xref: /src/tools/tools/so_splice/proxy.c (revision a0069f16fe235068fac8fac8d11aee4c5dac6e8c)
1 /*
2  * SPDX-License-Identifier: BSD-2-Clause
3  *
4  * Copyright (c) 2024 Klara, Inc.
5  */
6 
7 /*
8  * A simple TCP proxy.  Listens on a local address until a connection appears,
9  * then opens a TCP connection to the target address and shuttles data between
10  * the two until one side closes its connection.
11  *
12  * For example:
13  *
14  *   $ proxy -l 127.0.0.1:8080 www.example.com:80
15  *
16  * The -m flag selects the mode of the transfer.  Specify "-m copy" to enable
17  * copying through userspace, and "-m splice" to use SO_SPLICE.
18  *
19  * The -L flag enables a loopback mode, wherein all data is additionally proxied
20  * through a loopback TCP connection.  This exists mostly to help test a
21  * specific use-case in custom proxy software where we would like to use
22  * SO_SPLICE.
23  */
24 
25 #include <sys/types.h>
26 #include <sys/event.h>
27 #include <sys/socket.h>
28 #include <sys/wait.h>
29 
30 #include <netinet/in.h>
31 #include <netinet/tcp.h>
32 #include <arpa/inet.h>
33 
34 #include <assert.h>
35 #include <err.h>
36 #include <errno.h>
37 #include <netdb.h>
38 #include <signal.h>
39 #include <stdbool.h>
40 #include <stdio.h>
41 #include <stdlib.h>
42 #include <string.h>
43 #include <unistd.h>
44 
45 struct proxy_softc {
46 	struct sockaddr_storage	lss;
47 	struct sockaddr_storage	tss;
48 	size_t bufsz;
49 	enum proxy_mode { PROXY_MODE_COPY, PROXY_MODE_SPLICE } mode;
50 	bool loopback;
51 };
52 
53 static void
usage(void)54 usage(void)
55 {
56 	fprintf(stderr,
57 "usage: proxy [-m copy|splice] [-s <buf-size>] [-L] -l <listen addr> <target addr>\n");
58 	exit(1);
59 }
60 
61 static void
proxy_copy(struct proxy_softc * sc,int cs,int ts)62 proxy_copy(struct proxy_softc *sc, int cs, int ts)
63 {
64 	struct kevent kev[2];
65 	uint8_t *buf;
66 	int kq;
67 
68 	kq = kqueue();
69 	if (kq == -1)
70 		err(1, "kqueue");
71 
72 	EV_SET(&kev[0], cs, EVFILT_READ, EV_ADD, 0, 0, (void *)(uintptr_t)ts);
73 	EV_SET(&kev[1], ts, EVFILT_READ, EV_ADD, 0, 0, (void *)(uintptr_t)cs);
74 	if (kevent(kq, kev, 2, NULL, 0, NULL) == -1)
75 		err(1, "kevent");
76 
77 	buf = malloc(sc->bufsz);
78 	if (buf == NULL)
79 		err(1, "malloc");
80 
81 	for (;;) {
82 		uint8_t *data;
83 		ssize_t n, resid;
84 		int rs, ws;
85 
86 		if (kevent(kq, NULL, 0, kev, 2, NULL) == -1) {
87 			if (errno == EINTR)
88 				continue;
89 			err(1, "kevent");
90 		}
91 
92 		rs = (int)kev[0].ident;
93 		ws = (int)(uintptr_t)kev[0].udata;
94 
95 		n = read(rs, buf, sc->bufsz);
96 		if (n == -1) {
97 			if (errno == ECONNRESET)
98 				break;
99 			err(1, "read");
100 		}
101 		if (n == 0)
102 			break;
103 
104 		data = buf;
105 		resid = n;
106 		do {
107 			n = write(ws, data, resid);
108 			if (n == -1) {
109 				if (errno == EINTR)
110 					continue;
111 				if (errno == ECONNRESET || errno == EPIPE)
112 					break;
113 				err(1, "write");
114 			}
115 			assert(n > 0);
116 			data += n;
117 			resid -= n;
118 		} while (resid > 0);
119 	}
120 
121 	free(buf);
122 	close(kq);
123 }
124 
125 static void
splice(int s1,int s2)126 splice(int s1, int s2)
127 {
128 	struct splice sp;
129 
130 	memset(&sp, 0, sizeof(sp));
131 	sp.sp_fd = s2;
132 	if (setsockopt(s1, SOL_SOCKET, SO_SPLICE, &sp, sizeof(sp)) == -1)
133 		err(1, "setsockopt");
134 }
135 
136 static void
proxy_splice(struct proxy_softc * sc __unused,int cs,int ts)137 proxy_splice(struct proxy_softc *sc __unused, int cs, int ts)
138 {
139 	struct kevent kev[2];
140 	int error, kq;
141 
142 	/* Set up our splices. */
143 	splice(cs, ts);
144 	splice(ts, cs);
145 
146 	/* Block until the connection is terminated. */
147 	kq = kqueue();
148 	if (kq == -1)
149 		err(1, "kqueue");
150 	EV_SET(&kev[0], cs, EVFILT_READ, EV_ADD, 0, 0, NULL);
151 	EV_SET(&kev[1], ts, EVFILT_READ, EV_ADD, 0, 0, NULL);
152 	do {
153 		error = kevent(kq, kev, 2, kev, 2, NULL);
154 		if (error == -1 && errno != EINTR)
155 			err(1, "kevent");
156 	} while (error <= 0);
157 
158 	close(kq);
159 }
160 
161 static void
nodelay(int s)162 nodelay(int s)
163 {
164 	if (setsockopt(s, IPPROTO_TCP, TCP_NODELAY, &(int){1}, sizeof(int)) ==
165 	    -1)
166 		err(1, "setsockopt");
167 }
168 
169 /*
170  * Like socketpair(2), but for TCP sockets on the  loopback address.
171  */
172 static void
tcp_socketpair(int out[2],int af)173 tcp_socketpair(int out[2], int af)
174 {
175 	struct sockaddr_in sin;
176 	struct sockaddr_in6 sin6;
177 	struct sockaddr *sa;
178 	int sd[2];
179 
180 	sd[0] = socket(af, SOCK_STREAM, 0);
181 	if (sd[0] == -1)
182 		err(1, "socket");
183 	sd[1] = socket(af, SOCK_STREAM, 0);
184 	if (sd[1] == -1)
185 		err(1, "socket");
186 
187 	nodelay(sd[0]);
188 	nodelay(sd[1]);
189 
190 	if (af == AF_INET) {
191 		memset(&sin, 0, sizeof(sin));
192 		sin.sin_family = AF_INET;
193 		sin.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
194 		sin.sin_port = 0;
195 		sin.sin_len = sizeof(sin);
196 		sa = (struct sockaddr *)&sin;
197 	} else if (af == AF_INET6) {
198 		memset(&sin6, 0, sizeof(sin6));
199 		sin6.sin6_family = AF_INET6;
200 		sin6.sin6_addr = in6addr_loopback;
201 		sin6.sin6_port = 0;
202 		sin6.sin6_len = sizeof(sin6);
203 		sa = (struct sockaddr *)&sin6;
204 	} else {
205 		errx(1, "unsupported address family %d", af);
206 	}
207 
208 	if (bind(sd[0], sa, sa->sa_len) == -1)
209 		err(1, "bind");
210 	if (listen(sd[0], 1) == -1)
211 		err(1, "listen");
212 
213 	if (getsockname(sd[0], sa, &(socklen_t){sa->sa_len}) == -1)
214 		err(1, "getsockname");
215 	if (connect(sd[1], sa, sa->sa_len) == -1)
216 		err(1, "connect");
217 
218 	out[0] = sd[1];
219 	out[1] = accept(sd[0], NULL, NULL);
220 	if (out[1] == -1)
221 		err(1, "accept");
222 	close(sd[0]);
223 }
224 
225 /*
226  * Proxy data between two connected TCP sockets.  Returns the PID of the process
227  * forked off to handle the data transfer.
228  */
229 static pid_t
proxy(struct proxy_softc * sc,int s1,int s2)230 proxy(struct proxy_softc *sc, int s1, int s2)
231 {
232 	pid_t child;
233 
234 	child = fork();
235 	if (child == -1)
236 		err(1, "fork");
237 	if (child != 0) {
238 		close(s1);
239 		close(s2);
240 		return (child);
241 	}
242 
243 	if (sc->mode == PROXY_MODE_COPY)
244 		proxy_copy(sc, s1, s2);
245 	else
246 		proxy_splice(sc, s1, s2);
247 	_exit(0);
248 }
249 
250 /*
251  * The proxy event loop accepts connections and forks off child processes to
252  * handle them.  We also handle events generated when child processes exit
253  * (triggered by one side closing its connection).
254  */
255 static void
eventloop(struct proxy_softc * sc)256 eventloop(struct proxy_softc *sc)
257 {
258 	struct kevent kev;
259 	int kq, lsd;
260 	pid_t child;
261 
262 	lsd = socket(sc->lss.ss_family, SOCK_STREAM, 0);
263 	if (lsd == -1)
264 		err(1, "socket");
265 	if (setsockopt(lsd, SOL_SOCKET, SO_REUSEADDR, &(int){1}, sizeof(int)) ==
266 	    -1)
267 		err(1, "setsockopt");
268 	if (bind(lsd, (struct sockaddr *)&sc->lss, sc->lss.ss_len) == -1)
269 		err(1, "bind");
270 	if (listen(lsd, 5) == -1)
271 		err(1, "listen");
272 
273 	kq = kqueue();
274 	if (kq == -1)
275 		err(1, "kqueue");
276 	EV_SET(&kev, lsd, EVFILT_READ, EV_ADD, 0, 0, NULL);
277 	if (kevent(kq, &kev, 1, NULL, 0, NULL) == -1)
278 		err(1, "kevent");
279 
280 	for (;;) {
281 		if (kevent(kq, NULL, 0, &kev, 1, NULL) == -1) {
282 			if (errno == EINTR)
283 				continue;
284 			err(1, "kevent");
285 		}
286 
287 		switch (kev.filter) {
288 		case EVFILT_READ: {
289 			int s, ts;
290 
291 			if ((int)kev.ident != lsd)
292 				errx(1, "unexpected event ident %d",
293 				    (int)kev.ident);
294 
295 			s = accept(lsd, NULL, NULL);
296 			if (s == -1)
297 				err(1, "accept");
298 			nodelay(s);
299 
300 			ts = socket(sc->tss.ss_family, SOCK_STREAM, 0);
301 			if (ts == -1)
302 				err(1, "socket");
303 			nodelay(ts);
304 			if (connect(ts, (struct sockaddr *)&sc->tss,
305 			    sc->tss.ss_len) == -1)
306 				err(1, "connect");
307 
308 			if (sc->loopback) {
309 				int ls[2];
310 
311 				tcp_socketpair(ls, sc->tss.ss_family);
312 				child = proxy(sc, ls[0], ts);
313 				EV_SET(&kev, child, EVFILT_PROC, EV_ADD,
314 				    NOTE_EXIT, 0, NULL);
315 				if (kevent(kq, &kev, 1, NULL, 0, NULL) == -1)
316 					err(1, "kevent");
317 				child = proxy(sc, s, ls[1]);
318 				EV_SET(&kev, child, EVFILT_PROC, EV_ADD,
319 				    NOTE_EXIT, 0, NULL);
320 				if (kevent(kq, &kev, 1, NULL, 0, NULL) == -1)
321 					err(1, "kevent");
322 			} else {
323 				child = proxy(sc, s, ts);
324 				EV_SET(&kev, child, EVFILT_PROC, EV_ADD,
325 				    NOTE_EXIT, 0, NULL);
326 				if (kevent(kq, &kev, 1, NULL, 0, NULL) == -1)
327 					err(1, "kevent");
328 			}
329 
330 			break;
331 			}
332 		case EVFILT_PROC: {
333 			int status;
334 
335 			child = kev.ident;
336 			status = (int)kev.data;
337 			if (WIFEXITED(status)) {
338 				if (WEXITSTATUS(status) != 0) {
339 					errx(1, "child exited with status %d",
340 					    WEXITSTATUS(status));
341 				}
342 			} else if (WIFSIGNALED(status)) {
343 				warnx("child %d terminated by signal %d",
344 				    (pid_t)kev.ident, WTERMSIG(status));
345 			}
346 			if (waitpid(child, NULL, 0) == -1)
347 				err(1, "waitpid");
348 			break;
349 			}
350 		}
351 	}
352 }
353 
354 static void
addrinfo(struct sockaddr_storage * ss,const char * addr)355 addrinfo(struct sockaddr_storage *ss, const char *addr)
356 {
357 	struct addrinfo hints, *res, *res1;
358 	char *host, *port;
359 	int error;
360 
361 	host = strdup(addr);
362 	if (host == NULL)
363 		err(1, "strdup");
364 	port = strchr(host, ':');
365 	if (port == NULL)
366 		errx(1, "invalid address '%s', should be <addr>:<port>", host);
367 	*port++ = '\0';
368 
369 	memset(&hints, 0, sizeof(hints));
370 	hints.ai_socktype = SOCK_STREAM;
371 	error = getaddrinfo(host, port, &hints, &res);
372 	if (error != 0)
373 		errx(1, "%s", gai_strerror(error));
374 	for (res1 = res; res != NULL; res = res->ai_next) {
375 		if (res->ai_protocol == IPPROTO_TCP) {
376 			memcpy(ss, res->ai_addr, res->ai_addrlen);
377 			break;
378 		}
379 	}
380 	if (res == NULL)
381 		errx(1, "no TCP address found for '%s'", host);
382 	free(host);
383 	freeaddrinfo(res1);
384 }
385 
386 static void
proxy_init(struct proxy_softc * sc,const char * laddr,const char * taddr,size_t bufsz,enum proxy_mode mode,bool loopback)387 proxy_init(struct proxy_softc *sc, const char *laddr, const char *taddr,
388     size_t bufsz, enum proxy_mode mode, bool loopback)
389 {
390 	addrinfo(&sc->lss, laddr);
391 	addrinfo(&sc->tss, taddr);
392 
393 	sc->bufsz = bufsz;
394 	sc->mode = mode;
395 	sc->loopback = loopback;
396 }
397 
398 int
main(int argc,char ** argv)399 main(int argc, char **argv)
400 {
401 	struct proxy_softc sc;
402 	char *laddr, *taddr;
403 	size_t bufsz;
404 	enum proxy_mode mode;
405 	int ch;
406 	bool loopback;
407 
408 	(void)signal(SIGPIPE, SIG_IGN);
409 
410 	loopback = false;
411 	mode = PROXY_MODE_COPY;
412 	bufsz = 2 * 1024 * 1024ul;
413 	laddr = taddr = NULL;
414 	while ((ch = getopt(argc, argv, "Ll:m:s:")) != -1) {
415 		switch (ch) {
416 		case 'l':
417 			laddr = optarg;
418 			break;
419 		case 'L':
420 			loopback = true;
421 			break;
422 		case 'm':
423 			if (strcmp(optarg, "copy") == 0)
424 				mode = PROXY_MODE_COPY;
425 			else if (strcmp(optarg, "splice") == 0)
426 				mode = PROXY_MODE_SPLICE;
427 			else
428 				usage();
429 			break;
430 		case 's':
431 			bufsz = atoi(optarg);
432 			break;
433 		default:
434 			usage();
435 		}
436 	}
437 	argc -= optind;
438 	argv += optind;
439 
440 	if (laddr == NULL || argc != 1)
441 		usage();
442 	taddr = argv[0];
443 
444 	/* Marshal command-line parameters into a neat structure. */
445 	proxy_init(&sc, laddr, taddr, bufsz, mode, loopback);
446 
447 	/* Start handling connections. */
448 	eventloop(&sc);
449 
450 	return (0);
451 }
452