From ef2a76869f77d18412d85a06bfdb6a69d79f0717 Mon Sep 17 00:00:00 2001 From: Florian Stecker Date: Wed, 28 Aug 2024 11:25:55 -0400 Subject: [PATCH] fixed it! --- src/factorization.rs | 11 +- src/language.rs | 135 ++------------------- src/main.rs | 50 +++++--- src/normal_form.rs | 282 ++++++++++++++++++++++++++++++++++++++++++- 4 files changed, 326 insertions(+), 152 deletions(-) diff --git a/src/factorization.rs b/src/factorization.rs index c118654..1c19f84 100644 --- a/src/factorization.rs +++ b/src/factorization.rs @@ -66,13 +66,20 @@ impl egg::CostFunction for FactorizationCostFn { false }; + /* + if *a == Id::from(4) && *b == Id::from(19) { + println!("HERE {:?} {:?}", p1, p2); + } + */ + + if !p1.sum_of_monomials || !p2.sum_of_monomials { FactorizationCost::UnwantedOps } else { FactorizationCost::Polynomial(PolyStat { degree: usize::max(p1.degree, p2.degree), factors: 1, - ops: p1.ops + p2.ops, + ops: p1.ops + p2.ops + 1, monomial: false, sum_of_monomials: p1.sum_of_monomials && p2.sum_of_monomials, monic: result_monic, @@ -89,7 +96,7 @@ impl egg::CostFunction for FactorizationCostFn { FactorizationCost::Polynomial(PolyStat { degree: p1.degree + p2.degree, factors: p1.factors + p2.factors, - ops: p1.ops + p2.ops, + ops: p1.ops + p2.ops + 1, monomial: p1.monomial && p2.monomial, sum_of_monomials: p1.monomial && p2.monomial, monic: p1.monic && p2.monic, diff --git a/src/language.rs b/src/language.rs index 218a5a0..04b8463 100644 --- a/src/language.rs +++ b/src/language.rs @@ -167,6 +167,13 @@ fn is_nonzero_const(var: &str) -> impl Fn(&mut EGraph, Id, &Subst) -> bool { } } +fn is_const(var: &str) -> impl Fn(&mut EGraph, Id, &Subst) -> bool { + let var: Var = var.parse().unwrap(); + move |egraph, _, subst| { + egraph[subst[var]].data.is_some() + } +} + pub static RULES: LazyLock> = LazyLock::new(||vec![ rw!("commute-add"; "(+ ?x ?y)" => "(+ ?y ?x)"), rw!("commute-mul"; "(* ?x ?y)" => "(* ?y ?x)"), @@ -187,10 +194,6 @@ pub static RULES: LazyLock> = LazyLock::new(||vec![ rw!("square"; "(^ ?x 2)" => "(* ?x ?x)"), rw!("cube"; "(^ ?x 3)" => "(* ?x (* ?x ?x))"), - /* - rw!("inv-square"; "(* ?x ?x)" => "(^ ?x 2)"), - rw!("inv-cube"; "(* ?x (* ?x ?x))" => "(^ ?x 3)"), - */ rw!("sub"; "(- ?x ?y)" => "(+ ?x (* -1 ?y))"), rw!("neg"; "(- ?x)" => "(* -1 ?x)"), @@ -198,6 +201,8 @@ pub static RULES: LazyLock> = LazyLock::new(||vec![ rw!("div"; "(/ ?x ?y)" => "(* ?x (rec ?y))" if is_nonzero_const("?y")), rw!("factor_poly"; "(+ (* x ?x) ?y)" => "(* ?x (+ x (* ?y (rec ?x))))" if is_nonzero_const("?x")), + +// rw!("integer_sqrt"; "(^ ?x (1/2))" => {} if is_const("?x")), ]); pub struct PlusTimesCostFn; @@ -218,125 +223,3 @@ impl egg::CostFunction for PlusTimesCostFn { enode.fold(op_cost, |sum, i| sum + costs(i)) } } - -#[derive(Debug,Clone,Copy)] -pub struct PolyStat { - degree: usize, - factors: usize, // non-constant factors - ops: usize, - monomial: bool, - sum_of_monomials: bool, - monic: bool, - factorized: bool, // a product of monic polynomials and at least one constant -} - -#[derive(Debug,Clone,Copy)] -pub enum FactorizationCost { - UnwantedOps, - Polynomial(PolyStat) -} - -fn score(cost: FactorizationCost) -> usize { - match cost { - FactorizationCost::UnwantedOps => 10000, - FactorizationCost::Polynomial(p) => - if !p.factorized { - 1000 - } else { - 100 * (9 - p.factors) + p.ops - }, - } -} - -impl PartialEq for FactorizationCost { - fn eq(&self, other: &Self) -> bool { - score(*self) == score(*other) - } -} - -impl PartialOrd for FactorizationCost { - fn partial_cmp(&self, other: &Self) -> Option { - usize::partial_cmp(&score(*self), &score(*other)) - } -} - -pub struct FactorizationCostFn; -impl egg::CostFunction for FactorizationCostFn { - type Cost = FactorizationCost; - - fn cost(&mut self, enode: &EquationLanguage, mut costs: C) -> Self::Cost - where - C: FnMut(Id) -> Self::Cost, - { - match enode { - EquationLanguage::Add([a,b]) => { - match (costs(*a), costs(*b)) { - (FactorizationCost::Polynomial(p1),FactorizationCost::Polynomial(p2)) => { - // we only ever want to add monomials - let result_monic = if p1.degree > p2.degree { - p1.monic - } else if p2.degree > p1.degree { - p2.monic - } else { - false - }; - - if !p1.sum_of_monomials || !p2.sum_of_monomials { - FactorizationCost::UnwantedOps - } else { - FactorizationCost::Polynomial(PolyStat { - degree: usize::max(p1.degree, p2.degree), - factors: 1, - ops: p1.ops + p2.ops, - monomial: false, - sum_of_monomials: p1.sum_of_monomials && p2.sum_of_monomials, - monic: result_monic, - factorized: result_monic, - }) - } - }, - _ => FactorizationCost::UnwantedOps - } - }, - EquationLanguage::Mul([a,b]) => { - match (costs(*a), costs(*b)) { - (FactorizationCost::Polynomial(p1), FactorizationCost::Polynomial(p2)) => { - FactorizationCost::Polynomial(PolyStat { - degree: p1.degree + p2.degree, - factors: p1.factors + p2.factors, - ops: p1.ops + p2.ops, - monomial: p1.monomial && p2.monomial, - sum_of_monomials: p1.monomial && p2.monomial, - monic: p1.monic && p2.monic, - factorized: (p1.monic && p2.factorized) || (p2.monic && p1.factorized) - }) - }, - _ => FactorizationCost::UnwantedOps - } - }, - EquationLanguage::Num(c) => { - FactorizationCost::Polynomial(PolyStat { - degree: 0, - factors: 0, - ops: 0, - monomial: true, - sum_of_monomials: true, - monic: false, - factorized: true - }) - }, - EquationLanguage::Unknown => { - FactorizationCost::Polynomial(PolyStat { - degree: 1, - factors: 1, - ops: 0, - monomial: true, - sum_of_monomials: true, - monic: true, - factorized: true - }) - }, - _ => FactorizationCost::UnwantedOps, - } - } -} diff --git a/src/main.rs b/src/main.rs index c0c44f2..cb665fb 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,7 +1,7 @@ -use egg::{AstSize, EGraph, Extractor, Id, Pattern, RecExpr, Runner}; -use solveq::factorization::{extract_factorization, FactorizationCost, FactorizationCostFn}; -use solveq::language::{ConstantFold, EquationLanguage, FactorizationCostFn, PlusTimesCostFn, Rational, RULES}; -use solveq::normal_form::analyze3; +use egg::{AstSize, Extractor, Id, Pattern, PatternAst, RecExpr, Runner, Searcher}; +use solveq::factorization::{FactorizationCost, FactorizationCostFn}; +use solveq::language::{ConstantFold, EquationLanguage, Rational, RULES, EGraph}; +use solveq::normal_form::extract_normal_form; use solveq::parse::parse_equation; use solveq::output::print_term; @@ -17,8 +17,9 @@ static TEST_EQUATIONS: &[&str] = &[ fn main() { - let expr: RecExpr = "(* x (+ x -2))".parse().unwrap(); - println!("{:?}", get_expression_cost(&expr)); + // let expr: RecExpr = "(+ (* x (+ x -2)) -15)".parse().unwrap(); +// let expr: RecExpr = "(* (+ (* x (+ x -2)) -15) (+ x 5))".parse().unwrap(); +// println!("{:?}", get_expression_cost(&expr)); for eq in TEST_EQUATIONS { println!("Equation: {}", *eq); @@ -40,16 +41,23 @@ fn main() { let (best_cost, best_expr) = extractor.find_best(runner.roots[0]); -// println!("{:?} {:?}", best_cost, as AsRef<[EquationLanguage]>>::as_ref(&best_expr)); - - println!("Best expresssion: {} {:?}", best_expr, best_cost); - - let factorization = extract_factorization(&best_expr); - +// println!("Best expresssion: {} {:?}", best_expr, best_cost); // println!("{}", runner.explain_equivalence(&start, &best_expr).get_flat_string()); + let Some(factorization) = extract_normal_form(&runner.egraph, runner.roots[0]) else { + panic!("Couldn't factorize polynomial!"); + }; println!("Factorized normal form: {}", factorization); + /* + get_expression_cost("(* x (+ x -2))", &runner.egraph); + get_expression_cost("-15", &runner.egraph); + get_expression_cost("(+ (* x (+ x -2)) -15)", &runner.egraph); + */ + +// let factorization = extract_factorization(&best_expr); + + let mut solutions: Vec = Vec::new(); for poly in &factorization.polynomials { if poly.len() == 2 { // linear factor @@ -85,10 +93,16 @@ fn main() { } } -fn get_expression_cost(expr: &RecExpr) -> FactorizationCost { - let mut egraph = EGraph::new(ConstantFold::default()); - let id = egraph.add_expr(expr); - let extractor = Extractor::new(&egraph, FactorizationCostFn); - let (cost, _) = extractor.find_best(id); - cost +fn get_expression_cost(expr: &str, egraph: &EGraph) { +// let mut egraph = EGraph::new(ConstantFold::default()); + // let id = egraph.add_expr(expr); + let pattern: Pattern = expr.parse().unwrap(); + let matches = pattern.search(egraph); + for m in matches { + let extractor = Extractor::new(&egraph, FactorizationCostFn); + let (cost, _) = extractor.find_best(m.eclass); + + println!("expr: {}, id: {}, cost: {:?}", expr, m.eclass, cost); + } +// cost } diff --git a/src/normal_form.rs b/src/normal_form.rs index 2d776b5..777f458 100644 --- a/src/normal_form.rs +++ b/src/normal_form.rs @@ -1,6 +1,6 @@ use crate::language::{EGraph, EquationLanguage, Rational, RATIONAL_ONE, RATIONAL_ZERO}; -use std::collections::HashMap; -use egg::Id; +use std::{collections::HashMap, fmt}; +use egg::{AstSize, Extractor, Id}; #[derive(Debug,Clone)] pub enum SpecialTerm { @@ -12,36 +12,281 @@ pub enum SpecialTerm { Other, } -fn search_for(egraph: &EGraph, f: F) -> HashMap +#[derive(Debug,Clone)] +pub struct Factorization { + pub constant_factor: Rational, + pub polynomials: Vec>, +} + +// this is a property of an eclass, not a particular expression +#[derive(Debug,Clone)] +struct PolyStats { + degree: usize, + monomial: bool, + monic: bool, +} + +fn gather_poly_stats(egraph: &EGraph) -> HashMap { + walk_egraph(egraph, |_id, node, stats: &HashMap| { + let x = |i: &Id| stats.get(&egraph.find(*i)); + Some(match node { + EquationLanguage::Unknown => PolyStats { + degree: 1, + monomial: true, + monic: true, + }, + EquationLanguage::Num(c) => PolyStats { + degree: 0, + monomial: true, + monic: c == &RATIONAL_ONE, + }, + EquationLanguage::Mul([a,b]) => { + // if both aren't monic we can't tell, the leading coefficients could cancel + // but there should be an alternative representative with one of them monic + if !x(a)?.monic && !x(b)?.monic { + return None; + } + + PolyStats { + degree: x(a)?.degree + x(b)?.degree, + monomial: x(a)?.monomial && x(b)?.monomial, + monic: x(a)?.monic && x(b)?.monic, + } + }, + EquationLanguage::Add([a,b]) => { + // in this case, there should also be a simplified representative which + // has only a single leading term + if x(a)?.degree == x(b)?.degree { + return None; + } + + PolyStats { + degree: usize::max(x(a)?.degree, x(b)?.degree), + monomial: false, + monic: if x(a)?.degree > x(b)?.degree { x(a)?.monic } else { x(b)?.monic }, + } + }, + _ => { return None; } + }) + }) +} + +pub fn extract_normal_form(egraph: &EGraph, eclass: Id) -> Option { + let eclass = egraph.find(eclass); // get the canonical eclass + + let stats = gather_poly_stats(egraph); + + let Some(factorization) = find_general_factorization(egraph, &stats, eclass) else { return None; }; + + let mut result = Vec::new(); + let mut coeff: Option = None; + + for factor in factorization { + let extracted = extract_polynomial(egraph, &stats, factor)?; +// println!("Extracted: {:?}", extracted); + + if extracted.len() == 1 { + coeff = Some(extracted[0].clone()); + } else { + result.push(extracted); + } + } + + Some(Factorization { + constant_factor: coeff.unwrap_or_else(||RATIONAL_ONE.clone()), + polynomials: result + }) +} + +// a polynomial should be either of: +// - a monomial: then we know the degree and we try to parse it as a product of a constant and a monic monomial +// or just a constant +// - a sum of a monomial of highest degree, and a polynomial of lower degree, recursively walk these +fn extract_polynomial(egraph: &EGraph, stats: &HashMap, id: Id) -> Option> { + let st = &stats[&id]; + + if st.monomial { + let (deg, coeff) = extract_monomial(egraph, stats, id)?; + + let mut result = vec![RATIONAL_ZERO; deg]; + result.push(coeff); + return Some(result); + } else { + for node in &egraph[id].nodes { + match node { + EquationLanguage::Add([a,b]) => { + let a = egraph.find(*a); + let b = egraph.find(*b); + let Some(stata) = &stats.get(&a) else { continue }; + let Some(statb) = &stats.get(&b) else { continue }; + + if stata.degree == st.degree && stata.monomial && statb.degree < st.degree { + let (leading_deg, leading_coeff) = extract_monomial(egraph, stats, a)?; + let mut remainder = extract_polynomial(egraph, stats, b)?; + + assert!(leading_deg >= remainder.len()); + + remainder.resize(leading_deg, RATIONAL_ZERO.clone()); + remainder.push(leading_coeff); + return Some(remainder); + } + }, + _ => {} + } + } + } + + None +} + +// a monomial is either a power of x, a constant, or a product of a constant and power of x +fn extract_monomial(egraph: &EGraph, stats: &HashMap, id: Id) -> Option<(usize, Rational)> { + let extractor = Extractor::new(egraph, AstSize); + let (_, expr) = extractor.find_best(id); +// println!("Extract Monomial: {}", expr); + + let st = &stats[&id]; + + assert!(st.monomial); + + // monic + monomial = power of x + if st.monic { + return Some((st.degree, RATIONAL_ONE.clone())); + } + + for node in &egraph[id].nodes { + match node { + EquationLanguage::Mul([a,b]) => { + let a = egraph.find(*a); + let b = egraph.find(*b); + let Some(statb) = stats.get(&b) else { continue }; + + // a should be a constant + let Some(coeff) = egraph[a].data.clone() else { continue }; + + // b should be monic and nonconstant (hence a power of x) + if statb.degree == 0 || !statb.monic { + continue; + } + + assert_eq!(st.degree, statb.degree); + + return Some((st.degree, coeff)); + }, + EquationLanguage::Num(c) => { // a constant is also a monomial + return Some((0, c.clone())); + }, + _ => {}, + } + } + + None +} + +// like find_monic_factorization, but with the option of having a constant factor +fn find_general_factorization(egraph: &EGraph, stats: &HashMap, id: Id) -> Option> { + let st = stats.get(&id)?; + + if st.monic { + return find_monic_factorization(egraph, stats, id); + } else { + for node in &egraph[id].nodes { + match node { + EquationLanguage::Mul([a,b]) => { + let a = egraph.find(*a); + let b = egraph.find(*b); + let Some(stata) = stats.get(&a) else { continue }; + let Some(statb) = stats.get(&b) else { continue }; + + // a is constant, b is monic and nonconstant + if stata.degree == 0 && statb.degree > 0 && statb.monic { + let mut fac = find_monic_factorization(egraph, stats, b)?; + fac.push(a); + return Some(fac); + } + }, + _ => {} + } + } + } + None +} + +// this assumes `id` to be canonical +fn find_monic_factorization(egraph: &EGraph, stats: &HashMap, id: Id) -> Option> { + // we want the polynomial to be nonconstant and monic + if ! stats.get(&id).is_some_and(|x|x.monic && x.degree > 0) { + return None; + } + + // now the whole thing is a monic nonconst poly, so would be a valid factorization, + // but we want to go as deep as possible +// println!("{:?}", stats[&id]); + + // check if it is the product of two nonconstant monic polynomials + for node in &egraph[id].nodes { + match node { + EquationLanguage::Mul([a,b]) => { + let a = egraph.find(*a); + let b = egraph.find(*b); + let Some(stata) = stats.get(&a) else { continue }; + let Some(statb) = stats.get(&b) else { continue }; + + if stata.degree == 0 || statb.degree == 0 || !stata.monic || !statb.monic { + continue; + } + +// println!("stats = {:?}, stats a = {:?}, stats b = {:?}", stats[&id], stata, statb); + + let Some(mut faca) = find_monic_factorization(egraph, stats, a) else { continue }; + let Some(facb) = find_monic_factorization(egraph, stats, b) else { continue }; + + faca.extend_from_slice(&facb); + return Some(faca); + }, + _ => {} + } + } + + // at this point we know the current polynomial is monic, but we didn't find a further factorization + // so just return it as a single factor + Some(vec![id]) +} + +fn walk_egraph(egraph: &EGraph, f: F) -> HashMap where F: Fn(Id, &EquationLanguage, &HashMap) -> Option { let mut result: HashMap = HashMap::new(); let mut modifications: usize = 1; +// println!("{:?}", egraph[canonical]); + while modifications > 0 { modifications = 0; - for cls in egraph.classes() { + 'next_class: for cls in egraph.classes() { let id = cls.id; if result.contains_key(&id) { - continue; + continue 'next_class; } for node in &cls.nodes { if let Some(x) = f(id, node, &result) { result.insert(id, x); modifications += 1; + continue 'next_class; } } } - println!("{} modifications!", modifications); +// println!("{} modifications!", modifications); } result } +/* pub fn analyze3(egraph: &EGraph, eclass: Id) { let constants = search_for(egraph, |id, _, _| egraph[id].data.as_ref().map(|c|c.clone()) @@ -227,3 +472,28 @@ pub fn analyze2(egraph: &EGraph) -> HashMap { types } +*/ + +impl fmt::Display for Factorization { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + if self.constant_factor != RATIONAL_ONE { + write!(f, "{}", self.constant_factor)?; + } + + for poly in &self.polynomials { + write!(f, "(")?; + for (deg, coeff) in poly.iter().enumerate() { + if deg == 0 { + write!(f, "{}", coeff)?; + } else if deg == 1 { + write!(f, " + {}x", coeff)?; + } else { + write!(f, " + {}x^{}", coeff, deg)?; + } + } + write!(f, ")")?; + } + + Ok(()) + } +}