1 // SPDX-License-Identifier: GPL-2.0
2 /* Copyright (c) 2020, Tessares SA. */
3 /* Copyright (c) 2022, SUSE. */
4
5 #include <linux/const.h>
6 #include <netinet/in.h>
7 #include <test_progs.h>
8 #include "cgroup_helpers.h"
9 #include "network_helpers.h"
10 #include "mptcp_sock.skel.h"
11 #include "mptcpify.skel.h"
12
13 #define NS_TEST "mptcp_ns"
14
15 #ifndef IPPROTO_MPTCP
16 #define IPPROTO_MPTCP 262
17 #endif
18
19 #ifndef SOL_MPTCP
20 #define SOL_MPTCP 284
21 #endif
22 #ifndef MPTCP_INFO
23 #define MPTCP_INFO 1
24 #endif
25 #ifndef MPTCP_INFO_FLAG_FALLBACK
26 #define MPTCP_INFO_FLAG_FALLBACK _BITUL(0)
27 #endif
28 #ifndef MPTCP_INFO_FLAG_REMOTE_KEY_RECEIVED
29 #define MPTCP_INFO_FLAG_REMOTE_KEY_RECEIVED _BITUL(1)
30 #endif
31
32 #ifndef TCP_CA_NAME_MAX
33 #define TCP_CA_NAME_MAX 16
34 #endif
35
36 struct __mptcp_info {
37 __u8 mptcpi_subflows;
38 __u8 mptcpi_add_addr_signal;
39 __u8 mptcpi_add_addr_accepted;
40 __u8 mptcpi_subflows_max;
41 __u8 mptcpi_add_addr_signal_max;
42 __u8 mptcpi_add_addr_accepted_max;
43 __u32 mptcpi_flags;
44 __u32 mptcpi_token;
45 __u64 mptcpi_write_seq;
46 __u64 mptcpi_snd_una;
47 __u64 mptcpi_rcv_nxt;
48 __u8 mptcpi_local_addr_used;
49 __u8 mptcpi_local_addr_max;
50 __u8 mptcpi_csum_enabled;
51 __u32 mptcpi_retransmits;
52 __u64 mptcpi_bytes_retrans;
53 __u64 mptcpi_bytes_sent;
54 __u64 mptcpi_bytes_received;
55 __u64 mptcpi_bytes_acked;
56 };
57
58 struct mptcp_storage {
59 __u32 invoked;
60 __u32 is_mptcp;
61 struct sock *sk;
62 __u32 token;
63 struct sock *first;
64 char ca_name[TCP_CA_NAME_MAX];
65 };
66
create_netns(void)67 static struct nstoken *create_netns(void)
68 {
69 SYS(fail, "ip netns add %s", NS_TEST);
70 SYS(fail, "ip -net %s link set dev lo up", NS_TEST);
71
72 return open_netns(NS_TEST);
73 fail:
74 return NULL;
75 }
76
cleanup_netns(struct nstoken * nstoken)77 static void cleanup_netns(struct nstoken *nstoken)
78 {
79 if (nstoken)
80 close_netns(nstoken);
81
82 SYS_NOFAIL("ip netns del %s &> /dev/null", NS_TEST);
83 }
84
verify_tsk(int map_fd,int client_fd)85 static int verify_tsk(int map_fd, int client_fd)
86 {
87 int err, cfd = client_fd;
88 struct mptcp_storage val;
89
90 err = bpf_map_lookup_elem(map_fd, &cfd, &val);
91 if (!ASSERT_OK(err, "bpf_map_lookup_elem"))
92 return err;
93
94 if (!ASSERT_EQ(val.invoked, 1, "unexpected invoked count"))
95 err++;
96
97 if (!ASSERT_EQ(val.is_mptcp, 0, "unexpected is_mptcp"))
98 err++;
99
100 return err;
101 }
102
get_msk_ca_name(char ca_name[])103 static void get_msk_ca_name(char ca_name[])
104 {
105 size_t len;
106 int fd;
107
108 fd = open("/proc/sys/net/ipv4/tcp_congestion_control", O_RDONLY);
109 if (!ASSERT_GE(fd, 0, "failed to open tcp_congestion_control"))
110 return;
111
112 len = read(fd, ca_name, TCP_CA_NAME_MAX);
113 if (!ASSERT_GT(len, 0, "failed to read ca_name"))
114 goto err;
115
116 if (len > 0 && ca_name[len - 1] == '\n')
117 ca_name[len - 1] = '\0';
118
119 err:
120 close(fd);
121 }
122
verify_msk(int map_fd,int client_fd,__u32 token)123 static int verify_msk(int map_fd, int client_fd, __u32 token)
124 {
125 char ca_name[TCP_CA_NAME_MAX];
126 int err, cfd = client_fd;
127 struct mptcp_storage val;
128
129 if (!ASSERT_GT(token, 0, "invalid token"))
130 return -1;
131
132 get_msk_ca_name(ca_name);
133
134 err = bpf_map_lookup_elem(map_fd, &cfd, &val);
135 if (!ASSERT_OK(err, "bpf_map_lookup_elem"))
136 return err;
137
138 if (!ASSERT_EQ(val.invoked, 1, "unexpected invoked count"))
139 err++;
140
141 if (!ASSERT_EQ(val.is_mptcp, 1, "unexpected is_mptcp"))
142 err++;
143
144 if (!ASSERT_EQ(val.token, token, "unexpected token"))
145 err++;
146
147 if (!ASSERT_EQ(val.first, val.sk, "unexpected first"))
148 err++;
149
150 if (!ASSERT_STRNEQ(val.ca_name, ca_name, TCP_CA_NAME_MAX, "unexpected ca_name"))
151 err++;
152
153 return err;
154 }
155
run_test(int cgroup_fd,int server_fd,bool is_mptcp)156 static int run_test(int cgroup_fd, int server_fd, bool is_mptcp)
157 {
158 int client_fd, prog_fd, map_fd, err;
159 struct mptcp_sock *sock_skel;
160
161 sock_skel = mptcp_sock__open_and_load();
162 if (!ASSERT_OK_PTR(sock_skel, "skel_open_load"))
163 return libbpf_get_error(sock_skel);
164
165 err = mptcp_sock__attach(sock_skel);
166 if (!ASSERT_OK(err, "skel_attach"))
167 goto out;
168
169 prog_fd = bpf_program__fd(sock_skel->progs._sockops);
170 map_fd = bpf_map__fd(sock_skel->maps.socket_storage_map);
171 err = bpf_prog_attach(prog_fd, cgroup_fd, BPF_CGROUP_SOCK_OPS, 0);
172 if (!ASSERT_OK(err, "bpf_prog_attach"))
173 goto out;
174
175 client_fd = connect_to_fd(server_fd, 0);
176 if (!ASSERT_GE(client_fd, 0, "connect to fd")) {
177 err = -EIO;
178 goto out;
179 }
180
181 err += is_mptcp ? verify_msk(map_fd, client_fd, sock_skel->bss->token) :
182 verify_tsk(map_fd, client_fd);
183
184 close(client_fd);
185
186 out:
187 mptcp_sock__destroy(sock_skel);
188 return err;
189 }
190
test_base(void)191 static void test_base(void)
192 {
193 struct nstoken *nstoken = NULL;
194 int server_fd, cgroup_fd;
195
196 cgroup_fd = test__join_cgroup("/mptcp");
197 if (!ASSERT_GE(cgroup_fd, 0, "test__join_cgroup"))
198 return;
199
200 nstoken = create_netns();
201 if (!ASSERT_OK_PTR(nstoken, "create_netns"))
202 goto fail;
203
204 /* without MPTCP */
205 server_fd = start_server(AF_INET, SOCK_STREAM, NULL, 0, 0);
206 if (!ASSERT_GE(server_fd, 0, "start_server"))
207 goto with_mptcp;
208
209 ASSERT_OK(run_test(cgroup_fd, server_fd, false), "run_test tcp");
210
211 close(server_fd);
212
213 with_mptcp:
214 /* with MPTCP */
215 server_fd = start_mptcp_server(AF_INET, NULL, 0, 0);
216 if (!ASSERT_GE(server_fd, 0, "start_mptcp_server"))
217 goto fail;
218
219 ASSERT_OK(run_test(cgroup_fd, server_fd, true), "run_test mptcp");
220
221 close(server_fd);
222
223 fail:
224 cleanup_netns(nstoken);
225 close(cgroup_fd);
226 }
227
send_byte(int fd)228 static void send_byte(int fd)
229 {
230 char b = 0x55;
231
232 ASSERT_EQ(write(fd, &b, sizeof(b)), 1, "send single byte");
233 }
234
verify_mptcpify(int server_fd,int client_fd)235 static int verify_mptcpify(int server_fd, int client_fd)
236 {
237 struct __mptcp_info info;
238 socklen_t optlen;
239 int protocol;
240 int err = 0;
241
242 optlen = sizeof(protocol);
243 if (!ASSERT_OK(getsockopt(server_fd, SOL_SOCKET, SO_PROTOCOL, &protocol, &optlen),
244 "getsockopt(SOL_PROTOCOL)"))
245 return -1;
246
247 if (!ASSERT_EQ(protocol, IPPROTO_MPTCP, "protocol isn't MPTCP"))
248 err++;
249
250 optlen = sizeof(info);
251 if (!ASSERT_OK(getsockopt(client_fd, SOL_MPTCP, MPTCP_INFO, &info, &optlen),
252 "getsockopt(MPTCP_INFO)"))
253 return -1;
254
255 if (!ASSERT_GE(info.mptcpi_flags, 0, "unexpected mptcpi_flags"))
256 err++;
257 if (!ASSERT_FALSE(info.mptcpi_flags & MPTCP_INFO_FLAG_FALLBACK,
258 "MPTCP fallback"))
259 err++;
260 if (!ASSERT_TRUE(info.mptcpi_flags & MPTCP_INFO_FLAG_REMOTE_KEY_RECEIVED,
261 "no remote key received"))
262 err++;
263
264 return err;
265 }
266
run_mptcpify(int cgroup_fd)267 static int run_mptcpify(int cgroup_fd)
268 {
269 int server_fd, client_fd, err = 0;
270 struct mptcpify *mptcpify_skel;
271
272 mptcpify_skel = mptcpify__open_and_load();
273 if (!ASSERT_OK_PTR(mptcpify_skel, "skel_open_load"))
274 return libbpf_get_error(mptcpify_skel);
275
276 err = mptcpify__attach(mptcpify_skel);
277 if (!ASSERT_OK(err, "skel_attach"))
278 goto out;
279
280 /* without MPTCP */
281 server_fd = start_server(AF_INET, SOCK_STREAM, NULL, 0, 0);
282 if (!ASSERT_GE(server_fd, 0, "start_server")) {
283 err = -EIO;
284 goto out;
285 }
286
287 client_fd = connect_to_fd(server_fd, 0);
288 if (!ASSERT_GE(client_fd, 0, "connect to fd")) {
289 err = -EIO;
290 goto close_server;
291 }
292
293 send_byte(client_fd);
294
295 err = verify_mptcpify(server_fd, client_fd);
296
297 close(client_fd);
298 close_server:
299 close(server_fd);
300 out:
301 mptcpify__destroy(mptcpify_skel);
302 return err;
303 }
304
test_mptcpify(void)305 static void test_mptcpify(void)
306 {
307 struct nstoken *nstoken = NULL;
308 int cgroup_fd;
309
310 cgroup_fd = test__join_cgroup("/mptcpify");
311 if (!ASSERT_GE(cgroup_fd, 0, "test__join_cgroup"))
312 return;
313
314 nstoken = create_netns();
315 if (!ASSERT_OK_PTR(nstoken, "create_netns"))
316 goto fail;
317
318 ASSERT_OK(run_mptcpify(cgroup_fd), "run_mptcpify");
319
320 fail:
321 cleanup_netns(nstoken);
322 close(cgroup_fd);
323 }
324
test_mptcp(void)325 void test_mptcp(void)
326 {
327 if (test__start_subtest("base"))
328 test_base();
329 if (test__start_subtest("mptcpify"))
330 test_mptcpify();
331 }
332