Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
with self.assertRaises(ValueError):
args.validate()
args = VariationalArgs(iter=0)
with self.assertRaises(ValueError):
args.validate()
args = VariationalArgs(iter=1.1)
with self.assertRaises(ValueError):
args.validate()
args = VariationalArgs(grad_samples=0)
with self.assertRaises(ValueError):
args.validate()
args = VariationalArgs(grad_samples=1.1)
with self.assertRaises(ValueError):
args.validate()
args = VariationalArgs(elbo_samples=0)
with self.assertRaises(ValueError):
args.validate()
args = VariationalArgs(elbo_samples=1.1)
with self.assertRaises(ValueError):
args.validate()
args = VariationalArgs(eta=-0.00003)
with self.assertRaises(ValueError):
args.validate()
args = VariationalArgs(adapt_iter=0)
def test_set_variational_attrs(self):
stan = os.path.join(
DATAFILES_PATH, 'variational', 'eta_should_be_big.stan'
)
model = CmdStanModel(stan_file=stan)
no_data = {}
args = VariationalArgs(algorithm='meanfield')
cmdstan_args = CmdStanArgs(
model_name=model.name,
model_exe=model.exe_file,
chain_ids=None,
data=no_data,
method_args=args,
)
runset = RunSet(args=cmdstan_args, chains=1)
variational = CmdStanVB(runset)
self.assertIn(
'CmdStanVB: model=eta_should_be_big', variational.__repr__()
)
self.assertIn('method=variational', variational.__repr__())
# check CmdStanVB.__init__ state
self.assertEqual(variational._column_names, ())
def test_args_variational(self):
args = VariationalArgs()
self.assertTrue(True)
args = VariationalArgs(output_samples=1)
args.validate(chains=1)
cmd = args.compose(idx=0, cmd=[])
self.assertIn('method=variational', ' '.join(cmd))
self.assertIn('output_samples=1', ' '.join(cmd))
args = VariationalArgs(tol_rel_obj=1)
args.validate(chains=1)
cmd = args.compose(idx=0, cmd=[])
self.assertIn('method=variational', ' '.join(cmd))
self.assertIn('tol_rel_obj=1', ' '.join(cmd))
with self.assertRaises(ValueError):
args.validate()
args = VariationalArgs(eta=-0.00003)
with self.assertRaises(ValueError):
args.validate()
args = VariationalArgs(adapt_iter=0)
with self.assertRaises(ValueError):
args.validate()
args = VariationalArgs(adapt_iter=1.1)
with self.assertRaises(ValueError):
args.validate()
args = VariationalArgs(tol_rel_obj=0)
with self.assertRaises(ValueError):
args.validate()
args = VariationalArgs(eval_elbo=0)
with self.assertRaises(ValueError):
args.validate()
args = VariationalArgs(eval_elbo=1.5)
with self.assertRaises(ValueError):
args.validate()
args = VariationalArgs(output_samples=0)
with self.assertRaises(ValueError):
args.validate()
with self.assertRaises(ValueError):
args.validate()
args = VariationalArgs(adapt_iter=1.1)
with self.assertRaises(ValueError):
args.validate()
args = VariationalArgs(tol_rel_obj=0)
with self.assertRaises(ValueError):
args.validate()
args = VariationalArgs(eval_elbo=0)
with self.assertRaises(ValueError):
args.validate()
args = VariationalArgs(eval_elbo=1.5)
with self.assertRaises(ValueError):
args.validate()
args = VariationalArgs(output_samples=0)
with self.assertRaises(ValueError):
args.validate()
with self.assertRaises(ValueError):
args.validate()
args = VariationalArgs(grad_samples=0)
with self.assertRaises(ValueError):
args.validate()
args = VariationalArgs(grad_samples=1.1)
with self.assertRaises(ValueError):
args.validate()
args = VariationalArgs(elbo_samples=0)
with self.assertRaises(ValueError):
args.validate()
args = VariationalArgs(elbo_samples=1.1)
with self.assertRaises(ValueError):
args.validate()
args = VariationalArgs(eta=-0.00003)
with self.assertRaises(ValueError):
args.validate()
args = VariationalArgs(adapt_iter=0)
with self.assertRaises(ValueError):
args.validate()
args = VariationalArgs(adapt_iter=1.1)
with self.assertRaises(ValueError):
args.validate()
args = VariationalArgs(tol_rel_obj=0)
def test_args_bad(self):
args = VariationalArgs(algorithm='no_such_algo')
with self.assertRaises(ValueError):
args.validate()
args = VariationalArgs(iter=0)
with self.assertRaises(ValueError):
args.validate()
args = VariationalArgs(iter=1.1)
with self.assertRaises(ValueError):
args.validate()
args = VariationalArgs(grad_samples=0)
with self.assertRaises(ValueError):
args.validate()
args = VariationalArgs(grad_samples=1.1)
with self.assertRaises(ValueError):
args.validate()
args = VariationalArgs(elbo_samples=1.1)
with self.assertRaises(ValueError):
args.validate()
args = VariationalArgs(eta=-0.00003)
with self.assertRaises(ValueError):
args.validate()
args = VariationalArgs(adapt_iter=0)
with self.assertRaises(ValueError):
args.validate()
args = VariationalArgs(adapt_iter=1.1)
with self.assertRaises(ValueError):
args.validate()
args = VariationalArgs(tol_rel_obj=0)
with self.assertRaises(ValueError):
args.validate()
args = VariationalArgs(eval_elbo=0)
with self.assertRaises(ValueError):
args.validate()
args = VariationalArgs(eval_elbo=1.5)
with self.assertRaises(ValueError):
args.validate()
args = VariationalArgs(output_samples=0)
def test_args_variational(self):
args = VariationalArgs()
self.assertTrue(True)
args = VariationalArgs(output_samples=1)
args.validate(chains=1)
cmd = args.compose(idx=0, cmd=[])
self.assertIn('method=variational', ' '.join(cmd))
self.assertIn('output_samples=1', ' '.join(cmd))
args = VariationalArgs(tol_rel_obj=1)
args.validate(chains=1)
cmd = args.compose(idx=0, cmd=[])
self.assertIn('method=variational', ' '.join(cmd))
self.assertIn('tol_rel_obj=1', ' '.join(cmd))
self.model_exe = model_exe
self.chain_ids = chain_ids
self.data = data
self.seed = seed
self.inits = inits
self.output_dir = output_dir
self.save_diagnostics = save_diagnostics
self.refresh = refresh
self.method_args = method_args
if isinstance(method_args, SamplerArgs):
self.method = Method.SAMPLE
elif isinstance(method_args, OptimizeArgs):
self.method = Method.OPTIMIZE
elif isinstance(method_args, GenerateQuantitiesArgs):
self.method = Method.GENERATE_QUANTITIES
elif isinstance(method_args, VariationalArgs):
self.method = Method.VARIATIONAL
self.method_args.validate(len(chain_ids) if chain_ids else None)
self._logger = logger or get_logger()
self.validate()