How to use the trax.backend.numpy.array function in trax

To help you get started, we’ve selected a few trax examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github google / trax / trax / models / research / position_lookup_transformer.py View on Github external
def PerformPositionOperations(pos, positions=None):
  """Gets pos and returns (q1, ..., q5)."""
  succ_keys = positions[:-1, :]
  succ_values = positions[1:, :]
  subtract_1_keys = positions[1:, :]
  subtract_1_values = positions[:-1, :]
  l = int(positions.shape[0]) // 2
  add_keys = np.array([np.concatenate([positions[i, :], positions[j, :]])
                       for i in range(l) for j in range(l)])
  add_values = np.array([positions[i + j, :]
                         for i in range(l) for j in range(l)])
  # TODO(lukaszkaiser): try this below: "for j in range(i) for i in range(2*l)"
  sub_keys = np.array([np.concatenate([positions[i, :], positions[j, :]])
                       for j in range(l) for i in range(l)])
  sub_values = np.array([positions[max(i - j, 0), :]
                         for j in range(l) for i in range(l)])
  query_types = [
      QueryPositionKV(),
      QueryPositionKV(keys=succ_keys, values=succ_values),
      QueryPositionKV(keys=subtract_1_keys, values=subtract_1_values),
      QueryPositionKV(keys=add_keys, values=add_values, binary=True),
      QueryPositionKV(keys=sub_keys, values=sub_values, binary=True)]
  return [qt @ pos for qt in query_types]  # pylint: disable=syntax-error
github google / trax / trax / models / research / position_lookup_transformer.py View on Github external
def PerformPositionOperations(pos, positions=None):
  """Gets pos and returns (q1, ..., q5)."""
  succ_keys = positions[:-1, :]
  succ_values = positions[1:, :]
  subtract_1_keys = positions[1:, :]
  subtract_1_values = positions[:-1, :]
  l = int(positions.shape[0]) // 2
  add_keys = np.array([np.concatenate([positions[i, :], positions[j, :]])
                       for i in range(l) for j in range(l)])
  add_values = np.array([positions[i + j, :]
                         for i in range(l) for j in range(l)])
  # TODO(lukaszkaiser): try this below: "for j in range(i) for i in range(2*l)"
  sub_keys = np.array([np.concatenate([positions[i, :], positions[j, :]])
                       for j in range(l) for i in range(l)])
  sub_values = np.array([positions[max(i - j, 0), :]
                         for j in range(l) for i in range(l)])
  query_types = [
      QueryPositionKV(),
      QueryPositionKV(keys=succ_keys, values=succ_values),
      QueryPositionKV(keys=subtract_1_keys, values=subtract_1_values),
      QueryPositionKV(keys=add_keys, values=add_values, binary=True),
      QueryPositionKV(keys=sub_keys, values=sub_values, binary=True)]
  return [qt @ pos for qt in query_types]  # pylint: disable=syntax-error
github google / trax / trax / models / research / position_lookup_transformer.py View on Github external
def PerformPositionOperations(pos, positions=None):
  """Gets pos and returns (q1, ..., q5)."""
  succ_keys = positions[:-1, :]
  succ_values = positions[1:, :]
  subtract_1_keys = positions[1:, :]
  subtract_1_values = positions[:-1, :]
  l = int(positions.shape[0]) // 2
  add_keys = np.array([np.concatenate([positions[i, :], positions[j, :]])
                       for i in range(l) for j in range(l)])
  add_values = np.array([positions[i + j, :]
                         for i in range(l) for j in range(l)])
  # TODO(lukaszkaiser): try this below: "for j in range(i) for i in range(2*l)"
  sub_keys = np.array([np.concatenate([positions[i, :], positions[j, :]])
                       for j in range(l) for i in range(l)])
  sub_values = np.array([positions[max(i - j, 0), :]
                         for j in range(l) for i in range(l)])
  query_types = [
      QueryPositionKV(),
      QueryPositionKV(keys=succ_keys, values=succ_values),
      QueryPositionKV(keys=subtract_1_keys, values=subtract_1_values),
      QueryPositionKV(keys=add_keys, values=add_values, binary=True),
      QueryPositionKV(keys=sub_keys, values=sub_values, binary=True)]
  return [qt @ pos for qt in query_types]  # pylint: disable=syntax-error
github google / trax / trax / layers / core.py View on Github external
def one_hot(x, size, dtype=np.float32):  # pylint: disable=invalid-name
  """Make a n+1 dim one-hot array from n dim int-categorical array."""
  arange_size = np.arange(size)
  if backend.get_name() == 'jax':
    # Work around a jax broadcasting issue.
    arange_size = jax.lax.tie_in(x, arange_size)
  return np.array(x[..., np.newaxis] == arange_size, dtype)
github google / trax / trax / optimizers / base.py View on Github external
"""Initialize the optimizer.

    Takes the initial optimizer parameters as positional arguments. They are fed
    back to the optimizer in tree_update, in the same order. They can be changed
    between updates, e.g. for learning rate schedules.

    The constructor should be overridden in derived classes to give names to the
    optimizer parameters, so the gin configuration can set them.

    Args:
      learning_rate: The initial learning rate.
      **init_opt_params: Initial values of any additional optimizer parameters.
    """
    init_opt_params['learning_rate'] = learning_rate
    self._init_opt_params = {
        name: np.array(value) for (name, value) in init_opt_params.items()
    }
github google / trax / trax / layers / core.py View on Github external
def new_weights_and_state(self, input_signature):
    del input_signature
    state = {self._name: np.array(self._initial_rate)}
    return base.EMPTY_WEIGHTS, state
github google / trax / trax / optimizers / adafactor.py View on Github external
def _decay_rate_pow(i, exponent=0.8):
    """Default Adafactor second-moment decay schedule."""
    t = np.array(i, np.float32) + 1.0
    return 1.0 - t**(-exponent)