Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
####################################
# Perform Q by A attention
# [batch_size, 4, question_length, answer_length]
qa_similarity = self.span_attention(
q_rep.view(q_rep.shape[0] * q_rep.shape[1], q_rep.shape[2], q_rep.shape[3]),
a_rep.view(a_rep.shape[0] * a_rep.shape[1], a_rep.shape[2], a_rep.shape[3]),
).view(a_rep.shape[0], a_rep.shape[1], q_rep.shape[2], a_rep.shape[2])
qa_attention_weights = masked_softmax(qa_similarity, question_mask[..., None], dim=2)
attended_q = torch.einsum('bnqa,bnqd->bnad', (qa_attention_weights, q_rep))
# Have a second attention over the objects, do A by Objs
# [batch_size, 4, answer_length, num_objs]
atoo_similarity = self.obj_attention(a_rep.view(a_rep.shape[0], a_rep.shape[1] * a_rep.shape[2], -1),
obj_reps['obj_reps']).view(a_rep.shape[0], a_rep.shape[1],
a_rep.shape[2], obj_reps['obj_reps'].shape[1])
atoo_attention_weights = masked_softmax(atoo_similarity, box_mask[:,None,None])
attended_o = torch.einsum('bnao,bod->bnad', (atoo_attention_weights, obj_reps['obj_reps']))
reasoning_inp = torch.cat([x for x, to_pool in [(a_rep, self.reasoning_use_answer),
(attended_o, self.reasoning_use_obj),
(attended_q, self.reasoning_use_question)]
if to_pool], -1)
if self.rnn_input_dropout is not None:
reasoning_inp = self.rnn_input_dropout(reasoning_inp)
reasoning_output = self.reasoning_encoder(reasoning_inp, answer_mask)
###########################################
things_to_pool = torch.cat([x for x, to_pool in [(reasoning_output, self.pool_reasoning),
(a_rep, self.pool_answer),
# Shape: (batch_size, seqlen)
alpha = self._passage_weights_predictor(encoding).squeeze()
elif in_type == "question":
# Shape: (batch_size, seqlen)
alpha = self._question_weights_predictor(encoding).squeeze()
elif in_type == "arithmetic":
# Shape: (batch_size, seqlen)
alpha = self._arithmetic_weights_predictor(encoding).squeeze()
else:
# Shape: (batch_size, #num of numbers, seqlen)
alpha = torch.zeros(encoding.shape[:-1], device=encoding.device)
if self.number_rep == 'attention':
alpha = self._number_weights_predictor(encoding).squeeze()
# Shape: (batch_size, seqlen)
# (batch_size, #num of numbers, seqlen) for numbers
alpha = masked_softmax(alpha, mask)
# Shape: (batch_size, out)
# (batch_size, #num of numbers, out) for numbers
h = util.weighted_sum(encoding, alpha)
return h
)
# The recurrent modeling layers. Since these layers share the same parameters,
# we don't construct them conditioned on answering abilities.
modeled_passage_list = [self._modeling_proj_layer(merged_passage_attention_vectors)]
for _ in range(4):
modeled_passage = self._dropout(
self._modeling_layer(modeled_passage_list[-1], passage_mask)
)
modeled_passage_list.append(modeled_passage)
# Pop the first one, which is input
modeled_passage_list.pop(0)
# The first modeling layer is used to calculate the vector representation of passage
passage_weights = self._passage_weights_predictor(modeled_passage_list[0]).squeeze(-1)
passage_weights = masked_softmax(passage_weights, passage_mask)
passage_vector = util.weighted_sum(modeled_passage_list[0], passage_weights)
# The vector representation of question is calculated based on the unmatched encoding,
# because we may want to infer the answer ability only based on the question words.
question_weights = self._question_weights_predictor(encoded_question).squeeze(-1)
question_weights = masked_softmax(question_weights, question_mask)
question_vector = util.weighted_sum(encoded_question, question_weights)
if len(self.answering_abilities) > 1:
# Shape: (batch_size, number_of_abilities)
answer_ability_logits = self._answer_ability_predictor(
torch.cat([passage_vector, question_vector], -1)
)
answer_ability_log_probs = torch.nn.functional.log_softmax(answer_ability_logits, -1)
best_answer_ability = torch.argmax(answer_ability_log_probs, 1)
if "counting" in self.answering_abilities:
) -> Dict[str, torch.Tensor]:
cost_function = supervision
finished_states = self._get_finished_states(initial_state, transition_function)
loss = initial_state.score[0].new_zeros(1)
finished_model_scores = self._get_model_scores_by_batch(finished_states)
finished_costs = self._get_costs_by_batch(finished_states, cost_function)
for batch_index in finished_model_scores:
# Finished model scores are log-probabilities of the predicted sequences. We convert
# log probabilities into probabilities and re-normalize them to compute expected cost under
# the distribution approximated by the beam search.
costs = torch.cat([tensor.view(-1) for tensor in finished_costs[batch_index]])
logprobs = torch.cat([tensor.view(-1) for tensor in finished_model_scores[batch_index]])
# Unmasked softmax of log probabilities will convert them into probabilities and
# renormalize them.
renormalized_probs = nn_util.masked_softmax(logprobs, None)
loss += renormalized_probs.dot(costs)
mean_loss = loss / len(finished_model_scores)
return {
"loss": mean_loss,
"best_final_states": self._get_best_final_states(finished_states),
}
# Simple Pooling layers
max_masked_integrated_encodings = util.replace_masked_values(
integrated_encodings, text_mask.unsqueeze(2), -1e7)
max_pool = torch.max(max_masked_integrated_encodings, 1)[0]
min_masked_integrated_encodings = util.replace_masked_values(
integrated_encodings, text_mask.unsqueeze(2), +1e7)
min_pool = torch.min(min_masked_integrated_encodings, 1)[0]
mean_pool = torch.sum(integrated_encodings, 1) / torch.sum(text_mask, 1, keepdim=True)
# Self-attentive pooling layer
# Run through linear projection. Shape: (batch_size, sequence length, 1)
# Then remove the last dimension to get the proper attention shape (batch_size, sequence length).
self_attentive_logits = self._self_attentive_pooling_projection(
integrated_encodings).squeeze(2)
self_weights = util.masked_softmax(self_attentive_logits, text_mask)
self_attentive_pool = util.weighted_sum(integrated_encodings, self_weights)
pooled_representations = torch.cat([max_pool, min_pool, mean_pool, self_attentive_pool], 1)
pooled_representations_dropped = self._integrator_dropout(pooled_representations)
logits = self._output_layer(pooled_representations_dropped)
class_probabilities = F.softmax(logits, dim=-1)
output_dict = {u'logits': logits, u'class_probabilities': class_probabilities}
if label is not None:
loss = self.loss(logits, label)
for metric in list(self.metrics.values()):
metric(logits, label)
output_dict[u"loss"] = loss
return output_dict
# Shape: (batch_size, passage_length, question_length)
passage_question_similarity = self._matrix_attention(encoded_passage, encoded_question)
# Shape: (batch_size, passage_length, question_length)
passage_question_attention = util.masked_softmax(passage_question_similarity, question_mask)
# Shape: (batch_size, passage_length, encoding_dim)
passage_question_vectors = util.weighted_sum(encoded_question, passage_question_attention)
# We replace masked values with something really negative here, so they don't affect the
# max below.
masked_similarity = util.replace_masked_values(passage_question_similarity,
question_mask.unsqueeze(1),
-1e7)
# Shape: (batch_size, passage_length)
question_passage_similarity = masked_similarity.max(dim=-1)[0].squeeze(-1)
# Shape: (batch_size, passage_length)
question_passage_attention = util.masked_softmax(question_passage_similarity, passage_mask)
# Shape: (batch_size, encoding_dim)
question_passage_vector = util.weighted_sum(encoded_passage, question_passage_attention)
# Shape: (batch_size, passage_length, encoding_dim)
tiled_question_passage_vector = question_passage_vector.unsqueeze(1).expand(batch_size,
passage_length,
encoding_dim)
# Shape: (batch_size, passage_length, encoding_dim * 4)
final_merged_passage = torch.cat([encoded_passage,
passage_question_vectors,
encoded_passage * passage_question_vectors,
encoded_passage * tiled_question_passage_vector],
dim=-1)
modeled_passage = self._dropout_modeling_passage(self._modeling_layer(final_merged_passage, passage_lstm_mask))
modeling_dim = modeled_passage.size(-1)
# options : (batch, opnumlen, bert_dim)
# options_mask : (batch, opnumlen)
# bert_out : (batch, seqlen, bert_dim)
# bert_mask: (batch, seqlen)
# summary_vector : (batch, explen, bert_dim)
summary_vector = summary_vector.unsqueeze(1).expand(-1,maxlen,-1)
# out: (batch, explen, rnn_hdim)
out, _ = self.rnn(summary_vector)
out = self.rnndropout(out)
out = self.Wst(out)
# alpha : (batch, explen, seqlen)
alpha = torch.bmm(out, bert_out.transpose(1,2))
alpha = util.masked_softmax(alpha, bert_mask)
# context : (batch, explen, bert_dim)
context = util.weighted_sum(bert_out, alpha)
# context = self.Wcon(context)
# logits : (batch, explen, opnumlen)
logits = torch.bmm(context, options.transpose(1,2))
logits = util.replace_masked_values(logits, options_mask.unsqueeze(1).expand_as(logits), -1e7)
number_mask = options_mask.clone()
number_mask[:,:self.num_ops] = 0
op_mask = options_mask.clone()
op_mask[:,self.num_ops:] = 0
best_expression = beam_search(self.arithmetic_K, logits.softmax(-1),\
number_mask, op_mask, self.END, self.num_ops)
# Shape: (batch_size, seqlen)
alpha = self._question_weights_predictor(encoding).squeeze()
elif in_type == "arithmetic":
# Shape: (batch_size, seqlen)
alpha = self._arithmetic_weights_predictor(encoding).squeeze()
elif in_type == "multiple_spans":
#TODO: currenttly not using it...
alpha = self._multispan_weights_predictor(encoding).squeeze()
else:
# Shape: (batch_size, #num of numbers, seqlen)
alpha = torch.zeros(encoding.shape[:-1], device=encoding.device)
if self.number_rep == 'attention':
alpha = self._number_weights_predictor(encoding).squeeze()
# Shape: (batch_size, seqlen)
# (batch_size, #num of numbers, seqlen) for numbers
alpha = masked_softmax(alpha, mask)
# Shape: (batch_size, out)
# (batch_size, #num of numbers, out) for numbers
h = util.weighted_sum(encoding, alpha)
return h
def _question_pooling(self, question, question_mask):
V_q = self.V_q.expand(question.size(0), question.size(1), -1)
logits = self.question_linear(torch.cat([question, V_q], dim=-1)).squeeze(-1)
score = masked_softmax(logits, question_mask, dim=0)
state = torch.sum(score.unsqueeze(-1) * question, dim=0)
return state