cl_interpret/
pattern.rs

1//! Unification algorithm for cl-ast [Pattern]s and [ConValue]s
2//!
3//! [`variables()`] returns a flat list of symbols that are bound by a given pattern
4//! [`substitution()`] unifies a ConValue with a pattern, and produces a list of bound names
5
6use crate::{
7    convalue::ConValue,
8    env::Environment,
9    error::{Error, IResult},
10};
11use cl_ast::{Literal, Pattern, Sym};
12use std::collections::{HashMap, VecDeque};
13
14/// Gets the path variables in the given Pattern
15pub fn variables(pat: &Pattern) -> Vec<&Sym> {
16    fn patvars<'p>(set: &mut Vec<&'p Sym>, pat: &'p Pattern) {
17        match pat {
18            Pattern::Name(name) if name.to_ref() == "_" => {}
19            Pattern::Name(name) => set.push(name),
20            Pattern::Path(_) => {}
21            Pattern::Literal(_) => {}
22            Pattern::Rest(Some(pattern)) => patvars(set, pattern),
23            Pattern::Rest(None) => {}
24            Pattern::Ref(_, pattern) => patvars(set, pattern),
25            Pattern::RangeExc(_, _) => {}
26            Pattern::RangeInc(_, _) => {}
27            Pattern::Tuple(patterns) | Pattern::Array(patterns) => {
28                patterns.iter().for_each(|pat| patvars(set, pat))
29            }
30            Pattern::Struct(_path, items) => {
31                items.iter().for_each(|(name, pat)| match pat {
32                    Some(pat) => patvars(set, pat),
33                    None => set.push(name),
34                });
35            }
36            Pattern::TupleStruct(_path, items) => {
37                items.iter().for_each(|pat| patvars(set, pat));
38            }
39        }
40    }
41    let mut set = Vec::new();
42    patvars(&mut set, pat);
43    set
44}
45
46fn rest_binding<'pat>(
47    env: &Environment,
48    sub: &mut HashMap<Sym, ConValue>,
49    mut patterns: &'pat [Pattern],
50    mut values: VecDeque<ConValue>,
51) -> IResult<Option<(&'pat Pattern, VecDeque<ConValue>)>> {
52    // Bind the head of the list
53    while let [pattern, tail @ ..] = patterns {
54        if matches!(pattern, Pattern::Rest(_)) {
55            break;
56        }
57        let value = values
58            .pop_front()
59            .ok_or_else(|| Error::PatFailed(Box::new(pattern.clone())))?;
60        append_sub(env, sub, pattern, value)?;
61        patterns = tail;
62    }
63    // Bind the tail of the list
64    while let [head @ .., pattern] = patterns {
65        if matches!(pattern, Pattern::Rest(_)) {
66            break;
67        }
68        let value = values
69            .pop_back()
70            .ok_or_else(|| Error::PatFailed(Box::new(pattern.clone())))?;
71        append_sub(env, sub, pattern, value)?;
72        patterns = head;
73    }
74    // Bind the ..rest of the list
75    match patterns {
76        [] | [Pattern::Rest(None)] => Ok(None),
77        [Pattern::Rest(Some(pattern))] => Ok(Some((pattern.as_ref(), values))),
78        _ => Err(Error::PatFailed(Box::new(Pattern::Array(patterns.into())))),
79    }
80}
81
82fn rest_binding_ref<'pat>(
83    env: &Environment,
84    sub: &mut HashMap<Sym, ConValue>,
85    mut patterns: &'pat [Pattern],
86    mut head: usize,
87    mut tail: usize,
88) -> IResult<Option<(&'pat Pattern, usize, usize)>> {
89    // Bind the head of the list
90    while let [pattern, pat_tail @ ..] = patterns {
91        if matches!(pattern, Pattern::Rest(_)) {
92            break;
93        }
94        if head >= tail {
95            return Err(Error::PatFailed(Box::new(pattern.clone())));
96        }
97
98        append_sub(env, sub, pattern, ConValue::Ref(head))?;
99        head += 1;
100        patterns = pat_tail;
101    }
102    // Bind the tail of the list
103    while let [pat_head @ .., pattern] = patterns {
104        if matches!(pattern, Pattern::Rest(_)) {
105            break;
106        }
107        if head >= tail {
108            return Err(Error::PatFailed(Box::new(pattern.clone())));
109        };
110        append_sub(env, sub, pattern, ConValue::Ref(tail))?;
111        tail -= 1;
112        patterns = pat_head;
113    }
114    // Bind the ..rest of the list
115    match (patterns, tail - head) {
116        ([], 0) | ([Pattern::Rest(None)], _) => Ok(None),
117        ([Pattern::Rest(Some(pattern))], _) => Ok(Some((pattern.as_ref(), head, tail))),
118        _ => Err(Error::PatFailed(Box::new(Pattern::Array(patterns.into())))),
119    }
120}
121
122/// Appends a substitution to the provided table
123pub fn append_sub(
124    env: &Environment,
125    sub: &mut HashMap<Sym, ConValue>,
126    pat: &Pattern,
127    value: ConValue,
128) -> IResult<()> {
129    match (pat, value) {
130        (Pattern::Literal(Literal::Bool(a)), ConValue::Bool(b)) => {
131            (*a == b).then_some(()).ok_or(Error::NotAssignable())
132        }
133        (Pattern::Literal(Literal::Char(a)), ConValue::Char(b)) => {
134            (*a == b).then_some(()).ok_or(Error::NotAssignable())
135        }
136        (Pattern::Literal(Literal::Float(a)), ConValue::Float(b)) => (f64::from_bits(*a) == b)
137            .then_some(())
138            .ok_or(Error::NotAssignable()),
139        (Pattern::Literal(Literal::Int(a)), ConValue::Int(b)) => {
140            (b == *a as _).then_some(()).ok_or(Error::NotAssignable())
141        }
142        (Pattern::Literal(Literal::String(a)), ConValue::Str(b)) => {
143            (*a == *b).then_some(()).ok_or(Error::NotAssignable())
144        }
145        (Pattern::Literal(Literal::String(a)), ConValue::String(b)) => {
146            (*a == *b).then_some(()).ok_or(Error::NotAssignable())
147        }
148        (Pattern::Literal(_), _) => Err(Error::NotAssignable()),
149
150        (Pattern::Rest(Some(pat)), value) => match (pat.as_ref(), value) {
151            (Pattern::Literal(Literal::Int(a)), ConValue::Int(b)) => {
152                (b < *a as _).then_some(()).ok_or(Error::NotAssignable())
153            }
154            (Pattern::Literal(Literal::Char(a)), ConValue::Char(b)) => {
155                (b < *a as _).then_some(()).ok_or(Error::NotAssignable())
156            }
157            (Pattern::Literal(Literal::Bool(a)), ConValue::Bool(b)) => {
158                (!b & *a).then_some(()).ok_or(Error::NotAssignable())
159            }
160            (Pattern::Literal(Literal::Float(a)), ConValue::Float(b)) => {
161                (b < *a as _).then_some(()).ok_or(Error::NotAssignable())
162            }
163            (Pattern::Literal(Literal::String(a)), ConValue::Str(b)) => {
164                (&*b < a).then_some(()).ok_or(Error::NotAssignable())
165            }
166            (Pattern::Literal(Literal::String(a)), ConValue::String(b)) => {
167                (&*b < a).then_some(()).ok_or(Error::NotAssignable())
168            }
169            _ => Err(Error::NotAssignable()),
170        },
171
172        (Pattern::Name(name), _) if "_".eq(&**name) => Ok(()),
173        (Pattern::Name(name), value) => {
174            sub.insert(*name, value);
175            Ok(())
176        }
177
178        (Pattern::Ref(_, pat), ConValue::Ref(r)) => match env.get_id(r) {
179            Some(value) => append_sub(env, sub, pat, value.clone()),
180            None => Err(Error::PatFailed(pat.clone())),
181        },
182
183        (Pattern::Ref(_, pat), ConValue::Slice(head, len)) => {
184            let mut values = Vec::with_capacity(len);
185            for idx in head..(head + len) {
186                values.push(env.get_id(idx).cloned().ok_or(Error::StackOverflow(idx))?);
187            }
188            append_sub(env, sub, pat, ConValue::Array(values.into_boxed_slice()))
189        }
190
191        (Pattern::RangeExc(head, tail), value) => match (head.as_ref(), tail.as_ref(), value) {
192            (
193                Pattern::Literal(Literal::Int(a)),
194                Pattern::Literal(Literal::Int(c)),
195                ConValue::Int(b),
196            ) => (*a as isize <= b as _ && b < *c as isize)
197                .then_some(())
198                .ok_or(Error::NotAssignable()),
199            (
200                Pattern::Literal(Literal::Char(a)),
201                Pattern::Literal(Literal::Char(c)),
202                ConValue::Char(b),
203            ) => (*a <= b && b < *c)
204                .then_some(())
205                .ok_or(Error::NotAssignable()),
206            (
207                Pattern::Literal(Literal::Float(a)),
208                Pattern::Literal(Literal::Float(c)),
209                ConValue::Float(b),
210            ) => (f64::from_bits(*a) <= b && b < f64::from_bits(*c))
211                .then_some(())
212                .ok_or(Error::NotAssignable()),
213            (
214                Pattern::Literal(Literal::String(a)),
215                Pattern::Literal(Literal::String(c)),
216                ConValue::Str(b),
217            ) => (a.as_str() <= b.to_ref() && b.to_ref() < c.as_str())
218                .then_some(())
219                .ok_or(Error::NotAssignable()),
220            (
221                Pattern::Literal(Literal::String(a)),
222                Pattern::Literal(Literal::String(c)),
223                ConValue::String(b),
224            ) => (a.as_str() <= b.as_str() && b.as_str() < c.as_str())
225                .then_some(())
226                .ok_or(Error::NotAssignable()),
227            _ => Err(Error::NotAssignable()),
228        },
229
230        (Pattern::RangeInc(head, tail), value) => match (head.as_ref(), tail.as_ref(), value) {
231            (
232                Pattern::Literal(Literal::Int(a)),
233                Pattern::Literal(Literal::Int(c)),
234                ConValue::Int(b),
235            ) => (*a as isize <= b && b <= *c as isize)
236                .then_some(())
237                .ok_or(Error::NotAssignable()),
238            (
239                Pattern::Literal(Literal::Char(a)),
240                Pattern::Literal(Literal::Char(c)),
241                ConValue::Char(b),
242            ) => (*a <= b && b <= *c)
243                .then_some(())
244                .ok_or(Error::NotAssignable()),
245            (
246                Pattern::Literal(Literal::Float(a)),
247                Pattern::Literal(Literal::Float(c)),
248                ConValue::Float(b),
249            ) => (f64::from_bits(*a) <= b && b <= f64::from_bits(*c))
250                .then_some(())
251                .ok_or(Error::NotAssignable()),
252            (
253                Pattern::Literal(Literal::String(a)),
254                Pattern::Literal(Literal::String(c)),
255                ConValue::Str(b),
256            ) => (a.as_str() <= b.to_ref() && b.to_ref() <= c.as_str())
257                .then_some(())
258                .ok_or(Error::NotAssignable()),
259            (
260                Pattern::Literal(Literal::String(a)),
261                Pattern::Literal(Literal::String(c)),
262                ConValue::String(b),
263            ) => (a.as_str() <= b.as_str() && b.as_str() <= c.as_str())
264                .then_some(())
265                .ok_or(Error::NotAssignable()),
266            _ => Err(Error::NotAssignable()),
267        },
268
269        (Pattern::Array(patterns), ConValue::Array(values)) => {
270            match rest_binding(env, sub, patterns, values.into_vec().into())? {
271                Some((pattern, values)) => {
272                    append_sub(env, sub, pattern, ConValue::Array(Vec::from(values).into()))
273                }
274                _ => Ok(()),
275            }
276        }
277
278        (Pattern::Array(patterns), ConValue::Slice(head, len)) => {
279            match rest_binding_ref(env, sub, patterns, head, head + len)? {
280                Some((pat, head, tail)) => {
281                    append_sub(env, sub, pat, ConValue::Slice(head, tail - head))
282                }
283                None => Ok(()),
284            }
285        }
286
287        (Pattern::Tuple(patterns), ConValue::Empty) if patterns.is_empty() => Ok(()),
288        (Pattern::Tuple(patterns), ConValue::Tuple(values)) => {
289            match rest_binding(env, sub, patterns, values.into_vec().into())? {
290                Some((pattern, values)) => {
291                    append_sub(env, sub, pattern, ConValue::Tuple(Vec::from(values).into()))
292                }
293                _ => Ok(()),
294            }
295        }
296
297        (Pattern::TupleStruct(path, patterns), ConValue::TupleStruct(id, values)) => {
298            let tid = path
299                .as_sym()
300                .ok_or_else(|| Error::PatFailed(pat.clone().into()))?;
301            if id != tid {
302                return Err(Error::PatFailed(pat.clone().into()));
303            }
304            match rest_binding(env, sub, patterns, values.into_vec().into())? {
305                Some((pattern, values)) => {
306                    append_sub(env, sub, pattern, ConValue::Tuple(Vec::from(values).into()))
307                }
308                _ => Ok(()),
309            }
310        }
311
312        (Pattern::Struct(path, patterns), ConValue::Struct(id, mut values)) => {
313            let tid = path
314                .as_sym()
315                .ok_or_else(|| Error::PatFailed(pat.clone().into()))?;
316            if id != tid {
317                return Err(Error::PatFailed(pat.clone().into()));
318            }
319            for (name, pat) in patterns {
320                let value = values.remove(name).ok_or(Error::TypeError())?;
321                match pat {
322                    Some(pat) => append_sub(env, sub, pat, value)?,
323                    None => {
324                        sub.insert(*name, value);
325                    }
326                }
327            }
328            Ok(())
329        }
330
331        _ => {
332            // eprintln!("Could not match pattern `{pat}` with value `{value}`!");
333            Err(Error::NotAssignable())
334        }
335    }
336}
337
338/// Constructs a substitution from a pattern and a value
339pub fn substitution(
340    env: &Environment,
341    pat: &Pattern,
342    value: ConValue,
343) -> IResult<HashMap<Sym, ConValue>> {
344    let mut sub = HashMap::new();
345    append_sub(env, &mut sub, pat, value)?;
346    Ok(sub)
347}