1 /* SPDX-License-Identifier: GPL-2.0-or-later */
2 /*
3 * SM4 Cipher Algorithm, AES-NI/AVX optimized.
4 * as specified in
5 * https://tools.ietf.org/id/draft-ribose-cfrg-sm4-10.html
6 *
7 * Copyright (c) 2021, Alibaba Group.
8 * Copyright (c) 2021 Tianjia Zhang <tianjia.zhang@linux.alibaba.com>
9 */
10
11 #include <linux/module.h>
12 #include <linux/crypto.h>
13 #include <linux/kernel.h>
14 #include <asm/simd.h>
15 #include <crypto/internal/simd.h>
16 #include <crypto/internal/skcipher.h>
17 #include <crypto/sm4.h>
18 #include "sm4-avx.h"
19
20 #define SM4_CRYPT8_BLOCK_SIZE (SM4_BLOCK_SIZE * 8)
21
22 asmlinkage void sm4_aesni_avx_crypt4(const u32 *rk, u8 *dst,
23 const u8 *src, int nblocks);
24 asmlinkage void sm4_aesni_avx_crypt8(const u32 *rk, u8 *dst,
25 const u8 *src, int nblocks);
26 asmlinkage void sm4_aesni_avx_ctr_enc_blk8(const u32 *rk, u8 *dst,
27 const u8 *src, u8 *iv);
28 asmlinkage void sm4_aesni_avx_cbc_dec_blk8(const u32 *rk, u8 *dst,
29 const u8 *src, u8 *iv);
30 asmlinkage void sm4_aesni_avx_cfb_dec_blk8(const u32 *rk, u8 *dst,
31 const u8 *src, u8 *iv);
32
sm4_skcipher_setkey(struct crypto_skcipher * tfm,const u8 * key,unsigned int key_len)33 static int sm4_skcipher_setkey(struct crypto_skcipher *tfm, const u8 *key,
34 unsigned int key_len)
35 {
36 struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
37
38 return sm4_expandkey(ctx, key, key_len);
39 }
40
ecb_do_crypt(struct skcipher_request * req,const u32 * rkey)41 static int ecb_do_crypt(struct skcipher_request *req, const u32 *rkey)
42 {
43 struct skcipher_walk walk;
44 unsigned int nbytes;
45 int err;
46
47 err = skcipher_walk_virt(&walk, req, false);
48
49 while ((nbytes = walk.nbytes) > 0) {
50 const u8 *src = walk.src.virt.addr;
51 u8 *dst = walk.dst.virt.addr;
52
53 kernel_fpu_begin();
54 while (nbytes >= SM4_CRYPT8_BLOCK_SIZE) {
55 sm4_aesni_avx_crypt8(rkey, dst, src, 8);
56 dst += SM4_CRYPT8_BLOCK_SIZE;
57 src += SM4_CRYPT8_BLOCK_SIZE;
58 nbytes -= SM4_CRYPT8_BLOCK_SIZE;
59 }
60 while (nbytes >= SM4_BLOCK_SIZE) {
61 unsigned int nblocks = min(nbytes >> 4, 4u);
62 sm4_aesni_avx_crypt4(rkey, dst, src, nblocks);
63 dst += nblocks * SM4_BLOCK_SIZE;
64 src += nblocks * SM4_BLOCK_SIZE;
65 nbytes -= nblocks * SM4_BLOCK_SIZE;
66 }
67 kernel_fpu_end();
68
69 err = skcipher_walk_done(&walk, nbytes);
70 }
71
72 return err;
73 }
74
sm4_avx_ecb_encrypt(struct skcipher_request * req)75 int sm4_avx_ecb_encrypt(struct skcipher_request *req)
76 {
77 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
78 struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
79
80 return ecb_do_crypt(req, ctx->rkey_enc);
81 }
82 EXPORT_SYMBOL_GPL(sm4_avx_ecb_encrypt);
83
sm4_avx_ecb_decrypt(struct skcipher_request * req)84 int sm4_avx_ecb_decrypt(struct skcipher_request *req)
85 {
86 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
87 struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
88
89 return ecb_do_crypt(req, ctx->rkey_dec);
90 }
91 EXPORT_SYMBOL_GPL(sm4_avx_ecb_decrypt);
92
sm4_cbc_encrypt(struct skcipher_request * req)93 int sm4_cbc_encrypt(struct skcipher_request *req)
94 {
95 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
96 struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
97 struct skcipher_walk walk;
98 unsigned int nbytes;
99 int err;
100
101 err = skcipher_walk_virt(&walk, req, false);
102
103 while ((nbytes = walk.nbytes) > 0) {
104 const u8 *iv = walk.iv;
105 const u8 *src = walk.src.virt.addr;
106 u8 *dst = walk.dst.virt.addr;
107
108 while (nbytes >= SM4_BLOCK_SIZE) {
109 crypto_xor_cpy(dst, src, iv, SM4_BLOCK_SIZE);
110 sm4_crypt_block(ctx->rkey_enc, dst, dst);
111 iv = dst;
112 src += SM4_BLOCK_SIZE;
113 dst += SM4_BLOCK_SIZE;
114 nbytes -= SM4_BLOCK_SIZE;
115 }
116 if (iv != walk.iv)
117 memcpy(walk.iv, iv, SM4_BLOCK_SIZE);
118
119 err = skcipher_walk_done(&walk, nbytes);
120 }
121
122 return err;
123 }
124 EXPORT_SYMBOL_GPL(sm4_cbc_encrypt);
125
sm4_avx_cbc_decrypt(struct skcipher_request * req,unsigned int bsize,sm4_crypt_func func)126 int sm4_avx_cbc_decrypt(struct skcipher_request *req,
127 unsigned int bsize, sm4_crypt_func func)
128 {
129 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
130 struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
131 struct skcipher_walk walk;
132 unsigned int nbytes;
133 int err;
134
135 err = skcipher_walk_virt(&walk, req, false);
136
137 while ((nbytes = walk.nbytes) > 0) {
138 const u8 *src = walk.src.virt.addr;
139 u8 *dst = walk.dst.virt.addr;
140
141 kernel_fpu_begin();
142
143 while (nbytes >= bsize) {
144 func(ctx->rkey_dec, dst, src, walk.iv);
145 dst += bsize;
146 src += bsize;
147 nbytes -= bsize;
148 }
149
150 while (nbytes >= SM4_BLOCK_SIZE) {
151 u8 keystream[SM4_BLOCK_SIZE * 8];
152 u8 iv[SM4_BLOCK_SIZE];
153 unsigned int nblocks = min(nbytes >> 4, 8u);
154 int i;
155
156 sm4_aesni_avx_crypt8(ctx->rkey_dec, keystream,
157 src, nblocks);
158
159 src += ((int)nblocks - 2) * SM4_BLOCK_SIZE;
160 dst += (nblocks - 1) * SM4_BLOCK_SIZE;
161 memcpy(iv, src + SM4_BLOCK_SIZE, SM4_BLOCK_SIZE);
162
163 for (i = nblocks - 1; i > 0; i--) {
164 crypto_xor_cpy(dst, src,
165 &keystream[i * SM4_BLOCK_SIZE],
166 SM4_BLOCK_SIZE);
167 src -= SM4_BLOCK_SIZE;
168 dst -= SM4_BLOCK_SIZE;
169 }
170 crypto_xor_cpy(dst, walk.iv, keystream, SM4_BLOCK_SIZE);
171 memcpy(walk.iv, iv, SM4_BLOCK_SIZE);
172 dst += nblocks * SM4_BLOCK_SIZE;
173 src += (nblocks + 1) * SM4_BLOCK_SIZE;
174 nbytes -= nblocks * SM4_BLOCK_SIZE;
175 }
176
177 kernel_fpu_end();
178 err = skcipher_walk_done(&walk, nbytes);
179 }
180
181 return err;
182 }
183 EXPORT_SYMBOL_GPL(sm4_avx_cbc_decrypt);
184
cbc_decrypt(struct skcipher_request * req)185 static int cbc_decrypt(struct skcipher_request *req)
186 {
187 return sm4_avx_cbc_decrypt(req, SM4_CRYPT8_BLOCK_SIZE,
188 sm4_aesni_avx_cbc_dec_blk8);
189 }
190
sm4_cfb_encrypt(struct skcipher_request * req)191 int sm4_cfb_encrypt(struct skcipher_request *req)
192 {
193 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
194 struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
195 struct skcipher_walk walk;
196 unsigned int nbytes;
197 int err;
198
199 err = skcipher_walk_virt(&walk, req, false);
200
201 while ((nbytes = walk.nbytes) > 0) {
202 u8 keystream[SM4_BLOCK_SIZE];
203 const u8 *iv = walk.iv;
204 const u8 *src = walk.src.virt.addr;
205 u8 *dst = walk.dst.virt.addr;
206
207 while (nbytes >= SM4_BLOCK_SIZE) {
208 sm4_crypt_block(ctx->rkey_enc, keystream, iv);
209 crypto_xor_cpy(dst, src, keystream, SM4_BLOCK_SIZE);
210 iv = dst;
211 src += SM4_BLOCK_SIZE;
212 dst += SM4_BLOCK_SIZE;
213 nbytes -= SM4_BLOCK_SIZE;
214 }
215 if (iv != walk.iv)
216 memcpy(walk.iv, iv, SM4_BLOCK_SIZE);
217
218 /* tail */
219 if (walk.nbytes == walk.total && nbytes > 0) {
220 sm4_crypt_block(ctx->rkey_enc, keystream, walk.iv);
221 crypto_xor_cpy(dst, src, keystream, nbytes);
222 nbytes = 0;
223 }
224
225 err = skcipher_walk_done(&walk, nbytes);
226 }
227
228 return err;
229 }
230 EXPORT_SYMBOL_GPL(sm4_cfb_encrypt);
231
sm4_avx_cfb_decrypt(struct skcipher_request * req,unsigned int bsize,sm4_crypt_func func)232 int sm4_avx_cfb_decrypt(struct skcipher_request *req,
233 unsigned int bsize, sm4_crypt_func func)
234 {
235 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
236 struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
237 struct skcipher_walk walk;
238 unsigned int nbytes;
239 int err;
240
241 err = skcipher_walk_virt(&walk, req, false);
242
243 while ((nbytes = walk.nbytes) > 0) {
244 const u8 *src = walk.src.virt.addr;
245 u8 *dst = walk.dst.virt.addr;
246
247 kernel_fpu_begin();
248
249 while (nbytes >= bsize) {
250 func(ctx->rkey_enc, dst, src, walk.iv);
251 dst += bsize;
252 src += bsize;
253 nbytes -= bsize;
254 }
255
256 while (nbytes >= SM4_BLOCK_SIZE) {
257 u8 keystream[SM4_BLOCK_SIZE * 8];
258 unsigned int nblocks = min(nbytes >> 4, 8u);
259
260 memcpy(keystream, walk.iv, SM4_BLOCK_SIZE);
261 if (nblocks > 1)
262 memcpy(&keystream[SM4_BLOCK_SIZE], src,
263 (nblocks - 1) * SM4_BLOCK_SIZE);
264 memcpy(walk.iv, src + (nblocks - 1) * SM4_BLOCK_SIZE,
265 SM4_BLOCK_SIZE);
266
267 sm4_aesni_avx_crypt8(ctx->rkey_enc, keystream,
268 keystream, nblocks);
269
270 crypto_xor_cpy(dst, src, keystream,
271 nblocks * SM4_BLOCK_SIZE);
272 dst += nblocks * SM4_BLOCK_SIZE;
273 src += nblocks * SM4_BLOCK_SIZE;
274 nbytes -= nblocks * SM4_BLOCK_SIZE;
275 }
276
277 kernel_fpu_end();
278
279 /* tail */
280 if (walk.nbytes == walk.total && nbytes > 0) {
281 u8 keystream[SM4_BLOCK_SIZE];
282
283 sm4_crypt_block(ctx->rkey_enc, keystream, walk.iv);
284 crypto_xor_cpy(dst, src, keystream, nbytes);
285 nbytes = 0;
286 }
287
288 err = skcipher_walk_done(&walk, nbytes);
289 }
290
291 return err;
292 }
293 EXPORT_SYMBOL_GPL(sm4_avx_cfb_decrypt);
294
cfb_decrypt(struct skcipher_request * req)295 static int cfb_decrypt(struct skcipher_request *req)
296 {
297 return sm4_avx_cfb_decrypt(req, SM4_CRYPT8_BLOCK_SIZE,
298 sm4_aesni_avx_cfb_dec_blk8);
299 }
300
sm4_avx_ctr_crypt(struct skcipher_request * req,unsigned int bsize,sm4_crypt_func func)301 int sm4_avx_ctr_crypt(struct skcipher_request *req,
302 unsigned int bsize, sm4_crypt_func func)
303 {
304 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
305 struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
306 struct skcipher_walk walk;
307 unsigned int nbytes;
308 int err;
309
310 err = skcipher_walk_virt(&walk, req, false);
311
312 while ((nbytes = walk.nbytes) > 0) {
313 const u8 *src = walk.src.virt.addr;
314 u8 *dst = walk.dst.virt.addr;
315
316 kernel_fpu_begin();
317
318 while (nbytes >= bsize) {
319 func(ctx->rkey_enc, dst, src, walk.iv);
320 dst += bsize;
321 src += bsize;
322 nbytes -= bsize;
323 }
324
325 while (nbytes >= SM4_BLOCK_SIZE) {
326 u8 keystream[SM4_BLOCK_SIZE * 8];
327 unsigned int nblocks = min(nbytes >> 4, 8u);
328 int i;
329
330 for (i = 0; i < nblocks; i++) {
331 memcpy(&keystream[i * SM4_BLOCK_SIZE],
332 walk.iv, SM4_BLOCK_SIZE);
333 crypto_inc(walk.iv, SM4_BLOCK_SIZE);
334 }
335 sm4_aesni_avx_crypt8(ctx->rkey_enc, keystream,
336 keystream, nblocks);
337
338 crypto_xor_cpy(dst, src, keystream,
339 nblocks * SM4_BLOCK_SIZE);
340 dst += nblocks * SM4_BLOCK_SIZE;
341 src += nblocks * SM4_BLOCK_SIZE;
342 nbytes -= nblocks * SM4_BLOCK_SIZE;
343 }
344
345 kernel_fpu_end();
346
347 /* tail */
348 if (walk.nbytes == walk.total && nbytes > 0) {
349 u8 keystream[SM4_BLOCK_SIZE];
350
351 memcpy(keystream, walk.iv, SM4_BLOCK_SIZE);
352 crypto_inc(walk.iv, SM4_BLOCK_SIZE);
353
354 sm4_crypt_block(ctx->rkey_enc, keystream, keystream);
355
356 crypto_xor_cpy(dst, src, keystream, nbytes);
357 dst += nbytes;
358 src += nbytes;
359 nbytes = 0;
360 }
361
362 err = skcipher_walk_done(&walk, nbytes);
363 }
364
365 return err;
366 }
367 EXPORT_SYMBOL_GPL(sm4_avx_ctr_crypt);
368
ctr_crypt(struct skcipher_request * req)369 static int ctr_crypt(struct skcipher_request *req)
370 {
371 return sm4_avx_ctr_crypt(req, SM4_CRYPT8_BLOCK_SIZE,
372 sm4_aesni_avx_ctr_enc_blk8);
373 }
374
375 static struct skcipher_alg sm4_aesni_avx_skciphers[] = {
376 {
377 .base = {
378 .cra_name = "__ecb(sm4)",
379 .cra_driver_name = "__ecb-sm4-aesni-avx",
380 .cra_priority = 400,
381 .cra_flags = CRYPTO_ALG_INTERNAL,
382 .cra_blocksize = SM4_BLOCK_SIZE,
383 .cra_ctxsize = sizeof(struct sm4_ctx),
384 .cra_module = THIS_MODULE,
385 },
386 .min_keysize = SM4_KEY_SIZE,
387 .max_keysize = SM4_KEY_SIZE,
388 .walksize = 8 * SM4_BLOCK_SIZE,
389 .setkey = sm4_skcipher_setkey,
390 .encrypt = sm4_avx_ecb_encrypt,
391 .decrypt = sm4_avx_ecb_decrypt,
392 }, {
393 .base = {
394 .cra_name = "__cbc(sm4)",
395 .cra_driver_name = "__cbc-sm4-aesni-avx",
396 .cra_priority = 400,
397 .cra_flags = CRYPTO_ALG_INTERNAL,
398 .cra_blocksize = SM4_BLOCK_SIZE,
399 .cra_ctxsize = sizeof(struct sm4_ctx),
400 .cra_module = THIS_MODULE,
401 },
402 .min_keysize = SM4_KEY_SIZE,
403 .max_keysize = SM4_KEY_SIZE,
404 .ivsize = SM4_BLOCK_SIZE,
405 .walksize = 8 * SM4_BLOCK_SIZE,
406 .setkey = sm4_skcipher_setkey,
407 .encrypt = sm4_cbc_encrypt,
408 .decrypt = cbc_decrypt,
409 }, {
410 .base = {
411 .cra_name = "__cfb(sm4)",
412 .cra_driver_name = "__cfb-sm4-aesni-avx",
413 .cra_priority = 400,
414 .cra_flags = CRYPTO_ALG_INTERNAL,
415 .cra_blocksize = 1,
416 .cra_ctxsize = sizeof(struct sm4_ctx),
417 .cra_module = THIS_MODULE,
418 },
419 .min_keysize = SM4_KEY_SIZE,
420 .max_keysize = SM4_KEY_SIZE,
421 .ivsize = SM4_BLOCK_SIZE,
422 .chunksize = SM4_BLOCK_SIZE,
423 .walksize = 8 * SM4_BLOCK_SIZE,
424 .setkey = sm4_skcipher_setkey,
425 .encrypt = sm4_cfb_encrypt,
426 .decrypt = cfb_decrypt,
427 }, {
428 .base = {
429 .cra_name = "__ctr(sm4)",
430 .cra_driver_name = "__ctr-sm4-aesni-avx",
431 .cra_priority = 400,
432 .cra_flags = CRYPTO_ALG_INTERNAL,
433 .cra_blocksize = 1,
434 .cra_ctxsize = sizeof(struct sm4_ctx),
435 .cra_module = THIS_MODULE,
436 },
437 .min_keysize = SM4_KEY_SIZE,
438 .max_keysize = SM4_KEY_SIZE,
439 .ivsize = SM4_BLOCK_SIZE,
440 .chunksize = SM4_BLOCK_SIZE,
441 .walksize = 8 * SM4_BLOCK_SIZE,
442 .setkey = sm4_skcipher_setkey,
443 .encrypt = ctr_crypt,
444 .decrypt = ctr_crypt,
445 }
446 };
447
448 static struct simd_skcipher_alg *
449 simd_sm4_aesni_avx_skciphers[ARRAY_SIZE(sm4_aesni_avx_skciphers)];
450
sm4_init(void)451 static int __init sm4_init(void)
452 {
453 const char *feature_name;
454
455 if (!boot_cpu_has(X86_FEATURE_AVX) ||
456 !boot_cpu_has(X86_FEATURE_AES) ||
457 !boot_cpu_has(X86_FEATURE_OSXSAVE)) {
458 pr_info("AVX or AES-NI instructions are not detected.\n");
459 return -ENODEV;
460 }
461
462 if (!cpu_has_xfeatures(XFEATURE_MASK_SSE | XFEATURE_MASK_YMM,
463 &feature_name)) {
464 pr_info("CPU feature '%s' is not supported.\n", feature_name);
465 return -ENODEV;
466 }
467
468 return simd_register_skciphers_compat(sm4_aesni_avx_skciphers,
469 ARRAY_SIZE(sm4_aesni_avx_skciphers),
470 simd_sm4_aesni_avx_skciphers);
471 }
472
sm4_exit(void)473 static void __exit sm4_exit(void)
474 {
475 simd_unregister_skciphers(sm4_aesni_avx_skciphers,
476 ARRAY_SIZE(sm4_aesni_avx_skciphers),
477 simd_sm4_aesni_avx_skciphers);
478 }
479
480 module_init(sm4_init);
481 module_exit(sm4_exit);
482
483 MODULE_LICENSE("GPL v2");
484 MODULE_AUTHOR("Tianjia Zhang <tianjia.zhang@linux.alibaba.com>");
485 MODULE_DESCRIPTION("SM4 Cipher Algorithm, AES-NI/AVX optimized");
486 MODULE_ALIAS_CRYPTO("sm4");
487 MODULE_ALIAS_CRYPTO("sm4-aesni-avx");
488