1 // SPDX-License-Identifier: GPL-2.0 OR Linux-OpenIB
2 /*
3 * Copyright (c) 2022-2023 Fujitsu Ltd. All rights reserved.
4 */
5
6 #include <linux/hmm.h>
7 #include <linux/libnvdimm.h>
8
9 #include <rdma/ib_umem_odp.h>
10
11 #include "rxe.h"
12
rxe_ib_invalidate_range(struct mmu_interval_notifier * mni,const struct mmu_notifier_range * range,unsigned long cur_seq)13 static bool rxe_ib_invalidate_range(struct mmu_interval_notifier *mni,
14 const struct mmu_notifier_range *range,
15 unsigned long cur_seq)
16 {
17 struct ib_umem_odp *umem_odp =
18 container_of(mni, struct ib_umem_odp, notifier);
19 unsigned long start, end;
20
21 if (!mmu_notifier_range_blockable(range))
22 return false;
23
24 mutex_lock(&umem_odp->umem_mutex);
25 mmu_interval_set_seq(mni, cur_seq);
26
27 start = max_t(u64, ib_umem_start(umem_odp), range->start);
28 end = min_t(u64, ib_umem_end(umem_odp), range->end);
29
30 /* update umem_odp->map.pfn_list */
31 ib_umem_odp_unmap_dma_pages(umem_odp, start, end);
32
33 mutex_unlock(&umem_odp->umem_mutex);
34 return true;
35 }
36
37 const struct mmu_interval_notifier_ops rxe_mn_ops = {
38 .invalidate = rxe_ib_invalidate_range,
39 };
40
41 #define RXE_PAGEFAULT_DEFAULT 0
42 #define RXE_PAGEFAULT_RDONLY BIT(0)
43 #define RXE_PAGEFAULT_SNAPSHOT BIT(1)
rxe_odp_do_pagefault_and_lock(struct rxe_mr * mr,u64 user_va,int bcnt,u32 flags)44 static int rxe_odp_do_pagefault_and_lock(struct rxe_mr *mr, u64 user_va, int bcnt, u32 flags)
45 {
46 struct ib_umem_odp *umem_odp = to_ib_umem_odp(mr->umem);
47 bool fault = !(flags & RXE_PAGEFAULT_SNAPSHOT);
48 u64 access_mask = 0;
49 int np;
50
51 if (umem_odp->umem.writable && !(flags & RXE_PAGEFAULT_RDONLY))
52 access_mask |= HMM_PFN_WRITE;
53
54 /*
55 * ib_umem_odp_map_dma_and_lock() locks umem_mutex on success.
56 * Callers must release the lock later to let invalidation handler
57 * do its work again.
58 */
59 np = ib_umem_odp_map_dma_and_lock(umem_odp, user_va, bcnt,
60 access_mask, fault);
61 return np;
62 }
63
rxe_odp_init_pages(struct rxe_mr * mr)64 static int rxe_odp_init_pages(struct rxe_mr *mr)
65 {
66 struct ib_umem_odp *umem_odp = to_ib_umem_odp(mr->umem);
67 int ret;
68
69 ret = rxe_odp_do_pagefault_and_lock(mr, mr->umem->address,
70 mr->umem->length,
71 RXE_PAGEFAULT_SNAPSHOT);
72
73 if (ret >= 0)
74 mutex_unlock(&umem_odp->umem_mutex);
75
76 return ret >= 0 ? 0 : ret;
77 }
78
rxe_odp_mr_init_user(struct rxe_dev * rxe,u64 start,u64 length,u64 iova,int access_flags,struct rxe_mr * mr)79 int rxe_odp_mr_init_user(struct rxe_dev *rxe, u64 start, u64 length,
80 u64 iova, int access_flags, struct rxe_mr *mr)
81 {
82 struct ib_umem_odp *umem_odp;
83 int err;
84
85 if (!IS_ENABLED(CONFIG_INFINIBAND_ON_DEMAND_PAGING))
86 return -EOPNOTSUPP;
87
88 rxe_mr_init(access_flags, mr);
89
90 if (!start && length == U64_MAX) {
91 if (iova != 0)
92 return -EINVAL;
93 if (!(rxe->attr.odp_caps.general_caps & IB_ODP_SUPPORT_IMPLICIT))
94 return -EINVAL;
95
96 /* Never reach here, for implicit ODP is not implemented. */
97 }
98
99 umem_odp = ib_umem_odp_get(&rxe->ib_dev, start, length, access_flags,
100 &rxe_mn_ops);
101 if (IS_ERR(umem_odp)) {
102 rxe_dbg_mr(mr, "Unable to create umem_odp err = %d\n",
103 (int)PTR_ERR(umem_odp));
104 return PTR_ERR(umem_odp);
105 }
106
107 umem_odp->private = mr;
108
109 mr->umem = &umem_odp->umem;
110 mr->access = access_flags;
111 mr->ibmr.length = length;
112 mr->ibmr.iova = iova;
113
114 err = rxe_odp_init_pages(mr);
115 if (err) {
116 ib_umem_odp_release(umem_odp);
117 return err;
118 }
119
120 mr->state = RXE_MR_STATE_VALID;
121 mr->ibmr.type = IB_MR_TYPE_USER;
122
123 return err;
124 }
125
rxe_check_pagefault(struct ib_umem_odp * umem_odp,u64 iova,int length)126 static inline bool rxe_check_pagefault(struct ib_umem_odp *umem_odp, u64 iova,
127 int length)
128 {
129 bool need_fault = false;
130 u64 addr;
131 int idx;
132
133 addr = iova & (~(BIT(umem_odp->page_shift) - 1));
134
135 /* Skim through all pages that are to be accessed. */
136 while (addr < iova + length) {
137 idx = (addr - ib_umem_start(umem_odp)) >> umem_odp->page_shift;
138
139 if (!(umem_odp->map.pfn_list[idx] & HMM_PFN_VALID)) {
140 need_fault = true;
141 break;
142 }
143
144 addr += BIT(umem_odp->page_shift);
145 }
146 return need_fault;
147 }
148
rxe_odp_iova_to_index(struct ib_umem_odp * umem_odp,u64 iova)149 static unsigned long rxe_odp_iova_to_index(struct ib_umem_odp *umem_odp, u64 iova)
150 {
151 return (iova - ib_umem_start(umem_odp)) >> umem_odp->page_shift;
152 }
153
rxe_odp_iova_to_page_offset(struct ib_umem_odp * umem_odp,u64 iova)154 static unsigned long rxe_odp_iova_to_page_offset(struct ib_umem_odp *umem_odp, u64 iova)
155 {
156 return iova & (BIT(umem_odp->page_shift) - 1);
157 }
158
rxe_odp_map_range_and_lock(struct rxe_mr * mr,u64 iova,int length,u32 flags)159 static int rxe_odp_map_range_and_lock(struct rxe_mr *mr, u64 iova, int length, u32 flags)
160 {
161 struct ib_umem_odp *umem_odp = to_ib_umem_odp(mr->umem);
162 bool need_fault;
163 int err;
164
165 if (unlikely(length < 1))
166 return -EINVAL;
167
168 mutex_lock(&umem_odp->umem_mutex);
169
170 need_fault = rxe_check_pagefault(umem_odp, iova, length);
171 if (need_fault) {
172 mutex_unlock(&umem_odp->umem_mutex);
173
174 /* umem_mutex is locked on success. */
175 err = rxe_odp_do_pagefault_and_lock(mr, iova, length,
176 flags);
177 if (err < 0)
178 return err;
179
180 need_fault = rxe_check_pagefault(umem_odp, iova, length);
181 if (need_fault) {
182 mutex_unlock(&umem_odp->umem_mutex);
183 return -EFAULT;
184 }
185 }
186
187 return 0;
188 }
189
__rxe_odp_mr_copy(struct rxe_mr * mr,u64 iova,void * addr,int length,enum rxe_mr_copy_dir dir)190 static int __rxe_odp_mr_copy(struct rxe_mr *mr, u64 iova, void *addr,
191 int length, enum rxe_mr_copy_dir dir)
192 {
193 struct ib_umem_odp *umem_odp = to_ib_umem_odp(mr->umem);
194 struct page *page;
195 int idx, bytes;
196 size_t offset;
197 u8 *user_va;
198
199 idx = rxe_odp_iova_to_index(umem_odp, iova);
200 offset = rxe_odp_iova_to_page_offset(umem_odp, iova);
201
202 while (length > 0) {
203 u8 *src, *dest;
204
205 page = hmm_pfn_to_page(umem_odp->map.pfn_list[idx]);
206 user_va = kmap_local_page(page);
207
208 src = (dir == RXE_TO_MR_OBJ) ? addr : user_va;
209 dest = (dir == RXE_TO_MR_OBJ) ? user_va : addr;
210
211 bytes = BIT(umem_odp->page_shift) - offset;
212 if (bytes > length)
213 bytes = length;
214
215 memcpy(dest, src, bytes);
216 kunmap_local(user_va);
217
218 length -= bytes;
219 idx++;
220 offset = 0;
221 }
222
223 return 0;
224 }
225
rxe_odp_mr_copy(struct rxe_mr * mr,u64 iova,void * addr,int length,enum rxe_mr_copy_dir dir)226 int rxe_odp_mr_copy(struct rxe_mr *mr, u64 iova, void *addr, int length,
227 enum rxe_mr_copy_dir dir)
228 {
229 struct ib_umem_odp *umem_odp = to_ib_umem_odp(mr->umem);
230 u32 flags = RXE_PAGEFAULT_DEFAULT;
231 int err;
232
233 if (length == 0)
234 return 0;
235
236 if (unlikely(!mr->umem->is_odp))
237 return -EOPNOTSUPP;
238
239 switch (dir) {
240 case RXE_TO_MR_OBJ:
241 break;
242
243 case RXE_FROM_MR_OBJ:
244 flags |= RXE_PAGEFAULT_RDONLY;
245 break;
246
247 default:
248 return -EINVAL;
249 }
250
251 err = rxe_odp_map_range_and_lock(mr, iova, length, flags);
252 if (err)
253 return err;
254
255 err = __rxe_odp_mr_copy(mr, iova, addr, length, dir);
256
257 mutex_unlock(&umem_odp->umem_mutex);
258
259 return err;
260 }
261
rxe_odp_do_atomic_op(struct rxe_mr * mr,u64 iova,int opcode,u64 compare,u64 swap_add,u64 * orig_val)262 static enum resp_states rxe_odp_do_atomic_op(struct rxe_mr *mr, u64 iova,
263 int opcode, u64 compare,
264 u64 swap_add, u64 *orig_val)
265 {
266 struct ib_umem_odp *umem_odp = to_ib_umem_odp(mr->umem);
267 unsigned int page_offset;
268 struct page *page;
269 unsigned int idx;
270 u64 value;
271 u64 *va;
272 int err;
273
274 if (unlikely(mr->state != RXE_MR_STATE_VALID)) {
275 rxe_dbg_mr(mr, "mr not in valid state\n");
276 return RESPST_ERR_RKEY_VIOLATION;
277 }
278
279 err = mr_check_range(mr, iova, sizeof(value));
280 if (err) {
281 rxe_dbg_mr(mr, "iova out of range\n");
282 return RESPST_ERR_RKEY_VIOLATION;
283 }
284
285 page_offset = rxe_odp_iova_to_page_offset(umem_odp, iova);
286 if (unlikely(page_offset & 0x7)) {
287 rxe_dbg_mr(mr, "iova not aligned\n");
288 return RESPST_ERR_MISALIGNED_ATOMIC;
289 }
290
291 idx = rxe_odp_iova_to_index(umem_odp, iova);
292 page = hmm_pfn_to_page(umem_odp->map.pfn_list[idx]);
293
294 va = kmap_local_page(page);
295
296 spin_lock_bh(&atomic_ops_lock);
297 value = *orig_val = va[page_offset >> 3];
298
299 if (opcode == IB_OPCODE_RC_COMPARE_SWAP) {
300 if (value == compare)
301 va[page_offset >> 3] = swap_add;
302 } else {
303 value += swap_add;
304 va[page_offset >> 3] = value;
305 }
306 spin_unlock_bh(&atomic_ops_lock);
307
308 kunmap_local(va);
309
310 return RESPST_NONE;
311 }
312
rxe_odp_atomic_op(struct rxe_mr * mr,u64 iova,int opcode,u64 compare,u64 swap_add,u64 * orig_val)313 enum resp_states rxe_odp_atomic_op(struct rxe_mr *mr, u64 iova, int opcode,
314 u64 compare, u64 swap_add, u64 *orig_val)
315 {
316 struct ib_umem_odp *umem_odp = to_ib_umem_odp(mr->umem);
317 int err;
318
319 err = rxe_odp_map_range_and_lock(mr, iova, sizeof(char),
320 RXE_PAGEFAULT_DEFAULT);
321 if (err < 0)
322 return RESPST_ERR_RKEY_VIOLATION;
323
324 err = rxe_odp_do_atomic_op(mr, iova, opcode, compare, swap_add,
325 orig_val);
326 mutex_unlock(&umem_odp->umem_mutex);
327
328 return err;
329 }
330
rxe_odp_flush_pmem_iova(struct rxe_mr * mr,u64 iova,unsigned int length)331 int rxe_odp_flush_pmem_iova(struct rxe_mr *mr, u64 iova,
332 unsigned int length)
333 {
334 struct ib_umem_odp *umem_odp = to_ib_umem_odp(mr->umem);
335 unsigned int page_offset;
336 unsigned long index;
337 struct page *page;
338 unsigned int bytes;
339 int err;
340 u8 *va;
341
342 err = rxe_odp_map_range_and_lock(mr, iova, length,
343 RXE_PAGEFAULT_DEFAULT);
344 if (err)
345 return err;
346
347 while (length > 0) {
348 index = rxe_odp_iova_to_index(umem_odp, iova);
349 page_offset = rxe_odp_iova_to_page_offset(umem_odp, iova);
350
351 page = hmm_pfn_to_page(umem_odp->map.pfn_list[index]);
352
353 bytes = min_t(unsigned int, length,
354 mr_page_size(mr) - page_offset);
355
356 va = kmap_local_page(page);
357 arch_wb_cache_pmem(va + page_offset, bytes);
358 kunmap_local(va);
359
360 length -= bytes;
361 iova += bytes;
362 }
363
364 mutex_unlock(&umem_odp->umem_mutex);
365
366 return 0;
367 }
368
rxe_odp_do_atomic_write(struct rxe_mr * mr,u64 iova,u64 value)369 enum resp_states rxe_odp_do_atomic_write(struct rxe_mr *mr, u64 iova, u64 value)
370 {
371 struct ib_umem_odp *umem_odp = to_ib_umem_odp(mr->umem);
372 unsigned int page_offset;
373 unsigned long index;
374 struct page *page;
375 int err;
376 u64 *va;
377
378 /* See IBA oA19-28 */
379 err = mr_check_range(mr, iova, sizeof(value));
380 if (unlikely(err)) {
381 rxe_dbg_mr(mr, "iova out of range\n");
382 return RESPST_ERR_RKEY_VIOLATION;
383 }
384
385 err = rxe_odp_map_range_and_lock(mr, iova, sizeof(value),
386 RXE_PAGEFAULT_DEFAULT);
387 if (err)
388 return RESPST_ERR_RKEY_VIOLATION;
389
390 page_offset = rxe_odp_iova_to_page_offset(umem_odp, iova);
391 /* See IBA A19.4.2 */
392 if (unlikely(page_offset & 0x7)) {
393 mutex_unlock(&umem_odp->umem_mutex);
394 rxe_dbg_mr(mr, "misaligned address\n");
395 return RESPST_ERR_MISALIGNED_ATOMIC;
396 }
397
398 index = rxe_odp_iova_to_index(umem_odp, iova);
399 page = hmm_pfn_to_page(umem_odp->map.pfn_list[index]);
400
401 va = kmap_local_page(page);
402 /* Do atomic write after all prior operations have completed */
403 smp_store_release(&va[page_offset >> 3], value);
404 kunmap_local(va);
405
406 mutex_unlock(&umem_odp->umem_mutex);
407
408 return RESPST_NONE;
409 }
410
411 struct prefetch_mr_work {
412 struct work_struct work;
413 u32 pf_flags;
414 u32 num_sge;
415 struct {
416 u64 io_virt;
417 struct rxe_mr *mr;
418 size_t length;
419 } frags[];
420 };
421
rxe_ib_prefetch_mr_work(struct work_struct * w)422 static void rxe_ib_prefetch_mr_work(struct work_struct *w)
423 {
424 struct prefetch_mr_work *work =
425 container_of(w, struct prefetch_mr_work, work);
426 int ret;
427 u32 i;
428
429 /*
430 * We rely on IB/core that work is executed
431 * if we have num_sge != 0 only.
432 */
433 WARN_ON(!work->num_sge);
434 for (i = 0; i < work->num_sge; ++i) {
435 struct ib_umem_odp *umem_odp;
436
437 ret = rxe_odp_do_pagefault_and_lock(work->frags[i].mr,
438 work->frags[i].io_virt,
439 work->frags[i].length,
440 work->pf_flags);
441 if (ret < 0) {
442 rxe_dbg_mr(work->frags[i].mr,
443 "failed to prefetch the mr\n");
444 goto deref;
445 }
446
447 umem_odp = to_ib_umem_odp(work->frags[i].mr->umem);
448 mutex_unlock(&umem_odp->umem_mutex);
449
450 deref:
451 rxe_put(work->frags[i].mr);
452 }
453
454 kvfree(work);
455 }
456
rxe_ib_prefetch_sg_list(struct ib_pd * ibpd,enum ib_uverbs_advise_mr_advice advice,u32 pf_flags,struct ib_sge * sg_list,u32 num_sge)457 static int rxe_ib_prefetch_sg_list(struct ib_pd *ibpd,
458 enum ib_uverbs_advise_mr_advice advice,
459 u32 pf_flags, struct ib_sge *sg_list,
460 u32 num_sge)
461 {
462 struct rxe_pd *pd = container_of(ibpd, struct rxe_pd, ibpd);
463 int ret = 0;
464 u32 i;
465
466 for (i = 0; i < num_sge; ++i) {
467 struct rxe_mr *mr;
468 struct ib_umem_odp *umem_odp;
469
470 mr = lookup_mr(pd, IB_ACCESS_LOCAL_WRITE,
471 sg_list[i].lkey, RXE_LOOKUP_LOCAL);
472
473 if (!mr) {
474 rxe_dbg_pd(pd, "mr with lkey %x not found\n",
475 sg_list[i].lkey);
476 return -EINVAL;
477 }
478
479 if (advice == IB_UVERBS_ADVISE_MR_ADVICE_PREFETCH_WRITE &&
480 !mr->umem->writable) {
481 rxe_dbg_mr(mr, "missing write permission\n");
482 rxe_put(mr);
483 return -EPERM;
484 }
485
486 ret = rxe_odp_do_pagefault_and_lock(
487 mr, sg_list[i].addr, sg_list[i].length, pf_flags);
488 if (ret < 0) {
489 rxe_dbg_mr(mr, "failed to prefetch the mr\n");
490 rxe_put(mr);
491 return ret;
492 }
493
494 umem_odp = to_ib_umem_odp(mr->umem);
495 mutex_unlock(&umem_odp->umem_mutex);
496
497 rxe_put(mr);
498 }
499
500 return 0;
501 }
502
rxe_ib_advise_mr_prefetch(struct ib_pd * ibpd,enum ib_uverbs_advise_mr_advice advice,u32 flags,struct ib_sge * sg_list,u32 num_sge)503 static int rxe_ib_advise_mr_prefetch(struct ib_pd *ibpd,
504 enum ib_uverbs_advise_mr_advice advice,
505 u32 flags, struct ib_sge *sg_list,
506 u32 num_sge)
507 {
508 struct rxe_pd *pd = container_of(ibpd, struct rxe_pd, ibpd);
509 u32 pf_flags = RXE_PAGEFAULT_DEFAULT;
510 struct prefetch_mr_work *work;
511 struct rxe_mr *mr;
512 u32 i;
513
514 if (advice == IB_UVERBS_ADVISE_MR_ADVICE_PREFETCH)
515 pf_flags |= RXE_PAGEFAULT_RDONLY;
516
517 if (advice == IB_UVERBS_ADVISE_MR_ADVICE_PREFETCH_NO_FAULT)
518 pf_flags |= RXE_PAGEFAULT_SNAPSHOT;
519
520 /* Synchronous call */
521 if (flags & IB_UVERBS_ADVISE_MR_FLAG_FLUSH)
522 return rxe_ib_prefetch_sg_list(ibpd, advice, pf_flags, sg_list,
523 num_sge);
524
525 /* Asynchronous call is "best-effort" and allowed to fail */
526 work = kvzalloc_flex(*work, frags, num_sge);
527 if (!work)
528 return -ENOMEM;
529
530 INIT_WORK(&work->work, rxe_ib_prefetch_mr_work);
531 work->pf_flags = pf_flags;
532 work->num_sge = num_sge;
533
534 for (i = 0; i < num_sge; ++i) {
535 /* Takes a reference, which will be released in the queued work */
536 mr = lookup_mr(pd, IB_ACCESS_LOCAL_WRITE,
537 sg_list[i].lkey, RXE_LOOKUP_LOCAL);
538 if (!mr) {
539 mr = ERR_PTR(-EINVAL);
540 goto err;
541 }
542
543 work->frags[i].io_virt = sg_list[i].addr;
544 work->frags[i].length = sg_list[i].length;
545 work->frags[i].mr = mr;
546 }
547
548 queue_work(system_unbound_wq, &work->work);
549
550 return 0;
551
552 err:
553 /* rollback reference counts for the invalid request */
554 while (i > 0) {
555 i--;
556 rxe_put(work->frags[i].mr);
557 }
558
559 kvfree(work);
560
561 return PTR_ERR(mr);
562 }
563
rxe_ib_advise_mr(struct ib_pd * ibpd,enum ib_uverbs_advise_mr_advice advice,u32 flags,struct ib_sge * sg_list,u32 num_sge,struct uverbs_attr_bundle * attrs)564 int rxe_ib_advise_mr(struct ib_pd *ibpd,
565 enum ib_uverbs_advise_mr_advice advice,
566 u32 flags,
567 struct ib_sge *sg_list,
568 u32 num_sge,
569 struct uverbs_attr_bundle *attrs)
570 {
571 if (advice != IB_UVERBS_ADVISE_MR_ADVICE_PREFETCH &&
572 advice != IB_UVERBS_ADVISE_MR_ADVICE_PREFETCH_WRITE &&
573 advice != IB_UVERBS_ADVISE_MR_ADVICE_PREFETCH_NO_FAULT)
574 return -EOPNOTSUPP;
575
576 return rxe_ib_advise_mr_prefetch(ibpd, advice, flags,
577 sg_list, num_sge);
578 }
579