shared/
env.rs

1use anyhow::Context;
2use axum::{extract::ConnectInfo, http::HeaderMap};
3use colored::Colorize;
4use dotenvy::dotenv;
5use std::sync::{Arc, atomic::AtomicBool};
6use tracing_subscriber::fmt::writer::MakeWriterExt;
7
8#[derive(Clone)]
9pub enum RedisMode {
10    Redis {
11        redis_url: Option<String>,
12    },
13    Sentinel {
14        cluster_name: String,
15        redis_sentinels: Vec<String>,
16    },
17}
18
19impl std::fmt::Display for RedisMode {
20    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
21        match self {
22            RedisMode::Redis { .. } => write!(f, "Redis"),
23            RedisMode::Sentinel { .. } => write!(f, "Sentinel"),
24        }
25    }
26}
27
28pub struct EnvGuard(
29    pub Option<tracing_appender::non_blocking::WorkerGuard>,
30    pub tracing_appender::non_blocking::WorkerGuard,
31);
32
33pub struct Env {
34    pub redis_mode: RedisMode,
35
36    pub sentry_url: Option<String>,
37    pub database_migrate: bool,
38    pub database_url: String,
39    pub database_url_primary: Option<String>,
40
41    pub bind: String,
42    pub port: u16,
43
44    pub app_primary: bool,
45    pub app_debug: AtomicBool,
46    pub app_use_decryption_cache: bool,
47    pub app_use_internal_cache: bool,
48    pub app_trusted_proxies: Vec<cidr::IpCidr>,
49    pub app_log_directory: Option<String>,
50    pub app_encryption_key: String,
51    pub server_name: Option<String>,
52}
53
54impl Env {
55    pub fn parse() -> Result<(Arc<Self>, EnvGuard), anyhow::Error> {
56        dotenv().ok();
57
58        let env = Self {
59            redis_mode: match std::env::var("REDIS_MODE")
60                .unwrap_or("redis".to_string())
61                .trim_matches('"')
62            {
63                "redis" => RedisMode::Redis {
64                    redis_url: std::env::var("REDIS_URL")
65                        .ok()
66                        .map(|s| s.trim_matches('"').to_string()),
67                },
68                "sentinel" => RedisMode::Sentinel {
69                    cluster_name: std::env::var("REDIS_SENTINEL_CLUSTER")
70                        .context("REDIS_SENTINEL_CLUSTER is required")?
71                        .trim_matches('"')
72                        .to_string(),
73                    redis_sentinels: std::env::var("REDIS_SENTINELS")
74                        .context("REDIS_SENTINELS is required")?
75                        .trim_matches('"')
76                        .split(',')
77                        .map(|s| s.to_string())
78                        .collect(),
79                },
80                _ => {
81                    return Err(anyhow::anyhow!(
82                        "Invalid REDIS_MODE. Expected 'redis' or 'sentinel'."
83                    ));
84                }
85            },
86
87            sentry_url: std::env::var("SENTRY_URL")
88                .ok()
89                .map(|s| s.trim_matches('"').to_string()),
90            database_migrate: std::env::var("DATABASE_MIGRATE")
91                .unwrap_or("false".to_string())
92                .trim_matches('"')
93                .parse()
94                .unwrap(),
95            database_url: std::env::var("DATABASE_URL")
96                .context("DATABASE_URL is required")?
97                .trim_matches('"')
98                .to_string(),
99            database_url_primary: std::env::var("DATABASE_URL_PRIMARY")
100                .ok()
101                .map(|s| s.trim_matches('"').to_string()),
102
103            bind: std::env::var("BIND")
104                .unwrap_or("0.0.0.0".to_string())
105                .trim_matches('"')
106                .to_string(),
107            port: std::env::var("PORT")
108                .unwrap_or("6969".to_string())
109                .parse()
110                .context("Invalid PORT value")?,
111
112            app_primary: std::env::var("APP_PRIMARY")
113                .unwrap_or("true".to_string())
114                .trim_matches('"')
115                .parse()
116                .context("Invalid APP_DEBUG value")?,
117            app_debug: AtomicBool::new(
118                std::env::var("APP_DEBUG")
119                    .unwrap_or("false".to_string())
120                    .trim_matches('"')
121                    .parse()
122                    .context("Invalid APP_DEBUG value")?,
123            ),
124            app_use_decryption_cache: std::env::var("APP_USE_DECRYPTION_CACHE")
125                .unwrap_or("false".to_string())
126                .trim_matches('"')
127                .parse()
128                .context("Invalid APP_USE_DECRYPTION_CACHE value")?,
129            app_use_internal_cache: std::env::var("APP_USE_INTERNAL_CACHE")
130                .unwrap_or("true".to_string())
131                .trim_matches('"')
132                .parse()
133                .context("Invalid APP_USE_INTERNAL_CACHE value")?,
134            app_trusted_proxies: std::env::var("APP_TRUSTED_PROXIES")
135                .unwrap_or("".to_string())
136                .trim_matches('"')
137                .split(',')
138                .filter_map(|s| if s.is_empty() { None } else { s.parse().ok() })
139                .collect(),
140            app_log_directory: std::env::var("APP_LOG_DIRECTORY")
141                .ok()
142                .map(|s| s.trim_matches('"').to_string()),
143            app_encryption_key: std::env::var("APP_ENCRYPTION_KEY")
144                .expect("APP_ENCRYPTION_KEY is required")
145                .trim_matches('"')
146                .to_string(),
147            server_name: std::env::var("SERVER_NAME")
148                .ok()
149                .map(|s| s.trim_matches('"').to_string()),
150        };
151
152        if env.app_encryption_key.to_lowercase() == "changeme" {
153            println!(
154                "{}", "You are using the default APP_ENCRYPTION_KEY. This is unsupported, please modify your .env or your docker compose file.".red()
155            );
156            std::process::exit(1);
157        }
158
159        let (stdout_writer, stdout_guard) = tracing_appender::non_blocking(std::io::stdout());
160
161        let (appender, guard) = if let Some(app_log_directory) = &env.app_log_directory {
162            if !std::path::Path::new(app_log_directory).exists() {
163                std::fs::create_dir_all(app_log_directory)
164                    .context("failed to create log directory")?;
165            }
166
167            let latest_log_path = std::path::Path::new(&app_log_directory).join("panel.log");
168            let latest_file = std::fs::OpenOptions::new()
169                .create(true)
170                .append(true)
171                .open(&latest_log_path)
172                .context("failed to open latest log file")?;
173
174            let rolling_appender = tracing_appender::rolling::Builder::new()
175                .filename_prefix("panel")
176                .filename_suffix("log")
177                .max_log_files(30)
178                .rotation(tracing_appender::rolling::Rotation::DAILY)
179                .build(app_log_directory)
180                .context("failed to create rolling log file appender")?;
181
182            let (appender, guard) = tracing_appender::non_blocking::NonBlockingBuilder::default()
183                .buffered_lines_limit(50)
184                .finish(latest_file.and(rolling_appender));
185
186            (Some(appender), Some(guard))
187        } else {
188            (None, None)
189        };
190
191        if let Some(file_appender) = appender {
192            tracing::subscriber::set_global_default(
193                tracing_subscriber::fmt()
194                    .with_timer(tracing_subscriber::fmt::time::ChronoLocal::rfc_3339())
195                    .with_writer(stdout_writer.and(file_appender))
196                    .with_target(false)
197                    .with_level(true)
198                    .with_file(true)
199                    .with_line_number(true)
200                    .with_max_level(if env.is_debug() {
201                        tracing::Level::DEBUG
202                    } else {
203                        tracing::Level::INFO
204                    })
205                    .finish(),
206            )?;
207        } else {
208            tracing::subscriber::set_global_default(
209                tracing_subscriber::fmt()
210                    .with_timer(tracing_subscriber::fmt::time::ChronoLocal::new(
211                        "%Y-%m-%d %H:%M:%S %z".to_string(),
212                    ))
213                    .with_writer(stdout_writer)
214                    .with_target(false)
215                    .with_level(true)
216                    .with_file(true)
217                    .with_line_number(true)
218                    .with_max_level(if env.is_debug() {
219                        tracing::Level::DEBUG
220                    } else {
221                        tracing::Level::INFO
222                    })
223                    .finish(),
224            )?;
225        }
226
227        Ok((Arc::new(env), EnvGuard(guard, stdout_guard)))
228    }
229
230    #[inline]
231    pub fn find_ip(
232        &self,
233        headers: &HeaderMap,
234        connect_info: ConnectInfo<std::net::SocketAddr>,
235    ) -> std::net::IpAddr {
236        for cidr in &self.app_trusted_proxies {
237            if cidr.contains(&connect_info.ip()) {
238                if let Some(forwarded) = headers.get("X-Forwarded-For")
239                    && let Ok(forwarded) = forwarded.to_str()
240                    && let Some(ip) = forwarded.split(',').next()
241                {
242                    return ip.parse().unwrap_or_else(|_| connect_info.ip());
243                }
244
245                if let Some(forwarded) = headers.get("X-Real-IP")
246                    && let Ok(forwarded) = forwarded.to_str()
247                {
248                    return forwarded.parse().unwrap_or_else(|_| connect_info.ip());
249                }
250            }
251        }
252
253        connect_info.ip()
254    }
255
256    #[inline]
257    pub fn is_debug(&self) -> bool {
258        self.app_debug.load(std::sync::atomic::Ordering::Relaxed)
259    }
260}