Skip to main content

shared/
env.rs

1use anyhow::Context;
2use axum::{extract::ConnectInfo, http::HeaderMap};
3use colored::Colorize;
4use dotenvy::dotenv;
5use std::sync::{Arc, atomic::AtomicBool};
6use tracing_subscriber::fmt::writer::MakeWriterExt;
7
8#[derive(Clone)]
9pub enum RedisMode {
10    Redis {
11        redis_url: Option<String>,
12    },
13    Sentinel {
14        cluster_name: String,
15        redis_sentinels: Vec<String>,
16    },
17}
18
19impl std::fmt::Display for RedisMode {
20    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
21        match self {
22            RedisMode::Redis { .. } => write!(f, "Redis"),
23            RedisMode::Sentinel { .. } => write!(f, "Sentinel"),
24        }
25    }
26}
27
28pub struct EnvGuard(
29    pub Option<tracing_appender::non_blocking::WorkerGuard>,
30    pub tracing_appender::non_blocking::WorkerGuard,
31);
32
33pub struct Env {
34    pub redis_mode: RedisMode,
35
36    pub sentry_url: Option<String>,
37    pub database_migrate: bool,
38    pub database_url: String,
39    pub database_url_primary: Option<String>,
40
41    pub bind: String,
42    pub port: u16,
43
44    pub aio_base_wings_configuration: Option<String>,
45
46    pub app_primary: bool,
47    pub app_debug: AtomicBool,
48    pub app_enable_wings_proxy: bool,
49    pub app_use_decryption_cache: bool,
50    pub app_use_internal_cache: bool,
51    pub app_trusted_proxies: Vec<cidr::IpCidr>,
52    pub app_log_directory: Option<String>,
53    pub app_encryption_key: String,
54    pub server_name: Option<String>,
55}
56
57impl Env {
58    pub fn parse() -> Result<(Arc<Self>, EnvGuard), anyhow::Error> {
59        dotenv().ok();
60
61        let env = Self {
62            redis_mode: match std::env::var("REDIS_MODE")
63                .unwrap_or("redis".to_string())
64                .trim_matches('"')
65            {
66                "redis" => RedisMode::Redis {
67                    redis_url: std::env::var("REDIS_URL")
68                        .ok()
69                        .map(|s| s.trim_matches('"').to_string()),
70                },
71                "sentinel" => RedisMode::Sentinel {
72                    cluster_name: std::env::var("REDIS_SENTINEL_CLUSTER")
73                        .context("REDIS_SENTINEL_CLUSTER is required")?
74                        .trim_matches('"')
75                        .to_string(),
76                    redis_sentinels: std::env::var("REDIS_SENTINELS")
77                        .context("REDIS_SENTINELS is required")?
78                        .trim_matches('"')
79                        .split(',')
80                        .map(|s| s.to_string())
81                        .collect(),
82                },
83                _ => {
84                    return Err(anyhow::anyhow!(
85                        "Invalid REDIS_MODE. Expected 'redis' or 'sentinel'."
86                    ));
87                }
88            },
89
90            sentry_url: std::env::var("SENTRY_URL")
91                .ok()
92                .map(|s| s.trim_matches('"').to_string()),
93            database_migrate: std::env::var("DATABASE_MIGRATE")
94                .unwrap_or("false".to_string())
95                .trim_matches('"')
96                .parse()
97                .unwrap(),
98            database_url: std::env::var("DATABASE_URL")
99                .context("DATABASE_URL is required")?
100                .trim_matches('"')
101                .to_string(),
102            database_url_primary: std::env::var("DATABASE_URL_PRIMARY")
103                .ok()
104                .map(|s| s.trim_matches('"').to_string()),
105
106            bind: std::env::var("BIND")
107                .unwrap_or("0.0.0.0".to_string())
108                .trim_matches('"')
109                .to_string(),
110            port: std::env::var("PORT")
111                .unwrap_or("6969".to_string())
112                .parse()
113                .context("Invalid PORT value")?,
114
115            aio_base_wings_configuration: std::env::var("AIO_BASE_WINGS_CONFIGURATION")
116                .ok()
117                .map(|s| s.trim_matches('"').to_string()),
118
119            app_primary: std::env::var("APP_PRIMARY")
120                .unwrap_or("true".to_string())
121                .trim_matches('"')
122                .parse()
123                .context("Invalid APP_PRIMARY value")?,
124            app_debug: AtomicBool::new(
125                std::env::var("APP_DEBUG")
126                    .unwrap_or("false".to_string())
127                    .trim_matches('"')
128                    .parse()
129                    .context("Invalid APP_DEBUG value")?,
130            ),
131            app_enable_wings_proxy: std::env::var("APP_ENABLE_WINGS_PROXY")
132                .unwrap_or("false".to_string())
133                .trim_matches('"')
134                .parse()
135                .context("Invalid APP_ENABLE_WINGS_PROXY value")?,
136            app_use_decryption_cache: std::env::var("APP_USE_DECRYPTION_CACHE")
137                .unwrap_or("false".to_string())
138                .trim_matches('"')
139                .parse()
140                .context("Invalid APP_USE_DECRYPTION_CACHE value")?,
141            app_use_internal_cache: std::env::var("APP_USE_INTERNAL_CACHE")
142                .unwrap_or("true".to_string())
143                .trim_matches('"')
144                .parse()
145                .context("Invalid APP_USE_INTERNAL_CACHE value")?,
146            app_trusted_proxies: std::env::var("APP_TRUSTED_PROXIES")
147                .unwrap_or("".to_string())
148                .trim_matches('"')
149                .split(',')
150                .filter_map(|s| if s.is_empty() { None } else { s.parse().ok() })
151                .collect(),
152            app_log_directory: std::env::var("APP_LOG_DIRECTORY")
153                .ok()
154                .map(|s| s.trim_matches('"').to_string()),
155            app_encryption_key: std::env::var("APP_ENCRYPTION_KEY")
156                .expect("APP_ENCRYPTION_KEY is required")
157                .trim_matches('"')
158                .to_string(),
159            server_name: std::env::var("SERVER_NAME")
160                .ok()
161                .map(|s| s.trim_matches('"').to_string()),
162        };
163
164        if env.app_encryption_key.to_lowercase() == "changeme" {
165            println!(
166                "{}", "You are using the default APP_ENCRYPTION_KEY. This is unsupported, please modify your .env or your docker compose file.".red()
167            );
168            std::process::exit(1);
169        }
170
171        let (stdout_writer, stdout_guard) = tracing_appender::non_blocking(std::io::stdout());
172
173        let (appender, guard) = if let Some(app_log_directory) = &env.app_log_directory {
174            if !std::path::Path::new(app_log_directory).exists() {
175                std::fs::create_dir_all(app_log_directory)
176                    .context("failed to create log directory")?;
177            }
178
179            let latest_log_path = std::path::Path::new(&app_log_directory).join("panel.log");
180            let latest_file = std::fs::OpenOptions::new()
181                .create(true)
182                .append(true)
183                .open(&latest_log_path)
184                .context("failed to open latest log file")?;
185
186            let rolling_appender = tracing_appender::rolling::Builder::new()
187                .filename_prefix("panel")
188                .filename_suffix("log")
189                .max_log_files(30)
190                .rotation(tracing_appender::rolling::Rotation::DAILY)
191                .build(app_log_directory)
192                .context("failed to create rolling log file appender")?;
193
194            let (appender, guard) = tracing_appender::non_blocking::NonBlockingBuilder::default()
195                .buffered_lines_limit(50)
196                .finish(latest_file.and(rolling_appender));
197
198            (Some(appender), Some(guard))
199        } else {
200            (None, None)
201        };
202
203        if let Some(file_appender) = appender {
204            tracing::subscriber::set_global_default(
205                tracing_subscriber::fmt()
206                    .with_timer(tracing_subscriber::fmt::time::ChronoLocal::rfc_3339())
207                    .with_writer(stdout_writer.and(file_appender))
208                    .with_target(false)
209                    .with_level(true)
210                    .with_file(true)
211                    .with_line_number(true)
212                    .with_max_level(if env.is_debug() {
213                        tracing::Level::DEBUG
214                    } else {
215                        tracing::Level::INFO
216                    })
217                    .finish(),
218            )?;
219        } else {
220            tracing::subscriber::set_global_default(
221                tracing_subscriber::fmt()
222                    .with_timer(tracing_subscriber::fmt::time::ChronoLocal::new(
223                        "%Y-%m-%d %H:%M:%S %z".to_string(),
224                    ))
225                    .with_writer(stdout_writer)
226                    .with_target(false)
227                    .with_level(true)
228                    .with_file(true)
229                    .with_line_number(true)
230                    .with_max_level(if env.is_debug() {
231                        tracing::Level::DEBUG
232                    } else {
233                        tracing::Level::INFO
234                    })
235                    .finish(),
236            )?;
237        }
238
239        Ok((Arc::new(env), EnvGuard(guard, stdout_guard)))
240    }
241
242    #[inline]
243    pub fn find_ip(
244        &self,
245        headers: &HeaderMap,
246        connect_info: ConnectInfo<std::net::SocketAddr>,
247    ) -> std::net::IpAddr {
248        for cidr in &self.app_trusted_proxies {
249            if cidr.contains(&connect_info.ip()) {
250                if let Some(forwarded) = headers.get("X-Forwarded-For")
251                    && let Ok(forwarded) = forwarded.to_str()
252                    && let Some(ip) = forwarded.split(',').next()
253                {
254                    return ip.parse().unwrap_or_else(|_| connect_info.ip());
255                }
256
257                if let Some(forwarded) = headers.get("X-Real-IP")
258                    && let Ok(forwarded) = forwarded.to_str()
259                {
260                    return forwarded.parse().unwrap_or_else(|_| connect_info.ip());
261                }
262            }
263        }
264
265        connect_info.ip()
266    }
267
268    #[inline]
269    pub fn is_debug(&self) -> bool {
270        self.app_debug.load(std::sync::atomic::Ordering::Relaxed)
271    }
272}