Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_validate_big_run(self):
exe = os.path.join(DATAFILES_PATH, 'bernoulli' + EXTENSION)
sampler_args = SamplerArgs(iter_warmup=1500, iter_sampling=1000)
cmdstan_args = CmdStanArgs(
model_name='bernoulli',
model_exe=exe,
chain_ids=[1, 2],
seed=12345,
output_dir=DATAFILES_PATH,
method_args=sampler_args,
)
runset = RunSet(args=cmdstan_args, chains=2)
runset._csv_files = [
os.path.join(DATAFILES_PATH, 'runset-big', 'output_icar_nyc-1.csv'),
os.path.join(DATAFILES_PATH, 'runset-big', 'output_icar_nyc-1.csv'),
]
fit = CmdStanMCMC(runset)
phis = ['phi.{}'.format(str(x + 1)) for x in range(2095)]
column_names = SAMPLER_STATE + phis
self.assertEqual(fit.num_draws, 1000)
self.assertEqual(fit.column_names, tuple(column_names))
self.assertEqual(fit.metric_type, 'diag_e')
self.assertEqual(fit.stepsize.shape, (2,))
self.assertEqual(fit.metric.shape, (2, 2095))
self.assertEqual((1000, 2, 2102), fit.sample.shape)
phis = fit.get_drawset(params=['phi'])
self.assertEqual((2000, 2095), phis.shape)
phi1 = fit.get_drawset(params=['phi.1'])
def test_variable_lv(self):
# pylint: disable=C0103
# construct fit using existing sampler output
exe = os.path.join(DATAFILES_PATH, 'lotka-volterra' + EXTENSION)
jdata = os.path.join(DATAFILES_PATH, 'lotka-volterra.data.json')
sampler_args = SamplerArgs(iter_sampling=20)
cmdstan_args = CmdStanArgs(
model_name='lotka-volterra',
model_exe=exe,
chain_ids=[1],
seed=12345,
data=jdata,
output_dir=DATAFILES_PATH,
method_args=sampler_args,
)
runset = RunSet(args=cmdstan_args, chains=1)
runset._csv_files = [os.path.join(DATAFILES_PATH, 'lotka-volterra.csv')]
runset._set_retcode(0, 0)
fit = CmdStanMCMC(runset)
self.assertEqual(20, fit.num_draws)
self.assertEqual(8, len(fit._stan_var_dims))
self.assertTrue('z' in fit._stan_var_dims)
self.assertEqual(fit._stan_var_dims['z'], [20, 2])
z = fit.stan_variable(name='z')
self.assertEqual(z.shape, (20, 20, 2))
theta = fit.stan_variable(name='theta')
self.assertEqual(theta.shape, (20, 4))
# construct fit using existing sampler output
exe = os.path.join(DATAFILES_PATH, 'bernoulli' + EXTENSION)
jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')
sampler_args = SamplerArgs(
iter_sampling=100, max_treedepth=11, adapt_delta=0.95
)
cmdstan_args = CmdStanArgs(
model_name='bernoulli',
model_exe=exe,
chain_ids=[1, 2, 3, 4],
seed=12345,
data=jdata,
output_dir=DATAFILES_PATH,
method_args=sampler_args,
)
runset = RunSet(args=cmdstan_args, chains=4)
runset._csv_files = [
os.path.join(DATAFILES_PATH, 'runset-good', 'bern-1.csv'),
os.path.join(DATAFILES_PATH, 'runset-good', 'bern-2.csv'),
os.path.join(DATAFILES_PATH, 'runset-good', 'bern-3.csv'),
os.path.join(DATAFILES_PATH, 'runset-good', 'bern-4.csv'),
]
self.assertEqual(4, runset.chains)
retcodes = runset._retcodes
for i in range(len(retcodes)):
runset._set_retcode(i, 0)
self.assertTrue(runset._check_retcodes())
fit = CmdStanMCMC(runset)
self.assertEqual(100, fit.num_draws)
self.assertEqual(len(BERNOULLI_COLS), len(fit.column_names))
self.assertEqual('lp__', fit.column_names[0])
def test_validate_bad_run(self):
exe = os.path.join(DATAFILES_PATH, 'bernoulli' + EXTENSION)
jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')
sampler_args = SamplerArgs(max_treedepth=11, adapt_delta=0.95)
# some chains had errors
cmdstan_args = CmdStanArgs(
model_name='bernoulli',
model_exe=exe,
chain_ids=[1, 2, 3, 4],
seed=12345,
data=jdata,
output_dir=DATAFILES_PATH,
method_args=sampler_args,
)
runset = RunSet(args=cmdstan_args, chains=4)
for i in range(4):
runset._set_retcode(i, 0)
self.assertTrue(runset._check_retcodes())
# errors reported
runset._stderr_files = [
os.path.join(
DATAFILES_PATH, 'runset-bad', 'bad-transcript-bern-1.txt'
),
os.path.join(
DATAFILES_PATH, 'runset-bad', 'bad-transcript-bern-2.txt'
),
os.path.join(
DATAFILES_PATH, 'runset-bad', 'bad-transcript-bern-3.txt'
),
os.path.join(
def test_get_err_msgs(self):
exe = os.path.join(DATAFILES_PATH, 'logistic' + EXTENSION)
rdata = os.path.join(DATAFILES_PATH, 'logistic.data.R')
sampler_args = SamplerArgs()
cmdstan_args = CmdStanArgs(
model_name='logistic',
model_exe=exe,
chain_ids=[1, 2, 3],
data=rdata,
method_args=sampler_args,
)
runset = RunSet(args=cmdstan_args, chains=3)
for i in range(3):
runset._set_retcode(i, 70)
stdout_file = 'chain-' + str(i + 1) + '-missing-data-stdout.txt'
path = os.path.join(DATAFILES_PATH, stdout_file)
runset._stdout_files[i] = path
errs = '\n\t'.join(runset._get_err_msgs())
self.assertIn('Exception', errs)
def test_variables(self):
# construct fit using existing sampler output
exe = os.path.join(DATAFILES_PATH, 'lotka-volterra' + EXTENSION)
jdata = os.path.join(DATAFILES_PATH, 'lotka-volterra.data.json')
sampler_args = SamplerArgs(iter_sampling=20)
cmdstan_args = CmdStanArgs(
model_name='lotka-volterra',
model_exe=exe,
chain_ids=[1],
seed=12345,
data=jdata,
output_dir=DATAFILES_PATH,
method_args=sampler_args,
)
runset = RunSet(args=cmdstan_args, chains=1)
runset._csv_files = [os.path.join(DATAFILES_PATH, 'lotka-volterra.csv')]
runset._set_retcode(0, 0)
fit = CmdStanMCMC(runset)
self.assertEqual(20, fit.num_draws)
self.assertEqual(8, len(fit._stan_var_dims))
self.assertTrue('z' in fit._stan_var_dims)
self.assertEqual(fit._stan_var_dims['z'], [20, 2])
vars = fit.stan_variables()
self.assertEqual(len(vars), len(fit._stan_var_dims))
self.assertTrue('z' in vars)
self.assertEqual(vars['z'].shape, (20, 20, 2))
self.assertTrue('theta' in vars)
self.assertEqual(vars['theta'].shape, (20, 4))
def test_set_mle_attrs(self):
stan = os.path.join(DATAFILES_PATH, 'optimize', 'rosenbrock.stan')
model = CmdStanModel(stan_file=stan)
no_data = {}
args = OptimizeArgs(algorithm='Newton')
cmdstan_args = CmdStanArgs(
model_name=model.name,
model_exe=model.exe_file,
chain_ids=None,
data=no_data,
method_args=args,
)
runset = RunSet(args=cmdstan_args, chains=1)
mle = CmdStanMLE(runset)
self.assertIn('CmdStanMLE: model=rosenbrock', mle.__repr__())
self.assertIn('method=optimize', mle.__repr__())
self.assertEqual(mle._column_names, ())
self.assertEqual(mle._mle, {})
output = os.path.join(DATAFILES_PATH, 'optimize', 'rosenbrock_mle.csv')
mle._set_mle_attrs(output)
self.assertEqual(mle.column_names, ('lp__', 'x', 'y'))
self.assertAlmostEqual(mle.optimized_params_dict['x'], 1, places=3)
self.assertAlmostEqual(mle.optimized_params_dict['y'], 1, places=3)
def test_commands(self):
exe = os.path.join(DATAFILES_PATH, 'bernoulli' + EXTENSION)
jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')
sampler_args = SamplerArgs()
cmdstan_args = CmdStanArgs(
model_name='bernoulli',
model_exe=exe,
chain_ids=[1, 2, 3, 4],
data=jdata,
method_args=sampler_args,
)
runset = RunSet(args=cmdstan_args, chains=4)
self.assertIn('id=1', runset._cmds[0])
self.assertIn('id=4', runset._cmds[3])
)
with MaybeDictToFilePath(data, inits) as (_data, _inits):
args = CmdStanArgs(
self._name,
self._exe_file,
chain_ids=chain_ids,
data=_data,
seed=seed,
inits=_inits,
output_dir=output_dir,
save_diagnostics=save_diagnostics,
method_args=sampler_args,
refresh=refresh,
logger=self._logger,
)
runset = RunSet(args=args, chains=chains)
pbar = None
all_pbars = []
with ThreadPoolExecutor(max_workers=cores) as executor:
for i in range(chains):
if show_progress:
if (
isinstance(show_progress, str)
and show_progress.lower() == 'notebook'
):
try:
tqdm_pbar = tqdm.tqdm_notebook
except ImportError:
msg = (
'Cannot import tqdm.tqdm_notebook.\n'
'Functionality is only supported on the '
raise ValueError(
'MCMC sample must be either CmdStanMCMC object'
' or list of paths to sample csv_files.'
)
try:
chains = len(sample_csv_files)
if sample_drawset is None: # assemble sample from csv files
sampler_args = SamplerArgs()
args = CmdStanArgs(
self._name,
self._exe_file,
chain_ids=[x + 1 for x in range(chains)],
method_args=sampler_args,
)
runset = RunSet(args=args, chains=chains)
runset._csv_files = sample_csv_files
sample_fit = CmdStanMCMC(runset)
sample_fit._validate_csv_files()
sample_drawset = sample_fit.get_drawset()
except ValueError as e:
raise ValueError(
'Invalid mcmc_sample, error:\n\t{}\n\t'
' while processing files\n\t{}'.format(
repr(e), '\n\t'.join(sample_csv_files)
)
)
generate_quantities_args = GenerateQuantitiesArgs(
csv_files=sample_csv_files
)
generate_quantities_args.validate(chains)