diff --git a/fs/userfaultfd.c b/fs/userfaultfd.c
index 1c44bf75f9160cd509f684f0727c8a36be085527..175de70e3adfdd68a88ef6c11e16e68c8f296a17 100644
--- a/fs/userfaultfd.c
+++ b/fs/userfaultfd.c
@@ -1601,6 +1601,10 @@ static int userfaultfd_unregister(struct userfaultfd_ctx *ctx,
 			wake_userfault(vma->vm_userfaultfd_ctx.ctx, &range);
 		}
 
+		/* Reset ptes for the whole vma range if wr-protected */
+		if (userfaultfd_wp(vma))
+			uffd_wp_range(mm, vma, start, vma_end - start, false);
+
 		new_flags = vma->vm_flags & ~__VM_UFFD_FLAGS;
 		prev = vma_merge(mm, prev, start, vma_end, new_flags,
 				 vma->anon_vma, vma->vm_file, vma->vm_pgoff,
diff --git a/include/linux/userfaultfd_k.h b/include/linux/userfaultfd_k.h
index 732b522bacb7e5c15a8d6e00f9f4d4e8c6daa4ea..e1b8a915e9e9fb1346fc2ffdc354afcdbf6627d3 100644
--- a/include/linux/userfaultfd_k.h
+++ b/include/linux/userfaultfd_k.h
@@ -73,6 +73,8 @@ extern ssize_t mcopy_continue(struct mm_struct *dst_mm, unsigned long dst_start,
 extern int mwriteprotect_range(struct mm_struct *dst_mm,
 			       unsigned long start, unsigned long len,
 			       bool enable_wp, atomic_t *mmap_changing);
+extern void uffd_wp_range(struct mm_struct *dst_mm, struct vm_area_struct *vma,
+			  unsigned long start, unsigned long len, bool enable_wp);
 
 /* mm helpers */
 static inline bool is_mergeable_vm_userfaultfd_ctx(struct vm_area_struct *vma,
diff --git a/mm/userfaultfd.c b/mm/userfaultfd.c
index 07d3befc80e4134dd6b54be127197ffaad5c10e2..7327b2573f7c2f83c0475109fe7b6e6479b270a5 100644
--- a/mm/userfaultfd.c
+++ b/mm/userfaultfd.c
@@ -703,14 +703,29 @@ ssize_t mcopy_continue(struct mm_struct *dst_mm, unsigned long start,
 			      mmap_changing, 0);
 }
 
+void uffd_wp_range(struct mm_struct *dst_mm, struct vm_area_struct *dst_vma,
+		   unsigned long start, unsigned long len, bool enable_wp)
+{
+	struct mmu_gather tlb;
+	pgprot_t newprot;
+
+	if (enable_wp)
+		newprot = vm_get_page_prot(dst_vma->vm_flags & ~(VM_WRITE));
+	else
+		newprot = vm_get_page_prot(dst_vma->vm_flags);
+
+	tlb_gather_mmu(&tlb, dst_mm);
+	change_protection(&tlb, dst_vma, start, start + len, newprot,
+			  enable_wp ? MM_CP_UFFD_WP : MM_CP_UFFD_WP_RESOLVE);
+	tlb_finish_mmu(&tlb);
+}
+
 int mwriteprotect_range(struct mm_struct *dst_mm, unsigned long start,
 			unsigned long len, bool enable_wp,
 			atomic_t *mmap_changing)
 {
 	struct vm_area_struct *dst_vma;
 	unsigned long page_mask;
-	struct mmu_gather tlb;
-	pgprot_t newprot;
 	int err;
 
 	/*
@@ -750,15 +765,7 @@ int mwriteprotect_range(struct mm_struct *dst_mm, unsigned long start,
 			goto out_unlock;
 	}
 
-	if (enable_wp)
-		newprot = vm_get_page_prot(dst_vma->vm_flags & ~(VM_WRITE));
-	else
-		newprot = vm_get_page_prot(dst_vma->vm_flags);
-
-	tlb_gather_mmu(&tlb, dst_mm);
-	change_protection(&tlb, dst_vma, start, start + len, newprot,
-			  enable_wp ? MM_CP_UFFD_WP : MM_CP_UFFD_WP_RESOLVE);
-	tlb_finish_mmu(&tlb);
+	uffd_wp_range(dst_mm, dst_vma, start, len, enable_wp);
 
 	err = 0;
 out_unlock: