Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_multivariate_gradient_descent():
gd_lr = LinearRegression(eta=0.001, epochs=500, solver='gd', random_seed=0)
gd_lr.fit(X_rm_lstat_std, y_std)
assert_almost_equal(gd_lr.w_, expect_rm_lstat_std, decimal=3)
def test_univariate_stochastic_gradient_descent():
sgd_lr = LinearRegression(solver='sgd', eta=0.0001, epochs=100, random_seed=0)
sgd_lr.fit(X_rm_std, y_std)
assert_almost_equal(sgd_lr.w_, expect_rm_std, decimal=2)
def test_univariate_normal_equation_std():
ne_lr = LinearRegression(solver='normal_equation')
ne_lr.fit(X_rm_std, y_std)
assert_almost_equal(ne_lr.w_, expect_rm_std, decimal=3)
def test_multivariate_normal_equation():
ne_lr = LinearRegression(solver='normal_equation')
ne_lr.fit(X_rm_lstat, y)
assert_almost_equal(ne_lr.w_, expect_rm_lstat, decimal=3)
def test_multivariate_stochastic_gradient_descent():
sgd_lr = LinearRegression(eta=0.0001, epochs=500, solver='sgd', random_seed=0)
sgd_lr.fit(X_rm_lstat_std, y_std)
assert_almost_equal(sgd_lr.w_, expect_rm_lstat_std, decimal=2)
def test_univariate_gradient_descent():
gd_lr = LinearRegression(solver='gd', eta=0.001, epochs=500, random_seed=0)
gd_lr.fit(X_rm_std, y_std)
assert_almost_equal(gd_lr.w_, expect_rm_std, decimal=3)
def test_univariate_normal_equation():
ne_lr = LinearRegression(solver='normal_equation')
ne_lr.fit(X_rm, y)
assert_almost_equal(ne_lr.w_, expect_rm, decimal=3)