Pass the last user prompt to CreateFromProviderAsync calls to enable RAG support

This commit is contained in:
Thorsten Sommer 2025-02-10 18:47:00 +01:00
parent 6a2c3b637d
commit 28b0deb318
Signed by: tsommer
GPG Key ID: 371BBA77A02C0108
7 changed files with 38 additions and 24 deletions

View File

@ -13,6 +13,8 @@ public abstract class AgentBase(ILogger<AgentBase> logger, SettingsManager setti
protected ThreadSafeRandom RNG { get; init; } = rng;
protected ILogger<AgentBase> Logger { get; init; } = logger;
protected IContent? lastUserPrompt;
/// <summary>
/// Represents the type or category of this agent.
@ -63,15 +65,17 @@ public abstract class AgentBase(ILogger<AgentBase> logger, SettingsManager setti
protected DateTimeOffset AddUserRequest(ChatThread thread, string request)
{
var time = DateTimeOffset.Now;
this.lastUserPrompt = new ContentText
{
Text = request,
};
thread.Blocks.Add(new ContentBlock
{
Time = time,
ContentType = ContentType.TEXT,
Role = ChatRole.USER,
Content = new ContentText
{
Text = request,
},
Content = this.lastUserPrompt,
});
return time;
@ -103,6 +107,6 @@ public abstract class AgentBase(ILogger<AgentBase> logger, SettingsManager setti
// Use the selected provider to get the AI response.
// By awaiting this line, we wait for the entire
// content to be streamed.
await aiText.CreateFromProviderAsync(providerSettings.CreateProvider(this.Logger), this.SettingsManager, providerSettings.Model, thread);
await aiText.CreateFromProviderAsync(providerSettings.CreateProvider(this.Logger), this.SettingsManager, providerSettings.Model, this.lastUserPrompt, thread);
}
}

View File

@ -97,6 +97,7 @@ public abstract partial class AssistantBase : ComponentBase, IMessageBusReceiver
protected bool inputIsValid;
protected Profile currentProfile = Profile.NO_PROFILE;
protected ChatThread? chatThread;
protected IContent? lastUserPrompt;
private readonly Timer formChangeTimer = new(TimeSpan.FromSeconds(1.6));
@ -242,16 +243,18 @@ public abstract partial class AssistantBase : ComponentBase, IMessageBusReceiver
protected DateTimeOffset AddUserRequest(string request, bool hideContentFromUser = false)
{
var time = DateTimeOffset.Now;
this.lastUserPrompt = new ContentText
{
Text = request,
};
this.chatThread!.Blocks.Add(new ContentBlock
{
Time = time,
ContentType = ContentType.TEXT,
HideFromUser = hideContentFromUser,
Role = ChatRole.USER,
Content = new ContentText
{
Text = request,
},
Content = this.lastUserPrompt,
});
return time;
@ -287,7 +290,7 @@ public abstract partial class AssistantBase : ComponentBase, IMessageBusReceiver
// Use the selected provider to get the AI response.
// By awaiting this line, we wait for the entire
// content to be streamed.
await aiText.CreateFromProviderAsync(this.providerSettings.CreateProvider(this.Logger), this.SettingsManager, this.providerSettings.Model, this.chatThread);
await aiText.CreateFromProviderAsync(this.providerSettings.CreateProvider(this.Logger), this.SettingsManager, this.providerSettings.Model, this.lastUserPrompt, this.chatThread);
this.isProcessing = false;
this.StateHasChanged();

View File

@ -29,7 +29,7 @@ public sealed class ContentImage : IContent
public Func<Task> StreamingEvent { get; set; } = () => Task.CompletedTask;
/// <inheritdoc />
public Task CreateFromProviderAsync(IProvider provider, SettingsManager settings, Model chatModel, ChatThread chatChatThread, CancellationToken token = default)
public Task CreateFromProviderAsync(IProvider provider, SettingsManager settings, Model chatModel, IContent? lastPrompt, ChatThread? chatChatThread, CancellationToken token = default)
{
throw new NotImplementedException();
}

View File

@ -35,7 +35,7 @@ public sealed class ContentText : IContent
public Func<Task> StreamingEvent { get; set; } = () => Task.CompletedTask;
/// <inheritdoc />
public async Task CreateFromProviderAsync(IProvider provider, SettingsManager settings, Model chatModel, ChatThread? chatThread, CancellationToken token = default)
public async Task CreateFromProviderAsync(IProvider provider, SettingsManager settings, Model chatModel, IContent? lastPrompt, ChatThread? chatThread, CancellationToken token = default)
{
if(chatThread is null)
return;

View File

@ -42,5 +42,5 @@ public interface IContent
/// <summary>
/// Uses the provider to create the content.
/// </summary>
public Task CreateFromProviderAsync(IProvider provider, SettingsManager settings, Model chatModel, ChatThread chatChatThread, CancellationToken token = default);
public Task CreateFromProviderAsync(IProvider provider, SettingsManager settings, Model chatModel, IContent? lastPrompt, ChatThread? chatChatThread, CancellationToken token = default);
}

View File

@ -295,8 +295,14 @@ public partial class ChatComponent : MSGComponentBase, IAsyncDisposable
}
var time = DateTimeOffset.Now;
IContent? lastUserPrompt;
if (!reuseLastUserPrompt)
{
lastUserPrompt = new ContentText
{
Text = this.userInput,
};
//
// Add the user message to the thread:
//
@ -305,10 +311,7 @@ public partial class ChatComponent : MSGComponentBase, IAsyncDisposable
Time = time,
ContentType = ContentType.TEXT,
Role = ChatRole.USER,
Content = new ContentText
{
Text = this.userInput,
},
Content = lastUserPrompt,
});
// Save the chat:
@ -319,6 +322,8 @@ public partial class ChatComponent : MSGComponentBase, IAsyncDisposable
this.StateHasChanged();
}
}
else
lastUserPrompt = this.ChatThread.Blocks.Last(x => x.Role is ChatRole.USER).Content;
//
// Add the AI response to the thread:
@ -360,7 +365,7 @@ public partial class ChatComponent : MSGComponentBase, IAsyncDisposable
// Use the selected provider to get the AI response.
// By awaiting this line, we wait for the entire
// content to be streamed.
await aiText.CreateFromProviderAsync(this.Provider.CreateProvider(this.Logger), this.SettingsManager, this.Provider.Model, this.ChatThread, this.cancellationTokenSource.Token);
await aiText.CreateFromProviderAsync(this.Provider.CreateProvider(this.Logger), this.SettingsManager, this.Provider.Model, lastUserPrompt, this.ChatThread, this.cancellationTokenSource.Token);
}
this.cancellationTokenSource = null;

View File

@ -106,17 +106,19 @@ public partial class Writer : MSGComponentBase, IAsyncDisposable
};
var time = DateTimeOffset.Now;
var lastUserPrompt = new ContentText
{
// We use the maximum 160 characters from the end of the text:
Text = this.userInput.Length > 160 ? this.userInput[^160..] : this.userInput,
};
this.chatThread.Blocks.Clear();
this.chatThread.Blocks.Add(new ContentBlock
{
Time = time,
ContentType = ContentType.TEXT,
Role = ChatRole.USER,
Content = new ContentText
{
// We use the maximum 160 characters from the end of the text:
Text = this.userInput.Length > 160 ? this.userInput[^160..] : this.userInput,
},
Content = lastUserPrompt,
});
var aiText = new ContentText
@ -137,7 +139,7 @@ public partial class Writer : MSGComponentBase, IAsyncDisposable
this.isStreaming = true;
this.StateHasChanged();
await aiText.CreateFromProviderAsync(this.providerSettings.CreateProvider(this.Logger), this.SettingsManager, this.providerSettings.Model, this.chatThread);
await aiText.CreateFromProviderAsync(this.providerSettings.CreateProvider(this.Logger), this.SettingsManager, this.providerSettings.Model, lastUserPrompt, this.chatThread);
this.suggestion = aiText.Text;
this.isStreaming = false;