fixed it!

This commit is contained in:
Florian Stecker 2024-08-28 11:25:55 -04:00
parent d0265ea340
commit ef2a76869f
4 changed files with 326 additions and 152 deletions

View File

@ -66,13 +66,20 @@ impl egg::CostFunction<EquationLanguage> for FactorizationCostFn {
false false
}; };
/*
if *a == Id::from(4) && *b == Id::from(19) {
println!("HERE {:?} {:?}", p1, p2);
}
*/
if !p1.sum_of_monomials || !p2.sum_of_monomials { if !p1.sum_of_monomials || !p2.sum_of_monomials {
FactorizationCost::UnwantedOps FactorizationCost::UnwantedOps
} else { } else {
FactorizationCost::Polynomial(PolyStat { FactorizationCost::Polynomial(PolyStat {
degree: usize::max(p1.degree, p2.degree), degree: usize::max(p1.degree, p2.degree),
factors: 1, factors: 1,
ops: p1.ops + p2.ops, ops: p1.ops + p2.ops + 1,
monomial: false, monomial: false,
sum_of_monomials: p1.sum_of_monomials && p2.sum_of_monomials, sum_of_monomials: p1.sum_of_monomials && p2.sum_of_monomials,
monic: result_monic, monic: result_monic,
@ -89,7 +96,7 @@ impl egg::CostFunction<EquationLanguage> for FactorizationCostFn {
FactorizationCost::Polynomial(PolyStat { FactorizationCost::Polynomial(PolyStat {
degree: p1.degree + p2.degree, degree: p1.degree + p2.degree,
factors: p1.factors + p2.factors, factors: p1.factors + p2.factors,
ops: p1.ops + p2.ops, ops: p1.ops + p2.ops + 1,
monomial: p1.monomial && p2.monomial, monomial: p1.monomial && p2.monomial,
sum_of_monomials: p1.monomial && p2.monomial, sum_of_monomials: p1.monomial && p2.monomial,
monic: p1.monic && p2.monic, monic: p1.monic && p2.monic,

View File

@ -167,6 +167,13 @@ fn is_nonzero_const(var: &str) -> impl Fn(&mut EGraph, Id, &Subst) -> bool {
} }
} }
fn is_const(var: &str) -> impl Fn(&mut EGraph, Id, &Subst) -> bool {
let var: Var = var.parse().unwrap();
move |egraph, _, subst| {
egraph[subst[var]].data.is_some()
}
}
pub static RULES: LazyLock<Vec<Rewrite>> = LazyLock::new(||vec![ pub static RULES: LazyLock<Vec<Rewrite>> = LazyLock::new(||vec![
rw!("commute-add"; "(+ ?x ?y)" => "(+ ?y ?x)"), rw!("commute-add"; "(+ ?x ?y)" => "(+ ?y ?x)"),
rw!("commute-mul"; "(* ?x ?y)" => "(* ?y ?x)"), rw!("commute-mul"; "(* ?x ?y)" => "(* ?y ?x)"),
@ -187,10 +194,6 @@ pub static RULES: LazyLock<Vec<Rewrite>> = LazyLock::new(||vec![
rw!("square"; "(^ ?x 2)" => "(* ?x ?x)"), rw!("square"; "(^ ?x 2)" => "(* ?x ?x)"),
rw!("cube"; "(^ ?x 3)" => "(* ?x (* ?x ?x))"), rw!("cube"; "(^ ?x 3)" => "(* ?x (* ?x ?x))"),
/*
rw!("inv-square"; "(* ?x ?x)" => "(^ ?x 2)"),
rw!("inv-cube"; "(* ?x (* ?x ?x))" => "(^ ?x 3)"),
*/
rw!("sub"; "(- ?x ?y)" => "(+ ?x (* -1 ?y))"), rw!("sub"; "(- ?x ?y)" => "(+ ?x (* -1 ?y))"),
rw!("neg"; "(- ?x)" => "(* -1 ?x)"), rw!("neg"; "(- ?x)" => "(* -1 ?x)"),
@ -198,6 +201,8 @@ pub static RULES: LazyLock<Vec<Rewrite>> = LazyLock::new(||vec![
rw!("div"; "(/ ?x ?y)" => "(* ?x (rec ?y))" if is_nonzero_const("?y")), rw!("div"; "(/ ?x ?y)" => "(* ?x (rec ?y))" if is_nonzero_const("?y")),
rw!("factor_poly"; "(+ (* x ?x) ?y)" => "(* ?x (+ x (* ?y (rec ?x))))" if is_nonzero_const("?x")), rw!("factor_poly"; "(+ (* x ?x) ?y)" => "(* ?x (+ x (* ?y (rec ?x))))" if is_nonzero_const("?x")),
// rw!("integer_sqrt"; "(^ ?x (1/2))" => {} if is_const("?x")),
]); ]);
pub struct PlusTimesCostFn; pub struct PlusTimesCostFn;
@ -218,125 +223,3 @@ impl egg::CostFunction<EquationLanguage> for PlusTimesCostFn {
enode.fold(op_cost, |sum, i| sum + costs(i)) enode.fold(op_cost, |sum, i| sum + costs(i))
} }
} }
#[derive(Debug,Clone,Copy)]
pub struct PolyStat {
degree: usize,
factors: usize, // non-constant factors
ops: usize,
monomial: bool,
sum_of_monomials: bool,
monic: bool,
factorized: bool, // a product of monic polynomials and at least one constant
}
#[derive(Debug,Clone,Copy)]
pub enum FactorizationCost {
UnwantedOps,
Polynomial(PolyStat)
}
fn score(cost: FactorizationCost) -> usize {
match cost {
FactorizationCost::UnwantedOps => 10000,
FactorizationCost::Polynomial(p) =>
if !p.factorized {
1000
} else {
100 * (9 - p.factors) + p.ops
},
}
}
impl PartialEq for FactorizationCost {
fn eq(&self, other: &Self) -> bool {
score(*self) == score(*other)
}
}
impl PartialOrd for FactorizationCost {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
usize::partial_cmp(&score(*self), &score(*other))
}
}
pub struct FactorizationCostFn;
impl egg::CostFunction<EquationLanguage> for FactorizationCostFn {
type Cost = FactorizationCost;
fn cost<C>(&mut self, enode: &EquationLanguage, mut costs: C) -> Self::Cost
where
C: FnMut(Id) -> Self::Cost,
{
match enode {
EquationLanguage::Add([a,b]) => {
match (costs(*a), costs(*b)) {
(FactorizationCost::Polynomial(p1),FactorizationCost::Polynomial(p2)) => {
// we only ever want to add monomials
let result_monic = if p1.degree > p2.degree {
p1.monic
} else if p2.degree > p1.degree {
p2.monic
} else {
false
};
if !p1.sum_of_monomials || !p2.sum_of_monomials {
FactorizationCost::UnwantedOps
} else {
FactorizationCost::Polynomial(PolyStat {
degree: usize::max(p1.degree, p2.degree),
factors: 1,
ops: p1.ops + p2.ops,
monomial: false,
sum_of_monomials: p1.sum_of_monomials && p2.sum_of_monomials,
monic: result_monic,
factorized: result_monic,
})
}
},
_ => FactorizationCost::UnwantedOps
}
},
EquationLanguage::Mul([a,b]) => {
match (costs(*a), costs(*b)) {
(FactorizationCost::Polynomial(p1), FactorizationCost::Polynomial(p2)) => {
FactorizationCost::Polynomial(PolyStat {
degree: p1.degree + p2.degree,
factors: p1.factors + p2.factors,
ops: p1.ops + p2.ops,
monomial: p1.monomial && p2.monomial,
sum_of_monomials: p1.monomial && p2.monomial,
monic: p1.monic && p2.monic,
factorized: (p1.monic && p2.factorized) || (p2.monic && p1.factorized)
})
},
_ => FactorizationCost::UnwantedOps
}
},
EquationLanguage::Num(c) => {
FactorizationCost::Polynomial(PolyStat {
degree: 0,
factors: 0,
ops: 0,
monomial: true,
sum_of_monomials: true,
monic: false,
factorized: true
})
},
EquationLanguage::Unknown => {
FactorizationCost::Polynomial(PolyStat {
degree: 1,
factors: 1,
ops: 0,
monomial: true,
sum_of_monomials: true,
monic: true,
factorized: true
})
},
_ => FactorizationCost::UnwantedOps,
}
}
}

View File

@ -1,7 +1,7 @@
use egg::{AstSize, EGraph, Extractor, Id, Pattern, RecExpr, Runner}; use egg::{AstSize, Extractor, Id, Pattern, PatternAst, RecExpr, Runner, Searcher};
use solveq::factorization::{extract_factorization, FactorizationCost, FactorizationCostFn}; use solveq::factorization::{FactorizationCost, FactorizationCostFn};
use solveq::language::{ConstantFold, EquationLanguage, FactorizationCostFn, PlusTimesCostFn, Rational, RULES}; use solveq::language::{ConstantFold, EquationLanguage, Rational, RULES, EGraph};
use solveq::normal_form::analyze3; use solveq::normal_form::extract_normal_form;
use solveq::parse::parse_equation; use solveq::parse::parse_equation;
use solveq::output::print_term; use solveq::output::print_term;
@ -17,8 +17,9 @@ static TEST_EQUATIONS: &[&str] = &[
fn main() { fn main() {
let expr: RecExpr<EquationLanguage> = "(* x (+ x -2))".parse().unwrap(); // let expr: RecExpr<EquationLanguage> = "(+ (* x (+ x -2)) -15)".parse().unwrap();
println!("{:?}", get_expression_cost(&expr)); // let expr: RecExpr<EquationLanguage> = "(* (+ (* x (+ x -2)) -15) (+ x 5))".parse().unwrap();
// println!("{:?}", get_expression_cost(&expr));
for eq in TEST_EQUATIONS { for eq in TEST_EQUATIONS {
println!("Equation: {}", *eq); println!("Equation: {}", *eq);
@ -40,16 +41,23 @@ fn main() {
let (best_cost, best_expr) = extractor.find_best(runner.roots[0]); let (best_cost, best_expr) = extractor.find_best(runner.roots[0]);
// println!("{:?} {:?}", best_cost, <RecExpr<EquationLanguage> as AsRef<[EquationLanguage]>>::as_ref(&best_expr)); // println!("Best expresssion: {} {:?}", best_expr, best_cost);
println!("Best expresssion: {} {:?}", best_expr, best_cost);
let factorization = extract_factorization(&best_expr);
// println!("{}", runner.explain_equivalence(&start, &best_expr).get_flat_string()); // println!("{}", runner.explain_equivalence(&start, &best_expr).get_flat_string());
let Some(factorization) = extract_normal_form(&runner.egraph, runner.roots[0]) else {
panic!("Couldn't factorize polynomial!");
};
println!("Factorized normal form: {}", factorization); println!("Factorized normal form: {}", factorization);
/*
get_expression_cost("(* x (+ x -2))", &runner.egraph);
get_expression_cost("-15", &runner.egraph);
get_expression_cost("(+ (* x (+ x -2)) -15)", &runner.egraph);
*/
// let factorization = extract_factorization(&best_expr);
let mut solutions: Vec<String> = Vec::new(); let mut solutions: Vec<String> = Vec::new();
for poly in &factorization.polynomials { for poly in &factorization.polynomials {
if poly.len() == 2 { // linear factor if poly.len() == 2 { // linear factor
@ -85,10 +93,16 @@ fn main() {
} }
} }
fn get_expression_cost(expr: &RecExpr<EquationLanguage>) -> FactorizationCost { fn get_expression_cost(expr: &str, egraph: &EGraph) {
let mut egraph = EGraph::new(ConstantFold::default()); // let mut egraph = EGraph::new(ConstantFold::default());
let id = egraph.add_expr(expr); // let id = egraph.add_expr(expr);
let extractor = Extractor::new(&egraph, FactorizationCostFn); let pattern: Pattern<EquationLanguage> = expr.parse().unwrap();
let (cost, _) = extractor.find_best(id); let matches = pattern.search(egraph);
cost for m in matches {
let extractor = Extractor::new(&egraph, FactorizationCostFn);
let (cost, _) = extractor.find_best(m.eclass);
println!("expr: {}, id: {}, cost: {:?}", expr, m.eclass, cost);
}
// cost
} }

View File

@ -1,6 +1,6 @@
use crate::language::{EGraph, EquationLanguage, Rational, RATIONAL_ONE, RATIONAL_ZERO}; use crate::language::{EGraph, EquationLanguage, Rational, RATIONAL_ONE, RATIONAL_ZERO};
use std::collections::HashMap; use std::{collections::HashMap, fmt};
use egg::Id; use egg::{AstSize, Extractor, Id};
#[derive(Debug,Clone)] #[derive(Debug,Clone)]
pub enum SpecialTerm { pub enum SpecialTerm {
@ -12,36 +12,281 @@ pub enum SpecialTerm {
Other, Other,
} }
fn search_for<F, T>(egraph: &EGraph, f: F) -> HashMap<Id, T> #[derive(Debug,Clone)]
pub struct Factorization {
pub constant_factor: Rational,
pub polynomials: Vec<Vec<Rational>>,
}
// this is a property of an eclass, not a particular expression
#[derive(Debug,Clone)]
struct PolyStats {
degree: usize,
monomial: bool,
monic: bool,
}
fn gather_poly_stats(egraph: &EGraph) -> HashMap<Id, PolyStats> {
walk_egraph(egraph, |_id, node, stats: &HashMap<Id, PolyStats>| {
let x = |i: &Id| stats.get(&egraph.find(*i));
Some(match node {
EquationLanguage::Unknown => PolyStats {
degree: 1,
monomial: true,
monic: true,
},
EquationLanguage::Num(c) => PolyStats {
degree: 0,
monomial: true,
monic: c == &RATIONAL_ONE,
},
EquationLanguage::Mul([a,b]) => {
// if both aren't monic we can't tell, the leading coefficients could cancel
// but there should be an alternative representative with one of them monic
if !x(a)?.monic && !x(b)?.monic {
return None;
}
PolyStats {
degree: x(a)?.degree + x(b)?.degree,
monomial: x(a)?.monomial && x(b)?.monomial,
monic: x(a)?.monic && x(b)?.monic,
}
},
EquationLanguage::Add([a,b]) => {
// in this case, there should also be a simplified representative which
// has only a single leading term
if x(a)?.degree == x(b)?.degree {
return None;
}
PolyStats {
degree: usize::max(x(a)?.degree, x(b)?.degree),
monomial: false,
monic: if x(a)?.degree > x(b)?.degree { x(a)?.monic } else { x(b)?.monic },
}
},
_ => { return None; }
})
})
}
pub fn extract_normal_form(egraph: &EGraph, eclass: Id) -> Option<Factorization> {
let eclass = egraph.find(eclass); // get the canonical eclass
let stats = gather_poly_stats(egraph);
let Some(factorization) = find_general_factorization(egraph, &stats, eclass) else { return None; };
let mut result = Vec::new();
let mut coeff: Option<Rational> = None;
for factor in factorization {
let extracted = extract_polynomial(egraph, &stats, factor)?;
// println!("Extracted: {:?}", extracted);
if extracted.len() == 1 {
coeff = Some(extracted[0].clone());
} else {
result.push(extracted);
}
}
Some(Factorization {
constant_factor: coeff.unwrap_or_else(||RATIONAL_ONE.clone()),
polynomials: result
})
}
// a polynomial should be either of:
// - a monomial: then we know the degree and we try to parse it as a product of a constant and a monic monomial
// or just a constant
// - a sum of a monomial of highest degree, and a polynomial of lower degree, recursively walk these
fn extract_polynomial(egraph: &EGraph, stats: &HashMap<Id, PolyStats>, id: Id) -> Option<Vec<Rational>> {
let st = &stats[&id];
if st.monomial {
let (deg, coeff) = extract_monomial(egraph, stats, id)?;
let mut result = vec![RATIONAL_ZERO; deg];
result.push(coeff);
return Some(result);
} else {
for node in &egraph[id].nodes {
match node {
EquationLanguage::Add([a,b]) => {
let a = egraph.find(*a);
let b = egraph.find(*b);
let Some(stata) = &stats.get(&a) else { continue };
let Some(statb) = &stats.get(&b) else { continue };
if stata.degree == st.degree && stata.monomial && statb.degree < st.degree {
let (leading_deg, leading_coeff) = extract_monomial(egraph, stats, a)?;
let mut remainder = extract_polynomial(egraph, stats, b)?;
assert!(leading_deg >= remainder.len());
remainder.resize(leading_deg, RATIONAL_ZERO.clone());
remainder.push(leading_coeff);
return Some(remainder);
}
},
_ => {}
}
}
}
None
}
// a monomial is either a power of x, a constant, or a product of a constant and power of x
fn extract_monomial(egraph: &EGraph, stats: &HashMap<Id, PolyStats>, id: Id) -> Option<(usize, Rational)> {
let extractor = Extractor::new(egraph, AstSize);
let (_, expr) = extractor.find_best(id);
// println!("Extract Monomial: {}", expr);
let st = &stats[&id];
assert!(st.monomial);
// monic + monomial = power of x
if st.monic {
return Some((st.degree, RATIONAL_ONE.clone()));
}
for node in &egraph[id].nodes {
match node {
EquationLanguage::Mul([a,b]) => {
let a = egraph.find(*a);
let b = egraph.find(*b);
let Some(statb) = stats.get(&b) else { continue };
// a should be a constant
let Some(coeff) = egraph[a].data.clone() else { continue };
// b should be monic and nonconstant (hence a power of x)
if statb.degree == 0 || !statb.monic {
continue;
}
assert_eq!(st.degree, statb.degree);
return Some((st.degree, coeff));
},
EquationLanguage::Num(c) => { // a constant is also a monomial
return Some((0, c.clone()));
},
_ => {},
}
}
None
}
// like find_monic_factorization, but with the option of having a constant factor
fn find_general_factorization(egraph: &EGraph, stats: &HashMap<Id, PolyStats>, id: Id) -> Option<Vec<Id>> {
let st = stats.get(&id)?;
if st.monic {
return find_monic_factorization(egraph, stats, id);
} else {
for node in &egraph[id].nodes {
match node {
EquationLanguage::Mul([a,b]) => {
let a = egraph.find(*a);
let b = egraph.find(*b);
let Some(stata) = stats.get(&a) else { continue };
let Some(statb) = stats.get(&b) else { continue };
// a is constant, b is monic and nonconstant
if stata.degree == 0 && statb.degree > 0 && statb.monic {
let mut fac = find_monic_factorization(egraph, stats, b)?;
fac.push(a);
return Some(fac);
}
},
_ => {}
}
}
}
None
}
// this assumes `id` to be canonical
fn find_monic_factorization(egraph: &EGraph, stats: &HashMap<Id, PolyStats>, id: Id) -> Option<Vec<Id>> {
// we want the polynomial to be nonconstant and monic
if ! stats.get(&id).is_some_and(|x|x.monic && x.degree > 0) {
return None;
}
// now the whole thing is a monic nonconst poly, so would be a valid factorization,
// but we want to go as deep as possible
// println!("{:?}", stats[&id]);
// check if it is the product of two nonconstant monic polynomials
for node in &egraph[id].nodes {
match node {
EquationLanguage::Mul([a,b]) => {
let a = egraph.find(*a);
let b = egraph.find(*b);
let Some(stata) = stats.get(&a) else { continue };
let Some(statb) = stats.get(&b) else { continue };
if stata.degree == 0 || statb.degree == 0 || !stata.monic || !statb.monic {
continue;
}
// println!("stats = {:?}, stats a = {:?}, stats b = {:?}", stats[&id], stata, statb);
let Some(mut faca) = find_monic_factorization(egraph, stats, a) else { continue };
let Some(facb) = find_monic_factorization(egraph, stats, b) else { continue };
faca.extend_from_slice(&facb);
return Some(faca);
},
_ => {}
}
}
// at this point we know the current polynomial is monic, but we didn't find a further factorization
// so just return it as a single factor
Some(vec![id])
}
fn walk_egraph<F, T>(egraph: &EGraph, f: F) -> HashMap<Id, T>
where where
F: Fn(Id, &EquationLanguage, &HashMap<Id, T>) -> Option<T> { F: Fn(Id, &EquationLanguage, &HashMap<Id, T>) -> Option<T> {
let mut result: HashMap<Id, T> = HashMap::new(); let mut result: HashMap<Id, T> = HashMap::new();
let mut modifications: usize = 1; let mut modifications: usize = 1;
// println!("{:?}", egraph[canonical]);
while modifications > 0 { while modifications > 0 {
modifications = 0; modifications = 0;
for cls in egraph.classes() { 'next_class: for cls in egraph.classes() {
let id = cls.id; let id = cls.id;
if result.contains_key(&id) { if result.contains_key(&id) {
continue; continue 'next_class;
} }
for node in &cls.nodes { for node in &cls.nodes {
if let Some(x) = f(id, node, &result) { if let Some(x) = f(id, node, &result) {
result.insert(id, x); result.insert(id, x);
modifications += 1; modifications += 1;
continue 'next_class;
} }
} }
} }
println!("{} modifications!", modifications); // println!("{} modifications!", modifications);
} }
result result
} }
/*
pub fn analyze3(egraph: &EGraph, eclass: Id) { pub fn analyze3(egraph: &EGraph, eclass: Id) {
let constants = search_for(egraph, |id, _, _| let constants = search_for(egraph, |id, _, _|
egraph[id].data.as_ref().map(|c|c.clone()) egraph[id].data.as_ref().map(|c|c.clone())
@ -227,3 +472,28 @@ pub fn analyze2(egraph: &EGraph) -> HashMap<Id, SpecialTerm> {
types types
} }
*/
impl fmt::Display for Factorization {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if self.constant_factor != RATIONAL_ONE {
write!(f, "{}", self.constant_factor)?;
}
for poly in &self.polynomials {
write!(f, "(")?;
for (deg, coeff) in poly.iter().enumerate() {
if deg == 0 {
write!(f, "{}", coeff)?;
} else if deg == 1 {
write!(f, " + {}x", coeff)?;
} else {
write!(f, " + {}x^{}", coeff, deg)?;
}
}
write!(f, ")")?;
}
Ok(())
}
}