cl_ast/desugar/
constant_folder.rs

1use crate::{
2    ast::{ExprKind as Ek, *},
3    ast_visitor::{Fold, fold::or_fold_expr_kind},
4};
5
6pub struct ConstantFolder;
7
8macro bin_rule(
9    match ($kind: ident, $head: expr, $tail: expr) {
10        $(($op:ident, $impl:expr, $($ty:ident -> $rety:ident),*)),*$(,)?
11    }
12) {
13    #[allow(clippy::all)]
14    match ($kind, $head, $tail) {
15    $($((   BinaryKind::$op,
16            Expr { kind: ExprKind::Literal(Literal::$ty(a)), .. },
17            Expr { kind: ExprKind::Literal(Literal::$ty(b)), .. },
18        ) => {
19            ExprKind::Literal(Literal::$rety($impl(a, b)))
20        },)*)*
21        (kind, head, tail) => ExprKind::Binary(Binary {
22            kind,
23            parts: Box::new((head, tail)),
24        }),
25    }
26}
27
28macro un_rule(
29    match ($kind: ident, $tail: expr) {
30        $(($op:ident, $impl:expr, $($ty:ident),*)),*$(,)?
31    }
32) {
33    match ($kind, $tail) {
34    $($((UnaryKind::$op, Expr { kind: ExprKind::Literal(Literal::$ty(v)), .. }) => {
35            ExprKind::Literal(Literal::$ty($impl(v)))
36        },)*)*
37        (kind, tail) => ExprKind::Unary(Unary { kind, tail: Box::new(tail) }),
38    }
39}
40
41impl Fold for ConstantFolder {
42    fn fold_expr_kind(&mut self, kind: Ek) -> Ek {
43        match kind {
44            Ek::Group(Group { expr }) => self.fold_expr_kind(expr.kind),
45            Ek::Binary(Binary { kind, parts }) => {
46                let (head, tail) = *parts;
47                bin_rule! (match (kind, self.fold_expr(head), self.fold_expr(tail)) {
48                    (Lt,     |a, b| a < b,  Bool -> Bool, Int -> Bool),
49                    (LtEq,   |a, b| a <= b, Bool -> Bool, Int -> Bool),
50                    (Equal,  |a, b| a == b, Bool -> Bool, Int -> Bool),
51                    (NotEq,  |a, b| a != b, Bool -> Bool, Int -> Bool),
52                    (GtEq,   |a, b| a >= b, Bool -> Bool, Int -> Bool),
53                    (Gt,     |a, b| a > b,  Bool -> Bool, Int -> Bool),
54                    (BitAnd, |a, b| a & b,  Bool -> Bool, Int -> Int),
55                    (BitOr,  |a, b| a | b,  Bool -> Bool, Int -> Int),
56                    (BitXor, |a, b| a ^ b,  Bool -> Bool, Int -> Int),
57                    (Shl,    |a, b| a << b, Int -> Int),
58                    (Shr,    |a, b| a >> b, Int -> Int),
59                    (Add,    |a, b| a + b,  Int -> Int),
60                    (Sub,    |a, b| a - b,  Int -> Int),
61                    (Mul,    |a, b| a * b,  Int -> Int),
62                    (Div,    |a, b| a / b,  Int -> Int),
63                    (Rem,    |a, b| a % b,  Int -> Int),
64                    // Cursed bit-smuggled float shenanigans
65                    (Lt,     |a, b| (f64::from_bits(a) < f64::from_bits(b)),  Float -> Bool),
66                    (LtEq,   |a, b| (f64::from_bits(a) >= f64::from_bits(b)), Float -> Bool),
67                    (Equal,  |a, b| (f64::from_bits(a) == f64::from_bits(b)), Float -> Bool),
68                    (NotEq,  |a, b| (f64::from_bits(a) != f64::from_bits(b)), Float -> Bool),
69                    (GtEq,   |a, b| (f64::from_bits(a) <= f64::from_bits(b)), Float -> Bool),
70                    (Gt,     |a, b| (f64::from_bits(a) > f64::from_bits(b)),  Float -> Bool),
71                    (Add,    |a, b| (f64::from_bits(a) + f64::from_bits(b)).to_bits(),  Float -> Float),
72                    (Sub,    |a, b| (f64::from_bits(a) - f64::from_bits(b)).to_bits(),  Float -> Float),
73                    (Mul,    |a, b| (f64::from_bits(a) * f64::from_bits(b)).to_bits(),  Float -> Float),
74                    (Div,    |a, b| (f64::from_bits(a) / f64::from_bits(b)).to_bits(),  Float -> Float),
75                    (Rem,    |a, b| (f64::from_bits(a) % f64::from_bits(b)).to_bits(),  Float -> Float),
76                })
77            }
78            Ek::Unary(Unary { kind, tail }) => {
79                un_rule! (match (kind, self.fold_expr(*tail)) {
80                    (Not, std::ops::Not::not, Int, Bool),
81                    (Neg, std::ops::Not::not, Int, Bool),
82                    (Neg, |f| (-f64::from_bits(f)).to_bits(), Float),
83                    (At, std::ops::Not::not, Float), /* Lmao */
84                })
85            }
86            _ => or_fold_expr_kind(self, kind),
87        }
88    }
89}