AReaL/evaluation/latex2sympy/latex2sympy2.py

1226 lines
40 KiB
Python

import re
import sympy
from antlr4 import CommonTokenStream, InputStream
from antlr4.error.ErrorListener import ErrorListener
from sympy import apart, expand, expand_trig, factor, matrix_symbols, simplify
try:
from gen.PSLexer import PSLexer
from gen.PSListener import PSListener
from gen.PSParser import PSParser
except Exception:
from .gen.PSParser import PSParser
from .gen.PSLexer import PSLexer
from .gen.PSListener import PSListener
import hashlib
from sympy.parsing.sympy_parser import parse_expr
from sympy.printing.str import StrPrinter
is_real = None
frac_type = r"\frac"
variances = {}
var = {}
VARIABLE_VALUES = {}
def set_real(value):
global is_real
is_real = value
def set_variances(vars):
global variances
variances = vars
global var
var = {}
for variance in vars:
var[str(variance)] = vars[variance]
def latex2sympy(sympy: str, variable_values={}):
# record frac
global frac_type
if sympy.find(r"\frac") != -1:
frac_type = r"\frac"
if sympy.find(r"\dfrac") != -1:
frac_type = r"\dfrac"
if sympy.find(r"\tfrac") != -1:
frac_type = r"\tfrac"
sympy = sympy.replace(r"\dfrac", r"\frac")
sympy = sympy.replace(r"\tfrac", r"\frac")
# Translate Transpose
sympy = sympy.replace(r"\mathrm{T}", "T", -1)
# Translate Derivative
sympy = sympy.replace(r"\mathrm{d}", "d", -1).replace(r"{\rm d}", "d", -1)
# Translate Matrix
sympy = sympy.replace(r"\left[\begin{matrix}", r"\begin{bmatrix}", -1).replace(
r"\end{matrix}\right]", r"\end{bmatrix}", -1
)
# Translate Permutation
sympy = re.sub(
r"\(([a-zA-Z0-9+\-*/\\ ]+?)\)_{([a-zA-Z0-9+\-*/\\ ]+?)}",
r"\\frac{(\1)!}{((\1)-(\2))!}",
sympy,
)
# Remove \displaystyle
sympy = sympy.replace(r"\displaystyle", " ", -1)
# Remove \quad
sympy = (
sympy.replace(r"\quad", " ", -1)
.replace(r"\qquad", " ", -1)
.replace(r"~", " ", -1)
.replace(r"\,", " ", -1)
)
# Remove $
sympy = sympy.replace(r"$", " ", -1)
# variable values
global VARIABLE_VALUES
if len(variable_values) > 0:
VARIABLE_VALUES = variable_values
else:
VARIABLE_VALUES = {}
# setup listener
matherror = MathErrorListener(sympy)
# stream input
stream = InputStream(sympy)
lex = PSLexer(stream)
lex.removeErrorListeners()
lex.addErrorListener(matherror)
tokens = CommonTokenStream(lex)
parser = PSParser(tokens)
# remove default console error listener
parser.removeErrorListeners()
parser.addErrorListener(matherror)
# process the input
return_data = None
math = parser.math()
# if a list
if math.relation_list():
return_data = []
# go over list items
relation_list = math.relation_list().relation_list_content()
for list_item in relation_list.relation():
expr = convert_relation(list_item)
return_data.append(expr)
# if not, do default
else:
relation = math.relation()
return_data = convert_relation(relation)
return return_data
class MathErrorListener(ErrorListener):
def __init__(self, src):
super(ErrorListener, self).__init__()
self.src = src
def syntaxError(self, recog, symbol, line, col, msg, e):
fmt = "%s\n%s\n%s"
marker = "~" * col + "^"
if msg.startswith("missing"):
err = fmt % (msg, self.src, marker)
elif msg.startswith("no viable"):
err = fmt % ("I expected something else here", self.src, marker)
elif msg.startswith("mismatched"):
names = PSParser.literalNames
expected = [names[i] for i in e.getExpectedTokens() if i < len(names)]
if len(expected) < 10:
expected = " ".join(expected)
err = fmt % ("I expected one of these: " + expected, self.src, marker)
else:
err = fmt % ("I expected something else here", self.src, marker)
else:
err = fmt % ("I don't understand this", self.src, marker)
raise Exception(err)
def convert_relation(rel):
if rel.expr():
return convert_expr(rel.expr())
lh = convert_relation(rel.relation(0))
rh = convert_relation(rel.relation(1))
if rel.LT():
return sympy.StrictLessThan(lh, rh, evaluate=False)
elif rel.LTE():
return sympy.LessThan(lh, rh, evaluate=False)
elif rel.GT():
return sympy.StrictGreaterThan(lh, rh, evaluate=False)
elif rel.GTE():
return sympy.GreaterThan(lh, rh, evaluate=False)
elif rel.EQUAL():
return sympy.Eq(lh, rh, evaluate=False)
elif rel.ASSIGNMENT():
# !Use Global variances
if lh.is_Symbol:
# set value
variances[lh] = rh
var[str(lh)] = rh
return rh
else:
# find the symbols in lh - rh
equation = lh - rh
syms = equation.atoms(sympy.Symbol)
if len(syms) > 0:
# Solve equation
result = []
for sym in syms:
values = sympy.solve(equation, sym)
for value in values:
result.append(sympy.Eq(sym, value, evaluate=False))
return result
else:
return sympy.Eq(lh, rh, evaluate=False)
elif rel.IN():
# !Use Global variances
if hasattr(rh, "is_Pow") and rh.is_Pow and hasattr(rh.exp, "is_Mul"):
n = rh.exp.args[0]
m = rh.exp.args[1]
if n in variances:
n = variances[n]
if m in variances:
m = variances[m]
rh = sympy.MatrixSymbol(lh, n, m)
variances[lh] = rh
var[str(lh)] = rh
else:
raise Exception("Don't support this form of definition of matrix symbol.")
return lh
elif rel.UNEQUAL():
return sympy.Ne(lh, rh, evaluate=False)
def convert_expr(expr):
if expr.additive():
return convert_add(expr.additive())
def convert_elementary_transform(matrix, transform):
if transform.transform_scale():
transform_scale = transform.transform_scale()
transform_atom = transform_scale.transform_atom()
k = None
num = int(transform_atom.NUMBER().getText()) - 1
if transform_scale.expr():
k = convert_expr(transform_scale.expr())
elif transform_scale.group():
k = convert_expr(transform_scale.group().expr())
elif transform_scale.SUB():
k = -1
else:
k = 1
if transform_atom.LETTER_NO_E().getText() == "r":
matrix = matrix.elementary_row_op(op="n->kn", row=num, k=k)
elif transform_atom.LETTER_NO_E().getText() == "c":
matrix = matrix.elementary_col_op(op="n->kn", col=num, k=k)
else:
raise Exception("Row and col don's match")
elif transform.transform_swap():
first_atom = transform.transform_swap().transform_atom()[0]
second_atom = transform.transform_swap().transform_atom()[1]
first_num = int(first_atom.NUMBER().getText()) - 1
second_num = int(second_atom.NUMBER().getText()) - 1
if first_atom.LETTER_NO_E().getText() != second_atom.LETTER_NO_E().getText():
raise Exception("Row and col don's match")
elif first_atom.LETTER_NO_E().getText() == "r":
matrix = matrix.elementary_row_op(
op="n<->m", row1=first_num, row2=second_num
)
elif first_atom.LETTER_NO_E().getText() == "c":
matrix = matrix.elementary_col_op(
op="n<->m", col1=first_num, col2=second_num
)
else:
raise Exception("Row and col don's match")
elif transform.transform_assignment():
first_atom = transform.transform_assignment().transform_atom()
second_atom = (
transform.transform_assignment().transform_scale().transform_atom()
)
transform_scale = transform.transform_assignment().transform_scale()
k = None
if transform_scale.expr():
k = convert_expr(transform_scale.expr())
elif transform_scale.group():
k = convert_expr(transform_scale.group().expr())
elif transform_scale.SUB():
k = -1
else:
k = 1
first_num = int(first_atom.NUMBER().getText()) - 1
second_num = int(second_atom.NUMBER().getText()) - 1
if first_atom.LETTER_NO_E().getText() != second_atom.LETTER_NO_E().getText():
raise Exception("Row and col don's match")
elif first_atom.LETTER_NO_E().getText() == "r":
matrix = matrix.elementary_row_op(
op="n->n+km", k=k, row1=first_num, row2=second_num
)
elif first_atom.LETTER_NO_E().getText() == "c":
matrix = matrix.elementary_col_op(
op="n->n+km", k=k, col1=first_num, col2=second_num
)
else:
raise Exception("Row and col don's match")
return matrix
def convert_matrix(matrix):
# build matrix
row = matrix.matrix_row()
tmp = []
rows = 0
mat = None
for r in row:
tmp.append([])
for expr in r.expr():
tmp[rows].append(convert_expr(expr))
rows = rows + 1
mat = sympy.Matrix(tmp)
if hasattr(matrix, "MATRIX_XRIGHTARROW") and matrix.MATRIX_XRIGHTARROW():
transforms_list = matrix.elementary_transforms()
if len(transforms_list) == 1:
for transform in transforms_list[0].elementary_transform():
mat = convert_elementary_transform(mat, transform)
elif len(transforms_list) == 2:
# firstly transform top of xrightarrow
for transform in transforms_list[1].elementary_transform():
mat = convert_elementary_transform(mat, transform)
# firstly transform bottom of xrightarrow
for transform in transforms_list[0].elementary_transform():
mat = convert_elementary_transform(mat, transform)
return mat
def add_flat(lh, rh):
if hasattr(lh, "is_Add") and lh.is_Add or hasattr(rh, "is_Add") and rh.is_Add:
args = []
if hasattr(lh, "is_Add") and lh.is_Add:
args += list(lh.args)
else:
args += [lh]
if hasattr(rh, "is_Add") and rh.is_Add:
args = args + list(rh.args)
else:
args += [rh]
return sympy.Add(*args, evaluate=False)
else:
return sympy.Add(lh, rh, evaluate=False)
def mat_add_flat(lh, rh):
if (
hasattr(lh, "is_MatAdd")
and lh.is_MatAdd
or hasattr(rh, "is_MatAdd")
and rh.is_MatAdd
):
args = []
if hasattr(lh, "is_MatAdd") and lh.is_MatAdd:
args += list(lh.args)
else:
args += [lh]
if hasattr(rh, "is_MatAdd") and rh.is_MatAdd:
args = args + list(rh.args)
else:
args += [rh]
return sympy.MatAdd(*[arg.doit() for arg in args], evaluate=False)
else:
return sympy.MatAdd(lh.doit(), rh.doit(), evaluate=False)
def mul_flat(lh, rh):
if hasattr(lh, "is_Mul") and lh.is_Mul or hasattr(rh, "is_Mul") and rh.is_Mul:
args = []
if hasattr(lh, "is_Mul") and lh.is_Mul:
args += list(lh.args)
else:
args += [lh]
if hasattr(rh, "is_Mul") and rh.is_Mul:
args = args + list(rh.args)
else:
args += [rh]
return sympy.Mul(*args, evaluate=False)
else:
return sympy.Mul(lh, rh, evaluate=False)
def mat_mul_flat(lh, rh):
if (
hasattr(lh, "is_MatMul")
and lh.is_MatMul
or hasattr(rh, "is_MatMul")
and rh.is_MatMul
):
args = []
if hasattr(lh, "is_MatMul") and lh.is_MatMul:
args += list(lh.args)
else:
args += [lh]
if hasattr(rh, "is_MatMul") and rh.is_MatMul:
args = args + list(rh.args)
else:
args += [rh]
return sympy.MatMul(*[arg.doit() for arg in args], evaluate=False)
else:
if hasattr(lh, "doit") and hasattr(rh, "doit"):
return sympy.MatMul(lh.doit(), rh.doit(), evaluate=False)
elif hasattr(lh, "doit") and not hasattr(rh, "doit"):
return sympy.MatMul(lh.doit(), rh, evaluate=False)
elif not hasattr(lh, "doit") and hasattr(rh, "doit"):
return sympy.MatMul(lh, rh.doit(), evaluate=False)
else:
return sympy.MatMul(lh, rh, evaluate=False)
def convert_add(add):
if add.ADD():
lh = convert_add(add.additive(0))
rh = convert_add(add.additive(1))
if lh.is_Matrix or rh.is_Matrix:
return mat_add_flat(lh, rh)
else:
return add_flat(lh, rh)
elif add.SUB():
lh = convert_add(add.additive(0))
rh = convert_add(add.additive(1))
if lh.is_Matrix or rh.is_Matrix:
return mat_add_flat(lh, mat_mul_flat(-1, rh))
else:
# If we want to force ordering for variables this should be:
# return Sub(lh, rh, evaluate=False)
if not rh.is_Matrix and rh.func.is_Number:
rh = -rh
else:
rh = mul_flat(-1, rh)
return add_flat(lh, rh)
else:
return convert_mp(add.mp())
def convert_mp(mp):
if hasattr(mp, "mp"):
mp_left = mp.mp(0)
mp_right = mp.mp(1)
else:
mp_left = mp.mp_nofunc(0)
mp_right = mp.mp_nofunc(1)
if mp.MUL() or mp.CMD_TIMES() or mp.CMD_CDOT():
lh = convert_mp(mp_left)
rh = convert_mp(mp_right)
if lh.is_Matrix or rh.is_Matrix:
return mat_mul_flat(lh, rh)
else:
return mul_flat(lh, rh)
elif mp.DIV() or mp.CMD_DIV() or mp.COLON():
lh = convert_mp(mp_left)
rh = convert_mp(mp_right)
if lh.is_Matrix or rh.is_Matrix:
return sympy.MatMul(lh, sympy.Pow(rh, -1, evaluate=False), evaluate=False)
else:
return sympy.Mul(lh, sympy.Pow(rh, -1, evaluate=False), evaluate=False)
elif mp.CMD_MOD():
lh = convert_mp(mp_left)
rh = convert_mp(mp_right)
if rh.is_Matrix:
raise Exception(
"Cannot perform modulo operation with a matrix as an operand"
)
else:
return sympy.Mod(lh, rh, evaluate=False)
else:
if hasattr(mp, "unary"):
return convert_unary(mp.unary())
else:
return convert_unary(mp.unary_nofunc())
def convert_unary(unary):
if hasattr(unary, "unary"):
nested_unary = unary.unary()
else:
nested_unary = unary.unary_nofunc()
if hasattr(unary, "postfix_nofunc"):
first = unary.postfix()
tail = unary.postfix_nofunc()
postfix = [first] + tail
else:
postfix = unary.postfix()
if unary.ADD():
return convert_unary(nested_unary)
elif unary.SUB():
tmp_convert_nested_unary = convert_unary(nested_unary)
if tmp_convert_nested_unary.is_Matrix:
return mat_mul_flat(-1, tmp_convert_nested_unary, evaluate=False)
else:
if tmp_convert_nested_unary.func.is_Number:
return -tmp_convert_nested_unary
else:
return mul_flat(-1, tmp_convert_nested_unary)
elif postfix:
return convert_postfix_list(postfix)
def convert_postfix_list(arr, i=0):
if i >= len(arr):
raise Exception("Index out of bounds")
res = convert_postfix(arr[i])
if (
isinstance(res, sympy.Expr)
or isinstance(res, sympy.Matrix)
or res is sympy.S.EmptySet
):
if i == len(arr) - 1:
return res # nothing to multiply by
else:
# multiply by next
rh = convert_postfix_list(arr, i + 1)
if res.is_Matrix or rh.is_Matrix:
return mat_mul_flat(res, rh)
else:
return mul_flat(res, rh)
elif isinstance(res, tuple) or isinstance(res, list) or isinstance(res, dict):
return res
else: # must be derivative
wrt = res[0]
if i == len(arr) - 1:
raise Exception("Expected expression for derivative")
else:
expr = convert_postfix_list(arr, i + 1)
return sympy.Derivative(expr, wrt)
def do_subs(expr, at):
if at.expr():
at_expr = convert_expr(at.expr())
syms = at_expr.atoms(sympy.Symbol)
if len(syms) == 0:
return expr
elif len(syms) > 0:
sym = next(iter(syms))
return expr.subs(sym, at_expr)
elif at.equality():
lh = convert_expr(at.equality().expr(0))
rh = convert_expr(at.equality().expr(1))
return expr.subs(lh, rh)
def convert_postfix(postfix):
if hasattr(postfix, "exp"):
exp_nested = postfix.exp()
else:
exp_nested = postfix.exp_nofunc()
exp = convert_exp(exp_nested)
for op in postfix.postfix_op():
if op.BANG():
if isinstance(exp, list):
raise Exception("Cannot apply postfix to derivative")
exp = sympy.factorial(exp, evaluate=False)
elif op.eval_at():
ev = op.eval_at()
at_b = None
at_a = None
if ev.eval_at_sup():
at_b = do_subs(exp, ev.eval_at_sup())
if ev.eval_at_sub():
at_a = do_subs(exp, ev.eval_at_sub())
if at_b is not None and at_a is not None:
exp = add_flat(at_b, mul_flat(at_a, -1))
elif at_b is not None:
exp = at_b
elif at_a is not None:
exp = at_a
elif op.transpose():
try:
exp = exp.T
except:
try:
exp = sympy.transpose(exp)
except:
pass
pass
return exp
def convert_exp(exp):
if hasattr(exp, "exp"):
exp_nested = exp.exp()
else:
exp_nested = exp.exp_nofunc()
if exp_nested:
base = convert_exp(exp_nested)
if isinstance(base, list):
raise Exception("Cannot raise derivative to power")
if exp.atom():
exponent = convert_atom(exp.atom())
elif exp.expr():
exponent = convert_expr(exp.expr())
return sympy.Pow(base, exponent, evaluate=False)
else:
if hasattr(exp, "comp"):
return convert_comp(exp.comp())
else:
return convert_comp(exp.comp_nofunc())
def convert_comp(comp):
if comp.group():
return convert_expr(comp.group().expr())
elif comp.norm_group():
return convert_expr(comp.norm_group().expr()).norm()
elif comp.abs_group():
return sympy.Abs(convert_expr(comp.abs_group().expr()), evaluate=False)
elif comp.floor_group():
return handle_floor(convert_expr(comp.floor_group().expr()))
elif comp.ceil_group():
return handle_ceil(convert_expr(comp.ceil_group().expr()))
elif comp.atom():
return convert_atom(comp.atom())
elif comp.frac():
return convert_frac(comp.frac())
elif comp.binom():
return convert_binom(comp.binom())
elif comp.matrix():
return convert_matrix(comp.matrix())
elif comp.det():
# !Use Global variances
return convert_matrix(comp.det()).subs(variances).det()
elif comp.func():
return convert_func(comp.func())
def convert_atom(atom):
if atom.atom_expr():
atom_expr = atom.atom_expr()
# find the atom's text
atom_text = ""
if atom_expr.LETTER_NO_E():
atom_text = atom_expr.LETTER_NO_E().getText()
if atom_text == "I":
return sympy.I
elif atom_expr.GREEK_CMD():
atom_text = atom_expr.GREEK_CMD().getText()[1:].strip()
elif atom_expr.OTHER_SYMBOL_CMD():
atom_text = atom_expr.OTHER_SYMBOL_CMD().getText().strip()
elif atom_expr.accent():
atom_accent = atom_expr.accent()
# get name for accent
name = atom_accent.start.text
# name = atom_accent.start.text[1:]
# exception: check if bar or overline which are treated both as bar
# if name in ["bar", "overline"]:
# name = "bar"
# if name in ["vec", "overrightarrow"]:
# name = "vec"
# if name in ["tilde", "widetilde"]:
# name = "tilde"
# get the base (variable)
base = atom_accent.base.getText()
# set string to base+name
atom_text = name + "{" + base + "}"
# find atom's subscript, if any
subscript_text = ""
if atom_expr.subexpr():
subexpr = atom_expr.subexpr()
subscript = None
if subexpr.expr(): # subscript is expr
subscript = subexpr.expr().getText().strip()
elif subexpr.atom(): # subscript is atom
subscript = subexpr.atom().getText().strip()
elif subexpr.args(): # subscript is args
subscript = subexpr.args().getText().strip()
subscript_inner_text = StrPrinter().doprint(subscript)
if len(subscript_inner_text) > 1:
subscript_text = "_{" + subscript_inner_text + "}"
else:
subscript_text = "_" + subscript_inner_text
# construct the symbol using the text and optional subscript
atom_symbol = sympy.Symbol(atom_text + subscript_text, real=is_real)
# for matrix symbol
matrix_symbol = None
global var
if atom_text + subscript_text in var:
try:
rh = var[atom_text + subscript_text]
shape = sympy.shape(rh)
matrix_symbol = sympy.MatrixSymbol(
atom_text + subscript_text, shape[0], shape[1]
)
variances[matrix_symbol] = variances[atom_symbol]
except:
pass
# find the atom's superscript, and return as a Pow if found
if atom_expr.supexpr():
supexpr = atom_expr.supexpr()
func_pow = None
if supexpr.expr():
func_pow = convert_expr(supexpr.expr())
else:
func_pow = convert_atom(supexpr.atom())
return sympy.Pow(atom_symbol, func_pow, evaluate=False)
return atom_symbol if not matrix_symbol else matrix_symbol
elif atom.SYMBOL():
s = atom.SYMBOL().getText().replace("\\$", "").replace("\\%", "")
if s == "\\infty":
return sympy.oo
elif s == "\\pi":
return sympy.pi
elif s == "\\emptyset":
return sympy.S.EmptySet
else:
raise Exception("Unrecognized symbol")
elif atom.NUMBER():
s = atom.NUMBER().getText().replace(",", "")
try:
sr = sympy.Rational(s)
return sr
except (TypeError, ValueError):
return sympy.Number(s)
elif atom.E_NOTATION():
s = atom.E_NOTATION().getText().replace(",", "")
try:
sr = sympy.Rational(s)
return sr
except (TypeError, ValueError):
return sympy.Number(s)
elif atom.DIFFERENTIAL():
var = get_differential_var(atom.DIFFERENTIAL())
return sympy.Symbol("d" + var.name, real=is_real)
elif atom.mathit():
text = rule2text(atom.mathit().mathit_text())
return sympy.Symbol(text, real=is_real)
elif atom.VARIABLE():
text = atom.VARIABLE().getText()
is_percent = text.endswith("\\%")
trim_amount = 3 if is_percent else 1
name = text[10:]
name = name[0 : len(name) - trim_amount]
# add hash to distinguish from regular symbols
hash = hashlib.md5(name.encode()).hexdigest()
symbol_name = name + hash
# replace the variable for already known variable values
if name in VARIABLE_VALUES:
# if a sympy class
if isinstance(VARIABLE_VALUES[name], tuple(sympy.core.all_classes)):
symbol = VARIABLE_VALUES[name]
# if NOT a sympy class
else:
symbol = parse_expr(str(VARIABLE_VALUES[name]))
else:
symbol = sympy.Symbol(symbol_name, real=is_real)
if is_percent:
return sympy.Mul(symbol, sympy.Pow(100, -1, evaluate=False), evaluate=False)
# return the symbol
return symbol
elif atom.PERCENT_NUMBER():
text = atom.PERCENT_NUMBER().getText().replace("\\%", "").replace(",", "")
try:
number = sympy.Rational(text)
except (TypeError, ValueError):
number = sympy.Number(text)
percent = sympy.Rational(number, 100)
return percent
def rule2text(ctx):
stream = ctx.start.getInputStream()
# starting index of starting token
startIdx = ctx.start.start
# stopping index of stopping token
stopIdx = ctx.stop.stop
return stream.getText(startIdx, stopIdx)
def convert_frac(frac):
diff_op = False
partial_op = False
lower_itv = frac.lower.getSourceInterval()
lower_itv_len = lower_itv[1] - lower_itv[0] + 1
if (
frac.lower.start == frac.lower.stop
and frac.lower.start.type == PSLexer.DIFFERENTIAL
):
wrt = get_differential_var_str(frac.lower.start.text)
diff_op = True
elif (
lower_itv_len == 2
and frac.lower.start.type == PSLexer.SYMBOL
and frac.lower.start.text == "\\partial"
and (
frac.lower.stop.type == PSLexer.LETTER_NO_E
or frac.lower.stop.type == PSLexer.SYMBOL
)
):
partial_op = True
wrt = frac.lower.stop.text
if frac.lower.stop.type == PSLexer.SYMBOL:
wrt = wrt[1:]
if diff_op or partial_op:
wrt = sympy.Symbol(wrt, real=is_real)
if (
diff_op
and frac.upper.start == frac.upper.stop
and frac.upper.start.type == PSLexer.LETTER_NO_E
and frac.upper.start.text == "d"
):
return [wrt]
elif (
partial_op
and frac.upper.start == frac.upper.stop
and frac.upper.start.type == PSLexer.SYMBOL
and frac.upper.start.text == "\\partial"
):
return [wrt]
upper_text = rule2text(frac.upper)
expr_top = None
if diff_op and upper_text.startswith("d"):
expr_top = latex2sympy(upper_text[1:])
elif partial_op and frac.upper.start.text == "\\partial":
expr_top = latex2sympy(upper_text[len("\\partial") :])
if expr_top:
return sympy.Derivative(expr_top, wrt)
expr_top = convert_expr(frac.upper)
expr_bot = convert_expr(frac.lower)
if expr_top.is_Matrix or expr_bot.is_Matrix:
return sympy.MatMul(
expr_top, sympy.Pow(expr_bot, -1, evaluate=False), evaluate=False
)
else:
return sympy.Mul(
expr_top, sympy.Pow(expr_bot, -1, evaluate=False), evaluate=False
)
def convert_binom(binom):
expr_top = convert_expr(binom.upper)
expr_bot = convert_expr(binom.lower)
return sympy.binomial(expr_top, expr_bot)
def convert_func(func):
if func.func_normal_single_arg():
if func.L_PAREN(): # function called with parenthesis
arg = convert_func_arg(func.func_single_arg())
else:
arg = convert_func_arg(func.func_single_arg_noparens())
name = func.func_normal_single_arg().start.text[1:]
# change arc<trig> -> a<trig>
if name in ["arcsin", "arccos", "arctan", "arccsc", "arcsec", "arccot"]:
name = "a" + name[3:]
expr = getattr(sympy.functions, name)(arg, evaluate=False)
elif name in ["arsinh", "arcosh", "artanh"]:
name = "a" + name[2:]
expr = getattr(sympy.functions, name)(arg, evaluate=False)
elif name in ["arcsinh", "arccosh", "arctanh"]:
name = "a" + name[3:]
expr = getattr(sympy.functions, name)(arg, evaluate=False)
elif name == "operatorname":
operatorname = func.func_normal_single_arg().func_operator_name.getText()
if operatorname in ["arsinh", "arcosh", "artanh"]:
operatorname = "a" + operatorname[2:]
expr = getattr(sympy.functions, operatorname)(arg, evaluate=False)
elif operatorname in ["arcsinh", "arccosh", "arctanh"]:
operatorname = "a" + operatorname[3:]
expr = getattr(sympy.functions, operatorname)(arg, evaluate=False)
elif operatorname == "floor":
expr = handle_floor(arg)
elif operatorname == "ceil":
expr = handle_ceil(arg)
elif operatorname == "eye":
expr = sympy.eye(arg)
elif operatorname == "rank":
expr = sympy.Integer(arg.rank())
elif operatorname in ["trace", "tr"]:
expr = arg.trace()
elif operatorname == "rref":
expr = arg.rref()[0]
elif operatorname == "nullspace":
expr = arg.nullspace()
elif operatorname == "norm":
expr = arg.norm()
elif operatorname == "cols":
expr = [arg.col(i) for i in range(arg.cols)]
elif operatorname == "rows":
expr = [arg.row(i) for i in range(arg.rows)]
elif operatorname in ["eig", "eigen", "diagonalize"]:
expr = arg.diagonalize()
elif operatorname in ["eigenvals", "eigenvalues"]:
expr = arg.eigenvals()
elif operatorname in ["eigenvects", "eigenvectors"]:
expr = arg.eigenvects()
elif operatorname in ["svd", "SVD"]:
expr = arg.singular_value_decomposition()
elif name in ["log", "ln"]:
if func.subexpr():
if func.subexpr().atom():
base = convert_atom(func.subexpr().atom())
else:
base = convert_expr(func.subexpr().expr())
elif name == "log":
base = 10
elif name == "ln":
base = sympy.E
expr = sympy.log(arg, base, evaluate=False)
elif name in ["exp", "exponentialE"]:
expr = sympy.exp(arg)
elif name == "floor":
expr = handle_floor(arg)
elif name == "ceil":
expr = handle_ceil(arg)
elif name == "det":
expr = arg.det()
func_pow = None
should_pow = True
if func.supexpr():
if func.supexpr().expr():
func_pow = convert_expr(func.supexpr().expr())
else:
func_pow = convert_atom(func.supexpr().atom())
if name in ["sin", "cos", "tan", "csc", "sec", "cot", "sinh", "cosh", "tanh"]:
if func_pow == -1:
name = "a" + name
should_pow = False
expr = getattr(sympy.functions, name)(arg, evaluate=False)
if func_pow and should_pow:
expr = sympy.Pow(expr, func_pow, evaluate=False)
return expr
elif func.func_normal_multi_arg():
if func.L_PAREN(): # function called with parenthesis
args = func.func_multi_arg().getText().split(",")
else:
args = func.func_multi_arg_noparens().split(",")
args = list(map(lambda arg: latex2sympy(arg, VARIABLE_VALUES), args))
name = func.func_normal_multi_arg().start.text[1:]
if name == "operatorname":
operatorname = func.func_normal_multi_arg().func_operator_name.getText()
if operatorname in ["gcd", "lcm"]:
expr = handle_gcd_lcm(operatorname, args)
elif operatorname == "zeros":
expr = sympy.zeros(*args)
elif operatorname == "ones":
expr = sympy.ones(*args)
elif operatorname == "diag":
expr = sympy.diag(*args)
elif operatorname == "hstack":
expr = sympy.Matrix.hstack(*args)
elif operatorname == "vstack":
expr = sympy.Matrix.vstack(*args)
elif operatorname in ["orth", "ortho", "orthogonal", "orthogonalize"]:
if len(args) == 1:
arg = args[0]
expr = sympy.matrices.GramSchmidt(
[arg.col(i) for i in range(arg.cols)], True
)
else:
expr = sympy.matrices.GramSchmidt(args, True)
elif name in ["gcd", "lcm"]:
expr = handle_gcd_lcm(name, args)
elif name in ["max", "min"]:
name = name[0].upper() + name[1:]
expr = getattr(sympy.functions, name)(*args, evaluate=False)
func_pow = None
should_pow = True
if func.supexpr():
if func.supexpr().expr():
func_pow = convert_expr(func.supexpr().expr())
else:
func_pow = convert_atom(func.supexpr().atom())
if func_pow and should_pow:
expr = sympy.Pow(expr, func_pow, evaluate=False)
return expr
elif func.atom_expr_no_supexpr():
# define a function
f = sympy.Function(func.atom_expr_no_supexpr().getText())
# args
args = func.func_common_args().getText().split(",")
if args[-1] == "":
args = args[:-1]
args = [latex2sympy(arg, VARIABLE_VALUES) for arg in args]
# supexpr
if func.supexpr():
if func.supexpr().expr():
expr = convert_expr(func.supexpr().expr())
else:
expr = convert_atom(func.supexpr().atom())
return sympy.Pow(f(*args), expr, evaluate=False)
else:
return f(*args)
elif func.FUNC_INT():
return handle_integral(func)
elif func.FUNC_SQRT():
expr = convert_expr(func.base)
if func.root:
r = convert_expr(func.root)
return sympy.Pow(expr, 1 / r, evaluate=False)
else:
return sympy.Pow(expr, sympy.S.Half, evaluate=False)
elif func.FUNC_SUM():
return handle_sum_or_prod(func, "summation")
elif func.FUNC_PROD():
return handle_sum_or_prod(func, "product")
elif func.FUNC_LIM():
return handle_limit(func)
elif func.EXP_E():
return handle_exp(func)
def convert_func_arg(arg):
if hasattr(arg, "expr"):
return convert_expr(arg.expr())
else:
return convert_mp(arg.mp_nofunc())
def handle_integral(func):
if func.additive():
integrand = convert_add(func.additive())
elif func.frac():
integrand = convert_frac(func.frac())
else:
integrand = 1
int_var = None
if func.DIFFERENTIAL():
int_var = get_differential_var(func.DIFFERENTIAL())
else:
for sym in integrand.atoms(sympy.Symbol):
s = str(sym)
if len(s) > 1 and s[0] == "d":
if s[1] == "\\":
int_var = sympy.Symbol(s[2:], real=is_real)
else:
int_var = sympy.Symbol(s[1:], real=is_real)
int_sym = sym
if int_var:
integrand = integrand.subs(int_sym, 1)
else:
# Assume dx by default
int_var = sympy.Symbol("x", real=is_real)
if func.subexpr():
if func.subexpr().atom():
lower = convert_atom(func.subexpr().atom())
else:
lower = convert_expr(func.subexpr().expr())
if func.supexpr().atom():
upper = convert_atom(func.supexpr().atom())
else:
upper = convert_expr(func.supexpr().expr())
return sympy.Integral(integrand, (int_var, lower, upper))
else:
return sympy.Integral(integrand, int_var)
def handle_sum_or_prod(func, name):
val = convert_mp(func.mp())
iter_var = convert_expr(func.subeq().equality().expr(0))
start = convert_expr(func.subeq().equality().expr(1))
if func.supexpr().expr(): # ^{expr}
end = convert_expr(func.supexpr().expr())
else: # ^atom
end = convert_atom(func.supexpr().atom())
if name == "summation":
return sympy.Sum(val, (iter_var, start, end))
elif name == "product":
return sympy.Product(val, (iter_var, start, end))
def handle_limit(func):
sub = func.limit_sub()
if sub.LETTER_NO_E():
var = sympy.Symbol(sub.LETTER_NO_E().getText(), real=is_real)
elif sub.GREEK_CMD():
var = sympy.Symbol(sub.GREEK_CMD().getText()[1:].strip(), real=is_real)
elif sub.OTHER_SYMBOL_CMD():
var = sympy.Symbol(sub.OTHER_SYMBOL_CMD().getText().strip(), real=is_real)
else:
var = sympy.Symbol("x", real=is_real)
if sub.SUB():
direction = "-"
else:
direction = "+"
approaching = convert_expr(sub.expr())
content = convert_mp(func.mp())
return sympy.Limit(content, var, approaching, direction)
def handle_exp(func):
if func.supexpr():
if func.supexpr().expr(): # ^{expr}
exp_arg = convert_expr(func.supexpr().expr())
else: # ^atom
exp_arg = convert_atom(func.supexpr().atom())
else:
exp_arg = 1
return sympy.exp(exp_arg)
def handle_gcd_lcm(f, args):
"""Return the result of gcd() or lcm(), as UnevaluatedExpr.
f: str - name of function ("gcd" or "lcm")
args: List[Expr] - list of function arguments
"""
args = tuple(map(sympy.nsimplify, args))
# gcd() and lcm() don't support evaluate=False
return sympy.UnevaluatedExpr(getattr(sympy, f)(args))
def handle_floor(expr):
"""Apply floor() then return the floored expression.
expr: Expr - sympy expression as an argument to floor()
"""
return sympy.functions.floor(expr, evaluate=False)
def handle_ceil(expr):
"""Apply ceil() then return the ceil-ed expression.
expr: Expr - sympy expression as an argument to ceil()
"""
return sympy.functions.ceiling(expr, evaluate=False)
def get_differential_var(d):
text = get_differential_var_str(d.getText())
return sympy.Symbol(text, real=is_real)
def get_differential_var_str(text):
for i in range(1, len(text)):
c = text[i]
if not (c == " " or c == "\r" or c == "\n" or c == "\t"):
idx = i
break
text = text[idx:]
if text[0] == "\\":
text = text[1:]
return text
def latex(tex):
global frac_type
result = sympy.latex(tex)
result = (
result.replace(r"\frac", frac_type, -1)
.replace(r"\dfrac", frac_type, -1)
.replace(r"\tfrac", frac_type, -1)
)
result = result.replace(r"\left[\begin{matrix}", r"\begin{bmatrix}", -1).replace(
r"\end{matrix}\right]", r"\end{bmatrix}", -1
)
result = result.replace(r"\left", r"", -1).replace(r"\right", r"", -1)
result = result.replace(r" )", r")", -1)
result = result.replace(r"\log", r"\ln", -1)
return result
def latex2latex(tex):
result = latex2sympy(tex)
# if result is a list or tuple or dict
if (
isinstance(result, list)
or isinstance(result, tuple)
or isinstance(result, dict)
):
return latex(result)
else:
return latex(simplify(result.subs(variances).doit().doit()))
# Set image value
latex2latex("i=I")
latex2latex("j=I")
# set Identity(i)
for i in range(1, 10):
lh = sympy.Symbol(r"\bm{I}_" + str(i), real=False)
lh_m = sympy.MatrixSymbol(r"\bm{I}_" + str(i), i, i)
rh = sympy.Identity(i).as_mutable()
variances[lh] = rh
variances[lh_m] = rh
var[str(lh)] = rh
if __name__ == "__main__":
# latex2latex(r'A_1=\begin{bmatrix}1 & 2 & 3 & 4 \\ 5 & 6 & 7 & 8\end{bmatrix}')
# latex2latex(r'b_1=\begin{bmatrix}1 \\ 2 \\ 3 \\ 4\end{bmatrix}')
# tex = r"(x+2)|_{x=y+1}"
# tex = r"\operatorname{zeros}(3)"
tex = r"\operatorname{rows}(\begin{bmatrix}1 & 2 \\ 3 & 4\end{bmatrix})"
# print("latex2latex:", latex2latex(tex))
math = latex2sympy(tex)
# math = math.subs(variances)
print("latex:", tex)
# print("var:", variances)
print("raw_math:", math)
# print("math:", latex(math.doit()))
# print("math_type:", type(math.doit()))
# print("shape:", (math.doit()).shape)
print("cal:", latex2latex(tex))