1 // SPDX-License-Identifier: GPL-2.0-or-later 2 3 #include <linux/kernel.h> 4 #include <linux/slab.h> 5 #include <linux/vmalloc.h> 6 #include <linux/zstd.h> 7 8 #include "backend_zstd.h" 9 10 struct zstd_ctx { 11 zstd_cctx *cctx; 12 zstd_dctx *dctx; 13 void *cctx_mem; 14 void *dctx_mem; 15 }; 16 17 struct zstd_params { 18 zstd_custom_mem custom_mem; 19 zstd_cdict *cdict; 20 zstd_ddict *ddict; 21 zstd_parameters cprm; 22 }; 23 24 /* 25 * For C/D dictionaries we need to provide zstd with zstd_custom_mem, 26 * which zstd uses internally to allocate/free memory when needed. 27 */ 28 static void *zstd_custom_alloc(void *opaque, size_t size) 29 { 30 return kvzalloc(size, GFP_NOIO | __GFP_NOWARN); 31 } 32 33 static void zstd_custom_free(void *opaque, void *address) 34 { 35 kvfree(address); 36 } 37 38 static void zstd_release_params(struct zcomp_params *params) 39 { 40 struct zstd_params *zp = params->drv_data; 41 42 params->drv_data = NULL; 43 if (!zp) 44 return; 45 46 zstd_free_cdict(zp->cdict); 47 zstd_free_ddict(zp->ddict); 48 kfree(zp); 49 } 50 51 static int zstd_setup_params(struct zcomp_params *params) 52 { 53 zstd_compression_parameters prm; 54 struct zstd_params *zp; 55 56 zp = kzalloc(sizeof(*zp), GFP_KERNEL); 57 if (!zp) 58 return -ENOMEM; 59 60 params->drv_data = zp; 61 if (params->level == ZCOMP_PARAM_NOT_SET) 62 params->level = zstd_default_clevel(); 63 64 zp->cprm = zstd_get_params(params->level, PAGE_SIZE); 65 66 zp->custom_mem.customAlloc = zstd_custom_alloc; 67 zp->custom_mem.customFree = zstd_custom_free; 68 69 prm = zstd_get_cparams(params->level, PAGE_SIZE, 70 params->dict_sz); 71 72 zp->cdict = zstd_create_cdict_byreference(params->dict, 73 params->dict_sz, 74 prm, 75 zp->custom_mem); 76 if (!zp->cdict) 77 goto error; 78 79 zp->ddict = zstd_create_ddict_byreference(params->dict, 80 params->dict_sz, 81 zp->custom_mem); 82 if (!zp->ddict) 83 goto error; 84 85 return 0; 86 87 error: 88 zstd_release_params(params); 89 return -EINVAL; 90 } 91 92 static void zstd_destroy(struct zcomp_ctx *ctx) 93 { 94 struct zstd_ctx *zctx = ctx->context; 95 96 if (!zctx) 97 return; 98 99 /* 100 * If ->cctx_mem and ->dctx_mem were allocated then we didn't use 101 * C/D dictionary and ->cctx / ->dctx were "embedded" into these 102 * buffers. 103 * 104 * If otherwise then we need to explicitly release ->cctx / ->dctx. 105 */ 106 if (zctx->cctx_mem) 107 vfree(zctx->cctx_mem); 108 else 109 zstd_free_cctx(zctx->cctx); 110 111 if (zctx->dctx_mem) 112 vfree(zctx->dctx_mem); 113 else 114 zstd_free_dctx(zctx->dctx); 115 116 kfree(zctx); 117 } 118 119 static int zstd_create(struct zcomp_params *params, struct zcomp_ctx *ctx) 120 { 121 struct zstd_ctx *zctx; 122 zstd_parameters prm; 123 size_t sz; 124 125 zctx = kzalloc(sizeof(*zctx), GFP_KERNEL); 126 if (!zctx) 127 return -ENOMEM; 128 129 ctx->context = zctx; 130 if (params->dict_sz == 0) { 131 prm = zstd_get_params(params->level, PAGE_SIZE); 132 sz = zstd_cctx_workspace_bound(&prm.cParams); 133 zctx->cctx_mem = vzalloc(sz); 134 if (!zctx->cctx_mem) 135 goto error; 136 137 zctx->cctx = zstd_init_cctx(zctx->cctx_mem, sz); 138 if (!zctx->cctx) 139 goto error; 140 141 sz = zstd_dctx_workspace_bound(); 142 zctx->dctx_mem = vzalloc(sz); 143 if (!zctx->dctx_mem) 144 goto error; 145 146 zctx->dctx = zstd_init_dctx(zctx->dctx_mem, sz); 147 if (!zctx->dctx) 148 goto error; 149 } else { 150 struct zstd_params *zp = params->drv_data; 151 152 zctx->cctx = zstd_create_cctx_advanced(zp->custom_mem); 153 if (!zctx->cctx) 154 goto error; 155 156 zctx->dctx = zstd_create_dctx_advanced(zp->custom_mem); 157 if (!zctx->dctx) 158 goto error; 159 } 160 161 return 0; 162 163 error: 164 zstd_release_params(params); 165 zstd_destroy(ctx); 166 return -EINVAL; 167 } 168 169 static int zstd_compress(struct zcomp_params *params, struct zcomp_ctx *ctx, 170 struct zcomp_req *req) 171 { 172 struct zstd_params *zp = params->drv_data; 173 struct zstd_ctx *zctx = ctx->context; 174 size_t ret; 175 176 if (params->dict_sz == 0) 177 ret = zstd_compress_cctx(zctx->cctx, req->dst, req->dst_len, 178 req->src, req->src_len, &zp->cprm); 179 else 180 ret = zstd_compress_using_cdict(zctx->cctx, req->dst, 181 req->dst_len, req->src, 182 req->src_len, 183 zp->cdict); 184 if (zstd_is_error(ret)) 185 return -EINVAL; 186 req->dst_len = ret; 187 return 0; 188 } 189 190 static int zstd_decompress(struct zcomp_params *params, struct zcomp_ctx *ctx, 191 struct zcomp_req *req) 192 { 193 struct zstd_params *zp = params->drv_data; 194 struct zstd_ctx *zctx = ctx->context; 195 size_t ret; 196 197 if (params->dict_sz == 0) 198 ret = zstd_decompress_dctx(zctx->dctx, req->dst, req->dst_len, 199 req->src, req->src_len); 200 else 201 ret = zstd_decompress_using_ddict(zctx->dctx, req->dst, 202 req->dst_len, req->src, 203 req->src_len, zp->ddict); 204 if (zstd_is_error(ret)) 205 return -EINVAL; 206 return 0; 207 } 208 209 const struct zcomp_ops backend_zstd = { 210 .compress = zstd_compress, 211 .decompress = zstd_decompress, 212 .create_ctx = zstd_create, 213 .destroy_ctx = zstd_destroy, 214 .setup_params = zstd_setup_params, 215 .release_params = zstd_release_params, 216 .name = "zstd", 217 }; 218