Skip to content

Commit

Permalink
[Core] Fix wheel creation when multiple versions are installed (#3866)
Browse files Browse the repository at this point in the history
* Fix wheel creation

* Fix wheel creation

* Refactor _get_latest_wheel_and_remove_all_others

* lint
  • Loading branch information
romilbhardwaj authored Aug 25, 2024
1 parent 4e7478b commit 6067e2b
Showing 1 changed file with 29 additions and 18 deletions.
47 changes: 29 additions & 18 deletions sky/backends/wheel_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,29 +39,30 @@
f'{version.parse(sky.__version__)}-*.whl')


def _get_latest_wheel_and_remove_all_others() -> pathlib.Path:
wheel_name = (f'**/{_WHEEL_PATTERN}')
def _remove_stale_wheels(latest_wheel_dir: pathlib.Path) -> None:
"""Remove all wheels except the latest one."""
for f in WHEEL_DIR.iterdir():
if f != latest_wheel_dir:
if f.is_dir() and not f.is_symlink():
shutil.rmtree(f, ignore_errors=True)


def _get_latest_wheel() -> pathlib.Path:
wheel_name = f'**/{_WHEEL_PATTERN}'
try:
latest_wheel = max(WHEEL_DIR.glob(wheel_name), key=os.path.getctime)
except ValueError:
raise FileNotFoundError(
'Could not find built SkyPilot wheels with glob pattern '
f'{wheel_name} under {WHEEL_DIR!r}') from None

latest_wheel_dir_name = latest_wheel.parent
# Cleanup older wheels.
for f in WHEEL_DIR.iterdir():
if f != latest_wheel_dir_name:
if f.is_dir() and not f.is_symlink():
shutil.rmtree(f, ignore_errors=True)
return latest_wheel


def _build_sky_wheel():
"""Build a wheel for SkyPilot."""
with tempfile.TemporaryDirectory() as tmp_dir:
def _build_sky_wheel() -> pathlib.Path:
"""Build a wheel for SkyPilot and return the path to the wheel."""
with tempfile.TemporaryDirectory() as tmp_dir_str:
# prepare files
tmp_dir = pathlib.Path(tmp_dir)
tmp_dir = pathlib.Path(tmp_dir_str)
sky_tmp_dir = tmp_dir / 'sky'
sky_tmp_dir.mkdir()
for item in SKY_PACKAGE_PATH.iterdir():
Expand Down Expand Up @@ -129,6 +130,7 @@ def _build_sky_wheel():
wheel_dir = WHEEL_DIR / hash_of_latest_wheel
wheel_dir.mkdir(parents=True, exist_ok=True)
shutil.move(str(wheel_path), wheel_dir)
return wheel_dir / wheel_path.name


def build_sky_wheel() -> Tuple[pathlib.Path, str]:
Expand Down Expand Up @@ -161,13 +163,22 @@ def _get_latest_modification_time(path: pathlib.Path) -> float:
last_modification_time = _get_latest_modification_time(SKY_PACKAGE_PATH)
last_wheel_modification_time = _get_latest_modification_time(WHEEL_DIR)

# only build wheels if the wheel is outdated
if last_wheel_modification_time < last_modification_time:
# Only build wheels if the wheel is outdated or wheel does not exist
# for the requested version.
if (last_wheel_modification_time < last_modification_time) or not any(
WHEEL_DIR.glob(f'**/{_WHEEL_PATTERN}')):
if not WHEEL_DIR.exists():
WHEEL_DIR.mkdir(parents=True, exist_ok=True)
_build_sky_wheel()

latest_wheel = _get_latest_wheel_and_remove_all_others()
latest_wheel = _build_sky_wheel()
else:
latest_wheel = _get_latest_wheel()

# We remove all wheels except the latest one for garbage collection.
# Otherwise stale wheels will accumulate over time.
# TODO(romilb): If the user switches versions every alternate launch,
# the wheel will be rebuilt every time. At the risk of adding
# complexity, we can consider TTL caching wheels by version here.
_remove_stale_wheels(latest_wheel.parent)

wheel_hash = latest_wheel.parent.name

Expand Down

0 comments on commit 6067e2b

Please sign in to comment.