Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
dots < bdots_thresh[..., None], np.float32)
dots = dots - 1e7 * jax.lax.stop_gradient(top_k_mask)
# Softmax.
dots_logsumexp = backend.logsumexp(dots, axis=-1, keepdims=True)
dots = np.exp(dots - dots_logsumexp)
if self._dropout > 0.0:
# Dropout is broadcast across the bin dimension
dropout_shape = (1, dots.shape[-2], dots.shape[-1])
keep_prob = jax.lax.tie_in(dots, 1.0 - self._dropout)
keep = backend.random.bernoulli(rng, keep_prob, dropout_shape)
multiplier = keep.astype(dots.dtype) / jax.lax.tie_in(keep, keep_prob)
dots = dots * multiplier
bo = np.matmul(dots, bv)
so = np.reshape(bo, (-1, bo.shape[-1]))
slogits = np.reshape(dots_logsumexp, (-1,))
def unsort_for_output_impl(so, slogits):
o = np.take(so, undo_sort, axis=0)
# Sorting is considerably faster than gather, but first we need to get the
# XLA compiler to abandon the idea of fusing this sort with the input sort
# (which introduces a computation cycle and leads to a crash).
# TODO(kitaev): remove "sticker_" variable if XLA is fixed.
sticker_ = sticker + jax.lax.convert_element_type(
slogits[0] > 0, sticker.dtype)
_, logits = jax.lax.sort_key_val(sticker_, slogits, dimension=-1)
return o, logits
def unsort_for_output_vjp(so, slogits):
"""Custom gradient for unsort_for_output."""
"""Core dot product self-attention.
Args:
query: array of representations
key: array of representations
value: array of representations
mask: attention-mask, gates attention
dropout: float: dropout rate
mode: 'eval' or 'train': whether to use dropout
rng: JAX PRNGKey: subkey for disposable use
Returns:
Self attention for q, k, v arrays.
"""
depth = np.shape(query)[-1]
dots = np.matmul(query, np.swapaxes(key, -1, -2)) / np.sqrt(depth)
if mask is not None:
# TODO(kitaev): workaround for https://github.com/google/jax/issues/850
# We must ensure that both mask and the -1e9 constant have a data dependency
# on the input. Broadcasted copies of these use a lot of memory, so they
# should be computed at runtime (rather than being global constants).
if backend.get_name() == 'jax':
mask = jax.lax.tie_in(dots, mask)
# JAX's `full_like` already ties in -1e9 to dots.
dots = np.where(mask, dots, np.full_like(dots, -1e9))
# Softmax.
dots = np.exp(dots - backend.logsumexp(dots, axis=-1, keepdims=True))
if dropout >= 1.0:
raise ValueError('Dropout rates must be lower than 1.')
if dropout is not None and dropout > 0.0 and mode == 'train':
keep = backend.random.bernoulli(rng, 1.0 - dropout, dots.shape)
dots = np.where(keep, dots / (1.0 - dropout), np.zeros_like(dots))
# TODO(kitaev): workaround for https://github.com/google/jax/issues/850
# We must ensure that both mask and the -1e9 constant have a data dependency
# on the input. Broadcasted copies of these use a lot of memory, so they
# should be computed at runtime (rather than being global constants).
if backend.get_name() == 'jax':
mask = jax.lax.tie_in(dots, mask)
# JAX's `full_like` already ties in -1e9 to dots.
dots = np.where(mask, dots, np.full_like(dots, -1e9))
# Softmax.
dots = np.exp(dots - backend.logsumexp(dots, axis=-1, keepdims=True))
if dropout >= 1.0:
raise ValueError('Dropout rates must be lower than 1.')
if dropout is not None and dropout > 0.0 and mode == 'train':
keep = backend.random.bernoulli(rng, 1.0 - dropout, dots.shape)
dots = np.where(keep, dots / (1.0 - dropout), np.zeros_like(dots))
out = np.matmul(dots, value)
return out
def forward_slice(query_slice, q_loop_idx, key, value): # pylint: disable=invalid-name
"""Forward pass for a subset of the query vectors."""
if self._share_qk:
key = self.make_unit_length(key)
dots = np.matmul(
query_slice, np.swapaxes(key, -1, -2)) / np.sqrt(depth)
# Causal masking
mask = make_mask(dots.shape[-2], dots.shape[-1], q_loop_idx)
dots = dots - 1e9 * mask
# Mask out attention to self except when no other targets are available.
if self._share_qk:
self_mask = make_self_mask(dots.shape[-2], dots.shape[-1], q_loop_idx)
dots = dots - 1e5 * self_mask
# Softmax.
dots = np.exp(dots - backend.logsumexp(dots, axis=-1, keepdims=True))
if self.dropout is not None and self.dropout > 0.0:
# Dropout is broadcast across the batch+head dimension
self_mask = jax.lax.broadcasted_eye(dots.dtype, dots.shape, (2, 3))
self_mask = jax.lax.tie_in(dots, self_mask)
dots = dots - 1e5 * self_mask
# Softmax.
dots = np.exp(dots - backend.logsumexp(dots, axis=-1, keepdims=True))
if self.dropout > 0.0:
# Dropout is broadcast across the batch+head dimension
dropout_shape = (1, dots.shape[-3], dots.shape[-2], dots.shape[-1])
keep_prob = jax.lax.tie_in(dots, 1.0 - self.dropout)
keep = backend.random.bernoulli(rng, keep_prob, dropout_shape)
multiplier = keep.astype(dots.dtype) / jax.lax.tie_in(keep, keep_prob)
dots = dots * multiplier
bo = np.matmul(dots, bv)
output = np.reshape(bo, (bo.shape[0], -1, bo.shape[-1]))
assert output.shape == v.shape
return output[..., :original_len, :]
dropout_shape = (1, dots.shape[-2], dots.shape[-1])
slice_rng = jax.random.fold_in(rng, q_loop_idx)
keep_prob = jax.lax.tie_in(dots, 1.0 - self.dropout)
keep = backend.random.bernoulli(slice_rng, keep_prob, dropout_shape)
multiplier = keep.astype(dots.dtype) / jax.lax.tie_in(keep, keep_prob)
dots = dots * multiplier
if self._hard_k > 0:
top_k = np.sort(dots)[..., -self._hard_k] # Get the top-kth weight.
top_k = jax.lax.stop_gradient(top_k)
dots -= top_k[..., np.newaxis] # Subtract (be 0 for lower ones).
dots = np.maximum(dots, 0)
dots_sum = np.sum(dots, axis=-1, keepdims=True) # Re-normalize.
dots /= dots_sum # Re-normalize.
out_slice = np.matmul(dots, value)
return out_slice