Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def diagnose_no_problems(self):
stan = os.path.join(datafiles_path, 'bernoulli.stan')
exe = os.path.join(datafiles_path, 'bernoulli')
if not os.path.exists(exe):
compile_model(stan)
model = Model(stan, exe_file=exe)
jdata = os.path.join(datafiles_path, 'bernoulli.data.json')
post_sample = sample(
model, chains=4, cores=2, seed=12345, sampling_iters=200, data=jdata
)
capturedOutput = io.StringIO()
sys.stdout = capturedOutput
diagnose(post_sample)
sys.stdout = sys.__stdout__
self.assertEqual(capturedOutput.getvalue(), 'No problems detected.\n')
def test_missing_input(self):
stan = os.path.join(datafiles_path, 'bernoulli.stan')
output = os.path.join(TMPDIR, 'test4-bernoulli-output')
model = compile_model(stan)
with self.assertRaisesRegex(Exception, 'Error during sampling'):
post_sample = sample(model, csv_output_file=output)
def test_bernoulli_1(self):
stan = os.path.join(datafiles_path, 'bernoulli.stan')
exe = os.path.join(datafiles_path, 'bernoulli')
if not os.path.exists(exe):
compile_model(stan)
model = Model(stan, exe_file=exe)
jdata = os.path.join(datafiles_path, 'bernoulli.data.json')
output = os.path.join(datafiles_path, 'test1-bernoulli-output')
post_sample = sample(
model,
chains=4,
cores=2,
seed=12345,
sampling_iters=100,
data=jdata,
csv_output_file=output,
max_treedepth=11,
adapt_delta=0.95,
)
for i in range(post_sample.chains):
csv_file = post_sample.csv_files[i]
def test_bernoulli_data(self):
data_dict = {'N': 10, 'y': [0, 1, 0, 0, 0, 0, 0, 0, 0, 1]}
stan = os.path.join(datafiles_path, 'bernoulli.stan')
output = os.path.join(TMPDIR, 'test3-bernoulli-output')
model = compile_model(stan)
post_sample = sample(model, data=data_dict, csv_output_file=output)
for i in range(post_sample.chains):
csv_file = post_sample.csv_files[i]
txt_file = ''.join([os.path.splitext(csv_file)[0], '.txt'])
self.assertTrue(os.path.exists(csv_file))
self.assertTrue(os.path.exists(txt_file))
def test_postsample_good(self):
column_names = ['lp__','accept_stat__','stepsize__','treedepth__',
'n_leapfrog__','divergent__','energy__', 'theta']
stan = os.path.join(datafiles_path, 'bernoulli.stan')
exe = os.path.join(datafiles_path, 'bernoulli')
if not os.path.exists(exe):
compile_model(stan)
model = Model(stan, exe_file=exe)
jdata = os.path.join(datafiles_path, 'bernoulli.data.json')
post_sample = sample(model, data_file=jdata)
self.assertEqual(post_sample.chains,4)
self.assertEqual(post_sample.draws,1000)
self.assertEqual(post_sample.column_names, tuple(column_names))
post_sample.sample
self.assertEqual(post_sample.sample.shape,(1000, 4, 8))
df = post_sample.summary()
self.assertTrue(df.shape == (2, 9))
capturedOutput = io.StringIO()
sys.stdout = capturedOutput
post_sample.diagnose()
sys.stdout = sys.__stdout__
self.assertEqual(capturedOutput.getvalue(), 'No problems detected.\n')
def test_bernoulli_rdata(self):
rdata = os.path.join(datafiles_path, 'bernoulli.data.R')
stan = os.path.join(datafiles_path, 'bernoulli.stan')
output = os.path.join(TMPDIR, 'test3-bernoulli-output')
model = compile_model(stan)
post_sample = sample(model, data=rdata, csv_output_file=output)
for i in range(post_sample.chains):
csv_file = post_sample.csv_files[i]
txt_file = ''.join([os.path.splitext(csv_file)[0], '.txt'])
self.assertTrue(os.path.exists(csv_file))
self.assertTrue(os.path.exists(txt_file))
def test_bad(self):
stan = os.path.join(TMPDIR, 'bbad.stan')
with self.assertRaises(Exception):
model = compile_model(stan)
def test_bernoulli_2(self):
# tempfile for outputs
stan = os.path.join(datafiles_path, 'bernoulli.stan')
exe = os.path.join(datafiles_path, 'bernoulli')
if not os.path.exists(exe):
compile_model(stan)
model = Model(stan, exe_file=exe)
jdata = os.path.join(datafiles_path, 'bernoulli.data.json')
post_sample = sample(
model,
chains=4,
cores=2,
seed=12345,
sampling_iters=100,
data=jdata,
max_treedepth=11,
adapt_delta=0.95,
)
for i in range(post_sample.chains):
csv_file = post_sample.csv_files[i]
txt_file = ''.join([os.path.splitext(csv_file)[0], '.txt'])
self.assertTrue(os.path.exists(csv_file))
def test_bernoulli(self):
stan = os.path.join(datafiles_path, 'bernoulli.stan')
exe = os.path.join(datafiles_path, 'bernoulli')
if not os.path.exists(exe):
compile_model(stan)
model = Model(stan, exe_file=exe)
jdata = os.path.join(datafiles_path, 'bernoulli.data.json')
post_sample = sample(
model, chains=4, cores=2, seed=12345, sampling_iters=200, data=jdata
)
post_sample.assemble_sample()
df = get_drawset(post_sample)
self.assertEqual(
df.shape,
(
post_sample.chains * post_sample.draws,
len(post_sample.column_names),
),
def test_include(self):
stan = os.path.join(datafiles_path, 'bernoulli_include.stan')
exe = os.path.join(datafiles_path, 'bernoulli_include')
here = os.path.dirname(os.path.abspath(__file__))
datafiles_abspath = os.path.join(here, 'data')
include_paths = [datafiles_abspath]
if os.path.exists(exe):
os.remove(exe)
model = compile_model(stan, include_paths=include_paths)
self.assertEqual(stan, model.stan_file)
self.assertTrue(model.exe_file.endswith(exe))