How to use the jax.lax function in jax

To help you get started, we’ve selected a few jax examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github google / jax / polysimp.py View on Github external
from jax.interpreters import polysimp
from jax import lax
from jax import make_jaxpr
import jax.numpy as np

import jax.linear_util as lu

polysimp.addition_primitives.add(lax.add_p)
polysimp.multiplication_primitives.add(lax.mul_p)
polysimp.multiplication_primitives.add(lax.dot_p)
polysimp.linear_primitives.add(lax.broadcast_p)
polysimp.linear_primitives.add(lax.convert_element_type_p)
polysimp.linear_primitives.add(lax.reshape_p)

f = lambda x: x + x * x * x + 3 * x + 4 * x * x * x

print make_jaxpr(f)(2)
print f(2)

print make_jaxpr(polysimp.polysimp(lu.wrap_init(f)).call_wrapped)((2,))
print polysimp.polysimp(lu.wrap_init(f)).call_wrapped((2,))


import numpy as onp
github team-ocean / veros / veros / core / operators.py View on Github external
update_add = update_add_numpy
    update_multiply = update_multiply_numpy
    at = Index()
    solve_tridiagonal = solve_tridiagonal_numpy
    scan = scan_numpy

elif runtime_settings.backend == 'jax':
    import jax
    import jax.numpy
    numpy = jax.numpy
    update = jax.ops.index_update
    update_add = jax.ops.index_add
    update_multiply = update_multiply_jax
    at = jax.ops.index
    solve_tridiagonal = solve_tridiagonal_jax
    scan = jax.lax.scan

else:
    raise ValueError()
github tensorflow / probability / discussion / fun_mcmc / tf_on_jax.py View on Github external
def _while_loop(cond, body, loop_vars, **kwargs):  # pylint: disable=missing-docstring
  del kwargs

  # JAX doesn't do the automatic unwrapping of variables.
  def cond_wrapper(loop_vars):
    return cond(*loop_vars)

  def body_wrapper(loop_vars):
    return body(*loop_vars)

  return lax.while_loop(cond_wrapper, body_wrapper, loop_vars)
github google / jax / jax / initial_style.py View on Github external
def _index_arrays(i, aval, xs):
  if isinstance(aval, core.AbstractTuple):
    return core.pack(map(partial(_index_arrays, i), aval, xs))
  else:
    return lax.dynamic_index_in_dim(xs, i, keepdims=False)
github google / jax / jax / numpy / lax_numpy.py View on Github external
bitwise_not = _one_to_one_unop(onp.bitwise_not, lax.bitwise_not)
negative = _one_to_one_unop(onp.negative, lax.neg)
positive = _one_to_one_unop(onp.positive, lambda x: x)
sign = _one_to_one_unop(onp.sign, lax.sign)

floor = _one_to_one_unop(onp.floor, lax.floor, True)
ceil = _one_to_one_unop(onp.ceil, lax.ceil, True)
exp = _one_to_one_unop(onp.exp, lax.exp, True)
log = _one_to_one_unop(onp.log, lax.log, True)
expm1 = _one_to_one_unop(onp.expm1, lax.expm1, True)
log1p = _one_to_one_unop(onp.log1p, lax.log1p, True)
sin = _one_to_one_unop(onp.sin, lax.sin, True)
cos = _one_to_one_unop(onp.cos, lax.cos, True)
tan = _one_to_one_unop(onp.tan, lax.tan, True)
arcsin = _one_to_one_unop(onp.arcsin, lax.asin, True)
arccos = _one_to_one_unop(onp.arccos, lax.acos, True)
arctan = _one_to_one_unop(onp.arctan, lax.atan, True)
sinh = _one_to_one_unop(onp.sinh, lax.sinh, True)
cosh = _one_to_one_unop(onp.cosh, lax.cosh, True)
tanh = _one_to_one_unop(onp.tanh, lax.tanh, True)
sqrt = _one_to_one_unop(onp.sqrt, lax.sqrt, True)


add = _maybe_bool_binop(onp.add, lax.add, lax.bitwise_or)
bitwise_and = _one_to_one_binop(onp.bitwise_and, lax.bitwise_and)
bitwise_or = _one_to_one_binop(onp.bitwise_or, lax.bitwise_or)
bitwise_xor = _one_to_one_binop(onp.bitwise_xor, lax.bitwise_xor)
right_shift = _one_to_one_binop(onp.right_shift, lax.shift_right_arithmetic)
left_shift = _one_to_one_binop(onp.left_shift, lax.shift_left)
equal = _one_to_one_binop(onp.equal, lax.eq)
multiply = _maybe_bool_binop(onp.multiply, lax.mul, lax.bitwise_and)
not_equal = _one_to_one_binop(onp.not_equal, lax.ne)
github google / trax / trax / layers / research / efficient_attention.py View on Github external
"""Masks out elements attending to self.

      Args:
        N: number of query positions
        M: number of key positions
        k: position of the initial query element

      Returns:
        N x M mask, where 1.0 indicates that attention is not allowed.
      """
      x = jax.lax.tie_in(k, np.arange(N, dtype=np.int32))
      y = jax.lax.tie_in(k, np.arange(M, dtype=np.int32))
      mask = jax.lax.eq(
          (jax.lax.broadcast_in_dim(
              x, shape=(N, M), broadcast_dimensions=(0,)) + k),
          jax.lax.broadcast(y, [N]))
      mask = jax.lax.convert_element_type(mask, np.float32)
      return mask
github google / jax / jax / scipy / stats / uniform.py View on Github external
def logpdf(x, loc=0, scale=1):
  x, loc, scale = _promote_args_inexact("uniform.logpdf", x, loc, scale)
  log_probs = lax.neg(lax.log(scale))
  return where(logical_or(lax.gt(x, lax.add(loc, scale)),
                          lax.lt(x, loc)),
               -inf, log_probs)
github tensorflow / tensor2tensor / tensor2tensor / trax / layers / attention.py View on Github external
query_slice = jax.lax.dynamic_slice_in_dim(
          query, q_loop_idx, q_loop_stride, axis=-2)

      if do_backprop:
        ct_slice = jax.lax.dynamic_slice_in_dim(
            ct, q_loop_idx, q_loop_stride, axis=-2)
        out_slice, partial_ct = forward_and_vjp_slice(
            query_slice, q_loop_idx, key, value, ct_slice)
        query_ct_accum = jax.lax.dynamic_update_slice_in_dim(
            query_ct_accum, partial_ct[0], q_loop_idx, axis=-2)
        key_ct_accum = key_ct_accum + partial_ct[1]
        value_ct_accum = value_ct_accum + partial_ct[2]
      else:
        out_slice = forward_slice(query_slice, q_loop_idx, key, value)

      out_accum = jax.lax.dynamic_update_slice_in_dim(
          out_accum, out_slice, q_loop_idx, axis=-2)
      q_loop_idx = q_loop_idx + q_loop_stride

      if do_backprop:
        return (q_loop_idx, out_accum,
                query_ct_accum, key_ct_accum, value_ct_accum)
      else:
        return (q_loop_idx, out_accum)
github pyro-ppl / numpyro / numpyro / distributions / util.py View on Github external
def vec_to_tril_matrix(t, diagonal=0):
    # NB: the following formula only works for diagonal <= 0
    n = round((math.sqrt(1 + 8 * t.shape[-1]) - 1) / 2) - diagonal
    n2 = n * n
    idx = jnp.reshape(jnp.arange(n2), (n, n))[jnp.tril_indices(n, diagonal)]
    x = lax.scatter_add(jnp.zeros(t.shape[:-1] + (n2,)), jnp.expand_dims(idx, axis=-1), t,
                        lax.ScatterDimensionNumbers(update_window_dims=range(t.ndim - 1),
                                                    inserted_window_dims=(t.ndim - 1,),
                                                    scatter_dims_to_operand_dims=(t.ndim - 1,)))
    return jnp.reshape(x, x.shape[:-1] + (n, n))
github JuliusKunze / jaxnet / examples / wavenet.py View on Github external
|             |-> (*) -|
        input -|-> [filter] -|        |-> 1x1 conv -|
               |                                    |-> (+) -> dense output
               |------------------------------------|

        Where `[gate]` and `[filter]` are causal convolutions with a
        non-linear activation at the output
        """
        gated = Sequential(Conv1D(dilation_channels, (filter_width,),
                                  dilation=(dilation,)), sigmoid)(inputs)
        filtered = Sequential(Conv1D(dilation_channels, (filter_width,),
                                     dilation=(dilation,)), np.tanh)(inputs)
        p = gated * filtered
        out = Conv1D(residual_channels, (1,), padding='SAME')(p)
        # Add the transformed output of the resblock to the sliced input:
        sliced_inputs = lax.dynamic_slice(
            inputs, [0, inputs.shape[1] - out.shape[1], 0],
            [inputs.shape[0], out.shape[1], inputs.shape[2]])
        new_out = sum(out, sliced_inputs)
        skip = Conv1D(residual_channels, (1,), padding='SAME')(skip_slice(p, output_width))
        return new_out, skip