add cubic equation solver

This commit is contained in:
Florian Stecker 2024-08-28 15:19:03 -04:00
parent 0b8bdf3da6
commit 74ffaea4c0
3 changed files with 139 additions and 1 deletions

82
src/cubic.rs Normal file
View File

@ -0,0 +1,82 @@
use std::f64::consts::PI;
use crate::language::Rational;
pub fn approximate_rational_cubic(b: &Rational, c: &Rational, d: &Rational, limit: u64) -> Vec<Rational> {
let numerical_sols = solve_cubic_numerically(
1.0,
(b.num as f64) / (b.denom as f64),
(c.num as f64) / (c.denom as f64),
(d.num as f64) / (d.denom as f64));
numerical_sols.into_iter().map(|x|rational_approx(x, limit)).collect()
}
// assuming leading coefficient is not 0
pub fn solve_cubic_numerically(a: f64, b: f64, c: f64, d: f64) -> Vec<f64> {
assert_ne!(a, 0.0);
let b = b/a;
let c = c/a;
let d = d/a;
let b2 = b * b;
let b3 = b2 * b;
solve_depressed_cubic_numerically(
c - b2 / 3.0,
2.0 * b3 / 27.0 - b * c / 3.0 + d
).into_iter().map(|u|u - b / 3.0 / a).collect()
}
fn solve_depressed_cubic_numerically(p: f64, q: f64) -> Vec<f64> {
let disc = 4.0 * p * p * p + 27.0 * q * q;
if disc < 0.0 {
let r = 2.0 * (-p / 3.0).sqrt();
let t = 3.0 * q / 2.0 / p * (- 3.0 / p).sqrt();
let phi = t.acos() / 3.0;
vec![r * (phi).cos(),
r * (phi + 2.0 * PI / 3.0).cos(),
r * (phi + 4.0 * PI / 3.0).cos()]
} else {
// not implemented at the moment
vec![]
}
}
pub fn rational_approx(x: f64, limit: u64) -> Rational {
if x < 0.0 {
let Rational { num, denom } = rational_approx(-x, limit);
return Rational { num: -num, denom }
}
let mut num = 0;
let mut denom = 0;
for l in 0 .. 10 {
let (p,q) = rational_approx_level(x, l);
if q > limit {
break;
}
(num, denom) = (p, q);
}
Rational{
num: num as i64,
denom: denom
}
}
fn rational_approx_level(x: f64, level: usize) -> (u64, u64) {
if level == 0 {
(x as u64, 1)
} else {
let (p,q) = rational_approx_level(1.0 / (x - x.floor()), level-1);
let floorx = x as u64;
(q + floorx * p, p)
}
}

View File

@ -2,3 +2,4 @@ pub mod language;
pub mod normal_form;
pub mod parse;
pub mod output;
pub mod cubic;

View File

@ -1,5 +1,6 @@
use egg::{AstSize, Extractor, Id, RecExpr, Runner};
use solveq::language::{EquationLanguage, Rational, RULES};
use solveq::cubic::approximate_rational_cubic;
use solveq::language::{EquationLanguage, Rational, RULES, EGraph};
use solveq::normal_form::extract_normal_form;
use solveq::parse::parse_equation;
use solveq::output::print_term;
@ -12,6 +13,7 @@ static TEST_EQUATIONS: &[&str] = &[
"x ^ 2 = 2 * x + 15",
"(x ^ 2 - 2 * x - 15) * (x + 5) = 0",
"x ^ 3 + 3 * x ^ 2 - 25 * x - 75 = 0",
"x ^ 3 - 91 * x - 90 = 0",
];
fn main() {
@ -86,6 +88,59 @@ pub fn solve(eq: &str, verbose: bool) -> Vec<RecExpr<EquationLanguage>> {
let extractor = Extractor::new(&runner.egraph, AstSize);
let (_, simplified_expr) = extractor.find_best(runner.roots[0]);
solutions.push(simplified_expr);
} else if poly.len() == 4 {
let guesses = approximate_rational_cubic(
&Rational { num: 0, denom: 1 },
&Rational { num: -91, denom: 1 },
&Rational { num: -90, denom: 1 },
1000
);
if guesses.len() < 3 {
continue;
}
if verbose {
println!("Guessing rational solutions: {}, {}, {}",
guesses[0], guesses[1], guesses[2]);
}
let start = format!("x ^ 3 + ({}) * x ^ 2 + ({}) * x + ({})",
poly[2], poly[1], poly[0]);
let goal = format!("(x - ({})) * (x - ({})) * (x - ({}))",
guesses[0], guesses[1], guesses[2]);
let start_expr = parse_equation(&start).unwrap();
let goal_expr = parse_equation(&goal).unwrap();
let mut egraph: EGraph = Default::default();
egraph.add_expr(&start_expr);
egraph.add_expr(&goal_expr);
let runner = Runner::default()
.with_egraph(egraph)
.with_node_limit(100_000)
.with_iter_limit(100)
.run(&*RULES);
let equivs = runner.egraph.equivs(&start_expr, &goal_expr);
if !equivs.is_empty() {
if verbose {
println!("Verified guessed solutions!");
}
// expressions are equivalent, so the guesses are actually solutions
for s in &guesses {
let mut solexpr = RecExpr::default();
solexpr.add(EquationLanguage::Num(s.clone()));
solutions.push(solexpr);
}
} else {
if verbose {
println!("Couldn't verify guessed solutions.");
}
}
}
}