Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 103 additions & 16 deletions crates/service/src/apikey/apikey_models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,12 @@ pub(crate) fn delete_managed_model_source_mapping(
let source_id = normalize_required("sourceId", source_id)?;
let upstream_model = normalize_required("upstreamModel", upstream_model)?;
storage
.delete_model_source_mapping_with_unlink_preference(&id, &source_kind, &source_id, &upstream_model)
.delete_model_source_mapping_with_unlink_preference(
&id,
&source_kind,
&source_id,
&upstream_model,
)
.map_err(|err| format!("delete model mapping failed: {err}"))
}

Expand Down Expand Up @@ -495,7 +500,8 @@ pub(crate) fn bootstrap_account_pool_model_routes(

pub(crate) fn bootstrap_aggregate_api_model_routes(storage: &Storage) -> Result<(), String> {
let active_source_ids = active_aggregate_api_source_ids(storage)?;
prune_stale_aggregate_api_source_routes(storage, &active_source_ids)?;
let existing_source_ids = existing_aggregate_api_source_ids(storage)?;
prune_deleted_aggregate_api_source_routes(storage, &existing_source_ids)?;
for source_id in active_source_ids {
auto_associate_source_models(
storage,
Expand Down Expand Up @@ -643,6 +649,9 @@ where
.collect::<HashSet<_>>();
match requested_source_id.as_deref() {
Some(source_id) if !active_source_ids.contains(source_id) => {
if apis.iter().any(|api| api.id == source_id) {
return Err(format!("aggregate api `{source_id}` is disabled"));
}
let stale_upstream_models = stale_source_upstream_models(
storage,
ROUTING_SOURCE_KIND_AGGREGATE_API,
Expand All @@ -658,13 +667,16 @@ where
.delete_model_source_routes_for_source(ROUTING_SOURCE_KIND_AGGREGATE_API, source_id)
.map_err(|err| format!("delete stale aggregate api source routes failed: {err}"))?;
cleanup_orphan_auto_catalog_models(storage, &stale_upstream_models)?;
if apis.iter().any(|api| api.id == source_id) {
return Err(format!("aggregate api `{source_id}` is disabled"));
}
return Err(format!("aggregate api `{source_id}` not found"));
}
Some(_) => {}
None => prune_stale_aggregate_api_source_routes(storage, &active_source_ids)?,
None => {
let existing_source_ids = apis
.iter()
.map(|api| api.id.clone())
.collect::<HashSet<_>>();
prune_deleted_aggregate_api_source_routes(storage, &existing_source_ids)?;
}
}
let mut synced_any = false;
let mut last_error: Option<String> = None;
Expand Down Expand Up @@ -733,9 +745,16 @@ fn active_aggregate_api_source_ids(storage: &Storage) -> Result<HashSet<String>,
})
}

fn prune_stale_aggregate_api_source_routes(
fn existing_aggregate_api_source_ids(storage: &Storage) -> Result<HashSet<String>, String> {
storage
.list_aggregate_apis()
.map_err(|err| format!("list aggregate apis failed: {err}"))
.map(|apis| apis.into_iter().map(|api| api.id).collect::<HashSet<_>>())
}

fn prune_deleted_aggregate_api_source_routes(
storage: &Storage,
active_source_ids: &HashSet<String>,
existing_source_ids: &HashSet<String>,
) -> Result<(), String> {
let mut known_source_ids = storage
.list_model_source_models(Some(ROUTING_SOURCE_KIND_AGGREGATE_API), None)
Expand All @@ -752,7 +771,7 @@ fn prune_stale_aggregate_api_source_routes(
known_source_ids.insert(mapping.source_id);
}
for source_id in known_source_ids {
if active_source_ids.contains(source_id.as_str()) {
if existing_source_ids.contains(source_id.as_str()) {
continue;
}
let stale_upstream_models = stale_source_upstream_models(
Expand Down Expand Up @@ -947,7 +966,10 @@ fn auto_associate_source_models(
if existing_source_platform_mappings.contains(source_model.upstream_model.as_str()) {
continue;
}
let enabled = match prefs.get(source_model.upstream_model.as_str()).map(String::as_str) {
let enabled = match prefs
.get(source_model.upstream_model.as_str())
.map(String::as_str)
{
Some("unlinked") => continue,
Some(v) => v != "disabled",
None => true,
Expand All @@ -974,9 +996,7 @@ fn auto_associate_source_models(
if let Err(err) =
ensure_model_price_rules_for_aggregate_api(storage, source_id, &source_models)
{
log::warn!(
"aggregate API {source_id}: 自动创建模型价格规则失败: {err}"
);
log::warn!("aggregate API {source_id}: 自动创建模型价格规则失败: {err}");
}
}

Expand Down Expand Up @@ -3281,7 +3301,7 @@ mod tests {
}

#[test]
fn aggregate_bootstrap_prunes_stale_source_routes() {
fn aggregate_bootstrap_preserves_disabled_source_routes() {
let storage = Storage::open_in_memory().expect("open storage");
storage.init().expect("init storage");
insert_test_aggregate_api(&storage, "agg-stale", "disabled");
Expand All @@ -3295,6 +3315,75 @@ mod tests {
)
.expect("seed aggregate source model");
let now = now_ts();
storage
.upsert_model_source_mapping(&ModelSourceMapping {
id: "mapping-aggregate-stale".to_string(),
platform_model_slug: "vendor-stale".to_string(),
source_kind: ROUTING_SOURCE_KIND_AGGREGATE_API.to_string(),
source_id: "agg-stale".to_string(),
upstream_model: "vendor-stale".to_string(),
enabled: true,
priority: 0,
weight: 1,
billing_model_slug: None,
created_at: now,
updated_at: now,
})
.expect("seed stale mapping");
storage
.upsert_model_source_mapping_preference(
ROUTING_SOURCE_KIND_AGGREGATE_API,
"agg-stale",
"vendor-stale",
"unlinked",
)
.expect("seed preference");

bootstrap_aggregate_api_model_routes(&storage).expect("bootstrap aggregate routes");

assert_eq!(
storage
.list_model_source_models(
Some(ROUTING_SOURCE_KIND_AGGREGATE_API),
Some("agg-stale")
)
.expect("list source models")
.len(),
1
);
assert_eq!(
storage
.list_model_source_mappings(Some("vendor-stale"))
.expect("list mappings")
.len(),
1
);
assert_eq!(
storage
.list_model_source_mapping_preferences(
ROUTING_SOURCE_KIND_AGGREGATE_API,
"agg-stale",
)
.expect("list preferences")
.len(),
1
);
}

#[test]
fn aggregate_bootstrap_prunes_deleted_source_routes() {
let storage = Storage::open_in_memory().expect("open storage");
storage.init().expect("init storage");
seed_platform_catalog(&storage, &["vendor-stale"]);
storage
.upsert_discovered_model_source_models(
ROUTING_SOURCE_KIND_AGGREGATE_API,
"agg-stale",
&["vendor-stale".to_string()],
"synced",
)
.expect("seed aggregate source model");
let now = now_ts();
storage
.upsert_model_source_mapping(&ModelSourceMapping {
id: "mapping-aggregate-stale".to_string(),
Expand Down Expand Up @@ -3327,7 +3416,6 @@ mod tests {
fn bootstrap_aggregate_routes_cleans_orphan_auto_catalog_model() {
let storage = Storage::open_in_memory().expect("open storage");
storage.init().expect("init storage");
insert_test_aggregate_api(&storage, "agg-orphan", "disabled");
let now = now_ts();
storage
.upsert_model_source_model(&ModelSourceModel {
Expand Down Expand Up @@ -3384,7 +3472,6 @@ mod tests {
fn bootstrap_aggregate_routes_keeps_unrelated_remote_catalog_model() {
let storage = Storage::open_in_memory().expect("open storage");
storage.init().expect("init storage");
insert_test_aggregate_api(&storage, "agg-orphan", "disabled");
let now = now_ts();
storage
.upsert_model_source_model(&ModelSourceModel {
Expand Down