1use core::{fmt, str::FromStr};
4
5use rand_core::{CryptoRng, RngCore};
6use rsa::{
7 BoxedUint, Pkcs1v15Sign, Pss,
8 traits::{PrivateKeyParts, PublicKeyParts},
9};
10pub use rsa::{RsaPrivateKey, RsaPublicKey, errors::Error as RsaError};
11use sha2::{Digest, Sha256, Sha384, Sha512};
12
13use crate::{
14 Algorithm, AlgorithmSignature,
15 alg::{SecretBytes, StrongKey, WeakKeyError},
16 alloc::{Cow, String, ToOwned, Vec},
17 jwk::{JsonWebKey, JwkError, KeyType, RsaPrimeFactor, RsaPrivateParts},
18};
19
20#[derive(Debug)]
22#[cfg_attr(docsrs, doc(cfg(feature = "rsa")))]
23pub struct RsaSignature(Vec<u8>);
24
25impl AlgorithmSignature for RsaSignature {
26 fn try_from_slice(bytes: &[u8]) -> anyhow::Result<Self> {
27 Ok(RsaSignature(bytes.to_vec()))
28 }
29
30 fn as_bytes(&self) -> Cow<'_, [u8]> {
31 Cow::Borrowed(&self.0)
32 }
33}
34
35#[derive(Debug, Copy, Clone, Eq, PartialEq)]
37enum HashAlg {
38 Sha256,
39 Sha384,
40 Sha512,
41}
42
43impl HashAlg {
44 fn digest(self, message: &[u8]) -> HashDigest {
45 match self {
46 Self::Sha256 => HashDigest::Sha256(Sha256::digest(message).into()),
47 Self::Sha384 => HashDigest::Sha384(Sha384::digest(message).into()),
48 Self::Sha512 => HashDigest::Sha512(Sha512::digest(message).into()),
49 }
50 }
51}
52
53#[derive(Debug)]
55enum HashDigest {
56 Sha256([u8; 32]),
57 Sha384([u8; 48]),
58 Sha512([u8; 64]),
59}
60
61impl AsRef<[u8]> for HashDigest {
62 fn as_ref(&self) -> &[u8] {
63 match self {
64 Self::Sha256(bytes) => bytes,
65 Self::Sha384(bytes) => bytes,
66 Self::Sha512(bytes) => bytes,
67 }
68 }
69}
70
71#[derive(Debug, Copy, Clone, Eq, PartialEq)]
73enum Padding {
74 Pkcs1v15,
75 Pss,
76}
77
78#[derive(Debug)]
79enum PaddingScheme {
80 Pkcs1v15(Pkcs1v15Sign),
81 Pss(Pss),
82}
83
84#[derive(Debug, Copy, Clone, Eq, PartialEq)]
86#[non_exhaustive]
87#[cfg_attr(docsrs, doc(cfg(feature = "rsa")))]
88pub enum ModulusBits {
89 TwoKibibytes,
91 ThreeKibibytes,
93 FourKibibytes,
95}
96
97impl ModulusBits {
98 pub fn bits(self) -> usize {
100 match self {
101 Self::TwoKibibytes => 2_048,
102 Self::ThreeKibibytes => 3_072,
103 Self::FourKibibytes => 4_096,
104 }
105 }
106
107 fn is_valid_bits(bits: u32) -> bool {
108 matches!(bits, 2_048 | 3_072 | 4_096)
109 }
110}
111
112impl TryFrom<usize> for ModulusBits {
113 type Error = ModulusBitsError;
114
115 fn try_from(value: usize) -> Result<Self, Self::Error> {
116 match value {
117 2_048 => Ok(Self::TwoKibibytes),
118 3_072 => Ok(Self::ThreeKibibytes),
119 4_096 => Ok(Self::FourKibibytes),
120 _ => Err(ModulusBitsError(())),
121 }
122 }
123}
124
125#[derive(Debug)]
127#[cfg_attr(docsrs, doc(cfg(feature = "rsa")))]
128pub struct ModulusBitsError(());
129
130impl fmt::Display for ModulusBitsError {
131 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
132 formatter.write_str(
133 "Unsupported bit length of RSA modulus; only lengths 2048, 3072 and 4096 \
134 are supported.",
135 )
136 }
137}
138
139#[cfg(feature = "std")]
140impl std::error::Error for ModulusBitsError {}
141
142#[derive(Debug, Clone, Copy, PartialEq, Eq)]
165#[cfg_attr(docsrs, doc(cfg(feature = "rsa")))]
166pub struct Rsa {
167 hash_alg: HashAlg,
168 padding_alg: Padding,
169}
170
171impl Algorithm for Rsa {
172 type SigningKey = RsaPrivateKey;
173 type VerifyingKey = RsaPublicKey;
174 type Signature = RsaSignature;
175
176 fn name(&self) -> Cow<'static, str> {
177 Cow::Borrowed(self.alg_name())
178 }
179
180 fn sign(&self, signing_key: &Self::SigningKey, message: &[u8]) -> Self::Signature {
181 let digest = self.hash_alg.digest(message);
182 let digest = digest.as_ref();
183 let signing_result = match self.padding_scheme() {
184 PaddingScheme::Pkcs1v15(padding) => signing_key.sign_with_rng(
185 &mut rand_core::UnwrapErr(rand_core::OsRng),
186 padding,
187 digest,
188 ),
189 PaddingScheme::Pss(padding) => signing_key.sign_with_rng(
190 &mut rand_core::UnwrapErr(rand_core::OsRng),
191 padding,
192 digest,
193 ),
194 };
195 RsaSignature(signing_result.expect("Unexpected RSA signature failure"))
196 }
197
198 fn verify_signature(
199 &self,
200 signature: &Self::Signature,
201 verifying_key: &Self::VerifyingKey,
202 message: &[u8],
203 ) -> bool {
204 let digest = self.hash_alg.digest(message);
205 let digest = digest.as_ref();
206 let verify_result = match self.padding_scheme() {
207 PaddingScheme::Pkcs1v15(padding) => verifying_key.verify(padding, digest, &signature.0),
208 PaddingScheme::Pss(padding) => verifying_key.verify(padding, digest, &signature.0),
209 };
210 verify_result.is_ok()
211 }
212}
213
214impl Rsa {
215 const fn new(hash_alg: HashAlg, padding_alg: Padding) -> Self {
216 Rsa {
217 hash_alg,
218 padding_alg,
219 }
220 }
221
222 pub const fn rs256() -> Rsa {
224 Rsa::new(HashAlg::Sha256, Padding::Pkcs1v15)
225 }
226
227 pub const fn rs384() -> Rsa {
229 Rsa::new(HashAlg::Sha384, Padding::Pkcs1v15)
230 }
231
232 pub const fn rs512() -> Rsa {
234 Rsa::new(HashAlg::Sha512, Padding::Pkcs1v15)
235 }
236
237 pub const fn ps256() -> Rsa {
239 Rsa::new(HashAlg::Sha256, Padding::Pss)
240 }
241
242 pub const fn ps384() -> Rsa {
244 Rsa::new(HashAlg::Sha384, Padding::Pss)
245 }
246
247 pub const fn ps512() -> Rsa {
249 Rsa::new(HashAlg::Sha512, Padding::Pss)
250 }
251
252 pub fn with_name(name: &str) -> Self {
259 name.parse().unwrap()
260 }
261
262 fn padding_scheme(self) -> PaddingScheme {
263 match self.padding_alg {
264 Padding::Pkcs1v15 => PaddingScheme::Pkcs1v15(match self.hash_alg {
265 HashAlg::Sha256 => Pkcs1v15Sign::new::<Sha256>(),
266 HashAlg::Sha384 => Pkcs1v15Sign::new::<Sha384>(),
267 HashAlg::Sha512 => Pkcs1v15Sign::new::<Sha512>(),
268 }),
269 Padding::Pss => {
270 PaddingScheme::Pss(match self.hash_alg {
273 HashAlg::Sha256 => Pss::new_with_salt::<Sha256>(Sha256::output_size()),
274 HashAlg::Sha384 => Pss::new_with_salt::<Sha384>(Sha384::output_size()),
275 HashAlg::Sha512 => Pss::new_with_salt::<Sha512>(Sha512::output_size()),
276 })
277 }
278 }
279 }
280
281 fn alg_name(self) -> &'static str {
282 match (self.padding_alg, self.hash_alg) {
283 (Padding::Pkcs1v15, HashAlg::Sha256) => "RS256",
284 (Padding::Pkcs1v15, HashAlg::Sha384) => "RS384",
285 (Padding::Pkcs1v15, HashAlg::Sha512) => "RS512",
286 (Padding::Pss, HashAlg::Sha256) => "PS256",
287 (Padding::Pss, HashAlg::Sha384) => "PS384",
288 (Padding::Pss, HashAlg::Sha512) => "PS512",
289 }
290 }
291
292 pub fn generate<R: CryptoRng + RngCore>(
294 rng: &mut R,
295 modulus_bits: ModulusBits,
296 ) -> rsa::errors::Result<(StrongKey<RsaPrivateKey>, StrongKey<RsaPublicKey>)> {
297 let signing_key = RsaPrivateKey::new(rng, modulus_bits.bits())?;
298 let verifying_key = signing_key.to_public_key();
299 Ok((StrongKey(signing_key), StrongKey(verifying_key)))
300 }
301}
302
303impl FromStr for Rsa {
304 type Err = RsaParseError;
305
306 fn from_str(s: &str) -> Result<Self, Self::Err> {
307 Ok(match s {
308 "RS256" => Self::rs256(),
309 "RS384" => Self::rs384(),
310 "RS512" => Self::rs512(),
311 "PS256" => Self::ps256(),
312 "PS384" => Self::ps384(),
313 "PS512" => Self::ps512(),
314 _ => return Err(RsaParseError(s.to_owned())),
315 })
316 }
317}
318
319#[derive(Debug)]
321#[cfg_attr(docsrs, doc(cfg(feature = "rsa")))]
322pub struct RsaParseError(String);
323
324impl fmt::Display for RsaParseError {
325 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
326 write!(formatter, "Invalid RSA algorithm name: {}", self.0)
327 }
328}
329
330#[cfg(feature = "std")]
331impl std::error::Error for RsaParseError {}
332
333impl StrongKey<RsaPrivateKey> {
334 pub fn to_public_key(&self) -> StrongKey<RsaPublicKey> {
336 StrongKey(self.0.to_public_key())
337 }
338}
339
340impl TryFrom<RsaPrivateKey> for StrongKey<RsaPrivateKey> {
341 type Error = WeakKeyError<RsaPrivateKey>;
342
343 fn try_from(key: RsaPrivateKey) -> Result<Self, Self::Error> {
344 if ModulusBits::is_valid_bits(key.n().bits()) {
345 Ok(StrongKey(key))
346 } else {
347 Err(WeakKeyError(key))
348 }
349 }
350}
351
352impl TryFrom<RsaPublicKey> for StrongKey<RsaPublicKey> {
353 type Error = WeakKeyError<RsaPublicKey>;
354
355 fn try_from(key: RsaPublicKey) -> Result<Self, Self::Error> {
356 if ModulusBits::is_valid_bits(key.n().bits()) {
357 Ok(StrongKey(key))
358 } else {
359 Err(WeakKeyError(key))
360 }
361 }
362}
363
364impl<'a> From<&'a RsaPublicKey> for JsonWebKey<'a> {
365 fn from(key: &'a RsaPublicKey) -> JsonWebKey<'a> {
366 JsonWebKey::Rsa {
367 modulus: Cow::Owned(key.n().to_be_bytes_trimmed_vartime().into()),
368 public_exponent: Cow::Owned(key.e().to_be_bytes_trimmed_vartime().into()),
369 private_parts: None,
370 }
371 }
372}
373
374#[allow(clippy::cast_possible_truncation)] fn secret_uint_from_slice(slice: &[u8], precision: u32) -> Result<BoxedUint, JwkError> {
376 debug_assert!(precision <= RsaPublicKey::MAX_SIZE as u32);
377 BoxedUint::from_be_slice(slice, precision).map_err(|err| JwkError::custom(anyhow::anyhow!(err)))
378}
379
380fn secret_uint_to_slice(secret: &BoxedUint, precision: u32) -> SecretBytes<'static> {
382 let bytes = secret.to_be_bytes();
383 let precision_bytes = precision.div_ceil(8) as usize;
384 SecretBytes::owned_slice(if bytes.len() > precision_bytes {
385 let first_idx = bytes.len() - precision_bytes;
386 bytes[first_idx..].into()
387 } else {
388 bytes
389 })
390}
391
392fn pub_exponent_from_slice(slice: &[u8]) -> Result<BoxedUint, JwkError> {
393 BoxedUint::from_be_slice(slice, RsaPublicKey::MAX_PUB_EXPONENT.ilog2() + 1)
394 .map_err(|err| JwkError::custom(anyhow::anyhow!(err)))
395}
396
397impl TryFrom<&JsonWebKey<'_>> for RsaPublicKey {
398 type Error = JwkError;
399
400 fn try_from(jwk: &JsonWebKey<'_>) -> Result<Self, Self::Error> {
401 let JsonWebKey::Rsa {
402 modulus,
403 public_exponent,
404 ..
405 } = jwk
406 else {
407 return Err(JwkError::key_type(jwk, KeyType::Rsa));
408 };
409
410 let e = pub_exponent_from_slice(public_exponent)?;
411 let n = BoxedUint::from_be_slice_vartime(modulus);
412 Self::new(n, e).map_err(|err| JwkError::custom(anyhow::anyhow!(err)))
413 }
414}
415
416impl<'a> From<&'a RsaPrivateKey> for JsonWebKey<'a> {
422 fn from(key: &'a RsaPrivateKey) -> JsonWebKey<'a> {
423 const MSG: &str = "RsaPrivateKey must have at least 2 prime factors";
424
425 let p = key.primes().first().expect(MSG);
426 let q = key.primes().get(1).expect(MSG);
427 let precision = key.n().bits_precision();
432
433 let private_parts = RsaPrivateParts {
434 private_exponent: secret_uint_to_slice(key.d(), precision),
435 prime_factor_p: secret_uint_to_slice(p, precision),
436 prime_factor_q: secret_uint_to_slice(q, precision),
437 p_crt_exponent: None,
438 q_crt_exponent: None,
439 q_crt_coefficient: None,
440 other_prime_factors: key.primes()[2..]
441 .iter()
442 .map(|factor| RsaPrimeFactor {
443 factor: secret_uint_to_slice(factor, precision),
444 crt_exponent: None,
445 crt_coefficient: None,
446 })
447 .collect(),
448 };
449
450 JsonWebKey::Rsa {
451 modulus: Cow::Owned(key.n().to_be_bytes_trimmed_vartime().into()),
452 public_exponent: Cow::Owned(key.e().to_be_bytes_trimmed_vartime().into()),
453 private_parts: Some(private_parts),
454 }
455 }
456}
457
458impl TryFrom<&JsonWebKey<'_>> for RsaPrivateKey {
463 type Error = JwkError;
464
465 fn try_from(jwk: &JsonWebKey<'_>) -> Result<Self, Self::Error> {
466 let JsonWebKey::Rsa {
467 modulus,
468 public_exponent,
469 private_parts,
470 } = jwk
471 else {
472 return Err(JwkError::key_type(jwk, KeyType::Rsa));
473 };
474
475 let RsaPrivateParts {
476 private_exponent: d,
477 prime_factor_p,
478 prime_factor_q,
479 other_prime_factors,
480 ..
481 } = private_parts
482 .as_ref()
483 .ok_or_else(|| JwkError::NoField("d".into()))?;
484
485 let e = pub_exponent_from_slice(public_exponent)?;
486 let n = BoxedUint::from_be_slice_vartime(modulus);
487
488 let precision = n.bits().div_ceil(8) * 8;
490 if precision as usize > RsaPublicKey::MAX_SIZE {
491 return Err(JwkError::Custom(anyhow::anyhow!(
492 "Modulus precision ({got}) exceeds maximum supported value ({max})",
493 got = n.bits(),
494 max = RsaPublicKey::MAX_SIZE
495 )));
496 }
497
498 let d = secret_uint_from_slice(d, precision)?;
499 let mut factors = Vec::with_capacity(2 + other_prime_factors.len());
500 factors.push(secret_uint_from_slice(prime_factor_p, precision)?);
501 factors.push(secret_uint_from_slice(prime_factor_q, precision)?);
502 for other_factor in other_prime_factors {
503 factors.push(secret_uint_from_slice(&other_factor.factor, precision)?);
504 }
505
506 let key = Self::from_components(n, e, d, factors);
507 let key = key.map_err(|err| JwkError::custom(anyhow::anyhow!(err)))?;
508 key.validate()
509 .map_err(|err| JwkError::custom(anyhow::anyhow!(err)))?;
510 Ok(key)
511 }
512}