1use core::{fmt, marker::PhantomData};
4
5use crate::{
6 alloc::{Box, HashMap, String, ToString},
7 arith::Substitutions,
8 error::{ErrorKind, OpErrors},
9 visit::{self, Visit},
10 Function, Object, PrimitiveType, Slice, Tuple, Type, TypeVar,
11};
12
13pub trait Constraint<Prim: PrimitiveType>: fmt::Display + Send + Sync + 'static {
36 fn visitor<'r>(
44 &self,
45 substitutions: &'r mut Substitutions<Prim>,
46 errors: OpErrors<'r, Prim>,
47 ) -> Box<dyn Visit<Prim> + 'r>;
48
49 fn clone_boxed(&self) -> Box<dyn Constraint<Prim>>;
53}
54
55impl<Prim: PrimitiveType> fmt::Debug for dyn Constraint<Prim> {
56 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
57 formatter
58 .debug_tuple("dyn Constraint")
59 .field(&self.to_string())
60 .finish()
61 }
62}
63
64impl<Prim: PrimitiveType> Clone for Box<dyn Constraint<Prim>> {
65 fn clone(&self) -> Self {
66 self.clone_boxed()
67 }
68}
69
70pub trait ObjectSafeConstraint<Prim: PrimitiveType>: Constraint<Prim> {}
79
80#[derive(Debug)]
135pub struct StructConstraint<Prim, C, F> {
136 constraint: C,
137 predicate: F,
138 deny_dyn_slices: bool,
139 _prim: PhantomData<Prim>,
140}
141
142impl<Prim, C, F> StructConstraint<Prim, C, F>
143where
144 Prim: PrimitiveType,
145 C: Constraint<Prim> + Clone,
146 F: Fn(&Prim) -> bool + 'static,
147{
148 pub fn new(constraint: C, predicate: F) -> Self {
151 Self {
152 constraint,
153 predicate,
154 deny_dyn_slices: false,
155 _prim: PhantomData,
156 }
157 }
158
159 #[must_use]
161 pub fn deny_dyn_slices(mut self) -> Self {
162 self.deny_dyn_slices = true;
163 self
164 }
165
166 pub fn visitor<'r>(
168 self,
169 substitutions: &'r mut Substitutions<Prim>,
170 errors: OpErrors<'r, Prim>,
171 ) -> Box<dyn Visit<Prim> + 'r> {
172 Box::new(StructConstraintVisitor {
173 inner: self,
174 substitutions,
175 errors,
176 })
177 }
178}
179
180#[derive(Debug)]
181struct StructConstraintVisitor<'r, Prim: PrimitiveType, C, F> {
182 inner: StructConstraint<Prim, C, F>,
183 substitutions: &'r mut Substitutions<Prim>,
184 errors: OpErrors<'r, Prim>,
185}
186
187impl<Prim, C, F> Visit<Prim> for StructConstraintVisitor<'_, Prim, C, F>
188where
189 Prim: PrimitiveType,
190 C: Constraint<Prim> + Clone,
191 F: Fn(&Prim) -> bool + 'static,
192{
193 fn visit_type(&mut self, ty: &Type<Prim>) {
194 match ty {
195 Type::Dyn(constraints) => {
196 if !constraints.inner.simple.contains(&self.inner.constraint) {
197 self.errors.push(ErrorKind::failed_constraint(
198 ty.clone(),
199 self.inner.constraint.clone(),
200 ));
201 }
202 }
203 _ => visit::visit_type(self, ty),
204 }
205 }
206
207 fn visit_var(&mut self, var: TypeVar) {
208 debug_assert!(var.is_free());
209 self.substitutions.insert_constraint(
210 var.index(),
211 &self.inner.constraint,
212 self.errors.by_ref(),
213 );
214
215 let resolved = self.substitutions.fast_resolve(&Type::Var(var)).clone();
216 if let Type::Var(_) = resolved {
217 } else {
219 visit::visit_type(self, &resolved);
220 }
221 }
222
223 fn visit_primitive(&mut self, primitive: &Prim) {
224 if !(self.inner.predicate)(primitive) {
225 self.errors.push(ErrorKind::failed_constraint(
226 Type::Prim(primitive.clone()),
227 self.inner.constraint.clone(),
228 ));
229 }
230 }
231
232 fn visit_tuple(&mut self, tuple: &Tuple<Prim>) {
233 if self.inner.deny_dyn_slices {
234 let middle_len = tuple.parts().1.map(Slice::len);
235 if let Some(middle_len) = middle_len {
236 if let Err(err) = self.substitutions.apply_static_len(middle_len) {
237 self.errors.push(err);
238 }
239 }
240 }
241
242 for (i, element) in tuple.element_types() {
243 self.errors.push_path_fragment(i);
244 self.visit_type(element);
245 self.errors.pop_path_fragment();
246 }
247 }
248
249 fn visit_object(&mut self, obj: &Object<Prim>) {
250 for (name, element) in obj.iter() {
251 self.errors.push_path_fragment(name);
252 self.visit_type(element);
253 self.errors.pop_path_fragment();
254 }
255 }
256
257 fn visit_function(&mut self, function: &Function<Prim>) {
258 self.errors.push(ErrorKind::failed_constraint(
259 function.clone().into(),
260 self.inner.constraint.clone(),
261 ));
262 }
263}
264
265#[derive(Debug, Clone, Copy, PartialEq, Eq)]
271pub struct Linearity;
272
273impl fmt::Display for Linearity {
274 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
275 formatter.write_str("Lin")
276 }
277}
278
279impl<Prim: LinearType> Constraint<Prim> for Linearity {
280 fn visitor<'r>(
281 &self,
282 substitutions: &'r mut Substitutions<Prim>,
283 errors: OpErrors<'r, Prim>,
284 ) -> Box<dyn Visit<Prim> + 'r> {
285 StructConstraint::new(*self, LinearType::is_linear).visitor(substitutions, errors)
286 }
287
288 fn clone_boxed(&self) -> Box<dyn Constraint<Prim>> {
289 Box::new(*self)
290 }
291}
292
293impl<Prim: LinearType> ObjectSafeConstraint<Prim> for Linearity {}
294
295pub trait LinearType: PrimitiveType {
298 fn is_linear(&self) -> bool;
300}
301
302#[derive(Debug, Clone, Copy, PartialEq, Eq)]
307pub struct Ops;
308
309impl fmt::Display for Ops {
310 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
311 formatter.write_str("Ops")
312 }
313}
314
315impl<Prim: LinearType> Constraint<Prim> for Ops {
316 fn visitor<'r>(
317 &self,
318 substitutions: &'r mut Substitutions<Prim>,
319 errors: OpErrors<'r, Prim>,
320 ) -> Box<dyn Visit<Prim> + 'r> {
321 StructConstraint::new(*self, LinearType::is_linear)
322 .deny_dyn_slices()
323 .visitor(substitutions, errors)
324 }
325
326 fn clone_boxed(&self) -> Box<dyn Constraint<Prim>> {
327 Box::new(*self)
328 }
329}
330
331#[derive(Debug, Clone)]
336pub struct ConstraintSet<Prim: PrimitiveType> {
337 inner: HashMap<String, (Box<dyn Constraint<Prim>>, bool)>,
338}
339
340impl<Prim: PrimitiveType> Default for ConstraintSet<Prim> {
341 fn default() -> Self {
342 Self::new()
343 }
344}
345
346impl<Prim: PrimitiveType> PartialEq for ConstraintSet<Prim> {
347 fn eq(&self, other: &Self) -> bool {
348 if self.inner.len() == other.inner.len() {
349 self.inner.keys().all(|key| other.inner.contains_key(key))
350 } else {
351 false
352 }
353 }
354}
355
356impl<Prim: PrimitiveType> fmt::Display for ConstraintSet<Prim> {
357 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
358 let len = self.inner.len();
359 for (i, (constraint, _)) in self.inner.values().enumerate() {
360 fmt::Display::fmt(constraint, formatter)?;
361 if i + 1 < len {
362 formatter.write_str(" + ")?;
363 }
364 }
365 Ok(())
366 }
367}
368
369impl<Prim: PrimitiveType> ConstraintSet<Prim> {
370 pub fn new() -> Self {
372 Self {
373 inner: HashMap::new(),
374 }
375 }
376
377 pub fn just(constraint: impl Constraint<Prim>) -> Self {
379 let mut this = Self::new();
380 this.insert(constraint);
381 this
382 }
383
384 pub fn is_empty(&self) -> bool {
386 self.inner.is_empty()
387 }
388
389 fn contains(&self, constraint: &impl Constraint<Prim>) -> bool {
390 self.inner.contains_key(&constraint.to_string())
391 }
392
393 pub fn insert(&mut self, constraint: impl Constraint<Prim>) {
395 self.inner
396 .insert(constraint.to_string(), (Box::new(constraint), false));
397 }
398
399 pub fn insert_object_safe(&mut self, constraint: impl ObjectSafeConstraint<Prim>) {
401 self.inner
402 .insert(constraint.to_string(), (Box::new(constraint), true));
403 }
404
405 pub(crate) fn insert_boxed(&mut self, constraint: Box<dyn Constraint<Prim>>) {
407 self.inner
408 .insert(constraint.to_string(), (constraint, false));
409 }
410
411 pub(crate) fn get_by_name(&self, name: &str) -> Option<(&dyn Constraint<Prim>, bool)> {
413 self.inner
414 .get(name)
415 .map(|(constraint, is_object_safe)| (constraint.as_ref(), *is_object_safe))
416 }
417
418 pub(crate) fn apply_all(
420 &self,
421 ty: &Type<Prim>,
422 substitutions: &mut Substitutions<Prim>,
423 mut errors: OpErrors<'_, Prim>,
424 ) {
425 for (constraint, _) in self.inner.values() {
426 constraint
427 .visitor(substitutions, errors.by_ref())
428 .visit_type(ty);
429 }
430 }
431
432 pub(crate) fn apply_all_to_object(
434 &self,
435 object: &Object<Prim>,
436 substitutions: &mut Substitutions<Prim>,
437 mut errors: OpErrors<'_, Prim>,
438 ) {
439 for (constraint, _) in self.inner.values() {
440 constraint
441 .visitor(substitutions, errors.by_ref())
442 .visit_object(object);
443 }
444 }
445}
446
447#[derive(Debug, Clone, PartialEq)]
449pub(crate) struct CompleteConstraints<Prim: PrimitiveType> {
450 pub simple: ConstraintSet<Prim>,
451 pub object: Option<Object<Prim>>,
453}
454
455impl<Prim: PrimitiveType> Default for CompleteConstraints<Prim> {
456 fn default() -> Self {
457 Self {
458 simple: ConstraintSet::new(),
459 object: None,
460 }
461 }
462}
463
464impl<Prim: PrimitiveType> From<ConstraintSet<Prim>> for CompleteConstraints<Prim> {
465 fn from(constraints: ConstraintSet<Prim>) -> Self {
466 Self {
467 simple: constraints,
468 object: None,
469 }
470 }
471}
472
473impl<Prim: PrimitiveType> From<Object<Prim>> for CompleteConstraints<Prim> {
474 fn from(object: Object<Prim>) -> Self {
475 Self {
476 simple: ConstraintSet::default(),
477 object: Some(object),
478 }
479 }
480}
481
482impl<Prim: PrimitiveType> fmt::Display for CompleteConstraints<Prim> {
483 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
484 match (&self.object, self.simple.is_empty()) {
485 (Some(object), false) => write!(formatter, "{object} + {}", self.simple),
486 (Some(object), true) => fmt::Display::fmt(object, formatter),
487 (None, _) => fmt::Display::fmt(&self.simple, formatter),
488 }
489 }
490}
491
492impl<Prim: PrimitiveType> CompleteConstraints<Prim> {
493 pub fn is_empty(&self) -> bool {
495 self.object.is_none() && self.simple.is_empty()
496 }
497
498 pub fn insert(
500 &mut self,
501 constraint: impl Constraint<Prim>,
502 substitutions: &mut Substitutions<Prim>,
503 errors: OpErrors<'_, Prim>,
504 ) {
505 self.simple.insert(constraint);
506 self.check_object_consistency(substitutions, errors);
507 }
508
509 pub fn apply_all(
511 &self,
512 ty: &Type<Prim>,
513 substitutions: &mut Substitutions<Prim>,
514 mut errors: OpErrors<'_, Prim>,
515 ) {
516 self.simple.apply_all(ty, substitutions, errors.by_ref());
517 if let Some(lhs) = &self.object {
518 lhs.apply_as_constraint(ty, substitutions, errors);
519 }
520 }
521
522 pub fn map_object(self, map: impl FnOnce(&mut Object<Prim>)) -> Self {
524 Self {
525 simple: self.simple,
526 object: self.object.map(|mut object| {
527 map(&mut object);
528 object
529 }),
530 }
531 }
532
533 pub fn insert_obj_constraint(
535 &mut self,
536 object: Object<Prim>,
537 substitutions: &mut Substitutions<Prim>,
538 mut errors: OpErrors<'_, Prim>,
539 ) {
540 if let Some(existing_object) = &mut self.object {
541 existing_object.extend_from(object, substitutions, errors.by_ref());
542 } else {
543 self.object = Some(object);
544 }
545 self.check_object_consistency(substitutions, errors);
546 }
547
548 fn check_object_consistency(
549 &self,
550 substitutions: &mut Substitutions<Prim>,
551 errors: OpErrors<'_, Prim>,
552 ) {
553 if let Some(object) = &self.object {
554 self.simple
555 .apply_all_to_object(object, substitutions, errors);
556 }
557 }
558}