Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_fasta_based_dataset(intervals_file, fasta_file):
# just test the functionality
dl = StringSeqIntervalDl(intervals_file, fasta_file)
ret_val = dl[0]
assert isinstance(ret_val["inputs"], np.ndarray)
assert ret_val["inputs"].shape == ()
# # test with set wrong seqlen:
# dl = StringSeqIntervalDl(intervals_file, fasta_file, required_seq_len=3)
# with pytest.raises(Exception):
# dl[0]
dl = StringSeqIntervalDl(intervals_file, fasta_file, label_dtype="str")
ret_val = dl[0]
assert isinstance(ret_val['targets'][0], np.str_)
dl = StringSeqIntervalDl(intervals_file, fasta_file, label_dtype="int")
ret_val = dl[0]
assert isinstance(ret_val['targets'][0], np.int_)
dl = StringSeqIntervalDl(intervals_file, fasta_file, label_dtype="bool")
ret_val = dl[0]
dl = StringSeqIntervalDl(intervals_file, fasta_file)
ret_val = dl[0]
assert isinstance(ret_val["inputs"], np.ndarray)
assert ret_val["inputs"].shape == ()
# # test with set wrong seqlen:
# dl = StringSeqIntervalDl(intervals_file, fasta_file, required_seq_len=3)
# with pytest.raises(Exception):
# dl[0]
dl = StringSeqIntervalDl(intervals_file, fasta_file, label_dtype="str")
ret_val = dl[0]
assert isinstance(ret_val['targets'][0], np.str_)
dl = StringSeqIntervalDl(intervals_file, fasta_file, label_dtype="int")
ret_val = dl[0]
assert isinstance(ret_val['targets'][0], np.int_)
dl = StringSeqIntervalDl(intervals_file, fasta_file, label_dtype="bool")
ret_val = dl[0]
assert isinstance(ret_val['targets'][0], np.bool_)
vals = dl.load_all()
assert vals['inputs'][0] == 'GT'
def test_fasta_based_dataset(intervals_file, fasta_file):
# just test the functionality
dl = StringSeqIntervalDl(intervals_file, fasta_file)
ret_val = dl[0]
assert isinstance(ret_val["inputs"], np.ndarray)
assert ret_val["inputs"].shape == ()
# # test with set wrong seqlen:
# dl = StringSeqIntervalDl(intervals_file, fasta_file, required_seq_len=3)
# with pytest.raises(Exception):
# dl[0]
dl = StringSeqIntervalDl(intervals_file, fasta_file, label_dtype="str")
ret_val = dl[0]
assert isinstance(ret_val['targets'][0], np.str_)
dl = StringSeqIntervalDl(intervals_file, fasta_file, label_dtype="int")
ret_val = dl[0]
assert isinstance(ret_val['targets'][0], np.int_)
dl = StringSeqIntervalDl(intervals_file, fasta_file, label_dtype="bool")
ret_val = dl[0]
assert isinstance(ret_val['targets'][0], np.bool_)
vals = dl.load_all()
assert vals['inputs'][0] == 'GT'
def __init__(self,
intervals_file,
fasta_file,
num_chr_fasta=False,
label_dtype=None,
auto_resize_len=None,
# max_seq_len=None,
# use_strand=False,
alphabet_axis=1,
dummy_axis=None,
alphabet="ACGT",
ignore_targets=False,
dtype=None):
# core dataset, not using the one-hot encoding params
self.seq_dl = StringSeqIntervalDl(intervals_file, fasta_file, num_chr_fasta=num_chr_fasta,
label_dtype=label_dtype, auto_resize_len=auto_resize_len,
# use_strand=use_strand,
ignore_targets=ignore_targets)
self.input_transform = ReorderedOneHot(alphabet=alphabet,
dtype=dtype,
alphabet_axis=alphabet_axis,
dummy_axis=dummy_axis)