diff --git a/src/tsg_rule.cpp b/src/tsg_rule.cpp index db1c67d..cbe4876 100644 --- a/src/tsg_rule.cpp +++ b/src/tsg_rule.cpp @@ -11,33 +11,33 @@ enum kni_scan_table{ SCAN_TABLE_MAX }; -struct kni_protocol_identify_result{ - enum tsg_protocol protocol; - char domain[TSG_DOMAIN_MAX]; - int domain_len; -}; - Maat_feather_t g_kni_maat_feather; const char *g_kni_scan_table_name[SCAN_TABLE_MAX]; int g_kni_scan_tableid[SCAN_TABLE_MAX] = {0}; -static void protocol_identify(char *buff, int buff_len, struct kni_protocol_identify_result *result){ - result->protocol = TSG_PROTOCOL_UNKNOWN; +static void protocol_identify(char *buff, int buff_len, struct _identify_info *result){ + result->protocol = -1; //TODO: http: get from http protocol plugin - + /* + if(is_http){ + result->protocol = PROTO_HTTP; + return; + } + */ //ssl enum chello_parse_result chello_status = CHELLO_PARSE_INVALID_FORMAT; struct ssl_chello *chello = NULL; chello = ssl_chello_parse((const unsigned char*)buff, buff_len, &chello_status); if(chello_status == CHELLO_PARSE_SUCCESS){ - result->protocol = TSG_PROTOCOL_SSL; + result->protocol = PROTO_SSL; if(chello->sni == NULL){ result->domain_len = 0; } else{ - strncpy(result->domain, chello->sni, strnlen(chello->sni, sizeof(result->domain) - 1)); - result->domain_len = strlen(result->domain); + result->domain_len = strnlen(chello->sni, sizeof(result->domain) - 1); + strncpy(result->domain, chello->sni, result->domain_len); } + result->domain[result->domain_len] = '\0'; } ssl_chello_free(chello); return; @@ -60,25 +60,22 @@ int tsg_shared_table_init(const char *conffile, Maat_feather_t maat_feather, voi return 0; } -//return -1 if failed, return 0 on success -int tsg_scan_shared_policy(Maat_feather_t maat_feather, void *pkt, int pkt_len, Maat_rule_t *result, int result_num, enum tsg_protocol *protocol, char *domain, int *domain_len, - scan_status_t *mid, void *logger, int thread_seq){ - struct kni_protocol_identify_result protocol_identify_res; - memset(&protocol_identify_res, 0, sizeof(protocol_identify_res)); - protocol_identify((char*)pkt, pkt_len, &protocol_identify_res); - if(protocol_identify_res.protocol == TSG_PROTOCOL_UNKNOWN){ + +//return value: -1: failed, 0: not hit, >0: hit count +int tsg_scan_shared_policy(Maat_feather_t maat_feather, void *pkt, int pkt_len, Maat_rule_t *result, int result_num, + struct _identify_info *identify_info, scan_status_t *mid, void *logger, int thread_seq){ + memset(identify_info, 0, sizeof(*identify_info)); + protocol_identify((char*)pkt, pkt_len, identify_info); + if(identify_info->protocol != TSG_PROTOCOL_SSL && identify_info->protocol != TSG_PROTOCOL_HTTP){ return -1; } - *protocol = protocol_identify_res.protocol; - *domain_len = protocol_identify_res.domain_len; - strncpy(domain, protocol_identify_res.domain, *domain_len); int tableid; - if(protocol_identify_res.protocol == TSG_PROTOCOL_SSL){ + if(identify_info->protocol == TSG_PROTOCOL_SSL){ tableid = g_kni_scan_tableid[TSG_FIELD_SSL_SNI]; } - if(protocol_identify_res.protocol == TSG_PROTOCOL_HTTP){ + else{ tableid = g_kni_scan_tableid[TSG_FIELD_HTTP_HOST]; } - return Maat_full_scan_string(g_kni_maat_feather, tableid, CHARSET_UTF8, domain, *domain_len, + return Maat_full_scan_string(g_kni_maat_feather, tableid, CHARSET_UTF8, identify_info->domain, identify_info->domain_len, result, NULL, result_num, mid, thread_seq); } \ No newline at end of file