1. 修改setDataMap函数为addDataMap以防止swagger将DataMap视为一种属性

2. 当任务未通过审核时,现在会立刻报错而不是返回false
This commit is contained in:
EnderByEndera
2024-01-19 15:09:23 +08:00
parent 1d317eb10f
commit 449c320261
22 changed files with 176 additions and 75 deletions

View File

@@ -5,7 +5,7 @@ plugins {
}
group = 'com.realtime'
version = '0.0.1-SNAPSHOT'
version = '0.0.2-SNAPSHOT'
java {
sourceCompatibility = '17'
@@ -43,6 +43,7 @@ dependencies {
implementation 'org.springdoc:springdoc-openapi-starter-webmvc-ui:2.3.0'
implementation 'com.alibaba:easyexcel:3.3.3'
implementation 'com.baomidou:dynamic-datasource-spring-boot3-starter:4.3.0'
implementation 'com.github.xiaoymin:knife4j-openapi3-jakarta-spring-boot-starter:4.4.0'
}
tasks.named('test') {

31
qodana.yaml Normal file
View File

@@ -0,0 +1,31 @@
#-------------------------------------------------------------------------------#
# Qodana analysis is configured by qodana.yaml file #
# https://www.jetbrains.com/help/qodana/qodana-yaml.html #
#-------------------------------------------------------------------------------#
version: "1.0"
#Specify inspection profile for code analysis
profile:
name: qodana.starter
#Enable inspections
#include:
# - name: <SomeEnabledInspectionId>
#Disable inspections
#exclude:
# - name: <SomeDisabledInspectionId>
# paths:
# - <path/where/not/run/inspection>
projectJDK: 17 #(Applied in CI/CD pipeline)
#Execute shell command before Qodana execution (Applied in CI/CD pipeline)
#bootstrap: sh ./prepare-qodana.sh
#Install IDE plugins before Qodana execution (Applied in CI/CD pipeline)
#plugins:
# - id: <plugin.id> #(plugin id can be found at https://plugins.jetbrains.com)
#Specify Qodana linter for analysis (Applied in CI/CD pipeline)
linter: jetbrains/qodana-jvm:latest

View File

@@ -2,7 +2,6 @@ package com.realtime.protection;
import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.scheduling.annotation.EnableScheduling;
@SpringBootApplication
public class ProtectionApplication {

View File

@@ -1,18 +0,0 @@
package com.realtime.protection.configuration.cors;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.servlet.config.annotation.CorsRegistry;
import org.springframework.web.servlet.config.annotation.WebMvcConfigurer;
@Configuration
public class CorsFilter implements WebMvcConfigurer {
@Override
public void addCorsMappings(CorsRegistry corsRegistry) {
corsRegistry.addMapping("/**")
.allowedOrigins("http://localhost:8000")
.allowCredentials(true)
.allowedMethods("GET", "POST", "DELETE", "PUT")
.allowedHeaders("*")
.exposedHeaders("*");
}
}

View File

@@ -0,0 +1,17 @@
package com.realtime.protection.configuration.response;
import io.swagger.v3.oas.annotations.media.Schema;
import lombok.Data;
import lombok.EqualsAndHashCode;
import java.util.HashMap;
@EqualsAndHashCode(callSuper = true)
@Data
@Schema(description = "xxx")
public class ResponseData extends HashMap<String, Object> {
public Object put(String key, Object value, String description) {
return super.put(key, value);
}
}

View File

@@ -18,12 +18,12 @@ public class ResponseResult implements Serializable {
private String message;
@Schema(description = "封装数据")
private Map<String, Object> data;
private ResponseData data;
@Schema(description = "返回对象链接的另外一个返回对象")
private ResponseResult another;
public ResponseResult(int code, String message, LinkedHashMap<String, Object> data) {
public ResponseResult(int code, String message, ResponseData data) {
this.code = code;
this.message = message;
this.data = data;
@@ -31,13 +31,13 @@ public class ResponseResult implements Serializable {
public ResponseResult(int code) {
this.code = code;
this.data = new LinkedHashMap<>();
this.data = new ResponseData();
}
public ResponseResult(int code, String message) {
this.code = code;
this.message = message;
this.data = new LinkedHashMap<>();
this.data = new ResponseData();
}
public static ResponseResult ok() {
@@ -83,8 +83,8 @@ public class ResponseResult implements Serializable {
return this;
}
public ResponseResult setDataMap(Map<String, Object> data) {
this.data = data;
public ResponseResult addDataMap(Map<String, Object> data) {
this.data = (ResponseData) data;
return this;
}
}

View File

@@ -23,4 +23,6 @@ public interface CommandMapper {
Boolean setCommandInvalid(@Param("command_id") String commandId);
List<TaskCommandInfo> queryCommandInfoByTaskId(@Param("task_id") Long taskId);
TaskCommandInfo queryCommandInfoByUUID(@Param("uuid") String uuid);
}

View File

@@ -6,8 +6,10 @@ import com.realtime.protection.configuration.entity.task.TaskCommandInfo;
import com.realtime.protection.configuration.utils.SqlSessionWrapper;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
import java.util.List;
import java.util.UUID;
import java.util.function.Function;
@Service
@@ -24,14 +26,19 @@ public class CommandService {
this.sqlSessionWrapper = sqlSessionWrapper;
}
public Boolean createCommand(TaskCommandInfo commandInfo) {
return commandMapper.createCommand(commandInfo);
@Transactional
public String createCommand(TaskCommandInfo commandInfo) {
commandInfo.setUUID(UUID.randomUUID().toString());
commandMapper.createCommand(commandInfo);
return commandInfo.getUUID();
}
public void createCommands(List<TaskCommandInfo> taskCommandInfos) {
Function<CommandMapper, Function<List<TaskCommandInfo>, Boolean>> function = mapper -> list -> {
List<TaskCommandInfo> taskCommandInfoBatch = ListUtils.newArrayListWithExpectedSize(BatchSize);
for (TaskCommandInfo info : list) {
info.setUUID(UUID.randomUUID().toString());
taskCommandInfoBatch.add(info);
if (taskCommandInfoBatch.size() < BatchSize) {
continue;
@@ -56,6 +63,10 @@ public class CommandService {
return commandMapper.queryCommandInfoByTaskId(taskId);
}
public TaskCommandInfo queryCommandInfoByUUID(String uuid) {
return commandMapper.queryCommandInfoByUUID(uuid);
}
public Boolean startCommandsByTaskId(Long taskId) {
return commandMapper.startCommandsByTaskId(taskId);
}

View File

@@ -83,14 +83,13 @@ public class ProtectObjectController implements ProtectObjectControllerApi {
@Override
@GetMapping("/{protectObjectId}/query")
public ResponseResult queryProtectObject(@PathVariable Integer protectObjectId) throws IllegalAccessException {
public ResponseResult queryProtectObject(@PathVariable Integer protectObjectId) {
ProtectObject protectObject = protectObjectService.queryProtectObject(protectObjectId);
if (protectObject == null) {
return ResponseResult.invalid()
.setMessage("无效的防护对象ID也许该ID指定的防护对象不存在");
}
return ResponseResult.ok()
.setDataMap(EntityUtils.entityToMap(protectObject));
return ResponseResult.ok().setData("protect_object", protectObject);
}
@Override
@@ -124,7 +123,7 @@ public class ProtectObjectController implements ProtectObjectControllerApi {
public ResponseResult changeProtectObjectAuditStatus(@PathVariable Integer protectObjectId,
@PathVariable Integer auditStatus) {
return ResponseResult.ok()
.setDataMap(protectObjectService.changeProtectObjectAuditStatus(protectObjectId, auditStatus))
.addDataMap(protectObjectService.changeProtectObjectAuditStatus(protectObjectId, auditStatus))
.setData("proobj_id", protectObjectId);
}
}

View File

@@ -55,8 +55,7 @@ public class TemplateController implements TemplateControllerApi {
return ResponseResult.invalid()
.setMessage("无效的策略模板ID也许该模板不存在");
}
return ResponseResult.ok()
.setDataMap(EntityUtils.entityToMap(template));
return ResponseResult.ok().setData("template", template);
}
@Override

View File

@@ -152,7 +152,7 @@ public class StaticRuleController implements StaticRuleControllerApi {
.setData("success", false);
}
return ResponseResult.ok()
.setDataMap(staticRuleService.updateAuditStatus(id, auditStatus))
.addDataMap(staticRuleService.updateAuditStatus(id, auditStatus))
.setData("staticRule_id", id);
}

View File

@@ -87,8 +87,7 @@ public class TaskController implements TaskControllerApi {
return ResponseResult.invalid().setMessage("无效Task ID也许该ID对应的任务不存在");
}
return ResponseResult.ok()
.setDataMap(EntityUtils.entityToMap(task));
return ResponseResult.ok().setData("task", task);
}
@Override

View File

@@ -4,6 +4,7 @@ import com.realtime.protection.configuration.entity.task.Task;
import com.realtime.protection.configuration.entity.task.TaskCommandInfo;
import com.realtime.protection.configuration.exception.DorisStartException;
import com.realtime.protection.configuration.response.ResponseResult;
import io.netty.channel.ChannelHandler;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.Parameter;
import io.swagger.v3.oas.annotations.media.Content;
@@ -16,6 +17,10 @@ import jakarta.validation.constraints.Min;
import jakarta.validation.constraints.NotNull;
import org.springframework.web.bind.annotation.*;
import java.util.Map;
import static com.fasterxml.jackson.databind.type.LogicalType.Map;
@Tag(name = "任务控制器API", description = "任务管理模块相关的所有接口")
public interface TaskControllerApi {
@PostMapping("/new")
@@ -28,7 +33,8 @@ public interface TaskControllerApi {
content = @Content(
mediaType = "application/json",
schema = @Schema(implementation = ResponseResult.class)
)
),
responseCode = "200"
)
},
requestBody = @io.swagger.v3.oas.annotations.parameters.RequestBody(description = "任务信息")

View File

@@ -35,9 +35,9 @@ public class StateHandler {
throw new IllegalArgumentException("无效的task_id因为task_audit_status为空");
}
// 如果审核状态不为已通过审核,则无效
// 如果审核状态不为已通过审核,则报错
if (taskAuditStatus != AuditStatus.AUDITED.getAuditStatus()) {
return false;
throw new IllegalArgumentException("无效的task_id因为未通过审核");
}
return switch (TaskTypeEnum.getTaskTypeByNum(task.getTaskType())) {

View File

@@ -14,7 +14,8 @@ public class PendingState extends StateHandler implements State {
case FAILED -> handleFailed(commandService, taskId);
case RUNNING -> handleStart(taskService, commandService, taskId);
case FINISHED -> handleFinish(commandService, taskId);
default -> throw new IllegalStateException("Unexpected value: " + StateEnum.getStateEnumByState(newState));
default -> throw new IllegalStateException(taskId + " meets unexpected value: "
+ StateEnum.getStateEnumByState(newState));
};
}
}

View File

@@ -165,7 +165,7 @@ public class WhiteListController implements WhiteListControllerApi {
// }
return ResponseResult.ok()
.setDataMap(whiteListService.updateWhiteListObjectAuditStatus(id, auditStatus))
.addDataMap(whiteListService.updateWhiteListObjectAuditStatus(id, auditStatus))
.setData("whiteobj_id", id);
}

View File

@@ -1,5 +1,7 @@
server:
port: 8080
port: 8081
servlet:
context-path: /api/v1
logging:
level:
@@ -29,9 +31,6 @@ spring:
primary: mysql
strict: true
grace-destroy: true
mvc:
servlet:
path: /api/v1
jackson:
default-property-inclusion: non_null
@@ -49,7 +48,12 @@ task:
springdoc:
api-docs:
enabled: true
path: /api-docs
path: /v3/api-docs
swagger-ui:
path: /swagger
packages-to-scan: com.realtime.protection.server
path: /swagger-ui.html
packages-to-scan: com.realtime.protection.server
knife4j:
enable: true
setting:
language: zh_cn

View File

@@ -1,5 +1,7 @@
server:
port: 80
port: 8081
servlet:
context-path: /api/v1
logging:
level:
@@ -29,9 +31,6 @@ spring:
primary: mysql
strict: true
grace-destroy: true
mvc:
servlet:
path: /api/v1
jackson:
default-property-inclusion: non_null
@@ -41,8 +40,8 @@ mybatis:
task:
pool:
core-pool-size: 20
max-pool-size: 100
queue-capacity: 100
max-pool-size: 400
queue-capacity: 400
keep-alive-seconds: 120
springdoc:

View File

@@ -1,5 +1,7 @@
server:
port: 8081
servlet:
context-path: /api/v1
logging:
level:
@@ -29,9 +31,6 @@ spring:
primary: mysql
strict: false
grace-destroy: true
mvc:
servlet:
path: /api/v1
jackson:
default-property-inclusion: non_null
@@ -40,9 +39,9 @@ mybatis:
task:
pool:
core-pool-size: 20
core-pool-size: 50
max-pool-size: 100
queue-capacity: 100
queue-capacity: 50
keep-alive-seconds: 120
springdoc:

View File

@@ -9,8 +9,8 @@
MASK_SRC_IP, MASK_SRC_PORT, MASK_DST_IP, MASK_DST_PORT, MASK_PROTOCOL, VALID_TIME,
INVALID_TIME, IS_VALID,
SEND_TIMES, SUCCESS_TIMES, CREATE_TIME, LAST_UPDATE, IS_DELETED)
values (UUID(), #{info.taskId}, #{info.taskAct}, #{info.frequency},
#{info.fiveTupleWithMask.addrType},
values (#{info.UUID}, #{info.taskId}, #{info.taskAct}, #{info.frequency},
DEFAULT,
#{info.fiveTupleWithMask.sourceIP}, #{info.fiveTupleWithMask.sourcePort},
#{info.fiveTupleWithMask.destinationIP}, #{info.fiveTupleWithMask.destinationPort},
#{info.fiveTupleWithMask.protocolNum},
@@ -28,8 +28,8 @@
SEND_TIMES, SUCCESS_TIMES, CREATE_TIME, LAST_UPDATE, IS_DELETED)
values
<foreach collection="command_infos" item="info" separator=",">
(UUID(), #{info.taskId}, #{info.taskAct}, #{info.frequency},
#{info.fiveTupleWithMask.addrType},
(#{info.UUID}, #{info.taskId}, #{info.taskAct}, #{info.frequency},
DEFAULT,
#{info.fiveTupleWithMask.sourceIP}, #{info.fiveTupleWithMask.sourcePort},
#{info.fiveTupleWithMask.destinationIP}, #{info.fiveTupleWithMask.destinationPort},
#{info.fiveTupleWithMask.protocolNum},
@@ -59,6 +59,23 @@
</association>
</resultMap>
<select id="queryCommandInfoByUUID" resultMap="commandStatMap">
SELECT COMMAND_ID,
TASK_ACT,
SEND_TIMES,
SUCCESS_TIMES,
FIRST_SEND_TIME,
LASt_SEND_TIME,
SRC_IP,
SRC_PORT,
DST_IP,
DST_PORT,
PROTOCOL
FROM t_command
WHERE COMMAND_ID = #{uuid}
AND IS_DELETED = FALSE
</select>
<select id="queryCommandInfoByTaskId" resultMap="commandStatMap">
SELECT COMMAND_ID,
TASK_ACT,

View File

@@ -1,8 +1,15 @@
package com.realtime.protection.server.task;
import com.realtime.protection.ProtectionApplicationTests;
import com.realtime.protection.configuration.entity.rule.dynamicrule.DynamicRuleObject;
import com.realtime.protection.configuration.entity.rule.staticrule.StaticRuleObject;
import com.realtime.protection.configuration.entity.task.Task;
import com.realtime.protection.configuration.entity.task.TaskCommandInfo;
import com.realtime.protection.configuration.exception.DorisStartException;
import com.realtime.protection.configuration.utils.enums.StateEnum;
import com.realtime.protection.server.rule.dynamicrule.DynamicRuleService;
import com.realtime.protection.server.rule.staticrule.StaticRuleService;
import com.realtime.protection.server.task.status.StateChangeService;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.springframework.beans.factory.annotation.Autowired;
@@ -10,6 +17,7 @@ import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.dao.DataIntegrityViolationException;
import java.time.LocalDateTime;
import java.util.ArrayList;
import java.util.List;
import static org.junit.jupiter.api.Assertions.*;
@@ -17,11 +25,17 @@ import static org.junit.jupiter.api.Assertions.*;
@SpringBootTest
class TaskServiceTest extends ProtectionApplicationTests {
private final TaskService taskService;
private final StaticRuleService staticRuleService;
private final DynamicRuleService dynamicRuleService;
private final StateChangeService stateChangeService;
private Task task;
@Autowired
TaskServiceTest(TaskService taskService) {
TaskServiceTest(TaskService taskService, StaticRuleService staticRuleService, DynamicRuleService dynamicRuleService, StateChangeService stateChangeService) {
this.taskService = taskService;
this.staticRuleService = staticRuleService;
this.dynamicRuleService = dynamicRuleService;
this.stateChangeService = stateChangeService;
}
@BeforeEach
@@ -29,15 +43,14 @@ class TaskServiceTest extends ProtectionApplicationTests {
this.task = new Task();
task.setTaskName("静态测试");
LocalDateTime taskStartTime = LocalDateTime.now().plusDays(1);
LocalDateTime taskEndTime = LocalDateTime.now().plusDays(5);
LocalDateTime taskStartTime = LocalDateTime.now().plusMinutes(1);
LocalDateTime taskEndTime = LocalDateTime.now().plusYears(5);
task.setTaskStartTime(taskStartTime);
task.setTaskEndTime(taskEndTime);
task.setTaskAct("阻断");
task.setTaskType(1);
task.setStaticRuleIds(List.of(1, 2));
task.setDynamicRuleIds(List.of());
task.setTaskCreateUserId(1);
task.setTaskCreateUsername("xxx");
task.setTaskCreateDepart("xxx");
@@ -45,11 +58,22 @@ class TaskServiceTest extends ProtectionApplicationTests {
@Test
void testNewTaskSuccess() {
for (int i = 0; i < 100; i++) {
LocalDateTime taskStartTime = LocalDateTime.now().plusDays(i);
LocalDateTime taskEndTime = LocalDateTime.now().plusDays(i + 10);
task.setTaskStartTime(taskStartTime);
task.setTaskEndTime(taskEndTime);
for (int i = 1; i < 1000; i++) {
List<StaticRuleObject> staticRuleObjects = staticRuleService.queryStaticRule(
null, null, null, null, i, 2);
List<Integer> staticRuleIds = new ArrayList<>();
staticRuleObjects.forEach(staticRuleObject ->
staticRuleIds.add(staticRuleObject.getStaticRuleId()));
task.setStaticRuleIds(staticRuleIds);
List<DynamicRuleObject> dynamicRuleObjects = dynamicRuleService.queryDynamicRuleObject(
null, null, null, null, i, 2
);
List<Integer> dynamicRuleIds = new ArrayList<>();
dynamicRuleObjects.forEach(dynamicRuleObject ->
dynamicRuleIds.add(dynamicRuleObject.getDynamicRuleId()));
task.setDynamicRuleIds(dynamicRuleIds);
assertDoesNotThrow(() -> {
Long taskId = taskService.newTask(task);
assertTrue(taskId > 0);

View File

@@ -13,8 +13,9 @@ import org.springframework.boot.test.context.SpringBootTest;
import java.time.LocalDateTime;
import java.util.List;
import java.util.UUID;
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.junit.jupiter.api.Assertions.*;
@SpringBootTest
class CommandServiceTest extends ProtectionApplicationTests {
@@ -55,6 +56,7 @@ class CommandServiceTest extends ProtectionApplicationTests {
@Test
void createCommand() {
assertDoesNotThrow(() -> commandService.createCommand(taskCommandInfo));
assertNotNull(taskCommandInfo.getUUID());
}
@Test
@@ -81,6 +83,15 @@ class CommandServiceTest extends ProtectionApplicationTests {
}
assertDoesNotThrow(() -> commandService.createCommands(taskCommandInfos));
}
@Test
void queryCommandByUUID() {
List<TaskCommandInfo> taskCommandInfos = commandService.queryCommandInfoByTaskId(30L);
assertTrue(taskCommandInfos != null && !taskCommandInfos.isEmpty());
for (TaskCommandInfo taskCommandInfo : taskCommandInfos) {
assertNotNull(commandService.queryCommandInfoByUUID(taskCommandInfo.getUUID()));
}
}
}