Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
# B batch size
# K target dimension size
feature = torch.randn(B, T, D)
# this (B, T, D) layout is just a convention, you can override it by
# write your own _prepare_forward_input function
src_lengths = torch.from_numpy(
np.random.randint(low=1, high=T, size=B).astype(np.int64)
)
src_lengths[0] = T # make sure the maximum length matches
prev_output_tokens = []
for b in range(B):
token_length = np.random.randint(low=1, high=src_lengths[b].item() + 1)
tokens = np.random.randint(low=0, high=K, size=token_length)
prev_output_tokens.append(torch.from_numpy(tokens))
prev_output_tokens = fairseq_data_utils.collate_tokens(
prev_output_tokens,
pad_idx=1,
eos_idx=2,
left_pad=False,
move_eos_to_beginning=False,
)
src_lengths, sorted_order = src_lengths.sort(descending=True)
forward_input["src_tokens"] = feature.index_select(0, sorted_order)
forward_input["src_lengths"] = src_lengths
forward_input["prev_output_tokens"] = prev_output_tokens
return forward_input
def merge(key, left_pad, move_eos_to_beginning=False):
return data_utils.collate_tokens(
[s[key] for s in samples],
pad_idx, eos_idx, left_pad, move_eos_to_beginning,
)
def merge(x, left_pad, move_eos_to_beginning=False):
return data_utils.collate_tokens(
x, pad_idx, eos_idx, left_pad, move_eos_to_beginning
)
def merge(key, is_list=False):
if is_list:
res = []
for i in range(len(samples[0][key])):
res.append(data_utils.collate_tokens(
[s[key][i] for s in samples], pad_idx, eos_idx, left_pad=False,
))
return res
else:
return data_utils.collate_tokens(
[s[key] for s in samples], pad_idx, eos_idx, left_pad=False,
)
samples = parsed_samples
id = torch.LongTensor([s["id"] for s in samples])
frames = self._collate_frames([s["source"] for s in samples])
# sort samples by descending number of frames
frames_lengths = torch.LongTensor([s["source"].size(0) for s in samples])
frames_lengths, sort_order = frames_lengths.sort(descending=True)
id = id.index_select(0, sort_order)
frames = frames.index_select(0, sort_order)
target = None
target_lengths = None
prev_output_tokens = None
if samples[0].get("target", None) is not None:
ntokens = sum(len(s["target"]) for s in samples)
target = fairseq_data_utils.collate_tokens(
[s["target"] for s in samples],
self.pad_index,
self.eos_index,
left_pad=False,
move_eos_to_beginning=False,
)
target = target.index_select(0, sort_order)
target_lengths = torch.LongTensor(
[s["target"].size(0) for s in samples]
).index_select(0, sort_order)
prev_output_tokens = fairseq_data_utils.collate_tokens(
[s["target"] for s in samples],
self.pad_index,
self.eos_index,
left_pad=False,
move_eos_to_beginning=self.move_eos_to_beginning,
def _prepare_batch_for_alignment(self, sample, hypothesis):
src_tokens = sample['net_input']['src_tokens']
bsz = src_tokens.shape[0]
src_tokens = src_tokens[:, None, :].expand(-1, self.beam_size, -1).contiguous().view(bsz * self.beam_size, -1)
src_lengths = sample['net_input']['src_lengths']
src_lengths = src_lengths[:, None].expand(-1, self.beam_size).contiguous().view(bsz * self.beam_size)
prev_output_tokens = data_utils.collate_tokens(
[beam['tokens'] for example in hypothesis for beam in example],
self.pad, self.eos, self.left_pad_target, move_eos_to_beginning=True,
)
tgt_tokens = data_utils.collate_tokens(
[beam['tokens'] for example in hypothesis for beam in example],
self.pad, self.eos, self.left_pad_target, move_eos_to_beginning=False,
)
return src_tokens, src_lengths, prev_output_tokens, tgt_tokens
query_toks, query_mask, query_len = None, None, 0
query_tokens.append(query_toks)
query_masks.append(query_mask)
query_lengths.append(query_len)
cand_toks, cand_masks = [], []
for cand_span in cand_spans:
toks, mask = self.binarize_with_mask(
cand_span.text, prefix, suffix, leading_space, trailing_space,
)
cand_toks.append(toks)
cand_masks.append(mask)
# collate candidates
cand_toks = data_utils.collate_tokens(cand_toks, pad_idx=self.vocab.pad())
cand_masks = data_utils.collate_tokens(cand_masks, pad_idx=0)
assert cand_toks.size() == cand_masks.size()
candidate_tokens.append(cand_toks)
candidate_masks.append(cand_masks)
candidate_lengths.append(cand_toks.size(1))
labels.append(label)
query_lengths = np.array(query_lengths)
query_tokens = ListDataset(query_tokens, query_lengths)
query_masks = ListDataset(query_masks, query_lengths)
candidate_lengths = np.array(candidate_lengths)
candidate_tokens = ListDataset(candidate_tokens, candidate_lengths)
candidate_masks = ListDataset(candidate_masks, candidate_lengths)
def merge(key, left_pad, move_eos_to_beginning=False):
return data_utils.collate_tokens(
[s[key] for s in samples],
pad_idx, eos_idx, left_pad, move_eos_to_beginning,
)