diff --git a/src/main/java/com/realtime/protection/ProtectionApplication.java b/src/main/java/com/realtime/protection/ProtectionApplication.java index d25ea83..76df96e 100644 --- a/src/main/java/com/realtime/protection/ProtectionApplication.java +++ b/src/main/java/com/realtime/protection/ProtectionApplication.java @@ -2,17 +2,15 @@ package com.realtime.protection; import org.springframework.boot.autoconfigure.SpringBootApplication; import org.springframework.boot.builder.SpringApplicationBuilder; -import org.springframework.scheduling.annotation.EnableAsync; @SpringBootApplication -@EnableAsync public class ProtectionApplication { public static void main(String[] args) { SpringApplicationBuilder builder = new SpringApplicationBuilder(ProtectionApplication.class); // 在实际环境中应该修改为prod - builder.application().setAdditionalProfiles("dev"); + builder.application().setAdditionalProfiles("test"); builder.run(args); } diff --git a/src/main/java/com/realtime/protection/configuration/entity/defense/object/ProtectObject.java b/src/main/java/com/realtime/protection/configuration/entity/defense/object/ProtectObject.java index 1605220..5dd102a 100644 --- a/src/main/java/com/realtime/protection/configuration/entity/defense/object/ProtectObject.java +++ b/src/main/java/com/realtime/protection/configuration/entity/defense/object/ProtectObject.java @@ -3,6 +3,7 @@ package com.realtime.protection.configuration.entity.defense.object; import com.alibaba.excel.annotation.ExcelIgnore; import com.alibaba.excel.annotation.ExcelProperty; import com.fasterxml.jackson.annotation.JsonProperty; +import io.swagger.v3.oas.annotations.media.Schema; import jakarta.validation.constraints.Max; import jakarta.validation.constraints.Min; import jakarta.validation.constraints.NotNull; @@ -13,20 +14,24 @@ import lombok.Data; public class ProtectObject { @JsonProperty("proobj_id") @ExcelIgnore + @Schema(description = "防护对象ID", accessMode = Schema.AccessMode.READ_ONLY) private Integer protectObjectId; @JsonProperty("proobj_name") @NotNull(message = "proobj_name should not be empty.") @ExcelProperty("名称") + @Schema(description = "防护对象名称", example = "静态对象测试") private String protectObjectName; @JsonProperty("proobj_system_name") @ExcelProperty("操作系统名称") + @Schema(description = "防护对象操作系统名称", example = "xxx操作系统") private String protectObjectSystemName; @JsonProperty("proobj_ip_address") @Pattern(regexp = "^(\\d{1,3})\\.(\\d{1,3})\\.(\\d{1,3})\\.(\\d{1,3})$", message = "Invalid IPv4 Address") @ExcelProperty("IP地址") + @Schema(description = "防护对象IPv4地址", example = "192.168.0.1") private String protectObjectIPAddress; @JsonProperty("proobj_port") @@ -34,31 +39,38 @@ public class ProtectObject { @Max(value = 65535, message = "port should not be more than 65535") @Min(value = 1, message = "port should not be less than 1") @ExcelProperty("端口") + @Schema(description = "防护对象端口", maximum = "65535", minimum = "1", example = "8080") private Integer protectObjectPort; @JsonProperty("proobj_url") @NotNull(message = "proobj_url should not be empty.") @ExcelProperty("URL") + @Schema(description = "防护对象URL", example = "alice.bob.com") private String protectObjectURL; @JsonProperty("proobj_protocol") @NotNull(message = "proobj_protocol should not be empty.") @ExcelProperty("协议") + @Schema(description = "防护对象网络协议(目前仅可以填写TCP或UDP)", example = "TCP") private String protectObjectProtocol; @JsonProperty("proobj_audit_status") @ExcelIgnore + @Schema(description = "防护对象审核状态(0为未审核,1为已退回,2为审核通过)", example = "2") private Integer protectObjectAuditStatus; @JsonProperty("proobj_create_username") @ExcelIgnore + @Schema(description = "防护对象创建人", example = "xxx", accessMode = Schema.AccessMode.READ_ONLY) private String protectObjectCreateUsername; @JsonProperty("proobj_create_depart") @ExcelIgnore + @Schema(description = "防护对象创建人处室", example = "xxx", accessMode = Schema.AccessMode.READ_ONLY) private String protectObjectCreateDepart; @JsonProperty("proobj_create_userid") @ExcelIgnore + @Schema(description = "防护对象创建人ID", example = "0", accessMode = Schema.AccessMode.READ_ONLY) private Integer protectObjectCreateUserId; } diff --git a/src/main/java/com/realtime/protection/configuration/entity/defense/template/ProtectLevel.java b/src/main/java/com/realtime/protection/configuration/entity/defense/template/ProtectLevel.java index 8d9fd09..a34f019 100644 --- a/src/main/java/com/realtime/protection/configuration/entity/defense/template/ProtectLevel.java +++ b/src/main/java/com/realtime/protection/configuration/entity/defense/template/ProtectLevel.java @@ -1,22 +1,31 @@ package com.realtime.protection.configuration.entity.defense.template; +import io.swagger.v3.oas.annotations.media.Schema; import lombok.Data; @Data public class ProtectLevel { + @Schema(description = "防护等级ID", accessMode = Schema.AccessMode.READ_ONLY) private Integer protectLevelId; + @Schema(description = "该防护等级是否需要提取防护对象IP地址字段") private Boolean hasProtectObjectIP = false; + @Schema(description = "该防护等级是否需要提取防护对象端口字段") private Boolean hasProtectObjectPort = false; + @Schema(description = "该防护等级是否需要提取对端IP地址字段") private Boolean hasPeerIP = false; + @Schema(description = "该防护等级是否需要提取对端端口字段") private Boolean hasPeerPort = false; + @Schema(description = "该防护等级是否需要提取网络协议字段") private Boolean hasProtocol = false; + @Schema(description = "该防护等级是否需要提取URL字段") private Boolean hasURL = false; + @Schema(description = "该防护等级是否需要提取DNS") private Boolean hasDNS = false; } diff --git a/src/main/java/com/realtime/protection/configuration/entity/defense/template/Template.java b/src/main/java/com/realtime/protection/configuration/entity/defense/template/Template.java index 327c12f..ce69687 100644 --- a/src/main/java/com/realtime/protection/configuration/entity/defense/template/Template.java +++ b/src/main/java/com/realtime/protection/configuration/entity/defense/template/Template.java @@ -1,49 +1,58 @@ package com.realtime.protection.configuration.entity.defense.template; import com.fasterxml.jackson.annotation.JsonProperty; +import io.swagger.v3.oas.annotations.media.Schema; import jakarta.validation.constraints.NotNull; import lombok.Data; @Data public class Template { @JsonProperty("template_id") + @Schema(description = "防御策略模板ID", example = "2", accessMode = Schema.AccessMode.READ_ONLY) private Integer templateId; @JsonProperty("template_name") @NotNull(message = "template name should not be empty.") + @Schema(description = "防御策略模板名称", example = "自定义模板") private String templateName; - @JsonProperty("template_running_tasks") - private Integer templateRunningTasks; - - @JsonProperty("template_used") - private Integer templateUsedTimes; - @JsonProperty("source_system") @NotNull(message = "source_system should not be empty. ") + @Schema(description = "防御策略模板数据来源系统", example = "BW系统") private String sourceSystem; @JsonProperty("protect_level_low") @NotNull(message = "protect_level_low should not be empty. ") + @Schema(description = "防御策略模板日常态字段提取选项") private ProtectLevel protectLevelLow; @JsonProperty("protect_level_medium") @NotNull(message = "protect_level_medium should not be empty. ") + @Schema(description = "防御策略模板应急态字段提取选项") private ProtectLevel protectLevelMedium; @JsonProperty("protect_level_high") @NotNull(message = "protect_level_high should not be empty. ") + @Schema(description = "防御策略模板紧急态字段提取选项") private ProtectLevel protectLevelHigh; @JsonProperty("template_used_times") + @Schema(description = "防御策略模板使用次数", example = "20", accessMode = Schema.AccessMode.READ_ONLY) private Integer usedTimes; @JsonProperty("running_tasks") + @Schema(description = "防御策略模板已运行的任务数量", example = "30", accessMode = Schema.AccessMode.READ_ONLY) private Integer runningTasks; + @JsonProperty("create_user_id") + @Schema(description = "防御策略模板创建人ID", example = "1", accessMode = Schema.AccessMode.READ_ONLY) private Integer createUserId; + @JsonProperty("create_user_name") + @Schema(description = "防御策略模板创建人名称", example = "xxx", accessMode = Schema.AccessMode.READ_ONLY) private String createUsername; + @JsonProperty("create_user_depart") + @Schema(description = "防御策略模板创建人处室", example = "xxx", accessMode = Schema.AccessMode.READ_ONLY) private String createDepart; } diff --git a/src/main/java/com/realtime/protection/configuration/entity/task/Task.java b/src/main/java/com/realtime/protection/configuration/entity/task/Task.java index edcffe0..5fca3c0 100644 --- a/src/main/java/com/realtime/protection/configuration/entity/task/Task.java +++ b/src/main/java/com/realtime/protection/configuration/entity/task/Task.java @@ -1,6 +1,8 @@ package com.realtime.protection.configuration.entity.task; import com.fasterxml.jackson.annotation.JsonProperty; +import io.swagger.v3.oas.annotations.Parameter; +import io.swagger.v3.oas.annotations.media.Schema; import jakarta.validation.constraints.Future; import jakarta.validation.constraints.NotNull; import lombok.Data; @@ -9,56 +11,72 @@ import java.time.LocalDateTime; import java.util.List; @Data +@Schema(description = "一个任务对象包含的所有信息") public class Task { @JsonProperty("task_id") + @Schema(description = "任务ID", accessMode = Schema.AccessMode.READ_ONLY) private Long taskId; @JsonProperty("task_name") @NotNull(message = "task_name should not be empty. ") + @Schema(description = "任务名称", example = "静态任务") private String taskName; @JsonProperty("task_start_time") @NotNull(message = "task_start_time should not be empty. ") @Future(message = "task_start_time should be a future time") + @Schema(description = "任务开始时间,必须晚于当前时间", example = "2024-10-23T00:00:00") private LocalDateTime taskStartTime; @JsonProperty("task_end_time") @NotNull(message = "task_end_time should not be empty. ") @Future(message = "task_end_time should be a future time. ") + @Schema(description = "任务结束时间,必须晚于开始时间", example = "2024-10-24T00:00:00") private LocalDateTime taskEndTime; @JsonProperty("task_create_time") + @Schema(hidden = true) private LocalDateTime taskCreateTime; @JsonProperty("task_modify_time") + @Schema(hidden = true) private LocalDateTime taskModifyTime; @JsonProperty("task_type") @NotNull(message = "task_type should not be empty. ") + @Schema(description = "任务类型,1为静态任务,2为实时任务,3为研判后任务", example = "1") private Integer taskType; @JsonProperty("task_act") @NotNull(message = "task_act should not be empty. ") + @Schema(description = "任务行为,目前只能为【阻断】", example = "阻断") private String taskAct; @JsonProperty("task_create_username") + @Schema(hidden = true) private String taskCreateUsername; @JsonProperty("task_create_depart") + @Schema(hidden = true) private String taskCreateDepart; @JsonProperty("task_create_userid") + @Schema(hidden = true) private Long taskCreateUserId; @JsonProperty("static_rule_ids") + @Schema(description = "静态规则ID列表,动态和静态至少存在1个规则", example = "[10, 12]") private List staticRuleIds; @JsonProperty("dynamic_rule_ids") + @Schema(description = "动态规则ID列表,动态和静态至少存在1个规则", example = "[20, 30]") private List dynamicRuleIds; @JsonProperty("task_status") + @Schema(description = "任务状态(0为未启动,1为生成中,2为运行中,3为暂停中,4为已停止,5为已结束,6为失败)", accessMode = Schema.AccessMode.READ_ONLY) private Integer taskStatus; @JsonProperty("task_audit_status") + @Schema(description = "任务审核状态(0为未审核,1为已退回,2为已通过)", accessMode = Schema.AccessMode.READ_ONLY) private Integer taskAuditStatus; } diff --git a/src/main/java/com/realtime/protection/configuration/exception/GlobalExceptionHandler.java b/src/main/java/com/realtime/protection/configuration/exception/GlobalExceptionHandler.java index f123d4f..d669a38 100644 --- a/src/main/java/com/realtime/protection/configuration/exception/GlobalExceptionHandler.java +++ b/src/main/java/com/realtime/protection/configuration/exception/GlobalExceptionHandler.java @@ -5,6 +5,7 @@ import cn.dev33.satoken.exception.SaTokenException; import com.realtime.protection.configuration.response.ResponseResult; import com.realtime.protection.configuration.utils.enums.StateEnum; import com.realtime.protection.server.task.status.StateChangeService; +import lombok.extern.slf4j.Slf4j; import org.apache.ibatis.exceptions.PersistenceException; import org.springframework.context.support.DefaultMessageSourceResolvable; import org.springframework.core.annotation.Order; @@ -16,6 +17,7 @@ import org.springframework.web.method.annotation.HandlerMethodValidationExceptio import java.util.stream.Collectors; @RestControllerAdvice +@Slf4j public class GlobalExceptionHandler { private final StateChangeService stateChangeService; @@ -27,6 +29,7 @@ public class GlobalExceptionHandler { @Order(3) @ExceptionHandler(value = Exception.class) public ResponseResult handleGlobalException(Exception e) { + log.error("meets global exception: " + e.getMessage()); return ResponseResult.error().setMessage(e.getMessage()); } @@ -34,6 +37,7 @@ public class GlobalExceptionHandler { @Order(2) @ExceptionHandler(value = PersistenceException.class) public ResponseResult handleSQLException(PersistenceException e) { + log.error("meets database exception: " + e.getMessage()); return ResponseResult.invalid().setMessage( "please check the integrity of the data. check if the json data exists in the database"); } @@ -41,6 +45,7 @@ public class GlobalExceptionHandler { @Order(2) @ExceptionHandler(value = MethodArgumentNotValidException.class) public ResponseResult handleBindException(MethodArgumentNotValidException e) { + log.debug("meets data bind exception: " + e.getMessage()); return ResponseResult.invalid().setMessage( e.getBindingResult().getAllErrors().stream() .map(DefaultMessageSourceResolvable::getDefaultMessage).collect(Collectors.joining()) @@ -54,12 +59,14 @@ public class GlobalExceptionHandler { IllegalStateException.class }) public ResponseResult handleHandlerMethodValidationException(Exception e) { + log.debug("meets illegal argument exception: " + e.getMessage()); return ResponseResult.invalid().setMessage(e.getMessage()); } @Order(2) @ExceptionHandler(value = NotLoginException.class) public ResponseResult handleNotLoginException(NotLoginException e) { + log.debug("meets not login exception, login type: " + e.getLoginType()); return new ResponseResult( 401, e.getMessage() @@ -69,12 +76,14 @@ public class GlobalExceptionHandler { @Order(2) @ExceptionHandler(value = SaTokenException.class) public ResponseResult handleSaTokenException(SaTokenException e) { + log.debug("sa-token meets exception: " + e.getMessage()); return ResponseResult.unAuthorized().setMessage(e.getMessage()); } @Order(2) @ExceptionHandler(value = DorisStartException.class) public ResponseResult handleDorisStartException(DorisStartException e) { + log.warn("doris database meets exception: " + e.getMessage()); ResponseResult responseResult = ResponseResult.error() .setMessage("Doris command creation meets error: " + e.getMessage()); @@ -84,6 +93,7 @@ public class GlobalExceptionHandler { responseResult.setAnother(ResponseResult.error().setMessage(e.getMessage())); } + log.error(responseResult.getMessage()); return responseResult; } } diff --git a/src/main/java/com/realtime/protection/configuration/response/ResponseResult.java b/src/main/java/com/realtime/protection/configuration/response/ResponseResult.java index 07bc679..d710830 100644 --- a/src/main/java/com/realtime/protection/configuration/response/ResponseResult.java +++ b/src/main/java/com/realtime/protection/configuration/response/ResponseResult.java @@ -1,5 +1,6 @@ package com.realtime.protection.configuration.response; +import io.swagger.v3.oas.annotations.media.Schema; import lombok.Data; import java.io.Serializable; @@ -7,11 +8,19 @@ import java.util.LinkedHashMap; import java.util.Map; @Data +@Schema(name = "通用返回对象", description = "用于所有接口返回的通用返回对象") public class ResponseResult implements Serializable { + @Schema(description = "状态码") private int code; + + @Schema(description = "返回信息") private String message; + + @Schema(description = "封装数据") private Map data; + + @Schema(description = "返回对象链接的另外一个返回对象") private ResponseResult another; public ResponseResult(int code, String message, LinkedHashMap data) { diff --git a/src/main/java/com/realtime/protection/configuration/satoken/role/Admin.java b/src/main/java/com/realtime/protection/configuration/satoken/role/User.java similarity index 52% rename from src/main/java/com/realtime/protection/configuration/satoken/role/Admin.java rename to src/main/java/com/realtime/protection/configuration/satoken/role/User.java index c6c39a2..67eeb0f 100644 --- a/src/main/java/com/realtime/protection/configuration/satoken/role/Admin.java +++ b/src/main/java/com/realtime/protection/configuration/satoken/role/User.java @@ -1,5 +1,6 @@ package com.realtime.protection.configuration.satoken.role; -public enum Admin implements Role { - ADMIN +public enum User implements Role { + ADMIN, + NORMAL } diff --git a/src/main/java/com/realtime/protection/configuration/swagger/SwaggerConfiguration.java b/src/main/java/com/realtime/protection/configuration/swagger/SwaggerConfiguration.java new file mode 100644 index 0000000..6414074 --- /dev/null +++ b/src/main/java/com/realtime/protection/configuration/swagger/SwaggerConfiguration.java @@ -0,0 +1,19 @@ +package com.realtime.protection.configuration.swagger; + +import io.swagger.v3.oas.annotations.OpenAPIDefinition; +import io.swagger.v3.oas.annotations.info.Contact; +import io.swagger.v3.oas.annotations.info.Info; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; + +@OpenAPIDefinition( + info = @Info( + title = "实时方宇项目", + version = "1.0", + description = "实时方宇项目——前端接口文档", + contact = @Contact(name = "陈松岳", email = "chensongyue@iie.ac.cn") + ) +) +public class SwaggerConfiguration { + +} diff --git a/src/main/java/com/realtime/protection/configuration/threadpool/OverrideDefaultThreadPoolConfig.java b/src/main/java/com/realtime/protection/configuration/threadpool/OverrideDefaultThreadPoolConfig.java new file mode 100644 index 0000000..9f8e5b1 --- /dev/null +++ b/src/main/java/com/realtime/protection/configuration/threadpool/OverrideDefaultThreadPoolConfig.java @@ -0,0 +1,58 @@ +package com.realtime.protection.configuration.threadpool; + +import com.realtime.protection.configuration.exception.DorisStartException; +import com.realtime.protection.configuration.exception.GlobalExceptionHandler; +import lombok.extern.slf4j.Slf4j; +import org.springframework.aop.interceptor.AsyncUncaughtExceptionHandler; +import org.springframework.scheduling.annotation.AsyncConfigurer; +import org.springframework.scheduling.annotation.EnableAsync; +import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor; +import org.springframework.stereotype.Component; + +import java.util.concurrent.Executor; +import java.util.concurrent.ThreadPoolExecutor; + +@Slf4j +@Component +@EnableAsync +public class OverrideDefaultThreadPoolConfig implements AsyncConfigurer { + private final TaskThreadPoolConfig taskThreadPoolConfig; + private final GlobalExceptionHandler globalExceptionHandler; + + public OverrideDefaultThreadPoolConfig(TaskThreadPoolConfig taskThreadPoolConfig, + GlobalExceptionHandler globalExceptionHandler) { + this.taskThreadPoolConfig = taskThreadPoolConfig; + this.globalExceptionHandler = globalExceptionHandler; + } + + @Override + public Executor getAsyncExecutor() { + ThreadPoolTaskExecutor threadPoolTaskExecutor = new ThreadPoolTaskExecutor(); + + threadPoolTaskExecutor.setCorePoolSize(taskThreadPoolConfig.getCorePoolSize()); + threadPoolTaskExecutor.setMaxPoolSize(taskThreadPoolConfig.getMaxPoolSize()); + threadPoolTaskExecutor.setKeepAliveSeconds(taskThreadPoolConfig.getKeepAliveSeconds()); + threadPoolTaskExecutor.setQueueCapacity(taskThreadPoolConfig.getQueueCapacity()); + + threadPoolTaskExecutor.setThreadNamePrefix("ThreadPool-"); + threadPoolTaskExecutor.setRejectedExecutionHandler(new ThreadPoolExecutor.AbortPolicy()); + + threadPoolTaskExecutor.initialize(); + + return threadPoolTaskExecutor; + } + + @Override + public AsyncUncaughtExceptionHandler getAsyncUncaughtExceptionHandler() { + return (ex, method, params) -> { + log.debug(method.getName() + " meets error: " + ex.getMessage()); + + if (ex instanceof DorisStartException) { + globalExceptionHandler.handleDorisStartException((DorisStartException) ex); + return; + } + + globalExceptionHandler.handleGlobalException((Exception) ex); + }; + } +} diff --git a/src/main/java/com/realtime/protection/configuration/threadpool/TaskThreadPoolConfig.java b/src/main/java/com/realtime/protection/configuration/threadpool/TaskThreadPoolConfig.java new file mode 100644 index 0000000..045b473 --- /dev/null +++ b/src/main/java/com/realtime/protection/configuration/threadpool/TaskThreadPoolConfig.java @@ -0,0 +1,15 @@ +package com.realtime.protection.configuration.threadpool; + +import lombok.Data; +import org.springframework.boot.context.properties.ConfigurationProperties; +import org.springframework.stereotype.Component; + +@Data +@Component +@ConfigurationProperties(prefix = "task.pool") +public class TaskThreadPoolConfig { + private int corePoolSize; + private int maxPoolSize; + private int keepAliveSeconds; + private int queueCapacity; +} diff --git a/src/main/java/com/realtime/protection/configuration/utils/enums/StateEnum.java b/src/main/java/com/realtime/protection/configuration/utils/enums/StateEnum.java index 6668121..941281a 100644 --- a/src/main/java/com/realtime/protection/configuration/utils/enums/StateEnum.java +++ b/src/main/java/com/realtime/protection/configuration/utils/enums/StateEnum.java @@ -11,11 +11,12 @@ import java.util.Map; public enum StateEnum { // 仅需修改此处即可将任务状态以及对应的State和Num进行对应 PENDING(0, new PendingState()), - RUNNING(1, new RunningState()), - PAUSED(2, new PauseState()), - STOP(3, new StopState()), - FINISHED(4, new FinishedState()), - FAILED(5, new FailedState()); + GENERATING(1, new GeneratingState()), + RUNNING(2, new RunningState()), + PAUSED(3, new PauseState()), + STOP(4, new StopState()), + FINISHED(5, new FinishedState()), + FAILED(6, new FailedState()); // ---------------------------------------------- private final State state; diff --git a/src/main/java/com/realtime/protection/server/command/CommandMapper.java b/src/main/java/com/realtime/protection/server/command/CommandMapper.java index 94d1747..581b5db 100644 --- a/src/main/java/com/realtime/protection/server/command/CommandMapper.java +++ b/src/main/java/com/realtime/protection/server/command/CommandMapper.java @@ -7,7 +7,6 @@ import org.apache.ibatis.annotations.Param; import java.util.List; -@DS("doris") @Mapper public interface CommandMapper { Boolean createCommand(@Param("command") Command command); diff --git a/src/main/java/com/realtime/protection/server/command/CommandService.java b/src/main/java/com/realtime/protection/server/command/CommandService.java index ce19da3..98b510e 100644 --- a/src/main/java/com/realtime/protection/server/command/CommandService.java +++ b/src/main/java/com/realtime/protection/server/command/CommandService.java @@ -1,10 +1,14 @@ package com.realtime.protection.server.command; import com.alibaba.excel.util.ListUtils; +import com.baomidou.dynamic.datasource.annotation.DS; import com.realtime.protection.configuration.entity.task.Command; import com.realtime.protection.configuration.entity.task.TaskCommandInfo; import com.realtime.protection.configuration.exception.DorisStartException; import com.realtime.protection.configuration.utils.SqlSessionWrapper; +import com.realtime.protection.configuration.utils.enums.StateEnum; +import com.realtime.protection.server.task.TaskMapper; +import com.realtime.protection.server.task.TaskService; import lombok.extern.slf4j.Slf4j; import org.springframework.scheduling.annotation.Async; import org.springframework.stereotype.Service; @@ -18,17 +22,19 @@ import java.util.function.Function; public class CommandService { private final CommandMapper commandMapper; + private final TaskService taskService; private final SqlSessionWrapper sqlSessionWrapper; - private static final int BatchSize = 1000; + private static final int BatchSize = 100; private final Function> createCommandBatchFunction; - public CommandService(CommandMapper commandMapper, SqlSessionWrapper sqlSessionWrapper) { + public CommandService(CommandMapper commandMapper, TaskService taskService, SqlSessionWrapper sqlSessionWrapper) { this.commandMapper = commandMapper; + this.taskService = taskService; this.sqlSessionWrapper = sqlSessionWrapper; this.createCommandBatchFunction = mapper -> info -> { if (info.getFrequency() == null) { Command command = Command.generateCommand(info, info.getStartTime()); - commandMapper.createCommand(command); + mapper.createCommand(command); } List commandBatch = ListUtils.newArrayListWithExpectedSize(BatchSize); @@ -43,12 +49,12 @@ public class CommandService { if (commandBatch.size() < BatchSize) { continue; } - commandMapper.createCommands(commandBatch); + mapper.createCommands(commandBatch); commandBatch.clear(); } if (!commandBatch.isEmpty()) { - commandMapper.createCommands(commandBatch); + mapper.createCommands(commandBatch); commandBatch.clear(); } @@ -59,15 +65,17 @@ public class CommandService { } @Async + @DS("doris") public void createCommand(TaskCommandInfo commandInfo) throws DorisStartException { try { sqlSessionWrapper.startBatchSession(CommandMapper.class, createCommandBatchFunction, commandInfo); } catch (Exception e) { - throw new DorisStartException(e); + throw new DorisStartException(e, commandInfo.getTaskId()); } } @Async + @DS("doris") public void createCommands(List taskCommandInfos) throws DorisStartException { Function, Void>> function = mapper -> list -> { if (list == null || list.isEmpty()) { @@ -77,24 +85,34 @@ public class CommandService { for (TaskCommandInfo info : list) { createCommandBatchFunction.apply(mapper).apply(info); } + + taskService.changeTaskStatus(list.get(0).getTaskId(), StateEnum.RUNNING.getStateNum()); return null; }; try { sqlSessionWrapper.startBatchSession(CommandMapper.class, function, taskCommandInfos); } catch (Exception e) { - throw new DorisStartException(e); + TaskCommandInfo info = taskCommandInfos.get(0); + Long taskId = null; + if (info != null) { + taskId = info.getTaskId(); + } + throw new DorisStartException(e, taskId); } } + @DS("doris") public Boolean startCommandsByTaskId(Long taskId) { return commandMapper.startCommandsByTaskId(taskId); } + @DS("doris") public Boolean stopCommandsByTaskId(Long taskId) { return commandMapper.stopCommandsByTaskId(taskId); } + @DS("doris") public Boolean removeCommandsByTaskId(Long taskId) { return commandMapper.removeCommandsByTaskId(taskId); } diff --git a/src/main/java/com/realtime/protection/server/defense/object/ProtectObjectController.java b/src/main/java/com/realtime/protection/server/defense/object/ProtectObjectController.java index 1d6dee4..64e92e4 100644 --- a/src/main/java/com/realtime/protection/server/defense/object/ProtectObjectController.java +++ b/src/main/java/com/realtime/protection/server/defense/object/ProtectObjectController.java @@ -18,7 +18,7 @@ import java.util.List; @RestController @RequestMapping("/proobj") -public class ProtectObjectController { +public class ProtectObjectController implements ProtectObjectControllerApi { private final ProtectObjectService protectObjectService; @@ -26,6 +26,7 @@ public class ProtectObjectController { this.protectObjectService = protectObjectService; } + @Override @PostMapping("/new") public ResponseResult newProtectObject(@RequestBody @Valid ProtectObject protectObject) { Integer protectObjectId = protectObjectService.newProtectObject(protectObject); @@ -42,15 +43,17 @@ public class ProtectObjectController { .setData("success", true); } + @Override @PostMapping("/upload") public ResponseResult uploadFile( - @NotNull(message = "uploadFile cannot be null") MultipartFile uploadFile + @NotNull(message = "uploadFile cannot be null. ") MultipartFile uploadFile ) throws IOException { EasyExcel.read(uploadFile.getInputStream(), ProtectObject.class, new ProjectObjectDataListener(protectObjectService)).sheet().doRead(); return ResponseResult.ok(); } + @Override @GetMapping("/download") public void downloadTemplate(HttpServletResponse response) throws IOException { response.setContentType("application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"); @@ -64,6 +67,7 @@ public class ProtectObjectController { .doWrite(List.of()); } + @Override @GetMapping("/query") public ResponseResult queryProtectObjects(@RequestParam(value = "proobj_name", required = false) String protectObjectName, @@ -77,6 +81,7 @@ public class ProtectObjectController { } + @Override @GetMapping("/{protectObjectId}/query") public ResponseResult queryProtectObject(@PathVariable Integer protectObjectId) throws IllegalAccessException { ProtectObject protectObject = protectObjectService.queryProtectObject(protectObjectId); @@ -84,6 +89,7 @@ public class ProtectObjectController { .setDataMap(EntityUtils.entityToMap(protectObject)); } + @Override @PostMapping("/{protectObjectId}/update") public ResponseResult updateProtectObject(@PathVariable Integer protectObjectId, @RequestBody @Valid ProtectObject protectObject) { @@ -93,6 +99,7 @@ public class ProtectObjectController { .setData("success", protectObjectService.updateProtectObject(protectObject)); } + @Override @DeleteMapping("/{protectObjectId}/delete") public ResponseResult deleteProtectObject(@PathVariable Integer protectObjectId) { return ResponseResult.ok() @@ -100,14 +107,16 @@ public class ProtectObjectController { .setData("success", protectObjectService.deleteProtectObject(protectObjectId)); } - @PostMapping("/delete") - public ResponseResult deleteProtectObject(@RequestBody List protectObjectIds) { + @Override + @DeleteMapping("/delete/{protectObjectIds}") + public ResponseResult deleteProtectObject(@PathVariable List protectObjectIds) { return ResponseResult.ok() .setData("proobj_ids", protectObjectIds) .setData("success", protectObjectService.deleteProtectObjects(protectObjectIds)); } - @PostMapping("/{protectObjectId}/audit/{auditStatus}") + @Override + @GetMapping("/{protectObjectId}/audit/{auditStatus}") public ResponseResult changeProtectObjectAuditStatus(@PathVariable Integer protectObjectId, @PathVariable Integer auditStatus) { return ResponseResult.ok() diff --git a/src/main/java/com/realtime/protection/server/defense/object/ProtectObjectControllerApi.java b/src/main/java/com/realtime/protection/server/defense/object/ProtectObjectControllerApi.java new file mode 100644 index 0000000..6447900 --- /dev/null +++ b/src/main/java/com/realtime/protection/server/defense/object/ProtectObjectControllerApi.java @@ -0,0 +1,115 @@ +package com.realtime.protection.server.defense.object; + +import com.realtime.protection.configuration.entity.defense.object.ProtectObject; +import com.realtime.protection.configuration.response.ResponseResult; +import io.swagger.v3.oas.annotations.Operation; +import io.swagger.v3.oas.annotations.Parameter; +import io.swagger.v3.oas.annotations.media.Content; +import io.swagger.v3.oas.annotations.media.Schema; +import io.swagger.v3.oas.annotations.responses.ApiResponse; +import io.swagger.v3.oas.annotations.tags.Tag; +import jakarta.servlet.http.HttpServletResponse; +import jakarta.validation.Valid; +import jakarta.validation.constraints.Min; +import jakarta.validation.constraints.NotNull; +import org.apache.coyote.Response; +import org.springframework.web.bind.annotation.*; +import org.springframework.web.multipart.MultipartFile; + +import java.io.IOException; +import java.util.List; + +@Tag(name = "防护对象API", description = "防护对象模块所有接口") +public interface ProtectObjectControllerApi { + @PostMapping("/new") + @Operation( + summary = "新建防护对象", + description = "新建一个防护对象", + responses = { + @ApiResponse( + description = "返回新建对象结果", + content = @Content( + mediaType = "application/json", + schema = @Schema(implementation = ResponseResult.class) + ) + ) + }, + requestBody = @io.swagger.v3.oas.annotations.parameters.RequestBody(description = "防护对象信息") + ) + ResponseResult newProtectObject(@RequestBody @Valid ProtectObject protectObject); + + @PostMapping("/upload") + @Operation( + summary = "批量上传防护对象", + description = "使用模板文件上传并新建多个防护对象", + responses = { + @ApiResponse( + description = "返回批量上传新建对象结果", + content = @Content( + mediaType = "application/json", + schema = @Schema(implementation = ResponseResult.class) + ) + ) + }, + requestBody = @io.swagger.v3.oas.annotations.parameters.RequestBody(description = "上传文件") + ) + ResponseResult uploadFile( + @NotNull(message = "uploadFile cannot be null. ") MultipartFile uploadFile + ) throws IOException; + + @GetMapping("/download") + @Operation( + summary = "下载模板文件", + description = "下载防护对象上传模板文件", + responses = { + @ApiResponse( + description = "返回防护对象模板文件", + content = @Content( + mediaType = "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet" + ) + ) + } + ) + void downloadTemplate(HttpServletResponse response) throws IOException; + + @GetMapping("/query") + @Operation( + summary = "根据条件查询多个防护对象", + description = "根据查询条件和页码等,查询多个对象并以列表返回", + responses = { + @ApiResponse( + description = "返回多个防护对象", + content = @Content( + mediaType = "application/json", + schema = @Schema(implementation = ResponseResult.class) + ) + ) + }, + parameters = { + @Parameter() + } + ) + ResponseResult queryProtectObjects(@RequestParam(value = "proobj_name", required = false) + String protectObjectName, + @RequestParam(value = "proobj_id", required = false) @Min(1) + Integer protectObjectId, + @RequestParam("page") @Min(1) Integer page, + @RequestParam("page_size") @Min(1) Integer pageSize); + + @GetMapping("/{protectObjectId}/query") + ResponseResult queryProtectObject(@PathVariable Integer protectObjectId) throws IllegalAccessException; + + @PostMapping("/{protectObjectId}/update") + ResponseResult updateProtectObject(@PathVariable Integer protectObjectId, + @RequestBody @Valid ProtectObject protectObject); + + @DeleteMapping("/{protectObjectId}/delete") + ResponseResult deleteProtectObject(@PathVariable Integer protectObjectId); + + @DeleteMapping("/delete/{protectObjectIds}") + ResponseResult deleteProtectObject(@PathVariable List protectObjectIds); + + @GetMapping("/{protectObjectId}/audit/{auditStatus}") + ResponseResult changeProtectObjectAuditStatus(@PathVariable Integer protectObjectId, + @PathVariable Integer auditStatus); +} diff --git a/src/main/java/com/realtime/protection/server/defense/template/TemplateController.java b/src/main/java/com/realtime/protection/server/defense/template/TemplateController.java index 3efaff1..64d0487 100644 --- a/src/main/java/com/realtime/protection/server/defense/template/TemplateController.java +++ b/src/main/java/com/realtime/protection/server/defense/template/TemplateController.java @@ -10,7 +10,7 @@ import org.springframework.web.bind.annotation.*; import java.util.List; @RestController -@RequestMapping("/deftac") +@RequestMapping("/template") public class TemplateController { private final TemplateService templateService; @@ -21,15 +21,8 @@ public class TemplateController { @PostMapping("/new") public ResponseResult newTemplate(@RequestBody @Valid Template template) { - Integer templateId; - try { - templateId = templateService.newTemplate(template); - } catch (IllegalArgumentException e) { - return new ResponseResult(400, "Illegal Argument in template_elements or default_op") - .setData("template_id", null) - .setData("success", false); - } + Integer templateId = templateService.newTemplate(template); if (templateId > 0) { return ResponseResult.ok() diff --git a/src/main/java/com/realtime/protection/server/task/TaskController.java b/src/main/java/com/realtime/protection/server/task/TaskController.java index ab83c13..7dd4458 100644 --- a/src/main/java/com/realtime/protection/server/task/TaskController.java +++ b/src/main/java/com/realtime/protection/server/task/TaskController.java @@ -15,7 +15,7 @@ import java.util.List; @RestController @RequestMapping("/task") -public class TaskController { +public class TaskController implements TaskControllerApi { private final TaskService taskService; private final StateChangeService stateChangeService; @@ -25,6 +25,7 @@ public class TaskController { this.stateChangeService = stateChangeService; } + @Override @PostMapping("/new") public ResponseResult newTask(@RequestBody @Valid Task task) { Long taskId = taskService.newTask(task); @@ -42,6 +43,7 @@ public class TaskController { .setData("success", false); } + @Override @GetMapping("/query") public ResponseResult queryTasks(@RequestParam(value = "task_status", required = false) Integer taskStatus, @RequestParam(value = "task_type", required = false) String taskType, @@ -54,6 +56,7 @@ public class TaskController { .setData("task_list", tasks); } + @Override @GetMapping("/{id}/query") public ResponseResult queryTask(@PathVariable @Min(1) Long id) throws IllegalAccessException { Task task = taskService.queryTask(id); @@ -66,13 +69,16 @@ public class TaskController { .setDataMap(EntityUtils.entityToMap(task)); } - @PostMapping("/update") - public ResponseResult updateTask(@RequestBody @Valid Task task) { + @Override + @PostMapping("/{taskId}/update") + public ResponseResult updateTask(@PathVariable Long taskId, @RequestBody @Valid Task task) { + task.setTaskId(taskId); return ResponseResult.ok() .setData("task_id", task.getTaskId()) .setData("success", taskService.updateTask(task)); } + @Override @GetMapping("/{taskId}/audit/{auditStatus}") public ResponseResult changeTaskAuditStatus(@PathVariable @NotNull @Max(10) Integer auditStatus, @PathVariable @NotNull @Min(1) Long taskId) { @@ -83,13 +89,15 @@ public class TaskController { .setData("audit_status", taskService.queryTaskAuditStatus(taskId)); } - @GetMapping("/{taskId}/delete") + @Override + @DeleteMapping("/{taskId}/delete") public ResponseResult deleteTask(@PathVariable @NotNull @Min(1) Long taskId) { return ResponseResult.ok() .setData("task_id", taskId) .setData("success", taskService.deleteTask(taskId)); } + @Override @GetMapping("/{taskId}/running/{stateNum}") public ResponseResult changeTaskStatus(@PathVariable @NotNull Integer stateNum, @PathVariable @NotNull Long taskId) throws DorisStartException { diff --git a/src/main/java/com/realtime/protection/server/task/TaskControllerApi.java b/src/main/java/com/realtime/protection/server/task/TaskControllerApi.java new file mode 100644 index 0000000..6d90a40 --- /dev/null +++ b/src/main/java/com/realtime/protection/server/task/TaskControllerApi.java @@ -0,0 +1,162 @@ +package com.realtime.protection.server.task; + +import com.realtime.protection.configuration.entity.task.Task; +import com.realtime.protection.configuration.exception.DorisStartException; +import com.realtime.protection.configuration.response.ResponseResult; +import io.swagger.v3.oas.annotations.Operation; +import io.swagger.v3.oas.annotations.Parameter; +import io.swagger.v3.oas.annotations.media.Content; +import io.swagger.v3.oas.annotations.media.Schema; +import io.swagger.v3.oas.annotations.responses.ApiResponse; +import io.swagger.v3.oas.annotations.tags.Tag; +import jakarta.validation.Valid; +import jakarta.validation.constraints.Max; +import jakarta.validation.constraints.Min; +import jakarta.validation.constraints.NotNull; +import org.springframework.web.bind.annotation.*; + +@Tag(name = "任务控制器API", description = "任务管理模块相关的所有接口") +public interface TaskControllerApi { + @PostMapping("/new") + @Operation( + summary = "添加任务", + description = "根据任务信息添加任务并返回任务添加结果", + responses = { + @ApiResponse( + description = "返回任务添加结果信息", + content = @Content( + mediaType = "application/json", + schema = @Schema(implementation = ResponseResult.class) + ) + ) + }, + requestBody = @io.swagger.v3.oas.annotations.parameters.RequestBody(description = "任务信息") + ) + ResponseResult newTask(@RequestBody @Valid Task task); + + @GetMapping("/query") + @Operation( + summary = "查询任务", + description = "按页和搜索内容查询任务相关信息", + responses = { + @ApiResponse( + description = "返回查询到的所有任务", + content = @Content( + mediaType = "application/json", + schema = @Schema(implementation = ResponseResult.class) + ) + ) + }, + parameters = { + @Parameter(name = "task_status", description = "任务状态(0为未启动,1为生成中,2为运行中,3为暂停中,4为已停止,5为已结束,6为失败)"), + @Parameter(name = "task_type", description = "任务类型(1为静态,2为实时,3为研判后)"), + @Parameter(name = "task_name", description = "任务名称"), + @Parameter(name = "task_creator", description = "任务创建人"), + @Parameter(name = "page", description = "页码", example = "1"), + @Parameter(name = "page_size", description = "每页查询个数", example = "10") + } + ) + ResponseResult queryTasks(@RequestParam(value = "task_status", required = false) Integer taskStatus, + @RequestParam(value = "task_type", required = false) String taskType, + @RequestParam(value = "task_name", required = false) String taskName, + @RequestParam(value = "task_creator", required = false) String taskCreator, + @RequestParam("page") @Min(1) Integer page, + @RequestParam("page_size") @Min(1) Integer pageSize); + + @GetMapping("/{id}/query") + @Operation( + summary = "查询单个任务", + description = "根据任务ID查询单个任务的所有详细信息", + responses = { + @ApiResponse( + description = "返回查询到的单个任务", + content = @Content( + mediaType = "application/json", + schema = @Schema(implementation = ResponseResult.class) + ) + ) + }, + parameters = {@Parameter(name = "id", description = "任务ID", example = "38")} + ) + ResponseResult queryTask(@PathVariable @Min(1) Long id) throws IllegalAccessException; + + @PostMapping("/{taskId}/update") + @Operation( + summary = "更新任务", + description = "根据任务信息更新任务并返回更新结果", + responses = { + @ApiResponse( + description = "返回任务更新结果信息", + content = @Content( + mediaType = "application/json", + schema = @Schema(implementation = ResponseResult.class) + ) + ) + }, + requestBody = @io.swagger.v3.oas.annotations.parameters.RequestBody( + description = "任务信息,必须包含任务原有的或者添加/删除部分后的static_rule_ids和dynamic_rule_ids" + ) + ) + ResponseResult updateTask(@PathVariable Long taskId, @RequestBody @Valid Task task); + + @GetMapping("/{taskId}/audit/{auditStatus}") + @Operation( + summary = "任务审核状态修改", + description = "修改ID对应的任务的审核状态", + responses = { + @ApiResponse( + description = "返回任务审核状态修改的信息", + content = @Content( + mediaType = "application/json", + schema = @Schema(implementation = ResponseResult.class) + ) + ) + }, + parameters = { + @Parameter(name = "taskId", description = "任务ID", example = "38"), + @Parameter(name = "auditStatus", description = "任务欲修改的审核状态(0为未审核,1为已退回,2为审核通过)", example = "2") + } + ) + ResponseResult changeTaskAuditStatus(@PathVariable @NotNull @Max(10) Integer auditStatus, + @PathVariable @NotNull @Min(1) Long taskId); + + @DeleteMapping("/{taskId}/delete") + @Operation( + summary = "删除单个任务", + description = "根据任务ID删除对应任务", + responses = { + @ApiResponse( + description = "返回任务删除结果信息", + content = @Content( + mediaType = "application/json", + schema = @Schema(implementation = ResponseResult.class) + ) + ) + }, + parameters = { + @Parameter(name = "taskId", description = "任务ID") + } + ) + ResponseResult deleteTask(@PathVariable @NotNull @Min(1) Long taskId); + + @GetMapping("/{taskId}/running/{stateNum}") + @Operation( + summary = "修改任务运行状态", + description = "修改ID对应的任务的运行状态", + responses = { + @ApiResponse( + description = "返回任务运行状态修改结果", + content = @Content( + mediaType = "application/json", + schema = @Schema(implementation = ResponseResult.class) + ) + ) + }, + parameters = { + @Parameter(name = "taskId", description = "任务ID"), + @Parameter(name = "stateNum", description = "任务状态编号任务状态(0为未启动,1为生成中,2为运行中,3为暂停中,4为已停止,5为已结束,6为失败)") + } + ) + ResponseResult changeTaskStatus(@PathVariable @NotNull Integer stateNum, + @PathVariable @NotNull Long taskId) throws DorisStartException; +} diff --git a/src/main/java/com/realtime/protection/server/task/TaskService.java b/src/main/java/com/realtime/protection/server/task/TaskService.java index 23ffb28..5a029d0 100644 --- a/src/main/java/com/realtime/protection/server/task/TaskService.java +++ b/src/main/java/com/realtime/protection/server/task/TaskService.java @@ -1,5 +1,6 @@ package com.realtime.protection.server.task; +import com.baomidou.dynamic.datasource.annotation.DS; import com.realtime.protection.configuration.entity.task.Task; import com.realtime.protection.configuration.entity.task.TaskCommandInfo; import com.realtime.protection.configuration.utils.status.AuditStatusValidator; @@ -20,8 +21,11 @@ public class TaskService { public Long newTask(Task task) { taskMapper.newTask(task); - taskMapper.newTaskStaticRuleConcat(task.getTaskId(), task.getStaticRuleIds()); - taskMapper.newTaskDynamicRuleConcat(task.getTaskId(), task.getDynamicRuleIds()); + if (task.getStaticRuleIds() != null && !task.getStaticRuleIds().isEmpty()) + taskMapper.newTaskStaticRuleConcat(task.getTaskId(), task.getStaticRuleIds()); + + if (task.getDynamicRuleIds() != null && !task.getDynamicRuleIds().isEmpty()) + taskMapper.newTaskDynamicRuleConcat(task.getTaskId(), task.getDynamicRuleIds()); return task.getTaskId(); } @@ -70,6 +74,7 @@ public class TaskService { return taskMapper.deleteTask(taskId); } + @DS("mysql") public Boolean changeTaskStatus(Long taskId, Integer stateNum) { return taskMapper.changeTaskStatus(taskId, stateNum); } diff --git a/src/main/java/com/realtime/protection/server/task/status/StateChangeService.java b/src/main/java/com/realtime/protection/server/task/status/StateChangeService.java index 080c2d5..26f8de3 100644 --- a/src/main/java/com/realtime/protection/server/task/status/StateChangeService.java +++ b/src/main/java/com/realtime/protection/server/task/status/StateChangeService.java @@ -31,6 +31,10 @@ public class StateChangeService { State newState = StateEnum.getStateByNum(stateNum); + if (newState == null) { + return false; + } + if (!originalState.handle(newState, commandService, taskService, taskId)) { return false; } diff --git a/src/main/java/com/realtime/protection/server/task/status/StateHandler.java b/src/main/java/com/realtime/protection/server/task/status/StateHandler.java index 866b31c..d20101e 100644 --- a/src/main/java/com/realtime/protection/server/task/status/StateHandler.java +++ b/src/main/java/com/realtime/protection/server/task/status/StateHandler.java @@ -77,13 +77,7 @@ public class StateHandler { throw new IllegalArgumentException("static rules are empty, need to choose at least one static rule"); } - try { - commandService.createCommands(staticTaskCommandInfos); - } catch (DorisStartException e) { - e.taskId = taskId; - throw e; - } - + commandService.createCommands(staticTaskCommandInfos); return true; } } diff --git a/src/main/java/com/realtime/protection/server/task/status/states/GeneratingState.java b/src/main/java/com/realtime/protection/server/task/status/states/GeneratingState.java new file mode 100644 index 0000000..c9e059e --- /dev/null +++ b/src/main/java/com/realtime/protection/server/task/status/states/GeneratingState.java @@ -0,0 +1,19 @@ +package com.realtime.protection.server.task.status.states; + +import com.realtime.protection.configuration.exception.DorisStartException; +import com.realtime.protection.configuration.utils.enums.StateEnum; +import com.realtime.protection.configuration.utils.status.State; +import com.realtime.protection.server.command.CommandService; +import com.realtime.protection.server.task.TaskService; +import com.realtime.protection.server.task.status.StateHandler; + +public class GeneratingState extends StateHandler implements State { + @Override + public Boolean handle(State newState, CommandService commandService, TaskService taskService, Long taskId) throws DorisStartException { + return switch(StateEnum.getStateEnumByState(newState)) { + case RUNNING, GENERATING -> true; + case FAILED -> handleFailed(commandService, taskId); + default -> throw new IllegalStateException("Unexpected value: " + StateEnum.getStateEnumByState(newState)); + }; + } +} diff --git a/src/main/java/com/realtime/protection/server/task/status/states/PendingState.java b/src/main/java/com/realtime/protection/server/task/status/states/PendingState.java index 7454e0c..17f4bf3 100644 --- a/src/main/java/com/realtime/protection/server/task/status/states/PendingState.java +++ b/src/main/java/com/realtime/protection/server/task/status/states/PendingState.java @@ -11,7 +11,7 @@ public class PendingState extends StateHandler implements State { @Override public Boolean handle(State newState, CommandService commandService, TaskService taskService, Long taskId) throws DorisStartException { return switch (StateEnum.getStateEnumByState(newState)) { - case RUNNING -> handleStart(taskService, commandService, taskId); + case GENERATING -> handleStart(taskService, commandService, taskId); case FAILED -> handleFailed(commandService, taskId); default -> throw new IllegalStateException("Unexpected value: " + StateEnum.getStateEnumByState(newState)); }; diff --git a/src/main/java/com/realtime/protection/server/task/status/states/RunningState.java b/src/main/java/com/realtime/protection/server/task/status/states/RunningState.java index 25eeeb1..85514c2 100644 --- a/src/main/java/com/realtime/protection/server/task/status/states/RunningState.java +++ b/src/main/java/com/realtime/protection/server/task/status/states/RunningState.java @@ -10,7 +10,7 @@ public class RunningState extends StateHandler implements State { @Override public Boolean handle(State newState, CommandService commandService, TaskService taskService, Long taskId) { return switch(StateEnum.getStateEnumByState(newState)) { - case RUNNING -> true; + case RUNNING, GENERATING -> true; case PAUSED -> handlePause(commandService, taskId); case STOP -> handleStop(commandService, taskId); case FINISHED -> handleFinish(commandService, taskId); diff --git a/src/main/resources/config/application-dev.yml b/src/main/resources/config/application-dev.yml index f69eca1..c4d6923 100644 --- a/src/main/resources/config/application-dev.yml +++ b/src/main/resources/config/application-dev.yml @@ -33,5 +33,13 @@ spring: jackson: default-property-inclusion: non_null + mybatis: - mapper-locations: classpath:mappers/* \ No newline at end of file + mapper-locations: classpath:mappers/* + +task: + pool: + core-pool-size: 1 + max-pool-size: 1 + queue-capacity: 1 + keep-alive-seconds: 120 \ No newline at end of file diff --git a/src/main/resources/config/application-prod.yml b/src/main/resources/config/application-prod.yml index facbd23..279bb6a 100644 --- a/src/main/resources/config/application-prod.yml +++ b/src/main/resources/config/application-prod.yml @@ -1,5 +1,5 @@ server: - port: 80 + port: 8081 logging: level: @@ -34,4 +34,11 @@ spring: default-property-inclusion: non_null mybatis: - mapper-locations: classpath:mappers/* \ No newline at end of file + mapper-locations: classpath:mappers/* + +task: + pool: + core-pool-size: 20 + max-pool-size: 100 + queue-capacity: 100 + keep-alive-seconds: 60 \ No newline at end of file diff --git a/src/main/resources/config/application-test.yml b/src/main/resources/config/application-test.yml new file mode 100644 index 0000000..a65f7b7 --- /dev/null +++ b/src/main/resources/config/application-test.yml @@ -0,0 +1,53 @@ +server: + port: 8081 + +logging: + level: + com.realtime.protection: info + + +spring: + datasource: + dynamic: + datasource: + mysql: + driver-class-name: com.mysql.cj.jdbc.Driver + username: root + password: aiihhbfcsy123!@# + url: jdbc:mysql://192.168.107.89:3306/realtime_protection + hikari: + is-auto-commit: false + doris: + driver-class-name: com.mysql.cj.jdbc.Driver + username: root + url: jdbc:mysql://10.26.22.133:9030/command + hikari: + is-auto-commit: false + aop: + enabled: true + primary: mysql + strict: true + grace-destroy: true + mvc: + servlet: + path: /api/v1 + jackson: + default-property-inclusion: non_null + +mybatis: + mapper-locations: classpath:mappers/* + +task: + pool: + core-pool-size: 20 + max-pool-size: 100 + queue-capacity: 100 + keep-alive-seconds: 60 + +springdoc: + api-docs: + enabled: true + path: /api-docs + swagger-ui: + path: /swagger + packages-to-scan: com.realtime.protection.server \ No newline at end of file diff --git a/src/main/resources/mappers/CommandMapper.xml b/src/main/resources/mappers/CommandMapper.xml index 2ec63e2..5557247 100644 --- a/src/main/resources/mappers/CommandMapper.xml +++ b/src/main/resources/mappers/CommandMapper.xml @@ -36,23 +36,19 @@ UPDATE t_command - SET IS_VALID = FALSE + SET IS_VALID = FALSE, LAST_UPDATE = NOW() WHERE TASK_ID = #{task_id} AND IS_DELETED = FALSE UPDATE t_command - SET IS_VALID = TRUE + SET IS_VALID = TRUE, LAST_UPDATE = NOW() WHERE TASK_ID = #{task_id} AND IS_DELETED = FALSE UPDATE t_command - SET IS_DELETED = TRUE + SET IS_DELETED = TRUE, LAST_UPDATE = NOW() WHERE TASK_ID = #{task_id} - - \ No newline at end of file diff --git a/src/test/java/com/realtime/protection/server/defense/template/TemplateServiceTest.java b/src/test/java/com/realtime/protection/server/defense/template/TemplateServiceTest.java index 52a00b4..ebd8e50 100644 --- a/src/test/java/com/realtime/protection/server/defense/template/TemplateServiceTest.java +++ b/src/test/java/com/realtime/protection/server/defense/template/TemplateServiceTest.java @@ -67,8 +67,8 @@ class TemplateServiceTest { assertEquals(5, templates.size()); for (Template template : templates) { assertTrue(template.getTemplateId() > 0); - assertNotNull(template.getTemplateRunningTasks()); - assertNotNull(template.getTemplateUsedTimes()); + assertNotNull(template.getUsedTimes()); + assertNotNull(template.getRunningTasks()); } }