How to use the trax.layers.Residual function in trax

To help you get started, we’ve selected a few trax examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github google / trax / trax / models / transformer.py View on Github external
ff_activation: the non-linearity in feed-forward layer

  Returns:
    A list of layers that maps (activations, mask) to (activations, mask).
  """
  attention = tl.Attention(
      d_model, n_heads=n_heads, dropout=dropout, mode=mode)

  dropout_ = tl.Dropout(
      rate=dropout, name='dropout_enc_attn', mode=mode)

  feed_forward = _FeedForwardBlock(
      d_model, d_ff, dropout, layer_idx, mode, ff_activation)

  return [
      tl.Residual(
          tl.LayerNorm(),
          attention,
          dropout_,
      ),
      tl.Residual(
          feed_forward
      ),
github google / trax / trax / models / research / position_lookup_transformer.py View on Github external
d_ff: int: depth of feed-forward layer
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    mode: str: 'train' or 'eval'
  """
  return tl.Serial(
      tl.Residual(  # Self-attention block.
          tl.LayerNorm(),
          AttentionPosition(positions=positions,
                            d_model=d_model,
                            n_heads=n_heads,
                            dropout=dropout,
                            mode=mode),
          tl.Dropout(rate=dropout, mode=mode)
      ),
      tl.Residual(
          tl.LayerNorm(),
          tl.Dense(d_ff),
          tl.Relu(),
          tl.Dropout(rate=dropout, mode=mode),
          tl.Dense(d_model),
          tl.Dropout(rate=dropout, mode=mode),
      )
github google / trax / trax / models / transformer.py View on Github external
attention = tl.Attention(
      d_model, n_heads=n_heads, dropout=dropout, mode=mode)

  dropout_ = tl.Dropout(
      rate=dropout, name='dropout_enc_attn', mode=mode)

  feed_forward = _FeedForwardBlock(
      d_model, d_ff, dropout, layer_idx, mode, ff_activation)

  return [
      tl.Residual(
          tl.LayerNorm(),
          attention,
          dropout_,
      ),
      tl.Residual(
          feed_forward
      ),
github google / trax / trax / models / transformer.py View on Github external
d_attention_value=d_attn_value, attention_type=attn_type,
      share_qk=share_qk, mode=mode),

  dropout_ = tl.Dropout(
      rate=dropout, name='attention_%d' % layer_idx, mode=mode)

  feed_forward = _FeedForwardBlock(
      d_model, d_ff, dropout, layer_idx, mode, ff_activation)

  return [
      tl.Residual(
          tl.LayerNorm(),
          causal_attention,
          dropout_,
      ),
      tl.Residual(
          feed_forward
      ),
github google / trax / trax / models / transformer.py View on Github external
Returns:
    A list of layers that maps an activation tensor to an activation tensor.
  """
  causal_attention = tl.CausalAttention(
      d_model, n_heads=n_heads, d_attention_key=d_attn_key,
      d_attention_value=d_attn_value, attention_type=attn_type,
      share_qk=share_qk, mode=mode),

  dropout_ = tl.Dropout(
      rate=dropout, name='attention_%d' % layer_idx, mode=mode)

  feed_forward = _FeedForwardBlock(
      d_model, d_ff, dropout, layer_idx, mode, ff_activation)

  return [
      tl.Residual(
          tl.LayerNorm(),
          causal_attention,
          dropout_,
      ),
      tl.Residual(
          feed_forward
      ),
github google / trax / trax / models / transformer.py View on Github external
encoder_activations) to triples of the same sort.
  """
  def _Dropout():
    return tl.Dropout(rate=dropout, mode=mode)

  attention_qkv = tl.AttentionQKV(
      d_model, n_heads=n_heads, dropout=dropout, mode=mode)

  causal_attention = tl.CausalAttention(
      d_model, n_heads=n_heads, mode=mode)

  feed_forward = _FeedForwardBlock(
      d_model, d_ff, dropout, layer_idx, mode, ff_activation)

  return [                             # vec_d masks vec_e
      tl.Residual(
          tl.LayerNorm(),              # vec_d ..... .....
          causal_attention,            # vec_d ..... .....
          _Dropout(),                  # vec_d ..... .....
      ),
      tl.Residual(
          tl.LayerNorm(),              # vec_d ..... .....
          tl.Select([0, 2, 2, 1, 2]),  # vec_d vec_e vec_e masks vec_e
          attention_qkv,               # vec_d masks vec_e
          _Dropout(),                  # vec_d masks vec_e
      ),
      tl.Residual(
          feed_forward                 # vec_d masks vec_e
      ),
github google / trax / trax / models / transformer.py View on Github external
attention_qkv = tl.AttentionQKV(
      d_model, n_heads=n_heads, dropout=dropout, mode=mode)

  causal_attention = tl.CausalAttention(
      d_model, n_heads=n_heads, mode=mode)

  feed_forward = _FeedForwardBlock(
      d_model, d_ff, dropout, layer_idx, mode, ff_activation)

  return [                             # vec_d masks vec_e
      tl.Residual(
          tl.LayerNorm(),              # vec_d ..... .....
          causal_attention,            # vec_d ..... .....
          _Dropout(),                  # vec_d ..... .....
      ),
      tl.Residual(
          tl.LayerNorm(),              # vec_d ..... .....
          tl.Select([0, 2, 2, 1, 2]),  # vec_d vec_e vec_e masks vec_e
          attention_qkv,               # vec_d masks vec_e
          _Dropout(),                  # vec_d masks vec_e
      ),
      tl.Residual(
          feed_forward                 # vec_d masks vec_e
      ),
github google / trax / trax / models / resnet.py View on Github external
"""ResNet identical size block."""
  # TODO(jonni): Use good defaults so Resnet50 code is cleaner / less redundant.
  ks = kernel_size
  filters1, filters2, filters3 = filters
  main = [
      tl.Conv(filters1, (1, 1)),
      norm(mode=mode),
      non_linearity(),
      tl.Conv(filters2, (ks, ks), padding='SAME'),
      norm(mode=mode),
      non_linearity(),
      tl.Conv(filters3, (1, 1)),
      norm(mode=mode),
  ]
  return [
      tl.Residual(main),
      non_linearity(),
  ]