1 // SPDX-License-Identifier: GPL-2.0 2 #include "comm.h" 3 #include <errno.h> 4 #include <string.h> 5 #include <internal/rc_check.h> 6 #include <linux/refcount.h> 7 #include <linux/zalloc.h> 8 #include <tools/libc_compat.h> // reallocarray 9 10 #include "rwsem.h" 11 12 DECLARE_RC_STRUCT(comm_str) { 13 refcount_t refcnt; 14 char str[]; 15 }; 16 17 static struct comm_strs { 18 struct rw_semaphore lock; 19 struct comm_str **strs; 20 int num_strs; 21 int capacity; 22 } _comm_strs; 23 24 static void comm_strs__remove_if_last(struct comm_str *cs); 25 26 static void comm_strs__init(void) 27 { 28 init_rwsem(&_comm_strs.lock); 29 _comm_strs.capacity = 16; 30 _comm_strs.num_strs = 0; 31 _comm_strs.strs = calloc(16, sizeof(*_comm_strs.strs)); 32 } 33 34 static struct comm_strs *comm_strs__get(void) 35 { 36 static pthread_once_t comm_strs_type_once = PTHREAD_ONCE_INIT; 37 38 pthread_once(&comm_strs_type_once, comm_strs__init); 39 40 return &_comm_strs; 41 } 42 43 static refcount_t *comm_str__refcnt(struct comm_str *cs) 44 { 45 return &RC_CHK_ACCESS(cs)->refcnt; 46 } 47 48 static const char *comm_str__str(const struct comm_str *cs) 49 { 50 return &RC_CHK_ACCESS(cs)->str[0]; 51 } 52 53 static struct comm_str *comm_str__get(struct comm_str *cs) 54 { 55 struct comm_str *result; 56 57 if (RC_CHK_GET(result, cs)) 58 refcount_inc_not_zero(comm_str__refcnt(cs)); 59 60 return result; 61 } 62 63 static void comm_str__put(struct comm_str *cs) 64 { 65 if (!cs) 66 return; 67 68 if (refcount_dec_and_test(comm_str__refcnt(cs))) { 69 RC_CHK_FREE(cs); 70 } else { 71 if (refcount_read(comm_str__refcnt(cs)) == 1) 72 comm_strs__remove_if_last(cs); 73 74 RC_CHK_PUT(cs); 75 } 76 } 77 78 static struct comm_str *comm_str__new(const char *str) 79 { 80 struct comm_str *result = NULL; 81 RC_STRUCT(comm_str) *cs; 82 83 cs = malloc(sizeof(*cs) + strlen(str) + 1); 84 if (ADD_RC_CHK(result, cs)) { 85 refcount_set(comm_str__refcnt(result), 1); 86 strcpy(&cs->str[0], str); 87 } 88 return result; 89 } 90 91 static int comm_str__search(const void *_key, const void *_member) 92 { 93 const char *key = _key; 94 const struct comm_str *member = *(const struct comm_str * const *)_member; 95 96 return strcmp(key, comm_str__str(member)); 97 } 98 99 static void comm_strs__remove_if_last(struct comm_str *cs) 100 { 101 struct comm_strs *comm_strs = comm_strs__get(); 102 103 down_write(&comm_strs->lock); 104 /* 105 * Are there only references from the array, if so remove the array 106 * reference under the write lock so that we don't race with findnew. 107 */ 108 if (refcount_read(comm_str__refcnt(cs)) == 1) { 109 struct comm_str **entry; 110 111 entry = bsearch(comm_str__str(cs), comm_strs->strs, comm_strs->num_strs, 112 sizeof(struct comm_str *), comm_str__search); 113 comm_str__put(*entry); 114 for (int i = entry - comm_strs->strs; i < comm_strs->num_strs - 1; i++) 115 comm_strs->strs[i] = comm_strs->strs[i + 1]; 116 comm_strs->num_strs--; 117 } 118 up_write(&comm_strs->lock); 119 } 120 121 static struct comm_str *__comm_strs__find(struct comm_strs *comm_strs, const char *str) 122 { 123 struct comm_str **result; 124 125 result = bsearch(str, comm_strs->strs, comm_strs->num_strs, sizeof(struct comm_str *), 126 comm_str__search); 127 128 if (!result) 129 return NULL; 130 131 return comm_str__get(*result); 132 } 133 134 static struct comm_str *comm_strs__findnew(const char *str) 135 { 136 struct comm_strs *comm_strs = comm_strs__get(); 137 struct comm_str *result; 138 139 if (!comm_strs) 140 return NULL; 141 142 down_read(&comm_strs->lock); 143 result = __comm_strs__find(comm_strs, str); 144 up_read(&comm_strs->lock); 145 if (result) 146 return result; 147 148 down_write(&comm_strs->lock); 149 result = __comm_strs__find(comm_strs, str); 150 if (!result) { 151 if (comm_strs->num_strs == comm_strs->capacity) { 152 struct comm_str **tmp; 153 154 tmp = reallocarray(comm_strs->strs, 155 comm_strs->capacity + 16, 156 sizeof(*comm_strs->strs)); 157 if (!tmp) { 158 up_write(&comm_strs->lock); 159 return NULL; 160 } 161 comm_strs->strs = tmp; 162 comm_strs->capacity += 16; 163 } 164 result = comm_str__new(str); 165 if (result) { 166 int low = 0, high = comm_strs->num_strs - 1; 167 int insert = comm_strs->num_strs; /* Default to inserting at the end. */ 168 169 while (low <= high) { 170 int mid = low + (high - low) / 2; 171 int cmp = strcmp(comm_str__str(comm_strs->strs[mid]), str); 172 173 if (cmp < 0) { 174 low = mid + 1; 175 } else { 176 high = mid - 1; 177 insert = mid; 178 } 179 } 180 memmove(&comm_strs->strs[insert + 1], &comm_strs->strs[insert], 181 (comm_strs->num_strs - insert) * sizeof(struct comm_str *)); 182 comm_strs->num_strs++; 183 comm_strs->strs[insert] = result; 184 } 185 } 186 up_write(&comm_strs->lock); 187 return comm_str__get(result); 188 } 189 190 struct comm *comm__new(const char *str, u64 timestamp, bool exec) 191 { 192 struct comm *comm = zalloc(sizeof(*comm)); 193 194 if (!comm) 195 return NULL; 196 197 comm->start = timestamp; 198 comm->exec = exec; 199 200 comm->comm_str = comm_strs__findnew(str); 201 if (!comm->comm_str) { 202 free(comm); 203 return NULL; 204 } 205 206 return comm; 207 } 208 209 int comm__override(struct comm *comm, const char *str, u64 timestamp, bool exec) 210 { 211 struct comm_str *new, *old = comm->comm_str; 212 213 new = comm_strs__findnew(str); 214 if (!new) 215 return -ENOMEM; 216 217 comm_str__put(old); 218 comm->comm_str = new; 219 comm->start = timestamp; 220 if (exec) 221 comm->exec = true; 222 223 return 0; 224 } 225 226 void comm__free(struct comm *comm) 227 { 228 comm_str__put(comm->comm_str); 229 free(comm); 230 } 231 232 const char *comm__str(const struct comm *comm) 233 { 234 return comm_str__str(comm->comm_str); 235 } 236