首次提交本地代码
This commit is contained in:
169
main.py
Normal file
169
main.py
Normal file
@@ -0,0 +1,169 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import configparser
|
||||
from daemon import Daemon
|
||||
import sys
|
||||
from torch.cuda import is_available
|
||||
from loguru import logger
|
||||
from datetime import datetime
|
||||
from utils import create_all_dataloader, train_and_test_model
|
||||
import method
|
||||
import gc
|
||||
import traceback
|
||||
import pandas
|
||||
import os
|
||||
|
||||
|
||||
class Config(configparser.ConfigParser):
|
||||
def __init__(self, defaults=None):
|
||||
configparser.ConfigParser.__init__(self, defaults=defaults)
|
||||
|
||||
def optionxform(self, optionstr):
|
||||
return optionstr
|
||||
|
||||
def as_dict(self):
|
||||
d = dict(self._sections)
|
||||
for k in d:
|
||||
d[k] = dict(d[k])
|
||||
return d
|
||||
|
||||
|
||||
class Main(Daemon):
|
||||
def __init__(self, pidfile):
|
||||
super(Main, self).__init__(pidfile=pidfile)
|
||||
current = datetime.now()
|
||||
self.start_time = current.strftime("%Y-%m-%d_%H-%M-%S")
|
||||
|
||||
if len(sys.argv) == 1 or sys.argv[1] == "start":
|
||||
self.run()
|
||||
elif sys.argv[1] == "stop":
|
||||
self.stop()
|
||||
elif sys.argv[1] == "daemon":
|
||||
self.start()
|
||||
else:
|
||||
print("Input format error. Please input: python3 main.py start|stop|daemon")
|
||||
sys.exit(0)
|
||||
|
||||
def run(self):
|
||||
# 读取配置文件参数
|
||||
cf = Config()
|
||||
cf.read("./config.ini", encoding='utf8')
|
||||
cf_dict = cf.as_dict()
|
||||
# 读取数据集名称
|
||||
dataset_names = cf.get("Dataset", "name")
|
||||
datasets = dataset_names.split(",")
|
||||
datasets = [name.strip() for name in datasets]
|
||||
# 读取模型名称
|
||||
model_names = cf.get("Method", "name")
|
||||
models = model_names.split(",")
|
||||
models = [name.strip() for name in models]
|
||||
# 读取预处理方法
|
||||
preprocess_name = cf.get("Preprocess", "name")
|
||||
# 读取评估方法
|
||||
evaluation_names = cf.get("Evaluation", "name")
|
||||
evaluations = evaluation_names.split(",")
|
||||
evaluations = [name.strip() for name in evaluations]
|
||||
# 读取模型参数文件路径
|
||||
model_path = cf_dict["ModelPath"]
|
||||
# 读取训练参数
|
||||
train = cf.getboolean("BaseParameters", "train")
|
||||
epochs = cf.getint("BaseParameters", "epochs")
|
||||
batch_size = cf.getint("BaseParameters", "batch_size")
|
||||
learning_rate = cf.getfloat("BaseParameters", "learning_rate")
|
||||
device = cf.get("BaseParameters", "device")
|
||||
if device == "auto":
|
||||
device = 'cuda:0' if is_available() else 'cpu'
|
||||
# 读取自定义参数
|
||||
customs = cf_dict["CustomParameters"]
|
||||
|
||||
# 建立本次实验记录的路径
|
||||
os.makedirs(f"./records", exist_ok=True)
|
||||
os.makedirs(f"./records/{self.start_time}", exist_ok=True)
|
||||
os.makedirs(f"./records/{self.start_time}/detection_result", exist_ok=True)
|
||||
os.makedirs(f"./records/{self.start_time}/model", exist_ok=True)
|
||||
|
||||
# 初始化日志
|
||||
logger.add(f"./records/{self.start_time}/log",
|
||||
level='DEBUG',
|
||||
format='{time:YYYY-MM-DD HH:mm:ss} - {level} - {file} - {line} - {message}',
|
||||
rotation="100 MB")
|
||||
|
||||
# 核心程序
|
||||
self.core(models, datasets, preprocess_name, evaluations, model_path, train, epochs, batch_size, learning_rate,
|
||||
device, customs)
|
||||
logger.info(f"实验结束,关闭进程")
|
||||
|
||||
def core(self, models: [str], datasets: [str], preprocess_name: str, evaluations: [str], model_path: {}, train: bool,
|
||||
epochs: int, batch_size: int, learning_rate: float, device: str, customs: {}):
|
||||
"""
|
||||
初始化数据集与模型,并开始训练与测试
|
||||
:param models: 训练的模型名称,可包含多个
|
||||
:param datasets: 使用的数据集名称,可包含多个
|
||||
:param preprocess_name: 预处理方法名称
|
||||
:param evaluations: 评估方法名称,可包含多个
|
||||
:param model_path: 需要加载模型参数的路径,可包含多个
|
||||
:param train: 是否训练,如果为False,则仅测试模型
|
||||
:param epochs: 总训练轮数
|
||||
:param batch_size: batch的尺寸
|
||||
:param learning_rate: 学习率
|
||||
:param device: 设备
|
||||
:param customs: 自定义参数
|
||||
"""
|
||||
logger.info(f"加载数据集")
|
||||
try:
|
||||
# 初始化所有数据集
|
||||
all_dataloader = create_all_dataloader(datasets=datasets, input_size=int(customs["input_size"]),
|
||||
output_size=int(customs["output_size"]), step=int(customs["step"]),
|
||||
batch_size=batch_size, time_index=customs["time_index"] == "true",
|
||||
del_column_name=customs["del_column_name"] == "true",
|
||||
preprocess_name=preprocess_name)
|
||||
except RuntimeError:
|
||||
logger.error(traceback.format_exc())
|
||||
return
|
||||
|
||||
# 开始训练与测试
|
||||
|
||||
for model_name in models:
|
||||
try:
|
||||
logger.info(f"------------华丽丽的分界线:{model_name} 实验开始------------")
|
||||
for i in range(len(all_dataloader)):
|
||||
dataloader = all_dataloader[i]
|
||||
all_score = {}
|
||||
for sub_dataloader in dataloader:
|
||||
dataset_name = sub_dataloader["dataset_name"]
|
||||
normal_dataloader = sub_dataloader["normal"]
|
||||
attack_dataloader = sub_dataloader["attack"]
|
||||
logger.info(f"初始化模型 {model_name}")
|
||||
model = eval(f"method.{model_name}.Model")(customs=customs, dataloader=normal_dataloader)
|
||||
model = model.to(device)
|
||||
logger.info(f"模型初始化完成")
|
||||
pth_name = model_path[f"{model_name}_{dataset_name}"] if f"{model_name}_{dataset_name}" \
|
||||
in model_path else None
|
||||
best_score, best_detection = train_and_test_model(start_time=self.start_time, epochs=epochs,
|
||||
normal_dataloader=normal_dataloader,
|
||||
attack_dataloader=attack_dataloader,
|
||||
model=model, evaluations=evaluations,
|
||||
device=device, lr=learning_rate,
|
||||
model_path=pth_name, train=train)
|
||||
# 保存最佳检测结果的标签为csv文件
|
||||
best_detection = pandas.DataFrame(best_detection)
|
||||
best_detection.to_csv(f"./records/{self.start_time}/detection_result/{model_name}_{dataset_name}.csv", index=False)
|
||||
for evaluation_name in evaluations:
|
||||
if evaluation_name not in all_score:
|
||||
all_score[evaluation_name] = []
|
||||
all_score[evaluation_name].append(best_score[evaluation_name])
|
||||
gc.collect()
|
||||
logger.info(f"------------------------")
|
||||
logger.info(f"{model_name} / {datasets[i]} 实验完毕")
|
||||
for evaluation_name in all_score:
|
||||
logger.info(f"{evaluation_name}: {'{:.3f}'.format(sum(all_score[evaluation_name]) / len(all_score[evaluation_name]))}")
|
||||
logger.info(f"------------------------")
|
||||
logger.info(f"------------华丽丽的分界线:{model_name} 实验结束------------")
|
||||
except RuntimeError:
|
||||
logger.error(traceback.format_exc())
|
||||
return
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pidpath = "/tmp/command_detection.pid"
|
||||
app = Main(pidpath)
|
||||
|
||||
Reference in New Issue
Block a user