1 /* SPDX-License-Identifier: LGPL-2.1-or-later */
2 
3 #include <linux/netlink.h>
4 
5 #include "netlink-genl.h"
6 #include "netlink-internal.h"
7 #include "netlink-types-internal.h"
8 
9 static const NLType empty_types[1] = {
10         /* fake array to avoid .types==NULL, which denotes invalid type-systems */
11 };
12 
13 DEFINE_TYPE_SYSTEM(empty);
14 
15 static const NLType error_types[] = {
16         [NLMSGERR_ATTR_MSG]  = { .type = NETLINK_TYPE_STRING },
17         [NLMSGERR_ATTR_OFFS] = { .type = NETLINK_TYPE_U32 },
18 };
19 
20 DEFINE_TYPE_SYSTEM(error);
21 
22 static const NLType basic_types[] = {
23         [NLMSG_DONE]  = { .type = NETLINK_TYPE_NESTED, .type_system = &empty_type_system },
24         [NLMSG_ERROR] = { .type = NETLINK_TYPE_NESTED, .type_system = &error_type_system, .size = sizeof(struct nlmsgerr) },
25 };
26 
27 DEFINE_TYPE_SYSTEM(basic);
28 
type_get_type(const NLType * type)29 uint16_t type_get_type(const NLType *type) {
30         assert(type);
31         return type->type;
32 }
33 
type_get_size(const NLType * type)34 size_t type_get_size(const NLType *type) {
35         assert(type);
36         return type->size;
37 }
38 
type_get_type_system(const NLType * nl_type)39 const NLTypeSystem *type_get_type_system(const NLType *nl_type) {
40         assert(nl_type);
41         assert(nl_type->type == NETLINK_TYPE_NESTED);
42         assert(nl_type->type_system);
43         return nl_type->type_system;
44 }
45 
type_get_type_system_union(const NLType * nl_type)46 const NLTypeSystemUnion *type_get_type_system_union(const NLType *nl_type) {
47         assert(nl_type);
48         assert(nl_type->type == NETLINK_TYPE_UNION);
49         assert(nl_type->type_system_union);
50         return nl_type->type_system_union;
51 }
52 
type_system_root_get_type_system_and_header_size(sd_netlink * nl,uint16_t type,const NLTypeSystem ** ret_type_system,size_t * ret_header_size)53 int type_system_root_get_type_system_and_header_size(
54                 sd_netlink *nl,
55                 uint16_t type,
56                 const NLTypeSystem **ret_type_system,
57                 size_t *ret_header_size) {
58 
59         const NLType *nl_type;
60 
61         assert(nl);
62 
63         if (IN_SET(type, NLMSG_DONE, NLMSG_ERROR))
64                 nl_type = type_system_get_type(&basic_type_system, type);
65         else
66                 switch (nl->protocol) {
67                 case NETLINK_ROUTE:
68                         nl_type = rtnl_get_type(type);
69                         break;
70                 case NETLINK_NETFILTER:
71                         nl_type = nfnl_get_type(type);
72                         break;
73                 case NETLINK_GENERIC:
74                         return genl_get_type_system_and_header_size(nl, type, ret_type_system, ret_header_size);
75                 default:
76                         return -EOPNOTSUPP;
77                 }
78         if (!nl_type)
79                 return -EOPNOTSUPP;
80 
81         if (type_get_type(nl_type) != NETLINK_TYPE_NESTED)
82                 return -EOPNOTSUPP;
83 
84         if (ret_type_system)
85                 *ret_type_system = type_get_type_system(nl_type);
86         if (ret_header_size)
87                 *ret_header_size = type_get_size(nl_type);
88         return 0;
89 }
90 
type_system_get_type(const NLTypeSystem * type_system,uint16_t type)91 const NLType *type_system_get_type(const NLTypeSystem *type_system, uint16_t type) {
92         const NLType *nl_type;
93 
94         assert(type_system);
95         assert(type_system->types);
96 
97         if (type >= type_system->count)
98                 return NULL;
99 
100         nl_type = &type_system->types[type];
101 
102         if (nl_type->type == NETLINK_TYPE_UNSPEC)
103                 return NULL;
104 
105         return nl_type;
106 }
107 
type_system_get_type_system(const NLTypeSystem * type_system,uint16_t type)108 const NLTypeSystem *type_system_get_type_system(const NLTypeSystem *type_system, uint16_t type) {
109         const NLType *nl_type;
110 
111         nl_type = type_system_get_type(type_system, type);
112         if (!nl_type)
113                 return NULL;
114 
115         return type_get_type_system(nl_type);
116 }
117 
type_system_get_type_system_union(const NLTypeSystem * type_system,uint16_t type)118 const NLTypeSystemUnion *type_system_get_type_system_union(const NLTypeSystem *type_system, uint16_t type) {
119         const NLType *nl_type;
120 
121         nl_type = type_system_get_type(type_system, type);
122         if (!nl_type)
123                 return NULL;
124 
125         return type_get_type_system_union(nl_type);
126 }
127 
type_system_union_get_match_type(const NLTypeSystemUnion * type_system_union)128 NLMatchType type_system_union_get_match_type(const NLTypeSystemUnion *type_system_union) {
129         assert(type_system_union);
130         return type_system_union->match_type;
131 }
132 
type_system_union_get_match_attribute(const NLTypeSystemUnion * type_system_union)133 uint16_t type_system_union_get_match_attribute(const NLTypeSystemUnion *type_system_union) {
134         assert(type_system_union);
135         assert(type_system_union->match_type == NL_MATCH_SIBLING);
136         return type_system_union->match_attribute;
137 }
138 
type_system_union_get_type_system_by_string(const NLTypeSystemUnion * type_system_union,const char * key)139 const NLTypeSystem *type_system_union_get_type_system_by_string(const NLTypeSystemUnion *type_system_union, const char *key) {
140         assert(type_system_union);
141         assert(type_system_union->elements);
142         assert(type_system_union->match_type == NL_MATCH_SIBLING);
143         assert(key);
144 
145         for (size_t i = 0; i < type_system_union->count; i++)
146                 if (streq(type_system_union->elements[i].name, key))
147                         return &type_system_union->elements[i].type_system;
148 
149         return NULL;
150 }
151 
type_system_union_get_type_system_by_protocol(const NLTypeSystemUnion * type_system_union,uint16_t protocol)152 const NLTypeSystem *type_system_union_get_type_system_by_protocol(const NLTypeSystemUnion *type_system_union, uint16_t protocol) {
153         assert(type_system_union);
154         assert(type_system_union->elements);
155         assert(type_system_union->match_type == NL_MATCH_PROTOCOL);
156 
157         for (size_t i = 0; i < type_system_union->count; i++)
158                 if (type_system_union->elements[i].protocol == protocol)
159                         return &type_system_union->elements[i].type_system;
160 
161         return NULL;
162 }
163