Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
name: CI

on:
push:
branches: [master]
pull_request:

env:
CARGO_TERM_COLOR: always

jobs:
build-test-lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4

- name: Install stable Rust toolchain
uses: dtolnay/rust-toolchain@stable
with:
components: rustfmt, clippy

- name: Cache cargo registry and build
uses: actions/cache@v4
with:
path: |
~/.cargo/registry
~/.cargo/git
target
key: ${{ runner.os }}-cargo-${{ hashFiles('Cargo.lock') }}
restore-keys: ${{ runner.os }}-cargo-

- name: Check formatting
run: cargo fmt --all --check

- name: Clippy (deny warnings)
run: cargo clippy --all-targets -- -D warnings

- name: Build
run: cargo build --verbose

- name: Test
run: cargo test --verbose
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,15 @@ inference-bench throughput --url http://localhost:8080 # custom endpoint
| `--url` | | Override endpoint URL |
| `--prompt` | | Custom prompt text |

## Development

```bash
cargo test # run the unit test suite
cargo clippy # lint
cargo fmt # format
```

CI (build + test + clippy + fmt) runs on every push and pull request via
GitHub Actions (`.github/workflows/ci.yml`).

Built with Rust + comfy-table.
22 changes: 13 additions & 9 deletions src/bench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,7 @@ fn stream_generate(ep: &Endpoint, model: &str, prompt: &str, num_predict: u32) -
EndpointKind::Ollama | EndpointKind::FastFlowLM => {
endpoints::parse_ollama_stream_line(&line)
}
EndpointKind::LlamaCpp => {
endpoints::parse_llamacpp_stream_line(&line)
}
EndpointKind::LlamaCpp => endpoints::parse_llamacpp_stream_line(&line),
};

if let Some((token, done)) = parsed {
Expand All @@ -93,7 +91,11 @@ fn stream_generate(ep: &Endpoint, model: &str, prompt: &str, num_predict: u32) -

let total_time = start.elapsed().as_secs_f64();
let ttft = first_token_time.unwrap_or(total_time * 1000.0);
let tps = if total_time > 0.0 { token_count as f64 / total_time } else { 0.0 };
let tps = if total_time > 0.0 {
token_count as f64 / total_time
} else {
0.0
};

(ttft, tps, token_count)
}
Expand All @@ -108,7 +110,8 @@ pub fn run_ttft(endpoints: &[Endpoint], config: &BenchConfig) -> Vec<BenchResult
let mem_before = MemorySnapshot::capture();
let mut values = Vec::new();
for i in 0..config.repeat {
let (ttft, _, count) = stream_generate(ep, &config.model, &config.prompt, config.tokens);
let (ttft, _, count) =
stream_generate(ep, &config.model, &config.prompt, config.tokens);
if count == 0 {
eprintln!(" Run {}: FAILED (no tokens)", i + 1);
continue;
Expand Down Expand Up @@ -179,7 +182,10 @@ pub fn run_context(endpoints: &[Endpoint], config: &BenchConfig) -> Vec<ContextR
let long_prompt = config.prompt.repeat((ctx_len / 10).max(1) as usize);
let target_len = ctx_len as usize * 4;
let prompt = if long_prompt.len() > target_len {
match long_prompt.char_indices().nth(target_len.min(long_prompt.chars().count())) {
match long_prompt
.char_indices()
.nth(target_len.min(long_prompt.chars().count()))
{
Some((byte_idx, _)) => long_prompt[..byte_idx].to_string(),
None => long_prompt,
}
Expand Down Expand Up @@ -231,9 +237,7 @@ pub fn run_compare(endpoints: &[Endpoint], config: &BenchConfig) -> Vec<CompareR
EndpointKind::Ollama | EndpointKind::FastFlowLM => {
endpoints::parse_ollama_timings(&body_str)
}
EndpointKind::LlamaCpp => {
endpoints::parse_llamacpp_timings(&body_str)
}
EndpointKind::LlamaCpp => endpoints::parse_llamacpp_timings(&body_str),
};
match (timings.prompt_eval_count, timings.prompt_eval_duration_ns) {
(Some(count), Some(dur)) if dur > 0 => {
Expand Down
5 changes: 4 additions & 1 deletion src/detect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,10 @@ fn probe(ep: &Endpoint) -> bool {
}
};

match ureq::get(&url).timeout(std::time::Duration::from_secs(2)).call() {
match ureq::get(&url)
.timeout(std::time::Duration::from_secs(2))
.call()
{
Ok(resp) => resp.status() == 200,
Err(_) => false,
}
Expand Down
158 changes: 155 additions & 3 deletions src/endpoints.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,13 @@ impl Endpoint {
}

/// Build the request body for a generation request.
pub fn build_body(&self, model: &str, prompt: &str, num_predict: u32, stream: bool) -> serde_json::Value {
pub fn build_body(
&self,
model: &str,
prompt: &str,
num_predict: u32,
stream: bool,
) -> serde_json::Value {
match self.kind {
EndpointKind::Ollama | EndpointKind::FastFlowLM => {
serde_json::json!({
Expand Down Expand Up @@ -56,7 +62,11 @@ impl Endpoint {
pub fn parse_ollama_stream_line(line: &str) -> Option<(String, bool)> {
let v: serde_json::Value = serde_json::from_str(line).ok()?;
let done = v.get("done").and_then(|d| d.as_bool()).unwrap_or(false);
let token = v.get("response").and_then(|r| r.as_str()).unwrap_or("").to_string();
let token = v
.get("response")
.and_then(|r| r.as_str())
.unwrap_or("")
.to_string();
Some((token, done))
}

Expand All @@ -69,7 +79,11 @@ pub fn parse_llamacpp_stream_line(line: &str) -> Option<(String, bool)> {
}
let v: serde_json::Value = serde_json::from_str(data).ok()?;
let stop = v.get("stop").and_then(|s| s.as_bool()).unwrap_or(false);
let token = v.get("content").and_then(|c| c.as_str()).unwrap_or("").to_string();
let token = v
.get("content")
.and_then(|c| c.as_str())
.unwrap_or("")
.to_string();
Some((token, stop))
}

Expand Down Expand Up @@ -122,3 +136,141 @@ pub fn parse_llamacpp_timings(body: &str) -> ApiTimings {
prompt_eval_count: prompt_n,
}
}

#[cfg(test)]
mod tests {
use super::*;

fn ep(kind: EndpointKind) -> Endpoint {
Endpoint {
name: "test".to_string(),
base_url: "http://localhost:9999".to_string(),
kind,
}
}

#[test]
fn generate_url_per_kind() {
assert_eq!(
ep(EndpointKind::Ollama).generate_url(),
"http://localhost:9999/api/generate"
);
assert_eq!(
ep(EndpointKind::FastFlowLM).generate_url(),
"http://localhost:9999/api/generate"
);
assert_eq!(
ep(EndpointKind::LlamaCpp).generate_url(),
"http://localhost:9999/completion"
);
}

#[test]
fn build_body_ollama_shape() {
let body = ep(EndpointKind::Ollama).build_body("qwen3:4b", "hi", 42, true);
assert_eq!(body["model"], "qwen3:4b");
assert_eq!(body["prompt"], "hi");
assert_eq!(body["stream"], true);
assert_eq!(body["options"]["num_predict"], 42);
}

#[test]
fn build_body_llamacpp_shape() {
let body = ep(EndpointKind::LlamaCpp).build_body("ignored", "hi", 42, false);
assert_eq!(body["prompt"], "hi");
assert_eq!(body["n_predict"], 42);
assert_eq!(body["stream"], false);
// llama.cpp body carries no model field
assert!(body.get("model").is_none());
}

#[test]
fn parse_ollama_stream_token() {
let line = r#"{"response":"Hello","done":false}"#;
assert_eq!(
parse_ollama_stream_line(line),
Some(("Hello".to_string(), false))
);
}

#[test]
fn parse_ollama_stream_done() {
let line = r#"{"response":"","done":true}"#;
assert_eq!(parse_ollama_stream_line(line), Some((String::new(), true)));
}

#[test]
fn parse_ollama_stream_invalid_json() {
assert_eq!(parse_ollama_stream_line("not json"), None);
}

#[test]
fn parse_llamacpp_stream_token() {
let line = r#"data: {"content":"Hi","stop":false}"#;
assert_eq!(
parse_llamacpp_stream_line(line),
Some(("Hi".to_string(), false))
);
}

#[test]
fn parse_llamacpp_stream_done_marker() {
assert_eq!(
parse_llamacpp_stream_line("data: [DONE]"),
Some((String::new(), true))
);
}

#[test]
fn parse_llamacpp_stream_stop_flag() {
let line = r#"data: {"content":"","stop":true}"#;
assert_eq!(
parse_llamacpp_stream_line(line),
Some((String::new(), true))
);
}

#[test]
fn parse_llamacpp_stream_without_prefix_is_none() {
// Lines that are not SSE `data:` frames are ignored.
assert_eq!(parse_llamacpp_stream_line("event: ping"), None);
}

#[test]
fn parse_ollama_timings_extracts_fields() {
let body = r#"{"total_duration":1000,"prompt_eval_duration":400,
"eval_duration":600,"eval_count":12,"prompt_eval_count":34}"#;
let t = parse_ollama_timings(body);
assert_eq!(t.total_duration_ns, Some(1000));
assert_eq!(t.prompt_eval_duration_ns, Some(400));
assert_eq!(t.eval_duration_ns, Some(600));
assert_eq!(t.eval_count, Some(12));
assert_eq!(t.prompt_eval_count, Some(34));
}

#[test]
fn parse_ollama_timings_invalid_is_default() {
let t = parse_ollama_timings("garbage");
assert_eq!(t.total_duration_ns, None);
assert_eq!(t.prompt_eval_count, None);
}

#[test]
fn parse_llamacpp_timings_converts_ms_to_ns() {
let body = r#"{"timings":{"prompt_ms":10.0,"predicted_ms":20.0,
"predicted_n":5,"prompt_n":7}}"#;
let t = parse_llamacpp_timings(body);
assert_eq!(t.prompt_eval_duration_ns, Some(10_000_000));
assert_eq!(t.eval_duration_ns, Some(20_000_000));
assert_eq!(t.total_duration_ns, Some(30_000_000));
assert_eq!(t.eval_count, Some(5));
assert_eq!(t.prompt_eval_count, Some(7));
}

#[test]
fn parse_llamacpp_timings_missing_block_is_default() {
let t = parse_llamacpp_timings(r#"{"content":"hi"}"#);
assert_eq!(t.prompt_eval_duration_ns, None);
assert_eq!(t.total_duration_ns, None);
}
}
11 changes: 9 additions & 2 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@ mod output;
use clap::{Parser, Subcommand};

#[derive(Parser)]
#[command(name = "inference-bench", version, about = "LLM inference benchmark CLI")]
#[command(
name = "inference-bench",
version,
about = "LLM inference benchmark CLI"
)]
struct Cli {
#[command(subcommand)]
command: Commands,
Expand Down Expand Up @@ -64,7 +68,10 @@ fn main() {

let config = bench::BenchConfig {
model: cli.model.clone(),
prompt: cli.prompt.clone().unwrap_or_else(|| "Explain the theory of relativity in simple terms.".to_string()),
prompt: cli
.prompt
.clone()
.unwrap_or_else(|| "Explain the theory of relativity in simple terms.".to_string()),
tokens: cli.tokens,
repeat: cli.repeat,
warmup: cli.warmup,
Expand Down
46 changes: 37 additions & 9 deletions src/metrics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,15 +54,12 @@ fn read_gpu_mem_used(mem_type: &str) -> u64 {
if !name.starts_with("card") || name.contains('-') {
continue;
}
let total_path = format!(
"{}/{}/device/mem_info_{}_total",
drm_dir, name, mem_type
);
let used_path = format!(
"{}/{}/device/mem_info_{}_used",
drm_dir, name, mem_type
);
if let (Ok(total_s), Ok(used_s)) = (fs::read_to_string(&total_path), fs::read_to_string(&used_path)) {
let total_path = format!("{}/{}/device/mem_info_{}_total", drm_dir, name, mem_type);
let used_path = format!("{}/{}/device/mem_info_{}_used", drm_dir, name, mem_type);
if let (Ok(total_s), Ok(used_s)) = (
fs::read_to_string(&total_path),
fs::read_to_string(&used_path),
) {
let total: u64 = total_s.trim().parse().unwrap_or(0);
let used: u64 = used_s.trim().parse().unwrap_or(0);
if total > 0 {
Expand All @@ -86,3 +83,34 @@ pub fn mean_std(values: &[f64]) -> (f64, f64) {
let variance = values.iter().map(|v| (v - mean).powi(2)).sum::<f64>() / (n - 1.0);
(mean, variance.sqrt())
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn mean_std_empty() {
assert_eq!(mean_std(&[]), (0.0, 0.0));
}

#[test]
fn mean_std_single() {
// A single sample has no spread; std is defined as 0.
assert_eq!(mean_std(&[5.0]), (5.0, 0.0));
}

#[test]
fn mean_std_known_values() {
// Sample std (n-1 denominator) of [2,4,4,4,5,5,7,9] is 2.13809...
let (mean, std) = mean_std(&[2.0, 4.0, 4.0, 4.0, 5.0, 5.0, 7.0, 9.0]);
assert!((mean - 5.0).abs() < 1e-9);
assert!((std - 2.138_089_935_299_395).abs() < 1e-9);
}

#[test]
fn mean_std_identical_values_zero_std() {
let (mean, std) = mean_std(&[3.0, 3.0, 3.0]);
assert!((mean - 3.0).abs() < 1e-9);
assert!(std.abs() < 1e-9);
}
}
Loading
Loading