1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * Cryptographic API.
4  *
5  * Copyright (c) 2017-present, Facebook, Inc.
6  */
7 #include <linux/crypto.h>
8 #include <linux/init.h>
9 #include <linux/interrupt.h>
10 #include <linux/mm.h>
11 #include <linux/module.h>
12 #include <linux/net.h>
13 #include <linux/vmalloc.h>
14 #include <linux/zstd.h>
15 #include <crypto/internal/scompress.h>
16 
17 
18 #define ZSTD_DEF_LEVEL	3
19 
20 struct zstd_ctx {
21 	zstd_cctx *cctx;
22 	zstd_dctx *dctx;
23 	void *cwksp;
24 	void *dwksp;
25 };
26 
zstd_params(void)27 static zstd_parameters zstd_params(void)
28 {
29 	return zstd_get_params(ZSTD_DEF_LEVEL, 0);
30 }
31 
zstd_comp_init(struct zstd_ctx * ctx)32 static int zstd_comp_init(struct zstd_ctx *ctx)
33 {
34 	int ret = 0;
35 	const zstd_parameters params = zstd_params();
36 	const size_t wksp_size = zstd_cctx_workspace_bound(&params.cParams);
37 
38 	ctx->cwksp = vzalloc(wksp_size);
39 	if (!ctx->cwksp) {
40 		ret = -ENOMEM;
41 		goto out;
42 	}
43 
44 	ctx->cctx = zstd_init_cctx(ctx->cwksp, wksp_size);
45 	if (!ctx->cctx) {
46 		ret = -EINVAL;
47 		goto out_free;
48 	}
49 out:
50 	return ret;
51 out_free:
52 	vfree(ctx->cwksp);
53 	goto out;
54 }
55 
zstd_decomp_init(struct zstd_ctx * ctx)56 static int zstd_decomp_init(struct zstd_ctx *ctx)
57 {
58 	int ret = 0;
59 	const size_t wksp_size = zstd_dctx_workspace_bound();
60 
61 	ctx->dwksp = vzalloc(wksp_size);
62 	if (!ctx->dwksp) {
63 		ret = -ENOMEM;
64 		goto out;
65 	}
66 
67 	ctx->dctx = zstd_init_dctx(ctx->dwksp, wksp_size);
68 	if (!ctx->dctx) {
69 		ret = -EINVAL;
70 		goto out_free;
71 	}
72 out:
73 	return ret;
74 out_free:
75 	vfree(ctx->dwksp);
76 	goto out;
77 }
78 
zstd_comp_exit(struct zstd_ctx * ctx)79 static void zstd_comp_exit(struct zstd_ctx *ctx)
80 {
81 	vfree(ctx->cwksp);
82 	ctx->cwksp = NULL;
83 	ctx->cctx = NULL;
84 }
85 
zstd_decomp_exit(struct zstd_ctx * ctx)86 static void zstd_decomp_exit(struct zstd_ctx *ctx)
87 {
88 	vfree(ctx->dwksp);
89 	ctx->dwksp = NULL;
90 	ctx->dctx = NULL;
91 }
92 
__zstd_init(void * ctx)93 static int __zstd_init(void *ctx)
94 {
95 	int ret;
96 
97 	ret = zstd_comp_init(ctx);
98 	if (ret)
99 		return ret;
100 	ret = zstd_decomp_init(ctx);
101 	if (ret)
102 		zstd_comp_exit(ctx);
103 	return ret;
104 }
105 
zstd_alloc_ctx(void)106 static void *zstd_alloc_ctx(void)
107 {
108 	int ret;
109 	struct zstd_ctx *ctx;
110 
111 	ctx = kzalloc(sizeof(*ctx), GFP_KERNEL);
112 	if (!ctx)
113 		return ERR_PTR(-ENOMEM);
114 
115 	ret = __zstd_init(ctx);
116 	if (ret) {
117 		kfree(ctx);
118 		return ERR_PTR(ret);
119 	}
120 
121 	return ctx;
122 }
123 
__zstd_exit(void * ctx)124 static void __zstd_exit(void *ctx)
125 {
126 	zstd_comp_exit(ctx);
127 	zstd_decomp_exit(ctx);
128 }
129 
zstd_free_ctx(void * ctx)130 static void zstd_free_ctx(void *ctx)
131 {
132 	__zstd_exit(ctx);
133 	kfree_sensitive(ctx);
134 }
135 
__zstd_compress(const u8 * src,unsigned int slen,u8 * dst,unsigned int * dlen,void * ctx)136 static int __zstd_compress(const u8 *src, unsigned int slen,
137 			   u8 *dst, unsigned int *dlen, void *ctx)
138 {
139 	size_t out_len;
140 	struct zstd_ctx *zctx = ctx;
141 	const zstd_parameters params = zstd_params();
142 
143 	out_len = zstd_compress_cctx(zctx->cctx, dst, *dlen, src, slen, &params);
144 	if (zstd_is_error(out_len))
145 		return -EINVAL;
146 	*dlen = out_len;
147 	return 0;
148 }
149 
zstd_scompress(struct crypto_scomp * tfm,const u8 * src,unsigned int slen,u8 * dst,unsigned int * dlen,void * ctx)150 static int zstd_scompress(struct crypto_scomp *tfm, const u8 *src,
151 			  unsigned int slen, u8 *dst, unsigned int *dlen,
152 			  void *ctx)
153 {
154 	return __zstd_compress(src, slen, dst, dlen, ctx);
155 }
156 
__zstd_decompress(const u8 * src,unsigned int slen,u8 * dst,unsigned int * dlen,void * ctx)157 static int __zstd_decompress(const u8 *src, unsigned int slen,
158 			     u8 *dst, unsigned int *dlen, void *ctx)
159 {
160 	size_t out_len;
161 	struct zstd_ctx *zctx = ctx;
162 
163 	out_len = zstd_decompress_dctx(zctx->dctx, dst, *dlen, src, slen);
164 	if (zstd_is_error(out_len))
165 		return -EINVAL;
166 	*dlen = out_len;
167 	return 0;
168 }
169 
zstd_sdecompress(struct crypto_scomp * tfm,const u8 * src,unsigned int slen,u8 * dst,unsigned int * dlen,void * ctx)170 static int zstd_sdecompress(struct crypto_scomp *tfm, const u8 *src,
171 			    unsigned int slen, u8 *dst, unsigned int *dlen,
172 			    void *ctx)
173 {
174 	return __zstd_decompress(src, slen, dst, dlen, ctx);
175 }
176 
177 static struct scomp_alg scomp = {
178 	.alloc_ctx		= zstd_alloc_ctx,
179 	.free_ctx		= zstd_free_ctx,
180 	.compress		= zstd_scompress,
181 	.decompress		= zstd_sdecompress,
182 	.base			= {
183 		.cra_name	= "zstd",
184 		.cra_driver_name = "zstd-scomp",
185 		.cra_module	 = THIS_MODULE,
186 	}
187 };
188 
zstd_mod_init(void)189 static int __init zstd_mod_init(void)
190 {
191 	return crypto_register_scomp(&scomp);
192 }
193 
zstd_mod_fini(void)194 static void __exit zstd_mod_fini(void)
195 {
196 	crypto_unregister_scomp(&scomp);
197 }
198 
199 subsys_initcall(zstd_mod_init);
200 module_exit(zstd_mod_fini);
201 
202 MODULE_LICENSE("GPL");
203 MODULE_DESCRIPTION("Zstd Compression Algorithm");
204 MODULE_ALIAS_CRYPTO("zstd");
205