Compare commits

..

5 Commits

Author SHA1 Message Date
Florian Stecker
fced3915ea slightly clarify the explanation 2024-08-29 09:31:44 -04:00
Florian Stecker
829f1e7fa7 add readme 2024-08-28 16:41:44 -04:00
Florian Stecker
74ffaea4c0 add cubic equation solver 2024-08-28 15:19:03 -04:00
Florian Stecker
0b8bdf3da6 cleaned up 2024-08-28 12:31:28 -04:00
Florian Stecker
c110dd6889 all examples work 2024-08-28 12:07:03 -04:00
8 changed files with 353 additions and 576 deletions

61
README.md Normal file
View File

@@ -0,0 +1,61 @@
# Symbolic solver for polynomials with rational coefficients #
## Program output ##
Using `cargo run --release`. On my laptop, the last example times out on debug settings.
Equation: (x + 50) * 10 - 150 = 100
Parsed: (- (- (* (+ x 50) 10) 150) 100)
Factorized normal form: 10(25 + 1x)
Solutions: { x = -25 }
Equation: (x - 2) * (x + 2) = 0
Parsed: (- (* (- x 2) (+ x 2)) 0)
Factorized normal form: (-2 + 1x)(2 + 1x)
Solutions: { x = 2, x = -2 }
Equation: x ^ 2 = 4
Parsed: (- (^ x 2) 4)
Factorized normal form: (-4 + 0x + 1x^2)
Solutions: { x = 2, x = -2 }
Equation: x ^ 2 - 2 = 0
Parsed: (- (- (^ x 2) 2) 0)
Factorized normal form: (-2 + 0x + 1x^2)
Solutions: { x = 2 ^ (1/2), x = -2 ^ (1/2) }
Equation: x ^ 2 = 2 * x + 15
Parsed: (- (^ x 2) (+ (* 2 x) 15))
Factorized normal form: (-15 + -2x + 1x^2)
Solutions: { x = 5, x = -3 }
Equation: (x ^ 2 - 2 * x - 15) * (x + 5) = 0
Parsed: (- (* (- (- (^ x 2) (* 2 x)) 15) (+ x 5)) 0)
Factorized normal form: (-15 + -2x + 1x^2)(5 + 1x)
Solutions: { x = 5, x = -3, x = -5 }
Equation: x ^ 3 + 3 * x ^ 2 - 25 * x - 75 = 0
Parsed: (- (- (- (+ (^ x 3) (* 3 (^ x 2))) (* 25 x)) 75) 0)
Factorized normal form: (3 + 1x)(-25 + 0x + 1x^2)
Solutions: { x = -3, x = 5, x = -5 }
Equation: x ^ 3 - 91 * x - 90 = 0
Parsed: (- (- (- (^ x 3) (* 91 x)) 90) 0)
Factorized normal form: (-90 + -91x + 0x^2 + 1x^3)
Guessing rational solutions: 10, -9, -1
Verified guessed solutions!
Solutions: { x = 10, x = -9, x = -1 }
## Notes ##
This uses an egraph with rewrite rules representing the usual rules of arithmetic, and constant folding with rational numbers. This alone is enough to solve linear equations, by just starting with the difference of the two sides and minimizing AST size. However, it can't suffice for quadratic equations, as the egraph can't create values "out of this air". For example, to solve `x^2 = 2` the egraph must at least have "seen" the expression `2^(1/2)` once.
So instead, the egraph is used to bring the expression into a standard form: a product of monic polynomials which is factored as much as possible. Since this is difficult to model with a cost function, I wrote a custom extractor (in `normal_form.rs`).
Then we can use the usual quadratic formula to find solutions. The obtained solutions are afterwards simplified by another egraph, with a special "square root of a square integer" rule. This is pretty minimal, a more complete solution would also be able to deal with square roots of rational numbers etc.
Finally, the unfactorized cubic equation (equation 7) gets factored into a linear and a quadratic factor by the egraph. This is kind of "magic" though: the correct terms seem to come up randomly while mutating the egraph. For general cubics we can't expect this to happen, even if they have integer solutions. The additional equation 8 is an example for this.
While there is a general formula for the solutions of a cubic equation, it contains trigonometric functions and/or complex numbers, and we don't want to deal with that here. So we use a different strategy which can handle cubic equations with 3 rational solutions: we first solve the equation numerically using the cubic formula, then find rational expressions for the numerical solutions using continued fractions. Finally, we verify the resulting factorization by using an egraph. This equivalence check required increased node and iteration limits, and takes a few seconds on my laptop.
This has essentially only been tested on the 8 examples above. So I would expect there to still be a few bugs. The algorithm is also spectacularly inefficient, considering that the equality of rational polynomials is almost trivial to compute.

91
src/cubic.rs Normal file
View File

@@ -0,0 +1,91 @@
use std::f64::consts::PI;
use crate::language::Rational;
pub fn approximate_rational_cubic(b: &Rational, c: &Rational, d: &Rational, limit: u64) -> Vec<Rational> {
let numerical_sols = solve_cubic_numerically(
1.0,
(b.num as f64) / (b.denom as f64),
(c.num as f64) / (c.denom as f64),
(d.num as f64) / (d.denom as f64));
numerical_sols.into_iter().map(|x|rational_approx(x, limit)).collect()
}
// assuming leading coefficient is not 0, and also that we don't have a triple root
pub fn solve_cubic_numerically(a: f64, b: f64, c: f64, d: f64) -> Vec<f64> {
assert_ne!(a, 0.0);
let b = b/a;
let c = c/a;
let d = d/a;
let b2 = b * b;
let b3 = b2 * b;
solve_depressed_cubic_numerically(
c - b2 / 3.0,
2.0 * b3 / 27.0 - b * c / 3.0 + d
).into_iter().map(|u|u - b / 3.0 / a).collect()
}
fn solve_depressed_cubic_numerically(p: f64, q: f64) -> Vec<f64> {
let disc = p * p * p / 27.0 + q * q / 4.0;
if disc < 0.0 {
let r = 2.0 * (-p / 3.0).sqrt();
let t = 3.0 * q / 2.0 / p * (- 3.0 / p).sqrt();
let phi = t.acos() / 3.0;
vec![r * (phi).cos(),
r * (phi + 2.0 * PI / 3.0).cos(),
r * (phi + 4.0 * PI / 3.0).cos()]
} else {
// not implemented at the moment
let u = - q / 2.0 + disc.sqrt();
let v = - q / 2.0 - disc.sqrt();
vec![u.powf(1.0/3.0) + v.powf(1.0/3.0)]
}
}
// pretty inefficient implementation of continued fraction approximation
// but doesn't matter here
pub fn rational_approx(x: f64, limit: u64) -> Rational {
if x < 0.0 {
let Rational { num, denom } = rational_approx(-x, limit);
return Rational { num: -num, denom }
}
let mut num = 0;
let mut denom = 0;
for l in 0 .. 10 {
let Some((p,q)) = rational_approx_level(x, l) else { break; };
if q > limit {
break;
}
(num, denom) = (p, q);
}
Rational{
num: num as i64,
denom
}
}
fn rational_approx_level(x: f64, level: usize) -> Option<(u64, u64)> {
// we expect very big numbers or infinity if the previous iteration was exact
if x > 1e9 {
return None;
}
if level == 0 {
Some((x as u64, 1))
} else {
let (p,q) = rational_approx_level(1.0 / (x - x.floor()), level-1)?;
let floorx = x as u64;
Some((q + floorx * p, p))
}
}

View File

@@ -1,256 +0,0 @@
use core::fmt;
use std::{cmp::Ordering, fmt::Display};
use crate::language::{EquationLanguage, Rational, RATIONAL_ONE, RATIONAL_ZERO};
use egg::{Id, RecExpr};
#[derive(Debug,Clone,Copy)]
pub struct PolyStat {
degree: usize,
factors: usize, // non-constant factors
ops: usize,
monomial: bool,
sum_of_monomials: bool,
monic: bool,
factorized: bool, // a product of monic polynomials and at least one constant
}
#[derive(Debug,Clone,Copy)]
pub enum FactorizationCost {
UnwantedOps,
Polynomial(PolyStat)
}
fn score(cost: FactorizationCost) -> usize {
match cost {
FactorizationCost::UnwantedOps => 10000,
FactorizationCost::Polynomial(p) =>
if !p.factorized {
1000 + p.ops
} else {
100 * (9 - p.factors) + p.ops
},
}
}
impl PartialEq for FactorizationCost {
fn eq(&self, other: &Self) -> bool {
score(*self) == score(*other)
}
}
impl PartialOrd for FactorizationCost {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
usize::partial_cmp(&score(*self), &score(*other))
}
}
pub struct FactorizationCostFn;
impl egg::CostFunction<EquationLanguage> for FactorizationCostFn {
type Cost = FactorizationCost;
fn cost<C>(&mut self, enode: &EquationLanguage, mut costs: C) -> Self::Cost
where
C: FnMut(Id) -> Self::Cost,
{
match enode {
EquationLanguage::Add([a,b]) => {
match (costs(*a), costs(*b)) {
(FactorizationCost::Polynomial(p1),FactorizationCost::Polynomial(p2)) => {
// we only ever want to add monomials
let result_monic = if p1.degree > p2.degree {
p1.monic
} else if p2.degree > p1.degree {
p2.monic
} else {
false
};
/*
if *a == Id::from(4) && *b == Id::from(19) {
println!("HERE {:?} {:?}", p1, p2);
}
*/
if !p1.sum_of_monomials || !p2.sum_of_monomials {
FactorizationCost::UnwantedOps
} else {
FactorizationCost::Polynomial(PolyStat {
degree: usize::max(p1.degree, p2.degree),
factors: 1,
ops: p1.ops + p2.ops + 1,
monomial: false,
sum_of_monomials: p1.sum_of_monomials && p2.sum_of_monomials,
monic: result_monic,
factorized: result_monic,
})
}
},
_ => FactorizationCost::UnwantedOps
}
},
EquationLanguage::Mul([a,b]) => {
match (costs(*a), costs(*b)) {
(FactorizationCost::Polynomial(p1), FactorizationCost::Polynomial(p2)) => {
FactorizationCost::Polynomial(PolyStat {
degree: p1.degree + p2.degree,
factors: p1.factors + p2.factors,
ops: p1.ops + p2.ops + 1,
monomial: p1.monomial && p2.monomial,
sum_of_monomials: p1.monomial && p2.monomial,
monic: p1.monic && p2.monic,
factorized: (p1.monic && p2.factorized) || (p2.monic && p1.factorized)
})
},
_ => FactorizationCost::UnwantedOps
}
},
EquationLanguage::Num(c) => {
FactorizationCost::Polynomial(PolyStat {
degree: 0,
factors: 0,
ops: 0,
monomial: true,
sum_of_monomials: true,
monic: false,
factorized: true
})
},
EquationLanguage::Unknown => {
FactorizationCost::Polynomial(PolyStat {
degree: 1,
factors: 1,
ops: 0,
monomial: true,
sum_of_monomials: true,
monic: true,
factorized: true
})
},
_ => FactorizationCost::UnwantedOps,
}
}
}
#[derive(Debug,Clone)]
pub struct Factorization {
pub constant_factor: Rational,
pub polynomials: Vec<Vec<Rational>>,
}
impl Display for Factorization {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if self.constant_factor != RATIONAL_ONE {
write!(f, "{}", self.constant_factor)?;
}
for poly in &self.polynomials {
write!(f, "(")?;
for (deg, coeff) in poly.iter().enumerate() {
if deg == 0 {
write!(f, "{}", coeff)?;
} else if deg == 1 {
write!(f, " + {}x", coeff)?;
} else {
write!(f, " + {}x^{}", coeff, deg)?;
}
}
write!(f, ")")?;
}
Ok(())
}
}
pub fn extract_factorization(expr: &RecExpr<EquationLanguage>) -> Factorization {
let root_id: Id = Id::from(expr.as_ref().len()-1);
let mut constant_factor: Option<Rational> = None;
let mut factors: Vec<Vec<Rational>> = Vec::new();
let mut todo: Vec<Id> = Vec::new();
todo.push(root_id);
while todo.len() > 0 {
let id = todo.pop().unwrap();
match &expr[id] {
EquationLanguage::Mul([a,b]) => {
todo.push(*a);
todo.push(*b);
},
EquationLanguage::Num(x) => {
assert!(constant_factor.is_none());
constant_factor = Some(x.clone());
},
_ => {
factors.push(extract_polynomial(expr, id));
}
}
}
Factorization {
constant_factor: constant_factor.unwrap_or_else(||RATIONAL_ONE.clone()),
polynomials: factors
}
}
fn extract_polynomial(expr: &RecExpr<EquationLanguage>, id: Id) -> Vec<Rational> {
let mut result: Vec<Rational> = Vec::new();
let mut todo: Vec<Id> = Vec::new();
todo.push(id);
while todo.len() > 0 {
let id = todo.pop().unwrap();
match &expr[id] {
EquationLanguage::Add([a,b]) => {
todo.push(*a);
todo.push(*b);
},
_ => {
let (deg, coeff) = extract_monomial(expr, id);
result.resize(result.len().max(deg), RATIONAL_ZERO.clone());
if result.len() <= deg {
result.push(coeff);
} else {
assert!(result[deg] == RATIONAL_ZERO);
result[deg] = coeff;
}
}
}
}
result
}
fn extract_monomial(expr: &RecExpr<EquationLanguage>, id: Id) -> (usize, Rational) {
let mut coeff: Option<Rational> = None;
let mut deg: usize = 0;
let mut todo: Vec<Id> = Vec::new();
todo.push(id);
while todo.len() > 0 {
let id = todo.pop().unwrap();
match &expr[id] {
EquationLanguage::Unknown => {
deg += 1;
},
EquationLanguage::Mul([a,b]) => {
todo.push(*a);
todo.push(*b);
},
EquationLanguage::Num(x) => {
assert!(coeff.is_none());
coeff = Some(x.clone());
},
_ => {
panic!("Not a rational polynomial in normal form!");
}
}
}
(deg, coeff.unwrap_or_else(||RATIONAL_ONE.clone()))
}

View File

@@ -1,5 +1,5 @@
use std::{cmp::Ordering, fmt::{self,Display, Formatter}, str::FromStr, sync::LazyLock};
use egg::{define_language, merge_option, rewrite as rw, Analysis, DidMerge, Id, Language, Subst, Var};
use egg::{define_language, merge_option, rewrite as rw, Analysis, Applier, DidMerge, Id, Language, PatternAst, Subst, Symbol, Var};
pub type EGraph = egg::EGraph<EquationLanguage, ConstantFold>;
pub type Rewrite = egg::Rewrite<EquationLanguage, ConstantFold>;
@@ -39,9 +39,19 @@ impl Display for Rational {
}
impl FromStr for Rational {
type Err = std::num::ParseIntError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
Ok(Rational { num: s.parse::<i64>()?, denom: 1 })
type Err = String;
fn from_str(s: &str) -> Result<Self, String> {
let err = || Err(format!("Couldn't parse rational: {}", s));
if let Ok(num) = s.parse::<i64>() {
Ok(Rational { num, denom: 1 })
} else if let Some((snum, sdenom)) = s.split_once('/') {
let Ok(num) = snum.parse::<i64>() else { return err(); };
let Ok(denom) = sdenom.parse::<u64>() else { return err(); };
Ok(Rational { num, denom })
} else {
err()
}
}
}
@@ -174,6 +184,39 @@ fn is_const(var: &str) -> impl Fn(&mut EGraph, Id, &Subst) -> bool {
}
}
struct IntegerSqrt {
var: Var,
}
impl Applier<EquationLanguage, ConstantFold> for IntegerSqrt {
fn apply_one(&self,
egraph: &mut EGraph,
matched_id: Id,
subst: &Subst,
_searcher_pattern: Option<&PatternAst<EquationLanguage>>,
_rule_name: Symbol)
-> Vec<Id> {
let var_id = subst[self.var];
if let Some(value) = &egraph[var_id].data {
if value.denom == 1 && value.num >= 0 {
// isqrt is nightly only, so we just do this, adding 0.1 against rounding errors
let sq = (f64::sqrt(value.num as f64) + 0.1) as i64;
if value.num == sq*sq {
// println!("square root of integer {} is {}", value.num, sq);
let sq_id = egraph.add(EquationLanguage::Num(Rational { num: sq, denom: 1 }));
egraph.union(matched_id, sq_id);
return vec![matched_id, sq_id];
}
}
}
vec![]
}
}
pub static RULES: LazyLock<Vec<Rewrite>> = LazyLock::new(||vec![
rw!("commute-add"; "(+ ?x ?y)" => "(+ ?y ?x)"),
rw!("commute-mul"; "(* ?x ?y)" => "(* ?y ?x)"),
@@ -185,6 +228,8 @@ pub static RULES: LazyLock<Vec<Rewrite>> = LazyLock::new(||vec![
rw!("mul-0"; "(* ?x 0)" => "0"),
rw!("mul-1"; "(* ?x 1)" => "?x"),
rw!("0-sub"; "(- 0 ?x)" => "(- ?x)"),
rw!("add-sub"; "(+ ?x (* (-1) ?x))" => "0"),
// division by zero shouldn't happen unless input is invalid
rw!("mul-div"; "(* ?x (rec ?x))" => "1" if is_nonzero_const("?y")),
@@ -202,7 +247,7 @@ pub static RULES: LazyLock<Vec<Rewrite>> = LazyLock::new(||vec![
rw!("factor_poly"; "(+ (* x ?x) ?y)" => "(* ?x (+ x (* ?y (rec ?x))))" if is_nonzero_const("?x")),
// rw!("integer_sqrt"; "(^ ?x (1/2))" => {} if is_const("?x")),
rw!("integer_sqrt"; "(^ ?x (/ 1 2))" => { IntegerSqrt { var: "?x".parse().unwrap() } } if is_const("?x")),
]);
pub struct PlusTimesCostFn;

View File

@@ -1,5 +1,5 @@
pub mod language;
pub mod normal_form;
pub mod parse;
pub mod factorization;
pub mod output;
pub mod cubic;

View File

@@ -1,6 +1,6 @@
use egg::{AstSize, Extractor, Id, Pattern, PatternAst, RecExpr, Runner, Searcher};
use solveq::factorization::{FactorizationCost, FactorizationCostFn};
use solveq::language::{ConstantFold, EquationLanguage, Rational, RULES, EGraph};
use egg::{AstSize, Extractor, Id, RecExpr, Runner};
use solveq::cubic::approximate_rational_cubic;
use solveq::language::{EquationLanguage, Rational, RULES, EGraph};
use solveq::normal_form::extract_normal_form;
use solveq::parse::parse_equation;
use solveq::output::print_term;
@@ -13,96 +13,136 @@ static TEST_EQUATIONS: &[&str] = &[
"x ^ 2 = 2 * x + 15",
"(x ^ 2 - 2 * x - 15) * (x + 5) = 0",
"x ^ 3 + 3 * x ^ 2 - 25 * x - 75 = 0",
"x ^ 3 - 91 * x - 90 = 0",
];
fn main() {
// let expr: RecExpr<EquationLanguage> = "(+ (* x (+ x -2)) -15)".parse().unwrap();
// let expr: RecExpr<EquationLanguage> = "(* (+ (* x (+ x -2)) -15) (+ x 5))".parse().unwrap();
// println!("{:?}", get_expression_cost(&expr));
for eq in TEST_EQUATIONS {
println!("Equation: {}", *eq);
let mut start = parse_equation(*eq).unwrap();
let root_id = Id::from(start.as_ref().len()-1);
let EquationLanguage::Equals([left, right]) = start[root_id]
else { panic!("Not an equation without an equals sign!"); };
start[root_id] = EquationLanguage::Sub([left, right]);
let solutions = solve(*eq, true);
println!("Parsed: {}", &start);
let solutions_str: Vec<String> = solutions
.iter()
.map(|expr| format!("x = {}", print_term(expr)))
.collect();
println!("Solutions: {{ {} }}", solutions_str.join(", "));
let mut runner = Runner::default()
.with_explanations_enabled()
.with_expr(&start)
.run(&*RULES);
let extractor = Extractor::new(&runner.egraph, FactorizationCostFn);
let (best_cost, best_expr) = extractor.find_best(runner.roots[0]);
// println!("Best expresssion: {} {:?}", best_expr, best_cost);
// println!("{}", runner.explain_equivalence(&start, &best_expr).get_flat_string());
let Some(factorization) = extract_normal_form(&runner.egraph, runner.roots[0]) else {
panic!("Couldn't factorize polynomial!");
};
println!("Factorized normal form: {}", factorization);
/*
get_expression_cost("(* x (+ x -2))", &runner.egraph);
get_expression_cost("-15", &runner.egraph);
get_expression_cost("(+ (* x (+ x -2)) -15)", &runner.egraph);
*/
// let factorization = extract_factorization(&best_expr);
let mut solutions: Vec<String> = Vec::new();
for poly in &factorization.polynomials {
if poly.len() == 2 { // linear factor
let Rational { num, denom } = &poly[0];
solutions.push(format!("x = {}", Rational { num: -*num, denom: *denom }));
} else if poly.len() == 3 { // quadratic factor
let Rational { num: num0, denom: denom0 } = &poly[0];
let Rational { num: num1, denom: denom1 } = &poly[1];
let sol1 = format!("- ({num1})/(2 * ({denom1})) + ((({num1})/(2 * ({denom1}))) ^ 2 - ({num0}) / ({denom0})) ^ (1/2)");
let sol2 = format!("- ({num1})/(2 * ({denom1})) - ((({num1})/(2 * ({denom1}))) ^ 2 - ({num0}) / ({denom0})) ^ (1/2)");
let expr = parse_equation(&sol1).unwrap();
let runner = Runner::default()
.with_expr(&expr)
.run(&*RULES);
let extractor = Extractor::new(&runner.egraph, AstSize);
let (_, simplified_expr) = extractor.find_best(runner.roots[0]);
solutions.push(format!("x = {}", print_term(&simplified_expr)));
let expr = parse_equation(&sol2).unwrap();
let runner = Runner::default()
.with_expr(&expr)
.run(&*RULES);
let extractor = Extractor::new(&runner.egraph, AstSize);
let (_, simplified_expr) = extractor.find_best(runner.roots[0]);
solutions.push(format!("x = {}", print_term(&simplified_expr)));
}
}
println!("Solutions: {{ {} }}", solutions.join(", "));
println!("");
}
}
fn get_expression_cost(expr: &str, egraph: &EGraph) {
// let mut egraph = EGraph::new(ConstantFold::default());
// let id = egraph.add_expr(expr);
let pattern: Pattern<EquationLanguage> = expr.parse().unwrap();
let matches = pattern.search(egraph);
for m in matches {
let extractor = Extractor::new(&egraph, FactorizationCostFn);
let (cost, _) = extractor.find_best(m.eclass);
pub fn solve(eq: &str, verbose: bool) -> Vec<RecExpr<EquationLanguage>> {
let mut start = parse_equation(eq).unwrap();
let root_id = Id::from(start.as_ref().len()-1);
let EquationLanguage::Equals([left, right]) = start[root_id]
else { panic!("Not an equation without an equals sign!"); };
start[root_id] = EquationLanguage::Sub([left, right]);
println!("expr: {}, id: {}, cost: {:?}", expr, m.eclass, cost);
if verbose {
println!("Parsed: {}", &start);
}
// cost
let runner = Runner::default()
.with_explanations_enabled()
.with_expr(&start)
.run(&*RULES);
let Some(factorization) = extract_normal_form(&runner.egraph, runner.roots[0]) else {
panic!("Couldn't factorize polynomial!");
};
if verbose {
println!("Factorized normal form: {}", factorization);
}
let mut solutions: Vec<RecExpr<EquationLanguage>> = Vec::new();
for poly in &factorization.polynomials {
if poly.len() == 2 { // linear factor
let Rational { num, denom } = &poly[0];
let mut solexpr = RecExpr::default();
solexpr.add(EquationLanguage::Num(Rational {
num: -*num,
denom: *denom
}));
solutions.push(solexpr);
} else if poly.len() == 3 { // quadratic factor
let Rational { num: num0, denom: denom0 } = &poly[0];
let Rational { num: num1, denom: denom1 } = &poly[1];
let sol1 = format!("- ({num1})/(2 * ({denom1})) + ((({num1})/(2 * ({denom1}))) ^ 2 - ({num0}) / ({denom0})) ^ (1/2)");
let sol2 = format!("- ({num1})/(2 * ({denom1})) - ((({num1})/(2 * ({denom1}))) ^ 2 - ({num0}) / ({denom0})) ^ (1/2)");
let expr = parse_equation(&sol1).unwrap();
let runner = Runner::default()
.with_expr(&expr)
.run(&*RULES);
let extractor = Extractor::new(&runner.egraph, AstSize);
let (_, simplified_expr) = extractor.find_best(runner.roots[0]);
solutions.push(simplified_expr);
let expr = parse_equation(&sol2).unwrap();
let runner = Runner::default()
.with_expr(&expr)
.run(&*RULES);
let extractor = Extractor::new(&runner.egraph, AstSize);
let (_, simplified_expr) = extractor.find_best(runner.roots[0]);
solutions.push(simplified_expr);
} else if poly.len() == 4 {
let guesses = approximate_rational_cubic(
&Rational { num: 0, denom: 1 },
&Rational { num: -91, denom: 1 },
&Rational { num: -90, denom: 1 },
1000
);
if guesses.len() < 3 {
continue;
}
if verbose {
println!("Guessing rational solutions: {}, {}, {}",
guesses[0], guesses[1], guesses[2]);
}
let start = format!("x ^ 3 + ({}) * x ^ 2 + ({}) * x + ({})",
poly[2], poly[1], poly[0]);
let goal = format!("(x - ({})) * (x - ({})) * (x - ({}))",
guesses[0], guesses[1], guesses[2]);
let start_expr = parse_equation(&start).unwrap();
let goal_expr = parse_equation(&goal).unwrap();
let mut egraph: EGraph = Default::default();
egraph.add_expr(&start_expr);
egraph.add_expr(&goal_expr);
let runner = Runner::default()
.with_egraph(egraph)
.with_node_limit(100_000)
.with_iter_limit(100)
.run(&*RULES);
let equivs = runner.egraph.equivs(&start_expr, &goal_expr);
if !equivs.is_empty() {
if verbose {
println!("Verified guessed solutions!");
}
// expressions are equivalent, so the guesses are actually solutions
for s in &guesses {
let mut solexpr = RecExpr::default();
solexpr.add(EquationLanguage::Num(s.clone()));
solutions.push(solexpr);
}
} else {
if verbose {
println!("Couldn't verify guessed solutions.");
}
}
}
}
solutions
}

View File

@@ -1,16 +1,6 @@
use crate::language::{EGraph, EquationLanguage, Rational, RATIONAL_ONE, RATIONAL_ZERO};
use std::{collections::HashMap, fmt};
use egg::{AstSize, Extractor, Id};
#[derive(Debug,Clone)]
pub enum SpecialTerm {
Constant(Rational),
PowerOfX(usize),
Monomial(usize, Rational),
MonicNonconstPoly(Vec<Rational>),
Factorization(Rational, Vec<Vec<Rational>>),
Other,
}
use egg::Id;
#[derive(Debug,Clone)]
pub struct Factorization {
@@ -18,7 +8,32 @@ pub struct Factorization {
pub polynomials: Vec<Vec<Rational>>,
}
// this is a property of an eclass, not a particular expression
impl fmt::Display for Factorization {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if self.constant_factor != RATIONAL_ONE {
write!(f, "{}", self.constant_factor)?;
}
for poly in &self.polynomials {
write!(f, "(")?;
for (deg, coeff) in poly.iter().enumerate() {
if deg == 0 {
write!(f, "{}", coeff)?;
} else if deg == 1 {
write!(f, " + {}x", coeff)?;
} else {
write!(f, " + {}x^{}", coeff, deg)?;
}
}
write!(f, ")")?;
}
Ok(())
}
}
// warning: this is a property of an eclass, not a particular expression
// so we shouldn't use anything that's not a polynomial invariant
#[derive(Debug,Clone)]
struct PolyStats {
degree: usize,
@@ -126,7 +141,7 @@ fn extract_polynomial(egraph: &EGraph, stats: &HashMap<Id, PolyStats>, id: Id) -
assert!(leading_deg >= remainder.len());
remainder.resize(leading_deg, RATIONAL_ZERO.clone());
remainder.resize(leading_deg, RATIONAL_ZERO);
remainder.push(leading_coeff);
return Some(remainder);
}
@@ -141,10 +156,6 @@ fn extract_polynomial(egraph: &EGraph, stats: &HashMap<Id, PolyStats>, id: Id) -
// a monomial is either a power of x, a constant, or a product of a constant and power of x
fn extract_monomial(egraph: &EGraph, stats: &HashMap<Id, PolyStats>, id: Id) -> Option<(usize, Rational)> {
let extractor = Extractor::new(egraph, AstSize);
let (_, expr) = extractor.find_best(id);
// println!("Extract Monomial: {}", expr);
let st = &stats[&id];
assert!(st.monomial);
@@ -221,7 +232,6 @@ fn find_monic_factorization(egraph: &EGraph, stats: &HashMap<Id, PolyStats>, id:
// now the whole thing is a monic nonconst poly, so would be a valid factorization,
// but we want to go as deep as possible
// println!("{:?}", stats[&id]);
// check if it is the product of two nonconstant monic polynomials
for node in &egraph[id].nodes {
@@ -260,8 +270,6 @@ where
let mut result: HashMap<Id, T> = HashMap::new();
let mut modifications: usize = 1;
// println!("{:?}", egraph[canonical]);
while modifications > 0 {
modifications = 0;
@@ -285,215 +293,3 @@ where
result
}
/*
pub fn analyze3(egraph: &EGraph, eclass: Id) {
let constants = search_for(egraph, |id, _, _|
egraph[id].data.as_ref().map(|c|c.clone())
);
println!("{:?}", constants);
let powers_of_x = search_for(egraph, |_, node, matches| match *node {
EquationLanguage::Unknown => Some(1),
EquationLanguage::Mul([a,b]) => {
if !matches.contains_key(&a) || !matches.contains_key(&b) {
return None;
}
let (dega, degb) = (matches[&a], matches[&b]);
Some(dega + degb)
},
_ => None,
});
println!("{:?}", powers_of_x);
let monomials = search_for(egraph, |id, node, _| {
if let Some(deg) = powers_of_x.get(&id) {
return Some((*deg, RATIONAL_ONE.clone()));
}
if let Some(c) = constants.get(&id) {
return Some((0, c.clone()));
}
match *node {
EquationLanguage::Mul([a,b]) => {
if !constants.contains_key(&a) || !powers_of_x.contains_key(&b) {
return None;
}
let (coeff, deg) = (&constants[&a], powers_of_x[&b]);
Some((deg, coeff.clone()))
},
_ => None,
}
});
println!("{:?}", monomials);
let monic_polynomials = search_for(egraph, |id, node, matches| {
if let Some(deg) = powers_of_x.get(&id) {
let mut poly: Vec<Rational> = Vec::new();
poly.resize(*deg, RATIONAL_ZERO);
poly.push(RATIONAL_ONE.clone());
Some(poly)
} else {
match *node {
EquationLanguage::Add([a,b]) => {
if !matches.contains_key(&a) || !monomials.contains_key(&b) {
return None;
}
let (leading, (deg, coeff)) = (&matches[&a], &monomials[&b]);
if leading.len() <= *deg || leading[*deg] != RATIONAL_ZERO {
return None;
}
let mut poly = leading.clone();
poly[*deg] = coeff.clone();
Some(poly)
},
_ => None,
}
}
});
for p in &monic_polynomials {
println!("{:?}", p);
}
let factorizations: HashMap<Id, (Rational, Vec<Vec<Rational>>)> = search_for(egraph, |id, node, matches| {
if let Some(c) = constants.get(&id) {
return Some((c.clone(), vec![]));
}
if let Some(poly) = monic_polynomials.get(&id) {
return Some((RATIONAL_ONE.clone(), vec![poly.clone()]));
}
match *node {
EquationLanguage::Mul([a,b]) => {
if !matches.contains_key(&a) || !monic_polynomials.contains_key(&b) {
return None;
}
let ((factor, polys), newpoly) = (&matches[&a], &monic_polynomials[&b]);
let mut combined: Vec<Vec<Rational>> = polys.clone();
combined.push(newpoly.clone());
Some((factor.clone(), combined))
},
_ => None,
}
});
/*
for p in &factorizations {
println!("{:?}", p);
}
*/
println!("{:?}", factorizations[&eclass]);
}
pub fn analyze2(egraph: &EGraph) -> HashMap<Id, SpecialTerm> {
let mut types: HashMap<Id, SpecialTerm> = HashMap::new();
let mut modifications: usize = 1;
while modifications > 0 {
modifications = 0;
for cls in egraph.classes() {
let id = cls.id;
if types.contains_key(&id) {
continue;
}
if let Some(c) = &egraph[id].data {
types.insert(id, SpecialTerm::Constant(c.clone()));
modifications += 1;
continue;
}
for node in &cls.nodes {
match *node {
EquationLanguage::Unknown => {
types.insert(id, SpecialTerm::PowerOfX(1));
modifications += 1;
},
EquationLanguage::Mul([a,b]) => {
// as we don't know a and b yet, defer to future iteration
if !types.contains_key(&a) || !types.contains_key(&b) {
continue;
}
match (&types[&a], &types[&b]) {
(SpecialTerm::PowerOfX(dega), SpecialTerm::PowerOfX(degb)) => {
types.insert(id, SpecialTerm::PowerOfX(*dega + *degb));
modifications += 1;
},
(SpecialTerm::Constant(coeff), SpecialTerm::PowerOfX(deg)) => {
types.insert(id, SpecialTerm::Monomial(*deg, coeff.clone()));
modifications += 1;
},
_ => { },
}
},
EquationLanguage::Add([a,b]) => {
// as we don't know a and b yet, defer to future iteration
if !types.contains_key(&a) || !types.contains_key(&b) {
continue;
}
match (&types[&a], &types[&b]) {
(SpecialTerm::MonicNonconstPoly(poly), SpecialTerm::Monomial(deg, coeff)) => {
if poly.len() <= *deg || poly[*deg] != RATIONAL_ZERO {
continue;
}
let mut poly = poly.clone();
poly[*deg] = coeff.clone();
types.insert(id, SpecialTerm::MonicNonconstPoly(poly));
modifications += 1;
},
_ => { },
}
},
_ => {}
}
}
}
println!("{} modifications!", modifications);
}
types
}
*/
impl fmt::Display for Factorization {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if self.constant_factor != RATIONAL_ONE {
write!(f, "{}", self.constant_factor)?;
}
for poly in &self.polynomials {
write!(f, "(")?;
for (deg, coeff) in poly.iter().enumerate() {
if deg == 0 {
write!(f, "{}", coeff)?;
} else if deg == 1 {
write!(f, " + {}x", coeff)?;
} else {
write!(f, " + {}x^{}", coeff, deg)?;
}
}
write!(f, ")")?;
}
Ok(())
}
}

View File

@@ -1,5 +1,5 @@
use std::error::Error;
use egg::*;
use egg::{RecExpr, Id, FromOp, FromOpError};
use crate::language::EquationLanguage;
pub fn parse_equation(input: &str) -> Result<RecExpr<EquationLanguage>, ParseError> {