feat:AI工具新增ToolContext
This commit is contained in:
parent
911f5f8bf3
commit
dda2b56bbf
|
@ -4,10 +4,12 @@ import cn.hutool.core.collection.CollUtil;
|
||||||
import cn.hutool.core.util.ObjUtil;
|
import cn.hutool.core.util.ObjUtil;
|
||||||
import cn.hutool.core.util.StrUtil;
|
import cn.hutool.core.util.StrUtil;
|
||||||
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
|
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
|
||||||
|
import cn.iocoder.yudao.framework.ai.core.pojo.AiToolContext;
|
||||||
import cn.iocoder.yudao.framework.ai.core.util.AiUtils;
|
import cn.iocoder.yudao.framework.ai.core.util.AiUtils;
|
||||||
import cn.iocoder.yudao.framework.common.pojo.CommonResult;
|
import cn.iocoder.yudao.framework.common.pojo.CommonResult;
|
||||||
import cn.iocoder.yudao.framework.common.pojo.PageResult;
|
import cn.iocoder.yudao.framework.common.pojo.PageResult;
|
||||||
import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
|
import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
|
||||||
|
import cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils;
|
||||||
import cn.iocoder.yudao.framework.tenant.core.util.TenantUtils;
|
import cn.iocoder.yudao.framework.tenant.core.util.TenantUtils;
|
||||||
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessagePageReqVO;
|
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessagePageReqVO;
|
||||||
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageRespVO;
|
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageRespVO;
|
||||||
|
@ -115,7 +117,7 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
|
||||||
knowledgeSegments);
|
knowledgeSegments);
|
||||||
|
|
||||||
// 3.2 创建 chat 需要的 Prompt
|
// 3.2 创建 chat 需要的 Prompt
|
||||||
Prompt prompt = buildPrompt(conversation, historyMessages, knowledgeSegments, model, sendReqVO);
|
Prompt prompt = buildPrompt(chatModel, conversation, historyMessages, knowledgeSegments, model, sendReqVO);
|
||||||
ChatResponse chatResponse = chatModel.call(prompt);
|
ChatResponse chatResponse = chatModel.call(prompt);
|
||||||
|
|
||||||
// 3.3 更新响应内容
|
// 3.3 更新响应内容
|
||||||
|
@ -164,7 +166,7 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
|
||||||
knowledgeSegments);
|
knowledgeSegments);
|
||||||
|
|
||||||
// 4.2 构建 Prompt,并进行调用
|
// 4.2 构建 Prompt,并进行调用
|
||||||
Prompt prompt = buildPrompt(conversation, historyMessages, knowledgeSegments, model, sendReqVO);
|
Prompt prompt = buildPrompt(chatModel, conversation, historyMessages, knowledgeSegments, model, sendReqVO);
|
||||||
Flux<ChatResponse> streamResponse = chatModel.stream(prompt);
|
Flux<ChatResponse> streamResponse = chatModel.stream(prompt);
|
||||||
|
|
||||||
// 4.3 流式返回
|
// 4.3 流式返回
|
||||||
|
@ -220,7 +222,7 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
|
||||||
return knowledgeSegments;
|
return knowledgeSegments;
|
||||||
}
|
}
|
||||||
|
|
||||||
private Prompt buildPrompt(AiChatConversationDO conversation, List<AiChatMessageDO> messages,
|
private Prompt buildPrompt(StreamingChatModel chatModel, AiChatConversationDO conversation, List<AiChatMessageDO> messages,
|
||||||
List<AiKnowledgeSegmentSearchRespBO> knowledgeSegments,
|
List<AiKnowledgeSegmentSearchRespBO> knowledgeSegments,
|
||||||
AiModelDO model, AiChatMessageSendReqVO sendReqVO) {
|
AiModelDO model, AiChatMessageSendReqVO sendReqVO) {
|
||||||
List<Message> chatMessages = new ArrayList<>();
|
List<Message> chatMessages = new ArrayList<>();
|
||||||
|
@ -247,16 +249,22 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
|
||||||
|
|
||||||
// 2.1 查询 tool 工具
|
// 2.1 查询 tool 工具
|
||||||
Set<String> toolNames = null;
|
Set<String> toolNames = null;
|
||||||
|
Map<String,Object> toolContext = Map.of();
|
||||||
if (conversation.getRoleId() != null) {
|
if (conversation.getRoleId() != null) {
|
||||||
AiChatRoleDO chatRole = chatRoleService.getChatRole(conversation.getRoleId());
|
AiChatRoleDO chatRole = chatRoleService.getChatRole(conversation.getRoleId());
|
||||||
if (chatRole != null && CollUtil.isNotEmpty(chatRole.getToolIds())) {
|
if (chatRole != null && CollUtil.isNotEmpty(chatRole.getToolIds())) {
|
||||||
toolNames = convertSet(toolService.getToolList(chatRole.getToolIds()), AiToolDO::getName);
|
toolNames = convertSet(toolService.getToolList(chatRole.getToolIds()), AiToolDO::getName);
|
||||||
|
// 2.1.1 构建 Function Calling 的上下文参数
|
||||||
|
toolContext = Map.of(
|
||||||
|
AiToolContext.CONTEXT_KEY, new AiToolContext().setChatModel(chatModel).setUserId(SecurityFrameworkUtils.getLoginUserId())
|
||||||
|
.setRoleId(conversation.getRoleId())
|
||||||
|
.setConversationId(conversation.getId()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// 2.2 构建 ChatOptions 对象
|
// 2.2 构建 ChatOptions 对象
|
||||||
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
|
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
|
||||||
ChatOptions chatOptions = AiUtils.buildChatOptions(platform, model.getModel(),
|
ChatOptions chatOptions = AiUtils.buildChatOptions(platform, model.getModel(),
|
||||||
conversation.getTemperature(), conversation.getMaxTokens(), toolNames);
|
conversation.getTemperature(), conversation.getMaxTokens(), toolNames, toolContext);
|
||||||
return new Prompt(chatMessages, chatOptions);
|
return new Prompt(chatMessages, chatOptions);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,43 @@
|
||||||
|
package cn.iocoder.yudao.module.ai.service.model.tool;
|
||||||
|
|
||||||
|
import cn.iocoder.yudao.framework.ai.core.pojo.AiToolContext;
|
||||||
|
import com.fasterxml.jackson.annotation.JsonClassDescription;
|
||||||
|
import lombok.AllArgsConstructor;
|
||||||
|
import lombok.Data;
|
||||||
|
import lombok.NoArgsConstructor;
|
||||||
|
import org.springframework.ai.chat.model.ToolContext;
|
||||||
|
import org.springframework.stereotype.Component;
|
||||||
|
|
||||||
|
import java.util.function.BiFunction;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 工具:用户ID查询(上下文参数Demo)
|
||||||
|
*
|
||||||
|
* @author Ren
|
||||||
|
*/
|
||||||
|
@Component("userid_query")
|
||||||
|
public class UserIdQueryToolFunction
|
||||||
|
implements BiFunction<UserIdQueryToolFunction.Request, ToolContext, UserIdQueryToolFunction.Response> {
|
||||||
|
|
||||||
|
@Data
|
||||||
|
@JsonClassDescription("用户ID查询")
|
||||||
|
public static class Request { }
|
||||||
|
|
||||||
|
@Data
|
||||||
|
@AllArgsConstructor
|
||||||
|
@NoArgsConstructor
|
||||||
|
public static class Response {
|
||||||
|
/**
|
||||||
|
* 用户ID
|
||||||
|
*/
|
||||||
|
private Long UserId;
|
||||||
|
|
||||||
|
}
|
||||||
|
@Override
|
||||||
|
public UserIdQueryToolFunction.Response apply(UserIdQueryToolFunction.Request request, ToolContext toolContext) {
|
||||||
|
// 获取当前登录用户
|
||||||
|
AiToolContext context = (AiToolContext) toolContext.getContext().get(AiToolContext.CONTEXT_KEY);
|
||||||
|
|
||||||
|
return new Response(context.getUserId());
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,37 @@
|
||||||
|
package cn.iocoder.yudao.framework.ai.core.pojo;
|
||||||
|
|
||||||
|
import lombok.AllArgsConstructor;
|
||||||
|
import lombok.Data;
|
||||||
|
import lombok.NoArgsConstructor;
|
||||||
|
import org.springframework.ai.chat.model.ChatModel;
|
||||||
|
import org.springframework.ai.chat.model.StreamingChatModel;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 工具上下文参数 DTO,让AI工具可以处理当前用户的相关信息
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
@Data
|
||||||
|
@NoArgsConstructor
|
||||||
|
@AllArgsConstructor
|
||||||
|
public class AiToolContext {
|
||||||
|
public static final String CONTEXT_KEY = "AI_TOOL_CONTEXT";
|
||||||
|
/**
|
||||||
|
* 用户ID
|
||||||
|
*/
|
||||||
|
private Long userId;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 聊天模型
|
||||||
|
*/
|
||||||
|
private StreamingChatModel chatModel;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 关联的聊天角色Id
|
||||||
|
*/
|
||||||
|
private Long roleId;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 会话Id
|
||||||
|
*/
|
||||||
|
private Long conversationId;
|
||||||
|
}
|
|
@ -15,6 +15,7 @@ import org.springframework.ai.qianfan.QianFanChatOptions;
|
||||||
import org.springframework.ai.zhipuai.ZhiPuAiChatOptions;
|
import org.springframework.ai.zhipuai.ZhiPuAiChatOptions;
|
||||||
|
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
|
import java.util.Map;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -25,28 +26,28 @@ import java.util.Set;
|
||||||
public class AiUtils {
|
public class AiUtils {
|
||||||
|
|
||||||
public static ChatOptions buildChatOptions(AiPlatformEnum platform, String model, Double temperature, Integer maxTokens) {
|
public static ChatOptions buildChatOptions(AiPlatformEnum platform, String model, Double temperature, Integer maxTokens) {
|
||||||
return buildChatOptions(platform, model, temperature, maxTokens, null);
|
return buildChatOptions(platform, model, temperature, maxTokens, null, Map.of());
|
||||||
}
|
}
|
||||||
|
|
||||||
public static ChatOptions buildChatOptions(AiPlatformEnum platform, String model, Double temperature, Integer maxTokens,
|
public static ChatOptions buildChatOptions(AiPlatformEnum platform, String model, Double temperature, Integer maxTokens,
|
||||||
Set<String> toolNames) {
|
Set<String> toolNames, Map<String, Object> toolContext) {
|
||||||
toolNames = ObjUtil.defaultIfNull(toolNames, Collections.emptySet());
|
toolNames = ObjUtil.defaultIfNull(toolNames, Collections.emptySet());
|
||||||
// noinspection EnhancedSwitchMigration
|
// noinspection EnhancedSwitchMigration
|
||||||
switch (platform) {
|
switch (platform) {
|
||||||
case TONG_YI:
|
case TONG_YI:
|
||||||
return DashScopeChatOptions.builder().withModel(model).withTemperature(temperature).withMaxToken(maxTokens)
|
return DashScopeChatOptions.builder().withModel(model).withTemperature(temperature).withMaxToken(maxTokens)
|
||||||
.withFunctions(toolNames).build();
|
.withFunctions(toolNames).withToolContext(toolContext).build();
|
||||||
case YI_YAN:
|
case YI_YAN:
|
||||||
return QianFanChatOptions.builder().model(model).temperature(temperature).maxTokens(maxTokens).build();
|
return QianFanChatOptions.builder().model(model).temperature(temperature).maxTokens(maxTokens).build();
|
||||||
case ZHI_PU:
|
case ZHI_PU:
|
||||||
return ZhiPuAiChatOptions.builder().model(model).temperature(temperature).maxTokens(maxTokens)
|
return ZhiPuAiChatOptions.builder().model(model).temperature(temperature).maxTokens(maxTokens)
|
||||||
.functions(toolNames).build();
|
.functions(toolNames).toolContext(toolContext).build();
|
||||||
case MINI_MAX:
|
case MINI_MAX:
|
||||||
return MiniMaxChatOptions.builder().model(model).temperature(temperature).maxTokens(maxTokens)
|
return MiniMaxChatOptions.builder().model(model).temperature(temperature).maxTokens(maxTokens)
|
||||||
.functions(toolNames).build();
|
.functions(toolNames).toolContext(toolContext).build();
|
||||||
case MOONSHOT:
|
case MOONSHOT:
|
||||||
return MoonshotChatOptions.builder().model(model).temperature(temperature).maxTokens(maxTokens)
|
return MoonshotChatOptions.builder().model(model).temperature(temperature).maxTokens(maxTokens)
|
||||||
.functions(toolNames).build();
|
.functions(toolNames).toolContext(toolContext).build();
|
||||||
case OPENAI:
|
case OPENAI:
|
||||||
case DEEP_SEEK: // 复用 OpenAI 客户端
|
case DEEP_SEEK: // 复用 OpenAI 客户端
|
||||||
case DOU_BAO: // 复用 OpenAI 客户端
|
case DOU_BAO: // 复用 OpenAI 客户端
|
||||||
|
@ -55,14 +56,14 @@ public class AiUtils {
|
||||||
case SILICON_FLOW: // 复用 OpenAI 客户端
|
case SILICON_FLOW: // 复用 OpenAI 客户端
|
||||||
case BAI_CHUAN: // 复用 OpenAI 客户端
|
case BAI_CHUAN: // 复用 OpenAI 客户端
|
||||||
return OpenAiChatOptions.builder().model(model).temperature(temperature).maxTokens(maxTokens)
|
return OpenAiChatOptions.builder().model(model).temperature(temperature).maxTokens(maxTokens)
|
||||||
.toolNames(toolNames).build();
|
.toolNames(toolNames).toolContext(toolContext).build();
|
||||||
case AZURE_OPENAI:
|
case AZURE_OPENAI:
|
||||||
// TODO 芋艿:貌似没 model 字段???!
|
// TODO 芋艿:貌似没 model 字段???!
|
||||||
return AzureOpenAiChatOptions.builder().deploymentName(model).temperature(temperature).maxTokens(maxTokens)
|
return AzureOpenAiChatOptions.builder().deploymentName(model).temperature(temperature).maxTokens(maxTokens)
|
||||||
.toolNames(toolNames).build();
|
.toolNames(toolNames).toolContext(toolContext).build();
|
||||||
case OLLAMA:
|
case OLLAMA:
|
||||||
return OllamaOptions.builder().model(model).temperature(temperature).numPredict(maxTokens)
|
return OllamaOptions.builder().model(model).temperature(temperature).numPredict(maxTokens)
|
||||||
.toolNames(toolNames).build();
|
.toolNames(toolNames).toolContext(toolContext).build();
|
||||||
default:
|
default:
|
||||||
throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform));
|
throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform));
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue