Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions pkg/embedding/probe.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package embedding

import (
"bytes"
"context"
"encoding/json"
"net/http"
"time"
)

// probeTimeout bounds the v2 capability probe so a slow or unreachable proxy
// falls back to v1 quickly rather than stalling startup.
const probeTimeout = 15 * time.Second

// ProbeV2 reports whether the proxy exposes the versioned /v2/embedding route
// and, if so, the embedding model it currently advertises. It POSTs an empty
// check request to /v2/embedding/check; a 200 response means v2 is available and
// its model is returned. Any non-200 status or transport error yields
// ("", false), signalling the caller to fall back to the legacy /embed routes.
func ProbeV2(ctx context.Context, proxyURL string, tokenFn func() string) (string, bool) {
body, err := json.Marshal(embedCheckRequest{Hashes: []string{}})
if err != nil {
return "", false
}

req, err := http.NewRequestWithContext(
ctx, http.MethodPost, proxyURL+"/v2/embedding/check", bytes.NewReader(body),
)
if err != nil {
return "", false
}

req.Header.Set("Content-Type", "application/json")

if tokenFn != nil {
if token := tokenFn(); token != "" {
req.Header.Set("Authorization", "Bearer "+token)
}
}

client := &http.Client{Timeout: probeTimeout}

resp, err := client.Do(req)
if err != nil {
return "", false
}
defer func() { _ = resp.Body.Close() }()

if resp.StatusCode != http.StatusOK {
return "", false
}

var checkResp embedCheckResponse
if err := json.NewDecoder(resp.Body).Decode(&checkResp); err != nil {
return "", false
}

return checkResp.Model, true
}
96 changes: 72 additions & 24 deletions pkg/embedding/remote.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,21 @@ const (
maxBatchSize = 500
)

// embedCheckRequest is the request payload for the proxy /embed/check endpoint.
// embedCheckRequest is the request payload for the proxy embed-check endpoint.
type embedCheckRequest struct {
Model string `json:"model"`
Hashes []string `json:"hashes"`
}

// embedCheckResponse is the response from /embed/check.
// embedCheckResponse is the response from the embed-check endpoint. Model is
// populated by the v2 route (and empty on v1) so callers can observe which
// embedding model the proxy is currently serving.
type embedCheckResponse struct {
Model string `json:"model"`
Cached []embedResult `json:"cached"`
}

// embedRequest is the request payload for the proxy /embed endpoint.
// embedRequest is the request payload for the proxy embed endpoint.
type embedRequest struct {
Items []embedItem `json:"items"`
}
Expand All @@ -44,10 +47,13 @@ type embedItem struct {
Text string `json:"text"`
}

// embedResponse is the response payload from the proxy /embed endpoint.
// embedResponse is the response payload from the proxy embed endpoint. Both v1
// and v2 advertise the serving model; v2 additionally reports its fixed output
// dimensionality.
type embedResponse struct {
Results []embedResult `json:"results"`
Model string `json:"model"`
Results []embedResult `json:"results"`
Model string `json:"model"`
Dimensions int `json:"dimensions"`
}

// embedResult is a single embedding result.
Expand All @@ -56,8 +62,11 @@ type embedResult struct {
Vector []float32 `json:"vector"`
}

// RemoteEmbedder implements Embedder by calling the proxy's /embed endpoint.
// RemoteEmbedder implements Embedder by calling the proxy's embed endpoint.
// An optional local cache avoids round-trips to the proxy on warm restarts.
// When v2 is set it targets the versioned /v2/embedding routes (fp32 at a fixed
// dimensionality, model advertised per response); otherwise it uses the legacy
// /embed routes.
type RemoteEmbedder struct {
log logrus.FieldLogger
proxyURL string
Expand All @@ -66,26 +75,14 @@ type RemoteEmbedder struct {
invalidateFn func()
localCache cache.Cache
model string
v2 bool
progressFn func(completed, total int)
}

// OnProgress registers a callback invoked during EmbedBatch with the number of
// documents embedded so far and the total in the batch. It enables
// document-level progress reporting for index builds.
func (e *RemoteEmbedder) OnProgress(fn func(completed, total int)) {
e.progressFn = fn
}

func (e *RemoteEmbedder) reportProgress(completed, total int) {
if e.progressFn != nil {
e.progressFn(completed, total)
}
}

// Compile-time interface check.
var _ Embedder = (*RemoteEmbedder)(nil)

// NewRemote creates a new RemoteEmbedder that calls the proxy's /embed endpoint.
// NewRemote creates a RemoteEmbedder that calls the proxy's legacy /embed routes.
// tokenFn is called on each request to get the current auth token, and
// invalidateFn drops the cached token so a 401/403 can be retried with a fresh
// one (it may be nil to disable the retry).
Expand All @@ -98,6 +95,20 @@ func NewRemote(
invalidateFn func(),
localCache cache.Cache,
model string,
) *RemoteEmbedder {
return NewRemoteWithEndpoint(log, proxyURL, tokenFn, invalidateFn, localCache, model, false)
}

// NewRemoteWithEndpoint creates a RemoteEmbedder targeting either the v2
// (/v2/embedding) routes when v2 is true, or the legacy /embed routes otherwise.
func NewRemoteWithEndpoint(
log logrus.FieldLogger,
proxyURL string,
tokenFn func() string,
invalidateFn func(),
localCache cache.Cache,
model string,
v2 bool,
) *RemoteEmbedder {
return &RemoteEmbedder{
log: log.WithField("component", "remote-embedder"),
Expand All @@ -107,9 +118,24 @@ func NewRemote(
invalidateFn: invalidateFn,
localCache: localCache,
model: model,
v2: v2,
}
}

// Model returns the embedding model this embedder is keyed to. Index builds use
// it to tag the embedding space they were built in, so a later change to the
// proxy's served model can be detected and trigger a re-index.
func (e *RemoteEmbedder) Model() string {
return e.model
}

// OnProgress registers a callback invoked during EmbedBatch with the number of
// documents embedded so far and the total in the batch. It enables
// document-level progress reporting for index builds.
func (e *RemoteEmbedder) OnProgress(fn func(completed, total int)) {
e.progressFn = fn
}

// Embed returns the L2-normalized embedding vector for a single text string.
func (e *RemoteEmbedder) Embed(text string) ([]float32, error) {
vectors, err := e.EmbedBatch([]string{text})
Expand Down Expand Up @@ -328,6 +354,28 @@ func (e *RemoteEmbedder) Close() error {
return nil
}

func (e *RemoteEmbedder) reportProgress(completed, total int) {
if e.progressFn != nil {
e.progressFn(completed, total)
}
}

func (e *RemoteEmbedder) embedPath() string {
if e.v2 {
return "/v2/embedding"
}

return "/embed"
}

func (e *RemoteEmbedder) checkPath() string {
if e.v2 {
return "/v2/embedding/check"
}

return "/embed/check"
}

func (e *RemoteEmbedder) localCacheKey(textHash string) string {
return e.model + ":" + textHash
}
Expand All @@ -345,7 +393,7 @@ func (e *RemoteEmbedder) queueLocalCache(toCache map[string][]byte, textHash str
toCache[e.localCacheKey(textHash)] = data
}

// embedDirect sends all items to /embed without checking the cache first.
// embedDirect sends all items to the embed route without checking the cache first.
func (e *RemoteEmbedder) embedDirect(
texts []string,
hashes []string,
Expand Down Expand Up @@ -418,7 +466,7 @@ func (e *RemoteEmbedder) checkCached(hashes []string) ([]embedResult, error) {
return nil, fmt.Errorf("marshaling check request: %w", err)
}

resp, err := e.doWithAuthRetry("/embed/check", reqBody)
resp, err := e.doWithAuthRetry(e.checkPath(), reqBody)
if err != nil {
return nil, fmt.Errorf("calling embed check: %w", err)
}
Expand All @@ -444,7 +492,7 @@ func (e *RemoteEmbedder) callEmbed(items []embedItem) (*embedResponse, error) {
return nil, fmt.Errorf("marshaling embed request: %w", err)
}

resp, err := e.doWithAuthRetry("/embed", reqBody)
resp, err := e.doWithAuthRetry(e.embedPath(), reqBody)
if err != nil {
return nil, fmt.Errorf("calling proxy embed: %w", err)
}
Expand Down
69 changes: 62 additions & 7 deletions pkg/proxy/embedding.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,22 @@ type EmbedResult struct {
Vector []float32 `json:"vector"`
}

// EmbedV2Response is the response payload from the /v2/embedding endpoint. It
// advertises the model and dimensions so clients can detect a model change and
// re-index. Vectors are fp32.
type EmbedV2Response struct {
Model string `json:"model"`
Dimensions int `json:"dimensions"`
Results []EmbedResult `json:"results"`
}

// EmbedV2CheckResponse is the response from /v2/embedding/check. It advertises
// the model so clients can detect a model change.
type EmbedV2CheckResponse struct {
Model string `json:"model"`
Cached []EmbedResult `json:"cached"`
}

// EmbeddingService handles embedding requests using a remote API with caching.
type EmbeddingService struct {
log logrus.FieldLogger
Expand All @@ -68,15 +84,31 @@ type EmbeddingService struct {
apiURL string
client *http.Client
costPerToken float64
// dimensions, when > 0, requests a fixed output dimensionality from the
// embedding API (Matryoshka truncation). 0 leaves it unset (native dims).
dimensions int
}

// NewEmbeddingService creates a new EmbeddingService.
// If costPerToken is 0, the service fetches pricing from the API's /models endpoint.
// NewEmbeddingService creates a new EmbeddingService with native output
// dimensionality. If costPerToken is 0, the service fetches pricing from the
// API's /models endpoint.
func NewEmbeddingService(
log logrus.FieldLogger,
c cache.Cache,
apiKey, model, apiURL string,
costPerToken float64,
) *EmbeddingService {
return NewEmbeddingServiceWithDimensions(log, c, apiKey, model, apiURL, costPerToken, 0)
}

// NewEmbeddingServiceWithDimensions creates a new EmbeddingService that requests
// a fixed output dimensionality from the embedding API when dimensions > 0.
func NewEmbeddingServiceWithDimensions(
log logrus.FieldLogger,
c cache.Cache,
apiKey, model, apiURL string,
costPerToken float64,
dimensions int,
) *EmbeddingService {
svcLog := log.WithField("component", "embedding-service")
normalizedURL := strings.TrimRight(apiURL, "/")
Expand All @@ -103,6 +135,7 @@ func NewEmbeddingService(
apiURL: normalizedURL,
client: httpClient,
costPerToken: costPerToken,
dimensions: dimensions,
}
}

Expand All @@ -111,6 +144,24 @@ func (s *EmbeddingService) Model() string {
return s.model
}

// Dimensions returns the configured output dimensionality, or 0 for native.
func (s *EmbeddingService) Dimensions() int {
return s.dimensions
}

// cacheKeyPrefix returns the cache-key namespace for this service's vectors.
// It folds the requested dimensionality into the key when set (> 0), so two
// services on the same model id but different output dimensions never collide.
// Native-dimension services (dimensions == 0, i.e. v1) keep the legacy
// {model}: prefix unchanged, so existing cached vectors stay valid.
func (s *EmbeddingService) cacheKeyPrefix() string {
if s.dimensions > 0 {
return s.model + ":" + strconv.Itoa(s.dimensions) + ":"
}

return s.model + ":"
}

// Embed computes embeddings for the given items, using the cache where possible.
// Uncached items are sent to the upstream API in sub-batches of maxEmbedBatchSize.
func (s *EmbeddingService) Embed(ctx context.Context, items []EmbedItem) (*EmbedResponse, error) {
Expand All @@ -124,11 +175,11 @@ func (s *EmbeddingService) Embed(ctx context.Context, items []EmbedItem) (*Embed

s.log.WithField("items", len(items)).Info("Embed request received")

// Build cache keys: {model}:{hash}.
// Build cache keys: {model}:{hash} (or {model}:{dims}:{hash} when dims set).
cacheKeys := make([]string, len(items))

for i, item := range items {
cacheKeys[i] = s.model + ":" + item.Hash
cacheKeys[i] = s.cacheKeyPrefix() + item.Hash
}

// Check cache for existing vectors.
Expand Down Expand Up @@ -266,7 +317,7 @@ func (s *EmbeddingService) CheckCached(ctx context.Context, hashes []string) ([]

cacheKeys := make([]string, len(hashes))
for i, h := range hashes {
cacheKeys[i] = s.model + ":" + h
cacheKeys[i] = s.cacheKeyPrefix() + h
}

cached, err := s.cache.GetMulti(ctx, cacheKeys)
Expand Down Expand Up @@ -312,6 +363,9 @@ func (s *EmbeddingService) Close() error {
type openRouterRequest struct {
Model string `json:"model"`
Input []string `json:"input"`
// Dimensions requests a fixed output size (Matryoshka). Omitted when 0 so
// callers using native dimensionality send a byte-identical request.
Dimensions int `json:"dimensions,omitempty"`
}

// openRouterResponse is the response body from the OpenRouter embeddings API.
Expand All @@ -334,8 +388,9 @@ type openRouterEmbedding struct {

func (s *EmbeddingService) callEmbeddingAPI(ctx context.Context, texts []string) ([][]float32, *openRouterUsage, error) {
reqBody := openRouterRequest{
Model: s.model,
Input: texts,
Model: s.model,
Input: texts,
Dimensions: s.dimensions,
}

body, err := json.Marshal(reqBody)
Expand Down
Loading
Loading