1use core::{fmt, ops};
4
5use crate::{
6 alloc::{HashMap, HashSet, String, ToOwned, Vec},
7 arith::Substitutions,
8 error::{ErrorKind, OpErrors},
9 DynConstraints, PrimitiveType, Type,
10};
11
12#[derive(Debug, Clone, PartialEq)]
77pub struct Object<Prim: PrimitiveType> {
78 fields: HashMap<String, Type<Prim>>,
79}
80
81impl<Prim: PrimitiveType> Default for Object<Prim> {
82 fn default() -> Self {
83 Self {
84 fields: HashMap::new(),
85 }
86 }
87}
88
89impl<Prim, S, V> FromIterator<(S, V)> for Object<Prim>
90where
91 Prim: PrimitiveType,
92 S: Into<String>,
93 V: Into<Type<Prim>>,
94{
95 fn from_iter<T: IntoIterator<Item = (S, V)>>(iter: T) -> Self {
96 Self {
97 fields: iter
98 .into_iter()
99 .map(|(name, ty)| (name.into(), ty.into()))
100 .collect(),
101 }
102 }
103}
104
105impl<Prim, S, V, const N: usize> From<[(S, V); N]> for Object<Prim>
106where
107 Prim: PrimitiveType,
108 S: Into<String>,
109 V: Into<Type<Prim>>,
110{
111 fn from(entries: [(S, V); N]) -> Self {
112 Self {
113 fields: entries
114 .into_iter()
115 .map(|(name, ty)| (name.into(), ty.into()))
116 .collect(),
117 }
118 }
119}
120
121impl<Prim: PrimitiveType> fmt::Display for Object<Prim> {
122 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
123 let mut sorted_fields: Vec<_> = self.fields.iter().collect();
124 sorted_fields.sort_unstable_by_key(|(name, _)| *name);
125
126 formatter.write_str("{")?;
127 for (i, (name, ty)) in sorted_fields.into_iter().enumerate() {
128 write!(formatter, " {name}: {ty}")?;
129 if i + 1 < self.fields.len() {
130 formatter.write_str(",")?;
131 }
132 }
133 formatter.write_str(" }")
134 }
135}
136
137impl<Prim: PrimitiveType> Object<Prim> {
138 pub fn new() -> Self {
140 Self::default()
141 }
142
143 pub(crate) fn from_map(fields: HashMap<String, Type<Prim>>) -> Self {
144 Self { fields }
145 }
146
147 pub fn get(&self, name: &str) -> Option<&Type<Prim>> {
149 self.fields.get(name)
150 }
151
152 pub fn iter(&self) -> impl Iterator<Item = (&str, &Type<Prim>)> + '_ {
154 self.fields.iter().map(|(name, ty)| (name.as_str(), ty))
155 }
156
157 pub fn field_names(&self) -> impl Iterator<Item = &str> + '_ {
159 self.fields.keys().map(String::as_str)
160 }
161
162 pub fn into_dyn(self) -> Type<Prim> {
164 Type::Dyn(DynConstraints::from(self))
165 }
166
167 pub(crate) fn iter_mut(&mut self) -> impl Iterator<Item = (&str, &mut Type<Prim>)> + '_ {
168 self.fields.iter_mut().map(|(name, ty)| (name.as_str(), ty))
169 }
170
171 pub(crate) fn is_concrete(&self) -> bool {
172 self.fields.values().all(Type::is_concrete)
173 }
174
175 pub(crate) fn extend_from(
176 &mut self,
177 other: Self,
178 substitutions: &mut Substitutions<Prim>,
179 mut errors: OpErrors<'_, Prim>,
180 ) {
181 for (field_name, ty) in other.fields {
182 if let Some(this_field) = self.fields.get(&field_name) {
183 substitutions.unify(this_field, &ty, errors.join_path(field_name.as_str()));
184 } else {
185 self.fields.insert(field_name, ty);
186 }
187 }
188 }
189
190 pub(crate) fn apply_as_constraint(
191 &self,
192 ty: &Type<Prim>,
193 substitutions: &mut Substitutions<Prim>,
194 mut errors: OpErrors<'_, Prim>,
195 ) {
196 let resolved_ty = if let Type::Var(var) = ty {
197 debug_assert!(var.is_free());
198 substitutions.insert_obj_constraint(var.index(), self, errors.by_ref());
199 substitutions.fast_resolve(ty)
200 } else {
201 ty
202 };
203
204 match resolved_ty {
205 Type::Object(rhs) => {
206 self.constraint_object(&rhs.clone(), substitutions, errors);
207 }
208 Type::Dyn(constraints) => {
209 if let Some(object) = constraints.inner.object.clone() {
210 self.constraint_object(&object, substitutions, errors);
211 } else {
212 errors.push(ErrorKind::CannotAccessFields);
213 }
214 }
215 Type::Any | Type::Var(_) => { }
216 _ => errors.push(ErrorKind::CannotAccessFields),
217 }
218 }
219
220 fn constraint_object(
222 &self,
223 rhs: &Object<Prim>,
224 substitutions: &mut Substitutions<Prim>,
225 mut errors: OpErrors<'_, Prim>,
226 ) {
227 let mut missing_fields = HashSet::new();
228 for (field_name, lhs_ty) in self.iter() {
229 if let Some(rhs_ty) = rhs.get(field_name) {
230 substitutions.unify(lhs_ty, rhs_ty, errors.join_path(field_name));
231 } else {
232 missing_fields.insert(field_name.to_owned());
233 }
234 }
235
236 if !missing_fields.is_empty() {
237 errors.push(ErrorKind::MissingFields {
238 fields: missing_fields,
239 available_fields: rhs.field_names().map(String::from).collect(),
240 });
241 }
242 }
243}
244
245impl<Prim: PrimitiveType> ops::Index<&str> for Object<Prim> {
246 type Output = Type<Prim>;
247
248 fn index(&self, field_name: &str) -> &Self::Output {
249 self.get(field_name).unwrap_or_else(|| {
250 panic!("Object type does not contain field `{field_name}`");
251 })
252 }
253}
254
255#[cfg(test)]
256mod tests {
257 use assert_matches::assert_matches;
258
259 use super::*;
260 use crate::arith::Num;
261
262 fn get_err(errors: OpErrors<'_, Num>) -> ErrorKind<Num> {
263 let mut errors = errors.into_vec();
264 assert_eq!(errors.len(), 1, "{errors:?}");
265 errors.pop().unwrap()
266 }
267
268 #[test]
269 fn placing_obj_constraint() {
270 let lhs: Object<Num> = Object::from([("x", Type::NUM)]);
271 let mut substitutions = Substitutions::default();
272 let mut errors = OpErrors::new();
273 lhs.constraint_object(&lhs, &mut substitutions, errors.by_ref());
274 assert!(errors.into_vec().is_empty());
275
276 let var_rhs = Object::from([("x", Type::free_var(0))]);
277 let mut errors = OpErrors::new();
278 lhs.constraint_object(&var_rhs, &mut substitutions, errors.by_ref());
279 assert!(errors.into_vec().is_empty());
280 assert_eq!(*substitutions.fast_resolve(&Type::free_var(0)), Type::NUM);
281
282 let extra_rhs = Object::from([("x", Type::free_var(1)), ("y", Type::BOOL)]);
284 let mut errors = OpErrors::new();
285 lhs.constraint_object(&extra_rhs, &mut substitutions, errors.by_ref());
286 assert!(errors.into_vec().is_empty());
287 assert_eq!(*substitutions.fast_resolve(&Type::free_var(1)), Type::NUM);
288
289 let missing_field_rhs = Object::from([("y", Type::free_var(2))]);
290 let mut errors = OpErrors::new();
291 lhs.constraint_object(&missing_field_rhs, &mut substitutions, errors.by_ref());
292 assert_matches!(
293 get_err(errors),
294 ErrorKind::MissingFields { fields, available_fields }
295 if fields.len() == 1 && fields.contains("x") &&
296 available_fields.len() == 1 && available_fields.contains("y")
297 );
298
299 let incompatible_field_rhs = Object::from([("x", Type::BOOL)]);
300 let mut errors = OpErrors::new();
301 lhs.constraint_object(&incompatible_field_rhs, &mut substitutions, errors.by_ref());
302 assert_matches!(
303 get_err(errors),
304 ErrorKind::TypeMismatch(lhs, rhs) if *lhs == Type::NUM && *rhs == Type::BOOL
305 );
306 }
307}