170 lines
7.8 KiB
Python
170 lines
7.8 KiB
Python
# -*- coding:utf-8 -*-
|
||
import configparser
|
||
from daemon import Daemon
|
||
import sys
|
||
from torch.cuda import is_available
|
||
from loguru import logger
|
||
from datetime import datetime
|
||
from utils import create_all_dataloader, train_and_test_model
|
||
import method
|
||
import gc
|
||
import traceback
|
||
import pandas
|
||
import os
|
||
|
||
|
||
class Config(configparser.ConfigParser):
|
||
def __init__(self, defaults=None):
|
||
configparser.ConfigParser.__init__(self, defaults=defaults)
|
||
|
||
def optionxform(self, optionstr):
|
||
return optionstr
|
||
|
||
def as_dict(self):
|
||
d = dict(self._sections)
|
||
for k in d:
|
||
d[k] = dict(d[k])
|
||
return d
|
||
|
||
|
||
class Main(Daemon):
|
||
def __init__(self, pidfile):
|
||
super(Main, self).__init__(pidfile=pidfile)
|
||
current = datetime.now()
|
||
self.start_time = current.strftime("%Y-%m-%d_%H-%M-%S")
|
||
|
||
if len(sys.argv) == 1 or sys.argv[1] == "start":
|
||
self.run()
|
||
elif sys.argv[1] == "stop":
|
||
self.stop()
|
||
elif sys.argv[1] == "daemon":
|
||
self.start()
|
||
else:
|
||
print("Input format error. Please input: python3 main.py start|stop|daemon")
|
||
sys.exit(0)
|
||
|
||
def run(self):
|
||
# 读取配置文件参数
|
||
cf = Config()
|
||
cf.read("./config.ini", encoding='utf8')
|
||
cf_dict = cf.as_dict()
|
||
# 读取数据集名称
|
||
dataset_names = cf.get("Dataset", "name")
|
||
datasets = dataset_names.split(",")
|
||
datasets = [name.strip() for name in datasets]
|
||
# 读取模型名称
|
||
model_names = cf.get("Method", "name")
|
||
models = model_names.split(",")
|
||
models = [name.strip() for name in models]
|
||
# 读取预处理方法
|
||
preprocess_name = cf.get("Preprocess", "name")
|
||
# 读取评估方法
|
||
evaluation_names = cf.get("Evaluation", "name")
|
||
evaluations = evaluation_names.split(",")
|
||
evaluations = [name.strip() for name in evaluations]
|
||
# 读取模型参数文件路径
|
||
model_path = cf_dict["ModelPath"]
|
||
# 读取训练参数
|
||
train = cf.getboolean("BaseParameters", "train")
|
||
epochs = cf.getint("BaseParameters", "epochs")
|
||
batch_size = cf.getint("BaseParameters", "batch_size")
|
||
learning_rate = cf.getfloat("BaseParameters", "learning_rate")
|
||
device = cf.get("BaseParameters", "device")
|
||
if device == "auto":
|
||
device = 'cuda:0' if is_available() else 'cpu'
|
||
# 读取自定义参数
|
||
customs = cf_dict["CustomParameters"]
|
||
|
||
# 建立本次实验记录的路径
|
||
os.makedirs(f"./records", exist_ok=True)
|
||
os.makedirs(f"./records/{self.start_time}", exist_ok=True)
|
||
os.makedirs(f"./records/{self.start_time}/detection_result", exist_ok=True)
|
||
os.makedirs(f"./records/{self.start_time}/model", exist_ok=True)
|
||
|
||
# 初始化日志
|
||
logger.add(f"./records/{self.start_time}/log",
|
||
level='DEBUG',
|
||
format='{time:YYYY-MM-DD HH:mm:ss} - {level} - {file} - {line} - {message}',
|
||
rotation="100 MB")
|
||
|
||
# 核心程序
|
||
self.core(models, datasets, preprocess_name, evaluations, model_path, train, epochs, batch_size, learning_rate,
|
||
device, customs)
|
||
logger.info(f"实验结束,关闭进程")
|
||
|
||
def core(self, models: [str], datasets: [str], preprocess_name: str, evaluations: [str], model_path: {}, train: bool,
|
||
epochs: int, batch_size: int, learning_rate: float, device: str, customs: {}):
|
||
"""
|
||
初始化数据集与模型,并开始训练与测试
|
||
:param models: 训练的模型名称,可包含多个
|
||
:param datasets: 使用的数据集名称,可包含多个
|
||
:param preprocess_name: 预处理方法名称
|
||
:param evaluations: 评估方法名称,可包含多个
|
||
:param model_path: 需要加载模型参数的路径,可包含多个
|
||
:param train: 是否训练,如果为False,则仅测试模型
|
||
:param epochs: 总训练轮数
|
||
:param batch_size: batch的尺寸
|
||
:param learning_rate: 学习率
|
||
:param device: 设备
|
||
:param customs: 自定义参数
|
||
"""
|
||
logger.info(f"加载数据集")
|
||
try:
|
||
# 初始化所有数据集
|
||
all_dataloader = create_all_dataloader(datasets=datasets, input_size=int(customs["input_size"]),
|
||
output_size=int(customs["output_size"]), step=int(customs["step"]),
|
||
batch_size=batch_size, time_index=customs["time_index"] == "true",
|
||
del_column_name=customs["del_column_name"] == "true",
|
||
preprocess_name=preprocess_name)
|
||
except RuntimeError:
|
||
logger.error(traceback.format_exc())
|
||
return
|
||
|
||
# 开始训练与测试
|
||
|
||
for model_name in models:
|
||
try:
|
||
logger.info(f"------------华丽丽的分界线:{model_name} 实验开始------------")
|
||
for i in range(len(all_dataloader)):
|
||
dataloader = all_dataloader[i]
|
||
all_score = {}
|
||
for sub_dataloader in dataloader:
|
||
dataset_name = sub_dataloader["dataset_name"]
|
||
normal_dataloader = sub_dataloader["normal"]
|
||
attack_dataloader = sub_dataloader["attack"]
|
||
logger.info(f"初始化模型 {model_name}")
|
||
model = eval(f"method.{model_name}.Model")(customs=customs, dataloader=normal_dataloader)
|
||
model = model.to(device)
|
||
logger.info(f"模型初始化完成")
|
||
pth_name = model_path[f"{model_name}_{dataset_name}"] if f"{model_name}_{dataset_name}" \
|
||
in model_path else None
|
||
best_score, best_detection = train_and_test_model(start_time=self.start_time, epochs=epochs,
|
||
normal_dataloader=normal_dataloader,
|
||
attack_dataloader=attack_dataloader,
|
||
model=model, evaluations=evaluations,
|
||
device=device, lr=learning_rate,
|
||
model_path=pth_name, train=train)
|
||
# 保存最佳检测结果的标签为csv文件
|
||
best_detection = pandas.DataFrame(best_detection)
|
||
best_detection.to_csv(f"./records/{self.start_time}/detection_result/{model_name}_{dataset_name}.csv", index=False)
|
||
for evaluation_name in evaluations:
|
||
if evaluation_name not in all_score:
|
||
all_score[evaluation_name] = []
|
||
all_score[evaluation_name].append(best_score[evaluation_name])
|
||
gc.collect()
|
||
logger.info(f"------------------------")
|
||
logger.info(f"{model_name} / {datasets[i]} 实验完毕")
|
||
for evaluation_name in all_score:
|
||
logger.info(f"{evaluation_name}: {'{:.3f}'.format(sum(all_score[evaluation_name]) / len(all_score[evaluation_name]))}")
|
||
logger.info(f"------------------------")
|
||
logger.info(f"------------华丽丽的分界线:{model_name} 实验结束------------")
|
||
except RuntimeError:
|
||
logger.error(traceback.format_exc())
|
||
return
|
||
|
||
|
||
if __name__ == '__main__':
|
||
pidpath = "/tmp/command_detection.pid"
|
||
app = Main(pidpath)
|
||
|