How to use the ngboost.distns.LogNormal function in ngboost

To help you get started, we’ve selected a few ngboost examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github stanfordmlgroup / ngboost / examples / survival.py View on Github external
from ngboost.distns import LogNormal
from sklearn.datasets import load_boston
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
import numpy as np

if __name__ == "__main__":

    X, Y = load_boston(True)
    X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.2)

    # introduce administrative censoring
    T_train = np.minimum(Y_train, 30)
    E_train = Y_train > 30

    ngb = NGBSurvival(Dist=LogNormal).fit(X_train, T_train, E_train)
    Y_preds = ngb.predict(X_test)
    Y_dists = ngb.pred_dist(X_test)

    # test Mean Squared Error
    test_MSE = mean_squared_error(Y_preds, Y_test)
    print("Test MSE", test_MSE)

    # test Negative Log Likelihood
    test_NLL = -Y_dists.logpdf(Y_test.flatten()).mean()
    print("Test NLL", test_NLL)
github stanfordmlgroup / ngboost / ngboost / api.py View on Github external
def __init__(
        self,
        Dist=LogNormal,
        Score=LogScore,
        Base=default_tree_learner,
        natural_gradient=True,
        n_estimators=500,
        learning_rate=0.01,
        minibatch_frac=1.0,
        col_sample=1.0,
        verbose=True,
        verbose_eval=100,
        tol=1e-4,
        random_state=None,
    ):

        assert issubclass(
            Dist, RegressionDistn
        ), f"{Dist.__name__} is not useable for regression."
github stanfordmlgroup / ngboost / examples / visualizations / vis_crps_logscale.py View on Github external
    lognorm_crps_fn = lambda p: LogNormal(p, temp_scale = 1.0).crps(rvs).mean()
    logscale_crps_grad_fn = grad(logscale_crps_fn)
github stanfordmlgroup / ngboost / examples / visualizations / vis_crps_logscale.py View on Github external
import matplotlib as mpl
import itertools
from ngboost.distns import Normal, LogNormal
from tqdm import tqdm
from matplotlib import pyplot as plt


if __name__ == "__main__":

    key = random.PRNGKey(seed=123)
    rvs = np.exp(random.normal(key=key, shape=(500,)))

    logscale_axis = np.linspace(-3, 3, 1000)
    lognorm_axis = np.linspace(1e-8, 5, 1000)
    logscale_cdf = Normal(np.array([0, 0]), temp_scale = 1.0).cdf(logscale_axis)
    lognorm_cdf = LogNormal(np.array([0, 1]), temp_scale = 1.0).cdf(lognorm_axis)

    plt.figure(figsize = (8, 3))
    plt.subplot(1, 2, 2)
    plt.xlabel("$\\log x$")
    plt.ylabel("$F_\\theta(\\log x)$")
    plt.plot(logscale_axis, logscale_cdf, color = "black")
    plt.fill_between(logscale_axis, logscale_cdf,
                     where = logscale_axis < 0, color = "grey")
    plt.fill_between(logscale_axis, 1, logscale_cdf,
                     where = logscale_axis > 0, color = "grey")
    plt.axvline(0, color = "grey")
    plt.title("Log-scale")
    plt.subplot(1, 2, 1)
    plt.plot(lognorm_axis, lognorm_cdf, color = "black")
    plt.axvline(1, color = "grey")
    plt.xlabel("$x$")