Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 121 additions & 8 deletions internal/channel/gemini_channel.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,29 @@ func newGeminiChannel(f *Factory, group *models.Group) (ChannelProxy, error) {
}, nil
}

// BuildUpstreamURL constructs the target URL for Gemini requests.
func (ch *GeminiChannel) BuildUpstreamURL(originalURL *url.URL, groupName string) (string, error) {
base := ch.getUpstreamURL()
if base == nil {
return "", fmt.Errorf("no upstream URL configured for channel %s", ch.Name)
}

finalURL := *base
requestPath := trimProxyGroupPrefix(originalURL.Path, groupName)

if publisherBasePath, ok := vertexPublisherBasePath(base); ok {
finalURL.Path = buildVertexPublisherPath(publisherBasePath, requestPath)
} else if hasPathPrefix(requestPath, base.Path) {
finalURL.Path = ensureLeadingSlash(requestPath)
} else {
finalURL.Path = joinURLPath(base.Path, requestPath)
Comment thread
coderabbitai[bot] marked this conversation as resolved.
}

finalURL.RawQuery = originalURL.RawQuery

return finalURL.String(), nil
}

// ModifyRequest adds the API key as a query parameter for Gemini requests.
func (ch *GeminiChannel) ModifyRequest(req *http.Request, apiKey *models.APIKey, group *models.Group) {
if strings.Contains(req.URL.Path, "v1beta/openai") {
Expand Down Expand Up @@ -98,17 +121,12 @@ func (ch *GeminiChannel) ExtractModel(c *gin.Context, bodyBytes []byte) string {

// ValidateKey checks if the given API key is valid by making a generateContent request.
func (ch *GeminiChannel) ValidateKey(ctx context.Context, apiKey *models.APIKey, group *models.Group) (bool, error) {
upstreamURL := ch.getUpstreamURL()
if upstreamURL == nil {
return false, fmt.Errorf("no upstream URL configured for channel %s", ch.Name)
}

// Safely join the path segments
reqURL, err := url.JoinPath(upstreamURL.String(), "v1beta", "models", ch.TestModel+":generateContent")
reqURL, err := ch.BuildUpstreamURL(&url.URL{
Path: "/proxy/" + group.Name + "/v1beta/models/" + ch.TestModel + ":generateContent",
}, group.Name)
if err != nil {
return false, fmt.Errorf("failed to create gemini validation path: %w", err)
}
reqURL += "?key=" + apiKey.KeyValue

payload := gin.H{
"contents": []gin.H{
Expand All @@ -130,6 +148,7 @@ func (ch *GeminiChannel) ValidateKey(ctx context.Context, apiKey *models.APIKey,
return false, fmt.Errorf("failed to create validation request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
ch.ModifyRequest(req, apiKey, group)

// Apply custom header rules if available
if len(group.HeaderRuleList) > 0 {
Expand Down Expand Up @@ -343,3 +362,97 @@ func isFirstPage(req *http.Request) bool {
pageToken := req.URL.Query().Get("pageToken")
return pageToken == ""
}

func trimProxyGroupPrefix(requestPath, groupName string) string {
proxyPrefix := "/proxy/" + groupName
return strings.TrimPrefix(requestPath, proxyPrefix)
}

func vertexPublisherBasePath(base *url.URL) (string, bool) {
basePath := normalizeBasePath(base.Path)
if basePath == "/v1/publishers/google" || basePath == "/v1beta1/publishers/google" {
return basePath, true
}

if !strings.EqualFold(base.Hostname(), "aiplatform.googleapis.com") {
return "", false
}

switch basePath {
case "/", "/v1":
return "/v1/publishers/google", true
case "/v1beta1":
return "/v1beta1/publishers/google", true
default:
return "", false
}
}

func buildVertexPublisherPath(basePath, requestPath string) string {
if hasPathPrefix(requestPath, basePath) {
return ensureLeadingSlash(requestPath)
}

if modelPath, ok := geminiNativeModelPath(requestPath); ok {
return joinURLPath(basePath, modelPath)
}

return joinURLPath(basePath, requestPath)
}

func geminiNativeModelPath(requestPath string) (string, bool) {
parts := strings.Split(strings.TrimLeft(requestPath, "/"), "/")
if len(parts) == 0 {
return "", false
}

if parts[0] == "models" {
return strings.Join(parts, "/"), true
}

if len(parts) >= 2 && isGeminiNativeVersion(parts[0]) && parts[1] == "models" {
return strings.Join(parts[1:], "/"), true
}

return "", false
}

func isGeminiNativeVersion(segment string) bool {
return segment == "v1" || segment == "v1beta" || segment == "v1beta1"
}

func joinURLPath(basePath, requestPath string) string {
basePath = ensureLeadingSlash(basePath)
requestPath = strings.TrimLeft(requestPath, "/")
if requestPath == "" {
return basePath
}
if basePath == "/" {
return "/" + requestPath
}
return strings.TrimRight(basePath, "/") + "/" + requestPath
}

func hasPathPrefix(pathValue, prefix string) bool {
pathValue = ensureLeadingSlash(pathValue)
prefix = strings.TrimRight(ensureLeadingSlash(prefix), "/")
return pathValue == prefix || strings.HasPrefix(pathValue, prefix+"/")
}

func ensureLeadingSlash(pathValue string) string {
if pathValue == "" {
return "/"
}
if strings.HasPrefix(pathValue, "/") {
return pathValue
}
return "/" + pathValue
}

func normalizeBasePath(pathValue string) string {
normalized := strings.TrimRight(ensureLeadingSlash(pathValue), "/")
if normalized == "" {
return "/"
}
return normalized
}
165 changes: 165 additions & 0 deletions internal/channel/gemini_channel_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
package channel

import (
"context"
"gpt-load/internal/models"
"net/http"
"net/http/httptest"
"net/url"
"testing"
)

func TestGeminiBuildUpstreamURLPreservesDeveloperAPIPath(t *testing.T) {
ch := newTestGeminiChannel(t, "https://generativelanguage.googleapis.com")
originalURL := mustParseURL(t, "http://localhost:3001/proxy/gemini/v1beta/models/gemini-2.5-pro:generateContent?key=proxy-key")

got, err := ch.BuildUpstreamURL(originalURL, "gemini")
if err != nil {
t.Fatalf("BuildUpstreamURL returned error: %v", err)
}

want := "https://generativelanguage.googleapis.com/v1beta/models/gemini-2.5-pro:generateContent?key=proxy-key"
if got != want {
t.Fatalf("BuildUpstreamURL() = %q, want %q", got, want)
}
}

func TestGeminiBuildUpstreamURLDoesNotDuplicateDeveloperAPIBasePath(t *testing.T) {
ch := newTestGeminiChannel(t, "https://generativelanguage.googleapis.com/v1beta")
originalURL := mustParseURL(t, "http://localhost:3001/proxy/gemini/v1beta/models/gemini-2.5-pro:generateContent?key=proxy-key")

got, err := ch.BuildUpstreamURL(originalURL, "gemini")
if err != nil {
t.Fatalf("BuildUpstreamURL returned error: %v", err)
}

want := "https://generativelanguage.googleapis.com/v1beta/models/gemini-2.5-pro:generateContent?key=proxy-key"
if got != want {
t.Fatalf("BuildUpstreamURL() = %q, want %q", got, want)
}
}

func TestGeminiBuildUpstreamURLConvertsNativePathForVertexPublisherBase(t *testing.T) {
ch := newTestGeminiChannel(t, "https://aiplatform.googleapis.com/v1/publishers/google")
originalURL := mustParseURL(t, "http://localhost:3001/proxy/gemini/v1beta/models/gemini-2.5-pro:streamGenerateContent?alt=sse&key=proxy-key")

got, err := ch.BuildUpstreamURL(originalURL, "gemini")
if err != nil {
t.Fatalf("BuildUpstreamURL returned error: %v", err)
}

want := "https://aiplatform.googleapis.com/v1/publishers/google/models/gemini-2.5-pro:streamGenerateContent?alt=sse&key=proxy-key"
if got != want {
t.Fatalf("BuildUpstreamURL() = %q, want %q", got, want)
}
}

func TestGeminiBuildUpstreamURLConvertsNativePathForBareAiplatformBase(t *testing.T) {
ch := newTestGeminiChannel(t, "https://aiplatform.googleapis.com/")
originalURL := mustParseURL(t, "http://localhost:3001/proxy/gemini/v1beta/models/gemini-2.5-pro:generateContent?key=proxy-key")

got, err := ch.BuildUpstreamURL(originalURL, "gemini")
if err != nil {
t.Fatalf("BuildUpstreamURL returned error: %v", err)
}

want := "https://aiplatform.googleapis.com/v1/publishers/google/models/gemini-2.5-pro:generateContent?key=proxy-key"
if got != want {
t.Fatalf("BuildUpstreamURL() = %q, want %q", got, want)
}
}

func TestGeminiBuildUpstreamURLDoesNotDuplicateVertexPublisherPath(t *testing.T) {
ch := newTestGeminiChannel(t, "https://aiplatform.googleapis.com/v1/publishers/google")
originalURL := mustParseURL(t, "http://localhost:3001/proxy/gemini/v1/publishers/google/models/gemini-2.5-flash:generateContent")

got, err := ch.BuildUpstreamURL(originalURL, "gemini")
if err != nil {
t.Fatalf("BuildUpstreamURL returned error: %v", err)
}

want := "https://aiplatform.googleapis.com/v1/publishers/google/models/gemini-2.5-flash:generateContent"
if got != want {
t.Fatalf("BuildUpstreamURL() = %q, want %q", got, want)
}
}

func TestGeminiApplyModelRedirectWorksWithVertexPublisherPath(t *testing.T) {
ch := newTestGeminiChannel(t, "https://aiplatform.googleapis.com/v1/publishers/google")
req := &http.Request{
URL: mustParseURL(t, "https://aiplatform.googleapis.com/v1/publishers/google/models/source-model:generateContent"),
}
group := &models.Group{
Name: "gemini",
ModelRedirectMap: map[string]string{"source-model": "target-model"},
}
body := []byte(`{"contents":[]}`)

gotBody, err := ch.ApplyModelRedirect(req, body, group)
if err != nil {
t.Fatalf("ApplyModelRedirect returned error: %v", err)
}
if string(gotBody) != string(body) {
t.Fatalf("ApplyModelRedirect body = %s, want %s", gotBody, body)
}

wantPath := "/v1/publishers/google/models/target-model:generateContent"
if req.URL.Path != wantPath {
t.Fatalf("redirected path = %q, want %q", req.URL.Path, wantPath)
}
}

func TestGeminiValidateKeyUsesVertexPublisherPath(t *testing.T) {
var gotPath string
var gotKey string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotPath = r.URL.Path
gotKey = r.URL.Query().Get("key")
w.WriteHeader(http.StatusOK)
}))
t.Cleanup(server.Close)

ch := newTestGeminiChannel(t, server.URL+"/v1/publishers/google")
ch.HTTPClient = server.Client()
ch.TestModel = "gemini-test"

ok, err := ch.ValidateKey(context.Background(), &models.APIKey{KeyValue: "secret-key"}, &models.Group{Name: "gemini"})
if err != nil {
t.Fatalf("ValidateKey returned error: %v", err)
}
if !ok {
t.Fatal("ValidateKey returned false, want true")
}

wantPath := "/v1/publishers/google/models/gemini-test:generateContent"
if gotPath != wantPath {
t.Fatalf("validation path = %q, want %q", gotPath, wantPath)
}
if gotKey != "secret-key" {
t.Fatalf("validation key = %q, want %q", gotKey, "secret-key")
}
}

func newTestGeminiChannel(t *testing.T, upstream string) *GeminiChannel {
t.Helper()

upstreamURL := mustParseURL(t, upstream)
return &GeminiChannel{
BaseChannel: &BaseChannel{
Name: "gemini",
Upstreams: []UpstreamInfo{{URL: upstreamURL, Weight: 1}},
HTTPClient: http.DefaultClient,
TestModel: "gemini-2.0-flash-lite",
},
}
}

func mustParseURL(t *testing.T, rawURL string) *url.URL {
t.Helper()

parsed, err := url.Parse(rawURL)
if err != nil {
t.Fatalf("failed to parse url %q: %v", rawURL, err)
}
return parsed
}