From c110dd6889b6eda30eca114072e17caf98bf2453 Mon Sep 17 00:00:00 2001 From: Florian Stecker Date: Wed, 28 Aug 2024 12:07:03 -0400 Subject: [PATCH] all examples work --- src/language.rs | 57 +++++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 51 insertions(+), 6 deletions(-) diff --git a/src/language.rs b/src/language.rs index 04b8463..c3e430a 100644 --- a/src/language.rs +++ b/src/language.rs @@ -1,5 +1,5 @@ -use std::{cmp::Ordering, fmt::{self,Display, Formatter}, str::FromStr, sync::LazyLock}; -use egg::{define_language, merge_option, rewrite as rw, Analysis, DidMerge, Id, Language, Subst, Var}; +use std::{cmp::Ordering, fmt::{self,Display, Formatter}, num::ParseIntError, str::FromStr, sync::LazyLock}; +use egg::{define_language, merge_option, rewrite as rw, Analysis, Applier, DidMerge, Id, Language, PatternAst, Subst, Symbol, SymbolLang, Var}; pub type EGraph = egg::EGraph; pub type Rewrite = egg::Rewrite; @@ -39,9 +39,19 @@ impl Display for Rational { } impl FromStr for Rational { - type Err = std::num::ParseIntError; - fn from_str(s: &str) -> Result { - Ok(Rational { num: s.parse::()?, denom: 1 }) + type Err = String; + fn from_str(s: &str) -> Result { + let err = || Err(format!("Couldn't parse rational: {}", s)); + + if let Ok(num) = s.parse::() { + Ok(Rational { num, denom: 1 }) + } else if let Some((snum, sdenom)) = s.split_once('/') { + let Ok(num) = snum.parse::() else { return err(); }; + let Ok(denom) = sdenom.parse::() else { return err(); }; + Ok(Rational { num, denom }) + } else { + err() + } } } @@ -174,6 +184,39 @@ fn is_const(var: &str) -> impl Fn(&mut EGraph, Id, &Subst) -> bool { } } +struct IntegerSqrt { + var: Var, +} + +impl Applier for IntegerSqrt { + fn apply_one(&self, + egraph: &mut EGraph, + matched_id: Id, + subst: &Subst, + searcher_pattern: Option<&PatternAst>, + rule_name: Symbol) + -> Vec { + let var_id = subst[self.var]; + + if let Some(value) = &egraph[var_id].data { + if value.denom == 1 && value.num >= 0 { + // isqrt is nightly only, so we just do this, adding 0.1 against rounding errors + let sq = (f64::sqrt(value.num as f64) + 0.1) as i64; + if value.num == sq*sq { +// println!("square root of integer {} is {}", value.num, sq); + + let sq_id = egraph.add(EquationLanguage::Num(Rational { num: sq, denom: 1 })); + egraph.union(matched_id, sq_id); + + return vec![matched_id, sq_id]; + } + } + } + + vec![] + } +} + pub static RULES: LazyLock> = LazyLock::new(||vec![ rw!("commute-add"; "(+ ?x ?y)" => "(+ ?y ?x)"), rw!("commute-mul"; "(* ?x ?y)" => "(* ?y ?x)"), @@ -185,6 +228,8 @@ pub static RULES: LazyLock> = LazyLock::new(||vec![ rw!("mul-0"; "(* ?x 0)" => "0"), rw!("mul-1"; "(* ?x 1)" => "?x"), + rw!("0-sub"; "(- 0 ?x)" => "(- ?x)"), + rw!("add-sub"; "(+ ?x (* (-1) ?x))" => "0"), // division by zero shouldn't happen unless input is invalid rw!("mul-div"; "(* ?x (rec ?x))" => "1" if is_nonzero_const("?y")), @@ -202,7 +247,7 @@ pub static RULES: LazyLock> = LazyLock::new(||vec![ rw!("factor_poly"; "(+ (* x ?x) ?y)" => "(* ?x (+ x (* ?y (rec ?x))))" if is_nonzero_const("?x")), -// rw!("integer_sqrt"; "(^ ?x (1/2))" => {} if is_const("?x")), + rw!("integer_sqrt"; "(^ ?x (/ 1 2))" => { IntegerSqrt { var: "?x".parse().unwrap() } } if is_const("?x")), ]); pub struct PlusTimesCostFn;