all examples work
This commit is contained in:
parent
ef2a76869f
commit
c110dd6889
@ -1,5 +1,5 @@
|
|||||||
use std::{cmp::Ordering, fmt::{self,Display, Formatter}, str::FromStr, sync::LazyLock};
|
use std::{cmp::Ordering, fmt::{self,Display, Formatter}, num::ParseIntError, str::FromStr, sync::LazyLock};
|
||||||
use egg::{define_language, merge_option, rewrite as rw, Analysis, DidMerge, Id, Language, Subst, Var};
|
use egg::{define_language, merge_option, rewrite as rw, Analysis, Applier, DidMerge, Id, Language, PatternAst, Subst, Symbol, SymbolLang, Var};
|
||||||
|
|
||||||
pub type EGraph = egg::EGraph<EquationLanguage, ConstantFold>;
|
pub type EGraph = egg::EGraph<EquationLanguage, ConstantFold>;
|
||||||
pub type Rewrite = egg::Rewrite<EquationLanguage, ConstantFold>;
|
pub type Rewrite = egg::Rewrite<EquationLanguage, ConstantFold>;
|
||||||
@ -39,9 +39,19 @@ impl Display for Rational {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl FromStr for Rational {
|
impl FromStr for Rational {
|
||||||
type Err = std::num::ParseIntError;
|
type Err = String;
|
||||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
fn from_str(s: &str) -> Result<Self, String> {
|
||||||
Ok(Rational { num: s.parse::<i64>()?, denom: 1 })
|
let err = || Err(format!("Couldn't parse rational: {}", s));
|
||||||
|
|
||||||
|
if let Ok(num) = s.parse::<i64>() {
|
||||||
|
Ok(Rational { num, denom: 1 })
|
||||||
|
} else if let Some((snum, sdenom)) = s.split_once('/') {
|
||||||
|
let Ok(num) = snum.parse::<i64>() else { return err(); };
|
||||||
|
let Ok(denom) = sdenom.parse::<u64>() else { return err(); };
|
||||||
|
Ok(Rational { num, denom })
|
||||||
|
} else {
|
||||||
|
err()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -174,6 +184,39 @@ fn is_const(var: &str) -> impl Fn(&mut EGraph, Id, &Subst) -> bool {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct IntegerSqrt {
|
||||||
|
var: Var,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Applier<EquationLanguage, ConstantFold> for IntegerSqrt {
|
||||||
|
fn apply_one(&self,
|
||||||
|
egraph: &mut EGraph,
|
||||||
|
matched_id: Id,
|
||||||
|
subst: &Subst,
|
||||||
|
searcher_pattern: Option<&PatternAst<EquationLanguage>>,
|
||||||
|
rule_name: Symbol)
|
||||||
|
-> Vec<Id> {
|
||||||
|
let var_id = subst[self.var];
|
||||||
|
|
||||||
|
if let Some(value) = &egraph[var_id].data {
|
||||||
|
if value.denom == 1 && value.num >= 0 {
|
||||||
|
// isqrt is nightly only, so we just do this, adding 0.1 against rounding errors
|
||||||
|
let sq = (f64::sqrt(value.num as f64) + 0.1) as i64;
|
||||||
|
if value.num == sq*sq {
|
||||||
|
// println!("square root of integer {} is {}", value.num, sq);
|
||||||
|
|
||||||
|
let sq_id = egraph.add(EquationLanguage::Num(Rational { num: sq, denom: 1 }));
|
||||||
|
egraph.union(matched_id, sq_id);
|
||||||
|
|
||||||
|
return vec![matched_id, sq_id];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
vec![]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub static RULES: LazyLock<Vec<Rewrite>> = LazyLock::new(||vec![
|
pub static RULES: LazyLock<Vec<Rewrite>> = LazyLock::new(||vec![
|
||||||
rw!("commute-add"; "(+ ?x ?y)" => "(+ ?y ?x)"),
|
rw!("commute-add"; "(+ ?x ?y)" => "(+ ?y ?x)"),
|
||||||
rw!("commute-mul"; "(* ?x ?y)" => "(* ?y ?x)"),
|
rw!("commute-mul"; "(* ?x ?y)" => "(* ?y ?x)"),
|
||||||
@ -185,6 +228,8 @@ pub static RULES: LazyLock<Vec<Rewrite>> = LazyLock::new(||vec![
|
|||||||
rw!("mul-0"; "(* ?x 0)" => "0"),
|
rw!("mul-0"; "(* ?x 0)" => "0"),
|
||||||
rw!("mul-1"; "(* ?x 1)" => "?x"),
|
rw!("mul-1"; "(* ?x 1)" => "?x"),
|
||||||
|
|
||||||
|
rw!("0-sub"; "(- 0 ?x)" => "(- ?x)"),
|
||||||
|
|
||||||
rw!("add-sub"; "(+ ?x (* (-1) ?x))" => "0"),
|
rw!("add-sub"; "(+ ?x (* (-1) ?x))" => "0"),
|
||||||
// division by zero shouldn't happen unless input is invalid
|
// division by zero shouldn't happen unless input is invalid
|
||||||
rw!("mul-div"; "(* ?x (rec ?x))" => "1" if is_nonzero_const("?y")),
|
rw!("mul-div"; "(* ?x (rec ?x))" => "1" if is_nonzero_const("?y")),
|
||||||
@ -202,7 +247,7 @@ pub static RULES: LazyLock<Vec<Rewrite>> = LazyLock::new(||vec![
|
|||||||
|
|
||||||
rw!("factor_poly"; "(+ (* x ?x) ?y)" => "(* ?x (+ x (* ?y (rec ?x))))" if is_nonzero_const("?x")),
|
rw!("factor_poly"; "(+ (* x ?x) ?y)" => "(* ?x (+ x (* ?y (rec ?x))))" if is_nonzero_const("?x")),
|
||||||
|
|
||||||
// rw!("integer_sqrt"; "(^ ?x (1/2))" => {} if is_const("?x")),
|
rw!("integer_sqrt"; "(^ ?x (/ 1 2))" => { IntegerSqrt { var: "?x".parse().unwrap() } } if is_const("?x")),
|
||||||
]);
|
]);
|
||||||
|
|
||||||
pub struct PlusTimesCostFn;
|
pub struct PlusTimesCostFn;
|
||||||
|
Loading…
Reference in New Issue
Block a user