1 // SPDX-License-Identifier: GPL-2.0-or-later
2 /*
3 * x64 SIMD accelerated ChaCha and XChaCha stream ciphers,
4 * including ChaCha20 (RFC7539)
5 *
6 * Copyright (C) 2015 Martin Willi
7 */
8
9 #include <crypto/algapi.h>
10 #include <crypto/internal/chacha.h>
11 #include <crypto/internal/simd.h>
12 #include <crypto/internal/skcipher.h>
13 #include <linux/kernel.h>
14 #include <linux/module.h>
15 #include <linux/sizes.h>
16 #include <asm/simd.h>
17
18 asmlinkage void chacha_block_xor_ssse3(u32 *state, u8 *dst, const u8 *src,
19 unsigned int len, int nrounds);
20 asmlinkage void chacha_4block_xor_ssse3(u32 *state, u8 *dst, const u8 *src,
21 unsigned int len, int nrounds);
22 asmlinkage void hchacha_block_ssse3(const u32 *state, u32 *out, int nrounds);
23
24 asmlinkage void chacha_2block_xor_avx2(u32 *state, u8 *dst, const u8 *src,
25 unsigned int len, int nrounds);
26 asmlinkage void chacha_4block_xor_avx2(u32 *state, u8 *dst, const u8 *src,
27 unsigned int len, int nrounds);
28 asmlinkage void chacha_8block_xor_avx2(u32 *state, u8 *dst, const u8 *src,
29 unsigned int len, int nrounds);
30
31 asmlinkage void chacha_2block_xor_avx512vl(u32 *state, u8 *dst, const u8 *src,
32 unsigned int len, int nrounds);
33 asmlinkage void chacha_4block_xor_avx512vl(u32 *state, u8 *dst, const u8 *src,
34 unsigned int len, int nrounds);
35 asmlinkage void chacha_8block_xor_avx512vl(u32 *state, u8 *dst, const u8 *src,
36 unsigned int len, int nrounds);
37
38 static __ro_after_init DEFINE_STATIC_KEY_FALSE(chacha_use_simd);
39 static __ro_after_init DEFINE_STATIC_KEY_FALSE(chacha_use_avx2);
40 static __ro_after_init DEFINE_STATIC_KEY_FALSE(chacha_use_avx512vl);
41
chacha_advance(unsigned int len,unsigned int maxblocks)42 static unsigned int chacha_advance(unsigned int len, unsigned int maxblocks)
43 {
44 len = min(len, maxblocks * CHACHA_BLOCK_SIZE);
45 return round_up(len, CHACHA_BLOCK_SIZE) / CHACHA_BLOCK_SIZE;
46 }
47
chacha_dosimd(u32 * state,u8 * dst,const u8 * src,unsigned int bytes,int nrounds)48 static void chacha_dosimd(u32 *state, u8 *dst, const u8 *src,
49 unsigned int bytes, int nrounds)
50 {
51 if (IS_ENABLED(CONFIG_AS_AVX512) &&
52 static_branch_likely(&chacha_use_avx512vl)) {
53 while (bytes >= CHACHA_BLOCK_SIZE * 8) {
54 chacha_8block_xor_avx512vl(state, dst, src, bytes,
55 nrounds);
56 bytes -= CHACHA_BLOCK_SIZE * 8;
57 src += CHACHA_BLOCK_SIZE * 8;
58 dst += CHACHA_BLOCK_SIZE * 8;
59 state[12] += 8;
60 }
61 if (bytes > CHACHA_BLOCK_SIZE * 4) {
62 chacha_8block_xor_avx512vl(state, dst, src, bytes,
63 nrounds);
64 state[12] += chacha_advance(bytes, 8);
65 return;
66 }
67 if (bytes > CHACHA_BLOCK_SIZE * 2) {
68 chacha_4block_xor_avx512vl(state, dst, src, bytes,
69 nrounds);
70 state[12] += chacha_advance(bytes, 4);
71 return;
72 }
73 if (bytes) {
74 chacha_2block_xor_avx512vl(state, dst, src, bytes,
75 nrounds);
76 state[12] += chacha_advance(bytes, 2);
77 return;
78 }
79 }
80
81 if (static_branch_likely(&chacha_use_avx2)) {
82 while (bytes >= CHACHA_BLOCK_SIZE * 8) {
83 chacha_8block_xor_avx2(state, dst, src, bytes, nrounds);
84 bytes -= CHACHA_BLOCK_SIZE * 8;
85 src += CHACHA_BLOCK_SIZE * 8;
86 dst += CHACHA_BLOCK_SIZE * 8;
87 state[12] += 8;
88 }
89 if (bytes > CHACHA_BLOCK_SIZE * 4) {
90 chacha_8block_xor_avx2(state, dst, src, bytes, nrounds);
91 state[12] += chacha_advance(bytes, 8);
92 return;
93 }
94 if (bytes > CHACHA_BLOCK_SIZE * 2) {
95 chacha_4block_xor_avx2(state, dst, src, bytes, nrounds);
96 state[12] += chacha_advance(bytes, 4);
97 return;
98 }
99 if (bytes > CHACHA_BLOCK_SIZE) {
100 chacha_2block_xor_avx2(state, dst, src, bytes, nrounds);
101 state[12] += chacha_advance(bytes, 2);
102 return;
103 }
104 }
105
106 while (bytes >= CHACHA_BLOCK_SIZE * 4) {
107 chacha_4block_xor_ssse3(state, dst, src, bytes, nrounds);
108 bytes -= CHACHA_BLOCK_SIZE * 4;
109 src += CHACHA_BLOCK_SIZE * 4;
110 dst += CHACHA_BLOCK_SIZE * 4;
111 state[12] += 4;
112 }
113 if (bytes > CHACHA_BLOCK_SIZE) {
114 chacha_4block_xor_ssse3(state, dst, src, bytes, nrounds);
115 state[12] += chacha_advance(bytes, 4);
116 return;
117 }
118 if (bytes) {
119 chacha_block_xor_ssse3(state, dst, src, bytes, nrounds);
120 state[12]++;
121 }
122 }
123
hchacha_block_arch(const u32 * state,u32 * stream,int nrounds)124 void hchacha_block_arch(const u32 *state, u32 *stream, int nrounds)
125 {
126 if (!static_branch_likely(&chacha_use_simd) || !crypto_simd_usable()) {
127 hchacha_block_generic(state, stream, nrounds);
128 } else {
129 kernel_fpu_begin();
130 hchacha_block_ssse3(state, stream, nrounds);
131 kernel_fpu_end();
132 }
133 }
134 EXPORT_SYMBOL(hchacha_block_arch);
135
chacha_crypt_arch(u32 * state,u8 * dst,const u8 * src,unsigned int bytes,int nrounds)136 void chacha_crypt_arch(u32 *state, u8 *dst, const u8 *src, unsigned int bytes,
137 int nrounds)
138 {
139 if (!static_branch_likely(&chacha_use_simd) || !crypto_simd_usable() ||
140 bytes <= CHACHA_BLOCK_SIZE)
141 return chacha_crypt_generic(state, dst, src, bytes, nrounds);
142
143 do {
144 unsigned int todo = min_t(unsigned int, bytes, SZ_4K);
145
146 kernel_fpu_begin();
147 chacha_dosimd(state, dst, src, todo, nrounds);
148 kernel_fpu_end();
149
150 bytes -= todo;
151 src += todo;
152 dst += todo;
153 } while (bytes);
154 }
155 EXPORT_SYMBOL(chacha_crypt_arch);
156
chacha_simd_stream_xor(struct skcipher_request * req,const struct chacha_ctx * ctx,const u8 * iv)157 static int chacha_simd_stream_xor(struct skcipher_request *req,
158 const struct chacha_ctx *ctx, const u8 *iv)
159 {
160 u32 state[CHACHA_STATE_WORDS] __aligned(8);
161 struct skcipher_walk walk;
162 int err;
163
164 err = skcipher_walk_virt(&walk, req, false);
165
166 chacha_init(state, ctx->key, iv);
167
168 while (walk.nbytes > 0) {
169 unsigned int nbytes = walk.nbytes;
170
171 if (nbytes < walk.total)
172 nbytes = round_down(nbytes, walk.stride);
173
174 if (!static_branch_likely(&chacha_use_simd) ||
175 !crypto_simd_usable()) {
176 chacha_crypt_generic(state, walk.dst.virt.addr,
177 walk.src.virt.addr, nbytes,
178 ctx->nrounds);
179 } else {
180 kernel_fpu_begin();
181 chacha_dosimd(state, walk.dst.virt.addr,
182 walk.src.virt.addr, nbytes,
183 ctx->nrounds);
184 kernel_fpu_end();
185 }
186 err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
187 }
188
189 return err;
190 }
191
chacha_simd(struct skcipher_request * req)192 static int chacha_simd(struct skcipher_request *req)
193 {
194 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
195 struct chacha_ctx *ctx = crypto_skcipher_ctx(tfm);
196
197 return chacha_simd_stream_xor(req, ctx, req->iv);
198 }
199
xchacha_simd(struct skcipher_request * req)200 static int xchacha_simd(struct skcipher_request *req)
201 {
202 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
203 struct chacha_ctx *ctx = crypto_skcipher_ctx(tfm);
204 u32 state[CHACHA_STATE_WORDS] __aligned(8);
205 struct chacha_ctx subctx;
206 u8 real_iv[16];
207
208 chacha_init(state, ctx->key, req->iv);
209
210 if (req->cryptlen > CHACHA_BLOCK_SIZE && crypto_simd_usable()) {
211 kernel_fpu_begin();
212 hchacha_block_ssse3(state, subctx.key, ctx->nrounds);
213 kernel_fpu_end();
214 } else {
215 hchacha_block_generic(state, subctx.key, ctx->nrounds);
216 }
217 subctx.nrounds = ctx->nrounds;
218
219 memcpy(&real_iv[0], req->iv + 24, 8);
220 memcpy(&real_iv[8], req->iv + 16, 8);
221 return chacha_simd_stream_xor(req, &subctx, real_iv);
222 }
223
224 static struct skcipher_alg algs[] = {
225 {
226 .base.cra_name = "chacha20",
227 .base.cra_driver_name = "chacha20-simd",
228 .base.cra_priority = 300,
229 .base.cra_blocksize = 1,
230 .base.cra_ctxsize = sizeof(struct chacha_ctx),
231 .base.cra_module = THIS_MODULE,
232
233 .min_keysize = CHACHA_KEY_SIZE,
234 .max_keysize = CHACHA_KEY_SIZE,
235 .ivsize = CHACHA_IV_SIZE,
236 .chunksize = CHACHA_BLOCK_SIZE,
237 .setkey = chacha20_setkey,
238 .encrypt = chacha_simd,
239 .decrypt = chacha_simd,
240 }, {
241 .base.cra_name = "xchacha20",
242 .base.cra_driver_name = "xchacha20-simd",
243 .base.cra_priority = 300,
244 .base.cra_blocksize = 1,
245 .base.cra_ctxsize = sizeof(struct chacha_ctx),
246 .base.cra_module = THIS_MODULE,
247
248 .min_keysize = CHACHA_KEY_SIZE,
249 .max_keysize = CHACHA_KEY_SIZE,
250 .ivsize = XCHACHA_IV_SIZE,
251 .chunksize = CHACHA_BLOCK_SIZE,
252 .setkey = chacha20_setkey,
253 .encrypt = xchacha_simd,
254 .decrypt = xchacha_simd,
255 }, {
256 .base.cra_name = "xchacha12",
257 .base.cra_driver_name = "xchacha12-simd",
258 .base.cra_priority = 300,
259 .base.cra_blocksize = 1,
260 .base.cra_ctxsize = sizeof(struct chacha_ctx),
261 .base.cra_module = THIS_MODULE,
262
263 .min_keysize = CHACHA_KEY_SIZE,
264 .max_keysize = CHACHA_KEY_SIZE,
265 .ivsize = XCHACHA_IV_SIZE,
266 .chunksize = CHACHA_BLOCK_SIZE,
267 .setkey = chacha12_setkey,
268 .encrypt = xchacha_simd,
269 .decrypt = xchacha_simd,
270 },
271 };
272
chacha_simd_mod_init(void)273 static int __init chacha_simd_mod_init(void)
274 {
275 if (!boot_cpu_has(X86_FEATURE_SSSE3))
276 return 0;
277
278 static_branch_enable(&chacha_use_simd);
279
280 if (boot_cpu_has(X86_FEATURE_AVX) &&
281 boot_cpu_has(X86_FEATURE_AVX2) &&
282 cpu_has_xfeatures(XFEATURE_MASK_SSE | XFEATURE_MASK_YMM, NULL)) {
283 static_branch_enable(&chacha_use_avx2);
284
285 if (IS_ENABLED(CONFIG_AS_AVX512) &&
286 boot_cpu_has(X86_FEATURE_AVX512VL) &&
287 boot_cpu_has(X86_FEATURE_AVX512BW)) /* kmovq */
288 static_branch_enable(&chacha_use_avx512vl);
289 }
290 return IS_REACHABLE(CONFIG_CRYPTO_SKCIPHER) ?
291 crypto_register_skciphers(algs, ARRAY_SIZE(algs)) : 0;
292 }
293
chacha_simd_mod_fini(void)294 static void __exit chacha_simd_mod_fini(void)
295 {
296 if (IS_REACHABLE(CONFIG_CRYPTO_SKCIPHER) && boot_cpu_has(X86_FEATURE_SSSE3))
297 crypto_unregister_skciphers(algs, ARRAY_SIZE(algs));
298 }
299
300 module_init(chacha_simd_mod_init);
301 module_exit(chacha_simd_mod_fini);
302
303 MODULE_LICENSE("GPL");
304 MODULE_AUTHOR("Martin Willi <martin@strongswan.org>");
305 MODULE_DESCRIPTION("ChaCha and XChaCha stream ciphers (x64 SIMD accelerated)");
306 MODULE_ALIAS_CRYPTO("chacha20");
307 MODULE_ALIAS_CRYPTO("chacha20-simd");
308 MODULE_ALIAS_CRYPTO("xchacha20");
309 MODULE_ALIAS_CRYPTO("xchacha20-simd");
310 MODULE_ALIAS_CRYPTO("xchacha12");
311 MODULE_ALIAS_CRYPTO("xchacha12-simd");
312