Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def init(self, params):
shape = params.shape
slots = []
if self._factored and len(shape) >= 2:
v_row = np.zeros(shape[:-1], dtype=np.float32)
v_col = np.zeros(shape[:-2] + shape[-1:], dtype=np.float32)
slots.extend([v_row, v_col])
else:
v = np.zeros_like(params)
slots.append(v)
if self._do_momentum:
m = np.zeros_like(params)
slots.append(m)
return slots
def _fast_inference_init_state(input_signature, buffer_length):
"""Returns an initial state for causal attention layer fast inference."""
def zeros_for(batch_size, shape_dtype):
shape, dtype = shape_dtype.as_tuple()
depth = shape[-1]
return np.zeros((batch_size, buffer_length, depth), dtype=dtype)
batch_size = input_signature[0].shape[0]
k = zeros_for(batch_size, input_signature[1])
v = zeros_for(batch_size, input_signature[2])
mask = np.zeros((batch_size, 1, buffer_length))
index = 0
return (k, v, mask, index)
def zeros_for(batch_size, shape_dtype):
shape, dtype = shape_dtype.as_tuple()
depth = shape[-1]
return np.zeros((batch_size, buffer_length, depth), dtype=dtype)
def new_weights(self, input_signature):
# Usually (B, W, H, C)
shape = input_signature.shape
num_channels = shape[-1]
gamma = np.ones((num_channels,), dtype=np.float32)
beta = np.zeros((num_channels,), dtype=np.float32)
epsilon_l = base.EMPTY_WEIGHTS
if self._learn_epsilon:
epsilon_l = (self._init_learnt_epsilon,)
return gamma, beta, epsilon_l
def NewPositionalEncoding(x, positions=None, **kwargs):
"""Implements new positional encoding."""
del kwargs
x_length = np.shape(x)[1]
pos = np.array(positions)[np.newaxis, :x_length, :]
pos += np.zeros((np.shape(x)[0], 1, 1)) # Broadcast on batch.
return pos
assert return_output or ct is not None, 'No work to perform!'
if new_state is not None and new_state is not base.EMPTY_STATE:
buckets = new_state
else:
buckets = None
# The approach here is to perform attention for one batch element and head
# at a time. Note that there is absolutely no interaction across examples or
# heads: this layer has no parameters, and hashing patterns are also
# different across examples/heads. As a result, batching doesn't give any
# performance gains except in the case of accelerator under-utilization. We
# assume that hash-based attention will be applied primarily to long
# sequences, where unbatched attention for a single head has sufficient
# computation to fill up the accelerator.
batch_loop_idx = np.zeros((), dtype=np.int32)
batch_loop_max = qk.shape[0]
init_vals = (batch_loop_idx,)
if return_output:
out_accum = np.zeros_like(qk)
init_vals = init_vals + (out_accum,)
if return_state:
buckets_accum = np.zeros(
[qk.shape[0], self.n_hashes * qk.shape[1]], dtype=np.int32)
init_vals = init_vals + (buckets_accum,)
if ct is not None:
qk_ct_accum = np.zeros_like(qk)
v_ct_accum = np.zeros_like(v)
init_vals = init_vals + (qk_ct_accum, v_ct_accum)
def cond_fun(vals):
def new_weights_and_state(self, input_signature):
"""Helper to initialize batch norm weights."""
axis = self._axis
axis = (axis,) if np.isscalar(axis) else axis
input_shape = input_signature.shape
shape = tuple(d for i, d in enumerate(input_shape) if i not in axis)
beta = np.zeros(shape, dtype='float32') if self._center else ()
gamma = np.ones(shape, dtype='float32') if self._scale else ()
def get_stats_axis(i, d):
if i in axis:
return 1
else:
return d
stats_shape = tuple(get_stats_axis(i, d) for i, d in enumerate(input_shape))
running_mean = np.zeros(stats_shape, dtype=np.float32)
running_var = np.ones(stats_shape, dtype=np.float32)
n_batches = np.zeros((), dtype=np.int64)
weights = (beta, gamma)
state = (running_mean, running_var, n_batches)
return weights, state
# heads: this layer has no parameters, and hashing patterns are also
# different across examples/heads. As a result, batching doesn't give any
# performance gains except in the case of accelerator under-utilization. We
# assume that hash-based attention will be applied primarily to long
# sequences, where unbatched attention for a single head has sufficient
# computation to fill up the accelerator.
batch_loop_idx = np.zeros((), dtype=np.int32)
batch_loop_max = qk.shape[0]
init_vals = (batch_loop_idx,)
if return_output:
out_accum = np.zeros_like(qk)
init_vals = init_vals + (out_accum,)
if return_state:
buckets_accum = np.zeros(
[qk.shape[0], self.n_hashes * qk.shape[1]], dtype=np.int32)
init_vals = init_vals + (buckets_accum,)
if ct is not None:
qk_ct_accum = np.zeros_like(qk)
v_ct_accum = np.zeros_like(v)
init_vals = init_vals + (qk_ct_accum, v_ct_accum)
def cond_fun(vals):
batch_loop_idx = vals[0]
return jax.lax.lt(batch_loop_idx, batch_loop_max)
def body_fun(vals):
"""Performs attention for a single batch element and head."""
batch_loop_idx = vals[0]
if self._prng is None:
hash_slice_rng = jax.random.fold_in(rng, batch_loop_idx)
"""Helper to initialize batch norm weights."""
axis = self._axis
axis = (axis,) if np.isscalar(axis) else axis
input_shape = input_signature.shape
shape = tuple(d for i, d in enumerate(input_shape) if i not in axis)
beta = np.zeros(shape, dtype='float32') if self._center else ()
gamma = np.ones(shape, dtype='float32') if self._scale else ()
def get_stats_axis(i, d):
if i in axis:
return 1
else:
return d
stats_shape = tuple(get_stats_axis(i, d) for i, d in enumerate(input_shape))
running_mean = np.zeros(stats_shape, dtype=np.float32)
running_var = np.ones(stats_shape, dtype=np.float32)
n_batches = np.zeros((), dtype=np.int64)
weights = (beta, gamma)
state = (running_mean, running_var, n_batches)
return weights, state