diff --git a/drivers/platform/mpam/mpam_devices.c b/drivers/platform/mpam/mpam_devices.c
index 433a3e34b92bdfca6669a4460101112a803206f8..31edaeee86ee70189302c553d115833278765230 100644
--- a/drivers/platform/mpam/mpam_devices.c
+++ b/drivers/platform/mpam/mpam_devices.c
@@ -935,8 +935,6 @@ static void __ris_msmon_read(void *arg)
 	struct msmon_mbwu_state *mbwu_state;
 	u32 mon_sel, ctl_val, flt_val, cur_ctl, cur_flt;
 
-	lockdep_assert_held(&msc->lock);
-
 	spin_lock_irqsave(&msc->mon_sel_lock, flags);
 	mon_sel = FIELD_PREP(MSMON_CFG_MON_SEL_MON_SEL, ctx->mon) |
 		  FIELD_PREP(MSMON_CFG_MON_SEL_RIS, ris->ris_idx);
@@ -1015,9 +1013,11 @@ static void __ris_msmon_read(void *arg)
 	*(m->val) += now;
 }
 
+/* This is also called in atomic context via the restrl_pmu driver */
 static int _msmon_read(struct mpam_component *comp, struct mon_read *arg)
 {
-	int err, idx;
+	int err, idx, cpu;
+	struct cpumask *mask;
 	struct mpam_msc *msc;
 	struct mpam_msc_ris *ris;
 
@@ -1026,10 +1026,23 @@ static int _msmon_read(struct mpam_component *comp, struct mon_read *arg)
 		arg->ris = ris;
 
 		msc = ris->msc;
-		mutex_lock(&msc->lock);
-		err = smp_call_function_any(&msc->accessibility,
-					    __ris_msmon_read, arg, true);
-		mutex_unlock(&msc->lock);
+		mask = &msc->accessibility;
+
+		/*
+		 * Fail the access if we need to cross call to reach this MSC
+		 * and irqs are masked. The PMU driver calls this with irqs
+		 * masked, but it also specifies where the callback should run.
+		 */
+		err = -EIO;
+		cpu = get_cpu();
+		if (cpumask_test_cpu(cpu, mask)) {
+			__ris_msmon_read(arg);
+			err = 0;
+		} else if (!irqs_disabled())
+			err = smp_call_function_any(mask, __ris_msmon_read,
+						    arg, true);
+		put_cpu();
+
 		if (!err && arg->err)
 			err = arg->err;
 		if (err)
@@ -1048,8 +1061,6 @@ int mpam_msmon_read(struct mpam_component *comp, struct mon_cfg *ctx,
 	u64 wait_jiffies = 0;
 	struct mpam_props *cprops = &comp->class->props;
 
-	might_sleep();
-
 	if (!mpam_is_enabled())
 		return -EIO;
 
@@ -1063,7 +1074,7 @@ int mpam_msmon_read(struct mpam_component *comp, struct mon_cfg *ctx,
 	*val = 0;
 
 	err = _msmon_read(comp, &arg);
-	if (err == -EBUSY)
+	if (err == -EBUSY && !irqs_disabled())
 		wait_jiffies = usecs_to_jiffies(comp->class->nrdy_usec);
 
 	while (wait_jiffies)