Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def iterator(self):
"""Iterate over models and chains for each variable."""
if self.combined:
grouped_data = [[(0, datum)] for datum in self.data]
skip_dims = {"chain"}
else:
grouped_data = [datum.groupby("chain") for datum in self.data]
skip_dims = set()
label_dict = OrderedDict()
for name, grouped_datum in zip(self.model_names, grouped_data):
for _, sub_data in grouped_datum:
datum_iter = xarray_var_iter(
sub_data,
var_names=[self.var_name],
skip_dims=skip_dims,
reverse_selections=True,
)
for _, selection, values in datum_iter:
label = make_label(self.var_name, selection, position="beside")
if label not in label_dict:
label_dict[label] = OrderedDict()
if name not in label_dict[label]:
label_dict[label][name] = []
label_dict[label][name].append(values)
y = self.y_start
for label, model_data in label_dict.items():
for model_name, value_list in model_data.items():
_axes = [
bkp.figure(**backend_kwargs),
bkp.figure(x_range=axes[0][1].x_range, **backend_kwargs),
]
else:
_axes = [bkp.figure(**backend_kwargs), bkp.figure(**backend_kwargs)]
axes.append(_axes)
axes = np.array(axes)
cds_data = {}
cds_var_groups = {}
draw_name = "draw"
for var_name, selection, value in list(
xarray_var_iter(data, var_names=var_names, combined=True)
):
if selection:
cds_name = "{}_ARVIZ_CDS_SELECTION_{}".format(
var_name,
"_".join(
str(item)
for key, value in selection.items()
for item in (
[key, value]
if (isinstance(value, str) or not isinstance(value, Iterable))
else [key, *value]
)
),
)
else:
cds_name = var_name
coords[key] = np.where(np.in1d(observed[key], coords[key]))[0]
obs_plotters = filter_plotters_list(
list(
xarray_var_iter(
observed.isel(coords), skip_dims=set(flatten), var_names=var_names, combined=True
)
),
"plot_ppc",
)
length_plotters = len(obs_plotters)
pp_plotters = [
tup
for _, tup in zip(
range(length_plotters),
xarray_var_iter(
posterior_predictive.isel(coords),
var_names=pp_var_names,
skip_dims=set(flatten_pp),
combined=True,
),
)
]
rows, cols = default_grid(length_plotters)
(figsize, ax_labelsize, _, xt_labelsize, linewidth, markersize) = _scale_fig_size(
figsize, textsize, rows, cols
)
ppcplot_kwargs = dict(
ax=ax,
length_plotters=length_plotters,
Change size of credible interval
.. plot::
:context: close-figs
>>> az.plot_posterior(data, var_names=['mu'], credible_interval=.75)
"""
data = convert_to_dataset(data, group=group)
var_names = _var_names(var_names, data)
if coords is None:
coords = {}
plotters = filter_plotters_list(
list(xarray_var_iter(get_coords(data, coords), var_names=var_names, combined=True)),
"plot_posterior",
)
length_plotters = len(plotters)
rows, cols = default_grid(length_plotters)
(figsize, ax_labelsize, titlesize, xt_labelsize, _linewidth, _) = _scale_fig_size(
figsize, textsize, rows, cols
)
kwargs.setdefault("linewidth", _linewidth)
posteriorplot_kwargs = dict(
ax=ax,
length_plotters=length_plotters,
rows=rows,
cols=cols,
figsize=figsize,
def iterator(self):
"""Iterate over models and chains for each variable."""
if self.combined:
grouped_data = [[(0, datum)] for datum in self.data]
skip_dims = {"chain"}
else:
grouped_data = [datum.groupby("chain") for datum in self.data]
skip_dims = set()
label_dict = OrderedDict()
for name, grouped_datum in zip(self.model_names, grouped_data):
for _, sub_data in grouped_datum:
datum_iter = xarray_var_iter(
sub_data,
var_names=[self.var_name],
skip_dims=skip_dims,
reverse_selections=True,
)
for _, selection, values in datum_iter:
label = make_label(self.var_name, selection, position="beside")
if label not in label_dict:
label_dict[label] = OrderedDict()
if name not in label_dict[label]:
label_dict[label][name] = []
label_dict[label][name].append(values)
y = self.y_start
for label, model_data in label_dict.items():
for model_name, value_list in model_data.items():
num_colors = len(data.chain) + 1 if combined else len(data.chain)
# TODO: matplotlib is always required by arviz. Can we get rid of it?
colors = [
prop
for _, prop in zip(
range(num_colors), cycle(plt.rcParams["axes.prop_cycle"].by_key()["color"])
)
]
if compact:
skip_dims = set(data.dims) - {"chain", "draw"}
else:
skip_dims = set()
plotters = list(xarray_var_iter(data, var_names=var_names, combined=True, skip_dims=skip_dims))
max_plots = rcParams["plot.max_subplots"]
max_plots = len(plotters) if max_plots is None else max_plots
if len(plotters) > max_plots:
warnings.warn(
"rcParams['plot.max_subplots'] ({max_plots}) is smaller than the number "
"of variables to plot ({len_plotters}), generating only {max_plots} "
"plots".format(max_plots=max_plots, len_plotters=len(plotters)),
SyntaxWarning,
)
plotters = plotters[:max_plots]
if figsize is None:
figsize = (12, len(plotters) * 2)
if trace_kwargs is None:
trace_kwargs = {}
"""
valid_kinds = ["scatter", "kde", "hexbin"]
if kind not in valid_kinds:
raise ValueError(
("Plot type {} not recognized." "Plot type must be in {}").format(kind, valid_kinds)
)
data = convert_to_dataset(data, group="posterior")
if coords is None:
coords = {}
var_names = _var_names(var_names, data)
plotters = list(xarray_var_iter(get_coords(data, coords), var_names=var_names, combined=True))
if len(plotters) != 2:
raise Exception(
"Number of variables to be plotted must 2 (you supplied {})".format(len(plotters))
)
figsize, ax_labelsize, _, xt_labelsize, linewidth, _ = _scale_fig_size(figsize, textsize)
if joint_kwargs is None:
joint_kwargs = {}
if marginal_kwargs is None:
marginal_kwargs = {}
marginal_kwargs.setdefault("plot_kwargs", {})
marginal_kwargs["plot_kwargs"]["linewidth"] = linewidth