Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
with utils.tmp_dir() as tmp_dirpath:
filename = os.path.join(tmp_dirpath, "tmp.file")
estimator.save_model(filename)
estimator = xgboost.XGBRegressor(base_score=base_score)
estimator.load_model(filename)
assembler = assemblers.XGBoostModelAssembler(estimator)
actual = assembler.assemble()
expected = ast.SubroutineExpr(
ast.BinNumExpr(
ast.BinNumExpr(
ast.NumVal(base_score),
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(12),
ast.NumVal(9.72500038),
ast.CompOpType.GTE),
ast.NumVal(1.67318344),
ast.NumVal(2.92757893)),
ast.BinNumOpType.ADD),
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(5),
ast.NumVal(6.94099998),
ast.CompOpType.GTE),
ast.NumVal(3.3400948),
ast.NumVal(1.72118247)),
ast.BinNumOpType.ADD))
assert utils.cmp_exprs(actual, expected)
def test_nested_condition():
left = ast.BinNumExpr(
ast.IfExpr(
ast.CompExpr(ast.NumVal(1),
ast.NumVal(1),
ast.CompOpType.EQ),
ast.NumVal(1),
ast.NumVal(2)),
ast.NumVal(2),
ast.BinNumOpType.ADD)
bool_test = ast.CompExpr(ast.NumVal(1), left, ast.CompOpType.EQ)
expr_nested = ast.IfExpr(bool_test, ast.FeatureRef(2), ast.NumVal(2))
expr = ast.IfExpr(bool_test, expr_nested, ast.NumVal(2))
expected_code = """
Module Model
Function score(ByRef input_vector() As Double) As Double
actual = assembler.assemble()
sigmoid = ast.BinNumExpr(
ast.NumVal(1),
ast.BinNumExpr(
ast.NumVal(1),
ast.ExpExpr(
ast.BinNumExpr(
ast.NumVal(0),
ast.SubroutineExpr(
ast.BinNumExpr(
ast.BinNumExpr(
ast.NumVal(-0.0),
ast.SubroutineExpr(
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(20),
ast.NumVal(16.7950001),
ast.CompOpType.GTE),
ast.NumVal(-0.17062147),
ast.NumVal(0.1638484))),
ast.BinNumOpType.ADD),
ast.SubroutineExpr(
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(27),
ast.NumVal(0.142349988),
ast.CompOpType.GTE),
ast.NumVal(-0.16087772),
ast.NumVal(0.149866998))),
ast.BinNumOpType.ADD)),
ast.BinNumOpType.SUB)),
actual = assembler.assemble()
expected = ast.SubroutineExpr(
ast.BinNumExpr(
ast.BinNumExpr(
ast.NumVal(0),
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(5),
ast.NumVal(6.8455),
ast.CompOpType.GT),
ast.NumVal(24.007392728914056),
ast.NumVal(22.35695742616179)),
ast.BinNumOpType.ADD),
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(12),
ast.NumVal(9.63),
ast.CompOpType.GT),
ast.NumVal(-0.4903836928981587),
ast.NumVal(0.7222498915097475)),
ast.BinNumOpType.ADD))
assert utils.cmp_exprs(actual, expected)
def test_dependable_condition():
left = ast.BinNumExpr(
ast.IfExpr(
ast.CompExpr(ast.NumVal(1),
ast.NumVal(1),
ast.CompOpType.EQ),
ast.NumVal(1),
ast.NumVal(2)),
ast.NumVal(2),
ast.BinNumOpType.ADD)
right = ast.BinNumExpr(ast.NumVal(1), ast.NumVal(2), ast.BinNumOpType.DIV)
bool_test = ast.CompExpr(left, right, ast.CompOpType.GTE)
expr = ast.IfExpr(bool_test, ast.NumVal(1), ast.FeatureRef(0))
expected_code = """
Module Model
Function score(ByRef input_vector() As Double) As Double
Dim var0 As Double
Dim var1 As Double
If (1) == (1) Then
var1 = 1
Else
var1 = 2
End If
If ((var1) + (2)) >= ((1) / (2)) Then
var0 = 1
Else
def test_if_expr():
expr = ast.IfExpr(
ast.CompExpr(ast.NumVal(1), ast.FeatureRef(0), ast.CompOpType.EQ),
ast.NumVal(2),
ast.NumVal(3))
interpreter = interpreters.CInterpreter()
expected_code = """
double score(double * input) {
double var0;
if ((1) == (input[0])) {
var0 = 2;
} else {
var0 = 3;
}
return var0;
}"""
utils.assert_code_equal(interpreter.interpret(expr), expected_code)
def test_nested_condition():
left = ast.BinNumExpr(
ast.IfExpr(
ast.CompExpr(ast.NumVal(1),
ast.NumVal(1),
ast.CompOpType.EQ),
ast.NumVal(1),
ast.NumVal(2)),
ast.NumVal(2),
ast.BinNumOpType.ADD)
bool_test = ast.CompExpr(ast.NumVal(1), left, ast.CompOpType.EQ)
expr_nested = ast.IfExpr(bool_test, ast.FeatureRef(2), ast.NumVal(2))
expr = ast.IfExpr(bool_test, expr_nested, ast.NumVal(2))
expected_code = """
Module Model
Function score(ByRef input_vector() As Double) As Double
Dim var0 As Double
Dim var1 As Double
If (1) == (1) Then
var1 = 1
Else
var1 = 2
End If
If (1) == ((var1) + (2)) Then
def test_multi_output():
expr = ast.SubroutineExpr(
ast.IfExpr(
ast.CompExpr(
ast.NumVal(1),
ast.NumVal(1),
ast.CompOpType.EQ),
ast.VectorVal([ast.NumVal(1), ast.NumVal(2)]),
ast.VectorVal([ast.NumVal(3), ast.NumVal(4)])))
expected_code = """
Module Model
Function score(ByRef input_vector() As Double) As Double()
Dim var0() As Double
If (1) == (1) Then
Dim var1(1) As Double
var1(0) = 1
var1(1) = 2
var0 = var1
Else