arithmetic_typing/env/
mod.rs

1//! `TypeEnvironment` and related types.
2
3use core::ops;
4
5use arithmetic_parser::{grammars::Grammar, Block};
6
7use self::processor::TypeProcessor;
8use crate::{
9    alloc::{HashMap, String, ToOwned},
10    arith::{
11        Constraint, ConstraintSet, MapPrimitiveType, Num, NumArithmetic, ObjectSafeConstraint,
12        Substitutions, TypeArithmetic,
13    },
14    ast::TypeAst,
15    error::Errors,
16    types::{ParamConstraints, ParamQuantifier},
17    visit::VisitMut,
18    Function, PrimitiveType, Type,
19};
20
21mod processor;
22
23/// Environment containing type information on named variables.
24///
25/// # Examples
26///
27/// See [the crate docs](index.html#examples) for examples of usage.
28///
29/// # Concrete and partially specified types
30///
31/// The environment retains full info on the types even if the type is not
32/// [concrete](Type::is_concrete()). Non-concrete types are tied to an environment.
33/// An environment will panic on inserting a non-concrete type via [`Self::insert()`]
34/// or other methods.
35///
36/// ```
37/// # use arithmetic_parser::grammars::{F32Grammar, Parse};
38/// # use arithmetic_typing::{defs::Prelude, Annotated, TypeEnvironment};
39/// # type Parser = Annotated<F32Grammar>;
40/// # fn main() -> anyhow::Result<()> {
41/// // An easy way to get a non-concrete type is to involve `any`.
42/// let code = "(x, ...) = (1, 2, 3) as any;";
43/// let code = Parser::parse_statements(code)?;
44///
45/// let mut env: TypeEnvironment = Prelude::iter().collect();
46/// env.process_statements(&code)?;
47/// assert!(!env["x"].is_concrete());
48/// # Ok(())
49/// # }
50/// ```
51#[derive(Debug, Clone)]
52pub struct TypeEnvironment<Prim: PrimitiveType = Num> {
53    pub(crate) substitutions: Substitutions<Prim>,
54    pub(crate) known_constraints: ConstraintSet<Prim>,
55    variables: HashMap<String, Type<Prim>>,
56}
57
58impl<Prim: PrimitiveType> Default for TypeEnvironment<Prim> {
59    fn default() -> Self {
60        Self {
61            variables: HashMap::new(),
62            known_constraints: Prim::well_known_constraints(),
63            substitutions: Substitutions::default(),
64        }
65    }
66}
67
68impl<Prim: PrimitiveType> TypeEnvironment<Prim> {
69    /// Creates an empty environment.
70    pub fn new() -> Self {
71        Self::default()
72    }
73
74    /// Gets type of the specified variable.
75    pub fn get(&self, name: &str) -> Option<&Type<Prim>> {
76        self.variables.get(name)
77    }
78
79    /// Iterates over variables contained in this env.
80    pub fn iter(&self) -> impl Iterator<Item = (&str, &Type<Prim>)> + '_ {
81        self.variables.iter().map(|(name, ty)| (name.as_str(), ty))
82    }
83
84    fn prepare_type(ty: impl Into<Type<Prim>>) -> Type<Prim> {
85        let mut ty = ty.into();
86        assert!(ty.is_concrete(), "Type {ty} is not concrete");
87        TypePreparer.visit_type_mut(&mut ty);
88        ty
89    }
90
91    /// Sets type of a variable.
92    ///
93    /// # Panics
94    ///
95    /// - Will panic if `ty` is not [concrete](Type::is_concrete()). Non-concrete
96    ///   types are tied to the environment; inserting them into an env is a logical error.
97    pub fn insert(&mut self, name: &str, ty: impl Into<Type<Prim>>) -> &mut Self {
98        self.variables
99            .insert(name.to_owned(), Self::prepare_type(ty));
100        self
101    }
102
103    /// Inserts a [`Constraint`] into the environment so that it can be used when parsing
104    /// type annotations.
105    ///
106    /// Adding a constraint is not mandatory for it to be usable during type inference;
107    /// this method only influences whether the constraint is recognized during type parsing.
108    pub fn insert_constraint(&mut self, constraint: impl Constraint<Prim>) -> &mut Self {
109        self.known_constraints.insert(constraint);
110        self
111    }
112
113    /// Inserts an [`ObjectSafeConstraint`] into the environment so that it can be used
114    /// when parsing type annotations.
115    ///
116    /// Other than more strict type requirements, this method is identical to
117    /// [`Self::insert_constraint`].
118    pub fn insert_object_safe_constraint(
119        &mut self,
120        constraint: impl ObjectSafeConstraint<Prim>,
121    ) -> &mut Self {
122        self.known_constraints.insert_object_safe(constraint);
123        self
124    }
125
126    /// Processes statements with the default type arithmetic. After processing, the environment
127    /// will contain type info about newly declared vars.
128    ///
129    /// This method is a shortcut for calling `process_with_arithmetic` with
130    /// [`NumArithmetic::without_comparisons()`].
131    pub fn process_statements<'a, T>(
132        &mut self,
133        block: &Block<'a, T>,
134    ) -> Result<Type<Prim>, Errors<Prim>>
135    where
136        T: Grammar<Type<'a> = TypeAst<'a>>,
137        NumArithmetic: MapPrimitiveType<T::Lit, Prim = Prim> + TypeArithmetic<Prim>,
138    {
139        self.process_with_arithmetic(&NumArithmetic::without_comparisons(), block)
140    }
141
142    /// Processes statements with a given `arithmetic`. After processing, the environment
143    /// will contain type info about newly declared vars.
144    ///
145    /// # Errors
146    ///
147    /// Even if there are any type errors, all statements in the `block` will be executed
148    /// to completion and all errors will be reported. However, the environment will **not**
149    /// include any vars beyond the first failing statement.
150    pub fn process_with_arithmetic<'a, T, A>(
151        &mut self,
152        arithmetic: &A,
153        block: &Block<'a, T>,
154    ) -> Result<Type<Prim>, Errors<Prim>>
155    where
156        T: Grammar<Type<'a> = TypeAst<'a>>,
157        A: MapPrimitiveType<T::Lit, Prim = Prim> + TypeArithmetic<Prim>,
158    {
159        TypeProcessor::new(self, arithmetic).process_statements(block)
160    }
161}
162
163impl<Prim: PrimitiveType> ops::Index<&str> for TypeEnvironment<Prim> {
164    type Output = Type<Prim>;
165
166    fn index(&self, name: &str) -> &Self::Output {
167        self.get(name)
168            .unwrap_or_else(|| panic!("Variable `{name}` is not defined"))
169    }
170}
171
172/// Fills in parameters in all encountered top-level functions within a type.
173#[derive(Debug)]
174struct TypePreparer;
175
176impl<Prim: PrimitiveType> VisitMut<Prim> for TypePreparer {
177    fn visit_function_mut(&mut self, function: &mut Function<Prim>) {
178        if function.params.is_none() {
179            ParamQuantifier::fill_params(function, ParamConstraints::default());
180        }
181        // We intentionally do not recurse into functions; this is done within `ParamQuantifier`.
182    }
183}
184
185fn convert_iter<Prim: PrimitiveType, S, Ty, I>(
186    iter: I,
187) -> impl Iterator<Item = (String, Type<Prim>)>
188where
189    I: IntoIterator<Item = (S, Ty)>,
190    S: Into<String>,
191    Ty: Into<Type<Prim>>,
192{
193    iter.into_iter()
194        .map(|(name, ty)| (name.into(), TypeEnvironment::prepare_type(ty)))
195}
196
197impl<Prim: PrimitiveType, S, Ty> FromIterator<(S, Ty)> for TypeEnvironment<Prim>
198where
199    S: Into<String>,
200    Ty: Into<Type<Prim>>,
201{
202    fn from_iter<I: IntoIterator<Item = (S, Ty)>>(iter: I) -> Self {
203        Self {
204            variables: convert_iter(iter).collect(),
205            known_constraints: Prim::well_known_constraints(),
206            substitutions: Substitutions::default(),
207        }
208    }
209}
210
211impl<Prim: PrimitiveType, S, Ty> Extend<(S, Ty)> for TypeEnvironment<Prim>
212where
213    S: Into<String>,
214    Ty: Into<Type<Prim>>,
215{
216    fn extend<I: IntoIterator<Item = (S, Ty)>>(&mut self, iter: I) {
217        self.variables.extend(convert_iter(iter));
218    }
219}
220
221// Helper trait to wrap type mapper and arithmetic.
222trait FullArithmetic<Val, Prim: PrimitiveType>:
223    MapPrimitiveType<Val, Prim = Prim> + TypeArithmetic<Prim>
224{
225}
226
227impl<Val, Prim: PrimitiveType, T> FullArithmetic<Val, Prim> for T where
228    T: MapPrimitiveType<Val, Prim = Prim> + TypeArithmetic<Prim>
229{
230}