Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def __call__(self, shape, name=None, memo=None):
memo, shape = self._prepare(memo, shape)
if name is None:
name = self.auto()
shape, tag = self._get_shape(shape)
if id(self) ^ hash(tag) in memo:
return memo[id(self) ^ hash(tag)]
model = pm.modelcontext(None)
called_args = self._call_args(self.args, name, shape, memo)
called_kwargs = self._call_kwargs(self.kwargs, name, shape, memo)
called_kwargs.update(shape=shape['default'])
val = model.Var(
name, self.distcls.dist(
*called_args,
dtype=theano.config.floatX,
**called_kwargs
),
)
if self.testval is None:
val.tag.test_value = get_default_testval()(shape['default']).astype(val.dtype)
elif isinstance(self.testval, str) and self.testval == 'random':
val.tag.test_value = val.random(size=shape['default']).astype(val.dtype)
else:
val.tag.test_value = self.testval(shape['default']).astype(val.dtype)
regular_variance=1e-3,
**kwargs,
):
"""Get a PyMC3 NUTS step tuned for a given burn-in trace
Args:
trace: The ``MultiTrace`` output from a previous run of
``pymc3.sample``.
regular_window: The weight (in units of number of steps) to use
when regularizing the mass matrix estimate.
regular_variance: The amplitude of the regularization for the mass
matrix. This will be added to the diagonal of the covariance
matrix with weight given by ``regular_window``.
"""
model = pm.modelcontext(model)
# If not given, use the trivial metric
if trace is None or model.ndim == 1:
potential = quad.QuadPotentialDiag(np.ones(model.ndim))
else:
# Loop over samples and convert to the relevant parameter space;
# I'm sure that there's an easier way to do this, but I don't know
# how to make something work in general...
N = len(trace) * trace.nchains
samples = np.empty((N, model.ndim))
i = 0
for chain in trace._straces.values():
for p in chain:
samples[i] = model.bijection.map(p)
i += 1
def get_args_for_theano_function(point=None, model=None):
model = pm.modelcontext(model)
if point is None:
point = model.test_point
return [point[k.name] for k in model.vars]
def __init__(self, vars=None, S=None, proposal_dist=None, scaling=1.,
tune=True, tune_interval=100, model=None, mode=None, **kwargs):
model = pm.modelcontext(model)
if vars is None:
vars = model.vars
vars = pm.inputvars(vars)
if S is None:
S = np.ones(sum(v.dsize for v in vars))
if proposal_dist is not None:
self.proposal_dist = proposal_dist(S)
elif S.ndim == 1:
self.proposal_dist = NormalProposal(S)
elif S.ndim == 2:
self.proposal_dist = MultivariateNormalProposal(S)
else:
raise ValueError("Invalid rank for variance: %s" % S.ndim)
def __init__(self, vars=None, S=None, proposal_dist=None, lamb=None, scaling=0.001,
tune=None, tune_interval=100, model=None, mode=None, **kwargs):
model = pm.modelcontext(model)
if vars is None:
vars = model.cont_vars
vars = pm.inputvars(vars)
if S is None:
S = np.ones(model.ndim)
if proposal_dist is not None:
self.proposal_dist = proposal_dist(S)
else:
self.proposal_dist = UniformProposal(S)
self.scaling = np.atleast_1d(scaling).astype('d')
if lamb is None:
lamb = 2.38 / np.sqrt(2 * model.ndim)
def get_theano_function_for_var(var, model=None, **kwargs):
model = pm.modelcontext(model)
kwargs["on_unused_input"] = kwargs.get("on_unused_input", "ignore")
return theano.function(model.vars, var, **kwargs)
Parameters
----------
P_min : `~astropy.units.Quantity` [time]
P_max : `~astropy.units.Quantity` [time]
s : `~pm.model.TensorVariable`, ~astropy.units.Quantity` [speed]
model : `pymc3.Model`
This is either required, or this function must be called within a pymc3
model context.
"""
import theano.tensor as tt
import pymc3 as pm
from exoplanet.distributions import Angle
import exoplanet.units as xu
from .distributions import UniformLog, Kipping13Global
model = pm.modelcontext(model)
if pars is None:
pars = dict()
if s is None:
s = 0 * u.m/u.s
if isinstance(s, pm.model.TensorVariable):
pars['s'] = pars.get('s', s)
else:
if not hasattr(s, 'unit') or not s.unit.is_equivalent(u.km/u.s):
raise u.UnitsError("Invalid unit for s: must be equivalent to km/s")
# dictionary of parameters to return
out_pars = dict()