1 #include <common/bitree.h>
2 #include <mm/slab.h>
3 #include <common/errno.h>
4 #include <common/kfifo.h>
5 #include <common/string.h>
6 #include <debug/bug.h>
7 
8 #define smaller(root, a, b) (root->cmp((a)->value, (b)->value) == -1)
9 #define equal(root, a, b) (root->cmp((a)->value, (b)->value) == 0)
10 #define greater(root, a, b) (root->cmp((a)->value, (b)->value) == 1)
11 
12 /**
13  * @brief 创建二叉搜索树
14  *
15  * @param node 根节点
16  * @param cmp 比较函数
17  * @param release 用来释放结点的value的函数
18  * @return struct bt_root_t* 树根结构体
19  */
bt_create_tree(struct bt_node_t * node,int (* cmp)(void * a,void * b),int (* release)(void * value))20 struct bt_root_t *bt_create_tree(struct bt_node_t *node, int (*cmp)(void *a, void *b), int (*release)(void *value))
21 {
22     if (node == NULL || cmp == NULL)
23         return (void*)-EINVAL;
24 
25     struct bt_root_t *root = (struct bt_root_t *)kmalloc(sizeof(struct bt_root_t), 0);
26     memset((void *)root, 0, sizeof(struct bt_root_t));
27     root->bt_node = node;
28     root->cmp = cmp;
29     root->release = release;
30     root->size = (node == NULL) ? 0 : 1;
31 
32     return root;
33 }
34 
35 /**
36  * @brief 创建结点
37  *
38  * @param left 左子节点
39  * @param right 右子节点
40  * @param value 当前节点的值
41  * @return struct bt_node_t*
42  */
bt_create_node(struct bt_node_t * left,struct bt_node_t * right,struct bt_node_t * parent,void * value)43 struct bt_node_t *bt_create_node(struct bt_node_t *left, struct bt_node_t *right, struct bt_node_t *parent, void *value)
44 {
45     struct bt_node_t *node = (struct bt_node_t *)kmalloc(sizeof(struct bt_node_t), 0);
46     FAIL_ON_TO(node == NULL, nomem);
47     memset((void *)node, 0, sizeof(struct bt_node_t));
48 
49     node->left = left;
50     node->right = right;
51     node->value = value;
52     node->parent = parent;
53 
54     return node;
55 nomem:;
56     return (void*)-ENOMEM;
57 }
58 /**
59  * @brief 插入结点
60  *
61  * @param root 树根结点
62  * @param value 待插入结点的值
63  * @return int 返回码
64  */
bt_insert(struct bt_root_t * root,void * value)65 int bt_insert(struct bt_root_t *root, void *value)
66 {
67     if (root == NULL)
68         return -EINVAL;
69 
70     struct bt_node_t *this_node = root->bt_node;
71     struct bt_node_t *last_node = NULL;
72     struct bt_node_t *insert_node = bt_create_node(NULL, NULL, NULL, value);
73     FAIL_ON_TO((uint64_t)insert_node == (uint64_t)(-ENOMEM), failed);
74 
75     while (this_node != NULL)
76     {
77         last_node = this_node;
78         if (smaller(root, insert_node, this_node))
79             this_node = this_node->left;
80         else
81             this_node = this_node->right;
82     }
83 
84     insert_node->parent = last_node;
85     if (unlikely(last_node == NULL))
86         root->bt_node = insert_node;
87     else
88     {
89         if (smaller(root, insert_node, last_node))
90             last_node->left = insert_node;
91         else
92             last_node->right = insert_node;
93     }
94     ++root->size;
95     return 0;
96 
97 failed:;
98     return -ENOMEM;
99 }
100 
101 /**
102  * @brief 搜索值为value的结点
103  *
104  * @param value 值
105  * @param ret_addr 返回的结点基地址
106  * @return int 错误码
107  */
bt_query(struct bt_root_t * root,void * value,uint64_t * ret_addr)108 int bt_query(struct bt_root_t *root, void *value, uint64_t *ret_addr)
109 {
110     struct bt_node_t *this_node = root->bt_node;
111     struct bt_node_t tmp_node = {0};
112     tmp_node.value = value;
113 
114     // 如果返回地址为0
115     if (ret_addr == NULL)
116         return -EINVAL;
117 
118     while (this_node != NULL && !equal(root, this_node, &tmp_node))
119     {
120         if (smaller(root, &tmp_node, this_node))
121             this_node = this_node->left;
122         else
123             this_node = this_node->right;
124     }
125 
126     if (this_node != NULL && equal(root, this_node, &tmp_node))
127     {
128         *ret_addr = (uint64_t)this_node;
129         return 0;
130     }
131     else
132     {
133         // 找不到则返回-1,且addr设为0
134         *ret_addr = NULL;
135         return -1;
136     }
137 }
138 
bt_get_minimum(struct bt_node_t * this_node)139 static struct bt_node_t *bt_get_minimum(struct bt_node_t *this_node)
140 {
141     while (this_node->left != NULL)
142         this_node = this_node->left;
143     return this_node;
144 }
145 
146 /**
147  * @brief 删除结点
148  *
149  * @param root 树根
150  * @param value 待删除结点的值
151  * @return int 返回码
152  */
bt_delete(struct bt_root_t * root,void * value)153 int bt_delete(struct bt_root_t *root, void *value)
154 {
155     uint64_t tmp_addr;
156     int retval;
157 
158     // 寻找待删除结点
159     retval = bt_query(root, value, &tmp_addr);
160     if (retval != 0 || tmp_addr == NULL)
161         return retval;
162 
163     struct bt_node_t *this_node = (struct bt_node_t *)tmp_addr;
164     struct bt_node_t *to_delete = NULL, *to_delete_son = NULL;
165     if (this_node->left == NULL || this_node->right == NULL)
166         to_delete = this_node;
167     else
168     {
169         to_delete = bt_get_minimum(this_node->right);
170         // 释放要被删除的值,并把下一个结点的值替换上来
171         root->release(this_node->value);
172         this_node->value = to_delete->value;
173     }
174 
175     if (to_delete->left != NULL)
176         to_delete_son = to_delete->left;
177     else
178         to_delete_son = to_delete->right;
179 
180     if (to_delete_son != NULL)
181         to_delete_son->parent = to_delete->parent;
182 
183     if (to_delete->parent == NULL)
184         root->bt_node = to_delete_son;
185     else
186     {
187         if (to_delete->parent->left == to_delete)
188             to_delete->parent->left = to_delete_son;
189         else
190             to_delete->parent->right = to_delete_son;
191     }
192 
193     --root->size;
194     // 释放最终要删除的结点的对象
195     kfree(to_delete);
196 }
197 
198 /**
199  * @brief 释放整个二叉搜索树
200  *
201  * @param root 树的根节点
202  * @return int 错误码
203  */
bt_destroy_tree(struct bt_root_t * root)204 int bt_destroy_tree(struct bt_root_t *root)
205 {
206     // 新建一个kfifo缓冲区,将指向结点的指针存入fifo队列
207     // 注:为了将指针指向的地址存入队列,我们需要对指针取地址
208     struct kfifo_t fifo;
209     kfifo_alloc(&fifo, ((root->size + 1) / 2) * sizeof(struct bt_node_t *), 0);
210     kfifo_in(&fifo, (void *)&(root->bt_node), sizeof(struct bt_node_t *));
211 
212     // bfs
213     while (!kfifo_empty(&fifo))
214     {
215         // 取出队列头部的结点指针
216         struct bt_node_t *nd;
217         int count = kfifo_out(&fifo, &nd, sizeof(uint64_t));
218 
219         // 将子节点加入队列
220         if (nd->left != NULL)
221             kfifo_in(&fifo, (void *)&(nd->left), sizeof(struct bt_node_t *));
222 
223         if (nd->right != NULL)
224             kfifo_in(&fifo, (void *)&(nd->right), sizeof(struct bt_node_t *));
225 
226         // 销毁当前节点
227         root->release(nd->value);
228         kfree(nd);
229     }
230 
231     kfifo_free_alloc(&fifo);
232 
233     return 0;
234 }