diff --git a/config/default.yaml b/config/default.yaml index 794e55c..fed2a82 100644 --- a/config/default.yaml +++ b/config/default.yaml @@ -25,6 +25,8 @@ providers: vertexai: project: your-gcp-project location: us-central1 + claude: + api_key: your-claude-api-key scheduler: sync_batch_job_status_cron: "* * * * *" diff --git a/go.mod b/go.mod index 6899ebc..8235ab2 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/ideagate/aigateway-core go 1.25.8 require ( + github.com/anthropics/anthropic-sdk-go v1.33.0 github.com/redis/go-redis/v9 v9.18.0 github.com/robfig/cron/v3 v3.0.1 github.com/spf13/cast v1.7.1 @@ -49,6 +50,10 @@ require ( github.com/spf13/pflag v1.0.6 // indirect github.com/stretchr/objx v0.5.2 // indirect github.com/subosito/gotenv v1.6.0 // indirect + github.com/tidwall/gjson v1.18.0 // indirect + github.com/tidwall/match v1.1.1 // indirect + github.com/tidwall/pretty v1.2.1 // indirect + github.com/tidwall/sjson v1.2.5 // indirect go.opentelemetry.io/auto/sdk v1.2.1 // indirect go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.54.0 // indirect go.opentelemetry.io/otel v1.39.0 // indirect diff --git a/go.sum b/go.sum index ba58fbd..5167658 100644 --- a/go.sum +++ b/go.sum @@ -6,6 +6,8 @@ cloud.google.com/go/compute/metadata v0.9.0 h1:pDUj4QMoPejqq20dK0Pg2N4yG9zIkYGdB cloud.google.com/go/compute/metadata v0.9.0/go.mod h1:E0bWwX5wTnLPedCKqk3pJmVgCBSM6qQI1yTBdEb3C10= filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= +github.com/anthropics/anthropic-sdk-go v1.33.0 h1:YlRqiI+PjULBA8NoeeJmkXtFzTmjZZA7oFvpZ4FU7eU= +github.com/anthropics/anthropic-sdk-go v1.33.0/go.mod h1:dSIO7kSrOI7MA4fE6RRVaw8tyWP7HNQU5/H/KS4cax8= github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c= github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= @@ -17,6 +19,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= +github.com/dnaeon/go-vcr v1.2.0 h1:zHCHvJYTMh1N7xnV7zf1m1GPBF9Ad0Jk/whtQ1663qI= +github.com/dnaeon/go-vcr v1.2.0/go.mod h1:R4UdLID7HZT3taECzJs4YgbbH6PIGXB6W/sc5OLb6RQ= github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= @@ -95,6 +99,16 @@ github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8= github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU= +github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY= +github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= +github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= +github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4= +github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= +github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/zeebo/xxh3 v1.0.2 h1:xZmwmqxHZA8AI603jOQ0tMqmBr9lPeFwGg6d+xy9DC0= github.com/zeebo/xxh3 v1.0.2/go.mod h1:5NWz9Sef7zIDm2JHfFlcQvNekmcEl9ekUZQQKCYaDcA= go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64= @@ -138,6 +152,8 @@ google.golang.org/protobuf v1.36.10/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10= +gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/internal/aigateway/providers/claude.go b/internal/aigateway/providers/claude.go new file mode 100644 index 0000000..ab3ea29 --- /dev/null +++ b/internal/aigateway/providers/claude.go @@ -0,0 +1,183 @@ +package providers + +import ( + "context" + "encoding/json" + "errors" + "fmt" + + "github.com/anthropics/anthropic-sdk-go" + "github.com/anthropics/anthropic-sdk-go/option" + "github.com/anthropics/anthropic-sdk-go/packages/param" + aigatewayv1 "github.com/ideagate/aigateway-core/gen/aigateway/v1" + "github.com/ideagate/aigateway-core/internal/aigateway/models" + models2 "github.com/ideagate/aigateway-core/models" +) + +const claudeDefaultMaxTokens = 4096 + +type ClaudeProvider struct { + client anthropic.Client +} + +func newClaudeProvider(apiKey string) (*ClaudeProvider, error) { + if apiKey == "" { + return nil, fmt.Errorf("claude api key is required") + } + + return &ClaudeProvider{ + client: anthropic.NewClient(option.WithAPIKey(apiKey)), + }, nil +} + +func (c *ClaudeProvider) Name() string { + return "claude" +} + +func (c *ClaudeProvider) ChatCompletion(ctx context.Context, request *models2.ChatCompletionRequest) (*models2.ChatCompletionResponse, error) { + params := anthropic.MessageNewParams{ + Model: anthropic.Model(request.Model), + MaxTokens: claudeDefaultMaxTokens, + Messages: []anthropic.MessageParam{ + anthropic.NewUserMessage(anthropic.NewTextBlock(request.Content.Text)), + }, + } + + if request.Temperature != 0 { + params.Temperature = param.NewOpt(float64(request.Temperature)) + } + + if request.SystemInstruction.Text != "" { + params.System = []anthropic.TextBlockParam{ + {Text: request.SystemInstruction.Text}, + } + } + + msg, err := c.client.Messages.New(ctx, params) + if err != nil { + return nil, fmt.Errorf("create message: %w", err) + } + + text := extractTextFromContent(msg.Content) + + return &models2.ChatCompletionResponse{ + Content: models2.Content{ + Text: text, + }, + }, nil +} + +func (c *ClaudeProvider) SubmitBatchJob(ctx context.Context, requests []*aigatewayv1.SubmitBulkChatCompletionsRequest) (*aigatewayv1.SubmitBulkChatCompletionsResponse, error) { + if len(requests) == 0 { + return nil, errors.New("no requests provided") + } + + batchRequests := make([]anthropic.MessageBatchNewParamsRequest, len(requests)) + for i, request := range requests { + reqParams := anthropic.MessageBatchNewParamsRequestParams{ + Model: anthropic.Model(request.GetModel()), + MaxTokens: claudeDefaultMaxTokens, + Messages: []anthropic.MessageParam{ + anthropic.NewUserMessage(anthropic.NewTextBlock(request.GetContent().GetText())), + }, + } + + if request.GetTemperature() != 0 { + reqParams.Temperature = param.NewOpt(float64(request.GetTemperature())) + } + + if request.GetSystemInstruction().GetText() != "" { + reqParams.System = []anthropic.TextBlockParam{ + {Text: request.GetSystemInstruction().GetText()}, + } + } + + batchRequests[i] = anthropic.MessageBatchNewParamsRequest{ + CustomID: fmt.Sprintf("%d", i), + Params: reqParams, + } + } + + batch, err := c.client.Messages.Batches.New(ctx, anthropic.MessageBatchNewParams{ + Requests: batchRequests, + }) + if err != nil { + return nil, fmt.Errorf("failed to create batch job: %w", err) + } + + return &aigatewayv1.SubmitBulkChatCompletionsResponse{ + JobId: batch.ID, + }, nil +} + +// GetBatchJobStatus polls the provider once and returns the current status. +// On completion the full individual responses slice is JSON-serialised into +// ResultsJSON so callers can persist it directly. +func (c *ClaudeProvider) GetBatchJobStatus(ctx context.Context, referenceID string) (*BatchJobStatusResult, error) { + batch, err := c.client.Messages.Batches.Get(ctx, referenceID) + if err != nil { + return nil, fmt.Errorf("get batch job %s: %w", referenceID, err) + } + + result := &BatchJobStatusResult{} + + switch batch.ProcessingStatus { + case anthropic.MessageBatchProcessingStatusEnded: + // Collect all individual results by streaming the results file. + stream := c.client.Messages.Batches.ResultsStreaming(ctx, referenceID) + var responses []anthropic.MessageBatchIndividualResponse + for stream.Next() { + responses = append(responses, stream.Current()) + } + if err := stream.Err(); err != nil { + return nil, fmt.Errorf("stream batch results %s: %w", referenceID, err) + } + + // Determine overall status and aggregate token counts. + // Batches with at least one succeeded request are marked completed, + // matching the GoogleProvider behaviour for partial success. Callers + // can inspect the individual ResultsJSON entries for per-request errors. + succeeded := batch.RequestCounts.Succeeded + if succeeded == 0 { + result.Status = models.BatchJobStatusFailed + } else { + result.Status = models.BatchJobStatusCompleted + } + + for _, r := range responses { + if r.Result.Type == "succeeded" { + usage := r.Result.Message.Usage + result.InputTokenCount += usage.InputTokens + result.OutputTokenCount += usage.OutputTokens + result.TotalTokenCount += usage.InputTokens + usage.OutputTokens + } + } + + resultsJSON, err := json.Marshal(responses) + if err != nil { + return nil, fmt.Errorf("marshal batch results: %w", err) + } + result.ResultsJSON = resultsJSON + + case anthropic.MessageBatchProcessingStatusCanceling: + result.Status = models.BatchJobStatusProcessing + + default: + // in_progress — still in flight. + result.Status = models.BatchJobStatusProcessing + } + + return result, nil +} + +// extractTextFromContent returns the concatenated text from all text-type +// content blocks in the message response. +func extractTextFromContent(content []anthropic.ContentBlockUnion) string { + var text string + for _, block := range content { + if block.Type == "text" { + text += block.Text + } + } + return text +} diff --git a/internal/aigateway/providers/claude_test.go b/internal/aigateway/providers/claude_test.go new file mode 100644 index 0000000..d8075cf --- /dev/null +++ b/internal/aigateway/providers/claude_test.go @@ -0,0 +1,80 @@ +package providers + +import ( + "testing" + + "github.com/anthropics/anthropic-sdk-go" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewClaudeProvider(t *testing.T) { + t.Run("empty api key returns error", func(t *testing.T) { + _, err := newClaudeProvider("") + assert.Error(t, err) + assert.Contains(t, err.Error(), "api key is required") + }) + + t.Run("valid api key returns provider", func(t *testing.T) { + p, err := newClaudeProvider("test-api-key") + require.NoError(t, err) + assert.NotNil(t, p) + assert.Equal(t, "claude", p.Name()) + }) +} + +func TestExtractTextFromContent(t *testing.T) { + tests := []struct { + name string + content []anthropic.ContentBlockUnion + wantText string + }{ + { + name: "nil slice", + content: nil, + wantText: "", + }, + { + name: "empty slice", + content: []anthropic.ContentBlockUnion{}, + wantText: "", + }, + { + name: "single text block", + content: []anthropic.ContentBlockUnion{ + {Type: "text", Text: "Hello, world!"}, + }, + wantText: "Hello, world!", + }, + { + name: "multiple text blocks concatenated", + content: []anthropic.ContentBlockUnion{ + {Type: "text", Text: "Hello"}, + {Type: "text", Text: ", world!"}, + }, + wantText: "Hello, world!", + }, + { + name: "non-text blocks skipped", + content: []anthropic.ContentBlockUnion{ + {Type: "thinking", Thinking: "some reasoning"}, + {Type: "text", Text: "Answer"}, + }, + wantText: "Answer", + }, + { + name: "no text blocks returns empty string", + content: []anthropic.ContentBlockUnion{ + {Type: "thinking", Thinking: "some reasoning"}, + }, + wantText: "", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := extractTextFromContent(tc.content) + assert.Equal(t, tc.wantText, got) + }) + } +} diff --git a/internal/aigateway/providers/providers.go b/internal/aigateway/providers/providers.go index 4333c4f..d2f833f 100644 --- a/internal/aigateway/providers/providers.go +++ b/internal/aigateway/providers/providers.go @@ -33,3 +33,7 @@ type Provider interface { func New(geminiCfg platformconfig.GeminiConfig) (Provider, error) { return newGoogleProvider(geminiCfg.APIKey) } + +func NewClaude(claudeCfg platformconfig.ClaudeConfig) (Provider, error) { + return newClaudeProvider(claudeCfg.APIKey) +} diff --git a/internal/platform/config/config.go b/internal/platform/config/config.go index eb28860..f03d4de 100644 --- a/internal/platform/config/config.go +++ b/internal/platform/config/config.go @@ -43,6 +43,7 @@ type RedisConfig struct { type ProvidersConfig struct { Gemini GeminiConfig `mapstructure:"gemini"` VertexAI VertexAIConfig `mapstructure:"vertexai"` + Claude ClaudeConfig `mapstructure:"claude"` } type GeminiConfig struct { @@ -54,6 +55,10 @@ type VertexAIConfig struct { Location string `mapstructure:"location"` } +type ClaudeConfig struct { + APIKey string `mapstructure:"api_key"` +} + type SchedulerConfig struct { SyncBatchJobStatusCron string `mapstructure:"sync_batch_job_status_cron"` }