1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3 * Copyright (C) 2024, SUSE LLC
4 *
5 * Authors: Enzo Matsumiya <ematsumiya@suse.de>
6 *
7 * This file implements I/O compression support for SMB2 messages (SMB 3.1.1 only).
8 * See compress/ for implementation details of each algorithm.
9 *
10 * References:
11 * MS-SMB2 "3.1.4.4 Compressing the Message"
12 * MS-SMB2 "3.1.5.3 Decompressing the Chained Message"
13 * MS-XCA - for details of the supported algorithms
14 */
15 #include <linux/slab.h>
16 #include <linux/kernel.h>
17 #include <linux/uio.h>
18 #include <linux/sort.h>
19
20 #include "cifsglob.h"
21 #include "../common/smb2pdu.h"
22 #include "cifsproto.h"
23 #include "smb2proto.h"
24
25 #include "compress/lz77.h"
26 #include "compress.h"
27
28 /*
29 * The heuristic_*() functions below try to determine data compressibility.
30 *
31 * Derived from fs/btrfs/compression.c, changing coding style, some parameters, and removing
32 * unused parts.
33 *
34 * Read that file for better and more detailed explanation of the calculations.
35 *
36 * The algorithms are ran in a collected sample of the input (uncompressed) data.
37 * The sample is formed of 2K reads in PAGE_SIZE intervals, with a maximum size of 4M.
38 *
39 * Parsing the sample goes from "low-hanging fruits" (fastest algorithms, likely compressible)
40 * to "need more analysis" (likely uncompressible).
41 */
42
43 struct bucket {
44 unsigned int count;
45 };
46
47 /**
48 * has_low_entropy() - Compute Shannon entropy of the sampled data.
49 * @bkt: Bytes counts of the sample.
50 * @slen: Size of the sample.
51 *
52 * Return: true if the level (percentage of number of bits that would be required to
53 * compress the data) is below the minimum threshold.
54 *
55 * Note:
56 * There _is_ an entropy level here that's > 65 (minimum threshold) that would indicate a
57 * possibility of compression, but compressing, or even further analysing, it would waste so much
58 * resources that it's simply not worth it.
59 *
60 * Also Shannon entropy is the last computed heuristic; if we got this far and ended up
61 * with uncertainty, just stay on the safe side and call it uncompressible.
62 */
has_low_entropy(struct bucket * bkt,size_t slen)63 static bool has_low_entropy(struct bucket *bkt, size_t slen)
64 {
65 const size_t threshold = 65, max_entropy = 8 * ilog2(16);
66 size_t i, p, p2, len, sum = 0;
67
68 #define pow4(n) (n * n * n * n)
69 len = ilog2(pow4(slen));
70
71 for (i = 0; i < 256 && bkt[i].count > 0; i++) {
72 p = bkt[i].count;
73 p2 = ilog2(pow4(p));
74 sum += p * (len - p2);
75 }
76
77 sum /= slen;
78
79 return ((sum * 100 / max_entropy) <= threshold);
80 }
81
82 #define BYTE_DIST_BAD 0
83 #define BYTE_DIST_GOOD 1
84 #define BYTE_DIST_MAYBE 2
85 /**
86 * calc_byte_distribution() - Compute byte distribution on the sampled data.
87 * @bkt: Byte counts of the sample.
88 * @slen: Size of the sample.
89 *
90 * Return:
91 * BYTE_DIST_BAD: A "hard no" for compression -- a computed uniform distribution of
92 * the bytes (e.g. random or encrypted data).
93 * BYTE_DIST_GOOD: High probability (normal (Gaussian) distribution) of the data being
94 * compressible.
95 * BYTE_DIST_MAYBE: When computed byte distribution resulted in "low > n < high"
96 * grounds. has_low_entropy() should be used for a final decision.
97 */
calc_byte_distribution(struct bucket * bkt,size_t slen)98 static int calc_byte_distribution(struct bucket *bkt, size_t slen)
99 {
100 const size_t low = 64, high = 200, threshold = slen * 90 / 100;
101 size_t sum = 0;
102 int i;
103
104 for (i = 0; i < low; i++)
105 sum += bkt[i].count;
106
107 if (sum > threshold)
108 return BYTE_DIST_BAD;
109
110 for (; i < high && bkt[i].count > 0; i++) {
111 sum += bkt[i].count;
112 if (sum > threshold)
113 break;
114 }
115
116 if (i <= low)
117 return BYTE_DIST_GOOD;
118
119 if (i >= high)
120 return BYTE_DIST_BAD;
121
122 return BYTE_DIST_MAYBE;
123 }
124
is_mostly_ascii(const struct bucket * bkt)125 static bool is_mostly_ascii(const struct bucket *bkt)
126 {
127 size_t count = 0;
128 int i;
129
130 for (i = 0; i < 256; i++)
131 if (bkt[i].count > 0)
132 /* Too many non-ASCII (0-63) bytes. */
133 if (++count > 64)
134 return false;
135
136 return true;
137 }
138
has_repeated_data(const u8 * sample,size_t len)139 static bool has_repeated_data(const u8 *sample, size_t len)
140 {
141 size_t s = len / 2;
142
143 return (!memcmp(&sample[0], &sample[s], s));
144 }
145
cmp_bkt(const void * _a,const void * _b)146 static int cmp_bkt(const void *_a, const void *_b)
147 {
148 const struct bucket *a = _a, *b = _b;
149
150 /* Reverse sort. */
151 if (a->count > b->count)
152 return -1;
153
154 return 1;
155 }
156
157 /*
158 * Collect some 2K samples with 2K gaps between.
159 */
collect_sample(const struct iov_iter * source,ssize_t max,u8 * sample)160 static int collect_sample(const struct iov_iter *source, ssize_t max, u8 *sample)
161 {
162 struct iov_iter iter = *source;
163 size_t s = 0;
164
165 while (iov_iter_count(&iter) >= SZ_2K) {
166 size_t part = umin(umin(iov_iter_count(&iter), SZ_2K), max);
167 size_t n;
168
169 n = copy_from_iter(sample + s, part, &iter);
170 if (n != part)
171 return -EFAULT;
172
173 s += n;
174 max -= n;
175
176 if (iov_iter_count(&iter) < PAGE_SIZE - SZ_2K)
177 break;
178
179 iov_iter_advance(&iter, SZ_2K);
180 }
181
182 return s;
183 }
184
185 /**
186 * is_compressible() - Determines if a chunk of data is compressible.
187 * @data: Iterator containing uncompressed data.
188 *
189 * Return: true if @data is compressible, false otherwise.
190 *
191 * Tests shows that this function is quite reliable in predicting data compressibility,
192 * matching close to 1:1 with the behaviour of LZ77 compression success and failures.
193 */
is_compressible(const struct iov_iter * data)194 static bool is_compressible(const struct iov_iter *data)
195 {
196 const size_t read_size = SZ_2K, bkt_size = 256, max = SZ_4M;
197 struct bucket *bkt = NULL;
198 size_t len;
199 u8 *sample;
200 bool ret = false;
201 int i;
202
203 /* Preventive double check -- already checked in should_compress(). */
204 len = iov_iter_count(data);
205 if (unlikely(len < read_size))
206 return ret;
207
208 if (len - read_size > max)
209 len = max;
210
211 sample = kvzalloc(len, GFP_KERNEL);
212 if (!sample) {
213 WARN_ON_ONCE(1);
214
215 return ret;
216 }
217
218 /* Sample 2K bytes per page of the uncompressed data. */
219 i = collect_sample(data, len, sample);
220 if (i <= 0) {
221 WARN_ON_ONCE(1);
222
223 goto out;
224 }
225
226 len = i;
227 ret = true;
228
229 if (has_repeated_data(sample, len))
230 goto out;
231
232 bkt = kcalloc(bkt_size, sizeof(*bkt), GFP_KERNEL);
233 if (!bkt) {
234 WARN_ON_ONCE(1);
235 ret = false;
236
237 goto out;
238 }
239
240 for (i = 0; i < len; i++)
241 bkt[sample[i]].count++;
242
243 if (is_mostly_ascii(bkt))
244 goto out;
245
246 /* Sort in descending order */
247 sort(bkt, bkt_size, sizeof(*bkt), cmp_bkt, NULL);
248
249 i = calc_byte_distribution(bkt, len);
250 if (i != BYTE_DIST_MAYBE) {
251 ret = !!i;
252
253 goto out;
254 }
255
256 ret = has_low_entropy(bkt, len);
257 out:
258 kvfree(sample);
259 kfree(bkt);
260
261 return ret;
262 }
263
should_compress(const struct cifs_tcon * tcon,const struct smb_rqst * rq)264 bool should_compress(const struct cifs_tcon *tcon, const struct smb_rqst *rq)
265 {
266 const struct smb2_hdr *shdr = rq->rq_iov->iov_base;
267
268 if (unlikely(!tcon || !tcon->ses || !tcon->ses->server))
269 return false;
270
271 if (!tcon->ses->server->compression.enabled)
272 return false;
273
274 if (!(tcon->share_flags & SMB2_SHAREFLAG_COMPRESS_DATA))
275 return false;
276
277 if (shdr->Command == SMB2_WRITE) {
278 const struct smb2_write_req *wreq = rq->rq_iov->iov_base;
279
280 if (le32_to_cpu(wreq->Length) < SMB_COMPRESS_MIN_LEN)
281 return false;
282
283 return is_compressible(&rq->rq_iter);
284 }
285
286 return (shdr->Command == SMB2_READ);
287 }
288
smb_compress(struct TCP_Server_Info * server,struct smb_rqst * rq,compress_send_fn send_fn)289 int smb_compress(struct TCP_Server_Info *server, struct smb_rqst *rq, compress_send_fn send_fn)
290 {
291 struct iov_iter iter;
292 u32 slen, dlen;
293 void *src, *dst = NULL;
294 int ret;
295
296 if (!server || !rq || !rq->rq_iov || !rq->rq_iov->iov_base)
297 return -EINVAL;
298
299 if (rq->rq_iov->iov_len != sizeof(struct smb2_write_req))
300 return -EINVAL;
301
302 slen = iov_iter_count(&rq->rq_iter);
303 src = kvzalloc(slen, GFP_KERNEL);
304 if (!src) {
305 ret = -ENOMEM;
306 goto err_free;
307 }
308
309 /* Keep the original iter intact. */
310 iter = rq->rq_iter;
311
312 if (!copy_from_iter_full(src, slen, &iter)) {
313 ret = -EIO;
314 goto err_free;
315 }
316
317 /*
318 * This is just overprovisioning, as the algorithm will error out if @dst reaches 7/8
319 * of @slen.
320 */
321 dlen = slen;
322 dst = kvzalloc(dlen, GFP_KERNEL);
323 if (!dst) {
324 ret = -ENOMEM;
325 goto err_free;
326 }
327
328 ret = lz77_compress(src, slen, dst, &dlen);
329 if (!ret) {
330 struct smb2_compression_hdr hdr = { 0 };
331 struct smb_rqst comp_rq = { .rq_nvec = 3, };
332 struct kvec iov[3];
333
334 hdr.ProtocolId = SMB2_COMPRESSION_TRANSFORM_ID;
335 hdr.OriginalCompressedSegmentSize = cpu_to_le32(slen);
336 hdr.CompressionAlgorithm = SMB3_COMPRESS_LZ77;
337 hdr.Flags = SMB2_COMPRESSION_FLAG_NONE;
338 hdr.Offset = cpu_to_le32(rq->rq_iov[0].iov_len);
339
340 iov[0].iov_base = &hdr;
341 iov[0].iov_len = sizeof(hdr);
342 iov[1] = rq->rq_iov[0];
343 iov[2].iov_base = dst;
344 iov[2].iov_len = dlen;
345
346 comp_rq.rq_iov = iov;
347
348 ret = send_fn(server, 1, &comp_rq);
349 } else if (ret == -EMSGSIZE || dlen >= slen) {
350 ret = send_fn(server, 1, rq);
351 }
352 err_free:
353 kvfree(dst);
354 kvfree(src);
355
356 return ret;
357 }
358