Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
outputs = model(**inputs)
for i, example_index in enumerate(example_indices):
eval_feature = features[example_index.item()]
unique_id = int(eval_feature.unique_id)
if args['model_type'] in ['xlnet', 'xlm']:
# XLNet uses a more complex post-processing procedure
result = RawResultExtended(unique_id=unique_id,
start_top_log_probs=to_list(outputs[0][i]),
start_top_index=to_list(outputs[1][i]),
end_top_log_probs=to_list(outputs[2][i]),
end_top_index=to_list(outputs[3][i]),
cls_logits=to_list(outputs[4][i]))
else:
result = RawResult(unique_id=unique_id,
start_logits=to_list(outputs[0][i]),
end_logits=to_list(outputs[1][i]))
all_results.append(result)
prefix = 'test'
if not os.path.isdir(output_dir):
os.mkdir(output_dir)
output_prediction_file = os.path.join(output_dir, "predictions_{}.json".format(prefix))
output_nbest_file = os.path.join(output_dir, "nbest_predictions_{}.json".format(prefix))
output_null_log_odds_file = os.path.join(output_dir, "null_odds_{}.json".format(prefix))
if args['model_type'] in ['xlnet', 'xlm']:
# XLNet uses a more complex post-processing procedure
all_predictions, all_nbest_json, scores_diff_json = write_predictions_extended(examples, features, all_results, args['n_best_size'],
args['max_answer_length'], output_prediction_file,
outputs = model(**inputs)
for i, example_index in enumerate(example_indices):
eval_feature = features[example_index.item()]
unique_id = int(eval_feature.unique_id)
if args['model_type'] in ['xlnet', 'xlm']:
# XLNet uses a more complex post-processing procedure
result = RawResultExtended(unique_id=unique_id,
start_top_log_probs=to_list(outputs[0][i]),
start_top_index=to_list(outputs[1][i]),
end_top_log_probs=to_list(outputs[2][i]),
end_top_index=to_list(outputs[3][i]),
cls_logits=to_list(outputs[4][i]))
else:
result = RawResult(unique_id=unique_id,
start_logits=to_list(outputs[0][i]),
end_logits=to_list(outputs[1][i]))
all_results.append(result)
if args['model_type'] in ['xlnet', 'xlm']:
answers = get_best_predictions_extended(examples, features, all_results, n_best_size,
args['max_answer_length'], model.config.start_n_top, model.config.end_n_top, True, tokenizer, args['null_score_diff_threshold'])
else:
answers = get_best_predictions(examples, features, all_results, n_best_size, args['max_answer_length'], False, False, True, False)
return answers