Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def _get_writer(self, save_to, writers_args) -> \
delve.writers.AbstractWriter:
"""Create a writer to log history to `writer_dir`."""
if issubclass(type(save_to), delve.writers.AbstractWriter):
return save_to
if isinstance(save_to, list):
all_writers = []
for saver in save_to:
all_writers.append(self._get_writer(save_to=saver,
writers_args=writers_args))
return CompositWriter(all_writers)
if hasattr(delve, save_to):
writer = getattr(delve, save_to)(**writers_args)
else:
raise ValueError('Illegal argument for save_to "{}"'.
format(save_to))
return writer
w.add_scalar(name, value, **kwargs)
def add_scalars(self, prefix, value_dict, **kwargs):
for w in self.writers:
w.add_scalars(prefix, value_dict, **kwargs)
def save(self):
for w in self.writers:
w.save()
def close(self):
for w in self.writers:
w.close()
class CSVWriter(AbstractWriter):
def __init__(self, savepath: str, **kwargs):
"""
This writer produces a csv file with all saturation values.
The csv-file is overwritten with
an updated version every time save() is called.
:param savepath: CSV file path
"""
super(CSVWriter, self).__init__()
self.value_dict = {}
self.savepath = savepath
def resume_from_saved_state(self, initial_epoch: int):
self.epoch_counter = initial_epoch
if self._check_savestate_ok(self.savepath+'.csv'):
self.value_dict = pd.read_csv(self.savepath + '.csv', sep=';', index_col=0).to_dict('list')
raise NotImplementedError()
@abstractmethod
def add_scalars(self, prefix, value_dict, global_step, **kwargs):
raise NotImplementedError()
@abstractmethod
def save(self):
pass
@abstractmethod
def close(self):
pass
class CompositWriter(AbstractWriter):
def __init__(self, writers: List[AbstractWriter]):
"""
This writer combines multiple writers.
:param writers: List of writers. Each writer is called when the CompositeWriter is invoked.
"""
super(CompositWriter, self).__init__()
self.writers = writers
def resume_from_saved_state(self, initial_epoch: int):
for w in self.writers:
try:
w.resume_from_saved_state(initial_epoch)
except NotImplementedError:
warnings.warn(f'Writer {w.__class__.__name__} raised a NotImplementedError when attempting to resume training'
'This may result in corrupted or overwritten data.')
def save(self):
pkl.dump(self.epoch_counter, open(os.path.join(self.savepath, 'epoch_counter.pkl'), 'wb'))
if self.zip:
make_archive(
base_name=os.path.basename(self.savepath),
format='zip',
root_dir=os.path.dirname(self.savepath),
verbose=True
)
def close(self):
pass
class PrintWriter(AbstractWriter):
def __init__(self, **kwargs):
"""
Prints output to the console
"""
super(PrintWriter, self).__init__()
def resume_from_saved_state(self, initial_epoch: int):
pass
def add_scalar(self, name, value, **kwargs):
print(name, ':', value)
def add_scalars(self, prefix, value_dict, **kwargs):
for key in value_dict.keys():
self.add_scalar(prefix + '_' + key, value_dict[key])
else:
self.value_dict[name] = [value]
return
def add_scalars(self, prefix, value_dict, **kwargs):
for name in value_dict.keys():
self.add_scalar(name, value_dict[name])
def save(self):
pd.DataFrame.from_dict(self.value_dict).to_csv(self.savepath + '.csv', sep=';')
def close(self):
pass
class NPYWriter(AbstractWriter):
def __init__(self, savepath: str, zip: bool = False, **kwargs):
"""
The NPYWriter creates a folder containing one subfolder for each stat.
Each subfolder contains a npy-file with the saturation value for each epoch.
This writer saves non-scalar values and can thus be used to save
the covariance-matrix.
:param savepath: The root folder to save the folder structure to
:param zip: Whether to zip the output folder after every invocation
"""
super(NPYWriter, self).__init__()
self.savepath = savepath
self.epoch_counter = {}
self.zip = zip
def resume_from_saved_state(self, initial_epoch: int):
def add_scalar(self, name, value, **kwargs):
print(name, ':', value)
def add_scalars(self, prefix, value_dict, **kwargs):
for key in value_dict.keys():
self.add_scalar(prefix + '_' + key, value_dict[key])
def save(self):
pass
def close(self):
pass
class TensorBoardWriter(AbstractWriter):
def __init__(self, savepath: str, **kwargs):
"""
Writes output to tensorflow logs
:param savepath: the path for result logging
"""
super(TensorBoardWriter, self).__init__()
self.savepath = savepath
self.writer = SummaryWriter(savepath)
def resume_from_saved_state(self, initial_epoch: int):
raise NotImplementedError('Resuming is not yet implemented for TensorBoardWriter')
def add_scalar(self, name, value, **kwargs):
if 'covariance' in name:
return