Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
ast.BinVectorNumExpr(
ast.SubroutineExpr(
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(0),
ast.NumVal(2.5),
ast.CompOpType.LTE),
ast.VectorVal([
ast.NumVal(1.0),
ast.NumVal(0.0)]),
ast.VectorVal([
ast.NumVal(0.0),
ast.NumVal(1.0)]))),
ast.NumVal(0.5),
ast.BinNumOpType.MUL),
ast.BinNumOpType.ADD)
assert utils.cmp_exprs(actual, expected)
assembler = assemblers.XGBoostModelAssembler(estimator)
actual = assembler.assemble()
exponent = ast.ExpExpr(
ast.SubroutineExpr(
ast.BinNumExpr(
ast.NumVal(0.5),
ast.NumVal(0.0),
ast.BinNumOpType.ADD)),
to_reuse=True)
exponent_sum = ast.BinNumExpr(
ast.BinNumExpr(exponent, exponent, ast.BinNumOpType.ADD),
exponent,
ast.BinNumOpType.ADD,
to_reuse=True)
softmax = ast.BinNumExpr(exponent, exponent_sum, ast.BinNumOpType.DIV)
expected = ast.VectorVal([softmax] * 3)
assert utils.cmp_exprs(actual, expected)
ast.SubroutineExpr(
ast.NumVal(1.0)),
ast.NumVal(0.5),
ast.BinNumOpType.MUL),
ast.BinNumExpr(
ast.SubroutineExpr(
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(0),
ast.NumVal(1.5),
ast.CompOpType.LTE),
ast.NumVal(1.0),
ast.NumVal(2.0))),
ast.NumVal(0.5),
ast.BinNumOpType.MUL),
ast.BinNumOpType.ADD)
assert utils.cmp_exprs(actual, expected)
def test_nested_condition():
left = ast.BinNumExpr(
ast.IfExpr(
ast.CompExpr(ast.NumVal(1),
ast.NumVal(1),
ast.CompOpType.EQ),
ast.NumVal(1),
ast.NumVal(2)),
ast.NumVal(2),
ast.BinNumOpType.ADD)
bool_test = ast.CompExpr(ast.NumVal(1), left, ast.CompOpType.EQ)
expr_nested = ast.IfExpr(bool_test, ast.FeatureRef(2), ast.NumVal(2))
expr = ast.IfExpr(bool_test, expr_nested, ast.NumVal(2))
expected_code = """
func score(input []float64) float64 {
var var0 float64
var var1 float64
if (1) == (1) {
var1 = 1
} else {
var1 = 2
}
def kernel_ast(sup_vec_value):
return ast.SubroutineExpr(
ast.PowExpr(
ast.BinNumExpr(
ast.BinNumExpr(
ast.NumVal(estimator.gamma),
ast.BinNumExpr(
ast.NumVal(sup_vec_value),
ast.FeatureRef(0),
ast.BinNumOpType.MUL),
ast.BinNumOpType.MUL),
ast.NumVal(0.0),
ast.BinNumOpType.ADD),
ast.NumVal(estimator.degree)))
import os
from m2cgen import ast
from m2cgen.interpreters import utils, mixins
from m2cgen.interpreters.c.code_generator import CCodeGenerator
from m2cgen.interpreters.interpreter import ToCodeInterpreter
class CInterpreter(ToCodeInterpreter,
mixins.LinearAlgebraMixin):
supported_bin_vector_ops = {
ast.BinNumOpType.ADD: "add_vectors",
}
supported_bin_vector_num_ops = {
ast.BinNumOpType.MUL: "mul_vector_number",
}
exponent_function_name = "exp"
power_function_name = "pow"
tanh_function_name = "tanh"
def __init__(self, indent=4, *args, **kwargs):
cg = CCodeGenerator(indent=indent)
super(CInterpreter, self).__init__(cg, *args, **kwargs)
def interpret(self, expr):
self._cg.reset_state()
import os
from m2cgen import ast
from m2cgen.interpreters import mixins
from m2cgen.interpreters import utils
from m2cgen.interpreters.interpreter import ToCodeInterpreter
from m2cgen.interpreters.java.code_generator import JavaCodeGenerator
class JavaInterpreter(ToCodeInterpreter,
mixins.LinearAlgebraMixin,
mixins.SubroutinesAsFunctionsMixin):
supported_bin_vector_ops = {
ast.BinNumOpType.ADD: "addVectors",
}
supported_bin_vector_num_ops = {
ast.BinNumOpType.MUL: "mulVectorNumber",
}
exponent_function_name = "Math.exp"
power_function_name = "Math.pow"
tanh_function_name = "Math.tanh"
def __init__(self, package_name=None, class_name="Model", indent=4,
*args, **kwargs):
self.package_name = package_name
self.class_name = class_name
self.indent = indent
def assemble(self):
coef = 1.0 / self.model.n_estimators
trees = self.model.estimators_
def assemble_tree_expr(t):
assembler = TreeModelAssembler(t)
return utils.apply_bin_op(
ast.SubroutineExpr(assembler.assemble()),
ast.NumVal(coef),
ast.BinNumOpType.MUL)
assembled_trees = [assemble_tree_expr(t) for t in trees]
return utils.apply_op_to_expressions(
ast.BinNumOpType.ADD, *assembled_trees)
def _assemble_single_output(self, trees, base_score=0):
if self._tree_limit:
trees = trees[:self._tree_limit]
trees_ast = [self._assemble_tree(t) for t in trees]
to_sum = trees_ast
# In a large tree we need to generate multiple subroutines to avoid
# java limitations https://github.com/BayesWitnesses/m2cgen/issues/103.
trees_num_leaves = [self._count_leaves(t) for t in trees]
if sum(trees_num_leaves) > self._leaves_cutoff_threshold:
to_sum = self._split_into_subroutines(trees_ast, trees_num_leaves)
tmp_ast = utils.apply_op_to_expressions(
ast.BinNumOpType.ADD,
ast.NumVal(base_score),
*to_sum)
result_ast = self._final_transform(tmp_ast)
return ast.SubroutineExpr(result_ast)
import os
from m2cgen import ast
from m2cgen.interpreters import mixins
from m2cgen.interpreters import utils
from m2cgen.interpreters.interpreter import ToCodeInterpreter
from m2cgen.interpreters.javascript.code_generator \
import JavascriptCodeGenerator
class JavascriptInterpreter(ToCodeInterpreter,
mixins.LinearAlgebraMixin):
supported_bin_vector_ops = {
ast.BinNumOpType.ADD: "addVectors",
}
supported_bin_vector_num_ops = {
ast.BinNumOpType.MUL: "mulVectorNumber",
}
exponent_function_name = "Math.exp"
power_function_name = "Math.pow"
tanh_function_name = "Math.tanh"
def __init__(self, indent=4,
*args, **kwargs):
self.indent = indent
cg = JavascriptCodeGenerator(indent=indent)
super(JavascriptInterpreter, self).__init__(cg, *args, **kwargs)