How to use the squad.squad_document_utils.RawRankResult function in squad

To help you get started, we’ve selected a few squad examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github huminghao16 / RE3QA / bert / run_squad_document_full_e2e.py View on Github external
def evaluate_rank(args, model, device, eval_examples, eval_features, eval_dataloader, logger, type, n_para,
                  force_answer=False, write_pred=False, verbose_logging=False):
    all_results = []
    for input_ids, input_mask, segment_ids, example_indices in eval_dataloader:
        if len(all_results) % 5000 == 0 and verbose_logging:
            logger.info("Processing example: %d" % (len(all_results)))
        input_ids = input_ids.to(device)
        input_mask = input_mask.to(device)
        segment_ids = segment_ids.to(device)
        with torch.no_grad():
            batch_rank_logits = model('rank', input_mask, input_ids=input_ids, token_type_ids=segment_ids)
        for i, example_index in enumerate(example_indices):
            rank_logits = batch_rank_logits[i].detach().cpu().tolist()
            eval_feature = eval_features[example_index.item()]
            unique_id = int(eval_feature.unique_id)
            all_results.append(RawRankResult(unique_id=unique_id, rank_logit=float(rank_logits[1])))
    metrics, rank_predictions = eval_ranking(force_answer, args.n_best_size_rank, eval_examples, eval_features, all_results)

    if write_pred:
        rank_pred_file = "{}_{}paras_{}best.pkl".format(type, n_para, args.n_best_size_rank)
        rank_pred_path = os.path.join(args.output_dir, rank_pred_file)
        pickle.dump(rank_predictions, open(rank_pred_path, 'wb'))
        if type == 'distill':
            args.rank_train_file = rank_pred_file
        else:
            args.rank_pred_file = rank_pred_file
    return metrics
github huminghao16 / RE3QA / bert / run_triviaqa_wiki_full_e2e.py View on Github external
def evaluate_rank(args, model, device, eval_examples, eval_features, eval_dataloader, logger, type, n_para,
                  force_answer=False, write_pred=False, verbose_logging=False):
    all_results = []
    for input_ids, input_mask, segment_ids, example_indices in eval_dataloader:
        if len(all_results) % 5000 == 0 and verbose_logging:
            logger.info("Processing example: %d" % (len(all_results)))
        input_ids = input_ids.to(device)
        input_mask = input_mask.to(device)
        segment_ids = segment_ids.to(device)
        with torch.no_grad():
            batch_rank_logits = model('rank', input_mask, input_ids=input_ids, token_type_ids=segment_ids)
        for i, example_index in enumerate(example_indices):
            rank_logits = batch_rank_logits[i].detach().cpu().tolist()
            eval_feature = eval_features[example_index.item()]
            unique_id = int(eval_feature.unique_id)
            all_results.append(RawRankResult(unique_id=unique_id, rank_logit=float(rank_logits[1])))
    metrics, rank_predictions = eval_ranking(force_answer, args.n_best_size_rank, eval_examples, eval_features, all_results)

    if write_pred:
        rank_pred_file = "{}_{}paras_{}best.pkl".format(type, n_para, args.n_best_size_rank)
        rank_pred_path = os.path.join(args.output_dir, rank_pred_file)
        pickle.dump(rank_predictions, open(rank_pred_path, 'wb'))
        if type == 'distill':
            args.rank_train_file = rank_pred_file
        else:
            args.rank_pred_file = rank_pred_file
    return metrics