Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
CmdStanArgs(
model_name='bernoulli',
model_exe='bernoulli.exe',
chain_ids=[1, 2, 3, 4],
output_dir=fname,
method_args=sampler_args,
)
if os.path.exists(fname):
os.remove(fname)
# TODO: read-only dir test for Windows - set ACLs, not mode
if platform.system() == 'Darwin' or platform.system() == 'Linux':
with self.assertRaises(ValueError):
read_only = os.path.join(_TMPDIR, 'read_only')
os.mkdir(read_only, mode=0o444)
CmdStanArgs(
model_name='bernoulli',
model_exe='bernoulli.exe',
chain_ids=[1, 2, 3, 4],
output_dir=read_only,
method_args=sampler_args,
)
jinits = os.path.join(DATAFILES_PATH, 'bernoulli.init.json')
sampler_args = SamplerArgs()
with self.assertRaises(ValueError):
CmdStanArgs(
model_name='bernoulli',
model_exe=exe,
chain_ids=None,
seed=[1, 2, 3],
data=jdata,
inits=jinits,
method_args=sampler_args,
)
with self.assertRaises(ValueError):
CmdStanArgs(
model_name='bernoulli',
model_exe=exe,
chain_ids=None,
data=jdata,
inits=[jinits],
method_args=sampler_args,
)
def test_get_err_msgs(self):
exe = os.path.join(DATAFILES_PATH, 'logistic' + EXTENSION)
rdata = os.path.join(DATAFILES_PATH, 'logistic.data.R')
sampler_args = SamplerArgs()
cmdstan_args = CmdStanArgs(
model_name='logistic',
model_exe=exe,
chain_ids=[1, 2, 3],
data=rdata,
method_args=sampler_args,
)
runset = RunSet(args=cmdstan_args, chains=3)
for i in range(3):
runset._set_retcode(i, 70)
stdout_file = 'chain-' + str(i + 1) + '-missing-data-stdout.txt'
path = os.path.join(DATAFILES_PATH, stdout_file)
runset._stdout_files[i] = path
errs = '\n\t'.join(runset._get_err_msgs())
self.assertIn('Exception', errs)
def test_output_filenames(self):
exe = os.path.join(DATAFILES_PATH, 'bernoulli' + EXTENSION)
jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')
sampler_args = SamplerArgs()
cmdstan_args = CmdStanArgs(
model_name='bernoulli',
model_exe=exe,
chain_ids=[1, 2, 3, 4],
data=jdata,
method_args=sampler_args,
)
runset = RunSet(args=cmdstan_args, chains=4)
self.assertIn('bernoulli-', runset._csv_files[0])
self.assertIn('-1-', runset._csv_files[0])
self.assertIn('-4-', runset._csv_files[3])
def test_validate_good_run(self):
# construct fit using existing sampler output
exe = os.path.join(DATAFILES_PATH, 'bernoulli' + EXTENSION)
jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')
sampler_args = SamplerArgs(
iter_sampling=100, max_treedepth=11, adapt_delta=0.95
)
cmdstan_args = CmdStanArgs(
model_name='bernoulli',
model_exe=exe,
chain_ids=[1, 2, 3, 4],
seed=12345,
data=jdata,
output_dir=DATAFILES_PATH,
method_args=sampler_args,
)
runset = RunSet(args=cmdstan_args, chains=4)
runset._csv_files = [
os.path.join(DATAFILES_PATH, 'runset-good', 'bern-1.csv'),
os.path.join(DATAFILES_PATH, 'runset-good', 'bern-2.csv'),
os.path.join(DATAFILES_PATH, 'runset-good', 'bern-3.csv'),
os.path.join(DATAFILES_PATH, 'runset-good', 'bern-4.csv'),
]
self.assertEqual(4, runset.chains)
def test_args_bad(self):
sampler_args = SamplerArgs(iter_warmup=10, iter_sampling=20)
with self.assertRaisesRegex(
Exception, 'missing 2 required positional arguments'
):
CmdStanArgs(model_name='bernoulli', model_exe='bernoulli.exe')
with self.assertRaisesRegex(
ValueError, 'no such file no/such/path/to.file'
):
CmdStanArgs(
model_name='bernoulli',
model_exe='bernoulli.exe',
chain_ids=[1, 2, 3, 4],
data='no/such/path/to.file',
method_args=sampler_args,
)
with self.assertRaisesRegex(ValueError, 'invalid chain_id'):
CmdStanArgs(
model_name='bernoulli',
model_exe='bernoulli.exe',
chain_ids=[1, 2, 3, -4],
method_args=sampler_args,
)
with self.assertRaisesRegex(
self.assertIn('method=sample algorithm=hmc', ' '.join(cmd))
cmdstan_args = CmdStanArgs(
model_name='bernoulli',
model_exe=exe,
chain_ids=[7, 11, 18, 29],
data=jdata,
method_args=sampler_args,
)
cmd = cmdstan_args.compose_command(idx=0, csv_file='bern-output-1.csv')
self.assertIn('id=7 random seed=', ' '.join(cmd))
dirname = 'tmp' + str(time())
if os.path.exists(dirname):
os.rmdir(dirname)
CmdStanArgs(
model_name='bernoulli',
model_exe='bernoulli.exe',
chain_ids=[1, 2, 3, 4],
output_dir=dirname,
method_args=sampler_args,
)
self.assertTrue(os.path.exists(dirname))
os.rmdir(dirname)
sample_csv_files = mcmc_sample.runset.csv_files
sample_drawset = mcmc_sample.get_drawset()
chains = mcmc_sample.chains
elif isinstance(mcmc_sample, list):
sample_csv_files = mcmc_sample
else:
raise ValueError(
'MCMC sample must be either CmdStanMCMC object'
' or list of paths to sample csv_files.'
)
try:
chains = len(sample_csv_files)
if sample_drawset is None: # assemble sample from csv files
sampler_args = SamplerArgs()
args = CmdStanArgs(
self._name,
self._exe_file,
chain_ids=[x + 1 for x in range(chains)],
method_args=sampler_args,
)
runset = RunSet(args=args, chains=chains)
runset._csv_files = sample_csv_files
sample_fit = CmdStanMCMC(runset)
sample_fit._validate_csv_files()
sample_drawset = sample_fit.get_drawset()
except ValueError as e:
raise ValueError(
'Invalid mcmc_sample, error:\n\t{}\n\t'
' while processing files\n\t{}'.format(
repr(e), '\n\t'.join(sample_csv_files)
)
iter_warmup=iter_warmup,
iter_sampling=iter_sampling,
save_warmup=save_warmup,
thin=thin,
max_treedepth=max_treedepth,
metric=metric,
step_size=step_size,
adapt_engaged=adapt_engaged,
adapt_delta=adapt_delta,
adapt_init_phase=adapt_init_phase,
adapt_metric_window=adapt_metric_window,
adapt_step_size=adapt_step_size,
fixed_param=fixed_param,
)
with MaybeDictToFilePath(data, inits) as (_data, _inits):
args = CmdStanArgs(
self._name,
self._exe_file,
chain_ids=chain_ids,
data=_data,
seed=seed,
inits=_inits,
output_dir=output_dir,
save_diagnostics=save_diagnostics,
method_args=sampler_args,
refresh=refresh,
logger=self._logger,
)
runset = RunSet(args=args, chains=chains)
pbar = None
all_pbars = []
where `` is set with `csv_basename`.
:param algorithm: Algorithm to use. One of: "BFGS", "LBFGS", "Newton"
:param init_alpha: Line search step size for first iteration
:param iter: Total number of iterations
:return: CmdStanMLE object
"""
optimize_args = OptimizeArgs(
algorithm=algorithm, init_alpha=init_alpha, iter=iter
)
with MaybeDictToFilePath(data, inits) as (_data, _inits):
args = CmdStanArgs(
self._name,
self._exe_file,
chain_ids=None,
data=_data,
seed=seed,
inits=_inits,
output_dir=output_dir,
save_diagnostics=save_diagnostics,
method_args=optimize_args,
)
dummy_chain_id = 0
runset = RunSet(args=args, chains=1)
self._run_cmdstan(runset, dummy_chain_id)
if not runset._check_retcodes():