shared/
ntp.rs

1use anyhow::Context;
2use serde::{Deserialize, Serialize};
3use std::{
4    collections::HashMap,
5    net::SocketAddr,
6    sync::Arc,
7    time::{Duration, Instant},
8};
9use tokio::sync::{RwLock, RwLockReadGuard};
10use utoipa::ToSchema;
11
12#[derive(Debug, Clone, Copy, ToSchema, Deserialize, Serialize)]
13pub struct NtpOffset {
14    offset_micros: i64,
15}
16
17impl NtpOffset {
18    #[inline]
19    pub fn is_negative(self) -> bool {
20        self.offset_micros.is_negative()
21    }
22
23    #[inline]
24    pub fn abs_duration(self) -> Duration {
25        Duration::from_micros(self.offset_micros.unsigned_abs())
26    }
27}
28
29pub struct Ntp {
30    last_check: RwLock<Instant>,
31    last_result: RwLock<HashMap<SocketAddr, NtpOffset>>,
32}
33
34impl Ntp {
35    pub fn new() -> Arc<Self> {
36        let ntp = Arc::new(Self {
37            last_check: RwLock::new(Instant::now()),
38            last_result: RwLock::new(HashMap::new()),
39        });
40
41        tokio::spawn({
42            let ntp = ntp.clone();
43
44            async move {
45                let result = match check_ntp().await {
46                    Ok(result) => result,
47                    Err(err) => {
48                        tracing::error!("error while checking ntp time: {:?}", err);
49                        return;
50                    }
51                };
52
53                ntp.update_result(result).await;
54            }
55        });
56
57        ntp
58    }
59
60    async fn update_result(&self, result: HashMap<SocketAddr, NtpOffset>) {
61        *self.last_check.write().await = Instant::now();
62        *self.last_result.write().await = result;
63    }
64
65    pub async fn recheck_ntp(&self) -> Result<(), anyhow::Error> {
66        let result = check_ntp().await?;
67        self.update_result(result).await;
68
69        Ok(())
70    }
71
72    pub async fn get_last_result(&self) -> RwLockReadGuard<'_, HashMap<SocketAddr, NtpOffset>> {
73        self.last_result.read().await
74    }
75}
76
77pub async fn check_ntp() -> Result<HashMap<SocketAddr, NtpOffset>, anyhow::Error> {
78    let socket = tokio::net::UdpSocket::bind("0.0.0.0:0").await?;
79    let socket = sntpc_net_tokio::UdpSocketWrapper::from(socket);
80    let context = sntpc::NtpContext::new(sntpc::StdTimestampGen::default());
81
82    let pool_ntp_addrs = tokio::net::lookup_host(("pool.ntp.org", 123))
83        .await
84        .context("failed to resolve pool.ntp.org")?;
85
86    let get_pool_time = async |addr: SocketAddr| {
87        tokio::time::timeout(
88            Duration::from_secs(2),
89            sntpc::get_time(addr, &socket, context),
90        )
91        .await?
92        .map_err(|err| std::io::Error::other(format!("{:?}", err)))
93        .context("failed to get time from pool.ntp.org")
94    };
95
96    let mut result = HashMap::new();
97
98    for pool_ntp_addr in pool_ntp_addrs {
99        let pool_time = match get_pool_time(pool_ntp_addr).await {
100            Ok(time) => time,
101            Err(err) => {
102                tracing::warn!("failed to get time from {:?}: {:?}", pool_ntp_addr, err);
103                continue;
104            }
105        };
106
107        result.insert(
108            pool_ntp_addr,
109            NtpOffset {
110                offset_micros: pool_time.offset(),
111            },
112        );
113    }
114
115    Ok(result)
116}