1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3 * Copyright (c) 2025, Microsoft Corporation.
4 *
5 * Memory region management for mshv_root module.
6 *
7 * Authors: Microsoft Linux virtualization team
8 */
9
10 #include <linux/hmm.h>
11 #include <linux/hyperv.h>
12 #include <linux/kref.h>
13 #include <linux/mm.h>
14 #include <linux/vmalloc.h>
15
16 #include <asm/mshyperv.h>
17
18 #include "mshv_root.h"
19
20 #define MSHV_MAP_FAULT_IN_PAGES PTRS_PER_PMD
21
22 /**
23 * mshv_chunk_stride - Compute stride for mapping guest memory
24 * @page : The page to check for huge page backing
25 * @gfn : Guest frame number for the mapping
26 * @page_count: Total number of pages in the mapping
27 *
28 * Determines the appropriate stride (in pages) for mapping guest memory.
29 * Uses huge page stride if the backing page is huge and the guest mapping
30 * is properly aligned; otherwise falls back to single page stride.
31 *
32 * Return: Stride in pages, or -EINVAL if page order is unsupported.
33 */
mshv_chunk_stride(struct page * page,u64 gfn,u64 page_count)34 static int mshv_chunk_stride(struct page *page,
35 u64 gfn, u64 page_count)
36 {
37 unsigned int page_order;
38
39 /*
40 * Use single page stride by default. For huge page stride, the
41 * page must be compound and point to the head of the compound
42 * page, and both gfn and page_count must be huge-page aligned.
43 */
44 if (!PageCompound(page) || !PageHead(page) ||
45 !IS_ALIGNED(gfn, PTRS_PER_PMD) ||
46 !IS_ALIGNED(page_count, PTRS_PER_PMD))
47 return 1;
48
49 page_order = folio_order(page_folio(page));
50 /* The hypervisor only supports 2M huge page */
51 if (page_order != PMD_ORDER)
52 return -EINVAL;
53
54 return 1 << page_order;
55 }
56
57 /**
58 * mshv_region_process_chunk - Processes a contiguous chunk of memory pages
59 * in a region.
60 * @region : Pointer to the memory region structure.
61 * @flags : Flags to pass to the handler.
62 * @page_offset: Offset into the region's pages array to start processing.
63 * @page_count : Number of pages to process.
64 * @handler : Callback function to handle the chunk.
65 *
66 * This function scans the region's pages starting from @page_offset,
67 * checking for contiguous present pages of the same size (normal or huge).
68 * It invokes @handler for the chunk of contiguous pages found. Returns the
69 * number of pages handled, or a negative error code if the first page is
70 * not present or the handler fails.
71 *
72 * Note: The @handler callback must be able to handle both normal and huge
73 * pages.
74 *
75 * Return: Number of pages handled, or negative error code.
76 */
mshv_region_process_chunk(struct mshv_mem_region * region,u32 flags,u64 page_offset,u64 page_count,int (* handler)(struct mshv_mem_region * region,u32 flags,u64 page_offset,u64 page_count,bool huge_page))77 static long mshv_region_process_chunk(struct mshv_mem_region *region,
78 u32 flags,
79 u64 page_offset, u64 page_count,
80 int (*handler)(struct mshv_mem_region *region,
81 u32 flags,
82 u64 page_offset,
83 u64 page_count,
84 bool huge_page))
85 {
86 u64 gfn = region->start_gfn + page_offset;
87 u64 count;
88 struct page *page;
89 int stride, ret;
90
91 page = region->mreg_pages[page_offset];
92 if (!page)
93 return -EINVAL;
94
95 stride = mshv_chunk_stride(page, gfn, page_count);
96 if (stride < 0)
97 return stride;
98
99 /* Start at stride since the first stride is validated */
100 for (count = stride; count < page_count; count += stride) {
101 page = region->mreg_pages[page_offset + count];
102
103 /* Break if current page is not present */
104 if (!page)
105 break;
106
107 /* Break if stride size changes */
108 if (stride != mshv_chunk_stride(page, gfn + count,
109 page_count - count))
110 break;
111 }
112
113 ret = handler(region, flags, page_offset, count, stride > 1);
114 if (ret)
115 return ret;
116
117 return count;
118 }
119
120 /**
121 * mshv_region_process_range - Processes a range of memory pages in a
122 * region.
123 * @region : Pointer to the memory region structure.
124 * @flags : Flags to pass to the handler.
125 * @page_offset: Offset into the region's pages array to start processing.
126 * @page_count : Number of pages to process.
127 * @handler : Callback function to handle each chunk of contiguous
128 * pages.
129 *
130 * Iterates over the specified range of pages in @region, skipping
131 * non-present pages. For each contiguous chunk of present pages, invokes
132 * @handler via mshv_region_process_chunk.
133 *
134 * Note: The @handler callback must be able to handle both normal and huge
135 * pages.
136 *
137 * Returns 0 on success, or a negative error code on failure.
138 */
mshv_region_process_range(struct mshv_mem_region * region,u32 flags,u64 page_offset,u64 page_count,int (* handler)(struct mshv_mem_region * region,u32 flags,u64 page_offset,u64 page_count,bool huge_page))139 static int mshv_region_process_range(struct mshv_mem_region *region,
140 u32 flags,
141 u64 page_offset, u64 page_count,
142 int (*handler)(struct mshv_mem_region *region,
143 u32 flags,
144 u64 page_offset,
145 u64 page_count,
146 bool huge_page))
147 {
148 long ret;
149
150 if (page_offset + page_count > region->nr_pages)
151 return -EINVAL;
152
153 while (page_count) {
154 /* Skip non-present pages */
155 if (!region->mreg_pages[page_offset]) {
156 page_offset++;
157 page_count--;
158 continue;
159 }
160
161 ret = mshv_region_process_chunk(region, flags,
162 page_offset,
163 page_count,
164 handler);
165 if (ret < 0)
166 return ret;
167
168 page_offset += ret;
169 page_count -= ret;
170 }
171
172 return 0;
173 }
174
mshv_region_create(u64 guest_pfn,u64 nr_pages,u64 uaddr,u32 flags)175 struct mshv_mem_region *mshv_region_create(u64 guest_pfn, u64 nr_pages,
176 u64 uaddr, u32 flags)
177 {
178 struct mshv_mem_region *region;
179
180 region = vzalloc(sizeof(*region) + sizeof(struct page *) * nr_pages);
181 if (!region)
182 return ERR_PTR(-ENOMEM);
183
184 region->nr_pages = nr_pages;
185 region->start_gfn = guest_pfn;
186 region->start_uaddr = uaddr;
187 region->hv_map_flags = HV_MAP_GPA_READABLE | HV_MAP_GPA_ADJUSTABLE;
188 if (flags & BIT(MSHV_SET_MEM_BIT_WRITABLE))
189 region->hv_map_flags |= HV_MAP_GPA_WRITABLE;
190 if (flags & BIT(MSHV_SET_MEM_BIT_EXECUTABLE))
191 region->hv_map_flags |= HV_MAP_GPA_EXECUTABLE;
192
193 kref_init(®ion->mreg_refcount);
194
195 return region;
196 }
197
mshv_region_chunk_share(struct mshv_mem_region * region,u32 flags,u64 page_offset,u64 page_count,bool huge_page)198 static int mshv_region_chunk_share(struct mshv_mem_region *region,
199 u32 flags,
200 u64 page_offset, u64 page_count,
201 bool huge_page)
202 {
203 if (huge_page)
204 flags |= HV_MODIFY_SPA_PAGE_HOST_ACCESS_LARGE_PAGE;
205
206 return hv_call_modify_spa_host_access(region->partition->pt_id,
207 region->mreg_pages + page_offset,
208 page_count,
209 HV_MAP_GPA_READABLE |
210 HV_MAP_GPA_WRITABLE,
211 flags, true);
212 }
213
mshv_region_share(struct mshv_mem_region * region)214 int mshv_region_share(struct mshv_mem_region *region)
215 {
216 u32 flags = HV_MODIFY_SPA_PAGE_HOST_ACCESS_MAKE_SHARED;
217
218 return mshv_region_process_range(region, flags,
219 0, region->nr_pages,
220 mshv_region_chunk_share);
221 }
222
mshv_region_chunk_unshare(struct mshv_mem_region * region,u32 flags,u64 page_offset,u64 page_count,bool huge_page)223 static int mshv_region_chunk_unshare(struct mshv_mem_region *region,
224 u32 flags,
225 u64 page_offset, u64 page_count,
226 bool huge_page)
227 {
228 if (huge_page)
229 flags |= HV_MODIFY_SPA_PAGE_HOST_ACCESS_LARGE_PAGE;
230
231 return hv_call_modify_spa_host_access(region->partition->pt_id,
232 region->mreg_pages + page_offset,
233 page_count, 0,
234 flags, false);
235 }
236
mshv_region_unshare(struct mshv_mem_region * region)237 int mshv_region_unshare(struct mshv_mem_region *region)
238 {
239 u32 flags = HV_MODIFY_SPA_PAGE_HOST_ACCESS_MAKE_EXCLUSIVE;
240
241 return mshv_region_process_range(region, flags,
242 0, region->nr_pages,
243 mshv_region_chunk_unshare);
244 }
245
mshv_region_chunk_remap(struct mshv_mem_region * region,u32 flags,u64 page_offset,u64 page_count,bool huge_page)246 static int mshv_region_chunk_remap(struct mshv_mem_region *region,
247 u32 flags,
248 u64 page_offset, u64 page_count,
249 bool huge_page)
250 {
251 if (huge_page)
252 flags |= HV_MAP_GPA_LARGE_PAGE;
253
254 return hv_call_map_gpa_pages(region->partition->pt_id,
255 region->start_gfn + page_offset,
256 page_count, flags,
257 region->mreg_pages + page_offset);
258 }
259
mshv_region_remap_pages(struct mshv_mem_region * region,u32 map_flags,u64 page_offset,u64 page_count)260 static int mshv_region_remap_pages(struct mshv_mem_region *region,
261 u32 map_flags,
262 u64 page_offset, u64 page_count)
263 {
264 return mshv_region_process_range(region, map_flags,
265 page_offset, page_count,
266 mshv_region_chunk_remap);
267 }
268
mshv_region_map(struct mshv_mem_region * region)269 int mshv_region_map(struct mshv_mem_region *region)
270 {
271 u32 map_flags = region->hv_map_flags;
272
273 return mshv_region_remap_pages(region, map_flags,
274 0, region->nr_pages);
275 }
276
mshv_region_invalidate_pages(struct mshv_mem_region * region,u64 page_offset,u64 page_count)277 static void mshv_region_invalidate_pages(struct mshv_mem_region *region,
278 u64 page_offset, u64 page_count)
279 {
280 if (region->mreg_type == MSHV_REGION_TYPE_MEM_PINNED)
281 unpin_user_pages(region->mreg_pages + page_offset, page_count);
282
283 memset(region->mreg_pages + page_offset, 0,
284 page_count * sizeof(struct page *));
285 }
286
mshv_region_invalidate(struct mshv_mem_region * region)287 void mshv_region_invalidate(struct mshv_mem_region *region)
288 {
289 mshv_region_invalidate_pages(region, 0, region->nr_pages);
290 }
291
mshv_region_pin(struct mshv_mem_region * region)292 int mshv_region_pin(struct mshv_mem_region *region)
293 {
294 u64 done_count, nr_pages;
295 struct page **pages;
296 __u64 userspace_addr;
297 int ret;
298
299 for (done_count = 0; done_count < region->nr_pages; done_count += ret) {
300 pages = region->mreg_pages + done_count;
301 userspace_addr = region->start_uaddr +
302 done_count * HV_HYP_PAGE_SIZE;
303 nr_pages = min(region->nr_pages - done_count,
304 MSHV_PIN_PAGES_BATCH_SIZE);
305
306 /*
307 * Pinning assuming 4k pages works for large pages too.
308 * All page structs within the large page are returned.
309 *
310 * Pin requests are batched because pin_user_pages_fast
311 * with the FOLL_LONGTERM flag does a large temporary
312 * allocation of contiguous memory.
313 */
314 ret = pin_user_pages_fast(userspace_addr, nr_pages,
315 FOLL_WRITE | FOLL_LONGTERM,
316 pages);
317 if (ret != nr_pages)
318 goto release_pages;
319 }
320
321 return 0;
322
323 release_pages:
324 if (ret > 0)
325 done_count += ret;
326 mshv_region_invalidate_pages(region, 0, done_count);
327 return ret < 0 ? ret : -ENOMEM;
328 }
329
mshv_region_chunk_unmap(struct mshv_mem_region * region,u32 flags,u64 page_offset,u64 page_count,bool huge_page)330 static int mshv_region_chunk_unmap(struct mshv_mem_region *region,
331 u32 flags,
332 u64 page_offset, u64 page_count,
333 bool huge_page)
334 {
335 if (huge_page)
336 flags |= HV_UNMAP_GPA_LARGE_PAGE;
337
338 return hv_call_unmap_gpa_pages(region->partition->pt_id,
339 region->start_gfn + page_offset,
340 page_count, flags);
341 }
342
mshv_region_unmap(struct mshv_mem_region * region)343 static int mshv_region_unmap(struct mshv_mem_region *region)
344 {
345 return mshv_region_process_range(region, 0,
346 0, region->nr_pages,
347 mshv_region_chunk_unmap);
348 }
349
mshv_region_destroy(struct kref * ref)350 static void mshv_region_destroy(struct kref *ref)
351 {
352 struct mshv_mem_region *region =
353 container_of(ref, struct mshv_mem_region, mreg_refcount);
354 struct mshv_partition *partition = region->partition;
355 int ret;
356
357 if (region->mreg_type == MSHV_REGION_TYPE_MEM_MOVABLE)
358 mshv_region_movable_fini(region);
359
360 if (mshv_partition_encrypted(partition)) {
361 ret = mshv_region_share(region);
362 if (ret) {
363 pt_err(partition,
364 "Failed to regain access to memory, unpinning user pages will fail and crash the host error: %d\n",
365 ret);
366 return;
367 }
368 }
369
370 mshv_region_unmap(region);
371
372 mshv_region_invalidate(region);
373
374 vfree(region);
375 }
376
mshv_region_put(struct mshv_mem_region * region)377 void mshv_region_put(struct mshv_mem_region *region)
378 {
379 kref_put(®ion->mreg_refcount, mshv_region_destroy);
380 }
381
mshv_region_get(struct mshv_mem_region * region)382 int mshv_region_get(struct mshv_mem_region *region)
383 {
384 return kref_get_unless_zero(®ion->mreg_refcount);
385 }
386
387 /**
388 * mshv_region_hmm_fault_and_lock - Handle HMM faults and lock the memory region
389 * @region: Pointer to the memory region structure
390 * @range: Pointer to the HMM range structure
391 *
392 * This function performs the following steps:
393 * 1. Reads the notifier sequence for the HMM range.
394 * 2. Acquires a read lock on the memory map.
395 * 3. Handles HMM faults for the specified range.
396 * 4. Releases the read lock on the memory map.
397 * 5. If successful, locks the memory region mutex.
398 * 6. Verifies if the notifier sequence has changed during the operation.
399 * If it has, releases the mutex and returns -EBUSY to match with
400 * hmm_range_fault() return code for repeating.
401 *
402 * Return: 0 on success, a negative error code otherwise.
403 */
mshv_region_hmm_fault_and_lock(struct mshv_mem_region * region,struct hmm_range * range)404 static int mshv_region_hmm_fault_and_lock(struct mshv_mem_region *region,
405 struct hmm_range *range)
406 {
407 int ret;
408
409 range->notifier_seq = mmu_interval_read_begin(range->notifier);
410 mmap_read_lock(region->mreg_mni.mm);
411 ret = hmm_range_fault(range);
412 mmap_read_unlock(region->mreg_mni.mm);
413 if (ret)
414 return ret;
415
416 mutex_lock(®ion->mreg_mutex);
417
418 if (mmu_interval_read_retry(range->notifier, range->notifier_seq)) {
419 mutex_unlock(®ion->mreg_mutex);
420 cond_resched();
421 return -EBUSY;
422 }
423
424 return 0;
425 }
426
427 /**
428 * mshv_region_range_fault - Handle memory range faults for a given region.
429 * @region: Pointer to the memory region structure.
430 * @page_offset: Offset of the page within the region.
431 * @page_count: Number of pages to handle.
432 *
433 * This function resolves memory faults for a specified range of pages
434 * within a memory region. It uses HMM (Heterogeneous Memory Management)
435 * to fault in the required pages and updates the region's page array.
436 *
437 * Return: 0 on success, negative error code on failure.
438 */
mshv_region_range_fault(struct mshv_mem_region * region,u64 page_offset,u64 page_count)439 static int mshv_region_range_fault(struct mshv_mem_region *region,
440 u64 page_offset, u64 page_count)
441 {
442 struct hmm_range range = {
443 .notifier = ®ion->mreg_mni,
444 .default_flags = HMM_PFN_REQ_FAULT | HMM_PFN_REQ_WRITE,
445 };
446 unsigned long *pfns;
447 int ret;
448 u64 i;
449
450 pfns = kmalloc_array(page_count, sizeof(*pfns), GFP_KERNEL);
451 if (!pfns)
452 return -ENOMEM;
453
454 range.hmm_pfns = pfns;
455 range.start = region->start_uaddr + page_offset * HV_HYP_PAGE_SIZE;
456 range.end = range.start + page_count * HV_HYP_PAGE_SIZE;
457
458 do {
459 ret = mshv_region_hmm_fault_and_lock(region, &range);
460 } while (ret == -EBUSY);
461
462 if (ret)
463 goto out;
464
465 for (i = 0; i < page_count; i++)
466 region->mreg_pages[page_offset + i] = hmm_pfn_to_page(pfns[i]);
467
468 ret = mshv_region_remap_pages(region, region->hv_map_flags,
469 page_offset, page_count);
470
471 mutex_unlock(®ion->mreg_mutex);
472 out:
473 kfree(pfns);
474 return ret;
475 }
476
mshv_region_handle_gfn_fault(struct mshv_mem_region * region,u64 gfn)477 bool mshv_region_handle_gfn_fault(struct mshv_mem_region *region, u64 gfn)
478 {
479 u64 page_offset, page_count;
480 int ret;
481
482 /* Align the page offset to the nearest MSHV_MAP_FAULT_IN_PAGES. */
483 page_offset = ALIGN_DOWN(gfn - region->start_gfn,
484 MSHV_MAP_FAULT_IN_PAGES);
485
486 /* Map more pages than requested to reduce the number of faults. */
487 page_count = min(region->nr_pages - page_offset,
488 MSHV_MAP_FAULT_IN_PAGES);
489
490 ret = mshv_region_range_fault(region, page_offset, page_count);
491
492 WARN_ONCE(ret,
493 "p%llu: GPA intercept failed: region %#llx-%#llx, gfn %#llx, page_offset %llu, page_count %llu\n",
494 region->partition->pt_id, region->start_uaddr,
495 region->start_uaddr + (region->nr_pages << HV_HYP_PAGE_SHIFT),
496 gfn, page_offset, page_count);
497
498 return !ret;
499 }
500
501 /**
502 * mshv_region_interval_invalidate - Invalidate a range of memory region
503 * @mni: Pointer to the mmu_interval_notifier structure
504 * @range: Pointer to the mmu_notifier_range structure
505 * @cur_seq: Current sequence number for the interval notifier
506 *
507 * This function invalidates a memory region by remapping its pages with
508 * no access permissions. It locks the region's mutex to ensure thread safety
509 * and updates the sequence number for the interval notifier. If the range
510 * is blockable, it uses a blocking lock; otherwise, it attempts a non-blocking
511 * lock and returns false if unsuccessful.
512 *
513 * NOTE: Failure to invalidate a region is a serious error, as the pages will
514 * be considered freed while they are still mapped by the hypervisor.
515 * Any attempt to access such pages will likely crash the system.
516 *
517 * Return: true if the region was successfully invalidated, false otherwise.
518 */
mshv_region_interval_invalidate(struct mmu_interval_notifier * mni,const struct mmu_notifier_range * range,unsigned long cur_seq)519 static bool mshv_region_interval_invalidate(struct mmu_interval_notifier *mni,
520 const struct mmu_notifier_range *range,
521 unsigned long cur_seq)
522 {
523 struct mshv_mem_region *region = container_of(mni,
524 struct mshv_mem_region,
525 mreg_mni);
526 u64 page_offset, page_count;
527 unsigned long mstart, mend;
528 int ret = -EPERM;
529
530 mstart = max(range->start, region->start_uaddr);
531 mend = min(range->end, region->start_uaddr +
532 (region->nr_pages << HV_HYP_PAGE_SHIFT));
533
534 page_offset = HVPFN_DOWN(mstart - region->start_uaddr);
535 page_count = HVPFN_DOWN(mend - mstart);
536
537 if (mmu_notifier_range_blockable(range))
538 mutex_lock(®ion->mreg_mutex);
539 else if (!mutex_trylock(®ion->mreg_mutex))
540 goto out_fail;
541
542 mmu_interval_set_seq(mni, cur_seq);
543
544 ret = mshv_region_remap_pages(region, HV_MAP_GPA_NO_ACCESS,
545 page_offset, page_count);
546 if (ret)
547 goto out_unlock;
548
549 mshv_region_invalidate_pages(region, page_offset, page_count);
550
551 mutex_unlock(®ion->mreg_mutex);
552
553 return true;
554
555 out_unlock:
556 mutex_unlock(®ion->mreg_mutex);
557 out_fail:
558 WARN_ONCE(ret,
559 "Failed to invalidate region %#llx-%#llx (range %#lx-%#lx, event: %u, pages %#llx-%#llx, mm: %#llx): %d\n",
560 region->start_uaddr,
561 region->start_uaddr + (region->nr_pages << HV_HYP_PAGE_SHIFT),
562 range->start, range->end, range->event,
563 page_offset, page_offset + page_count - 1, (u64)range->mm, ret);
564 return false;
565 }
566
567 static const struct mmu_interval_notifier_ops mshv_region_mni_ops = {
568 .invalidate = mshv_region_interval_invalidate,
569 };
570
mshv_region_movable_fini(struct mshv_mem_region * region)571 void mshv_region_movable_fini(struct mshv_mem_region *region)
572 {
573 mmu_interval_notifier_remove(®ion->mreg_mni);
574 }
575
mshv_region_movable_init(struct mshv_mem_region * region)576 bool mshv_region_movable_init(struct mshv_mem_region *region)
577 {
578 int ret;
579
580 ret = mmu_interval_notifier_insert(®ion->mreg_mni, current->mm,
581 region->start_uaddr,
582 region->nr_pages << HV_HYP_PAGE_SHIFT,
583 &mshv_region_mni_ops);
584 if (ret)
585 return false;
586
587 mutex_init(®ion->mreg_mutex);
588
589 return true;
590 }
591