Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
self.model = model
self.X = X
self.Y = Y
def predict(self, model, X):
# e.g. model(torch.fromnumpy(X)).detach().numpy()
return model.predict(X)
def draw(self, *args, **kwargs):
plt.plot(self.X, self.Y, 'r.', label="Ground truth")
plt.plot(self.X, self.predict(self.model, self.X), '-', label="Model")
plt.title("Prediction")
plt.legend(loc='lower right')
class Plot2d(BaseSubplot):
def __init__(self, model, X, Y, valiation_data=(None, None), h=0.02, margin=0.25):
super().__init__()
self.model = model
self.X = X
self.Y = Y
self.X_test, self.Y_test = valiation_data
# add size assertions
self.cm_bg = plt.cm.RdBu
self.cm_points = ListedColormap(['#FF0000', '#0000FF'])
h = .02 # step size in the mesh
x_min = X[:, 0].min() - margin
x_max = X[:, 0].max() + margin
serie_metric_name = serie_fmt.format(self.metric)
serie_metric_logs = [(log.get('_i', i + 1), log[serie_metric_name])
for i, log in enumerate(logs[skip:])
if serie_metric_name in log]
if len(serie_metric_logs) > 0:
xs, ys = zip(*serie_metric_logs)
plt.plot(xs, ys, label=serie_label)
plt.title(self.title)
plt.xlabel('epoch')
plt.legend(loc='center right')
class Plot1D(BaseSubplot):
def __init__(self, model, X, Y):
super().__init__(self)
self.model = model
self.X = X
self.Y = Y
def predict(self, model, X):
# e.g. model(torch.fromnumpy(X)).detach().numpy()
return model.predict(X)
def draw(self, *args, **kwargs):
plt.plot(self.X, self.Y, 'r.', label="Ground truth")
plt.plot(self.X, self.predict(self.model, self.X), '-', label="Model")
plt.title("Prediction")
plt.legend(loc='lower right')
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
class BaseSubplot:
def __init__(self):
pass
def draw(self):
raise Exception("Not implemented")
def __call__(self, *args, **kwargs):
self.draw(*args, **kwargs)
class LossSubplot(BaseSubplot):
def __init__(self,
metric,
title="",
series_fmt={'training': '{}', 'validation':'val_{}'},
skip_first=2,
max_epoch=None):
super().__init__(self)
self.metric = metric
self.title = title
self.series_fmt = series_fmt
self.skip_first = skip_first
self.max_epoch = max_epoch
def _how_many_to_skip(self, log_length, skip_first):
if log_length < skip_first:
return 0