How to use the allennlp.nn.util.masked_log_softmax function in allennlp

To help you get started, weā€™ve selected a few allennlp examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github matthew-z / R-net / qa / squad / rnet.py View on Github external
span_end_logits = util.replace_masked_values(
            span_end_logits, passage_mask, -1e7)

        best_span = self.get_best_span(span_start_logits, span_end_logits)

        output_dict = {
            "span_start_logits": span_start_logits,
            "span_start_probs": span_start_probs,
            "span_end_logits": span_end_logits,
            "span_end_probs": span_end_probs,
            "best_span": best_span,
        }


        if span_start is not None:
            loss = nll_loss(util.masked_log_softmax(
                span_start_logits, passage_mask), span_start.squeeze(-1))
            self._span_start_accuracy(
                span_start_logits, span_start.squeeze(-1))
            loss += nll_loss(util.masked_log_softmax(span_end_logits,
                                                     passage_mask), span_end.squeeze(-1))
            self._span_end_accuracy(span_end_logits, span_end.squeeze(-1))
            self._span_accuracy(best_span, torch.stack(
                [span_start, span_end], -1))
            output_dict["loss"] = loss

        # Compute the EM and F1 on SQuAD and add the tokenized input to the output.
        if metadata is not None:
            output_dict['best_span_str'] = []
            question_tokens = []
            passage_tokens = []
            batch_size = question_embeded.size(0)
github manuwhs / Trapyng / libs / AllenNLP_lib / bidaf_model.py View on Github external
self._squad_metrics(best_span_string, answer_texts)
                    if (get_sample_level_information):
                        em_sample, f1_sample = bidut.get_em_f1_metrics(best_span_string,answer_texts)
                        output_dict["em_samples"].append(em_sample)
                        output_dict["f1_samples"].append(f1_sample)
                        
            output_dict['question_tokens'] = question_tokens
            output_dict['passage_tokens'] = passage_tokens
            
        if (get_sample_level_information):
            # Add information about the individual samples for future analysis
            output_dict["span_start_sample_loss"] = []
            output_dict["span_end_sample_loss"] = []
            for i in range (batch_size):
                span_start_loss = nll_loss(util.masked_log_softmax(span_start_logits[[i],:], passage_mask[[i],:]), span_start.squeeze(-1)[[i]])
                span_end_loss = nll_loss(util.masked_log_softmax(span_end_logits[[i],:], passage_mask[[i],:]), span_end.squeeze(-1)[[i]])
                
                output_dict["span_start_sample_loss"].append(float(span_start_loss.detach().cpu().numpy()))
                output_dict["span_end_sample_loss"].append(float(span_end_loss.detach().cpu().numpy()))
        if(get_attentions):
            output_dict["C2Q_attention"] = passage_question_attention
            output_dict["Q2C_attention"] = question_passage_attention
            output_dict["simmilarity"] = passage_question_similarity
            
        return output_dict
github rloganiv / kglm-model / kglm / models / simplified.py View on Github external
mask: torch.Tensor,
                    alias_indices: torch.Tensor,
                    alias_tokens: torch.Tensor):
        mention_mask = target_alias_indices.gt(0)
        batch_size, sequence_length, vocab_size = generate_scores.shape
        copy_sequence_length = copy_scores.shape[-1]

        # Flat sequences make life **much** easier.
        flattened_targets = target_tokens.view(batch_size * sequence_length, 1)
        flattened_mask = mask.view(-1, 1).byte()
        alias_mask = alias_indices.view(batch_size, sequence_length, -1).gt(0)

        # The log-probability distribution is then given by taking the masked log softmax.
        generate_log_probs = masked_log_softmax(generate_scores,
                                                torch.ones_like(generate_scores))
        copy_log_probs = masked_log_softmax(copy_scores, alias_mask)

        # GENERATE LOSS ###
        # The generated token loss is a simple cross-entropy calculation, we can just gather
        # the log probabilties...
        flattened_log_probs = generate_log_probs.view(batch_size * sequence_length, -1)
        generate_log_probs = flattened_log_probs.gather(1, flattened_targets)
        # ...except we need to ignore the contribution of UNK tokens that are
        # copied (always in the simplified model). To do that we create a mask
        # which is 1 only if the token is not a copied UNK (or padding).
        unks = target_tokens.eq(self._unk_index).view(-1, 1)
        copied = target_alias_indices.gt(0).view(-1, 1)
        generate_mask = ~copied & flattened_mask
        # Since we are in log-space we apply the mask by addition.
        generate_log_probs = generate_log_probs + (generate_mask.float() + 1e-45).log()

        # COPY LOSS ###
github raylin1000 / drop-bert / drop_bert / augmented_bert_plus.py View on Github external
def _question_span_module(self, passage_vector, question_out, question_mask):
        # Shape: (batch_size, question_length)
        encoded_question_for_span_prediction = \
            torch.cat([question_out,
                       passage_vector.unsqueeze(1).repeat(1, question_out.size(1), 1)], -1)
        question_span_start_logits = \
            self._question_span_start_predictor(encoded_question_for_span_prediction).squeeze(-1)
        # Shape: (batch_size, question_length)
        question_span_end_logits = \
            self._question_span_end_predictor(encoded_question_for_span_prediction).squeeze(-1)
        question_span_start_log_probs = util.masked_log_softmax(question_span_start_logits, question_mask)
        question_span_end_log_probs = util.masked_log_softmax(question_span_end_logits, question_mask)

        # Info about the best question span prediction
        question_span_start_logits = \
            util.replace_masked_values(question_span_start_logits, question_mask, -1e7)
        question_span_end_logits = \
            util.replace_masked_values(question_span_end_logits, question_mask, -1e7)

        # Shape: (batch_size, 2)
        best_question_span = get_best_span(question_span_start_logits, question_span_end_logits)
        return question_span_start_log_probs, question_span_end_log_probs, best_question_span
github raylin1000 / drop-bert / drop_bert / augmented_bert_templated_old.py View on Github external
def _passage_span_module(self, passage_out, passage_mask):
        # Shape: (batch_size, passage_length)
        passage_span_start_logits = self._passage_span_start_predictor(passage_out).squeeze(-1)

        # Shape: (batch_size, passage_length)
        passage_span_end_logits = self._passage_span_end_predictor(passage_out).squeeze(-1)

        # Shape: (batch_size, passage_length)
        passage_span_start_log_probs = util.masked_log_softmax(passage_span_start_logits, passage_mask)
        passage_span_end_log_probs = util.masked_log_softmax(passage_span_end_logits, passage_mask)

        # Info about the best passage span prediction
        passage_span_start_logits = util.replace_masked_values(passage_span_start_logits, passage_mask, -1e7)
        passage_span_end_logits = util.replace_masked_values(passage_span_end_logits, passage_mask, -1e7)

        # Shape: (batch_size, 2)
        best_passage_span = get_best_span(passage_span_start_logits, passage_span_end_logits)
        return passage_span_start_log_probs, passage_span_end_log_probs, best_passage_span
github raylin1000 / drop-bert / drop_bert / augmented_bert_templated.py View on Github external
def _passage_span_module(self, passage_out, passage_mask):
        # Shape: (batch_size, passage_length)
        passage_span_start_logits = self._passage_span_start_predictor(passage_out).squeeze(-1)

        # Shape: (batch_size, passage_length)
        passage_span_end_logits = self._passage_span_end_predictor(passage_out).squeeze(-1)

        # Shape: (batch_size, passage_length)
        passage_span_start_log_probs = util.masked_log_softmax(passage_span_start_logits, passage_mask)
        passage_span_end_log_probs = util.masked_log_softmax(passage_span_end_logits, passage_mask)

        # Info about the best passage span prediction
        passage_span_start_logits = util.replace_masked_values(passage_span_start_logits, passage_mask, -1e7)
        passage_span_end_logits = util.replace_masked_values(passage_span_end_logits, passage_mask, -1e7)

        # Shape: (batch_size, 2)
        best_passage_span = get_best_span(passage_span_start_logits, passage_span_end_logits)
        return passage_span_start_log_probs, passage_span_end_log_probs, best_passage_span
github raylin1000 / drop-bert / drop_bert / augmented_bert_templated_old.py View on Github external
if self.num_special_numbers > 0:
            special_numbers = self.special_embedding(torch.arange(self.num_special_numbers, device=number_indices.device))
            special_numbers = special_numbers.expand(number_indices.shape[0],-1,-1)
            encoded_numbers = torch.cat([special_numbers, encoded_numbers], 1)
            
            mask = torch.ones((number_indices.shape[0],self.num_special_numbers), device=number_indices.device).long()
            number_mask = torch.cat([mask, number_mask], -1)
            
        # Shape: (batch_size, # of numbers, 2*bert_dim)
        encoded_numbers = torch.cat(
                [encoded_numbers, arithmetic_passage_vector.unsqueeze(1).repeat(1, encoded_numbers.size(1), 1)], -1)
        
        # Shape: (batch_size, #templates, #slots, #numbers)
        arithmetic_template_slot_logits = self._arithmetic_template_slot_predictor(encoded_numbers).transpose(1,2)
        arithmetic_template_slot_log_probs = util.masked_log_softmax(arithmetic_template_slot_logits, number_mask)
        arithmetic_template_slot_log_probs = arithmetic_template_slot_log_probs.reshape(number_mask.shape[0],
                                                                                        self.num_arithmetic_templates, 
                                                                                        self.num_template_slots, 
                                                                                        number_mask.shape[-1])
        # Shape: (batch_size, #templates, #slots)
        arithmetic_best_template_slots = arithmetic_template_slot_log_probs.argmax(-1)
        return arithmetic_template_slot_log_probs, arithmetic_best_template_slots, number_mask
github manuwhs / Trapyng / Examples / 5.1 AllenNLP / 3.main_BiDAF_Experiments.py View on Github external
best_span = BidirectionalAttentionFlow_1.get_best_span(span_start_logits, span_end_logits)

print ("best_spans", best_span)

"""
------------------------------ GET LOSES AND ACCURACIES -----------------------------------
"""
span_start_accuracy_function = CategoricalAccuracy()
span_end_accuracy_function = CategoricalAccuracy()
span_accuracy_function = BooleanAccuracy()
squad_metrics_function = SquadEmAndF1()

# Compute the loss for training.
if span_start is not None:
    span_start_loss = nll_loss(util.masked_log_softmax(span_start_logits, passage_mask), span_start.squeeze(-1))
    span_end_loss = nll_loss(util.masked_log_softmax(span_end_logits, passage_mask), span_end.squeeze(-1))
    loss = span_start_loss + span_end_loss
    
    span_start_accuracy_function(span_start_logits, span_start.squeeze(-1))
    span_end_accuracy_function(span_end_logits, span_end.squeeze(-1))
    span_accuracy_function(best_span, torch.stack([span_start, span_end], -1))

    span_start_accuracy = span_start_accuracy_function.get_metric()
    span_end_accuracy =  span_end_accuracy_function.get_metric()
    span_accuracy = span_accuracy_function.get_metric()


    print ("Loss: ", loss)
    print ("span_start_accuracy: ", span_start_accuracy)
    print ("span_start_accuracy: ", span_start_accuracy)
    print ("span_end_accuracy: ", span_end_accuracy)
github raylin1000 / drop-bert / drop_bert / augmented_bert_templated.py View on Github external
if self.num_special_numbers > 0:
            special_numbers = self.special_embedding(torch.arange(self.num_special_numbers, device=number_indices.device))
            special_numbers = special_numbers.expand(number_indices.shape[0],-1,-1)
            encoded_numbers = torch.cat([special_numbers, encoded_numbers], 1)
            
            mask = torch.ones((number_indices.shape[0],self.num_special_numbers), device=number_indices.device).long()
            number_mask = torch.cat([mask, number_mask], -1)
            
        # Shape: (batch_size, # of numbers, 2*bert_dim)
        encoded_numbers = torch.cat(
                [encoded_numbers, arithmetic_passage_vector.unsqueeze(1).repeat(1, encoded_numbers.size(1), 1)], -1)
        
        # Shape: (batch_size, #templates, #slots, #numbers)
        arithmetic_template_slot_logits = self._arithmetic_template_slot_predictor(encoded_numbers).transpose(1,2)
        arithmetic_template_slot_log_probs = util.masked_log_softmax(arithmetic_template_slot_logits, number_mask)
        arithmetic_template_slot_log_probs = arithmetic_template_slot_log_probs.reshape(number_mask.shape[0],
                                                                                        self.num_arithmetic_templates, 
                                                                                        self.num_template_slots, 
                                                                                        number_mask.shape[-1])
        # Shape: (batch_size, #templates, #slots)
        arithmetic_best_template_slots = arithmetic_template_slot_log_probs.argmax(-1)
        return arithmetic_template_slot_log_probs, arithmetic_best_template_slots, number_mask
github allenai / allennlp / allennlp / models / encoder_decoders / copynet_seq2seq.py View on Github external
Tuple[torch.Tensor, torch.Tensor]
            Shape: `(batch_size,), (batch_size, max_input_sequence_length)`
        """
        _, target_size = generation_scores.size()

        # The point of this mask is to just mask out all source token scores
        # that just represent padding. We apply the mask to the concatenation
        # of the generation scores and the copy scores to normalize the scores
        # correctly during the softmax.
        # shape: (batch_size, target_vocab_size + trimmed_source_length)
        mask = torch.cat((generation_scores_mask, copy_mask), dim=-1)
        # shape: (batch_size, target_vocab_size + trimmed_source_length)
        all_scores = torch.cat((generation_scores, copy_scores), dim=-1)
        # Normalize generation and copy scores.
        # shape: (batch_size, target_vocab_size + trimmed_source_length)
        log_probs = util.masked_log_softmax(all_scores, mask)
        # Calculate the log probability (`copy_log_probs`) for each token in the source sentence
        # that matches the current target token. We use the sum of these copy probabilities
        # for matching tokens in the source sentence to get the total probability
        # for the target token. We also need to normalize the individual copy probabilities
        # to create `selective_weights`, which are used in the next timestep to create
        # a selective read state.
        # shape: (batch_size, trimmed_source_length)
        copy_log_probs = log_probs[:, target_size:] + (target_to_source.float() + 1e-45).log()
        # Since `log_probs[:, target_size]` gives us the raw copy log probabilities,
        # we use a non-log softmax to get the normalized non-log copy probabilities.
        selective_weights = util.masked_softmax(log_probs[:, target_size:], target_to_source)
        # This mask ensures that item in the batch has a non-zero generation probabilities
        # for this timestep only when the gold target token is not OOV or there are no
        # matching tokens in the source sentence.
        # shape: (batch_size, 1)
        gen_mask = ((target_tokens != self._oov_index) | (target_to_source.sum(-1) == 0)).float()