Compare commits
5 Commits
ef2a76869f
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fced3915ea | ||
|
|
829f1e7fa7 | ||
|
|
74ffaea4c0 | ||
|
|
0b8bdf3da6 | ||
|
|
c110dd6889 |
61
README.md
Normal file
61
README.md
Normal file
@@ -0,0 +1,61 @@
|
||||
# Symbolic solver for polynomials with rational coefficients #
|
||||
|
||||
## Program output ##
|
||||
|
||||
Using `cargo run --release`. On my laptop, the last example times out on debug settings.
|
||||
|
||||
Equation: (x + 50) * 10 - 150 = 100
|
||||
Parsed: (- (- (* (+ x 50) 10) 150) 100)
|
||||
Factorized normal form: 10(25 + 1x)
|
||||
Solutions: { x = -25 }
|
||||
|
||||
Equation: (x - 2) * (x + 2) = 0
|
||||
Parsed: (- (* (- x 2) (+ x 2)) 0)
|
||||
Factorized normal form: (-2 + 1x)(2 + 1x)
|
||||
Solutions: { x = 2, x = -2 }
|
||||
|
||||
Equation: x ^ 2 = 4
|
||||
Parsed: (- (^ x 2) 4)
|
||||
Factorized normal form: (-4 + 0x + 1x^2)
|
||||
Solutions: { x = 2, x = -2 }
|
||||
|
||||
Equation: x ^ 2 - 2 = 0
|
||||
Parsed: (- (- (^ x 2) 2) 0)
|
||||
Factorized normal form: (-2 + 0x + 1x^2)
|
||||
Solutions: { x = 2 ^ (1/2), x = -2 ^ (1/2) }
|
||||
|
||||
Equation: x ^ 2 = 2 * x + 15
|
||||
Parsed: (- (^ x 2) (+ (* 2 x) 15))
|
||||
Factorized normal form: (-15 + -2x + 1x^2)
|
||||
Solutions: { x = 5, x = -3 }
|
||||
|
||||
Equation: (x ^ 2 - 2 * x - 15) * (x + 5) = 0
|
||||
Parsed: (- (* (- (- (^ x 2) (* 2 x)) 15) (+ x 5)) 0)
|
||||
Factorized normal form: (-15 + -2x + 1x^2)(5 + 1x)
|
||||
Solutions: { x = 5, x = -3, x = -5 }
|
||||
|
||||
Equation: x ^ 3 + 3 * x ^ 2 - 25 * x - 75 = 0
|
||||
Parsed: (- (- (- (+ (^ x 3) (* 3 (^ x 2))) (* 25 x)) 75) 0)
|
||||
Factorized normal form: (3 + 1x)(-25 + 0x + 1x^2)
|
||||
Solutions: { x = -3, x = 5, x = -5 }
|
||||
|
||||
Equation: x ^ 3 - 91 * x - 90 = 0
|
||||
Parsed: (- (- (- (^ x 3) (* 91 x)) 90) 0)
|
||||
Factorized normal form: (-90 + -91x + 0x^2 + 1x^3)
|
||||
Guessing rational solutions: 10, -9, -1
|
||||
Verified guessed solutions!
|
||||
Solutions: { x = 10, x = -9, x = -1 }
|
||||
|
||||
## Notes ##
|
||||
|
||||
This uses an egraph with rewrite rules representing the usual rules of arithmetic, and constant folding with rational numbers. This alone is enough to solve linear equations, by just starting with the difference of the two sides and minimizing AST size. However, it can't suffice for quadratic equations, as the egraph can't create values "out of this air". For example, to solve `x^2 = 2` the egraph must at least have "seen" the expression `2^(1/2)` once.
|
||||
|
||||
So instead, the egraph is used to bring the expression into a standard form: a product of monic polynomials which is factored as much as possible. Since this is difficult to model with a cost function, I wrote a custom extractor (in `normal_form.rs`).
|
||||
|
||||
Then we can use the usual quadratic formula to find solutions. The obtained solutions are afterwards simplified by another egraph, with a special "square root of a square integer" rule. This is pretty minimal, a more complete solution would also be able to deal with square roots of rational numbers etc.
|
||||
|
||||
Finally, the unfactorized cubic equation (equation 7) gets factored into a linear and a quadratic factor by the egraph. This is kind of "magic" though: the correct terms seem to come up randomly while mutating the egraph. For general cubics we can't expect this to happen, even if they have integer solutions. The additional equation 8 is an example for this.
|
||||
|
||||
While there is a general formula for the solutions of a cubic equation, it contains trigonometric functions and/or complex numbers, and we don't want to deal with that here. So we use a different strategy which can handle cubic equations with 3 rational solutions: we first solve the equation numerically using the cubic formula, then find rational expressions for the numerical solutions using continued fractions. Finally, we verify the resulting factorization by using an egraph. This equivalence check required increased node and iteration limits, and takes a few seconds on my laptop.
|
||||
|
||||
This has essentially only been tested on the 8 examples above. So I would expect there to still be a few bugs. The algorithm is also spectacularly inefficient, considering that the equality of rational polynomials is almost trivial to compute.
|
||||
91
src/cubic.rs
Normal file
91
src/cubic.rs
Normal file
@@ -0,0 +1,91 @@
|
||||
use std::f64::consts::PI;
|
||||
|
||||
use crate::language::Rational;
|
||||
|
||||
pub fn approximate_rational_cubic(b: &Rational, c: &Rational, d: &Rational, limit: u64) -> Vec<Rational> {
|
||||
let numerical_sols = solve_cubic_numerically(
|
||||
1.0,
|
||||
(b.num as f64) / (b.denom as f64),
|
||||
(c.num as f64) / (c.denom as f64),
|
||||
(d.num as f64) / (d.denom as f64));
|
||||
|
||||
numerical_sols.into_iter().map(|x|rational_approx(x, limit)).collect()
|
||||
}
|
||||
|
||||
// assuming leading coefficient is not 0, and also that we don't have a triple root
|
||||
pub fn solve_cubic_numerically(a: f64, b: f64, c: f64, d: f64) -> Vec<f64> {
|
||||
assert_ne!(a, 0.0);
|
||||
|
||||
let b = b/a;
|
||||
let c = c/a;
|
||||
let d = d/a;
|
||||
|
||||
let b2 = b * b;
|
||||
let b3 = b2 * b;
|
||||
|
||||
solve_depressed_cubic_numerically(
|
||||
c - b2 / 3.0,
|
||||
2.0 * b3 / 27.0 - b * c / 3.0 + d
|
||||
).into_iter().map(|u|u - b / 3.0 / a).collect()
|
||||
}
|
||||
|
||||
fn solve_depressed_cubic_numerically(p: f64, q: f64) -> Vec<f64> {
|
||||
let disc = p * p * p / 27.0 + q * q / 4.0;
|
||||
|
||||
if disc < 0.0 {
|
||||
let r = 2.0 * (-p / 3.0).sqrt();
|
||||
let t = 3.0 * q / 2.0 / p * (- 3.0 / p).sqrt();
|
||||
let phi = t.acos() / 3.0;
|
||||
|
||||
vec![r * (phi).cos(),
|
||||
r * (phi + 2.0 * PI / 3.0).cos(),
|
||||
r * (phi + 4.0 * PI / 3.0).cos()]
|
||||
} else {
|
||||
// not implemented at the moment
|
||||
let u = - q / 2.0 + disc.sqrt();
|
||||
let v = - q / 2.0 - disc.sqrt();
|
||||
vec![u.powf(1.0/3.0) + v.powf(1.0/3.0)]
|
||||
}
|
||||
}
|
||||
|
||||
// pretty inefficient implementation of continued fraction approximation
|
||||
// but doesn't matter here
|
||||
pub fn rational_approx(x: f64, limit: u64) -> Rational {
|
||||
if x < 0.0 {
|
||||
let Rational { num, denom } = rational_approx(-x, limit);
|
||||
return Rational { num: -num, denom }
|
||||
}
|
||||
|
||||
let mut num = 0;
|
||||
let mut denom = 0;
|
||||
for l in 0 .. 10 {
|
||||
let Some((p,q)) = rational_approx_level(x, l) else { break; };
|
||||
|
||||
if q > limit {
|
||||
break;
|
||||
}
|
||||
|
||||
(num, denom) = (p, q);
|
||||
}
|
||||
|
||||
Rational{
|
||||
num: num as i64,
|
||||
denom
|
||||
}
|
||||
}
|
||||
|
||||
fn rational_approx_level(x: f64, level: usize) -> Option<(u64, u64)> {
|
||||
// we expect very big numbers or infinity if the previous iteration was exact
|
||||
if x > 1e9 {
|
||||
return None;
|
||||
}
|
||||
|
||||
if level == 0 {
|
||||
Some((x as u64, 1))
|
||||
} else {
|
||||
let (p,q) = rational_approx_level(1.0 / (x - x.floor()), level-1)?;
|
||||
let floorx = x as u64;
|
||||
|
||||
Some((q + floorx * p, p))
|
||||
}
|
||||
}
|
||||
@@ -1,256 +0,0 @@
|
||||
use core::fmt;
|
||||
use std::{cmp::Ordering, fmt::Display};
|
||||
use crate::language::{EquationLanguage, Rational, RATIONAL_ONE, RATIONAL_ZERO};
|
||||
use egg::{Id, RecExpr};
|
||||
|
||||
|
||||
#[derive(Debug,Clone,Copy)]
|
||||
pub struct PolyStat {
|
||||
degree: usize,
|
||||
factors: usize, // non-constant factors
|
||||
ops: usize,
|
||||
monomial: bool,
|
||||
sum_of_monomials: bool,
|
||||
monic: bool,
|
||||
factorized: bool, // a product of monic polynomials and at least one constant
|
||||
}
|
||||
|
||||
#[derive(Debug,Clone,Copy)]
|
||||
pub enum FactorizationCost {
|
||||
UnwantedOps,
|
||||
Polynomial(PolyStat)
|
||||
}
|
||||
|
||||
fn score(cost: FactorizationCost) -> usize {
|
||||
match cost {
|
||||
FactorizationCost::UnwantedOps => 10000,
|
||||
FactorizationCost::Polynomial(p) =>
|
||||
if !p.factorized {
|
||||
1000 + p.ops
|
||||
} else {
|
||||
100 * (9 - p.factors) + p.ops
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialEq for FactorizationCost {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
score(*self) == score(*other)
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialOrd for FactorizationCost {
|
||||
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
|
||||
usize::partial_cmp(&score(*self), &score(*other))
|
||||
}
|
||||
}
|
||||
|
||||
pub struct FactorizationCostFn;
|
||||
impl egg::CostFunction<EquationLanguage> for FactorizationCostFn {
|
||||
type Cost = FactorizationCost;
|
||||
|
||||
fn cost<C>(&mut self, enode: &EquationLanguage, mut costs: C) -> Self::Cost
|
||||
where
|
||||
C: FnMut(Id) -> Self::Cost,
|
||||
{
|
||||
match enode {
|
||||
EquationLanguage::Add([a,b]) => {
|
||||
match (costs(*a), costs(*b)) {
|
||||
(FactorizationCost::Polynomial(p1),FactorizationCost::Polynomial(p2)) => {
|
||||
// we only ever want to add monomials
|
||||
let result_monic = if p1.degree > p2.degree {
|
||||
p1.monic
|
||||
} else if p2.degree > p1.degree {
|
||||
p2.monic
|
||||
} else {
|
||||
false
|
||||
};
|
||||
|
||||
/*
|
||||
if *a == Id::from(4) && *b == Id::from(19) {
|
||||
println!("HERE {:?} {:?}", p1, p2);
|
||||
}
|
||||
*/
|
||||
|
||||
|
||||
if !p1.sum_of_monomials || !p2.sum_of_monomials {
|
||||
FactorizationCost::UnwantedOps
|
||||
} else {
|
||||
FactorizationCost::Polynomial(PolyStat {
|
||||
degree: usize::max(p1.degree, p2.degree),
|
||||
factors: 1,
|
||||
ops: p1.ops + p2.ops + 1,
|
||||
monomial: false,
|
||||
sum_of_monomials: p1.sum_of_monomials && p2.sum_of_monomials,
|
||||
monic: result_monic,
|
||||
factorized: result_monic,
|
||||
})
|
||||
}
|
||||
},
|
||||
_ => FactorizationCost::UnwantedOps
|
||||
}
|
||||
},
|
||||
EquationLanguage::Mul([a,b]) => {
|
||||
match (costs(*a), costs(*b)) {
|
||||
(FactorizationCost::Polynomial(p1), FactorizationCost::Polynomial(p2)) => {
|
||||
FactorizationCost::Polynomial(PolyStat {
|
||||
degree: p1.degree + p2.degree,
|
||||
factors: p1.factors + p2.factors,
|
||||
ops: p1.ops + p2.ops + 1,
|
||||
monomial: p1.monomial && p2.monomial,
|
||||
sum_of_monomials: p1.monomial && p2.monomial,
|
||||
monic: p1.monic && p2.monic,
|
||||
factorized: (p1.monic && p2.factorized) || (p2.monic && p1.factorized)
|
||||
})
|
||||
},
|
||||
_ => FactorizationCost::UnwantedOps
|
||||
}
|
||||
},
|
||||
EquationLanguage::Num(c) => {
|
||||
FactorizationCost::Polynomial(PolyStat {
|
||||
degree: 0,
|
||||
factors: 0,
|
||||
ops: 0,
|
||||
monomial: true,
|
||||
sum_of_monomials: true,
|
||||
monic: false,
|
||||
factorized: true
|
||||
})
|
||||
},
|
||||
EquationLanguage::Unknown => {
|
||||
FactorizationCost::Polynomial(PolyStat {
|
||||
degree: 1,
|
||||
factors: 1,
|
||||
ops: 0,
|
||||
monomial: true,
|
||||
sum_of_monomials: true,
|
||||
monic: true,
|
||||
factorized: true
|
||||
})
|
||||
},
|
||||
_ => FactorizationCost::UnwantedOps,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug,Clone)]
|
||||
pub struct Factorization {
|
||||
pub constant_factor: Rational,
|
||||
pub polynomials: Vec<Vec<Rational>>,
|
||||
}
|
||||
|
||||
impl Display for Factorization {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
if self.constant_factor != RATIONAL_ONE {
|
||||
write!(f, "{}", self.constant_factor)?;
|
||||
}
|
||||
|
||||
for poly in &self.polynomials {
|
||||
write!(f, "(")?;
|
||||
for (deg, coeff) in poly.iter().enumerate() {
|
||||
if deg == 0 {
|
||||
write!(f, "{}", coeff)?;
|
||||
} else if deg == 1 {
|
||||
write!(f, " + {}x", coeff)?;
|
||||
} else {
|
||||
write!(f, " + {}x^{}", coeff, deg)?;
|
||||
}
|
||||
}
|
||||
write!(f, ")")?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub fn extract_factorization(expr: &RecExpr<EquationLanguage>) -> Factorization {
|
||||
let root_id: Id = Id::from(expr.as_ref().len()-1);
|
||||
|
||||
let mut constant_factor: Option<Rational> = None;
|
||||
let mut factors: Vec<Vec<Rational>> = Vec::new();
|
||||
let mut todo: Vec<Id> = Vec::new();
|
||||
todo.push(root_id);
|
||||
|
||||
while todo.len() > 0 {
|
||||
let id = todo.pop().unwrap();
|
||||
|
||||
match &expr[id] {
|
||||
EquationLanguage::Mul([a,b]) => {
|
||||
todo.push(*a);
|
||||
todo.push(*b);
|
||||
},
|
||||
EquationLanguage::Num(x) => {
|
||||
assert!(constant_factor.is_none());
|
||||
constant_factor = Some(x.clone());
|
||||
},
|
||||
_ => {
|
||||
factors.push(extract_polynomial(expr, id));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Factorization {
|
||||
constant_factor: constant_factor.unwrap_or_else(||RATIONAL_ONE.clone()),
|
||||
polynomials: factors
|
||||
}
|
||||
}
|
||||
|
||||
fn extract_polynomial(expr: &RecExpr<EquationLanguage>, id: Id) -> Vec<Rational> {
|
||||
let mut result: Vec<Rational> = Vec::new();
|
||||
let mut todo: Vec<Id> = Vec::new();
|
||||
todo.push(id);
|
||||
|
||||
while todo.len() > 0 {
|
||||
let id = todo.pop().unwrap();
|
||||
|
||||
match &expr[id] {
|
||||
EquationLanguage::Add([a,b]) => {
|
||||
todo.push(*a);
|
||||
todo.push(*b);
|
||||
},
|
||||
_ => {
|
||||
let (deg, coeff) = extract_monomial(expr, id);
|
||||
result.resize(result.len().max(deg), RATIONAL_ZERO.clone());
|
||||
|
||||
if result.len() <= deg {
|
||||
result.push(coeff);
|
||||
} else {
|
||||
assert!(result[deg] == RATIONAL_ZERO);
|
||||
result[deg] = coeff;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
fn extract_monomial(expr: &RecExpr<EquationLanguage>, id: Id) -> (usize, Rational) {
|
||||
let mut coeff: Option<Rational> = None;
|
||||
let mut deg: usize = 0;
|
||||
let mut todo: Vec<Id> = Vec::new();
|
||||
todo.push(id);
|
||||
|
||||
while todo.len() > 0 {
|
||||
let id = todo.pop().unwrap();
|
||||
|
||||
match &expr[id] {
|
||||
EquationLanguage::Unknown => {
|
||||
deg += 1;
|
||||
},
|
||||
EquationLanguage::Mul([a,b]) => {
|
||||
todo.push(*a);
|
||||
todo.push(*b);
|
||||
},
|
||||
EquationLanguage::Num(x) => {
|
||||
assert!(coeff.is_none());
|
||||
coeff = Some(x.clone());
|
||||
},
|
||||
_ => {
|
||||
panic!("Not a rational polynomial in normal form!");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
(deg, coeff.unwrap_or_else(||RATIONAL_ONE.clone()))
|
||||
}
|
||||
@@ -1,5 +1,5 @@
|
||||
use std::{cmp::Ordering, fmt::{self,Display, Formatter}, str::FromStr, sync::LazyLock};
|
||||
use egg::{define_language, merge_option, rewrite as rw, Analysis, DidMerge, Id, Language, Subst, Var};
|
||||
use egg::{define_language, merge_option, rewrite as rw, Analysis, Applier, DidMerge, Id, Language, PatternAst, Subst, Symbol, Var};
|
||||
|
||||
pub type EGraph = egg::EGraph<EquationLanguage, ConstantFold>;
|
||||
pub type Rewrite = egg::Rewrite<EquationLanguage, ConstantFold>;
|
||||
@@ -39,9 +39,19 @@ impl Display for Rational {
|
||||
}
|
||||
|
||||
impl FromStr for Rational {
|
||||
type Err = std::num::ParseIntError;
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
Ok(Rational { num: s.parse::<i64>()?, denom: 1 })
|
||||
type Err = String;
|
||||
fn from_str(s: &str) -> Result<Self, String> {
|
||||
let err = || Err(format!("Couldn't parse rational: {}", s));
|
||||
|
||||
if let Ok(num) = s.parse::<i64>() {
|
||||
Ok(Rational { num, denom: 1 })
|
||||
} else if let Some((snum, sdenom)) = s.split_once('/') {
|
||||
let Ok(num) = snum.parse::<i64>() else { return err(); };
|
||||
let Ok(denom) = sdenom.parse::<u64>() else { return err(); };
|
||||
Ok(Rational { num, denom })
|
||||
} else {
|
||||
err()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -174,6 +184,39 @@ fn is_const(var: &str) -> impl Fn(&mut EGraph, Id, &Subst) -> bool {
|
||||
}
|
||||
}
|
||||
|
||||
struct IntegerSqrt {
|
||||
var: Var,
|
||||
}
|
||||
|
||||
impl Applier<EquationLanguage, ConstantFold> for IntegerSqrt {
|
||||
fn apply_one(&self,
|
||||
egraph: &mut EGraph,
|
||||
matched_id: Id,
|
||||
subst: &Subst,
|
||||
_searcher_pattern: Option<&PatternAst<EquationLanguage>>,
|
||||
_rule_name: Symbol)
|
||||
-> Vec<Id> {
|
||||
let var_id = subst[self.var];
|
||||
|
||||
if let Some(value) = &egraph[var_id].data {
|
||||
if value.denom == 1 && value.num >= 0 {
|
||||
// isqrt is nightly only, so we just do this, adding 0.1 against rounding errors
|
||||
let sq = (f64::sqrt(value.num as f64) + 0.1) as i64;
|
||||
if value.num == sq*sq {
|
||||
// println!("square root of integer {} is {}", value.num, sq);
|
||||
|
||||
let sq_id = egraph.add(EquationLanguage::Num(Rational { num: sq, denom: 1 }));
|
||||
egraph.union(matched_id, sq_id);
|
||||
|
||||
return vec![matched_id, sq_id];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
vec![]
|
||||
}
|
||||
}
|
||||
|
||||
pub static RULES: LazyLock<Vec<Rewrite>> = LazyLock::new(||vec![
|
||||
rw!("commute-add"; "(+ ?x ?y)" => "(+ ?y ?x)"),
|
||||
rw!("commute-mul"; "(* ?x ?y)" => "(* ?y ?x)"),
|
||||
@@ -185,6 +228,8 @@ pub static RULES: LazyLock<Vec<Rewrite>> = LazyLock::new(||vec![
|
||||
rw!("mul-0"; "(* ?x 0)" => "0"),
|
||||
rw!("mul-1"; "(* ?x 1)" => "?x"),
|
||||
|
||||
rw!("0-sub"; "(- 0 ?x)" => "(- ?x)"),
|
||||
|
||||
rw!("add-sub"; "(+ ?x (* (-1) ?x))" => "0"),
|
||||
// division by zero shouldn't happen unless input is invalid
|
||||
rw!("mul-div"; "(* ?x (rec ?x))" => "1" if is_nonzero_const("?y")),
|
||||
@@ -202,7 +247,7 @@ pub static RULES: LazyLock<Vec<Rewrite>> = LazyLock::new(||vec![
|
||||
|
||||
rw!("factor_poly"; "(+ (* x ?x) ?y)" => "(* ?x (+ x (* ?y (rec ?x))))" if is_nonzero_const("?x")),
|
||||
|
||||
// rw!("integer_sqrt"; "(^ ?x (1/2))" => {} if is_const("?x")),
|
||||
rw!("integer_sqrt"; "(^ ?x (/ 1 2))" => { IntegerSqrt { var: "?x".parse().unwrap() } } if is_const("?x")),
|
||||
]);
|
||||
|
||||
pub struct PlusTimesCostFn;
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
pub mod language;
|
||||
pub mod normal_form;
|
||||
pub mod parse;
|
||||
pub mod factorization;
|
||||
pub mod output;
|
||||
pub mod cubic;
|
||||
|
||||
202
src/main.rs
202
src/main.rs
@@ -1,6 +1,6 @@
|
||||
use egg::{AstSize, Extractor, Id, Pattern, PatternAst, RecExpr, Runner, Searcher};
|
||||
use solveq::factorization::{FactorizationCost, FactorizationCostFn};
|
||||
use solveq::language::{ConstantFold, EquationLanguage, Rational, RULES, EGraph};
|
||||
use egg::{AstSize, Extractor, Id, RecExpr, Runner};
|
||||
use solveq::cubic::approximate_rational_cubic;
|
||||
use solveq::language::{EquationLanguage, Rational, RULES, EGraph};
|
||||
use solveq::normal_form::extract_normal_form;
|
||||
use solveq::parse::parse_equation;
|
||||
use solveq::output::print_term;
|
||||
@@ -13,96 +13,136 @@ static TEST_EQUATIONS: &[&str] = &[
|
||||
"x ^ 2 = 2 * x + 15",
|
||||
"(x ^ 2 - 2 * x - 15) * (x + 5) = 0",
|
||||
"x ^ 3 + 3 * x ^ 2 - 25 * x - 75 = 0",
|
||||
"x ^ 3 - 91 * x - 90 = 0",
|
||||
];
|
||||
|
||||
|
||||
fn main() {
|
||||
// let expr: RecExpr<EquationLanguage> = "(+ (* x (+ x -2)) -15)".parse().unwrap();
|
||||
// let expr: RecExpr<EquationLanguage> = "(* (+ (* x (+ x -2)) -15) (+ x 5))".parse().unwrap();
|
||||
// println!("{:?}", get_expression_cost(&expr));
|
||||
|
||||
for eq in TEST_EQUATIONS {
|
||||
println!("Equation: {}", *eq);
|
||||
|
||||
let mut start = parse_equation(*eq).unwrap();
|
||||
let root_id = Id::from(start.as_ref().len()-1);
|
||||
let EquationLanguage::Equals([left, right]) = start[root_id]
|
||||
else { panic!("Not an equation without an equals sign!"); };
|
||||
start[root_id] = EquationLanguage::Sub([left, right]);
|
||||
let solutions = solve(*eq, true);
|
||||
|
||||
println!("Parsed: {}", &start);
|
||||
let solutions_str: Vec<String> = solutions
|
||||
.iter()
|
||||
.map(|expr| format!("x = {}", print_term(expr)))
|
||||
.collect();
|
||||
println!("Solutions: {{ {} }}", solutions_str.join(", "));
|
||||
|
||||
let mut runner = Runner::default()
|
||||
.with_explanations_enabled()
|
||||
.with_expr(&start)
|
||||
.run(&*RULES);
|
||||
|
||||
let extractor = Extractor::new(&runner.egraph, FactorizationCostFn);
|
||||
|
||||
let (best_cost, best_expr) = extractor.find_best(runner.roots[0]);
|
||||
|
||||
// println!("Best expresssion: {} {:?}", best_expr, best_cost);
|
||||
// println!("{}", runner.explain_equivalence(&start, &best_expr).get_flat_string());
|
||||
|
||||
let Some(factorization) = extract_normal_form(&runner.egraph, runner.roots[0]) else {
|
||||
panic!("Couldn't factorize polynomial!");
|
||||
};
|
||||
println!("Factorized normal form: {}", factorization);
|
||||
|
||||
/*
|
||||
get_expression_cost("(* x (+ x -2))", &runner.egraph);
|
||||
get_expression_cost("-15", &runner.egraph);
|
||||
get_expression_cost("(+ (* x (+ x -2)) -15)", &runner.egraph);
|
||||
*/
|
||||
|
||||
// let factorization = extract_factorization(&best_expr);
|
||||
|
||||
|
||||
let mut solutions: Vec<String> = Vec::new();
|
||||
for poly in &factorization.polynomials {
|
||||
if poly.len() == 2 { // linear factor
|
||||
let Rational { num, denom } = &poly[0];
|
||||
solutions.push(format!("x = {}", Rational { num: -*num, denom: *denom }));
|
||||
} else if poly.len() == 3 { // quadratic factor
|
||||
let Rational { num: num0, denom: denom0 } = &poly[0];
|
||||
let Rational { num: num1, denom: denom1 } = &poly[1];
|
||||
|
||||
let sol1 = format!("- ({num1})/(2 * ({denom1})) + ((({num1})/(2 * ({denom1}))) ^ 2 - ({num0}) / ({denom0})) ^ (1/2)");
|
||||
let sol2 = format!("- ({num1})/(2 * ({denom1})) - ((({num1})/(2 * ({denom1}))) ^ 2 - ({num0}) / ({denom0})) ^ (1/2)");
|
||||
|
||||
let expr = parse_equation(&sol1).unwrap();
|
||||
let runner = Runner::default()
|
||||
.with_expr(&expr)
|
||||
.run(&*RULES);
|
||||
let extractor = Extractor::new(&runner.egraph, AstSize);
|
||||
let (_, simplified_expr) = extractor.find_best(runner.roots[0]);
|
||||
solutions.push(format!("x = {}", print_term(&simplified_expr)));
|
||||
|
||||
let expr = parse_equation(&sol2).unwrap();
|
||||
let runner = Runner::default()
|
||||
.with_expr(&expr)
|
||||
.run(&*RULES);
|
||||
let extractor = Extractor::new(&runner.egraph, AstSize);
|
||||
let (_, simplified_expr) = extractor.find_best(runner.roots[0]);
|
||||
solutions.push(format!("x = {}", print_term(&simplified_expr)));
|
||||
}
|
||||
}
|
||||
|
||||
println!("Solutions: {{ {} }}", solutions.join(", "));
|
||||
println!("");
|
||||
}
|
||||
}
|
||||
|
||||
fn get_expression_cost(expr: &str, egraph: &EGraph) {
|
||||
// let mut egraph = EGraph::new(ConstantFold::default());
|
||||
// let id = egraph.add_expr(expr);
|
||||
let pattern: Pattern<EquationLanguage> = expr.parse().unwrap();
|
||||
let matches = pattern.search(egraph);
|
||||
for m in matches {
|
||||
let extractor = Extractor::new(&egraph, FactorizationCostFn);
|
||||
let (cost, _) = extractor.find_best(m.eclass);
|
||||
pub fn solve(eq: &str, verbose: bool) -> Vec<RecExpr<EquationLanguage>> {
|
||||
let mut start = parse_equation(eq).unwrap();
|
||||
let root_id = Id::from(start.as_ref().len()-1);
|
||||
let EquationLanguage::Equals([left, right]) = start[root_id]
|
||||
else { panic!("Not an equation without an equals sign!"); };
|
||||
start[root_id] = EquationLanguage::Sub([left, right]);
|
||||
|
||||
println!("expr: {}, id: {}, cost: {:?}", expr, m.eclass, cost);
|
||||
if verbose {
|
||||
println!("Parsed: {}", &start);
|
||||
}
|
||||
// cost
|
||||
|
||||
let runner = Runner::default()
|
||||
.with_explanations_enabled()
|
||||
.with_expr(&start)
|
||||
.run(&*RULES);
|
||||
|
||||
let Some(factorization) = extract_normal_form(&runner.egraph, runner.roots[0]) else {
|
||||
panic!("Couldn't factorize polynomial!");
|
||||
};
|
||||
|
||||
if verbose {
|
||||
println!("Factorized normal form: {}", factorization);
|
||||
}
|
||||
|
||||
let mut solutions: Vec<RecExpr<EquationLanguage>> = Vec::new();
|
||||
for poly in &factorization.polynomials {
|
||||
if poly.len() == 2 { // linear factor
|
||||
let Rational { num, denom } = &poly[0];
|
||||
let mut solexpr = RecExpr::default();
|
||||
solexpr.add(EquationLanguage::Num(Rational {
|
||||
num: -*num,
|
||||
denom: *denom
|
||||
}));
|
||||
solutions.push(solexpr);
|
||||
} else if poly.len() == 3 { // quadratic factor
|
||||
let Rational { num: num0, denom: denom0 } = &poly[0];
|
||||
let Rational { num: num1, denom: denom1 } = &poly[1];
|
||||
|
||||
let sol1 = format!("- ({num1})/(2 * ({denom1})) + ((({num1})/(2 * ({denom1}))) ^ 2 - ({num0}) / ({denom0})) ^ (1/2)");
|
||||
let sol2 = format!("- ({num1})/(2 * ({denom1})) - ((({num1})/(2 * ({denom1}))) ^ 2 - ({num0}) / ({denom0})) ^ (1/2)");
|
||||
|
||||
let expr = parse_equation(&sol1).unwrap();
|
||||
let runner = Runner::default()
|
||||
.with_expr(&expr)
|
||||
.run(&*RULES);
|
||||
let extractor = Extractor::new(&runner.egraph, AstSize);
|
||||
let (_, simplified_expr) = extractor.find_best(runner.roots[0]);
|
||||
solutions.push(simplified_expr);
|
||||
|
||||
let expr = parse_equation(&sol2).unwrap();
|
||||
let runner = Runner::default()
|
||||
.with_expr(&expr)
|
||||
.run(&*RULES);
|
||||
let extractor = Extractor::new(&runner.egraph, AstSize);
|
||||
let (_, simplified_expr) = extractor.find_best(runner.roots[0]);
|
||||
solutions.push(simplified_expr);
|
||||
} else if poly.len() == 4 {
|
||||
let guesses = approximate_rational_cubic(
|
||||
&Rational { num: 0, denom: 1 },
|
||||
&Rational { num: -91, denom: 1 },
|
||||
&Rational { num: -90, denom: 1 },
|
||||
1000
|
||||
);
|
||||
|
||||
if guesses.len() < 3 {
|
||||
continue;
|
||||
}
|
||||
|
||||
if verbose {
|
||||
println!("Guessing rational solutions: {}, {}, {}",
|
||||
guesses[0], guesses[1], guesses[2]);
|
||||
}
|
||||
|
||||
let start = format!("x ^ 3 + ({}) * x ^ 2 + ({}) * x + ({})",
|
||||
poly[2], poly[1], poly[0]);
|
||||
let goal = format!("(x - ({})) * (x - ({})) * (x - ({}))",
|
||||
guesses[0], guesses[1], guesses[2]);
|
||||
|
||||
let start_expr = parse_equation(&start).unwrap();
|
||||
let goal_expr = parse_equation(&goal).unwrap();
|
||||
|
||||
let mut egraph: EGraph = Default::default();
|
||||
egraph.add_expr(&start_expr);
|
||||
egraph.add_expr(&goal_expr);
|
||||
|
||||
let runner = Runner::default()
|
||||
.with_egraph(egraph)
|
||||
.with_node_limit(100_000)
|
||||
.with_iter_limit(100)
|
||||
.run(&*RULES);
|
||||
|
||||
let equivs = runner.egraph.equivs(&start_expr, &goal_expr);
|
||||
|
||||
if !equivs.is_empty() {
|
||||
if verbose {
|
||||
println!("Verified guessed solutions!");
|
||||
}
|
||||
|
||||
// expressions are equivalent, so the guesses are actually solutions
|
||||
for s in &guesses {
|
||||
let mut solexpr = RecExpr::default();
|
||||
solexpr.add(EquationLanguage::Num(s.clone()));
|
||||
solutions.push(solexpr);
|
||||
}
|
||||
} else {
|
||||
if verbose {
|
||||
println!("Couldn't verify guessed solutions.");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
solutions
|
||||
}
|
||||
|
||||
@@ -1,16 +1,6 @@
|
||||
use crate::language::{EGraph, EquationLanguage, Rational, RATIONAL_ONE, RATIONAL_ZERO};
|
||||
use std::{collections::HashMap, fmt};
|
||||
use egg::{AstSize, Extractor, Id};
|
||||
|
||||
#[derive(Debug,Clone)]
|
||||
pub enum SpecialTerm {
|
||||
Constant(Rational),
|
||||
PowerOfX(usize),
|
||||
Monomial(usize, Rational),
|
||||
MonicNonconstPoly(Vec<Rational>),
|
||||
Factorization(Rational, Vec<Vec<Rational>>),
|
||||
Other,
|
||||
}
|
||||
use egg::Id;
|
||||
|
||||
#[derive(Debug,Clone)]
|
||||
pub struct Factorization {
|
||||
@@ -18,7 +8,32 @@ pub struct Factorization {
|
||||
pub polynomials: Vec<Vec<Rational>>,
|
||||
}
|
||||
|
||||
// this is a property of an eclass, not a particular expression
|
||||
impl fmt::Display for Factorization {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
if self.constant_factor != RATIONAL_ONE {
|
||||
write!(f, "{}", self.constant_factor)?;
|
||||
}
|
||||
|
||||
for poly in &self.polynomials {
|
||||
write!(f, "(")?;
|
||||
for (deg, coeff) in poly.iter().enumerate() {
|
||||
if deg == 0 {
|
||||
write!(f, "{}", coeff)?;
|
||||
} else if deg == 1 {
|
||||
write!(f, " + {}x", coeff)?;
|
||||
} else {
|
||||
write!(f, " + {}x^{}", coeff, deg)?;
|
||||
}
|
||||
}
|
||||
write!(f, ")")?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
// warning: this is a property of an eclass, not a particular expression
|
||||
// so we shouldn't use anything that's not a polynomial invariant
|
||||
#[derive(Debug,Clone)]
|
||||
struct PolyStats {
|
||||
degree: usize,
|
||||
@@ -126,7 +141,7 @@ fn extract_polynomial(egraph: &EGraph, stats: &HashMap<Id, PolyStats>, id: Id) -
|
||||
|
||||
assert!(leading_deg >= remainder.len());
|
||||
|
||||
remainder.resize(leading_deg, RATIONAL_ZERO.clone());
|
||||
remainder.resize(leading_deg, RATIONAL_ZERO);
|
||||
remainder.push(leading_coeff);
|
||||
return Some(remainder);
|
||||
}
|
||||
@@ -141,10 +156,6 @@ fn extract_polynomial(egraph: &EGraph, stats: &HashMap<Id, PolyStats>, id: Id) -
|
||||
|
||||
// a monomial is either a power of x, a constant, or a product of a constant and power of x
|
||||
fn extract_monomial(egraph: &EGraph, stats: &HashMap<Id, PolyStats>, id: Id) -> Option<(usize, Rational)> {
|
||||
let extractor = Extractor::new(egraph, AstSize);
|
||||
let (_, expr) = extractor.find_best(id);
|
||||
// println!("Extract Monomial: {}", expr);
|
||||
|
||||
let st = &stats[&id];
|
||||
|
||||
assert!(st.monomial);
|
||||
@@ -221,7 +232,6 @@ fn find_monic_factorization(egraph: &EGraph, stats: &HashMap<Id, PolyStats>, id:
|
||||
|
||||
// now the whole thing is a monic nonconst poly, so would be a valid factorization,
|
||||
// but we want to go as deep as possible
|
||||
// println!("{:?}", stats[&id]);
|
||||
|
||||
// check if it is the product of two nonconstant monic polynomials
|
||||
for node in &egraph[id].nodes {
|
||||
@@ -260,8 +270,6 @@ where
|
||||
let mut result: HashMap<Id, T> = HashMap::new();
|
||||
let mut modifications: usize = 1;
|
||||
|
||||
// println!("{:?}", egraph[canonical]);
|
||||
|
||||
while modifications > 0 {
|
||||
modifications = 0;
|
||||
|
||||
@@ -285,215 +293,3 @@ where
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
/*
|
||||
pub fn analyze3(egraph: &EGraph, eclass: Id) {
|
||||
let constants = search_for(egraph, |id, _, _|
|
||||
egraph[id].data.as_ref().map(|c|c.clone())
|
||||
);
|
||||
|
||||
println!("{:?}", constants);
|
||||
|
||||
let powers_of_x = search_for(egraph, |_, node, matches| match *node {
|
||||
EquationLanguage::Unknown => Some(1),
|
||||
EquationLanguage::Mul([a,b]) => {
|
||||
if !matches.contains_key(&a) || !matches.contains_key(&b) {
|
||||
return None;
|
||||
}
|
||||
|
||||
let (dega, degb) = (matches[&a], matches[&b]);
|
||||
Some(dega + degb)
|
||||
},
|
||||
_ => None,
|
||||
});
|
||||
|
||||
println!("{:?}", powers_of_x);
|
||||
|
||||
let monomials = search_for(egraph, |id, node, _| {
|
||||
if let Some(deg) = powers_of_x.get(&id) {
|
||||
return Some((*deg, RATIONAL_ONE.clone()));
|
||||
}
|
||||
|
||||
if let Some(c) = constants.get(&id) {
|
||||
return Some((0, c.clone()));
|
||||
}
|
||||
|
||||
match *node {
|
||||
EquationLanguage::Mul([a,b]) => {
|
||||
if !constants.contains_key(&a) || !powers_of_x.contains_key(&b) {
|
||||
return None;
|
||||
}
|
||||
|
||||
let (coeff, deg) = (&constants[&a], powers_of_x[&b]);
|
||||
Some((deg, coeff.clone()))
|
||||
},
|
||||
_ => None,
|
||||
}
|
||||
});
|
||||
|
||||
println!("{:?}", monomials);
|
||||
|
||||
let monic_polynomials = search_for(egraph, |id, node, matches| {
|
||||
if let Some(deg) = powers_of_x.get(&id) {
|
||||
let mut poly: Vec<Rational> = Vec::new();
|
||||
poly.resize(*deg, RATIONAL_ZERO);
|
||||
poly.push(RATIONAL_ONE.clone());
|
||||
Some(poly)
|
||||
} else {
|
||||
match *node {
|
||||
EquationLanguage::Add([a,b]) => {
|
||||
if !matches.contains_key(&a) || !monomials.contains_key(&b) {
|
||||
return None;
|
||||
}
|
||||
|
||||
let (leading, (deg, coeff)) = (&matches[&a], &monomials[&b]);
|
||||
if leading.len() <= *deg || leading[*deg] != RATIONAL_ZERO {
|
||||
return None;
|
||||
}
|
||||
|
||||
let mut poly = leading.clone();
|
||||
poly[*deg] = coeff.clone();
|
||||
Some(poly)
|
||||
},
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
for p in &monic_polynomials {
|
||||
println!("{:?}", p);
|
||||
}
|
||||
|
||||
let factorizations: HashMap<Id, (Rational, Vec<Vec<Rational>>)> = search_for(egraph, |id, node, matches| {
|
||||
if let Some(c) = constants.get(&id) {
|
||||
return Some((c.clone(), vec![]));
|
||||
}
|
||||
|
||||
if let Some(poly) = monic_polynomials.get(&id) {
|
||||
return Some((RATIONAL_ONE.clone(), vec![poly.clone()]));
|
||||
}
|
||||
|
||||
match *node {
|
||||
EquationLanguage::Mul([a,b]) => {
|
||||
if !matches.contains_key(&a) || !monic_polynomials.contains_key(&b) {
|
||||
return None;
|
||||
}
|
||||
|
||||
let ((factor, polys), newpoly) = (&matches[&a], &monic_polynomials[&b]);
|
||||
|
||||
let mut combined: Vec<Vec<Rational>> = polys.clone();
|
||||
combined.push(newpoly.clone());
|
||||
Some((factor.clone(), combined))
|
||||
},
|
||||
_ => None,
|
||||
}
|
||||
});
|
||||
|
||||
/*
|
||||
for p in &factorizations {
|
||||
println!("{:?}", p);
|
||||
}
|
||||
*/
|
||||
|
||||
println!("{:?}", factorizations[&eclass]);
|
||||
}
|
||||
|
||||
pub fn analyze2(egraph: &EGraph) -> HashMap<Id, SpecialTerm> {
|
||||
let mut types: HashMap<Id, SpecialTerm> = HashMap::new();
|
||||
|
||||
let mut modifications: usize = 1;
|
||||
|
||||
while modifications > 0 {
|
||||
modifications = 0;
|
||||
|
||||
for cls in egraph.classes() {
|
||||
let id = cls.id;
|
||||
if types.contains_key(&id) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Some(c) = &egraph[id].data {
|
||||
types.insert(id, SpecialTerm::Constant(c.clone()));
|
||||
modifications += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
for node in &cls.nodes {
|
||||
match *node {
|
||||
EquationLanguage::Unknown => {
|
||||
types.insert(id, SpecialTerm::PowerOfX(1));
|
||||
modifications += 1;
|
||||
},
|
||||
EquationLanguage::Mul([a,b]) => {
|
||||
// as we don't know a and b yet, defer to future iteration
|
||||
if !types.contains_key(&a) || !types.contains_key(&b) {
|
||||
continue;
|
||||
}
|
||||
|
||||
match (&types[&a], &types[&b]) {
|
||||
(SpecialTerm::PowerOfX(dega), SpecialTerm::PowerOfX(degb)) => {
|
||||
types.insert(id, SpecialTerm::PowerOfX(*dega + *degb));
|
||||
modifications += 1;
|
||||
},
|
||||
(SpecialTerm::Constant(coeff), SpecialTerm::PowerOfX(deg)) => {
|
||||
types.insert(id, SpecialTerm::Monomial(*deg, coeff.clone()));
|
||||
modifications += 1;
|
||||
},
|
||||
_ => { },
|
||||
}
|
||||
},
|
||||
EquationLanguage::Add([a,b]) => {
|
||||
// as we don't know a and b yet, defer to future iteration
|
||||
if !types.contains_key(&a) || !types.contains_key(&b) {
|
||||
continue;
|
||||
}
|
||||
|
||||
match (&types[&a], &types[&b]) {
|
||||
(SpecialTerm::MonicNonconstPoly(poly), SpecialTerm::Monomial(deg, coeff)) => {
|
||||
if poly.len() <= *deg || poly[*deg] != RATIONAL_ZERO {
|
||||
continue;
|
||||
}
|
||||
|
||||
let mut poly = poly.clone();
|
||||
poly[*deg] = coeff.clone();
|
||||
types.insert(id, SpecialTerm::MonicNonconstPoly(poly));
|
||||
modifications += 1;
|
||||
},
|
||||
_ => { },
|
||||
}
|
||||
},
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
println!("{} modifications!", modifications);
|
||||
}
|
||||
|
||||
types
|
||||
}
|
||||
*/
|
||||
|
||||
impl fmt::Display for Factorization {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
if self.constant_factor != RATIONAL_ONE {
|
||||
write!(f, "{}", self.constant_factor)?;
|
||||
}
|
||||
|
||||
for poly in &self.polynomials {
|
||||
write!(f, "(")?;
|
||||
for (deg, coeff) in poly.iter().enumerate() {
|
||||
if deg == 0 {
|
||||
write!(f, "{}", coeff)?;
|
||||
} else if deg == 1 {
|
||||
write!(f, " + {}x", coeff)?;
|
||||
} else {
|
||||
write!(f, " + {}x^{}", coeff, deg)?;
|
||||
}
|
||||
}
|
||||
write!(f, ")")?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
use std::error::Error;
|
||||
use egg::*;
|
||||
use egg::{RecExpr, Id, FromOp, FromOpError};
|
||||
use crate::language::EquationLanguage;
|
||||
|
||||
pub fn parse_equation(input: &str) -> Result<RecExpr<EquationLanguage>, ParseError> {
|
||||
|
||||
Reference in New Issue
Block a user