Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def _forward_loop(self,
source: Dict[str, torch.Tensor],
alias_database: AliasDatabase,
mention_type: torch.Tensor,
raw_entity_ids: Dict[str, torch.Tensor],
entity_ids: Dict[str, torch.Tensor],
parent_ids: Dict[str, torch.Tensor],
relations: Dict[str, torch.Tensor],
shortlist: Dict[str, torch.Tensor],
shortlist_inds: torch.Tensor) -> Dict[str, torch.Tensor]:
# Get the token mask and extract indexed text fields.
# shape: (batch_size, sequence_length)
target_mask = get_text_field_mask(source)
source = source['tokens']
raw_entity_ids = raw_entity_ids['raw_entity_ids']
entity_ids = entity_ids['entity_ids']
parent_ids = parent_ids['entity_ids']
relations = relations['relations']
logger.debug('Source & Target shape: %s', source.shape)
logger.debug('Entity ids shape: %s', entity_ids.shape)
logger.debug('Relations & Parent ids shape: %s', relations.shape)
logger.debug('Shortlist shape: %s', shortlist['entity_ids'].shape)
# Embed source tokens.
# shape: (batch_size, sequence_length, embedding_dim)
encoded, alpha_loss, beta_loss = self._encode_source(source)
splits = [self.token_embedding_dim] + [self.entity_embedding_dim] * 2
encoded_token, encoded_head, encoded_relation = encoded.split(splits, dim=-1)
the index (with respect to top_spans) of the possible antecedents the model considered.
predicted_antecedents : ``torch.IntTensor``
A tensor of shape ``(batch_size, num_spans_to_keep)`` representing, for each top span, the
index (with respect to antecedent_indices) of the most likely antecedent. -1 means there
was no predicted link.
loss : ``torch.FloatTensor``, optional
A scalar loss to be optimised.
"""
# Shape: (batch_size, document_length, embedding_size)
text_embeddings = self._lexical_dropout(self._text_field_embedder(text))
document_length = text_embeddings.size(1)
num_spans = span_starts.size(1)
# Shape: (batch_size, document_length)
text_mask = util.get_text_field_mask(text).float()
# Shape: (batch_size, num_spans, 1)
span_mask = (span_starts >= 0).float()
# IndexFields return -1 when they are used as padding. As we do
# some comparisons based on span widths when we attend over the
# span representations that we generate from these indices, we
# need them to be <= 0. This is only relevant in edge cases where
# the number of spans we consider after the pruning stage is >= the
# total number of spans, because in this case, it is possible we might
# consider a masked span.
span_starts = F.relu(span_starts.float()).long()
span_ends = F.relu(span_ends.float()).long()
# Shape: (batch_size, num_spans, 2)
span_indices = torch.cat([span_starts, span_ends], -1)
# Shape: (batch_size, document_length, encoding_dim)
best_span : torch.IntTensor
The result of a constrained inference over ``span_start_logits`` and
``span_end_logits`` to find the most probable span. Shape is ``(batch_size, 2)``
and each offset is a token index.
loss : torch.FloatTensor, optional
A scalar loss to be optimised.
best_span_str : List[str]
If sufficient metadata was provided for the instances in the batch, we also return the
string from the original passage that the model thinks is the best answer to the
question.
"""
embedded_question = self._highway_layer(self._text_field_embedder(question))
embedded_passage = self._highway_layer(self._text_field_embedder(passage))
batch_size = embedded_question.size(0)
passage_length = embedded_passage.size(1)
question_mask = util.get_text_field_mask(question).float()
passage_mask = util.get_text_field_mask(passage).float()
question_lstm_mask = question_mask if self._mask_lstms else None
passage_lstm_mask = passage_mask if self._mask_lstms else None
encoded_question = self._dropout(self._phrase_layer(embedded_question, question_lstm_mask))
encoded_passage = self._dropout(self._phrase_layer(embedded_passage, passage_lstm_mask))
encoding_dim = encoded_question.size(-1)
# Shape: (batch_size, passage_length, question_length)
passage_question_similarity = self._matrix_attention(encoded_passage, encoded_question)
# Shape: (batch_size, passage_length, question_length)
passage_question_attention = util.masked_softmax(passage_question_similarity, question_mask)
# Shape: (batch_size, passage_length, encoding_dim)
passage_question_vectors = util.weighted_sum(encoded_question, passage_question_attention)
# We replace masked values with something really negative here, so they don't affect the
def forward(self,
context: Dict[str, torch.LongTensor],
response: Dict[str, torch.LongTensor],
label: torch.LongTensor = None) -> Dict[str, torch.Tensor]:
embedded_context = self.text_field_embedder(context)
context_mask = get_text_field_mask(context).float()
embedded_response = self.text_field_embedder(response)
response_mask = get_text_field_mask(response).float()
if self.context_encoder:
embedded_context = self.context_encoder(embedded_context, context_mask)
if self.response_encoder:
embedded_response = self.response_encoder(embedded_response, response_mask)
projected_context = self.attend_feedforward(embedded_context)
projected_response = self.attend_feedforward(embedded_response)
# batch x context_length x response_length
similarity_matrix = self.matrix_attention(projected_context, projected_response)
# batch x context_length x response_length
c2r_attention = last_dim_softmax(similarity_matrix, response_mask)
-------
An output dictionary consisting of:
label_logits : torch.FloatTensor
A tensor of shape ``(batch_size, num_labels)`` representing unnormalised log
probabilities of the entailment label.
label_probs : torch.FloatTensor
A tensor of shape ``(batch_size, num_labels)`` representing probabilities of the
entailment label.
loss : torch.FloatTensor, optional
A scalar loss to be optimised.
"""
embedded_premise = self._text_field_embedder(premise)
embedded_hypothesis = self._text_field_embedder(hypothesis)
premise_mask = get_text_field_mask(premise).float()
hypothesis_mask = get_text_field_mask(hypothesis).float()
if self._premise_encoder:
embedded_premise = self._premise_encoder(embedded_premise, premise_mask)
if self._hypothesis_encoder:
embedded_hypothesis = self._hypothesis_encoder(embedded_hypothesis, hypothesis_mask)
projected_premise = self._attend_feedforward(embedded_premise)
projected_hypothesis = self._attend_feedforward(embedded_hypothesis)
# Shape: (batch_size, premise_length, hypothesis_length)
similarity_matrix = self._matrix_attention(projected_premise, projected_hypothesis)
# Shape: (batch_size, premise_length, hypothesis_length)
p2h_attention = masked_softmax(similarity_matrix, hypothesis_mask)
# Shape: (batch_size, premise_length, embedding_dim)
attended_hypothesis = weighted_sum(embedded_hypothesis, p2h_attention)
def _get_initial_state(self,
utterance: Dict[str, torch.LongTensor],
worlds: List[SpiderWorld],
schema: Dict[str, torch.LongTensor],
actions: List[List[ProductionRule]]) -> GrammarBasedState:
schema_text = schema['text']
embedded_schema = self._question_embedder(schema_text, num_wrapping_dims=1)
schema_mask = util.get_text_field_mask(schema_text, num_wrapping_dims=1).float()
embedded_utterance = self._question_embedder(utterance)
utterance_mask = util.get_text_field_mask(utterance).float()
batch_size, num_entities, num_entity_tokens, _ = embedded_schema.size()
num_entities = max([len(world.db_context.knowledge_graph.entities) for world in worlds])
num_question_tokens = utterance['tokens'].size(1)
# entity_types: tensor with shape (batch_size, num_entities), where each entry is the
# entity's type id.
# entity_type_dict: Dict[int, int], mapping flattened_entity_index -> type_index
# These encode the same information, but for efficiency reasons later it's nice
# to have one version as a tensor and one that's accessible on the cpu.
entity_types, entity_type_dict = self._get_type_vector(worlds, num_entities, embedded_schema.device)
entity_type_embeddings = self._entity_type_encoder_embedding(entity_types)
# Compute entity and question word similarity. We tried using cosine distance here, but
# because this similarity is the main mechanism that the model can use to push apart logit
def _get_initial_state(
self,
utterance: Dict[str, torch.LongTensor],
worlds: List[AtisWorld],
actions: List[List[ProductionRule]],
linking_scores: torch.Tensor,
) -> GrammarBasedState:
embedded_utterance = self._utterance_embedder(utterance)
utterance_mask = util.get_text_field_mask(utterance).float()
batch_size = embedded_utterance.size(0)
num_entities = max([len(world.entities) for world in worlds])
# entity_types: tensor with shape (batch_size, num_entities)
entity_types, _ = self._get_type_vector(worlds, num_entities, embedded_utterance)
# (batch_size, num_utterance_tokens, embedding_dim)
encoder_input = embedded_utterance
# (batch_size, utterance_length, encoder_output_dim)
encoder_outputs = self._dropout(self._encoder(encoder_input, utterance_mask))
# This will be our initial hidden state and memory cell for the decoder LSTM.
final_encoder_output = util.get_final_encoder_states(
encoder_outputs, utterance_mask, self._encoder.is_bidirectional()
def forward(self, text_field_input: Dict[str, torch.Tensor], num_wrapping_dims: int = 0) -> torch.Tensor:
text_field_embeddings = self._base_text_field_embedder.forward(text_field_input, num_wrapping_dims)
base_representation = text_field_embeddings
mask = util.get_text_field_mask(text_field_input)
for encoder in self._previous_encoders:
text_field_embeddings = encoder(text_field_embeddings, mask)
text_field_embeddings = torch.cat([base_representation, text_field_embeddings], dim=-1)
return torch.cat([text_field_embeddings], dim=-1)
def _get_initial_state(
self,
utterance: Dict[str, torch.LongTensor],
worlds: List[AtisWorld],
actions: List[List[ProductionRule]],
linking_scores: torch.Tensor,
) -> GrammarBasedState:
embedded_utterance = self._utterance_embedder(utterance)
utterance_mask = util.get_text_field_mask(utterance).float()
batch_size = embedded_utterance.size(0)
num_entities = max([len(world.entities) for world in worlds])
# entity_types: tensor with shape (batch_size, num_entities)
entity_types, _ = self._get_type_vector(worlds, num_entities, embedded_utterance)
# (batch_size, num_utterance_tokens, embedding_dim)
encoder_input = embedded_utterance
# (batch_size, utterance_length, encoder_output_dim)
encoder_outputs = self._dropout(self._encoder(encoder_input, utterance_mask))
# This will be our initial hidden state and memory cell for the decoder LSTM.
final_encoder_output = util.get_final_encoder_states(
encoder_outputs, utterance_mask, self._encoder.is_bidirectional()
def _get_initial_rnn_state(self, sentence: Dict[str, torch.LongTensor]):
embedded_input = self._sentence_embedder(sentence)
# (batch_size, sentence_length)
sentence_mask = util.get_text_field_mask(sentence).float()
batch_size = embedded_input.size(0)
# (batch_size, sentence_length, encoder_output_dim)
encoder_outputs = self._dropout(self._encoder(embedded_input, sentence_mask))
final_encoder_output = util.get_final_encoder_states(
encoder_outputs, sentence_mask, self._encoder.is_bidirectional()
)
memory_cell = encoder_outputs.new_zeros(batch_size, self._encoder.get_output_dim())
attended_sentence, _ = self._decoder_step.attend_on_question(
final_encoder_output, encoder_outputs, sentence_mask
)
encoder_outputs_list = [encoder_outputs[i] for i in range(batch_size)]
sentence_mask_list = [sentence_mask[i] for i in range(batch_size)]
initial_rnn_state = []