diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatMessageServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatMessageServiceImpl.java index f310ba69fd..36dcaa999c 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatMessageServiceImpl.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatMessageServiceImpl.java @@ -4,10 +4,12 @@ import cn.hutool.core.collection.CollUtil; import cn.hutool.core.util.ObjUtil; import cn.hutool.core.util.StrUtil; import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum; +import cn.iocoder.yudao.framework.ai.core.pojo.AiToolContext; import cn.iocoder.yudao.framework.ai.core.util.AiUtils; import cn.iocoder.yudao.framework.common.pojo.CommonResult; import cn.iocoder.yudao.framework.common.pojo.PageResult; import cn.iocoder.yudao.framework.common.util.object.BeanUtils; +import cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils; import cn.iocoder.yudao.framework.tenant.core.util.TenantUtils; import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessagePageReqVO; import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageRespVO; @@ -115,7 +117,7 @@ public class AiChatMessageServiceImpl implements AiChatMessageService { knowledgeSegments); // 3.2 创建 chat 需要的 Prompt - Prompt prompt = buildPrompt(conversation, historyMessages, knowledgeSegments, model, sendReqVO); + Prompt prompt = buildPrompt(chatModel, conversation, historyMessages, knowledgeSegments, model, sendReqVO); ChatResponse chatResponse = chatModel.call(prompt); // 3.3 更新响应内容 @@ -164,7 +166,7 @@ public class AiChatMessageServiceImpl implements AiChatMessageService { knowledgeSegments); // 4.2 构建 Prompt,并进行调用 - Prompt prompt = buildPrompt(conversation, historyMessages, knowledgeSegments, model, sendReqVO); + Prompt prompt = buildPrompt(chatModel, conversation, historyMessages, knowledgeSegments, model, sendReqVO); Flux streamResponse = chatModel.stream(prompt); // 4.3 流式返回 @@ -220,7 +222,7 @@ public class AiChatMessageServiceImpl implements AiChatMessageService { return knowledgeSegments; } - private Prompt buildPrompt(AiChatConversationDO conversation, List messages, + private Prompt buildPrompt(StreamingChatModel chatModel, AiChatConversationDO conversation, List messages, List knowledgeSegments, AiModelDO model, AiChatMessageSendReqVO sendReqVO) { List chatMessages = new ArrayList<>(); @@ -247,16 +249,22 @@ public class AiChatMessageServiceImpl implements AiChatMessageService { // 2.1 查询 tool 工具 Set toolNames = null; + Map toolContext = Map.of(); if (conversation.getRoleId() != null) { AiChatRoleDO chatRole = chatRoleService.getChatRole(conversation.getRoleId()); if (chatRole != null && CollUtil.isNotEmpty(chatRole.getToolIds())) { toolNames = convertSet(toolService.getToolList(chatRole.getToolIds()), AiToolDO::getName); + // 2.1.1 构建 Function Calling 的上下文参数 + toolContext = Map.of( + AiToolContext.CONTEXT_KEY, new AiToolContext().setChatModel(chatModel).setUserId(SecurityFrameworkUtils.getLoginUserId()) + .setRoleId(conversation.getRoleId()) + .setConversationId(conversation.getId())); } } // 2.2 构建 ChatOptions 对象 AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform()); ChatOptions chatOptions = AiUtils.buildChatOptions(platform, model.getModel(), - conversation.getTemperature(), conversation.getMaxTokens(), toolNames); + conversation.getTemperature(), conversation.getMaxTokens(), toolNames, toolContext); return new Prompt(chatMessages, chatOptions); } diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/tool/UserIdQueryToolFunction.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/tool/UserIdQueryToolFunction.java new file mode 100644 index 0000000000..380fdba414 --- /dev/null +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/tool/UserIdQueryToolFunction.java @@ -0,0 +1,43 @@ +package cn.iocoder.yudao.module.ai.service.model.tool; + +import cn.iocoder.yudao.framework.ai.core.pojo.AiToolContext; +import com.fasterxml.jackson.annotation.JsonClassDescription; +import lombok.AllArgsConstructor; +import lombok.Data; +import lombok.NoArgsConstructor; +import org.springframework.ai.chat.model.ToolContext; +import org.springframework.stereotype.Component; + +import java.util.function.BiFunction; + +/** + * 工具:用户ID查询(上下文参数Demo) + * + * @author Ren + */ +@Component("userid_query") +public class UserIdQueryToolFunction + implements BiFunction { + + @Data + @JsonClassDescription("用户ID查询") + public static class Request { } + + @Data + @AllArgsConstructor + @NoArgsConstructor + public static class Response { + /** + * 用户ID + */ + private Long UserId; + + } + @Override + public UserIdQueryToolFunction.Response apply(UserIdQueryToolFunction.Request request, ToolContext toolContext) { + // 获取当前登录用户 + AiToolContext context = (AiToolContext) toolContext.getContext().get(AiToolContext.CONTEXT_KEY); + + return new Response(context.getUserId()); + } +} diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/pojo/AiToolContext.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/pojo/AiToolContext.java new file mode 100644 index 0000000000..a0aa72f466 --- /dev/null +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/pojo/AiToolContext.java @@ -0,0 +1,37 @@ +package cn.iocoder.yudao.framework.ai.core.pojo; + +import lombok.AllArgsConstructor; +import lombok.Data; +import lombok.NoArgsConstructor; +import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.chat.model.StreamingChatModel; + +/** + * 工具上下文参数 DTO,让AI工具可以处理当前用户的相关信息 + * + */ +@Data +@NoArgsConstructor +@AllArgsConstructor +public class AiToolContext { + public static final String CONTEXT_KEY = "AI_TOOL_CONTEXT"; + /** + * 用户ID + */ + private Long userId; + + /** + * 聊天模型 + */ + private StreamingChatModel chatModel; + + /** + * 关联的聊天角色Id + */ + private Long roleId; + + /** + * 会话Id + */ + private Long conversationId; +} \ No newline at end of file diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/util/AiUtils.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/util/AiUtils.java index 09370c6363..47577356be 100644 --- a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/util/AiUtils.java +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/util/AiUtils.java @@ -15,6 +15,7 @@ import org.springframework.ai.qianfan.QianFanChatOptions; import org.springframework.ai.zhipuai.ZhiPuAiChatOptions; import java.util.Collections; +import java.util.Map; import java.util.Set; /** @@ -25,28 +26,28 @@ import java.util.Set; public class AiUtils { public static ChatOptions buildChatOptions(AiPlatformEnum platform, String model, Double temperature, Integer maxTokens) { - return buildChatOptions(platform, model, temperature, maxTokens, null); + return buildChatOptions(platform, model, temperature, maxTokens, null, Map.of()); } public static ChatOptions buildChatOptions(AiPlatformEnum platform, String model, Double temperature, Integer maxTokens, - Set toolNames) { + Set toolNames, Map toolContext) { toolNames = ObjUtil.defaultIfNull(toolNames, Collections.emptySet()); // noinspection EnhancedSwitchMigration switch (platform) { case TONG_YI: return DashScopeChatOptions.builder().withModel(model).withTemperature(temperature).withMaxToken(maxTokens) - .withFunctions(toolNames).build(); + .withFunctions(toolNames).withToolContext(toolContext).build(); case YI_YAN: return QianFanChatOptions.builder().model(model).temperature(temperature).maxTokens(maxTokens).build(); case ZHI_PU: return ZhiPuAiChatOptions.builder().model(model).temperature(temperature).maxTokens(maxTokens) - .functions(toolNames).build(); + .functions(toolNames).toolContext(toolContext).build(); case MINI_MAX: return MiniMaxChatOptions.builder().model(model).temperature(temperature).maxTokens(maxTokens) - .functions(toolNames).build(); + .functions(toolNames).toolContext(toolContext).build(); case MOONSHOT: return MoonshotChatOptions.builder().model(model).temperature(temperature).maxTokens(maxTokens) - .functions(toolNames).build(); + .functions(toolNames).toolContext(toolContext).build(); case OPENAI: case DEEP_SEEK: // 复用 OpenAI 客户端 case DOU_BAO: // 复用 OpenAI 客户端 @@ -55,14 +56,14 @@ public class AiUtils { case SILICON_FLOW: // 复用 OpenAI 客户端 case BAI_CHUAN: // 复用 OpenAI 客户端 return OpenAiChatOptions.builder().model(model).temperature(temperature).maxTokens(maxTokens) - .toolNames(toolNames).build(); + .toolNames(toolNames).toolContext(toolContext).build(); case AZURE_OPENAI: // TODO 芋艿:貌似没 model 字段???! return AzureOpenAiChatOptions.builder().deploymentName(model).temperature(temperature).maxTokens(maxTokens) - .toolNames(toolNames).build(); + .toolNames(toolNames).toolContext(toolContext).build(); case OLLAMA: return OllamaOptions.builder().model(model).temperature(temperature).numPredict(maxTokens) - .toolNames(toolNames).build(); + .toolNames(toolNames).toolContext(toolContext).build(); default: throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform)); }