Compare commits

..

7 Commits

Author SHA1 Message Date
Florian Stecker
fced3915ea slightly clarify the explanation 2024-08-29 09:31:44 -04:00
Florian Stecker
829f1e7fa7 add readme 2024-08-28 16:41:44 -04:00
Florian Stecker
74ffaea4c0 add cubic equation solver 2024-08-28 15:19:03 -04:00
Florian Stecker
0b8bdf3da6 cleaned up 2024-08-28 12:31:28 -04:00
Florian Stecker
c110dd6889 all examples work 2024-08-28 12:07:03 -04:00
Florian Stecker
ef2a76869f fixed it! 2024-08-28 11:25:55 -04:00
Florian Stecker
d0265ea340 broken state 2024-08-27 23:24:51 -04:00
8 changed files with 665 additions and 529 deletions

61
README.md Normal file
View File

@@ -0,0 +1,61 @@
# Symbolic solver for polynomials with rational coefficients #
## Program output ##
Using `cargo run --release`. On my laptop, the last example times out on debug settings.
Equation: (x + 50) * 10 - 150 = 100
Parsed: (- (- (* (+ x 50) 10) 150) 100)
Factorized normal form: 10(25 + 1x)
Solutions: { x = -25 }
Equation: (x - 2) * (x + 2) = 0
Parsed: (- (* (- x 2) (+ x 2)) 0)
Factorized normal form: (-2 + 1x)(2 + 1x)
Solutions: { x = 2, x = -2 }
Equation: x ^ 2 = 4
Parsed: (- (^ x 2) 4)
Factorized normal form: (-4 + 0x + 1x^2)
Solutions: { x = 2, x = -2 }
Equation: x ^ 2 - 2 = 0
Parsed: (- (- (^ x 2) 2) 0)
Factorized normal form: (-2 + 0x + 1x^2)
Solutions: { x = 2 ^ (1/2), x = -2 ^ (1/2) }
Equation: x ^ 2 = 2 * x + 15
Parsed: (- (^ x 2) (+ (* 2 x) 15))
Factorized normal form: (-15 + -2x + 1x^2)
Solutions: { x = 5, x = -3 }
Equation: (x ^ 2 - 2 * x - 15) * (x + 5) = 0
Parsed: (- (* (- (- (^ x 2) (* 2 x)) 15) (+ x 5)) 0)
Factorized normal form: (-15 + -2x + 1x^2)(5 + 1x)
Solutions: { x = 5, x = -3, x = -5 }
Equation: x ^ 3 + 3 * x ^ 2 - 25 * x - 75 = 0
Parsed: (- (- (- (+ (^ x 3) (* 3 (^ x 2))) (* 25 x)) 75) 0)
Factorized normal form: (3 + 1x)(-25 + 0x + 1x^2)
Solutions: { x = -3, x = 5, x = -5 }
Equation: x ^ 3 - 91 * x - 90 = 0
Parsed: (- (- (- (^ x 3) (* 91 x)) 90) 0)
Factorized normal form: (-90 + -91x + 0x^2 + 1x^3)
Guessing rational solutions: 10, -9, -1
Verified guessed solutions!
Solutions: { x = 10, x = -9, x = -1 }
## Notes ##
This uses an egraph with rewrite rules representing the usual rules of arithmetic, and constant folding with rational numbers. This alone is enough to solve linear equations, by just starting with the difference of the two sides and minimizing AST size. However, it can't suffice for quadratic equations, as the egraph can't create values "out of this air". For example, to solve `x^2 = 2` the egraph must at least have "seen" the expression `2^(1/2)` once.
So instead, the egraph is used to bring the expression into a standard form: a product of monic polynomials which is factored as much as possible. Since this is difficult to model with a cost function, I wrote a custom extractor (in `normal_form.rs`).
Then we can use the usual quadratic formula to find solutions. The obtained solutions are afterwards simplified by another egraph, with a special "square root of a square integer" rule. This is pretty minimal, a more complete solution would also be able to deal with square roots of rational numbers etc.
Finally, the unfactorized cubic equation (equation 7) gets factored into a linear and a quadratic factor by the egraph. This is kind of "magic" though: the correct terms seem to come up randomly while mutating the egraph. For general cubics we can't expect this to happen, even if they have integer solutions. The additional equation 8 is an example for this.
While there is a general formula for the solutions of a cubic equation, it contains trigonometric functions and/or complex numbers, and we don't want to deal with that here. So we use a different strategy which can handle cubic equations with 3 rational solutions: we first solve the equation numerically using the cubic formula, then find rational expressions for the numerical solutions using continued fractions. Finally, we verify the resulting factorization by using an egraph. This equivalence check required increased node and iteration limits, and takes a few seconds on my laptop.
This has essentially only been tested on the 8 examples above. So I would expect there to still be a few bugs. The algorithm is also spectacularly inefficient, considering that the equality of rational polynomials is almost trivial to compute.

91
src/cubic.rs Normal file
View File

@@ -0,0 +1,91 @@
use std::f64::consts::PI;
use crate::language::Rational;
pub fn approximate_rational_cubic(b: &Rational, c: &Rational, d: &Rational, limit: u64) -> Vec<Rational> {
let numerical_sols = solve_cubic_numerically(
1.0,
(b.num as f64) / (b.denom as f64),
(c.num as f64) / (c.denom as f64),
(d.num as f64) / (d.denom as f64));
numerical_sols.into_iter().map(|x|rational_approx(x, limit)).collect()
}
// assuming leading coefficient is not 0, and also that we don't have a triple root
pub fn solve_cubic_numerically(a: f64, b: f64, c: f64, d: f64) -> Vec<f64> {
assert_ne!(a, 0.0);
let b = b/a;
let c = c/a;
let d = d/a;
let b2 = b * b;
let b3 = b2 * b;
solve_depressed_cubic_numerically(
c - b2 / 3.0,
2.0 * b3 / 27.0 - b * c / 3.0 + d
).into_iter().map(|u|u - b / 3.0 / a).collect()
}
fn solve_depressed_cubic_numerically(p: f64, q: f64) -> Vec<f64> {
let disc = p * p * p / 27.0 + q * q / 4.0;
if disc < 0.0 {
let r = 2.0 * (-p / 3.0).sqrt();
let t = 3.0 * q / 2.0 / p * (- 3.0 / p).sqrt();
let phi = t.acos() / 3.0;
vec![r * (phi).cos(),
r * (phi + 2.0 * PI / 3.0).cos(),
r * (phi + 4.0 * PI / 3.0).cos()]
} else {
// not implemented at the moment
let u = - q / 2.0 + disc.sqrt();
let v = - q / 2.0 - disc.sqrt();
vec![u.powf(1.0/3.0) + v.powf(1.0/3.0)]
}
}
// pretty inefficient implementation of continued fraction approximation
// but doesn't matter here
pub fn rational_approx(x: f64, limit: u64) -> Rational {
if x < 0.0 {
let Rational { num, denom } = rational_approx(-x, limit);
return Rational { num: -num, denom }
}
let mut num = 0;
let mut denom = 0;
for l in 0 .. 10 {
let Some((p,q)) = rational_approx_level(x, l) else { break; };
if q > limit {
break;
}
(num, denom) = (p, q);
}
Rational{
num: num as i64,
denom
}
}
fn rational_approx_level(x: f64, level: usize) -> Option<(u64, u64)> {
// we expect very big numbers or infinity if the previous iteration was exact
if x > 1e9 {
return None;
}
if level == 0 {
Some((x as u64, 1))
} else {
let (p,q) = rational_approx_level(1.0 / (x - x.floor()), level-1)?;
let floorx = x as u64;
Some((q + floorx * p, p))
}
}

View File

@@ -1,5 +1,5 @@
use std::{cmp::Ordering, fmt::{self,Display, Formatter}, str::FromStr, sync::LazyLock}; use std::{cmp::Ordering, fmt::{self,Display, Formatter}, str::FromStr, sync::LazyLock};
use egg::{define_language, merge_option, rewrite as rw, Analysis, DidMerge, Id, Language, Subst, Var}; use egg::{define_language, merge_option, rewrite as rw, Analysis, Applier, DidMerge, Id, Language, PatternAst, Subst, Symbol, Var};
pub type EGraph = egg::EGraph<EquationLanguage, ConstantFold>; pub type EGraph = egg::EGraph<EquationLanguage, ConstantFold>;
pub type Rewrite = egg::Rewrite<EquationLanguage, ConstantFold>; pub type Rewrite = egg::Rewrite<EquationLanguage, ConstantFold>;
@@ -39,9 +39,19 @@ impl Display for Rational {
} }
impl FromStr for Rational { impl FromStr for Rational {
type Err = std::num::ParseIntError; type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> { fn from_str(s: &str) -> Result<Self, String> {
Ok(Rational { num: s.parse::<i64>()?, denom: 1 }) let err = || Err(format!("Couldn't parse rational: {}", s));
if let Ok(num) = s.parse::<i64>() {
Ok(Rational { num, denom: 1 })
} else if let Some((snum, sdenom)) = s.split_once('/') {
let Ok(num) = snum.parse::<i64>() else { return err(); };
let Ok(denom) = sdenom.parse::<u64>() else { return err(); };
Ok(Rational { num, denom })
} else {
err()
}
} }
} }
@@ -167,6 +177,46 @@ fn is_nonzero_const(var: &str) -> impl Fn(&mut EGraph, Id, &Subst) -> bool {
} }
} }
fn is_const(var: &str) -> impl Fn(&mut EGraph, Id, &Subst) -> bool {
let var: Var = var.parse().unwrap();
move |egraph, _, subst| {
egraph[subst[var]].data.is_some()
}
}
struct IntegerSqrt {
var: Var,
}
impl Applier<EquationLanguage, ConstantFold> for IntegerSqrt {
fn apply_one(&self,
egraph: &mut EGraph,
matched_id: Id,
subst: &Subst,
_searcher_pattern: Option<&PatternAst<EquationLanguage>>,
_rule_name: Symbol)
-> Vec<Id> {
let var_id = subst[self.var];
if let Some(value) = &egraph[var_id].data {
if value.denom == 1 && value.num >= 0 {
// isqrt is nightly only, so we just do this, adding 0.1 against rounding errors
let sq = (f64::sqrt(value.num as f64) + 0.1) as i64;
if value.num == sq*sq {
// println!("square root of integer {} is {}", value.num, sq);
let sq_id = egraph.add(EquationLanguage::Num(Rational { num: sq, denom: 1 }));
egraph.union(matched_id, sq_id);
return vec![matched_id, sq_id];
}
}
}
vec![]
}
}
pub static RULES: LazyLock<Vec<Rewrite>> = LazyLock::new(||vec![ pub static RULES: LazyLock<Vec<Rewrite>> = LazyLock::new(||vec![
rw!("commute-add"; "(+ ?x ?y)" => "(+ ?y ?x)"), rw!("commute-add"; "(+ ?x ?y)" => "(+ ?y ?x)"),
rw!("commute-mul"; "(* ?x ?y)" => "(* ?y ?x)"), rw!("commute-mul"; "(* ?x ?y)" => "(* ?y ?x)"),
@@ -178,6 +228,8 @@ pub static RULES: LazyLock<Vec<Rewrite>> = LazyLock::new(||vec![
rw!("mul-0"; "(* ?x 0)" => "0"), rw!("mul-0"; "(* ?x 0)" => "0"),
rw!("mul-1"; "(* ?x 1)" => "?x"), rw!("mul-1"; "(* ?x 1)" => "?x"),
rw!("0-sub"; "(- 0 ?x)" => "(- ?x)"),
rw!("add-sub"; "(+ ?x (* (-1) ?x))" => "0"), rw!("add-sub"; "(+ ?x (* (-1) ?x))" => "0"),
// division by zero shouldn't happen unless input is invalid // division by zero shouldn't happen unless input is invalid
rw!("mul-div"; "(* ?x (rec ?x))" => "1" if is_nonzero_const("?y")), rw!("mul-div"; "(* ?x (rec ?x))" => "1" if is_nonzero_const("?y")),
@@ -187,10 +239,6 @@ pub static RULES: LazyLock<Vec<Rewrite>> = LazyLock::new(||vec![
rw!("square"; "(^ ?x 2)" => "(* ?x ?x)"), rw!("square"; "(^ ?x 2)" => "(* ?x ?x)"),
rw!("cube"; "(^ ?x 3)" => "(* ?x (* ?x ?x))"), rw!("cube"; "(^ ?x 3)" => "(* ?x (* ?x ?x))"),
/*
rw!("inv-square"; "(* ?x ?x)" => "(^ ?x 2)"),
rw!("inv-cube"; "(* ?x (* ?x ?x))" => "(^ ?x 3)"),
*/
rw!("sub"; "(- ?x ?y)" => "(+ ?x (* -1 ?y))"), rw!("sub"; "(- ?x ?y)" => "(+ ?x (* -1 ?y))"),
rw!("neg"; "(- ?x)" => "(* -1 ?x)"), rw!("neg"; "(- ?x)" => "(* -1 ?x)"),
@@ -198,6 +246,8 @@ pub static RULES: LazyLock<Vec<Rewrite>> = LazyLock::new(||vec![
rw!("div"; "(/ ?x ?y)" => "(* ?x (rec ?y))" if is_nonzero_const("?y")), rw!("div"; "(/ ?x ?y)" => "(* ?x (rec ?y))" if is_nonzero_const("?y")),
rw!("factor_poly"; "(+ (* x ?x) ?y)" => "(* ?x (+ x (* ?y (rec ?x))))" if is_nonzero_const("?x")), rw!("factor_poly"; "(+ (* x ?x) ?y)" => "(* ?x (+ x (* ?y (rec ?x))))" if is_nonzero_const("?x")),
rw!("integer_sqrt"; "(^ ?x (/ 1 2))" => { IntegerSqrt { var: "?x".parse().unwrap() } } if is_const("?x")),
]); ]);
pub struct PlusTimesCostFn; pub struct PlusTimesCostFn;
@@ -218,125 +268,3 @@ impl egg::CostFunction<EquationLanguage> for PlusTimesCostFn {
enode.fold(op_cost, |sum, i| sum + costs(i)) enode.fold(op_cost, |sum, i| sum + costs(i))
} }
} }
#[derive(Debug,Clone,Copy)]
pub struct PolyStat {
degree: usize,
factors: usize, // non-constant factors
ops: usize,
monomial: bool,
sum_of_monomials: bool,
monic: bool,
factorized: bool, // a product of monic polynomials and at least one constant
}
#[derive(Debug,Clone,Copy)]
pub enum FactorizationCost {
UnwantedOps,
Polynomial(PolyStat)
}
fn score(cost: FactorizationCost) -> usize {
match cost {
FactorizationCost::UnwantedOps => 10000,
FactorizationCost::Polynomial(p) =>
if !p.factorized {
1000
} else {
100 * (9 - p.factors) + p.ops
},
}
}
impl PartialEq for FactorizationCost {
fn eq(&self, other: &Self) -> bool {
score(*self) == score(*other)
}
}
impl PartialOrd for FactorizationCost {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
usize::partial_cmp(&score(*self), &score(*other))
}
}
pub struct FactorizationCostFn;
impl egg::CostFunction<EquationLanguage> for FactorizationCostFn {
type Cost = FactorizationCost;
fn cost<C>(&mut self, enode: &EquationLanguage, mut costs: C) -> Self::Cost
where
C: FnMut(Id) -> Self::Cost,
{
match enode {
EquationLanguage::Add([a,b]) => {
match (costs(*a), costs(*b)) {
(FactorizationCost::Polynomial(p1),FactorizationCost::Polynomial(p2)) => {
// we only ever want to add monomials
let result_monic = if p1.degree > p2.degree {
p1.monic
} else if p2.degree > p1.degree {
p2.monic
} else {
false
};
if !p1.sum_of_monomials || !p2.sum_of_monomials {
FactorizationCost::UnwantedOps
} else {
FactorizationCost::Polynomial(PolyStat {
degree: usize::max(p1.degree, p2.degree),
factors: 1,
ops: p1.ops + p2.ops,
monomial: false,
sum_of_monomials: p1.sum_of_monomials && p2.sum_of_monomials,
monic: result_monic,
factorized: result_monic,
})
}
},
_ => FactorizationCost::UnwantedOps
}
},
EquationLanguage::Mul([a,b]) => {
match (costs(*a), costs(*b)) {
(FactorizationCost::Polynomial(p1), FactorizationCost::Polynomial(p2)) => {
FactorizationCost::Polynomial(PolyStat {
degree: p1.degree + p2.degree,
factors: p1.factors + p2.factors,
ops: p1.ops + p2.ops,
monomial: p1.monomial && p2.monomial,
sum_of_monomials: p1.monomial && p2.monomial,
monic: p1.monic && p2.monic,
factorized: (p1.monic && p2.factorized) || (p2.monic && p1.factorized)
})
},
_ => FactorizationCost::UnwantedOps
}
},
EquationLanguage::Num(c) => {
FactorizationCost::Polynomial(PolyStat {
degree: 0,
factors: 0,
ops: 0,
monomial: true,
sum_of_monomials: true,
monic: false,
factorized: true
})
},
EquationLanguage::Unknown => {
FactorizationCost::Polynomial(PolyStat {
degree: 1,
factors: 1,
ops: 0,
monomial: true,
sum_of_monomials: true,
monic: true,
factorized: true
})
},
_ => FactorizationCost::UnwantedOps,
}
}
}

View File

@@ -1,3 +1,5 @@
pub mod language; pub mod language;
pub mod normal_form; pub mod normal_form;
pub mod parse; pub mod parse;
pub mod output;
pub mod cubic;

View File

@@ -1,226 +1,148 @@
use egg::{Extractor, Pattern, RecExpr, Runner}; use egg::{AstSize, Extractor, Id, RecExpr, Runner};
use solveq::language::{RULES, EquationLanguage, PlusTimesCostFn, FactorizationCostFn}; use solveq::cubic::approximate_rational_cubic;
use solveq::normal_form::analyze3; use solveq::language::{EquationLanguage, Rational, RULES, EGraph};
use solveq::normal_form::extract_normal_form;
use solveq::parse::parse_equation; use solveq::parse::parse_equation;
use solveq::output::print_term;
static TEST_EQUATIONS: &[&str] = &[ static TEST_EQUATIONS: &[&str] = &[
"(x + 50) * 10 - 150 - 100", "(x + 50) * 10 - 150 = 100",
"(x - 2) * (x + 2) - 0", "(x - 2) * (x + 2) = 0",
"x ^ 2 - 4", "x ^ 2 = 4",
"x ^ 2 - 2 - 0", "x ^ 2 - 2 = 0",
"x ^ 2 - (2 * x + 15)", "x ^ 2 = 2 * x + 15",
"(x ^ 2 - 2 * x - 15) * (x + 5) - 0", "(x ^ 2 - 2 * x - 15) * (x + 5) = 0",
"x ^ 3 + 3 * x ^ 2 - 25 * x - 75 - 0", "x ^ 3 + 3 * x ^ 2 - 25 * x - 75 = 0",
"x ^ 3 - 91 * x - 90 = 0",
]; ];
fn main() { fn main() {
for eq in TEST_EQUATIONS { for eq in TEST_EQUATIONS {
let start = parse_equation(*eq).unwrap(); println!("Equation: {}", *eq);
// println!("{:?}", &start); let solutions = solve(*eq, true);
// do transformation to left - right = 0
let mut runner = Runner::default() let solutions_str: Vec<String> = solutions
.iter()
.map(|expr| format!("x = {}", print_term(expr)))
.collect();
println!("Solutions: {{ {} }}", solutions_str.join(", "));
println!("");
}
}
pub fn solve(eq: &str, verbose: bool) -> Vec<RecExpr<EquationLanguage>> {
let mut start = parse_equation(eq).unwrap();
let root_id = Id::from(start.as_ref().len()-1);
let EquationLanguage::Equals([left, right]) = start[root_id]
else { panic!("Not an equation without an equals sign!"); };
start[root_id] = EquationLanguage::Sub([left, right]);
if verbose {
println!("Parsed: {}", &start);
}
let runner = Runner::default()
.with_explanations_enabled() .with_explanations_enabled()
.with_expr(&start) .with_expr(&start)
.run(&*RULES); .run(&*RULES);
let extractor = Extractor::new(&runner.egraph, FactorizationCostFn); let Some(factorization) = extract_normal_form(&runner.egraph, runner.roots[0]) else {
panic!("Couldn't factorize polynomial!");
};
let (best_cost, best_expr) = extractor.find_best(runner.roots[0]); if verbose {
println!("Factorized normal form: {}", factorization);
println!("{}", start);
println!("{:?} {:?}", best_cost, <RecExpr<EquationLanguage> as AsRef<[EquationLanguage]>>::as_ref(&best_expr));
println!("");
} }
// let root = runner.roots[0]; let mut solutions: Vec<RecExpr<EquationLanguage>> = Vec::new();
// let egraph = &runner.egraph; for poly in &factorization.polynomials {
// let pattern: Pattern<EquationLanguage> = "(+ (* ?a (* x x)) ?c)".parse().unwrap(); if poly.len() == 2 { // linear factor
// let matches = pattern.search(&egraph); let Rational { num, denom } = &poly[0];
let mut solexpr = RecExpr::default();
solexpr.add(EquationLanguage::Num(Rational {
num: -*num,
denom: *denom
}));
solutions.push(solexpr);
} else if poly.len() == 3 { // quadratic factor
let Rational { num: num0, denom: denom0 } = &poly[0];
let Rational { num: num1, denom: denom1 } = &poly[1];
let sol1 = format!("- ({num1})/(2 * ({denom1})) + ((({num1})/(2 * ({denom1}))) ^ 2 - ({num0}) / ({denom0})) ^ (1/2)");
let sol2 = format!("- ({num1})/(2 * ({denom1})) - ((({num1})/(2 * ({denom1}))) ^ 2 - ({num0}) / ({denom0})) ^ (1/2)");
let expr = parse_equation(&sol1).unwrap();
let runner = Runner::default()
.with_expr(&expr)
.run(&*RULES);
let extractor = Extractor::new(&runner.egraph, AstSize);
let (_, simplified_expr) = extractor.find_best(runner.roots[0]);
solutions.push(simplified_expr);
// println!("{:?}", egraph.classes().count()); let expr = parse_equation(&sol2).unwrap();
let runner = Runner::default()
.with_expr(&expr)
.run(&*RULES);
let extractor = Extractor::new(&runner.egraph, AstSize);
let (_, simplified_expr) = extractor.find_best(runner.roots[0]);
solutions.push(simplified_expr);
} else if poly.len() == 4 {
let guesses = approximate_rational_cubic(
&Rational { num: 0, denom: 1 },
&Rational { num: -91, denom: 1 },
&Rational { num: -90, denom: 1 },
1000
);
// Analyze if guesses.len() < 3 {
// analyze3(egraph, runner.roots[0]);
/*
for class in egraph.classes() {
if monic_nonconst_polynomial(egraph, class.id).is_some() {
let (_, best_expr) = extractor.find_best(class.id);
println!("Monomial: {}", best_expr);
}
}
println!("{:?}", &matches);
*/
// println!("{}", runner.explain_equivalence(&start, &best_expr).get_flat_string());
}
/*
fn power_of_x(egraph: &EGraph, eclass: Id) -> Option<usize> {
for n in &egraph[eclass].nodes {
match *n {
EquationLanguage::Unknown => { return Some(1) },
EquationLanguage::Mul([a,b]) => {
let Some(left) = power_of_x(egraph, a) else { continue };
let Some(right) = power_of_x(egraph, b) else { continue };
return Some(left + right);
},
_ => {}
}
}
None
}
fn monomial(egraph: &EGraph, eclass: Id) -> Option<(usize, Rational)> {
if let Some(deg) = power_of_x(egraph, eclass) {
return Some((deg, RATIONAL_ONE.clone()));
}
for n in &egraph[eclass].nodes {
match *n {
EquationLanguage::Mul([a,b]) => {
let Some(coeff) = egraph[a].data.clone() else { continue };
let Some(deg) = power_of_x(egraph, b) else { continue };
return Some((deg, coeff));
},
_ => {}
}
}
None
}
// this is either a power_of_x, or a sum of this and a monomial
fn monic_nonconst_polynomial(egraph: &EGraph, eclass: Id) -> Option<Vec<Rational>> {
let mut result: Vec<Rational> = Vec::new();
if let Some(deg) = power_of_x(egraph, eclass) {
result.resize(deg - 1, RATIONAL_ZERO);
result.push(RATIONAL_ONE.clone());
return Some(result);
}
for n in &egraph[eclass].nodes {
match *n {
EquationLanguage::Add([a,b]) => {
let Some(mut leading) = monic_nonconst_polynomial(egraph, a)
else { continue };
let Some(addon) = monomial(egraph, b)
else { continue };
if leading.len() <= addon.0 || leading[addon.0] != RATIONAL_ZERO {
continue; continue;
} }
leading[addon.0] = addon.1.clone(); if verbose {
return Some(leading); println!("Guessing rational solutions: {}, {}, {}",
}, guesses[0], guesses[1], guesses[2]);
_ => {},
}
}
None
} }
*/ let start = format!("x ^ 3 + ({}) * x ^ 2 + ({}) * x + ({})",
poly[2], poly[1], poly[0]);
let goal = format!("(x - ({})) * (x - ({})) * (x - ({}))",
guesses[0], guesses[1], guesses[2]);
let start_expr = parse_equation(&start).unwrap();
let goal_expr = parse_equation(&goal).unwrap();
/* let mut egraph: EGraph = Default::default();
fn analyze(egraph: &EGraph, _id: Id) { egraph.add_expr(&start_expr);
let mut types: HashMap<Id, SpecialTerm> = HashMap::new(); egraph.add_expr(&goal_expr);
let mut todo: VecDeque<Id> = VecDeque::new();
// todo.push_back(runner.roots[0]); let runner = Runner::default()
for cls in egraph.classes() { .with_egraph(egraph)
todo.push_back(cls.id); .with_node_limit(100_000)
.with_iter_limit(100)
.run(&*RULES);
let equivs = runner.egraph.equivs(&start_expr, &goal_expr);
if !equivs.is_empty() {
if verbose {
println!("Verified guessed solutions!");
} }
'todo: while todo.len() > 0 { // expressions are equivalent, so the guesses are actually solutions
let id = todo.pop_front().unwrap(); for s in &guesses {
if types.contains_key(&id) { let mut solexpr = RecExpr::default();
continue 'todo; solexpr.add(EquationLanguage::Num(s.clone()));
solutions.push(solexpr);
} }
} else {
if let Some(c) = &egraph[id].data { if verbose {
types.insert(id, SpecialTerm::Constant(c.clone())); println!("Couldn't verify guessed solutions.");
continue 'todo;
} }
'nodes: for n in &egraph[id].nodes {
match *n {
EquationLanguage::Unknown => {
types.insert(id, SpecialTerm::PowerOfX(1));
continue 'todo;
},
EquationLanguage::Mul([a,b]) => {
if !types.contains_key(&a) {
todo.push_back(a);
todo.push_back(id);
continue 'nodes;
} }
if !types.contains_key(&b) {
todo.push_back(b);
todo.push_back(id);
continue 'nodes;
}
match (&types[&a], &types[&b]) {
(SpecialTerm::PowerOfX(dega), SpecialTerm::PowerOfX(degb)) => {
types.insert(id, SpecialTerm::PowerOfX(*dega + *degb));
},
(SpecialTerm::Constant(coeff), SpecialTerm::PowerOfX(deg)) => {
types.insert(id, SpecialTerm::Monomial(*deg, coeff.clone()));
},
_ => { continue 'nodes; },
}
continue 'todo;
},
EquationLanguage::Add([a,b]) => {
if !types.contains_key(&a) {
todo.push_front(a);
todo.push_back(id);
continue 'todo;
}
if !types.contains_key(&b) {
todo.push_front(b);
todo.push_back(id);
continue 'todo;
}
match (&types[&a], &types[&b]) {
(SpecialTerm::MonicNonconstPoly(poly), SpecialTerm::Monomial(deg, coeff)) => {
if poly.len() <= *deg || poly[*deg] != RATIONAL_ZERO {
continue 'nodes;
}
let mut poly = poly.clone();
poly[*deg] = coeff.clone();
types.insert(id, SpecialTerm::MonicNonconstPoly(poly));
},
_ => { continue 'nodes; },
}
continue 'todo;
},
_ => {},
} }
} }
types.insert(id, SpecialTerm::Other); solutions
} }
for (id, ty) in &types {
if !matches!(ty, &SpecialTerm::Other) {
println!("{:?}", &ty);
}
}
}
*/

View File

@@ -1,18 +1,269 @@
use crate::language::{EGraph, EquationLanguage, Rational, RATIONAL_ONE, RATIONAL_ZERO}; use crate::language::{EGraph, EquationLanguage, Rational, RATIONAL_ONE, RATIONAL_ZERO};
use std::collections::HashMap; use std::{collections::HashMap, fmt};
use egg::Id; use egg::Id;
#[derive(Debug,Clone)] #[derive(Debug,Clone)]
pub enum SpecialTerm { pub struct Factorization {
Constant(Rational), pub constant_factor: Rational,
PowerOfX(usize), pub polynomials: Vec<Vec<Rational>>,
Monomial(usize, Rational),
MonicNonconstPoly(Vec<Rational>),
Factorization(Rational, Vec<Vec<Rational>>),
Other,
} }
fn search_for<F, T>(egraph: &EGraph, f: F) -> HashMap<Id, T> impl fmt::Display for Factorization {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if self.constant_factor != RATIONAL_ONE {
write!(f, "{}", self.constant_factor)?;
}
for poly in &self.polynomials {
write!(f, "(")?;
for (deg, coeff) in poly.iter().enumerate() {
if deg == 0 {
write!(f, "{}", coeff)?;
} else if deg == 1 {
write!(f, " + {}x", coeff)?;
} else {
write!(f, " + {}x^{}", coeff, deg)?;
}
}
write!(f, ")")?;
}
Ok(())
}
}
// warning: this is a property of an eclass, not a particular expression
// so we shouldn't use anything that's not a polynomial invariant
#[derive(Debug,Clone)]
struct PolyStats {
degree: usize,
monomial: bool,
monic: bool,
}
fn gather_poly_stats(egraph: &EGraph) -> HashMap<Id, PolyStats> {
walk_egraph(egraph, |_id, node, stats: &HashMap<Id, PolyStats>| {
let x = |i: &Id| stats.get(&egraph.find(*i));
Some(match node {
EquationLanguage::Unknown => PolyStats {
degree: 1,
monomial: true,
monic: true,
},
EquationLanguage::Num(c) => PolyStats {
degree: 0,
monomial: true,
monic: c == &RATIONAL_ONE,
},
EquationLanguage::Mul([a,b]) => {
// if both aren't monic we can't tell, the leading coefficients could cancel
// but there should be an alternative representative with one of them monic
if !x(a)?.monic && !x(b)?.monic {
return None;
}
PolyStats {
degree: x(a)?.degree + x(b)?.degree,
monomial: x(a)?.monomial && x(b)?.monomial,
monic: x(a)?.monic && x(b)?.monic,
}
},
EquationLanguage::Add([a,b]) => {
// in this case, there should also be a simplified representative which
// has only a single leading term
if x(a)?.degree == x(b)?.degree {
return None;
}
PolyStats {
degree: usize::max(x(a)?.degree, x(b)?.degree),
monomial: false,
monic: if x(a)?.degree > x(b)?.degree { x(a)?.monic } else { x(b)?.monic },
}
},
_ => { return None; }
})
})
}
pub fn extract_normal_form(egraph: &EGraph, eclass: Id) -> Option<Factorization> {
let eclass = egraph.find(eclass); // get the canonical eclass
let stats = gather_poly_stats(egraph);
let Some(factorization) = find_general_factorization(egraph, &stats, eclass) else { return None; };
let mut result = Vec::new();
let mut coeff: Option<Rational> = None;
for factor in factorization {
let extracted = extract_polynomial(egraph, &stats, factor)?;
// println!("Extracted: {:?}", extracted);
if extracted.len() == 1 {
coeff = Some(extracted[0].clone());
} else {
result.push(extracted);
}
}
Some(Factorization {
constant_factor: coeff.unwrap_or_else(||RATIONAL_ONE.clone()),
polynomials: result
})
}
// a polynomial should be either of:
// - a monomial: then we know the degree and we try to parse it as a product of a constant and a monic monomial
// or just a constant
// - a sum of a monomial of highest degree, and a polynomial of lower degree, recursively walk these
fn extract_polynomial(egraph: &EGraph, stats: &HashMap<Id, PolyStats>, id: Id) -> Option<Vec<Rational>> {
let st = &stats[&id];
if st.monomial {
let (deg, coeff) = extract_monomial(egraph, stats, id)?;
let mut result = vec![RATIONAL_ZERO; deg];
result.push(coeff);
return Some(result);
} else {
for node in &egraph[id].nodes {
match node {
EquationLanguage::Add([a,b]) => {
let a = egraph.find(*a);
let b = egraph.find(*b);
let Some(stata) = &stats.get(&a) else { continue };
let Some(statb) = &stats.get(&b) else { continue };
if stata.degree == st.degree && stata.monomial && statb.degree < st.degree {
let (leading_deg, leading_coeff) = extract_monomial(egraph, stats, a)?;
let mut remainder = extract_polynomial(egraph, stats, b)?;
assert!(leading_deg >= remainder.len());
remainder.resize(leading_deg, RATIONAL_ZERO);
remainder.push(leading_coeff);
return Some(remainder);
}
},
_ => {}
}
}
}
None
}
// a monomial is either a power of x, a constant, or a product of a constant and power of x
fn extract_monomial(egraph: &EGraph, stats: &HashMap<Id, PolyStats>, id: Id) -> Option<(usize, Rational)> {
let st = &stats[&id];
assert!(st.monomial);
// monic + monomial = power of x
if st.monic {
return Some((st.degree, RATIONAL_ONE.clone()));
}
for node in &egraph[id].nodes {
match node {
EquationLanguage::Mul([a,b]) => {
let a = egraph.find(*a);
let b = egraph.find(*b);
let Some(statb) = stats.get(&b) else { continue };
// a should be a constant
let Some(coeff) = egraph[a].data.clone() else { continue };
// b should be monic and nonconstant (hence a power of x)
if statb.degree == 0 || !statb.monic {
continue;
}
assert_eq!(st.degree, statb.degree);
return Some((st.degree, coeff));
},
EquationLanguage::Num(c) => { // a constant is also a monomial
return Some((0, c.clone()));
},
_ => {},
}
}
None
}
// like find_monic_factorization, but with the option of having a constant factor
fn find_general_factorization(egraph: &EGraph, stats: &HashMap<Id, PolyStats>, id: Id) -> Option<Vec<Id>> {
let st = stats.get(&id)?;
if st.monic {
return find_monic_factorization(egraph, stats, id);
} else {
for node in &egraph[id].nodes {
match node {
EquationLanguage::Mul([a,b]) => {
let a = egraph.find(*a);
let b = egraph.find(*b);
let Some(stata) = stats.get(&a) else { continue };
let Some(statb) = stats.get(&b) else { continue };
// a is constant, b is monic and nonconstant
if stata.degree == 0 && statb.degree > 0 && statb.monic {
let mut fac = find_monic_factorization(egraph, stats, b)?;
fac.push(a);
return Some(fac);
}
},
_ => {}
}
}
}
None
}
// this assumes `id` to be canonical
fn find_monic_factorization(egraph: &EGraph, stats: &HashMap<Id, PolyStats>, id: Id) -> Option<Vec<Id>> {
// we want the polynomial to be nonconstant and monic
if ! stats.get(&id).is_some_and(|x|x.monic && x.degree > 0) {
return None;
}
// now the whole thing is a monic nonconst poly, so would be a valid factorization,
// but we want to go as deep as possible
// check if it is the product of two nonconstant monic polynomials
for node in &egraph[id].nodes {
match node {
EquationLanguage::Mul([a,b]) => {
let a = egraph.find(*a);
let b = egraph.find(*b);
let Some(stata) = stats.get(&a) else { continue };
let Some(statb) = stats.get(&b) else { continue };
if stata.degree == 0 || statb.degree == 0 || !stata.monic || !statb.monic {
continue;
}
// println!("stats = {:?}, stats a = {:?}, stats b = {:?}", stats[&id], stata, statb);
let Some(mut faca) = find_monic_factorization(egraph, stats, a) else { continue };
let Some(facb) = find_monic_factorization(egraph, stats, b) else { continue };
faca.extend_from_slice(&facb);
return Some(faca);
},
_ => {}
}
}
// at this point we know the current polynomial is monic, but we didn't find a further factorization
// so just return it as a single factor
Some(vec![id])
}
fn walk_egraph<F, T>(egraph: &EGraph, f: F) -> HashMap<Id, T>
where where
F: Fn(Id, &EquationLanguage, &HashMap<Id, T>) -> Option<T> { F: Fn(Id, &EquationLanguage, &HashMap<Id, T>) -> Option<T> {
@@ -22,208 +273,23 @@ where
while modifications > 0 { while modifications > 0 {
modifications = 0; modifications = 0;
for cls in egraph.classes() { 'next_class: for cls in egraph.classes() {
let id = cls.id; let id = cls.id;
if result.contains_key(&id) { if result.contains_key(&id) {
continue; continue 'next_class;
} }
for node in &cls.nodes { for node in &cls.nodes {
if let Some(x) = f(id, node, &result) { if let Some(x) = f(id, node, &result) {
result.insert(id, x); result.insert(id, x);
modifications += 1; modifications += 1;
continue 'next_class;
} }
} }
} }
println!("{} modifications!", modifications); // println!("{} modifications!", modifications);
} }
result result
} }
pub fn analyze3(egraph: &EGraph, eclass: Id) {
let constants = search_for(egraph, |id, _, _|
egraph[id].data.as_ref().map(|c|c.clone())
);
println!("{:?}", constants);
let powers_of_x = search_for(egraph, |_, node, matches| match *node {
EquationLanguage::Unknown => Some(1),
EquationLanguage::Mul([a,b]) => {
if !matches.contains_key(&a) || !matches.contains_key(&b) {
return None;
}
let (dega, degb) = (matches[&a], matches[&b]);
Some(dega + degb)
},
_ => None,
});
println!("{:?}", powers_of_x);
let monomials = search_for(egraph, |id, node, _| {
if let Some(deg) = powers_of_x.get(&id) {
return Some((*deg, RATIONAL_ONE.clone()));
}
if let Some(c) = constants.get(&id) {
return Some((0, c.clone()));
}
match *node {
EquationLanguage::Mul([a,b]) => {
if !constants.contains_key(&a) || !powers_of_x.contains_key(&b) {
return None;
}
let (coeff, deg) = (&constants[&a], powers_of_x[&b]);
Some((deg, coeff.clone()))
},
_ => None,
}
});
println!("{:?}", monomials);
let monic_polynomials = search_for(egraph, |id, node, matches| {
if let Some(deg) = powers_of_x.get(&id) {
let mut poly: Vec<Rational> = Vec::new();
poly.resize(*deg, RATIONAL_ZERO);
poly.push(RATIONAL_ONE.clone());
Some(poly)
} else {
match *node {
EquationLanguage::Add([a,b]) => {
if !matches.contains_key(&a) || !monomials.contains_key(&b) {
return None;
}
let (leading, (deg, coeff)) = (&matches[&a], &monomials[&b]);
if leading.len() <= *deg || leading[*deg] != RATIONAL_ZERO {
return None;
}
let mut poly = leading.clone();
poly[*deg] = coeff.clone();
Some(poly)
},
_ => None,
}
}
});
for p in &monic_polynomials {
println!("{:?}", p);
}
let factorizations: HashMap<Id, (Rational, Vec<Vec<Rational>>)> = search_for(egraph, |id, node, matches| {
if let Some(c) = constants.get(&id) {
return Some((c.clone(), vec![]));
}
if let Some(poly) = monic_polynomials.get(&id) {
return Some((RATIONAL_ONE.clone(), vec![poly.clone()]));
}
match *node {
EquationLanguage::Mul([a,b]) => {
if !matches.contains_key(&a) || !monic_polynomials.contains_key(&b) {
return None;
}
let ((factor, polys), newpoly) = (&matches[&a], &monic_polynomials[&b]);
let mut combined: Vec<Vec<Rational>> = polys.clone();
combined.push(newpoly.clone());
Some((factor.clone(), combined))
},
_ => None,
}
});
/*
for p in &factorizations {
println!("{:?}", p);
}
*/
println!("{:?}", factorizations[&eclass]);
}
pub fn analyze2(egraph: &EGraph) -> HashMap<Id, SpecialTerm> {
let mut types: HashMap<Id, SpecialTerm> = HashMap::new();
let mut modifications: usize = 1;
while modifications > 0 {
modifications = 0;
for cls in egraph.classes() {
let id = cls.id;
if types.contains_key(&id) {
continue;
}
if let Some(c) = &egraph[id].data {
types.insert(id, SpecialTerm::Constant(c.clone()));
modifications += 1;
continue;
}
for node in &cls.nodes {
match *node {
EquationLanguage::Unknown => {
types.insert(id, SpecialTerm::PowerOfX(1));
modifications += 1;
},
EquationLanguage::Mul([a,b]) => {
// as we don't know a and b yet, defer to future iteration
if !types.contains_key(&a) || !types.contains_key(&b) {
continue;
}
match (&types[&a], &types[&b]) {
(SpecialTerm::PowerOfX(dega), SpecialTerm::PowerOfX(degb)) => {
types.insert(id, SpecialTerm::PowerOfX(*dega + *degb));
modifications += 1;
},
(SpecialTerm::Constant(coeff), SpecialTerm::PowerOfX(deg)) => {
types.insert(id, SpecialTerm::Monomial(*deg, coeff.clone()));
modifications += 1;
},
_ => { },
}
},
EquationLanguage::Add([a,b]) => {
// as we don't know a and b yet, defer to future iteration
if !types.contains_key(&a) || !types.contains_key(&b) {
continue;
}
match (&types[&a], &types[&b]) {
(SpecialTerm::MonicNonconstPoly(poly), SpecialTerm::Monomial(deg, coeff)) => {
if poly.len() <= *deg || poly[*deg] != RATIONAL_ZERO {
continue;
}
let mut poly = poly.clone();
poly[*deg] = coeff.clone();
types.insert(id, SpecialTerm::MonicNonconstPoly(poly));
modifications += 1;
},
_ => { },
}
},
_ => {}
}
}
}
println!("{} modifications!", modifications);
}
types
}

66
src/output.rs Normal file
View File

@@ -0,0 +1,66 @@
use egg::{RecExpr, Id};
use crate::language::EquationLanguage;
// there is already a Display implementation generated by define_langauge!
// but we want an alternative string conversion
pub fn print_term(expr: &RecExpr<EquationLanguage>) -> String {
let root_id = Id::from(expr.as_ref().len()-1);
print_term_inner(expr, root_id).0
}
// the second result is the precedence of the top level op: 1 = '+-', 2 = '*/', 3 = '^', 4 = primitive
fn print_term_inner(expr: &RecExpr<EquationLanguage>, id: Id) -> (String, usize) {
match &expr[id] {
EquationLanguage::Num(c) => {
(format!("{}", c), if c.denom == 1 { 4 } else { 2 })
},
EquationLanguage::Neg([a]) => {
(print_unary(expr, *a, "-", 1), 1)
},
EquationLanguage::Add([a,b]) => {
(print_binary(expr, *a, *b, "+", 1), 1)
},
EquationLanguage::Sub([a,b]) => {
(print_binary(expr, *a, *b, "-", 1), 1)
},
EquationLanguage::Mul([a,b]) => {
(print_binary(expr, *a, *b, "*", 2), 2)
},
EquationLanguage::Div([a,b]) => {
(print_binary(expr, *a, *b, "/", 2), 2)
},
EquationLanguage::Power([a,b]) => {
(print_binary(expr, *a, *b, "^", 3), 3)
},
_ => unimplemented!()
}
}
fn print_unary(expr: &RecExpr<EquationLanguage>, a: Id, op: &str, precedence: usize) -> String {
let (astr, aprec) = print_term_inner(expr, a);
if aprec > precedence {
format!("{}{}", op, astr)
} else {
format!("{}({})", op, astr)
}
}
fn print_binary(expr: &RecExpr<EquationLanguage>, a: Id, b: Id, op: &str, precedence: usize) -> String {
let (astr, aprec) = print_term_inner(expr, a);
let (bstr, bprec) = print_term_inner(expr, b);
if aprec > precedence {
if bprec > precedence {
format!("{} {} {}", astr, op, bstr)
} else {
format!("{} {} ({})", astr, op, bstr)
}
} else {
if bprec > precedence {
format!("({}) {} {}", astr, op, bstr)
} else {
format!("({}) {} ({})", astr, op, bstr)
}
}
}

View File

@@ -1,5 +1,5 @@
use std::error::Error; use std::error::Error;
use egg::*; use egg::{RecExpr, Id, FromOp, FromOpError};
use crate::language::EquationLanguage; use crate::language::EquationLanguage;
pub fn parse_equation(input: &str) -> Result<RecExpr<EquationLanguage>, ParseError> { pub fn parse_equation(input: &str) -> Result<RecExpr<EquationLanguage>, ParseError> {
@@ -26,15 +26,15 @@ fn parse_equation_inner(input: &str, expr: &mut RecExpr<EquationLanguage>) -> Re
} }
match c { match c {
'^' if precedence > 3 => { '^' if precedence >= 3 => {
operator_position = Some(i); operator_position = Some(i);
precedence = 3; precedence = 3;
}, },
'*' | '/' if precedence > 2 => { '*' | '/' if precedence >= 2 => {
operator_position = Some(i); operator_position = Some(i);
precedence = 2; precedence = 2;
}, },
'-' | '+' if precedence > 1 => { '-' | '+' if precedence >= 1 => {
operator_position = Some(i); operator_position = Some(i);
precedence = 1; precedence = 1;
}, },