use chrono::{DateTime, Duration, Utc};
use serde::{Deserialize, Serialize};
use crate::{Claim, ValidationError};
#[derive(Debug, Clone, Copy)]
#[non_exhaustive]
pub struct TimeOptions<F = fn() -> DateTime<Utc>> {
pub leeway: Duration,
pub clock_fn: F,
}
impl<F: Fn() -> DateTime<Utc>> TimeOptions<F> {
pub fn new(leeway: Duration, clock_fn: F) -> Self {
Self { leeway, clock_fn }
}
}
impl TimeOptions {
#[cfg(feature = "clock")]
#[cfg_attr(docsrs, doc(cfg(feature = "clock")))]
pub fn from_leeway(leeway: Duration) -> Self {
Self {
leeway,
clock_fn: Utc::now,
}
}
}
#[cfg(feature = "clock")]
impl Default for TimeOptions {
fn default() -> Self {
Self::from_leeway(Duration::try_seconds(60).unwrap())
}
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct Empty {}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[non_exhaustive]
pub struct Claims<T> {
#[serde(
rename = "exp",
default,
skip_serializing_if = "Option::is_none",
with = "self::serde_timestamp"
)]
pub expiration: Option<DateTime<Utc>>,
#[serde(
rename = "nbf",
default,
skip_serializing_if = "Option::is_none",
with = "self::serde_timestamp"
)]
pub not_before: Option<DateTime<Utc>>,
#[serde(
rename = "iat",
default,
skip_serializing_if = "Option::is_none",
with = "self::serde_timestamp"
)]
pub issued_at: Option<DateTime<Utc>>,
#[serde(flatten)]
pub custom: T,
}
impl Claims<Empty> {
pub fn empty() -> Self {
Self {
expiration: None,
not_before: None,
issued_at: None,
custom: Empty {},
}
}
}
impl<T> Claims<T> {
pub fn new(custom_claims: T) -> Self {
Self {
expiration: None,
not_before: None,
issued_at: None,
custom: custom_claims,
}
}
#[must_use]
pub fn set_duration<F>(self, options: &TimeOptions<F>, duration: Duration) -> Self
where
F: Fn() -> DateTime<Utc>,
{
Self {
expiration: Some((options.clock_fn)() + duration),
..self
}
}
#[must_use]
pub fn set_duration_and_issuance<F>(self, options: &TimeOptions<F>, duration: Duration) -> Self
where
F: Fn() -> DateTime<Utc>,
{
let issued_at = (options.clock_fn)();
Self {
expiration: Some(issued_at + duration),
issued_at: Some(issued_at),
..self
}
}
#[must_use]
pub fn set_not_before(self, moment: DateTime<Utc>) -> Self {
Self {
not_before: Some(moment),
..self
}
}
pub fn validate_expiration<F>(&self, options: &TimeOptions<F>) -> Result<&Self, ValidationError>
where
F: Fn() -> DateTime<Utc>,
{
self.expiration.map_or(
Err(ValidationError::NoClaim(Claim::Expiration)),
|expiration| {
let expiration_with_leeway = expiration
.checked_add_signed(options.leeway)
.unwrap_or(DateTime::<Utc>::MAX_UTC);
if (options.clock_fn)() > expiration_with_leeway {
Err(ValidationError::Expired)
} else {
Ok(self)
}
},
)
}
pub fn validate_maturity<F>(&self, options: &TimeOptions<F>) -> Result<&Self, ValidationError>
where
F: Fn() -> DateTime<Utc>,
{
self.not_before.map_or(
Err(ValidationError::NoClaim(Claim::NotBefore)),
|not_before| {
if (options.clock_fn)() < not_before - options.leeway {
Err(ValidationError::NotMature)
} else {
Ok(self)
}
},
)
}
}
mod serde_timestamp {
use core::fmt;
use chrono::{offset::TimeZone, DateTime, Utc};
use serde::{
de::{Error as DeError, Visitor},
Deserializer, Serializer,
};
struct TimestampVisitor;
impl<'de> Visitor<'de> for TimestampVisitor {
type Value = DateTime<Utc>;
fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
formatter.write_str("UTC timestamp")
}
fn visit_i64<E>(self, value: i64) -> Result<Self::Value, E>
where
E: DeError,
{
Utc.timestamp_opt(value, 0)
.single()
.ok_or_else(|| E::custom("UTC timestamp overflow"))
}
fn visit_u64<E>(self, value: u64) -> Result<Self::Value, E>
where
E: DeError,
{
let value = i64::try_from(value).map_err(DeError::custom)?;
Utc.timestamp_opt(value, 0)
.single()
.ok_or_else(|| E::custom("UTC timestamp overflow"))
}
#[allow(clippy::cast_possible_truncation)]
fn visit_f64<E>(self, value: f64) -> Result<Self::Value, E>
where
E: DeError,
{
Utc.timestamp_opt(value as i64, 0)
.single()
.ok_or_else(|| E::custom("UTC timestamp overflow"))
}
}
pub fn serialize<S: Serializer>(
time: &Option<DateTime<Utc>>,
serializer: S,
) -> Result<S::Ok, S::Error> {
serializer.serialize_i64(time.unwrap().timestamp())
}
pub fn deserialize<'de, D: Deserializer<'de>>(
deserializer: D,
) -> Result<Option<DateTime<Utc>>, D::Error> {
deserializer.deserialize_i64(TimestampVisitor).map(Some)
}
}
#[cfg(all(test, feature = "clock"))]
mod tests {
use assert_matches::assert_matches;
use chrono::TimeZone;
use super::*;
#[test]
fn empty_claims_can_be_serialized() {
let mut claims = Claims::empty();
assert!(serde_json::to_string(&claims).is_ok());
claims.expiration = Some(Utc::now());
assert!(serde_json::to_string(&claims).is_ok());
claims.not_before = Some(Utc::now());
assert!(serde_json::to_string(&claims).is_ok());
}
#[test]
#[cfg(feature = "ciborium")]
fn empty_claims_can_be_serialized_to_cbor() {
let mut claims = Claims::empty();
assert!(ciborium::into_writer(&claims, &mut vec![]).is_ok());
claims.expiration = Some(Utc::now());
assert!(ciborium::into_writer(&claims, &mut vec![]).is_ok());
claims.not_before = Some(Utc::now());
assert!(ciborium::into_writer(&claims, &mut vec![]).is_ok());
}
#[test]
fn expired_claim() {
let mut claims = Claims::empty();
let time_options = TimeOptions::default();
assert_matches!(
claims.validate_expiration(&time_options).unwrap_err(),
ValidationError::NoClaim(Claim::Expiration)
);
claims.expiration = Some(DateTime::<Utc>::MAX_UTC);
assert!(claims.validate_expiration(&time_options).is_ok());
claims.expiration = Some(Utc::now() - Duration::try_hours(1).unwrap());
assert_matches!(
claims.validate_expiration(&time_options).unwrap_err(),
ValidationError::Expired
);
claims.expiration = Some(Utc::now() - Duration::try_seconds(10).unwrap());
assert!(claims.validate_expiration(&time_options).is_ok());
assert_matches!(
claims
.validate_expiration(&TimeOptions::from_leeway(Duration::try_seconds(5).unwrap()))
.unwrap_err(),
ValidationError::Expired
);
let expiration = claims.expiration.unwrap();
assert!(claims
.validate_expiration(&TimeOptions::new(
Duration::try_seconds(3).unwrap(),
move || { expiration }
))
.is_ok());
}
#[test]
fn immature_claim() {
let mut claims = Claims::empty();
let time_options = TimeOptions::default();
assert_matches!(
claims.validate_maturity(&time_options).unwrap_err(),
ValidationError::NoClaim(Claim::NotBefore)
);
claims.not_before = Some(Utc::now() + Duration::try_hours(1).unwrap());
assert_matches!(
claims.validate_maturity(&time_options).unwrap_err(),
ValidationError::NotMature
);
claims.not_before = Some(Utc::now() + Duration::try_seconds(10).unwrap());
assert!(claims.validate_maturity(&time_options).is_ok());
assert_matches!(
claims
.validate_maturity(&TimeOptions::from_leeway(Duration::try_seconds(5).unwrap()))
.unwrap_err(),
ValidationError::NotMature
);
}
#[test]
fn float_timestamp() {
let claims = "{\"exp\": 1.691203462e+9}";
let claims: Claims<Empty> = serde_json::from_str(claims).unwrap();
let timestamp = Utc.timestamp_opt(1_691_203_462, 0).single().unwrap();
assert_eq!(claims.expiration, Some(timestamp));
}
#[test]
fn float_timestamp_errors() {
let invalid_claims = ["{\"exp\": 1e20}", "{\"exp\": -1e20}"];
for claims in invalid_claims {
let err = serde_json::from_str::<Claims<Empty>>(claims).unwrap_err();
let err = err.to_string();
assert!(err.contains("UTC timestamp overflow"), "{err}");
}
}
}