How to use the trax.backend.numpy function in trax

To help you get started, we’ve selected a few trax examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github google / trax / trax / optimizers / sm3.py View on Github external
def init(self, params):
    vs = [np.zeros(sz, dtype=params.dtype) for sz in params.shape]
    return (np.zeros_like(params), vs)
github google / trax / trax / layers / rnn.py View on Github external
def forward(self, inputs, weights):
    x, gru_state = inputs

    # Dense layer on the concatenation of x and h.
    w1, b1, w2, b2 = weights
    y = np.dot(np.concatenate([x, gru_state], axis=-1), w1) + b1

    # Update and reset gates.
    u, r = np.split(math.sigmoid(y), 2, axis=-1)

    # Candidate.
    c = np.dot(np.concatenate([x, r * gru_state], axis=-1), w2) + b2

    new_gru_state = u * gru_state + (1 - u) * np.tanh(c)
    return new_gru_state, new_gru_state
github google / trax / trax / layers / attention.py View on Github external
rng: JAX PRNGKey: subkey for disposable use

  Returns:
    Self attention for q, k, v arrays.
  """
  depth = np.shape(query)[-1]
  dots = np.matmul(query, np.swapaxes(key, -1, -2)) / np.sqrt(depth)
  if mask is not None:
    # TODO(kitaev): workaround for https://github.com/google/jax/issues/850
    # We must ensure that both mask and the -1e9 constant have a data dependency
    # on the input. Broadcasted copies of these use a lot of memory, so they
    # should be computed at runtime (rather than being global constants).
    if backend.get_name() == 'jax':
      mask = jax.lax.tie_in(dots, mask)
    # JAX's `full_like` already ties in -1e9 to dots.
    dots = np.where(mask, dots, np.full_like(dots, -1e9))
  # Softmax.
  dots = np.exp(dots - backend.logsumexp(dots, axis=-1, keepdims=True))
  if dropout >= 1.0:
    raise ValueError('Dropout rates must be lower than 1.')
  if dropout is not None and dropout > 0.0 and mode == 'train':
    keep = backend.random.bernoulli(rng, 1.0 - dropout, dots.shape)
    dots = np.where(keep, dots / (1.0 - dropout), np.zeros_like(dots))
  out = np.matmul(dots, value)
  return out
github google / trax / trax / layers / core.py View on Github external
def forward(self, x, weights):
    return np.take(weights, x, axis=0)
github google / trax / trax / layers / research / efficient_attention.py View on Github external
dots_thresh = np.sort(merged_top_dots)[:, -self._hard_k]
      # It's possible to compute the partition function at this point, but right
      # now this codepath isn't set up for backprop, and there might also be
      # issues computing it this way if two dot-products are exactly equal.

      sdots_thresh = dots_thresh[st]
      bdots_thresh = np.reshape(sdots_thresh, (self.n_hashes * self.n_bins, -1))
      bdots_thresh = jax.lax.stop_gradient(bdots_thresh)

      top_k_mask = jax.lax.convert_element_type(
          dots < bdots_thresh[..., None], np.float32)
      dots = dots - 1e7 * jax.lax.stop_gradient(top_k_mask)

    # Softmax.
    dots_logsumexp = backend.logsumexp(dots, axis=-1, keepdims=True)
    dots = np.exp(dots - dots_logsumexp)

    if self._dropout > 0.0:
      # Dropout is broadcast across the bin dimension
      dropout_shape = (1, dots.shape[-2], dots.shape[-1])
      keep_prob = jax.lax.tie_in(dots, 1.0 - self._dropout)
      keep = backend.random.bernoulli(rng, keep_prob, dropout_shape)
      multiplier = keep.astype(dots.dtype) / jax.lax.tie_in(keep, keep_prob)
      dots = dots * multiplier

    bo = np.matmul(dots, bv)
    so = np.reshape(bo, (-1, bo.shape[-1]))
    slogits = np.reshape(dots_logsumexp, (-1,))

    def unsort_for_output_impl(so, slogits):
      o = np.take(so, undo_sort, axis=0)
      # Sorting is considerably faster than gather, but first we need to get the
github google / trax / trax / supervised / trainer_lib.py View on Github external
def update_model_state(self, key, value):
    """Updates model state based on nontrainable_params."""
    # Translate model state keys to nontrainable param names.
    if key in self._nontrainable_param_map:
      p_name = self._nontrainable_param_map[key]
    else:
      # If a key is not in mapping, it stays the same.
      p_name = key
    if p_name in self.nontrainable_params:
      if self._step == 0:
        log('Mapping model state key {} to nontrainable param {}.'
            ''.format(key, p_name))
        return self._for_n_devices(np.array(self.nontrainable_params[p_name]))
    return value
github google / trax / trax / optimizers / sm3.py View on Github external
def _minimum(self, tensor_list):
    minimum = tensor_list[0]
    for i in range(1, len(tensor_list)):
      minimum = np.minimum(minimum, tensor_list[i])
    return minimum
github google / trax / trax / layers / initializers.py View on Github external
n_cols = cur_shape[-1]
    flat_shape = (n_cols, n_rows) if n_rows < n_cols else (n_rows, n_cols)

    # Generate a random matrix
    a = random.normal(rng, flat_shape, dtype=np.float32)

    # Compute the qr factorization
    q, r = np.linalg.qr(a)

    # Make Q uniform
    d = np.diag(r)
    q *= np.sign(d)

    # Transpose and reshape back q if needed.
    if n_rows < n_cols:
      q = np.transpose(q)
    q = np.reshape(q, shape)

    # Return scaled as requested.
    return stddev * q
github google / trax / trax / layers / attention.py View on Github external
emb = np.concatenate(embs, -1)

    if self._mode == 'predict':
      assert self._dropout == 0.0
      emb = np.reshape(emb, (inputs.shape[0], -1, emb.shape[-1]))
      return inputs + emb[:, state, :][:, None, :], state + 1
    elif self._dropout == 0:
      return inputs + np.reshape(emb, inputs.shape), state
    else:
      noise_shape = list(emb.shape)
      for dim in self._dropout_broadcast_dims:
        noise_shape[dim] = 1
      keep_prob = 1.0 - self._dropout
      if backend.get_name() == 'jax':
        keep_prob = jax.lax.tie_in(
            inputs, np.full((), keep_prob, dtype=inputs.dtype))
      keep = backend.random.bernoulli(rng, keep_prob, tuple(noise_shape))
      multiplier = keep.astype(inputs.dtype) / keep_prob

      return inputs + np.reshape(emb * multiplier, inputs.shape), state
github google / trax / trax / supervised / inputs.py View on Github external
def dataset_to_stream(dataset, input_name):
  """Takes a tf.Dataset and creates a numpy stream of ready batches."""
  for example in backend.dataset_as_numpy(dataset):
    features = example[0]
    inp, out = features[input_name], example[1]
    mask = features['mask'] if 'mask' in features else None
    # All input-pipeline processing should be on CPU.
    with tf.device('cpu:0'):
      # Some accelerators don't handle uint8 well, cast to int.
      if isinstance(inp, np.uint8):
        inp = inp.astype(np.int32)
      if isinstance(out, np.uint8):
        out = out.astype(np.int32)
      if len(out.shape) > 1 and out.shape[-1] == 1:
        out = np.squeeze(out, axis=-1)
    yield (inp, out) if mask is None else (inp, out, mask)