271 lines
7.3 KiB
Rust
271 lines
7.3 KiB
Rust
use std::{cmp::Ordering, fmt::{self,Display, Formatter}, num::ParseIntError, str::FromStr, sync::LazyLock};
|
|
use egg::{define_language, merge_option, rewrite as rw, Analysis, Applier, DidMerge, Id, Language, PatternAst, Subst, Symbol, SymbolLang, Var};
|
|
|
|
pub type EGraph = egg::EGraph<EquationLanguage, ConstantFold>;
|
|
pub type Rewrite = egg::Rewrite<EquationLanguage, ConstantFold>;
|
|
|
|
define_language! {
|
|
pub enum EquationLanguage {
|
|
"x" = Unknown,
|
|
"+" = Add([Id; 2]),
|
|
"-" = Sub([Id; 2]),
|
|
"-" = Neg([Id; 1]),
|
|
"*" = Mul([Id; 2]),
|
|
"/" = Div([Id; 2]),
|
|
"^" = Power([Id; 2]),
|
|
"=" = Equals([Id; 2]),
|
|
"rec" = Reciprocal([Id; 1]),
|
|
Num(Rational),
|
|
}
|
|
}
|
|
|
|
#[derive(Debug,Hash,Clone)]
|
|
pub struct Rational {
|
|
pub num: i64,
|
|
pub denom: u64,
|
|
}
|
|
|
|
pub const RATIONAL_ZERO: Rational = Rational { num: 0, denom: 1 };
|
|
pub const RATIONAL_ONE: Rational = Rational { num: 1, denom: 1 };
|
|
|
|
impl Display for Rational {
|
|
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
|
|
if self.denom == 1 {
|
|
write!(f, "{}", self.num)
|
|
} else {
|
|
write!(f, "{}/{}", self.num, self.denom)
|
|
}
|
|
}
|
|
}
|
|
|
|
impl FromStr for Rational {
|
|
type Err = String;
|
|
fn from_str(s: &str) -> Result<Self, String> {
|
|
let err = || Err(format!("Couldn't parse rational: {}", s));
|
|
|
|
if let Ok(num) = s.parse::<i64>() {
|
|
Ok(Rational { num, denom: 1 })
|
|
} else if let Some((snum, sdenom)) = s.split_once('/') {
|
|
let Ok(num) = snum.parse::<i64>() else { return err(); };
|
|
let Ok(denom) = sdenom.parse::<u64>() else { return err(); };
|
|
Ok(Rational { num, denom })
|
|
} else {
|
|
err()
|
|
}
|
|
}
|
|
}
|
|
|
|
impl PartialEq for Rational {
|
|
fn eq(&self, other: &Rational) -> bool {
|
|
(self.num as i128) * (other.denom as i128) == (other.num as i128) * (self.denom as i128)
|
|
}
|
|
}
|
|
|
|
impl Eq for Rational {}
|
|
|
|
impl PartialOrd for Rational {
|
|
fn partial_cmp(&self, other: &Rational) -> Option<Ordering> {
|
|
i128::partial_cmp(
|
|
&((self.num as i128) * (other.denom as i128)),
|
|
&((other.num as i128) * (other.denom as i128))
|
|
)
|
|
}
|
|
}
|
|
|
|
impl Ord for Rational {
|
|
fn cmp(&self, other: &Rational) -> Ordering {
|
|
i128::cmp(
|
|
&((self.num as i128) * (other.denom as i128)),
|
|
&((other.num as i128) * (other.denom as i128))
|
|
)
|
|
}
|
|
}
|
|
|
|
impl Rational {
|
|
fn simplify(&mut self) {
|
|
let mut a = self.num.abs() as u64;
|
|
let mut b = self.denom;
|
|
|
|
if a > b {
|
|
(a, b) = (b, a);
|
|
}
|
|
|
|
while a > 0 {
|
|
(a, b) = (b % a, a);
|
|
}
|
|
|
|
self.num /= b as i64;
|
|
self.denom /= b;
|
|
}
|
|
}
|
|
|
|
// constant folding code essentially comes from egg examples, except using rationals instead of floats
|
|
#[derive(Default)]
|
|
pub struct ConstantFold;
|
|
impl Analysis<EquationLanguage> for ConstantFold {
|
|
type Data = Option<Rational>;
|
|
|
|
fn make(egraph: &EGraph, enode: &EquationLanguage) -> Self::Data {
|
|
let x = |i: &Id| -> Self::Data { egraph[*i].data.clone() };
|
|
|
|
let mut value = match enode {
|
|
EquationLanguage::Num(c) => c.clone(),
|
|
EquationLanguage::Add([a,b]) => Rational {
|
|
num: x(a)?.num * x(b)?.denom as i64 + x(a)?.denom as i64 * x(b)?.num,
|
|
denom: x(a)?.denom * x(b)?.denom
|
|
},
|
|
EquationLanguage::Sub([a,b]) => Rational {
|
|
num: x(a)?.num * x(b)?.denom as i64 - x(a)?.denom as i64 * x(b)?.num,
|
|
denom: x(a)?.denom * x(b)?.denom
|
|
},
|
|
EquationLanguage::Mul([a,b]) => Rational {
|
|
num: x(a)?.num * x(b)?.num,
|
|
denom: x(a)?.denom * x(b)?.denom
|
|
},
|
|
EquationLanguage::Div([a,b]) => {
|
|
if x(b)?.num == 0 {
|
|
return None;
|
|
} else if x(b)?.num > 0 {
|
|
Rational {
|
|
num: x(a)?.num * x(b)?.denom as i64,
|
|
denom: x(b)?.num as u64 * x(a)?.denom,
|
|
}
|
|
} else {
|
|
Rational {
|
|
num: - x(a)?.num * x(b)?.denom as i64,
|
|
denom: (-x(b)?.num) as u64 * x(a)?.denom,
|
|
}
|
|
}
|
|
},
|
|
EquationLanguage::Neg([a]) => Rational {
|
|
num: -x(a)?.num,
|
|
denom: x(a)?.denom,
|
|
},
|
|
EquationLanguage::Reciprocal([a]) => Rational {
|
|
num: if x(a)?.num > 0 { x(a)?.denom as i64 } else { - (x(a)?.denom as i64) },
|
|
denom: x(a)?.num.abs() as u64,
|
|
},
|
|
_ => return None,
|
|
};
|
|
|
|
value.simplify();
|
|
|
|
Some(value)
|
|
}
|
|
|
|
fn merge(&mut self, to: &mut Self::Data, from: Self::Data) -> DidMerge {
|
|
merge_option(to, from, |a, b| {
|
|
assert!(a == &b, "Merged non-equal constants");
|
|
DidMerge(false, false)
|
|
})
|
|
}
|
|
|
|
fn modify(egraph: &mut EGraph, id: Id) {
|
|
let data = egraph[id].data.clone();
|
|
if let Some(c) = data {
|
|
let added = egraph.add(EquationLanguage::Num(c));
|
|
egraph.union(id, added);
|
|
egraph[id].nodes.retain(|n|n.is_leaf());
|
|
}
|
|
}
|
|
}
|
|
|
|
fn is_nonzero_const(var: &str) -> impl Fn(&mut EGraph, Id, &Subst) -> bool {
|
|
let var: Var = var.parse().unwrap();
|
|
move |egraph, _, subst| {
|
|
egraph[subst[var]].data.as_ref().filter(|x|*x != &RATIONAL_ZERO).is_some()
|
|
}
|
|
}
|
|
|
|
fn is_const(var: &str) -> impl Fn(&mut EGraph, Id, &Subst) -> bool {
|
|
let var: Var = var.parse().unwrap();
|
|
move |egraph, _, subst| {
|
|
egraph[subst[var]].data.is_some()
|
|
}
|
|
}
|
|
|
|
struct IntegerSqrt {
|
|
var: Var,
|
|
}
|
|
|
|
impl Applier<EquationLanguage, ConstantFold> for IntegerSqrt {
|
|
fn apply_one(&self,
|
|
egraph: &mut EGraph,
|
|
matched_id: Id,
|
|
subst: &Subst,
|
|
searcher_pattern: Option<&PatternAst<EquationLanguage>>,
|
|
rule_name: Symbol)
|
|
-> Vec<Id> {
|
|
let var_id = subst[self.var];
|
|
|
|
if let Some(value) = &egraph[var_id].data {
|
|
if value.denom == 1 && value.num >= 0 {
|
|
// isqrt is nightly only, so we just do this, adding 0.1 against rounding errors
|
|
let sq = (f64::sqrt(value.num as f64) + 0.1) as i64;
|
|
if value.num == sq*sq {
|
|
// println!("square root of integer {} is {}", value.num, sq);
|
|
|
|
let sq_id = egraph.add(EquationLanguage::Num(Rational { num: sq, denom: 1 }));
|
|
egraph.union(matched_id, sq_id);
|
|
|
|
return vec![matched_id, sq_id];
|
|
}
|
|
}
|
|
}
|
|
|
|
vec![]
|
|
}
|
|
}
|
|
|
|
pub static RULES: LazyLock<Vec<Rewrite>> = LazyLock::new(||vec![
|
|
rw!("commute-add"; "(+ ?x ?y)" => "(+ ?y ?x)"),
|
|
rw!("commute-mul"; "(* ?x ?y)" => "(* ?y ?x)"),
|
|
|
|
rw!("assoc-add"; "(+ ?x (+ ?y ?z))" => "(+ (+ ?x ?y) ?z)"),
|
|
rw!("assoc-mul"; "(* ?x (* ?y ?z))" => "(* (* ?x ?y) ?z)"),
|
|
|
|
rw!("add-0"; "(+ ?x 0)" => "?x"),
|
|
rw!("mul-0"; "(* ?x 0)" => "0"),
|
|
rw!("mul-1"; "(* ?x 1)" => "?x"),
|
|
|
|
rw!("0-sub"; "(- 0 ?x)" => "(- ?x)"),
|
|
|
|
rw!("add-sub"; "(+ ?x (* (-1) ?x))" => "0"),
|
|
// division by zero shouldn't happen unless input is invalid
|
|
rw!("mul-div"; "(* ?x (rec ?x))" => "1" if is_nonzero_const("?y")),
|
|
|
|
rw!("distribute"; "(* (+ ?x ?y) ?z)" => "(+ (* ?x ?z) (* ?y ?z))"),
|
|
rw!("factor"; "(+ (* ?x ?z) (* ?y ?z))" => "(* (+ ?x ?y) ?z)"),
|
|
|
|
rw!("square"; "(^ ?x 2)" => "(* ?x ?x)"),
|
|
rw!("cube"; "(^ ?x 3)" => "(* ?x (* ?x ?x))"),
|
|
|
|
rw!("sub"; "(- ?x ?y)" => "(+ ?x (* -1 ?y))"),
|
|
rw!("neg"; "(- ?x)" => "(* -1 ?x)"),
|
|
// division by zero shouldn't happen unless input is invalid
|
|
rw!("div"; "(/ ?x ?y)" => "(* ?x (rec ?y))" if is_nonzero_const("?y")),
|
|
|
|
rw!("factor_poly"; "(+ (* x ?x) ?y)" => "(* ?x (+ x (* ?y (rec ?x))))" if is_nonzero_const("?x")),
|
|
|
|
rw!("integer_sqrt"; "(^ ?x (/ 1 2))" => { IntegerSqrt { var: "?x".parse().unwrap() } } if is_const("?x")),
|
|
]);
|
|
|
|
pub struct PlusTimesCostFn;
|
|
impl egg::CostFunction<EquationLanguage> for PlusTimesCostFn {
|
|
type Cost = usize;
|
|
fn cost<C>(&mut self, enode: &EquationLanguage, mut costs: C) -> usize
|
|
where
|
|
C: FnMut(Id) -> usize,
|
|
{
|
|
let op_cost = match enode {
|
|
EquationLanguage::Div(_) => 1000,
|
|
EquationLanguage::Sub(_) => 1000,
|
|
EquationLanguage::Neg(_) => 1000,
|
|
EquationLanguage::Reciprocal(_) => 1000,
|
|
EquationLanguage::Power(_) => 1000,
|
|
_ => 1,
|
|
};
|
|
enode.fold(op_cost, |sum, i| sum + costs(i))
|
|
}
|
|
}
|