Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def __eq__(self, other):
if isinstance(self, other.__class__):
return all(
[
self.batch_indices == other.batch_indices,
self.action_history == other.action_history,
util.tensors_equal(self.score, other.score, tolerance=1e-3),
util.tensors_equal(self.rnn_state, other.rnn_state, tolerance=1e-4),
self.grammar_state == other.grammar_state,
self.checklist_state == other.checklist_state,
self.possible_actions == other.possible_actions,
self.extras == other.extras,
util.tensors_equal(self.debug_info, other.debug_info, tolerance=1e-6),
]
)
return NotImplemented
def batch_tensors(self, tensor_list ) :
# pylint: disable=no-self-use
batched_text = nn_util.batch_tensor_dicts(tensor[u'text'] for tensor in tensor_list) # type: ignore
batched_linking = torch.stack([tensor[u'linking'] for tensor in tensor_list])
return {u'text': batched_text, u'linking': batched_linking}
encoder_output_mask ) :
u"""
Given a query (which is typically the decoder hidden state), compute an attention over the
output of the question encoder, and return a weighted sum of the question representations
given this attention. We also return the attention weights themselves.
This is a simple computation, but we have it as a separate method so that the ``forward``
method on the main parser module can call it on the initial hidden state, to simplify the
logic in ``take_step``.
"""
# (group_size, question_length)
question_attention_weights = self._input_attention(query,
encoder_outputs,
encoder_output_mask)
# (group_size, encoder_output_dim)
attended_question = util.weighted_sum(encoder_outputs, question_attention_weights)
return attended_question, question_attention_weights
def __eq__(self, other):
if isinstance(self, other.__class__):
return all(
[
self._nonterminal_stack == other._nonterminal_stack,
self._lambda_stacks == other._lambda_stacks,
util.tensors_equal(self._valid_actions, other._valid_actions),
util.tensors_equal(self._context_actions, other._context_actions),
self._is_nonterminal == other._is_nonterminal,
]
)
return NotImplemented
def _decoder_step(
self,
last_predictions: torch.Tensor,
selective_weights: torch.Tensor,
state: Dict[str, torch.Tensor],
) -> Dict[str, torch.Tensor]:
# shape: (group_size, max_input_sequence_length, encoder_output_dim)
encoder_outputs_mask = state["source_mask"]
# shape: (group_size, target_embedding_dim)
embedded_input = self._target_embedder(last_predictions)
# shape: (group_size, max_input_sequence_length)
attentive_weights = self._attention(
state["decoder_hidden"], state["encoder_outputs"], encoder_outputs_mask
)
# shape: (group_size, encoder_output_dim)
attentive_read = util.weighted_sum(state["encoder_outputs"], attentive_weights)
# shape: (group_size, encoder_output_dim)
selective_read = util.weighted_sum(
state["encoder_outputs"][:, 1:-1], selective_weights
)
# shape: (group_size, target_embedding_dim + encoder_output_dim * 2)
decoder_input = torch.cat((embedded_input, attentive_read, selective_read), -1)
# shape: (group_size, decoder_input_dim)
projected_decoder_input = self._input_projection_layer(decoder_input)
state["decoder_hidden"], state["decoder_context"] = self._decoder_cell(
projected_decoder_input, (state["decoder_hidden"], state["decoder_context"])
)
return state
def _init_decoder_state(self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
batch_size = state["source_mask"].size(0)
# shape: (batch_size, encoder_output_dim)
final_encoder_output = util.get_final_encoder_states(
state["encoder_outputs"],
state["source_mask"],
self._encoder.is_bidirectional())
# Initialize the decoder hidden state with the final output of the encoder.
# shape: (batch_size, decoder_output_dim)
state["decoder_hidden"] = final_encoder_output
encoder_outputs = state["encoder_outputs"]
state["decoder_context"] = encoder_outputs.new_zeros(batch_size, self._decoder_output_dim)
if self._embed_attn_to_output:
state["attn_context"] = encoder_outputs.new_zeros(encoder_outputs.size(0), encoder_outputs.size(2))
if self._use_coverage:
state["coverage"] = encoder_outputs.new_zeros(batch_size, encoder_outputs.size(1))
return state
best_span: torch.FloatTensor
A tensor of shape ``()``
true_span: torch.FloatTensor
loss: torch.FloatTensor
"""
# batchsize * listLength * paragraphSize * embeddingSize
input_embedding_paragraph = self._text_field_embedder(tokens_list)
input_pos_embedding_paragraph = self._pos_field_embedder(positions_list)
input_sent_pos_embedding_paragraph = self._sent_pos_field_embedder(sent_positions_list)
# batchsize * listLength * paragraphSize * (embeddingSize*2)
embedding_paragraph = torch.cat([input_embedding_paragraph, input_pos_embedding_paragraph,
input_sent_pos_embedding_paragraph], dim=-1)
# batchsize * listLength * paragraphSize, this mask is shared with the text fields and sequence label fields
para_mask = util.get_text_field_mask(tokens_list, num_wrapping_dims=1).float()
# batchsize * listLength , this mask is shared with the index fields
para_index_mask, para_index_mask_indices = torch.max(para_mask, 2)
# apply mask to update the index values, padded instances will be 0
after_loc_start_list = (after_loc_start_list.float() * para_index_mask.unsqueeze(2)).long()
after_loc_end_list = (after_loc_end_list.float() * para_index_mask.unsqueeze(2)).long()
after_category_list = (after_category_list.float() * para_index_mask.unsqueeze(2)).long()
after_category_mask_list = (after_category_mask_list.float() * para_index_mask.unsqueeze(2)).long()
batch_size, list_size, paragraph_size, input_dim = embedding_paragraph.size()
# to store the values passed to next step
tmp_category_probability = torch.zeros(batch_size, 3)
tmp_start_probability = torch.zeros(batch_size, paragraph_size)
def get_output_dim(self) :
unidirectional_dim = int(self._input_dim / 2)
forward_combined_dim = util.get_combined_dim(self._forward_combination,
[unidirectional_dim, unidirectional_dim])
backward_combined_dim = util.get_combined_dim(self._backward_combination,
[unidirectional_dim, unidirectional_dim])
if self._span_width_embedding is not None:
return forward_combined_dim + backward_combined_dim +\
self._span_width_embedding.get_output_dim()
return forward_combined_dim + backward_combined_dim
passage_lstm_mask = passage_mask if self._mask_lstms else None
# Phrase layer is the shared Bi-GRU in the paper
# (extended_batch_size, sequence_length, input_dim) -> (extended_batch_size, sequence_length, encoding_dim)
encoded_question = self._dropout(self._phrase_layer(embedded_question, question_lstm_mask))
encoded_passage = self._dropout(self._phrase_layer(embedded_passage, passage_lstm_mask))
batch_size, num_tokens, _ = encoded_passage.size()
encoding_dim = encoded_question.size(-1)
# Shape: (extended_batch_size, passage_length, question_length)
# these are the a_ij in the paper
passage_question_similarity = self._matrix_attention(encoded_passage, encoded_question)
# Shape: (extended_batch_size, passage_length, question_length)
# these are the p_ij in the paper
passage_question_attention = util.last_dim_softmax(passage_question_similarity, question_mask)
# Shape: (extended_batch_size, passage_length, encoding_dim)
# these are the c_i in the paper
passage_question_vectors = util.weighted_sum(encoded_question, passage_question_attention)
# We replace masked values with something really negative here, so they don't affect the
# max below.
# Shape: (extended_batch_size, passage_length, question_length)
masked_similarity = util.replace_masked_values(passage_question_similarity,
question_mask.unsqueeze(1),
-1e7)
# Take the max over the last dimension (all question words)
# Shape: (extended_batch_size, passage_length)
question_passage_similarity = masked_similarity.max(dim=-1)[0]