Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
action_sequence = []
elif action_sequence is None:
return None
index_fields: List[Field] = []
production_rule_fields: List[Field] = []
for production_rule in all_actions:
nonterminal, _ = production_rule.split(" ->")
production_rule = " ".join(production_rule.split(" "))
field = ProductionRuleField(
production_rule, self._world.is_global_rule(nonterminal), nonterminal=nonterminal
)
production_rule_fields.append(field)
valid_actions_field = ListField(production_rule_fields)
fields["valid_actions"] = valid_actions_field
action_map = {
action.rule: i # type: ignore
for i, action in enumerate(valid_actions_field.field_list)
}
for production_rule in action_sequence:
index_fields.append(IndexField(action_map[production_rule], valid_actions_field))
if not action_sequence:
index_fields = [IndexField(-1, valid_actions_field)]
action_sequence_field = ListField(index_fields)
fields["action_sequence"] = action_sequence_field
return Instance(fields)
targets : ``List[str]``, optional
Contains the target tokens to be predicted. The length of this list should be the same
as the number of [MASK] tokens in the input.
"""
if not tokens:
tokens = self._tokenizer.tokenize(sentence)
input_field = TextField(tokens, self._token_indexers)
mask_positions = []
for i, token in enumerate(tokens):
if token.text == "[MASK]":
mask_positions.append(i)
if not mask_positions:
raise ValueError("No [MASK] tokens found!")
if targets and len(targets) != len(mask_positions):
raise ValueError(f"Found {len(mask_positions)} mask tokens and {len(targets)} targets")
mask_position_field = ListField([IndexField(i, input_field) for i in mask_positions])
fields: Dict[str, Field] = {"tokens": input_field, "mask_positions": mask_position_field}
# TODO(mattg): there's a problem if the targets get split into multiple word pieces...
# (maksym-del): if we index word that was not split into wordpieces with
# PretrainedTransformerTokenizer we will get OOV token ID...
# Until this is handeled, let's use first wordpiece id for each token since tokens should contain text_ids
# to be indexed with PretrainedTokenIndexer. It also requeires hack to avoid adding special tokens...
if targets is not None:
# target_field = TextField([Token(target) for target in targets], self._token_indexers)
first_wordpieces = [self._targets_tokenizer.tokenize(target)[0] for target in targets]
target_tokens = []
for wordpiece, target in zip(first_wordpieces, targets):
target_tokens.append(
Token(text=target, text_id=wordpiece.text_id, type_id=wordpiece.type_id)
)
fields["target_ids"] = TextField(target_tokens, self._token_indexers)
return Instance(fields)
if target is not None:
fields['target'] = TextField(_tokenize(target), self._token_indexers)
metadata['target_tokens'] = target
if shortlist is not None:
fields['shortlist'] = TextField(_tokenize(shortlist), self._entity_indexers)
if raw_entity_ids is not None:
fields['raw_entity_ids'] = TextField(_tokenize(raw_entity_ids), self._raw_entity_indexers)
if entity_ids is not None:
fields['entity_ids'] = TextField(_tokenize(entity_ids), self._entity_indexers)
if parent_ids is not None:
fields['parent_ids'] = ListField([
TextField(_tokenize(sublist),
token_indexers=self._entity_indexers)
for sublist in parent_ids])
if relations is not None:
fields['relations'] = ListField([
TextField(_tokenize(sublist),
token_indexers=self._relation_indexers)
for sublist in relations])
if mention_type is not None:
fields['mention_type'] = SequentialArrayField(mention_type, dtype=np.int64)
if shortlist_inds is not None:
fields['shortlist_inds'] = SequentialArrayField(shortlist_inds, dtype=np.int64)
if alias_copy_inds is not None:
fields['alias_copy_inds'] = SequentialArrayField(alias_copy_inds, dtype=np.int64)
return Instance(fields)
fields['s_ent_def'] = TextField(
self._tokenizer.tokenize(s_ent['definition']), self._token_only_indexer
) if s_ent['definition'] else self._empty_token_text_field
fields['t_ent_def'] = TextField(
self._tokenizer.tokenize(t_ent['definition']), self._token_only_indexer
) if t_ent['definition'] else self._empty_token_text_field
# add entity context fields
s_contexts = sample_n(s_ent['other_contexts'], 16, 256)
t_contexts = sample_n(t_ent['other_contexts'], 16, 256)
fields['s_ent_context'] = ListField(
[TextField(self._tokenizer.tokenize(c), self._token_only_indexer)
for c in s_contexts]
)
fields['t_ent_context'] = ListField(
[TextField(self._tokenizer.tokenize(c), self._token_only_indexer)
for c in t_contexts]
)
# add boolean label (0 = no match, 1 = match)
fields['label'] = BooleanField(label)
return Instance(fields)
question_span_fields.append(SpanField(-1, -1, question_passage_field))
fields["answer_as_question_spans"] = ListField(question_span_fields)
if self.exp_search == 'add_sub':
add_sub_signs_field: List[Field] = []
extra_signs_field: List[Field] = []
for signs_for_one_add_sub_expressions in valid_expressions:
extra_signs = signs_for_one_add_sub_expressions[:len(self.extra_numbers)]
normal_signs = signs_for_one_add_sub_expressions[len(self.extra_numbers):]
add_sub_signs_field.append(SequenceLabelField(normal_signs, numbers_in_passage_field))
extra_signs_field.append(SequenceLabelField(extra_signs, extra_numbers_field))
if not add_sub_signs_field:
add_sub_signs_field.append(SequenceLabelField([0] * len(number_tokens), numbers_in_passage_field))
if not extra_signs_field:
extra_signs_field.append(SequenceLabelField([0] * len(self.extra_numbers), extra_numbers_field))
fields["answer_as_expressions"] = ListField(add_sub_signs_field)
if self.extra_numbers:
fields["answer_as_expressions_extra"] = ListField(extra_signs_field)
elif self.exp_search in ['template', 'full']:
expression_indices = []
for expression in valid_expressions:
if not expression:
expression.append(3 * [-1])
expression_indices.append(ArrayField(np.array(expression), padding_value=-1))
if not expression_indices:
expression_indices = \
[ArrayField(np.array([3 * [-1]]), padding_value=-1) for _ in range(len(self.templates))]
fields["answer_as_expressions"] = ListField(expression_indices)
count_fields: List[Field] = [LabelField(count_label, skip_indexing=True) for count_label in valid_counts]
if not count_fields:
count_fields.append(LabelField(-1, skip_indexing=True))
def text_to_instance(self, # type: ignore
item_id: Any,
question_text: str,
choice_text_list: List[str],
answer_id: int
) -> Instance:
# pylint: disable=arguments-differ
fields: Dict[str, Field] = {}
question_tokens = self._tokenizer.tokenize(question_text)
choices_tokens_list = [self._tokenizer.tokenize(x) for x in choice_text_list]
fields['question'] = TextField(question_tokens, self._token_indexers)
fields['choices_list'] = ListField([TextField(x, self._token_indexers) for x in choices_tokens_list])
fields['label'] = LabelField(answer_id, skip_indexing=True)
metadata = {
"id": item_id,
"question_text": question_text,
"choice_text_list": choice_text_list,
"question_tokens": [x.text for x in question_tokens],
"choice_tokens_list": [[x.text for x in ct] for ct in choices_tokens_list],
}
fields["metadata"] = MetadataField(metadata)
return Instance(fields)
def _process_sentence(self, sent: Sentence):
# Get the sentence text and define the `text_field`.
sentence_text = [self._normalize_word(word) for word in sent.text]
text_field = TextField([Token(word) for word in sentence_text], self._token_indexers)
# Enumerate spans.
spans = []
for start, end in enumerate_spans(sentence_text, max_span_width=self._max_span_width):
spans.append(SpanField(start, end, text_field))
span_field = ListField(spans)
span_tuples = [(span.span_start, span.span_end) for span in spans]
# Convert data to fields.
# NOTE: The `ner_labels` and `coref_labels` would ideally have type
# `ListField[SequenceLabelField]`, where the sequence labels are over the `SpanField` of
# `spans`. But calling `as_tensor_dict()` fails on this specific data type. Matt G
# recognized that this is an AllenNLP API issue and suggested that represent these as
# `ListField[ListField[LabelField]]` instead.
dataset = sent.doc.dataset
fields = {}
fields["text"] = text_field
fields["spans"] = span_field
if sent.ner is not None:
ner_labels = self._process_ner(span_tuples, sent)
fields["ner_labels"] = ListField(
[LabelField(entry, label_namespace=f"ner_labels")
# Depending on the type of supervision used for training the parser, we may want either
# target action sequences or an agenda in our instance. We check if target sequences are
# provided, and include them if they are. If not, we'll get an agenda for the sentence, and
# include that in the instance.
if target_sequences:
action_sequence_fields: List[Field] = []
for target_sequence in target_sequences:
index_fields = ListField(
[
IndexField(instance_action_ids[action], action_field)
for action in target_sequence
]
)
action_sequence_fields.append(index_fields)
# TODO(pradeep): Define a max length for this field.
fields["target_action_sequences"] = ListField(action_sequence_fields)
elif self._output_agendas:
# TODO(pradeep): Assuming every world gives the same agenda for a sentence. This is true
# now, but may change later too.
agenda = worlds[0].get_agenda_for_sentence(sentence)
assert agenda, "No agenda found for sentence: %s" % sentence
# agenda_field contains indices into actions.
agenda_field = ListField(
[IndexField(instance_action_ids[action], action_field) for action in agenda]
)
fields["agenda"] = agenda_field
if labels:
labels_field = ListField(
[LabelField(label, label_namespace="denotations") for label in labels]
)
fields["labels"] = labels_field
if answer_texts is None or len(answer_texts) > 0:
metadata['answer_texts'] = answer_texts
else:
metadata['answer_texts'] = ['']
if token_spans:
# There may be multiple answer annotations, so we pick the one that occurs the most. This
# only matters on the SQuAD dev set, and it means our computed metrics ("start_acc",
# "end_acc", and "span_acc") aren't quite the same as the official metrics, which look at
# all of the annotations. This is why we have a separate official SQuAD metric calculation
# (the "em" and "f1" metrics use the official script).
candidate_answers: Counter = Counter()
token_spans = set(token_spans)
span_fields = []
span_fields = ListField([SpanField(start, end, passage_field)
for start, end in token_spans])
else:
span_fields = ListField([SpanField(-1, -1, passage_field)])
fields['spans'] = span_fields
metadata.update(additional_metadata)
fields['metadata'] = MetadataField(metadata)
return Instance(fields)
fields["answer_as_expressions_extra"] = ListField(extra_signs_field)
elif self.exp_search in ['template', 'full']:
expression_indices = []
for expression in valid_expressions:
if not expression:
expression.append(3 * [-1])
expression_indices.append(ArrayField(np.array(expression), padding_value=-1))
if not expression_indices:
expression_indices = \
[ArrayField(np.array([3 * [-1]]), padding_value=-1) for _ in range(len(self.templates))]
fields["answer_as_expressions"] = ListField(expression_indices)
count_fields: List[Field] = [LabelField(count_label, skip_indexing=True) for count_label in valid_counts]
if not count_fields:
count_fields.append(LabelField(-1, skip_indexing=True))
fields["answer_as_counts"] = ListField(count_fields)
fields["num_spans"] = LabelField(num_spans, skip_indexing=True)
fields["metadata"] = MetadataField(metadata)
return Instance(fields)