Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
token_type_ids=segment_ids)
batch_features, batch_results = [], []
for j, example_index in enumerate(example_indices):
start_logits = batch_start_logits[j].detach().cpu().tolist()
end_logits = batch_end_logits[j].detach().cpu().tolist()
eval_feature = eval_features[example_index.item()]
eval_rank_logit = eval_rank_logits[example_index.item()]
unique_id = int(eval_feature.unique_id)
batch_features.append(eval_feature)
batch_results.append(RawReadResult(unique_id=unique_id,
start_logits=start_logits,
end_logits=end_logits,
rank_logit=eval_rank_logit))
span_starts, span_ends, _, _ = annotate_candidates(eval_examples, batch_features, batch_results,
args.filter_type, False, args.n_best_size_read,
args.max_answer_length, args.do_lower_case,
args.verbose_logging, logger)
span_starts = torch.tensor(span_starts, dtype=torch.long)
span_ends = torch.tensor(span_ends, dtype=torch.long)
span_starts = span_starts.to(device)
span_ends = span_ends.to(device)
sequence_output = sequence_output.to(device)
with torch.no_grad():
batch_rerank_logits = model('rerank_inference', input_mask, span_starts=span_starts,
span_ends=span_ends, sequence_input=sequence_output)
for j, example_index in enumerate(example_indices):
start_logits = batch_start_logits[j].detach().cpu().tolist()
end_logits = batch_end_logits[j].detach().cpu().tolist()
rerank_logits = batch_rerank_logits[j].detach().cpu().numpy()
input_ids, input_mask, segment_ids, start_positions, end_positions, example_indices = read_batch
batch_start_logits, batch_end_logits, _ = model('read_inference', input_mask,
input_ids=input_ids, token_type_ids=segment_ids)
batch_features, batch_results = [], []
for j, example_index in enumerate(example_indices):
start_logits = batch_start_logits[j].detach().cpu().tolist()
end_logits = batch_end_logits[j].detach().cpu().tolist()
train_feature = train_read_features[example_index.item()]
unique_id = int(train_feature.unique_id)
batch_features.append(train_feature)
batch_results.append(RawReadResult(unique_id=unique_id,
start_logits=start_logits,
end_logits=end_logits,
rank_logit=0.))
span_starts, span_ends, hard_labels, soft_labels = annotate_candidates(train_examples, batch_features,
batch_results, args.filter_type,
True, args.n_best_size_read,
args.max_answer_length,
args.do_lower_case,
args.verbose_logging, logger)
span_starts = torch.tensor(span_starts, dtype=torch.long)
span_ends = torch.tensor(span_ends, dtype=torch.long)
hard_labels = torch.tensor(hard_labels, dtype=torch.long)
soft_labels = torch.tensor(soft_labels, dtype=torch.long)
span_starts = span_starts.to(device)
span_ends = span_ends.to(device)
hard_labels = hard_labels.to(device)
soft_labels = soft_labels.to(device)
read_rerank_loss = model('read_rerank_train', input_mask, input_ids=input_ids, token_type_ids=segment_ids,
start_positions=start_positions, end_positions=end_positions, span_starts=span_starts,