1 /* SPDX-License-Identifier: LGPL-2.1-or-later */
2 
3 #include <linux/genetlink.h>
4 
5 #include "sd-netlink.h"
6 
7 #include "alloc-util.h"
8 #include "netlink-genl.h"
9 #include "netlink-internal.h"
10 #include "netlink-types.h"
11 
12 typedef struct GenericNetlinkFamily {
13         sd_netlink *genl;
14 
15         const NLTypeSystem *type_system;
16 
17         uint16_t id; /* a.k.a nlmsg_type */
18         char *name;
19         uint32_t version;
20         uint32_t additional_header_size;
21         Hashmap *multicast_group_by_name;
22 } GenericNetlinkFamily;
23 
24 static const GenericNetlinkFamily nlctrl_static = {
25         .id = GENL_ID_CTRL,
26         .name = (char*) CTRL_GENL_NAME,
27         .version = 0x01,
28 };
29 
genl_family_free(GenericNetlinkFamily * f)30 static GenericNetlinkFamily *genl_family_free(GenericNetlinkFamily *f) {
31         if (!f)
32                 return NULL;
33 
34         if (f->genl) {
35                 if (f->id > 0)
36                         hashmap_remove(f->genl->genl_family_by_id, UINT_TO_PTR(f->id));
37                 if (f->name)
38                         hashmap_remove(f->genl->genl_family_by_name, f->name);
39         }
40 
41         free(f->name);
42         hashmap_free(f->multicast_group_by_name);
43 
44         return mfree(f);
45 }
46 
47 DEFINE_TRIVIAL_CLEANUP_FUNC(GenericNetlinkFamily*, genl_family_free);
48 
genl_clear_family(sd_netlink * nl)49 void genl_clear_family(sd_netlink *nl) {
50         assert(nl);
51 
52         nl->genl_family_by_name = hashmap_free_with_destructor(nl->genl_family_by_name, genl_family_free);
53         nl->genl_family_by_id = hashmap_free_with_destructor(nl->genl_family_by_id, genl_family_free);
54 }
55 
genl_family_new_unsupported(sd_netlink * nl,const char * family_name,const NLTypeSystem * type_system)56 static int genl_family_new_unsupported(
57                 sd_netlink *nl,
58                 const char *family_name,
59                 const NLTypeSystem *type_system) {
60 
61         _cleanup_(genl_family_freep) GenericNetlinkFamily *f = NULL;
62         int r;
63 
64         assert(nl);
65         assert(family_name);
66         assert(type_system);
67 
68         /* Kernel does not support the genl family? To prevent from resolving the family name again,
69          * let's store the family with zero id to indicate that. */
70 
71         f = new(GenericNetlinkFamily, 1);
72         if (!f)
73                 return -ENOMEM;
74 
75         *f = (GenericNetlinkFamily) {
76                 .type_system = type_system,
77         };
78 
79         f->name = strdup(family_name);
80         if (!f->name)
81                 return -ENOMEM;
82 
83         r = hashmap_ensure_put(&nl->genl_family_by_name, &string_hash_ops, f->name, f);
84         if (r < 0)
85                 return r;
86 
87         f->genl = nl;
88         TAKE_PTR(f);
89         return 0;
90 }
91 
genl_family_new(sd_netlink * nl,const char * expected_family_name,const NLTypeSystem * type_system,sd_netlink_message * message,const GenericNetlinkFamily ** ret)92 static int genl_family_new(
93                 sd_netlink *nl,
94                 const char *expected_family_name,
95                 const NLTypeSystem *type_system,
96                 sd_netlink_message *message,
97                 const GenericNetlinkFamily **ret) {
98 
99         _cleanup_(genl_family_freep) GenericNetlinkFamily *f = NULL;
100         const char *family_name;
101         uint8_t cmd;
102         int r;
103 
104         assert(nl);
105         assert(expected_family_name);
106         assert(type_system);
107         assert(message);
108         assert(ret);
109 
110         f = new(GenericNetlinkFamily, 1);
111         if (!f)
112                 return -ENOMEM;
113 
114         *f = (GenericNetlinkFamily) {
115                 .type_system = type_system,
116         };
117 
118         r = sd_genl_message_get_family_name(nl, message, &family_name);
119         if (r < 0)
120                 return r;
121 
122         if (!streq(family_name, CTRL_GENL_NAME))
123                 return -EINVAL;
124 
125         r = sd_genl_message_get_command(nl, message, &cmd);
126         if (r < 0)
127                 return r;
128 
129         if (cmd != CTRL_CMD_NEWFAMILY)
130                 return -EINVAL;
131 
132         r = sd_netlink_message_read_u16(message, CTRL_ATTR_FAMILY_ID, &f->id);
133         if (r < 0)
134                 return r;
135 
136         r = sd_netlink_message_read_string_strdup(message, CTRL_ATTR_FAMILY_NAME, &f->name);
137         if (r < 0)
138                 return r;
139 
140         if (!streq(f->name, expected_family_name))
141                 return -EINVAL;
142 
143         r = sd_netlink_message_read_u32(message, CTRL_ATTR_VERSION, &f->version);
144         if (r < 0)
145                 return r;
146 
147         r = sd_netlink_message_read_u32(message, CTRL_ATTR_HDRSIZE, &f->additional_header_size);
148         if (r < 0)
149                 return r;
150 
151         r = sd_netlink_message_enter_container(message, CTRL_ATTR_MCAST_GROUPS);
152         if (r >= 0) {
153                 for (uint16_t i = 0; i < UINT16_MAX; i++) {
154                         _cleanup_free_ char *group_name = NULL;
155                         uint32_t group_id;
156 
157                         r = sd_netlink_message_enter_array(message, i + 1);
158                         if (r == -ENODATA)
159                                 break;
160                         if (r < 0)
161                                 return r;
162 
163                         r = sd_netlink_message_read_u32(message, CTRL_ATTR_MCAST_GRP_ID, &group_id);
164                         if (r < 0)
165                                 return r;
166 
167                         r = sd_netlink_message_read_string_strdup(message, CTRL_ATTR_MCAST_GRP_NAME, &group_name);
168                         if (r < 0)
169                                 return r;
170 
171                         r = sd_netlink_message_exit_container(message);
172                         if (r < 0)
173                                 return r;
174 
175                         if (group_id == 0) {
176                                 log_debug("sd-netlink: received multicast group '%s' for generic netlink family '%s' with id == 0, ignoring",
177                                           group_name, f->name);
178                                 continue;
179                         }
180 
181                         r = hashmap_ensure_put(&f->multicast_group_by_name, &string_hash_ops_free, group_name, UINT32_TO_PTR(group_id));
182                         if (r < 0)
183                                 return r;
184 
185                         TAKE_PTR(group_name);
186                 }
187 
188                 r = sd_netlink_message_exit_container(message);
189                 if (r < 0)
190                         return r;
191         }
192 
193         r = hashmap_ensure_put(&nl->genl_family_by_id, NULL, UINT_TO_PTR(f->id), f);
194         if (r < 0)
195                 return r;
196 
197         r = hashmap_ensure_put(&nl->genl_family_by_name, &string_hash_ops, f->name, f);
198         if (r < 0) {
199                 hashmap_remove(nl->genl_family_by_id, UINT_TO_PTR(f->id));
200                 return r;
201         }
202 
203         f->genl = nl;
204         *ret = TAKE_PTR(f);
205         return 0;
206 }
207 
genl_family_get_type_system(const GenericNetlinkFamily * family)208 static const NLTypeSystem *genl_family_get_type_system(const GenericNetlinkFamily *family) {
209         assert(family);
210 
211         if (family->type_system)
212                 return family->type_system;
213 
214         return genl_get_type_system_by_name(family->name);
215 }
216 
genl_message_new(sd_netlink * nl,const GenericNetlinkFamily * family,uint8_t cmd,sd_netlink_message ** ret)217 static int genl_message_new(
218                 sd_netlink *nl,
219                 const GenericNetlinkFamily *family,
220                 uint8_t cmd,
221                 sd_netlink_message **ret) {
222 
223         _cleanup_(sd_netlink_message_unrefp) sd_netlink_message *m = NULL;
224         const NLTypeSystem *type_system;
225         int r;
226 
227         assert(nl);
228         assert(nl->protocol == NETLINK_GENERIC);
229         assert(family);
230         assert(ret);
231 
232         type_system = genl_family_get_type_system(family);
233         if (!type_system)
234                 return -EOPNOTSUPP;
235 
236         r = message_new_full(nl, family->id, type_system,
237                              sizeof(struct genlmsghdr) + family->additional_header_size, &m);
238         if (r < 0)
239                 return r;
240 
241         *(struct genlmsghdr *) NLMSG_DATA(m->hdr) = (struct genlmsghdr) {
242                 .cmd = cmd,
243                 .version = family->version,
244         };
245 
246         *ret = TAKE_PTR(m);
247         return 0;
248 }
249 
genl_family_get_by_name_internal(sd_netlink * nl,const GenericNetlinkFamily * ctrl,const char * name,const GenericNetlinkFamily ** ret)250 static int genl_family_get_by_name_internal(
251                 sd_netlink *nl,
252                 const GenericNetlinkFamily *ctrl,
253                 const char *name,
254                 const GenericNetlinkFamily **ret) {
255 
256         _cleanup_(sd_netlink_message_unrefp) sd_netlink_message *req = NULL, *reply = NULL;
257         const NLTypeSystem *type_system;
258         int r;
259 
260         assert(nl);
261         assert(nl->protocol == NETLINK_GENERIC);
262         assert(ctrl);
263         assert(name);
264         assert(ret);
265 
266         type_system = genl_get_type_system_by_name(name);
267         if (!type_system)
268                 return -EOPNOTSUPP;
269 
270         r = genl_message_new(nl, ctrl, CTRL_CMD_GETFAMILY, &req);
271         if (r < 0)
272                 return r;
273 
274         r = sd_netlink_message_append_string(req, CTRL_ATTR_FAMILY_NAME, name);
275         if (r < 0)
276                 return r;
277 
278         if (sd_netlink_call(nl, req, 0, &reply) < 0) {
279                 (void) genl_family_new_unsupported(nl, name, type_system);
280                 return -EOPNOTSUPP;
281         }
282 
283         return genl_family_new(nl, name, type_system, reply, ret);
284 }
285 
genl_family_get_by_name(sd_netlink * nl,const char * name,const GenericNetlinkFamily ** ret)286 static int genl_family_get_by_name(sd_netlink *nl, const char *name, const GenericNetlinkFamily **ret) {
287         const GenericNetlinkFamily *f, *ctrl;
288         int r;
289 
290         assert(nl);
291         assert(nl->protocol == NETLINK_GENERIC);
292         assert(name);
293         assert(ret);
294 
295         f = hashmap_get(nl->genl_family_by_name, name);
296         if (f) {
297                 if (f->id == 0) /* kernel does not support the family. */
298                         return -EOPNOTSUPP;
299 
300                 *ret = f;
301                 return 0;
302         }
303 
304         if (streq(name, CTRL_GENL_NAME))
305                 return genl_family_get_by_name_internal(nl, &nlctrl_static, CTRL_GENL_NAME, ret);
306 
307         ctrl = hashmap_get(nl->genl_family_by_name, CTRL_GENL_NAME);
308         if (!ctrl) {
309                 r = genl_family_get_by_name_internal(nl, &nlctrl_static, CTRL_GENL_NAME, &ctrl);
310                 if (r < 0)
311                         return r;
312         }
313 
314         return genl_family_get_by_name_internal(nl, ctrl, name, ret);
315 }
316 
genl_family_get_by_id(sd_netlink * nl,uint16_t id,const GenericNetlinkFamily ** ret)317 static int genl_family_get_by_id(sd_netlink *nl, uint16_t id, const GenericNetlinkFamily **ret) {
318         const GenericNetlinkFamily *f;
319 
320         assert(nl);
321         assert(nl->protocol == NETLINK_GENERIC);
322         assert(ret);
323 
324         f = hashmap_get(nl->genl_family_by_id, UINT_TO_PTR(id));
325         if (f) {
326                 *ret = f;
327                 return 0;
328         }
329 
330         if (id == GENL_ID_CTRL) {
331                 *ret = &nlctrl_static;
332                 return 0;
333         }
334 
335         return -ENOENT;
336 }
337 
genl_get_type_system_and_header_size(sd_netlink * nl,uint16_t id,const NLTypeSystem ** ret_type_system,size_t * ret_header_size)338 int genl_get_type_system_and_header_size(
339                 sd_netlink *nl,
340                 uint16_t id,
341                 const NLTypeSystem **ret_type_system,
342                 size_t *ret_header_size) {
343 
344         const GenericNetlinkFamily *f;
345         int r;
346 
347         assert(nl);
348         assert(nl->protocol == NETLINK_GENERIC);
349 
350         r = genl_family_get_by_id(nl, id, &f);
351         if (r < 0)
352                 return r;
353 
354         if (ret_type_system) {
355                 const NLTypeSystem *t;
356 
357                 t = genl_family_get_type_system(f);
358                 if (!t)
359                         return -EOPNOTSUPP;
360 
361                 *ret_type_system = t;
362         }
363         if (ret_header_size)
364                 *ret_header_size = sizeof(struct genlmsghdr) + f->additional_header_size;
365         return 0;
366 }
367 
sd_genl_message_new(sd_netlink * nl,const char * family_name,uint8_t cmd,sd_netlink_message ** ret)368 int sd_genl_message_new(sd_netlink *nl, const char *family_name, uint8_t cmd, sd_netlink_message **ret) {
369         const GenericNetlinkFamily *family;
370         int r;
371 
372         assert_return(nl, -EINVAL);
373         assert_return(nl->protocol == NETLINK_GENERIC, -EINVAL);
374         assert_return(family_name, -EINVAL);
375         assert_return(ret, -EINVAL);
376 
377         r = genl_family_get_by_name(nl, family_name, &family);
378         if (r < 0)
379                 return r;
380 
381         return genl_message_new(nl, family, cmd, ret);
382 }
383 
sd_genl_message_get_family_name(sd_netlink * nl,sd_netlink_message * m,const char ** ret)384 int sd_genl_message_get_family_name(sd_netlink *nl, sd_netlink_message *m, const char **ret) {
385         const GenericNetlinkFamily *family;
386         uint16_t nlmsg_type;
387         int r;
388 
389         assert_return(nl, -EINVAL);
390         assert_return(nl->protocol == NETLINK_GENERIC, -EINVAL);
391         assert_return(m, -EINVAL);
392         assert_return(ret, -EINVAL);
393 
394         r = sd_netlink_message_get_type(m, &nlmsg_type);
395         if (r < 0)
396                 return r;
397 
398         r = genl_family_get_by_id(nl, nlmsg_type, &family);
399         if (r < 0)
400                 return r;
401 
402         *ret = family->name;
403         return 0;
404 }
405 
sd_genl_message_get_command(sd_netlink * nl,sd_netlink_message * m,uint8_t * ret)406 int sd_genl_message_get_command(sd_netlink *nl, sd_netlink_message *m, uint8_t *ret) {
407         struct genlmsghdr *h;
408         uint16_t nlmsg_type;
409         size_t size;
410         int r;
411 
412         assert_return(nl, -EINVAL);
413         assert_return(nl->protocol == NETLINK_GENERIC, -EINVAL);
414         assert_return(m, -EINVAL);
415         assert_return(m->protocol == NETLINK_GENERIC, -EINVAL);
416         assert_return(m->hdr, -EINVAL);
417         assert_return(ret, -EINVAL);
418 
419         r = sd_netlink_message_get_type(m, &nlmsg_type);
420         if (r < 0)
421                 return r;
422 
423         r = genl_get_type_system_and_header_size(nl, nlmsg_type, NULL, &size);
424         if (r < 0)
425                 return r;
426 
427         if (m->hdr->nlmsg_len < NLMSG_LENGTH(size))
428                 return -EBADMSG;
429 
430         h = NLMSG_DATA(m->hdr);
431 
432         *ret = h->cmd;
433         return 0;
434 }
435 
genl_family_get_multicast_group_id_by_name(const GenericNetlinkFamily * f,const char * name,uint32_t * ret)436 static int genl_family_get_multicast_group_id_by_name(const GenericNetlinkFamily *f, const char *name, uint32_t *ret) {
437         void *p;
438 
439         assert(f);
440         assert(name);
441 
442         p = hashmap_get(f->multicast_group_by_name, name);
443         if (!p)
444                 return -ENOENT;
445 
446         if (ret)
447                 *ret = PTR_TO_UINT32(p);
448         return 0;
449 }
450 
sd_genl_add_match(sd_netlink * nl,sd_netlink_slot ** ret_slot,const char * family_name,const char * multicast_group_name,uint8_t command,sd_netlink_message_handler_t callback,sd_netlink_destroy_t destroy_callback,void * userdata,const char * description)451 int sd_genl_add_match(
452                 sd_netlink *nl,
453                 sd_netlink_slot **ret_slot,
454                 const char *family_name,
455                 const char *multicast_group_name,
456                 uint8_t command,
457                 sd_netlink_message_handler_t callback,
458                 sd_netlink_destroy_t destroy_callback,
459                 void *userdata,
460                 const char *description) {
461 
462         const GenericNetlinkFamily *f;
463         uint32_t multicast_group_id;
464         int r;
465 
466         assert_return(nl, -EINVAL);
467         assert_return(nl->protocol == NETLINK_GENERIC, -EINVAL);
468         assert_return(callback, -EINVAL);
469         assert_return(family_name, -EINVAL);
470         assert_return(multicast_group_name, -EINVAL);
471 
472         /* If command == 0, then all commands belonging to the multicast group trigger the callback. */
473 
474         r = genl_family_get_by_name(nl, family_name, &f);
475         if (r < 0)
476                 return r;
477 
478         r = genl_family_get_multicast_group_id_by_name(f, multicast_group_name, &multicast_group_id);
479         if (r < 0)
480                 return r;
481 
482         return netlink_add_match_internal(nl, ret_slot, &multicast_group_id, 1, f->id, command,
483                                           callback, destroy_callback, userdata, description);
484 }
485 
sd_genl_socket_open(sd_netlink ** ret)486 int sd_genl_socket_open(sd_netlink **ret) {
487         return netlink_open_family(ret, NETLINK_GENERIC);
488 }
489