1use core::fmt;
4
5use arithmetic_parser::{
6 Error as ParseError, ErrorKind as ParseErrorKind, InputSpan, NomResult, Spanned,
7};
8use nom::Err as NomErr;
9
10use crate::{
11 alloc::{Box, HashMap, HashSet, String, ToOwned},
12 arith::{CompleteConstraints, Constraint, ConstraintSet},
13 ast::{
14 ConstraintsAst, FunctionAst, ObjectAst, SliceAst, SpannedTypeAst, TupleAst, TupleLenAst,
15 TypeAst, TypeConstraintsAst,
16 },
17 error::{Error, Errors},
18 types::{ParamConstraints, ParamQuantifier},
19 DynConstraints, Function, Object, PrimitiveType, Slice, Tuple, Type, TypeEnvironment,
20 UnknownLen,
21};
22
23#[derive(Debug, Clone)]
55#[non_exhaustive]
56pub enum AstConversionError {
57 EmbeddedQuantifier,
59 FreeLengthVar(String),
61 FreeTypeVar(String),
63 UnusedLength(String),
65 UnusedTypeParam(String),
67 UnknownType(String),
69 UnknownConstraint(String),
71 InvalidSomeType,
76 InvalidSomeLength,
81 DuplicateField(String),
83 NotObjectSafe(String),
85}
86
87impl fmt::Display for AstConversionError {
88 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
89 match self {
90 Self::EmbeddedQuantifier => {
91 formatter.write_str("`for` quantifier for a function that is not top-level")
92 }
93
94 Self::FreeLengthVar(name) => {
95 write!(
96 formatter,
97 "Length param `{name}` is not scoped by function definition"
98 )
99 }
100 Self::FreeTypeVar(name) => {
101 write!(
102 formatter,
103 "Type param `{name}` is not scoped by function definition"
104 )
105 }
106
107 Self::UnusedLength(name) => {
108 write!(formatter, "Unused length param `{name}`")
109 }
110 Self::UnusedTypeParam(name) => {
111 write!(formatter, "Unused type param `{name}`")
112 }
113 Self::UnknownType(name) => {
114 write!(formatter, "Unknown type `{name}`")
115 }
116 Self::UnknownConstraint(name) => {
117 write!(formatter, "Unknown constraint `{name}`")
118 }
119
120 Self::InvalidSomeType => {
121 formatter.write_str("`_` type is disallowed when parsing standalone type")
122 }
123 Self::InvalidSomeLength => {
124 formatter.write_str("`_` length is disallowed when parsing standalone type")
125 }
126
127 Self::DuplicateField(name) => {
128 write!(formatter, "Duplicate field `{name}` in object type")
129 }
130
131 Self::NotObjectSafe(name) => {
132 write!(formatter, "Constraint `{name}` is not object-safe")
133 }
134 }
135 }
136}
137
138#[cfg(feature = "std")]
139impl std::error::Error for AstConversionError {}
140
141#[derive(Debug)]
143pub(crate) struct AstConversionState<'r, 'a, Prim: PrimitiveType> {
144 env: Option<&'r mut TypeEnvironment<Prim>>,
145 known_constraints: ConstraintSet<Prim>,
146 errors: &'r mut Errors<Prim>,
147 len_params: HashMap<&'a str, usize>,
148 type_params: HashMap<&'a str, usize>,
149 is_in_function: bool,
150}
151
152impl<'r, 'a, Prim: PrimitiveType> AstConversionState<'r, 'a, Prim> {
153 pub fn new(env: &'r mut TypeEnvironment<Prim>, errors: &'r mut Errors<Prim>) -> Self {
154 let known_constraints = env.known_constraints.clone();
155 Self {
156 env: Some(env),
157 known_constraints,
158 errors,
159 len_params: HashMap::new(),
160 type_params: HashMap::new(),
161 is_in_function: false,
162 }
163 }
164
165 fn without_env(errors: &'r mut Errors<Prim>) -> Self {
166 Self {
167 env: None,
168 known_constraints: Prim::well_known_constraints(),
169 errors,
170 len_params: HashMap::new(),
171 type_params: HashMap::new(),
172 is_in_function: false,
173 }
174 }
175
176 fn type_param_idx(&mut self, param_name: &'a str) -> usize {
177 let type_param_count = self.type_params.len();
178 *self
179 .type_params
180 .entry(param_name)
181 .or_insert(type_param_count)
182 }
183
184 fn len_param_idx(&mut self, param_name: &'a str) -> usize {
185 let len_param_count = self.len_params.len();
186 *self.len_params.entry(param_name).or_insert(len_param_count)
187 }
188
189 fn new_type(&mut self, span: Option<&SpannedTypeAst<'a>>) -> Type<Prim> {
190 let errors = &mut *self.errors;
191 self.env.as_mut().map_or_else(
192 || {
193 if let Some(span) = span {
194 let err = AstConversionError::InvalidSomeType;
195 errors.push(Error::conversion(err, span));
196 }
197 Type::free_var(0)
200 },
201 |env| env.substitutions.new_type_var(),
202 )
203 }
204
205 fn new_len(&mut self, span: Option<&Spanned<'a, TupleLenAst>>) -> UnknownLen {
206 let errors = &mut *self.errors;
207 self.env.as_mut().map_or_else(
208 || {
209 if let Some(span) = span {
210 let err = AstConversionError::InvalidSomeLength;
211 errors.push(Error::conversion(err, span));
212 }
213 UnknownLen::free_var(0)
216 },
217 |env| env.substitutions.new_len_var(),
218 )
219 }
220
221 fn resolve_constraint(&self, name: &str) -> Option<(Box<dyn Constraint<Prim>>, bool)> {
222 self.known_constraints
223 .get_by_name(name)
224 .map(|(constraint, is_object_safe)| (constraint.clone_boxed(), is_object_safe))
225 }
226
227 pub(crate) fn convert_type(&mut self, ty: &SpannedTypeAst<'a>) -> Type<Prim> {
228 match &ty.extra {
229 TypeAst::Some => self.new_type(Some(ty)),
230 TypeAst::Any => Type::Any,
231 TypeAst::Dyn(constraints) => Type::Dyn(constraints.convert_dyn(self)),
232 TypeAst::Ident => {
233 let ident = *ty.fragment();
234 if let Ok(prim_type) = Prim::from_str(ident) {
235 Type::Prim(prim_type)
236 } else {
237 let err = AstConversionError::UnknownType(ident.to_owned());
238 self.errors.push(Error::conversion(err, ty));
239 self.new_type(None)
240 }
241 }
242
243 TypeAst::Param => {
244 let name = &ty.fragment()[1..];
245 if self.is_in_function {
246 let idx = self.type_param_idx(name);
247 Type::param(idx)
248 } else {
249 let err = AstConversionError::FreeTypeVar(name.to_owned());
250 self.errors.push(Error::conversion(err, ty));
251 self.new_type(None)
252 }
253 }
254
255 TypeAst::Function(function) => self.convert_fn(function, None),
256 TypeAst::FunctionWithConstraints {
257 function,
258 constraints,
259 } => self.convert_fn(&function.extra, Some(constraints)),
260
261 TypeAst::Tuple(tuple) => tuple.convert(self).into(),
262 TypeAst::Slice(slice) => slice.convert(self).into(),
263 TypeAst::Object(object) => object.convert(self).into(),
264 }
265 }
266
267 fn convert_fn(
268 &mut self,
269 function: &FunctionAst<'a>,
270 constraints: Option<&Spanned<'a, ConstraintsAst<'a>>>,
271 ) -> Type<Prim> {
272 if self.is_in_function {
273 if let Some(constraints) = constraints {
274 let err = AstConversionError::EmbeddedQuantifier;
275 self.errors.push(Error::conversion(err, constraints));
276 }
277 function.convert(self).into()
278 } else {
279 self.is_in_function = true;
280 let mut converted_fn = function.convert(self);
281 let constraints =
282 constraints.map_or_else(ParamConstraints::default, |c| c.extra.convert(self));
283 ParamQuantifier::fill_params(&mut converted_fn, constraints);
284
285 self.is_in_function = false;
286 self.type_params.clear();
287 self.len_params.clear();
288 converted_fn.into()
289 }
290 }
291}
292
293impl<'a> TypeConstraintsAst<'a> {
294 fn convert<Prim: PrimitiveType>(
295 &self,
296 state: &mut AstConversionState<'_, 'a, Prim>,
297 ) -> CompleteConstraints<Prim> {
298 self.do_convert(state, false)
299 }
300
301 fn convert_dyn<Prim: PrimitiveType>(
302 &self,
303 state: &mut AstConversionState<'_, 'a, Prim>,
304 ) -> DynConstraints<Prim> {
305 DynConstraints {
306 inner: self.do_convert(state, true),
307 }
308 }
309
310 fn do_convert<Prim: PrimitiveType>(
311 &self,
312 state: &mut AstConversionState<'_, 'a, Prim>,
313 require_object_safety: bool,
314 ) -> CompleteConstraints<Prim> {
315 let mut constraints = CompleteConstraints::default();
316 if let Some(object) = &self.object {
317 constraints.object = Some(object.convert(state));
318 }
319
320 self.terms.iter().fold(constraints, |mut acc, input| {
321 let input_str = *input.fragment();
322 if let Some((constraint, is_object_safe)) = state.resolve_constraint(input_str) {
323 if require_object_safety && !is_object_safe {
324 let err = AstConversionError::NotObjectSafe(input_str.to_owned());
325 state.errors.push(Error::conversion(err, input));
326 } else {
327 acc.simple.insert_boxed(constraint);
328 }
329 } else {
330 let err = AstConversionError::UnknownConstraint(input_str.to_owned());
331 state.errors.push(Error::conversion(err, input));
332 }
333 acc
334 })
335 }
336}
337
338impl<'a> ConstraintsAst<'a> {
339 fn convert<Prim: PrimitiveType>(
340 &self,
341 state: &mut AstConversionState<'_, 'a, Prim>,
342 ) -> ParamConstraints<Prim> {
343 let mut static_lengths = HashSet::with_capacity(self.static_lengths.len());
344 for dyn_length in &self.static_lengths {
345 let name = *dyn_length.fragment();
346 if let Some(index) = state.len_params.get(name) {
347 static_lengths.insert(*index);
348 } else {
349 let err = AstConversionError::UnusedLength(name.to_owned());
350 state.errors.push(Error::conversion(err, dyn_length));
351 }
352 }
353
354 let mut type_params = HashMap::with_capacity(self.type_params.len());
355 for (param, constraints) in &self.type_params {
356 let name = *param.fragment();
357 if let Some(index) = state.type_params.get(name) {
358 type_params.insert(*index, constraints.convert(state));
359 } else {
360 let err = AstConversionError::UnusedTypeParam(name.to_owned());
361 state.errors.push(Error::conversion(err, param));
362 }
363 }
364
365 ParamConstraints {
366 type_params,
367 static_lengths,
368 }
369 }
370}
371
372impl<'a> TupleAst<'a> {
373 fn convert<Prim: PrimitiveType>(
374 &self,
375 state: &mut AstConversionState<'_, 'a, Prim>,
376 ) -> Tuple<Prim> {
377 let start = self
378 .start
379 .iter()
380 .map(|element| state.convert_type(element))
381 .collect();
382 let middle = self
383 .middle
384 .as_ref()
385 .map(|middle| middle.extra.convert(state));
386 let end = self
387 .end
388 .iter()
389 .map(|element| state.convert_type(element))
390 .collect();
391 Tuple::from_parts(start, middle, end)
392 }
393}
394
395impl<'a> SliceAst<'a> {
396 fn convert<Prim: PrimitiveType>(
397 &self,
398 state: &mut AstConversionState<'_, 'a, Prim>,
399 ) -> Slice<Prim> {
400 let element = state.convert_type(&self.element);
401
402 let converted_length = match &self.length.extra {
403 TupleLenAst::Ident => {
404 let name = *self.length.fragment();
405 if state.is_in_function {
406 let const_param = state.len_param_idx(name);
407 UnknownLen::param(const_param)
408 } else {
409 let err = AstConversionError::FreeLengthVar(name.to_owned());
410 state.errors.push(Error::conversion(err, &self.length));
411 state.new_len(None)
412 }
413 }
414 TupleLenAst::Some => state.new_len(Some(&self.length)),
415 TupleLenAst::Dynamic => UnknownLen::Dynamic,
416 };
417
418 Slice::new(element, converted_length)
419 }
420}
421
422impl<'a> ObjectAst<'a> {
423 fn convert<Prim: PrimitiveType>(
424 &self,
425 state: &mut AstConversionState<'_, 'a, Prim>,
426 ) -> Object<Prim> {
427 let mut fields = HashMap::new();
428 for (field_name, ty) in &self.fields {
429 let field_name_str = *field_name.fragment();
430 if fields.contains_key(field_name_str) {
431 let err = AstConversionError::DuplicateField(field_name_str.to_owned());
432 state.errors.push(Error::conversion(err, field_name));
433 } else {
434 fields.insert(field_name_str.to_owned(), state.convert_type(ty));
435 }
436 }
437 Object::from_map(fields)
438 }
439}
440
441impl<'a> FunctionAst<'a> {
442 fn convert<Prim: PrimitiveType>(
443 &self,
444 state: &mut AstConversionState<'_, 'a, Prim>,
445 ) -> Function<Prim> {
446 let args = self.args.extra.convert(state);
447 let return_type = state.convert_type(&self.return_type);
448 Function::new(args, return_type)
449 }
450
451 pub fn try_convert<Prim>(&self) -> Result<Function<Prim>, Errors<Prim>>
453 where
454 Prim: PrimitiveType,
455 {
456 let mut errors = Errors::new();
457 let mut state = AstConversionState::without_env(&mut errors);
458 state.is_in_function = true;
459
460 let output = self.convert(&mut state);
461 if errors.is_empty() {
462 Ok(output)
463 } else {
464 Err(errors)
465 }
466 }
467}
468
469fn parse_inner<'a, Ast>(
471 parser: fn(InputSpan<'a>) -> NomResult<'a, Ast>,
472 input: InputSpan<'a>,
473) -> NomResult<'a, Ast> {
474 let (rest, ast) = parser(input)?;
475 if !rest.fragment().is_empty() {
476 let err = ParseErrorKind::Leftovers.with_span(&rest.into());
477 return Err(NomErr::Failure(err));
478 }
479 Ok((rest, ast))
480}
481
482fn from_str<'a, Ast>(
484 parser: fn(InputSpan<'a>) -> NomResult<'a, Ast>,
485 def: &'a str,
486) -> Result<Ast, ParseError> {
487 let input = InputSpan::new(def);
488 let (_, ast) = parse_inner(parser, input).map_err(|err| match err {
489 NomErr::Incomplete(_) => ParseErrorKind::Incomplete.with_span(&input.into()),
490 NomErr::Error(e) | NomErr::Failure(e) => e,
491 })?;
492 Ok(ast)
493}
494
495impl<'a> TypeAst<'a> {
496 pub fn try_from(def: &'a str) -> Result<SpannedTypeAst<'a>, ParseError> {
498 from_str(TypeAst::parse, def)
499 }
500}
501
502impl<'a, Prim: PrimitiveType> TryFrom<&SpannedTypeAst<'a>> for Type<Prim> {
503 type Error = Errors<Prim>;
504
505 fn try_from(ast: &SpannedTypeAst<'a>) -> Result<Self, Self::Error> {
506 let mut errors = Errors::new();
507 let mut state = AstConversionState::without_env(&mut errors);
508
509 let output = state.convert_type(ast);
510 if errors.is_empty() {
511 Ok(output)
512 } else {
513 Err(errors)
514 }
515 }
516}
517
518impl<'a> TryFrom<&'a str> for FunctionAst<'a> {
519 type Error = ParseError;
520
521 fn try_from(def: &'a str) -> Result<Self, Self::Error> {
522 from_str(FunctionAst::parse, def)
523 }
524}
525
526#[cfg(test)]
527mod tests {
528 use assert_matches::assert_matches;
529
530 use super::*;
531 use crate::{
532 alloc::{vec, ToString},
533 arith::Num,
534 };
535
536 #[test]
537 fn converting_raw_fn_type() {
538 let input = InputSpan::new("(['T; N], ('T) -> Bool) -> Bool");
539 let (_, fn_type) = FunctionAst::parse(input).unwrap();
540 let fn_type = fn_type.try_convert::<Num>().unwrap();
541
542 assert_eq!(fn_type.to_string(), *input.fragment());
543 }
544
545 #[test]
546 fn converting_fn_type_with_constraint() {
547 let input = InputSpan::new("for<'T: Lin> (['T; N], ('T) -> Bool) -> Bool");
548 let (_, ast) = TypeAst::parse(input).unwrap();
549 let fn_type = <Type>::try_from(&ast).unwrap();
550
551 assert_eq!(fn_type.to_string(), *input.fragment());
552 }
553
554 #[test]
555 fn parsing_basic_types() -> anyhow::Result<()> {
556 let num_type = <Type>::try_from(&TypeAst::try_from("Num")?)?;
557 assert_eq!(num_type, Type::NUM);
558
559 let bool_type = <Type>::try_from(&TypeAst::try_from("Bool")?)?;
560 assert_eq!(bool_type, Type::BOOL);
561
562 let tuple_type = <Type>::try_from(&TypeAst::try_from("(Num, (Bool, Bool))")?)?;
563 assert_eq!(
564 tuple_type,
565 Type::from((Type::NUM, Type::Tuple(vec![Type::BOOL; 2].into()),))
566 );
567
568 let slice_type = <Type>::try_from(&TypeAst::try_from("[(Num, Bool)]")?)?;
569 let slice_type = match &slice_type {
570 Type::Tuple(tuple) => tuple.as_slice().unwrap(),
571 _ => panic!("Unexpected type: {slice_type:?}"),
572 };
573
574 assert_eq!(*slice_type.element(), Type::from((Type::NUM, Type::BOOL)));
575 assert_matches!(
576 slice_type.len().components(),
577 (Some(UnknownLen::Dynamic), 0)
578 );
579 Ok(())
580 }
581
582 #[test]
583 fn parsing_functional_type() -> anyhow::Result<()> {
584 let ty = <Type>::try_from(&TypeAst::try_from("(['T; N], ('T) -> 'U) -> 'U")?)?;
585 let Type::Function(ty) = ty else {
586 panic!("Unexpected type: {ty:?}");
587 };
588
589 assert_eq!(ty.params.as_ref().unwrap().len_params.len(), 1);
590 assert_eq!(ty.params.as_ref().unwrap().type_params.len(), 2);
591 assert_eq!(ty.return_type, Type::param(1));
592 Ok(())
593 }
594
595 #[test]
596 fn parsing_functional_type_with_varargs() -> anyhow::Result<()> {
597 let ty = <Type>::try_from(&TypeAst::try_from("(...[Num; N]) -> Num")?)?;
598 let Type::Function(ty) = ty else {
599 panic!("Unexpected type: {ty:?}");
600 };
601
602 assert_eq!(ty.params.as_ref().unwrap().len_params.len(), 1);
603 assert!(ty.params.as_ref().unwrap().type_params.is_empty());
604 let args_slice = ty.args.as_slice().unwrap();
605 assert_eq!(*args_slice.element(), Type::NUM);
606 assert_eq!(args_slice.len(), UnknownLen::param(0).into());
607 Ok(())
608 }
609
610 #[test]
611 fn parsing_incomplete_type() {
612 const INCOMPLETE_TYPES: &[&str] = &[
613 "fn(",
614 "fn(['T; ",
615 "fn(['T; N], fn(",
616 "fn(['T; N], fn('T)",
617 "fn(['T; N], fn('T)) -",
618 "fn(['T; N], fn('T)) ->",
619 ];
620
621 for &input in INCOMPLETE_TYPES {
622 TypeAst::try_from(input).unwrap_err();
624 }
625 }
626
627 #[test]
628 fn parsing_type_with_object_constraint() -> anyhow::Result<()> {
629 let type_def = "for<'T: { x: Num } + Lin> ('T) -> Bool";
630 let ty = TypeAst::try_from(type_def)?;
631 let ty = <Type>::try_from(&ty)?;
632 let Type::Function(ty) = ty else {
633 panic!("Unexpected type: {ty:?}");
634 };
635
636 let type_params = &ty.params.as_ref().unwrap().type_params;
637 assert_eq!(type_params.len(), 1);
638 let (_, type_params) = &type_params[0];
639 assert!(type_params.object.is_some());
640 assert!(type_params.simple.get_by_name("Lin").is_some());
641
642 assert_eq!(ty.to_string(), type_def);
643 Ok(())
644 }
645}