Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
ff_activation: the non-linearity in feed-forward layer
Returns:
A list of layers that maps (activations, mask) to (activations, mask).
"""
attention = tl.Attention(
d_model, n_heads=n_heads, dropout=dropout, mode=mode)
dropout_ = tl.Dropout(
rate=dropout, name='dropout_enc_attn', mode=mode)
feed_forward = _FeedForwardBlock(
d_model, d_ff, dropout, layer_idx, mode, ff_activation)
return [
tl.Residual(
tl.LayerNorm(),
attention,
dropout_,
),
tl.Residual(
feed_forward
),
d_ff: int: depth of feed-forward layer
n_heads: int: number of attention heads
dropout: float: dropout rate (how much to drop out)
mode: str: 'train' or 'eval'
"""
return tl.Serial(
tl.Residual( # Self-attention block.
tl.LayerNorm(),
AttentionPosition(positions=positions,
d_model=d_model,
n_heads=n_heads,
dropout=dropout,
mode=mode),
tl.Dropout(rate=dropout, mode=mode)
),
tl.Residual(
tl.LayerNorm(),
tl.Dense(d_ff),
tl.Relu(),
tl.Dropout(rate=dropout, mode=mode),
tl.Dense(d_model),
tl.Dropout(rate=dropout, mode=mode),
)
attention = tl.Attention(
d_model, n_heads=n_heads, dropout=dropout, mode=mode)
dropout_ = tl.Dropout(
rate=dropout, name='dropout_enc_attn', mode=mode)
feed_forward = _FeedForwardBlock(
d_model, d_ff, dropout, layer_idx, mode, ff_activation)
return [
tl.Residual(
tl.LayerNorm(),
attention,
dropout_,
),
tl.Residual(
feed_forward
),
d_attention_value=d_attn_value, attention_type=attn_type,
share_qk=share_qk, mode=mode),
dropout_ = tl.Dropout(
rate=dropout, name='attention_%d' % layer_idx, mode=mode)
feed_forward = _FeedForwardBlock(
d_model, d_ff, dropout, layer_idx, mode, ff_activation)
return [
tl.Residual(
tl.LayerNorm(),
causal_attention,
dropout_,
),
tl.Residual(
feed_forward
),
Returns:
A list of layers that maps an activation tensor to an activation tensor.
"""
causal_attention = tl.CausalAttention(
d_model, n_heads=n_heads, d_attention_key=d_attn_key,
d_attention_value=d_attn_value, attention_type=attn_type,
share_qk=share_qk, mode=mode),
dropout_ = tl.Dropout(
rate=dropout, name='attention_%d' % layer_idx, mode=mode)
feed_forward = _FeedForwardBlock(
d_model, d_ff, dropout, layer_idx, mode, ff_activation)
return [
tl.Residual(
tl.LayerNorm(),
causal_attention,
dropout_,
),
tl.Residual(
feed_forward
),
encoder_activations) to triples of the same sort.
"""
def _Dropout():
return tl.Dropout(rate=dropout, mode=mode)
attention_qkv = tl.AttentionQKV(
d_model, n_heads=n_heads, dropout=dropout, mode=mode)
causal_attention = tl.CausalAttention(
d_model, n_heads=n_heads, mode=mode)
feed_forward = _FeedForwardBlock(
d_model, d_ff, dropout, layer_idx, mode, ff_activation)
return [ # vec_d masks vec_e
tl.Residual(
tl.LayerNorm(), # vec_d ..... .....
causal_attention, # vec_d ..... .....
_Dropout(), # vec_d ..... .....
),
tl.Residual(
tl.LayerNorm(), # vec_d ..... .....
tl.Select([0, 2, 2, 1, 2]), # vec_d vec_e vec_e masks vec_e
attention_qkv, # vec_d masks vec_e
_Dropout(), # vec_d masks vec_e
),
tl.Residual(
feed_forward # vec_d masks vec_e
),
attention_qkv = tl.AttentionQKV(
d_model, n_heads=n_heads, dropout=dropout, mode=mode)
causal_attention = tl.CausalAttention(
d_model, n_heads=n_heads, mode=mode)
feed_forward = _FeedForwardBlock(
d_model, d_ff, dropout, layer_idx, mode, ff_activation)
return [ # vec_d masks vec_e
tl.Residual(
tl.LayerNorm(), # vec_d ..... .....
causal_attention, # vec_d ..... .....
_Dropout(), # vec_d ..... .....
),
tl.Residual(
tl.LayerNorm(), # vec_d ..... .....
tl.Select([0, 2, 2, 1, 2]), # vec_d vec_e vec_e masks vec_e
attention_qkv, # vec_d masks vec_e
_Dropout(), # vec_d masks vec_e
),
tl.Residual(
feed_forward # vec_d masks vec_e
),
"""ResNet identical size block."""
# TODO(jonni): Use good defaults so Resnet50 code is cleaner / less redundant.
ks = kernel_size
filters1, filters2, filters3 = filters
main = [
tl.Conv(filters1, (1, 1)),
norm(mode=mode),
non_linearity(),
tl.Conv(filters2, (ks, ks), padding='SAME'),
norm(mode=mode),
non_linearity(),
tl.Conv(filters3, (1, 1)),
norm(mode=mode),
]
return [
tl.Residual(main),
non_linearity(),
]