Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_simple_fusion_searcher(self):
index_dirs = ['indexes/lucene-index-cord19-abstract-2020-05-01/',
'indexes/lucene-index-cord19-full-text-2020-05-01/',
'indexes/lucene-index-cord19-paragraph-2020-05-01/']
searcher = SimpleFusionSearcher(index_dirs, method=FusionMethod.RRF)
runs, topics = [], get_topics('covid_round2')
for topic in tqdm(sorted(topics.keys())):
query = topics[topic]['question'] + ' ' + topics[topic]['query']
hits = searcher.search(query, k=10000, query_generator=None, strip_segment_id=True, remove_dups=True)
docid_score_pair = [(hit.docid, hit.score) for hit in hits]
run = TrecRun.from_search_results(docid_score_pair, topic=topic)
runs.append(run)
all_topics_run = TrecRun.concat(runs)
all_topics_run.save_to_txt(output_path='runs/fused.txt', tag='reciprocal_rank_fusion_k=60')
# Only keep topic, docid and rank. Scores have different floating point precisions.
# TODO: We should probably do this in Python as opposed to calling out to shell for better portability.
os.system("""awk '{print $1" "$3" "$4}' runs/fused.txt > runs/this.txt""")
os.system("""awk '{print $1" "$3" "$4}' runs/anserini.covid-r2.fusion1.txt > runs/that.txt""")
self.assertTrue(filecmp.cmp('runs/this.txt', 'runs/that.txt'))
def test_discard_qrels(self):
run = TrecRun('tests/resources/simple_trec_run_filter.txt')
qrels = Qrels('tools/topics-and-qrels/qrels.covid-round1.txt')
run.discard_qrels(qrels, clone=False).save_to_txt(output_path=self.output_path)
self.assertTrue(filecmp.cmp('tests/resources/simple_trec_run_remove_verify.txt', self.output_path))
def test_trec_run_read(self):
input_path = 'tests/resources/simple_trec_run_read.txt'
verify_path = 'tests/resources/simple_trec_run_read_verify.txt'
run = TrecRun(filepath=input_path)
run.save_to_txt(self.output_path)
self.assertTrue(filecmp.cmp(verify_path, self.output_path))
def test_normalize_scores(self):
run = TrecRun('tests/resources/simple_trec_run_fusion_1.txt')
run.rescore(RescoreMethod.NORMALIZE).save_to_txt(self.output_path)
self.assertTrue(filecmp.cmp('tests/resources/simple_trec_run_normalize_verify.txt', self.output_path))
def test_simple_fusion_searcher(self):
index_dirs = ['indexes/lucene-index-cord19-abstract-2020-05-01/',
'indexes/lucene-index-cord19-full-text-2020-05-01/',
'indexes/lucene-index-cord19-paragraph-2020-05-01/']
searcher = SimpleFusionSearcher(index_dirs, method=FusionMethod.RRF)
runs, topics = [], get_topics('covid_round2')
for topic in tqdm(sorted(topics.keys())):
query = topics[topic]['question'] + ' ' + topics[topic]['query']
hits = searcher.search(query, k=10000, query_generator=None, strip_segment_id=True, remove_dups=True)
docid_score_pair = [(hit.docid, hit.score) for hit in hits]
run = TrecRun.from_search_results(docid_score_pair, topic=topic)
runs.append(run)
all_topics_run = TrecRun.concat(runs)
all_topics_run.save_to_txt(output_path='runs/fused.txt', tag='reciprocal_rank_fusion_k=60')
# Only keep topic, docid and rank. Scores have different floating point precisions.
# TODO: We should probably do this in Python as opposed to calling out to shell for better portability.
os.system("""awk '{print $1" "$3" "$4}' runs/fused.txt > runs/this.txt""")
os.system("""awk '{print $1" "$3" "$4}' runs/anserini.covid-r2.fusion1.txt > runs/that.txt""")
self.assertTrue(filecmp.cmp('runs/this.txt', 'runs/that.txt'))
def test_retain_qrels(self):
run = TrecRun('tests/resources/simple_trec_run_filter.txt')
qrels = Qrels('tools/topics-and-qrels/qrels.covid-round1.txt')
run.retain_qrels(qrels, clone=True).save_to_txt(output_path=self.output_path)
self.assertTrue(filecmp.cmp('tests/resources/simple_trec_run_keep_verify.txt', self.output_path))
depth : int
Maximum number of results from each input run to consider. Set to ``None`` by default, which indicates that
the complete list of results is considered.
k : int
Length of final results list. Set to ``None`` by default, which indicates that the union of all input documents
are ranked.
Returns
-------
TrecRun
Output ``TrecRun`` that combines input runs via reciprocal rank fusion.
"""
# TODO: Add option to *not* clone runs, thus making the method destructive, but also more efficient.
rrf_runs = [run.clone().rescore(method=RescoreMethod.RRF, rrf_k=rrf_k) for run in runs]
return TrecRun.merge(rrf_runs, AggregationMethod.SUM, depth=depth, k=k)
are ranked.
Returns
-------
TrecRun
Output ``TrecRun`` that combines input runs via interpolation.
"""
if len(runs) != 2:
raise Exception('Interpolation must be performed on exactly two runs.')
scaled_runs = []
scaled_runs.append(runs[0].clone().rescore(method=RescoreMethod.SCALE, scale=alpha))
scaled_runs.append(runs[1].clone().rescore(method=RescoreMethod.SCALE, scale=(1-alpha)))
return TrecRun.merge(scaled_runs, AggregationMethod.SUM, depth=depth, k=k)
parser = argparse.ArgumentParser(description='Perform various ways of fusion given a list of trec run files.')
parser.add_argument('--runs', type=str, nargs='+', default=[], required=True,
help='List of run files separated by space.')
parser.add_argument('--output', type=str, required=True, help="Path to resulting fused txt.")
parser.add_argument('--runtag', type=str, default="pyserini.fusion", help="Tag name of fused run.")
parser.add_argument('--method', type=FusionMethod, default=FusionMethod.RRF, help="The fusion method to be used.")
parser.add_argument('--rrf.k', dest='rrf_k', type=int, default=60,
help="Parameter k needed for reciprocal rank fusion.")
parser.add_argument('--alpha', type=float, default=0.5, required=False, help='Alpha value used for interpolation.')
parser.add_argument('--depth', type=int, default=1000, required=False, help='Pool depth per topic.')
parser.add_argument('--k', type=int, default=1000, required=False, help='Number of documents to output per topic.')
args = parser.parse_args()
trec_runs = [TrecRun(filepath=path) for path in args.runs]
fused_run = None
if args.method == FusionMethod.RRF:
fused_run = reciprocal_rank_fusion(trec_runs, rrf_k=args.rrf_k, depth=args.depth, k=args.k)
elif args.method == FusionMethod.INTERPOLATION:
fused_run = interpolation(trec_runs, alpha=args.alpha, depth=args.depth, k=args.k)
elif args.method == FusionMethod.AVERAGE:
fused_run = average(trec_runs, depth=args.depth, k=args.k)
else:
raise NotImplementedError(f'Fusion method {args.method} not implemented.')
fused_run.save_to_txt(args.output, tag=args.runtag)
def search(self, q: Union[str, JQuery], k: int = 10, query_generator: JQueryGenerator = None, strip_segment_id=False, remove_dups=False) -> List[JSimpleSearcherResult]:
trec_runs, docid_to_search_result = list(), dict()
for searcher in self.searchers:
docid_score_pair = list()
hits = searcher.search(q, k=k, query_generator=query_generator,
strip_segment_id=strip_segment_id, remove_dups=remove_dups)
for hit in hits:
docid_to_search_result[hit.docid] = hit
docid_score_pair.append((hit.docid, hit.score))
run = TrecRun.from_search_results(docid_score_pair)
trec_runs.append(run)
if self.method == FusionMethod.RRF:
fused_run = reciprocal_rank_fusion(trec_runs, rrf_k=60, depth=1000, k=1000)
else:
raise NotImplementedError()
return SimpleFusionSearcher.convert_to_search_result(fused_run, docid_to_search_result)