Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def init(self, params):
vs = [np.zeros(sz, dtype=params.dtype) for sz in params.shape]
return (np.zeros_like(params), vs)
def forward(self, inputs, weights):
x, gru_state = inputs
# Dense layer on the concatenation of x and h.
w1, b1, w2, b2 = weights
y = np.dot(np.concatenate([x, gru_state], axis=-1), w1) + b1
# Update and reset gates.
u, r = np.split(math.sigmoid(y), 2, axis=-1)
# Candidate.
c = np.dot(np.concatenate([x, r * gru_state], axis=-1), w2) + b2
new_gru_state = u * gru_state + (1 - u) * np.tanh(c)
return new_gru_state, new_gru_state
rng: JAX PRNGKey: subkey for disposable use
Returns:
Self attention for q, k, v arrays.
"""
depth = np.shape(query)[-1]
dots = np.matmul(query, np.swapaxes(key, -1, -2)) / np.sqrt(depth)
if mask is not None:
# TODO(kitaev): workaround for https://github.com/google/jax/issues/850
# We must ensure that both mask and the -1e9 constant have a data dependency
# on the input. Broadcasted copies of these use a lot of memory, so they
# should be computed at runtime (rather than being global constants).
if backend.get_name() == 'jax':
mask = jax.lax.tie_in(dots, mask)
# JAX's `full_like` already ties in -1e9 to dots.
dots = np.where(mask, dots, np.full_like(dots, -1e9))
# Softmax.
dots = np.exp(dots - backend.logsumexp(dots, axis=-1, keepdims=True))
if dropout >= 1.0:
raise ValueError('Dropout rates must be lower than 1.')
if dropout is not None and dropout > 0.0 and mode == 'train':
keep = backend.random.bernoulli(rng, 1.0 - dropout, dots.shape)
dots = np.where(keep, dots / (1.0 - dropout), np.zeros_like(dots))
out = np.matmul(dots, value)
return out
def forward(self, x, weights):
return np.take(weights, x, axis=0)
dots_thresh = np.sort(merged_top_dots)[:, -self._hard_k]
# It's possible to compute the partition function at this point, but right
# now this codepath isn't set up for backprop, and there might also be
# issues computing it this way if two dot-products are exactly equal.
sdots_thresh = dots_thresh[st]
bdots_thresh = np.reshape(sdots_thresh, (self.n_hashes * self.n_bins, -1))
bdots_thresh = jax.lax.stop_gradient(bdots_thresh)
top_k_mask = jax.lax.convert_element_type(
dots < bdots_thresh[..., None], np.float32)
dots = dots - 1e7 * jax.lax.stop_gradient(top_k_mask)
# Softmax.
dots_logsumexp = backend.logsumexp(dots, axis=-1, keepdims=True)
dots = np.exp(dots - dots_logsumexp)
if self._dropout > 0.0:
# Dropout is broadcast across the bin dimension
dropout_shape = (1, dots.shape[-2], dots.shape[-1])
keep_prob = jax.lax.tie_in(dots, 1.0 - self._dropout)
keep = backend.random.bernoulli(rng, keep_prob, dropout_shape)
multiplier = keep.astype(dots.dtype) / jax.lax.tie_in(keep, keep_prob)
dots = dots * multiplier
bo = np.matmul(dots, bv)
so = np.reshape(bo, (-1, bo.shape[-1]))
slogits = np.reshape(dots_logsumexp, (-1,))
def unsort_for_output_impl(so, slogits):
o = np.take(so, undo_sort, axis=0)
# Sorting is considerably faster than gather, but first we need to get the
def update_model_state(self, key, value):
"""Updates model state based on nontrainable_params."""
# Translate model state keys to nontrainable param names.
if key in self._nontrainable_param_map:
p_name = self._nontrainable_param_map[key]
else:
# If a key is not in mapping, it stays the same.
p_name = key
if p_name in self.nontrainable_params:
if self._step == 0:
log('Mapping model state key {} to nontrainable param {}.'
''.format(key, p_name))
return self._for_n_devices(np.array(self.nontrainable_params[p_name]))
return value
def _minimum(self, tensor_list):
minimum = tensor_list[0]
for i in range(1, len(tensor_list)):
minimum = np.minimum(minimum, tensor_list[i])
return minimum
n_cols = cur_shape[-1]
flat_shape = (n_cols, n_rows) if n_rows < n_cols else (n_rows, n_cols)
# Generate a random matrix
a = random.normal(rng, flat_shape, dtype=np.float32)
# Compute the qr factorization
q, r = np.linalg.qr(a)
# Make Q uniform
d = np.diag(r)
q *= np.sign(d)
# Transpose and reshape back q if needed.
if n_rows < n_cols:
q = np.transpose(q)
q = np.reshape(q, shape)
# Return scaled as requested.
return stddev * q
emb = np.concatenate(embs, -1)
if self._mode == 'predict':
assert self._dropout == 0.0
emb = np.reshape(emb, (inputs.shape[0], -1, emb.shape[-1]))
return inputs + emb[:, state, :][:, None, :], state + 1
elif self._dropout == 0:
return inputs + np.reshape(emb, inputs.shape), state
else:
noise_shape = list(emb.shape)
for dim in self._dropout_broadcast_dims:
noise_shape[dim] = 1
keep_prob = 1.0 - self._dropout
if backend.get_name() == 'jax':
keep_prob = jax.lax.tie_in(
inputs, np.full((), keep_prob, dtype=inputs.dtype))
keep = backend.random.bernoulli(rng, keep_prob, tuple(noise_shape))
multiplier = keep.astype(inputs.dtype) / keep_prob
return inputs + np.reshape(emb * multiplier, inputs.shape), state
def dataset_to_stream(dataset, input_name):
"""Takes a tf.Dataset and creates a numpy stream of ready batches."""
for example in backend.dataset_as_numpy(dataset):
features = example[0]
inp, out = features[input_name], example[1]
mask = features['mask'] if 'mask' in features else None
# All input-pipeline processing should be on CPU.
with tf.device('cpu:0'):
# Some accelerators don't handle uint8 well, cast to int.
if isinstance(inp, np.uint8):
inp = inp.astype(np.int32)
if isinstance(out, np.uint8):
out = out.astype(np.int32)
if len(out.shape) > 1 and out.shape[-1] == 1:
out = np.squeeze(out, axis=-1)
yield (inp, out) if mask is None else (inp, out, mask)