【功能新增】AI:增加 AI 对话,与 tool 的打通

This commit is contained in:
YunaiV 2025-03-14 21:11:28 +08:00
parent e869aaee83
commit 1e8845ce6d
10 changed files with 147 additions and 33 deletions

View File

@ -1,5 +1,6 @@
package cn.iocoder.yudao.module.ai.controller.admin.model; package cn.iocoder.yudao.module.ai.controller.admin.model;
import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum;
import cn.iocoder.yudao.framework.common.pojo.CommonResult; import cn.iocoder.yudao.framework.common.pojo.CommonResult;
import cn.iocoder.yudao.framework.common.pojo.PageResult; import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.framework.common.util.object.BeanUtils; import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
@ -17,8 +18,10 @@ import org.springframework.security.access.prepost.PreAuthorize;
import org.springframework.validation.annotation.Validated; import org.springframework.validation.annotation.Validated;
import org.springframework.web.bind.annotation.*; import org.springframework.web.bind.annotation.*;
import static cn.iocoder.yudao.framework.common.pojo.CommonResult.success; import java.util.List;
import static cn.iocoder.yudao.framework.common.pojo.CommonResult.success;
import static cn.iocoder.yudao.framework.common.util.collection.CollectionUtils.convertList;
@Tag(name = "管理后台 - AI 工具") @Tag(name = "管理后台 - AI 工具")
@RestController @RestController
@ -30,14 +33,14 @@ public class AiToolController {
private AiToolService toolService; private AiToolService toolService;
@PostMapping("/create") @PostMapping("/create")
@Operation(summary = "创建AI 工具") @Operation(summary = "创建工具")
@PreAuthorize("@ss.hasPermission('ai:tool:create')") @PreAuthorize("@ss.hasPermission('ai:tool:create')")
public CommonResult<Long> createTool(@Valid @RequestBody AiToolSaveReqVO createReqVO) { public CommonResult<Long> createTool(@Valid @RequestBody AiToolSaveReqVO createReqVO) {
return success(toolService.createTool(createReqVO)); return success(toolService.createTool(createReqVO));
} }
@PutMapping("/update") @PutMapping("/update")
@Operation(summary = "更新AI 工具") @Operation(summary = "更新工具")
@PreAuthorize("@ss.hasPermission('ai:tool:update')") @PreAuthorize("@ss.hasPermission('ai:tool:update')")
public CommonResult<Boolean> updateTool(@Valid @RequestBody AiToolSaveReqVO updateReqVO) { public CommonResult<Boolean> updateTool(@Valid @RequestBody AiToolSaveReqVO updateReqVO) {
toolService.updateTool(updateReqVO); toolService.updateTool(updateReqVO);
@ -45,7 +48,7 @@ public class AiToolController {
} }
@DeleteMapping("/delete") @DeleteMapping("/delete")
@Operation(summary = "删除AI 工具") @Operation(summary = "删除工具")
@Parameter(name = "id", description = "编号", required = true) @Parameter(name = "id", description = "编号", required = true)
@PreAuthorize("@ss.hasPermission('ai:tool:delete')") @PreAuthorize("@ss.hasPermission('ai:tool:delete')")
public CommonResult<Boolean> deleteTool(@RequestParam("id") Long id) { public CommonResult<Boolean> deleteTool(@RequestParam("id") Long id) {
@ -54,7 +57,7 @@ public class AiToolController {
} }
@GetMapping("/get") @GetMapping("/get")
@Operation(summary = "获得AI 工具") @Operation(summary = "获得工具")
@Parameter(name = "id", description = "编号", required = true, example = "1024") @Parameter(name = "id", description = "编号", required = true, example = "1024")
@PreAuthorize("@ss.hasPermission('ai:tool:query')") @PreAuthorize("@ss.hasPermission('ai:tool:query')")
public CommonResult<AiToolRespVO> getTool(@RequestParam("id") Long id) { public CommonResult<AiToolRespVO> getTool(@RequestParam("id") Long id) {
@ -63,11 +66,19 @@ public class AiToolController {
} }
@GetMapping("/page") @GetMapping("/page")
@Operation(summary = "获得AI 工具分页") @Operation(summary = "获得工具分页")
@PreAuthorize("@ss.hasPermission('ai:tool:query')") @PreAuthorize("@ss.hasPermission('ai:tool:query')")
public CommonResult<PageResult<AiToolRespVO>> getToolPage(@Valid AiToolPageReqVO pageReqVO) { public CommonResult<PageResult<AiToolRespVO>> getToolPage(@Valid AiToolPageReqVO pageReqVO) {
PageResult<AiToolDO> pageResult = toolService.getToolPage(pageReqVO); PageResult<AiToolDO> pageResult = toolService.getToolPage(pageReqVO);
return success(BeanUtils.toBean(pageResult, AiToolRespVO.class)); return success(BeanUtils.toBean(pageResult, AiToolRespVO.class));
} }
@GetMapping("/simple-list")
@Operation(summary = "获得工具列表")
public CommonResult<List<AiToolRespVO>> getToolSimpleList() {
List<AiToolDO> list = toolService.getToolListByStatus(CommonStatusEnum.ENABLE.getStatus());
return success(convertList(list, tool -> new AiToolRespVO()
.setId(tool.getId()).setName(tool.getName())));
}
} }

View File

@ -49,6 +49,9 @@ public class AiChatRoleRespVO implements VO {
@Schema(description = "引用的知识库编号列表", example = "1,2,3") @Schema(description = "引用的知识库编号列表", example = "1,2,3")
private List<Long> knowledgeIds; private List<Long> knowledgeIds;
@Schema(description = "引用的工具编号列表", example = "1,2,3")
private List<Long> toolIds;
@Schema(description = "是否公开", requiredMode = Schema.RequiredMode.REQUIRED, example = "1") @Schema(description = "是否公开", requiredMode = Schema.RequiredMode.REQUIRED, example = "1")
private Boolean publicStatus; private Boolean publicStatus;

View File

@ -34,4 +34,7 @@ public class AiChatRoleSaveMyReqVO {
@Schema(description = "引用的知识库编号列表", example = "1,2,3") @Schema(description = "引用的知识库编号列表", example = "1,2,3")
private List<Long> knowledgeIds; private List<Long> knowledgeIds;
@Schema(description = "引用的工具编号列表", example = "1,2,3")
private List<Long> toolIds;
} }

View File

@ -47,6 +47,9 @@ public class AiChatRoleSaveReqVO {
@Schema(description = "引用的知识库编号列表", example = "1,2,3") @Schema(description = "引用的知识库编号列表", example = "1,2,3")
private List<Long> knowledgeIds; private List<Long> knowledgeIds;
@Schema(description = "引用的工具编号列表", example = "1,2,3")
private List<Long> toolIds;
@Schema(description = "是否公开", requiredMode = Schema.RequiredMode.REQUIRED, example = "1") @Schema(description = "是否公开", requiredMode = Schema.RequiredMode.REQUIRED, example = "1")
@NotNull(message = "是否公开不能为空") @NotNull(message = "是否公开不能为空")
private Boolean publicStatus; private Boolean publicStatus;

View File

@ -74,6 +74,13 @@ public class AiChatRoleDO extends BaseDO {
*/ */
@TableField(typeHandler = LongListTypeHandler.class) @TableField(typeHandler = LongListTypeHandler.class)
private List<Long> knowledgeIds; private List<Long> knowledgeIds;
/**
* 引用的工具编号列表
*
* 关联 {@link AiToolDO#getId()} 字段
*/
@TableField(typeHandler = LongListTypeHandler.class)
private List<Long> toolIds;
/** /**
* 是否公开 * 是否公开

View File

@ -7,6 +7,8 @@ import cn.iocoder.yudao.module.ai.controller.admin.model.vo.tool.AiToolPageReqVO
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiToolDO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiToolDO;
import org.apache.ibatis.annotations.Mapper; import org.apache.ibatis.annotations.Mapper;
import java.util.List;
/** /**
* AI 工具 Mapper * AI 工具 Mapper
* *
@ -24,4 +26,10 @@ public interface AiToolMapper extends BaseMapperX<AiToolDO> {
.orderByDesc(AiToolDO::getId)); .orderByDesc(AiToolDO::getId));
} }
default List<AiToolDO> selectListByStatus(Integer status) {
return selectList(new LambdaQueryWrapperX<AiToolDO>()
.eq(AiToolDO::getStatus, status)
.orderByDesc(AiToolDO::getId));
}
} }

View File

@ -7,7 +7,6 @@ import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import cn.iocoder.yudao.framework.ai.core.util.AiUtils; import cn.iocoder.yudao.framework.ai.core.util.AiUtils;
import cn.iocoder.yudao.framework.common.pojo.CommonResult; import cn.iocoder.yudao.framework.common.pojo.CommonResult;
import cn.iocoder.yudao.framework.common.pojo.PageResult; import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.framework.common.util.collection.SetUtils;
import cn.iocoder.yudao.framework.common.util.object.BeanUtils; import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
import cn.iocoder.yudao.framework.tenant.core.util.TenantUtils; import cn.iocoder.yudao.framework.tenant.core.util.TenantUtils;
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessagePageReqVO; import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessagePageReqVO;
@ -19,6 +18,7 @@ import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatMessageDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeDocumentDO; import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeDocumentDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiToolDO;
import cn.iocoder.yudao.module.ai.dal.mysql.chat.AiChatMessageMapper; import cn.iocoder.yudao.module.ai.dal.mysql.chat.AiChatMessageMapper;
import cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants; import cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants;
import cn.iocoder.yudao.module.ai.service.knowledge.AiKnowledgeDocumentService; import cn.iocoder.yudao.module.ai.service.knowledge.AiKnowledgeDocumentService;
@ -27,6 +27,7 @@ import cn.iocoder.yudao.module.ai.service.knowledge.bo.AiKnowledgeSegmentSearchR
import cn.iocoder.yudao.module.ai.service.knowledge.bo.AiKnowledgeSegmentSearchRespBO; import cn.iocoder.yudao.module.ai.service.knowledge.bo.AiKnowledgeSegmentSearchRespBO;
import cn.iocoder.yudao.module.ai.service.model.AiChatRoleService; import cn.iocoder.yudao.module.ai.service.model.AiChatRoleService;
import cn.iocoder.yudao.module.ai.service.model.AiModelService; import cn.iocoder.yudao.module.ai.service.model.AiModelService;
import cn.iocoder.yudao.module.ai.service.model.AiToolService;
import jakarta.annotation.Resource; import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.Message;
@ -50,6 +51,7 @@ import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionU
import static cn.iocoder.yudao.framework.common.pojo.CommonResult.error; import static cn.iocoder.yudao.framework.common.pojo.CommonResult.error;
import static cn.iocoder.yudao.framework.common.pojo.CommonResult.success; import static cn.iocoder.yudao.framework.common.pojo.CommonResult.success;
import static cn.iocoder.yudao.framework.common.util.collection.CollectionUtils.convertList; import static cn.iocoder.yudao.framework.common.util.collection.CollectionUtils.convertList;
import static cn.iocoder.yudao.framework.common.util.collection.CollectionUtils.convertSet;
import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.CHAT_CONVERSATION_NOT_EXISTS; import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.CHAT_CONVERSATION_NOT_EXISTS;
import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.CHAT_MESSAGE_NOT_EXIST; import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.CHAT_MESSAGE_NOT_EXIST;
@ -82,6 +84,8 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
private AiKnowledgeSegmentService knowledgeSegmentService; private AiKnowledgeSegmentService knowledgeSegmentService;
@Resource @Resource
private AiKnowledgeDocumentService knowledgeDocumentService; private AiKnowledgeDocumentService knowledgeDocumentService;
@Resource
private AiToolService toolService;
@Transactional(rollbackFor = Exception.class) @Transactional(rollbackFor = Exception.class)
public AiChatMessageSendRespVO sendMessage(AiChatMessageSendReqVO sendReqVO, Long userId) { public AiChatMessageSendRespVO sendMessage(AiChatMessageSendReqVO sendReqVO, Long userId) {
@ -118,11 +122,13 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
String newContent = chatResponse.getResult().getOutput().getText(); String newContent = chatResponse.getResult().getOutput().getText();
chatMessageMapper.updateById(new AiChatMessageDO().setId(assistantMessage.getId()).setContent(newContent)); chatMessageMapper.updateById(new AiChatMessageDO().setId(assistantMessage.getId()).setContent(newContent));
// 3.4 响应结果 // 3.4 响应结果
List<AiChatMessageRespVO.KnowledgeSegment> segments = BeanUtils.toBean(knowledgeSegments, AiChatMessageRespVO.KnowledgeSegment.class, List<AiChatMessageRespVO.KnowledgeSegment> segments = BeanUtils.toBean(knowledgeSegments,
AiChatMessageRespVO.KnowledgeSegment.class,
segment -> { segment -> {
AiKnowledgeDocumentDO document = knowledgeDocumentService.getKnowledgeDocument(segment.getDocumentId()); AiKnowledgeDocumentDO document = knowledgeDocumentService
segment.setDocumentName(document != null ? document.getName() : null); .getKnowledgeDocument(segment.getDocumentId());
}); segment.setDocumentName(document != null ? document.getName() : null);
});
return new AiChatMessageSendRespVO() return new AiChatMessageSendRespVO()
.setSend(BeanUtils.toBean(userMessage, AiChatMessageSendRespVO.Message.class)) .setSend(BeanUtils.toBean(userMessage, AiChatMessageSendRespVO.Message.class))
.setReceive(BeanUtils.toBean(assistantMessage, AiChatMessageSendRespVO.Message.class) .setReceive(BeanUtils.toBean(assistantMessage, AiChatMessageSendRespVO.Message.class)
@ -130,7 +136,8 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
} }
@Override @Override
public Flux<CommonResult<AiChatMessageSendRespVO>> sendChatMessageStream(AiChatMessageSendReqVO sendReqVO, Long userId) { public Flux<CommonResult<AiChatMessageSendRespVO>> sendChatMessageStream(AiChatMessageSendReqVO sendReqVO,
Long userId) {
// 1.1 校验对话存在 // 1.1 校验对话存在
AiChatConversationDO conversation = chatConversationService AiChatConversationDO conversation = chatConversationService
.validateChatConversationExists(sendReqVO.getConversationId()); .validateChatConversationExists(sendReqVO.getConversationId());
@ -143,7 +150,8 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
StreamingChatModel chatModel = modalService.getChatModel(model.getId()); StreamingChatModel chatModel = modalService.getChatModel(model.getId());
// 2. 知识库找回 // 2. 知识库找回
List<AiKnowledgeSegmentSearchRespBO> knowledgeSegments = recallKnowledgeSegment(sendReqVO.getContent(), conversation); List<AiKnowledgeSegmentSearchRespBO> knowledgeSegments = recallKnowledgeSegment(sendReqVO.getContent(),
conversation);
// 3. 插入 user 发送消息 // 3. 插入 user 发送消息
AiChatMessageDO userMessage = createChatMessage(conversation.getId(), null, model, AiChatMessageDO userMessage = createChatMessage(conversation.getId(), null, model,
@ -167,7 +175,8 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
if (StrUtil.isEmpty(contentBuffer)) { if (StrUtil.isEmpty(contentBuffer)) {
segments = BeanUtils.toBean(knowledgeSegments, AiChatMessageRespVO.KnowledgeSegment.class, segments = BeanUtils.toBean(knowledgeSegments, AiChatMessageRespVO.KnowledgeSegment.class,
segment -> TenantUtils.executeIgnore(() -> { segment -> TenantUtils.executeIgnore(() -> {
AiKnowledgeDocumentDO document = knowledgeDocumentService.getKnowledgeDocument(segment.getDocumentId()); AiKnowledgeDocumentDO document = knowledgeDocumentService
.getKnowledgeDocument(segment.getDocumentId());
segment.setDocumentName(document != null ? document.getName() : null); segment.setDocumentName(document != null ? document.getName() : null);
})); }));
} }
@ -192,7 +201,7 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
} }
private List<AiKnowledgeSegmentSearchRespBO> recallKnowledgeSegment(String content, private List<AiKnowledgeSegmentSearchRespBO> recallKnowledgeSegment(String content,
AiChatConversationDO conversation) { AiChatConversationDO conversation) {
// 1. 查询聊天角色 // 1. 查询聊天角色
if (conversation == null || conversation.getRoleId() == null) { if (conversation == null || conversation.getRoleId() == null) {
return Collections.emptyList(); return Collections.emptyList();
@ -212,8 +221,8 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
} }
private Prompt buildPrompt(AiChatConversationDO conversation, List<AiChatMessageDO> messages, private Prompt buildPrompt(AiChatConversationDO conversation, List<AiChatMessageDO> messages,
List<AiKnowledgeSegmentSearchRespBO> knowledgeSegments, List<AiKnowledgeSegmentSearchRespBO> knowledgeSegments,
AiModelDO model, AiChatMessageSendReqVO sendReqVO) { AiModelDO model, AiChatMessageSendReqVO sendReqVO) {
List<Message> chatMessages = new ArrayList<>(); List<Message> chatMessages = new ArrayList<>();
// 1.1 System Context 角色设定 // 1.1 System Context 角色设定
if (StrUtil.isNotBlank(conversation.getSystemMessage())) { if (StrUtil.isNotBlank(conversation.getSystemMessage())) {
@ -236,11 +245,18 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
chatMessages.add(new UserMessage(String.format(KNOWLEDGE_USER_MESSAGE_TEMPLATE, reference))); chatMessages.add(new UserMessage(String.format(KNOWLEDGE_USER_MESSAGE_TEMPLATE, reference)));
} }
// 2. 构建 ChatOptions 对象 // 2.1 查询 tool 工具
Set<String> toolNames = null;
if (conversation.getRoleId() != null) {
AiChatRoleDO chatRole = chatRoleService.getChatRole(conversation.getRoleId());
if (chatRole != null && CollUtil.isNotEmpty(chatRole.getToolIds())) {
toolNames = convertSet(toolService.getToolList(chatRole.getToolIds()), AiToolDO::getName);
}
}
// 2.2 构建 ChatOptions 对象
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform()); AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
ChatOptions chatOptions = AiUtils.buildChatOptions(platform, model.getModel(), ChatOptions chatOptions = AiUtils.buildChatOptions(platform, model.getModel(),
conversation.getTemperature(), conversation.getMaxTokens(), conversation.getTemperature(), conversation.getMaxTokens(), toolNames);
SetUtils.asSet("directory_list", "weather_query"));
return new Prompt(chatMessages, chatOptions); return new Prompt(chatMessages, chatOptions);
} }
@ -255,8 +271,8 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
* @return 消息上下文 * @return 消息上下文
*/ */
private List<AiChatMessageDO> filterContextMessages(List<AiChatMessageDO> messages, private List<AiChatMessageDO> filterContextMessages(List<AiChatMessageDO> messages,
AiChatConversationDO conversation, AiChatConversationDO conversation,
AiChatMessageSendReqVO sendReqVO) { AiChatMessageSendReqVO sendReqVO) {
if (conversation.getMaxContexts() == null || ObjUtil.notEqual(sendReqVO.getUseContext(), Boolean.TRUE)) { if (conversation.getMaxContexts() == null || ObjUtil.notEqual(sendReqVO.getUseContext(), Boolean.TRUE)) {
return Collections.emptyList(); return Collections.emptyList();
} }
@ -285,9 +301,9 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
} }
private AiChatMessageDO createChatMessage(Long conversationId, Long replyId, private AiChatMessageDO createChatMessage(Long conversationId, Long replyId,
AiModelDO model, Long userId, Long roleId, AiModelDO model, Long userId, Long roleId,
MessageType messageType, String content, Boolean useContext, MessageType messageType, String content, Boolean useContext,
List<AiKnowledgeSegmentSearchRespBO> knowledgeSegments) { List<AiKnowledgeSegmentSearchRespBO> knowledgeSegments) {
AiChatMessageDO message = new AiChatMessageDO().setConversationId(conversationId).setReplyId(replyId) AiChatMessageDO message = new AiChatMessageDO().setConversationId(conversationId).setReplyId(replyId)
.setModel(model.getModel()).setModelId(model.getId()).setUserId(userId).setRoleId(roleId) .setModel(model.getModel()).setModelId(model.getId()).setUserId(userId).setRoleId(roleId)
.setType(messageType.getValue()).setContent(content).setUseContext(useContext) .setType(messageType.getValue()).setContent(content).setUseContext(useContext)

View File

@ -39,11 +39,15 @@ public class AiChatRoleServiceImpl implements AiChatRoleService {
@Resource @Resource
private AiKnowledgeService knowledgeService; private AiKnowledgeService knowledgeService;
@Resource
private AiToolService toolService;
@Override @Override
public Long createChatRole(AiChatRoleSaveReqVO createReqVO) { public Long createChatRole(AiChatRoleSaveReqVO createReqVO) {
// 校验文档 // 校验文档
validateDocuments(createReqVO.getKnowledgeIds()); validateDocuments(createReqVO.getKnowledgeIds());
// 校验工具
validateTools(createReqVO.getToolIds());
// 保存角色 // 保存角色
AiChatRoleDO chatRole = BeanUtils.toBean(createReqVO, AiChatRoleDO.class); AiChatRoleDO chatRole = BeanUtils.toBean(createReqVO, AiChatRoleDO.class);
@ -55,6 +59,8 @@ public class AiChatRoleServiceImpl implements AiChatRoleService {
public Long createChatRoleMy(AiChatRoleSaveMyReqVO createReqVO, Long userId) { public Long createChatRoleMy(AiChatRoleSaveMyReqVO createReqVO, Long userId) {
// 校验文档 // 校验文档
validateDocuments(createReqVO.getKnowledgeIds()); validateDocuments(createReqVO.getKnowledgeIds());
// 校验工具
validateTools(createReqVO.getToolIds());
// 保存角色 // 保存角色
AiChatRoleDO chatRole = BeanUtils.toBean(createReqVO, AiChatRoleDO.class).setUserId(userId) AiChatRoleDO chatRole = BeanUtils.toBean(createReqVO, AiChatRoleDO.class).setUserId(userId)
@ -69,6 +75,8 @@ public class AiChatRoleServiceImpl implements AiChatRoleService {
validateChatRoleExists(updateReqVO.getId()); validateChatRoleExists(updateReqVO.getId());
// 校验文档 // 校验文档
validateDocuments(updateReqVO.getKnowledgeIds()); validateDocuments(updateReqVO.getKnowledgeIds());
// 校验工具
validateTools(updateReqVO.getToolIds());
// 更新角色 // 更新角色
AiChatRoleDO updateObj = BeanUtils.toBean(updateReqVO, AiChatRoleDO.class); AiChatRoleDO updateObj = BeanUtils.toBean(updateReqVO, AiChatRoleDO.class);
@ -84,6 +92,8 @@ public class AiChatRoleServiceImpl implements AiChatRoleService {
} }
// 校验文档 // 校验文档
validateDocuments(updateReqVO.getKnowledgeIds()); validateDocuments(updateReqVO.getKnowledgeIds());
// 校验工具
validateTools(updateReqVO.getToolIds());
// 更新 // 更新
AiChatRoleDO updateObj = BeanUtils.toBean(updateReqVO, AiChatRoleDO.class); AiChatRoleDO updateObj = BeanUtils.toBean(updateReqVO, AiChatRoleDO.class);
@ -103,6 +113,19 @@ public class AiChatRoleServiceImpl implements AiChatRoleService {
knowledgeIds.forEach(knowledgeService::validateKnowledgeExists); knowledgeIds.forEach(knowledgeService::validateKnowledgeExists);
} }
/**
* 校验工具是否存在
*
* @param toolIds 工具编号列表
*/
private void validateTools(List<Long> toolIds) {
if (CollUtil.isEmpty(toolIds)) {
return;
}
// 遍历校验每个工具是否存在
toolIds.forEach(toolService::validateToolExists);
}
@Override @Override
public void deleteChatRole(Long id) { public void deleteChatRole(Long id) {
// 校验存在 // 校验存在

View File

@ -6,6 +6,9 @@ import cn.iocoder.yudao.module.ai.controller.admin.model.vo.tool.AiToolSaveReqVO
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiToolDO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiToolDO;
import jakarta.validation.Valid; import jakarta.validation.Valid;
import java.util.Collection;
import java.util.List;
/** /**
* AI 工具 Service 接口 * AI 工具 Service 接口
* *
@ -14,7 +17,7 @@ import jakarta.validation.Valid;
public interface AiToolService { public interface AiToolService {
/** /**
* 创建AI 工具 * 创建工具
* *
* @param createReqVO 创建信息 * @param createReqVO 创建信息
* @return 编号 * @return 编号
@ -22,33 +25,56 @@ public interface AiToolService {
Long createTool(@Valid AiToolSaveReqVO createReqVO); Long createTool(@Valid AiToolSaveReqVO createReqVO);
/** /**
* 更新AI 工具 * 更新工具
* *
* @param updateReqVO 更新信息 * @param updateReqVO 更新信息
*/ */
void updateTool(@Valid AiToolSaveReqVO updateReqVO); void updateTool(@Valid AiToolSaveReqVO updateReqVO);
/** /**
* 删除AI 工具 * 删除工具
* *
* @param id 编号 * @param id 编号
*/ */
void deleteTool(Long id); void deleteTool(Long id);
/** /**
* 获得AI 工具 * 校验工具是否存在
* *
* @param id 编号 * @param id 编号
* @return AI 工具 */
void validateToolExists(Long id);
/**
* 获得工具
*
* @param id 编号
* @return 工具
*/ */
AiToolDO getTool(Long id); AiToolDO getTool(Long id);
/** /**
* 获得AI 工具分页 * 获得工具列表
*
* @param ids 编号列表
* @return 工具列表
*/
List<AiToolDO> getToolList(Collection<Long> ids);
/**
* 获得工具分页
* *
* @param pageReqVO 分页查询 * @param pageReqVO 分页查询
* @return AI 工具分页 * @return 工具分页
*/ */
PageResult<AiToolDO> getToolPage(AiToolPageReqVO pageReqVO); PageResult<AiToolDO> getToolPage(AiToolPageReqVO pageReqVO);
/**
* 获得工具列表
*
* @param status 状态
* @return 工具列表
*/
List<AiToolDO> getToolListByStatus(Integer status);
} }

View File

@ -12,6 +12,9 @@ import org.springframework.beans.factory.NoSuchBeanDefinitionException;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import org.springframework.validation.annotation.Validated; import org.springframework.validation.annotation.Validated;
import java.util.Collection;
import java.util.List;
import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception; import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception;
import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.TOOL_NAME_NOT_EXISTS; import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.TOOL_NAME_NOT_EXISTS;
import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.TOOL_NOT_EXISTS; import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.TOOL_NOT_EXISTS;
@ -59,7 +62,8 @@ public class AiToolServiceImpl implements AiToolService {
toolMapper.deleteById(id); toolMapper.deleteById(id);
} }
private void validateToolExists(Long id) { @Override
public void validateToolExists(Long id) {
if (toolMapper.selectById(id) == null) { if (toolMapper.selectById(id) == null) {
throw exception(TOOL_NOT_EXISTS); throw exception(TOOL_NOT_EXISTS);
} }
@ -78,9 +82,19 @@ public class AiToolServiceImpl implements AiToolService {
return toolMapper.selectById(id); return toolMapper.selectById(id);
} }
@Override
public List<AiToolDO> getToolList(Collection<Long> ids) {
return toolMapper.selectBatchIds(ids);
}
@Override @Override
public PageResult<AiToolDO> getToolPage(AiToolPageReqVO pageReqVO) { public PageResult<AiToolDO> getToolPage(AiToolPageReqVO pageReqVO) {
return toolMapper.selectPage(pageReqVO); return toolMapper.selectPage(pageReqVO);
} }
@Override
public List<AiToolDO> getToolListByStatus(Integer status) {
return toolMapper.selectListByStatus(status);
}
} }