xref: /linux/drivers/hv/mshv_regions.c (revision c3d13784d5b200fc4b4a1f5d5f5585b8e3a5777e)
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * Copyright (c) 2025, Microsoft Corporation.
4  *
5  * Memory region management for mshv_root module.
6  *
7  * Authors: Microsoft Linux virtualization team
8  */
9 
10 #include <linux/hmm.h>
11 #include <linux/hyperv.h>
12 #include <linux/kref.h>
13 #include <linux/mm.h>
14 #include <linux/vmalloc.h>
15 
16 #include <asm/mshyperv.h>
17 
18 #include "mshv_root.h"
19 
20 #define MSHV_MAP_FAULT_IN_PAGES				PTRS_PER_PMD
21 
22 /**
23  * mshv_chunk_stride - Compute stride for mapping guest memory
24  * @page      : The page to check for huge page backing
25  * @gfn       : Guest frame number for the mapping
26  * @page_count: Total number of pages in the mapping
27  *
28  * Determines the appropriate stride (in pages) for mapping guest memory.
29  * Uses huge page stride if the backing page is huge and the guest mapping
30  * is properly aligned; otherwise falls back to single page stride.
31  *
32  * Return: Stride in pages, or -EINVAL if page order is unsupported.
33  */
mshv_chunk_stride(struct page * page,u64 gfn,u64 page_count)34 static int mshv_chunk_stride(struct page *page,
35 			     u64 gfn, u64 page_count)
36 {
37 	unsigned int page_order;
38 
39 	/*
40 	 * Use single page stride by default. For huge page stride, the
41 	 * page must be compound and point to the head of the compound
42 	 * page, and both gfn and page_count must be huge-page aligned.
43 	 */
44 	if (!PageCompound(page) || !PageHead(page) ||
45 	    !IS_ALIGNED(gfn, PTRS_PER_PMD) ||
46 	    !IS_ALIGNED(page_count, PTRS_PER_PMD))
47 		return 1;
48 
49 	page_order = folio_order(page_folio(page));
50 	/* The hypervisor only supports 2M huge page */
51 	if (page_order != PMD_ORDER)
52 		return -EINVAL;
53 
54 	return 1 << page_order;
55 }
56 
57 /**
58  * mshv_region_process_chunk - Processes a contiguous chunk of memory pages
59  *                             in a region.
60  * @region     : Pointer to the memory region structure.
61  * @flags      : Flags to pass to the handler.
62  * @page_offset: Offset into the region's pages array to start processing.
63  * @page_count : Number of pages to process.
64  * @handler    : Callback function to handle the chunk.
65  *
66  * This function scans the region's pages starting from @page_offset,
67  * checking for contiguous present pages of the same size (normal or huge).
68  * It invokes @handler for the chunk of contiguous pages found. Returns the
69  * number of pages handled, or a negative error code if the first page is
70  * not present or the handler fails.
71  *
72  * Note: The @handler callback must be able to handle both normal and huge
73  * pages.
74  *
75  * Return: Number of pages handled, or negative error code.
76  */
mshv_region_process_chunk(struct mshv_mem_region * region,u32 flags,u64 page_offset,u64 page_count,int (* handler)(struct mshv_mem_region * region,u32 flags,u64 page_offset,u64 page_count,bool huge_page))77 static long mshv_region_process_chunk(struct mshv_mem_region *region,
78 				      u32 flags,
79 				      u64 page_offset, u64 page_count,
80 				      int (*handler)(struct mshv_mem_region *region,
81 						     u32 flags,
82 						     u64 page_offset,
83 						     u64 page_count,
84 						     bool huge_page))
85 {
86 	u64 gfn = region->start_gfn + page_offset;
87 	u64 count;
88 	struct page *page;
89 	int stride, ret;
90 
91 	page = region->mreg_pages[page_offset];
92 	if (!page)
93 		return -EINVAL;
94 
95 	stride = mshv_chunk_stride(page, gfn, page_count);
96 	if (stride < 0)
97 		return stride;
98 
99 	/* Start at stride since the first stride is validated */
100 	for (count = stride; count < page_count; count += stride) {
101 		page = region->mreg_pages[page_offset + count];
102 
103 		/* Break if current page is not present */
104 		if (!page)
105 			break;
106 
107 		/* Break if stride size changes */
108 		if (stride != mshv_chunk_stride(page, gfn + count,
109 						page_count - count))
110 			break;
111 	}
112 
113 	ret = handler(region, flags, page_offset, count, stride > 1);
114 	if (ret)
115 		return ret;
116 
117 	return count;
118 }
119 
120 /**
121  * mshv_region_process_range - Processes a range of memory pages in a
122  *                             region.
123  * @region     : Pointer to the memory region structure.
124  * @flags      : Flags to pass to the handler.
125  * @page_offset: Offset into the region's pages array to start processing.
126  * @page_count : Number of pages to process.
127  * @handler    : Callback function to handle each chunk of contiguous
128  *               pages.
129  *
130  * Iterates over the specified range of pages in @region, skipping
131  * non-present pages. For each contiguous chunk of present pages, invokes
132  * @handler via mshv_region_process_chunk.
133  *
134  * Note: The @handler callback must be able to handle both normal and huge
135  * pages.
136  *
137  * Returns 0 on success, or a negative error code on failure.
138  */
mshv_region_process_range(struct mshv_mem_region * region,u32 flags,u64 page_offset,u64 page_count,int (* handler)(struct mshv_mem_region * region,u32 flags,u64 page_offset,u64 page_count,bool huge_page))139 static int mshv_region_process_range(struct mshv_mem_region *region,
140 				     u32 flags,
141 				     u64 page_offset, u64 page_count,
142 				     int (*handler)(struct mshv_mem_region *region,
143 						    u32 flags,
144 						    u64 page_offset,
145 						    u64 page_count,
146 						    bool huge_page))
147 {
148 	long ret;
149 
150 	if (page_offset + page_count > region->nr_pages)
151 		return -EINVAL;
152 
153 	while (page_count) {
154 		/* Skip non-present pages */
155 		if (!region->mreg_pages[page_offset]) {
156 			page_offset++;
157 			page_count--;
158 			continue;
159 		}
160 
161 		ret = mshv_region_process_chunk(region, flags,
162 						page_offset,
163 						page_count,
164 						handler);
165 		if (ret < 0)
166 			return ret;
167 
168 		page_offset += ret;
169 		page_count -= ret;
170 	}
171 
172 	return 0;
173 }
174 
mshv_region_create(u64 guest_pfn,u64 nr_pages,u64 uaddr,u32 flags)175 struct mshv_mem_region *mshv_region_create(u64 guest_pfn, u64 nr_pages,
176 					   u64 uaddr, u32 flags)
177 {
178 	struct mshv_mem_region *region;
179 
180 	region = vzalloc(sizeof(*region) + sizeof(struct page *) * nr_pages);
181 	if (!region)
182 		return ERR_PTR(-ENOMEM);
183 
184 	region->nr_pages = nr_pages;
185 	region->start_gfn = guest_pfn;
186 	region->start_uaddr = uaddr;
187 	region->hv_map_flags = HV_MAP_GPA_READABLE | HV_MAP_GPA_ADJUSTABLE;
188 	if (flags & BIT(MSHV_SET_MEM_BIT_WRITABLE))
189 		region->hv_map_flags |= HV_MAP_GPA_WRITABLE;
190 	if (flags & BIT(MSHV_SET_MEM_BIT_EXECUTABLE))
191 		region->hv_map_flags |= HV_MAP_GPA_EXECUTABLE;
192 
193 	kref_init(&region->mreg_refcount);
194 
195 	return region;
196 }
197 
mshv_region_chunk_share(struct mshv_mem_region * region,u32 flags,u64 page_offset,u64 page_count,bool huge_page)198 static int mshv_region_chunk_share(struct mshv_mem_region *region,
199 				   u32 flags,
200 				   u64 page_offset, u64 page_count,
201 				   bool huge_page)
202 {
203 	if (huge_page)
204 		flags |= HV_MODIFY_SPA_PAGE_HOST_ACCESS_LARGE_PAGE;
205 
206 	return hv_call_modify_spa_host_access(region->partition->pt_id,
207 					      region->mreg_pages + page_offset,
208 					      page_count,
209 					      HV_MAP_GPA_READABLE |
210 					      HV_MAP_GPA_WRITABLE,
211 					      flags, true);
212 }
213 
mshv_region_share(struct mshv_mem_region * region)214 int mshv_region_share(struct mshv_mem_region *region)
215 {
216 	u32 flags = HV_MODIFY_SPA_PAGE_HOST_ACCESS_MAKE_SHARED;
217 
218 	return mshv_region_process_range(region, flags,
219 					 0, region->nr_pages,
220 					 mshv_region_chunk_share);
221 }
222 
mshv_region_chunk_unshare(struct mshv_mem_region * region,u32 flags,u64 page_offset,u64 page_count,bool huge_page)223 static int mshv_region_chunk_unshare(struct mshv_mem_region *region,
224 				     u32 flags,
225 				     u64 page_offset, u64 page_count,
226 				     bool huge_page)
227 {
228 	if (huge_page)
229 		flags |= HV_MODIFY_SPA_PAGE_HOST_ACCESS_LARGE_PAGE;
230 
231 	return hv_call_modify_spa_host_access(region->partition->pt_id,
232 					      region->mreg_pages + page_offset,
233 					      page_count, 0,
234 					      flags, false);
235 }
236 
mshv_region_unshare(struct mshv_mem_region * region)237 int mshv_region_unshare(struct mshv_mem_region *region)
238 {
239 	u32 flags = HV_MODIFY_SPA_PAGE_HOST_ACCESS_MAKE_EXCLUSIVE;
240 
241 	return mshv_region_process_range(region, flags,
242 					 0, region->nr_pages,
243 					 mshv_region_chunk_unshare);
244 }
245 
mshv_region_chunk_remap(struct mshv_mem_region * region,u32 flags,u64 page_offset,u64 page_count,bool huge_page)246 static int mshv_region_chunk_remap(struct mshv_mem_region *region,
247 				   u32 flags,
248 				   u64 page_offset, u64 page_count,
249 				   bool huge_page)
250 {
251 	if (huge_page)
252 		flags |= HV_MAP_GPA_LARGE_PAGE;
253 
254 	return hv_call_map_gpa_pages(region->partition->pt_id,
255 				     region->start_gfn + page_offset,
256 				     page_count, flags,
257 				     region->mreg_pages + page_offset);
258 }
259 
mshv_region_remap_pages(struct mshv_mem_region * region,u32 map_flags,u64 page_offset,u64 page_count)260 static int mshv_region_remap_pages(struct mshv_mem_region *region,
261 				   u32 map_flags,
262 				   u64 page_offset, u64 page_count)
263 {
264 	return mshv_region_process_range(region, map_flags,
265 					 page_offset, page_count,
266 					 mshv_region_chunk_remap);
267 }
268 
mshv_region_map(struct mshv_mem_region * region)269 int mshv_region_map(struct mshv_mem_region *region)
270 {
271 	u32 map_flags = region->hv_map_flags;
272 
273 	return mshv_region_remap_pages(region, map_flags,
274 				       0, region->nr_pages);
275 }
276 
mshv_region_invalidate_pages(struct mshv_mem_region * region,u64 page_offset,u64 page_count)277 static void mshv_region_invalidate_pages(struct mshv_mem_region *region,
278 					 u64 page_offset, u64 page_count)
279 {
280 	if (region->mreg_type == MSHV_REGION_TYPE_MEM_PINNED)
281 		unpin_user_pages(region->mreg_pages + page_offset, page_count);
282 
283 	memset(region->mreg_pages + page_offset, 0,
284 	       page_count * sizeof(struct page *));
285 }
286 
mshv_region_invalidate(struct mshv_mem_region * region)287 void mshv_region_invalidate(struct mshv_mem_region *region)
288 {
289 	mshv_region_invalidate_pages(region, 0, region->nr_pages);
290 }
291 
mshv_region_pin(struct mshv_mem_region * region)292 int mshv_region_pin(struct mshv_mem_region *region)
293 {
294 	u64 done_count, nr_pages;
295 	struct page **pages;
296 	__u64 userspace_addr;
297 	int ret;
298 
299 	for (done_count = 0; done_count < region->nr_pages; done_count += ret) {
300 		pages = region->mreg_pages + done_count;
301 		userspace_addr = region->start_uaddr +
302 				 done_count * HV_HYP_PAGE_SIZE;
303 		nr_pages = min(region->nr_pages - done_count,
304 			       MSHV_PIN_PAGES_BATCH_SIZE);
305 
306 		/*
307 		 * Pinning assuming 4k pages works for large pages too.
308 		 * All page structs within the large page are returned.
309 		 *
310 		 * Pin requests are batched because pin_user_pages_fast
311 		 * with the FOLL_LONGTERM flag does a large temporary
312 		 * allocation of contiguous memory.
313 		 */
314 		ret = pin_user_pages_fast(userspace_addr, nr_pages,
315 					  FOLL_WRITE | FOLL_LONGTERM,
316 					  pages);
317 		if (ret != nr_pages)
318 			goto release_pages;
319 	}
320 
321 	return 0;
322 
323 release_pages:
324 	if (ret > 0)
325 		done_count += ret;
326 	mshv_region_invalidate_pages(region, 0, done_count);
327 	return ret < 0 ? ret : -ENOMEM;
328 }
329 
mshv_region_chunk_unmap(struct mshv_mem_region * region,u32 flags,u64 page_offset,u64 page_count,bool huge_page)330 static int mshv_region_chunk_unmap(struct mshv_mem_region *region,
331 				   u32 flags,
332 				   u64 page_offset, u64 page_count,
333 				   bool huge_page)
334 {
335 	if (huge_page)
336 		flags |= HV_UNMAP_GPA_LARGE_PAGE;
337 
338 	return hv_call_unmap_gpa_pages(region->partition->pt_id,
339 				       region->start_gfn + page_offset,
340 				       page_count, flags);
341 }
342 
mshv_region_unmap(struct mshv_mem_region * region)343 static int mshv_region_unmap(struct mshv_mem_region *region)
344 {
345 	return mshv_region_process_range(region, 0,
346 					 0, region->nr_pages,
347 					 mshv_region_chunk_unmap);
348 }
349 
mshv_region_destroy(struct kref * ref)350 static void mshv_region_destroy(struct kref *ref)
351 {
352 	struct mshv_mem_region *region =
353 		container_of(ref, struct mshv_mem_region, mreg_refcount);
354 	struct mshv_partition *partition = region->partition;
355 	int ret;
356 
357 	if (region->mreg_type == MSHV_REGION_TYPE_MEM_MOVABLE)
358 		mshv_region_movable_fini(region);
359 
360 	if (mshv_partition_encrypted(partition)) {
361 		ret = mshv_region_share(region);
362 		if (ret) {
363 			pt_err(partition,
364 			       "Failed to regain access to memory, unpinning user pages will fail and crash the host error: %d\n",
365 			       ret);
366 			return;
367 		}
368 	}
369 
370 	mshv_region_unmap(region);
371 
372 	mshv_region_invalidate(region);
373 
374 	vfree(region);
375 }
376 
mshv_region_put(struct mshv_mem_region * region)377 void mshv_region_put(struct mshv_mem_region *region)
378 {
379 	kref_put(&region->mreg_refcount, mshv_region_destroy);
380 }
381 
mshv_region_get(struct mshv_mem_region * region)382 int mshv_region_get(struct mshv_mem_region *region)
383 {
384 	return kref_get_unless_zero(&region->mreg_refcount);
385 }
386 
387 /**
388  * mshv_region_hmm_fault_and_lock - Handle HMM faults and lock the memory region
389  * @region: Pointer to the memory region structure
390  * @range: Pointer to the HMM range structure
391  *
392  * This function performs the following steps:
393  * 1. Reads the notifier sequence for the HMM range.
394  * 2. Acquires a read lock on the memory map.
395  * 3. Handles HMM faults for the specified range.
396  * 4. Releases the read lock on the memory map.
397  * 5. If successful, locks the memory region mutex.
398  * 6. Verifies if the notifier sequence has changed during the operation.
399  *    If it has, releases the mutex and returns -EBUSY to match with
400  *    hmm_range_fault() return code for repeating.
401  *
402  * Return: 0 on success, a negative error code otherwise.
403  */
mshv_region_hmm_fault_and_lock(struct mshv_mem_region * region,struct hmm_range * range)404 static int mshv_region_hmm_fault_and_lock(struct mshv_mem_region *region,
405 					  struct hmm_range *range)
406 {
407 	int ret;
408 
409 	range->notifier_seq = mmu_interval_read_begin(range->notifier);
410 	mmap_read_lock(region->mreg_mni.mm);
411 	ret = hmm_range_fault(range);
412 	mmap_read_unlock(region->mreg_mni.mm);
413 	if (ret)
414 		return ret;
415 
416 	mutex_lock(&region->mreg_mutex);
417 
418 	if (mmu_interval_read_retry(range->notifier, range->notifier_seq)) {
419 		mutex_unlock(&region->mreg_mutex);
420 		cond_resched();
421 		return -EBUSY;
422 	}
423 
424 	return 0;
425 }
426 
427 /**
428  * mshv_region_range_fault - Handle memory range faults for a given region.
429  * @region: Pointer to the memory region structure.
430  * @page_offset: Offset of the page within the region.
431  * @page_count: Number of pages to handle.
432  *
433  * This function resolves memory faults for a specified range of pages
434  * within a memory region. It uses HMM (Heterogeneous Memory Management)
435  * to fault in the required pages and updates the region's page array.
436  *
437  * Return: 0 on success, negative error code on failure.
438  */
mshv_region_range_fault(struct mshv_mem_region * region,u64 page_offset,u64 page_count)439 static int mshv_region_range_fault(struct mshv_mem_region *region,
440 				   u64 page_offset, u64 page_count)
441 {
442 	struct hmm_range range = {
443 		.notifier = &region->mreg_mni,
444 		.default_flags = HMM_PFN_REQ_FAULT | HMM_PFN_REQ_WRITE,
445 	};
446 	unsigned long *pfns;
447 	int ret;
448 	u64 i;
449 
450 	pfns = kmalloc_array(page_count, sizeof(*pfns), GFP_KERNEL);
451 	if (!pfns)
452 		return -ENOMEM;
453 
454 	range.hmm_pfns = pfns;
455 	range.start = region->start_uaddr + page_offset * HV_HYP_PAGE_SIZE;
456 	range.end = range.start + page_count * HV_HYP_PAGE_SIZE;
457 
458 	do {
459 		ret = mshv_region_hmm_fault_and_lock(region, &range);
460 	} while (ret == -EBUSY);
461 
462 	if (ret)
463 		goto out;
464 
465 	for (i = 0; i < page_count; i++)
466 		region->mreg_pages[page_offset + i] = hmm_pfn_to_page(pfns[i]);
467 
468 	ret = mshv_region_remap_pages(region, region->hv_map_flags,
469 				      page_offset, page_count);
470 
471 	mutex_unlock(&region->mreg_mutex);
472 out:
473 	kfree(pfns);
474 	return ret;
475 }
476 
mshv_region_handle_gfn_fault(struct mshv_mem_region * region,u64 gfn)477 bool mshv_region_handle_gfn_fault(struct mshv_mem_region *region, u64 gfn)
478 {
479 	u64 page_offset, page_count;
480 	int ret;
481 
482 	/* Align the page offset to the nearest MSHV_MAP_FAULT_IN_PAGES. */
483 	page_offset = ALIGN_DOWN(gfn - region->start_gfn,
484 				 MSHV_MAP_FAULT_IN_PAGES);
485 
486 	/* Map more pages than requested to reduce the number of faults. */
487 	page_count = min(region->nr_pages - page_offset,
488 			 MSHV_MAP_FAULT_IN_PAGES);
489 
490 	ret = mshv_region_range_fault(region, page_offset, page_count);
491 
492 	WARN_ONCE(ret,
493 		  "p%llu: GPA intercept failed: region %#llx-%#llx, gfn %#llx, page_offset %llu, page_count %llu\n",
494 		  region->partition->pt_id, region->start_uaddr,
495 		  region->start_uaddr + (region->nr_pages << HV_HYP_PAGE_SHIFT),
496 		  gfn, page_offset, page_count);
497 
498 	return !ret;
499 }
500 
501 /**
502  * mshv_region_interval_invalidate - Invalidate a range of memory region
503  * @mni: Pointer to the mmu_interval_notifier structure
504  * @range: Pointer to the mmu_notifier_range structure
505  * @cur_seq: Current sequence number for the interval notifier
506  *
507  * This function invalidates a memory region by remapping its pages with
508  * no access permissions. It locks the region's mutex to ensure thread safety
509  * and updates the sequence number for the interval notifier. If the range
510  * is blockable, it uses a blocking lock; otherwise, it attempts a non-blocking
511  * lock and returns false if unsuccessful.
512  *
513  * NOTE: Failure to invalidate a region is a serious error, as the pages will
514  * be considered freed while they are still mapped by the hypervisor.
515  * Any attempt to access such pages will likely crash the system.
516  *
517  * Return: true if the region was successfully invalidated, false otherwise.
518  */
mshv_region_interval_invalidate(struct mmu_interval_notifier * mni,const struct mmu_notifier_range * range,unsigned long cur_seq)519 static bool mshv_region_interval_invalidate(struct mmu_interval_notifier *mni,
520 					    const struct mmu_notifier_range *range,
521 					    unsigned long cur_seq)
522 {
523 	struct mshv_mem_region *region = container_of(mni,
524 						      struct mshv_mem_region,
525 						      mreg_mni);
526 	u64 page_offset, page_count;
527 	unsigned long mstart, mend;
528 	int ret = -EPERM;
529 
530 	mstart = max(range->start, region->start_uaddr);
531 	mend = min(range->end, region->start_uaddr +
532 		   (region->nr_pages << HV_HYP_PAGE_SHIFT));
533 
534 	page_offset = HVPFN_DOWN(mstart - region->start_uaddr);
535 	page_count = HVPFN_DOWN(mend - mstart);
536 
537 	if (mmu_notifier_range_blockable(range))
538 		mutex_lock(&region->mreg_mutex);
539 	else if (!mutex_trylock(&region->mreg_mutex))
540 		goto out_fail;
541 
542 	mmu_interval_set_seq(mni, cur_seq);
543 
544 	ret = mshv_region_remap_pages(region, HV_MAP_GPA_NO_ACCESS,
545 				      page_offset, page_count);
546 	if (ret)
547 		goto out_unlock;
548 
549 	mshv_region_invalidate_pages(region, page_offset, page_count);
550 
551 	mutex_unlock(&region->mreg_mutex);
552 
553 	return true;
554 
555 out_unlock:
556 	mutex_unlock(&region->mreg_mutex);
557 out_fail:
558 	WARN_ONCE(ret,
559 		  "Failed to invalidate region %#llx-%#llx (range %#lx-%#lx, event: %u, pages %#llx-%#llx, mm: %#llx): %d\n",
560 		  region->start_uaddr,
561 		  region->start_uaddr + (region->nr_pages << HV_HYP_PAGE_SHIFT),
562 		  range->start, range->end, range->event,
563 		  page_offset, page_offset + page_count - 1, (u64)range->mm, ret);
564 	return false;
565 }
566 
567 static const struct mmu_interval_notifier_ops mshv_region_mni_ops = {
568 	.invalidate = mshv_region_interval_invalidate,
569 };
570 
mshv_region_movable_fini(struct mshv_mem_region * region)571 void mshv_region_movable_fini(struct mshv_mem_region *region)
572 {
573 	mmu_interval_notifier_remove(&region->mreg_mni);
574 }
575 
mshv_region_movable_init(struct mshv_mem_region * region)576 bool mshv_region_movable_init(struct mshv_mem_region *region)
577 {
578 	int ret;
579 
580 	ret = mmu_interval_notifier_insert(&region->mreg_mni, current->mm,
581 					   region->start_uaddr,
582 					   region->nr_pages << HV_HYP_PAGE_SHIFT,
583 					   &mshv_region_mni_ops);
584 	if (ret)
585 		return false;
586 
587 	mutex_init(&region->mreg_mutex);
588 
589 	return true;
590 }
591