Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
@cached(name="prior_distribution_memo")
def prior_distribution(self):
zeros = torch.zeros_like(self.variational_distribution.mean)
ones = torch.ones_like(zeros)
res = MultivariateNormal(zeros, DiagLazyTensor(ones))
return res
@cached(name="cholesky_factor")
def _cholesky_factor(self, induc_induc_covar):
L = psd_safe_cholesky(delazify(induc_induc_covar).double())
return L
@cached(name="covar_cache")
def covar_cache(self):
# Get inverse root
train_train_covar = self.train_prior_dist.lazy_covariance_matrix
train_interp_indices = train_train_covar.left_interp_indices
train_interp_values = train_train_covar.left_interp_values
# Get probe vectors for inverse root
num_probe_vectors = settings.fast_pred_var.num_probe_vectors()
num_inducing = train_train_covar.base_lazy_tensor.size(-1)
vector_indices = torch.randperm(num_inducing).type_as(train_interp_indices)
probe_vector_indices = vector_indices[:num_probe_vectors]
test_vector_indices = vector_indices[num_probe_vectors : 2 * num_probe_vectors]
probe_interp_indices = probe_vector_indices.unsqueeze(1)
probe_test_interp_indices = test_vector_indices.unsqueeze(1)
dtype = train_train_covar.dtype
@cached(name="prior_distribution_memo")
def prior_distribution(self):
out = self.model(self.inducing_points)
res = MultivariateNormal(out.mean, out.lazy_covariance_matrix.add_jitter())
return res
@cached(name="kernel_eval")
def evaluate_kernel(self):
"""
NB: This is a meta LazyTensor, in the sense that evaluate can return
a LazyTensor if the kernel being evaluated does so.
"""
x1 = self.x1
x2 = self.x2
with settings.lazily_evaluate_kernels(False):
temp_active_dims = self.kernel.active_dims
self.kernel.active_dims = None
res = self.kernel(x1, x2, diag=False, last_dim_is_batch=self.last_dim_is_batch, **self.params)
self.kernel.active_dims = temp_active_dims
# Check the size of the output
if settings.debug.on():
@cached(name="cholesky_factor")
def _cholesky_factor(self, induc_induc_covar):
# Maybe used - if we're not using CG
L = psd_safe_cholesky(delazify(induc_induc_covar))
return L
@cached(name="mean_cache")
def mean_cache(self):
train_train_covar = self.train_prior_dist.lazy_covariance_matrix
train_interp_indices = train_train_covar.left_interp_indices
train_interp_values = train_train_covar.left_interp_values
mvn = self.likelihood(self.train_prior_dist, self.train_inputs)
train_mean, train_train_covar_with_noise = mvn.mean, mvn.lazy_covariance_matrix
mean_diff = (self.train_labels - train_mean).unsqueeze(-1)
train_train_covar_inv_labels = train_train_covar_with_noise.inv_matmul(mean_diff)
# New root factor
base_size = train_train_covar.base_lazy_tensor.size(-1)
mean_cache = train_train_covar.base_lazy_tensor.matmul(
left_t_interp(train_interp_indices, train_interp_values, train_train_covar_inv_labels, base_size)
)
@cached(name="size")
def _size(self):
left_size = _prod(lazy_tensor.size(-2) for lazy_tensor in self.lazy_tensors)
right_size = _prod(lazy_tensor.size(-1) for lazy_tensor in self.lazy_tensors)
return torch.Size((*self.lazy_tensors[0].batch_shape, left_size, right_size))
@cached
def evaluate(self):
if self._diag.dim() == 0:
return self._diag
return self._diag.unsqueeze(-1) * torch.eye(self._diag.shape[-1], dtype=self.dtype, device=self.device)
@cached(name="mean_diff_inv_quad_memo")
def mean_diff_inv_quad(self):
prior_mean = self.prior_distribution.mean
return (prior_mean * prior_mean).sum(-1)