Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
CODE = """data {
int N;
int y[N];
}
parameters {
real theta;
}
model {
theta ~ beta(1,1); // uniform prior on interval 0,1
y ~ bernoulli(theta);
}
"""
BERN_STAN = os.path.join(DATAFILES_PATH, 'bernoulli.stan')
BERN_EXE = os.path.join(DATAFILES_PATH, 'bernoulli' + EXTENSION)
class CmdStanModelTest(unittest.TestCase):
# pylint: disable=no-self-use
@pytest.fixture(scope='class', autouse=True)
def do_clean_up(self):
for root, _, files in os.walk(DATAFILES_PATH):
for filename in files:
_, ext = os.path.splitext(filename)
if ext.lower() in ('.o', '.d', '.hpp', '.exe', ''):
filepath = os.path.join(root, filename)
os.remove(filepath)
def show_cmdstan_version(self):
print('\n\nCmdStan version: {}\n\n'.format(cmdstan_path()))
def test_model_includes_implicit(self):
stan = os.path.join(DATAFILES_PATH, 'bernoulli_include.stan')
exe = os.path.join(DATAFILES_PATH, 'bernoulli_include' + EXTENSION)
if os.path.exists(exe):
os.remove(exe)
model2 = CmdStanModel(stan_file=stan)
self.assertTrue(model2.exe_file.endswith(exe.replace('\\', '/')))
def test_windows_short_path_file(self):
if platform.system() != 'Windows':
return
original_path = os.path.join(_TMPDIR, 'new path', 'my_file.csv')
os.makedirs(os.path.split(original_path)[0], exist_ok=True)
assert os.path.exists(os.path.split(original_path)[0])
assert ' ' in original_path
assert os.path.splitext(original_path)[1] == '.csv'
short_path = windows_short_path(original_path)
assert os.path.exists(os.path.split(short_path)[0])
assert original_path != short_path
assert ' ' not in short_path
assert os.path.splitext(short_path)[1] == '.csv'
def test_sample_plus_quantities_dedup(self):
stan = os.path.join(DATAFILES_PATH, 'bernoulli_ppc.stan')
model = CmdStanModel(stan_file=stan)
jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')
bern_fit = model.sample(
data=jdata, chains=4, cores=2, seed=12345, iter_sampling=100
)
bern_gqs = model.generate_quantities(data=jdata, mcmc_sample=bern_fit)
self.assertEqual(
bern_gqs.sample_plus_quantities.shape[1],
bern_gqs.mcmc_sample.shape[1],
)
def test_save_warmup_thin(self):
stan = os.path.join(DATAFILES_PATH, 'bernoulli.stan')
jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')
bern_model = CmdStanModel(stan_file=stan)
bern_fit = bern_model.sample(
data=jdata,
chains=2,
seed=12345,
iter_warmup=200,
iter_sampling=100,
thin=5,
save_warmup=True,
)
self.assertEqual(bern_fit.column_names, tuple(BERNOULLI_COLS))
self.assertEqual(bern_fit.num_draws_warmup, 40)
self.assertEqual(bern_fit.warmup.shape, (40, 2, len(BERNOULLI_COLS)))
self.assertEqual(bern_fit.num_draws, 20)
self.assertEqual(bern_fit.sample.shape, (20, 2, len(BERNOULLI_COLS)))
def test_bernoulli_good(self, stanfile='bernoulli.stan'):
stan = os.path.join(DATAFILES_PATH, stanfile)
bern_model = CmdStanModel(stan_file=stan)
jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')
bern_fit = bern_model.sample(
data=jdata, chains=2, cores=2, seed=12345, iter_sampling=100
)
self.assertIn('CmdStanMCMC: model=bernoulli', bern_fit.__repr__())
self.assertIn('method=sample', bern_fit.__repr__())
self.assertEqual(bern_fit.runset._args.method, Method.SAMPLE)
for i in range(bern_fit.runset.chains):
csv_file = bern_fit.runset.csv_files[i]
stdout_file = bern_fit.runset.stdout_files[i]
self.assertTrue(os.path.exists(csv_file))
self.assertTrue(os.path.exists(stdout_file))
def test_custom_metric(self):
stan = os.path.join(DATAFILES_PATH, 'bernoulli.stan')
jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')
bern_model = CmdStanModel(stan_file=stan)
jmetric = os.path.join(DATAFILES_PATH, 'bernoulli.metric.json')
# just test that it runs without error
bern_model.sample(
data=jdata,
chains=2,
cores=2,
seed=12345,
iter_sampling=200,
metric=jmetric,
)
def test_read_progress(self):
model = CmdStanModel(stan_file=BERN_STAN, compile=False)
proc_mock = Mock()
proc_mock.poll.side_effect = [None, None, 'finish']
stan_output1 = 'Iteration: 12100 / 31000 [ 39%] (Warmup)'
stan_output2 = 'Iteration: 14000 / 31000 [ 45%] (Warmup)'
pbar = tqdm.tqdm(desc='Chain 1 - warmup', position=1, total=1)
proc_mock.stdout.readline.side_effect = [
stan_output1.encode('utf-8'),
stan_output2.encode('utf-8'),
]
with LogCapture() as log:
result = model._read_progress(proc=proc_mock, pbar=pbar, idx=0)
self.assertEqual([], log.actual())
self.assertEqual(31000, pbar.total)
def test_set_mle_attrs(self):
stan = os.path.join(DATAFILES_PATH, 'optimize', 'rosenbrock.stan')
model = CmdStanModel(stan_file=stan)
no_data = {}
args = OptimizeArgs(algorithm='Newton')
cmdstan_args = CmdStanArgs(
model_name=model.name,
model_exe=model.exe_file,
chain_ids=None,
data=no_data,
method_args=args,
)
runset = RunSet(args=cmdstan_args, chains=1)
mle = CmdStanMLE(runset)
self.assertIn('CmdStanMLE: model=rosenbrock', mle.__repr__())
self.assertIn('method=optimize', mle.__repr__())
self.assertEqual(mle._column_names, ())
self.assertEqual(mle._mle, {})
exe_file=os.path.join('..', 'bernoulli' + EXTENSION),
)
self.assertEqual(model1.stan_file, dotdot_stan)
self.assertEqual(model1.exe_file, dotdot_exe)
os.remove(dotdot_stan)
os.remove(dotdot_exe)
tilde_stan = os.path.realpath(
os.path.join(os.path.expanduser('~'), 'bernoulli.stan')
)
tilde_exe = os.path.realpath(
os.path.join(os.path.expanduser('~'), 'bernoulli' + EXTENSION)
)
shutil.copyfile(BERN_STAN, tilde_stan)
shutil.copyfile(BERN_EXE, tilde_exe)
model2 = CmdStanModel(
stan_file=os.path.join('~', 'bernoulli.stan'),
exe_file=os.path.join('~', 'bernoulli' + EXTENSION),
)
self.assertEqual(model2.stan_file, tilde_stan)
self.assertEqual(model2.exe_file, tilde_exe)
os.remove(tilde_stan)
os.remove(tilde_exe)