1 /* SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause */
2 /* ******************************************************************
3  * bitstream
4  * Part of FSE library
5  * Copyright (c) Meta Platforms, Inc. and affiliates.
6  *
7  * You can contact the author at :
8  * - Source repository : https://github.com/Cyan4973/FiniteStateEntropy
9  *
10  * This source code is licensed under both the BSD-style license (found in the
11  * LICENSE file in the root directory of this source tree) and the GPLv2 (found
12  * in the COPYING file in the root directory of this source tree).
13  * You may select, at your option, one of the above-listed licenses.
14 ****************************************************************** */
15 #ifndef BITSTREAM_H_MODULE
16 #define BITSTREAM_H_MODULE
17 
18 /*
19 *  This API consists of small unitary functions, which must be inlined for best performance.
20 *  Since link-time-optimization is not available for all compilers,
21 *  these functions are defined into a .h to be included.
22 */
23 
24 /*-****************************************
25 *  Dependencies
26 ******************************************/
27 #include "mem.h"            /* unaligned access routines */
28 #include "compiler.h"       /* UNLIKELY() */
29 #include "debug.h"          /* assert(), DEBUGLOG(), RAWLOG() */
30 #include "error_private.h"  /* error codes and messages */
31 #include "bits.h"           /* ZSTD_highbit32 */
32 
33 /*=========================================
34 *  Target specific
35 =========================================*/
36 
37 #define STREAM_ACCUMULATOR_MIN_32  25
38 #define STREAM_ACCUMULATOR_MIN_64  57
39 #define STREAM_ACCUMULATOR_MIN    ((U32)(MEM_32bits() ? STREAM_ACCUMULATOR_MIN_32 : STREAM_ACCUMULATOR_MIN_64))
40 
41 
42 /*-******************************************
43 *  bitStream encoding API (write forward)
44 ********************************************/
45 typedef size_t BitContainerType;
46 /* bitStream can mix input from multiple sources.
47  * A critical property of these streams is that they encode and decode in **reverse** direction.
48  * So the first bit sequence you add will be the last to be read, like a LIFO stack.
49  */
50 typedef struct {
51     BitContainerType bitContainer;
52     unsigned bitPos;
53     char*  startPtr;
54     char*  ptr;
55     char*  endPtr;
56 } BIT_CStream_t;
57 
58 MEM_STATIC size_t BIT_initCStream(BIT_CStream_t* bitC, void* dstBuffer, size_t dstCapacity);
59 MEM_STATIC void   BIT_addBits(BIT_CStream_t* bitC, BitContainerType value, unsigned nbBits);
60 MEM_STATIC void   BIT_flushBits(BIT_CStream_t* bitC);
61 MEM_STATIC size_t BIT_closeCStream(BIT_CStream_t* bitC);
62 
63 /* Start with initCStream, providing the size of buffer to write into.
64 *  bitStream will never write outside of this buffer.
65 *  `dstCapacity` must be >= sizeof(bitD->bitContainer), otherwise @return will be an error code.
66 *
67 *  bits are first added to a local register.
68 *  Local register is BitContainerType, 64-bits on 64-bits systems, or 32-bits on 32-bits systems.
69 *  Writing data into memory is an explicit operation, performed by the flushBits function.
70 *  Hence keep track how many bits are potentially stored into local register to avoid register overflow.
71 *  After a flushBits, a maximum of 7 bits might still be stored into local register.
72 *
73 *  Avoid storing elements of more than 24 bits if you want compatibility with 32-bits bitstream readers.
74 *
75 *  Last operation is to close the bitStream.
76 *  The function returns the final size of CStream in bytes.
77 *  If data couldn't fit into `dstBuffer`, it will return a 0 ( == not storable)
78 */
79 
80 
81 /*-********************************************
82 *  bitStream decoding API (read backward)
83 **********************************************/
84 typedef struct {
85     BitContainerType bitContainer;
86     unsigned bitsConsumed;
87     const char* ptr;
88     const char* start;
89     const char* limitPtr;
90 } BIT_DStream_t;
91 
92 typedef enum { BIT_DStream_unfinished = 0,  /* fully refilled */
93                BIT_DStream_endOfBuffer = 1, /* still some bits left in bitstream */
94                BIT_DStream_completed = 2,   /* bitstream entirely consumed, bit-exact */
95                BIT_DStream_overflow = 3     /* user requested more bits than present in bitstream */
96     } BIT_DStream_status;  /* result of BIT_reloadDStream() */
97 
98 MEM_STATIC size_t   BIT_initDStream(BIT_DStream_t* bitD, const void* srcBuffer, size_t srcSize);
99 MEM_STATIC BitContainerType BIT_readBits(BIT_DStream_t* bitD, unsigned nbBits);
100 MEM_STATIC BIT_DStream_status BIT_reloadDStream(BIT_DStream_t* bitD);
101 MEM_STATIC unsigned BIT_endOfDStream(const BIT_DStream_t* bitD);
102 
103 
104 /* Start by invoking BIT_initDStream().
105 *  A chunk of the bitStream is then stored into a local register.
106 *  Local register size is 64-bits on 64-bits systems, 32-bits on 32-bits systems (BitContainerType).
107 *  You can then retrieve bitFields stored into the local register, **in reverse order**.
108 *  Local register is explicitly reloaded from memory by the BIT_reloadDStream() method.
109 *  A reload guarantee a minimum of ((8*sizeof(bitD->bitContainer))-7) bits when its result is BIT_DStream_unfinished.
110 *  Otherwise, it can be less than that, so proceed accordingly.
111 *  Checking if DStream has reached its end can be performed with BIT_endOfDStream().
112 */
113 
114 
115 /*-****************************************
116 *  unsafe API
117 ******************************************/
118 MEM_STATIC void BIT_addBitsFast(BIT_CStream_t* bitC, BitContainerType value, unsigned nbBits);
119 /* faster, but works only if value is "clean", meaning all high bits above nbBits are 0 */
120 
121 MEM_STATIC void BIT_flushBitsFast(BIT_CStream_t* bitC);
122 /* unsafe version; does not check buffer overflow */
123 
124 MEM_STATIC size_t BIT_readBitsFast(BIT_DStream_t* bitD, unsigned nbBits);
125 /* faster, but works only if nbBits >= 1 */
126 
127 /*=====    Local Constants   =====*/
128 static const unsigned BIT_mask[] = {
129     0,          1,         3,         7,         0xF,       0x1F,
130     0x3F,       0x7F,      0xFF,      0x1FF,     0x3FF,     0x7FF,
131     0xFFF,      0x1FFF,    0x3FFF,    0x7FFF,    0xFFFF,    0x1FFFF,
132     0x3FFFF,    0x7FFFF,   0xFFFFF,   0x1FFFFF,  0x3FFFFF,  0x7FFFFF,
133     0xFFFFFF,   0x1FFFFFF, 0x3FFFFFF, 0x7FFFFFF, 0xFFFFFFF, 0x1FFFFFFF,
134     0x3FFFFFFF, 0x7FFFFFFF}; /* up to 31 bits */
135 #define BIT_MASK_SIZE (sizeof(BIT_mask) / sizeof(BIT_mask[0]))
136 
137 /*-**************************************************************
138 *  bitStream encoding
139 ****************************************************************/
140 /*! BIT_initCStream() :
141  *  `dstCapacity` must be > sizeof(size_t)
142  *  @return : 0 if success,
143  *            otherwise an error code (can be tested using ERR_isError()) */
144 MEM_STATIC size_t BIT_initCStream(BIT_CStream_t* bitC,
145                                   void* startPtr, size_t dstCapacity)
146 {
147     bitC->bitContainer = 0;
148     bitC->bitPos = 0;
149     bitC->startPtr = (char*)startPtr;
150     bitC->ptr = bitC->startPtr;
151     bitC->endPtr = bitC->startPtr + dstCapacity - sizeof(bitC->bitContainer);
152     if (dstCapacity <= sizeof(bitC->bitContainer)) return ERROR(dstSize_tooSmall);
153     return 0;
154 }
155 
156 FORCE_INLINE_TEMPLATE BitContainerType BIT_getLowerBits(BitContainerType bitContainer, U32 const nbBits)
157 {
158     assert(nbBits < BIT_MASK_SIZE);
159     return bitContainer & BIT_mask[nbBits];
160 }
161 
162 /*! BIT_addBits() :
163  *  can add up to 31 bits into `bitC`.
164  *  Note : does not check for register overflow ! */
165 MEM_STATIC void BIT_addBits(BIT_CStream_t* bitC,
166                             BitContainerType value, unsigned nbBits)
167 {
168     DEBUG_STATIC_ASSERT(BIT_MASK_SIZE == 32);
169     assert(nbBits < BIT_MASK_SIZE);
170     assert(nbBits + bitC->bitPos < sizeof(bitC->bitContainer) * 8);
171     bitC->bitContainer |= BIT_getLowerBits(value, nbBits) << bitC->bitPos;
172     bitC->bitPos += nbBits;
173 }
174 
175 /*! BIT_addBitsFast() :
176  *  works only if `value` is _clean_,
177  *  meaning all high bits above nbBits are 0 */
178 MEM_STATIC void BIT_addBitsFast(BIT_CStream_t* bitC,
179                                 BitContainerType value, unsigned nbBits)
180 {
181     assert((value>>nbBits) == 0);
182     assert(nbBits + bitC->bitPos < sizeof(bitC->bitContainer) * 8);
183     bitC->bitContainer |= value << bitC->bitPos;
184     bitC->bitPos += nbBits;
185 }
186 
187 /*! BIT_flushBitsFast() :
188  *  assumption : bitContainer has not overflowed
189  *  unsafe version; does not check buffer overflow */
190 MEM_STATIC void BIT_flushBitsFast(BIT_CStream_t* bitC)
191 {
192     size_t const nbBytes = bitC->bitPos >> 3;
193     assert(bitC->bitPos < sizeof(bitC->bitContainer) * 8);
194     assert(bitC->ptr <= bitC->endPtr);
195     MEM_writeLEST(bitC->ptr, bitC->bitContainer);
196     bitC->ptr += nbBytes;
197     bitC->bitPos &= 7;
198     bitC->bitContainer >>= nbBytes*8;
199 }
200 
201 /*! BIT_flushBits() :
202  *  assumption : bitContainer has not overflowed
203  *  safe version; check for buffer overflow, and prevents it.
204  *  note : does not signal buffer overflow.
205  *  overflow will be revealed later on using BIT_closeCStream() */
206 MEM_STATIC void BIT_flushBits(BIT_CStream_t* bitC)
207 {
208     size_t const nbBytes = bitC->bitPos >> 3;
209     assert(bitC->bitPos < sizeof(bitC->bitContainer) * 8);
210     assert(bitC->ptr <= bitC->endPtr);
211     MEM_writeLEST(bitC->ptr, bitC->bitContainer);
212     bitC->ptr += nbBytes;
213     if (bitC->ptr > bitC->endPtr) bitC->ptr = bitC->endPtr;
214     bitC->bitPos &= 7;
215     bitC->bitContainer >>= nbBytes*8;
216 }
217 
218 /*! BIT_closeCStream() :
219  *  @return : size of CStream, in bytes,
220  *            or 0 if it could not fit into dstBuffer */
221 MEM_STATIC size_t BIT_closeCStream(BIT_CStream_t* bitC)
222 {
223     BIT_addBitsFast(bitC, 1, 1);   /* endMark */
224     BIT_flushBits(bitC);
225     if (bitC->ptr >= bitC->endPtr) return 0; /* overflow detected */
226     return (size_t)(bitC->ptr - bitC->startPtr) + (bitC->bitPos > 0);
227 }
228 
229 
230 /*-********************************************************
231 *  bitStream decoding
232 **********************************************************/
233 /*! BIT_initDStream() :
234  *  Initialize a BIT_DStream_t.
235  * `bitD` : a pointer to an already allocated BIT_DStream_t structure.
236  * `srcSize` must be the *exact* size of the bitStream, in bytes.
237  * @return : size of stream (== srcSize), or an errorCode if a problem is detected
238  */
239 MEM_STATIC size_t BIT_initDStream(BIT_DStream_t* bitD, const void* srcBuffer, size_t srcSize)
240 {
241     if (srcSize < 1) { ZSTD_memset(bitD, 0, sizeof(*bitD)); return ERROR(srcSize_wrong); }
242 
243     bitD->start = (const char*)srcBuffer;
244     bitD->limitPtr = bitD->start + sizeof(bitD->bitContainer);
245 
246     if (srcSize >=  sizeof(bitD->bitContainer)) {  /* normal case */
247         bitD->ptr   = (const char*)srcBuffer + srcSize - sizeof(bitD->bitContainer);
248         bitD->bitContainer = MEM_readLEST(bitD->ptr);
249         { BYTE const lastByte = ((const BYTE*)srcBuffer)[srcSize-1];
250           bitD->bitsConsumed = lastByte ? 8 - ZSTD_highbit32(lastByte) : 0;  /* ensures bitsConsumed is always set */
251           if (lastByte == 0) return ERROR(GENERIC); /* endMark not present */ }
252     } else {
253         bitD->ptr   = bitD->start;
254         bitD->bitContainer = *(const BYTE*)(bitD->start);
255         switch(srcSize)
256         {
257         case 7: bitD->bitContainer += (BitContainerType)(((const BYTE*)(srcBuffer))[6]) << (sizeof(bitD->bitContainer)*8 - 16);
258                 ZSTD_FALLTHROUGH;
259 
260         case 6: bitD->bitContainer += (BitContainerType)(((const BYTE*)(srcBuffer))[5]) << (sizeof(bitD->bitContainer)*8 - 24);
261                 ZSTD_FALLTHROUGH;
262 
263         case 5: bitD->bitContainer += (BitContainerType)(((const BYTE*)(srcBuffer))[4]) << (sizeof(bitD->bitContainer)*8 - 32);
264                 ZSTD_FALLTHROUGH;
265 
266         case 4: bitD->bitContainer += (BitContainerType)(((const BYTE*)(srcBuffer))[3]) << 24;
267                 ZSTD_FALLTHROUGH;
268 
269         case 3: bitD->bitContainer += (BitContainerType)(((const BYTE*)(srcBuffer))[2]) << 16;
270                 ZSTD_FALLTHROUGH;
271 
272         case 2: bitD->bitContainer += (BitContainerType)(((const BYTE*)(srcBuffer))[1]) <<  8;
273                 ZSTD_FALLTHROUGH;
274 
275         default: break;
276         }
277         {   BYTE const lastByte = ((const BYTE*)srcBuffer)[srcSize-1];
278             bitD->bitsConsumed = lastByte ? 8 - ZSTD_highbit32(lastByte) : 0;
279             if (lastByte == 0) return ERROR(corruption_detected);  /* endMark not present */
280         }
281         bitD->bitsConsumed += (U32)(sizeof(bitD->bitContainer) - srcSize)*8;
282     }
283 
284     return srcSize;
285 }
286 
287 FORCE_INLINE_TEMPLATE BitContainerType BIT_getUpperBits(BitContainerType bitContainer, U32 const start)
288 {
289     return bitContainer >> start;
290 }
291 
292 FORCE_INLINE_TEMPLATE BitContainerType BIT_getMiddleBits(BitContainerType bitContainer, U32 const start, U32 const nbBits)
293 {
294     U32 const regMask = sizeof(bitContainer)*8 - 1;
295     /* if start > regMask, bitstream is corrupted, and result is undefined */
296     assert(nbBits < BIT_MASK_SIZE);
297     /* x86 transform & ((1 << nbBits) - 1) to bzhi instruction, it is better
298      * than accessing memory. When bmi2 instruction is not present, we consider
299      * such cpus old (pre-Haswell, 2013) and their performance is not of that
300      * importance.
301      */
302 #if defined(__x86_64__) || defined(_M_X64)
303     return (bitContainer >> (start & regMask)) & ((((U64)1) << nbBits) - 1);
304 #else
305     return (bitContainer >> (start & regMask)) & BIT_mask[nbBits];
306 #endif
307 }
308 
309 /*! BIT_lookBits() :
310  *  Provides next n bits from local register.
311  *  local register is not modified.
312  *  On 32-bits, maxNbBits==24.
313  *  On 64-bits, maxNbBits==56.
314  * @return : value extracted */
315 FORCE_INLINE_TEMPLATE BitContainerType BIT_lookBits(const BIT_DStream_t*  bitD, U32 nbBits)
316 {
317     /* arbitrate between double-shift and shift+mask */
318 #if 1
319     /* if bitD->bitsConsumed + nbBits > sizeof(bitD->bitContainer)*8,
320      * bitstream is likely corrupted, and result is undefined */
321     return BIT_getMiddleBits(bitD->bitContainer, (sizeof(bitD->bitContainer)*8) - bitD->bitsConsumed - nbBits, nbBits);
322 #else
323     /* this code path is slower on my os-x laptop */
324     U32 const regMask = sizeof(bitD->bitContainer)*8 - 1;
325     return ((bitD->bitContainer << (bitD->bitsConsumed & regMask)) >> 1) >> ((regMask-nbBits) & regMask);
326 #endif
327 }
328 
329 /*! BIT_lookBitsFast() :
330  *  unsafe version; only works if nbBits >= 1 */
331 MEM_STATIC BitContainerType BIT_lookBitsFast(const BIT_DStream_t* bitD, U32 nbBits)
332 {
333     U32 const regMask = sizeof(bitD->bitContainer)*8 - 1;
334     assert(nbBits >= 1);
335     return (bitD->bitContainer << (bitD->bitsConsumed & regMask)) >> (((regMask+1)-nbBits) & regMask);
336 }
337 
338 FORCE_INLINE_TEMPLATE void BIT_skipBits(BIT_DStream_t* bitD, U32 nbBits)
339 {
340     bitD->bitsConsumed += nbBits;
341 }
342 
343 /*! BIT_readBits() :
344  *  Read (consume) next n bits from local register and update.
345  *  Pay attention to not read more than nbBits contained into local register.
346  * @return : extracted value. */
347 FORCE_INLINE_TEMPLATE BitContainerType BIT_readBits(BIT_DStream_t* bitD, unsigned nbBits)
348 {
349     BitContainerType const value = BIT_lookBits(bitD, nbBits);
350     BIT_skipBits(bitD, nbBits);
351     return value;
352 }
353 
354 /*! BIT_readBitsFast() :
355  *  unsafe version; only works if nbBits >= 1 */
356 MEM_STATIC BitContainerType BIT_readBitsFast(BIT_DStream_t* bitD, unsigned nbBits)
357 {
358     BitContainerType const value = BIT_lookBitsFast(bitD, nbBits);
359     assert(nbBits >= 1);
360     BIT_skipBits(bitD, nbBits);
361     return value;
362 }
363 
364 /*! BIT_reloadDStream_internal() :
365  *  Simple variant of BIT_reloadDStream(), with two conditions:
366  *  1. bitstream is valid : bitsConsumed <= sizeof(bitD->bitContainer)*8
367  *  2. look window is valid after shifted down : bitD->ptr >= bitD->start
368  */
369 MEM_STATIC BIT_DStream_status BIT_reloadDStream_internal(BIT_DStream_t* bitD)
370 {
371     assert(bitD->bitsConsumed <= sizeof(bitD->bitContainer)*8);
372     bitD->ptr -= bitD->bitsConsumed >> 3;
373     assert(bitD->ptr >= bitD->start);
374     bitD->bitsConsumed &= 7;
375     bitD->bitContainer = MEM_readLEST(bitD->ptr);
376     return BIT_DStream_unfinished;
377 }
378 
379 /*! BIT_reloadDStreamFast() :
380  *  Similar to BIT_reloadDStream(), but with two differences:
381  *  1. bitsConsumed <= sizeof(bitD->bitContainer)*8 must hold!
382  *  2. Returns BIT_DStream_overflow when bitD->ptr < bitD->limitPtr, at this
383  *     point you must use BIT_reloadDStream() to reload.
384  */
385 MEM_STATIC BIT_DStream_status BIT_reloadDStreamFast(BIT_DStream_t* bitD)
386 {
387     if (UNLIKELY(bitD->ptr < bitD->limitPtr))
388         return BIT_DStream_overflow;
389     return BIT_reloadDStream_internal(bitD);
390 }
391 
392 /*! BIT_reloadDStream() :
393  *  Refill `bitD` from buffer previously set in BIT_initDStream() .
394  *  This function is safe, it guarantees it will not never beyond src buffer.
395  * @return : status of `BIT_DStream_t` internal register.
396  *           when status == BIT_DStream_unfinished, internal register is filled with at least 25 or 57 bits */
397 FORCE_INLINE_TEMPLATE BIT_DStream_status BIT_reloadDStream(BIT_DStream_t* bitD)
398 {
399     /* note : once in overflow mode, a bitstream remains in this mode until it's reset */
400     if (UNLIKELY(bitD->bitsConsumed > (sizeof(bitD->bitContainer)*8))) {
401         static const BitContainerType zeroFilled = 0;
402         bitD->ptr = (const char*)&zeroFilled; /* aliasing is allowed for char */
403         /* overflow detected, erroneous scenario or end of stream: no update */
404         return BIT_DStream_overflow;
405     }
406 
407     assert(bitD->ptr >= bitD->start);
408 
409     if (bitD->ptr >= bitD->limitPtr) {
410         return BIT_reloadDStream_internal(bitD);
411     }
412     if (bitD->ptr == bitD->start) {
413         /* reached end of bitStream => no update */
414         if (bitD->bitsConsumed < sizeof(bitD->bitContainer)*8) return BIT_DStream_endOfBuffer;
415         return BIT_DStream_completed;
416     }
417     /* start < ptr < limitPtr => cautious update */
418     {   U32 nbBytes = bitD->bitsConsumed >> 3;
419         BIT_DStream_status result = BIT_DStream_unfinished;
420         if (bitD->ptr - nbBytes < bitD->start) {
421             nbBytes = (U32)(bitD->ptr - bitD->start);  /* ptr > start */
422             result = BIT_DStream_endOfBuffer;
423         }
424         bitD->ptr -= nbBytes;
425         bitD->bitsConsumed -= nbBytes*8;
426         bitD->bitContainer = MEM_readLEST(bitD->ptr);   /* reminder : srcSize > sizeof(bitD->bitContainer), otherwise bitD->ptr == bitD->start */
427         return result;
428     }
429 }
430 
431 /*! BIT_endOfDStream() :
432  * @return : 1 if DStream has _exactly_ reached its end (all bits consumed).
433  */
434 MEM_STATIC unsigned BIT_endOfDStream(const BIT_DStream_t* DStream)
435 {
436     return ((DStream->ptr == DStream->start) && (DStream->bitsConsumed == sizeof(DStream->bitContainer)*8));
437 }
438 
439 #endif /* BITSTREAM_H_MODULE */
440