Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def loop(self):
# Loop through the monitored iterator
for epoch in logger.loop(range(0, self.__epochs)):
self._train()
self._test()
self.__log_model_params()
# Clear line and output to console
logger.write()
# Clear line and go to the next line;
# that is, we add a new line to the output
# at the end of each epoch
if (epoch + 1) % self.__log_new_line_interval == 0:
logger.new_line()
if self.__is_save_models:
logger.save_checkpoint()
def loop(self):
# Loop through the monitored iterator
for epoch in logger.loop(range(0, self.__epochs)):
self._train()
self._test()
self.__log_model_params()
# Clear line and output to console
logger.write()
# Clear line and go to the next line;
# that is, we add a new line to the output
# at the end of each epoch
if (epoch + 1) % self.__log_new_line_interval == 0:
logger.new_line()
if self.__is_save_models:
logger.save_checkpoint()
# Training and testing
train(args, model, device, train_loader, optimizer, epoch)
test(model, device, test_loader)
# Add histograms with model parameter values and gradients
for name, param in model.named_parameters():
if param.requires_grad:
logger.store(name, param.data.cpu().numpy())
logger.store(f"{name}_grad", param.grad.cpu().numpy())
# Clear line and output to console
logger.write()
# Output the progress summaries to `trial.yaml` and
# to the python file header
logger.save_progress()
# Clear line and go to the next line;
# that is, we add a new line to the output
# at the end of each epoch
logger.new_line()
# Handled delayed interrupt
except KeyboardInterrupt:
logger.finish_loop()
logger.new_line()
logger.log("\nKilling loop...")
break
self._train()
self._test()
self.__log_model_params()
# Clear line and output to console
logger.write()
# Clear line and go to the next line;
# that is, we add a new line to the output
# at the end of each epoch
if (epoch + 1) % self.__log_new_line_interval == 0:
logger.new_line()
if self.__is_save_models:
logger.save_checkpoint()
def __next__(self):
if self.__signal_received is not None:
logger.log('\nKilling Loop.',
color=colors.Color.red)
logger.finish_loop()
self.__finish()
raise StopIteration("SIGINT")
try:
epoch = next(self.__loop)
except StopIteration as e:
self.__finish()
raise e
if self.__is_interval(epoch, self.__log_write_interval):
logger.write()
if self.__is_interval(epoch, self.__log_new_line_interval):
logger.new_line()
self.run = Run.create(
experiment_path=self.experiment_path,
python_file=python_file,
trial_time=time.localtime(),
comment=comment)
repo = git.Repo(self.lab.path)
self.run.commit = repo.head.commit.hexsha
self.run.commit_message = repo.head.commit.message.strip()
self.run.is_dirty = repo.is_dirty()
self.run.diff = repo.git.diff()
checkpoint_saver = self._create_checkpoint_saver()
logger.internal().set_checkpoint_saver(checkpoint_saver)
if writers is None:
writers = {'sqlite', 'tensorboard'}
if 'sqlite' in writers:
logger.internal().add_writer(sqlite.Writer(self.run.sqlite_path))
if 'tensorboard' in writers:
logger.internal().add_writer(tensorboard.Writer(self.run.tensorboard_log_path))
def create_writer(self, session: tf_compat.Session):
"""
## Create TensorFlow summary writer
"""
logger.add_writer(tensorboard_writer.Writer(
tf_compat.summary.FileWriter(str(self.info.summary_path), session.graph)))