diff --git a/drivers/iommu/riscv/iommu-pci.c b/drivers/iommu/riscv/iommu-pci.c index c7a89143014c5..d82d2b00904c7 100644 --- a/drivers/iommu/riscv/iommu-pci.c +++ b/drivers/iommu/riscv/iommu-pci.c @@ -101,6 +101,13 @@ static void riscv_iommu_pci_remove(struct pci_dev *pdev) riscv_iommu_remove(iommu); } +static void riscv_iommu_pci_shutdown(struct pci_dev *pdev) +{ + struct riscv_iommu_device *iommu = dev_get_drvdata(&pdev->dev); + + riscv_iommu_disable(iommu); +} + static const struct pci_device_id riscv_iommu_pci_tbl[] = { {PCI_VDEVICE(REDHAT, PCI_DEVICE_ID_REDHAT_RISCV_IOMMU), 0}, {PCI_VDEVICE(RIVOS, PCI_DEVICE_ID_RIVOS_RISCV_IOMMU_GA), 0}, @@ -112,6 +119,7 @@ static struct pci_driver riscv_iommu_pci_driver = { .id_table = riscv_iommu_pci_tbl, .probe = riscv_iommu_pci_probe, .remove = riscv_iommu_pci_remove, + .shutdown = riscv_iommu_pci_shutdown, .driver = { .suppress_bind_attrs = true, }, diff --git a/drivers/iommu/riscv/iommu-platform.c b/drivers/iommu/riscv/iommu-platform.c index da336863f152f..c30b7271b6153 100644 --- a/drivers/iommu/riscv/iommu-platform.c +++ b/drivers/iommu/riscv/iommu-platform.c @@ -11,18 +11,43 @@ */ #include +#include +#include #include #include #include "iommu-bits.h" #include "iommu.h" +static void riscv_iommu_write_msi_msg(struct msi_desc *desc, struct msi_msg *msg) +{ + struct device *dev = msi_desc_to_dev(desc); + struct riscv_iommu_device *iommu = dev_get_drvdata(dev); + u16 idx = desc->msi_index; + u64 addr; + + addr = ((u64)msg->address_hi << 32) | msg->address_lo; + + if (addr != (addr & RISCV_IOMMU_MSI_CFG_TBL_ADDR)) { + dev_err_once(dev, + "uh oh, the IOMMU can't send MSIs to 0x%llx, sending to 0x%llx instead\n", + addr, addr & RISCV_IOMMU_MSI_CFG_TBL_ADDR); + } + + addr &= RISCV_IOMMU_MSI_CFG_TBL_ADDR; + + riscv_iommu_writeq(iommu, RISCV_IOMMU_REG_MSI_CFG_TBL_ADDR(idx), addr); + riscv_iommu_writel(iommu, RISCV_IOMMU_REG_MSI_CFG_TBL_DATA(idx), msg->data); + riscv_iommu_writel(iommu, RISCV_IOMMU_REG_MSI_CFG_TBL_CTRL(idx), 0); +} + static int riscv_iommu_platform_probe(struct platform_device *pdev) { + enum riscv_iommu_igs_settings igs; struct device *dev = &pdev->dev; struct riscv_iommu_device *iommu = NULL; struct resource *res = NULL; - int vec; + int vec, ret; iommu = devm_kzalloc(dev, sizeof(*iommu), GFP_KERNEL); if (!iommu) @@ -40,16 +65,6 @@ static int riscv_iommu_platform_probe(struct platform_device *pdev) iommu->caps = riscv_iommu_readq(iommu, RISCV_IOMMU_REG_CAPABILITIES); iommu->fctl = riscv_iommu_readl(iommu, RISCV_IOMMU_REG_FCTL); - /* For now we only support WSI */ - switch (FIELD_GET(RISCV_IOMMU_CAPABILITIES_IGS, iommu->caps)) { - case RISCV_IOMMU_CAPABILITIES_IGS_WSI: - case RISCV_IOMMU_CAPABILITIES_IGS_BOTH: - break; - default: - return dev_err_probe(dev, -ENODEV, - "unable to use wire-signaled interrupts\n"); - } - iommu->irqs_count = platform_irq_count(pdev); if (iommu->irqs_count <= 0) return dev_err_probe(dev, -ENODEV, @@ -57,13 +72,58 @@ static int riscv_iommu_platform_probe(struct platform_device *pdev) if (iommu->irqs_count > RISCV_IOMMU_INTR_COUNT) iommu->irqs_count = RISCV_IOMMU_INTR_COUNT; - for (vec = 0; vec < iommu->irqs_count; vec++) - iommu->irqs[vec] = platform_get_irq(pdev, vec); + igs = FIELD_GET(RISCV_IOMMU_CAPABILITIES_IGS, iommu->caps); + switch (igs) { + case RISCV_IOMMU_CAPABILITIES_IGS_BOTH: + case RISCV_IOMMU_CAPABILITIES_IGS_MSI: + if (is_of_node(dev->fwnode)) + of_msi_configure(dev, to_of_node(dev->fwnode)); + + if (!dev_get_msi_domain(dev)) { + dev_warn(dev, "failed to find an MSI domain\n"); + goto msi_fail; + } + + ret = platform_device_msi_init_and_alloc_irqs(dev, iommu->irqs_count, + riscv_iommu_write_msi_msg); + if (ret) { + dev_warn(dev, "failed to allocate MSIs\n"); + goto msi_fail; + } + + for (vec = 0; vec < iommu->irqs_count; vec++) + iommu->irqs[vec] = msi_get_virq(dev, vec); + + /* Enable message-signaled interrupts, fctl.WSI */ + if (iommu->fctl & RISCV_IOMMU_FCTL_WSI) { + iommu->fctl ^= RISCV_IOMMU_FCTL_WSI; + riscv_iommu_writel(iommu, RISCV_IOMMU_REG_FCTL, iommu->fctl); + } + + dev_info(dev, "using MSIs\n"); + break; + +msi_fail: + if (igs != RISCV_IOMMU_CAPABILITIES_IGS_BOTH) { + return dev_err_probe(dev, -ENODEV, + "unable to use wire-signaled interrupts\n"); + } - /* Enable wire-signaled interrupts, fctl.WSI */ - if (!(iommu->fctl & RISCV_IOMMU_FCTL_WSI)) { - iommu->fctl |= RISCV_IOMMU_FCTL_WSI; - riscv_iommu_writel(iommu, RISCV_IOMMU_REG_FCTL, iommu->fctl); + fallthrough; + + case RISCV_IOMMU_CAPABILITIES_IGS_WSI: + for (vec = 0; vec < iommu->irqs_count; vec++) + iommu->irqs[vec] = platform_get_irq(pdev, vec); + + /* Enable wire-signaled interrupts, fctl.WSI */ + if (!(iommu->fctl & RISCV_IOMMU_FCTL_WSI)) { + iommu->fctl |= RISCV_IOMMU_FCTL_WSI; + riscv_iommu_writel(iommu, RISCV_IOMMU_REG_FCTL, iommu->fctl); + } + dev_info(dev, "using wire-signaled interrupts\n"); + break; + default: + return dev_err_probe(dev, -ENODEV, "invalid IGS\n"); } return riscv_iommu_init(iommu); @@ -71,7 +131,18 @@ static int riscv_iommu_platform_probe(struct platform_device *pdev) static void riscv_iommu_platform_remove(struct platform_device *pdev) { - riscv_iommu_remove(dev_get_drvdata(&pdev->dev)); + struct riscv_iommu_device *iommu = dev_get_drvdata(&pdev->dev); + bool msi = !(iommu->fctl & RISCV_IOMMU_FCTL_WSI); + + riscv_iommu_remove(iommu); + + if (msi) + platform_device_msi_free_irqs_all(&pdev->dev); +}; + +static void riscv_iommu_platform_shutdown(struct platform_device *pdev) +{ + riscv_iommu_disable(dev_get_drvdata(&pdev->dev)); }; static const struct of_device_id riscv_iommu_of_match[] = { @@ -82,6 +153,7 @@ static const struct of_device_id riscv_iommu_of_match[] = { static struct platform_driver riscv_iommu_platform_driver = { .probe = riscv_iommu_platform_probe, .remove_new = riscv_iommu_platform_remove, + .shutdown = riscv_iommu_platform_shutdown, .driver = { .name = "riscv,iommu", .of_match_table = riscv_iommu_of_match, diff --git a/drivers/iommu/riscv/iommu.c b/drivers/iommu/riscv/iommu.c index 8a05def774bdb..d1922371a4672 100644 --- a/drivers/iommu/riscv/iommu.c +++ b/drivers/iommu/riscv/iommu.c @@ -240,6 +240,12 @@ static int riscv_iommu_queue_enable(struct riscv_iommu_device *iommu, return rc; } + /* Empty queue before enabling it */ + if (queue->qid == RISCV_IOMMU_INTR_CQ) + riscv_iommu_writel(queue->iommu, Q_TAIL(queue), 0); + else + riscv_iommu_writel(queue->iommu, Q_HEAD(queue), 0); + /* * Enable queue with interrupts, clear any memory fault if any. * Wait for the hardware to acknowledge request and activate queue @@ -645,9 +651,11 @@ static struct riscv_iommu_dc *riscv_iommu_get_dc(struct riscv_iommu_device *iomm * This is best effort IOMMU translation shutdown flow. * Disable IOMMU without waiting for hardware response. */ -static void riscv_iommu_disable(struct riscv_iommu_device *iommu) +void riscv_iommu_disable(struct riscv_iommu_device *iommu) { - riscv_iommu_writeq(iommu, RISCV_IOMMU_REG_DDTP, 0); + riscv_iommu_writeq(iommu, RISCV_IOMMU_REG_DDTP, + FIELD_PREP(RISCV_IOMMU_DDTP_IOMMU_MODE, + RISCV_IOMMU_DDTP_IOMMU_MODE_BARE)); riscv_iommu_writel(iommu, RISCV_IOMMU_REG_CQCSR, 0); riscv_iommu_writel(iommu, RISCV_IOMMU_REG_FQCSR, 0); riscv_iommu_writel(iommu, RISCV_IOMMU_REG_PQCSR, 0); @@ -1270,11 +1278,11 @@ static phys_addr_t riscv_iommu_iova_to_phys(struct iommu_domain *iommu_domain, dma_addr_t iova) { struct riscv_iommu_domain *domain = iommu_domain_to_riscv(iommu_domain); - unsigned long pte_size; + size_t pte_size; unsigned long *ptr; ptr = riscv_iommu_pte_fetch(domain, iova, &pte_size); - if (_io_pte_none(*ptr) || !_io_pte_present(*ptr)) + if (!ptr) return 0; return pfn_to_phys(__page_val_to_pfn(*ptr)) | (iova & (pte_size - 1)); diff --git a/drivers/iommu/riscv/iommu.h b/drivers/iommu/riscv/iommu.h index b1c4664542b48..46df79dd54957 100644 --- a/drivers/iommu/riscv/iommu.h +++ b/drivers/iommu/riscv/iommu.h @@ -64,6 +64,7 @@ struct riscv_iommu_device { int riscv_iommu_init(struct riscv_iommu_device *iommu); void riscv_iommu_remove(struct riscv_iommu_device *iommu); +void riscv_iommu_disable(struct riscv_iommu_device *iommu); #define riscv_iommu_readl(iommu, addr) \ readl_relaxed((iommu)->reg + (addr))