1 /* SPDX-License-Identifier: LGPL-2.1-or-later */
2
3 #include <linux/genetlink.h>
4
5 #include "sd-netlink.h"
6
7 #include "alloc-util.h"
8 #include "netlink-genl.h"
9 #include "netlink-internal.h"
10 #include "netlink-types.h"
11
12 typedef struct GenericNetlinkFamily {
13 sd_netlink *genl;
14
15 const NLTypeSystem *type_system;
16
17 uint16_t id; /* a.k.a nlmsg_type */
18 char *name;
19 uint32_t version;
20 uint32_t additional_header_size;
21 Hashmap *multicast_group_by_name;
22 } GenericNetlinkFamily;
23
24 static const GenericNetlinkFamily nlctrl_static = {
25 .id = GENL_ID_CTRL,
26 .name = (char*) CTRL_GENL_NAME,
27 .version = 0x01,
28 };
29
genl_family_free(GenericNetlinkFamily * f)30 static GenericNetlinkFamily *genl_family_free(GenericNetlinkFamily *f) {
31 if (!f)
32 return NULL;
33
34 if (f->genl) {
35 if (f->id > 0)
36 hashmap_remove(f->genl->genl_family_by_id, UINT_TO_PTR(f->id));
37 if (f->name)
38 hashmap_remove(f->genl->genl_family_by_name, f->name);
39 }
40
41 free(f->name);
42 hashmap_free(f->multicast_group_by_name);
43
44 return mfree(f);
45 }
46
47 DEFINE_TRIVIAL_CLEANUP_FUNC(GenericNetlinkFamily*, genl_family_free);
48
genl_clear_family(sd_netlink * nl)49 void genl_clear_family(sd_netlink *nl) {
50 assert(nl);
51
52 nl->genl_family_by_name = hashmap_free_with_destructor(nl->genl_family_by_name, genl_family_free);
53 nl->genl_family_by_id = hashmap_free_with_destructor(nl->genl_family_by_id, genl_family_free);
54 }
55
genl_family_new_unsupported(sd_netlink * nl,const char * family_name,const NLTypeSystem * type_system)56 static int genl_family_new_unsupported(
57 sd_netlink *nl,
58 const char *family_name,
59 const NLTypeSystem *type_system) {
60
61 _cleanup_(genl_family_freep) GenericNetlinkFamily *f = NULL;
62 int r;
63
64 assert(nl);
65 assert(family_name);
66 assert(type_system);
67
68 /* Kernel does not support the genl family? To prevent from resolving the family name again,
69 * let's store the family with zero id to indicate that. */
70
71 f = new(GenericNetlinkFamily, 1);
72 if (!f)
73 return -ENOMEM;
74
75 *f = (GenericNetlinkFamily) {
76 .type_system = type_system,
77 };
78
79 f->name = strdup(family_name);
80 if (!f->name)
81 return -ENOMEM;
82
83 r = hashmap_ensure_put(&nl->genl_family_by_name, &string_hash_ops, f->name, f);
84 if (r < 0)
85 return r;
86
87 f->genl = nl;
88 TAKE_PTR(f);
89 return 0;
90 }
91
genl_family_new(sd_netlink * nl,const char * expected_family_name,const NLTypeSystem * type_system,sd_netlink_message * message,const GenericNetlinkFamily ** ret)92 static int genl_family_new(
93 sd_netlink *nl,
94 const char *expected_family_name,
95 const NLTypeSystem *type_system,
96 sd_netlink_message *message,
97 const GenericNetlinkFamily **ret) {
98
99 _cleanup_(genl_family_freep) GenericNetlinkFamily *f = NULL;
100 const char *family_name;
101 uint8_t cmd;
102 int r;
103
104 assert(nl);
105 assert(expected_family_name);
106 assert(type_system);
107 assert(message);
108 assert(ret);
109
110 f = new(GenericNetlinkFamily, 1);
111 if (!f)
112 return -ENOMEM;
113
114 *f = (GenericNetlinkFamily) {
115 .type_system = type_system,
116 };
117
118 r = sd_genl_message_get_family_name(nl, message, &family_name);
119 if (r < 0)
120 return r;
121
122 if (!streq(family_name, CTRL_GENL_NAME))
123 return -EINVAL;
124
125 r = sd_genl_message_get_command(nl, message, &cmd);
126 if (r < 0)
127 return r;
128
129 if (cmd != CTRL_CMD_NEWFAMILY)
130 return -EINVAL;
131
132 r = sd_netlink_message_read_u16(message, CTRL_ATTR_FAMILY_ID, &f->id);
133 if (r < 0)
134 return r;
135
136 r = sd_netlink_message_read_string_strdup(message, CTRL_ATTR_FAMILY_NAME, &f->name);
137 if (r < 0)
138 return r;
139
140 if (!streq(f->name, expected_family_name))
141 return -EINVAL;
142
143 r = sd_netlink_message_read_u32(message, CTRL_ATTR_VERSION, &f->version);
144 if (r < 0)
145 return r;
146
147 r = sd_netlink_message_read_u32(message, CTRL_ATTR_HDRSIZE, &f->additional_header_size);
148 if (r < 0)
149 return r;
150
151 r = sd_netlink_message_enter_container(message, CTRL_ATTR_MCAST_GROUPS);
152 if (r >= 0) {
153 for (uint16_t i = 0; i < UINT16_MAX; i++) {
154 _cleanup_free_ char *group_name = NULL;
155 uint32_t group_id;
156
157 r = sd_netlink_message_enter_array(message, i + 1);
158 if (r == -ENODATA)
159 break;
160 if (r < 0)
161 return r;
162
163 r = sd_netlink_message_read_u32(message, CTRL_ATTR_MCAST_GRP_ID, &group_id);
164 if (r < 0)
165 return r;
166
167 r = sd_netlink_message_read_string_strdup(message, CTRL_ATTR_MCAST_GRP_NAME, &group_name);
168 if (r < 0)
169 return r;
170
171 r = sd_netlink_message_exit_container(message);
172 if (r < 0)
173 return r;
174
175 if (group_id == 0) {
176 log_debug("sd-netlink: received multicast group '%s' for generic netlink family '%s' with id == 0, ignoring",
177 group_name, f->name);
178 continue;
179 }
180
181 r = hashmap_ensure_put(&f->multicast_group_by_name, &string_hash_ops_free, group_name, UINT32_TO_PTR(group_id));
182 if (r < 0)
183 return r;
184
185 TAKE_PTR(group_name);
186 }
187
188 r = sd_netlink_message_exit_container(message);
189 if (r < 0)
190 return r;
191 }
192
193 r = hashmap_ensure_put(&nl->genl_family_by_id, NULL, UINT_TO_PTR(f->id), f);
194 if (r < 0)
195 return r;
196
197 r = hashmap_ensure_put(&nl->genl_family_by_name, &string_hash_ops, f->name, f);
198 if (r < 0) {
199 hashmap_remove(nl->genl_family_by_id, UINT_TO_PTR(f->id));
200 return r;
201 }
202
203 f->genl = nl;
204 *ret = TAKE_PTR(f);
205 return 0;
206 }
207
genl_family_get_type_system(const GenericNetlinkFamily * family)208 static const NLTypeSystem *genl_family_get_type_system(const GenericNetlinkFamily *family) {
209 assert(family);
210
211 if (family->type_system)
212 return family->type_system;
213
214 return genl_get_type_system_by_name(family->name);
215 }
216
genl_message_new(sd_netlink * nl,const GenericNetlinkFamily * family,uint8_t cmd,sd_netlink_message ** ret)217 static int genl_message_new(
218 sd_netlink *nl,
219 const GenericNetlinkFamily *family,
220 uint8_t cmd,
221 sd_netlink_message **ret) {
222
223 _cleanup_(sd_netlink_message_unrefp) sd_netlink_message *m = NULL;
224 const NLTypeSystem *type_system;
225 int r;
226
227 assert(nl);
228 assert(nl->protocol == NETLINK_GENERIC);
229 assert(family);
230 assert(ret);
231
232 type_system = genl_family_get_type_system(family);
233 if (!type_system)
234 return -EOPNOTSUPP;
235
236 r = message_new_full(nl, family->id, type_system,
237 sizeof(struct genlmsghdr) + family->additional_header_size, &m);
238 if (r < 0)
239 return r;
240
241 *(struct genlmsghdr *) NLMSG_DATA(m->hdr) = (struct genlmsghdr) {
242 .cmd = cmd,
243 .version = family->version,
244 };
245
246 *ret = TAKE_PTR(m);
247 return 0;
248 }
249
genl_family_get_by_name_internal(sd_netlink * nl,const GenericNetlinkFamily * ctrl,const char * name,const GenericNetlinkFamily ** ret)250 static int genl_family_get_by_name_internal(
251 sd_netlink *nl,
252 const GenericNetlinkFamily *ctrl,
253 const char *name,
254 const GenericNetlinkFamily **ret) {
255
256 _cleanup_(sd_netlink_message_unrefp) sd_netlink_message *req = NULL, *reply = NULL;
257 const NLTypeSystem *type_system;
258 int r;
259
260 assert(nl);
261 assert(nl->protocol == NETLINK_GENERIC);
262 assert(ctrl);
263 assert(name);
264 assert(ret);
265
266 type_system = genl_get_type_system_by_name(name);
267 if (!type_system)
268 return -EOPNOTSUPP;
269
270 r = genl_message_new(nl, ctrl, CTRL_CMD_GETFAMILY, &req);
271 if (r < 0)
272 return r;
273
274 r = sd_netlink_message_append_string(req, CTRL_ATTR_FAMILY_NAME, name);
275 if (r < 0)
276 return r;
277
278 if (sd_netlink_call(nl, req, 0, &reply) < 0) {
279 (void) genl_family_new_unsupported(nl, name, type_system);
280 return -EOPNOTSUPP;
281 }
282
283 return genl_family_new(nl, name, type_system, reply, ret);
284 }
285
genl_family_get_by_name(sd_netlink * nl,const char * name,const GenericNetlinkFamily ** ret)286 static int genl_family_get_by_name(sd_netlink *nl, const char *name, const GenericNetlinkFamily **ret) {
287 const GenericNetlinkFamily *f, *ctrl;
288 int r;
289
290 assert(nl);
291 assert(nl->protocol == NETLINK_GENERIC);
292 assert(name);
293 assert(ret);
294
295 f = hashmap_get(nl->genl_family_by_name, name);
296 if (f) {
297 if (f->id == 0) /* kernel does not support the family. */
298 return -EOPNOTSUPP;
299
300 *ret = f;
301 return 0;
302 }
303
304 if (streq(name, CTRL_GENL_NAME))
305 return genl_family_get_by_name_internal(nl, &nlctrl_static, CTRL_GENL_NAME, ret);
306
307 ctrl = hashmap_get(nl->genl_family_by_name, CTRL_GENL_NAME);
308 if (!ctrl) {
309 r = genl_family_get_by_name_internal(nl, &nlctrl_static, CTRL_GENL_NAME, &ctrl);
310 if (r < 0)
311 return r;
312 }
313
314 return genl_family_get_by_name_internal(nl, ctrl, name, ret);
315 }
316
genl_family_get_by_id(sd_netlink * nl,uint16_t id,const GenericNetlinkFamily ** ret)317 static int genl_family_get_by_id(sd_netlink *nl, uint16_t id, const GenericNetlinkFamily **ret) {
318 const GenericNetlinkFamily *f;
319
320 assert(nl);
321 assert(nl->protocol == NETLINK_GENERIC);
322 assert(ret);
323
324 f = hashmap_get(nl->genl_family_by_id, UINT_TO_PTR(id));
325 if (f) {
326 *ret = f;
327 return 0;
328 }
329
330 if (id == GENL_ID_CTRL) {
331 *ret = &nlctrl_static;
332 return 0;
333 }
334
335 return -ENOENT;
336 }
337
genl_get_type_system_and_header_size(sd_netlink * nl,uint16_t id,const NLTypeSystem ** ret_type_system,size_t * ret_header_size)338 int genl_get_type_system_and_header_size(
339 sd_netlink *nl,
340 uint16_t id,
341 const NLTypeSystem **ret_type_system,
342 size_t *ret_header_size) {
343
344 const GenericNetlinkFamily *f;
345 int r;
346
347 assert(nl);
348 assert(nl->protocol == NETLINK_GENERIC);
349
350 r = genl_family_get_by_id(nl, id, &f);
351 if (r < 0)
352 return r;
353
354 if (ret_type_system) {
355 const NLTypeSystem *t;
356
357 t = genl_family_get_type_system(f);
358 if (!t)
359 return -EOPNOTSUPP;
360
361 *ret_type_system = t;
362 }
363 if (ret_header_size)
364 *ret_header_size = sizeof(struct genlmsghdr) + f->additional_header_size;
365 return 0;
366 }
367
sd_genl_message_new(sd_netlink * nl,const char * family_name,uint8_t cmd,sd_netlink_message ** ret)368 int sd_genl_message_new(sd_netlink *nl, const char *family_name, uint8_t cmd, sd_netlink_message **ret) {
369 const GenericNetlinkFamily *family;
370 int r;
371
372 assert_return(nl, -EINVAL);
373 assert_return(nl->protocol == NETLINK_GENERIC, -EINVAL);
374 assert_return(family_name, -EINVAL);
375 assert_return(ret, -EINVAL);
376
377 r = genl_family_get_by_name(nl, family_name, &family);
378 if (r < 0)
379 return r;
380
381 return genl_message_new(nl, family, cmd, ret);
382 }
383
sd_genl_message_get_family_name(sd_netlink * nl,sd_netlink_message * m,const char ** ret)384 int sd_genl_message_get_family_name(sd_netlink *nl, sd_netlink_message *m, const char **ret) {
385 const GenericNetlinkFamily *family;
386 uint16_t nlmsg_type;
387 int r;
388
389 assert_return(nl, -EINVAL);
390 assert_return(nl->protocol == NETLINK_GENERIC, -EINVAL);
391 assert_return(m, -EINVAL);
392 assert_return(ret, -EINVAL);
393
394 r = sd_netlink_message_get_type(m, &nlmsg_type);
395 if (r < 0)
396 return r;
397
398 r = genl_family_get_by_id(nl, nlmsg_type, &family);
399 if (r < 0)
400 return r;
401
402 *ret = family->name;
403 return 0;
404 }
405
sd_genl_message_get_command(sd_netlink * nl,sd_netlink_message * m,uint8_t * ret)406 int sd_genl_message_get_command(sd_netlink *nl, sd_netlink_message *m, uint8_t *ret) {
407 struct genlmsghdr *h;
408 uint16_t nlmsg_type;
409 size_t size;
410 int r;
411
412 assert_return(nl, -EINVAL);
413 assert_return(nl->protocol == NETLINK_GENERIC, -EINVAL);
414 assert_return(m, -EINVAL);
415 assert_return(m->protocol == NETLINK_GENERIC, -EINVAL);
416 assert_return(m->hdr, -EINVAL);
417 assert_return(ret, -EINVAL);
418
419 r = sd_netlink_message_get_type(m, &nlmsg_type);
420 if (r < 0)
421 return r;
422
423 r = genl_get_type_system_and_header_size(nl, nlmsg_type, NULL, &size);
424 if (r < 0)
425 return r;
426
427 if (m->hdr->nlmsg_len < NLMSG_LENGTH(size))
428 return -EBADMSG;
429
430 h = NLMSG_DATA(m->hdr);
431
432 *ret = h->cmd;
433 return 0;
434 }
435
genl_family_get_multicast_group_id_by_name(const GenericNetlinkFamily * f,const char * name,uint32_t * ret)436 static int genl_family_get_multicast_group_id_by_name(const GenericNetlinkFamily *f, const char *name, uint32_t *ret) {
437 void *p;
438
439 assert(f);
440 assert(name);
441
442 p = hashmap_get(f->multicast_group_by_name, name);
443 if (!p)
444 return -ENOENT;
445
446 if (ret)
447 *ret = PTR_TO_UINT32(p);
448 return 0;
449 }
450
sd_genl_add_match(sd_netlink * nl,sd_netlink_slot ** ret_slot,const char * family_name,const char * multicast_group_name,uint8_t command,sd_netlink_message_handler_t callback,sd_netlink_destroy_t destroy_callback,void * userdata,const char * description)451 int sd_genl_add_match(
452 sd_netlink *nl,
453 sd_netlink_slot **ret_slot,
454 const char *family_name,
455 const char *multicast_group_name,
456 uint8_t command,
457 sd_netlink_message_handler_t callback,
458 sd_netlink_destroy_t destroy_callback,
459 void *userdata,
460 const char *description) {
461
462 const GenericNetlinkFamily *f;
463 uint32_t multicast_group_id;
464 int r;
465
466 assert_return(nl, -EINVAL);
467 assert_return(nl->protocol == NETLINK_GENERIC, -EINVAL);
468 assert_return(callback, -EINVAL);
469 assert_return(family_name, -EINVAL);
470 assert_return(multicast_group_name, -EINVAL);
471
472 /* If command == 0, then all commands belonging to the multicast group trigger the callback. */
473
474 r = genl_family_get_by_name(nl, family_name, &f);
475 if (r < 0)
476 return r;
477
478 r = genl_family_get_multicast_group_id_by_name(f, multicast_group_name, &multicast_group_id);
479 if (r < 0)
480 return r;
481
482 return netlink_add_match_internal(nl, ret_slot, &multicast_group_id, 1, f->id, command,
483 callback, destroy_callback, userdata, description);
484 }
485
sd_genl_socket_open(sd_netlink ** ret)486 int sd_genl_socket_open(sd_netlink **ret) {
487 return netlink_open_family(ret, NETLINK_GENERIC);
488 }
489