#include "monitor_mem.h" #include // for FOLL_FORCE #include // pid pid_task #include // get_task_mm #include /* for kmalloc */ /// @brief transfer user space address to kernel space address /// change static global "kaddr" and "page" value /// @param pid: process id /// @param kaddr: user space address /// @return kernel space address + offset void *convert_user_space_ptr(pid_t pid, unsigned long addr) { struct task_struct *task; struct mm_struct *mm; int ret; // unsigned long aligned_addr = 0; // unsigned long offset = 0; watch_local_memory *node; // if (addr < TASK_SIZE || addr > -PAGE_SIZE) // { // printk(KERN_ERR "Invalid address\n"); // return NULL; // } // for get_user_pages_remote unsigned long aligned_addr = addr & PAGE_MASK; unsigned long offset = addr & ~PAGE_MASK; printk(KERN_INFO "%s\n", __FUNCTION__); node = kmalloc(sizeof(watch_local_memory), GFP_KERNEL); node->task_id = pid; // Find the task with pid rcu_read_lock(); task = pid_task(find_vpid(pid), PIDTYPE_PID); rcu_read_unlock(); if (!task) { printk(KERN_ERR "Cannot find task for PID %d\n", pid); kfree(node); // careful there is kfree return NULL; } // Get memory descriptor mm = get_task_mm(task); if (!mm) { printk(KERN_ERR "Cannot get memory descriptor\n"); kfree(node); // careful there is kfree return NULL; } down_read(&task->mm->mmap_lock); ret = get_user_pages_remote(task->mm, aligned_addr, 1, FOLL_FORCE, &(node->page), NULL, NULL); up_read(&task->mm->mmap_lock); if (ret != 1) { printk(KERN_ERR "Cannot get user page\n"); kfree(node); // careful there is kfree return NULL; } // Map the page to kernel space node->kaddr = kmap(node->page); list_add_tail(&node->entry, &watch_local_memory_list); // add to list // printk(KERN_INFO "node->kaddr: %p, aligned_addr: %ld, offset: %ld\n", // node->kaddr, aligned_addr, offset); return (void *)((unsigned long)(node->kaddr) + offset); } /// @brief free page in watch_local_memory_list with task_id /// @param task_id void free_page_list(pid_t task_id) { watch_local_memory *node, *next; list_for_each_entry_safe(node, next, &watch_local_memory_list, entry) { if (node == NULL) break; if (node->task_id == task_id) { // unmap and release the page if (node->kaddr) kunmap(node->kaddr); if (node->page) put_page(node->page); list_del(&node->entry); kfree(node); // careful there is kfree } } } /// @brief free all page in watch_local_memory_list /// @param void free_all_page_list(void) { watch_local_memory *node, *next; list_for_each_entry_safe(node, next, &watch_local_memory_list, entry) { if (node == NULL) break; // unmap and release the page if (node->kaddr) kunmap(node->kaddr); if (node->page) put_page(node->page); list_del(&node->entry); kfree(node); // careful there is kfree } } // for read_and_compare typedef unsigned char (*compare_func)(void *, long long); unsigned char compare_1_byte_signed(void *ptr, long long threshold) { // printk("compare_1_byte_signed: value %d, biss: %lld\n", *(char *)ptr, // threshold); return *(char *)ptr > threshold; } unsigned char compare_1_byte_unsigned(void *ptr, long long threshold) { // printk("compare_1_byte_unsigned: value %d, biss: %lld\n", *(unsigned char // *)ptr, threshold); return *(unsigned char *)ptr > threshold; } unsigned char compare_2_byte_signed(void *ptr, long long threshold) { // printk("compare_2_byte_signed: value %d, biss: %lld\n", *(short int *)ptr, // threshold); return *(short int *)ptr > threshold; } unsigned char compare_2_byte_unsigned(void *ptr, long long threshold) { // printk("compare_2_byte_unsigned: value %d, biss: %lld\n", *(unsigned short // int *)ptr, threshold); return *(unsigned short int *)ptr > threshold; } unsigned char compare_4_byte_signed(void *ptr, long long threshold) { // printk("compare_4_byte_signed: value %d, biss: %lld\n", *(int *)ptr, // threshold); return *(int *)ptr > threshold; } unsigned char compare_4_byte_unsigned(void *ptr, long long threshold) { // printk("compare_4_byte_unsigned: value %d, biss: %lld\n", *(unsigned int // *)ptr, threshold); return *(unsigned int *)ptr > threshold; } unsigned char compare_8_byte_signed(void *ptr, long long threshold) { // printk("compare_8_byte_signed: value %lld, biss: %lld\n", *(long long // *)ptr, threshold); return *(long long *)ptr > threshold; } unsigned char compare_8_byte_unsigned(void *ptr, long long threshold) { // printk("compare_8_byte_unsigned: value %lld, biss: %lld\n", *(unsigned long // long *)ptr, threshold); return *(unsigned long long *)ptr > threshold; } // list of compare functions static compare_func compare_funcs[8] = { compare_1_byte_signed, compare_2_byte_signed, compare_4_byte_signed, compare_8_byte_signed, compare_1_byte_unsigned, compare_2_byte_unsigned, compare_4_byte_unsigned, compare_8_byte_unsigned}; static int func_indices[2][9] = {{0, 0, 1, 0, 2, 0, 0, 0, 3}, {0, 4, 5, 0, 6, 0, 0, 0, 7}}; /// @brief read k_arg->kptr and compare with threshold /// @param k_arg /// @return result of compare unsigned char read_and_compare(void *ptr, int len, unsigned char above_threshold, unsigned char is_unsigned, long long threshold) { unsigned char result = 0; // if (len != 1 && len != 2 && len != 4 && len != 8) // { // printk(KERN_ERR "Invalid length\n"); // return 0; // } result = compare_funcs[func_indices[is_unsigned][len]](ptr, threshold); // printk(KERN_INFO "read_and_compare: name %s, value %d, biss: %lld, result: // %d \n", k_arg->name, *(int *)ptr, // threshold, result); if (above_threshold) return result; else return !result; }