Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_single_condition():
estimator = tree.DecisionTreeRegressor()
estimator.fit([[1], [2]], [1, 2])
assembler = assemblers.TreeModelAssembler(estimator)
actual = assembler.assemble()
expected = ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(0),
ast.NumVal(1.5),
ast.CompOpType.LTE),
ast.NumVal(1.0),
ast.NumVal(2.0))
assert utils.cmp_exprs(actual, expected)
def test_two_features():
estimator = linear_model.LinearRegression()
estimator.coef_ = [1, 2]
estimator.intercept_ = 3
assembler = assemblers.LinearModelAssembler(estimator)
actual = assembler.assemble()
expected = ast.BinNumExpr(
ast.BinNumExpr(
ast.NumVal(3),
ast.BinNumExpr(
ast.FeatureRef(0),
ast.NumVal(1),
ast.BinNumOpType.MUL),
ast.BinNumOpType.ADD),
ast.BinNumExpr(
ast.FeatureRef(1),
ast.NumVal(2),
ast.BinNumOpType.MUL),
ast.BinNumOpType.ADD)
assert utils.cmp_exprs(actual, expected)
expected = ast.SubroutineExpr(
ast.BinNumExpr(
ast.BinNumExpr(
ast.NumVal(base_score),
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(12),
ast.NumVal(9.72500038),
ast.CompOpType.GTE),
ast.NumVal(1.67318344),
ast.NumVal(2.92757893)),
ast.BinNumOpType.ADD),
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(5),
ast.NumVal(6.94099998),
ast.CompOpType.GTE),
ast.NumVal(3.3400948),
ast.NumVal(1.72118247)),
ast.BinNumOpType.ADD))
assert utils.cmp_exprs(actual, expected)
def test_if_expr():
expr = ast.IfExpr(
ast.CompExpr(ast.NumVal(1), ast.FeatureRef(0), ast.CompOpType.EQ),
ast.NumVal(2),
ast.NumVal(3))
interpreter = interpreters.PythonInterpreter()
expected_code = """
def score(input):
if (1) == (input[0]):
var0 = 2
else:
var0 = 3
return var0
"""
utils.assert_code_equal(interpreter.interpret(expr), expected_code)
sigmoid = ast.BinNumExpr(
ast.NumVal(1),
ast.BinNumExpr(
ast.NumVal(1),
ast.ExpExpr(
ast.BinNumExpr(
ast.NumVal(0),
ast.SubroutineExpr(
ast.BinNumExpr(
ast.BinNumExpr(
ast.NumVal(0),
ast.SubroutineExpr(
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(23),
ast.NumVal(868.2000000000002),
ast.CompOpType.GT),
ast.NumVal(0.2762557140263451),
ast.NumVal(0.6399134166614473))),
ast.BinNumOpType.ADD),
ast.SubroutineExpr(
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(27),
ast.NumVal(0.14205000000000004),
ast.CompOpType.GT),
ast.NumVal(-0.2139321843285849),
ast.NumVal(0.1151466338793227))),
ast.BinNumOpType.ADD)),
ast.BinNumOpType.SUB)),
ast.BinNumOpType.ADD),
def test_bin_num_expr():
expr = ast.BinNumExpr(
ast.BinNumExpr(
ast.FeatureRef(0), ast.NumVal(-2), ast.BinNumOpType.DIV),
ast.NumVal(2),
ast.BinNumOpType.MUL)
expected_code = """
Module Model
Function score(ByRef input_vector() As Double) As Double
score = ((input_vector(0)) / (-2)) * (2)
End Function
End Module
"""
interpreter = VisualBasicInterpreter()
utils.assert_code_equal(interpreter.interpret(expr), expected_code)
expected = ast.SubroutineExpr(
ast.BinNumExpr(
ast.BinNumExpr(
ast.NumVal(0),
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(5),
ast.NumVal(6.8455),
ast.CompOpType.GT),
ast.NumVal(24.007392728914056),
ast.NumVal(22.35695742616179)),
ast.BinNumOpType.ADD),
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(12),
ast.NumVal(9.63),
ast.CompOpType.GT),
ast.NumVal(-0.4903836928981587),
ast.NumVal(0.7222498915097475)),
ast.BinNumOpType.ADD))
assert utils.cmp_exprs(actual, expected)
ast.FeatureRef(0),
ast.NumVal(1.5),
ast.CompOpType.LTE),
ast.VectorVal([
ast.NumVal(0.0),
ast.NumVal(1.0)]),
ast.VectorVal([
ast.NumVal(1.0),
ast.NumVal(0.0)]))),
ast.NumVal(0.5),
ast.BinNumOpType.MUL),
ast.BinVectorNumExpr(
ast.SubroutineExpr(
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(0),
ast.NumVal(2.5),
ast.CompOpType.LTE),
ast.VectorVal([
ast.NumVal(1.0),
ast.NumVal(0.0)]),
ast.VectorVal([
ast.NumVal(0.0),
ast.NumVal(1.0)]))),
ast.NumVal(0.5),
ast.BinNumOpType.MUL),
ast.BinNumOpType.ADD)
assert utils.cmp_exprs(actual, expected)
ast.BinNumExpr(
ast.SubroutineExpr(
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(0),
ast.NumVal(1.5),
ast.CompOpType.LTE),
ast.NumVal(1.0),
ast.NumVal(2.0))),
ast.NumVal(0.5),
ast.BinNumOpType.MUL),
ast.BinNumExpr(
ast.SubroutineExpr(
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(0),
ast.NumVal(2.5),
ast.CompOpType.LTE),
ast.NumVal(2.0),
ast.NumVal(3.0))),
ast.NumVal(0.5),
ast.BinNumOpType.MUL),
ast.BinNumOpType.ADD)
assert utils.cmp_exprs(actual, expected)