Skip to content

Commit

Permalink
Fix mypy warnings for ManifestLoader.load() (#8443)
Browse files Browse the repository at this point in the history
  • Loading branch information
gshank authored Aug 17, 2023
1 parent d088d44 commit c30b691
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 10 deletions.
6 changes: 6 additions & 0 deletions .changes/unreleased/Under the Hood-20230817-134548.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Under the Hood
body: Fix mypy warnings for ManifestLoader.load()
time: 2023-08-17T13:45:48.937252-04:00
custom:
Author: gshank
Issue: "8401"
12 changes: 7 additions & 5 deletions core/dbt/parser/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@
load_source_file,
FileDiff,
ReadFilesFromDiff,
ReadFiles,
)
from dbt.parser.partial import PartialParsing, special_override_macros
from dbt.contracts.graph.manifest import (
Expand Down Expand Up @@ -259,7 +260,7 @@ def __init__(
# We need to know if we're actually partially parsing. It could
# have been enabled, but not happening because of some issue.
self.partially_parsing = False
self.partial_parser = None
self.partial_parser: Optional[PartialParsing] = None

# This is a saved manifest from a previous run that's used for partial parsing
self.saved_manifest: Optional[Manifest] = self.read_manifest_for_partial_parse()
Expand Down Expand Up @@ -331,7 +332,7 @@ def get_full_manifest(
return manifest

# This is where the main action happens
def load(self):
def load(self) -> Manifest:
start_read_files = time.perf_counter()

# This updates the "files" dictionary in self.manifest, and creates
Expand All @@ -340,6 +341,7 @@ def load(self):
# of parsers to lists of file strings. The file strings are
# used to get the SourceFiles from the manifest files.
saved_files = self.saved_manifest.files if self.saved_manifest else {}
file_reader: Optional[ReadFiles] = None
if self.file_diff:
# We're getting files from a file diff
file_reader = ReadFilesFromDiff(
Expand Down Expand Up @@ -403,7 +405,7 @@ def load(self):
}

# get file info for local logs
parse_file_type = None
parse_file_type: str = ""
file_id = self.partial_parser.processing_file
if file_id:
source_file = None
Expand Down Expand Up @@ -484,7 +486,7 @@ def load(self):
self.manifest.rebuild_disabled_lookup()

# Load yaml files
parser_types = [SchemaParser]
parser_types = [SchemaParser] # type: ignore
for project in self.all_projects.values():
if project.project_name not in project_parser_files:
continue
Expand Down Expand Up @@ -1062,7 +1064,7 @@ def track_project_load(self):

# Takes references in 'refs' array of nodes and exposures, finds the target
# node, and updates 'depends_on.nodes' with the unique id
def process_refs(self, current_project: str, dependencies: Optional[Dict[str, Project]]):
def process_refs(self, current_project: str, dependencies: Optional[Mapping[str, Project]]):
for node in self.manifest.nodes.values():
if node.created_at < self.started_at:
continue
Expand Down
20 changes: 15 additions & 5 deletions core/dbt/parser/read_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@
from dbt.parser.schemas import yaml_from_file, schema_file_keys
from dbt.exceptions import ParsingError
from dbt.parser.search import filesystem_search
from typing import Optional, Dict, List, Mapping
from typing import Optional, Dict, List, Mapping, MutableMapping
from dbt.events.types import InputFileDiffError
from dbt.events.functions import fire_event
from typing import Protocol


@dataclass
Expand Down Expand Up @@ -173,12 +174,21 @@ def generate_dbt_ignore_spec(project_root):
return ignore_spec


# Protocol for the ReadFiles... classes
class ReadFiles(Protocol):
files: MutableMapping[str, AnySourceFile]
project_parser_files: Dict

def read_files(self):
pass


@dataclass
class ReadFilesFromFileSystem:
all_projects: Mapping[str, Project]
files: Dict[str, AnySourceFile] = field(default_factory=dict)
files: MutableMapping[str, AnySourceFile] = field(default_factory=dict)
# saved_files is only used to compare schema files
saved_files: Dict[str, AnySourceFile] = field(default_factory=dict)
saved_files: MutableMapping[str, AnySourceFile] = field(default_factory=dict)
# project_parser_files = {
# "my_project": {
# "ModelParser": ["my_project://models/my_model.sql"]
Expand Down Expand Up @@ -212,10 +222,10 @@ class ReadFilesFromDiff:
root_project_name: str
all_projects: Mapping[str, Project]
file_diff: FileDiff
files: Dict[str, AnySourceFile] = field(default_factory=dict)
files: MutableMapping[str, AnySourceFile] = field(default_factory=dict)
# saved_files is used to construct a fresh copy of files, without
# additional information from parsing
saved_files: Dict[str, AnySourceFile] = field(default_factory=dict)
saved_files: MutableMapping[str, AnySourceFile] = field(default_factory=dict)
project_parser_files: Dict = field(default_factory=dict)
project_file_types: Dict = field(default_factory=dict)
local_package_dirs: Optional[List[str]] = None
Expand Down

0 comments on commit c30b691

Please sign in to comment.