package com.realtime.protection.server.command; import com.alibaba.excel.util.ListUtils; import com.baomidou.dynamic.datasource.annotation.DS; import com.baomidou.dynamic.datasource.annotation.DSTransactional; import com.realtime.protection.configuration.entity.task.TaskCommandInfo; import com.realtime.protection.configuration.entity.whitelist.WhiteListObject; import com.realtime.protection.configuration.utils.Counter; import com.realtime.protection.configuration.utils.SqlSessionWrapper; import com.realtime.protection.server.task.status.StateHandler; import com.realtime.protection.server.whitelist.WhiteListMapper; import lombok.extern.slf4j.Slf4j; import org.springframework.stereotype.Service; import java.time.LocalDateTime; import java.time.format.DateTimeFormatter; import java.util.Collections; import java.util.List; import java.util.UUID; import java.util.function.Function; @Service @Slf4j @DS("doris") public class CommandService { private final CommandMapper commandMapper; private final SqlSessionWrapper sqlSessionWrapper; private final Counter counter; private final WhiteListMapper whiteListMapper; private static final int BatchSize = 100; private final StateHandler stateHandler; public CommandService(CommandMapper commandMapper, SqlSessionWrapper sqlSessionWrapper, Counter counter, WhiteListMapper whiteListMapper, StateHandler stateHandler) { this.commandMapper = commandMapper; this.sqlSessionWrapper = sqlSessionWrapper; this.counter = counter; this.whiteListMapper = whiteListMapper; this.stateHandler = stateHandler; } @DSTransactional public String createCommand(TaskCommandInfo commandInfo) { String uuid = commandMapper.queryCommandInfo(commandInfo); if (uuid != null) { return uuid; } commandInfo.setDisplayId( "ZL-" + LocalDateTime.now().format(DateTimeFormatter.ofPattern("yyyyMMdd")) + "-" + String.format("%06d", counter.generateId("command")) ); //指令:白名单检查 List whiteListsHit = commandMapper.whiteListCommandCheck(commandInfo.getFiveTupleWithMask()); if (!whiteListsHit.isEmpty()) { commandInfo.setUUID(UUID.randomUUID().toString()); commandMapper.createCommandInWhiteListHit(commandInfo); commandMapper.createCommandWhiteListConnect(commandInfo.getUUID(), whiteListsHit); //写入历史表 insertCommandHistory(commandInfo.getUUID()); return commandInfo.getUUID(); } commandInfo.setUUID(UUID.randomUUID().toString()); commandMapper.createCommand(commandInfo); //写入历史表 insertCommandHistory(commandInfo.getUUID()); return commandInfo.getUUID(); } public List createCommands(List taskCommandInfos) { List commandUUIDs = ListUtils.newArrayListWithExpectedSize(taskCommandInfos.size()); Function, Boolean>> function = mapper -> list -> { List taskCommandInfoBatch = ListUtils.newArrayListWithExpectedSize(BatchSize); for (TaskCommandInfo info : list) { info.setUUID(UUID.randomUUID().toString()); commandUUIDs.add(info.getUUID()); info.setDisplayId( "ZL-" + LocalDateTime.now().format(DateTimeFormatter.ofPattern("yyyyMMdd")) + "-" + String.format("%06d", counter.generateId("command")) ); taskCommandInfoBatch.add(info); if (taskCommandInfoBatch.size() < BatchSize) { continue; } //因为createCommands只用于静态规则生成command,静态规则已经检查了白名单,所以不检查了 commandMapper.createCommands(taskCommandInfoBatch); insertCommandHistoryBatch(taskCommandInfoBatch); taskCommandInfoBatch.clear(); } if (!taskCommandInfoBatch.isEmpty()) { commandMapper.createCommands(taskCommandInfoBatch); insertCommandHistoryBatch(taskCommandInfoBatch); taskCommandInfoBatch.clear(); } return true; }; sqlSessionWrapper.startBatchSession(CommandMapper.class, function, taskCommandInfos); return commandUUIDs; } public List queryCommandInfos(Long taskId, String sourceIP, String sourcePort, String destinationIP, String destinationPort, Integer page, Integer pageNum) { return commandMapper.queryCommandInfos(taskId, sourceIP, sourcePort, destinationIP, destinationPort, page, pageNum); } public TaskCommandInfo queryCommandInfoByUUID(String uuid) { return commandMapper.queryCommandInfoByUUID(uuid); } public Boolean startCommandsByTaskId(Long taskId) { return commandMapper.startCommandsByTaskId(taskId); } public Boolean stopCommandsByTaskId(Long taskId) { return commandMapper.stopCommandsByTaskId(taskId); } public Boolean removeCommandsByTaskId(Long taskId) { return commandMapper.removeCommandsByTaskId(taskId); } public Boolean setCommandJudged(String commandId, Boolean isJudged) { //设置指令是否已经研判 Boolean success = commandMapper.setCommandJudged(commandId, isJudged); //isJudged为true时,发送指令首次 下发信号 try { List commandUUIDs = Collections.singletonList(commandId); if (isJudged){ stateHandler.sendCommandDistributeSignal(commandUUIDs); } }catch (Exception e) { log.info(String.format("动态任务研判后任务首次指令下发c3出错,任务id: %d,commandUUIDs: %s", queryCommandInfoByUUID(commandId).getTaskId(), commandId)); } return success; } public Integer queryCommandTotalNum(Long taskId, String sourceIP, String sourcePort, String destinationIP, String destinationPort){ return commandMapper.queryCommandTotalNum(taskId, sourceIP, sourcePort, destinationIP, destinationPort); } public void insertCommandHistory(String commandUUID) { commandMapper.updateCommandHistoryExpireTime(commandUUID); commandMapper.insertCommandHistory(commandUUID); } public void insertCommandHistoryBatch(List commandIdList) { List commandIds = ListUtils.newArrayListWithExpectedSize(commandIdList.size()); commandIdList.forEach(item -> commandIds.add(item.getUUID())); commandMapper.updateCommandHistoryExpireTimeBatch(commandIds); commandMapper.insertCommandHistoryBatch(commandIds); } }