1 #include <linux/kernel.h>
2 #include <linux/module.h>
3 #include <linux/list_sort.h>
4 #include <linux/slab.h>
5 #include <linux/list.h>
6 
7 #define MAX_LIST_LENGTH_BITS 20
8 
9 /*
10  * Returns a list organized in an intermediate format suited
11  * to chaining of merge() calls: null-terminated, no reserved or
12  * sentinel head node, "prev" links not maintained.
13  */
merge(void * priv,int (* cmp)(void * priv,struct list_head * a,struct list_head * b),struct list_head * a,struct list_head * b)14 static struct list_head *merge(void *priv,
15 				int (*cmp)(void *priv, struct list_head *a,
16 					struct list_head *b),
17 				struct list_head *a, struct list_head *b)
18 {
19 	struct list_head head, *tail = &head;
20 
21 	while (a && b) {
22 		/* if equal, take 'a' -- important for sort stability */
23 		if ((*cmp)(priv, a, b) <= 0) {
24 			tail->next = a;
25 			a = a->next;
26 		} else {
27 			tail->next = b;
28 			b = b->next;
29 		}
30 		tail = tail->next;
31 	}
32 	tail->next = a?:b;
33 	return head.next;
34 }
35 
36 /*
37  * Combine final list merge with restoration of standard doubly-linked
38  * list structure.  This approach duplicates code from merge(), but
39  * runs faster than the tidier alternatives of either a separate final
40  * prev-link restoration pass, or maintaining the prev links
41  * throughout.
42  */
merge_and_restore_back_links(void * priv,int (* cmp)(void * priv,struct list_head * a,struct list_head * b),struct list_head * head,struct list_head * a,struct list_head * b)43 static void merge_and_restore_back_links(void *priv,
44 				int (*cmp)(void *priv, struct list_head *a,
45 					struct list_head *b),
46 				struct list_head *head,
47 				struct list_head *a, struct list_head *b)
48 {
49 	struct list_head *tail = head;
50 
51 	while (a && b) {
52 		/* if equal, take 'a' -- important for sort stability */
53 		if ((*cmp)(priv, a, b) <= 0) {
54 			tail->next = a;
55 			a->prev = tail;
56 			a = a->next;
57 		} else {
58 			tail->next = b;
59 			b->prev = tail;
60 			b = b->next;
61 		}
62 		tail = tail->next;
63 	}
64 	tail->next = a ? : b;
65 
66 	do {
67 		/*
68 		 * In worst cases this loop may run many iterations.
69 		 * Continue callbacks to the client even though no
70 		 * element comparison is needed, so the client's cmp()
71 		 * routine can invoke cond_resched() periodically.
72 		 */
73 		(*cmp)(priv, tail->next, tail->next);
74 
75 		tail->next->prev = tail;
76 		tail = tail->next;
77 	} while (tail->next);
78 
79 	tail->next = head;
80 	head->prev = tail;
81 }
82 
83 /**
84  * list_sort - sort a list
85  * @priv: private data, opaque to list_sort(), passed to @cmp
86  * @head: the list to sort
87  * @cmp: the elements comparison function
88  *
89  * This function implements "merge sort", which has O(nlog(n))
90  * complexity.
91  *
92  * The comparison function @cmp must return a negative value if @a
93  * should sort before @b, and a positive value if @a should sort after
94  * @b. If @a and @b are equivalent, and their original relative
95  * ordering is to be preserved, @cmp must return 0.
96  */
list_sort(void * priv,struct list_head * head,int (* cmp)(void * priv,struct list_head * a,struct list_head * b))97 void list_sort(void *priv, struct list_head *head,
98 		int (*cmp)(void *priv, struct list_head *a,
99 			struct list_head *b))
100 {
101 	struct list_head *part[MAX_LIST_LENGTH_BITS+1]; /* sorted partial lists
102 						-- last slot is a sentinel */
103 	int lev;  /* index into part[] */
104 	int max_lev = 0;
105 	struct list_head *list;
106 
107 	if (list_empty(head))
108 		return;
109 
110 	memset(part, 0, sizeof(part));
111 
112 	head->prev->next = NULL;
113 	list = head->next;
114 
115 	while (list) {
116 		struct list_head *cur = list;
117 		list = list->next;
118 		cur->next = NULL;
119 
120 		for (lev = 0; part[lev]; lev++) {
121 			cur = merge(priv, cmp, part[lev], cur);
122 			part[lev] = NULL;
123 		}
124 		if (lev > max_lev) {
125 			if (unlikely(lev >= ARRAY_SIZE(part)-1)) {
126 				printk_once(KERN_DEBUG "list passed to"
127 					" list_sort() too long for"
128 					" efficiency\n");
129 				lev--;
130 			}
131 			max_lev = lev;
132 		}
133 		part[lev] = cur;
134 	}
135 
136 	for (lev = 0; lev < max_lev; lev++)
137 		if (part[lev])
138 			list = merge(priv, cmp, part[lev], list);
139 
140 	merge_and_restore_back_links(priv, cmp, head, part[max_lev], list);
141 }
142 EXPORT_SYMBOL(list_sort);
143 
144 #ifdef CONFIG_TEST_LIST_SORT
145 
146 #include <linux/random.h>
147 
148 /*
149  * The pattern of set bits in the list length determines which cases
150  * are hit in list_sort().
151  */
152 #define TEST_LIST_LEN (512+128+2) /* not including head */
153 
154 #define TEST_POISON1 0xDEADBEEF
155 #define TEST_POISON2 0xA324354C
156 
157 struct debug_el {
158 	unsigned int poison1;
159 	struct list_head list;
160 	unsigned int poison2;
161 	int value;
162 	unsigned serial;
163 };
164 
165 /* Array, containing pointers to all elements in the test list */
166 static struct debug_el **elts __initdata;
167 
check(struct debug_el * ela,struct debug_el * elb)168 static int __init check(struct debug_el *ela, struct debug_el *elb)
169 {
170 	if (ela->serial >= TEST_LIST_LEN) {
171 		printk(KERN_ERR "list_sort_test: error: incorrect serial %d\n",
172 				ela->serial);
173 		return -EINVAL;
174 	}
175 	if (elb->serial >= TEST_LIST_LEN) {
176 		printk(KERN_ERR "list_sort_test: error: incorrect serial %d\n",
177 				elb->serial);
178 		return -EINVAL;
179 	}
180 	if (elts[ela->serial] != ela || elts[elb->serial] != elb) {
181 		printk(KERN_ERR "list_sort_test: error: phantom element\n");
182 		return -EINVAL;
183 	}
184 	if (ela->poison1 != TEST_POISON1 || ela->poison2 != TEST_POISON2) {
185 		printk(KERN_ERR "list_sort_test: error: bad poison: %#x/%#x\n",
186 				ela->poison1, ela->poison2);
187 		return -EINVAL;
188 	}
189 	if (elb->poison1 != TEST_POISON1 || elb->poison2 != TEST_POISON2) {
190 		printk(KERN_ERR "list_sort_test: error: bad poison: %#x/%#x\n",
191 				elb->poison1, elb->poison2);
192 		return -EINVAL;
193 	}
194 	return 0;
195 }
196 
cmp(void * priv,struct list_head * a,struct list_head * b)197 static int __init cmp(void *priv, struct list_head *a, struct list_head *b)
198 {
199 	struct debug_el *ela, *elb;
200 
201 	ela = container_of(a, struct debug_el, list);
202 	elb = container_of(b, struct debug_el, list);
203 
204 	check(ela, elb);
205 	return ela->value - elb->value;
206 }
207 
list_sort_test(void)208 static int __init list_sort_test(void)
209 {
210 	int i, count = 1, err = -EINVAL;
211 	struct debug_el *el;
212 	struct list_head *cur, *tmp;
213 	LIST_HEAD(head);
214 
215 	printk(KERN_DEBUG "list_sort_test: start testing list_sort()\n");
216 
217 	elts = kmalloc(sizeof(void *) * TEST_LIST_LEN, GFP_KERNEL);
218 	if (!elts) {
219 		printk(KERN_ERR "list_sort_test: error: cannot allocate "
220 				"memory\n");
221 		goto exit;
222 	}
223 
224 	for (i = 0; i < TEST_LIST_LEN; i++) {
225 		el = kmalloc(sizeof(*el), GFP_KERNEL);
226 		if (!el) {
227 			printk(KERN_ERR "list_sort_test: error: cannot "
228 					"allocate memory\n");
229 			goto exit;
230 		}
231 		 /* force some equivalencies */
232 		el->value = random32() % (TEST_LIST_LEN/3);
233 		el->serial = i;
234 		el->poison1 = TEST_POISON1;
235 		el->poison2 = TEST_POISON2;
236 		elts[i] = el;
237 		list_add_tail(&el->list, &head);
238 	}
239 
240 	list_sort(NULL, &head, cmp);
241 
242 	for (cur = head.next; cur->next != &head; cur = cur->next) {
243 		struct debug_el *el1;
244 		int cmp_result;
245 
246 		if (cur->next->prev != cur) {
247 			printk(KERN_ERR "list_sort_test: error: list is "
248 					"corrupted\n");
249 			goto exit;
250 		}
251 
252 		cmp_result = cmp(NULL, cur, cur->next);
253 		if (cmp_result > 0) {
254 			printk(KERN_ERR "list_sort_test: error: list is not "
255 					"sorted\n");
256 			goto exit;
257 		}
258 
259 		el = container_of(cur, struct debug_el, list);
260 		el1 = container_of(cur->next, struct debug_el, list);
261 		if (cmp_result == 0 && el->serial >= el1->serial) {
262 			printk(KERN_ERR "list_sort_test: error: order of "
263 					"equivalent elements not preserved\n");
264 			goto exit;
265 		}
266 
267 		if (check(el, el1)) {
268 			printk(KERN_ERR "list_sort_test: error: element check "
269 					"failed\n");
270 			goto exit;
271 		}
272 		count++;
273 	}
274 
275 	if (count != TEST_LIST_LEN) {
276 		printk(KERN_ERR "list_sort_test: error: bad list length %d",
277 				count);
278 		goto exit;
279 	}
280 
281 	err = 0;
282 exit:
283 	kfree(elts);
284 	list_for_each_safe(cur, tmp, &head) {
285 		list_del(cur);
286 		kfree(container_of(cur, struct debug_el, list));
287 	}
288 	return err;
289 }
290 module_init(list_sort_test);
291 #endif /* CONFIG_TEST_LIST_SORT */
292