From f4dff048d2c0c99aacfc1487b4ecd75ddaea0e95 Mon Sep 17 00:00:00 2001 From: Blue Date: Mon, 25 Nov 2024 16:02:03 -0800 Subject: [PATCH] Improve symlink logic --- distributions/validate-modern.py | 94 +++++++++++++++++++------------- 1 file changed, 57 insertions(+), 37 deletions(-) diff --git a/distributions/validate-modern.py b/distributions/validate-modern.py index 4ddd8aeb..b43eedb1 100644 --- a/distributions/validate-modern.py +++ b/distributions/validate-modern.py @@ -233,36 +233,59 @@ def list_directory(path: str): return units + +def get_tar_file(tar, path: str, follow_symlink=False): + + # Tar members can be formated as /{path}, {path}, or ./{path} + if path.startswith('/'): + paths = [path, '.' + path, path[1:]] + elif path.startswith('./'): + paths = [path, path[1:], path[2:]] + else: + paths = [path, './' + path, '/' + path] + + def follow_if_symlink(info, path: str): + if follow_symlink and info.issym(): + if info.linkpath.startswith('/'): + return get_tar_file(tar, info.linkpath, follow_symlink=True) + else: + return get_tar_file(tar, f'{os.path.dirname(path)}/{info.linkpath}', follow_symlink=True) + else: + return info, path + + # First try accessing the file directly + for e in paths: + try: + return follow_if_symlink(tar.getmember(e), e) + except KeyError: + continue + + if not follow_symlink: + return None, None + + # Then look for symlinks + # The path might be covered by a symlink, check if parent exists and is a symlink + parent_path = os.path.dirname(path) + if parent_path != path: + try: + parent_info, real_parent_path = get_tar_file(tar, parent_path, follow_symlink=True) + if real_parent_path != parent_path: + return get_tar_file(tar, f'{real_parent_path}/{os.path.basename(path)}', follow_symlink=True) + except KeyError: + pass + + return None, None + def read_tar(flavor: str, name: str, file, elf_magic: str): with tarfile.open(fileobj=file) as tar: def validate_mode(path: str, mode, uid, gid, max_size = None, optional = False, follow_symlink = False, magic = None, parse_method = None): - try: - info = tar.getmember(path) - except KeyError: - try: - path = '.' + path - info = tar.getmember(path) - except KeyError: - # The path might be covered by a symlink, check if parent exists and is a symlink - parent_path = os.path.dirname(path) - if parent_path != path: - try: - parent_info = tar.getmember(parent_path) - if parent_info.issym(): - return validate_mode(f'/{parent_info.linkpath}/{os.path.basename(path)}', mode, uid, gid, max_size, optional, True, magic) - except KeyError: - pass - - if not optional: - error(flavor, name, f'File "{path}" not found in tar') - return False - - if follow_symlink and info.issym(): - if info.linkpath.startswith('/'): - return validate_mode(info.linkpath, mode, uid, gid, max_size, optional, True, magic, parse_method) - else: - return validate_mode(f'{os.path.dirname(path)}/{info.linkpath}', mode, uid, gid, max_size, optional, True, magic, parse_method) + info, real_path = get_tar_file(tar, path, follow_symlink) + + if info is None: + if not optional: + error(flavor, name, f'File "{path}" not found in tar') + return False permissions = oct(info.mode) if permissions not in mode: @@ -278,7 +301,7 @@ def validate_mode(path: str, mode, uid, gid, max_size = None, optional = False, error(flavor, name, f'file: "{path}" is too big (info.size), max: {max_size}') if magic is not None or parse_method is not None: - content = tar.extractfile(path) + content = tar.extractfile(real_path) if parse_method is not None: parse_method(content) @@ -293,15 +316,12 @@ def validate_mode(path: str, mode, uid, gid, max_size = None, optional = False, return True def validate_config(path: str, valid_keys: list): - try: - content = tar.extractfile(path) - except KeyError: - try: - content = tar.extractfile('.' + path) - except KeyError: - error(flavor, name, f'File "{file}" not found in tar') - return None + _, path = get_tar_file(tar, path, follow_symlink=True) + if path is None: + error(flavor, name, f'File "{file}" not found in tar') + return None + content = tar.extractfile(path) config = configparser.ConfigParser() config.read_string(content.read().decode()) @@ -332,7 +352,7 @@ def validate_config(path: str, valid_keys: list): defaultUid = int(defaultUid) if shortcut_icon := config.get('shortcut.icon', None): - validate_mode(shortcut_icon, [oct(0o660), oct(0o640)], 0, 0, 1024 * 1024) + validate_mode(shortcut_icon, [oct(0o664), oct(0o644)], 0, 0, 1024 * 1024) if not shortcut_icon.startswith(USR_LIB_WSL): warning(flavor, name, f'value for shortcut.icon is not under {USR_LIB_WSL}: "{shortcut_icon}"') @@ -346,7 +366,7 @@ def validate_config(path: str, valid_keys: list): if validate_mode('/etc/wsl.conf', [oct(0o664), oct(0o644)], 0, 0, optional=True): config = validate_config('/etc/wsl.conf', ['boot.systemd']) if config.get('boot.systemd', False): - validate_mode('/sbin/init', [oct(0o775), oct(0o755)], 0, 0, magic=elf_magic) + validate_mode('/sbin/init', [oct(0o775), oct(0o755)], 0, 0, magic=elf_magic, follow_symlink=True) validate_mode('/etc/passwd', [oct(0o664), oct(0o644)], 0, 0, parse_method = lambda fd: read_passwd(flavor, name, defaultUid, fd)) validate_mode('/etc/shadow', [oct(0o640), oct(0o600)], 0, None)