Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_bernoulli_good(self, stanfile='bernoulli.stan'):
stan = os.path.join(DATAFILES_PATH, stanfile)
bern_model = CmdStanModel(stan_file=stan)
jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')
bern_fit = bern_model.sample(
data=jdata, chains=2, cores=2, seed=12345, iter_sampling=100
)
self.assertIn('CmdStanMCMC: model=bernoulli', bern_fit.__repr__())
self.assertIn('method=sample', bern_fit.__repr__())
self.assertEqual(bern_fit.runset._args.method, Method.SAMPLE)
for i in range(bern_fit.runset.chains):
csv_file = bern_fit.runset.csv_files[i]
stdout_file = bern_fit.runset.stdout_files[i]
self.assertTrue(os.path.exists(csv_file))
self.assertTrue(os.path.exists(stdout_file))
self.assertEqual(bern_fit.runset.chains, 2)
self.assertEqual(bern_fit.num_draws, 100)
self.assertEqual(bern_fit.column_names, tuple(BERNOULLI_COLS))
bern_sample = bern_fit.sample
self.assertEqual(bern_sample.shape, (100, 2, len(BERNOULLI_COLS)))
self.assertEqual(bern_fit.metric_type, 'diag_e')
self.assertEqual(bern_fit.stepsize.shape, (2,))
self.assertEqual(bern_fit.metric.shape, (2, 1))
def test_gen_quantities_csv_files(self):
stan = os.path.join(DATAFILES_PATH, 'bernoulli_ppc.stan')
model = CmdStanModel(stan_file=stan)
jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')
# synthesize list of filenames
goodfiles_path = os.path.join(DATAFILES_PATH, 'runset-good', 'bern')
csv_files = []
for i in range(4):
csv_files.append('{}-{}.csv'.format(goodfiles_path, i + 1))
bern_gqs = model.generate_quantities(data=jdata, mcmc_sample=csv_files)
self.assertEqual(
bern_gqs.runset._args.method, Method.GENERATE_QUANTITIES
)
self.assertIn('CmdStanGQ: model=bernoulli_ppc', bern_gqs.__repr__())
self.assertIn('method=generate_quantities', bern_gqs.__repr__())
# check results - ouput files, quantities of interest, draws
self.assertEqual(bern_gqs.runset.chains, 4)
for i in range(bern_gqs.runset.chains):
self.assertEqual(bern_gqs.runset._retcode(i), 0)
csv_file = bern_gqs.runset.csv_files[i]
self.assertTrue(os.path.exists(csv_file))
column_names = [
'y_rep.1',
'y_rep.2',
'y_rep.3',
'y_rep.4',
'y_rep.5',
def test_gen_quanties_mcmc_sample(self):
stan = os.path.join(DATAFILES_PATH, 'bernoulli.stan')
bern_model = CmdStanModel(stan_file=stan)
jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')
bern_fit = bern_model.sample(
data=jdata, chains=4, cores=2, seed=12345, iter_sampling=100
)
stan = os.path.join(DATAFILES_PATH, 'bernoulli_ppc.stan')
model = CmdStanModel(stan_file=stan)
bern_gqs = model.generate_quantities(data=jdata, mcmc_sample=bern_fit)
self.assertEqual(
bern_gqs.runset._args.method, Method.GENERATE_QUANTITIES
)
self.assertIn('CmdStanGQ: model=bernoulli_ppc', bern_gqs.__repr__())
self.assertIn('method=generate_quantities', bern_gqs.__repr__())
# check results - ouput files, quantities of interest, draws
self.assertEqual(bern_gqs.runset.chains, 4)
for i in range(bern_gqs.runset.chains):
self.assertEqual(bern_gqs.runset._retcode(i), 0)
csv_file = bern_gqs.runset.csv_files[i]
self.assertTrue(os.path.exists(csv_file))
column_names = [
'y_rep.1',
'y_rep.2',
'y_rep.3',
'y_rep.4',
'y_rep.5',
def __init__(self, runset: RunSet) -> None:
"""Initialize object."""
if not runset.method == Method.SAMPLE:
raise ValueError(
'Wrong runset method, expecting sample runset, '
'found method {}'.format(runset.method)
)
self.runset = runset
# copy info from runset
self._is_fixed_param = runset._args.method_args.fixed_param
self._iter_sampling = runset._args.method_args.iter_sampling
self._iter_warmup = runset._args.method_args.iter_warmup
self._save_warmup = runset._args.method_args.save_warmup
self._thin = runset._args.method_args.thin
# parse the remainder from csv files
self._draws_sampling = None
self._draws_warmup = None
self._column_names = ()
self._num_params = None # metric dim(s)
def sample_plus_quantities(self) -> pd.DataFrame:
"""
Returns the column-wise concatenation of the input drawset
with generated quantities drawset. If there are duplicate
columns in both the input and the generated quantities,
the input column is dropped in favor of the recomputed
values in the generate quantities drawset.
"""
if not self.runset.method == Method.GENERATE_QUANTITIES:
raise ValueError('Bad runset method {}.'.format(self.runset.method))
if self._generated_quantities is None:
self._assemble_generated_quantities()
cols_1 = self.mcmc_sample.columns.tolist()
cols_2 = self.generated_quantities_pd.columns.tolist()
dups = [
item
for item, count in Counter(cols_1 + cols_2).items()
if count > 1
]
return pd.concat(
[self.mcmc_sample.drop(columns=dups), self.generated_quantities_pd],
axis=1,
)
self.model_name = model_name
self.model_exe = model_exe
self.chain_ids = chain_ids
self.data = data
self.seed = seed
self.inits = inits
self.output_dir = output_dir
self.save_diagnostics = save_diagnostics
self.refresh = refresh
self.method_args = method_args
if isinstance(method_args, SamplerArgs):
self.method = Method.SAMPLE
elif isinstance(method_args, OptimizeArgs):
self.method = Method.OPTIMIZE
elif isinstance(method_args, GenerateQuantitiesArgs):
self.method = Method.GENERATE_QUANTITIES
elif isinstance(method_args, VariationalArgs):
self.method = Method.VARIATIONAL
self.method_args.validate(len(chain_ids) if chain_ids else None)
self._logger = logger or get_logger()
self.validate()
) -> None:
"""Initialize object."""
self.model_name = model_name
self.model_exe = model_exe
self.chain_ids = chain_ids
self.data = data
self.seed = seed
self.inits = inits
self.output_dir = output_dir
self.save_diagnostics = save_diagnostics
self.refresh = refresh
self.method_args = method_args
if isinstance(method_args, SamplerArgs):
self.method = Method.SAMPLE
elif isinstance(method_args, OptimizeArgs):
self.method = Method.OPTIMIZE
elif isinstance(method_args, GenerateQuantitiesArgs):
self.method = Method.GENERATE_QUANTITIES
elif isinstance(method_args, VariationalArgs):
self.method = Method.VARIATIONAL
self.method_args.validate(len(chain_ids) if chain_ids else None)
self._logger = logger or get_logger()
self.validate()
def __init__(self, runset: RunSet) -> None:
"""Initialize object."""
if not runset.method == Method.OPTIMIZE:
raise ValueError(
'Wrong runset method, expecting optimize runset, '
'found method {}'.format(runset.method)
)
self.runset = runset
self._column_names = ()
self._mle = {}
self._set_mle_attrs(runset.csv_files[0])
refresh: str = None,
logger: logging.Logger = None,
) -> None:
"""Initialize object."""
self.model_name = model_name
self.model_exe = model_exe
self.chain_ids = chain_ids
self.data = data
self.seed = seed
self.inits = inits
self.output_dir = output_dir
self.save_diagnostics = save_diagnostics
self.refresh = refresh
self.method_args = method_args
if isinstance(method_args, SamplerArgs):
self.method = Method.SAMPLE
elif isinstance(method_args, OptimizeArgs):
self.method = Method.OPTIMIZE
elif isinstance(method_args, GenerateQuantitiesArgs):
self.method = Method.GENERATE_QUANTITIES
elif isinstance(method_args, VariationalArgs):
self.method = Method.VARIATIONAL
self.method_args.validate(len(chain_ids) if chain_ids else None)
self._logger = logger or get_logger()
self.validate()
self.chain_ids = chain_ids
self.data = data
self.seed = seed
self.inits = inits
self.output_dir = output_dir
self.save_diagnostics = save_diagnostics
self.refresh = refresh
self.method_args = method_args
if isinstance(method_args, SamplerArgs):
self.method = Method.SAMPLE
elif isinstance(method_args, OptimizeArgs):
self.method = Method.OPTIMIZE
elif isinstance(method_args, GenerateQuantitiesArgs):
self.method = Method.GENERATE_QUANTITIES
elif isinstance(method_args, VariationalArgs):
self.method = Method.VARIATIONAL
self.method_args.validate(len(chain_ids) if chain_ids else None)
self._logger = logger or get_logger()
self.validate()