Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
model_t = self.models_t[group]
yhat_cs[group] = model_c.predict(X)
yhat_ts[group] = model_t.predict(X)
if (y is not None) and (treatment is not None) and verbose:
mask = (treatment == group) | (treatment == self.control_name)
treatment_filt = treatment[mask]
y_filt = y[mask]
w = (treatment_filt == group).astype(int)
yhat = np.zeros_like(y_filt, dtype=float)
yhat[w == 0] = yhat_cs[group][mask][w == 0]
yhat[w == 1] = yhat_ts[group][mask][w == 1]
logger.info('Error metrics for group {}'.format(group))
regression_metrics(y_filt, yhat, w)
te = np.zeros((X.shape[0], self.t_groups.shape[0]))
for i, group in enumerate(self.t_groups):
te[:, i] = yhat_ts[group] - yhat_cs[group]
if not return_components:
return te
else:
return te, yhat_cs, yhat_ts
# set the treatment column to one (the treatment group)
X_new[:, 0] = 1
yhat_ts[group] = model.predict(X_new)
if (y is not None) and (treatment is not None) and verbose:
mask = (treatment == group) | (treatment == self.control_name)
treatment_filt = treatment[mask]
w = (treatment_filt == group).astype(int)
y_filt = y[mask]
yhat = np.zeros_like(y_filt, dtype=float)
yhat[w == 0] = yhat_cs[group][mask][w == 0]
yhat[w == 1] = yhat_ts[group][mask][w == 1]
logger.info('Error metrics for group {}'.format(group))
regression_metrics(y_filt, yhat, w)
te = np.zeros((X.shape[0], self.t_groups.shape[0]))
for i, group in enumerate(self.t_groups):
te[:, i] = yhat_ts[group] - yhat_cs[group]
if not return_components:
return te
else:
return te, yhat_cs, yhat_ts
return te
_te = (p[group] * dhat_cs[group] + (1 - p[group]) * dhat_ts[group]).reshape(-1, 1)
te[:, i] = np.ravel(_te)
if (y is not None) and (treatment is not None) and verbose:
mask = (treatment == group) | (treatment == self.control_name)
treatment_filt = treatment[mask]
X_filt = X[mask]
y_filt = y[mask]
w = (treatment_filt == group).astype(int)
yhat = np.zeros_like(y_filt, dtype=float)
yhat[w == 0] = self.models_mu_c[group].predict(X_filt[w == 0])
yhat[w == 1] = self.models_mu_t[group].predict(X_filt[w == 1])
logger.info('Error metrics for group {}'.format(group))
regression_metrics(y_filt, yhat, w)
if not return_components:
return te
else:
return te, dhat_cs, dhat_ts