Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def get_augmented_ancestors_tree_sequence(self, sample_indexes):
"""
Return the ancestors tree sequence augmented with samples as extra ancestors.
"""
logger.debug("Building augmented ancestors tree sequence")
tsb = self.tree_sequence_builder
tables = self.ancestors_ts_tables.copy()
flags, times = tsb.dump_nodes()
s = 0
num_pc_ancestors = 0
for j in range(len(tables.nodes), len(flags)):
if times[j] == 0.0:
# This is an augmented ancestor node.
tables.nodes.add_row(
flags=constants.NODE_IS_SAMPLE_ANCESTOR,
time=times[j],
metadata=self.encode_metadata(
{"sample_data_id": int(sample_indexes[s])}
),
)
s += 1
else:
# This is a path compressed node
tables.nodes.add_row(flags=flags[j], time=times[j])
assert is_pc_ancestor(flags[j])
num_pc_ancestors += 1
assert s == len(sample_indexes)
assert len(tables.nodes) == len(flags)
# Increment the time for all nodes so the augmented samples are no longer
# at timepoint 0.
def update_site_metadata(current_metadata, inference_type):
assert inference_type in {
constants.INFERENCE_FULL,
constants.INFERENCE_FITCH_PARSIMONY,
}
return {"inference_type": inference_type, **current_metadata}
def __init__(
self,
sample_data,
inference_site_position,
num_threads=1,
path_compression=True,
recombination_rate=None,
mismatch_rate=None,
precision=None,
extended_checks=False,
engine=constants.C_ENGINE,
progress_monitor=None,
):
self.sample_data = sample_data
self.num_threads = num_threads
self.path_compression = path_compression
self.num_samples = self.sample_data.num_samples
self.num_sites = len(inference_site_position)
self.progress_monitor = _get_progress_monitor(progress_monitor)
self.match_progress = None # Allocated by subclass
self.extended_checks = extended_checks
# Map of site index to tree sequence position. Bracketing
# values of 0 and L are used for simplicity.
self.position_map = np.hstack(
[inference_site_position, [sample_data.sequence_length]]
)
self.position_map[0] = 0
def infer(
sample_data,
*,
num_threads=0,
path_compression=True,
simplify=True,
recombination_rate=None,
mismatch_rate=None,
precision=None,
exclude_positions=None,
engine=constants.C_ENGINE,
progress_monitor=None,
):
"""
infer(sample_data, *, num_threads=0, path_compression=True, simplify=True,\
exclude_positions=None)
Runs the full :ref:`inference pipeline ` on the specified
:class:`SampleData` instance and returns the inferred
:class:`tskit.TreeSequence`.
:param SampleData sample_data: The input :class:`SampleData` instance
representing the observed data that we wish to make inferences from.
:param int num_threads: The number of worker threads to use in parallelised
sections of the algorithm. If <= 0, do not spawn any threads and
use simpler sequential algorithms (default).
:param bool path_compression: Whether to merge edges that share identical
engine=constants.C_ENGINE,
progress_monitor=None,
):
self.sample_data = sample_data
self.ancestor_data = ancestor_data
self.progress_monitor = progress_monitor
self.max_sites = sample_data.num_sites
self.num_sites = 0
self.num_samples = sample_data.num_samples
self.num_threads = num_threads
if engine == constants.C_ENGINE:
logger.debug("Using C AncestorBuilder implementation")
self.ancestor_builder = _tsinfer.AncestorBuilder(
self.num_samples, self.max_sites
)
elif engine == constants.PY_ENGINE:
logger.debug("Using Python AncestorBuilder implementation")
self.ancestor_builder = algorithm.AncestorBuilder(
self.num_samples, self.max_sites
)
else:
raise ValueError("Unknown engine:{}".format(engine))
logger.info("Starting addition of {} sites".format(self.max_sites))
progress = self.progress_monitor.get("ga_add_sites", self.max_sites)
inference_site_id = []
for variant in self.sample_data.variants():
# If there's missing data the last allele is None
num_alleles = len(variant.alleles) - int(variant.alleles[-1] is None)
counts = allele_counts(variant.genotypes)
use_site = False
if variant.site.position not in exclude_positions:
if num_alleles == 2:
if counts.derived > 1 and counts.derived < counts.known:
use_site = True
if use_site:
time = variant.site.time
if time == constants.TIME_UNSPECIFIED:
# Non-variable sites have no obvious freq-as-time values
assert counts.known != counts.derived
assert counts.known != counts.ancestral
assert counts.known > 0
# Time = freq of *all* derived alleles. Note that if n_alleles > 2 this
# may not be sensible: https://github.com/tskit-dev/tsinfer/issues/228
time = counts.derived / counts.known
self.ancestor_builder.add_site(time, variant.genotypes)
inference_site_id.append(variant.site.id)
self.num_sites += 1
progress.update()
progress.close()
self.ancestor_data.set_inference_sites(inference_site_id)
logger.info("Finished adding sites")
raise ValueError("Non-missing values for genotypes must be < num alleles")
if position < 0:
raise ValueError("Site position must be > 0")
if self.sequence_length > 0 and position >= self.sequence_length:
raise ValueError("Site position must be less than the sequence length")
if position <= self._last_position:
raise ValueError(
"Site positions must be unique and added in increasing order"
)
if inference is not None:
raise ValueError(
"Inference sites no longer be stored in the sample data file. "
"Please use the exclude_sites option to generate_ancestors."
)
if time is None:
time = constants.TIME_UNSPECIFIED
site_id = self._sites_writer.add(
position=position,
genotypes=genotypes,
metadata=self._check_metadata(metadata),
alleles=alleles,
time=time,
)
self._last_position = position
return site_id
def is_pc_ancestor(flags):
"""
Returns True if the path compression ancestor flag is set on the specified
flags value.
"""
return (flags & constants.NODE_IS_PC_ANCESTOR) != 0