Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_if_expr():
expr = ast.IfExpr(
ast.CompExpr(ast.NumVal(1), ast.FeatureRef(0), ast.CompOpType.EQ),
ast.NumVal(2),
ast.NumVal(3))
interpreter = interpreters.GoInterpreter()
expected_code = """
func score(input []float64) float64 {
var var0 float64
if (1) == (input[0]) {
var0 = 2
} else {
var0 = 3
}
return var0
}"""
utils.assert_code_equal(interpreter.interpret(expr), expected_code)
def test_if_expr():
expr = ast.IfExpr(
ast.CompExpr(ast.NumVal(1), ast.FeatureRef(0), ast.CompOpType.EQ),
ast.NumVal(2),
ast.NumVal(3))
interpreter = interpreters.JavaInterpreter()
expected_code = """
public class Model {
public static double score(double[] input) {
double var0;
if ((1) == (input[0])) {
var0 = 2;
} else {
var0 = 3;
}
return var0;
def test_dependable_condition():
left = ast.BinNumExpr(
ast.IfExpr(
ast.CompExpr(ast.NumVal(1),
ast.NumVal(1),
ast.CompOpType.EQ),
ast.NumVal(1),
ast.NumVal(2)),
ast.NumVal(2),
ast.BinNumOpType.ADD)
right = ast.BinNumExpr(ast.NumVal(1), ast.NumVal(2), ast.BinNumOpType.DIV)
bool_test = ast.CompExpr(left, right, ast.CompOpType.GTE)
expr = ast.IfExpr(bool_test, ast.NumVal(1), ast.FeatureRef(0))
expected_code = """
def score(input):
if (1) == (1):
var1 = 1
else:
var1 = 2
if ((var1) + (2)) >= ((1) / (2)):
var0 = 1
else:
var0 = input[0]
return var0
"""
def test_multi_output():
expr = ast.SubroutineExpr(
ast.IfExpr(
ast.CompExpr(
ast.NumVal(1),
ast.NumVal(1),
ast.CompOpType.EQ),
ast.VectorVal([ast.NumVal(1), ast.NumVal(2)]),
ast.VectorVal([ast.NumVal(3), ast.NumVal(4)])))
expected_code = """
def score(input):
if (1) == (1):
var0 = [1, 2]
else:
var0 = [3, 4]
return var0
"""
interpreter = interpreters.PythonInterpreter()
utils.assert_code_equal(interpreter.interpret(expr), expected_code)
def test_dependable_condition():
left = ast.BinNumExpr(
ast.IfExpr(
ast.CompExpr(ast.NumVal(1),
ast.NumVal(1),
ast.CompOpType.EQ),
ast.NumVal(1),
ast.NumVal(2)),
ast.NumVal(2),
ast.BinNumOpType.ADD)
right = ast.BinNumExpr(ast.NumVal(1), ast.NumVal(2), ast.BinNumOpType.DIV)
bool_test = ast.CompExpr(left, right, ast.CompOpType.GTE)
expr = ast.IfExpr(bool_test, ast.NumVal(1), ast.FeatureRef(0))
expected_code = """
Module Model
Function score(ByRef input_vector() As Double) As Double
Dim var0 As Double
Dim var1 As Double
If (1) == (1) Then
sigmoid = ast.BinNumExpr(
ast.NumVal(1),
ast.BinNumExpr(
ast.NumVal(1),
ast.ExpExpr(
ast.BinNumExpr(
ast.NumVal(0),
ast.SubroutineExpr(
ast.BinNumExpr(
ast.BinNumExpr(
ast.NumVal(0),
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(23),
ast.NumVal(868.2000000000002),
ast.CompOpType.GT),
ast.NumVal(0.2762557140263451),
ast.NumVal(0.6399134166614473)),
ast.BinNumOpType.ADD),
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(27),
ast.NumVal(0.14205000000000004),
ast.CompOpType.GT),
ast.NumVal(-0.2139321843285849),
ast.NumVal(0.1151466338793227)),
ast.BinNumOpType.ADD)),
ast.BinNumOpType.SUB)),
ast.BinNumOpType.ADD),
ast.BinNumOpType.DIV,
to_reuse=True)
def test_nested_condition():
left = ast.BinNumExpr(
ast.IfExpr(
ast.CompExpr(ast.NumVal(1),
ast.NumVal(1),
ast.CompOpType.EQ),
ast.NumVal(1),
ast.NumVal(2)),
ast.NumVal(2),
ast.BinNumOpType.ADD)
bool_test = ast.CompExpr(ast.NumVal(1), left, ast.CompOpType.EQ)
expr_nested = ast.IfExpr(bool_test, ast.FeatureRef(2), ast.NumVal(2))
expr = ast.IfExpr(bool_test, expr_nested, ast.NumVal(2))
expected_code = """
double score(double * input) {
double var0;
double var1;
if ((1) == (1)) {
var1 = 1;
} else {
var1 = 2;
}
if ((1) == ((var1) + (2))) {
double var2;
def _assemble_tree(self, tree):
if "leaf" in tree:
return ast.NumVal(tree["leaf"])
threshold = ast.NumVal(tree["split_condition"])
split = tree["split"]
feature_idx = self._feature_name_to_idx.get(split, split)
feature_ref = ast.FeatureRef(feature_idx)
# Since comparison with NaN (missing) value always returns false we
# should make sure that the node ID specified in the "missing" field
# always ends up in the "else" branch of the ast.IfExpr.
use_lt_comp = tree["missing"] == tree["no"]
if use_lt_comp:
comp_op = ast.CompOpType.LT
true_child_id = tree["yes"]
false_child_id = tree["no"]
else:
comp_op = ast.CompOpType.GTE
true_child_id = tree["no"]
false_child_id = tree["yes"]
return ast.IfExpr(ast.CompExpr(feature_ref, threshold, comp_op),
self._assemble_child_tree(tree, true_child_id),
self._assemble_child_tree(tree, false_child_id))
def _assemble_tree(self, tree):
if "leaf_value" in tree:
return ast.NumVal(tree["leaf_value"])
threshold = ast.NumVal(tree["threshold"])
feature_ref = ast.FeatureRef(tree["split_feature"])
op = ast.CompOpType.from_str_op(tree["decision_type"])
assert op == ast.CompOpType.LTE, "Unexpected comparison op"
# Make sure that if the "default_left" is true the left tree branch
# ends up in the "else" branch of the ast.IfExpr.
if tree["default_left"]:
op = ast.CompOpType.GT
true_child = tree["right_child"]
false_child = tree["left_child"]
else:
true_child = tree["left_child"]
false_child = tree["right_child"]
return ast.IfExpr(
ast.CompExpr(feature_ref, threshold, op),
self._assemble_tree(true_child),
self._assemble_tree(false_child))
def lte(l, r):
return ast.CompExpr(l, r, ast.CompOpType.LTE)