1use core::fmt;
37
38use serde::{Deserialize, Deserializer, Serialize, Serializer};
39use sha2::digest::{Digest, Output};
40
41use crate::{
42 alg::SecretBytes,
43 alloc::{Cow, String, ToString, Vec},
44};
45
46#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
48#[non_exhaustive]
49pub enum KeyType {
50 Rsa,
52 EllipticCurve,
55 Symmetric,
57 KeyPair,
59}
60
61impl fmt::Display for KeyType {
62 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
63 formatter.write_str(match self {
64 Self::Rsa => "RSA",
65 Self::EllipticCurve => "EC",
66 Self::Symmetric => "oct",
67 Self::KeyPair => "OKP",
68 })
69 }
70}
71
72#[derive(Debug)]
75#[non_exhaustive]
76pub enum JwkError {
77 NoField(String),
79 UnexpectedKeyType {
81 expected: KeyType,
83 actual: KeyType,
85 },
86 UnexpectedValue {
88 field: String,
90 expected: String,
92 actual: String,
94 },
95 UnexpectedLen {
97 field: String,
99 expected: usize,
101 actual: usize,
103 },
104 MismatchedKeys,
106 Custom(anyhow::Error),
108}
109
110impl fmt::Display for JwkError {
111 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
112 match self {
113 Self::UnexpectedKeyType { expected, actual } => {
114 write!(
115 formatter,
116 "unexpected key type: {actual} (expected {expected})"
117 )
118 }
119 Self::NoField(field) => write!(formatter, "field `{field}` is absent from JWK"),
120 Self::UnexpectedValue {
121 field,
122 expected,
123 actual,
124 } => {
125 write!(
126 formatter,
127 "field `{field}` has unexpected value (expected: {expected}, got: {actual})"
128 )
129 }
130 Self::UnexpectedLen {
131 field,
132 expected,
133 actual,
134 } => {
135 write!(
136 formatter,
137 "field `{field}` has unexpected length (expected: {expected}, got: {actual})"
138 )
139 }
140 Self::MismatchedKeys => {
141 formatter.write_str("private and public keys encoded in JWK do not match")
142 }
143 Self::Custom(err) => fmt::Display::fmt(err, formatter),
144 }
145 }
146}
147
148#[cfg(feature = "std")]
149impl std::error::Error for JwkError {
150 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
151 match self {
152 Self::Custom(err) => Some(err.as_ref()),
153 _ => None,
154 }
155 }
156}
157
158impl JwkError {
159 pub fn custom(err: impl Into<anyhow::Error>) -> Self {
161 Self::Custom(err.into())
162 }
163
164 pub(crate) fn key_type(jwk: &JsonWebKey<'_>, expected: KeyType) -> Self {
165 let actual = jwk.key_type();
166 debug_assert_ne!(actual, expected);
167 Self::UnexpectedKeyType { actual, expected }
168 }
169}
170
171impl Serialize for SecretBytes<'_> {
172 fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
173 base64url::serialize(self.as_ref(), serializer)
174 }
175}
176
177impl<'de> Deserialize<'de> for SecretBytes<'_> {
178 fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
179 base64url::deserialize(deserializer).map(SecretBytes::new)
180 }
181}
182
183#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
223#[serde(tag = "kty")]
224#[non_exhaustive]
225pub enum JsonWebKey<'a> {
226 #[serde(rename = "RSA")]
228 Rsa {
229 #[serde(rename = "n", with = "base64url")]
231 modulus: Cow<'a, [u8]>,
232 #[serde(rename = "e", with = "base64url")]
234 public_exponent: Cow<'a, [u8]>,
235 #[serde(flatten)]
237 private_parts: Option<RsaPrivateParts<'a>>,
238 },
239 #[serde(rename = "EC")]
241 EllipticCurve {
242 #[serde(rename = "crv")]
244 curve: Cow<'a, str>,
245 #[serde(with = "base64url")]
247 x: Cow<'a, [u8]>,
248 #[serde(with = "base64url")]
250 y: Cow<'a, [u8]>,
251 #[serde(rename = "d", default, skip_serializing_if = "Option::is_none")]
253 secret: Option<SecretBytes<'a>>,
254 },
255 #[serde(rename = "oct")]
257 Symmetric {
258 #[serde(rename = "k")]
260 secret: SecretBytes<'a>,
261 },
262 #[serde(rename = "OKP")]
264 KeyPair {
265 #[serde(rename = "crv")]
267 curve: Cow<'a, str>,
268 #[serde(with = "base64url")]
271 x: Cow<'a, [u8]>,
272 #[serde(rename = "d", default, skip_serializing_if = "Option::is_none")]
274 secret: Option<SecretBytes<'a>>,
275 },
276}
277
278impl JsonWebKey<'_> {
279 pub fn key_type(&self) -> KeyType {
281 match self {
282 Self::Rsa { .. } => KeyType::Rsa,
283 Self::EllipticCurve { .. } => KeyType::EllipticCurve,
284 Self::Symmetric { .. } => KeyType::Symmetric,
285 Self::KeyPair { .. } => KeyType::KeyPair,
286 }
287 }
288
289 pub fn is_signing_key(&self) -> bool {
291 match self {
292 Self::Rsa { private_parts, .. } => private_parts.is_some(),
293 Self::EllipticCurve { secret, .. } | Self::KeyPair { secret, .. } => secret.is_some(),
294 Self::Symmetric { .. } => true,
295 }
296 }
297
298 #[must_use]
300 pub fn to_verifying_key(&self) -> Self {
301 match self {
302 Self::Rsa {
303 modulus,
304 public_exponent,
305 ..
306 } => Self::Rsa {
307 modulus: modulus.clone(),
308 public_exponent: public_exponent.clone(),
309 private_parts: None,
310 },
311
312 Self::EllipticCurve { curve, x, y, .. } => Self::EllipticCurve {
313 curve: curve.clone(),
314 x: x.clone(),
315 y: y.clone(),
316 secret: None,
317 },
318
319 Self::Symmetric { secret } => Self::Symmetric {
320 secret: secret.clone(),
321 },
322
323 Self::KeyPair { curve, x, .. } => Self::KeyPair {
324 curve: curve.clone(),
325 x: x.clone(),
326 secret: None,
327 },
328 }
329 }
330
331 pub fn thumbprint<D: Digest>(&self) -> Output<D> {
336 let hashed_key = if self.is_signing_key() {
337 Cow::Owned(self.to_verifying_key())
338 } else {
339 Cow::Borrowed(self)
340 };
341 D::digest(hashed_key.to_string().as_bytes())
342 }
343}
344
345impl fmt::Display for JsonWebKey<'_> {
346 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
348 let json_value = serde_json::to_value(self).expect("Cannot convert JsonWebKey to JSON");
349 let json_value = json_value.as_object().unwrap();
350 let mut json_entries: Vec<_> = json_value.iter().collect();
353 json_entries.sort_unstable_by(|(x, _), (y, _)| x.cmp(y));
354
355 formatter.write_str("{")?;
356 let field_count = json_entries.len();
357 for (i, (name, value)) in json_entries.into_iter().enumerate() {
358 write!(formatter, "\"{name}\":{value}")?;
359 if i + 1 < field_count {
360 formatter.write_str(",")?;
361 }
362 }
363 formatter.write_str("}")
364 }
365}
366
367#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
375pub struct RsaPrivateParts<'a> {
376 #[serde(rename = "d")]
378 pub private_exponent: SecretBytes<'a>,
379 #[serde(rename = "p")]
381 pub prime_factor_p: SecretBytes<'a>,
382 #[serde(rename = "q")]
384 pub prime_factor_q: SecretBytes<'a>,
385 #[serde(rename = "dp", default, skip_serializing_if = "Option::is_none")]
387 pub p_crt_exponent: Option<SecretBytes<'a>>,
388 #[serde(rename = "dq", default, skip_serializing_if = "Option::is_none")]
390 pub q_crt_exponent: Option<SecretBytes<'a>>,
391 #[serde(rename = "qi", default, skip_serializing_if = "Option::is_none")]
393 pub q_crt_coefficient: Option<SecretBytes<'a>>,
394 #[serde(rename = "oth", default, skip_serializing_if = "Vec::is_empty")]
396 pub other_prime_factors: Vec<RsaPrimeFactor<'a>>,
397}
398
399#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
407pub struct RsaPrimeFactor<'a> {
408 #[serde(rename = "r")]
410 pub factor: SecretBytes<'a>,
411 #[serde(rename = "d", default, skip_serializing_if = "Option::is_none")]
413 pub crt_exponent: Option<SecretBytes<'a>>,
414 #[serde(rename = "t", default, skip_serializing_if = "Option::is_none")]
416 pub crt_coefficient: Option<SecretBytes<'a>>,
417}
418
419#[cfg(any(
420 feature = "es256k",
421 feature = "k256",
422 feature = "p256",
423 feature = "exonum-crypto",
424 feature = "ed25519-dalek",
425 feature = "ed25519-compact"
426))]
427mod helpers {
428 use super::{JsonWebKey, JwkError};
429 use crate::{Algorithm, alg::SigningKey, alloc::ToOwned};
430
431 impl JsonWebKey<'_> {
432 pub(crate) fn ensure_curve(curve: &str, expected: &str) -> Result<(), JwkError> {
433 if curve == expected {
434 Ok(())
435 } else {
436 Err(JwkError::UnexpectedValue {
437 field: "crv".to_owned(),
438 expected: expected.to_owned(),
439 actual: curve.to_owned(),
440 })
441 }
442 }
443
444 pub(crate) fn ensure_len(
445 field: &str,
446 bytes: &[u8],
447 expected_len: usize,
448 ) -> Result<(), JwkError> {
449 if bytes.len() == expected_len {
450 Ok(())
451 } else {
452 Err(JwkError::UnexpectedLen {
453 field: field.to_owned(),
454 expected: expected_len,
455 actual: bytes.len(),
456 })
457 }
458 }
459
460 pub(crate) fn ensure_key_match<Alg, K>(&self, signing_key: K) -> Result<K, JwkError>
463 where
464 Alg: Algorithm<SigningKey = K>,
465 K: SigningKey<Alg>,
466 Alg::VerifyingKey: for<'jwk> TryFrom<&'jwk Self, Error = JwkError> + PartialEq,
467 {
468 let verifying_key = <Alg::VerifyingKey>::try_from(self)?;
469 if verifying_key == signing_key.to_verifying_key() {
470 Ok(signing_key)
471 } else {
472 Err(JwkError::MismatchedKeys)
473 }
474 }
475 }
476}
477
478mod base64url {
479 use core::fmt;
480
481 use base64ct::{Base64UrlUnpadded, Encoding};
482 use serde::{
483 Deserializer, Serializer,
484 de::{Error as DeError, Unexpected, Visitor},
485 };
486
487 use crate::alloc::{Cow, Vec};
488
489 pub fn serialize<S>(value: &[u8], serializer: S) -> Result<S::Ok, S::Error>
490 where
491 S: Serializer,
492 {
493 if serializer.is_human_readable() {
494 serializer.serialize_str(&Base64UrlUnpadded::encode_string(value))
495 } else {
496 serializer.serialize_bytes(value)
497 }
498 }
499
500 pub fn deserialize<'de, D>(deserializer: D) -> Result<Cow<'static, [u8]>, D::Error>
501 where
502 D: Deserializer<'de>,
503 {
504 struct Base64Visitor;
505
506 impl Visitor<'_> for Base64Visitor {
507 type Value = Vec<u8>;
508
509 fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
510 formatter.write_str("base64url-encoded data")
511 }
512
513 fn visit_str<E: DeError>(self, value: &str) -> Result<Self::Value, E> {
514 Base64UrlUnpadded::decode_vec(value)
515 .map_err(|_| E::invalid_value(Unexpected::Str(value), &self))
516 }
517
518 fn visit_bytes<E: DeError>(self, value: &[u8]) -> Result<Self::Value, E> {
519 Ok(value.to_vec())
520 }
521
522 fn visit_byte_buf<E: DeError>(self, value: Vec<u8>) -> Result<Self::Value, E> {
523 Ok(value)
524 }
525 }
526
527 struct BytesVisitor;
528
529 impl Visitor<'_> for BytesVisitor {
530 type Value = Vec<u8>;
531
532 fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
533 formatter.write_str("byte buffer")
534 }
535
536 fn visit_bytes<E: DeError>(self, value: &[u8]) -> Result<Self::Value, E> {
537 Ok(value.to_vec())
538 }
539
540 fn visit_byte_buf<E: DeError>(self, value: Vec<u8>) -> Result<Self::Value, E> {
541 Ok(value)
542 }
543 }
544
545 let maybe_bytes = if deserializer.is_human_readable() {
546 deserializer.deserialize_str(Base64Visitor)
547 } else {
548 deserializer.deserialize_bytes(BytesVisitor)
549 };
550 maybe_bytes.map(Cow::Owned)
551 }
552}
553
554#[cfg(test)]
555mod tests {
556 use assert_matches::assert_matches;
557
558 use super::*;
559 use crate::alg::Hs256Key;
560
561 fn create_jwk() -> JsonWebKey<'static> {
562 JsonWebKey::KeyPair {
563 curve: Cow::Borrowed("Ed25519"),
564 x: Cow::Borrowed(b"test"),
565 secret: None,
566 }
567 }
568
569 #[test]
570 fn serializing_jwk() {
571 let jwk = create_jwk();
572
573 let json = serde_json::to_value(&jwk).unwrap();
574 assert_eq!(
575 json,
576 serde_json::json!({ "crv": "Ed25519", "kty": "OKP", "x": "dGVzdA" })
577 );
578
579 let restored: JsonWebKey<'_> = serde_json::from_value(json).unwrap();
580 assert_eq!(restored, jwk);
581 }
582
583 #[test]
584 fn jwk_deserialization_errors() {
585 let missing_field_json = r#"{"crv":"Ed25519"}"#;
586 let missing_field_err = serde_json::from_str::<JsonWebKey<'_>>(missing_field_json)
587 .unwrap_err()
588 .to_string();
589 assert!(
590 missing_field_err.contains("missing field `kty`"),
591 "{missing_field_err}"
592 );
593
594 let base64_json = r#"{"crv":"Ed25519","kty":"OKP","x":"??"}"#;
595 let base64_err = serde_json::from_str::<JsonWebKey<'_>>(base64_json)
596 .unwrap_err()
597 .to_string();
598 assert!(
599 base64_err.contains("invalid value: string \"??\""),
600 "{base64_err}"
601 );
602 assert!(
603 base64_err.contains("base64url-encoded data"),
604 "{base64_err}"
605 );
606 }
607
608 #[test]
609 fn extra_jwk_fields() {
610 #[derive(Debug, Serialize, Deserialize)]
611 struct ExtendedJsonWebKey<'a, T> {
612 #[serde(flatten)]
613 base: JsonWebKey<'a>,
614 #[serde(flatten)]
615 extra: T,
616 }
617
618 #[derive(Debug, Deserialize)]
619 struct Extra {
620 #[serde(rename = "kid")]
621 key_id: String,
622 #[serde(rename = "use")]
623 key_use: KeyUse,
624 }
625
626 #[derive(Debug, Deserialize, PartialEq)]
627 enum KeyUse {
628 #[serde(rename = "sig")]
629 Signature,
630 #[serde(rename = "enc")]
631 Encryption,
632 }
633
634 let json_str = r#"
635 { "kty": "oct", "kid": "my-unique-key", "k": "dGVzdA", "use": "sig" }
636 "#;
637 let jwk: ExtendedJsonWebKey<'_, Extra> = serde_json::from_str(json_str).unwrap();
638
639 assert_matches!(&jwk.base, JsonWebKey::Symmetric { secret } if secret.as_ref() == b"test");
640 assert_eq!(jwk.extra.key_id, "my-unique-key");
641 assert_eq!(jwk.extra.key_use, KeyUse::Signature);
642
643 let key = Hs256Key::try_from(&jwk.base).unwrap();
644 let jwk_from_key = JsonWebKey::from(&key);
645
646 assert_matches!(
647 jwk_from_key,
648 JsonWebKey::Symmetric { secret } if secret.as_ref() == b"test"
649 );
650 }
651
652 #[test]
653 #[cfg(feature = "ciborium")]
654 fn jwk_with_cbor() {
655 let key = JsonWebKey::KeyPair {
656 curve: Cow::Borrowed("Ed25519"),
657 x: Cow::Borrowed(b"public"),
658 secret: Some(SecretBytes::borrowed(b"private")),
659 };
660 let mut bytes = vec![];
661 ciborium::into_writer(&key, &mut bytes).unwrap();
662 assert!(bytes.windows(6).any(|window| window == b"public"));
663 assert!(bytes.windows(7).any(|window| window == b"private"));
664
665 let restored: JsonWebKey<'_> = ciborium::from_reader(&bytes[..]).unwrap();
666 assert_eq!(restored, key);
667 }
668}