Compare commits
2 Commits
68b6293028
...
ef2a76869f
Author | SHA1 | Date | |
---|---|---|---|
|
ef2a76869f | ||
|
d0265ea340 |
256
src/factorization.rs
Normal file
256
src/factorization.rs
Normal file
@ -0,0 +1,256 @@
|
|||||||
|
use core::fmt;
|
||||||
|
use std::{cmp::Ordering, fmt::Display};
|
||||||
|
use crate::language::{EquationLanguage, Rational, RATIONAL_ONE, RATIONAL_ZERO};
|
||||||
|
use egg::{Id, RecExpr};
|
||||||
|
|
||||||
|
|
||||||
|
#[derive(Debug,Clone,Copy)]
|
||||||
|
pub struct PolyStat {
|
||||||
|
degree: usize,
|
||||||
|
factors: usize, // non-constant factors
|
||||||
|
ops: usize,
|
||||||
|
monomial: bool,
|
||||||
|
sum_of_monomials: bool,
|
||||||
|
monic: bool,
|
||||||
|
factorized: bool, // a product of monic polynomials and at least one constant
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug,Clone,Copy)]
|
||||||
|
pub enum FactorizationCost {
|
||||||
|
UnwantedOps,
|
||||||
|
Polynomial(PolyStat)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn score(cost: FactorizationCost) -> usize {
|
||||||
|
match cost {
|
||||||
|
FactorizationCost::UnwantedOps => 10000,
|
||||||
|
FactorizationCost::Polynomial(p) =>
|
||||||
|
if !p.factorized {
|
||||||
|
1000 + p.ops
|
||||||
|
} else {
|
||||||
|
100 * (9 - p.factors) + p.ops
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PartialEq for FactorizationCost {
|
||||||
|
fn eq(&self, other: &Self) -> bool {
|
||||||
|
score(*self) == score(*other)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PartialOrd for FactorizationCost {
|
||||||
|
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
|
||||||
|
usize::partial_cmp(&score(*self), &score(*other))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct FactorizationCostFn;
|
||||||
|
impl egg::CostFunction<EquationLanguage> for FactorizationCostFn {
|
||||||
|
type Cost = FactorizationCost;
|
||||||
|
|
||||||
|
fn cost<C>(&mut self, enode: &EquationLanguage, mut costs: C) -> Self::Cost
|
||||||
|
where
|
||||||
|
C: FnMut(Id) -> Self::Cost,
|
||||||
|
{
|
||||||
|
match enode {
|
||||||
|
EquationLanguage::Add([a,b]) => {
|
||||||
|
match (costs(*a), costs(*b)) {
|
||||||
|
(FactorizationCost::Polynomial(p1),FactorizationCost::Polynomial(p2)) => {
|
||||||
|
// we only ever want to add monomials
|
||||||
|
let result_monic = if p1.degree > p2.degree {
|
||||||
|
p1.monic
|
||||||
|
} else if p2.degree > p1.degree {
|
||||||
|
p2.monic
|
||||||
|
} else {
|
||||||
|
false
|
||||||
|
};
|
||||||
|
|
||||||
|
/*
|
||||||
|
if *a == Id::from(4) && *b == Id::from(19) {
|
||||||
|
println!("HERE {:?} {:?}", p1, p2);
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
|
|
||||||
|
if !p1.sum_of_monomials || !p2.sum_of_monomials {
|
||||||
|
FactorizationCost::UnwantedOps
|
||||||
|
} else {
|
||||||
|
FactorizationCost::Polynomial(PolyStat {
|
||||||
|
degree: usize::max(p1.degree, p2.degree),
|
||||||
|
factors: 1,
|
||||||
|
ops: p1.ops + p2.ops + 1,
|
||||||
|
monomial: false,
|
||||||
|
sum_of_monomials: p1.sum_of_monomials && p2.sum_of_monomials,
|
||||||
|
monic: result_monic,
|
||||||
|
factorized: result_monic,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
},
|
||||||
|
_ => FactorizationCost::UnwantedOps
|
||||||
|
}
|
||||||
|
},
|
||||||
|
EquationLanguage::Mul([a,b]) => {
|
||||||
|
match (costs(*a), costs(*b)) {
|
||||||
|
(FactorizationCost::Polynomial(p1), FactorizationCost::Polynomial(p2)) => {
|
||||||
|
FactorizationCost::Polynomial(PolyStat {
|
||||||
|
degree: p1.degree + p2.degree,
|
||||||
|
factors: p1.factors + p2.factors,
|
||||||
|
ops: p1.ops + p2.ops + 1,
|
||||||
|
monomial: p1.monomial && p2.monomial,
|
||||||
|
sum_of_monomials: p1.monomial && p2.monomial,
|
||||||
|
monic: p1.monic && p2.monic,
|
||||||
|
factorized: (p1.monic && p2.factorized) || (p2.monic && p1.factorized)
|
||||||
|
})
|
||||||
|
},
|
||||||
|
_ => FactorizationCost::UnwantedOps
|
||||||
|
}
|
||||||
|
},
|
||||||
|
EquationLanguage::Num(c) => {
|
||||||
|
FactorizationCost::Polynomial(PolyStat {
|
||||||
|
degree: 0,
|
||||||
|
factors: 0,
|
||||||
|
ops: 0,
|
||||||
|
monomial: true,
|
||||||
|
sum_of_monomials: true,
|
||||||
|
monic: false,
|
||||||
|
factorized: true
|
||||||
|
})
|
||||||
|
},
|
||||||
|
EquationLanguage::Unknown => {
|
||||||
|
FactorizationCost::Polynomial(PolyStat {
|
||||||
|
degree: 1,
|
||||||
|
factors: 1,
|
||||||
|
ops: 0,
|
||||||
|
monomial: true,
|
||||||
|
sum_of_monomials: true,
|
||||||
|
monic: true,
|
||||||
|
factorized: true
|
||||||
|
})
|
||||||
|
},
|
||||||
|
_ => FactorizationCost::UnwantedOps,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug,Clone)]
|
||||||
|
pub struct Factorization {
|
||||||
|
pub constant_factor: Rational,
|
||||||
|
pub polynomials: Vec<Vec<Rational>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Display for Factorization {
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
|
if self.constant_factor != RATIONAL_ONE {
|
||||||
|
write!(f, "{}", self.constant_factor)?;
|
||||||
|
}
|
||||||
|
|
||||||
|
for poly in &self.polynomials {
|
||||||
|
write!(f, "(")?;
|
||||||
|
for (deg, coeff) in poly.iter().enumerate() {
|
||||||
|
if deg == 0 {
|
||||||
|
write!(f, "{}", coeff)?;
|
||||||
|
} else if deg == 1 {
|
||||||
|
write!(f, " + {}x", coeff)?;
|
||||||
|
} else {
|
||||||
|
write!(f, " + {}x^{}", coeff, deg)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
write!(f, ")")?;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn extract_factorization(expr: &RecExpr<EquationLanguage>) -> Factorization {
|
||||||
|
let root_id: Id = Id::from(expr.as_ref().len()-1);
|
||||||
|
|
||||||
|
let mut constant_factor: Option<Rational> = None;
|
||||||
|
let mut factors: Vec<Vec<Rational>> = Vec::new();
|
||||||
|
let mut todo: Vec<Id> = Vec::new();
|
||||||
|
todo.push(root_id);
|
||||||
|
|
||||||
|
while todo.len() > 0 {
|
||||||
|
let id = todo.pop().unwrap();
|
||||||
|
|
||||||
|
match &expr[id] {
|
||||||
|
EquationLanguage::Mul([a,b]) => {
|
||||||
|
todo.push(*a);
|
||||||
|
todo.push(*b);
|
||||||
|
},
|
||||||
|
EquationLanguage::Num(x) => {
|
||||||
|
assert!(constant_factor.is_none());
|
||||||
|
constant_factor = Some(x.clone());
|
||||||
|
},
|
||||||
|
_ => {
|
||||||
|
factors.push(extract_polynomial(expr, id));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Factorization {
|
||||||
|
constant_factor: constant_factor.unwrap_or_else(||RATIONAL_ONE.clone()),
|
||||||
|
polynomials: factors
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn extract_polynomial(expr: &RecExpr<EquationLanguage>, id: Id) -> Vec<Rational> {
|
||||||
|
let mut result: Vec<Rational> = Vec::new();
|
||||||
|
let mut todo: Vec<Id> = Vec::new();
|
||||||
|
todo.push(id);
|
||||||
|
|
||||||
|
while todo.len() > 0 {
|
||||||
|
let id = todo.pop().unwrap();
|
||||||
|
|
||||||
|
match &expr[id] {
|
||||||
|
EquationLanguage::Add([a,b]) => {
|
||||||
|
todo.push(*a);
|
||||||
|
todo.push(*b);
|
||||||
|
},
|
||||||
|
_ => {
|
||||||
|
let (deg, coeff) = extract_monomial(expr, id);
|
||||||
|
result.resize(result.len().max(deg), RATIONAL_ZERO.clone());
|
||||||
|
|
||||||
|
if result.len() <= deg {
|
||||||
|
result.push(coeff);
|
||||||
|
} else {
|
||||||
|
assert!(result[deg] == RATIONAL_ZERO);
|
||||||
|
result[deg] = coeff;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
result
|
||||||
|
}
|
||||||
|
|
||||||
|
fn extract_monomial(expr: &RecExpr<EquationLanguage>, id: Id) -> (usize, Rational) {
|
||||||
|
let mut coeff: Option<Rational> = None;
|
||||||
|
let mut deg: usize = 0;
|
||||||
|
let mut todo: Vec<Id> = Vec::new();
|
||||||
|
todo.push(id);
|
||||||
|
|
||||||
|
while todo.len() > 0 {
|
||||||
|
let id = todo.pop().unwrap();
|
||||||
|
|
||||||
|
match &expr[id] {
|
||||||
|
EquationLanguage::Unknown => {
|
||||||
|
deg += 1;
|
||||||
|
},
|
||||||
|
EquationLanguage::Mul([a,b]) => {
|
||||||
|
todo.push(*a);
|
||||||
|
todo.push(*b);
|
||||||
|
},
|
||||||
|
EquationLanguage::Num(x) => {
|
||||||
|
assert!(coeff.is_none());
|
||||||
|
coeff = Some(x.clone());
|
||||||
|
},
|
||||||
|
_ => {
|
||||||
|
panic!("Not a rational polynomial in normal form!");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
(deg, coeff.unwrap_or_else(||RATIONAL_ONE.clone()))
|
||||||
|
}
|
135
src/language.rs
135
src/language.rs
@ -167,6 +167,13 @@ fn is_nonzero_const(var: &str) -> impl Fn(&mut EGraph, Id, &Subst) -> bool {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn is_const(var: &str) -> impl Fn(&mut EGraph, Id, &Subst) -> bool {
|
||||||
|
let var: Var = var.parse().unwrap();
|
||||||
|
move |egraph, _, subst| {
|
||||||
|
egraph[subst[var]].data.is_some()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub static RULES: LazyLock<Vec<Rewrite>> = LazyLock::new(||vec![
|
pub static RULES: LazyLock<Vec<Rewrite>> = LazyLock::new(||vec![
|
||||||
rw!("commute-add"; "(+ ?x ?y)" => "(+ ?y ?x)"),
|
rw!("commute-add"; "(+ ?x ?y)" => "(+ ?y ?x)"),
|
||||||
rw!("commute-mul"; "(* ?x ?y)" => "(* ?y ?x)"),
|
rw!("commute-mul"; "(* ?x ?y)" => "(* ?y ?x)"),
|
||||||
@ -187,10 +194,6 @@ pub static RULES: LazyLock<Vec<Rewrite>> = LazyLock::new(||vec![
|
|||||||
|
|
||||||
rw!("square"; "(^ ?x 2)" => "(* ?x ?x)"),
|
rw!("square"; "(^ ?x 2)" => "(* ?x ?x)"),
|
||||||
rw!("cube"; "(^ ?x 3)" => "(* ?x (* ?x ?x))"),
|
rw!("cube"; "(^ ?x 3)" => "(* ?x (* ?x ?x))"),
|
||||||
/*
|
|
||||||
rw!("inv-square"; "(* ?x ?x)" => "(^ ?x 2)"),
|
|
||||||
rw!("inv-cube"; "(* ?x (* ?x ?x))" => "(^ ?x 3)"),
|
|
||||||
*/
|
|
||||||
|
|
||||||
rw!("sub"; "(- ?x ?y)" => "(+ ?x (* -1 ?y))"),
|
rw!("sub"; "(- ?x ?y)" => "(+ ?x (* -1 ?y))"),
|
||||||
rw!("neg"; "(- ?x)" => "(* -1 ?x)"),
|
rw!("neg"; "(- ?x)" => "(* -1 ?x)"),
|
||||||
@ -198,6 +201,8 @@ pub static RULES: LazyLock<Vec<Rewrite>> = LazyLock::new(||vec![
|
|||||||
rw!("div"; "(/ ?x ?y)" => "(* ?x (rec ?y))" if is_nonzero_const("?y")),
|
rw!("div"; "(/ ?x ?y)" => "(* ?x (rec ?y))" if is_nonzero_const("?y")),
|
||||||
|
|
||||||
rw!("factor_poly"; "(+ (* x ?x) ?y)" => "(* ?x (+ x (* ?y (rec ?x))))" if is_nonzero_const("?x")),
|
rw!("factor_poly"; "(+ (* x ?x) ?y)" => "(* ?x (+ x (* ?y (rec ?x))))" if is_nonzero_const("?x")),
|
||||||
|
|
||||||
|
// rw!("integer_sqrt"; "(^ ?x (1/2))" => {} if is_const("?x")),
|
||||||
]);
|
]);
|
||||||
|
|
||||||
pub struct PlusTimesCostFn;
|
pub struct PlusTimesCostFn;
|
||||||
@ -218,125 +223,3 @@ impl egg::CostFunction<EquationLanguage> for PlusTimesCostFn {
|
|||||||
enode.fold(op_cost, |sum, i| sum + costs(i))
|
enode.fold(op_cost, |sum, i| sum + costs(i))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug,Clone,Copy)]
|
|
||||||
pub struct PolyStat {
|
|
||||||
degree: usize,
|
|
||||||
factors: usize, // non-constant factors
|
|
||||||
ops: usize,
|
|
||||||
monomial: bool,
|
|
||||||
sum_of_monomials: bool,
|
|
||||||
monic: bool,
|
|
||||||
factorized: bool, // a product of monic polynomials and at least one constant
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug,Clone,Copy)]
|
|
||||||
pub enum FactorizationCost {
|
|
||||||
UnwantedOps,
|
|
||||||
Polynomial(PolyStat)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn score(cost: FactorizationCost) -> usize {
|
|
||||||
match cost {
|
|
||||||
FactorizationCost::UnwantedOps => 10000,
|
|
||||||
FactorizationCost::Polynomial(p) =>
|
|
||||||
if !p.factorized {
|
|
||||||
1000
|
|
||||||
} else {
|
|
||||||
100 * (9 - p.factors) + p.ops
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl PartialEq for FactorizationCost {
|
|
||||||
fn eq(&self, other: &Self) -> bool {
|
|
||||||
score(*self) == score(*other)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl PartialOrd for FactorizationCost {
|
|
||||||
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
|
|
||||||
usize::partial_cmp(&score(*self), &score(*other))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct FactorizationCostFn;
|
|
||||||
impl egg::CostFunction<EquationLanguage> for FactorizationCostFn {
|
|
||||||
type Cost = FactorizationCost;
|
|
||||||
|
|
||||||
fn cost<C>(&mut self, enode: &EquationLanguage, mut costs: C) -> Self::Cost
|
|
||||||
where
|
|
||||||
C: FnMut(Id) -> Self::Cost,
|
|
||||||
{
|
|
||||||
match enode {
|
|
||||||
EquationLanguage::Add([a,b]) => {
|
|
||||||
match (costs(*a), costs(*b)) {
|
|
||||||
(FactorizationCost::Polynomial(p1),FactorizationCost::Polynomial(p2)) => {
|
|
||||||
// we only ever want to add monomials
|
|
||||||
let result_monic = if p1.degree > p2.degree {
|
|
||||||
p1.monic
|
|
||||||
} else if p2.degree > p1.degree {
|
|
||||||
p2.monic
|
|
||||||
} else {
|
|
||||||
false
|
|
||||||
};
|
|
||||||
|
|
||||||
if !p1.sum_of_monomials || !p2.sum_of_monomials {
|
|
||||||
FactorizationCost::UnwantedOps
|
|
||||||
} else {
|
|
||||||
FactorizationCost::Polynomial(PolyStat {
|
|
||||||
degree: usize::max(p1.degree, p2.degree),
|
|
||||||
factors: 1,
|
|
||||||
ops: p1.ops + p2.ops,
|
|
||||||
monomial: false,
|
|
||||||
sum_of_monomials: p1.sum_of_monomials && p2.sum_of_monomials,
|
|
||||||
monic: result_monic,
|
|
||||||
factorized: result_monic,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
},
|
|
||||||
_ => FactorizationCost::UnwantedOps
|
|
||||||
}
|
|
||||||
},
|
|
||||||
EquationLanguage::Mul([a,b]) => {
|
|
||||||
match (costs(*a), costs(*b)) {
|
|
||||||
(FactorizationCost::Polynomial(p1), FactorizationCost::Polynomial(p2)) => {
|
|
||||||
FactorizationCost::Polynomial(PolyStat {
|
|
||||||
degree: p1.degree + p2.degree,
|
|
||||||
factors: p1.factors + p2.factors,
|
|
||||||
ops: p1.ops + p2.ops,
|
|
||||||
monomial: p1.monomial && p2.monomial,
|
|
||||||
sum_of_monomials: p1.monomial && p2.monomial,
|
|
||||||
monic: p1.monic && p2.monic,
|
|
||||||
factorized: (p1.monic && p2.factorized) || (p2.monic && p1.factorized)
|
|
||||||
})
|
|
||||||
},
|
|
||||||
_ => FactorizationCost::UnwantedOps
|
|
||||||
}
|
|
||||||
},
|
|
||||||
EquationLanguage::Num(c) => {
|
|
||||||
FactorizationCost::Polynomial(PolyStat {
|
|
||||||
degree: 0,
|
|
||||||
factors: 0,
|
|
||||||
ops: 0,
|
|
||||||
monomial: true,
|
|
||||||
sum_of_monomials: true,
|
|
||||||
monic: false,
|
|
||||||
factorized: true
|
|
||||||
})
|
|
||||||
},
|
|
||||||
EquationLanguage::Unknown => {
|
|
||||||
FactorizationCost::Polynomial(PolyStat {
|
|
||||||
degree: 1,
|
|
||||||
factors: 1,
|
|
||||||
ops: 0,
|
|
||||||
monomial: true,
|
|
||||||
sum_of_monomials: true,
|
|
||||||
monic: true,
|
|
||||||
factorized: true
|
|
||||||
})
|
|
||||||
},
|
|
||||||
_ => FactorizationCost::UnwantedOps,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
pub mod language;
|
pub mod language;
|
||||||
pub mod normal_form;
|
pub mod normal_form;
|
||||||
pub mod parse;
|
pub mod parse;
|
||||||
|
pub mod factorization;
|
||||||
|
pub mod output;
|
||||||
|
284
src/main.rs
284
src/main.rs
@ -1,24 +1,36 @@
|
|||||||
use egg::{Extractor, Pattern, RecExpr, Runner};
|
use egg::{AstSize, Extractor, Id, Pattern, PatternAst, RecExpr, Runner, Searcher};
|
||||||
use solveq::language::{RULES, EquationLanguage, PlusTimesCostFn, FactorizationCostFn};
|
use solveq::factorization::{FactorizationCost, FactorizationCostFn};
|
||||||
use solveq::normal_form::analyze3;
|
use solveq::language::{ConstantFold, EquationLanguage, Rational, RULES, EGraph};
|
||||||
|
use solveq::normal_form::extract_normal_form;
|
||||||
use solveq::parse::parse_equation;
|
use solveq::parse::parse_equation;
|
||||||
|
use solveq::output::print_term;
|
||||||
|
|
||||||
static TEST_EQUATIONS: &[&str] = &[
|
static TEST_EQUATIONS: &[&str] = &[
|
||||||
"(x + 50) * 10 - 150 - 100",
|
"(x + 50) * 10 - 150 = 100",
|
||||||
"(x - 2) * (x + 2) - 0",
|
"(x - 2) * (x + 2) = 0",
|
||||||
"x ^ 2 - 4",
|
"x ^ 2 = 4",
|
||||||
"x ^ 2 - 2 - 0",
|
"x ^ 2 - 2 = 0",
|
||||||
"x ^ 2 - (2 * x + 15)",
|
"x ^ 2 = 2 * x + 15",
|
||||||
"(x ^ 2 - 2 * x - 15) * (x + 5) - 0",
|
"(x ^ 2 - 2 * x - 15) * (x + 5) = 0",
|
||||||
"x ^ 3 + 3 * x ^ 2 - 25 * x - 75 - 0",
|
"x ^ 3 + 3 * x ^ 2 - 25 * x - 75 = 0",
|
||||||
];
|
];
|
||||||
|
|
||||||
fn main() {
|
|
||||||
for eq in TEST_EQUATIONS {
|
|
||||||
let start = parse_equation(*eq).unwrap();
|
|
||||||
|
|
||||||
// println!("{:?}", &start);
|
fn main() {
|
||||||
// do transformation to left - right = 0
|
// let expr: RecExpr<EquationLanguage> = "(+ (* x (+ x -2)) -15)".parse().unwrap();
|
||||||
|
// let expr: RecExpr<EquationLanguage> = "(* (+ (* x (+ x -2)) -15) (+ x 5))".parse().unwrap();
|
||||||
|
// println!("{:?}", get_expression_cost(&expr));
|
||||||
|
|
||||||
|
for eq in TEST_EQUATIONS {
|
||||||
|
println!("Equation: {}", *eq);
|
||||||
|
|
||||||
|
let mut start = parse_equation(*eq).unwrap();
|
||||||
|
let root_id = Id::from(start.as_ref().len()-1);
|
||||||
|
let EquationLanguage::Equals([left, right]) = start[root_id]
|
||||||
|
else { panic!("Not an equation without an equals sign!"); };
|
||||||
|
start[root_id] = EquationLanguage::Sub([left, right]);
|
||||||
|
|
||||||
|
println!("Parsed: {}", &start);
|
||||||
|
|
||||||
let mut runner = Runner::default()
|
let mut runner = Runner::default()
|
||||||
.with_explanations_enabled()
|
.with_explanations_enabled()
|
||||||
@ -29,198 +41,68 @@ fn main() {
|
|||||||
|
|
||||||
let (best_cost, best_expr) = extractor.find_best(runner.roots[0]);
|
let (best_cost, best_expr) = extractor.find_best(runner.roots[0]);
|
||||||
|
|
||||||
println!("{}", start);
|
// println!("Best expresssion: {} {:?}", best_expr, best_cost);
|
||||||
println!("{:?} {:?}", best_cost, <RecExpr<EquationLanguage> as AsRef<[EquationLanguage]>>::as_ref(&best_expr));
|
// println!("{}", runner.explain_equivalence(&start, &best_expr).get_flat_string());
|
||||||
|
|
||||||
|
let Some(factorization) = extract_normal_form(&runner.egraph, runner.roots[0]) else {
|
||||||
|
panic!("Couldn't factorize polynomial!");
|
||||||
|
};
|
||||||
|
println!("Factorized normal form: {}", factorization);
|
||||||
|
|
||||||
|
/*
|
||||||
|
get_expression_cost("(* x (+ x -2))", &runner.egraph);
|
||||||
|
get_expression_cost("-15", &runner.egraph);
|
||||||
|
get_expression_cost("(+ (* x (+ x -2)) -15)", &runner.egraph);
|
||||||
|
*/
|
||||||
|
|
||||||
|
// let factorization = extract_factorization(&best_expr);
|
||||||
|
|
||||||
|
|
||||||
|
let mut solutions: Vec<String> = Vec::new();
|
||||||
|
for poly in &factorization.polynomials {
|
||||||
|
if poly.len() == 2 { // linear factor
|
||||||
|
let Rational { num, denom } = &poly[0];
|
||||||
|
solutions.push(format!("x = {}", Rational { num: -*num, denom: *denom }));
|
||||||
|
} else if poly.len() == 3 { // quadratic factor
|
||||||
|
let Rational { num: num0, denom: denom0 } = &poly[0];
|
||||||
|
let Rational { num: num1, denom: denom1 } = &poly[1];
|
||||||
|
|
||||||
|
let sol1 = format!("- ({num1})/(2 * ({denom1})) + ((({num1})/(2 * ({denom1}))) ^ 2 - ({num0}) / ({denom0})) ^ (1/2)");
|
||||||
|
let sol2 = format!("- ({num1})/(2 * ({denom1})) - ((({num1})/(2 * ({denom1}))) ^ 2 - ({num0}) / ({denom0})) ^ (1/2)");
|
||||||
|
|
||||||
|
let expr = parse_equation(&sol1).unwrap();
|
||||||
|
let runner = Runner::default()
|
||||||
|
.with_expr(&expr)
|
||||||
|
.run(&*RULES);
|
||||||
|
let extractor = Extractor::new(&runner.egraph, AstSize);
|
||||||
|
let (_, simplified_expr) = extractor.find_best(runner.roots[0]);
|
||||||
|
solutions.push(format!("x = {}", print_term(&simplified_expr)));
|
||||||
|
|
||||||
|
let expr = parse_equation(&sol2).unwrap();
|
||||||
|
let runner = Runner::default()
|
||||||
|
.with_expr(&expr)
|
||||||
|
.run(&*RULES);
|
||||||
|
let extractor = Extractor::new(&runner.egraph, AstSize);
|
||||||
|
let (_, simplified_expr) = extractor.find_best(runner.roots[0]);
|
||||||
|
solutions.push(format!("x = {}", print_term(&simplified_expr)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
println!("Solutions: {{ {} }}", solutions.join(", "));
|
||||||
println!("");
|
println!("");
|
||||||
}
|
}
|
||||||
|
|
||||||
// let root = runner.roots[0];
|
|
||||||
// let egraph = &runner.egraph;
|
|
||||||
// let pattern: Pattern<EquationLanguage> = "(+ (* ?a (* x x)) ?c)".parse().unwrap();
|
|
||||||
// let matches = pattern.search(&egraph);
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
// println!("{:?}", egraph.classes().count());
|
|
||||||
|
|
||||||
// Analyze
|
|
||||||
// analyze3(egraph, runner.roots[0]);
|
|
||||||
|
|
||||||
/*
|
|
||||||
for class in egraph.classes() {
|
|
||||||
if monic_nonconst_polynomial(egraph, class.id).is_some() {
|
|
||||||
let (_, best_expr) = extractor.find_best(class.id);
|
|
||||||
println!("Monomial: {}", best_expr);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
println!("{:?}", &matches);
|
fn get_expression_cost(expr: &str, egraph: &EGraph) {
|
||||||
*/
|
// let mut egraph = EGraph::new(ConstantFold::default());
|
||||||
|
// let id = egraph.add_expr(expr);
|
||||||
|
let pattern: Pattern<EquationLanguage> = expr.parse().unwrap();
|
||||||
|
let matches = pattern.search(egraph);
|
||||||
|
for m in matches {
|
||||||
|
let extractor = Extractor::new(&egraph, FactorizationCostFn);
|
||||||
|
let (cost, _) = extractor.find_best(m.eclass);
|
||||||
|
|
||||||
|
println!("expr: {}, id: {}, cost: {:?}", expr, m.eclass, cost);
|
||||||
// println!("{}", runner.explain_equivalence(&start, &best_expr).get_flat_string());
|
|
||||||
}
|
}
|
||||||
|
// cost
|
||||||
|
|
||||||
/*
|
|
||||||
fn power_of_x(egraph: &EGraph, eclass: Id) -> Option<usize> {
|
|
||||||
for n in &egraph[eclass].nodes {
|
|
||||||
match *n {
|
|
||||||
EquationLanguage::Unknown => { return Some(1) },
|
|
||||||
EquationLanguage::Mul([a,b]) => {
|
|
||||||
let Some(left) = power_of_x(egraph, a) else { continue };
|
|
||||||
let Some(right) = power_of_x(egraph, b) else { continue };
|
|
||||||
return Some(left + right);
|
|
||||||
},
|
|
||||||
_ => {}
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
None
|
|
||||||
}
|
|
||||||
|
|
||||||
fn monomial(egraph: &EGraph, eclass: Id) -> Option<(usize, Rational)> {
|
|
||||||
if let Some(deg) = power_of_x(egraph, eclass) {
|
|
||||||
return Some((deg, RATIONAL_ONE.clone()));
|
|
||||||
}
|
|
||||||
|
|
||||||
for n in &egraph[eclass].nodes {
|
|
||||||
match *n {
|
|
||||||
EquationLanguage::Mul([a,b]) => {
|
|
||||||
let Some(coeff) = egraph[a].data.clone() else { continue };
|
|
||||||
let Some(deg) = power_of_x(egraph, b) else { continue };
|
|
||||||
return Some((deg, coeff));
|
|
||||||
},
|
|
||||||
_ => {}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
None
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
// this is either a power_of_x, or a sum of this and a monomial
|
|
||||||
fn monic_nonconst_polynomial(egraph: &EGraph, eclass: Id) -> Option<Vec<Rational>> {
|
|
||||||
let mut result: Vec<Rational> = Vec::new();
|
|
||||||
|
|
||||||
if let Some(deg) = power_of_x(egraph, eclass) {
|
|
||||||
result.resize(deg - 1, RATIONAL_ZERO);
|
|
||||||
result.push(RATIONAL_ONE.clone());
|
|
||||||
return Some(result);
|
|
||||||
}
|
|
||||||
|
|
||||||
for n in &egraph[eclass].nodes {
|
|
||||||
match *n {
|
|
||||||
EquationLanguage::Add([a,b]) => {
|
|
||||||
let Some(mut leading) = monic_nonconst_polynomial(egraph, a)
|
|
||||||
else { continue };
|
|
||||||
let Some(addon) = monomial(egraph, b)
|
|
||||||
else { continue };
|
|
||||||
|
|
||||||
if leading.len() <= addon.0 || leading[addon.0] != RATIONAL_ZERO {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
leading[addon.0] = addon.1.clone();
|
|
||||||
return Some(leading);
|
|
||||||
},
|
|
||||||
_ => {},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
None
|
|
||||||
}
|
|
||||||
|
|
||||||
*/
|
|
||||||
|
|
||||||
|
|
||||||
/*
|
|
||||||
fn analyze(egraph: &EGraph, _id: Id) {
|
|
||||||
let mut types: HashMap<Id, SpecialTerm> = HashMap::new();
|
|
||||||
let mut todo: VecDeque<Id> = VecDeque::new();
|
|
||||||
// todo.push_back(runner.roots[0]);
|
|
||||||
for cls in egraph.classes() {
|
|
||||||
todo.push_back(cls.id);
|
|
||||||
}
|
|
||||||
|
|
||||||
'todo: while todo.len() > 0 {
|
|
||||||
let id = todo.pop_front().unwrap();
|
|
||||||
if types.contains_key(&id) {
|
|
||||||
continue 'todo;
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Some(c) = &egraph[id].data {
|
|
||||||
types.insert(id, SpecialTerm::Constant(c.clone()));
|
|
||||||
continue 'todo;
|
|
||||||
}
|
|
||||||
|
|
||||||
'nodes: for n in &egraph[id].nodes {
|
|
||||||
match *n {
|
|
||||||
EquationLanguage::Unknown => {
|
|
||||||
types.insert(id, SpecialTerm::PowerOfX(1));
|
|
||||||
continue 'todo;
|
|
||||||
},
|
|
||||||
EquationLanguage::Mul([a,b]) => {
|
|
||||||
if !types.contains_key(&a) {
|
|
||||||
todo.push_back(a);
|
|
||||||
todo.push_back(id);
|
|
||||||
continue 'nodes;
|
|
||||||
}
|
|
||||||
|
|
||||||
if !types.contains_key(&b) {
|
|
||||||
todo.push_back(b);
|
|
||||||
todo.push_back(id);
|
|
||||||
continue 'nodes;
|
|
||||||
}
|
|
||||||
|
|
||||||
match (&types[&a], &types[&b]) {
|
|
||||||
(SpecialTerm::PowerOfX(dega), SpecialTerm::PowerOfX(degb)) => {
|
|
||||||
types.insert(id, SpecialTerm::PowerOfX(*dega + *degb));
|
|
||||||
},
|
|
||||||
(SpecialTerm::Constant(coeff), SpecialTerm::PowerOfX(deg)) => {
|
|
||||||
types.insert(id, SpecialTerm::Monomial(*deg, coeff.clone()));
|
|
||||||
},
|
|
||||||
_ => { continue 'nodes; },
|
|
||||||
}
|
|
||||||
continue 'todo;
|
|
||||||
},
|
|
||||||
EquationLanguage::Add([a,b]) => {
|
|
||||||
if !types.contains_key(&a) {
|
|
||||||
todo.push_front(a);
|
|
||||||
todo.push_back(id);
|
|
||||||
continue 'todo;
|
|
||||||
}
|
|
||||||
|
|
||||||
if !types.contains_key(&b) {
|
|
||||||
todo.push_front(b);
|
|
||||||
todo.push_back(id);
|
|
||||||
continue 'todo;
|
|
||||||
}
|
|
||||||
|
|
||||||
match (&types[&a], &types[&b]) {
|
|
||||||
(SpecialTerm::MonicNonconstPoly(poly), SpecialTerm::Monomial(deg, coeff)) => {
|
|
||||||
if poly.len() <= *deg || poly[*deg] != RATIONAL_ZERO {
|
|
||||||
continue 'nodes;
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut poly = poly.clone();
|
|
||||||
poly[*deg] = coeff.clone();
|
|
||||||
types.insert(id, SpecialTerm::MonicNonconstPoly(poly));
|
|
||||||
},
|
|
||||||
_ => { continue 'nodes; },
|
|
||||||
}
|
|
||||||
continue 'todo;
|
|
||||||
},
|
|
||||||
_ => {},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
types.insert(id, SpecialTerm::Other);
|
|
||||||
}
|
|
||||||
|
|
||||||
for (id, ty) in &types {
|
|
||||||
if !matches!(ty, &SpecialTerm::Other) {
|
|
||||||
println!("{:?}", &ty);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
*/
|
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
use crate::language::{EGraph, EquationLanguage, Rational, RATIONAL_ONE, RATIONAL_ZERO};
|
use crate::language::{EGraph, EquationLanguage, Rational, RATIONAL_ONE, RATIONAL_ZERO};
|
||||||
use std::collections::HashMap;
|
use std::{collections::HashMap, fmt};
|
||||||
use egg::Id;
|
use egg::{AstSize, Extractor, Id};
|
||||||
|
|
||||||
#[derive(Debug,Clone)]
|
#[derive(Debug,Clone)]
|
||||||
pub enum SpecialTerm {
|
pub enum SpecialTerm {
|
||||||
@ -12,36 +12,281 @@ pub enum SpecialTerm {
|
|||||||
Other,
|
Other,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn search_for<F, T>(egraph: &EGraph, f: F) -> HashMap<Id, T>
|
#[derive(Debug,Clone)]
|
||||||
|
pub struct Factorization {
|
||||||
|
pub constant_factor: Rational,
|
||||||
|
pub polynomials: Vec<Vec<Rational>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
// this is a property of an eclass, not a particular expression
|
||||||
|
#[derive(Debug,Clone)]
|
||||||
|
struct PolyStats {
|
||||||
|
degree: usize,
|
||||||
|
monomial: bool,
|
||||||
|
monic: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn gather_poly_stats(egraph: &EGraph) -> HashMap<Id, PolyStats> {
|
||||||
|
walk_egraph(egraph, |_id, node, stats: &HashMap<Id, PolyStats>| {
|
||||||
|
let x = |i: &Id| stats.get(&egraph.find(*i));
|
||||||
|
Some(match node {
|
||||||
|
EquationLanguage::Unknown => PolyStats {
|
||||||
|
degree: 1,
|
||||||
|
monomial: true,
|
||||||
|
monic: true,
|
||||||
|
},
|
||||||
|
EquationLanguage::Num(c) => PolyStats {
|
||||||
|
degree: 0,
|
||||||
|
monomial: true,
|
||||||
|
monic: c == &RATIONAL_ONE,
|
||||||
|
},
|
||||||
|
EquationLanguage::Mul([a,b]) => {
|
||||||
|
// if both aren't monic we can't tell, the leading coefficients could cancel
|
||||||
|
// but there should be an alternative representative with one of them monic
|
||||||
|
if !x(a)?.monic && !x(b)?.monic {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
PolyStats {
|
||||||
|
degree: x(a)?.degree + x(b)?.degree,
|
||||||
|
monomial: x(a)?.monomial && x(b)?.monomial,
|
||||||
|
monic: x(a)?.monic && x(b)?.monic,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
EquationLanguage::Add([a,b]) => {
|
||||||
|
// in this case, there should also be a simplified representative which
|
||||||
|
// has only a single leading term
|
||||||
|
if x(a)?.degree == x(b)?.degree {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
PolyStats {
|
||||||
|
degree: usize::max(x(a)?.degree, x(b)?.degree),
|
||||||
|
monomial: false,
|
||||||
|
monic: if x(a)?.degree > x(b)?.degree { x(a)?.monic } else { x(b)?.monic },
|
||||||
|
}
|
||||||
|
},
|
||||||
|
_ => { return None; }
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn extract_normal_form(egraph: &EGraph, eclass: Id) -> Option<Factorization> {
|
||||||
|
let eclass = egraph.find(eclass); // get the canonical eclass
|
||||||
|
|
||||||
|
let stats = gather_poly_stats(egraph);
|
||||||
|
|
||||||
|
let Some(factorization) = find_general_factorization(egraph, &stats, eclass) else { return None; };
|
||||||
|
|
||||||
|
let mut result = Vec::new();
|
||||||
|
let mut coeff: Option<Rational> = None;
|
||||||
|
|
||||||
|
for factor in factorization {
|
||||||
|
let extracted = extract_polynomial(egraph, &stats, factor)?;
|
||||||
|
// println!("Extracted: {:?}", extracted);
|
||||||
|
|
||||||
|
if extracted.len() == 1 {
|
||||||
|
coeff = Some(extracted[0].clone());
|
||||||
|
} else {
|
||||||
|
result.push(extracted);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Some(Factorization {
|
||||||
|
constant_factor: coeff.unwrap_or_else(||RATIONAL_ONE.clone()),
|
||||||
|
polynomials: result
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// a polynomial should be either of:
|
||||||
|
// - a monomial: then we know the degree and we try to parse it as a product of a constant and a monic monomial
|
||||||
|
// or just a constant
|
||||||
|
// - a sum of a monomial of highest degree, and a polynomial of lower degree, recursively walk these
|
||||||
|
fn extract_polynomial(egraph: &EGraph, stats: &HashMap<Id, PolyStats>, id: Id) -> Option<Vec<Rational>> {
|
||||||
|
let st = &stats[&id];
|
||||||
|
|
||||||
|
if st.monomial {
|
||||||
|
let (deg, coeff) = extract_monomial(egraph, stats, id)?;
|
||||||
|
|
||||||
|
let mut result = vec![RATIONAL_ZERO; deg];
|
||||||
|
result.push(coeff);
|
||||||
|
return Some(result);
|
||||||
|
} else {
|
||||||
|
for node in &egraph[id].nodes {
|
||||||
|
match node {
|
||||||
|
EquationLanguage::Add([a,b]) => {
|
||||||
|
let a = egraph.find(*a);
|
||||||
|
let b = egraph.find(*b);
|
||||||
|
let Some(stata) = &stats.get(&a) else { continue };
|
||||||
|
let Some(statb) = &stats.get(&b) else { continue };
|
||||||
|
|
||||||
|
if stata.degree == st.degree && stata.monomial && statb.degree < st.degree {
|
||||||
|
let (leading_deg, leading_coeff) = extract_monomial(egraph, stats, a)?;
|
||||||
|
let mut remainder = extract_polynomial(egraph, stats, b)?;
|
||||||
|
|
||||||
|
assert!(leading_deg >= remainder.len());
|
||||||
|
|
||||||
|
remainder.resize(leading_deg, RATIONAL_ZERO.clone());
|
||||||
|
remainder.push(leading_coeff);
|
||||||
|
return Some(remainder);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
|
// a monomial is either a power of x, a constant, or a product of a constant and power of x
|
||||||
|
fn extract_monomial(egraph: &EGraph, stats: &HashMap<Id, PolyStats>, id: Id) -> Option<(usize, Rational)> {
|
||||||
|
let extractor = Extractor::new(egraph, AstSize);
|
||||||
|
let (_, expr) = extractor.find_best(id);
|
||||||
|
// println!("Extract Monomial: {}", expr);
|
||||||
|
|
||||||
|
let st = &stats[&id];
|
||||||
|
|
||||||
|
assert!(st.monomial);
|
||||||
|
|
||||||
|
// monic + monomial = power of x
|
||||||
|
if st.monic {
|
||||||
|
return Some((st.degree, RATIONAL_ONE.clone()));
|
||||||
|
}
|
||||||
|
|
||||||
|
for node in &egraph[id].nodes {
|
||||||
|
match node {
|
||||||
|
EquationLanguage::Mul([a,b]) => {
|
||||||
|
let a = egraph.find(*a);
|
||||||
|
let b = egraph.find(*b);
|
||||||
|
let Some(statb) = stats.get(&b) else { continue };
|
||||||
|
|
||||||
|
// a should be a constant
|
||||||
|
let Some(coeff) = egraph[a].data.clone() else { continue };
|
||||||
|
|
||||||
|
// b should be monic and nonconstant (hence a power of x)
|
||||||
|
if statb.degree == 0 || !statb.monic {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
assert_eq!(st.degree, statb.degree);
|
||||||
|
|
||||||
|
return Some((st.degree, coeff));
|
||||||
|
},
|
||||||
|
EquationLanguage::Num(c) => { // a constant is also a monomial
|
||||||
|
return Some((0, c.clone()));
|
||||||
|
},
|
||||||
|
_ => {},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
|
// like find_monic_factorization, but with the option of having a constant factor
|
||||||
|
fn find_general_factorization(egraph: &EGraph, stats: &HashMap<Id, PolyStats>, id: Id) -> Option<Vec<Id>> {
|
||||||
|
let st = stats.get(&id)?;
|
||||||
|
|
||||||
|
if st.monic {
|
||||||
|
return find_monic_factorization(egraph, stats, id);
|
||||||
|
} else {
|
||||||
|
for node in &egraph[id].nodes {
|
||||||
|
match node {
|
||||||
|
EquationLanguage::Mul([a,b]) => {
|
||||||
|
let a = egraph.find(*a);
|
||||||
|
let b = egraph.find(*b);
|
||||||
|
let Some(stata) = stats.get(&a) else { continue };
|
||||||
|
let Some(statb) = stats.get(&b) else { continue };
|
||||||
|
|
||||||
|
// a is constant, b is monic and nonconstant
|
||||||
|
if stata.degree == 0 && statb.degree > 0 && statb.monic {
|
||||||
|
let mut fac = find_monic_factorization(egraph, stats, b)?;
|
||||||
|
fac.push(a);
|
||||||
|
return Some(fac);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
|
// this assumes `id` to be canonical
|
||||||
|
fn find_monic_factorization(egraph: &EGraph, stats: &HashMap<Id, PolyStats>, id: Id) -> Option<Vec<Id>> {
|
||||||
|
// we want the polynomial to be nonconstant and monic
|
||||||
|
if ! stats.get(&id).is_some_and(|x|x.monic && x.degree > 0) {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
// now the whole thing is a monic nonconst poly, so would be a valid factorization,
|
||||||
|
// but we want to go as deep as possible
|
||||||
|
// println!("{:?}", stats[&id]);
|
||||||
|
|
||||||
|
// check if it is the product of two nonconstant monic polynomials
|
||||||
|
for node in &egraph[id].nodes {
|
||||||
|
match node {
|
||||||
|
EquationLanguage::Mul([a,b]) => {
|
||||||
|
let a = egraph.find(*a);
|
||||||
|
let b = egraph.find(*b);
|
||||||
|
let Some(stata) = stats.get(&a) else { continue };
|
||||||
|
let Some(statb) = stats.get(&b) else { continue };
|
||||||
|
|
||||||
|
if stata.degree == 0 || statb.degree == 0 || !stata.monic || !statb.monic {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// println!("stats = {:?}, stats a = {:?}, stats b = {:?}", stats[&id], stata, statb);
|
||||||
|
|
||||||
|
let Some(mut faca) = find_monic_factorization(egraph, stats, a) else { continue };
|
||||||
|
let Some(facb) = find_monic_factorization(egraph, stats, b) else { continue };
|
||||||
|
|
||||||
|
faca.extend_from_slice(&facb);
|
||||||
|
return Some(faca);
|
||||||
|
},
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// at this point we know the current polynomial is monic, but we didn't find a further factorization
|
||||||
|
// so just return it as a single factor
|
||||||
|
Some(vec![id])
|
||||||
|
}
|
||||||
|
|
||||||
|
fn walk_egraph<F, T>(egraph: &EGraph, f: F) -> HashMap<Id, T>
|
||||||
where
|
where
|
||||||
F: Fn(Id, &EquationLanguage, &HashMap<Id, T>) -> Option<T> {
|
F: Fn(Id, &EquationLanguage, &HashMap<Id, T>) -> Option<T> {
|
||||||
|
|
||||||
let mut result: HashMap<Id, T> = HashMap::new();
|
let mut result: HashMap<Id, T> = HashMap::new();
|
||||||
let mut modifications: usize = 1;
|
let mut modifications: usize = 1;
|
||||||
|
|
||||||
|
// println!("{:?}", egraph[canonical]);
|
||||||
|
|
||||||
while modifications > 0 {
|
while modifications > 0 {
|
||||||
modifications = 0;
|
modifications = 0;
|
||||||
|
|
||||||
for cls in egraph.classes() {
|
'next_class: for cls in egraph.classes() {
|
||||||
let id = cls.id;
|
let id = cls.id;
|
||||||
if result.contains_key(&id) {
|
if result.contains_key(&id) {
|
||||||
continue;
|
continue 'next_class;
|
||||||
}
|
}
|
||||||
|
|
||||||
for node in &cls.nodes {
|
for node in &cls.nodes {
|
||||||
if let Some(x) = f(id, node, &result) {
|
if let Some(x) = f(id, node, &result) {
|
||||||
result.insert(id, x);
|
result.insert(id, x);
|
||||||
modifications += 1;
|
modifications += 1;
|
||||||
|
continue 'next_class;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
println!("{} modifications!", modifications);
|
// println!("{} modifications!", modifications);
|
||||||
}
|
}
|
||||||
|
|
||||||
result
|
result
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
pub fn analyze3(egraph: &EGraph, eclass: Id) {
|
pub fn analyze3(egraph: &EGraph, eclass: Id) {
|
||||||
let constants = search_for(egraph, |id, _, _|
|
let constants = search_for(egraph, |id, _, _|
|
||||||
egraph[id].data.as_ref().map(|c|c.clone())
|
egraph[id].data.as_ref().map(|c|c.clone())
|
||||||
@ -227,3 +472,28 @@ pub fn analyze2(egraph: &EGraph) -> HashMap<Id, SpecialTerm> {
|
|||||||
|
|
||||||
types
|
types
|
||||||
}
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
|
impl fmt::Display for Factorization {
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
|
if self.constant_factor != RATIONAL_ONE {
|
||||||
|
write!(f, "{}", self.constant_factor)?;
|
||||||
|
}
|
||||||
|
|
||||||
|
for poly in &self.polynomials {
|
||||||
|
write!(f, "(")?;
|
||||||
|
for (deg, coeff) in poly.iter().enumerate() {
|
||||||
|
if deg == 0 {
|
||||||
|
write!(f, "{}", coeff)?;
|
||||||
|
} else if deg == 1 {
|
||||||
|
write!(f, " + {}x", coeff)?;
|
||||||
|
} else {
|
||||||
|
write!(f, " + {}x^{}", coeff, deg)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
write!(f, ")")?;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
66
src/output.rs
Normal file
66
src/output.rs
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
use egg::{RecExpr, Id};
|
||||||
|
use crate::language::EquationLanguage;
|
||||||
|
|
||||||
|
// there is already a Display implementation generated by define_langauge!
|
||||||
|
// but we want an alternative string conversion
|
||||||
|
pub fn print_term(expr: &RecExpr<EquationLanguage>) -> String {
|
||||||
|
let root_id = Id::from(expr.as_ref().len()-1);
|
||||||
|
print_term_inner(expr, root_id).0
|
||||||
|
}
|
||||||
|
|
||||||
|
// the second result is the precedence of the top level op: 1 = '+-', 2 = '*/', 3 = '^', 4 = primitive
|
||||||
|
fn print_term_inner(expr: &RecExpr<EquationLanguage>, id: Id) -> (String, usize) {
|
||||||
|
match &expr[id] {
|
||||||
|
EquationLanguage::Num(c) => {
|
||||||
|
(format!("{}", c), if c.denom == 1 { 4 } else { 2 })
|
||||||
|
},
|
||||||
|
EquationLanguage::Neg([a]) => {
|
||||||
|
(print_unary(expr, *a, "-", 1), 1)
|
||||||
|
},
|
||||||
|
EquationLanguage::Add([a,b]) => {
|
||||||
|
(print_binary(expr, *a, *b, "+", 1), 1)
|
||||||
|
},
|
||||||
|
EquationLanguage::Sub([a,b]) => {
|
||||||
|
(print_binary(expr, *a, *b, "-", 1), 1)
|
||||||
|
},
|
||||||
|
EquationLanguage::Mul([a,b]) => {
|
||||||
|
(print_binary(expr, *a, *b, "*", 2), 2)
|
||||||
|
},
|
||||||
|
EquationLanguage::Div([a,b]) => {
|
||||||
|
(print_binary(expr, *a, *b, "/", 2), 2)
|
||||||
|
},
|
||||||
|
EquationLanguage::Power([a,b]) => {
|
||||||
|
(print_binary(expr, *a, *b, "^", 3), 3)
|
||||||
|
},
|
||||||
|
_ => unimplemented!()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn print_unary(expr: &RecExpr<EquationLanguage>, a: Id, op: &str, precedence: usize) -> String {
|
||||||
|
let (astr, aprec) = print_term_inner(expr, a);
|
||||||
|
|
||||||
|
if aprec > precedence {
|
||||||
|
format!("{}{}", op, astr)
|
||||||
|
} else {
|
||||||
|
format!("{}({})", op, astr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn print_binary(expr: &RecExpr<EquationLanguage>, a: Id, b: Id, op: &str, precedence: usize) -> String {
|
||||||
|
let (astr, aprec) = print_term_inner(expr, a);
|
||||||
|
let (bstr, bprec) = print_term_inner(expr, b);
|
||||||
|
|
||||||
|
if aprec > precedence {
|
||||||
|
if bprec > precedence {
|
||||||
|
format!("{} {} {}", astr, op, bstr)
|
||||||
|
} else {
|
||||||
|
format!("{} {} ({})", astr, op, bstr)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if bprec > precedence {
|
||||||
|
format!("({}) {} {}", astr, op, bstr)
|
||||||
|
} else {
|
||||||
|
format!("({}) {} ({})", astr, op, bstr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -26,15 +26,15 @@ fn parse_equation_inner(input: &str, expr: &mut RecExpr<EquationLanguage>) -> Re
|
|||||||
}
|
}
|
||||||
|
|
||||||
match c {
|
match c {
|
||||||
'^' if precedence > 3 => {
|
'^' if precedence >= 3 => {
|
||||||
operator_position = Some(i);
|
operator_position = Some(i);
|
||||||
precedence = 3;
|
precedence = 3;
|
||||||
},
|
},
|
||||||
'*' | '/' if precedence > 2 => {
|
'*' | '/' if precedence >= 2 => {
|
||||||
operator_position = Some(i);
|
operator_position = Some(i);
|
||||||
precedence = 2;
|
precedence = 2;
|
||||||
},
|
},
|
||||||
'-' | '+' if precedence > 1 => {
|
'-' | '+' if precedence >= 1 => {
|
||||||
operator_position = Some(i);
|
operator_position = Some(i);
|
||||||
precedence = 1;
|
precedence = 1;
|
||||||
},
|
},
|
||||||
|
Loading…
Reference in New Issue
Block a user