From c6252d4b65bb42a896d3c10259faed5c23682c0e Mon Sep 17 00:00:00 2001 From: Mihai Maruseac Date: Wed, 24 Jul 2024 07:21:20 -0700 Subject: [PATCH] Machinery for singing generic manifests. THIS IS DRAFT, WIP. Will split into separate PRs once it works. But posting publicly to show what the plans are (#224, #248, #240, #111). Signed-off-by: Mihai Maruseac --- model_signing/manifest/manifest.py | 92 +++++++++++++++- model_signing/manifest/manifest_test.py | 100 ++++++++++++++++++ .../serialize_by_file_shard_test.py | 64 +++++------ model_signing/signing/__init__.py | 13 +++ model_signing/signing/signing.py | 80 ++++++++++++++ 5 files changed, 308 insertions(+), 41 deletions(-) create mode 100644 model_signing/signing/__init__.py create mode 100644 model_signing/signing/signing.py diff --git a/model_signing/manifest/manifest.py b/model_signing/manifest/manifest.py index 419f406e..fbc6517b 100644 --- a/model_signing/manifest/manifest.py +++ b/model_signing/manifest/manifest.py @@ -55,15 +55,44 @@ from collections.abc import Iterable import dataclasses import pathlib -from typing import Self +from typing import Iterator, Self +from typing_extensions import override from model_signing.hashing import hashing +@dataclasses.dataclass(frozen=True) +class ResourceDescriptor: + """A description of any content from any `Manifest`. + + We aim this to be similar to in-toto's `ResourceDescriptor`. To support + cases where in-toto cannot be directly used, we make this a dataclass that + can be mapped to in-toto when needed, and used as its own otherwise. + + Not all fields from in-toto are specified at this moment. All fields here + must be present, unlike in-toto, where all are optional. + + See github.com/in-toto/attestation/blob/main/spec/v1/resource_descriptor.md + for the in-toto specification. + + Attributes: + identifier: A string that uniquely identifies this `ResourceDescriptor`. + Corresponds to `name`, `uri`, or `content` in in-toto specification. + digest: One digest for the item. Note that unlike in-toto, we only have + one digest for the item and it is always required. + """ + + identifier: str + digest: hashing.Digest + + class Manifest(metaclass=abc.ABCMeta): """Generic manifest file to represent a model.""" - pass + @abc.abstractmethod + def resource_descriptors(self) -> Iterator[ResourceDescriptor]: + """Yields each resource from the manifest, one by one.""" + pass @dataclasses.dataclass(frozen=True) @@ -72,6 +101,17 @@ class DigestManifest(Manifest): digest: hashing.Digest + @override + def resource_descriptors(self) -> Iterator[ResourceDescriptor]: + """Yields each resource from the manifest, one by one. + + In this case, we have only one descriptor to return. Since model paths + are already encoded in the digest, use "." for the digest. Subclasses + might record additional fields to have distinguishable human readable + identifiers. + """ + yield ResourceDescriptor(identifier=".", digest=self.digest) + class ItemizedManifest(Manifest): """A detailed manifest, recording integrity of every model component.""" @@ -130,6 +170,37 @@ def __init__(self, items: Iterable[FileManifestItem]): def __eq__(self, other: Self): return self._item_to_digest == other._item_to_digest + @override + def resource_descriptors(self) -> Iterator[ResourceDescriptor]: + """Yields each resource from the manifest, one by one. + + The items are returned in alphabetical order of the path. + """ + for item, digest in sorted(self._item_to_digest.items()): + yield ResourceDescriptor(identifier=str(item), digest=digest) + + +@dataclasses.dataclass(frozen=True, order=True) +class Shard: + """A dataclass to hold information about a file shard. + + Attributes: + path: The path to the file, relative to the model root. + start: The start offset of the shard (included). + end: The end offset of the shard (not included). + """ + + path: pathlib.PurePath + start: int + end: int + + def __str__(self) -> str: + """Converts the item to a canonicalized string representation. + + The format is {path}:{start}:{end}, which should also be easy to decode. + """ + return f"{str(self.path)}:{self.start}:{self.end}" + @dataclasses.dataclass class ShardedFileManifestItem(ManifestItem): @@ -146,7 +217,7 @@ def __init__( path: pathlib.PurePath, start: int, end: int, - digest: hashing.Digest + digest: hashing.Digest, ): """Builds a manifest item pairing a file shard with its digest. @@ -163,9 +234,9 @@ def __init__( self.digest = digest @property - def input_tuple(self) -> tuple[pathlib.PurePath, int, int]: + def input_tuple(self) -> Shard: """Returns the triple that uniquely determines the manifest item.""" - return (self.path, self.start, self.end) + return Shard(self.path, self.start, self.end) class ShardLevelManifest(FileLevelManifest): @@ -178,3 +249,14 @@ def __init__(self, items: Iterable[ShardedFileManifestItem]): efficient updates and retrieval of digests. """ self._item_to_digest = {item.input_tuple: item.digest for item in items} + + @override + def resource_descriptors(self) -> Iterator[ResourceDescriptor]: + """Yields each resource from the manifest, one by one. + + The items are returned in the order given by the `input_tuple` property + of `ShardedFileManifestItem` used to create this instance (the triple of + file name and shard endpoints). + """ + for item, digest in sorted(self._item_to_digest.items()): + yield ResourceDescriptor(identifier=str(item), digest=digest) diff --git a/model_signing/manifest/manifest_test.py b/model_signing/manifest/manifest_test.py index f097c8de..f541cddf 100644 --- a/model_signing/manifest/manifest_test.py +++ b/model_signing/manifest/manifest_test.py @@ -13,11 +13,31 @@ # limitations under the License. import pathlib +import pytest from model_signing.hashing import hashing from model_signing.manifest import manifest +class TestDigestManifest: + + def test_manifest_has_just_one_resource_descriptor(self): + digest = hashing.Digest("test", b"test_digest") + manifest_file = manifest.DigestManifest(digest) + + descriptors = list(manifest_file.resource_descriptors()) + + assert len(descriptors) == 1 + + def test_manifest_has_the_correct_resource_descriptor(self): + digest = hashing.Digest("test", b"test_digest") + manifest_file = manifest.DigestManifest(digest) + + for descriptor in manifest_file.resource_descriptors(): + assert descriptor.identifier == "." + assert descriptor.digest == digest + + class TestFileLevelManifest: def test_insert_order_does_not_matter(self): @@ -34,6 +54,39 @@ def test_insert_order_does_not_matter(self): assert manifest1 == manifest2 + @pytest.mark.parametrize("num_items", [1, 3, 5]) + def test_manifest_has_all_resource_descriptors(self, num_items): + items: list[manifest.FileManifestItem] = [] + for i in range(num_items): + path = pathlib.PurePath(f"file{i}") + digest = hashing.Digest("test", b"hash{i}") + item = manifest.FileManifestItem(path=path, digest=digest) + items.append(item) + manifest_file = manifest.FileLevelManifest(items) + + descriptors = list(manifest_file.resource_descriptors()) + + assert len(descriptors) == num_items + + def test_manifest_has_the_correct_resource_descriptors(self): + path1 = pathlib.PurePath(f"file1") + digest1 = hashing.Digest("test", b"hash1") + item1 = manifest.FileManifestItem(path=path1, digest=digest1) + + path2 = pathlib.PurePath(f"file2") + digest2 = hashing.Digest("test", b"hash2") + item2 = manifest.FileManifestItem(path=path2, digest=digest2) + + # Note order is reversed + manifest_file = manifest.FileLevelManifest([item2, item1]) + descriptors = list(manifest_file.resource_descriptors()) + + # But we expect the descriptors to be in order by file + assert descriptors[0].identifier == "file1" + assert descriptors[1].identifier == "file2" + assert descriptors[0].digest.digest_value == b"hash1" + assert descriptors[1].digest.digest_value == b"hash2" + class TestShardLevelManifest: @@ -70,3 +123,50 @@ def test_same_path_different_shards_gives_different_manifest(self): manifest2 = manifest.ShardLevelManifest([item]) assert manifest1 != manifest2 + + @pytest.mark.parametrize("num_items", [1, 3, 5]) + def test_manifest_has_all_resource_descriptors(self, num_items): + items: list[manifest.ShardedFileManifestItem] = [] + for i in range(num_items): + path = pathlib.PurePath(f"file") + digest = hashing.Digest("test", b"hash{i}") + item = manifest.ShardedFileManifestItem( + path=path, digest=digest, start=i, end=i + 2 + ) + items.append(item) + manifest_file = manifest.ShardLevelManifest(items) + + descriptors = list(manifest_file.resource_descriptors()) + + assert len(descriptors) == num_items + + def test_manifest_has_the_correct_resource_descriptors(self): + path1 = pathlib.PurePath(f"file1") + digest1 = hashing.Digest("test", b"hash1") + item1 = manifest.ShardedFileManifestItem( + path=path1, digest=digest1, start=0, end=4 + ) + + path2 = pathlib.PurePath(f"file2") + digest2 = hashing.Digest("test", b"hash2") + item2 = manifest.ShardedFileManifestItem( + path=path2, digest=digest2, start=0, end=4 + ) + + # First file, but second shard + digest3 = hashing.Digest("test", b"hash3") + item3 = manifest.ShardedFileManifestItem( + path=path1, digest=digest3, start=4, end=8 + ) + + # Note order is reversed + manifest_file = manifest.ShardLevelManifest([item3, item2, item1]) + descriptors = list(manifest_file.resource_descriptors()) + + # But we expect the descriptors to be in order by file shard + assert descriptors[0].identifier == "file1:0:4" + assert descriptors[1].identifier == "file1:4:8" + assert descriptors[2].identifier == "file2:0:4" + assert descriptors[0].digest.digest_value == b"hash1" + assert descriptors[1].digest.digest_value == b"hash3" + assert descriptors[2].digest.digest_value == b"hash2" diff --git a/model_signing/serialization/serialize_by_file_shard_test.py b/model_signing/serialization/serialize_by_file_shard_test.py index 750a9454..8583791c 100644 --- a/model_signing/serialization/serialize_by_file_shard_test.py +++ b/model_signing/serialization/serialize_by_file_shard_test.py @@ -301,18 +301,9 @@ def test_shard_size_changes_digests(self, sample_model_folder): assert manifest1.digest.digest_value != manifest2.digest.digest_value -@dataclasses.dataclass(frozen=True, order=True) -class _Shard: - """A shard of a file from a manifest.""" - - path: str - start: int - end: int - - def _extract_shard_items_from_manifest( manifest: manifest.ShardLevelManifest, -) -> dict[_Shard, str]: +) -> dict[manifest.Shard, str]: """Builds a dictionary representation of the items in a manifest. Every item is mapped to its digest. @@ -320,13 +311,12 @@ def _extract_shard_items_from_manifest( Used in multiple tests to check that we obtained the expected manifest. """ return { - # convert to file path (relative to model) string and endpoints - _Shard(str(shard[0]), shard[1], shard[2]): digest.digest_hex + shard: digest.digest_hex for shard, digest in manifest._item_to_digest.items() } -def _parse_shard_and_digest(line: str) -> tuple[_Shard, str]: +def _parse_shard_and_digest(line: str) -> tuple[manifest.Shard, str]: """Reads a file shard and its digest from a line in the golden file. Args: @@ -336,7 +326,7 @@ def _parse_shard_and_digest(line: str) -> tuple[_Shard, str]: The shard tuple and the digest corresponding to the line that was read. """ path, start, end, digest = line.strip().split(":") - shard = _Shard(path, int(start), int(end)) + shard = manifest.Shard(pathlib.PurePosixPath(path), int(start), int(end)) return shard, digest @@ -370,18 +360,16 @@ def test_known_models(self, request, model_fixture_name): serializer = serialize_by_file_shard.ManifestSerializer( self._hasher_factory ) - manifest = serializer.serialize(model) - items = _extract_shard_items_from_manifest(manifest) + manifest_file = serializer.serialize(model) + items = _extract_shard_items_from_manifest(manifest_file) # Compare with golden, or write to golden (approximately "assert") if should_update: with open(golden_path, "w", encoding="utf-8") as f: for shard, digest in sorted(items.items()): - f.write( - f"{shard.path}:{shard.start}:{shard.end}:{digest}\n" - ) + f.write(f"{shard}:{digest}\n") else: - found_items: dict[_Shard, str] = {} + found_items: dict[manifest.Shard, str] = {} with open(golden_path, "r", encoding="utf-8") as f: for line in f: shard, digest = _parse_shard_and_digest(line) @@ -403,18 +391,16 @@ def test_known_models_small_shards(self, request, model_fixture_name): serializer = serialize_by_file_shard.ManifestSerializer( self._hasher_factory_small_shards ) - manifest = serializer.serialize(model) - items = _extract_shard_items_from_manifest(manifest) + manifest_file = serializer.serialize(model) + items = _extract_shard_items_from_manifest(manifest_file) # Compare with golden, or write to golden (approximately "assert") if should_update: with open(golden_path, "w", encoding="utf-8") as f: for shard, digest in sorted(items.items()): - f.write( - f"{shard.path}:{shard.start}:{shard.end}:{digest}\n" - ) + f.write(f"{shard}:{digest}\n") else: - found_items: dict[_Shard, str] = {} + found_items: dict[manifest.Shard, str] = {} with open(golden_path, "r", encoding="utf-8") as f: for line in f: shard, digest = _parse_shard_and_digest(line) @@ -522,9 +508,8 @@ def _check_manifests_match_except_on_renamed_file( old_manifest._item_to_digest ) for shard, digest in new_manifest._item_to_digest.items(): - path, start, end = shard - if path.name == new_name: - old_shard = (old_name, start, end) + if shard.path.name == new_name: + old_shard = manifest.Shard(old_name, shard.start, shard.end) assert old_manifest._item_to_digest[old_shard] == digest else: assert old_manifest._item_to_digest[shard] == digest @@ -566,13 +551,14 @@ def _check_manifests_match_except_on_renamed_dir( old_manifest._item_to_digest ) for shard, digest in new_manifest._item_to_digest.items(): - path, start, end = shard - if new_name in path.parts: + if new_name in shard.path.parts: parts = [ old_name if part == new_name else part - for part in path.parts + for part in shard.path.parts ] - old = (pathlib.PurePosixPath(*parts), start, end) + old = manifest.Shard( + pathlib.PurePosixPath(*parts), shard.start, shard.end + ) assert old_manifest._item_to_digest[old] == digest else: assert old_manifest._item_to_digest[shard] == digest @@ -627,10 +613,10 @@ def _check_manifests_match_except_on_entry( old_manifest._item_to_digest ) for shard, digest in new_manifest._item_to_digest.items(): - path, _, _ = shard - if path == expected_mismatch_path: + if shard.path == expected_mismatch_path: # Note that the file size changes - assert old_manifest._item_to_digest[(path, 0, 23)] != digest + item = manifest.Shard(shard.path, 0, 23) + assert old_manifest._item_to_digest[item] != digest else: assert old_manifest._item_to_digest[shard] == digest @@ -668,3 +654,9 @@ def test_max_workers_does_not_change_digest(self, sample_model_folder): assert manifest1 == manifest2 assert manifest1 == manifest3 + + + def test_shard_to_string(self): + """Ensure the shard's `__str__` method behaves as assumed.""" + shard = manifest.Shard(pathlib.PurePosixPath("a"), 0, 42) + assert str(shard) == "a:0:42" diff --git a/model_signing/signing/__init__.py b/model_signing/signing/__init__.py new file mode 100644 index 00000000..0888a055 --- /dev/null +++ b/model_signing/signing/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 The Sigstore Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/model_signing/signing/signing.py b/model_signing/signing/signing.py new file mode 100644 index 00000000..b324ec58 --- /dev/null +++ b/model_signing/signing/signing.py @@ -0,0 +1,80 @@ +# Copyright 2024 The Sigstore Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Machinery for signing and verification of ML models. + +The serialization API produces a manifest representation of the models, which +can be used to implement various verification patterns. However, when signing, +we need to actually represent this manifest in a specific disk format. But, +there are multiple ways to use `manifest.Manifest` objects, so we add a new +`SigningMaterial` class hierarchy to serialize and sign manifests. + +The output of a signing process is a `Signature` instance, backed by a format to +serialize this to disk. In OSS, this is usually a Sigstore bundle. + +TODO: expand on this. +""" + +import abc +import pathlib +from typing import Self + +from model_signing.manifest import manifest + + +class SigningMaterial(metaclass=abc.ABCMeta): + """Generic material that we can sign.""" + + @classmethod + @abc.abstractmethod + def from_manifest(cls, manifest: manifest.Manifest) -> Self: + """Converts a manifest to the signing material used for signing.""" + pass + + @abc.abstractmethod + def sign(self) -> "Signature": + """Signs the current SigningMaterial with the provided key/signer. + + TODO: arguments, abstract over signing format, etc. + """ + pass + + +class Signature(metaclass=abc.ABCMeta): + """Generic signature support.""" + + @abc.abstractmethod + def write_signature(self, path: pathlib.Path): + """Writes the signature to disk, to the given path.""" + pass + + @classmethod + @abc.abstractmethod + def read_signature(cls, path: pathlib.Path) -> Self: + """Reads the signature from disk. + + Does not perform any verification, except what is needed to parse the + signature file. Use `verify` to validate the signature. + """ + pass + + @abc.abstractmethod + def verify(self): # TODO: signature + """Verifies the signature. + + If the verification passes, this method returns TODO: what? + + TODO: Document return and raises. + """ + pass