solveq/src/normal_form.rs
Florian Stecker ef2a76869f fixed it!
2024-08-28 11:25:55 -04:00

500 lines
13 KiB
Rust

use crate::language::{EGraph, EquationLanguage, Rational, RATIONAL_ONE, RATIONAL_ZERO};
use std::{collections::HashMap, fmt};
use egg::{AstSize, Extractor, Id};
#[derive(Debug,Clone)]
pub enum SpecialTerm {
Constant(Rational),
PowerOfX(usize),
Monomial(usize, Rational),
MonicNonconstPoly(Vec<Rational>),
Factorization(Rational, Vec<Vec<Rational>>),
Other,
}
#[derive(Debug,Clone)]
pub struct Factorization {
pub constant_factor: Rational,
pub polynomials: Vec<Vec<Rational>>,
}
// this is a property of an eclass, not a particular expression
#[derive(Debug,Clone)]
struct PolyStats {
degree: usize,
monomial: bool,
monic: bool,
}
fn gather_poly_stats(egraph: &EGraph) -> HashMap<Id, PolyStats> {
walk_egraph(egraph, |_id, node, stats: &HashMap<Id, PolyStats>| {
let x = |i: &Id| stats.get(&egraph.find(*i));
Some(match node {
EquationLanguage::Unknown => PolyStats {
degree: 1,
monomial: true,
monic: true,
},
EquationLanguage::Num(c) => PolyStats {
degree: 0,
monomial: true,
monic: c == &RATIONAL_ONE,
},
EquationLanguage::Mul([a,b]) => {
// if both aren't monic we can't tell, the leading coefficients could cancel
// but there should be an alternative representative with one of them monic
if !x(a)?.monic && !x(b)?.monic {
return None;
}
PolyStats {
degree: x(a)?.degree + x(b)?.degree,
monomial: x(a)?.monomial && x(b)?.monomial,
monic: x(a)?.monic && x(b)?.monic,
}
},
EquationLanguage::Add([a,b]) => {
// in this case, there should also be a simplified representative which
// has only a single leading term
if x(a)?.degree == x(b)?.degree {
return None;
}
PolyStats {
degree: usize::max(x(a)?.degree, x(b)?.degree),
monomial: false,
monic: if x(a)?.degree > x(b)?.degree { x(a)?.monic } else { x(b)?.monic },
}
},
_ => { return None; }
})
})
}
pub fn extract_normal_form(egraph: &EGraph, eclass: Id) -> Option<Factorization> {
let eclass = egraph.find(eclass); // get the canonical eclass
let stats = gather_poly_stats(egraph);
let Some(factorization) = find_general_factorization(egraph, &stats, eclass) else { return None; };
let mut result = Vec::new();
let mut coeff: Option<Rational> = None;
for factor in factorization {
let extracted = extract_polynomial(egraph, &stats, factor)?;
// println!("Extracted: {:?}", extracted);
if extracted.len() == 1 {
coeff = Some(extracted[0].clone());
} else {
result.push(extracted);
}
}
Some(Factorization {
constant_factor: coeff.unwrap_or_else(||RATIONAL_ONE.clone()),
polynomials: result
})
}
// a polynomial should be either of:
// - a monomial: then we know the degree and we try to parse it as a product of a constant and a monic monomial
// or just a constant
// - a sum of a monomial of highest degree, and a polynomial of lower degree, recursively walk these
fn extract_polynomial(egraph: &EGraph, stats: &HashMap<Id, PolyStats>, id: Id) -> Option<Vec<Rational>> {
let st = &stats[&id];
if st.monomial {
let (deg, coeff) = extract_monomial(egraph, stats, id)?;
let mut result = vec![RATIONAL_ZERO; deg];
result.push(coeff);
return Some(result);
} else {
for node in &egraph[id].nodes {
match node {
EquationLanguage::Add([a,b]) => {
let a = egraph.find(*a);
let b = egraph.find(*b);
let Some(stata) = &stats.get(&a) else { continue };
let Some(statb) = &stats.get(&b) else { continue };
if stata.degree == st.degree && stata.monomial && statb.degree < st.degree {
let (leading_deg, leading_coeff) = extract_monomial(egraph, stats, a)?;
let mut remainder = extract_polynomial(egraph, stats, b)?;
assert!(leading_deg >= remainder.len());
remainder.resize(leading_deg, RATIONAL_ZERO.clone());
remainder.push(leading_coeff);
return Some(remainder);
}
},
_ => {}
}
}
}
None
}
// a monomial is either a power of x, a constant, or a product of a constant and power of x
fn extract_monomial(egraph: &EGraph, stats: &HashMap<Id, PolyStats>, id: Id) -> Option<(usize, Rational)> {
let extractor = Extractor::new(egraph, AstSize);
let (_, expr) = extractor.find_best(id);
// println!("Extract Monomial: {}", expr);
let st = &stats[&id];
assert!(st.monomial);
// monic + monomial = power of x
if st.monic {
return Some((st.degree, RATIONAL_ONE.clone()));
}
for node in &egraph[id].nodes {
match node {
EquationLanguage::Mul([a,b]) => {
let a = egraph.find(*a);
let b = egraph.find(*b);
let Some(statb) = stats.get(&b) else { continue };
// a should be a constant
let Some(coeff) = egraph[a].data.clone() else { continue };
// b should be monic and nonconstant (hence a power of x)
if statb.degree == 0 || !statb.monic {
continue;
}
assert_eq!(st.degree, statb.degree);
return Some((st.degree, coeff));
},
EquationLanguage::Num(c) => { // a constant is also a monomial
return Some((0, c.clone()));
},
_ => {},
}
}
None
}
// like find_monic_factorization, but with the option of having a constant factor
fn find_general_factorization(egraph: &EGraph, stats: &HashMap<Id, PolyStats>, id: Id) -> Option<Vec<Id>> {
let st = stats.get(&id)?;
if st.monic {
return find_monic_factorization(egraph, stats, id);
} else {
for node in &egraph[id].nodes {
match node {
EquationLanguage::Mul([a,b]) => {
let a = egraph.find(*a);
let b = egraph.find(*b);
let Some(stata) = stats.get(&a) else { continue };
let Some(statb) = stats.get(&b) else { continue };
// a is constant, b is monic and nonconstant
if stata.degree == 0 && statb.degree > 0 && statb.monic {
let mut fac = find_monic_factorization(egraph, stats, b)?;
fac.push(a);
return Some(fac);
}
},
_ => {}
}
}
}
None
}
// this assumes `id` to be canonical
fn find_monic_factorization(egraph: &EGraph, stats: &HashMap<Id, PolyStats>, id: Id) -> Option<Vec<Id>> {
// we want the polynomial to be nonconstant and monic
if ! stats.get(&id).is_some_and(|x|x.monic && x.degree > 0) {
return None;
}
// now the whole thing is a monic nonconst poly, so would be a valid factorization,
// but we want to go as deep as possible
// println!("{:?}", stats[&id]);
// check if it is the product of two nonconstant monic polynomials
for node in &egraph[id].nodes {
match node {
EquationLanguage::Mul([a,b]) => {
let a = egraph.find(*a);
let b = egraph.find(*b);
let Some(stata) = stats.get(&a) else { continue };
let Some(statb) = stats.get(&b) else { continue };
if stata.degree == 0 || statb.degree == 0 || !stata.monic || !statb.monic {
continue;
}
// println!("stats = {:?}, stats a = {:?}, stats b = {:?}", stats[&id], stata, statb);
let Some(mut faca) = find_monic_factorization(egraph, stats, a) else { continue };
let Some(facb) = find_monic_factorization(egraph, stats, b) else { continue };
faca.extend_from_slice(&facb);
return Some(faca);
},
_ => {}
}
}
// at this point we know the current polynomial is monic, but we didn't find a further factorization
// so just return it as a single factor
Some(vec![id])
}
fn walk_egraph<F, T>(egraph: &EGraph, f: F) -> HashMap<Id, T>
where
F: Fn(Id, &EquationLanguage, &HashMap<Id, T>) -> Option<T> {
let mut result: HashMap<Id, T> = HashMap::new();
let mut modifications: usize = 1;
// println!("{:?}", egraph[canonical]);
while modifications > 0 {
modifications = 0;
'next_class: for cls in egraph.classes() {
let id = cls.id;
if result.contains_key(&id) {
continue 'next_class;
}
for node in &cls.nodes {
if let Some(x) = f(id, node, &result) {
result.insert(id, x);
modifications += 1;
continue 'next_class;
}
}
}
// println!("{} modifications!", modifications);
}
result
}
/*
pub fn analyze3(egraph: &EGraph, eclass: Id) {
let constants = search_for(egraph, |id, _, _|
egraph[id].data.as_ref().map(|c|c.clone())
);
println!("{:?}", constants);
let powers_of_x = search_for(egraph, |_, node, matches| match *node {
EquationLanguage::Unknown => Some(1),
EquationLanguage::Mul([a,b]) => {
if !matches.contains_key(&a) || !matches.contains_key(&b) {
return None;
}
let (dega, degb) = (matches[&a], matches[&b]);
Some(dega + degb)
},
_ => None,
});
println!("{:?}", powers_of_x);
let monomials = search_for(egraph, |id, node, _| {
if let Some(deg) = powers_of_x.get(&id) {
return Some((*deg, RATIONAL_ONE.clone()));
}
if let Some(c) = constants.get(&id) {
return Some((0, c.clone()));
}
match *node {
EquationLanguage::Mul([a,b]) => {
if !constants.contains_key(&a) || !powers_of_x.contains_key(&b) {
return None;
}
let (coeff, deg) = (&constants[&a], powers_of_x[&b]);
Some((deg, coeff.clone()))
},
_ => None,
}
});
println!("{:?}", monomials);
let monic_polynomials = search_for(egraph, |id, node, matches| {
if let Some(deg) = powers_of_x.get(&id) {
let mut poly: Vec<Rational> = Vec::new();
poly.resize(*deg, RATIONAL_ZERO);
poly.push(RATIONAL_ONE.clone());
Some(poly)
} else {
match *node {
EquationLanguage::Add([a,b]) => {
if !matches.contains_key(&a) || !monomials.contains_key(&b) {
return None;
}
let (leading, (deg, coeff)) = (&matches[&a], &monomials[&b]);
if leading.len() <= *deg || leading[*deg] != RATIONAL_ZERO {
return None;
}
let mut poly = leading.clone();
poly[*deg] = coeff.clone();
Some(poly)
},
_ => None,
}
}
});
for p in &monic_polynomials {
println!("{:?}", p);
}
let factorizations: HashMap<Id, (Rational, Vec<Vec<Rational>>)> = search_for(egraph, |id, node, matches| {
if let Some(c) = constants.get(&id) {
return Some((c.clone(), vec![]));
}
if let Some(poly) = monic_polynomials.get(&id) {
return Some((RATIONAL_ONE.clone(), vec![poly.clone()]));
}
match *node {
EquationLanguage::Mul([a,b]) => {
if !matches.contains_key(&a) || !monic_polynomials.contains_key(&b) {
return None;
}
let ((factor, polys), newpoly) = (&matches[&a], &monic_polynomials[&b]);
let mut combined: Vec<Vec<Rational>> = polys.clone();
combined.push(newpoly.clone());
Some((factor.clone(), combined))
},
_ => None,
}
});
/*
for p in &factorizations {
println!("{:?}", p);
}
*/
println!("{:?}", factorizations[&eclass]);
}
pub fn analyze2(egraph: &EGraph) -> HashMap<Id, SpecialTerm> {
let mut types: HashMap<Id, SpecialTerm> = HashMap::new();
let mut modifications: usize = 1;
while modifications > 0 {
modifications = 0;
for cls in egraph.classes() {
let id = cls.id;
if types.contains_key(&id) {
continue;
}
if let Some(c) = &egraph[id].data {
types.insert(id, SpecialTerm::Constant(c.clone()));
modifications += 1;
continue;
}
for node in &cls.nodes {
match *node {
EquationLanguage::Unknown => {
types.insert(id, SpecialTerm::PowerOfX(1));
modifications += 1;
},
EquationLanguage::Mul([a,b]) => {
// as we don't know a and b yet, defer to future iteration
if !types.contains_key(&a) || !types.contains_key(&b) {
continue;
}
match (&types[&a], &types[&b]) {
(SpecialTerm::PowerOfX(dega), SpecialTerm::PowerOfX(degb)) => {
types.insert(id, SpecialTerm::PowerOfX(*dega + *degb));
modifications += 1;
},
(SpecialTerm::Constant(coeff), SpecialTerm::PowerOfX(deg)) => {
types.insert(id, SpecialTerm::Monomial(*deg, coeff.clone()));
modifications += 1;
},
_ => { },
}
},
EquationLanguage::Add([a,b]) => {
// as we don't know a and b yet, defer to future iteration
if !types.contains_key(&a) || !types.contains_key(&b) {
continue;
}
match (&types[&a], &types[&b]) {
(SpecialTerm::MonicNonconstPoly(poly), SpecialTerm::Monomial(deg, coeff)) => {
if poly.len() <= *deg || poly[*deg] != RATIONAL_ZERO {
continue;
}
let mut poly = poly.clone();
poly[*deg] = coeff.clone();
types.insert(id, SpecialTerm::MonicNonconstPoly(poly));
modifications += 1;
},
_ => { },
}
},
_ => {}
}
}
}
println!("{} modifications!", modifications);
}
types
}
*/
impl fmt::Display for Factorization {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if self.constant_factor != RATIONAL_ONE {
write!(f, "{}", self.constant_factor)?;
}
for poly in &self.polynomials {
write!(f, "(")?;
for (deg, coeff) in poly.iter().enumerate() {
if deg == 0 {
write!(f, "{}", coeff)?;
} else if deg == 1 {
write!(f, " + {}x", coeff)?;
} else {
write!(f, " + {}x^{}", coeff, deg)?;
}
}
write!(f, ")")?;
}
Ok(())
}
}