Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -169,12 +169,13 @@ tests:

### LLM Provider Support

| Provider | Config | API Key |
|----------|--------|---------|
| Claude | `provider: claude` | `ANTHROPIC_API_KEY` |
| OpenAI | `provider: openai` | `OPENAI_API_KEY` |
| Ollama (local) | `provider: local` | Not required |
| Custom | `baseUrl: http://...` | Optional |
| Provider | Config | API Key | Default Model |
|----------|--------|---------|---------------|
| Claude | `provider: claude` | `ANTHROPIC_API_KEY` | `claude-sonnet-4-20250514` |
| OpenAI | `provider: openai` | `OPENAI_API_KEY` | `gpt-4o-mini` |
| Gemini | `provider: gemini` | `GOOGLE_API_KEY` or `GEMINI_API_KEY` | `gemini-2.0-flash` |
| Ollama (local) | `provider: local` | Not required | `llama3` |
| Custom | `baseUrl: http://...` | Optional | — |

## Deployment

Expand Down
33 changes: 30 additions & 3 deletions cli/src/main/scala/com/dataweaver/cli/ai/LLMClient.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,10 @@ object LLMClient {
}

config.provider match {
case "claude" => callClaude(prompt, config)
case "openai" => callOpenAI(prompt, config)
case other => Left(s"Unknown AI provider '$other'. Supported: claude, openai")
case "claude" => callClaude(prompt, config)
case "openai" => callOpenAI(prompt, config)
case "gemini" | "vertex-ai" => callGemini(prompt, config)
case other => Left(s"Unknown AI provider '$other'. Supported: claude, openai, gemini")
}
}

Expand Down Expand Up @@ -68,6 +69,23 @@ object LLMClient {
executeRequest(request).flatMap(extractOpenAIResponse)
}

private def callGemini(prompt: String, config: WeaverConfig.AIConfig): Either[String, String] = {
val body = s"""{
"contents": [{"parts": [{"text": ${escapeJson(prompt)}}]}],
"generationConfig": {"maxOutputTokens": 4096}
}"""

val url = s"https://generativelanguage.googleapis.com/v1beta/models/${config.model}:generateContent?key=${config.apiKey}"

val request = HttpRequest.newBuilder()
.uri(URI.create(url))
.header("Content-Type", "application/json")
.POST(HttpRequest.BodyPublishers.ofString(body))
.build()

executeRequest(request).flatMap(extractGeminiResponse)
}

private def executeRequest(request: HttpRequest): Either[String, String] = {
Try(httpClient.send(request, HttpResponse.BodyHandlers.ofString())) match {
case Success(response) if response.statusCode() >= 200 && response.statusCode() < 300 =>
Expand Down Expand Up @@ -98,6 +116,15 @@ object LLMClient {
}
}

/** Extract text from Gemini API response: candidates[0].content.parts[0].text */
private def extractGeminiResponse(json: String): Either[String, String] = {
val textPattern = """"text"\s*:\s*"((?:[^"\\]|\\.)*)"""".r
textPattern.findFirstMatchIn(json) match {
case Some(m) => Right(unescapeJson(m.group(1)))
case None => Left(s"Cannot parse Gemini response: ${json.take(500)}")
}
}

private def escapeJson(s: String): String = {
"\"" + s.replace("\\", "\\\\").replace("\"", "\\\"").replace("\n", "\\n").replace("\r", "\\r").replace("\t", "\\t") + "\""
}
Expand Down
7 changes: 4 additions & 3 deletions cli/src/main/scala/com/dataweaver/cli/ai/WeaverConfig.scala
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,10 @@ object WeaverConfig {
case _ =>
// Fallback to standard env vars
provider match {
case "claude" => sys.env.getOrElse("ANTHROPIC_API_KEY", "")
case "openai" => sys.env.getOrElse("OPENAI_API_KEY", "")
case _ => ""
case "claude" => sys.env.getOrElse("ANTHROPIC_API_KEY", "")
case "openai" => sys.env.getOrElse("OPENAI_API_KEY", "")
case "gemini" | "vertex-ai" => sys.env.getOrElse("GOOGLE_API_KEY", sys.env.getOrElse("GEMINI_API_KEY", ""))
case _ => ""
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,16 +133,23 @@ class LLMTransformPlugin extends TransformPlugin {
Map("x-api-key" -> apiKey, "anthropic-version" -> "2023-06-01", "Content-Type" -> "application/json"),
s"""{"model":"$model","max_tokens":4096,"messages":[{"role":"user","content":"$escapedPrompt"}]}"""
)
case "gemini" | "vertex-ai" =>
// Google AI Studio (default) or Vertex AI via baseUrl
val geminiUrl = s"https://generativelanguage.googleapis.com/v1beta/models/$model:generateContent?key=$apiKey"
(
geminiUrl,
Map("Content-Type" -> "application/json"),
s"""{"contents":[{"parts":[{"text":"$escapedPrompt"}]}],"generationConfig":{"maxOutputTokens":4096}}"""
)
case "local" =>
// Ollama-compatible API (OpenAI-compatible endpoint)
(
"http://localhost:11434/v1/chat/completions",
Map("Content-Type" -> "application/json"),
s"""{"model":"$model","messages":[{"role":"user","content":"$escapedPrompt"}]}"""
)
case _ => // openai and any other OpenAI-compatible provider
val apiUrl = if (provider == "openai") "https://api.openai.com/v1/chat/completions"
else s"https://api.$provider.com/v1/chat/completions" // extensible
else s"https://api.$provider.com/v1/chat/completions"
(
apiUrl,
Map("Authorization" -> s"Bearer $apiKey", "Content-Type" -> "application/json"),
Expand All @@ -160,15 +167,28 @@ class LLMTransformPlugin extends TransformPlugin {
throw new RuntimeException(s"LLM API error (${response.statusCode()}): ${response.body().take(500)}")
}

// Extract text from response
val textPattern = """"text"\s*:\s*"((?:[^"\\]|\\.)*)"""".r
val contentPattern = """"content"\s*:\s*"((?:[^"\\]|\\.)*)"""".r
extractResponseText(response.body(), provider)
}

/** Extract generated text from LLM API response based on provider format. */
private def extractResponseText(responseBody: String, provider: String): String = {
val result = provider match {
case "claude" =>
val pattern = """"text"\s*:\s*"((?:[^"\\]|\\.)*)"""".r
pattern.findFirstMatchIn(responseBody).map(_.group(1))

case "gemini" | "vertex-ai" =>
// Gemini response: candidates[0].content.parts[0].text
val pattern = """"text"\s*:\s*"((?:[^"\\]|\\.)*)"""".r
pattern.findFirstMatchIn(responseBody).map(_.group(1))

val pattern = if (provider == "claude") textPattern else contentPattern
pattern.findFirstMatchIn(response.body()) match {
case Some(m) => m.group(1).replace("\\n", "\n").replace("\\\"", "\"")
case None => throw new RuntimeException(s"Cannot parse LLM response")
case _ => // openai, local, and compatible
val pattern = """"content"\s*:\s*"((?:[^"\\]|\\.)*)"""".r
pattern.findFirstMatchIn(responseBody).map(_.group(1))
}

result.map(_.replace("\\n", "\n").replace("\\\"", "\""))
.getOrElse(throw new RuntimeException(s"Cannot parse LLM response"))
}

private def parseOutputSchema(schemaStr: String): List[(String, DataType)] = {
Expand Down Expand Up @@ -226,10 +246,11 @@ class LLMTransformPlugin extends TransformPlugin {
private def resolveApiKey(config: Map[String, String], provider: String): String = {
config.getOrElse("apiKey", {
provider match {
case "claude" => sys.env.getOrElse("ANTHROPIC_API_KEY", "")
case "openai" => sys.env.getOrElse("OPENAI_API_KEY", "")
case "local" => "" // Ollama doesn't need an API key
case _ => ""
case "claude" => sys.env.getOrElse("ANTHROPIC_API_KEY", "")
case "openai" => sys.env.getOrElse("OPENAI_API_KEY", "")
case "gemini" | "vertex-ai" => sys.env.getOrElse("GOOGLE_API_KEY", sys.env.getOrElse("GEMINI_API_KEY", ""))
case "local" => ""
case _ => ""
}
})
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,12 @@ JSON array:"""
("https://api.anthropic.com/v1/messages",
Map("x-api-key" -> apiKey, "anthropic-version" -> "2023-06-01", "Content-Type" -> "application/json"),
s"""{"model":"$model","max_tokens":4096,"messages":[{"role":"user","content":"$escapedPrompt"}]}""")
case "gemini" | "vertex-ai" =>
val geminiUrl = baseUrl.getOrElse(
s"https://generativelanguage.googleapis.com/v1beta/models/$model:generateContent?key=$apiKey")
(geminiUrl,
Map("Content-Type" -> "application/json"),
s"""{"contents":[{"parts":[{"text":"$escapedPrompt"}]}],"generationConfig":{"maxOutputTokens":4096}}""")
case _ => // openai-compatible
val apiUrl = baseUrl.getOrElse("https://api.openai.com/v1/chat/completions")
(apiUrl,
Expand Down Expand Up @@ -135,16 +141,18 @@ JSON array:"""
}

private def defaultModel(provider: String): String = provider match {
case "claude" => "claude-sonnet-4-20250514"
case "openai" => "gpt-4o-mini"
case "local" => "llama3"
case _ => "gpt-4o-mini"
case "claude" => "claude-sonnet-4-20250514"
case "openai" => "gpt-4o-mini"
case "gemini" | "vertex-ai" => "gemini-2.0-flash"
case "local" => "llama3"
case _ => "gpt-4o-mini"
}

private def resolveApiKey(provider: String): String = provider match {
case "claude" => sys.env.getOrElse("ANTHROPIC_API_KEY", "")
case "openai" => sys.env.getOrElse("OPENAI_API_KEY", "")
case "local" => "" // no key needed
case _ => ""
case "claude" => sys.env.getOrElse("ANTHROPIC_API_KEY", "")
case "openai" => sys.env.getOrElse("OPENAI_API_KEY", "")
case "gemini" | "vertex-ai" => sys.env.getOrElse("GOOGLE_API_KEY", sys.env.getOrElse("GEMINI_API_KEY", ""))
case "local" => ""
case _ => ""
}
}
Loading