This repository has been archived on 2025-09-14. You can view files and clone it, but cannot push or open issues or pull requests.
Files

170 lines
7.8 KiB
Python
Raw Permalink Normal View History

2023-05-25 15:30:02 +08:00
# -*- coding:utf-8 -*-
import configparser
from daemon import Daemon
import sys
from torch.cuda import is_available
from loguru import logger
from datetime import datetime
from utils import create_all_dataloader, train_and_test_model
import method
import gc
import traceback
import pandas
import os
class Config(configparser.ConfigParser):
def __init__(self, defaults=None):
configparser.ConfigParser.__init__(self, defaults=defaults)
def optionxform(self, optionstr):
return optionstr
def as_dict(self):
d = dict(self._sections)
for k in d:
d[k] = dict(d[k])
return d
class Main(Daemon):
def __init__(self, pidfile):
super(Main, self).__init__(pidfile=pidfile)
current = datetime.now()
self.start_time = current.strftime("%Y-%m-%d_%H-%M-%S")
if len(sys.argv) == 1 or sys.argv[1] == "start":
self.run()
elif sys.argv[1] == "stop":
self.stop()
elif sys.argv[1] == "daemon":
self.start()
else:
print("Input format error. Please input: python3 main.py start|stop|daemon")
sys.exit(0)
def run(self):
# 读取配置文件参数
cf = Config()
cf.read("./config.ini", encoding='utf8')
cf_dict = cf.as_dict()
# 读取数据集名称
dataset_names = cf.get("Dataset", "name")
datasets = dataset_names.split(",")
datasets = [name.strip() for name in datasets]
# 读取模型名称
model_names = cf.get("Method", "name")
models = model_names.split(",")
models = [name.strip() for name in models]
# 读取预处理方法
preprocess_name = cf.get("Preprocess", "name")
# 读取评估方法
evaluation_names = cf.get("Evaluation", "name")
evaluations = evaluation_names.split(",")
evaluations = [name.strip() for name in evaluations]
# 读取模型参数文件路径
model_path = cf_dict["ModelPath"]
# 读取训练参数
train = cf.getboolean("BaseParameters", "train")
epochs = cf.getint("BaseParameters", "epochs")
batch_size = cf.getint("BaseParameters", "batch_size")
learning_rate = cf.getfloat("BaseParameters", "learning_rate")
device = cf.get("BaseParameters", "device")
if device == "auto":
device = 'cuda:0' if is_available() else 'cpu'
# 读取自定义参数
customs = cf_dict["CustomParameters"]
# 建立本次实验记录的路径
os.makedirs(f"./records", exist_ok=True)
os.makedirs(f"./records/{self.start_time}", exist_ok=True)
os.makedirs(f"./records/{self.start_time}/detection_result", exist_ok=True)
os.makedirs(f"./records/{self.start_time}/model", exist_ok=True)
# 初始化日志
logger.add(f"./records/{self.start_time}/log",
level='DEBUG',
format='{time:YYYY-MM-DD HH:mm:ss} - {level} - {file} - {line} - {message}',
rotation="100 MB")
# 核心程序
self.core(models, datasets, preprocess_name, evaluations, model_path, train, epochs, batch_size, learning_rate,
device, customs)
logger.info(f"实验结束,关闭进程")
def core(self, models: [str], datasets: [str], preprocess_name: str, evaluations: [str], model_path: {}, train: bool,
epochs: int, batch_size: int, learning_rate: float, device: str, customs: {}):
"""
初始化数据集与模型并开始训练与测试
:param models: 训练的模型名称可包含多个
:param datasets: 使用的数据集名称可包含多个
:param preprocess_name: 预处理方法名称
:param evaluations: 评估方法名称可包含多个
:param model_path: 需要加载模型参数的路径可包含多个
:param train: 是否训练如果为False则仅测试模型
:param epochs: 总训练轮数
:param batch_size: batch的尺寸
:param learning_rate: 学习率
:param device: 设备
:param customs: 自定义参数
"""
logger.info(f"加载数据集")
try:
# 初始化所有数据集
all_dataloader = create_all_dataloader(datasets=datasets, input_size=int(customs["input_size"]),
output_size=int(customs["output_size"]), step=int(customs["step"]),
batch_size=batch_size, time_index=customs["time_index"] == "true",
del_column_name=customs["del_column_name"] == "true",
preprocess_name=preprocess_name)
except RuntimeError:
logger.error(traceback.format_exc())
return
# 开始训练与测试
for model_name in models:
try:
logger.info(f"------------华丽丽的分界线:{model_name} 实验开始------------")
for i in range(len(all_dataloader)):
dataloader = all_dataloader[i]
all_score = {}
for sub_dataloader in dataloader:
dataset_name = sub_dataloader["dataset_name"]
normal_dataloader = sub_dataloader["normal"]
attack_dataloader = sub_dataloader["attack"]
logger.info(f"初始化模型 {model_name}")
model = eval(f"method.{model_name}.Model")(customs=customs, dataloader=normal_dataloader)
model = model.to(device)
logger.info(f"模型初始化完成")
pth_name = model_path[f"{model_name}_{dataset_name}"] if f"{model_name}_{dataset_name}" \
in model_path else None
best_score, best_detection = train_and_test_model(start_time=self.start_time, epochs=epochs,
normal_dataloader=normal_dataloader,
attack_dataloader=attack_dataloader,
model=model, evaluations=evaluations,
device=device, lr=learning_rate,
model_path=pth_name, train=train)
# 保存最佳检测结果的标签为csv文件
best_detection = pandas.DataFrame(best_detection)
best_detection.to_csv(f"./records/{self.start_time}/detection_result/{model_name}_{dataset_name}.csv", index=False)
for evaluation_name in evaluations:
if evaluation_name not in all_score:
all_score[evaluation_name] = []
all_score[evaluation_name].append(best_score[evaluation_name])
gc.collect()
logger.info(f"------------------------")
logger.info(f"{model_name} / {datasets[i]} 实验完毕")
for evaluation_name in all_score:
logger.info(f"{evaluation_name}: {'{:.3f}'.format(sum(all_score[evaluation_name]) / len(all_score[evaluation_name]))}")
logger.info(f"------------------------")
logger.info(f"------------华丽丽的分界线:{model_name} 实验结束------------")
except RuntimeError:
logger.error(traceback.format_exc())
return
if __name__ == '__main__':
pidpath = "/tmp/command_detection.pid"
app = Main(pidpath)