diff --git a/include/lua_plugin_manage.h b/include/lua_plugin_manage.h index 99f1686..12bb1b8 100644 --- a/include/lua_plugin_manage.h +++ b/include/lua_plugin_manage.h @@ -1,6 +1,7 @@ #pragma once #include "stellar/stellar.h" +#include struct lua_config_spec { diff --git a/src/lua_binding_function.c b/src/lua_binding_function.c index 81be039..c308c41 100644 --- a/src/lua_binding_function.c +++ b/src/lua_binding_function.c @@ -575,4 +575,97 @@ err: lua_settop(L, 0); lua_pushboolean(L, 0); return 1; +} + +static const char lua_base64_encode_table[] = { + 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', + 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', + 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', + 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', + 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', + 'y', 'z', '0', '1', '2', '3', '4', '5', '6', '7', + '8', '9', '+', '/'}; + +static int lua_base64_encode(const char *indata, int inlen, char *outdata, int *outlen) +{ + if (indata == NULL || inlen <= 0) + { + return -1; + } + int i, j; + unsigned char num = inlen % 3; + if (outdata != NULL) + { + // 编码,3个字节一组,若数据总长度不是3的倍数,则跳过最后的 num 个字节数据 + for (i = 0, j = 0; i < inlen - num; i += 3, j += 4) + { + outdata[j] = lua_base64_encode_table[(unsigned char)indata[i] >> 2]; + outdata[j + 1] = lua_base64_encode_table[(((unsigned char)indata[i] & 0x03) << 4) | ((unsigned char)indata[i + 1] >> 4)]; + outdata[j + 2] = lua_base64_encode_table[(((unsigned char)indata[i + 1] & 0x0f) << 2) | ((unsigned char)indata[i + 2] >> 6)]; + outdata[j + 3] = lua_base64_encode_table[(unsigned char)indata[i + 2] & 0x3f]; + } + // 继续处理最后的 num 个字节的数据 + if (num == 1) + { // 余数为1,需补齐两个字节'=' + outdata[j] = lua_base64_encode_table[(unsigned char)indata[inlen - 1] >> 2]; + outdata[j + 1] = lua_base64_encode_table[((unsigned char)indata[inlen - 1] & 0x03) << 4]; + outdata[j + 2] = '='; + outdata[j + 3] = '='; + } + else if (num == 2) + { // 余数为2,需补齐一个字节'=' + outdata[j] = lua_base64_encode_table[(unsigned char)indata[inlen - 2] >> 2]; + outdata[j + 1] = lua_base64_encode_table[(((unsigned char)indata[inlen - 2] & 0x03) << 4) | ((unsigned char)indata[inlen - 1] >> 4)]; + outdata[j + 2] = lua_base64_encode_table[((unsigned char)indata[inlen - 1] & 0x0f) << 2]; + outdata[j + 3] = '='; + } + } + if (outlen != NULL) + { + *outlen = (inlen + (num == 0 ? 0 : 3 - num)) * 4 / 3; + } + + return 0; +} + +int lua_session_get0_current_payload(struct lua_state *state) +{ + lua_State *L = (lua_State *)state; + if (lua_gettop(L) != 1) + goto err; + + struct session *sess = (struct session *)lua_topointer(L, -1); + lua_settop(L, 0); + + size_t payload_len = 0; + const char *payload = session_get0_current_payload(sess, &payload_len); + +#if 0 + lua_pushlightuserdata(L, (void *)payload); + lua_pushinteger(L, (lua_Integer)payload_len); + return 2; +err: + lua_settop(L, 0); + lua_pushlightuserdata(L, NULL); + lua_pushboolean(L, 0); + return 2; +#else + char *payload_base64 = (char *)calloc(2, payload_len); + int payload_base64_len = 0; + if (!lua_base64_encode(payload, payload_len, payload_base64, &payload_base64_len)) + { + lua_pushstring(L, payload_base64); + lua_pushinteger(L, (lua_Integer)payload_len); + if (payload_base64) + free(payload_base64); + return 2; + } + if (payload_base64) + free(payload_base64); +err: + lua_settop(L, 0); + lua_pushstring(L, NULL); + lua_pushboolean(L, 0); + return 2; +#endif } \ No newline at end of file diff --git a/src/lua_binding_function.h b/src/lua_binding_function.h index b031c08..ba515cc 100644 --- a/src/lua_binding_function.h +++ b/src/lua_binding_function.h @@ -39,6 +39,9 @@ int lua_session_mq_topic_is_active(struct lua_state *state); int lua_session_mq_ignore_message(struct lua_state *state); int lua_session_mq_unignore_message(struct lua_state *state); +/* session相关其他函数 */ +int lua_session_get0_current_payload(struct lua_state *state); + /* ***** ***** ***** ***** ***** ***** */ /* 此部分为注册至C中的lua通用函数, 实现在lua_plugin_cfunc.c中 */ void *lpm_ctx_new_func(struct session *sess, void *plugin_env); diff --git a/src/lua_plugin_manage.c b/src/lua_plugin_manage.c index d8c7adb..6633d2a 100644 --- a/src/lua_plugin_manage.c +++ b/src/lua_plugin_manage.c @@ -35,6 +35,8 @@ struct lua_bind_function_spec lua_bind_function[] = { {lua_session_mq_publish_message, "publish_message", "session_mq"}, {lua_session_mq_ignore_message, "ignore_message", "session_mq"}, {lua_session_mq_unignore_message, "unignore_message", "session_mq"}, + + {lua_session_get0_current_payload, "get_payload", "session"}, {NULL, NULL, NULL}, }; @@ -632,17 +634,17 @@ struct lua_on_message_fn *hash_on_msg_fn_insert(struct lua_on_message_fn msg_fn_ { int hash_key = calc_on_message_func_hash_key(topic_id, plugin_id); struct lua_on_message_fn *insert_positon = &msg_fn_hashlist[hash_key]; - while ((insert_positon->on_use - HASH_MAX_NUM) > 0) + while ((insert_positon->hash_on_use - HASH_MAX_NUM) > 0) { - if (insert_positon->on_use % HASH_MAX_NUM == (HASH_MAX_NUM - 1)) + if (insert_positon->hash_on_use % HASH_MAX_NUM == (HASH_MAX_NUM - 1)) insert_positon = &msg_fn_hashlist[0]; else insert_positon++; /* 没有空位置了 */ - if (insert_positon->on_use % HASH_MAX_NUM == hash_key) + if (insert_positon->hash_on_use % HASH_MAX_NUM == hash_key) return NULL; } - insert_positon->on_use += (HASH_MAX_NUM + 1); + insert_positon->hash_on_use += (HASH_MAX_NUM + 1); insert_positon->topic_id = topic_id; insert_positon->plugin_id = plugin_id; return insert_positon; @@ -652,7 +654,7 @@ struct lua_on_message_fn *hash_find_on_msg_fn(struct lua_on_message_fn msg_fn_ha { int hash_key = calc_on_message_func_hash_key(topic_id, plugin_id); struct lua_on_message_fn *find_position = &msg_fn_hashlist[hash_key]; - if ((find_position->on_use - HASH_MAX_NUM) < 0) + if ((find_position->hash_on_use - HASH_MAX_NUM) < 0) return NULL; while (find_position) { @@ -660,11 +662,11 @@ struct lua_on_message_fn *hash_find_on_msg_fn(struct lua_on_message_fn msg_fn_ha { return find_position; } - if (find_position->on_use % HASH_MAX_NUM == (HASH_MAX_NUM - 1)) + if (find_position->hash_on_use % HASH_MAX_NUM == (HASH_MAX_NUM - 1)) find_position = &msg_fn_hashlist[0]; else find_position++; - if ((find_position->on_use % HASH_MAX_NUM == hash_key) || (find_position->on_use - HASH_MAX_NUM) < 0) + if ((find_position->hash_on_use % HASH_MAX_NUM == hash_key) || (find_position->hash_on_use - HASH_MAX_NUM) < 0) break; } return NULL; @@ -831,10 +833,12 @@ struct lua_plugin_manage *lua_plugin_manage_init( for (unsigned on_message_index = 0; on_message_index < HASH_MAX_NUM; on_message_index++) { - new_plugin_manage->on_session_message_hashlist[on_message_index].on_use = on_message_index; - new_plugin_manage->on_packet_message_hashlist[on_message_index].on_use = on_message_index; + new_plugin_manage->on_session_message_hashlist[on_message_index].hash_on_use = on_message_index; + new_plugin_manage->on_packet_message_hashlist[on_message_index].hash_on_use = on_message_index; } + if (specific_num == 0) + return new_plugin_manage; new_plugin_manage->load_script_array = (struct lua_load_script *)calloc(specific_num, sizeof(struct lua_load_script)); new_plugin_manage->load_script_num = specific_num; for (unsigned spec_index = 0; spec_index < specific_num; spec_index++) diff --git a/src/lua_plugin_manage_internal.h b/src/lua_plugin_manage_internal.h index 18ac5b6..b977a04 100644 --- a/src/lua_plugin_manage_internal.h +++ b/src/lua_plugin_manage_internal.h @@ -178,7 +178,7 @@ struct lua_message_free_arg struct lua_on_message_fn { - int on_use; + int hash_on_use; int topic_id; int plugin_id; int lua_on_msg_fn_ref_id;