Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions config/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ providers:
vertexai:
project: your-gcp-project
location: us-central1
claude:
api_key: your-claude-api-key

scheduler:
sync_batch_job_status_cron: "* * * * *"
Expand Down
5 changes: 5 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ module github.com/ideagate/aigateway-core
go 1.25.8

require (
github.com/anthropics/anthropic-sdk-go v1.33.0
github.com/redis/go-redis/v9 v9.18.0
github.com/robfig/cron/v3 v3.0.1
github.com/spf13/cast v1.7.1
Expand Down Expand Up @@ -49,6 +50,10 @@ require (
github.com/spf13/pflag v1.0.6 // indirect
github.com/stretchr/objx v0.5.2 // indirect
github.com/subosito/gotenv v1.6.0 // indirect
github.com/tidwall/gjson v1.18.0 // indirect
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.1 // indirect
github.com/tidwall/sjson v1.2.5 // indirect
go.opentelemetry.io/auto/sdk v1.2.1 // indirect
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.54.0 // indirect
go.opentelemetry.io/otel v1.39.0 // indirect
Expand Down
16 changes: 16 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ cloud.google.com/go/compute/metadata v0.9.0 h1:pDUj4QMoPejqq20dK0Pg2N4yG9zIkYGdB
cloud.google.com/go/compute/metadata v0.9.0/go.mod h1:E0bWwX5wTnLPedCKqk3pJmVgCBSM6qQI1yTBdEb3C10=
filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA=
filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
github.com/anthropics/anthropic-sdk-go v1.33.0 h1:YlRqiI+PjULBA8NoeeJmkXtFzTmjZZA7oFvpZ4FU7eU=
github.com/anthropics/anthropic-sdk-go v1.33.0/go.mod h1:dSIO7kSrOI7MA4fE6RRVaw8tyWP7HNQU5/H/KS4cax8=
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c=
github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA=
Expand All @@ -17,6 +19,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
github.com/dnaeon/go-vcr v1.2.0 h1:zHCHvJYTMh1N7xnV7zf1m1GPBF9Ad0Jk/whtQ1663qI=
github.com/dnaeon/go-vcr v1.2.0/go.mod h1:R4UdLID7HZT3taECzJs4YgbbH6PIGXB6W/sc5OLb6RQ=
github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg=
github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8=
Expand Down Expand Up @@ -95,6 +99,16 @@ github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8=
github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU=
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
github.com/zeebo/xxh3 v1.0.2 h1:xZmwmqxHZA8AI603jOQ0tMqmBr9lPeFwGg6d+xy9DC0=
github.com/zeebo/xxh3 v1.0.2/go.mod h1:5NWz9Sef7zIDm2JHfFlcQvNekmcEl9ekUZQQKCYaDcA=
go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64=
Expand Down Expand Up @@ -138,6 +152,8 @@ google.golang.org/protobuf v1.36.10/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10=
gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
Expand Down
183 changes: 183 additions & 0 deletions internal/aigateway/providers/claude.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
package providers

import (
"context"
"encoding/json"
"errors"
"fmt"

"github.com/anthropics/anthropic-sdk-go"
"github.com/anthropics/anthropic-sdk-go/option"
"github.com/anthropics/anthropic-sdk-go/packages/param"
aigatewayv1 "github.com/ideagate/aigateway-core/gen/aigateway/v1"
"github.com/ideagate/aigateway-core/internal/aigateway/models"
models2 "github.com/ideagate/aigateway-core/models"
)

const claudeDefaultMaxTokens = 4096

type ClaudeProvider struct {
client anthropic.Client
}

func newClaudeProvider(apiKey string) (*ClaudeProvider, error) {
if apiKey == "" {
return nil, fmt.Errorf("claude api key is required")
}

return &ClaudeProvider{
client: anthropic.NewClient(option.WithAPIKey(apiKey)),
}, nil
}

func (c *ClaudeProvider) Name() string {
return "claude"
}

func (c *ClaudeProvider) ChatCompletion(ctx context.Context, request *models2.ChatCompletionRequest) (*models2.ChatCompletionResponse, error) {
params := anthropic.MessageNewParams{
Model: anthropic.Model(request.Model),
MaxTokens: claudeDefaultMaxTokens,
Messages: []anthropic.MessageParam{
anthropic.NewUserMessage(anthropic.NewTextBlock(request.Content.Text)),
},
}

if request.Temperature != 0 {
params.Temperature = param.NewOpt(float64(request.Temperature))
}

if request.SystemInstruction.Text != "" {
params.System = []anthropic.TextBlockParam{
{Text: request.SystemInstruction.Text},
}
}

msg, err := c.client.Messages.New(ctx, params)
if err != nil {
return nil, fmt.Errorf("create message: %w", err)
}

text := extractTextFromContent(msg.Content)

return &models2.ChatCompletionResponse{
Content: models2.Content{
Text: text,
},
}, nil
}

func (c *ClaudeProvider) SubmitBatchJob(ctx context.Context, requests []*aigatewayv1.SubmitBulkChatCompletionsRequest) (*aigatewayv1.SubmitBulkChatCompletionsResponse, error) {
if len(requests) == 0 {
return nil, errors.New("no requests provided")
}

batchRequests := make([]anthropic.MessageBatchNewParamsRequest, len(requests))
for i, request := range requests {
reqParams := anthropic.MessageBatchNewParamsRequestParams{
Model: anthropic.Model(request.GetModel()),
MaxTokens: claudeDefaultMaxTokens,
Messages: []anthropic.MessageParam{
anthropic.NewUserMessage(anthropic.NewTextBlock(request.GetContent().GetText())),
},
}

if request.GetTemperature() != 0 {
reqParams.Temperature = param.NewOpt(float64(request.GetTemperature()))
}

if request.GetSystemInstruction().GetText() != "" {
reqParams.System = []anthropic.TextBlockParam{
{Text: request.GetSystemInstruction().GetText()},
}
}

batchRequests[i] = anthropic.MessageBatchNewParamsRequest{
CustomID: fmt.Sprintf("%d", i),
Params: reqParams,
}
}

batch, err := c.client.Messages.Batches.New(ctx, anthropic.MessageBatchNewParams{
Requests: batchRequests,
})
if err != nil {
return nil, fmt.Errorf("failed to create batch job: %w", err)
}

return &aigatewayv1.SubmitBulkChatCompletionsResponse{
JobId: batch.ID,
}, nil
}

// GetBatchJobStatus polls the provider once and returns the current status.
// On completion the full individual responses slice is JSON-serialised into
// ResultsJSON so callers can persist it directly.
func (c *ClaudeProvider) GetBatchJobStatus(ctx context.Context, referenceID string) (*BatchJobStatusResult, error) {
batch, err := c.client.Messages.Batches.Get(ctx, referenceID)
if err != nil {
return nil, fmt.Errorf("get batch job %s: %w", referenceID, err)
}

result := &BatchJobStatusResult{}

switch batch.ProcessingStatus {
case anthropic.MessageBatchProcessingStatusEnded:
// Collect all individual results by streaming the results file.
stream := c.client.Messages.Batches.ResultsStreaming(ctx, referenceID)
var responses []anthropic.MessageBatchIndividualResponse
for stream.Next() {
responses = append(responses, stream.Current())
}
if err := stream.Err(); err != nil {
return nil, fmt.Errorf("stream batch results %s: %w", referenceID, err)
}

// Determine overall status and aggregate token counts.
// Batches with at least one succeeded request are marked completed,
// matching the GoogleProvider behaviour for partial success. Callers
// can inspect the individual ResultsJSON entries for per-request errors.
succeeded := batch.RequestCounts.Succeeded
if succeeded == 0 {
result.Status = models.BatchJobStatusFailed
} else {
result.Status = models.BatchJobStatusCompleted
}

for _, r := range responses {
if r.Result.Type == "succeeded" {
usage := r.Result.Message.Usage
result.InputTokenCount += usage.InputTokens
result.OutputTokenCount += usage.OutputTokens
result.TotalTokenCount += usage.InputTokens + usage.OutputTokens
}
}

resultsJSON, err := json.Marshal(responses)
if err != nil {
return nil, fmt.Errorf("marshal batch results: %w", err)
}
result.ResultsJSON = resultsJSON

case anthropic.MessageBatchProcessingStatusCanceling:
result.Status = models.BatchJobStatusProcessing

default:
// in_progress — still in flight.
result.Status = models.BatchJobStatusProcessing
}

return result, nil
}

// extractTextFromContent returns the concatenated text from all text-type
// content blocks in the message response.
func extractTextFromContent(content []anthropic.ContentBlockUnion) string {
var text string
for _, block := range content {
if block.Type == "text" {
text += block.Text
}
}
return text
}
80 changes: 80 additions & 0 deletions internal/aigateway/providers/claude_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
package providers

import (
"testing"

"github.com/anthropics/anthropic-sdk-go"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestNewClaudeProvider(t *testing.T) {
t.Run("empty api key returns error", func(t *testing.T) {
_, err := newClaudeProvider("")
assert.Error(t, err)
assert.Contains(t, err.Error(), "api key is required")
})

t.Run("valid api key returns provider", func(t *testing.T) {
p, err := newClaudeProvider("test-api-key")
require.NoError(t, err)
assert.NotNil(t, p)
assert.Equal(t, "claude", p.Name())
})
}

func TestExtractTextFromContent(t *testing.T) {
tests := []struct {
name string
content []anthropic.ContentBlockUnion
wantText string
}{
{
name: "nil slice",
content: nil,
wantText: "",
},
{
name: "empty slice",
content: []anthropic.ContentBlockUnion{},
wantText: "",
},
{
name: "single text block",
content: []anthropic.ContentBlockUnion{
{Type: "text", Text: "Hello, world!"},
},
wantText: "Hello, world!",
},
{
name: "multiple text blocks concatenated",
content: []anthropic.ContentBlockUnion{
{Type: "text", Text: "Hello"},
{Type: "text", Text: ", world!"},
},
wantText: "Hello, world!",
},
{
name: "non-text blocks skipped",
content: []anthropic.ContentBlockUnion{
{Type: "thinking", Thinking: "some reasoning"},
{Type: "text", Text: "Answer"},
},
wantText: "Answer",
},
{
name: "no text blocks returns empty string",
content: []anthropic.ContentBlockUnion{
{Type: "thinking", Thinking: "some reasoning"},
},
wantText: "",
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := extractTextFromContent(tc.content)
assert.Equal(t, tc.wantText, got)
})
}
}
4 changes: 4 additions & 0 deletions internal/aigateway/providers/providers.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,7 @@ type Provider interface {
func New(geminiCfg platformconfig.GeminiConfig) (Provider, error) {
return newGoogleProvider(geminiCfg.APIKey)
}

func NewClaude(claudeCfg platformconfig.ClaudeConfig) (Provider, error) {
return newClaudeProvider(claudeCfg.APIKey)
}
5 changes: 5 additions & 0 deletions internal/platform/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ type RedisConfig struct {
type ProvidersConfig struct {
Gemini GeminiConfig `mapstructure:"gemini"`
VertexAI VertexAIConfig `mapstructure:"vertexai"`
Claude ClaudeConfig `mapstructure:"claude"`
}

type GeminiConfig struct {
Expand All @@ -54,6 +55,10 @@ type VertexAIConfig struct {
Location string `mapstructure:"location"`
}

type ClaudeConfig struct {
APIKey string `mapstructure:"api_key"`
}

type SchedulerConfig struct {
SyncBatchJobStatusCron string `mapstructure:"sync_batch_job_status_cron"`
}
Expand Down