Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
random_rotations = self._sample_rotation(rotations_shape, vecs, rng)
# TODO(lukaszkaiser): the dropout mask will be used for all rounds of
# hashing, so it's shared between them. Check if that's what we want.
dropped_vecs = self.drop_for_hash(vecs, subrng)
rotated_vecs = np.einsum('tf,fhb->htb', dropped_vecs, random_rotations)
if self._rehash_each_round:
if self._factorize_hash and len(factor_list) > 1:
# We factorized self.n_buckets as the product of factor_list.
# Get the buckets for them and combine.
buckets, cur_sum, cur_product = None, 0, 1
for factor in factor_list:
rv = rotated_vecs[..., cur_sum:cur_sum + (factor // 2)]
cur_sum += factor // 2
rv = np.concatenate([rv, -rv], axis=-1)
if buckets is None:
buckets = np.argmax(rv, axis=-1)
else:
buckets += cur_product * np.argmax(rv, axis=-1)
cur_product *= factor
else:
rotated_vecs = np.concatenate([rotated_vecs, -rotated_vecs], axis=-1)
buckets = np.argmax(rotated_vecs, axis=-1)
# buckets is now (self.n_hashes, seqlen). Next we add offsets so that
# bucket numbers from different hashing rounds don't overlap.
offsets = jax.lax.tie_in(buckets, np.arange(self.n_hashes))
offsets = np.reshape(offsets * self.n_buckets, (-1, 1))
buckets = np.reshape(buckets + offsets, (-1,))
else:
assert not self._factorize_hash
rotated_vecs = np.concatenate([rotated_vecs, -rotated_vecs], axis=-1)
def forward(self, inp, weights):
"""Reshape input to have heads dimension and concatenate positions there."""
x = inp[0]
n_batches, seqlen = x.shape[0], x.shape[1]
d_head = x.shape[-1] // self._n_heads
res = np.reshape(x, (n_batches, seqlen, self._n_heads, d_head))
res = np.transpose(res, (0, 2, 1, 3)) # (batch, heads, len, depth)
if self._n_pos == 1: # Just one position given, tile into each head.
pos_shape = list(res.shape)[:-1] + [inp[1].shape[-1]]
pos = inp[1][:, None, :, :] + np.zeros(pos_shape) # Add 0 to broadcast.
else: # As many positions as heads, concatenate them in.
pos = [p[:, None, :, :] for p in inp[1:]]
pos = np.concatenate(pos, axis=1)
res = np.concatenate([res, pos], axis=-1)
# n_batch, n_heads, seqlen, d_head -> n_batch*n_heads, seqlen, d_head
res = np.reshape(res, (-1, seqlen, d_head + POS_VECTOR_SIZE))
return res
if self._rehash_each_round:
if self._factorize_hash and len(factor_list) > 1:
# We factorized self.n_buckets as the product of factor_list.
# Get the buckets for them and combine.
buckets, cur_sum, cur_product = None, 0, 1
for factor in factor_list:
rv = rotated_vecs[..., cur_sum:cur_sum + (factor // 2)]
cur_sum += factor // 2
rv = np.concatenate([rv, -rv], axis=-1)
if buckets is None:
buckets = np.argmax(rv, axis=-1)
else:
buckets += cur_product * np.argmax(rv, axis=-1)
cur_product *= factor
else:
rotated_vecs = np.concatenate([rotated_vecs, -rotated_vecs], axis=-1)
buckets = np.argmax(rotated_vecs, axis=-1)
# buckets is now (self.n_hashes, seqlen). Next we add offsets so that
# bucket numbers from different hashing rounds don't overlap.
offsets = jax.lax.tie_in(buckets, np.arange(self.n_hashes))
offsets = np.reshape(offsets * self.n_buckets, (-1, 1))
buckets = np.reshape(buckets + offsets, (-1,))
else:
assert not self._factorize_hash
rotated_vecs = np.concatenate([rotated_vecs, -rotated_vecs], axis=-1)
# In this configuration, we map each item to the top self.n_hashes buckets
rotated_vecs = np.squeeze(rotated_vecs, 0)
bucket_range = jax.lax.tie_in(vecs, np.arange(rotated_vecs.shape[-1]))
bucket_range = np.reshape(bucket_range, (1, -1))
bucket_range = np.broadcast_to(bucket_range, rotated_vecs.shape)
_, buckets = jax.lax.sort_key_val(
def reverse(self, output, weights=(), state=(), new_state=(), **kwargs):
del weights, kwargs
x1_split = []
x2_split = []
for y in output:
y1, y2 = np.split(y, 2, -1)
x1_split.append(y1)
x2_split.append(y2)
x1 = np.concatenate(x1_split, self._axis)
x2 = np.concatenate(x2_split, self._axis)
return (x1, x2)
def look_one_back(x):
# Output: pairs [ bin_i bin_{i-1} ] concatenated on the time axis.
if len(x.shape) == 2:
x_extra = np.concatenate([x[-1:, :], x[:-1, :]], axis=0)
return np.concatenate([x, x_extra], axis=1)
else:
assert len(x.shape) == 4
x_extra = np.concatenate([x[:, -1:, :, :], x[:, :-1, :, :]], axis=1)
return np.concatenate([x, x_extra], axis=2)
def look_one_back(x):
if len(x.shape) == 2:
x_extra = np.concatenate([x[-1:, :], x[:-1, :]], axis=0)
else:
x_extra = np.concatenate([x[-1:, :, :], x[:-1, :, :]], axis=0)
return np.concatenate([x, x_extra], axis=1)
def QueryPositionKV(x, keys=None, values=None, binary=False, **unused_kwargs):
"""Query a table with a position vector."""
if keys is None:
return x
k = np.array(keys)
v = np.array(values)
q = x
if binary:
q = np.concatenate([x, x], axis=-1)
return tl.DotProductAttention(q, k, v, None, 0.0, None, None)
def look_one_back(x):
if len(x.shape) == 2:
x_extra = np.concatenate([x[-1:, :], x[:-1, :]], axis=0)
else:
x_extra = np.concatenate([x[-1:, :, :], x[:-1, :, :]], axis=0)
return np.concatenate([x, x_extra], axis=1)
the vector with combined xs and one with combined positions.
"""
seqlen = x.shape[1]
d_head = x.shape[2]
x = np.reshape(x, (-1, n_heads, seqlen, d_head))
x = np.transpose(x, (0, 2, 1, 3)) # -> n_batch, seqlen, n_heads, d_head
x = np.reshape(x, (-1, seqlen, n_heads * d_head))
head_size = int(d_head) - POS_VECTOR_SIZE
res, positions, idx = [], [], 0
for _ in range(n_heads):
res.append(x[:, :, idx:idx+head_size])
idx += head_size
positions.append(x[:, :, idx:idx+POS_VECTOR_SIZE])
idx += POS_VECTOR_SIZE
combined_position = sum(positions) / float(len(positions))
return np.concatenate(res, axis=-1), combined_position
def look_one_back(x):
# Output: pairs [ bin_i bin_{i-1} ] concatenated on the time axis.
if len(x.shape) == 2:
x_extra = np.concatenate([x[-1:, :], x[:-1, :]], axis=0)
return np.concatenate([x, x_extra], axis=1)
else:
assert len(x.shape) == 4
x_extra = np.concatenate([x[:, -1:, :, :], x[:, :-1, :, :]], axis=1)
return np.concatenate([x, x_extra], axis=2)