Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
print(
"== Dataset=%s X.shape=%s Censorship=%.4f"
% (args.dataset, str(X.shape), np.mean(1 - E))
)
for itr in range(args.reps):
X_train, X_test, Y_train, Y_test, E_train, E_test = train_test_split(
X, Y, E, test_size=0.2
)
X_train, X_val, Y_train, Y_val, E_train, E_val = train_test_split(
X_train, Y_train, E_train, test_size=0.2
)
ngb = NGBSurvival(
Dist=eval(args.distn),
n_estimators=args.n_est,
learning_rate=args.lr,
natural_gradient=args.natural,
verbose=args.verbose,
minibatch_frac=1.0,
Base=base_name_to_learner[args.base],
Score=eval(args.score),
)
train_losses = ngb.fit(X_train, Y_train, E_train)
forecast = ngb.pred_dist(X_test)
train_forecast = ngb.pred_dist(X_train)
print(
"NGB score: %.4f (val), %.4f (train)"
% (
args = argparser.parse_args()
m, n = 1000, 5
X = np.random.randn(m, n) / np.sqrt(n)
Y = X @ np.ones((n,)) + 0.5 * np.random.randn(*(m,))
T = X @ np.ones((n,)) + 0.5 * np.random.randn(*(m,)) + args.eps
E = (T > Y).astype(int)
print(X.shape, Y.shape, E.shape)
print(f"Event rate: {np.mean(E):.2f}")
X_tr, X_te, Y_tr, Y_te, T_tr, T_te, E_tr, E_te = train_test_split(
X, Y, T, E, test_size=0.2
)
ngb = NGBSurvival(
Dist=Exponential,
n_estimators=args.n_estimators,
learning_rate=args.lr,
natural_gradient=True,
Base=default_linear_learner,
Score=MLE,
verbose=True,
verbose_eval=True,
)
train_losses = ngb.fit(X_tr, np.exp(np.minimum(Y_tr, T_tr)), E_tr)
preds = ngb.pred_dist(X_te)
print(f"R2: {r2_score(Y_te, np.log(preds.mean()))}")
plt.hist(preds.mean(), range=(0, 10), bins=30, alpha=0.5, label="Pred")
plt.hist(np.exp(Y_te), range=(0, 10), bins=30, alpha=0.5, label="True")