Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def posterior_predictive_to_xarray(self):
"""Convert posterior_predictive samples to xarray."""
data = {}
for k, ary in self.posterior_predictive.items():
ary = ary.detach().cpu().numpy()
shape = ary.shape
if shape[0] == self.nchains and shape[1] == self.ndraws:
data[k] = ary
elif shape[0] == self.nchains * self.ndraws:
data[k] = ary.reshape((self.nchains, self.ndraws, *shape[1:]))
else:
data[k] = utils.expand_dims(ary)
_log.warning(
"posterior predictive shape not compatible with number of chains and draws. "
"This can mean that some draws or even whole chains are not represented."
)
return dict_to_dataset(data, library=self.pyro, coords=self.coords, dims=self.dims)
def handle_chain_location(self, ary):
"""Move the axis corresponding to the chain to first position.
If there is only one chain which has no axis, add it.
"""
if self.chain_dim is None:
return utils.expand_dims(ary)
return ary.swapaxes(0, self.chain_dim)
def posterior_predictive_to_xarray(self):
"""Convert posterior_predictive samples to xarray."""
data = {}
for k, ary in self.posterior_predictive.items():
shape = ary.shape
if shape[0] == self.nchains and shape[1] == self.ndraws:
data[k] = ary
elif shape[0] == self.nchains * self.ndraws:
data[k] = ary.reshape((self.nchains, self.ndraws, *shape[1:]))
else:
data[k] = utils.expand_dims(ary)
_log.warning(
"posterior predictive shape not compatible with number of chains and draws. "
"This can mean that some draws or even whole chains are not represented."
)
return dict_to_dataset(data, library=self.numpyro, coords=self.coords, dims=self.dims)
return {"prior": None, "prior_predictive": None}
if self.posterior is not None:
prior_vars = list(self.posterior.get_samples().keys())
prior_predictive_vars = [key for key in self.prior.keys() if key not in prior_vars]
else:
prior_vars = self.prior.keys()
prior_predictive_vars = None
priors_dict = {}
for group, var_names in zip(
("prior", "prior_predictive"), (prior_vars, prior_predictive_vars)
):
priors_dict[group] = (
None
if var_names is None
else dict_to_dataset(
{k: utils.expand_dims(self.prior[k].detach().cpu().numpy()) for k in var_names},
library=self.pyro,
coords=self.coords,
dims=self.dims,
)
)
return priors_dict
return {"prior": None, "prior_predictive": None}
if self.posterior is not None:
prior_vars = list(self._samples.keys())
prior_predictive_vars = [key for key in self.prior.keys() if key not in prior_vars]
else:
prior_vars = self.prior.keys()
prior_predictive_vars = None
priors_dict = {}
for group, var_names in zip(
("prior", "prior_predictive"), (prior_vars, prior_predictive_vars)
):
priors_dict[group] = (
None
if var_names is None
else dict_to_dataset(
{k: utils.expand_dims(self.prior[k]) for k in var_names},
library=self.numpyro,
coords=self.coords,
dims=self.dims,
)
)
return priors_dict