From 24d9aeaf9b850ed951c02de1166d5b12799dc7e6 Mon Sep 17 00:00:00 2001 From: hijos <1019378968@qq.com> Date: Sun, 21 Jun 2026 01:10:13 +0800 Subject: [PATCH 1/2] feat: support Vertex AI publisher URLs for Gemini --- internal/channel/gemini_channel.go | 127 ++++++++++++++++++-- internal/channel/gemini_channel_test.go | 150 ++++++++++++++++++++++++ 2 files changed, 269 insertions(+), 8 deletions(-) create mode 100644 internal/channel/gemini_channel_test.go diff --git a/internal/channel/gemini_channel.go b/internal/channel/gemini_channel.go index fa6b2f0d7..cc29c14a1 100644 --- a/internal/channel/gemini_channel.go +++ b/internal/channel/gemini_channel.go @@ -36,6 +36,27 @@ func newGeminiChannel(f *Factory, group *models.Group) (ChannelProxy, error) { }, nil } +// BuildUpstreamURL constructs the target URL for Gemini requests. +func (ch *GeminiChannel) BuildUpstreamURL(originalURL *url.URL, groupName string) (string, error) { + base := ch.getUpstreamURL() + if base == nil { + return "", fmt.Errorf("no upstream URL configured for channel %s", ch.Name) + } + + finalURL := *base + requestPath := trimProxyGroupPrefix(originalURL.Path, groupName) + + if publisherBasePath, ok := vertexPublisherBasePath(base); ok { + finalURL.Path = buildVertexPublisherPath(publisherBasePath, requestPath) + } else { + finalURL.Path = joinURLPath(base.Path, requestPath) + } + + finalURL.RawQuery = originalURL.RawQuery + + return finalURL.String(), nil +} + // ModifyRequest adds the API key as a query parameter for Gemini requests. func (ch *GeminiChannel) ModifyRequest(req *http.Request, apiKey *models.APIKey, group *models.Group) { if strings.Contains(req.URL.Path, "v1beta/openai") { @@ -98,17 +119,12 @@ func (ch *GeminiChannel) ExtractModel(c *gin.Context, bodyBytes []byte) string { // ValidateKey checks if the given API key is valid by making a generateContent request. func (ch *GeminiChannel) ValidateKey(ctx context.Context, apiKey *models.APIKey, group *models.Group) (bool, error) { - upstreamURL := ch.getUpstreamURL() - if upstreamURL == nil { - return false, fmt.Errorf("no upstream URL configured for channel %s", ch.Name) - } - - // Safely join the path segments - reqURL, err := url.JoinPath(upstreamURL.String(), "v1beta", "models", ch.TestModel+":generateContent") + reqURL, err := ch.BuildUpstreamURL(&url.URL{ + Path: "/proxy/" + group.Name + "/v1beta/models/" + ch.TestModel + ":generateContent", + }, group.Name) if err != nil { return false, fmt.Errorf("failed to create gemini validation path: %w", err) } - reqURL += "?key=" + apiKey.KeyValue payload := gin.H{ "contents": []gin.H{ @@ -130,6 +146,7 @@ func (ch *GeminiChannel) ValidateKey(ctx context.Context, apiKey *models.APIKey, return false, fmt.Errorf("failed to create validation request: %w", err) } req.Header.Set("Content-Type", "application/json") + ch.ModifyRequest(req, apiKey, group) // Apply custom header rules if available if len(group.HeaderRuleList) > 0 { @@ -343,3 +360,97 @@ func isFirstPage(req *http.Request) bool { pageToken := req.URL.Query().Get("pageToken") return pageToken == "" } + +func trimProxyGroupPrefix(requestPath, groupName string) string { + proxyPrefix := "/proxy/" + groupName + return strings.TrimPrefix(requestPath, proxyPrefix) +} + +func vertexPublisherBasePath(base *url.URL) (string, bool) { + basePath := normalizeBasePath(base.Path) + if basePath == "/v1/publishers/google" || basePath == "/v1beta1/publishers/google" { + return basePath, true + } + + if !strings.EqualFold(base.Hostname(), "aiplatform.googleapis.com") { + return "", false + } + + switch basePath { + case "/", "/v1": + return "/v1/publishers/google", true + case "/v1beta1": + return "/v1beta1/publishers/google", true + default: + return "", false + } +} + +func buildVertexPublisherPath(basePath, requestPath string) string { + if hasPathPrefix(requestPath, basePath) { + return ensureLeadingSlash(requestPath) + } + + if modelPath, ok := geminiNativeModelPath(requestPath); ok { + return joinURLPath(basePath, modelPath) + } + + return joinURLPath(basePath, requestPath) +} + +func geminiNativeModelPath(requestPath string) (string, bool) { + parts := strings.Split(strings.TrimLeft(requestPath, "/"), "/") + if len(parts) == 0 { + return "", false + } + + if parts[0] == "models" { + return strings.Join(parts, "/"), true + } + + if len(parts) >= 2 && isGeminiNativeVersion(parts[0]) && parts[1] == "models" { + return strings.Join(parts[1:], "/"), true + } + + return "", false +} + +func isGeminiNativeVersion(segment string) bool { + return segment == "v1" || segment == "v1beta" || segment == "v1beta1" +} + +func joinURLPath(basePath, requestPath string) string { + basePath = ensureLeadingSlash(basePath) + requestPath = strings.TrimLeft(requestPath, "/") + if requestPath == "" { + return basePath + } + if basePath == "/" { + return "/" + requestPath + } + return strings.TrimRight(basePath, "/") + "/" + requestPath +} + +func hasPathPrefix(pathValue, prefix string) bool { + pathValue = ensureLeadingSlash(pathValue) + prefix = strings.TrimRight(ensureLeadingSlash(prefix), "/") + return pathValue == prefix || strings.HasPrefix(pathValue, prefix+"/") +} + +func ensureLeadingSlash(pathValue string) string { + if pathValue == "" { + return "/" + } + if strings.HasPrefix(pathValue, "/") { + return pathValue + } + return "/" + pathValue +} + +func normalizeBasePath(pathValue string) string { + normalized := strings.TrimRight(ensureLeadingSlash(pathValue), "/") + if normalized == "" { + return "/" + } + return normalized +} diff --git a/internal/channel/gemini_channel_test.go b/internal/channel/gemini_channel_test.go new file mode 100644 index 000000000..09a560844 --- /dev/null +++ b/internal/channel/gemini_channel_test.go @@ -0,0 +1,150 @@ +package channel + +import ( + "context" + "gpt-load/internal/models" + "net/http" + "net/http/httptest" + "net/url" + "testing" +) + +func TestGeminiBuildUpstreamURLPreservesDeveloperAPIPath(t *testing.T) { + ch := newTestGeminiChannel(t, "https://generativelanguage.googleapis.com") + originalURL := mustParseURL(t, "http://localhost:3001/proxy/gemini/v1beta/models/gemini-2.5-pro:generateContent?key=proxy-key") + + got, err := ch.BuildUpstreamURL(originalURL, "gemini") + if err != nil { + t.Fatalf("BuildUpstreamURL returned error: %v", err) + } + + want := "https://generativelanguage.googleapis.com/v1beta/models/gemini-2.5-pro:generateContent?key=proxy-key" + if got != want { + t.Fatalf("BuildUpstreamURL() = %q, want %q", got, want) + } +} + +func TestGeminiBuildUpstreamURLConvertsNativePathForVertexPublisherBase(t *testing.T) { + ch := newTestGeminiChannel(t, "https://aiplatform.googleapis.com/v1/publishers/google") + originalURL := mustParseURL(t, "http://localhost:3001/proxy/gemini/v1beta/models/gemini-2.5-pro:streamGenerateContent?alt=sse&key=proxy-key") + + got, err := ch.BuildUpstreamURL(originalURL, "gemini") + if err != nil { + t.Fatalf("BuildUpstreamURL returned error: %v", err) + } + + want := "https://aiplatform.googleapis.com/v1/publishers/google/models/gemini-2.5-pro:streamGenerateContent?alt=sse&key=proxy-key" + if got != want { + t.Fatalf("BuildUpstreamURL() = %q, want %q", got, want) + } +} + +func TestGeminiBuildUpstreamURLConvertsNativePathForBareAiplatformBase(t *testing.T) { + ch := newTestGeminiChannel(t, "https://aiplatform.googleapis.com/") + originalURL := mustParseURL(t, "http://localhost:3001/proxy/gemini/v1beta/models/gemini-2.5-pro:generateContent?key=proxy-key") + + got, err := ch.BuildUpstreamURL(originalURL, "gemini") + if err != nil { + t.Fatalf("BuildUpstreamURL returned error: %v", err) + } + + want := "https://aiplatform.googleapis.com/v1/publishers/google/models/gemini-2.5-pro:generateContent?key=proxy-key" + if got != want { + t.Fatalf("BuildUpstreamURL() = %q, want %q", got, want) + } +} + +func TestGeminiBuildUpstreamURLDoesNotDuplicateVertexPublisherPath(t *testing.T) { + ch := newTestGeminiChannel(t, "https://aiplatform.googleapis.com/v1/publishers/google") + originalURL := mustParseURL(t, "http://localhost:3001/proxy/gemini/v1/publishers/google/models/gemini-2.5-flash:generateContent") + + got, err := ch.BuildUpstreamURL(originalURL, "gemini") + if err != nil { + t.Fatalf("BuildUpstreamURL returned error: %v", err) + } + + want := "https://aiplatform.googleapis.com/v1/publishers/google/models/gemini-2.5-flash:generateContent" + if got != want { + t.Fatalf("BuildUpstreamURL() = %q, want %q", got, want) + } +} + +func TestGeminiApplyModelRedirectWorksWithVertexPublisherPath(t *testing.T) { + ch := newTestGeminiChannel(t, "https://aiplatform.googleapis.com/v1/publishers/google") + req := &http.Request{ + URL: mustParseURL(t, "https://aiplatform.googleapis.com/v1/publishers/google/models/source-model:generateContent"), + } + group := &models.Group{ + Name: "gemini", + ModelRedirectMap: map[string]string{"source-model": "target-model"}, + } + body := []byte(`{"contents":[]}`) + + gotBody, err := ch.ApplyModelRedirect(req, body, group) + if err != nil { + t.Fatalf("ApplyModelRedirect returned error: %v", err) + } + if string(gotBody) != string(body) { + t.Fatalf("ApplyModelRedirect body = %s, want %s", gotBody, body) + } + + wantPath := "/v1/publishers/google/models/target-model:generateContent" + if req.URL.Path != wantPath { + t.Fatalf("redirected path = %q, want %q", req.URL.Path, wantPath) + } +} + +func TestGeminiValidateKeyUsesVertexPublisherPath(t *testing.T) { + var gotPath string + var gotKey string + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotPath = r.URL.Path + gotKey = r.URL.Query().Get("key") + w.WriteHeader(http.StatusOK) + })) + t.Cleanup(server.Close) + + ch := newTestGeminiChannel(t, server.URL+"/v1/publishers/google") + ch.HTTPClient = server.Client() + ch.TestModel = "gemini-test" + + ok, err := ch.ValidateKey(context.Background(), &models.APIKey{KeyValue: "secret-key"}, &models.Group{Name: "gemini"}) + if err != nil { + t.Fatalf("ValidateKey returned error: %v", err) + } + if !ok { + t.Fatal("ValidateKey returned false, want true") + } + + wantPath := "/v1/publishers/google/models/gemini-test:generateContent" + if gotPath != wantPath { + t.Fatalf("validation path = %q, want %q", gotPath, wantPath) + } + if gotKey != "secret-key" { + t.Fatalf("validation key = %q, want %q", gotKey, "secret-key") + } +} + +func newTestGeminiChannel(t *testing.T, upstream string) *GeminiChannel { + t.Helper() + + upstreamURL := mustParseURL(t, upstream) + return &GeminiChannel{ + BaseChannel: &BaseChannel{ + Name: "gemini", + Upstreams: []UpstreamInfo{{URL: upstreamURL, Weight: 1}}, + HTTPClient: http.DefaultClient, + TestModel: "gemini-2.0-flash-lite", + }, + } +} + +func mustParseURL(t *testing.T, rawURL string) *url.URL { + t.Helper() + + parsed, err := url.Parse(rawURL) + if err != nil { + t.Fatalf("failed to parse url %q: %v", rawURL, err) + } + return parsed +} From fdd2919e43041bd5a7a48c9239ae0e118d6a54de Mon Sep 17 00:00:00 2001 From: hijos <1019378968@qq.com> Date: Sun, 21 Jun 2026 01:27:29 +0800 Subject: [PATCH 2/2] fix: avoid duplicate Gemini upstream base paths --- internal/channel/gemini_channel.go | 2 ++ internal/channel/gemini_channel_test.go | 15 +++++++++++++++ 2 files changed, 17 insertions(+) diff --git a/internal/channel/gemini_channel.go b/internal/channel/gemini_channel.go index cc29c14a1..d988c0308 100644 --- a/internal/channel/gemini_channel.go +++ b/internal/channel/gemini_channel.go @@ -48,6 +48,8 @@ func (ch *GeminiChannel) BuildUpstreamURL(originalURL *url.URL, groupName string if publisherBasePath, ok := vertexPublisherBasePath(base); ok { finalURL.Path = buildVertexPublisherPath(publisherBasePath, requestPath) + } else if hasPathPrefix(requestPath, base.Path) { + finalURL.Path = ensureLeadingSlash(requestPath) } else { finalURL.Path = joinURLPath(base.Path, requestPath) } diff --git a/internal/channel/gemini_channel_test.go b/internal/channel/gemini_channel_test.go index 09a560844..a160e08ef 100644 --- a/internal/channel/gemini_channel_test.go +++ b/internal/channel/gemini_channel_test.go @@ -24,6 +24,21 @@ func TestGeminiBuildUpstreamURLPreservesDeveloperAPIPath(t *testing.T) { } } +func TestGeminiBuildUpstreamURLDoesNotDuplicateDeveloperAPIBasePath(t *testing.T) { + ch := newTestGeminiChannel(t, "https://generativelanguage.googleapis.com/v1beta") + originalURL := mustParseURL(t, "http://localhost:3001/proxy/gemini/v1beta/models/gemini-2.5-pro:generateContent?key=proxy-key") + + got, err := ch.BuildUpstreamURL(originalURL, "gemini") + if err != nil { + t.Fatalf("BuildUpstreamURL returned error: %v", err) + } + + want := "https://generativelanguage.googleapis.com/v1beta/models/gemini-2.5-pro:generateContent?key=proxy-key" + if got != want { + t.Fatalf("BuildUpstreamURL() = %q, want %q", got, want) + } +} + func TestGeminiBuildUpstreamURLConvertsNativePathForVertexPublisherBase(t *testing.T) { ch := newTestGeminiChannel(t, "https://aiplatform.googleapis.com/v1/publishers/google") originalURL := mustParseURL(t, "http://localhost:3001/proxy/gemini/v1beta/models/gemini-2.5-pro:streamGenerateContent?alt=sse&key=proxy-key")