Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
n_neighbors=self.n_neighbors, weights=self.weights,
algorithm='brute', metric=_dtw_classic,
metric_params=self.metric_params,
n_jobs=self.n_jobs, **self.kwargs
)
elif self.metric == 'dtw_sakoechiba':
n_timestamps = X.shape[1]
if self.metric_params is None:
region = sakoe_chiba_band(n_timestamps)
else:
if 'window_size' not in self.metric_params.keys():
window_size = 0.1
else:
window_size = self.metric_params['window_size']
region = sakoe_chiba_band(n_timestamps,
window_size=window_size)
self._clf = SklearnKNN(
n_neighbors=self.n_neighbors, weights=self.weights,
algorithm='brute', metric=_dtw_region,
metric_params={'region': region},
n_jobs=self.n_jobs, **self.kwargs
)
elif self.metric == 'dtw_itakura':
n_timestamps = X.shape[1]
if self.metric_params is None:
region = itakura_parallelogram(n_timestamps)
else:
if 'max_slope' not in self.metric_params.keys():
max_slope = 2.
else:
metric_params=self.metric_params,
n_jobs=self.n_jobs, **self.kwargs
)
elif self.metric == 'dtw_classic':
self._clf = SklearnKNN(
n_neighbors=self.n_neighbors, weights=self.weights,
algorithm='brute', metric=_dtw_classic,
metric_params=self.metric_params,
n_jobs=self.n_jobs, **self.kwargs
)
elif self.metric == 'dtw_sakoechiba':
n_timestamps = X.shape[1]
if self.metric_params is None:
region = sakoe_chiba_band(n_timestamps)
else:
if 'window_size' not in self.metric_params.keys():
window_size = 0.1
else:
window_size = self.metric_params['window_size']
region = sakoe_chiba_band(n_timestamps,
window_size=window_size)
self._clf = SklearnKNN(
n_neighbors=self.n_neighbors, weights=self.weights,
algorithm='brute', metric=_dtw_region,
metric_params={'region': region},
n_jobs=self.n_jobs, **self.kwargs
)
elif self.metric == 'dtw_itakura':
n_timestamps = X.shape[1]
plt.subplot(2, 2, 1)
plt.pcolor(timestamps_1, timestamps_2, matrix_classic,
edgecolors='k', cmap='Greys')
plt.xlabel('x', fontsize=12)
plt.ylabel('y', fontsize=12)
plt.title("{0}\nDTW(x, y) = {1:.2f}".format('classic', dtw_classic),
fontsize=14)
# Dynamic Time Warping: sakoechiba
window_size = 0.1
dtw_sakoechiba, path_sakoechiba = dtw(
x, y, dist='square', method='sakoechiba',
options={'window_size': window_size}, return_path=True
)
band = sakoe_chiba_band(n_timestamps_1, n_timestamps_2,
window_size=window_size)
matrix_sakoechiba = np.zeros((n_timestamps_1 + 1, n_timestamps_2 + 1))
for i in range(n_timestamps_1):
matrix_sakoechiba[i, np.arange(*band[:, i])] = 0.5
matrix_sakoechiba[tuple(path_sakoechiba)] = 1.
plt.subplot(2, 2, 2)
plt.pcolor(timestamps_1, timestamps_2, matrix_sakoechiba.T,
edgecolors='k', cmap='Greys')
plt.xlabel('x', fontsize=12)
plt.ylabel('y', fontsize=12)
plt.title("{0}\nDTW(x, y) = {1:.2f}".format('sakoechiba', dtw_sakoechiba),
fontsize=14)
# Dynamic Time Warping: itakura
slope = 1.2
def plot_sakoe_chiba(n_timestamps_1, n_timestamps_2, window_size=0.5, ax=None):
"""Plot the Sakoe-Chiba band."""
region = sakoe_chiba_band(n_timestamps_1, n_timestamps_2, window_size)
scale, horizontal_shift, vertical_shift = \
_check_sakoe_chiba_params(n_timestamps_1, n_timestamps_2, window_size)
mask = np.zeros((n_timestamps_2, n_timestamps_1))
for i, (j, k) in enumerate(region.T):
mask[j:k, i] = 1.
plt.imshow(mask, origin='lower', cmap='Wistia', vmin=0, vmax=1)
sz = max(n_timestamps_1, n_timestamps_2)
x = np.arange(-1, sz + 1)
lower_bound = scale * (x - horizontal_shift) - vertical_shift
upper_bound = scale * (x + horizontal_shift) + vertical_shift
plt.plot(x, lower_bound, 'b', lw=2)
plt.plot(x, upper_bound, 'g', lw=2)
diag = (n_timestamps_2 - 1) / (n_timestamps_1 - 1) * np.arange(-1, sz + 1)
plt.plot(x, diag, 'black', lw=1)