Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def add_scalar(self, name, value, **kwargs):
if 'covariance' in name:
return
self.writer.add_scalar(name, value)
def add_scalars(self, prefix, value_dict, **kwargs):
self.writer.add_scalars(prefix, value_dict)
def save(self):
pass
def close(self):
self.writer.close()
class CSVandPlottingWriter(CSVWriter):
def __init__(self, savepath: str, plot_manipulation_func: Callable[[plt.Axes], plt.Axes] = None, **kwargs):
"""
This writer produces CSV files and plots.
:param savepath: Path to store plots and CSV files
:param plot_manipulation_func: A function mapping an axis object to an axis object by
using pyplot code.
:param kwargs:
"""
super(CSVandPlottingWriter, self).__init__(savepath)
self.plot_man_func = plot_manipulation_func if plot_manipulation_func is not None else lambda x: x
self.primary_metric = None if not 'primary_metric' in kwargs else kwargs['primary_metric']
self.fontsize = 16 if not 'fontsize' in kwargs else kwargs['fontsize']
self.figsize = None if not 'figsize' in kwargs else kwargs['figsize']
self.epoch_counter: int = 0
self.stats = []
def __init__(self, savepath: str, **kwargs):
"""
This writer produces a csv file with all saturation values.
The csv-file is overwritten with
an updated version every time save() is called.
:param savepath: CSV file path
"""
super(CSVWriter, self).__init__()
self.value_dict = {}
self.savepath = savepath