From ca8bbeb5b438f15a4d5b709943a0af99fff63bf6 Mon Sep 17 00:00:00 2001 From: Peter Clemente III Date: Fri, 29 May 2026 13:12:19 -0400 Subject: [PATCH] Polish: add unit tests, CI, and formatting - Apply cargo fmt across all source files (no logic changes) - Add 18 unit tests covering pure logic: endpoint URL/body construction, Ollama and llama.cpp stream-line parsing, timing parsing (incl. ms->ns conversion and invalid-input defaults), and mean/std statistics - Add GitHub Actions CI workflow (fmt check + clippy -D warnings + build + test) - Document development/test workflow in README Co-Authored-By: Claude Opus 4.8 (1M context) --- .github/workflows/ci.yml | 42 ++++++++++ README.md | 11 +++ src/bench.rs | 22 ++--- src/detect.rs | 5 +- src/endpoints.rs | 158 +++++++++++++++++++++++++++++++++++- src/main.rs | 11 ++- src/metrics.rs | 46 ++++++++--- src/output.rs | 168 +++++++++++++++++++++++++++++---------- 8 files changed, 395 insertions(+), 68 deletions(-) create mode 100644 .github/workflows/ci.yml diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..25fc396 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,42 @@ +name: CI + +on: + push: + branches: [master] + pull_request: + +env: + CARGO_TERM_COLOR: always + +jobs: + build-test-lint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Install stable Rust toolchain + uses: dtolnay/rust-toolchain@stable + with: + components: rustfmt, clippy + + - name: Cache cargo registry and build + uses: actions/cache@v4 + with: + path: | + ~/.cargo/registry + ~/.cargo/git + target + key: ${{ runner.os }}-cargo-${{ hashFiles('Cargo.lock') }} + restore-keys: ${{ runner.os }}-cargo- + + - name: Check formatting + run: cargo fmt --all --check + + - name: Clippy (deny warnings) + run: cargo clippy --all-targets -- -D warnings + + - name: Build + run: cargo build --verbose + + - name: Test + run: cargo test --verbose diff --git a/README.md b/README.md index 30319be..ae8be8c 100644 --- a/README.md +++ b/README.md @@ -43,4 +43,15 @@ inference-bench throughput --url http://localhost:8080 # custom endpoint | `--url` | | Override endpoint URL | | `--prompt` | | Custom prompt text | +## Development + +```bash +cargo test # run the unit test suite +cargo clippy # lint +cargo fmt # format +``` + +CI (build + test + clippy + fmt) runs on every push and pull request via +GitHub Actions (`.github/workflows/ci.yml`). + Built with Rust + comfy-table. diff --git a/src/bench.rs b/src/bench.rs index 2318ed9..2c2e8ac 100644 --- a/src/bench.rs +++ b/src/bench.rs @@ -73,9 +73,7 @@ fn stream_generate(ep: &Endpoint, model: &str, prompt: &str, num_predict: u32) - EndpointKind::Ollama | EndpointKind::FastFlowLM => { endpoints::parse_ollama_stream_line(&line) } - EndpointKind::LlamaCpp => { - endpoints::parse_llamacpp_stream_line(&line) - } + EndpointKind::LlamaCpp => endpoints::parse_llamacpp_stream_line(&line), }; if let Some((token, done)) = parsed { @@ -93,7 +91,11 @@ fn stream_generate(ep: &Endpoint, model: &str, prompt: &str, num_predict: u32) - let total_time = start.elapsed().as_secs_f64(); let ttft = first_token_time.unwrap_or(total_time * 1000.0); - let tps = if total_time > 0.0 { token_count as f64 / total_time } else { 0.0 }; + let tps = if total_time > 0.0 { + token_count as f64 / total_time + } else { + 0.0 + }; (ttft, tps, token_count) } @@ -108,7 +110,8 @@ pub fn run_ttft(endpoints: &[Endpoint], config: &BenchConfig) -> Vec Vec target_len { - match long_prompt.char_indices().nth(target_len.min(long_prompt.chars().count())) { + match long_prompt + .char_indices() + .nth(target_len.min(long_prompt.chars().count())) + { Some((byte_idx, _)) => long_prompt[..byte_idx].to_string(), None => long_prompt, } @@ -231,9 +237,7 @@ pub fn run_compare(endpoints: &[Endpoint], config: &BenchConfig) -> Vec { endpoints::parse_ollama_timings(&body_str) } - EndpointKind::LlamaCpp => { - endpoints::parse_llamacpp_timings(&body_str) - } + EndpointKind::LlamaCpp => endpoints::parse_llamacpp_timings(&body_str), }; match (timings.prompt_eval_count, timings.prompt_eval_duration_ns) { (Some(count), Some(dur)) if dur > 0 => { diff --git a/src/detect.rs b/src/detect.rs index 335143f..3941929 100644 --- a/src/detect.rs +++ b/src/detect.rs @@ -42,7 +42,10 @@ fn probe(ep: &Endpoint) -> bool { } }; - match ureq::get(&url).timeout(std::time::Duration::from_secs(2)).call() { + match ureq::get(&url) + .timeout(std::time::Duration::from_secs(2)) + .call() + { Ok(resp) => resp.status() == 200, Err(_) => false, } diff --git a/src/endpoints.rs b/src/endpoints.rs index 5ef31fc..619d8c6 100644 --- a/src/endpoints.rs +++ b/src/endpoints.rs @@ -28,7 +28,13 @@ impl Endpoint { } /// Build the request body for a generation request. - pub fn build_body(&self, model: &str, prompt: &str, num_predict: u32, stream: bool) -> serde_json::Value { + pub fn build_body( + &self, + model: &str, + prompt: &str, + num_predict: u32, + stream: bool, + ) -> serde_json::Value { match self.kind { EndpointKind::Ollama | EndpointKind::FastFlowLM => { serde_json::json!({ @@ -56,7 +62,11 @@ impl Endpoint { pub fn parse_ollama_stream_line(line: &str) -> Option<(String, bool)> { let v: serde_json::Value = serde_json::from_str(line).ok()?; let done = v.get("done").and_then(|d| d.as_bool()).unwrap_or(false); - let token = v.get("response").and_then(|r| r.as_str()).unwrap_or("").to_string(); + let token = v + .get("response") + .and_then(|r| r.as_str()) + .unwrap_or("") + .to_string(); Some((token, done)) } @@ -69,7 +79,11 @@ pub fn parse_llamacpp_stream_line(line: &str) -> Option<(String, bool)> { } let v: serde_json::Value = serde_json::from_str(data).ok()?; let stop = v.get("stop").and_then(|s| s.as_bool()).unwrap_or(false); - let token = v.get("content").and_then(|c| c.as_str()).unwrap_or("").to_string(); + let token = v + .get("content") + .and_then(|c| c.as_str()) + .unwrap_or("") + .to_string(); Some((token, stop)) } @@ -122,3 +136,141 @@ pub fn parse_llamacpp_timings(body: &str) -> ApiTimings { prompt_eval_count: prompt_n, } } + +#[cfg(test)] +mod tests { + use super::*; + + fn ep(kind: EndpointKind) -> Endpoint { + Endpoint { + name: "test".to_string(), + base_url: "http://localhost:9999".to_string(), + kind, + } + } + + #[test] + fn generate_url_per_kind() { + assert_eq!( + ep(EndpointKind::Ollama).generate_url(), + "http://localhost:9999/api/generate" + ); + assert_eq!( + ep(EndpointKind::FastFlowLM).generate_url(), + "http://localhost:9999/api/generate" + ); + assert_eq!( + ep(EndpointKind::LlamaCpp).generate_url(), + "http://localhost:9999/completion" + ); + } + + #[test] + fn build_body_ollama_shape() { + let body = ep(EndpointKind::Ollama).build_body("qwen3:4b", "hi", 42, true); + assert_eq!(body["model"], "qwen3:4b"); + assert_eq!(body["prompt"], "hi"); + assert_eq!(body["stream"], true); + assert_eq!(body["options"]["num_predict"], 42); + } + + #[test] + fn build_body_llamacpp_shape() { + let body = ep(EndpointKind::LlamaCpp).build_body("ignored", "hi", 42, false); + assert_eq!(body["prompt"], "hi"); + assert_eq!(body["n_predict"], 42); + assert_eq!(body["stream"], false); + // llama.cpp body carries no model field + assert!(body.get("model").is_none()); + } + + #[test] + fn parse_ollama_stream_token() { + let line = r#"{"response":"Hello","done":false}"#; + assert_eq!( + parse_ollama_stream_line(line), + Some(("Hello".to_string(), false)) + ); + } + + #[test] + fn parse_ollama_stream_done() { + let line = r#"{"response":"","done":true}"#; + assert_eq!(parse_ollama_stream_line(line), Some((String::new(), true))); + } + + #[test] + fn parse_ollama_stream_invalid_json() { + assert_eq!(parse_ollama_stream_line("not json"), None); + } + + #[test] + fn parse_llamacpp_stream_token() { + let line = r#"data: {"content":"Hi","stop":false}"#; + assert_eq!( + parse_llamacpp_stream_line(line), + Some(("Hi".to_string(), false)) + ); + } + + #[test] + fn parse_llamacpp_stream_done_marker() { + assert_eq!( + parse_llamacpp_stream_line("data: [DONE]"), + Some((String::new(), true)) + ); + } + + #[test] + fn parse_llamacpp_stream_stop_flag() { + let line = r#"data: {"content":"","stop":true}"#; + assert_eq!( + parse_llamacpp_stream_line(line), + Some((String::new(), true)) + ); + } + + #[test] + fn parse_llamacpp_stream_without_prefix_is_none() { + // Lines that are not SSE `data:` frames are ignored. + assert_eq!(parse_llamacpp_stream_line("event: ping"), None); + } + + #[test] + fn parse_ollama_timings_extracts_fields() { + let body = r#"{"total_duration":1000,"prompt_eval_duration":400, + "eval_duration":600,"eval_count":12,"prompt_eval_count":34}"#; + let t = parse_ollama_timings(body); + assert_eq!(t.total_duration_ns, Some(1000)); + assert_eq!(t.prompt_eval_duration_ns, Some(400)); + assert_eq!(t.eval_duration_ns, Some(600)); + assert_eq!(t.eval_count, Some(12)); + assert_eq!(t.prompt_eval_count, Some(34)); + } + + #[test] + fn parse_ollama_timings_invalid_is_default() { + let t = parse_ollama_timings("garbage"); + assert_eq!(t.total_duration_ns, None); + assert_eq!(t.prompt_eval_count, None); + } + + #[test] + fn parse_llamacpp_timings_converts_ms_to_ns() { + let body = r#"{"timings":{"prompt_ms":10.0,"predicted_ms":20.0, + "predicted_n":5,"prompt_n":7}}"#; + let t = parse_llamacpp_timings(body); + assert_eq!(t.prompt_eval_duration_ns, Some(10_000_000)); + assert_eq!(t.eval_duration_ns, Some(20_000_000)); + assert_eq!(t.total_duration_ns, Some(30_000_000)); + assert_eq!(t.eval_count, Some(5)); + assert_eq!(t.prompt_eval_count, Some(7)); + } + + #[test] + fn parse_llamacpp_timings_missing_block_is_default() { + let t = parse_llamacpp_timings(r#"{"content":"hi"}"#); + assert_eq!(t.prompt_eval_duration_ns, None); + assert_eq!(t.total_duration_ns, None); + } +} diff --git a/src/main.rs b/src/main.rs index ef757f5..cbae654 100644 --- a/src/main.rs +++ b/src/main.rs @@ -7,7 +7,11 @@ mod output; use clap::{Parser, Subcommand}; #[derive(Parser)] -#[command(name = "inference-bench", version, about = "LLM inference benchmark CLI")] +#[command( + name = "inference-bench", + version, + about = "LLM inference benchmark CLI" +)] struct Cli { #[command(subcommand)] command: Commands, @@ -64,7 +68,10 @@ fn main() { let config = bench::BenchConfig { model: cli.model.clone(), - prompt: cli.prompt.clone().unwrap_or_else(|| "Explain the theory of relativity in simple terms.".to_string()), + prompt: cli + .prompt + .clone() + .unwrap_or_else(|| "Explain the theory of relativity in simple terms.".to_string()), tokens: cli.tokens, repeat: cli.repeat, warmup: cli.warmup, diff --git a/src/metrics.rs b/src/metrics.rs index b7e754b..c6dc0c7 100644 --- a/src/metrics.rs +++ b/src/metrics.rs @@ -54,15 +54,12 @@ fn read_gpu_mem_used(mem_type: &str) -> u64 { if !name.starts_with("card") || name.contains('-') { continue; } - let total_path = format!( - "{}/{}/device/mem_info_{}_total", - drm_dir, name, mem_type - ); - let used_path = format!( - "{}/{}/device/mem_info_{}_used", - drm_dir, name, mem_type - ); - if let (Ok(total_s), Ok(used_s)) = (fs::read_to_string(&total_path), fs::read_to_string(&used_path)) { + let total_path = format!("{}/{}/device/mem_info_{}_total", drm_dir, name, mem_type); + let used_path = format!("{}/{}/device/mem_info_{}_used", drm_dir, name, mem_type); + if let (Ok(total_s), Ok(used_s)) = ( + fs::read_to_string(&total_path), + fs::read_to_string(&used_path), + ) { let total: u64 = total_s.trim().parse().unwrap_or(0); let used: u64 = used_s.trim().parse().unwrap_or(0); if total > 0 { @@ -86,3 +83,34 @@ pub fn mean_std(values: &[f64]) -> (f64, f64) { let variance = values.iter().map(|v| (v - mean).powi(2)).sum::() / (n - 1.0); (mean, variance.sqrt()) } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn mean_std_empty() { + assert_eq!(mean_std(&[]), (0.0, 0.0)); + } + + #[test] + fn mean_std_single() { + // A single sample has no spread; std is defined as 0. + assert_eq!(mean_std(&[5.0]), (5.0, 0.0)); + } + + #[test] + fn mean_std_known_values() { + // Sample std (n-1 denominator) of [2,4,4,4,5,5,7,9] is 2.13809... + let (mean, std) = mean_std(&[2.0, 4.0, 4.0, 4.0, 5.0, 5.0, 7.0, 9.0]); + assert!((mean - 5.0).abs() < 1e-9); + assert!((std - 2.138_089_935_299_395).abs() < 1e-9); + } + + #[test] + fn mean_std_identical_values_zero_std() { + let (mean, std) = mean_std(&[3.0, 3.0, 3.0]); + assert!((mean - 3.0).abs() < 1e-9); + assert!(std.abs() < 1e-9); + } +} diff --git a/src/output.rs b/src/output.rs index 3efd505..9e25c36 100644 --- a/src/output.rs +++ b/src/output.rs @@ -1,5 +1,5 @@ use crate::bench::{BenchResult, CompareResult, ContextResult}; -use comfy_table::{Table, ContentArrangement, presets::UTF8_FULL, Cell, Color}; +use comfy_table::{presets::UTF8_FULL, Cell, Color, ContentArrangement, Table}; #[derive(Debug, Clone, Copy)] pub enum OutputFormat { @@ -12,22 +12,31 @@ pub fn render(results: &[BenchResult], title: &str, format: OutputFormat) { match format { OutputFormat::Table => render_table(results, title), OutputFormat::Json => { - let data: Vec<_> = results.iter().map(|r| { - serde_json::json!({ - "endpoint": r.endpoint, - "metric": r.metric_name, - "mean": r.mean, - "std": r.std, - "unit": r.unit, - "values": r.values, + let data: Vec<_> = results + .iter() + .map(|r| { + serde_json::json!({ + "endpoint": r.endpoint, + "metric": r.metric_name, + "mean": r.mean, + "std": r.std, + "unit": r.unit, + "values": r.values, + }) }) - }).collect(); - println!("{}", serde_json::to_string_pretty(&data).unwrap_or_default()); + .collect(); + println!( + "{}", + serde_json::to_string_pretty(&data).unwrap_or_default() + ); } OutputFormat::Csv => { println!("endpoint,metric,mean,std,unit"); for r in results { - println!("{},{},{:.2},{:.2},{}", r.endpoint, r.metric_name, r.mean, r.std, r.unit); + println!( + "{},{},{:.2},{:.2},{}", + r.endpoint, r.metric_name, r.mean, r.std, r.unit + ); } } } @@ -38,15 +47,36 @@ fn render_table(results: &[BenchResult], title: &str) { table.load_preset(UTF8_FULL); table.set_content_arrangement(ContentArrangement::Dynamic); table.set_header(vec![ - Cell::new("Endpoint").fg(Color::Rgb { r: 0, g: 255, b: 200 }), - Cell::new(format!("{} (mean)", title)).fg(Color::Rgb { r: 0, g: 255, b: 200 }), - Cell::new("Std Dev").fg(Color::Rgb { r: 0, g: 255, b: 200 }), - Cell::new("Runs").fg(Color::Rgb { r: 0, g: 255, b: 200 }), - Cell::new("Mem Δ (sys)").fg(Color::Rgb { r: 0, g: 255, b: 200 }), + Cell::new("Endpoint").fg(Color::Rgb { + r: 0, + g: 255, + b: 200, + }), + Cell::new(format!("{} (mean)", title)).fg(Color::Rgb { + r: 0, + g: 255, + b: 200, + }), + Cell::new("Std Dev").fg(Color::Rgb { + r: 0, + g: 255, + b: 200, + }), + Cell::new("Runs").fg(Color::Rgb { + r: 0, + g: 255, + b: 200, + }), + Cell::new("Mem Δ (sys)").fg(Color::Rgb { + r: 0, + g: 255, + b: 200, + }), ]); for r in results { - let mem_delta = r.mem_after.system_available_mib as i64 - r.mem_before.system_available_mib as i64; + let mem_delta = + r.mem_after.system_available_mib as i64 - r.mem_before.system_available_mib as i64; table.add_row(vec![ Cell::new(&r.endpoint), Cell::new(format!("{:.1} {}", r.mean, r.unit)), @@ -65,10 +95,26 @@ pub fn render_context(results: &[ContextResult], format: OutputFormat) { table.load_preset(UTF8_FULL); table.set_content_arrangement(ContentArrangement::Dynamic); table.set_header(vec![ - Cell::new("Endpoint").fg(Color::Rgb { r: 0, g: 255, b: 200 }), - Cell::new("Context").fg(Color::Rgb { r: 0, g: 255, b: 200 }), - Cell::new("TTFT (ms)").fg(Color::Rgb { r: 0, g: 255, b: 200 }), - Cell::new("TPS").fg(Color::Rgb { r: 0, g: 255, b: 200 }), + Cell::new("Endpoint").fg(Color::Rgb { + r: 0, + g: 255, + b: 200, + }), + Cell::new("Context").fg(Color::Rgb { + r: 0, + g: 255, + b: 200, + }), + Cell::new("TTFT (ms)").fg(Color::Rgb { + r: 0, + g: 255, + b: 200, + }), + Cell::new("TPS").fg(Color::Rgb { + r: 0, + g: 255, + b: 200, + }), ]); for r in results { table.add_row(vec![ @@ -81,20 +127,29 @@ pub fn render_context(results: &[ContextResult], format: OutputFormat) { println!("{}", table); } OutputFormat::Json => { - let data: Vec<_> = results.iter().map(|r| { - serde_json::json!({ - "endpoint": r.endpoint, - "context_len": r.context_len, - "ttft_ms": r.ttft_ms, - "tps": r.tps, + let data: Vec<_> = results + .iter() + .map(|r| { + serde_json::json!({ + "endpoint": r.endpoint, + "context_len": r.context_len, + "ttft_ms": r.ttft_ms, + "tps": r.tps, + }) }) - }).collect(); - println!("{}", serde_json::to_string_pretty(&data).unwrap_or_default()); + .collect(); + println!( + "{}", + serde_json::to_string_pretty(&data).unwrap_or_default() + ); } OutputFormat::Csv => { println!("endpoint,context_len,ttft_ms,tps"); for r in results { - println!("{},{},{:.0},{:.1}", r.endpoint, r.context_len, r.ttft_ms, r.tps); + println!( + "{},{},{:.0},{:.1}", + r.endpoint, r.context_len, r.ttft_ms, r.tps + ); } } } @@ -107,10 +162,26 @@ pub fn render_compare(results: &[CompareResult], format: OutputFormat) { table.load_preset(UTF8_FULL); table.set_content_arrangement(ContentArrangement::Dynamic); table.set_header(vec![ - Cell::new("Endpoint").fg(Color::Rgb { r: 0, g: 255, b: 200 }), - Cell::new("TTFT (ms)").fg(Color::Rgb { r: 0, g: 255, b: 200 }), - Cell::new("Gen TPS").fg(Color::Rgb { r: 0, g: 255, b: 200 }), - Cell::new("Prompt TPS").fg(Color::Rgb { r: 0, g: 255, b: 200 }), + Cell::new("Endpoint").fg(Color::Rgb { + r: 0, + g: 255, + b: 200, + }), + Cell::new("TTFT (ms)").fg(Color::Rgb { + r: 0, + g: 255, + b: 200, + }), + Cell::new("Gen TPS").fg(Color::Rgb { + r: 0, + g: 255, + b: 200, + }), + Cell::new("Prompt TPS").fg(Color::Rgb { + r: 0, + g: 255, + b: 200, + }), ]); for r in results { let prompt_str = if r.prompt_eval_tps > 0.0 { @@ -128,20 +199,29 @@ pub fn render_compare(results: &[CompareResult], format: OutputFormat) { println!("{}", table); } OutputFormat::Json => { - let data: Vec<_> = results.iter().map(|r| { - serde_json::json!({ - "endpoint": r.endpoint, - "ttft_ms": r.ttft_ms, - "tps": r.tps, - "prompt_eval_tps": r.prompt_eval_tps, + let data: Vec<_> = results + .iter() + .map(|r| { + serde_json::json!({ + "endpoint": r.endpoint, + "ttft_ms": r.ttft_ms, + "tps": r.tps, + "prompt_eval_tps": r.prompt_eval_tps, + }) }) - }).collect(); - println!("{}", serde_json::to_string_pretty(&data).unwrap_or_default()); + .collect(); + println!( + "{}", + serde_json::to_string_pretty(&data).unwrap_or_default() + ); } OutputFormat::Csv => { println!("endpoint,ttft_ms,tps,prompt_eval_tps"); for r in results { - println!("{},{:.0},{:.1},{:.1}", r.endpoint, r.ttft_ms, r.tps, r.prompt_eval_tps); + println!( + "{},{:.0},{:.1},{:.1}", + r.endpoint, r.ttft_ms, r.tps, r.prompt_eval_tps + ); } } }