use std::{cmp::Ordering, fmt::{self,Display, Formatter}, num::ParseIntError, str::FromStr, sync::LazyLock}; use egg::{define_language, merge_option, rewrite as rw, Analysis, Applier, DidMerge, Id, Language, PatternAst, Subst, Symbol, SymbolLang, Var}; pub type EGraph = egg::EGraph; pub type Rewrite = egg::Rewrite; define_language! { pub enum EquationLanguage { "x" = Unknown, "+" = Add([Id; 2]), "-" = Sub([Id; 2]), "-" = Neg([Id; 1]), "*" = Mul([Id; 2]), "/" = Div([Id; 2]), "^" = Power([Id; 2]), "=" = Equals([Id; 2]), "rec" = Reciprocal([Id; 1]), Num(Rational), } } #[derive(Debug,Hash,Clone)] pub struct Rational { pub num: i64, pub denom: u64, } pub const RATIONAL_ZERO: Rational = Rational { num: 0, denom: 1 }; pub const RATIONAL_ONE: Rational = Rational { num: 1, denom: 1 }; impl Display for Rational { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { if self.denom == 1 { write!(f, "{}", self.num) } else { write!(f, "{}/{}", self.num, self.denom) } } } impl FromStr for Rational { type Err = String; fn from_str(s: &str) -> Result { let err = || Err(format!("Couldn't parse rational: {}", s)); if let Ok(num) = s.parse::() { Ok(Rational { num, denom: 1 }) } else if let Some((snum, sdenom)) = s.split_once('/') { let Ok(num) = snum.parse::() else { return err(); }; let Ok(denom) = sdenom.parse::() else { return err(); }; Ok(Rational { num, denom }) } else { err() } } } impl PartialEq for Rational { fn eq(&self, other: &Rational) -> bool { (self.num as i128) * (other.denom as i128) == (other.num as i128) * (self.denom as i128) } } impl Eq for Rational {} impl PartialOrd for Rational { fn partial_cmp(&self, other: &Rational) -> Option { i128::partial_cmp( &((self.num as i128) * (other.denom as i128)), &((other.num as i128) * (other.denom as i128)) ) } } impl Ord for Rational { fn cmp(&self, other: &Rational) -> Ordering { i128::cmp( &((self.num as i128) * (other.denom as i128)), &((other.num as i128) * (other.denom as i128)) ) } } impl Rational { fn simplify(&mut self) { let mut a = self.num.abs() as u64; let mut b = self.denom; if a > b { (a, b) = (b, a); } while a > 0 { (a, b) = (b % a, a); } self.num /= b as i64; self.denom /= b; } } // constant folding code essentially comes from egg examples, except using rationals instead of floats #[derive(Default)] pub struct ConstantFold; impl Analysis for ConstantFold { type Data = Option; fn make(egraph: &EGraph, enode: &EquationLanguage) -> Self::Data { let x = |i: &Id| -> Self::Data { egraph[*i].data.clone() }; let mut value = match enode { EquationLanguage::Num(c) => c.clone(), EquationLanguage::Add([a,b]) => Rational { num: x(a)?.num * x(b)?.denom as i64 + x(a)?.denom as i64 * x(b)?.num, denom: x(a)?.denom * x(b)?.denom }, EquationLanguage::Sub([a,b]) => Rational { num: x(a)?.num * x(b)?.denom as i64 - x(a)?.denom as i64 * x(b)?.num, denom: x(a)?.denom * x(b)?.denom }, EquationLanguage::Mul([a,b]) => Rational { num: x(a)?.num * x(b)?.num, denom: x(a)?.denom * x(b)?.denom }, EquationLanguage::Div([a,b]) => { if x(b)?.num == 0 { return None; } else if x(b)?.num > 0 { Rational { num: x(a)?.num * x(b)?.denom as i64, denom: x(b)?.num as u64 * x(a)?.denom, } } else { Rational { num: - x(a)?.num * x(b)?.denom as i64, denom: (-x(b)?.num) as u64 * x(a)?.denom, } } }, EquationLanguage::Neg([a]) => Rational { num: -x(a)?.num, denom: x(a)?.denom, }, EquationLanguage::Reciprocal([a]) => Rational { num: if x(a)?.num > 0 { x(a)?.denom as i64 } else { - (x(a)?.denom as i64) }, denom: x(a)?.num.abs() as u64, }, _ => return None, }; value.simplify(); Some(value) } fn merge(&mut self, to: &mut Self::Data, from: Self::Data) -> DidMerge { merge_option(to, from, |a, b| { assert!(a == &b, "Merged non-equal constants"); DidMerge(false, false) }) } fn modify(egraph: &mut EGraph, id: Id) { let data = egraph[id].data.clone(); if let Some(c) = data { let added = egraph.add(EquationLanguage::Num(c)); egraph.union(id, added); egraph[id].nodes.retain(|n|n.is_leaf()); } } } fn is_nonzero_const(var: &str) -> impl Fn(&mut EGraph, Id, &Subst) -> bool { let var: Var = var.parse().unwrap(); move |egraph, _, subst| { egraph[subst[var]].data.as_ref().filter(|x|*x != &RATIONAL_ZERO).is_some() } } fn is_const(var: &str) -> impl Fn(&mut EGraph, Id, &Subst) -> bool { let var: Var = var.parse().unwrap(); move |egraph, _, subst| { egraph[subst[var]].data.is_some() } } struct IntegerSqrt { var: Var, } impl Applier for IntegerSqrt { fn apply_one(&self, egraph: &mut EGraph, matched_id: Id, subst: &Subst, searcher_pattern: Option<&PatternAst>, rule_name: Symbol) -> Vec { let var_id = subst[self.var]; if let Some(value) = &egraph[var_id].data { if value.denom == 1 && value.num >= 0 { // isqrt is nightly only, so we just do this, adding 0.1 against rounding errors let sq = (f64::sqrt(value.num as f64) + 0.1) as i64; if value.num == sq*sq { // println!("square root of integer {} is {}", value.num, sq); let sq_id = egraph.add(EquationLanguage::Num(Rational { num: sq, denom: 1 })); egraph.union(matched_id, sq_id); return vec![matched_id, sq_id]; } } } vec![] } } pub static RULES: LazyLock> = LazyLock::new(||vec![ rw!("commute-add"; "(+ ?x ?y)" => "(+ ?y ?x)"), rw!("commute-mul"; "(* ?x ?y)" => "(* ?y ?x)"), rw!("assoc-add"; "(+ ?x (+ ?y ?z))" => "(+ (+ ?x ?y) ?z)"), rw!("assoc-mul"; "(* ?x (* ?y ?z))" => "(* (* ?x ?y) ?z)"), rw!("add-0"; "(+ ?x 0)" => "?x"), rw!("mul-0"; "(* ?x 0)" => "0"), rw!("mul-1"; "(* ?x 1)" => "?x"), rw!("0-sub"; "(- 0 ?x)" => "(- ?x)"), rw!("add-sub"; "(+ ?x (* (-1) ?x))" => "0"), // division by zero shouldn't happen unless input is invalid rw!("mul-div"; "(* ?x (rec ?x))" => "1" if is_nonzero_const("?y")), rw!("distribute"; "(* (+ ?x ?y) ?z)" => "(+ (* ?x ?z) (* ?y ?z))"), rw!("factor"; "(+ (* ?x ?z) (* ?y ?z))" => "(* (+ ?x ?y) ?z)"), rw!("square"; "(^ ?x 2)" => "(* ?x ?x)"), rw!("cube"; "(^ ?x 3)" => "(* ?x (* ?x ?x))"), rw!("sub"; "(- ?x ?y)" => "(+ ?x (* -1 ?y))"), rw!("neg"; "(- ?x)" => "(* -1 ?x)"), // division by zero shouldn't happen unless input is invalid rw!("div"; "(/ ?x ?y)" => "(* ?x (rec ?y))" if is_nonzero_const("?y")), rw!("factor_poly"; "(+ (* x ?x) ?y)" => "(* ?x (+ x (* ?y (rec ?x))))" if is_nonzero_const("?x")), rw!("integer_sqrt"; "(^ ?x (/ 1 2))" => { IntegerSqrt { var: "?x".parse().unwrap() } } if is_const("?x")), ]); pub struct PlusTimesCostFn; impl egg::CostFunction for PlusTimesCostFn { type Cost = usize; fn cost(&mut self, enode: &EquationLanguage, mut costs: C) -> usize where C: FnMut(Id) -> usize, { let op_cost = match enode { EquationLanguage::Div(_) => 1000, EquationLanguage::Sub(_) => 1000, EquationLanguage::Neg(_) => 1000, EquationLanguage::Reciprocal(_) => 1000, EquationLanguage::Power(_) => 1000, _ => 1, }; enode.fold(op_cost, |sum, i| sum + costs(i)) } }