From 6c7d21079f9d5709220bf04e65572b87ee5c2566 Mon Sep 17 00:00:00 2001 From: gadorlhiac Date: Wed, 4 Dec 2024 17:21:40 -0800 Subject: [PATCH] ENH Bring rest of models into pydantic v2 mypy conformance. Just missing validators.py --- lute/io/models/base.py | 5 +--- lute/io/models/sfx_find_peaks.py | 19 +++++++----- lute/io/models/sfx_index.py | 50 +++++++++++++++++--------------- lute/io/models/sfx_merge.py | 32 ++++++++++---------- lute/io/models/sfx_solve.py | 15 ++++++---- lute/io/models/smd.py | 13 ++++++--- lute/io/models/tests.py | 2 +- 7 files changed, 75 insertions(+), 61 deletions(-) diff --git a/lute/io/models/base.py b/lute/io/models/base.py index ac7d185e..e0d32dab 100644 --- a/lute/io/models/base.py +++ b/lute/io/models/base.py @@ -35,9 +35,7 @@ Optional, ClassVar, no_type_check, - Callable, cast, - Literal, ) import pydantic @@ -51,8 +49,7 @@ # Ignore mypy and ruff for now since type checking against pydantic 1.10 from pydantic import model_validator, field_validator # type: ignore from pydantic_core import PydanticUndefined # type: ignore - from pydantic_settings import SettingsConfigDict # type: ignore - from pydantic_settings import BaseSettings + from pydantic_settings import SettingsConfigDict, BaseSettings # type: ignore @no_type_check # This function causes many headaches with mypy... Ignore def Field( diff --git a/lute/io/models/sfx_find_peaks.py b/lute/io/models/sfx_find_peaks.py index 6242bbfd..41a154b9 100644 --- a/lute/io/models/sfx_find_peaks.py +++ b/lute/io/models/sfx_find_peaks.py @@ -48,7 +48,7 @@ class FindPeaksPyAlgosParameters(TaskParameters): Config = FindPeaksPyAlgosParametersConfig if PYDANTIC_V2: - out_file_validator = field_validator("out_file") + out_file_validator: ClassVar = field_validator("out_file") else: out_file_validator = validator("out_file", always=True) @@ -209,17 +209,22 @@ class FindPeaksPsocakeParameters(ThirdPartyParameters): Config = FindPeaksPsocakeParametersConfig if PYDANTIC_V2: - e_validator = field_validator("e") - r_validator = field_validator("r") - set_output_path_validator = field_validator("lute_template_cfg") - sz_parameters_validator = field_validator("sz_parameters") - result_validator = model_validator(mode="after") + e_validator: ClassVar = field_validator("e") + r_validator: ClassVar = field_validator("r") + set_output_path_validator: ClassVar = field_validator("lute_template_cfg") + sz_parameters_validator: ClassVar = field_validator("sz_parameters") + result_validator: ClassVar = model_validator(mode="after") else: e_validator = validator("e", always=True) r_validator = validator("r", always=True) set_output_path_validator = validator("lute_template_cfg", always=True) sz_parameters_validator = validator("sz_parameters", always=True) - result_validator = root_validator(pre=False) + # Strictly only need pre=False for running, but it doesn't match overload + # variants so mypy complains when using pydantic v2. This is functionally + # the same for our purposes + result_validator = root_validator( + pre=False, skip_on_failure=True, allow_reuse=True + ) class SZParameters(BaseModel): compressor: Literal["qoz", "sz3"] = Field( diff --git a/lute/io/models/sfx_index.py b/lute/io/models/sfx_index.py index 349f4425..f510d18a 100644 --- a/lute/io/models/sfx_index.py +++ b/lute/io/models/sfx_index.py @@ -75,8 +75,8 @@ class IndexCrystFELParameters(ThirdPartyParameters): Config = IndexCrystFELParametersConfig if PYDANTIC_V2: - in_file_validator = field_validator("in_file") - out_file_validator = field_validator("out_file") + in_file_validator: ClassVar = field_validator("in_file") + out_file_validator: ClassVar = field_validator("out_file") else: in_file_validator = validator("in_file", always=True) out_file_validator = validator("out_file", always=True) @@ -531,6 +531,15 @@ class ConcatenateStreamFilesParameters(TaskParameters): else: Config = ConcatenateStreamFilesParametersConfig + if PYDANTIC_V2: + in_file_validator: ClassVar = field_validator("in_file") + tag_validator: ClassVar = field_validator("tag") + out_file_validator: ClassVar = field_validator("out_file") + else: + in_file_validator = validator("in_file", always=True) + tag_validator = validator("tag", always=True) + out_file_validator = validator("out_file", always=True) + in_file: str = ( Field( "", @@ -568,15 +577,6 @@ class ConcatenateStreamFilesParameters(TaskParameters): else Field("", description="Path to merged output stream file.", is_result=True) ) - if PYDANTIC_V2: - in_file_validator = field_validator("in_file") - tag_validator = field_validator("tag") - out_file_validator = field_validator("out_file") - else: - in_file_validator = validator("in_file", always=True) - tag_validator = validator("tag", always=True) - out_file_validator = validator("out_file", always=True) - @in_file_validator @classmethod def validate_in_file(cls, in_file: str, values: Dict[str, Any]) -> str: @@ -636,6 +636,17 @@ class IndexCCTBXXFELParameters(ThirdPartyParameters): else: Config = IndexCCTBXXFELParametersConfig + if PYDANTIC_V2: + phil_path_validator: ClassVar = field_validator("phil_file") + phil_template_validator: ClassVar = field_validator("lute_template_cfg") + in_file_validator: ClassVar = field_validator("in_file") + data_spec_validator: ClassVar = field_validator("data_spec") + else: + phil_path_validator = validator("phil_file", always=True) + phil_template_validator = validator("lute_template_cfg", always=True) + in_file_validator = validator("in_file", always=True) + data_spec_validator = validator("data_spec", always=True) + class PhilParameters(BaseModel): """Template parameters for CCTBX phil file.""" @@ -646,8 +657,10 @@ class PhilParameters(BaseModel): Config = PhilParametersConfig if PYDANTIC_V2: - output_output_dir_validator = field_validator("output_output_dir") - output_logging_dir_validator = field_validator("output_logging_dir") + output_output_dir_validator: ClassVar = field_validator("output_output_dir") + output_logging_dir_validator: ClassVar = field_validator( + "output_logging_dir" + ) else: output_output_dir_validator = validator("output_output_dir", always=True) output_logging_dir_validator = validator("output_logging_dir", always=True) @@ -947,17 +960,6 @@ def set_output_log_dir(cls, output: str, values: Dict[str, Any]) -> str: ) ) - if PYDANTIC_V2: - phil_path_validator = field_validator("phil_file") - phil_template_validator = field_validator("lute_template_cfg") - in_file_validator = field_validator("in_file") - data_spec_validator = field_validator("data_spec") - else: - phil_path_validator = validator("phil_file", always=True) - phil_template_validator = validator("lute_template_cfg", always=True) - in_file_validator = validator("in_file", always=True) - data_spec_validator = validator("data_spec", always=True) - @phil_path_validator @classmethod def set_default_phil_path(cls, phil_file: str, values: Dict[str, Any]) -> str: diff --git a/lute/io/models/sfx_merge.py b/lute/io/models/sfx_merge.py index e74b12a4..ba2aaaf8 100644 --- a/lute/io/models/sfx_merge.py +++ b/lute/io/models/sfx_merge.py @@ -72,8 +72,8 @@ class MergePartialatorParameters(ThirdPartyParameters): Config = MergePartialatorParametersConfig if PYDANTIC_V2: - in_file_validator = field_validator("in_file") - out_file_validator = field_validator("out_file") + in_file_validator: ClassVar = field_validator("in_file") + out_file_validator: ClassVar = field_validator("out_file") else: in_file_validator = validator("in_file", always=True) out_file_validator = validator("out_file", always=True) @@ -289,6 +289,13 @@ class MergeCCTBXXFELParameters(ThirdPartyParameters): else: Config = MergeCCTBXXFELParametersConfig + if PYDANTIC_V2: + phil_file_validator: ClassVar = field_validator("phil_file") + phil_template_validator: ClassVar = field_validator("lute_template_cfg") + else: + phil_file_validator = validator("phil_file", always=True) + phil_template_validator = validator("lute_template_cfg", always=True) + class PhilParameters(BaseModel): """Template parameters for CCTBX phil file.""" @@ -461,13 +468,6 @@ class PhilParameters(BaseModel): ) ) - if PYDANTIC_V2: - phil_file_validator = field_validator("phil_file") - phil_template_validator = field_validator("lute_template_cfg") - else: - phil_file_validator = validator("phil_file", always=True) - phil_template_validator = validator("lute_template_cfg", always=True) - @phil_file_validator @classmethod def set_default_phil_path(cls, phil_file: str, values: Dict[str, Any]) -> str: @@ -514,10 +514,10 @@ class CompareHKLParameters(ThirdPartyParameters): Config = CompareHKLParametersConfig if PYDANTIC_V2: - in_files_validator = field_validator("in_files") - cell_file_validator = field_validator("cell_file") - symmetry_validator = field_validator("symmetry") - shell_file_validator = field_validator("shell_file") + in_files_validator: ClassVar = field_validator("in_files") + cell_file_validator: ClassVar = field_validator("cell_file") + symmetry_validator: ClassVar = field_validator("symmetry") + shell_file_validator: ClassVar = field_validator("shell_file") else: in_files_validator = validator("in_files", always=True) cell_file_validator = validator("cell_file", always=True) @@ -716,9 +716,9 @@ class ManipulateHKLParameters(ThirdPartyParameters): Config = ManipulateHKLParametersConfig if PYDANTIC_V2: - in_file_validator = field_validator("in_file") - out_file_validator = field_validator("out_file") - cell_file_validator = field_validator("cell_file") + in_file_validator: ClassVar = field_validator("in_file") + out_file_validator: ClassVar = field_validator("out_file") + cell_file_validator: ClassVar = field_validator("cell_file") else: in_file_validator = validator("in_file", always=True) out_file_validator = validator("out_file", always=True) diff --git a/lute/io/models/sfx_solve.py b/lute/io/models/sfx_solve.py index 9227b3c1..6bd07e66 100644 --- a/lute/io/models/sfx_solve.py +++ b/lute/io/models/sfx_solve.py @@ -59,13 +59,18 @@ class DimpleSolveParameters(ThirdPartyParameters): Config = DimpleSolveParametersConfig if PYDANTIC_V2: - in_file_validator = field_validator("in_file") - out_dir_validator = field_validator("out_dir") - result_validator = model_validator(mode="after") + in_file_validator: ClassVar = field_validator("in_file") + out_dir_validator: ClassVar = field_validator("out_dir") + result_validator: ClassVar = model_validator(mode="after") else: in_file_validator = validator("in_file", always=True) out_dir_validator = validator("out_dir", always=True) - result_validator = root_validator(pre=False) + # Strictly only need pre=False for running, but it doesn't match overload + # variants so mypy complains when using pydantic v2. This is functionally + # the same for our purposes + result_validator = root_validator( + pre=False, skip_on_failure=True, allow_reuse=True + ) executable: str = Field( "/sdf/group/lcls/ds/tools/ccp4-8.0/bin/dimple", @@ -277,7 +282,7 @@ class RunSHELXCParameters(ThirdPartyParameters): """ if PYDANTIC_V2: - in_file_validator = field_validator("in_file") + in_file_validator: ClassVar = field_validator("in_file") else: in_file_validator = validator("in_file", always=True) diff --git a/lute/io/models/smd.py b/lute/io/models/smd.py index 3ce55f22..82b39fc7 100644 --- a/lute/io/models/smd.py +++ b/lute/io/models/smd.py @@ -73,13 +73,18 @@ class SubmitSMDParameters(ThirdPartyParameters): Config = SubmitSMDParametersConfig if PYDANTIC_V2: - producer_validator = field_validator("producer") - producer_template_validator = field_validator("lute_template_cfg") - result_validator = model_validator(mode="after") + producer_validator: ClassVar = field_validator("producer") + producer_template_validator: ClassVar = field_validator("lute_template_cfg") + result_validator: ClassVar = model_validator(mode="after") else: producer_validator = validator("producer", always=True) producer_template_validator = validator("lute_template_cfg", always=True) - result_validator = root_validator(pre=False) + # Strictly only need pre=False for running, but it doesn't match overload + # variants so mypy complains when using pydantic v2. This is functionally + # the same for our purposes + result_validator = root_validator( + pre=False, skip_on_failure=True, allow_reuse=True + ) class ProducerParameters(BaseModel): class ROIParams(BaseModel): diff --git a/lute/io/models/tests.py b/lute/io/models/tests.py index 32562950..a3963ca3 100644 --- a/lute/io/models/tests.py +++ b/lute/io/models/tests.py @@ -126,7 +126,7 @@ class TestReadOutputParameters(TaskParameters): in_file: str = Field("", description="File to read in. (Full path)") if PYDANTIC_V2: - in_file_validator = field_validator("in_file") + in_file_validator: ClassVar = field_validator("in_file") else: in_file_validator = validator("in_file", always=True)