Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
pre_encoded_sent = self._pre_encode_feedforward(dropped_embedded_sent)
encoded_tokens = self._encoder(pre_encoded_sent, sentence_mask)
# Compute biattention. This is a special case since the inputs are the same.
attention_logits = encoded_tokens.bmm(encoded_tokens.permute(0, 2, 1).contiguous())
attention_weights = util.last_dim_softmax(attention_logits, sentence_mask)
encoded_sentence = util.weighted_sum(encoded_tokens, attention_weights)
# Build the input to the integrator
integrator_input = torch.cat([encoded_tokens,
encoded_tokens - encoded_sentence,
encoded_tokens * encoded_sentence], 2)
integrated_encodings = self._integrator(integrator_input, sentence_mask)
# Simple Pooling layers
max_masked_integrated_encodings = util.replace_masked_values(
integrated_encodings, sentence_mask.unsqueeze(2), -1e7)
max_pool = torch.max(max_masked_integrated_encodings, 1)[0]
min_masked_integrated_encodings = util.replace_masked_values(
integrated_encodings, sentence_mask.unsqueeze(2), +1e7)
min_pool = torch.min(min_masked_integrated_encodings, 1)[0]
mean_pool = torch.sum(integrated_encodings, 1) / torch.sum(sentence_mask, 1, keepdim=True)
# Self-attentive pooling layer
# Run through linear projection. Shape: (batch_size, sequence length, 1)
# Then remove the last dimension to get the proper attention shape (batch_size, sequence length).
self_attentive_logits = self._self_attentive_pooling_projection(integrated_encodings).squeeze(2)
self_weights = util.masked_softmax(self_attentive_logits, sentence_mask)
self_attentive_pool = util.weighted_sum(integrated_encodings, self_weights)
pooled_representations = torch.cat([max_pool, min_pool, mean_pool, self_attentive_pool], 1)
pooled_representations_dropped = self._integrator_dropout(pooled_representations).squeeze(1)
number_sign_log_probs = torch.nn.functional.log_softmax(number_sign_logits, -1)
# Shape: (batch_size, # of numbers in passage).
best_signs_for_numbers = torch.argmax(number_sign_log_probs, -1)
# For padding numbers, the best sign masked as 0 (not included).
best_signs_for_numbers = util.replace_masked_values(
best_signs_for_numbers, number_mask, 0
)
# Shape: (batch_size, # of numbers in passage)
best_signs_log_probs = torch.gather(
number_sign_log_probs, 2, best_signs_for_numbers.unsqueeze(-1)
).squeeze(-1)
# the probs of the masked positions should be 1 so that it will not affect the joint probability
# TODO: this is not quite right, since if there are many numbers in the passage,
# TODO: the joint probability would be very small.
best_signs_log_probs = util.replace_masked_values(best_signs_log_probs, number_mask, 0)
# Shape: (batch_size,)
best_combination_log_prob = best_signs_log_probs.sum(-1)
if len(self.answering_abilities) > 1:
best_combination_log_prob += answer_ability_log_probs[
:, self._addition_subtraction_index
]
output_dict = {}
# If answer is given, compute the loss.
if (
answer_as_passage_spans is not None
or answer_as_question_spans is not None
or answer_as_add_sub_expressions is not None
or answer_as_counts is not None
):
for _ in range(3):
modeled_passage = self._dropout(
self._modeling_layer(modeled_passage_list[-1], passage_mask)
)
modeled_passage_list.append(modeled_passage)
# Shape: (batch_size, passage_length, modeling_dim * 2))
span_start_input = torch.cat([modeled_passage_list[-3], modeled_passage_list[-2]], dim=-1)
# Shape: (batch_size, passage_length)
span_start_logits = self._span_start_predictor(span_start_input).squeeze(-1)
# Shape: (batch_size, passage_length, modeling_dim * 2)
span_end_input = torch.cat([modeled_passage_list[-3], modeled_passage_list[-1]], dim=-1)
span_end_logits = self._span_end_predictor(span_end_input).squeeze(-1)
span_start_logits = util.replace_masked_values(span_start_logits, passage_mask, -1e32)
span_end_logits = util.replace_masked_values(span_end_logits, passage_mask, -1e32)
# Shape: (batch_size, passage_length)
span_start_probs = torch.nn.functional.softmax(span_start_logits, dim=-1)
span_end_probs = torch.nn.functional.softmax(span_end_logits, dim=-1)
best_span = get_best_span(span_start_logits, span_end_logits)
output_dict = {
"passage_question_attention": passage_question_attention,
"span_start_logits": span_start_logits,
"span_start_probs": span_start_probs,
"span_end_logits": span_end_logits,
"span_end_probs": span_end_probs,
"best_span": best_span,
}
log_marginal_likelihood_for_passage_span = util.logsumexp(
log_likelihood_for_passage_spans
)
log_marginal_likelihood_list.append(log_marginal_likelihood_for_passage_span)
elif answering_ability == "question_span_extraction":
# Shape: (batch_size, # of answer spans)
gold_question_span_starts = answer_as_question_spans[:, :, 0]
gold_question_span_ends = answer_as_question_spans[:, :, 1]
# Some spans are padded with index -1,
# so we clamp those paddings to 0 and then mask after `torch.gather()`.
gold_question_span_mask = (gold_question_span_starts != -1).long()
clamped_gold_question_span_starts = util.replace_masked_values(
gold_question_span_starts, gold_question_span_mask, 0
)
clamped_gold_question_span_ends = util.replace_masked_values(
gold_question_span_ends, gold_question_span_mask, 0
)
# Shape: (batch_size, # of answer spans)
log_likelihood_for_question_span_starts = torch.gather(
question_span_start_log_probs, 1, clamped_gold_question_span_starts
)
log_likelihood_for_question_span_ends = torch.gather(
question_span_end_log_probs, 1, clamped_gold_question_span_ends
)
# Shape: (batch_size, # of answer spans)
log_likelihood_for_question_spans = (
log_likelihood_for_question_span_starts
+ log_likelihood_for_question_span_ends
)
# For those padded spans, we set their log probabilities to be very small negative value
log_likelihood_for_question_spans = util.replace_masked_values(
start_rep = self._span_start_encoder(final_merged_passage, repeated_passage_mask)
span_start_logits = self._span_start_predictor(start_rep).squeeze(-1)
end_rep = self._span_end_encoder(
torch.cat([final_merged_passage, start_rep], dim=-1), repeated_passage_mask
)
span_end_logits = self._span_end_predictor(end_rep).squeeze(-1)
span_yesno_logits = self._span_yesno_predictor(end_rep).squeeze(-1)
span_followup_logits = self._span_followup_predictor(end_rep).squeeze(-1)
span_start_logits = util.replace_masked_values(
span_start_logits, repeated_passage_mask, -1e7
)
# batch_size * maxqa_len_pair, max_document_len
span_end_logits = util.replace_masked_values(span_end_logits, repeated_passage_mask, -1e7)
best_span = self._get_best_span_yesno_followup(
span_start_logits,
span_end_logits,
span_yesno_logits,
span_followup_logits,
self._max_span_length,
)
output_dict: Dict[str, Any] = {}
# Compute the loss.
if span_start is not None:
loss = nll_loss(
util.masked_log_softmax(span_start_logits, repeated_passage_mask),
span_start.view(-1),
# contextual_q = self._variational_dropout(self.contextual_layer_q(fused_q, question_mask))
contextual_q = self.contextual_layer_q(fused_q, question_mask)
# cnt * m
gamma = util.masked_softmax(self.linear_self_align(contextual_q).squeeze(2), question_mask, dim=1)
# cnt * h
weighted_q = torch.bmm(gamma.unsqueeze(1), contextual_q).squeeze(1)
span_start_logits = self.bilinear_layer_s(weighted_q, contextual_p)
span_end_logits = self.bilinear_layer_e(weighted_q, contextual_p)
# cnt * n * 1 cnt * 1 * h
span_yesno_logits = self.yesno_predictor(torch.bmm(span_end_logits.unsqueeze(2), weighted_q.unsqueeze(1)))
# span_yesno_logits = self.yesno_predictor(contextual_p)
span_start_logits = util.replace_masked_values(span_start_logits, repeated_passage_mask, -1e7)
span_end_logits = util.replace_masked_values(span_end_logits, repeated_passage_mask, -1e7)
best_span = self._get_best_span_yesno_followup(span_start_logits, span_end_logits, span_yesno_logits, self._max_span_length)
output_dict: Dict[str, Any] = {}
# Compute the loss for training
if span_start is not None:
loss = nll_loss(util.masked_log_softmax(span_start_logits, repeated_passage_mask), span_start.view(-1), ignore_index=-1)
self._span_start_accuracy(span_start_logits, span_start.view(-1), mask=qa_mask)
loss += nll_loss(util.masked_log_softmax(span_end_logits, repeated_passage_mask), span_end.view(-1), ignore_index=-1)
self._span_end_accuracy(span_end_logits, span_end.view(-1), mask=qa_mask)
self._span_accuracy(best_span[:, 0:2],
torch.stack([span_start, span_end], -1).view(total_qa_count, 2),
mask=qa_mask.unsqueeze(1).expand(-1, 2).long())
# Shape: (extended_batch_size, passage_length, question_length)
# these are the a_ij in the paper
passage_question_similarity = self._matrix_attention(encoded_passage, encoded_question)
# Shape: (extended_batch_size, passage_length, question_length)
# these are the p_ij in the paper
passage_question_attention = util.last_dim_softmax(passage_question_similarity, question_mask)
# Shape: (extended_batch_size, passage_length, encoding_dim)
# these are the c_i in the paper
passage_question_vectors = util.weighted_sum(encoded_question, passage_question_attention)
# We replace masked values with something really negative here, so they don't affect the
# max below.
# Shape: (extended_batch_size, passage_length, question_length)
masked_similarity = util.replace_masked_values(passage_question_similarity,
question_mask.unsqueeze(1),
-1e7)
# Take the max over the last dimension (all question words)
# Shape: (extended_batch_size, passage_length)
question_passage_similarity = masked_similarity.max(dim=-1)[0]
# masked_softmax operates over the last (i.e. passage_length) dimension
# Shape: (extended_batch_size, passage_length)
question_passage_attention = util.masked_softmax(question_passage_similarity, passage_mask)
# Shape: (extended_batch_size, encoding_dim),
# these are the q_c in the paper
question_passage_vector = util.weighted_sum(encoded_passage, question_passage_attention)
# Shape: (extended_batch_size, passage_length, encoding_dim)
def _question_span_module(self, passage_vector, question_out, question_mask):
# Shape: (batch_size, question_length)
encoded_question_for_span_prediction = \
torch.cat([question_out,
passage_vector.unsqueeze(1).repeat(1, question_out.size(1), 1)], -1)
question_span_start_logits = \
self._question_span_start_predictor(encoded_question_for_span_prediction).squeeze(-1)
# Shape: (batch_size, question_length)
question_span_end_logits = \
self._question_span_end_predictor(encoded_question_for_span_prediction).squeeze(-1)
question_span_start_log_probs = util.masked_log_softmax(question_span_start_logits, question_mask)
question_span_end_log_probs = util.masked_log_softmax(question_span_end_logits, question_mask)
# Info about the best question span prediction
question_span_start_logits = \
util.replace_masked_values(question_span_start_logits, question_mask, -1e7)
question_span_end_logits = \
util.replace_masked_values(question_span_end_logits, question_mask, -1e7)
# Shape: (batch_size, 2)
best_question_span = get_best_span(question_span_start_logits, question_span_end_logits)
return question_span_start_log_probs, question_span_end_log_probs, best_question_span
gold_passage_span_ends, gold_passage_span_mask, 0
)
# Shape: (batch_size, # of answer spans)
log_likelihood_for_passage_span_starts = torch.gather(
passage_span_start_log_probs, 1, clamped_gold_passage_span_starts
)
log_likelihood_for_passage_span_ends = torch.gather(
passage_span_end_log_probs, 1, clamped_gold_passage_span_ends
)
# Shape: (batch_size, # of answer spans)
log_likelihood_for_passage_spans = (
log_likelihood_for_passage_span_starts
+ log_likelihood_for_passage_span_ends
)
# For those padded spans, we set their log probabilities to be very small negative value
log_likelihood_for_passage_spans = util.replace_masked_values(
log_likelihood_for_passage_spans, gold_passage_span_mask, -1e7
)
# Shape: (batch_size, )
log_marginal_likelihood_for_passage_span = util.logsumexp(
log_likelihood_for_passage_spans
)
log_marginal_likelihood_list.append(log_marginal_likelihood_for_passage_span)
elif answering_ability == "question_span_extraction":
# Shape: (batch_size, # of answer spans)
gold_question_span_starts = answer_as_question_spans[:, :, 0]
gold_question_span_ends = answer_as_question_spans[:, :, 1]
# Some spans are padded with index -1,
# so we clamp those paddings to 0 and then mask after `torch.gather()`.
gold_question_span_mask = (gold_question_span_starts != -1).long()
clamped_gold_question_span_starts = util.replace_masked_values(
def _question_span_module(self, passage_vector, question_out, question_mask):
# Shape: (batch_size, question_length)
encoded_question_for_span_prediction = \
torch.cat([question_out,
passage_vector.unsqueeze(1).repeat(1, question_out.size(1), 1)], -1)
question_span_start_logits = \
self._question_span_start_predictor(encoded_question_for_span_prediction).squeeze(-1)
# Shape: (batch_size, question_length)
question_span_end_logits = \
self._question_span_end_predictor(encoded_question_for_span_prediction).squeeze(-1)
question_span_start_log_probs = util.masked_log_softmax(question_span_start_logits, question_mask)
question_span_end_log_probs = util.masked_log_softmax(question_span_end_logits, question_mask)
# Info about the best question span prediction
question_span_start_logits = \
util.replace_masked_values(question_span_start_logits, question_mask, -1e7)
question_span_end_logits = \
util.replace_masked_values(question_span_end_logits, question_mask, -1e7)
# Shape: (batch_size, 2)
best_question_span = get_best_span(question_span_start_logits, question_span_end_logits)
return question_span_start_log_probs, question_span_end_log_probs, best_question_span