1 // SPDX-License-Identifier: GPL-2.0-or-later
2 /*
3 * Cryptographic API.
4 *
5 * Cipher operations.
6 *
7 * Copyright (c) 2002 James Morris <jmorris@intercode.com.au>
8 * 2002 Adam J. Richter <adam@yggdrasil.com>
9 * 2004 Jean-Luc Cooke <jlcooke@certainkey.com>
10 */
11
12 #include <crypto/scatterwalk.h>
13 #include <linux/crypto.h>
14 #include <linux/errno.h>
15 #include <linux/kernel.h>
16 #include <linux/mm.h>
17 #include <linux/module.h>
18 #include <linux/scatterlist.h>
19 #include <linux/slab.h>
20
21 enum {
22 SKCIPHER_WALK_SLOW = 1 << 0,
23 SKCIPHER_WALK_COPY = 1 << 1,
24 SKCIPHER_WALK_DIFF = 1 << 2,
25 SKCIPHER_WALK_SLEEP = 1 << 3,
26 };
27
skcipher_walk_gfp(struct skcipher_walk * walk)28 static inline gfp_t skcipher_walk_gfp(struct skcipher_walk *walk)
29 {
30 return walk->flags & SKCIPHER_WALK_SLEEP ? GFP_KERNEL : GFP_ATOMIC;
31 }
32
scatterwalk_skip(struct scatter_walk * walk,unsigned int nbytes)33 void scatterwalk_skip(struct scatter_walk *walk, unsigned int nbytes)
34 {
35 struct scatterlist *sg = walk->sg;
36
37 nbytes += walk->offset - sg->offset;
38
39 while (nbytes > sg->length) {
40 nbytes -= sg->length;
41 sg = sg_next(sg);
42 }
43 walk->sg = sg;
44 walk->offset = sg->offset + nbytes;
45 }
46 EXPORT_SYMBOL_GPL(scatterwalk_skip);
47
memcpy_from_scatterwalk(void * buf,struct scatter_walk * walk,unsigned int nbytes)48 inline void memcpy_from_scatterwalk(void *buf, struct scatter_walk *walk,
49 unsigned int nbytes)
50 {
51 do {
52 unsigned int to_copy;
53
54 to_copy = scatterwalk_next(walk, nbytes);
55 memcpy(buf, walk->addr, to_copy);
56 scatterwalk_done_src(walk, to_copy);
57 buf += to_copy;
58 nbytes -= to_copy;
59 } while (nbytes);
60 }
61 EXPORT_SYMBOL_GPL(memcpy_from_scatterwalk);
62
memcpy_to_scatterwalk(struct scatter_walk * walk,const void * buf,unsigned int nbytes)63 inline void memcpy_to_scatterwalk(struct scatter_walk *walk, const void *buf,
64 unsigned int nbytes)
65 {
66 do {
67 unsigned int to_copy;
68
69 to_copy = scatterwalk_next(walk, nbytes);
70 memcpy(walk->addr, buf, to_copy);
71 scatterwalk_done_dst(walk, to_copy);
72 buf += to_copy;
73 nbytes -= to_copy;
74 } while (nbytes);
75 }
76 EXPORT_SYMBOL_GPL(memcpy_to_scatterwalk);
77
memcpy_from_sglist(void * buf,struct scatterlist * sg,unsigned int start,unsigned int nbytes)78 void memcpy_from_sglist(void *buf, struct scatterlist *sg,
79 unsigned int start, unsigned int nbytes)
80 {
81 struct scatter_walk walk;
82
83 if (unlikely(nbytes == 0)) /* in case sg == NULL */
84 return;
85
86 scatterwalk_start_at_pos(&walk, sg, start);
87 memcpy_from_scatterwalk(buf, &walk, nbytes);
88 }
89 EXPORT_SYMBOL_GPL(memcpy_from_sglist);
90
memcpy_to_sglist(struct scatterlist * sg,unsigned int start,const void * buf,unsigned int nbytes)91 void memcpy_to_sglist(struct scatterlist *sg, unsigned int start,
92 const void *buf, unsigned int nbytes)
93 {
94 struct scatter_walk walk;
95
96 if (unlikely(nbytes == 0)) /* in case sg == NULL */
97 return;
98
99 scatterwalk_start_at_pos(&walk, sg, start);
100 memcpy_to_scatterwalk(&walk, buf, nbytes);
101 }
102 EXPORT_SYMBOL_GPL(memcpy_to_sglist);
103
memcpy_sglist(struct scatterlist * dst,struct scatterlist * src,unsigned int nbytes)104 void memcpy_sglist(struct scatterlist *dst, struct scatterlist *src,
105 unsigned int nbytes)
106 {
107 struct skcipher_walk walk = {};
108
109 if (unlikely(nbytes == 0)) /* in case sg == NULL */
110 return;
111
112 walk.total = nbytes;
113
114 scatterwalk_start(&walk.in, src);
115 scatterwalk_start(&walk.out, dst);
116
117 skcipher_walk_first(&walk, true);
118 do {
119 if (walk.src.virt.addr != walk.dst.virt.addr)
120 memcpy(walk.dst.virt.addr, walk.src.virt.addr,
121 walk.nbytes);
122 skcipher_walk_done(&walk, 0);
123 } while (walk.nbytes);
124 }
125 EXPORT_SYMBOL_GPL(memcpy_sglist);
126
scatterwalk_ffwd(struct scatterlist dst[2],struct scatterlist * src,unsigned int len)127 struct scatterlist *scatterwalk_ffwd(struct scatterlist dst[2],
128 struct scatterlist *src,
129 unsigned int len)
130 {
131 for (;;) {
132 if (!len)
133 return src;
134
135 if (src->length > len)
136 break;
137
138 len -= src->length;
139 src = sg_next(src);
140 }
141
142 sg_init_table(dst, 2);
143 sg_set_page(dst, sg_page(src), src->length - len, src->offset + len);
144 scatterwalk_crypto_chain(dst, sg_next(src), 2);
145
146 return dst;
147 }
148 EXPORT_SYMBOL_GPL(scatterwalk_ffwd);
149
skcipher_next_slow(struct skcipher_walk * walk,unsigned int bsize)150 static int skcipher_next_slow(struct skcipher_walk *walk, unsigned int bsize)
151 {
152 unsigned alignmask = walk->alignmask;
153 unsigned n;
154 void *buffer;
155
156 if (!walk->buffer)
157 walk->buffer = walk->page;
158 buffer = walk->buffer;
159 if (!buffer) {
160 /* Min size for a buffer of bsize bytes aligned to alignmask */
161 n = bsize + (alignmask & ~(crypto_tfm_ctx_alignment() - 1));
162
163 buffer = kzalloc(n, skcipher_walk_gfp(walk));
164 if (!buffer)
165 return skcipher_walk_done(walk, -ENOMEM);
166 walk->buffer = buffer;
167 }
168
169 buffer = PTR_ALIGN(buffer, alignmask + 1);
170 memcpy_from_scatterwalk(buffer, &walk->in, bsize);
171 walk->out.__addr = buffer;
172 walk->in.__addr = walk->out.addr;
173
174 walk->nbytes = bsize;
175 walk->flags |= SKCIPHER_WALK_SLOW;
176
177 return 0;
178 }
179
skcipher_next_copy(struct skcipher_walk * walk)180 static int skcipher_next_copy(struct skcipher_walk *walk)
181 {
182 void *tmp = walk->page;
183
184 scatterwalk_map(&walk->in);
185 memcpy(tmp, walk->in.addr, walk->nbytes);
186 scatterwalk_unmap(&walk->in);
187 /*
188 * walk->in is advanced later when the number of bytes actually
189 * processed (which might be less than walk->nbytes) is known.
190 */
191
192 walk->in.__addr = tmp;
193 walk->out.__addr = tmp;
194 return 0;
195 }
196
skcipher_next_fast(struct skcipher_walk * walk)197 static int skcipher_next_fast(struct skcipher_walk *walk)
198 {
199 unsigned long diff;
200
201 diff = offset_in_page(walk->in.offset) -
202 offset_in_page(walk->out.offset);
203 diff |= (u8 *)(sg_page(walk->in.sg) + (walk->in.offset >> PAGE_SHIFT)) -
204 (u8 *)(sg_page(walk->out.sg) + (walk->out.offset >> PAGE_SHIFT));
205
206 scatterwalk_map(&walk->out);
207 walk->in.__addr = walk->out.__addr;
208
209 if (diff) {
210 walk->flags |= SKCIPHER_WALK_DIFF;
211 scatterwalk_map(&walk->in);
212 }
213
214 return 0;
215 }
216
skcipher_walk_next(struct skcipher_walk * walk)217 static int skcipher_walk_next(struct skcipher_walk *walk)
218 {
219 unsigned int bsize;
220 unsigned int n;
221
222 n = walk->total;
223 bsize = min(walk->stride, max(n, walk->blocksize));
224 n = scatterwalk_clamp(&walk->in, n);
225 n = scatterwalk_clamp(&walk->out, n);
226
227 if (unlikely(n < bsize)) {
228 if (unlikely(walk->total < walk->blocksize))
229 return skcipher_walk_done(walk, -EINVAL);
230
231 slow_path:
232 return skcipher_next_slow(walk, bsize);
233 }
234 walk->nbytes = n;
235
236 if (unlikely((walk->in.offset | walk->out.offset) & walk->alignmask)) {
237 if (!walk->page) {
238 gfp_t gfp = skcipher_walk_gfp(walk);
239
240 walk->page = (void *)__get_free_page(gfp);
241 if (!walk->page)
242 goto slow_path;
243 }
244 walk->flags |= SKCIPHER_WALK_COPY;
245 return skcipher_next_copy(walk);
246 }
247
248 return skcipher_next_fast(walk);
249 }
250
skcipher_copy_iv(struct skcipher_walk * walk)251 static int skcipher_copy_iv(struct skcipher_walk *walk)
252 {
253 unsigned alignmask = walk->alignmask;
254 unsigned ivsize = walk->ivsize;
255 unsigned aligned_stride = ALIGN(walk->stride, alignmask + 1);
256 unsigned size;
257 u8 *iv;
258
259 /* Min size for a buffer of stride + ivsize, aligned to alignmask */
260 size = aligned_stride + ivsize +
261 (alignmask & ~(crypto_tfm_ctx_alignment() - 1));
262
263 walk->buffer = kmalloc(size, skcipher_walk_gfp(walk));
264 if (!walk->buffer)
265 return -ENOMEM;
266
267 iv = PTR_ALIGN(walk->buffer, alignmask + 1) + aligned_stride;
268
269 walk->iv = memcpy(iv, walk->iv, walk->ivsize);
270 return 0;
271 }
272
skcipher_walk_first(struct skcipher_walk * walk,bool atomic)273 int skcipher_walk_first(struct skcipher_walk *walk, bool atomic)
274 {
275 if (WARN_ON_ONCE(in_hardirq()))
276 return -EDEADLK;
277
278 walk->flags = atomic ? 0 : SKCIPHER_WALK_SLEEP;
279
280 walk->buffer = NULL;
281 if (unlikely(((unsigned long)walk->iv & walk->alignmask))) {
282 int err = skcipher_copy_iv(walk);
283 if (err)
284 return err;
285 }
286
287 walk->page = NULL;
288
289 return skcipher_walk_next(walk);
290 }
291 EXPORT_SYMBOL_GPL(skcipher_walk_first);
292
293 /**
294 * skcipher_walk_done() - finish one step of a skcipher_walk
295 * @walk: the skcipher_walk
296 * @res: number of bytes *not* processed (>= 0) from walk->nbytes,
297 * or a -errno value to terminate the walk due to an error
298 *
299 * This function cleans up after one step of walking through the source and
300 * destination scatterlists, and advances to the next step if applicable.
301 * walk->nbytes is set to the number of bytes available in the next step,
302 * walk->total is set to the new total number of bytes remaining, and
303 * walk->{src,dst}.virt.addr is set to the next pair of data pointers. If there
304 * is no more data, or if an error occurred (i.e. -errno return), then
305 * walk->nbytes and walk->total are set to 0 and all resources owned by the
306 * skcipher_walk are freed.
307 *
308 * Return: 0 or a -errno value. If @res was a -errno value then it will be
309 * returned, but other errors may occur too.
310 */
skcipher_walk_done(struct skcipher_walk * walk,int res)311 int skcipher_walk_done(struct skcipher_walk *walk, int res)
312 {
313 unsigned int n = walk->nbytes; /* num bytes processed this step */
314 unsigned int total = 0; /* new total remaining */
315
316 if (!n)
317 goto finish;
318
319 if (likely(res >= 0)) {
320 n -= res; /* subtract num bytes *not* processed */
321 total = walk->total - n;
322 }
323
324 if (likely(!(walk->flags & (SKCIPHER_WALK_SLOW |
325 SKCIPHER_WALK_COPY |
326 SKCIPHER_WALK_DIFF)))) {
327 scatterwalk_advance(&walk->in, n);
328 } else if (walk->flags & SKCIPHER_WALK_DIFF) {
329 scatterwalk_done_src(&walk->in, n);
330 } else if (walk->flags & SKCIPHER_WALK_COPY) {
331 scatterwalk_advance(&walk->in, n);
332 scatterwalk_map(&walk->out);
333 memcpy(walk->out.addr, walk->page, n);
334 } else { /* SKCIPHER_WALK_SLOW */
335 if (res > 0) {
336 /*
337 * Didn't process all bytes. Either the algorithm is
338 * broken, or this was the last step and it turned out
339 * the message wasn't evenly divisible into blocks but
340 * the algorithm requires it.
341 */
342 res = -EINVAL;
343 total = 0;
344 } else
345 memcpy_to_scatterwalk(&walk->out, walk->out.addr, n);
346 goto dst_done;
347 }
348
349 scatterwalk_done_dst(&walk->out, n);
350 dst_done:
351
352 if (res > 0)
353 res = 0;
354
355 walk->total = total;
356 walk->nbytes = 0;
357
358 if (total) {
359 if (walk->flags & SKCIPHER_WALK_SLEEP)
360 cond_resched();
361 walk->flags &= ~(SKCIPHER_WALK_SLOW | SKCIPHER_WALK_COPY |
362 SKCIPHER_WALK_DIFF);
363 return skcipher_walk_next(walk);
364 }
365
366 finish:
367 /* Short-circuit for the common/fast path. */
368 if (!((unsigned long)walk->buffer | (unsigned long)walk->page))
369 goto out;
370
371 if (walk->iv != walk->oiv)
372 memcpy(walk->oiv, walk->iv, walk->ivsize);
373 if (walk->buffer != walk->page)
374 kfree(walk->buffer);
375 if (walk->page)
376 free_page((unsigned long)walk->page);
377
378 out:
379 return res;
380 }
381 EXPORT_SYMBOL_GPL(skcipher_walk_done);
382