This repository has been archived on 2025-09-14. You can view files and clone it, but cannot push or open issues or pull requests.
Files
dengzeyi-sequenceshield/代码/sequenceShield/cicdos2017/script/parse_file.py
2022-11-21 12:08:58 +08:00

150 lines
4.8 KiB
Python

import pandas as pd
import os
from sklearn import tree
from sklearn.metrics import classification_report
from sklearn.metrics import roc_curve, auc, roc_auc_score
from sklearn.ensemble import RandomForestClassifier
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.naive_bayes import BernoulliNB
from sklearn.naive_bayes import GaussianNB
from sklearn.neighbors import KNeighborsClassifier as kNN
from sklearn.metrics import confusion_matrix
from pandas.core.frame import DataFrame
import datetime
import numpy as np
import pandas as pd
import joblib
from sklearn.model_selection import train_test_split
wireshark_csv_path = "cicdos2017/dataset/cicdos2017_from_wireshark.csv"
meta_csv_path = "cicdos2017/dataset/meta.csv"
flow = {}
feature_latitude = 40 + 3 # 前三个元素记录了packet类型的数量[high,low,normal]
feature_ls = []
# attacks.txt
attack_info = []
attack_info_path = "cicdos2017/dataset/attacks.txt"
def get_attack_info(file_path):
with open(file_path) as f:
line_list = f.readlines()
for line in line_list:
if line == "":
break
ls = line.split()
label = ls[0]
dstIP = ls[2]
time = int(ls[5])*60 # s
attack_info.append([label, dstIP, time])
def get_label(srcIP, dstIP, time, num_ls) -> str:
label = "normal"
for i in range(len(attack_info)):
if i == len(attack_info)-1:
# last one
if time >= attack_info[i][2] and (srcIP == attack_info[i][1] or dstIP == attack_info[i][1]):
label = attack_info[i][0]
else:
if time >= attack_info[i][2] and time < attack_info[i+1][2] and (srcIP == attack_info[i][1] or dstIP == attack_info[i][1]):
label = attack_info[i][0]
# low ddos packet
if label == "slowread" or label == "slowheaders" or label == "slowbody2" or label == "rudy" or label == "slowloris":
num_ls[1] += 1
label = "low"
# normal packet
elif label == "normal":
num_ls[2] += 1
else:
num_ls[0] += 1 # high ddos packet
label = "high"
return label
def pre_process(wireshark_csv_path,meta_csv_path):
# packet type [high,low,normal]
packet_num_ls = [0, 0, 0]
df = pd.read_csv(wireshark_csv_path)
for index, row in df.iterrows():
if index % 20000 == 0:
print(index)
srcIP = row["Source"]
dstIP = row["Destination"]
time = row["Time"] # s
protocol = row["Protocol"]
length = row["Length"]
# get label
label = get_label(srcIP, dstIP, time, packet_num_ls)
# get packet type
# key
key = srcIP+dstIP
reverse_key = dstIP+srcIP
if not flow.__contains__(key) and not flow.__contains__(reverse_key):
# 创建一个flow
flow[key] = [0,0,0, "-"+protocol]
if label =="high":
flow[key][0]=1
elif label == "low":
flow[key][1]=1
else:
flow[key][2]=1
elif flow.__contains__(key):
# key添加到flow里面
if len(flow[key]) >= feature_latitude:
feature_ls.append(flow.pop(key))
flow[key] = [0,0,0, "-"+protocol]
if label =="high":
flow[key][0]=1
elif label == "low":
flow[key][1]=1
else:
flow[key][2]=1
flow[key].append("-"+protocol)
# 更新type数量
if label =="high":
flow[key][0]+=1
elif label == "low":
flow[key][1]+=1
else:
flow[key][2]+=1
else:
# reverse key
if len(flow[reverse_key]) >= feature_latitude:
feature_ls.append(flow.pop(reverse_key))
flow[reverse_key] = [0,0,0, "+"+protocol]
if label =="high":
flow[reverse_key][0]=1
elif label == "low":
flow[reverse_key][1]=1
else:
flow[reverse_key][2]=1
flow[reverse_key].append("+"+protocol)
# 更新type数量
if label =="high":
flow[reverse_key][0]+=1
elif label == "low":
flow[reverse_key][1]+=1
else:
flow[reverse_key][2]+=1
print("数据包type统计:[high,low,normal]")
print(packet_num_ls)
print("--------------------")
print("字典中剩余个数:%d" % (len(flow)))
# write to csv file
total_data = pd.DataFrame(data=feature_ls)
# print(total_data)
total_data.to_csv(meta_csv_path, index=False,
encoding="utf-8", sep=',', mode='w', header=True)
return
get_attack_info(attack_info_path)
pre_process(wireshark_csv_path,meta_csv_path)