commit 68b629302800df14807a0ce8313aa78a57afcf48 Author: Florian Stecker Date: Tue Aug 27 19:23:36 2024 -0400 initial version diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..96ef6c0 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +/target +Cargo.lock diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..965abf9 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,7 @@ +[package] +name = "solveq" +version = "0.1.0" +edition = "2021" + +[dependencies] +egg = "0.9.5" diff --git a/src/language.rs b/src/language.rs new file mode 100644 index 0000000..218a5a0 --- /dev/null +++ b/src/language.rs @@ -0,0 +1,342 @@ +use std::{cmp::Ordering, fmt::{self,Display, Formatter}, str::FromStr, sync::LazyLock}; +use egg::{define_language, merge_option, rewrite as rw, Analysis, DidMerge, Id, Language, Subst, Var}; + +pub type EGraph = egg::EGraph; +pub type Rewrite = egg::Rewrite; + +define_language! { + pub enum EquationLanguage { + "x" = Unknown, + "+" = Add([Id; 2]), + "-" = Sub([Id; 2]), + "-" = Neg([Id; 1]), + "*" = Mul([Id; 2]), + "/" = Div([Id; 2]), + "^" = Power([Id; 2]), + "=" = Equals([Id; 2]), + "rec" = Reciprocal([Id; 1]), + Num(Rational), + } +} + +#[derive(Debug,Hash,Clone)] +pub struct Rational { + pub num: i64, + pub denom: u64, +} + +pub const RATIONAL_ZERO: Rational = Rational { num: 0, denom: 1 }; +pub const RATIONAL_ONE: Rational = Rational { num: 1, denom: 1 }; + +impl Display for Rational { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + if self.denom == 1 { + write!(f, "{}", self.num) + } else { + write!(f, "{}/{}", self.num, self.denom) + } + } +} + +impl FromStr for Rational { + type Err = std::num::ParseIntError; + fn from_str(s: &str) -> Result { + Ok(Rational { num: s.parse::()?, denom: 1 }) + } +} + +impl PartialEq for Rational { + fn eq(&self, other: &Rational) -> bool { + (self.num as i128) * (other.denom as i128) == (other.num as i128) * (self.denom as i128) + } +} + +impl Eq for Rational {} + +impl PartialOrd for Rational { + fn partial_cmp(&self, other: &Rational) -> Option { + i128::partial_cmp( + &((self.num as i128) * (other.denom as i128)), + &((other.num as i128) * (other.denom as i128)) + ) + } +} + +impl Ord for Rational { + fn cmp(&self, other: &Rational) -> Ordering { + i128::cmp( + &((self.num as i128) * (other.denom as i128)), + &((other.num as i128) * (other.denom as i128)) + ) + } +} + +impl Rational { + fn simplify(&mut self) { + let mut a = self.num.abs() as u64; + let mut b = self.denom; + + if a > b { + (a, b) = (b, a); + } + + while a > 0 { + (a, b) = (b % a, a); + } + + self.num /= b as i64; + self.denom /= b; + } +} + +// constant folding code essentially comes from egg examples, except using rationals instead of floats +#[derive(Default)] +pub struct ConstantFold; +impl Analysis for ConstantFold { + type Data = Option; + + fn make(egraph: &EGraph, enode: &EquationLanguage) -> Self::Data { + let x = |i: &Id| -> Self::Data { egraph[*i].data.clone() }; + + let mut value = match enode { + EquationLanguage::Num(c) => c.clone(), + EquationLanguage::Add([a,b]) => Rational { + num: x(a)?.num * x(b)?.denom as i64 + x(a)?.denom as i64 * x(b)?.num, + denom: x(a)?.denom * x(b)?.denom + }, + EquationLanguage::Sub([a,b]) => Rational { + num: x(a)?.num * x(b)?.denom as i64 - x(a)?.denom as i64 * x(b)?.num, + denom: x(a)?.denom * x(b)?.denom + }, + EquationLanguage::Mul([a,b]) => Rational { + num: x(a)?.num * x(b)?.num, + denom: x(a)?.denom * x(b)?.denom + }, + EquationLanguage::Div([a,b]) => { + if x(b)?.num == 0 { + return None; + } else if x(b)?.num > 0 { + Rational { + num: x(a)?.num * x(b)?.denom as i64, + denom: x(b)?.num as u64 * x(a)?.denom, + } + } else { + Rational { + num: - x(a)?.num * x(b)?.denom as i64, + denom: (-x(b)?.num) as u64 * x(a)?.denom, + } + } + }, + EquationLanguage::Neg([a]) => Rational { + num: -x(a)?.num, + denom: x(a)?.denom, + }, + EquationLanguage::Reciprocal([a]) => Rational { + num: if x(a)?.num > 0 { x(a)?.denom as i64 } else { - (x(a)?.denom as i64) }, + denom: x(a)?.num.abs() as u64, + }, + _ => return None, + }; + + value.simplify(); + + Some(value) + } + + fn merge(&mut self, to: &mut Self::Data, from: Self::Data) -> DidMerge { + merge_option(to, from, |a, b| { + assert!(a == &b, "Merged non-equal constants"); + DidMerge(false, false) + }) + } + + fn modify(egraph: &mut EGraph, id: Id) { + let data = egraph[id].data.clone(); + if let Some(c) = data { + let added = egraph.add(EquationLanguage::Num(c)); + egraph.union(id, added); + egraph[id].nodes.retain(|n|n.is_leaf()); + } + } +} + +fn is_nonzero_const(var: &str) -> impl Fn(&mut EGraph, Id, &Subst) -> bool { + let var: Var = var.parse().unwrap(); + move |egraph, _, subst| { + egraph[subst[var]].data.as_ref().filter(|x|*x != &RATIONAL_ZERO).is_some() + } +} + +pub static RULES: LazyLock> = LazyLock::new(||vec![ + rw!("commute-add"; "(+ ?x ?y)" => "(+ ?y ?x)"), + rw!("commute-mul"; "(* ?x ?y)" => "(* ?y ?x)"), + + rw!("assoc-add"; "(+ ?x (+ ?y ?z))" => "(+ (+ ?x ?y) ?z)"), + rw!("assoc-mul"; "(* ?x (* ?y ?z))" => "(* (* ?x ?y) ?z)"), + + rw!("add-0"; "(+ ?x 0)" => "?x"), + rw!("mul-0"; "(* ?x 0)" => "0"), + rw!("mul-1"; "(* ?x 1)" => "?x"), + + rw!("add-sub"; "(+ ?x (* (-1) ?x))" => "0"), + // division by zero shouldn't happen unless input is invalid + rw!("mul-div"; "(* ?x (rec ?x))" => "1" if is_nonzero_const("?y")), + + rw!("distribute"; "(* (+ ?x ?y) ?z)" => "(+ (* ?x ?z) (* ?y ?z))"), + rw!("factor"; "(+ (* ?x ?z) (* ?y ?z))" => "(* (+ ?x ?y) ?z)"), + + rw!("square"; "(^ ?x 2)" => "(* ?x ?x)"), + rw!("cube"; "(^ ?x 3)" => "(* ?x (* ?x ?x))"), + /* + rw!("inv-square"; "(* ?x ?x)" => "(^ ?x 2)"), + rw!("inv-cube"; "(* ?x (* ?x ?x))" => "(^ ?x 3)"), + */ + + rw!("sub"; "(- ?x ?y)" => "(+ ?x (* -1 ?y))"), + rw!("neg"; "(- ?x)" => "(* -1 ?x)"), + // division by zero shouldn't happen unless input is invalid + rw!("div"; "(/ ?x ?y)" => "(* ?x (rec ?y))" if is_nonzero_const("?y")), + + rw!("factor_poly"; "(+ (* x ?x) ?y)" => "(* ?x (+ x (* ?y (rec ?x))))" if is_nonzero_const("?x")), +]); + +pub struct PlusTimesCostFn; +impl egg::CostFunction for PlusTimesCostFn { + type Cost = usize; + fn cost(&mut self, enode: &EquationLanguage, mut costs: C) -> usize + where + C: FnMut(Id) -> usize, + { + let op_cost = match enode { + EquationLanguage::Div(_) => 1000, + EquationLanguage::Sub(_) => 1000, + EquationLanguage::Neg(_) => 1000, + EquationLanguage::Reciprocal(_) => 1000, + EquationLanguage::Power(_) => 1000, + _ => 1, + }; + enode.fold(op_cost, |sum, i| sum + costs(i)) + } +} + +#[derive(Debug,Clone,Copy)] +pub struct PolyStat { + degree: usize, + factors: usize, // non-constant factors + ops: usize, + monomial: bool, + sum_of_monomials: bool, + monic: bool, + factorized: bool, // a product of monic polynomials and at least one constant +} + +#[derive(Debug,Clone,Copy)] +pub enum FactorizationCost { + UnwantedOps, + Polynomial(PolyStat) +} + +fn score(cost: FactorizationCost) -> usize { + match cost { + FactorizationCost::UnwantedOps => 10000, + FactorizationCost::Polynomial(p) => + if !p.factorized { + 1000 + } else { + 100 * (9 - p.factors) + p.ops + }, + } +} + +impl PartialEq for FactorizationCost { + fn eq(&self, other: &Self) -> bool { + score(*self) == score(*other) + } +} + +impl PartialOrd for FactorizationCost { + fn partial_cmp(&self, other: &Self) -> Option { + usize::partial_cmp(&score(*self), &score(*other)) + } +} + +pub struct FactorizationCostFn; +impl egg::CostFunction for FactorizationCostFn { + type Cost = FactorizationCost; + + fn cost(&mut self, enode: &EquationLanguage, mut costs: C) -> Self::Cost + where + C: FnMut(Id) -> Self::Cost, + { + match enode { + EquationLanguage::Add([a,b]) => { + match (costs(*a), costs(*b)) { + (FactorizationCost::Polynomial(p1),FactorizationCost::Polynomial(p2)) => { + // we only ever want to add monomials + let result_monic = if p1.degree > p2.degree { + p1.monic + } else if p2.degree > p1.degree { + p2.monic + } else { + false + }; + + if !p1.sum_of_monomials || !p2.sum_of_monomials { + FactorizationCost::UnwantedOps + } else { + FactorizationCost::Polynomial(PolyStat { + degree: usize::max(p1.degree, p2.degree), + factors: 1, + ops: p1.ops + p2.ops, + monomial: false, + sum_of_monomials: p1.sum_of_monomials && p2.sum_of_monomials, + monic: result_monic, + factorized: result_monic, + }) + } + }, + _ => FactorizationCost::UnwantedOps + } + }, + EquationLanguage::Mul([a,b]) => { + match (costs(*a), costs(*b)) { + (FactorizationCost::Polynomial(p1), FactorizationCost::Polynomial(p2)) => { + FactorizationCost::Polynomial(PolyStat { + degree: p1.degree + p2.degree, + factors: p1.factors + p2.factors, + ops: p1.ops + p2.ops, + monomial: p1.monomial && p2.monomial, + sum_of_monomials: p1.monomial && p2.monomial, + monic: p1.monic && p2.monic, + factorized: (p1.monic && p2.factorized) || (p2.monic && p1.factorized) + }) + }, + _ => FactorizationCost::UnwantedOps + } + }, + EquationLanguage::Num(c) => { + FactorizationCost::Polynomial(PolyStat { + degree: 0, + factors: 0, + ops: 0, + monomial: true, + sum_of_monomials: true, + monic: false, + factorized: true + }) + }, + EquationLanguage::Unknown => { + FactorizationCost::Polynomial(PolyStat { + degree: 1, + factors: 1, + ops: 0, + monomial: true, + sum_of_monomials: true, + monic: true, + factorized: true + }) + }, + _ => FactorizationCost::UnwantedOps, + } + } +} diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..bee4b54 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,3 @@ +pub mod language; +pub mod normal_form; +pub mod parse; diff --git a/src/main.rs b/src/main.rs new file mode 100644 index 0000000..fd78938 --- /dev/null +++ b/src/main.rs @@ -0,0 +1,226 @@ +use egg::{Extractor, Pattern, RecExpr, Runner}; +use solveq::language::{RULES, EquationLanguage, PlusTimesCostFn, FactorizationCostFn}; +use solveq::normal_form::analyze3; +use solveq::parse::parse_equation; + +static TEST_EQUATIONS: &[&str] = &[ + "(x + 50) * 10 - 150 - 100", + "(x - 2) * (x + 2) - 0", + "x ^ 2 - 4", + "x ^ 2 - 2 - 0", + "x ^ 2 - (2 * x + 15)", + "(x ^ 2 - 2 * x - 15) * (x + 5) - 0", + "x ^ 3 + 3 * x ^ 2 - 25 * x - 75 - 0", +]; + +fn main() { + for eq in TEST_EQUATIONS { + let start = parse_equation(*eq).unwrap(); + + // println!("{:?}", &start); + // do transformation to left - right = 0 + + let mut runner = Runner::default() + .with_explanations_enabled() + .with_expr(&start) + .run(&*RULES); + + let extractor = Extractor::new(&runner.egraph, FactorizationCostFn); + + let (best_cost, best_expr) = extractor.find_best(runner.roots[0]); + + println!("{}", start); + println!("{:?} {:?}", best_cost, as AsRef<[EquationLanguage]>>::as_ref(&best_expr)); + + + println!(""); + } + +// let root = runner.roots[0]; +// let egraph = &runner.egraph; +// let pattern: Pattern = "(+ (* ?a (* x x)) ?c)".parse().unwrap(); +// let matches = pattern.search(&egraph); + + + +// println!("{:?}", egraph.classes().count()); + + // Analyze +// analyze3(egraph, runner.roots[0]); + + /* + for class in egraph.classes() { + if monic_nonconst_polynomial(egraph, class.id).is_some() { + let (_, best_expr) = extractor.find_best(class.id); + println!("Monomial: {}", best_expr); + } + } + + println!("{:?}", &matches); + */ + + +// println!("{}", runner.explain_equivalence(&start, &best_expr).get_flat_string()); +} + + +/* +fn power_of_x(egraph: &EGraph, eclass: Id) -> Option { + for n in &egraph[eclass].nodes { + match *n { + EquationLanguage::Unknown => { return Some(1) }, + EquationLanguage::Mul([a,b]) => { + let Some(left) = power_of_x(egraph, a) else { continue }; + let Some(right) = power_of_x(egraph, b) else { continue }; + return Some(left + right); + }, + _ => {} + } + } + None +} + +fn monomial(egraph: &EGraph, eclass: Id) -> Option<(usize, Rational)> { + if let Some(deg) = power_of_x(egraph, eclass) { + return Some((deg, RATIONAL_ONE.clone())); + } + + for n in &egraph[eclass].nodes { + match *n { + EquationLanguage::Mul([a,b]) => { + let Some(coeff) = egraph[a].data.clone() else { continue }; + let Some(deg) = power_of_x(egraph, b) else { continue }; + return Some((deg, coeff)); + }, + _ => {} + } + } + None +} + + +// this is either a power_of_x, or a sum of this and a monomial +fn monic_nonconst_polynomial(egraph: &EGraph, eclass: Id) -> Option> { + let mut result: Vec = Vec::new(); + + if let Some(deg) = power_of_x(egraph, eclass) { + result.resize(deg - 1, RATIONAL_ZERO); + result.push(RATIONAL_ONE.clone()); + return Some(result); + } + + for n in &egraph[eclass].nodes { + match *n { + EquationLanguage::Add([a,b]) => { + let Some(mut leading) = monic_nonconst_polynomial(egraph, a) + else { continue }; + let Some(addon) = monomial(egraph, b) + else { continue }; + + if leading.len() <= addon.0 || leading[addon.0] != RATIONAL_ZERO { + continue; + } + + leading[addon.0] = addon.1.clone(); + return Some(leading); + }, + _ => {}, + } + } + None +} + +*/ + + +/* +fn analyze(egraph: &EGraph, _id: Id) { + let mut types: HashMap = HashMap::new(); + let mut todo: VecDeque = VecDeque::new(); + // todo.push_back(runner.roots[0]); + for cls in egraph.classes() { + todo.push_back(cls.id); + } + + 'todo: while todo.len() > 0 { + let id = todo.pop_front().unwrap(); + if types.contains_key(&id) { + continue 'todo; + } + + if let Some(c) = &egraph[id].data { + types.insert(id, SpecialTerm::Constant(c.clone())); + continue 'todo; + } + + 'nodes: for n in &egraph[id].nodes { + match *n { + EquationLanguage::Unknown => { + types.insert(id, SpecialTerm::PowerOfX(1)); + continue 'todo; + }, + EquationLanguage::Mul([a,b]) => { + if !types.contains_key(&a) { + todo.push_back(a); + todo.push_back(id); + continue 'nodes; + } + + if !types.contains_key(&b) { + todo.push_back(b); + todo.push_back(id); + continue 'nodes; + } + + match (&types[&a], &types[&b]) { + (SpecialTerm::PowerOfX(dega), SpecialTerm::PowerOfX(degb)) => { + types.insert(id, SpecialTerm::PowerOfX(*dega + *degb)); + }, + (SpecialTerm::Constant(coeff), SpecialTerm::PowerOfX(deg)) => { + types.insert(id, SpecialTerm::Monomial(*deg, coeff.clone())); + }, + _ => { continue 'nodes; }, + } + continue 'todo; + }, + EquationLanguage::Add([a,b]) => { + if !types.contains_key(&a) { + todo.push_front(a); + todo.push_back(id); + continue 'todo; + } + + if !types.contains_key(&b) { + todo.push_front(b); + todo.push_back(id); + continue 'todo; + } + + match (&types[&a], &types[&b]) { + (SpecialTerm::MonicNonconstPoly(poly), SpecialTerm::Monomial(deg, coeff)) => { + if poly.len() <= *deg || poly[*deg] != RATIONAL_ZERO { + continue 'nodes; + } + + let mut poly = poly.clone(); + poly[*deg] = coeff.clone(); + types.insert(id, SpecialTerm::MonicNonconstPoly(poly)); + }, + _ => { continue 'nodes; }, + } + continue 'todo; + }, + _ => {}, + } + } + + types.insert(id, SpecialTerm::Other); + } + + for (id, ty) in &types { + if !matches!(ty, &SpecialTerm::Other) { + println!("{:?}", &ty); + } + } +} +*/ diff --git a/src/normal_form.rs b/src/normal_form.rs new file mode 100644 index 0000000..2d776b5 --- /dev/null +++ b/src/normal_form.rs @@ -0,0 +1,229 @@ +use crate::language::{EGraph, EquationLanguage, Rational, RATIONAL_ONE, RATIONAL_ZERO}; +use std::collections::HashMap; +use egg::Id; + +#[derive(Debug,Clone)] +pub enum SpecialTerm { + Constant(Rational), + PowerOfX(usize), + Monomial(usize, Rational), + MonicNonconstPoly(Vec), + Factorization(Rational, Vec>), + Other, +} + +fn search_for(egraph: &EGraph, f: F) -> HashMap +where + F: Fn(Id, &EquationLanguage, &HashMap) -> Option { + + let mut result: HashMap = HashMap::new(); + let mut modifications: usize = 1; + + while modifications > 0 { + modifications = 0; + + for cls in egraph.classes() { + let id = cls.id; + if result.contains_key(&id) { + continue; + } + + for node in &cls.nodes { + if let Some(x) = f(id, node, &result) { + result.insert(id, x); + modifications += 1; + } + } + } + + println!("{} modifications!", modifications); + } + + result +} + +pub fn analyze3(egraph: &EGraph, eclass: Id) { + let constants = search_for(egraph, |id, _, _| + egraph[id].data.as_ref().map(|c|c.clone()) + ); + + println!("{:?}", constants); + + let powers_of_x = search_for(egraph, |_, node, matches| match *node { + EquationLanguage::Unknown => Some(1), + EquationLanguage::Mul([a,b]) => { + if !matches.contains_key(&a) || !matches.contains_key(&b) { + return None; + } + + let (dega, degb) = (matches[&a], matches[&b]); + Some(dega + degb) + }, + _ => None, + }); + + println!("{:?}", powers_of_x); + + let monomials = search_for(egraph, |id, node, _| { + if let Some(deg) = powers_of_x.get(&id) { + return Some((*deg, RATIONAL_ONE.clone())); + } + + if let Some(c) = constants.get(&id) { + return Some((0, c.clone())); + } + + match *node { + EquationLanguage::Mul([a,b]) => { + if !constants.contains_key(&a) || !powers_of_x.contains_key(&b) { + return None; + } + + let (coeff, deg) = (&constants[&a], powers_of_x[&b]); + Some((deg, coeff.clone())) + }, + _ => None, + } + }); + + println!("{:?}", monomials); + + let monic_polynomials = search_for(egraph, |id, node, matches| { + if let Some(deg) = powers_of_x.get(&id) { + let mut poly: Vec = Vec::new(); + poly.resize(*deg, RATIONAL_ZERO); + poly.push(RATIONAL_ONE.clone()); + Some(poly) + } else { + match *node { + EquationLanguage::Add([a,b]) => { + if !matches.contains_key(&a) || !monomials.contains_key(&b) { + return None; + } + + let (leading, (deg, coeff)) = (&matches[&a], &monomials[&b]); + if leading.len() <= *deg || leading[*deg] != RATIONAL_ZERO { + return None; + } + + let mut poly = leading.clone(); + poly[*deg] = coeff.clone(); + Some(poly) + }, + _ => None, + } + } + }); + + for p in &monic_polynomials { + println!("{:?}", p); + } + + let factorizations: HashMap>)> = search_for(egraph, |id, node, matches| { + if let Some(c) = constants.get(&id) { + return Some((c.clone(), vec![])); + } + + if let Some(poly) = monic_polynomials.get(&id) { + return Some((RATIONAL_ONE.clone(), vec![poly.clone()])); + } + + match *node { + EquationLanguage::Mul([a,b]) => { + if !matches.contains_key(&a) || !monic_polynomials.contains_key(&b) { + return None; + } + + let ((factor, polys), newpoly) = (&matches[&a], &monic_polynomials[&b]); + + let mut combined: Vec> = polys.clone(); + combined.push(newpoly.clone()); + Some((factor.clone(), combined)) + }, + _ => None, + } + }); + + /* + for p in &factorizations { + println!("{:?}", p); +} + */ + + println!("{:?}", factorizations[&eclass]); +} + +pub fn analyze2(egraph: &EGraph) -> HashMap { + let mut types: HashMap = HashMap::new(); + + let mut modifications: usize = 1; + + while modifications > 0 { + modifications = 0; + + for cls in egraph.classes() { + let id = cls.id; + if types.contains_key(&id) { + continue; + } + + if let Some(c) = &egraph[id].data { + types.insert(id, SpecialTerm::Constant(c.clone())); + modifications += 1; + continue; + } + + for node in &cls.nodes { + match *node { + EquationLanguage::Unknown => { + types.insert(id, SpecialTerm::PowerOfX(1)); + modifications += 1; + }, + EquationLanguage::Mul([a,b]) => { + // as we don't know a and b yet, defer to future iteration + if !types.contains_key(&a) || !types.contains_key(&b) { + continue; + } + + match (&types[&a], &types[&b]) { + (SpecialTerm::PowerOfX(dega), SpecialTerm::PowerOfX(degb)) => { + types.insert(id, SpecialTerm::PowerOfX(*dega + *degb)); + modifications += 1; + }, + (SpecialTerm::Constant(coeff), SpecialTerm::PowerOfX(deg)) => { + types.insert(id, SpecialTerm::Monomial(*deg, coeff.clone())); + modifications += 1; + }, + _ => { }, + } + }, + EquationLanguage::Add([a,b]) => { + // as we don't know a and b yet, defer to future iteration + if !types.contains_key(&a) || !types.contains_key(&b) { + continue; + } + + match (&types[&a], &types[&b]) { + (SpecialTerm::MonicNonconstPoly(poly), SpecialTerm::Monomial(deg, coeff)) => { + if poly.len() <= *deg || poly[*deg] != RATIONAL_ZERO { + continue; + } + + let mut poly = poly.clone(); + poly[*deg] = coeff.clone(); + types.insert(id, SpecialTerm::MonicNonconstPoly(poly)); + modifications += 1; + }, + _ => { }, + } + }, + _ => {} + } + } + } + + println!("{} modifications!", modifications); + } + + types +} diff --git a/src/parse.rs b/src/parse.rs new file mode 100644 index 0000000..2498e98 --- /dev/null +++ b/src/parse.rs @@ -0,0 +1,103 @@ +use std::error::Error; +use egg::*; +use crate::language::EquationLanguage; + +pub fn parse_equation(input: &str) -> Result, ParseError> { + let mut result: RecExpr = Default::default(); + parse_equation_inner(&input.replace(" ", ""), &mut result)?; + Ok(result) +} + +// this is a very simple parser essentially copied from the technical interview +fn parse_equation_inner(input: &str, expr: &mut RecExpr) -> Result { + let mut level = 0; + let mut precedence = 1000; // 0 = '=', 1 = '+-', 2 = '*/', 3 = '^' + let mut operator_position: Option = None; + + for (i,c) in input.chars().enumerate() { + if c == '(' { + level += 1; + } else if c == ')' { + level -= 1; + } + + if level > 0 { + continue; + } + + match c { + '^' if precedence > 3 => { + operator_position = Some(i); + precedence = 3; + }, + '*' | '/' if precedence > 2 => { + operator_position = Some(i); + precedence = 2; + }, + '-' | '+' if precedence > 1 => { + operator_position = Some(i); + precedence = 1; + }, + '=' => { + operator_position = Some(i); + precedence = 0; + }, + _ => {}, + } + } + + // no top level operator => either primitive item or in parantheses + if let Some(operator_position) = operator_position { + if operator_position == 0 && input.starts_with("-") { + let inner = parse_equation_inner(&input[1 .. input.len()], expr)?; + let id = expr.add(EquationLanguage::from_op("-", vec![inner])?); + return Ok(id); + } + + let left = parse_equation_inner(&input[0 .. operator_position], expr)?; + let right = parse_equation_inner(&input[operator_position+1 .. input.len()], expr)?; + + let id = expr.add(EquationLanguage::from_op( + &input[operator_position .. operator_position + 1], + vec![left, right] + )?); + + Ok(id) + } else { + if input.starts_with("(") && input.ends_with(")") { + // expression in parentheses + parse_equation_inner(&input[1..input.len()-1], expr) + } else { + // standalone integer + if input == "x" { + let id = expr.add(EquationLanguage::Unknown); + Ok(id) + } else { + input.parse::() + .map_err(|_|ParseError(format!("Failed conversion to i64: {}", &input)))?; + let id = expr.add(EquationLanguage::from_op( + input, vec![] + )?); + Ok(id) + } + } + } + +} + +#[derive(Debug)] +pub struct ParseError(String); + +impl Error for ParseError {} + +impl std::fmt::Display for ParseError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", &self.0) + } +} + +impl From for ParseError { + fn from(value: FromOpError) -> Self { + ParseError(format!("Error parsing {}", &value)) + } +}