use core::fmt; use std::{cmp::Ordering, fmt::Display}; use crate::language::{EquationLanguage, Rational, RATIONAL_ONE, RATIONAL_ZERO}; use egg::{Id, RecExpr}; #[derive(Debug,Clone,Copy)] pub struct PolyStat { degree: usize, factors: usize, // non-constant factors ops: usize, monomial: bool, sum_of_monomials: bool, monic: bool, factorized: bool, // a product of monic polynomials and at least one constant } #[derive(Debug,Clone,Copy)] pub enum FactorizationCost { UnwantedOps, Polynomial(PolyStat) } fn score(cost: FactorizationCost) -> usize { match cost { FactorizationCost::UnwantedOps => 10000, FactorizationCost::Polynomial(p) => if !p.factorized { 1000 + p.ops } else { 100 * (9 - p.factors) + p.ops }, } } impl PartialEq for FactorizationCost { fn eq(&self, other: &Self) -> bool { score(*self) == score(*other) } } impl PartialOrd for FactorizationCost { fn partial_cmp(&self, other: &Self) -> Option { usize::partial_cmp(&score(*self), &score(*other)) } } pub struct FactorizationCostFn; impl egg::CostFunction for FactorizationCostFn { type Cost = FactorizationCost; fn cost(&mut self, enode: &EquationLanguage, mut costs: C) -> Self::Cost where C: FnMut(Id) -> Self::Cost, { match enode { EquationLanguage::Add([a,b]) => { match (costs(*a), costs(*b)) { (FactorizationCost::Polynomial(p1),FactorizationCost::Polynomial(p2)) => { // we only ever want to add monomials let result_monic = if p1.degree > p2.degree { p1.monic } else if p2.degree > p1.degree { p2.monic } else { false }; /* if *a == Id::from(4) && *b == Id::from(19) { println!("HERE {:?} {:?}", p1, p2); } */ if !p1.sum_of_monomials || !p2.sum_of_monomials { FactorizationCost::UnwantedOps } else { FactorizationCost::Polynomial(PolyStat { degree: usize::max(p1.degree, p2.degree), factors: 1, ops: p1.ops + p2.ops + 1, monomial: false, sum_of_monomials: p1.sum_of_monomials && p2.sum_of_monomials, monic: result_monic, factorized: result_monic, }) } }, _ => FactorizationCost::UnwantedOps } }, EquationLanguage::Mul([a,b]) => { match (costs(*a), costs(*b)) { (FactorizationCost::Polynomial(p1), FactorizationCost::Polynomial(p2)) => { FactorizationCost::Polynomial(PolyStat { degree: p1.degree + p2.degree, factors: p1.factors + p2.factors, ops: p1.ops + p2.ops + 1, monomial: p1.monomial && p2.monomial, sum_of_monomials: p1.monomial && p2.monomial, monic: p1.monic && p2.monic, factorized: (p1.monic && p2.factorized) || (p2.monic && p1.factorized) }) }, _ => FactorizationCost::UnwantedOps } }, EquationLanguage::Num(c) => { FactorizationCost::Polynomial(PolyStat { degree: 0, factors: 0, ops: 0, monomial: true, sum_of_monomials: true, monic: false, factorized: true }) }, EquationLanguage::Unknown => { FactorizationCost::Polynomial(PolyStat { degree: 1, factors: 1, ops: 0, monomial: true, sum_of_monomials: true, monic: true, factorized: true }) }, _ => FactorizationCost::UnwantedOps, } } } #[derive(Debug,Clone)] pub struct Factorization { pub constant_factor: Rational, pub polynomials: Vec>, } impl Display for Factorization { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { if self.constant_factor != RATIONAL_ONE { write!(f, "{}", self.constant_factor)?; } for poly in &self.polynomials { write!(f, "(")?; for (deg, coeff) in poly.iter().enumerate() { if deg == 0 { write!(f, "{}", coeff)?; } else if deg == 1 { write!(f, " + {}x", coeff)?; } else { write!(f, " + {}x^{}", coeff, deg)?; } } write!(f, ")")?; } Ok(()) } } pub fn extract_factorization(expr: &RecExpr) -> Factorization { let root_id: Id = Id::from(expr.as_ref().len()-1); let mut constant_factor: Option = None; let mut factors: Vec> = Vec::new(); let mut todo: Vec = Vec::new(); todo.push(root_id); while todo.len() > 0 { let id = todo.pop().unwrap(); match &expr[id] { EquationLanguage::Mul([a,b]) => { todo.push(*a); todo.push(*b); }, EquationLanguage::Num(x) => { assert!(constant_factor.is_none()); constant_factor = Some(x.clone()); }, _ => { factors.push(extract_polynomial(expr, id)); } } } Factorization { constant_factor: constant_factor.unwrap_or_else(||RATIONAL_ONE.clone()), polynomials: factors } } fn extract_polynomial(expr: &RecExpr, id: Id) -> Vec { let mut result: Vec = Vec::new(); let mut todo: Vec = Vec::new(); todo.push(id); while todo.len() > 0 { let id = todo.pop().unwrap(); match &expr[id] { EquationLanguage::Add([a,b]) => { todo.push(*a); todo.push(*b); }, _ => { let (deg, coeff) = extract_monomial(expr, id); result.resize(result.len().max(deg), RATIONAL_ZERO.clone()); if result.len() <= deg { result.push(coeff); } else { assert!(result[deg] == RATIONAL_ZERO); result[deg] = coeff; } } } } result } fn extract_monomial(expr: &RecExpr, id: Id) -> (usize, Rational) { let mut coeff: Option = None; let mut deg: usize = 0; let mut todo: Vec = Vec::new(); todo.push(id); while todo.len() > 0 { let id = todo.pop().unwrap(); match &expr[id] { EquationLanguage::Unknown => { deg += 1; }, EquationLanguage::Mul([a,b]) => { todo.push(*a); todo.push(*b); }, EquationLanguage::Num(x) => { assert!(coeff.is_none()); coeff = Some(x.clone()); }, _ => { panic!("Not a rational polynomial in normal form!"); } } } (deg, coeff.unwrap_or_else(||RATIONAL_ONE.clone())) }