How to use the cmdstanpy.cmdstan_args.VariationalArgs function in cmdstanpy

To help you get started, we’ve selected a few cmdstanpy examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github stan-dev / cmdstanpy / test / test_cmdstan_args.py View on Github external
with self.assertRaises(ValueError):
            args.validate()

        args = VariationalArgs(iter=0)
        with self.assertRaises(ValueError):
            args.validate()

        args = VariationalArgs(iter=1.1)
        with self.assertRaises(ValueError):
            args.validate()

        args = VariationalArgs(grad_samples=0)
        with self.assertRaises(ValueError):
            args.validate()

        args = VariationalArgs(grad_samples=1.1)
        with self.assertRaises(ValueError):
            args.validate()

        args = VariationalArgs(elbo_samples=0)
        with self.assertRaises(ValueError):
            args.validate()

        args = VariationalArgs(elbo_samples=1.1)
        with self.assertRaises(ValueError):
            args.validate()

        args = VariationalArgs(eta=-0.00003)
        with self.assertRaises(ValueError):
            args.validate()

        args = VariationalArgs(adapt_iter=0)
github stan-dev / cmdstanpy / test / test_variational.py View on Github external
def test_set_variational_attrs(self):
        stan = os.path.join(
            DATAFILES_PATH, 'variational', 'eta_should_be_big.stan'
        )
        model = CmdStanModel(stan_file=stan)
        no_data = {}
        args = VariationalArgs(algorithm='meanfield')
        cmdstan_args = CmdStanArgs(
            model_name=model.name,
            model_exe=model.exe_file,
            chain_ids=None,
            data=no_data,
            method_args=args,
        )
        runset = RunSet(args=cmdstan_args, chains=1)
        variational = CmdStanVB(runset)
        self.assertIn(
            'CmdStanVB: model=eta_should_be_big', variational.__repr__()
        )
        self.assertIn('method=variational', variational.__repr__())

        # check CmdStanVB.__init__ state
        self.assertEqual(variational._column_names, ())
github stan-dev / cmdstanpy / test / test_cmdstan_args.py View on Github external
def test_args_variational(self):
        args = VariationalArgs()
        self.assertTrue(True)

        args = VariationalArgs(output_samples=1)
        args.validate(chains=1)
        cmd = args.compose(idx=0, cmd=[])
        self.assertIn('method=variational', ' '.join(cmd))
        self.assertIn('output_samples=1', ' '.join(cmd))

        args = VariationalArgs(tol_rel_obj=1)
        args.validate(chains=1)
        cmd = args.compose(idx=0, cmd=[])
        self.assertIn('method=variational', ' '.join(cmd))
        self.assertIn('tol_rel_obj=1', ' '.join(cmd))
github stan-dev / cmdstanpy / test / test_cmdstan_args.py View on Github external
with self.assertRaises(ValueError):
            args.validate()

        args = VariationalArgs(eta=-0.00003)
        with self.assertRaises(ValueError):
            args.validate()

        args = VariationalArgs(adapt_iter=0)
        with self.assertRaises(ValueError):
            args.validate()

        args = VariationalArgs(adapt_iter=1.1)
        with self.assertRaises(ValueError):
            args.validate()

        args = VariationalArgs(tol_rel_obj=0)
        with self.assertRaises(ValueError):
            args.validate()

        args = VariationalArgs(eval_elbo=0)
        with self.assertRaises(ValueError):
            args.validate()

        args = VariationalArgs(eval_elbo=1.5)
        with self.assertRaises(ValueError):
            args.validate()

        args = VariationalArgs(output_samples=0)
        with self.assertRaises(ValueError):
            args.validate()
github stan-dev / cmdstanpy / test / test_cmdstan_args.py View on Github external
with self.assertRaises(ValueError):
            args.validate()

        args = VariationalArgs(adapt_iter=1.1)
        with self.assertRaises(ValueError):
            args.validate()

        args = VariationalArgs(tol_rel_obj=0)
        with self.assertRaises(ValueError):
            args.validate()

        args = VariationalArgs(eval_elbo=0)
        with self.assertRaises(ValueError):
            args.validate()

        args = VariationalArgs(eval_elbo=1.5)
        with self.assertRaises(ValueError):
            args.validate()

        args = VariationalArgs(output_samples=0)
        with self.assertRaises(ValueError):
            args.validate()
github stan-dev / cmdstanpy / test / test_cmdstan_args.py View on Github external
with self.assertRaises(ValueError):
            args.validate()

        args = VariationalArgs(grad_samples=0)
        with self.assertRaises(ValueError):
            args.validate()

        args = VariationalArgs(grad_samples=1.1)
        with self.assertRaises(ValueError):
            args.validate()

        args = VariationalArgs(elbo_samples=0)
        with self.assertRaises(ValueError):
            args.validate()

        args = VariationalArgs(elbo_samples=1.1)
        with self.assertRaises(ValueError):
            args.validate()

        args = VariationalArgs(eta=-0.00003)
        with self.assertRaises(ValueError):
            args.validate()

        args = VariationalArgs(adapt_iter=0)
        with self.assertRaises(ValueError):
            args.validate()

        args = VariationalArgs(adapt_iter=1.1)
        with self.assertRaises(ValueError):
            args.validate()

        args = VariationalArgs(tol_rel_obj=0)
github stan-dev / cmdstanpy / test / test_cmdstan_args.py View on Github external
def test_args_bad(self):
        args = VariationalArgs(algorithm='no_such_algo')
        with self.assertRaises(ValueError):
            args.validate()

        args = VariationalArgs(iter=0)
        with self.assertRaises(ValueError):
            args.validate()

        args = VariationalArgs(iter=1.1)
        with self.assertRaises(ValueError):
            args.validate()

        args = VariationalArgs(grad_samples=0)
        with self.assertRaises(ValueError):
            args.validate()

        args = VariationalArgs(grad_samples=1.1)
github stan-dev / cmdstanpy / test / test_cmdstan_args.py View on Github external
with self.assertRaises(ValueError):
            args.validate()

        args = VariationalArgs(elbo_samples=1.1)
        with self.assertRaises(ValueError):
            args.validate()

        args = VariationalArgs(eta=-0.00003)
        with self.assertRaises(ValueError):
            args.validate()

        args = VariationalArgs(adapt_iter=0)
        with self.assertRaises(ValueError):
            args.validate()

        args = VariationalArgs(adapt_iter=1.1)
        with self.assertRaises(ValueError):
            args.validate()

        args = VariationalArgs(tol_rel_obj=0)
        with self.assertRaises(ValueError):
            args.validate()

        args = VariationalArgs(eval_elbo=0)
        with self.assertRaises(ValueError):
            args.validate()

        args = VariationalArgs(eval_elbo=1.5)
        with self.assertRaises(ValueError):
            args.validate()

        args = VariationalArgs(output_samples=0)
github stan-dev / cmdstanpy / test / test_cmdstan_args.py View on Github external
def test_args_variational(self):
        args = VariationalArgs()
        self.assertTrue(True)

        args = VariationalArgs(output_samples=1)
        args.validate(chains=1)
        cmd = args.compose(idx=0, cmd=[])
        self.assertIn('method=variational', ' '.join(cmd))
        self.assertIn('output_samples=1', ' '.join(cmd))

        args = VariationalArgs(tol_rel_obj=1)
        args.validate(chains=1)
        cmd = args.compose(idx=0, cmd=[])
        self.assertIn('method=variational', ' '.join(cmd))
        self.assertIn('tol_rel_obj=1', ' '.join(cmd))
github stan-dev / cmdstanpy / cmdstanpy / cmdstan_args.py View on Github external
self.model_exe = model_exe
        self.chain_ids = chain_ids
        self.data = data
        self.seed = seed
        self.inits = inits
        self.output_dir = output_dir
        self.save_diagnostics = save_diagnostics
        self.refresh = refresh
        self.method_args = method_args
        if isinstance(method_args, SamplerArgs):
            self.method = Method.SAMPLE
        elif isinstance(method_args, OptimizeArgs):
            self.method = Method.OPTIMIZE
        elif isinstance(method_args, GenerateQuantitiesArgs):
            self.method = Method.GENERATE_QUANTITIES
        elif isinstance(method_args, VariationalArgs):
            self.method = Method.VARIATIONAL
        self.method_args.validate(len(chain_ids) if chain_ids else None)
        self._logger = logger or get_logger()
        self.validate()