Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
from jax.interpreters import polysimp
from jax import lax
from jax import make_jaxpr
import jax.numpy as np
import jax.linear_util as lu
polysimp.addition_primitives.add(lax.add_p)
polysimp.multiplication_primitives.add(lax.mul_p)
polysimp.multiplication_primitives.add(lax.dot_p)
polysimp.linear_primitives.add(lax.broadcast_p)
polysimp.linear_primitives.add(lax.convert_element_type_p)
polysimp.linear_primitives.add(lax.reshape_p)
f = lambda x: x + x * x * x + 3 * x + 4 * x * x * x
print make_jaxpr(f)(2)
print f(2)
print make_jaxpr(polysimp.polysimp(lu.wrap_init(f)).call_wrapped)((2,))
print polysimp.polysimp(lu.wrap_init(f)).call_wrapped((2,))
import numpy as onp
update_add = update_add_numpy
update_multiply = update_multiply_numpy
at = Index()
solve_tridiagonal = solve_tridiagonal_numpy
scan = scan_numpy
elif runtime_settings.backend == 'jax':
import jax
import jax.numpy
numpy = jax.numpy
update = jax.ops.index_update
update_add = jax.ops.index_add
update_multiply = update_multiply_jax
at = jax.ops.index
solve_tridiagonal = solve_tridiagonal_jax
scan = jax.lax.scan
else:
raise ValueError()
def _while_loop(cond, body, loop_vars, **kwargs): # pylint: disable=missing-docstring
del kwargs
# JAX doesn't do the automatic unwrapping of variables.
def cond_wrapper(loop_vars):
return cond(*loop_vars)
def body_wrapper(loop_vars):
return body(*loop_vars)
return lax.while_loop(cond_wrapper, body_wrapper, loop_vars)
def _index_arrays(i, aval, xs):
if isinstance(aval, core.AbstractTuple):
return core.pack(map(partial(_index_arrays, i), aval, xs))
else:
return lax.dynamic_index_in_dim(xs, i, keepdims=False)
bitwise_not = _one_to_one_unop(onp.bitwise_not, lax.bitwise_not)
negative = _one_to_one_unop(onp.negative, lax.neg)
positive = _one_to_one_unop(onp.positive, lambda x: x)
sign = _one_to_one_unop(onp.sign, lax.sign)
floor = _one_to_one_unop(onp.floor, lax.floor, True)
ceil = _one_to_one_unop(onp.ceil, lax.ceil, True)
exp = _one_to_one_unop(onp.exp, lax.exp, True)
log = _one_to_one_unop(onp.log, lax.log, True)
expm1 = _one_to_one_unop(onp.expm1, lax.expm1, True)
log1p = _one_to_one_unop(onp.log1p, lax.log1p, True)
sin = _one_to_one_unop(onp.sin, lax.sin, True)
cos = _one_to_one_unop(onp.cos, lax.cos, True)
tan = _one_to_one_unop(onp.tan, lax.tan, True)
arcsin = _one_to_one_unop(onp.arcsin, lax.asin, True)
arccos = _one_to_one_unop(onp.arccos, lax.acos, True)
arctan = _one_to_one_unop(onp.arctan, lax.atan, True)
sinh = _one_to_one_unop(onp.sinh, lax.sinh, True)
cosh = _one_to_one_unop(onp.cosh, lax.cosh, True)
tanh = _one_to_one_unop(onp.tanh, lax.tanh, True)
sqrt = _one_to_one_unop(onp.sqrt, lax.sqrt, True)
add = _maybe_bool_binop(onp.add, lax.add, lax.bitwise_or)
bitwise_and = _one_to_one_binop(onp.bitwise_and, lax.bitwise_and)
bitwise_or = _one_to_one_binop(onp.bitwise_or, lax.bitwise_or)
bitwise_xor = _one_to_one_binop(onp.bitwise_xor, lax.bitwise_xor)
right_shift = _one_to_one_binop(onp.right_shift, lax.shift_right_arithmetic)
left_shift = _one_to_one_binop(onp.left_shift, lax.shift_left)
equal = _one_to_one_binop(onp.equal, lax.eq)
multiply = _maybe_bool_binop(onp.multiply, lax.mul, lax.bitwise_and)
not_equal = _one_to_one_binop(onp.not_equal, lax.ne)
"""Masks out elements attending to self.
Args:
N: number of query positions
M: number of key positions
k: position of the initial query element
Returns:
N x M mask, where 1.0 indicates that attention is not allowed.
"""
x = jax.lax.tie_in(k, np.arange(N, dtype=np.int32))
y = jax.lax.tie_in(k, np.arange(M, dtype=np.int32))
mask = jax.lax.eq(
(jax.lax.broadcast_in_dim(
x, shape=(N, M), broadcast_dimensions=(0,)) + k),
jax.lax.broadcast(y, [N]))
mask = jax.lax.convert_element_type(mask, np.float32)
return mask
def logpdf(x, loc=0, scale=1):
x, loc, scale = _promote_args_inexact("uniform.logpdf", x, loc, scale)
log_probs = lax.neg(lax.log(scale))
return where(logical_or(lax.gt(x, lax.add(loc, scale)),
lax.lt(x, loc)),
-inf, log_probs)
query_slice = jax.lax.dynamic_slice_in_dim(
query, q_loop_idx, q_loop_stride, axis=-2)
if do_backprop:
ct_slice = jax.lax.dynamic_slice_in_dim(
ct, q_loop_idx, q_loop_stride, axis=-2)
out_slice, partial_ct = forward_and_vjp_slice(
query_slice, q_loop_idx, key, value, ct_slice)
query_ct_accum = jax.lax.dynamic_update_slice_in_dim(
query_ct_accum, partial_ct[0], q_loop_idx, axis=-2)
key_ct_accum = key_ct_accum + partial_ct[1]
value_ct_accum = value_ct_accum + partial_ct[2]
else:
out_slice = forward_slice(query_slice, q_loop_idx, key, value)
out_accum = jax.lax.dynamic_update_slice_in_dim(
out_accum, out_slice, q_loop_idx, axis=-2)
q_loop_idx = q_loop_idx + q_loop_stride
if do_backprop:
return (q_loop_idx, out_accum,
query_ct_accum, key_ct_accum, value_ct_accum)
else:
return (q_loop_idx, out_accum)
def vec_to_tril_matrix(t, diagonal=0):
# NB: the following formula only works for diagonal <= 0
n = round((math.sqrt(1 + 8 * t.shape[-1]) - 1) / 2) - diagonal
n2 = n * n
idx = jnp.reshape(jnp.arange(n2), (n, n))[jnp.tril_indices(n, diagonal)]
x = lax.scatter_add(jnp.zeros(t.shape[:-1] + (n2,)), jnp.expand_dims(idx, axis=-1), t,
lax.ScatterDimensionNumbers(update_window_dims=range(t.ndim - 1),
inserted_window_dims=(t.ndim - 1,),
scatter_dims_to_operand_dims=(t.ndim - 1,)))
return jnp.reshape(x, x.shape[:-1] + (n, n))
| |-> (*) -|
input -|-> [filter] -| |-> 1x1 conv -|
| |-> (+) -> dense output
|------------------------------------|
Where `[gate]` and `[filter]` are causal convolutions with a
non-linear activation at the output
"""
gated = Sequential(Conv1D(dilation_channels, (filter_width,),
dilation=(dilation,)), sigmoid)(inputs)
filtered = Sequential(Conv1D(dilation_channels, (filter_width,),
dilation=(dilation,)), np.tanh)(inputs)
p = gated * filtered
out = Conv1D(residual_channels, (1,), padding='SAME')(p)
# Add the transformed output of the resblock to the sliced input:
sliced_inputs = lax.dynamic_slice(
inputs, [0, inputs.shape[1] - out.shape[1], 0],
[inputs.shape[0], out.shape[1], inputs.shape[2]])
new_out = sum(out, sliced_inputs)
skip = Conv1D(residual_channels, (1,), padding='SAME')(skip_slice(p, output_width))
return new_out, skip