diff --git a/README.md b/README.md index 95d621b..65ff2b1 100644 --- a/README.md +++ b/README.md @@ -4,8 +4,10 @@ Kernel-level per-tenant admission controller built with Rust and eBPF. Enforces ## Use cases -1. **CPU/network tenant admission**: enforce per-cgroup and optional L4/L7 packet admission policies at the kernel boundary, with durable userspace policy management and Prometheus counters. -2. **LLM inference admission**: run the `examples/inference-admission` controller next to an inference stack to translate token-budget usage, KV-cache pressure, and GPU utilization into Vantage base policies and runtime overrides for inference endpoints. +1. **vLLM inference admission**: run the `examples/inference-admission` controller next to vLLM, scrape `/metrics`, and translate KV-cache pressure, queued/running request counts, and token-budget proxy metrics into Vantage base policies and runtime overrides for inference endpoints. File-backed metrics remain available as a portable fallback for demos and tests. +2. **CPU/network tenant admission**: enforce per-cgroup and optional L4/L7 packet admission policies at the kernel boundary, with durable userspace policy management and Prometheus counters. + +Vantage enforces network admission at the kernel boundary. Inference controllers are userspace adapters that translate model-serving pressure, such as vLLM metrics, file fixtures, and future NVML sources, into Vantage admission policy. Semantic scheduling inside vLLM, CUDA, or the inference runtime itself is out of scope. ## How it works diff --git a/examples/inference-admission/README.md b/examples/inference-admission/README.md index de747c6..f360e1b 100644 --- a/examples/inference-admission/README.md +++ b/examples/inference-admission/README.md @@ -1,34 +1,71 @@ # Inference Admission Example -`vantage-inference-admission` is a userspace controller for LLM inference workloads. It reads token-budget, KV-cache, and GPU utilization samples, then writes Vantage policies for one cgroup and one inference HTTP endpoint. +`vantage-inference-admission` is a userspace controller for LLM inference workloads. In vLLM mode it scrapes `/metrics`, converts KV-cache pressure, waiting/running request counts, and token-budget proxy metrics into a pressure sample, then writes Vantage policies for one cgroup and one inference HTTP endpoint. -This example keeps inference semantics outside the eBPF ABI. Vantage still enforces packet admission through its existing `tc` classifier; the example maps inference pressure into base policies and manual runtime overrides. +Vantage enforces network admission at the kernel boundary. Inference controllers are userspace adapters that translate model-serving pressure, such as vLLM metrics, file fixtures, and future NVML sources, into Vantage admission policy. Semantic scheduling inside vLLM, CUDA, or the inference runtime itself is out of scope. -## Run +## Demo -Start `vantage` first, then run the example: +Run the single-command demo from the repository root: + +```shell +examples/inference-admission/demo.sh +``` + +The demo starts a mock vLLM `/metrics` server and a mock Vantage-compatible API server, then runs the controller and prints visible `Normal -> Throttled -> Exhausted` transitions. It does not require a real GPU, real vLLM process, root, or eBPF attachment. + +## vLLM Mode + +Start `vantage`, run vLLM, then run the controller: ```shell cargo run -p inference-admission -- \ --tenant cg:12345 \ --inference-port 8000 \ --inference-http-path /v1/chat/completions \ - --metrics-file-path /tmp/vantage-inference-metrics.json \ - --gpu-util-file-path /tmp/vantage-gpu-util.json + --metrics-source vllm \ + --vllm-metrics-base-url http://127.0.0.1:8000 \ + --vllm-metrics-path /metrics ``` -The controller writes: +The vLLM adapter parses: -- `PUT /policy/cg:{id}` for the normal base policy. -- `PUT /runtime-policy/cg:{id}` when GPU, KV-cache, or token budget pressure is high. -- `DELETE /runtime-policy/cg:{id}` when all pressure signals recover below their low watermarks. +- `vllm:gpu_cache_usage_perc` +- `vllm:num_requests_waiting` +- `vllm:num_requests_running` +- `vllm:prompt_tokens_total` plus `vllm:generation_tokens_total` as token-budget proxy metrics -Runtime overrides are written through the public API as manual overrides. Do not use the same tenant/flow selector for another manual override while this example is running. +Token counters are converted into scrape-to-scrape deltas for the controller's current budget window. Exhaustion is only enforced when `--disabled-on-exhaustion` is set. + +## File-Backed Fallback -## Input files +File mode is the default and remains useful for portable tests and demos: + +```shell +cargo run -p inference-admission -- \ + --tenant cg:12345 \ + --inference-port 8000 \ + --inference-http-path /v1/chat/completions \ + --metrics-source file \ + --metrics-file-path /tmp/vantage-inference-metrics.json \ + --gpu-util-file-path /tmp/vantage-gpu-util.json +``` Inference pressure: +```json +{ + "ts_unix_ms": 1710000000000, + "tokens_used_current_minute": 54000, + "token_budget_per_minute": 60000, + "kv_cache_percent": 87.5, + "active_requests": 12, + "queued_requests": 3 +} +``` + +The older byte-based KV-cache fields are still accepted: + ```json { "ts_unix_ms": 1710000000000, @@ -41,7 +78,7 @@ Inference pressure: } ``` -GPU utilization: +GPU utilization fallback: ```json { @@ -50,16 +87,27 @@ GPU utilization: } ``` -Missing input files are treated as empty/no-signal samples. Invalid JSON is treated as a tick failure; the controller retains its previously applied state and retries on the next tick. +Missing input files are treated as empty/no-signal samples. Invalid JSON or invalid vLLM metrics are treated as tick failures; the controller retains its previously applied state and retries on the next tick. + +## Vantage Writes + +The controller writes: + +- `PUT /policy/cg:{id}` for the normal base policy. +- `PUT /runtime-policy/cg:{id}` when GPU, KV-cache, or token budget pressure is high. +- `DELETE /runtime-policy/cg:{id}` when all pressure signals recover below their low watermarks. + +Runtime overrides are written through the public API as manual overrides. Do not use the same tenant/flow selector for another manual override while this example is running. ## Scope -In scope for this example: +In scope: - Single tenant cgroup. - Single TCP inference endpoint. - `POST` HTTP path selectors. -- File-backed metrics inputs. +- vLLM Prometheus metrics input. +- File-backed metrics fallback. - Hysteresis-based normal, throttled, and exhausted modes. Out of scope: diff --git a/examples/inference-admission/demo.sh b/examples/inference-admission/demo.sh new file mode 100755 index 0000000..070bf9a --- /dev/null +++ b/examples/inference-admission/demo.sh @@ -0,0 +1,206 @@ +#!/usr/bin/env bash +set -euo pipefail + +if ! command -v python3 >/dev/null 2>&1; then + echo "python3 is required for the inference admission demo" >&2 + exit 1 +fi + +ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/../.." && pwd)" +TMP_DIR="$(mktemp -d)" +VLLM_PORT="18080" +VANTAGE_PORT="13000" +CONTROLLER_PID="" +VLLM_PID="" +VANTAGE_PID="" + +cleanup() { + for pid in "$CONTROLLER_PID" "$VLLM_PID" "$VANTAGE_PID"; do + if [[ -n "$pid" ]] && kill -0 "$pid" >/dev/null 2>&1; then + kill "$pid" >/dev/null 2>&1 || true + wait "$pid" >/dev/null 2>&1 || true + fi + done + rm -rf "$TMP_DIR" +} +trap cleanup EXIT INT TERM + +cat >"$TMP_DIR/mock_vllm.py" <<'PY' +import socketserver +import sys +import time + +START = time.monotonic() + +def metrics(): + elapsed = time.monotonic() - START + if elapsed < 20: + phase, gpu, waiting, running, prompt, generation = "Normal", "0.20", 0, 1, 12, 8 + elif elapsed < 40: + phase, gpu, waiting, running, prompt, generation = "Throttled", "0.95", 8, 16, 50, 30 + else: + phase, gpu, waiting, running, prompt, generation = "Exhausted", "0.98", 12, 20, 170, 150 + body = f"""# HELP vllm:gpu_cache_usage_perc GPU KV-cache utilization. +# TYPE vllm:gpu_cache_usage_perc gauge +vllm:gpu_cache_usage_perc {gpu} +# HELP vllm:num_requests_waiting Number of waiting requests. +# TYPE vllm:num_requests_waiting gauge +vllm:num_requests_waiting {waiting} +# HELP vllm:num_requests_running Number of running requests. +# TYPE vllm:num_requests_running gauge +vllm:num_requests_running {running} +# HELP vllm:prompt_tokens_total Prompt tokens in the current synthetic demo window. +# TYPE vllm:prompt_tokens_total counter +vllm:prompt_tokens_total {prompt} +# HELP vllm:generation_tokens_total Generation tokens in the current synthetic demo window. +# TYPE vllm:generation_tokens_total counter +vllm:generation_tokens_total {generation} +""" + return phase, body.encode() + +def read_request(sock): + sock.settimeout(0.25) + data = b"" + while b"\r\n\r\n" not in data and len(data) < 8192: + try: + chunk = sock.recv(1024) + except TimeoutError: + break + if not chunk: + break + data += chunk + return data + +def respond(sock, status, body=b"", content_type="text/plain"): + reason = "OK" if status == 200 else "Not Found" + headers = ( + f"HTTP/1.1 {status} {reason}\r\n" + f"Content-Length: {len(body)}\r\n" + f"Content-Type: {content_type}\r\n" + "Connection: close\r\n\r\n" + ).encode() + sock.sendall(headers + body) + +class Handler(socketserver.BaseRequestHandler): + def handle(self): + data = read_request(self.request) + first = data.split(b"\r\n", 1)[0].decode(errors="replace") + path = first.split(" ")[1] if " " in first else "" + if path != "/metrics": + respond(self.request, 404) + return + phase, body = metrics() + print(f"mock-vllm phase={phase}", flush=True) + respond(self.request, 200, body, "text/plain; version=0.0.4") + +class Server(socketserver.ThreadingTCPServer): + allow_reuse_address = True + +Server(("127.0.0.1", int(sys.argv[1])), Handler).serve_forever() +PY + +cat >"$TMP_DIR/mock_vantage.py" <<'PY' +import json +import socketserver +import sys +from urllib.parse import urlparse + +LAST_MODE = None + +def announce(mode): + global LAST_MODE + if mode != LAST_MODE: + print(f"admission transition: {mode}", flush=True) + LAST_MODE = mode + +def read_request(sock): + sock.settimeout(0.25) + data = b"" + while b"\r\n\r\n" not in data and len(data) < 65536: + try: + chunk = sock.recv(4096) + except TimeoutError: + break + if not chunk: + break + data += chunk + headers, _, body = data.partition(b"\r\n\r\n") + length = 0 + for line in headers.decode(errors="replace").splitlines(): + if line.lower().startswith("content-length:"): + length = int(line.split(":", 1)[1].strip()) + while len(body) < length: + try: + chunk = sock.recv(length - len(body)) + except TimeoutError: + break + if not chunk: + break + body += chunk + return headers, body + +def respond(sock, status=204): + reason = "No Content" if status == 204 else "Not Found" + sock.sendall( + f"HTTP/1.1 {status} {reason}\r\nContent-Length: 0\r\nConnection: close\r\n\r\n".encode() + ) + +class Handler(socketserver.BaseRequestHandler): + def handle(self): + headers, body = read_request(self.request) + first = headers.split(b"\r\n", 1)[0].decode(errors="replace") + parts = first.split(" ") + method = parts[0] if len(parts) > 0 else "" + raw_path = parts[1] if len(parts) > 1 else "" + path = urlparse(raw_path).path + if method == "PUT" and path.startswith("/policy/cg:"): + respond(self.request, 204) + return + if method == "PUT" and path.startswith("/runtime-policy/cg:"): + payload = json.loads(body.decode() or "{}") + announce("Exhausted" if payload.get("enabled") is False else "Throttled") + respond(self.request, 204) + return + if method == "DELETE" and path.startswith("/runtime-policy/cg:"): + announce("Normal") + respond(self.request, 204) + return + respond(self.request, 404) + +class Server(socketserver.ThreadingTCPServer): + allow_reuse_address = True + +Server(("127.0.0.1", int(sys.argv[1])), Handler).serve_forever() +PY + +python3 "$TMP_DIR/mock_vllm.py" "$VLLM_PORT" & +VLLM_PID="$!" +python3 "$TMP_DIR/mock_vantage.py" "$VANTAGE_PORT" & +VANTAGE_PID="$!" + +sleep 1 + +echo "Starting inference admission controller demo." +echo "Expected visible transitions: Normal -> Throttled -> Exhausted" + +cargo build -p inference-admission + +"${ROOT_DIR}/target/debug/vantage-inference-admission" \ + --vantage-base-url "http://127.0.0.1:${VANTAGE_PORT}" \ + --tenant cg:42 \ + --inference-port 8000 \ + --inference-http-path /v1/chat/completions \ + --metrics-source vllm \ + --vllm-metrics-base-url "http://127.0.0.1:${VLLM_PORT}" \ + --vllm-metrics-path /metrics \ + --tick-ms 1000 \ + --token-budget-per-minute 100 \ + --disabled-on-exhaustion & +CONTROLLER_PID="$!" + +sleep 62 +kill "$CONTROLLER_PID" >/dev/null 2>&1 || true +wait "$CONTROLLER_PID" >/dev/null 2>&1 || true +CONTROLLER_PID="" + +echo "Demo complete." diff --git a/examples/inference-admission/fixtures/vllm_metrics_exhausted.prom b/examples/inference-admission/fixtures/vllm_metrics_exhausted.prom new file mode 100644 index 0000000..5849dfc --- /dev/null +++ b/examples/inference-admission/fixtures/vllm_metrics_exhausted.prom @@ -0,0 +1,15 @@ +# HELP vllm:gpu_cache_usage_perc GPU KV-cache utilization. +# TYPE vllm:gpu_cache_usage_perc gauge +vllm:gpu_cache_usage_perc 0.98 +# HELP vllm:num_requests_waiting Number of waiting requests. +# TYPE vllm:num_requests_waiting gauge +vllm:num_requests_waiting 12 +# HELP vllm:num_requests_running Number of running requests. +# TYPE vllm:num_requests_running gauge +vllm:num_requests_running 20 +# HELP vllm:prompt_tokens_total Prompt tokens in the current synthetic demo window. +# TYPE vllm:prompt_tokens_total counter +vllm:prompt_tokens_total 70 +# HELP vllm:generation_tokens_total Generation tokens in the current synthetic demo window. +# TYPE vllm:generation_tokens_total counter +vllm:generation_tokens_total 50 diff --git a/examples/inference-admission/fixtures/vllm_metrics_normal.prom b/examples/inference-admission/fixtures/vllm_metrics_normal.prom new file mode 100644 index 0000000..2c42834 --- /dev/null +++ b/examples/inference-admission/fixtures/vllm_metrics_normal.prom @@ -0,0 +1,15 @@ +# HELP vllm:gpu_cache_usage_perc GPU KV-cache utilization. +# TYPE vllm:gpu_cache_usage_perc gauge +vllm:gpu_cache_usage_perc 0.20 +# HELP vllm:num_requests_waiting Number of waiting requests. +# TYPE vllm:num_requests_waiting gauge +vllm:num_requests_waiting 0 +# HELP vllm:num_requests_running Number of running requests. +# TYPE vllm:num_requests_running gauge +vllm:num_requests_running 1 +# HELP vllm:prompt_tokens_total Prompt tokens in the current synthetic demo window. +# TYPE vllm:prompt_tokens_total counter +vllm:prompt_tokens_total 12 +# HELP vllm:generation_tokens_total Generation tokens in the current synthetic demo window. +# TYPE vllm:generation_tokens_total counter +vllm:generation_tokens_total 8 diff --git a/examples/inference-admission/fixtures/vllm_metrics_throttled.prom b/examples/inference-admission/fixtures/vllm_metrics_throttled.prom new file mode 100644 index 0000000..21d0d04 --- /dev/null +++ b/examples/inference-admission/fixtures/vllm_metrics_throttled.prom @@ -0,0 +1,15 @@ +# HELP vllm:gpu_cache_usage_perc GPU KV-cache utilization. +# TYPE vllm:gpu_cache_usage_perc gauge +vllm:gpu_cache_usage_perc 0.95 +# HELP vllm:num_requests_waiting Number of waiting requests. +# TYPE vllm:num_requests_waiting gauge +vllm:num_requests_waiting 8 +# HELP vllm:num_requests_running Number of running requests. +# TYPE vllm:num_requests_running gauge +vllm:num_requests_running 16 +# HELP vllm:prompt_tokens_total Prompt tokens in the current synthetic demo window. +# TYPE vllm:prompt_tokens_total counter +vllm:prompt_tokens_total 50 +# HELP vllm:generation_tokens_total Generation tokens in the current synthetic demo window. +# TYPE vllm:generation_tokens_total counter +vllm:generation_tokens_total 30 diff --git a/examples/inference-admission/src/config.rs b/examples/inference-admission/src/config.rs index 42de4e6..323bffc 100644 --- a/examples/inference-admission/src/config.rs +++ b/examples/inference-admission/src/config.rs @@ -1,6 +1,6 @@ use std::path::PathBuf; -use clap::Parser; +use clap::{Parser, ValueEnum}; use thiserror::Error; #[derive(Debug, Clone, Parser)] @@ -80,10 +80,29 @@ pub(crate) struct Cli { throttle_burst_tokens: u64, #[arg(long, env = "VANTAGE_INFERENCE_DISABLED_ON_EXHAUSTION")] disabled_on_exhaustion: bool, + #[arg( + long, + value_enum, + default_value_t = InferenceMetricsSourceMode::File, + env = "VANTAGE_INFERENCE_METRICS_SOURCE" + )] + metrics_source: InferenceMetricsSourceMode, #[arg(long, env = "VANTAGE_INFERENCE_METRICS_FILE_PATH")] metrics_file_path: Option, #[arg(long, env = "VANTAGE_INFERENCE_GPU_UTIL_FILE_PATH")] gpu_util_file_path: Option, + #[arg( + long, + default_value = "http://127.0.0.1:8000", + env = "VANTAGE_INFERENCE_VLLM_METRICS_BASE_URL" + )] + vllm_metrics_base_url: String, + #[arg( + long, + default_value = "/metrics", + env = "VANTAGE_INFERENCE_VLLM_METRICS_PATH" + )] + vllm_metrics_path: String, } #[derive(Debug, Clone, PartialEq, Eq)] @@ -91,6 +110,12 @@ pub(crate) struct TenantSelector { pub(crate) cgroup_id: u64, } +#[derive(Debug, Clone, Copy, PartialEq, Eq, ValueEnum)] +pub(crate) enum InferenceMetricsSourceMode { + File, + Vllm, +} + #[derive(Debug, Clone)] pub(crate) struct Config { pub(crate) vantage_base_url: String, @@ -108,8 +133,11 @@ pub(crate) struct Config { pub(crate) throttle_rate_tokens_per_sec: u64, pub(crate) throttle_burst_tokens: u64, pub(crate) disabled_on_exhaustion: bool, + pub(crate) metrics_source: InferenceMetricsSourceMode, pub(crate) metrics_file_path: Option, pub(crate) gpu_util_file_path: Option, + pub(crate) vllm_metrics_base_url: String, + pub(crate) vllm_metrics_path: String, } #[derive(Debug, Error)] @@ -151,8 +179,11 @@ impl Config { throttle_rate_tokens_per_sec: cli.throttle_rate_tokens_per_sec.max(1), throttle_burst_tokens: cli.throttle_burst_tokens.max(1), disabled_on_exhaustion: cli.disabled_on_exhaustion, + metrics_source: cli.metrics_source, metrics_file_path: cli.metrics_file_path, gpu_util_file_path: cli.gpu_util_file_path, + vllm_metrics_base_url: cli.vllm_metrics_base_url.trim_end_matches('/').to_owned(), + vllm_metrics_path: normalize_http_path(&cli.vllm_metrics_path), } } } @@ -230,4 +261,37 @@ mod tests { assert!((config.gpu_low_watermark_percent - 79.0).abs() < f64::EPSILON); assert_eq!(config.token_budget_per_minute, 1); } + + #[test] + fn defaults_to_file_metrics_source() { + let parsed = Config::try_from_iter(["vantage-inference-admission", "--tenant", "7"]); + let Ok(config) = parsed else { + panic!("config should parse"); + }; + assert_eq!( + config.metrics_source, + super::InferenceMetricsSourceMode::File + ); + } + + #[test] + fn parses_vllm_metrics_source_and_normalizes_path() { + let parsed = Config::try_from_iter([ + "vantage-inference-admission", + "--tenant", + "7", + "--metrics-source", + "vllm", + "--vllm-metrics-path", + "metrics", + ]); + let Ok(config) = parsed else { + panic!("config should parse"); + }; + assert_eq!( + config.metrics_source, + super::InferenceMetricsSourceMode::Vllm + ); + assert_eq!(config.vllm_metrics_path, "/metrics"); + } } diff --git a/examples/inference-admission/src/controller.rs b/examples/inference-admission/src/controller.rs index 9945a70..d9732c0 100644 --- a/examples/inference-admission/src/controller.rs +++ b/examples/inference-admission/src/controller.rs @@ -1,4 +1,4 @@ -use std::{fs, io::ErrorKind, path::PathBuf}; +use std::{fs, future::Future, io::ErrorKind, path::PathBuf}; use thiserror::Error; use tracing::info; @@ -9,6 +9,7 @@ use crate::{ inference::{InferencePressure, InferencePressureSample}, state::LastAppliedState, vantage_client::{AdmissionClient, ClientError, PolicyShape, PolicyTarget}, + vllm::{VllmMetricsError, VllmMetricsSource}, }; const TOKEN_THROTTLE_HIGH_PERCENT: f64 = 90.0; @@ -74,10 +75,20 @@ pub(crate) enum InferenceSourceError { #[source] source: serde_json::Error, }, + #[error(transparent)] + Vllm(#[from] VllmMetricsError), } pub(crate) trait InferenceSource { - fn sample(&self) -> Result; + fn sample( + &self, + ) -> impl Future> + Send; +} + +#[derive(Debug, Clone)] +pub(crate) enum ConfiguredInferenceSource { + File(FileInferenceSource), + Vllm(VllmMetricsSource), } impl FileInferenceSource { @@ -90,7 +101,7 @@ impl FileInferenceSource { } impl InferenceSource for FileInferenceSource { - fn sample(&self) -> Result { + async fn sample(&self) -> Result { let Some(path) = &self.path else { return Ok(InferencePressureSample::empty( self.default_token_budget_per_minute, @@ -119,6 +130,15 @@ impl InferenceSource for FileInferenceSource { } } +impl InferenceSource for ConfiguredInferenceSource { + async fn sample(&self) -> Result { + match self { + Self::File(source) => source.sample().await, + Self::Vllm(source) => source.sample().await, + } + } +} + impl AdmissionController where Client: AdmissionClient, @@ -143,7 +163,7 @@ where pub(crate) async fn tick(&mut self) -> Result<(), ControllerError> { let gpu_sample = self.gpu_source.sample()?; - let inference_sample = self.inference_source.sample()?; + let inference_sample = self.inference_source.sample().await?; let desired = decide_admission( &self.config, self.mode, @@ -282,10 +302,16 @@ mod tests { use crate::{ config::Config, gpu::{GpuError, GpuUtilSample, GpuUtilSource}, + http_client::HttpClientError, inference::InferencePressureSample, vantage_client::{ClientError, PolicyShape, PolicyTarget}, + vllm::{metrics_to_pressure_sample, parse_vllm_metrics}, }; + const VLLM_NORMAL: &str = include_str!("../fixtures/vllm_metrics_normal.prom"); + const VLLM_THROTTLED: &str = include_str!("../fixtures/vllm_metrics_throttled.prom"); + const VLLM_EXHAUSTED: &str = include_str!("../fixtures/vllm_metrics_exhausted.prom"); + #[derive(Debug, Clone, PartialEq, Eq)] enum ClientCall { PutBase(PolicyTarget, PolicyShape), @@ -324,7 +350,9 @@ mod tests { .fail_after_calls .is_some_and(|limit| calls.len() >= limit) { - return Err(ClientError::InvalidResponse("fixture failure".to_owned())); + return Err(ClientError::Http(HttpClientError::InvalidResponse( + "fixture failure".to_owned(), + ))); } calls.push(call); } @@ -371,7 +399,7 @@ mod tests { struct FixedInference(InferencePressureSample); impl super::InferenceSource for FixedInference { - fn sample(&self) -> Result { + async fn sample(&self) -> Result { Ok(self.0) } } @@ -397,6 +425,7 @@ mod tests { token_budget_per_minute: 100, kv_cache_used_bytes: kv_used, kv_cache_capacity_bytes: Some(100), + kv_cache_percent: None, active_requests: None, queued_requests: None, } @@ -631,4 +660,66 @@ mod tests { assert_eq!(desired.mode, AdmissionMode::Normal); assert!(desired.runtime_policy.is_none()); } + + #[test] + fn vllm_normal_fixture_decides_normal() { + let config = config(true); + let parsed = parse_vllm_metrics(VLLM_NORMAL); + let Ok(parsed) = parsed else { + panic!("fixture should parse"); + }; + let inference = metrics_to_pressure_sample(parsed, 100, 1); + let desired = decide_admission( + &config, + AdmissionMode::Normal, + None, + inference, + inference.pressure(), + ); + + assert_eq!(desired.mode, AdmissionMode::Normal); + } + + #[test] + fn vllm_throttled_fixture_decides_throttled() { + let config = config(true); + let parsed = parse_vllm_metrics(VLLM_THROTTLED); + let Ok(parsed) = parsed else { + panic!("fixture should parse"); + }; + let inference = metrics_to_pressure_sample(parsed, 100, 1); + let desired = decide_admission( + &config, + AdmissionMode::Normal, + None, + inference, + inference.pressure(), + ); + + assert_eq!( + desired.mode, + AdmissionMode::Throttled { + reason: ThrottleReason::KvCache + } + ); + } + + #[test] + fn vllm_exhausted_fixture_decides_exhausted() { + let config = config(true); + let parsed = parse_vllm_metrics(VLLM_EXHAUSTED); + let Ok(parsed) = parsed else { + panic!("fixture should parse"); + }; + let inference = metrics_to_pressure_sample(parsed, 100, 1); + let desired = decide_admission( + &config, + AdmissionMode::Normal, + None, + inference, + inference.pressure(), + ); + + assert_eq!(desired.mode, AdmissionMode::Exhausted); + } } diff --git a/examples/inference-admission/src/http_client.rs b/examples/inference-admission/src/http_client.rs new file mode 100644 index 0000000..9c42b4e --- /dev/null +++ b/examples/inference-admission/src/http_client.rs @@ -0,0 +1,311 @@ +use std::{fmt::Write as _, future::Future, str::FromStr, time::Duration}; + +use thiserror::Error; +use tokio::{ + io::{AsyncReadExt as _, AsyncWriteExt as _}, + net::TcpStream, + time::timeout, +}; + +const REQUEST_TIMEOUT: Duration = Duration::from_secs(5); +const MAX_RESPONSE_BYTES: usize = 1 << 20; + +#[derive(Debug, Clone)] +pub(crate) struct HttpClient { + endpoint: HttpEndpoint, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +struct HttpEndpoint { + host: String, + port: u16, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub(crate) struct HttpResponse { + pub(crate) status: u16, + pub(crate) body: String, +} + +impl HttpResponse { + pub(crate) const fn status_is_success(&self) -> bool { + self.status >= 200 && self.status < 300 + } +} + +#[derive(Debug, Clone, Copy)] +enum Method { + Get, + Put, + Delete, +} + +impl Method { + const fn as_str(self) -> &'static str { + match self { + Self::Get => "GET", + Self::Put => "PUT", + Self::Delete => "DELETE", + } + } +} + +#[derive(Debug, Error)] +pub(crate) enum HttpClientError { + #[error("{0}")] + InvalidBaseUrl(String), + #[error("HTTP transport failed: {0}")] + Transport(#[from] std::io::Error), + #[error("invalid HTTP response: {0}")] + InvalidResponse(String), + #[error("HTTP operation timed out after {seconds}s")] + Timeout { seconds: u64 }, + #[error("HTTP response exceeded {limit} bytes")] + ResponseTooLarge { limit: usize }, +} + +impl HttpClient { + pub(crate) fn new(base_url: &str) -> Result { + Ok(Self { + endpoint: HttpEndpoint::from_str(base_url)?, + }) + } + + pub(crate) async fn get(&self, path: &str) -> Result { + self.request(Method::Get, path, None).await + } + + pub(crate) async fn put_json( + &self, + path: &str, + body: &[u8], + ) -> Result { + self.request(Method::Put, path, Some(body)).await + } + + pub(crate) async fn delete(&self, path: &str) -> Result { + self.request(Method::Delete, path, None).await + } + + async fn request( + &self, + method: Method, + path: &str, + body: Option<&[u8]>, + ) -> Result { + let mut stream = with_timeout(TcpStream::connect(( + self.endpoint.host.as_str(), + self.endpoint.port, + ))) + .await??; + stream.set_nodelay(true)?; + let mut request = String::new(); + let _ = write!( + request, + "{} {} HTTP/1.0\r\nHost: {}\r\nConnection: close\r\n", + method.as_str(), + path, + self.endpoint.host + ); + if let Some(body) = body { + let _ = write!(request, "Content-Length: {}\r\n", body.len()); + request.push_str("Content-Type: application/json\r\n"); + } + request.push_str("\r\n"); + with_timeout(stream.write_all(request.as_bytes())).await??; + if let Some(body) = body { + with_timeout(stream.write_all(body)).await??; + } + with_timeout(stream.flush()).await??; + + let response_bytes = read_response_bounded(&mut stream).await?; + parse_response(&response_bytes) + } +} + +impl FromStr for HttpEndpoint { + type Err = HttpClientError; + + fn from_str(raw: &str) -> Result { + let Some(rest) = raw.strip_prefix("http://") else { + return Err(HttpClientError::InvalidBaseUrl( + "only http:// URLs are supported in this example".to_owned(), + )); + }; + if rest.contains(['/', '?', '#']) { + return Err(HttpClientError::InvalidBaseUrl( + "base URL must not include a path, query, or fragment".to_owned(), + )); + } + if rest.is_empty() { + return Err(HttpClientError::InvalidBaseUrl( + "base URL must include a host".to_owned(), + )); + } + let (host, port) = match rest.rsplit_once(':') { + Some((host, raw_port)) if !host.is_empty() => { + let port = raw_port.parse::().map_err(|error| { + HttpClientError::InvalidBaseUrl(format!("invalid port '{raw_port}': {error}")) + })?; + (host.to_owned(), port) + } + Some(_) => { + return Err(HttpClientError::InvalidBaseUrl( + "base URL host must not be empty".to_owned(), + )); + } + None => (rest.to_owned(), 80), + }; + Ok(Self { host, port }) + } +} + +async fn with_timeout(future: impl Future) -> Result { + timeout(REQUEST_TIMEOUT, future) + .await + .map_err(|_| HttpClientError::Timeout { + seconds: REQUEST_TIMEOUT.as_secs(), + }) +} + +async fn read_response_bounded(stream: &mut TcpStream) -> Result, HttpClientError> { + let mut response = Vec::new(); + let mut chunk = [0_u8; 8192]; + let mut expected_len: Option = None; + let mut body_start: Option = None; + let mut header_complete_without_body = false; + loop { + let read = with_timeout(stream.read(&mut chunk)).await??; + if read == 0 { + return Ok(response); + } + if response.len().saturating_add(read) > MAX_RESPONSE_BYTES { + return Err(HttpClientError::ResponseTooLarge { + limit: MAX_RESPONSE_BYTES, + }); + } + response.extend_from_slice(&chunk[..read]); + if body_start.is_none() + && let Some(header_end) = find_header_end(&response) + { + let headers = String::from_utf8_lossy(&response[..header_end]); + expected_len = parse_content_length(&headers)?; + let status = parse_status_code(&headers)?; + header_complete_without_body = expected_len.is_none() && status_has_no_body(status); + body_start = Some(header_end.saturating_add(4)); + } + if header_complete_without_body { + return Ok(response); + } + if let (Some(start), Some(expected)) = (body_start, expected_len) + && response.len().saturating_sub(start) >= expected + { + return Ok(response); + } + } +} + +fn parse_status_code(headers: &str) -> Result { + let Some(status_line) = headers.lines().next() else { + return Err(HttpClientError::InvalidResponse( + "missing HTTP status line".to_owned(), + )); + }; + let Some(raw_status) = status_line.split_whitespace().nth(1) else { + return Err(HttpClientError::InvalidResponse( + "missing HTTP status code".to_owned(), + )); + }; + raw_status.parse::().map_err(|error| { + HttpClientError::InvalidResponse(format!( + "invalid HTTP status code '{raw_status}': {error}" + )) + }) +} + +const fn status_has_no_body(status: u16) -> bool { + (status >= 100 && status < 200) || status == 204 || status == 304 +} + +fn find_header_end(response: &[u8]) -> Option { + response.windows(4).position(|window| window == b"\r\n\r\n") +} + +fn parse_content_length(headers: &str) -> Result, HttpClientError> { + for line in headers.lines() { + let Some((name, value)) = line.split_once(':') else { + continue; + }; + if !name.eq_ignore_ascii_case("content-length") { + continue; + } + let trimmed = value.trim(); + let parsed = trimmed.parse::().map_err(|error| { + HttpClientError::InvalidResponse(format!("invalid Content-Length '{trimmed}': {error}")) + })?; + return Ok(Some(parsed)); + } + Ok(None) +} + +fn parse_response(response: &[u8]) -> Result { + let text = String::from_utf8_lossy(response); + let Some((headers, body)) = text.split_once("\r\n\r\n") else { + return Err(HttpClientError::InvalidResponse( + "missing HTTP header terminator".to_owned(), + )); + }; + let Some(status_line) = headers.lines().next() else { + return Err(HttpClientError::InvalidResponse( + "missing HTTP status line".to_owned(), + )); + }; + let mut parts = status_line.split_whitespace(); + let Some(version) = parts.next() else { + return Err(HttpClientError::InvalidResponse( + "missing HTTP version".to_owned(), + )); + }; + if !version.starts_with("HTTP/") { + return Err(HttpClientError::InvalidResponse(format!( + "invalid HTTP version '{version}'" + ))); + } + let Some(raw_status) = parts.next() else { + return Err(HttpClientError::InvalidResponse( + "missing HTTP status code".to_owned(), + )); + }; + let status = raw_status.parse::().map_err(|error| { + HttpClientError::InvalidResponse(format!( + "invalid HTTP status code '{raw_status}': {error}" + )) + })?; + Ok(HttpResponse { + status, + body: body.to_owned(), + }) +} + +#[cfg(test)] +mod tests { + use std::str::FromStr as _; + + use super::HttpEndpoint; + + #[test] + fn parses_http_base_url() { + let endpoint = HttpEndpoint::from_str("http://127.0.0.1:3000"); + let Ok(endpoint) = endpoint else { + panic!("endpoint should parse"); + }; + assert_eq!(endpoint.host, "127.0.0.1"); + assert_eq!(endpoint.port, 3000); + } + + #[test] + fn rejects_base_url_path() { + let endpoint = HttpEndpoint::from_str("http://127.0.0.1:3000/api"); + assert!(endpoint.is_err(), "base URL path should be rejected"); + } +} diff --git a/examples/inference-admission/src/inference.rs b/examples/inference-admission/src/inference.rs index 73317ae..0a0d061 100644 --- a/examples/inference-admission/src/inference.rs +++ b/examples/inference-admission/src/inference.rs @@ -1,13 +1,19 @@ use serde::Deserialize; -#[derive(Debug, Clone, Copy, Deserialize, PartialEq, Eq)] +#[derive(Debug, Clone, Copy, Deserialize, PartialEq)] pub(crate) struct InferencePressureSample { pub(crate) ts_unix_ms: u64, pub(crate) tokens_used_current_minute: u64, pub(crate) token_budget_per_minute: u64, + #[serde(default)] pub(crate) kv_cache_used_bytes: Option, + #[serde(default)] pub(crate) kv_cache_capacity_bytes: Option, + #[serde(default)] + pub(crate) kv_cache_percent: Option, + #[serde(default)] pub(crate) active_requests: Option, + #[serde(default)] pub(crate) queued_requests: Option, } @@ -27,6 +33,7 @@ impl InferencePressureSample { token_budget_per_minute, kv_cache_used_bytes: None, kv_cache_capacity_bytes: None, + kv_cache_percent: None, active_requests: None, queued_requests: None, } @@ -38,12 +45,14 @@ impl InferencePressureSample { self.tokens_used_current_minute.min(token_budget), token_budget, ); - let kv_cache_percent = match (self.kv_cache_used_bytes, self.kv_cache_capacity_bytes) { - (Some(used), Some(capacity)) if capacity > 0 => { - Some(ratio_percent(used.min(capacity), capacity)) + let kv_cache_percent = self.kv_cache_percent.map(clamp_percent).or_else(|| { + match (self.kv_cache_used_bytes, self.kv_cache_capacity_bytes) { + (Some(used), Some(capacity)) if capacity > 0 => { + Some(ratio_percent(used.min(capacity), capacity)) + } + _ => None, } - _ => None, - }; + }); InferencePressure { token_budget_percent, @@ -54,6 +63,14 @@ impl InferencePressureSample { } } +const fn clamp_percent(value: f64) -> f64 { + if value.is_finite() { + value.clamp(0.0, 100.0) + } else { + 100.0 + } +} + fn ratio_percent(numerator: u64, denominator: u64) -> f64 { if denominator == 0 { return 0.0; @@ -78,6 +95,7 @@ mod tests { token_budget_per_minute: 100, kv_cache_used_bytes: Some(750), kv_cache_capacity_bytes: Some(1_000), + kv_cache_percent: None, active_requests: Some(3), queued_requests: Some(4), }; @@ -97,6 +115,7 @@ mod tests { token_budget_per_minute: 100, kv_cache_used_bytes: Some(750), kv_cache_capacity_bytes: Some(0), + kv_cache_percent: None, active_requests: None, queued_requests: None, }; @@ -112,10 +131,37 @@ mod tests { token_budget_per_minute: 100, kv_cache_used_bytes: None, kv_cache_capacity_bytes: None, + kv_cache_percent: None, active_requests: None, queued_requests: None, }; assert!((sample.pressure().token_budget_percent - 100.0).abs() < f64::EPSILON); } + + #[test] + fn explicit_kv_percent_takes_precedence() { + let sample = InferencePressureSample { + ts_unix_ms: 1, + tokens_used_current_minute: 0, + token_budget_per_minute: 100, + kv_cache_used_bytes: Some(10), + kv_cache_capacity_bytes: Some(100), + kv_cache_percent: Some(90.0), + active_requests: None, + queued_requests: None, + }; + + assert_eq!(sample.pressure().kv_cache_percent, Some(90.0)); + } + + #[test] + fn old_file_json_without_kv_percent_still_parses() { + let raw = r#"{"ts_unix_ms":1,"tokens_used_current_minute":2,"token_budget_per_minute":10}"#; + let parsed = serde_json::from_str::(raw); + let Ok(parsed) = parsed else { + panic!("old fixture should parse"); + }; + assert_eq!(parsed.kv_cache_percent, None); + } } diff --git a/examples/inference-admission/src/main.rs b/examples/inference-admission/src/main.rs index 4c086ab..dfb838c 100644 --- a/examples/inference-admission/src/main.rs +++ b/examples/inference-admission/src/main.rs @@ -3,9 +3,11 @@ pub(crate) mod config; pub(crate) mod controller; pub(crate) mod gpu; +pub(crate) mod http_client; pub(crate) mod inference; pub(crate) mod state; pub(crate) mod vantage_client; +pub(crate) mod vllm; use std::time::Duration; @@ -13,10 +15,11 @@ use tokio::time::{MissedTickBehavior, interval}; use tracing::{info, warn}; use crate::{ - config::Config, - controller::{AdmissionController, FileInferenceSource}, + config::{Config, InferenceMetricsSourceMode}, + controller::{AdmissionController, ConfiguredInferenceSource, FileInferenceSource}, gpu::FileGpuUtilSource, vantage_client::VantageClient, + vllm::VllmMetricsSource, }; #[tokio::main] @@ -30,10 +33,24 @@ async fn main() -> anyhow::Result<()> { let config = Config::from_args(); let client = VantageClient::new(&config.vantage_base_url)?; let gpu_source = FileGpuUtilSource::new(config.gpu_util_file_path.clone()); - let inference_source = FileInferenceSource::new( - config.metrics_file_path.clone(), - config.token_budget_per_minute, - ); + let inference_source = match config.metrics_source { + InferenceMetricsSourceMode::File => { + ConfiguredInferenceSource::File(FileInferenceSource::new( + config.metrics_file_path.clone(), + config.token_budget_per_minute, + )) + } + InferenceMetricsSourceMode::Vllm => { + if config.metrics_file_path.is_some() { + warn!("metrics file path is ignored when metrics source is vllm"); + } + ConfiguredInferenceSource::Vllm(VllmMetricsSource::new( + &config.vllm_metrics_base_url, + config.vllm_metrics_path.clone(), + config.token_budget_per_minute, + )?) + } + }; let mut controller = AdmissionController::new(config.clone(), client, gpu_source, inference_source); @@ -42,6 +59,7 @@ async fn main() -> anyhow::Result<()> { tenant = config.tenant.cgroup_id, inference_port = config.inference_port, inference_http_path = %config.inference_http_path, + metrics_source = ?config.metrics_source, tick_ms = config.tick_ms, "vantage inference admission controller started" ); diff --git a/examples/inference-admission/src/vantage_client.rs b/examples/inference-admission/src/vantage_client.rs index f1a4773..55be26d 100644 --- a/examples/inference-admission/src/vantage_client.rs +++ b/examples/inference-admission/src/vantage_client.rs @@ -1,15 +1,9 @@ -use std::{fmt::Write as _, future::Future, str::FromStr, time::Duration}; +use std::{fmt::Write as _, future::Future}; use serde::Serialize; use thiserror::Error; -use tokio::{ - io::{AsyncReadExt as _, AsyncWriteExt as _}, - net::TcpStream, - time::timeout, -}; -const REQUEST_TIMEOUT: Duration = Duration::from_secs(5); -const MAX_RESPONSE_BYTES: usize = 1 << 20; +use crate::http_client::{HttpClient, HttpClientError, HttpResponse}; #[derive(Debug, Clone, PartialEq, Eq)] pub(crate) struct PolicyTarget { @@ -32,31 +26,17 @@ pub(crate) struct PolicyRequest { #[derive(Debug, Clone)] pub(crate) struct VantageClient { - endpoint: HttpEndpoint, -} - -#[derive(Debug, Clone, PartialEq, Eq)] -struct HttpEndpoint { - host: String, - port: u16, + http: HttpClient, } #[derive(Debug, Error)] pub(crate) enum ClientError { - #[error("{0}")] - InvalidBaseUrl(String), #[error("failed to serialize request body: {0}")] Serialize(#[from] serde_json::Error), - #[error("HTTP transport failed: {0}")] - Transport(#[from] std::io::Error), #[error("vantage returned HTTP {status}: {body}")] HttpStatus { status: u16, body: String }, - #[error("invalid HTTP response from vantage: {0}")] - InvalidResponse(String), - #[error("HTTP operation timed out after {seconds}s")] - Timeout { seconds: u64 }, - #[error("HTTP response exceeded {limit} bytes")] - ResponseTooLarge { limit: usize }, + #[error(transparent)] + Http(#[from] HttpClientError), } impl PolicyRequest { @@ -79,21 +59,6 @@ pub(crate) struct PolicyShape { pub(crate) enabled: bool, } -#[derive(Debug, Clone, Copy)] -enum Method { - Put, - Delete, -} - -impl Method { - const fn as_str(self) -> &'static str { - match self { - Self::Put => "PUT", - Self::Delete => "DELETE", - } - } -} - pub(crate) trait AdmissionClient { fn put_base_policy( &self, @@ -115,44 +80,9 @@ pub(crate) trait AdmissionClient { impl VantageClient { pub(crate) fn new(base_url: &str) -> Result { Ok(Self { - endpoint: HttpEndpoint::from_str(base_url)?, + http: HttpClient::new(base_url)?, }) } - - async fn request( - &self, - method: Method, - path: &str, - body: Option<&[u8]>, - ) -> Result { - let mut stream = with_timeout(TcpStream::connect(( - self.endpoint.host.as_str(), - self.endpoint.port, - ))) - .await??; - let body_len = body.map_or(0, <[u8]>::len); - let mut request = String::new(); - let _ = write!( - request, - "{} {} HTTP/1.1\r\nHost: {}\r\nConnection: close\r\nContent-Length: {}\r\n", - method.as_str(), - path, - self.endpoint.host, - body_len - ); - if body.is_some() { - request.push_str("Content-Type: application/json\r\n"); - } - request.push_str("\r\n"); - with_timeout(stream.write_all(request.as_bytes())).await??; - if let Some(body) = body { - with_timeout(stream.write_all(body)).await??; - } - with_timeout(stream.shutdown()).await??; - - let response_bytes = read_response_bounded(&mut stream).await?; - parse_response(&response_bytes) - } } impl AdmissionClient for VantageClient { @@ -164,9 +94,7 @@ impl AdmissionClient for VantageClient { let mut body = serde_json::to_value(PolicyRequest::with_target(policy, target))?; body["http_path"] = serde_json::Value::String(target.http_path.clone()); let bytes = serde_json::to_vec(&body)?; - let response = self - .request(Method::Put, &policy_path(target), Some(&bytes)) - .await?; + let response = self.http.put_json(&policy_path(target), &bytes).await?; require_success(response) } @@ -179,7 +107,8 @@ impl AdmissionClient for VantageClient { body["http_path"] = serde_json::Value::String(target.http_path.clone()); let bytes = serde_json::to_vec(&body)?; let response = self - .request(Method::Put, &runtime_policy_path(target), Some(&bytes)) + .http + .put_json(&runtime_policy_path(target), &bytes) .await?; require_success(response) } @@ -190,91 +119,13 @@ impl AdmissionClient for VantageClient { force: bool, ) -> Result<(), ClientError> { let response = self - .request( - Method::Delete, - &runtime_policy_delete_path(target, force), - None, - ) + .http + .delete(&runtime_policy_delete_path(target, force)) .await?; require_success(response) } } -impl FromStr for HttpEndpoint { - type Err = ClientError; - - fn from_str(raw: &str) -> Result { - let Some(rest) = raw.strip_prefix("http://") else { - return Err(ClientError::InvalidBaseUrl( - "only http:// vantage URLs are supported in this example".to_owned(), - )); - }; - if rest.contains(['/', '?', '#']) { - return Err(ClientError::InvalidBaseUrl( - "base URL must not include a path, query, or fragment".to_owned(), - )); - } - let authority = rest; - if authority.is_empty() { - return Err(ClientError::InvalidBaseUrl( - "base URL must include a host".to_owned(), - )); - } - let (host, port) = match authority.rsplit_once(':') { - Some((host, raw_port)) if !host.is_empty() => { - let port = raw_port.parse::().map_err(|error| { - ClientError::InvalidBaseUrl(format!("invalid port '{raw_port}': {error}")) - })?; - (host.to_owned(), port) - } - Some(_) => { - return Err(ClientError::InvalidBaseUrl( - "base URL host must not be empty".to_owned(), - )); - } - None => (authority.to_owned(), 80), - }; - Ok(Self { host, port }) - } -} - -async fn with_timeout(future: impl Future) -> Result { - timeout(REQUEST_TIMEOUT, future) - .await - .map_err(|_| ClientError::Timeout { - seconds: REQUEST_TIMEOUT.as_secs(), - }) -} - -async fn read_response_bounded(stream: &mut TcpStream) -> Result, ClientError> { - let mut response = Vec::new(); - let mut chunk = [0_u8; 8192]; - loop { - let read = with_timeout(stream.read(&mut chunk)).await??; - if read == 0 { - return Ok(response); - } - if response.len().saturating_add(read) > MAX_RESPONSE_BYTES { - return Err(ClientError::ResponseTooLarge { - limit: MAX_RESPONSE_BYTES, - }); - } - response.extend_from_slice(&chunk[..read]); - } -} - -#[derive(Debug, Clone, PartialEq, Eq)] -struct HttpResponse { - status: u16, - body: String, -} - -impl HttpResponse { - const fn status_is_success(&self) -> bool { - self.status >= 200 && self.status < 300 - } -} - fn require_success(response: HttpResponse) -> Result<(), ClientError> { if response.status_is_success() { return Ok(()); @@ -285,43 +136,6 @@ fn require_success(response: HttpResponse) -> Result<(), ClientError> { }) } -fn parse_response(response: &[u8]) -> Result { - let text = String::from_utf8_lossy(response); - let Some((headers, body)) = text.split_once("\r\n\r\n") else { - return Err(ClientError::InvalidResponse( - "missing HTTP header terminator".to_owned(), - )); - }; - let Some(status_line) = headers.lines().next() else { - return Err(ClientError::InvalidResponse( - "missing HTTP status line".to_owned(), - )); - }; - let mut parts = status_line.split_whitespace(); - let Some(version) = parts.next() else { - return Err(ClientError::InvalidResponse( - "missing HTTP version".to_owned(), - )); - }; - if !version.starts_with("HTTP/") { - return Err(ClientError::InvalidResponse(format!( - "invalid HTTP version '{version}'" - ))); - } - let Some(raw_status) = parts.next() else { - return Err(ClientError::InvalidResponse( - "missing HTTP status code".to_owned(), - )); - }; - let status = raw_status.parse::().map_err(|error| { - ClientError::InvalidResponse(format!("invalid HTTP status code '{raw_status}': {error}")) - })?; - Ok(HttpResponse { - status, - body: body.to_owned(), - }) -} - fn policy_path(target: &PolicyTarget) -> String { format!("/policy/cg:{}", target.cgroup_id) } @@ -356,9 +170,7 @@ fn percent_encode(raw: &str) -> String { #[cfg(test)] mod tests { - use std::str::FromStr as _; - - use super::{HttpEndpoint, PolicyTarget, percent_encode, runtime_policy_delete_path}; + use super::{PolicyTarget, percent_encode, runtime_policy_delete_path}; fn target() -> PolicyTarget { PolicyTarget { @@ -370,22 +182,6 @@ mod tests { } } - #[test] - fn parses_http_base_url() { - let endpoint = HttpEndpoint::from_str("http://127.0.0.1:3000"); - let Ok(endpoint) = endpoint else { - panic!("endpoint should parse"); - }; - assert_eq!(endpoint.host, "127.0.0.1"); - assert_eq!(endpoint.port, 3000); - } - - #[test] - fn rejects_base_url_path() { - let endpoint = HttpEndpoint::from_str("http://127.0.0.1:3000/api"); - assert!(endpoint.is_err(), "base URL path should be rejected"); - } - #[test] fn encodes_delete_runtime_policy_query() { let path = runtime_policy_delete_path(&target(), false); diff --git a/examples/inference-admission/src/vllm.rs b/examples/inference-admission/src/vllm.rs new file mode 100644 index 0000000..c975354 --- /dev/null +++ b/examples/inference-admission/src/vllm.rs @@ -0,0 +1,383 @@ +use std::sync::{Arc, Mutex}; + +use thiserror::Error; + +use crate::{ + controller::{InferenceSource, InferenceSourceError}, + http_client::{HttpClient, HttpClientError}, + inference::InferencePressureSample, +}; + +#[derive(Debug, Clone)] +pub(crate) struct VllmMetricsSource { + http: HttpClient, + metrics_path: String, + token_budget_per_minute: u64, + token_window: Arc>, +} + +#[derive(Debug, Clone, Copy, Default, PartialEq)] +pub(crate) struct ParsedVllmMetrics { + pub(crate) gpu_cache_usage_percent: Option, + pub(crate) requests_waiting: Option, + pub(crate) requests_running: Option, + pub(crate) prompt_tokens_total: Option, + pub(crate) generation_tokens_total: Option, +} + +#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)] +struct TokenTotals { + prompt: u64, + generation: u64, +} + +#[derive(Debug, Default)] +struct TokenWindow { + last_totals: Option, +} + +#[derive(Debug, Error)] +pub(crate) enum VllmMetricsError { + #[error(transparent)] + Http(#[from] HttpClientError), + #[error("vLLM metrics endpoint returned HTTP {status}: {body}")] + HttpStatus { status: u16, body: String }, + #[error("failed to parse vLLM metric line {line_number}: {reason}")] + ParseLine { line_number: usize, reason: String }, + #[error("vLLM token counter state lock is poisoned")] + StateLockPoisoned, +} + +impl VllmMetricsSource { + pub(crate) fn new( + metrics_base_url: &str, + metrics_path: String, + token_budget_per_minute: u64, + ) -> Result { + Ok(Self { + http: HttpClient::new(metrics_base_url)?, + metrics_path, + token_budget_per_minute, + token_window: Arc::new(Mutex::new(TokenWindow::default())), + }) + } +} + +impl InferenceSource for VllmMetricsSource { + async fn sample(&self) -> Result { + let response = self + .http + .get(&self.metrics_path) + .await + .map_err(VllmMetricsError::from)?; + if !response.status_is_success() { + return Err(VllmMetricsError::HttpStatus { + status: response.status, + body: response.body, + } + .into()); + } + let parsed = parse_vllm_metrics(&response.body)?; + let tokens_used_current_minute = self + .token_window + .lock() + .map_err(|_| VllmMetricsError::StateLockPoisoned)? + .tokens_since_last_sample(parsed); + Ok(metrics_to_pressure_sample_with_tokens( + parsed, + tokens_used_current_minute, + self.token_budget_per_minute, + unix_timestamp_ms(), + )) + } +} + +pub(crate) fn parse_vllm_metrics(text: &str) -> Result { + let mut parsed = ParsedVllmMetrics::default(); + for (index, raw_line) in text.lines().enumerate() { + let line_number = index.saturating_add(1); + let line = raw_line.trim(); + if line.is_empty() || line.starts_with('#') { + continue; + } + let name = parse_metric_name(line, line_number)?; + if !is_relevant_metric(name) { + continue; + } + let value = parse_metric_value(line, line_number)?; + match name { + "vllm:gpu_cache_usage_perc" => { + let percent = normalize_gpu_cache_percent(value); + parsed.gpu_cache_usage_percent = Some( + parsed + .gpu_cache_usage_percent + .map_or(percent, |current| current.max(percent)), + ); + } + "vllm:num_requests_waiting" => { + parsed.requests_waiting = Some( + parsed + .requests_waiting + .unwrap_or(0) + .saturating_add(count(value)), + ); + } + "vllm:num_requests_running" => { + parsed.requests_running = Some( + parsed + .requests_running + .unwrap_or(0) + .saturating_add(count(value)), + ); + } + "vllm:prompt_tokens_total" | "vllm:request_prompt_tokens_total" => { + parsed.prompt_tokens_total = Some( + parsed + .prompt_tokens_total + .unwrap_or(0) + .saturating_add(count(value)), + ); + } + "vllm:generation_tokens_total" | "vllm:request_generation_tokens_total" => { + parsed.generation_tokens_total = Some( + parsed + .generation_tokens_total + .unwrap_or(0) + .saturating_add(count(value)), + ); + } + _ => {} + } + } + Ok(parsed) +} + +#[cfg(test)] +pub(crate) fn metrics_to_pressure_sample( + metrics: ParsedVllmMetrics, + token_budget_per_minute: u64, + now_ms: u64, +) -> InferencePressureSample { + let tokens = metrics + .prompt_tokens_total + .unwrap_or(0) + .saturating_add(metrics.generation_tokens_total.unwrap_or(0)); + metrics_to_pressure_sample_with_tokens(metrics, tokens, token_budget_per_minute, now_ms) +} + +fn metrics_to_pressure_sample_with_tokens( + metrics: ParsedVllmMetrics, + tokens_used_current_minute: u64, + token_budget_per_minute: u64, + now_ms: u64, +) -> InferencePressureSample { + InferencePressureSample { + ts_unix_ms: now_ms, + tokens_used_current_minute, + token_budget_per_minute: token_budget_per_minute.max(1), + kv_cache_used_bytes: None, + kv_cache_capacity_bytes: None, + kv_cache_percent: metrics.gpu_cache_usage_percent, + active_requests: metrics.requests_running, + queued_requests: metrics.requests_waiting, + } +} + +impl TokenWindow { + fn tokens_since_last_sample(&mut self, metrics: ParsedVllmMetrics) -> u64 { + let current = TokenTotals { + prompt: metrics.prompt_tokens_total.unwrap_or(0), + generation: metrics.generation_tokens_total.unwrap_or(0), + }; + let tokens = self.last_totals.map_or(0, |last| { + current + .prompt + .saturating_sub(last.prompt) + .saturating_add(current.generation.saturating_sub(last.generation)) + }); + self.last_totals = Some(current); + tokens + } +} + +fn parse_metric_name(line: &str, line_number: usize) -> Result<&str, VllmMetricsError> { + let mut fields = line.split_whitespace(); + let Some(raw_name) = fields.next() else { + return Err(parse_error(line_number, "missing metric name")); + }; + Ok(raw_name.split_once('{').map_or(raw_name, |(name, _)| name)) +} + +fn is_relevant_metric(name: &str) -> bool { + matches!( + name, + "vllm:gpu_cache_usage_perc" + | "vllm:num_requests_waiting" + | "vllm:num_requests_running" + | "vllm:prompt_tokens_total" + | "vllm:request_prompt_tokens_total" + | "vllm:generation_tokens_total" + | "vllm:request_generation_tokens_total" + ) +} + +fn parse_metric_value(line: &str, line_number: usize) -> Result { + let mut fields = line.split_whitespace(); + let _ = fields.next(); + let Some(raw_value) = fields.next() else { + return Err(parse_error(line_number, "missing metric value")); + }; + let value = raw_value.parse::().map_err(|error| { + parse_error( + line_number, + &format!("invalid value '{raw_value}': {error}"), + ) + })?; + if !value.is_finite() { + return Err(parse_error(line_number, "metric value must be finite")); + } + Ok(value) +} + +fn parse_error(line_number: usize, reason: &str) -> VllmMetricsError { + VllmMetricsError::ParseLine { + line_number, + reason: reason.to_owned(), + } +} + +fn normalize_gpu_cache_percent(value: f64) -> f64 { + let percent = if value <= 1.0 { value * 100.0 } else { value }; + percent.clamp(0.0, 100.0) +} + +fn count(value: f64) -> u64 { + const U64_MAX_AS_F64: f64 = 18_446_744_073_709_551_615.0; + if value <= 0.0 { + 0 + } else if value >= U64_MAX_AS_F64 { + u64::MAX + } else { + let rounded = format!("{:.0}", value.floor()); + rounded.parse::().unwrap_or(u64::MAX) + } +} + +fn unix_timestamp_ms() -> u64 { + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .map_or(0, |duration| { + u64::try_from(duration.as_millis()).unwrap_or(u64::MAX) + }) +} + +#[cfg(test)] +mod tests { + use super::{TokenWindow, metrics_to_pressure_sample, parse_vllm_metrics}; + + const NORMAL: &str = include_str!("../fixtures/vllm_metrics_normal.prom"); + const THROTTLED: &str = include_str!("../fixtures/vllm_metrics_throttled.prom"); + const EXHAUSTED: &str = include_str!("../fixtures/vllm_metrics_exhausted.prom"); + + #[test] + fn parses_normal_fixture() { + let parsed = parse_vllm_metrics(NORMAL); + let Ok(parsed) = parsed else { + panic!("fixture should parse"); + }; + assert_eq!(parsed.gpu_cache_usage_percent, Some(20.0)); + assert_eq!(parsed.requests_waiting, Some(0)); + assert_eq!(parsed.requests_running, Some(1)); + assert_eq!(parsed.prompt_tokens_total, Some(12)); + assert_eq!(parsed.generation_tokens_total, Some(8)); + } + + #[test] + fn parses_throttled_fixture_with_ratio_gpu_cache() { + let parsed = parse_vllm_metrics(THROTTLED); + let Ok(parsed) = parsed else { + panic!("fixture should parse"); + }; + assert_eq!(parsed.gpu_cache_usage_percent, Some(95.0)); + assert_eq!(parsed.requests_waiting, Some(8)); + assert_eq!(parsed.requests_running, Some(16)); + } + + #[test] + fn converts_exhausted_fixture_to_pressure_sample() { + let parsed = parse_vllm_metrics(EXHAUSTED); + let Ok(parsed) = parsed else { + panic!("fixture should parse"); + }; + let sample = metrics_to_pressure_sample(parsed, 100, 123); + assert_eq!(sample.ts_unix_ms, 123); + assert_eq!(sample.tokens_used_current_minute, 120); + assert_eq!(sample.token_budget_per_minute, 100); + assert_eq!(sample.kv_cache_percent, Some(98.0)); + assert_eq!(sample.active_requests, Some(20)); + assert_eq!(sample.queued_requests, Some(12)); + } + + #[test] + fn sums_labeled_metrics() { + let text = "\ +vllm:num_requests_waiting{model=\"a\"} 2 +vllm:num_requests_waiting{model=\"b\"} 3 +vllm:prompt_tokens_total{model=\"a\"} 4 +vllm:prompt_tokens_total{model=\"b\"} 5 +"; + let parsed = parse_vllm_metrics(text); + let Ok(parsed) = parsed else { + panic!("metrics should parse"); + }; + assert_eq!(parsed.requests_waiting, Some(5)); + assert_eq!(parsed.prompt_tokens_total, Some(9)); + } + + #[test] + fn ignores_unrelated_invalid_metrics() { + let text = "\ +unrelated_metric NaN +other_exporter_metric_without_value +vllm:num_requests_running 2 +vllm:prompt_tokens_total 7 +"; + let parsed = parse_vllm_metrics(text); + let Ok(parsed) = parsed else { + panic!("unrelated invalid metrics should be ignored"); + }; + assert_eq!(parsed.requests_running, Some(2)); + assert_eq!(parsed.prompt_tokens_total, Some(7)); + } + + #[test] + fn token_window_uses_counter_deltas() { + let mut window = TokenWindow::default(); + let first = parse_vllm_metrics( + "\ +vllm:prompt_tokens_total 1000 +vllm:generation_tokens_total 2000 +", + ) + .expect("first sample should parse"); + assert_eq!(window.tokens_since_last_sample(first), 0); + + let second = parse_vllm_metrics( + "\ +vllm:prompt_tokens_total 1010 +vllm:generation_tokens_total 2025 +", + ) + .expect("second sample should parse"); + assert_eq!(window.tokens_since_last_sample(second), 35); + + let quiet = parse_vllm_metrics( + "\ +vllm:prompt_tokens_total 1010 +vllm:generation_tokens_total 2025 +", + ) + .expect("quiet sample should parse"); + assert_eq!(window.tokens_since_last_sample(quiet), 0); + } +}