Merge remote-tracking branch 'origin/master'
# Conflicts: # src/main/java/com/realtime/protection/server/rule/dynamicrule/DynamicRuleControllerApi.java # src/main/java/com/realtime/protection/server/rule/staticrule/StaticRuleControllerApi.java
This commit is contained in:
@@ -29,7 +29,7 @@ public class ProtectObject {
|
||||
private String protectObjectSystemName;
|
||||
|
||||
@JsonProperty("proobj_ip_address")
|
||||
@Pattern(regexp = "^(\\d{1,3})\\.(\\d{1,3})\\.(\\d{1,3})\\.(\\d{1,3})$", message = "Invalid IPv4 Address")
|
||||
@Pattern(regexp = "^(\\d{1,3})\\.(\\d{1,3})\\.(\\d{1,3})\\.(\\d{1,3})$", message = "无效IPv4地址")
|
||||
@ExcelProperty("IP地址")
|
||||
@Schema(description = "防护对象IPv4地址", example = "192.168.0.1")
|
||||
private String protectObjectIPAddress;
|
||||
|
||||
@@ -1,32 +0,0 @@
|
||||
package com.realtime.protection.configuration.entity.task;
|
||||
|
||||
import com.realtime.protection.configuration.utils.enums.ProtocolEnum;
|
||||
import lombok.Data;
|
||||
|
||||
import java.time.LocalDateTime;
|
||||
|
||||
@Data
|
||||
public class Command {
|
||||
private FiveTupleWithMask fiveTupleWithMask;
|
||||
private Long taskId;
|
||||
|
||||
private String operation;
|
||||
private LocalDateTime validTime;
|
||||
private LocalDateTime invalidTime;
|
||||
|
||||
public static Command generateCommand(TaskCommandInfo info, LocalDateTime validTime) {
|
||||
Command command = new Command();
|
||||
|
||||
FiveTupleWithMask fiveTupleWithMask = info.getFiveTupleWithMask();
|
||||
if (fiveTupleWithMask.getProtocol() != null)
|
||||
fiveTupleWithMask.setProtocolNum(ProtocolEnum.valueOf(fiveTupleWithMask.getProtocol()).getProtocolNumber());
|
||||
|
||||
command.setFiveTupleWithMask(fiveTupleWithMask);
|
||||
command.setTaskId(info.getTaskId());
|
||||
command.setOperation(info.getOperation());
|
||||
command.setValidTime(validTime);
|
||||
command.setInvalidTime(info.getEndTime());
|
||||
|
||||
return command;
|
||||
}
|
||||
}
|
||||
@@ -1,19 +1,58 @@
|
||||
package com.realtime.protection.configuration.entity.task;
|
||||
|
||||
import io.swagger.v3.oas.annotations.media.Schema;
|
||||
import jakarta.validation.constraints.Max;
|
||||
import jakarta.validation.constraints.Min;
|
||||
import jakarta.validation.constraints.Pattern;
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
public class FiveTupleWithMask {
|
||||
@Schema(description = "地址类型(IPv4 or IPv6)", example = "4")
|
||||
private Integer addrType;
|
||||
|
||||
@Schema(description = "源IP", example = "192.168.104.14")
|
||||
@Pattern(regexp = "^(\\d{1,3})\\.(\\d{1,3})\\.(\\d{1,3})\\.(\\d{1,3})$", message = "源IP:无效IPv4地址")
|
||||
private String sourceIP;
|
||||
|
||||
@Schema(description = "源端口", example = "114")
|
||||
@Max(value = 65535, message = "源端口不可大于65535")
|
||||
@Min(value = 1, message = "源端口不可小于1")
|
||||
private String sourcePort;
|
||||
|
||||
@Schema(description = "目的IP", example = "102.165.11.39")
|
||||
@Pattern(regexp = "^(\\d{1,3})\\.(\\d{1,3})\\.(\\d{1,3})\\.(\\d{1,3})$", message = "目的IP:无效IPv4地址")
|
||||
private String destinationIP;
|
||||
|
||||
@Schema(description = "目的端口", example = "514")
|
||||
@Max(value = 65535, message = "目的端口不可大于65535")
|
||||
@Min(value = 1, message = "目的端口不可小于1")
|
||||
private String destinationPort;
|
||||
|
||||
@Schema(description = "协议名称", example = "TCP", accessMode = Schema.AccessMode.WRITE_ONLY)
|
||||
private String protocol;
|
||||
|
||||
@Schema(description = "协议号", example = "6", accessMode = Schema.AccessMode.READ_ONLY)
|
||||
private Integer protocolNum;
|
||||
|
||||
@Schema(description = "源IP掩码", example = "255.255.255.0")
|
||||
@Pattern(regexp = "^(\\d{1,3})\\.(\\d{1,3})\\.(\\d{1,3})\\.(\\d{1,3})$", message = "源IP掩码:无效IPv4地址")
|
||||
private String maskSourceIP;
|
||||
|
||||
@Schema(description = "源端口掩码", example = "0")
|
||||
@Max(value = 65535, message = "源端口掩码不可大于65535")
|
||||
@Min(value = 1, message = "源端口掩码不可小于1")
|
||||
private String maskSourcePort;
|
||||
|
||||
@Schema(description = "目的IP掩码", example = "255.255.0.0")
|
||||
@Pattern(regexp = "^(\\d{1,3})\\.(\\d{1,3})\\.(\\d{1,3})\\.(\\d{1,3})$", message = "目的IP掩码:无效IPv4地址")
|
||||
private String maskDestinationIP;
|
||||
|
||||
@Schema(description = "目的端口掩码", example = "0")
|
||||
@Max(value = 65535, message = "目的端口掩码不可大于65535")
|
||||
@Min(value = 1, message = "目的端口掩码不可小于1")
|
||||
private String maskDestinationPort;
|
||||
|
||||
@Schema(description = "协议掩码", example = "0")
|
||||
private String maskProtocol;
|
||||
}
|
||||
|
||||
@@ -52,24 +52,24 @@ public class Task {
|
||||
private String taskAct;
|
||||
|
||||
@JsonProperty("task_create_username")
|
||||
@Schema(hidden = true)
|
||||
@Schema(description = "任务创建人名称", accessMode = Schema.AccessMode.READ_ONLY)
|
||||
private String taskCreateUsername;
|
||||
|
||||
@JsonProperty("task_create_depart")
|
||||
@Schema(hidden = true)
|
||||
@Schema(description = "任务创建人处室", accessMode = Schema.AccessMode.READ_ONLY)
|
||||
private String taskCreateDepart;
|
||||
|
||||
@JsonProperty("task_create_userid")
|
||||
@Schema(hidden = true)
|
||||
private Long taskCreateUserId;
|
||||
@Schema(description = "任务创建人ID", accessMode = Schema.AccessMode.READ_ONLY)
|
||||
private Integer taskCreateUserId;
|
||||
|
||||
@JsonProperty("static_rule_ids")
|
||||
@Schema(description = "静态规则ID列表,动态和静态至少存在1个规则", example = "[10, 12]")
|
||||
private List<Long> staticRuleIds;
|
||||
private List<Integer> staticRuleIds;
|
||||
|
||||
@JsonProperty("dynamic_rule_ids")
|
||||
@Schema(description = "动态规则ID列表,动态和静态至少存在1个规则", example = "[20, 30]")
|
||||
private List<Long> dynamicRuleIds;
|
||||
private List<Integer> dynamicRuleIds;
|
||||
|
||||
@JsonProperty("task_status")
|
||||
@Schema(description = "任务状态(0为未启动,1为生成中,2为运行中,3为暂停中,4为已停止,5为已结束,6为失败)", accessMode = Schema.AccessMode.READ_ONLY)
|
||||
|
||||
@@ -1,19 +1,68 @@
|
||||
package com.realtime.protection.configuration.entity.task;
|
||||
|
||||
import io.swagger.v3.oas.annotations.media.Schema;
|
||||
import jakarta.validation.constraints.NotNull;
|
||||
import lombok.Data;
|
||||
|
||||
import java.time.LocalDateTime;
|
||||
|
||||
@Data
|
||||
public class TaskCommandInfo {
|
||||
private FiveTupleWithMask fiveTupleWithMask;
|
||||
@Schema(description = "指令UUID", accessMode = Schema.AccessMode.READ_ONLY)
|
||||
private String UUID;
|
||||
|
||||
@Schema(description = "任务ID", accessMode = Schema.AccessMode.READ_ONLY)
|
||||
private Long taskId;
|
||||
|
||||
@Schema(description = "规则ID", hidden = true)
|
||||
private Long ruleId;
|
||||
|
||||
// 额外字段
|
||||
private String operation;
|
||||
@Schema(description = "任务创建人名称", accessMode = Schema.AccessMode.READ_ONLY)
|
||||
private String taskCreateUsername;
|
||||
|
||||
@Schema(description = "任务创建人处室", accessMode = Schema.AccessMode.READ_ONLY)
|
||||
private String taskCreateDepart;
|
||||
|
||||
@Schema(description = "任务创建人ID", accessMode = Schema.AccessMode.READ_ONLY)
|
||||
private Integer taskCreateUserId;
|
||||
|
||||
@Schema(description = "任务名称", example = "API测试任务")
|
||||
@NotNull(message = "任务名称不能为空")
|
||||
private String taskName;
|
||||
|
||||
@Schema(description = "任务类型", example = "1")
|
||||
@NotNull(message = "任务类型不能为空")
|
||||
private Integer taskType;
|
||||
|
||||
@Schema(description = "任务操作", example = "阻断")
|
||||
@NotNull(message = "任务操作不能为空。")
|
||||
private String taskAct;
|
||||
|
||||
@Schema(description = "指令下发频率", example = "30")
|
||||
@NotNull(message = "指令下发频率不能为空。")
|
||||
private Integer frequency;
|
||||
|
||||
@Schema(description = "任务开始时间", example = "2025-10-14T10:23:33")
|
||||
@NotNull(message = "任务开始时间不能为空。")
|
||||
private LocalDateTime startTime;
|
||||
|
||||
@Schema(description = "任务结束时间", example = "2026-10-22T10:33:22")
|
||||
@NotNull(message = "指令结束时间不能为空。")
|
||||
private LocalDateTime endTime;
|
||||
|
||||
@Schema(description = "五元组信息")
|
||||
@NotNull(message = "五元组信息不能为空。")
|
||||
private FiveTupleWithMask fiveTupleWithMask;
|
||||
|
||||
@Schema(description = "指令下发次数", accessMode = Schema.AccessMode.READ_ONLY)
|
||||
private Integer commandSentTimes;
|
||||
|
||||
@Schema(description = "指令成功次数", accessMode = Schema.AccessMode.READ_ONLY)
|
||||
private Integer commandSuccessTimes;
|
||||
|
||||
@Schema(description = "首次下发时间", accessMode = Schema.AccessMode.READ_ONLY)
|
||||
private LocalDateTime earliestSendTime;
|
||||
|
||||
@Schema(description = "最新下发时间", accessMode = Schema.AccessMode.READ_ONLY)
|
||||
private LocalDateTime latestSendTime;
|
||||
}
|
||||
|
||||
@@ -15,6 +15,8 @@ import org.springframework.web.bind.annotation.ExceptionHandler;
|
||||
import org.springframework.web.bind.annotation.RestControllerAdvice;
|
||||
import org.springframework.web.method.annotation.HandlerMethodValidationException;
|
||||
|
||||
import java.sql.SQLException;
|
||||
import java.sql.SQLIntegrityConstraintViolationException;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@RestControllerAdvice
|
||||
@@ -36,7 +38,12 @@ public class GlobalExceptionHandler {
|
||||
|
||||
|
||||
@Order(2)
|
||||
@ExceptionHandler(value = {PersistenceException.class, DuplicateKeyException.class})
|
||||
@ExceptionHandler(value = {
|
||||
PersistenceException.class,
|
||||
DuplicateKeyException.class,
|
||||
SQLException.class,
|
||||
SQLIntegrityConstraintViolationException.class
|
||||
})
|
||||
public ResponseResult handleSQLException(Exception e) {
|
||||
log.info("遭遇数据库异常:" + e.getMessage());
|
||||
return ResponseResult.invalid().setMessage(
|
||||
@@ -89,7 +96,8 @@ public class GlobalExceptionHandler {
|
||||
.setMessage("Doris数据库指令生成遭遇异常:" + e.getMessage());
|
||||
|
||||
try {
|
||||
stateChangeService.changeState(StateEnum.FAILED.getStateNum(), e.taskId);
|
||||
// 内部修改状态,可以跳过一切状态检查
|
||||
stateChangeService.changeState(StateEnum.FAILED.getStateNum(), e.taskId, true);
|
||||
} catch (Exception another) {
|
||||
responseResult.setAnother(ResponseResult.error().setMessage(e.getMessage()));
|
||||
}
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
package com.realtime.protection.configuration.utils.status;
|
||||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
@Slf4j
|
||||
public class AuditStatusValidator {
|
||||
|
||||
private final Integer auditStatusOriginal;
|
||||
@@ -12,8 +15,8 @@ public class AuditStatusValidator {
|
||||
return new AuditStatusValidator(auditStatusOriginal);
|
||||
}
|
||||
|
||||
public Boolean checkValidate(Integer auditStatusNow) {
|
||||
switch (auditStatusNow) {
|
||||
public Boolean checkValidate(Integer newAuditStatus) {
|
||||
switch (newAuditStatus) {
|
||||
case 0, 1 -> {
|
||||
return auditStatusOriginal != 2;
|
||||
}
|
||||
@@ -21,6 +24,7 @@ public class AuditStatusValidator {
|
||||
return auditStatusOriginal != 1;
|
||||
}
|
||||
default -> {
|
||||
log.debug("欲修改的审核状态不正确,需要使用正确的审核状态,当前的审核状态:" + auditStatusOriginal);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
package com.realtime.protection.server.command;
|
||||
|
||||
import com.realtime.protection.configuration.entity.task.Command;
|
||||
import com.realtime.protection.configuration.entity.task.TaskCommandInfo;
|
||||
import org.apache.ibatis.annotations.Mapper;
|
||||
import org.apache.ibatis.annotations.Param;
|
||||
|
||||
@@ -8,13 +8,15 @@ import java.util.List;
|
||||
|
||||
@Mapper
|
||||
public interface CommandMapper {
|
||||
Boolean createCommand(@Param("command") Command command);
|
||||
Boolean createCommand(@Param("info") TaskCommandInfo taskCommandInfo);
|
||||
|
||||
void createCommands(@Param("commands") List<Command> commands);
|
||||
void createCommands(@Param("command_infos") List<TaskCommandInfo> taskCommandInfos);
|
||||
|
||||
Boolean stopCommandsByTaskId(@Param("task_id") Long taskId);
|
||||
|
||||
Boolean removeCommandsByTaskId(@Param("task_id") Long taskId);
|
||||
|
||||
Boolean startCommandsByTaskId(@Param("task_id") Long taskId);
|
||||
|
||||
List<TaskCommandInfo> queryCommandInfoByTaskId(@Param("task_id") Long taskId);
|
||||
}
|
||||
|
||||
@@ -2,120 +2,69 @@ package com.realtime.protection.server.command;
|
||||
|
||||
import com.alibaba.excel.util.ListUtils;
|
||||
import com.baomidou.dynamic.datasource.annotation.DS;
|
||||
import com.realtime.protection.configuration.entity.task.Command;
|
||||
import com.realtime.protection.configuration.entity.task.TaskCommandInfo;
|
||||
import com.realtime.protection.configuration.exception.DorisStartException;
|
||||
import com.realtime.protection.configuration.utils.SqlSessionWrapper;
|
||||
import com.realtime.protection.configuration.utils.enums.StateEnum;
|
||||
import com.realtime.protection.server.task.TaskService;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.scheduling.annotation.Async;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.time.LocalDateTime;
|
||||
import java.util.List;
|
||||
import java.util.function.Function;
|
||||
|
||||
@Service
|
||||
@Slf4j
|
||||
@DS("doris")
|
||||
public class CommandService {
|
||||
|
||||
private final CommandMapper commandMapper;
|
||||
private final TaskService taskService;
|
||||
private final SqlSessionWrapper sqlSessionWrapper;
|
||||
private static final int BatchSize = 100;
|
||||
private final Function<CommandMapper, Function<TaskCommandInfo, Void>> createCommandBatchFunction;
|
||||
|
||||
public CommandService(CommandMapper commandMapper, TaskService taskService, SqlSessionWrapper sqlSessionWrapper) {
|
||||
public CommandService(CommandMapper commandMapper, SqlSessionWrapper sqlSessionWrapper) {
|
||||
this.commandMapper = commandMapper;
|
||||
this.taskService = taskService;
|
||||
this.sqlSessionWrapper = sqlSessionWrapper;
|
||||
this.createCommandBatchFunction = mapper -> info -> {
|
||||
if (info.getFrequency() == null) {
|
||||
Command command = Command.generateCommand(info, info.getStartTime());
|
||||
mapper.createCommand(command);
|
||||
}
|
||||
}
|
||||
|
||||
List<Command> commandBatch = ListUtils.newArrayListWithExpectedSize(BatchSize);
|
||||
LocalDateTime validTime = info.getStartTime();
|
||||
public Boolean createCommand(TaskCommandInfo commandInfo) {
|
||||
return commandMapper.createCommand(commandInfo);
|
||||
}
|
||||
|
||||
while (validTime.isBefore(info.getEndTime())) {
|
||||
Command command = Command.generateCommand(info, validTime);
|
||||
commandBatch.add(command);
|
||||
|
||||
validTime = validTime.plusMinutes(info.getFrequency());
|
||||
|
||||
if (commandBatch.size() < BatchSize) {
|
||||
public void createCommands(List<TaskCommandInfo> taskCommandInfos) {
|
||||
Function<CommandMapper, Function<List<TaskCommandInfo>, Boolean>> function = mapper -> list -> {
|
||||
List<TaskCommandInfo> taskCommandInfoBatch = ListUtils.newArrayListWithExpectedSize(BatchSize);
|
||||
for (TaskCommandInfo info : list) {
|
||||
taskCommandInfoBatch.add(info);
|
||||
if (taskCommandInfoBatch.size() < BatchSize) {
|
||||
continue;
|
||||
}
|
||||
mapper.createCommands(commandBatch);
|
||||
commandBatch.clear();
|
||||
|
||||
commandMapper.createCommands(taskCommandInfoBatch);
|
||||
taskCommandInfoBatch.clear();
|
||||
}
|
||||
|
||||
if (!commandBatch.isEmpty()) {
|
||||
mapper.createCommands(commandBatch);
|
||||
commandBatch.clear();
|
||||
if (!taskCommandInfoBatch.isEmpty()) {
|
||||
commandMapper.createCommands(taskCommandInfoBatch);
|
||||
taskCommandInfoBatch.clear();
|
||||
}
|
||||
|
||||
log.debug(String.format("在task(%d)和rule(%d)中构建了全部指令",
|
||||
info.getTaskId(), info.getRuleId()));
|
||||
return null;
|
||||
};
|
||||
}
|
||||
|
||||
@Async
|
||||
@DS("doris")
|
||||
public void createCommand(TaskCommandInfo commandInfo) throws DorisStartException {
|
||||
try {
|
||||
sqlSessionWrapper.startBatchSession(CommandMapper.class, createCommandBatchFunction, commandInfo);
|
||||
taskService.changeTaskStatus(commandInfo.getTaskId(), StateEnum.RUNNING.getStateNum());
|
||||
} catch (Exception e) {
|
||||
throw new DorisStartException(e, commandInfo.getTaskId());
|
||||
}
|
||||
}
|
||||
|
||||
@Async
|
||||
@DS("doris")
|
||||
public void createCommands(List<TaskCommandInfo> taskCommandInfos) throws DorisStartException {
|
||||
Function<CommandMapper, Function<List<TaskCommandInfo>, Void>> function = mapper -> list -> {
|
||||
if (list == null || list.isEmpty()) {
|
||||
return null;
|
||||
}
|
||||
|
||||
for (TaskCommandInfo info : list) {
|
||||
createCommandBatchFunction.apply(mapper).apply(info);
|
||||
}
|
||||
|
||||
taskService.changeTaskStatus(list.get(0).getTaskId(), StateEnum.RUNNING.getStateNum());
|
||||
return null;
|
||||
return true;
|
||||
};
|
||||
|
||||
try {
|
||||
sqlSessionWrapper.startBatchSession(CommandMapper.class, function, taskCommandInfos);
|
||||
} catch (Exception e) {
|
||||
TaskCommandInfo info = null;
|
||||
if (taskCommandInfos != null) {
|
||||
info = taskCommandInfos.get(0);
|
||||
}
|
||||
Long taskId = null;
|
||||
if (info != null) {
|
||||
taskId = info.getTaskId();
|
||||
}
|
||||
throw new DorisStartException(e, taskId);
|
||||
}
|
||||
sqlSessionWrapper.startBatchSession(CommandMapper.class, function, taskCommandInfos);
|
||||
|
||||
}
|
||||
|
||||
public List<TaskCommandInfo> queryCommandInfoByTaskId(Long taskId) {
|
||||
return commandMapper.queryCommandInfoByTaskId(taskId);
|
||||
}
|
||||
|
||||
@DS("doris")
|
||||
public Boolean startCommandsByTaskId(Long taskId) {
|
||||
return commandMapper.startCommandsByTaskId(taskId);
|
||||
}
|
||||
|
||||
@DS("doris")
|
||||
public Boolean stopCommandsByTaskId(Long taskId) {
|
||||
return commandMapper.stopCommandsByTaskId(taskId);
|
||||
}
|
||||
|
||||
@DS("doris")
|
||||
public Boolean removeCommandsByTaskId(Long taskId) {
|
||||
return commandMapper.removeCommandsByTaskId(taskId);
|
||||
}
|
||||
|
||||
@@ -78,7 +78,6 @@ public class ProtectObjectService {
|
||||
return false;
|
||||
}
|
||||
boolean success = true;
|
||||
Integer result;
|
||||
|
||||
List<Integer> protectObjectBatch = ListUtils.newArrayListWithExpectedSize(batchSize);
|
||||
for (Integer protectObjectId : list) {
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
package com.realtime.protection.server.task;
|
||||
|
||||
import com.realtime.protection.configuration.entity.task.Task;
|
||||
import com.realtime.protection.configuration.entity.task.TaskCommandInfo;
|
||||
import com.realtime.protection.configuration.exception.DorisStartException;
|
||||
import com.realtime.protection.configuration.response.ResponseResult;
|
||||
import com.realtime.protection.configuration.utils.EntityUtils;
|
||||
import com.realtime.protection.server.command.CommandService;
|
||||
import com.realtime.protection.server.task.status.StateChangeService;
|
||||
import jakarta.validation.Valid;
|
||||
import jakarta.validation.constraints.Max;
|
||||
@@ -18,10 +20,12 @@ import java.util.List;
|
||||
public class TaskController implements TaskControllerApi {
|
||||
|
||||
private final TaskService taskService;
|
||||
private final CommandService commandService;
|
||||
private final StateChangeService stateChangeService;
|
||||
|
||||
public TaskController(TaskService taskService, StateChangeService stateChangeService) {
|
||||
public TaskController(TaskService taskService, CommandService commandService, StateChangeService stateChangeService) {
|
||||
this.taskService = taskService;
|
||||
this.commandService = commandService;
|
||||
this.stateChangeService = stateChangeService;
|
||||
}
|
||||
|
||||
@@ -43,6 +47,24 @@ public class TaskController implements TaskControllerApi {
|
||||
.setData("success", false);
|
||||
}
|
||||
|
||||
// API推送Endpoint
|
||||
@Override
|
||||
@PostMapping("/api/new")
|
||||
public ResponseResult newTaskWithAPI(@RequestBody @Valid TaskCommandInfo taskCommandInfo) {
|
||||
Long taskId = taskService.newTaskUsingCommandInfo(taskCommandInfo);
|
||||
if (taskId <= 0) {
|
||||
return ResponseResult.invalid()
|
||||
.setData("taskId", -1)
|
||||
.setData("success", false);
|
||||
}
|
||||
|
||||
commandService.createCommand(taskCommandInfo);
|
||||
|
||||
return ResponseResult.ok()
|
||||
.setData("taskId", taskId)
|
||||
.setData("success", true);
|
||||
}
|
||||
|
||||
@Override
|
||||
@GetMapping("/query")
|
||||
public ResponseResult queryTasks(@RequestParam(value = "task_status", required = false) Integer taskStatus,
|
||||
@@ -62,7 +84,7 @@ public class TaskController implements TaskControllerApi {
|
||||
Task task = taskService.queryTask(id);
|
||||
|
||||
if (task == null) {
|
||||
return ResponseResult.invalid().setMessage("Task ID is invalid");
|
||||
return ResponseResult.invalid().setMessage("无效Task ID,也许该ID对应的任务不存在?");
|
||||
}
|
||||
|
||||
return ResponseResult.ok()
|
||||
@@ -103,7 +125,16 @@ public class TaskController implements TaskControllerApi {
|
||||
@PathVariable @NotNull Long taskId) throws DorisStartException {
|
||||
return ResponseResult.ok()
|
||||
.setData("task_id", taskId)
|
||||
.setData("success", stateChangeService.changeState(stateNum, taskId))
|
||||
// 外部修改状态,需要进行状态检查
|
||||
.setData("success", stateChangeService.changeState(stateNum, taskId, false))
|
||||
.setData("status_now", taskService.queryTaskStatus(taskId));
|
||||
}
|
||||
|
||||
@Override
|
||||
@GetMapping("/{taskId}/commands")
|
||||
public ResponseResult queryCommandInfoByTaskId(@PathVariable Long taskId) {
|
||||
return ResponseResult.ok()
|
||||
.setData("success", true)
|
||||
.setData("commands", commandService.queryCommandInfoByTaskId(taskId));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package com.realtime.protection.server.task;
|
||||
|
||||
import com.realtime.protection.configuration.entity.task.Task;
|
||||
import com.realtime.protection.configuration.entity.task.TaskCommandInfo;
|
||||
import com.realtime.protection.configuration.exception.DorisStartException;
|
||||
import com.realtime.protection.configuration.response.ResponseResult;
|
||||
import io.swagger.v3.oas.annotations.Operation;
|
||||
@@ -34,6 +35,24 @@ public interface TaskControllerApi {
|
||||
)
|
||||
ResponseResult newTask(@RequestBody @Valid Task task);
|
||||
|
||||
// API推送Endpoint
|
||||
@PostMapping("/api/new")
|
||||
@Operation(
|
||||
summary = "任务推送外部API",
|
||||
description = "提供给外部的任务推送API",
|
||||
responses = {
|
||||
@ApiResponse(
|
||||
description = "返回外部任务推送结果",
|
||||
content = @Content(
|
||||
mediaType = "application/json",
|
||||
schema = @Schema(implementation = ResponseResult.class)
|
||||
)
|
||||
)
|
||||
},
|
||||
requestBody = @io.swagger.v3.oas.annotations.parameters.RequestBody(description = "任务推送信息")
|
||||
)
|
||||
ResponseResult newTaskWithAPI(@RequestBody @Valid TaskCommandInfo taskCommandInfo) throws DorisStartException;
|
||||
|
||||
@GetMapping("/query")
|
||||
@Operation(
|
||||
summary = "查询任务",
|
||||
@@ -162,4 +181,23 @@ public interface TaskControllerApi {
|
||||
)
|
||||
ResponseResult changeTaskStatus(@PathVariable @NotNull @Min(0) @Max(6) Integer stateNum,
|
||||
@PathVariable @NotNull Long taskId) throws DorisStartException;
|
||||
|
||||
@GetMapping("/{taskId}/commands")
|
||||
@Operation(
|
||||
summary = "获得任务已推送指令的相关数据",
|
||||
description = "获得任务已推送指令的相关数据,包括最新下发时间、首次下发时间、下发次数、下发成功次数等",
|
||||
responses = {
|
||||
@ApiResponse(
|
||||
description = "返回任务已推送指令的相关数据",
|
||||
content = @Content(
|
||||
mediaType = "application/json",
|
||||
schema = @Schema(implementation = ResponseResult.class)
|
||||
)
|
||||
)
|
||||
},
|
||||
parameters = {
|
||||
@Parameter(name = "taskId", description = "任务ID")
|
||||
}
|
||||
)
|
||||
ResponseResult queryCommandInfoByTaskId(@PathVariable Long taskId);
|
||||
}
|
||||
|
||||
@@ -12,10 +12,12 @@ public interface TaskMapper {
|
||||
void newTask(@Param("task") Task task);
|
||||
|
||||
void newTaskStaticRuleConcat(@Param("task_id") Long taskId,
|
||||
@Param("rule_ids") List<Long> staticRuleIds);
|
||||
@Param("rule_ids") List<Integer> staticRuleIds);
|
||||
|
||||
void newTaskDynamicRuleConcat(@Param("task_id") Long taskId,
|
||||
@Param("rule_ids") List<Long> dynamicRuleIds);
|
||||
@Param("rule_ids") List<Integer> dynamicRuleIds);
|
||||
|
||||
void newTaskUsingCommandInfo(@Param("info") TaskCommandInfo taskCommandInfo);
|
||||
|
||||
List<Task> queryTasks(@Param("task_status") Integer taskStatus, @Param("task_type") String task_type,
|
||||
@Param("task_name") String taskName, @Param("task_creator") String taskCreator,
|
||||
@@ -35,9 +37,13 @@ public interface TaskMapper {
|
||||
|
||||
Boolean changeTaskStatus(@Param("task_id") Long taskId, @Param("state") Integer stateNum);
|
||||
|
||||
List<TaskCommandInfo> getStaticCommands(@Param("task_id") Long taskId);
|
||||
List<TaskCommandInfo> getStaticCommandInfos(@Param("task_id") Long taskId);
|
||||
|
||||
Integer queryTaskAuditStatus(@Param("task_id") Long taskId);
|
||||
|
||||
Integer queryTaskStatus(@Param("task_id") Long taskId);
|
||||
|
||||
List<Integer> queryDynamicRuleIdsFromTaskId(@Param("task_id") Long taskId);
|
||||
|
||||
List<Integer> queryStaticRuleIdsFromTaskId(@Param("task_id") Long taskId);
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ import org.springframework.transaction.annotation.Transactional;
|
||||
import java.util.List;
|
||||
|
||||
@Service
|
||||
@DS("mysql")
|
||||
public class TaskService {
|
||||
private final TaskMapper taskMapper;
|
||||
|
||||
@@ -19,6 +20,10 @@ public class TaskService {
|
||||
|
||||
@Transactional
|
||||
public Long newTask(Task task) {
|
||||
task.setTaskCreateUserId(1);
|
||||
task.setTaskCreateUsername("xxx");
|
||||
task.setTaskCreateDepart("xxx");
|
||||
|
||||
taskMapper.newTask(task);
|
||||
|
||||
if (task.getStaticRuleIds() != null && !task.getStaticRuleIds().isEmpty())
|
||||
@@ -30,14 +35,33 @@ public class TaskService {
|
||||
return task.getTaskId();
|
||||
}
|
||||
|
||||
@Transactional
|
||||
public List<Task> queryTasks(Integer taskStatus,
|
||||
String taskType, String taskName, String taskCreator,
|
||||
Integer page, Integer pageSize) {
|
||||
return taskMapper.queryTasks(taskStatus, taskType, taskName, taskCreator, page, pageSize);
|
||||
List<Task> tasks = taskMapper.queryTasks(taskStatus, taskType, taskName, taskCreator, page, pageSize);
|
||||
for (Task task : tasks) {
|
||||
if (task == null) {
|
||||
continue;
|
||||
}
|
||||
task.setStaticRuleIds(taskMapper.queryStaticRuleIdsFromTaskId(task.getTaskId()));
|
||||
task.setDynamicRuleIds(taskMapper.queryDynamicRuleIdsFromTaskId(task.getTaskId()));
|
||||
}
|
||||
|
||||
return tasks;
|
||||
}
|
||||
|
||||
@Transactional
|
||||
public Task queryTask(Long id) {
|
||||
return taskMapper.queryTask(id);
|
||||
Task task = taskMapper.queryTask(id);
|
||||
if (task == null) {
|
||||
return null;
|
||||
}
|
||||
|
||||
task.setStaticRuleIds(taskMapper.queryStaticRuleIdsFromTaskId(task.getTaskId()));
|
||||
task.setDynamicRuleIds(taskMapper.queryDynamicRuleIdsFromTaskId(task.getTaskId()));
|
||||
|
||||
return task;
|
||||
}
|
||||
|
||||
@Transactional
|
||||
@@ -74,13 +98,13 @@ public class TaskService {
|
||||
return taskMapper.deleteTask(taskId);
|
||||
}
|
||||
|
||||
@DS("mysql")
|
||||
|
||||
public Boolean changeTaskStatus(Long taskId, Integer stateNum) {
|
||||
return taskMapper.changeTaskStatus(taskId, stateNum);
|
||||
}
|
||||
|
||||
public List<TaskCommandInfo> getStaticCommandInfos(Long taskId) {
|
||||
return taskMapper.getStaticCommands(taskId);
|
||||
return taskMapper.getStaticCommandInfos(taskId);
|
||||
}
|
||||
|
||||
public Integer queryTaskAuditStatus(Long taskId) {
|
||||
@@ -90,4 +114,9 @@ public class TaskService {
|
||||
public Integer queryTaskStatus(Long taskId) {
|
||||
return taskMapper.queryTaskStatus(taskId);
|
||||
}
|
||||
|
||||
public Long newTaskUsingCommandInfo(TaskCommandInfo taskCommandInfo) {
|
||||
taskMapper.newTaskUsingCommandInfo(taskCommandInfo);
|
||||
return taskCommandInfo.getTaskId();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -23,12 +23,7 @@ public class StateChangeService {
|
||||
}
|
||||
|
||||
@DSTransactional
|
||||
public Boolean changeState(Integer stateNum, Long taskId) throws DorisStartException {
|
||||
if (Objects.equals(stateNum, StateEnum.GENERATING.getStateNum()) ||
|
||||
Objects.equals(stateNum, StateEnum.FAILED.getStateNum())) {
|
||||
throw new IllegalArgumentException("非法任务状态:" + StateEnum.getStateByNum(stateNum));
|
||||
}
|
||||
|
||||
public Boolean changeState(Integer stateNum, Long taskId, Boolean inner) throws DorisStartException {
|
||||
Integer originalStateNum = taskService.queryTaskStatus(taskId);
|
||||
if (originalStateNum == null) {
|
||||
throw new IllegalArgumentException("无法找到" + taskId + "的任务状态,也许任务ID不存在?");
|
||||
@@ -38,8 +33,15 @@ public class StateChangeService {
|
||||
|
||||
State newState = StateEnum.getStateByNum(stateNum);
|
||||
|
||||
if (newState == null) {
|
||||
return false;
|
||||
if (!inner && !checkState(originalState, newState)) {
|
||||
throw new IllegalArgumentException(
|
||||
String.format("任务状态转换失败,原状态:%s,欲切换状态:%s",
|
||||
originalState.getClass().getSimpleName(),
|
||||
newState.getClass().getSimpleName()));
|
||||
}
|
||||
|
||||
if (Objects.equals(originalState, newState)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (!originalState.handle(newState, commandService, taskService, taskId)) {
|
||||
@@ -54,4 +56,21 @@ public class StateChangeService {
|
||||
// 这里一定是handle成功的状态,我们再进行task status的修改,如果handle失败,要么返回false,要么抛出异常,不会进入此处
|
||||
return taskService.changeTaskStatus(taskId, stateNum);
|
||||
}
|
||||
|
||||
private Boolean checkState(State originalState, State newState) {
|
||||
if (originalState == null || newState == null) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// FAILED、FINISHED状态以及GENERATING都只能在程序内部修改,外部接口不能修改
|
||||
if (Objects.equals(newState, StateEnum.FAILED.getState())
|
||||
|| Objects.equals(newState, StateEnum.FINISHED.getState())
|
||||
|| Objects.equals(newState, StateEnum.GENERATING.getState())) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// 在任务状态转换为GENERATING之后,我们需要在外部接口屏蔽掉所有状态
|
||||
// 我们需要保证只有任务创建函数才能将GENERATING状态转换为RUNNING状态
|
||||
return !Objects.equals(originalState, StateEnum.GENERATING.getState());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,7 +13,6 @@ public class FailedState extends StateHandler implements State {
|
||||
return switch (StateEnum.getStateEnumByState(newState)) {
|
||||
case RUNNING -> handleStart(taskService, commandService, taskId);
|
||||
case STOP -> handleStop(commandService, taskId);
|
||||
case FAILED -> true;
|
||||
default -> throw new IllegalStateException("Unexpected value: " + StateEnum.getStateEnumByState(newState));
|
||||
};
|
||||
}
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
package com.realtime.protection.server.task.status.states;
|
||||
|
||||
import com.realtime.protection.configuration.exception.DorisStartException;
|
||||
import com.realtime.protection.configuration.utils.enums.StateEnum;
|
||||
import com.realtime.protection.configuration.utils.status.State;
|
||||
import com.realtime.protection.server.command.CommandService;
|
||||
@@ -9,9 +8,9 @@ import com.realtime.protection.server.task.status.StateHandler;
|
||||
|
||||
public class GeneratingState extends StateHandler implements State {
|
||||
@Override
|
||||
public Boolean handle(State newState, CommandService commandService, TaskService taskService, Long taskId) throws DorisStartException {
|
||||
public Boolean handle(State newState, CommandService commandService, TaskService taskService, Long taskId) {
|
||||
return switch (StateEnum.getStateEnumByState(newState)) {
|
||||
case RUNNING, GENERATING -> true;
|
||||
case RUNNING -> true;
|
||||
case FAILED -> handleFailed(commandService, taskId);
|
||||
default -> throw new IllegalStateException("Unexpected value: " + StateEnum.getStateEnumByState(newState));
|
||||
};
|
||||
|
||||
@@ -11,8 +11,8 @@ public class PendingState extends StateHandler implements State {
|
||||
@Override
|
||||
public Boolean handle(State newState, CommandService commandService, TaskService taskService, Long taskId) throws DorisStartException {
|
||||
return switch (StateEnum.getStateEnumByState(newState)) {
|
||||
case GENERATING -> handleStart(taskService, commandService, taskId);
|
||||
case FAILED -> handleFailed(commandService, taskId);
|
||||
case RUNNING -> handleStart(taskService, commandService, taskId);
|
||||
default -> throw new IllegalStateException("Unexpected value: " + StateEnum.getStateEnumByState(newState));
|
||||
};
|
||||
}
|
||||
|
||||
@@ -10,7 +10,6 @@ public class RunningState extends StateHandler implements State {
|
||||
@Override
|
||||
public Boolean handle(State newState, CommandService commandService, TaskService taskService, Long taskId) {
|
||||
return switch (StateEnum.getStateEnumByState(newState)) {
|
||||
case RUNNING, GENERATING -> true;
|
||||
case PAUSED -> handlePause(commandService, taskId);
|
||||
case STOP -> handleStop(commandService, taskId);
|
||||
case FINISHED -> handleFinish(commandService, taskId);
|
||||
|
||||
@@ -2,7 +2,7 @@ package com.realtime.protection.server.whitelist;
|
||||
|
||||
import com.alibaba.excel.util.ListUtils;
|
||||
import com.realtime.protection.configuration.entity.rule.staticrule.StaticRuleObject;
|
||||
import com.realtime.protection.configuration.entity.task.Command;
|
||||
import com.realtime.protection.configuration.entity.task.TaskCommandInfo;
|
||||
import com.realtime.protection.configuration.entity.whitelist.WhiteListObject;
|
||||
import com.realtime.protection.configuration.utils.SqlSessionWrapper;
|
||||
import com.realtime.protection.configuration.utils.status.AuditStatusValidator;
|
||||
@@ -129,10 +129,10 @@ public class WhiteListService {
|
||||
|
||||
}
|
||||
|
||||
public List<WhiteListObject> whiteListCommandJudge(Command command) {
|
||||
public List<WhiteListObject> whiteListCommandJudge(TaskCommandInfo taskCommandInfo) {
|
||||
//参数应该是指令,不管动态静态
|
||||
// 命中的whitelist列表:每一列包含ip port url
|
||||
return whiteListMapper.whiteListCommandJudge(command.getFiveTupleWithMask());
|
||||
return whiteListMapper.whiteListCommandJudge(taskCommandInfo.getFiveTupleWithMask());
|
||||
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user