package cn.mesalab.service; import cn.mesalab.config.ApplicationConfig; import cn.mesalab.dao.DruidData; import cn.mesalab.service.algorithm.KalmanFilter; import cn.mesalab.utils.HbaseUtils; import cn.mesalab.utils.SeriesUtils; import com.google.common.collect.Lists; import com.google.common.util.concurrent.ThreadFactoryBuilder; import org.apache.commons.math3.stat.StatUtils; import org.apache.hadoop.hbase.client.Put; import org.apache.hadoop.hbase.client.Table; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.IOException; import java.util.*; import java.util.concurrent.*; import java.util.stream.Collectors; /** * @author yjy * @version 1.0 * baseline生成及写入 * @date 2021/7/23 5:38 下午 */ public class BaselineGeneration { private static final Logger LOG = LoggerFactory.getLogger(BaselineGeneration.class); private static DruidData druidData; private static HbaseUtils hbaseUtils; private static Table hbaseTable; private static List> batchDruidData = new ArrayList<>(); private static List attackTypeList = Arrays.asList( ApplicationConfig.DRUID_ATTACKTYPE_TCP_SYN_FLOOD, ApplicationConfig.DRUID_ATTACKTYPE_ICMP_FLOOD, ApplicationConfig.DRUID_ATTACKTYPE_UDP_FLOOD, ApplicationConfig.DRUID_ATTACKTYPE_DNS_AMPL ); private static final Integer BASELINE_POINT_NUM = ApplicationConfig.BASELINE_RANGE_DAYS * 24 * (60/ApplicationConfig.HISTORICAL_GRAD); /** * 程序执行 */ public static void perform() { long start = System.currentTimeMillis(); druidData = DruidData.getInstance(); hbaseUtils = HbaseUtils.getInstance(); hbaseTable = hbaseUtils.getHbaseTable(); LOG.info("Druid 成功建立连接"); try{ // baseline生成并写入 generateBaselinesThread(); long last = System.currentTimeMillis(); LOG.warn("运行时间:" + (last - start)); druidData.closeConn(); hbaseTable.close(); LOG.info("Druid 关闭连接"); } catch (Exception e){ e.printStackTrace(); } System.exit(0); } /** * 多线程baseline生成入口 * @throws InterruptedException */ private static void generateBaselinesThread() throws InterruptedException { int threadNum = Runtime.getRuntime().availableProcessors(); ThreadFactory namedThreadFactory = new ThreadFactoryBuilder() .setNameFormat("baseline-demo-%d").build(); // 创建线程池 ThreadPoolExecutor executor = new ThreadPoolExecutor( threadNum, threadNum, 0L, TimeUnit.MILLISECONDS, new LinkedBlockingQueue<>(1024), namedThreadFactory, new ThreadPoolExecutor.AbortPolicy()); // IP列表获取 ArrayList destinationIps = druidData.getServerIpList(); LOG.info("共查询到服务端ip " +destinationIps.size() + " 个"); LOG.info("Baseline batch 大小: " + ApplicationConfig.GENERATE_BATCH_SIZE); // 分批进行IP baseline生成和处理 List> batchIpLists = Lists.partition(destinationIps, ApplicationConfig.GENERATE_BATCH_SIZE); for (List batchIps: batchIpLists){ if(batchIps.size()>0){ executor.execute(() -> generateBaselines(batchIps)); } } executor.shutdown(); executor.awaitTermination(10L, TimeUnit.HOURS); } /** * 批量生成IP baseline * @param ipList ip列表 */ public static void generateBaselines(List ipList){ druidData = DruidData.getInstance(); batchDruidData = druidData.readFromDruid(ipList); List putList = new ArrayList<>(); for(String attackType: attackTypeList){ for(String ip: ipList){ int[] ipBaseline = generateSingleIpBaseline(ip, attackType); if (ipBaseline!= null){ putList = hbaseUtils.cachedInPut(putList, ip, ipBaseline, attackType, ApplicationConfig.BASELINE_METRIC_TYPE); } } } try { hbaseTable.put(putList); LOG.info("Baseline 线程 " + Thread.currentThread().getId() + " 成功写入Baseline条数共计 " + putList.size()); } catch (IOException e) { e.printStackTrace(); } druidData.closeConn(); } /** * 单ip baseline生成逻辑 * @param ip ip * @param attackType 攻击类型 * @return baseline序列,长度为 60/HISTORICAL_GRAD*24 */ private static int[] generateSingleIpBaseline(String ip, String attackType){ // 查询 List> originSeries = druidData.getTimeSeriesData(batchDruidData, ip, attackType); if (originSeries.size()==0){ return null; } // 时间序列缺失值补0 List> completSeries = SeriesUtils.complementSeries(originSeries); int[] baselineArr = new int[BASELINE_POINT_NUM]; Listseries = completSeries.stream().map( i -> Integer.valueOf(i.get(ApplicationConfig.BASELINE_METRIC_TYPE).toString())).collect(Collectors.toList()); // 判断ip出现频率 if(originSeries.size()/(float)completSeries.size()>ApplicationConfig.BASELINE_HISTORICAL_RATIO){ // 低频率 double percentile = StatUtils.percentile(series.stream().mapToDouble(Double::valueOf).toArray(), ApplicationConfig.BASELINE_SPARSE_FILL_PERCENTILE); Arrays.fill(baselineArr, (int)percentile); baselineArr = baselineFunction(series); } else { // 判断周期性 if (SeriesUtils.isPeriod(series)){ baselineArr = baselineFunction(series); } else { int ipPercentile = SeriesUtils.percentile( originSeries.stream().map(i -> Integer.valueOf(i.get(ApplicationConfig.BASELINE_METRIC_TYPE).toString())).collect(Collectors.toList()), ApplicationConfig.BASELINE_RATIONAL_PERCENTILE); Arrays.fill(baselineArr, ipPercentile); } } return baselineArr; } /** * baseline 生成算法 * @param timeSeries 输入序列 * @return 输出序列 */ private static int[] baselineFunction(List timeSeries){ int[] result; switch (ApplicationConfig.BASELINE_FUNCTION){ case "KalmanFilter": KalmanFilter kalmanFilter = new KalmanFilter(); kalmanFilter.forcast(timeSeries, BASELINE_POINT_NUM); result = kalmanFilter.getForecastSeries().stream().mapToInt(Integer::valueOf).toArray(); break; default: result = timeSeries.subList(0, BASELINE_POINT_NUM).stream().mapToInt(Integer::valueOf).toArray(); } return result; } public static void main(String[] args) { perform(); } }