How to use the numpyro.handlers.seed function in numpyro

To help you get started, we’ve selected a few numpyro examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github pyro-ppl / numpyro / test / test_handlers.py View on Github external
def test_plate(model):
    trace = handlers.trace(handlers.seed(model, random.PRNGKey(1))).get_trace()
    jit_trace = handlers.trace(jit(handlers.seed(model, random.PRNGKey(1)))).get_trace()
    assert 'z' in trace
    for name, site in trace.items():
        if site['type'] == 'sample':
            assert_allclose(jit_trace[name]['value'], site['value'])
github pyro-ppl / numpyro / test / test_handlers.py View on Github external
    ys = vmap(lambda rng_key: handlers.seed(lambda: _sample(), rng_key)())(jnp.arange(100))
    assert_allclose(xs, ys, atol=1e-6)
github pyro-ppl / numpyro / test / test_infer_util.py View on Github external
def test_model_with_transformed_distribution():
    x_prior = dist.HalfNormal(2)
    y_prior = dist.LogNormal(scale=3.)  # transformed distribution

    def model():
        numpyro.sample('x', x_prior)
        numpyro.sample('y', y_prior)

    params = {'x': np.array(-5.), 'y': np.array(7.)}
    model = handlers.seed(model, random.PRNGKey(0))
    inv_transforms = {'x': biject_to(x_prior.support), 'y': biject_to(y_prior.support)}
    expected_samples = partial(transform_fn, inv_transforms)(params)
    expected_potential_energy = (
        - x_prior.log_prob(expected_samples['x']) -
        y_prior.log_prob(expected_samples['y']) -
        inv_transforms['x'].log_abs_det_jacobian(params['x'], expected_samples['x']) -
        inv_transforms['y'].log_abs_det_jacobian(params['y'], expected_samples['y'])
    )

    base_inv_transforms = {'x': biject_to(x_prior.support), 'y': biject_to(y_prior.base_dist.support)}
    actual_samples = constrain_fn(
        handlers.seed(model, random.PRNGKey(0)), base_inv_transforms,  (), {}, params)
    actual_potential_energy = potential_energy(model, base_inv_transforms, (), {}, params)

    assert_allclose(expected_samples['x'], actual_samples['x'])
    assert_allclose(expected_samples['y'], actual_samples['y'])
github pyro-ppl / numpyro / test / test_handlers.py View on Github external
def fn(rng_key_1, rng_key_2, rng_key_3):
        xs = []
        with handlers.seed(rng_seed=rng_key_1):
            with handlers.seed(rng_seed=rng_key_2):
                xs.append(numpyro.sample('x', dist.Normal(0., 1.)))
                with handlers.seed(rng_seed=rng_key_3):
                    xs.append(numpyro.sample('y', dist.Normal(0., 1.)))
        return jnp.stack(xs)
github pyro-ppl / numpyro / test / test_handlers.py View on Github external
def test_condition():
    def model():
        x = numpyro.sample('x', dist.Delta(0.))
        y = numpyro.sample('y', dist.Normal(0., 1.))
        return x + y

    model = handlers.condition(handlers.seed(model, random.PRNGKey(1)), {'y': 2.})
    model_trace = handlers.trace(model).get_trace()
    assert model_trace['y']['value'] == 2.
    assert model_trace['y']['is_observed']
    assert handlers.condition(model, {'y': 3.})() == 3.
github pyro-ppl / numpyro / numpyro / infer / elbo.py View on Github external
def single_particle_elbo(rng_key):
            model_seed, guide_seed = random.split(rng_key)
            seeded_model = seed(model, model_seed)
            seeded_guide = seed(guide, guide_seed)
            guide_log_density, guide_trace = log_density(seeded_guide, args, kwargs, param_map)
            seeded_model = replay(seeded_model, guide_trace)
            model_log_density, _ = log_density(seeded_model, args, kwargs, param_map)

            # log p(z) - log q(z)
            elbo = model_log_density - guide_log_density
            return elbo
github arviz-devs / arviz / arviz / data / io_numpyro.py View on Github external
samples = {
                    "Param:{}".format(i): jax.device_get(v)
                    for i, v in enumerate(tree_flatten_samples)
                }
            self._samples = samples
            self.nchains, self.ndraws = posterior.num_chains, posterior.num_samples
            self.model = self.posterior.sampler.model
            # model arguments and keyword arguments
            self._args = self.posterior._args  # pylint: disable=protected-access
            self._kwargs = self.posterior._kwargs  # pylint: disable=protected-access
        else:
            self.nchains = self.ndraws = 0

        observations = {}
        if self.model is not None:
            seeded_model = numpyro.handlers.seed(self.model, jax.random.PRNGKey(0))
            trace = numpyro.handlers.trace(seeded_model).get_trace(*self._args, **self._kwargs)
            observations = {
                name: site["value"]
                for name, site in trace.items()
                if site["type"] == "sample" and site["is_observed"]
            }
        self.observations = observations if observations else None
github pyro-ppl / numpyro / numpyro / infer / autoguide.py View on Github external
def sample_posterior(self, rng_key, params, sample_shape=()):
        """
        Get samples from the learned posterior.

        :param jax.random.PRNGKey rng_key: random key to be used draw samples.
        :param dict params: Current parameters of model and autoguide.
            The parameters can be obtained using :meth:`~numpyro.infer.svi.SVI.get_params`
            method from :class:`~numpyro.infer.svi.SVI`.
        :param tuple sample_shape: batch shape of each latent sample, defaults to ().
        :return: a dict containing samples drawn the this guide.
        :rtype: dict
        """
        latent_sample = handlers.substitute(
            handlers.seed(self._sample_latent, rng_key), params)(sample_shape=sample_shape)
        return self._unpack_and_constrain(latent_sample, params)
github pyro-ppl / numpyro / examples / hmm_enum.py View on Github external
present_notes = ((sequences == 1).sum(0).sum(0) > 0)
    # remove notes that are never played (we remove 37/88 notes)
    sequences = sequences[..., present_notes]

    if args.truncate:
        lengths = lengths.clip(0, args.truncate)
        sequences = sequences[:, :args.truncate]

    # All of our models have two plates: "data" and "tones".
    max_plate_nesting = 1 if model is model_0 else 2

    # To help debug our tensor shapes, let's print the shape of each site's
    # distribution, value, and log_prob tensor. Note this information is
    # automatically printed on most errors inside SVI.
    if args.print_shapes:
        model_trace = packed_trace(enum(seed(model, 42), -max_plate_nesting - 1)).get_trace(
            sequences, lengths, args=args)
        for name in model_trace:
            if model_trace[name]['is_observed'] or model_trace[name]['infer'].get('enumerate', None) == 'parallel':
                dim_to_name = model_trace[name]['infer']['dim_to_name']
                logging.info(to_funsor(model_trace[name]['fn'].log_prob(model_trace[name]['value']),
                                       output=funsor.reals(), dim_to_name=dim_to_name).inputs)

    logging.info('Starting inference...')
    rng_key = random.PRNGKey(2)
    start = time.time()
    kernel = {'nuts': NUTS, 'hmc': HMC}[args.kernel](enum(model, -max_plate_nesting - 1))
    mcmc = MCMC(kernel, args.num_warmup, args.num_samples, progress_bar=True)
    mcmc.run(rng_key, sequences, lengths, args=args)
    mcmc.print_summary()
    # samples = mcmc.get_samples()  # TODO do something with this
    logging.info('\nMCMC elapsed time:', time.time() - start)