Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def FeedForward(d_model, d_ff, dropout, activation, mode):
"""Feed-forward block with layer normalization at start."""
return [
tl.LayerNorm(),
tl.Dense(d_ff),
BroadcastedDropout(rate=dropout, mode=mode), # pylint: disable=no-value-for-parameter
activation(),
tl.Dense(d_model),
BroadcastedDropout(rate=dropout, mode=mode), # pylint: disable=no-value-for-parameter
]
def AppendLearnedPosOperation(vec, q1, q2, q3, q4, q5):
"""Get (vec, q1, ...) and return new_pos."""
# Create 5 scalar weights (length 1 vectors) from first component of input.
ws = [tl.Dense(1) @ vec for _ in range(5)]
new_pos = Softmax5Branches() @ (ws + [q1, q2, q3, q4, q5])
return new_pos
Args:
d_feature: Number of memory channels (dimensionality of feature embedding).
steps: Number of times depthwise recurrence steps.
vocab_size: Vocabulary size.
mode: Whether we are training or evaluating or doing inference.
Returns:
A NeuralGPU Stax model.
"""
del mode
core = ConvDiagonalGRU(units=d_feature)
return tl.Serial(
tl.Embedding(d_feature=d_feature, vocab_size=vocab_size),
[core] * steps,
tl.Dense(vocab_size),
tl.LogSoftmax(),
)
def FeedForward(d_model, d_ff, dropout, activation, mode):
"""Feed-forward block with layer normalization at start."""
return [
tl.LayerNorm(),
tl.Dense(d_ff),
BroadcastedDropout(rate=dropout, mode=mode), # pylint: disable=no-value-for-parameter
activation(),
tl.Dense(d_model),
BroadcastedDropout(rate=dropout, mode=mode), # pylint: disable=no-value-for-parameter
]
[rnn_cell(n_units=d_model) for _ in range(n_layers)]),
tl.Parallel([], tl.Concatenate(n_items=n_layers))
)
zero_state = tl.MakeZeroState( # pylint: disable=no-value-for-parameter
depth_multiplier=n_layers * rnn_cell_d_state_multiplier
)
return tl.Serial(
tl.ShiftRight(mode=mode),
tl.Embedding(d_model, vocab_size),
tl.Dropout(rate=dropout, name='embedding', mode=mode),
tl.Branch([], zero_state),
tl.Scan(MultiRNNCell(), axis=1),
tl.Select([0], n_in=2), # Drop RNN state.
tl.Dense(vocab_size),
tl.LogSoftmax()
)