1use core::{fmt, str::FromStr};
4
5use rand_core::CryptoRng;
6use rsa::{
7 BoxedUint, Pkcs1v15Sign, Pss,
8 traits::{PrivateKeyParts, PublicKeyParts},
9};
10pub use rsa::{RsaPrivateKey, RsaPublicKey, errors::Error as RsaError};
11use sha2::{Digest, Sha256, Sha384, Sha512};
12
13use crate::{
14 Algorithm, AlgorithmSignature,
15 alg::{SecretBytes, StrongKey, WeakKeyError},
16 alloc::{Cow, String, ToOwned, Vec},
17 jwk::{JsonWebKey, JwkError, KeyType, RsaPrimeFactor, RsaPrivateParts},
18};
19
20#[derive(Debug)]
22#[cfg_attr(docsrs, doc(cfg(feature = "rsa")))]
23pub struct RsaSignature(Vec<u8>);
24
25impl AlgorithmSignature for RsaSignature {
26 fn try_from_slice(bytes: &[u8]) -> anyhow::Result<Self> {
27 Ok(RsaSignature(bytes.to_vec()))
28 }
29
30 fn as_bytes(&self) -> Cow<'_, [u8]> {
31 Cow::Borrowed(&self.0)
32 }
33}
34
35#[derive(Debug, Copy, Clone, Eq, PartialEq)]
37enum HashAlg {
38 Sha256,
39 Sha384,
40 Sha512,
41}
42
43impl HashAlg {
44 fn digest(self, message: &[u8]) -> HashDigest {
45 match self {
46 Self::Sha256 => HashDigest::Sha256(Sha256::digest(message).into()),
47 Self::Sha384 => HashDigest::Sha384(Sha384::digest(message).into()),
48 Self::Sha512 => HashDigest::Sha512(Sha512::digest(message).into()),
49 }
50 }
51}
52
53#[derive(Debug)]
55enum HashDigest {
56 Sha256([u8; 32]),
57 Sha384([u8; 48]),
58 Sha512([u8; 64]),
59}
60
61impl AsRef<[u8]> for HashDigest {
62 fn as_ref(&self) -> &[u8] {
63 match self {
64 Self::Sha256(bytes) => bytes,
65 Self::Sha384(bytes) => bytes,
66 Self::Sha512(bytes) => bytes,
67 }
68 }
69}
70
71#[derive(Debug, Copy, Clone, Eq, PartialEq)]
73enum Padding {
74 Pkcs1v15,
75 Pss,
76}
77
78#[derive(Debug)]
79enum PaddingScheme {
80 Pkcs1v15(Pkcs1v15Sign),
81 Pss256(Pss<Sha256>),
82 Pss384(Pss<Sha384>),
83 Pss512(Pss<Sha512>),
84}
85
86#[derive(Debug, Copy, Clone, Eq, PartialEq)]
88#[non_exhaustive]
89#[cfg_attr(docsrs, doc(cfg(feature = "rsa")))]
90pub enum ModulusBits {
91 TwoKibibytes,
93 ThreeKibibytes,
95 FourKibibytes,
97}
98
99impl ModulusBits {
100 pub fn bits(self) -> usize {
102 match self {
103 Self::TwoKibibytes => 2_048,
104 Self::ThreeKibibytes => 3_072,
105 Self::FourKibibytes => 4_096,
106 }
107 }
108
109 fn is_valid_bits(bits: u32) -> bool {
110 matches!(bits, 2_048 | 3_072 | 4_096)
111 }
112}
113
114impl TryFrom<usize> for ModulusBits {
115 type Error = ModulusBitsError;
116
117 fn try_from(value: usize) -> Result<Self, Self::Error> {
118 match value {
119 2_048 => Ok(Self::TwoKibibytes),
120 3_072 => Ok(Self::ThreeKibibytes),
121 4_096 => Ok(Self::FourKibibytes),
122 _ => Err(ModulusBitsError(())),
123 }
124 }
125}
126
127#[derive(Debug)]
129#[cfg_attr(docsrs, doc(cfg(feature = "rsa")))]
130pub struct ModulusBitsError(());
131
132impl fmt::Display for ModulusBitsError {
133 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
134 formatter.write_str(
135 "Unsupported bit length of RSA modulus; only lengths 2048, 3072 and 4096 \
136 are supported.",
137 )
138 }
139}
140
141#[cfg(feature = "std")]
142impl std::error::Error for ModulusBitsError {}
143
144#[derive(Debug, Clone, Copy, PartialEq, Eq)]
167#[cfg_attr(docsrs, doc(cfg(feature = "rsa")))]
168pub struct Rsa {
169 hash_alg: HashAlg,
170 padding_alg: Padding,
171}
172
173impl Algorithm for Rsa {
174 type SigningKey = RsaPrivateKey;
175 type VerifyingKey = RsaPublicKey;
176 type Signature = RsaSignature;
177
178 fn name(&self) -> Cow<'static, str> {
179 Cow::Borrowed(self.alg_name())
180 }
181
182 fn sign(&self, signing_key: &Self::SigningKey, message: &[u8]) -> Self::Signature {
183 let digest = self.hash_alg.digest(message);
184 let digest = digest.as_ref();
185 let signing_result = match self.padding_scheme() {
186 PaddingScheme::Pkcs1v15(padding) => signing_key.sign_with_rng(
187 &mut rand_core::UnwrapErr(getrandom::SysRng),
188 padding,
189 digest,
190 ),
191 PaddingScheme::Pss256(padding) => signing_key.sign_with_rng(
192 &mut rand_core::UnwrapErr(getrandom::SysRng),
193 padding,
194 digest,
195 ),
196 PaddingScheme::Pss384(padding) => signing_key.sign_with_rng(
197 &mut rand_core::UnwrapErr(getrandom::SysRng),
198 padding,
199 digest,
200 ),
201 PaddingScheme::Pss512(padding) => signing_key.sign_with_rng(
202 &mut rand_core::UnwrapErr(getrandom::SysRng),
203 padding,
204 digest,
205 ),
206 };
207 RsaSignature(signing_result.expect("Unexpected RSA signature failure"))
208 }
209
210 fn verify_signature(
211 &self,
212 signature: &Self::Signature,
213 verifying_key: &Self::VerifyingKey,
214 message: &[u8],
215 ) -> bool {
216 let digest = self.hash_alg.digest(message);
217 let digest = digest.as_ref();
218 let verify_result = match self.padding_scheme() {
219 PaddingScheme::Pkcs1v15(padding) => verifying_key.verify(padding, digest, &signature.0),
220 PaddingScheme::Pss256(padding) => verifying_key.verify(padding, digest, &signature.0),
221 PaddingScheme::Pss384(padding) => verifying_key.verify(padding, digest, &signature.0),
222 PaddingScheme::Pss512(padding) => verifying_key.verify(padding, digest, &signature.0),
223 };
224 verify_result.is_ok()
225 }
226}
227
228impl Rsa {
229 const fn new(hash_alg: HashAlg, padding_alg: Padding) -> Self {
230 Rsa {
231 hash_alg,
232 padding_alg,
233 }
234 }
235
236 pub const fn rs256() -> Rsa {
238 Rsa::new(HashAlg::Sha256, Padding::Pkcs1v15)
239 }
240
241 pub const fn rs384() -> Rsa {
243 Rsa::new(HashAlg::Sha384, Padding::Pkcs1v15)
244 }
245
246 pub const fn rs512() -> Rsa {
248 Rsa::new(HashAlg::Sha512, Padding::Pkcs1v15)
249 }
250
251 pub const fn ps256() -> Rsa {
253 Rsa::new(HashAlg::Sha256, Padding::Pss)
254 }
255
256 pub const fn ps384() -> Rsa {
258 Rsa::new(HashAlg::Sha384, Padding::Pss)
259 }
260
261 pub const fn ps512() -> Rsa {
263 Rsa::new(HashAlg::Sha512, Padding::Pss)
264 }
265
266 pub fn with_name(name: &str) -> Self {
273 name.parse().unwrap()
274 }
275
276 fn padding_scheme(self) -> PaddingScheme {
277 match self.padding_alg {
278 Padding::Pkcs1v15 => PaddingScheme::Pkcs1v15(match self.hash_alg {
279 HashAlg::Sha256 => Pkcs1v15Sign::new::<Sha256>(),
280 HashAlg::Sha384 => Pkcs1v15Sign::new::<Sha384>(),
281 HashAlg::Sha512 => Pkcs1v15Sign::new::<Sha512>(),
282 }),
283 Padding::Pss => {
284 match self.hash_alg {
287 HashAlg::Sha256 => {
288 PaddingScheme::Pss256(Pss::new_with_salt(Sha256::output_size()))
289 }
290 HashAlg::Sha384 => {
291 PaddingScheme::Pss384(Pss::new_with_salt(Sha384::output_size()))
292 }
293 HashAlg::Sha512 => {
294 PaddingScheme::Pss512(Pss::new_with_salt(Sha512::output_size()))
295 }
296 }
297 }
298 }
299 }
300
301 fn alg_name(self) -> &'static str {
302 match (self.padding_alg, self.hash_alg) {
303 (Padding::Pkcs1v15, HashAlg::Sha256) => "RS256",
304 (Padding::Pkcs1v15, HashAlg::Sha384) => "RS384",
305 (Padding::Pkcs1v15, HashAlg::Sha512) => "RS512",
306 (Padding::Pss, HashAlg::Sha256) => "PS256",
307 (Padding::Pss, HashAlg::Sha384) => "PS384",
308 (Padding::Pss, HashAlg::Sha512) => "PS512",
309 }
310 }
311
312 pub fn generate<R: CryptoRng>(
314 rng: &mut R,
315 modulus_bits: ModulusBits,
316 ) -> rsa::errors::Result<(StrongKey<RsaPrivateKey>, StrongKey<RsaPublicKey>)> {
317 let signing_key = RsaPrivateKey::new(rng, modulus_bits.bits())?;
318 let verifying_key = signing_key.to_public_key();
319 Ok((StrongKey(signing_key), StrongKey(verifying_key)))
320 }
321}
322
323impl FromStr for Rsa {
324 type Err = RsaParseError;
325
326 fn from_str(s: &str) -> Result<Self, Self::Err> {
327 Ok(match s {
328 "RS256" => Self::rs256(),
329 "RS384" => Self::rs384(),
330 "RS512" => Self::rs512(),
331 "PS256" => Self::ps256(),
332 "PS384" => Self::ps384(),
333 "PS512" => Self::ps512(),
334 _ => return Err(RsaParseError(s.to_owned())),
335 })
336 }
337}
338
339#[derive(Debug)]
341#[cfg_attr(docsrs, doc(cfg(feature = "rsa")))]
342pub struct RsaParseError(String);
343
344impl fmt::Display for RsaParseError {
345 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
346 write!(formatter, "Invalid RSA algorithm name: {}", self.0)
347 }
348}
349
350#[cfg(feature = "std")]
351impl std::error::Error for RsaParseError {}
352
353impl StrongKey<RsaPrivateKey> {
354 pub fn to_public_key(&self) -> StrongKey<RsaPublicKey> {
356 StrongKey(self.0.to_public_key())
357 }
358}
359
360impl TryFrom<RsaPrivateKey> for StrongKey<RsaPrivateKey> {
361 type Error = WeakKeyError<RsaPrivateKey>;
362
363 fn try_from(key: RsaPrivateKey) -> Result<Self, Self::Error> {
364 if ModulusBits::is_valid_bits(key.n().bits()) {
365 Ok(StrongKey(key))
366 } else {
367 Err(WeakKeyError(key))
368 }
369 }
370}
371
372impl TryFrom<RsaPublicKey> for StrongKey<RsaPublicKey> {
373 type Error = WeakKeyError<RsaPublicKey>;
374
375 fn try_from(key: RsaPublicKey) -> Result<Self, Self::Error> {
376 if ModulusBits::is_valid_bits(key.n().bits()) {
377 Ok(StrongKey(key))
378 } else {
379 Err(WeakKeyError(key))
380 }
381 }
382}
383
384impl<'a> From<&'a RsaPublicKey> for JsonWebKey<'a> {
385 fn from(key: &'a RsaPublicKey) -> JsonWebKey<'a> {
386 JsonWebKey::Rsa {
387 modulus: Cow::Owned(key.n().to_be_bytes_trimmed_vartime().into()),
388 public_exponent: Cow::Owned(key.e().to_be_bytes_trimmed_vartime().into()),
389 private_parts: None,
390 }
391 }
392}
393
394#[allow(clippy::cast_possible_truncation)] fn secret_uint_from_slice(slice: &[u8], precision: u32) -> Result<BoxedUint, JwkError> {
396 debug_assert!(precision <= RsaPublicKey::MAX_SIZE as u32);
397 BoxedUint::from_be_slice(slice, precision).map_err(|err| JwkError::custom(anyhow::anyhow!(err)))
398}
399
400fn secret_uint_to_slice(secret: &BoxedUint, precision: u32) -> SecretBytes<'static> {
402 let bytes = secret.to_be_bytes();
403 let precision_bytes = precision.div_ceil(8) as usize;
404 SecretBytes::owned_slice(if bytes.len() > precision_bytes {
405 let first_idx = bytes.len() - precision_bytes;
406 bytes[first_idx..].into()
407 } else {
408 bytes
409 })
410}
411
412fn pub_exponent_from_slice(slice: &[u8]) -> Result<BoxedUint, JwkError> {
413 BoxedUint::from_be_slice(slice, RsaPublicKey::MAX_PUB_EXPONENT.ilog2() + 1)
414 .map_err(|err| JwkError::custom(anyhow::anyhow!(err)))
415}
416
417impl TryFrom<&JsonWebKey<'_>> for RsaPublicKey {
418 type Error = JwkError;
419
420 fn try_from(jwk: &JsonWebKey<'_>) -> Result<Self, Self::Error> {
421 let JsonWebKey::Rsa {
422 modulus,
423 public_exponent,
424 ..
425 } = jwk
426 else {
427 return Err(JwkError::key_type(jwk, KeyType::Rsa));
428 };
429
430 let e = pub_exponent_from_slice(public_exponent)?;
431 let n = BoxedUint::from_be_slice_vartime(modulus);
432 Self::new(n, e).map_err(|err| JwkError::custom(anyhow::anyhow!(err)))
433 }
434}
435
436impl<'a> From<&'a RsaPrivateKey> for JsonWebKey<'a> {
442 fn from(key: &'a RsaPrivateKey) -> JsonWebKey<'a> {
443 const MSG: &str = "RsaPrivateKey must have at least 2 prime factors";
444
445 let p = key.primes().first().expect(MSG);
446 let q = key.primes().get(1).expect(MSG);
447 let precision = key.n().bits_precision();
452
453 let private_parts = RsaPrivateParts {
454 private_exponent: secret_uint_to_slice(key.d(), precision),
455 prime_factor_p: secret_uint_to_slice(p, precision),
456 prime_factor_q: secret_uint_to_slice(q, precision),
457 p_crt_exponent: None,
458 q_crt_exponent: None,
459 q_crt_coefficient: None,
460 other_prime_factors: key.primes()[2..]
461 .iter()
462 .map(|factor| RsaPrimeFactor {
463 factor: secret_uint_to_slice(factor, precision),
464 crt_exponent: None,
465 crt_coefficient: None,
466 })
467 .collect(),
468 };
469
470 JsonWebKey::Rsa {
471 modulus: Cow::Owned(key.n().to_be_bytes_trimmed_vartime().into()),
472 public_exponent: Cow::Owned(key.e().to_be_bytes_trimmed_vartime().into()),
473 private_parts: Some(private_parts),
474 }
475 }
476}
477
478impl TryFrom<&JsonWebKey<'_>> for RsaPrivateKey {
483 type Error = JwkError;
484
485 fn try_from(jwk: &JsonWebKey<'_>) -> Result<Self, Self::Error> {
486 let JsonWebKey::Rsa {
487 modulus,
488 public_exponent,
489 private_parts,
490 } = jwk
491 else {
492 return Err(JwkError::key_type(jwk, KeyType::Rsa));
493 };
494
495 let RsaPrivateParts {
496 private_exponent: d,
497 prime_factor_p,
498 prime_factor_q,
499 other_prime_factors,
500 ..
501 } = private_parts
502 .as_ref()
503 .ok_or_else(|| JwkError::NoField("d".into()))?;
504
505 let e = pub_exponent_from_slice(public_exponent)?;
506 let n = BoxedUint::from_be_slice_vartime(modulus);
507
508 let precision = n.bits().div_ceil(8) * 8;
510 if precision as usize > RsaPublicKey::MAX_SIZE {
511 return Err(JwkError::Custom(anyhow::anyhow!(
512 "Modulus precision ({got}) exceeds maximum supported value ({max})",
513 got = n.bits(),
514 max = RsaPublicKey::MAX_SIZE
515 )));
516 }
517
518 let d = secret_uint_from_slice(d, precision)?;
519 let mut factors = Vec::with_capacity(2 + other_prime_factors.len());
520 factors.push(secret_uint_from_slice(prime_factor_p, precision)?);
521 factors.push(secret_uint_from_slice(prime_factor_q, precision)?);
522 for other_factor in other_prime_factors {
523 factors.push(secret_uint_from_slice(&other_factor.factor, precision)?);
524 }
525
526 let key = Self::from_components(n, e, d, factors);
527 let key = key.map_err(|err| JwkError::custom(anyhow::anyhow!(err)))?;
528 key.validate()
529 .map_err(|err| JwkError::custom(anyhow::anyhow!(err)))?;
530 Ok(key)
531 }
532}