arithmetic_typing/types/
object.rs

1//! Object types.
2
3use core::{fmt, ops};
4
5use crate::{
6    alloc::{HashMap, HashSet, String, ToOwned, Vec},
7    arith::Substitutions,
8    error::{ErrorKind, OpErrors},
9    DynConstraints, PrimitiveType, Type,
10};
11
12/// Object type: a collection of named fields with heterogeneous types.
13///
14/// # Notation
15///
16/// Object types are denoted using a brace notation such as `{ x: Num, y: [(Num, 'T)] }`.
17/// Here, `x` and `y` are field names, and `Num` / `[(Num, 'T)]` are types of the corresponding
18/// object fields.
19///
20/// # As constraint
21///
22/// Object types are *exact*; their extensions cannot be unified with the original types.
23/// For example, if a function argument is `{ x: Num, y: Num }`,
24/// the function cannot be called with an arg of type `{ x: Num, y: Num, z: Num }`:
25///
26/// ```
27/// # use arithmetic_parser::grammars::{Parse, F32Grammar};
28/// # use arithmetic_typing::{error::ErrorKind, Annotated, TypeEnvironment};
29/// # use assert_matches::assert_matches;
30/// # fn main() -> anyhow::Result<()> {
31/// let code = "
32///     sum_coords = |pt: { x: Num, y: Num }| pt.x + pt.y;
33///     sum_coords(#{ x: 3, y: 4 }); // OK
34///     sum_coords(#{ x: 3, y: 4, z: 5 }); // fails
35/// ";
36/// let ast = Annotated::<F32Grammar>::parse_statements(code)?;
37/// let err = TypeEnvironment::new().process_statements(&ast).unwrap_err();
38/// # assert_eq!(err.len(), 1);
39/// let err = err.iter().next().unwrap();
40/// assert_matches!(err.kind(), ErrorKind::FieldsMismatch { .. });
41/// # Ok(())
42/// # }
43/// ```
44///
45/// To bridge this gap, objects can be used as a constraint on types, similarly to [`Constraint`]s.
46/// As a constraint, an object specifies *necessary* fields, which can be arbitrarily extended.
47///
48/// The type inference algorithm uses object constraints, not concrete object types whenever
49/// possible:
50///
51/// ```
52/// # use arithmetic_parser::grammars::{Parse, F32Grammar};
53/// # use arithmetic_typing::{error::ErrorKind, Annotated, TypeEnvironment};
54/// # use assert_matches::assert_matches;
55/// # fn main() -> anyhow::Result<()> {
56/// let code = "
57///     sum_coords = |pt| pt.x + pt.y;
58///     sum_coords(#{ x: 3, y: 4 }); // OK
59///     sum_coords(#{ x: 3, y: 4, z: 5 }); // also OK
60/// ";
61/// let ast = Annotated::<F32Grammar>::parse_statements(code)?;
62/// let mut env = TypeEnvironment::new();
63/// env.process_statements(&ast)?;
64/// assert_eq!(
65///     env["sum_coords"].to_string(),
66///     "for<'T: { x: 'U, y: 'U }, 'U: Ops> ('T) -> 'U"
67/// );
68/// # Ok(())
69/// # }
70/// ```
71///
72/// Note that the object constraint in this case refers to another type param, which is
73/// constrained on its own!
74///
75/// [`Constraint`]: crate::arith::Constraint
76#[derive(Debug, Clone, PartialEq)]
77pub struct Object<Prim: PrimitiveType> {
78    fields: HashMap<String, Type<Prim>>,
79}
80
81impl<Prim: PrimitiveType> Default for Object<Prim> {
82    fn default() -> Self {
83        Self {
84            fields: HashMap::new(),
85        }
86    }
87}
88
89impl<Prim, S, V> FromIterator<(S, V)> for Object<Prim>
90where
91    Prim: PrimitiveType,
92    S: Into<String>,
93    V: Into<Type<Prim>>,
94{
95    fn from_iter<T: IntoIterator<Item = (S, V)>>(iter: T) -> Self {
96        Self {
97            fields: iter
98                .into_iter()
99                .map(|(name, ty)| (name.into(), ty.into()))
100                .collect(),
101        }
102    }
103}
104
105impl<Prim, S, V, const N: usize> From<[(S, V); N]> for Object<Prim>
106where
107    Prim: PrimitiveType,
108    S: Into<String>,
109    V: Into<Type<Prim>>,
110{
111    fn from(entries: [(S, V); N]) -> Self {
112        Self {
113            fields: entries
114                .into_iter()
115                .map(|(name, ty)| (name.into(), ty.into()))
116                .collect(),
117        }
118    }
119}
120
121impl<Prim: PrimitiveType> fmt::Display for Object<Prim> {
122    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
123        let mut sorted_fields: Vec<_> = self.fields.iter().collect();
124        sorted_fields.sort_unstable_by_key(|(name, _)| *name);
125
126        formatter.write_str("{")?;
127        for (i, (name, ty)) in sorted_fields.into_iter().enumerate() {
128            write!(formatter, " {name}: {ty}")?;
129            if i + 1 < self.fields.len() {
130                formatter.write_str(",")?;
131            }
132        }
133        formatter.write_str(" }")
134    }
135}
136
137impl<Prim: PrimitiveType> Object<Prim> {
138    /// Creates an empty object.
139    pub fn new() -> Self {
140        Self::default()
141    }
142
143    pub(crate) fn from_map(fields: HashMap<String, Type<Prim>>) -> Self {
144        Self { fields }
145    }
146
147    /// Returns type of a field with the specified `name`.
148    pub fn get(&self, name: &str) -> Option<&Type<Prim>> {
149        self.fields.get(name)
150    }
151
152    /// Iterates over fields in this object.
153    pub fn iter(&self) -> impl Iterator<Item = (&str, &Type<Prim>)> + '_ {
154        self.fields.iter().map(|(name, ty)| (name.as_str(), ty))
155    }
156
157    /// Iterates over field names in this object.
158    pub fn field_names(&self) -> impl Iterator<Item = &str> + '_ {
159        self.fields.keys().map(String::as_str)
160    }
161
162    /// Converts this object into a corresponding dynamic constraint.
163    pub fn into_dyn(self) -> Type<Prim> {
164        Type::Dyn(DynConstraints::from(self))
165    }
166
167    pub(crate) fn iter_mut(&mut self) -> impl Iterator<Item = (&str, &mut Type<Prim>)> + '_ {
168        self.fields.iter_mut().map(|(name, ty)| (name.as_str(), ty))
169    }
170
171    pub(crate) fn is_concrete(&self) -> bool {
172        self.fields.values().all(Type::is_concrete)
173    }
174
175    pub(crate) fn extend_from(
176        &mut self,
177        other: Self,
178        substitutions: &mut Substitutions<Prim>,
179        mut errors: OpErrors<'_, Prim>,
180    ) {
181        for (field_name, ty) in other.fields {
182            if let Some(this_field) = self.fields.get(&field_name) {
183                substitutions.unify(this_field, &ty, errors.join_path(field_name.as_str()));
184            } else {
185                self.fields.insert(field_name, ty);
186            }
187        }
188    }
189
190    pub(crate) fn apply_as_constraint(
191        &self,
192        ty: &Type<Prim>,
193        substitutions: &mut Substitutions<Prim>,
194        mut errors: OpErrors<'_, Prim>,
195    ) {
196        let resolved_ty = if let Type::Var(var) = ty {
197            debug_assert!(var.is_free());
198            substitutions.insert_obj_constraint(var.index(), self, errors.by_ref());
199            substitutions.fast_resolve(ty)
200        } else {
201            ty
202        };
203
204        match resolved_ty {
205            Type::Object(rhs) => {
206                self.constraint_object(&rhs.clone(), substitutions, errors);
207            }
208            Type::Dyn(constraints) => {
209                if let Some(object) = constraints.inner.object.clone() {
210                    self.constraint_object(&object, substitutions, errors);
211                } else {
212                    errors.push(ErrorKind::CannotAccessFields);
213                }
214            }
215            Type::Any | Type::Var(_) => { /* OK */ }
216            _ => errors.push(ErrorKind::CannotAccessFields),
217        }
218    }
219
220    /// Places an object constraint encoded in `lhs` on a (concrete) object in `rhs`.
221    fn constraint_object(
222        &self,
223        rhs: &Object<Prim>,
224        substitutions: &mut Substitutions<Prim>,
225        mut errors: OpErrors<'_, Prim>,
226    ) {
227        let mut missing_fields = HashSet::new();
228        for (field_name, lhs_ty) in self.iter() {
229            if let Some(rhs_ty) = rhs.get(field_name) {
230                substitutions.unify(lhs_ty, rhs_ty, errors.join_path(field_name));
231            } else {
232                missing_fields.insert(field_name.to_owned());
233            }
234        }
235
236        if !missing_fields.is_empty() {
237            errors.push(ErrorKind::MissingFields {
238                fields: missing_fields,
239                available_fields: rhs.field_names().map(String::from).collect(),
240            });
241        }
242    }
243}
244
245impl<Prim: PrimitiveType> ops::Index<&str> for Object<Prim> {
246    type Output = Type<Prim>;
247
248    fn index(&self, field_name: &str) -> &Self::Output {
249        self.get(field_name).unwrap_or_else(|| {
250            panic!("Object type does not contain field `{field_name}`");
251        })
252    }
253}
254
255#[cfg(test)]
256mod tests {
257    use assert_matches::assert_matches;
258
259    use super::*;
260    use crate::arith::Num;
261
262    fn get_err(errors: OpErrors<'_, Num>) -> ErrorKind<Num> {
263        let mut errors = errors.into_vec();
264        assert_eq!(errors.len(), 1, "{errors:?}");
265        errors.pop().unwrap()
266    }
267
268    #[test]
269    fn placing_obj_constraint() {
270        let lhs: Object<Num> = Object::from([("x", Type::NUM)]);
271        let mut substitutions = Substitutions::default();
272        let mut errors = OpErrors::new();
273        lhs.constraint_object(&lhs, &mut substitutions, errors.by_ref());
274        assert!(errors.into_vec().is_empty());
275
276        let var_rhs = Object::from([("x", Type::free_var(0))]);
277        let mut errors = OpErrors::new();
278        lhs.constraint_object(&var_rhs, &mut substitutions, errors.by_ref());
279        assert!(errors.into_vec().is_empty());
280        assert_eq!(*substitutions.fast_resolve(&Type::free_var(0)), Type::NUM);
281
282        // Extra fields in RHS are fine.
283        let extra_rhs = Object::from([("x", Type::free_var(1)), ("y", Type::BOOL)]);
284        let mut errors = OpErrors::new();
285        lhs.constraint_object(&extra_rhs, &mut substitutions, errors.by_ref());
286        assert!(errors.into_vec().is_empty());
287        assert_eq!(*substitutions.fast_resolve(&Type::free_var(1)), Type::NUM);
288
289        let missing_field_rhs = Object::from([("y", Type::free_var(2))]);
290        let mut errors = OpErrors::new();
291        lhs.constraint_object(&missing_field_rhs, &mut substitutions, errors.by_ref());
292        assert_matches!(
293            get_err(errors),
294            ErrorKind::MissingFields { fields, available_fields }
295                if fields.len() == 1 && fields.contains("x") &&
296                available_fields.len() == 1 && available_fields.contains("y")
297        );
298
299        let incompatible_field_rhs = Object::from([("x", Type::BOOL)]);
300        let mut errors = OpErrors::new();
301        lhs.constraint_object(&incompatible_field_rhs, &mut substitutions, errors.by_ref());
302        assert_matches!(
303            get_err(errors),
304            ErrorKind::TypeMismatch(lhs, rhs) if *lhs == Type::NUM && *rhs == Type::BOOL
305        );
306    }
307}