xref: /src/crypto/openssl/ssl/quic/quic_srtm.c (revision f25b8c9fb4f58cf61adb47d7570abe7caa6d385d)
1 /*
2  * Copyright 2023-2024 The OpenSSL Project Authors. All Rights Reserved.
3  *
4  * Licensed under the Apache License 2.0 (the "License").  You may not use
5  * this file except in compliance with the License.  You can obtain a copy
6  * in the file LICENSE in the source distribution or at
7  * https://www.openssl.org/source/license.html
8  */
9 
10 #include "internal/quic_srtm.h"
11 #include "internal/common.h"
12 #include <openssl/lhash.h>
13 #include <openssl/core_names.h>
14 #include <openssl/rand.h>
15 
16 /*
17  * QUIC Stateless Reset Token Manager
18  * ==================================
19  */
20 typedef struct srtm_item_st SRTM_ITEM;
21 
22 #define BLINDED_SRT_LEN 16
23 
24 DEFINE_LHASH_OF_EX(SRTM_ITEM);
25 
26 /*
27  * The SRTM is implemented using two LHASH instances, one matching opaque pointers to
28  * an item structure, and another matching a SRT-derived value to an item
29  * structure. Multiple items with different seq_num values under a given opaque,
30  * and duplicate SRTs, are handled using sorted singly-linked lists.
31  *
32  * The O(n) insert and lookup performance is tolerated on the basis that the
33  * total number of entries for a given opaque (total number of extant CIDs for a
34  * connection) should be quite small, and the QUIC protocol allows us to place a
35  * hard limit on this via the active_connection_id_limit TPARAM. Thus there is
36  * no risk of a large number of SRTs needing to be registered under a given
37  * opaque.
38  *
39  * It is expected one SRTM will exist per QUIC_PORT and track all SRTs across
40  * all connections for that QUIC_PORT.
41  */
42 struct srtm_item_st {
43     SRTM_ITEM *next_by_srt_blinded; /* SORT BY opaque  DESC */
44     SRTM_ITEM *next_by_seq_num; /* SORT BY seq_num DESC */
45     void *opaque; /* \__ unique identity for item */
46     uint64_t seq_num; /* /                            */
47     QUIC_STATELESS_RESET_TOKEN srt;
48     unsigned char srt_blinded[BLINDED_SRT_LEN]; /* H(srt) */
49 
50 #ifdef FUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION
51     uint32_t debug_token;
52 #endif
53 };
54 
55 struct quic_srtm_st {
56     /* Crypto context used to calculate blinded SRTs H(srt). */
57     EVP_CIPHER_CTX *blind_ctx; /* kept with key */
58 
59     LHASH_OF(SRTM_ITEM) *items_fwd; /* (opaque)  -> SRTM_ITEM */
60     LHASH_OF(SRTM_ITEM) *items_rev; /* (H(srt))  -> SRTM_ITEM */
61 
62     /*
63      * Monotonically transitions to 1 in event of allocation failure. The only
64      * valid operation on such an object is to free it.
65      */
66     unsigned int alloc_failed : 1;
67 };
68 
items_fwd_hash(const SRTM_ITEM * item)69 static unsigned long items_fwd_hash(const SRTM_ITEM *item)
70 {
71     return (unsigned long)(uintptr_t)item->opaque;
72 }
73 
items_fwd_cmp(const SRTM_ITEM * a,const SRTM_ITEM * b)74 static int items_fwd_cmp(const SRTM_ITEM *a, const SRTM_ITEM *b)
75 {
76     return a->opaque != b->opaque;
77 }
78 
items_rev_hash(const SRTM_ITEM * item)79 static unsigned long items_rev_hash(const SRTM_ITEM *item)
80 {
81     /*
82      * srt_blinded has already been through a crypto-grade hash function, so we
83      * can just use bits from that.
84      */
85     unsigned long l;
86 
87     memcpy(&l, item->srt_blinded, sizeof(l));
88     return l;
89 }
90 
items_rev_cmp(const SRTM_ITEM * a,const SRTM_ITEM * b)91 static int items_rev_cmp(const SRTM_ITEM *a, const SRTM_ITEM *b)
92 {
93     /*
94      * We don't need to use CRYPTO_memcmp here as the relationship of
95      * srt_blinded to srt is already cryptographically obfuscated.
96      */
97     return memcmp(a->srt_blinded, b->srt_blinded, sizeof(a->srt_blinded));
98 }
99 
srtm_check_lh(QUIC_SRTM * srtm,LHASH_OF (SRTM_ITEM)* lh)100 static int srtm_check_lh(QUIC_SRTM *srtm, LHASH_OF(SRTM_ITEM) *lh)
101 {
102     if (lh_SRTM_ITEM_error(lh)) {
103         srtm->alloc_failed = 1;
104         return 0;
105     }
106 
107     return 1;
108 }
109 
ossl_quic_srtm_new(OSSL_LIB_CTX * libctx,const char * propq)110 QUIC_SRTM *ossl_quic_srtm_new(OSSL_LIB_CTX *libctx, const char *propq)
111 {
112     QUIC_SRTM *srtm = NULL;
113     unsigned char key[16];
114     EVP_CIPHER *ecb = NULL;
115 
116     if (RAND_priv_bytes_ex(libctx, key, sizeof(key), sizeof(key) * 8) != 1)
117         goto err;
118 
119     if ((srtm = OPENSSL_zalloc(sizeof(*srtm))) == NULL)
120         return NULL;
121 
122     /* Use AES-128-ECB as a permutation over 128-bit SRTs. */
123     if ((ecb = EVP_CIPHER_fetch(libctx, "AES-128-ECB", propq)) == NULL)
124         goto err;
125 
126     if ((srtm->blind_ctx = EVP_CIPHER_CTX_new()) == NULL)
127         goto err;
128 
129     if (!EVP_EncryptInit_ex2(srtm->blind_ctx, ecb, key, NULL, NULL))
130         goto err;
131 
132     EVP_CIPHER_free(ecb);
133     ecb = NULL;
134 
135     /* Create mappings. */
136     if ((srtm->items_fwd = lh_SRTM_ITEM_new(items_fwd_hash, items_fwd_cmp)) == NULL
137         || (srtm->items_rev = lh_SRTM_ITEM_new(items_rev_hash, items_rev_cmp)) == NULL)
138         goto err;
139 
140     return srtm;
141 
142 err:
143     /*
144      * No cleansing of key needed as blinding exists only for side channel
145      * mitigation.
146      */
147     ossl_quic_srtm_free(srtm);
148     EVP_CIPHER_free(ecb);
149     return NULL;
150 }
151 
srtm_free_each(SRTM_ITEM * ihead)152 static void srtm_free_each(SRTM_ITEM *ihead)
153 {
154     SRTM_ITEM *inext, *item = ihead;
155 
156     for (item = item->next_by_seq_num; item != NULL; item = inext) {
157         inext = item->next_by_seq_num;
158         OPENSSL_free(item);
159     }
160 
161     OPENSSL_free(ihead);
162 }
163 
ossl_quic_srtm_free(QUIC_SRTM * srtm)164 void ossl_quic_srtm_free(QUIC_SRTM *srtm)
165 {
166     if (srtm == NULL)
167         return;
168 
169     lh_SRTM_ITEM_free(srtm->items_rev);
170     if (srtm->items_fwd != NULL) {
171         lh_SRTM_ITEM_doall(srtm->items_fwd, srtm_free_each);
172         lh_SRTM_ITEM_free(srtm->items_fwd);
173     }
174 
175     EVP_CIPHER_CTX_free(srtm->blind_ctx);
176     OPENSSL_free(srtm);
177 }
178 
179 /*
180  * Find a SRTM_ITEM by (opaque, seq_num). Returns NULL if no match.
181  * If head is non-NULL, writes the head of the relevant opaque list to *head if
182  * there is one.
183  * If prev is non-NULL, writes the previous node to *prev or NULL if it is
184  * the first item.
185  */
srtm_find(QUIC_SRTM * srtm,void * opaque,uint64_t seq_num,SRTM_ITEM ** head_p,SRTM_ITEM ** prev_p)186 static SRTM_ITEM *srtm_find(QUIC_SRTM *srtm, void *opaque, uint64_t seq_num,
187     SRTM_ITEM **head_p, SRTM_ITEM **prev_p)
188 {
189     SRTM_ITEM key, *item = NULL, *prev = NULL;
190 
191     key.opaque = opaque;
192 
193     item = lh_SRTM_ITEM_retrieve(srtm->items_fwd, &key);
194     if (head_p != NULL)
195         *head_p = item;
196 
197     for (; item != NULL; prev = item, item = item->next_by_seq_num)
198         if (item->seq_num == seq_num) {
199             break;
200         } else if (item->seq_num < seq_num) {
201             /*
202              * List is sorted in descending order so there can't be any match
203              * after this.
204              */
205             item = NULL;
206             break;
207         }
208 
209     if (prev_p != NULL)
210         *prev_p = prev;
211 
212     return item;
213 }
214 
215 /*
216  * Inserts a SRTM_ITEM into the singly-linked by-sequence-number linked list.
217  * The new head pointer is written to *new_head (which may or may not be
218  * unchanged).
219  */
sorted_insert_seq_num(SRTM_ITEM * head,SRTM_ITEM * item,SRTM_ITEM ** new_head)220 static void sorted_insert_seq_num(SRTM_ITEM *head, SRTM_ITEM *item, SRTM_ITEM **new_head)
221 {
222     uint64_t seq_num = item->seq_num;
223     SRTM_ITEM *cur = head, **fixup = new_head;
224 
225     *new_head = head;
226 
227     while (cur != NULL && cur->seq_num > seq_num) {
228         fixup = &cur->next_by_seq_num;
229         cur = cur->next_by_seq_num;
230     }
231 
232     item->next_by_seq_num = *fixup;
233     *fixup = item;
234 }
235 
236 /*
237  * Inserts a SRTM_ITEM into the singly-linked by-SRT list.
238  * The new head pointer is written to *new_head (which may or may not be
239  * unchanged).
240  */
sorted_insert_srt(SRTM_ITEM * head,SRTM_ITEM * item,SRTM_ITEM ** new_head)241 static void sorted_insert_srt(SRTM_ITEM *head, SRTM_ITEM *item, SRTM_ITEM **new_head)
242 {
243     uintptr_t opaque = (uintptr_t)item->opaque;
244     SRTM_ITEM *cur = head, **fixup = new_head;
245 
246     *new_head = head;
247 
248     while (cur != NULL && (uintptr_t)cur->opaque > opaque) {
249         fixup = &cur->next_by_srt_blinded;
250         cur = cur->next_by_srt_blinded;
251     }
252 
253     item->next_by_srt_blinded = *fixup;
254     *fixup = item;
255 }
256 
257 /*
258  * Computes the blinded SRT value used for internal lookup for side channel
259  * mitigation purposes. We compute this once as a cached value when an SRTM_ITEM
260  * is formed.
261  */
srtm_compute_blinded(QUIC_SRTM * srtm,SRTM_ITEM * item,const QUIC_STATELESS_RESET_TOKEN * token)262 static int srtm_compute_blinded(QUIC_SRTM *srtm, SRTM_ITEM *item,
263     const QUIC_STATELESS_RESET_TOKEN *token)
264 {
265     int outl = 0;
266 
267     /*
268      * We use AES-128-ECB as a permutation using a random key to facilitate
269      * blinding for side-channel purposes. Encrypt the token as a single AES
270      * block.
271      */
272     if (!EVP_EncryptUpdate(srtm->blind_ctx, item->srt_blinded, &outl,
273             (const unsigned char *)token, sizeof(*token)))
274         return 0;
275 
276     if (!ossl_assert(outl == sizeof(*token)))
277         return 0;
278 
279     return 1;
280 }
281 
ossl_quic_srtm_add(QUIC_SRTM * srtm,void * opaque,uint64_t seq_num,const QUIC_STATELESS_RESET_TOKEN * token)282 int ossl_quic_srtm_add(QUIC_SRTM *srtm, void *opaque, uint64_t seq_num,
283     const QUIC_STATELESS_RESET_TOKEN *token)
284 {
285     SRTM_ITEM *item = NULL, *head = NULL, *new_head, *r_item;
286 
287     if (srtm->alloc_failed)
288         return 0;
289 
290     /* (opaque, seq_num) duplicates not allowed */
291     if ((item = srtm_find(srtm, opaque, seq_num, &head, NULL)) != NULL)
292         return 0;
293 
294     if ((item = OPENSSL_zalloc(sizeof(*item))) == NULL)
295         return 0;
296 
297     item->opaque = opaque;
298     item->seq_num = seq_num;
299     item->srt = *token;
300     if (!srtm_compute_blinded(srtm, item, &item->srt)) {
301         OPENSSL_free(item);
302         return 0;
303     }
304 
305     /* Add to forward mapping. */
306     if (head == NULL) {
307         /* First item under this opaque */
308         lh_SRTM_ITEM_insert(srtm->items_fwd, item);
309         if (!srtm_check_lh(srtm, srtm->items_fwd)) {
310             OPENSSL_free(item);
311             return 0;
312         }
313     } else {
314         sorted_insert_seq_num(head, item, &new_head);
315         if (new_head != head) { /* head changed, update in lhash */
316             lh_SRTM_ITEM_insert(srtm->items_fwd, new_head);
317             if (!srtm_check_lh(srtm, srtm->items_fwd)) {
318                 OPENSSL_free(item);
319                 return 0;
320             }
321         }
322     }
323 
324     /* Add to reverse mapping. */
325     r_item = lh_SRTM_ITEM_retrieve(srtm->items_rev, item);
326     if (r_item == NULL) {
327         /* First item under this blinded SRT */
328         lh_SRTM_ITEM_insert(srtm->items_rev, item);
329         if (!srtm_check_lh(srtm, srtm->items_rev))
330             /*
331              * Can't free the item now as we would have to undo the insertion
332              * into the forward mapping which would require an insert operation
333              * to restore the previous value. which might also fail. However,
334              * the item will be freed OK when we free the entire SRTM.
335              */
336             return 0;
337     } else {
338         sorted_insert_srt(r_item, item, &new_head);
339         if (new_head != r_item) { /* head changed, update in lhash */
340             lh_SRTM_ITEM_insert(srtm->items_rev, new_head);
341             if (!srtm_check_lh(srtm, srtm->items_rev))
342                 /* As above. */
343                 return 0;
344         }
345     }
346 
347     return 1;
348 }
349 
350 /* Remove item from reverse mapping. */
srtm_remove_from_rev(QUIC_SRTM * srtm,SRTM_ITEM * item)351 static int srtm_remove_from_rev(QUIC_SRTM *srtm, SRTM_ITEM *item)
352 {
353     SRTM_ITEM *rh_item;
354 
355     rh_item = lh_SRTM_ITEM_retrieve(srtm->items_rev, item);
356     assert(rh_item != NULL);
357     if (rh_item == item) {
358         /*
359          * Change lhash to point to item after this one, or remove the entry if
360          * this is the last one.
361          */
362         if (item->next_by_srt_blinded != NULL) {
363             lh_SRTM_ITEM_insert(srtm->items_rev, item->next_by_srt_blinded);
364             if (!srtm_check_lh(srtm, srtm->items_rev))
365                 return 0;
366         } else {
367             lh_SRTM_ITEM_delete(srtm->items_rev, item);
368         }
369     } else {
370         /* Find our entry in the SRT list */
371         for (; rh_item->next_by_srt_blinded != item;
372             rh_item = rh_item->next_by_srt_blinded)
373             ;
374         rh_item->next_by_srt_blinded = item->next_by_srt_blinded;
375     }
376 
377     return 1;
378 }
379 
ossl_quic_srtm_remove(QUIC_SRTM * srtm,void * opaque,uint64_t seq_num)380 int ossl_quic_srtm_remove(QUIC_SRTM *srtm, void *opaque, uint64_t seq_num)
381 {
382     SRTM_ITEM *item, *prev = NULL;
383 
384     if (srtm->alloc_failed)
385         return 0;
386 
387     if ((item = srtm_find(srtm, opaque, seq_num, NULL, &prev)) == NULL)
388         /* No match */
389         return 0;
390 
391     /* Remove from forward mapping. */
392     if (prev == NULL) {
393         /*
394          * Change lhash to point to item after this one, or remove the entry if
395          * this is the last one.
396          */
397         if (item->next_by_seq_num != NULL) {
398             lh_SRTM_ITEM_insert(srtm->items_fwd, item->next_by_seq_num);
399             if (!srtm_check_lh(srtm, srtm->items_fwd))
400                 return 0;
401         } else {
402             lh_SRTM_ITEM_delete(srtm->items_fwd, item);
403         }
404     } else {
405         prev->next_by_seq_num = item->next_by_seq_num;
406     }
407 
408     /* Remove from reverse mapping. */
409     if (!srtm_remove_from_rev(srtm, item))
410         return 0;
411 
412     OPENSSL_free(item);
413     return 1;
414 }
415 
ossl_quic_srtm_cull(QUIC_SRTM * srtm,void * opaque)416 int ossl_quic_srtm_cull(QUIC_SRTM *srtm, void *opaque)
417 {
418     SRTM_ITEM key, *item = NULL, *inext, *ihead;
419 
420     key.opaque = opaque;
421 
422     if (srtm->alloc_failed)
423         return 0;
424 
425     if ((ihead = lh_SRTM_ITEM_retrieve(srtm->items_fwd, &key)) == NULL)
426         return 1; /* nothing removed is a success condition */
427 
428     for (item = ihead; item != NULL; item = inext) {
429         inext = item->next_by_seq_num;
430         if (item != ihead) {
431             srtm_remove_from_rev(srtm, item);
432             OPENSSL_free(item);
433         }
434     }
435 
436     lh_SRTM_ITEM_delete(srtm->items_fwd, ihead);
437     srtm_remove_from_rev(srtm, ihead);
438     OPENSSL_free(ihead);
439     return 1;
440 }
441 
ossl_quic_srtm_lookup(QUIC_SRTM * srtm,const QUIC_STATELESS_RESET_TOKEN * token,size_t idx,void ** opaque,uint64_t * seq_num)442 int ossl_quic_srtm_lookup(QUIC_SRTM *srtm,
443     const QUIC_STATELESS_RESET_TOKEN *token,
444     size_t idx,
445     void **opaque, uint64_t *seq_num)
446 {
447     SRTM_ITEM key, *item;
448 
449     if (srtm->alloc_failed)
450         return 0;
451 
452     if (!srtm_compute_blinded(srtm, &key, token))
453         return 0;
454 
455     item = lh_SRTM_ITEM_retrieve(srtm->items_rev, &key);
456     for (; idx > 0 && item != NULL; --idx, item = item->next_by_srt_blinded)
457         ;
458     if (item == NULL)
459         return 0;
460 
461     if (opaque != NULL)
462         *opaque = item->opaque;
463     if (seq_num != NULL)
464         *seq_num = item->seq_num;
465 
466     return 1;
467 }
468 
469 #ifdef FUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION
470 
471 static uint32_t token_next = 0x5eadbeef;
472 static size_t tokens_seen;
473 
474 struct check_args {
475     uint32_t token;
476     int mode;
477 };
478 
check_mark(SRTM_ITEM * item,void * arg)479 static void check_mark(SRTM_ITEM *item, void *arg)
480 {
481     struct check_args *arg_ = arg;
482     uint32_t token = arg_->token;
483     uint64_t prev_seq_num = 0;
484     void *prev_opaque = NULL;
485     int have_prev = 0;
486 
487     assert(item != NULL);
488 
489     while (item != NULL) {
490         if (have_prev) {
491             assert(!(item->opaque == prev_opaque && item->seq_num == prev_seq_num));
492             if (!arg_->mode)
493                 assert(item->opaque != prev_opaque || item->seq_num < prev_seq_num);
494         }
495 
496         ++tokens_seen;
497         item->debug_token = token;
498         prev_opaque = item->opaque;
499         prev_seq_num = item->seq_num;
500         have_prev = 1;
501 
502         if (arg_->mode)
503             item = item->next_by_srt_blinded;
504         else
505             item = item->next_by_seq_num;
506     }
507 }
508 
check_count(SRTM_ITEM * item,void * arg)509 static void check_count(SRTM_ITEM *item, void *arg)
510 {
511     struct check_args *arg_ = arg;
512     uint32_t token = arg_->token;
513 
514     assert(item != NULL);
515 
516     while (item != NULL) {
517         ++tokens_seen;
518         assert(item->debug_token == token);
519 
520         if (arg_->mode)
521             item = item->next_by_seq_num;
522         else
523             item = item->next_by_srt_blinded;
524     }
525 }
526 
527 #endif
528 
ossl_quic_srtm_check(const QUIC_SRTM * srtm)529 void ossl_quic_srtm_check(const QUIC_SRTM *srtm)
530 {
531 #ifdef FUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION
532     struct check_args args = { 0 };
533     size_t tokens_expected, tokens_expected_old;
534 
535     args.token = token_next;
536     ++token_next;
537 
538     assert(srtm != NULL);
539     assert(srtm->blind_ctx != NULL);
540     assert(srtm->items_fwd != NULL);
541     assert(srtm->items_rev != NULL);
542 
543     tokens_seen = 0;
544     lh_SRTM_ITEM_doall_arg(srtm->items_fwd, check_mark, &args);
545 
546     tokens_expected = tokens_seen;
547     tokens_seen = 0;
548     lh_SRTM_ITEM_doall_arg(srtm->items_rev, check_count, &args);
549 
550     assert(tokens_seen == tokens_expected);
551     tokens_expected_old = tokens_expected;
552 
553     args.token = token_next;
554     ++token_next;
555 
556     args.mode = 1;
557     tokens_seen = 0;
558     lh_SRTM_ITEM_doall_arg(srtm->items_rev, check_mark, &args);
559 
560     tokens_expected = tokens_seen;
561     tokens_seen = 0;
562     lh_SRTM_ITEM_doall_arg(srtm->items_fwd, check_count, &args);
563 
564     assert(tokens_seen == tokens_expected);
565     assert(tokens_seen == tokens_expected_old);
566 #endif
567 }
568