1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * Copyright (C) 2021 ARM Limited.
4  */
5 #include <errno.h>
6 #include <stdbool.h>
7 #include <stddef.h>
8 #include <stdio.h>
9 #include <stdlib.h>
10 #include <string.h>
11 #include <unistd.h>
12 #include <sys/auxv.h>
13 #include <sys/prctl.h>
14 #include <sys/ptrace.h>
15 #include <sys/types.h>
16 #include <sys/uio.h>
17 #include <sys/wait.h>
18 #include <asm/sigcontext.h>
19 #include <asm/ptrace.h>
20 
21 #include "../../kselftest.h"
22 
23 /* <linux/elf.h> and <sys/auxv.h> don't like each other, so: */
24 #ifndef NT_ARM_ZA
25 #define NT_ARM_ZA 0x40c
26 #endif
27 
28 #define EXPECTED_TESTS (((SVE_VQ_MAX - SVE_VQ_MIN) + 1) * 3)
29 
fill_buf(char * buf,size_t size)30 static void fill_buf(char *buf, size_t size)
31 {
32 	int i;
33 
34 	for (i = 0; i < size; i++)
35 		buf[i] = random();
36 }
37 
do_child(void)38 static int do_child(void)
39 {
40 	if (ptrace(PTRACE_TRACEME, -1, NULL, NULL))
41 		ksft_exit_fail_msg("PTRACE_TRACEME", strerror(errno));
42 
43 	if (raise(SIGSTOP))
44 		ksft_exit_fail_msg("raise(SIGSTOP)", strerror(errno));
45 
46 	return EXIT_SUCCESS;
47 }
48 
get_za(pid_t pid,void ** buf,size_t * size)49 static struct user_za_header *get_za(pid_t pid, void **buf, size_t *size)
50 {
51 	struct user_za_header *za;
52 	void *p;
53 	size_t sz = sizeof(*za);
54 	struct iovec iov;
55 
56 	while (1) {
57 		if (*size < sz) {
58 			p = realloc(*buf, sz);
59 			if (!p) {
60 				errno = ENOMEM;
61 				goto error;
62 			}
63 
64 			*buf = p;
65 			*size = sz;
66 		}
67 
68 		iov.iov_base = *buf;
69 		iov.iov_len = sz;
70 		if (ptrace(PTRACE_GETREGSET, pid, NT_ARM_ZA, &iov))
71 			goto error;
72 
73 		za = *buf;
74 		if (za->size <= sz)
75 			break;
76 
77 		sz = za->size;
78 	}
79 
80 	return za;
81 
82 error:
83 	return NULL;
84 }
85 
set_za(pid_t pid,const struct user_za_header * za)86 static int set_za(pid_t pid, const struct user_za_header *za)
87 {
88 	struct iovec iov;
89 
90 	iov.iov_base = (void *)za;
91 	iov.iov_len = za->size;
92 	return ptrace(PTRACE_SETREGSET, pid, NT_ARM_ZA, &iov);
93 }
94 
95 /* Validate attempting to set the specfied VL via ptrace */
ptrace_set_get_vl(pid_t child,unsigned int vl,bool * supported)96 static void ptrace_set_get_vl(pid_t child, unsigned int vl, bool *supported)
97 {
98 	struct user_za_header za;
99 	struct user_za_header *new_za = NULL;
100 	size_t new_za_size = 0;
101 	int ret, prctl_vl;
102 
103 	*supported = false;
104 
105 	/* Check if the VL is supported in this process */
106 	prctl_vl = prctl(PR_SME_SET_VL, vl);
107 	if (prctl_vl == -1)
108 		ksft_exit_fail_msg("prctl(PR_SME_SET_VL) failed: %s (%d)\n",
109 				   strerror(errno), errno);
110 
111 	/* If the VL is not supported then a supported VL will be returned */
112 	*supported = (prctl_vl == vl);
113 
114 	/* Set the VL by doing a set with no register payload */
115 	memset(&za, 0, sizeof(za));
116 	za.size = sizeof(za);
117 	za.vl = vl;
118 	ret = set_za(child, &za);
119 	if (ret != 0) {
120 		ksft_test_result_fail("Failed to set VL %u\n", vl);
121 		return;
122 	}
123 
124 	/*
125 	 * Read back the new register state and verify that we have the
126 	 * same VL that we got from prctl() on ourselves.
127 	 */
128 	if (!get_za(child, (void **)&new_za, &new_za_size)) {
129 		ksft_test_result_fail("Failed to read VL %u\n", vl);
130 		return;
131 	}
132 
133 	ksft_test_result(new_za->vl = prctl_vl, "Set VL %u\n", vl);
134 
135 	free(new_za);
136 }
137 
138 /* Validate attempting to set no ZA data and read it back */
ptrace_set_no_data(pid_t child,unsigned int vl)139 static void ptrace_set_no_data(pid_t child, unsigned int vl)
140 {
141 	void *read_buf = NULL;
142 	struct user_za_header write_za;
143 	struct user_za_header *read_za;
144 	size_t read_za_size = 0;
145 	int ret;
146 
147 	/* Set up some data and write it out */
148 	memset(&write_za, 0, sizeof(write_za));
149 	write_za.size = ZA_PT_ZA_OFFSET;
150 	write_za.vl = vl;
151 
152 	ret = set_za(child, &write_za);
153 	if (ret != 0) {
154 		ksft_test_result_fail("Failed to set VL %u no data\n", vl);
155 		return;
156 	}
157 
158 	/* Read the data back */
159 	if (!get_za(child, (void **)&read_buf, &read_za_size)) {
160 		ksft_test_result_fail("Failed to read VL %u no data\n", vl);
161 		return;
162 	}
163 	read_za = read_buf;
164 
165 	/* We might read more data if there's extensions we don't know */
166 	if (read_za->size < write_za.size) {
167 		ksft_test_result_fail("VL %u wrote %d bytes, only read %d\n",
168 				      vl, write_za.size, read_za->size);
169 		goto out_read;
170 	}
171 
172 	ksft_test_result(read_za->size == write_za.size,
173 			 "Disabled ZA for VL %u\n", vl);
174 
175 out_read:
176 	free(read_buf);
177 }
178 
179 /* Validate attempting to set data and read it back */
ptrace_set_get_data(pid_t child,unsigned int vl)180 static void ptrace_set_get_data(pid_t child, unsigned int vl)
181 {
182 	void *write_buf;
183 	void *read_buf = NULL;
184 	struct user_za_header *write_za;
185 	struct user_za_header *read_za;
186 	size_t read_za_size = 0;
187 	unsigned int vq = sve_vq_from_vl(vl);
188 	int ret;
189 	size_t data_size;
190 
191 	data_size = ZA_PT_SIZE(vq);
192 	write_buf = malloc(data_size);
193 	if (!write_buf) {
194 		ksft_test_result_fail("Error allocating %d byte buffer for VL %u\n",
195 				      data_size, vl);
196 		return;
197 	}
198 	write_za = write_buf;
199 
200 	/* Set up some data and write it out */
201 	memset(write_za, 0, data_size);
202 	write_za->size = data_size;
203 	write_za->vl = vl;
204 
205 	fill_buf(write_buf + ZA_PT_ZA_OFFSET, ZA_PT_ZA_SIZE(vq));
206 
207 	ret = set_za(child, write_za);
208 	if (ret != 0) {
209 		ksft_test_result_fail("Failed to set VL %u data\n", vl);
210 		goto out;
211 	}
212 
213 	/* Read the data back */
214 	if (!get_za(child, (void **)&read_buf, &read_za_size)) {
215 		ksft_test_result_fail("Failed to read VL %u data\n", vl);
216 		goto out;
217 	}
218 	read_za = read_buf;
219 
220 	/* We might read more data if there's extensions we don't know */
221 	if (read_za->size < write_za->size) {
222 		ksft_test_result_fail("VL %u wrote %d bytes, only read %d\n",
223 				      vl, write_za->size, read_za->size);
224 		goto out_read;
225 	}
226 
227 	ksft_test_result(memcmp(write_buf + ZA_PT_ZA_OFFSET,
228 				read_buf + ZA_PT_ZA_OFFSET,
229 				ZA_PT_ZA_SIZE(vq)) == 0,
230 			 "Data match for VL %u\n", vl);
231 
232 out_read:
233 	free(read_buf);
234 out:
235 	free(write_buf);
236 }
237 
do_parent(pid_t child)238 static int do_parent(pid_t child)
239 {
240 	int ret = EXIT_FAILURE;
241 	pid_t pid;
242 	int status;
243 	siginfo_t si;
244 	unsigned int vq, vl;
245 	bool vl_supported;
246 
247 	/* Attach to the child */
248 	while (1) {
249 		int sig;
250 
251 		pid = wait(&status);
252 		if (pid == -1) {
253 			perror("wait");
254 			goto error;
255 		}
256 
257 		/*
258 		 * This should never happen but it's hard to flag in
259 		 * the framework.
260 		 */
261 		if (pid != child)
262 			continue;
263 
264 		if (WIFEXITED(status) || WIFSIGNALED(status))
265 			ksft_exit_fail_msg("Child died unexpectedly\n");
266 
267 		if (!WIFSTOPPED(status))
268 			goto error;
269 
270 		sig = WSTOPSIG(status);
271 
272 		if (ptrace(PTRACE_GETSIGINFO, pid, NULL, &si)) {
273 			if (errno == ESRCH)
274 				goto disappeared;
275 
276 			if (errno == EINVAL) {
277 				sig = 0; /* bust group-stop */
278 				goto cont;
279 			}
280 
281 			ksft_test_result_fail("PTRACE_GETSIGINFO: %s\n",
282 					      strerror(errno));
283 			goto error;
284 		}
285 
286 		if (sig == SIGSTOP && si.si_code == SI_TKILL &&
287 		    si.si_pid == pid)
288 			break;
289 
290 	cont:
291 		if (ptrace(PTRACE_CONT, pid, NULL, sig)) {
292 			if (errno == ESRCH)
293 				goto disappeared;
294 
295 			ksft_test_result_fail("PTRACE_CONT: %s\n",
296 					      strerror(errno));
297 			goto error;
298 		}
299 	}
300 
301 	ksft_print_msg("Parent is %d, child is %d\n", getpid(), child);
302 
303 	/* Step through every possible VQ */
304 	for (vq = SVE_VQ_MIN; vq <= SVE_VQ_MAX; vq++) {
305 		vl = sve_vl_from_vq(vq);
306 
307 		/* First, try to set this vector length */
308 		ptrace_set_get_vl(child, vl, &vl_supported);
309 
310 		/* If the VL is supported validate data set/get */
311 		if (vl_supported) {
312 			ptrace_set_no_data(child, vl);
313 			ptrace_set_get_data(child, vl);
314 		} else {
315 			ksft_test_result_skip("Disabled ZA for VL %u\n", vl);
316 			ksft_test_result_skip("Get and set data for VL %u\n",
317 					      vl);
318 		}
319 	}
320 
321 	ret = EXIT_SUCCESS;
322 
323 error:
324 	kill(child, SIGKILL);
325 
326 disappeared:
327 	return ret;
328 }
329 
main(void)330 int main(void)
331 {
332 	int ret = EXIT_SUCCESS;
333 	pid_t child;
334 
335 	srandom(getpid());
336 
337 	ksft_print_header();
338 
339 	if (!(getauxval(AT_HWCAP2) & HWCAP2_SME)) {
340 		ksft_set_plan(1);
341 		ksft_exit_skip("SME not available\n");
342 	}
343 
344 	ksft_set_plan(EXPECTED_TESTS);
345 
346 	child = fork();
347 	if (!child)
348 		return do_child();
349 
350 	if (do_parent(child))
351 		ret = EXIT_FAILURE;
352 
353 	ksft_print_cnts();
354 
355 	return ret;
356 }
357