Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def _ess_mad(ary, relative=False):
"""Calculate split-ess for mean absolute deviance."""
ary = np.asarray(ary)
if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)):
return np.nan
ary = abs(ary - np.median(ary))
ary = ary <= np.median(ary)
ary = _z_scale(_split_chains(ary))
return _ess(ary, relative=relative)
def _rhat(ary):
"""Compute the rhat for a 2d array."""
_numba_flag = Numba.numba_flag
ary = np.asarray(ary, dtype=float)
if _not_valid(ary, check_shape=False):
return np.nan
_, num_samples = ary.shape
# Calculate chain mean
chain_mean = np.mean(ary, axis=1)
# Calculate chain variance
chain_var = _numba_var(svar, np.var, ary, axis=1, ddof=1)
# Calculate between-chain variance
between_chain_variance = num_samples * _numba_var(svar, np.var, chain_mean, axis=None, ddof=1)
# Calculate within-chain variance
within_chain_variance = np.mean(chain_var)
# Estimate of marginal posterior variance
rhat_value = np.sqrt(
(between_chain_variance / within_chain_variance + num_samples - 1) / (num_samples)
)
return rhat_value
# ess sd
ess_sd_value = _ess_sd(ary)
# ess bulk
z_split = _z_scale(_split_chains(ary))
ess_bulk_value = _ess(z_split)
# ess tail
quantile05, quantile95 = _quantile(ary, [0.05, 0.95])
iquantile05 = ary <= quantile05
quantile05_ess = _ess(_split_chains(iquantile05))
iquantile95 = ary <= quantile95
quantile95_ess = _ess(_split_chains(iquantile95))
ess_tail_value = min(quantile05_ess, quantile95_ess)
if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=2)):
rhat_value = np.nan
else:
# r_hat
rhat_bulk = _rhat(z_split)
ary_folded = np.abs(ary - np.median(ary))
rhat_tail = _rhat(_z_scale(_split_chains(ary_folded)))
rhat_value = max(rhat_bulk, rhat_tail)
# mcse_mean
sd = np.std(ary, ddof=1)
mcse_mean_value = sd / np.sqrt(ess_mean_value)
# mcse_sd
fac_mcse_sd = np.sqrt(np.exp(1) * (1 - 1 / ess_sd_value) ** (ess_sd_value - 1) - 1)
mcse_sd_value = sd * fac_mcse_sd
def _ess_sd(ary, relative=False):
"""Compute the effective sample size for the sd."""
ary = np.asarray(ary)
if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)):
return np.nan
ary = _split_chains(ary)
return min(_ess(ary, relative=relative), _ess(ary ** 2, relative=relative))
def _ess_z_scale(ary, relative=False):
"""Calculate ess for z-scaLe."""
ary = np.asarray(ary)
if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)):
return np.nan
return _ess(_z_scale(_split_chains(ary)), relative=relative)
def _ess_tail(ary, prob=None, relative=False):
"""Compute the effective sample size for the tail.
If `prob` defined, ess = min(qess(prob), qess(1-prob))
"""
if prob is None:
prob = (0.05, 0.95)
elif not isinstance(prob, Sequence):
prob = (prob, 1 - prob)
ary = np.asarray(ary)
if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)):
return np.nan
prob_low, prob_high = prob
quantile_low_ess = _ess_quantile(ary, prob_low, relative=relative)
quantile_high_ess = _ess_quantile(ary, prob_high, relative=relative)
return min(quantile_low_ess, quantile_high_ess)
def _ess_local(ary, prob, relative=False):
"""Compute the effective sample size for the specific residual."""
ary = np.asarray(ary)
if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)):
return np.nan
if prob is None:
raise TypeError("Prob not defined.")
if len(prob) != 2:
raise ValueError("Prob argument in ess local must be upper and lower bound")
quantile = _quantile(ary, prob)
iquantile = (quantile[0] <= ary) & (ary <= quantile[1])
return _ess(_split_chains(iquantile), relative=relative)
def _ess_folded(ary, relative=False):
"""Calculate split-ess for folded data."""
ary = np.asarray(ary)
if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)):
return np.nan
return _ess(_z_fold(_split_chains(ary)), relative=relative)
def _mcse_mean(ary):
"""Compute the Markov Chain mean error."""
_numba_flag = Numba.numba_flag
ary = np.asarray(ary)
if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)):
return np.nan
ess = _ess_mean(ary)
if _numba_flag:
sd = _sqrt(svar(np.ravel(ary), ddof=1), np.zeros(1))
else:
sd = np.std(ary, ddof=1)
mcse_mean_value = sd / np.sqrt(ess)
return mcse_mean_value
def _rhat_z_scale(ary):
ary = np.asarray(ary)
if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=2)):
return np.nan
return _rhat(_z_scale(_split_chains(ary)))