1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES
4  */
5 #include <linux/iommu.h>
6 #include <uapi/linux/iommufd.h>
7 
8 #include "../iommu-priv.h"
9 #include "iommufd_private.h"
10 
__iommufd_hwpt_destroy(struct iommufd_hw_pagetable * hwpt)11 static void __iommufd_hwpt_destroy(struct iommufd_hw_pagetable *hwpt)
12 {
13 	if (hwpt->domain)
14 		iommu_domain_free(hwpt->domain);
15 
16 	if (hwpt->fault)
17 		refcount_dec(&hwpt->fault->common.obj.users);
18 }
19 
iommufd_hwpt_paging_destroy(struct iommufd_object * obj)20 void iommufd_hwpt_paging_destroy(struct iommufd_object *obj)
21 {
22 	struct iommufd_hwpt_paging *hwpt_paging =
23 		container_of(obj, struct iommufd_hwpt_paging, common.obj);
24 
25 	if (!list_empty(&hwpt_paging->hwpt_item)) {
26 		mutex_lock(&hwpt_paging->ioas->mutex);
27 		list_del(&hwpt_paging->hwpt_item);
28 		mutex_unlock(&hwpt_paging->ioas->mutex);
29 
30 		iopt_table_remove_domain(&hwpt_paging->ioas->iopt,
31 					 hwpt_paging->common.domain);
32 	}
33 
34 	__iommufd_hwpt_destroy(&hwpt_paging->common);
35 	refcount_dec(&hwpt_paging->ioas->obj.users);
36 }
37 
iommufd_hwpt_paging_abort(struct iommufd_object * obj)38 void iommufd_hwpt_paging_abort(struct iommufd_object *obj)
39 {
40 	struct iommufd_hwpt_paging *hwpt_paging =
41 		container_of(obj, struct iommufd_hwpt_paging, common.obj);
42 
43 	/* The ioas->mutex must be held until finalize is called. */
44 	lockdep_assert_held(&hwpt_paging->ioas->mutex);
45 
46 	if (!list_empty(&hwpt_paging->hwpt_item)) {
47 		list_del_init(&hwpt_paging->hwpt_item);
48 		iopt_table_remove_domain(&hwpt_paging->ioas->iopt,
49 					 hwpt_paging->common.domain);
50 	}
51 	iommufd_hwpt_paging_destroy(obj);
52 }
53 
iommufd_hwpt_nested_destroy(struct iommufd_object * obj)54 void iommufd_hwpt_nested_destroy(struct iommufd_object *obj)
55 {
56 	struct iommufd_hwpt_nested *hwpt_nested =
57 		container_of(obj, struct iommufd_hwpt_nested, common.obj);
58 
59 	__iommufd_hwpt_destroy(&hwpt_nested->common);
60 	if (hwpt_nested->viommu)
61 		refcount_dec(&hwpt_nested->viommu->obj.users);
62 	else
63 		refcount_dec(&hwpt_nested->parent->common.obj.users);
64 }
65 
iommufd_hwpt_nested_abort(struct iommufd_object * obj)66 void iommufd_hwpt_nested_abort(struct iommufd_object *obj)
67 {
68 	iommufd_hwpt_nested_destroy(obj);
69 }
70 
71 static int
iommufd_hwpt_paging_enforce_cc(struct iommufd_hwpt_paging * hwpt_paging)72 iommufd_hwpt_paging_enforce_cc(struct iommufd_hwpt_paging *hwpt_paging)
73 {
74 	struct iommu_domain *paging_domain = hwpt_paging->common.domain;
75 
76 	if (hwpt_paging->enforce_cache_coherency)
77 		return 0;
78 
79 	if (paging_domain->ops->enforce_cache_coherency)
80 		hwpt_paging->enforce_cache_coherency =
81 			paging_domain->ops->enforce_cache_coherency(
82 				paging_domain);
83 	if (!hwpt_paging->enforce_cache_coherency)
84 		return -EINVAL;
85 	return 0;
86 }
87 
88 /**
89  * iommufd_hwpt_paging_alloc() - Get a PAGING iommu_domain for a device
90  * @ictx: iommufd context
91  * @ioas: IOAS to associate the domain with
92  * @idev: Device to get an iommu_domain for
93  * @pasid: PASID to get an iommu_domain for
94  * @flags: Flags from userspace
95  * @immediate_attach: True if idev should be attached to the hwpt
96  * @user_data: The user provided driver specific data describing the domain to
97  *             create
98  *
99  * Allocate a new iommu_domain and return it as a hw_pagetable. The HWPT
100  * will be linked to the given ioas and upon return the underlying iommu_domain
101  * is fully popoulated.
102  *
103  * The caller must hold the ioas->mutex until after
104  * iommufd_object_abort_and_destroy() or iommufd_object_finalize() is called on
105  * the returned hwpt.
106  */
107 struct iommufd_hwpt_paging *
iommufd_hwpt_paging_alloc(struct iommufd_ctx * ictx,struct iommufd_ioas * ioas,struct iommufd_device * idev,ioasid_t pasid,u32 flags,bool immediate_attach,const struct iommu_user_data * user_data)108 iommufd_hwpt_paging_alloc(struct iommufd_ctx *ictx, struct iommufd_ioas *ioas,
109 			  struct iommufd_device *idev, ioasid_t pasid,
110 			  u32 flags, bool immediate_attach,
111 			  const struct iommu_user_data *user_data)
112 {
113 	const u32 valid_flags = IOMMU_HWPT_ALLOC_NEST_PARENT |
114 				IOMMU_HWPT_ALLOC_DIRTY_TRACKING |
115 				IOMMU_HWPT_FAULT_ID_VALID |
116 				IOMMU_HWPT_ALLOC_PASID;
117 	const struct iommu_ops *ops = dev_iommu_ops(idev->dev);
118 	struct iommufd_hwpt_paging *hwpt_paging;
119 	struct iommufd_hw_pagetable *hwpt;
120 	int rc;
121 
122 	lockdep_assert_held(&ioas->mutex);
123 
124 	if ((flags || user_data) && !ops->domain_alloc_paging_flags)
125 		return ERR_PTR(-EOPNOTSUPP);
126 	if (flags & ~valid_flags)
127 		return ERR_PTR(-EOPNOTSUPP);
128 	if ((flags & IOMMU_HWPT_ALLOC_DIRTY_TRACKING) &&
129 	    !device_iommu_capable(idev->dev, IOMMU_CAP_DIRTY_TRACKING))
130 		return ERR_PTR(-EOPNOTSUPP);
131 	if ((flags & IOMMU_HWPT_FAULT_ID_VALID) &&
132 	    (flags & IOMMU_HWPT_ALLOC_NEST_PARENT))
133 		return ERR_PTR(-EOPNOTSUPP);
134 
135 	hwpt_paging = __iommufd_object_alloc(
136 		ictx, hwpt_paging, IOMMUFD_OBJ_HWPT_PAGING, common.obj);
137 	if (IS_ERR(hwpt_paging))
138 		return ERR_CAST(hwpt_paging);
139 	hwpt = &hwpt_paging->common;
140 	hwpt->pasid_compat = flags & IOMMU_HWPT_ALLOC_PASID;
141 
142 	INIT_LIST_HEAD(&hwpt_paging->hwpt_item);
143 	/* Pairs with iommufd_hw_pagetable_destroy() */
144 	refcount_inc(&ioas->obj.users);
145 	hwpt_paging->ioas = ioas;
146 	hwpt_paging->nest_parent = flags & IOMMU_HWPT_ALLOC_NEST_PARENT;
147 
148 	if (ops->domain_alloc_paging_flags) {
149 		hwpt->domain = ops->domain_alloc_paging_flags(idev->dev,
150 				flags & ~IOMMU_HWPT_FAULT_ID_VALID, user_data);
151 		if (IS_ERR(hwpt->domain)) {
152 			rc = PTR_ERR(hwpt->domain);
153 			hwpt->domain = NULL;
154 			goto out_abort;
155 		}
156 		hwpt->domain->owner = ops;
157 	} else {
158 		hwpt->domain = iommu_paging_domain_alloc(idev->dev);
159 		if (IS_ERR(hwpt->domain)) {
160 			rc = PTR_ERR(hwpt->domain);
161 			hwpt->domain = NULL;
162 			goto out_abort;
163 		}
164 	}
165 	hwpt->domain->iommufd_hwpt = hwpt;
166 	hwpt->domain->cookie_type = IOMMU_COOKIE_IOMMUFD;
167 
168 	/*
169 	 * Set the coherency mode before we do iopt_table_add_domain() as some
170 	 * iommus have a per-PTE bit that controls it and need to decide before
171 	 * doing any maps. It is an iommu driver bug to report
172 	 * IOMMU_CAP_ENFORCE_CACHE_COHERENCY but fail enforce_cache_coherency on
173 	 * a new domain.
174 	 *
175 	 * The cache coherency mode must be configured here and unchanged later.
176 	 * Note that a HWPT (non-CC) created for a device (non-CC) can be later
177 	 * reused by another device (either non-CC or CC). However, A HWPT (CC)
178 	 * created for a device (CC) cannot be reused by another device (non-CC)
179 	 * but only devices (CC). Instead user space in this case would need to
180 	 * allocate a separate HWPT (non-CC).
181 	 */
182 	if (idev->enforce_cache_coherency) {
183 		rc = iommufd_hwpt_paging_enforce_cc(hwpt_paging);
184 		if (WARN_ON(rc))
185 			goto out_abort;
186 	}
187 
188 	/*
189 	 * immediate_attach exists only to accommodate iommu drivers that cannot
190 	 * directly allocate a domain. These drivers do not finish creating the
191 	 * domain until attach is completed. Thus we must have this call
192 	 * sequence. Once those drivers are fixed this should be removed.
193 	 */
194 	if (immediate_attach) {
195 		rc = iommufd_hw_pagetable_attach(hwpt, idev, pasid);
196 		if (rc)
197 			goto out_abort;
198 	}
199 
200 	rc = iopt_table_add_domain(&ioas->iopt, hwpt->domain);
201 	if (rc)
202 		goto out_detach;
203 	list_add_tail(&hwpt_paging->hwpt_item, &ioas->hwpt_list);
204 	return hwpt_paging;
205 
206 out_detach:
207 	if (immediate_attach)
208 		iommufd_hw_pagetable_detach(idev, pasid);
209 out_abort:
210 	iommufd_object_abort_and_destroy(ictx, &hwpt->obj);
211 	return ERR_PTR(rc);
212 }
213 
214 /**
215  * iommufd_hwpt_nested_alloc() - Get a NESTED iommu_domain for a device
216  * @ictx: iommufd context
217  * @parent: Parent PAGING-type hwpt to associate the domain with
218  * @idev: Device to get an iommu_domain for
219  * @flags: Flags from userspace
220  * @user_data: user_data pointer. Must be valid
221  *
222  * Allocate a new iommu_domain (must be IOMMU_DOMAIN_NESTED) and return it as
223  * a NESTED hw_pagetable. The given parent PAGING-type hwpt must be capable of
224  * being a parent.
225  */
226 static struct iommufd_hwpt_nested *
iommufd_hwpt_nested_alloc(struct iommufd_ctx * ictx,struct iommufd_hwpt_paging * parent,struct iommufd_device * idev,u32 flags,const struct iommu_user_data * user_data)227 iommufd_hwpt_nested_alloc(struct iommufd_ctx *ictx,
228 			  struct iommufd_hwpt_paging *parent,
229 			  struct iommufd_device *idev, u32 flags,
230 			  const struct iommu_user_data *user_data)
231 {
232 	const struct iommu_ops *ops = dev_iommu_ops(idev->dev);
233 	struct iommufd_hwpt_nested *hwpt_nested;
234 	struct iommufd_hw_pagetable *hwpt;
235 	int rc;
236 
237 	if ((flags & ~(IOMMU_HWPT_FAULT_ID_VALID | IOMMU_HWPT_ALLOC_PASID)) ||
238 	    !user_data->len || !ops->domain_alloc_nested)
239 		return ERR_PTR(-EOPNOTSUPP);
240 	if (parent->auto_domain || !parent->nest_parent ||
241 	    parent->common.domain->owner != ops)
242 		return ERR_PTR(-EINVAL);
243 
244 	hwpt_nested = __iommufd_object_alloc(
245 		ictx, hwpt_nested, IOMMUFD_OBJ_HWPT_NESTED, common.obj);
246 	if (IS_ERR(hwpt_nested))
247 		return ERR_CAST(hwpt_nested);
248 	hwpt = &hwpt_nested->common;
249 	hwpt->pasid_compat = flags & IOMMU_HWPT_ALLOC_PASID;
250 
251 	refcount_inc(&parent->common.obj.users);
252 	hwpt_nested->parent = parent;
253 
254 	hwpt->domain = ops->domain_alloc_nested(
255 		idev->dev, parent->common.domain,
256 		flags & ~IOMMU_HWPT_FAULT_ID_VALID, user_data);
257 	if (IS_ERR(hwpt->domain)) {
258 		rc = PTR_ERR(hwpt->domain);
259 		hwpt->domain = NULL;
260 		goto out_abort;
261 	}
262 	hwpt->domain->owner = ops;
263 	hwpt->domain->iommufd_hwpt = hwpt;
264 	hwpt->domain->cookie_type = IOMMU_COOKIE_IOMMUFD;
265 
266 	if (WARN_ON_ONCE(hwpt->domain->type != IOMMU_DOMAIN_NESTED)) {
267 		rc = -EINVAL;
268 		goto out_abort;
269 	}
270 	return hwpt_nested;
271 
272 out_abort:
273 	iommufd_object_abort_and_destroy(ictx, &hwpt->obj);
274 	return ERR_PTR(rc);
275 }
276 
277 /**
278  * iommufd_viommu_alloc_hwpt_nested() - Get a hwpt_nested for a vIOMMU
279  * @viommu: vIOMMU ojbect to associate the hwpt_nested/domain with
280  * @flags: Flags from userspace
281  * @user_data: user_data pointer. Must be valid
282  *
283  * Allocate a new IOMMU_DOMAIN_NESTED for a vIOMMU and return it as a NESTED
284  * hw_pagetable.
285  */
286 static struct iommufd_hwpt_nested *
iommufd_viommu_alloc_hwpt_nested(struct iommufd_viommu * viommu,u32 flags,const struct iommu_user_data * user_data)287 iommufd_viommu_alloc_hwpt_nested(struct iommufd_viommu *viommu, u32 flags,
288 				 const struct iommu_user_data *user_data)
289 {
290 	struct iommufd_hwpt_nested *hwpt_nested;
291 	struct iommufd_hw_pagetable *hwpt;
292 	int rc;
293 
294 	if (flags & ~(IOMMU_HWPT_FAULT_ID_VALID | IOMMU_HWPT_ALLOC_PASID))
295 		return ERR_PTR(-EOPNOTSUPP);
296 	if (!user_data->len)
297 		return ERR_PTR(-EOPNOTSUPP);
298 	if (!viommu->ops || !viommu->ops->alloc_domain_nested)
299 		return ERR_PTR(-EOPNOTSUPP);
300 
301 	hwpt_nested = __iommufd_object_alloc(
302 		viommu->ictx, hwpt_nested, IOMMUFD_OBJ_HWPT_NESTED, common.obj);
303 	if (IS_ERR(hwpt_nested))
304 		return ERR_CAST(hwpt_nested);
305 	hwpt = &hwpt_nested->common;
306 	hwpt->pasid_compat = flags & IOMMU_HWPT_ALLOC_PASID;
307 
308 	hwpt_nested->viommu = viommu;
309 	refcount_inc(&viommu->obj.users);
310 	hwpt_nested->parent = viommu->hwpt;
311 
312 	hwpt->domain =
313 		viommu->ops->alloc_domain_nested(viommu,
314 				flags & ~IOMMU_HWPT_FAULT_ID_VALID,
315 				user_data);
316 	if (IS_ERR(hwpt->domain)) {
317 		rc = PTR_ERR(hwpt->domain);
318 		hwpt->domain = NULL;
319 		goto out_abort;
320 	}
321 	hwpt->domain->iommufd_hwpt = hwpt;
322 	hwpt->domain->owner = viommu->iommu_dev->ops;
323 	hwpt->domain->cookie_type = IOMMU_COOKIE_IOMMUFD;
324 
325 	if (WARN_ON_ONCE(hwpt->domain->type != IOMMU_DOMAIN_NESTED)) {
326 		rc = -EINVAL;
327 		goto out_abort;
328 	}
329 	return hwpt_nested;
330 
331 out_abort:
332 	iommufd_object_abort_and_destroy(viommu->ictx, &hwpt->obj);
333 	return ERR_PTR(rc);
334 }
335 
iommufd_hwpt_alloc(struct iommufd_ucmd * ucmd)336 int iommufd_hwpt_alloc(struct iommufd_ucmd *ucmd)
337 {
338 	struct iommu_hwpt_alloc *cmd = ucmd->cmd;
339 	const struct iommu_user_data user_data = {
340 		.type = cmd->data_type,
341 		.uptr = u64_to_user_ptr(cmd->data_uptr),
342 		.len = cmd->data_len,
343 	};
344 	struct iommufd_hw_pagetable *hwpt;
345 	struct iommufd_ioas *ioas = NULL;
346 	struct iommufd_object *pt_obj;
347 	struct iommufd_device *idev;
348 	int rc;
349 
350 	if (cmd->__reserved)
351 		return -EOPNOTSUPP;
352 	if ((cmd->data_type == IOMMU_HWPT_DATA_NONE && cmd->data_len) ||
353 	    (cmd->data_type != IOMMU_HWPT_DATA_NONE && !cmd->data_len))
354 		return -EINVAL;
355 
356 	idev = iommufd_get_device(ucmd, cmd->dev_id);
357 	if (IS_ERR(idev))
358 		return PTR_ERR(idev);
359 
360 	pt_obj = iommufd_get_object(ucmd->ictx, cmd->pt_id, IOMMUFD_OBJ_ANY);
361 	if (IS_ERR(pt_obj)) {
362 		rc = -EINVAL;
363 		goto out_put_idev;
364 	}
365 
366 	if (pt_obj->type == IOMMUFD_OBJ_IOAS) {
367 		struct iommufd_hwpt_paging *hwpt_paging;
368 
369 		ioas = container_of(pt_obj, struct iommufd_ioas, obj);
370 		mutex_lock(&ioas->mutex);
371 		hwpt_paging = iommufd_hwpt_paging_alloc(
372 			ucmd->ictx, ioas, idev, IOMMU_NO_PASID, cmd->flags,
373 			false, user_data.len ? &user_data : NULL);
374 		if (IS_ERR(hwpt_paging)) {
375 			rc = PTR_ERR(hwpt_paging);
376 			goto out_unlock;
377 		}
378 		hwpt = &hwpt_paging->common;
379 	} else if (pt_obj->type == IOMMUFD_OBJ_HWPT_PAGING) {
380 		struct iommufd_hwpt_nested *hwpt_nested;
381 
382 		hwpt_nested = iommufd_hwpt_nested_alloc(
383 			ucmd->ictx,
384 			container_of(pt_obj, struct iommufd_hwpt_paging,
385 				     common.obj),
386 			idev, cmd->flags, &user_data);
387 		if (IS_ERR(hwpt_nested)) {
388 			rc = PTR_ERR(hwpt_nested);
389 			goto out_unlock;
390 		}
391 		hwpt = &hwpt_nested->common;
392 	} else if (pt_obj->type == IOMMUFD_OBJ_VIOMMU) {
393 		struct iommufd_hwpt_nested *hwpt_nested;
394 		struct iommufd_viommu *viommu;
395 
396 		viommu = container_of(pt_obj, struct iommufd_viommu, obj);
397 		if (viommu->iommu_dev != __iommu_get_iommu_dev(idev->dev)) {
398 			rc = -EINVAL;
399 			goto out_unlock;
400 		}
401 		hwpt_nested = iommufd_viommu_alloc_hwpt_nested(
402 			viommu, cmd->flags, &user_data);
403 		if (IS_ERR(hwpt_nested)) {
404 			rc = PTR_ERR(hwpt_nested);
405 			goto out_unlock;
406 		}
407 		hwpt = &hwpt_nested->common;
408 	} else {
409 		rc = -EINVAL;
410 		goto out_put_pt;
411 	}
412 
413 	if (cmd->flags & IOMMU_HWPT_FAULT_ID_VALID) {
414 		struct iommufd_fault *fault;
415 
416 		fault = iommufd_get_fault(ucmd, cmd->fault_id);
417 		if (IS_ERR(fault)) {
418 			rc = PTR_ERR(fault);
419 			goto out_hwpt;
420 		}
421 		hwpt->fault = fault;
422 		hwpt->domain->iopf_handler = iommufd_fault_iopf_handler;
423 		refcount_inc(&fault->common.obj.users);
424 		iommufd_put_object(ucmd->ictx, &fault->common.obj);
425 	}
426 
427 	cmd->out_hwpt_id = hwpt->obj.id;
428 	rc = iommufd_ucmd_respond(ucmd, sizeof(*cmd));
429 	if (rc)
430 		goto out_hwpt;
431 	iommufd_object_finalize(ucmd->ictx, &hwpt->obj);
432 	goto out_unlock;
433 
434 out_hwpt:
435 	iommufd_object_abort_and_destroy(ucmd->ictx, &hwpt->obj);
436 out_unlock:
437 	if (ioas)
438 		mutex_unlock(&ioas->mutex);
439 out_put_pt:
440 	iommufd_put_object(ucmd->ictx, pt_obj);
441 out_put_idev:
442 	iommufd_put_object(ucmd->ictx, &idev->obj);
443 	return rc;
444 }
445 
iommufd_hwpt_set_dirty_tracking(struct iommufd_ucmd * ucmd)446 int iommufd_hwpt_set_dirty_tracking(struct iommufd_ucmd *ucmd)
447 {
448 	struct iommu_hwpt_set_dirty_tracking *cmd = ucmd->cmd;
449 	struct iommufd_hwpt_paging *hwpt_paging;
450 	struct iommufd_ioas *ioas;
451 	int rc = -EOPNOTSUPP;
452 	bool enable;
453 
454 	if (cmd->flags & ~IOMMU_HWPT_DIRTY_TRACKING_ENABLE)
455 		return rc;
456 
457 	hwpt_paging = iommufd_get_hwpt_paging(ucmd, cmd->hwpt_id);
458 	if (IS_ERR(hwpt_paging))
459 		return PTR_ERR(hwpt_paging);
460 
461 	ioas = hwpt_paging->ioas;
462 	enable = cmd->flags & IOMMU_HWPT_DIRTY_TRACKING_ENABLE;
463 
464 	rc = iopt_set_dirty_tracking(&ioas->iopt, hwpt_paging->common.domain,
465 				     enable);
466 
467 	iommufd_put_object(ucmd->ictx, &hwpt_paging->common.obj);
468 	return rc;
469 }
470 
iommufd_hwpt_get_dirty_bitmap(struct iommufd_ucmd * ucmd)471 int iommufd_hwpt_get_dirty_bitmap(struct iommufd_ucmd *ucmd)
472 {
473 	struct iommu_hwpt_get_dirty_bitmap *cmd = ucmd->cmd;
474 	struct iommufd_hwpt_paging *hwpt_paging;
475 	struct iommufd_ioas *ioas;
476 	int rc = -EOPNOTSUPP;
477 
478 	if ((cmd->flags & ~(IOMMU_HWPT_GET_DIRTY_BITMAP_NO_CLEAR)) ||
479 	    cmd->__reserved)
480 		return -EOPNOTSUPP;
481 
482 	hwpt_paging = iommufd_get_hwpt_paging(ucmd, cmd->hwpt_id);
483 	if (IS_ERR(hwpt_paging))
484 		return PTR_ERR(hwpt_paging);
485 
486 	ioas = hwpt_paging->ioas;
487 	rc = iopt_read_and_clear_dirty_data(
488 		&ioas->iopt, hwpt_paging->common.domain, cmd->flags, cmd);
489 
490 	iommufd_put_object(ucmd->ictx, &hwpt_paging->common.obj);
491 	return rc;
492 }
493 
iommufd_hwpt_invalidate(struct iommufd_ucmd * ucmd)494 int iommufd_hwpt_invalidate(struct iommufd_ucmd *ucmd)
495 {
496 	struct iommu_hwpt_invalidate *cmd = ucmd->cmd;
497 	struct iommu_user_data_array data_array = {
498 		.type = cmd->data_type,
499 		.uptr = u64_to_user_ptr(cmd->data_uptr),
500 		.entry_len = cmd->entry_len,
501 		.entry_num = cmd->entry_num,
502 	};
503 	struct iommufd_object *pt_obj;
504 	u32 done_num = 0;
505 	int rc;
506 
507 	if (cmd->__reserved) {
508 		rc = -EOPNOTSUPP;
509 		goto out;
510 	}
511 
512 	if (cmd->entry_num && (!cmd->data_uptr || !cmd->entry_len)) {
513 		rc = -EINVAL;
514 		goto out;
515 	}
516 
517 	pt_obj = iommufd_get_object(ucmd->ictx, cmd->hwpt_id, IOMMUFD_OBJ_ANY);
518 	if (IS_ERR(pt_obj)) {
519 		rc = PTR_ERR(pt_obj);
520 		goto out;
521 	}
522 	if (pt_obj->type == IOMMUFD_OBJ_HWPT_NESTED) {
523 		struct iommufd_hw_pagetable *hwpt =
524 			container_of(pt_obj, struct iommufd_hw_pagetable, obj);
525 
526 		if (!hwpt->domain->ops ||
527 		    !hwpt->domain->ops->cache_invalidate_user) {
528 			rc = -EOPNOTSUPP;
529 			goto out_put_pt;
530 		}
531 		rc = hwpt->domain->ops->cache_invalidate_user(hwpt->domain,
532 							      &data_array);
533 	} else if (pt_obj->type == IOMMUFD_OBJ_VIOMMU) {
534 		struct iommufd_viommu *viommu =
535 			container_of(pt_obj, struct iommufd_viommu, obj);
536 
537 		if (!viommu->ops || !viommu->ops->cache_invalidate) {
538 			rc = -EOPNOTSUPP;
539 			goto out_put_pt;
540 		}
541 		rc = viommu->ops->cache_invalidate(viommu, &data_array);
542 	} else {
543 		rc = -EINVAL;
544 		goto out_put_pt;
545 	}
546 
547 	done_num = data_array.entry_num;
548 
549 out_put_pt:
550 	iommufd_put_object(ucmd->ictx, pt_obj);
551 out:
552 	cmd->entry_num = done_num;
553 	if (iommufd_ucmd_respond(ucmd, sizeof(*cmd)))
554 		return -EFAULT;
555 	return rc;
556 }
557