Skip to content

Commit

Permalink
ENH Bring rest of models into pydantic v2 mypy conformance. Just miss…
Browse files Browse the repository at this point in the history
…ing validators.py
  • Loading branch information
gadorlhiac committed Dec 5, 2024
1 parent 8575f3c commit 6c7d210
Show file tree
Hide file tree
Showing 7 changed files with 75 additions and 61 deletions.
5 changes: 1 addition & 4 deletions lute/io/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,7 @@
Optional,
ClassVar,
no_type_check,
Callable,
cast,
Literal,
)

import pydantic
Expand All @@ -51,8 +49,7 @@
# Ignore mypy and ruff for now since type checking against pydantic 1.10
from pydantic import model_validator, field_validator # type: ignore
from pydantic_core import PydanticUndefined # type: ignore
from pydantic_settings import SettingsConfigDict # type: ignore
from pydantic_settings import BaseSettings
from pydantic_settings import SettingsConfigDict, BaseSettings # type: ignore

@no_type_check # This function causes many headaches with mypy... Ignore
def Field(
Expand Down
19 changes: 12 additions & 7 deletions lute/io/models/sfx_find_peaks.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class FindPeaksPyAlgosParameters(TaskParameters):
Config = FindPeaksPyAlgosParametersConfig

if PYDANTIC_V2:
out_file_validator = field_validator("out_file")
out_file_validator: ClassVar = field_validator("out_file")
else:
out_file_validator = validator("out_file", always=True)

Expand Down Expand Up @@ -209,17 +209,22 @@ class FindPeaksPsocakeParameters(ThirdPartyParameters):
Config = FindPeaksPsocakeParametersConfig

if PYDANTIC_V2:
e_validator = field_validator("e")
r_validator = field_validator("r")
set_output_path_validator = field_validator("lute_template_cfg")
sz_parameters_validator = field_validator("sz_parameters")
result_validator = model_validator(mode="after")
e_validator: ClassVar = field_validator("e")
r_validator: ClassVar = field_validator("r")
set_output_path_validator: ClassVar = field_validator("lute_template_cfg")
sz_parameters_validator: ClassVar = field_validator("sz_parameters")
result_validator: ClassVar = model_validator(mode="after")
else:
e_validator = validator("e", always=True)
r_validator = validator("r", always=True)
set_output_path_validator = validator("lute_template_cfg", always=True)
sz_parameters_validator = validator("sz_parameters", always=True)
result_validator = root_validator(pre=False)
# Strictly only need pre=False for running, but it doesn't match overload
# variants so mypy complains when using pydantic v2. This is functionally
# the same for our purposes
result_validator = root_validator(
pre=False, skip_on_failure=True, allow_reuse=True
)

class SZParameters(BaseModel):
compressor: Literal["qoz", "sz3"] = Field(
Expand Down
50 changes: 26 additions & 24 deletions lute/io/models/sfx_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@ class IndexCrystFELParameters(ThirdPartyParameters):
Config = IndexCrystFELParametersConfig

if PYDANTIC_V2:
in_file_validator = field_validator("in_file")
out_file_validator = field_validator("out_file")
in_file_validator: ClassVar = field_validator("in_file")
out_file_validator: ClassVar = field_validator("out_file")
else:
in_file_validator = validator("in_file", always=True)
out_file_validator = validator("out_file", always=True)
Expand Down Expand Up @@ -531,6 +531,15 @@ class ConcatenateStreamFilesParameters(TaskParameters):
else:
Config = ConcatenateStreamFilesParametersConfig

if PYDANTIC_V2:
in_file_validator: ClassVar = field_validator("in_file")
tag_validator: ClassVar = field_validator("tag")
out_file_validator: ClassVar = field_validator("out_file")
else:
in_file_validator = validator("in_file", always=True)
tag_validator = validator("tag", always=True)
out_file_validator = validator("out_file", always=True)

in_file: str = (
Field(
"",
Expand Down Expand Up @@ -568,15 +577,6 @@ class ConcatenateStreamFilesParameters(TaskParameters):
else Field("", description="Path to merged output stream file.", is_result=True)
)

if PYDANTIC_V2:
in_file_validator = field_validator("in_file")
tag_validator = field_validator("tag")
out_file_validator = field_validator("out_file")
else:
in_file_validator = validator("in_file", always=True)
tag_validator = validator("tag", always=True)
out_file_validator = validator("out_file", always=True)

@in_file_validator
@classmethod
def validate_in_file(cls, in_file: str, values: Dict[str, Any]) -> str:
Expand Down Expand Up @@ -636,6 +636,17 @@ class IndexCCTBXXFELParameters(ThirdPartyParameters):
else:
Config = IndexCCTBXXFELParametersConfig

if PYDANTIC_V2:
phil_path_validator: ClassVar = field_validator("phil_file")
phil_template_validator: ClassVar = field_validator("lute_template_cfg")
in_file_validator: ClassVar = field_validator("in_file")
data_spec_validator: ClassVar = field_validator("data_spec")
else:
phil_path_validator = validator("phil_file", always=True)
phil_template_validator = validator("lute_template_cfg", always=True)
in_file_validator = validator("in_file", always=True)
data_spec_validator = validator("data_spec", always=True)

class PhilParameters(BaseModel):
"""Template parameters for CCTBX phil file."""

Expand All @@ -646,8 +657,10 @@ class PhilParameters(BaseModel):
Config = PhilParametersConfig

if PYDANTIC_V2:
output_output_dir_validator = field_validator("output_output_dir")
output_logging_dir_validator = field_validator("output_logging_dir")
output_output_dir_validator: ClassVar = field_validator("output_output_dir")
output_logging_dir_validator: ClassVar = field_validator(
"output_logging_dir"
)
else:
output_output_dir_validator = validator("output_output_dir", always=True)
output_logging_dir_validator = validator("output_logging_dir", always=True)
Expand Down Expand Up @@ -947,17 +960,6 @@ def set_output_log_dir(cls, output: str, values: Dict[str, Any]) -> str:
)
)

if PYDANTIC_V2:
phil_path_validator = field_validator("phil_file")
phil_template_validator = field_validator("lute_template_cfg")
in_file_validator = field_validator("in_file")
data_spec_validator = field_validator("data_spec")
else:
phil_path_validator = validator("phil_file", always=True)
phil_template_validator = validator("lute_template_cfg", always=True)
in_file_validator = validator("in_file", always=True)
data_spec_validator = validator("data_spec", always=True)

@phil_path_validator
@classmethod
def set_default_phil_path(cls, phil_file: str, values: Dict[str, Any]) -> str:
Expand Down
32 changes: 16 additions & 16 deletions lute/io/models/sfx_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ class MergePartialatorParameters(ThirdPartyParameters):
Config = MergePartialatorParametersConfig

if PYDANTIC_V2:
in_file_validator = field_validator("in_file")
out_file_validator = field_validator("out_file")
in_file_validator: ClassVar = field_validator("in_file")
out_file_validator: ClassVar = field_validator("out_file")
else:
in_file_validator = validator("in_file", always=True)
out_file_validator = validator("out_file", always=True)
Expand Down Expand Up @@ -289,6 +289,13 @@ class MergeCCTBXXFELParameters(ThirdPartyParameters):
else:
Config = MergeCCTBXXFELParametersConfig

if PYDANTIC_V2:
phil_file_validator: ClassVar = field_validator("phil_file")
phil_template_validator: ClassVar = field_validator("lute_template_cfg")
else:
phil_file_validator = validator("phil_file", always=True)
phil_template_validator = validator("lute_template_cfg", always=True)

class PhilParameters(BaseModel):
"""Template parameters for CCTBX phil file."""

Expand Down Expand Up @@ -461,13 +468,6 @@ class PhilParameters(BaseModel):
)
)

if PYDANTIC_V2:
phil_file_validator = field_validator("phil_file")
phil_template_validator = field_validator("lute_template_cfg")
else:
phil_file_validator = validator("phil_file", always=True)
phil_template_validator = validator("lute_template_cfg", always=True)

@phil_file_validator
@classmethod
def set_default_phil_path(cls, phil_file: str, values: Dict[str, Any]) -> str:
Expand Down Expand Up @@ -514,10 +514,10 @@ class CompareHKLParameters(ThirdPartyParameters):
Config = CompareHKLParametersConfig

if PYDANTIC_V2:
in_files_validator = field_validator("in_files")
cell_file_validator = field_validator("cell_file")
symmetry_validator = field_validator("symmetry")
shell_file_validator = field_validator("shell_file")
in_files_validator: ClassVar = field_validator("in_files")
cell_file_validator: ClassVar = field_validator("cell_file")
symmetry_validator: ClassVar = field_validator("symmetry")
shell_file_validator: ClassVar = field_validator("shell_file")
else:
in_files_validator = validator("in_files", always=True)
cell_file_validator = validator("cell_file", always=True)
Expand Down Expand Up @@ -716,9 +716,9 @@ class ManipulateHKLParameters(ThirdPartyParameters):
Config = ManipulateHKLParametersConfig

if PYDANTIC_V2:
in_file_validator = field_validator("in_file")
out_file_validator = field_validator("out_file")
cell_file_validator = field_validator("cell_file")
in_file_validator: ClassVar = field_validator("in_file")
out_file_validator: ClassVar = field_validator("out_file")
cell_file_validator: ClassVar = field_validator("cell_file")
else:
in_file_validator = validator("in_file", always=True)
out_file_validator = validator("out_file", always=True)
Expand Down
15 changes: 10 additions & 5 deletions lute/io/models/sfx_solve.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,18 @@ class DimpleSolveParameters(ThirdPartyParameters):
Config = DimpleSolveParametersConfig

if PYDANTIC_V2:
in_file_validator = field_validator("in_file")
out_dir_validator = field_validator("out_dir")
result_validator = model_validator(mode="after")
in_file_validator: ClassVar = field_validator("in_file")
out_dir_validator: ClassVar = field_validator("out_dir")
result_validator: ClassVar = model_validator(mode="after")
else:
in_file_validator = validator("in_file", always=True)
out_dir_validator = validator("out_dir", always=True)
result_validator = root_validator(pre=False)
# Strictly only need pre=False for running, but it doesn't match overload
# variants so mypy complains when using pydantic v2. This is functionally
# the same for our purposes
result_validator = root_validator(
pre=False, skip_on_failure=True, allow_reuse=True
)

executable: str = Field(
"/sdf/group/lcls/ds/tools/ccp4-8.0/bin/dimple",
Expand Down Expand Up @@ -277,7 +282,7 @@ class RunSHELXCParameters(ThirdPartyParameters):
"""

if PYDANTIC_V2:
in_file_validator = field_validator("in_file")
in_file_validator: ClassVar = field_validator("in_file")
else:
in_file_validator = validator("in_file", always=True)

Expand Down
13 changes: 9 additions & 4 deletions lute/io/models/smd.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,13 +73,18 @@ class SubmitSMDParameters(ThirdPartyParameters):
Config = SubmitSMDParametersConfig

if PYDANTIC_V2:
producer_validator = field_validator("producer")
producer_template_validator = field_validator("lute_template_cfg")
result_validator = model_validator(mode="after")
producer_validator: ClassVar = field_validator("producer")
producer_template_validator: ClassVar = field_validator("lute_template_cfg")
result_validator: ClassVar = model_validator(mode="after")
else:
producer_validator = validator("producer", always=True)
producer_template_validator = validator("lute_template_cfg", always=True)
result_validator = root_validator(pre=False)
# Strictly only need pre=False for running, but it doesn't match overload
# variants so mypy complains when using pydantic v2. This is functionally
# the same for our purposes
result_validator = root_validator(
pre=False, skip_on_failure=True, allow_reuse=True
)

class ProducerParameters(BaseModel):
class ROIParams(BaseModel):
Expand Down
2 changes: 1 addition & 1 deletion lute/io/models/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ class TestReadOutputParameters(TaskParameters):
in_file: str = Field("", description="File to read in. (Full path)")

if PYDANTIC_V2:
in_file_validator = field_validator("in_file")
in_file_validator: ClassVar = field_validator("in_file")
else:
in_file_validator = validator("in_file", always=True)

Expand Down

0 comments on commit 6c7d210

Please sign in to comment.