Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
key, values pairs. Values are formatted to shape = (nchain, ndraws, *shape)
"""
col_groups = defaultdict(list)
columns = dfs[0].columns
for col in columns:
key, *loc = col.split(".")
loc = tuple(int(i) - 1 for i in loc)
col_groups[key].append((col, loc))
chains = len(dfs)
draws = len(dfs[0])
sample = {}
for key, cols_locs in col_groups.items():
ndim = np.array([loc for _, loc in cols_locs]).max(0) + 1
dtype = dfs[0][cols_locs[0][0]].dtype
sample[key] = utils.full((chains, draws, *ndim), 0, dtype=dtype)
for col, loc in cols_locs:
for chain_id, df in enumerate(dfs):
draw = df[col].values
if loc == ():
sample[key][chain_id, :] = draw
else:
axis1_all = range(sample[key].shape[1])
slicer = (chain_id, axis1_all, *loc)
sample[key][slicer] = draw
return sample
def observed_data_to_xarray(self):
"""Convert observed data to xarray."""
observed_data = {}
for key, vals in self.observed_data.items():
vals = utils.one_de(vals)
val_dims = self.dims.get(key) if self.dims is not None else None
val_dims, coords = generate_dims_coords(
vals.shape, key, dims=val_dims, coords=self.coords
)
observed_data[key] = xr.DataArray(vals, dims=val_dims, coords=coords)
return xr.Dataset(data_vars=observed_data, attrs=make_attrs(library=self.cmdstanpy))
observed_data = {}
if isinstance(self.observed, self.tf.Tensor):
with self.tf.Session() as sess:
vals = sess.run(self.observed, feed_dict=self.feed_dict)
else:
vals = self.observed
if self.dims is None:
dims = {}
else:
dims = self.dims
name = "obs"
val_dims = dims.get(name)
vals = utils.one_de(vals)
val_dims, coords = generate_dims_coords(vals.shape, name, dims=val_dims, coords=self.coords)
# coords = {key: xr.IndexVariable((key,), data=coords[key]) for key in val_dims}
observed_data[name] = xr.DataArray(vals, dims=val_dims, coords=coords)
return xr.Dataset(data_vars=observed_data, attrs=make_attrs(library=self.tfp))
def observed_data_to_xarray(self):
"""Convert observed data to xarray."""
if self.dims is None:
dims = {}
else:
dims = self.dims
observed_data = {}
for name, vals in self.observations.items():
vals = utils.one_de(vals)
val_dims = dims.get(name)
val_dims, coords = generate_dims_coords(
vals.shape, name, dims=val_dims, coords=self.coords
)
# filter coords based on the dims
coords = {key: xr.IndexVariable((key,), data=coords[key]) for key in val_dims}
observed_data[name] = xr.DataArray(vals, dims=val_dims, coords=coords)
return xr.Dataset(data_vars=observed_data, attrs=make_attrs(library=self.pyro))