jwt_compact/alg/
rsa.rs

1//! RSA-based JWT algorithms: `RS*` and `PS*`.
2
3use core::{fmt, str::FromStr};
4
5use rand_core::{CryptoRng, RngCore};
6use rsa::{
7    BoxedUint, Pkcs1v15Sign, Pss,
8    traits::{PrivateKeyParts, PublicKeyParts},
9};
10pub use rsa::{RsaPrivateKey, RsaPublicKey, errors::Error as RsaError};
11use sha2::{Digest, Sha256, Sha384, Sha512};
12
13use crate::{
14    Algorithm, AlgorithmSignature,
15    alg::{SecretBytes, StrongKey, WeakKeyError},
16    alloc::{Cow, String, ToOwned, Vec},
17    jwk::{JsonWebKey, JwkError, KeyType, RsaPrimeFactor, RsaPrivateParts},
18};
19
20/// RSA signature.
21#[derive(Debug)]
22#[cfg_attr(docsrs, doc(cfg(feature = "rsa")))]
23pub struct RsaSignature(Vec<u8>);
24
25impl AlgorithmSignature for RsaSignature {
26    fn try_from_slice(bytes: &[u8]) -> anyhow::Result<Self> {
27        Ok(RsaSignature(bytes.to_vec()))
28    }
29
30    fn as_bytes(&self) -> Cow<'_, [u8]> {
31        Cow::Borrowed(&self.0)
32    }
33}
34
35/// RSA hash algorithm.
36#[derive(Debug, Copy, Clone, Eq, PartialEq)]
37enum HashAlg {
38    Sha256,
39    Sha384,
40    Sha512,
41}
42
43impl HashAlg {
44    fn digest(self, message: &[u8]) -> HashDigest {
45        match self {
46            Self::Sha256 => HashDigest::Sha256(Sha256::digest(message).into()),
47            Self::Sha384 => HashDigest::Sha384(Sha384::digest(message).into()),
48            Self::Sha512 => HashDigest::Sha512(Sha512::digest(message).into()),
49        }
50    }
51}
52
53/// Output of a [`HashAlg`].
54#[derive(Debug)]
55enum HashDigest {
56    Sha256([u8; 32]),
57    Sha384([u8; 48]),
58    Sha512([u8; 64]),
59}
60
61impl AsRef<[u8]> for HashDigest {
62    fn as_ref(&self) -> &[u8] {
63        match self {
64            Self::Sha256(bytes) => bytes,
65            Self::Sha384(bytes) => bytes,
66            Self::Sha512(bytes) => bytes,
67        }
68    }
69}
70
71/// RSA padding algorithm.
72#[derive(Debug, Copy, Clone, Eq, PartialEq)]
73enum Padding {
74    Pkcs1v15,
75    Pss,
76}
77
78#[derive(Debug)]
79enum PaddingScheme {
80    Pkcs1v15(Pkcs1v15Sign),
81    Pss(Pss),
82}
83
84/// Bit length of an RSA key modulus (aka RSA key length).
85#[derive(Debug, Copy, Clone, Eq, PartialEq)]
86#[non_exhaustive]
87#[cfg_attr(docsrs, doc(cfg(feature = "rsa")))]
88pub enum ModulusBits {
89    /// 2048 bits. This is the minimum recommended key length as of 2020.
90    TwoKibibytes,
91    /// 3072 bits.
92    ThreeKibibytes,
93    /// 4096 bits.
94    FourKibibytes,
95}
96
97impl ModulusBits {
98    /// Converts this length to the numeric value.
99    pub fn bits(self) -> usize {
100        match self {
101            Self::TwoKibibytes => 2_048,
102            Self::ThreeKibibytes => 3_072,
103            Self::FourKibibytes => 4_096,
104        }
105    }
106
107    fn is_valid_bits(bits: u32) -> bool {
108        matches!(bits, 2_048 | 3_072 | 4_096)
109    }
110}
111
112impl TryFrom<usize> for ModulusBits {
113    type Error = ModulusBitsError;
114
115    fn try_from(value: usize) -> Result<Self, Self::Error> {
116        match value {
117            2_048 => Ok(Self::TwoKibibytes),
118            3_072 => Ok(Self::ThreeKibibytes),
119            4_096 => Ok(Self::FourKibibytes),
120            _ => Err(ModulusBitsError(())),
121        }
122    }
123}
124
125/// Error type returned when a conversion of an integer into `ModulusBits` fails.
126#[derive(Debug)]
127#[cfg_attr(docsrs, doc(cfg(feature = "rsa")))]
128pub struct ModulusBitsError(());
129
130impl fmt::Display for ModulusBitsError {
131    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
132        formatter.write_str(
133            "Unsupported bit length of RSA modulus; only lengths 2048, 3072 and 4096 \
134            are supported.",
135        )
136    }
137}
138
139#[cfg(feature = "std")]
140impl std::error::Error for ModulusBitsError {}
141
142/// Integrity algorithm using [RSA] digital signatures.
143///
144/// Depending on the variation, the algorithm employs PKCS#1 v1.5 or PSS padding and
145/// one of the hash functions from the SHA-2 family: SHA-256, SHA-384, or SHA-512.
146/// See [RFC 7518] for more details. Depending on the chosen parameters,
147/// the name of the algorithm is one of `RS256`, `RS384`, `RS512`, `PS256`, `PS384`, `PS512`:
148///
149/// - `R` / `P` denote the padding scheme: PKCS#1 v1.5 for `R`, PSS for `P`
150/// - `256` / `384` / `512` denote the hash function
151///
152/// The length of RSA keys is not unequivocally specified by the algorithm; nevertheless,
153/// it **MUST** be at least 2048 bits as per RFC 7518. To minimize risks of misconfiguration,
154/// use [`StrongAlg`](super::StrongAlg) wrapper around `Rsa`:
155///
156/// ```
157/// # use jwt_compact::alg::{StrongAlg, Rsa};
158/// const ALG: StrongAlg<Rsa> = StrongAlg(Rsa::rs256());
159/// // `ALG` will not support RSA keys with unsecure lengths by design!
160/// ```
161///
162/// [RSA]: https://en.wikipedia.org/wiki/RSA_(cryptosystem)
163/// [RFC 7518]: https://www.rfc-editor.org/rfc/rfc7518.html
164#[derive(Debug, Clone, Copy, PartialEq, Eq)]
165#[cfg_attr(docsrs, doc(cfg(feature = "rsa")))]
166pub struct Rsa {
167    hash_alg: HashAlg,
168    padding_alg: Padding,
169}
170
171impl Algorithm for Rsa {
172    type SigningKey = RsaPrivateKey;
173    type VerifyingKey = RsaPublicKey;
174    type Signature = RsaSignature;
175
176    fn name(&self) -> Cow<'static, str> {
177        Cow::Borrowed(self.alg_name())
178    }
179
180    fn sign(&self, signing_key: &Self::SigningKey, message: &[u8]) -> Self::Signature {
181        let digest = self.hash_alg.digest(message);
182        let digest = digest.as_ref();
183        let signing_result = match self.padding_scheme() {
184            PaddingScheme::Pkcs1v15(padding) => signing_key.sign_with_rng(
185                &mut rand_core::UnwrapErr(rand_core::OsRng),
186                padding,
187                digest,
188            ),
189            PaddingScheme::Pss(padding) => signing_key.sign_with_rng(
190                &mut rand_core::UnwrapErr(rand_core::OsRng),
191                padding,
192                digest,
193            ),
194        };
195        RsaSignature(signing_result.expect("Unexpected RSA signature failure"))
196    }
197
198    fn verify_signature(
199        &self,
200        signature: &Self::Signature,
201        verifying_key: &Self::VerifyingKey,
202        message: &[u8],
203    ) -> bool {
204        let digest = self.hash_alg.digest(message);
205        let digest = digest.as_ref();
206        let verify_result = match self.padding_scheme() {
207            PaddingScheme::Pkcs1v15(padding) => verifying_key.verify(padding, digest, &signature.0),
208            PaddingScheme::Pss(padding) => verifying_key.verify(padding, digest, &signature.0),
209        };
210        verify_result.is_ok()
211    }
212}
213
214impl Rsa {
215    const fn new(hash_alg: HashAlg, padding_alg: Padding) -> Self {
216        Rsa {
217            hash_alg,
218            padding_alg,
219        }
220    }
221
222    /// RSA with SHA-256 and PKCS#1 v1.5 padding.
223    pub const fn rs256() -> Rsa {
224        Rsa::new(HashAlg::Sha256, Padding::Pkcs1v15)
225    }
226
227    /// RSA with SHA-384 and PKCS#1 v1.5 padding.
228    pub const fn rs384() -> Rsa {
229        Rsa::new(HashAlg::Sha384, Padding::Pkcs1v15)
230    }
231
232    /// RSA with SHA-512 and PKCS#1 v1.5 padding.
233    pub const fn rs512() -> Rsa {
234        Rsa::new(HashAlg::Sha512, Padding::Pkcs1v15)
235    }
236
237    /// RSA with SHA-256 and PSS padding.
238    pub const fn ps256() -> Rsa {
239        Rsa::new(HashAlg::Sha256, Padding::Pss)
240    }
241
242    /// RSA with SHA-384 and PSS padding.
243    pub const fn ps384() -> Rsa {
244        Rsa::new(HashAlg::Sha384, Padding::Pss)
245    }
246
247    /// RSA with SHA-512 and PSS padding.
248    pub const fn ps512() -> Rsa {
249        Rsa::new(HashAlg::Sha512, Padding::Pss)
250    }
251
252    /// RSA based on the specified algorithm name.
253    ///
254    /// # Panics
255    ///
256    /// - Panics if the name is not one of the six RSA-based JWS algorithms. Prefer using
257    ///   the [`FromStr`] trait if the conversion is potentially fallible.
258    pub fn with_name(name: &str) -> Self {
259        name.parse().unwrap()
260    }
261
262    fn padding_scheme(self) -> PaddingScheme {
263        match self.padding_alg {
264            Padding::Pkcs1v15 => PaddingScheme::Pkcs1v15(match self.hash_alg {
265                HashAlg::Sha256 => Pkcs1v15Sign::new::<Sha256>(),
266                HashAlg::Sha384 => Pkcs1v15Sign::new::<Sha384>(),
267                HashAlg::Sha512 => Pkcs1v15Sign::new::<Sha512>(),
268            }),
269            Padding::Pss => {
270                // The salt length needs to be set to the size of hash function output;
271                // see https://www.rfc-editor.org/rfc/rfc7518.html#section-3.5.
272                PaddingScheme::Pss(match self.hash_alg {
273                    HashAlg::Sha256 => Pss::new_with_salt::<Sha256>(Sha256::output_size()),
274                    HashAlg::Sha384 => Pss::new_with_salt::<Sha384>(Sha384::output_size()),
275                    HashAlg::Sha512 => Pss::new_with_salt::<Sha512>(Sha512::output_size()),
276                })
277            }
278        }
279    }
280
281    fn alg_name(self) -> &'static str {
282        match (self.padding_alg, self.hash_alg) {
283            (Padding::Pkcs1v15, HashAlg::Sha256) => "RS256",
284            (Padding::Pkcs1v15, HashAlg::Sha384) => "RS384",
285            (Padding::Pkcs1v15, HashAlg::Sha512) => "RS512",
286            (Padding::Pss, HashAlg::Sha256) => "PS256",
287            (Padding::Pss, HashAlg::Sha384) => "PS384",
288            (Padding::Pss, HashAlg::Sha512) => "PS512",
289        }
290    }
291
292    /// Generates a new key pair with the specified modulus bit length (aka key length).
293    pub fn generate<R: CryptoRng + RngCore>(
294        rng: &mut R,
295        modulus_bits: ModulusBits,
296    ) -> rsa::errors::Result<(StrongKey<RsaPrivateKey>, StrongKey<RsaPublicKey>)> {
297        let signing_key = RsaPrivateKey::new(rng, modulus_bits.bits())?;
298        let verifying_key = signing_key.to_public_key();
299        Ok((StrongKey(signing_key), StrongKey(verifying_key)))
300    }
301}
302
303impl FromStr for Rsa {
304    type Err = RsaParseError;
305
306    fn from_str(s: &str) -> Result<Self, Self::Err> {
307        Ok(match s {
308            "RS256" => Self::rs256(),
309            "RS384" => Self::rs384(),
310            "RS512" => Self::rs512(),
311            "PS256" => Self::ps256(),
312            "PS384" => Self::ps384(),
313            "PS512" => Self::ps512(),
314            _ => return Err(RsaParseError(s.to_owned())),
315        })
316    }
317}
318
319/// Errors that can occur when parsing an [`Rsa`] algorithm from a string.
320#[derive(Debug)]
321#[cfg_attr(docsrs, doc(cfg(feature = "rsa")))]
322pub struct RsaParseError(String);
323
324impl fmt::Display for RsaParseError {
325    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
326        write!(formatter, "Invalid RSA algorithm name: {}", self.0)
327    }
328}
329
330#[cfg(feature = "std")]
331impl std::error::Error for RsaParseError {}
332
333impl StrongKey<RsaPrivateKey> {
334    /// Converts this private key to a public key.
335    pub fn to_public_key(&self) -> StrongKey<RsaPublicKey> {
336        StrongKey(self.0.to_public_key())
337    }
338}
339
340impl TryFrom<RsaPrivateKey> for StrongKey<RsaPrivateKey> {
341    type Error = WeakKeyError<RsaPrivateKey>;
342
343    fn try_from(key: RsaPrivateKey) -> Result<Self, Self::Error> {
344        if ModulusBits::is_valid_bits(key.n().bits()) {
345            Ok(StrongKey(key))
346        } else {
347            Err(WeakKeyError(key))
348        }
349    }
350}
351
352impl TryFrom<RsaPublicKey> for StrongKey<RsaPublicKey> {
353    type Error = WeakKeyError<RsaPublicKey>;
354
355    fn try_from(key: RsaPublicKey) -> Result<Self, Self::Error> {
356        if ModulusBits::is_valid_bits(key.n().bits()) {
357            Ok(StrongKey(key))
358        } else {
359            Err(WeakKeyError(key))
360        }
361    }
362}
363
364impl<'a> From<&'a RsaPublicKey> for JsonWebKey<'a> {
365    fn from(key: &'a RsaPublicKey) -> JsonWebKey<'a> {
366        JsonWebKey::Rsa {
367            modulus: Cow::Owned(key.n().to_be_bytes_trimmed_vartime().into()),
368            public_exponent: Cow::Owned(key.e().to_be_bytes_trimmed_vartime().into()),
369            private_parts: None,
370        }
371    }
372}
373
374#[allow(clippy::cast_possible_truncation)] // not triggered
375fn secret_uint_from_slice(slice: &[u8], precision: u32) -> Result<BoxedUint, JwkError> {
376    debug_assert!(precision <= RsaPublicKey::MAX_SIZE as u32);
377    BoxedUint::from_be_slice(slice, precision).map_err(|err| JwkError::custom(anyhow::anyhow!(err)))
378}
379
380/// The caller must ensure that setting `precision` won't truncate the value.
381fn secret_uint_to_slice(secret: &BoxedUint, precision: u32) -> SecretBytes<'static> {
382    let bytes = secret.to_be_bytes();
383    let precision_bytes = precision.div_ceil(8) as usize;
384    SecretBytes::owned_slice(if bytes.len() > precision_bytes {
385        let first_idx = bytes.len() - precision_bytes;
386        bytes[first_idx..].into()
387    } else {
388        bytes
389    })
390}
391
392fn pub_exponent_from_slice(slice: &[u8]) -> Result<BoxedUint, JwkError> {
393    BoxedUint::from_be_slice(slice, RsaPublicKey::MAX_PUB_EXPONENT.ilog2() + 1)
394        .map_err(|err| JwkError::custom(anyhow::anyhow!(err)))
395}
396
397impl TryFrom<&JsonWebKey<'_>> for RsaPublicKey {
398    type Error = JwkError;
399
400    fn try_from(jwk: &JsonWebKey<'_>) -> Result<Self, Self::Error> {
401        let JsonWebKey::Rsa {
402            modulus,
403            public_exponent,
404            ..
405        } = jwk
406        else {
407            return Err(JwkError::key_type(jwk, KeyType::Rsa));
408        };
409
410        let e = pub_exponent_from_slice(public_exponent)?;
411        let n = BoxedUint::from_be_slice_vartime(modulus);
412        Self::new(n, e).map_err(|err| JwkError::custom(anyhow::anyhow!(err)))
413    }
414}
415
416/// ⚠ **Warning.** Contrary to [RFC 7518], this implementation does not set `dp`, `dq`, and `qi`
417/// fields in the JWK root object, as well as `d` and `t` fields for additional factors
418/// (i.e., in the `oth` array).
419///
420/// [RFC 7518]: https://tools.ietf.org/html/rfc7518#section-6.3.2
421impl<'a> From<&'a RsaPrivateKey> for JsonWebKey<'a> {
422    fn from(key: &'a RsaPrivateKey) -> JsonWebKey<'a> {
423        const MSG: &str = "RsaPrivateKey must have at least 2 prime factors";
424
425        let p = key.primes().first().expect(MSG);
426        let q = key.primes().get(1).expect(MSG);
427        // Truncate secret values to the modulus precision. We know that all secret values don't exceed the modulus,
428        // so this is safe. `d` in particular does have q higher precision for multi-prime RSA keys
429        // (e.g., it may have 2,176-bit precision for a 2,048-bit modulus), which would lead to excessive zero padding
430        // and may lead to unnecessary deserialization errors.
431        let precision = key.n().bits_precision();
432
433        let private_parts = RsaPrivateParts {
434            private_exponent: secret_uint_to_slice(key.d(), precision),
435            prime_factor_p: secret_uint_to_slice(p, precision),
436            prime_factor_q: secret_uint_to_slice(q, precision),
437            p_crt_exponent: None,
438            q_crt_exponent: None,
439            q_crt_coefficient: None,
440            other_prime_factors: key.primes()[2..]
441                .iter()
442                .map(|factor| RsaPrimeFactor {
443                    factor: secret_uint_to_slice(factor, precision),
444                    crt_exponent: None,
445                    crt_coefficient: None,
446                })
447                .collect(),
448        };
449
450        JsonWebKey::Rsa {
451            modulus: Cow::Owned(key.n().to_be_bytes_trimmed_vartime().into()),
452            public_exponent: Cow::Owned(key.e().to_be_bytes_trimmed_vartime().into()),
453            private_parts: Some(private_parts),
454        }
455    }
456}
457
458/// ⚠ **Warning.** Contrary to [RFC 7518] (at least, in spirit), this conversion ignores
459/// `dp`, `dq`, and `qi` fields from JWK, as well as `d` and `t` fields for additional factors.
460///
461/// [RFC 7518]: https://www.rfc-editor.org/rfc/rfc7518.html
462impl TryFrom<&JsonWebKey<'_>> for RsaPrivateKey {
463    type Error = JwkError;
464
465    fn try_from(jwk: &JsonWebKey<'_>) -> Result<Self, Self::Error> {
466        let JsonWebKey::Rsa {
467            modulus,
468            public_exponent,
469            private_parts,
470        } = jwk
471        else {
472            return Err(JwkError::key_type(jwk, KeyType::Rsa));
473        };
474
475        let RsaPrivateParts {
476            private_exponent: d,
477            prime_factor_p,
478            prime_factor_q,
479            other_prime_factors,
480            ..
481        } = private_parts
482            .as_ref()
483            .ok_or_else(|| JwkError::NoField("d".into()))?;
484
485        let e = pub_exponent_from_slice(public_exponent)?;
486        let n = BoxedUint::from_be_slice_vartime(modulus);
487
488        // Round `n` bitness up to the nearest value divisible by 8
489        let precision = n.bits().div_ceil(8) * 8;
490        if precision as usize > RsaPublicKey::MAX_SIZE {
491            return Err(JwkError::Custom(anyhow::anyhow!(
492                "Modulus precision ({got}) exceeds maximum supported value ({max})",
493                got = n.bits(),
494                max = RsaPublicKey::MAX_SIZE
495            )));
496        }
497
498        let d = secret_uint_from_slice(d, precision)?;
499        let mut factors = Vec::with_capacity(2 + other_prime_factors.len());
500        factors.push(secret_uint_from_slice(prime_factor_p, precision)?);
501        factors.push(secret_uint_from_slice(prime_factor_q, precision)?);
502        for other_factor in other_prime_factors {
503            factors.push(secret_uint_from_slice(&other_factor.factor, precision)?);
504        }
505
506        let key = Self::from_components(n, e, d, factors);
507        let key = key.map_err(|err| JwkError::custom(anyhow::anyhow!(err)))?;
508        key.validate()
509            .map_err(|err| JwkError::custom(anyhow::anyhow!(err)))?;
510        Ok(key)
511    }
512}