From 0b8bdf3da65d1c0269a5273f8a33dd5d3c578388 Mon Sep 17 00:00:00 2001 From: Florian Stecker Date: Wed, 28 Aug 2024 12:31:28 -0400 Subject: [PATCH] cleaned up --- src/factorization.rs | 256 ------------------------------------------ src/language.rs | 8 +- src/lib.rs | 1 - src/main.rs | 147 +++++++++++------------- src/normal_form.rs | 258 +++++-------------------------------------- src/parse.rs | 2 +- 6 files changed, 98 insertions(+), 574 deletions(-) delete mode 100644 src/factorization.rs diff --git a/src/factorization.rs b/src/factorization.rs deleted file mode 100644 index 1c19f84..0000000 --- a/src/factorization.rs +++ /dev/null @@ -1,256 +0,0 @@ -use core::fmt; -use std::{cmp::Ordering, fmt::Display}; -use crate::language::{EquationLanguage, Rational, RATIONAL_ONE, RATIONAL_ZERO}; -use egg::{Id, RecExpr}; - - -#[derive(Debug,Clone,Copy)] -pub struct PolyStat { - degree: usize, - factors: usize, // non-constant factors - ops: usize, - monomial: bool, - sum_of_monomials: bool, - monic: bool, - factorized: bool, // a product of monic polynomials and at least one constant -} - -#[derive(Debug,Clone,Copy)] -pub enum FactorizationCost { - UnwantedOps, - Polynomial(PolyStat) -} - -fn score(cost: FactorizationCost) -> usize { - match cost { - FactorizationCost::UnwantedOps => 10000, - FactorizationCost::Polynomial(p) => - if !p.factorized { - 1000 + p.ops - } else { - 100 * (9 - p.factors) + p.ops - }, - } -} - -impl PartialEq for FactorizationCost { - fn eq(&self, other: &Self) -> bool { - score(*self) == score(*other) - } -} - -impl PartialOrd for FactorizationCost { - fn partial_cmp(&self, other: &Self) -> Option { - usize::partial_cmp(&score(*self), &score(*other)) - } -} - -pub struct FactorizationCostFn; -impl egg::CostFunction for FactorizationCostFn { - type Cost = FactorizationCost; - - fn cost(&mut self, enode: &EquationLanguage, mut costs: C) -> Self::Cost - where - C: FnMut(Id) -> Self::Cost, - { - match enode { - EquationLanguage::Add([a,b]) => { - match (costs(*a), costs(*b)) { - (FactorizationCost::Polynomial(p1),FactorizationCost::Polynomial(p2)) => { - // we only ever want to add monomials - let result_monic = if p1.degree > p2.degree { - p1.monic - } else if p2.degree > p1.degree { - p2.monic - } else { - false - }; - - /* - if *a == Id::from(4) && *b == Id::from(19) { - println!("HERE {:?} {:?}", p1, p2); - } - */ - - - if !p1.sum_of_monomials || !p2.sum_of_monomials { - FactorizationCost::UnwantedOps - } else { - FactorizationCost::Polynomial(PolyStat { - degree: usize::max(p1.degree, p2.degree), - factors: 1, - ops: p1.ops + p2.ops + 1, - monomial: false, - sum_of_monomials: p1.sum_of_monomials && p2.sum_of_monomials, - monic: result_monic, - factorized: result_monic, - }) - } - }, - _ => FactorizationCost::UnwantedOps - } - }, - EquationLanguage::Mul([a,b]) => { - match (costs(*a), costs(*b)) { - (FactorizationCost::Polynomial(p1), FactorizationCost::Polynomial(p2)) => { - FactorizationCost::Polynomial(PolyStat { - degree: p1.degree + p2.degree, - factors: p1.factors + p2.factors, - ops: p1.ops + p2.ops + 1, - monomial: p1.monomial && p2.monomial, - sum_of_monomials: p1.monomial && p2.monomial, - monic: p1.monic && p2.monic, - factorized: (p1.monic && p2.factorized) || (p2.monic && p1.factorized) - }) - }, - _ => FactorizationCost::UnwantedOps - } - }, - EquationLanguage::Num(c) => { - FactorizationCost::Polynomial(PolyStat { - degree: 0, - factors: 0, - ops: 0, - monomial: true, - sum_of_monomials: true, - monic: false, - factorized: true - }) - }, - EquationLanguage::Unknown => { - FactorizationCost::Polynomial(PolyStat { - degree: 1, - factors: 1, - ops: 0, - monomial: true, - sum_of_monomials: true, - monic: true, - factorized: true - }) - }, - _ => FactorizationCost::UnwantedOps, - } - } -} - -#[derive(Debug,Clone)] -pub struct Factorization { - pub constant_factor: Rational, - pub polynomials: Vec>, -} - -impl Display for Factorization { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - if self.constant_factor != RATIONAL_ONE { - write!(f, "{}", self.constant_factor)?; - } - - for poly in &self.polynomials { - write!(f, "(")?; - for (deg, coeff) in poly.iter().enumerate() { - if deg == 0 { - write!(f, "{}", coeff)?; - } else if deg == 1 { - write!(f, " + {}x", coeff)?; - } else { - write!(f, " + {}x^{}", coeff, deg)?; - } - } - write!(f, ")")?; - } - - Ok(()) - } -} - -pub fn extract_factorization(expr: &RecExpr) -> Factorization { - let root_id: Id = Id::from(expr.as_ref().len()-1); - - let mut constant_factor: Option = None; - let mut factors: Vec> = Vec::new(); - let mut todo: Vec = Vec::new(); - todo.push(root_id); - - while todo.len() > 0 { - let id = todo.pop().unwrap(); - - match &expr[id] { - EquationLanguage::Mul([a,b]) => { - todo.push(*a); - todo.push(*b); - }, - EquationLanguage::Num(x) => { - assert!(constant_factor.is_none()); - constant_factor = Some(x.clone()); - }, - _ => { - factors.push(extract_polynomial(expr, id)); - } - } - } - - Factorization { - constant_factor: constant_factor.unwrap_or_else(||RATIONAL_ONE.clone()), - polynomials: factors - } -} - -fn extract_polynomial(expr: &RecExpr, id: Id) -> Vec { - let mut result: Vec = Vec::new(); - let mut todo: Vec = Vec::new(); - todo.push(id); - - while todo.len() > 0 { - let id = todo.pop().unwrap(); - - match &expr[id] { - EquationLanguage::Add([a,b]) => { - todo.push(*a); - todo.push(*b); - }, - _ => { - let (deg, coeff) = extract_monomial(expr, id); - result.resize(result.len().max(deg), RATIONAL_ZERO.clone()); - - if result.len() <= deg { - result.push(coeff); - } else { - assert!(result[deg] == RATIONAL_ZERO); - result[deg] = coeff; - } - } - } - } - - result -} - -fn extract_monomial(expr: &RecExpr, id: Id) -> (usize, Rational) { - let mut coeff: Option = None; - let mut deg: usize = 0; - let mut todo: Vec = Vec::new(); - todo.push(id); - - while todo.len() > 0 { - let id = todo.pop().unwrap(); - - match &expr[id] { - EquationLanguage::Unknown => { - deg += 1; - }, - EquationLanguage::Mul([a,b]) => { - todo.push(*a); - todo.push(*b); - }, - EquationLanguage::Num(x) => { - assert!(coeff.is_none()); - coeff = Some(x.clone()); - }, - _ => { - panic!("Not a rational polynomial in normal form!"); - } - } - } - - (deg, coeff.unwrap_or_else(||RATIONAL_ONE.clone())) -} diff --git a/src/language.rs b/src/language.rs index c3e430a..b3918ff 100644 --- a/src/language.rs +++ b/src/language.rs @@ -1,5 +1,5 @@ -use std::{cmp::Ordering, fmt::{self,Display, Formatter}, num::ParseIntError, str::FromStr, sync::LazyLock}; -use egg::{define_language, merge_option, rewrite as rw, Analysis, Applier, DidMerge, Id, Language, PatternAst, Subst, Symbol, SymbolLang, Var}; +use std::{cmp::Ordering, fmt::{self,Display, Formatter}, str::FromStr, sync::LazyLock}; +use egg::{define_language, merge_option, rewrite as rw, Analysis, Applier, DidMerge, Id, Language, PatternAst, Subst, Symbol, Var}; pub type EGraph = egg::EGraph; pub type Rewrite = egg::Rewrite; @@ -193,8 +193,8 @@ impl Applier for IntegerSqrt { egraph: &mut EGraph, matched_id: Id, subst: &Subst, - searcher_pattern: Option<&PatternAst>, - rule_name: Symbol) + _searcher_pattern: Option<&PatternAst>, + _rule_name: Symbol) -> Vec { let var_id = subst[self.var]; diff --git a/src/lib.rs b/src/lib.rs index 809f05c..9c8d9de 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,4 @@ pub mod language; pub mod normal_form; pub mod parse; -pub mod factorization; pub mod output; diff --git a/src/main.rs b/src/main.rs index cb665fb..2045bc1 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,6 +1,5 @@ -use egg::{AstSize, Extractor, Id, Pattern, PatternAst, RecExpr, Runner, Searcher}; -use solveq::factorization::{FactorizationCost, FactorizationCostFn}; -use solveq::language::{ConstantFold, EquationLanguage, Rational, RULES, EGraph}; +use egg::{AstSize, Extractor, Id, RecExpr, Runner}; +use solveq::language::{EquationLanguage, Rational, RULES}; use solveq::normal_form::extract_normal_form; use solveq::parse::parse_equation; use solveq::output::print_term; @@ -15,94 +14,80 @@ static TEST_EQUATIONS: &[&str] = &[ "x ^ 3 + 3 * x ^ 2 - 25 * x - 75 = 0", ]; - fn main() { - // let expr: RecExpr = "(+ (* x (+ x -2)) -15)".parse().unwrap(); -// let expr: RecExpr = "(* (+ (* x (+ x -2)) -15) (+ x 5))".parse().unwrap(); -// println!("{:?}", get_expression_cost(&expr)); - for eq in TEST_EQUATIONS { println!("Equation: {}", *eq); - let mut start = parse_equation(*eq).unwrap(); - let root_id = Id::from(start.as_ref().len()-1); - let EquationLanguage::Equals([left, right]) = start[root_id] - else { panic!("Not an equation without an equals sign!"); }; - start[root_id] = EquationLanguage::Sub([left, right]); + let solutions = solve(*eq, true); - println!("Parsed: {}", &start); + let solutions_str: Vec = solutions + .iter() + .map(|expr| format!("x = {}", print_term(expr))) + .collect(); + println!("Solutions: {{ {} }}", solutions_str.join(", ")); - let mut runner = Runner::default() - .with_explanations_enabled() - .with_expr(&start) - .run(&*RULES); - - let extractor = Extractor::new(&runner.egraph, FactorizationCostFn); - - let (best_cost, best_expr) = extractor.find_best(runner.roots[0]); - -// println!("Best expresssion: {} {:?}", best_expr, best_cost); -// println!("{}", runner.explain_equivalence(&start, &best_expr).get_flat_string()); - - let Some(factorization) = extract_normal_form(&runner.egraph, runner.roots[0]) else { - panic!("Couldn't factorize polynomial!"); - }; - println!("Factorized normal form: {}", factorization); - - /* - get_expression_cost("(* x (+ x -2))", &runner.egraph); - get_expression_cost("-15", &runner.egraph); - get_expression_cost("(+ (* x (+ x -2)) -15)", &runner.egraph); - */ - -// let factorization = extract_factorization(&best_expr); - - - let mut solutions: Vec = Vec::new(); - for poly in &factorization.polynomials { - if poly.len() == 2 { // linear factor - let Rational { num, denom } = &poly[0]; - solutions.push(format!("x = {}", Rational { num: -*num, denom: *denom })); - } else if poly.len() == 3 { // quadratic factor - let Rational { num: num0, denom: denom0 } = &poly[0]; - let Rational { num: num1, denom: denom1 } = &poly[1]; - - let sol1 = format!("- ({num1})/(2 * ({denom1})) + ((({num1})/(2 * ({denom1}))) ^ 2 - ({num0}) / ({denom0})) ^ (1/2)"); - let sol2 = format!("- ({num1})/(2 * ({denom1})) - ((({num1})/(2 * ({denom1}))) ^ 2 - ({num0}) / ({denom0})) ^ (1/2)"); - - let expr = parse_equation(&sol1).unwrap(); - let runner = Runner::default() - .with_expr(&expr) - .run(&*RULES); - let extractor = Extractor::new(&runner.egraph, AstSize); - let (_, simplified_expr) = extractor.find_best(runner.roots[0]); - solutions.push(format!("x = {}", print_term(&simplified_expr))); - - let expr = parse_equation(&sol2).unwrap(); - let runner = Runner::default() - .with_expr(&expr) - .run(&*RULES); - let extractor = Extractor::new(&runner.egraph, AstSize); - let (_, simplified_expr) = extractor.find_best(runner.roots[0]); - solutions.push(format!("x = {}", print_term(&simplified_expr))); - } - } - - println!("Solutions: {{ {} }}", solutions.join(", ")); println!(""); } } -fn get_expression_cost(expr: &str, egraph: &EGraph) { -// let mut egraph = EGraph::new(ConstantFold::default()); - // let id = egraph.add_expr(expr); - let pattern: Pattern = expr.parse().unwrap(); - let matches = pattern.search(egraph); - for m in matches { - let extractor = Extractor::new(&egraph, FactorizationCostFn); - let (cost, _) = extractor.find_best(m.eclass); +pub fn solve(eq: &str, verbose: bool) -> Vec> { + let mut start = parse_equation(eq).unwrap(); + let root_id = Id::from(start.as_ref().len()-1); + let EquationLanguage::Equals([left, right]) = start[root_id] + else { panic!("Not an equation without an equals sign!"); }; + start[root_id] = EquationLanguage::Sub([left, right]); - println!("expr: {}, id: {}, cost: {:?}", expr, m.eclass, cost); + if verbose { + println!("Parsed: {}", &start); } -// cost + + let runner = Runner::default() + .with_explanations_enabled() + .with_expr(&start) + .run(&*RULES); + + let Some(factorization) = extract_normal_form(&runner.egraph, runner.roots[0]) else { + panic!("Couldn't factorize polynomial!"); + }; + + if verbose { + println!("Factorized normal form: {}", factorization); + } + + let mut solutions: Vec> = Vec::new(); + for poly in &factorization.polynomials { + if poly.len() == 2 { // linear factor + let Rational { num, denom } = &poly[0]; + let mut solexpr = RecExpr::default(); + solexpr.add(EquationLanguage::Num(Rational { + num: -*num, + denom: *denom + })); + solutions.push(solexpr); + } else if poly.len() == 3 { // quadratic factor + let Rational { num: num0, denom: denom0 } = &poly[0]; + let Rational { num: num1, denom: denom1 } = &poly[1]; + + let sol1 = format!("- ({num1})/(2 * ({denom1})) + ((({num1})/(2 * ({denom1}))) ^ 2 - ({num0}) / ({denom0})) ^ (1/2)"); + let sol2 = format!("- ({num1})/(2 * ({denom1})) - ((({num1})/(2 * ({denom1}))) ^ 2 - ({num0}) / ({denom0})) ^ (1/2)"); + + let expr = parse_equation(&sol1).unwrap(); + let runner = Runner::default() + .with_expr(&expr) + .run(&*RULES); + let extractor = Extractor::new(&runner.egraph, AstSize); + let (_, simplified_expr) = extractor.find_best(runner.roots[0]); + solutions.push(simplified_expr); + + let expr = parse_equation(&sol2).unwrap(); + let runner = Runner::default() + .with_expr(&expr) + .run(&*RULES); + let extractor = Extractor::new(&runner.egraph, AstSize); + let (_, simplified_expr) = extractor.find_best(runner.roots[0]); + solutions.push(simplified_expr); + } + } + + solutions } diff --git a/src/normal_form.rs b/src/normal_form.rs index 777f458..8ca3624 100644 --- a/src/normal_form.rs +++ b/src/normal_form.rs @@ -1,16 +1,6 @@ use crate::language::{EGraph, EquationLanguage, Rational, RATIONAL_ONE, RATIONAL_ZERO}; use std::{collections::HashMap, fmt}; -use egg::{AstSize, Extractor, Id}; - -#[derive(Debug,Clone)] -pub enum SpecialTerm { - Constant(Rational), - PowerOfX(usize), - Monomial(usize, Rational), - MonicNonconstPoly(Vec), - Factorization(Rational, Vec>), - Other, -} +use egg::Id; #[derive(Debug,Clone)] pub struct Factorization { @@ -18,7 +8,32 @@ pub struct Factorization { pub polynomials: Vec>, } -// this is a property of an eclass, not a particular expression +impl fmt::Display for Factorization { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + if self.constant_factor != RATIONAL_ONE { + write!(f, "{}", self.constant_factor)?; + } + + for poly in &self.polynomials { + write!(f, "(")?; + for (deg, coeff) in poly.iter().enumerate() { + if deg == 0 { + write!(f, "{}", coeff)?; + } else if deg == 1 { + write!(f, " + {}x", coeff)?; + } else { + write!(f, " + {}x^{}", coeff, deg)?; + } + } + write!(f, ")")?; + } + + Ok(()) + } +} + +// warning: this is a property of an eclass, not a particular expression +// so we shouldn't use anything that's not a polynomial invariant #[derive(Debug,Clone)] struct PolyStats { degree: usize, @@ -141,10 +156,6 @@ fn extract_polynomial(egraph: &EGraph, stats: &HashMap, id: Id) - // a monomial is either a power of x, a constant, or a product of a constant and power of x fn extract_monomial(egraph: &EGraph, stats: &HashMap, id: Id) -> Option<(usize, Rational)> { - let extractor = Extractor::new(egraph, AstSize); - let (_, expr) = extractor.find_best(id); -// println!("Extract Monomial: {}", expr); - let st = &stats[&id]; assert!(st.monomial); @@ -221,7 +232,6 @@ fn find_monic_factorization(egraph: &EGraph, stats: &HashMap, id: // now the whole thing is a monic nonconst poly, so would be a valid factorization, // but we want to go as deep as possible -// println!("{:?}", stats[&id]); // check if it is the product of two nonconstant monic polynomials for node in &egraph[id].nodes { @@ -260,8 +270,6 @@ where let mut result: HashMap = HashMap::new(); let mut modifications: usize = 1; -// println!("{:?}", egraph[canonical]); - while modifications > 0 { modifications = 0; @@ -285,215 +293,3 @@ where result } - -/* -pub fn analyze3(egraph: &EGraph, eclass: Id) { - let constants = search_for(egraph, |id, _, _| - egraph[id].data.as_ref().map(|c|c.clone()) - ); - - println!("{:?}", constants); - - let powers_of_x = search_for(egraph, |_, node, matches| match *node { - EquationLanguage::Unknown => Some(1), - EquationLanguage::Mul([a,b]) => { - if !matches.contains_key(&a) || !matches.contains_key(&b) { - return None; - } - - let (dega, degb) = (matches[&a], matches[&b]); - Some(dega + degb) - }, - _ => None, - }); - - println!("{:?}", powers_of_x); - - let monomials = search_for(egraph, |id, node, _| { - if let Some(deg) = powers_of_x.get(&id) { - return Some((*deg, RATIONAL_ONE.clone())); - } - - if let Some(c) = constants.get(&id) { - return Some((0, c.clone())); - } - - match *node { - EquationLanguage::Mul([a,b]) => { - if !constants.contains_key(&a) || !powers_of_x.contains_key(&b) { - return None; - } - - let (coeff, deg) = (&constants[&a], powers_of_x[&b]); - Some((deg, coeff.clone())) - }, - _ => None, - } - }); - - println!("{:?}", monomials); - - let monic_polynomials = search_for(egraph, |id, node, matches| { - if let Some(deg) = powers_of_x.get(&id) { - let mut poly: Vec = Vec::new(); - poly.resize(*deg, RATIONAL_ZERO); - poly.push(RATIONAL_ONE.clone()); - Some(poly) - } else { - match *node { - EquationLanguage::Add([a,b]) => { - if !matches.contains_key(&a) || !monomials.contains_key(&b) { - return None; - } - - let (leading, (deg, coeff)) = (&matches[&a], &monomials[&b]); - if leading.len() <= *deg || leading[*deg] != RATIONAL_ZERO { - return None; - } - - let mut poly = leading.clone(); - poly[*deg] = coeff.clone(); - Some(poly) - }, - _ => None, - } - } - }); - - for p in &monic_polynomials { - println!("{:?}", p); - } - - let factorizations: HashMap>)> = search_for(egraph, |id, node, matches| { - if let Some(c) = constants.get(&id) { - return Some((c.clone(), vec![])); - } - - if let Some(poly) = monic_polynomials.get(&id) { - return Some((RATIONAL_ONE.clone(), vec![poly.clone()])); - } - - match *node { - EquationLanguage::Mul([a,b]) => { - if !matches.contains_key(&a) || !monic_polynomials.contains_key(&b) { - return None; - } - - let ((factor, polys), newpoly) = (&matches[&a], &monic_polynomials[&b]); - - let mut combined: Vec> = polys.clone(); - combined.push(newpoly.clone()); - Some((factor.clone(), combined)) - }, - _ => None, - } - }); - - /* - for p in &factorizations { - println!("{:?}", p); -} - */ - - println!("{:?}", factorizations[&eclass]); -} - -pub fn analyze2(egraph: &EGraph) -> HashMap { - let mut types: HashMap = HashMap::new(); - - let mut modifications: usize = 1; - - while modifications > 0 { - modifications = 0; - - for cls in egraph.classes() { - let id = cls.id; - if types.contains_key(&id) { - continue; - } - - if let Some(c) = &egraph[id].data { - types.insert(id, SpecialTerm::Constant(c.clone())); - modifications += 1; - continue; - } - - for node in &cls.nodes { - match *node { - EquationLanguage::Unknown => { - types.insert(id, SpecialTerm::PowerOfX(1)); - modifications += 1; - }, - EquationLanguage::Mul([a,b]) => { - // as we don't know a and b yet, defer to future iteration - if !types.contains_key(&a) || !types.contains_key(&b) { - continue; - } - - match (&types[&a], &types[&b]) { - (SpecialTerm::PowerOfX(dega), SpecialTerm::PowerOfX(degb)) => { - types.insert(id, SpecialTerm::PowerOfX(*dega + *degb)); - modifications += 1; - }, - (SpecialTerm::Constant(coeff), SpecialTerm::PowerOfX(deg)) => { - types.insert(id, SpecialTerm::Monomial(*deg, coeff.clone())); - modifications += 1; - }, - _ => { }, - } - }, - EquationLanguage::Add([a,b]) => { - // as we don't know a and b yet, defer to future iteration - if !types.contains_key(&a) || !types.contains_key(&b) { - continue; - } - - match (&types[&a], &types[&b]) { - (SpecialTerm::MonicNonconstPoly(poly), SpecialTerm::Monomial(deg, coeff)) => { - if poly.len() <= *deg || poly[*deg] != RATIONAL_ZERO { - continue; - } - - let mut poly = poly.clone(); - poly[*deg] = coeff.clone(); - types.insert(id, SpecialTerm::MonicNonconstPoly(poly)); - modifications += 1; - }, - _ => { }, - } - }, - _ => {} - } - } - } - - println!("{} modifications!", modifications); - } - - types -} -*/ - -impl fmt::Display for Factorization { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - if self.constant_factor != RATIONAL_ONE { - write!(f, "{}", self.constant_factor)?; - } - - for poly in &self.polynomials { - write!(f, "(")?; - for (deg, coeff) in poly.iter().enumerate() { - if deg == 0 { - write!(f, "{}", coeff)?; - } else if deg == 1 { - write!(f, " + {}x", coeff)?; - } else { - write!(f, " + {}x^{}", coeff, deg)?; - } - } - write!(f, ")")?; - } - - Ok(()) - } -} diff --git a/src/parse.rs b/src/parse.rs index 50e2817..c47ddc0 100644 --- a/src/parse.rs +++ b/src/parse.rs @@ -1,5 +1,5 @@ use std::error::Error; -use egg::*; +use egg::{RecExpr, Id, FromOp, FromOpError}; use crate::language::EquationLanguage; pub fn parse_equation(input: &str) -> Result, ParseError> {