1use core::cmp::Ordering;
4
5use arithmetic_parser::{BinaryOp, Location, Op, UnaryOp};
6
7use crate::{
8 alloc::Arc,
9 arith::OrdArithmetic,
10 error::{AuxErrorInfo, Error, ErrorKind, TupleLenMismatchContext},
11 exec::ModuleId,
12 Object, Tuple, Value,
13};
14
15#[derive(Debug, Clone, Copy)]
16enum OpSide {
17 Lhs,
18 Rhs,
19}
20
21#[derive(Debug)]
22struct BinaryOpError {
23 inner: ErrorKind,
24 side: Option<OpSide>,
25}
26
27impl BinaryOpError {
28 fn new(op: BinaryOp) -> Self {
29 Self {
30 inner: ErrorKind::UnexpectedOperand { op: Op::Binary(op) },
31 side: None,
32 }
33 }
34
35 fn tuple(op: BinaryOp, lhs: usize, rhs: usize) -> Self {
36 Self {
37 inner: ErrorKind::TupleLenMismatch {
38 lhs: lhs.into(),
39 rhs,
40 context: TupleLenMismatchContext::BinaryOp(op),
41 },
42 side: Some(OpSide::Lhs),
43 }
44 }
45
46 fn object<T>(op: BinaryOp, lhs: Object<T>, rhs: Object<T>) -> Self {
47 Self {
48 inner: ErrorKind::FieldsMismatch {
49 lhs_fields: lhs.into_iter().map(|(name, _)| name).collect(),
50 rhs_fields: rhs.into_iter().map(|(name, _)| name).collect(),
51 op,
52 },
53 side: Some(OpSide::Lhs),
54 }
55 }
56
57 fn with_side(mut self, side: OpSide) -> Self {
58 self.side = Some(side);
59 self
60 }
61
62 fn with_error_kind(mut self, error_kind: ErrorKind) -> Self {
63 self.inner = error_kind;
64 self
65 }
66
67 fn span(
68 self,
69 module_id: Arc<dyn ModuleId>,
70 total_span: Location,
71 lhs_span: Location,
72 rhs_span: Location,
73 ) -> Error {
74 let main_span = match self.side {
75 Some(OpSide::Lhs) => lhs_span,
76 Some(OpSide::Rhs) => rhs_span,
77 None => total_span,
78 };
79
80 let aux_info = match &self.inner {
81 ErrorKind::TupleLenMismatch { rhs, .. } => Some(AuxErrorInfo::UnbalancedRhsTuple(*rhs)),
82 ErrorKind::FieldsMismatch { rhs_fields, .. } => {
83 Some(AuxErrorInfo::UnbalancedRhsObject(rhs_fields.clone()))
84 }
85 _ => None,
86 };
87
88 let mut err = Error::new(module_id, &main_span, self.inner);
89 if let Some(aux_info) = aux_info {
90 err = err.with_location(&rhs_span, aux_info);
91 }
92 err
93 }
94}
95
96impl<T: Clone> Value<T> {
97 fn try_binary_op_inner(
98 self,
99 rhs: Self,
100 op: BinaryOp,
101 arithmetic: &dyn OrdArithmetic<T>,
102 ) -> Result<Self, BinaryOpError> {
103 match (self, rhs) {
104 (Self::Prim(this), Self::Prim(other)) => {
105 let op_result = match op {
106 BinaryOp::Add => arithmetic.add(this, other),
107 BinaryOp::Sub => arithmetic.sub(this, other),
108 BinaryOp::Mul => arithmetic.mul(this, other),
109 BinaryOp::Div => arithmetic.div(this, other),
110 BinaryOp::Power => arithmetic.pow(this, other),
111 _ => unreachable!(),
112 };
113 op_result
114 .map(Self::Prim)
115 .map_err(|e| BinaryOpError::new(op).with_error_kind(ErrorKind::Arithmetic(e)))
116 }
117
118 (this @ Self::Prim(_), Self::Tuple(other)) => {
119 let output: Result<Tuple<_>, _> = other
120 .into_iter()
121 .map(|y| this.clone().try_binary_op_inner(y, op, arithmetic))
122 .collect();
123 output.map(Self::Tuple)
124 }
125 (Self::Tuple(this), other @ Self::Prim(_)) => {
126 let output: Result<Tuple<_>, _> = this
127 .into_iter()
128 .map(|x| x.try_binary_op_inner(other.clone(), op, arithmetic))
129 .collect();
130 output.map(Self::Tuple)
131 }
132
133 (Self::Tuple(this), Self::Tuple(other)) => {
134 if this.len() == other.len() {
135 let output: Result<Tuple<_>, _> = this
136 .into_iter()
137 .zip(other)
138 .map(|(x, y)| x.try_binary_op_inner(y, op, arithmetic))
139 .collect();
140 output.map(Self::Tuple)
141 } else {
142 Err(BinaryOpError::tuple(op, this.len(), other.len()))
143 }
144 }
145
146 (this @ Self::Prim(_), Self::Object(other)) => {
147 let output: Result<Object<_>, _> = other
148 .into_iter()
149 .map(|(name, y)| {
150 this.clone()
151 .try_binary_op_inner(y, op, arithmetic)
152 .map(|res| (name, res))
153 })
154 .collect();
155 output.map(Self::Object)
156 }
157 (Self::Object(this), other @ Self::Prim(_)) => {
158 let output: Result<Object<_>, _> = this
159 .into_iter()
160 .map(|(name, x)| {
161 x.try_binary_op_inner(other.clone(), op, arithmetic)
162 .map(|res| (name, res))
163 })
164 .collect();
165 output.map(Self::Object)
166 }
167
168 (Self::Object(this), Self::Object(mut other)) => {
169 let same_keys = this.len() == other.len()
170 && this.field_names().all(|key| other.contains_field(key));
171 if same_keys {
172 let output: Result<Object<_>, _> = this
173 .into_iter()
174 .map(|(name, x)| {
175 let y = other.remove(&name).unwrap();
176 x.try_binary_op_inner(y, op, arithmetic)
178 .map(|res| (name, res))
179 })
180 .collect();
181 output.map(Self::Object)
182 } else {
183 Err(BinaryOpError::object(op, this, other))
184 }
185 }
186
187 (Self::Prim(_) | Self::Tuple(_) | Self::Object(_), _) => {
188 Err(BinaryOpError::new(op).with_side(OpSide::Rhs))
189 }
190 _ => Err(BinaryOpError::new(op).with_side(OpSide::Lhs)),
191 }
192 }
193
194 #[inline]
195 pub(crate) fn try_binary_op(
196 module_id: &Arc<dyn ModuleId>,
197 total_span: Location,
198 lhs: Location<Self>,
199 rhs: Location<Self>,
200 op: BinaryOp,
201 arithmetic: &dyn OrdArithmetic<T>,
202 ) -> Result<Self, Error> {
203 let lhs_span = lhs.with_no_extra();
204 let rhs_span = rhs.with_no_extra();
205 lhs.extra
206 .try_binary_op_inner(rhs.extra, op, arithmetic)
207 .map_err(|e| e.span(module_id.clone(), total_span, lhs_span, rhs_span))
208 }
209}
210
211impl<T> Value<T> {
212 pub(crate) fn try_neg(self, arithmetic: &dyn OrdArithmetic<T>) -> Result<Self, ErrorKind> {
213 match self {
214 Self::Prim(val) => arithmetic
215 .neg(val)
216 .map(Self::Prim)
217 .map_err(ErrorKind::Arithmetic),
218
219 Self::Tuple(tuple) => {
220 let res: Result<Tuple<_>, _> = tuple
221 .into_iter()
222 .map(|elem| elem.try_neg(arithmetic))
223 .collect();
224 res.map(Self::Tuple)
225 }
226
227 Self::Object(object) => {
228 let res: Result<Object<_>, _> = object
229 .into_iter()
230 .map(|(name, value)| value.try_neg(arithmetic).map(|mapped| (name, mapped)))
231 .collect();
232 res.map(Self::Object)
233 }
234
235 _ => Err(ErrorKind::UnexpectedOperand {
236 op: UnaryOp::Neg.into(),
237 }),
238 }
239 }
240
241 pub(crate) fn try_not(self) -> Result<Self, ErrorKind> {
242 match self {
243 Self::Bool(val) => Ok(Self::Bool(!val)),
244 Self::Tuple(tuple) => {
245 let res: Result<Tuple<_>, _> = tuple.into_iter().map(Value::try_not).collect();
246 res.map(Self::Tuple)
247 }
248
249 _ => Err(ErrorKind::UnexpectedOperand {
250 op: UnaryOp::Not.into(),
251 }),
252 }
253 }
254
255 pub(crate) fn eq_by_arithmetic(&self, rhs: &Self, arithmetic: &dyn OrdArithmetic<T>) -> bool {
257 match (self, rhs) {
258 (Self::Prim(this), Self::Prim(other)) => arithmetic.eq(this, other),
259 (Self::Bool(this), Self::Bool(other)) => this == other,
260 (Self::Tuple(this), Self::Tuple(other)) => {
261 if this.len() == other.len() {
262 this.iter()
263 .zip(other.iter())
264 .all(|(x, y)| x.eq_by_arithmetic(y, arithmetic))
265 } else {
266 false
267 }
268 }
269 (Self::Object(this), Self::Object(other)) => this.eq_by_arithmetic(other, arithmetic),
270 (Self::Function(this), Self::Function(other)) => this.is_same_function(other),
271 (Self::Ref(this), Self::Ref(other)) => this == other,
272 _ => false,
273 }
274 }
275
276 pub(crate) fn compare(
277 module_id: &Arc<dyn ModuleId>,
278 lhs: &Location<Self>,
279 rhs: &Location<Self>,
280 op: BinaryOp,
281 arithmetic: &dyn OrdArithmetic<T>,
282 ) -> Result<Self, Error> {
283 let Value::Prim(lhs_value) = &lhs.extra else {
285 return Err(Error::new(module_id.clone(), lhs, ErrorKind::CannotCompare));
286 };
287 let Value::Prim(rhs_value) = &rhs.extra else {
288 return Err(Error::new(module_id.clone(), rhs, ErrorKind::CannotCompare));
289 };
290
291 let maybe_ordering = arithmetic.partial_cmp(lhs_value, rhs_value);
292 let cmp_result = maybe_ordering.is_some_and(|ordering| match op {
293 BinaryOp::Gt => ordering == Ordering::Greater,
294 BinaryOp::Lt => ordering == Ordering::Less,
295 BinaryOp::Ge => ordering != Ordering::Less,
296 BinaryOp::Le => ordering != Ordering::Greater,
297 _ => unreachable!(),
298 });
299 Ok(Value::Bool(cmp_result))
300 }
301
302 pub(crate) fn try_and(
303 module_id: &Arc<dyn ModuleId>,
304 lhs: &Location<Self>,
305 rhs: &Location<Self>,
306 ) -> Result<Self, Error> {
307 match (&lhs.extra, &rhs.extra) {
308 (Value::Bool(this), Value::Bool(other)) => Ok(Value::Bool(*this && *other)),
309 (Value::Bool(_), _) => {
310 let err = ErrorKind::UnexpectedOperand {
311 op: BinaryOp::And.into(),
312 };
313 Err(Error::new(module_id.clone(), rhs, err))
314 }
315 _ => {
316 let err = ErrorKind::UnexpectedOperand {
317 op: BinaryOp::And.into(),
318 };
319 Err(Error::new(module_id.clone(), lhs, err))
320 }
321 }
322 }
323
324 pub(crate) fn try_or(
325 module_id: &Arc<dyn ModuleId>,
326 lhs: &Location<Self>,
327 rhs: &Location<Self>,
328 ) -> Result<Self, Error> {
329 match (&lhs.extra, &rhs.extra) {
330 (Value::Bool(this), Value::Bool(other)) => Ok(Value::Bool(*this || *other)),
331 (Value::Bool(_), _) => {
332 let err = ErrorKind::UnexpectedOperand {
333 op: BinaryOp::Or.into(),
334 };
335 Err(Error::new(module_id.clone(), rhs, err))
336 }
337 _ => {
338 let err = ErrorKind::UnexpectedOperand {
339 op: BinaryOp::Or.into(),
340 };
341 Err(Error::new(module_id.clone(), lhs, err))
342 }
343 }
344 }
345}
346
347impl<T> Object<T> {
348 fn eq_by_arithmetic(&self, other: &Self, arithmetic: &dyn OrdArithmetic<T>) -> bool {
349 if self.len() == other.len() {
350 for (field_name, this_elem) in self {
351 let Some(that_elem) = other.get(field_name) else {
352 return false;
353 };
354 if !this_elem.eq_by_arithmetic(that_elem, arithmetic) {
355 return false;
356 }
357 }
358 true
359 } else {
360 false
361 }
362 }
363}