Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
span_starts = span_starts.to(device)
span_ends = span_ends.to(device)
sequence_output = sequence_output.to(device)
with torch.no_grad():
batch_rerank_logits = model('rerank_inference', input_mask, span_starts=span_starts,
span_ends=span_ends, sequence_input=sequence_output)
for j, example_index in enumerate(example_indices):
start_logits = batch_start_logits[j].detach().cpu().tolist()
end_logits = batch_end_logits[j].detach().cpu().tolist()
rerank_logits = batch_rerank_logits[j].detach().cpu().numpy()
start_indexes = span_starts[j].detach().cpu().tolist()
end_indexes = span_ends[j].detach().cpu().tolist()
eval_feature = eval_features[example_index.item()]
eval_rank_logit = eval_rank_logits[example_index.item()]
unique_id = int(eval_feature.unique_id)
all_results.append(RawFinalResult(unique_id=unique_id, start_logits=start_logits, end_logits=end_logits,
rank_logit=eval_rank_logit, rerank_logits=rerank_logits,
start_indexes=start_indexes, end_indexes=end_indexes))
all_predictions, all_nbest_json = write_rerank_predictions(eval_examples, eval_features, all_results, args.length_heuristic,
args.pred_rank_weight, args.pred_rerank_weight,
args.ablate_type, args.n_best_size_read,
args.max_answer_length, args.do_lower_case,
args.verbose_logging, logger)
if write_pred:
output_prediction_file = os.path.join(args.output_dir, "{}_predictions.json".format(type))
output_nbest_file = os.path.join(args.output_dir, "{}_nbest_predictions.json".format(type))
with open(output_prediction_file, "w") as writer:
writer.write(json.dumps(all_predictions, indent=4) + "\n")
with open(output_nbest_file, "w") as writer:
writer.write(json.dumps(all_nbest_json, indent=4) + "\n")
span_starts = span_starts.to(device)
span_ends = span_ends.to(device)
sequence_output = sequence_output.to(device)
with torch.no_grad():
batch_rerank_logits = model('rerank_inference', input_mask, span_starts=span_starts,
span_ends=span_ends, sequence_input=sequence_output)
for j, example_index in enumerate(example_indices):
start_logits = batch_start_logits[j].detach().cpu().tolist()
end_logits = batch_end_logits[j].detach().cpu().tolist()
rerank_logits = batch_rerank_logits[j].detach().cpu().numpy()
start_indexes = span_starts[j].detach().cpu().tolist()
end_indexes = span_ends[j].detach().cpu().tolist()
eval_feature = eval_features[example_index.item()]
eval_rank_logit = eval_rank_logits[example_index.item()]
unique_id = int(eval_feature.unique_id)
all_results.append(RawFinalResult(unique_id=unique_id, start_logits=start_logits, end_logits=end_logits,
rank_logit=eval_rank_logit, rerank_logits=rerank_logits,
start_indexes=start_indexes, end_indexes=end_indexes))
all_predictions, all_nbest_json = write_rerank_predictions(eval_examples, eval_features, all_results, args.length_heuristic,
args.pred_rank_weight, args.pred_rerank_weight,
args.ablate_type, args.n_best_size_read,
args.max_answer_length, args.do_lower_case,
args.verbose_logging, logger)
if write_pred:
output_prediction_file = os.path.join(args.output_dir, "predictions.json")
output_nbest_file = os.path.join(args.output_dir, "nbest_predictions.json")
with open(output_prediction_file, "w") as writer:
writer.write(json.dumps(all_predictions, indent=4) + "\n")
with open(output_nbest_file, "w") as writer:
writer.write(json.dumps(all_nbest_json, indent=4) + "\n")