use crate::language::{EGraph, EquationLanguage, Rational, RATIONAL_ONE, RATIONAL_ZERO}; use std::{collections::HashMap, fmt}; use egg::{AstSize, Extractor, Id}; #[derive(Debug,Clone)] pub enum SpecialTerm { Constant(Rational), PowerOfX(usize), Monomial(usize, Rational), MonicNonconstPoly(Vec), Factorization(Rational, Vec>), Other, } #[derive(Debug,Clone)] pub struct Factorization { pub constant_factor: Rational, pub polynomials: Vec>, } // this is a property of an eclass, not a particular expression #[derive(Debug,Clone)] struct PolyStats { degree: usize, monomial: bool, monic: bool, } fn gather_poly_stats(egraph: &EGraph) -> HashMap { walk_egraph(egraph, |_id, node, stats: &HashMap| { let x = |i: &Id| stats.get(&egraph.find(*i)); Some(match node { EquationLanguage::Unknown => PolyStats { degree: 1, monomial: true, monic: true, }, EquationLanguage::Num(c) => PolyStats { degree: 0, monomial: true, monic: c == &RATIONAL_ONE, }, EquationLanguage::Mul([a,b]) => { // if both aren't monic we can't tell, the leading coefficients could cancel // but there should be an alternative representative with one of them monic if !x(a)?.monic && !x(b)?.monic { return None; } PolyStats { degree: x(a)?.degree + x(b)?.degree, monomial: x(a)?.monomial && x(b)?.monomial, monic: x(a)?.monic && x(b)?.monic, } }, EquationLanguage::Add([a,b]) => { // in this case, there should also be a simplified representative which // has only a single leading term if x(a)?.degree == x(b)?.degree { return None; } PolyStats { degree: usize::max(x(a)?.degree, x(b)?.degree), monomial: false, monic: if x(a)?.degree > x(b)?.degree { x(a)?.monic } else { x(b)?.monic }, } }, _ => { return None; } }) }) } pub fn extract_normal_form(egraph: &EGraph, eclass: Id) -> Option { let eclass = egraph.find(eclass); // get the canonical eclass let stats = gather_poly_stats(egraph); let Some(factorization) = find_general_factorization(egraph, &stats, eclass) else { return None; }; let mut result = Vec::new(); let mut coeff: Option = None; for factor in factorization { let extracted = extract_polynomial(egraph, &stats, factor)?; // println!("Extracted: {:?}", extracted); if extracted.len() == 1 { coeff = Some(extracted[0].clone()); } else { result.push(extracted); } } Some(Factorization { constant_factor: coeff.unwrap_or_else(||RATIONAL_ONE.clone()), polynomials: result }) } // a polynomial should be either of: // - a monomial: then we know the degree and we try to parse it as a product of a constant and a monic monomial // or just a constant // - a sum of a monomial of highest degree, and a polynomial of lower degree, recursively walk these fn extract_polynomial(egraph: &EGraph, stats: &HashMap, id: Id) -> Option> { let st = &stats[&id]; if st.monomial { let (deg, coeff) = extract_monomial(egraph, stats, id)?; let mut result = vec![RATIONAL_ZERO; deg]; result.push(coeff); return Some(result); } else { for node in &egraph[id].nodes { match node { EquationLanguage::Add([a,b]) => { let a = egraph.find(*a); let b = egraph.find(*b); let Some(stata) = &stats.get(&a) else { continue }; let Some(statb) = &stats.get(&b) else { continue }; if stata.degree == st.degree && stata.monomial && statb.degree < st.degree { let (leading_deg, leading_coeff) = extract_monomial(egraph, stats, a)?; let mut remainder = extract_polynomial(egraph, stats, b)?; assert!(leading_deg >= remainder.len()); remainder.resize(leading_deg, RATIONAL_ZERO.clone()); remainder.push(leading_coeff); return Some(remainder); } }, _ => {} } } } None } // a monomial is either a power of x, a constant, or a product of a constant and power of x fn extract_monomial(egraph: &EGraph, stats: &HashMap, id: Id) -> Option<(usize, Rational)> { let extractor = Extractor::new(egraph, AstSize); let (_, expr) = extractor.find_best(id); // println!("Extract Monomial: {}", expr); let st = &stats[&id]; assert!(st.monomial); // monic + monomial = power of x if st.monic { return Some((st.degree, RATIONAL_ONE.clone())); } for node in &egraph[id].nodes { match node { EquationLanguage::Mul([a,b]) => { let a = egraph.find(*a); let b = egraph.find(*b); let Some(statb) = stats.get(&b) else { continue }; // a should be a constant let Some(coeff) = egraph[a].data.clone() else { continue }; // b should be monic and nonconstant (hence a power of x) if statb.degree == 0 || !statb.monic { continue; } assert_eq!(st.degree, statb.degree); return Some((st.degree, coeff)); }, EquationLanguage::Num(c) => { // a constant is also a monomial return Some((0, c.clone())); }, _ => {}, } } None } // like find_monic_factorization, but with the option of having a constant factor fn find_general_factorization(egraph: &EGraph, stats: &HashMap, id: Id) -> Option> { let st = stats.get(&id)?; if st.monic { return find_monic_factorization(egraph, stats, id); } else { for node in &egraph[id].nodes { match node { EquationLanguage::Mul([a,b]) => { let a = egraph.find(*a); let b = egraph.find(*b); let Some(stata) = stats.get(&a) else { continue }; let Some(statb) = stats.get(&b) else { continue }; // a is constant, b is monic and nonconstant if stata.degree == 0 && statb.degree > 0 && statb.monic { let mut fac = find_monic_factorization(egraph, stats, b)?; fac.push(a); return Some(fac); } }, _ => {} } } } None } // this assumes `id` to be canonical fn find_monic_factorization(egraph: &EGraph, stats: &HashMap, id: Id) -> Option> { // we want the polynomial to be nonconstant and monic if ! stats.get(&id).is_some_and(|x|x.monic && x.degree > 0) { return None; } // now the whole thing is a monic nonconst poly, so would be a valid factorization, // but we want to go as deep as possible // println!("{:?}", stats[&id]); // check if it is the product of two nonconstant monic polynomials for node in &egraph[id].nodes { match node { EquationLanguage::Mul([a,b]) => { let a = egraph.find(*a); let b = egraph.find(*b); let Some(stata) = stats.get(&a) else { continue }; let Some(statb) = stats.get(&b) else { continue }; if stata.degree == 0 || statb.degree == 0 || !stata.monic || !statb.monic { continue; } // println!("stats = {:?}, stats a = {:?}, stats b = {:?}", stats[&id], stata, statb); let Some(mut faca) = find_monic_factorization(egraph, stats, a) else { continue }; let Some(facb) = find_monic_factorization(egraph, stats, b) else { continue }; faca.extend_from_slice(&facb); return Some(faca); }, _ => {} } } // at this point we know the current polynomial is monic, but we didn't find a further factorization // so just return it as a single factor Some(vec![id]) } fn walk_egraph(egraph: &EGraph, f: F) -> HashMap where F: Fn(Id, &EquationLanguage, &HashMap) -> Option { let mut result: HashMap = HashMap::new(); let mut modifications: usize = 1; // println!("{:?}", egraph[canonical]); while modifications > 0 { modifications = 0; 'next_class: for cls in egraph.classes() { let id = cls.id; if result.contains_key(&id) { continue 'next_class; } for node in &cls.nodes { if let Some(x) = f(id, node, &result) { result.insert(id, x); modifications += 1; continue 'next_class; } } } // println!("{} modifications!", modifications); } result } /* pub fn analyze3(egraph: &EGraph, eclass: Id) { let constants = search_for(egraph, |id, _, _| egraph[id].data.as_ref().map(|c|c.clone()) ); println!("{:?}", constants); let powers_of_x = search_for(egraph, |_, node, matches| match *node { EquationLanguage::Unknown => Some(1), EquationLanguage::Mul([a,b]) => { if !matches.contains_key(&a) || !matches.contains_key(&b) { return None; } let (dega, degb) = (matches[&a], matches[&b]); Some(dega + degb) }, _ => None, }); println!("{:?}", powers_of_x); let monomials = search_for(egraph, |id, node, _| { if let Some(deg) = powers_of_x.get(&id) { return Some((*deg, RATIONAL_ONE.clone())); } if let Some(c) = constants.get(&id) { return Some((0, c.clone())); } match *node { EquationLanguage::Mul([a,b]) => { if !constants.contains_key(&a) || !powers_of_x.contains_key(&b) { return None; } let (coeff, deg) = (&constants[&a], powers_of_x[&b]); Some((deg, coeff.clone())) }, _ => None, } }); println!("{:?}", monomials); let monic_polynomials = search_for(egraph, |id, node, matches| { if let Some(deg) = powers_of_x.get(&id) { let mut poly: Vec = Vec::new(); poly.resize(*deg, RATIONAL_ZERO); poly.push(RATIONAL_ONE.clone()); Some(poly) } else { match *node { EquationLanguage::Add([a,b]) => { if !matches.contains_key(&a) || !monomials.contains_key(&b) { return None; } let (leading, (deg, coeff)) = (&matches[&a], &monomials[&b]); if leading.len() <= *deg || leading[*deg] != RATIONAL_ZERO { return None; } let mut poly = leading.clone(); poly[*deg] = coeff.clone(); Some(poly) }, _ => None, } } }); for p in &monic_polynomials { println!("{:?}", p); } let factorizations: HashMap>)> = search_for(egraph, |id, node, matches| { if let Some(c) = constants.get(&id) { return Some((c.clone(), vec![])); } if let Some(poly) = monic_polynomials.get(&id) { return Some((RATIONAL_ONE.clone(), vec![poly.clone()])); } match *node { EquationLanguage::Mul([a,b]) => { if !matches.contains_key(&a) || !monic_polynomials.contains_key(&b) { return None; } let ((factor, polys), newpoly) = (&matches[&a], &monic_polynomials[&b]); let mut combined: Vec> = polys.clone(); combined.push(newpoly.clone()); Some((factor.clone(), combined)) }, _ => None, } }); /* for p in &factorizations { println!("{:?}", p); } */ println!("{:?}", factorizations[&eclass]); } pub fn analyze2(egraph: &EGraph) -> HashMap { let mut types: HashMap = HashMap::new(); let mut modifications: usize = 1; while modifications > 0 { modifications = 0; for cls in egraph.classes() { let id = cls.id; if types.contains_key(&id) { continue; } if let Some(c) = &egraph[id].data { types.insert(id, SpecialTerm::Constant(c.clone())); modifications += 1; continue; } for node in &cls.nodes { match *node { EquationLanguage::Unknown => { types.insert(id, SpecialTerm::PowerOfX(1)); modifications += 1; }, EquationLanguage::Mul([a,b]) => { // as we don't know a and b yet, defer to future iteration if !types.contains_key(&a) || !types.contains_key(&b) { continue; } match (&types[&a], &types[&b]) { (SpecialTerm::PowerOfX(dega), SpecialTerm::PowerOfX(degb)) => { types.insert(id, SpecialTerm::PowerOfX(*dega + *degb)); modifications += 1; }, (SpecialTerm::Constant(coeff), SpecialTerm::PowerOfX(deg)) => { types.insert(id, SpecialTerm::Monomial(*deg, coeff.clone())); modifications += 1; }, _ => { }, } }, EquationLanguage::Add([a,b]) => { // as we don't know a and b yet, defer to future iteration if !types.contains_key(&a) || !types.contains_key(&b) { continue; } match (&types[&a], &types[&b]) { (SpecialTerm::MonicNonconstPoly(poly), SpecialTerm::Monomial(deg, coeff)) => { if poly.len() <= *deg || poly[*deg] != RATIONAL_ZERO { continue; } let mut poly = poly.clone(); poly[*deg] = coeff.clone(); types.insert(id, SpecialTerm::MonicNonconstPoly(poly)); modifications += 1; }, _ => { }, } }, _ => {} } } } println!("{} modifications!", modifications); } types } */ impl fmt::Display for Factorization { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { if self.constant_factor != RATIONAL_ONE { write!(f, "{}", self.constant_factor)?; } for poly in &self.polynomials { write!(f, "(")?; for (deg, coeff) in poly.iter().enumerate() { if deg == 0 { write!(f, "{}", coeff)?; } else if deg == 1 { write!(f, " + {}x", coeff)?; } else { write!(f, " + {}x^{}", coeff, deg)?; } } write!(f, ")")?; } Ok(()) } }