diff --git a/README.md b/README.md index 81c17a60..b2f8ce3e 100644 --- a/README.md +++ b/README.md @@ -401,6 +401,10 @@ val outputFormat = OutputFormat.JsonSchema(schema) ### Claude Tool Calling +#### Custom Tools + +Define your own tools that Claude calls and your application executes: + ```scala import sttp.ai.claude.models.{Tool, ToolInputSchema, PropertySchema} @@ -411,7 +415,7 @@ val weatherTool = Tool( `type` = "object", properties = Map( "location" -> PropertySchema(`type` = "string", description = Some("City name")), - "unit" -> PropertySchema(`type` = "string", enum = Some(List("celsius", "fahrenheit"))) + "unit" -> PropertySchema(`type` = "string", `enum` = Some(List("celsius", "fahrenheit"))) ), required = Some(List("location")) ) @@ -425,6 +429,42 @@ val request = MessageRequest.withTools( ) ``` +#### Predefined Tools + +Currently supported: + +- **`Tool.WebSearch`** (`web_search_20250305`) + +```scala +import sttp.ai.claude.models.{ContentBlock, Message, Tool} +import sttp.ai.claude.requests.MessageRequest + +val request = MessageRequest.withTools( + model = "claude-sonnet-4-5-20250514", + messages = List(Message.user(List(ContentBlock.text("What was the most recent SpaceX launch?")))), + maxTokens = 1024, + tools = List(Tool.WebSearch.default) +) + +val response = client.createMessage(request) + +response.content.foreach { + case t: ContentBlock.TextContent => println(t.text) + case s: ContentBlock.ServerToolUseContent => + println(s"Searched for: ${s.input.get("query").map(_.str).getOrElse("")}") + case r: ContentBlock.WebSearchToolResultContent => + r.content match { + case ContentBlock.WebSearchToolResult.Results(items) => + items.foreach(it => println(s"- ${it.title} — ${it.url}")) + case ContentBlock.WebSearchToolResult.Error(code) => + println(s"Web search failed: $code") + } + case _ => () +} +``` + +Both custom and predefined tools can be passed in the same `tools` list. + ### Claude Streaming #### Using fs2 (cats-effect) diff --git a/claude/src/main/scala/sttp/ai/claude/models/ContentBlock.scala b/claude/src/main/scala/sttp/ai/claude/models/ContentBlock.scala index 40b21158..5176ad3d 100644 --- a/claude/src/main/scala/sttp/ai/claude/models/ContentBlock.scala +++ b/claude/src/main/scala/sttp/ai/claude/models/ContentBlock.scala @@ -1,5 +1,6 @@ package sttp.ai.claude.models +import sttp.ai.core.json.SnakePickle import sttp.ai.core.json.SnakePickle.{macroRW, ReadWriter} import ujson.Value import upickle.implicits.key @@ -52,6 +53,67 @@ object ContentBlock { val `type`: String = "document" } + @key("server_tool_use") + case class ServerToolUseContent( + id: String, + name: String, + input: Map[String, Value] + ) extends ContentBlock { + val `type`: String = "server_tool_use" + } + + @key("web_search_tool_result") + case class WebSearchToolResultContent( + toolUseId: String, + content: WebSearchToolResult, + caller: Option[Value] = None + ) extends ContentBlock { + val `type`: String = "web_search_tool_result" + } + + case class WebSearchResult( + url: String, + title: String, + pageAge: Option[String] = None, + encryptedContent: Option[String] = None + ) { + val `type`: String = "web_search_result" + } + + object WebSearchResult { + implicit val rw: ReadWriter[WebSearchResult] = macroRW + } + + sealed trait WebSearchToolResult + + object WebSearchToolResult { + case class Results(items: List[WebSearchResult]) extends WebSearchToolResult + + case class Error(errorCode: String) extends WebSearchToolResult + + private val ErrorTypeValue = "web_search_tool_result_error" + + implicit val rw: ReadWriter[WebSearchToolResult] = SnakePickle + .readwriter[Value] + .bimap[WebSearchToolResult]( + { + case Results(items) => SnakePickle.writeJs(items) + case Error(code) => + ujson.Obj( + "type" -> ujson.Str(ErrorTypeValue), + "error_code" -> ujson.Str(code) + ) + }, + { + case arr: ujson.Arr => Results(SnakePickle.read[List[WebSearchResult]](arr)) + case obj: ujson.Obj if obj.value.get("type").contains(ujson.Str(ErrorTypeValue)) => + Error(obj("error_code").str) + case other => + throw new IllegalArgumentException(s"Unrecognised web_search_tool_result content: $other") + } + ) + } + sealed trait DocumentSource { def `type`: String } @@ -120,6 +182,8 @@ object ContentBlock { implicit val toolUseContentRW: ReadWriter[ToolUseContent] = macroRW implicit val toolResultContentRW: ReadWriter[ToolResultContent] = macroRW implicit val documentContentRW: ReadWriter[DocumentContent] = macroRW + implicit val serverToolUseContentRW: ReadWriter[ServerToolUseContent] = macroRW + implicit val webSearchToolResultContentRW: ReadWriter[WebSearchToolResultContent] = macroRW implicit val rw: ReadWriter[ContentBlock] = ReadWriter.merge( textContentRW, @@ -127,6 +191,8 @@ object ContentBlock { imageContentRW, toolUseContentRW, toolResultContentRW, - documentContentRW + documentContentRW, + serverToolUseContentRW, + webSearchToolResultContentRW ) } diff --git a/claude/src/main/scala/sttp/ai/claude/models/Tool.scala b/claude/src/main/scala/sttp/ai/claude/models/Tool.scala index 663e22c1..f3ac14f0 100644 --- a/claude/src/main/scala/sttp/ai/claude/models/Tool.scala +++ b/claude/src/main/scala/sttp/ai/claude/models/Tool.scala @@ -1,12 +1,10 @@ package sttp.ai.claude.models +import sttp.ai.core.json.SnakePickle import sttp.ai.core.json.SnakePickle.{macroRW, ReadWriter} +import ujson.Value -case class Tool( - name: String, - description: String, - inputSchema: ToolInputSchema -) +sealed trait Tool case class ToolInputSchema( `type`: String, @@ -58,6 +56,93 @@ object ToolInputSchema { implicit val rw: ReadWriter[ToolInputSchema] = macroRW } +@upickle.implicits.serializeDefaults(true) +case class UserLocation( + `type`: String = UserLocation.ApproximateType, + city: Option[String] = None, + region: Option[String] = None, + country: Option[String] = None, + timezone: Option[String] = None +) + +object UserLocation { + val ApproximateType = "approximate" + + def approximate( + city: Option[String] = None, + region: Option[String] = None, + country: Option[String] = None, + timezone: Option[String] = None + ): UserLocation = UserLocation(ApproximateType, city, region, country, timezone) + + implicit val rw: ReadWriter[UserLocation] = macroRW +} + object Tool { - implicit val rw: ReadWriter[Tool] = macroRW + case class Custom( + name: String, + description: String, + inputSchema: ToolInputSchema + ) extends Tool + + @upickle.implicits.key("web_search_20250305") + case class WebSearch( + maxUses: Option[Int] = None, + allowedDomains: Option[List[String]] = None, + blockedDomains: Option[List[String]] = None, + userLocation: Option[UserLocation] = None + ) extends Tool + + object WebSearch { + final val ToolType = "web_search_20250305" + final val ToolName = "web_search" + + val default: WebSearch = WebSearch() + } + + def apply(name: String, description: String, inputSchema: ToolInputSchema): Custom = + Custom(name, description, inputSchema) + + // manual rw so custom JSON has no `type` field, matching Anthropic documented format + private val customRW: ReadWriter[Custom] = SnakePickle + .readwriter[Value] + .bimap[Custom]( + c => + ujson.Obj( + "name" -> ujson.Str(c.name), + "description" -> ujson.Str(c.description), + "input_schema" -> SnakePickle.writeJs(c.inputSchema) + ), + json => + Custom( + name = json("name").str, + description = json("description").str, + inputSchema = SnakePickle.read[ToolInputSchema](json("input_schema")) + ) + ) + + private val webSearchRW: ReadWriter[WebSearch] = macroRW + + private def withName(json: Value, toolName: String): Value = { + val obj = scala.collection.mutable.LinkedHashMap[String, Value]() + json.obj.foreach { case (k, v) => + obj.update(k, v) + if (k == SnakePickle.tagName) obj.update("name", ujson.Str(toolName)) + } + ujson.Obj.from(obj) + } + + implicit val toolRW: ReadWriter[Tool] = SnakePickle + .readwriter[Value] + .bimap[Tool]( + { + case c: Custom => SnakePickle.writeJs(c)(customRW) + case ws: WebSearch => withName(SnakePickle.writeJs(ws)(webSearchRW), WebSearch.ToolName) + }, + json => + json.obj.get(SnakePickle.tagName).map(_.str) match { + case Some(WebSearch.ToolType) => SnakePickle.read[WebSearch](json)(webSearchRW) + case _ => SnakePickle.read[Custom](json)(customRW) + } + ) } diff --git a/claude/src/test/scala/sttp/ai/claude/integration/ClaudeIntegrationSpec.scala b/claude/src/test/scala/sttp/ai/claude/integration/ClaudeIntegrationSpec.scala index 69e0fa70..d18607bf 100644 --- a/claude/src/test/scala/sttp/ai/claude/integration/ClaudeIntegrationSpec.scala +++ b/claude/src/test/scala/sttp/ai/claude/integration/ClaudeIntegrationSpec.scala @@ -256,6 +256,35 @@ class ClaudeIntegrationSpec extends AnyFlatSpec with Matchers with BeforeAndAfte () } + it should "handle web search predefined tool successfully" in + withClient { client => + // given + val request = MessageRequest.withTools( + model = testModel, + messages = List(Message.user("What was the most recent SpaceX launch? Use web search to find out.")), + maxTokens = 1024, + tools = List(Tool.WebSearch(maxUses = Some(1))) + ) + + // when + val response = client.createMessage(request) + + // then + response should not be null + response.role shouldBe "assistant" + response.content should not be empty + + val serverToolUse = response.content.collectFirst { case s: ContentBlock.ServerToolUseContent => s } + serverToolUse should be(defined) + serverToolUse.get.name shouldBe "web_search" + + val toolResult = response.content.collectFirst { case r: ContentBlock.WebSearchToolResultContent => r } + toolResult should be(defined) + toolResult.get.toolUseId shouldBe serverToolUse.get.id + toolResult.get.content shouldBe a[ContentBlock.WebSearchToolResult.Results] + () + } + "Claude Error Handling" should "throw AuthenticationException for invalid API key" in { // given val invalidConfig = ClaudeConfig( diff --git a/claude/src/test/scala/sttp/ai/claude/unit/models/ToolSpec.scala b/claude/src/test/scala/sttp/ai/claude/unit/models/ToolSpec.scala new file mode 100644 index 00000000..dec69d9b --- /dev/null +++ b/claude/src/test/scala/sttp/ai/claude/unit/models/ToolSpec.scala @@ -0,0 +1,96 @@ +package sttp.ai.claude.unit.models + +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers +import sttp.ai.claude.models._ +import sttp.ai.core.json.SnakePickle._ + +class ToolSpec extends AnyFlatSpec with Matchers { + + "Tool.Custom" should "serialize without a type discriminator" in { + val tool = Tool.Custom( + name = "get_weather", + description = "Get weather for a city", + inputSchema = ToolInputSchema.forObject( + properties = Map("city" -> PropertySchema.string("The city name")), + required = Some(List("city")) + ) + ) + + val json = ujson.read(write[Tool](tool)) + + json.obj.contains("type") shouldBe false + json("name").str shouldBe "get_weather" + json("description").str shouldBe "Get weather for a city" + json("input_schema")("type").str shouldBe "object" + } + + it should "round-trip" in { + val tool: Tool = Tool.Custom( + name = "get_weather", + description = "desc", + inputSchema = ToolInputSchema.forObject( + properties = Map("city" -> PropertySchema.string("city")), + required = Some(List("city")) + ) + ) + read[Tool](write(tool)) shouldBe tool + } + + "Tool.WebSearch" should "serialize with type and name discriminators" in { + val tool = Tool.WebSearch( + maxUses = Some(5), + allowedDomains = Some(List("example.com")), + userLocation = Some(UserLocation.approximate(city = Some("San Francisco"), country = Some("US"))) + ) + + val json = ujson.read(write[Tool](tool)) + + json("type").str shouldBe "web_search_20250305" + json("name").str shouldBe "web_search" + json("max_uses").num shouldBe 5 + json("allowed_domains").arr.map(_.str).toList shouldBe List("example.com") + json("user_location")("type").str shouldBe "approximate" + json("user_location")("city").str shouldBe "San Francisco" + json("user_location")("country").str shouldBe "US" + } + + it should "omit unset fields" in { + val tool: Tool = Tool.WebSearch() + val json = ujson.read(write[Tool](tool)) + + json("type").str shouldBe "web_search_20250305" + json("name").str shouldBe "web_search" + json.obj.contains("max_uses") shouldBe false + json.obj.contains("allowed_domains") shouldBe false + json.obj.contains("blocked_domains") shouldBe false + json.obj.contains("user_location") shouldBe false + } + + it should "round-trip" in { + val tool: Tool = Tool.WebSearch( + maxUses = Some(3), + blockedDomains = Some(List("bad.example")), + userLocation = Some(UserLocation.approximate(timezone = Some("America/Los_Angeles"))) + ) + read[Tool](write(tool)) shouldBe tool + } + + "Tool list" should "mix custom and predefined tools in a single array" in { + val tools: List[Tool] = List( + Tool.Custom( + name = "get_weather", + description = "weather", + inputSchema = ToolInputSchema.forObject(Map("city" -> PropertySchema.string("city"))) + ), + Tool.WebSearch(maxUses = Some(5)) + ) + + val arr = ujson.read(write(tools)).arr.toList + + arr.head.obj.contains("type") shouldBe false + arr(1)("type").str shouldBe "web_search_20250305" + + read[List[Tool]](write(tools)) shouldBe tools + } +} diff --git a/claude/src/test/scala/sttp/ai/claude/unit/responses/WebSearchResponseSpec.scala b/claude/src/test/scala/sttp/ai/claude/unit/responses/WebSearchResponseSpec.scala new file mode 100644 index 00000000..c29861ac --- /dev/null +++ b/claude/src/test/scala/sttp/ai/claude/unit/responses/WebSearchResponseSpec.scala @@ -0,0 +1,165 @@ +package sttp.ai.claude.unit.responses + +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers +import sttp.ai.claude.models.{Citation, ContentBlock} +import sttp.ai.claude.responses.MessageResponse +import sttp.ai.core.json.SnakePickle._ + +class WebSearchResponseSpec extends AnyFlatSpec with Matchers { + + private val successResponseJson = + """{ + | "id": "msg_01ABC", + | "type": "message", + | "role": "assistant", + | "model": "claude-haiku-4-5-20251001", + | "content": [ + | { + | "type": "text", + | "text": "I'll search for that." + | }, + | { + | "type": "server_tool_use", + | "id": "srvtoolu_01XYZ", + | "name": "web_search", + | "input": { "query": "claude shannon birth date" } + | }, + | { + | "type": "web_search_tool_result", + | "tool_use_id": "srvtoolu_01XYZ", + | "content": [ + | { + | "type": "web_search_result", + | "url": "https://en.wikipedia.org/wiki/Claude_Shannon", + | "title": "Claude Shannon - Wikipedia", + | "encrypted_content": "AAA", + | "page_age": "April 30, 2025" + | } + | ], + | "caller": { "type": "direct" } + | }, + | { + | "type": "text", + | "text": "Claude Shannon was born on April 30, 1916.", + | "citations": [ + | { + | "type": "web_search_result_location", + | "url": "https://en.wikipedia.org/wiki/Claude_Shannon", + | "title": "Claude Shannon - Wikipedia", + | "encrypted_index": "BBB", + | "cited_text": "Claude Elwood Shannon (April 30, 1916 ..." + | } + | ] + | } + | ], + | "stop_reason": "end_turn", + | "stop_sequence": null, + | "usage": { + | "input_tokens": 100, + | "output_tokens": 50, + | "server_tool_use": { "web_search_requests": 1 } + | } + |}""".stripMargin + + private val errorResponseJson = + """{ + | "id": "msg_02DEF", + | "type": "message", + | "role": "assistant", + | "model": "claude-haiku-4-5-20251001", + | "content": [ + | { + | "type": "web_search_tool_result", + | "tool_use_id": "srvtoolu_02DEF", + | "content": { + | "type": "web_search_tool_result_error", + | "error_code": "max_uses_exceeded" + | } + | } + | ], + | "stop_reason": "end_turn", + | "stop_sequence": null, + | "usage": { "input_tokens": 10, "output_tokens": 5 } + |}""".stripMargin + + "MessageResponse with web_search content" should "deserialize server_tool_use blocks" in { + val response = read[MessageResponse](successResponseJson) + + val serverToolUse = response.content.collectFirst { case s: ContentBlock.ServerToolUseContent => s } + serverToolUse should be(defined) + serverToolUse.get.id shouldBe "srvtoolu_01XYZ" + serverToolUse.get.name shouldBe "web_search" + serverToolUse.get.input("query").str shouldBe "claude shannon birth date" + } + + it should "deserialize web_search_tool_result with results array" in { + val response = read[MessageResponse](successResponseJson) + + val toolResult = response.content.collectFirst { case r: ContentBlock.WebSearchToolResultContent => r } + toolResult should be(defined) + toolResult.get.toolUseId shouldBe "srvtoolu_01XYZ" + + val results = toolResult.get.content match { + case ContentBlock.WebSearchToolResult.Results(items) => items + case other => fail(s"Expected Results, got $other") + } + results should have size 1 + results.head.url shouldBe "https://en.wikipedia.org/wiki/Claude_Shannon" + results.head.title shouldBe "Claude Shannon - Wikipedia" + results.head.pageAge shouldBe Some("April 30, 2025") + results.head.encryptedContent shouldBe Some("AAA") + } + + it should "preserve the undocumented caller field" in { + val response = read[MessageResponse](successResponseJson) + val toolResult = response.content.collectFirst { case r: ContentBlock.WebSearchToolResultContent => r }.get + + toolResult.caller should be(defined) + toolResult.caller.get("type").str shouldBe "direct" + } + + it should "deserialize web_search_result_location citations on text blocks" in { + val response = read[MessageResponse](successResponseJson) + + val finalText = response.content.collect { case t: ContentBlock.TextContent => t }.last + finalText.citations should be(defined) + finalText.citations.get should have size 1 + finalText.citations.get.head shouldBe a[Citation.WebSearchResultLocation] + val cite = finalText.citations.get.head.asInstanceOf[Citation.WebSearchResultLocation] + cite.url shouldBe "https://en.wikipedia.org/wiki/Claude_Shannon" + cite.encryptedIndex shouldBe "BBB" + } + + it should "deserialize web_search_tool_result error variant" in { + val response = read[MessageResponse](errorResponseJson) + + val toolResult = response.content.collectFirst { case r: ContentBlock.WebSearchToolResultContent => r } + toolResult should be(defined) + + toolResult.get.content match { + case ContentBlock.WebSearchToolResult.Error(code) => code shouldBe "max_uses_exceeded" + case other => fail(s"Expected Error, got $other") + } + } + + "WebSearchToolResult content RW" should "round-trip Results variant" in { + val original: ContentBlock.WebSearchToolResult = + ContentBlock.WebSearchToolResult.Results( + List( + ContentBlock.WebSearchResult( + url = "https://example.com", + title = "Example", + pageAge = Some("yesterday"), + encryptedContent = Some("X") + ) + ) + ) + read[ContentBlock.WebSearchToolResult](write(original)) shouldBe original + } + + it should "round-trip Error variant" in { + val original: ContentBlock.WebSearchToolResult = ContentBlock.WebSearchToolResult.Error("too_many_requests") + read[ContentBlock.WebSearchToolResult](write(original)) shouldBe original + } +}