1 // SPDX-License-Identifier: GPL-2.0
2 /* Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES.
3  *
4  * The iopt_pages is the center of the storage and motion of PFNs. Each
5  * iopt_pages represents a logical linear array of full PFNs. The array is 0
6  * based and has npages in it. Accessors use 'index' to refer to the entry in
7  * this logical array, regardless of its storage location.
8  *
9  * PFNs are stored in a tiered scheme:
10  *  1) iopt_pages::pinned_pfns xarray
11  *  2) An iommu_domain
12  *  3) The origin of the PFNs, i.e. the userspace pointer
13  *
14  * PFN have to be copied between all combinations of tiers, depending on the
15  * configuration.
16  *
17  * When a PFN is taken out of the userspace pointer it is pinned exactly once.
18  * The storage locations of the PFN's index are tracked in the two interval
19  * trees. If no interval includes the index then it is not pinned.
20  *
21  * If access_itree includes the PFN's index then an in-kernel access has
22  * requested the page. The PFN is stored in the xarray so other requestors can
23  * continue to find it.
24  *
25  * If the domains_itree includes the PFN's index then an iommu_domain is storing
26  * the PFN and it can be read back using iommu_iova_to_phys(). To avoid
27  * duplicating storage the xarray is not used if only iommu_domains are using
28  * the PFN's index.
29  *
30  * As a general principle this is designed so that destroy never fails. This
31  * means removing an iommu_domain or releasing a in-kernel access will not fail
32  * due to insufficient memory. In practice this means some cases have to hold
33  * PFNs in the xarray even though they are also being stored in an iommu_domain.
34  *
35  * While the iopt_pages can use an iommu_domain as storage, it does not have an
36  * IOVA itself. Instead the iopt_area represents a range of IOVA and uses the
37  * iopt_pages as the PFN provider. Multiple iopt_areas can share the iopt_pages
38  * and reference their own slice of the PFN array, with sub page granularity.
39  *
40  * In this file the term 'last' indicates an inclusive and closed interval, eg
41  * [0,0] refers to a single PFN. 'end' means an open range, eg [0,0) refers to
42  * no PFNs.
43  *
44  * Be cautious of overflow. An IOVA can go all the way up to U64_MAX, so
45  * last_iova + 1 can overflow. An iopt_pages index will always be much less than
46  * ULONG_MAX so last_index + 1 cannot overflow.
47  */
48 #include <linux/overflow.h>
49 #include <linux/slab.h>
50 #include <linux/iommu.h>
51 #include <linux/sched/mm.h>
52 #include <linux/highmem.h>
53 #include <linux/kthread.h>
54 #include <linux/iommufd.h>
55 
56 #include "io_pagetable.h"
57 #include "double_span.h"
58 
59 #ifndef CONFIG_IOMMUFD_TEST
60 #define TEMP_MEMORY_LIMIT 65536
61 #else
62 #define TEMP_MEMORY_LIMIT iommufd_test_memory_limit
63 #endif
64 #define BATCH_BACKUP_SIZE 32
65 
66 /*
67  * More memory makes pin_user_pages() and the batching more efficient, but as
68  * this is only a performance optimization don't try too hard to get it. A 64k
69  * allocation can hold about 26M of 4k pages and 13G of 2M pages in an
70  * pfn_batch. Various destroy paths cannot fail and provide a small amount of
71  * stack memory as a backup contingency. If backup_len is given this cannot
72  * fail.
73  */
temp_kmalloc(size_t * size,void * backup,size_t backup_len)74 static void *temp_kmalloc(size_t *size, void *backup, size_t backup_len)
75 {
76 	void *res;
77 
78 	if (WARN_ON(*size == 0))
79 		return NULL;
80 
81 	if (*size < backup_len)
82 		return backup;
83 
84 	if (!backup && iommufd_should_fail())
85 		return NULL;
86 
87 	*size = min_t(size_t, *size, TEMP_MEMORY_LIMIT);
88 	res = kmalloc(*size, GFP_KERNEL | __GFP_NOWARN | __GFP_NORETRY);
89 	if (res)
90 		return res;
91 	*size = PAGE_SIZE;
92 	if (backup_len) {
93 		res = kmalloc(*size, GFP_KERNEL | __GFP_NOWARN | __GFP_NORETRY);
94 		if (res)
95 			return res;
96 		*size = backup_len;
97 		return backup;
98 	}
99 	return kmalloc(*size, GFP_KERNEL);
100 }
101 
interval_tree_double_span_iter_update(struct interval_tree_double_span_iter * iter)102 void interval_tree_double_span_iter_update(
103 	struct interval_tree_double_span_iter *iter)
104 {
105 	unsigned long last_hole = ULONG_MAX;
106 	unsigned int i;
107 
108 	for (i = 0; i != ARRAY_SIZE(iter->spans); i++) {
109 		if (interval_tree_span_iter_done(&iter->spans[i])) {
110 			iter->is_used = -1;
111 			return;
112 		}
113 
114 		if (iter->spans[i].is_hole) {
115 			last_hole = min(last_hole, iter->spans[i].last_hole);
116 			continue;
117 		}
118 
119 		iter->is_used = i + 1;
120 		iter->start_used = iter->spans[i].start_used;
121 		iter->last_used = min(iter->spans[i].last_used, last_hole);
122 		return;
123 	}
124 
125 	iter->is_used = 0;
126 	iter->start_hole = iter->spans[0].start_hole;
127 	iter->last_hole =
128 		min(iter->spans[0].last_hole, iter->spans[1].last_hole);
129 }
130 
interval_tree_double_span_iter_first(struct interval_tree_double_span_iter * iter,struct rb_root_cached * itree1,struct rb_root_cached * itree2,unsigned long first_index,unsigned long last_index)131 void interval_tree_double_span_iter_first(
132 	struct interval_tree_double_span_iter *iter,
133 	struct rb_root_cached *itree1, struct rb_root_cached *itree2,
134 	unsigned long first_index, unsigned long last_index)
135 {
136 	unsigned int i;
137 
138 	iter->itrees[0] = itree1;
139 	iter->itrees[1] = itree2;
140 	for (i = 0; i != ARRAY_SIZE(iter->spans); i++)
141 		interval_tree_span_iter_first(&iter->spans[i], iter->itrees[i],
142 					      first_index, last_index);
143 	interval_tree_double_span_iter_update(iter);
144 }
145 
interval_tree_double_span_iter_next(struct interval_tree_double_span_iter * iter)146 void interval_tree_double_span_iter_next(
147 	struct interval_tree_double_span_iter *iter)
148 {
149 	unsigned int i;
150 
151 	if (iter->is_used == -1 ||
152 	    iter->last_hole == iter->spans[0].last_index) {
153 		iter->is_used = -1;
154 		return;
155 	}
156 
157 	for (i = 0; i != ARRAY_SIZE(iter->spans); i++)
158 		interval_tree_span_iter_advance(
159 			&iter->spans[i], iter->itrees[i], iter->last_hole + 1);
160 	interval_tree_double_span_iter_update(iter);
161 }
162 
iopt_pages_add_npinned(struct iopt_pages * pages,size_t npages)163 static void iopt_pages_add_npinned(struct iopt_pages *pages, size_t npages)
164 {
165 	int rc;
166 
167 	rc = check_add_overflow(pages->npinned, npages, &pages->npinned);
168 	if (IS_ENABLED(CONFIG_IOMMUFD_TEST))
169 		WARN_ON(rc || pages->npinned > pages->npages);
170 }
171 
iopt_pages_sub_npinned(struct iopt_pages * pages,size_t npages)172 static void iopt_pages_sub_npinned(struct iopt_pages *pages, size_t npages)
173 {
174 	int rc;
175 
176 	rc = check_sub_overflow(pages->npinned, npages, &pages->npinned);
177 	if (IS_ENABLED(CONFIG_IOMMUFD_TEST))
178 		WARN_ON(rc || pages->npinned > pages->npages);
179 }
180 
iopt_pages_err_unpin(struct iopt_pages * pages,unsigned long start_index,unsigned long last_index,struct page ** page_list)181 static void iopt_pages_err_unpin(struct iopt_pages *pages,
182 				 unsigned long start_index,
183 				 unsigned long last_index,
184 				 struct page **page_list)
185 {
186 	unsigned long npages = last_index - start_index + 1;
187 
188 	unpin_user_pages(page_list, npages);
189 	iopt_pages_sub_npinned(pages, npages);
190 }
191 
192 /*
193  * index is the number of PAGE_SIZE units from the start of the area's
194  * iopt_pages. If the iova is sub page-size then the area has an iova that
195  * covers a portion of the first and last pages in the range.
196  */
iopt_area_index_to_iova(struct iopt_area * area,unsigned long index)197 static unsigned long iopt_area_index_to_iova(struct iopt_area *area,
198 					     unsigned long index)
199 {
200 	if (IS_ENABLED(CONFIG_IOMMUFD_TEST))
201 		WARN_ON(index < iopt_area_index(area) ||
202 			index > iopt_area_last_index(area));
203 	index -= iopt_area_index(area);
204 	if (index == 0)
205 		return iopt_area_iova(area);
206 	return iopt_area_iova(area) - area->page_offset + index * PAGE_SIZE;
207 }
208 
iopt_area_index_to_iova_last(struct iopt_area * area,unsigned long index)209 static unsigned long iopt_area_index_to_iova_last(struct iopt_area *area,
210 						  unsigned long index)
211 {
212 	if (IS_ENABLED(CONFIG_IOMMUFD_TEST))
213 		WARN_ON(index < iopt_area_index(area) ||
214 			index > iopt_area_last_index(area));
215 	if (index == iopt_area_last_index(area))
216 		return iopt_area_last_iova(area);
217 	return iopt_area_iova(area) - area->page_offset +
218 	       (index - iopt_area_index(area) + 1) * PAGE_SIZE - 1;
219 }
220 
iommu_unmap_nofail(struct iommu_domain * domain,unsigned long iova,size_t size)221 static void iommu_unmap_nofail(struct iommu_domain *domain, unsigned long iova,
222 			       size_t size)
223 {
224 	size_t ret;
225 
226 	ret = iommu_unmap(domain, iova, size);
227 	/*
228 	 * It is a logic error in this code or a driver bug if the IOMMU unmaps
229 	 * something other than exactly as requested. This implies that the
230 	 * iommu driver may not fail unmap for reasons beyond bad agruments.
231 	 * Particularly, the iommu driver may not do a memory allocation on the
232 	 * unmap path.
233 	 */
234 	WARN_ON(ret != size);
235 }
236 
iopt_area_unmap_domain_range(struct iopt_area * area,struct iommu_domain * domain,unsigned long start_index,unsigned long last_index)237 static void iopt_area_unmap_domain_range(struct iopt_area *area,
238 					 struct iommu_domain *domain,
239 					 unsigned long start_index,
240 					 unsigned long last_index)
241 {
242 	unsigned long start_iova = iopt_area_index_to_iova(area, start_index);
243 
244 	iommu_unmap_nofail(domain, start_iova,
245 			   iopt_area_index_to_iova_last(area, last_index) -
246 				   start_iova + 1);
247 }
248 
iopt_pages_find_domain_area(struct iopt_pages * pages,unsigned long index)249 static struct iopt_area *iopt_pages_find_domain_area(struct iopt_pages *pages,
250 						     unsigned long index)
251 {
252 	struct interval_tree_node *node;
253 
254 	node = interval_tree_iter_first(&pages->domains_itree, index, index);
255 	if (!node)
256 		return NULL;
257 	return container_of(node, struct iopt_area, pages_node);
258 }
259 
260 /*
261  * A simple datastructure to hold a vector of PFNs, optimized for contiguous
262  * PFNs. This is used as a temporary holding memory for shuttling pfns from one
263  * place to another. Generally everything is made more efficient if operations
264  * work on the largest possible grouping of pfns. eg fewer lock/unlock cycles,
265  * better cache locality, etc
266  */
267 struct pfn_batch {
268 	unsigned long *pfns;
269 	u32 *npfns;
270 	unsigned int array_size;
271 	unsigned int end;
272 	unsigned int total_pfns;
273 };
274 
batch_clear(struct pfn_batch * batch)275 static void batch_clear(struct pfn_batch *batch)
276 {
277 	batch->total_pfns = 0;
278 	batch->end = 0;
279 	batch->pfns[0] = 0;
280 	batch->npfns[0] = 0;
281 }
282 
283 /*
284  * Carry means we carry a portion of the final hugepage over to the front of the
285  * batch
286  */
batch_clear_carry(struct pfn_batch * batch,unsigned int keep_pfns)287 static void batch_clear_carry(struct pfn_batch *batch, unsigned int keep_pfns)
288 {
289 	if (!keep_pfns)
290 		return batch_clear(batch);
291 
292 	if (IS_ENABLED(CONFIG_IOMMUFD_TEST))
293 		WARN_ON(!batch->end ||
294 			batch->npfns[batch->end - 1] < keep_pfns);
295 
296 	batch->total_pfns = keep_pfns;
297 	batch->pfns[0] = batch->pfns[batch->end - 1] +
298 			 (batch->npfns[batch->end - 1] - keep_pfns);
299 	batch->npfns[0] = keep_pfns;
300 	batch->end = 1;
301 }
302 
batch_skip_carry(struct pfn_batch * batch,unsigned int skip_pfns)303 static void batch_skip_carry(struct pfn_batch *batch, unsigned int skip_pfns)
304 {
305 	if (!batch->total_pfns)
306 		return;
307 	if (IS_ENABLED(CONFIG_IOMMUFD_TEST))
308 		WARN_ON(batch->total_pfns != batch->npfns[0]);
309 	skip_pfns = min(batch->total_pfns, skip_pfns);
310 	batch->pfns[0] += skip_pfns;
311 	batch->npfns[0] -= skip_pfns;
312 	batch->total_pfns -= skip_pfns;
313 }
314 
__batch_init(struct pfn_batch * batch,size_t max_pages,void * backup,size_t backup_len)315 static int __batch_init(struct pfn_batch *batch, size_t max_pages, void *backup,
316 			size_t backup_len)
317 {
318 	const size_t elmsz = sizeof(*batch->pfns) + sizeof(*batch->npfns);
319 	size_t size = max_pages * elmsz;
320 
321 	batch->pfns = temp_kmalloc(&size, backup, backup_len);
322 	if (!batch->pfns)
323 		return -ENOMEM;
324 	if (IS_ENABLED(CONFIG_IOMMUFD_TEST) && WARN_ON(size < elmsz))
325 		return -EINVAL;
326 	batch->array_size = size / elmsz;
327 	batch->npfns = (u32 *)(batch->pfns + batch->array_size);
328 	batch_clear(batch);
329 	return 0;
330 }
331 
batch_init(struct pfn_batch * batch,size_t max_pages)332 static int batch_init(struct pfn_batch *batch, size_t max_pages)
333 {
334 	return __batch_init(batch, max_pages, NULL, 0);
335 }
336 
batch_init_backup(struct pfn_batch * batch,size_t max_pages,void * backup,size_t backup_len)337 static void batch_init_backup(struct pfn_batch *batch, size_t max_pages,
338 			      void *backup, size_t backup_len)
339 {
340 	__batch_init(batch, max_pages, backup, backup_len);
341 }
342 
batch_destroy(struct pfn_batch * batch,void * backup)343 static void batch_destroy(struct pfn_batch *batch, void *backup)
344 {
345 	if (batch->pfns != backup)
346 		kfree(batch->pfns);
347 }
348 
349 /* true if the pfn was added, false otherwise */
batch_add_pfn(struct pfn_batch * batch,unsigned long pfn)350 static bool batch_add_pfn(struct pfn_batch *batch, unsigned long pfn)
351 {
352 	const unsigned int MAX_NPFNS = type_max(typeof(*batch->npfns));
353 
354 	if (batch->end &&
355 	    pfn == batch->pfns[batch->end - 1] + batch->npfns[batch->end - 1] &&
356 	    batch->npfns[batch->end - 1] != MAX_NPFNS) {
357 		batch->npfns[batch->end - 1]++;
358 		batch->total_pfns++;
359 		return true;
360 	}
361 	if (batch->end == batch->array_size)
362 		return false;
363 	batch->total_pfns++;
364 	batch->pfns[batch->end] = pfn;
365 	batch->npfns[batch->end] = 1;
366 	batch->end++;
367 	return true;
368 }
369 
370 /*
371  * Fill the batch with pfns from the domain. When the batch is full, or it
372  * reaches last_index, the function will return. The caller should use
373  * batch->total_pfns to determine the starting point for the next iteration.
374  */
batch_from_domain(struct pfn_batch * batch,struct iommu_domain * domain,struct iopt_area * area,unsigned long start_index,unsigned long last_index)375 static void batch_from_domain(struct pfn_batch *batch,
376 			      struct iommu_domain *domain,
377 			      struct iopt_area *area, unsigned long start_index,
378 			      unsigned long last_index)
379 {
380 	unsigned int page_offset = 0;
381 	unsigned long iova;
382 	phys_addr_t phys;
383 
384 	iova = iopt_area_index_to_iova(area, start_index);
385 	if (start_index == iopt_area_index(area))
386 		page_offset = area->page_offset;
387 	while (start_index <= last_index) {
388 		/*
389 		 * This is pretty slow, it would be nice to get the page size
390 		 * back from the driver, or have the driver directly fill the
391 		 * batch.
392 		 */
393 		phys = iommu_iova_to_phys(domain, iova) - page_offset;
394 		if (!batch_add_pfn(batch, PHYS_PFN(phys)))
395 			return;
396 		iova += PAGE_SIZE - page_offset;
397 		page_offset = 0;
398 		start_index++;
399 	}
400 }
401 
raw_pages_from_domain(struct iommu_domain * domain,struct iopt_area * area,unsigned long start_index,unsigned long last_index,struct page ** out_pages)402 static struct page **raw_pages_from_domain(struct iommu_domain *domain,
403 					   struct iopt_area *area,
404 					   unsigned long start_index,
405 					   unsigned long last_index,
406 					   struct page **out_pages)
407 {
408 	unsigned int page_offset = 0;
409 	unsigned long iova;
410 	phys_addr_t phys;
411 
412 	iova = iopt_area_index_to_iova(area, start_index);
413 	if (start_index == iopt_area_index(area))
414 		page_offset = area->page_offset;
415 	while (start_index <= last_index) {
416 		phys = iommu_iova_to_phys(domain, iova) - page_offset;
417 		*(out_pages++) = pfn_to_page(PHYS_PFN(phys));
418 		iova += PAGE_SIZE - page_offset;
419 		page_offset = 0;
420 		start_index++;
421 	}
422 	return out_pages;
423 }
424 
425 /* Continues reading a domain until we reach a discontinuity in the pfns. */
batch_from_domain_continue(struct pfn_batch * batch,struct iommu_domain * domain,struct iopt_area * area,unsigned long start_index,unsigned long last_index)426 static void batch_from_domain_continue(struct pfn_batch *batch,
427 				       struct iommu_domain *domain,
428 				       struct iopt_area *area,
429 				       unsigned long start_index,
430 				       unsigned long last_index)
431 {
432 	unsigned int array_size = batch->array_size;
433 
434 	batch->array_size = batch->end;
435 	batch_from_domain(batch, domain, area, start_index, last_index);
436 	batch->array_size = array_size;
437 }
438 
439 /*
440  * This is part of the VFIO compatibility support for VFIO_TYPE1_IOMMU. That
441  * mode permits splitting a mapped area up, and then one of the splits is
442  * unmapped. Doing this normally would cause us to violate our invariant of
443  * pairing map/unmap. Thus, to support old VFIO compatibility disable support
444  * for batching consecutive PFNs. All PFNs mapped into the iommu are done in
445  * PAGE_SIZE units, not larger or smaller.
446  */
batch_iommu_map_small(struct iommu_domain * domain,unsigned long iova,phys_addr_t paddr,size_t size,int prot)447 static int batch_iommu_map_small(struct iommu_domain *domain,
448 				 unsigned long iova, phys_addr_t paddr,
449 				 size_t size, int prot)
450 {
451 	unsigned long start_iova = iova;
452 	int rc;
453 
454 	if (IS_ENABLED(CONFIG_IOMMUFD_TEST))
455 		WARN_ON(paddr % PAGE_SIZE || iova % PAGE_SIZE ||
456 			size % PAGE_SIZE);
457 
458 	while (size) {
459 		rc = iommu_map(domain, iova, paddr, PAGE_SIZE, prot,
460 			       GFP_KERNEL_ACCOUNT);
461 		if (rc)
462 			goto err_unmap;
463 		iova += PAGE_SIZE;
464 		paddr += PAGE_SIZE;
465 		size -= PAGE_SIZE;
466 	}
467 	return 0;
468 
469 err_unmap:
470 	if (start_iova != iova)
471 		iommu_unmap_nofail(domain, start_iova, iova - start_iova);
472 	return rc;
473 }
474 
batch_to_domain(struct pfn_batch * batch,struct iommu_domain * domain,struct iopt_area * area,unsigned long start_index)475 static int batch_to_domain(struct pfn_batch *batch, struct iommu_domain *domain,
476 			   struct iopt_area *area, unsigned long start_index)
477 {
478 	bool disable_large_pages = area->iopt->disable_large_pages;
479 	unsigned long last_iova = iopt_area_last_iova(area);
480 	unsigned int page_offset = 0;
481 	unsigned long start_iova;
482 	unsigned long next_iova;
483 	unsigned int cur = 0;
484 	unsigned long iova;
485 	int rc;
486 
487 	/* The first index might be a partial page */
488 	if (start_index == iopt_area_index(area))
489 		page_offset = area->page_offset;
490 	next_iova = iova = start_iova =
491 		iopt_area_index_to_iova(area, start_index);
492 	while (cur < batch->end) {
493 		next_iova = min(last_iova + 1,
494 				next_iova + batch->npfns[cur] * PAGE_SIZE -
495 					page_offset);
496 		if (disable_large_pages)
497 			rc = batch_iommu_map_small(
498 				domain, iova,
499 				PFN_PHYS(batch->pfns[cur]) + page_offset,
500 				next_iova - iova, area->iommu_prot);
501 		else
502 			rc = iommu_map(domain, iova,
503 				       PFN_PHYS(batch->pfns[cur]) + page_offset,
504 				       next_iova - iova, area->iommu_prot,
505 				       GFP_KERNEL_ACCOUNT);
506 		if (rc)
507 			goto err_unmap;
508 		iova = next_iova;
509 		page_offset = 0;
510 		cur++;
511 	}
512 	return 0;
513 err_unmap:
514 	if (start_iova != iova)
515 		iommu_unmap_nofail(domain, start_iova, iova - start_iova);
516 	return rc;
517 }
518 
batch_from_xarray(struct pfn_batch * batch,struct xarray * xa,unsigned long start_index,unsigned long last_index)519 static void batch_from_xarray(struct pfn_batch *batch, struct xarray *xa,
520 			      unsigned long start_index,
521 			      unsigned long last_index)
522 {
523 	XA_STATE(xas, xa, start_index);
524 	void *entry;
525 
526 	rcu_read_lock();
527 	while (true) {
528 		entry = xas_next(&xas);
529 		if (xas_retry(&xas, entry))
530 			continue;
531 		WARN_ON(!xa_is_value(entry));
532 		if (!batch_add_pfn(batch, xa_to_value(entry)) ||
533 		    start_index == last_index)
534 			break;
535 		start_index++;
536 	}
537 	rcu_read_unlock();
538 }
539 
batch_from_xarray_clear(struct pfn_batch * batch,struct xarray * xa,unsigned long start_index,unsigned long last_index)540 static void batch_from_xarray_clear(struct pfn_batch *batch, struct xarray *xa,
541 				    unsigned long start_index,
542 				    unsigned long last_index)
543 {
544 	XA_STATE(xas, xa, start_index);
545 	void *entry;
546 
547 	xas_lock(&xas);
548 	while (true) {
549 		entry = xas_next(&xas);
550 		if (xas_retry(&xas, entry))
551 			continue;
552 		WARN_ON(!xa_is_value(entry));
553 		if (!batch_add_pfn(batch, xa_to_value(entry)))
554 			break;
555 		xas_store(&xas, NULL);
556 		if (start_index == last_index)
557 			break;
558 		start_index++;
559 	}
560 	xas_unlock(&xas);
561 }
562 
clear_xarray(struct xarray * xa,unsigned long start_index,unsigned long last_index)563 static void clear_xarray(struct xarray *xa, unsigned long start_index,
564 			 unsigned long last_index)
565 {
566 	XA_STATE(xas, xa, start_index);
567 	void *entry;
568 
569 	xas_lock(&xas);
570 	xas_for_each(&xas, entry, last_index)
571 		xas_store(&xas, NULL);
572 	xas_unlock(&xas);
573 }
574 
pages_to_xarray(struct xarray * xa,unsigned long start_index,unsigned long last_index,struct page ** pages)575 static int pages_to_xarray(struct xarray *xa, unsigned long start_index,
576 			   unsigned long last_index, struct page **pages)
577 {
578 	struct page **end_pages = pages + (last_index - start_index) + 1;
579 	struct page **half_pages = pages + (end_pages - pages) / 2;
580 	XA_STATE(xas, xa, start_index);
581 
582 	do {
583 		void *old;
584 
585 		xas_lock(&xas);
586 		while (pages != end_pages) {
587 			/* xarray does not participate in fault injection */
588 			if (pages == half_pages && iommufd_should_fail()) {
589 				xas_set_err(&xas, -EINVAL);
590 				xas_unlock(&xas);
591 				/* aka xas_destroy() */
592 				xas_nomem(&xas, GFP_KERNEL);
593 				goto err_clear;
594 			}
595 
596 			old = xas_store(&xas, xa_mk_value(page_to_pfn(*pages)));
597 			if (xas_error(&xas))
598 				break;
599 			WARN_ON(old);
600 			pages++;
601 			xas_next(&xas);
602 		}
603 		xas_unlock(&xas);
604 	} while (xas_nomem(&xas, GFP_KERNEL));
605 
606 err_clear:
607 	if (xas_error(&xas)) {
608 		if (xas.xa_index != start_index)
609 			clear_xarray(xa, start_index, xas.xa_index - 1);
610 		return xas_error(&xas);
611 	}
612 	return 0;
613 }
614 
batch_from_pages(struct pfn_batch * batch,struct page ** pages,size_t npages)615 static void batch_from_pages(struct pfn_batch *batch, struct page **pages,
616 			     size_t npages)
617 {
618 	struct page **end = pages + npages;
619 
620 	for (; pages != end; pages++)
621 		if (!batch_add_pfn(batch, page_to_pfn(*pages)))
622 			break;
623 }
624 
batch_unpin(struct pfn_batch * batch,struct iopt_pages * pages,unsigned int first_page_off,size_t npages)625 static void batch_unpin(struct pfn_batch *batch, struct iopt_pages *pages,
626 			unsigned int first_page_off, size_t npages)
627 {
628 	unsigned int cur = 0;
629 
630 	while (first_page_off) {
631 		if (batch->npfns[cur] > first_page_off)
632 			break;
633 		first_page_off -= batch->npfns[cur];
634 		cur++;
635 	}
636 
637 	while (npages) {
638 		size_t to_unpin = min_t(size_t, npages,
639 					batch->npfns[cur] - first_page_off);
640 
641 		unpin_user_page_range_dirty_lock(
642 			pfn_to_page(batch->pfns[cur] + first_page_off),
643 			to_unpin, pages->writable);
644 		iopt_pages_sub_npinned(pages, to_unpin);
645 		cur++;
646 		first_page_off = 0;
647 		npages -= to_unpin;
648 	}
649 }
650 
copy_data_page(struct page * page,void * data,unsigned long offset,size_t length,unsigned int flags)651 static void copy_data_page(struct page *page, void *data, unsigned long offset,
652 			   size_t length, unsigned int flags)
653 {
654 	void *mem;
655 
656 	mem = kmap_local_page(page);
657 	if (flags & IOMMUFD_ACCESS_RW_WRITE) {
658 		memcpy(mem + offset, data, length);
659 		set_page_dirty_lock(page);
660 	} else {
661 		memcpy(data, mem + offset, length);
662 	}
663 	kunmap_local(mem);
664 }
665 
batch_rw(struct pfn_batch * batch,void * data,unsigned long offset,unsigned long length,unsigned int flags)666 static unsigned long batch_rw(struct pfn_batch *batch, void *data,
667 			      unsigned long offset, unsigned long length,
668 			      unsigned int flags)
669 {
670 	unsigned long copied = 0;
671 	unsigned int npage = 0;
672 	unsigned int cur = 0;
673 
674 	while (cur < batch->end) {
675 		unsigned long bytes = min(length, PAGE_SIZE - offset);
676 
677 		copy_data_page(pfn_to_page(batch->pfns[cur] + npage), data,
678 			       offset, bytes, flags);
679 		offset = 0;
680 		length -= bytes;
681 		data += bytes;
682 		copied += bytes;
683 		npage++;
684 		if (npage == batch->npfns[cur]) {
685 			npage = 0;
686 			cur++;
687 		}
688 		if (!length)
689 			break;
690 	}
691 	return copied;
692 }
693 
694 /* pfn_reader_user is just the pin_user_pages() path */
695 struct pfn_reader_user {
696 	struct page **upages;
697 	size_t upages_len;
698 	unsigned long upages_start;
699 	unsigned long upages_end;
700 	unsigned int gup_flags;
701 	/*
702 	 * 1 means mmget() and mmap_read_lock(), 0 means only mmget(), -1 is
703 	 * neither
704 	 */
705 	int locked;
706 };
707 
pfn_reader_user_init(struct pfn_reader_user * user,struct iopt_pages * pages)708 static void pfn_reader_user_init(struct pfn_reader_user *user,
709 				 struct iopt_pages *pages)
710 {
711 	user->upages = NULL;
712 	user->upages_start = 0;
713 	user->upages_end = 0;
714 	user->locked = -1;
715 
716 	user->gup_flags = FOLL_LONGTERM;
717 	if (pages->writable)
718 		user->gup_flags |= FOLL_WRITE;
719 }
720 
pfn_reader_user_destroy(struct pfn_reader_user * user,struct iopt_pages * pages)721 static void pfn_reader_user_destroy(struct pfn_reader_user *user,
722 				    struct iopt_pages *pages)
723 {
724 	if (user->locked != -1) {
725 		if (user->locked)
726 			mmap_read_unlock(pages->source_mm);
727 		if (pages->source_mm != current->mm)
728 			mmput(pages->source_mm);
729 		user->locked = -1;
730 	}
731 
732 	kfree(user->upages);
733 	user->upages = NULL;
734 }
735 
pfn_reader_user_pin(struct pfn_reader_user * user,struct iopt_pages * pages,unsigned long start_index,unsigned long last_index)736 static int pfn_reader_user_pin(struct pfn_reader_user *user,
737 			       struct iopt_pages *pages,
738 			       unsigned long start_index,
739 			       unsigned long last_index)
740 {
741 	bool remote_mm = pages->source_mm != current->mm;
742 	unsigned long npages;
743 	uintptr_t uptr;
744 	long rc;
745 
746 	if (IS_ENABLED(CONFIG_IOMMUFD_TEST) &&
747 	    WARN_ON(last_index < start_index))
748 		return -EINVAL;
749 
750 	if (!user->upages) {
751 		/* All undone in pfn_reader_destroy() */
752 		user->upages_len =
753 			(last_index - start_index + 1) * sizeof(*user->upages);
754 		user->upages = temp_kmalloc(&user->upages_len, NULL, 0);
755 		if (!user->upages)
756 			return -ENOMEM;
757 	}
758 
759 	if (user->locked == -1) {
760 		/*
761 		 * The majority of usages will run the map task within the mm
762 		 * providing the pages, so we can optimize into
763 		 * get_user_pages_fast()
764 		 */
765 		if (remote_mm) {
766 			if (!mmget_not_zero(pages->source_mm))
767 				return -EFAULT;
768 		}
769 		user->locked = 0;
770 	}
771 
772 	npages = min_t(unsigned long, last_index - start_index + 1,
773 		       user->upages_len / sizeof(*user->upages));
774 
775 
776 	if (iommufd_should_fail())
777 		return -EFAULT;
778 
779 	uptr = (uintptr_t)(pages->uptr + start_index * PAGE_SIZE);
780 	if (!remote_mm)
781 		rc = pin_user_pages_fast(uptr, npages, user->gup_flags,
782 					 user->upages);
783 	else {
784 		if (!user->locked) {
785 			mmap_read_lock(pages->source_mm);
786 			user->locked = 1;
787 		}
788 		rc = pin_user_pages_remote(pages->source_mm, uptr, npages,
789 					   user->gup_flags, user->upages,
790 					   &user->locked);
791 	}
792 	if (rc <= 0) {
793 		if (WARN_ON(!rc))
794 			return -EFAULT;
795 		return rc;
796 	}
797 	iopt_pages_add_npinned(pages, rc);
798 	user->upages_start = start_index;
799 	user->upages_end = start_index + rc;
800 	return 0;
801 }
802 
803 /* This is the "modern" and faster accounting method used by io_uring */
incr_user_locked_vm(struct iopt_pages * pages,unsigned long npages)804 static int incr_user_locked_vm(struct iopt_pages *pages, unsigned long npages)
805 {
806 	unsigned long lock_limit;
807 	unsigned long cur_pages;
808 	unsigned long new_pages;
809 
810 	lock_limit = task_rlimit(pages->source_task, RLIMIT_MEMLOCK) >>
811 		     PAGE_SHIFT;
812 	do {
813 		cur_pages = atomic_long_read(&pages->source_user->locked_vm);
814 		new_pages = cur_pages + npages;
815 		if (new_pages > lock_limit)
816 			return -ENOMEM;
817 	} while (atomic_long_cmpxchg(&pages->source_user->locked_vm, cur_pages,
818 				     new_pages) != cur_pages);
819 	return 0;
820 }
821 
decr_user_locked_vm(struct iopt_pages * pages,unsigned long npages)822 static void decr_user_locked_vm(struct iopt_pages *pages, unsigned long npages)
823 {
824 	if (WARN_ON(atomic_long_read(&pages->source_user->locked_vm) < npages))
825 		return;
826 	atomic_long_sub(npages, &pages->source_user->locked_vm);
827 }
828 
829 /* This is the accounting method used for compatibility with VFIO */
update_mm_locked_vm(struct iopt_pages * pages,unsigned long npages,bool inc,struct pfn_reader_user * user)830 static int update_mm_locked_vm(struct iopt_pages *pages, unsigned long npages,
831 			       bool inc, struct pfn_reader_user *user)
832 {
833 	bool do_put = false;
834 	int rc;
835 
836 	if (user && user->locked) {
837 		mmap_read_unlock(pages->source_mm);
838 		user->locked = 0;
839 		/* If we had the lock then we also have a get */
840 	} else if ((!user || !user->upages) &&
841 		   pages->source_mm != current->mm) {
842 		if (!mmget_not_zero(pages->source_mm))
843 			return -EINVAL;
844 		do_put = true;
845 	}
846 
847 	mmap_write_lock(pages->source_mm);
848 	rc = __account_locked_vm(pages->source_mm, npages, inc,
849 				 pages->source_task, false);
850 	mmap_write_unlock(pages->source_mm);
851 
852 	if (do_put)
853 		mmput(pages->source_mm);
854 	return rc;
855 }
856 
do_update_pinned(struct iopt_pages * pages,unsigned long npages,bool inc,struct pfn_reader_user * user)857 static int do_update_pinned(struct iopt_pages *pages, unsigned long npages,
858 			    bool inc, struct pfn_reader_user *user)
859 {
860 	int rc = 0;
861 
862 	switch (pages->account_mode) {
863 	case IOPT_PAGES_ACCOUNT_NONE:
864 		break;
865 	case IOPT_PAGES_ACCOUNT_USER:
866 		if (inc)
867 			rc = incr_user_locked_vm(pages, npages);
868 		else
869 			decr_user_locked_vm(pages, npages);
870 		break;
871 	case IOPT_PAGES_ACCOUNT_MM:
872 		rc = update_mm_locked_vm(pages, npages, inc, user);
873 		break;
874 	}
875 	if (rc)
876 		return rc;
877 
878 	pages->last_npinned = pages->npinned;
879 	if (inc)
880 		atomic64_add(npages, &pages->source_mm->pinned_vm);
881 	else
882 		atomic64_sub(npages, &pages->source_mm->pinned_vm);
883 	return 0;
884 }
885 
update_unpinned(struct iopt_pages * pages)886 static void update_unpinned(struct iopt_pages *pages)
887 {
888 	if (WARN_ON(pages->npinned > pages->last_npinned))
889 		return;
890 	if (pages->npinned == pages->last_npinned)
891 		return;
892 	do_update_pinned(pages, pages->last_npinned - pages->npinned, false,
893 			 NULL);
894 }
895 
896 /*
897  * Changes in the number of pages pinned is done after the pages have been read
898  * and processed. If the user lacked the limit then the error unwind will unpin
899  * everything that was just pinned. This is because it is expensive to calculate
900  * how many pages we have already pinned within a range to generate an accurate
901  * prediction in advance of doing the work to actually pin them.
902  */
pfn_reader_user_update_pinned(struct pfn_reader_user * user,struct iopt_pages * pages)903 static int pfn_reader_user_update_pinned(struct pfn_reader_user *user,
904 					 struct iopt_pages *pages)
905 {
906 	unsigned long npages;
907 	bool inc;
908 
909 	lockdep_assert_held(&pages->mutex);
910 
911 	if (pages->npinned == pages->last_npinned)
912 		return 0;
913 
914 	if (pages->npinned < pages->last_npinned) {
915 		npages = pages->last_npinned - pages->npinned;
916 		inc = false;
917 	} else {
918 		if (iommufd_should_fail())
919 			return -ENOMEM;
920 		npages = pages->npinned - pages->last_npinned;
921 		inc = true;
922 	}
923 	return do_update_pinned(pages, npages, inc, user);
924 }
925 
926 /*
927  * PFNs are stored in three places, in order of preference:
928  * - The iopt_pages xarray. This is only populated if there is a
929  *   iopt_pages_access
930  * - The iommu_domain under an area
931  * - The original PFN source, ie pages->source_mm
932  *
933  * This iterator reads the pfns optimizing to load according to the
934  * above order.
935  */
936 struct pfn_reader {
937 	struct iopt_pages *pages;
938 	struct interval_tree_double_span_iter span;
939 	struct pfn_batch batch;
940 	unsigned long batch_start_index;
941 	unsigned long batch_end_index;
942 	unsigned long last_index;
943 
944 	struct pfn_reader_user user;
945 };
946 
pfn_reader_update_pinned(struct pfn_reader * pfns)947 static int pfn_reader_update_pinned(struct pfn_reader *pfns)
948 {
949 	return pfn_reader_user_update_pinned(&pfns->user, pfns->pages);
950 }
951 
952 /*
953  * The batch can contain a mixture of pages that are still in use and pages that
954  * need to be unpinned. Unpin only pages that are not held anywhere else.
955  */
pfn_reader_unpin(struct pfn_reader * pfns)956 static void pfn_reader_unpin(struct pfn_reader *pfns)
957 {
958 	unsigned long last = pfns->batch_end_index - 1;
959 	unsigned long start = pfns->batch_start_index;
960 	struct interval_tree_double_span_iter span;
961 	struct iopt_pages *pages = pfns->pages;
962 
963 	lockdep_assert_held(&pages->mutex);
964 
965 	interval_tree_for_each_double_span(&span, &pages->access_itree,
966 					   &pages->domains_itree, start, last) {
967 		if (span.is_used)
968 			continue;
969 
970 		batch_unpin(&pfns->batch, pages, span.start_hole - start,
971 			    span.last_hole - span.start_hole + 1);
972 	}
973 }
974 
975 /* Process a single span to load it from the proper storage */
pfn_reader_fill_span(struct pfn_reader * pfns)976 static int pfn_reader_fill_span(struct pfn_reader *pfns)
977 {
978 	struct interval_tree_double_span_iter *span = &pfns->span;
979 	unsigned long start_index = pfns->batch_end_index;
980 	struct iopt_area *area;
981 	int rc;
982 
983 	if (IS_ENABLED(CONFIG_IOMMUFD_TEST) &&
984 	    WARN_ON(span->last_used < start_index))
985 		return -EINVAL;
986 
987 	if (span->is_used == 1) {
988 		batch_from_xarray(&pfns->batch, &pfns->pages->pinned_pfns,
989 				  start_index, span->last_used);
990 		return 0;
991 	}
992 
993 	if (span->is_used == 2) {
994 		/*
995 		 * Pull as many pages from the first domain we find in the
996 		 * target span. If it is too small then we will be called again
997 		 * and we'll find another area.
998 		 */
999 		area = iopt_pages_find_domain_area(pfns->pages, start_index);
1000 		if (WARN_ON(!area))
1001 			return -EINVAL;
1002 
1003 		/* The storage_domain cannot change without the pages mutex */
1004 		batch_from_domain(
1005 			&pfns->batch, area->storage_domain, area, start_index,
1006 			min(iopt_area_last_index(area), span->last_used));
1007 		return 0;
1008 	}
1009 
1010 	if (start_index >= pfns->user.upages_end) {
1011 		rc = pfn_reader_user_pin(&pfns->user, pfns->pages, start_index,
1012 					 span->last_hole);
1013 		if (rc)
1014 			return rc;
1015 	}
1016 
1017 	batch_from_pages(&pfns->batch,
1018 			 pfns->user.upages +
1019 				 (start_index - pfns->user.upages_start),
1020 			 pfns->user.upages_end - start_index);
1021 	return 0;
1022 }
1023 
pfn_reader_done(struct pfn_reader * pfns)1024 static bool pfn_reader_done(struct pfn_reader *pfns)
1025 {
1026 	return pfns->batch_start_index == pfns->last_index + 1;
1027 }
1028 
pfn_reader_next(struct pfn_reader * pfns)1029 static int pfn_reader_next(struct pfn_reader *pfns)
1030 {
1031 	int rc;
1032 
1033 	batch_clear(&pfns->batch);
1034 	pfns->batch_start_index = pfns->batch_end_index;
1035 
1036 	while (pfns->batch_end_index != pfns->last_index + 1) {
1037 		unsigned int npfns = pfns->batch.total_pfns;
1038 
1039 		if (IS_ENABLED(CONFIG_IOMMUFD_TEST) &&
1040 		    WARN_ON(interval_tree_double_span_iter_done(&pfns->span)))
1041 			return -EINVAL;
1042 
1043 		rc = pfn_reader_fill_span(pfns);
1044 		if (rc)
1045 			return rc;
1046 
1047 		if (WARN_ON(!pfns->batch.total_pfns))
1048 			return -EINVAL;
1049 
1050 		pfns->batch_end_index =
1051 			pfns->batch_start_index + pfns->batch.total_pfns;
1052 		if (pfns->batch_end_index == pfns->span.last_used + 1)
1053 			interval_tree_double_span_iter_next(&pfns->span);
1054 
1055 		/* Batch is full */
1056 		if (npfns == pfns->batch.total_pfns)
1057 			return 0;
1058 	}
1059 	return 0;
1060 }
1061 
pfn_reader_init(struct pfn_reader * pfns,struct iopt_pages * pages,unsigned long start_index,unsigned long last_index)1062 static int pfn_reader_init(struct pfn_reader *pfns, struct iopt_pages *pages,
1063 			   unsigned long start_index, unsigned long last_index)
1064 {
1065 	int rc;
1066 
1067 	lockdep_assert_held(&pages->mutex);
1068 
1069 	pfns->pages = pages;
1070 	pfns->batch_start_index = start_index;
1071 	pfns->batch_end_index = start_index;
1072 	pfns->last_index = last_index;
1073 	pfn_reader_user_init(&pfns->user, pages);
1074 	rc = batch_init(&pfns->batch, last_index - start_index + 1);
1075 	if (rc)
1076 		return rc;
1077 	interval_tree_double_span_iter_first(&pfns->span, &pages->access_itree,
1078 					     &pages->domains_itree, start_index,
1079 					     last_index);
1080 	return 0;
1081 }
1082 
1083 /*
1084  * There are many assertions regarding the state of pages->npinned vs
1085  * pages->last_pinned, for instance something like unmapping a domain must only
1086  * decrement the npinned, and pfn_reader_destroy() must be called only after all
1087  * the pins are updated. This is fine for success flows, but error flows
1088  * sometimes need to release the pins held inside the pfn_reader before going on
1089  * to complete unmapping and releasing pins held in domains.
1090  */
pfn_reader_release_pins(struct pfn_reader * pfns)1091 static void pfn_reader_release_pins(struct pfn_reader *pfns)
1092 {
1093 	struct iopt_pages *pages = pfns->pages;
1094 
1095 	if (pfns->user.upages_end > pfns->batch_end_index) {
1096 		size_t npages = pfns->user.upages_end - pfns->batch_end_index;
1097 
1098 		/* Any pages not transferred to the batch are just unpinned */
1099 		unpin_user_pages(pfns->user.upages + (pfns->batch_end_index -
1100 						      pfns->user.upages_start),
1101 				 npages);
1102 		iopt_pages_sub_npinned(pages, npages);
1103 		pfns->user.upages_end = pfns->batch_end_index;
1104 	}
1105 	if (pfns->batch_start_index != pfns->batch_end_index) {
1106 		pfn_reader_unpin(pfns);
1107 		pfns->batch_start_index = pfns->batch_end_index;
1108 	}
1109 }
1110 
pfn_reader_destroy(struct pfn_reader * pfns)1111 static void pfn_reader_destroy(struct pfn_reader *pfns)
1112 {
1113 	struct iopt_pages *pages = pfns->pages;
1114 
1115 	pfn_reader_release_pins(pfns);
1116 	pfn_reader_user_destroy(&pfns->user, pfns->pages);
1117 	batch_destroy(&pfns->batch, NULL);
1118 	WARN_ON(pages->last_npinned != pages->npinned);
1119 }
1120 
pfn_reader_first(struct pfn_reader * pfns,struct iopt_pages * pages,unsigned long start_index,unsigned long last_index)1121 static int pfn_reader_first(struct pfn_reader *pfns, struct iopt_pages *pages,
1122 			    unsigned long start_index, unsigned long last_index)
1123 {
1124 	int rc;
1125 
1126 	if (IS_ENABLED(CONFIG_IOMMUFD_TEST) &&
1127 	    WARN_ON(last_index < start_index))
1128 		return -EINVAL;
1129 
1130 	rc = pfn_reader_init(pfns, pages, start_index, last_index);
1131 	if (rc)
1132 		return rc;
1133 	rc = pfn_reader_next(pfns);
1134 	if (rc) {
1135 		pfn_reader_destroy(pfns);
1136 		return rc;
1137 	}
1138 	return 0;
1139 }
1140 
iopt_alloc_pages(void __user * uptr,unsigned long length,bool writable)1141 struct iopt_pages *iopt_alloc_pages(void __user *uptr, unsigned long length,
1142 				    bool writable)
1143 {
1144 	struct iopt_pages *pages;
1145 	unsigned long end;
1146 
1147 	/*
1148 	 * The iommu API uses size_t as the length, and protect the DIV_ROUND_UP
1149 	 * below from overflow
1150 	 */
1151 	if (length > SIZE_MAX - PAGE_SIZE || length == 0)
1152 		return ERR_PTR(-EINVAL);
1153 
1154 	if (check_add_overflow((unsigned long)uptr, length, &end))
1155 		return ERR_PTR(-EOVERFLOW);
1156 
1157 	pages = kzalloc(sizeof(*pages), GFP_KERNEL_ACCOUNT);
1158 	if (!pages)
1159 		return ERR_PTR(-ENOMEM);
1160 
1161 	kref_init(&pages->kref);
1162 	xa_init_flags(&pages->pinned_pfns, XA_FLAGS_ACCOUNT);
1163 	mutex_init(&pages->mutex);
1164 	pages->source_mm = current->mm;
1165 	mmgrab(pages->source_mm);
1166 	pages->uptr = (void __user *)ALIGN_DOWN((uintptr_t)uptr, PAGE_SIZE);
1167 	pages->npages = DIV_ROUND_UP(length + (uptr - pages->uptr), PAGE_SIZE);
1168 	pages->access_itree = RB_ROOT_CACHED;
1169 	pages->domains_itree = RB_ROOT_CACHED;
1170 	pages->writable = writable;
1171 	if (capable(CAP_IPC_LOCK))
1172 		pages->account_mode = IOPT_PAGES_ACCOUNT_NONE;
1173 	else
1174 		pages->account_mode = IOPT_PAGES_ACCOUNT_USER;
1175 	pages->source_task = current->group_leader;
1176 	get_task_struct(current->group_leader);
1177 	pages->source_user = get_uid(current_user());
1178 	return pages;
1179 }
1180 
iopt_release_pages(struct kref * kref)1181 void iopt_release_pages(struct kref *kref)
1182 {
1183 	struct iopt_pages *pages = container_of(kref, struct iopt_pages, kref);
1184 
1185 	WARN_ON(!RB_EMPTY_ROOT(&pages->access_itree.rb_root));
1186 	WARN_ON(!RB_EMPTY_ROOT(&pages->domains_itree.rb_root));
1187 	WARN_ON(pages->npinned);
1188 	WARN_ON(!xa_empty(&pages->pinned_pfns));
1189 	mmdrop(pages->source_mm);
1190 	mutex_destroy(&pages->mutex);
1191 	put_task_struct(pages->source_task);
1192 	free_uid(pages->source_user);
1193 	kfree(pages);
1194 }
1195 
1196 static void
iopt_area_unpin_domain(struct pfn_batch * batch,struct iopt_area * area,struct iopt_pages * pages,struct iommu_domain * domain,unsigned long start_index,unsigned long last_index,unsigned long * unmapped_end_index,unsigned long real_last_index)1197 iopt_area_unpin_domain(struct pfn_batch *batch, struct iopt_area *area,
1198 		       struct iopt_pages *pages, struct iommu_domain *domain,
1199 		       unsigned long start_index, unsigned long last_index,
1200 		       unsigned long *unmapped_end_index,
1201 		       unsigned long real_last_index)
1202 {
1203 	while (start_index <= last_index) {
1204 		unsigned long batch_last_index;
1205 
1206 		if (*unmapped_end_index <= last_index) {
1207 			unsigned long start =
1208 				max(start_index, *unmapped_end_index);
1209 
1210 			if (IS_ENABLED(CONFIG_IOMMUFD_TEST) &&
1211 			    batch->total_pfns)
1212 				WARN_ON(*unmapped_end_index -
1213 						batch->total_pfns !=
1214 					start_index);
1215 			batch_from_domain(batch, domain, area, start,
1216 					  last_index);
1217 			batch_last_index = start_index + batch->total_pfns - 1;
1218 		} else {
1219 			batch_last_index = last_index;
1220 		}
1221 
1222 		if (IS_ENABLED(CONFIG_IOMMUFD_TEST))
1223 			WARN_ON(batch_last_index > real_last_index);
1224 
1225 		/*
1226 		 * unmaps must always 'cut' at a place where the pfns are not
1227 		 * contiguous to pair with the maps that always install
1228 		 * contiguous pages. Thus, if we have to stop unpinning in the
1229 		 * middle of the domains we need to keep reading pfns until we
1230 		 * find a cut point to do the unmap. The pfns we read are
1231 		 * carried over and either skipped or integrated into the next
1232 		 * batch.
1233 		 */
1234 		if (batch_last_index == last_index &&
1235 		    last_index != real_last_index)
1236 			batch_from_domain_continue(batch, domain, area,
1237 						   last_index + 1,
1238 						   real_last_index);
1239 
1240 		if (*unmapped_end_index <= batch_last_index) {
1241 			iopt_area_unmap_domain_range(
1242 				area, domain, *unmapped_end_index,
1243 				start_index + batch->total_pfns - 1);
1244 			*unmapped_end_index = start_index + batch->total_pfns;
1245 		}
1246 
1247 		/* unpin must follow unmap */
1248 		batch_unpin(batch, pages, 0,
1249 			    batch_last_index - start_index + 1);
1250 		start_index = batch_last_index + 1;
1251 
1252 		batch_clear_carry(batch,
1253 				  *unmapped_end_index - batch_last_index - 1);
1254 	}
1255 }
1256 
__iopt_area_unfill_domain(struct iopt_area * area,struct iopt_pages * pages,struct iommu_domain * domain,unsigned long last_index)1257 static void __iopt_area_unfill_domain(struct iopt_area *area,
1258 				      struct iopt_pages *pages,
1259 				      struct iommu_domain *domain,
1260 				      unsigned long last_index)
1261 {
1262 	struct interval_tree_double_span_iter span;
1263 	unsigned long start_index = iopt_area_index(area);
1264 	unsigned long unmapped_end_index = start_index;
1265 	u64 backup[BATCH_BACKUP_SIZE];
1266 	struct pfn_batch batch;
1267 
1268 	lockdep_assert_held(&pages->mutex);
1269 
1270 	/*
1271 	 * For security we must not unpin something that is still DMA mapped,
1272 	 * so this must unmap any IOVA before we go ahead and unpin the pages.
1273 	 * This creates a complexity where we need to skip over unpinning pages
1274 	 * held in the xarray, but continue to unmap from the domain.
1275 	 *
1276 	 * The domain unmap cannot stop in the middle of a contiguous range of
1277 	 * PFNs. To solve this problem the unpinning step will read ahead to the
1278 	 * end of any contiguous span, unmap that whole span, and then only
1279 	 * unpin the leading part that does not have any accesses. The residual
1280 	 * PFNs that were unmapped but not unpinned are called a "carry" in the
1281 	 * batch as they are moved to the front of the PFN list and continue on
1282 	 * to the next iteration(s).
1283 	 */
1284 	batch_init_backup(&batch, last_index + 1, backup, sizeof(backup));
1285 	interval_tree_for_each_double_span(&span, &pages->domains_itree,
1286 					   &pages->access_itree, start_index,
1287 					   last_index) {
1288 		if (span.is_used) {
1289 			batch_skip_carry(&batch,
1290 					 span.last_used - span.start_used + 1);
1291 			continue;
1292 		}
1293 		iopt_area_unpin_domain(&batch, area, pages, domain,
1294 				       span.start_hole, span.last_hole,
1295 				       &unmapped_end_index, last_index);
1296 	}
1297 	/*
1298 	 * If the range ends in a access then we do the residual unmap without
1299 	 * any unpins.
1300 	 */
1301 	if (unmapped_end_index != last_index + 1)
1302 		iopt_area_unmap_domain_range(area, domain, unmapped_end_index,
1303 					     last_index);
1304 	WARN_ON(batch.total_pfns);
1305 	batch_destroy(&batch, backup);
1306 	update_unpinned(pages);
1307 }
1308 
iopt_area_unfill_partial_domain(struct iopt_area * area,struct iopt_pages * pages,struct iommu_domain * domain,unsigned long end_index)1309 static void iopt_area_unfill_partial_domain(struct iopt_area *area,
1310 					    struct iopt_pages *pages,
1311 					    struct iommu_domain *domain,
1312 					    unsigned long end_index)
1313 {
1314 	if (end_index != iopt_area_index(area))
1315 		__iopt_area_unfill_domain(area, pages, domain, end_index - 1);
1316 }
1317 
1318 /**
1319  * iopt_area_unmap_domain() - Unmap without unpinning PFNs in a domain
1320  * @area: The IOVA range to unmap
1321  * @domain: The domain to unmap
1322  *
1323  * The caller must know that unpinning is not required, usually because there
1324  * are other domains in the iopt.
1325  */
iopt_area_unmap_domain(struct iopt_area * area,struct iommu_domain * domain)1326 void iopt_area_unmap_domain(struct iopt_area *area, struct iommu_domain *domain)
1327 {
1328 	iommu_unmap_nofail(domain, iopt_area_iova(area),
1329 			   iopt_area_length(area));
1330 }
1331 
1332 /**
1333  * iopt_area_unfill_domain() - Unmap and unpin PFNs in a domain
1334  * @area: IOVA area to use
1335  * @pages: page supplier for the area (area->pages is NULL)
1336  * @domain: Domain to unmap from
1337  *
1338  * The domain should be removed from the domains_itree before calling. The
1339  * domain will always be unmapped, but the PFNs may not be unpinned if there are
1340  * still accesses.
1341  */
iopt_area_unfill_domain(struct iopt_area * area,struct iopt_pages * pages,struct iommu_domain * domain)1342 void iopt_area_unfill_domain(struct iopt_area *area, struct iopt_pages *pages,
1343 			     struct iommu_domain *domain)
1344 {
1345 	__iopt_area_unfill_domain(area, pages, domain,
1346 				  iopt_area_last_index(area));
1347 }
1348 
1349 /**
1350  * iopt_area_fill_domain() - Map PFNs from the area into a domain
1351  * @area: IOVA area to use
1352  * @domain: Domain to load PFNs into
1353  *
1354  * Read the pfns from the area's underlying iopt_pages and map them into the
1355  * given domain. Called when attaching a new domain to an io_pagetable.
1356  */
iopt_area_fill_domain(struct iopt_area * area,struct iommu_domain * domain)1357 int iopt_area_fill_domain(struct iopt_area *area, struct iommu_domain *domain)
1358 {
1359 	unsigned long done_end_index;
1360 	struct pfn_reader pfns;
1361 	int rc;
1362 
1363 	lockdep_assert_held(&area->pages->mutex);
1364 
1365 	rc = pfn_reader_first(&pfns, area->pages, iopt_area_index(area),
1366 			      iopt_area_last_index(area));
1367 	if (rc)
1368 		return rc;
1369 
1370 	while (!pfn_reader_done(&pfns)) {
1371 		done_end_index = pfns.batch_start_index;
1372 		rc = batch_to_domain(&pfns.batch, domain, area,
1373 				     pfns.batch_start_index);
1374 		if (rc)
1375 			goto out_unmap;
1376 		done_end_index = pfns.batch_end_index;
1377 
1378 		rc = pfn_reader_next(&pfns);
1379 		if (rc)
1380 			goto out_unmap;
1381 	}
1382 
1383 	rc = pfn_reader_update_pinned(&pfns);
1384 	if (rc)
1385 		goto out_unmap;
1386 	goto out_destroy;
1387 
1388 out_unmap:
1389 	pfn_reader_release_pins(&pfns);
1390 	iopt_area_unfill_partial_domain(area, area->pages, domain,
1391 					done_end_index);
1392 out_destroy:
1393 	pfn_reader_destroy(&pfns);
1394 	return rc;
1395 }
1396 
1397 /**
1398  * iopt_area_fill_domains() - Install PFNs into the area's domains
1399  * @area: The area to act on
1400  * @pages: The pages associated with the area (area->pages is NULL)
1401  *
1402  * Called during area creation. The area is freshly created and not inserted in
1403  * the domains_itree yet. PFNs are read and loaded into every domain held in the
1404  * area's io_pagetable and the area is installed in the domains_itree.
1405  *
1406  * On failure all domains are left unchanged.
1407  */
iopt_area_fill_domains(struct iopt_area * area,struct iopt_pages * pages)1408 int iopt_area_fill_domains(struct iopt_area *area, struct iopt_pages *pages)
1409 {
1410 	unsigned long done_first_end_index;
1411 	unsigned long done_all_end_index;
1412 	struct iommu_domain *domain;
1413 	unsigned long unmap_index;
1414 	struct pfn_reader pfns;
1415 	unsigned long index;
1416 	int rc;
1417 
1418 	lockdep_assert_held(&area->iopt->domains_rwsem);
1419 
1420 	if (xa_empty(&area->iopt->domains))
1421 		return 0;
1422 
1423 	mutex_lock(&pages->mutex);
1424 	rc = pfn_reader_first(&pfns, pages, iopt_area_index(area),
1425 			      iopt_area_last_index(area));
1426 	if (rc)
1427 		goto out_unlock;
1428 
1429 	while (!pfn_reader_done(&pfns)) {
1430 		done_first_end_index = pfns.batch_end_index;
1431 		done_all_end_index = pfns.batch_start_index;
1432 		xa_for_each(&area->iopt->domains, index, domain) {
1433 			rc = batch_to_domain(&pfns.batch, domain, area,
1434 					     pfns.batch_start_index);
1435 			if (rc)
1436 				goto out_unmap;
1437 		}
1438 		done_all_end_index = done_first_end_index;
1439 
1440 		rc = pfn_reader_next(&pfns);
1441 		if (rc)
1442 			goto out_unmap;
1443 	}
1444 	rc = pfn_reader_update_pinned(&pfns);
1445 	if (rc)
1446 		goto out_unmap;
1447 
1448 	area->storage_domain = xa_load(&area->iopt->domains, 0);
1449 	interval_tree_insert(&area->pages_node, &pages->domains_itree);
1450 	goto out_destroy;
1451 
1452 out_unmap:
1453 	pfn_reader_release_pins(&pfns);
1454 	xa_for_each(&area->iopt->domains, unmap_index, domain) {
1455 		unsigned long end_index;
1456 
1457 		if (unmap_index < index)
1458 			end_index = done_first_end_index;
1459 		else
1460 			end_index = done_all_end_index;
1461 
1462 		/*
1463 		 * The area is not yet part of the domains_itree so we have to
1464 		 * manage the unpinning specially. The last domain does the
1465 		 * unpin, every other domain is just unmapped.
1466 		 */
1467 		if (unmap_index != area->iopt->next_domain_id - 1) {
1468 			if (end_index != iopt_area_index(area))
1469 				iopt_area_unmap_domain_range(
1470 					area, domain, iopt_area_index(area),
1471 					end_index - 1);
1472 		} else {
1473 			iopt_area_unfill_partial_domain(area, pages, domain,
1474 							end_index);
1475 		}
1476 	}
1477 out_destroy:
1478 	pfn_reader_destroy(&pfns);
1479 out_unlock:
1480 	mutex_unlock(&pages->mutex);
1481 	return rc;
1482 }
1483 
1484 /**
1485  * iopt_area_unfill_domains() - unmap PFNs from the area's domains
1486  * @area: The area to act on
1487  * @pages: The pages associated with the area (area->pages is NULL)
1488  *
1489  * Called during area destruction. This unmaps the iova's covered by all the
1490  * area's domains and releases the PFNs.
1491  */
iopt_area_unfill_domains(struct iopt_area * area,struct iopt_pages * pages)1492 void iopt_area_unfill_domains(struct iopt_area *area, struct iopt_pages *pages)
1493 {
1494 	struct io_pagetable *iopt = area->iopt;
1495 	struct iommu_domain *domain;
1496 	unsigned long index;
1497 
1498 	lockdep_assert_held(&iopt->domains_rwsem);
1499 
1500 	mutex_lock(&pages->mutex);
1501 	if (!area->storage_domain)
1502 		goto out_unlock;
1503 
1504 	xa_for_each(&iopt->domains, index, domain)
1505 		if (domain != area->storage_domain)
1506 			iopt_area_unmap_domain_range(
1507 				area, domain, iopt_area_index(area),
1508 				iopt_area_last_index(area));
1509 
1510 	if (IS_ENABLED(CONFIG_IOMMUFD_TEST))
1511 		WARN_ON(RB_EMPTY_NODE(&area->pages_node.rb));
1512 	interval_tree_remove(&area->pages_node, &pages->domains_itree);
1513 	iopt_area_unfill_domain(area, pages, area->storage_domain);
1514 	area->storage_domain = NULL;
1515 out_unlock:
1516 	mutex_unlock(&pages->mutex);
1517 }
1518 
iopt_pages_unpin_xarray(struct pfn_batch * batch,struct iopt_pages * pages,unsigned long start_index,unsigned long end_index)1519 static void iopt_pages_unpin_xarray(struct pfn_batch *batch,
1520 				    struct iopt_pages *pages,
1521 				    unsigned long start_index,
1522 				    unsigned long end_index)
1523 {
1524 	while (start_index <= end_index) {
1525 		batch_from_xarray_clear(batch, &pages->pinned_pfns, start_index,
1526 					end_index);
1527 		batch_unpin(batch, pages, 0, batch->total_pfns);
1528 		start_index += batch->total_pfns;
1529 		batch_clear(batch);
1530 	}
1531 }
1532 
1533 /**
1534  * iopt_pages_unfill_xarray() - Update the xarry after removing an access
1535  * @pages: The pages to act on
1536  * @start_index: Starting PFN index
1537  * @last_index: Last PFN index
1538  *
1539  * Called when an iopt_pages_access is removed, removes pages from the itree.
1540  * The access should already be removed from the access_itree.
1541  */
iopt_pages_unfill_xarray(struct iopt_pages * pages,unsigned long start_index,unsigned long last_index)1542 void iopt_pages_unfill_xarray(struct iopt_pages *pages,
1543 			      unsigned long start_index,
1544 			      unsigned long last_index)
1545 {
1546 	struct interval_tree_double_span_iter span;
1547 	u64 backup[BATCH_BACKUP_SIZE];
1548 	struct pfn_batch batch;
1549 	bool batch_inited = false;
1550 
1551 	lockdep_assert_held(&pages->mutex);
1552 
1553 	interval_tree_for_each_double_span(&span, &pages->access_itree,
1554 					   &pages->domains_itree, start_index,
1555 					   last_index) {
1556 		if (!span.is_used) {
1557 			if (!batch_inited) {
1558 				batch_init_backup(&batch,
1559 						  last_index - start_index + 1,
1560 						  backup, sizeof(backup));
1561 				batch_inited = true;
1562 			}
1563 			iopt_pages_unpin_xarray(&batch, pages, span.start_hole,
1564 						span.last_hole);
1565 		} else if (span.is_used == 2) {
1566 			/* Covered by a domain */
1567 			clear_xarray(&pages->pinned_pfns, span.start_used,
1568 				     span.last_used);
1569 		}
1570 		/* Otherwise covered by an existing access */
1571 	}
1572 	if (batch_inited)
1573 		batch_destroy(&batch, backup);
1574 	update_unpinned(pages);
1575 }
1576 
1577 /**
1578  * iopt_pages_fill_from_xarray() - Fast path for reading PFNs
1579  * @pages: The pages to act on
1580  * @start_index: The first page index in the range
1581  * @last_index: The last page index in the range
1582  * @out_pages: The output array to return the pages
1583  *
1584  * This can be called if the caller is holding a refcount on an
1585  * iopt_pages_access that is known to have already been filled. It quickly reads
1586  * the pages directly from the xarray.
1587  *
1588  * This is part of the SW iommu interface to read pages for in-kernel use.
1589  */
iopt_pages_fill_from_xarray(struct iopt_pages * pages,unsigned long start_index,unsigned long last_index,struct page ** out_pages)1590 void iopt_pages_fill_from_xarray(struct iopt_pages *pages,
1591 				 unsigned long start_index,
1592 				 unsigned long last_index,
1593 				 struct page **out_pages)
1594 {
1595 	XA_STATE(xas, &pages->pinned_pfns, start_index);
1596 	void *entry;
1597 
1598 	rcu_read_lock();
1599 	while (start_index <= last_index) {
1600 		entry = xas_next(&xas);
1601 		if (xas_retry(&xas, entry))
1602 			continue;
1603 		WARN_ON(!xa_is_value(entry));
1604 		*(out_pages++) = pfn_to_page(xa_to_value(entry));
1605 		start_index++;
1606 	}
1607 	rcu_read_unlock();
1608 }
1609 
iopt_pages_fill_from_domain(struct iopt_pages * pages,unsigned long start_index,unsigned long last_index,struct page ** out_pages)1610 static int iopt_pages_fill_from_domain(struct iopt_pages *pages,
1611 				       unsigned long start_index,
1612 				       unsigned long last_index,
1613 				       struct page **out_pages)
1614 {
1615 	while (start_index != last_index + 1) {
1616 		unsigned long domain_last;
1617 		struct iopt_area *area;
1618 
1619 		area = iopt_pages_find_domain_area(pages, start_index);
1620 		if (WARN_ON(!area))
1621 			return -EINVAL;
1622 
1623 		domain_last = min(iopt_area_last_index(area), last_index);
1624 		out_pages = raw_pages_from_domain(area->storage_domain, area,
1625 						  start_index, domain_last,
1626 						  out_pages);
1627 		start_index = domain_last + 1;
1628 	}
1629 	return 0;
1630 }
1631 
iopt_pages_fill_from_mm(struct iopt_pages * pages,struct pfn_reader_user * user,unsigned long start_index,unsigned long last_index,struct page ** out_pages)1632 static int iopt_pages_fill_from_mm(struct iopt_pages *pages,
1633 				   struct pfn_reader_user *user,
1634 				   unsigned long start_index,
1635 				   unsigned long last_index,
1636 				   struct page **out_pages)
1637 {
1638 	unsigned long cur_index = start_index;
1639 	int rc;
1640 
1641 	while (cur_index != last_index + 1) {
1642 		user->upages = out_pages + (cur_index - start_index);
1643 		rc = pfn_reader_user_pin(user, pages, cur_index, last_index);
1644 		if (rc)
1645 			goto out_unpin;
1646 		cur_index = user->upages_end;
1647 	}
1648 	return 0;
1649 
1650 out_unpin:
1651 	if (start_index != cur_index)
1652 		iopt_pages_err_unpin(pages, start_index, cur_index - 1,
1653 				     out_pages);
1654 	return rc;
1655 }
1656 
1657 /**
1658  * iopt_pages_fill_xarray() - Read PFNs
1659  * @pages: The pages to act on
1660  * @start_index: The first page index in the range
1661  * @last_index: The last page index in the range
1662  * @out_pages: The output array to return the pages, may be NULL
1663  *
1664  * This populates the xarray and returns the pages in out_pages. As the slow
1665  * path this is able to copy pages from other storage tiers into the xarray.
1666  *
1667  * On failure the xarray is left unchanged.
1668  *
1669  * This is part of the SW iommu interface to read pages for in-kernel use.
1670  */
iopt_pages_fill_xarray(struct iopt_pages * pages,unsigned long start_index,unsigned long last_index,struct page ** out_pages)1671 int iopt_pages_fill_xarray(struct iopt_pages *pages, unsigned long start_index,
1672 			   unsigned long last_index, struct page **out_pages)
1673 {
1674 	struct interval_tree_double_span_iter span;
1675 	unsigned long xa_end = start_index;
1676 	struct pfn_reader_user user;
1677 	int rc;
1678 
1679 	lockdep_assert_held(&pages->mutex);
1680 
1681 	pfn_reader_user_init(&user, pages);
1682 	user.upages_len = (last_index - start_index + 1) * sizeof(*out_pages);
1683 	interval_tree_for_each_double_span(&span, &pages->access_itree,
1684 					   &pages->domains_itree, start_index,
1685 					   last_index) {
1686 		struct page **cur_pages;
1687 
1688 		if (span.is_used == 1) {
1689 			cur_pages = out_pages + (span.start_used - start_index);
1690 			iopt_pages_fill_from_xarray(pages, span.start_used,
1691 						    span.last_used, cur_pages);
1692 			continue;
1693 		}
1694 
1695 		if (span.is_used == 2) {
1696 			cur_pages = out_pages + (span.start_used - start_index);
1697 			iopt_pages_fill_from_domain(pages, span.start_used,
1698 						    span.last_used, cur_pages);
1699 			rc = pages_to_xarray(&pages->pinned_pfns,
1700 					     span.start_used, span.last_used,
1701 					     cur_pages);
1702 			if (rc)
1703 				goto out_clean_xa;
1704 			xa_end = span.last_used + 1;
1705 			continue;
1706 		}
1707 
1708 		/* hole */
1709 		cur_pages = out_pages + (span.start_hole - start_index);
1710 		rc = iopt_pages_fill_from_mm(pages, &user, span.start_hole,
1711 					     span.last_hole, cur_pages);
1712 		if (rc)
1713 			goto out_clean_xa;
1714 		rc = pages_to_xarray(&pages->pinned_pfns, span.start_hole,
1715 				     span.last_hole, cur_pages);
1716 		if (rc) {
1717 			iopt_pages_err_unpin(pages, span.start_hole,
1718 					     span.last_hole, cur_pages);
1719 			goto out_clean_xa;
1720 		}
1721 		xa_end = span.last_hole + 1;
1722 	}
1723 	rc = pfn_reader_user_update_pinned(&user, pages);
1724 	if (rc)
1725 		goto out_clean_xa;
1726 	user.upages = NULL;
1727 	pfn_reader_user_destroy(&user, pages);
1728 	return 0;
1729 
1730 out_clean_xa:
1731 	if (start_index != xa_end)
1732 		iopt_pages_unfill_xarray(pages, start_index, xa_end - 1);
1733 	user.upages = NULL;
1734 	pfn_reader_user_destroy(&user, pages);
1735 	return rc;
1736 }
1737 
1738 /*
1739  * This uses the pfn_reader instead of taking a shortcut by using the mm. It can
1740  * do every scenario and is fully consistent with what an iommu_domain would
1741  * see.
1742  */
iopt_pages_rw_slow(struct iopt_pages * pages,unsigned long start_index,unsigned long last_index,unsigned long offset,void * data,unsigned long length,unsigned int flags)1743 static int iopt_pages_rw_slow(struct iopt_pages *pages,
1744 			      unsigned long start_index,
1745 			      unsigned long last_index, unsigned long offset,
1746 			      void *data, unsigned long length,
1747 			      unsigned int flags)
1748 {
1749 	struct pfn_reader pfns;
1750 	int rc;
1751 
1752 	mutex_lock(&pages->mutex);
1753 
1754 	rc = pfn_reader_first(&pfns, pages, start_index, last_index);
1755 	if (rc)
1756 		goto out_unlock;
1757 
1758 	while (!pfn_reader_done(&pfns)) {
1759 		unsigned long done;
1760 
1761 		done = batch_rw(&pfns.batch, data, offset, length, flags);
1762 		data += done;
1763 		length -= done;
1764 		offset = 0;
1765 		pfn_reader_unpin(&pfns);
1766 
1767 		rc = pfn_reader_next(&pfns);
1768 		if (rc)
1769 			goto out_destroy;
1770 	}
1771 	if (WARN_ON(length != 0))
1772 		rc = -EINVAL;
1773 out_destroy:
1774 	pfn_reader_destroy(&pfns);
1775 out_unlock:
1776 	mutex_unlock(&pages->mutex);
1777 	return rc;
1778 }
1779 
1780 /*
1781  * A medium speed path that still allows DMA inconsistencies, but doesn't do any
1782  * memory allocations or interval tree searches.
1783  */
iopt_pages_rw_page(struct iopt_pages * pages,unsigned long index,unsigned long offset,void * data,unsigned long length,unsigned int flags)1784 static int iopt_pages_rw_page(struct iopt_pages *pages, unsigned long index,
1785 			      unsigned long offset, void *data,
1786 			      unsigned long length, unsigned int flags)
1787 {
1788 	struct page *page = NULL;
1789 	int rc;
1790 
1791 	if (!mmget_not_zero(pages->source_mm))
1792 		return iopt_pages_rw_slow(pages, index, index, offset, data,
1793 					  length, flags);
1794 
1795 	if (iommufd_should_fail()) {
1796 		rc = -EINVAL;
1797 		goto out_mmput;
1798 	}
1799 
1800 	mmap_read_lock(pages->source_mm);
1801 	rc = pin_user_pages_remote(
1802 		pages->source_mm, (uintptr_t)(pages->uptr + index * PAGE_SIZE),
1803 		1, (flags & IOMMUFD_ACCESS_RW_WRITE) ? FOLL_WRITE : 0, &page,
1804 		NULL);
1805 	mmap_read_unlock(pages->source_mm);
1806 	if (rc != 1) {
1807 		if (WARN_ON(rc >= 0))
1808 			rc = -EINVAL;
1809 		goto out_mmput;
1810 	}
1811 	copy_data_page(page, data, offset, length, flags);
1812 	unpin_user_page(page);
1813 	rc = 0;
1814 
1815 out_mmput:
1816 	mmput(pages->source_mm);
1817 	return rc;
1818 }
1819 
1820 /**
1821  * iopt_pages_rw_access - Copy to/from a linear slice of the pages
1822  * @pages: pages to act on
1823  * @start_byte: First byte of pages to copy to/from
1824  * @data: Kernel buffer to get/put the data
1825  * @length: Number of bytes to copy
1826  * @flags: IOMMUFD_ACCESS_RW_* flags
1827  *
1828  * This will find each page in the range, kmap it and then memcpy to/from
1829  * the given kernel buffer.
1830  */
iopt_pages_rw_access(struct iopt_pages * pages,unsigned long start_byte,void * data,unsigned long length,unsigned int flags)1831 int iopt_pages_rw_access(struct iopt_pages *pages, unsigned long start_byte,
1832 			 void *data, unsigned long length, unsigned int flags)
1833 {
1834 	unsigned long start_index = start_byte / PAGE_SIZE;
1835 	unsigned long last_index = (start_byte + length - 1) / PAGE_SIZE;
1836 	bool change_mm = current->mm != pages->source_mm;
1837 	int rc = 0;
1838 
1839 	if (IS_ENABLED(CONFIG_IOMMUFD_TEST) &&
1840 	    (flags & __IOMMUFD_ACCESS_RW_SLOW_PATH))
1841 		change_mm = true;
1842 
1843 	if ((flags & IOMMUFD_ACCESS_RW_WRITE) && !pages->writable)
1844 		return -EPERM;
1845 
1846 	if (!(flags & IOMMUFD_ACCESS_RW_KTHREAD) && change_mm) {
1847 		if (start_index == last_index)
1848 			return iopt_pages_rw_page(pages, start_index,
1849 						  start_byte % PAGE_SIZE, data,
1850 						  length, flags);
1851 		return iopt_pages_rw_slow(pages, start_index, last_index,
1852 					  start_byte % PAGE_SIZE, data, length,
1853 					  flags);
1854 	}
1855 
1856 	/*
1857 	 * Try to copy using copy_to_user(). We do this as a fast path and
1858 	 * ignore any pinning inconsistencies, unlike a real DMA path.
1859 	 */
1860 	if (change_mm) {
1861 		if (!mmget_not_zero(pages->source_mm))
1862 			return iopt_pages_rw_slow(pages, start_index,
1863 						  last_index,
1864 						  start_byte % PAGE_SIZE, data,
1865 						  length, flags);
1866 		kthread_use_mm(pages->source_mm);
1867 	}
1868 
1869 	if (flags & IOMMUFD_ACCESS_RW_WRITE) {
1870 		if (copy_to_user(pages->uptr + start_byte, data, length))
1871 			rc = -EFAULT;
1872 	} else {
1873 		if (copy_from_user(data, pages->uptr + start_byte, length))
1874 			rc = -EFAULT;
1875 	}
1876 
1877 	if (change_mm) {
1878 		kthread_unuse_mm(pages->source_mm);
1879 		mmput(pages->source_mm);
1880 	}
1881 
1882 	return rc;
1883 }
1884 
1885 static struct iopt_pages_access *
iopt_pages_get_exact_access(struct iopt_pages * pages,unsigned long index,unsigned long last)1886 iopt_pages_get_exact_access(struct iopt_pages *pages, unsigned long index,
1887 			    unsigned long last)
1888 {
1889 	struct interval_tree_node *node;
1890 
1891 	lockdep_assert_held(&pages->mutex);
1892 
1893 	/* There can be overlapping ranges in this interval tree */
1894 	for (node = interval_tree_iter_first(&pages->access_itree, index, last);
1895 	     node; node = interval_tree_iter_next(node, index, last))
1896 		if (node->start == index && node->last == last)
1897 			return container_of(node, struct iopt_pages_access,
1898 					    node);
1899 	return NULL;
1900 }
1901 
1902 /**
1903  * iopt_area_add_access() - Record an in-knerel access for PFNs
1904  * @area: The source of PFNs
1905  * @start_index: First page index
1906  * @last_index: Inclusive last page index
1907  * @out_pages: Output list of struct page's representing the PFNs
1908  * @flags: IOMMUFD_ACCESS_RW_* flags
1909  *
1910  * Record that an in-kernel access will be accessing the pages, ensure they are
1911  * pinned, and return the PFNs as a simple list of 'struct page *'.
1912  *
1913  * This should be undone through a matching call to iopt_area_remove_access()
1914  */
iopt_area_add_access(struct iopt_area * area,unsigned long start_index,unsigned long last_index,struct page ** out_pages,unsigned int flags)1915 int iopt_area_add_access(struct iopt_area *area, unsigned long start_index,
1916 			  unsigned long last_index, struct page **out_pages,
1917 			  unsigned int flags)
1918 {
1919 	struct iopt_pages *pages = area->pages;
1920 	struct iopt_pages_access *access;
1921 	int rc;
1922 
1923 	if ((flags & IOMMUFD_ACCESS_RW_WRITE) && !pages->writable)
1924 		return -EPERM;
1925 
1926 	mutex_lock(&pages->mutex);
1927 	access = iopt_pages_get_exact_access(pages, start_index, last_index);
1928 	if (access) {
1929 		area->num_accesses++;
1930 		access->users++;
1931 		iopt_pages_fill_from_xarray(pages, start_index, last_index,
1932 					    out_pages);
1933 		mutex_unlock(&pages->mutex);
1934 		return 0;
1935 	}
1936 
1937 	access = kzalloc(sizeof(*access), GFP_KERNEL_ACCOUNT);
1938 	if (!access) {
1939 		rc = -ENOMEM;
1940 		goto err_unlock;
1941 	}
1942 
1943 	rc = iopt_pages_fill_xarray(pages, start_index, last_index, out_pages);
1944 	if (rc)
1945 		goto err_free;
1946 
1947 	access->node.start = start_index;
1948 	access->node.last = last_index;
1949 	access->users = 1;
1950 	area->num_accesses++;
1951 	interval_tree_insert(&access->node, &pages->access_itree);
1952 	mutex_unlock(&pages->mutex);
1953 	return 0;
1954 
1955 err_free:
1956 	kfree(access);
1957 err_unlock:
1958 	mutex_unlock(&pages->mutex);
1959 	return rc;
1960 }
1961 
1962 /**
1963  * iopt_area_remove_access() - Release an in-kernel access for PFNs
1964  * @area: The source of PFNs
1965  * @start_index: First page index
1966  * @last_index: Inclusive last page index
1967  *
1968  * Undo iopt_area_add_access() and unpin the pages if necessary. The caller
1969  * must stop using the PFNs before calling this.
1970  */
iopt_area_remove_access(struct iopt_area * area,unsigned long start_index,unsigned long last_index)1971 void iopt_area_remove_access(struct iopt_area *area, unsigned long start_index,
1972 			     unsigned long last_index)
1973 {
1974 	struct iopt_pages *pages = area->pages;
1975 	struct iopt_pages_access *access;
1976 
1977 	mutex_lock(&pages->mutex);
1978 	access = iopt_pages_get_exact_access(pages, start_index, last_index);
1979 	if (WARN_ON(!access))
1980 		goto out_unlock;
1981 
1982 	WARN_ON(area->num_accesses == 0 || access->users == 0);
1983 	area->num_accesses--;
1984 	access->users--;
1985 	if (access->users)
1986 		goto out_unlock;
1987 
1988 	interval_tree_remove(&access->node, &pages->access_itree);
1989 	iopt_pages_unfill_xarray(pages, start_index, last_index);
1990 	kfree(access);
1991 out_unlock:
1992 	mutex_unlock(&pages->mutex);
1993 }
1994