cleaned up
This commit is contained in:
parent
c110dd6889
commit
0b8bdf3da6
@ -1,256 +0,0 @@
|
|||||||
use core::fmt;
|
|
||||||
use std::{cmp::Ordering, fmt::Display};
|
|
||||||
use crate::language::{EquationLanguage, Rational, RATIONAL_ONE, RATIONAL_ZERO};
|
|
||||||
use egg::{Id, RecExpr};
|
|
||||||
|
|
||||||
|
|
||||||
#[derive(Debug,Clone,Copy)]
|
|
||||||
pub struct PolyStat {
|
|
||||||
degree: usize,
|
|
||||||
factors: usize, // non-constant factors
|
|
||||||
ops: usize,
|
|
||||||
monomial: bool,
|
|
||||||
sum_of_monomials: bool,
|
|
||||||
monic: bool,
|
|
||||||
factorized: bool, // a product of monic polynomials and at least one constant
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug,Clone,Copy)]
|
|
||||||
pub enum FactorizationCost {
|
|
||||||
UnwantedOps,
|
|
||||||
Polynomial(PolyStat)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn score(cost: FactorizationCost) -> usize {
|
|
||||||
match cost {
|
|
||||||
FactorizationCost::UnwantedOps => 10000,
|
|
||||||
FactorizationCost::Polynomial(p) =>
|
|
||||||
if !p.factorized {
|
|
||||||
1000 + p.ops
|
|
||||||
} else {
|
|
||||||
100 * (9 - p.factors) + p.ops
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl PartialEq for FactorizationCost {
|
|
||||||
fn eq(&self, other: &Self) -> bool {
|
|
||||||
score(*self) == score(*other)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl PartialOrd for FactorizationCost {
|
|
||||||
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
|
|
||||||
usize::partial_cmp(&score(*self), &score(*other))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct FactorizationCostFn;
|
|
||||||
impl egg::CostFunction<EquationLanguage> for FactorizationCostFn {
|
|
||||||
type Cost = FactorizationCost;
|
|
||||||
|
|
||||||
fn cost<C>(&mut self, enode: &EquationLanguage, mut costs: C) -> Self::Cost
|
|
||||||
where
|
|
||||||
C: FnMut(Id) -> Self::Cost,
|
|
||||||
{
|
|
||||||
match enode {
|
|
||||||
EquationLanguage::Add([a,b]) => {
|
|
||||||
match (costs(*a), costs(*b)) {
|
|
||||||
(FactorizationCost::Polynomial(p1),FactorizationCost::Polynomial(p2)) => {
|
|
||||||
// we only ever want to add monomials
|
|
||||||
let result_monic = if p1.degree > p2.degree {
|
|
||||||
p1.monic
|
|
||||||
} else if p2.degree > p1.degree {
|
|
||||||
p2.monic
|
|
||||||
} else {
|
|
||||||
false
|
|
||||||
};
|
|
||||||
|
|
||||||
/*
|
|
||||||
if *a == Id::from(4) && *b == Id::from(19) {
|
|
||||||
println!("HERE {:?} {:?}", p1, p2);
|
|
||||||
}
|
|
||||||
*/
|
|
||||||
|
|
||||||
|
|
||||||
if !p1.sum_of_monomials || !p2.sum_of_monomials {
|
|
||||||
FactorizationCost::UnwantedOps
|
|
||||||
} else {
|
|
||||||
FactorizationCost::Polynomial(PolyStat {
|
|
||||||
degree: usize::max(p1.degree, p2.degree),
|
|
||||||
factors: 1,
|
|
||||||
ops: p1.ops + p2.ops + 1,
|
|
||||||
monomial: false,
|
|
||||||
sum_of_monomials: p1.sum_of_monomials && p2.sum_of_monomials,
|
|
||||||
monic: result_monic,
|
|
||||||
factorized: result_monic,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
},
|
|
||||||
_ => FactorizationCost::UnwantedOps
|
|
||||||
}
|
|
||||||
},
|
|
||||||
EquationLanguage::Mul([a,b]) => {
|
|
||||||
match (costs(*a), costs(*b)) {
|
|
||||||
(FactorizationCost::Polynomial(p1), FactorizationCost::Polynomial(p2)) => {
|
|
||||||
FactorizationCost::Polynomial(PolyStat {
|
|
||||||
degree: p1.degree + p2.degree,
|
|
||||||
factors: p1.factors + p2.factors,
|
|
||||||
ops: p1.ops + p2.ops + 1,
|
|
||||||
monomial: p1.monomial && p2.monomial,
|
|
||||||
sum_of_monomials: p1.monomial && p2.monomial,
|
|
||||||
monic: p1.monic && p2.monic,
|
|
||||||
factorized: (p1.monic && p2.factorized) || (p2.monic && p1.factorized)
|
|
||||||
})
|
|
||||||
},
|
|
||||||
_ => FactorizationCost::UnwantedOps
|
|
||||||
}
|
|
||||||
},
|
|
||||||
EquationLanguage::Num(c) => {
|
|
||||||
FactorizationCost::Polynomial(PolyStat {
|
|
||||||
degree: 0,
|
|
||||||
factors: 0,
|
|
||||||
ops: 0,
|
|
||||||
monomial: true,
|
|
||||||
sum_of_monomials: true,
|
|
||||||
monic: false,
|
|
||||||
factorized: true
|
|
||||||
})
|
|
||||||
},
|
|
||||||
EquationLanguage::Unknown => {
|
|
||||||
FactorizationCost::Polynomial(PolyStat {
|
|
||||||
degree: 1,
|
|
||||||
factors: 1,
|
|
||||||
ops: 0,
|
|
||||||
monomial: true,
|
|
||||||
sum_of_monomials: true,
|
|
||||||
monic: true,
|
|
||||||
factorized: true
|
|
||||||
})
|
|
||||||
},
|
|
||||||
_ => FactorizationCost::UnwantedOps,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug,Clone)]
|
|
||||||
pub struct Factorization {
|
|
||||||
pub constant_factor: Rational,
|
|
||||||
pub polynomials: Vec<Vec<Rational>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Display for Factorization {
|
|
||||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
|
||||||
if self.constant_factor != RATIONAL_ONE {
|
|
||||||
write!(f, "{}", self.constant_factor)?;
|
|
||||||
}
|
|
||||||
|
|
||||||
for poly in &self.polynomials {
|
|
||||||
write!(f, "(")?;
|
|
||||||
for (deg, coeff) in poly.iter().enumerate() {
|
|
||||||
if deg == 0 {
|
|
||||||
write!(f, "{}", coeff)?;
|
|
||||||
} else if deg == 1 {
|
|
||||||
write!(f, " + {}x", coeff)?;
|
|
||||||
} else {
|
|
||||||
write!(f, " + {}x^{}", coeff, deg)?;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
write!(f, ")")?;
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn extract_factorization(expr: &RecExpr<EquationLanguage>) -> Factorization {
|
|
||||||
let root_id: Id = Id::from(expr.as_ref().len()-1);
|
|
||||||
|
|
||||||
let mut constant_factor: Option<Rational> = None;
|
|
||||||
let mut factors: Vec<Vec<Rational>> = Vec::new();
|
|
||||||
let mut todo: Vec<Id> = Vec::new();
|
|
||||||
todo.push(root_id);
|
|
||||||
|
|
||||||
while todo.len() > 0 {
|
|
||||||
let id = todo.pop().unwrap();
|
|
||||||
|
|
||||||
match &expr[id] {
|
|
||||||
EquationLanguage::Mul([a,b]) => {
|
|
||||||
todo.push(*a);
|
|
||||||
todo.push(*b);
|
|
||||||
},
|
|
||||||
EquationLanguage::Num(x) => {
|
|
||||||
assert!(constant_factor.is_none());
|
|
||||||
constant_factor = Some(x.clone());
|
|
||||||
},
|
|
||||||
_ => {
|
|
||||||
factors.push(extract_polynomial(expr, id));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Factorization {
|
|
||||||
constant_factor: constant_factor.unwrap_or_else(||RATIONAL_ONE.clone()),
|
|
||||||
polynomials: factors
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn extract_polynomial(expr: &RecExpr<EquationLanguage>, id: Id) -> Vec<Rational> {
|
|
||||||
let mut result: Vec<Rational> = Vec::new();
|
|
||||||
let mut todo: Vec<Id> = Vec::new();
|
|
||||||
todo.push(id);
|
|
||||||
|
|
||||||
while todo.len() > 0 {
|
|
||||||
let id = todo.pop().unwrap();
|
|
||||||
|
|
||||||
match &expr[id] {
|
|
||||||
EquationLanguage::Add([a,b]) => {
|
|
||||||
todo.push(*a);
|
|
||||||
todo.push(*b);
|
|
||||||
},
|
|
||||||
_ => {
|
|
||||||
let (deg, coeff) = extract_monomial(expr, id);
|
|
||||||
result.resize(result.len().max(deg), RATIONAL_ZERO.clone());
|
|
||||||
|
|
||||||
if result.len() <= deg {
|
|
||||||
result.push(coeff);
|
|
||||||
} else {
|
|
||||||
assert!(result[deg] == RATIONAL_ZERO);
|
|
||||||
result[deg] = coeff;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
result
|
|
||||||
}
|
|
||||||
|
|
||||||
fn extract_monomial(expr: &RecExpr<EquationLanguage>, id: Id) -> (usize, Rational) {
|
|
||||||
let mut coeff: Option<Rational> = None;
|
|
||||||
let mut deg: usize = 0;
|
|
||||||
let mut todo: Vec<Id> = Vec::new();
|
|
||||||
todo.push(id);
|
|
||||||
|
|
||||||
while todo.len() > 0 {
|
|
||||||
let id = todo.pop().unwrap();
|
|
||||||
|
|
||||||
match &expr[id] {
|
|
||||||
EquationLanguage::Unknown => {
|
|
||||||
deg += 1;
|
|
||||||
},
|
|
||||||
EquationLanguage::Mul([a,b]) => {
|
|
||||||
todo.push(*a);
|
|
||||||
todo.push(*b);
|
|
||||||
},
|
|
||||||
EquationLanguage::Num(x) => {
|
|
||||||
assert!(coeff.is_none());
|
|
||||||
coeff = Some(x.clone());
|
|
||||||
},
|
|
||||||
_ => {
|
|
||||||
panic!("Not a rational polynomial in normal form!");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
(deg, coeff.unwrap_or_else(||RATIONAL_ONE.clone()))
|
|
||||||
}
|
|
@ -1,5 +1,5 @@
|
|||||||
use std::{cmp::Ordering, fmt::{self,Display, Formatter}, num::ParseIntError, str::FromStr, sync::LazyLock};
|
use std::{cmp::Ordering, fmt::{self,Display, Formatter}, str::FromStr, sync::LazyLock};
|
||||||
use egg::{define_language, merge_option, rewrite as rw, Analysis, Applier, DidMerge, Id, Language, PatternAst, Subst, Symbol, SymbolLang, Var};
|
use egg::{define_language, merge_option, rewrite as rw, Analysis, Applier, DidMerge, Id, Language, PatternAst, Subst, Symbol, Var};
|
||||||
|
|
||||||
pub type EGraph = egg::EGraph<EquationLanguage, ConstantFold>;
|
pub type EGraph = egg::EGraph<EquationLanguage, ConstantFold>;
|
||||||
pub type Rewrite = egg::Rewrite<EquationLanguage, ConstantFold>;
|
pub type Rewrite = egg::Rewrite<EquationLanguage, ConstantFold>;
|
||||||
@ -193,8 +193,8 @@ impl Applier<EquationLanguage, ConstantFold> for IntegerSqrt {
|
|||||||
egraph: &mut EGraph,
|
egraph: &mut EGraph,
|
||||||
matched_id: Id,
|
matched_id: Id,
|
||||||
subst: &Subst,
|
subst: &Subst,
|
||||||
searcher_pattern: Option<&PatternAst<EquationLanguage>>,
|
_searcher_pattern: Option<&PatternAst<EquationLanguage>>,
|
||||||
rule_name: Symbol)
|
_rule_name: Symbol)
|
||||||
-> Vec<Id> {
|
-> Vec<Id> {
|
||||||
let var_id = subst[self.var];
|
let var_id = subst[self.var];
|
||||||
|
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
pub mod language;
|
pub mod language;
|
||||||
pub mod normal_form;
|
pub mod normal_form;
|
||||||
pub mod parse;
|
pub mod parse;
|
||||||
pub mod factorization;
|
|
||||||
pub mod output;
|
pub mod output;
|
||||||
|
79
src/main.rs
79
src/main.rs
@ -1,6 +1,5 @@
|
|||||||
use egg::{AstSize, Extractor, Id, Pattern, PatternAst, RecExpr, Runner, Searcher};
|
use egg::{AstSize, Extractor, Id, RecExpr, Runner};
|
||||||
use solveq::factorization::{FactorizationCost, FactorizationCostFn};
|
use solveq::language::{EquationLanguage, Rational, RULES};
|
||||||
use solveq::language::{ConstantFold, EquationLanguage, Rational, RULES, EGraph};
|
|
||||||
use solveq::normal_form::extract_normal_form;
|
use solveq::normal_form::extract_normal_form;
|
||||||
use solveq::parse::parse_equation;
|
use solveq::parse::parse_equation;
|
||||||
use solveq::output::print_term;
|
use solveq::output::print_term;
|
||||||
@ -15,54 +14,56 @@ static TEST_EQUATIONS: &[&str] = &[
|
|||||||
"x ^ 3 + 3 * x ^ 2 - 25 * x - 75 = 0",
|
"x ^ 3 + 3 * x ^ 2 - 25 * x - 75 = 0",
|
||||||
];
|
];
|
||||||
|
|
||||||
|
|
||||||
fn main() {
|
fn main() {
|
||||||
// let expr: RecExpr<EquationLanguage> = "(+ (* x (+ x -2)) -15)".parse().unwrap();
|
|
||||||
// let expr: RecExpr<EquationLanguage> = "(* (+ (* x (+ x -2)) -15) (+ x 5))".parse().unwrap();
|
|
||||||
// println!("{:?}", get_expression_cost(&expr));
|
|
||||||
|
|
||||||
for eq in TEST_EQUATIONS {
|
for eq in TEST_EQUATIONS {
|
||||||
println!("Equation: {}", *eq);
|
println!("Equation: {}", *eq);
|
||||||
|
|
||||||
let mut start = parse_equation(*eq).unwrap();
|
let solutions = solve(*eq, true);
|
||||||
|
|
||||||
|
let solutions_str: Vec<String> = solutions
|
||||||
|
.iter()
|
||||||
|
.map(|expr| format!("x = {}", print_term(expr)))
|
||||||
|
.collect();
|
||||||
|
println!("Solutions: {{ {} }}", solutions_str.join(", "));
|
||||||
|
|
||||||
|
println!("");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn solve(eq: &str, verbose: bool) -> Vec<RecExpr<EquationLanguage>> {
|
||||||
|
let mut start = parse_equation(eq).unwrap();
|
||||||
let root_id = Id::from(start.as_ref().len()-1);
|
let root_id = Id::from(start.as_ref().len()-1);
|
||||||
let EquationLanguage::Equals([left, right]) = start[root_id]
|
let EquationLanguage::Equals([left, right]) = start[root_id]
|
||||||
else { panic!("Not an equation without an equals sign!"); };
|
else { panic!("Not an equation without an equals sign!"); };
|
||||||
start[root_id] = EquationLanguage::Sub([left, right]);
|
start[root_id] = EquationLanguage::Sub([left, right]);
|
||||||
|
|
||||||
|
if verbose {
|
||||||
println!("Parsed: {}", &start);
|
println!("Parsed: {}", &start);
|
||||||
|
}
|
||||||
|
|
||||||
let mut runner = Runner::default()
|
let runner = Runner::default()
|
||||||
.with_explanations_enabled()
|
.with_explanations_enabled()
|
||||||
.with_expr(&start)
|
.with_expr(&start)
|
||||||
.run(&*RULES);
|
.run(&*RULES);
|
||||||
|
|
||||||
let extractor = Extractor::new(&runner.egraph, FactorizationCostFn);
|
|
||||||
|
|
||||||
let (best_cost, best_expr) = extractor.find_best(runner.roots[0]);
|
|
||||||
|
|
||||||
// println!("Best expresssion: {} {:?}", best_expr, best_cost);
|
|
||||||
// println!("{}", runner.explain_equivalence(&start, &best_expr).get_flat_string());
|
|
||||||
|
|
||||||
let Some(factorization) = extract_normal_form(&runner.egraph, runner.roots[0]) else {
|
let Some(factorization) = extract_normal_form(&runner.egraph, runner.roots[0]) else {
|
||||||
panic!("Couldn't factorize polynomial!");
|
panic!("Couldn't factorize polynomial!");
|
||||||
};
|
};
|
||||||
|
|
||||||
|
if verbose {
|
||||||
println!("Factorized normal form: {}", factorization);
|
println!("Factorized normal form: {}", factorization);
|
||||||
|
}
|
||||||
|
|
||||||
/*
|
let mut solutions: Vec<RecExpr<EquationLanguage>> = Vec::new();
|
||||||
get_expression_cost("(* x (+ x -2))", &runner.egraph);
|
|
||||||
get_expression_cost("-15", &runner.egraph);
|
|
||||||
get_expression_cost("(+ (* x (+ x -2)) -15)", &runner.egraph);
|
|
||||||
*/
|
|
||||||
|
|
||||||
// let factorization = extract_factorization(&best_expr);
|
|
||||||
|
|
||||||
|
|
||||||
let mut solutions: Vec<String> = Vec::new();
|
|
||||||
for poly in &factorization.polynomials {
|
for poly in &factorization.polynomials {
|
||||||
if poly.len() == 2 { // linear factor
|
if poly.len() == 2 { // linear factor
|
||||||
let Rational { num, denom } = &poly[0];
|
let Rational { num, denom } = &poly[0];
|
||||||
solutions.push(format!("x = {}", Rational { num: -*num, denom: *denom }));
|
let mut solexpr = RecExpr::default();
|
||||||
|
solexpr.add(EquationLanguage::Num(Rational {
|
||||||
|
num: -*num,
|
||||||
|
denom: *denom
|
||||||
|
}));
|
||||||
|
solutions.push(solexpr);
|
||||||
} else if poly.len() == 3 { // quadratic factor
|
} else if poly.len() == 3 { // quadratic factor
|
||||||
let Rational { num: num0, denom: denom0 } = &poly[0];
|
let Rational { num: num0, denom: denom0 } = &poly[0];
|
||||||
let Rational { num: num1, denom: denom1 } = &poly[1];
|
let Rational { num: num1, denom: denom1 } = &poly[1];
|
||||||
@ -76,7 +77,7 @@ fn main() {
|
|||||||
.run(&*RULES);
|
.run(&*RULES);
|
||||||
let extractor = Extractor::new(&runner.egraph, AstSize);
|
let extractor = Extractor::new(&runner.egraph, AstSize);
|
||||||
let (_, simplified_expr) = extractor.find_best(runner.roots[0]);
|
let (_, simplified_expr) = extractor.find_best(runner.roots[0]);
|
||||||
solutions.push(format!("x = {}", print_term(&simplified_expr)));
|
solutions.push(simplified_expr);
|
||||||
|
|
||||||
let expr = parse_equation(&sol2).unwrap();
|
let expr = parse_equation(&sol2).unwrap();
|
||||||
let runner = Runner::default()
|
let runner = Runner::default()
|
||||||
@ -84,25 +85,9 @@ fn main() {
|
|||||||
.run(&*RULES);
|
.run(&*RULES);
|
||||||
let extractor = Extractor::new(&runner.egraph, AstSize);
|
let extractor = Extractor::new(&runner.egraph, AstSize);
|
||||||
let (_, simplified_expr) = extractor.find_best(runner.roots[0]);
|
let (_, simplified_expr) = extractor.find_best(runner.roots[0]);
|
||||||
solutions.push(format!("x = {}", print_term(&simplified_expr)));
|
solutions.push(simplified_expr);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
println!("Solutions: {{ {} }}", solutions.join(", "));
|
solutions
|
||||||
println!("");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn get_expression_cost(expr: &str, egraph: &EGraph) {
|
|
||||||
// let mut egraph = EGraph::new(ConstantFold::default());
|
|
||||||
// let id = egraph.add_expr(expr);
|
|
||||||
let pattern: Pattern<EquationLanguage> = expr.parse().unwrap();
|
|
||||||
let matches = pattern.search(egraph);
|
|
||||||
for m in matches {
|
|
||||||
let extractor = Extractor::new(&egraph, FactorizationCostFn);
|
|
||||||
let (cost, _) = extractor.find_best(m.eclass);
|
|
||||||
|
|
||||||
println!("expr: {}, id: {}, cost: {:?}", expr, m.eclass, cost);
|
|
||||||
}
|
|
||||||
// cost
|
|
||||||
}
|
}
|
||||||
|
@ -1,16 +1,6 @@
|
|||||||
use crate::language::{EGraph, EquationLanguage, Rational, RATIONAL_ONE, RATIONAL_ZERO};
|
use crate::language::{EGraph, EquationLanguage, Rational, RATIONAL_ONE, RATIONAL_ZERO};
|
||||||
use std::{collections::HashMap, fmt};
|
use std::{collections::HashMap, fmt};
|
||||||
use egg::{AstSize, Extractor, Id};
|
use egg::Id;
|
||||||
|
|
||||||
#[derive(Debug,Clone)]
|
|
||||||
pub enum SpecialTerm {
|
|
||||||
Constant(Rational),
|
|
||||||
PowerOfX(usize),
|
|
||||||
Monomial(usize, Rational),
|
|
||||||
MonicNonconstPoly(Vec<Rational>),
|
|
||||||
Factorization(Rational, Vec<Vec<Rational>>),
|
|
||||||
Other,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug,Clone)]
|
#[derive(Debug,Clone)]
|
||||||
pub struct Factorization {
|
pub struct Factorization {
|
||||||
@ -18,7 +8,32 @@ pub struct Factorization {
|
|||||||
pub polynomials: Vec<Vec<Rational>>,
|
pub polynomials: Vec<Vec<Rational>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
// this is a property of an eclass, not a particular expression
|
impl fmt::Display for Factorization {
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
|
if self.constant_factor != RATIONAL_ONE {
|
||||||
|
write!(f, "{}", self.constant_factor)?;
|
||||||
|
}
|
||||||
|
|
||||||
|
for poly in &self.polynomials {
|
||||||
|
write!(f, "(")?;
|
||||||
|
for (deg, coeff) in poly.iter().enumerate() {
|
||||||
|
if deg == 0 {
|
||||||
|
write!(f, "{}", coeff)?;
|
||||||
|
} else if deg == 1 {
|
||||||
|
write!(f, " + {}x", coeff)?;
|
||||||
|
} else {
|
||||||
|
write!(f, " + {}x^{}", coeff, deg)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
write!(f, ")")?;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// warning: this is a property of an eclass, not a particular expression
|
||||||
|
// so we shouldn't use anything that's not a polynomial invariant
|
||||||
#[derive(Debug,Clone)]
|
#[derive(Debug,Clone)]
|
||||||
struct PolyStats {
|
struct PolyStats {
|
||||||
degree: usize,
|
degree: usize,
|
||||||
@ -141,10 +156,6 @@ fn extract_polynomial(egraph: &EGraph, stats: &HashMap<Id, PolyStats>, id: Id) -
|
|||||||
|
|
||||||
// a monomial is either a power of x, a constant, or a product of a constant and power of x
|
// a monomial is either a power of x, a constant, or a product of a constant and power of x
|
||||||
fn extract_monomial(egraph: &EGraph, stats: &HashMap<Id, PolyStats>, id: Id) -> Option<(usize, Rational)> {
|
fn extract_monomial(egraph: &EGraph, stats: &HashMap<Id, PolyStats>, id: Id) -> Option<(usize, Rational)> {
|
||||||
let extractor = Extractor::new(egraph, AstSize);
|
|
||||||
let (_, expr) = extractor.find_best(id);
|
|
||||||
// println!("Extract Monomial: {}", expr);
|
|
||||||
|
|
||||||
let st = &stats[&id];
|
let st = &stats[&id];
|
||||||
|
|
||||||
assert!(st.monomial);
|
assert!(st.monomial);
|
||||||
@ -221,7 +232,6 @@ fn find_monic_factorization(egraph: &EGraph, stats: &HashMap<Id, PolyStats>, id:
|
|||||||
|
|
||||||
// now the whole thing is a monic nonconst poly, so would be a valid factorization,
|
// now the whole thing is a monic nonconst poly, so would be a valid factorization,
|
||||||
// but we want to go as deep as possible
|
// but we want to go as deep as possible
|
||||||
// println!("{:?}", stats[&id]);
|
|
||||||
|
|
||||||
// check if it is the product of two nonconstant monic polynomials
|
// check if it is the product of two nonconstant monic polynomials
|
||||||
for node in &egraph[id].nodes {
|
for node in &egraph[id].nodes {
|
||||||
@ -260,8 +270,6 @@ where
|
|||||||
let mut result: HashMap<Id, T> = HashMap::new();
|
let mut result: HashMap<Id, T> = HashMap::new();
|
||||||
let mut modifications: usize = 1;
|
let mut modifications: usize = 1;
|
||||||
|
|
||||||
// println!("{:?}", egraph[canonical]);
|
|
||||||
|
|
||||||
while modifications > 0 {
|
while modifications > 0 {
|
||||||
modifications = 0;
|
modifications = 0;
|
||||||
|
|
||||||
@ -285,215 +293,3 @@ where
|
|||||||
|
|
||||||
result
|
result
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
|
||||||
pub fn analyze3(egraph: &EGraph, eclass: Id) {
|
|
||||||
let constants = search_for(egraph, |id, _, _|
|
|
||||||
egraph[id].data.as_ref().map(|c|c.clone())
|
|
||||||
);
|
|
||||||
|
|
||||||
println!("{:?}", constants);
|
|
||||||
|
|
||||||
let powers_of_x = search_for(egraph, |_, node, matches| match *node {
|
|
||||||
EquationLanguage::Unknown => Some(1),
|
|
||||||
EquationLanguage::Mul([a,b]) => {
|
|
||||||
if !matches.contains_key(&a) || !matches.contains_key(&b) {
|
|
||||||
return None;
|
|
||||||
}
|
|
||||||
|
|
||||||
let (dega, degb) = (matches[&a], matches[&b]);
|
|
||||||
Some(dega + degb)
|
|
||||||
},
|
|
||||||
_ => None,
|
|
||||||
});
|
|
||||||
|
|
||||||
println!("{:?}", powers_of_x);
|
|
||||||
|
|
||||||
let monomials = search_for(egraph, |id, node, _| {
|
|
||||||
if let Some(deg) = powers_of_x.get(&id) {
|
|
||||||
return Some((*deg, RATIONAL_ONE.clone()));
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Some(c) = constants.get(&id) {
|
|
||||||
return Some((0, c.clone()));
|
|
||||||
}
|
|
||||||
|
|
||||||
match *node {
|
|
||||||
EquationLanguage::Mul([a,b]) => {
|
|
||||||
if !constants.contains_key(&a) || !powers_of_x.contains_key(&b) {
|
|
||||||
return None;
|
|
||||||
}
|
|
||||||
|
|
||||||
let (coeff, deg) = (&constants[&a], powers_of_x[&b]);
|
|
||||||
Some((deg, coeff.clone()))
|
|
||||||
},
|
|
||||||
_ => None,
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
println!("{:?}", monomials);
|
|
||||||
|
|
||||||
let monic_polynomials = search_for(egraph, |id, node, matches| {
|
|
||||||
if let Some(deg) = powers_of_x.get(&id) {
|
|
||||||
let mut poly: Vec<Rational> = Vec::new();
|
|
||||||
poly.resize(*deg, RATIONAL_ZERO);
|
|
||||||
poly.push(RATIONAL_ONE.clone());
|
|
||||||
Some(poly)
|
|
||||||
} else {
|
|
||||||
match *node {
|
|
||||||
EquationLanguage::Add([a,b]) => {
|
|
||||||
if !matches.contains_key(&a) || !monomials.contains_key(&b) {
|
|
||||||
return None;
|
|
||||||
}
|
|
||||||
|
|
||||||
let (leading, (deg, coeff)) = (&matches[&a], &monomials[&b]);
|
|
||||||
if leading.len() <= *deg || leading[*deg] != RATIONAL_ZERO {
|
|
||||||
return None;
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut poly = leading.clone();
|
|
||||||
poly[*deg] = coeff.clone();
|
|
||||||
Some(poly)
|
|
||||||
},
|
|
||||||
_ => None,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
for p in &monic_polynomials {
|
|
||||||
println!("{:?}", p);
|
|
||||||
}
|
|
||||||
|
|
||||||
let factorizations: HashMap<Id, (Rational, Vec<Vec<Rational>>)> = search_for(egraph, |id, node, matches| {
|
|
||||||
if let Some(c) = constants.get(&id) {
|
|
||||||
return Some((c.clone(), vec![]));
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Some(poly) = monic_polynomials.get(&id) {
|
|
||||||
return Some((RATIONAL_ONE.clone(), vec![poly.clone()]));
|
|
||||||
}
|
|
||||||
|
|
||||||
match *node {
|
|
||||||
EquationLanguage::Mul([a,b]) => {
|
|
||||||
if !matches.contains_key(&a) || !monic_polynomials.contains_key(&b) {
|
|
||||||
return None;
|
|
||||||
}
|
|
||||||
|
|
||||||
let ((factor, polys), newpoly) = (&matches[&a], &monic_polynomials[&b]);
|
|
||||||
|
|
||||||
let mut combined: Vec<Vec<Rational>> = polys.clone();
|
|
||||||
combined.push(newpoly.clone());
|
|
||||||
Some((factor.clone(), combined))
|
|
||||||
},
|
|
||||||
_ => None,
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
/*
|
|
||||||
for p in &factorizations {
|
|
||||||
println!("{:?}", p);
|
|
||||||
}
|
|
||||||
*/
|
|
||||||
|
|
||||||
println!("{:?}", factorizations[&eclass]);
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn analyze2(egraph: &EGraph) -> HashMap<Id, SpecialTerm> {
|
|
||||||
let mut types: HashMap<Id, SpecialTerm> = HashMap::new();
|
|
||||||
|
|
||||||
let mut modifications: usize = 1;
|
|
||||||
|
|
||||||
while modifications > 0 {
|
|
||||||
modifications = 0;
|
|
||||||
|
|
||||||
for cls in egraph.classes() {
|
|
||||||
let id = cls.id;
|
|
||||||
if types.contains_key(&id) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Some(c) = &egraph[id].data {
|
|
||||||
types.insert(id, SpecialTerm::Constant(c.clone()));
|
|
||||||
modifications += 1;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
for node in &cls.nodes {
|
|
||||||
match *node {
|
|
||||||
EquationLanguage::Unknown => {
|
|
||||||
types.insert(id, SpecialTerm::PowerOfX(1));
|
|
||||||
modifications += 1;
|
|
||||||
},
|
|
||||||
EquationLanguage::Mul([a,b]) => {
|
|
||||||
// as we don't know a and b yet, defer to future iteration
|
|
||||||
if !types.contains_key(&a) || !types.contains_key(&b) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
match (&types[&a], &types[&b]) {
|
|
||||||
(SpecialTerm::PowerOfX(dega), SpecialTerm::PowerOfX(degb)) => {
|
|
||||||
types.insert(id, SpecialTerm::PowerOfX(*dega + *degb));
|
|
||||||
modifications += 1;
|
|
||||||
},
|
|
||||||
(SpecialTerm::Constant(coeff), SpecialTerm::PowerOfX(deg)) => {
|
|
||||||
types.insert(id, SpecialTerm::Monomial(*deg, coeff.clone()));
|
|
||||||
modifications += 1;
|
|
||||||
},
|
|
||||||
_ => { },
|
|
||||||
}
|
|
||||||
},
|
|
||||||
EquationLanguage::Add([a,b]) => {
|
|
||||||
// as we don't know a and b yet, defer to future iteration
|
|
||||||
if !types.contains_key(&a) || !types.contains_key(&b) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
match (&types[&a], &types[&b]) {
|
|
||||||
(SpecialTerm::MonicNonconstPoly(poly), SpecialTerm::Monomial(deg, coeff)) => {
|
|
||||||
if poly.len() <= *deg || poly[*deg] != RATIONAL_ZERO {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut poly = poly.clone();
|
|
||||||
poly[*deg] = coeff.clone();
|
|
||||||
types.insert(id, SpecialTerm::MonicNonconstPoly(poly));
|
|
||||||
modifications += 1;
|
|
||||||
},
|
|
||||||
_ => { },
|
|
||||||
}
|
|
||||||
},
|
|
||||||
_ => {}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
println!("{} modifications!", modifications);
|
|
||||||
}
|
|
||||||
|
|
||||||
types
|
|
||||||
}
|
|
||||||
*/
|
|
||||||
|
|
||||||
impl fmt::Display for Factorization {
|
|
||||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
|
||||||
if self.constant_factor != RATIONAL_ONE {
|
|
||||||
write!(f, "{}", self.constant_factor)?;
|
|
||||||
}
|
|
||||||
|
|
||||||
for poly in &self.polynomials {
|
|
||||||
write!(f, "(")?;
|
|
||||||
for (deg, coeff) in poly.iter().enumerate() {
|
|
||||||
if deg == 0 {
|
|
||||||
write!(f, "{}", coeff)?;
|
|
||||||
} else if deg == 1 {
|
|
||||||
write!(f, " + {}x", coeff)?;
|
|
||||||
} else {
|
|
||||||
write!(f, " + {}x^{}", coeff, deg)?;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
write!(f, ")")?;
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
use std::error::Error;
|
use std::error::Error;
|
||||||
use egg::*;
|
use egg::{RecExpr, Id, FromOp, FromOpError};
|
||||||
use crate::language::EquationLanguage;
|
use crate::language::EquationLanguage;
|
||||||
|
|
||||||
pub fn parse_equation(input: &str) -> Result<RecExpr<EquationLanguage>, ParseError> {
|
pub fn parse_equation(input: &str) -> Result<RecExpr<EquationLanguage>, ParseError> {
|
||||||
|
Loading…
Reference in New Issue
Block a user