Skip to content

Commit

Permalink
fix vfiohotplug not liking non-dec addrs
Browse files Browse the repository at this point in the history
  • Loading branch information
ifd3f committed Dec 13, 2024
1 parent 657a845 commit 53ac4c1
Showing 1 changed file with 29 additions and 18 deletions.
47 changes: 29 additions & 18 deletions nix/nixos-modules/astral/vfio/vfiohotplug.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,14 +111,20 @@ def manage_device(action: str, device: Device, config: Config):
<product id="0x{device.product}" />
</source>
</hostdev>"""
manage_device_raw_xml(action, xml, config)
case "pci":
path = find_pci_path(device.vendor, device.product)
xml = f"""
<hostdev mode="subsystem" type="pci" managed="yes">
<source>
<address domain="0x{path.domain}" bus="0x{path.bus}" slot="0x{path.slot}" function="0x{path.function}"/>
</source>
</hostdev>"""
for path in find_pci_paths(device.vendor, device.product):
xml = f"""
<hostdev mode="subsystem" type="pci" managed="yes">
<source>
<address domain="0x{path.domain}" bus="0x{path.bus}" slot="0x{path.slot}" function="0x{path.function}"/>
</source>
</hostdev>"""
manage_device_raw_xml(action, xml, config)


def manage_device_raw_xml(action: str, xml: str, config: Config):
with tempfile.NamedTemporaryFile("w") as f:
logger.debug("Writing XML file %r: %r", f.name, xml)

f.write(xml)
Expand Down Expand Up @@ -176,21 +182,26 @@ def parse_device(d: dict) -> Device:
return Device(dt, vendor, product)


def find_pci_path(vendor: str, product: str) -> PCIPath:
def find_pci_paths(vendor: str, product: str) -> t.List[PCIPath]:
vpid = f"{vendor}:{product}"
logger.debug("querying path of device %s", vpid)
output = (
subprocess.check_output(["lspci", "-d", vpid, "-D"])
.decode()
.strip()
.splitlines()
)
cmd = ["lspci", "-d", vpid, "-D"]
logger.debug("querying path of device %s with command %r", vpid, cmd)
output = subprocess.check_output(cmd).decode().strip()
logger.debug("output: %r", output)

results = [
PCIPath(*re.match(r"(\d+):(\d+):(\d+).(\d+).*", l).groups()) for l in output
PCIPath(*re.match(r"(\w+):(\w+):(\w+).(\w+).*", l).groups())
for l in output.splitlines()
if l.strip()
]
logger.info("associated pci:%s to paths: %r", vpid, results)
return results[0]

if not results:
raise EnvironmentError(
f"Could not find PCI path for {vpid}! lspci output: {output}"
)

logger.debug("associated pci:%s to paths: %r", vpid, results)
return results


if __name__ == "__main__":
Expand Down

0 comments on commit 53ac4c1

Please sign in to comment.