Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
(4) c_t = f_t * c_{t-1} + (1 - f_t) * y_t
(5) h_t = r_t * activation(c_t) + (1 - r_t) * x_t
We assume the input is of shape [batch, length, depth] and recurrence
happens on the length dimension. This returns a single layer. It's best
to use at least 2, they say in the paper, except inside a Transformer.
Args:
n_units: output depth of the SRU layer.
activation: Optional activation function.
Returns:
The SRU layer.
"""
# pylint: disable=no-value-for-parameter
return cb.Serial( # x
cb.Branch(core.Dense(3 * n_units), []), # r_f_y, x
cb.Split(n_items=3), # r, f, y, x
cb.Parallel(core.Sigmoid(), core.Sigmoid()), # r, f, y, x
base.Fn(lambda r, f, y: (y * (1.0 - f), f, r)), # y * (1 - f), f, r, x
cb.Parallel([], [], cb.Branch(MakeZeroState(), [])),
cb.Scan(InnerSRUCell(), axis=1),
cb.Select([0], n_in=2), # act(c), r, x
activation or [],
base.Fn(lambda c, r, x: c * r + x * (1 - r))
)
def Attention(d_feature, n_heads=1, dropout=0.0, mode='train'):
"""Transformer-style multi-headed attention.
Accepts inputs of the form (x, mask) and constructs (q, k, v) from x.
Args:
d_feature: int: dimensionality of feature embedding
n_heads: int: number of attention heads
dropout: float: dropout rate
mode: str: 'train' or 'eval'
Returns:
Multi-headed self-attention result and the mask.
"""
return cb.Serial(
cb.Dup(), cb.Dup(),
AttentionQKV(d_feature, n_heads=n_heads, dropout=dropout, mode=mode),
)
def MaskedScalar(metric_layer, mask_id=None, has_weights=False):
"""Metric as scalar compatible with Trax masking."""
# Stack of (inputs, targets) --> (metric, weight-mask).
metric_and_mask = [
cb.Parallel(
[],
cb.Dup() # Duplicate targets
),
cb.Parallel(
metric_layer, # Metric: (inputs, targets) --> metric
WeightMask(mask_id=mask_id) # pylint: disable=no-value-for-parameter
)
]
if not has_weights:
# Take (metric, weight-mask) and return the weighted mean.
return cb.Serial(metric_and_mask, WeightedMean()) # pylint: disable=no-value-for-parameter
return cb.Serial(
metric_and_mask,
cb.Parallel(
[],
cb.Multiply() # Multiply given weights by mask_id weights
),
WeightedMean() # pylint: disable=no-value-for-parameter
)
def _validate(self, layers):
if not layers or len(layers) < 2:
raise ValueError(
'layers ({}) must be a list with at least two elements'.format(
layers))
layers = list(layers) # Ensure we can modify layers.
for i, obj in enumerate(layers):
if obj is None or obj == []: # pylint: disable=g-explicit-bool-comparison
layers[i] = Serial(None)
elif isinstance(obj, (list, tuple)):
layers[i] = Serial(obj)
else:
if not isinstance(obj, base.Layer):
raise ValueError(
'Found nonlayer object ({}) in layers list: [{}].'.format(
obj, layers))
if layers[i].n_in == 0:
raise ValueError(
'Sublayer with n_in = 0 not allowed in Parallel:'
' {}'.format(layers[i]))
return layers
- inputs: a, b, c
- outputs: F(a), G(a, b, c), h1, h2 where h1, h2 = H(a, b)
As an important special case, a None argument to Branch acts as if it takes
one argument, which it leaves unchanged. (It acts as a one-arg no-op.)
Args:
*layers: list of layers
Returns:
the branch layer
"""
parallel_layer = Parallel(*layers)
indices = [list(range(layer.n_in)) for layer in parallel_layer.sublayers]
return Serial(Select(_deep_flatten(indices)), parallel_layer)
"""Metric as scalar compatible with Trax masking."""
# Stack of (inputs, targets) --> (metric, weight-mask).
metric_and_mask = [
cb.Parallel(
[],
cb.Dup() # Duplicate targets
),
cb.Parallel(
metric_layer, # Metric: (inputs, targets) --> metric
WeightMask(mask_id=mask_id) # pylint: disable=no-value-for-parameter
)
]
if not has_weights:
# Take (metric, weight-mask) and return the weighted mean.
return cb.Serial(metric_and_mask, WeightedMean()) # pylint: disable=no-value-for-parameter
return cb.Serial(
metric_and_mask,
cb.Parallel(
[],
cb.Multiply() # Multiply given weights by mask_id weights
),
WeightedMean() # pylint: disable=no-value-for-parameter
)
def AttentionQKV(d_feature, n_heads=1, dropout=0.0, mode='train'):
"""Transformer-style multi-headed attention.
Accepts inputs of the form q, k, v, mask.
Args:
d_feature: int: dimensionality of feature embedding
n_heads: int: number of attention heads
dropout: float: dropout rate
mode: str: 'train' or 'eval'
Returns:
Multi-headed self-attention result and the mask.
"""
return cb.Serial(
cb.Parallel(
core.Dense(d_feature),
core.Dense(d_feature),
core.Dense(d_feature),
),
PureAttention( # pylint: disable=no-value-for-parameter
n_heads=n_heads, dropout=dropout, mode=mode),
core.Dense(d_feature),
)