jwt_compact/alg/
rsa.rs

1//! RSA-based JWT algorithms: `RS*` and `PS*`.
2
3use core::{fmt, str::FromStr};
4
5use rand_core::CryptoRng;
6use rsa::{
7    BoxedUint, Pkcs1v15Sign, Pss,
8    traits::{PrivateKeyParts, PublicKeyParts},
9};
10pub use rsa::{RsaPrivateKey, RsaPublicKey, errors::Error as RsaError};
11use sha2::{Digest, Sha256, Sha384, Sha512};
12
13use crate::{
14    Algorithm, AlgorithmSignature,
15    alg::{SecretBytes, StrongKey, WeakKeyError},
16    alloc::{Cow, String, ToOwned, Vec},
17    jwk::{JsonWebKey, JwkError, KeyType, RsaPrimeFactor, RsaPrivateParts},
18};
19
20/// RSA signature.
21#[derive(Debug)]
22#[cfg_attr(docsrs, doc(cfg(feature = "rsa")))]
23pub struct RsaSignature(Vec<u8>);
24
25impl AlgorithmSignature for RsaSignature {
26    fn try_from_slice(bytes: &[u8]) -> anyhow::Result<Self> {
27        Ok(RsaSignature(bytes.to_vec()))
28    }
29
30    fn as_bytes(&self) -> Cow<'_, [u8]> {
31        Cow::Borrowed(&self.0)
32    }
33}
34
35/// RSA hash algorithm.
36#[derive(Debug, Copy, Clone, Eq, PartialEq)]
37enum HashAlg {
38    Sha256,
39    Sha384,
40    Sha512,
41}
42
43impl HashAlg {
44    fn digest(self, message: &[u8]) -> HashDigest {
45        match self {
46            Self::Sha256 => HashDigest::Sha256(Sha256::digest(message).into()),
47            Self::Sha384 => HashDigest::Sha384(Sha384::digest(message).into()),
48            Self::Sha512 => HashDigest::Sha512(Sha512::digest(message).into()),
49        }
50    }
51}
52
53/// Output of a [`HashAlg`].
54#[derive(Debug)]
55enum HashDigest {
56    Sha256([u8; 32]),
57    Sha384([u8; 48]),
58    Sha512([u8; 64]),
59}
60
61impl AsRef<[u8]> for HashDigest {
62    fn as_ref(&self) -> &[u8] {
63        match self {
64            Self::Sha256(bytes) => bytes,
65            Self::Sha384(bytes) => bytes,
66            Self::Sha512(bytes) => bytes,
67        }
68    }
69}
70
71/// RSA padding algorithm.
72#[derive(Debug, Copy, Clone, Eq, PartialEq)]
73enum Padding {
74    Pkcs1v15,
75    Pss,
76}
77
78#[derive(Debug)]
79enum PaddingScheme {
80    Pkcs1v15(Pkcs1v15Sign),
81    Pss256(Pss<Sha256>),
82    Pss384(Pss<Sha384>),
83    Pss512(Pss<Sha512>),
84}
85
86/// Bit length of an RSA key modulus (aka RSA key length).
87#[derive(Debug, Copy, Clone, Eq, PartialEq)]
88#[non_exhaustive]
89#[cfg_attr(docsrs, doc(cfg(feature = "rsa")))]
90pub enum ModulusBits {
91    /// 2048 bits. This is the minimum recommended key length as of 2020.
92    TwoKibibytes,
93    /// 3072 bits.
94    ThreeKibibytes,
95    /// 4096 bits.
96    FourKibibytes,
97}
98
99impl ModulusBits {
100    /// Converts this length to the numeric value.
101    pub fn bits(self) -> usize {
102        match self {
103            Self::TwoKibibytes => 2_048,
104            Self::ThreeKibibytes => 3_072,
105            Self::FourKibibytes => 4_096,
106        }
107    }
108
109    fn is_valid_bits(bits: u32) -> bool {
110        matches!(bits, 2_048 | 3_072 | 4_096)
111    }
112}
113
114impl TryFrom<usize> for ModulusBits {
115    type Error = ModulusBitsError;
116
117    fn try_from(value: usize) -> Result<Self, Self::Error> {
118        match value {
119            2_048 => Ok(Self::TwoKibibytes),
120            3_072 => Ok(Self::ThreeKibibytes),
121            4_096 => Ok(Self::FourKibibytes),
122            _ => Err(ModulusBitsError(())),
123        }
124    }
125}
126
127/// Error type returned when a conversion of an integer into `ModulusBits` fails.
128#[derive(Debug)]
129#[cfg_attr(docsrs, doc(cfg(feature = "rsa")))]
130pub struct ModulusBitsError(());
131
132impl fmt::Display for ModulusBitsError {
133    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
134        formatter.write_str(
135            "Unsupported bit length of RSA modulus; only lengths 2048, 3072 and 4096 \
136            are supported.",
137        )
138    }
139}
140
141#[cfg(feature = "std")]
142impl std::error::Error for ModulusBitsError {}
143
144/// Integrity algorithm using [RSA] digital signatures.
145///
146/// Depending on the variation, the algorithm employs PKCS#1 v1.5 or PSS padding and
147/// one of the hash functions from the SHA-2 family: SHA-256, SHA-384, or SHA-512.
148/// See [RFC 7518] for more details. Depending on the chosen parameters,
149/// the name of the algorithm is one of `RS256`, `RS384`, `RS512`, `PS256`, `PS384`, `PS512`:
150///
151/// - `R` / `P` denote the padding scheme: PKCS#1 v1.5 for `R`, PSS for `P`
152/// - `256` / `384` / `512` denote the hash function
153///
154/// The length of RSA keys is not unequivocally specified by the algorithm; nevertheless,
155/// it **MUST** be at least 2048 bits as per RFC 7518. To minimize risks of misconfiguration,
156/// use [`StrongAlg`](super::StrongAlg) wrapper around `Rsa`:
157///
158/// ```
159/// # use jwt_compact::alg::{StrongAlg, Rsa};
160/// const ALG: StrongAlg<Rsa> = StrongAlg(Rsa::rs256());
161/// // `ALG` will not support RSA keys with unsecure lengths by design!
162/// ```
163///
164/// [RSA]: https://en.wikipedia.org/wiki/RSA_(cryptosystem)
165/// [RFC 7518]: https://www.rfc-editor.org/rfc/rfc7518.html
166#[derive(Debug, Clone, Copy, PartialEq, Eq)]
167#[cfg_attr(docsrs, doc(cfg(feature = "rsa")))]
168pub struct Rsa {
169    hash_alg: HashAlg,
170    padding_alg: Padding,
171}
172
173impl Algorithm for Rsa {
174    type SigningKey = RsaPrivateKey;
175    type VerifyingKey = RsaPublicKey;
176    type Signature = RsaSignature;
177
178    fn name(&self) -> Cow<'static, str> {
179        Cow::Borrowed(self.alg_name())
180    }
181
182    fn sign(&self, signing_key: &Self::SigningKey, message: &[u8]) -> Self::Signature {
183        let digest = self.hash_alg.digest(message);
184        let digest = digest.as_ref();
185        let signing_result = match self.padding_scheme() {
186            PaddingScheme::Pkcs1v15(padding) => signing_key.sign_with_rng(
187                &mut rand_core::UnwrapErr(getrandom::SysRng),
188                padding,
189                digest,
190            ),
191            PaddingScheme::Pss256(padding) => signing_key.sign_with_rng(
192                &mut rand_core::UnwrapErr(getrandom::SysRng),
193                padding,
194                digest,
195            ),
196            PaddingScheme::Pss384(padding) => signing_key.sign_with_rng(
197                &mut rand_core::UnwrapErr(getrandom::SysRng),
198                padding,
199                digest,
200            ),
201            PaddingScheme::Pss512(padding) => signing_key.sign_with_rng(
202                &mut rand_core::UnwrapErr(getrandom::SysRng),
203                padding,
204                digest,
205            ),
206        };
207        RsaSignature(signing_result.expect("Unexpected RSA signature failure"))
208    }
209
210    fn verify_signature(
211        &self,
212        signature: &Self::Signature,
213        verifying_key: &Self::VerifyingKey,
214        message: &[u8],
215    ) -> bool {
216        let digest = self.hash_alg.digest(message);
217        let digest = digest.as_ref();
218        let verify_result = match self.padding_scheme() {
219            PaddingScheme::Pkcs1v15(padding) => verifying_key.verify(padding, digest, &signature.0),
220            PaddingScheme::Pss256(padding) => verifying_key.verify(padding, digest, &signature.0),
221            PaddingScheme::Pss384(padding) => verifying_key.verify(padding, digest, &signature.0),
222            PaddingScheme::Pss512(padding) => verifying_key.verify(padding, digest, &signature.0),
223        };
224        verify_result.is_ok()
225    }
226}
227
228impl Rsa {
229    const fn new(hash_alg: HashAlg, padding_alg: Padding) -> Self {
230        Rsa {
231            hash_alg,
232            padding_alg,
233        }
234    }
235
236    /// RSA with SHA-256 and PKCS#1 v1.5 padding.
237    pub const fn rs256() -> Rsa {
238        Rsa::new(HashAlg::Sha256, Padding::Pkcs1v15)
239    }
240
241    /// RSA with SHA-384 and PKCS#1 v1.5 padding.
242    pub const fn rs384() -> Rsa {
243        Rsa::new(HashAlg::Sha384, Padding::Pkcs1v15)
244    }
245
246    /// RSA with SHA-512 and PKCS#1 v1.5 padding.
247    pub const fn rs512() -> Rsa {
248        Rsa::new(HashAlg::Sha512, Padding::Pkcs1v15)
249    }
250
251    /// RSA with SHA-256 and PSS padding.
252    pub const fn ps256() -> Rsa {
253        Rsa::new(HashAlg::Sha256, Padding::Pss)
254    }
255
256    /// RSA with SHA-384 and PSS padding.
257    pub const fn ps384() -> Rsa {
258        Rsa::new(HashAlg::Sha384, Padding::Pss)
259    }
260
261    /// RSA with SHA-512 and PSS padding.
262    pub const fn ps512() -> Rsa {
263        Rsa::new(HashAlg::Sha512, Padding::Pss)
264    }
265
266    /// RSA based on the specified algorithm name.
267    ///
268    /// # Panics
269    ///
270    /// - Panics if the name is not one of the six RSA-based JWS algorithms. Prefer using
271    ///   the [`FromStr`] trait if the conversion is potentially fallible.
272    pub fn with_name(name: &str) -> Self {
273        name.parse().unwrap()
274    }
275
276    fn padding_scheme(self) -> PaddingScheme {
277        match self.padding_alg {
278            Padding::Pkcs1v15 => PaddingScheme::Pkcs1v15(match self.hash_alg {
279                HashAlg::Sha256 => Pkcs1v15Sign::new::<Sha256>(),
280                HashAlg::Sha384 => Pkcs1v15Sign::new::<Sha384>(),
281                HashAlg::Sha512 => Pkcs1v15Sign::new::<Sha512>(),
282            }),
283            Padding::Pss => {
284                // The salt length needs to be set to the size of hash function output;
285                // see https://www.rfc-editor.org/rfc/rfc7518.html#section-3.5.
286                match self.hash_alg {
287                    HashAlg::Sha256 => {
288                        PaddingScheme::Pss256(Pss::new_with_salt(Sha256::output_size()))
289                    }
290                    HashAlg::Sha384 => {
291                        PaddingScheme::Pss384(Pss::new_with_salt(Sha384::output_size()))
292                    }
293                    HashAlg::Sha512 => {
294                        PaddingScheme::Pss512(Pss::new_with_salt(Sha512::output_size()))
295                    }
296                }
297            }
298        }
299    }
300
301    fn alg_name(self) -> &'static str {
302        match (self.padding_alg, self.hash_alg) {
303            (Padding::Pkcs1v15, HashAlg::Sha256) => "RS256",
304            (Padding::Pkcs1v15, HashAlg::Sha384) => "RS384",
305            (Padding::Pkcs1v15, HashAlg::Sha512) => "RS512",
306            (Padding::Pss, HashAlg::Sha256) => "PS256",
307            (Padding::Pss, HashAlg::Sha384) => "PS384",
308            (Padding::Pss, HashAlg::Sha512) => "PS512",
309        }
310    }
311
312    /// Generates a new key pair with the specified modulus bit length (aka key length).
313    pub fn generate<R: CryptoRng>(
314        rng: &mut R,
315        modulus_bits: ModulusBits,
316    ) -> rsa::errors::Result<(StrongKey<RsaPrivateKey>, StrongKey<RsaPublicKey>)> {
317        let signing_key = RsaPrivateKey::new(rng, modulus_bits.bits())?;
318        let verifying_key = signing_key.to_public_key();
319        Ok((StrongKey(signing_key), StrongKey(verifying_key)))
320    }
321}
322
323impl FromStr for Rsa {
324    type Err = RsaParseError;
325
326    fn from_str(s: &str) -> Result<Self, Self::Err> {
327        Ok(match s {
328            "RS256" => Self::rs256(),
329            "RS384" => Self::rs384(),
330            "RS512" => Self::rs512(),
331            "PS256" => Self::ps256(),
332            "PS384" => Self::ps384(),
333            "PS512" => Self::ps512(),
334            _ => return Err(RsaParseError(s.to_owned())),
335        })
336    }
337}
338
339/// Errors that can occur when parsing an [`Rsa`] algorithm from a string.
340#[derive(Debug)]
341#[cfg_attr(docsrs, doc(cfg(feature = "rsa")))]
342pub struct RsaParseError(String);
343
344impl fmt::Display for RsaParseError {
345    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
346        write!(formatter, "Invalid RSA algorithm name: {}", self.0)
347    }
348}
349
350#[cfg(feature = "std")]
351impl std::error::Error for RsaParseError {}
352
353impl StrongKey<RsaPrivateKey> {
354    /// Converts this private key to a public key.
355    pub fn to_public_key(&self) -> StrongKey<RsaPublicKey> {
356        StrongKey(self.0.to_public_key())
357    }
358}
359
360impl TryFrom<RsaPrivateKey> for StrongKey<RsaPrivateKey> {
361    type Error = WeakKeyError<RsaPrivateKey>;
362
363    fn try_from(key: RsaPrivateKey) -> Result<Self, Self::Error> {
364        if ModulusBits::is_valid_bits(key.n().bits()) {
365            Ok(StrongKey(key))
366        } else {
367            Err(WeakKeyError(key))
368        }
369    }
370}
371
372impl TryFrom<RsaPublicKey> for StrongKey<RsaPublicKey> {
373    type Error = WeakKeyError<RsaPublicKey>;
374
375    fn try_from(key: RsaPublicKey) -> Result<Self, Self::Error> {
376        if ModulusBits::is_valid_bits(key.n().bits()) {
377            Ok(StrongKey(key))
378        } else {
379            Err(WeakKeyError(key))
380        }
381    }
382}
383
384impl<'a> From<&'a RsaPublicKey> for JsonWebKey<'a> {
385    fn from(key: &'a RsaPublicKey) -> JsonWebKey<'a> {
386        JsonWebKey::Rsa {
387            modulus: Cow::Owned(key.n().to_be_bytes_trimmed_vartime().into()),
388            public_exponent: Cow::Owned(key.e().to_be_bytes_trimmed_vartime().into()),
389            private_parts: None,
390        }
391    }
392}
393
394#[allow(clippy::cast_possible_truncation)] // not triggered
395fn secret_uint_from_slice(slice: &[u8], precision: u32) -> Result<BoxedUint, JwkError> {
396    debug_assert!(precision <= RsaPublicKey::MAX_SIZE as u32);
397    BoxedUint::from_be_slice(slice, precision).map_err(|err| JwkError::custom(anyhow::anyhow!(err)))
398}
399
400/// The caller must ensure that setting `precision` won't truncate the value.
401fn secret_uint_to_slice(secret: &BoxedUint, precision: u32) -> SecretBytes<'static> {
402    let bytes = secret.to_be_bytes();
403    let precision_bytes = precision.div_ceil(8) as usize;
404    SecretBytes::owned_slice(if bytes.len() > precision_bytes {
405        let first_idx = bytes.len() - precision_bytes;
406        bytes[first_idx..].into()
407    } else {
408        bytes
409    })
410}
411
412fn pub_exponent_from_slice(slice: &[u8]) -> Result<BoxedUint, JwkError> {
413    BoxedUint::from_be_slice(slice, RsaPublicKey::MAX_PUB_EXPONENT.ilog2() + 1)
414        .map_err(|err| JwkError::custom(anyhow::anyhow!(err)))
415}
416
417impl TryFrom<&JsonWebKey<'_>> for RsaPublicKey {
418    type Error = JwkError;
419
420    fn try_from(jwk: &JsonWebKey<'_>) -> Result<Self, Self::Error> {
421        let JsonWebKey::Rsa {
422            modulus,
423            public_exponent,
424            ..
425        } = jwk
426        else {
427            return Err(JwkError::key_type(jwk, KeyType::Rsa));
428        };
429
430        let e = pub_exponent_from_slice(public_exponent)?;
431        let n = BoxedUint::from_be_slice_vartime(modulus);
432        Self::new(n, e).map_err(|err| JwkError::custom(anyhow::anyhow!(err)))
433    }
434}
435
436/// ⚠ **Warning.** Contrary to [RFC 7518], this implementation does not set `dp`, `dq`, and `qi`
437/// fields in the JWK root object, as well as `d` and `t` fields for additional factors
438/// (i.e., in the `oth` array).
439///
440/// [RFC 7518]: https://tools.ietf.org/html/rfc7518#section-6.3.2
441impl<'a> From<&'a RsaPrivateKey> for JsonWebKey<'a> {
442    fn from(key: &'a RsaPrivateKey) -> JsonWebKey<'a> {
443        const MSG: &str = "RsaPrivateKey must have at least 2 prime factors";
444
445        let p = key.primes().first().expect(MSG);
446        let q = key.primes().get(1).expect(MSG);
447        // Truncate secret values to the modulus precision. We know that all secret values don't exceed the modulus,
448        // so this is safe. `d` in particular does have q higher precision for multi-prime RSA keys
449        // (e.g., it may have 2,176-bit precision for a 2,048-bit modulus), which would lead to excessive zero padding
450        // and may lead to unnecessary deserialization errors.
451        let precision = key.n().bits_precision();
452
453        let private_parts = RsaPrivateParts {
454            private_exponent: secret_uint_to_slice(key.d(), precision),
455            prime_factor_p: secret_uint_to_slice(p, precision),
456            prime_factor_q: secret_uint_to_slice(q, precision),
457            p_crt_exponent: None,
458            q_crt_exponent: None,
459            q_crt_coefficient: None,
460            other_prime_factors: key.primes()[2..]
461                .iter()
462                .map(|factor| RsaPrimeFactor {
463                    factor: secret_uint_to_slice(factor, precision),
464                    crt_exponent: None,
465                    crt_coefficient: None,
466                })
467                .collect(),
468        };
469
470        JsonWebKey::Rsa {
471            modulus: Cow::Owned(key.n().to_be_bytes_trimmed_vartime().into()),
472            public_exponent: Cow::Owned(key.e().to_be_bytes_trimmed_vartime().into()),
473            private_parts: Some(private_parts),
474        }
475    }
476}
477
478/// ⚠ **Warning.** Contrary to [RFC 7518] (at least, in spirit), this conversion ignores
479/// `dp`, `dq`, and `qi` fields from JWK, as well as `d` and `t` fields for additional factors.
480///
481/// [RFC 7518]: https://www.rfc-editor.org/rfc/rfc7518.html
482impl TryFrom<&JsonWebKey<'_>> for RsaPrivateKey {
483    type Error = JwkError;
484
485    fn try_from(jwk: &JsonWebKey<'_>) -> Result<Self, Self::Error> {
486        let JsonWebKey::Rsa {
487            modulus,
488            public_exponent,
489            private_parts,
490        } = jwk
491        else {
492            return Err(JwkError::key_type(jwk, KeyType::Rsa));
493        };
494
495        let RsaPrivateParts {
496            private_exponent: d,
497            prime_factor_p,
498            prime_factor_q,
499            other_prime_factors,
500            ..
501        } = private_parts
502            .as_ref()
503            .ok_or_else(|| JwkError::NoField("d".into()))?;
504
505        let e = pub_exponent_from_slice(public_exponent)?;
506        let n = BoxedUint::from_be_slice_vartime(modulus);
507
508        // Round `n` bitness up to the nearest value divisible by 8
509        let precision = n.bits().div_ceil(8) * 8;
510        if precision as usize > RsaPublicKey::MAX_SIZE {
511            return Err(JwkError::Custom(anyhow::anyhow!(
512                "Modulus precision ({got}) exceeds maximum supported value ({max})",
513                got = n.bits(),
514                max = RsaPublicKey::MAX_SIZE
515            )));
516        }
517
518        let d = secret_uint_from_slice(d, precision)?;
519        let mut factors = Vec::with_capacity(2 + other_prime_factors.len());
520        factors.push(secret_uint_from_slice(prime_factor_p, precision)?);
521        factors.push(secret_uint_from_slice(prime_factor_q, precision)?);
522        for other_factor in other_prime_factors {
523            factors.push(secret_uint_from_slice(&other_factor.factor, precision)?);
524        }
525
526        let key = Self::from_components(n, e, d, factors);
527        let key = key.map_err(|err| JwkError::custom(anyhow::anyhow!(err)))?;
528        key.validate()
529            .map_err(|err| JwkError::custom(anyhow::anyhow!(err)))?;
530        Ok(key)
531    }
532}