Improved the augmentation & generation of RAG (#314)

This commit is contained in:
Thorsten Sommer 2025-03-08 13:56:38 +01:00 committed by GitHub
parent df01ce188e
commit 6a4a7dc0d6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
23 changed files with 69 additions and 44 deletions

View File

@ -293,7 +293,7 @@ public abstract partial class AssistantBase : ComponentBase, IMessageBusReceiver
// Use the selected provider to get the AI response. // Use the selected provider to get the AI response.
// By awaiting this line, we wait for the entire // By awaiting this line, we wait for the entire
// content to be streamed. // content to be streamed.
await aiText.CreateFromProviderAsync(this.providerSettings.CreateProvider(this.Logger), this.providerSettings.Model, this.lastUserPrompt, this.chatThread); this.chatThread = await aiText.CreateFromProviderAsync(this.providerSettings.CreateProvider(this.Logger), this.providerSettings.Model, this.lastUserPrompt, this.chatThread);
this.isProcessing = false; this.isProcessing = false;
this.StateHasChanged(); this.StateHasChanged();

View File

@ -12,7 +12,6 @@ public enum ChatRole
USER, USER,
AI, AI,
AGENT, AGENT,
RAG,
} }
/// <summary> /// <summary>

View File

@ -40,6 +40,11 @@ public sealed record ChatThread
/// </summary> /// </summary>
public IReadOnlyList<DataSourceAgentSelected> AISelectedDataSources { get; set; } = []; public IReadOnlyList<DataSourceAgentSelected> AISelectedDataSources { get; set; } = [];
/// <summary>
/// The augmented data for this chat thread. Will be inserted into the system prompt.
/// </summary>
public string AugmentedData { get; set; } = string.Empty;
/// <summary> /// <summary>
/// The name of the chat thread. Usually generated by an AI model or manually edited by the user. /// The name of the chat thread. Usually generated by an AI model or manually edited by the user.
/// </summary> /// </summary>
@ -74,31 +79,48 @@ public sealed record ChatThread
/// <returns>The prepared system prompt.</returns> /// <returns>The prepared system prompt.</returns>
public string PrepareSystemPrompt(SettingsManager settingsManager, ChatThread chatThread, ILogger logger) public string PrepareSystemPrompt(SettingsManager settingsManager, ChatThread chatThread, ILogger logger)
{ {
var isAugmentedDataAvailable = !string.IsNullOrWhiteSpace(chatThread.AugmentedData);
var systemPromptWithAugmentedData = isAugmentedDataAvailable switch
{
true => $"""
{chatThread.SystemPrompt}
{chatThread.AugmentedData}
""",
false => chatThread.SystemPrompt,
};
if(isAugmentedDataAvailable)
logger.LogInformation("Augmented data is available for the chat thread.");
else
logger.LogInformation("No augmented data is available for the chat thread.");
// //
// Prepare the system prompt: // Prepare the system prompt:
// //
string systemPromptText; string systemPromptText;
var logMessage = $"Using no profile for chat thread '{chatThread.Name}'."; var logMessage = $"Using no profile for chat thread '{chatThread.Name}'.";
if (string.IsNullOrWhiteSpace(chatThread.SelectedProfile)) if (string.IsNullOrWhiteSpace(chatThread.SelectedProfile))
systemPromptText = chatThread.SystemPrompt; systemPromptText = systemPromptWithAugmentedData;
else else
{ {
if(!Guid.TryParse(chatThread.SelectedProfile, out var profileId)) if(!Guid.TryParse(chatThread.SelectedProfile, out var profileId))
systemPromptText = chatThread.SystemPrompt; systemPromptText = systemPromptWithAugmentedData;
else else
{ {
if(chatThread.SelectedProfile == Profile.NO_PROFILE.Id || profileId == Guid.Empty) if(chatThread.SelectedProfile == Profile.NO_PROFILE.Id || profileId == Guid.Empty)
systemPromptText = chatThread.SystemPrompt; systemPromptText = systemPromptWithAugmentedData;
else else
{ {
var profile = settingsManager.ConfigurationData.Profiles.FirstOrDefault(x => x.Id == chatThread.SelectedProfile); var profile = settingsManager.ConfigurationData.Profiles.FirstOrDefault(x => x.Id == chatThread.SelectedProfile);
if(profile == default) if(profile == default)
systemPromptText = chatThread.SystemPrompt; systemPromptText = systemPromptWithAugmentedData;
else else
{ {
logMessage = $"Using profile '{profile.Name}' for chat thread '{chatThread.Name}'."; logMessage = $"Using profile '{profile.Name}' for chat thread '{chatThread.Name}'.";
systemPromptText = $""" systemPromptText = $"""
{chatThread.SystemPrompt} {systemPromptWithAugmentedData}
{profile.ToSystemPrompt()} {profile.ToSystemPrompt()}
"""; """;

View File

@ -28,7 +28,7 @@ public sealed class ContentImage : IContent, IImageSource
public Func<Task> StreamingEvent { get; set; } = () => Task.CompletedTask; public Func<Task> StreamingEvent { get; set; } = () => Task.CompletedTask;
/// <inheritdoc /> /// <inheritdoc />
public Task CreateFromProviderAsync(IProvider provider, Model chatModel, IContent? lastPrompt, ChatThread? chatChatThread, CancellationToken token = default) public Task<ChatThread> CreateFromProviderAsync(IProvider provider, Model chatModel, IContent? lastPrompt, ChatThread? chatChatThread, CancellationToken token = default)
{ {
throw new NotImplementedException(); throw new NotImplementedException();
} }

View File

@ -36,10 +36,10 @@ public sealed class ContentText : IContent
public Func<Task> StreamingEvent { get; set; } = () => Task.CompletedTask; public Func<Task> StreamingEvent { get; set; } = () => Task.CompletedTask;
/// <inheritdoc /> /// <inheritdoc />
public async Task CreateFromProviderAsync(IProvider provider, Model chatModel, IContent? lastPrompt, ChatThread? chatThread, CancellationToken token = default) public async Task<ChatThread> CreateFromProviderAsync(IProvider provider, Model chatModel, IContent? lastPrompt, ChatThread? chatThread, CancellationToken token = default)
{ {
if(chatThread is null) if(chatThread is null)
return; return new();
// Call the RAG process. Right now, we only have one RAG process: // Call the RAG process. Right now, we only have one RAG process:
if (lastPrompt is not null) if (lastPrompt is not null)
@ -115,6 +115,7 @@ public sealed class ContentText : IContent
// Inform the UI that the streaming is done: // Inform the UI that the streaming is done:
await this.StreamingDone(); await this.StreamingDone();
return chatThread;
} }
#endregion #endregion

View File

@ -41,7 +41,7 @@ public interface IContent
/// <summary> /// <summary>
/// Uses the provider to create the content. /// Uses the provider to create the content.
/// </summary> /// </summary>
public Task CreateFromProviderAsync(IProvider provider, Model chatModel, IContent? lastPrompt, ChatThread? chatChatThread, CancellationToken token = default); public Task<ChatThread> CreateFromProviderAsync(IProvider provider, Model chatModel, IContent? lastPrompt, ChatThread? chatChatThread, CancellationToken token = default);
/// <summary> /// <summary>
/// Returns the corresponding ERI content type. /// Returns the corresponding ERI content type.

View File

@ -475,7 +475,7 @@ public partial class ChatComponent : MSGComponentBase, IAsyncDisposable
// Use the selected provider to get the AI response. // Use the selected provider to get the AI response.
// By awaiting this line, we wait for the entire // By awaiting this line, we wait for the entire
// content to be streamed. // content to be streamed.
await aiText.CreateFromProviderAsync(this.Provider.CreateProvider(this.Logger), this.Provider.Model, lastUserPrompt, this.ChatThread, this.cancellationTokenSource.Token); this.ChatThread = await aiText.CreateFromProviderAsync(this.Provider.CreateProvider(this.Logger), this.Provider.Model, lastUserPrompt, this.ChatThread, this.cancellationTokenSource.Token);
} }
this.cancellationTokenSource = null; this.cancellationTokenSource = null;

View File

@ -139,7 +139,7 @@ public partial class Writer : MSGComponentBase, IAsyncDisposable
this.isStreaming = true; this.isStreaming = true;
this.StateHasChanged(); this.StateHasChanged();
await aiText.CreateFromProviderAsync(this.providerSettings.CreateProvider(this.Logger), this.providerSettings.Model, lastUserPrompt, this.chatThread); this.chatThread = await aiText.CreateFromProviderAsync(this.providerSettings.CreateProvider(this.Logger), this.providerSettings.Model, lastUserPrompt, this.chatThread);
this.suggestion = aiText.Text; this.suggestion = aiText.Text;
this.isStreaming = false; this.isStreaming = false;

View File

@ -38,7 +38,6 @@ public sealed class ProviderAnthropic(ILogger logger) : BaseProvider("https://ap
ChatRole.USER => "user", ChatRole.USER => "user",
ChatRole.AI => "assistant", ChatRole.AI => "assistant",
ChatRole.AGENT => "assistant", ChatRole.AGENT => "assistant",
ChatRole.RAG => "assistant",
_ => "user", _ => "user",
}, },

View File

@ -49,7 +49,6 @@ public sealed class ProviderDeepSeek(ILogger logger) : BaseProvider("https://api
ChatRole.USER => "user", ChatRole.USER => "user",
ChatRole.AI => "assistant", ChatRole.AI => "assistant",
ChatRole.AGENT => "assistant", ChatRole.AGENT => "assistant",
ChatRole.RAG => "assistant",
ChatRole.SYSTEM => "system", ChatRole.SYSTEM => "system",
_ => "user", _ => "user",

View File

@ -49,7 +49,6 @@ public class ProviderFireworks(ILogger logger) : BaseProvider("https://api.firew
ChatRole.AI => "assistant", ChatRole.AI => "assistant",
ChatRole.AGENT => "assistant", ChatRole.AGENT => "assistant",
ChatRole.SYSTEM => "system", ChatRole.SYSTEM => "system",
ChatRole.RAG => "assistant",
_ => "user", _ => "user",
}, },

View File

@ -49,7 +49,6 @@ public sealed class ProviderGWDG(ILogger logger) : BaseProvider("https://chat-ai
ChatRole.USER => "user", ChatRole.USER => "user",
ChatRole.AI => "assistant", ChatRole.AI => "assistant",
ChatRole.AGENT => "assistant", ChatRole.AGENT => "assistant",
ChatRole.RAG => "assistant",
ChatRole.SYSTEM => "system", ChatRole.SYSTEM => "system",
_ => "user", _ => "user",

View File

@ -50,7 +50,6 @@ public class ProviderGoogle(ILogger logger) : BaseProvider("https://generativela
ChatRole.AI => "assistant", ChatRole.AI => "assistant",
ChatRole.AGENT => "assistant", ChatRole.AGENT => "assistant",
ChatRole.SYSTEM => "system", ChatRole.SYSTEM => "system",
ChatRole.RAG => "assistant",
_ => "user", _ => "user",
}, },

View File

@ -50,7 +50,6 @@ public class ProviderGroq(ILogger logger) : BaseProvider("https://api.groq.com/o
ChatRole.AI => "assistant", ChatRole.AI => "assistant",
ChatRole.AGENT => "assistant", ChatRole.AGENT => "assistant",
ChatRole.SYSTEM => "system", ChatRole.SYSTEM => "system",
ChatRole.RAG => "assistant",
_ => "user", _ => "user",
}, },

View File

@ -49,7 +49,6 @@ public sealed class ProviderHelmholtz(ILogger logger) : BaseProvider("https://ap
ChatRole.USER => "user", ChatRole.USER => "user",
ChatRole.AI => "assistant", ChatRole.AI => "assistant",
ChatRole.AGENT => "assistant", ChatRole.AGENT => "assistant",
ChatRole.RAG => "assistant",
ChatRole.SYSTEM => "system", ChatRole.SYSTEM => "system",
_ => "user", _ => "user",

View File

@ -48,7 +48,6 @@ public sealed class ProviderMistral(ILogger logger) : BaseProvider("https://api.
ChatRole.AI => "assistant", ChatRole.AI => "assistant",
ChatRole.AGENT => "assistant", ChatRole.AGENT => "assistant",
ChatRole.SYSTEM => "system", ChatRole.SYSTEM => "system",
ChatRole.RAG => "assistant",
_ => "user", _ => "user",
}, },

View File

@ -76,7 +76,6 @@ public sealed class ProviderOpenAI(ILogger logger) : BaseProvider("https://api.o
ChatRole.USER => "user", ChatRole.USER => "user",
ChatRole.AI => "assistant", ChatRole.AI => "assistant",
ChatRole.AGENT => "assistant", ChatRole.AGENT => "assistant",
ChatRole.RAG => "assistant",
ChatRole.SYSTEM => systemPromptRole, ChatRole.SYSTEM => systemPromptRole,
_ => "user", _ => "user",

View File

@ -46,7 +46,6 @@ public sealed class ProviderSelfHosted(ILogger logger, Host host, string hostnam
ChatRole.AI => "assistant", ChatRole.AI => "assistant",
ChatRole.AGENT => "assistant", ChatRole.AGENT => "assistant",
ChatRole.SYSTEM => "system", ChatRole.SYSTEM => "system",
ChatRole.RAG => "assistant",
_ => "user", _ => "user",
}, },

View File

@ -50,7 +50,6 @@ public sealed class ProviderX(ILogger logger) : BaseProvider("https://api.x.ai/v
ChatRole.AI => "assistant", ChatRole.AI => "assistant",
ChatRole.AGENT => "assistant", ChatRole.AGENT => "assistant",
ChatRole.SYSTEM => "system", ChatRole.SYSTEM => "system",
ChatRole.RAG => "assistant",
_ => "user", _ => "user",
}, },

View File

@ -97,7 +97,7 @@ public readonly record struct DataSourceERI_V1 : IERIDataSource
Links = eriContext.Links, Links = eriContext.Links,
Category = eriContext.Type.ToRetrievalContentCategory(), Category = eriContext.Type.ToRetrievalContentCategory(),
MatchedText = eriContext.MatchedContent, MatchedText = eriContext.MatchedContent,
DataSourceName = eriContext.Name, DataSourceName = this.Name,
SurroundingContent = eriContext.SurroundingContent, SurroundingContent = eriContext.SurroundingContent,
}); });
break; break;
@ -111,7 +111,7 @@ public readonly record struct DataSourceERI_V1 : IERIDataSource
Source = eriContext.MatchedContent, Source = eriContext.MatchedContent,
Category = eriContext.Type.ToRetrievalContentCategory(), Category = eriContext.Type.ToRetrievalContentCategory(),
SourceType = ContentImageSource.BASE64, SourceType = ContentImageSource.BASE64,
DataSourceName = eriContext.Name, DataSourceName = this.Name,
}); });
break; break;

View File

@ -66,22 +66,8 @@ public sealed class AugmentationOne : IAugmentationProcess
// Let's convert all retrieval contexts to Markdown: // Let's convert all retrieval contexts to Markdown:
await retrievalContexts.AsMarkdown(sb, token); await retrievalContexts.AsMarkdown(sb, token);
// // Add the augmented data to the chat thread:
// Append the entire augmentation to the chat thread, chatThread.AugmentedData = sb.ToString();
// just before the user prompt:
//
chatThread.Blocks.Insert(chatThread.Blocks.Count - 1, new()
{
Role = ChatRole.RAG,
Time = DateTimeOffset.UtcNow,
ContentType = ContentType.TEXT,
HideFromUser = true,
Content = new ContentText
{
Text = sb.ToString(),
}
});
return chatThread; return chatThread;
} }

View File

@ -38,6 +38,30 @@ public sealed class AISrcSelWithRetCtxVal : IRagProcess
// makes sense to proceed with the RAG process: // makes sense to proceed with the RAG process:
var proceedWithRAG = true; var proceedWithRAG = true;
//
// We read the last block in the chat thread. We need to re-arrange
// the order of blocks later, after the augmentation process takes
// place:
//
if(chatThread.Blocks.Count == 0)
{
logger.LogError("The chat thread is empty. Skipping the RAG process.");
return chatThread;
}
if (chatThread.Blocks.Last().Role != ChatRole.AI)
{
logger.LogError("The last block in the chat thread is not the AI block. There is something wrong with the chat thread. Skipping the RAG process.");
return chatThread;
}
//
// At this point in time, the chat thread contains already the
// last block, which is the waiting AI block. We need to remove
// this block before we call some parts of the RAG process:
//
var chatThreadWithoutWaitingAIBlock = chatThread with { Blocks = chatThread.Blocks[..^1] };
// //
// When the user wants to bind data sources to the chat, we // When the user wants to bind data sources to the chat, we
// have to check if the data sources are available for the // have to check if the data sources are available for the
@ -84,7 +108,7 @@ public sealed class AISrcSelWithRetCtxVal : IRagProcess
// //
var retrievalTasks = new List<Task<IReadOnlyList<IRetrievalContext>>>(selectedDataSources.Count); var retrievalTasks = new List<Task<IReadOnlyList<IRetrievalContext>>>(selectedDataSources.Count);
foreach (var dataSource in selectedDataSources) foreach (var dataSource in selectedDataSources)
retrievalTasks.Add(dataSource.RetrieveDataAsync(lastPrompt, chatThread, token)); retrievalTasks.Add(dataSource.RetrieveDataAsync(lastPrompt, chatThreadWithoutWaitingAIBlock, token));
// //
// Wait for all retrieval tasks to finish: // Wait for all retrieval tasks to finish:

View File

@ -4,3 +4,7 @@
- Improved the ERI client to raise an error when the server responds with additional JSON data that is not expected. - Improved the ERI client to raise an error when the server responds with additional JSON data that is not expected.
- Improved the error handling in the ERI data source info dialog in cases where servers respond with an invalid message. - Improved the error handling in the ERI data source info dialog in cases where servers respond with an invalid message.
- Improved the error handling for the entire RAG process. - Improved the error handling for the entire RAG process.
- Improved chat thread persistence after modifications through the RAG process.
- Improved the augmentation and generation part of RAG by passing the augmented data into the system prompt.
- Fixed the chat thread we use for the data retrieval by removing the last block, which is meant to be for the final AI answer.
- Fixed the data source name for ERI data sources when performing data retrieval.