elastic_elgamal/group/
ristretto.rs

1use core::convert::TryInto;
2
3use elliptic_curve::rand_core::{CryptoRng, RngCore};
4
5use crate::{
6    curve25519::{
7        constants::{RISTRETTO_BASEPOINT_POINT, RISTRETTO_BASEPOINT_TABLE},
8        ristretto::{CompressedRistretto, RistrettoPoint},
9        scalar::Scalar,
10        traits::{Identity, IsIdentity, MultiscalarMul, VartimeMultiscalarMul},
11    },
12    group::{ElementOps, Group, RandomBytesProvider, ScalarOps},
13};
14
15/// [Ristretto](https://ristretto.group/) transform of Curve25519, also known as ristretto255.
16#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
17#[cfg_attr(
18    docsrs,
19    doc(cfg(any(feature = "curve25519-dalek", feature = "curve25519-dalek-ng")))
20)]
21pub struct Ristretto(());
22
23impl ScalarOps for Ristretto {
24    type Scalar = Scalar;
25
26    const SCALAR_SIZE: usize = 32;
27
28    fn generate_scalar<R: CryptoRng + RngCore>(rng: &mut R) -> Self::Scalar {
29        let mut scalar_bytes = [0_u8; 64];
30        rng.fill_bytes(&mut scalar_bytes[..]);
31        Scalar::from_bytes_mod_order_wide(&scalar_bytes)
32    }
33
34    fn scalar_from_random_bytes(source: RandomBytesProvider<'_>) -> Self::Scalar {
35        let mut scalar_bytes = [0_u8; 64];
36        source.fill_bytes(&mut scalar_bytes);
37        Scalar::from_bytes_mod_order_wide(&scalar_bytes)
38    }
39
40    fn invert_scalar(scalar: Self::Scalar) -> Self::Scalar {
41        scalar.invert()
42    }
43
44    #[cfg(feature = "curve25519-dalek")]
45    fn invert_scalars(scalars: &mut [Self::Scalar]) {
46        Scalar::invert_batch_alloc(scalars);
47    }
48
49    #[cfg(feature = "curve25519-dalek-ng")]
50    fn invert_scalars(scalars: &mut [Self::Scalar]) {
51        Scalar::batch_invert(scalars);
52    }
53
54    fn serialize_scalar(scalar: &Self::Scalar, buffer: &mut [u8]) {
55        buffer.copy_from_slice(&scalar.to_bytes());
56    }
57
58    #[cfg(feature = "curve25519-dalek")]
59    fn deserialize_scalar(buffer: &[u8]) -> Option<Self::Scalar> {
60        let bytes: &[u8; 32] = buffer.try_into().expect("input has incorrect byte size");
61        Scalar::from_canonical_bytes(*bytes).into()
62    }
63
64    #[cfg(feature = "curve25519-dalek-ng")]
65    fn deserialize_scalar(buffer: &[u8]) -> Option<Self::Scalar> {
66        let bytes: &[u8; 32] = buffer.try_into().expect("input has incorrect byte size");
67        Scalar::from_canonical_bytes(*bytes)
68    }
69}
70
71impl ElementOps for Ristretto {
72    type Element = RistrettoPoint;
73
74    const ELEMENT_SIZE: usize = 32;
75
76    fn identity() -> Self::Element {
77        RistrettoPoint::identity()
78    }
79
80    fn is_identity(element: &Self::Element) -> bool {
81        element.is_identity()
82    }
83
84    fn generator() -> Self::Element {
85        RISTRETTO_BASEPOINT_POINT
86    }
87
88    fn serialize_element(element: &Self::Element, buffer: &mut [u8]) {
89        buffer.copy_from_slice(&element.compress().to_bytes());
90    }
91
92    #[cfg(feature = "curve25519-dalek")]
93    fn deserialize_element(buffer: &[u8]) -> Option<Self::Element> {
94        CompressedRistretto::from_slice(buffer).ok()?.decompress()
95    }
96
97    #[cfg(feature = "curve25519-dalek-ng")]
98    fn deserialize_element(buffer: &[u8]) -> Option<Self::Element> {
99        CompressedRistretto::from_slice(buffer).decompress()
100    }
101}
102
103impl Group for Ristretto {
104    #[cfg(feature = "curve25519-dalek")]
105    fn mul_generator(k: &Scalar) -> Self::Element {
106        k * RISTRETTO_BASEPOINT_TABLE
107    }
108
109    #[cfg(feature = "curve25519-dalek-ng")]
110    fn mul_generator(k: &Scalar) -> Self::Element {
111        k * &RISTRETTO_BASEPOINT_TABLE
112    }
113
114    fn vartime_mul_generator(k: &Scalar) -> Self::Element {
115        #[cfg(feature = "curve25519-dalek")]
116        let zero = Scalar::ZERO;
117        #[cfg(feature = "curve25519-dalek-ng")]
118        let zero = Scalar::zero();
119
120        RistrettoPoint::vartime_double_scalar_mul_basepoint(&zero, &RistrettoPoint::identity(), k)
121    }
122
123    fn multi_mul<'a, I, J>(scalars: I, elements: J) -> Self::Element
124    where
125        I: IntoIterator<Item = &'a Self::Scalar>,
126        J: IntoIterator<Item = Self::Element>,
127    {
128        RistrettoPoint::multiscalar_mul(scalars, elements)
129    }
130
131    fn vartime_double_mul_generator(
132        k: &Scalar,
133        k_element: Self::Element,
134        r: &Scalar,
135    ) -> Self::Element {
136        RistrettoPoint::vartime_double_scalar_mul_basepoint(k, &k_element, r)
137    }
138
139    fn vartime_multi_mul<'a, I, J>(scalars: I, elements: J) -> Self::Element
140    where
141        I: IntoIterator<Item = &'a Self::Scalar>,
142        J: IntoIterator<Item = Self::Element>,
143    {
144        RistrettoPoint::vartime_multiscalar_mul(scalars, elements)
145    }
146}
147
148#[cfg(test)]
149mod tests {
150    use super::*;
151    use crate::{
152        DiscreteLogTable,
153        app::{ChoiceParams, EncryptedChoice},
154        group::Curve25519Subgroup,
155    };
156
157    type SecretKey = crate::SecretKey<Ristretto>;
158    type Keypair = crate::Keypair<Ristretto>;
159
160    #[test]
161    fn encrypt_and_decrypt() {
162        let mut rng = rand::rng();
163        let keypair = Keypair::generate(&mut rng);
164        let value = Ristretto::generate_scalar(&mut rng);
165        let encrypted = keypair.public().encrypt(value, &mut rng);
166        let decryption = keypair.secret().decrypt_to_element(encrypted);
167        assert_eq!(decryption, Ristretto::vartime_mul_generator(&value));
168    }
169
170    #[test]
171    fn encrypt_choice() {
172        let mut rng = rand::rng();
173        let (pk, sk) = Keypair::generate(&mut rng).into_tuple();
174        let choice_params = ChoiceParams::single(pk, 5);
175        let encrypted = EncryptedChoice::single(&choice_params, 3, &mut rng);
176        let choices = encrypted.verify(&choice_params).unwrap();
177
178        let lookup_table = DiscreteLogTable::new(0..=1);
179        for (i, &choice) in choices.iter().enumerate() {
180            let decryption = sk.decrypt(choice, &lookup_table);
181            assert_eq!(decryption.unwrap(), u64::from(i == 3));
182        }
183    }
184
185    #[test]
186    fn edwards_and_ristretto_public_keys_differ() {
187        type SubgroupSecretKey = crate::SecretKey<Curve25519Subgroup>;
188        type SubgroupKeypair = crate::Keypair<Curve25519Subgroup>;
189
190        for _ in 0..1_000 {
191            let secret_key = SecretKey::generate(&mut rand::rng());
192            let keypair = Keypair::from(secret_key.clone());
193            let secret_key = SubgroupSecretKey::new(*secret_key.expose_scalar());
194            let ed_keypair = SubgroupKeypair::from(secret_key);
195            assert_ne!(keypair.public().as_bytes(), ed_keypair.public().as_bytes());
196        }
197    }
198}