1 // SPDX-License-Identifier: GPL-2.0-or-later
2 
3 #include <linux/kernel.h>
4 #include <linux/slab.h>
5 #include <linux/vmalloc.h>
6 #include <linux/zstd.h>
7 
8 #include "backend_zstd.h"
9 
10 struct zstd_ctx {
11 	zstd_cctx *cctx;
12 	zstd_dctx *dctx;
13 	void *cctx_mem;
14 	void *dctx_mem;
15 };
16 
17 struct zstd_params {
18 	zstd_custom_mem custom_mem;
19 	zstd_cdict *cdict;
20 	zstd_ddict *ddict;
21 	zstd_parameters cprm;
22 };
23 
24 /*
25  * For C/D dictionaries we need to provide zstd with zstd_custom_mem,
26  * which zstd uses internally to allocate/free memory when needed.
27  */
zstd_custom_alloc(void * opaque,size_t size)28 static void *zstd_custom_alloc(void *opaque, size_t size)
29 {
30 	return kvzalloc(size, GFP_NOIO | __GFP_NOWARN);
31 }
32 
zstd_custom_free(void * opaque,void * address)33 static void zstd_custom_free(void *opaque, void *address)
34 {
35 	kvfree(address);
36 }
37 
zstd_release_params(struct zcomp_params * params)38 static void zstd_release_params(struct zcomp_params *params)
39 {
40 	struct zstd_params *zp = params->drv_data;
41 
42 	params->drv_data = NULL;
43 	if (!zp)
44 		return;
45 
46 	zstd_free_cdict(zp->cdict);
47 	zstd_free_ddict(zp->ddict);
48 	kfree(zp);
49 }
50 
zstd_setup_params(struct zcomp_params * params)51 static int zstd_setup_params(struct zcomp_params *params)
52 {
53 	zstd_compression_parameters prm;
54 	struct zstd_params *zp;
55 
56 	zp = kzalloc(sizeof(*zp), GFP_KERNEL);
57 	if (!zp)
58 		return -ENOMEM;
59 
60 	params->drv_data = zp;
61 	if (params->level == ZCOMP_PARAM_NO_LEVEL)
62 		params->level = zstd_default_clevel();
63 
64 	zp->cprm = zstd_get_params(params->level, PAGE_SIZE);
65 
66 	zp->custom_mem.customAlloc = zstd_custom_alloc;
67 	zp->custom_mem.customFree = zstd_custom_free;
68 
69 	prm = zstd_get_cparams(params->level, PAGE_SIZE,
70 			       params->dict_sz);
71 
72 	zp->cdict = zstd_create_cdict_byreference(params->dict,
73 						  params->dict_sz,
74 						  prm,
75 						  zp->custom_mem);
76 	if (!zp->cdict)
77 		goto error;
78 
79 	zp->ddict = zstd_create_ddict_byreference(params->dict,
80 						  params->dict_sz,
81 						  zp->custom_mem);
82 	if (!zp->ddict)
83 		goto error;
84 
85 	return 0;
86 
87 error:
88 	zstd_release_params(params);
89 	return -EINVAL;
90 }
91 
zstd_destroy(struct zcomp_ctx * ctx)92 static void zstd_destroy(struct zcomp_ctx *ctx)
93 {
94 	struct zstd_ctx *zctx = ctx->context;
95 
96 	if (!zctx)
97 		return;
98 
99 	/*
100 	 * If ->cctx_mem and ->dctx_mem were allocated then we didn't use
101 	 * C/D dictionary and ->cctx / ->dctx were "embedded" into these
102 	 * buffers.
103 	 *
104 	 * If otherwise then we need to explicitly release ->cctx / ->dctx.
105 	 */
106 	if (zctx->cctx_mem)
107 		vfree(zctx->cctx_mem);
108 	else
109 		zstd_free_cctx(zctx->cctx);
110 
111 	if (zctx->dctx_mem)
112 		vfree(zctx->dctx_mem);
113 	else
114 		zstd_free_dctx(zctx->dctx);
115 
116 	kfree(zctx);
117 }
118 
zstd_create(struct zcomp_params * params,struct zcomp_ctx * ctx)119 static int zstd_create(struct zcomp_params *params, struct zcomp_ctx *ctx)
120 {
121 	struct zstd_ctx *zctx;
122 	zstd_parameters prm;
123 	size_t sz;
124 
125 	zctx = kzalloc(sizeof(*zctx), GFP_KERNEL);
126 	if (!zctx)
127 		return -ENOMEM;
128 
129 	ctx->context = zctx;
130 	if (params->dict_sz == 0) {
131 		prm = zstd_get_params(params->level, PAGE_SIZE);
132 		sz = zstd_cctx_workspace_bound(&prm.cParams);
133 		zctx->cctx_mem = vzalloc(sz);
134 		if (!zctx->cctx_mem)
135 			goto error;
136 
137 		zctx->cctx = zstd_init_cctx(zctx->cctx_mem, sz);
138 		if (!zctx->cctx)
139 			goto error;
140 
141 		sz = zstd_dctx_workspace_bound();
142 		zctx->dctx_mem = vzalloc(sz);
143 		if (!zctx->dctx_mem)
144 			goto error;
145 
146 		zctx->dctx = zstd_init_dctx(zctx->dctx_mem, sz);
147 		if (!zctx->dctx)
148 			goto error;
149 	} else {
150 		struct zstd_params *zp = params->drv_data;
151 
152 		zctx->cctx = zstd_create_cctx_advanced(zp->custom_mem);
153 		if (!zctx->cctx)
154 			goto error;
155 
156 		zctx->dctx = zstd_create_dctx_advanced(zp->custom_mem);
157 		if (!zctx->dctx)
158 			goto error;
159 	}
160 
161 	return 0;
162 
163 error:
164 	zstd_release_params(params);
165 	zstd_destroy(ctx);
166 	return -EINVAL;
167 }
168 
zstd_compress(struct zcomp_params * params,struct zcomp_ctx * ctx,struct zcomp_req * req)169 static int zstd_compress(struct zcomp_params *params, struct zcomp_ctx *ctx,
170 			 struct zcomp_req *req)
171 {
172 	struct zstd_params *zp = params->drv_data;
173 	struct zstd_ctx *zctx = ctx->context;
174 	size_t ret;
175 
176 	if (params->dict_sz == 0)
177 		ret = zstd_compress_cctx(zctx->cctx, req->dst, req->dst_len,
178 					 req->src, req->src_len, &zp->cprm);
179 	else
180 		ret = zstd_compress_using_cdict(zctx->cctx, req->dst,
181 						req->dst_len, req->src,
182 						req->src_len,
183 						zp->cdict);
184 	if (zstd_is_error(ret))
185 		return -EINVAL;
186 	req->dst_len = ret;
187 	return 0;
188 }
189 
zstd_decompress(struct zcomp_params * params,struct zcomp_ctx * ctx,struct zcomp_req * req)190 static int zstd_decompress(struct zcomp_params *params, struct zcomp_ctx *ctx,
191 			   struct zcomp_req *req)
192 {
193 	struct zstd_params *zp = params->drv_data;
194 	struct zstd_ctx *zctx = ctx->context;
195 	size_t ret;
196 
197 	if (params->dict_sz == 0)
198 		ret = zstd_decompress_dctx(zctx->dctx, req->dst, req->dst_len,
199 					   req->src, req->src_len);
200 	else
201 		ret = zstd_decompress_using_ddict(zctx->dctx, req->dst,
202 						  req->dst_len, req->src,
203 						  req->src_len, zp->ddict);
204 	if (zstd_is_error(ret))
205 		return -EINVAL;
206 	return 0;
207 }
208 
209 const struct zcomp_ops backend_zstd = {
210 	.compress	= zstd_compress,
211 	.decompress	= zstd_decompress,
212 	.create_ctx	= zstd_create,
213 	.destroy_ctx	= zstd_destroy,
214 	.setup_params	= zstd_setup_params,
215 	.release_params	= zstd_release_params,
216 	.name		= "zstd",
217 };
218