1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3 * Cryptographic API.
4 *
5 * Copyright (c) 2017-present, Facebook, Inc.
6 */
7 #include <linux/crypto.h>
8 #include <linux/init.h>
9 #include <linux/interrupt.h>
10 #include <linux/mm.h>
11 #include <linux/module.h>
12 #include <linux/net.h>
13 #include <linux/vmalloc.h>
14 #include <linux/zstd.h>
15 #include <crypto/internal/acompress.h>
16 #include <crypto/scatterwalk.h>
17
18
19 #define ZSTD_DEF_LEVEL 3
20 #define ZSTD_MAX_WINDOWLOG 18
21 #define ZSTD_MAX_SIZE BIT(ZSTD_MAX_WINDOWLOG)
22
23 struct zstd_ctx {
24 zstd_cctx *cctx;
25 zstd_dctx *dctx;
26 size_t wksp_size;
27 zstd_parameters params;
28 u8 wksp[] __aligned(8);
29 };
30
31 static DEFINE_MUTEX(zstd_stream_lock);
32
zstd_alloc_stream(void)33 static void *zstd_alloc_stream(void)
34 {
35 zstd_parameters params;
36 struct zstd_ctx *ctx;
37 size_t wksp_size;
38
39 params = zstd_get_params(ZSTD_DEF_LEVEL, ZSTD_MAX_SIZE);
40
41 wksp_size = max_t(size_t,
42 zstd_cstream_workspace_bound(¶ms.cParams),
43 zstd_dstream_workspace_bound(ZSTD_MAX_SIZE));
44 if (!wksp_size)
45 return ERR_PTR(-EINVAL);
46
47 ctx = kvmalloc(sizeof(*ctx) + wksp_size, GFP_KERNEL);
48 if (!ctx)
49 return ERR_PTR(-ENOMEM);
50
51 ctx->params = params;
52 ctx->wksp_size = wksp_size;
53
54 return ctx;
55 }
56
zstd_free_stream(void * ctx)57 static void zstd_free_stream(void *ctx)
58 {
59 kvfree(ctx);
60 }
61
62 static struct crypto_acomp_streams zstd_streams = {
63 .alloc_ctx = zstd_alloc_stream,
64 .free_ctx = zstd_free_stream,
65 };
66
zstd_init(struct crypto_acomp * acomp_tfm)67 static int zstd_init(struct crypto_acomp *acomp_tfm)
68 {
69 int ret = 0;
70
71 mutex_lock(&zstd_stream_lock);
72 ret = crypto_acomp_alloc_streams(&zstd_streams);
73 mutex_unlock(&zstd_stream_lock);
74
75 return ret;
76 }
77
zstd_exit(struct crypto_acomp * acomp_tfm)78 static void zstd_exit(struct crypto_acomp *acomp_tfm)
79 {
80 crypto_acomp_free_streams(&zstd_streams);
81 }
82
zstd_compress_one(struct acomp_req * req,struct zstd_ctx * ctx,const void * src,void * dst,unsigned int * dlen)83 static int zstd_compress_one(struct acomp_req *req, struct zstd_ctx *ctx,
84 const void *src, void *dst, unsigned int *dlen)
85 {
86 size_t out_len;
87
88 ctx->cctx = zstd_init_cctx(ctx->wksp, ctx->wksp_size);
89 if (!ctx->cctx)
90 return -EINVAL;
91
92 out_len = zstd_compress_cctx(ctx->cctx, dst, req->dlen, src, req->slen,
93 &ctx->params);
94 if (zstd_is_error(out_len))
95 return -EINVAL;
96
97 *dlen = out_len;
98
99 return 0;
100 }
101
zstd_compress(struct acomp_req * req)102 static int zstd_compress(struct acomp_req *req)
103 {
104 struct crypto_acomp_stream *s;
105 unsigned int pos, scur, dcur;
106 unsigned int total_out = 0;
107 bool data_available = true;
108 zstd_out_buffer outbuf;
109 struct acomp_walk walk;
110 zstd_in_buffer inbuf;
111 struct zstd_ctx *ctx;
112 size_t pending_bytes;
113 size_t num_bytes;
114 int ret;
115
116 s = crypto_acomp_lock_stream_bh(&zstd_streams);
117 ctx = s->ctx;
118
119 ret = acomp_walk_virt(&walk, req, true);
120 if (ret)
121 goto out;
122
123 ctx->cctx = zstd_init_cstream(&ctx->params, 0, ctx->wksp, ctx->wksp_size);
124 if (!ctx->cctx) {
125 ret = -EINVAL;
126 goto out;
127 }
128
129 do {
130 dcur = acomp_walk_next_dst(&walk);
131 if (!dcur) {
132 ret = -ENOSPC;
133 goto out;
134 }
135
136 outbuf.pos = 0;
137 outbuf.dst = (u8 *)walk.dst.virt.addr;
138 outbuf.size = dcur;
139
140 do {
141 scur = acomp_walk_next_src(&walk);
142 if (dcur == req->dlen && scur == req->slen) {
143 ret = zstd_compress_one(req, ctx, walk.src.virt.addr,
144 walk.dst.virt.addr, &total_out);
145 acomp_walk_done_src(&walk, scur);
146 acomp_walk_done_dst(&walk, dcur);
147 goto out;
148 }
149
150 if (scur) {
151 inbuf.pos = 0;
152 inbuf.src = walk.src.virt.addr;
153 inbuf.size = scur;
154 } else {
155 data_available = false;
156 break;
157 }
158
159 num_bytes = zstd_compress_stream(ctx->cctx, &outbuf, &inbuf);
160 if (ZSTD_isError(num_bytes)) {
161 ret = -EIO;
162 goto out;
163 }
164
165 pending_bytes = zstd_flush_stream(ctx->cctx, &outbuf);
166 if (ZSTD_isError(pending_bytes)) {
167 ret = -EIO;
168 goto out;
169 }
170 acomp_walk_done_src(&walk, inbuf.pos);
171 } while (dcur != outbuf.pos);
172
173 total_out += outbuf.pos;
174 acomp_walk_done_dst(&walk, dcur);
175 } while (data_available);
176
177 pos = outbuf.pos;
178 num_bytes = zstd_end_stream(ctx->cctx, &outbuf);
179 if (ZSTD_isError(num_bytes))
180 ret = -EIO;
181 else
182 total_out += (outbuf.pos - pos);
183
184 out:
185 if (ret)
186 req->dlen = 0;
187 else
188 req->dlen = total_out;
189
190 crypto_acomp_unlock_stream_bh(s);
191
192 return ret;
193 }
194
zstd_decompress_one(struct acomp_req * req,struct zstd_ctx * ctx,const void * src,void * dst,unsigned int * dlen)195 static int zstd_decompress_one(struct acomp_req *req, struct zstd_ctx *ctx,
196 const void *src, void *dst, unsigned int *dlen)
197 {
198 size_t out_len;
199
200 ctx->dctx = zstd_init_dctx(ctx->wksp, ctx->wksp_size);
201 if (!ctx->dctx)
202 return -EINVAL;
203
204 out_len = zstd_decompress_dctx(ctx->dctx, dst, req->dlen, src, req->slen);
205 if (zstd_is_error(out_len))
206 return -EINVAL;
207
208 *dlen = out_len;
209
210 return 0;
211 }
212
zstd_decompress(struct acomp_req * req)213 static int zstd_decompress(struct acomp_req *req)
214 {
215 struct crypto_acomp_stream *s;
216 unsigned int total_out = 0;
217 unsigned int scur, dcur;
218 zstd_out_buffer outbuf;
219 struct acomp_walk walk;
220 zstd_in_buffer inbuf;
221 struct zstd_ctx *ctx;
222 size_t pending_bytes;
223 int ret;
224
225 s = crypto_acomp_lock_stream_bh(&zstd_streams);
226 ctx = s->ctx;
227
228 ret = acomp_walk_virt(&walk, req, true);
229 if (ret)
230 goto out;
231
232 ctx->dctx = zstd_init_dstream(ZSTD_MAX_SIZE, ctx->wksp, ctx->wksp_size);
233 if (!ctx->dctx) {
234 ret = -EINVAL;
235 goto out;
236 }
237
238 do {
239 scur = acomp_walk_next_src(&walk);
240 if (scur) {
241 inbuf.pos = 0;
242 inbuf.size = scur;
243 inbuf.src = walk.src.virt.addr;
244 } else {
245 break;
246 }
247
248 do {
249 dcur = acomp_walk_next_dst(&walk);
250 if (dcur == req->dlen && scur == req->slen) {
251 ret = zstd_decompress_one(req, ctx, walk.src.virt.addr,
252 walk.dst.virt.addr, &total_out);
253 acomp_walk_done_dst(&walk, dcur);
254 acomp_walk_done_src(&walk, scur);
255 goto out;
256 }
257
258 if (!dcur) {
259 ret = -ENOSPC;
260 goto out;
261 }
262
263 outbuf.pos = 0;
264 outbuf.dst = (u8 *)walk.dst.virt.addr;
265 outbuf.size = dcur;
266
267 pending_bytes = zstd_decompress_stream(ctx->dctx, &outbuf, &inbuf);
268 if (ZSTD_isError(pending_bytes)) {
269 ret = -EIO;
270 goto out;
271 }
272
273 total_out += outbuf.pos;
274
275 acomp_walk_done_dst(&walk, outbuf.pos);
276 } while (inbuf.pos != scur);
277
278 acomp_walk_done_src(&walk, scur);
279 } while (ret == 0);
280
281 out:
282 if (ret)
283 req->dlen = 0;
284 else
285 req->dlen = total_out;
286
287 crypto_acomp_unlock_stream_bh(s);
288
289 return ret;
290 }
291
292 static struct acomp_alg zstd_acomp = {
293 .base = {
294 .cra_name = "zstd",
295 .cra_driver_name = "zstd-generic",
296 .cra_flags = CRYPTO_ALG_REQ_VIRT,
297 .cra_module = THIS_MODULE,
298 },
299 .init = zstd_init,
300 .exit = zstd_exit,
301 .compress = zstd_compress,
302 .decompress = zstd_decompress,
303 };
304
zstd_mod_init(void)305 static int __init zstd_mod_init(void)
306 {
307 return crypto_register_acomp(&zstd_acomp);
308 }
309
zstd_mod_fini(void)310 static void __exit zstd_mod_fini(void)
311 {
312 crypto_unregister_acomp(&zstd_acomp);
313 }
314
315 module_init(zstd_mod_init);
316 module_exit(zstd_mod_fini);
317
318 MODULE_LICENSE("GPL");
319 MODULE_DESCRIPTION("Zstd Compression Algorithm");
320 MODULE_ALIAS_CRYPTO("zstd");
321