How to use the poutyne.framework.callbacks.ModelCheckpoint function in Poutyne

To help you get started, we’ve selected a few Poutyne examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github GRAAL-Research / poutyne / tests / framework / callbacks / test_checkpoint.py View on Github external
def test_temporary_filename_arg_with_differing_checkpoint_filename(self):
        epochs = 10
        tmp_filename = os.path.join(self.temp_dir_obj.name, 'my_checkpoint.tmp.ckpt')
        checkpoint_filename = os.path.join(self.temp_dir_obj.name, 'my_checkpoint_{epoch}.ckpt')
        train_gen = some_data_generator(ModelCheckpointTest.batch_size)
        valid_gen = some_data_generator(ModelCheckpointTest.batch_size)
        checkpointer = ModelCheckpoint(checkpoint_filename,
                                       monitor='val_loss',
                                       verbose=True,
                                       period=1,
                                       temporary_filename=tmp_filename)
        self.model.fit_generator(train_gen, valid_gen, epochs=epochs, steps_per_epoch=5, callbacks=[checkpointer])
        self.assertFalse(os.path.isfile(tmp_filename))
        for i in range(1, epochs + 1):
            self.assertTrue(os.path.isfile(checkpoint_filename.format(epoch=i)))
github GRAAL-Research / poutyne / tests / framework / callbacks / test_checkpoint.py View on Github external
def test_periodic_with_period_of_2(self):
        checkpointer = ModelCheckpoint(self.checkpoint_filename,
                                       monitor='val_loss',
                                       verbose=True,
                                       period=2,
                                       save_best_only=False)

        val_losses = [1] * 10
        has_checkpoints = [False, True] * 5
        self._test_checkpointer_with_val_losses(checkpointer, val_losses, has_checkpoints)
github GRAAL-Research / poutyne / tests / framework / callbacks / test_checkpoint.py View on Github external
def test_restore_best_without_save_best_only(self):
        with self.assertRaises(ValueError):
            ModelCheckpoint(self.checkpoint_filename,
                            monitor='val_loss',
                            verbose=True,
                            save_best_only=False,
                            restore_best=True)

        with self.assertRaises(ValueError):
            ModelCheckpoint(self.checkpoint_filename, monitor='val_loss', verbose=True, restore_best=True)
github GRAAL-Research / poutyne / tests / framework / callbacks / test_checkpoint.py View on Github external
def test_integration(self):
        train_gen = some_data_generator(ModelCheckpointTest.batch_size)
        valid_gen = some_data_generator(ModelCheckpointTest.batch_size)
        checkpointer = ModelCheckpoint(self.checkpoint_filename, monitor='val_loss', verbose=True, save_best_only=True)
        self.model.fit_generator(train_gen, valid_gen, epochs=10, steps_per_epoch=5, callbacks=[checkpointer])
github GRAAL-Research / poutyne / tests / framework / callbacks / test_checkpoint.py View on Github external
def test_periodic_with_period_of_1(self):
        checkpointer = ModelCheckpoint(self.checkpoint_filename,
                                       monitor='val_loss',
                                       verbose=True,
                                       period=1,
                                       save_best_only=False)

        val_losses = [1] * 10
        has_checkpoints = [True] * 10
        self._test_checkpointer_with_val_losses(checkpointer, val_losses, has_checkpoints)
github GRAAL-Research / poutyne / tests / framework / callbacks / test_checkpoint.py View on Github external
def test_save_best_only(self):
        checkpointer = ModelCheckpoint(self.checkpoint_filename, monitor='val_loss', verbose=True, save_best_only=True)

        val_losses = [10, 3, 8, 5, 2]
        has_checkpoints = [True, True, False, False, True]
        self._test_checkpointer_with_val_losses(checkpointer, val_losses, has_checkpoints)
github GRAAL-Research / poutyne / tests / framework / callbacks / test_checkpoint.py View on Github external
def test_non_atomic_write(self):
        checkpoint_filename = os.path.join(self.temp_dir_obj.name, 'my_checkpoint.ckpt')
        train_gen = some_data_generator(ModelCheckpointTest.batch_size)
        valid_gen = some_data_generator(ModelCheckpointTest.batch_size)
        checkpointer = ModelCheckpoint(checkpoint_filename,
                                       monitor='val_loss',
                                       verbose=True,
                                       period=1,
                                       atomic_write=False)
        self.model.fit_generator(train_gen, valid_gen, epochs=10, steps_per_epoch=5, callbacks=[checkpointer])
        self.assertTrue(os.path.isfile(checkpoint_filename))
github GRAAL-Research / poutyne / tests / framework / callbacks / test_checkpoint.py View on Github external
def test_temporary_filename_arg(self):
        tmp_filename = os.path.join(self.temp_dir_obj.name, 'my_checkpoint.tmp.ckpt')
        checkpoint_filename = os.path.join(self.temp_dir_obj.name, 'my_checkpoint.ckpt')
        train_gen = some_data_generator(ModelCheckpointTest.batch_size)
        valid_gen = some_data_generator(ModelCheckpointTest.batch_size)
        checkpointer = ModelCheckpoint(checkpoint_filename,
                                       monitor='val_loss',
                                       verbose=True,
                                       period=1,
                                       temporary_filename=tmp_filename)
        self.model.fit_generator(train_gen, valid_gen, epochs=10, steps_per_epoch=5, callbacks=[checkpointer])
        self.assertFalse(os.path.isfile(tmp_filename))
        self.assertTrue(os.path.isfile(checkpoint_filename))
github GRAAL-Research / poutyne / poutyne / framework / experiment.py View on Github external
def _init_model_restoring_callbacks(self, initial_epoch, save_every_epoch):
        callbacks = []
        best_checkpoint = ModelCheckpoint(self.best_checkpoint_filename,
                                          monitor=self.monitor_metric,
                                          mode=self.monitor_mode,
                                          save_best_only=not save_every_epoch,
                                          restore_best=not save_every_epoch,
                                          verbose=not save_every_epoch,
                                          temporary_filename=self.best_checkpoint_tmp_filename)
        callbacks.append(best_checkpoint)

        if save_every_epoch:
            best_restore = BestModelRestore(monitor=self.monitor_metric, mode=self.monitor_mode, verbose=True)
            callbacks.append(best_restore)

        if initial_epoch > 1:
            # We set the current best metric score in the ModelCheckpoint so that
            # it does not save checkpoint it would not have saved if the
            # optimization was not stopped.