1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * Copyright (C) 2021 ARM Limited.
4  */
5 
6 #include <errno.h>
7 #include <stdbool.h>
8 #include <stddef.h>
9 #include <stdio.h>
10 #include <stdlib.h>
11 #include <string.h>
12 #include <unistd.h>
13 #include <sys/auxv.h>
14 #include <sys/prctl.h>
15 #include <asm/hwcap.h>
16 #include <asm/sigcontext.h>
17 #include <asm/unistd.h>
18 
19 #include "../../kselftest.h"
20 
21 #include "syscall-abi.h"
22 
23 #define NUM_VL ((SVE_VQ_MAX - SVE_VQ_MIN) + 1)
24 
25 static int default_sme_vl;
26 
27 extern void do_syscall(int sve_vl, int sme_vl);
28 
fill_random(void * buf,size_t size)29 static void fill_random(void *buf, size_t size)
30 {
31 	int i;
32 	uint32_t *lbuf = buf;
33 
34 	/* random() returns a 32 bit number regardless of the size of long */
35 	for (i = 0; i < size / sizeof(uint32_t); i++)
36 		lbuf[i] = random();
37 }
38 
39 /*
40  * We also repeat the test for several syscalls to try to expose different
41  * behaviour.
42  */
43 static struct syscall_cfg {
44 	int syscall_nr;
45 	const char *name;
46 } syscalls[] = {
47 	{ __NR_getpid,		"getpid()" },
48 	{ __NR_sched_yield,	"sched_yield()" },
49 };
50 
51 #define NUM_GPR 31
52 uint64_t gpr_in[NUM_GPR];
53 uint64_t gpr_out[NUM_GPR];
54 
setup_gpr(struct syscall_cfg * cfg,int sve_vl,int sme_vl,uint64_t svcr)55 static void setup_gpr(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
56 		      uint64_t svcr)
57 {
58 	fill_random(gpr_in, sizeof(gpr_in));
59 	gpr_in[8] = cfg->syscall_nr;
60 	memset(gpr_out, 0, sizeof(gpr_out));
61 }
62 
check_gpr(struct syscall_cfg * cfg,int sve_vl,int sme_vl,uint64_t svcr)63 static int check_gpr(struct syscall_cfg *cfg, int sve_vl, int sme_vl, uint64_t svcr)
64 {
65 	int errors = 0;
66 	int i;
67 
68 	/*
69 	 * GPR x0-x7 may be clobbered, and all others should be preserved.
70 	 */
71 	for (i = 9; i < ARRAY_SIZE(gpr_in); i++) {
72 		if (gpr_in[i] != gpr_out[i]) {
73 			ksft_print_msg("%s SVE VL %d mismatch in GPR %d: %llx != %llx\n",
74 				       cfg->name, sve_vl, i,
75 				       gpr_in[i], gpr_out[i]);
76 			errors++;
77 		}
78 	}
79 
80 	return errors;
81 }
82 
83 #define NUM_FPR 32
84 uint64_t fpr_in[NUM_FPR * 2];
85 uint64_t fpr_out[NUM_FPR * 2];
86 
setup_fpr(struct syscall_cfg * cfg,int sve_vl,int sme_vl,uint64_t svcr)87 static void setup_fpr(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
88 		      uint64_t svcr)
89 {
90 	fill_random(fpr_in, sizeof(fpr_in));
91 	memset(fpr_out, 0, sizeof(fpr_out));
92 }
93 
check_fpr(struct syscall_cfg * cfg,int sve_vl,int sme_vl,uint64_t svcr)94 static int check_fpr(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
95 		     uint64_t svcr)
96 {
97 	int errors = 0;
98 	int i;
99 
100 	if (!sve_vl) {
101 		for (i = 0; i < ARRAY_SIZE(fpr_in); i++) {
102 			if (fpr_in[i] != fpr_out[i]) {
103 				ksft_print_msg("%s Q%d/%d mismatch %llx != %llx\n",
104 					       cfg->name,
105 					       i / 2, i % 2,
106 					       fpr_in[i], fpr_out[i]);
107 				errors++;
108 			}
109 		}
110 	}
111 
112 	return errors;
113 }
114 
115 #define SVE_Z_SHARED_BYTES (128 / 8)
116 
117 static uint8_t z_zero[__SVE_ZREG_SIZE(SVE_VQ_MAX)];
118 uint8_t z_in[SVE_NUM_ZREGS * __SVE_ZREG_SIZE(SVE_VQ_MAX)];
119 uint8_t z_out[SVE_NUM_ZREGS * __SVE_ZREG_SIZE(SVE_VQ_MAX)];
120 
setup_z(struct syscall_cfg * cfg,int sve_vl,int sme_vl,uint64_t svcr)121 static void setup_z(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
122 		    uint64_t svcr)
123 {
124 	fill_random(z_in, sizeof(z_in));
125 	fill_random(z_out, sizeof(z_out));
126 }
127 
check_z(struct syscall_cfg * cfg,int sve_vl,int sme_vl,uint64_t svcr)128 static int check_z(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
129 		   uint64_t svcr)
130 {
131 	size_t reg_size = sve_vl;
132 	int errors = 0;
133 	int i;
134 
135 	if (!sve_vl)
136 		return 0;
137 
138 	for (i = 0; i < SVE_NUM_ZREGS; i++) {
139 		uint8_t *in = &z_in[reg_size * i];
140 		uint8_t *out = &z_out[reg_size * i];
141 
142 		if (svcr & SVCR_SM_MASK) {
143 			/*
144 			 * In streaming mode the whole register should
145 			 * be cleared by the transition out of
146 			 * streaming mode.
147 			 */
148 			if (memcmp(z_zero, out, reg_size) != 0) {
149 				ksft_print_msg("%s SVE VL %d Z%d non-zero\n",
150 					       cfg->name, sve_vl, i);
151 				errors++;
152 			}
153 		} else {
154 			/*
155 			 * For standard SVE the low 128 bits should be
156 			 * preserved and any additional bits cleared.
157 			 */
158 			if (memcmp(in, out, SVE_Z_SHARED_BYTES) != 0) {
159 				ksft_print_msg("%s SVE VL %d Z%d low 128 bits changed\n",
160 					       cfg->name, sve_vl, i);
161 				errors++;
162 			}
163 
164 			if (reg_size > SVE_Z_SHARED_BYTES &&
165 			    (memcmp(z_zero, out + SVE_Z_SHARED_BYTES,
166 				    reg_size - SVE_Z_SHARED_BYTES) != 0)) {
167 				ksft_print_msg("%s SVE VL %d Z%d high bits non-zero\n",
168 					       cfg->name, sve_vl, i);
169 				errors++;
170 			}
171 		}
172 	}
173 
174 	return errors;
175 }
176 
177 uint8_t p_in[SVE_NUM_PREGS * __SVE_PREG_SIZE(SVE_VQ_MAX)];
178 uint8_t p_out[SVE_NUM_PREGS * __SVE_PREG_SIZE(SVE_VQ_MAX)];
179 
setup_p(struct syscall_cfg * cfg,int sve_vl,int sme_vl,uint64_t svcr)180 static void setup_p(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
181 		    uint64_t svcr)
182 {
183 	fill_random(p_in, sizeof(p_in));
184 	fill_random(p_out, sizeof(p_out));
185 }
186 
check_p(struct syscall_cfg * cfg,int sve_vl,int sme_vl,uint64_t svcr)187 static int check_p(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
188 		   uint64_t svcr)
189 {
190 	size_t reg_size = sve_vq_from_vl(sve_vl) * 2; /* 1 bit per VL byte */
191 
192 	int errors = 0;
193 	int i;
194 
195 	if (!sve_vl)
196 		return 0;
197 
198 	/* After a syscall the P registers should be zeroed */
199 	for (i = 0; i < SVE_NUM_PREGS * reg_size; i++)
200 		if (p_out[i])
201 			errors++;
202 	if (errors)
203 		ksft_print_msg("%s SVE VL %d predicate registers non-zero\n",
204 			       cfg->name, sve_vl);
205 
206 	return errors;
207 }
208 
209 uint8_t ffr_in[__SVE_PREG_SIZE(SVE_VQ_MAX)];
210 uint8_t ffr_out[__SVE_PREG_SIZE(SVE_VQ_MAX)];
211 
setup_ffr(struct syscall_cfg * cfg,int sve_vl,int sme_vl,uint64_t svcr)212 static void setup_ffr(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
213 		      uint64_t svcr)
214 {
215 	/*
216 	 * If we are in streaming mode and do not have FA64 then FFR
217 	 * is unavailable.
218 	 */
219 	if ((svcr & SVCR_SM_MASK) &&
220 	    !(getauxval(AT_HWCAP2) & HWCAP2_SME_FA64)) {
221 		memset(&ffr_in, 0, sizeof(ffr_in));
222 		return;
223 	}
224 
225 	/*
226 	 * It is only valid to set a contiguous set of bits starting
227 	 * at 0.  For now since we're expecting this to be cleared by
228 	 * a syscall just set all bits.
229 	 */
230 	memset(ffr_in, 0xff, sizeof(ffr_in));
231 	fill_random(ffr_out, sizeof(ffr_out));
232 }
233 
check_ffr(struct syscall_cfg * cfg,int sve_vl,int sme_vl,uint64_t svcr)234 static int check_ffr(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
235 		     uint64_t svcr)
236 {
237 	size_t reg_size = sve_vq_from_vl(sve_vl) * 2;  /* 1 bit per VL byte */
238 	int errors = 0;
239 	int i;
240 
241 	if (!sve_vl)
242 		return 0;
243 
244 	if ((svcr & SVCR_SM_MASK) &&
245 	    !(getauxval(AT_HWCAP2) & HWCAP2_SME_FA64))
246 		return 0;
247 
248 	/* After a syscall FFR should be zeroed */
249 	for (i = 0; i < reg_size; i++)
250 		if (ffr_out[i])
251 			errors++;
252 	if (errors)
253 		ksft_print_msg("%s SVE VL %d FFR non-zero\n",
254 			       cfg->name, sve_vl);
255 
256 	return errors;
257 }
258 
259 uint64_t svcr_in, svcr_out;
260 
setup_svcr(struct syscall_cfg * cfg,int sve_vl,int sme_vl,uint64_t svcr)261 static void setup_svcr(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
262 		    uint64_t svcr)
263 {
264 	svcr_in = svcr;
265 }
266 
check_svcr(struct syscall_cfg * cfg,int sve_vl,int sme_vl,uint64_t svcr)267 static int check_svcr(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
268 		      uint64_t svcr)
269 {
270 	int errors = 0;
271 
272 	if (svcr_out & SVCR_SM_MASK) {
273 		ksft_print_msg("%s Still in SM, SVCR %llx\n",
274 			       cfg->name, svcr_out);
275 		errors++;
276 	}
277 
278 	if ((svcr_in & SVCR_ZA_MASK) != (svcr_out & SVCR_ZA_MASK)) {
279 		ksft_print_msg("%s PSTATE.ZA changed, SVCR %llx != %llx\n",
280 			       cfg->name, svcr_in, svcr_out);
281 		errors++;
282 	}
283 
284 	return errors;
285 }
286 
287 uint8_t za_in[SVE_NUM_PREGS * __SVE_ZREG_SIZE(SVE_VQ_MAX)];
288 uint8_t za_out[SVE_NUM_PREGS * __SVE_ZREG_SIZE(SVE_VQ_MAX)];
289 
setup_za(struct syscall_cfg * cfg,int sve_vl,int sme_vl,uint64_t svcr)290 static void setup_za(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
291 		     uint64_t svcr)
292 {
293 	fill_random(za_in, sizeof(za_in));
294 	memset(za_out, 0, sizeof(za_out));
295 }
296 
check_za(struct syscall_cfg * cfg,int sve_vl,int sme_vl,uint64_t svcr)297 static int check_za(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
298 		    uint64_t svcr)
299 {
300 	size_t reg_size = sme_vl * sme_vl;
301 	int errors = 0;
302 
303 	if (!(svcr & SVCR_ZA_MASK))
304 		return 0;
305 
306 	if (memcmp(za_in, za_out, reg_size) != 0) {
307 		ksft_print_msg("SME VL %d ZA does not match\n", sme_vl);
308 		errors++;
309 	}
310 
311 	return errors;
312 }
313 
314 typedef void (*setup_fn)(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
315 			 uint64_t svcr);
316 typedef int (*check_fn)(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
317 			uint64_t svcr);
318 
319 /*
320  * Each set of registers has a setup function which is called before
321  * the syscall to fill values in a global variable for loading by the
322  * test code and a check function which validates that the results are
323  * as expected.  Vector lengths are passed everywhere, a vector length
324  * of 0 should be treated as do not test.
325  */
326 static struct {
327 	setup_fn setup;
328 	check_fn check;
329 } regset[] = {
330 	{ setup_gpr, check_gpr },
331 	{ setup_fpr, check_fpr },
332 	{ setup_z, check_z },
333 	{ setup_p, check_p },
334 	{ setup_ffr, check_ffr },
335 	{ setup_svcr, check_svcr },
336 	{ setup_za, check_za },
337 };
338 
do_test(struct syscall_cfg * cfg,int sve_vl,int sme_vl,uint64_t svcr)339 static bool do_test(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
340 		    uint64_t svcr)
341 {
342 	int errors = 0;
343 	int i;
344 
345 	for (i = 0; i < ARRAY_SIZE(regset); i++)
346 		regset[i].setup(cfg, sve_vl, sme_vl, svcr);
347 
348 	do_syscall(sve_vl, sme_vl);
349 
350 	for (i = 0; i < ARRAY_SIZE(regset); i++)
351 		errors += regset[i].check(cfg, sve_vl, sme_vl, svcr);
352 
353 	return errors == 0;
354 }
355 
test_one_syscall(struct syscall_cfg * cfg)356 static void test_one_syscall(struct syscall_cfg *cfg)
357 {
358 	int sve_vq, sve_vl;
359 	int sme_vq, sme_vl;
360 
361 	/* FPSIMD only case */
362 	ksft_test_result(do_test(cfg, 0, default_sme_vl, 0),
363 			 "%s FPSIMD\n", cfg->name);
364 
365 	if (!(getauxval(AT_HWCAP) & HWCAP_SVE))
366 		return;
367 
368 	for (sve_vq = SVE_VQ_MAX; sve_vq > 0; --sve_vq) {
369 		sve_vl = prctl(PR_SVE_SET_VL, sve_vq * 16);
370 		if (sve_vl == -1)
371 			ksft_exit_fail_msg("PR_SVE_SET_VL failed: %s (%d)\n",
372 					   strerror(errno), errno);
373 
374 		sve_vl &= PR_SVE_VL_LEN_MASK;
375 
376 		if (sve_vq != sve_vq_from_vl(sve_vl))
377 			sve_vq = sve_vq_from_vl(sve_vl);
378 
379 		ksft_test_result(do_test(cfg, sve_vl, default_sme_vl, 0),
380 				 "%s SVE VL %d\n", cfg->name, sve_vl);
381 
382 		if (!(getauxval(AT_HWCAP2) & HWCAP2_SME))
383 			continue;
384 
385 		for (sme_vq = SVE_VQ_MAX; sme_vq > 0; --sme_vq) {
386 			sme_vl = prctl(PR_SME_SET_VL, sme_vq * 16);
387 			if (sme_vl == -1)
388 				ksft_exit_fail_msg("PR_SME_SET_VL failed: %s (%d)\n",
389 						   strerror(errno), errno);
390 
391 			sme_vl &= PR_SME_VL_LEN_MASK;
392 
393 			if (sme_vq != sve_vq_from_vl(sme_vl))
394 				sme_vq = sve_vq_from_vl(sme_vl);
395 
396 			ksft_test_result(do_test(cfg, sve_vl, sme_vl,
397 						 SVCR_ZA_MASK | SVCR_SM_MASK),
398 					 "%s SVE VL %d/SME VL %d SM+ZA\n",
399 					 cfg->name, sve_vl, sme_vl);
400 			ksft_test_result(do_test(cfg, sve_vl, sme_vl,
401 						 SVCR_SM_MASK),
402 					 "%s SVE VL %d/SME VL %d SM\n",
403 					 cfg->name, sve_vl, sme_vl);
404 			ksft_test_result(do_test(cfg, sve_vl, sme_vl,
405 						 SVCR_ZA_MASK),
406 					 "%s SVE VL %d/SME VL %d ZA\n",
407 					 cfg->name, sve_vl, sme_vl);
408 		}
409 	}
410 }
411 
sve_count_vls(void)412 int sve_count_vls(void)
413 {
414 	unsigned int vq;
415 	int vl_count = 0;
416 	int vl;
417 
418 	if (!(getauxval(AT_HWCAP) & HWCAP_SVE))
419 		return 0;
420 
421 	/*
422 	 * Enumerate up to SVE_VQ_MAX vector lengths
423 	 */
424 	for (vq = SVE_VQ_MAX; vq > 0; --vq) {
425 		vl = prctl(PR_SVE_SET_VL, vq * 16);
426 		if (vl == -1)
427 			ksft_exit_fail_msg("PR_SVE_SET_VL failed: %s (%d)\n",
428 					   strerror(errno), errno);
429 
430 		vl &= PR_SVE_VL_LEN_MASK;
431 
432 		if (vq != sve_vq_from_vl(vl))
433 			vq = sve_vq_from_vl(vl);
434 
435 		vl_count++;
436 	}
437 
438 	return vl_count;
439 }
440 
sme_count_vls(void)441 int sme_count_vls(void)
442 {
443 	unsigned int vq;
444 	int vl_count = 0;
445 	int vl;
446 
447 	if (!(getauxval(AT_HWCAP2) & HWCAP2_SME))
448 		return 0;
449 
450 	/* Ensure we configure a SME VL, used to flag if SVCR is set */
451 	default_sme_vl = 16;
452 
453 	/*
454 	 * Enumerate up to SVE_VQ_MAX vector lengths
455 	 */
456 	for (vq = SVE_VQ_MAX; vq > 0; --vq) {
457 		vl = prctl(PR_SME_SET_VL, vq * 16);
458 		if (vl == -1)
459 			ksft_exit_fail_msg("PR_SME_SET_VL failed: %s (%d)\n",
460 					   strerror(errno), errno);
461 
462 		vl &= PR_SME_VL_LEN_MASK;
463 
464 		if (vq != sve_vq_from_vl(vl))
465 			vq = sve_vq_from_vl(vl);
466 
467 		vl_count++;
468 	}
469 
470 	return vl_count;
471 }
472 
main(void)473 int main(void)
474 {
475 	int i;
476 	int tests = 1;  /* FPSIMD */
477 
478 	srandom(getpid());
479 
480 	ksft_print_header();
481 	tests += sve_count_vls();
482 	tests += (sve_count_vls() * sme_count_vls()) * 3;
483 	ksft_set_plan(ARRAY_SIZE(syscalls) * tests);
484 
485 	if (getauxval(AT_HWCAP2) & HWCAP2_SME_FA64)
486 		ksft_print_msg("SME with FA64\n");
487 	else if (getauxval(AT_HWCAP2) & HWCAP2_SME)
488 		ksft_print_msg("SME without FA64\n");
489 
490 	for (i = 0; i < ARRAY_SIZE(syscalls); i++)
491 		test_one_syscall(&syscalls[i]);
492 
493 	ksft_print_cnts();
494 
495 	return 0;
496 }
497