cl_interpret/
pattern.rs

1//! Unification algorithm for cl-ast [Pattern]s and [ConValue]s
2//!
3//! [`variables()`] returns a flat list of symbols that are bound by a given pattern
4//! [`substitution()`] unifies a ConValue with a pattern, and produces a list of bound names
5
6use crate::{
7    convalue::ConValue,
8    error::{Error, IResult},
9};
10use cl_ast::{Literal, Pattern, Sym};
11use std::collections::{HashMap, VecDeque};
12
13/// Gets the path variables in the given Pattern
14pub fn variables(pat: &Pattern) -> Vec<&Sym> {
15    fn patvars<'p>(set: &mut Vec<&'p Sym>, pat: &'p Pattern) {
16        match pat {
17            Pattern::Name(name) if name.to_ref() == "_" => {}
18            Pattern::Name(name) => set.push(name),
19            Pattern::Path(_) => {}
20            Pattern::Literal(_) => {}
21            Pattern::Rest(Some(pattern)) => patvars(set, pattern),
22            Pattern::Rest(None) => {}
23            Pattern::Ref(_, pattern) => patvars(set, pattern),
24            Pattern::RangeExc(_, _) => {}
25            Pattern::RangeInc(_, _) => {}
26            Pattern::Tuple(patterns) | Pattern::Array(patterns) => {
27                patterns.iter().for_each(|pat| patvars(set, pat))
28            }
29            Pattern::Struct(_path, items) => {
30                items.iter().for_each(|(name, pat)| match pat {
31                    Some(pat) => patvars(set, pat),
32                    None => set.push(name),
33                });
34            }
35            Pattern::TupleStruct(_path, items) => {
36                items.iter().for_each(|pat| patvars(set, pat));
37            }
38        }
39    }
40    let mut set = Vec::new();
41    patvars(&mut set, pat);
42    set
43}
44
45fn rest_binding<'pat>(
46    sub: &mut HashMap<&'pat Sym, ConValue>,
47    mut patterns: &'pat [Pattern],
48    mut values: VecDeque<ConValue>,
49) -> IResult<Option<(&'pat Pattern, VecDeque<ConValue>)>> {
50    // Bind the head of the list
51    while let [pattern, tail @ ..] = patterns {
52        if matches!(pattern, Pattern::Rest(_)) {
53            break;
54        }
55        let value = values
56            .pop_front()
57            .ok_or_else(|| Error::PatFailed(Box::new(pattern.clone())))?;
58        append_sub(sub, pattern, value)?;
59        patterns = tail;
60    }
61    // Bind the tail of the list
62    while let [head @ .., pattern] = patterns {
63        if matches!(pattern, Pattern::Rest(_)) {
64            break;
65        }
66        let value = values
67            .pop_back()
68            .ok_or_else(|| Error::PatFailed(Box::new(pattern.clone())))?;
69        append_sub(sub, pattern, value)?;
70        patterns = head;
71    }
72    // Bind the ..rest of the list
73    match patterns {
74        [] | [Pattern::Rest(None)] => Ok(None),
75        [Pattern::Rest(Some(pattern))] => Ok(Some((pattern.as_ref(), values))),
76        _ => Err(Error::PatFailed(Box::new(Pattern::Array(patterns.into())))),
77    }
78}
79
80/// Appends a substitution to the provided table
81pub fn append_sub<'pat>(
82    sub: &mut HashMap<&'pat Sym, ConValue>,
83    pat: &'pat Pattern,
84    value: ConValue,
85) -> IResult<()> {
86    match (pat, value) {
87        (Pattern::Literal(Literal::Bool(a)), ConValue::Bool(b)) => {
88            (*a == b).then_some(()).ok_or(Error::NotAssignable())
89        }
90        (Pattern::Literal(Literal::Char(a)), ConValue::Char(b)) => {
91            (*a == b).then_some(()).ok_or(Error::NotAssignable())
92        }
93        (Pattern::Literal(Literal::Float(a)), ConValue::Float(b)) => (f64::from_bits(*a) == b)
94            .then_some(())
95            .ok_or(Error::NotAssignable()),
96        (Pattern::Literal(Literal::Int(a)), ConValue::Int(b)) => {
97            (b == *a as _).then_some(()).ok_or(Error::NotAssignable())
98        }
99        (Pattern::Literal(Literal::String(a)), ConValue::String(b)) => {
100            (*a == *b).then_some(()).ok_or(Error::NotAssignable())
101        }
102        (Pattern::Literal(_), _) => Err(Error::NotAssignable()),
103
104        (Pattern::Rest(Some(pat)), value) => match (pat.as_ref(), value) {
105            (Pattern::Literal(Literal::Int(a)), ConValue::Int(b)) => {
106                (b < *a as _).then_some(()).ok_or(Error::NotAssignable())
107            }
108            (Pattern::Literal(Literal::Char(a)), ConValue::Char(b)) => {
109                (b < *a as _).then_some(()).ok_or(Error::NotAssignable())
110            }
111            (Pattern::Literal(Literal::Bool(a)), ConValue::Bool(b)) => {
112                (!b & *a).then_some(()).ok_or(Error::NotAssignable())
113            }
114            (Pattern::Literal(Literal::Float(a)), ConValue::Float(b)) => {
115                (b < *a as _).then_some(()).ok_or(Error::NotAssignable())
116            }
117            (Pattern::Literal(Literal::String(a)), ConValue::String(b)) => {
118                (&*b < a).then_some(()).ok_or(Error::NotAssignable())
119            }
120            _ => Err(Error::NotAssignable()),
121        },
122
123        (Pattern::Name(name), _) if "_".eq(&**name) => Ok(()),
124        (Pattern::Name(name), value) => {
125            sub.insert(name, value);
126            Ok(())
127        }
128
129        (Pattern::Ref(_, pat), ConValue::Ref(r)) => {
130            todo!("Dereference <{r}> in pattern matching {pat}")
131        }
132
133        (Pattern::RangeExc(head, tail), value) => match (head.as_ref(), tail.as_ref(), value) {
134            (
135                Pattern::Literal(Literal::Int(a)),
136                Pattern::Literal(Literal::Int(c)),
137                ConValue::Int(b),
138            ) => (*a as isize <= b as _ && b < *c as isize)
139                .then_some(())
140                .ok_or(Error::NotAssignable()),
141            (
142                Pattern::Literal(Literal::Char(a)),
143                Pattern::Literal(Literal::Char(c)),
144                ConValue::Char(b),
145            ) => (*a <= b && b < *c)
146                .then_some(())
147                .ok_or(Error::NotAssignable()),
148            (
149                Pattern::Literal(Literal::Float(a)),
150                Pattern::Literal(Literal::Float(c)),
151                ConValue::Float(b),
152            ) => (f64::from_bits(*a) <= b && b < f64::from_bits(*c))
153                .then_some(())
154                .ok_or(Error::NotAssignable()),
155            (
156                Pattern::Literal(Literal::String(a)),
157                Pattern::Literal(Literal::String(c)),
158                ConValue::String(b),
159            ) => (a.as_str() <= b.to_ref() && b.to_ref() < c.as_str())
160                .then_some(())
161                .ok_or(Error::NotAssignable()),
162            _ => Err(Error::NotAssignable()),
163        },
164
165        (Pattern::RangeInc(head, tail), value) => match (head.as_ref(), tail.as_ref(), value) {
166            (
167                Pattern::Literal(Literal::Int(a)),
168                Pattern::Literal(Literal::Int(c)),
169                ConValue::Int(b),
170            ) => (*a as isize <= b && b <= *c as isize)
171                .then_some(())
172                .ok_or(Error::NotAssignable()),
173            (
174                Pattern::Literal(Literal::Char(a)),
175                Pattern::Literal(Literal::Char(c)),
176                ConValue::Char(b),
177            ) => (*a <= b && b <= *c)
178                .then_some(())
179                .ok_or(Error::NotAssignable()),
180            (
181                Pattern::Literal(Literal::Float(a)),
182                Pattern::Literal(Literal::Float(c)),
183                ConValue::Float(b),
184            ) => (f64::from_bits(*a) <= b && b <= f64::from_bits(*c))
185                .then_some(())
186                .ok_or(Error::NotAssignable()),
187            (
188                Pattern::Literal(Literal::String(a)),
189                Pattern::Literal(Literal::String(c)),
190                ConValue::String(b),
191            ) => (a.as_str() <= b.to_ref() && b.to_ref() <= c.as_str())
192                .then_some(())
193                .ok_or(Error::NotAssignable()),
194            _ => Err(Error::NotAssignable()),
195        },
196
197        (Pattern::Array(patterns), ConValue::Array(values)) => {
198            match rest_binding(sub, patterns, values.into_vec().into())? {
199                Some((pattern, values)) => {
200                    append_sub(sub, pattern, ConValue::Array(Vec::from(values).into()))
201                }
202                _ => Ok(()),
203            }
204        }
205
206        (Pattern::Tuple(patterns), ConValue::Empty) if patterns.is_empty() => Ok(()),
207        (Pattern::Tuple(patterns), ConValue::Tuple(values)) => {
208            match rest_binding(sub, patterns, values.into_vec().into())? {
209                Some((pattern, values)) => {
210                    append_sub(sub, pattern, ConValue::Tuple(Vec::from(values).into()))
211                }
212                _ => Ok(()),
213            }
214        }
215
216        (Pattern::TupleStruct(path, patterns), ConValue::TupleStruct(parts)) => {
217            let (name, values) = *parts;
218            if !path.ends_with(name) {
219                Err(Error::TypeError())?
220            }
221            match rest_binding(sub, patterns, values.into_vec().into())? {
222                Some((pattern, values)) => {
223                    append_sub(sub, pattern, ConValue::Tuple(Vec::from(values).into()))
224                }
225                _ => Ok(()),
226            }
227        }
228
229        (Pattern::Struct(path, patterns), ConValue::Struct(parts)) => {
230            let (name, mut values) = *parts;
231            if !path.ends_with(&name) {
232                Err(Error::TypeError())?
233            }
234            for (name, pat) in patterns {
235                let value = values.remove(name).ok_or(Error::TypeError())?;
236                match pat {
237                    Some(pat) => append_sub(sub, pat, value)?,
238                    None => {
239                        sub.insert(name, value);
240                    }
241                }
242            }
243            Ok(())
244        }
245
246        _ => {
247            // eprintln!("Could not match pattern `{pat}` with value `{value}`!");
248            Err(Error::NotAssignable())
249        }
250    }
251}
252
253/// Constructs a substitution from a pattern and a value
254pub fn substitution(pat: &Pattern, value: ConValue) -> IResult<HashMap<&Sym, ConValue>> {
255    let mut sub = HashMap::new();
256    append_sub(&mut sub, pat, value)?;
257    Ok(sub)
258}