Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 125 additions & 25 deletions src/cycod/ChatClient/FunctionCallingChat.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using Microsoft.Extensions.AI;
using System.Threading;

public class FunctionCallingChat : IAsyncDisposable
{
Expand Down Expand Up @@ -60,15 +61,19 @@ public async Task<string> CompleteChatStreamingAsync(
Action<IList<ChatMessage>>? messageCallback = null,
Action<ChatResponseUpdate>? streamingCallback = null,
Func<string, string?, bool>? approveFunctionCall = null,
Action<string, string, object?>? functionCallCallback = null)
Action<string, string, object?>? functionCallCallback = null,
CancellationToken cancellationToken = default,
Func<string>? getDisplayBuffer = null)
{
return await CompleteChatStreamingAsync(
userPrompt,
new List<string>(),
messageCallback,
streamingCallback,
approveFunctionCall,
functionCallCallback);
functionCallCallback,
cancellationToken,
getDisplayBuffer);
}

public async Task<string> CompleteChatStreamingAsync(
Expand All @@ -77,51 +82,114 @@ public async Task<string> CompleteChatStreamingAsync(
Action<IList<ChatMessage>>? messageCallback = null,
Action<ChatResponseUpdate>? streamingCallback = null,
Func<string, string?, bool>? approveFunctionCall = null,
Action<string, string, object?>? functionCallCallback = null)
Action<string, string, object?>? functionCallCallback = null,
CancellationToken cancellationToken = default,
Func<string>? getDisplayBuffer = null)
{
var message = CreateUserMessageWithImages(userPrompt, imageFiles);

Conversation.Messages.Add(message);
messageCallback?.Invoke(Conversation.Messages);

var contentToReturn = string.Empty;
while (true)
{
var responseContent = string.Empty;
await foreach (var update in _chatClient.GetStreamingResponseAsync(Conversation.Messages, _options))

// Surround streaming with try/catch to handle user interruption via OperationCancelledException
try
{
_functionCallDetector.CheckForFunctionCall(update);
await foreach (var update in _chatClient.GetStreamingResponseAsync(Conversation.Messages, _options, cancellationToken))
{
// Check for cancellation before processing each update
cancellationToken.ThrowIfCancellationRequested();

_functionCallDetector.CheckForFunctionCall(update);

var content = string.Join("", update.Contents
.Where(c => c is TextContent)
.Cast<TextContent>()
.Select(c => c.Text)
.ToList());

var content = string.Join("", update.Contents
.Where(c => c is TextContent)
.Cast<TextContent>()
.Select(c => c.Text)
.ToList());
if (update.FinishReason == ChatFinishReason.ContentFilter)
{
content = $"{content}\nWARNING: Content filtered!";
}

responseContent += content;
contentToReturn += content;

streamingCallback?.Invoke(update);
}
}
catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested)
{
// User interrupted - trim content to match what was actually displayed
var displayBuffer = getDisplayBuffer?.Invoke() ?? "";

if (update.FinishReason == ChatFinishReason.ContentFilter)
// If we have both the display buffer and the content to return, trim the content to match what was displayed
// This avoids showing content that was generated but not actually seen by the user due to the slight delay in cancellation
if (!string.IsNullOrEmpty(displayBuffer) && !string.IsNullOrEmpty(contentToReturn))
{
content = $"{content}\nWARNING: Content filtered!";
var trimmedContent = TrimContentToDisplayBuffer(contentToReturn, displayBuffer);
if (!string.IsNullOrEmpty(trimmedContent))
{
Conversation.Messages.Add(new ChatMessage(ChatRole.Assistant, trimmedContent));
messageCallback?.Invoke(Conversation.Messages);
}
}
throw;
}

responseContent += content;
contentToReturn += content;
// Surround assistant response handling with try/catch to handle user interruption via ChatCommand.UserWantsControlException
try
{
if (TryCallFunctions(responseContent, approveFunctionCall, functionCallCallback, messageCallback))
{
_functionCallDetector.Clear();
continue;
}

streamingCallback?.Invoke(update);
Conversation.Messages.Add(new ChatMessage(ChatRole.Assistant, responseContent));
messageCallback?.Invoke(Conversation.Messages);
}

if (TryCallFunctions(responseContent, approveFunctionCall, functionCallCallback, messageCallback))
catch (ChatCommand.UserWantsControlException)
{
// User cancelled function call - exit streaming entirely and return control to user
_functionCallDetector.Clear();
continue;
return ""; // Empty response - will show blank Assistant line
}

Conversation.Messages.Add(new ChatMessage(ChatRole.Assistant, responseContent));
messageCallback?.Invoke(Conversation.Messages);

return contentToReturn;
}
}

/// <summary>
/// Trims the full content to match what was actually displayed in the display buffer.
/// This is used to ensure that if the user cancels the response, we only show the content that was actually seen by the user,
/// and not any additional content that may have been generated but not displayed due to cancellation.
/// </summary>
/// <param name="fullContent">Content generated by the AI so far (usually more than what has been displayed)</param>
/// <param name="displayBuffer">The tail end of the content that was actually displayed to the user</param>
/// <returns>All content that was shown on screen to the user</returns>
private string TrimContentToDisplayBuffer(string fullContent, string displayBuffer)
{
if (string.IsNullOrEmpty(displayBuffer) || string.IsNullOrEmpty(fullContent))
return "";

// Find where the display buffer content appears in the full content
var displayBufferIndex = fullContent.LastIndexOf(displayBuffer);
if (displayBufferIndex >= 0)
{
// Trim to end where display buffer ends
return fullContent.Substring(0, displayBufferIndex + displayBuffer.Length);
}

// If we can't find the display buffer in the content, return the display buffer
// This handles edge cases where the content might have been modified
return displayBuffer;
}

private bool TryCallFunctions(string responseContent, Func<string, string?, bool>? approveFunctionCall, Action<string, string, object?>? functionCallCallback, Action<IList<ChatMessage>>? messageCallback)
{
var noFunctionsToCall = !_functionCallDetector.HasFunctionCalls();
Expand All @@ -139,7 +207,30 @@ private bool TryCallFunctions(string responseContent, Func<string, string?, bool
Conversation.Messages.Add(new ChatMessage(ChatRole.Assistant, assistantContent));
messageCallback?.Invoke(Conversation.Messages);

var functionCallResults = CallFunctions(readyToCallFunctionCalls, approveFunctionCall, functionCallCallback);
List<AIContent> functionCallResults;
try
{
functionCallResults = CallFunctions(readyToCallFunctionCalls, approveFunctionCall, functionCallCallback);
}
// If the user cancels during function call approval, we need to handle that gracefully by
// using the same logic as denying function calls
catch (ChatCommand.UserWantsControlException)
{
var functionResultContents = new List<AIContent>();

foreach (var functionCall in readyToCallFunctionCalls)
{
var cancelResult = DontCallFunction(functionCall, functionCallCallback);
functionResultContents.Add(new FunctionResultContent(functionCall.CallId, cancelResult));
}

Conversation.Messages.Add(new ChatMessage(ChatRole.Tool, functionResultContents));
messageCallback?.Invoke(Conversation.Messages);

// Re-throw so the main loop (CompleteChatStreamingAsync()) knows
// to stop processing and return control to the user
throw;
}

var attachToToolMessage = functionCallResults
.Where(c => c is FunctionResultContent)
Expand Down Expand Up @@ -174,7 +265,16 @@ private List<AIContent> CallFunctions(List<FunctionCallDetector.ReadyToCallFunct
var functionResultContents = new List<AIContent>();
foreach (var functionCall in readyToCallFunctionCalls)
{
var approved = approveFunctionCall?.Invoke(functionCall.Name, functionCall.Arguments) ?? true;
bool approved;
try
{
approved = approveFunctionCall?.Invoke(functionCall.Name, functionCall.Arguments) ?? true;
}
catch (ChatCommand.UserWantsControlException)
{
// Re-throw so TryCallFunctions can handle it
throw;
}

var functionResult = approved
? CallFunction(functionCall, functionCallCallback)
Expand Down Expand Up @@ -215,7 +315,7 @@ private object DontCallFunction(FunctionCallDetector.ReadyToCallFunctionCall fun
functionCallCallback?.Invoke(functionCall.Name, functionCall.Arguments, null);

ConsoleHelpers.WriteDebugLine($"Function call not approved: {functionCall.Name} with arguments: {functionCall.Arguments}");
var functionResult = "User did not approve function call";
var functionResult = ChatCommand.CallDeniedMessage;

functionCallCallback?.Invoke(functionCall.Name, functionCall.Arguments, functionResult);

Expand Down
Loading
Loading