1use core::{fmt, str::FromStr};
5
6use arithmetic_parser::{BinaryOp, UnaryOp};
7use num_traits::NumOps;
8
9pub(crate) use self::constraints::CompleteConstraints;
10pub use self::{
11 constraints::{
12 Constraint, ConstraintSet, LinearType, Linearity, ObjectSafeConstraint, Ops,
13 StructConstraint,
14 },
15 substitutions::Substitutions,
16};
17use crate::{
18 error::{ErrorKind, ErrorPathFragment, OpErrors},
19 PrimitiveType, Type,
20};
21
22mod constraints;
23mod substitutions;
24
25pub trait MapPrimitiveType<Val> {
30 type Prim: PrimitiveType;
32
33 fn type_of_literal(&self, lit: &Val) -> Self::Prim;
35}
36
37pub trait TypeArithmetic<Prim: PrimitiveType> {
44 fn process_unary_op(
46 &self,
47 substitutions: &mut Substitutions<Prim>,
48 context: &UnaryOpContext<Prim>,
49 errors: OpErrors<'_, Prim>,
50 ) -> Type<Prim>;
51
52 fn process_binary_op(
54 &self,
55 substitutions: &mut Substitutions<Prim>,
56 context: &BinaryOpContext<Prim>,
57 errors: OpErrors<'_, Prim>,
58 ) -> Type<Prim>;
59}
60
61#[derive(Debug, Clone)]
65pub struct UnaryOpContext<Prim: PrimitiveType> {
66 pub op: UnaryOp,
68 pub arg: Type<Prim>,
70}
71
72#[derive(Debug, Clone)]
76pub struct BinaryOpContext<Prim: PrimitiveType> {
77 pub op: BinaryOp,
79 pub lhs: Type<Prim>,
81 pub rhs: Type<Prim>,
83}
84
85pub trait WithBoolean: PrimitiveType {
87 const BOOL: Self;
89}
90
91#[derive(Debug, Clone, Copy, Default)]
94pub struct BoolArithmetic;
95
96impl<Prim: WithBoolean> TypeArithmetic<Prim> for BoolArithmetic {
97 fn process_unary_op(
102 &self,
103 substitutions: &mut Substitutions<Prim>,
104 context: &UnaryOpContext<Prim>,
105 mut errors: OpErrors<'_, Prim>,
106 ) -> Type<Prim> {
107 let op = context.op;
108 if op == UnaryOp::Not {
109 substitutions.unify(&Type::BOOL, &context.arg, errors);
110 Type::BOOL
111 } else {
112 let err = ErrorKind::unsupported(op);
113 errors.push(err);
114 substitutions.new_type_var()
115 }
116 }
117
118 fn process_binary_op(
125 &self,
126 substitutions: &mut Substitutions<Prim>,
127 context: &BinaryOpContext<Prim>,
128 mut errors: OpErrors<'_, Prim>,
129 ) -> Type<Prim> {
130 match context.op {
131 BinaryOp::Eq | BinaryOp::NotEq => {
132 substitutions.unify(&context.lhs, &context.rhs, errors);
133 Type::BOOL
134 }
135
136 BinaryOp::And | BinaryOp::Or => {
137 substitutions.unify(
138 &Type::BOOL,
139 &context.lhs,
140 errors.join_path(ErrorPathFragment::Lhs),
141 );
142 substitutions.unify(
143 &Type::BOOL,
144 &context.rhs,
145 errors.join_path(ErrorPathFragment::Rhs),
146 );
147 Type::BOOL
148 }
149
150 _ => {
151 errors.push(ErrorKind::unsupported(context.op));
152 substitutions.new_type_var()
153 }
154 }
155 }
156}
157
158#[derive(Debug, Clone)]
160pub struct OpConstraintSettings<'a, Prim: PrimitiveType> {
161 pub lin: &'a dyn Constraint<Prim>,
163 pub ops: &'a dyn Constraint<Prim>,
165}
166
167impl<Prim: PrimitiveType> Copy for OpConstraintSettings<'_, Prim> {}
168
169#[derive(Debug, Clone, Copy, PartialEq, Eq)]
171pub enum Num {
172 Num,
174 Bool,
176}
177
178impl fmt::Display for Num {
179 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
180 formatter.write_str(match self {
181 Self::Num => "Num",
182 Self::Bool => "Bool",
183 })
184 }
185}
186
187impl FromStr for Num {
188 type Err = anyhow::Error;
189
190 fn from_str(s: &str) -> Result<Self, Self::Err> {
191 match s {
192 "Num" => Ok(Self::Num),
193 "Bool" => Ok(Self::Bool),
194 _ => Err(anyhow::anyhow!("Expected `Num` or `Bool`")),
195 }
196 }
197}
198
199impl PrimitiveType for Num {
200 fn well_known_constraints() -> ConstraintSet<Self> {
201 let mut constraints = ConstraintSet::default();
202 constraints.insert_object_safe(Linearity);
203 constraints.insert(Ops);
204 constraints
205 }
206}
207
208impl WithBoolean for Num {
209 const BOOL: Self = Self::Bool;
210}
211
212impl LinearType for Num {
214 fn is_linear(&self) -> bool {
215 matches!(self, Self::Num)
216 }
217}
218
219#[derive(Debug, Clone)]
239pub struct NumArithmetic {
240 comparisons_enabled: bool,
241}
242
243impl NumArithmetic {
244 pub const fn without_comparisons() -> Self {
246 Self {
247 comparisons_enabled: false,
248 }
249 }
250
251 pub const fn with_comparisons() -> Self {
253 Self {
254 comparisons_enabled: true,
255 }
256 }
257
258 pub fn unify_binary_op<Prim: PrimitiveType>(
267 substitutions: &mut Substitutions<Prim>,
268 context: &BinaryOpContext<Prim>,
269 mut errors: OpErrors<'_, Prim>,
270 settings: OpConstraintSettings<'_, Prim>,
271 ) -> Type<Prim> {
272 let lhs_ty = &context.lhs;
273 let rhs_ty = &context.rhs;
274 let resolved_lhs_ty = substitutions.fast_resolve(lhs_ty);
275 let resolved_rhs_ty = substitutions.fast_resolve(rhs_ty);
276
277 match (
278 resolved_lhs_ty.is_primitive(),
279 resolved_rhs_ty.is_primitive(),
280 ) {
281 (Some(true), Some(false)) => {
282 let resolved_rhs_ty = resolved_rhs_ty.clone();
283 settings
284 .lin
285 .visitor(substitutions, errors.join_path(ErrorPathFragment::Lhs))
286 .visit_type(lhs_ty);
287 settings
288 .lin
289 .visitor(substitutions, errors.join_path(ErrorPathFragment::Rhs))
290 .visit_type(rhs_ty);
291 resolved_rhs_ty
292 }
293 (Some(false), Some(true)) => {
294 let resolved_lhs_ty = resolved_lhs_ty.clone();
295 settings
296 .lin
297 .visitor(substitutions, errors.join_path(ErrorPathFragment::Lhs))
298 .visit_type(lhs_ty);
299 settings
300 .lin
301 .visitor(substitutions, errors.join_path(ErrorPathFragment::Rhs))
302 .visit_type(rhs_ty);
303 resolved_lhs_ty
304 }
305 _ => {
306 let lhs_is_valid = errors.join_path(ErrorPathFragment::Lhs).check(|errors| {
307 settings
308 .ops
309 .visitor(substitutions, errors)
310 .visit_type(lhs_ty);
311 });
312 let rhs_is_valid = errors.join_path(ErrorPathFragment::Rhs).check(|errors| {
313 settings
314 .ops
315 .visitor(substitutions, errors)
316 .visit_type(rhs_ty);
317 });
318
319 if lhs_is_valid && rhs_is_valid {
320 substitutions.unify(lhs_ty, rhs_ty, errors);
321 }
322 if lhs_is_valid {
323 lhs_ty.clone()
324 } else {
325 rhs_ty.clone()
326 }
327 }
328 }
329 }
330
331 pub fn process_unary_op<Prim: WithBoolean>(
336 substitutions: &mut Substitutions<Prim>,
337 context: &UnaryOpContext<Prim>,
338 mut errors: OpErrors<'_, Prim>,
339 constraints: &impl Constraint<Prim>,
340 ) -> Type<Prim> {
341 match context.op {
342 UnaryOp::Not => BoolArithmetic.process_unary_op(substitutions, context, errors),
343 UnaryOp::Neg => {
344 constraints
345 .visitor(substitutions, errors)
346 .visit_type(&context.arg);
347 context.arg.clone()
348 }
349 _ => {
350 errors.push(ErrorKind::unsupported(context.op));
351 substitutions.new_type_var()
352 }
353 }
354 }
355
356 pub fn process_binary_op<Prim: WithBoolean>(
367 substitutions: &mut Substitutions<Prim>,
368 context: &BinaryOpContext<Prim>,
369 mut errors: OpErrors<'_, Prim>,
370 comparable_type: Option<Prim>,
371 settings: OpConstraintSettings<'_, Prim>,
372 ) -> Type<Prim> {
373 match context.op {
374 BinaryOp::And | BinaryOp::Or | BinaryOp::Eq | BinaryOp::NotEq => {
375 BoolArithmetic.process_binary_op(substitutions, context, errors)
376 }
377
378 BinaryOp::Add | BinaryOp::Sub | BinaryOp::Mul | BinaryOp::Div | BinaryOp::Power => {
379 Self::unify_binary_op(substitutions, context, errors, settings)
380 }
381
382 BinaryOp::Ge | BinaryOp::Le | BinaryOp::Lt | BinaryOp::Gt => {
383 if let Some(ty) = comparable_type {
384 let ty = Type::Prim(ty);
385 substitutions.unify(
386 &ty,
387 &context.lhs,
388 errors.join_path(ErrorPathFragment::Lhs),
389 );
390 substitutions.unify(
391 &ty,
392 &context.rhs,
393 errors.join_path(ErrorPathFragment::Rhs),
394 );
395 } else {
396 let err = ErrorKind::unsupported(context.op);
397 errors.push(err);
398 }
399 Type::BOOL
400 }
401
402 _ => {
403 errors.push(ErrorKind::unsupported(context.op));
404 substitutions.new_type_var()
405 }
406 }
407 }
408}
409
410impl<Val> MapPrimitiveType<Val> for NumArithmetic
411where
412 Val: Clone + NumOps + PartialEq,
413{
414 type Prim = Num;
415
416 fn type_of_literal(&self, _: &Val) -> Self::Prim {
417 Num::Num
418 }
419}
420
421impl TypeArithmetic<Num> for NumArithmetic {
422 fn process_unary_op(
423 &self,
424 substitutions: &mut Substitutions<Num>,
425 context: &UnaryOpContext<Num>,
426 errors: OpErrors<'_, Num>,
427 ) -> Type<Num> {
428 Self::process_unary_op(substitutions, context, errors, &Linearity)
429 }
430
431 fn process_binary_op(
432 &self,
433 substitutions: &mut Substitutions<Num>,
434 context: &BinaryOpContext<Num>,
435 errors: OpErrors<'_, Num>,
436 ) -> Type<Num> {
437 const OP_SETTINGS: OpConstraintSettings<'static, Num> = OpConstraintSettings {
438 lin: &Linearity,
439 ops: &Ops,
440 };
441
442 let comparable_type = if self.comparisons_enabled {
443 Some(Num::Num)
444 } else {
445 None
446 };
447
448 Self::process_binary_op(substitutions, context, errors, comparable_type, OP_SETTINGS)
449 }
450}