Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
from ngboost.distns import LogNormal
from sklearn.datasets import load_boston
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
import numpy as np
if __name__ == "__main__":
X, Y = load_boston(True)
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.2)
# introduce administrative censoring
T_train = np.minimum(Y_train, 30)
E_train = Y_train > 30
ngb = NGBSurvival(Dist=LogNormal).fit(X_train, T_train, E_train)
Y_preds = ngb.predict(X_test)
Y_dists = ngb.pred_dist(X_test)
# test Mean Squared Error
test_MSE = mean_squared_error(Y_preds, Y_test)
print("Test MSE", test_MSE)
# test Negative Log Likelihood
test_NLL = -Y_dists.logpdf(Y_test.flatten()).mean()
print("Test NLL", test_NLL)
def __init__(
self,
Dist=LogNormal,
Score=LogScore,
Base=default_tree_learner,
natural_gradient=True,
n_estimators=500,
learning_rate=0.01,
minibatch_frac=1.0,
col_sample=1.0,
verbose=True,
verbose_eval=100,
tol=1e-4,
random_state=None,
):
assert issubclass(
Dist, RegressionDistn
), f"{Dist.__name__} is not useable for regression."
lognorm_crps_fn = lambda p: LogNormal(p, temp_scale = 1.0).crps(rvs).mean()
logscale_crps_grad_fn = grad(logscale_crps_fn)
import matplotlib as mpl
import itertools
from ngboost.distns import Normal, LogNormal
from tqdm import tqdm
from matplotlib import pyplot as plt
if __name__ == "__main__":
key = random.PRNGKey(seed=123)
rvs = np.exp(random.normal(key=key, shape=(500,)))
logscale_axis = np.linspace(-3, 3, 1000)
lognorm_axis = np.linspace(1e-8, 5, 1000)
logscale_cdf = Normal(np.array([0, 0]), temp_scale = 1.0).cdf(logscale_axis)
lognorm_cdf = LogNormal(np.array([0, 1]), temp_scale = 1.0).cdf(lognorm_axis)
plt.figure(figsize = (8, 3))
plt.subplot(1, 2, 2)
plt.xlabel("$\\log x$")
plt.ylabel("$F_\\theta(\\log x)$")
plt.plot(logscale_axis, logscale_cdf, color = "black")
plt.fill_between(logscale_axis, logscale_cdf,
where = logscale_axis < 0, color = "grey")
plt.fill_between(logscale_axis, 1, logscale_cdf,
where = logscale_axis > 0, color = "grey")
plt.axvline(0, color = "grey")
plt.title("Log-scale")
plt.subplot(1, 2, 1)
plt.plot(lognorm_axis, lognorm_cdf, color = "black")
plt.axvline(1, color = "grey")
plt.xlabel("$x$")