Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
double var0;
double var1;
if ((1) == (1)) {
var1 = 1;
} else {
var1 = 2;
}
if (((var1) + (2)) >= ((1) / (2))) {
var0 = 1;
} else {
var0 = input[0];
}
return var0;
}"""
interpreter = interpreters.CInterpreter()
utils.assert_code_equal(interpreter.interpret(expr), expected_code)
def test_pow_expr():
expr = ast.PowExpr(ast.NumVal(2.0), ast.NumVal(3.0))
interpreter = interpreters.CInterpreter()
expected_code = """
#include
double score(double * input) {
return pow(2.0, 3.0);
}"""
utils.assert_code_equal(interpreter.interpret(expr), expected_code)
ast.CompOpType.EQ),
ast.VectorVal([ast.NumVal(1), ast.NumVal(2)]),
ast.VectorVal([ast.NumVal(3), ast.NumVal(4)])))
expected_code = """
#include
void score(double * input, double * output) {
double var0[2];
if ((1) == (1)) {
memcpy(var0, (double[]){1, 2}, 2 * sizeof(double));
} else {
memcpy(var0, (double[]){3, 4}, 2 * sizeof(double));
}
memcpy(output, var0, 2 * sizeof(double));
}"""
interpreter = interpreters.CInterpreter()
utils.assert_code_equal(interpreter.interpret(expr), expected_code)
def test_exp_expr():
expr = ast.ExpExpr(ast.NumVal(1.0))
interpreter = interpreters.CInterpreter()
expected_code = """
#include
double score(double * input) {
return exp(1.0);
}"""
utils.assert_code_equal(interpreter.interpret(expr), expected_code)
def test_bin_vector_expr():
expr = ast.BinVectorExpr(
ast.VectorVal([ast.NumVal(1), ast.NumVal(2)]),
ast.VectorVal([ast.NumVal(3), ast.NumVal(4)]),
ast.BinNumOpType.ADD)
interpreter = interpreters.CInterpreter()
expected_code = """
#include
void add_vectors(double *v1, double *v2, int size, double *result) {
for(int i = 0; i < size; ++i)
result[i] = v1[i] + v2[i];
}
void mul_vector_number(double *v1, double num, int size, double *result) {
for(int i = 0; i < size; ++i)
result[i] = v1[i] * num;
}
void score(double * input, double * output) {
double var0[2];
add_vectors((double[]){1, 2}, (double[]){3, 4}, 2, var0);
memcpy(output, var0, 2 * sizeof(double));
}"""
def test_bin_vector_num_expr():
expr = ast.BinVectorNumExpr(
ast.VectorVal([ast.NumVal(1), ast.NumVal(2)]),
ast.NumVal(1),
ast.BinNumOpType.MUL)
interpreter = interpreters.CInterpreter()
expected_code = """
#include
void add_vectors(double *v1, double *v2, int size, double *result) {
for(int i = 0; i < size; ++i)
result[i] = v1[i] + v2[i];
}
void mul_vector_number(double *v1, double num, int size, double *result) {
for(int i = 0; i < size; ++i)
result[i] = v1[i] * num;
}
void score(double * input, double * output) {
double var0[2];
mul_vector_number((double[]){1, 2}, 1, 2, var0);
memcpy(output, var0, 2 * sizeof(double));
}"""
def test_if_expr():
expr = ast.IfExpr(
ast.CompExpr(ast.NumVal(1), ast.FeatureRef(0), ast.CompOpType.EQ),
ast.NumVal(2),
ast.NumVal(3))
interpreter = interpreters.CInterpreter()
expected_code = """
double score(double * input) {
double var0;
if ((1) == (input[0])) {
var0 = 2;
} else {
var0 = 3;
}
return var0;
}"""
utils.assert_code_equal(interpreter.interpret(expr), expected_code)
def test_bin_num_expr():
expr = ast.BinNumExpr(
ast.BinNumExpr(
ast.FeatureRef(0), ast.NumVal(-2), ast.BinNumOpType.DIV),
ast.NumVal(2),
ast.BinNumOpType.MUL)
interpreter = interpreters.CInterpreter()
expected_code = """
double score(double * input) {
return ((input[0]) / (-2)) * (2);
}"""
utils.assert_code_equal(interpreter.interpret(expr), expected_code)
def test_tanh_expr():
expr = ast.TanhExpr(ast.NumVal(2.0))
interpreter = interpreters.CInterpreter()
expected_code = """
#include
double score(double * input) {
return tanh(2.0);
}"""
utils.assert_code_equal(interpreter.interpret(expr), expected_code)
def export_to_c(model, indent=4):
"""
Generates a C code representation of the given model.
Parameters
----------
model : object
The model object that should be transpiled into code.
indent : int, optional
The size of indents in the generated code.
Returns
-------
code : string
"""
interpreter = interpreters.CInterpreter(indent=indent)
return _export(model, interpreter)