1. 删除Command类,Doris数据库改用TaskCommandInfo类作为实体类

2. 取消FailedState和GeneratingState的使用
3. 修改部分bug
This commit is contained in:
EnderByEndera
2024-01-15 20:40:55 +08:00
parent ee10a17aea
commit 6cfe4bf5d3
34 changed files with 482 additions and 247 deletions

View File

@@ -29,7 +29,7 @@ public class ProtectObject {
private String protectObjectSystemName;
@JsonProperty("proobj_ip_address")
@Pattern(regexp = "^(\\d{1,3})\\.(\\d{1,3})\\.(\\d{1,3})\\.(\\d{1,3})$", message = "Invalid IPv4 Address")
@Pattern(regexp = "^(\\d{1,3})\\.(\\d{1,3})\\.(\\d{1,3})\\.(\\d{1,3})$", message = "无效IPv4地址")
@ExcelProperty("IP地址")
@Schema(description = "防护对象IPv4地址", example = "192.168.0.1")
private String protectObjectIPAddress;

View File

@@ -1,32 +0,0 @@
package com.realtime.protection.configuration.entity.task;
import com.realtime.protection.configuration.utils.enums.ProtocolEnum;
import lombok.Data;
import java.time.LocalDateTime;
@Data
public class Command {
private FiveTupleWithMask fiveTupleWithMask;
private Long taskId;
private String operation;
private LocalDateTime validTime;
private LocalDateTime invalidTime;
public static Command generateCommand(TaskCommandInfo info, LocalDateTime validTime) {
Command command = new Command();
FiveTupleWithMask fiveTupleWithMask = info.getFiveTupleWithMask();
if (fiveTupleWithMask.getProtocol() != null)
fiveTupleWithMask.setProtocolNum(ProtocolEnum.valueOf(fiveTupleWithMask.getProtocol()).getProtocolNumber());
command.setFiveTupleWithMask(fiveTupleWithMask);
command.setTaskId(info.getTaskId());
command.setOperation(info.getOperation());
command.setValidTime(validTime);
command.setInvalidTime(info.getEndTime());
return command;
}
}

View File

@@ -1,19 +1,58 @@
package com.realtime.protection.configuration.entity.task;
import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.validation.constraints.Max;
import jakarta.validation.constraints.Min;
import jakarta.validation.constraints.Pattern;
import lombok.Data;
@Data
public class FiveTupleWithMask {
@Schema(description = "地址类型(IPv4 or IPv6)", example = "4")
private Integer addrType;
@Schema(description = "源IP", example = "192.168.104.14")
@Pattern(regexp = "^(\\d{1,3})\\.(\\d{1,3})\\.(\\d{1,3})\\.(\\d{1,3})$", message = "源IP无效IPv4地址")
private String sourceIP;
@Schema(description = "源端口", example = "114")
@Max(value = 65535, message = "源端口不可大于65535")
@Min(value = 1, message = "源端口不可小于1")
private String sourcePort;
@Schema(description = "目的IP", example = "102.165.11.39")
@Pattern(regexp = "^(\\d{1,3})\\.(\\d{1,3})\\.(\\d{1,3})\\.(\\d{1,3})$", message = "目的IP无效IPv4地址")
private String destinationIP;
@Schema(description = "目的端口", example = "514")
@Max(value = 65535, message = "目的端口不可大于65535")
@Min(value = 1, message = "目的端口不可小于1")
private String destinationPort;
@Schema(description = "协议名称", example = "TCP", accessMode = Schema.AccessMode.WRITE_ONLY)
private String protocol;
@Schema(description = "协议号", example = "6", accessMode = Schema.AccessMode.READ_ONLY)
private Integer protocolNum;
@Schema(description = "源IP掩码", example = "255.255.255.0")
@Pattern(regexp = "^(\\d{1,3})\\.(\\d{1,3})\\.(\\d{1,3})\\.(\\d{1,3})$", message = "源IP掩码无效IPv4地址")
private String maskSourceIP;
@Schema(description = "源端口掩码", example = "0")
@Max(value = 65535, message = "源端口掩码不可大于65535")
@Min(value = 1, message = "源端口掩码不可小于1")
private String maskSourcePort;
@Schema(description = "目的IP掩码", example = "255.255.0.0")
@Pattern(regexp = "^(\\d{1,3})\\.(\\d{1,3})\\.(\\d{1,3})\\.(\\d{1,3})$", message = "目的IP掩码无效IPv4地址")
private String maskDestinationIP;
@Schema(description = "目的端口掩码", example = "0")
@Max(value = 65535, message = "目的端口掩码不可大于65535")
@Min(value = 1, message = "目的端口掩码不可小于1")
private String maskDestinationPort;
@Schema(description = "协议掩码", example = "0")
private String maskProtocol;
}

View File

@@ -52,24 +52,24 @@ public class Task {
private String taskAct;
@JsonProperty("task_create_username")
@Schema(hidden = true)
@Schema(description = "任务创建人名称", accessMode = Schema.AccessMode.READ_ONLY)
private String taskCreateUsername;
@JsonProperty("task_create_depart")
@Schema(hidden = true)
@Schema(description = "任务创建人处室", accessMode = Schema.AccessMode.READ_ONLY)
private String taskCreateDepart;
@JsonProperty("task_create_userid")
@Schema(hidden = true)
private Long taskCreateUserId;
@Schema(description = "任务创建人ID", accessMode = Schema.AccessMode.READ_ONLY)
private Integer taskCreateUserId;
@JsonProperty("static_rule_ids")
@Schema(description = "静态规则ID列表动态和静态至少存在1个规则", example = "[10, 12]")
private List<Long> staticRuleIds;
private List<Integer> staticRuleIds;
@JsonProperty("dynamic_rule_ids")
@Schema(description = "动态规则ID列表动态和静态至少存在1个规则", example = "[20, 30]")
private List<Long> dynamicRuleIds;
private List<Integer> dynamicRuleIds;
@JsonProperty("task_status")
@Schema(description = "任务状态0为未启动1为生成中2为运行中3为暂停中4为已停止5为已结束6为失败", accessMode = Schema.AccessMode.READ_ONLY)

View File

@@ -1,19 +1,68 @@
package com.realtime.protection.configuration.entity.task;
import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.validation.constraints.NotNull;
import lombok.Data;
import java.time.LocalDateTime;
@Data
public class TaskCommandInfo {
private FiveTupleWithMask fiveTupleWithMask;
@Schema(description = "指令UUID", accessMode = Schema.AccessMode.READ_ONLY)
private String UUID;
@Schema(description = "任务ID", accessMode = Schema.AccessMode.READ_ONLY)
private Long taskId;
@Schema(description = "规则ID", hidden = true)
private Long ruleId;
// 额外字段
private String operation;
@Schema(description = "任务创建人名称", accessMode = Schema.AccessMode.READ_ONLY)
private String taskCreateUsername;
@Schema(description = "任务创建人处室", accessMode = Schema.AccessMode.READ_ONLY)
private String taskCreateDepart;
@Schema(description = "任务创建人ID", accessMode = Schema.AccessMode.READ_ONLY)
private Integer taskCreateUserId;
@Schema(description = "任务名称", example = "API测试任务")
@NotNull(message = "任务名称不能为空")
private String taskName;
@Schema(description = "任务类型", example = "1")
@NotNull(message = "任务类型不能为空")
private Integer taskType;
@Schema(description = "任务操作", example = "阻断")
@NotNull(message = "任务操作不能为空。")
private String taskAct;
@Schema(description = "指令下发频率", example = "30")
@NotNull(message = "指令下发频率不能为空。")
private Integer frequency;
@Schema(description = "任务开始时间", example = "2025-10-14T10:23:33")
@NotNull(message = "任务开始时间不能为空。")
private LocalDateTime startTime;
@Schema(description = "任务结束时间", example = "2026-10-22T10:33:22")
@NotNull(message = "指令结束时间不能为空。")
private LocalDateTime endTime;
@Schema(description = "五元组信息")
@NotNull(message = "五元组信息不能为空。")
private FiveTupleWithMask fiveTupleWithMask;
@Schema(description = "指令下发次数", accessMode = Schema.AccessMode.READ_ONLY)
private Integer commandSentTimes;
@Schema(description = "指令成功次数", accessMode = Schema.AccessMode.READ_ONLY)
private Integer commandSuccessTimes;
@Schema(description = "首次下发时间", accessMode = Schema.AccessMode.READ_ONLY)
private LocalDateTime earliestSendTime;
@Schema(description = "最新下发时间", accessMode = Schema.AccessMode.READ_ONLY)
private LocalDateTime latestSendTime;
}

View File

@@ -15,6 +15,8 @@ import org.springframework.web.bind.annotation.ExceptionHandler;
import org.springframework.web.bind.annotation.RestControllerAdvice;
import org.springframework.web.method.annotation.HandlerMethodValidationException;
import java.sql.SQLException;
import java.sql.SQLIntegrityConstraintViolationException;
import java.util.stream.Collectors;
@RestControllerAdvice
@@ -36,7 +38,12 @@ public class GlobalExceptionHandler {
@Order(2)
@ExceptionHandler(value = {PersistenceException.class, DuplicateKeyException.class})
@ExceptionHandler(value = {
PersistenceException.class,
DuplicateKeyException.class,
SQLException.class,
SQLIntegrityConstraintViolationException.class
})
public ResponseResult handleSQLException(Exception e) {
log.info("遭遇数据库异常:" + e.getMessage());
return ResponseResult.invalid().setMessage(
@@ -89,7 +96,8 @@ public class GlobalExceptionHandler {
.setMessage("Doris数据库指令生成遭遇异常" + e.getMessage());
try {
stateChangeService.changeState(StateEnum.FAILED.getStateNum(), e.taskId);
// 内部修改状态,可以跳过一切状态检查
stateChangeService.changeState(StateEnum.FAILED.getStateNum(), e.taskId, true);
} catch (Exception another) {
responseResult.setAnother(ResponseResult.error().setMessage(e.getMessage()));
}

View File

@@ -1,5 +1,8 @@
package com.realtime.protection.configuration.utils.status;
import lombok.extern.slf4j.Slf4j;
@Slf4j
public class AuditStatusValidator {
private final Integer auditStatusOriginal;
@@ -12,8 +15,8 @@ public class AuditStatusValidator {
return new AuditStatusValidator(auditStatusOriginal);
}
public Boolean checkValidate(Integer auditStatusNow) {
switch (auditStatusNow) {
public Boolean checkValidate(Integer newAuditStatus) {
switch (newAuditStatus) {
case 0, 1 -> {
return auditStatusOriginal != 2;
}
@@ -21,6 +24,7 @@ public class AuditStatusValidator {
return auditStatusOriginal != 1;
}
default -> {
log.debug("欲修改的审核状态不正确,需要使用正确的审核状态,当前的审核状态:" + auditStatusOriginal);
return false;
}
}

View File

@@ -1,6 +1,6 @@
package com.realtime.protection.server.command;
import com.realtime.protection.configuration.entity.task.Command;
import com.realtime.protection.configuration.entity.task.TaskCommandInfo;
import org.apache.ibatis.annotations.Mapper;
import org.apache.ibatis.annotations.Param;
@@ -8,13 +8,15 @@ import java.util.List;
@Mapper
public interface CommandMapper {
Boolean createCommand(@Param("command") Command command);
Boolean createCommand(@Param("info") TaskCommandInfo taskCommandInfo);
void createCommands(@Param("commands") List<Command> commands);
void createCommands(@Param("command_infos") List<TaskCommandInfo> taskCommandInfos);
Boolean stopCommandsByTaskId(@Param("task_id") Long taskId);
Boolean removeCommandsByTaskId(@Param("task_id") Long taskId);
Boolean startCommandsByTaskId(@Param("task_id") Long taskId);
List<TaskCommandInfo> queryCommandInfoByTaskId(@Param("task_id") Long taskId);
}

View File

@@ -2,120 +2,69 @@ package com.realtime.protection.server.command;
import com.alibaba.excel.util.ListUtils;
import com.baomidou.dynamic.datasource.annotation.DS;
import com.realtime.protection.configuration.entity.task.Command;
import com.realtime.protection.configuration.entity.task.TaskCommandInfo;
import com.realtime.protection.configuration.exception.DorisStartException;
import com.realtime.protection.configuration.utils.SqlSessionWrapper;
import com.realtime.protection.configuration.utils.enums.StateEnum;
import com.realtime.protection.server.task.TaskService;
import lombok.extern.slf4j.Slf4j;
import org.springframework.scheduling.annotation.Async;
import org.springframework.stereotype.Service;
import java.time.LocalDateTime;
import java.util.List;
import java.util.function.Function;
@Service
@Slf4j
@DS("doris")
public class CommandService {
private final CommandMapper commandMapper;
private final TaskService taskService;
private final SqlSessionWrapper sqlSessionWrapper;
private static final int BatchSize = 100;
private final Function<CommandMapper, Function<TaskCommandInfo, Void>> createCommandBatchFunction;
public CommandService(CommandMapper commandMapper, TaskService taskService, SqlSessionWrapper sqlSessionWrapper) {
public CommandService(CommandMapper commandMapper, SqlSessionWrapper sqlSessionWrapper) {
this.commandMapper = commandMapper;
this.taskService = taskService;
this.sqlSessionWrapper = sqlSessionWrapper;
this.createCommandBatchFunction = mapper -> info -> {
if (info.getFrequency() == null) {
Command command = Command.generateCommand(info, info.getStartTime());
mapper.createCommand(command);
}
}
List<Command> commandBatch = ListUtils.newArrayListWithExpectedSize(BatchSize);
LocalDateTime validTime = info.getStartTime();
public Boolean createCommand(TaskCommandInfo commandInfo) {
return commandMapper.createCommand(commandInfo);
}
while (validTime.isBefore(info.getEndTime())) {
Command command = Command.generateCommand(info, validTime);
commandBatch.add(command);
validTime = validTime.plusMinutes(info.getFrequency());
if (commandBatch.size() < BatchSize) {
public void createCommands(List<TaskCommandInfo> taskCommandInfos) {
Function<CommandMapper, Function<List<TaskCommandInfo>, Boolean>> function = mapper -> list -> {
List<TaskCommandInfo> taskCommandInfoBatch = ListUtils.newArrayListWithExpectedSize(BatchSize);
for (TaskCommandInfo info : list) {
taskCommandInfoBatch.add(info);
if (taskCommandInfoBatch.size() < BatchSize) {
continue;
}
mapper.createCommands(commandBatch);
commandBatch.clear();
commandMapper.createCommands(taskCommandInfoBatch);
taskCommandInfoBatch.clear();
}
if (!commandBatch.isEmpty()) {
mapper.createCommands(commandBatch);
commandBatch.clear();
if (!taskCommandInfoBatch.isEmpty()) {
commandMapper.createCommands(taskCommandInfoBatch);
taskCommandInfoBatch.clear();
}
log.debug(String.format("在task(%d)和rule(%d)中构建了全部指令",
info.getTaskId(), info.getRuleId()));
return null;
};
}
@Async
@DS("doris")
public void createCommand(TaskCommandInfo commandInfo) throws DorisStartException {
try {
sqlSessionWrapper.startBatchSession(CommandMapper.class, createCommandBatchFunction, commandInfo);
taskService.changeTaskStatus(commandInfo.getTaskId(), StateEnum.RUNNING.getStateNum());
} catch (Exception e) {
throw new DorisStartException(e, commandInfo.getTaskId());
}
}
@Async
@DS("doris")
public void createCommands(List<TaskCommandInfo> taskCommandInfos) throws DorisStartException {
Function<CommandMapper, Function<List<TaskCommandInfo>, Void>> function = mapper -> list -> {
if (list == null || list.isEmpty()) {
return null;
}
for (TaskCommandInfo info : list) {
createCommandBatchFunction.apply(mapper).apply(info);
}
taskService.changeTaskStatus(list.get(0).getTaskId(), StateEnum.RUNNING.getStateNum());
return null;
return true;
};
try {
sqlSessionWrapper.startBatchSession(CommandMapper.class, function, taskCommandInfos);
} catch (Exception e) {
TaskCommandInfo info = null;
if (taskCommandInfos != null) {
info = taskCommandInfos.get(0);
}
Long taskId = null;
if (info != null) {
taskId = info.getTaskId();
}
throw new DorisStartException(e, taskId);
}
sqlSessionWrapper.startBatchSession(CommandMapper.class, function, taskCommandInfos);
}
public List<TaskCommandInfo> queryCommandInfoByTaskId(Long taskId) {
return commandMapper.queryCommandInfoByTaskId(taskId);
}
@DS("doris")
public Boolean startCommandsByTaskId(Long taskId) {
return commandMapper.startCommandsByTaskId(taskId);
}
@DS("doris")
public Boolean stopCommandsByTaskId(Long taskId) {
return commandMapper.stopCommandsByTaskId(taskId);
}
@DS("doris")
public Boolean removeCommandsByTaskId(Long taskId) {
return commandMapper.removeCommandsByTaskId(taskId);
}

View File

@@ -78,7 +78,6 @@ public class ProtectObjectService {
return false;
}
boolean success = true;
Integer result;
List<Integer> protectObjectBatch = ListUtils.newArrayListWithExpectedSize(batchSize);
for (Integer protectObjectId : list) {

View File

@@ -74,7 +74,7 @@ public interface DynamicRuleControllerApi {
@Parameter(name = "ids", description = "动态规则id列表")
}
)
public ResponseResult deleteDynamicRuleObjects(@PathVariable List<Integer> ids);
ResponseResult deleteDynamicRuleObjects(@PathVariable List<Integer> ids);
@Operation(
summary = "修改动态规则",
@@ -94,7 +94,7 @@ public interface DynamicRuleControllerApi {
requestBody = @io.swagger.v3.oas.annotations.parameters.RequestBody(
description = "动态规则信息")
)
public ResponseResult updateDynamicRuleObject(
ResponseResult updateDynamicRuleObject(
@PathVariable Integer id,
@RequestBody @Valid DynamicRuleObject dynamicRuleObject);
@@ -114,7 +114,7 @@ public interface DynamicRuleControllerApi {
@Parameter(name = "id", description = "动态规则ID", example = "2")
}
)
public ResponseResult queryDynamicRuleObjectById(@PathVariable Integer id);
ResponseResult queryDynamicRuleObjectById(@PathVariable Integer id);
@Operation(
summary = "根据条件查询多个动态规则",
@@ -135,7 +135,7 @@ public interface DynamicRuleControllerApi {
@Parameter(name = "page_size", description = "每页大小", example = "10")
}
)
public ResponseResult queryDynamicRuleObject(
ResponseResult queryDynamicRuleObject(
@RequestParam(value = "name", required = false) String dynamicRuleName,
@RequestParam(value = "id", required = false) Integer dynamicRuleId,
@RequestParam(value = "page", defaultValue = "1") Integer page,

View File

@@ -183,6 +183,6 @@ public interface StaticRuleControllerApi {
@Parameter(name = "auditStatus", description = "要修改为的静态规则审核状态")
}
)
public ResponseResult updateStaticRuleAuditStatus(@PathVariable Integer id, @PathVariable Integer auditStatus);
ResponseResult updateStaticRuleAuditStatus(@PathVariable Integer id, @PathVariable Integer auditStatus);
}

View File

@@ -1,9 +1,11 @@
package com.realtime.protection.server.task;
import com.realtime.protection.configuration.entity.task.Task;
import com.realtime.protection.configuration.entity.task.TaskCommandInfo;
import com.realtime.protection.configuration.exception.DorisStartException;
import com.realtime.protection.configuration.response.ResponseResult;
import com.realtime.protection.configuration.utils.EntityUtils;
import com.realtime.protection.server.command.CommandService;
import com.realtime.protection.server.task.status.StateChangeService;
import jakarta.validation.Valid;
import jakarta.validation.constraints.Max;
@@ -18,10 +20,12 @@ import java.util.List;
public class TaskController implements TaskControllerApi {
private final TaskService taskService;
private final CommandService commandService;
private final StateChangeService stateChangeService;
public TaskController(TaskService taskService, StateChangeService stateChangeService) {
public TaskController(TaskService taskService, CommandService commandService, StateChangeService stateChangeService) {
this.taskService = taskService;
this.commandService = commandService;
this.stateChangeService = stateChangeService;
}
@@ -43,6 +47,24 @@ public class TaskController implements TaskControllerApi {
.setData("success", false);
}
// API推送Endpoint
@Override
@PostMapping("/api/new")
public ResponseResult newTaskWithAPI(@RequestBody @Valid TaskCommandInfo taskCommandInfo) {
Long taskId = taskService.newTaskUsingCommandInfo(taskCommandInfo);
if (taskId <= 0) {
return ResponseResult.invalid()
.setData("taskId", -1)
.setData("success", false);
}
commandService.createCommand(taskCommandInfo);
return ResponseResult.ok()
.setData("taskId", taskId)
.setData("success", true);
}
@Override
@GetMapping("/query")
public ResponseResult queryTasks(@RequestParam(value = "task_status", required = false) Integer taskStatus,
@@ -62,7 +84,7 @@ public class TaskController implements TaskControllerApi {
Task task = taskService.queryTask(id);
if (task == null) {
return ResponseResult.invalid().setMessage("Task ID is invalid");
return ResponseResult.invalid().setMessage("无效Task ID也许该ID对应的任务不存在");
}
return ResponseResult.ok()
@@ -103,7 +125,16 @@ public class TaskController implements TaskControllerApi {
@PathVariable @NotNull Long taskId) throws DorisStartException {
return ResponseResult.ok()
.setData("task_id", taskId)
.setData("success", stateChangeService.changeState(stateNum, taskId))
// 外部修改状态,需要进行状态检查
.setData("success", stateChangeService.changeState(stateNum, taskId, false))
.setData("status_now", taskService.queryTaskStatus(taskId));
}
@Override
@GetMapping("/{taskId}/commands")
public ResponseResult queryCommandInfoByTaskId(@PathVariable Long taskId) {
return ResponseResult.ok()
.setData("success", true)
.setData("commands", commandService.queryCommandInfoByTaskId(taskId));
}
}

View File

@@ -1,6 +1,7 @@
package com.realtime.protection.server.task;
import com.realtime.protection.configuration.entity.task.Task;
import com.realtime.protection.configuration.entity.task.TaskCommandInfo;
import com.realtime.protection.configuration.exception.DorisStartException;
import com.realtime.protection.configuration.response.ResponseResult;
import io.swagger.v3.oas.annotations.Operation;
@@ -34,6 +35,24 @@ public interface TaskControllerApi {
)
ResponseResult newTask(@RequestBody @Valid Task task);
// API推送Endpoint
@PostMapping("/api/new")
@Operation(
summary = "任务推送外部API",
description = "提供给外部的任务推送API",
responses = {
@ApiResponse(
description = "返回外部任务推送结果",
content = @Content(
mediaType = "application/json",
schema = @Schema(implementation = ResponseResult.class)
)
)
},
requestBody = @io.swagger.v3.oas.annotations.parameters.RequestBody(description = "任务推送信息")
)
ResponseResult newTaskWithAPI(@RequestBody @Valid TaskCommandInfo taskCommandInfo) throws DorisStartException;
@GetMapping("/query")
@Operation(
summary = "查询任务",
@@ -162,4 +181,23 @@ public interface TaskControllerApi {
)
ResponseResult changeTaskStatus(@PathVariable @NotNull @Min(0) @Max(6) Integer stateNum,
@PathVariable @NotNull Long taskId) throws DorisStartException;
@GetMapping("/{taskId}/commands")
@Operation(
summary = "获得任务已推送指令的相关数据",
description = "获得任务已推送指令的相关数据,包括最新下发时间、首次下发时间、下发次数、下发成功次数等",
responses = {
@ApiResponse(
description = "返回任务已推送指令的相关数据",
content = @Content(
mediaType = "application/json",
schema = @Schema(implementation = ResponseResult.class)
)
)
},
parameters = {
@Parameter(name = "taskId", description = "任务ID")
}
)
ResponseResult queryCommandInfoByTaskId(@PathVariable Long taskId);
}

View File

@@ -12,10 +12,12 @@ public interface TaskMapper {
void newTask(@Param("task") Task task);
void newTaskStaticRuleConcat(@Param("task_id") Long taskId,
@Param("rule_ids") List<Long> staticRuleIds);
@Param("rule_ids") List<Integer> staticRuleIds);
void newTaskDynamicRuleConcat(@Param("task_id") Long taskId,
@Param("rule_ids") List<Long> dynamicRuleIds);
@Param("rule_ids") List<Integer> dynamicRuleIds);
void newTaskUsingCommandInfo(@Param("info") TaskCommandInfo taskCommandInfo);
List<Task> queryTasks(@Param("task_status") Integer taskStatus, @Param("task_type") String task_type,
@Param("task_name") String taskName, @Param("task_creator") String taskCreator,
@@ -35,9 +37,13 @@ public interface TaskMapper {
Boolean changeTaskStatus(@Param("task_id") Long taskId, @Param("state") Integer stateNum);
List<TaskCommandInfo> getStaticCommands(@Param("task_id") Long taskId);
List<TaskCommandInfo> getStaticCommandInfos(@Param("task_id") Long taskId);
Integer queryTaskAuditStatus(@Param("task_id") Long taskId);
Integer queryTaskStatus(@Param("task_id") Long taskId);
List<Integer> queryDynamicRuleIdsFromTaskId(@Param("task_id") Long taskId);
List<Integer> queryStaticRuleIdsFromTaskId(@Param("task_id") Long taskId);
}

View File

@@ -10,6 +10,7 @@ import org.springframework.transaction.annotation.Transactional;
import java.util.List;
@Service
@DS("mysql")
public class TaskService {
private final TaskMapper taskMapper;
@@ -19,6 +20,10 @@ public class TaskService {
@Transactional
public Long newTask(Task task) {
task.setTaskCreateUserId(1);
task.setTaskCreateUsername("xxx");
task.setTaskCreateDepart("xxx");
taskMapper.newTask(task);
if (task.getStaticRuleIds() != null && !task.getStaticRuleIds().isEmpty())
@@ -30,14 +35,33 @@ public class TaskService {
return task.getTaskId();
}
@Transactional
public List<Task> queryTasks(Integer taskStatus,
String taskType, String taskName, String taskCreator,
Integer page, Integer pageSize) {
return taskMapper.queryTasks(taskStatus, taskType, taskName, taskCreator, page, pageSize);
List<Task> tasks = taskMapper.queryTasks(taskStatus, taskType, taskName, taskCreator, page, pageSize);
for (Task task : tasks) {
if (task == null) {
continue;
}
task.setStaticRuleIds(taskMapper.queryStaticRuleIdsFromTaskId(task.getTaskId()));
task.setDynamicRuleIds(taskMapper.queryDynamicRuleIdsFromTaskId(task.getTaskId()));
}
return tasks;
}
@Transactional
public Task queryTask(Long id) {
return taskMapper.queryTask(id);
Task task = taskMapper.queryTask(id);
if (task == null) {
return null;
}
task.setStaticRuleIds(taskMapper.queryStaticRuleIdsFromTaskId(task.getTaskId()));
task.setDynamicRuleIds(taskMapper.queryDynamicRuleIdsFromTaskId(task.getTaskId()));
return task;
}
@Transactional
@@ -74,13 +98,13 @@ public class TaskService {
return taskMapper.deleteTask(taskId);
}
@DS("mysql")
public Boolean changeTaskStatus(Long taskId, Integer stateNum) {
return taskMapper.changeTaskStatus(taskId, stateNum);
}
public List<TaskCommandInfo> getStaticCommandInfos(Long taskId) {
return taskMapper.getStaticCommands(taskId);
return taskMapper.getStaticCommandInfos(taskId);
}
public Integer queryTaskAuditStatus(Long taskId) {
@@ -90,4 +114,9 @@ public class TaskService {
public Integer queryTaskStatus(Long taskId) {
return taskMapper.queryTaskStatus(taskId);
}
public Long newTaskUsingCommandInfo(TaskCommandInfo taskCommandInfo) {
taskMapper.newTaskUsingCommandInfo(taskCommandInfo);
return taskCommandInfo.getTaskId();
}
}

View File

@@ -23,12 +23,7 @@ public class StateChangeService {
}
@DSTransactional
public Boolean changeState(Integer stateNum, Long taskId) throws DorisStartException {
if (Objects.equals(stateNum, StateEnum.GENERATING.getStateNum()) ||
Objects.equals(stateNum, StateEnum.FAILED.getStateNum())) {
throw new IllegalArgumentException("非法任务状态:" + StateEnum.getStateByNum(stateNum));
}
public Boolean changeState(Integer stateNum, Long taskId, Boolean inner) throws DorisStartException {
Integer originalStateNum = taskService.queryTaskStatus(taskId);
if (originalStateNum == null) {
throw new IllegalArgumentException("无法找到" + taskId + "的任务状态也许任务ID不存在?");
@@ -38,8 +33,15 @@ public class StateChangeService {
State newState = StateEnum.getStateByNum(stateNum);
if (newState == null) {
return false;
if (!inner && !checkState(originalState, newState)) {
throw new IllegalArgumentException(
String.format("任务状态转换失败,原状态:%s欲切换状态%s",
originalState.getClass().getSimpleName(),
newState.getClass().getSimpleName()));
}
if (Objects.equals(originalState, newState)) {
return true;
}
if (!originalState.handle(newState, commandService, taskService, taskId)) {
@@ -54,4 +56,21 @@ public class StateChangeService {
// 这里一定是handle成功的状态我们再进行task status的修改如果handle失败要么返回false要么抛出异常不会进入此处
return taskService.changeTaskStatus(taskId, stateNum);
}
private Boolean checkState(State originalState, State newState) {
if (originalState == null || newState == null) {
return false;
}
// FAILED、FINISHED状态以及GENERATING都只能在程序内部修改外部接口不能修改
if (Objects.equals(newState, StateEnum.FAILED.getState())
|| Objects.equals(newState, StateEnum.FINISHED.getState())
|| Objects.equals(newState, StateEnum.GENERATING.getState())) {
return false;
}
// 在任务状态转换为GENERATING之后我们需要在外部接口屏蔽掉所有状态
// 我们需要保证只有任务创建函数才能将GENERATING状态转换为RUNNING状态
return !Objects.equals(originalState, StateEnum.GENERATING.getState());
}
}

View File

@@ -13,7 +13,6 @@ public class FailedState extends StateHandler implements State {
return switch (StateEnum.getStateEnumByState(newState)) {
case RUNNING -> handleStart(taskService, commandService, taskId);
case STOP -> handleStop(commandService, taskId);
case FAILED -> true;
default -> throw new IllegalStateException("Unexpected value: " + StateEnum.getStateEnumByState(newState));
};
}

View File

@@ -1,6 +1,5 @@
package com.realtime.protection.server.task.status.states;
import com.realtime.protection.configuration.exception.DorisStartException;
import com.realtime.protection.configuration.utils.enums.StateEnum;
import com.realtime.protection.configuration.utils.status.State;
import com.realtime.protection.server.command.CommandService;
@@ -9,9 +8,9 @@ import com.realtime.protection.server.task.status.StateHandler;
public class GeneratingState extends StateHandler implements State {
@Override
public Boolean handle(State newState, CommandService commandService, TaskService taskService, Long taskId) throws DorisStartException {
public Boolean handle(State newState, CommandService commandService, TaskService taskService, Long taskId) {
return switch (StateEnum.getStateEnumByState(newState)) {
case RUNNING, GENERATING -> true;
case RUNNING -> true;
case FAILED -> handleFailed(commandService, taskId);
default -> throw new IllegalStateException("Unexpected value: " + StateEnum.getStateEnumByState(newState));
};

View File

@@ -11,8 +11,8 @@ public class PendingState extends StateHandler implements State {
@Override
public Boolean handle(State newState, CommandService commandService, TaskService taskService, Long taskId) throws DorisStartException {
return switch (StateEnum.getStateEnumByState(newState)) {
case GENERATING -> handleStart(taskService, commandService, taskId);
case FAILED -> handleFailed(commandService, taskId);
case RUNNING -> handleStart(taskService, commandService, taskId);
default -> throw new IllegalStateException("Unexpected value: " + StateEnum.getStateEnumByState(newState));
};
}

View File

@@ -10,7 +10,6 @@ public class RunningState extends StateHandler implements State {
@Override
public Boolean handle(State newState, CommandService commandService, TaskService taskService, Long taskId) {
return switch (StateEnum.getStateEnumByState(newState)) {
case RUNNING, GENERATING -> true;
case PAUSED -> handlePause(commandService, taskId);
case STOP -> handleStop(commandService, taskId);
case FINISHED -> handleFinish(commandService, taskId);

View File

@@ -2,7 +2,7 @@ package com.realtime.protection.server.whitelist;
import com.alibaba.excel.util.ListUtils;
import com.realtime.protection.configuration.entity.rule.staticrule.StaticRuleObject;
import com.realtime.protection.configuration.entity.task.Command;
import com.realtime.protection.configuration.entity.task.TaskCommandInfo;
import com.realtime.protection.configuration.entity.whitelist.WhiteListObject;
import com.realtime.protection.configuration.utils.SqlSessionWrapper;
import com.realtime.protection.configuration.utils.status.AuditStatusValidator;
@@ -129,10 +129,10 @@ public class WhiteListService {
}
public List<WhiteListObject> whiteListCommandJudge(Command command) {
public List<WhiteListObject> whiteListCommandJudge(TaskCommandInfo taskCommandInfo) {
//参数应该是指令,不管动态静态
// 命中的whitelist列表每一列包含ip port url
return whiteListMapper.whiteListCommandJudge(command.getFiveTupleWithMask());
return whiteListMapper.whiteListCommandJudge(taskCommandInfo.getFiveTupleWithMask());
}

View File

@@ -48,6 +48,8 @@ task:
springdoc:
api-docs:
enabled: false
enabled: true
path: /api-docs
swagger-ui:
enabled: false
path: /swagger
packages-to-scan: com.realtime.protection.server

View File

@@ -3,39 +3,79 @@
PUBLIC "-//mybatis.org//DTD Mapper 3.0//EN"
"http://mybatis.org/dtd/mybatis-3-mapper.dtd">
<mapper namespace="com.realtime.protection.server.command.CommandMapper">
<insert id="createCommand" parameterType="com.realtime.protection.configuration.entity.task.Command">
insert into t_command(COMMAND_ID, ADDR_TYPE, SRC_IP, SRC_PORT, DST_IP, DST_PORT, PROTOCOL, MASK_SRC_IP,
MASK_SRC_PORT,
MASK_DST_IP, MASK_DST_PORT, IS_VALID, VALID_TIME, INVALID_TIME, IS_SENT,
CREATE_TIME, LAST_UPDATE, IS_DELETED, TASK_ID)
values (UUID(), DEFAULT,
#{command.fiveTupleWithMask.sourceIP},
#{command.fiveTupleWithMask.sourcePort}, #{command.fiveTupleWithMask.destinationIP},
#{command.fiveTupleWithMask.destinationPort},
#{command.fiveTupleWithMask.protocolNum}, #{command.fiveTupleWithMask.maskSourceIP},
#{command.fiveTupleWithMask.maskSourcePort},
#{command.fiveTupleWithMask.maskDestinationIP}, #{command.fiveTupleWithMask.maskDestinationPort}, TRUE,
#{command.validTime}, #{command.invalidTime}, FALSE, NOW(), NOW(), FALSE, #{command.taskId})
<insert id="createCommand" parameterType="com.realtime.protection.configuration.entity.task.TaskCommandInfo">
insert into t_command(COMMAND_ID, TASK_ID, TASK_ACT, FREQUENCY, ADDR_TYPE, SRC_IP, SRC_PORT, DST_IP, DST_PORT,
PROTOCOL,
MASK_SRC_IP, MASK_SRC_PORT, MASK_DST_IP, MASK_DST_PORT, MASK_PROTOCOL, VALID_TIME,
INVALID_TIME, IS_VALID,
SEND_TIMES, SUCCESS_TIMES, CREATE_TIME, LAST_UPDATE, IS_DELETED)
values (UUID(), #{info.taskId}, #{info.taskAct}, #{info.frequency},
#{info.fiveTupleWithMask.addrType},
#{info.fiveTupleWithMask.sourceIP}, #{info.fiveTupleWithMask.sourcePort},
#{info.fiveTupleWithMask.destinationIP}, #{info.fiveTupleWithMask.destinationPort},
#{info.fiveTupleWithMask.protocolNum},
#{info.fiveTupleWithMask.maskSourceIP}, #{info.fiveTupleWithMask.maskSourcePort},
#{info.fiveTupleWithMask.maskDestinationIP}, #{info.fiveTupleWithMask.maskDestinationPort},
#{info.fiveTupleWithMask.maskProtocol},
#{info.startTime}, #{info.endTime}, TRUE, 0, 0,
NOW(), NOW(), FALSE)
</insert>
<insert id="createCommands" parameterType="com.realtime.protection.configuration.entity.task.Command">
insert into t_command(COMMAND_ID, ADDR_TYPE, SRC_IP, SRC_PORT, DST_IP, DST_PORT, PROTOCOL, MASK_SRC_IP,
MASK_SRC_PORT,
MASK_DST_IP, MASK_DST_PORT, IS_VALID, VALID_TIME, INVALID_TIME, IS_SENT,
CREATE_TIME, LAST_UPDATE, IS_DELETED, TASK_ID)
<insert id="createCommands" parameterType="com.realtime.protection.configuration.entity.task.TaskCommandInfo">
insert into t_command(COMMAND_ID, TASK_ID, TASK_ACT, FREQUENCY, ADDR_TYPE, SRC_IP, SRC_PORT, DST_IP, DST_PORT,
PROTOCOL,
MASK_SRC_IP, MASK_SRC_PORT, MASK_DST_IP, MASK_DST_PORT, MASK_PROTOCOL, VALID_TIME, INVALID_TIME, IS_VALID,
SEND_TIMES, SUCCESS_TIMES, CREATE_TIME, LAST_UPDATE, IS_DELETED)
values
<foreach collection="commands" item="command" separator=",">
(UUID(), DEFAULT,
#{command.fiveTupleWithMask.sourceIP},
#{command.fiveTupleWithMask.sourcePort}, #{command.fiveTupleWithMask.destinationIP},
#{command.fiveTupleWithMask.destinationPort},
#{command.fiveTupleWithMask.protocolNum}, #{command.fiveTupleWithMask.maskSourceIP},
#{command.fiveTupleWithMask.maskSourcePort},
#{command.fiveTupleWithMask.maskDestinationIP}, #{command.fiveTupleWithMask.maskDestinationPort}, TRUE,
#{command.validTime}, #{command.invalidTime}, FALSE, NOW(), NOW(), FALSE, #{command.taskId})
<foreach collection="command_infos" item="info" separator=",">
(UUID(), #{info.taskId}, #{info.taskAct}, #{info.frequency},
#{info.fiveTupleWithMask.addrType},
#{info.fiveTupleWithMask.sourceIP}, #{info.fiveTupleWithMask.sourcePort},
#{info.fiveTupleWithMask.destinationIP}, #{info.fiveTupleWithMask.destinationPort},
#{info.fiveTupleWithMask.protocolNum},
#{info.fiveTupleWithMask.maskSourceIP}, #{info.fiveTupleWithMask.maskSourcePort},
#{info.fiveTupleWithMask.maskDestinationIP}, #{info.fiveTupleWithMask.maskDestinationPort},
#{info.fiveTupleWithMask.maskProtocol},
#{info.startTime}, #{info.endTime}, TRUE, 0, 0,
NOW(), NOW(), FALSE
)
</foreach>
</insert>
<resultMap id="commandStatMap" type="com.realtime.protection.configuration.entity.task.TaskCommandInfo">
<id column="COMMAND_ID" property="UUID"/>
<result column="TASK_ACT" property="taskAct"/>
<result column="SEND_TIMES" property="commandSentTimes"/>
<result column="SUCCESS_TIMES" property="commandSuccessTimes"/>
<result column="FIRST_SEND_TIME" property="earliestSendTime"/>
<result column="LAST_SEND_TIME" property="latestSendTime"/>
<association property="fiveTupleWithMask">
<result column="SRC_IP" property="sourceIP"/>
<result column="SRC_PORT" property="sourcePort"/>
<result column="DST_IP" property="destinationIP"/>
<result column="DST_PORT" property="destinationPort"/>
<result column="PROTOCOL" property="protocolNum"/>
</association>
</resultMap>
<select id="queryCommandInfoByTaskId" resultMap="commandStatMap">
SELECT COMMAND_ID,
TASK_ACT,
SEND_TIMES,
SUCCESS_TIMES,
FIRST_SEND_TIME,
LASt_SEND_TIME,
SRC_IP,
SRC_PORT,
DST_IP,
DST_PORT,
PROTOCOL
FROM t_command
WHERE TASK_ID = #{task_id}
AND IS_DELETED = FALSE
</select>
<update id="stopCommandsByTaskId">
UPDATE t_command
SET IS_VALID = FALSE,
@@ -57,5 +97,6 @@
SET IS_DELETED = TRUE,
LAST_UPDATE = NOW()
WHERE TASK_ID = #{task_id}
AND IS_DELETED = FALSE
</update>
</mapper>

View File

@@ -46,7 +46,10 @@
</resultMap>
<select id="queryProtectObjects" resultMap="protectObjectMap">
SELECT * FROM t_protect_object
SELECT
protect_object_id, protect_object_name, protect_object_system_name, INET_NTOA(protect_object_ip),
protect_object_port, protect_object_url, protect_object_protocol, protect_object_audit_status
FROM t_protect_object
<where>
<if test="proobj_name != null">protect_object_name LIKE CONCAT('%', #{proobj_name}, '%')</if>
<if test="proobj_id != null">protect_object_id = #{proobj_id}</if>
@@ -55,7 +58,14 @@
</select>
<select id="queryProtectObject" resultMap="protectObjectMap">
SELECT *
SELECT protect_object_id,
protect_object_name,
protect_object_system_name,
INET_NTOA(protect_object_ip),
protect_object_port,
protect_object_url,
protect_object_protocol,
protect_object_audit_status
FROM t_protect_object
WHERE protect_object_id = #{proobj_id}
</select>

View File

@@ -18,27 +18,25 @@
<update id="newTaskStaticRuleConcat">
UPDATE t_static_rule
SET static_rule_used_task_id = #{task_id}
<where>
<if test="rule_ids != null and rule_ids.size() > 0">
AND static_rule_id IN
<foreach collection="rule_ids" item="rule_id" open="(" close=")" separator=",">
#{rule_id}
</foreach>
</if>
</where>
WHERE
<if test="rule_ids != null and rule_ids.size() > 0">
static_rule_id IN
<foreach collection="rule_ids" item="rule_id" open="(" close=")" separator=",">
#{rule_id}
</foreach>
</if>
</update>
<update id="newTaskDynamicRuleConcat">
UPDATE t_dynamic_rule
SET dynamic_rule_used_task_id = #{task_id}
<where>
<if test="rule_ids != null and rule_ids.size() > 0">
AND dynamic_rule_id IN
<foreach collection="rule_ids" item="rule_id" open="(" close=")" separator=",">
#{rule_id}
</foreach>
</if>
</where>
WHERE
<if test="rule_ids != null and rule_ids.size() > 0">
dynamic_rule_id IN
<foreach collection="rule_ids" item="rule_id" open="(" close=")" separator=",">
#{rule_id}
</foreach>
</if>
</update>
<resultMap id="taskMap" type="com.realtime.protection.configuration.entity.task.Task">
@@ -55,19 +53,10 @@
<result column="task_create_username" property="taskCreateUsername"/>
<result column="task_create_depart" property="taskCreateDepart"/>
<collection property="staticRuleIds" ofType="java.lang.Integer">
<id column="static_rule_id"/>
</collection>
<collection property="dynamicRuleIds" ofType="java.lang.Integer">
<id column="dynamic_rule_id"/>
</collection>
</resultMap>
<select id="queryTasks" resultMap="taskMap">
SELECT * FROM t_task
LEFT JOIN realtime_protection.t_static_rule tsr on t_task.task_id = tsr.static_rule_used_task_id
LEFT JOIN realtime_protection.t_dynamic_rule tdr on t_task.task_id = tdr.dynamic_rule_used_task_id
<where>
<if test="task_status != null">
AND task_status = #{task_status}
@@ -79,17 +68,27 @@
AND task_name LIKE CONCAT('%', #{task_name}, '%')
</if>
<if test="task_creator != null">
AND task_create_username = #{task_creator}
AND task_create_username LIKE CONCAT('%', #{task_creator}, '%')
</if>
</where>
LIMIT ${(page - 1) * page_size}, #{page_size}
</select>
<select id="queryStaticRuleIdsFromTaskId" resultType="java.lang.Integer">
SELECT static_rule_id
FROM t_static_rule
WHERE static_rule_used_task_id = #{task_id}
</select>
<select id="queryDynamicRuleIdsFromTaskId" resultType="java.lang.Integer">
SELECT dynamic_rule_id
FROM t_dynamic_rule
WHERE dynamic_rule_used_task_id = #{task_id}
</select>
<select id="queryTask" resultMap="taskMap">
SELECT *
FROM t_task
LEFT JOIN realtime_protection.t_static_rule tsr on t_task.task_id = tsr.static_rule_used_task_id
LEFT JOIN realtime_protection.t_dynamic_rule tdr on t_task.task_id = tdr.dynamic_rule_used_task_id
WHERE t_task.task_id = #{task_id}
</select>
@@ -152,9 +151,16 @@
</delete>
<resultMap id="staticCommandMap" type="com.realtime.protection.configuration.entity.task.TaskCommandInfo">
<result column="task_act" property="operation"/>
<result column="task_name" property="taskName"/>
<result column="task_create_username" property="taskCreateUsername"/>
<result column="task_create_depart" property="taskCreateDepart"/>
<result column="task_create_userid" property="taskCreateUserId"/>
<result column="task_id" property="taskId"/>
<result column="static_rule_id" property="ruleId"/>
<result column="task_act" property="taskAct"/>
<result column="task_type" property="taskType"/>
<result column="static_rule_frequency" property="frequency"/>
<result column="task_start_time" property="startTime"/>
<result column="task_end_time" property="endTime"/>
@@ -170,15 +176,38 @@
<result column="static_rule_msport" property="maskSourcePort"/>
<result column="static_rule_mdip" property="maskDestinationIP"/>
<result column="static_rule_mdport" property="maskDestinationPort"/>
<result column="static_rule_mprotocol" property="maskProtocol"/>
</association>
</resultMap>
<select id="getStaticCommands" resultMap="staticCommandMap">
SELECT t_task.task_id,
<insert id="newTaskUsingCommandInfo" useGeneratedKeys="true" keyProperty="taskId"
parameterType="com.realtime.protection.configuration.entity.task.TaskCommandInfo">
INSERT INTO t_task(task_name, task_start_time, task_end_time, task_create_time, task_modify_time, task_type,
task_act, task_create_username, task_create_depart, task_create_userid)
VALUE
(
#{info.taskName}, #{info.startTime}, #{info.endTime}, NOW(), NOW(), #{info.taskType},
#{info.taskAct}, #{info.taskCreateUsername}, #{info.taskCreateDepart}, #{info.taskCreateUserId}
);
</insert>
<select id="getStaticCommandInfos" resultMap="staticCommandMap">
SELECT t_task.task_name,
t_task.task_id,
tsr.static_rule_id,
t_task.task_create_username,
t_task.task_create_depart,
t_task.task_create_userid,
t_task.task_type,
t_task.task_act,
tsr.static_rule_frequency,
t_task.task_start_time,
t_task.task_end_time,
INET_NTOA(tsr.static_rule_sip) as static_rule_sip,
tsr.static_rule_sport,
INET_NTOA(tsr.static_rule_dip) as static_rule_dip,
@@ -188,7 +217,7 @@
tsr.static_rule_msport,
INET_NTOA(tsr.static_rule_mdip) as static_rule_mdip,
tsr.static_rule_mdport,
tsr.static_rule_frequency
tsr.static_rule_mprotocol
FROM t_task
LEFT JOIN realtime_protection.t_static_rule tsr on t_task.task_id = tsr.static_rule_used_task_id
WHERE task_id = #{task_id}

View File

@@ -1,13 +1,12 @@
package com.realtime.protection;
import org.junit.jupiter.api.Test;
import com.baomidou.dynamic.datasource.annotation.DSTransactional;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.test.annotation.Rollback;
@SpringBootTest
class ProtectionApplicationTests {
@Test
void contextLoads() {
}
@Rollback
@DSTransactional
public class ProtectionApplicationTests {
}

View File

@@ -1,5 +1,6 @@
package com.realtime.protection.server.defense.object;
import com.realtime.protection.ProtectionApplicationTests;
import com.realtime.protection.configuration.entity.defense.object.ProtectObject;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
@@ -13,7 +14,7 @@ import java.util.List;
import static org.junit.jupiter.api.Assertions.*;
@SpringBootTest
class ProtectObjectServiceTest {
class ProtectObjectServiceTest extends ProtectionApplicationTests {
private final ProtectObjectService protectObjectService;
private ProtectObject protectObject;

View File

@@ -1,5 +1,6 @@
package com.realtime.protection.server.defense.template;
import com.realtime.protection.ProtectionApplicationTests;
import com.realtime.protection.configuration.entity.defense.template.ProtectLevel;
import com.realtime.protection.configuration.entity.defense.template.Template;
import org.junit.jupiter.api.AfterEach;
@@ -15,7 +16,7 @@ import java.util.List;
import static org.junit.jupiter.api.Assertions.*;
@SpringBootTest
class TemplateServiceTest {
class TemplateServiceTest extends ProtectionApplicationTests {
private final TemplateService templateService;
private Template template;

View File

@@ -1,5 +1,6 @@
package com.realtime.protection.server.rule.dynamic;
import com.realtime.protection.ProtectionApplicationTests;
import com.realtime.protection.configuration.entity.rule.dynamicrule.DynamicRuleObject;
import com.realtime.protection.server.rule.dynamicrule.DynamicRuleService;
import org.junit.jupiter.api.Test;
@@ -11,7 +12,7 @@ import java.util.List;
import static org.junit.jupiter.api.Assertions.assertTrue;
@SpringBootTest
public class DynamicRuleServiceTest {
public class DynamicRuleServiceTest extends ProtectionApplicationTests {
private final DynamicRuleService dynamicRuleService;
@Autowired

View File

@@ -1,5 +1,6 @@
package com.realtime.protection.server.rule.staticrule;
import com.realtime.protection.ProtectionApplicationTests;
import com.realtime.protection.configuration.entity.rule.staticrule.StaticRuleObject;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
@@ -13,7 +14,7 @@ import java.util.List;
import static org.junit.jupiter.api.Assertions.assertTrue;
@SpringBootTest
public class StaticRuleServiceTest {
public class StaticRuleServiceTest extends ProtectionApplicationTests {
private final StaticRuleService staticRuleService;
private StaticRuleObject staticRuleTest;

View File

@@ -1,5 +1,6 @@
package com.realtime.protection.server.task;
import com.realtime.protection.ProtectionApplicationTests;
import com.realtime.protection.configuration.entity.task.Task;
import com.realtime.protection.configuration.entity.task.TaskCommandInfo;
import org.junit.jupiter.api.BeforeEach;
@@ -14,7 +15,7 @@ import java.util.List;
import static org.junit.jupiter.api.Assertions.*;
@SpringBootTest
class TaskServiceTest {
class TaskServiceTest extends ProtectionApplicationTests {
private final TaskService taskService;
private Task task;
@@ -35,9 +36,9 @@ class TaskServiceTest {
task.setTaskEndTime(taskEndTime);
task.setTaskAct("阻断");
task.setTaskType(1);
task.setStaticRuleIds(List.of(1L, 2L));
task.setStaticRuleIds(List.of(1, 2));
task.setDynamicRuleIds(List.of());
task.setTaskCreateUserId(1L);
task.setTaskCreateUserId(1);
task.setTaskCreateUsername("xxx");
task.setTaskCreateDepart("xxx");
}
@@ -76,7 +77,7 @@ class TaskServiceTest {
void testUpdateTasks() {
Task originalTask = taskService.queryTask(38L);
originalTask.setStaticRuleIds(List.of(16L, 17L, 18L, 19L));
originalTask.setStaticRuleIds(List.of(16, 17, 18, 19));
originalTask.setTaskName("修改测试");
assertTrue(taskService.updateTask(originalTask));

View File

@@ -1,5 +1,7 @@
package com.realtime.protection.server.task.status;
import com.alibaba.excel.util.ListUtils;
import com.realtime.protection.ProtectionApplicationTests;
import com.realtime.protection.configuration.entity.task.FiveTupleWithMask;
import com.realtime.protection.configuration.entity.task.TaskCommandInfo;
import com.realtime.protection.server.command.CommandService;
@@ -10,13 +12,12 @@ import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.test.context.SpringBootTest;
import java.time.LocalDateTime;
import java.util.ArrayList;
import java.util.List;
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
@SpringBootTest
class CommandServiceTest {
class CommandServiceTest extends ProtectionApplicationTests {
private final CommandService commandService;
private TaskCommandInfo taskCommandInfo;
@@ -37,8 +38,10 @@ class CommandServiceTest {
taskCommandInfo.setFrequency(30);
taskCommandInfo.setTaskId(30L);
taskCommandInfo.setFiveTupleWithMask(fiveTupleWithMask);
taskCommandInfo.setOperation("阻断");
taskCommandInfo.setEndTime(LocalDateTime.now().plusDays(1));
taskCommandInfo.setTaskAct("阻断");
taskCommandInfo.setStartTime(LocalDateTime.now().plusDays(10));
taskCommandInfo.setEndTime(LocalDateTime.now().plusDays(140));
taskCommandInfo.setFrequency(30);
startTime = System.currentTimeMillis();
}
@@ -56,12 +59,13 @@ class CommandServiceTest {
@Test
void createCommands() {
List<TaskCommandInfo> taskCommandInfos = new ArrayList<>();
List<TaskCommandInfo> taskCommandInfos = ListUtils.newArrayListWithExpectedSize(100);
for (int i = 0; i < 100; i++) {
int port = i + 1000;
taskCommandInfo = new TaskCommandInfo();
TaskCommandInfo taskCommandInfo = new TaskCommandInfo();
taskCommandInfo.setFiveTupleWithMask(new FiveTupleWithMask());
taskCommandInfo.setTaskId(24L);
taskCommandInfo.setTaskId(30L);
taskCommandInfo.setTaskAct("阻断");
taskCommandInfo.getFiveTupleWithMask().setSourcePort(Integer.toString(port));
taskCommandInfo.setStartTime(LocalDateTime.now().plusDays(5));
taskCommandInfo.setEndTime(LocalDateTime.now().plusDays(10));
@@ -70,6 +74,12 @@ class CommandServiceTest {
taskCommandInfos.add(taskCommandInfo);
}
for (TaskCommandInfo info : taskCommandInfos) {
if (info.getFrequency() == null) {
throw new IllegalArgumentException();
}
}
assertDoesNotThrow(() -> commandService.createCommands(taskCommandInfos));
}

View File

@@ -1,7 +1,8 @@
package com.realtime.protection.server.whitelist;
import com.realtime.protection.configuration.entity.task.Command;
import com.realtime.protection.ProtectionApplicationTests;
import com.realtime.protection.configuration.entity.task.FiveTupleWithMask;
import com.realtime.protection.configuration.entity.task.TaskCommandInfo;
import com.realtime.protection.configuration.entity.whitelist.WhiteListObject;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
@@ -14,7 +15,7 @@ import java.util.List;
import static org.junit.jupiter.api.Assertions.assertTrue;
@SpringBootTest
class WhiteListServiceTest {
class WhiteListServiceTest extends ProtectionApplicationTests {
private final WhiteListService whiteListService;
private WhiteListObject whiteListObject;
@@ -78,13 +79,13 @@ class WhiteListServiceTest {
@Test
void testWhiteListCommandJudge() {
FiveTupleWithMask fiveTupleWithMask = new FiveTupleWithMask();
Command command = new Command();
TaskCommandInfo taskCommandInfo = new TaskCommandInfo();
fiveTupleWithMask.setDestinationIP("128.1.1.123");
fiveTupleWithMask.setMaskDestinationIP("255.255.255.0");
fiveTupleWithMask.setDestinationPort("80");
command.setFiveTupleWithMask(fiveTupleWithMask);
taskCommandInfo.setFiveTupleWithMask(fiveTupleWithMask);
List<WhiteListObject> whitelists = whiteListService.whiteListCommandJudge(command);
List<WhiteListObject> whitelists = whiteListService.whiteListCommandJudge(taskCommandInfo);
System.out.println(whitelists);
}