diff --git a/projects/adapter/src/dbt/adapters/fal_experimental/adapter.py b/projects/adapter/src/dbt/adapters/fal_experimental/adapter.py index af03ad19f..4b97d16c7 100644 --- a/projects/adapter/src/dbt/adapters/fal_experimental/adapter.py +++ b/projects/adapter/src/dbt/adapters/fal_experimental/adapter.py @@ -19,6 +19,7 @@ from dbt.parser.manifest import MacroManifest, Manifest from .adapter_support import ( + FalCloudContext, prepare_for_adapter, read_relation_as_df, reconstruct_adapter, @@ -28,7 +29,7 @@ from .utils import extra_path, get_fal_scripts_path, retrieve_symbol -def run_with_adapter(code: str, adapter: BaseAdapter, config: RuntimeConfig) -> Any: +def run_with_adapter(code: str, adapter: BaseAdapter, config: RuntimeConfig, is_cloud: bool = False) -> Any: # main symbol is defined during dbt-fal's compilation # and acts as an entrypoint for us to run the model. fal_scripts_path = str(get_fal_scripts_path(config)) @@ -37,6 +38,7 @@ def run_with_adapter(code: str, adapter: BaseAdapter, config: RuntimeConfig) -> return main( read_df=prepare_for_adapter(adapter, read_relation_as_df), write_df=prepare_for_adapter(adapter, write_df_to_relation), + fal_context=FalCloudContext() if is_cloud else None ) @@ -46,6 +48,7 @@ def _isolated_runner( manifest: Manifest, macro_manifest: MacroManifest, local_packages: Optional[bytes] = None, + is_cloud: bool = False ) -> Any: # This function can be run in an entirely separate # process or an environment, so we need to reconstruct @@ -60,7 +63,7 @@ def _isolated_runner( zip_file = zipfile.ZipFile(io.BytesIO(local_packages)) zip_file.extractall(fal_scripts_path) - return run_with_adapter(code, adapter, config) + return run_with_adapter(code, adapter, config, is_cloud) def run_in_environment_with_adapter( @@ -102,7 +105,8 @@ def run_in_environment_with_adapter( config, manifest, macro_manifest, - local_packages=compressed_local_packages + local_packages=compressed_local_packages, + is_cloud=is_remote ) if environment.kind == "local": diff --git a/projects/adapter/src/dbt/adapters/fal_experimental/adapter_support.py b/projects/adapter/src/dbt/adapters/fal_experimental/adapter_support.py index 41d7e07e8..140a7b202 100644 --- a/projects/adapter/src/dbt/adapters/fal_experimental/adapter_support.py +++ b/projects/adapter/src/dbt/adapters/fal_experimental/adapter_support.py @@ -1,3 +1,4 @@ +from dataclasses import dataclass import functools from time import sleep from typing import Any @@ -208,3 +209,24 @@ def reload_adapter_cache(adapter: BaseAdapter, manifest: Manifest) -> None: def new_connection(adapter: BaseAdapter, connection_name: str) -> Connection: with adapter.connection_named(connection_name): yield adapter.connections.get_thread_connection() + + +class FalCloudWriter(object): + def __init__(self, path: str, options: str): + self.path = path + self.options = options + + def __enter__(self): + self.file = open(self.path, self.options) + return self.file + + def __exit__(self, *args): + self.file.close() + + + +@dataclass +class FalCloudContext: + @property + def store_open(self): + return FalCloudWriter