arithmetic_typing/ast/
conversion.rs

1//! Logic for converting `*Ast` types into their "main" counterparts.
2
3use core::fmt;
4
5use arithmetic_parser::{
6    Error as ParseError, ErrorKind as ParseErrorKind, InputSpan, NomResult, Spanned,
7};
8use nom::Err as NomErr;
9
10use crate::{
11    alloc::{Box, HashMap, HashSet, String, ToOwned},
12    arith::{CompleteConstraints, Constraint, ConstraintSet},
13    ast::{
14        ConstraintsAst, FunctionAst, ObjectAst, SliceAst, SpannedTypeAst, TupleAst, TupleLenAst,
15        TypeAst, TypeConstraintsAst,
16    },
17    error::{Error, Errors},
18    types::{ParamConstraints, ParamQuantifier},
19    DynConstraints, Function, Object, PrimitiveType, Slice, Tuple, Type, TypeEnvironment,
20    UnknownLen,
21};
22
23/// Kinds of errors that can occur when converting `*Ast` types into their "main" counterparts.
24///
25/// During type inference, errors of this type are wrapped into the [`AstConversion`]
26/// variant of typing errors.
27///
28/// [`AstConversion`]: crate::error::ErrorKind::AstConversion
29///
30/// # Examples
31///
32/// ```
33/// use arithmetic_parser::grammars::{Parse, F32Grammar};
34/// use arithmetic_typing::{
35///     ast::AstConversionError, error::ErrorKind, Annotated, TypeEnvironment,
36/// };
37/// # use assert_matches::assert_matches;
38///
39/// # fn main() -> anyhow::Result<()> {
40/// let code = "bogus_slice: ['T; _] = (1, 2, 3);";
41/// let ast = Annotated::<F32Grammar>::parse_statements(code)?;
42///
43/// let errors = TypeEnvironment::new().process_statements(&ast).unwrap_err();
44/// let err = errors.into_iter().next().unwrap();
45/// assert_eq!(err.main_location().span(code), "'T");
46/// assert_matches!(
47///     err.kind(),
48///     ErrorKind::AstConversion(AstConversionError::FreeTypeVar(id))
49///         if id == "T"
50/// );
51/// # Ok(())
52/// # }
53/// ```
54#[derive(Debug, Clone)]
55#[non_exhaustive]
56pub enum AstConversionError {
57    /// Embedded param quantifiers.
58    EmbeddedQuantifier,
59    /// Length param not scoped by a function.
60    FreeLengthVar(String),
61    /// Type param not scoped by a function.
62    FreeTypeVar(String),
63    /// Unused length param.
64    UnusedLength(String),
65    /// Unused length param.
66    UnusedTypeParam(String),
67    /// Unknown type name.
68    UnknownType(String),
69    /// Unknown constraint.
70    UnknownConstraint(String),
71    /// Some type (`_`) encountered when parsing a standalone type.
72    ///
73    /// `_` types are only allowed in the context of a [`TypeEnvironment`]. It is a logical
74    /// error to use them when parsing standalone types.
75    InvalidSomeType,
76    /// Some length (`_`) encountered when parsing a standalone type.
77    ///
78    /// `_` lengths are only allowed in the context of a [`TypeEnvironment`]. It is a logical
79    /// error to use them when parsing standalone types.
80    InvalidSomeLength,
81    /// Field with the same name is defined multiple times in an object type.
82    DuplicateField(String),
83    /// Constraint is not object-safe.
84    NotObjectSafe(String),
85}
86
87impl fmt::Display for AstConversionError {
88    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
89        match self {
90            Self::EmbeddedQuantifier => {
91                formatter.write_str("`for` quantifier for a function that is not top-level")
92            }
93
94            Self::FreeLengthVar(name) => {
95                write!(
96                    formatter,
97                    "Length param `{name}` is not scoped by function definition"
98                )
99            }
100            Self::FreeTypeVar(name) => {
101                write!(
102                    formatter,
103                    "Type param `{name}` is not scoped by function definition"
104                )
105            }
106
107            Self::UnusedLength(name) => {
108                write!(formatter, "Unused length param `{name}`")
109            }
110            Self::UnusedTypeParam(name) => {
111                write!(formatter, "Unused type param `{name}`")
112            }
113            Self::UnknownType(name) => {
114                write!(formatter, "Unknown type `{name}`")
115            }
116            Self::UnknownConstraint(name) => {
117                write!(formatter, "Unknown constraint `{name}`")
118            }
119
120            Self::InvalidSomeType => {
121                formatter.write_str("`_` type is disallowed when parsing standalone type")
122            }
123            Self::InvalidSomeLength => {
124                formatter.write_str("`_` length is disallowed when parsing standalone type")
125            }
126
127            Self::DuplicateField(name) => {
128                write!(formatter, "Duplicate field `{name}` in object type")
129            }
130
131            Self::NotObjectSafe(name) => {
132                write!(formatter, "Constraint `{name}` is not object-safe")
133            }
134        }
135    }
136}
137
138#[cfg(feature = "std")]
139impl std::error::Error for AstConversionError {}
140
141/// Intermediate conversion state.
142#[derive(Debug)]
143pub(crate) struct AstConversionState<'r, 'a, Prim: PrimitiveType> {
144    env: Option<&'r mut TypeEnvironment<Prim>>,
145    known_constraints: ConstraintSet<Prim>,
146    errors: &'r mut Errors<Prim>,
147    len_params: HashMap<&'a str, usize>,
148    type_params: HashMap<&'a str, usize>,
149    is_in_function: bool,
150}
151
152impl<'r, 'a, Prim: PrimitiveType> AstConversionState<'r, 'a, Prim> {
153    pub fn new(env: &'r mut TypeEnvironment<Prim>, errors: &'r mut Errors<Prim>) -> Self {
154        let known_constraints = env.known_constraints.clone();
155        Self {
156            env: Some(env),
157            known_constraints,
158            errors,
159            len_params: HashMap::new(),
160            type_params: HashMap::new(),
161            is_in_function: false,
162        }
163    }
164
165    fn without_env(errors: &'r mut Errors<Prim>) -> Self {
166        Self {
167            env: None,
168            known_constraints: Prim::well_known_constraints(),
169            errors,
170            len_params: HashMap::new(),
171            type_params: HashMap::new(),
172            is_in_function: false,
173        }
174    }
175
176    fn type_param_idx(&mut self, param_name: &'a str) -> usize {
177        let type_param_count = self.type_params.len();
178        *self
179            .type_params
180            .entry(param_name)
181            .or_insert(type_param_count)
182    }
183
184    fn len_param_idx(&mut self, param_name: &'a str) -> usize {
185        let len_param_count = self.len_params.len();
186        *self.len_params.entry(param_name).or_insert(len_param_count)
187    }
188
189    fn new_type(&mut self, span: Option<&SpannedTypeAst<'a>>) -> Type<Prim> {
190        let errors = &mut *self.errors;
191        self.env.as_mut().map_or_else(
192            || {
193                if let Some(span) = span {
194                    let err = AstConversionError::InvalidSomeType;
195                    errors.push(Error::conversion(err, span));
196                }
197                // We don't particularly care about the returned value; the enclosing type
198                // will be discarded anyway.
199                Type::free_var(0)
200            },
201            |env| env.substitutions.new_type_var(),
202        )
203    }
204
205    fn new_len(&mut self, span: Option<&Spanned<'a, TupleLenAst>>) -> UnknownLen {
206        let errors = &mut *self.errors;
207        self.env.as_mut().map_or_else(
208            || {
209                if let Some(span) = span {
210                    let err = AstConversionError::InvalidSomeLength;
211                    errors.push(Error::conversion(err, span));
212                }
213                // We don't particularly care about the returned value; the enclosing type
214                // will be discarded anyway.
215                UnknownLen::free_var(0)
216            },
217            |env| env.substitutions.new_len_var(),
218        )
219    }
220
221    fn resolve_constraint(&self, name: &str) -> Option<(Box<dyn Constraint<Prim>>, bool)> {
222        self.known_constraints
223            .get_by_name(name)
224            .map(|(constraint, is_object_safe)| (constraint.clone_boxed(), is_object_safe))
225    }
226
227    pub(crate) fn convert_type(&mut self, ty: &SpannedTypeAst<'a>) -> Type<Prim> {
228        match &ty.extra {
229            TypeAst::Some => self.new_type(Some(ty)),
230            TypeAst::Any => Type::Any,
231            TypeAst::Dyn(constraints) => Type::Dyn(constraints.convert_dyn(self)),
232            TypeAst::Ident => {
233                let ident = *ty.fragment();
234                if let Ok(prim_type) = Prim::from_str(ident) {
235                    Type::Prim(prim_type)
236                } else {
237                    let err = AstConversionError::UnknownType(ident.to_owned());
238                    self.errors.push(Error::conversion(err, ty));
239                    self.new_type(None)
240                }
241            }
242
243            TypeAst::Param => {
244                let name = &ty.fragment()[1..];
245                if self.is_in_function {
246                    let idx = self.type_param_idx(name);
247                    Type::param(idx)
248                } else {
249                    let err = AstConversionError::FreeTypeVar(name.to_owned());
250                    self.errors.push(Error::conversion(err, ty));
251                    self.new_type(None)
252                }
253            }
254
255            TypeAst::Function(function) => self.convert_fn(function, None),
256            TypeAst::FunctionWithConstraints {
257                function,
258                constraints,
259            } => self.convert_fn(&function.extra, Some(constraints)),
260
261            TypeAst::Tuple(tuple) => tuple.convert(self).into(),
262            TypeAst::Slice(slice) => slice.convert(self).into(),
263            TypeAst::Object(object) => object.convert(self).into(),
264        }
265    }
266
267    fn convert_fn(
268        &mut self,
269        function: &FunctionAst<'a>,
270        constraints: Option<&Spanned<'a, ConstraintsAst<'a>>>,
271    ) -> Type<Prim> {
272        if self.is_in_function {
273            if let Some(constraints) = constraints {
274                let err = AstConversionError::EmbeddedQuantifier;
275                self.errors.push(Error::conversion(err, constraints));
276            }
277            function.convert(self).into()
278        } else {
279            self.is_in_function = true;
280            let mut converted_fn = function.convert(self);
281            let constraints =
282                constraints.map_or_else(ParamConstraints::default, |c| c.extra.convert(self));
283            ParamQuantifier::fill_params(&mut converted_fn, constraints);
284
285            self.is_in_function = false;
286            self.type_params.clear();
287            self.len_params.clear();
288            converted_fn.into()
289        }
290    }
291}
292
293impl<'a> TypeConstraintsAst<'a> {
294    fn convert<Prim: PrimitiveType>(
295        &self,
296        state: &mut AstConversionState<'_, 'a, Prim>,
297    ) -> CompleteConstraints<Prim> {
298        self.do_convert(state, false)
299    }
300
301    fn convert_dyn<Prim: PrimitiveType>(
302        &self,
303        state: &mut AstConversionState<'_, 'a, Prim>,
304    ) -> DynConstraints<Prim> {
305        DynConstraints {
306            inner: self.do_convert(state, true),
307        }
308    }
309
310    fn do_convert<Prim: PrimitiveType>(
311        &self,
312        state: &mut AstConversionState<'_, 'a, Prim>,
313        require_object_safety: bool,
314    ) -> CompleteConstraints<Prim> {
315        let mut constraints = CompleteConstraints::default();
316        if let Some(object) = &self.object {
317            constraints.object = Some(object.convert(state));
318        }
319
320        self.terms.iter().fold(constraints, |mut acc, input| {
321            let input_str = *input.fragment();
322            if let Some((constraint, is_object_safe)) = state.resolve_constraint(input_str) {
323                if require_object_safety && !is_object_safe {
324                    let err = AstConversionError::NotObjectSafe(input_str.to_owned());
325                    state.errors.push(Error::conversion(err, input));
326                } else {
327                    acc.simple.insert_boxed(constraint);
328                }
329            } else {
330                let err = AstConversionError::UnknownConstraint(input_str.to_owned());
331                state.errors.push(Error::conversion(err, input));
332            }
333            acc
334        })
335    }
336}
337
338impl<'a> ConstraintsAst<'a> {
339    fn convert<Prim: PrimitiveType>(
340        &self,
341        state: &mut AstConversionState<'_, 'a, Prim>,
342    ) -> ParamConstraints<Prim> {
343        let mut static_lengths = HashSet::with_capacity(self.static_lengths.len());
344        for dyn_length in &self.static_lengths {
345            let name = *dyn_length.fragment();
346            if let Some(index) = state.len_params.get(name) {
347                static_lengths.insert(*index);
348            } else {
349                let err = AstConversionError::UnusedLength(name.to_owned());
350                state.errors.push(Error::conversion(err, dyn_length));
351            }
352        }
353
354        let mut type_params = HashMap::with_capacity(self.type_params.len());
355        for (param, constraints) in &self.type_params {
356            let name = *param.fragment();
357            if let Some(index) = state.type_params.get(name) {
358                type_params.insert(*index, constraints.convert(state));
359            } else {
360                let err = AstConversionError::UnusedTypeParam(name.to_owned());
361                state.errors.push(Error::conversion(err, param));
362            }
363        }
364
365        ParamConstraints {
366            type_params,
367            static_lengths,
368        }
369    }
370}
371
372impl<'a> TupleAst<'a> {
373    fn convert<Prim: PrimitiveType>(
374        &self,
375        state: &mut AstConversionState<'_, 'a, Prim>,
376    ) -> Tuple<Prim> {
377        let start = self
378            .start
379            .iter()
380            .map(|element| state.convert_type(element))
381            .collect();
382        let middle = self
383            .middle
384            .as_ref()
385            .map(|middle| middle.extra.convert(state));
386        let end = self
387            .end
388            .iter()
389            .map(|element| state.convert_type(element))
390            .collect();
391        Tuple::from_parts(start, middle, end)
392    }
393}
394
395impl<'a> SliceAst<'a> {
396    fn convert<Prim: PrimitiveType>(
397        &self,
398        state: &mut AstConversionState<'_, 'a, Prim>,
399    ) -> Slice<Prim> {
400        let element = state.convert_type(&self.element);
401
402        let converted_length = match &self.length.extra {
403            TupleLenAst::Ident => {
404                let name = *self.length.fragment();
405                if state.is_in_function {
406                    let const_param = state.len_param_idx(name);
407                    UnknownLen::param(const_param)
408                } else {
409                    let err = AstConversionError::FreeLengthVar(name.to_owned());
410                    state.errors.push(Error::conversion(err, &self.length));
411                    state.new_len(None)
412                }
413            }
414            TupleLenAst::Some => state.new_len(Some(&self.length)),
415            TupleLenAst::Dynamic => UnknownLen::Dynamic,
416        };
417
418        Slice::new(element, converted_length)
419    }
420}
421
422impl<'a> ObjectAst<'a> {
423    fn convert<Prim: PrimitiveType>(
424        &self,
425        state: &mut AstConversionState<'_, 'a, Prim>,
426    ) -> Object<Prim> {
427        let mut fields = HashMap::new();
428        for (field_name, ty) in &self.fields {
429            let field_name_str = *field_name.fragment();
430            if fields.contains_key(field_name_str) {
431                let err = AstConversionError::DuplicateField(field_name_str.to_owned());
432                state.errors.push(Error::conversion(err, field_name));
433            } else {
434                fields.insert(field_name_str.to_owned(), state.convert_type(ty));
435            }
436        }
437        Object::from_map(fields)
438    }
439}
440
441impl<'a> FunctionAst<'a> {
442    fn convert<Prim: PrimitiveType>(
443        &self,
444        state: &mut AstConversionState<'_, 'a, Prim>,
445    ) -> Function<Prim> {
446        let args = self.args.extra.convert(state);
447        let return_type = state.convert_type(&self.return_type);
448        Function::new(args, return_type)
449    }
450
451    /// Tries to convert this type into a [`Function`].
452    pub fn try_convert<Prim>(&self) -> Result<Function<Prim>, Errors<Prim>>
453    where
454        Prim: PrimitiveType,
455    {
456        let mut errors = Errors::new();
457        let mut state = AstConversionState::without_env(&mut errors);
458        state.is_in_function = true;
459
460        let output = self.convert(&mut state);
461        if errors.is_empty() {
462            Ok(output)
463        } else {
464            Err(errors)
465        }
466    }
467}
468
469/// Shared parsing code for `TypeAst` and `FunctionAst`.
470fn parse_inner<'a, Ast>(
471    parser: fn(InputSpan<'a>) -> NomResult<'a, Ast>,
472    input: InputSpan<'a>,
473) -> NomResult<'a, Ast> {
474    let (rest, ast) = parser(input)?;
475    if !rest.fragment().is_empty() {
476        let err = ParseErrorKind::Leftovers.with_span(&rest.into());
477        return Err(NomErr::Failure(err));
478    }
479    Ok((rest, ast))
480}
481
482/// Shared `TryFrom<&str>` logic for `TypeAst` and `FunctionAst`.
483fn from_str<'a, Ast>(
484    parser: fn(InputSpan<'a>) -> NomResult<'a, Ast>,
485    def: &'a str,
486) -> Result<Ast, ParseError> {
487    let input = InputSpan::new(def);
488    let (_, ast) = parse_inner(parser, input).map_err(|err| match err {
489        NomErr::Incomplete(_) => ParseErrorKind::Incomplete.with_span(&input.into()),
490        NomErr::Error(e) | NomErr::Failure(e) => e,
491    })?;
492    Ok(ast)
493}
494
495impl<'a> TypeAst<'a> {
496    /// Parses type AST from a string.
497    pub fn try_from(def: &'a str) -> Result<SpannedTypeAst<'a>, ParseError> {
498        from_str(TypeAst::parse, def)
499    }
500}
501
502impl<'a, Prim: PrimitiveType> TryFrom<&SpannedTypeAst<'a>> for Type<Prim> {
503    type Error = Errors<Prim>;
504
505    fn try_from(ast: &SpannedTypeAst<'a>) -> Result<Self, Self::Error> {
506        let mut errors = Errors::new();
507        let mut state = AstConversionState::without_env(&mut errors);
508
509        let output = state.convert_type(ast);
510        if errors.is_empty() {
511            Ok(output)
512        } else {
513            Err(errors)
514        }
515    }
516}
517
518impl<'a> TryFrom<&'a str> for FunctionAst<'a> {
519    type Error = ParseError;
520
521    fn try_from(def: &'a str) -> Result<Self, Self::Error> {
522        from_str(FunctionAst::parse, def)
523    }
524}
525
526#[cfg(test)]
527mod tests {
528    use assert_matches::assert_matches;
529
530    use super::*;
531    use crate::{
532        alloc::{vec, ToString},
533        arith::Num,
534    };
535
536    #[test]
537    fn converting_raw_fn_type() {
538        let input = InputSpan::new("(['T; N], ('T) -> Bool) -> Bool");
539        let (_, fn_type) = FunctionAst::parse(input).unwrap();
540        let fn_type = fn_type.try_convert::<Num>().unwrap();
541
542        assert_eq!(fn_type.to_string(), *input.fragment());
543    }
544
545    #[test]
546    fn converting_fn_type_with_constraint() {
547        let input = InputSpan::new("for<'T: Lin> (['T; N], ('T) -> Bool) -> Bool");
548        let (_, ast) = TypeAst::parse(input).unwrap();
549        let fn_type = <Type>::try_from(&ast).unwrap();
550
551        assert_eq!(fn_type.to_string(), *input.fragment());
552    }
553
554    #[test]
555    fn parsing_basic_types() -> anyhow::Result<()> {
556        let num_type = <Type>::try_from(&TypeAst::try_from("Num")?)?;
557        assert_eq!(num_type, Type::NUM);
558
559        let bool_type = <Type>::try_from(&TypeAst::try_from("Bool")?)?;
560        assert_eq!(bool_type, Type::BOOL);
561
562        let tuple_type = <Type>::try_from(&TypeAst::try_from("(Num, (Bool, Bool))")?)?;
563        assert_eq!(
564            tuple_type,
565            Type::from((Type::NUM, Type::Tuple(vec![Type::BOOL; 2].into()),))
566        );
567
568        let slice_type = <Type>::try_from(&TypeAst::try_from("[(Num, Bool)]")?)?;
569        let slice_type = match &slice_type {
570            Type::Tuple(tuple) => tuple.as_slice().unwrap(),
571            _ => panic!("Unexpected type: {slice_type:?}"),
572        };
573
574        assert_eq!(*slice_type.element(), Type::from((Type::NUM, Type::BOOL)));
575        assert_matches!(
576            slice_type.len().components(),
577            (Some(UnknownLen::Dynamic), 0)
578        );
579        Ok(())
580    }
581
582    #[test]
583    fn parsing_functional_type() -> anyhow::Result<()> {
584        let ty = <Type>::try_from(&TypeAst::try_from("(['T; N], ('T) -> 'U) -> 'U")?)?;
585        let Type::Function(ty) = ty else {
586            panic!("Unexpected type: {ty:?}");
587        };
588
589        assert_eq!(ty.params.as_ref().unwrap().len_params.len(), 1);
590        assert_eq!(ty.params.as_ref().unwrap().type_params.len(), 2);
591        assert_eq!(ty.return_type, Type::param(1));
592        Ok(())
593    }
594
595    #[test]
596    fn parsing_functional_type_with_varargs() -> anyhow::Result<()> {
597        let ty = <Type>::try_from(&TypeAst::try_from("(...[Num; N]) -> Num")?)?;
598        let Type::Function(ty) = ty else {
599            panic!("Unexpected type: {ty:?}");
600        };
601
602        assert_eq!(ty.params.as_ref().unwrap().len_params.len(), 1);
603        assert!(ty.params.as_ref().unwrap().type_params.is_empty());
604        let args_slice = ty.args.as_slice().unwrap();
605        assert_eq!(*args_slice.element(), Type::NUM);
606        assert_eq!(args_slice.len(), UnknownLen::param(0).into());
607        Ok(())
608    }
609
610    #[test]
611    fn parsing_incomplete_type() {
612        const INCOMPLETE_TYPES: &[&str] = &[
613            "fn(",
614            "fn(['T; ",
615            "fn(['T; N], fn(",
616            "fn(['T; N], fn('T)",
617            "fn(['T; N], fn('T)) -",
618            "fn(['T; N], fn('T)) ->",
619        ];
620
621        for &input in INCOMPLETE_TYPES {
622            // TODO: some of reported errors are difficult to interpret; should clarify.
623            TypeAst::try_from(input).unwrap_err();
624        }
625    }
626
627    #[test]
628    fn parsing_type_with_object_constraint() -> anyhow::Result<()> {
629        let type_def = "for<'T: { x: Num } + Lin> ('T) -> Bool";
630        let ty = TypeAst::try_from(type_def)?;
631        let ty = <Type>::try_from(&ty)?;
632        let Type::Function(ty) = ty else {
633            panic!("Unexpected type: {ty:?}");
634        };
635
636        let type_params = &ty.params.as_ref().unwrap().type_params;
637        assert_eq!(type_params.len(), 1);
638        let (_, type_params) = &type_params[0];
639        assert!(type_params.object.is_some());
640        assert!(type_params.simple.get_by_name("Lin").is_some());
641
642        assert_eq!(ty.to_string(), type_def);
643        Ok(())
644    }
645}