Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
exe = os.path.join(DATAFILES_PATH, 'bernoulli' + EXTENSION)
sampler_args = SamplerArgs()
cmdstan_args = CmdStanArgs(
model_name='bernoulli',
model_exe=exe,
chain_ids=[1],
output_dir=DATAFILES_PATH,
method_args=sampler_args,
)
runset = RunSet(args=cmdstan_args, chains=1)
runset._csv_files = [
os.path.join(
DATAFILES_PATH, 'diagnose-good', 'corr_gauss_depth8-1.csv'
)
]
fit = CmdStanMCMC(runset)
# TODO - use cmdstan test files instead
expected = '\n'.join(
[
'Checking sampler transitions treedepth.',
'424 of 1000 (42%) transitions hit the maximum '
'treedepth limit of 8, or 2^8 leapfrog steps.',
'Trajectories that are prematurely terminated '
'due to this limit will result in slow exploration.',
'For optimal performance, increase this limit.',
]
)
self.assertIn(expected, fit.diagnose().replace('\r\n', '\n'))
exe = os.path.join(DATAFILES_PATH, 'lotka-volterra' + EXTENSION)
jdata = os.path.join(DATAFILES_PATH, 'lotka-volterra.data.json')
sampler_args = SamplerArgs(iter_sampling=20)
cmdstan_args = CmdStanArgs(
model_name='lotka-volterra',
model_exe=exe,
chain_ids=[1],
seed=12345,
data=jdata,
output_dir=DATAFILES_PATH,
method_args=sampler_args,
)
runset = RunSet(args=cmdstan_args, chains=1)
runset._csv_files = [os.path.join(DATAFILES_PATH, 'lotka-volterra.csv')]
runset._set_retcode(0, 0)
fit = CmdStanMCMC(runset)
self.assertEqual(20, fit.num_draws)
self.assertEqual(8, len(fit._stan_var_dims))
self.assertTrue('z' in fit._stan_var_dims)
self.assertEqual(fit._stan_var_dims['z'], [20, 2])
vars = fit.stan_variables()
self.assertEqual(len(vars), len(fit._stan_var_dims))
self.assertTrue('z' in vars)
self.assertEqual(vars['z'].shape, (20, 20, 2))
self.assertTrue('theta' in vars)
self.assertEqual(vars['theta'].shape, (20, 4))
method_args=sampler_args,
)
runset = RunSet(args=cmdstan_args, chains=4)
runset._csv_files = [
os.path.join(DATAFILES_PATH, 'runset-good', 'bern-1.csv'),
os.path.join(DATAFILES_PATH, 'runset-good', 'bern-2.csv'),
os.path.join(DATAFILES_PATH, 'runset-good', 'bern-3.csv'),
os.path.join(DATAFILES_PATH, 'runset-good', 'bern-4.csv'),
]
self.assertEqual(4, runset.chains)
retcodes = runset._retcodes
for i in range(len(retcodes)):
runset._set_retcode(i, 0)
self.assertTrue(runset._check_retcodes())
fit = CmdStanMCMC(runset)
self.assertEqual(100, fit.num_draws)
self.assertEqual(len(BERNOULLI_COLS), len(fit.column_names))
self.assertEqual('lp__', fit.column_names[0])
drawset = fit.get_drawset()
self.assertEqual(
drawset.shape,
(fit.runset.chains * fit.num_draws, len(fit.column_names)),
)
_ = fit.summary()
self.assertTrue(True)
# TODO - use cmdstan test files instead
expected = '\n'.join(
[
'Checking sampler transitions treedepth.',
os.path.join(DATAFILES_PATH, 'runset-bad', 'bad-draws-bern-4.csv'),
]
with self.assertRaisesRegex(ValueError, 'draws'):
CmdStanMCMC(runset)
# mismatch - column headers, draws
runset._csv_files = [
os.path.join(DATAFILES_PATH, 'runset-bad', 'bad-cols-bern-1.csv'),
os.path.join(DATAFILES_PATH, 'runset-bad', 'bad-cols-bern-2.csv'),
os.path.join(DATAFILES_PATH, 'runset-bad', 'bad-cols-bern-3.csv'),
os.path.join(DATAFILES_PATH, 'runset-bad', 'bad-cols-bern-4.csv'),
]
with self.assertRaisesRegex(
ValueError, 'bad draw, expecting 9 items, found 8'
):
CmdStanMCMC(runset)
exe = os.path.join(DATAFILES_PATH, 'lotka-volterra' + EXTENSION)
jdata = os.path.join(DATAFILES_PATH, 'lotka-volterra.data.json')
sampler_args = SamplerArgs(iter_sampling=20)
cmdstan_args = CmdStanArgs(
model_name='lotka-volterra',
model_exe=exe,
chain_ids=[1],
seed=12345,
data=jdata,
output_dir=DATAFILES_PATH,
method_args=sampler_args,
)
runset = RunSet(args=cmdstan_args, chains=1)
runset._csv_files = [os.path.join(DATAFILES_PATH, 'lotka-volterra.csv')]
runset._set_retcode(0, 0)
fit = CmdStanMCMC(runset)
self.assertEqual(20, fit.num_draws)
self.assertEqual(8, len(fit._stan_var_dims))
self.assertTrue('z' in fit._stan_var_dims)
self.assertEqual(fit._stan_var_dims['z'], [20, 2])
z = fit.stan_variable(name='z')
self.assertEqual(z.shape, (20, 20, 2))
theta = fit.stan_variable(name='theta')
self.assertEqual(theta.shape, (20, 4))
exe = os.path.join(DATAFILES_PATH, 'bernoulli' + EXTENSION)
sampler_args = SamplerArgs(iter_warmup=1500, iter_sampling=1000)
cmdstan_args = CmdStanArgs(
model_name='bernoulli',
model_exe=exe,
chain_ids=[1, 2],
seed=12345,
output_dir=DATAFILES_PATH,
method_args=sampler_args,
)
runset = RunSet(args=cmdstan_args, chains=2)
runset._csv_files = [
os.path.join(DATAFILES_PATH, 'runset-big', 'output_icar_nyc-1.csv'),
os.path.join(DATAFILES_PATH, 'runset-big', 'output_icar_nyc-1.csv'),
]
fit = CmdStanMCMC(runset)
phis = ['phi.{}'.format(str(x + 1)) for x in range(2095)]
column_names = SAMPLER_STATE + phis
self.assertEqual(fit.num_draws, 1000)
self.assertEqual(fit.column_names, tuple(column_names))
self.assertEqual(fit.metric_type, 'diag_e')
self.assertEqual(fit.stepsize.shape, (2,))
self.assertEqual(fit.metric.shape, (2, 2095))
self.assertEqual((1000, 2, 2102), fit.sample.shape)
phis = fit.get_drawset(params=['phi'])
self.assertEqual((2000, 2095), phis.shape)
phi1 = fit.get_drawset(params=['phi.1'])
self.assertEqual((2000, 1), phi1.shape)
mo_phis = fit.get_drawset(params=['phi.1', 'phi.10', 'phi.100'])
self.assertEqual((2000, 3), mo_phis.shape)
phi2095 = fit.get_drawset(params=['phi.2095'])
self.assertEqual((2000, 1), phi2095.shape)
' or list of paths to sample csv_files.'
)
try:
chains = len(sample_csv_files)
if sample_drawset is None: # assemble sample from csv files
sampler_args = SamplerArgs()
args = CmdStanArgs(
self._name,
self._exe_file,
chain_ids=[x + 1 for x in range(chains)],
method_args=sampler_args,
)
runset = RunSet(args=args, chains=chains)
runset._csv_files = sample_csv_files
sample_fit = CmdStanMCMC(runset)
sample_fit._validate_csv_files()
sample_drawset = sample_fit.get_drawset()
except ValueError as e:
raise ValueError(
'Invalid mcmc_sample, error:\n\t{}\n\t'
' while processing files\n\t{}'.format(
repr(e), '\n\t'.join(sample_csv_files)
)
)
generate_quantities_args = GenerateQuantitiesArgs(
csv_files=sample_csv_files
)
generate_quantities_args.validate(chains)
with MaybeDictToFilePath(data, None) as (_data, _inits):
args = CmdStanArgs(
# re-enable logger for console
self._logger.propagate = True
err_msg = 'Error during sampling.\n'
if not runset._check_retcodes():
for i in range(chains):
if runset._retcode(i) != 0:
err_msg = '{}chain {} returned error code {}\n'.format(
err_msg, i + 1, runset._retcode(i)
)
console_errs = runset._get_err_msgs()
if len(console_errs) > 0:
err_msg = '{}{}'.format(err_msg, ''.join(console_errs))
raise RuntimeError(err_msg)
mcmc = CmdStanMCMC(runset, fixed_param)
mcmc._validate_csv_files()
return mcmc