cl_interpret/function/
collect_upvars.rs

1//! Collects the "Upvars" of a function at the point of its creation, allowing variable capture
2use crate::env::{Environment, Place};
3use cl_ast::{
4    Function, Let, Path, PathPart, Pattern, Sym,
5    ast_visitor::{visit::*, walk::Walk},
6};
7use std::collections::{HashMap, HashSet};
8
9pub fn collect_upvars(f: &Function, env: &Environment) -> super::Upvars {
10    CollectUpvars::new(env).visit(f).finish_copied()
11}
12
13#[derive(Clone, Debug)]
14pub struct CollectUpvars<'env> {
15    env: &'env Environment,
16    upvars: HashMap<Sym, Place>,
17    blacklist: HashSet<Sym>,
18}
19
20impl<'env> CollectUpvars<'env> {
21    pub fn new(env: &'env Environment) -> Self {
22        Self { upvars: HashMap::new(), blacklist: HashSet::new(), env }
23    }
24
25    pub fn finish(&mut self) -> HashMap<Sym, Place> {
26        std::mem::take(&mut self.upvars)
27    }
28
29    pub fn finish_copied(&mut self) -> super::Upvars {
30        let Self { env, upvars, blacklist: _ } = self;
31        std::mem::take(upvars)
32            .into_iter()
33            .map(|(k, v)| (k, env.get_id(v).cloned()))
34            .collect()
35    }
36
37    pub fn add_upvar(&mut self, name: &Sym) {
38        let Self { env, upvars, blacklist } = self;
39        if blacklist.contains(name) || upvars.contains_key(name) {
40            return;
41        }
42        if let Ok(place) = env.id_of(*name) {
43            upvars.insert(*name, place);
44        }
45    }
46
47    pub fn bind_name(&mut self, name: &Sym) {
48        self.blacklist.insert(*name);
49    }
50}
51
52impl<'a> Visit<'a> for CollectUpvars<'_> {
53    fn visit_block(&mut self, b: &'a cl_ast::Block) {
54        let blacklist = self.blacklist.clone();
55
56        // visit the block
57        b.children(self);
58
59        // restore the blacklist
60        self.blacklist = blacklist;
61    }
62
63    fn visit_let(&mut self, l: &'a cl_ast::Let) {
64        let Let { mutable, name, ty, init } = l;
65        self.visit_mutability(mutable);
66
67        ty.visit_in(self);
68        // visit the initializer, which may use the bound name
69        init.visit_in(self);
70        // a bound name can never be an upvar
71        self.visit_pattern(name);
72    }
73
74    fn visit_function(&mut self, f: &'a cl_ast::Function) {
75        let Function { name: _, gens: _, sign: _, bind, body } = f;
76        // parameters can never be upvars
77        bind.visit_in(self);
78        body.visit_in(self);
79    }
80
81    fn visit_for(&mut self, f: &'a cl_ast::For) {
82        let cl_ast::For { bind, cond, pass, fail } = f;
83        self.visit_expr(cond);
84        self.visit_else(fail);
85        self.visit_pattern(bind);
86        self.visit_block(pass);
87    }
88
89    fn visit_path(&mut self, p: &'a cl_ast::Path) {
90        // TODO: path resolution in environments
91        let Path { absolute: false, parts } = p else {
92            return;
93        };
94        let [PathPart::Ident(name)] = parts.as_slice() else {
95            return;
96        };
97        self.add_upvar(name);
98    }
99
100    fn visit_fielder(&mut self, f: &'a cl_ast::Fielder) {
101        let cl_ast::Fielder { name, init } = f;
102        if let Some(init) = init {
103            self.visit_expr(init);
104        } else {
105            self.add_upvar(name); // fielder without init grabs from env
106        }
107    }
108
109    fn visit_pattern(&mut self, p: &'a cl_ast::Pattern) {
110        match p {
111            Pattern::Name(name) => {
112                self.bind_name(name);
113            }
114            Pattern::RangeExc(_, _) | Pattern::RangeInc(_, _) => {}
115            _ => p.children(self),
116        }
117    }
118}