Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
# Setup state.
rng, init_rng = jax_random.split(rng)
self._rngs = np.stack(jax_random.split(rng, self._n_devices))
first_shape = inputs.input_shape[0]
# If the inputs are a tuple/list, add [None] (batch) to each element.
if isinstance(first_shape, (list, tuple)):
model_input_shape = tuple(
tuple([None] + list(shape)) for shape in inputs.input_shape)
model_target_shape = tuple(
tuple([None] + list(shape)) for shape in inputs.target_shape)
else: # Otherwise just add [None] to the input shape.
model_input_shape = tuple([None] + list(inputs.input_shape))
model_target_shape = tuple([None] + list(inputs.target_shape))
# Change all None to 1 in input and target shape.
model_input_shape = backend.nested_map(lambda x: x or 1, model_input_shape)
model_target_shape = backend.nested_map(lambda x: x or 1,
model_target_shape)
def new_opt_state_and_model_state(input_shape, input_dtype, target_shape,
target_dtype, rng):
"""Returns optimizer and model states suitable for training a model."""
# Combine inputs and targets on the stack.
if not isinstance(input_dtype, (list, tuple)):
input_dtype = [input_dtype]
input_shape = [input_shape]
if not isinstance(target_dtype, (list, tuple)):
target_dtype = [target_dtype]
target_shape = [target_shape]
dtypes = list(input_dtype) + list(target_dtype)
shapes = list(input_shape) + list(target_shape)
if self._has_weights:
def _jit_update_fn(predict_fn, loss_fn, optimizer, n_devices, jit=True):
"""Returns a (JIT-compiled) function that computes updates for one step."""
model_and_loss = tl.Serial(predict_fn, loss_fn)
# Gradients are always wrt. the first argument, so putting weights first.
def model_and_loss_call(weights, batch, state, rng):
res = model_and_loss(batch, weights=weights, state=state, rng=rng)
return res, model_and_loss.state
if n_devices == 1: # TODO(lukaszkaiser): remove branch when not needed.
def single_update(i, opt_state, batch, state, rng):
weights, slots, opt_params = opt_state
rng, subrng = jax_random.split(rng[0])
grad_fn = backend.grad(model_and_loss_call, has_aux=True)
grads, state = grad_fn(weights, batch, state, rng)
return optimizer.tree_update(
i, grads, weights, slots, opt_params), state, [subrng]
return backend.jit(single_update) if jit else single_update
# Else, for n_devices > 1:
@functools.partial(backend.pmap, axis_name='batch')
def mapped_update(i, opt_state, batch, state, rng):
"""This is a multi-device version of the update function above."""
# We assume all tensors have the first dimension = n_devices.
weights, slots, opt_params = opt_state
rng, subrng = jax_random.split(rng)
grad_fn = backend.grad(model_and_loss_call, has_aux=True)
grads, state = grad_fn(weights, batch, state, rng)
# We do a psum(1.0) here instead of `n_devices` since `n_devices` is just
# the number of devices on this host machine, however psum goes over all
# devices of all hosts (ex: a TPU pod) and we need to be averaging over all
# of them.
grads = jax.tree_util.tree_map(
lambda g: backend.psum(g, 'batch') / backend.psum(1.0, 'batch'), grads)
def forward(self, inputs, weights):
del weights
return tuple(backend.numpy.split(inputs, self._n_items, self._axis))
dots = np.matmul(query, np.swapaxes(key, -1, -2)) / np.sqrt(depth)
if mask is not None:
# TODO(kitaev): workaround for https://github.com/google/jax/issues/850
# We must ensure that both mask and the -1e9 constant have a data dependency
# on the input. Broadcasted copies of these use a lot of memory, so they
# should be computed at runtime (rather than being global constants).
if backend.get_name() == 'jax':
mask = jax.lax.tie_in(dots, mask)
# JAX's `full_like` already ties in -1e9 to dots.
dots = np.where(mask, dots, np.full_like(dots, -1e9))
# Softmax.
dots = np.exp(dots - backend.logsumexp(dots, axis=-1, keepdims=True))
if dropout >= 1.0:
raise ValueError('Dropout rates must be lower than 1.')
if dropout is not None and dropout > 0.0 and mode == 'train':
keep = backend.random.bernoulli(rng, 1.0 - dropout, dots.shape)
dots = np.where(keep, dots / (1.0 - dropout), np.zeros_like(dots))
out = np.matmul(dots, value)
return out
def _combine_devices(x_tuple):
"""Combine multi-device tensors into a single batch."""
def f(x):
if len(x.shape) < 2:
return x # No extra batch dimension: use devices as batch, so return.
batch_size = x.shape[0] * x.shape[1]
return backend.numpy.reshape(x, [batch_size] + list(x.shape[2:]))
return backend.nested_map(f, x_tuple)
def forward(self, xs, weights):
del weights
return backend.numpy.concatenate(xs, self._axis)
def forward(self, x, weights):
w, b = weights
x_shape = list(x.shape)
if len(x_shape) > 4:
self._check_nhwc()
new_batch_dim = six.moves.reduce(operator.mul, x_shape[:-3])
x = np.reshape(x, [new_batch_dim] + x_shape[-3:])
res = backend.conv(
x, w, self._strides, self._padding, self._dimension_numbers,
self._one) + b
if len(x_shape) > 4:
res = np.reshape(res, x_shape[:-3] + list(res.shape[-3:]))
return res
def forward_and_backward(self, inputs, ct, state, new_state, **kwargs):
assert backend.get_name() == 'jax', (
'JAX backend is required to use forward_and_backward.')
# Simultaneous forward pass and backprop through the attention mechanism.
def _do_forward(x): # pylint: disable=invalid-name
res, _ = self.forward_with_state(x, state=state, **kwargs)
return res
output, vjpfun = jax.vjp(_do_forward, inputs)
return output, vjpfun(ct)[0]
def MaxPool(x, weights, pool_size=(2, 2), strides=None, padding='VALID', **kw):
del weights, kw
return backend.max_pool(x, pool_size=pool_size, strides=strides,
padding=padding)
def forward_with_state(self, inputs, weights=base.EMPTY_WEIGHTS,
state=base.EMPTY_STATE, rng=None, **kwargs):
if self._mode in ('train', 'eval'):
x = inputs
symbol_size = np.shape(x)[1]
px = weights[:, :symbol_size, :]
if self._dropout == 0:
return (x + px, state)
else:
noise_shape = list(px.shape)
for dim in self._dropout_broadcast_dims:
noise_shape[dim] = 1
keep_prob = 1.0 - self._dropout
if backend.get_name() == 'jax':
keep_prob = jax.lax.tie_in(x, np.full((), keep_prob, dtype=x.dtype))
keep = backend.random.bernoulli(rng, keep_prob, tuple(noise_shape))
multiplier = keep.astype(x.dtype) / keep_prob
return (x + px * multiplier, state)
else:
assert self._mode == 'predict'
assert self._dropout == 0
# State in this class is only used for fast inference. In that case,
# the model is called with consecutive elements position-by-position.
# This positional encoding layer needs to store the index of the current
# position then and increment it on each call -- that's how state is used
# and updated below.
return (inputs + np.expand_dims(weights[:, state, :], 1), state + 1)