Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def rand_default(scale=3):
randn = npr.RandomState(0).randn
return partial(_rand_dtype, randn, scale=scale)
tree_all(tree_multimap(_assert_numpy_allclose, xs, ys))
def check_close(xs, ys, atol=None, rtol=None):
assert_close = partial(_assert_numpy_close, atol=atol, rtol=rtol)
tree_all(tree_multimap(assert_close, xs, ys))
def inner_prod(xs, ys):
def contract(x, y):
return onp.real(onp.dot(onp.conj(x).reshape(-1), y.reshape(-1)))
return tree_reduce(onp.add, tree_multimap(contract, xs, ys))
add = partial(tree_multimap, lambda x, y: onp.add(x, y, dtype=_dtype(x)))
sub = partial(tree_multimap, lambda x, y: onp.subtract(x, y, dtype=_dtype(x)))
conj = partial(tree_map, lambda x: onp.conj(x, dtype=_dtype(x)))
def scalar_mul(xs, a):
return tree_map(lambda x: onp.multiply(x, a, dtype=_dtype(x)), xs)
def rand_like(rng, x):
shape = onp.shape(x)
dtype = _dtype(x)
randn = lambda: onp.asarray(rng.randn(*shape), dtype=dtype)
if dtypes.issubdtype(dtype, onp.complexfloating):
return randn() + dtype.type(1.0j) * randn()
else:
return randn()
def check_close(xs, ys, atol=None, rtol=None):
assert_close = partial(_assert_numpy_close, atol=atol, rtol=rtol)
tree_all(tree_multimap(assert_close, xs, ys))
def defvectorized(prim):
primitive_batchers[prim] = partial(vectorized_batcher, prim)
@partial(jit, static_argnums=(1, 2))
def _pad(array, pad_width, mode, constant_values):
array = asarray(array)
nd = ndim(array)
pad_width = onp.broadcast_to(onp.asarray(pad_width), (nd, 2))
if any(pad_width < 0):
raise ValueError("index can't contain negative values")
if mode == "constant":
constant_values = broadcast_to(asarray(constant_values), (nd, 2))
constant_values = lax.convert_element_type(constant_values, array.dtype)
for i in xrange(nd):
widths = [(0, 0, 0)] * nd
widths[i] = (pad_width[i, 0], 0, 0)
array = lax.pad(array, constant_values[i, 0], widths)
widths[i] = (0, pad_width[i, 1], 0)
array = lax.pad(array, constant_values[i, 1], widths)
def PixelCNNPP(nr_resnet=5, nr_filters=160, nr_logistic_mix=10, dropout_p=.5):
Resnet = partial(GatedResnet, dropout_p=dropout_p)
ResnetDown = partial(Resnet, Conv=DownShiftedConv)
ResnetDownRight = partial(Resnet, Conv=DownRightShiftedConv)
ConvDown = partial(DownShiftedConv, out_chan=nr_filters)
ConvDownRight = partial(DownRightShiftedConv, out_chan=nr_filters)
HalveDown = partial(ConvDown, strides=(2, 2))
HalveDownRight = partial(ConvDownRight, strides=(2, 2))
DoubleDown = partial(DownShiftedConvTranspose, out_chan=nr_filters, strides=(2, 2))
DoubleDownRight = partial(DownRightShiftedConvTranspose, out_chan=nr_filters, strides=(2, 2))
def ResnetUpBlock():
@parametrized
def resnet_up_block(us, uls):
for _ in range(nr_resnet):
us.append(ResnetDown()(us[-1]))
uls.append(ResnetDownRight()(uls[-1], us[-1]))
return us, uls
return resnet_up_block
def ResnetDownBlock(nr_resnet):
@parametrized
def resnet_down_block(u, ul, us, uls):
us = us.copy()
if bdim is not None:
window_dimensions = \
window_dimensions[:bdim] + (1,) + window_dimensions[bdim:]
window_strides = window_strides[:bdim] + (1,) + window_strides[bdim:]
operand = _reduce_window_max(
operand, window_dimensions, window_strides, padding)
return operand, 0
_reduce_window_max_translation_rule = partial(
_reduce_window_chooser_translation_rule, max_p, _get_max_identity)
reduce_window_max_p = standard_primitive(
_common_reduce_window_shape_rule, _input_dtype, 'reduce_window_max',
_reduce_window_max_translation_rule)
ad.defjvp(reduce_window_max_p, partial(_reduce_window_chooser_jvp_rule, max_p))
batching.primitive_batchers[reduce_window_max_p] = _reduce_window_max_batch_rule
_reduce_window_min_translation_rule = partial(
_reduce_window_chooser_translation_rule, min_p, _get_min_identity)
reduce_window_min_p = standard_primitive(
_common_reduce_window_shape_rule, _input_dtype, 'reduce_window_min',
_reduce_window_min_translation_rule)
ad.defjvp(reduce_window_min_p, partial(_reduce_window_chooser_jvp_rule, min_p))
def _select_and_scatter_shape_rule(
operand, source, init_value, select_jaxpr, select_consts, scatter_jaxpr,
scatter_consts, window_dimensions, window_strides, padding):
_check_shapelike("select_and_scatter", "window_dimensions", window_dimensions)
_check_shapelike("select_and_scatter", "window_strides", window_strides)
if len(window_dimensions) != len(window_strides):
@partial(jit, static_argnums=(1, 2))
def _cumulative_reduction(a, axis, dtype):
if axis is None or isscalar(a):
a = ravel(a)
axis = 0
a_shape = list(shape(a))
num_dims = len(a_shape)
if axis < 0:
axis = axis + num_dims
if axis < 0 or axis >= num_dims:
raise ValueError(
"axis {} is out of bounds for array of dimension {}".format(
axis, num_dims))
if squash_nan:
const_dims, init_dims = split_list(dims, [cond_nconsts + body_nconsts])
new_consts = [batching.moveaxis(x, d, 0) if d is not batching.not_mapped and d != 0
else x for x, d in zip(consts, const_dims)]
new_init = [batching.broadcast(x, size, 0) if now_bat and not was_bat
else batching.moveaxis(x, d, 0) if now_bat else x
for x, d, was_bat, now_bat in zip(init, init_dims, init_bat, carry_bat)]
outs = while_p.bind(*(new_consts + new_init),
cond_nconsts=cond_nconsts, cond_jaxpr=cond_jaxpr_batched,
body_nconsts=body_nconsts, body_jaxpr=body_jaxpr_batched)
out_bdims = [0 if b else batching.not_mapped for b in carry_bat]
return outs, out_bdims
while_p = lax.Primitive('while')
while_p.multiple_results = True
while_p.def_impl(partial(xla.apply_primitive, while_p))
while_p.def_abstract_eval(_while_loop_abstract_eval)
xla.initial_style_translations[while_p] = _while_loop_translation_rule
batching.primitive_batchers[while_p] = _while_loop_batching_rule
### cond
def cond(pred, true_operand, true_fun, false_operand, false_fun):
"""Conditionally apply ``true_fun`` or ``false_fun``.
Has equivalent semantics to this Python implementation::
def cond(pred, true_operand, true_fun, false_operand, false_fun):
if pred:
return true_fun(true_operand)
else:
def PmapPrimitive(name):
prim = Primitive(name)
prim.def_impl(partial(unbound_name_error, name))
prim.def_abstract_eval(lambda x, *args, **kwargs: x) # default
return prim