mirror of https://github.com/inclusionAI/AReaL
1226 lines
40 KiB
Python
1226 lines
40 KiB
Python
import re
|
|
|
|
import sympy
|
|
from antlr4 import CommonTokenStream, InputStream
|
|
from antlr4.error.ErrorListener import ErrorListener
|
|
from sympy import apart, expand, expand_trig, factor, matrix_symbols, simplify
|
|
|
|
try:
|
|
from gen.PSLexer import PSLexer
|
|
from gen.PSListener import PSListener
|
|
from gen.PSParser import PSParser
|
|
except Exception:
|
|
from .gen.PSParser import PSParser
|
|
from .gen.PSLexer import PSLexer
|
|
from .gen.PSListener import PSListener
|
|
|
|
import hashlib
|
|
|
|
from sympy.parsing.sympy_parser import parse_expr
|
|
from sympy.printing.str import StrPrinter
|
|
|
|
is_real = None
|
|
|
|
frac_type = r"\frac"
|
|
|
|
variances = {}
|
|
var = {}
|
|
|
|
VARIABLE_VALUES = {}
|
|
|
|
|
|
def set_real(value):
|
|
global is_real
|
|
is_real = value
|
|
|
|
|
|
def set_variances(vars):
|
|
global variances
|
|
variances = vars
|
|
global var
|
|
var = {}
|
|
for variance in vars:
|
|
var[str(variance)] = vars[variance]
|
|
|
|
|
|
def latex2sympy(sympy: str, variable_values={}):
|
|
# record frac
|
|
global frac_type
|
|
if sympy.find(r"\frac") != -1:
|
|
frac_type = r"\frac"
|
|
if sympy.find(r"\dfrac") != -1:
|
|
frac_type = r"\dfrac"
|
|
if sympy.find(r"\tfrac") != -1:
|
|
frac_type = r"\tfrac"
|
|
sympy = sympy.replace(r"\dfrac", r"\frac")
|
|
sympy = sympy.replace(r"\tfrac", r"\frac")
|
|
# Translate Transpose
|
|
sympy = sympy.replace(r"\mathrm{T}", "T", -1)
|
|
# Translate Derivative
|
|
sympy = sympy.replace(r"\mathrm{d}", "d", -1).replace(r"{\rm d}", "d", -1)
|
|
# Translate Matrix
|
|
sympy = sympy.replace(r"\left[\begin{matrix}", r"\begin{bmatrix}", -1).replace(
|
|
r"\end{matrix}\right]", r"\end{bmatrix}", -1
|
|
)
|
|
# Translate Permutation
|
|
sympy = re.sub(
|
|
r"\(([a-zA-Z0-9+\-*/\\ ]+?)\)_{([a-zA-Z0-9+\-*/\\ ]+?)}",
|
|
r"\\frac{(\1)!}{((\1)-(\2))!}",
|
|
sympy,
|
|
)
|
|
# Remove \displaystyle
|
|
sympy = sympy.replace(r"\displaystyle", " ", -1)
|
|
# Remove \quad
|
|
sympy = (
|
|
sympy.replace(r"\quad", " ", -1)
|
|
.replace(r"\qquad", " ", -1)
|
|
.replace(r"~", " ", -1)
|
|
.replace(r"\,", " ", -1)
|
|
)
|
|
# Remove $
|
|
sympy = sympy.replace(r"$", " ", -1)
|
|
|
|
# variable values
|
|
global VARIABLE_VALUES
|
|
if len(variable_values) > 0:
|
|
VARIABLE_VALUES = variable_values
|
|
else:
|
|
VARIABLE_VALUES = {}
|
|
|
|
# setup listener
|
|
matherror = MathErrorListener(sympy)
|
|
|
|
# stream input
|
|
stream = InputStream(sympy)
|
|
lex = PSLexer(stream)
|
|
lex.removeErrorListeners()
|
|
lex.addErrorListener(matherror)
|
|
|
|
tokens = CommonTokenStream(lex)
|
|
parser = PSParser(tokens)
|
|
|
|
# remove default console error listener
|
|
parser.removeErrorListeners()
|
|
parser.addErrorListener(matherror)
|
|
|
|
# process the input
|
|
return_data = None
|
|
math = parser.math()
|
|
|
|
# if a list
|
|
if math.relation_list():
|
|
return_data = []
|
|
|
|
# go over list items
|
|
relation_list = math.relation_list().relation_list_content()
|
|
for list_item in relation_list.relation():
|
|
expr = convert_relation(list_item)
|
|
return_data.append(expr)
|
|
|
|
# if not, do default
|
|
else:
|
|
relation = math.relation()
|
|
return_data = convert_relation(relation)
|
|
|
|
return return_data
|
|
|
|
|
|
class MathErrorListener(ErrorListener):
|
|
def __init__(self, src):
|
|
super(ErrorListener, self).__init__()
|
|
self.src = src
|
|
|
|
def syntaxError(self, recog, symbol, line, col, msg, e):
|
|
fmt = "%s\n%s\n%s"
|
|
marker = "~" * col + "^"
|
|
|
|
if msg.startswith("missing"):
|
|
err = fmt % (msg, self.src, marker)
|
|
elif msg.startswith("no viable"):
|
|
err = fmt % ("I expected something else here", self.src, marker)
|
|
elif msg.startswith("mismatched"):
|
|
names = PSParser.literalNames
|
|
expected = [names[i] for i in e.getExpectedTokens() if i < len(names)]
|
|
if len(expected) < 10:
|
|
expected = " ".join(expected)
|
|
err = fmt % ("I expected one of these: " + expected, self.src, marker)
|
|
else:
|
|
err = fmt % ("I expected something else here", self.src, marker)
|
|
else:
|
|
err = fmt % ("I don't understand this", self.src, marker)
|
|
raise Exception(err)
|
|
|
|
|
|
def convert_relation(rel):
|
|
if rel.expr():
|
|
return convert_expr(rel.expr())
|
|
|
|
lh = convert_relation(rel.relation(0))
|
|
rh = convert_relation(rel.relation(1))
|
|
if rel.LT():
|
|
return sympy.StrictLessThan(lh, rh, evaluate=False)
|
|
elif rel.LTE():
|
|
return sympy.LessThan(lh, rh, evaluate=False)
|
|
elif rel.GT():
|
|
return sympy.StrictGreaterThan(lh, rh, evaluate=False)
|
|
elif rel.GTE():
|
|
return sympy.GreaterThan(lh, rh, evaluate=False)
|
|
elif rel.EQUAL():
|
|
return sympy.Eq(lh, rh, evaluate=False)
|
|
elif rel.ASSIGNMENT():
|
|
# !Use Global variances
|
|
if lh.is_Symbol:
|
|
# set value
|
|
variances[lh] = rh
|
|
var[str(lh)] = rh
|
|
return rh
|
|
else:
|
|
# find the symbols in lh - rh
|
|
equation = lh - rh
|
|
syms = equation.atoms(sympy.Symbol)
|
|
if len(syms) > 0:
|
|
# Solve equation
|
|
result = []
|
|
for sym in syms:
|
|
values = sympy.solve(equation, sym)
|
|
for value in values:
|
|
result.append(sympy.Eq(sym, value, evaluate=False))
|
|
return result
|
|
else:
|
|
return sympy.Eq(lh, rh, evaluate=False)
|
|
elif rel.IN():
|
|
# !Use Global variances
|
|
if hasattr(rh, "is_Pow") and rh.is_Pow and hasattr(rh.exp, "is_Mul"):
|
|
n = rh.exp.args[0]
|
|
m = rh.exp.args[1]
|
|
if n in variances:
|
|
n = variances[n]
|
|
if m in variances:
|
|
m = variances[m]
|
|
rh = sympy.MatrixSymbol(lh, n, m)
|
|
variances[lh] = rh
|
|
var[str(lh)] = rh
|
|
else:
|
|
raise Exception("Don't support this form of definition of matrix symbol.")
|
|
return lh
|
|
elif rel.UNEQUAL():
|
|
return sympy.Ne(lh, rh, evaluate=False)
|
|
|
|
|
|
def convert_expr(expr):
|
|
if expr.additive():
|
|
return convert_add(expr.additive())
|
|
|
|
|
|
def convert_elementary_transform(matrix, transform):
|
|
if transform.transform_scale():
|
|
transform_scale = transform.transform_scale()
|
|
transform_atom = transform_scale.transform_atom()
|
|
k = None
|
|
num = int(transform_atom.NUMBER().getText()) - 1
|
|
if transform_scale.expr():
|
|
k = convert_expr(transform_scale.expr())
|
|
elif transform_scale.group():
|
|
k = convert_expr(transform_scale.group().expr())
|
|
elif transform_scale.SUB():
|
|
k = -1
|
|
else:
|
|
k = 1
|
|
if transform_atom.LETTER_NO_E().getText() == "r":
|
|
matrix = matrix.elementary_row_op(op="n->kn", row=num, k=k)
|
|
elif transform_atom.LETTER_NO_E().getText() == "c":
|
|
matrix = matrix.elementary_col_op(op="n->kn", col=num, k=k)
|
|
else:
|
|
raise Exception("Row and col don's match")
|
|
|
|
elif transform.transform_swap():
|
|
first_atom = transform.transform_swap().transform_atom()[0]
|
|
second_atom = transform.transform_swap().transform_atom()[1]
|
|
first_num = int(first_atom.NUMBER().getText()) - 1
|
|
second_num = int(second_atom.NUMBER().getText()) - 1
|
|
if first_atom.LETTER_NO_E().getText() != second_atom.LETTER_NO_E().getText():
|
|
raise Exception("Row and col don's match")
|
|
elif first_atom.LETTER_NO_E().getText() == "r":
|
|
matrix = matrix.elementary_row_op(
|
|
op="n<->m", row1=first_num, row2=second_num
|
|
)
|
|
elif first_atom.LETTER_NO_E().getText() == "c":
|
|
matrix = matrix.elementary_col_op(
|
|
op="n<->m", col1=first_num, col2=second_num
|
|
)
|
|
else:
|
|
raise Exception("Row and col don's match")
|
|
|
|
elif transform.transform_assignment():
|
|
first_atom = transform.transform_assignment().transform_atom()
|
|
second_atom = (
|
|
transform.transform_assignment().transform_scale().transform_atom()
|
|
)
|
|
transform_scale = transform.transform_assignment().transform_scale()
|
|
k = None
|
|
if transform_scale.expr():
|
|
k = convert_expr(transform_scale.expr())
|
|
elif transform_scale.group():
|
|
k = convert_expr(transform_scale.group().expr())
|
|
elif transform_scale.SUB():
|
|
k = -1
|
|
else:
|
|
k = 1
|
|
first_num = int(first_atom.NUMBER().getText()) - 1
|
|
second_num = int(second_atom.NUMBER().getText()) - 1
|
|
if first_atom.LETTER_NO_E().getText() != second_atom.LETTER_NO_E().getText():
|
|
raise Exception("Row and col don's match")
|
|
elif first_atom.LETTER_NO_E().getText() == "r":
|
|
matrix = matrix.elementary_row_op(
|
|
op="n->n+km", k=k, row1=first_num, row2=second_num
|
|
)
|
|
elif first_atom.LETTER_NO_E().getText() == "c":
|
|
matrix = matrix.elementary_col_op(
|
|
op="n->n+km", k=k, col1=first_num, col2=second_num
|
|
)
|
|
else:
|
|
raise Exception("Row and col don's match")
|
|
|
|
return matrix
|
|
|
|
|
|
def convert_matrix(matrix):
|
|
# build matrix
|
|
row = matrix.matrix_row()
|
|
tmp = []
|
|
rows = 0
|
|
mat = None
|
|
|
|
for r in row:
|
|
tmp.append([])
|
|
for expr in r.expr():
|
|
tmp[rows].append(convert_expr(expr))
|
|
rows = rows + 1
|
|
|
|
mat = sympy.Matrix(tmp)
|
|
|
|
if hasattr(matrix, "MATRIX_XRIGHTARROW") and matrix.MATRIX_XRIGHTARROW():
|
|
transforms_list = matrix.elementary_transforms()
|
|
if len(transforms_list) == 1:
|
|
for transform in transforms_list[0].elementary_transform():
|
|
mat = convert_elementary_transform(mat, transform)
|
|
elif len(transforms_list) == 2:
|
|
# firstly transform top of xrightarrow
|
|
for transform in transforms_list[1].elementary_transform():
|
|
mat = convert_elementary_transform(mat, transform)
|
|
# firstly transform bottom of xrightarrow
|
|
for transform in transforms_list[0].elementary_transform():
|
|
mat = convert_elementary_transform(mat, transform)
|
|
|
|
return mat
|
|
|
|
|
|
def add_flat(lh, rh):
|
|
if hasattr(lh, "is_Add") and lh.is_Add or hasattr(rh, "is_Add") and rh.is_Add:
|
|
args = []
|
|
if hasattr(lh, "is_Add") and lh.is_Add:
|
|
args += list(lh.args)
|
|
else:
|
|
args += [lh]
|
|
if hasattr(rh, "is_Add") and rh.is_Add:
|
|
args = args + list(rh.args)
|
|
else:
|
|
args += [rh]
|
|
return sympy.Add(*args, evaluate=False)
|
|
else:
|
|
return sympy.Add(lh, rh, evaluate=False)
|
|
|
|
|
|
def mat_add_flat(lh, rh):
|
|
if (
|
|
hasattr(lh, "is_MatAdd")
|
|
and lh.is_MatAdd
|
|
or hasattr(rh, "is_MatAdd")
|
|
and rh.is_MatAdd
|
|
):
|
|
args = []
|
|
if hasattr(lh, "is_MatAdd") and lh.is_MatAdd:
|
|
args += list(lh.args)
|
|
else:
|
|
args += [lh]
|
|
if hasattr(rh, "is_MatAdd") and rh.is_MatAdd:
|
|
args = args + list(rh.args)
|
|
else:
|
|
args += [rh]
|
|
return sympy.MatAdd(*[arg.doit() for arg in args], evaluate=False)
|
|
else:
|
|
return sympy.MatAdd(lh.doit(), rh.doit(), evaluate=False)
|
|
|
|
|
|
def mul_flat(lh, rh):
|
|
if hasattr(lh, "is_Mul") and lh.is_Mul or hasattr(rh, "is_Mul") and rh.is_Mul:
|
|
args = []
|
|
if hasattr(lh, "is_Mul") and lh.is_Mul:
|
|
args += list(lh.args)
|
|
else:
|
|
args += [lh]
|
|
if hasattr(rh, "is_Mul") and rh.is_Mul:
|
|
args = args + list(rh.args)
|
|
else:
|
|
args += [rh]
|
|
return sympy.Mul(*args, evaluate=False)
|
|
else:
|
|
return sympy.Mul(lh, rh, evaluate=False)
|
|
|
|
|
|
def mat_mul_flat(lh, rh):
|
|
if (
|
|
hasattr(lh, "is_MatMul")
|
|
and lh.is_MatMul
|
|
or hasattr(rh, "is_MatMul")
|
|
and rh.is_MatMul
|
|
):
|
|
args = []
|
|
if hasattr(lh, "is_MatMul") and lh.is_MatMul:
|
|
args += list(lh.args)
|
|
else:
|
|
args += [lh]
|
|
if hasattr(rh, "is_MatMul") and rh.is_MatMul:
|
|
args = args + list(rh.args)
|
|
else:
|
|
args += [rh]
|
|
return sympy.MatMul(*[arg.doit() for arg in args], evaluate=False)
|
|
else:
|
|
if hasattr(lh, "doit") and hasattr(rh, "doit"):
|
|
return sympy.MatMul(lh.doit(), rh.doit(), evaluate=False)
|
|
elif hasattr(lh, "doit") and not hasattr(rh, "doit"):
|
|
return sympy.MatMul(lh.doit(), rh, evaluate=False)
|
|
elif not hasattr(lh, "doit") and hasattr(rh, "doit"):
|
|
return sympy.MatMul(lh, rh.doit(), evaluate=False)
|
|
else:
|
|
return sympy.MatMul(lh, rh, evaluate=False)
|
|
|
|
|
|
def convert_add(add):
|
|
if add.ADD():
|
|
lh = convert_add(add.additive(0))
|
|
rh = convert_add(add.additive(1))
|
|
|
|
if lh.is_Matrix or rh.is_Matrix:
|
|
return mat_add_flat(lh, rh)
|
|
else:
|
|
return add_flat(lh, rh)
|
|
elif add.SUB():
|
|
lh = convert_add(add.additive(0))
|
|
rh = convert_add(add.additive(1))
|
|
|
|
if lh.is_Matrix or rh.is_Matrix:
|
|
return mat_add_flat(lh, mat_mul_flat(-1, rh))
|
|
else:
|
|
# If we want to force ordering for variables this should be:
|
|
# return Sub(lh, rh, evaluate=False)
|
|
if not rh.is_Matrix and rh.func.is_Number:
|
|
rh = -rh
|
|
else:
|
|
rh = mul_flat(-1, rh)
|
|
return add_flat(lh, rh)
|
|
else:
|
|
return convert_mp(add.mp())
|
|
|
|
|
|
def convert_mp(mp):
|
|
if hasattr(mp, "mp"):
|
|
mp_left = mp.mp(0)
|
|
mp_right = mp.mp(1)
|
|
else:
|
|
mp_left = mp.mp_nofunc(0)
|
|
mp_right = mp.mp_nofunc(1)
|
|
|
|
if mp.MUL() or mp.CMD_TIMES() or mp.CMD_CDOT():
|
|
lh = convert_mp(mp_left)
|
|
rh = convert_mp(mp_right)
|
|
|
|
if lh.is_Matrix or rh.is_Matrix:
|
|
return mat_mul_flat(lh, rh)
|
|
else:
|
|
return mul_flat(lh, rh)
|
|
elif mp.DIV() or mp.CMD_DIV() or mp.COLON():
|
|
lh = convert_mp(mp_left)
|
|
rh = convert_mp(mp_right)
|
|
if lh.is_Matrix or rh.is_Matrix:
|
|
return sympy.MatMul(lh, sympy.Pow(rh, -1, evaluate=False), evaluate=False)
|
|
else:
|
|
return sympy.Mul(lh, sympy.Pow(rh, -1, evaluate=False), evaluate=False)
|
|
elif mp.CMD_MOD():
|
|
lh = convert_mp(mp_left)
|
|
rh = convert_mp(mp_right)
|
|
if rh.is_Matrix:
|
|
raise Exception(
|
|
"Cannot perform modulo operation with a matrix as an operand"
|
|
)
|
|
else:
|
|
return sympy.Mod(lh, rh, evaluate=False)
|
|
else:
|
|
if hasattr(mp, "unary"):
|
|
return convert_unary(mp.unary())
|
|
else:
|
|
return convert_unary(mp.unary_nofunc())
|
|
|
|
|
|
def convert_unary(unary):
|
|
if hasattr(unary, "unary"):
|
|
nested_unary = unary.unary()
|
|
else:
|
|
nested_unary = unary.unary_nofunc()
|
|
if hasattr(unary, "postfix_nofunc"):
|
|
first = unary.postfix()
|
|
tail = unary.postfix_nofunc()
|
|
postfix = [first] + tail
|
|
else:
|
|
postfix = unary.postfix()
|
|
|
|
if unary.ADD():
|
|
return convert_unary(nested_unary)
|
|
elif unary.SUB():
|
|
tmp_convert_nested_unary = convert_unary(nested_unary)
|
|
if tmp_convert_nested_unary.is_Matrix:
|
|
return mat_mul_flat(-1, tmp_convert_nested_unary, evaluate=False)
|
|
else:
|
|
if tmp_convert_nested_unary.func.is_Number:
|
|
return -tmp_convert_nested_unary
|
|
else:
|
|
return mul_flat(-1, tmp_convert_nested_unary)
|
|
elif postfix:
|
|
return convert_postfix_list(postfix)
|
|
|
|
|
|
def convert_postfix_list(arr, i=0):
|
|
if i >= len(arr):
|
|
raise Exception("Index out of bounds")
|
|
|
|
res = convert_postfix(arr[i])
|
|
|
|
if (
|
|
isinstance(res, sympy.Expr)
|
|
or isinstance(res, sympy.Matrix)
|
|
or res is sympy.S.EmptySet
|
|
):
|
|
if i == len(arr) - 1:
|
|
return res # nothing to multiply by
|
|
else:
|
|
# multiply by next
|
|
rh = convert_postfix_list(arr, i + 1)
|
|
|
|
if res.is_Matrix or rh.is_Matrix:
|
|
return mat_mul_flat(res, rh)
|
|
else:
|
|
return mul_flat(res, rh)
|
|
elif isinstance(res, tuple) or isinstance(res, list) or isinstance(res, dict):
|
|
return res
|
|
else: # must be derivative
|
|
wrt = res[0]
|
|
if i == len(arr) - 1:
|
|
raise Exception("Expected expression for derivative")
|
|
else:
|
|
expr = convert_postfix_list(arr, i + 1)
|
|
return sympy.Derivative(expr, wrt)
|
|
|
|
|
|
def do_subs(expr, at):
|
|
if at.expr():
|
|
at_expr = convert_expr(at.expr())
|
|
syms = at_expr.atoms(sympy.Symbol)
|
|
if len(syms) == 0:
|
|
return expr
|
|
elif len(syms) > 0:
|
|
sym = next(iter(syms))
|
|
return expr.subs(sym, at_expr)
|
|
elif at.equality():
|
|
lh = convert_expr(at.equality().expr(0))
|
|
rh = convert_expr(at.equality().expr(1))
|
|
return expr.subs(lh, rh)
|
|
|
|
|
|
def convert_postfix(postfix):
|
|
if hasattr(postfix, "exp"):
|
|
exp_nested = postfix.exp()
|
|
else:
|
|
exp_nested = postfix.exp_nofunc()
|
|
|
|
exp = convert_exp(exp_nested)
|
|
for op in postfix.postfix_op():
|
|
if op.BANG():
|
|
if isinstance(exp, list):
|
|
raise Exception("Cannot apply postfix to derivative")
|
|
exp = sympy.factorial(exp, evaluate=False)
|
|
elif op.eval_at():
|
|
ev = op.eval_at()
|
|
at_b = None
|
|
at_a = None
|
|
if ev.eval_at_sup():
|
|
at_b = do_subs(exp, ev.eval_at_sup())
|
|
if ev.eval_at_sub():
|
|
at_a = do_subs(exp, ev.eval_at_sub())
|
|
if at_b is not None and at_a is not None:
|
|
exp = add_flat(at_b, mul_flat(at_a, -1))
|
|
elif at_b is not None:
|
|
exp = at_b
|
|
elif at_a is not None:
|
|
exp = at_a
|
|
elif op.transpose():
|
|
try:
|
|
exp = exp.T
|
|
except:
|
|
try:
|
|
exp = sympy.transpose(exp)
|
|
except:
|
|
pass
|
|
pass
|
|
|
|
return exp
|
|
|
|
|
|
def convert_exp(exp):
|
|
if hasattr(exp, "exp"):
|
|
exp_nested = exp.exp()
|
|
else:
|
|
exp_nested = exp.exp_nofunc()
|
|
|
|
if exp_nested:
|
|
base = convert_exp(exp_nested)
|
|
if isinstance(base, list):
|
|
raise Exception("Cannot raise derivative to power")
|
|
if exp.atom():
|
|
exponent = convert_atom(exp.atom())
|
|
elif exp.expr():
|
|
exponent = convert_expr(exp.expr())
|
|
return sympy.Pow(base, exponent, evaluate=False)
|
|
else:
|
|
if hasattr(exp, "comp"):
|
|
return convert_comp(exp.comp())
|
|
else:
|
|
return convert_comp(exp.comp_nofunc())
|
|
|
|
|
|
def convert_comp(comp):
|
|
if comp.group():
|
|
return convert_expr(comp.group().expr())
|
|
elif comp.norm_group():
|
|
return convert_expr(comp.norm_group().expr()).norm()
|
|
elif comp.abs_group():
|
|
return sympy.Abs(convert_expr(comp.abs_group().expr()), evaluate=False)
|
|
elif comp.floor_group():
|
|
return handle_floor(convert_expr(comp.floor_group().expr()))
|
|
elif comp.ceil_group():
|
|
return handle_ceil(convert_expr(comp.ceil_group().expr()))
|
|
elif comp.atom():
|
|
return convert_atom(comp.atom())
|
|
elif comp.frac():
|
|
return convert_frac(comp.frac())
|
|
elif comp.binom():
|
|
return convert_binom(comp.binom())
|
|
elif comp.matrix():
|
|
return convert_matrix(comp.matrix())
|
|
elif comp.det():
|
|
# !Use Global variances
|
|
return convert_matrix(comp.det()).subs(variances).det()
|
|
elif comp.func():
|
|
return convert_func(comp.func())
|
|
|
|
|
|
def convert_atom(atom):
|
|
if atom.atom_expr():
|
|
atom_expr = atom.atom_expr()
|
|
|
|
# find the atom's text
|
|
atom_text = ""
|
|
if atom_expr.LETTER_NO_E():
|
|
atom_text = atom_expr.LETTER_NO_E().getText()
|
|
if atom_text == "I":
|
|
return sympy.I
|
|
elif atom_expr.GREEK_CMD():
|
|
atom_text = atom_expr.GREEK_CMD().getText()[1:].strip()
|
|
elif atom_expr.OTHER_SYMBOL_CMD():
|
|
atom_text = atom_expr.OTHER_SYMBOL_CMD().getText().strip()
|
|
elif atom_expr.accent():
|
|
atom_accent = atom_expr.accent()
|
|
# get name for accent
|
|
name = atom_accent.start.text
|
|
# name = atom_accent.start.text[1:]
|
|
# exception: check if bar or overline which are treated both as bar
|
|
# if name in ["bar", "overline"]:
|
|
# name = "bar"
|
|
# if name in ["vec", "overrightarrow"]:
|
|
# name = "vec"
|
|
# if name in ["tilde", "widetilde"]:
|
|
# name = "tilde"
|
|
# get the base (variable)
|
|
base = atom_accent.base.getText()
|
|
# set string to base+name
|
|
atom_text = name + "{" + base + "}"
|
|
|
|
# find atom's subscript, if any
|
|
subscript_text = ""
|
|
if atom_expr.subexpr():
|
|
subexpr = atom_expr.subexpr()
|
|
subscript = None
|
|
if subexpr.expr(): # subscript is expr
|
|
subscript = subexpr.expr().getText().strip()
|
|
elif subexpr.atom(): # subscript is atom
|
|
subscript = subexpr.atom().getText().strip()
|
|
elif subexpr.args(): # subscript is args
|
|
subscript = subexpr.args().getText().strip()
|
|
subscript_inner_text = StrPrinter().doprint(subscript)
|
|
if len(subscript_inner_text) > 1:
|
|
subscript_text = "_{" + subscript_inner_text + "}"
|
|
else:
|
|
subscript_text = "_" + subscript_inner_text
|
|
|
|
# construct the symbol using the text and optional subscript
|
|
atom_symbol = sympy.Symbol(atom_text + subscript_text, real=is_real)
|
|
# for matrix symbol
|
|
matrix_symbol = None
|
|
global var
|
|
if atom_text + subscript_text in var:
|
|
try:
|
|
rh = var[atom_text + subscript_text]
|
|
shape = sympy.shape(rh)
|
|
matrix_symbol = sympy.MatrixSymbol(
|
|
atom_text + subscript_text, shape[0], shape[1]
|
|
)
|
|
variances[matrix_symbol] = variances[atom_symbol]
|
|
except:
|
|
pass
|
|
|
|
# find the atom's superscript, and return as a Pow if found
|
|
if atom_expr.supexpr():
|
|
supexpr = atom_expr.supexpr()
|
|
func_pow = None
|
|
if supexpr.expr():
|
|
func_pow = convert_expr(supexpr.expr())
|
|
else:
|
|
func_pow = convert_atom(supexpr.atom())
|
|
return sympy.Pow(atom_symbol, func_pow, evaluate=False)
|
|
|
|
return atom_symbol if not matrix_symbol else matrix_symbol
|
|
elif atom.SYMBOL():
|
|
s = atom.SYMBOL().getText().replace("\\$", "").replace("\\%", "")
|
|
if s == "\\infty":
|
|
return sympy.oo
|
|
elif s == "\\pi":
|
|
return sympy.pi
|
|
elif s == "\\emptyset":
|
|
return sympy.S.EmptySet
|
|
else:
|
|
raise Exception("Unrecognized symbol")
|
|
elif atom.NUMBER():
|
|
s = atom.NUMBER().getText().replace(",", "")
|
|
try:
|
|
sr = sympy.Rational(s)
|
|
return sr
|
|
except (TypeError, ValueError):
|
|
return sympy.Number(s)
|
|
elif atom.E_NOTATION():
|
|
s = atom.E_NOTATION().getText().replace(",", "")
|
|
try:
|
|
sr = sympy.Rational(s)
|
|
return sr
|
|
except (TypeError, ValueError):
|
|
return sympy.Number(s)
|
|
elif atom.DIFFERENTIAL():
|
|
var = get_differential_var(atom.DIFFERENTIAL())
|
|
return sympy.Symbol("d" + var.name, real=is_real)
|
|
elif atom.mathit():
|
|
text = rule2text(atom.mathit().mathit_text())
|
|
return sympy.Symbol(text, real=is_real)
|
|
elif atom.VARIABLE():
|
|
text = atom.VARIABLE().getText()
|
|
is_percent = text.endswith("\\%")
|
|
trim_amount = 3 if is_percent else 1
|
|
name = text[10:]
|
|
name = name[0 : len(name) - trim_amount]
|
|
|
|
# add hash to distinguish from regular symbols
|
|
hash = hashlib.md5(name.encode()).hexdigest()
|
|
symbol_name = name + hash
|
|
|
|
# replace the variable for already known variable values
|
|
if name in VARIABLE_VALUES:
|
|
# if a sympy class
|
|
if isinstance(VARIABLE_VALUES[name], tuple(sympy.core.all_classes)):
|
|
symbol = VARIABLE_VALUES[name]
|
|
|
|
# if NOT a sympy class
|
|
else:
|
|
symbol = parse_expr(str(VARIABLE_VALUES[name]))
|
|
else:
|
|
symbol = sympy.Symbol(symbol_name, real=is_real)
|
|
|
|
if is_percent:
|
|
return sympy.Mul(symbol, sympy.Pow(100, -1, evaluate=False), evaluate=False)
|
|
|
|
# return the symbol
|
|
return symbol
|
|
|
|
elif atom.PERCENT_NUMBER():
|
|
text = atom.PERCENT_NUMBER().getText().replace("\\%", "").replace(",", "")
|
|
try:
|
|
number = sympy.Rational(text)
|
|
except (TypeError, ValueError):
|
|
number = sympy.Number(text)
|
|
percent = sympy.Rational(number, 100)
|
|
return percent
|
|
|
|
|
|
def rule2text(ctx):
|
|
stream = ctx.start.getInputStream()
|
|
# starting index of starting token
|
|
startIdx = ctx.start.start
|
|
# stopping index of stopping token
|
|
stopIdx = ctx.stop.stop
|
|
|
|
return stream.getText(startIdx, stopIdx)
|
|
|
|
|
|
def convert_frac(frac):
|
|
diff_op = False
|
|
partial_op = False
|
|
lower_itv = frac.lower.getSourceInterval()
|
|
lower_itv_len = lower_itv[1] - lower_itv[0] + 1
|
|
if (
|
|
frac.lower.start == frac.lower.stop
|
|
and frac.lower.start.type == PSLexer.DIFFERENTIAL
|
|
):
|
|
wrt = get_differential_var_str(frac.lower.start.text)
|
|
diff_op = True
|
|
elif (
|
|
lower_itv_len == 2
|
|
and frac.lower.start.type == PSLexer.SYMBOL
|
|
and frac.lower.start.text == "\\partial"
|
|
and (
|
|
frac.lower.stop.type == PSLexer.LETTER_NO_E
|
|
or frac.lower.stop.type == PSLexer.SYMBOL
|
|
)
|
|
):
|
|
partial_op = True
|
|
wrt = frac.lower.stop.text
|
|
if frac.lower.stop.type == PSLexer.SYMBOL:
|
|
wrt = wrt[1:]
|
|
|
|
if diff_op or partial_op:
|
|
wrt = sympy.Symbol(wrt, real=is_real)
|
|
if (
|
|
diff_op
|
|
and frac.upper.start == frac.upper.stop
|
|
and frac.upper.start.type == PSLexer.LETTER_NO_E
|
|
and frac.upper.start.text == "d"
|
|
):
|
|
return [wrt]
|
|
elif (
|
|
partial_op
|
|
and frac.upper.start == frac.upper.stop
|
|
and frac.upper.start.type == PSLexer.SYMBOL
|
|
and frac.upper.start.text == "\\partial"
|
|
):
|
|
return [wrt]
|
|
upper_text = rule2text(frac.upper)
|
|
|
|
expr_top = None
|
|
if diff_op and upper_text.startswith("d"):
|
|
expr_top = latex2sympy(upper_text[1:])
|
|
elif partial_op and frac.upper.start.text == "\\partial":
|
|
expr_top = latex2sympy(upper_text[len("\\partial") :])
|
|
if expr_top:
|
|
return sympy.Derivative(expr_top, wrt)
|
|
|
|
expr_top = convert_expr(frac.upper)
|
|
expr_bot = convert_expr(frac.lower)
|
|
if expr_top.is_Matrix or expr_bot.is_Matrix:
|
|
return sympy.MatMul(
|
|
expr_top, sympy.Pow(expr_bot, -1, evaluate=False), evaluate=False
|
|
)
|
|
else:
|
|
return sympy.Mul(
|
|
expr_top, sympy.Pow(expr_bot, -1, evaluate=False), evaluate=False
|
|
)
|
|
|
|
|
|
def convert_binom(binom):
|
|
expr_top = convert_expr(binom.upper)
|
|
expr_bot = convert_expr(binom.lower)
|
|
return sympy.binomial(expr_top, expr_bot)
|
|
|
|
|
|
def convert_func(func):
|
|
if func.func_normal_single_arg():
|
|
if func.L_PAREN(): # function called with parenthesis
|
|
arg = convert_func_arg(func.func_single_arg())
|
|
else:
|
|
arg = convert_func_arg(func.func_single_arg_noparens())
|
|
|
|
name = func.func_normal_single_arg().start.text[1:]
|
|
|
|
# change arc<trig> -> a<trig>
|
|
if name in ["arcsin", "arccos", "arctan", "arccsc", "arcsec", "arccot"]:
|
|
name = "a" + name[3:]
|
|
expr = getattr(sympy.functions, name)(arg, evaluate=False)
|
|
elif name in ["arsinh", "arcosh", "artanh"]:
|
|
name = "a" + name[2:]
|
|
expr = getattr(sympy.functions, name)(arg, evaluate=False)
|
|
elif name in ["arcsinh", "arccosh", "arctanh"]:
|
|
name = "a" + name[3:]
|
|
expr = getattr(sympy.functions, name)(arg, evaluate=False)
|
|
elif name == "operatorname":
|
|
operatorname = func.func_normal_single_arg().func_operator_name.getText()
|
|
|
|
if operatorname in ["arsinh", "arcosh", "artanh"]:
|
|
operatorname = "a" + operatorname[2:]
|
|
expr = getattr(sympy.functions, operatorname)(arg, evaluate=False)
|
|
elif operatorname in ["arcsinh", "arccosh", "arctanh"]:
|
|
operatorname = "a" + operatorname[3:]
|
|
expr = getattr(sympy.functions, operatorname)(arg, evaluate=False)
|
|
elif operatorname == "floor":
|
|
expr = handle_floor(arg)
|
|
elif operatorname == "ceil":
|
|
expr = handle_ceil(arg)
|
|
elif operatorname == "eye":
|
|
expr = sympy.eye(arg)
|
|
elif operatorname == "rank":
|
|
expr = sympy.Integer(arg.rank())
|
|
elif operatorname in ["trace", "tr"]:
|
|
expr = arg.trace()
|
|
elif operatorname == "rref":
|
|
expr = arg.rref()[0]
|
|
elif operatorname == "nullspace":
|
|
expr = arg.nullspace()
|
|
elif operatorname == "norm":
|
|
expr = arg.norm()
|
|
elif operatorname == "cols":
|
|
expr = [arg.col(i) for i in range(arg.cols)]
|
|
elif operatorname == "rows":
|
|
expr = [arg.row(i) for i in range(arg.rows)]
|
|
elif operatorname in ["eig", "eigen", "diagonalize"]:
|
|
expr = arg.diagonalize()
|
|
elif operatorname in ["eigenvals", "eigenvalues"]:
|
|
expr = arg.eigenvals()
|
|
elif operatorname in ["eigenvects", "eigenvectors"]:
|
|
expr = arg.eigenvects()
|
|
elif operatorname in ["svd", "SVD"]:
|
|
expr = arg.singular_value_decomposition()
|
|
elif name in ["log", "ln"]:
|
|
if func.subexpr():
|
|
if func.subexpr().atom():
|
|
base = convert_atom(func.subexpr().atom())
|
|
else:
|
|
base = convert_expr(func.subexpr().expr())
|
|
elif name == "log":
|
|
base = 10
|
|
elif name == "ln":
|
|
base = sympy.E
|
|
expr = sympy.log(arg, base, evaluate=False)
|
|
elif name in ["exp", "exponentialE"]:
|
|
expr = sympy.exp(arg)
|
|
elif name == "floor":
|
|
expr = handle_floor(arg)
|
|
elif name == "ceil":
|
|
expr = handle_ceil(arg)
|
|
elif name == "det":
|
|
expr = arg.det()
|
|
|
|
func_pow = None
|
|
should_pow = True
|
|
if func.supexpr():
|
|
if func.supexpr().expr():
|
|
func_pow = convert_expr(func.supexpr().expr())
|
|
else:
|
|
func_pow = convert_atom(func.supexpr().atom())
|
|
|
|
if name in ["sin", "cos", "tan", "csc", "sec", "cot", "sinh", "cosh", "tanh"]:
|
|
if func_pow == -1:
|
|
name = "a" + name
|
|
should_pow = False
|
|
expr = getattr(sympy.functions, name)(arg, evaluate=False)
|
|
|
|
if func_pow and should_pow:
|
|
expr = sympy.Pow(expr, func_pow, evaluate=False)
|
|
|
|
return expr
|
|
|
|
elif func.func_normal_multi_arg():
|
|
if func.L_PAREN(): # function called with parenthesis
|
|
args = func.func_multi_arg().getText().split(",")
|
|
else:
|
|
args = func.func_multi_arg_noparens().split(",")
|
|
|
|
args = list(map(lambda arg: latex2sympy(arg, VARIABLE_VALUES), args))
|
|
name = func.func_normal_multi_arg().start.text[1:]
|
|
|
|
if name == "operatorname":
|
|
operatorname = func.func_normal_multi_arg().func_operator_name.getText()
|
|
if operatorname in ["gcd", "lcm"]:
|
|
expr = handle_gcd_lcm(operatorname, args)
|
|
elif operatorname == "zeros":
|
|
expr = sympy.zeros(*args)
|
|
elif operatorname == "ones":
|
|
expr = sympy.ones(*args)
|
|
elif operatorname == "diag":
|
|
expr = sympy.diag(*args)
|
|
elif operatorname == "hstack":
|
|
expr = sympy.Matrix.hstack(*args)
|
|
elif operatorname == "vstack":
|
|
expr = sympy.Matrix.vstack(*args)
|
|
elif operatorname in ["orth", "ortho", "orthogonal", "orthogonalize"]:
|
|
if len(args) == 1:
|
|
arg = args[0]
|
|
expr = sympy.matrices.GramSchmidt(
|
|
[arg.col(i) for i in range(arg.cols)], True
|
|
)
|
|
else:
|
|
expr = sympy.matrices.GramSchmidt(args, True)
|
|
elif name in ["gcd", "lcm"]:
|
|
expr = handle_gcd_lcm(name, args)
|
|
elif name in ["max", "min"]:
|
|
name = name[0].upper() + name[1:]
|
|
expr = getattr(sympy.functions, name)(*args, evaluate=False)
|
|
|
|
func_pow = None
|
|
should_pow = True
|
|
if func.supexpr():
|
|
if func.supexpr().expr():
|
|
func_pow = convert_expr(func.supexpr().expr())
|
|
else:
|
|
func_pow = convert_atom(func.supexpr().atom())
|
|
|
|
if func_pow and should_pow:
|
|
expr = sympy.Pow(expr, func_pow, evaluate=False)
|
|
|
|
return expr
|
|
elif func.atom_expr_no_supexpr():
|
|
# define a function
|
|
f = sympy.Function(func.atom_expr_no_supexpr().getText())
|
|
# args
|
|
args = func.func_common_args().getText().split(",")
|
|
if args[-1] == "":
|
|
args = args[:-1]
|
|
args = [latex2sympy(arg, VARIABLE_VALUES) for arg in args]
|
|
# supexpr
|
|
if func.supexpr():
|
|
if func.supexpr().expr():
|
|
expr = convert_expr(func.supexpr().expr())
|
|
else:
|
|
expr = convert_atom(func.supexpr().atom())
|
|
return sympy.Pow(f(*args), expr, evaluate=False)
|
|
else:
|
|
return f(*args)
|
|
elif func.FUNC_INT():
|
|
return handle_integral(func)
|
|
elif func.FUNC_SQRT():
|
|
expr = convert_expr(func.base)
|
|
if func.root:
|
|
r = convert_expr(func.root)
|
|
return sympy.Pow(expr, 1 / r, evaluate=False)
|
|
else:
|
|
return sympy.Pow(expr, sympy.S.Half, evaluate=False)
|
|
elif func.FUNC_SUM():
|
|
return handle_sum_or_prod(func, "summation")
|
|
elif func.FUNC_PROD():
|
|
return handle_sum_or_prod(func, "product")
|
|
elif func.FUNC_LIM():
|
|
return handle_limit(func)
|
|
elif func.EXP_E():
|
|
return handle_exp(func)
|
|
|
|
|
|
def convert_func_arg(arg):
|
|
if hasattr(arg, "expr"):
|
|
return convert_expr(arg.expr())
|
|
else:
|
|
return convert_mp(arg.mp_nofunc())
|
|
|
|
|
|
def handle_integral(func):
|
|
if func.additive():
|
|
integrand = convert_add(func.additive())
|
|
elif func.frac():
|
|
integrand = convert_frac(func.frac())
|
|
else:
|
|
integrand = 1
|
|
|
|
int_var = None
|
|
if func.DIFFERENTIAL():
|
|
int_var = get_differential_var(func.DIFFERENTIAL())
|
|
else:
|
|
for sym in integrand.atoms(sympy.Symbol):
|
|
s = str(sym)
|
|
if len(s) > 1 and s[0] == "d":
|
|
if s[1] == "\\":
|
|
int_var = sympy.Symbol(s[2:], real=is_real)
|
|
else:
|
|
int_var = sympy.Symbol(s[1:], real=is_real)
|
|
int_sym = sym
|
|
if int_var:
|
|
integrand = integrand.subs(int_sym, 1)
|
|
else:
|
|
# Assume dx by default
|
|
int_var = sympy.Symbol("x", real=is_real)
|
|
|
|
if func.subexpr():
|
|
if func.subexpr().atom():
|
|
lower = convert_atom(func.subexpr().atom())
|
|
else:
|
|
lower = convert_expr(func.subexpr().expr())
|
|
if func.supexpr().atom():
|
|
upper = convert_atom(func.supexpr().atom())
|
|
else:
|
|
upper = convert_expr(func.supexpr().expr())
|
|
return sympy.Integral(integrand, (int_var, lower, upper))
|
|
else:
|
|
return sympy.Integral(integrand, int_var)
|
|
|
|
|
|
def handle_sum_or_prod(func, name):
|
|
val = convert_mp(func.mp())
|
|
iter_var = convert_expr(func.subeq().equality().expr(0))
|
|
start = convert_expr(func.subeq().equality().expr(1))
|
|
if func.supexpr().expr(): # ^{expr}
|
|
end = convert_expr(func.supexpr().expr())
|
|
else: # ^atom
|
|
end = convert_atom(func.supexpr().atom())
|
|
|
|
if name == "summation":
|
|
return sympy.Sum(val, (iter_var, start, end))
|
|
elif name == "product":
|
|
return sympy.Product(val, (iter_var, start, end))
|
|
|
|
|
|
def handle_limit(func):
|
|
sub = func.limit_sub()
|
|
if sub.LETTER_NO_E():
|
|
var = sympy.Symbol(sub.LETTER_NO_E().getText(), real=is_real)
|
|
elif sub.GREEK_CMD():
|
|
var = sympy.Symbol(sub.GREEK_CMD().getText()[1:].strip(), real=is_real)
|
|
elif sub.OTHER_SYMBOL_CMD():
|
|
var = sympy.Symbol(sub.OTHER_SYMBOL_CMD().getText().strip(), real=is_real)
|
|
else:
|
|
var = sympy.Symbol("x", real=is_real)
|
|
if sub.SUB():
|
|
direction = "-"
|
|
else:
|
|
direction = "+"
|
|
approaching = convert_expr(sub.expr())
|
|
content = convert_mp(func.mp())
|
|
|
|
return sympy.Limit(content, var, approaching, direction)
|
|
|
|
|
|
def handle_exp(func):
|
|
if func.supexpr():
|
|
if func.supexpr().expr(): # ^{expr}
|
|
exp_arg = convert_expr(func.supexpr().expr())
|
|
else: # ^atom
|
|
exp_arg = convert_atom(func.supexpr().atom())
|
|
else:
|
|
exp_arg = 1
|
|
return sympy.exp(exp_arg)
|
|
|
|
|
|
def handle_gcd_lcm(f, args):
|
|
"""Return the result of gcd() or lcm(), as UnevaluatedExpr.
|
|
|
|
f: str - name of function ("gcd" or "lcm")
|
|
args: List[Expr] - list of function arguments
|
|
"""
|
|
|
|
args = tuple(map(sympy.nsimplify, args))
|
|
|
|
# gcd() and lcm() don't support evaluate=False
|
|
return sympy.UnevaluatedExpr(getattr(sympy, f)(args))
|
|
|
|
|
|
def handle_floor(expr):
|
|
"""Apply floor() then return the floored expression.
|
|
|
|
expr: Expr - sympy expression as an argument to floor()
|
|
"""
|
|
return sympy.functions.floor(expr, evaluate=False)
|
|
|
|
|
|
def handle_ceil(expr):
|
|
"""Apply ceil() then return the ceil-ed expression.
|
|
|
|
expr: Expr - sympy expression as an argument to ceil()
|
|
"""
|
|
return sympy.functions.ceiling(expr, evaluate=False)
|
|
|
|
|
|
def get_differential_var(d):
|
|
text = get_differential_var_str(d.getText())
|
|
return sympy.Symbol(text, real=is_real)
|
|
|
|
|
|
def get_differential_var_str(text):
|
|
for i in range(1, len(text)):
|
|
c = text[i]
|
|
if not (c == " " or c == "\r" or c == "\n" or c == "\t"):
|
|
idx = i
|
|
break
|
|
text = text[idx:]
|
|
if text[0] == "\\":
|
|
text = text[1:]
|
|
return text
|
|
|
|
|
|
def latex(tex):
|
|
global frac_type
|
|
result = sympy.latex(tex)
|
|
result = (
|
|
result.replace(r"\frac", frac_type, -1)
|
|
.replace(r"\dfrac", frac_type, -1)
|
|
.replace(r"\tfrac", frac_type, -1)
|
|
)
|
|
result = result.replace(r"\left[\begin{matrix}", r"\begin{bmatrix}", -1).replace(
|
|
r"\end{matrix}\right]", r"\end{bmatrix}", -1
|
|
)
|
|
result = result.replace(r"\left", r"", -1).replace(r"\right", r"", -1)
|
|
result = result.replace(r" )", r")", -1)
|
|
result = result.replace(r"\log", r"\ln", -1)
|
|
return result
|
|
|
|
|
|
def latex2latex(tex):
|
|
result = latex2sympy(tex)
|
|
# if result is a list or tuple or dict
|
|
if (
|
|
isinstance(result, list)
|
|
or isinstance(result, tuple)
|
|
or isinstance(result, dict)
|
|
):
|
|
return latex(result)
|
|
else:
|
|
return latex(simplify(result.subs(variances).doit().doit()))
|
|
|
|
|
|
# Set image value
|
|
latex2latex("i=I")
|
|
latex2latex("j=I")
|
|
# set Identity(i)
|
|
for i in range(1, 10):
|
|
lh = sympy.Symbol(r"\bm{I}_" + str(i), real=False)
|
|
lh_m = sympy.MatrixSymbol(r"\bm{I}_" + str(i), i, i)
|
|
rh = sympy.Identity(i).as_mutable()
|
|
variances[lh] = rh
|
|
variances[lh_m] = rh
|
|
var[str(lh)] = rh
|
|
|
|
if __name__ == "__main__":
|
|
# latex2latex(r'A_1=\begin{bmatrix}1 & 2 & 3 & 4 \\ 5 & 6 & 7 & 8\end{bmatrix}')
|
|
# latex2latex(r'b_1=\begin{bmatrix}1 \\ 2 \\ 3 \\ 4\end{bmatrix}')
|
|
# tex = r"(x+2)|_{x=y+1}"
|
|
# tex = r"\operatorname{zeros}(3)"
|
|
tex = r"\operatorname{rows}(\begin{bmatrix}1 & 2 \\ 3 & 4\end{bmatrix})"
|
|
# print("latex2latex:", latex2latex(tex))
|
|
math = latex2sympy(tex)
|
|
# math = math.subs(variances)
|
|
print("latex:", tex)
|
|
# print("var:", variances)
|
|
print("raw_math:", math)
|
|
# print("math:", latex(math.doit()))
|
|
# print("math_type:", type(math.doit()))
|
|
# print("shape:", (math.doit()).shape)
|
|
print("cal:", latex2latex(tex))
|