Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
cmd = args.compose(1, cmd=[])
self.assertIn(
'method=sample algorithm=hmc adapt engaged=1', ' '.join(cmd)
)
args = SamplerArgs(
adapt_init_phase=26, adapt_metric_window=60, adapt_step_size=34,
)
args.validate(chains=4)
cmd = args.compose(1, cmd=[])
self.assertIn('method=sample algorithm=hmc adapt', ' '.join(cmd))
self.assertIn('init_buffer=26', ' '.join(cmd))
self.assertIn('window=60', ' '.join(cmd))
self.assertIn('term_buffer=34', ' '.join(cmd))
args = SamplerArgs()
args.validate(chains=4)
cmd = args.compose(1, cmd=[])
self.assertNotIn('engine=nuts', ' '.join(cmd))
self.assertNotIn('adapt engaged=0', ' '.join(cmd))
def test_validate_good_run(self):
# construct fit using existing sampler output
exe = os.path.join(DATAFILES_PATH, 'bernoulli' + EXTENSION)
jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')
sampler_args = SamplerArgs(
iter_sampling=100, max_treedepth=11, adapt_delta=0.95
)
cmdstan_args = CmdStanArgs(
model_name='bernoulli',
model_exe=exe,
chain_ids=[1, 2, 3, 4],
seed=12345,
data=jdata,
output_dir=DATAFILES_PATH,
method_args=sampler_args,
)
runset = RunSet(args=cmdstan_args, chains=4)
runset._csv_files = [
os.path.join(DATAFILES_PATH, 'runset-good', 'bern-1.csv'),
os.path.join(DATAFILES_PATH, 'runset-good', 'bern-2.csv'),
os.path.join(DATAFILES_PATH, 'runset-good', 'bern-3.csv'),
with self.assertRaises(ValueError):
args.validate(chains=2)
args = SamplerArgs(step_size=[1.0, 1.1])
with self.assertRaises(ValueError):
args.validate(chains=1)
args = SamplerArgs(step_size=[1.0, -1.1])
with self.assertRaises(ValueError):
args.validate(chains=2)
args = SamplerArgs(adapt_delta=1.1)
with self.assertRaises(ValueError):
args.validate(chains=2)
args = SamplerArgs(adapt_delta=-0.1)
with self.assertRaises(ValueError):
args.validate(chains=2)
args = SamplerArgs(iter_warmup=100, fixed_param=True)
with self.assertRaises(ValueError):
args.validate(chains=2)
args = SamplerArgs(save_warmup=True, fixed_param=True)
with self.assertRaises(ValueError):
args.validate(chains=2)
args = SamplerArgs(max_treedepth=12, fixed_param=True)
with self.assertRaises(ValueError):
args.validate(chains=2)
args = SamplerArgs(metric='dense', fixed_param=True)
def test_validate_big_run(self):
exe = os.path.join(DATAFILES_PATH, 'bernoulli' + EXTENSION)
sampler_args = SamplerArgs(iter_warmup=1500, iter_sampling=1000)
cmdstan_args = CmdStanArgs(
model_name='bernoulli',
model_exe=exe,
chain_ids=[1, 2],
seed=12345,
output_dir=DATAFILES_PATH,
method_args=sampler_args,
)
runset = RunSet(args=cmdstan_args, chains=2)
runset._csv_files = [
os.path.join(DATAFILES_PATH, 'runset-big', 'output_icar_nyc-1.csv'),
os.path.join(DATAFILES_PATH, 'runset-big', 'output_icar_nyc-1.csv'),
]
fit = CmdStanMCMC(runset)
phis = ['phi.{}'.format(str(x + 1)) for x in range(2095)]
column_names = SAMPLER_STATE + phis
def test_diagnose_divergences(self):
exe = os.path.join(DATAFILES_PATH, 'bernoulli' + EXTENSION)
sampler_args = SamplerArgs()
cmdstan_args = CmdStanArgs(
model_name='bernoulli',
model_exe=exe,
chain_ids=[1],
output_dir=DATAFILES_PATH,
method_args=sampler_args,
)
runset = RunSet(args=cmdstan_args, chains=1)
runset._csv_files = [
os.path.join(
DATAFILES_PATH, 'diagnose-good', 'corr_gauss_depth8-1.csv'
)
]
fit = CmdStanMCMC(runset)
# TODO - use cmdstan test files instead
expected = '\n'.join(
args = SamplerArgs(metric='diag_e')
args.validate(chains=4)
cmd = args.compose(1, cmd=[])
self.assertIn(
'method=sample algorithm=hmc metric=diag_e', ' '.join(cmd)
)
args = SamplerArgs(metric='diag')
args.validate(chains=4)
cmd = args.compose(1, cmd=[])
self.assertIn(
'method=sample algorithm=hmc metric=diag_e', ' '.join(cmd)
)
args = SamplerArgs()
args.validate(chains=4)
cmd = args.compose(1, cmd=[])
self.assertNotIn('metric=', ' '.join(cmd))
jmetric = os.path.join(DATAFILES_PATH, 'bernoulli.metric.json')
args = SamplerArgs(metric=jmetric)
args.validate(chains=4)
cmd = args.compose(1, cmd=[])
self.assertIn('metric=diag_e', ' '.join(cmd))
self.assertIn('metric_file=', ' '.join(cmd))
self.assertIn('bernoulli.metric.json', ' '.join(cmd))
jmetric2 = os.path.join(DATAFILES_PATH, 'bernoulli.metric-2.json')
args = SamplerArgs(metric=[jmetric, jmetric2])
args.validate(chains=2)
cmd = args.compose(0, cmd=[])
def test_commands(self):
exe = os.path.join(DATAFILES_PATH, 'bernoulli' + EXTENSION)
jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')
sampler_args = SamplerArgs()
cmdstan_args = CmdStanArgs(
model_name='bernoulli',
model_exe=exe,
chain_ids=[1, 2, 3, 4],
data=jdata,
method_args=sampler_args,
)
runset = RunSet(args=cmdstan_args, chains=4)
self.assertIn('id=1', runset._cmds[0])
self.assertIn('id=4', runset._cmds[3])
def test_check_retcodes(self):
exe = os.path.join(DATAFILES_PATH, 'bernoulli' + EXTENSION)
jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')
sampler_args = SamplerArgs()
cmdstan_args = CmdStanArgs(
model_name='bernoulli',
model_exe=exe,
chain_ids=[1, 2, 3, 4],
data=jdata,
method_args=sampler_args,
)
runset = RunSet(args=cmdstan_args, chains=4)
self.assertIn('RunSet: chains=4', runset.__repr__())
self.assertIn('method=sample', runset.__repr__())
retcodes = runset._retcodes
self.assertEqual(4, len(retcodes))
for i in range(len(retcodes)):
self.assertEqual(-1, runset._retcode(i))
runset._set_retcode(0, 0)
with self.assertRaises(ValueError):
args.validate(chains=2)
args = SamplerArgs(adapt_metric_window=0.88)
with self.assertRaises(ValueError):
args.validate(chains=2)
args = SamplerArgs(adapt_step_size=0.88)
with self.assertRaises(ValueError):
args.validate(chains=2)
args = SamplerArgs(adapt_init_phase=-1)
with self.assertRaises(ValueError):
args.validate(chains=2)
args = SamplerArgs(adapt_metric_window=-2)
with self.assertRaises(ValueError):
args.validate(chains=2)
args = SamplerArgs(adapt_step_size=-3)
with self.assertRaises(ValueError):
args.validate(chains=2)
args = SamplerArgs(adapt_delta=0.88, fixed_param=True)
with self.assertRaises(ValueError):
args.validate(chains=2)
def test_bad(self):
args = SamplerArgs(iter_warmup=-10)
with self.assertRaises(ValueError):
args.validate(chains=2)
args = SamplerArgs(iter_warmup=10, adapt_engaged=False)
with self.assertRaises(ValueError):
args.validate(chains=2)
args = SamplerArgs(iter_sampling=-10)
with self.assertRaises(ValueError):
args.validate(chains=2)
args = SamplerArgs(thin=-10)
with self.assertRaises(ValueError):
args.validate(chains=2)
args = SamplerArgs(max_treedepth=-10)
with self.assertRaises(ValueError):
args.validate(chains=2)
args = SamplerArgs(step_size=-10)
with self.assertRaises(ValueError):
args.validate(chains=2)
args = SamplerArgs(step_size=[1.0, 1.1])