Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def verify_round_trip(self, genotypes, exclude_sites):
self.assertEqual(genotypes.shape[0], exclude_sites.shape[0])
with tsinfer.SampleData() as sample_data:
for j in range(genotypes.shape[0]):
sample_data.add_site(j, genotypes[j])
exclude_positions = sample_data.sites_position[:][exclude_sites]
for simplify in [False, True]:
output_ts = tsinfer.infer(
sample_data, simplify=simplify, exclude_positions=exclude_positions
)
for tree in output_ts.trees():
for site in tree.sites():
inf_type = json.loads(site.metadata)["inference_type"]
if exclude_sites[site.id]:
self.assertEqual(inf_type, tsinfer.INFERENCE_FITCH_PARSIMONY)
else:
self.assertEqual(inf_type, tsinfer.INFERENCE_FULL)
f = np.sum(genotypes[site.id])
if f == 0:
def test_zero_sequence_length(self):
# Mangle a sample data file to force a zero sequence length.
ts = msprime.simulate(10, mutation_rate=2, random_seed=5)
with tempfile.TemporaryDirectory(prefix="tsinf_format_test") as tempdir:
filename = os.path.join(tempdir, "samples.tmp")
with tsinfer.SampleData(path=filename) as sample_data:
for var in ts.variants():
sample_data.add_site(var.site.position, var.genotypes)
store = zarr.LMDBStore(filename, subdir=False)
data = zarr.open(store=store, mode="w+")
data.attrs["sequence_length"] = 0
store.close()
sample_data = tsinfer.load(filename)
self.assertEqual(sample_data.sequence_length, 0)
self.assertRaises(ValueError, tsinfer.generate_ancestors, sample_data)
def test_match_ancestors_samples(self):
with tsinfer.SampleData(sequence_length=2) as sample_data:
sample_data.add_site(1, genotypes=[0, 1, 1, 0], alleles=["G", "C"])
ancestor_data = tsinfer.generate_ancestors(sample_data)
# match_ancestors fails when samples unfinalised
unfinalised = tsinfer.SampleData(sequence_length=2)
unfinalised.add_site(1, genotypes=[0, 1, 1, 0], alleles=["G", "C"])
self.assertRaises(
ValueError, tsinfer.match_ancestors, unfinalised, ancestor_data
)
def test_infer(self):
ts = msprime.simulate(10, mutation_rate=1, random_seed=1)
self.assertGreater(ts.num_sites, 1)
samples = tsinfer.SampleData.from_tree_sequence(ts)
inferred_ts = tsinfer.infer(samples)
self.validate_ts(inferred_ts)
def test_large_random_data(self):
n = 100
m = 30
G, positions = get_random_data_example(n, m)
with tsinfer.SampleData(sequence_length=m) as sample_data:
for genotypes, position in zip(G, positions):
sample_data.add_site(position, genotypes)
self.verify(sample_data)
def generate_samples(ts, error_param=0):
"""
Generate a samples file from a simulated ts based on the empirically estimated
error matrix saved in self.error_matrix.
Reject any variants that result in a fixed column.
"""
assert ts.num_sites != 0
sd = tsinfer.SampleData(sequence_length=ts.sequence_length)
try:
e = float(error_param)
for v in ts.variants():
g = v.genotypes if error_param == 0 else make_errors(v.genotypes, e)
sd.add_site(position=v.site.position, alleles=v.alleles, genotypes=g)
except ValueError:
error_matrix = pd.read_csv(error_param)
# Error_param is not a number => is a error file
# First record the allele frequency
for v in ts.variants():
m = v.genotypes.shape[0]
frequency = np.sum(v.genotypes) / m
# Find closest row in error matrix file
closest_row = (error_matrix["freq"] - frequency).abs().argsort()[:1]
closest_freq = error_matrix.iloc[closest_row]
g = make_errors_genotype_model(v.genotypes, closest_freq)
def convert(
vcf_file, pedigree_file, output_file, max_variants=None, show_progress=False):
if max_variants is None:
max_variants = 2**32 # Arbitrary, but > defined max for VCF
with tsinfer.SampleData(path=output_file, num_flush_threads=2) as sample_data:
pop_id_map = add_populations(sample_data)
vcf = cyvcf2.VCF(vcf_file)
individual_names = list(vcf.samples)
vcf.close()
with open(pedigree_file, "r") as ped_file:
add_samples(ped_file, pop_id_map, individual_names, sample_data)
for index, site in enumerate(variants(vcf_file, show_progress)):
sample_data.add_site(
position=site.position, genotypes=site.genotypes,
alleles=site.alleles, metadata=site.metadata)
if index == max_variants:
break
sample_data.record_provenance(command=sys.argv[0], args=sys.argv[1:])