1 // SPDX-License-Identifier: GPL-2.0
2 #include <linux/init.h>
3 #include <linux/static_call.h>
4 #include <linux/bug.h>
5 #include <linux/smp.h>
6 #include <linux/sort.h>
7 #include <linux/slab.h>
8 #include <linux/module.h>
9 #include <linux/cpu.h>
10 #include <linux/processor.h>
11 #include <asm/sections.h>
12
13 extern struct static_call_site __start_static_call_sites[],
14 __stop_static_call_sites[];
15 extern struct static_call_tramp_key __start_static_call_tramp_key[],
16 __stop_static_call_tramp_key[];
17
18 static int static_call_initialized;
19
20 /*
21 * Must be called before early_initcall() to be effective.
22 */
static_call_force_reinit(void)23 void static_call_force_reinit(void)
24 {
25 if (WARN_ON_ONCE(!static_call_initialized))
26 return;
27
28 static_call_initialized++;
29 }
30
31 /* mutex to protect key modules/sites */
32 static DEFINE_MUTEX(static_call_mutex);
33
static_call_lock(void)34 static void static_call_lock(void)
35 {
36 mutex_lock(&static_call_mutex);
37 }
38
static_call_unlock(void)39 static void static_call_unlock(void)
40 {
41 mutex_unlock(&static_call_mutex);
42 }
43
static_call_addr(struct static_call_site * site)44 static inline void *static_call_addr(struct static_call_site *site)
45 {
46 return (void *)((long)site->addr + (long)&site->addr);
47 }
48
__static_call_key(const struct static_call_site * site)49 static inline unsigned long __static_call_key(const struct static_call_site *site)
50 {
51 return (long)site->key + (long)&site->key;
52 }
53
static_call_key(const struct static_call_site * site)54 static inline struct static_call_key *static_call_key(const struct static_call_site *site)
55 {
56 return (void *)(__static_call_key(site) & ~STATIC_CALL_SITE_FLAGS);
57 }
58
59 /* These assume the key is word-aligned. */
static_call_is_init(struct static_call_site * site)60 static inline bool static_call_is_init(struct static_call_site *site)
61 {
62 return __static_call_key(site) & STATIC_CALL_SITE_INIT;
63 }
64
static_call_is_tail(struct static_call_site * site)65 static inline bool static_call_is_tail(struct static_call_site *site)
66 {
67 return __static_call_key(site) & STATIC_CALL_SITE_TAIL;
68 }
69
static_call_set_init(struct static_call_site * site)70 static inline void static_call_set_init(struct static_call_site *site)
71 {
72 site->key = (__static_call_key(site) | STATIC_CALL_SITE_INIT) -
73 (long)&site->key;
74 }
75
static_call_site_cmp(const void * _a,const void * _b)76 static int static_call_site_cmp(const void *_a, const void *_b)
77 {
78 const struct static_call_site *a = _a;
79 const struct static_call_site *b = _b;
80 const struct static_call_key *key_a = static_call_key(a);
81 const struct static_call_key *key_b = static_call_key(b);
82
83 if (key_a < key_b)
84 return -1;
85
86 if (key_a > key_b)
87 return 1;
88
89 return 0;
90 }
91
static_call_site_swap(void * _a,void * _b,int size)92 static void static_call_site_swap(void *_a, void *_b, int size)
93 {
94 long delta = (unsigned long)_a - (unsigned long)_b;
95 struct static_call_site *a = _a;
96 struct static_call_site *b = _b;
97 struct static_call_site tmp = *a;
98
99 a->addr = b->addr - delta;
100 a->key = b->key - delta;
101
102 b->addr = tmp.addr + delta;
103 b->key = tmp.key + delta;
104 }
105
static_call_sort_entries(struct static_call_site * start,struct static_call_site * stop)106 static inline void static_call_sort_entries(struct static_call_site *start,
107 struct static_call_site *stop)
108 {
109 sort(start, stop - start, sizeof(struct static_call_site),
110 static_call_site_cmp, static_call_site_swap);
111 }
112
static_call_key_has_mods(struct static_call_key * key)113 static inline bool static_call_key_has_mods(struct static_call_key *key)
114 {
115 return !(key->type & 1);
116 }
117
static_call_key_next(struct static_call_key * key)118 static inline struct static_call_mod *static_call_key_next(struct static_call_key *key)
119 {
120 if (!static_call_key_has_mods(key))
121 return NULL;
122
123 return key->mods;
124 }
125
static_call_key_sites(struct static_call_key * key)126 static inline struct static_call_site *static_call_key_sites(struct static_call_key *key)
127 {
128 if (static_call_key_has_mods(key))
129 return NULL;
130
131 return (struct static_call_site *)(key->type & ~1);
132 }
133
__static_call_update(struct static_call_key * key,void * tramp,void * func)134 void __static_call_update(struct static_call_key *key, void *tramp, void *func)
135 {
136 struct static_call_site *site, *stop;
137 struct static_call_mod *site_mod, first;
138
139 cpus_read_lock();
140 static_call_lock();
141
142 if (key->func == func)
143 goto done;
144
145 key->func = func;
146
147 arch_static_call_transform(NULL, tramp, func, false);
148
149 /*
150 * If uninitialized, we'll not update the callsites, but they still
151 * point to the trampoline and we just patched that.
152 */
153 if (WARN_ON_ONCE(!static_call_initialized))
154 goto done;
155
156 first = (struct static_call_mod){
157 .next = static_call_key_next(key),
158 .mod = NULL,
159 .sites = static_call_key_sites(key),
160 };
161
162 for (site_mod = &first; site_mod; site_mod = site_mod->next) {
163 bool init = system_state < SYSTEM_RUNNING;
164 struct module *mod = site_mod->mod;
165
166 if (!site_mod->sites) {
167 /*
168 * This can happen if the static call key is defined in
169 * a module which doesn't use it.
170 *
171 * It also happens in the has_mods case, where the
172 * 'first' entry has no sites associated with it.
173 */
174 continue;
175 }
176
177 stop = __stop_static_call_sites;
178
179 if (mod) {
180 #ifdef CONFIG_MODULES
181 stop = mod->static_call_sites +
182 mod->num_static_call_sites;
183 init = mod->state == MODULE_STATE_COMING;
184 #endif
185 }
186
187 for (site = site_mod->sites;
188 site < stop && static_call_key(site) == key; site++) {
189 void *site_addr = static_call_addr(site);
190
191 if (!init && static_call_is_init(site))
192 continue;
193
194 if (!kernel_text_address((unsigned long)site_addr)) {
195 /*
196 * This skips patching built-in __exit, which
197 * is part of init_section_contains() but is
198 * not part of kernel_text_address().
199 *
200 * Skipping built-in __exit is fine since it
201 * will never be executed.
202 */
203 WARN_ONCE(!static_call_is_init(site),
204 "can't patch static call site at %pS",
205 site_addr);
206 continue;
207 }
208
209 arch_static_call_transform(site_addr, NULL, func,
210 static_call_is_tail(site));
211 }
212 }
213
214 done:
215 static_call_unlock();
216 cpus_read_unlock();
217 }
218 EXPORT_SYMBOL_GPL(__static_call_update);
219
__static_call_init(struct module * mod,struct static_call_site * start,struct static_call_site * stop)220 static int __static_call_init(struct module *mod,
221 struct static_call_site *start,
222 struct static_call_site *stop)
223 {
224 struct static_call_site *site;
225 struct static_call_key *key, *prev_key = NULL;
226 struct static_call_mod *site_mod;
227
228 if (start == stop)
229 return 0;
230
231 static_call_sort_entries(start, stop);
232
233 for (site = start; site < stop; site++) {
234 void *site_addr = static_call_addr(site);
235
236 if ((mod && within_module_init((unsigned long)site_addr, mod)) ||
237 (!mod && init_section_contains(site_addr, 1)))
238 static_call_set_init(site);
239
240 key = static_call_key(site);
241 if (key != prev_key) {
242 prev_key = key;
243
244 /*
245 * For vmlinux (!mod) avoid the allocation by storing
246 * the sites pointer in the key itself. Also see
247 * __static_call_update()'s @first.
248 *
249 * This allows architectures (eg. x86) to call
250 * static_call_init() before memory allocation works.
251 */
252 if (!mod) {
253 key->sites = site;
254 key->type |= 1;
255 goto do_transform;
256 }
257
258 site_mod = kzalloc(sizeof(*site_mod), GFP_KERNEL);
259 if (!site_mod)
260 return -ENOMEM;
261
262 /*
263 * When the key has a direct sites pointer, extract
264 * that into an explicit struct static_call_mod, so we
265 * can have a list of modules.
266 */
267 if (static_call_key_sites(key)) {
268 site_mod->mod = NULL;
269 site_mod->next = NULL;
270 site_mod->sites = static_call_key_sites(key);
271
272 key->mods = site_mod;
273
274 site_mod = kzalloc(sizeof(*site_mod), GFP_KERNEL);
275 if (!site_mod)
276 return -ENOMEM;
277 }
278
279 site_mod->mod = mod;
280 site_mod->sites = site;
281 site_mod->next = static_call_key_next(key);
282 key->mods = site_mod;
283 }
284
285 do_transform:
286 arch_static_call_transform(site_addr, NULL, key->func,
287 static_call_is_tail(site));
288 }
289
290 return 0;
291 }
292
addr_conflict(struct static_call_site * site,void * start,void * end)293 static int addr_conflict(struct static_call_site *site, void *start, void *end)
294 {
295 unsigned long addr = (unsigned long)static_call_addr(site);
296
297 if (addr <= (unsigned long)end &&
298 addr + CALL_INSN_SIZE > (unsigned long)start)
299 return 1;
300
301 return 0;
302 }
303
__static_call_text_reserved(struct static_call_site * iter_start,struct static_call_site * iter_stop,void * start,void * end,bool init)304 static int __static_call_text_reserved(struct static_call_site *iter_start,
305 struct static_call_site *iter_stop,
306 void *start, void *end, bool init)
307 {
308 struct static_call_site *iter = iter_start;
309
310 while (iter < iter_stop) {
311 if (init || !static_call_is_init(iter)) {
312 if (addr_conflict(iter, start, end))
313 return 1;
314 }
315 iter++;
316 }
317
318 return 0;
319 }
320
321 #ifdef CONFIG_MODULES
322
__static_call_mod_text_reserved(void * start,void * end)323 static int __static_call_mod_text_reserved(void *start, void *end)
324 {
325 struct module *mod;
326 int ret;
327
328 preempt_disable();
329 mod = __module_text_address((unsigned long)start);
330 WARN_ON_ONCE(__module_text_address((unsigned long)end) != mod);
331 if (!try_module_get(mod))
332 mod = NULL;
333 preempt_enable();
334
335 if (!mod)
336 return 0;
337
338 ret = __static_call_text_reserved(mod->static_call_sites,
339 mod->static_call_sites + mod->num_static_call_sites,
340 start, end, mod->state == MODULE_STATE_COMING);
341
342 module_put(mod);
343
344 return ret;
345 }
346
tramp_key_lookup(unsigned long addr)347 static unsigned long tramp_key_lookup(unsigned long addr)
348 {
349 struct static_call_tramp_key *start = __start_static_call_tramp_key;
350 struct static_call_tramp_key *stop = __stop_static_call_tramp_key;
351 struct static_call_tramp_key *tramp_key;
352
353 for (tramp_key = start; tramp_key != stop; tramp_key++) {
354 unsigned long tramp;
355
356 tramp = (long)tramp_key->tramp + (long)&tramp_key->tramp;
357 if (tramp == addr)
358 return (long)tramp_key->key + (long)&tramp_key->key;
359 }
360
361 return 0;
362 }
363
static_call_add_module(struct module * mod)364 static int static_call_add_module(struct module *mod)
365 {
366 struct static_call_site *start = mod->static_call_sites;
367 struct static_call_site *stop = start + mod->num_static_call_sites;
368 struct static_call_site *site;
369
370 for (site = start; site != stop; site++) {
371 unsigned long s_key = __static_call_key(site);
372 unsigned long addr = s_key & ~STATIC_CALL_SITE_FLAGS;
373 unsigned long key;
374
375 /*
376 * Is the key is exported, 'addr' points to the key, which
377 * means modules are allowed to call static_call_update() on
378 * it.
379 *
380 * Otherwise, the key isn't exported, and 'addr' points to the
381 * trampoline so we need to lookup the key.
382 *
383 * We go through this dance to prevent crazy modules from
384 * abusing sensitive static calls.
385 */
386 if (!kernel_text_address(addr))
387 continue;
388
389 key = tramp_key_lookup(addr);
390 if (!key) {
391 pr_warn("Failed to fixup __raw_static_call() usage at: %ps\n",
392 static_call_addr(site));
393 return -EINVAL;
394 }
395
396 key |= s_key & STATIC_CALL_SITE_FLAGS;
397 site->key = key - (long)&site->key;
398 }
399
400 return __static_call_init(mod, start, stop);
401 }
402
static_call_del_module(struct module * mod)403 static void static_call_del_module(struct module *mod)
404 {
405 struct static_call_site *start = mod->static_call_sites;
406 struct static_call_site *stop = mod->static_call_sites +
407 mod->num_static_call_sites;
408 struct static_call_key *key, *prev_key = NULL;
409 struct static_call_mod *site_mod, **prev;
410 struct static_call_site *site;
411
412 for (site = start; site < stop; site++) {
413 key = static_call_key(site);
414 if (key == prev_key)
415 continue;
416
417 prev_key = key;
418
419 for (prev = &key->mods, site_mod = key->mods;
420 site_mod && site_mod->mod != mod;
421 prev = &site_mod->next, site_mod = site_mod->next)
422 ;
423
424 if (!site_mod)
425 continue;
426
427 *prev = site_mod->next;
428 kfree(site_mod);
429 }
430 }
431
static_call_module_notify(struct notifier_block * nb,unsigned long val,void * data)432 static int static_call_module_notify(struct notifier_block *nb,
433 unsigned long val, void *data)
434 {
435 struct module *mod = data;
436 int ret = 0;
437
438 cpus_read_lock();
439 static_call_lock();
440
441 switch (val) {
442 case MODULE_STATE_COMING:
443 ret = static_call_add_module(mod);
444 if (ret) {
445 WARN(1, "Failed to allocate memory for static calls");
446 static_call_del_module(mod);
447 }
448 break;
449 case MODULE_STATE_GOING:
450 static_call_del_module(mod);
451 break;
452 }
453
454 static_call_unlock();
455 cpus_read_unlock();
456
457 return notifier_from_errno(ret);
458 }
459
460 static struct notifier_block static_call_module_nb = {
461 .notifier_call = static_call_module_notify,
462 };
463
464 #else
465
__static_call_mod_text_reserved(void * start,void * end)466 static inline int __static_call_mod_text_reserved(void *start, void *end)
467 {
468 return 0;
469 }
470
471 #endif /* CONFIG_MODULES */
472
static_call_text_reserved(void * start,void * end)473 int static_call_text_reserved(void *start, void *end)
474 {
475 bool init = system_state < SYSTEM_RUNNING;
476 int ret = __static_call_text_reserved(__start_static_call_sites,
477 __stop_static_call_sites, start, end, init);
478
479 if (ret)
480 return ret;
481
482 return __static_call_mod_text_reserved(start, end);
483 }
484
static_call_init(void)485 int __init static_call_init(void)
486 {
487 int ret;
488
489 /* See static_call_force_reinit(). */
490 if (static_call_initialized == 1)
491 return 0;
492
493 cpus_read_lock();
494 static_call_lock();
495 ret = __static_call_init(NULL, __start_static_call_sites,
496 __stop_static_call_sites);
497 static_call_unlock();
498 cpus_read_unlock();
499
500 if (ret) {
501 pr_err("Failed to allocate memory for static_call!\n");
502 BUG();
503 }
504
505 #ifdef CONFIG_MODULES
506 if (!static_call_initialized)
507 register_module_notifier(&static_call_module_nb);
508 #endif
509
510 static_call_initialized = 1;
511 return 0;
512 }
513 early_initcall(static_call_init);
514
515 #ifdef CONFIG_STATIC_CALL_SELFTEST
516
func_a(int x)517 static int func_a(int x)
518 {
519 return x+1;
520 }
521
func_b(int x)522 static int func_b(int x)
523 {
524 return x+2;
525 }
526
527 DEFINE_STATIC_CALL(sc_selftest, func_a);
528
529 static struct static_call_data {
530 int (*func)(int);
531 int val;
532 int expect;
533 } static_call_data [] __initdata = {
534 { NULL, 2, 3 },
535 { func_b, 2, 4 },
536 { func_a, 2, 3 }
537 };
538
test_static_call_init(void)539 static int __init test_static_call_init(void)
540 {
541 int i;
542
543 for (i = 0; i < ARRAY_SIZE(static_call_data); i++ ) {
544 struct static_call_data *scd = &static_call_data[i];
545
546 if (scd->func)
547 static_call_update(sc_selftest, scd->func);
548
549 WARN_ON(static_call(sc_selftest)(scd->val) != scd->expect);
550 }
551
552 return 0;
553 }
554 early_initcall(test_static_call_init);
555
556 #endif /* CONFIG_STATIC_CALL_SELFTEST */
557