broken state

This commit is contained in:
Florian Stecker 2024-08-27 23:24:51 -04:00
parent 68b6293028
commit d0265ea340
5 changed files with 384 additions and 199 deletions

249
src/factorization.rs Normal file
View File

@ -0,0 +1,249 @@
use core::fmt;
use std::{cmp::Ordering, fmt::Display};
use crate::language::{EquationLanguage, Rational, RATIONAL_ONE, RATIONAL_ZERO};
use egg::{Id, RecExpr};
#[derive(Debug,Clone,Copy)]
pub struct PolyStat {
degree: usize,
factors: usize, // non-constant factors
ops: usize,
monomial: bool,
sum_of_monomials: bool,
monic: bool,
factorized: bool, // a product of monic polynomials and at least one constant
}
#[derive(Debug,Clone,Copy)]
pub enum FactorizationCost {
UnwantedOps,
Polynomial(PolyStat)
}
fn score(cost: FactorizationCost) -> usize {
match cost {
FactorizationCost::UnwantedOps => 10000,
FactorizationCost::Polynomial(p) =>
if !p.factorized {
1000 + p.ops
} else {
100 * (9 - p.factors) + p.ops
},
}
}
impl PartialEq for FactorizationCost {
fn eq(&self, other: &Self) -> bool {
score(*self) == score(*other)
}
}
impl PartialOrd for FactorizationCost {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
usize::partial_cmp(&score(*self), &score(*other))
}
}
pub struct FactorizationCostFn;
impl egg::CostFunction<EquationLanguage> for FactorizationCostFn {
type Cost = FactorizationCost;
fn cost<C>(&mut self, enode: &EquationLanguage, mut costs: C) -> Self::Cost
where
C: FnMut(Id) -> Self::Cost,
{
match enode {
EquationLanguage::Add([a,b]) => {
match (costs(*a), costs(*b)) {
(FactorizationCost::Polynomial(p1),FactorizationCost::Polynomial(p2)) => {
// we only ever want to add monomials
let result_monic = if p1.degree > p2.degree {
p1.monic
} else if p2.degree > p1.degree {
p2.monic
} else {
false
};
if !p1.sum_of_monomials || !p2.sum_of_monomials {
FactorizationCost::UnwantedOps
} else {
FactorizationCost::Polynomial(PolyStat {
degree: usize::max(p1.degree, p2.degree),
factors: 1,
ops: p1.ops + p2.ops,
monomial: false,
sum_of_monomials: p1.sum_of_monomials && p2.sum_of_monomials,
monic: result_monic,
factorized: result_monic,
})
}
},
_ => FactorizationCost::UnwantedOps
}
},
EquationLanguage::Mul([a,b]) => {
match (costs(*a), costs(*b)) {
(FactorizationCost::Polynomial(p1), FactorizationCost::Polynomial(p2)) => {
FactorizationCost::Polynomial(PolyStat {
degree: p1.degree + p2.degree,
factors: p1.factors + p2.factors,
ops: p1.ops + p2.ops,
monomial: p1.monomial && p2.monomial,
sum_of_monomials: p1.monomial && p2.monomial,
monic: p1.monic && p2.monic,
factorized: (p1.monic && p2.factorized) || (p2.monic && p1.factorized)
})
},
_ => FactorizationCost::UnwantedOps
}
},
EquationLanguage::Num(c) => {
FactorizationCost::Polynomial(PolyStat {
degree: 0,
factors: 0,
ops: 0,
monomial: true,
sum_of_monomials: true,
monic: false,
factorized: true
})
},
EquationLanguage::Unknown => {
FactorizationCost::Polynomial(PolyStat {
degree: 1,
factors: 1,
ops: 0,
monomial: true,
sum_of_monomials: true,
monic: true,
factorized: true
})
},
_ => FactorizationCost::UnwantedOps,
}
}
}
#[derive(Debug,Clone)]
pub struct Factorization {
pub constant_factor: Rational,
pub polynomials: Vec<Vec<Rational>>,
}
impl Display for Factorization {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if self.constant_factor != RATIONAL_ONE {
write!(f, "{}", self.constant_factor)?;
}
for poly in &self.polynomials {
write!(f, "(")?;
for (deg, coeff) in poly.iter().enumerate() {
if deg == 0 {
write!(f, "{}", coeff)?;
} else if deg == 1 {
write!(f, " + {}x", coeff)?;
} else {
write!(f, " + {}x^{}", coeff, deg)?;
}
}
write!(f, ")")?;
}
Ok(())
}
}
pub fn extract_factorization(expr: &RecExpr<EquationLanguage>) -> Factorization {
let root_id: Id = Id::from(expr.as_ref().len()-1);
let mut constant_factor: Option<Rational> = None;
let mut factors: Vec<Vec<Rational>> = Vec::new();
let mut todo: Vec<Id> = Vec::new();
todo.push(root_id);
while todo.len() > 0 {
let id = todo.pop().unwrap();
match &expr[id] {
EquationLanguage::Mul([a,b]) => {
todo.push(*a);
todo.push(*b);
},
EquationLanguage::Num(x) => {
assert!(constant_factor.is_none());
constant_factor = Some(x.clone());
},
_ => {
factors.push(extract_polynomial(expr, id));
}
}
}
Factorization {
constant_factor: constant_factor.unwrap_or_else(||RATIONAL_ONE.clone()),
polynomials: factors
}
}
fn extract_polynomial(expr: &RecExpr<EquationLanguage>, id: Id) -> Vec<Rational> {
let mut result: Vec<Rational> = Vec::new();
let mut todo: Vec<Id> = Vec::new();
todo.push(id);
while todo.len() > 0 {
let id = todo.pop().unwrap();
match &expr[id] {
EquationLanguage::Add([a,b]) => {
todo.push(*a);
todo.push(*b);
},
_ => {
let (deg, coeff) = extract_monomial(expr, id);
result.resize(result.len().max(deg), RATIONAL_ZERO.clone());
if result.len() <= deg {
result.push(coeff);
} else {
assert!(result[deg] == RATIONAL_ZERO);
result[deg] = coeff;
}
}
}
}
result
}
fn extract_monomial(expr: &RecExpr<EquationLanguage>, id: Id) -> (usize, Rational) {
let mut coeff: Option<Rational> = None;
let mut deg: usize = 0;
let mut todo: Vec<Id> = Vec::new();
todo.push(id);
while todo.len() > 0 {
let id = todo.pop().unwrap();
match &expr[id] {
EquationLanguage::Unknown => {
deg += 1;
},
EquationLanguage::Mul([a,b]) => {
todo.push(*a);
todo.push(*b);
},
EquationLanguage::Num(x) => {
assert!(coeff.is_none());
coeff = Some(x.clone());
},
_ => {
panic!("Not a rational polynomial in normal form!");
}
}
}
(deg, coeff.unwrap_or_else(||RATIONAL_ONE.clone()))
}

View File

@ -1,3 +1,5 @@
pub mod language;
pub mod normal_form;
pub mod parse;
pub mod factorization;
pub mod output;

View File

@ -1,24 +1,35 @@
use egg::{Extractor, Pattern, RecExpr, Runner};
use solveq::language::{RULES, EquationLanguage, PlusTimesCostFn, FactorizationCostFn};
use egg::{AstSize, EGraph, Extractor, Id, Pattern, RecExpr, Runner};
use solveq::factorization::{extract_factorization, FactorizationCost, FactorizationCostFn};
use solveq::language::{ConstantFold, EquationLanguage, FactorizationCostFn, PlusTimesCostFn, Rational, RULES};
use solveq::normal_form::analyze3;
use solveq::parse::parse_equation;
use solveq::output::print_term;
static TEST_EQUATIONS: &[&str] = &[
"(x + 50) * 10 - 150 - 100",
"(x - 2) * (x + 2) - 0",
"x ^ 2 - 4",
"x ^ 2 - 2 - 0",
"x ^ 2 - (2 * x + 15)",
"(x ^ 2 - 2 * x - 15) * (x + 5) - 0",
"x ^ 3 + 3 * x ^ 2 - 25 * x - 75 - 0",
"(x + 50) * 10 - 150 = 100",
"(x - 2) * (x + 2) = 0",
"x ^ 2 = 4",
"x ^ 2 - 2 = 0",
"x ^ 2 = 2 * x + 15",
"(x ^ 2 - 2 * x - 15) * (x + 5) = 0",
"x ^ 3 + 3 * x ^ 2 - 25 * x - 75 = 0",
];
fn main() {
for eq in TEST_EQUATIONS {
let start = parse_equation(*eq).unwrap();
// println!("{:?}", &start);
// do transformation to left - right = 0
fn main() {
let expr: RecExpr<EquationLanguage> = "(* x (+ x -2))".parse().unwrap();
println!("{:?}", get_expression_cost(&expr));
for eq in TEST_EQUATIONS {
println!("Equation: {}", *eq);
let mut start = parse_equation(*eq).unwrap();
let root_id = Id::from(start.as_ref().len()-1);
let EquationLanguage::Equals([left, right]) = start[root_id]
else { panic!("Not an equation without an equals sign!"); };
start[root_id] = EquationLanguage::Sub([left, right]);
println!("Parsed: {}", &start);
let mut runner = Runner::default()
.with_explanations_enabled()
@ -29,198 +40,55 @@ fn main() {
let (best_cost, best_expr) = extractor.find_best(runner.roots[0]);
println!("{}", start);
println!("{:?} {:?}", best_cost, <RecExpr<EquationLanguage> as AsRef<[EquationLanguage]>>::as_ref(&best_expr));
// println!("{:?} {:?}", best_cost, <RecExpr<EquationLanguage> as AsRef<[EquationLanguage]>>::as_ref(&best_expr));
println!("Best expresssion: {} {:?}", best_expr, best_cost);
println!("");
}
let factorization = extract_factorization(&best_expr);
// let root = runner.roots[0];
// let egraph = &runner.egraph;
// let pattern: Pattern<EquationLanguage> = "(+ (* ?a (* x x)) ?c)".parse().unwrap();
// let matches = pattern.search(&egraph);
// println!("{}", runner.explain_equivalence(&start, &best_expr).get_flat_string());
println!("Factorized normal form: {}", factorization);
let mut solutions: Vec<String> = Vec::new();
for poly in &factorization.polynomials {
if poly.len() == 2 { // linear factor
let Rational { num, denom } = &poly[0];
solutions.push(format!("x = {}", Rational { num: -*num, denom: *denom }));
} else if poly.len() == 3 { // quadratic factor
let Rational { num: num0, denom: denom0 } = &poly[0];
let Rational { num: num1, denom: denom1 } = &poly[1];
// println!("{:?}", egraph.classes().count());
let sol1 = format!("- ({num1})/(2 * ({denom1})) + ((({num1})/(2 * ({denom1}))) ^ 2 - ({num0}) / ({denom0})) ^ (1/2)");
let sol2 = format!("- ({num1})/(2 * ({denom1})) - ((({num1})/(2 * ({denom1}))) ^ 2 - ({num0}) / ({denom0})) ^ (1/2)");
// Analyze
// analyze3(egraph, runner.roots[0]);
let expr = parse_equation(&sol1).unwrap();
let runner = Runner::default()
.with_expr(&expr)
.run(&*RULES);
let extractor = Extractor::new(&runner.egraph, AstSize);
let (_, simplified_expr) = extractor.find_best(runner.roots[0]);
solutions.push(format!("x = {}", print_term(&simplified_expr)));
/*
for class in egraph.classes() {
if monic_nonconst_polynomial(egraph, class.id).is_some() {
let (_, best_expr) = extractor.find_best(class.id);
println!("Monomial: {}", best_expr);
}
}
println!("{:?}", &matches);
*/
// println!("{}", runner.explain_equivalence(&start, &best_expr).get_flat_string());
}
/*
fn power_of_x(egraph: &EGraph, eclass: Id) -> Option<usize> {
for n in &egraph[eclass].nodes {
match *n {
EquationLanguage::Unknown => { return Some(1) },
EquationLanguage::Mul([a,b]) => {
let Some(left) = power_of_x(egraph, a) else { continue };
let Some(right) = power_of_x(egraph, b) else { continue };
return Some(left + right);
},
_ => {}
}
}
None
}
fn monomial(egraph: &EGraph, eclass: Id) -> Option<(usize, Rational)> {
if let Some(deg) = power_of_x(egraph, eclass) {
return Some((deg, RATIONAL_ONE.clone()));
}
for n in &egraph[eclass].nodes {
match *n {
EquationLanguage::Mul([a,b]) => {
let Some(coeff) = egraph[a].data.clone() else { continue };
let Some(deg) = power_of_x(egraph, b) else { continue };
return Some((deg, coeff));
},
_ => {}
}
}
None
}
// this is either a power_of_x, or a sum of this and a monomial
fn monic_nonconst_polynomial(egraph: &EGraph, eclass: Id) -> Option<Vec<Rational>> {
let mut result: Vec<Rational> = Vec::new();
if let Some(deg) = power_of_x(egraph, eclass) {
result.resize(deg - 1, RATIONAL_ZERO);
result.push(RATIONAL_ONE.clone());
return Some(result);
}
for n in &egraph[eclass].nodes {
match *n {
EquationLanguage::Add([a,b]) => {
let Some(mut leading) = monic_nonconst_polynomial(egraph, a)
else { continue };
let Some(addon) = monomial(egraph, b)
else { continue };
if leading.len() <= addon.0 || leading[addon.0] != RATIONAL_ZERO {
continue;
}
leading[addon.0] = addon.1.clone();
return Some(leading);
},
_ => {},
}
}
None
}
*/
/*
fn analyze(egraph: &EGraph, _id: Id) {
let mut types: HashMap<Id, SpecialTerm> = HashMap::new();
let mut todo: VecDeque<Id> = VecDeque::new();
// todo.push_back(runner.roots[0]);
for cls in egraph.classes() {
todo.push_back(cls.id);
}
'todo: while todo.len() > 0 {
let id = todo.pop_front().unwrap();
if types.contains_key(&id) {
continue 'todo;
}
if let Some(c) = &egraph[id].data {
types.insert(id, SpecialTerm::Constant(c.clone()));
continue 'todo;
}
'nodes: for n in &egraph[id].nodes {
match *n {
EquationLanguage::Unknown => {
types.insert(id, SpecialTerm::PowerOfX(1));
continue 'todo;
},
EquationLanguage::Mul([a,b]) => {
if !types.contains_key(&a) {
todo.push_back(a);
todo.push_back(id);
continue 'nodes;
}
if !types.contains_key(&b) {
todo.push_back(b);
todo.push_back(id);
continue 'nodes;
}
match (&types[&a], &types[&b]) {
(SpecialTerm::PowerOfX(dega), SpecialTerm::PowerOfX(degb)) => {
types.insert(id, SpecialTerm::PowerOfX(*dega + *degb));
},
(SpecialTerm::Constant(coeff), SpecialTerm::PowerOfX(deg)) => {
types.insert(id, SpecialTerm::Monomial(*deg, coeff.clone()));
},
_ => { continue 'nodes; },
}
continue 'todo;
},
EquationLanguage::Add([a,b]) => {
if !types.contains_key(&a) {
todo.push_front(a);
todo.push_back(id);
continue 'todo;
}
if !types.contains_key(&b) {
todo.push_front(b);
todo.push_back(id);
continue 'todo;
}
match (&types[&a], &types[&b]) {
(SpecialTerm::MonicNonconstPoly(poly), SpecialTerm::Monomial(deg, coeff)) => {
if poly.len() <= *deg || poly[*deg] != RATIONAL_ZERO {
continue 'nodes;
}
let mut poly = poly.clone();
poly[*deg] = coeff.clone();
types.insert(id, SpecialTerm::MonicNonconstPoly(poly));
},
_ => { continue 'nodes; },
}
continue 'todo;
},
_ => {},
let expr = parse_equation(&sol2).unwrap();
let runner = Runner::default()
.with_expr(&expr)
.run(&*RULES);
let extractor = Extractor::new(&runner.egraph, AstSize);
let (_, simplified_expr) = extractor.find_best(runner.roots[0]);
solutions.push(format!("x = {}", print_term(&simplified_expr)));
}
}
types.insert(id, SpecialTerm::Other);
}
for (id, ty) in &types {
if !matches!(ty, &SpecialTerm::Other) {
println!("{:?}", &ty);
}
println!("Solutions: {{ {} }}", solutions.join(", "));
println!("");
}
}
*/
fn get_expression_cost(expr: &RecExpr<EquationLanguage>) -> FactorizationCost {
let mut egraph = EGraph::new(ConstantFold::default());
let id = egraph.add_expr(expr);
let extractor = Extractor::new(&egraph, FactorizationCostFn);
let (cost, _) = extractor.find_best(id);
cost
}

66
src/output.rs Normal file
View File

@ -0,0 +1,66 @@
use egg::{RecExpr, Id};
use crate::language::EquationLanguage;
// there is already a Display implementation generated by define_langauge!
// but we want an alternative string conversion
pub fn print_term(expr: &RecExpr<EquationLanguage>) -> String {
let root_id = Id::from(expr.as_ref().len()-1);
print_term_inner(expr, root_id).0
}
// the second result is the precedence of the top level op: 1 = '+-', 2 = '*/', 3 = '^', 4 = primitive
fn print_term_inner(expr: &RecExpr<EquationLanguage>, id: Id) -> (String, usize) {
match &expr[id] {
EquationLanguage::Num(c) => {
(format!("{}", c), if c.denom == 1 { 4 } else { 2 })
},
EquationLanguage::Neg([a]) => {
(print_unary(expr, *a, "-", 1), 1)
},
EquationLanguage::Add([a,b]) => {
(print_binary(expr, *a, *b, "+", 1), 1)
},
EquationLanguage::Sub([a,b]) => {
(print_binary(expr, *a, *b, "-", 1), 1)
},
EquationLanguage::Mul([a,b]) => {
(print_binary(expr, *a, *b, "*", 2), 2)
},
EquationLanguage::Div([a,b]) => {
(print_binary(expr, *a, *b, "/", 2), 2)
},
EquationLanguage::Power([a,b]) => {
(print_binary(expr, *a, *b, "^", 3), 3)
},
_ => unimplemented!()
}
}
fn print_unary(expr: &RecExpr<EquationLanguage>, a: Id, op: &str, precedence: usize) -> String {
let (astr, aprec) = print_term_inner(expr, a);
if aprec > precedence {
format!("{}{}", op, astr)
} else {
format!("{}({})", op, astr)
}
}
fn print_binary(expr: &RecExpr<EquationLanguage>, a: Id, b: Id, op: &str, precedence: usize) -> String {
let (astr, aprec) = print_term_inner(expr, a);
let (bstr, bprec) = print_term_inner(expr, b);
if aprec > precedence {
if bprec > precedence {
format!("{} {} {}", astr, op, bstr)
} else {
format!("{} {} ({})", astr, op, bstr)
}
} else {
if bprec > precedence {
format!("({}) {} {}", astr, op, bstr)
} else {
format!("({}) {} ({})", astr, op, bstr)
}
}
}

View File

@ -26,15 +26,15 @@ fn parse_equation_inner(input: &str, expr: &mut RecExpr<EquationLanguage>) -> Re
}
match c {
'^' if precedence > 3 => {
'^' if precedence >= 3 => {
operator_position = Some(i);
precedence = 3;
},
'*' | '/' if precedence > 2 => {
'*' | '/' if precedence >= 2 => {
operator_position = Some(i);
precedence = 2;
},
'-' | '+' if precedence > 1 => {
'-' | '+' if precedence >= 1 => {
operator_position = Some(i);
precedence = 1;
},