Skip to content

Commit

Permalink
Add comments and fix dev version comparison
Browse files Browse the repository at this point in the history
  • Loading branch information
austinschneider committed Jan 24, 2024
1 parent ab24bea commit b2e193b
Showing 1 changed file with 72 additions and 2 deletions.
74 changes: 72 additions & 2 deletions python/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ def has_module(module_name):
)

def decompose_version(version):
# Break the version string into its components
matches = _version_regex.match(version)
if matches is None:
return dict()
Expand All @@ -230,63 +231,93 @@ def decompose_version(version):


def normalize_version(version):
# Normalize the version string
d = decompose_version(version)
n_version = ""

# Add epoch if present
if d["epoch"] is not None:
n_version += str(int(d["epoch"])) + "!"

# Add release segment
if d["release"] is not None:
# Remove leading zeros from each segment
n_version += ".".join(
[str(int(s)) for s in d["release"].strip(". ").lower().split(".")]
)

# Add pre-release segment
if d["pre"] is not None:
if d["pre_l"] is not None:
# Add pre-release label
if d["pre_l"] in ["a", "alpha"]:
n_version += "a"
elif d["pre_l"] in ["b", "beta"]:
n_version += "b"
elif d["pre_l"] in ["c", "rc", "pre", "preview"]:
n_version += "rc"
# Add pre-release number
if d["pre_n"] is not None:
n_version += str(int(d["pre_n"]))

# Add post-release segment
if d["post"] is not None:
n_version += ".post"
# Add post-release number
if d["post_n1"] is not None:
n_version += str(int(d["post_n1"]))
elif d["post_n2"] is not None:
n_version += str(int(d["post_n2"]))

# Add dev-release segment
if d["dev"] is not None:
n_version += ".dev"
# Add dev-release number
if d["dev_n"] is not None:
n_version += str(int(d["dev_n"]))

# Add local segment
if d["local"] is not None:
n_version += "+"
segments = []
# Local segment can contain letters and numbers
for s in d["local"].lower().split("."):
# Remove leading zeros from each segment if it is a number
try:
segments.append(str(int(s)))
except:
segments.append(s)
# Join the segments with a dot
n_version += ".".join(segments)
return n_version


def tokenize_version(version):
# Tokenize the version string
d = decompose_version(normalize_version(version))
tokens = []
tokens = dict()

if d["epoch"] is not None:
# Add epoch if present
tokens["epoch"] = int(d["epoch"])
else:
# Default epoch is 0
tokens["epoch"] = 0

# Add release segment
if d["release"] is not None:
# Remove leading zeros from each segment
tokens["release"] = tuple(
[int(s) for s in d["release"].strip(". ").lower().split(".")]
)

# Add pre-release segment
if d["pre"] is not None:
# Pre-release segment starts with 0 to indicate that it is present
# Because a pre-release version comes before a release version
pre_token = [0]

if d["pre_l"] is not None:
if d["pre_l"] in ["a", "alpha"]:
pre_token.append(1)
Expand All @@ -302,9 +333,14 @@ def tokenize_version(version):
else:
tokens["pre"] = tuple(pre_token)
else:
# Pre-release segment starts with 1 if not present to indicate that it is not present
# Because a release version comes after a pre-release version
tokens["pre"] = (1,)

# Add post-release segment
if d["post"] is not None:
# If post-release segment is present, it starts with 1
# Because a post-release version comes after its corresponding (pre-)release version
post_token = [1]
if d["post_n1"] is not None:
post_token.append(int(d["post_n1"]))
Expand All @@ -314,19 +350,29 @@ def tokenize_version(version):
post_token.append(0)
tokens["post"] = tuple(post_token)
else:
# If post-release segment is not present, it starts with 0
# Because a post-release version comes after its corresponding (pre-)release version
tokens["post"] = (0,)

# Add dev-release segment
if d["dev"] is not None:
dev_token = [1]
# If dev-release segment is present, it starts with 0
# Because a dev-release version comes before its corresponding (pre-,post-)release version
dev_token = [0]
if d["dev_n"] is not None:
dev_token.append(int(d["dev_n"]))
else:
dev_token.append(0)
tokens["dev"] = tuple(dev_token)
else:
tokens["dev"] = (0,)
# If dev-release segment is not present, it starts with 1
# Because a dev-release version comes before its corresponding (pre-,post-)release version
tokens["dev"] = (1,)

# Add local segment
if d["local"] is not None:
# If local segment is present, it starts with 1
# Because a local version comes after its corresponding (pre-,post-,dev-)release version
local_token = [1]
for s in d["local"].lower().split("."):
try:
Expand All @@ -335,8 +381,11 @@ def tokenize_version(version):
local_token.append((0, s))
tokens["local"] = tuple(local_token)
else:
# If local segment is not present, it starts with 0
# Because a local version comes after its corresponding (pre-,post-,dev-)release version
tokens["local"] = (0,)

# Return the tokenized version
token_list = [
tokens["epoch"],
tokens["release"],
Expand All @@ -350,23 +399,30 @@ def tokenize_version(version):


def _get_model_path(model_name, prefix=None, suffix=None, is_file=True, must_exist=True):
# Get the path to the model file
_model_regex = re.compile(
r"^\s*" + _MODEL_PATTERN + ("" if suffix is None else r"(?:" + suffix + r")?") + r"\s*$",
re.VERBOSE | re.IGNORECASE,
)
if suffix is None:
suffix = ""
# Get the path to the resources directory
resources_dir = resource_package_dir()
base_dir = resources_dir

# Add prefix if present
if prefix is not None:
base_dir = os.path.join(base_dir, prefix)

# Get the model name and version
d = _model_regex.match(model_name)
if d is None:
raise ValueError("Invalid model name: {}".format(model_name))
d = d.groupdict()
model_name = d["model_name"]
version = d["version"]

# Search for the model folder in the resources directory
model_names = [
f for f in os.listdir(base_dir) if not os.path.isfile(os.path.join(base_dir, f))
]
Expand All @@ -375,22 +431,27 @@ def _get_model_path(model_name, prefix=None, suffix=None, is_file=True, must_exi
folder_exists = False

if len(model_names) == 0 and must_exist:
# Whoops, we didn't find the model folder!
raise ValueError(
"No model folders found for {}\nSearched in ".format(model_name, base_dir)
)
elif len(model_names) == 0 and not must_exist:
# Let's use the provided model name as the folder name
model_name = model_name
elif len(model_names) == 1:
# We found the model folder!
folder_exists = True
model_name = model_names[0]
else:
# Multiple model folders found, we cannot decide which one to use
raise ValueError(
"Multiple directories found for {}\nSearched in ".format(
model_name, base_dir
)
)

if folder_exists:
# Search for the model file in the model folder
model_files = [
f
for f in os.listdir(os.path.join(base_dir, model_name))
Expand All @@ -399,6 +460,7 @@ def _get_model_path(model_name, prefix=None, suffix=None, is_file=True, must_exi
else:
model_files = []

# From the found model files, extract the model versions
model_versions = []
for f in model_files:
d = _model_regex.match(f)
Expand All @@ -418,6 +480,7 @@ def _get_model_path(model_name, prefix=None, suffix=None, is_file=True, must_exi
)
))

# Raise an error if no model file is found and we require it to exist
if len(model_versions) == 0 and must_exist:
raise ValueError(
"No model found for {}\nSearched in ".format(
Expand All @@ -426,16 +489,20 @@ def _get_model_path(model_name, prefix=None, suffix=None, is_file=True, must_exi
)

if version is None and must_exist:
# If no version is provided, use the latest version
version_idx, version = max(
enumerate(model_versions), key=lambda x: tokenize_version(x[1])
)
model_file_name = model_files[version_idx]
elif version is None and not must_exist:
# If no version is provided and we don't require it to exist, default to v1
version = "v1"
model_file_name = "{}-v{}{}".format(model_name, version, suffix)
else:
# A version is provided
version = normalize_version(version)
if must_exist:
# If the version must exist, raise an error if it doesn't
if version not in model_versions:
raise ValueError(
"No model found for {}-{}\nSearched in ".format(
Expand All @@ -445,10 +512,13 @@ def _get_model_path(model_name, prefix=None, suffix=None, is_file=True, must_exi
version_idx = model_versions.index(version)
model_file_name = model_files[version_idx]
else:
# The version doesn't have to exist
if version in model_versions:
# If the version exists, use it
version_idx = model_versions.index(version)
model_file_name = model_files[version_idx]
else:
# Otherwise use the provided version
model_file_name = "{}-v{}{}".format(model_name, version, suffix)

return os.path.join(base_dir, model_name, model_file_name)
Expand Down

0 comments on commit b2e193b

Please sign in to comment.