How to use the trax.layers.Dense function in trax

To help you get started, we’ve selected a few trax examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github google / trax / trax / models / research / reformer.py View on Github external
def FeedForward(d_model, d_ff, dropout, activation, mode):
  """Feed-forward block with layer normalization at start."""
  return [
      tl.LayerNorm(),
      tl.Dense(d_ff),
      BroadcastedDropout(rate=dropout, mode=mode),  # pylint: disable=no-value-for-parameter
      activation(),
      tl.Dense(d_model),
      BroadcastedDropout(rate=dropout, mode=mode),  # pylint: disable=no-value-for-parameter
  ]
github google / trax / trax / models / research / position_lookup_transformer.py View on Github external
def AppendLearnedPosOperation(vec, q1, q2, q3, q4, q5):
  """Get (vec, q1, ...) and return new_pos."""
  # Create 5 scalar weights (length 1 vectors) from first component of input.
  ws = [tl.Dense(1) @ vec for _ in range(5)]
  new_pos = Softmax5Branches() @ (ws + [q1, q2, q3, q4, q5])
  return new_pos
github google / trax / trax / models / neural_gpu.py View on Github external
Args:
    d_feature: Number of memory channels (dimensionality of feature embedding).
    steps: Number of times depthwise recurrence steps.
    vocab_size: Vocabulary size.
    mode: Whether we are training or evaluating or doing inference.

  Returns:
    A NeuralGPU Stax model.
  """
  del mode

  core = ConvDiagonalGRU(units=d_feature)
  return tl.Serial(
      tl.Embedding(d_feature=d_feature, vocab_size=vocab_size),
      [core] * steps,
      tl.Dense(vocab_size),
      tl.LogSoftmax(),
  )
github google / trax / trax / models / research / reformer.py View on Github external
def FeedForward(d_model, d_ff, dropout, activation, mode):
  """Feed-forward block with layer normalization at start."""
  return [
      tl.LayerNorm(),
      tl.Dense(d_ff),
      BroadcastedDropout(rate=dropout, mode=mode),  # pylint: disable=no-value-for-parameter
      activation(),
      tl.Dense(d_model),
      BroadcastedDropout(rate=dropout, mode=mode),  # pylint: disable=no-value-for-parameter
  ]
github google / trax / trax / models / rnn.py View on Github external
[rnn_cell(n_units=d_model) for _ in range(n_layers)]),
        tl.Parallel([], tl.Concatenate(n_items=n_layers))
    )

  zero_state = tl.MakeZeroState(  # pylint: disable=no-value-for-parameter
      depth_multiplier=n_layers * rnn_cell_d_state_multiplier
  )

  return tl.Serial(
      tl.ShiftRight(mode=mode),
      tl.Embedding(d_model, vocab_size),
      tl.Dropout(rate=dropout, name='embedding', mode=mode),
      tl.Branch([], zero_state),
      tl.Scan(MultiRNNCell(), axis=1),
      tl.Select([0], n_in=2),  # Drop RNN state.
      tl.Dense(vocab_size),
      tl.LogSoftmax()
  )