Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def _evaluate_symbols(self, expression, variables, parameters):
symbols = symvar(expression)
f = Function(symbols, [expression])
f.init()
f_in = []
for sym in symbols:
[child, name] = self.symbol_dict[sym.name()]
if name in child._variables:
f_in.append(variables[child.label, name])
elif name in child._parameters:
f_in.append(parameters[child.label, name])
return evalf(f, f_in)
if expr.is_constant():
return 0.0
elif expr.is_symbolic():
if expr.name() in states:
return der_states[expr.name()].symbol
elif expr.name() in alg_states:
# This algebraic state must now become a differentiated state.
states[expr.name()] = alg_states.pop(expr.name())
der_sym = ca.MX.sym('der({})'.format(expr.name()))
der_states[expr.name()] = Variable(der_sym, float)
return der_sym
else:
return 0.0
else:
# Differentiate using CasADi and chain rule
orig_deps = ca.symvar(expr)
deps = ca.vertcat(*orig_deps)
J = ca.Function('J', [deps], [ca.jacobian(expr, deps)])
J_sparsity = J.sparsity_out(0)
der_deps = [get_derivative(dep) if J_sparsity.has_nz(0, j) else ca.DM.zeros(dep.size()) for j, dep in enumerate(orig_deps)]
return ca.mtimes(J(deps), ca.vertcat(*der_deps))
return False, None, None
sym, jac = [], []
for child, q_i in self.q_i.items():
for name, ind in q_i.items():
var = self.distr_problem.father.get_variables(child, name, spline=False, symbolic=True, substitute=False)
jj = jacobian(g, var)
jac = horzcat(jac, jj[:, ind])
sym.append(var)
for nghb in self.q_ij.keys():
for child, q_ij in self.q_ij[nghb].items():
for name, ind in q_ij.items():
var = self.distr_problem.father.get_variables(child, name, spline=False, symbolic=True, substitute=False)
jj = jacobian(g, var)
jac = horzcat(jac, jj[:, ind])
sym.append(var)
for sym in symvar(jac):
if sym not in self.par_global.values():
return False, None, None
par = struct_symMX(self.par_global_struct)
A, b = jac, -g
for s in sym:
A = substitute(A, s, np.zeros(s.shape))
b = substitute(b, s, np.zeros(s.shape))
dep_b = [s.name() for s in symvar(b)]
dep_A = [s.name() for s in symvar(b)]
for name, sym in self.par_global.items():
if sym.name() in dep_b:
b = substitute(b, sym, par[name])
if sym.name() in dep_A:
A = substitute(A, sym, par[name])
A = Function('A', [par], [A]).expand()
b = Function('b', [par], [b]).expand()
for child, q_ij in self.q_ij[nghb].items():
for name, ind in q_ij.items():
var = self.distr_problem.father.get_variables(child, name, spline=False, symbolic=True, substitute=False)
jj = jacobian(g, var)
jac = horzcat(jac, jj[:, ind])
sym.append(var)
for sym in symvar(jac):
if sym not in self.par_i.values():
return False, None, None
par = struct_symMX(self.par_struct)
A, b = jac, -g
for s in sym:
A = substitute(A, s, np.zeros(s.shape))
b = substitute(b, s, np.zeros(s.shape))
dep_b = [s.name() for s in symvar(b)]
dep_A = [s.name() for s in symvar(b)]
for name, sym in self.par_i.items():
if sym.name() in dep_b:
b = substitute(b, sym, par[name])
if sym.name() in dep_A:
A = substitute(A, sym, par[name])
A = Function('A', [par], [A]).expand()
b = Function('b', [par], [b]).expand()
return True, A, b
def _expand_simplify_mx(equations):
# Sometimes MX expressions can end up horribly nested and complicated,
# which makes further simplifications like alias detection difficult.
# We can simplify the equations by expanding to SX, and then
# rebuilding them with MX symbols again.
if not equations:
return []
symbols_mx = ca.symvar(ca.veccat(*equations))
assert all(x.shape == (1, 1) for x in symbols_mx), "Vector/Matrix SX symbols cannot be mapped"
f_mx = ca.Function('tmp_mx', symbols_mx, equations).expand()
symbols_sx = [ca.SX.sym(x.name(), *x.shape) for x in symbols_mx]
sx_equations = f_mx.call(symbols_sx)
sx_to_mx_map = {s: m for s, m in zip(symbols_sx, symbols_mx)}
def _sx_to_mx(sx_expr):
if not sx_expr.is_scalar():
rows = []
for i in range(sx_expr.size1()):
cols = []
for j in range(sx_expr.size2()):
cols.append(_sx_to_mx(sx_expr[i, j]))
rows.append(cols)
def get_dependency(expression):
sym = symvar(expression)
f = Function('f', sym, [expression])
dep = {}
for index, sym in enumerate(sym):
J = f.sparsity_jac(index, 0)
dep[sym] = sorted(sum1(J).find())
return dep
def add_to_dict(self, symbol, name):
for sym in symvar(symbol):
if sym in self.symbol_dict:
raise ValueError('Symbol already added for %s' % self.label)
self.symbol_dict[sym.name()] = [self, name]
def _substitute_symbols(self, expr, variables, parameters):
if isinstance(expr, (int, float)):
return expr
for sym in symvar(expr):
[child, name] = self.symbol_dict[sym.name()]
if name in child._variables:
expr = substitute(expr, sym, variables[child.label, name])
elif name in child._parameters:
expr = substitute(expr, sym, parameters[child.label, name])
return expr