diff --git a/src/main/java/com/realtime/protection/configuration/entity/defense/object/ProtectObject.java b/src/main/java/com/realtime/protection/configuration/entity/defense/object/ProtectObject.java index ca286bc..d472407 100644 --- a/src/main/java/com/realtime/protection/configuration/entity/defense/object/ProtectObject.java +++ b/src/main/java/com/realtime/protection/configuration/entity/defense/object/ProtectObject.java @@ -29,7 +29,7 @@ public class ProtectObject { private String protectObjectSystemName; @JsonProperty("proobj_ip_address") - @Pattern(regexp = "^(\\d{1,3})\\.(\\d{1,3})\\.(\\d{1,3})\\.(\\d{1,3})$", message = "Invalid IPv4 Address") + @Pattern(regexp = "^(\\d{1,3})\\.(\\d{1,3})\\.(\\d{1,3})\\.(\\d{1,3})$", message = "无效IPv4地址") @ExcelProperty("IP地址") @Schema(description = "防护对象IPv4地址", example = "192.168.0.1") private String protectObjectIPAddress; diff --git a/src/main/java/com/realtime/protection/configuration/entity/task/Command.java b/src/main/java/com/realtime/protection/configuration/entity/task/Command.java deleted file mode 100644 index 27fafb7..0000000 --- a/src/main/java/com/realtime/protection/configuration/entity/task/Command.java +++ /dev/null @@ -1,32 +0,0 @@ -package com.realtime.protection.configuration.entity.task; - -import com.realtime.protection.configuration.utils.enums.ProtocolEnum; -import lombok.Data; - -import java.time.LocalDateTime; - -@Data -public class Command { - private FiveTupleWithMask fiveTupleWithMask; - private Long taskId; - - private String operation; - private LocalDateTime validTime; - private LocalDateTime invalidTime; - - public static Command generateCommand(TaskCommandInfo info, LocalDateTime validTime) { - Command command = new Command(); - - FiveTupleWithMask fiveTupleWithMask = info.getFiveTupleWithMask(); - if (fiveTupleWithMask.getProtocol() != null) - fiveTupleWithMask.setProtocolNum(ProtocolEnum.valueOf(fiveTupleWithMask.getProtocol()).getProtocolNumber()); - - command.setFiveTupleWithMask(fiveTupleWithMask); - command.setTaskId(info.getTaskId()); - command.setOperation(info.getOperation()); - command.setValidTime(validTime); - command.setInvalidTime(info.getEndTime()); - - return command; - } -} diff --git a/src/main/java/com/realtime/protection/configuration/entity/task/FiveTupleWithMask.java b/src/main/java/com/realtime/protection/configuration/entity/task/FiveTupleWithMask.java index 55b563a..33ddd5b 100644 --- a/src/main/java/com/realtime/protection/configuration/entity/task/FiveTupleWithMask.java +++ b/src/main/java/com/realtime/protection/configuration/entity/task/FiveTupleWithMask.java @@ -1,19 +1,58 @@ package com.realtime.protection.configuration.entity.task; +import io.swagger.v3.oas.annotations.media.Schema; +import jakarta.validation.constraints.Max; +import jakarta.validation.constraints.Min; +import jakarta.validation.constraints.Pattern; import lombok.Data; @Data public class FiveTupleWithMask { + @Schema(description = "地址类型(IPv4 or IPv6)", example = "4") private Integer addrType; + + @Schema(description = "源IP", example = "192.168.104.14") + @Pattern(regexp = "^(\\d{1,3})\\.(\\d{1,3})\\.(\\d{1,3})\\.(\\d{1,3})$", message = "源IP:无效IPv4地址") private String sourceIP; + + @Schema(description = "源端口", example = "114") + @Max(value = 65535, message = "源端口不可大于65535") + @Min(value = 1, message = "源端口不可小于1") private String sourcePort; + + @Schema(description = "目的IP", example = "102.165.11.39") + @Pattern(regexp = "^(\\d{1,3})\\.(\\d{1,3})\\.(\\d{1,3})\\.(\\d{1,3})$", message = "目的IP:无效IPv4地址") private String destinationIP; + + @Schema(description = "目的端口", example = "514") + @Max(value = 65535, message = "目的端口不可大于65535") + @Min(value = 1, message = "目的端口不可小于1") private String destinationPort; + + @Schema(description = "协议名称", example = "TCP", accessMode = Schema.AccessMode.WRITE_ONLY) private String protocol; + + @Schema(description = "协议号", example = "6", accessMode = Schema.AccessMode.READ_ONLY) private Integer protocolNum; + @Schema(description = "源IP掩码", example = "255.255.255.0") + @Pattern(regexp = "^(\\d{1,3})\\.(\\d{1,3})\\.(\\d{1,3})\\.(\\d{1,3})$", message = "源IP掩码:无效IPv4地址") private String maskSourceIP; + + @Schema(description = "源端口掩码", example = "0") + @Max(value = 65535, message = "源端口掩码不可大于65535") + @Min(value = 1, message = "源端口掩码不可小于1") private String maskSourcePort; + + @Schema(description = "目的IP掩码", example = "255.255.0.0") + @Pattern(regexp = "^(\\d{1,3})\\.(\\d{1,3})\\.(\\d{1,3})\\.(\\d{1,3})$", message = "目的IP掩码:无效IPv4地址") private String maskDestinationIP; + + @Schema(description = "目的端口掩码", example = "0") + @Max(value = 65535, message = "目的端口掩码不可大于65535") + @Min(value = 1, message = "目的端口掩码不可小于1") private String maskDestinationPort; + + @Schema(description = "协议掩码", example = "0") + private String maskProtocol; } diff --git a/src/main/java/com/realtime/protection/configuration/entity/task/Task.java b/src/main/java/com/realtime/protection/configuration/entity/task/Task.java index 818e7d2..dc31aff 100644 --- a/src/main/java/com/realtime/protection/configuration/entity/task/Task.java +++ b/src/main/java/com/realtime/protection/configuration/entity/task/Task.java @@ -52,24 +52,24 @@ public class Task { private String taskAct; @JsonProperty("task_create_username") - @Schema(hidden = true) + @Schema(description = "任务创建人名称", accessMode = Schema.AccessMode.READ_ONLY) private String taskCreateUsername; @JsonProperty("task_create_depart") - @Schema(hidden = true) + @Schema(description = "任务创建人处室", accessMode = Schema.AccessMode.READ_ONLY) private String taskCreateDepart; @JsonProperty("task_create_userid") - @Schema(hidden = true) - private Long taskCreateUserId; + @Schema(description = "任务创建人ID", accessMode = Schema.AccessMode.READ_ONLY) + private Integer taskCreateUserId; @JsonProperty("static_rule_ids") @Schema(description = "静态规则ID列表,动态和静态至少存在1个规则", example = "[10, 12]") - private List staticRuleIds; + private List staticRuleIds; @JsonProperty("dynamic_rule_ids") @Schema(description = "动态规则ID列表,动态和静态至少存在1个规则", example = "[20, 30]") - private List dynamicRuleIds; + private List dynamicRuleIds; @JsonProperty("task_status") @Schema(description = "任务状态(0为未启动,1为生成中,2为运行中,3为暂停中,4为已停止,5为已结束,6为失败)", accessMode = Schema.AccessMode.READ_ONLY) diff --git a/src/main/java/com/realtime/protection/configuration/entity/task/TaskCommandInfo.java b/src/main/java/com/realtime/protection/configuration/entity/task/TaskCommandInfo.java index d420a94..18f65db 100644 --- a/src/main/java/com/realtime/protection/configuration/entity/task/TaskCommandInfo.java +++ b/src/main/java/com/realtime/protection/configuration/entity/task/TaskCommandInfo.java @@ -1,19 +1,68 @@ package com.realtime.protection.configuration.entity.task; +import io.swagger.v3.oas.annotations.media.Schema; +import jakarta.validation.constraints.NotNull; import lombok.Data; import java.time.LocalDateTime; @Data public class TaskCommandInfo { - private FiveTupleWithMask fiveTupleWithMask; + @Schema(description = "指令UUID", accessMode = Schema.AccessMode.READ_ONLY) + private String UUID; + @Schema(description = "任务ID", accessMode = Schema.AccessMode.READ_ONLY) private Long taskId; + + @Schema(description = "规则ID", hidden = true) private Long ruleId; - // 额外字段 - private String operation; + @Schema(description = "任务创建人名称", accessMode = Schema.AccessMode.READ_ONLY) + private String taskCreateUsername; + + @Schema(description = "任务创建人处室", accessMode = Schema.AccessMode.READ_ONLY) + private String taskCreateDepart; + + @Schema(description = "任务创建人ID", accessMode = Schema.AccessMode.READ_ONLY) + private Integer taskCreateUserId; + + @Schema(description = "任务名称", example = "API测试任务") + @NotNull(message = "任务名称不能为空") + private String taskName; + + @Schema(description = "任务类型", example = "1") + @NotNull(message = "任务类型不能为空") + private Integer taskType; + + @Schema(description = "任务操作", example = "阻断") + @NotNull(message = "任务操作不能为空。") + private String taskAct; + + @Schema(description = "指令下发频率", example = "30") + @NotNull(message = "指令下发频率不能为空。") private Integer frequency; + + @Schema(description = "任务开始时间", example = "2025-10-14T10:23:33") + @NotNull(message = "任务开始时间不能为空。") private LocalDateTime startTime; + + @Schema(description = "任务结束时间", example = "2026-10-22T10:33:22") + @NotNull(message = "指令结束时间不能为空。") private LocalDateTime endTime; + + @Schema(description = "五元组信息") + @NotNull(message = "五元组信息不能为空。") + private FiveTupleWithMask fiveTupleWithMask; + + @Schema(description = "指令下发次数", accessMode = Schema.AccessMode.READ_ONLY) + private Integer commandSentTimes; + + @Schema(description = "指令成功次数", accessMode = Schema.AccessMode.READ_ONLY) + private Integer commandSuccessTimes; + + @Schema(description = "首次下发时间", accessMode = Schema.AccessMode.READ_ONLY) + private LocalDateTime earliestSendTime; + + @Schema(description = "最新下发时间", accessMode = Schema.AccessMode.READ_ONLY) + private LocalDateTime latestSendTime; } diff --git a/src/main/java/com/realtime/protection/configuration/exception/GlobalExceptionHandler.java b/src/main/java/com/realtime/protection/configuration/exception/GlobalExceptionHandler.java index 4ee2540..835ad1a 100644 --- a/src/main/java/com/realtime/protection/configuration/exception/GlobalExceptionHandler.java +++ b/src/main/java/com/realtime/protection/configuration/exception/GlobalExceptionHandler.java @@ -15,6 +15,8 @@ import org.springframework.web.bind.annotation.ExceptionHandler; import org.springframework.web.bind.annotation.RestControllerAdvice; import org.springframework.web.method.annotation.HandlerMethodValidationException; +import java.sql.SQLException; +import java.sql.SQLIntegrityConstraintViolationException; import java.util.stream.Collectors; @RestControllerAdvice @@ -36,7 +38,12 @@ public class GlobalExceptionHandler { @Order(2) - @ExceptionHandler(value = {PersistenceException.class, DuplicateKeyException.class}) + @ExceptionHandler(value = { + PersistenceException.class, + DuplicateKeyException.class, + SQLException.class, + SQLIntegrityConstraintViolationException.class + }) public ResponseResult handleSQLException(Exception e) { log.info("遭遇数据库异常:" + e.getMessage()); return ResponseResult.invalid().setMessage( @@ -89,7 +96,8 @@ public class GlobalExceptionHandler { .setMessage("Doris数据库指令生成遭遇异常:" + e.getMessage()); try { - stateChangeService.changeState(StateEnum.FAILED.getStateNum(), e.taskId); + // 内部修改状态,可以跳过一切状态检查 + stateChangeService.changeState(StateEnum.FAILED.getStateNum(), e.taskId, true); } catch (Exception another) { responseResult.setAnother(ResponseResult.error().setMessage(e.getMessage())); } diff --git a/src/main/java/com/realtime/protection/configuration/utils/status/AuditStatusValidator.java b/src/main/java/com/realtime/protection/configuration/utils/status/AuditStatusValidator.java index 9226059..22f53bb 100644 --- a/src/main/java/com/realtime/protection/configuration/utils/status/AuditStatusValidator.java +++ b/src/main/java/com/realtime/protection/configuration/utils/status/AuditStatusValidator.java @@ -1,5 +1,8 @@ package com.realtime.protection.configuration.utils.status; +import lombok.extern.slf4j.Slf4j; + +@Slf4j public class AuditStatusValidator { private final Integer auditStatusOriginal; @@ -12,8 +15,8 @@ public class AuditStatusValidator { return new AuditStatusValidator(auditStatusOriginal); } - public Boolean checkValidate(Integer auditStatusNow) { - switch (auditStatusNow) { + public Boolean checkValidate(Integer newAuditStatus) { + switch (newAuditStatus) { case 0, 1 -> { return auditStatusOriginal != 2; } @@ -21,6 +24,7 @@ public class AuditStatusValidator { return auditStatusOriginal != 1; } default -> { + log.debug("欲修改的审核状态不正确,需要使用正确的审核状态,当前的审核状态:" + auditStatusOriginal); return false; } } diff --git a/src/main/java/com/realtime/protection/server/command/CommandMapper.java b/src/main/java/com/realtime/protection/server/command/CommandMapper.java index 751cc8c..4221835 100644 --- a/src/main/java/com/realtime/protection/server/command/CommandMapper.java +++ b/src/main/java/com/realtime/protection/server/command/CommandMapper.java @@ -1,6 +1,6 @@ package com.realtime.protection.server.command; -import com.realtime.protection.configuration.entity.task.Command; +import com.realtime.protection.configuration.entity.task.TaskCommandInfo; import org.apache.ibatis.annotations.Mapper; import org.apache.ibatis.annotations.Param; @@ -8,13 +8,15 @@ import java.util.List; @Mapper public interface CommandMapper { - Boolean createCommand(@Param("command") Command command); + Boolean createCommand(@Param("info") TaskCommandInfo taskCommandInfo); - void createCommands(@Param("commands") List commands); + void createCommands(@Param("command_infos") List taskCommandInfos); Boolean stopCommandsByTaskId(@Param("task_id") Long taskId); Boolean removeCommandsByTaskId(@Param("task_id") Long taskId); Boolean startCommandsByTaskId(@Param("task_id") Long taskId); + + List queryCommandInfoByTaskId(@Param("task_id") Long taskId); } diff --git a/src/main/java/com/realtime/protection/server/command/CommandService.java b/src/main/java/com/realtime/protection/server/command/CommandService.java index df52b48..a36f224 100644 --- a/src/main/java/com/realtime/protection/server/command/CommandService.java +++ b/src/main/java/com/realtime/protection/server/command/CommandService.java @@ -2,120 +2,69 @@ package com.realtime.protection.server.command; import com.alibaba.excel.util.ListUtils; import com.baomidou.dynamic.datasource.annotation.DS; -import com.realtime.protection.configuration.entity.task.Command; import com.realtime.protection.configuration.entity.task.TaskCommandInfo; -import com.realtime.protection.configuration.exception.DorisStartException; import com.realtime.protection.configuration.utils.SqlSessionWrapper; -import com.realtime.protection.configuration.utils.enums.StateEnum; -import com.realtime.protection.server.task.TaskService; import lombok.extern.slf4j.Slf4j; -import org.springframework.scheduling.annotation.Async; import org.springframework.stereotype.Service; -import java.time.LocalDateTime; import java.util.List; import java.util.function.Function; @Service @Slf4j +@DS("doris") public class CommandService { private final CommandMapper commandMapper; - private final TaskService taskService; private final SqlSessionWrapper sqlSessionWrapper; private static final int BatchSize = 100; - private final Function> createCommandBatchFunction; - public CommandService(CommandMapper commandMapper, TaskService taskService, SqlSessionWrapper sqlSessionWrapper) { + public CommandService(CommandMapper commandMapper, SqlSessionWrapper sqlSessionWrapper) { this.commandMapper = commandMapper; - this.taskService = taskService; this.sqlSessionWrapper = sqlSessionWrapper; - this.createCommandBatchFunction = mapper -> info -> { - if (info.getFrequency() == null) { - Command command = Command.generateCommand(info, info.getStartTime()); - mapper.createCommand(command); - } + } - List commandBatch = ListUtils.newArrayListWithExpectedSize(BatchSize); - LocalDateTime validTime = info.getStartTime(); + public Boolean createCommand(TaskCommandInfo commandInfo) { + return commandMapper.createCommand(commandInfo); + } - while (validTime.isBefore(info.getEndTime())) { - Command command = Command.generateCommand(info, validTime); - commandBatch.add(command); - - validTime = validTime.plusMinutes(info.getFrequency()); - - if (commandBatch.size() < BatchSize) { + public void createCommands(List taskCommandInfos) { + Function, Boolean>> function = mapper -> list -> { + List taskCommandInfoBatch = ListUtils.newArrayListWithExpectedSize(BatchSize); + for (TaskCommandInfo info : list) { + taskCommandInfoBatch.add(info); + if (taskCommandInfoBatch.size() < BatchSize) { continue; } - mapper.createCommands(commandBatch); - commandBatch.clear(); + + commandMapper.createCommands(taskCommandInfoBatch); + taskCommandInfoBatch.clear(); } - if (!commandBatch.isEmpty()) { - mapper.createCommands(commandBatch); - commandBatch.clear(); + if (!taskCommandInfoBatch.isEmpty()) { + commandMapper.createCommands(taskCommandInfoBatch); + taskCommandInfoBatch.clear(); } - log.debug(String.format("在task(%d)和rule(%d)中构建了全部指令", - info.getTaskId(), info.getRuleId())); - return null; - }; - } - - @Async - @DS("doris") - public void createCommand(TaskCommandInfo commandInfo) throws DorisStartException { - try { - sqlSessionWrapper.startBatchSession(CommandMapper.class, createCommandBatchFunction, commandInfo); - taskService.changeTaskStatus(commandInfo.getTaskId(), StateEnum.RUNNING.getStateNum()); - } catch (Exception e) { - throw new DorisStartException(e, commandInfo.getTaskId()); - } - } - - @Async - @DS("doris") - public void createCommands(List taskCommandInfos) throws DorisStartException { - Function, Void>> function = mapper -> list -> { - if (list == null || list.isEmpty()) { - return null; - } - - for (TaskCommandInfo info : list) { - createCommandBatchFunction.apply(mapper).apply(info); - } - - taskService.changeTaskStatus(list.get(0).getTaskId(), StateEnum.RUNNING.getStateNum()); - return null; + return true; }; - try { - sqlSessionWrapper.startBatchSession(CommandMapper.class, function, taskCommandInfos); - } catch (Exception e) { - TaskCommandInfo info = null; - if (taskCommandInfos != null) { - info = taskCommandInfos.get(0); - } - Long taskId = null; - if (info != null) { - taskId = info.getTaskId(); - } - throw new DorisStartException(e, taskId); - } + sqlSessionWrapper.startBatchSession(CommandMapper.class, function, taskCommandInfos); + + } + + public List queryCommandInfoByTaskId(Long taskId) { + return commandMapper.queryCommandInfoByTaskId(taskId); } - @DS("doris") public Boolean startCommandsByTaskId(Long taskId) { return commandMapper.startCommandsByTaskId(taskId); } - @DS("doris") public Boolean stopCommandsByTaskId(Long taskId) { return commandMapper.stopCommandsByTaskId(taskId); } - @DS("doris") public Boolean removeCommandsByTaskId(Long taskId) { return commandMapper.removeCommandsByTaskId(taskId); } diff --git a/src/main/java/com/realtime/protection/server/defense/object/ProtectObjectService.java b/src/main/java/com/realtime/protection/server/defense/object/ProtectObjectService.java index 2a73b90..11ce685 100644 --- a/src/main/java/com/realtime/protection/server/defense/object/ProtectObjectService.java +++ b/src/main/java/com/realtime/protection/server/defense/object/ProtectObjectService.java @@ -78,7 +78,6 @@ public class ProtectObjectService { return false; } boolean success = true; - Integer result; List protectObjectBatch = ListUtils.newArrayListWithExpectedSize(batchSize); for (Integer protectObjectId : list) { diff --git a/src/main/java/com/realtime/protection/server/rule/dynamicrule/DynamicRuleControllerApi.java b/src/main/java/com/realtime/protection/server/rule/dynamicrule/DynamicRuleControllerApi.java index a92a96c..769140d 100644 --- a/src/main/java/com/realtime/protection/server/rule/dynamicrule/DynamicRuleControllerApi.java +++ b/src/main/java/com/realtime/protection/server/rule/dynamicrule/DynamicRuleControllerApi.java @@ -74,7 +74,7 @@ public interface DynamicRuleControllerApi { @Parameter(name = "ids", description = "动态规则id列表") } ) - public ResponseResult deleteDynamicRuleObjects(@PathVariable List ids); + ResponseResult deleteDynamicRuleObjects(@PathVariable List ids); @Operation( summary = "修改动态规则", @@ -94,7 +94,7 @@ public interface DynamicRuleControllerApi { requestBody = @io.swagger.v3.oas.annotations.parameters.RequestBody( description = "动态规则信息") ) - public ResponseResult updateDynamicRuleObject( + ResponseResult updateDynamicRuleObject( @PathVariable Integer id, @RequestBody @Valid DynamicRuleObject dynamicRuleObject); @@ -114,7 +114,7 @@ public interface DynamicRuleControllerApi { @Parameter(name = "id", description = "动态规则ID", example = "2") } ) - public ResponseResult queryDynamicRuleObjectById(@PathVariable Integer id); + ResponseResult queryDynamicRuleObjectById(@PathVariable Integer id); @Operation( summary = "根据条件查询多个动态规则", @@ -135,7 +135,7 @@ public interface DynamicRuleControllerApi { @Parameter(name = "page_size", description = "每页大小", example = "10") } ) - public ResponseResult queryDynamicRuleObject( + ResponseResult queryDynamicRuleObject( @RequestParam(value = "name", required = false) String dynamicRuleName, @RequestParam(value = "id", required = false) Integer dynamicRuleId, @RequestParam(value = "page", defaultValue = "1") Integer page, diff --git a/src/main/java/com/realtime/protection/server/rule/staticrule/StaticRuleControllerApi.java b/src/main/java/com/realtime/protection/server/rule/staticrule/StaticRuleControllerApi.java index 9679367..057490a 100644 --- a/src/main/java/com/realtime/protection/server/rule/staticrule/StaticRuleControllerApi.java +++ b/src/main/java/com/realtime/protection/server/rule/staticrule/StaticRuleControllerApi.java @@ -183,6 +183,6 @@ public interface StaticRuleControllerApi { @Parameter(name = "auditStatus", description = "要修改为的静态规则审核状态") } ) - public ResponseResult updateStaticRuleAuditStatus(@PathVariable Integer id, @PathVariable Integer auditStatus); + ResponseResult updateStaticRuleAuditStatus(@PathVariable Integer id, @PathVariable Integer auditStatus); } diff --git a/src/main/java/com/realtime/protection/server/task/TaskController.java b/src/main/java/com/realtime/protection/server/task/TaskController.java index 8b8bf15..38da9f2 100644 --- a/src/main/java/com/realtime/protection/server/task/TaskController.java +++ b/src/main/java/com/realtime/protection/server/task/TaskController.java @@ -1,9 +1,11 @@ package com.realtime.protection.server.task; import com.realtime.protection.configuration.entity.task.Task; +import com.realtime.protection.configuration.entity.task.TaskCommandInfo; import com.realtime.protection.configuration.exception.DorisStartException; import com.realtime.protection.configuration.response.ResponseResult; import com.realtime.protection.configuration.utils.EntityUtils; +import com.realtime.protection.server.command.CommandService; import com.realtime.protection.server.task.status.StateChangeService; import jakarta.validation.Valid; import jakarta.validation.constraints.Max; @@ -18,10 +20,12 @@ import java.util.List; public class TaskController implements TaskControllerApi { private final TaskService taskService; + private final CommandService commandService; private final StateChangeService stateChangeService; - public TaskController(TaskService taskService, StateChangeService stateChangeService) { + public TaskController(TaskService taskService, CommandService commandService, StateChangeService stateChangeService) { this.taskService = taskService; + this.commandService = commandService; this.stateChangeService = stateChangeService; } @@ -43,6 +47,24 @@ public class TaskController implements TaskControllerApi { .setData("success", false); } + // API推送Endpoint + @Override + @PostMapping("/api/new") + public ResponseResult newTaskWithAPI(@RequestBody @Valid TaskCommandInfo taskCommandInfo) { + Long taskId = taskService.newTaskUsingCommandInfo(taskCommandInfo); + if (taskId <= 0) { + return ResponseResult.invalid() + .setData("taskId", -1) + .setData("success", false); + } + + commandService.createCommand(taskCommandInfo); + + return ResponseResult.ok() + .setData("taskId", taskId) + .setData("success", true); + } + @Override @GetMapping("/query") public ResponseResult queryTasks(@RequestParam(value = "task_status", required = false) Integer taskStatus, @@ -62,7 +84,7 @@ public class TaskController implements TaskControllerApi { Task task = taskService.queryTask(id); if (task == null) { - return ResponseResult.invalid().setMessage("Task ID is invalid"); + return ResponseResult.invalid().setMessage("无效Task ID,也许该ID对应的任务不存在?"); } return ResponseResult.ok() @@ -103,7 +125,16 @@ public class TaskController implements TaskControllerApi { @PathVariable @NotNull Long taskId) throws DorisStartException { return ResponseResult.ok() .setData("task_id", taskId) - .setData("success", stateChangeService.changeState(stateNum, taskId)) + // 外部修改状态,需要进行状态检查 + .setData("success", stateChangeService.changeState(stateNum, taskId, false)) .setData("status_now", taskService.queryTaskStatus(taskId)); } + + @Override + @GetMapping("/{taskId}/commands") + public ResponseResult queryCommandInfoByTaskId(@PathVariable Long taskId) { + return ResponseResult.ok() + .setData("success", true) + .setData("commands", commandService.queryCommandInfoByTaskId(taskId)); + } } diff --git a/src/main/java/com/realtime/protection/server/task/TaskControllerApi.java b/src/main/java/com/realtime/protection/server/task/TaskControllerApi.java index df37b8e..fe9a308 100644 --- a/src/main/java/com/realtime/protection/server/task/TaskControllerApi.java +++ b/src/main/java/com/realtime/protection/server/task/TaskControllerApi.java @@ -1,6 +1,7 @@ package com.realtime.protection.server.task; import com.realtime.protection.configuration.entity.task.Task; +import com.realtime.protection.configuration.entity.task.TaskCommandInfo; import com.realtime.protection.configuration.exception.DorisStartException; import com.realtime.protection.configuration.response.ResponseResult; import io.swagger.v3.oas.annotations.Operation; @@ -34,6 +35,24 @@ public interface TaskControllerApi { ) ResponseResult newTask(@RequestBody @Valid Task task); + // API推送Endpoint + @PostMapping("/api/new") + @Operation( + summary = "任务推送外部API", + description = "提供给外部的任务推送API", + responses = { + @ApiResponse( + description = "返回外部任务推送结果", + content = @Content( + mediaType = "application/json", + schema = @Schema(implementation = ResponseResult.class) + ) + ) + }, + requestBody = @io.swagger.v3.oas.annotations.parameters.RequestBody(description = "任务推送信息") + ) + ResponseResult newTaskWithAPI(@RequestBody @Valid TaskCommandInfo taskCommandInfo) throws DorisStartException; + @GetMapping("/query") @Operation( summary = "查询任务", @@ -162,4 +181,23 @@ public interface TaskControllerApi { ) ResponseResult changeTaskStatus(@PathVariable @NotNull @Min(0) @Max(6) Integer stateNum, @PathVariable @NotNull Long taskId) throws DorisStartException; + + @GetMapping("/{taskId}/commands") + @Operation( + summary = "获得任务已推送指令的相关数据", + description = "获得任务已推送指令的相关数据,包括最新下发时间、首次下发时间、下发次数、下发成功次数等", + responses = { + @ApiResponse( + description = "返回任务已推送指令的相关数据", + content = @Content( + mediaType = "application/json", + schema = @Schema(implementation = ResponseResult.class) + ) + ) + }, + parameters = { + @Parameter(name = "taskId", description = "任务ID") + } + ) + ResponseResult queryCommandInfoByTaskId(@PathVariable Long taskId); } diff --git a/src/main/java/com/realtime/protection/server/task/TaskMapper.java b/src/main/java/com/realtime/protection/server/task/TaskMapper.java index b01e6a2..1e26df7 100644 --- a/src/main/java/com/realtime/protection/server/task/TaskMapper.java +++ b/src/main/java/com/realtime/protection/server/task/TaskMapper.java @@ -12,10 +12,12 @@ public interface TaskMapper { void newTask(@Param("task") Task task); void newTaskStaticRuleConcat(@Param("task_id") Long taskId, - @Param("rule_ids") List staticRuleIds); + @Param("rule_ids") List staticRuleIds); void newTaskDynamicRuleConcat(@Param("task_id") Long taskId, - @Param("rule_ids") List dynamicRuleIds); + @Param("rule_ids") List dynamicRuleIds); + + void newTaskUsingCommandInfo(@Param("info") TaskCommandInfo taskCommandInfo); List queryTasks(@Param("task_status") Integer taskStatus, @Param("task_type") String task_type, @Param("task_name") String taskName, @Param("task_creator") String taskCreator, @@ -35,9 +37,13 @@ public interface TaskMapper { Boolean changeTaskStatus(@Param("task_id") Long taskId, @Param("state") Integer stateNum); - List getStaticCommands(@Param("task_id") Long taskId); + List getStaticCommandInfos(@Param("task_id") Long taskId); Integer queryTaskAuditStatus(@Param("task_id") Long taskId); Integer queryTaskStatus(@Param("task_id") Long taskId); + + List queryDynamicRuleIdsFromTaskId(@Param("task_id") Long taskId); + + List queryStaticRuleIdsFromTaskId(@Param("task_id") Long taskId); } diff --git a/src/main/java/com/realtime/protection/server/task/TaskService.java b/src/main/java/com/realtime/protection/server/task/TaskService.java index 2de3150..410fa9d 100644 --- a/src/main/java/com/realtime/protection/server/task/TaskService.java +++ b/src/main/java/com/realtime/protection/server/task/TaskService.java @@ -10,6 +10,7 @@ import org.springframework.transaction.annotation.Transactional; import java.util.List; @Service +@DS("mysql") public class TaskService { private final TaskMapper taskMapper; @@ -19,6 +20,10 @@ public class TaskService { @Transactional public Long newTask(Task task) { + task.setTaskCreateUserId(1); + task.setTaskCreateUsername("xxx"); + task.setTaskCreateDepart("xxx"); + taskMapper.newTask(task); if (task.getStaticRuleIds() != null && !task.getStaticRuleIds().isEmpty()) @@ -30,14 +35,33 @@ public class TaskService { return task.getTaskId(); } + @Transactional public List queryTasks(Integer taskStatus, String taskType, String taskName, String taskCreator, Integer page, Integer pageSize) { - return taskMapper.queryTasks(taskStatus, taskType, taskName, taskCreator, page, pageSize); + List tasks = taskMapper.queryTasks(taskStatus, taskType, taskName, taskCreator, page, pageSize); + for (Task task : tasks) { + if (task == null) { + continue; + } + task.setStaticRuleIds(taskMapper.queryStaticRuleIdsFromTaskId(task.getTaskId())); + task.setDynamicRuleIds(taskMapper.queryDynamicRuleIdsFromTaskId(task.getTaskId())); + } + + return tasks; } + @Transactional public Task queryTask(Long id) { - return taskMapper.queryTask(id); + Task task = taskMapper.queryTask(id); + if (task == null) { + return null; + } + + task.setStaticRuleIds(taskMapper.queryStaticRuleIdsFromTaskId(task.getTaskId())); + task.setDynamicRuleIds(taskMapper.queryDynamicRuleIdsFromTaskId(task.getTaskId())); + + return task; } @Transactional @@ -74,13 +98,13 @@ public class TaskService { return taskMapper.deleteTask(taskId); } - @DS("mysql") + public Boolean changeTaskStatus(Long taskId, Integer stateNum) { return taskMapper.changeTaskStatus(taskId, stateNum); } public List getStaticCommandInfos(Long taskId) { - return taskMapper.getStaticCommands(taskId); + return taskMapper.getStaticCommandInfos(taskId); } public Integer queryTaskAuditStatus(Long taskId) { @@ -90,4 +114,9 @@ public class TaskService { public Integer queryTaskStatus(Long taskId) { return taskMapper.queryTaskStatus(taskId); } + + public Long newTaskUsingCommandInfo(TaskCommandInfo taskCommandInfo) { + taskMapper.newTaskUsingCommandInfo(taskCommandInfo); + return taskCommandInfo.getTaskId(); + } } diff --git a/src/main/java/com/realtime/protection/server/task/status/StateChangeService.java b/src/main/java/com/realtime/protection/server/task/status/StateChangeService.java index 56e4185..450e093 100644 --- a/src/main/java/com/realtime/protection/server/task/status/StateChangeService.java +++ b/src/main/java/com/realtime/protection/server/task/status/StateChangeService.java @@ -23,12 +23,7 @@ public class StateChangeService { } @DSTransactional - public Boolean changeState(Integer stateNum, Long taskId) throws DorisStartException { - if (Objects.equals(stateNum, StateEnum.GENERATING.getStateNum()) || - Objects.equals(stateNum, StateEnum.FAILED.getStateNum())) { - throw new IllegalArgumentException("非法任务状态:" + StateEnum.getStateByNum(stateNum)); - } - + public Boolean changeState(Integer stateNum, Long taskId, Boolean inner) throws DorisStartException { Integer originalStateNum = taskService.queryTaskStatus(taskId); if (originalStateNum == null) { throw new IllegalArgumentException("无法找到" + taskId + "的任务状态,也许任务ID不存在?"); @@ -38,8 +33,15 @@ public class StateChangeService { State newState = StateEnum.getStateByNum(stateNum); - if (newState == null) { - return false; + if (!inner && !checkState(originalState, newState)) { + throw new IllegalArgumentException( + String.format("任务状态转换失败,原状态:%s,欲切换状态:%s", + originalState.getClass().getSimpleName(), + newState.getClass().getSimpleName())); + } + + if (Objects.equals(originalState, newState)) { + return true; } if (!originalState.handle(newState, commandService, taskService, taskId)) { @@ -54,4 +56,21 @@ public class StateChangeService { // 这里一定是handle成功的状态,我们再进行task status的修改,如果handle失败,要么返回false,要么抛出异常,不会进入此处 return taskService.changeTaskStatus(taskId, stateNum); } + + private Boolean checkState(State originalState, State newState) { + if (originalState == null || newState == null) { + return false; + } + + // FAILED、FINISHED状态以及GENERATING都只能在程序内部修改,外部接口不能修改 + if (Objects.equals(newState, StateEnum.FAILED.getState()) + || Objects.equals(newState, StateEnum.FINISHED.getState()) + || Objects.equals(newState, StateEnum.GENERATING.getState())) { + return false; + } + + // 在任务状态转换为GENERATING之后,我们需要在外部接口屏蔽掉所有状态 + // 我们需要保证只有任务创建函数才能将GENERATING状态转换为RUNNING状态 + return !Objects.equals(originalState, StateEnum.GENERATING.getState()); + } } diff --git a/src/main/java/com/realtime/protection/server/task/status/states/FailedState.java b/src/main/java/com/realtime/protection/server/task/status/states/FailedState.java index 76ad571..b38a16e 100644 --- a/src/main/java/com/realtime/protection/server/task/status/states/FailedState.java +++ b/src/main/java/com/realtime/protection/server/task/status/states/FailedState.java @@ -13,7 +13,6 @@ public class FailedState extends StateHandler implements State { return switch (StateEnum.getStateEnumByState(newState)) { case RUNNING -> handleStart(taskService, commandService, taskId); case STOP -> handleStop(commandService, taskId); - case FAILED -> true; default -> throw new IllegalStateException("Unexpected value: " + StateEnum.getStateEnumByState(newState)); }; } diff --git a/src/main/java/com/realtime/protection/server/task/status/states/GeneratingState.java b/src/main/java/com/realtime/protection/server/task/status/states/GeneratingState.java index e049636..1b37c84 100644 --- a/src/main/java/com/realtime/protection/server/task/status/states/GeneratingState.java +++ b/src/main/java/com/realtime/protection/server/task/status/states/GeneratingState.java @@ -1,6 +1,5 @@ package com.realtime.protection.server.task.status.states; -import com.realtime.protection.configuration.exception.DorisStartException; import com.realtime.protection.configuration.utils.enums.StateEnum; import com.realtime.protection.configuration.utils.status.State; import com.realtime.protection.server.command.CommandService; @@ -9,9 +8,9 @@ import com.realtime.protection.server.task.status.StateHandler; public class GeneratingState extends StateHandler implements State { @Override - public Boolean handle(State newState, CommandService commandService, TaskService taskService, Long taskId) throws DorisStartException { + public Boolean handle(State newState, CommandService commandService, TaskService taskService, Long taskId) { return switch (StateEnum.getStateEnumByState(newState)) { - case RUNNING, GENERATING -> true; + case RUNNING -> true; case FAILED -> handleFailed(commandService, taskId); default -> throw new IllegalStateException("Unexpected value: " + StateEnum.getStateEnumByState(newState)); }; diff --git a/src/main/java/com/realtime/protection/server/task/status/states/PendingState.java b/src/main/java/com/realtime/protection/server/task/status/states/PendingState.java index 17f4bf3..3735bac 100644 --- a/src/main/java/com/realtime/protection/server/task/status/states/PendingState.java +++ b/src/main/java/com/realtime/protection/server/task/status/states/PendingState.java @@ -11,8 +11,8 @@ public class PendingState extends StateHandler implements State { @Override public Boolean handle(State newState, CommandService commandService, TaskService taskService, Long taskId) throws DorisStartException { return switch (StateEnum.getStateEnumByState(newState)) { - case GENERATING -> handleStart(taskService, commandService, taskId); case FAILED -> handleFailed(commandService, taskId); + case RUNNING -> handleStart(taskService, commandService, taskId); default -> throw new IllegalStateException("Unexpected value: " + StateEnum.getStateEnumByState(newState)); }; } diff --git a/src/main/java/com/realtime/protection/server/task/status/states/RunningState.java b/src/main/java/com/realtime/protection/server/task/status/states/RunningState.java index cbe61db..d184cb2 100644 --- a/src/main/java/com/realtime/protection/server/task/status/states/RunningState.java +++ b/src/main/java/com/realtime/protection/server/task/status/states/RunningState.java @@ -10,7 +10,6 @@ public class RunningState extends StateHandler implements State { @Override public Boolean handle(State newState, CommandService commandService, TaskService taskService, Long taskId) { return switch (StateEnum.getStateEnumByState(newState)) { - case RUNNING, GENERATING -> true; case PAUSED -> handlePause(commandService, taskId); case STOP -> handleStop(commandService, taskId); case FINISHED -> handleFinish(commandService, taskId); diff --git a/src/main/java/com/realtime/protection/server/whitelist/WhiteListService.java b/src/main/java/com/realtime/protection/server/whitelist/WhiteListService.java index 3d939c4..7b68982 100644 --- a/src/main/java/com/realtime/protection/server/whitelist/WhiteListService.java +++ b/src/main/java/com/realtime/protection/server/whitelist/WhiteListService.java @@ -2,7 +2,7 @@ package com.realtime.protection.server.whitelist; import com.alibaba.excel.util.ListUtils; import com.realtime.protection.configuration.entity.rule.staticrule.StaticRuleObject; -import com.realtime.protection.configuration.entity.task.Command; +import com.realtime.protection.configuration.entity.task.TaskCommandInfo; import com.realtime.protection.configuration.entity.whitelist.WhiteListObject; import com.realtime.protection.configuration.utils.SqlSessionWrapper; import com.realtime.protection.configuration.utils.status.AuditStatusValidator; @@ -129,10 +129,10 @@ public class WhiteListService { } - public List whiteListCommandJudge(Command command) { + public List whiteListCommandJudge(TaskCommandInfo taskCommandInfo) { //参数应该是指令,不管动态静态 // 命中的whitelist列表:每一列包含ip port url - return whiteListMapper.whiteListCommandJudge(command.getFiveTupleWithMask()); + return whiteListMapper.whiteListCommandJudge(taskCommandInfo.getFiveTupleWithMask()); } diff --git a/src/main/resources/config/application-dev.yml b/src/main/resources/config/application-dev.yml index cd27f86..d1b128a 100644 --- a/src/main/resources/config/application-dev.yml +++ b/src/main/resources/config/application-dev.yml @@ -48,6 +48,8 @@ task: springdoc: api-docs: - enabled: false + enabled: true + path: /api-docs swagger-ui: - enabled: false \ No newline at end of file + path: /swagger + packages-to-scan: com.realtime.protection.server \ No newline at end of file diff --git a/src/main/resources/mappers/CommandMapper.xml b/src/main/resources/mappers/CommandMapper.xml index 8a7b5bb..fe0b50c 100644 --- a/src/main/resources/mappers/CommandMapper.xml +++ b/src/main/resources/mappers/CommandMapper.xml @@ -3,39 +3,79 @@ PUBLIC "-//mybatis.org//DTD Mapper 3.0//EN" "http://mybatis.org/dtd/mybatis-3-mapper.dtd"> - - insert into t_command(COMMAND_ID, ADDR_TYPE, SRC_IP, SRC_PORT, DST_IP, DST_PORT, PROTOCOL, MASK_SRC_IP, - MASK_SRC_PORT, - MASK_DST_IP, MASK_DST_PORT, IS_VALID, VALID_TIME, INVALID_TIME, IS_SENT, - CREATE_TIME, LAST_UPDATE, IS_DELETED, TASK_ID) - values (UUID(), DEFAULT, - #{command.fiveTupleWithMask.sourceIP}, - #{command.fiveTupleWithMask.sourcePort}, #{command.fiveTupleWithMask.destinationIP}, - #{command.fiveTupleWithMask.destinationPort}, - #{command.fiveTupleWithMask.protocolNum}, #{command.fiveTupleWithMask.maskSourceIP}, - #{command.fiveTupleWithMask.maskSourcePort}, - #{command.fiveTupleWithMask.maskDestinationIP}, #{command.fiveTupleWithMask.maskDestinationPort}, TRUE, - #{command.validTime}, #{command.invalidTime}, FALSE, NOW(), NOW(), FALSE, #{command.taskId}) + + insert into t_command(COMMAND_ID, TASK_ID, TASK_ACT, FREQUENCY, ADDR_TYPE, SRC_IP, SRC_PORT, DST_IP, DST_PORT, + PROTOCOL, + MASK_SRC_IP, MASK_SRC_PORT, MASK_DST_IP, MASK_DST_PORT, MASK_PROTOCOL, VALID_TIME, + INVALID_TIME, IS_VALID, + SEND_TIMES, SUCCESS_TIMES, CREATE_TIME, LAST_UPDATE, IS_DELETED) + values (UUID(), #{info.taskId}, #{info.taskAct}, #{info.frequency}, + #{info.fiveTupleWithMask.addrType}, + #{info.fiveTupleWithMask.sourceIP}, #{info.fiveTupleWithMask.sourcePort}, + #{info.fiveTupleWithMask.destinationIP}, #{info.fiveTupleWithMask.destinationPort}, + #{info.fiveTupleWithMask.protocolNum}, + #{info.fiveTupleWithMask.maskSourceIP}, #{info.fiveTupleWithMask.maskSourcePort}, + #{info.fiveTupleWithMask.maskDestinationIP}, #{info.fiveTupleWithMask.maskDestinationPort}, + #{info.fiveTupleWithMask.maskProtocol}, + #{info.startTime}, #{info.endTime}, TRUE, 0, 0, + NOW(), NOW(), FALSE) - - insert into t_command(COMMAND_ID, ADDR_TYPE, SRC_IP, SRC_PORT, DST_IP, DST_PORT, PROTOCOL, MASK_SRC_IP, - MASK_SRC_PORT, - MASK_DST_IP, MASK_DST_PORT, IS_VALID, VALID_TIME, INVALID_TIME, IS_SENT, - CREATE_TIME, LAST_UPDATE, IS_DELETED, TASK_ID) + + insert into t_command(COMMAND_ID, TASK_ID, TASK_ACT, FREQUENCY, ADDR_TYPE, SRC_IP, SRC_PORT, DST_IP, DST_PORT, + PROTOCOL, + MASK_SRC_IP, MASK_SRC_PORT, MASK_DST_IP, MASK_DST_PORT, MASK_PROTOCOL, VALID_TIME, INVALID_TIME, IS_VALID, + SEND_TIMES, SUCCESS_TIMES, CREATE_TIME, LAST_UPDATE, IS_DELETED) values - - (UUID(), DEFAULT, - #{command.fiveTupleWithMask.sourceIP}, - #{command.fiveTupleWithMask.sourcePort}, #{command.fiveTupleWithMask.destinationIP}, - #{command.fiveTupleWithMask.destinationPort}, - #{command.fiveTupleWithMask.protocolNum}, #{command.fiveTupleWithMask.maskSourceIP}, - #{command.fiveTupleWithMask.maskSourcePort}, - #{command.fiveTupleWithMask.maskDestinationIP}, #{command.fiveTupleWithMask.maskDestinationPort}, TRUE, - #{command.validTime}, #{command.invalidTime}, FALSE, NOW(), NOW(), FALSE, #{command.taskId}) + + (UUID(), #{info.taskId}, #{info.taskAct}, #{info.frequency}, + #{info.fiveTupleWithMask.addrType}, + #{info.fiveTupleWithMask.sourceIP}, #{info.fiveTupleWithMask.sourcePort}, + #{info.fiveTupleWithMask.destinationIP}, #{info.fiveTupleWithMask.destinationPort}, + #{info.fiveTupleWithMask.protocolNum}, + #{info.fiveTupleWithMask.maskSourceIP}, #{info.fiveTupleWithMask.maskSourcePort}, + #{info.fiveTupleWithMask.maskDestinationIP}, #{info.fiveTupleWithMask.maskDestinationPort}, + #{info.fiveTupleWithMask.maskProtocol}, + #{info.startTime}, #{info.endTime}, TRUE, 0, 0, + NOW(), NOW(), FALSE + ) + + + + + + + + + + + + + + + + + + + UPDATE t_command SET IS_VALID = FALSE, @@ -57,5 +97,6 @@ SET IS_DELETED = TRUE, LAST_UPDATE = NOW() WHERE TASK_ID = #{task_id} + AND IS_DELETED = FALSE diff --git a/src/main/resources/mappers/ProtectObjectMapper.xml b/src/main/resources/mappers/ProtectObjectMapper.xml index ae97dd1..a1ba13e 100644 --- a/src/main/resources/mappers/ProtectObjectMapper.xml +++ b/src/main/resources/mappers/ProtectObjectMapper.xml @@ -46,7 +46,10 @@ diff --git a/src/main/resources/mappers/TaskMapper.xml b/src/main/resources/mappers/TaskMapper.xml index f3b8aad..e51c151 100644 --- a/src/main/resources/mappers/TaskMapper.xml +++ b/src/main/resources/mappers/TaskMapper.xml @@ -18,27 +18,25 @@ UPDATE t_static_rule SET static_rule_used_task_id = #{task_id} - - - AND static_rule_id IN - - #{rule_id} - - - + WHERE + + static_rule_id IN + + #{rule_id} + + UPDATE t_dynamic_rule SET dynamic_rule_used_task_id = #{task_id} - - - AND dynamic_rule_id IN - - #{rule_id} - - - + WHERE + + dynamic_rule_id IN + + #{rule_id} + + @@ -55,19 +53,10 @@ - - - - - - - + + + + @@ -152,9 +151,16 @@ - + + + + + + + + @@ -170,15 +176,38 @@ + - + SELECT t_task.task_name, + + t_task.task_id, tsr.static_rule_id, + + t_task.task_create_username, + t_task.task_create_depart, + t_task.task_create_userid, + + t_task.task_type, t_task.task_act, + tsr.static_rule_frequency, + t_task.task_start_time, t_task.task_end_time, + INET_NTOA(tsr.static_rule_sip) as static_rule_sip, tsr.static_rule_sport, INET_NTOA(tsr.static_rule_dip) as static_rule_dip, @@ -188,7 +217,7 @@ tsr.static_rule_msport, INET_NTOA(tsr.static_rule_mdip) as static_rule_mdip, tsr.static_rule_mdport, - tsr.static_rule_frequency + tsr.static_rule_mprotocol FROM t_task LEFT JOIN realtime_protection.t_static_rule tsr on t_task.task_id = tsr.static_rule_used_task_id WHERE task_id = #{task_id} diff --git a/src/test/java/com/realtime/protection/ProtectionApplicationTests.java b/src/test/java/com/realtime/protection/ProtectionApplicationTests.java index 1bc42e9..d1ab67b 100644 --- a/src/test/java/com/realtime/protection/ProtectionApplicationTests.java +++ b/src/test/java/com/realtime/protection/ProtectionApplicationTests.java @@ -1,13 +1,12 @@ package com.realtime.protection; -import org.junit.jupiter.api.Test; +import com.baomidou.dynamic.datasource.annotation.DSTransactional; import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.test.annotation.Rollback; @SpringBootTest -class ProtectionApplicationTests { - - @Test - void contextLoads() { - } +@Rollback +@DSTransactional +public class ProtectionApplicationTests { } diff --git a/src/test/java/com/realtime/protection/server/defense/object/ProtectObjectServiceTest.java b/src/test/java/com/realtime/protection/server/defense/object/ProtectObjectServiceTest.java index a537549..deff95b 100644 --- a/src/test/java/com/realtime/protection/server/defense/object/ProtectObjectServiceTest.java +++ b/src/test/java/com/realtime/protection/server/defense/object/ProtectObjectServiceTest.java @@ -1,5 +1,6 @@ package com.realtime.protection.server.defense.object; +import com.realtime.protection.ProtectionApplicationTests; import com.realtime.protection.configuration.entity.defense.object.ProtectObject; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -13,7 +14,7 @@ import java.util.List; import static org.junit.jupiter.api.Assertions.*; @SpringBootTest -class ProtectObjectServiceTest { +class ProtectObjectServiceTest extends ProtectionApplicationTests { private final ProtectObjectService protectObjectService; private ProtectObject protectObject; diff --git a/src/test/java/com/realtime/protection/server/defense/template/TemplateServiceTest.java b/src/test/java/com/realtime/protection/server/defense/template/TemplateServiceTest.java index fe2e97b..4b6436d 100644 --- a/src/test/java/com/realtime/protection/server/defense/template/TemplateServiceTest.java +++ b/src/test/java/com/realtime/protection/server/defense/template/TemplateServiceTest.java @@ -1,5 +1,6 @@ package com.realtime.protection.server.defense.template; +import com.realtime.protection.ProtectionApplicationTests; import com.realtime.protection.configuration.entity.defense.template.ProtectLevel; import com.realtime.protection.configuration.entity.defense.template.Template; import org.junit.jupiter.api.AfterEach; @@ -15,7 +16,7 @@ import java.util.List; import static org.junit.jupiter.api.Assertions.*; @SpringBootTest -class TemplateServiceTest { +class TemplateServiceTest extends ProtectionApplicationTests { private final TemplateService templateService; private Template template; diff --git a/src/test/java/com/realtime/protection/server/rule/dynamic/DynamicRuleServiceTest.java b/src/test/java/com/realtime/protection/server/rule/dynamic/DynamicRuleServiceTest.java index ab18154..c96bcf6 100644 --- a/src/test/java/com/realtime/protection/server/rule/dynamic/DynamicRuleServiceTest.java +++ b/src/test/java/com/realtime/protection/server/rule/dynamic/DynamicRuleServiceTest.java @@ -1,5 +1,6 @@ package com.realtime.protection.server.rule.dynamic; +import com.realtime.protection.ProtectionApplicationTests; import com.realtime.protection.configuration.entity.rule.dynamicrule.DynamicRuleObject; import com.realtime.protection.server.rule.dynamicrule.DynamicRuleService; import org.junit.jupiter.api.Test; @@ -11,7 +12,7 @@ import java.util.List; import static org.junit.jupiter.api.Assertions.assertTrue; @SpringBootTest -public class DynamicRuleServiceTest { +public class DynamicRuleServiceTest extends ProtectionApplicationTests { private final DynamicRuleService dynamicRuleService; @Autowired diff --git a/src/test/java/com/realtime/protection/server/rule/staticrule/StaticRuleServiceTest.java b/src/test/java/com/realtime/protection/server/rule/staticrule/StaticRuleServiceTest.java index fab1ad1..8aecf2c 100644 --- a/src/test/java/com/realtime/protection/server/rule/staticrule/StaticRuleServiceTest.java +++ b/src/test/java/com/realtime/protection/server/rule/staticrule/StaticRuleServiceTest.java @@ -1,5 +1,6 @@ package com.realtime.protection.server.rule.staticrule; +import com.realtime.protection.ProtectionApplicationTests; import com.realtime.protection.configuration.entity.rule.staticrule.StaticRuleObject; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -13,7 +14,7 @@ import java.util.List; import static org.junit.jupiter.api.Assertions.assertTrue; @SpringBootTest -public class StaticRuleServiceTest { +public class StaticRuleServiceTest extends ProtectionApplicationTests { private final StaticRuleService staticRuleService; private StaticRuleObject staticRuleTest; diff --git a/src/test/java/com/realtime/protection/server/task/TaskServiceTest.java b/src/test/java/com/realtime/protection/server/task/TaskServiceTest.java index e297ff8..810ebaf 100644 --- a/src/test/java/com/realtime/protection/server/task/TaskServiceTest.java +++ b/src/test/java/com/realtime/protection/server/task/TaskServiceTest.java @@ -1,5 +1,6 @@ package com.realtime.protection.server.task; +import com.realtime.protection.ProtectionApplicationTests; import com.realtime.protection.configuration.entity.task.Task; import com.realtime.protection.configuration.entity.task.TaskCommandInfo; import org.junit.jupiter.api.BeforeEach; @@ -14,7 +15,7 @@ import java.util.List; import static org.junit.jupiter.api.Assertions.*; @SpringBootTest -class TaskServiceTest { +class TaskServiceTest extends ProtectionApplicationTests { private final TaskService taskService; private Task task; @@ -35,9 +36,9 @@ class TaskServiceTest { task.setTaskEndTime(taskEndTime); task.setTaskAct("阻断"); task.setTaskType(1); - task.setStaticRuleIds(List.of(1L, 2L)); + task.setStaticRuleIds(List.of(1, 2)); task.setDynamicRuleIds(List.of()); - task.setTaskCreateUserId(1L); + task.setTaskCreateUserId(1); task.setTaskCreateUsername("xxx"); task.setTaskCreateDepart("xxx"); } @@ -76,7 +77,7 @@ class TaskServiceTest { void testUpdateTasks() { Task originalTask = taskService.queryTask(38L); - originalTask.setStaticRuleIds(List.of(16L, 17L, 18L, 19L)); + originalTask.setStaticRuleIds(List.of(16, 17, 18, 19)); originalTask.setTaskName("修改测试"); assertTrue(taskService.updateTask(originalTask)); diff --git a/src/test/java/com/realtime/protection/server/task/status/CommandServiceTest.java b/src/test/java/com/realtime/protection/server/task/status/CommandServiceTest.java index 5da9529..d13c123 100644 --- a/src/test/java/com/realtime/protection/server/task/status/CommandServiceTest.java +++ b/src/test/java/com/realtime/protection/server/task/status/CommandServiceTest.java @@ -1,5 +1,7 @@ package com.realtime.protection.server.task.status; +import com.alibaba.excel.util.ListUtils; +import com.realtime.protection.ProtectionApplicationTests; import com.realtime.protection.configuration.entity.task.FiveTupleWithMask; import com.realtime.protection.configuration.entity.task.TaskCommandInfo; import com.realtime.protection.server.command.CommandService; @@ -10,13 +12,12 @@ import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.test.context.SpringBootTest; import java.time.LocalDateTime; -import java.util.ArrayList; import java.util.List; import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; @SpringBootTest -class CommandServiceTest { +class CommandServiceTest extends ProtectionApplicationTests { private final CommandService commandService; private TaskCommandInfo taskCommandInfo; @@ -37,8 +38,10 @@ class CommandServiceTest { taskCommandInfo.setFrequency(30); taskCommandInfo.setTaskId(30L); taskCommandInfo.setFiveTupleWithMask(fiveTupleWithMask); - taskCommandInfo.setOperation("阻断"); - taskCommandInfo.setEndTime(LocalDateTime.now().plusDays(1)); + taskCommandInfo.setTaskAct("阻断"); + taskCommandInfo.setStartTime(LocalDateTime.now().plusDays(10)); + taskCommandInfo.setEndTime(LocalDateTime.now().plusDays(140)); + taskCommandInfo.setFrequency(30); startTime = System.currentTimeMillis(); } @@ -56,12 +59,13 @@ class CommandServiceTest { @Test void createCommands() { - List taskCommandInfos = new ArrayList<>(); + List taskCommandInfos = ListUtils.newArrayListWithExpectedSize(100); for (int i = 0; i < 100; i++) { int port = i + 1000; - taskCommandInfo = new TaskCommandInfo(); + TaskCommandInfo taskCommandInfo = new TaskCommandInfo(); taskCommandInfo.setFiveTupleWithMask(new FiveTupleWithMask()); - taskCommandInfo.setTaskId(24L); + taskCommandInfo.setTaskId(30L); + taskCommandInfo.setTaskAct("阻断"); taskCommandInfo.getFiveTupleWithMask().setSourcePort(Integer.toString(port)); taskCommandInfo.setStartTime(LocalDateTime.now().plusDays(5)); taskCommandInfo.setEndTime(LocalDateTime.now().plusDays(10)); @@ -70,6 +74,12 @@ class CommandServiceTest { taskCommandInfos.add(taskCommandInfo); } + for (TaskCommandInfo info : taskCommandInfos) { + if (info.getFrequency() == null) { + throw new IllegalArgumentException(); + } + } + assertDoesNotThrow(() -> commandService.createCommands(taskCommandInfos)); } diff --git a/src/test/java/com/realtime/protection/server/whitelist/WhiteListServiceTest.java b/src/test/java/com/realtime/protection/server/whitelist/WhiteListServiceTest.java index 16c6c95..607e1e5 100644 --- a/src/test/java/com/realtime/protection/server/whitelist/WhiteListServiceTest.java +++ b/src/test/java/com/realtime/protection/server/whitelist/WhiteListServiceTest.java @@ -1,7 +1,8 @@ package com.realtime.protection.server.whitelist; -import com.realtime.protection.configuration.entity.task.Command; +import com.realtime.protection.ProtectionApplicationTests; import com.realtime.protection.configuration.entity.task.FiveTupleWithMask; +import com.realtime.protection.configuration.entity.task.TaskCommandInfo; import com.realtime.protection.configuration.entity.whitelist.WhiteListObject; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -14,7 +15,7 @@ import java.util.List; import static org.junit.jupiter.api.Assertions.assertTrue; @SpringBootTest -class WhiteListServiceTest { +class WhiteListServiceTest extends ProtectionApplicationTests { private final WhiteListService whiteListService; private WhiteListObject whiteListObject; @@ -78,13 +79,13 @@ class WhiteListServiceTest { @Test void testWhiteListCommandJudge() { FiveTupleWithMask fiveTupleWithMask = new FiveTupleWithMask(); - Command command = new Command(); + TaskCommandInfo taskCommandInfo = new TaskCommandInfo(); fiveTupleWithMask.setDestinationIP("128.1.1.123"); fiveTupleWithMask.setMaskDestinationIP("255.255.255.0"); fiveTupleWithMask.setDestinationPort("80"); - command.setFiveTupleWithMask(fiveTupleWithMask); + taskCommandInfo.setFiveTupleWithMask(fiveTupleWithMask); - List whitelists = whiteListService.whiteListCommandJudge(command); + List whitelists = whiteListService.whiteListCommandJudge(taskCommandInfo); System.out.println(whitelists); }