diff --git a/sriov/common/utils.py b/sriov/common/utils.py index 9bbd2b3..bade4c4 100644 --- a/sriov/common/utils.py +++ b/sriov/common/utils.py @@ -1112,3 +1112,29 @@ def get_nic_model(ssh_obj: ShellHandler, pf: str) -> str: cmd = [f"lshw -C network -businfo | grep {pf}"] outs, _ = execute_and_assert(ssh_obj, cmd, 0) return re.split("\\s{2,}", outs[0][0])[-1].strip() + + +def get_driver_pci(ssh_obj: ShellHandler, pci: str) -> str: + """Get the base driver (i40e or ice) by the PCI address + + Args: + ssh_obj (ShellHandler): ssh connection obj + pci (str): pci address + + Returns: + str: The driver this pci address is bound to + """ + drivers = ["i40e", "ice"] + modprobe_cmds = [] + cmds = [] + for driver in drivers: + modprobe_cmds.append(f"modprobe {driver}") + cmds.append(f"find /sys/bus/pci/drivers/{driver}/ -name {pci}") + outs, _ = execute_and_assert(ssh_obj, modprobe_cmds, 0) + outs, _ = execute_and_assert(ssh_obj, cmds, 0) + if outs[0]: + return drivers[0] + elif outs[1]: + return drivers[1] + else: + assert Exception("Driver not in list: ", drivers) diff --git a/sriov/tests/conftest.py b/sriov/tests/conftest.py index 1ab600f..8a00cad 100644 --- a/sriov/tests/conftest.py +++ b/sriov/tests/conftest.py @@ -17,6 +17,7 @@ destroy_vfs, bind_driver, execute_and_assert, + get_driver_pci, ) @@ -54,15 +55,27 @@ def settings(dut, trafficgen) -> Config: settings = get_settings_obj() pf1_name = settings.config["dut"]["interface"]["pf1"]["name"] settings.config["dut"]["interface"]["pf1"]["pci"] = get_pci_address(dut, pf1_name) + settings.config["dut"]["interface"]["pf1"]["driver"] = get_driver_pci( + dut, settings.config["dut"]["interface"]["pf1"]["pci"] + ) settings.config["dut"]["interface"]["pf2"]["pci"] = get_pci_address( dut, settings.config["dut"]["interface"]["pf2"]["name"] ) + settings.config["dut"]["interface"]["pf2"]["driver"] = get_driver_pci( + dut, settings.config["dut"]["interface"]["pf2"]["pci"] + ) settings.config["trafficgen"]["interface"]["pf1"]["pci"] = get_pci_address( trafficgen, settings.config["trafficgen"]["interface"]["pf1"]["name"] ) + settings.config["trafficgen"]["interface"]["pf1"]["driver"] = get_driver_pci( + trafficgen, settings.config["trafficgen"]["interface"]["pf1"]["pci"] + ) settings.config["trafficgen"]["interface"]["pf2"]["pci"] = get_pci_address( trafficgen, settings.config["trafficgen"]["interface"]["pf2"]["name"] ) + settings.config["trafficgen"]["interface"]["pf2"]["driver"] = get_driver_pci( + trafficgen, settings.config["trafficgen"]["interface"]["pf2"]["pci"] + ) create_vfs(dut, pf1_name, 1) vf1_name = settings.config["dut"]["interface"]["vf1"]["name"] settings.config["dut"]["interface"]["vf1"]["pci"] = get_pci_address(dut, vf1_name) @@ -161,17 +174,17 @@ def _cleanup( ] execute_and_assert(trafficgen, kill_trafficgen, 0) - trafficgen_pfs_pci = [] + trafficgen_pfs_pci = {} if "pf1" in settings.config["trafficgen"]["interface"]: - trafficgen_pfs_pci.append( + trafficgen_pfs_pci[ settings.config["trafficgen"]["interface"]["pf1"]["pci"] - ) + ] = settings.config["trafficgen"]["interface"]["pf1"]["driver"] if "pf2" in settings.config["trafficgen"]["interface"]: - trafficgen_pfs_pci.append( + trafficgen_pfs_pci[ settings.config["trafficgen"]["interface"]["pf2"]["pci"] - ) + ] = settings.config["trafficgen"]["interface"]["pf2"]["driver"] for pf in trafficgen_pfs_pci: - assert bind_driver(trafficgen, pf, "i40e") + assert bind_driver(trafficgen, pf, trafficgen_pfs_pci[pf]) reset_command(dut, testdata)