Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def get_data(n_images, axes, shape):
red_n = rng.choice(len(axes)-1)+1
red_axes = ''.join(rng.choice(tuple(axes),red_n,replace=False))
keepdims = rng.choice((True,False))
def _gen():
for i in range(n_images):
x = rng.uniform(size=shape)
y = np.mean(x,axis=tuple(axes_dict(axes)[a] for a in red_axes),keepdims=keepdims)
yield x, y, axes, None
return RawData(_gen, n_images, ''), red_axes, keepdims
def get_data(n_images, axes, shape):
def _gen():
for i in range(n_images):
x = rng.uniform(size=shape)
y = 5 + 3*x
yield x, y, axes, None
return RawData(_gen, n_images, '')
def _create(img_size,img_axes,patch_size,patch_axes):
U,V = (rng.uniform(size=(n_images,)+img_size) for _ in range(2))
X,Y,XYaxes = create_patches (
raw_data = RawData.from_arrays(U,V,img_axes),
patch_size = patch_size,
patch_axes = patch_axes,
n_patches_per_image = n_patches_per_image,
save_file = save_file
)
(_X,_Y), val_data, _XYaxes = load_training_data(save_file,verbose=True)
assert val_data is None
assert _XYaxes[-1 if backend_channels_last else 1] == 'C'
_X,_Y = (move_image_axes(u,fr=_XYaxes,to=XYaxes) for u in (_X,_Y))
assert np.allclose(X,_X,atol=1e-6)
assert np.allclose(Y,_Y,atol=1e-6)
assert set(XYaxes) == set(_XYaxes)
assert load_training_data(save_file,validation_split=0.5)[2] is not None
assert all(len(x)==3 for x in load_training_data(save_file,n_images=3)[0])