1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3 * Copyright (C) 2021 ARM Limited.
4 */
5
6 #include <errno.h>
7 #include <stdbool.h>
8 #include <stddef.h>
9 #include <stdio.h>
10 #include <stdlib.h>
11 #include <string.h>
12 #include <unistd.h>
13 #include <sys/auxv.h>
14 #include <sys/prctl.h>
15 #include <asm/hwcap.h>
16 #include <asm/sigcontext.h>
17 #include <asm/unistd.h>
18
19 #include "../../kselftest.h"
20
21 #include "syscall-abi.h"
22
23 #define NUM_VL ((SVE_VQ_MAX - SVE_VQ_MIN) + 1)
24
25 static int default_sme_vl;
26
27 extern void do_syscall(int sve_vl, int sme_vl);
28
fill_random(void * buf,size_t size)29 static void fill_random(void *buf, size_t size)
30 {
31 int i;
32 uint32_t *lbuf = buf;
33
34 /* random() returns a 32 bit number regardless of the size of long */
35 for (i = 0; i < size / sizeof(uint32_t); i++)
36 lbuf[i] = random();
37 }
38
39 /*
40 * We also repeat the test for several syscalls to try to expose different
41 * behaviour.
42 */
43 static struct syscall_cfg {
44 int syscall_nr;
45 const char *name;
46 } syscalls[] = {
47 { __NR_getpid, "getpid()" },
48 { __NR_sched_yield, "sched_yield()" },
49 };
50
51 #define NUM_GPR 31
52 uint64_t gpr_in[NUM_GPR];
53 uint64_t gpr_out[NUM_GPR];
54
setup_gpr(struct syscall_cfg * cfg,int sve_vl,int sme_vl,uint64_t svcr)55 static void setup_gpr(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
56 uint64_t svcr)
57 {
58 fill_random(gpr_in, sizeof(gpr_in));
59 gpr_in[8] = cfg->syscall_nr;
60 memset(gpr_out, 0, sizeof(gpr_out));
61 }
62
check_gpr(struct syscall_cfg * cfg,int sve_vl,int sme_vl,uint64_t svcr)63 static int check_gpr(struct syscall_cfg *cfg, int sve_vl, int sme_vl, uint64_t svcr)
64 {
65 int errors = 0;
66 int i;
67
68 /*
69 * GPR x0-x7 may be clobbered, and all others should be preserved.
70 */
71 for (i = 9; i < ARRAY_SIZE(gpr_in); i++) {
72 if (gpr_in[i] != gpr_out[i]) {
73 ksft_print_msg("%s SVE VL %d mismatch in GPR %d: %llx != %llx\n",
74 cfg->name, sve_vl, i,
75 gpr_in[i], gpr_out[i]);
76 errors++;
77 }
78 }
79
80 return errors;
81 }
82
83 #define NUM_FPR 32
84 uint64_t fpr_in[NUM_FPR * 2];
85 uint64_t fpr_out[NUM_FPR * 2];
86
setup_fpr(struct syscall_cfg * cfg,int sve_vl,int sme_vl,uint64_t svcr)87 static void setup_fpr(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
88 uint64_t svcr)
89 {
90 fill_random(fpr_in, sizeof(fpr_in));
91 memset(fpr_out, 0, sizeof(fpr_out));
92 }
93
check_fpr(struct syscall_cfg * cfg,int sve_vl,int sme_vl,uint64_t svcr)94 static int check_fpr(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
95 uint64_t svcr)
96 {
97 int errors = 0;
98 int i;
99
100 if (!sve_vl) {
101 for (i = 0; i < ARRAY_SIZE(fpr_in); i++) {
102 if (fpr_in[i] != fpr_out[i]) {
103 ksft_print_msg("%s Q%d/%d mismatch %llx != %llx\n",
104 cfg->name,
105 i / 2, i % 2,
106 fpr_in[i], fpr_out[i]);
107 errors++;
108 }
109 }
110 }
111
112 return errors;
113 }
114
115 #define SVE_Z_SHARED_BYTES (128 / 8)
116
117 static uint8_t z_zero[__SVE_ZREG_SIZE(SVE_VQ_MAX)];
118 uint8_t z_in[SVE_NUM_ZREGS * __SVE_ZREG_SIZE(SVE_VQ_MAX)];
119 uint8_t z_out[SVE_NUM_ZREGS * __SVE_ZREG_SIZE(SVE_VQ_MAX)];
120
setup_z(struct syscall_cfg * cfg,int sve_vl,int sme_vl,uint64_t svcr)121 static void setup_z(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
122 uint64_t svcr)
123 {
124 fill_random(z_in, sizeof(z_in));
125 fill_random(z_out, sizeof(z_out));
126 }
127
check_z(struct syscall_cfg * cfg,int sve_vl,int sme_vl,uint64_t svcr)128 static int check_z(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
129 uint64_t svcr)
130 {
131 size_t reg_size = sve_vl;
132 int errors = 0;
133 int i;
134
135 if (!sve_vl)
136 return 0;
137
138 for (i = 0; i < SVE_NUM_ZREGS; i++) {
139 uint8_t *in = &z_in[reg_size * i];
140 uint8_t *out = &z_out[reg_size * i];
141
142 if (svcr & SVCR_SM_MASK) {
143 /*
144 * In streaming mode the whole register should
145 * be cleared by the transition out of
146 * streaming mode.
147 */
148 if (memcmp(z_zero, out, reg_size) != 0) {
149 ksft_print_msg("%s SVE VL %d Z%d non-zero\n",
150 cfg->name, sve_vl, i);
151 errors++;
152 }
153 } else {
154 /*
155 * For standard SVE the low 128 bits should be
156 * preserved and any additional bits cleared.
157 */
158 if (memcmp(in, out, SVE_Z_SHARED_BYTES) != 0) {
159 ksft_print_msg("%s SVE VL %d Z%d low 128 bits changed\n",
160 cfg->name, sve_vl, i);
161 errors++;
162 }
163
164 if (reg_size > SVE_Z_SHARED_BYTES &&
165 (memcmp(z_zero, out + SVE_Z_SHARED_BYTES,
166 reg_size - SVE_Z_SHARED_BYTES) != 0)) {
167 ksft_print_msg("%s SVE VL %d Z%d high bits non-zero\n",
168 cfg->name, sve_vl, i);
169 errors++;
170 }
171 }
172 }
173
174 return errors;
175 }
176
177 uint8_t p_in[SVE_NUM_PREGS * __SVE_PREG_SIZE(SVE_VQ_MAX)];
178 uint8_t p_out[SVE_NUM_PREGS * __SVE_PREG_SIZE(SVE_VQ_MAX)];
179
setup_p(struct syscall_cfg * cfg,int sve_vl,int sme_vl,uint64_t svcr)180 static void setup_p(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
181 uint64_t svcr)
182 {
183 fill_random(p_in, sizeof(p_in));
184 fill_random(p_out, sizeof(p_out));
185 }
186
check_p(struct syscall_cfg * cfg,int sve_vl,int sme_vl,uint64_t svcr)187 static int check_p(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
188 uint64_t svcr)
189 {
190 size_t reg_size = sve_vq_from_vl(sve_vl) * 2; /* 1 bit per VL byte */
191
192 int errors = 0;
193 int i;
194
195 if (!sve_vl)
196 return 0;
197
198 /* After a syscall the P registers should be zeroed */
199 for (i = 0; i < SVE_NUM_PREGS * reg_size; i++)
200 if (p_out[i])
201 errors++;
202 if (errors)
203 ksft_print_msg("%s SVE VL %d predicate registers non-zero\n",
204 cfg->name, sve_vl);
205
206 return errors;
207 }
208
209 uint8_t ffr_in[__SVE_PREG_SIZE(SVE_VQ_MAX)];
210 uint8_t ffr_out[__SVE_PREG_SIZE(SVE_VQ_MAX)];
211
setup_ffr(struct syscall_cfg * cfg,int sve_vl,int sme_vl,uint64_t svcr)212 static void setup_ffr(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
213 uint64_t svcr)
214 {
215 /*
216 * If we are in streaming mode and do not have FA64 then FFR
217 * is unavailable.
218 */
219 if ((svcr & SVCR_SM_MASK) &&
220 !(getauxval(AT_HWCAP2) & HWCAP2_SME_FA64)) {
221 memset(&ffr_in, 0, sizeof(ffr_in));
222 return;
223 }
224
225 /*
226 * It is only valid to set a contiguous set of bits starting
227 * at 0. For now since we're expecting this to be cleared by
228 * a syscall just set all bits.
229 */
230 memset(ffr_in, 0xff, sizeof(ffr_in));
231 fill_random(ffr_out, sizeof(ffr_out));
232 }
233
check_ffr(struct syscall_cfg * cfg,int sve_vl,int sme_vl,uint64_t svcr)234 static int check_ffr(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
235 uint64_t svcr)
236 {
237 size_t reg_size = sve_vq_from_vl(sve_vl) * 2; /* 1 bit per VL byte */
238 int errors = 0;
239 int i;
240
241 if (!sve_vl)
242 return 0;
243
244 if ((svcr & SVCR_SM_MASK) &&
245 !(getauxval(AT_HWCAP2) & HWCAP2_SME_FA64))
246 return 0;
247
248 /* After a syscall FFR should be zeroed */
249 for (i = 0; i < reg_size; i++)
250 if (ffr_out[i])
251 errors++;
252 if (errors)
253 ksft_print_msg("%s SVE VL %d FFR non-zero\n",
254 cfg->name, sve_vl);
255
256 return errors;
257 }
258
259 uint64_t svcr_in, svcr_out;
260
setup_svcr(struct syscall_cfg * cfg,int sve_vl,int sme_vl,uint64_t svcr)261 static void setup_svcr(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
262 uint64_t svcr)
263 {
264 svcr_in = svcr;
265 }
266
check_svcr(struct syscall_cfg * cfg,int sve_vl,int sme_vl,uint64_t svcr)267 static int check_svcr(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
268 uint64_t svcr)
269 {
270 int errors = 0;
271
272 if (svcr_out & SVCR_SM_MASK) {
273 ksft_print_msg("%s Still in SM, SVCR %llx\n",
274 cfg->name, svcr_out);
275 errors++;
276 }
277
278 if ((svcr_in & SVCR_ZA_MASK) != (svcr_out & SVCR_ZA_MASK)) {
279 ksft_print_msg("%s PSTATE.ZA changed, SVCR %llx != %llx\n",
280 cfg->name, svcr_in, svcr_out);
281 errors++;
282 }
283
284 return errors;
285 }
286
287 uint8_t za_in[SVE_NUM_PREGS * __SVE_ZREG_SIZE(SVE_VQ_MAX)];
288 uint8_t za_out[SVE_NUM_PREGS * __SVE_ZREG_SIZE(SVE_VQ_MAX)];
289
setup_za(struct syscall_cfg * cfg,int sve_vl,int sme_vl,uint64_t svcr)290 static void setup_za(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
291 uint64_t svcr)
292 {
293 fill_random(za_in, sizeof(za_in));
294 memset(za_out, 0, sizeof(za_out));
295 }
296
check_za(struct syscall_cfg * cfg,int sve_vl,int sme_vl,uint64_t svcr)297 static int check_za(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
298 uint64_t svcr)
299 {
300 size_t reg_size = sme_vl * sme_vl;
301 int errors = 0;
302
303 if (!(svcr & SVCR_ZA_MASK))
304 return 0;
305
306 if (memcmp(za_in, za_out, reg_size) != 0) {
307 ksft_print_msg("SME VL %d ZA does not match\n", sme_vl);
308 errors++;
309 }
310
311 return errors;
312 }
313
314 typedef void (*setup_fn)(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
315 uint64_t svcr);
316 typedef int (*check_fn)(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
317 uint64_t svcr);
318
319 /*
320 * Each set of registers has a setup function which is called before
321 * the syscall to fill values in a global variable for loading by the
322 * test code and a check function which validates that the results are
323 * as expected. Vector lengths are passed everywhere, a vector length
324 * of 0 should be treated as do not test.
325 */
326 static struct {
327 setup_fn setup;
328 check_fn check;
329 } regset[] = {
330 { setup_gpr, check_gpr },
331 { setup_fpr, check_fpr },
332 { setup_z, check_z },
333 { setup_p, check_p },
334 { setup_ffr, check_ffr },
335 { setup_svcr, check_svcr },
336 { setup_za, check_za },
337 };
338
do_test(struct syscall_cfg * cfg,int sve_vl,int sme_vl,uint64_t svcr)339 static bool do_test(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
340 uint64_t svcr)
341 {
342 int errors = 0;
343 int i;
344
345 for (i = 0; i < ARRAY_SIZE(regset); i++)
346 regset[i].setup(cfg, sve_vl, sme_vl, svcr);
347
348 do_syscall(sve_vl, sme_vl);
349
350 for (i = 0; i < ARRAY_SIZE(regset); i++)
351 errors += regset[i].check(cfg, sve_vl, sme_vl, svcr);
352
353 return errors == 0;
354 }
355
test_one_syscall(struct syscall_cfg * cfg)356 static void test_one_syscall(struct syscall_cfg *cfg)
357 {
358 int sve_vq, sve_vl;
359 int sme_vq, sme_vl;
360
361 /* FPSIMD only case */
362 ksft_test_result(do_test(cfg, 0, default_sme_vl, 0),
363 "%s FPSIMD\n", cfg->name);
364
365 if (!(getauxval(AT_HWCAP) & HWCAP_SVE))
366 return;
367
368 for (sve_vq = SVE_VQ_MAX; sve_vq > 0; --sve_vq) {
369 sve_vl = prctl(PR_SVE_SET_VL, sve_vq * 16);
370 if (sve_vl == -1)
371 ksft_exit_fail_msg("PR_SVE_SET_VL failed: %s (%d)\n",
372 strerror(errno), errno);
373
374 sve_vl &= PR_SVE_VL_LEN_MASK;
375
376 if (sve_vq != sve_vq_from_vl(sve_vl))
377 sve_vq = sve_vq_from_vl(sve_vl);
378
379 ksft_test_result(do_test(cfg, sve_vl, default_sme_vl, 0),
380 "%s SVE VL %d\n", cfg->name, sve_vl);
381
382 if (!(getauxval(AT_HWCAP2) & HWCAP2_SME))
383 continue;
384
385 for (sme_vq = SVE_VQ_MAX; sme_vq > 0; --sme_vq) {
386 sme_vl = prctl(PR_SME_SET_VL, sme_vq * 16);
387 if (sme_vl == -1)
388 ksft_exit_fail_msg("PR_SME_SET_VL failed: %s (%d)\n",
389 strerror(errno), errno);
390
391 sme_vl &= PR_SME_VL_LEN_MASK;
392
393 if (sme_vq != sve_vq_from_vl(sme_vl))
394 sme_vq = sve_vq_from_vl(sme_vl);
395
396 ksft_test_result(do_test(cfg, sve_vl, sme_vl,
397 SVCR_ZA_MASK | SVCR_SM_MASK),
398 "%s SVE VL %d/SME VL %d SM+ZA\n",
399 cfg->name, sve_vl, sme_vl);
400 ksft_test_result(do_test(cfg, sve_vl, sme_vl,
401 SVCR_SM_MASK),
402 "%s SVE VL %d/SME VL %d SM\n",
403 cfg->name, sve_vl, sme_vl);
404 ksft_test_result(do_test(cfg, sve_vl, sme_vl,
405 SVCR_ZA_MASK),
406 "%s SVE VL %d/SME VL %d ZA\n",
407 cfg->name, sve_vl, sme_vl);
408 }
409 }
410 }
411
sve_count_vls(void)412 int sve_count_vls(void)
413 {
414 unsigned int vq;
415 int vl_count = 0;
416 int vl;
417
418 if (!(getauxval(AT_HWCAP) & HWCAP_SVE))
419 return 0;
420
421 /*
422 * Enumerate up to SVE_VQ_MAX vector lengths
423 */
424 for (vq = SVE_VQ_MAX; vq > 0; --vq) {
425 vl = prctl(PR_SVE_SET_VL, vq * 16);
426 if (vl == -1)
427 ksft_exit_fail_msg("PR_SVE_SET_VL failed: %s (%d)\n",
428 strerror(errno), errno);
429
430 vl &= PR_SVE_VL_LEN_MASK;
431
432 if (vq != sve_vq_from_vl(vl))
433 vq = sve_vq_from_vl(vl);
434
435 vl_count++;
436 }
437
438 return vl_count;
439 }
440
sme_count_vls(void)441 int sme_count_vls(void)
442 {
443 unsigned int vq;
444 int vl_count = 0;
445 int vl;
446
447 if (!(getauxval(AT_HWCAP2) & HWCAP2_SME))
448 return 0;
449
450 /* Ensure we configure a SME VL, used to flag if SVCR is set */
451 default_sme_vl = 16;
452
453 /*
454 * Enumerate up to SVE_VQ_MAX vector lengths
455 */
456 for (vq = SVE_VQ_MAX; vq > 0; --vq) {
457 vl = prctl(PR_SME_SET_VL, vq * 16);
458 if (vl == -1)
459 ksft_exit_fail_msg("PR_SME_SET_VL failed: %s (%d)\n",
460 strerror(errno), errno);
461
462 vl &= PR_SME_VL_LEN_MASK;
463
464 if (vq != sve_vq_from_vl(vl))
465 vq = sve_vq_from_vl(vl);
466
467 vl_count++;
468 }
469
470 return vl_count;
471 }
472
main(void)473 int main(void)
474 {
475 int i;
476 int tests = 1; /* FPSIMD */
477
478 srandom(getpid());
479
480 ksft_print_header();
481 tests += sve_count_vls();
482 tests += (sve_count_vls() * sme_count_vls()) * 3;
483 ksft_set_plan(ARRAY_SIZE(syscalls) * tests);
484
485 if (getauxval(AT_HWCAP2) & HWCAP2_SME_FA64)
486 ksft_print_msg("SME with FA64\n");
487 else if (getauxval(AT_HWCAP2) & HWCAP2_SME)
488 ksft_print_msg("SME without FA64\n");
489
490 for (i = 0; i < ARRAY_SIZE(syscalls); i++)
491 test_one_syscall(&syscalls[i]);
492
493 ksft_print_cnts();
494
495 return 0;
496 }
497