150 lines
4.8 KiB
Python
150 lines
4.8 KiB
Python
import pandas as pd
|
|
import os
|
|
from sklearn import tree
|
|
from sklearn.metrics import classification_report
|
|
from sklearn.metrics import roc_curve, auc, roc_auc_score
|
|
from sklearn.ensemble import RandomForestClassifier
|
|
from sklearn.ensemble import GradientBoostingClassifier
|
|
from sklearn.naive_bayes import BernoulliNB
|
|
from sklearn.naive_bayes import GaussianNB
|
|
from sklearn.neighbors import KNeighborsClassifier as kNN
|
|
from sklearn.metrics import confusion_matrix
|
|
from pandas.core.frame import DataFrame
|
|
import datetime
|
|
import numpy as np
|
|
import pandas as pd
|
|
import joblib
|
|
from sklearn.model_selection import train_test_split
|
|
|
|
|
|
wireshark_csv_path = "cicdos2017/dataset/cicdos2017_from_wireshark.csv"
|
|
meta_csv_path = "cicdos2017/dataset/meta.csv"
|
|
flow = {}
|
|
feature_latitude = 40 + 3 # 前三个元素记录了packet类型的数量[high,low,normal]
|
|
feature_ls = []
|
|
|
|
# attacks.txt
|
|
attack_info = []
|
|
attack_info_path = "cicdos2017/dataset/attacks.txt"
|
|
|
|
|
|
def get_attack_info(file_path):
|
|
with open(file_path) as f:
|
|
line_list = f.readlines()
|
|
for line in line_list:
|
|
if line == "":
|
|
break
|
|
ls = line.split()
|
|
label = ls[0]
|
|
dstIP = ls[2]
|
|
time = int(ls[5])*60 # s
|
|
attack_info.append([label, dstIP, time])
|
|
|
|
|
|
def get_label(srcIP, dstIP, time, num_ls) -> str:
|
|
label = "normal"
|
|
for i in range(len(attack_info)):
|
|
if i == len(attack_info)-1:
|
|
# last one
|
|
if time >= attack_info[i][2] and (srcIP == attack_info[i][1] or dstIP == attack_info[i][1]):
|
|
label = attack_info[i][0]
|
|
else:
|
|
if time >= attack_info[i][2] and time < attack_info[i+1][2] and (srcIP == attack_info[i][1] or dstIP == attack_info[i][1]):
|
|
label = attack_info[i][0]
|
|
|
|
# low ddos packet
|
|
if label == "slowread" or label == "slowheaders" or label == "slowbody2" or label == "rudy" or label == "slowloris":
|
|
num_ls[1] += 1
|
|
label = "low"
|
|
# normal packet
|
|
elif label == "normal":
|
|
num_ls[2] += 1
|
|
else:
|
|
num_ls[0] += 1 # high ddos packet
|
|
label = "high"
|
|
return label
|
|
|
|
|
|
def pre_process(wireshark_csv_path,meta_csv_path):
|
|
# packet type [high,low,normal]
|
|
packet_num_ls = [0, 0, 0]
|
|
df = pd.read_csv(wireshark_csv_path)
|
|
for index, row in df.iterrows():
|
|
if index % 20000 == 0:
|
|
print(index)
|
|
srcIP = row["Source"]
|
|
dstIP = row["Destination"]
|
|
time = row["Time"] # s
|
|
protocol = row["Protocol"]
|
|
length = row["Length"]
|
|
# get label
|
|
label = get_label(srcIP, dstIP, time, packet_num_ls)
|
|
# get packet type
|
|
|
|
# key
|
|
key = srcIP+dstIP
|
|
reverse_key = dstIP+srcIP
|
|
|
|
if not flow.__contains__(key) and not flow.__contains__(reverse_key):
|
|
# 创建一个flow
|
|
flow[key] = [0,0,0, "-"+protocol]
|
|
if label =="high":
|
|
flow[key][0]=1
|
|
elif label == "low":
|
|
flow[key][1]=1
|
|
else:
|
|
flow[key][2]=1
|
|
elif flow.__contains__(key):
|
|
# key添加到flow里面
|
|
if len(flow[key]) >= feature_latitude:
|
|
feature_ls.append(flow.pop(key))
|
|
flow[key] = [0,0,0, "-"+protocol]
|
|
if label =="high":
|
|
flow[key][0]=1
|
|
elif label == "low":
|
|
flow[key][1]=1
|
|
else:
|
|
flow[key][2]=1
|
|
flow[key].append("-"+protocol)
|
|
# 更新type数量
|
|
if label =="high":
|
|
flow[key][0]+=1
|
|
elif label == "low":
|
|
flow[key][1]+=1
|
|
else:
|
|
flow[key][2]+=1
|
|
else:
|
|
# reverse key
|
|
if len(flow[reverse_key]) >= feature_latitude:
|
|
feature_ls.append(flow.pop(reverse_key))
|
|
flow[reverse_key] = [0,0,0, "+"+protocol]
|
|
if label =="high":
|
|
flow[reverse_key][0]=1
|
|
elif label == "low":
|
|
flow[reverse_key][1]=1
|
|
else:
|
|
flow[reverse_key][2]=1
|
|
flow[reverse_key].append("+"+protocol)
|
|
# 更新type数量
|
|
if label =="high":
|
|
flow[reverse_key][0]+=1
|
|
elif label == "low":
|
|
flow[reverse_key][1]+=1
|
|
else:
|
|
flow[reverse_key][2]+=1
|
|
|
|
print("数据包type统计:[high,low,normal]")
|
|
print(packet_num_ls)
|
|
print("--------------------")
|
|
print("字典中剩余个数:%d" % (len(flow)))
|
|
# write to csv file
|
|
total_data = pd.DataFrame(data=feature_ls)
|
|
# print(total_data)
|
|
total_data.to_csv(meta_csv_path, index=False,
|
|
encoding="utf-8", sep=',', mode='w', header=True)
|
|
return
|
|
|
|
|
|
get_attack_info(attack_info_path)
|
|
pre_process(wireshark_csv_path,meta_csv_path)
|