Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
x.shape[0] * n_sections,
x.shape[1] // n_sections,
) + x.shape[2:])
@tl.layer()
def Unchunk(x, weights, n_sections=2, **kwargs):
del weights, kwargs
assert x.shape[0] % n_sections == 0
return np.reshape(x, (
x.shape[0] // n_sections,
x.shape[1] * n_sections,
) + x.shape[2:])
class ReversibleHalfResidual(tl.ReversibleLayer, tl.Serial):
"""Half of a RevNet-style residual (only updates part of the hidden state)."""
def __init__(self, residual_layers):
self.compute_residual = tl.Serial(
# (x1_or_y1, x2) -> (x2, x1_or_y1, x2)
tl.Parallel([], tl.Dup()),
tl.Swap(),
tl.Parallel(residual_layers, [], []),
)
layers = [
self.compute_residual,
tl.Parallel(tl.Add(), [])
]
super(ReversibleHalfResidual, self).__init__(layers)
def MultiRNNCell():
"""Multi-layer RNN cell."""
assert n_layers == 2
return tl.Serial(
tl.Parallel([], tl.Split(n_items=n_layers)),
tl.SerialWithSideOutputs(
[rnn_cell(n_units=d_model) for _ in range(n_layers)]),
tl.Parallel([], tl.Concatenate(n_items=n_layers))
)
d_model: int: depth of embedding
d_ff: int: depth of feed-forward layer
n_heads: int: number of attention heads
dropout: float: dropout rate (how much to drop out)
layer_idx: which layer are we at (for bookkeeping)
mode: str: 'train' or 'eval'
ff_activation: the non-linearity in feed-forward layer
Returns:
A list of layers which maps triples (decoder_activations, mask,
encoder_activations) to triples of the same sort.
"""
def _Dropout():
return tl.Dropout(rate=dropout, mode=mode)
attention_qkv = tl.AttentionQKV(
d_model, n_heads=n_heads, dropout=dropout, mode=mode)
causal_attention = tl.CausalAttention(
d_model, n_heads=n_heads, mode=mode)
feed_forward = _FeedForwardBlock(
d_model, d_ff, dropout, layer_idx, mode, ff_activation)
return [ # vec_d masks vec_e
tl.Residual(
tl.LayerNorm(), # vec_d ..... .....
causal_attention, # vec_d ..... .....
_Dropout(), # vec_d ..... .....
),
tl.Residual(
tl.LayerNorm(), # vec_d ..... .....
def __init__(self, residual_layers):
self.compute_residual = tl.Serial(
# (x1_or_y1, x2) -> (x2, x1_or_y1, x2)
tl.Parallel([], tl.Dup()),
tl.Swap(),
tl.Parallel(residual_layers, [], []),
)
layers = [
self.compute_residual,
tl.Parallel(tl.Add(), [])
]
super(ReversibleHalfResidual, self).__init__(layers)
self.subtract_top = tl.Parallel(tl.SubtractTop(), [])
self.reverse_layers = [self.compute_residual, self.subtract_top]
def __init__(self, residual_layers):
self.compute_residual = tl.Serial(
# (x1_or_y1, x2) -> (x2, x1_or_y1, x2)
tl.Parallel([], tl.Dup()),
tl.Swap(),
tl.Parallel(residual_layers, [], []),
)
layers = [
self.compute_residual,
tl.Parallel(tl.Add(), [])
]
super(ReversibleHalfResidual, self).__init__(layers)
self.subtract_top = tl.Parallel(tl.SubtractTop(), [])
self.reverse_layers = [self.compute_residual, self.subtract_top]
def ConvBlock(kernel_size, filters, strides, norm, non_linearity,
mode='train'):
"""ResNet convolutional striding block."""
# TODO(jonni): Use good defaults so Resnet50 code is cleaner / less redundant.
ks = kernel_size
filters1, filters2, filters3 = filters
main = [
tl.Conv(filters1, (1, 1), strides),
norm(mode=mode),
non_linearity(),
tl.Conv(filters2, (ks, ks), padding='SAME'),
norm(mode=mode),
non_linearity(),
tl.Conv(filters3, (1, 1)),
norm(mode=mode),
]
shortcut = [
tl.Conv(filters3, (1, 1), strides),
norm(mode=mode),
]
return [
tl.Residual(main, shortcut=shortcut),
non_linearity()
]
d_ff: int: depth of feed-forward layer
dropout: float: dropout rate (how much to drop out)
layer_idx: which layer are we at (for bookkeeping)
mode: str: 'train' or 'eval'
activation: the non-linearity in feed-forward layer
Returns:
A list of layers which maps vectors to vectors.
"""
dropout_middle = tl.Dropout(
rate=dropout, name='ff_middle_%d' % layer_idx, mode=mode)
dropout_final = tl.Dropout(
rate=dropout, name='ff_final_%d' % layer_idx, mode=mode)
return [
tl.LayerNorm(),
tl.Dense(d_ff),
activation(),
dropout_middle,
tl.Dense(d_model),
dropout_final,
]
@tl.layer()
def FlattenControlsIntoTime(x, **unused_kwargs): # pylint: disable=invalid-name
"""Splits logits for actions in different controls and flattens controls."""
return np.reshape(x, (x.shape[0], -1, n_actions))
created from the original source tokens to prevent attending to the padding
part of the input.
Args:
d_model: int: depth of embedding
d_ff: int: depth of feed-forward layer
n_heads: int: number of attention heads
dropout: float: dropout rate (how much to drop out)
layer_idx: which layer are we at (for bookkeeping)
mode: str: 'train' or 'eval'
ff_activation: the non-linearity in feed-forward layer
Returns:
A list of layers that maps (activations, mask) to (activations, mask).
"""
attention = tl.Attention(
d_model, n_heads=n_heads, dropout=dropout, mode=mode)
dropout_ = tl.Dropout(
rate=dropout, name='dropout_enc_attn', mode=mode)
feed_forward = _FeedForwardBlock(
d_model, d_ff, dropout, layer_idx, mode, ff_activation)
return [
tl.Residual(
tl.LayerNorm(),
attention,
dropout_,
),
tl.Residual(
feed_forward