// SPDX-License-Identifier: GPL-2.0-or-later
/*
 *   Copyright (C) 2017, Microsoft Corporation.
 *   Copyright (c) 2025, Stefan Metzmacher
 */

#include "smbdirect_internal.h"

bool smbdirect_frwr_is_supported(const struct ib_device_attr *attrs)
{
	/*
	 * Test if FRWR (Fast Registration Work Requests) is supported on the
	 * device This implementation requires FRWR on RDMA read/write return
	 * value: true if it is supported
	 */

	if (!(attrs->device_cap_flags & IB_DEVICE_MEM_MGT_EXTENSIONS))
		return false;
	if (attrs->max_fast_reg_page_list_len == 0)
		return false;
	return true;
}
__SMBDIRECT_EXPORT_SYMBOL__(smbdirect_frwr_is_supported);

static void smbdirect_socket_cleanup_work(struct work_struct *work);

static int smbdirect_socket_rdma_event_handler(struct rdma_cm_id *id,
					       struct rdma_cm_event *event)
{
	struct smbdirect_socket *sc = id->context;
	int ret = -ESTALE;

	/*
	 * This should be replaced before any real work
	 * starts! So it should never be called!
	 */

	if (event->event == RDMA_CM_EVENT_DEVICE_REMOVAL)
		ret = -ENETDOWN;
	if (IS_ERR(SMBDIRECT_DEBUG_ERR_PTR(event->status)))
		ret = event->status;
	pr_err("%s (first_error=%1pe, expected=%s) => event=%s status=%d => ret=%1pe\n",
		smbdirect_socket_status_string(sc->status),
		SMBDIRECT_DEBUG_ERR_PTR(sc->first_error),
		rdma_event_msg(sc->rdma.expected_event),
		rdma_event_msg(event->event),
		event->status,
		SMBDIRECT_DEBUG_ERR_PTR(ret));
	WARN_ONCE(1, "%s should not be called!\n", __func__);
	sc->rdma.cm_id = NULL;
	return -ESTALE;
}

int smbdirect_socket_init_new(struct net *net, struct smbdirect_socket *sc)
{
	struct rdma_cm_id *id;
	int ret;

	smbdirect_socket_init(sc);

	id = rdma_create_id(net,
			    smbdirect_socket_rdma_event_handler,
			    sc,
			    RDMA_PS_TCP,
			    IB_QPT_RC);
	if (IS_ERR(id)) {
		pr_err("%s: rdma_create_id() failed %1pe\n", __func__, id);
		return PTR_ERR(id);
	}

	ret = rdma_set_afonly(id, 1);
	if (ret) {
		rdma_destroy_id(id);
		pr_err("%s: rdma_set_afonly() failed %1pe\n",
		       __func__, SMBDIRECT_DEBUG_ERR_PTR(ret));
		return ret;
	}

	sc->rdma.cm_id = id;

	INIT_WORK(&sc->disconnect_work, smbdirect_socket_cleanup_work);

	return 0;
}

int smbdirect_socket_create_kern(struct net *net, struct smbdirect_socket **_sc)
{
	struct smbdirect_socket *sc;
	int ret;

	ret = -ENOMEM;
	sc = kzalloc_obj(*sc);
	if (!sc)
		goto alloc_failed;

	ret = smbdirect_socket_init_new(net, sc);
	if (ret)
		goto init_failed;

	kref_init(&sc->refs.destroy);

	*_sc = sc;
	return 0;

init_failed:
	kfree(sc);
alloc_failed:
	return ret;
}
__SMBDIRECT_EXPORT_SYMBOL__(smbdirect_socket_create_kern);

int smbdirect_socket_init_accepting(struct rdma_cm_id *id, struct smbdirect_socket *sc)
{
	smbdirect_socket_init(sc);

	sc->rdma.cm_id = id;
	sc->rdma.cm_id->context = sc;
	sc->rdma.cm_id->event_handler = smbdirect_socket_rdma_event_handler;

	sc->ib.dev = sc->rdma.cm_id->device;

	INIT_WORK(&sc->disconnect_work, smbdirect_socket_cleanup_work);

	return 0;
}

int smbdirect_socket_create_accepting(struct rdma_cm_id *id, struct smbdirect_socket **_sc)
{
	struct smbdirect_socket *sc;
	int ret;

	ret = -ENOMEM;
	sc = kzalloc_obj(*sc);
	if (!sc)
		goto alloc_failed;

	ret = smbdirect_socket_init_accepting(id, sc);
	if (ret)
		goto init_failed;

	kref_init(&sc->refs.destroy);

	*_sc = sc;
	return 0;

init_failed:
	kfree(sc);
alloc_failed:
	return ret;
}
__SMBDIRECT_EXPORT_SYMBOL__(smbdirect_socket_create_accepting);

int smbdirect_socket_set_initial_parameters(struct smbdirect_socket *sc,
					    const struct smbdirect_socket_parameters *sp)
{
	/*
	 * This is only allowed before connect or accept
	 */
	WARN_ONCE(sc->status != SMBDIRECT_SOCKET_CREATED,
		  "status=%s first_error=%1pe",
		  smbdirect_socket_status_string(sc->status),
		  SMBDIRECT_DEBUG_ERR_PTR(sc->first_error));
	if (sc->status != SMBDIRECT_SOCKET_CREATED)
		return -EINVAL;

	if (sp->flags & ~SMBDIRECT_FLAG_PORT_RANGE_MASK)
		return -EINVAL;

	if (sp->initiator_depth > U8_MAX)
		return -EINVAL;
	if (sp->responder_resources > U8_MAX)
		return -EINVAL;

	if (sp->flags & SMBDIRECT_FLAG_PORT_RANGE_ONLY_IB &&
	    sp->flags & SMBDIRECT_FLAG_PORT_RANGE_ONLY_IW)
		return -EINVAL;
	else if (sp->flags & SMBDIRECT_FLAG_PORT_RANGE_ONLY_IB)
		rdma_restrict_node_type(sc->rdma.cm_id, RDMA_NODE_IB_CA);
	else if (sp->flags & SMBDIRECT_FLAG_PORT_RANGE_ONLY_IW)
		rdma_restrict_node_type(sc->rdma.cm_id, RDMA_NODE_RNIC);

	/*
	 * Make a copy of the callers parameters
	 * from here we only work on the copy
	 *
	 * TODO: do we want consistency checking?
	 */
	sc->parameters = *sp;

	return 0;
}
__SMBDIRECT_EXPORT_SYMBOL__(smbdirect_socket_set_initial_parameters);

const struct smbdirect_socket_parameters *
smbdirect_socket_get_current_parameters(struct smbdirect_socket *sc)
{
	return &sc->parameters;
}
__SMBDIRECT_EXPORT_SYMBOL__(smbdirect_socket_get_current_parameters);

int smbdirect_socket_set_kernel_settings(struct smbdirect_socket *sc,
					 enum ib_poll_context poll_ctx,
					 gfp_t gfp_mask)
{
	/*
	 * This is only allowed before connect or accept
	 */
	WARN_ONCE(sc->status != SMBDIRECT_SOCKET_CREATED,
		  "status=%s first_error=%1pe",
		  smbdirect_socket_status_string(sc->status),
		  SMBDIRECT_DEBUG_ERR_PTR(sc->first_error));
	if (sc->status != SMBDIRECT_SOCKET_CREATED)
		return -EINVAL;

	sc->ib.poll_ctx = poll_ctx;

	sc->send_io.mem.gfp_mask = gfp_mask;
	sc->recv_io.mem.gfp_mask = gfp_mask;
	sc->rw_io.mem.gfp_mask = gfp_mask;

	return 0;
}
__SMBDIRECT_EXPORT_SYMBOL__(smbdirect_socket_set_kernel_settings);

void smbdirect_socket_set_logging(struct smbdirect_socket *sc,
				  void *private_ptr,
				  bool (*needed)(struct smbdirect_socket *sc,
						 void *private_ptr,
						 unsigned int lvl,
						 unsigned int cls),
				  void (*vaprintf)(struct smbdirect_socket *sc,
						   const char *func,
						   unsigned int line,
						   void *private_ptr,
						   unsigned int lvl,
						   unsigned int cls,
						   struct va_format *vaf))
{
	sc->logging.private_ptr = private_ptr;
	sc->logging.needed = needed;
	sc->logging.vaprintf = vaprintf;
}
__SMBDIRECT_EXPORT_SYMBOL__(smbdirect_socket_set_logging);

static void smbdirect_socket_wake_up_all(struct smbdirect_socket *sc)
{
	/*
	 * Wake up all waiters in all wait queues
	 * in order to notice the broken connection.
	 */
	wake_up_all(&sc->status_wait);
	wake_up_all(&sc->listen.wait_queue);
	wake_up_all(&sc->send_io.bcredits.wait_queue);
	wake_up_all(&sc->send_io.lcredits.wait_queue);
	wake_up_all(&sc->send_io.credits.wait_queue);
	wake_up_all(&sc->send_io.pending.zero_wait_queue);
	wake_up_all(&sc->recv_io.reassembly.wait_queue);
	wake_up_all(&sc->rw_io.credits.wait_queue);
	wake_up_all(&sc->mr_io.ready.wait_queue);
}

void __smbdirect_socket_schedule_cleanup(struct smbdirect_socket *sc,
					 const char *macro_name,
					 unsigned int lvl,
					 const char *func,
					 unsigned int line,
					 int error,
					 enum smbdirect_socket_status *force_status)
{
	struct smbdirect_socket *psc, *tsc;
	unsigned long flags;
	bool was_first = false;

	if (!sc->first_error) {
		___smbdirect_log_generic(sc, func, line,
			lvl,
			SMBDIRECT_LOG_RDMA_EVENT,
			"%s(%1pe%s%s) called from %s in line=%u status=%s\n",
			macro_name,
			SMBDIRECT_DEBUG_ERR_PTR(error),
			force_status ? ", " : "",
			force_status ? smbdirect_socket_status_string(*force_status) : "",
			func, line,
			smbdirect_socket_status_string(sc->status));
		if (error)
			sc->first_error = error;
		else
			sc->first_error = -ECONNABORTED;
		was_first = true;
	}

	/*
	 * make sure other work (than disconnect_work)
	 * is not queued again but here we don't block and avoid
	 * disable[_delayed]_work_sync()
	 */
	disable_work(&sc->connect.work);
	disable_work(&sc->recv_io.posted.refill_work);
	disable_work(&sc->idle.immediate_work);
	sc->idle.keepalive = SMBDIRECT_KEEPALIVE_NONE;
	disable_delayed_work(&sc->idle.timer_work);

	/*
	 * In case we were a listener we need to
	 * disconnect all pending and ready sockets
	 *
	 * First we move ready sockets to pending again.
	 */
	spin_lock_irqsave(&sc->listen.lock, flags);
	list_splice_init(&sc->listen.ready, &sc->listen.pending);
	list_for_each_entry_safe(psc, tsc, &sc->listen.pending, accept.list)
		smbdirect_socket_schedule_cleanup(psc, sc->first_error);
	spin_unlock_irqrestore(&sc->listen.lock, flags);

	switch (sc->status) {
	case SMBDIRECT_SOCKET_RESOLVE_ADDR_FAILED:
	case SMBDIRECT_SOCKET_RESOLVE_ROUTE_FAILED:
	case SMBDIRECT_SOCKET_RDMA_CONNECT_FAILED:
	case SMBDIRECT_SOCKET_NEGOTIATE_FAILED:
	case SMBDIRECT_SOCKET_ERROR:
	case SMBDIRECT_SOCKET_DISCONNECTING:
	case SMBDIRECT_SOCKET_DISCONNECTED:
	case SMBDIRECT_SOCKET_DESTROYED:
		/*
		 * Keep the current error status
		 */
		break;

	case SMBDIRECT_SOCKET_RESOLVE_ADDR_NEEDED:
	case SMBDIRECT_SOCKET_RESOLVE_ADDR_RUNNING:
		sc->status = SMBDIRECT_SOCKET_RESOLVE_ADDR_FAILED;
		break;

	case SMBDIRECT_SOCKET_RESOLVE_ROUTE_NEEDED:
	case SMBDIRECT_SOCKET_RESOLVE_ROUTE_RUNNING:
		sc->status = SMBDIRECT_SOCKET_RESOLVE_ROUTE_FAILED;
		break;

	case SMBDIRECT_SOCKET_RDMA_CONNECT_NEEDED:
	case SMBDIRECT_SOCKET_RDMA_CONNECT_RUNNING:
		sc->status = SMBDIRECT_SOCKET_RDMA_CONNECT_FAILED;
		break;

	case SMBDIRECT_SOCKET_NEGOTIATE_NEEDED:
	case SMBDIRECT_SOCKET_NEGOTIATE_RUNNING:
		sc->status = SMBDIRECT_SOCKET_NEGOTIATE_FAILED;
		break;

	case SMBDIRECT_SOCKET_CREATED:
	case SMBDIRECT_SOCKET_LISTENING:
		sc->status = SMBDIRECT_SOCKET_DISCONNECTED;
		break;

	case SMBDIRECT_SOCKET_CONNECTED:
		sc->status = SMBDIRECT_SOCKET_ERROR;
		break;
	}

	if (force_status && (was_first || *force_status > sc->status))
		sc->status = *force_status;

	/*
	 * Wake up all waiters in all wait queues
	 * in order to notice the broken connection.
	 */
	smbdirect_socket_wake_up_all(sc);

	queue_work(sc->workqueues.cleanup, &sc->disconnect_work);
}

static void smbdirect_socket_cleanup_work(struct work_struct *work)
{
	struct smbdirect_socket *sc =
		container_of(work, struct smbdirect_socket, disconnect_work);
	struct smbdirect_socket *psc, *tsc;
	unsigned long flags;

	/*
	 * This should not never be called in an interrupt!
	 */
	WARN_ON_ONCE(in_interrupt());

	if (!sc->first_error) {
		smbdirect_log_rdma_event(sc, SMBDIRECT_LOG_ERR,
			"%s called with first_error==0\n",
			smbdirect_socket_status_string(sc->status));

		sc->first_error = -ECONNABORTED;
	}

	/*
	 * make sure this and other work is not queued again
	 * but here we don't block and avoid
	 * disable[_delayed]_work_sync()
	 */
	disable_work(&sc->disconnect_work);
	disable_work(&sc->connect.work);
	disable_work(&sc->recv_io.posted.refill_work);
	disable_work(&sc->idle.immediate_work);
	sc->idle.keepalive = SMBDIRECT_KEEPALIVE_NONE;
	disable_delayed_work(&sc->idle.timer_work);

	/*
	 * In case we were a listener we need to
	 * disconnect all pending and ready sockets
	 *
	 * First we move ready sockets to pending again.
	 */
	spin_lock_irqsave(&sc->listen.lock, flags);
	list_splice_init(&sc->listen.ready, &sc->listen.pending);
	list_for_each_entry_safe(psc, tsc, &sc->listen.pending, accept.list)
		smbdirect_socket_schedule_cleanup(psc, sc->first_error);
	spin_unlock_irqrestore(&sc->listen.lock, flags);

	switch (sc->status) {
	case SMBDIRECT_SOCKET_NEGOTIATE_NEEDED:
	case SMBDIRECT_SOCKET_NEGOTIATE_RUNNING:
	case SMBDIRECT_SOCKET_NEGOTIATE_FAILED:
	case SMBDIRECT_SOCKET_CONNECTED:
	case SMBDIRECT_SOCKET_ERROR:
		sc->status = SMBDIRECT_SOCKET_DISCONNECTING;
		/*
		 * Make sure we hold the callback lock
		 * im order to coordinate with the
		 * rdma_event handlers, typically
		 * smbdirect_connection_rdma_event_handler(),
		 * and smbdirect_socket_destroy().
		 *
		 * So that the order of ib_drain_qp()
		 * and rdma_disconnect() is controlled
		 * by the mutex.
		 */
		rdma_lock_handler(sc->rdma.cm_id);
		rdma_disconnect(sc->rdma.cm_id);
		rdma_unlock_handler(sc->rdma.cm_id);
		break;

	case SMBDIRECT_SOCKET_CREATED:
	case SMBDIRECT_SOCKET_LISTENING:
	case SMBDIRECT_SOCKET_RESOLVE_ADDR_NEEDED:
	case SMBDIRECT_SOCKET_RESOLVE_ADDR_RUNNING:
	case SMBDIRECT_SOCKET_RESOLVE_ADDR_FAILED:
	case SMBDIRECT_SOCKET_RESOLVE_ROUTE_NEEDED:
	case SMBDIRECT_SOCKET_RESOLVE_ROUTE_RUNNING:
	case SMBDIRECT_SOCKET_RESOLVE_ROUTE_FAILED:
	case SMBDIRECT_SOCKET_RDMA_CONNECT_NEEDED:
	case SMBDIRECT_SOCKET_RDMA_CONNECT_RUNNING:
	case SMBDIRECT_SOCKET_RDMA_CONNECT_FAILED:
		/*
		 * rdma_{accept,connect}() never reached
		 * RDMA_CM_EVENT_ESTABLISHED
		 */
		sc->status = SMBDIRECT_SOCKET_DISCONNECTED;
		break;

	case SMBDIRECT_SOCKET_DISCONNECTING:
	case SMBDIRECT_SOCKET_DISCONNECTED:
	case SMBDIRECT_SOCKET_DESTROYED:
		break;
	}

	/*
	 * Wake up all waiters in all wait queues
	 * in order to notice the broken connection.
	 */
	smbdirect_socket_wake_up_all(sc);
}

static void smbdirect_socket_destroy(struct smbdirect_socket *sc)
{
	struct smbdirect_socket *psc, *tsc;
	size_t psockets;
	struct smbdirect_recv_io *recv_io;
	struct smbdirect_recv_io *recv_tmp;
	LIST_HEAD(all_list);
	unsigned long flags;

	smbdirect_log_rdma_event(sc, SMBDIRECT_LOG_INFO,
		"status=%s first_error=%1pe",
		smbdirect_socket_status_string(sc->status),
		SMBDIRECT_DEBUG_ERR_PTR(sc->first_error));

	/*
	 * This should not never be called in an interrupt!
	 */
	WARN_ON_ONCE(in_interrupt());

	if (sc->status == SMBDIRECT_SOCKET_DESTROYED)
		return;

	WARN_ONCE(sc->status != SMBDIRECT_SOCKET_DISCONNECTED,
		  "status=%s first_error=%1pe",
		  smbdirect_socket_status_string(sc->status),
		  SMBDIRECT_DEBUG_ERR_PTR(sc->first_error));

	/*
	 * The listener should clear this before we reach this
	 */
	WARN_ONCE(sc->accept.listener,
		  "status=%s first_error=%1pe",
		  smbdirect_socket_status_string(sc->status),
		  SMBDIRECT_DEBUG_ERR_PTR(sc->first_error));

	/*
	 * Wake up all waiters in all wait queues
	 * in order to notice the broken connection.
	 *
	 * Most likely this was already called via
	 * smbdirect_socket_cleanup_work(), but call it again...
	 */
	smbdirect_socket_wake_up_all(sc);

	disable_work_sync(&sc->disconnect_work);
	disable_work_sync(&sc->connect.work);
	disable_work_sync(&sc->recv_io.posted.refill_work);
	disable_work_sync(&sc->idle.immediate_work);
	disable_delayed_work_sync(&sc->idle.timer_work);

	if (sc->rdma.cm_id)
		rdma_lock_handler(sc->rdma.cm_id);

	if (sc->ib.qp) {
		smbdirect_log_rdma_event(sc, SMBDIRECT_LOG_INFO,
			"drain qp\n");
		ib_drain_qp(sc->ib.qp);
	}

	/*
	 * In case we were a listener we need to
	 * disconnect all pending and ready sockets
	 *
	 * We move ready sockets to pending again.
	 */
	spin_lock_irqsave(&sc->listen.lock, flags);
	list_splice_tail_init(&sc->listen.ready, &all_list);
	list_splice_tail_init(&sc->listen.pending, &all_list);
	spin_unlock_irqrestore(&sc->listen.lock, flags);
	psockets = list_count_nodes(&all_list);
	if (sc->listen.backlog != -1) /* was a listener */
		smbdirect_log_rdma_event(sc, SMBDIRECT_LOG_INFO,
			"release %zu pending sockets\n", psockets);
	list_for_each_entry_safe(psc, tsc, &all_list, accept.list) {
		list_del_init(&psc->accept.list);
		psc->accept.listener = NULL;
		smbdirect_socket_release(psc);
	}
	if (sc->listen.backlog != -1) /* was a listener */
		smbdirect_log_rdma_event(sc, SMBDIRECT_LOG_INFO,
			"released %zu pending sockets\n", psockets);
	INIT_LIST_HEAD(&all_list);

	/* It's not possible for upper layer to get to reassembly */
	if (sc->listen.backlog == -1) /* was not a listener */
		smbdirect_log_rdma_event(sc, SMBDIRECT_LOG_INFO,
			"drain the reassembly queue\n");
	spin_lock_irqsave(&sc->recv_io.reassembly.lock, flags);
	list_splice_tail_init(&sc->recv_io.reassembly.list, &all_list);
	spin_unlock_irqrestore(&sc->recv_io.reassembly.lock, flags);
	list_for_each_entry_safe(recv_io, recv_tmp, &all_list, list)
		smbdirect_connection_put_recv_io(recv_io);
	sc->recv_io.reassembly.data_length = 0;

	if (sc->listen.backlog == -1) /* was not a listener */
		smbdirect_log_rdma_event(sc, SMBDIRECT_LOG_INFO,
			"freeing mr list\n");
	smbdirect_connection_destroy_mr_list(sc);

	if (sc->listen.backlog == -1) /* was not a listener */
		smbdirect_log_rdma_event(sc, SMBDIRECT_LOG_INFO,
			"destroying qp\n");
	smbdirect_connection_destroy_qp(sc);
	if (sc->rdma.cm_id) {
		rdma_unlock_handler(sc->rdma.cm_id);
		smbdirect_log_rdma_event(sc, SMBDIRECT_LOG_INFO,
			"destroying cm_id\n");
		rdma_destroy_id(sc->rdma.cm_id);
		sc->rdma.cm_id = NULL;
	}

	if (sc->listen.backlog == -1) /* was not a listener */
		smbdirect_log_rdma_event(sc, SMBDIRECT_LOG_INFO,
			"destroying mem pools\n");
	smbdirect_connection_destroy_mem_pools(sc);

	sc->status = SMBDIRECT_SOCKET_DESTROYED;

	smbdirect_log_rdma_event(sc, SMBDIRECT_LOG_INFO,
		"rdma session destroyed\n");
}

void smbdirect_socket_destroy_sync(struct smbdirect_socket *sc)
{
	smbdirect_log_rdma_event(sc, SMBDIRECT_LOG_INFO,
		"status=%s first_error=%1pe",
		smbdirect_socket_status_string(sc->status),
		SMBDIRECT_DEBUG_ERR_PTR(sc->first_error));

	/*
	 * This should not never be called in an interrupt!
	 */
	WARN_ON_ONCE(in_interrupt());

	/*
	 * First we try to disable the work
	 * without disable_work_sync() in a
	 * non blocking way, if it's already
	 * running it will be handles by
	 * disable_work_sync() below.
	 *
	 * Here we just want to make sure queue_work() in
	 * smbdirect_socket_schedule_cleanup_lvl()
	 * is a no-op.
	 */
	disable_work(&sc->disconnect_work);

	if (!sc->first_error)
		/*
		 * SMBDIRECT_LOG_INFO is enough here
		 * as this is the typical case where
		 * we terminate the connection ourself.
		 */
		smbdirect_socket_schedule_cleanup_lvl(sc,
						      SMBDIRECT_LOG_INFO,
						      -ESHUTDOWN);

	smbdirect_log_rdma_event(sc, SMBDIRECT_LOG_INFO,
		"cancelling and disable disconnect_work\n");
	disable_work_sync(&sc->disconnect_work);

	smbdirect_log_rdma_event(sc, SMBDIRECT_LOG_INFO,
		"destroying rdma session\n");
	if (sc->status < SMBDIRECT_SOCKET_DISCONNECTING)
		smbdirect_socket_cleanup_work(&sc->disconnect_work);
	if (sc->status < SMBDIRECT_SOCKET_DISCONNECTED) {
		smbdirect_log_rdma_event(sc, SMBDIRECT_LOG_INFO,
			"wait for transport being disconnected\n");
		wait_event(sc->status_wait, sc->status == SMBDIRECT_SOCKET_DISCONNECTED);
		smbdirect_log_rdma_event(sc, SMBDIRECT_LOG_INFO,
			"waited for transport being disconnected\n");
	}

	/*
	 * Once we reached SMBDIRECT_SOCKET_DISCONNECTED,
	 * we should call smbdirect_socket_destroy()
	 */
	smbdirect_socket_destroy(sc);
	smbdirect_log_rdma_event(sc, SMBDIRECT_LOG_INFO,
		"status=%s first_error=%1pe",
		smbdirect_socket_status_string(sc->status),
		SMBDIRECT_DEBUG_ERR_PTR(sc->first_error));
}

int smbdirect_socket_bind(struct smbdirect_socket *sc, struct sockaddr *addr)
{
	int ret;

	if (sc->status != SMBDIRECT_SOCKET_CREATED)
		return -EINVAL;

	ret = rdma_bind_addr(sc->rdma.cm_id, addr);
	if (ret)
		return ret;

	return 0;
}
__SMBDIRECT_EXPORT_SYMBOL__(smbdirect_socket_bind);

void smbdirect_socket_shutdown(struct smbdirect_socket *sc)
{
	smbdirect_socket_schedule_cleanup_lvl(sc, SMBDIRECT_LOG_INFO, -ESHUTDOWN);
}
__SMBDIRECT_EXPORT_SYMBOL__(smbdirect_socket_shutdown);

static void smbdirect_socket_release_disconnect(struct kref *kref)
{
	struct smbdirect_socket *sc =
		container_of(kref, struct smbdirect_socket, refs.disconnect);

	/*
	 * For now do a sync disconnect/destroy
	 */
	smbdirect_socket_destroy_sync(sc);
}

static void smbdirect_socket_release_destroy(struct kref *kref)
{
	struct smbdirect_socket *sc =
		container_of(kref, struct smbdirect_socket, refs.destroy);

	/*
	 * Do a sync disconnect/destroy...
	 * hopefully a no-op, as it should be already
	 * in DESTROYED state, before we free the memory.
	 */
	smbdirect_socket_destroy_sync(sc);
	kfree(sc);
}

void smbdirect_socket_release(struct smbdirect_socket *sc)
{
	/*
	 * We expect only 1 disconnect reference
	 * and if it is already 0, it's a use after free!
	 */
	WARN_ON_ONCE(kref_read(&sc->refs.disconnect) != 1);
	WARN_ON(!kref_put(&sc->refs.disconnect, smbdirect_socket_release_disconnect));

	/*
	 * This may not trigger smbdirect_socket_release_destroy(),
	 * if struct smbdirect_socket is embedded in another structure
	 * indicated by REFCOUNT_MAX.
	 */
	kref_put(&sc->refs.destroy, smbdirect_socket_release_destroy);
}
__SMBDIRECT_EXPORT_SYMBOL__(smbdirect_socket_release);

int smbdirect_socket_wait_for_credits(struct smbdirect_socket *sc,
				      enum smbdirect_socket_status expected_status,
				      int unexpected_errno,
				      wait_queue_head_t *waitq,
				      atomic_t *total_credits,
				      int needed)
{
	int ret;

	if (WARN_ON_ONCE(needed < 0))
		return -EINVAL;

	do {
		if (atomic_sub_return(needed, total_credits) >= 0)
			return 0;

		atomic_add(needed, total_credits);
		ret = wait_event_interruptible(*waitq,
					       atomic_read(total_credits) >= needed ||
					       sc->status != expected_status);

		if (sc->status != expected_status)
			return unexpected_errno;
		else if (ret < 0)
			return ret;
	} while (true);
}
