1use core::ops;
4
5use arithmetic_parser::{grammars::Grammar, Block};
6
7use self::processor::TypeProcessor;
8use crate::{
9 alloc::{HashMap, String, ToOwned},
10 arith::{
11 Constraint, ConstraintSet, MapPrimitiveType, Num, NumArithmetic, ObjectSafeConstraint,
12 Substitutions, TypeArithmetic,
13 },
14 ast::TypeAst,
15 error::Errors,
16 types::{ParamConstraints, ParamQuantifier},
17 visit::VisitMut,
18 Function, PrimitiveType, Type,
19};
20
21mod processor;
22
23#[derive(Debug, Clone)]
52pub struct TypeEnvironment<Prim: PrimitiveType = Num> {
53 pub(crate) substitutions: Substitutions<Prim>,
54 pub(crate) known_constraints: ConstraintSet<Prim>,
55 variables: HashMap<String, Type<Prim>>,
56}
57
58impl<Prim: PrimitiveType> Default for TypeEnvironment<Prim> {
59 fn default() -> Self {
60 Self {
61 variables: HashMap::new(),
62 known_constraints: Prim::well_known_constraints(),
63 substitutions: Substitutions::default(),
64 }
65 }
66}
67
68impl<Prim: PrimitiveType> TypeEnvironment<Prim> {
69 pub fn new() -> Self {
71 Self::default()
72 }
73
74 pub fn get(&self, name: &str) -> Option<&Type<Prim>> {
76 self.variables.get(name)
77 }
78
79 pub fn iter(&self) -> impl Iterator<Item = (&str, &Type<Prim>)> + '_ {
81 self.variables.iter().map(|(name, ty)| (name.as_str(), ty))
82 }
83
84 fn prepare_type(ty: impl Into<Type<Prim>>) -> Type<Prim> {
85 let mut ty = ty.into();
86 assert!(ty.is_concrete(), "Type {ty} is not concrete");
87 TypePreparer.visit_type_mut(&mut ty);
88 ty
89 }
90
91 pub fn insert(&mut self, name: &str, ty: impl Into<Type<Prim>>) -> &mut Self {
98 self.variables
99 .insert(name.to_owned(), Self::prepare_type(ty));
100 self
101 }
102
103 pub fn insert_constraint(&mut self, constraint: impl Constraint<Prim>) -> &mut Self {
109 self.known_constraints.insert(constraint);
110 self
111 }
112
113 pub fn insert_object_safe_constraint(
119 &mut self,
120 constraint: impl ObjectSafeConstraint<Prim>,
121 ) -> &mut Self {
122 self.known_constraints.insert_object_safe(constraint);
123 self
124 }
125
126 pub fn process_statements<'a, T>(
132 &mut self,
133 block: &Block<'a, T>,
134 ) -> Result<Type<Prim>, Errors<Prim>>
135 where
136 T: Grammar<Type<'a> = TypeAst<'a>>,
137 NumArithmetic: MapPrimitiveType<T::Lit, Prim = Prim> + TypeArithmetic<Prim>,
138 {
139 self.process_with_arithmetic(&NumArithmetic::without_comparisons(), block)
140 }
141
142 pub fn process_with_arithmetic<'a, T, A>(
151 &mut self,
152 arithmetic: &A,
153 block: &Block<'a, T>,
154 ) -> Result<Type<Prim>, Errors<Prim>>
155 where
156 T: Grammar<Type<'a> = TypeAst<'a>>,
157 A: MapPrimitiveType<T::Lit, Prim = Prim> + TypeArithmetic<Prim>,
158 {
159 TypeProcessor::new(self, arithmetic).process_statements(block)
160 }
161}
162
163impl<Prim: PrimitiveType> ops::Index<&str> for TypeEnvironment<Prim> {
164 type Output = Type<Prim>;
165
166 fn index(&self, name: &str) -> &Self::Output {
167 self.get(name)
168 .unwrap_or_else(|| panic!("Variable `{name}` is not defined"))
169 }
170}
171
172#[derive(Debug)]
174struct TypePreparer;
175
176impl<Prim: PrimitiveType> VisitMut<Prim> for TypePreparer {
177 fn visit_function_mut(&mut self, function: &mut Function<Prim>) {
178 if function.params.is_none() {
179 ParamQuantifier::fill_params(function, ParamConstraints::default());
180 }
181 }
183}
184
185fn convert_iter<Prim: PrimitiveType, S, Ty, I>(
186 iter: I,
187) -> impl Iterator<Item = (String, Type<Prim>)>
188where
189 I: IntoIterator<Item = (S, Ty)>,
190 S: Into<String>,
191 Ty: Into<Type<Prim>>,
192{
193 iter.into_iter()
194 .map(|(name, ty)| (name.into(), TypeEnvironment::prepare_type(ty)))
195}
196
197impl<Prim: PrimitiveType, S, Ty> FromIterator<(S, Ty)> for TypeEnvironment<Prim>
198where
199 S: Into<String>,
200 Ty: Into<Type<Prim>>,
201{
202 fn from_iter<I: IntoIterator<Item = (S, Ty)>>(iter: I) -> Self {
203 Self {
204 variables: convert_iter(iter).collect(),
205 known_constraints: Prim::well_known_constraints(),
206 substitutions: Substitutions::default(),
207 }
208 }
209}
210
211impl<Prim: PrimitiveType, S, Ty> Extend<(S, Ty)> for TypeEnvironment<Prim>
212where
213 S: Into<String>,
214 Ty: Into<Type<Prim>>,
215{
216 fn extend<I: IntoIterator<Item = (S, Ty)>>(&mut self, iter: I) {
217 self.variables.extend(convert_iter(iter));
218 }
219}
220
221trait FullArithmetic<Val, Prim: PrimitiveType>:
223 MapPrimitiveType<Val, Prim = Prim> + TypeArithmetic<Prim>
224{
225}
226
227impl<Val, Prim: PrimitiveType, T> FullArithmetic<Val, Prim> for T where
228 T: MapPrimitiveType<Val, Prim = Prim> + TypeArithmetic<Prim>
229{
230}