Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
span_end_logits = util.replace_masked_values(
span_end_logits, passage_mask, -1e7)
best_span = self.get_best_span(span_start_logits, span_end_logits)
output_dict = {
"span_start_logits": span_start_logits,
"span_start_probs": span_start_probs,
"span_end_logits": span_end_logits,
"span_end_probs": span_end_probs,
"best_span": best_span,
}
if span_start is not None:
loss = nll_loss(util.masked_log_softmax(
span_start_logits, passage_mask), span_start.squeeze(-1))
self._span_start_accuracy(
span_start_logits, span_start.squeeze(-1))
loss += nll_loss(util.masked_log_softmax(span_end_logits,
passage_mask), span_end.squeeze(-1))
self._span_end_accuracy(span_end_logits, span_end.squeeze(-1))
self._span_accuracy(best_span, torch.stack(
[span_start, span_end], -1))
output_dict["loss"] = loss
# Compute the EM and F1 on SQuAD and add the tokenized input to the output.
if metadata is not None:
output_dict['best_span_str'] = []
question_tokens = []
passage_tokens = []
batch_size = question_embeded.size(0)
self._squad_metrics(best_span_string, answer_texts)
if (get_sample_level_information):
em_sample, f1_sample = bidut.get_em_f1_metrics(best_span_string,answer_texts)
output_dict["em_samples"].append(em_sample)
output_dict["f1_samples"].append(f1_sample)
output_dict['question_tokens'] = question_tokens
output_dict['passage_tokens'] = passage_tokens
if (get_sample_level_information):
# Add information about the individual samples for future analysis
output_dict["span_start_sample_loss"] = []
output_dict["span_end_sample_loss"] = []
for i in range (batch_size):
span_start_loss = nll_loss(util.masked_log_softmax(span_start_logits[[i],:], passage_mask[[i],:]), span_start.squeeze(-1)[[i]])
span_end_loss = nll_loss(util.masked_log_softmax(span_end_logits[[i],:], passage_mask[[i],:]), span_end.squeeze(-1)[[i]])
output_dict["span_start_sample_loss"].append(float(span_start_loss.detach().cpu().numpy()))
output_dict["span_end_sample_loss"].append(float(span_end_loss.detach().cpu().numpy()))
if(get_attentions):
output_dict["C2Q_attention"] = passage_question_attention
output_dict["Q2C_attention"] = question_passage_attention
output_dict["simmilarity"] = passage_question_similarity
return output_dict
mask: torch.Tensor,
alias_indices: torch.Tensor,
alias_tokens: torch.Tensor):
mention_mask = target_alias_indices.gt(0)
batch_size, sequence_length, vocab_size = generate_scores.shape
copy_sequence_length = copy_scores.shape[-1]
# Flat sequences make life **much** easier.
flattened_targets = target_tokens.view(batch_size * sequence_length, 1)
flattened_mask = mask.view(-1, 1).byte()
alias_mask = alias_indices.view(batch_size, sequence_length, -1).gt(0)
# The log-probability distribution is then given by taking the masked log softmax.
generate_log_probs = masked_log_softmax(generate_scores,
torch.ones_like(generate_scores))
copy_log_probs = masked_log_softmax(copy_scores, alias_mask)
# GENERATE LOSS ###
# The generated token loss is a simple cross-entropy calculation, we can just gather
# the log probabilties...
flattened_log_probs = generate_log_probs.view(batch_size * sequence_length, -1)
generate_log_probs = flattened_log_probs.gather(1, flattened_targets)
# ...except we need to ignore the contribution of UNK tokens that are
# copied (always in the simplified model). To do that we create a mask
# which is 1 only if the token is not a copied UNK (or padding).
unks = target_tokens.eq(self._unk_index).view(-1, 1)
copied = target_alias_indices.gt(0).view(-1, 1)
generate_mask = ~copied & flattened_mask
# Since we are in log-space we apply the mask by addition.
generate_log_probs = generate_log_probs + (generate_mask.float() + 1e-45).log()
# COPY LOSS ###
def _question_span_module(self, passage_vector, question_out, question_mask):
# Shape: (batch_size, question_length)
encoded_question_for_span_prediction = \
torch.cat([question_out,
passage_vector.unsqueeze(1).repeat(1, question_out.size(1), 1)], -1)
question_span_start_logits = \
self._question_span_start_predictor(encoded_question_for_span_prediction).squeeze(-1)
# Shape: (batch_size, question_length)
question_span_end_logits = \
self._question_span_end_predictor(encoded_question_for_span_prediction).squeeze(-1)
question_span_start_log_probs = util.masked_log_softmax(question_span_start_logits, question_mask)
question_span_end_log_probs = util.masked_log_softmax(question_span_end_logits, question_mask)
# Info about the best question span prediction
question_span_start_logits = \
util.replace_masked_values(question_span_start_logits, question_mask, -1e7)
question_span_end_logits = \
util.replace_masked_values(question_span_end_logits, question_mask, -1e7)
# Shape: (batch_size, 2)
best_question_span = get_best_span(question_span_start_logits, question_span_end_logits)
return question_span_start_log_probs, question_span_end_log_probs, best_question_span
def _passage_span_module(self, passage_out, passage_mask):
# Shape: (batch_size, passage_length)
passage_span_start_logits = self._passage_span_start_predictor(passage_out).squeeze(-1)
# Shape: (batch_size, passage_length)
passage_span_end_logits = self._passage_span_end_predictor(passage_out).squeeze(-1)
# Shape: (batch_size, passage_length)
passage_span_start_log_probs = util.masked_log_softmax(passage_span_start_logits, passage_mask)
passage_span_end_log_probs = util.masked_log_softmax(passage_span_end_logits, passage_mask)
# Info about the best passage span prediction
passage_span_start_logits = util.replace_masked_values(passage_span_start_logits, passage_mask, -1e7)
passage_span_end_logits = util.replace_masked_values(passage_span_end_logits, passage_mask, -1e7)
# Shape: (batch_size, 2)
best_passage_span = get_best_span(passage_span_start_logits, passage_span_end_logits)
return passage_span_start_log_probs, passage_span_end_log_probs, best_passage_span
def _passage_span_module(self, passage_out, passage_mask):
# Shape: (batch_size, passage_length)
passage_span_start_logits = self._passage_span_start_predictor(passage_out).squeeze(-1)
# Shape: (batch_size, passage_length)
passage_span_end_logits = self._passage_span_end_predictor(passage_out).squeeze(-1)
# Shape: (batch_size, passage_length)
passage_span_start_log_probs = util.masked_log_softmax(passage_span_start_logits, passage_mask)
passage_span_end_log_probs = util.masked_log_softmax(passage_span_end_logits, passage_mask)
# Info about the best passage span prediction
passage_span_start_logits = util.replace_masked_values(passage_span_start_logits, passage_mask, -1e7)
passage_span_end_logits = util.replace_masked_values(passage_span_end_logits, passage_mask, -1e7)
# Shape: (batch_size, 2)
best_passage_span = get_best_span(passage_span_start_logits, passage_span_end_logits)
return passage_span_start_log_probs, passage_span_end_log_probs, best_passage_span
if self.num_special_numbers > 0:
special_numbers = self.special_embedding(torch.arange(self.num_special_numbers, device=number_indices.device))
special_numbers = special_numbers.expand(number_indices.shape[0],-1,-1)
encoded_numbers = torch.cat([special_numbers, encoded_numbers], 1)
mask = torch.ones((number_indices.shape[0],self.num_special_numbers), device=number_indices.device).long()
number_mask = torch.cat([mask, number_mask], -1)
# Shape: (batch_size, # of numbers, 2*bert_dim)
encoded_numbers = torch.cat(
[encoded_numbers, arithmetic_passage_vector.unsqueeze(1).repeat(1, encoded_numbers.size(1), 1)], -1)
# Shape: (batch_size, #templates, #slots, #numbers)
arithmetic_template_slot_logits = self._arithmetic_template_slot_predictor(encoded_numbers).transpose(1,2)
arithmetic_template_slot_log_probs = util.masked_log_softmax(arithmetic_template_slot_logits, number_mask)
arithmetic_template_slot_log_probs = arithmetic_template_slot_log_probs.reshape(number_mask.shape[0],
self.num_arithmetic_templates,
self.num_template_slots,
number_mask.shape[-1])
# Shape: (batch_size, #templates, #slots)
arithmetic_best_template_slots = arithmetic_template_slot_log_probs.argmax(-1)
return arithmetic_template_slot_log_probs, arithmetic_best_template_slots, number_mask
best_span = BidirectionalAttentionFlow_1.get_best_span(span_start_logits, span_end_logits)
print ("best_spans", best_span)
"""
------------------------------ GET LOSES AND ACCURACIES -----------------------------------
"""
span_start_accuracy_function = CategoricalAccuracy()
span_end_accuracy_function = CategoricalAccuracy()
span_accuracy_function = BooleanAccuracy()
squad_metrics_function = SquadEmAndF1()
# Compute the loss for training.
if span_start is not None:
span_start_loss = nll_loss(util.masked_log_softmax(span_start_logits, passage_mask), span_start.squeeze(-1))
span_end_loss = nll_loss(util.masked_log_softmax(span_end_logits, passage_mask), span_end.squeeze(-1))
loss = span_start_loss + span_end_loss
span_start_accuracy_function(span_start_logits, span_start.squeeze(-1))
span_end_accuracy_function(span_end_logits, span_end.squeeze(-1))
span_accuracy_function(best_span, torch.stack([span_start, span_end], -1))
span_start_accuracy = span_start_accuracy_function.get_metric()
span_end_accuracy = span_end_accuracy_function.get_metric()
span_accuracy = span_accuracy_function.get_metric()
print ("Loss: ", loss)
print ("span_start_accuracy: ", span_start_accuracy)
print ("span_start_accuracy: ", span_start_accuracy)
print ("span_end_accuracy: ", span_end_accuracy)
if self.num_special_numbers > 0:
special_numbers = self.special_embedding(torch.arange(self.num_special_numbers, device=number_indices.device))
special_numbers = special_numbers.expand(number_indices.shape[0],-1,-1)
encoded_numbers = torch.cat([special_numbers, encoded_numbers], 1)
mask = torch.ones((number_indices.shape[0],self.num_special_numbers), device=number_indices.device).long()
number_mask = torch.cat([mask, number_mask], -1)
# Shape: (batch_size, # of numbers, 2*bert_dim)
encoded_numbers = torch.cat(
[encoded_numbers, arithmetic_passage_vector.unsqueeze(1).repeat(1, encoded_numbers.size(1), 1)], -1)
# Shape: (batch_size, #templates, #slots, #numbers)
arithmetic_template_slot_logits = self._arithmetic_template_slot_predictor(encoded_numbers).transpose(1,2)
arithmetic_template_slot_log_probs = util.masked_log_softmax(arithmetic_template_slot_logits, number_mask)
arithmetic_template_slot_log_probs = arithmetic_template_slot_log_probs.reshape(number_mask.shape[0],
self.num_arithmetic_templates,
self.num_template_slots,
number_mask.shape[-1])
# Shape: (batch_size, #templates, #slots)
arithmetic_best_template_slots = arithmetic_template_slot_log_probs.argmax(-1)
return arithmetic_template_slot_log_probs, arithmetic_best_template_slots, number_mask
Tuple[torch.Tensor, torch.Tensor]
Shape: `(batch_size,), (batch_size, max_input_sequence_length)`
"""
_, target_size = generation_scores.size()
# The point of this mask is to just mask out all source token scores
# that just represent padding. We apply the mask to the concatenation
# of the generation scores and the copy scores to normalize the scores
# correctly during the softmax.
# shape: (batch_size, target_vocab_size + trimmed_source_length)
mask = torch.cat((generation_scores_mask, copy_mask), dim=-1)
# shape: (batch_size, target_vocab_size + trimmed_source_length)
all_scores = torch.cat((generation_scores, copy_scores), dim=-1)
# Normalize generation and copy scores.
# shape: (batch_size, target_vocab_size + trimmed_source_length)
log_probs = util.masked_log_softmax(all_scores, mask)
# Calculate the log probability (`copy_log_probs`) for each token in the source sentence
# that matches the current target token. We use the sum of these copy probabilities
# for matching tokens in the source sentence to get the total probability
# for the target token. We also need to normalize the individual copy probabilities
# to create `selective_weights`, which are used in the next timestep to create
# a selective read state.
# shape: (batch_size, trimmed_source_length)
copy_log_probs = log_probs[:, target_size:] + (target_to_source.float() + 1e-45).log()
# Since `log_probs[:, target_size]` gives us the raw copy log probabilities,
# we use a non-log softmax to get the normalized non-log copy probabilities.
selective_weights = util.masked_softmax(log_probs[:, target_size:], target_to_source)
# This mask ensures that item in the batch has a non-zero generation probabilities
# for this timestep only when the gold target token is not OOV or there are no
# matching tokens in the source sentence.
# shape: (batch_size, 1)
gen_mask = ((target_tokens != self._oov_index) | (target_to_source.sum(-1) == 0)).float()