How to use the poutyne.utils._concat function in Poutyne

To help you get started, we’ve selected a few Poutyne examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github GRAAL-Research / poutyne / tests / test_utils.py View on Github external
def test_tuple_2(self):
        """
        Test the concatenation of a [([], ([], []))]
        """
        obj = [(np.arange(5), (np.ones(5) * 2, np.ones(5) * 3))] * 5
        concat = _concat(obj)
        self.assertEqual(concat[0].shape, (25, ))
        self.assertEqual(concat[1][0].shape, (25, ))
        self.assertEqual(concat[1][1].shape, (25, ))
        for i in range(5):
            for j in range(5):
                self.assertTrue(concat[0][i * 5 + j] == j)
        self.assertTrue((concat[1][0] == 2).all())
        self.assertTrue((concat[1][1] == 3).all())
github GRAAL-Research / poutyne / tests / framework / test_model.py View on Github external
else:
                self.assertEqual(type(obj), np.ndarray)
                self.assertEqual(obj.shape, (cur_batch_size, 1))

        for pred in pred_y:
            if remaning_example < ModelTest.batch_size:
                cur_batch_size = remaning_example
                remaning_example = 0
            else:
                remaning_example -= ModelTest.batch_size
            self._test_size_and_type_for_generator(pred, (cur_batch_size, 1))
        if multi_output:
            for pred in _concat(pred_y):
                self.assertEqual(pred.shape, (ModelTest.evaluate_dataset_len, 1))
        else:
            self.assertEqual(_concat(pred_y).shape, (ModelTest.evaluate_dataset_len, 1))
github GRAAL-Research / poutyne / tests / framework / test_model.py View on Github external
def test_predict_generator_multi_output(self):
        num_steps = 10
        generator = some_data_tensor_generator_multi_output(ModelTest.batch_size)
        generator = (x for x, _ in generator)
        pred_y = self.multi_output_model.predict_generator(generator, steps=num_steps)

        for pred in pred_y:
            self._test_size_and_type_for_generator(pred, (ModelTest.batch_size, 1))
            # self.assertEqual(type(pred), np.ndarray)
            # self.assertEqual(pred.shape, (ModelTest.batch_size, 1))
        for pred in _concat(pred_y):
            self.assertEqual(pred.shape, (num_steps * ModelTest.batch_size, 1))
github GRAAL-Research / poutyne / tests / framework / test_model.py View on Github external
def test_predict_generator_multi_io(self):
        num_steps = 10
        generator = some_data_tensor_generator_multi_io(ModelTest.batch_size)
        generator = (x for x, _ in generator)
        pred_y = self.multi_io_model.predict_generator(generator, steps=num_steps)

        for pred in pred_y:
            self._test_size_and_type_for_generator(pred, (ModelTest.batch_size, 1))

        for pred in _concat(pred_y):
            self.assertEqual(pred.shape, (num_steps * ModelTest.batch_size, 1))
github GRAAL-Research / poutyne / tests / test_utils.py View on Github external
def test_tuple_1(self):
        """
        Test the concatenation of a [([], [])]
        """
        obj = [(np.arange(5), np.ones(5) * 2)] * 5
        concat = _concat(obj)
        self.assertEqual(concat[0].shape, (25, ))
        self.assertEqual(concat[1].shape, (25, ))
        for i in range(5):
            for j in range(5):
                self.assertTrue(concat[0][i * 5 + j] == j)
        self.assertTrue((concat[1] == 2).all())
github GRAAL-Research / poutyne / tests / test_utils.py View on Github external
def test_array_1(self):
        """
        Test the concatenation of a [[[], []]]
        """
        obj = [[np.arange(5), np.ones(5) * 2]] * 5
        concat = _concat(obj)
        self.assertEqual(concat[0].shape, (25, ))
        self.assertEqual(concat[1].shape, (25, ))
        for i in range(5):
            for j in range(5):
                self.assertTrue(concat[0][i * 5 + j] == j)
        self.assertTrue((concat[1] == 2).all())
github GRAAL-Research / poutyne / tests / framework / test_model.py View on Github external
def test_evaluate_generator_multi_output(self):
        num_steps = 10
        generator = some_data_tensor_generator_multi_output(ModelTest.batch_size)
        loss, metrics, pred_y = self.multi_output_model.evaluate_generator(generator, steps=num_steps, return_pred=True)
        self.assertEqual(type(loss), float)
        self.assertEqual(type(metrics), np.ndarray)
        self.assertEqual(metrics.tolist(), [some_metric_1_value, some_metric_2_value, some_constant_epoch_metric_value])
        self._test_size_and_type_for_generator(pred_y, (ModelTest.batch_size, 1))
        for pred in _concat(pred_y):
            self.assertEqual(pred.shape, (num_steps * ModelTest.batch_size, 1))
github GRAAL-Research / poutyne / tests / test_utils.py View on Github external
def test_single_array(self):
        """
        Test the concatenation of a single array
        """
        obj = [np.arange(5)] * 5
        concat = _concat(obj)
        self.assertEqual(concat.shape, (25, ))
github GRAAL-Research / poutyne / poutyne / framework / model.py View on Github external
``metrics`` is a Numpy array of size ``n``, where ``n`` is the
            number of batch metrics plus the number of epoch metrics if ``n > 1``. If
            ``n == 1``, then ``metrics`` is a float. If ``n == 0``, the ``metrics`` is
            omitted. The first elements of ``metrics`` are the batch metrics and are
            followed by the epoch metrics. See the :func:`~Model.fit_generator()` method
            for examples with batch metrics and epoch metrics.

            If ``return_pred`` is True, ``pred_y`` is the list of the predictions
            of each batch with tensors converted into Numpy arrays. It is otherwise ommited.


        """
        generator = self._dataloader_from_data((x, y), batch_size=batch_size)
        ret = self.evaluate_generator(generator, steps=len(generator), return_pred=return_pred)
        if return_pred:
            ret = (*ret[:-1], _concat(ret[-1]))
        return ret