257 lines
5.9 KiB
Rust
257 lines
5.9 KiB
Rust
use core::fmt;
|
|
use std::{cmp::Ordering, fmt::Display};
|
|
use crate::language::{EquationLanguage, Rational, RATIONAL_ONE, RATIONAL_ZERO};
|
|
use egg::{Id, RecExpr};
|
|
|
|
|
|
#[derive(Debug,Clone,Copy)]
|
|
pub struct PolyStat {
|
|
degree: usize,
|
|
factors: usize, // non-constant factors
|
|
ops: usize,
|
|
monomial: bool,
|
|
sum_of_monomials: bool,
|
|
monic: bool,
|
|
factorized: bool, // a product of monic polynomials and at least one constant
|
|
}
|
|
|
|
#[derive(Debug,Clone,Copy)]
|
|
pub enum FactorizationCost {
|
|
UnwantedOps,
|
|
Polynomial(PolyStat)
|
|
}
|
|
|
|
fn score(cost: FactorizationCost) -> usize {
|
|
match cost {
|
|
FactorizationCost::UnwantedOps => 10000,
|
|
FactorizationCost::Polynomial(p) =>
|
|
if !p.factorized {
|
|
1000 + p.ops
|
|
} else {
|
|
100 * (9 - p.factors) + p.ops
|
|
},
|
|
}
|
|
}
|
|
|
|
impl PartialEq for FactorizationCost {
|
|
fn eq(&self, other: &Self) -> bool {
|
|
score(*self) == score(*other)
|
|
}
|
|
}
|
|
|
|
impl PartialOrd for FactorizationCost {
|
|
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
|
|
usize::partial_cmp(&score(*self), &score(*other))
|
|
}
|
|
}
|
|
|
|
pub struct FactorizationCostFn;
|
|
impl egg::CostFunction<EquationLanguage> for FactorizationCostFn {
|
|
type Cost = FactorizationCost;
|
|
|
|
fn cost<C>(&mut self, enode: &EquationLanguage, mut costs: C) -> Self::Cost
|
|
where
|
|
C: FnMut(Id) -> Self::Cost,
|
|
{
|
|
match enode {
|
|
EquationLanguage::Add([a,b]) => {
|
|
match (costs(*a), costs(*b)) {
|
|
(FactorizationCost::Polynomial(p1),FactorizationCost::Polynomial(p2)) => {
|
|
// we only ever want to add monomials
|
|
let result_monic = if p1.degree > p2.degree {
|
|
p1.monic
|
|
} else if p2.degree > p1.degree {
|
|
p2.monic
|
|
} else {
|
|
false
|
|
};
|
|
|
|
/*
|
|
if *a == Id::from(4) && *b == Id::from(19) {
|
|
println!("HERE {:?} {:?}", p1, p2);
|
|
}
|
|
*/
|
|
|
|
|
|
if !p1.sum_of_monomials || !p2.sum_of_monomials {
|
|
FactorizationCost::UnwantedOps
|
|
} else {
|
|
FactorizationCost::Polynomial(PolyStat {
|
|
degree: usize::max(p1.degree, p2.degree),
|
|
factors: 1,
|
|
ops: p1.ops + p2.ops + 1,
|
|
monomial: false,
|
|
sum_of_monomials: p1.sum_of_monomials && p2.sum_of_monomials,
|
|
monic: result_monic,
|
|
factorized: result_monic,
|
|
})
|
|
}
|
|
},
|
|
_ => FactorizationCost::UnwantedOps
|
|
}
|
|
},
|
|
EquationLanguage::Mul([a,b]) => {
|
|
match (costs(*a), costs(*b)) {
|
|
(FactorizationCost::Polynomial(p1), FactorizationCost::Polynomial(p2)) => {
|
|
FactorizationCost::Polynomial(PolyStat {
|
|
degree: p1.degree + p2.degree,
|
|
factors: p1.factors + p2.factors,
|
|
ops: p1.ops + p2.ops + 1,
|
|
monomial: p1.monomial && p2.monomial,
|
|
sum_of_monomials: p1.monomial && p2.monomial,
|
|
monic: p1.monic && p2.monic,
|
|
factorized: (p1.monic && p2.factorized) || (p2.monic && p1.factorized)
|
|
})
|
|
},
|
|
_ => FactorizationCost::UnwantedOps
|
|
}
|
|
},
|
|
EquationLanguage::Num(c) => {
|
|
FactorizationCost::Polynomial(PolyStat {
|
|
degree: 0,
|
|
factors: 0,
|
|
ops: 0,
|
|
monomial: true,
|
|
sum_of_monomials: true,
|
|
monic: false,
|
|
factorized: true
|
|
})
|
|
},
|
|
EquationLanguage::Unknown => {
|
|
FactorizationCost::Polynomial(PolyStat {
|
|
degree: 1,
|
|
factors: 1,
|
|
ops: 0,
|
|
monomial: true,
|
|
sum_of_monomials: true,
|
|
monic: true,
|
|
factorized: true
|
|
})
|
|
},
|
|
_ => FactorizationCost::UnwantedOps,
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(Debug,Clone)]
|
|
pub struct Factorization {
|
|
pub constant_factor: Rational,
|
|
pub polynomials: Vec<Vec<Rational>>,
|
|
}
|
|
|
|
impl Display for Factorization {
|
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
|
if self.constant_factor != RATIONAL_ONE {
|
|
write!(f, "{}", self.constant_factor)?;
|
|
}
|
|
|
|
for poly in &self.polynomials {
|
|
write!(f, "(")?;
|
|
for (deg, coeff) in poly.iter().enumerate() {
|
|
if deg == 0 {
|
|
write!(f, "{}", coeff)?;
|
|
} else if deg == 1 {
|
|
write!(f, " + {}x", coeff)?;
|
|
} else {
|
|
write!(f, " + {}x^{}", coeff, deg)?;
|
|
}
|
|
}
|
|
write!(f, ")")?;
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
}
|
|
|
|
pub fn extract_factorization(expr: &RecExpr<EquationLanguage>) -> Factorization {
|
|
let root_id: Id = Id::from(expr.as_ref().len()-1);
|
|
|
|
let mut constant_factor: Option<Rational> = None;
|
|
let mut factors: Vec<Vec<Rational>> = Vec::new();
|
|
let mut todo: Vec<Id> = Vec::new();
|
|
todo.push(root_id);
|
|
|
|
while todo.len() > 0 {
|
|
let id = todo.pop().unwrap();
|
|
|
|
match &expr[id] {
|
|
EquationLanguage::Mul([a,b]) => {
|
|
todo.push(*a);
|
|
todo.push(*b);
|
|
},
|
|
EquationLanguage::Num(x) => {
|
|
assert!(constant_factor.is_none());
|
|
constant_factor = Some(x.clone());
|
|
},
|
|
_ => {
|
|
factors.push(extract_polynomial(expr, id));
|
|
}
|
|
}
|
|
}
|
|
|
|
Factorization {
|
|
constant_factor: constant_factor.unwrap_or_else(||RATIONAL_ONE.clone()),
|
|
polynomials: factors
|
|
}
|
|
}
|
|
|
|
fn extract_polynomial(expr: &RecExpr<EquationLanguage>, id: Id) -> Vec<Rational> {
|
|
let mut result: Vec<Rational> = Vec::new();
|
|
let mut todo: Vec<Id> = Vec::new();
|
|
todo.push(id);
|
|
|
|
while todo.len() > 0 {
|
|
let id = todo.pop().unwrap();
|
|
|
|
match &expr[id] {
|
|
EquationLanguage::Add([a,b]) => {
|
|
todo.push(*a);
|
|
todo.push(*b);
|
|
},
|
|
_ => {
|
|
let (deg, coeff) = extract_monomial(expr, id);
|
|
result.resize(result.len().max(deg), RATIONAL_ZERO.clone());
|
|
|
|
if result.len() <= deg {
|
|
result.push(coeff);
|
|
} else {
|
|
assert!(result[deg] == RATIONAL_ZERO);
|
|
result[deg] = coeff;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
result
|
|
}
|
|
|
|
fn extract_monomial(expr: &RecExpr<EquationLanguage>, id: Id) -> (usize, Rational) {
|
|
let mut coeff: Option<Rational> = None;
|
|
let mut deg: usize = 0;
|
|
let mut todo: Vec<Id> = Vec::new();
|
|
todo.push(id);
|
|
|
|
while todo.len() > 0 {
|
|
let id = todo.pop().unwrap();
|
|
|
|
match &expr[id] {
|
|
EquationLanguage::Unknown => {
|
|
deg += 1;
|
|
},
|
|
EquationLanguage::Mul([a,b]) => {
|
|
todo.push(*a);
|
|
todo.push(*b);
|
|
},
|
|
EquationLanguage::Num(x) => {
|
|
assert!(coeff.is_none());
|
|
coeff = Some(x.clone());
|
|
},
|
|
_ => {
|
|
panic!("Not a rational polynomial in normal form!");
|
|
}
|
|
}
|
|
}
|
|
|
|
(deg, coeff.unwrap_or_else(||RATIONAL_ONE.clone()))
|
|
}
|