Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
# Simple Pooling layers
max_masked_integrated_encodings = util.replace_masked_values(
integrated_encodings, sentence_mask.unsqueeze(2), -1e7)
max_pool = torch.max(max_masked_integrated_encodings, 1)[0]
min_masked_integrated_encodings = util.replace_masked_values(
integrated_encodings, sentence_mask.unsqueeze(2), +1e7)
min_pool = torch.min(min_masked_integrated_encodings, 1)[0]
mean_pool = torch.sum(integrated_encodings, 1) / torch.sum(sentence_mask, 1, keepdim=True)
# Self-attentive pooling layer
# Run through linear projection. Shape: (batch_size, sequence length, 1)
# Then remove the last dimension to get the proper attention shape (batch_size, sequence length).
self_attentive_logits = self._self_attentive_pooling_projection(integrated_encodings).squeeze(2)
self_weights = util.masked_softmax(self_attentive_logits, sentence_mask)
self_attentive_pool = util.weighted_sum(integrated_encodings, self_weights)
pooled_representations = torch.cat([max_pool, min_pool, mean_pool, self_attentive_pool], 1)
pooled_representations_dropped = self._integrator_dropout(pooled_representations).squeeze(1)
logits = self._output_layer(pooled_representations_dropped)
output_dict = {'logits': logits}
if label is not None:
loss = self.loss(logits, label.squeeze(-1))
for metric in self.metrics.values():
metric(logits, label.squeeze(-1))
output_dict["loss"] = loss
return output_dict
# Shape: (batch_size, seqlen)
alpha = self._question_weights_predictor(encoding).squeeze()
elif in_type == "arithmetic":
# Shape: (batch_size, seqlen)
alpha = self._arithmetic_weights_predictor(encoding).squeeze()
else:
# Shape: (batch_size, #num of numbers, seqlen)
alpha = torch.zeros(encoding.shape[:-1], device=encoding.device)
if self.number_rep == 'attention':
alpha = self._number_weights_predictor(encoding).squeeze()
# Shape: (batch_size, seqlen)
# (batch_size, #num of numbers, seqlen) for numbers
alpha = masked_softmax(alpha, mask)
# Shape: (batch_size, out)
# (batch_size, #num of numbers, out) for numbers
h = util.weighted_sum(encoding, alpha)
return h
neighbor_mask = util.get_text_field_mask({'ignored': neighbor_indices + 1},
num_wrapping_dims=1).float()
# Encoder initialized to easily obtain a masked average.
neighbor_encoder = TimeDistributed(BagOfEmbeddingsEncoder(self._embedding_dim, averaged=True))
# (batch_size, num_entities, embedding_dim)
embedded_neighbors = neighbor_encoder(embedded_neighbors, neighbor_mask)
projected_neighbor_embeddings = self._neighbor_params(embedded_neighbors.float())
# (batch_size, num_entities, embedding_dim)
entity_embeddings = torch.tanh(entity_type_embeddings + projected_neighbor_embeddings)
else:
# (batch_size, num_entities, embedding_dim)
entity_embeddings = torch.tanh(entity_type_embeddings)
link_embedding = util.weighted_sum(entity_embeddings, linking_probabilities)
encoder_input = torch.cat([link_embedding, embedded_utterance], 2)
# (batch_size, utterance_length, encoder_output_dim)
encoder_outputs = self._dropout(self._encoder(encoder_input, utterance_mask))
# compute the relevance of each entity with the relevance GNN
ent_relevance, ent_relevance_logits, ent_to_qst_lnk_probs = self._graph_pruning(worlds,
encoder_outputs,
entity_type_embeddings,
linking_scores,
utterance_mask,
self._get_graph_adj_lists)
# save this for loss calculation
self.predicted_relevance_logits = ent_relevance_logits
# multiply the embedding with the computed relevance
hypothesis_mask: torch.Tensor = None,
premise_token_weights: torch.Tensor = None) -> Dict[str, torch.Tensor]: # pylint: disable=unused-argument
p2h_attention = masked_softmax(similarity_matrix, hypothesis_mask)
# Shape: (batch_size, premise_length, embedding_dim)
attended_hypothesis = weighted_sum(encoded_hypothesis, p2h_attention)
h2p_attention = masked_softmax(similarity_matrix.transpose(1, 2).contiguous(), premise_mask)
if premise_token_weights is not None:
h2p_attention = premise_token_weights.unsqueeze(1) * h2p_attention
h2p_attention = masked_divide(h2p_attention, h2p_attention.sum(dim=-1, keepdim=True))
# Shape: (batch_size, hypothesis_length, embedding_dim)
attended_premise = weighted_sum(encoded_premise, h2p_attention)
# the "enhancement" layer
premise_enhanced = torch.cat([encoded_premise, attended_hypothesis,
encoded_premise - attended_hypothesis,
encoded_premise * attended_hypothesis], dim=-1)
hypothesis_enhanced = torch.cat([encoded_hypothesis, attended_premise,
encoded_hypothesis - attended_premise,
encoded_hypothesis * attended_premise], dim=-1)
# The projection layer down to the model dimension. Dropout is not applied before
# projection.
projected_enhanced_premise = self._projection_feedforward(premise_enhanced)
projected_enhanced_hypothesis = self._projection_feedforward(hypothesis_enhanced)
# Run the inference layer
if self.rnn_input_dropout:
# Shape: (batch_size, passage_length, question_length)
passage_question_attention = util.masked_softmax(passage_question_similarity, question_mask)
# Shape: (batch_size, passage_length, encoding_dim)
passage_question_vectors = util.weighted_sum(encoded_question, passage_question_attention)
# We replace masked values with something really negative here, so they don't affect the
# max below.
masked_similarity = util.replace_masked_values(passage_question_similarity,
question_mask.unsqueeze(1),
-1e7)
# Shape: (batch_size, passage_length)
question_passage_similarity = masked_similarity.max(dim=-1)[0].squeeze(-1)
# Shape: (batch_size, passage_length)
question_passage_attention = util.masked_softmax(question_passage_similarity, passage_mask)
# Shape: (batch_size, encoding_dim)
question_passage_vector = util.weighted_sum(encoded_passage, question_passage_attention)
# Shape: (batch_size, passage_length, encoding_dim)
tiled_question_passage_vector = question_passage_vector.unsqueeze(1).expand(batch_size,
passage_length,
encoding_dim)
# Shape: (batch_size, passage_length, encoding_dim * 4)
final_merged_passage = torch.cat([encoded_passage,
passage_question_vectors,
encoded_passage * passage_question_vectors,
encoded_passage * tiled_question_passage_vector],
dim=-1)
"""
modeling_layer : ``Seq2SeqEncoder``
The encoder (with its own internal stacking) that we will use in between the bidirectional
attention and predicting span start and end.
encoder_outputs,
entity_type_embeddings,
linking_scores,
utterance_mask,
self._get_graph_adj_lists)
# save this for loss calculation
self.predicted_relevance_logits = ent_relevance_logits
# multiply the embedding with the computed relevance
graph_initial_embedding = entity_type_embeddings * ent_relevance
encoder_output_dim = self._encoder.get_output_dim()
if self._gnn:
entities_graph_encoding = self._get_schema_graph_encoding(worlds,
graph_initial_embedding)
graph_link_embedding = util.weighted_sum(entities_graph_encoding, linking_probabilities)
encoder_outputs = torch.cat((
encoder_outputs,
graph_link_embedding
), dim=-1)
encoder_output_dim = self._action_embedding_dim + self._encoder.get_output_dim()
else:
entities_graph_encoding = None
if self._self_attend:
# linked_actions_linking_scores = self._get_linked_actions_linking_scores(actions, entities_graph_encoding)
entities_ff = self._ent2ent_ff(entities_graph_encoding)
linked_actions_linking_scores = torch.bmm(entities_ff, entities_ff.transpose(1, 2))
else:
linked_actions_linking_scores = [None] * batch_size
# This will be our initial hidden state and memory cell for the decoder LSTM.
passage_question_similarity = self._matrix_attention(encoded_passage, encoded_question)
# Shape: (batch_size, passage_length, question_length)
passage_question_attention = masked_softmax(
passage_question_similarity, question_mask, memory_efficient=True
)
# Shape: (batch_size, passage_length, encoding_dim)
passage_question_vectors = util.weighted_sum(encoded_question, passage_question_attention)
# Shape: (batch_size, question_length, passage_length)
question_passage_attention = masked_softmax(
passage_question_similarity.transpose(1, 2), passage_mask, memory_efficient=True
)
# Shape: (batch_size, passage_length, passage_length)
attention_over_attention = torch.bmm(passage_question_attention, question_passage_attention)
# Shape: (batch_size, passage_length, encoding_dim)
passage_passage_vectors = util.weighted_sum(encoded_passage, attention_over_attention)
# Shape: (batch_size, passage_length, encoding_dim * 4)
merged_passage_attention_vectors = self._dropout(
torch.cat(
[
encoded_passage,
passage_question_vectors,
encoded_passage * passage_question_vectors,
encoded_passage * passage_passage_vectors,
],
dim=-1,
)
)
modeled_passage_list = [self._modeling_proj_layer(merged_passage_attention_vectors)]
passage_length = embedded_passage.size(1)
question_mask = util.get_text_field_mask(question).float()
passage_mask = util.get_text_field_mask(passage).float()
question_lstm_mask = question_mask if self._mask_lstms else None
passage_lstm_mask = passage_mask if self._mask_lstms else None
encoded_question = self._dropout(self._phrase_layer(embedded_question, question_lstm_mask))
encoded_passage = self._dropout(self._phrase_layer(embedded_passage, passage_lstm_mask))
encoding_dim = encoded_question.size(-1)
# Shape: (batch_size, passage_length, question_length)
passage_question_similarity = self._matrix_attention(encoded_passage, encoded_question)
# Shape: (batch_size, passage_length, question_length)
passage_question_attention = util.masked_softmax(passage_question_similarity, question_mask)
# Shape: (batch_size, passage_length, encoding_dim)
passage_question_vectors = util.weighted_sum(encoded_question, passage_question_attention)
# We replace masked values with something really negative here, so they don't affect the
# max below.
masked_similarity = util.replace_masked_values(
passage_question_similarity, question_mask.unsqueeze(1), -1e7
)
# Shape: (batch_size, passage_length)
question_passage_similarity = masked_similarity.max(dim=-1)[0].squeeze(-1)
# Shape: (batch_size, passage_length)
question_passage_attention = util.masked_softmax(question_passage_similarity, passage_mask)
# Shape: (batch_size, encoding_dim)
question_passage_vector = util.weighted_sum(encoded_passage, question_passage_attention)
# Shape: (batch_size, passage_length, encoding_dim)
tiled_question_passage_vector = question_passage_vector.unsqueeze(1).expand(
batch_size, passage_length, encoding_dim
)
self,
decoder_hidden_state: torch.Tensor = None,
encoder_outputs: torch.Tensor = None,
encoder_outputs_mask: torch.Tensor = None,
) -> torch.Tensor:
"""Apply attention over encoder outputs and decoder state."""
# Ensure mask is also a FloatTensor. Or else the multiplication within
# attention will complain.
# shape: (batch_size, max_input_sequence_length, encoder_output_dim)
encoder_outputs_mask = encoder_outputs_mask.float()
# shape: (batch_size, max_input_sequence_length)
input_weights = self._attention(decoder_hidden_state, encoder_outputs, encoder_outputs_mask)
# shape: (batch_size, encoder_output_dim)
attended_input = util.weighted_sum(encoder_outputs, input_weights)
return attended_input