arithmetic_eval/values/
ops.rs

1//! Operations on `Value`s.
2
3use core::cmp::Ordering;
4
5use arithmetic_parser::{BinaryOp, Location, Op, UnaryOp};
6
7use crate::{
8    alloc::Arc,
9    arith::OrdArithmetic,
10    error::{AuxErrorInfo, Error, ErrorKind, TupleLenMismatchContext},
11    exec::ModuleId,
12    Object, Tuple, Value,
13};
14
15#[derive(Debug, Clone, Copy)]
16enum OpSide {
17    Lhs,
18    Rhs,
19}
20
21#[derive(Debug)]
22struct BinaryOpError {
23    inner: ErrorKind,
24    side: Option<OpSide>,
25}
26
27impl BinaryOpError {
28    fn new(op: BinaryOp) -> Self {
29        Self {
30            inner: ErrorKind::UnexpectedOperand { op: Op::Binary(op) },
31            side: None,
32        }
33    }
34
35    fn tuple(op: BinaryOp, lhs: usize, rhs: usize) -> Self {
36        Self {
37            inner: ErrorKind::TupleLenMismatch {
38                lhs: lhs.into(),
39                rhs,
40                context: TupleLenMismatchContext::BinaryOp(op),
41            },
42            side: Some(OpSide::Lhs),
43        }
44    }
45
46    fn object<T>(op: BinaryOp, lhs: Object<T>, rhs: Object<T>) -> Self {
47        Self {
48            inner: ErrorKind::FieldsMismatch {
49                lhs_fields: lhs.into_iter().map(|(name, _)| name).collect(),
50                rhs_fields: rhs.into_iter().map(|(name, _)| name).collect(),
51                op,
52            },
53            side: Some(OpSide::Lhs),
54        }
55    }
56
57    fn with_side(mut self, side: OpSide) -> Self {
58        self.side = Some(side);
59        self
60    }
61
62    fn with_error_kind(mut self, error_kind: ErrorKind) -> Self {
63        self.inner = error_kind;
64        self
65    }
66
67    fn span(
68        self,
69        module_id: Arc<dyn ModuleId>,
70        total_span: Location,
71        lhs_span: Location,
72        rhs_span: Location,
73    ) -> Error {
74        let main_span = match self.side {
75            Some(OpSide::Lhs) => lhs_span,
76            Some(OpSide::Rhs) => rhs_span,
77            None => total_span,
78        };
79
80        let aux_info = match &self.inner {
81            ErrorKind::TupleLenMismatch { rhs, .. } => Some(AuxErrorInfo::UnbalancedRhsTuple(*rhs)),
82            ErrorKind::FieldsMismatch { rhs_fields, .. } => {
83                Some(AuxErrorInfo::UnbalancedRhsObject(rhs_fields.clone()))
84            }
85            _ => None,
86        };
87
88        let mut err = Error::new(module_id, &main_span, self.inner);
89        if let Some(aux_info) = aux_info {
90            err = err.with_location(&rhs_span, aux_info);
91        }
92        err
93    }
94}
95
96impl<T: Clone> Value<T> {
97    fn try_binary_op_inner(
98        self,
99        rhs: Self,
100        op: BinaryOp,
101        arithmetic: &dyn OrdArithmetic<T>,
102    ) -> Result<Self, BinaryOpError> {
103        match (self, rhs) {
104            (Self::Prim(this), Self::Prim(other)) => {
105                let op_result = match op {
106                    BinaryOp::Add => arithmetic.add(this, other),
107                    BinaryOp::Sub => arithmetic.sub(this, other),
108                    BinaryOp::Mul => arithmetic.mul(this, other),
109                    BinaryOp::Div => arithmetic.div(this, other),
110                    BinaryOp::Power => arithmetic.pow(this, other),
111                    _ => unreachable!(),
112                };
113                op_result
114                    .map(Self::Prim)
115                    .map_err(|e| BinaryOpError::new(op).with_error_kind(ErrorKind::Arithmetic(e)))
116            }
117
118            (this @ Self::Prim(_), Self::Tuple(other)) => {
119                let output: Result<Tuple<_>, _> = other
120                    .into_iter()
121                    .map(|y| this.clone().try_binary_op_inner(y, op, arithmetic))
122                    .collect();
123                output.map(Self::Tuple)
124            }
125            (Self::Tuple(this), other @ Self::Prim(_)) => {
126                let output: Result<Tuple<_>, _> = this
127                    .into_iter()
128                    .map(|x| x.try_binary_op_inner(other.clone(), op, arithmetic))
129                    .collect();
130                output.map(Self::Tuple)
131            }
132
133            (Self::Tuple(this), Self::Tuple(other)) => {
134                if this.len() == other.len() {
135                    let output: Result<Tuple<_>, _> = this
136                        .into_iter()
137                        .zip(other)
138                        .map(|(x, y)| x.try_binary_op_inner(y, op, arithmetic))
139                        .collect();
140                    output.map(Self::Tuple)
141                } else {
142                    Err(BinaryOpError::tuple(op, this.len(), other.len()))
143                }
144            }
145
146            (this @ Self::Prim(_), Self::Object(other)) => {
147                let output: Result<Object<_>, _> = other
148                    .into_iter()
149                    .map(|(name, y)| {
150                        this.clone()
151                            .try_binary_op_inner(y, op, arithmetic)
152                            .map(|res| (name, res))
153                    })
154                    .collect();
155                output.map(Self::Object)
156            }
157            (Self::Object(this), other @ Self::Prim(_)) => {
158                let output: Result<Object<_>, _> = this
159                    .into_iter()
160                    .map(|(name, x)| {
161                        x.try_binary_op_inner(other.clone(), op, arithmetic)
162                            .map(|res| (name, res))
163                    })
164                    .collect();
165                output.map(Self::Object)
166            }
167
168            (Self::Object(this), Self::Object(mut other)) => {
169                let same_keys = this.len() == other.len()
170                    && this.field_names().all(|key| other.contains_field(key));
171                if same_keys {
172                    let output: Result<Object<_>, _> = this
173                        .into_iter()
174                        .map(|(name, x)| {
175                            let y = other.remove(&name).unwrap();
176                            // ^ `unwrap` safety was checked previously
177                            x.try_binary_op_inner(y, op, arithmetic)
178                                .map(|res| (name, res))
179                        })
180                        .collect();
181                    output.map(Self::Object)
182                } else {
183                    Err(BinaryOpError::object(op, this, other))
184                }
185            }
186
187            (Self::Prim(_) | Self::Tuple(_) | Self::Object(_), _) => {
188                Err(BinaryOpError::new(op).with_side(OpSide::Rhs))
189            }
190            _ => Err(BinaryOpError::new(op).with_side(OpSide::Lhs)),
191        }
192    }
193
194    #[inline]
195    pub(crate) fn try_binary_op(
196        module_id: &Arc<dyn ModuleId>,
197        total_span: Location,
198        lhs: Location<Self>,
199        rhs: Location<Self>,
200        op: BinaryOp,
201        arithmetic: &dyn OrdArithmetic<T>,
202    ) -> Result<Self, Error> {
203        let lhs_span = lhs.with_no_extra();
204        let rhs_span = rhs.with_no_extra();
205        lhs.extra
206            .try_binary_op_inner(rhs.extra, op, arithmetic)
207            .map_err(|e| e.span(module_id.clone(), total_span, lhs_span, rhs_span))
208    }
209}
210
211impl<T> Value<T> {
212    pub(crate) fn try_neg(self, arithmetic: &dyn OrdArithmetic<T>) -> Result<Self, ErrorKind> {
213        match self {
214            Self::Prim(val) => arithmetic
215                .neg(val)
216                .map(Self::Prim)
217                .map_err(ErrorKind::Arithmetic),
218
219            Self::Tuple(tuple) => {
220                let res: Result<Tuple<_>, _> = tuple
221                    .into_iter()
222                    .map(|elem| elem.try_neg(arithmetic))
223                    .collect();
224                res.map(Self::Tuple)
225            }
226
227            Self::Object(object) => {
228                let res: Result<Object<_>, _> = object
229                    .into_iter()
230                    .map(|(name, value)| value.try_neg(arithmetic).map(|mapped| (name, mapped)))
231                    .collect();
232                res.map(Self::Object)
233            }
234
235            _ => Err(ErrorKind::UnexpectedOperand {
236                op: UnaryOp::Neg.into(),
237            }),
238        }
239    }
240
241    pub(crate) fn try_not(self) -> Result<Self, ErrorKind> {
242        match self {
243            Self::Bool(val) => Ok(Self::Bool(!val)),
244            Self::Tuple(tuple) => {
245                let res: Result<Tuple<_>, _> = tuple.into_iter().map(Value::try_not).collect();
246                res.map(Self::Tuple)
247            }
248
249            _ => Err(ErrorKind::UnexpectedOperand {
250                op: UnaryOp::Not.into(),
251            }),
252        }
253    }
254
255    // **NB.** Must match `PartialEq` impl for `Value`!
256    pub(crate) fn eq_by_arithmetic(&self, rhs: &Self, arithmetic: &dyn OrdArithmetic<T>) -> bool {
257        match (self, rhs) {
258            (Self::Prim(this), Self::Prim(other)) => arithmetic.eq(this, other),
259            (Self::Bool(this), Self::Bool(other)) => this == other,
260            (Self::Tuple(this), Self::Tuple(other)) => {
261                if this.len() == other.len() {
262                    this.iter()
263                        .zip(other.iter())
264                        .all(|(x, y)| x.eq_by_arithmetic(y, arithmetic))
265                } else {
266                    false
267                }
268            }
269            (Self::Object(this), Self::Object(other)) => this.eq_by_arithmetic(other, arithmetic),
270            (Self::Function(this), Self::Function(other)) => this.is_same_function(other),
271            (Self::Ref(this), Self::Ref(other)) => this == other,
272            _ => false,
273        }
274    }
275
276    pub(crate) fn compare(
277        module_id: &Arc<dyn ModuleId>,
278        lhs: &Location<Self>,
279        rhs: &Location<Self>,
280        op: BinaryOp,
281        arithmetic: &dyn OrdArithmetic<T>,
282    ) -> Result<Self, Error> {
283        // We only know how to compare primitive values.
284        let Value::Prim(lhs_value) = &lhs.extra else {
285            return Err(Error::new(module_id.clone(), lhs, ErrorKind::CannotCompare));
286        };
287        let Value::Prim(rhs_value) = &rhs.extra else {
288            return Err(Error::new(module_id.clone(), rhs, ErrorKind::CannotCompare));
289        };
290
291        let maybe_ordering = arithmetic.partial_cmp(lhs_value, rhs_value);
292        let cmp_result = maybe_ordering.is_some_and(|ordering| match op {
293            BinaryOp::Gt => ordering == Ordering::Greater,
294            BinaryOp::Lt => ordering == Ordering::Less,
295            BinaryOp::Ge => ordering != Ordering::Less,
296            BinaryOp::Le => ordering != Ordering::Greater,
297            _ => unreachable!(),
298        });
299        Ok(Value::Bool(cmp_result))
300    }
301
302    pub(crate) fn try_and(
303        module_id: &Arc<dyn ModuleId>,
304        lhs: &Location<Self>,
305        rhs: &Location<Self>,
306    ) -> Result<Self, Error> {
307        match (&lhs.extra, &rhs.extra) {
308            (Value::Bool(this), Value::Bool(other)) => Ok(Value::Bool(*this && *other)),
309            (Value::Bool(_), _) => {
310                let err = ErrorKind::UnexpectedOperand {
311                    op: BinaryOp::And.into(),
312                };
313                Err(Error::new(module_id.clone(), rhs, err))
314            }
315            _ => {
316                let err = ErrorKind::UnexpectedOperand {
317                    op: BinaryOp::And.into(),
318                };
319                Err(Error::new(module_id.clone(), lhs, err))
320            }
321        }
322    }
323
324    pub(crate) fn try_or(
325        module_id: &Arc<dyn ModuleId>,
326        lhs: &Location<Self>,
327        rhs: &Location<Self>,
328    ) -> Result<Self, Error> {
329        match (&lhs.extra, &rhs.extra) {
330            (Value::Bool(this), Value::Bool(other)) => Ok(Value::Bool(*this || *other)),
331            (Value::Bool(_), _) => {
332                let err = ErrorKind::UnexpectedOperand {
333                    op: BinaryOp::Or.into(),
334                };
335                Err(Error::new(module_id.clone(), rhs, err))
336            }
337            _ => {
338                let err = ErrorKind::UnexpectedOperand {
339                    op: BinaryOp::Or.into(),
340                };
341                Err(Error::new(module_id.clone(), lhs, err))
342            }
343        }
344    }
345}
346
347impl<T> Object<T> {
348    fn eq_by_arithmetic(&self, other: &Self, arithmetic: &dyn OrdArithmetic<T>) -> bool {
349        if self.len() == other.len() {
350            for (field_name, this_elem) in self {
351                let Some(that_elem) = other.get(field_name) else {
352                    return false;
353                };
354                if !this_elem.eq_by_arithmetic(that_elem, arithmetic) {
355                    return false;
356                }
357            }
358            true
359        } else {
360            false
361        }
362    }
363}