How to use the cmdstanpy.cmdstan_args.SamplerArgs function in cmdstanpy

To help you get started, we’ve selected a few cmdstanpy examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github stan-dev / cmdstanpy / test / test_cmdstan_args.py View on Github external
cmd = args.compose(1, cmd=[])
        self.assertIn(
            'method=sample algorithm=hmc adapt engaged=1', ' '.join(cmd)
        )

        args = SamplerArgs(
            adapt_init_phase=26, adapt_metric_window=60, adapt_step_size=34,
        )
        args.validate(chains=4)
        cmd = args.compose(1, cmd=[])
        self.assertIn('method=sample algorithm=hmc adapt', ' '.join(cmd))
        self.assertIn('init_buffer=26', ' '.join(cmd))
        self.assertIn('window=60', ' '.join(cmd))
        self.assertIn('term_buffer=34', ' '.join(cmd))

        args = SamplerArgs()
        args.validate(chains=4)
        cmd = args.compose(1, cmd=[])
        self.assertNotIn('engine=nuts', ' '.join(cmd))
        self.assertNotIn('adapt engaged=0', ' '.join(cmd))
github stan-dev / cmdstanpy / test / test_sample.py View on Github external
def test_validate_good_run(self):
        # construct fit using existing sampler output
        exe = os.path.join(DATAFILES_PATH, 'bernoulli' + EXTENSION)
        jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')
        sampler_args = SamplerArgs(
            iter_sampling=100, max_treedepth=11, adapt_delta=0.95
        )
        cmdstan_args = CmdStanArgs(
            model_name='bernoulli',
            model_exe=exe,
            chain_ids=[1, 2, 3, 4],
            seed=12345,
            data=jdata,
            output_dir=DATAFILES_PATH,
            method_args=sampler_args,
        )
        runset = RunSet(args=cmdstan_args, chains=4)
        runset._csv_files = [
            os.path.join(DATAFILES_PATH, 'runset-good', 'bern-1.csv'),
            os.path.join(DATAFILES_PATH, 'runset-good', 'bern-2.csv'),
            os.path.join(DATAFILES_PATH, 'runset-good', 'bern-3.csv'),
github stan-dev / cmdstanpy / test / test_cmdstan_args.py View on Github external
with self.assertRaises(ValueError):
            args.validate(chains=2)

        args = SamplerArgs(step_size=[1.0, 1.1])
        with self.assertRaises(ValueError):
            args.validate(chains=1)

        args = SamplerArgs(step_size=[1.0, -1.1])
        with self.assertRaises(ValueError):
            args.validate(chains=2)

        args = SamplerArgs(adapt_delta=1.1)
        with self.assertRaises(ValueError):
            args.validate(chains=2)

        args = SamplerArgs(adapt_delta=-0.1)
        with self.assertRaises(ValueError):
            args.validate(chains=2)

        args = SamplerArgs(iter_warmup=100, fixed_param=True)
        with self.assertRaises(ValueError):
            args.validate(chains=2)

        args = SamplerArgs(save_warmup=True, fixed_param=True)
        with self.assertRaises(ValueError):
            args.validate(chains=2)

        args = SamplerArgs(max_treedepth=12, fixed_param=True)
        with self.assertRaises(ValueError):
            args.validate(chains=2)

        args = SamplerArgs(metric='dense', fixed_param=True)
github stan-dev / cmdstanpy / test / test_sample.py View on Github external
def test_validate_big_run(self):
        exe = os.path.join(DATAFILES_PATH, 'bernoulli' + EXTENSION)
        sampler_args = SamplerArgs(iter_warmup=1500, iter_sampling=1000)
        cmdstan_args = CmdStanArgs(
            model_name='bernoulli',
            model_exe=exe,
            chain_ids=[1, 2],
            seed=12345,
            output_dir=DATAFILES_PATH,
            method_args=sampler_args,
        )
        runset = RunSet(args=cmdstan_args, chains=2)
        runset._csv_files = [
            os.path.join(DATAFILES_PATH, 'runset-big', 'output_icar_nyc-1.csv'),
            os.path.join(DATAFILES_PATH, 'runset-big', 'output_icar_nyc-1.csv'),
        ]
        fit = CmdStanMCMC(runset)
        phis = ['phi.{}'.format(str(x + 1)) for x in range(2095)]
        column_names = SAMPLER_STATE + phis
github stan-dev / cmdstanpy / test / test_sample.py View on Github external
def test_diagnose_divergences(self):
        exe = os.path.join(DATAFILES_PATH, 'bernoulli' + EXTENSION)
        sampler_args = SamplerArgs()
        cmdstan_args = CmdStanArgs(
            model_name='bernoulli',
            model_exe=exe,
            chain_ids=[1],
            output_dir=DATAFILES_PATH,
            method_args=sampler_args,
        )
        runset = RunSet(args=cmdstan_args, chains=1)
        runset._csv_files = [
            os.path.join(
                DATAFILES_PATH, 'diagnose-good', 'corr_gauss_depth8-1.csv'
            )
        ]
        fit = CmdStanMCMC(runset)
        # TODO - use cmdstan test files instead
        expected = '\n'.join(
github stan-dev / cmdstanpy / test / test_cmdstan_args.py View on Github external
args = SamplerArgs(metric='diag_e')
        args.validate(chains=4)
        cmd = args.compose(1, cmd=[])
        self.assertIn(
            'method=sample algorithm=hmc metric=diag_e', ' '.join(cmd)
        )

        args = SamplerArgs(metric='diag')
        args.validate(chains=4)
        cmd = args.compose(1, cmd=[])
        self.assertIn(
            'method=sample algorithm=hmc metric=diag_e', ' '.join(cmd)
        )

        args = SamplerArgs()
        args.validate(chains=4)
        cmd = args.compose(1, cmd=[])
        self.assertNotIn('metric=', ' '.join(cmd))

        jmetric = os.path.join(DATAFILES_PATH, 'bernoulli.metric.json')
        args = SamplerArgs(metric=jmetric)
        args.validate(chains=4)
        cmd = args.compose(1, cmd=[])
        self.assertIn('metric=diag_e', ' '.join(cmd))
        self.assertIn('metric_file=', ' '.join(cmd))
        self.assertIn('bernoulli.metric.json', ' '.join(cmd))

        jmetric2 = os.path.join(DATAFILES_PATH, 'bernoulli.metric-2.json')
        args = SamplerArgs(metric=[jmetric, jmetric2])
        args.validate(chains=2)
        cmd = args.compose(0, cmd=[])
github stan-dev / cmdstanpy / test / test_runset.py View on Github external
def test_commands(self):
        exe = os.path.join(DATAFILES_PATH, 'bernoulli' + EXTENSION)
        jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')
        sampler_args = SamplerArgs()
        cmdstan_args = CmdStanArgs(
            model_name='bernoulli',
            model_exe=exe,
            chain_ids=[1, 2, 3, 4],
            data=jdata,
            method_args=sampler_args,
        )
        runset = RunSet(args=cmdstan_args, chains=4)
        self.assertIn('id=1', runset._cmds[0])
        self.assertIn('id=4', runset._cmds[3])
github stan-dev / cmdstanpy / test / test_runset.py View on Github external
def test_check_retcodes(self):
        exe = os.path.join(DATAFILES_PATH, 'bernoulli' + EXTENSION)
        jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')
        sampler_args = SamplerArgs()
        cmdstan_args = CmdStanArgs(
            model_name='bernoulli',
            model_exe=exe,
            chain_ids=[1, 2, 3, 4],
            data=jdata,
            method_args=sampler_args,
        )
        runset = RunSet(args=cmdstan_args, chains=4)
        self.assertIn('RunSet: chains=4', runset.__repr__())
        self.assertIn('method=sample', runset.__repr__())

        retcodes = runset._retcodes
        self.assertEqual(4, len(retcodes))
        for i in range(len(retcodes)):
            self.assertEqual(-1, runset._retcode(i))
        runset._set_retcode(0, 0)
github stan-dev / cmdstanpy / test / test_cmdstan_args.py View on Github external
with self.assertRaises(ValueError):
            args.validate(chains=2)

        args = SamplerArgs(adapt_metric_window=0.88)
        with self.assertRaises(ValueError):
            args.validate(chains=2)

        args = SamplerArgs(adapt_step_size=0.88)
        with self.assertRaises(ValueError):
            args.validate(chains=2)

        args = SamplerArgs(adapt_init_phase=-1)
        with self.assertRaises(ValueError):
            args.validate(chains=2)

        args = SamplerArgs(adapt_metric_window=-2)
        with self.assertRaises(ValueError):
            args.validate(chains=2)

        args = SamplerArgs(adapt_step_size=-3)
        with self.assertRaises(ValueError):
            args.validate(chains=2)

        args = SamplerArgs(adapt_delta=0.88, fixed_param=True)
        with self.assertRaises(ValueError):
            args.validate(chains=2)
github stan-dev / cmdstanpy / test / test_cmdstan_args.py View on Github external
def test_bad(self):
        args = SamplerArgs(iter_warmup=-10)
        with self.assertRaises(ValueError):
            args.validate(chains=2)

        args = SamplerArgs(iter_warmup=10, adapt_engaged=False)
        with self.assertRaises(ValueError):
            args.validate(chains=2)

        args = SamplerArgs(iter_sampling=-10)
        with self.assertRaises(ValueError):
            args.validate(chains=2)

        args = SamplerArgs(thin=-10)
        with self.assertRaises(ValueError):
            args.validate(chains=2)

        args = SamplerArgs(max_treedepth=-10)
        with self.assertRaises(ValueError):
            args.validate(chains=2)

        args = SamplerArgs(step_size=-10)
        with self.assertRaises(ValueError):
            args.validate(chains=2)

        args = SamplerArgs(step_size=[1.0, 1.1])