1 // SPDX-License-Identifier: GPL-2.0 OR Linux-OpenIB
2 /*
3  * Copyright (c) 2022-2023 Fujitsu Ltd. All rights reserved.
4  */
5 
6 #include <linux/hmm.h>
7 
8 #include <rdma/ib_umem_odp.h>
9 
10 #include "rxe.h"
11 
rxe_ib_invalidate_range(struct mmu_interval_notifier * mni,const struct mmu_notifier_range * range,unsigned long cur_seq)12 static bool rxe_ib_invalidate_range(struct mmu_interval_notifier *mni,
13 				    const struct mmu_notifier_range *range,
14 				    unsigned long cur_seq)
15 {
16 	struct ib_umem_odp *umem_odp =
17 		container_of(mni, struct ib_umem_odp, notifier);
18 	unsigned long start, end;
19 
20 	if (!mmu_notifier_range_blockable(range))
21 		return false;
22 
23 	mutex_lock(&umem_odp->umem_mutex);
24 	mmu_interval_set_seq(mni, cur_seq);
25 
26 	start = max_t(u64, ib_umem_start(umem_odp), range->start);
27 	end = min_t(u64, ib_umem_end(umem_odp), range->end);
28 
29 	/* update umem_odp->dma_list */
30 	ib_umem_odp_unmap_dma_pages(umem_odp, start, end);
31 
32 	mutex_unlock(&umem_odp->umem_mutex);
33 	return true;
34 }
35 
36 const struct mmu_interval_notifier_ops rxe_mn_ops = {
37 	.invalidate = rxe_ib_invalidate_range,
38 };
39 
40 #define RXE_PAGEFAULT_DEFAULT 0
41 #define RXE_PAGEFAULT_RDONLY BIT(0)
42 #define RXE_PAGEFAULT_SNAPSHOT BIT(1)
rxe_odp_do_pagefault_and_lock(struct rxe_mr * mr,u64 user_va,int bcnt,u32 flags)43 static int rxe_odp_do_pagefault_and_lock(struct rxe_mr *mr, u64 user_va, int bcnt, u32 flags)
44 {
45 	struct ib_umem_odp *umem_odp = to_ib_umem_odp(mr->umem);
46 	bool fault = !(flags & RXE_PAGEFAULT_SNAPSHOT);
47 	u64 access_mask;
48 	int np;
49 
50 	access_mask = ODP_READ_ALLOWED_BIT;
51 	if (umem_odp->umem.writable && !(flags & RXE_PAGEFAULT_RDONLY))
52 		access_mask |= ODP_WRITE_ALLOWED_BIT;
53 
54 	/*
55 	 * ib_umem_odp_map_dma_and_lock() locks umem_mutex on success.
56 	 * Callers must release the lock later to let invalidation handler
57 	 * do its work again.
58 	 */
59 	np = ib_umem_odp_map_dma_and_lock(umem_odp, user_va, bcnt,
60 					  access_mask, fault);
61 	return np;
62 }
63 
rxe_odp_init_pages(struct rxe_mr * mr)64 static int rxe_odp_init_pages(struct rxe_mr *mr)
65 {
66 	struct ib_umem_odp *umem_odp = to_ib_umem_odp(mr->umem);
67 	int ret;
68 
69 	ret = rxe_odp_do_pagefault_and_lock(mr, mr->umem->address,
70 					    mr->umem->length,
71 					    RXE_PAGEFAULT_SNAPSHOT);
72 
73 	if (ret >= 0)
74 		mutex_unlock(&umem_odp->umem_mutex);
75 
76 	return ret >= 0 ? 0 : ret;
77 }
78 
rxe_odp_mr_init_user(struct rxe_dev * rxe,u64 start,u64 length,u64 iova,int access_flags,struct rxe_mr * mr)79 int rxe_odp_mr_init_user(struct rxe_dev *rxe, u64 start, u64 length,
80 			 u64 iova, int access_flags, struct rxe_mr *mr)
81 {
82 	struct ib_umem_odp *umem_odp;
83 	int err;
84 
85 	if (!IS_ENABLED(CONFIG_INFINIBAND_ON_DEMAND_PAGING))
86 		return -EOPNOTSUPP;
87 
88 	rxe_mr_init(access_flags, mr);
89 
90 	if (!start && length == U64_MAX) {
91 		if (iova != 0)
92 			return -EINVAL;
93 		if (!(rxe->attr.odp_caps.general_caps & IB_ODP_SUPPORT_IMPLICIT))
94 			return -EINVAL;
95 
96 		/* Never reach here, for implicit ODP is not implemented. */
97 	}
98 
99 	umem_odp = ib_umem_odp_get(&rxe->ib_dev, start, length, access_flags,
100 				   &rxe_mn_ops);
101 	if (IS_ERR(umem_odp)) {
102 		rxe_dbg_mr(mr, "Unable to create umem_odp err = %d\n",
103 			   (int)PTR_ERR(umem_odp));
104 		return PTR_ERR(umem_odp);
105 	}
106 
107 	umem_odp->private = mr;
108 
109 	mr->umem = &umem_odp->umem;
110 	mr->access = access_flags;
111 	mr->ibmr.length = length;
112 	mr->ibmr.iova = iova;
113 	mr->page_offset = ib_umem_offset(&umem_odp->umem);
114 
115 	err = rxe_odp_init_pages(mr);
116 	if (err) {
117 		ib_umem_odp_release(umem_odp);
118 		return err;
119 	}
120 
121 	mr->state = RXE_MR_STATE_VALID;
122 	mr->ibmr.type = IB_MR_TYPE_USER;
123 
124 	return err;
125 }
126 
rxe_check_pagefault(struct ib_umem_odp * umem_odp,u64 iova,int length,u32 perm)127 static inline bool rxe_check_pagefault(struct ib_umem_odp *umem_odp,
128 				       u64 iova, int length, u32 perm)
129 {
130 	bool need_fault = false;
131 	u64 addr;
132 	int idx;
133 
134 	addr = iova & (~(BIT(umem_odp->page_shift) - 1));
135 
136 	/* Skim through all pages that are to be accessed. */
137 	while (addr < iova + length) {
138 		idx = (addr - ib_umem_start(umem_odp)) >> umem_odp->page_shift;
139 
140 		if (!(umem_odp->dma_list[idx] & perm)) {
141 			need_fault = true;
142 			break;
143 		}
144 
145 		addr += BIT(umem_odp->page_shift);
146 	}
147 	return need_fault;
148 }
149 
rxe_odp_map_range_and_lock(struct rxe_mr * mr,u64 iova,int length,u32 flags)150 static int rxe_odp_map_range_and_lock(struct rxe_mr *mr, u64 iova, int length, u32 flags)
151 {
152 	struct ib_umem_odp *umem_odp = to_ib_umem_odp(mr->umem);
153 	bool need_fault;
154 	u64 perm;
155 	int err;
156 
157 	if (unlikely(length < 1))
158 		return -EINVAL;
159 
160 	perm = ODP_READ_ALLOWED_BIT;
161 	if (!(flags & RXE_PAGEFAULT_RDONLY))
162 		perm |= ODP_WRITE_ALLOWED_BIT;
163 
164 	mutex_lock(&umem_odp->umem_mutex);
165 
166 	need_fault = rxe_check_pagefault(umem_odp, iova, length, perm);
167 	if (need_fault) {
168 		mutex_unlock(&umem_odp->umem_mutex);
169 
170 		/* umem_mutex is locked on success. */
171 		err = rxe_odp_do_pagefault_and_lock(mr, iova, length,
172 						    flags);
173 		if (err < 0)
174 			return err;
175 
176 		need_fault = rxe_check_pagefault(umem_odp, iova, length, perm);
177 		if (need_fault)
178 			return -EFAULT;
179 	}
180 
181 	return 0;
182 }
183 
__rxe_odp_mr_copy(struct rxe_mr * mr,u64 iova,void * addr,int length,enum rxe_mr_copy_dir dir)184 static int __rxe_odp_mr_copy(struct rxe_mr *mr, u64 iova, void *addr,
185 			     int length, enum rxe_mr_copy_dir dir)
186 {
187 	struct ib_umem_odp *umem_odp = to_ib_umem_odp(mr->umem);
188 	struct page *page;
189 	int idx, bytes;
190 	size_t offset;
191 	u8 *user_va;
192 
193 	idx = (iova - ib_umem_start(umem_odp)) >> umem_odp->page_shift;
194 	offset = iova & (BIT(umem_odp->page_shift) - 1);
195 
196 	while (length > 0) {
197 		u8 *src, *dest;
198 
199 		page = hmm_pfn_to_page(umem_odp->pfn_list[idx]);
200 		user_va = kmap_local_page(page);
201 		if (!user_va)
202 			return -EFAULT;
203 
204 		src = (dir == RXE_TO_MR_OBJ) ? addr : user_va;
205 		dest = (dir == RXE_TO_MR_OBJ) ? user_va : addr;
206 
207 		bytes = BIT(umem_odp->page_shift) - offset;
208 		if (bytes > length)
209 			bytes = length;
210 
211 		memcpy(dest, src, bytes);
212 		kunmap_local(user_va);
213 
214 		length  -= bytes;
215 		idx++;
216 		offset = 0;
217 	}
218 
219 	return 0;
220 }
221 
rxe_odp_mr_copy(struct rxe_mr * mr,u64 iova,void * addr,int length,enum rxe_mr_copy_dir dir)222 int rxe_odp_mr_copy(struct rxe_mr *mr, u64 iova, void *addr, int length,
223 		    enum rxe_mr_copy_dir dir)
224 {
225 	struct ib_umem_odp *umem_odp = to_ib_umem_odp(mr->umem);
226 	u32 flags = RXE_PAGEFAULT_DEFAULT;
227 	int err;
228 
229 	if (length == 0)
230 		return 0;
231 
232 	if (unlikely(!mr->umem->is_odp))
233 		return -EOPNOTSUPP;
234 
235 	switch (dir) {
236 	case RXE_TO_MR_OBJ:
237 		break;
238 
239 	case RXE_FROM_MR_OBJ:
240 		flags |= RXE_PAGEFAULT_RDONLY;
241 		break;
242 
243 	default:
244 		return -EINVAL;
245 	}
246 
247 	err = rxe_odp_map_range_and_lock(mr, iova, length, flags);
248 	if (err)
249 		return err;
250 
251 	err =  __rxe_odp_mr_copy(mr, iova, addr, length, dir);
252 
253 	mutex_unlock(&umem_odp->umem_mutex);
254 
255 	return err;
256 }
257 
rxe_odp_do_atomic_op(struct rxe_mr * mr,u64 iova,int opcode,u64 compare,u64 swap_add,u64 * orig_val)258 static int rxe_odp_do_atomic_op(struct rxe_mr *mr, u64 iova, int opcode,
259 				u64 compare, u64 swap_add, u64 *orig_val)
260 {
261 	struct ib_umem_odp *umem_odp = to_ib_umem_odp(mr->umem);
262 	unsigned int page_offset;
263 	struct page *page;
264 	unsigned int idx;
265 	u64 value;
266 	u64 *va;
267 	int err;
268 
269 	if (unlikely(mr->state != RXE_MR_STATE_VALID)) {
270 		rxe_dbg_mr(mr, "mr not in valid state\n");
271 		return RESPST_ERR_RKEY_VIOLATION;
272 	}
273 
274 	err = mr_check_range(mr, iova, sizeof(value));
275 	if (err) {
276 		rxe_dbg_mr(mr, "iova out of range\n");
277 		return RESPST_ERR_RKEY_VIOLATION;
278 	}
279 
280 	idx = (iova - ib_umem_start(umem_odp)) >> umem_odp->page_shift;
281 	page_offset = iova & (BIT(umem_odp->page_shift) - 1);
282 	page = hmm_pfn_to_page(umem_odp->pfn_list[idx]);
283 	if (!page)
284 		return RESPST_ERR_RKEY_VIOLATION;
285 
286 	if (unlikely(page_offset & 0x7)) {
287 		rxe_dbg_mr(mr, "iova not aligned\n");
288 		return RESPST_ERR_MISALIGNED_ATOMIC;
289 	}
290 
291 	va = kmap_local_page(page);
292 
293 	spin_lock_bh(&atomic_ops_lock);
294 	value = *orig_val = va[page_offset >> 3];
295 
296 	if (opcode == IB_OPCODE_RC_COMPARE_SWAP) {
297 		if (value == compare)
298 			va[page_offset >> 3] = swap_add;
299 	} else {
300 		value += swap_add;
301 		va[page_offset >> 3] = value;
302 	}
303 	spin_unlock_bh(&atomic_ops_lock);
304 
305 	kunmap_local(va);
306 
307 	return 0;
308 }
309 
rxe_odp_atomic_op(struct rxe_mr * mr,u64 iova,int opcode,u64 compare,u64 swap_add,u64 * orig_val)310 int rxe_odp_atomic_op(struct rxe_mr *mr, u64 iova, int opcode,
311 			 u64 compare, u64 swap_add, u64 *orig_val)
312 {
313 	struct ib_umem_odp *umem_odp = to_ib_umem_odp(mr->umem);
314 	int err;
315 
316 	err = rxe_odp_map_range_and_lock(mr, iova, sizeof(char),
317 					 RXE_PAGEFAULT_DEFAULT);
318 	if (err < 0)
319 		return RESPST_ERR_RKEY_VIOLATION;
320 
321 	err = rxe_odp_do_atomic_op(mr, iova, opcode, compare, swap_add,
322 				   orig_val);
323 	mutex_unlock(&umem_odp->umem_mutex);
324 
325 	return err;
326 }
327