#include #include #include #include #include #include #define smaller(root, a, b) (root->cmp((a)->value, (b)->value) == -1) #define equal(root, a, b) (root->cmp((a)->value, (b)->value) == 0) #define greater(root, a, b) (root->cmp((a)->value, (b)->value) == 1) /** * @brief 创建二叉搜索树 * * @param node 根节点 * @param cmp 比较函数 * @param release 用来释放结点的value的函数 * @return struct bt_root_t* 树根结构体 */ struct bt_root_t *bt_create_tree(struct bt_node_t *node, int (*cmp)(void *a, void *b), int (*release)(void *value)) { if (node == NULL || cmp == NULL) return (void*)-EINVAL; struct bt_root_t *root = (struct bt_root_t *)kmalloc(sizeof(struct bt_root_t), 0); memset((void *)root, 0, sizeof(struct bt_root_t)); root->bt_node = node; root->cmp = cmp; root->release = release; root->size = (node == NULL) ? 0 : 1; return root; } /** * @brief 创建结点 * * @param left 左子节点 * @param right 右子节点 * @param value 当前节点的值 * @return struct bt_node_t* */ struct bt_node_t *bt_create_node(struct bt_node_t *left, struct bt_node_t *right, struct bt_node_t *parent, void *value) { struct bt_node_t *node = (struct bt_node_t *)kmalloc(sizeof(struct bt_node_t), 0); FAIL_ON_TO(node == NULL, nomem); memset((void *)node, 0, sizeof(struct bt_node_t)); node->left = left; node->right = right; node->value = value; node->parent = parent; return node; nomem:; return (void*)-ENOMEM; } /** * @brief 插入结点 * * @param root 树根结点 * @param value 待插入结点的值 * @return int 返回码 */ int bt_insert(struct bt_root_t *root, void *value) { if (root == NULL) return -EINVAL; struct bt_node_t *this_node = root->bt_node; struct bt_node_t *last_node = NULL; struct bt_node_t *insert_node = bt_create_node(NULL, NULL, NULL, value); FAIL_ON_TO((uint64_t)insert_node == (uint64_t)(-ENOMEM), failed); while (this_node != NULL) { last_node = this_node; if (smaller(root, insert_node, this_node)) this_node = this_node->left; else this_node = this_node->right; } insert_node->parent = last_node; if (unlikely(last_node == NULL)) root->bt_node = insert_node; else { if (smaller(root, insert_node, last_node)) last_node->left = insert_node; else last_node->right = insert_node; } ++root->size; return 0; failed:; return -ENOMEM; } /** * @brief 搜索值为value的结点 * * @param value 值 * @param ret_addr 返回的结点基地址 * @return int 错误码 */ int bt_query(struct bt_root_t *root, void *value, uint64_t *ret_addr) { struct bt_node_t *this_node = root->bt_node; struct bt_node_t tmp_node = {0}; tmp_node.value = value; // 如果返回地址为0 if (ret_addr == NULL) return -EINVAL; while (this_node != NULL && !equal(root, this_node, &tmp_node)) { if (smaller(root, &tmp_node, this_node)) this_node = this_node->left; else this_node = this_node->right; } if (this_node != NULL && equal(root, this_node, &tmp_node)) { *ret_addr = (uint64_t)this_node; return 0; } else { // 找不到则返回-1,且addr设为0 *ret_addr = NULL; return -1; } } static struct bt_node_t *bt_get_minimum(struct bt_node_t *this_node) { while (this_node->left != NULL) this_node = this_node->left; return this_node; } /** * @brief 删除结点 * * @param root 树根 * @param value 待删除结点的值 * @return int 返回码 */ int bt_delete(struct bt_root_t *root, void *value) { uint64_t tmp_addr; int retval; // 寻找待删除结点 retval = bt_query(root, value, &tmp_addr); if (retval != 0 || tmp_addr == NULL) return retval; struct bt_node_t *this_node = (struct bt_node_t *)tmp_addr; struct bt_node_t *to_delete = NULL, *to_delete_son = NULL; if (this_node->left == NULL || this_node->right == NULL) to_delete = this_node; else { to_delete = bt_get_minimum(this_node->right); // 释放要被删除的值,并把下一个结点的值替换上来 root->release(this_node->value); this_node->value = to_delete->value; } if (to_delete->left != NULL) to_delete_son = to_delete->left; else to_delete_son = to_delete->right; if (to_delete_son != NULL) to_delete_son->parent = to_delete->parent; if (to_delete->parent == NULL) root->bt_node = to_delete_son; else { if (to_delete->parent->left == to_delete) to_delete->parent->left = to_delete_son; else to_delete->parent->right = to_delete_son; } --root->size; // 释放最终要删除的结点的对象 kfree(to_delete); } /** * @brief 释放整个二叉搜索树 * * @param root 树的根节点 * @return int 错误码 */ int bt_destroy_tree(struct bt_root_t *root) { // 新建一个kfifo缓冲区,将指向结点的指针存入fifo队列 // 注:为了将指针指向的地址存入队列,我们需要对指针取地址 struct kfifo_t fifo; kfifo_alloc(&fifo, ((root->size + 1) / 2) * sizeof(struct bt_node_t *), 0); kfifo_in(&fifo, (void *)&(root->bt_node), sizeof(struct bt_node_t *)); // bfs while (!kfifo_empty(&fifo)) { // 取出队列头部的结点指针 struct bt_node_t *nd; int count = kfifo_out(&fifo, &nd, sizeof(uint64_t)); // 将子节点加入队列 if (nd->left != NULL) kfifo_in(&fifo, (void *)&(nd->left), sizeof(struct bt_node_t *)); if (nd->right != NULL) kfifo_in(&fifo, (void *)&(nd->right), sizeof(struct bt_node_t *)); // 销毁当前节点 root->release(nd->value); kfree(nd); } kfifo_free_alloc(&fifo); return 0; }