Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
A MultiTrace instance
"""
files = glob(os.path.join(name, 'chain-*.csv'))
if len(files) == 0:
raise ValueError('No files present in directory {}'.format(name))
straces = []
for f in files:
chain = int(os.path.splitext(f)[0].rsplit('-', 1)[1])
model_vars_in_chain = _parse_chain_vars(f, model)
strace = Text(name, model=model, vars=model_vars_in_chain)
strace.chain = chain
strace.filename = f
straces.append(strace)
return base.MultiTrace(straces)
step,
start,
parallelize,
tune=tune,
model=model,
random_seed=random_seed,
progressbar=progressbar,
)
if progressbar:
sampling = progress_bar(sampling, total=draws, display=progressbar)
latest_traces = None
for it, traces in enumerate(sampling):
latest_traces = traces
return MultiTrace(latest_traces)
'Keyword argument varnames renamed to var_names, and will be removed in pymc3 3.8',
DeprecationWarning
)
def rscore(x, num_samples):
# Calculate between-chain variance
B = num_samples * np.var(np.mean(x, axis=1), axis=0, ddof=1)
# Calculate within-chain variance
W = np.mean(np.var(x, axis=1, ddof=1), axis=0)
# Estimate of marginal posterior variance
Vhat = W * (num_samples - 1) / num_samples + B / num_samples
return np.sqrt(Vhat / W)
if not isinstance(mtrace, MultiTrace):
# Return rscore for passed arrays
return rscore(np.array(mtrace), mtrace.shape[1])
if mtrace.nchains < 2:
raise ValueError(
'Gelman-Rubin diagnostic requires multiple chains '
'of the same length.')
if var_names is None:
var_names = get_default_varnames(mtrace.varnames, include_transformed=include_transformed)
Rhat = {}
for var in var_names:
x = np.array(mtrace.get_values(var, combine=False))
num_samples = x.shape[1]
----------
name : str
Path to HDF5 arrays file
model : Model
If None, the model is taken from the `with` context.
Returns
-------
A MultiTrace instance
"""
straces = []
for chain in HDF5(name, model=model).chains:
trace = HDF5(name, model=model)
trace.chain = chain
straces.append(trace)
return base.MultiTrace(straces)
def _choose_backend(trace, chain, shortcuts=None, **kwds):
if isinstance(trace, BaseTrace):
return trace
if isinstance(trace, MultiTrace):
return trace._straces[chain]
if trace is None:
return NDArray(**kwds)
if shortcuts is None:
shortcuts = pm.backends._shortcuts
try:
backend = shortcuts[trace]["backend"]
name = shortcuts[trace]["name"]
return backend(name, **kwds)
except TypeError:
return NDArray(vars=trace, **kwds)
except KeyError:
raise ValueError("Argument `trace` is invalid.")
model=None,
**kwargs
):
skip_first = kwargs.get("skip_first", 0)
sampling = _iter_sample(draws, step, start, trace, chain, tune, model, random_seed)
_pbar_data = None
_pbar_data = {"chain": chain, "divergences": 0}
_desc = "Sampling chain {chain:d}, {divergences:,d} divergences"
sampling = progress_bar(sampling, total=draws, display=progressbar)
sampling.comment = _desc.format(**_pbar_data)
try:
strace = None
for it, (strace, diverging) in enumerate(sampling):
if it >= skip_first:
trace = MultiTrace([strace])
if diverging and _pbar_data is not None:
_pbar_data["divergences"] += 1
sampling.comment = _desc.format(**_pbar_data)
except KeyboardInterrupt:
pass
return strace
directory : str
Path to a pymc3 serialized trace
model : pm.Model (optional)
Model used to create the trace. Can also be inferred from context
Returns
-------
pm.Multitrace that was saved in the directory
"""
straces = []
for subdir in glob.glob(os.path.join(directory, '*')):
if os.path.isdir(subdir):
straces.append(SerializeNDArray(subdir).load(model))
if not straces:
raise TraceDirectoryError("%s is not a PyMC3 saved chain directory." % directory)
return base.MultiTrace(straces)
while step.beta < 1.:
print('Beta: ' + str(step.beta), ' Stage: ' + str(step.stage))
if step.stage == 0:
# Initial stage
print('Sample initial stage: ...')
stage_path = homepath + '/stage_' + str(step.stage)
trace = Text(stage_path, model=model)
initial = _iter_initial(step, chain=chain, trace=trace)
progress = progress_bar(step.n_chains)
try:
for i, strace in enumerate(initial):
if progressbar:
progress.update(i)
except KeyboardInterrupt:
strace.close()
mtrace = MultiTrace([strace])
step.population, step.array_population, step.likelihoods = \
step.select_end_points(mtrace)
step.beta, step.old_beta, step.weights = step.calc_beta()
step.covariance = step.calc_covariance()
step.res_indx = step.resample()
step.stage += 1
del(strace, mtrace, trace)
else:
if progressbar and njobs > 1:
progressbar = False
# Metropolis sampling intermediate stages
stage_path = homepath + '/stage_' + str(step.stage)
step.proposal_dist = MvNPd(step.covariance)
sample_args = {
'draws': n_steps,