xref: /src/sys/contrib/zlib/infback.c (revision 7aa1dba6b00ccfb7d66627badc8a7aaa06b02946)
1 /* infback.c -- inflate using a call-back interface
2  * Copyright (C) 1995-2026 Mark Adler
3  * For conditions of distribution and use, see copyright notice in zlib.h
4  */
5 
6 /*
7    This code is largely copied from inflate.c.  Normally either infback.o or
8    inflate.o would be linked into an application--not both.  The interface
9    with inffast.c is retained so that optimized assembler-coded versions of
10    inflate_fast() can be used with either inflate.c or infback.c.
11  */
12 
13 #include "zutil.h"
14 #include "inftrees.h"
15 #include "inflate.h"
16 #include "inffast.h"
17 
18 /*
19    strm provides memory allocation functions in zalloc and zfree, or
20    Z_NULL to use the library memory allocation functions.
21 
22    windowBits is in the range 8..15, and window is a user-supplied
23    window and output buffer that is 2**windowBits bytes.
24  */
inflateBackInit_(z_streamp strm,int windowBits,unsigned char FAR * window,const char * version,int stream_size)25 int ZEXPORT inflateBackInit_(z_streamp strm, int windowBits,
26                              unsigned char FAR *window, const char *version,
27                              int stream_size) {
28     struct inflate_state FAR *state;
29 
30     if (version == Z_NULL || version[0] != ZLIB_VERSION[0] ||
31         stream_size != (int)(sizeof(z_stream)))
32         return Z_VERSION_ERROR;
33     if (strm == Z_NULL || window == Z_NULL ||
34         windowBits < 8 || windowBits > 15)
35         return Z_STREAM_ERROR;
36     strm->msg = Z_NULL;                 /* in case we return an error */
37     if (strm->zalloc == (alloc_func)0) {
38 #if defined(Z_SOLO) && !defined(_KERNEL)
39         return Z_STREAM_ERROR;
40 #else
41         strm->zalloc = zcalloc;
42         strm->opaque = (voidpf)0;
43 #endif
44     }
45     if (strm->zfree == (free_func)0)
46 #if defined(Z_SOLO) && !defined(_KERNEL)
47         return Z_STREAM_ERROR;
48 #else
49         strm->zfree = zcfree;
50 #endif
51     state = (struct inflate_state FAR *)ZALLOC(strm, 1,
52                                                sizeof(struct inflate_state));
53     if (state == Z_NULL) return Z_MEM_ERROR;
54     Tracev((stderr, "inflate: allocated\n"));
55     strm->state = (struct internal_state FAR *)state;
56     state->dmax = 32768U;
57     state->wbits = (uInt)windowBits;
58     state->wsize = 1U << windowBits;
59     state->window = window;
60     state->wnext = 0;
61     state->whave = 0;
62     state->sane = 1;
63     return Z_OK;
64 }
65 
66 /* Macros for inflateBack(): */
67 
68 /* Load returned state from inflate_fast() */
69 #define LOAD() \
70     do { \
71         put = strm->next_out; \
72         left = strm->avail_out; \
73         next = strm->next_in; \
74         have = strm->avail_in; \
75         hold = state->hold; \
76         bits = state->bits; \
77     } while (0)
78 
79 /* Set state from registers for inflate_fast() */
80 #define RESTORE() \
81     do { \
82         strm->next_out = put; \
83         strm->avail_out = left; \
84         strm->next_in = next; \
85         strm->avail_in = have; \
86         state->hold = hold; \
87         state->bits = bits; \
88     } while (0)
89 
90 /* Clear the input bit accumulator */
91 #define INITBITS() \
92     do { \
93         hold = 0; \
94         bits = 0; \
95     } while (0)
96 
97 /* Assure that some input is available.  If input is requested, but denied,
98    then return a Z_BUF_ERROR from inflateBack(). */
99 #define PULL() \
100     do { \
101         if (have == 0) { \
102             have = in(in_desc, &next); \
103             if (have == 0) { \
104                 next = Z_NULL; \
105                 ret = Z_BUF_ERROR; \
106                 goto inf_leave; \
107             } \
108         } \
109     } while (0)
110 
111 /* Get a byte of input into the bit accumulator, or return from inflateBack()
112    with an error if there is no input available. */
113 #define PULLBYTE() \
114     do { \
115         PULL(); \
116         have--; \
117         hold += (unsigned long)(*next++) << bits; \
118         bits += 8; \
119     } while (0)
120 
121 /* Assure that there are at least n bits in the bit accumulator.  If there is
122    not enough available input to do that, then return from inflateBack() with
123    an error. */
124 #define NEEDBITS(n) \
125     do { \
126         while (bits < (unsigned)(n)) \
127             PULLBYTE(); \
128     } while (0)
129 
130 /* Return the low n bits of the bit accumulator (n < 16) */
131 #define BITS(n) \
132     ((unsigned)hold & ((1U << (n)) - 1))
133 
134 /* Remove n bits from the bit accumulator */
135 #define DROPBITS(n) \
136     do { \
137         hold >>= (n); \
138         bits -= (unsigned)(n); \
139     } while (0)
140 
141 /* Remove zero to seven bits as needed to go to a byte boundary */
142 #define BYTEBITS() \
143     do { \
144         hold >>= bits & 7; \
145         bits -= bits & 7; \
146     } while (0)
147 
148 /* Assure that some output space is available, by writing out the window
149    if it's full.  If the write fails, return from inflateBack() with a
150    Z_BUF_ERROR. */
151 #define ROOM() \
152     do { \
153         if (left == 0) { \
154             put = state->window; \
155             left = state->wsize; \
156             state->whave = left; \
157             if (out(out_desc, put, left)) { \
158                 ret = Z_BUF_ERROR; \
159                 goto inf_leave; \
160             } \
161         } \
162     } while (0)
163 
164 /*
165    strm provides the memory allocation functions and window buffer on input,
166    and provides information on the unused input on return.  For Z_DATA_ERROR
167    returns, strm will also provide an error message.
168 
169    in() and out() are the call-back input and output functions.  When
170    inflateBack() needs more input, it calls in().  When inflateBack() has
171    filled the window with output, or when it completes with data in the
172    window, it calls out() to write out the data.  The application must not
173    change the provided input until in() is called again or inflateBack()
174    returns.  The application must not change the window/output buffer until
175    inflateBack() returns.
176 
177    in() and out() are called with a descriptor parameter provided in the
178    inflateBack() call.  This parameter can be a structure that provides the
179    information required to do the read or write, as well as accumulated
180    information on the input and output such as totals and check values.
181 
182    in() should return zero on failure.  out() should return non-zero on
183    failure.  If either in() or out() fails, than inflateBack() returns a
184    Z_BUF_ERROR.  strm->next_in can be checked for Z_NULL to see whether it
185    was in() or out() that caused in the error.  Otherwise,  inflateBack()
186    returns Z_STREAM_END on success, Z_DATA_ERROR for an deflate format
187    error, or Z_MEM_ERROR if it could not allocate memory for the state.
188    inflateBack() can also return Z_STREAM_ERROR if the input parameters
189    are not correct, i.e. strm is Z_NULL or the state was not initialized.
190  */
inflateBack(z_streamp strm,in_func in,void FAR * in_desc,out_func out,void FAR * out_desc)191 int ZEXPORT inflateBack(z_streamp strm, in_func in, void FAR *in_desc,
192                         out_func out, void FAR *out_desc) {
193     struct inflate_state FAR *state;
194     z_const unsigned char FAR *next;    /* next input */
195     unsigned char FAR *put;     /* next output */
196     unsigned have, left;        /* available input and output */
197     unsigned long hold;         /* bit buffer */
198     unsigned bits;              /* bits in bit buffer */
199     unsigned copy;              /* number of stored or match bytes to copy */
200     unsigned char FAR *from;    /* where to copy match bytes from */
201     code here;                  /* current decoding table entry */
202     code last;                  /* parent table entry */
203     unsigned len;               /* length to copy for repeats, bits to drop */
204     int ret;                    /* return code */
205     static const unsigned short order[19] = /* permutation of code lengths */
206         {16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15};
207 
208     /* Check that the strm exists and that the state was initialized */
209     if (strm == Z_NULL || strm->state == Z_NULL)
210         return Z_STREAM_ERROR;
211     state = (struct inflate_state FAR *)strm->state;
212 
213     /* Reset the state */
214     strm->msg = Z_NULL;
215     state->mode = TYPE;
216     state->last = 0;
217     state->whave = 0;
218     next = strm->next_in;
219     have = next != Z_NULL ? strm->avail_in : 0;
220     hold = 0;
221     bits = 0;
222     put = state->window;
223     left = state->wsize;
224 
225     /* Inflate until end of block marked as last */
226     for (;;)
227         switch (state->mode) {
228         case TYPE:
229             /* determine and dispatch block type */
230             if (state->last) {
231                 BYTEBITS();
232                 state->mode = DONE;
233                 break;
234             }
235             NEEDBITS(3);
236             state->last = BITS(1);
237             DROPBITS(1);
238             switch (BITS(2)) {
239             case 0:                             /* stored block */
240                 Tracev((stderr, "inflate:     stored block%s\n",
241                         state->last ? " (last)" : ""));
242                 state->mode = STORED;
243                 break;
244             case 1:                             /* fixed block */
245                 inflate_fixed(state);
246                 Tracev((stderr, "inflate:     fixed codes block%s\n",
247                         state->last ? " (last)" : ""));
248                 state->mode = LEN;              /* decode codes */
249                 break;
250             case 2:                             /* dynamic block */
251                 Tracev((stderr, "inflate:     dynamic codes block%s\n",
252                         state->last ? " (last)" : ""));
253                 state->mode = TABLE;
254                 break;
255             default:
256                 strm->msg = (z_const char *)"invalid block type";
257                 state->mode = BAD;
258             }
259             DROPBITS(2);
260             break;
261 
262         case STORED:
263             /* get and verify stored block length */
264             BYTEBITS();                         /* go to byte boundary */
265             NEEDBITS(32);
266             if ((hold & 0xffff) != ((hold >> 16) ^ 0xffff)) {
267                 strm->msg = (z_const char *)"invalid stored block lengths";
268                 state->mode = BAD;
269                 break;
270             }
271             state->length = (unsigned)hold & 0xffff;
272             Tracev((stderr, "inflate:       stored length %u\n",
273                     state->length));
274             INITBITS();
275 
276             /* copy stored block from input to output */
277             while (state->length != 0) {
278                 copy = state->length;
279                 PULL();
280                 ROOM();
281                 if (copy > have) copy = have;
282                 if (copy > left) copy = left;
283                 zmemcpy(put, next, copy);
284                 have -= copy;
285                 next += copy;
286                 left -= copy;
287                 put += copy;
288                 state->length -= copy;
289             }
290             Tracev((stderr, "inflate:       stored end\n"));
291             state->mode = TYPE;
292             break;
293 
294         case TABLE:
295             /* get dynamic table entries descriptor */
296             NEEDBITS(14);
297             state->nlen = BITS(5) + 257;
298             DROPBITS(5);
299             state->ndist = BITS(5) + 1;
300             DROPBITS(5);
301             state->ncode = BITS(4) + 4;
302             DROPBITS(4);
303 #ifndef PKZIP_BUG_WORKAROUND
304             if (state->nlen > 286 || state->ndist > 30) {
305                 strm->msg = (z_const char *)
306                     "too many length or distance symbols";
307                 state->mode = BAD;
308                 break;
309             }
310 #endif
311             Tracev((stderr, "inflate:       table sizes ok\n"));
312 
313             /* get code length code lengths (not a typo) */
314             state->have = 0;
315             while (state->have < state->ncode) {
316                 NEEDBITS(3);
317                 state->lens[order[state->have++]] = (unsigned short)BITS(3);
318                 DROPBITS(3);
319             }
320             while (state->have < 19)
321                 state->lens[order[state->have++]] = 0;
322             state->next = state->codes;
323             state->lencode = (code const FAR *)(state->next);
324             state->lenbits = 7;
325             ret = inflate_table(CODES, state->lens, 19, &(state->next),
326                                 &(state->lenbits), state->work);
327             if (ret) {
328                 strm->msg = (z_const char *)"invalid code lengths set";
329                 state->mode = BAD;
330                 break;
331             }
332             Tracev((stderr, "inflate:       code lengths ok\n"));
333 
334             /* get length and distance code code lengths */
335             state->have = 0;
336             while (state->have < state->nlen + state->ndist) {
337                 for (;;) {
338                     here = state->lencode[BITS(state->lenbits)];
339                     if ((unsigned)(here.bits) <= bits) break;
340                     PULLBYTE();
341                 }
342                 if (here.val < 16) {
343                     DROPBITS(here.bits);
344                     state->lens[state->have++] = here.val;
345                 }
346                 else {
347                     if (here.val == 16) {
348                         NEEDBITS(here.bits + 2);
349                         DROPBITS(here.bits);
350                         if (state->have == 0) {
351                             strm->msg = (z_const char *)
352                                 "invalid bit length repeat";
353                             state->mode = BAD;
354                             break;
355                         }
356                         len = (unsigned)(state->lens[state->have - 1]);
357                         copy = 3 + BITS(2);
358                         DROPBITS(2);
359                     }
360                     else if (here.val == 17) {
361                         NEEDBITS(here.bits + 3);
362                         DROPBITS(here.bits);
363                         len = 0;
364                         copy = 3 + BITS(3);
365                         DROPBITS(3);
366                     }
367                     else {
368                         NEEDBITS(here.bits + 7);
369                         DROPBITS(here.bits);
370                         len = 0;
371                         copy = 11 + BITS(7);
372                         DROPBITS(7);
373                     }
374                     if (state->have + copy > state->nlen + state->ndist) {
375                         strm->msg = (z_const char *)
376                             "invalid bit length repeat";
377                         state->mode = BAD;
378                         break;
379                     }
380                     while (copy--)
381                         state->lens[state->have++] = (unsigned short)len;
382                 }
383             }
384 
385             /* handle error breaks in while */
386             if (state->mode == BAD) break;
387 
388             /* check for end-of-block code (better have one) */
389             if (state->lens[256] == 0) {
390                 strm->msg = (z_const char *)
391                     "invalid code -- missing end-of-block";
392                 state->mode = BAD;
393                 break;
394             }
395 
396             /* build code tables -- note: do not change the lenbits or distbits
397                values here (9 and 6) without reading the comments in inftrees.h
398                concerning the ENOUGH constants, which depend on those values */
399             state->next = state->codes;
400             state->lencode = (code const FAR *)(state->next);
401             state->lenbits = 9;
402             ret = inflate_table(LENS, state->lens, state->nlen, &(state->next),
403                                 &(state->lenbits), state->work);
404             if (ret) {
405                 strm->msg = (z_const char *)"invalid literal/lengths set";
406                 state->mode = BAD;
407                 break;
408             }
409             state->distcode = (code const FAR *)(state->next);
410             state->distbits = 6;
411             ret = inflate_table(DISTS, state->lens + state->nlen, state->ndist,
412                             &(state->next), &(state->distbits), state->work);
413             if (ret) {
414                 strm->msg = (z_const char *)"invalid distances set";
415                 state->mode = BAD;
416                 break;
417             }
418             Tracev((stderr, "inflate:       codes ok\n"));
419             state->mode = LEN;
420                 /* fallthrough */
421 
422         case LEN:
423             /* use inflate_fast() if we have enough input and output */
424             if (have >= 6 && left >= 258) {
425                 RESTORE();
426                 if (state->whave < state->wsize)
427                     state->whave = state->wsize - left;
428                 inflate_fast(strm, state->wsize);
429                 LOAD();
430                 break;
431             }
432 
433             /* get a literal, length, or end-of-block code */
434             for (;;) {
435                 here = state->lencode[BITS(state->lenbits)];
436                 if ((unsigned)(here.bits) <= bits) break;
437                 PULLBYTE();
438             }
439             if (here.op && (here.op & 0xf0) == 0) {
440                 last = here;
441                 for (;;) {
442                     here = state->lencode[last.val +
443                             (BITS(last.bits + last.op) >> last.bits)];
444                     if ((unsigned)(last.bits + here.bits) <= bits) break;
445                     PULLBYTE();
446                 }
447                 DROPBITS(last.bits);
448             }
449             DROPBITS(here.bits);
450             state->length = (unsigned)here.val;
451 
452             /* process literal */
453             if (here.op == 0) {
454                 Tracevv((stderr, here.val >= 0x20 && here.val < 0x7f ?
455                         "inflate:         literal '%c'\n" :
456                         "inflate:         literal 0x%02x\n", here.val));
457                 ROOM();
458                 *put++ = (unsigned char)(state->length);
459                 left--;
460                 state->mode = LEN;
461                 break;
462             }
463 
464             /* process end of block */
465             if (here.op & 32) {
466                 Tracevv((stderr, "inflate:         end of block\n"));
467                 state->mode = TYPE;
468                 break;
469             }
470 
471             /* invalid code */
472             if (here.op & 64) {
473                 strm->msg = (z_const char *)"invalid literal/length code";
474                 state->mode = BAD;
475                 break;
476             }
477 
478             /* length code -- get extra bits, if any */
479             state->extra = (unsigned)(here.op) & 15;
480             if (state->extra != 0) {
481                 NEEDBITS(state->extra);
482                 state->length += BITS(state->extra);
483                 DROPBITS(state->extra);
484             }
485             Tracevv((stderr, "inflate:         length %u\n", state->length));
486 
487             /* get distance code */
488             for (;;) {
489                 here = state->distcode[BITS(state->distbits)];
490                 if ((unsigned)(here.bits) <= bits) break;
491                 PULLBYTE();
492             }
493             if ((here.op & 0xf0) == 0) {
494                 last = here;
495                 for (;;) {
496                     here = state->distcode[last.val +
497                             (BITS(last.bits + last.op) >> last.bits)];
498                     if ((unsigned)(last.bits + here.bits) <= bits) break;
499                     PULLBYTE();
500                 }
501                 DROPBITS(last.bits);
502             }
503             DROPBITS(here.bits);
504             if (here.op & 64) {
505                 strm->msg = (z_const char *)"invalid distance code";
506                 state->mode = BAD;
507                 break;
508             }
509             state->offset = (unsigned)here.val;
510 
511             /* get distance extra bits, if any */
512             state->extra = (unsigned)(here.op) & 15;
513             if (state->extra != 0) {
514                 NEEDBITS(state->extra);
515                 state->offset += BITS(state->extra);
516                 DROPBITS(state->extra);
517             }
518             if (state->offset > state->wsize - (state->whave < state->wsize ?
519                                                 left : 0)) {
520                 strm->msg = (z_const char *)"invalid distance too far back";
521                 state->mode = BAD;
522                 break;
523             }
524             Tracevv((stderr, "inflate:         distance %u\n", state->offset));
525 
526             /* copy match from window to output */
527             do {
528                 ROOM();
529                 copy = state->wsize - state->offset;
530                 if (copy < left) {
531                     from = put + copy;
532                     copy = left - copy;
533                 }
534                 else {
535                     from = put - state->offset;
536                     copy = left;
537                 }
538                 if (copy > state->length) copy = state->length;
539                 state->length -= copy;
540                 left -= copy;
541                 do {
542                     *put++ = *from++;
543                 } while (--copy);
544             } while (state->length != 0);
545             break;
546 
547         case DONE:
548             /* inflate stream terminated properly */
549             ret = Z_STREAM_END;
550             goto inf_leave;
551 
552         case BAD:
553             ret = Z_DATA_ERROR;
554             goto inf_leave;
555 
556         default:
557             /* can't happen, but makes compilers happy */
558             ret = Z_STREAM_ERROR;
559             goto inf_leave;
560         }
561 
562     /* Write leftover output and return unused input */
563   inf_leave:
564     if (left < state->wsize) {
565         if (out(out_desc, state->window, state->wsize - left) &&
566             ret == Z_STREAM_END)
567             ret = Z_BUF_ERROR;
568     }
569     strm->next_in = next;
570     strm->avail_in = have;
571     return ret;
572 }
573 
inflateBackEnd(z_streamp strm)574 int ZEXPORT inflateBackEnd(z_streamp strm) {
575     if (strm == Z_NULL || strm->state == Z_NULL || strm->zfree == (free_func)0)
576         return Z_STREAM_ERROR;
577     ZFREE(strm, strm->state);
578     strm->state = Z_NULL;
579     Tracev((stderr, "inflate: end\n"));
580     return Z_OK;
581 }
582