How to use the fairseq.data.data_utils.collate_tokens function in fairseq

To help you get started, we’ve selected a few fairseq examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github freewym / espresso / tests / speech_recognition / asr_test_base.py View on Github external
# B batch size
    # K target dimension size
    feature = torch.randn(B, T, D)
    # this (B, T, D) layout is just a convention, you can override it by
    # write your own _prepare_forward_input function
    src_lengths = torch.from_numpy(
        np.random.randint(low=1, high=T, size=B).astype(np.int64)
    )
    src_lengths[0] = T  # make sure the maximum length matches
    prev_output_tokens = []
    for b in range(B):
        token_length = np.random.randint(low=1, high=src_lengths[b].item() + 1)
        tokens = np.random.randint(low=0, high=K, size=token_length)
        prev_output_tokens.append(torch.from_numpy(tokens))

    prev_output_tokens = fairseq_data_utils.collate_tokens(
        prev_output_tokens,
        pad_idx=1,
        eos_idx=2,
        left_pad=False,
        move_eos_to_beginning=False,
    )
    src_lengths, sorted_order = src_lengths.sort(descending=True)
    forward_input["src_tokens"] = feature.index_select(0, sorted_order)
    forward_input["src_lengths"] = src_lengths
    forward_input["prev_output_tokens"] = prev_output_tokens

    return forward_input
github microsoft / MASS / MASS-supNMT / mass / noisy_language_pair_dataset.py View on Github external
def merge(key, left_pad, move_eos_to_beginning=False):
        return data_utils.collate_tokens(
            [s[key] for s in samples],
            pad_idx, eos_idx, left_pad, move_eos_to_beginning,
        )
github microsoft / MASS / MASS-summarization / mass / masked_dataset.py View on Github external
def merge(x, left_pad, move_eos_to_beginning=False):
            return data_utils.collate_tokens(
                x, pad_idx, eos_idx, left_pad, move_eos_to_beginning
            )
github freewym / espresso / fairseq / data / monolingual_dataset.py View on Github external
def merge(key, is_list=False):
        if is_list:
            res = []
            for i in range(len(samples[0][key])):
                res.append(data_utils.collate_tokens(
                    [s[key][i] for s in samples], pad_idx, eos_idx, left_pad=False,
                ))
            return res
        else:
            return data_utils.collate_tokens(
                [s[key] for s in samples], pad_idx, eos_idx, left_pad=False,
            )
github freewym / espresso / examples / speech_recognition / data / collaters.py View on Github external
samples = parsed_samples

        id = torch.LongTensor([s["id"] for s in samples])
        frames = self._collate_frames([s["source"] for s in samples])
        # sort samples by descending number of frames
        frames_lengths = torch.LongTensor([s["source"].size(0) for s in samples])
        frames_lengths, sort_order = frames_lengths.sort(descending=True)
        id = id.index_select(0, sort_order)
        frames = frames.index_select(0, sort_order)

        target = None
        target_lengths = None
        prev_output_tokens = None
        if samples[0].get("target", None) is not None:
            ntokens = sum(len(s["target"]) for s in samples)
            target = fairseq_data_utils.collate_tokens(
                [s["target"] for s in samples],
                self.pad_index,
                self.eos_index,
                left_pad=False,
                move_eos_to_beginning=False,
            )
            target = target.index_select(0, sort_order)
            target_lengths = torch.LongTensor(
                [s["target"].size(0) for s in samples]
            ).index_select(0, sort_order)
            prev_output_tokens = fairseq_data_utils.collate_tokens(
                [s["target"] for s in samples],
                self.pad_index,
                self.eos_index,
                left_pad=False,
                move_eos_to_beginning=self.move_eos_to_beginning,
github freewym / espresso / fairseq / sequence_generator.py View on Github external
def _prepare_batch_for_alignment(self, sample, hypothesis):
        src_tokens = sample['net_input']['src_tokens']
        bsz = src_tokens.shape[0]
        src_tokens = src_tokens[:, None, :].expand(-1, self.beam_size, -1).contiguous().view(bsz * self.beam_size, -1)
        src_lengths = sample['net_input']['src_lengths']
        src_lengths = src_lengths[:, None].expand(-1, self.beam_size).contiguous().view(bsz * self.beam_size)
        prev_output_tokens = data_utils.collate_tokens(
            [beam['tokens'] for example in hypothesis for beam in example],
            self.pad, self.eos, self.left_pad_target, move_eos_to_beginning=True,
        )
        tgt_tokens = data_utils.collate_tokens(
            [beam['tokens'] for example in hypothesis for beam in example],
            self.pad, self.eos, self.left_pad_target, move_eos_to_beginning=False,
        )
        return src_tokens, src_lengths, prev_output_tokens, tgt_tokens
github freewym / espresso / examples / roberta / wsc / wsc_task.py View on Github external
query_toks, query_mask, query_len = None, None, 0

            query_tokens.append(query_toks)
            query_masks.append(query_mask)
            query_lengths.append(query_len)

            cand_toks, cand_masks = [], []
            for cand_span in cand_spans:
                toks, mask = self.binarize_with_mask(
                    cand_span.text, prefix, suffix, leading_space, trailing_space,
                )
                cand_toks.append(toks)
                cand_masks.append(mask)

            # collate candidates
            cand_toks = data_utils.collate_tokens(cand_toks, pad_idx=self.vocab.pad())
            cand_masks = data_utils.collate_tokens(cand_masks, pad_idx=0)
            assert cand_toks.size() == cand_masks.size()

            candidate_tokens.append(cand_toks)
            candidate_masks.append(cand_masks)
            candidate_lengths.append(cand_toks.size(1))

            labels.append(label)

        query_lengths = np.array(query_lengths)
        query_tokens = ListDataset(query_tokens, query_lengths)
        query_masks = ListDataset(query_masks, query_lengths)

        candidate_lengths = np.array(candidate_lengths)
        candidate_tokens = ListDataset(candidate_tokens, candidate_lengths)
        candidate_masks = ListDataset(candidate_masks, candidate_lengths)
github zhiqwang / sightseq / sightseq / data / text_recognition_dataset.py View on Github external
def merge(key, left_pad, move_eos_to_beginning=False):
        return data_utils.collate_tokens(
            [s[key] for s in samples],
            pad_idx, eos_idx, left_pad, move_eos_to_beginning,
        )