1 // SPDX-License-Identifier: GPL-2.0
2
3 #include <stdio.h>
4 #include <string.h>
5 #include <sys/poll.h>
6 #include <sys/socket.h>
7 #include <sys/time.h>
8 #include <netinet/in.h>
9 #include <unistd.h>
10
11 #include <ynl.h>
12
13 #include "psp-user.h"
14
15 #define dbg(msg...) \
16 do { \
17 if (opts->verbose) \
18 fprintf(stderr, "DEBUG: " msg); \
19 } while (0)
20
21 static bool should_quit;
22
23 struct opts {
24 int port;
25 int ifindex;
26 bool verbose;
27 };
28
29 enum accept_cfg {
30 ACCEPT_CFG_NONE = 0,
31 ACCEPT_CFG_CLEAR,
32 ACCEPT_CFG_PSP,
33 };
34
35 static struct {
36 unsigned char tx;
37 unsigned char rx;
38 } psp_vers;
39
conn_setup_psp(struct ynl_sock * ys,struct opts * opts,int data_sock)40 static int conn_setup_psp(struct ynl_sock *ys, struct opts *opts, int data_sock)
41 {
42 struct psp_rx_assoc_rsp *rsp;
43 struct psp_rx_assoc_req *req;
44 struct psp_tx_assoc_rsp *tsp;
45 struct psp_tx_assoc_req *teq;
46 char info[300];
47 int key_len;
48 ssize_t sz;
49 __u32 spi;
50
51 dbg("create PSP connection\n");
52
53 // Rx assoc alloc
54 req = psp_rx_assoc_req_alloc();
55
56 psp_rx_assoc_req_set_sock_fd(req, data_sock);
57 psp_rx_assoc_req_set_version(req, psp_vers.rx);
58
59 rsp = psp_rx_assoc(ys, req);
60 psp_rx_assoc_req_free(req);
61
62 if (!rsp) {
63 perror("ERROR: failed to Rx assoc");
64 return -1;
65 }
66
67 // SPI exchange
68 key_len = rsp->rx_key._len.key;
69 memcpy(info, &rsp->rx_key.spi, sizeof(spi));
70 memcpy(&info[sizeof(spi)], rsp->rx_key.key, key_len);
71 sz = sizeof(spi) + key_len;
72
73 send(data_sock, info, sz, MSG_WAITALL);
74 psp_rx_assoc_rsp_free(rsp);
75
76 sz = recv(data_sock, info, sz, MSG_WAITALL);
77 if (sz < 0) {
78 perror("ERROR: failed to read PSP key from sock");
79 return -1;
80 }
81 memcpy(&spi, info, sizeof(spi));
82
83 // Setup Tx assoc
84 teq = psp_tx_assoc_req_alloc();
85
86 psp_tx_assoc_req_set_sock_fd(teq, data_sock);
87 psp_tx_assoc_req_set_version(teq, psp_vers.tx);
88 psp_tx_assoc_req_set_tx_key_spi(teq, spi);
89 psp_tx_assoc_req_set_tx_key_key(teq, &info[sizeof(spi)], key_len);
90
91 tsp = psp_tx_assoc(ys, teq);
92 psp_tx_assoc_req_free(teq);
93 if (!tsp) {
94 perror("ERROR: failed to Tx assoc");
95 return -1;
96 }
97 psp_tx_assoc_rsp_free(tsp);
98
99 return 0;
100 }
101
send_ack(int sock)102 static void send_ack(int sock)
103 {
104 send(sock, "ack", 4, MSG_WAITALL);
105 }
106
send_err(int sock)107 static void send_err(int sock)
108 {
109 send(sock, "err", 4, MSG_WAITALL);
110 }
111
send_str(int sock,int value)112 static void send_str(int sock, int value)
113 {
114 char buf[128];
115 int ret;
116
117 ret = snprintf(buf, sizeof(buf), "%d", value);
118 send(sock, buf, ret + 1, MSG_WAITALL);
119 }
120
121 static void
run_session(struct ynl_sock * ys,struct opts * opts,int server_sock,int comm_sock)122 run_session(struct ynl_sock *ys, struct opts *opts,
123 int server_sock, int comm_sock)
124 {
125 enum accept_cfg accept_cfg = ACCEPT_CFG_NONE;
126 struct pollfd pfds[3];
127 size_t data_read = 0;
128 int data_sock = -1;
129
130 while (true) {
131 bool race_close = false;
132 int nfds;
133
134 memset(pfds, 0, sizeof(pfds));
135
136 pfds[0].fd = server_sock;
137 pfds[0].events = POLLIN;
138
139 pfds[1].fd = comm_sock;
140 pfds[1].events = POLLIN;
141
142 nfds = 2;
143 if (data_sock >= 0) {
144 pfds[2].fd = data_sock;
145 pfds[2].events = POLLIN;
146 nfds++;
147 }
148
149 dbg(" ...\n");
150 if (poll(pfds, nfds, -1) < 0) {
151 perror("poll");
152 break;
153 }
154
155 /* data sock */
156 if (pfds[2].revents & POLLIN) {
157 char buf[8192];
158 ssize_t n;
159
160 n = recv(data_sock, buf, sizeof(buf), 0);
161 if (n <= 0) {
162 if (n < 0)
163 perror("data read");
164 close(data_sock);
165 data_sock = -1;
166 dbg("data sock closed\n");
167 } else {
168 data_read += n;
169 dbg("data read %zd\n", data_read);
170 }
171 }
172
173 /* comm sock */
174 if (pfds[1].revents & POLLIN) {
175 static char buf[4096];
176 static ssize_t off;
177 bool consumed;
178 ssize_t n;
179
180 n = recv(comm_sock, &buf[off], sizeof(buf) - off, 0);
181 if (n <= 0) {
182 if (n < 0)
183 perror("comm read");
184 return;
185 }
186
187 off += n;
188 n = off;
189
190 #define __consume(sz) \
191 ({ \
192 if (n == (sz)) { \
193 off = 0; \
194 } else { \
195 off -= (sz); \
196 memmove(buf, &buf[(sz)], off); \
197 } \
198 })
199
200 #define cmd(_name) \
201 ({ \
202 ssize_t sz = sizeof(_name); \
203 bool match = n >= sz && !memcmp(buf, _name, sz); \
204 \
205 if (match) { \
206 dbg("command: " _name "\n"); \
207 __consume(sz); \
208 } \
209 consumed |= match; \
210 match; \
211 })
212
213 do {
214 consumed = false;
215
216 if (cmd("read len"))
217 send_str(comm_sock, data_read);
218
219 if (cmd("data echo")) {
220 if (data_sock >= 0)
221 send(data_sock, "echo", 5,
222 MSG_WAITALL);
223 else
224 fprintf(stderr, "WARN: echo but no data sock\n");
225 send_ack(comm_sock);
226 }
227 if (cmd("data close")) {
228 if (data_sock >= 0) {
229 close(data_sock);
230 data_sock = -1;
231 send_ack(comm_sock);
232 } else {
233 race_close = true;
234 }
235 }
236 if (cmd("conn psp")) {
237 if (accept_cfg != ACCEPT_CFG_NONE)
238 fprintf(stderr, "WARN: old conn config still set!\n");
239 accept_cfg = ACCEPT_CFG_PSP;
240 send_ack(comm_sock);
241 /* next two bytes are versions */
242 if (off >= 2) {
243 memcpy(&psp_vers, buf, 2);
244 __consume(2);
245 } else {
246 fprintf(stderr, "WARN: short conn psp command!\n");
247 }
248 }
249 if (cmd("conn clr")) {
250 if (accept_cfg != ACCEPT_CFG_NONE)
251 fprintf(stderr, "WARN: old conn config still set!\n");
252 accept_cfg = ACCEPT_CFG_CLEAR;
253 send_ack(comm_sock);
254 }
255 if (cmd("exit"))
256 should_quit = true;
257 #undef cmd
258
259 if (!consumed) {
260 fprintf(stderr, "WARN: unknown cmd: [%zd] %s\n",
261 off, buf);
262 }
263 } while (consumed && off);
264 }
265
266 /* server sock */
267 if (pfds[0].revents & POLLIN) {
268 if (data_sock >= 0) {
269 fprintf(stderr, "WARN: new data sock but old one still here\n");
270 close(data_sock);
271 data_sock = -1;
272 }
273 data_sock = accept(server_sock, NULL, NULL);
274 if (data_sock < 0) {
275 perror("accept");
276 continue;
277 }
278 data_read = 0;
279
280 if (accept_cfg == ACCEPT_CFG_CLEAR) {
281 dbg("new data sock: clear\n");
282 /* nothing to do */
283 } else if (accept_cfg == ACCEPT_CFG_PSP) {
284 dbg("new data sock: psp\n");
285 conn_setup_psp(ys, opts, data_sock);
286 } else {
287 fprintf(stderr, "WARN: new data sock but no config\n");
288 }
289 accept_cfg = ACCEPT_CFG_NONE;
290 }
291
292 if (race_close) {
293 if (data_sock >= 0) {
294 /* indeed, ordering problem, handle the close */
295 close(data_sock);
296 data_sock = -1;
297 send_ack(comm_sock);
298 } else {
299 fprintf(stderr, "WARN: close but no data sock\n");
300 send_err(comm_sock);
301 }
302 }
303 }
304 dbg("session ending\n");
305 }
306
spawn_server(struct opts * opts)307 static int spawn_server(struct opts *opts)
308 {
309 struct sockaddr_in6 addr;
310 int fd;
311
312 fd = socket(AF_INET6, SOCK_STREAM, 0);
313 if (fd < 0) {
314 perror("can't open socket");
315 return -1;
316 }
317
318 memset(&addr, 0, sizeof(addr));
319
320 addr.sin6_family = AF_INET6;
321 addr.sin6_addr = in6addr_any;
322 addr.sin6_port = htons(opts->port);
323
324 if (bind(fd, (struct sockaddr *)&addr, sizeof(addr))) {
325 perror("can't bind socket");
326 return -1;
327 }
328
329 if (listen(fd, 5)) {
330 perror("can't listen");
331 return -1;
332 }
333
334 return fd;
335 }
336
run_responder(struct ynl_sock * ys,struct opts * opts)337 static int run_responder(struct ynl_sock *ys, struct opts *opts)
338 {
339 int server_sock, comm;
340
341 server_sock = spawn_server(opts);
342 if (server_sock < 0)
343 return 4;
344
345 while (!should_quit) {
346 comm = accept(server_sock, NULL, NULL);
347 if (comm < 0) {
348 perror("accept failed");
349 } else {
350 run_session(ys, opts, server_sock, comm);
351 close(comm);
352 }
353 }
354
355 return 0;
356 }
357
usage(const char * name,const char * miss)358 static void usage(const char *name, const char *miss)
359 {
360 if (miss)
361 fprintf(stderr, "Missing argument: %s\n", miss);
362
363 fprintf(stderr, "Usage: %s -p port [-v] [-i ifindex]\n", name);
364 exit(EXIT_FAILURE);
365 }
366
parse_cmd_opts(int argc,char ** argv,struct opts * opts)367 static void parse_cmd_opts(int argc, char **argv, struct opts *opts)
368 {
369 int opt;
370
371 while ((opt = getopt(argc, argv, "vp:i:")) != -1) {
372 switch (opt) {
373 case 'v':
374 opts->verbose = 1;
375 break;
376 case 'p':
377 opts->port = atoi(optarg);
378 break;
379 case 'i':
380 opts->ifindex = atoi(optarg);
381 break;
382 default:
383 usage(argv[0], NULL);
384 }
385 }
386 }
387
psp_dev_set_ena(struct ynl_sock * ys,__u32 dev_id,__u32 versions)388 static int psp_dev_set_ena(struct ynl_sock *ys, __u32 dev_id, __u32 versions)
389 {
390 struct psp_dev_set_req *sreq;
391 struct psp_dev_set_rsp *srsp;
392
393 fprintf(stderr, "Set PSP enable on device %d to 0x%x\n",
394 dev_id, versions);
395
396 sreq = psp_dev_set_req_alloc();
397
398 psp_dev_set_req_set_id(sreq, dev_id);
399 psp_dev_set_req_set_psp_versions_ena(sreq, versions);
400
401 srsp = psp_dev_set(ys, sreq);
402 psp_dev_set_req_free(sreq);
403 if (!srsp)
404 return 10;
405
406 psp_dev_set_rsp_free(srsp);
407 return 0;
408 }
409
main(int argc,char ** argv)410 int main(int argc, char **argv)
411 {
412 struct psp_dev_get_list *dev_list;
413 __u32 ver_ena, ver_cap;
414 struct opts opts = {};
415 struct ynl_error yerr;
416 struct ynl_sock *ys;
417 int devid = -1;
418 int ret;
419
420 parse_cmd_opts(argc, argv, &opts);
421 if (!opts.port)
422 usage(argv[0], "port"); // exits
423
424 ys = ynl_sock_create(&ynl_psp_family, &yerr);
425 if (!ys) {
426 fprintf(stderr, "YNL: %s\n", yerr.msg);
427 return 1;
428 }
429
430 dev_list = psp_dev_get_dump(ys);
431 if (ynl_dump_empty(dev_list) && ys->err.code)
432 goto err_close;
433
434 ynl_dump_foreach(dev_list, d) {
435 if (opts.ifindex) {
436 if (d->ifindex != opts.ifindex)
437 continue;
438 devid = d->id;
439 ver_ena = d->psp_versions_ena;
440 ver_cap = d->psp_versions_cap;
441 break;
442 } else if (devid < 0) {
443 devid = d->id;
444 ver_ena = d->psp_versions_ena;
445 ver_cap = d->psp_versions_cap;
446 } else {
447 fprintf(stderr, "Multiple PSP devices found\n");
448 goto err_close_silent;
449 }
450 }
451 psp_dev_get_list_free(dev_list);
452
453 if (opts.ifindex && devid < 0)
454 fprintf(stderr,
455 "WARN: PSP device with ifindex %d requested on cmdline, not found\n",
456 opts.ifindex);
457
458 if (devid >= 0 && ver_ena != ver_cap) {
459 ret = psp_dev_set_ena(ys, devid, ver_cap);
460 if (ret)
461 goto err_close;
462 }
463
464 ret = run_responder(ys, &opts);
465
466 if (devid >= 0 && ver_ena != ver_cap &&
467 psp_dev_set_ena(ys, devid, ver_ena))
468 fprintf(stderr, "WARN: failed to set the PSP versions back\n");
469
470 ynl_sock_destroy(ys);
471
472 return ret;
473
474 err_close:
475 fprintf(stderr, "YNL: %s\n", ys->err.msg);
476 err_close_silent:
477 ynl_sock_destroy(ys);
478 return 2;
479 }
480