xref: /linux/fs/smb/common/smbdirect/smbdirect_connection.c (revision 3cd8b194bf3428dfa53120fee47e827a7c495815)
1 // SPDX-License-Identifier: GPL-2.0-or-later
2 /*
3  *   Copyright (C) 2017, Microsoft Corporation.
4  *   Copyright (c) 2025, Stefan Metzmacher
5  */
6 
7 #include "smbdirect_internal.h"
8 #include <linux/folio_queue.h>
9 
10 struct smbdirect_map_sges {
11 	struct ib_sge *sge;
12 	size_t num_sge;
13 	size_t max_sge;
14 	struct ib_device *device;
15 	u32 local_dma_lkey;
16 	enum dma_data_direction direction;
17 };
18 
19 static ssize_t smbdirect_map_sges_from_iter(struct iov_iter *iter, size_t len,
20 					    struct smbdirect_map_sges *state);
21 
22 static void smbdirect_connection_recv_io_refill_work(struct work_struct *work);
23 static void smbdirect_connection_send_immediate_work(struct work_struct *work);
24 
smbdirect_connection_qp_event_handler(struct ib_event * event,void * context)25 static void smbdirect_connection_qp_event_handler(struct ib_event *event, void *context)
26 {
27 	struct smbdirect_socket *sc = context;
28 
29 	smbdirect_log_rdma_event(sc, SMBDIRECT_LOG_ERR,
30 		"%s on device %.*s socket %p (cm_id=%p) status %s first_error %1pe\n",
31 		ib_event_msg(event->event),
32 		IB_DEVICE_NAME_MAX,
33 		event->device->name,
34 		sc, sc->rdma.cm_id,
35 		smbdirect_socket_status_string(sc->status),
36 		SMBDIRECT_DEBUG_ERR_PTR(sc->first_error));
37 
38 	switch (event->event) {
39 	case IB_EVENT_CQ_ERR:
40 	case IB_EVENT_QP_FATAL:
41 		smbdirect_socket_schedule_cleanup(sc, -ECONNABORTED);
42 		break;
43 
44 	default:
45 		break;
46 	}
47 }
48 
smbdirect_connection_rdma_event_handler(struct rdma_cm_id * id,struct rdma_cm_event * event)49 static int smbdirect_connection_rdma_event_handler(struct rdma_cm_id *id,
50 						   struct rdma_cm_event *event)
51 {
52 	struct smbdirect_socket *sc = id->context;
53 	int ret = -ECONNRESET;
54 
55 	if (event->event == RDMA_CM_EVENT_DEVICE_REMOVAL)
56 		ret = -ENETDOWN;
57 	if (IS_ERR(SMBDIRECT_DEBUG_ERR_PTR(event->status)))
58 		ret = event->status;
59 
60 	/*
61 	 * cma_cm_event_handler() has
62 	 * lockdep_assert_held(&id_priv->handler_mutex);
63 	 *
64 	 * Mutexes are not allowed in interrupts,
65 	 * and we rely on not being in an interrupt here.
66 	 */
67 	WARN_ON_ONCE(in_interrupt());
68 
69 	if (event->event != sc->rdma.expected_event) {
70 		smbdirect_log_rdma_event(sc, SMBDIRECT_LOG_ERR,
71 			"%s (first_error=%1pe, expected=%s) => event=%s status=%d => ret=%1pe\n",
72 			smbdirect_socket_status_string(sc->status),
73 			SMBDIRECT_DEBUG_ERR_PTR(sc->first_error),
74 			rdma_event_msg(sc->rdma.expected_event),
75 			rdma_event_msg(event->event),
76 			event->status,
77 			SMBDIRECT_DEBUG_ERR_PTR(ret));
78 
79 		/*
80 		 * If we get RDMA_CM_EVENT_DEVICE_REMOVAL,
81 		 * we should change to SMBDIRECT_SOCKET_DISCONNECTED,
82 		 * so that rdma_disconnect() is avoided later via
83 		 * smbdirect_socket_schedule_cleanup[_status]() =>
84 		 * smbdirect_socket_cleanup_work().
85 		 *
86 		 * As otherwise we'd set SMBDIRECT_SOCKET_DISCONNECTING,
87 		 * but never ever get RDMA_CM_EVENT_DISCONNECTED and
88 		 * never reach SMBDIRECT_SOCKET_DISCONNECTED.
89 		 */
90 		if (event->event == RDMA_CM_EVENT_DEVICE_REMOVAL)
91 			smbdirect_socket_schedule_cleanup_status(sc,
92 								 SMBDIRECT_LOG_ERR,
93 								 ret,
94 								 SMBDIRECT_SOCKET_DISCONNECTED);
95 		else
96 			smbdirect_socket_schedule_cleanup(sc, ret);
97 		if (sc->ib.qp)
98 			ib_drain_qp(sc->ib.qp);
99 		return 0;
100 	}
101 
102 	smbdirect_log_rdma_event(sc, SMBDIRECT_LOG_INFO,
103 		"%s (first_error=%1pe) event=%s\n",
104 		smbdirect_socket_status_string(sc->status),
105 		SMBDIRECT_DEBUG_ERR_PTR(sc->first_error),
106 		rdma_event_msg(event->event));
107 
108 	switch (event->event) {
109 	case RDMA_CM_EVENT_DISCONNECTED:
110 		/*
111 		 * We need to change to SMBDIRECT_SOCKET_DISCONNECTED,
112 		 * so that rdma_disconnect() is avoided later via
113 		 * smbdirect_socket_schedule_cleanup_status() =>
114 		 * smbdirect_socket_cleanup_work().
115 		 *
116 		 * As otherwise we'd set SMBDIRECT_SOCKET_DISCONNECTING,
117 		 * but never ever get RDMA_CM_EVENT_DISCONNECTED and
118 		 * never reach SMBDIRECT_SOCKET_DISCONNECTED.
119 		 *
120 		 * This is also a normal disconnect so
121 		 * SMBDIRECT_LOG_INFO should be good enough
122 		 * and avoids spamming the default logs.
123 		 */
124 		smbdirect_socket_schedule_cleanup_status(sc,
125 							 SMBDIRECT_LOG_INFO,
126 							 ret,
127 							 SMBDIRECT_SOCKET_DISCONNECTED);
128 		if (sc->ib.qp)
129 			ib_drain_qp(sc->ib.qp);
130 		return 0;
131 
132 	default:
133 		break;
134 	}
135 
136 	/*
137 	 * This is an internal error, should be handled above via
138 	 * event->event != sc->rdma.expected_event already.
139 	 */
140 	WARN_ON_ONCE(sc->rdma.expected_event != RDMA_CM_EVENT_DISCONNECTED);
141 	smbdirect_socket_schedule_cleanup(sc, -ECONNABORTED);
142 	return 0;
143 }
144 
smbdirect_connection_rdma_established(struct smbdirect_socket * sc)145 void smbdirect_connection_rdma_established(struct smbdirect_socket *sc)
146 {
147 	smbdirect_log_rdma_event(sc, SMBDIRECT_LOG_INFO,
148 		"rdma established: device: %.*s local: %pISpsfc remote: %pISpsfc\n",
149 		IB_DEVICE_NAME_MAX,
150 		sc->ib.dev->name,
151 		&sc->rdma.cm_id->route.addr.src_addr,
152 		&sc->rdma.cm_id->route.addr.dst_addr);
153 
154 	sc->rdma.cm_id->event_handler = smbdirect_connection_rdma_event_handler;
155 	sc->rdma.expected_event = RDMA_CM_EVENT_DISCONNECTED;
156 }
157 
smbdirect_connection_negotiation_done(struct smbdirect_socket * sc)158 void smbdirect_connection_negotiation_done(struct smbdirect_socket *sc)
159 {
160 	if (unlikely(sc->first_error))
161 		return;
162 
163 	if (sc->status == SMBDIRECT_SOCKET_CONNECTED)
164 		/*
165 		 * This is the accept case where
166 		 * smbdirect_socket_accept() already sets
167 		 * SMBDIRECT_SOCKET_CONNECTED
168 		 */
169 		goto done;
170 
171 	if (sc->status != SMBDIRECT_SOCKET_NEGOTIATE_RUNNING) {
172 		/*
173 		 * Something went wrong...
174 		 */
175 		smbdirect_log_rdma_event(sc, SMBDIRECT_LOG_ERR,
176 			"status=%s first_error=%1pe local: %pISpsfc remote: %pISpsfc\n",
177 			smbdirect_socket_status_string(sc->status),
178 			SMBDIRECT_DEBUG_ERR_PTR(sc->first_error),
179 			&sc->rdma.cm_id->route.addr.src_addr,
180 			&sc->rdma.cm_id->route.addr.dst_addr);
181 		return;
182 	}
183 
184 	/*
185 	 * We are done, so we can wake up the waiter.
186 	 */
187 	WARN_ONCE(sc->status == SMBDIRECT_SOCKET_CONNECTED,
188 		  "status=%s first_error=%1pe",
189 		  smbdirect_socket_status_string(sc->status),
190 		  SMBDIRECT_DEBUG_ERR_PTR(sc->first_error));
191 	sc->status = SMBDIRECT_SOCKET_CONNECTED;
192 
193 	/*
194 	 * We need to setup the refill and send immediate work
195 	 * in order to get a working connection.
196 	 */
197 done:
198 	INIT_WORK(&sc->recv_io.posted.refill_work, smbdirect_connection_recv_io_refill_work);
199 	INIT_WORK(&sc->idle.immediate_work, smbdirect_connection_send_immediate_work);
200 
201 	smbdirect_log_rdma_event(sc, SMBDIRECT_LOG_INFO,
202 		"negotiated: local: %pISpsfc remote: %pISpsfc\n",
203 		&sc->rdma.cm_id->route.addr.src_addr,
204 		&sc->rdma.cm_id->route.addr.dst_addr);
205 
206 	wake_up(&sc->status_wait);
207 }
208 
smbdirect_rdma_rw_send_wrs(struct ib_device * dev,const struct ib_qp_init_attr * attr)209 static u32 smbdirect_rdma_rw_send_wrs(struct ib_device *dev,
210 				      const struct ib_qp_init_attr *attr)
211 {
212 	/*
213 	 * This could be split out of rdma_rw_init_qp()
214 	 * and be a helper function next to rdma_rw_mr_factor()
215 	 *
216 	 * We can't check unlikely(rdma_rw_force_mr) here,
217 	 * but that is most likely 0 anyway.
218 	 */
219 	u32 factor;
220 
221 	WARN_ON_ONCE(attr->port_num == 0);
222 
223 	/*
224 	 * Each context needs at least one RDMA READ or WRITE WR.
225 	 *
226 	 * For some hardware we might need more, eventually we should ask the
227 	 * HCA driver for a multiplier here.
228 	 */
229 	factor = 1;
230 
231 	/*
232 	 * If the device needs MRs to perform RDMA READ or WRITE operations,
233 	 * we'll need two additional MRs for the registrations and the
234 	 * invalidation.
235 	 */
236 	if (rdma_protocol_iwarp(dev, attr->port_num) || dev->attrs.max_sgl_rd)
237 		factor += 2;	/* inv + reg */
238 
239 	return factor * attr->cap.max_rdma_ctxs;
240 }
241 
smbdirect_connection_create_qp(struct smbdirect_socket * sc)242 int smbdirect_connection_create_qp(struct smbdirect_socket *sc)
243 {
244 	const struct smbdirect_socket_parameters *sp = &sc->parameters;
245 	struct ib_qp_init_attr qp_attr;
246 	struct ib_qp_cap qp_cap;
247 	u32 rdma_send_wr;
248 	u32 max_send_wr;
249 	int ret;
250 
251 	/*
252 	 * Note that {rdma,ib}_create_qp() will call
253 	 * rdma_rw_init_qp() if max_rdma_ctxs is not 0.
254 	 * It will adjust max_send_wr to the required
255 	 * number of additional WRs for the RDMA RW operations.
256 	 * It will cap max_send_wr to the device limit.
257 	 *
258 	 * We use allocate sp->responder_resources * 2 MRs
259 	 * and each MR needs WRs for REG and INV, so
260 	 * we use '* 4'.
261 	 *
262 	 * +1 for ib_drain_qp()
263 	 */
264 	memset(&qp_cap, 0, sizeof(qp_cap));
265 	qp_cap.max_send_wr = sp->send_credit_target + sp->responder_resources * 4 + 1;
266 	qp_cap.max_recv_wr = sp->recv_credit_max + 1;
267 	qp_cap.max_send_sge = SMBDIRECT_SEND_IO_MAX_SGE;
268 	qp_cap.max_recv_sge = SMBDIRECT_RECV_IO_MAX_SGE;
269 	qp_cap.max_inline_data = 0;
270 	qp_cap.max_rdma_ctxs = sc->rw_io.credits.max;
271 
272 	/*
273 	 * Find out the number of max_send_wr
274 	 * after rdma_rw_init_qp() adjusted it.
275 	 *
276 	 * We only do it on a temporary variable,
277 	 * as rdma_create_qp() will trigger
278 	 * rdma_rw_init_qp() again.
279 	 */
280 	memset(&qp_attr, 0, sizeof(qp_attr));
281 	qp_attr.cap = qp_cap;
282 	qp_attr.port_num = sc->rdma.cm_id->port_num;
283 	rdma_send_wr = smbdirect_rdma_rw_send_wrs(sc->ib.dev, &qp_attr);
284 	max_send_wr = qp_cap.max_send_wr + rdma_send_wr;
285 
286 	if (qp_cap.max_send_wr > sc->ib.dev->attrs.max_cqe ||
287 	    qp_cap.max_send_wr > sc->ib.dev->attrs.max_qp_wr) {
288 		pr_err("Possible CQE overrun: max_send_wr %d\n",
289 		       qp_cap.max_send_wr);
290 		pr_err("device %.*s reporting max_cqe %d max_qp_wr %d\n",
291 		       IB_DEVICE_NAME_MAX,
292 		       sc->ib.dev->name,
293 		       sc->ib.dev->attrs.max_cqe,
294 		       sc->ib.dev->attrs.max_qp_wr);
295 		pr_err("consider lowering send_credit_target = %d\n",
296 		       sp->send_credit_target);
297 		return -EINVAL;
298 	}
299 
300 	if (qp_cap.max_rdma_ctxs &&
301 	    (max_send_wr >= sc->ib.dev->attrs.max_cqe ||
302 	     max_send_wr >= sc->ib.dev->attrs.max_qp_wr)) {
303 		pr_err("Possible CQE overrun: rdma_send_wr %d + max_send_wr %d = %d\n",
304 		       rdma_send_wr, qp_cap.max_send_wr, max_send_wr);
305 		pr_err("device %.*s reporting max_cqe %d max_qp_wr %d\n",
306 		       IB_DEVICE_NAME_MAX,
307 		       sc->ib.dev->name,
308 		       sc->ib.dev->attrs.max_cqe,
309 		       sc->ib.dev->attrs.max_qp_wr);
310 		pr_err("consider lowering send_credit_target = %d, max_rdma_ctxs = %d\n",
311 		       sp->send_credit_target, qp_cap.max_rdma_ctxs);
312 		return -EINVAL;
313 	}
314 
315 	if (qp_cap.max_recv_wr > sc->ib.dev->attrs.max_cqe ||
316 	    qp_cap.max_recv_wr > sc->ib.dev->attrs.max_qp_wr) {
317 		pr_err("Possible CQE overrun: max_recv_wr %d\n",
318 		       qp_cap.max_recv_wr);
319 		pr_err("device %.*s reporting max_cqe %d max_qp_wr %d\n",
320 		       IB_DEVICE_NAME_MAX,
321 		       sc->ib.dev->name,
322 		       sc->ib.dev->attrs.max_cqe,
323 		       sc->ib.dev->attrs.max_qp_wr);
324 		pr_err("consider lowering receive_credit_max = %d\n",
325 		       sp->recv_credit_max);
326 		return -EINVAL;
327 	}
328 
329 	if (qp_cap.max_send_sge > sc->ib.dev->attrs.max_send_sge ||
330 	    qp_cap.max_recv_sge > sc->ib.dev->attrs.max_recv_sge) {
331 		pr_err("device %.*s max_send_sge/max_recv_sge = %d/%d too small\n",
332 		       IB_DEVICE_NAME_MAX,
333 		       sc->ib.dev->name,
334 		       sc->ib.dev->attrs.max_send_sge,
335 		       sc->ib.dev->attrs.max_recv_sge);
336 		return -EINVAL;
337 	}
338 
339 	sc->ib.pd = ib_alloc_pd(sc->ib.dev, 0);
340 	if (IS_ERR(sc->ib.pd)) {
341 		pr_err("Can't create RDMA PD: %1pe\n", sc->ib.pd);
342 		ret = PTR_ERR(sc->ib.pd);
343 		sc->ib.pd = NULL;
344 		return ret;
345 	}
346 
347 	sc->ib.send_cq = ib_alloc_cq_any(sc->ib.dev, sc,
348 					 max_send_wr,
349 					 sc->ib.poll_ctx);
350 	if (IS_ERR(sc->ib.send_cq)) {
351 		pr_err("Can't create RDMA send CQ: %1pe\n", sc->ib.send_cq);
352 		ret = PTR_ERR(sc->ib.send_cq);
353 		sc->ib.send_cq = NULL;
354 		goto err;
355 	}
356 
357 	sc->ib.recv_cq = ib_alloc_cq_any(sc->ib.dev, sc,
358 					 qp_cap.max_recv_wr,
359 					 sc->ib.poll_ctx);
360 	if (IS_ERR(sc->ib.recv_cq)) {
361 		pr_err("Can't create RDMA recv CQ: %1pe\n", sc->ib.recv_cq);
362 		ret = PTR_ERR(sc->ib.recv_cq);
363 		sc->ib.recv_cq = NULL;
364 		goto err;
365 	}
366 
367 	/*
368 	 * We reset completely here!
369 	 * As the above use was just temporary
370 	 * to calc max_send_wr and rdma_send_wr.
371 	 *
372 	 * rdma_create_qp() will trigger rdma_rw_init_qp()
373 	 * again if max_rdma_ctxs is not 0.
374 	 */
375 	memset(&qp_attr, 0, sizeof(qp_attr));
376 	qp_attr.event_handler = smbdirect_connection_qp_event_handler;
377 	qp_attr.qp_context = sc;
378 	qp_attr.cap = qp_cap;
379 	qp_attr.sq_sig_type = IB_SIGNAL_REQ_WR;
380 	qp_attr.qp_type = IB_QPT_RC;
381 	qp_attr.send_cq = sc->ib.send_cq;
382 	qp_attr.recv_cq = sc->ib.recv_cq;
383 	qp_attr.port_num = ~0;
384 
385 	ret = rdma_create_qp(sc->rdma.cm_id, sc->ib.pd, &qp_attr);
386 	if (ret) {
387 		pr_err("Can't create RDMA QP: %1pe\n",
388 		       SMBDIRECT_DEBUG_ERR_PTR(ret));
389 		goto err;
390 	}
391 	sc->ib.qp = sc->rdma.cm_id->qp;
392 
393 	return 0;
394 err:
395 	smbdirect_connection_destroy_qp(sc);
396 	return ret;
397 }
398 
smbdirect_connection_destroy_qp(struct smbdirect_socket * sc)399 void smbdirect_connection_destroy_qp(struct smbdirect_socket *sc)
400 {
401 	if (sc->ib.qp) {
402 		ib_drain_qp(sc->ib.qp);
403 		sc->ib.qp = NULL;
404 		rdma_destroy_qp(sc->rdma.cm_id);
405 	}
406 	if (sc->ib.recv_cq) {
407 		ib_destroy_cq(sc->ib.recv_cq);
408 		sc->ib.recv_cq = NULL;
409 	}
410 	if (sc->ib.send_cq) {
411 		ib_destroy_cq(sc->ib.send_cq);
412 		sc->ib.send_cq = NULL;
413 	}
414 	if (sc->ib.pd) {
415 		ib_dealloc_pd(sc->ib.pd);
416 		sc->ib.pd = NULL;
417 	}
418 }
419 
smbdirect_connection_create_mem_pools(struct smbdirect_socket * sc)420 int smbdirect_connection_create_mem_pools(struct smbdirect_socket *sc)
421 {
422 	const struct smbdirect_socket_parameters *sp = &sc->parameters;
423 	char name[80];
424 	size_t i;
425 
426 	/*
427 	 * We use sizeof(struct smbdirect_negotiate_resp) for the
428 	 * payload size as it is larger as
429 	 * sizeof(struct smbdirect_data_transfer).
430 	 *
431 	 * This will fit client and server usage for now.
432 	 */
433 	snprintf(name, sizeof(name), "smbdirect_send_io_cache_%p", sc);
434 	struct kmem_cache_args send_io_args = {
435 		.align		= __alignof__(struct smbdirect_send_io),
436 	};
437 	sc->send_io.mem.cache = kmem_cache_create(name,
438 						  sizeof(struct smbdirect_send_io) +
439 						  sizeof(struct smbdirect_negotiate_resp),
440 						  &send_io_args,
441 						  SLAB_HWCACHE_ALIGN);
442 	if (!sc->send_io.mem.cache)
443 		goto err;
444 
445 	sc->send_io.mem.pool = mempool_create_slab_pool(sp->send_credit_target,
446 							sc->send_io.mem.cache);
447 	if (!sc->send_io.mem.pool)
448 		goto err;
449 
450 	/*
451 	 * A payload size of sp->max_recv_size should fit
452 	 * any message.
453 	 *
454 	 * For smbdirect_data_transfer messages the whole
455 	 * buffer might be exposed to userspace
456 	 * (currently on the client side...)
457 	 * The documentation says data_offset = 0 would be
458 	 * strange but valid.
459 	 */
460 	snprintf(name, sizeof(name), "smbdirect_recv_io_cache_%p", sc);
461 	struct kmem_cache_args recv_io_args = {
462 		.align		= __alignof__(struct smbdirect_recv_io),
463 		.useroffset	= sizeof(struct smbdirect_recv_io),
464 		.usersize	= sp->max_recv_size,
465 	};
466 	sc->recv_io.mem.cache = kmem_cache_create(name,
467 						  sizeof(struct smbdirect_recv_io) +
468 						  sp->max_recv_size,
469 						  &recv_io_args,
470 						  SLAB_HWCACHE_ALIGN);
471 	if (!sc->recv_io.mem.cache)
472 		goto err;
473 
474 	sc->recv_io.mem.pool = mempool_create_slab_pool(sp->recv_credit_max,
475 							sc->recv_io.mem.cache);
476 	if (!sc->recv_io.mem.pool)
477 		goto err;
478 
479 	for (i = 0; i < sp->recv_credit_max; i++) {
480 		struct smbdirect_recv_io *recv_io;
481 
482 		recv_io = mempool_alloc(sc->recv_io.mem.pool,
483 					sc->recv_io.mem.gfp_mask);
484 		if (!recv_io)
485 			goto err;
486 		recv_io->socket = sc;
487 		recv_io->sge.length = 0;
488 		list_add_tail(&recv_io->list, &sc->recv_io.free.list);
489 	}
490 
491 	return 0;
492 err:
493 	smbdirect_connection_destroy_mem_pools(sc);
494 	return -ENOMEM;
495 }
496 
smbdirect_connection_destroy_mem_pools(struct smbdirect_socket * sc)497 void smbdirect_connection_destroy_mem_pools(struct smbdirect_socket *sc)
498 {
499 	struct smbdirect_recv_io *recv_io, *next_io;
500 
501 	list_for_each_entry_safe(recv_io, next_io, &sc->recv_io.free.list, list) {
502 		list_del(&recv_io->list);
503 		mempool_free(recv_io, sc->recv_io.mem.pool);
504 	}
505 
506 	/*
507 	 * Note mempool_destroy() and kmem_cache_destroy()
508 	 * work fine with a NULL pointer
509 	 */
510 
511 	mempool_destroy(sc->recv_io.mem.pool);
512 	sc->recv_io.mem.pool = NULL;
513 
514 	kmem_cache_destroy(sc->recv_io.mem.cache);
515 	sc->recv_io.mem.cache = NULL;
516 
517 	mempool_destroy(sc->send_io.mem.pool);
518 	sc->send_io.mem.pool = NULL;
519 
520 	kmem_cache_destroy(sc->send_io.mem.cache);
521 	sc->send_io.mem.cache = NULL;
522 }
523 
smbdirect_connection_alloc_send_io(struct smbdirect_socket * sc)524 struct smbdirect_send_io *smbdirect_connection_alloc_send_io(struct smbdirect_socket *sc)
525 {
526 	struct smbdirect_send_io *msg;
527 
528 	msg = mempool_alloc(sc->send_io.mem.pool, sc->send_io.mem.gfp_mask);
529 	if (!msg)
530 		return ERR_PTR(-ENOMEM);
531 	msg->socket = sc;
532 	INIT_LIST_HEAD(&msg->sibling_list);
533 	msg->num_sge = 0;
534 
535 	return msg;
536 }
537 
smbdirect_connection_free_send_io(struct smbdirect_send_io * msg)538 void smbdirect_connection_free_send_io(struct smbdirect_send_io *msg)
539 {
540 	struct smbdirect_socket *sc = msg->socket;
541 	size_t i;
542 
543 	/*
544 	 * The list needs to be empty!
545 	 * The caller should take care of it.
546 	 */
547 	WARN_ON_ONCE(!list_empty(&msg->sibling_list));
548 
549 	/*
550 	 * Note we call ib_dma_unmap_page(), even if some sges are mapped using
551 	 * ib_dma_map_single().
552 	 *
553 	 * The difference between _single() and _page() only matters for the
554 	 * ib_dma_map_*() case.
555 	 *
556 	 * For the ib_dma_unmap_*() case it does not matter as both take the
557 	 * dma_addr_t and dma_unmap_single_attrs() is just an alias to
558 	 * dma_unmap_page_attrs().
559 	 */
560 	for (i = 0; i < msg->num_sge; i++)
561 		ib_dma_unmap_page(sc->ib.dev,
562 				  msg->sge[i].addr,
563 				  msg->sge[i].length,
564 				  DMA_TO_DEVICE);
565 
566 	mempool_free(msg, sc->send_io.mem.pool);
567 }
568 
smbdirect_connection_get_recv_io(struct smbdirect_socket * sc)569 struct smbdirect_recv_io *smbdirect_connection_get_recv_io(struct smbdirect_socket *sc)
570 {
571 	struct smbdirect_recv_io *msg = NULL;
572 	unsigned long flags;
573 
574 	spin_lock_irqsave(&sc->recv_io.free.lock, flags);
575 	if (likely(!sc->first_error))
576 		msg = list_first_entry_or_null(&sc->recv_io.free.list,
577 					       struct smbdirect_recv_io,
578 					       list);
579 	if (likely(msg)) {
580 		list_del(&msg->list);
581 		sc->statistics.get_receive_buffer++;
582 	}
583 	spin_unlock_irqrestore(&sc->recv_io.free.lock, flags);
584 
585 	return msg;
586 }
587 
smbdirect_connection_put_recv_io(struct smbdirect_recv_io * msg)588 void smbdirect_connection_put_recv_io(struct smbdirect_recv_io *msg)
589 {
590 	struct smbdirect_socket *sc = msg->socket;
591 	unsigned long flags;
592 
593 	if (likely(msg->sge.length != 0)) {
594 		ib_dma_unmap_single(sc->ib.dev,
595 				    msg->sge.addr,
596 				    msg->sge.length,
597 				    DMA_FROM_DEVICE);
598 		msg->sge.length = 0;
599 	}
600 
601 	spin_lock_irqsave(&sc->recv_io.free.lock, flags);
602 	list_add_tail(&msg->list, &sc->recv_io.free.list);
603 	sc->statistics.put_receive_buffer++;
604 	spin_unlock_irqrestore(&sc->recv_io.free.lock, flags);
605 
606 	queue_work(sc->workqueues.refill, &sc->recv_io.posted.refill_work);
607 }
608 
smbdirect_connection_reassembly_append_recv_io(struct smbdirect_socket * sc,struct smbdirect_recv_io * msg,u32 data_length)609 void smbdirect_connection_reassembly_append_recv_io(struct smbdirect_socket *sc,
610 						    struct smbdirect_recv_io *msg,
611 						    u32 data_length)
612 {
613 	unsigned long flags;
614 
615 	spin_lock_irqsave(&sc->recv_io.reassembly.lock, flags);
616 	list_add_tail(&msg->list, &sc->recv_io.reassembly.list);
617 	sc->recv_io.reassembly.queue_length++;
618 	/*
619 	 * Make sure reassembly_data_length is updated after list and
620 	 * reassembly_queue_length are updated. On the dequeue side
621 	 * reassembly_data_length is checked without a lock to determine
622 	 * if reassembly_queue_length and list is up to date
623 	 */
624 	virt_wmb();
625 	sc->recv_io.reassembly.data_length += data_length;
626 	spin_unlock_irqrestore(&sc->recv_io.reassembly.lock, flags);
627 	sc->statistics.enqueue_reassembly_queue++;
628 }
629 
630 struct smbdirect_recv_io *
smbdirect_connection_reassembly_first_recv_io(struct smbdirect_socket * sc)631 smbdirect_connection_reassembly_first_recv_io(struct smbdirect_socket *sc)
632 {
633 	struct smbdirect_recv_io *msg;
634 
635 	msg = list_first_entry_or_null(&sc->recv_io.reassembly.list,
636 				       struct smbdirect_recv_io,
637 				       list);
638 
639 	return msg;
640 }
641 
smbdirect_connection_negotiate_rdma_resources(struct smbdirect_socket * sc,u8 peer_initiator_depth,u8 peer_responder_resources,const struct rdma_conn_param * param)642 void smbdirect_connection_negotiate_rdma_resources(struct smbdirect_socket *sc,
643 						   u8 peer_initiator_depth,
644 						   u8 peer_responder_resources,
645 						   const struct rdma_conn_param *param)
646 {
647 	struct smbdirect_socket_parameters *sp = &sc->parameters;
648 
649 	if (rdma_protocol_iwarp(sc->ib.dev, sc->rdma.cm_id->port_num) &&
650 	    param->private_data_len == 8) {
651 		/*
652 		 * Legacy clients with only iWarp MPA v1 support
653 		 * need a private blob in order to negotiate
654 		 * the IRD/ORD values.
655 		 */
656 		const __be32 *ird_ord_hdr = param->private_data;
657 		u32 ird32 = be32_to_cpu(ird_ord_hdr[0]);
658 		u32 ord32 = be32_to_cpu(ird_ord_hdr[1]);
659 
660 		/*
661 		 * cifs.ko sends the legacy IRD/ORD negotiation
662 		 * event if iWarp MPA v2 was used.
663 		 *
664 		 * Here we check that the values match and only
665 		 * mark the client as legacy if they don't match.
666 		 */
667 		if ((u32)param->initiator_depth != ird32 ||
668 		    (u32)param->responder_resources != ord32) {
669 			/*
670 			 * There are broken clients (old cifs.ko)
671 			 * using little endian and also
672 			 * struct rdma_conn_param only uses u8
673 			 * for initiator_depth and responder_resources,
674 			 * so we truncate the value to U8_MAX.
675 			 *
676 			 * smb_direct_accept_client() will then
677 			 * do the real negotiation in order to
678 			 * select the minimum between client and
679 			 * server.
680 			 */
681 			ird32 = min_t(u32, ird32, U8_MAX);
682 			ord32 = min_t(u32, ord32, U8_MAX);
683 
684 			sc->rdma.legacy_iwarp = true;
685 			peer_initiator_depth = (u8)ird32;
686 			peer_responder_resources = (u8)ord32;
687 		}
688 	}
689 
690 	/*
691 	 * negotiate the value by using the minimum
692 	 * between client and server if the client provided
693 	 * non 0 values.
694 	 */
695 	if (peer_initiator_depth != 0)
696 		sp->initiator_depth = min_t(u8, sp->initiator_depth,
697 					    peer_initiator_depth);
698 	if (peer_responder_resources != 0)
699 		sp->responder_resources = min_t(u8, sp->responder_resources,
700 						peer_responder_resources);
701 }
702 
smbdirect_connection_is_connected(struct smbdirect_socket * sc)703 bool smbdirect_connection_is_connected(struct smbdirect_socket *sc)
704 {
705 	if (unlikely(!sc || sc->first_error || sc->status != SMBDIRECT_SOCKET_CONNECTED))
706 		return false;
707 	return true;
708 }
709 __SMBDIRECT_EXPORT_SYMBOL__(smbdirect_connection_is_connected);
710 
smbdirect_connection_wait_for_connected(struct smbdirect_socket * sc)711 int smbdirect_connection_wait_for_connected(struct smbdirect_socket *sc)
712 {
713 	const struct smbdirect_socket_parameters *sp = &sc->parameters;
714 	union {
715 		struct sockaddr sa;
716 		struct sockaddr_storage ss;
717 	} src_addr, dst_addr;
718 	const struct sockaddr *src = NULL;
719 	const struct sockaddr *dst = NULL;
720 	char _devname[IB_DEVICE_NAME_MAX] = { 0, };
721 	const char *devname = NULL;
722 	int ret;
723 
724 	if (sc->rdma.cm_id) {
725 		src_addr.ss = sc->rdma.cm_id->route.addr.src_addr;
726 		if (src_addr.sa.sa_family != AF_UNSPEC)
727 			src = &src_addr.sa;
728 		dst_addr.ss = sc->rdma.cm_id->route.addr.dst_addr;
729 		if (dst_addr.sa.sa_family != AF_UNSPEC)
730 			dst = &dst_addr.sa;
731 
732 		if (sc->ib.dev) {
733 			memcpy(_devname, sc->ib.dev->name, IB_DEVICE_NAME_MAX);
734 			devname = _devname;
735 		}
736 	}
737 
738 	smbdirect_log_rdma_event(sc, SMBDIRECT_LOG_INFO,
739 		"waiting for connection: device: %.*s local: %pISpsfc remote: %pISpsfc\n",
740 		IB_DEVICE_NAME_MAX, devname, src, dst);
741 
742 	ret = wait_event_interruptible_timeout(sc->status_wait,
743 					       sc->status == SMBDIRECT_SOCKET_CONNECTED ||
744 					       sc->first_error,
745 					       msecs_to_jiffies(sp->negotiate_timeout_msec));
746 	if (sc->rdma.cm_id) {
747 		/*
748 		 * Maybe src and dev are updated in the meantime.
749 		 */
750 		src_addr.ss = sc->rdma.cm_id->route.addr.src_addr;
751 		if (src_addr.sa.sa_family != AF_UNSPEC)
752 			src = &src_addr.sa;
753 		dst_addr.ss = sc->rdma.cm_id->route.addr.dst_addr;
754 		if (dst_addr.sa.sa_family != AF_UNSPEC)
755 			dst = &dst_addr.sa;
756 
757 		if (sc->ib.dev) {
758 			memcpy(_devname, sc->ib.dev->name, IB_DEVICE_NAME_MAX);
759 			devname = _devname;
760 		}
761 	}
762 	if (ret == 0)
763 		ret = -ETIMEDOUT;
764 	if (ret < 0)
765 		smbdirect_socket_schedule_cleanup(sc, ret);
766 	if (sc->first_error) {
767 		int lvl = SMBDIRECT_LOG_ERR;
768 
769 		ret = sc->first_error;
770 		if (ret == -ENODEV)
771 			lvl = SMBDIRECT_LOG_INFO;
772 
773 		smbdirect_log_rdma_event(sc, lvl,
774 			"connection failed %1pe device: %.*s local: %pISpsfc remote: %pISpsfc\n",
775 			SMBDIRECT_DEBUG_ERR_PTR(ret),
776 			IB_DEVICE_NAME_MAX, devname, src, dst);
777 		return ret;
778 	}
779 
780 	return 0;
781 }
782 __SMBDIRECT_EXPORT_SYMBOL__(smbdirect_connection_wait_for_connected);
783 
smbdirect_connection_idle_timer_work(struct work_struct * work)784 void smbdirect_connection_idle_timer_work(struct work_struct *work)
785 {
786 	struct smbdirect_socket *sc =
787 		container_of(work, struct smbdirect_socket, idle.timer_work.work);
788 	const struct smbdirect_socket_parameters *sp = &sc->parameters;
789 
790 	if (sc->idle.keepalive != SMBDIRECT_KEEPALIVE_NONE) {
791 		smbdirect_log_keep_alive(sc, SMBDIRECT_LOG_ERR,
792 			"%s => timeout sc->idle.keepalive=%s\n",
793 			smbdirect_socket_status_string(sc->status),
794 			sc->idle.keepalive == SMBDIRECT_KEEPALIVE_SENT ?
795 			"SENT" : "PENDING");
796 		smbdirect_socket_schedule_cleanup(sc, -ETIMEDOUT);
797 		return;
798 	}
799 
800 	if (sc->status != SMBDIRECT_SOCKET_CONNECTED)
801 		return;
802 
803 	/*
804 	 * Now use the keepalive timeout (instead of keepalive interval)
805 	 * in order to wait for a response
806 	 */
807 	sc->idle.keepalive = SMBDIRECT_KEEPALIVE_PENDING;
808 	mod_delayed_work(sc->workqueues.idle, &sc->idle.timer_work,
809 			 msecs_to_jiffies(sp->keepalive_timeout_msec));
810 	smbdirect_log_keep_alive(sc, SMBDIRECT_LOG_INFO,
811 		"schedule send of empty idle message\n");
812 	queue_work(sc->workqueues.immediate, &sc->idle.immediate_work);
813 }
814 
smbdirect_connection_grant_recv_credits(struct smbdirect_socket * sc)815 u16 smbdirect_connection_grant_recv_credits(struct smbdirect_socket *sc)
816 {
817 	int missing;
818 	int available;
819 	int new_credits;
820 
821 	if (atomic_read(&sc->recv_io.credits.count) >= sc->recv_io.credits.target)
822 		return 0;
823 
824 	missing = (int)sc->recv_io.credits.target - atomic_read(&sc->recv_io.credits.count);
825 	available = atomic_xchg(&sc->recv_io.credits.available, 0);
826 	new_credits = min3((int)U16_MAX, missing, available);
827 	if (new_credits <= 0) {
828 		/*
829 		 * If credits are available, but not granted
830 		 * we need to re-add them again.
831 		 */
832 		if (available)
833 			atomic_add(available, &sc->recv_io.credits.available);
834 		return 0;
835 	}
836 
837 	if (new_credits < available) {
838 		/*
839 		 * Readd the remaining available again.
840 		 */
841 		available -= new_credits;
842 		atomic_add(available, &sc->recv_io.credits.available);
843 	}
844 
845 	/*
846 	 * Remember we granted the credits
847 	 */
848 	atomic_add(new_credits, &sc->recv_io.credits.count);
849 	return new_credits;
850 }
851 
smbdirect_connection_request_keep_alive(struct smbdirect_socket * sc)852 static bool smbdirect_connection_request_keep_alive(struct smbdirect_socket *sc)
853 {
854 	const struct smbdirect_socket_parameters *sp = &sc->parameters;
855 
856 	if (sc->idle.keepalive == SMBDIRECT_KEEPALIVE_PENDING) {
857 		sc->idle.keepalive = SMBDIRECT_KEEPALIVE_SENT;
858 		/*
859 		 * Now use the keepalive timeout (instead of keepalive interval)
860 		 * in order to wait for a response
861 		 */
862 		mod_delayed_work(sc->workqueues.idle, &sc->idle.timer_work,
863 				 msecs_to_jiffies(sp->keepalive_timeout_msec));
864 		return true;
865 	}
866 
867 	return false;
868 }
869 
smbdirect_connection_post_send_wr(struct smbdirect_socket * sc,struct ib_send_wr * wr)870 int smbdirect_connection_post_send_wr(struct smbdirect_socket *sc,
871 				      struct ib_send_wr *wr)
872 {
873 	int ret;
874 
875 	if (unlikely(sc->first_error))
876 		return sc->first_error;
877 
878 	atomic_inc(&sc->send_io.pending.count);
879 	ret = ib_post_send(sc->ib.qp, wr, NULL);
880 	if (ret) {
881 		atomic_dec(&sc->send_io.pending.count);
882 		smbdirect_log_rdma_send(sc, SMBDIRECT_LOG_ERR,
883 			"ib_post_send() failed %1pe\n",
884 			SMBDIRECT_DEBUG_ERR_PTR(ret));
885 		smbdirect_socket_schedule_cleanup(sc, ret);
886 	}
887 
888 	return ret;
889 }
890 
smbdirect_connection_send_batch_init(struct smbdirect_send_batch * batch,bool need_invalidate_rkey,unsigned int remote_key)891 static void smbdirect_connection_send_batch_init(struct smbdirect_send_batch *batch,
892 						 bool need_invalidate_rkey,
893 						 unsigned int remote_key)
894 {
895 	INIT_LIST_HEAD(&batch->msg_list);
896 	batch->wr_cnt = 0;
897 	batch->need_invalidate_rkey = need_invalidate_rkey;
898 	batch->remote_key = remote_key;
899 	batch->credit = 0;
900 }
901 
smbdirect_connection_send_batch_flush(struct smbdirect_socket * sc,struct smbdirect_send_batch * batch,bool is_last)902 int smbdirect_connection_send_batch_flush(struct smbdirect_socket *sc,
903 					  struct smbdirect_send_batch *batch,
904 					  bool is_last)
905 {
906 	struct smbdirect_send_io *first, *last;
907 	int ret = 0;
908 
909 	if (list_empty(&batch->msg_list))
910 		goto release_credit;
911 
912 	first = list_first_entry(&batch->msg_list,
913 				 struct smbdirect_send_io,
914 				 sibling_list);
915 	last = list_last_entry(&batch->msg_list,
916 			       struct smbdirect_send_io,
917 			       sibling_list);
918 
919 	if (batch->need_invalidate_rkey) {
920 		first->wr.opcode = IB_WR_SEND_WITH_INV;
921 		first->wr.ex.invalidate_rkey = batch->remote_key;
922 		batch->need_invalidate_rkey = false;
923 		batch->remote_key = 0;
924 	}
925 
926 	last->wr.send_flags = IB_SEND_SIGNALED;
927 	last->wr.wr_cqe = &last->cqe;
928 
929 	/*
930 	 * Remove last from send_ctx->msg_list
931 	 * and splice the rest of send_ctx->msg_list
932 	 * to last->sibling_list.
933 	 *
934 	 * send_ctx->msg_list is a valid empty list
935 	 * at the end.
936 	 */
937 	list_del_init(&last->sibling_list);
938 	list_splice_tail_init(&batch->msg_list, &last->sibling_list);
939 	batch->wr_cnt = 0;
940 
941 	ret = smbdirect_connection_post_send_wr(sc, &first->wr);
942 	if (ret) {
943 		struct smbdirect_send_io *sibling, *next;
944 
945 		list_for_each_entry_safe(sibling, next, &last->sibling_list, sibling_list) {
946 			list_del_init(&sibling->sibling_list);
947 			smbdirect_connection_free_send_io(sibling);
948 		}
949 		smbdirect_connection_free_send_io(last);
950 	}
951 
952 release_credit:
953 	if (is_last && !ret && batch->credit) {
954 		atomic_add(batch->credit, &sc->send_io.bcredits.count);
955 		batch->credit = 0;
956 		wake_up(&sc->send_io.bcredits.wait_queue);
957 	}
958 
959 	return ret;
960 }
961 __SMBDIRECT_EXPORT_SYMBOL__(smbdirect_connection_send_batch_flush);
962 
963 struct smbdirect_send_batch *
smbdirect_init_send_batch_storage(struct smbdirect_send_batch_storage * storage,bool need_invalidate_rkey,unsigned int remote_key)964 smbdirect_init_send_batch_storage(struct smbdirect_send_batch_storage *storage,
965 				  bool need_invalidate_rkey,
966 				  unsigned int remote_key)
967 {
968 	struct smbdirect_send_batch *batch = (struct smbdirect_send_batch *)storage;
969 
970 	memset(storage, 0, sizeof(*storage));
971 	BUILD_BUG_ON(sizeof(*batch) > sizeof(*storage));
972 
973 	smbdirect_connection_send_batch_init(batch,
974 					     need_invalidate_rkey,
975 					     remote_key);
976 
977 	return batch;
978 }
979 __SMBDIRECT_EXPORT_SYMBOL__(smbdirect_init_send_batch_storage);
980 
smbdirect_connection_wait_for_send_bcredit(struct smbdirect_socket * sc,struct smbdirect_send_batch * batch)981 static int smbdirect_connection_wait_for_send_bcredit(struct smbdirect_socket *sc,
982 						      struct smbdirect_send_batch *batch)
983 {
984 	int ret;
985 
986 	if (batch->credit)
987 		return 0;
988 
989 	ret = smbdirect_socket_wait_for_credits(sc,
990 						 SMBDIRECT_SOCKET_CONNECTED,
991 						 -ENOTCONN,
992 						 &sc->send_io.bcredits.wait_queue,
993 						 &sc->send_io.bcredits.count,
994 						 1);
995 	if (ret)
996 		return ret;
997 
998 	batch->credit = 1;
999 	return 0;
1000 }
1001 
smbdirect_connection_wait_for_send_lcredit(struct smbdirect_socket * sc,struct smbdirect_send_batch * batch)1002 static int smbdirect_connection_wait_for_send_lcredit(struct smbdirect_socket *sc,
1003 						      struct smbdirect_send_batch *batch)
1004 {
1005 	if (batch && atomic_read(&sc->send_io.lcredits.count) <= 1) {
1006 		int ret;
1007 
1008 		ret = smbdirect_connection_send_batch_flush(sc, batch, false);
1009 		if (ret)
1010 			return ret;
1011 	}
1012 
1013 	return smbdirect_socket_wait_for_credits(sc,
1014 						 SMBDIRECT_SOCKET_CONNECTED,
1015 						 -ENOTCONN,
1016 						 &sc->send_io.lcredits.wait_queue,
1017 						 &sc->send_io.lcredits.count,
1018 						 1);
1019 }
1020 
smbdirect_connection_wait_for_send_credits(struct smbdirect_socket * sc,struct smbdirect_send_batch * batch)1021 static int smbdirect_connection_wait_for_send_credits(struct smbdirect_socket *sc,
1022 						      struct smbdirect_send_batch *batch)
1023 {
1024 	if (batch && (batch->wr_cnt >= 16 || atomic_read(&sc->send_io.credits.count) <= 1)) {
1025 		int ret;
1026 
1027 		ret = smbdirect_connection_send_batch_flush(sc, batch, false);
1028 		if (ret)
1029 			return ret;
1030 	}
1031 
1032 	return smbdirect_socket_wait_for_credits(sc,
1033 						 SMBDIRECT_SOCKET_CONNECTED,
1034 						 -ENOTCONN,
1035 						 &sc->send_io.credits.wait_queue,
1036 						 &sc->send_io.credits.count,
1037 						 1);
1038 }
1039 
1040 static void smbdirect_connection_send_io_done(struct ib_cq *cq, struct ib_wc *wc);
1041 
smbdirect_connection_post_send_io(struct smbdirect_socket * sc,struct smbdirect_send_batch * batch,struct smbdirect_send_io * msg)1042 static int smbdirect_connection_post_send_io(struct smbdirect_socket *sc,
1043 					     struct smbdirect_send_batch *batch,
1044 					     struct smbdirect_send_io *msg)
1045 {
1046 	int i;
1047 
1048 	for (i = 0; i < msg->num_sge; i++)
1049 		ib_dma_sync_single_for_device(sc->ib.dev,
1050 					      msg->sge[i].addr, msg->sge[i].length,
1051 					      DMA_TO_DEVICE);
1052 
1053 	msg->cqe.done = smbdirect_connection_send_io_done;
1054 	msg->wr.wr_cqe = &msg->cqe;
1055 	msg->wr.opcode = IB_WR_SEND;
1056 	msg->wr.sg_list = &msg->sge[0];
1057 	msg->wr.num_sge = msg->num_sge;
1058 	msg->wr.next = NULL;
1059 
1060 	if (batch) {
1061 		msg->wr.send_flags = 0;
1062 		if (!list_empty(&batch->msg_list)) {
1063 			struct smbdirect_send_io *last;
1064 
1065 			last = list_last_entry(&batch->msg_list,
1066 					       struct smbdirect_send_io,
1067 					       sibling_list);
1068 			last->wr.next = &msg->wr;
1069 		}
1070 		list_add_tail(&msg->sibling_list, &batch->msg_list);
1071 		batch->wr_cnt++;
1072 		return 0;
1073 	}
1074 
1075 	msg->wr.send_flags = IB_SEND_SIGNALED;
1076 	return smbdirect_connection_post_send_wr(sc, &msg->wr);
1077 }
1078 
smbdirect_connection_send_single_iter(struct smbdirect_socket * sc,struct smbdirect_send_batch * batch,struct iov_iter * iter,unsigned int flags,u32 remaining_data_length)1079 int smbdirect_connection_send_single_iter(struct smbdirect_socket *sc,
1080 					  struct smbdirect_send_batch *batch,
1081 					  struct iov_iter *iter,
1082 					  unsigned int flags,
1083 					  u32 remaining_data_length)
1084 {
1085 	const struct smbdirect_socket_parameters *sp = &sc->parameters;
1086 	struct smbdirect_send_batch _batch;
1087 	struct smbdirect_send_io *msg;
1088 	struct smbdirect_data_transfer *packet;
1089 	size_t header_length;
1090 	u16 new_credits = 0;
1091 	u32 data_length = 0;
1092 	int ret;
1093 
1094 	if (WARN_ON_ONCE(flags))
1095 		return -EINVAL; /* no flags support for now */
1096 
1097 	if (iter) {
1098 		if (WARN_ON_ONCE(iov_iter_rw(iter) != ITER_SOURCE))
1099 			return -EINVAL; /* It's a bug in upper layer to get there */
1100 
1101 		header_length = sizeof(struct smbdirect_data_transfer);
1102 		if (WARN_ON_ONCE(remaining_data_length == 0 ||
1103 				 iov_iter_count(iter) > remaining_data_length))
1104 			return -EINVAL;
1105 	} else {
1106 		/* If this is a packet without payload, don't send padding */
1107 		header_length = offsetof(struct smbdirect_data_transfer, padding);
1108 		if (WARN_ON_ONCE(remaining_data_length))
1109 			return -EINVAL;
1110 	}
1111 
1112 	if (sc->status != SMBDIRECT_SOCKET_CONNECTED) {
1113 		smbdirect_log_write(sc, SMBDIRECT_LOG_ERR,
1114 			"status=%s first_error=%1pe => %1pe\n",
1115 			smbdirect_socket_status_string(sc->status),
1116 			SMBDIRECT_DEBUG_ERR_PTR(sc->first_error),
1117 			SMBDIRECT_DEBUG_ERR_PTR(-ENOTCONN));
1118 		return -ENOTCONN;
1119 	}
1120 
1121 	if (!batch) {
1122 		smbdirect_connection_send_batch_init(&_batch, false, 0);
1123 		batch = &_batch;
1124 	}
1125 
1126 	ret = smbdirect_connection_wait_for_send_bcredit(sc, batch);
1127 	if (ret)
1128 		goto bcredit_failed;
1129 
1130 	ret = smbdirect_connection_wait_for_send_lcredit(sc, batch);
1131 	if (ret)
1132 		goto lcredit_failed;
1133 
1134 	ret = smbdirect_connection_wait_for_send_credits(sc, batch);
1135 	if (ret)
1136 		goto credit_failed;
1137 
1138 	new_credits = smbdirect_connection_grant_recv_credits(sc);
1139 	if (new_credits == 0 &&
1140 	    atomic_read(&sc->send_io.credits.count) == 0 &&
1141 	    atomic_read(&sc->recv_io.credits.count) == 0) {
1142 		/*
1143 		 * queue the refill work in order to
1144 		 * get some new recv credits we can grant to
1145 		 * the peer.
1146 		 */
1147 		queue_work(sc->workqueues.refill, &sc->recv_io.posted.refill_work);
1148 
1149 		/*
1150 		 * wait until either the refill work or the peer
1151 		 * granted new credits
1152 		 */
1153 		ret = wait_event_interruptible(sc->send_io.credits.wait_queue,
1154 					       atomic_read(&sc->send_io.credits.count) >= 1 ||
1155 					       atomic_read(&sc->recv_io.credits.available) >= 1 ||
1156 					       sc->status != SMBDIRECT_SOCKET_CONNECTED);
1157 		if (sc->status != SMBDIRECT_SOCKET_CONNECTED)
1158 			ret = -ENOTCONN;
1159 		if (ret < 0)
1160 			goto credit_failed;
1161 
1162 		new_credits = smbdirect_connection_grant_recv_credits(sc);
1163 	}
1164 
1165 	msg = smbdirect_connection_alloc_send_io(sc);
1166 	if (IS_ERR(msg)) {
1167 		ret = PTR_ERR(msg);
1168 		goto alloc_failed;
1169 	}
1170 
1171 	/* Map the packet to DMA */
1172 	msg->sge[0].addr = ib_dma_map_single(sc->ib.dev,
1173 					     msg->packet,
1174 					     header_length,
1175 					     DMA_TO_DEVICE);
1176 	ret = ib_dma_mapping_error(sc->ib.dev, msg->sge[0].addr);
1177 	if (ret)
1178 		goto err;
1179 
1180 	msg->sge[0].length = header_length;
1181 	msg->sge[0].lkey = sc->ib.pd->local_dma_lkey;
1182 	msg->num_sge = 1;
1183 
1184 	if (iter) {
1185 		struct smbdirect_map_sges extract = {
1186 			.num_sge	= msg->num_sge,
1187 			.max_sge	= ARRAY_SIZE(msg->sge),
1188 			.sge		= msg->sge,
1189 			.device		= sc->ib.dev,
1190 			.local_dma_lkey	= sc->ib.pd->local_dma_lkey,
1191 			.direction	= DMA_TO_DEVICE,
1192 		};
1193 		size_t payload_len = umin(iov_iter_count(iter),
1194 					  sp->max_send_size - sizeof(*packet));
1195 
1196 		ret = smbdirect_map_sges_from_iter(iter, payload_len, &extract);
1197 		if (ret < 0)
1198 			goto err;
1199 		data_length = ret;
1200 		remaining_data_length -= data_length;
1201 		msg->num_sge = extract.num_sge;
1202 	}
1203 
1204 	/* Fill in the packet header */
1205 	packet = (struct smbdirect_data_transfer *)msg->packet;
1206 	packet->credits_requested = cpu_to_le16(sp->send_credit_target);
1207 	packet->credits_granted = cpu_to_le16(new_credits);
1208 
1209 	packet->flags = 0;
1210 	if (smbdirect_connection_request_keep_alive(sc))
1211 		packet->flags |= cpu_to_le16(SMBDIRECT_FLAG_RESPONSE_REQUESTED);
1212 
1213 	packet->reserved = 0;
1214 	if (!data_length)
1215 		packet->data_offset = 0;
1216 	else
1217 		packet->data_offset = cpu_to_le32(24);
1218 	packet->data_length = cpu_to_le32(data_length);
1219 	packet->remaining_data_length = cpu_to_le32(remaining_data_length);
1220 	packet->padding = 0;
1221 
1222 	smbdirect_log_outgoing(sc, SMBDIRECT_LOG_INFO,
1223 		"DataOut: %s=%u, %s=%u, %s=0x%x, %s=%u, %s=%u, %s=%u\n",
1224 		"CreditsRequested",
1225 		le16_to_cpu(packet->credits_requested),
1226 		"CreditsGranted",
1227 		le16_to_cpu(packet->credits_granted),
1228 		"Flags",
1229 		le16_to_cpu(packet->flags),
1230 		"RemainingDataLength",
1231 		le32_to_cpu(packet->remaining_data_length),
1232 		"DataOffset",
1233 		le32_to_cpu(packet->data_offset),
1234 		"DataLength",
1235 		le32_to_cpu(packet->data_length));
1236 
1237 	ret = smbdirect_connection_post_send_io(sc, batch, msg);
1238 	if (ret)
1239 		goto err;
1240 
1241 	/*
1242 	 * From here msg is moved to send_ctx
1243 	 * and we should not free it explicitly.
1244 	 */
1245 
1246 	if (batch == &_batch) {
1247 		ret = smbdirect_connection_send_batch_flush(sc, batch, true);
1248 		if (ret)
1249 			goto flush_failed;
1250 	}
1251 
1252 	return data_length;
1253 err:
1254 	smbdirect_connection_free_send_io(msg);
1255 flush_failed:
1256 alloc_failed:
1257 	atomic_inc(&sc->send_io.credits.count);
1258 credit_failed:
1259 	atomic_inc(&sc->send_io.lcredits.count);
1260 lcredit_failed:
1261 	atomic_add(batch->credit, &sc->send_io.bcredits.count);
1262 	batch->credit = 0;
1263 bcredit_failed:
1264 	return ret;
1265 }
1266 __SMBDIRECT_EXPORT_SYMBOL__(smbdirect_connection_send_single_iter);
1267 
smbdirect_connection_send_wait_zero_pending(struct smbdirect_socket * sc)1268 int smbdirect_connection_send_wait_zero_pending(struct smbdirect_socket *sc)
1269 {
1270 	/*
1271 	 * As an optimization, we don't wait for individual I/O to finish
1272 	 * before sending the next one.
1273 	 * Send them all and wait for pending send count to get to 0
1274 	 * that means all the I/Os have been out and we are good to return
1275 	 */
1276 
1277 	wait_event(sc->send_io.pending.zero_wait_queue,
1278 		   atomic_read(&sc->send_io.pending.count) == 0 ||
1279 		   sc->status != SMBDIRECT_SOCKET_CONNECTED);
1280 	if (sc->status != SMBDIRECT_SOCKET_CONNECTED) {
1281 		smbdirect_log_write(sc, SMBDIRECT_LOG_ERR,
1282 			"status=%s first_error=%1pe => %1pe\n",
1283 			smbdirect_socket_status_string(sc->status),
1284 			SMBDIRECT_DEBUG_ERR_PTR(sc->first_error),
1285 			SMBDIRECT_DEBUG_ERR_PTR(-ENOTCONN));
1286 		return -ENOTCONN;
1287 	}
1288 
1289 	return 0;
1290 }
1291 __SMBDIRECT_EXPORT_SYMBOL__(smbdirect_connection_send_wait_zero_pending);
1292 
smbdirect_connection_send_iter(struct smbdirect_socket * sc,struct iov_iter * iter,unsigned int flags,bool need_invalidate,unsigned int remote_key)1293 int smbdirect_connection_send_iter(struct smbdirect_socket *sc,
1294 				   struct iov_iter *iter,
1295 				   unsigned int flags,
1296 				   bool need_invalidate,
1297 				   unsigned int remote_key)
1298 {
1299 	const struct smbdirect_socket_parameters *sp = &sc->parameters;
1300 	struct smbdirect_send_batch batch;
1301 	int total_count = iov_iter_count(iter);
1302 	int ret;
1303 	int error = 0;
1304 	__be32 hdr;
1305 
1306 	if (WARN_ONCE(flags, "unexpected flags=0x%x\n", flags))
1307 		return -EINVAL; /* no flags support for now */
1308 
1309 	if (WARN_ON_ONCE(iov_iter_rw(iter) != ITER_SOURCE))
1310 		return -EINVAL; /* It's a bug in upper layer to get there */
1311 
1312 	if (sc->status != SMBDIRECT_SOCKET_CONNECTED) {
1313 		smbdirect_log_write(sc, SMBDIRECT_LOG_INFO,
1314 			"status=%s first_error=%1pe => %1pe\n",
1315 			smbdirect_socket_status_string(sc->status),
1316 			SMBDIRECT_DEBUG_ERR_PTR(sc->first_error),
1317 			SMBDIRECT_DEBUG_ERR_PTR(-ENOTCONN));
1318 		return -ENOTCONN;
1319 	}
1320 
1321 	/*
1322 	 * For now we expect the iter to have the full
1323 	 * message, including a 4 byte length header.
1324 	 */
1325 	if (iov_iter_count(iter) <= 4)
1326 		return -EINVAL;
1327 	if (!copy_from_iter_full(&hdr, sizeof(hdr), iter))
1328 		return -EFAULT;
1329 	if (iov_iter_count(iter) != be32_to_cpu(hdr))
1330 		return -EINVAL;
1331 
1332 	/*
1333 	 * The size must fit into the negotiated
1334 	 * fragmented send size.
1335 	 */
1336 	if (iov_iter_count(iter) > sp->max_fragmented_send_size)
1337 		return -EMSGSIZE;
1338 
1339 	smbdirect_log_write(sc, SMBDIRECT_LOG_INFO,
1340 		"Sending (RDMA): length=%zu\n",
1341 		iov_iter_count(iter));
1342 
1343 	smbdirect_connection_send_batch_init(&batch, need_invalidate, remote_key);
1344 	while (iov_iter_count(iter)) {
1345 		ret = smbdirect_connection_send_single_iter(sc,
1346 							    &batch,
1347 							    iter,
1348 							    flags,
1349 							    iov_iter_count(iter));
1350 		if (unlikely(ret < 0)) {
1351 			error = ret;
1352 			break;
1353 		}
1354 	}
1355 
1356 	ret = smbdirect_connection_send_batch_flush(sc, &batch, true);
1357 	if (unlikely(ret && !error))
1358 		error = ret;
1359 
1360 	/*
1361 	 * As an optimization, we don't wait for individual I/O to finish
1362 	 * before sending the next one.
1363 	 * Send them all and wait for pending send count to get to 0
1364 	 * that means all the I/Os have been out and we are good to return
1365 	 */
1366 
1367 	ret = smbdirect_connection_send_wait_zero_pending(sc);
1368 	if (unlikely(ret && !error))
1369 		error = ret;
1370 
1371 	if (unlikely(error))
1372 		return error;
1373 
1374 	return total_count;
1375 }
1376 __SMBDIRECT_EXPORT_SYMBOL__(smbdirect_connection_send_iter);
1377 
smbdirect_connection_send_io_done(struct ib_cq * cq,struct ib_wc * wc)1378 static void smbdirect_connection_send_io_done(struct ib_cq *cq, struct ib_wc *wc)
1379 {
1380 	struct smbdirect_send_io *msg =
1381 		container_of(wc->wr_cqe, struct smbdirect_send_io, cqe);
1382 	struct smbdirect_socket *sc = msg->socket;
1383 	struct smbdirect_send_io *sibling, *next;
1384 	int lcredits = 0;
1385 
1386 	smbdirect_log_rdma_send(sc, SMBDIRECT_LOG_INFO,
1387 		"smbdirect_send_io completed. status='%s (%d)', opcode=%d\n",
1388 		ib_wc_status_msg(wc->status), wc->status, wc->opcode);
1389 
1390 	if (unlikely(!(msg->wr.send_flags & IB_SEND_SIGNALED))) {
1391 		/*
1392 		 * This happens when smbdirect_send_io is a sibling
1393 		 * before the final message, it is signaled on
1394 		 * error anyway, so we need to skip
1395 		 * smbdirect_connection_free_send_io here,
1396 		 * otherwise is will destroy the memory
1397 		 * of the siblings too, which will cause
1398 		 * use after free problems for the others
1399 		 * triggered from ib_drain_qp().
1400 		 */
1401 		if (wc->status != IB_WC_SUCCESS)
1402 			goto skip_free;
1403 
1404 		/*
1405 		 * This should not happen!
1406 		 * But we better just close the
1407 		 * connection...
1408 		 */
1409 		smbdirect_log_rdma_send(sc, SMBDIRECT_LOG_ERR,
1410 			"unexpected send completion wc->status=%s (%d) wc->opcode=%d\n",
1411 			ib_wc_status_msg(wc->status), wc->status, wc->opcode);
1412 		smbdirect_socket_schedule_cleanup(sc, -ECONNABORTED);
1413 		return;
1414 	}
1415 
1416 	/*
1417 	 * Free possible siblings and then the main send_io
1418 	 */
1419 	list_for_each_entry_safe(sibling, next, &msg->sibling_list, sibling_list) {
1420 		list_del_init(&sibling->sibling_list);
1421 		smbdirect_connection_free_send_io(sibling);
1422 		lcredits += 1;
1423 	}
1424 	/* Note this frees wc->wr_cqe, but not wc */
1425 	smbdirect_connection_free_send_io(msg);
1426 	lcredits += 1;
1427 
1428 	if (unlikely(wc->status != IB_WC_SUCCESS || WARN_ON_ONCE(wc->opcode != IB_WC_SEND))) {
1429 skip_free:
1430 		if (wc->status != IB_WC_WR_FLUSH_ERR)
1431 			smbdirect_log_rdma_send(sc, SMBDIRECT_LOG_ERR,
1432 				"wc->status=%s (%d) wc->opcode=%d\n",
1433 				ib_wc_status_msg(wc->status), wc->status, wc->opcode);
1434 		smbdirect_socket_schedule_cleanup(sc, -ECONNABORTED);
1435 		return;
1436 	}
1437 
1438 	atomic_add(lcredits, &sc->send_io.lcredits.count);
1439 	wake_up(&sc->send_io.lcredits.wait_queue);
1440 
1441 	if (atomic_dec_and_test(&sc->send_io.pending.count))
1442 		wake_up(&sc->send_io.pending.zero_wait_queue);
1443 }
1444 
smbdirect_connection_send_immediate_work(struct work_struct * work)1445 static void smbdirect_connection_send_immediate_work(struct work_struct *work)
1446 {
1447 	struct smbdirect_socket *sc =
1448 		container_of(work, struct smbdirect_socket, idle.immediate_work);
1449 	int ret;
1450 
1451 	if (sc->status != SMBDIRECT_SOCKET_CONNECTED)
1452 		return;
1453 
1454 	smbdirect_log_keep_alive(sc, SMBDIRECT_LOG_INFO,
1455 		"send an empty message\n");
1456 	sc->statistics.send_empty++;
1457 	ret = smbdirect_connection_send_single_iter(sc, NULL, NULL, 0, 0);
1458 	if (ret < 0) {
1459 		smbdirect_log_write(sc, SMBDIRECT_LOG_ERR,
1460 			"smbdirect_connection_send_single_iter ret=%1pe\n",
1461 			SMBDIRECT_DEBUG_ERR_PTR(ret));
1462 		smbdirect_socket_schedule_cleanup(sc, ret);
1463 	}
1464 }
1465 
smbdirect_connection_post_recv_io(struct smbdirect_recv_io * msg)1466 int smbdirect_connection_post_recv_io(struct smbdirect_recv_io *msg)
1467 {
1468 	struct smbdirect_socket *sc = msg->socket;
1469 	const struct smbdirect_socket_parameters *sp = &sc->parameters;
1470 	struct ib_recv_wr recv_wr = {
1471 		.wr_cqe = &msg->cqe,
1472 		.sg_list = &msg->sge,
1473 		.num_sge = 1,
1474 	};
1475 	int ret;
1476 
1477 	if (unlikely(sc->first_error))
1478 		return sc->first_error;
1479 
1480 	msg->sge.addr = ib_dma_map_single(sc->ib.dev,
1481 					  msg->packet,
1482 					  sp->max_recv_size,
1483 					  DMA_FROM_DEVICE);
1484 	ret = ib_dma_mapping_error(sc->ib.dev, msg->sge.addr);
1485 	if (ret)
1486 		return ret;
1487 
1488 	msg->sge.length = sp->max_recv_size;
1489 	msg->sge.lkey = sc->ib.pd->local_dma_lkey;
1490 
1491 	ret = ib_post_recv(sc->ib.qp, &recv_wr, NULL);
1492 	if (ret) {
1493 		smbdirect_log_rdma_recv(sc, SMBDIRECT_LOG_ERR,
1494 			"ib_post_recv failed ret=%d (%1pe)\n",
1495 			ret, SMBDIRECT_DEBUG_ERR_PTR(ret));
1496 		ib_dma_unmap_single(sc->ib.dev,
1497 				    msg->sge.addr,
1498 				    msg->sge.length,
1499 				    DMA_FROM_DEVICE);
1500 		msg->sge.length = 0;
1501 		smbdirect_socket_schedule_cleanup(sc, ret);
1502 	}
1503 
1504 	return ret;
1505 }
1506 
smbdirect_connection_recv_io_done(struct ib_cq * cq,struct ib_wc * wc)1507 void smbdirect_connection_recv_io_done(struct ib_cq *cq, struct ib_wc *wc)
1508 {
1509 	struct smbdirect_recv_io *recv_io =
1510 		container_of(wc->wr_cqe, struct smbdirect_recv_io, cqe);
1511 	struct smbdirect_socket *sc = recv_io->socket;
1512 	const struct smbdirect_socket_parameters *sp = &sc->parameters;
1513 	struct smbdirect_data_transfer *data_transfer;
1514 	int current_recv_credits;
1515 	u16 old_recv_credit_target;
1516 	u16 credits_requested;
1517 	u16 credits_granted;
1518 	u16 flags;
1519 	u32 data_offset;
1520 	u32 data_length;
1521 	u32 remaining_data_length;
1522 
1523 	if (unlikely(wc->status != IB_WC_SUCCESS || WARN_ON_ONCE(wc->opcode != IB_WC_RECV))) {
1524 		if (wc->status != IB_WC_WR_FLUSH_ERR)
1525 			smbdirect_log_rdma_recv(sc, SMBDIRECT_LOG_ERR,
1526 				"wc->status=%s (%d) wc->opcode=%d\n",
1527 				ib_wc_status_msg(wc->status), wc->status, wc->opcode);
1528 		goto error;
1529 	}
1530 
1531 	smbdirect_log_rdma_recv(sc, SMBDIRECT_LOG_INFO,
1532 		"recv_io=0x%p type=%d wc status=%s wc opcode %d byte_len=%d pkey_index=%u\n",
1533 		recv_io, sc->recv_io.expected,
1534 		ib_wc_status_msg(wc->status), wc->opcode,
1535 		wc->byte_len, wc->pkey_index);
1536 
1537 	/*
1538 	 * Reset timer to the keepalive interval in
1539 	 * order to trigger our next keepalive message.
1540 	 */
1541 	sc->idle.keepalive = SMBDIRECT_KEEPALIVE_NONE;
1542 	mod_delayed_work(sc->workqueues.idle, &sc->idle.timer_work,
1543 			 msecs_to_jiffies(sp->keepalive_interval_msec));
1544 
1545 	ib_dma_sync_single_for_cpu(sc->ib.dev,
1546 				   recv_io->sge.addr,
1547 				   recv_io->sge.length,
1548 				   DMA_FROM_DEVICE);
1549 
1550 	if (unlikely(wc->byte_len <
1551 	    offsetof(struct smbdirect_data_transfer, padding))) {
1552 		smbdirect_log_rdma_event(sc, SMBDIRECT_LOG_ERR,
1553 			"wc->byte_len=%u < %zu\n",
1554 			wc->byte_len,
1555 			offsetof(struct smbdirect_data_transfer, padding));
1556 		goto error;
1557 	}
1558 
1559 	data_transfer = (struct smbdirect_data_transfer *)recv_io->packet;
1560 	credits_requested = le16_to_cpu(data_transfer->credits_requested);
1561 	credits_granted = le16_to_cpu(data_transfer->credits_granted);
1562 	flags = le16_to_cpu(data_transfer->flags);
1563 	remaining_data_length = le32_to_cpu(data_transfer->remaining_data_length);
1564 	data_offset = le32_to_cpu(data_transfer->data_offset);
1565 	data_length = le32_to_cpu(data_transfer->data_length);
1566 
1567 	smbdirect_log_incoming(sc, SMBDIRECT_LOG_INFO,
1568 		"DataIn: %s=%u, %s=%u, %s=0x%x, %s=%u, %s=%u, %s=%u\n",
1569 		"CreditsRequested",
1570 		credits_requested,
1571 		"CreditsGranted",
1572 		credits_granted,
1573 		"Flags",
1574 		flags,
1575 		"RemainingDataLength",
1576 		remaining_data_length,
1577 		"DataOffset",
1578 		data_offset,
1579 		"DataLength",
1580 		data_length);
1581 
1582 	if (unlikely(credits_requested == 0)) {
1583 		smbdirect_log_rdma_event(sc, SMBDIRECT_LOG_ERR,
1584 			"invalid: credits_requested == 0\n");
1585 		goto error;
1586 	}
1587 
1588 	if (unlikely(data_offset % 8 != 0)) {
1589 		smbdirect_log_rdma_event(sc, SMBDIRECT_LOG_ERR,
1590 			"invalid: data_offset=%u (0x%x) not aligned to 8\n",
1591 			data_offset, data_offset);
1592 		goto error;
1593 	}
1594 
1595 	if (unlikely(wc->byte_len < data_offset ||
1596 	    (u64)wc->byte_len < (u64)data_offset + data_length)) {
1597 		smbdirect_log_rdma_event(sc, SMBDIRECT_LOG_ERR,
1598 			"wc->byte_len=%u < date_offset=%u + data_length=%u\n",
1599 			wc->byte_len, data_offset, data_length);
1600 		goto error;
1601 	}
1602 
1603 	if (unlikely(remaining_data_length > sp->max_fragmented_recv_size ||
1604 	    data_length > sp->max_fragmented_recv_size ||
1605 	    (u64)remaining_data_length + (u64)data_length > (u64)sp->max_fragmented_recv_size)) {
1606 		smbdirect_log_rdma_event(sc, SMBDIRECT_LOG_ERR,
1607 			"remaining_data_length=%u + data_length=%u > max_fragmented=%u\n",
1608 			remaining_data_length, data_length, sp->max_fragmented_recv_size);
1609 		goto error;
1610 	}
1611 
1612 	if (data_length) {
1613 		if (sc->recv_io.reassembly.full_packet_received)
1614 			recv_io->first_segment = true;
1615 
1616 		if (remaining_data_length)
1617 			sc->recv_io.reassembly.full_packet_received = false;
1618 		else
1619 			sc->recv_io.reassembly.full_packet_received = true;
1620 	}
1621 
1622 	atomic_dec(&sc->recv_io.posted.count);
1623 	current_recv_credits = atomic_dec_return(&sc->recv_io.credits.count);
1624 
1625 	/*
1626 	 * We take the value from the peer, which is checked to be higher than 0,
1627 	 * but we limit it to the max value we support in order to have
1628 	 * the main logic simpler.
1629 	 */
1630 	old_recv_credit_target = sc->recv_io.credits.target;
1631 	sc->recv_io.credits.target = credits_requested;
1632 	sc->recv_io.credits.target = min_t(u16, sc->recv_io.credits.target,
1633 					   sp->recv_credit_max);
1634 	if (credits_granted) {
1635 		atomic_add(credits_granted, &sc->send_io.credits.count);
1636 		/*
1637 		 * We have new send credits granted from remote peer
1638 		 * If any sender is waiting for credits, unblock it
1639 		 */
1640 		wake_up(&sc->send_io.credits.wait_queue);
1641 	}
1642 
1643 	/* Send an immediate response right away if requested */
1644 	if (flags & SMBDIRECT_FLAG_RESPONSE_REQUESTED) {
1645 		smbdirect_log_keep_alive(sc, SMBDIRECT_LOG_INFO,
1646 			"schedule send of immediate response\n");
1647 		queue_work(sc->workqueues.immediate, &sc->idle.immediate_work);
1648 	}
1649 
1650 	/*
1651 	 * If this is a packet with data playload place the data in
1652 	 * reassembly queue and wake up the reading thread
1653 	 */
1654 	if (data_length) {
1655 		if (current_recv_credits <= (sc->recv_io.credits.target / 4) ||
1656 		    sc->recv_io.credits.target > old_recv_credit_target)
1657 			queue_work(sc->workqueues.refill, &sc->recv_io.posted.refill_work);
1658 
1659 		smbdirect_connection_reassembly_append_recv_io(sc, recv_io, data_length);
1660 		wake_up(&sc->recv_io.reassembly.wait_queue);
1661 	} else
1662 		smbdirect_connection_put_recv_io(recv_io);
1663 
1664 	return;
1665 
1666 error:
1667 	/*
1668 	 * Make sure smbdirect_connection_put_recv_io() does not
1669 	 * start recv_io.posted.refill_work.
1670 	 */
1671 	disable_work(&sc->recv_io.posted.refill_work);
1672 	smbdirect_connection_put_recv_io(recv_io);
1673 	smbdirect_socket_schedule_cleanup(sc, -ECONNABORTED);
1674 }
1675 
smbdirect_connection_recv_io_refill(struct smbdirect_socket * sc)1676 int smbdirect_connection_recv_io_refill(struct smbdirect_socket *sc)
1677 {
1678 	int missing;
1679 	int posted = 0;
1680 
1681 	if (unlikely(sc->first_error))
1682 		return sc->first_error;
1683 
1684 	/*
1685 	 * Find out how much smbdirect_recv_io buffers we should post.
1686 	 *
1687 	 * Note that sc->recv_io.credits.target is the value
1688 	 * from the peer and it can in theory change over time,
1689 	 * but it is forced to be at least 1 and at max
1690 	 * sp->recv_credit_max.
1691 	 *
1692 	 * So it can happen that missing will be lower than 0,
1693 	 * which means the peer has recently lowered its desired
1694 	 * target, while be already granted a higher number of credits.
1695 	 *
1696 	 * Note 'posted' is the number of smbdirect_recv_io buffers
1697 	 * posted within this function, while sc->recv_io.posted.count
1698 	 * is the overall value of posted smbdirect_recv_io buffers.
1699 	 *
1700 	 * We try to post as much buffers as missing, but
1701 	 * this is limited if a lot of smbdirect_recv_io buffers
1702 	 * are still in the sc->recv_io.reassembly.list instead of
1703 	 * the sc->recv_io.free.list.
1704 	 *
1705 	 */
1706 	missing = (int)sc->recv_io.credits.target - atomic_read(&sc->recv_io.posted.count);
1707 	while (posted < missing) {
1708 		struct smbdirect_recv_io *recv_io;
1709 		int ret;
1710 
1711 		/*
1712 		 * It's ok if smbdirect_connection_get_recv_io()
1713 		 * returns NULL, it means smbdirect_recv_io structures
1714 		 * are still be in the reassembly.list.
1715 		 */
1716 		recv_io = smbdirect_connection_get_recv_io(sc);
1717 		if (!recv_io)
1718 			break;
1719 
1720 		recv_io->first_segment = false;
1721 
1722 		ret = smbdirect_connection_post_recv_io(recv_io);
1723 		if (ret) {
1724 			smbdirect_log_rdma_recv(sc, SMBDIRECT_LOG_ERR,
1725 				"smbdirect_connection_post_recv_io failed rc=%d (%1pe)\n",
1726 				ret, SMBDIRECT_DEBUG_ERR_PTR(ret));
1727 			smbdirect_connection_put_recv_io(recv_io);
1728 			return ret;
1729 		}
1730 
1731 		atomic_inc(&sc->recv_io.posted.count);
1732 		posted += 1;
1733 	}
1734 
1735 	/* If nothing was posted we're done */
1736 	if (posted == 0)
1737 		return 0;
1738 
1739 	atomic_add(posted, &sc->recv_io.credits.available);
1740 
1741 	/*
1742 	 * If the last send credit is waiting for credits
1743 	 * it can grant we need to wake it up
1744 	 */
1745 	if (atomic_read(&sc->send_io.bcredits.count) == 0 &&
1746 	    atomic_read(&sc->send_io.credits.count) == 0)
1747 		wake_up(&sc->send_io.credits.wait_queue);
1748 
1749 	/*
1750 	 * If we posted at least one smbdirect_recv_io buffer,
1751 	 * we need to inform the peer about it and grant
1752 	 * additional credits.
1753 	 *
1754 	 * However there is one case where we don't want to
1755 	 * do that.
1756 	 *
1757 	 * If only a single credit was missing before
1758 	 * reaching the requested target, we should not
1759 	 * post an immediate send, as that would cause
1760 	 * endless ping pong once a keep alive exchange
1761 	 * is started.
1762 	 *
1763 	 * However if sc->recv_io.credits.target is only 1,
1764 	 * the peer has no credit left and we need to
1765 	 * grant the credit anyway.
1766 	 */
1767 	if (missing == 1 && sc->recv_io.credits.target != 1)
1768 		return 0;
1769 
1770 	return posted;
1771 }
1772 
smbdirect_connection_recv_io_refill_work(struct work_struct * work)1773 static void smbdirect_connection_recv_io_refill_work(struct work_struct *work)
1774 {
1775 	struct smbdirect_socket *sc =
1776 		container_of(work, struct smbdirect_socket, recv_io.posted.refill_work);
1777 	int posted;
1778 
1779 	posted = smbdirect_connection_recv_io_refill(sc);
1780 	if (unlikely(posted < 0)) {
1781 		smbdirect_socket_schedule_cleanup(sc, posted);
1782 		return;
1783 	}
1784 	if (posted > 0) {
1785 		smbdirect_log_keep_alive(sc, SMBDIRECT_LOG_INFO,
1786 			"schedule send of an empty message\n");
1787 		queue_work(sc->workqueues.immediate, &sc->idle.immediate_work);
1788 	}
1789 }
1790 
smbdirect_connection_recvmsg(struct smbdirect_socket * sc,struct msghdr * msg,unsigned int flags)1791 int smbdirect_connection_recvmsg(struct smbdirect_socket *sc,
1792 				 struct msghdr *msg,
1793 				 unsigned int flags)
1794 {
1795 	struct smbdirect_recv_io *response;
1796 	struct smbdirect_data_transfer *data_transfer;
1797 	size_t size = iov_iter_count(&msg->msg_iter);
1798 	int to_copy, to_read, data_read, offset;
1799 	u32 data_length, remaining_data_length, data_offset;
1800 	int ret;
1801 
1802 	if (WARN_ONCE(flags, "unexpected flags=0x%x\n", flags))
1803 		return -EINVAL; /* no flags support for now */
1804 
1805 	if (WARN_ON_ONCE(iov_iter_rw(&msg->msg_iter) != ITER_DEST))
1806 		return -EINVAL; /* It's a bug in upper layer to get there */
1807 
1808 again:
1809 	if (sc->status != SMBDIRECT_SOCKET_CONNECTED) {
1810 		smbdirect_log_read(sc, SMBDIRECT_LOG_INFO,
1811 			"status=%s first_error=%1pe => %1pe\n",
1812 			smbdirect_socket_status_string(sc->status),
1813 			SMBDIRECT_DEBUG_ERR_PTR(sc->first_error),
1814 			SMBDIRECT_DEBUG_ERR_PTR(-ENOTCONN));
1815 		return -ENOTCONN;
1816 	}
1817 
1818 	/*
1819 	 * No need to hold the reassembly queue lock all the time as we are
1820 	 * the only one reading from the front of the queue. The transport
1821 	 * may add more entries to the back of the queue at the same time
1822 	 */
1823 	smbdirect_log_read(sc, SMBDIRECT_LOG_INFO,
1824 		"size=%zd sc->recv_io.reassembly.data_length=%d\n",
1825 		size, sc->recv_io.reassembly.data_length);
1826 	if (sc->recv_io.reassembly.data_length >= size) {
1827 		int queue_length;
1828 		int queue_removed = 0;
1829 		unsigned long flags;
1830 
1831 		/*
1832 		 * Need to make sure reassembly_data_length is read before
1833 		 * reading reassembly_queue_length and calling
1834 		 * smbdirect_connection_reassembly_first_recv_io. This call is lock free
1835 		 * as we never read at the end of the queue which are being
1836 		 * updated in SOFTIRQ as more data is received
1837 		 */
1838 		virt_rmb();
1839 		queue_length = sc->recv_io.reassembly.queue_length;
1840 		data_read = 0;
1841 		to_read = size;
1842 		offset = sc->recv_io.reassembly.first_entry_offset;
1843 		while (data_read < size) {
1844 			response = smbdirect_connection_reassembly_first_recv_io(sc);
1845 			data_transfer = (void *)response->packet;
1846 			data_length = le32_to_cpu(data_transfer->data_length);
1847 			remaining_data_length =
1848 				le32_to_cpu(
1849 					data_transfer->remaining_data_length);
1850 			data_offset = le32_to_cpu(data_transfer->data_offset);
1851 
1852 			/*
1853 			 * The upper layer expects RFC1002 length at the
1854 			 * beginning of the payload. Return it to indicate
1855 			 * the total length of the packet. This minimize the
1856 			 * change to upper layer packet processing logic. This
1857 			 * will be eventually remove when an intermediate
1858 			 * transport layer is added
1859 			 */
1860 			if (response->first_segment && size == 4) {
1861 				unsigned int rfc1002_len =
1862 					data_length + remaining_data_length;
1863 				__be32 rfc1002_hdr = cpu_to_be32(rfc1002_len);
1864 
1865 				if (copy_to_iter(&rfc1002_hdr, sizeof(rfc1002_hdr),
1866 						 &msg->msg_iter) != sizeof(rfc1002_hdr))
1867 					return -EFAULT;
1868 				data_read = 4;
1869 				response->first_segment = false;
1870 				smbdirect_log_read(sc, SMBDIRECT_LOG_INFO,
1871 					"returning rfc1002 length %d\n",
1872 					rfc1002_len);
1873 				goto read_rfc1002_done;
1874 			}
1875 
1876 			to_copy = min_t(int, data_length - offset, to_read);
1877 			if (copy_to_iter((u8 *)data_transfer + data_offset + offset,
1878 					 to_copy, &msg->msg_iter) != to_copy)
1879 				return -EFAULT;
1880 
1881 			/* move on to the next buffer? */
1882 			if (to_copy == data_length - offset) {
1883 				queue_length--;
1884 				/*
1885 				 * No need to lock if we are not at the
1886 				 * end of the queue
1887 				 */
1888 				if (queue_length)
1889 					list_del(&response->list);
1890 				else {
1891 					spin_lock_irqsave(
1892 						&sc->recv_io.reassembly.lock, flags);
1893 					list_del(&response->list);
1894 					spin_unlock_irqrestore(
1895 						&sc->recv_io.reassembly.lock, flags);
1896 				}
1897 				queue_removed++;
1898 				sc->statistics.dequeue_reassembly_queue++;
1899 				smbdirect_connection_put_recv_io(response);
1900 				offset = 0;
1901 				smbdirect_log_read(sc, SMBDIRECT_LOG_INFO,
1902 					"smbdirect_connection_put_recv_io offset=0\n");
1903 			} else
1904 				offset += to_copy;
1905 
1906 			to_read -= to_copy;
1907 			data_read += to_copy;
1908 
1909 			smbdirect_log_read(sc, SMBDIRECT_LOG_INFO,
1910 				 "memcpy %d bytes len-ofs=%u => todo=%u done=%u ofs=%u\n",
1911 				 to_copy, data_length - offset,
1912 				 to_read, data_read, offset);
1913 		}
1914 
1915 		spin_lock_irqsave(&sc->recv_io.reassembly.lock, flags);
1916 		sc->recv_io.reassembly.data_length -= data_read;
1917 		sc->recv_io.reassembly.queue_length -= queue_removed;
1918 		spin_unlock_irqrestore(&sc->recv_io.reassembly.lock, flags);
1919 
1920 		sc->recv_io.reassembly.first_entry_offset = offset;
1921 		smbdirect_log_read(sc, SMBDIRECT_LOG_INFO,
1922 			 "returning data_read=%d reassembly_length=%d first_ofs=%u\n",
1923 			 data_read, sc->recv_io.reassembly.data_length,
1924 			 sc->recv_io.reassembly.first_entry_offset);
1925 read_rfc1002_done:
1926 		return data_read;
1927 	}
1928 
1929 	smbdirect_log_read(sc, SMBDIRECT_LOG_INFO,
1930 		"wait_event on more data\n");
1931 	ret = wait_event_interruptible(sc->recv_io.reassembly.wait_queue,
1932 				       sc->recv_io.reassembly.data_length >= size ||
1933 				       sc->status != SMBDIRECT_SOCKET_CONNECTED);
1934 	/* Don't return any data if interrupted */
1935 	if (ret)
1936 		return ret;
1937 
1938 	goto again;
1939 }
1940 __SMBDIRECT_EXPORT_SYMBOL__(smbdirect_connection_recvmsg);
1941 
smbdirect_map_sges_single_page(struct smbdirect_map_sges * state,struct page * page,size_t off,size_t len)1942 static bool smbdirect_map_sges_single_page(struct smbdirect_map_sges *state,
1943 					   struct page *page, size_t off, size_t len)
1944 {
1945 	struct ib_sge *sge;
1946 	u64 addr;
1947 
1948 	if (state->num_sge >= state->max_sge)
1949 		return false;
1950 
1951 	addr = ib_dma_map_page(state->device, page,
1952 			       off, len, state->direction);
1953 	if (ib_dma_mapping_error(state->device, addr))
1954 		return false;
1955 
1956 	sge = &state->sge[state->num_sge++];
1957 	sge->addr   = addr;
1958 	sge->length = len;
1959 	sge->lkey   = state->local_dma_lkey;
1960 
1961 	return true;
1962 }
1963 
1964 /*
1965  * Extract page fragments from a BVEC-class iterator and add them to an ib_sge
1966  * list.  The pages are not pinned.
1967  */
smbdirect_map_sges_from_bvec(struct iov_iter * iter,struct smbdirect_map_sges * state,ssize_t maxsize)1968 static ssize_t smbdirect_map_sges_from_bvec(struct iov_iter *iter,
1969 					    struct smbdirect_map_sges *state,
1970 					    ssize_t maxsize)
1971 {
1972 	const struct bio_vec *bv = iter->bvec;
1973 	unsigned long start = iter->iov_offset;
1974 	unsigned int i;
1975 	ssize_t ret = 0;
1976 
1977 	for (i = 0; i < iter->nr_segs; i++) {
1978 		size_t off, len;
1979 		bool ok;
1980 
1981 		len = bv[i].bv_len;
1982 		if (start >= len) {
1983 			start -= len;
1984 			continue;
1985 		}
1986 
1987 		len = min_t(size_t, maxsize, len - start);
1988 		off = bv[i].bv_offset + start;
1989 
1990 		ok = smbdirect_map_sges_single_page(state,
1991 						    bv[i].bv_page,
1992 						    off,
1993 						    len);
1994 		if (!ok)
1995 			return -EIO;
1996 
1997 		ret += len;
1998 		maxsize -= len;
1999 		if (state->num_sge >= state->max_sge || maxsize <= 0)
2000 			break;
2001 		start = 0;
2002 	}
2003 
2004 	if (ret > 0)
2005 		iov_iter_advance(iter, ret);
2006 	return ret;
2007 }
2008 
2009 /*
2010  * Extract fragments from a KVEC-class iterator and add them to an ib_sge list.
2011  * This can deal with vmalloc'd buffers as well as kmalloc'd or static buffers.
2012  * The pages are not pinned.
2013  */
smbdirect_map_sges_from_kvec(struct iov_iter * iter,struct smbdirect_map_sges * state,ssize_t maxsize)2014 static ssize_t smbdirect_map_sges_from_kvec(struct iov_iter *iter,
2015 					    struct smbdirect_map_sges *state,
2016 					    ssize_t maxsize)
2017 {
2018 	const struct kvec *kv = iter->kvec;
2019 	unsigned long start = iter->iov_offset;
2020 	unsigned int i;
2021 	ssize_t ret = 0;
2022 
2023 	for (i = 0; i < iter->nr_segs; i++) {
2024 		struct page *page;
2025 		unsigned long kaddr;
2026 		size_t off, len, seg;
2027 
2028 		len = kv[i].iov_len;
2029 		if (start >= len) {
2030 			start -= len;
2031 			continue;
2032 		}
2033 
2034 		kaddr = (unsigned long)kv[i].iov_base + start;
2035 		off = kaddr & ~PAGE_MASK;
2036 		len = min_t(size_t, maxsize, len - start);
2037 		kaddr &= PAGE_MASK;
2038 
2039 		maxsize -= len;
2040 		do {
2041 			bool ok;
2042 
2043 			seg = min_t(size_t, len, PAGE_SIZE - off);
2044 
2045 			if (is_vmalloc_or_module_addr((void *)kaddr))
2046 				page = vmalloc_to_page((void *)kaddr);
2047 			else
2048 				page = virt_to_page((void *)kaddr);
2049 
2050 			ok = smbdirect_map_sges_single_page(state, page, off, seg);
2051 			if (!ok)
2052 				return -EIO;
2053 
2054 			ret += seg;
2055 			len -= seg;
2056 			kaddr += PAGE_SIZE;
2057 			off = 0;
2058 		} while (len > 0 && state->num_sge < state->max_sge);
2059 
2060 		if (state->num_sge >= state->max_sge || maxsize <= 0)
2061 			break;
2062 		start = 0;
2063 	}
2064 
2065 	if (ret > 0)
2066 		iov_iter_advance(iter, ret);
2067 	return ret;
2068 }
2069 
2070 /*
2071  * Extract folio fragments from a FOLIOQ-class iterator and add them to an
2072  * ib_sge list.  The folios are not pinned.
2073  */
smbdirect_map_sges_from_folioq(struct iov_iter * iter,struct smbdirect_map_sges * state,ssize_t maxsize)2074 static ssize_t smbdirect_map_sges_from_folioq(struct iov_iter *iter,
2075 					      struct smbdirect_map_sges *state,
2076 					      ssize_t maxsize)
2077 {
2078 	const struct folio_queue *folioq = iter->folioq;
2079 	unsigned int slot = iter->folioq_slot;
2080 	ssize_t ret = 0;
2081 	size_t offset = iter->iov_offset;
2082 
2083 	if (WARN_ON_ONCE(!folioq))
2084 		return -EIO;
2085 
2086 	if (slot >= folioq_nr_slots(folioq)) {
2087 		folioq = folioq->next;
2088 		if (WARN_ON_ONCE(!folioq))
2089 			return -EIO;
2090 		slot = 0;
2091 	}
2092 
2093 	do {
2094 		struct folio *folio = folioq_folio(folioq, slot);
2095 		size_t fsize = folioq_folio_size(folioq, slot);
2096 
2097 		if (offset < fsize) {
2098 			size_t part = umin(maxsize, fsize - offset);
2099 			bool ok;
2100 
2101 			ok = smbdirect_map_sges_single_page(state,
2102 							    folio_page(folio, 0),
2103 							    offset,
2104 							    part);
2105 			if (!ok)
2106 				return -EIO;
2107 
2108 			offset += part;
2109 			ret += part;
2110 			maxsize -= part;
2111 		}
2112 
2113 		if (offset >= fsize) {
2114 			offset = 0;
2115 			slot++;
2116 			if (slot >= folioq_nr_slots(folioq)) {
2117 				if (!folioq->next) {
2118 					WARN_ON_ONCE(ret < iter->count);
2119 					break;
2120 				}
2121 				folioq = folioq->next;
2122 				slot = 0;
2123 			}
2124 		}
2125 	} while (state->num_sge < state->max_sge && maxsize > 0);
2126 
2127 	iter->folioq = folioq;
2128 	iter->folioq_slot = slot;
2129 	iter->iov_offset = offset;
2130 	iter->count -= ret;
2131 	return ret;
2132 }
2133 
2134 /*
2135  * Extract page fragments from up to the given amount of the source iterator
2136  * and build up an ib_sge list that refers to all of those bits.  The ib_sge list
2137  * is appended to, up to the maximum number of elements set in the parameter
2138  * block.
2139  *
2140  * The extracted page fragments are not pinned or ref'd in any way; if an
2141  * IOVEC/UBUF-type iterator is to be used, it should be converted to a
2142  * BVEC-type iterator and the pages pinned, ref'd or otherwise held in some
2143  * way.
2144  */
smbdirect_map_sges_from_iter(struct iov_iter * iter,size_t len,struct smbdirect_map_sges * state)2145 static ssize_t smbdirect_map_sges_from_iter(struct iov_iter *iter, size_t len,
2146 					    struct smbdirect_map_sges *state)
2147 {
2148 	ssize_t ret;
2149 	size_t before = state->num_sge;
2150 
2151 	if (WARN_ON_ONCE(iov_iter_rw(iter) != ITER_SOURCE))
2152 		return -EIO;
2153 
2154 	switch (iov_iter_type(iter)) {
2155 	case ITER_BVEC:
2156 		ret = smbdirect_map_sges_from_bvec(iter, state, len);
2157 		break;
2158 	case ITER_KVEC:
2159 		ret = smbdirect_map_sges_from_kvec(iter, state, len);
2160 		break;
2161 	case ITER_FOLIOQ:
2162 		ret = smbdirect_map_sges_from_folioq(iter, state, len);
2163 		break;
2164 	default:
2165 		WARN_ONCE(1, "iov_iter_type[%u]\n", iov_iter_type(iter));
2166 		return -EIO;
2167 	}
2168 
2169 	if (ret < 0) {
2170 		while (state->num_sge > before) {
2171 			struct ib_sge *sge = &state->sge[state->num_sge--];
2172 
2173 			ib_dma_unmap_page(state->device,
2174 					  sge->addr,
2175 					  sge->length,
2176 					  state->direction);
2177 		}
2178 	}
2179 
2180 	return ret;
2181 }
2182