Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def _get_combination(combination: str, tensors: List[torch.Tensor]) -> torch.Tensor:
if combination.isdigit():
index = int(combination) - 1
return tensors[index]
else:
if len(combination) != 3:
raise ConfigurationError("Invalid combination: " + combination)
first_tensor = _get_combination(combination[0], tensors)
second_tensor = _get_combination(combination[2], tensors)
operation = combination[1]
if operation == '*':
return first_tensor * second_tensor
elif operation == '/':
return first_tensor / second_tensor
elif operation == '+':
return first_tensor + second_tensor
elif operation == '-':
return first_tensor - second_tensor
else:
raise ConfigurationError("Invalid operation: " + operation)
instances = []
# open data file and read lines
with open(file_path, 'r') as ontm_file:
logger.info("Reading ontology matching instances from jsonl dataset at: %s", file_path)
for line in tqdm.tqdm(ontm_file):
training_pair = json.loads(line)
s_ent = training_pair['source_ent']
t_ent = training_pair['target_ent']
label = training_pair['label']
# convert entry to instance and append to instances
instances.append(self.text_to_instance(s_ent, t_ent, label))
if not instances:
raise ConfigurationError("No instances were read from the given filepath {}. "
"Is the path correct?".format(file_path))
return Dataset(instances)
loss : torch.FloatTensor, optional
A scalar loss to be optimised.
"""
embedded_text_input = self.embedding_dropout(self.text_field_embedder(tokens))
text_mask = util.get_text_field_mask(tokens)
embedded_verb_indicator = self.binary_feature_embedding(verb_indicator.long())
# Concatenate the verb feature onto the embedded text. This now
# has shape (batch_size, sequence_length, embedding_dim + binary_feature_dim).
embedded_text_with_verb_indicator = torch.cat(
[embedded_text_input, embedded_verb_indicator], -1)
embedding_dim_with_binary_feature = embedded_text_with_verb_indicator.size()[
2]
if self.stacked_encoder.get_input_dim() != embedding_dim_with_binary_feature:
raise ConfigurationError("The SRL model uses an indicator feature, which makes "
"the embedding dimension one larger than the value "
"specified. Therefore, the 'input_dim' of the stacked_encoder "
"must be equal to total_embedding_dim + 1.")
encoded_text = self.stacked_encoder(
embedded_text_with_verb_indicator, text_mask)
batch_size, num_spans = tags.size()
assert num_spans % self.max_span_width == 0
tags = tags.view(batch_size, -1, self.max_span_width)
span_starts = F.relu(span_starts.float()).long().view(batch_size, -1)
span_ends = F.relu(span_ends.float()).long().view(batch_size, -1)
target_index = F.relu(target_index.float()).long().view(batch_size)
# shape (batch_size, sequence_length * max_span_width, embedding_dim)
span_embeddings = span_srl_util.compute_span_representations(self.max_span_width,
tokens : ``List[str]``, required.
The original string tokens in the sentence.
arc_tags : torch.LongTensor, optional (default = None)
A torch tensor representing the sequence of integer indices denoting the parent of every
word in the dependency parse. Has shape ``(batch_size, sequence_length, sequence_length)``.
Returns
-------
An output dictionary.
"""
embedded_text_input = self.text_field_embedder(tokens)
if pos_tags is not None and self._pos_tag_embedding is not None:
embedded_pos_tags = self._pos_tag_embedding(pos_tags)
embedded_text_input = torch.cat([embedded_text_input, embedded_pos_tags], -1)
elif self._pos_tag_embedding is not None:
raise ConfigurationError("Model uses a POS embedding, but no POS tags were passed.")
mask = get_text_field_mask(tokens)
embedded_text_input = self._input_dropout(embedded_text_input)
encoded_text = self.encoder(embedded_text_input, mask)
float_mask = mask.float()
encoded_text = self._dropout(encoded_text)
# shape (batch_size, sequence_length, arc_representation_dim)
head_arc_representation = self._dropout(self.head_arc_feedforward(encoded_text))
child_arc_representation = self._dropout(self.child_arc_feedforward(encoded_text))
# shape (batch_size, sequence_length, tag_representation_dim)
head_tag_representation = self._dropout(self.head_tag_feedforward(encoded_text))
child_tag_representation = self._dropout(self.child_tag_feedforward(encoded_text))
# shape (batch_size, sequence_length, sequence_length)
u"are ignoring them, using instead the model parameters in the archive.")
vocabulary_params = params.pop(u'vocabulary', {})
if vocabulary_params.get(u'directory_path', None):
logger.warning(u"You passed `directory_path` in parameters for the vocabulary in "
u"your configuration file, but it will be ignored. ")
all_datasets = datasets_from_params(params)
vocab = model.vocab
if extend_vocab:
datasets_for_vocab_creation = set(params.pop(u"datasets_for_vocab_creation", all_datasets))
for dataset in datasets_for_vocab_creation:
if dataset not in all_datasets:
raise ConfigurationError("invalid 'dataset_for_vocab_creation' {dataset}")
logger.info(u"Extending model vocabulary using %s data.", u", ".join(datasets_for_vocab_creation))
vocab.extend_from_instances(vocabulary_params,
(instance for key, dataset in list(all_datasets.items())
for instance in dataset
if key in datasets_for_vocab_creation))
vocab.save_to_files(os.path.join(serialization_dir, u"vocabulary"))
iterator = DataIterator.from_params(params.pop(u"iterator"))
iterator.index_with(model.vocab)
train_data = all_datasets[u'train']
validation_data = all_datasets.get(u'validation')
test_data = all_datasets.get(u'test')
def _get_instance_data(self) :
if self._dataset_reader is None:
raise ConfigurationError(u"To generate instances directly, pass a DatasetReader.")
else:
yield self._dataset_reader.read(self._input_file)
use_prelinked_entities: bool = True,
use_untyped_entities: bool = True,
token_indexers: Dict[str, TokenIndexer] = None,
cross_validation_split_to_exclude: int = None,
keep_if_unparseable: bool = True,
lazy: bool = False,
) -> None:
super().__init__(lazy)
self._token_indexers = token_indexers or {"tokens": SingleIdTokenIndexer()}
self._use_all_sql = use_all_sql
self._remove_unneeded_aliases = remove_unneeded_aliases
self._use_prelinked_entities = use_prelinked_entities
self._keep_if_unparsable = keep_if_unparseable
if not self._use_prelinked_entities:
raise ConfigurationError(
"The grammar based text2sql dataset reader "
"currently requires the use of entity pre-linking."
)
self._cross_validation_split_to_exclude = str(cross_validation_split_to_exclude)
if database_file is not None:
database_file = cached_path(database_file)
connection = sqlite3.connect(database_file)
self._cursor = connection.cursor()
else:
self._cursor = None
self._schema_path = schema_path
self._world = Text2SqlWorld(
schema_path,
self.tag_projection_layer = TimeDistributed(
Linear(self.text_field_embedder.get_output_dim(), self.num_tags)
)
# if constrain_crf_decoding and calculate_span_f1 are not
# provided, (i.e., they're None), set them to True
# if label_encoding is provided and False if it isn't.
if constrain_crf_decoding is None:
constrain_crf_decoding = label_encoding is not None
if calculate_span_f1 is None:
calculate_span_f1 = label_encoding is not None
self.label_encoding = label_encoding
if constrain_crf_decoding:
if not label_encoding:
raise ConfigurationError("constrain_crf_decoding is True, but "
"no label_encoding was specified.")
labels = self.vocab.get_index_to_token_vocabulary(label_namespace)
constraints = allowed_transitions(label_encoding, labels)
else:
constraints = None
self.include_start_end_transitions = include_start_end_transitions
self.crf = ConditionalRandomField(
self.num_tags, constraints,
include_start_end_transitions=include_start_end_transitions
)
self.metrics = {
"accuracy": CategoricalAccuracy(),
"accuracy3": CategoricalAccuracy(top_k=3)
}
for key in flat_params.keys() - flat_loaded.keys():
logger.error(f"Key '{key}' found in training configuration but not in the serialization "
f"directory we're recovering from.")
fail = True
for key in flat_loaded.keys() - flat_params.keys():
logger.error(f"Key '{key}' found in the serialization directory we're recovering from "
f"but not in the training config.")
fail = True
for key in flat_params.keys():
if flat_params.get(key, None) != flat_loaded.get(key, None):
logger.error(f"Value for '{key}' in training configuration does not match that the value in "
f"the serialization directory we're recovering from: "
f"{flat_params[key]} != {flat_loaded[key]}")
fail = True
if fail:
raise ConfigurationError("Training configuration does not match the configuration we're "
"recovering from.")
else:
if recover:
raise ConfigurationError(f"--recover specified but serialization_dir ({serialization_dir}) "
"does not exist. There is nothing to recover from.")
os.makedirs(serialization_dir, exist_ok=True)
for key in list(flat_loaded.keys()) - list(flat_params.keys()):
logger.error("Key '{key}' found in the serialization directory we're recovering from "
"but not in the training config.")
fail = True
for key in list(flat_params.keys()):
if flat_params.get(key, None) != flat_loaded.get(key, None):
logger.error("Value for '{key}' in training configuration does not match that the value in "
"the serialization directory we're recovering from: "
"{flat_params[key]} != {flat_loaded[key]}")
fail = True
if fail:
raise ConfigurationError(u"Training configuration does not match the configuration we're "
u"recovering from.")
else:
if recover:
raise ConfigurationError("--recover specified but serialization_dir ({serialization_dir}) "
u"does not exist. There is nothing to recover from.")
os.makedirs(serialization_dir, exist_ok=True)