Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
@requires("prior")
def prior_to_xarray(self):
"""Convert prior samples to xarray."""
# filter prior_predictive
prior_predictive = self.prior_predictive
columns = self.prior[0].columns
if prior_predictive is None or (
isinstance(prior_predictive, str) and prior_predictive.lower().endswith(".csv")
):
prior_predictive = []
elif isinstance(prior_predictive, str):
prior_predictive = [col for col in columns if prior_predictive == col.split(".")[0]]
else:
prior_predictive = [
col
for col in columns
if any(item == col.split(".")[0] for item in prior_predictive)
@requires("posterior_predictive")
def posterior_predictive_to_xarray(self):
"""Convert posterior_predictive samples to xarray."""
posterior_predictive = self.posterior_predictive
columns = self.posterior.column_names
if isinstance(posterior_predictive, str):
posterior_predictive = [posterior_predictive]
valid_cols = [col for col in columns if col.split(".")[0] in set(posterior_predictive)]
data = _unpack_frame(self.posterior.sample, columns, valid_cols)
return dict_to_dataset(data, library=self.cmdstanpy, coords=self.coords, dims=self.dims)
@requires("sample_stats")
def sample_stats_to_xarray(self):
"""Extract sample_stats from fit."""
dtypes = {"divergent__": bool, "n_leapfrog__": np.int64, "treedepth__": np.int64}
# copy dims and coords
dims = deepcopy(self.dims) if self.dims is not None else {}
coords = deepcopy(self.coords) if self.coords is not None else {}
sampler_params = self.sample_stats
log_likelihood = self.log_likelihood
if isinstance(log_likelihood, str):
log_likelihood_cols = [
col for col in self.posterior[0].columns if log_likelihood == col.split(".")[0]
]
log_likelihood_vals = [item[log_likelihood_cols] for item in self.posterior]
@requires("posterior")
def posterior_to_xarray(self):
"""Extract posterior samples from output csv."""
columns = self.posterior[0].columns
# filter posterior_predictive and log_likelihood
posterior_predictive = self.posterior_predictive
if posterior_predictive is None or (
isinstance(posterior_predictive, str) and posterior_predictive.lower().endswith(".csv")
):
posterior_predictive = []
elif isinstance(posterior_predictive, str):
posterior_predictive = [
col for col in columns if posterior_predictive == col.split(".")[0]
]
else:
posterior_predictive = [
@requires("prior_predictive")
def prior_predictive_to_xarray(self):
"""Convert prior_predictive samples to xarray."""
prior = self.prior
prior_model = self.prior_model
prior_predictive = self.prior_predictive
data = get_draws_stan3(prior, model=prior_model, variables=prior_predictive)
return dict_to_dataset(data, library=self.stan, coords=self.coords, dims=self.dims)
@requires("posterior_predictive")
def posterior_predictive_to_xarray(self):
"""Convert posterior_predictive samples to xarray."""
posterior_predictive = self.posterior_predictive
columns = self.posterior[0].columns
if (
isinstance(posterior_predictive, (tuple, list))
and posterior_predictive[0].endswith(".csv")
) or (isinstance(posterior_predictive, str) and posterior_predictive.endswith(".csv")):
if isinstance(posterior_predictive, str):
posterior_predictive = [posterior_predictive]
chain_data = []
for path in posterior_predictive:
parsed_output = _read_output(path)
for sample, *_ in parsed_output:
chain_data.append(sample)
data = _unpack_dataframes(chain_data)
@requires("trace")
@requires("model")
def constant_data_to_xarray(self):
"""Convert constant data to xarray."""
model_vars = self.pymc3.util.get_default_varnames( # pylint: disable=no-member
self.trace.varnames, include_transformed=True
)
if self.observations is not None:
model_vars.extend(
[obs.name for obs in self.observations.values() if hasattr(obs, "name")]
)
model_vars.extend(self.observations.keys())
constant_data_vars = {
name: var for name, var in self.model.named_vars.items() if name not in model_vars
}
if not constant_data_vars:
return None
@requires("posterior")
def posterior_to_xarray(self):
"""Convert posterior samples to xarray."""
data = self.posterior
if not isinstance(data, dict):
raise TypeError("DictConverter.posterior is not a dictionary")
if "log_likelihood" in data:
warnings.warn(
"log_likelihood found in posterior."
" For stats functions log_likelihood needs to be in sample_stats.",
SyntaxWarning,
)
return dict_to_dataset(data, library=None, coords=self.coords, dims=self.dims)
@requires("trace")
def sample_stats_to_xarray(self):
"""Extract sample_stats from PyMC3 trace."""
rename_key = {"model_logp": "lp"}
data = {}
for stat in self.trace.stat_names:
name = rename_key.get(stat, stat)
data[name] = np.array(self.trace.get_sampler_stats(stat, combine=False))
log_likelihood, dims = self._extract_log_likelihood()
if log_likelihood is not None:
data["log_likelihood"] = log_likelihood
dims = {"log_likelihood": dims}
else:
dims = None
return dict_to_dataset(data, library=self.pymc3, dims=dims, coords=self.coords)
@requires("prior")
def sample_stats_prior_to_xarray(self):
"""Extract sample_stats_prior from prior."""
prior = self.prior
prior_model = self.prior_model
data = get_sample_stats_stan3(prior, model=prior_model)
return dict_to_dataset(data, library=self.stan, coords=self.coords, dims=self.dims)