Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
pdf = pyhf.Model(spec)
data = source['bindata']['data'] + pdf.config.auxdata
backends = [
pyhf.tensor.numpy_backend(),
pyhf.tensor.tensorflow_backend(session=tf.compat.v1.Session()),
pyhf.tensor.pytorch_backend(),
]
test_statistic = []
for backend in backends:
if backend.name == 'tensorflow':
tf.reset_default_graph()
backend.session = tf.compat.v1.Session()
pyhf.set_backend(backend)
q_mu = pyhf.infer.hypotest(
1.0,
data,
pdf,
pdf.config.suggested_init(),
pdf.config.suggested_bounds(),
return_test_statistics=True,
)[-1][0]
test_statistic.append(pyhf.tensorlib.tolist(q_mu))
# compare to NumPy/SciPy
test_statistic = np.array(test_statistic)
numpy_ratio = np.divide(test_statistic, test_statistic[0])
numpy_ratio_delta_unity = np.absolute(np.subtract(numpy_ratio, 1))
def test_set_backend(backend):
tb, _ = pyhf.get_backend()
pyhf.set_backend(tb.name)
def test_custom_backend_name_supported():
class custom_backend(object):
def __init__(self, **kwargs):
self.name = "pytorch"
with pytest.raises(AttributeError):
pyhf.set_backend(custom_backend())
def test_custom_backend_name_notsupported():
class custom_backend(object):
def __init__(self, **kwargs):
self.name = "notsupported"
backend = custom_backend()
assert pyhf.tensorlib.name != backend.name
pyhf.set_backend(backend)
assert pyhf.tensorlib.name == backend.name
def test_tensorflow_tolist_nosession():
pyhf.set_backend(pyhf.tensor.tensorflow_backend())
tb = pyhf.tensorlib
# this isn't covered by test_list_to_list since we need to check if it's ok
# without a session explicitly
assert tb.tolist([1, 2, 3, 4]) == [1, 2, 3, 4]
with pytest.raises(RuntimeError):
# but a tensor shouldn't
assert tb.tolist(tb.astensor([1, 2, 3, 4])) == [1, 2, 3, 4]
def test_supported_backends():
pyhf.set_backend("fail")
def test_set_backend_by_string(backend_name):
pyhf.set_backend(backend_name)
assert isinstance(
pyhf.tensorlib,
getattr(pyhf.tensor, "{0:s}_backend".format(backend_name.lower())),
)
def test_set_backend_by_bytestring(backend_name):
pyhf.set_backend(backend_name)
assert isinstance(
pyhf.tensorlib,
getattr(pyhf.tensor, "{0:s}_backend".format(backend_name.decode("utf-8"))),
)
def test_supported_backends(backend_name):
with pytest.raises(pyhf.exceptions.InvalidBackend):
pyhf.set_backend(backend_name)