Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_term_init(diabetes_data):
model = Model(diabetes_data)
term = Term(model, 'BMI', diabetes_data['BMI'])
# Test that all defaults are properly initialized
assert term.name == 'BMI'
assert term.categorical == False
assert term.type_ == 'fixed'
assert term.levels is not None
assert term.data.shape == (442, 1)
def test_term_split(diabetes_data):
# Split a continuous fixed variable
model = Model(diabetes_data)
model.add_term('BMI', split_by='age_grp')
assert model.terms['BMI'].data.shape == (442, 3)
# Split a categorical fixed variable
model.reset()
model.add_term('BMI', split_by='age_grp', categorical=True)
assert model.terms['BMI'].data.shape == (442, 489)
# Split a continuous random variable
model.reset()
model.add_term('BMI', split_by='age_grp', categorical=False, random=True)
assert model.terms['BMI'].data.shape == (442, 3)
# Split a categorical random variable
model.reset()
model.add_term('BMI', split_by='age_grp', categorical=True, random=True)
t = model.terms['BMI'].data
assert isinstance(t, dict)
assert t['age_grp[0]'].shape == (442, 83)
def test_term_init(diabetes_data):
model = Model(diabetes_data)
term = Term(model, 'BMI', diabetes_data['BMI'])
# Test that all defaults are properly initialized
assert term.name == 'BMI'
assert term.categorical == False
assert term.type_ == 'fixed'
assert term.levels is not None
assert term.data.shape == (442, 1)
def test_prior_retrieval():
config_file = join(dirname(__file__), 'data', 'sample_priors.json')
pf = PriorFactory(config_file)
prior = pf.get(dist='asiago')
assert prior.name == 'Asiago'
assert isinstance(prior, Prior)
assert prior.args['hardness'] == 10
with pytest.raises(KeyError):
assert prior.args['holes'] == 4
family = pf.get(family='hard')
assert isinstance(family, Family)
assert family.link == 'grate'
backup = family.prior.args['backup']
assert isinstance(backup, Prior)
assert backup.args['flavor'] == 10000
prior = pf.get(term='yellow')
assert prior.name == 'Swiss'
def test_prior_class():
prior = Prior('CheeseWhiz', holes=0, taste=-10)
assert prior.name == 'CheeseWhiz'
assert isinstance(prior.args, dict)
assert prior.args['taste'] == -10
prior.update(taste=-100, return_to_store=1)
assert prior.args['return_to_store'] == 1
def test_add_term_to_model(base_model):
base_model.add_term('BMI')
assert isinstance(base_model.terms['BMI'], Term)
base_model.add_term('age_grp', random=False, categorical=True)
# Test that arguments are passed appropriately onto Term initializer
base_model.add_term('BP', random=True, split_by='age_grp', categorical=True)
assert isinstance(base_model.terms['BP'], Term)
def test_model_init_and_intercept(diabetes_data):
model = Model(diabetes_data, intercept=True)
assert hasattr(model, 'data')
assert 'Intercept' in model.terms
assert len(model.terms) == 1
assert model.y is None
assert hasattr(model, 'backend')
model = Model(diabetes_data)
assert 'Intercept' not in model.terms
assert not model.terms
def test_model_init_and_intercept(diabetes_data):
model = Model(diabetes_data, intercept=True)
assert hasattr(model, 'data')
assert 'Intercept' in model.terms
assert len(model.terms) == 1
assert model.y is None
assert hasattr(model, 'backend')
model = Model(diabetes_data)
assert 'Intercept' not in model.terms
assert not model.terms
def base_model(diabetes_data):
return Model(diabetes_data)
def test_prior_factory_init_from_config():
config_file = join(dirname(__file__), 'data', 'sample_priors.json')
pf = PriorFactory(config_file)
for d in ['dists', 'terms', 'families']:
assert hasattr(pf, d)
assert isinstance(getattr(pf, d), dict)
config_dict = json.load(open(config_file, 'r'))
pf = PriorFactory(config_dict)
for d in ['dists', 'terms', 'families']:
assert hasattr(pf, d)
assert isinstance(getattr(pf, d), dict)
assert 'feta' in pf.dists
assert 'hard' in pf.families
assert 'yellow' in pf.terms