arithmetic_eval/values/
ops.rs

1//! Operations on `Value`s.
2
3use core::cmp::Ordering;
4
5use arithmetic_parser::{BinaryOp, Location, Op, UnaryOp};
6
7use crate::{
8    Object, Tuple, Value,
9    alloc::Arc,
10    arith::OrdArithmetic,
11    error::{AuxErrorInfo, Error, ErrorKind, TupleLenMismatchContext},
12    exec::ModuleId,
13};
14
15#[derive(Debug, Clone, Copy)]
16enum OpSide {
17    Lhs,
18    Rhs,
19}
20
21#[derive(Debug)]
22struct BinaryOpError {
23    inner: ErrorKind,
24    side: Option<OpSide>,
25}
26
27impl BinaryOpError {
28    fn new(op: BinaryOp) -> Self {
29        Self {
30            inner: ErrorKind::UnexpectedOperand { op: Op::Binary(op) },
31            side: None,
32        }
33    }
34
35    fn tuple(op: BinaryOp, lhs: usize, rhs: usize) -> Self {
36        Self {
37            inner: ErrorKind::TupleLenMismatch {
38                lhs: lhs.into(),
39                rhs,
40                context: TupleLenMismatchContext::BinaryOp(op),
41            },
42            side: Some(OpSide::Lhs),
43        }
44    }
45
46    fn object<T>(op: BinaryOp, lhs: Object<T>, rhs: Object<T>) -> Self {
47        Self {
48            inner: ErrorKind::FieldsMismatch {
49                lhs_fields: lhs.into_iter().map(|(name, _)| name).collect(),
50                rhs_fields: rhs.into_iter().map(|(name, _)| name).collect(),
51                op,
52            },
53            side: Some(OpSide::Lhs),
54        }
55    }
56
57    fn with_side(mut self, side: OpSide) -> Self {
58        self.side = Some(side);
59        self
60    }
61
62    fn with_error_kind(mut self, error_kind: ErrorKind) -> Self {
63        self.inner = error_kind;
64        self
65    }
66
67    fn span(
68        self,
69        module_id: Arc<dyn ModuleId>,
70        total_span: Location,
71        lhs_span: Location,
72        rhs_span: Location,
73    ) -> Error {
74        let main_span = match self.side {
75            Some(OpSide::Lhs) => lhs_span,
76            Some(OpSide::Rhs) => rhs_span,
77            None => total_span,
78        };
79
80        let aux_info = match &self.inner {
81            ErrorKind::TupleLenMismatch { rhs, .. } => Some(AuxErrorInfo::UnbalancedRhsTuple(*rhs)),
82            ErrorKind::FieldsMismatch { rhs_fields, .. } => {
83                Some(AuxErrorInfo::UnbalancedRhsObject(rhs_fields.clone()))
84            }
85            _ => None,
86        };
87
88        let mut err = Error::new(module_id, &main_span, self.inner);
89        if let Some(aux_info) = aux_info {
90            err = err.with_location(&rhs_span, aux_info);
91        }
92        err
93    }
94}
95
96impl<T: Clone> Value<T> {
97    fn try_binary_op_inner(
98        self,
99        rhs: Self,
100        op: BinaryOp,
101        arithmetic: &dyn OrdArithmetic<T>,
102    ) -> Result<Self, BinaryOpError> {
103        match (self, rhs) {
104            (Self::Prim(this), Self::Prim(other)) => {
105                let op_result = match op {
106                    BinaryOp::Add => arithmetic.add(this, other),
107                    BinaryOp::Sub => arithmetic.sub(this, other),
108                    BinaryOp::Mul => arithmetic.mul(this, other),
109                    BinaryOp::Div => arithmetic.div(this, other),
110                    BinaryOp::Power => arithmetic.pow(this, other),
111                    _ => unreachable!(),
112                };
113                op_result
114                    .map(Self::Prim)
115                    .map_err(|e| BinaryOpError::new(op).with_error_kind(ErrorKind::Arithmetic(e)))
116            }
117
118            (this @ Self::Prim(_), Self::Tuple(other)) => {
119                let output: Result<Tuple<_>, _> = other
120                    .into_iter()
121                    .map(|y| this.clone().try_binary_op_inner(y, op, arithmetic))
122                    .collect();
123                output.map(Self::Tuple)
124            }
125            (Self::Tuple(this), other @ Self::Prim(_)) => {
126                let output: Result<Tuple<_>, _> = this
127                    .into_iter()
128                    .map(|x| x.try_binary_op_inner(other.clone(), op, arithmetic))
129                    .collect();
130                output.map(Self::Tuple)
131            }
132
133            (Self::Tuple(this), Self::Tuple(other)) => {
134                if this.len() == other.len() {
135                    let output: Result<Tuple<_>, _> = this
136                        .into_iter()
137                        .zip(other)
138                        .map(|(x, y)| x.try_binary_op_inner(y, op, arithmetic))
139                        .collect();
140                    output.map(Self::Tuple)
141                } else {
142                    Err(BinaryOpError::tuple(op, this.len(), other.len()))
143                }
144            }
145
146            (this @ Self::Prim(_), Self::Object(other)) => {
147                let output: Result<Object<_>, _> = other
148                    .into_iter()
149                    .map(|(name, y)| {
150                        this.clone()
151                            .try_binary_op_inner(y, op, arithmetic)
152                            .map(|res| (name, res))
153                    })
154                    .collect();
155                output.map(Self::Object)
156            }
157            (Self::Object(this), other @ Self::Prim(_)) => {
158                let output: Result<Object<_>, _> = this
159                    .into_iter()
160                    .map(|(name, x)| {
161                        x.try_binary_op_inner(other.clone(), op, arithmetic)
162                            .map(|res| (name, res))
163                    })
164                    .collect();
165                output.map(Self::Object)
166            }
167
168            (Self::Object(this), Self::Object(mut other)) => {
169                let same_keys = this.len() == other.len()
170                    && this.field_names().all(|key| other.contains_field(key));
171                if same_keys {
172                    let output: Result<Object<_>, _> = this
173                        .into_iter()
174                        .map(|(name, x)| {
175                            let y = other.remove(&name).unwrap();
176                            // ^ `unwrap` safety was checked previously
177                            x.try_binary_op_inner(y, op, arithmetic)
178                                .map(|res| (name, res))
179                        })
180                        .collect();
181                    output.map(Self::Object)
182                } else {
183                    Err(BinaryOpError::object(op, this, other))
184                }
185            }
186
187            (Self::Prim(_) | Self::Tuple(_) | Self::Object(_), _) => {
188                Err(BinaryOpError::new(op).with_side(OpSide::Rhs))
189            }
190            _ => Err(BinaryOpError::new(op).with_side(OpSide::Lhs)),
191        }
192    }
193
194    #[inline]
195    pub(crate) fn try_binary_op(
196        module_id: &Arc<dyn ModuleId>,
197        total_span: Location,
198        lhs: Location<Self>,
199        rhs: Location<Self>,
200        op: BinaryOp,
201        arithmetic: &dyn OrdArithmetic<T>,
202    ) -> Result<Self, Error> {
203        let lhs_span = lhs.with_no_extra();
204        let rhs_span = rhs.with_no_extra();
205        lhs.extra
206            .try_binary_op_inner(rhs.extra, op, arithmetic)
207            .map_err(|e| e.span(module_id.clone(), total_span, lhs_span, rhs_span))
208    }
209}
210
211impl<T> Value<T> {
212    pub(crate) fn try_neg(self, arithmetic: &dyn OrdArithmetic<T>) -> Result<Self, ErrorKind> {
213        match self {
214            Self::Prim(val) => arithmetic
215                .neg(val)
216                .map(Self::Prim)
217                .map_err(ErrorKind::Arithmetic),
218
219            Self::Tuple(tuple) => {
220                let res: Result<Tuple<_>, _> = tuple
221                    .into_iter()
222                    .map(|elem| elem.try_neg(arithmetic))
223                    .collect();
224                res.map(Self::Tuple)
225            }
226
227            Self::Object(object) => {
228                let res: Result<Object<_>, _> = object
229                    .into_iter()
230                    .map(|(name, value)| value.try_neg(arithmetic).map(|mapped| (name, mapped)))
231                    .collect();
232                res.map(Self::Object)
233            }
234
235            _ => Err(ErrorKind::UnexpectedOperand {
236                op: UnaryOp::Neg.into(),
237            }),
238        }
239    }
240
241    pub(crate) fn try_not(self) -> Result<Self, ErrorKind> {
242        match self {
243            Self::Bool(val) => Ok(Self::Bool(!val)),
244            Self::Tuple(tuple) => {
245                let res: Result<Tuple<_>, _> = tuple.into_iter().map(Value::try_not).collect();
246                res.map(Self::Tuple)
247            }
248
249            _ => Err(ErrorKind::UnexpectedOperand {
250                op: UnaryOp::Not.into(),
251            }),
252        }
253    }
254
255    // **NB.** Must match `PartialEq` impl for `Value`!
256    pub(crate) fn eq_by_arithmetic(&self, rhs: &Self, arithmetic: &dyn OrdArithmetic<T>) -> bool {
257        match (self, rhs) {
258            (Self::Prim(this), Self::Prim(other)) => arithmetic.eq(this, other),
259            (Self::Bool(this), Self::Bool(other)) => this == other,
260            (Self::Tuple(this), Self::Tuple(other)) if this.len() == other.len() => this
261                .iter()
262                .zip(other.iter())
263                .all(|(x, y)| x.eq_by_arithmetic(y, arithmetic)),
264            (Self::Object(this), Self::Object(other)) => this.eq_by_arithmetic(other, arithmetic),
265            (Self::Function(this), Self::Function(other)) => this.is_same_function(other),
266            (Self::Ref(this), Self::Ref(other)) => this == other,
267            _ => false,
268        }
269    }
270
271    pub(crate) fn compare(
272        module_id: &Arc<dyn ModuleId>,
273        lhs: &Location<Self>,
274        rhs: &Location<Self>,
275        op: BinaryOp,
276        arithmetic: &dyn OrdArithmetic<T>,
277    ) -> Result<Self, Error> {
278        // We only know how to compare primitive values.
279        let Value::Prim(lhs_value) = &lhs.extra else {
280            return Err(Error::new(module_id.clone(), lhs, ErrorKind::CannotCompare));
281        };
282        let Value::Prim(rhs_value) = &rhs.extra else {
283            return Err(Error::new(module_id.clone(), rhs, ErrorKind::CannotCompare));
284        };
285
286        let maybe_ordering = arithmetic.partial_cmp(lhs_value, rhs_value);
287        let cmp_result = maybe_ordering.is_some_and(|ordering| match op {
288            BinaryOp::Gt => ordering == Ordering::Greater,
289            BinaryOp::Lt => ordering == Ordering::Less,
290            BinaryOp::Ge => ordering != Ordering::Less,
291            BinaryOp::Le => ordering != Ordering::Greater,
292            _ => unreachable!(),
293        });
294        Ok(Value::Bool(cmp_result))
295    }
296
297    pub(crate) fn try_and(
298        module_id: &Arc<dyn ModuleId>,
299        lhs: &Location<Self>,
300        rhs: &Location<Self>,
301    ) -> Result<Self, Error> {
302        match (&lhs.extra, &rhs.extra) {
303            (Value::Bool(this), Value::Bool(other)) => Ok(Value::Bool(*this && *other)),
304            (Value::Bool(_), _) => {
305                let err = ErrorKind::UnexpectedOperand {
306                    op: BinaryOp::And.into(),
307                };
308                Err(Error::new(module_id.clone(), rhs, err))
309            }
310            _ => {
311                let err = ErrorKind::UnexpectedOperand {
312                    op: BinaryOp::And.into(),
313                };
314                Err(Error::new(module_id.clone(), lhs, err))
315            }
316        }
317    }
318
319    pub(crate) fn try_or(
320        module_id: &Arc<dyn ModuleId>,
321        lhs: &Location<Self>,
322        rhs: &Location<Self>,
323    ) -> Result<Self, Error> {
324        match (&lhs.extra, &rhs.extra) {
325            (Value::Bool(this), Value::Bool(other)) => Ok(Value::Bool(*this || *other)),
326            (Value::Bool(_), _) => {
327                let err = ErrorKind::UnexpectedOperand {
328                    op: BinaryOp::Or.into(),
329                };
330                Err(Error::new(module_id.clone(), rhs, err))
331            }
332            _ => {
333                let err = ErrorKind::UnexpectedOperand {
334                    op: BinaryOp::Or.into(),
335                };
336                Err(Error::new(module_id.clone(), lhs, err))
337            }
338        }
339    }
340}
341
342impl<T> Object<T> {
343    fn eq_by_arithmetic(&self, other: &Self, arithmetic: &dyn OrdArithmetic<T>) -> bool {
344        if self.len() == other.len() {
345            for (field_name, this_elem) in self {
346                let Some(that_elem) = other.get(field_name) else {
347                    return false;
348                };
349                if !this_elem.eq_by_arithmetic(that_elem, arithmetic) {
350                    return false;
351                }
352            }
353            true
354        } else {
355            false
356        }
357    }
358}