Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
class Splitter:
def __init__(self):
return
def split(self, X, T):
return [(np.arange(0, first_half), np.arange(first_half, X.shape[0])),
(np.arange(first_half, X.shape[0]), np.arange(0, first_half))]
lr = LinearDMLCateEstimator(model_y=first_stage(),
model_t=first_stage(),
n_splits=Splitter(),
linear_first_stages=False,
discrete_treatment=False)
lr.fit(y, X[:, -d_t:], X[:, :d_x], X[:, d_x:-d_t],
inference=StatsModelsInference(cov_type=cov_type))
for alpha in alpha_list:
key = ("n_{}_n_exp_{}_hetero_{}_d_{}_d_x_"
"{}_p_{}_d_t_{}_cov_type_{}_alpha_{}").format(
n, n_exp, hetero_coef, d, d_x, p, d_t, cov_type, alpha)
_append_coverage(key, coverage_est, est, X_test,
alpha, true_coef, true_effect)
_append_coverage(key, coverage_lr, lr, X_test,
alpha, true_coef, true_effect)
if it == n_exp - 1:
n_tests += 1
mean_coef_cov = np.mean(coverage_est[key]['coef_cov'])
mean_eff_cov = np.mean(coverage_est[key]['effect_cov'])
mean_coef_cov_lr = np.mean(coverage_lr[key]['coef_cov'])
mean_eff_cov_lr = np.mean(coverage_lr[key]['effect_cov'])
[print("{}. Time: {:.2f}, Mean Coef Cov: ({:.4f}, {:.4f}), "
"Mean Effect Cov: ({:.4f}, {:.4f})".format(key,
np.arange(first_half_sum, X.shape[0])),
(np.arange(first_half_sum, X.shape[0]),
np.arange(0, first_half_sum))]
est = LinearDMLCateEstimator(model_y=first_stage(),
model_t=first_stage(),
n_splits=SplitterSum(),
linear_first_stages=False,
discrete_treatment=False)
est.fit(y_sum,
X_final[:, -d_t:],
X_final[:, :d_x],
X_final[:, d_x:-d_t],
sample_weight=n_sum,
sample_var=var_sum,
inference=StatsModelsInference(cov_type=cov_type))
class Splitter:
def __init__(self):
return
def split(self, X, T):
return [(np.arange(0, first_half), np.arange(first_half, X.shape[0])),
(np.arange(first_half, X.shape[0]), np.arange(0, first_half))]
lr = LinearDMLCateEstimator(model_y=first_stage(),
model_t=first_stage(),
n_splits=Splitter(),
linear_first_stages=False,
discrete_treatment=False)
lr.fit(y, X[:, -d_t:], X[:, :d_x], X[:, d_x:-d_t],
inference=StatsModelsInference(cov_type=cov_type))