diff --git a/src/factorization.rs b/src/factorization.rs new file mode 100644 index 0000000..c118654 --- /dev/null +++ b/src/factorization.rs @@ -0,0 +1,249 @@ +use core::fmt; +use std::{cmp::Ordering, fmt::Display}; +use crate::language::{EquationLanguage, Rational, RATIONAL_ONE, RATIONAL_ZERO}; +use egg::{Id, RecExpr}; + + +#[derive(Debug,Clone,Copy)] +pub struct PolyStat { + degree: usize, + factors: usize, // non-constant factors + ops: usize, + monomial: bool, + sum_of_monomials: bool, + monic: bool, + factorized: bool, // a product of monic polynomials and at least one constant +} + +#[derive(Debug,Clone,Copy)] +pub enum FactorizationCost { + UnwantedOps, + Polynomial(PolyStat) +} + +fn score(cost: FactorizationCost) -> usize { + match cost { + FactorizationCost::UnwantedOps => 10000, + FactorizationCost::Polynomial(p) => + if !p.factorized { + 1000 + p.ops + } else { + 100 * (9 - p.factors) + p.ops + }, + } +} + +impl PartialEq for FactorizationCost { + fn eq(&self, other: &Self) -> bool { + score(*self) == score(*other) + } +} + +impl PartialOrd for FactorizationCost { + fn partial_cmp(&self, other: &Self) -> Option { + usize::partial_cmp(&score(*self), &score(*other)) + } +} + +pub struct FactorizationCostFn; +impl egg::CostFunction for FactorizationCostFn { + type Cost = FactorizationCost; + + fn cost(&mut self, enode: &EquationLanguage, mut costs: C) -> Self::Cost + where + C: FnMut(Id) -> Self::Cost, + { + match enode { + EquationLanguage::Add([a,b]) => { + match (costs(*a), costs(*b)) { + (FactorizationCost::Polynomial(p1),FactorizationCost::Polynomial(p2)) => { + // we only ever want to add monomials + let result_monic = if p1.degree > p2.degree { + p1.monic + } else if p2.degree > p1.degree { + p2.monic + } else { + false + }; + + if !p1.sum_of_monomials || !p2.sum_of_monomials { + FactorizationCost::UnwantedOps + } else { + FactorizationCost::Polynomial(PolyStat { + degree: usize::max(p1.degree, p2.degree), + factors: 1, + ops: p1.ops + p2.ops, + monomial: false, + sum_of_monomials: p1.sum_of_monomials && p2.sum_of_monomials, + monic: result_monic, + factorized: result_monic, + }) + } + }, + _ => FactorizationCost::UnwantedOps + } + }, + EquationLanguage::Mul([a,b]) => { + match (costs(*a), costs(*b)) { + (FactorizationCost::Polynomial(p1), FactorizationCost::Polynomial(p2)) => { + FactorizationCost::Polynomial(PolyStat { + degree: p1.degree + p2.degree, + factors: p1.factors + p2.factors, + ops: p1.ops + p2.ops, + monomial: p1.monomial && p2.monomial, + sum_of_monomials: p1.monomial && p2.monomial, + monic: p1.monic && p2.monic, + factorized: (p1.monic && p2.factorized) || (p2.monic && p1.factorized) + }) + }, + _ => FactorizationCost::UnwantedOps + } + }, + EquationLanguage::Num(c) => { + FactorizationCost::Polynomial(PolyStat { + degree: 0, + factors: 0, + ops: 0, + monomial: true, + sum_of_monomials: true, + monic: false, + factorized: true + }) + }, + EquationLanguage::Unknown => { + FactorizationCost::Polynomial(PolyStat { + degree: 1, + factors: 1, + ops: 0, + monomial: true, + sum_of_monomials: true, + monic: true, + factorized: true + }) + }, + _ => FactorizationCost::UnwantedOps, + } + } +} + +#[derive(Debug,Clone)] +pub struct Factorization { + pub constant_factor: Rational, + pub polynomials: Vec>, +} + +impl Display for Factorization { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + if self.constant_factor != RATIONAL_ONE { + write!(f, "{}", self.constant_factor)?; + } + + for poly in &self.polynomials { + write!(f, "(")?; + for (deg, coeff) in poly.iter().enumerate() { + if deg == 0 { + write!(f, "{}", coeff)?; + } else if deg == 1 { + write!(f, " + {}x", coeff)?; + } else { + write!(f, " + {}x^{}", coeff, deg)?; + } + } + write!(f, ")")?; + } + + Ok(()) + } +} + +pub fn extract_factorization(expr: &RecExpr) -> Factorization { + let root_id: Id = Id::from(expr.as_ref().len()-1); + + let mut constant_factor: Option = None; + let mut factors: Vec> = Vec::new(); + let mut todo: Vec = Vec::new(); + todo.push(root_id); + + while todo.len() > 0 { + let id = todo.pop().unwrap(); + + match &expr[id] { + EquationLanguage::Mul([a,b]) => { + todo.push(*a); + todo.push(*b); + }, + EquationLanguage::Num(x) => { + assert!(constant_factor.is_none()); + constant_factor = Some(x.clone()); + }, + _ => { + factors.push(extract_polynomial(expr, id)); + } + } + } + + Factorization { + constant_factor: constant_factor.unwrap_or_else(||RATIONAL_ONE.clone()), + polynomials: factors + } +} + +fn extract_polynomial(expr: &RecExpr, id: Id) -> Vec { + let mut result: Vec = Vec::new(); + let mut todo: Vec = Vec::new(); + todo.push(id); + + while todo.len() > 0 { + let id = todo.pop().unwrap(); + + match &expr[id] { + EquationLanguage::Add([a,b]) => { + todo.push(*a); + todo.push(*b); + }, + _ => { + let (deg, coeff) = extract_monomial(expr, id); + result.resize(result.len().max(deg), RATIONAL_ZERO.clone()); + + if result.len() <= deg { + result.push(coeff); + } else { + assert!(result[deg] == RATIONAL_ZERO); + result[deg] = coeff; + } + } + } + } + + result +} + +fn extract_monomial(expr: &RecExpr, id: Id) -> (usize, Rational) { + let mut coeff: Option = None; + let mut deg: usize = 0; + let mut todo: Vec = Vec::new(); + todo.push(id); + + while todo.len() > 0 { + let id = todo.pop().unwrap(); + + match &expr[id] { + EquationLanguage::Unknown => { + deg += 1; + }, + EquationLanguage::Mul([a,b]) => { + todo.push(*a); + todo.push(*b); + }, + EquationLanguage::Num(x) => { + assert!(coeff.is_none()); + coeff = Some(x.clone()); + }, + _ => { + panic!("Not a rational polynomial in normal form!"); + } + } + } + + (deg, coeff.unwrap_or_else(||RATIONAL_ONE.clone())) +} diff --git a/src/lib.rs b/src/lib.rs index bee4b54..809f05c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,5 @@ pub mod language; pub mod normal_form; pub mod parse; +pub mod factorization; +pub mod output; diff --git a/src/main.rs b/src/main.rs index fd78938..c0c44f2 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,24 +1,35 @@ -use egg::{Extractor, Pattern, RecExpr, Runner}; -use solveq::language::{RULES, EquationLanguage, PlusTimesCostFn, FactorizationCostFn}; +use egg::{AstSize, EGraph, Extractor, Id, Pattern, RecExpr, Runner}; +use solveq::factorization::{extract_factorization, FactorizationCost, FactorizationCostFn}; +use solveq::language::{ConstantFold, EquationLanguage, FactorizationCostFn, PlusTimesCostFn, Rational, RULES}; use solveq::normal_form::analyze3; use solveq::parse::parse_equation; +use solveq::output::print_term; static TEST_EQUATIONS: &[&str] = &[ - "(x + 50) * 10 - 150 - 100", - "(x - 2) * (x + 2) - 0", - "x ^ 2 - 4", - "x ^ 2 - 2 - 0", - "x ^ 2 - (2 * x + 15)", - "(x ^ 2 - 2 * x - 15) * (x + 5) - 0", - "x ^ 3 + 3 * x ^ 2 - 25 * x - 75 - 0", + "(x + 50) * 10 - 150 = 100", + "(x - 2) * (x + 2) = 0", + "x ^ 2 = 4", + "x ^ 2 - 2 = 0", + "x ^ 2 = 2 * x + 15", + "(x ^ 2 - 2 * x - 15) * (x + 5) = 0", + "x ^ 3 + 3 * x ^ 2 - 25 * x - 75 = 0", ]; -fn main() { - for eq in TEST_EQUATIONS { - let start = parse_equation(*eq).unwrap(); - // println!("{:?}", &start); - // do transformation to left - right = 0 +fn main() { + let expr: RecExpr = "(* x (+ x -2))".parse().unwrap(); + println!("{:?}", get_expression_cost(&expr)); + + for eq in TEST_EQUATIONS { + println!("Equation: {}", *eq); + + let mut start = parse_equation(*eq).unwrap(); + let root_id = Id::from(start.as_ref().len()-1); + let EquationLanguage::Equals([left, right]) = start[root_id] + else { panic!("Not an equation without an equals sign!"); }; + start[root_id] = EquationLanguage::Sub([left, right]); + + println!("Parsed: {}", &start); let mut runner = Runner::default() .with_explanations_enabled() @@ -29,198 +40,55 @@ fn main() { let (best_cost, best_expr) = extractor.find_best(runner.roots[0]); - println!("{}", start); - println!("{:?} {:?}", best_cost, as AsRef<[EquationLanguage]>>::as_ref(&best_expr)); +// println!("{:?} {:?}", best_cost, as AsRef<[EquationLanguage]>>::as_ref(&best_expr)); + println!("Best expresssion: {} {:?}", best_expr, best_cost); - println!(""); - } + let factorization = extract_factorization(&best_expr); -// let root = runner.roots[0]; -// let egraph = &runner.egraph; -// let pattern: Pattern = "(+ (* ?a (* x x)) ?c)".parse().unwrap(); -// let matches = pattern.search(&egraph); +// println!("{}", runner.explain_equivalence(&start, &best_expr).get_flat_string()); + println!("Factorized normal form: {}", factorization); + let mut solutions: Vec = Vec::new(); + for poly in &factorization.polynomials { + if poly.len() == 2 { // linear factor + let Rational { num, denom } = &poly[0]; + solutions.push(format!("x = {}", Rational { num: -*num, denom: *denom })); + } else if poly.len() == 3 { // quadratic factor + let Rational { num: num0, denom: denom0 } = &poly[0]; + let Rational { num: num1, denom: denom1 } = &poly[1]; -// println!("{:?}", egraph.classes().count()); + let sol1 = format!("- ({num1})/(2 * ({denom1})) + ((({num1})/(2 * ({denom1}))) ^ 2 - ({num0}) / ({denom0})) ^ (1/2)"); + let sol2 = format!("- ({num1})/(2 * ({denom1})) - ((({num1})/(2 * ({denom1}))) ^ 2 - ({num0}) / ({denom0})) ^ (1/2)"); - // Analyze -// analyze3(egraph, runner.roots[0]); + let expr = parse_equation(&sol1).unwrap(); + let runner = Runner::default() + .with_expr(&expr) + .run(&*RULES); + let extractor = Extractor::new(&runner.egraph, AstSize); + let (_, simplified_expr) = extractor.find_best(runner.roots[0]); + solutions.push(format!("x = {}", print_term(&simplified_expr))); - /* - for class in egraph.classes() { - if monic_nonconst_polynomial(egraph, class.id).is_some() { - let (_, best_expr) = extractor.find_best(class.id); - println!("Monomial: {}", best_expr); - } - } - - println!("{:?}", &matches); - */ - - -// println!("{}", runner.explain_equivalence(&start, &best_expr).get_flat_string()); -} - - -/* -fn power_of_x(egraph: &EGraph, eclass: Id) -> Option { - for n in &egraph[eclass].nodes { - match *n { - EquationLanguage::Unknown => { return Some(1) }, - EquationLanguage::Mul([a,b]) => { - let Some(left) = power_of_x(egraph, a) else { continue }; - let Some(right) = power_of_x(egraph, b) else { continue }; - return Some(left + right); - }, - _ => {} - } - } - None -} - -fn monomial(egraph: &EGraph, eclass: Id) -> Option<(usize, Rational)> { - if let Some(deg) = power_of_x(egraph, eclass) { - return Some((deg, RATIONAL_ONE.clone())); - } - - for n in &egraph[eclass].nodes { - match *n { - EquationLanguage::Mul([a,b]) => { - let Some(coeff) = egraph[a].data.clone() else { continue }; - let Some(deg) = power_of_x(egraph, b) else { continue }; - return Some((deg, coeff)); - }, - _ => {} - } - } - None -} - - -// this is either a power_of_x, or a sum of this and a monomial -fn monic_nonconst_polynomial(egraph: &EGraph, eclass: Id) -> Option> { - let mut result: Vec = Vec::new(); - - if let Some(deg) = power_of_x(egraph, eclass) { - result.resize(deg - 1, RATIONAL_ZERO); - result.push(RATIONAL_ONE.clone()); - return Some(result); - } - - for n in &egraph[eclass].nodes { - match *n { - EquationLanguage::Add([a,b]) => { - let Some(mut leading) = monic_nonconst_polynomial(egraph, a) - else { continue }; - let Some(addon) = monomial(egraph, b) - else { continue }; - - if leading.len() <= addon.0 || leading[addon.0] != RATIONAL_ZERO { - continue; - } - - leading[addon.0] = addon.1.clone(); - return Some(leading); - }, - _ => {}, - } - } - None -} - -*/ - - -/* -fn analyze(egraph: &EGraph, _id: Id) { - let mut types: HashMap = HashMap::new(); - let mut todo: VecDeque = VecDeque::new(); - // todo.push_back(runner.roots[0]); - for cls in egraph.classes() { - todo.push_back(cls.id); - } - - 'todo: while todo.len() > 0 { - let id = todo.pop_front().unwrap(); - if types.contains_key(&id) { - continue 'todo; - } - - if let Some(c) = &egraph[id].data { - types.insert(id, SpecialTerm::Constant(c.clone())); - continue 'todo; - } - - 'nodes: for n in &egraph[id].nodes { - match *n { - EquationLanguage::Unknown => { - types.insert(id, SpecialTerm::PowerOfX(1)); - continue 'todo; - }, - EquationLanguage::Mul([a,b]) => { - if !types.contains_key(&a) { - todo.push_back(a); - todo.push_back(id); - continue 'nodes; - } - - if !types.contains_key(&b) { - todo.push_back(b); - todo.push_back(id); - continue 'nodes; - } - - match (&types[&a], &types[&b]) { - (SpecialTerm::PowerOfX(dega), SpecialTerm::PowerOfX(degb)) => { - types.insert(id, SpecialTerm::PowerOfX(*dega + *degb)); - }, - (SpecialTerm::Constant(coeff), SpecialTerm::PowerOfX(deg)) => { - types.insert(id, SpecialTerm::Monomial(*deg, coeff.clone())); - }, - _ => { continue 'nodes; }, - } - continue 'todo; - }, - EquationLanguage::Add([a,b]) => { - if !types.contains_key(&a) { - todo.push_front(a); - todo.push_back(id); - continue 'todo; - } - - if !types.contains_key(&b) { - todo.push_front(b); - todo.push_back(id); - continue 'todo; - } - - match (&types[&a], &types[&b]) { - (SpecialTerm::MonicNonconstPoly(poly), SpecialTerm::Monomial(deg, coeff)) => { - if poly.len() <= *deg || poly[*deg] != RATIONAL_ZERO { - continue 'nodes; - } - - let mut poly = poly.clone(); - poly[*deg] = coeff.clone(); - types.insert(id, SpecialTerm::MonicNonconstPoly(poly)); - }, - _ => { continue 'nodes; }, - } - continue 'todo; - }, - _ => {}, + let expr = parse_equation(&sol2).unwrap(); + let runner = Runner::default() + .with_expr(&expr) + .run(&*RULES); + let extractor = Extractor::new(&runner.egraph, AstSize); + let (_, simplified_expr) = extractor.find_best(runner.roots[0]); + solutions.push(format!("x = {}", print_term(&simplified_expr))); } } - types.insert(id, SpecialTerm::Other); - } - - for (id, ty) in &types { - if !matches!(ty, &SpecialTerm::Other) { - println!("{:?}", &ty); - } + println!("Solutions: {{ {} }}", solutions.join(", ")); + println!(""); } } -*/ + +fn get_expression_cost(expr: &RecExpr) -> FactorizationCost { + let mut egraph = EGraph::new(ConstantFold::default()); + let id = egraph.add_expr(expr); + let extractor = Extractor::new(&egraph, FactorizationCostFn); + let (cost, _) = extractor.find_best(id); + cost +} diff --git a/src/output.rs b/src/output.rs new file mode 100644 index 0000000..fae24ba --- /dev/null +++ b/src/output.rs @@ -0,0 +1,66 @@ +use egg::{RecExpr, Id}; +use crate::language::EquationLanguage; + +// there is already a Display implementation generated by define_langauge! +// but we want an alternative string conversion +pub fn print_term(expr: &RecExpr) -> String { + let root_id = Id::from(expr.as_ref().len()-1); + print_term_inner(expr, root_id).0 +} + +// the second result is the precedence of the top level op: 1 = '+-', 2 = '*/', 3 = '^', 4 = primitive +fn print_term_inner(expr: &RecExpr, id: Id) -> (String, usize) { + match &expr[id] { + EquationLanguage::Num(c) => { + (format!("{}", c), if c.denom == 1 { 4 } else { 2 }) + }, + EquationLanguage::Neg([a]) => { + (print_unary(expr, *a, "-", 1), 1) + }, + EquationLanguage::Add([a,b]) => { + (print_binary(expr, *a, *b, "+", 1), 1) + }, + EquationLanguage::Sub([a,b]) => { + (print_binary(expr, *a, *b, "-", 1), 1) + }, + EquationLanguage::Mul([a,b]) => { + (print_binary(expr, *a, *b, "*", 2), 2) + }, + EquationLanguage::Div([a,b]) => { + (print_binary(expr, *a, *b, "/", 2), 2) + }, + EquationLanguage::Power([a,b]) => { + (print_binary(expr, *a, *b, "^", 3), 3) + }, + _ => unimplemented!() + } +} + +fn print_unary(expr: &RecExpr, a: Id, op: &str, precedence: usize) -> String { + let (astr, aprec) = print_term_inner(expr, a); + + if aprec > precedence { + format!("{}{}", op, astr) + } else { + format!("{}({})", op, astr) + } +} + +fn print_binary(expr: &RecExpr, a: Id, b: Id, op: &str, precedence: usize) -> String { + let (astr, aprec) = print_term_inner(expr, a); + let (bstr, bprec) = print_term_inner(expr, b); + + if aprec > precedence { + if bprec > precedence { + format!("{} {} {}", astr, op, bstr) + } else { + format!("{} {} ({})", astr, op, bstr) + } + } else { + if bprec > precedence { + format!("({}) {} {}", astr, op, bstr) + } else { + format!("({}) {} ({})", astr, op, bstr) + } + } +} diff --git a/src/parse.rs b/src/parse.rs index 2498e98..50e2817 100644 --- a/src/parse.rs +++ b/src/parse.rs @@ -26,15 +26,15 @@ fn parse_equation_inner(input: &str, expr: &mut RecExpr) -> Re } match c { - '^' if precedence > 3 => { + '^' if precedence >= 3 => { operator_position = Some(i); precedence = 3; }, - '*' | '/' if precedence > 2 => { + '*' | '/' if precedence >= 2 => { operator_position = Some(i); precedence = 2; }, - '-' | '+' if precedence > 1 => { + '-' | '+' if precedence >= 1 => { operator_position = Some(i); precedence = 1; },