1use chrono::{DateTime, Duration, Utc};
2use serde::{Deserialize, Serialize};
3
4use crate::{Claim, ValidationError};
5
6#[derive(Debug, Clone, Copy)]
27#[non_exhaustive]
28pub struct TimeOptions<F = fn() -> DateTime<Utc>> {
29 pub leeway: Duration,
31 pub clock_fn: F,
33}
34
35impl<F: Fn() -> DateTime<Utc>> TimeOptions<F> {
36 pub fn new(leeway: Duration, clock_fn: F) -> Self {
38 Self { leeway, clock_fn }
39 }
40}
41
42impl TimeOptions {
43 #[cfg(feature = "clock")]
45 #[cfg_attr(docsrs, doc(cfg(feature = "clock")))]
46 pub fn from_leeway(leeway: Duration) -> Self {
47 Self {
48 leeway,
49 clock_fn: Utc::now,
50 }
51 }
52}
53
54#[cfg(feature = "clock")]
58impl Default for TimeOptions {
59 fn default() -> Self {
60 Self::from_leeway(Duration::try_seconds(60).unwrap())
61 }
62}
63
64#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Hash, Serialize, Deserialize)]
66pub struct Empty {}
67
68#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
76#[non_exhaustive]
77pub struct Claims<T> {
78 #[serde(
80 rename = "exp",
81 default,
82 skip_serializing_if = "Option::is_none",
83 with = "self::serde_timestamp"
84 )]
85 pub expiration: Option<DateTime<Utc>>,
86
87 #[serde(
89 rename = "nbf",
90 default,
91 skip_serializing_if = "Option::is_none",
92 with = "self::serde_timestamp"
93 )]
94 pub not_before: Option<DateTime<Utc>>,
95
96 #[serde(
98 rename = "iat",
99 default,
100 skip_serializing_if = "Option::is_none",
101 with = "self::serde_timestamp"
102 )]
103 pub issued_at: Option<DateTime<Utc>>,
104
105 #[serde(flatten)]
107 pub custom: T,
108}
109
110impl Claims<Empty> {
111 pub fn empty() -> Self {
113 Self {
114 expiration: None,
115 not_before: None,
116 issued_at: None,
117 custom: Empty {},
118 }
119 }
120}
121
122impl<T> Claims<T> {
123 pub fn new(custom_claims: T) -> Self {
125 Self {
126 expiration: None,
127 not_before: None,
128 issued_at: None,
129 custom: custom_claims,
130 }
131 }
132
133 #[must_use]
136 pub fn set_duration<F>(self, options: &TimeOptions<F>, duration: Duration) -> Self
137 where
138 F: Fn() -> DateTime<Utc>,
139 {
140 Self {
141 expiration: Some((options.clock_fn)() + duration),
142 ..self
143 }
144 }
145
146 #[must_use]
149 pub fn set_duration_and_issuance<F>(self, options: &TimeOptions<F>, duration: Duration) -> Self
150 where
151 F: Fn() -> DateTime<Utc>,
152 {
153 let issued_at = (options.clock_fn)();
154 Self {
155 expiration: Some(issued_at + duration),
156 issued_at: Some(issued_at),
157 ..self
158 }
159 }
160
161 #[must_use]
163 pub fn set_not_before(self, moment: DateTime<Utc>) -> Self {
164 Self {
165 not_before: Some(moment),
166 ..self
167 }
168 }
169
170 pub fn validate_expiration<F>(&self, options: &TimeOptions<F>) -> Result<&Self, ValidationError>
175 where
176 F: Fn() -> DateTime<Utc>,
177 {
178 self.expiration.map_or(
179 Err(ValidationError::NoClaim(Claim::Expiration)),
180 |expiration| {
181 let expiration_with_leeway = expiration
182 .checked_add_signed(options.leeway)
183 .unwrap_or(DateTime::<Utc>::MAX_UTC);
184 if (options.clock_fn)() > expiration_with_leeway {
185 Err(ValidationError::Expired)
186 } else {
187 Ok(self)
188 }
189 },
190 )
191 }
192
193 pub fn validate_maturity<F>(&self, options: &TimeOptions<F>) -> Result<&Self, ValidationError>
198 where
199 F: Fn() -> DateTime<Utc>,
200 {
201 self.not_before.map_or(
202 Err(ValidationError::NoClaim(Claim::NotBefore)),
203 |not_before| {
204 if (options.clock_fn)() < not_before - options.leeway {
205 Err(ValidationError::NotMature)
206 } else {
207 Ok(self)
208 }
209 },
210 )
211 }
212}
213
214mod serde_timestamp {
215 use core::fmt;
216
217 use chrono::{DateTime, Utc, offset::TimeZone};
218 use serde::{
219 Deserializer, Serializer,
220 de::{Error as DeError, Visitor},
221 };
222
223 struct TimestampVisitor;
224
225 impl Visitor<'_> for TimestampVisitor {
226 type Value = DateTime<Utc>;
227
228 fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
229 formatter.write_str("UTC timestamp")
230 }
231
232 fn visit_i64<E>(self, value: i64) -> Result<Self::Value, E>
233 where
234 E: DeError,
235 {
236 Utc.timestamp_opt(value, 0)
237 .single()
238 .ok_or_else(|| E::custom("UTC timestamp overflow"))
239 }
240
241 fn visit_u64<E>(self, value: u64) -> Result<Self::Value, E>
242 where
243 E: DeError,
244 {
245 let value = i64::try_from(value).map_err(DeError::custom)?;
246 Utc.timestamp_opt(value, 0)
247 .single()
248 .ok_or_else(|| E::custom("UTC timestamp overflow"))
249 }
250
251 #[allow(clippy::cast_possible_truncation)]
252 fn visit_f64<E>(self, value: f64) -> Result<Self::Value, E>
254 where
255 E: DeError,
256 {
257 Utc.timestamp_opt(value as i64, 0)
258 .single()
259 .ok_or_else(|| E::custom("UTC timestamp overflow"))
260 }
261 }
262
263 #[allow(unknown_lints, clippy::ref_option)] pub fn serialize<S: Serializer>(
265 time: &Option<DateTime<Utc>>,
266 serializer: S,
267 ) -> Result<S::Ok, S::Error> {
268 serializer.serialize_i64(time.unwrap().timestamp())
270 }
271
272 pub fn deserialize<'de, D: Deserializer<'de>>(
273 deserializer: D,
274 ) -> Result<Option<DateTime<Utc>>, D::Error> {
275 deserializer.deserialize_i64(TimestampVisitor).map(Some)
276 }
277}
278
279#[cfg(all(test, feature = "clock"))]
280mod tests {
281 use assert_matches::assert_matches;
282 use chrono::TimeZone;
283
284 use super::*;
285
286 #[test]
287 fn empty_claims_can_be_serialized() {
288 let mut claims = Claims::empty();
289 assert!(serde_json::to_string(&claims).is_ok());
290 claims.expiration = Some(Utc::now());
291 assert!(serde_json::to_string(&claims).is_ok());
292 claims.not_before = Some(Utc::now());
293 assert!(serde_json::to_string(&claims).is_ok());
294 }
295
296 #[test]
297 #[cfg(feature = "ciborium")]
298 fn empty_claims_can_be_serialized_to_cbor() {
299 let mut claims = Claims::empty();
300 assert!(ciborium::into_writer(&claims, &mut vec![]).is_ok());
301 claims.expiration = Some(Utc::now());
302 assert!(ciborium::into_writer(&claims, &mut vec![]).is_ok());
303 claims.not_before = Some(Utc::now());
304 assert!(ciborium::into_writer(&claims, &mut vec![]).is_ok());
305 }
306
307 #[test]
308 fn expired_claim() {
309 let mut claims = Claims::empty();
310 let time_options = TimeOptions::default();
311 assert_matches!(
312 claims.validate_expiration(&time_options).unwrap_err(),
313 ValidationError::NoClaim(Claim::Expiration)
314 );
315
316 claims.expiration = Some(DateTime::<Utc>::MAX_UTC);
317 assert!(claims.validate_expiration(&time_options).is_ok());
318
319 claims.expiration = Some(Utc::now() - Duration::try_hours(1).unwrap());
320 assert_matches!(
321 claims.validate_expiration(&time_options).unwrap_err(),
322 ValidationError::Expired
323 );
324
325 claims.expiration = Some(Utc::now() - Duration::try_seconds(10).unwrap());
326 assert!(claims.validate_expiration(&time_options).is_ok());
328 assert_matches!(
330 claims
331 .validate_expiration(&TimeOptions::from_leeway(Duration::try_seconds(5).unwrap()))
332 .unwrap_err(),
333 ValidationError::Expired
334 );
335 let expiration = claims.expiration.unwrap();
337 assert!(
338 claims
339 .validate_expiration(&TimeOptions::new(
340 Duration::try_seconds(3).unwrap(),
341 move || { expiration }
342 ))
343 .is_ok()
344 );
345 }
346
347 #[test]
348 fn immature_claim() {
349 let mut claims = Claims::empty();
350 let time_options = TimeOptions::default();
351 assert_matches!(
352 claims.validate_maturity(&time_options).unwrap_err(),
353 ValidationError::NoClaim(Claim::NotBefore)
354 );
355
356 claims.not_before = Some(Utc::now() + Duration::try_hours(1).unwrap());
357 assert_matches!(
358 claims.validate_maturity(&time_options).unwrap_err(),
359 ValidationError::NotMature
360 );
361
362 claims.not_before = Some(Utc::now() + Duration::try_seconds(10).unwrap());
363 assert!(claims.validate_maturity(&time_options).is_ok());
365 assert_matches!(
367 claims
368 .validate_maturity(&TimeOptions::from_leeway(Duration::try_seconds(5).unwrap()))
369 .unwrap_err(),
370 ValidationError::NotMature
371 );
372 }
373 #[test]
374 fn float_timestamp() {
375 let claims = "{\"exp\": 1.691203462e+9}";
376 let claims: Claims<Empty> = serde_json::from_str(claims).unwrap();
377 let timestamp = Utc.timestamp_opt(1_691_203_462, 0).single().unwrap();
378 assert_eq!(claims.expiration, Some(timestamp));
379 }
380
381 #[test]
382 fn float_timestamp_errors() {
383 let invalid_claims = ["{\"exp\": 1e20}", "{\"exp\": -1e20}"];
384 for claims in invalid_claims {
385 let err = serde_json::from_str::<Claims<Empty>>(claims).unwrap_err();
386 let err = err.to_string();
387 assert!(err.contains("UTC timestamp overflow"), "{err}");
388 }
389 }
390}