1 // SPDX-License-Identifier: GPL-2.0
2 /* Author: Dmitry Safonov <dima@arista.com> */
3 #include <inttypes.h>
4 #include "aolib.h"
5
6 static union tcp_addr local_addr;
7
__setup_lo_intf(const char * lo_intf,const char * addr_str,uint8_t prefix)8 static void __setup_lo_intf(const char *lo_intf,
9 const char *addr_str, uint8_t prefix)
10 {
11 if (inet_pton(TEST_FAMILY, addr_str, &local_addr) != 1)
12 test_error("Can't convert local ip address");
13
14 if (ip_addr_add(lo_intf, TEST_FAMILY, local_addr, prefix))
15 test_error("Failed to add %s ip address", lo_intf);
16
17 if (link_set_up(lo_intf))
18 test_error("Failed to bring %s up", lo_intf);
19 }
20
setup_lo_intf(const char * lo_intf)21 static void setup_lo_intf(const char *lo_intf)
22 {
23 #ifdef IPV6_TEST
24 __setup_lo_intf(lo_intf, "::1", 128);
25 #else
26 __setup_lo_intf(lo_intf, "127.0.0.1", 8);
27 #endif
28 }
29
tcp_self_connect(const char * tst,unsigned int port,bool different_keyids,bool check_restore)30 static void tcp_self_connect(const char *tst, unsigned int port,
31 bool different_keyids, bool check_restore)
32 {
33 uint64_t before_challenge_ack, after_challenge_ack;
34 uint64_t before_syn_challenge, after_syn_challenge;
35 struct tcp_ao_counters before_ao, after_ao;
36 uint64_t before_aogood, after_aogood;
37 struct netstat *ns_before, *ns_after;
38 const size_t nr_packets = 20;
39 struct tcp_ao_repair ao_img;
40 struct tcp_sock_state img;
41 sockaddr_af addr;
42 int sk;
43
44 tcp_addr_to_sockaddr_in(&addr, &local_addr, htons(port));
45
46 sk = socket(test_family, SOCK_STREAM, IPPROTO_TCP);
47 if (sk < 0)
48 test_error("socket()");
49
50 if (different_keyids) {
51 if (test_add_key(sk, DEFAULT_TEST_PASSWORD, local_addr, -1, 5, 7))
52 test_error("setsockopt(TCP_AO_ADD_KEY)");
53 if (test_add_key(sk, DEFAULT_TEST_PASSWORD, local_addr, -1, 7, 5))
54 test_error("setsockopt(TCP_AO_ADD_KEY)");
55 } else {
56 if (test_add_key(sk, DEFAULT_TEST_PASSWORD, local_addr, -1, 100, 100))
57 test_error("setsockopt(TCP_AO_ADD_KEY)");
58 }
59
60 if (bind(sk, (struct sockaddr *)&addr, sizeof(addr)) < 0)
61 test_error("bind()");
62
63 ns_before = netstat_read();
64 before_aogood = netstat_get(ns_before, "TCPAOGood", NULL);
65 before_challenge_ack = netstat_get(ns_before, "TCPChallengeACK", NULL);
66 before_syn_challenge = netstat_get(ns_before, "TCPSYNChallenge", NULL);
67 if (test_get_tcp_ao_counters(sk, &before_ao))
68 test_error("test_get_tcp_ao_counters()");
69
70 if (__test_connect_socket(sk, "lo", (struct sockaddr *)&addr,
71 sizeof(addr), TEST_TIMEOUT_SEC) < 0) {
72 ns_after = netstat_read();
73 netstat_print_diff(ns_before, ns_after);
74 test_error("failed to connect()");
75 }
76
77 if (test_client_verify(sk, 100, nr_packets, TEST_TIMEOUT_SEC)) {
78 test_fail("%s: tcp connection verify failed", tst);
79 close(sk);
80 return;
81 }
82
83 ns_after = netstat_read();
84 after_aogood = netstat_get(ns_after, "TCPAOGood", NULL);
85 after_challenge_ack = netstat_get(ns_after, "TCPChallengeACK", NULL);
86 after_syn_challenge = netstat_get(ns_after, "TCPSYNChallenge", NULL);
87 if (test_get_tcp_ao_counters(sk, &after_ao))
88 test_error("test_get_tcp_ao_counters()");
89 if (!check_restore) {
90 /* to debug: netstat_print_diff(ns_before, ns_after); */
91 netstat_free(ns_before);
92 }
93 netstat_free(ns_after);
94
95 if (after_aogood <= before_aogood) {
96 test_fail("%s: TCPAOGood counter mismatch: %zu <= %zu",
97 tst, after_aogood, before_aogood);
98 close(sk);
99 return;
100 }
101 if (after_challenge_ack <= before_challenge_ack ||
102 after_syn_challenge <= before_syn_challenge) {
103 /*
104 * It's also meant to test simultaneous open, so check
105 * these counters as well.
106 */
107 test_fail("%s: Didn't challenge SYN or ACK: %zu <= %zu OR %zu <= %zu",
108 tst, after_challenge_ack, before_challenge_ack,
109 after_syn_challenge, before_syn_challenge);
110 close(sk);
111 return;
112 }
113
114 if (test_tcp_ao_counters_cmp(tst, &before_ao, &after_ao, TEST_CNT_GOOD)) {
115 close(sk);
116 return;
117 }
118
119 if (!check_restore) {
120 test_ok("%s: connect TCPAOGood %" PRIu64 " => %" PRIu64,
121 tst, before_aogood, after_aogood);
122 close(sk);
123 return;
124 }
125
126 test_enable_repair(sk);
127 test_sock_checkpoint(sk, &img, &addr);
128 #ifdef IPV6_TEST
129 addr.sin6_port = htons(port + 1);
130 #else
131 addr.sin_port = htons(port + 1);
132 #endif
133 test_ao_checkpoint(sk, &ao_img);
134 test_kill_sk(sk);
135
136 sk = socket(test_family, SOCK_STREAM, IPPROTO_TCP);
137 if (sk < 0)
138 test_error("socket()");
139
140 test_enable_repair(sk);
141 __test_sock_restore(sk, "lo", &img, &addr, &addr, sizeof(addr));
142 if (different_keyids) {
143 if (test_add_repaired_key(sk, DEFAULT_TEST_PASSWORD, 0,
144 local_addr, -1, 7, 5))
145 test_error("setsockopt(TCP_AO_ADD_KEY)");
146 if (test_add_repaired_key(sk, DEFAULT_TEST_PASSWORD, 0,
147 local_addr, -1, 5, 7))
148 test_error("setsockopt(TCP_AO_ADD_KEY)");
149 } else {
150 if (test_add_repaired_key(sk, DEFAULT_TEST_PASSWORD, 0,
151 local_addr, -1, 100, 100))
152 test_error("setsockopt(TCP_AO_ADD_KEY)");
153 }
154 test_ao_restore(sk, &ao_img);
155 test_disable_repair(sk);
156 test_sock_state_free(&img);
157 if (test_client_verify(sk, 100, nr_packets, TEST_TIMEOUT_SEC)) {
158 test_fail("%s: tcp connection verify failed", tst);
159 close(sk);
160 return;
161 }
162 ns_after = netstat_read();
163 after_aogood = netstat_get(ns_after, "TCPAOGood", NULL);
164 /* to debug: netstat_print_diff(ns_before, ns_after); */
165 netstat_free(ns_before);
166 netstat_free(ns_after);
167 close(sk);
168 if (after_aogood <= before_aogood) {
169 test_fail("%s: TCPAOGood counter mismatch: %zu <= %zu",
170 tst, after_aogood, before_aogood);
171 return;
172 }
173 test_ok("%s: connect TCPAOGood %" PRIu64 " => %" PRIu64,
174 tst, before_aogood, after_aogood);
175 }
176
client_fn(void * arg)177 static void *client_fn(void *arg)
178 {
179 unsigned int port = test_server_port;
180
181 setup_lo_intf("lo");
182
183 tcp_self_connect("self-connect(same keyids)", port++, false, false);
184 tcp_self_connect("self-connect(different keyids)", port++, true, false);
185 tcp_self_connect("self-connect(restore)", port, false, true);
186 port += 2;
187 tcp_self_connect("self-connect(restore, different keyids)", port, true, true);
188 port += 2;
189
190 return NULL;
191 }
192
main(int argc,char * argv[])193 int main(int argc, char *argv[])
194 {
195 test_init(4, client_fn, NULL);
196 return 0;
197 }
198