diff --git a/CHANGELOG.md b/CHANGELOG.md index ca07952..7a02830 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,36 +7,28 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +## [0.1.1] - 2026-02-09 + +### Fixed +- Fixed infinite redirect loop when Akkoma has `force_ssl: [rewrite_on: [:x_forwarded_proto]]` enabled + ### Added -- Initial release of Akkoma Media Proxy -- Caching reverse proxy for Akkoma/Pleroma media -- Automatic image format conversion (AVIF, WebP) -- Content negotiation based on Accept headers -- Path filtering for `/media` and `/proxy` endpoints -- TOML-based configuration with sensible defaults -- Environment variable configuration support -- Docker support with multi-platform builds -- GitHub Actions CI/CD pipeline -- Health check endpoint (`/health`) -- Metrics endpoint (`/metrics`) -- Comprehensive documentation -- Example configuration files -- Docker Compose example - -### Features -- High-performance async I/O with Tokio -- Intelligent caching with TTL and size limits -- Image quality and dimension controls -- Configurable Via header -- Connection pooling for upstream requests -- CORS support -- Gzip/Brotli compression -- Security hardening (path restrictions, timeouts) +- **Secure X-Forwarded headers support**: Opt-in forwarding of `X-Forwarded-Proto`, `X-Forwarded-For`, and `X-Forwarded-Host` headers with trusted proxy validation +- Configuration options `forward_headers_enabled` and `trusted_proxies` for controlling header forwarding behavior +- IP address and CIDR range matching for trusted proxy verification +- Automatic header derivation from actual connection for untrusted sources +- Comprehensive test suite for header forwarding and trusted proxy functionality + +### Security +- X-Forwarded headers are now only honored from explicitly trusted proxy sources +- Prevents header spoofing attacks by validating client IP against configured trusted proxies +- Headers from untrusted sources are ignored or overwritten with actual connection information ## [0.1.0] - 2024-12-06 ### Added - Initial implementation -[Unreleased]: https://github.com/BlockG-ws/fantastic-computing-machine/compare/v0.1.0...HEAD -[0.1.0]: https://github.com/BlockG-ws/fantastic-computing-machine/releases/tag/v0.1.0 +[Unreleased]: https://github.com/BlockG-ws/akkoproxy/compare/v0.1.1...HEAD +[0.1.1]: https://github.com/BlockG-ws/akkoproxy/compare/v0.1.0...v0.1.1 +[0.1.0]: https://github.com/BlockG-ws/akkoproxy/releases/tag/v0.1.0 diff --git a/Cargo.lock b/Cargo.lock index 86ec1cb..06c0412 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -19,7 +19,7 @@ dependencies = [ [[package]] name = "akkoproxy" -version = "0.1.0" +version = "0.1.1" dependencies = [ "anyhow", "axum", @@ -30,6 +30,7 @@ dependencies = [ "hyper", "hyper-util", "image", + "ipnetwork", "libavif", "moka", "reqwest", @@ -1243,6 +1244,15 @@ version = "2.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "469fb0b9cefa57e3ef31275ee7cacb78f2fdca44e4765491884a2b119d4eb130" +[[package]] +name = "ipnetwork" +version = "0.20.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bf466541e9d546596ee94f9f69590f89473455f88372423e0008fc1a7daf100e" +dependencies = [ + "serde", +] + [[package]] name = "iri-string" version = "0.7.9" diff --git a/Cargo.toml b/Cargo.toml index e6542df..0f5be70 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "akkoproxy" -version = "0.1.0" +version = "0.1.1" edition = "2021" authors = ["Akkoproxy Contributors"] description = "A fast caching and optimization media proxy for Akkoma/Pleroma" @@ -40,6 +40,7 @@ futures = "0.3" http-body-util = "0.1" url = "2.5" clap = { version = "4.5", features = ["derive"] } +ipnetwork = "0.20" [profile.release] opt-level = 3 diff --git a/README.md b/README.md index 1f34ce9..6d8c9f6 100644 --- a/README.md +++ b/README.md @@ -5,6 +5,7 @@ A fast caching and optimization media proxy for Akkoma/Pleroma, built in Rust. ## Features - **Caching Reverse Proxy**: Caches media and proxy requests to reduce load on upstream servers +- **Secure X-Forwarded Headers**: Opt-in forwarding of `X-Forwarded-Proto`, `X-Forwarded-For`, and `X-Forwarded-Host` headers with trusted proxy validation, ensuring compatibility with Akkoma's `force_ssl` configuration - **Header Preservation**: Preserves all upstream headers by default, including redirects (302) with Location headers - **Image Format Conversion**: Automatically converts images to modern formats (AVIF, WebP) based on client `Accept` headers - **Path Filtering**: Only handles `/media` and `/proxy` endpoints for security @@ -93,9 +94,51 @@ timeout = 30 # Request timeout in seconds ```toml [server] bind = "0.0.0.0:3000" # Bind address -via_header = "akkoma-media-proxy/0.1.0" # Via header value +via_header = "akkoma-media-proxy/0.1.1" # Via header value preserve_upstream_headers = true # Preserve all headers from upstream (default: true) behind_cloudflare_free = false # Enable Cloudflare Free plan compatibility (default: false) + +# X-Forwarded headers configuration (for SSL/TLS detection) +forward_headers_enabled = false # Enable X-Forwarded-* header forwarding (default: false) +trusted_proxies = ["192.168.1.1", "10.0.0.0/8"] # Trusted proxy IPs/CIDRs (default: empty) +``` + +#### X-Forwarded Headers and Trusted Proxies + +**Important for Akkoma's `force_ssl` configuration:** + +When Akkoma has `force_ssl: [rewrite_on: [:x_forwarded_proto]]` enabled, it relies on the `X-Forwarded-Proto` header to detect HTTPS connections. To prevent infinite redirect loops, you need to configure header forwarding: + +1. **Enable header forwarding**: Set `forward_headers_enabled = true` +2. **Configure trusted proxies**: List the IP addresses or CIDR ranges of your reverse proxy/load balancer in `trusted_proxies` + +**Security considerations:** +- Only enable `forward_headers_enabled` if you're behind a reverse proxy (nginx, Cloudflare, etc.) +- **Must** configure `trusted_proxies` with your proxy IPs - an empty list will prevent all header forwarding +- Only requests from trusted IPs will have their `X-Forwarded-*` headers honored +- Requests from untrusted sources will have headers derived from the actual connection + +**Example configurations:** + +```toml +# Behind nginx on the same host +[server] +forward_headers_enabled = true +trusted_proxies = ["127.0.0.1", "::1"] + +# Behind Cloudflare (use Cloudflare's IP ranges) +[server] +forward_headers_enabled = true +trusted_proxies = [ + "173.245.48.0/20", + "103.21.244.0/22", + # ... other Cloudflare ranges +] + +# Behind a local reverse proxy +[server] +forward_headers_enabled = true +trusted_proxies = ["192.168.1.1", "10.0.0.0/8"] ``` #### Cloudflare Free Plan Compatibility @@ -153,6 +196,9 @@ max_dimension = 4096 # Maximum image dimension 1. **Request Filtering**: Only `/media` and `/proxy` paths are allowed 2. **Cache Check**: Looks for cached response with the requested format 3. **Upstream Fetch**: If not cached, fetches from upstream server + - **Secure header forwarding**: When enabled, forwards `X-Forwarded-Proto`, `X-Forwarded-For`, and `X-Forwarded-Host` headers only from trusted proxy sources + - **Untrusted protection**: Headers from untrusted sources are ignored or derived from the actual connection + - This ensures proper SSL/TLS detection when Akkoma has `force_ssl` enabled while preventing header spoofing 4. **Header Preservation**: All upstream headers (including Location for redirects) are preserved by default 5. **Image Conversion**: For images, converts to the best format based on `Accept` header: - Prefers AVIF if `image/avif` is accepted diff --git a/config.example.toml b/config.example.toml index a192c93..b36d92f 100644 --- a/config.example.toml +++ b/config.example.toml @@ -13,7 +13,7 @@ timeout = 30 bind = "0.0.0.0:3000" # Custom Via header value (default: akkoproxy/{version}) -via_header = "akkoproxy/0.1.0" +via_header = "akkoproxy/0.1.1" # Preserve all headers from upstream when responding (default: true) preserve_upstream_headers = true @@ -27,6 +27,21 @@ preserve_upstream_headers = true # - Then: Add query parameter "format=avif" behind_cloudflare_free = false +# Enable forwarding of X-Forwarded-* headers to upstream (default: false) +# When enabled, X-Forwarded-Proto, X-Forwarded-For, and X-Forwarded-Host headers +# will be forwarded to the upstream server based on the trusted_proxies configuration. +# This is essential for proper SSL/TLS detection when Akkoma has force_ssl enabled. +forward_headers_enabled = false + +# List of trusted proxy IP addresses or CIDR ranges (default: empty) +# Only requests from these IPs will have their X-Forwarded-* headers honored. +# If empty and forward_headers_enabled is true, no headers will be forwarded (secure default). +# For untrusted sources, X-Forwarded-For will be set to the actual client IP. +# Examples: +# trusted_proxies = ["192.168.1.1", "10.0.0.0/8", "172.16.0.0/12"] +# trusted_proxies = ["127.0.0.1", "::1"] # localhost only +trusted_proxies = [] + [cache] # Maximum number of cached items (default: 10000) max_capacity = 10000 diff --git a/src/cache.rs b/src/cache.rs index c40790a..97ce869 100644 --- a/src/cache.rs +++ b/src/cache.rs @@ -39,20 +39,20 @@ impl ResponseCache { .time_to_live(ttl) .initial_capacity(100) .build(); - + Self { cache } } - + /// Get a cached response pub async fn get(&self, key: &CacheKey) -> Option> { self.cache.get(key).await } - + /// Store a response in the cache pub async fn put(&self, key: CacheKey, response: CachedResponse) { self.cache.insert(key, Arc::new(response)).await; } - + /// Get cache statistics pub fn stats(&self) -> CacheStats { CacheStats { @@ -77,16 +77,16 @@ mod tests { #[tokio::test] async fn test_cache_put_and_get() { let cache = ResponseCache::new(100, Duration::from_secs(60), 1024 * 1024); - + let key = CacheKey::new("/media/test.jpg".to_string(), "avif".to_string()); let response = CachedResponse { data: Bytes::from("test data"), content_type: "image/avif".to_string(), upstream_headers: None, }; - + cache.put(key.clone(), response.clone()).await; - + let cached = cache.get(&key).await; assert!(cached.is_some()); assert_eq!(cached.unwrap().content_type, "image/avif"); @@ -95,67 +95,64 @@ mod tests { #[tokio::test] async fn test_cache_miss() { let cache = ResponseCache::new(100, Duration::from_secs(60), 1024 * 1024); - + let key = CacheKey::new("/media/nonexistent.jpg".to_string(), "webp".to_string()); let cached = cache.get(&key).await; - + assert!(cached.is_none()); } - + #[tokio::test] async fn test_cache_with_upstream_headers() { let cache = ResponseCache::new(100, Duration::from_secs(60), 1024 * 1024); - + // Create headers to cache let mut headers = HeaderMap::new(); headers.insert( HeaderName::from_static("x-custom-header"), HeaderValue::from_static("test-value"), ); - + let key = CacheKey::new("/media/test.jpg".to_string(), "avif".to_string()); let response = CachedResponse { data: Bytes::from("test data"), content_type: "image/avif".to_string(), upstream_headers: Some(headers.clone()), }; - + cache.put(key.clone(), response.clone()).await; - + let cached = cache.get(&key).await; assert!(cached.is_some()); let cached = cached.unwrap(); assert_eq!(cached.content_type, "image/avif"); assert!(cached.upstream_headers.is_some()); - + let cached_headers = cached.upstream_headers.as_ref().unwrap(); - assert_eq!( - cached_headers.get("x-custom-header").unwrap(), - "test-value" - ); + assert_eq!(cached_headers.get("x-custom-header").unwrap(), "test-value"); } - + #[tokio::test] async fn test_cache_ttl() { // Create cache with 1 second TTL let cache = ResponseCache::new(100, Duration::from_secs(1), 1024 * 1024); - + let key = CacheKey::new("/media/test.jpg".to_string(), "avif".to_string()); let response = CachedResponse { data: Bytes::from("test data"), content_type: "image/avif".to_string(), upstream_headers: None, }; - + cache.put(key.clone(), response.clone()).await; - + // Should be in cache immediately let cached = cache.get(&key).await; assert!(cached.is_some()); - + // Wait for TTL to expire tokio::time::sleep(Duration::from_secs(2)).await; - + // Should be gone after TTL let cached = cache.get(&key).await; assert!(cached.is_none()); diff --git a/src/config.rs b/src/config.rs index 3796fd3..3b46c2b 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,8 +1,8 @@ +use anyhow::{Context, Result}; use serde::{Deserialize, Serialize}; use std::fs; use std::net::SocketAddr; use std::path::Path; -use anyhow::{Context, Result}; /// Application configuration #[derive(Debug, Clone, Deserialize, Serialize)] @@ -10,14 +10,14 @@ pub struct Config { /// Server configuration #[serde(default)] pub server: ServerConfig, - + /// Upstream configuration pub upstream: UpstreamConfig, - + /// Cache configuration #[serde(default)] pub cache: CacheConfig, - + /// Image processing configuration #[serde(default)] pub image: ImageConfig, @@ -28,28 +28,39 @@ pub struct ServerConfig { /// Address to bind to #[serde(default = "default_bind_address")] pub bind: SocketAddr, - + /// Custom Via header value #[serde(default = "default_via_header")] pub via_header: String, - + /// Preserve all headers from upstream #[serde(default = "default_true")] pub preserve_upstream_headers: bool, - + /// Enable Cloudflare Free plan compatibility mode /// When enabled, the proxy will look for a 'format' query parameter /// and use it to determine output format (avif/webp), then strip it /// from the upstream request #[serde(default)] pub behind_cloudflare_free: bool, + + /// Enable forwarding of X-Forwarded-* headers to upstream + /// When disabled, X-Forwarded-* headers from clients are ignored + #[serde(default)] + pub forward_headers_enabled: bool, + + /// List of trusted proxy IP addresses or CIDR ranges + /// Only requests from these IPs will have their X-Forwarded-* headers honored + /// If empty, no headers will be forwarded (secure default) + #[serde(default)] + pub trusted_proxies: Vec, } #[derive(Debug, Clone, Deserialize, Serialize)] pub struct UpstreamConfig { /// Upstream server URL (e.g., "https://akkoma.example.com") pub url: String, - + /// Timeout for upstream requests in seconds #[serde(default = "default_timeout")] pub timeout: u64, @@ -60,11 +71,11 @@ pub struct CacheConfig { /// Maximum number of cached items #[serde(default = "default_max_capacity")] pub max_capacity: u64, - + /// Time to live for cached items in seconds #[serde(default = "default_ttl")] pub ttl: u64, - + /// Maximum size of a cached item in bytes #[serde(default = "default_max_item_size")] pub max_item_size: u64, @@ -75,15 +86,15 @@ pub struct ImageConfig { /// Enable AVIF conversion #[serde(default = "default_true")] pub enable_avif: bool, - + /// Enable WebP conversion #[serde(default = "default_true")] pub enable_webp: bool, - + /// JPEG quality for conversions (1-100) #[serde(default = "default_quality")] pub quality: u8, - + /// Maximum image dimensions for processing #[serde(default = "default_max_dimension")] pub max_dimension: u32, @@ -91,7 +102,9 @@ pub struct ImageConfig { // Default value functions fn default_bind_address() -> SocketAddr { - "0.0.0.0:3000".parse().expect("Failed to parse default bind address") + "0.0.0.0:3000" + .parse() + .expect("Failed to parse default bind address") } fn default_via_header() -> String { @@ -133,6 +146,8 @@ impl Default for ServerConfig { via_header: default_via_header(), preserve_upstream_headers: true, behind_cloudflare_free: false, + forward_headers_enabled: false, + trusted_proxies: Vec::new(), } } } @@ -161,16 +176,15 @@ impl Default for ImageConfig { impl Config { /// Load configuration from a TOML file pub fn from_file>(path: P) -> Result { - let contents = fs::read_to_string(path) - .context("Failed to read configuration file")?; - - let config: Config = toml::from_str(&contents) - .context("Failed to parse configuration file")?; - + let contents = fs::read_to_string(path).context("Failed to read configuration file")?; + + let config: Config = + toml::from_str(&contents).context("Failed to parse configuration file")?; + config.validate()?; Ok(config) } - + /// Create a default configuration with a given upstream URL #[cfg(test)] pub fn with_upstream(upstream_url: String) -> Self { @@ -184,7 +198,7 @@ impl Config { image: ImageConfig::default(), } } - + /// Create a default configuration with empty upstream (to be filled later) pub fn default_without_upstream() -> Self { Self { @@ -197,18 +211,17 @@ impl Config { image: ImageConfig::default(), } } - + /// Validate configuration pub fn validate(&self) -> Result<()> { // Validate upstream URL - url::Url::parse(&self.upstream.url) - .context("Invalid upstream URL")?; - + url::Url::parse(&self.upstream.url).context("Invalid upstream URL")?; + // Validate quality if self.image.quality == 0 || self.image.quality > 100 { anyhow::bail!("Image quality must be between 1 and 100"); } - + Ok(()) } } diff --git a/src/image.rs b/src/image.rs index b314c4c..cc4433e 100644 --- a/src/image.rs +++ b/src/image.rs @@ -30,30 +30,25 @@ impl ImageConverter { enable_webp, } } - + /// Convert image to the requested format - pub fn convert(&self, data: &Bytes, target_format: OutputFormat) -> Result<(Bytes, &'static str)> { + pub fn convert( + &self, + data: &Bytes, + target_format: OutputFormat, + ) -> Result<(Bytes, &'static str)> { // Try to detect and decode the image - let img = image::load_from_memory(data) - .context("Failed to decode image")?; - + let img = image::load_from_memory(data).context("Failed to decode image")?; + // Check dimensions and resize if necessary let img = self.resize_if_needed(img); - + // Convert to target format let (converted, mime_type) = match target_format { - OutputFormat::Avif if self.enable_avif => { - (self.to_avif(&img)?, "image/avif") - } - OutputFormat::WebP if self.enable_webp => { - (self.to_webp(&img)?, "image/webp") - } - OutputFormat::Jpeg => { - (self.to_jpeg(&img)?, "image/jpeg") - } - OutputFormat::Png => { - (self.to_png(&img)?, "image/png") - } + OutputFormat::Avif if self.enable_avif => (self.to_avif(&img)?, "image/avif"), + OutputFormat::WebP if self.enable_webp => (self.to_webp(&img)?, "image/webp"), + OutputFormat::Jpeg => (self.to_jpeg(&img)?, "image/jpeg"), + OutputFormat::Png => (self.to_png(&img)?, "image/png"), OutputFormat::Original => { // Return original data return Ok((data.clone(), "application/octet-stream")); @@ -63,30 +58,30 @@ impl ImageConverter { (self.to_jpeg(&img)?, "image/jpeg") } }; - + Ok((converted, mime_type)) } - + /// Resize image if it exceeds maximum dimensions fn resize_if_needed(&self, img: DynamicImage) -> DynamicImage { let (width, height) = img.dimensions(); - + if width > self.max_dimension || height > self.max_dimension { let scale = if width > height { self.max_dimension as f32 / width as f32 } else { self.max_dimension as f32 / height as f32 }; - + let new_width = (width as f32 * scale) as u32; let new_height = (height as f32 * scale) as u32; - + img.resize(new_width, new_height, image::imageops::FilterType::Lanczos3) } else { img } } - + /// Convert image to AVIF format fn to_avif(&self, img: &DynamicImage) -> Result { let mut buffer = Vec::new(); @@ -95,52 +90,52 @@ impl ImageConverter { 10, // Speed (1-10, 10 is fastest) self.quality, ); - + img.write_with_encoder(encoder) .context("Failed to encode AVIF")?; - + Ok(Bytes::from(buffer)) } - + /// Convert image to WebP format fn to_webp(&self, img: &DynamicImage) -> Result { let mut buffer = Vec::new(); let encoder = image::codecs::webp::WebPEncoder::new_lossless(&mut buffer); - + img.write_with_encoder(encoder) .context("Failed to encode WebP")?; - + Ok(Bytes::from(buffer)) } - + /// Convert image to JPEG format fn to_jpeg(&self, img: &DynamicImage) -> Result { let mut buffer = Vec::new(); let mut cursor = Cursor::new(&mut buffer); - + img.write_to(&mut cursor, ImageFormat::Jpeg) .context("Failed to encode JPEG")?; - + Ok(Bytes::from(buffer)) } - + /// Convert image to PNG format fn to_png(&self, img: &DynamicImage) -> Result { let mut buffer = Vec::new(); let mut cursor = Cursor::new(&mut buffer); - + img.write_to(&mut cursor, ImageFormat::Png) .context("Failed to encode PNG")?; - + Ok(Bytes::from(buffer)) } } /// Parse Accept header to determine preferred image format -/// +/// /// # Example /// For the Accept header: `image/avif,image/webp,image/png,image/svg+xml,image/*;q=0.8,*/*;q=0.5` -/// +/// /// With default config (enable_avif=true, enable_webp=true): /// - image/avif has quality 1.0 (default when not specified) /// - image/webp has quality 1.0 (default when not specified) @@ -148,17 +143,17 @@ impl ImageConverter { /// - image/svg+xml is not supported (ignored) /// - image/* has quality 0.8 /// - */* has quality 0.5 -/// +/// /// Result: Returns Avif (first format with highest quality 1.0) pub fn parse_accept_header(accept: &str, enable_avif: bool, enable_webp: bool) -> OutputFormat { // Parse media types and their quality values let mut formats: Vec<(OutputFormat, f32)> = Vec::new(); - + for part in accept.split(',') { let part = part.trim(); let mut segments = part.split(';'); let media_type = segments.next().unwrap_or("").trim(); - + // Extract quality value (default to 1.0) let quality = segments .find_map(|s| { @@ -166,7 +161,7 @@ pub fn parse_accept_header(accept: &str, enable_avif: bool, enable_webp: bool) - s.strip_prefix("q=")?.parse::().ok() }) .unwrap_or(1.0); - + // Map media type to output format let format = match media_type { "image/avif" if enable_avif => Some(OutputFormat::Avif), @@ -177,18 +172,19 @@ pub fn parse_accept_header(accept: &str, enable_avif: bool, enable_webp: bool) - "*/*" => Some(OutputFormat::Original), _ => None, }; - + if let Some(fmt) = format { formats.push((fmt, quality)); } } - + // Sort by quality (descending) // Use unwrap_or to handle NaN values gracefully (treat as equal) formats.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)); - + // Return the highest quality format, or Original if none found - formats.first() + formats + .first() .map(|(fmt, _)| *fmt) .unwrap_or(OutputFormat::Original) } @@ -250,11 +246,26 @@ mod tests { #[test] fn test_format_from_content_type() { - assert_eq!(format_from_content_type("image/avif"), Some(OutputFormat::Avif)); - assert_eq!(format_from_content_type("image/webp"), Some(OutputFormat::WebP)); - assert_eq!(format_from_content_type("image/jpeg"), Some(OutputFormat::Jpeg)); - assert_eq!(format_from_content_type("image/jpg"), Some(OutputFormat::Jpeg)); - assert_eq!(format_from_content_type("image/png"), Some(OutputFormat::Png)); + assert_eq!( + format_from_content_type("image/avif"), + Some(OutputFormat::Avif) + ); + assert_eq!( + format_from_content_type("image/webp"), + Some(OutputFormat::WebP) + ); + assert_eq!( + format_from_content_type("image/jpeg"), + Some(OutputFormat::Jpeg) + ); + assert_eq!( + format_from_content_type("image/jpg"), + Some(OutputFormat::Jpeg) + ); + assert_eq!( + format_from_content_type("image/png"), + Some(OutputFormat::Png) + ); assert_eq!(format_from_content_type("image/gif"), None); assert_eq!(format_from_content_type("text/plain"), None); } @@ -264,11 +275,11 @@ mod tests { // Same format satisfies assert!(format_satisfies(OutputFormat::Avif, OutputFormat::Avif)); assert!(format_satisfies(OutputFormat::WebP, OutputFormat::WebP)); - + // Original always satisfied assert!(format_satisfies(OutputFormat::Avif, OutputFormat::Original)); assert!(format_satisfies(OutputFormat::Jpeg, OutputFormat::Original)); - + // Different formats don't satisfy assert!(!format_satisfies(OutputFormat::Jpeg, OutputFormat::Avif)); assert!(!format_satisfies(OutputFormat::Png, OutputFormat::WebP)); diff --git a/src/main.rs b/src/main.rs index 8ceedea..00fc975 100644 --- a/src/main.rs +++ b/src/main.rs @@ -4,10 +4,7 @@ mod image; mod proxy; use anyhow::{Context, Result}; -use axum::{ - routing::get, - Router, -}; +use axum::{routing::get, Router}; use clap::Parser; use std::net::SocketAddr; use std::path::PathBuf; @@ -74,14 +71,17 @@ async fn main() -> Result<()> { // Load configuration let config = load_config(&cli)?; - + info!("Configuration loaded:"); info!(" Bind address: {}", config.server.bind); info!(" Upstream URL: {}", config.upstream.url); info!(" Cache max capacity: {}", config.cache.max_capacity); info!(" AVIF conversion: {}", config.image.enable_avif); info!(" WebP conversion: {}", config.image.enable_webp); - info!(" Preserve upstream headers: {}", config.server.preserve_upstream_headers); + info!( + " Preserve upstream headers: {}", + config.server.preserve_upstream_headers + ); // Create application state let state = AppState::new(config.clone()); @@ -92,7 +92,8 @@ async fn main() -> Result<()> { .route("/metrics", get(metrics_handler)) .fallback(proxy_handler) .layer(TraceLayer::new_for_http()) - .with_state(state); + .with_state(state) + .into_make_service_with_connect_info::(); // Start server let listener = tokio::net::TcpListener::bind(&config.server.bind) @@ -100,10 +101,8 @@ async fn main() -> Result<()> { .with_context(|| format!("Failed to bind to {}", config.server.bind))?; info!("Server listening on {}", config.server.bind); - - axum::serve(listener, app) - .await - .context("Server error")?; + + axum::serve(listener, app).await.context("Server error")?; Ok(()) } @@ -129,10 +128,13 @@ fn load_config(cli: &Cli) -> Result { // Priority 2 (medium): Apply command-line options if let Some(upstream_url) = &cli.upstream { - info!("Overriding upstream URL from command line: {}", upstream_url); + info!( + "Overriding upstream URL from command line: {}", + upstream_url + ); config.upstream.url = upstream_url.clone(); } - + if let Some(bind) = cli.bind { config.server.bind = bind; } @@ -158,14 +160,14 @@ fn load_config(cli: &Cli) -> Result { info!("Overriding upstream URL from environment: {}", upstream_url); config.upstream.url = upstream_url; } - + if let Ok(bind_str) = std::env::var("BIND_ADDRESS") { if let Ok(bind) = bind_str.parse() { info!("Overriding bind address from environment: {}", bind); config.server.bind = bind; } } - + if let Ok(preserve) = std::env::var("PRESERVE_HEADERS") { if let Ok(value) = preserve.parse::() { info!("Overriding preserve_headers from environment: {}", value); diff --git a/src/proxy.rs b/src/proxy.rs index 39ce89e..ef9e95e 100644 --- a/src/proxy.rs +++ b/src/proxy.rs @@ -1,13 +1,18 @@ use crate::cache::{CacheKey, CachedResponse, ResponseCache}; use crate::config::Config; -use crate::image::{is_image_content_type, parse_accept_header, format_from_content_type, format_satisfies, ImageConverter, OutputFormat}; +use crate::image::{ + format_from_content_type, format_satisfies, is_image_content_type, parse_accept_header, + ImageConverter, OutputFormat, +}; use axum::{ body::Body, - extract::{Request, State}, + extract::{ConnectInfo, Request, State}, http::{header, HeaderMap, StatusCode, Uri}, response::{IntoResponse, Response}, }; use bytes::Bytes; +use ipnetwork::IpNetwork; +use std::net::{IpAddr, SocketAddr}; use std::sync::Arc; use std::time::Duration; use tracing::{debug, error, info, warn}; @@ -40,7 +45,7 @@ fn build_vary_header(upstream_vary: Option<&str>) -> String { let has_accept = upstream_value .split(',') .any(|v| v.trim().eq_ignore_ascii_case("accept")); - + if has_accept { // If upstream already has "Accept", just use upstream value upstream_value.to_string() @@ -54,6 +59,80 @@ fn build_vary_header(upstream_vary: Option<&str>) -> String { } } +/// Check if an IP address is in the trusted proxy list +fn is_trusted_proxy(ip: &IpAddr, trusted_proxies: &[String]) -> bool { + if trusted_proxies.is_empty() { + return false; + } + + for proxy in trusted_proxies { + // Try parsing as a CIDR range + if let Ok(network) = proxy.parse::() { + if network.contains(*ip) { + return true; + } + } + // Try parsing as a single IP address + else if let Ok(proxy_ip) = proxy.parse::() { + if proxy_ip == *ip { + return true; + } + } + } + + false +} + +/// Apply X-Forwarded headers to the upstream request builder +/// Only honors headers from trusted proxies, otherwise derives from actual connection +fn apply_forwarded_headers( + request_builder: reqwest::RequestBuilder, + headers: &HeaderMap, + client_ip: IpAddr, + config: &Config, +) -> reqwest::RequestBuilder { + let mut builder = request_builder; + + // Check if header forwarding is enabled + if !config.server.forward_headers_enabled { + debug!("X-Forwarded header forwarding is disabled"); + return builder; + } + + // Check if the client IP is trusted + let is_trusted = is_trusted_proxy(&client_ip, &config.server.trusted_proxies); + + if is_trusted { + debug!("Client IP {} is trusted, forwarding X-Forwarded-* headers", client_ip); + + // Forward X-Forwarded-Proto if present + if let Some(forwarded_proto) = headers.get("x-forwarded-proto") { + builder = builder.header("X-Forwarded-Proto", forwarded_proto); + } + + // Forward X-Forwarded-For if present + if let Some(forwarded_for) = headers.get("x-forwarded-for") { + builder = builder.header("X-Forwarded-For", forwarded_for); + } + + // Forward X-Forwarded-Host if present + if let Some(forwarded_host) = headers.get("x-forwarded-host") { + builder = builder.header("X-Forwarded-Host", forwarded_host); + } + } else { + debug!("Client IP {} is not trusted, setting X-Forwarded-For from actual connection", client_ip); + + // Set X-Forwarded-For to the actual client IP only + // Do not preserve untrusted X-Forwarded-For headers as they could be spoofed + builder = builder.header("X-Forwarded-For", client_ip.to_string()); + + // Do not forward X-Forwarded-Proto or X-Forwarded-Host from untrusted sources + // Let the upstream derive these from the actual connection if needed + } + + builder +} + /// Application state shared across handlers #[derive(Clone)] pub struct AppState { @@ -65,17 +144,21 @@ pub struct AppState { impl AppState { pub fn new(config: Config) -> Self { - debug!("Initializing AppState with config: bind={}, upstream={}", - config.server.bind, config.upstream.url); - + debug!( + "Initializing AppState with config: bind={}, upstream={}", + config.server.bind, config.upstream.url + ); + let cache = ResponseCache::new( config.cache.max_capacity, Duration::from_secs(config.cache.ttl), config.cache.max_item_size, ); - debug!("Cache initialized: max_capacity={}, ttl={}s, max_item_size={} bytes", - config.cache.max_capacity, config.cache.ttl, config.cache.max_item_size); - + debug!( + "Cache initialized: max_capacity={}, ttl={}s, max_item_size={} bytes", + config.cache.max_capacity, config.cache.ttl, config.cache.max_item_size + ); + let client = reqwest::Client::builder() .timeout(Duration::from_secs(config.upstream.timeout)) .user_agent(format!("akkoproxy/{}", env!("CARGO_PKG_VERSION"))) @@ -84,19 +167,26 @@ impl AppState { .redirect(reqwest::redirect::Policy::none()) .build() .expect("Failed to create HTTP client"); - debug!("HTTP client configured: timeout={}s, user_agent=akkoproxy/{}, redirect_policy=none", - config.upstream.timeout, env!("CARGO_PKG_VERSION")); - + debug!( + "HTTP client configured: timeout={}s, user_agent=akkoproxy/{}, redirect_policy=none", + config.upstream.timeout, + env!("CARGO_PKG_VERSION") + ); + let image_converter = Arc::new(ImageConverter::new( config.image.quality, config.image.max_dimension, config.image.enable_avif, config.image.enable_webp, )); - debug!("Image converter initialized: quality={}, max_dimension={}, avif={}, webp={}", - config.image.quality, config.image.max_dimension, - config.image.enable_avif, config.image.enable_webp); - + debug!( + "Image converter initialized: quality={}, max_dimension={}, avif={}, webp={}", + config.image.quality, + config.image.max_dimension, + config.image.enable_avif, + config.image.enable_webp + ); + Self { config: Arc::new(config), cache, @@ -108,6 +198,7 @@ impl AppState { /// Main proxy handler pub async fn proxy_handler( + ConnectInfo(addr): ConnectInfo, State(state): State, uri: Uri, headers: HeaderMap, @@ -115,9 +206,9 @@ pub async fn proxy_handler( ) -> Result { let path = uri.path(); let query = uri.query().unwrap_or(""); - - debug!("Proxying request: {} {}", path, query); - + + debug!("Proxying request: {} {} from {}", path, query, addr); + // Handle root path with redirect if path == "/" { return Ok(Response::builder() @@ -126,27 +217,28 @@ pub async fn proxy_handler( .body(Body::empty()) .expect("Failed to build root redirect response")); } - + // Only handle /media and /proxy paths if !path.starts_with("/media") && !path.starts_with("/proxy") { warn!("Path not allowed: {}", path); return Err(ProxyError::PathNotAllowed); } - + // Parse query parameters if behind_cloudflare_free is enabled - let (format_from_query, upstream_query) = if state.config.server.behind_cloudflare_free && !query.is_empty() { - parse_query_for_format(query) - } else { - (None, query.to_string()) - }; - + let (format_from_query, upstream_query) = + if state.config.server.behind_cloudflare_free && !query.is_empty() { + parse_query_for_format(query) + } else { + (None, query.to_string()) + }; + // Build upstream URL (without format query if it was present) let upstream_url = if upstream_query.is_empty() { format!("{}{}", state.config.upstream.url, path) } else { format!("{}{}?{}", state.config.upstream.url, path, upstream_query) }; - + // Determine desired format let desired_format = if let Some(fmt) = format_from_query { // Use format from query parameter if available @@ -157,63 +249,81 @@ pub async fn proxy_handler( .get(header::ACCEPT) .and_then(|v| v.to_str().ok()) .unwrap_or("*/*"); - + parse_accept_header( accept, state.config.image.enable_avif, state.config.image.enable_webp, ) }; - + // Generate cache key let cache_key = CacheKey::new( - format!("{}{}", path, if query.is_empty() { String::new() } else { format!("?{}", query) }), + format!( + "{}{}", + path, + if query.is_empty() { + String::new() + } else { + format!("?{}", query) + } + ), format!("{:?}", desired_format), ); - + // Check cache first if let Some(cached) = state.cache.get(&cache_key).await { debug!("Cache hit for {}", path); return Ok(build_response( - cached.data.clone(), - &cached.content_type, - &state.config.server.via_header, + cached.data.clone(), + &cached.content_type, + &state.config.server.via_header, cached.upstream_headers.as_ref(), true, // is_cache_hit )); } + + debug!( + "Cache miss for {}, fetching from upstream: {}", + path, upstream_url + ); + + // Build request with forwarded headers + let request_builder = state.client.get(&upstream_url); - debug!("Cache miss for {}, fetching from upstream: {}", path, upstream_url); - + // Apply X-Forwarded-* headers based on configuration and trust policy + let request_builder = apply_forwarded_headers( + request_builder, + &headers, + addr.ip(), + &state.config, + ); + // Fetch from upstream - let response = state.client - .get(&upstream_url) - .send() - .await - .map_err(|e| { - error!("Failed to fetch from upstream: {}", e); - ProxyError::UpstreamError(e) - })?; - + let response = request_builder.send().await.map_err(|e| { + error!("Failed to fetch from upstream: {}", e); + ProxyError::UpstreamError(e) + })?; + let status = response.status(); - + // Handle non-success responses (redirects, errors, etc.) // For non-2xx responses, preserve and forward the response with its status code if !status.is_success() { debug!("Upstream returned non-success status: {}", status); - + // Preserve upstream headers let upstream_headers = if state.config.server.preserve_upstream_headers { Some(response.headers().clone()) } else { None }; - + let body_bytes = response.bytes().await.map_err(|e| { error!("Failed to read response body: {}", e); ProxyError::UpstreamError(e) })?; - + // Build response with the actual status code from upstream return Ok(build_response_with_status( body_bytes, @@ -222,26 +332,26 @@ pub async fn proxy_handler( upstream_headers.as_ref(), )); } - + // Preserve upstream headers if configured (for success responses) let upstream_headers = if state.config.server.preserve_upstream_headers { Some(response.headers().clone()) } else { None }; - + let content_type = response .headers() .get(header::CONTENT_TYPE) .and_then(|v| v.to_str().ok()) .unwrap_or("application/octet-stream") .to_string(); - + let body_bytes = response.bytes().await.map_err(|e| { error!("Failed to read response body: {}", e); ProxyError::UpstreamError(e) })?; - + // Check if this is an image and conversion is requested // Skip conversion if upstream format already satisfies the desired format let upstream_format = format_from_content_type(&content_type); @@ -252,13 +362,17 @@ pub async fn proxy_handler( body_bytes.len(), state.config.cache.max_item_size as usize, ); - + let (final_data, final_content_type) = if needs_conversion { debug!("Converting image to {:?}", desired_format); - + match state.image_converter.convert(&body_bytes, desired_format) { Ok((converted, mime_type)) => { - info!("Successfully converted image: {} bytes -> {} bytes", body_bytes.len(), converted.len()); + info!( + "Successfully converted image: {} bytes -> {} bytes", + body_bytes.len(), + converted.len() + ); (converted, mime_type.to_string()) } Err(e) => { @@ -268,15 +382,21 @@ pub async fn proxy_handler( } } else { if is_image_content_type(&content_type) && upstream_format.is_some() { - debug!("Skipping conversion: upstream format {:?} already satisfies desired format {:?}", - upstream_format, desired_format); + debug!( + "Skipping conversion: upstream format {:?} already satisfies desired format {:?}", + upstream_format, desired_format + ); } else { - debug!("Not converting: is_image={}, format={:?}, size={}", - is_image_content_type(&content_type), desired_format, body_bytes.len()); + debug!( + "Not converting: is_image={}, format={:?}, size={}", + is_image_content_type(&content_type), + desired_format, + body_bytes.len() + ); } (body_bytes, content_type) }; - + // Cache the response if final_data.len() <= state.config.cache.max_item_size as usize { let cached_response = CachedResponse { @@ -289,11 +409,11 @@ pub async fn proxy_handler( } else { debug!("Response too large to cache: {} bytes", final_data.len()); } - + Ok(build_response( - final_data, - &final_content_type, - &state.config.server.via_header, + final_data, + &final_content_type, + &state.config.server.via_header, upstream_headers.as_ref(), false, // is_cache_hit )) @@ -301,17 +421,17 @@ pub async fn proxy_handler( /// Parse query string to extract format parameter and return modified query /// Returns (format_option, remaining_query_string) -/// +/// /// This parser is intentionally simple and only handles basic ASCII format values /// ("avif", "webp") with case-insensitive matching. It handles '+' as space /// (common in query strings) but does not perform full URL decoding. -/// +/// /// Cloudflare Transform Rules generate clean query parameters like "format=avif" /// so complex URL decoding is not necessary for this use case. fn parse_query_for_format(query: &str) -> (Option, String) { let mut format_value = None; let mut remaining_params = Vec::new(); - + for param in query.split('&') { if let Some((key, value)) = param.split_once('=') { if key == "format" { @@ -332,7 +452,7 @@ fn parse_query_for_format(query: &str) -> (Option, String) { remaining_params.push(param); } } - + (format_value, remaining_params.join("&")) } @@ -348,42 +468,41 @@ fn should_convert_image( if !is_image_content_type(content_type) { return false; } - + // Must not be requesting original format if desired_format == OutputFormat::Original { return false; } - + // Must be within size limits if content_size > max_size { return false; } - + // Skip conversion if upstream format already satisfies desired format !matches!(upstream_format, Some(fmt) if format_satisfies(fmt, desired_format)) } /// Build HTTP response with appropriate headers fn build_response( - data: Bytes, - content_type: &str, + data: Bytes, + content_type: &str, via_header: &str, upstream_headers: Option<&HeaderMap>, is_cache_hit: bool, ) -> Response { - let mut builder = Response::builder() - .status(StatusCode::OK); - + let mut builder = Response::builder().status(StatusCode::OK); + // Check if upstream has CORS header let upstream_has_cors = upstream_headers .map(|h| h.contains_key(header::ACCESS_CONTROL_ALLOW_ORIGIN)) .unwrap_or(false); - + // Check if upstream has Vary header let upstream_vary = upstream_headers .and_then(|h| h.get(header::VARY)) .and_then(|v| v.to_str().ok()); - + // Add upstream headers if configured if let Some(headers) = upstream_headers { for (key, value) in headers.iter() { @@ -394,23 +513,23 @@ fn build_response( } } } - + // Always set/override these headers builder = builder .header(header::CONTENT_TYPE, content_type) .header(header::VIA, via_header) .header(header::CACHE_CONTROL, "public, max-age=31536000, immutable") .header("X-Cache-Status", if is_cache_hit { "HIT" } else { "MISS" }); - + // Always add Vary header with Accept // If upstream has Vary header, prepend "Accept" to it builder = builder.header(header::VARY, build_vary_header(upstream_vary)); - + // Only set CORS header if upstream didn't provide one if !upstream_has_cors { builder = builder.header(header::ACCESS_CONTROL_ALLOW_ORIGIN, "*"); } - + builder .body(Body::from(data)) .expect("Failed to build response") @@ -423,19 +542,18 @@ fn build_response_with_status( via_header: &str, upstream_headers: Option<&HeaderMap>, ) -> Response { - let mut builder = Response::builder() - .status(status); - + let mut builder = Response::builder().status(status); + // Check if upstream has CORS header let upstream_has_cors = upstream_headers .map(|h| h.contains_key(header::ACCESS_CONTROL_ALLOW_ORIGIN)) .unwrap_or(false); - + // Check if upstream has Vary header let upstream_vary = upstream_headers .and_then(|h| h.get(header::VARY)) .and_then(|v| v.to_str().ok()); - + // Add upstream headers if configured if let Some(headers) = upstream_headers { for (key, value) in headers.iter() { @@ -446,19 +564,19 @@ fn build_response_with_status( } } } - + // Always add Via header builder = builder.header(header::VIA, via_header); - + // Always add Vary header with Accept // If upstream has Vary header, prepend "Accept" to it builder = builder.header(header::VARY, build_vary_header(upstream_vary)); - + // Only set CORS header if upstream didn't provide one if !upstream_has_cors { builder = builder.header(header::ACCESS_CONTROL_ALLOW_ORIGIN, "*"); } - + builder .body(Body::from(data)) .expect("Failed to build response with status") @@ -474,10 +592,9 @@ pub async fn metrics_handler(State(state): State) -> impl IntoResponse let stats = state.cache.stats(); let body = format!( "# Cache Statistics\ncache_entries {}\ncache_size_bytes {}\n", - stats.entry_count, - stats.weighted_size + stats.entry_count, stats.weighted_size ); - + ( StatusCode::OK, [(header::CONTENT_TYPE, "text/plain; version=0.0.4")], @@ -495,14 +612,12 @@ pub enum ProxyError { impl IntoResponse for ProxyError { fn into_response(self) -> Response { let (status, message) = match self { - ProxyError::PathNotAllowed => { - (StatusCode::FORBIDDEN, "Path not allowed".to_string()) - } + ProxyError::PathNotAllowed => (StatusCode::FORBIDDEN, "Path not allowed".to_string()), ProxyError::UpstreamError(e) => { (StatusCode::BAD_GATEWAY, format!("Upstream error: {}", e)) } }; - + (status, message).into_response() } } @@ -519,9 +634,15 @@ mod tests { upstream_headers.insert(header::CONTENT_TYPE, HeaderValue::from_static("image/jpeg")); upstream_headers.insert(header::VIA, HeaderValue::from_static("upstream-proxy")); upstream_headers.insert(header::CACHE_CONTROL, HeaderValue::from_static("no-cache")); - upstream_headers.insert(HeaderName::from_static("x-cache-status"), HeaderValue::from_static("upstream-hit")); - upstream_headers.insert(HeaderName::from_static("x-custom-header"), HeaderValue::from_static("custom-value")); - + upstream_headers.insert( + HeaderName::from_static("x-cache-status"), + HeaderValue::from_static("upstream-hit"), + ); + upstream_headers.insert( + HeaderName::from_static("x-custom-header"), + HeaderValue::from_static("custom-value"), + ); + // Build response with different content-type let response = build_response( Bytes::from("test data"), @@ -530,41 +651,62 @@ mod tests { Some(&upstream_headers), true, ); - + let headers = response.headers(); - + // Content-Type should only have the proxy's value (image/avif), not upstream's (image/jpeg) let content_types: Vec<_> = headers.get_all(header::CONTENT_TYPE).iter().collect(); - assert_eq!(content_types.len(), 1, "Content-Type should not be duplicated"); + assert_eq!( + content_types.len(), + 1, + "Content-Type should not be duplicated" + ); assert_eq!(content_types[0], "image/avif"); - + // Via should only have the proxy's value let via_values: Vec<_> = headers.get_all(header::VIA).iter().collect(); assert_eq!(via_values.len(), 1, "Via should not be duplicated"); assert_eq!(via_values[0], "akkoproxy/1.0"); - + // Cache-Control should only have the proxy's value let cache_control_values: Vec<_> = headers.get_all(header::CACHE_CONTROL).iter().collect(); - assert_eq!(cache_control_values.len(), 1, "Cache-Control should not be duplicated"); - assert_eq!(cache_control_values[0], "public, max-age=31536000, immutable"); - + assert_eq!( + cache_control_values.len(), + 1, + "Cache-Control should not be duplicated" + ); + assert_eq!( + cache_control_values[0], + "public, max-age=31536000, immutable" + ); + // X-Cache-Status should only have the proxy's value let x_cache_status_values: Vec<_> = headers.get_all("x-cache-status").iter().collect(); - assert_eq!(x_cache_status_values.len(), 1, "X-Cache-Status should not be duplicated"); + assert_eq!( + x_cache_status_values.len(), + 1, + "X-Cache-Status should not be duplicated" + ); assert_eq!(x_cache_status_values[0], "HIT"); - + // Custom header should be preserved assert_eq!(headers.get("x-custom-header").unwrap(), "custom-value"); } - + #[test] fn test_build_response_with_status_no_duplicate_headers() { // Create upstream headers let mut upstream_headers = HeaderMap::new(); upstream_headers.insert(header::VIA, HeaderValue::from_static("upstream-proxy")); - upstream_headers.insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, HeaderValue::from_static("https://example.com")); - upstream_headers.insert(HeaderName::from_static("x-custom-header"), HeaderValue::from_static("custom-value")); - + upstream_headers.insert( + header::ACCESS_CONTROL_ALLOW_ORIGIN, + HeaderValue::from_static("https://example.com"), + ); + upstream_headers.insert( + HeaderName::from_static("x-custom-header"), + HeaderValue::from_static("custom-value"), + ); + // Build response let response = build_response_with_status( Bytes::from("redirect"), @@ -572,76 +714,86 @@ mod tests { "akkoproxy/1.0", Some(&upstream_headers), ); - + let headers = response.headers(); - + // Via should only have the proxy's value let via_values: Vec<_> = headers.get_all(header::VIA).iter().collect(); assert_eq!(via_values.len(), 1, "Via should not be duplicated"); assert_eq!(via_values[0], "akkoproxy/1.0"); - + // Access-Control-Allow-Origin should have upstream's value (not replaced) - let acao_values: Vec<_> = headers.get_all(header::ACCESS_CONTROL_ALLOW_ORIGIN).iter().collect(); - assert_eq!(acao_values.len(), 1, "Access-Control-Allow-Origin should not be duplicated"); + let acao_values: Vec<_> = headers + .get_all(header::ACCESS_CONTROL_ALLOW_ORIGIN) + .iter() + .collect(); + assert_eq!( + acao_values.len(), + 1, + "Access-Control-Allow-Origin should not be duplicated" + ); assert_eq!(acao_values[0], "https://example.com"); - + // Custom header should be preserved assert_eq!(headers.get("x-custom-header").unwrap(), "custom-value"); } - + #[test] fn test_parse_query_for_format() { // Test format=avif let (format, remaining) = parse_query_for_format("format=avif&other=value"); assert_eq!(format, Some(OutputFormat::Avif)); assert_eq!(remaining, "other=value"); - + // Test format=webp let (format, remaining) = parse_query_for_format("format=webp"); assert_eq!(format, Some(OutputFormat::WebP)); assert_eq!(remaining, ""); - + // Test no format parameter let (format, remaining) = parse_query_for_format("other=value&another=test"); assert_eq!(format, None); assert_eq!(remaining, "other=value&another=test"); - + // Test format with unknown value let (format, remaining) = parse_query_for_format("format=jpeg&other=value"); assert_eq!(format, None); assert_eq!(remaining, "other=value"); - + // Test format in middle let (format, remaining) = parse_query_for_format("a=1&format=avif&b=2"); assert_eq!(format, Some(OutputFormat::Avif)); assert_eq!(remaining, "a=1&b=2"); - + // Test format with URL encoding (spaces as +) let (format, remaining) = parse_query_for_format("format=webp+test&other=value"); assert_eq!(format, None); // Should not match due to extra text assert_eq!(remaining, "other=value"); - + // Test format with case insensitivity let (format, remaining) = parse_query_for_format("format=AVIF"); assert_eq!(format, Some(OutputFormat::Avif)); assert_eq!(remaining, ""); - + let (format, remaining) = parse_query_for_format("format=WebP"); assert_eq!(format, Some(OutputFormat::WebP)); assert_eq!(remaining, ""); - + // Test format with whitespace (+ is space in query strings) let (format, remaining) = parse_query_for_format("format=+avif+&other=value"); assert_eq!(format, Some(OutputFormat::Avif)); assert_eq!(remaining, "other=value"); } - + #[test] fn test_cors_header_follows_upstream() { // Test when upstream provides CORS header let mut upstream_headers = HeaderMap::new(); - upstream_headers.insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, HeaderValue::from_static("https://example.com")); - + upstream_headers.insert( + header::ACCESS_CONTROL_ALLOW_ORIGIN, + HeaderValue::from_static("https://example.com"), + ); + let response = build_response( Bytes::from("test"), "text/plain", @@ -649,10 +801,16 @@ mod tests { Some(&upstream_headers), false, ); - + // Should use upstream CORS value - assert_eq!(response.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN).unwrap(), "https://example.com"); - + assert_eq!( + response + .headers() + .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) + .unwrap(), + "https://example.com" + ); + // Test when upstream doesn't provide CORS header let response = build_response( Bytes::from("test"), @@ -661,11 +819,17 @@ mod tests { None, false, ); - + // Should use default "*" - assert_eq!(response.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN).unwrap(), "*"); + assert_eq!( + response + .headers() + .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) + .unwrap(), + "*" + ); } - + #[test] fn test_vary_header_always_present() { // Test with no upstream headers - should have Vary: Accept @@ -676,9 +840,9 @@ mod tests { None, false, ); - + assert_eq!(response.headers().get(header::VARY).unwrap(), "Accept"); - + // Test with upstream headers (no Vary) - should still have Vary: Accept let upstream_headers = HeaderMap::new(); let response = build_response( @@ -688,16 +852,16 @@ mod tests { Some(&upstream_headers), false, ); - + assert_eq!(response.headers().get(header::VARY).unwrap(), "Accept"); } - + #[test] fn test_vary_header_prepends_accept_to_upstream() { // Test when upstream has Vary header without Accept let mut upstream_headers = HeaderMap::new(); upstream_headers.insert(header::VARY, HeaderValue::from_static("Origin, User-Agent")); - + let response = build_response( Bytes::from("test"), "text/plain", @@ -705,13 +869,16 @@ mod tests { Some(&upstream_headers), false, ); - - assert_eq!(response.headers().get(header::VARY).unwrap(), "Accept, Origin, User-Agent"); - + + assert_eq!( + response.headers().get(header::VARY).unwrap(), + "Accept, Origin, User-Agent" + ); + // Test when upstream has Vary header with Accept already present let mut upstream_headers = HeaderMap::new(); upstream_headers.insert(header::VARY, HeaderValue::from_static("Accept, Origin")); - + let response = build_response( Bytes::from("test"), "text/plain", @@ -719,14 +886,17 @@ mod tests { Some(&upstream_headers), false, ); - + // Should not duplicate Accept - assert_eq!(response.headers().get(header::VARY).unwrap(), "Accept, Origin"); - + assert_eq!( + response.headers().get(header::VARY).unwrap(), + "Accept, Origin" + ); + // Test when upstream has Vary header with accept in different case let mut upstream_headers = HeaderMap::new(); upstream_headers.insert(header::VARY, HeaderValue::from_static("ACCEPT, Origin")); - + let response = build_response( Bytes::from("test"), "text/plain", @@ -734,11 +904,14 @@ mod tests { Some(&upstream_headers), false, ); - + // Should recognize case-insensitive match and not duplicate - assert_eq!(response.headers().get(header::VARY).unwrap(), "ACCEPT, Origin"); + assert_eq!( + response.headers().get(header::VARY).unwrap(), + "ACCEPT, Origin" + ); } - + #[test] fn test_vary_header_in_response_with_status() { // Test without upstream Vary header @@ -748,20 +921,171 @@ mod tests { "akkoproxy/1.0", None, ); - + assert_eq!(response.headers().get(header::VARY).unwrap(), "Accept"); - + // Test with upstream Vary header let mut upstream_headers = HeaderMap::new(); upstream_headers.insert(header::VARY, HeaderValue::from_static("Origin")); - + let response = build_response_with_status( Bytes::from("test"), StatusCode::MOVED_PERMANENTLY, "akkoproxy/1.0", Some(&upstream_headers), ); + + assert_eq!( + response.headers().get(header::VARY).unwrap(), + "Accept, Origin" + ); + } + + #[test] + fn test_is_trusted_proxy() { + // Test with empty trusted proxies list + let trusted_proxies: Vec = vec![]; + let ip = "192.168.1.1".parse::().unwrap(); + assert!(!is_trusted_proxy(&ip, &trusted_proxies)); + + // Test with single IP match + let trusted_proxies = vec!["192.168.1.1".to_string()]; + let ip = "192.168.1.1".parse::().unwrap(); + assert!(is_trusted_proxy(&ip, &trusted_proxies)); + + // Test with single IP no match + let ip = "192.168.1.2".parse::().unwrap(); + assert!(!is_trusted_proxy(&ip, &trusted_proxies)); + + // Test with CIDR range match + let trusted_proxies = vec!["192.168.1.0/24".to_string()]; + let ip = "192.168.1.100".parse::().unwrap(); + assert!(is_trusted_proxy(&ip, &trusted_proxies)); + + // Test with CIDR range no match + let ip = "192.168.2.100".parse::().unwrap(); + assert!(!is_trusted_proxy(&ip, &trusted_proxies)); + + // Test with multiple entries + let trusted_proxies = vec![ + "10.0.0.0/8".to_string(), + "172.16.0.0/12".to_string(), + "192.168.1.1".to_string(), + ]; + assert!(is_trusted_proxy(&"10.5.5.5".parse::().unwrap(), &trusted_proxies)); + assert!(is_trusted_proxy(&"172.20.1.1".parse::().unwrap(), &trusted_proxies)); + assert!(is_trusted_proxy(&"192.168.1.1".parse::().unwrap(), &trusted_proxies)); + assert!(!is_trusted_proxy(&"8.8.8.8".parse::().unwrap(), &trusted_proxies)); + + // Test with IPv6 + let trusted_proxies = vec!["::1".to_string(), "fe80::/10".to_string()]; + assert!(is_trusted_proxy(&"::1".parse::().unwrap(), &trusted_proxies)); + assert!(is_trusted_proxy(&"fe80::1".parse::().unwrap(), &trusted_proxies)); + assert!(!is_trusted_proxy(&"2001:db8::1".parse::().unwrap(), &trusted_proxies)); + } + + #[test] + fn test_apply_forwarded_headers_disabled() { + use reqwest::Client; + + let config = Config::with_upstream("https://example.com".to_string()); + assert!(!config.server.forward_headers_enabled); + + let client = Client::new(); + let builder = client.get("https://example.com/test"); + + let mut headers = HeaderMap::new(); + headers.insert("x-forwarded-proto", HeaderValue::from_static("https")); + headers.insert("x-forwarded-for", HeaderValue::from_static("1.2.3.4")); + headers.insert("x-forwarded-host", HeaderValue::from_static("example.org")); + + let client_ip = "10.0.0.1".parse::().unwrap(); + + let result_builder = apply_forwarded_headers(builder, &headers, client_ip, &config); + + // Headers should not be forwarded when disabled + let request = result_builder.build().unwrap(); + assert!(request.headers().get("x-forwarded-proto").is_none()); + assert!(request.headers().get("x-forwarded-for").is_none()); + assert!(request.headers().get("x-forwarded-host").is_none()); + } + + #[test] + fn test_apply_forwarded_headers_from_trusted_proxy() { + use reqwest::Client; + + let mut config = Config::with_upstream("https://example.com".to_string()); + config.server.forward_headers_enabled = true; + config.server.trusted_proxies = vec!["10.0.0.0/8".to_string()]; + + let client = Client::new(); + let builder = client.get("https://example.com/test"); + + let mut headers = HeaderMap::new(); + headers.insert("x-forwarded-proto", HeaderValue::from_static("https")); + headers.insert("x-forwarded-for", HeaderValue::from_static("1.2.3.4")); + headers.insert("x-forwarded-host", HeaderValue::from_static("example.org")); + + let client_ip = "10.0.0.1".parse::().unwrap(); // Trusted + + let result_builder = apply_forwarded_headers(builder, &headers, client_ip, &config); + + // Headers should be forwarded from trusted proxy + let request = result_builder.build().unwrap(); + assert_eq!(request.headers().get("x-forwarded-proto").unwrap(), "https"); + assert_eq!(request.headers().get("x-forwarded-for").unwrap(), "1.2.3.4"); + assert_eq!(request.headers().get("x-forwarded-host").unwrap(), "example.org"); + } + + #[test] + fn test_apply_forwarded_headers_from_untrusted_proxy() { + use reqwest::Client; + + let mut config = Config::with_upstream("https://example.com".to_string()); + config.server.forward_headers_enabled = true; + config.server.trusted_proxies = vec!["10.0.0.0/8".to_string()]; + + let client = Client::new(); + let builder = client.get("https://example.com/test"); + + let mut headers = HeaderMap::new(); + headers.insert("x-forwarded-proto", HeaderValue::from_static("https")); + headers.insert("x-forwarded-for", HeaderValue::from_static("1.2.3.4")); + headers.insert("x-forwarded-host", HeaderValue::from_static("example.org")); + + let client_ip = "8.8.8.8".parse::().unwrap(); // Untrusted + + let result_builder = apply_forwarded_headers(builder, &headers, client_ip, &config); + + // Proto and Host should not be forwarded from untrusted source + let request = result_builder.build().unwrap(); + assert!(request.headers().get("x-forwarded-proto").is_none()); + assert!(request.headers().get("x-forwarded-host").is_none()); + + // X-Forwarded-For should be set to actual client IP only (not appending untrusted value) + let xff = request.headers().get("x-forwarded-for").unwrap().to_str().unwrap(); + assert_eq!(xff, "8.8.8.8"); + } + + #[test] + fn test_apply_forwarded_headers_no_existing_xff() { + use reqwest::Client; + + let mut config = Config::with_upstream("https://example.com".to_string()); + config.server.forward_headers_enabled = true; + config.server.trusted_proxies = vec!["10.0.0.0/8".to_string()]; + + let client = Client::new(); + let builder = client.get("https://example.com/test"); + + let headers = HeaderMap::new(); // No X-Forwarded headers + + let client_ip = "8.8.8.8".parse::().unwrap(); // Untrusted + + let result_builder = apply_forwarded_headers(builder, &headers, client_ip, &config); - assert_eq!(response.headers().get(header::VARY).unwrap(), "Accept, Origin"); + // X-Forwarded-For should be set to actual client IP + let request = result_builder.build().unwrap(); + assert_eq!(request.headers().get("x-forwarded-for").unwrap(), "8.8.8.8"); } }