feat:【AI 大模型】增加 AI ToolContext 上下文
This commit is contained in:
parent
3c87f953ee
commit
e11ee654ef
|
@ -4,12 +4,10 @@ import cn.hutool.core.collection.CollUtil;
|
||||||
import cn.hutool.core.util.ObjUtil;
|
import cn.hutool.core.util.ObjUtil;
|
||||||
import cn.hutool.core.util.StrUtil;
|
import cn.hutool.core.util.StrUtil;
|
||||||
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
|
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
|
||||||
import cn.iocoder.yudao.framework.ai.core.pojo.AiToolContext;
|
|
||||||
import cn.iocoder.yudao.framework.ai.core.util.AiUtils;
|
import cn.iocoder.yudao.framework.ai.core.util.AiUtils;
|
||||||
import cn.iocoder.yudao.framework.common.pojo.CommonResult;
|
import cn.iocoder.yudao.framework.common.pojo.CommonResult;
|
||||||
import cn.iocoder.yudao.framework.common.pojo.PageResult;
|
import cn.iocoder.yudao.framework.common.pojo.PageResult;
|
||||||
import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
|
import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
|
||||||
import cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils;
|
|
||||||
import cn.iocoder.yudao.framework.tenant.core.util.TenantUtils;
|
import cn.iocoder.yudao.framework.tenant.core.util.TenantUtils;
|
||||||
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessagePageReqVO;
|
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessagePageReqVO;
|
||||||
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageRespVO;
|
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageRespVO;
|
||||||
|
@ -103,8 +101,7 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
|
||||||
ChatModel chatModel = modalService.getChatModel(model.getId());
|
ChatModel chatModel = modalService.getChatModel(model.getId());
|
||||||
|
|
||||||
// 2. 知识库找回
|
// 2. 知识库找回
|
||||||
List<AiKnowledgeSegmentSearchRespBO> knowledgeSegments = recallKnowledgeSegment(sendReqVO.getContent(),
|
List<AiKnowledgeSegmentSearchRespBO> knowledgeSegments = recallKnowledgeSegment(sendReqVO.getContent(), conversation);
|
||||||
conversation);
|
|
||||||
|
|
||||||
// 3. 插入 user 发送消息
|
// 3. 插入 user 发送消息
|
||||||
AiChatMessageDO userMessage = createChatMessage(conversation.getId(), null, model,
|
AiChatMessageDO userMessage = createChatMessage(conversation.getId(), null, model,
|
||||||
|
@ -117,7 +114,7 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
|
||||||
knowledgeSegments);
|
knowledgeSegments);
|
||||||
|
|
||||||
// 3.2 创建 chat 需要的 Prompt
|
// 3.2 创建 chat 需要的 Prompt
|
||||||
Prompt prompt = buildPrompt(chatModel, conversation, historyMessages, knowledgeSegments, model, sendReqVO);
|
Prompt prompt = buildPrompt(conversation, historyMessages, knowledgeSegments, model, sendReqVO);
|
||||||
ChatResponse chatResponse = chatModel.call(prompt);
|
ChatResponse chatResponse = chatModel.call(prompt);
|
||||||
|
|
||||||
// 3.3 更新响应内容
|
// 3.3 更新响应内容
|
||||||
|
@ -166,7 +163,7 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
|
||||||
knowledgeSegments);
|
knowledgeSegments);
|
||||||
|
|
||||||
// 4.2 构建 Prompt,并进行调用
|
// 4.2 构建 Prompt,并进行调用
|
||||||
Prompt prompt = buildPrompt(chatModel, conversation, historyMessages, knowledgeSegments, model, sendReqVO);
|
Prompt prompt = buildPrompt(conversation, historyMessages, knowledgeSegments, model, sendReqVO);
|
||||||
Flux<ChatResponse> streamResponse = chatModel.stream(prompt);
|
Flux<ChatResponse> streamResponse = chatModel.stream(prompt);
|
||||||
|
|
||||||
// 4.3 流式返回
|
// 4.3 流式返回
|
||||||
|
@ -222,9 +219,9 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
|
||||||
return knowledgeSegments;
|
return knowledgeSegments;
|
||||||
}
|
}
|
||||||
|
|
||||||
private Prompt buildPrompt(StreamingChatModel chatModel, AiChatConversationDO conversation, List<AiChatMessageDO> messages,
|
private Prompt buildPrompt(AiChatConversationDO conversation, List<AiChatMessageDO> messages,
|
||||||
List<AiKnowledgeSegmentSearchRespBO> knowledgeSegments,
|
List<AiKnowledgeSegmentSearchRespBO> knowledgeSegments,
|
||||||
AiModelDO model, AiChatMessageSendReqVO sendReqVO) {
|
AiModelDO model, AiChatMessageSendReqVO sendReqVO) {
|
||||||
List<Message> chatMessages = new ArrayList<>();
|
List<Message> chatMessages = new ArrayList<>();
|
||||||
// 1.1 System Context 角色设定
|
// 1.1 System Context 角色设定
|
||||||
if (StrUtil.isNotBlank(conversation.getSystemMessage())) {
|
if (StrUtil.isNotBlank(conversation.getSystemMessage())) {
|
||||||
|
@ -254,11 +251,7 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
|
||||||
AiChatRoleDO chatRole = chatRoleService.getChatRole(conversation.getRoleId());
|
AiChatRoleDO chatRole = chatRoleService.getChatRole(conversation.getRoleId());
|
||||||
if (chatRole != null && CollUtil.isNotEmpty(chatRole.getToolIds())) {
|
if (chatRole != null && CollUtil.isNotEmpty(chatRole.getToolIds())) {
|
||||||
toolNames = convertSet(toolService.getToolList(chatRole.getToolIds()), AiToolDO::getName);
|
toolNames = convertSet(toolService.getToolList(chatRole.getToolIds()), AiToolDO::getName);
|
||||||
// 2.1.1 构建 Function Calling 的上下文参数
|
toolContext = AiUtils.buildCommonToolContext();
|
||||||
toolContext = Map.of(
|
|
||||||
AiToolContext.CONTEXT_KEY, new AiToolContext().setChatModel(chatModel).setUserId(SecurityFrameworkUtils.getLoginUserId())
|
|
||||||
.setRoleId(conversation.getRoleId())
|
|
||||||
.setConversationId(conversation.getId()));
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// 2.2 构建 ChatOptions 对象
|
// 2.2 构建 ChatOptions 对象
|
||||||
|
|
|
@ -29,6 +29,13 @@ public interface AiKnowledgeService {
|
||||||
*/
|
*/
|
||||||
void updateKnowledge(AiKnowledgeSaveReqVO updateReqVO);
|
void updateKnowledge(AiKnowledgeSaveReqVO updateReqVO);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 删除知识库
|
||||||
|
*
|
||||||
|
* @param id 知识库编号
|
||||||
|
*/
|
||||||
|
void deleteKnowledge(Long id);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 获得知识库
|
* 获得知识库
|
||||||
*
|
*
|
||||||
|
|
|
@ -1,13 +1,11 @@
|
||||||
package cn.iocoder.yudao.module.ai.service.knowledge;
|
package cn.iocoder.yudao.module.ai.service.knowledge;
|
||||||
|
|
||||||
import cn.hutool.core.util.ObjUtil;
|
import cn.hutool.core.util.ObjUtil;
|
||||||
import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum;
|
|
||||||
import cn.iocoder.yudao.framework.common.pojo.PageResult;
|
import cn.iocoder.yudao.framework.common.pojo.PageResult;
|
||||||
import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
|
import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
|
||||||
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.knowledge.AiKnowledgePageReqVO;
|
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.knowledge.AiKnowledgePageReqVO;
|
||||||
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.knowledge.AiKnowledgeSaveReqVO;
|
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.knowledge.AiKnowledgeSaveReqVO;
|
||||||
import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeDO;
|
import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeDO;
|
||||||
import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeDocumentDO;
|
|
||||||
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO;
|
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO;
|
||||||
import cn.iocoder.yudao.module.ai.dal.mysql.knowledge.AiKnowledgeMapper;
|
import cn.iocoder.yudao.module.ai.dal.mysql.knowledge.AiKnowledgeMapper;
|
||||||
import cn.iocoder.yudao.module.ai.service.model.AiModelService;
|
import cn.iocoder.yudao.module.ai.service.model.AiModelService;
|
||||||
|
@ -67,6 +65,11 @@ public class AiKnowledgeServiceImpl implements AiKnowledgeService {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void deleteKnowledge(Long id) {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public AiKnowledgeDO getKnowledge(Long id) {
|
public AiKnowledgeDO getKnowledge(Long id) {
|
||||||
return knowledgeMapper.selectById(id);
|
return knowledgeMapper.selectById(id);
|
||||||
|
|
|
@ -1,43 +0,0 @@
|
||||||
package cn.iocoder.yudao.module.ai.service.model.tool;
|
|
||||||
|
|
||||||
import cn.iocoder.yudao.framework.ai.core.pojo.AiToolContext;
|
|
||||||
import com.fasterxml.jackson.annotation.JsonClassDescription;
|
|
||||||
import lombok.AllArgsConstructor;
|
|
||||||
import lombok.Data;
|
|
||||||
import lombok.NoArgsConstructor;
|
|
||||||
import org.springframework.ai.chat.model.ToolContext;
|
|
||||||
import org.springframework.stereotype.Component;
|
|
||||||
|
|
||||||
import java.util.function.BiFunction;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* 工具:用户ID查询(上下文参数Demo)
|
|
||||||
*
|
|
||||||
* @author Ren
|
|
||||||
*/
|
|
||||||
@Component("userid_query")
|
|
||||||
public class UserIdQueryToolFunction
|
|
||||||
implements BiFunction<UserIdQueryToolFunction.Request, ToolContext, UserIdQueryToolFunction.Response> {
|
|
||||||
|
|
||||||
@Data
|
|
||||||
@JsonClassDescription("用户ID查询")
|
|
||||||
public static class Request { }
|
|
||||||
|
|
||||||
@Data
|
|
||||||
@AllArgsConstructor
|
|
||||||
@NoArgsConstructor
|
|
||||||
public static class Response {
|
|
||||||
/**
|
|
||||||
* 用户ID
|
|
||||||
*/
|
|
||||||
private Long UserId;
|
|
||||||
|
|
||||||
}
|
|
||||||
@Override
|
|
||||||
public UserIdQueryToolFunction.Response apply(UserIdQueryToolFunction.Request request, ToolContext toolContext) {
|
|
||||||
// 获取当前登录用户
|
|
||||||
AiToolContext context = (AiToolContext) toolContext.getContext().get(AiToolContext.CONTEXT_KEY);
|
|
||||||
|
|
||||||
return new Response(context.getUserId());
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -0,0 +1,75 @@
|
||||||
|
package cn.iocoder.yudao.module.ai.service.model.tool;
|
||||||
|
|
||||||
|
import cn.iocoder.yudao.framework.ai.core.util.AiUtils;
|
||||||
|
import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
|
||||||
|
import cn.iocoder.yudao.framework.security.core.LoginUser;
|
||||||
|
import cn.iocoder.yudao.framework.tenant.core.util.TenantUtils;
|
||||||
|
import cn.iocoder.yudao.module.system.api.user.AdminUserApi;
|
||||||
|
import cn.iocoder.yudao.module.system.api.user.dto.AdminUserRespDTO;
|
||||||
|
import com.fasterxml.jackson.annotation.JsonClassDescription;
|
||||||
|
import jakarta.annotation.Resource;
|
||||||
|
import lombok.AllArgsConstructor;
|
||||||
|
import lombok.Data;
|
||||||
|
import lombok.NoArgsConstructor;
|
||||||
|
import org.springframework.ai.chat.model.ToolContext;
|
||||||
|
import org.springframework.stereotype.Component;
|
||||||
|
|
||||||
|
import java.util.function.BiFunction;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 工具:当前用户信息查询
|
||||||
|
*
|
||||||
|
* 同时,也是展示 ToolContext 上下文的使用
|
||||||
|
*
|
||||||
|
* @author Ren
|
||||||
|
*/
|
||||||
|
@Component("user_profile_query")
|
||||||
|
public class UserProfileQueryToolFunction
|
||||||
|
implements BiFunction<UserProfileQueryToolFunction.Request, ToolContext, UserProfileQueryToolFunction.Response> {
|
||||||
|
|
||||||
|
@Resource
|
||||||
|
private AdminUserApi adminUserApi;
|
||||||
|
|
||||||
|
@Data
|
||||||
|
@JsonClassDescription("当前用户信息查询")
|
||||||
|
public static class Request { }
|
||||||
|
|
||||||
|
@Data
|
||||||
|
@AllArgsConstructor
|
||||||
|
@NoArgsConstructor
|
||||||
|
public static class Response {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 用户ID
|
||||||
|
*/
|
||||||
|
private Long id;
|
||||||
|
/**
|
||||||
|
* 用户昵称
|
||||||
|
*/
|
||||||
|
private String nickname;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 手机号码
|
||||||
|
*/
|
||||||
|
private String mobile;
|
||||||
|
/**
|
||||||
|
* 用户头像
|
||||||
|
*/
|
||||||
|
private String avatar;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public UserProfileQueryToolFunction.Response apply(UserProfileQueryToolFunction.Request request, ToolContext toolContext) {
|
||||||
|
LoginUser loginUser = (LoginUser) toolContext.getContext().get(AiUtils.TOOL_CONTEXT_LOGIN_USER);
|
||||||
|
Long tenantId = (Long) toolContext.getContext().get(AiUtils.TOOL_CONTEXT_TENANT_ID);
|
||||||
|
if (loginUser == null | tenantId == null) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
return TenantUtils.execute(tenantId, () -> {
|
||||||
|
AdminUserRespDTO user = adminUserApi.getUser(loginUser.getId());
|
||||||
|
return BeanUtils.toBean(user, Response.class);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -24,6 +24,18 @@
|
||||||
<artifactId>yudao-common</artifactId>
|
<artifactId>yudao-common</artifactId>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
|
||||||
|
<!-- 业务组件 -->
|
||||||
|
<dependency>
|
||||||
|
<groupId>cn.iocoder.boot</groupId>
|
||||||
|
<artifactId>yudao-spring-boot-starter-biz-tenant</artifactId>
|
||||||
|
</dependency>
|
||||||
|
|
||||||
|
<!-- Web 相关 -->
|
||||||
|
<dependency>
|
||||||
|
<groupId>cn.iocoder.boot</groupId>
|
||||||
|
<artifactId>yudao-spring-boot-starter-security</artifactId>
|
||||||
|
</dependency>
|
||||||
|
|
||||||
<!-- Spring AI Model 模型接入 -->
|
<!-- Spring AI Model 模型接入 -->
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.springframework.ai</groupId>
|
<groupId>org.springframework.ai</groupId>
|
||||||
|
|
|
@ -1,37 +0,0 @@
|
||||||
package cn.iocoder.yudao.framework.ai.core.pojo;
|
|
||||||
|
|
||||||
import lombok.AllArgsConstructor;
|
|
||||||
import lombok.Data;
|
|
||||||
import lombok.NoArgsConstructor;
|
|
||||||
import org.springframework.ai.chat.model.ChatModel;
|
|
||||||
import org.springframework.ai.chat.model.StreamingChatModel;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* 工具上下文参数 DTO,让AI工具可以处理当前用户的相关信息
|
|
||||||
*
|
|
||||||
*/
|
|
||||||
@Data
|
|
||||||
@NoArgsConstructor
|
|
||||||
@AllArgsConstructor
|
|
||||||
public class AiToolContext {
|
|
||||||
public static final String CONTEXT_KEY = "AI_TOOL_CONTEXT";
|
|
||||||
/**
|
|
||||||
* 用户ID
|
|
||||||
*/
|
|
||||||
private Long userId;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* 聊天模型
|
|
||||||
*/
|
|
||||||
private StreamingChatModel chatModel;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* 关联的聊天角色Id
|
|
||||||
*/
|
|
||||||
private Long roleId;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* 会话Id
|
|
||||||
*/
|
|
||||||
private Long conversationId;
|
|
||||||
}
|
|
|
@ -3,6 +3,8 @@ package cn.iocoder.yudao.framework.ai.core.util;
|
||||||
import cn.hutool.core.util.ObjUtil;
|
import cn.hutool.core.util.ObjUtil;
|
||||||
import cn.hutool.core.util.StrUtil;
|
import cn.hutool.core.util.StrUtil;
|
||||||
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
|
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
|
||||||
|
import cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils;
|
||||||
|
import cn.iocoder.yudao.framework.tenant.core.context.TenantContextHolder;
|
||||||
import com.alibaba.cloud.ai.dashscope.chat.DashScopeChatOptions;
|
import com.alibaba.cloud.ai.dashscope.chat.DashScopeChatOptions;
|
||||||
import org.springframework.ai.azure.openai.AzureOpenAiChatOptions;
|
import org.springframework.ai.azure.openai.AzureOpenAiChatOptions;
|
||||||
import org.springframework.ai.chat.messages.*;
|
import org.springframework.ai.chat.messages.*;
|
||||||
|
@ -15,6 +17,7 @@ import org.springframework.ai.qianfan.QianFanChatOptions;
|
||||||
import org.springframework.ai.zhipuai.ZhiPuAiChatOptions;
|
import org.springframework.ai.zhipuai.ZhiPuAiChatOptions;
|
||||||
|
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
|
import java.util.HashMap;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
|
|
||||||
|
@ -25,8 +28,11 @@ import java.util.Set;
|
||||||
*/
|
*/
|
||||||
public class AiUtils {
|
public class AiUtils {
|
||||||
|
|
||||||
|
public static final String TOOL_CONTEXT_LOGIN_USER = "LOGIN_USER";
|
||||||
|
public static final String TOOL_CONTEXT_TENANT_ID = "TENANT_ID";
|
||||||
|
|
||||||
public static ChatOptions buildChatOptions(AiPlatformEnum platform, String model, Double temperature, Integer maxTokens) {
|
public static ChatOptions buildChatOptions(AiPlatformEnum platform, String model, Double temperature, Integer maxTokens) {
|
||||||
return buildChatOptions(platform, model, temperature, maxTokens, null, Map.of());
|
return buildChatOptions(platform, model, temperature, maxTokens, null, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static ChatOptions buildChatOptions(AiPlatformEnum platform, String model, Double temperature, Integer maxTokens,
|
public static ChatOptions buildChatOptions(AiPlatformEnum platform, String model, Double temperature, Integer maxTokens,
|
||||||
|
@ -85,4 +91,11 @@ public class AiUtils {
|
||||||
throw new IllegalArgumentException(StrUtil.format("未知消息类型({})", type));
|
throw new IllegalArgumentException(StrUtil.format("未知消息类型({})", type));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public static Map<String, Object> buildCommonToolContext() {
|
||||||
|
Map<String, Object> context = new HashMap<>();
|
||||||
|
context.put(TOOL_CONTEXT_LOGIN_USER, SecurityFrameworkUtils.getLoginUser());
|
||||||
|
context.put(TOOL_CONTEXT_TENANT_ID, TenantContextHolder.getTenantId());
|
||||||
|
return context;
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
Loading…
Reference in New Issue