Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Use Pydantic models to validate metadata in aggregation.py and remove drop_nones #750

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion examples/s/d/nn/_project/aggregate_surfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from pathlib import Path

import fmu.dataio
import fmu.dataio._utils
import numpy as np
import xtgeo
import yaml
Expand Down Expand Up @@ -54,7 +55,7 @@ def main():
operations = ["mean", "min", "max", "std"]

# This is the ID we assign to this set of aggregations
aggregation_id = "something_very_unique" # IRL this will usually be a uuid
aggregation_id = str(fmu.dataio._utils.uuid_from_string("something_very_unique"))

# Initialize an AggregatedData object for this set of aggregations
exp = fmu.dataio.AggregatedData(
Expand Down
27 changes: 0 additions & 27 deletions src/fmu/dataio/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import json
import os
import uuid
from copy import deepcopy
from pathlib import Path
from tempfile import NamedTemporaryFile
from typing import Any, Final, Literal
Expand Down Expand Up @@ -371,32 +370,6 @@ def read_named_envvar(envvar: str) -> str | None:
return os.environ.get(envvar, None)


def filter_validate_metadata(metadata_in: dict) -> dict:
"""Validate metadatadict at topmost_level and strip away any alien keys."""

valids = [
"$schema",
"version",
"source",
"tracklog",
"class",
"fmu",
"file",
"data",
"display",
"access",
"masterdata",
]

metadata = deepcopy(metadata_in)

for key in metadata_in:
if key not in valids:
del metadata[key]

return metadata


def generate_description(desc: str | list | None = None) -> list | None:
"""Parse desciption input (generic)."""
if not desc:
Expand Down
57 changes: 33 additions & 24 deletions src/fmu/dataio/aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
from . import _utils, dataio, types
from ._logging import null_logger
from ._metadata import generate_meta_tracklog
from ._model import internal
from ._model.enums import FMUContext
from .exceptions import InvalidMetadataError
from .providers.objectdata._provider import objectdata_provider_factory

logger: Final = null_logger(__name__)
Expand Down Expand Up @@ -62,7 +64,7 @@ class AggregatedData:
tagname: str = ""
verbosity: str = "DEPRECATED" # keep for while

_metadata: dict = field(default_factory=dict, init=False)
_metadata: internal.DataClassMeta = field(init=False)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Had to change the type of _metadata to internal.DataClassMeta here in order to pass the mypy checks. Not sure if that is what we wanted, or if there is a better way to do it?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is what we wanted and the right way to do it. We might not want to use the internal class -- but that situation is a bit confusing and can addressed outside of this pr.

_metafile: Path = field(default_factory=Path, init=False)

def __post_init__(self) -> None:
Expand Down Expand Up @@ -194,7 +196,7 @@ def _construct_filename(self, template: dict) -> tuple[Path, Path | None]:

return relname, absname

def _generate_aggrd_metadata(
def _set_metadata(
self,
obj: types.Inferrable,
real_ids: list[int],
Expand All @@ -216,8 +218,8 @@ def _generate_aggrd_metadata(
if not self.operation:
raise ValueError("The 'operation' key has no value")

# use first as template but filter away invalid entries first:
template = _utils.filter_validate_metadata(self.source_metadata[0])
# use first as template
template = copy.deepcopy(self.source_metadata[0])
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as I could see from the doc, Pydantic should default to "Ignore" on unknown fields so it should be okay to remove this filtering method. Is this correct?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, plus we expect that incoming metadata is already valid and will be further validated once we put it into the Pydantic model


relpath, abspath = self._construct_filename(template)

Expand Down Expand Up @@ -251,25 +253,32 @@ def _generate_aggrd_metadata(

objdata = objectdata_provider_factory(obj=obj, dataio=etemp)

template["tracklog"] = [generate_meta_tracklog()[0].model_dump(mode="json")]
template["tracklog"] = [generate_meta_tracklog()[0]]
template["file"] = {
"relative_path": str(relpath),
"absolute_path": str(abspath) if abspath else None,
"checksum_md5": (
None
if not compute_md5
else _utils.compute_md5_using_temp_file(obj, objdata.extension)
),
}
if compute_md5:
template["file"]["checksum_md5"] = _utils.compute_md5_using_temp_file(
obj, objdata.extension
)

# data section
if self.name:
template["data"]["name"] = self.name
if self.tagname:
template["data"]["tagname"] = self.tagname
if bbox := objdata.get_bbox():
template["data"]["bbox"] = bbox.model_dump(mode="json", exclude_none=True)
template["data"]["bbox"] = bbox

self._metadata = template
try:
self._metadata = internal.DataClassMeta.model_validate(template)
except ValidationError as err:
raise InvalidMetadataError(
f"The existing metadata for the aggregated data is invalid. "
f"Detailed information: \n{str(err)}"
) from err

# ==================================================================================
# Public methods:
Expand All @@ -296,13 +305,20 @@ def generate_metadata(
a temporary export of the data, and may be time consuming for large
data.

skip_null: If True (default), None values in putput will be skipped
skip_null: This input parameter has been deprecated. If set to False,
a deprecation warning will be raised.
**kwargs: See AggregatedData() arguments; initial will be overridden by
Comment on lines +309 to 310
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We would eventually like to get away from this pattern too, where a class instance has its state updated so it can be reused. **kwargs is present in a few other places as well.

settings here.
"""
logger.info("Generate metadata for class")
self._update_settings(kwargs)

if not skip_null:
warnings.warn(
"The input parameter 'skip_null' has been deprecated. "
"Setting this to False will not have any effect."
)

# get input realization numbers:
real_ids = []
uuids = []
Expand All @@ -317,13 +333,10 @@ def generate_metadata(
uuids.append(xuuid)

# first config file as template
self._generate_aggrd_metadata(obj, real_ids, uuids, compute_md5)
if skip_null:
self._metadata = _utils.drop_nones(self._metadata)
self._set_metadata(obj, real_ids, uuids, compute_md5)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think eventually we should move to call _set_metadata() in __post_init__() but that can be taken in a separate PR.


return copy.deepcopy(self._metadata)
return self._metadata.model_dump(mode="json", exclude_none=True, by_alias=True)

# alias method
def generate_aggregation_metadata(
self,
obj: types.Inferrable,
Expand All @@ -348,7 +361,7 @@ def export(self, obj: types.Inferrable, **kwargs: object) -> str:
"""
self._update_settings(kwargs)

metadata = self.generate_metadata(obj, compute_md5=False)
metadata = self.generate_metadata(obj, compute_md5=True)

abspath = metadata["file"].get("absolute_path", None)

Expand All @@ -362,15 +375,11 @@ def export(self, obj: types.Inferrable, **kwargs: object) -> str:
outfile.parent.mkdir(parents=True, exist_ok=True)
metafile = outfile.parent / ("." + str(outfile.name) + ".yml")

logger.info("Export to file and compute MD5 sum")
# inject the computed md5 checksum in metadata
metadata["file"]["checksum_md5"] = _utils.export_file_compute_checksum_md5(
obj, outfile
)
logger.info("Export to file and export metadata file.")
_utils.export_file(obj, outfile)

_utils.export_metadata_file(metafile, metadata, savefmt=self.meta_format)
logger.info("Actual file is: %s", outfile)
logger.info("Metadata file is: %s", metafile)

self._metadata = metadata
return str(outfile)
57 changes: 38 additions & 19 deletions tests/test_units/test_aggregated_surfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,17 @@ def test_regsurf_aggregated(fmurun_w_casemetadata, aggr_surfs_mean):
aggr_mean, metas = aggr_surfs_mean # xtgeo_object, list-of-metadata-dicts
logger.info("Aggr. mean is %s", aggr_mean.values.mean())

aggregation_uuid = str(utils.uuid_from_string("1234"))

aggdata = dataio.AggregatedData(
source_metadata=metas,
operation="mean",
name="myaggrd",
aggregation_id="1234",
aggregation_id=aggregation_uuid,
)
newmeta = aggdata.generate_metadata(aggr_mean)
logger.debug("New metadata:\n%s", utils.prettyprint_dict(newmeta))
assert newmeta["fmu"]["aggregation"]["id"] == "1234"
assert newmeta["fmu"]["aggregation"]["id"] == aggregation_uuid
assert newmeta["fmu"]["context"]["stage"] == "iteration"


Expand All @@ -46,15 +48,17 @@ def test_regsurf_aggregated_content_seismic(
aggr_mean, metas = aggr_sesimic_surfs_mean # xtgeo_object, list-of-metadata-dicts
logger.info("Aggr. mean is %s", aggr_mean.values.mean())

aggregation_uuid = str(utils.uuid_from_string("1234"))

aggdata = dataio.AggregatedData(
source_metadata=metas,
operation="mean",
name="myaggrd",
aggregation_id="1234",
aggregation_id=aggregation_uuid,
)
newmeta = aggdata.generate_metadata(aggr_mean)
logger.debug("New metadata:\n%s", utils.prettyprint_dict(newmeta))
assert newmeta["fmu"]["aggregation"]["id"] == "1234"
assert newmeta["fmu"]["aggregation"]["id"] == aggregation_uuid
assert newmeta["fmu"]["context"]["stage"] == "iteration"


Expand All @@ -71,20 +75,25 @@ def test_regsurf_aggregated_export(fmurun_w_casemetadata, aggr_surfs_mean):
aggr_mean, metas = aggr_surfs_mean # xtgeo_object, list-of-metadata-dicts
logger.info("Aggr. mean is %s", aggr_mean.values.mean())

aggregation_uuid = str(utils.uuid_from_string("1234"))

aggdata = dataio.AggregatedData(
source_metadata=metas,
operation="mean",
name="myaggrd",
tagname="mean",
aggregation_id="1234",
aggregation_id=aggregation_uuid,
)

mypath = aggdata.export(aggr_mean)

logger.info("Relative path: %s", aggdata._metadata["file"]["relative_path"])
logger.info("Absolute path: %s", aggdata._metadata["file"]["absolute_path"])
logger.info("Relative path: %s", aggdata._metadata.file.relative_path)
logger.info("Absolute path: %s", aggdata._metadata.file.absolute_path)
logger.debug(
"Final metadata after export:\n%s", utils.prettyprint_dict(aggdata._metadata)
"Final metadata after export:\n%s",
utils.prettyprint_dict(
aggdata._metadata.model_dump(mode="json", exclude_none=True, by_alias=True)
),
)

assert "iter-0/share/results/maps/myaggrd--mean.gri" in mypath
Expand All @@ -99,12 +108,14 @@ def test_regsurf_aggregated_alt_keys(fmurun_w_casemetadata, aggr_surfs_mean):
aggr_mean, metas = aggr_surfs_mean # xtgeo_object, list-of-metadata-dicts
logger.info("Aggr. mean is %s", aggr_mean.values.mean())

aggregation_uuid = str(utils.uuid_from_string("1234"))

meta1 = dataio.AggregatedData(
source_metadata=metas,
operation="mean",
name="myaggrd",
tagname="mean",
aggregation_id="1234",
aggregation_id=aggregation_uuid,
).generate_metadata(aggr_mean)

# alternative
Expand All @@ -114,7 +125,7 @@ def test_regsurf_aggregated_alt_keys(fmurun_w_casemetadata, aggr_surfs_mean):
operation="mean",
name="myaggrd",
tagname="mean",
aggregation_id="1234",
aggregation_id=aggregation_uuid,
)

# alternative with export
Expand All @@ -125,9 +136,9 @@ def test_regsurf_aggregated_alt_keys(fmurun_w_casemetadata, aggr_surfs_mean):
operation="mean",
name="myaggrd",
tagname="mean",
aggregation_id="1234",
aggregation_id=aggregation_uuid,
)
meta3 = aggdata3._metadata
meta3 = aggdata3._metadata.model_dump(mode="json", exclude_none=True, by_alias=True)

del meta1["tracklog"]
del meta2["tracklog"]
Expand All @@ -152,13 +163,15 @@ def test_regsurf_aggr_export_give_casepath(fmurun_w_casemetadata, aggr_surfs_mea
aggr_mean, metas = aggr_surfs_mean # xtgeo_object, list-of-metadata-dicts
logger.info("Aggr. mean is %s", aggr_mean.values.mean())

aggregation_uuid = str(utils.uuid_from_string("1234abcd"))

aggdata = dataio.AggregatedData(
source_metadata=metas,
casepath=casepath,
operation="mean",
name="myaggrd",
tagname="mean",
aggregation_id="1234abcd",
aggregation_id=aggregation_uuid,
)

mypath = aggdata.export(aggr_mean)
Expand Down Expand Up @@ -210,12 +223,14 @@ def test_regsurf_aggr_export_abspath_none(fmurun_w_casemetadata, aggr_surfs_mean
# manipulate first metadata record so mimic abspath is None
metas[0]["file"]["absolute_path"] = None

aggregation_uuid = str(utils.uuid_from_string("1234abcd"))

aggdata = dataio.AggregatedData(
source_metadata=metas,
operation="mean",
name="myaggrd",
tagname="mean",
aggregation_id="1234abcd",
aggregation_id=aggregation_uuid,
)

newmeta = aggdata.generate_metadata(aggr_mean)
Expand Down Expand Up @@ -257,18 +272,20 @@ def test_regsurf_aggregated_aggregation_id(fmurun_w_casemetadata, aggr_surfs_mea
assert "id" in newmeta["fmu"]["aggregation"]
assert newmeta["fmu"]["aggregation"]["id"] != "1234" # shall be uuid

aggregation_uuid = str(utils.uuid_from_string("1234"))

# let aggregation_id argument be used as aggregation_id
aggdata = dataio.AggregatedData(
source_metadata=metas,
operation="mean",
name="myaggrd2",
aggregation_id="1234",
aggregation_id=aggregation_uuid,
)
newmeta = aggdata.generate_metadata(aggr_mean)
logger.debug("New metadata:\n%s", utils.prettyprint_dict(newmeta))
assert newmeta["fmu"]["aggregation"]["id"] == "1234"
assert newmeta["fmu"]["aggregation"]["id"] == aggregation_uuid

# Raise when given aggregation_id is not a string 1
# Raise when given aggregation_id is not a uuid string
with pytest.raises(ValueError):
aggdata = dataio.AggregatedData(
source_metadata=metas,
Expand All @@ -278,7 +295,7 @@ def test_regsurf_aggregated_aggregation_id(fmurun_w_casemetadata, aggr_surfs_mea
)
newmeta = aggdata.generate_metadata(aggr_mean)

# Raise when given aggregation_id is not a string 2
# Raise when given aggregation_id is not a uuid string 2
with pytest.raises(ValueError):
aggdata = dataio.AggregatedData(
source_metadata=metas,
Expand Down Expand Up @@ -351,11 +368,13 @@ def test_regsurf_aggregated_diffdata(fmurun_w_casemetadata, rmsglobalconfig, reg
aggregated = surfs.statistics()
logger.info("Aggr. mean is %s", aggregated["mean"].values.mean()) # shall be 1238.5

aggregation_uuid = str(utils.uuid_from_string("789politipoliti"))

aggdata = dataio.AggregatedData(
source_metadata=metas,
operation="mean",
name="myaggrd",
aggregation_id="789politipoliti",
aggregation_id=aggregation_uuid,
)
newmeta = aggdata.generate_metadata(aggregated["mean"])
logger.info("New metadata:\n%s", utils.prettyprint_dict(newmeta))