mirror of
https://github.com/MindWorkAI/AI-Studio.git
synced 2025-03-12 14:09:07 +00:00
Improved the augmentation & generation of RAG (#314)
This commit is contained in:
parent
df01ce188e
commit
6a4a7dc0d6
@ -293,7 +293,7 @@ public abstract partial class AssistantBase : ComponentBase, IMessageBusReceiver
|
|||||||
// Use the selected provider to get the AI response.
|
// Use the selected provider to get the AI response.
|
||||||
// By awaiting this line, we wait for the entire
|
// By awaiting this line, we wait for the entire
|
||||||
// content to be streamed.
|
// content to be streamed.
|
||||||
await aiText.CreateFromProviderAsync(this.providerSettings.CreateProvider(this.Logger), this.providerSettings.Model, this.lastUserPrompt, this.chatThread);
|
this.chatThread = await aiText.CreateFromProviderAsync(this.providerSettings.CreateProvider(this.Logger), this.providerSettings.Model, this.lastUserPrompt, this.chatThread);
|
||||||
|
|
||||||
this.isProcessing = false;
|
this.isProcessing = false;
|
||||||
this.StateHasChanged();
|
this.StateHasChanged();
|
||||||
|
@ -12,7 +12,6 @@ public enum ChatRole
|
|||||||
USER,
|
USER,
|
||||||
AI,
|
AI,
|
||||||
AGENT,
|
AGENT,
|
||||||
RAG,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
|
@ -40,6 +40,11 @@ public sealed record ChatThread
|
|||||||
/// </summary>
|
/// </summary>
|
||||||
public IReadOnlyList<DataSourceAgentSelected> AISelectedDataSources { get; set; } = [];
|
public IReadOnlyList<DataSourceAgentSelected> AISelectedDataSources { get; set; } = [];
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// The augmented data for this chat thread. Will be inserted into the system prompt.
|
||||||
|
/// </summary>
|
||||||
|
public string AugmentedData { get; set; } = string.Empty;
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// The name of the chat thread. Usually generated by an AI model or manually edited by the user.
|
/// The name of the chat thread. Usually generated by an AI model or manually edited by the user.
|
||||||
/// </summary>
|
/// </summary>
|
||||||
@ -74,31 +79,48 @@ public sealed record ChatThread
|
|||||||
/// <returns>The prepared system prompt.</returns>
|
/// <returns>The prepared system prompt.</returns>
|
||||||
public string PrepareSystemPrompt(SettingsManager settingsManager, ChatThread chatThread, ILogger logger)
|
public string PrepareSystemPrompt(SettingsManager settingsManager, ChatThread chatThread, ILogger logger)
|
||||||
{
|
{
|
||||||
|
var isAugmentedDataAvailable = !string.IsNullOrWhiteSpace(chatThread.AugmentedData);
|
||||||
|
var systemPromptWithAugmentedData = isAugmentedDataAvailable switch
|
||||||
|
{
|
||||||
|
true => $"""
|
||||||
|
{chatThread.SystemPrompt}
|
||||||
|
|
||||||
|
{chatThread.AugmentedData}
|
||||||
|
""",
|
||||||
|
|
||||||
|
false => chatThread.SystemPrompt,
|
||||||
|
};
|
||||||
|
|
||||||
|
if(isAugmentedDataAvailable)
|
||||||
|
logger.LogInformation("Augmented data is available for the chat thread.");
|
||||||
|
else
|
||||||
|
logger.LogInformation("No augmented data is available for the chat thread.");
|
||||||
|
|
||||||
//
|
//
|
||||||
// Prepare the system prompt:
|
// Prepare the system prompt:
|
||||||
//
|
//
|
||||||
string systemPromptText;
|
string systemPromptText;
|
||||||
var logMessage = $"Using no profile for chat thread '{chatThread.Name}'.";
|
var logMessage = $"Using no profile for chat thread '{chatThread.Name}'.";
|
||||||
if (string.IsNullOrWhiteSpace(chatThread.SelectedProfile))
|
if (string.IsNullOrWhiteSpace(chatThread.SelectedProfile))
|
||||||
systemPromptText = chatThread.SystemPrompt;
|
systemPromptText = systemPromptWithAugmentedData;
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
if(!Guid.TryParse(chatThread.SelectedProfile, out var profileId))
|
if(!Guid.TryParse(chatThread.SelectedProfile, out var profileId))
|
||||||
systemPromptText = chatThread.SystemPrompt;
|
systemPromptText = systemPromptWithAugmentedData;
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
if(chatThread.SelectedProfile == Profile.NO_PROFILE.Id || profileId == Guid.Empty)
|
if(chatThread.SelectedProfile == Profile.NO_PROFILE.Id || profileId == Guid.Empty)
|
||||||
systemPromptText = chatThread.SystemPrompt;
|
systemPromptText = systemPromptWithAugmentedData;
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
var profile = settingsManager.ConfigurationData.Profiles.FirstOrDefault(x => x.Id == chatThread.SelectedProfile);
|
var profile = settingsManager.ConfigurationData.Profiles.FirstOrDefault(x => x.Id == chatThread.SelectedProfile);
|
||||||
if(profile == default)
|
if(profile == default)
|
||||||
systemPromptText = chatThread.SystemPrompt;
|
systemPromptText = systemPromptWithAugmentedData;
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
logMessage = $"Using profile '{profile.Name}' for chat thread '{chatThread.Name}'.";
|
logMessage = $"Using profile '{profile.Name}' for chat thread '{chatThread.Name}'.";
|
||||||
systemPromptText = $"""
|
systemPromptText = $"""
|
||||||
{chatThread.SystemPrompt}
|
{systemPromptWithAugmentedData}
|
||||||
|
|
||||||
{profile.ToSystemPrompt()}
|
{profile.ToSystemPrompt()}
|
||||||
""";
|
""";
|
||||||
|
@ -28,7 +28,7 @@ public sealed class ContentImage : IContent, IImageSource
|
|||||||
public Func<Task> StreamingEvent { get; set; } = () => Task.CompletedTask;
|
public Func<Task> StreamingEvent { get; set; } = () => Task.CompletedTask;
|
||||||
|
|
||||||
/// <inheritdoc />
|
/// <inheritdoc />
|
||||||
public Task CreateFromProviderAsync(IProvider provider, Model chatModel, IContent? lastPrompt, ChatThread? chatChatThread, CancellationToken token = default)
|
public Task<ChatThread> CreateFromProviderAsync(IProvider provider, Model chatModel, IContent? lastPrompt, ChatThread? chatChatThread, CancellationToken token = default)
|
||||||
{
|
{
|
||||||
throw new NotImplementedException();
|
throw new NotImplementedException();
|
||||||
}
|
}
|
||||||
|
@ -36,10 +36,10 @@ public sealed class ContentText : IContent
|
|||||||
public Func<Task> StreamingEvent { get; set; } = () => Task.CompletedTask;
|
public Func<Task> StreamingEvent { get; set; } = () => Task.CompletedTask;
|
||||||
|
|
||||||
/// <inheritdoc />
|
/// <inheritdoc />
|
||||||
public async Task CreateFromProviderAsync(IProvider provider, Model chatModel, IContent? lastPrompt, ChatThread? chatThread, CancellationToken token = default)
|
public async Task<ChatThread> CreateFromProviderAsync(IProvider provider, Model chatModel, IContent? lastPrompt, ChatThread? chatThread, CancellationToken token = default)
|
||||||
{
|
{
|
||||||
if(chatThread is null)
|
if(chatThread is null)
|
||||||
return;
|
return new();
|
||||||
|
|
||||||
// Call the RAG process. Right now, we only have one RAG process:
|
// Call the RAG process. Right now, we only have one RAG process:
|
||||||
if (lastPrompt is not null)
|
if (lastPrompt is not null)
|
||||||
@ -115,6 +115,7 @@ public sealed class ContentText : IContent
|
|||||||
|
|
||||||
// Inform the UI that the streaming is done:
|
// Inform the UI that the streaming is done:
|
||||||
await this.StreamingDone();
|
await this.StreamingDone();
|
||||||
|
return chatThread;
|
||||||
}
|
}
|
||||||
|
|
||||||
#endregion
|
#endregion
|
||||||
|
@ -41,7 +41,7 @@ public interface IContent
|
|||||||
/// <summary>
|
/// <summary>
|
||||||
/// Uses the provider to create the content.
|
/// Uses the provider to create the content.
|
||||||
/// </summary>
|
/// </summary>
|
||||||
public Task CreateFromProviderAsync(IProvider provider, Model chatModel, IContent? lastPrompt, ChatThread? chatChatThread, CancellationToken token = default);
|
public Task<ChatThread> CreateFromProviderAsync(IProvider provider, Model chatModel, IContent? lastPrompt, ChatThread? chatChatThread, CancellationToken token = default);
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Returns the corresponding ERI content type.
|
/// Returns the corresponding ERI content type.
|
||||||
|
@ -475,7 +475,7 @@ public partial class ChatComponent : MSGComponentBase, IAsyncDisposable
|
|||||||
// Use the selected provider to get the AI response.
|
// Use the selected provider to get the AI response.
|
||||||
// By awaiting this line, we wait for the entire
|
// By awaiting this line, we wait for the entire
|
||||||
// content to be streamed.
|
// content to be streamed.
|
||||||
await aiText.CreateFromProviderAsync(this.Provider.CreateProvider(this.Logger), this.Provider.Model, lastUserPrompt, this.ChatThread, this.cancellationTokenSource.Token);
|
this.ChatThread = await aiText.CreateFromProviderAsync(this.Provider.CreateProvider(this.Logger), this.Provider.Model, lastUserPrompt, this.ChatThread, this.cancellationTokenSource.Token);
|
||||||
}
|
}
|
||||||
|
|
||||||
this.cancellationTokenSource = null;
|
this.cancellationTokenSource = null;
|
||||||
|
@ -139,7 +139,7 @@ public partial class Writer : MSGComponentBase, IAsyncDisposable
|
|||||||
this.isStreaming = true;
|
this.isStreaming = true;
|
||||||
this.StateHasChanged();
|
this.StateHasChanged();
|
||||||
|
|
||||||
await aiText.CreateFromProviderAsync(this.providerSettings.CreateProvider(this.Logger), this.providerSettings.Model, lastUserPrompt, this.chatThread);
|
this.chatThread = await aiText.CreateFromProviderAsync(this.providerSettings.CreateProvider(this.Logger), this.providerSettings.Model, lastUserPrompt, this.chatThread);
|
||||||
this.suggestion = aiText.Text;
|
this.suggestion = aiText.Text;
|
||||||
|
|
||||||
this.isStreaming = false;
|
this.isStreaming = false;
|
||||||
|
@ -38,7 +38,6 @@ public sealed class ProviderAnthropic(ILogger logger) : BaseProvider("https://ap
|
|||||||
ChatRole.USER => "user",
|
ChatRole.USER => "user",
|
||||||
ChatRole.AI => "assistant",
|
ChatRole.AI => "assistant",
|
||||||
ChatRole.AGENT => "assistant",
|
ChatRole.AGENT => "assistant",
|
||||||
ChatRole.RAG => "assistant",
|
|
||||||
|
|
||||||
_ => "user",
|
_ => "user",
|
||||||
},
|
},
|
||||||
|
@ -49,7 +49,6 @@ public sealed class ProviderDeepSeek(ILogger logger) : BaseProvider("https://api
|
|||||||
ChatRole.USER => "user",
|
ChatRole.USER => "user",
|
||||||
ChatRole.AI => "assistant",
|
ChatRole.AI => "assistant",
|
||||||
ChatRole.AGENT => "assistant",
|
ChatRole.AGENT => "assistant",
|
||||||
ChatRole.RAG => "assistant",
|
|
||||||
ChatRole.SYSTEM => "system",
|
ChatRole.SYSTEM => "system",
|
||||||
|
|
||||||
_ => "user",
|
_ => "user",
|
||||||
|
@ -49,7 +49,6 @@ public class ProviderFireworks(ILogger logger) : BaseProvider("https://api.firew
|
|||||||
ChatRole.AI => "assistant",
|
ChatRole.AI => "assistant",
|
||||||
ChatRole.AGENT => "assistant",
|
ChatRole.AGENT => "assistant",
|
||||||
ChatRole.SYSTEM => "system",
|
ChatRole.SYSTEM => "system",
|
||||||
ChatRole.RAG => "assistant",
|
|
||||||
|
|
||||||
_ => "user",
|
_ => "user",
|
||||||
},
|
},
|
||||||
|
@ -49,7 +49,6 @@ public sealed class ProviderGWDG(ILogger logger) : BaseProvider("https://chat-ai
|
|||||||
ChatRole.USER => "user",
|
ChatRole.USER => "user",
|
||||||
ChatRole.AI => "assistant",
|
ChatRole.AI => "assistant",
|
||||||
ChatRole.AGENT => "assistant",
|
ChatRole.AGENT => "assistant",
|
||||||
ChatRole.RAG => "assistant",
|
|
||||||
ChatRole.SYSTEM => "system",
|
ChatRole.SYSTEM => "system",
|
||||||
|
|
||||||
_ => "user",
|
_ => "user",
|
||||||
|
@ -50,7 +50,6 @@ public class ProviderGoogle(ILogger logger) : BaseProvider("https://generativela
|
|||||||
ChatRole.AI => "assistant",
|
ChatRole.AI => "assistant",
|
||||||
ChatRole.AGENT => "assistant",
|
ChatRole.AGENT => "assistant",
|
||||||
ChatRole.SYSTEM => "system",
|
ChatRole.SYSTEM => "system",
|
||||||
ChatRole.RAG => "assistant",
|
|
||||||
|
|
||||||
_ => "user",
|
_ => "user",
|
||||||
},
|
},
|
||||||
|
@ -50,7 +50,6 @@ public class ProviderGroq(ILogger logger) : BaseProvider("https://api.groq.com/o
|
|||||||
ChatRole.AI => "assistant",
|
ChatRole.AI => "assistant",
|
||||||
ChatRole.AGENT => "assistant",
|
ChatRole.AGENT => "assistant",
|
||||||
ChatRole.SYSTEM => "system",
|
ChatRole.SYSTEM => "system",
|
||||||
ChatRole.RAG => "assistant",
|
|
||||||
|
|
||||||
_ => "user",
|
_ => "user",
|
||||||
},
|
},
|
||||||
|
@ -49,7 +49,6 @@ public sealed class ProviderHelmholtz(ILogger logger) : BaseProvider("https://ap
|
|||||||
ChatRole.USER => "user",
|
ChatRole.USER => "user",
|
||||||
ChatRole.AI => "assistant",
|
ChatRole.AI => "assistant",
|
||||||
ChatRole.AGENT => "assistant",
|
ChatRole.AGENT => "assistant",
|
||||||
ChatRole.RAG => "assistant",
|
|
||||||
ChatRole.SYSTEM => "system",
|
ChatRole.SYSTEM => "system",
|
||||||
|
|
||||||
_ => "user",
|
_ => "user",
|
||||||
|
@ -48,7 +48,6 @@ public sealed class ProviderMistral(ILogger logger) : BaseProvider("https://api.
|
|||||||
ChatRole.AI => "assistant",
|
ChatRole.AI => "assistant",
|
||||||
ChatRole.AGENT => "assistant",
|
ChatRole.AGENT => "assistant",
|
||||||
ChatRole.SYSTEM => "system",
|
ChatRole.SYSTEM => "system",
|
||||||
ChatRole.RAG => "assistant",
|
|
||||||
|
|
||||||
_ => "user",
|
_ => "user",
|
||||||
},
|
},
|
||||||
|
@ -76,7 +76,6 @@ public sealed class ProviderOpenAI(ILogger logger) : BaseProvider("https://api.o
|
|||||||
ChatRole.USER => "user",
|
ChatRole.USER => "user",
|
||||||
ChatRole.AI => "assistant",
|
ChatRole.AI => "assistant",
|
||||||
ChatRole.AGENT => "assistant",
|
ChatRole.AGENT => "assistant",
|
||||||
ChatRole.RAG => "assistant",
|
|
||||||
ChatRole.SYSTEM => systemPromptRole,
|
ChatRole.SYSTEM => systemPromptRole,
|
||||||
|
|
||||||
_ => "user",
|
_ => "user",
|
||||||
|
@ -46,7 +46,6 @@ public sealed class ProviderSelfHosted(ILogger logger, Host host, string hostnam
|
|||||||
ChatRole.AI => "assistant",
|
ChatRole.AI => "assistant",
|
||||||
ChatRole.AGENT => "assistant",
|
ChatRole.AGENT => "assistant",
|
||||||
ChatRole.SYSTEM => "system",
|
ChatRole.SYSTEM => "system",
|
||||||
ChatRole.RAG => "assistant",
|
|
||||||
|
|
||||||
_ => "user",
|
_ => "user",
|
||||||
},
|
},
|
||||||
|
@ -50,7 +50,6 @@ public sealed class ProviderX(ILogger logger) : BaseProvider("https://api.x.ai/v
|
|||||||
ChatRole.AI => "assistant",
|
ChatRole.AI => "assistant",
|
||||||
ChatRole.AGENT => "assistant",
|
ChatRole.AGENT => "assistant",
|
||||||
ChatRole.SYSTEM => "system",
|
ChatRole.SYSTEM => "system",
|
||||||
ChatRole.RAG => "assistant",
|
|
||||||
|
|
||||||
_ => "user",
|
_ => "user",
|
||||||
},
|
},
|
||||||
|
@ -97,7 +97,7 @@ public readonly record struct DataSourceERI_V1 : IERIDataSource
|
|||||||
Links = eriContext.Links,
|
Links = eriContext.Links,
|
||||||
Category = eriContext.Type.ToRetrievalContentCategory(),
|
Category = eriContext.Type.ToRetrievalContentCategory(),
|
||||||
MatchedText = eriContext.MatchedContent,
|
MatchedText = eriContext.MatchedContent,
|
||||||
DataSourceName = eriContext.Name,
|
DataSourceName = this.Name,
|
||||||
SurroundingContent = eriContext.SurroundingContent,
|
SurroundingContent = eriContext.SurroundingContent,
|
||||||
});
|
});
|
||||||
break;
|
break;
|
||||||
@ -111,7 +111,7 @@ public readonly record struct DataSourceERI_V1 : IERIDataSource
|
|||||||
Source = eriContext.MatchedContent,
|
Source = eriContext.MatchedContent,
|
||||||
Category = eriContext.Type.ToRetrievalContentCategory(),
|
Category = eriContext.Type.ToRetrievalContentCategory(),
|
||||||
SourceType = ContentImageSource.BASE64,
|
SourceType = ContentImageSource.BASE64,
|
||||||
DataSourceName = eriContext.Name,
|
DataSourceName = this.Name,
|
||||||
});
|
});
|
||||||
break;
|
break;
|
||||||
|
|
||||||
|
@ -66,22 +66,8 @@ public sealed class AugmentationOne : IAugmentationProcess
|
|||||||
// Let's convert all retrieval contexts to Markdown:
|
// Let's convert all retrieval contexts to Markdown:
|
||||||
await retrievalContexts.AsMarkdown(sb, token);
|
await retrievalContexts.AsMarkdown(sb, token);
|
||||||
|
|
||||||
//
|
// Add the augmented data to the chat thread:
|
||||||
// Append the entire augmentation to the chat thread,
|
chatThread.AugmentedData = sb.ToString();
|
||||||
// just before the user prompt:
|
|
||||||
//
|
|
||||||
chatThread.Blocks.Insert(chatThread.Blocks.Count - 1, new()
|
|
||||||
{
|
|
||||||
Role = ChatRole.RAG,
|
|
||||||
Time = DateTimeOffset.UtcNow,
|
|
||||||
ContentType = ContentType.TEXT,
|
|
||||||
HideFromUser = true,
|
|
||||||
Content = new ContentText
|
|
||||||
{
|
|
||||||
Text = sb.ToString(),
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
return chatThread;
|
return chatThread;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -38,6 +38,30 @@ public sealed class AISrcSelWithRetCtxVal : IRagProcess
|
|||||||
// makes sense to proceed with the RAG process:
|
// makes sense to proceed with the RAG process:
|
||||||
var proceedWithRAG = true;
|
var proceedWithRAG = true;
|
||||||
|
|
||||||
|
//
|
||||||
|
// We read the last block in the chat thread. We need to re-arrange
|
||||||
|
// the order of blocks later, after the augmentation process takes
|
||||||
|
// place:
|
||||||
|
//
|
||||||
|
if(chatThread.Blocks.Count == 0)
|
||||||
|
{
|
||||||
|
logger.LogError("The chat thread is empty. Skipping the RAG process.");
|
||||||
|
return chatThread;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (chatThread.Blocks.Last().Role != ChatRole.AI)
|
||||||
|
{
|
||||||
|
logger.LogError("The last block in the chat thread is not the AI block. There is something wrong with the chat thread. Skipping the RAG process.");
|
||||||
|
return chatThread;
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// At this point in time, the chat thread contains already the
|
||||||
|
// last block, which is the waiting AI block. We need to remove
|
||||||
|
// this block before we call some parts of the RAG process:
|
||||||
|
//
|
||||||
|
var chatThreadWithoutWaitingAIBlock = chatThread with { Blocks = chatThread.Blocks[..^1] };
|
||||||
|
|
||||||
//
|
//
|
||||||
// When the user wants to bind data sources to the chat, we
|
// When the user wants to bind data sources to the chat, we
|
||||||
// have to check if the data sources are available for the
|
// have to check if the data sources are available for the
|
||||||
@ -84,7 +108,7 @@ public sealed class AISrcSelWithRetCtxVal : IRagProcess
|
|||||||
//
|
//
|
||||||
var retrievalTasks = new List<Task<IReadOnlyList<IRetrievalContext>>>(selectedDataSources.Count);
|
var retrievalTasks = new List<Task<IReadOnlyList<IRetrievalContext>>>(selectedDataSources.Count);
|
||||||
foreach (var dataSource in selectedDataSources)
|
foreach (var dataSource in selectedDataSources)
|
||||||
retrievalTasks.Add(dataSource.RetrieveDataAsync(lastPrompt, chatThread, token));
|
retrievalTasks.Add(dataSource.RetrieveDataAsync(lastPrompt, chatThreadWithoutWaitingAIBlock, token));
|
||||||
|
|
||||||
//
|
//
|
||||||
// Wait for all retrieval tasks to finish:
|
// Wait for all retrieval tasks to finish:
|
||||||
|
@ -4,3 +4,7 @@
|
|||||||
- Improved the ERI client to raise an error when the server responds with additional JSON data that is not expected.
|
- Improved the ERI client to raise an error when the server responds with additional JSON data that is not expected.
|
||||||
- Improved the error handling in the ERI data source info dialog in cases where servers respond with an invalid message.
|
- Improved the error handling in the ERI data source info dialog in cases where servers respond with an invalid message.
|
||||||
- Improved the error handling for the entire RAG process.
|
- Improved the error handling for the entire RAG process.
|
||||||
|
- Improved chat thread persistence after modifications through the RAG process.
|
||||||
|
- Improved the augmentation and generation part of RAG by passing the augmented data into the system prompt.
|
||||||
|
- Fixed the chat thread we use for the data retrieval by removing the last block, which is meant to be for the final AI answer.
|
||||||
|
- Fixed the data source name for ERI data sources when performing data retrieval.
|
Loading…
Reference in New Issue
Block a user