cl_interpret/function/
collect_upvars.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
//! Collects the "Upvars" of a function at the point of its creation, allowing variable capture
use crate::{convalue::ConValue, env::Environment};
use cl_ast::{ast_visitor::visit::*, Function, Let, Param, Path, PathPart, Pattern, Sym};
use std::collections::{HashMap, HashSet};

pub fn collect_upvars(f: &Function, env: &Environment) -> super::Upvars {
    CollectUpvars::new(env).get_upvars(f)
}

#[derive(Clone, Debug)]
pub struct CollectUpvars<'env> {
    env: &'env Environment,
    upvars: HashMap<Sym, Option<ConValue>>,
    blacklist: HashSet<Sym>,
}

impl<'env> CollectUpvars<'env> {
    pub fn new(env: &'env Environment) -> Self {
        Self { upvars: HashMap::new(), blacklist: HashSet::new(), env }
    }
    pub fn get_upvars(mut self, f: &cl_ast::Function) -> HashMap<Sym, Option<ConValue>> {
        self.visit_function(f);
        self.upvars
    }

    pub fn add_upvar(&mut self, name: &Sym) {
        let Self { env, upvars, blacklist } = self;
        if blacklist.contains(name) || upvars.contains_key(name) {
            return;
        }
        if let Ok(upvar) = env.get_local(*name) {
            upvars.insert(*name, Some(upvar));
        }
    }

    pub fn bind_name(&mut self, name: &Sym) {
        self.blacklist.insert(*name);
    }
}

impl<'a> Visit<'a> for CollectUpvars<'_> {
    fn visit_block(&mut self, b: &'a cl_ast::Block) {
        let blacklist = self.blacklist.clone();

        // visit the block
        let cl_ast::Block { stmts } = b;
        stmts.iter().for_each(|s| self.visit_stmt(s));

        // restore the blacklist
        self.blacklist = blacklist;
    }

    fn visit_let(&mut self, l: &'a cl_ast::Let) {
        let Let { mutable, name, ty, init } = l;
        self.visit_mutability(mutable);
        if let Some(ty) = ty {
            self.visit_ty(ty);
        }
        // visit the initializer, which may use the bound name
        if let Some(init) = init {
            self.visit_expr(init)
        }
        // a bound name can never be an upvar
        self.visit_pattern(name);
    }

    fn visit_function(&mut self, f: &'a cl_ast::Function) {
        let Function { name: _, sign: _, bind, body } = f;
        // parameters can never be upvars
        for Param { mutability: _, name } in bind {
            self.bind_name(name);
        }
        if let Some(body) = body {
            self.visit_expr(body);
        }
    }

    fn visit_for(&mut self, f: &'a cl_ast::For) {
        let cl_ast::For { bind, cond, pass, fail } = f;
        self.visit_expr(cond);
        self.visit_else(fail);
        self.bind_name(bind); // TODO: is bind only bound in the pass block?
        self.visit_block(pass);
    }

    fn visit_path(&mut self, p: &'a cl_ast::Path) {
        // TODO: path resolution in environments
        let Path { absolute: false, parts } = p else {
            return;
        };
        let [PathPart::Ident(name)] = parts.as_slice() else {
            return;
        };
        self.add_upvar(name);
    }

    fn visit_fielder(&mut self, f: &'a cl_ast::Fielder) {
        let cl_ast::Fielder { name, init } = f;
        if let Some(init) = init {
            self.visit_expr(init);
        } else {
            self.add_upvar(name); // fielder without init grabs from env
        }
    }

    fn visit_pattern(&mut self, p: &'a cl_ast::Pattern) {
        match p {
            Pattern::Path(path) => {
                if let [PathPart::Ident(name)] = path.parts.as_slice() {
                    self.bind_name(name)
                }
            }
            Pattern::Literal(literal) => self.visit_literal(literal),
            Pattern::Ref(mutability, pattern) => {
                self.visit_mutability(mutability);
                self.visit_pattern(pattern);
            }
            Pattern::Tuple(patterns) => {
                patterns.iter().for_each(|p| self.visit_pattern(p));
            }
            Pattern::Array(patterns) => {
                patterns.iter().for_each(|p| self.visit_pattern(p));
            }
            Pattern::Struct(path, items) => {
                self.visit_path(path);
                items.iter().for_each(|(_name, bind)| {
                    bind.as_ref().inspect(|bind| {
                        self.visit_pattern(bind);
                    });
                });
            }
        }
    }
}