Skip to content

Commit

Permalink
Merge branch 'main' into pin_for_redshift_connector_204_fix
Browse files Browse the repository at this point in the history
  • Loading branch information
colin-rogers-dbt authored Jun 13, 2024
2 parents f7be9c3 + 0c3f514 commit 070021f
Show file tree
Hide file tree
Showing 7 changed files with 45 additions and 20 deletions.
6 changes: 6 additions & 0 deletions .changes/unreleased/Under the Hood-20240331-103115.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Under the Hood
body: Lazy load agate
time: 2024-03-31T10:31:15.65006-04:00
custom:
Author: dwreeves
Issue: "745"
13 changes: 9 additions & 4 deletions dbt/adapters/redshift/connections.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
import re
from multiprocessing import Lock
from contextlib import contextmanager
from typing import Any, Callable, Dict, Tuple, Union, Optional, List
from typing import Any, Callable, Dict, Tuple, Union, Optional, List, TYPE_CHECKING
from dataclasses import dataclass, field

import agate
import sqlparse
import redshift_connector
from dbt.adapters.exceptions import FailedToConnectError
from dbt_common.clients import agate_helper
from redshift_connector.utils.oids import get_datatype_name

from dbt.adapters.sql import SQLConnectionManager
Expand All @@ -19,6 +17,11 @@
from dbt_common.helper_types import Port
from dbt_common.exceptions import DbtRuntimeError, CompilationError, DbtDatabaseError

if TYPE_CHECKING:
# Indirectly imported via agate_helper, which is lazy loaded further downfile.
# Used by mypy for earlier type hints.
import agate


class SSLConfigError(CompilationError):
def __init__(self, exc: ValidationError):
Expand Down Expand Up @@ -393,13 +396,15 @@ def execute(
auto_begin: bool = False,
fetch: bool = False,
limit: Optional[int] = None,
) -> Tuple[AdapterResponse, agate.Table]:
) -> Tuple[AdapterResponse, "agate.Table"]:
sql = self._add_query_comment(sql)
_, cursor = self.add_query(sql, auto_begin)
response = self.get_response(cursor)
if fetch:
table = self.get_result_from_cursor(cursor, limit)
else:
from dbt_common.clients import agate_helper

table = agate_helper.empty_table()
return response, table

Expand Down
9 changes: 6 additions & 3 deletions dbt/adapters/redshift/impl.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
from dataclasses import dataclass
from dbt_common.contracts.constraints import ConstraintType
from typing import Optional, Set, Any, Dict, Type
from typing import Optional, Set, Any, Dict, Type, TYPE_CHECKING
from collections import namedtuple
from dbt.adapters.base import PythonJobHelper
from dbt.adapters.base.impl import AdapterConfig, ConstraintSupport
Expand All @@ -28,6 +28,9 @@

GET_RELATIONS_MACRO_NAME = "redshift__get_relations"

if TYPE_CHECKING:
import agate


@dataclass
class RedshiftConfig(AdapterConfig):
Expand Down Expand Up @@ -85,7 +88,7 @@ def drop_relation(self, relation):
return super().drop_relation(relation)

@classmethod
def convert_text_type(cls, agate_table, col_idx):
def convert_text_type(cls, agate_table: "agate.Table", col_idx):
column = agate_table.columns[col_idx]
# `lens` must be a list, so this can't be a generator expression,
# because max() raises ane exception if its argument has no members.
Expand All @@ -94,7 +97,7 @@ def convert_text_type(cls, agate_table, col_idx):
return "varchar({})".format(max_len)

@classmethod
def convert_time_type(cls, agate_table, col_idx):
def convert_time_type(cls, agate_table: "agate.Table", col_idx):
return "varchar(24)"

@available
Expand Down
11 changes: 8 additions & 3 deletions dbt/adapters/redshift/relation_configs/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from dataclasses import dataclass
from typing import Optional, Dict
from typing import Optional, Dict, TYPE_CHECKING

import agate
from dbt.adapters.base.relation import Policy
from dbt.adapters.contracts.relation import ComponentName, RelationConfig
from dbt.adapters.relation_configs import (
Expand All @@ -15,6 +14,10 @@
RedshiftQuotePolicy,
)

if TYPE_CHECKING:
# Imported downfile for specific row gathering function.
import agate


@dataclass(frozen=True, eq=True, unsafe_hash=True)
class RedshiftRelationConfigBase(RelationConfigBase):
Expand Down Expand Up @@ -63,8 +66,10 @@ def _render_part(cls, component: ComponentName, value: Optional[str]) -> Optiona
return None

@classmethod
def _get_first_row(cls, results: agate.Table) -> agate.Row:
def _get_first_row(cls, results: "agate.Table") -> "agate.Row":
try:
return results.rows[0]
except IndexError:
import agate

return agate.Row(values=set())
8 changes: 5 additions & 3 deletions dbt/adapters/redshift/relation_configs/dist.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from dataclasses import dataclass
from dbt.adapters.contracts.relation import RelationConfig
from typing import Optional, Set, Dict
from typing import Optional, Set, Dict, TYPE_CHECKING

import agate
from dbt.adapters.relation_configs import (
RelationConfigChange,
RelationConfigChangeAction,
Expand All @@ -15,6 +14,9 @@

from dbt.adapters.redshift.relation_configs.base import RedshiftRelationConfigBase

if TYPE_CHECKING:
import agate


class RedshiftDistStyle(StrEnum):
auto = "auto"
Expand Down Expand Up @@ -108,7 +110,7 @@ def parse_relation_config(cls, relation_config: RelationConfig) -> dict:
return config

@classmethod
def parse_relation_results(cls, relation_results_entry: agate.Row) -> Dict:
def parse_relation_results(cls, relation_results_entry: "agate.Row") -> Dict:
"""
Translate agate objects from the database into a standard dictionary.
Expand Down
10 changes: 6 additions & 4 deletions dbt/adapters/redshift/relation_configs/materialized_view.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from dataclasses import dataclass, field
from typing import Optional, Set, Dict, Any
from typing import Optional, Set, Dict, Any, TYPE_CHECKING

import agate
from dbt.adapters.relation_configs import (
RelationResults,
RelationConfigChange,
Expand All @@ -25,6 +24,9 @@
)
from dbt.adapters.redshift.utility import evaluate_bool

if TYPE_CHECKING:
import agate


@dataclass(frozen=True, eq=True, unsafe_hash=True)
class RedshiftMaterializedViewConfig(RedshiftRelationConfigBase, RelationConfigValidationMixin):
Expand Down Expand Up @@ -173,10 +175,10 @@ def parse_relation_results(cls, relation_results: RelationResults) -> Dict:
Returns: a standard dictionary describing this `RedshiftMaterializedViewConfig` instance
"""
materialized_view: agate.Row = cls._get_first_row(
materialized_view: "agate.Row" = cls._get_first_row(
relation_results.get("materialized_view")
)
query: agate.Row = cls._get_first_row(relation_results.get("query"))
query: "agate.Row" = cls._get_first_row(relation_results.get("query"))

config_dict = {
"mv_name": materialized_view.get("table"),
Expand Down
8 changes: 5 additions & 3 deletions dbt/adapters/redshift/relation_configs/sort.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from dataclasses import dataclass
from dbt.adapters.contracts.relation import RelationConfig
from typing import Optional, FrozenSet, Set, Dict, Any
from typing import Optional, FrozenSet, Set, Dict, Any, TYPE_CHECKING

import agate
from dbt.adapters.relation_configs import (
RelationConfigChange,
RelationConfigChangeAction,
Expand All @@ -15,6 +14,9 @@

from dbt.adapters.redshift.relation_configs.base import RedshiftRelationConfigBase

if TYPE_CHECKING:
import agate


class RedshiftSortStyle(StrEnum):
auto = "auto"
Expand Down Expand Up @@ -136,7 +138,7 @@ def parse_relation_config(cls, relation_config: RelationConfig) -> Dict[str, Any
return config_dict

@classmethod
def parse_relation_results(cls, relation_results_entry: agate.Row) -> dict:
def parse_relation_results(cls, relation_results_entry: "agate.Row") -> dict:
"""
Translate agate objects from the database into a standard dictionary.
Expand Down

0 comments on commit 070021f

Please sign in to comment.