1use anyhow::Context;
2use serde::{Deserialize, Serialize};
3use std::{
4 collections::HashMap,
5 net::SocketAddr,
6 sync::Arc,
7 time::{Duration, Instant},
8};
9use tokio::sync::{RwLock, RwLockReadGuard};
10use utoipa::ToSchema;
11
12#[derive(Debug, Clone, Copy, ToSchema, Deserialize, Serialize)]
13pub struct NtpOffset {
14 offset_micros: i64,
15}
16
17impl NtpOffset {
18 #[inline]
19 pub fn is_negative(self) -> bool {
20 self.offset_micros.is_negative()
21 }
22
23 #[inline]
24 pub fn abs_duration(self) -> Duration {
25 Duration::from_micros(self.offset_micros.unsigned_abs())
26 }
27}
28
29pub struct Ntp {
30 last_check: RwLock<Instant>,
31 last_result: RwLock<HashMap<SocketAddr, NtpOffset>>,
32}
33
34impl Ntp {
35 pub fn new() -> Arc<Self> {
36 let ntp = Arc::new(Self {
37 last_check: RwLock::new(Instant::now()),
38 last_result: RwLock::new(HashMap::new()),
39 });
40
41 tokio::spawn({
42 let ntp = ntp.clone();
43
44 async move {
45 let result = match check_ntp().await {
46 Ok(result) => result,
47 Err(err) => {
48 tracing::error!("error while checking ntp time: {:?}", err);
49 return;
50 }
51 };
52
53 ntp.update_result(result).await;
54 }
55 });
56
57 ntp
58 }
59
60 async fn update_result(&self, result: HashMap<SocketAddr, NtpOffset>) {
61 *self.last_check.write().await = Instant::now();
62 *self.last_result.write().await = result;
63 }
64
65 pub async fn recheck_ntp(&self) -> Result<(), anyhow::Error> {
66 let result = check_ntp().await?;
67 self.update_result(result).await;
68
69 Ok(())
70 }
71
72 pub async fn get_last_result(&self) -> RwLockReadGuard<'_, HashMap<SocketAddr, NtpOffset>> {
73 self.last_result.read().await
74 }
75}
76
77pub async fn check_ntp() -> Result<HashMap<SocketAddr, NtpOffset>, anyhow::Error> {
78 let socket = tokio::net::UdpSocket::bind("0.0.0.0:0").await?;
79 let socket = sntpc_net_tokio::UdpSocketWrapper::from(socket);
80 let context = sntpc::NtpContext::new(sntpc::StdTimestampGen::default());
81
82 let pool_ntp_addrs = tokio::net::lookup_host(("pool.ntp.org", 123))
83 .await
84 .context("failed to resolve pool.ntp.org")?;
85
86 let get_pool_time = async |addr: SocketAddr| {
87 tokio::time::timeout(
88 Duration::from_secs(2),
89 sntpc::get_time(addr, &socket, context),
90 )
91 .await?
92 .map_err(|err| std::io::Error::other(format!("{:?}", err)))
93 .context("failed to get time from pool.ntp.org")
94 };
95
96 let mut result = HashMap::new();
97
98 for pool_ntp_addr in pool_ntp_addrs {
99 let pool_time = match get_pool_time(pool_ntp_addr).await {
100 Ok(time) => time,
101 Err(err) => {
102 tracing::warn!("failed to get time from {:?}: {:?}", pool_ntp_addr, err);
103 continue;
104 }
105 };
106
107 result.insert(
108 pool_ntp_addr,
109 NtpOffset {
110 offset_micros: pool_time.offset(),
111 },
112 );
113 }
114
115 Ok(result)
116}