diff --git a/api/python/quilt3/api.py b/api/python/quilt3/api.py index 561a2a8d35a..da1b18899b9 100644 --- a/api/python/quilt3/api.py +++ b/api/python/quilt3/api.py @@ -21,7 +21,7 @@ ) -def copy(src, dest): +def copy(src, dest, put_options=None): """ Copies ``src`` object from QUILT to ``dest``. @@ -29,10 +29,11 @@ def copy(src, dest): or local file paths (starting with ``file:///``). Parameters: - src (str): a path to retrieve - dest (str): a path to write to + src: a path to retrieve + dest: a path to write to + put_options: optional arguments to pass to the PutObject operation """ - copy_file(PhysicalKey.from_url(fix_url(src)), PhysicalKey.from_url(fix_url(dest))) + copy_file(PhysicalKey.from_url(fix_url(src)), PhysicalKey.from_url(fix_url(dest)), put_options=put_options) @ApiTelemetry("api.delete_package") diff --git a/api/python/quilt3/backends/base.py b/api/python/quilt3/backends/base.py index ff8a065deec..dfc5548d055 100644 --- a/api/python/quilt3/backends/base.py +++ b/api/python/quilt3/backends/base.py @@ -87,7 +87,8 @@ def delete_package_version(self, pkg_name: str, top_hash: str): pass @abc.abstractmethod - def push_manifest(self, pkg_name: str, top_hash: str, manifest_data: bytes): + def push_manifest(self, pkg_name: str, top_hash: str, manifest_data: bytes, + put_options: dict = None): pass @abc.abstractmethod @@ -134,14 +135,14 @@ def manifests_package_dir(self, pkg_name: str) -> PhysicalKey: def manifest_pk(self, pkg_name: str, top_hash: str) -> PhysicalKey: return self.root.join(f'packages/{top_hash}') - def push_manifest(self, pkg_name: str, top_hash: str, manifest_data: bytes): + def push_manifest(self, pkg_name: str, top_hash: str, manifest_data: bytes, put_options: dict = None): """returns: timestamp to support catalog drag-and-drop => browse""" - put_bytes(manifest_data, self.manifest_pk(pkg_name, top_hash)) + put_bytes(manifest_data, self.manifest_pk(pkg_name, top_hash), put_options=put_options) hash_bytes = top_hash.encode() # TODO: use a float to string formatter instead of double casting timestamp_str = str(int(time.time())) - put_bytes(hash_bytes, self.pointer_pk(pkg_name, timestamp_str)) - put_bytes(hash_bytes, self.pointer_latest_pk(pkg_name)) + put_bytes(hash_bytes, self.pointer_pk(pkg_name, timestamp_str), put_options=put_options) + put_bytes(hash_bytes, self.pointer_latest_pk(pkg_name), put_options=put_options) return timestamp_str @staticmethod @@ -246,9 +247,9 @@ def list_package_versions(self, pkg_name: str): for dt, top_hash in self.list_package_versions_with_timestamps(pkg_name): yield str(int(dt.timestamp())), top_hash - def push_manifest(self, pkg_name: str, top_hash: str, manifest_data: bytes): - put_bytes(manifest_data, self.manifest_pk(pkg_name, top_hash)) - put_bytes(top_hash.encode(), self.pointer_latest_pk(pkg_name)) + def push_manifest(self, pkg_name: str, top_hash: str, manifest_data: bytes, put_options: dict = None): + put_bytes(manifest_data, self.manifest_pk(pkg_name, top_hash), put_options=put_options) + put_bytes(top_hash.encode(), self.pointer_latest_pk(pkg_name), put_options=put_options) @staticmethod def _top_hash_from_path(path: str) -> str: diff --git a/api/python/quilt3/bucket.py b/api/python/quilt3/bucket.py index e0f58d69d2a..fab661701f9 100644 --- a/api/python/quilt3/bucket.py +++ b/api/python/quilt3/bucket.py @@ -55,13 +55,14 @@ def search(self, query: T.Union[str, dict], limit: int = 10) -> T.List[dict]: """ return search_api(query, index=f"{self._pk.bucket},{self._pk.bucket}_packages", limit=limit)["hits"]["hits"] - def put_file(self, key, path): + def put_file(self, key, path, put_options=None): """ Stores file at path to key in bucket. Args: key(str): key in bucket to store file at path(str): string representing local path to file + put_options(dict): optional arguments to pass to the PutObject operation Returns: None @@ -71,15 +72,16 @@ def put_file(self, key, path): * if copy fails """ dest = self._pk.join(key) - copy_file(PhysicalKey.from_url(fix_url(path)), dest) + copy_file(PhysicalKey.from_url(fix_url(path)), dest, put_options=put_options) - def put_dir(self, key, directory): + def put_dir(self, key, directory, put_options=None): """ Stores all files in the `directory` under the prefix `key`. Args: key(str): prefix to store files under in bucket directory(str): path to directory to grab files from + put_options(dict): optional arguments to pass to the PutObject operation Returns: None @@ -97,7 +99,7 @@ def put_dir(self, key, directory): src = PhysicalKey.from_path(str(src_path) + '/') dest = self._pk.join(key) - copy_file(src, dest) + copy_file(src, dest, put_options=put_options) def keys(self): """ diff --git a/api/python/quilt3/data_transfer.py b/api/python/quilt3/data_transfer.py index 7ae97ef5bf4..d5ca59596c7 100644 --- a/api/python/quilt3/data_transfer.py +++ b/api/python/quilt3/data_transfer.py @@ -53,6 +53,18 @@ logger = logging.getLogger(__name__) +def add_put_options_safely(params: dict, put_options: Optional[dict]): + """ + Add put options to the params dictionary safely. + This method ensures that the put options do not overwrite existing keys in the params dictionary. + """ + if put_options: + for key, value in put_options.items(): + if key in params: + raise ValueError(f"Cannot override key `{key}` using put_options: {put_options}.") + params[key] = value + + class S3Api(Enum): GET_OBJECT = "GET_OBJECT" HEAD_OBJECT = "HEAD_OBJECT" @@ -301,27 +313,31 @@ def _copy_local_file(ctx: WorkerContext, size: int, src_path: str, dest_path: st ctx.done(PhysicalKey.from_path(dest_path), None) -def _upload_file(ctx: WorkerContext, size: int, src_path: str, dest_bucket: str, dest_key: str): +def _upload_file(ctx: WorkerContext, size: int, src_path: str, dest_bucket: str, dest_key: str, put_options=None): s3_client = ctx.s3_client_provider.standard_client if not is_mpu(size): with ReadFileChunk.from_filename(src_path, 0, size, [ctx.progress]) as fd: - resp = s3_client.put_object( + s3_params = dict( Body=fd, Bucket=dest_bucket, Key=dest_key, ChecksumAlgorithm='SHA256', ) + add_put_options_safely(s3_params, put_options) + resp = s3_client.put_object(**s3_params) version_id = resp.get('VersionId') # Absent in unversioned buckets. checksum = _simple_s3_to_quilt_checksum(resp['ChecksumSHA256']) ctx.done(PhysicalKey(dest_bucket, dest_key, version_id), checksum) else: - resp = s3_client.create_multipart_upload( + s3_create_params = dict( Bucket=dest_bucket, Key=dest_key, ChecksumAlgorithm='SHA256', ) + add_put_options_safely(s3_create_params, put_options) + resp = s3_client.create_multipart_upload(**s3_create_params) upload_id = resp['UploadId'] chunksize = get_checksum_chunksize(size) @@ -336,7 +352,7 @@ def upload_part(i, start, end): nonlocal remaining part_id = i + 1 with ReadFileChunk.from_filename(src_path, start, end-start, [ctx.progress]) as fd: - part = s3_client.upload_part( + s3_upload_params = dict( Body=fd, Bucket=dest_bucket, Key=dest_key, @@ -344,6 +360,8 @@ def upload_part(i, start, end): PartNumber=part_id, ChecksumAlgorithm='SHA256', ) + add_put_options_safely(s3_upload_params, put_options) + part = s3_client.upload_part(**s3_upload_params) with lock: parts[i] = dict( PartNumber=part_id, @@ -354,6 +372,13 @@ def upload_part(i, start, end): done = remaining == 0 if done: + s3_complete_params = dict( + Bucket=dest_bucket, + Key=dest_key, + UploadId=upload_id, + MultipartUpload={'Parts': parts}, + ) + add_put_options_safely(s3_complete_params, put_options) resp = s3_client.complete_multipart_upload( Bucket=dest_bucket, Key=dest_key, @@ -468,21 +493,20 @@ def _copy_remote_file(ctx: WorkerContext, size: int, src_bucket: str, src_key: s Key=dest_key, ChecksumAlgorithm='SHA256', ) - - if extra_args: - params.update(extra_args) - + add_put_options_safely(params, extra_args) resp = s3_client.copy_object(**params) ctx.progress(size) version_id = resp.get('VersionId') # Absent in unversioned buckets. checksum = _simple_s3_to_quilt_checksum(resp['CopyObjectResult']['ChecksumSHA256']) ctx.done(PhysicalKey(dest_bucket, dest_key, version_id), checksum) else: - resp = s3_client.create_multipart_upload( + s3_create_params = dict( Bucket=dest_bucket, Key=dest_key, ChecksumAlgorithm='SHA256', ) + add_put_options_safely(s3_create_params, extra_args) + resp = s3_client.create_multipart_upload(**s3_create_params) upload_id = resp['UploadId'] chunksize = get_checksum_chunksize(size) @@ -496,7 +520,7 @@ def _copy_remote_file(ctx: WorkerContext, size: int, src_bucket: str, src_key: s def upload_part(i, start, end): nonlocal remaining part_id = i + 1 - part = s3_client.upload_part_copy( + s3_upload_params = dict( CopySource=src_params, CopySourceRange=f'bytes={start}-{end-1}', Bucket=dest_bucket, @@ -504,6 +528,8 @@ def upload_part(i, start, end): UploadId=upload_id, PartNumber=part_id, ) + add_put_options_safely(s3_upload_params, extra_args) + part = s3_client.upload_part_copy(**s3_upload_params) with lock: parts[i] = dict( PartNumber=part_id, @@ -516,12 +542,14 @@ def upload_part(i, start, end): ctx.progress(end - start) if done: - resp = s3_client.complete_multipart_upload( + s3_complete_params = dict( Bucket=dest_bucket, Key=dest_key, UploadId=upload_id, MultipartUpload={'Parts': parts}, ) + add_put_options_safely(s3_complete_params, extra_args) + resp = s3_client.complete_multipart_upload(**s3_complete_params) version_id = resp.get('VersionId') # Absent in unversioned buckets. checksum, _ = resp['ChecksumSHA256'].split('-', 1) ctx.done(PhysicalKey(dest_bucket, dest_key, version_id), checksum) @@ -580,7 +608,8 @@ def _reuse_remote_file(ctx: WorkerContext, size: int, src_path: str, dest_bucket return None -def _upload_or_reuse_file(ctx: WorkerContext, size: int, src_path: str, dest_bucket: str, dest_path: str): +def _upload_or_reuse_file(ctx: WorkerContext, size: int, src_path: str, dest_bucket: str, + dest_path: str, put_options=None): result = _reuse_remote_file(ctx, size, src_path, dest_bucket, dest_path) if result is not None: dest_version_id, checksum = result @@ -588,7 +617,7 @@ def _upload_or_reuse_file(ctx: WorkerContext, size: int, src_path: str, dest_buc ctx.done(PhysicalKey(dest_bucket, dest_path, dest_version_id), checksum) return # Optimization succeeded. # If the optimization didn't happen, do the normal upload. - _upload_file(ctx, size, src_path, dest_bucket, dest_path) + _upload_file(ctx, size, src_path, dest_bucket, dest_path, put_options) def _copy_file_list_last_retry(retry_state): @@ -602,7 +631,8 @@ def _copy_file_list_last_retry(retry_state): wait=wait_exponential(multiplier=1, min=1, max=10), retry=retry_if_not_result(all), retry_error_callback=_copy_file_list_last_retry) -def _copy_file_list_internal(file_list, results, message, callback, exceptions_to_ignore=(ClientError,)): +def _copy_file_list_internal(file_list, results, message, callback, + exceptions_to_ignore=(ClientError,), put_options=None): """ Takes a list of tuples (src, dest, size) and copies the data in parallel. `results` is the list where results will be stored. @@ -668,13 +698,13 @@ def done_callback(value, checksum): else: if dest.version_id: raise ValueError("Cannot set VersionId on destination") - _upload_or_reuse_file(ctx, size, src.path, dest.bucket, dest.path) + _upload_or_reuse_file(ctx, size, src.path, dest.bucket, dest.path, put_options) else: if dest.is_local(): _download_file(ctx, size, src.bucket, src.path, src.version_id, dest.path) else: _copy_remote_file(ctx, size, src.bucket, src.path, src.version_id, - dest.bucket, dest.path) + dest.bucket, dest.path, extra_args=put_options) try: for idx, (args, result) in enumerate(zip(file_list, results)): @@ -855,7 +885,7 @@ def delete_url(src: PhysicalKey): s3_client.delete_object(Bucket=src.bucket, Key=src.path) -def copy_file_list(file_list, message=None, callback=None): +def copy_file_list(file_list, message=None, callback=None, put_options=None): """ Takes a list of tuples (src, dest, size) and copies them in parallel. URLs must be regular files, not directories. @@ -865,10 +895,10 @@ def copy_file_list(file_list, message=None, callback=None): if _looks_like_dir(src) or _looks_like_dir(dest): raise ValueError("Directories are not allowed") - return _copy_file_list_internal(file_list, [None] * len(file_list), message, callback) + return _copy_file_list_internal(file_list, [None] * len(file_list), message, callback, put_options=put_options) -def copy_file(src: PhysicalKey, dest: PhysicalKey, size=None, message=None, callback=None): +def copy_file(src: PhysicalKey, dest: PhysicalKey, size=None, message=None, callback=None, put_options=None): """ Copies a single file or directory. If src is a file, dest can be a file or a directory. @@ -900,10 +930,10 @@ def sanity_check(rel_path): src = PhysicalKey(src.bucket, src.path, version_id) url_list.append((src, dest, size)) - _copy_file_list_internal(url_list, [None] * len(url_list), message, callback) + _copy_file_list_internal(url_list, [None] * len(url_list), message, callback, put_options=put_options) -def put_bytes(data: bytes, dest: PhysicalKey): +def put_bytes(data: bytes, dest: PhysicalKey, put_options=None): if _looks_like_dir(dest): raise ValueError("Invalid path: %r" % dest.path) @@ -915,11 +945,9 @@ def put_bytes(data: bytes, dest: PhysicalKey): if dest.version_id is not None: raise ValueError("Cannot set VersionId on destination") s3_client = S3ClientProvider().standard_client - s3_client.put_object( - Bucket=dest.bucket, - Key=dest.path, - Body=data, - ) + s3_params = dict(Bucket=dest.bucket, Key=dest.path, Body=data) + add_put_options_safely(s3_params, put_options) + s3_client.put_object(**s3_params) def _local_get_bytes(pk: PhysicalKey): diff --git a/api/python/quilt3/packages.py b/api/python/quilt3/packages.py index 09f1391b60c..c6ab5ff0c5b 100644 --- a/api/python/quilt3/packages.py +++ b/api/python/quilt3/packages.py @@ -351,13 +351,14 @@ def deserialize(self, func=None, **format_opts): return formats[0].deserialize(data, self._meta, pkey_ext, **format_opts) - def fetch(self, dest=None): + def fetch(self, dest=None, put_options=None): """ Gets objects from entry and saves them to dest. Args: - dest: where to put the files + dest: url for where to put the files Defaults to the entry name + put_options: optional arguments to pass to the PutObject operation Returns: None @@ -368,7 +369,7 @@ def fetch(self, dest=None): else: dest = PhysicalKey.from_url(fix_url(dest)) - copy_file(self.physical_key, dest) + copy_file(self.physical_key, dest, put_options=put_options) # return a package reroot package physical keys after the copy operation succeeds # see GH#388 for context @@ -1053,7 +1054,7 @@ def _validate_with_workflow(self, *, registry, workflow, name, message): @ApiTelemetry("package.build") @_fix_docstring(workflow=_WORKFLOW_PARAM_DOCSTRING) - def build(self, name, registry=None, message=None, *, workflow=...): + def build(self, name, registry=None, message=None, *, workflow=..., put_options=None): """ Serializes this package to a registry. @@ -1063,15 +1064,16 @@ def build(self, name, registry=None, message=None, *, workflow=...): defaults to local registry message: the commit message of the package %(workflow)s + put_options: optional arguments to pass to the PutObject operation Returns: The top hash as a string. """ registry = get_package_registry(registry) self._validate_with_workflow(registry=registry, workflow=workflow, name=name, message=message) - return self._build(name=name, registry=registry, message=message) + return self._build(name=name, registry=registry, message=message, put_options=put_options) - def _build(self, name, registry, message): + def _build(self, name, registry, message, put_options=None): validate_package_name(name) registry = get_package_registry(registry) @@ -1079,13 +1081,13 @@ def _build(self, name, registry, message): self._calculate_missing_hashes() top_hash = self.top_hash - self._push_manifest(name, registry, top_hash) + self._push_manifest(name, registry, top_hash, put_options=put_options) return top_hash - def _push_manifest(self, name, registry, top_hash): + def _push_manifest(self, name, registry, top_hash, put_options=None): manifest = io.BytesIO() self._dump(manifest) - registry.push_manifest(name, top_hash, manifest.getvalue()) + registry.push_manifest(name, top_hash, manifest.getvalue(), put_options=put_options) @ApiTelemetry("package.dump") def dump(self, writable_file): @@ -1357,7 +1359,7 @@ def _get_top_hash_parts(cls, meta, entries): @_fix_docstring(workflow=_WORKFLOW_PARAM_DOCSTRING) def push( self, name, registry=None, dest=None, message=None, selector_fn=None, *, - workflow=..., force: bool = False, dedupe: bool = False + workflow=..., force: bool = False, dedupe: bool = False, put_options=None ): """ Copies objects to path, then creates a new package that points to those objects. @@ -1400,19 +1402,20 @@ def push( %(workflow)s force: skip the top hash check and overwrite any existing package dedupe: don't push if the top hash matches the existing package top hash; return the current package + put_options: optional arguments to pass to the PutObject operation Returns: A new package that points to the copied objects. """ return self._push( name, registry, dest, message, selector_fn, workflow=workflow, - print_info=True, force=force, dedupe=dedupe + print_info=True, force=force, dedupe=dedupe, put_options=put_options ) def _push( self, name, registry=None, dest=None, message=None, selector_fn=None, *, workflow, print_info, force: bool, dedupe: bool, - copy_file_list_fn: T.Optional[CopyFileListFn] = None, + copy_file_list_fn: T.Optional[CopyFileListFn] = None, put_options=None ): if selector_fn is None: def selector_fn(*args): @@ -1533,7 +1536,7 @@ def check_hash_conficts(latest_hash): entries.append((logical_key, entry)) file_list.append((physical_key, new_physical_key, entry.size)) - results = copy_file_list_fn(file_list, message="Copying objects") + results = copy_file_list_fn(file_list, message="Copying objects", put_options=put_options) for (logical_key, entry), (versioned_key, checksum) in zip(entries, results): # Create a new package entry pointing to the new remote key. @@ -1580,7 +1583,7 @@ def physical_key_is_temp_file(pk): latest_hash = get_latest_hash() check_hash_conficts(latest_hash) - pkg._push_manifest(name, registry, top_hash) + pkg._push_manifest(name, registry, top_hash, put_options=put_options) if print_info: shorthash = registry.shorten_top_hash(name, top_hash) diff --git a/api/python/tests/integration/test_packages.py b/api/python/tests/integration/test_packages.py index 1a3e34664ad..e2e236ebbed 100644 --- a/api/python/tests/integration/test_packages.py +++ b/api/python/tests/integration/test_packages.py @@ -46,7 +46,7 @@ LOCAL_REGISTRY = Path('local_registry') # Set by QuiltTestCase -def _mock_copy_file_list(file_list, callback=None, message=None): +def _mock_copy_file_list(file_list, callback=None, message=None, **kwargs): return [(key, None) for _, key, _ in file_list] @@ -452,7 +452,8 @@ def test_fetch_default_dest(tmpdir): filepath = os.path.join(os.path.dirname(__file__), 'data', 'foo.txt') copy_mock.assert_called_once_with( PhysicalKey.from_path(filepath), - PhysicalKey.from_path('foo.txt') + PhysicalKey.from_path('foo.txt'), + put_options=None, ) @patch('quilt3.workflows.validate', mock.MagicMock(return_value=None)) @@ -1259,6 +1260,7 @@ def test_commit_message_on_push(self, mocked_workflow_validate): 'Quilt/test_pkg_name', registry, mock.sentinel.top_hash, + put_options=None, ) mocked_workflow_validate.assert_called_once_with( registry=registry, @@ -1925,7 +1927,7 @@ def test_push_dest_fn(self): pkg.push(pkg_name, registry='s3://test-bucket', dest=dest_fn, force=True) dest_fn.assert_called_once_with(lk, pkg[lk]) - push_manifest_mock.assert_called_once_with(pkg_name, mock.sentinel.top_hash, ANY) + push_manifest_mock.assert_called_once_with(pkg_name, mock.sentinel.top_hash, ANY, put_options=None) assert Package.load( BytesIO(push_manifest_mock.call_args[0][2]) )[lk].physical_key == PhysicalKey(dest_bucket, dest_key, version) @@ -1951,7 +1953,7 @@ def test_push_selector_fn_false(self): selector_fn.assert_called_once_with(lk, pkg[lk]) calculate_checksum_mock.assert_called_once_with([PhysicalKey(src_bucket, src_key, src_version)], [0]) - push_manifest_mock.assert_called_once_with(pkg_name, mock.sentinel.top_hash, ANY) + push_manifest_mock.assert_called_once_with(pkg_name, mock.sentinel.top_hash, ANY, put_options=None) assert Package.load( BytesIO(push_manifest_mock.call_args[0][2]) )[lk].physical_key == PhysicalKey(src_bucket, src_key, src_version) @@ -1998,7 +2000,7 @@ def test_push_selector_fn_true(self): selector_fn.assert_called_once_with(lk, pkg[lk]) calculate_checksum_mock.assert_called_once_with([], []) - push_manifest_mock.assert_called_once_with(pkg_name, mock.sentinel.top_hash, ANY) + push_manifest_mock.assert_called_once_with(pkg_name, mock.sentinel.top_hash, ANY, put_options=None) assert Package.load( BytesIO(push_manifest_mock.call_args[0][2]) )[lk].physical_key == PhysicalKey(dst_bucket, dst_key, dst_version) @@ -2046,10 +2048,15 @@ class PackageTestV2(PackageTest): def local_manifest_timestamp_fixer(self, timestamp): wrapped = self.LocalPackageRegistryDefault.push_manifest - def wrapper(pkg_registry, pkg_name, top_hash, manifest_data): - wrapped(pkg_registry, pkg_name, top_hash, manifest_data) - os.utime(pkg_registry._manifest_parent_pk(pkg_name, top_hash).path, (timestamp, timestamp)) - return patch.object(self.LocalPackageRegistryDefault, 'push_manifest', wrapper) + def wrapper(pkg_registry, pkg_name, top_hash, manifest_data, put_options=None): + wrapped(pkg_registry, pkg_name, top_hash, manifest_data, put_options=put_options) + os.utime( + pkg_registry._manifest_parent_pk(pkg_name, top_hash).path, + (timestamp, timestamp) + ) + return patch.object( + self.LocalPackageRegistryDefault, 'push_manifest', wrapper + ) def _test_list_remote_packages_setup_stubber(self, pkg_registry, *, pkg_names): self.s3_stubber.add_response( diff --git a/api/python/tests/test_bucket.py b/api/python/tests/test_bucket.py index e4cfedc29c3..0ce1d1483cc 100644 --- a/api/python/tests/test_bucket.py +++ b/api/python/tests/test_bucket.py @@ -172,30 +172,39 @@ def test_bucket_select(self): def test_bucket_put_file(self): with patch("quilt3.bucket.copy_file") as copy_mock: + opts = {'SSECustomerKey': 'FakeKey'} bucket = Bucket('s3://test-bucket') - bucket.put_file(key='README.md', path='./README') # put local file to bucket + # put local file to bucket + bucket.put_file(key='README.md', path='./README', put_options=opts) copy_mock.assert_called_once_with( - PhysicalKey.from_path('README'), PhysicalKey.from_url('s3://test-bucket/README.md')) + PhysicalKey.from_path('README'), + PhysicalKey.from_url('s3://test-bucket/README.md'), + put_options=opts, + ) def test_bucket_put_dir(self): path = pathlib.Path(__file__).parent / 'data' bucket = Bucket('s3://test-bucket') + opts = {'SSECustomerKey': 'FakeKey'} with patch("quilt3.bucket.copy_file") as copy_mock: - bucket.put_dir('test', path) + bucket.put_dir('test', path, opts) copy_mock.assert_called_once_with( - PhysicalKey.from_path(str(path) + '/'), PhysicalKey.from_url('s3://test-bucket/test/')) + PhysicalKey.from_path(str(path) + '/'), + PhysicalKey.from_url('s3://test-bucket/test/'), put_options=opts) with patch("quilt3.bucket.copy_file") as copy_mock: - bucket.put_dir('test/', path) + bucket.put_dir('test/', path, opts) copy_mock.assert_called_once_with( - PhysicalKey.from_path(str(path) + '/'), PhysicalKey.from_url('s3://test-bucket/test/')) + PhysicalKey.from_path(str(path) + '/'), + PhysicalKey.from_url('s3://test-bucket/test/'), put_options=opts) with patch("quilt3.bucket.copy_file") as copy_mock: - bucket.put_dir('', path) + bucket.put_dir('', path, opts) copy_mock.assert_called_once_with( - PhysicalKey.from_path(str(path) + '/'), PhysicalKey.from_url('s3://test-bucket/')) + PhysicalKey.from_path(str(path) + '/'), + PhysicalKey.from_url('s3://test-bucket/'), put_options=opts) def test_remote_delete(self): self.s3_stubber.add_response( diff --git a/api/python/tests/test_data_transfer.py b/api/python/tests/test_data_transfer.py index 05a71597a98..27677e4ece0 100644 --- a/api/python/tests/test_data_transfer.py +++ b/api/python/tests/test_data_transfer.py @@ -29,6 +29,27 @@ class DataTransferTest(QuiltTestCase): + def test_add_put_options_safely(self): + OPTIONS_TEMPLATE = {'SSECustomerKey': '123456789'} + + # Test that the function adds the options + options_empty = {} + data_transfer.add_put_options_safely(options_empty, OPTIONS_TEMPLATE) + self.assertEqual(options_empty, OPTIONS_TEMPLATE) + + # Test that the function works when passed None + options_unchanged = OPTIONS_TEMPLATE.copy() + data_transfer.add_put_options_safely(options_unchanged, None) + self.assertEqual(options_unchanged, OPTIONS_TEMPLATE) + + # Test that the function raises error if it would modify the original options + options_original = OPTIONS_TEMPLATE.copy() + options_modified = {'SSECustomerKey': '987654321'} + with pytest.raises(ValueError, + match="Cannot override key `SSECustomerKey` using put_options:" + " {'SSECustomerKey': '987654321'}."): + data_transfer.add_put_options_safely(options_original, options_modified) + def test_select(self): # Note: The boto3 Stubber doesn't work properly with s3_client.select_object_content(). # The return value expects a dict where an iterable is in the actual results. diff --git a/docs/api-reference/Bucket.md b/docs/api-reference/Bucket.md index 7570078cc3a..8aeab93cb27 100644 --- a/docs/api-reference/Bucket.md +++ b/docs/api-reference/Bucket.md @@ -33,7 +33,7 @@ __Returns__ search results -## Bucket.put\_file(self, key, path) {#Bucket.put\_file} +## Bucket.put\_file(self, key, path, put\_options=None) {#Bucket.put\_file} Stores file at path to key in bucket. @@ -41,6 +41,7 @@ __Arguments__ * __key(str)__: key in bucket to store file at * __path(str)__: string representing local path to file +* __put_options(dict)__: optional arguments to pass to the PutObject operation __Returns__ @@ -52,7 +53,7 @@ __Raises__ * if copy fails -## Bucket.put\_dir(self, key, directory) {#Bucket.put\_dir} +## Bucket.put\_dir(self, key, directory, put\_options=None) {#Bucket.put\_dir} Stores all files in the `directory` under the prefix `key`. @@ -60,6 +61,7 @@ __Arguments__ * __key(str)__: prefix to store files under in bucket * __directory(str)__: path to directory to grab files from +* __put_options(dict)__: optional arguments to pass to the PutObject operation __Returns__ diff --git a/docs/api-reference/Package.md b/docs/api-reference/Package.md index 526ea577b67..843714db535 100644 --- a/docs/api-reference/Package.md +++ b/docs/api-reference/Package.md @@ -193,7 +193,7 @@ no such entry exists. Sets user metadata on this Package. -## Package.build(self, name, registry=None, message=None, \*, workflow=Ellipsis) {#Package.build} +## Package.build(self, name, registry=None, message=None, \*, workflow=Ellipsis, put\_options=None) {#Package.build} Serializes this package to a registry. @@ -207,6 +207,7 @@ __Arguments__ If not specified, the default workflow will be used. * __For details see__: https://docs.quiltdata.com/advanced-usage/workflows +* __put_options__: optional arguments to pass to the PutObject operation __Returns__ @@ -272,7 +273,7 @@ __Raises__ * `KeyError`: when logical_key is not present to be deleted -## Package.push(self, name, registry=None, dest=None, message=None, selector\_fn=None, \*, workflow=Ellipsis, force: bool = False, dedupe: bool = False) {#Package.push} +## Package.push(self, name, registry=None, dest=None, message=None, selector\_fn=None, \*, workflow=Ellipsis, force: bool = False, dedupe: bool = False, put\_options=None) {#Package.push} Copies objects to path, then creates a new package that points to those objects. Copies each object in this package to path according to logical key structure, @@ -318,6 +319,7 @@ __Arguments__ * __force__: skip the top hash check and overwrite any existing package * __dedupe__: don't push if the top hash matches the existing package top hash; return the current package +* __put_options__: optional arguments to pass to the PutObject operation __Returns__ @@ -507,14 +509,15 @@ hash verification fail when deserialization metadata is not present -## PackageEntry.fetch(self, dest=None) {#PackageEntry.fetch} +## PackageEntry.fetch(self, dest=None, put\_options=None) {#PackageEntry.fetch} Gets objects from entry and saves them to dest. __Arguments__ -* __dest__: where to put the files +* __dest__: url for where to put the files Defaults to the entry name +* __put_options__: optional arguments to pass to the PutObject operation __Returns__ diff --git a/lambdas/pkgpush/tests/test_index.py b/lambdas/pkgpush/tests/test_index.py index 7f375999408..58b93295089 100644 --- a/lambdas/pkgpush/tests/test_index.py +++ b/lambdas/pkgpush/tests/test_index.py @@ -587,9 +587,11 @@ def make_request(self, params, **kwargs): ) @contextlib.contextmanager - def _mock_package_build(self, entries, *, message=..., expected_workflow=...): + def _mock_package_build(self, entries, *, message=..., expected_workflow=..., put_options=None): if message is ...: message = self.dst_commit_message + if put_options is None: + put_options = {} # Use a test package to verify manifest entries test_pkg = Package() @@ -614,11 +616,6 @@ def _mock_package_build(self, entries, *, message=..., expected_workflow=...): self.s3_stubber.add_response( 'put_object', service_response={}, - expected_params={ - 'Body': manifest.read(), - 'Bucket': self.dst_bucket, - 'Key': f'.quilt/packages/{test_pkg.top_hash}', - }, ) self.s3_stubber.add_response( 'put_object', @@ -627,15 +624,17 @@ def _mock_package_build(self, entries, *, message=..., expected_workflow=...): 'Body': str.encode(test_pkg.top_hash), 'Bucket': self.dst_bucket, 'Key': f'.quilt/named_packages/{self.dst_pkg_name}/{str(int(self.mock_timestamp))}', + **put_options, }, ) self.s3_stubber.add_response( - 'put_object', + "put_object", service_response={}, expected_params={ - 'Body': str.encode(test_pkg.top_hash), - 'Bucket': self.dst_bucket, - 'Key': f'.quilt/named_packages/{self.dst_pkg_name}/latest', + "Body": str.encode(test_pkg.top_hash), + "Bucket": self.dst_bucket, + "Key": f".quilt/named_packages/{self.dst_pkg_name}/latest", + **put_options, }, ) with mock.patch('quilt3.workflows.validate', return_value=mocked_workflow_data) as workflow_validate_mock: