Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def AttentionPosition(vec, pos,
positions=None, d_model=None, n_heads=8,
dropout=0.0, mode='train'):
"""Transformer-style multi-headed attention."""
new_posns = list(LearnedPosOperations(positions=positions,
n_combinations=n_heads) @ (vec, pos))
hq = tl.Serial(tl.Dense(d_model),
CopyPosToHeads(n_heads, tile=False)) @ ([vec,] + new_posns)
hk = tl.Serial(tl.Dense(d_model),
CopyPosToHeads(n_heads, tile=True)) @ (vec, pos)
hv = tl.ComputeAttentionHeads(
n_heads=n_heads, d_head=d_model // n_heads) @ vec
x, pos = tl.Serial(
tl.DotProductCausalAttention(dropout=dropout, mode=mode),
CombineHeadsPos(n_heads=n_heads),
tl.Dense(d_model)) @ (hq, hk, hv)
return x, pos
def AttentionPosition(vec, pos,
positions=None, d_model=None, n_heads=8,
dropout=0.0, mode='train'):
"""Transformer-style multi-headed attention."""
new_posns = list(LearnedPosOperations(positions=positions,
n_combinations=n_heads) @ (vec, pos))
hq = tl.Serial(tl.Dense(d_model),
CopyPosToHeads(n_heads, tile=False)) @ ([vec,] + new_posns)
hk = tl.Serial(tl.Dense(d_model),
CopyPosToHeads(n_heads, tile=True)) @ (vec, pos)
hv = tl.ComputeAttentionHeads(
n_heads=n_heads, d_head=d_model // n_heads) @ vec
x, pos = tl.Serial(
tl.DotProductCausalAttention(dropout=dropout, mode=mode),
CombineHeadsPos(n_heads=n_heads),
tl.Dense(d_model)) @ (hq, hk, hv)
return x, pos
def MultiRNNCell():
"""Multi-layer RNN cell."""
assert n_layers == 2
return tl.Serial(
tl.Parallel([], tl.Split(n_items=n_layers)),
tl.SerialWithSideOutputs(
[rnn_cell(n_units=d_model) for _ in range(n_layers)]),
tl.Parallel([], tl.Concatenate(n_items=n_layers))
)
input_dtype = [input_dtype]
input_shape = [input_shape]
if not isinstance(target_dtype, (list, tuple)):
target_dtype = [target_dtype]
target_shape = [target_shape]
dtypes = list(input_dtype) + list(target_dtype)
shapes = list(input_shape) + list(target_shape)
if self._has_weights:
shapes += list(target_shape)
dtypes += [np.float32 for _ in target_dtype]
input_signature = tuple(ShapeDtype(s, d)
for (s, d) in zip(shapes, dtypes))
# We need to create a new model instance and not reuse `model_train` here,
# because `m.initialize` puts cached parameter values in `m` and hence the
# next call of `m.initialize` will give wrong results.
m = tl.Serial(model(mode='train'), loss_fn)
m._set_rng_recursive(rng) # pylint: disable=protected-access
weights, state = m.init(input_signature)
(slots, opt_params) = opt.tree_init(weights)
return (OptState(weights, slots, opt_params), state)
else PositionalEncoder(output_vocab_size))
if output_vocab_size is None:
output_vocab_size = input_vocab_size
encoder_blocks = [
_EncoderBlock(
d_model, d_ff, n_heads, dropout, i, mode, ff_activation)
for i in range(n_encoder_layers)]
encoder_decoder_blocks = [
_EncoderDecoderBlock(
d_model, d_ff, n_heads, dropout, i, mode, ff_activation)
for i in range(n_decoder_layers)]
# Assemble and return the model.
return tl.Serial(
# Input: encoder_side_tokens, decoder_side_tokens
# Copy decoder tokens for use in loss.
tl.Select([0, 1, 1]), # tok_e tok_d tok_d
# Encode.
tl.Branch(
in_encoder, tl.PaddingMask()), # vec_e masks ..... .....
encoder_blocks, # vec_d masks ..... .....
tl.LayerNorm(), # vec_e ..... ..... .....
# Decode.
tl.Select([2, 1, 0]), # tok_d masks vec_e .....
tl.ShiftRight(), # tok_d ..... ..... .....
out_encoder, # vec_d ..... ..... .....
tl.Branch(
[], tl.EncoderDecoderMask()), # vec_d masks ..... .....
def __init__(self, layer, n_sections=1, check_shapes=True):
"""Initialize the combinator.
Args:
layer: a layer to apply to each element.
n_sections: how many sections to map to (default: 1).
check_shapes: whether to check that shapes are identical (default: true).
Returns:
A new layer representing mapping layer to all elements of the input.
"""
super(Map, self).__init__(n_in=n_sections, n_out=n_sections)
if layer is None or isinstance(layer, (list, tuple)):
layer = tl.Serial(layer)
self._layer = layer
# Generally a Map should be applied to lists where all elements have
# the same shape -- because self._layer will only be initialized once
# and it could have different parameters for different shapes. But there
# are valid cases -- e.g., when self._layer has no parameters -- where we
# can apply Map to different shapes -- set check_shapes=False in such cases.
self._check_shapes = check_shapes
self._n_sections = n_sections
tl.Dense(n_preds_per_input),
tl.Flatten()],
)
]
else:
layers = [
bottom_layers_fn(**kwargs),
tl.Dup(),
tl.Parallel(
[tl.Dense(n_preds_per_input * n_actions),
FlattenControlsIntoTime(), # pylint: disable=no-value-for-parameter
tl.LogSoftmax()],
[tl.Dense(n_preds_per_input), tl.Flatten()],
)
]
return tl.Serial(layers)
def AtariCnn(n_frames=4, hidden_sizes=(32, 32), output_size=128, mode='train'):
"""An Atari CNN."""
del mode
# TODO(jonni): Include link to paper?
# Input shape: (B, T, H, W, C)
# Output shape: (B, T, output_size)
return tl.Serial(
tl.Fn(lambda x: x / 255.0), # Convert unsigned bytes to float.
_FrameStack(n_frames=n_frames), # (B, T, H, W, 4C)
tl.Conv(hidden_sizes[0], (5, 5), (2, 2), 'SAME'),
tl.Relu(),
tl.Conv(hidden_sizes[1], (5, 5), (2, 2), 'SAME'),
tl.Relu(),
tl.Flatten(n_axes_to_keep=2), # B, T and rest.
tl.Dense(output_size),
tl.Relu(),
)
def __init__(self, pre_attention, attention, post_attention):
self.pre_attention = tl.Serial(
# (x1_or_y1, x2) -> (x2, x1_or_y1, x2)
tl.Parallel([], tl.Dup()),
tl.Swap(),
tl.Parallel(pre_attention, [], []),
)
assert hasattr(attention, 'forward_and_backward')
self.attention = ApplyAttentionWrapper(attention)
self.post_attention = tl.Parallel(post_attention, [], [])
layers = [
self.pre_attention,
self.attention,
self.post_attention,
tl.Parallel(tl.Add(), []),
]
super(ReversibleAttentionHalfResidual, self).__init__(layers)