feat:AI工具新增ToolContext

This commit is contained in:
Ren 2025-04-12 23:52:24 +08:00
parent 911f5f8bf3
commit dda2b56bbf
4 changed files with 102 additions and 13 deletions

View File

@ -4,10 +4,12 @@ import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.util.ObjUtil; import cn.hutool.core.util.ObjUtil;
import cn.hutool.core.util.StrUtil; import cn.hutool.core.util.StrUtil;
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum; import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import cn.iocoder.yudao.framework.ai.core.pojo.AiToolContext;
import cn.iocoder.yudao.framework.ai.core.util.AiUtils; import cn.iocoder.yudao.framework.ai.core.util.AiUtils;
import cn.iocoder.yudao.framework.common.pojo.CommonResult; import cn.iocoder.yudao.framework.common.pojo.CommonResult;
import cn.iocoder.yudao.framework.common.pojo.PageResult; import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.framework.common.util.object.BeanUtils; import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
import cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils;
import cn.iocoder.yudao.framework.tenant.core.util.TenantUtils; import cn.iocoder.yudao.framework.tenant.core.util.TenantUtils;
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessagePageReqVO; import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessagePageReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageRespVO; import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageRespVO;
@ -115,7 +117,7 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
knowledgeSegments); knowledgeSegments);
// 3.2 创建 chat 需要的 Prompt // 3.2 创建 chat 需要的 Prompt
Prompt prompt = buildPrompt(conversation, historyMessages, knowledgeSegments, model, sendReqVO); Prompt prompt = buildPrompt(chatModel, conversation, historyMessages, knowledgeSegments, model, sendReqVO);
ChatResponse chatResponse = chatModel.call(prompt); ChatResponse chatResponse = chatModel.call(prompt);
// 3.3 更新响应内容 // 3.3 更新响应内容
@ -164,7 +166,7 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
knowledgeSegments); knowledgeSegments);
// 4.2 构建 Prompt并进行调用 // 4.2 构建 Prompt并进行调用
Prompt prompt = buildPrompt(conversation, historyMessages, knowledgeSegments, model, sendReqVO); Prompt prompt = buildPrompt(chatModel, conversation, historyMessages, knowledgeSegments, model, sendReqVO);
Flux<ChatResponse> streamResponse = chatModel.stream(prompt); Flux<ChatResponse> streamResponse = chatModel.stream(prompt);
// 4.3 流式返回 // 4.3 流式返回
@ -220,7 +222,7 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
return knowledgeSegments; return knowledgeSegments;
} }
private Prompt buildPrompt(AiChatConversationDO conversation, List<AiChatMessageDO> messages, private Prompt buildPrompt(StreamingChatModel chatModel, AiChatConversationDO conversation, List<AiChatMessageDO> messages,
List<AiKnowledgeSegmentSearchRespBO> knowledgeSegments, List<AiKnowledgeSegmentSearchRespBO> knowledgeSegments,
AiModelDO model, AiChatMessageSendReqVO sendReqVO) { AiModelDO model, AiChatMessageSendReqVO sendReqVO) {
List<Message> chatMessages = new ArrayList<>(); List<Message> chatMessages = new ArrayList<>();
@ -247,16 +249,22 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
// 2.1 查询 tool 工具 // 2.1 查询 tool 工具
Set<String> toolNames = null; Set<String> toolNames = null;
Map<String,Object> toolContext = Map.of();
if (conversation.getRoleId() != null) { if (conversation.getRoleId() != null) {
AiChatRoleDO chatRole = chatRoleService.getChatRole(conversation.getRoleId()); AiChatRoleDO chatRole = chatRoleService.getChatRole(conversation.getRoleId());
if (chatRole != null && CollUtil.isNotEmpty(chatRole.getToolIds())) { if (chatRole != null && CollUtil.isNotEmpty(chatRole.getToolIds())) {
toolNames = convertSet(toolService.getToolList(chatRole.getToolIds()), AiToolDO::getName); toolNames = convertSet(toolService.getToolList(chatRole.getToolIds()), AiToolDO::getName);
// 2.1.1 构建 Function Calling 的上下文参数
toolContext = Map.of(
AiToolContext.CONTEXT_KEY, new AiToolContext().setChatModel(chatModel).setUserId(SecurityFrameworkUtils.getLoginUserId())
.setRoleId(conversation.getRoleId())
.setConversationId(conversation.getId()));
} }
} }
// 2.2 构建 ChatOptions 对象 // 2.2 构建 ChatOptions 对象
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform()); AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
ChatOptions chatOptions = AiUtils.buildChatOptions(platform, model.getModel(), ChatOptions chatOptions = AiUtils.buildChatOptions(platform, model.getModel(),
conversation.getTemperature(), conversation.getMaxTokens(), toolNames); conversation.getTemperature(), conversation.getMaxTokens(), toolNames, toolContext);
return new Prompt(chatMessages, chatOptions); return new Prompt(chatMessages, chatOptions);
} }

View File

@ -0,0 +1,43 @@
package cn.iocoder.yudao.module.ai.service.model.tool;
import cn.iocoder.yudao.framework.ai.core.pojo.AiToolContext;
import com.fasterxml.jackson.annotation.JsonClassDescription;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import org.springframework.ai.chat.model.ToolContext;
import org.springframework.stereotype.Component;
import java.util.function.BiFunction;
/**
* 工具用户ID查询上下文参数Demo
*
* @author Ren
*/
@Component("userid_query")
public class UserIdQueryToolFunction
implements BiFunction<UserIdQueryToolFunction.Request, ToolContext, UserIdQueryToolFunction.Response> {
@Data
@JsonClassDescription("用户ID查询")
public static class Request { }
@Data
@AllArgsConstructor
@NoArgsConstructor
public static class Response {
/**
* 用户ID
*/
private Long UserId;
}
@Override
public UserIdQueryToolFunction.Response apply(UserIdQueryToolFunction.Request request, ToolContext toolContext) {
// 获取当前登录用户
AiToolContext context = (AiToolContext) toolContext.getContext().get(AiToolContext.CONTEXT_KEY);
return new Response(context.getUserId());
}
}

View File

@ -0,0 +1,37 @@
package cn.iocoder.yudao.framework.ai.core.pojo;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.StreamingChatModel;
/**
* 工具上下文参数 DTO让AI工具可以处理当前用户的相关信息
*
*/
@Data
@NoArgsConstructor
@AllArgsConstructor
public class AiToolContext {
public static final String CONTEXT_KEY = "AI_TOOL_CONTEXT";
/**
* 用户ID
*/
private Long userId;
/**
* 聊天模型
*/
private StreamingChatModel chatModel;
/**
* 关联的聊天角色Id
*/
private Long roleId;
/**
* 会话Id
*/
private Long conversationId;
}

View File

@ -15,6 +15,7 @@ import org.springframework.ai.qianfan.QianFanChatOptions;
import org.springframework.ai.zhipuai.ZhiPuAiChatOptions; import org.springframework.ai.zhipuai.ZhiPuAiChatOptions;
import java.util.Collections; import java.util.Collections;
import java.util.Map;
import java.util.Set; import java.util.Set;
/** /**
@ -25,28 +26,28 @@ import java.util.Set;
public class AiUtils { public class AiUtils {
public static ChatOptions buildChatOptions(AiPlatformEnum platform, String model, Double temperature, Integer maxTokens) { public static ChatOptions buildChatOptions(AiPlatformEnum platform, String model, Double temperature, Integer maxTokens) {
return buildChatOptions(platform, model, temperature, maxTokens, null); return buildChatOptions(platform, model, temperature, maxTokens, null, Map.of());
} }
public static ChatOptions buildChatOptions(AiPlatformEnum platform, String model, Double temperature, Integer maxTokens, public static ChatOptions buildChatOptions(AiPlatformEnum platform, String model, Double temperature, Integer maxTokens,
Set<String> toolNames) { Set<String> toolNames, Map<String, Object> toolContext) {
toolNames = ObjUtil.defaultIfNull(toolNames, Collections.emptySet()); toolNames = ObjUtil.defaultIfNull(toolNames, Collections.emptySet());
// noinspection EnhancedSwitchMigration // noinspection EnhancedSwitchMigration
switch (platform) { switch (platform) {
case TONG_YI: case TONG_YI:
return DashScopeChatOptions.builder().withModel(model).withTemperature(temperature).withMaxToken(maxTokens) return DashScopeChatOptions.builder().withModel(model).withTemperature(temperature).withMaxToken(maxTokens)
.withFunctions(toolNames).build(); .withFunctions(toolNames).withToolContext(toolContext).build();
case YI_YAN: case YI_YAN:
return QianFanChatOptions.builder().model(model).temperature(temperature).maxTokens(maxTokens).build(); return QianFanChatOptions.builder().model(model).temperature(temperature).maxTokens(maxTokens).build();
case ZHI_PU: case ZHI_PU:
return ZhiPuAiChatOptions.builder().model(model).temperature(temperature).maxTokens(maxTokens) return ZhiPuAiChatOptions.builder().model(model).temperature(temperature).maxTokens(maxTokens)
.functions(toolNames).build(); .functions(toolNames).toolContext(toolContext).build();
case MINI_MAX: case MINI_MAX:
return MiniMaxChatOptions.builder().model(model).temperature(temperature).maxTokens(maxTokens) return MiniMaxChatOptions.builder().model(model).temperature(temperature).maxTokens(maxTokens)
.functions(toolNames).build(); .functions(toolNames).toolContext(toolContext).build();
case MOONSHOT: case MOONSHOT:
return MoonshotChatOptions.builder().model(model).temperature(temperature).maxTokens(maxTokens) return MoonshotChatOptions.builder().model(model).temperature(temperature).maxTokens(maxTokens)
.functions(toolNames).build(); .functions(toolNames).toolContext(toolContext).build();
case OPENAI: case OPENAI:
case DEEP_SEEK: // 复用 OpenAI 客户端 case DEEP_SEEK: // 复用 OpenAI 客户端
case DOU_BAO: // 复用 OpenAI 客户端 case DOU_BAO: // 复用 OpenAI 客户端
@ -55,14 +56,14 @@ public class AiUtils {
case SILICON_FLOW: // 复用 OpenAI 客户端 case SILICON_FLOW: // 复用 OpenAI 客户端
case BAI_CHUAN: // 复用 OpenAI 客户端 case BAI_CHUAN: // 复用 OpenAI 客户端
return OpenAiChatOptions.builder().model(model).temperature(temperature).maxTokens(maxTokens) return OpenAiChatOptions.builder().model(model).temperature(temperature).maxTokens(maxTokens)
.toolNames(toolNames).build(); .toolNames(toolNames).toolContext(toolContext).build();
case AZURE_OPENAI: case AZURE_OPENAI:
// TODO 芋艿貌似没 model 字段 // TODO 芋艿貌似没 model 字段
return AzureOpenAiChatOptions.builder().deploymentName(model).temperature(temperature).maxTokens(maxTokens) return AzureOpenAiChatOptions.builder().deploymentName(model).temperature(temperature).maxTokens(maxTokens)
.toolNames(toolNames).build(); .toolNames(toolNames).toolContext(toolContext).build();
case OLLAMA: case OLLAMA:
return OllamaOptions.builder().model(model).temperature(temperature).numPredict(maxTokens) return OllamaOptions.builder().model(model).temperature(temperature).numPredict(maxTokens)
.toolNames(toolNames).build(); .toolNames(toolNames).toolContext(toolContext).build();
default: default:
throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform)); throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform));
} }