Compare commits
7 Commits
68b6293028
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fced3915ea | ||
|
|
829f1e7fa7 | ||
|
|
74ffaea4c0 | ||
|
|
0b8bdf3da6 | ||
|
|
c110dd6889 | ||
|
|
ef2a76869f | ||
|
|
d0265ea340 |
61
README.md
Normal file
61
README.md
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
# Symbolic solver for polynomials with rational coefficients #
|
||||||
|
|
||||||
|
## Program output ##
|
||||||
|
|
||||||
|
Using `cargo run --release`. On my laptop, the last example times out on debug settings.
|
||||||
|
|
||||||
|
Equation: (x + 50) * 10 - 150 = 100
|
||||||
|
Parsed: (- (- (* (+ x 50) 10) 150) 100)
|
||||||
|
Factorized normal form: 10(25 + 1x)
|
||||||
|
Solutions: { x = -25 }
|
||||||
|
|
||||||
|
Equation: (x - 2) * (x + 2) = 0
|
||||||
|
Parsed: (- (* (- x 2) (+ x 2)) 0)
|
||||||
|
Factorized normal form: (-2 + 1x)(2 + 1x)
|
||||||
|
Solutions: { x = 2, x = -2 }
|
||||||
|
|
||||||
|
Equation: x ^ 2 = 4
|
||||||
|
Parsed: (- (^ x 2) 4)
|
||||||
|
Factorized normal form: (-4 + 0x + 1x^2)
|
||||||
|
Solutions: { x = 2, x = -2 }
|
||||||
|
|
||||||
|
Equation: x ^ 2 - 2 = 0
|
||||||
|
Parsed: (- (- (^ x 2) 2) 0)
|
||||||
|
Factorized normal form: (-2 + 0x + 1x^2)
|
||||||
|
Solutions: { x = 2 ^ (1/2), x = -2 ^ (1/2) }
|
||||||
|
|
||||||
|
Equation: x ^ 2 = 2 * x + 15
|
||||||
|
Parsed: (- (^ x 2) (+ (* 2 x) 15))
|
||||||
|
Factorized normal form: (-15 + -2x + 1x^2)
|
||||||
|
Solutions: { x = 5, x = -3 }
|
||||||
|
|
||||||
|
Equation: (x ^ 2 - 2 * x - 15) * (x + 5) = 0
|
||||||
|
Parsed: (- (* (- (- (^ x 2) (* 2 x)) 15) (+ x 5)) 0)
|
||||||
|
Factorized normal form: (-15 + -2x + 1x^2)(5 + 1x)
|
||||||
|
Solutions: { x = 5, x = -3, x = -5 }
|
||||||
|
|
||||||
|
Equation: x ^ 3 + 3 * x ^ 2 - 25 * x - 75 = 0
|
||||||
|
Parsed: (- (- (- (+ (^ x 3) (* 3 (^ x 2))) (* 25 x)) 75) 0)
|
||||||
|
Factorized normal form: (3 + 1x)(-25 + 0x + 1x^2)
|
||||||
|
Solutions: { x = -3, x = 5, x = -5 }
|
||||||
|
|
||||||
|
Equation: x ^ 3 - 91 * x - 90 = 0
|
||||||
|
Parsed: (- (- (- (^ x 3) (* 91 x)) 90) 0)
|
||||||
|
Factorized normal form: (-90 + -91x + 0x^2 + 1x^3)
|
||||||
|
Guessing rational solutions: 10, -9, -1
|
||||||
|
Verified guessed solutions!
|
||||||
|
Solutions: { x = 10, x = -9, x = -1 }
|
||||||
|
|
||||||
|
## Notes ##
|
||||||
|
|
||||||
|
This uses an egraph with rewrite rules representing the usual rules of arithmetic, and constant folding with rational numbers. This alone is enough to solve linear equations, by just starting with the difference of the two sides and minimizing AST size. However, it can't suffice for quadratic equations, as the egraph can't create values "out of this air". For example, to solve `x^2 = 2` the egraph must at least have "seen" the expression `2^(1/2)` once.
|
||||||
|
|
||||||
|
So instead, the egraph is used to bring the expression into a standard form: a product of monic polynomials which is factored as much as possible. Since this is difficult to model with a cost function, I wrote a custom extractor (in `normal_form.rs`).
|
||||||
|
|
||||||
|
Then we can use the usual quadratic formula to find solutions. The obtained solutions are afterwards simplified by another egraph, with a special "square root of a square integer" rule. This is pretty minimal, a more complete solution would also be able to deal with square roots of rational numbers etc.
|
||||||
|
|
||||||
|
Finally, the unfactorized cubic equation (equation 7) gets factored into a linear and a quadratic factor by the egraph. This is kind of "magic" though: the correct terms seem to come up randomly while mutating the egraph. For general cubics we can't expect this to happen, even if they have integer solutions. The additional equation 8 is an example for this.
|
||||||
|
|
||||||
|
While there is a general formula for the solutions of a cubic equation, it contains trigonometric functions and/or complex numbers, and we don't want to deal with that here. So we use a different strategy which can handle cubic equations with 3 rational solutions: we first solve the equation numerically using the cubic formula, then find rational expressions for the numerical solutions using continued fractions. Finally, we verify the resulting factorization by using an egraph. This equivalence check required increased node and iteration limits, and takes a few seconds on my laptop.
|
||||||
|
|
||||||
|
This has essentially only been tested on the 8 examples above. So I would expect there to still be a few bugs. The algorithm is also spectacularly inefficient, considering that the equality of rational polynomials is almost trivial to compute.
|
||||||
91
src/cubic.rs
Normal file
91
src/cubic.rs
Normal file
@@ -0,0 +1,91 @@
|
|||||||
|
use std::f64::consts::PI;
|
||||||
|
|
||||||
|
use crate::language::Rational;
|
||||||
|
|
||||||
|
pub fn approximate_rational_cubic(b: &Rational, c: &Rational, d: &Rational, limit: u64) -> Vec<Rational> {
|
||||||
|
let numerical_sols = solve_cubic_numerically(
|
||||||
|
1.0,
|
||||||
|
(b.num as f64) / (b.denom as f64),
|
||||||
|
(c.num as f64) / (c.denom as f64),
|
||||||
|
(d.num as f64) / (d.denom as f64));
|
||||||
|
|
||||||
|
numerical_sols.into_iter().map(|x|rational_approx(x, limit)).collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
// assuming leading coefficient is not 0, and also that we don't have a triple root
|
||||||
|
pub fn solve_cubic_numerically(a: f64, b: f64, c: f64, d: f64) -> Vec<f64> {
|
||||||
|
assert_ne!(a, 0.0);
|
||||||
|
|
||||||
|
let b = b/a;
|
||||||
|
let c = c/a;
|
||||||
|
let d = d/a;
|
||||||
|
|
||||||
|
let b2 = b * b;
|
||||||
|
let b3 = b2 * b;
|
||||||
|
|
||||||
|
solve_depressed_cubic_numerically(
|
||||||
|
c - b2 / 3.0,
|
||||||
|
2.0 * b3 / 27.0 - b * c / 3.0 + d
|
||||||
|
).into_iter().map(|u|u - b / 3.0 / a).collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn solve_depressed_cubic_numerically(p: f64, q: f64) -> Vec<f64> {
|
||||||
|
let disc = p * p * p / 27.0 + q * q / 4.0;
|
||||||
|
|
||||||
|
if disc < 0.0 {
|
||||||
|
let r = 2.0 * (-p / 3.0).sqrt();
|
||||||
|
let t = 3.0 * q / 2.0 / p * (- 3.0 / p).sqrt();
|
||||||
|
let phi = t.acos() / 3.0;
|
||||||
|
|
||||||
|
vec![r * (phi).cos(),
|
||||||
|
r * (phi + 2.0 * PI / 3.0).cos(),
|
||||||
|
r * (phi + 4.0 * PI / 3.0).cos()]
|
||||||
|
} else {
|
||||||
|
// not implemented at the moment
|
||||||
|
let u = - q / 2.0 + disc.sqrt();
|
||||||
|
let v = - q / 2.0 - disc.sqrt();
|
||||||
|
vec![u.powf(1.0/3.0) + v.powf(1.0/3.0)]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// pretty inefficient implementation of continued fraction approximation
|
||||||
|
// but doesn't matter here
|
||||||
|
pub fn rational_approx(x: f64, limit: u64) -> Rational {
|
||||||
|
if x < 0.0 {
|
||||||
|
let Rational { num, denom } = rational_approx(-x, limit);
|
||||||
|
return Rational { num: -num, denom }
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut num = 0;
|
||||||
|
let mut denom = 0;
|
||||||
|
for l in 0 .. 10 {
|
||||||
|
let Some((p,q)) = rational_approx_level(x, l) else { break; };
|
||||||
|
|
||||||
|
if q > limit {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
(num, denom) = (p, q);
|
||||||
|
}
|
||||||
|
|
||||||
|
Rational{
|
||||||
|
num: num as i64,
|
||||||
|
denom
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn rational_approx_level(x: f64, level: usize) -> Option<(u64, u64)> {
|
||||||
|
// we expect very big numbers or infinity if the previous iteration was exact
|
||||||
|
if x > 1e9 {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
if level == 0 {
|
||||||
|
Some((x as u64, 1))
|
||||||
|
} else {
|
||||||
|
let (p,q) = rational_approx_level(1.0 / (x - x.floor()), level-1)?;
|
||||||
|
let floorx = x as u64;
|
||||||
|
|
||||||
|
Some((q + floorx * p, p))
|
||||||
|
}
|
||||||
|
}
|
||||||
188
src/language.rs
188
src/language.rs
@@ -1,5 +1,5 @@
|
|||||||
use std::{cmp::Ordering, fmt::{self,Display, Formatter}, str::FromStr, sync::LazyLock};
|
use std::{cmp::Ordering, fmt::{self,Display, Formatter}, str::FromStr, sync::LazyLock};
|
||||||
use egg::{define_language, merge_option, rewrite as rw, Analysis, DidMerge, Id, Language, Subst, Var};
|
use egg::{define_language, merge_option, rewrite as rw, Analysis, Applier, DidMerge, Id, Language, PatternAst, Subst, Symbol, Var};
|
||||||
|
|
||||||
pub type EGraph = egg::EGraph<EquationLanguage, ConstantFold>;
|
pub type EGraph = egg::EGraph<EquationLanguage, ConstantFold>;
|
||||||
pub type Rewrite = egg::Rewrite<EquationLanguage, ConstantFold>;
|
pub type Rewrite = egg::Rewrite<EquationLanguage, ConstantFold>;
|
||||||
@@ -39,9 +39,19 @@ impl Display for Rational {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl FromStr for Rational {
|
impl FromStr for Rational {
|
||||||
type Err = std::num::ParseIntError;
|
type Err = String;
|
||||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
fn from_str(s: &str) -> Result<Self, String> {
|
||||||
Ok(Rational { num: s.parse::<i64>()?, denom: 1 })
|
let err = || Err(format!("Couldn't parse rational: {}", s));
|
||||||
|
|
||||||
|
if let Ok(num) = s.parse::<i64>() {
|
||||||
|
Ok(Rational { num, denom: 1 })
|
||||||
|
} else if let Some((snum, sdenom)) = s.split_once('/') {
|
||||||
|
let Ok(num) = snum.parse::<i64>() else { return err(); };
|
||||||
|
let Ok(denom) = sdenom.parse::<u64>() else { return err(); };
|
||||||
|
Ok(Rational { num, denom })
|
||||||
|
} else {
|
||||||
|
err()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -167,6 +177,46 @@ fn is_nonzero_const(var: &str) -> impl Fn(&mut EGraph, Id, &Subst) -> bool {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn is_const(var: &str) -> impl Fn(&mut EGraph, Id, &Subst) -> bool {
|
||||||
|
let var: Var = var.parse().unwrap();
|
||||||
|
move |egraph, _, subst| {
|
||||||
|
egraph[subst[var]].data.is_some()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct IntegerSqrt {
|
||||||
|
var: Var,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Applier<EquationLanguage, ConstantFold> for IntegerSqrt {
|
||||||
|
fn apply_one(&self,
|
||||||
|
egraph: &mut EGraph,
|
||||||
|
matched_id: Id,
|
||||||
|
subst: &Subst,
|
||||||
|
_searcher_pattern: Option<&PatternAst<EquationLanguage>>,
|
||||||
|
_rule_name: Symbol)
|
||||||
|
-> Vec<Id> {
|
||||||
|
let var_id = subst[self.var];
|
||||||
|
|
||||||
|
if let Some(value) = &egraph[var_id].data {
|
||||||
|
if value.denom == 1 && value.num >= 0 {
|
||||||
|
// isqrt is nightly only, so we just do this, adding 0.1 against rounding errors
|
||||||
|
let sq = (f64::sqrt(value.num as f64) + 0.1) as i64;
|
||||||
|
if value.num == sq*sq {
|
||||||
|
// println!("square root of integer {} is {}", value.num, sq);
|
||||||
|
|
||||||
|
let sq_id = egraph.add(EquationLanguage::Num(Rational { num: sq, denom: 1 }));
|
||||||
|
egraph.union(matched_id, sq_id);
|
||||||
|
|
||||||
|
return vec![matched_id, sq_id];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
vec![]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub static RULES: LazyLock<Vec<Rewrite>> = LazyLock::new(||vec![
|
pub static RULES: LazyLock<Vec<Rewrite>> = LazyLock::new(||vec![
|
||||||
rw!("commute-add"; "(+ ?x ?y)" => "(+ ?y ?x)"),
|
rw!("commute-add"; "(+ ?x ?y)" => "(+ ?y ?x)"),
|
||||||
rw!("commute-mul"; "(* ?x ?y)" => "(* ?y ?x)"),
|
rw!("commute-mul"; "(* ?x ?y)" => "(* ?y ?x)"),
|
||||||
@@ -178,6 +228,8 @@ pub static RULES: LazyLock<Vec<Rewrite>> = LazyLock::new(||vec![
|
|||||||
rw!("mul-0"; "(* ?x 0)" => "0"),
|
rw!("mul-0"; "(* ?x 0)" => "0"),
|
||||||
rw!("mul-1"; "(* ?x 1)" => "?x"),
|
rw!("mul-1"; "(* ?x 1)" => "?x"),
|
||||||
|
|
||||||
|
rw!("0-sub"; "(- 0 ?x)" => "(- ?x)"),
|
||||||
|
|
||||||
rw!("add-sub"; "(+ ?x (* (-1) ?x))" => "0"),
|
rw!("add-sub"; "(+ ?x (* (-1) ?x))" => "0"),
|
||||||
// division by zero shouldn't happen unless input is invalid
|
// division by zero shouldn't happen unless input is invalid
|
||||||
rw!("mul-div"; "(* ?x (rec ?x))" => "1" if is_nonzero_const("?y")),
|
rw!("mul-div"; "(* ?x (rec ?x))" => "1" if is_nonzero_const("?y")),
|
||||||
@@ -187,10 +239,6 @@ pub static RULES: LazyLock<Vec<Rewrite>> = LazyLock::new(||vec![
|
|||||||
|
|
||||||
rw!("square"; "(^ ?x 2)" => "(* ?x ?x)"),
|
rw!("square"; "(^ ?x 2)" => "(* ?x ?x)"),
|
||||||
rw!("cube"; "(^ ?x 3)" => "(* ?x (* ?x ?x))"),
|
rw!("cube"; "(^ ?x 3)" => "(* ?x (* ?x ?x))"),
|
||||||
/*
|
|
||||||
rw!("inv-square"; "(* ?x ?x)" => "(^ ?x 2)"),
|
|
||||||
rw!("inv-cube"; "(* ?x (* ?x ?x))" => "(^ ?x 3)"),
|
|
||||||
*/
|
|
||||||
|
|
||||||
rw!("sub"; "(- ?x ?y)" => "(+ ?x (* -1 ?y))"),
|
rw!("sub"; "(- ?x ?y)" => "(+ ?x (* -1 ?y))"),
|
||||||
rw!("neg"; "(- ?x)" => "(* -1 ?x)"),
|
rw!("neg"; "(- ?x)" => "(* -1 ?x)"),
|
||||||
@@ -198,6 +246,8 @@ pub static RULES: LazyLock<Vec<Rewrite>> = LazyLock::new(||vec![
|
|||||||
rw!("div"; "(/ ?x ?y)" => "(* ?x (rec ?y))" if is_nonzero_const("?y")),
|
rw!("div"; "(/ ?x ?y)" => "(* ?x (rec ?y))" if is_nonzero_const("?y")),
|
||||||
|
|
||||||
rw!("factor_poly"; "(+ (* x ?x) ?y)" => "(* ?x (+ x (* ?y (rec ?x))))" if is_nonzero_const("?x")),
|
rw!("factor_poly"; "(+ (* x ?x) ?y)" => "(* ?x (+ x (* ?y (rec ?x))))" if is_nonzero_const("?x")),
|
||||||
|
|
||||||
|
rw!("integer_sqrt"; "(^ ?x (/ 1 2))" => { IntegerSqrt { var: "?x".parse().unwrap() } } if is_const("?x")),
|
||||||
]);
|
]);
|
||||||
|
|
||||||
pub struct PlusTimesCostFn;
|
pub struct PlusTimesCostFn;
|
||||||
@@ -218,125 +268,3 @@ impl egg::CostFunction<EquationLanguage> for PlusTimesCostFn {
|
|||||||
enode.fold(op_cost, |sum, i| sum + costs(i))
|
enode.fold(op_cost, |sum, i| sum + costs(i))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug,Clone,Copy)]
|
|
||||||
pub struct PolyStat {
|
|
||||||
degree: usize,
|
|
||||||
factors: usize, // non-constant factors
|
|
||||||
ops: usize,
|
|
||||||
monomial: bool,
|
|
||||||
sum_of_monomials: bool,
|
|
||||||
monic: bool,
|
|
||||||
factorized: bool, // a product of monic polynomials and at least one constant
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug,Clone,Copy)]
|
|
||||||
pub enum FactorizationCost {
|
|
||||||
UnwantedOps,
|
|
||||||
Polynomial(PolyStat)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn score(cost: FactorizationCost) -> usize {
|
|
||||||
match cost {
|
|
||||||
FactorizationCost::UnwantedOps => 10000,
|
|
||||||
FactorizationCost::Polynomial(p) =>
|
|
||||||
if !p.factorized {
|
|
||||||
1000
|
|
||||||
} else {
|
|
||||||
100 * (9 - p.factors) + p.ops
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl PartialEq for FactorizationCost {
|
|
||||||
fn eq(&self, other: &Self) -> bool {
|
|
||||||
score(*self) == score(*other)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl PartialOrd for FactorizationCost {
|
|
||||||
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
|
|
||||||
usize::partial_cmp(&score(*self), &score(*other))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct FactorizationCostFn;
|
|
||||||
impl egg::CostFunction<EquationLanguage> for FactorizationCostFn {
|
|
||||||
type Cost = FactorizationCost;
|
|
||||||
|
|
||||||
fn cost<C>(&mut self, enode: &EquationLanguage, mut costs: C) -> Self::Cost
|
|
||||||
where
|
|
||||||
C: FnMut(Id) -> Self::Cost,
|
|
||||||
{
|
|
||||||
match enode {
|
|
||||||
EquationLanguage::Add([a,b]) => {
|
|
||||||
match (costs(*a), costs(*b)) {
|
|
||||||
(FactorizationCost::Polynomial(p1),FactorizationCost::Polynomial(p2)) => {
|
|
||||||
// we only ever want to add monomials
|
|
||||||
let result_monic = if p1.degree > p2.degree {
|
|
||||||
p1.monic
|
|
||||||
} else if p2.degree > p1.degree {
|
|
||||||
p2.monic
|
|
||||||
} else {
|
|
||||||
false
|
|
||||||
};
|
|
||||||
|
|
||||||
if !p1.sum_of_monomials || !p2.sum_of_monomials {
|
|
||||||
FactorizationCost::UnwantedOps
|
|
||||||
} else {
|
|
||||||
FactorizationCost::Polynomial(PolyStat {
|
|
||||||
degree: usize::max(p1.degree, p2.degree),
|
|
||||||
factors: 1,
|
|
||||||
ops: p1.ops + p2.ops,
|
|
||||||
monomial: false,
|
|
||||||
sum_of_monomials: p1.sum_of_monomials && p2.sum_of_monomials,
|
|
||||||
monic: result_monic,
|
|
||||||
factorized: result_monic,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
},
|
|
||||||
_ => FactorizationCost::UnwantedOps
|
|
||||||
}
|
|
||||||
},
|
|
||||||
EquationLanguage::Mul([a,b]) => {
|
|
||||||
match (costs(*a), costs(*b)) {
|
|
||||||
(FactorizationCost::Polynomial(p1), FactorizationCost::Polynomial(p2)) => {
|
|
||||||
FactorizationCost::Polynomial(PolyStat {
|
|
||||||
degree: p1.degree + p2.degree,
|
|
||||||
factors: p1.factors + p2.factors,
|
|
||||||
ops: p1.ops + p2.ops,
|
|
||||||
monomial: p1.monomial && p2.monomial,
|
|
||||||
sum_of_monomials: p1.monomial && p2.monomial,
|
|
||||||
monic: p1.monic && p2.monic,
|
|
||||||
factorized: (p1.monic && p2.factorized) || (p2.monic && p1.factorized)
|
|
||||||
})
|
|
||||||
},
|
|
||||||
_ => FactorizationCost::UnwantedOps
|
|
||||||
}
|
|
||||||
},
|
|
||||||
EquationLanguage::Num(c) => {
|
|
||||||
FactorizationCost::Polynomial(PolyStat {
|
|
||||||
degree: 0,
|
|
||||||
factors: 0,
|
|
||||||
ops: 0,
|
|
||||||
monomial: true,
|
|
||||||
sum_of_monomials: true,
|
|
||||||
monic: false,
|
|
||||||
factorized: true
|
|
||||||
})
|
|
||||||
},
|
|
||||||
EquationLanguage::Unknown => {
|
|
||||||
FactorizationCost::Polynomial(PolyStat {
|
|
||||||
degree: 1,
|
|
||||||
factors: 1,
|
|
||||||
ops: 0,
|
|
||||||
monomial: true,
|
|
||||||
sum_of_monomials: true,
|
|
||||||
monic: true,
|
|
||||||
factorized: true
|
|
||||||
})
|
|
||||||
},
|
|
||||||
_ => FactorizationCost::UnwantedOps,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
pub mod language;
|
pub mod language;
|
||||||
pub mod normal_form;
|
pub mod normal_form;
|
||||||
pub mod parse;
|
pub mod parse;
|
||||||
|
pub mod output;
|
||||||
|
pub mod cubic;
|
||||||
|
|||||||
322
src/main.rs
322
src/main.rs
@@ -1,226 +1,148 @@
|
|||||||
use egg::{Extractor, Pattern, RecExpr, Runner};
|
use egg::{AstSize, Extractor, Id, RecExpr, Runner};
|
||||||
use solveq::language::{RULES, EquationLanguage, PlusTimesCostFn, FactorizationCostFn};
|
use solveq::cubic::approximate_rational_cubic;
|
||||||
use solveq::normal_form::analyze3;
|
use solveq::language::{EquationLanguage, Rational, RULES, EGraph};
|
||||||
|
use solveq::normal_form::extract_normal_form;
|
||||||
use solveq::parse::parse_equation;
|
use solveq::parse::parse_equation;
|
||||||
|
use solveq::output::print_term;
|
||||||
|
|
||||||
static TEST_EQUATIONS: &[&str] = &[
|
static TEST_EQUATIONS: &[&str] = &[
|
||||||
"(x + 50) * 10 - 150 - 100",
|
"(x + 50) * 10 - 150 = 100",
|
||||||
"(x - 2) * (x + 2) - 0",
|
"(x - 2) * (x + 2) = 0",
|
||||||
"x ^ 2 - 4",
|
"x ^ 2 = 4",
|
||||||
"x ^ 2 - 2 - 0",
|
"x ^ 2 - 2 = 0",
|
||||||
"x ^ 2 - (2 * x + 15)",
|
"x ^ 2 = 2 * x + 15",
|
||||||
"(x ^ 2 - 2 * x - 15) * (x + 5) - 0",
|
"(x ^ 2 - 2 * x - 15) * (x + 5) = 0",
|
||||||
"x ^ 3 + 3 * x ^ 2 - 25 * x - 75 - 0",
|
"x ^ 3 + 3 * x ^ 2 - 25 * x - 75 = 0",
|
||||||
|
"x ^ 3 - 91 * x - 90 = 0",
|
||||||
];
|
];
|
||||||
|
|
||||||
fn main() {
|
fn main() {
|
||||||
for eq in TEST_EQUATIONS {
|
for eq in TEST_EQUATIONS {
|
||||||
let start = parse_equation(*eq).unwrap();
|
println!("Equation: {}", *eq);
|
||||||
|
|
||||||
// println!("{:?}", &start);
|
let solutions = solve(*eq, true);
|
||||||
// do transformation to left - right = 0
|
|
||||||
|
|
||||||
let mut runner = Runner::default()
|
let solutions_str: Vec<String> = solutions
|
||||||
|
.iter()
|
||||||
|
.map(|expr| format!("x = {}", print_term(expr)))
|
||||||
|
.collect();
|
||||||
|
println!("Solutions: {{ {} }}", solutions_str.join(", "));
|
||||||
|
|
||||||
|
println!("");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn solve(eq: &str, verbose: bool) -> Vec<RecExpr<EquationLanguage>> {
|
||||||
|
let mut start = parse_equation(eq).unwrap();
|
||||||
|
let root_id = Id::from(start.as_ref().len()-1);
|
||||||
|
let EquationLanguage::Equals([left, right]) = start[root_id]
|
||||||
|
else { panic!("Not an equation without an equals sign!"); };
|
||||||
|
start[root_id] = EquationLanguage::Sub([left, right]);
|
||||||
|
|
||||||
|
if verbose {
|
||||||
|
println!("Parsed: {}", &start);
|
||||||
|
}
|
||||||
|
|
||||||
|
let runner = Runner::default()
|
||||||
.with_explanations_enabled()
|
.with_explanations_enabled()
|
||||||
.with_expr(&start)
|
.with_expr(&start)
|
||||||
.run(&*RULES);
|
.run(&*RULES);
|
||||||
|
|
||||||
let extractor = Extractor::new(&runner.egraph, FactorizationCostFn);
|
let Some(factorization) = extract_normal_form(&runner.egraph, runner.roots[0]) else {
|
||||||
|
panic!("Couldn't factorize polynomial!");
|
||||||
|
};
|
||||||
|
|
||||||
let (best_cost, best_expr) = extractor.find_best(runner.roots[0]);
|
if verbose {
|
||||||
|
println!("Factorized normal form: {}", factorization);
|
||||||
println!("{}", start);
|
|
||||||
println!("{:?} {:?}", best_cost, <RecExpr<EquationLanguage> as AsRef<[EquationLanguage]>>::as_ref(&best_expr));
|
|
||||||
|
|
||||||
|
|
||||||
println!("");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// let root = runner.roots[0];
|
let mut solutions: Vec<RecExpr<EquationLanguage>> = Vec::new();
|
||||||
// let egraph = &runner.egraph;
|
for poly in &factorization.polynomials {
|
||||||
// let pattern: Pattern<EquationLanguage> = "(+ (* ?a (* x x)) ?c)".parse().unwrap();
|
if poly.len() == 2 { // linear factor
|
||||||
// let matches = pattern.search(&egraph);
|
let Rational { num, denom } = &poly[0];
|
||||||
|
let mut solexpr = RecExpr::default();
|
||||||
|
solexpr.add(EquationLanguage::Num(Rational {
|
||||||
|
num: -*num,
|
||||||
|
denom: *denom
|
||||||
|
}));
|
||||||
|
solutions.push(solexpr);
|
||||||
|
} else if poly.len() == 3 { // quadratic factor
|
||||||
|
let Rational { num: num0, denom: denom0 } = &poly[0];
|
||||||
|
let Rational { num: num1, denom: denom1 } = &poly[1];
|
||||||
|
|
||||||
|
let sol1 = format!("- ({num1})/(2 * ({denom1})) + ((({num1})/(2 * ({denom1}))) ^ 2 - ({num0}) / ({denom0})) ^ (1/2)");
|
||||||
|
let sol2 = format!("- ({num1})/(2 * ({denom1})) - ((({num1})/(2 * ({denom1}))) ^ 2 - ({num0}) / ({denom0})) ^ (1/2)");
|
||||||
|
|
||||||
|
let expr = parse_equation(&sol1).unwrap();
|
||||||
|
let runner = Runner::default()
|
||||||
|
.with_expr(&expr)
|
||||||
|
.run(&*RULES);
|
||||||
|
let extractor = Extractor::new(&runner.egraph, AstSize);
|
||||||
|
let (_, simplified_expr) = extractor.find_best(runner.roots[0]);
|
||||||
|
solutions.push(simplified_expr);
|
||||||
|
|
||||||
// println!("{:?}", egraph.classes().count());
|
let expr = parse_equation(&sol2).unwrap();
|
||||||
|
let runner = Runner::default()
|
||||||
|
.with_expr(&expr)
|
||||||
|
.run(&*RULES);
|
||||||
|
let extractor = Extractor::new(&runner.egraph, AstSize);
|
||||||
|
let (_, simplified_expr) = extractor.find_best(runner.roots[0]);
|
||||||
|
solutions.push(simplified_expr);
|
||||||
|
} else if poly.len() == 4 {
|
||||||
|
let guesses = approximate_rational_cubic(
|
||||||
|
&Rational { num: 0, denom: 1 },
|
||||||
|
&Rational { num: -91, denom: 1 },
|
||||||
|
&Rational { num: -90, denom: 1 },
|
||||||
|
1000
|
||||||
|
);
|
||||||
|
|
||||||
// Analyze
|
if guesses.len() < 3 {
|
||||||
// analyze3(egraph, runner.roots[0]);
|
|
||||||
|
|
||||||
/*
|
|
||||||
for class in egraph.classes() {
|
|
||||||
if monic_nonconst_polynomial(egraph, class.id).is_some() {
|
|
||||||
let (_, best_expr) = extractor.find_best(class.id);
|
|
||||||
println!("Monomial: {}", best_expr);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
println!("{:?}", &matches);
|
|
||||||
*/
|
|
||||||
|
|
||||||
|
|
||||||
// println!("{}", runner.explain_equivalence(&start, &best_expr).get_flat_string());
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/*
|
|
||||||
fn power_of_x(egraph: &EGraph, eclass: Id) -> Option<usize> {
|
|
||||||
for n in &egraph[eclass].nodes {
|
|
||||||
match *n {
|
|
||||||
EquationLanguage::Unknown => { return Some(1) },
|
|
||||||
EquationLanguage::Mul([a,b]) => {
|
|
||||||
let Some(left) = power_of_x(egraph, a) else { continue };
|
|
||||||
let Some(right) = power_of_x(egraph, b) else { continue };
|
|
||||||
return Some(left + right);
|
|
||||||
},
|
|
||||||
_ => {}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
None
|
|
||||||
}
|
|
||||||
|
|
||||||
fn monomial(egraph: &EGraph, eclass: Id) -> Option<(usize, Rational)> {
|
|
||||||
if let Some(deg) = power_of_x(egraph, eclass) {
|
|
||||||
return Some((deg, RATIONAL_ONE.clone()));
|
|
||||||
}
|
|
||||||
|
|
||||||
for n in &egraph[eclass].nodes {
|
|
||||||
match *n {
|
|
||||||
EquationLanguage::Mul([a,b]) => {
|
|
||||||
let Some(coeff) = egraph[a].data.clone() else { continue };
|
|
||||||
let Some(deg) = power_of_x(egraph, b) else { continue };
|
|
||||||
return Some((deg, coeff));
|
|
||||||
},
|
|
||||||
_ => {}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
None
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
// this is either a power_of_x, or a sum of this and a monomial
|
|
||||||
fn monic_nonconst_polynomial(egraph: &EGraph, eclass: Id) -> Option<Vec<Rational>> {
|
|
||||||
let mut result: Vec<Rational> = Vec::new();
|
|
||||||
|
|
||||||
if let Some(deg) = power_of_x(egraph, eclass) {
|
|
||||||
result.resize(deg - 1, RATIONAL_ZERO);
|
|
||||||
result.push(RATIONAL_ONE.clone());
|
|
||||||
return Some(result);
|
|
||||||
}
|
|
||||||
|
|
||||||
for n in &egraph[eclass].nodes {
|
|
||||||
match *n {
|
|
||||||
EquationLanguage::Add([a,b]) => {
|
|
||||||
let Some(mut leading) = monic_nonconst_polynomial(egraph, a)
|
|
||||||
else { continue };
|
|
||||||
let Some(addon) = monomial(egraph, b)
|
|
||||||
else { continue };
|
|
||||||
|
|
||||||
if leading.len() <= addon.0 || leading[addon.0] != RATIONAL_ZERO {
|
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
leading[addon.0] = addon.1.clone();
|
if verbose {
|
||||||
return Some(leading);
|
println!("Guessing rational solutions: {}, {}, {}",
|
||||||
},
|
guesses[0], guesses[1], guesses[2]);
|
||||||
_ => {},
|
}
|
||||||
|
|
||||||
|
let start = format!("x ^ 3 + ({}) * x ^ 2 + ({}) * x + ({})",
|
||||||
|
poly[2], poly[1], poly[0]);
|
||||||
|
let goal = format!("(x - ({})) * (x - ({})) * (x - ({}))",
|
||||||
|
guesses[0], guesses[1], guesses[2]);
|
||||||
|
|
||||||
|
let start_expr = parse_equation(&start).unwrap();
|
||||||
|
let goal_expr = parse_equation(&goal).unwrap();
|
||||||
|
|
||||||
|
let mut egraph: EGraph = Default::default();
|
||||||
|
egraph.add_expr(&start_expr);
|
||||||
|
egraph.add_expr(&goal_expr);
|
||||||
|
|
||||||
|
let runner = Runner::default()
|
||||||
|
.with_egraph(egraph)
|
||||||
|
.with_node_limit(100_000)
|
||||||
|
.with_iter_limit(100)
|
||||||
|
.run(&*RULES);
|
||||||
|
|
||||||
|
let equivs = runner.egraph.equivs(&start_expr, &goal_expr);
|
||||||
|
|
||||||
|
if !equivs.is_empty() {
|
||||||
|
if verbose {
|
||||||
|
println!("Verified guessed solutions!");
|
||||||
|
}
|
||||||
|
|
||||||
|
// expressions are equivalent, so the guesses are actually solutions
|
||||||
|
for s in &guesses {
|
||||||
|
let mut solexpr = RecExpr::default();
|
||||||
|
solexpr.add(EquationLanguage::Num(s.clone()));
|
||||||
|
solutions.push(solexpr);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if verbose {
|
||||||
|
println!("Couldn't verify guessed solutions.");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
None
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
solutions
|
||||||
}
|
}
|
||||||
|
|
||||||
*/
|
|
||||||
|
|
||||||
|
|
||||||
/*
|
|
||||||
fn analyze(egraph: &EGraph, _id: Id) {
|
|
||||||
let mut types: HashMap<Id, SpecialTerm> = HashMap::new();
|
|
||||||
let mut todo: VecDeque<Id> = VecDeque::new();
|
|
||||||
// todo.push_back(runner.roots[0]);
|
|
||||||
for cls in egraph.classes() {
|
|
||||||
todo.push_back(cls.id);
|
|
||||||
}
|
|
||||||
|
|
||||||
'todo: while todo.len() > 0 {
|
|
||||||
let id = todo.pop_front().unwrap();
|
|
||||||
if types.contains_key(&id) {
|
|
||||||
continue 'todo;
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Some(c) = &egraph[id].data {
|
|
||||||
types.insert(id, SpecialTerm::Constant(c.clone()));
|
|
||||||
continue 'todo;
|
|
||||||
}
|
|
||||||
|
|
||||||
'nodes: for n in &egraph[id].nodes {
|
|
||||||
match *n {
|
|
||||||
EquationLanguage::Unknown => {
|
|
||||||
types.insert(id, SpecialTerm::PowerOfX(1));
|
|
||||||
continue 'todo;
|
|
||||||
},
|
|
||||||
EquationLanguage::Mul([a,b]) => {
|
|
||||||
if !types.contains_key(&a) {
|
|
||||||
todo.push_back(a);
|
|
||||||
todo.push_back(id);
|
|
||||||
continue 'nodes;
|
|
||||||
}
|
|
||||||
|
|
||||||
if !types.contains_key(&b) {
|
|
||||||
todo.push_back(b);
|
|
||||||
todo.push_back(id);
|
|
||||||
continue 'nodes;
|
|
||||||
}
|
|
||||||
|
|
||||||
match (&types[&a], &types[&b]) {
|
|
||||||
(SpecialTerm::PowerOfX(dega), SpecialTerm::PowerOfX(degb)) => {
|
|
||||||
types.insert(id, SpecialTerm::PowerOfX(*dega + *degb));
|
|
||||||
},
|
|
||||||
(SpecialTerm::Constant(coeff), SpecialTerm::PowerOfX(deg)) => {
|
|
||||||
types.insert(id, SpecialTerm::Monomial(*deg, coeff.clone()));
|
|
||||||
},
|
|
||||||
_ => { continue 'nodes; },
|
|
||||||
}
|
|
||||||
continue 'todo;
|
|
||||||
},
|
|
||||||
EquationLanguage::Add([a,b]) => {
|
|
||||||
if !types.contains_key(&a) {
|
|
||||||
todo.push_front(a);
|
|
||||||
todo.push_back(id);
|
|
||||||
continue 'todo;
|
|
||||||
}
|
|
||||||
|
|
||||||
if !types.contains_key(&b) {
|
|
||||||
todo.push_front(b);
|
|
||||||
todo.push_back(id);
|
|
||||||
continue 'todo;
|
|
||||||
}
|
|
||||||
|
|
||||||
match (&types[&a], &types[&b]) {
|
|
||||||
(SpecialTerm::MonicNonconstPoly(poly), SpecialTerm::Monomial(deg, coeff)) => {
|
|
||||||
if poly.len() <= *deg || poly[*deg] != RATIONAL_ZERO {
|
|
||||||
continue 'nodes;
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut poly = poly.clone();
|
|
||||||
poly[*deg] = coeff.clone();
|
|
||||||
types.insert(id, SpecialTerm::MonicNonconstPoly(poly));
|
|
||||||
},
|
|
||||||
_ => { continue 'nodes; },
|
|
||||||
}
|
|
||||||
continue 'todo;
|
|
||||||
},
|
|
||||||
_ => {},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
types.insert(id, SpecialTerm::Other);
|
|
||||||
}
|
|
||||||
|
|
||||||
for (id, ty) in &types {
|
|
||||||
if !matches!(ty, &SpecialTerm::Other) {
|
|
||||||
println!("{:?}", &ty);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
*/
|
|
||||||
|
|||||||
@@ -1,18 +1,269 @@
|
|||||||
use crate::language::{EGraph, EquationLanguage, Rational, RATIONAL_ONE, RATIONAL_ZERO};
|
use crate::language::{EGraph, EquationLanguage, Rational, RATIONAL_ONE, RATIONAL_ZERO};
|
||||||
use std::collections::HashMap;
|
use std::{collections::HashMap, fmt};
|
||||||
use egg::Id;
|
use egg::Id;
|
||||||
|
|
||||||
#[derive(Debug,Clone)]
|
#[derive(Debug,Clone)]
|
||||||
pub enum SpecialTerm {
|
pub struct Factorization {
|
||||||
Constant(Rational),
|
pub constant_factor: Rational,
|
||||||
PowerOfX(usize),
|
pub polynomials: Vec<Vec<Rational>>,
|
||||||
Monomial(usize, Rational),
|
|
||||||
MonicNonconstPoly(Vec<Rational>),
|
|
||||||
Factorization(Rational, Vec<Vec<Rational>>),
|
|
||||||
Other,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn search_for<F, T>(egraph: &EGraph, f: F) -> HashMap<Id, T>
|
impl fmt::Display for Factorization {
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
|
if self.constant_factor != RATIONAL_ONE {
|
||||||
|
write!(f, "{}", self.constant_factor)?;
|
||||||
|
}
|
||||||
|
|
||||||
|
for poly in &self.polynomials {
|
||||||
|
write!(f, "(")?;
|
||||||
|
for (deg, coeff) in poly.iter().enumerate() {
|
||||||
|
if deg == 0 {
|
||||||
|
write!(f, "{}", coeff)?;
|
||||||
|
} else if deg == 1 {
|
||||||
|
write!(f, " + {}x", coeff)?;
|
||||||
|
} else {
|
||||||
|
write!(f, " + {}x^{}", coeff, deg)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
write!(f, ")")?;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// warning: this is a property of an eclass, not a particular expression
|
||||||
|
// so we shouldn't use anything that's not a polynomial invariant
|
||||||
|
#[derive(Debug,Clone)]
|
||||||
|
struct PolyStats {
|
||||||
|
degree: usize,
|
||||||
|
monomial: bool,
|
||||||
|
monic: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn gather_poly_stats(egraph: &EGraph) -> HashMap<Id, PolyStats> {
|
||||||
|
walk_egraph(egraph, |_id, node, stats: &HashMap<Id, PolyStats>| {
|
||||||
|
let x = |i: &Id| stats.get(&egraph.find(*i));
|
||||||
|
Some(match node {
|
||||||
|
EquationLanguage::Unknown => PolyStats {
|
||||||
|
degree: 1,
|
||||||
|
monomial: true,
|
||||||
|
monic: true,
|
||||||
|
},
|
||||||
|
EquationLanguage::Num(c) => PolyStats {
|
||||||
|
degree: 0,
|
||||||
|
monomial: true,
|
||||||
|
monic: c == &RATIONAL_ONE,
|
||||||
|
},
|
||||||
|
EquationLanguage::Mul([a,b]) => {
|
||||||
|
// if both aren't monic we can't tell, the leading coefficients could cancel
|
||||||
|
// but there should be an alternative representative with one of them monic
|
||||||
|
if !x(a)?.monic && !x(b)?.monic {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
PolyStats {
|
||||||
|
degree: x(a)?.degree + x(b)?.degree,
|
||||||
|
monomial: x(a)?.monomial && x(b)?.monomial,
|
||||||
|
monic: x(a)?.monic && x(b)?.monic,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
EquationLanguage::Add([a,b]) => {
|
||||||
|
// in this case, there should also be a simplified representative which
|
||||||
|
// has only a single leading term
|
||||||
|
if x(a)?.degree == x(b)?.degree {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
PolyStats {
|
||||||
|
degree: usize::max(x(a)?.degree, x(b)?.degree),
|
||||||
|
monomial: false,
|
||||||
|
monic: if x(a)?.degree > x(b)?.degree { x(a)?.monic } else { x(b)?.monic },
|
||||||
|
}
|
||||||
|
},
|
||||||
|
_ => { return None; }
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn extract_normal_form(egraph: &EGraph, eclass: Id) -> Option<Factorization> {
|
||||||
|
let eclass = egraph.find(eclass); // get the canonical eclass
|
||||||
|
|
||||||
|
let stats = gather_poly_stats(egraph);
|
||||||
|
|
||||||
|
let Some(factorization) = find_general_factorization(egraph, &stats, eclass) else { return None; };
|
||||||
|
|
||||||
|
let mut result = Vec::new();
|
||||||
|
let mut coeff: Option<Rational> = None;
|
||||||
|
|
||||||
|
for factor in factorization {
|
||||||
|
let extracted = extract_polynomial(egraph, &stats, factor)?;
|
||||||
|
// println!("Extracted: {:?}", extracted);
|
||||||
|
|
||||||
|
if extracted.len() == 1 {
|
||||||
|
coeff = Some(extracted[0].clone());
|
||||||
|
} else {
|
||||||
|
result.push(extracted);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Some(Factorization {
|
||||||
|
constant_factor: coeff.unwrap_or_else(||RATIONAL_ONE.clone()),
|
||||||
|
polynomials: result
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// a polynomial should be either of:
|
||||||
|
// - a monomial: then we know the degree and we try to parse it as a product of a constant and a monic monomial
|
||||||
|
// or just a constant
|
||||||
|
// - a sum of a monomial of highest degree, and a polynomial of lower degree, recursively walk these
|
||||||
|
fn extract_polynomial(egraph: &EGraph, stats: &HashMap<Id, PolyStats>, id: Id) -> Option<Vec<Rational>> {
|
||||||
|
let st = &stats[&id];
|
||||||
|
|
||||||
|
if st.monomial {
|
||||||
|
let (deg, coeff) = extract_monomial(egraph, stats, id)?;
|
||||||
|
|
||||||
|
let mut result = vec![RATIONAL_ZERO; deg];
|
||||||
|
result.push(coeff);
|
||||||
|
return Some(result);
|
||||||
|
} else {
|
||||||
|
for node in &egraph[id].nodes {
|
||||||
|
match node {
|
||||||
|
EquationLanguage::Add([a,b]) => {
|
||||||
|
let a = egraph.find(*a);
|
||||||
|
let b = egraph.find(*b);
|
||||||
|
let Some(stata) = &stats.get(&a) else { continue };
|
||||||
|
let Some(statb) = &stats.get(&b) else { continue };
|
||||||
|
|
||||||
|
if stata.degree == st.degree && stata.monomial && statb.degree < st.degree {
|
||||||
|
let (leading_deg, leading_coeff) = extract_monomial(egraph, stats, a)?;
|
||||||
|
let mut remainder = extract_polynomial(egraph, stats, b)?;
|
||||||
|
|
||||||
|
assert!(leading_deg >= remainder.len());
|
||||||
|
|
||||||
|
remainder.resize(leading_deg, RATIONAL_ZERO);
|
||||||
|
remainder.push(leading_coeff);
|
||||||
|
return Some(remainder);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
|
// a monomial is either a power of x, a constant, or a product of a constant and power of x
|
||||||
|
fn extract_monomial(egraph: &EGraph, stats: &HashMap<Id, PolyStats>, id: Id) -> Option<(usize, Rational)> {
|
||||||
|
let st = &stats[&id];
|
||||||
|
|
||||||
|
assert!(st.monomial);
|
||||||
|
|
||||||
|
// monic + monomial = power of x
|
||||||
|
if st.monic {
|
||||||
|
return Some((st.degree, RATIONAL_ONE.clone()));
|
||||||
|
}
|
||||||
|
|
||||||
|
for node in &egraph[id].nodes {
|
||||||
|
match node {
|
||||||
|
EquationLanguage::Mul([a,b]) => {
|
||||||
|
let a = egraph.find(*a);
|
||||||
|
let b = egraph.find(*b);
|
||||||
|
let Some(statb) = stats.get(&b) else { continue };
|
||||||
|
|
||||||
|
// a should be a constant
|
||||||
|
let Some(coeff) = egraph[a].data.clone() else { continue };
|
||||||
|
|
||||||
|
// b should be monic and nonconstant (hence a power of x)
|
||||||
|
if statb.degree == 0 || !statb.monic {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
assert_eq!(st.degree, statb.degree);
|
||||||
|
|
||||||
|
return Some((st.degree, coeff));
|
||||||
|
},
|
||||||
|
EquationLanguage::Num(c) => { // a constant is also a monomial
|
||||||
|
return Some((0, c.clone()));
|
||||||
|
},
|
||||||
|
_ => {},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
|
// like find_monic_factorization, but with the option of having a constant factor
|
||||||
|
fn find_general_factorization(egraph: &EGraph, stats: &HashMap<Id, PolyStats>, id: Id) -> Option<Vec<Id>> {
|
||||||
|
let st = stats.get(&id)?;
|
||||||
|
|
||||||
|
if st.monic {
|
||||||
|
return find_monic_factorization(egraph, stats, id);
|
||||||
|
} else {
|
||||||
|
for node in &egraph[id].nodes {
|
||||||
|
match node {
|
||||||
|
EquationLanguage::Mul([a,b]) => {
|
||||||
|
let a = egraph.find(*a);
|
||||||
|
let b = egraph.find(*b);
|
||||||
|
let Some(stata) = stats.get(&a) else { continue };
|
||||||
|
let Some(statb) = stats.get(&b) else { continue };
|
||||||
|
|
||||||
|
// a is constant, b is monic and nonconstant
|
||||||
|
if stata.degree == 0 && statb.degree > 0 && statb.monic {
|
||||||
|
let mut fac = find_monic_factorization(egraph, stats, b)?;
|
||||||
|
fac.push(a);
|
||||||
|
return Some(fac);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
|
// this assumes `id` to be canonical
|
||||||
|
fn find_monic_factorization(egraph: &EGraph, stats: &HashMap<Id, PolyStats>, id: Id) -> Option<Vec<Id>> {
|
||||||
|
// we want the polynomial to be nonconstant and monic
|
||||||
|
if ! stats.get(&id).is_some_and(|x|x.monic && x.degree > 0) {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
// now the whole thing is a monic nonconst poly, so would be a valid factorization,
|
||||||
|
// but we want to go as deep as possible
|
||||||
|
|
||||||
|
// check if it is the product of two nonconstant monic polynomials
|
||||||
|
for node in &egraph[id].nodes {
|
||||||
|
match node {
|
||||||
|
EquationLanguage::Mul([a,b]) => {
|
||||||
|
let a = egraph.find(*a);
|
||||||
|
let b = egraph.find(*b);
|
||||||
|
let Some(stata) = stats.get(&a) else { continue };
|
||||||
|
let Some(statb) = stats.get(&b) else { continue };
|
||||||
|
|
||||||
|
if stata.degree == 0 || statb.degree == 0 || !stata.monic || !statb.monic {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// println!("stats = {:?}, stats a = {:?}, stats b = {:?}", stats[&id], stata, statb);
|
||||||
|
|
||||||
|
let Some(mut faca) = find_monic_factorization(egraph, stats, a) else { continue };
|
||||||
|
let Some(facb) = find_monic_factorization(egraph, stats, b) else { continue };
|
||||||
|
|
||||||
|
faca.extend_from_slice(&facb);
|
||||||
|
return Some(faca);
|
||||||
|
},
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// at this point we know the current polynomial is monic, but we didn't find a further factorization
|
||||||
|
// so just return it as a single factor
|
||||||
|
Some(vec![id])
|
||||||
|
}
|
||||||
|
|
||||||
|
fn walk_egraph<F, T>(egraph: &EGraph, f: F) -> HashMap<Id, T>
|
||||||
where
|
where
|
||||||
F: Fn(Id, &EquationLanguage, &HashMap<Id, T>) -> Option<T> {
|
F: Fn(Id, &EquationLanguage, &HashMap<Id, T>) -> Option<T> {
|
||||||
|
|
||||||
@@ -22,208 +273,23 @@ where
|
|||||||
while modifications > 0 {
|
while modifications > 0 {
|
||||||
modifications = 0;
|
modifications = 0;
|
||||||
|
|
||||||
for cls in egraph.classes() {
|
'next_class: for cls in egraph.classes() {
|
||||||
let id = cls.id;
|
let id = cls.id;
|
||||||
if result.contains_key(&id) {
|
if result.contains_key(&id) {
|
||||||
continue;
|
continue 'next_class;
|
||||||
}
|
}
|
||||||
|
|
||||||
for node in &cls.nodes {
|
for node in &cls.nodes {
|
||||||
if let Some(x) = f(id, node, &result) {
|
if let Some(x) = f(id, node, &result) {
|
||||||
result.insert(id, x);
|
result.insert(id, x);
|
||||||
modifications += 1;
|
modifications += 1;
|
||||||
|
continue 'next_class;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
println!("{} modifications!", modifications);
|
// println!("{} modifications!", modifications);
|
||||||
}
|
}
|
||||||
|
|
||||||
result
|
result
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn analyze3(egraph: &EGraph, eclass: Id) {
|
|
||||||
let constants = search_for(egraph, |id, _, _|
|
|
||||||
egraph[id].data.as_ref().map(|c|c.clone())
|
|
||||||
);
|
|
||||||
|
|
||||||
println!("{:?}", constants);
|
|
||||||
|
|
||||||
let powers_of_x = search_for(egraph, |_, node, matches| match *node {
|
|
||||||
EquationLanguage::Unknown => Some(1),
|
|
||||||
EquationLanguage::Mul([a,b]) => {
|
|
||||||
if !matches.contains_key(&a) || !matches.contains_key(&b) {
|
|
||||||
return None;
|
|
||||||
}
|
|
||||||
|
|
||||||
let (dega, degb) = (matches[&a], matches[&b]);
|
|
||||||
Some(dega + degb)
|
|
||||||
},
|
|
||||||
_ => None,
|
|
||||||
});
|
|
||||||
|
|
||||||
println!("{:?}", powers_of_x);
|
|
||||||
|
|
||||||
let monomials = search_for(egraph, |id, node, _| {
|
|
||||||
if let Some(deg) = powers_of_x.get(&id) {
|
|
||||||
return Some((*deg, RATIONAL_ONE.clone()));
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Some(c) = constants.get(&id) {
|
|
||||||
return Some((0, c.clone()));
|
|
||||||
}
|
|
||||||
|
|
||||||
match *node {
|
|
||||||
EquationLanguage::Mul([a,b]) => {
|
|
||||||
if !constants.contains_key(&a) || !powers_of_x.contains_key(&b) {
|
|
||||||
return None;
|
|
||||||
}
|
|
||||||
|
|
||||||
let (coeff, deg) = (&constants[&a], powers_of_x[&b]);
|
|
||||||
Some((deg, coeff.clone()))
|
|
||||||
},
|
|
||||||
_ => None,
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
println!("{:?}", monomials);
|
|
||||||
|
|
||||||
let monic_polynomials = search_for(egraph, |id, node, matches| {
|
|
||||||
if let Some(deg) = powers_of_x.get(&id) {
|
|
||||||
let mut poly: Vec<Rational> = Vec::new();
|
|
||||||
poly.resize(*deg, RATIONAL_ZERO);
|
|
||||||
poly.push(RATIONAL_ONE.clone());
|
|
||||||
Some(poly)
|
|
||||||
} else {
|
|
||||||
match *node {
|
|
||||||
EquationLanguage::Add([a,b]) => {
|
|
||||||
if !matches.contains_key(&a) || !monomials.contains_key(&b) {
|
|
||||||
return None;
|
|
||||||
}
|
|
||||||
|
|
||||||
let (leading, (deg, coeff)) = (&matches[&a], &monomials[&b]);
|
|
||||||
if leading.len() <= *deg || leading[*deg] != RATIONAL_ZERO {
|
|
||||||
return None;
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut poly = leading.clone();
|
|
||||||
poly[*deg] = coeff.clone();
|
|
||||||
Some(poly)
|
|
||||||
},
|
|
||||||
_ => None,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
for p in &monic_polynomials {
|
|
||||||
println!("{:?}", p);
|
|
||||||
}
|
|
||||||
|
|
||||||
let factorizations: HashMap<Id, (Rational, Vec<Vec<Rational>>)> = search_for(egraph, |id, node, matches| {
|
|
||||||
if let Some(c) = constants.get(&id) {
|
|
||||||
return Some((c.clone(), vec![]));
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Some(poly) = monic_polynomials.get(&id) {
|
|
||||||
return Some((RATIONAL_ONE.clone(), vec![poly.clone()]));
|
|
||||||
}
|
|
||||||
|
|
||||||
match *node {
|
|
||||||
EquationLanguage::Mul([a,b]) => {
|
|
||||||
if !matches.contains_key(&a) || !monic_polynomials.contains_key(&b) {
|
|
||||||
return None;
|
|
||||||
}
|
|
||||||
|
|
||||||
let ((factor, polys), newpoly) = (&matches[&a], &monic_polynomials[&b]);
|
|
||||||
|
|
||||||
let mut combined: Vec<Vec<Rational>> = polys.clone();
|
|
||||||
combined.push(newpoly.clone());
|
|
||||||
Some((factor.clone(), combined))
|
|
||||||
},
|
|
||||||
_ => None,
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
/*
|
|
||||||
for p in &factorizations {
|
|
||||||
println!("{:?}", p);
|
|
||||||
}
|
|
||||||
*/
|
|
||||||
|
|
||||||
println!("{:?}", factorizations[&eclass]);
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn analyze2(egraph: &EGraph) -> HashMap<Id, SpecialTerm> {
|
|
||||||
let mut types: HashMap<Id, SpecialTerm> = HashMap::new();
|
|
||||||
|
|
||||||
let mut modifications: usize = 1;
|
|
||||||
|
|
||||||
while modifications > 0 {
|
|
||||||
modifications = 0;
|
|
||||||
|
|
||||||
for cls in egraph.classes() {
|
|
||||||
let id = cls.id;
|
|
||||||
if types.contains_key(&id) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Some(c) = &egraph[id].data {
|
|
||||||
types.insert(id, SpecialTerm::Constant(c.clone()));
|
|
||||||
modifications += 1;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
for node in &cls.nodes {
|
|
||||||
match *node {
|
|
||||||
EquationLanguage::Unknown => {
|
|
||||||
types.insert(id, SpecialTerm::PowerOfX(1));
|
|
||||||
modifications += 1;
|
|
||||||
},
|
|
||||||
EquationLanguage::Mul([a,b]) => {
|
|
||||||
// as we don't know a and b yet, defer to future iteration
|
|
||||||
if !types.contains_key(&a) || !types.contains_key(&b) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
match (&types[&a], &types[&b]) {
|
|
||||||
(SpecialTerm::PowerOfX(dega), SpecialTerm::PowerOfX(degb)) => {
|
|
||||||
types.insert(id, SpecialTerm::PowerOfX(*dega + *degb));
|
|
||||||
modifications += 1;
|
|
||||||
},
|
|
||||||
(SpecialTerm::Constant(coeff), SpecialTerm::PowerOfX(deg)) => {
|
|
||||||
types.insert(id, SpecialTerm::Monomial(*deg, coeff.clone()));
|
|
||||||
modifications += 1;
|
|
||||||
},
|
|
||||||
_ => { },
|
|
||||||
}
|
|
||||||
},
|
|
||||||
EquationLanguage::Add([a,b]) => {
|
|
||||||
// as we don't know a and b yet, defer to future iteration
|
|
||||||
if !types.contains_key(&a) || !types.contains_key(&b) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
match (&types[&a], &types[&b]) {
|
|
||||||
(SpecialTerm::MonicNonconstPoly(poly), SpecialTerm::Monomial(deg, coeff)) => {
|
|
||||||
if poly.len() <= *deg || poly[*deg] != RATIONAL_ZERO {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut poly = poly.clone();
|
|
||||||
poly[*deg] = coeff.clone();
|
|
||||||
types.insert(id, SpecialTerm::MonicNonconstPoly(poly));
|
|
||||||
modifications += 1;
|
|
||||||
},
|
|
||||||
_ => { },
|
|
||||||
}
|
|
||||||
},
|
|
||||||
_ => {}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
println!("{} modifications!", modifications);
|
|
||||||
}
|
|
||||||
|
|
||||||
types
|
|
||||||
}
|
|
||||||
|
|||||||
66
src/output.rs
Normal file
66
src/output.rs
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
use egg::{RecExpr, Id};
|
||||||
|
use crate::language::EquationLanguage;
|
||||||
|
|
||||||
|
// there is already a Display implementation generated by define_langauge!
|
||||||
|
// but we want an alternative string conversion
|
||||||
|
pub fn print_term(expr: &RecExpr<EquationLanguage>) -> String {
|
||||||
|
let root_id = Id::from(expr.as_ref().len()-1);
|
||||||
|
print_term_inner(expr, root_id).0
|
||||||
|
}
|
||||||
|
|
||||||
|
// the second result is the precedence of the top level op: 1 = '+-', 2 = '*/', 3 = '^', 4 = primitive
|
||||||
|
fn print_term_inner(expr: &RecExpr<EquationLanguage>, id: Id) -> (String, usize) {
|
||||||
|
match &expr[id] {
|
||||||
|
EquationLanguage::Num(c) => {
|
||||||
|
(format!("{}", c), if c.denom == 1 { 4 } else { 2 })
|
||||||
|
},
|
||||||
|
EquationLanguage::Neg([a]) => {
|
||||||
|
(print_unary(expr, *a, "-", 1), 1)
|
||||||
|
},
|
||||||
|
EquationLanguage::Add([a,b]) => {
|
||||||
|
(print_binary(expr, *a, *b, "+", 1), 1)
|
||||||
|
},
|
||||||
|
EquationLanguage::Sub([a,b]) => {
|
||||||
|
(print_binary(expr, *a, *b, "-", 1), 1)
|
||||||
|
},
|
||||||
|
EquationLanguage::Mul([a,b]) => {
|
||||||
|
(print_binary(expr, *a, *b, "*", 2), 2)
|
||||||
|
},
|
||||||
|
EquationLanguage::Div([a,b]) => {
|
||||||
|
(print_binary(expr, *a, *b, "/", 2), 2)
|
||||||
|
},
|
||||||
|
EquationLanguage::Power([a,b]) => {
|
||||||
|
(print_binary(expr, *a, *b, "^", 3), 3)
|
||||||
|
},
|
||||||
|
_ => unimplemented!()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn print_unary(expr: &RecExpr<EquationLanguage>, a: Id, op: &str, precedence: usize) -> String {
|
||||||
|
let (astr, aprec) = print_term_inner(expr, a);
|
||||||
|
|
||||||
|
if aprec > precedence {
|
||||||
|
format!("{}{}", op, astr)
|
||||||
|
} else {
|
||||||
|
format!("{}({})", op, astr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn print_binary(expr: &RecExpr<EquationLanguage>, a: Id, b: Id, op: &str, precedence: usize) -> String {
|
||||||
|
let (astr, aprec) = print_term_inner(expr, a);
|
||||||
|
let (bstr, bprec) = print_term_inner(expr, b);
|
||||||
|
|
||||||
|
if aprec > precedence {
|
||||||
|
if bprec > precedence {
|
||||||
|
format!("{} {} {}", astr, op, bstr)
|
||||||
|
} else {
|
||||||
|
format!("{} {} ({})", astr, op, bstr)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if bprec > precedence {
|
||||||
|
format!("({}) {} {}", astr, op, bstr)
|
||||||
|
} else {
|
||||||
|
format!("({}) {} ({})", astr, op, bstr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,5 +1,5 @@
|
|||||||
use std::error::Error;
|
use std::error::Error;
|
||||||
use egg::*;
|
use egg::{RecExpr, Id, FromOp, FromOpError};
|
||||||
use crate::language::EquationLanguage;
|
use crate::language::EquationLanguage;
|
||||||
|
|
||||||
pub fn parse_equation(input: &str) -> Result<RecExpr<EquationLanguage>, ParseError> {
|
pub fn parse_equation(input: &str) -> Result<RecExpr<EquationLanguage>, ParseError> {
|
||||||
@@ -26,15 +26,15 @@ fn parse_equation_inner(input: &str, expr: &mut RecExpr<EquationLanguage>) -> Re
|
|||||||
}
|
}
|
||||||
|
|
||||||
match c {
|
match c {
|
||||||
'^' if precedence > 3 => {
|
'^' if precedence >= 3 => {
|
||||||
operator_position = Some(i);
|
operator_position = Some(i);
|
||||||
precedence = 3;
|
precedence = 3;
|
||||||
},
|
},
|
||||||
'*' | '/' if precedence > 2 => {
|
'*' | '/' if precedence >= 2 => {
|
||||||
operator_position = Some(i);
|
operator_position = Some(i);
|
||||||
precedence = 2;
|
precedence = 2;
|
||||||
},
|
},
|
||||||
'-' | '+' if precedence > 1 => {
|
'-' | '+' if precedence >= 1 => {
|
||||||
operator_position = Some(i);
|
operator_position = Some(i);
|
||||||
precedence = 1;
|
precedence = 1;
|
||||||
},
|
},
|
||||||
|
|||||||
Reference in New Issue
Block a user