Skip to content

Commit

Permalink
Remove warning on grogu use
Browse files Browse the repository at this point in the history
  • Loading branch information
Alexander Neuwirth committed Nov 5, 2024
1 parent 0abfb4c commit d2b96b7
Show file tree
Hide file tree
Showing 21 changed files with 322 additions and 204 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ repos:
- id: trailing-whitespace

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: "v0.6.8"
rev: "v0.7.2"
hooks:
- id: ruff
args: ["--fix", "--show-fixes"]
Expand Down
12 changes: 10 additions & 2 deletions src/babyyoda/analysisobject.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
class UHIAnalysisObject:
def key(self):
def key(self) -> str:
return self.path()

def setAnnotationsDict(self, d: dict):
def path(self) -> str:
err = "UHIAnalysisObject.path() must be implemented by subclass"
raise NotImplementedError(err)

def setAnnotationsDict(self, d: dict[str, str]) -> None:
for k, v in d.items():
self.setAnnotation(k, v)

def setAnnotation(self, key: str, value: str) -> None:
err = "UHIAnalysisObject.setAnnotation() must be implemented by subclass"
raise NotImplementedError(err)
39 changes: 31 additions & 8 deletions src/babyyoda/counter.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import contextlib
from typing import Any, Optional

from babyyoda.analysisobject import UHIAnalysisObject


def set_bin0d(target, source):
def set_bin0d(target: Any, source: Any) -> None:
if hasattr(target, "set"):
target.set(
source.numEntries(),
Expand All @@ -15,24 +16,46 @@ def set_bin0d(target, source):
raise NotImplementedError(err)


def Counter(*args, **kwargs):
def Counter(*args, **kwargs) -> "UHICounter":
"""
Automatically select the correct version of the Histo1D class
"""
try:
from babyyoda import yoda

return yoda.Counter(*args, **kwargs)
except ImportError:
import babyyoda.grogu as yoda
return yoda.Counter(*args, **kwargs)
from babyyoda import grogu

return grogu.Counter(*args, **kwargs)


# TODO make this implementation independent (no V2 or V3...)
class UHICounter(UHIAnalysisObject):
######
# Minimum required functions
######

def sumW(self) -> float:
raise NotImplementedError

def sumW2(self) -> float:
raise NotImplementedError

def numEntries(self) -> float:
raise NotImplementedError

def annotationsDict(self) -> dict[str, Optional[str]]:
raise NotImplementedError

def clone(self) -> "UHICounter":
raise NotImplementedError

######
# BACKENDS
######

def to_grogu_v2(self):
def to_grogu_v2(self) -> Any:
from babyyoda.grogu.counter_v2 import GROGU_COUNTER_V2

return GROGU_COUNTER_V2(
Expand All @@ -47,7 +70,7 @@ def to_grogu_v2(self):
],
)

def to_grogu_v3(self):
def to_grogu_v3(self) -> Any:
from babyyoda.grogu.counter_v3 import GROGU_COUNTER_V3

return GROGU_COUNTER_V3(
Expand All @@ -62,11 +85,11 @@ def to_grogu_v3(self):
],
)

def to_yoda_v3(self):
def to_yoda_v3(self) -> Any:
err = "Not implemented yet"
raise NotImplementedError(err)

def to_string(self):
def to_string(self) -> str:
# Now we need to map YODA to grogu and then call to_string
# TODO do we want to hardcode v3 here?
return self.to_grogu_v3().to_string()
Expand Down
20 changes: 10 additions & 10 deletions src/babyyoda/grogu/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from typing import Any

from babyyoda.grogu.counter_v2 import Counter_v2
from babyyoda.grogu.counter_v3 import Counter_v3
from babyyoda.grogu.counter_v3 import GROGU_COUNTER_V3, Counter_v3
from babyyoda.grogu.histo1d_v2 import Histo1D_v2
from babyyoda.grogu.histo1d_v3 import Histo1D_v3
from babyyoda.grogu.histo1d_v3 import GROGU_HISTO1D_V3, Histo1D_v3
from babyyoda.grogu.histo2d_v2 import Histo2D_v2
from babyyoda.grogu.histo2d_v3 import Histo2D_v3
from babyyoda.grogu.histo2d_v3 import GROGU_HISTO2D_V3, Histo2D_v3

from .read import read
from .write import write
Expand All @@ -20,21 +22,19 @@
]


def Counter(*args, **kwargs):
def Counter(*args: Any, **kwargs: Any) -> GROGU_COUNTER_V3:
return Counter_v3(*args, **kwargs)


def Histo1D(*args, **kwargs):
def Histo1D(*args: Any, **kwargs: Any) -> GROGU_HISTO1D_V3:
return Histo1D_v3(*args, **kwargs)


def Histo2D(
*args,
title=None,
**kwargs,
):
*args: Any,
**kwargs: Any,
) -> GROGU_HISTO2D_V3:
return Histo2D_v3(
*args,
title=title,
**kwargs,
)
36 changes: 18 additions & 18 deletions src/babyyoda/grogu/analysis_object.py
Original file line number Diff line number Diff line change
@@ -1,62 +1,62 @@
import re
from dataclasses import dataclass, field
from typing import Optional


@dataclass
class GROGU_ANALYSIS_OBJECT:
d_annotations: dict = field(default_factory=dict)
# TODO add anotations
d_annotations: dict[str, Optional[str]] = field(default_factory=dict)
d_key: str = ""

def __post_init__(self):
def __post_init__(self) -> None:
if "Path" not in self.d_annotations:
self.d_annotations["Path"] = "/"
if "Title" not in self.d_annotations:
self.d_annotations["Title"] = ""

############################################
# YODA compatibilty code
# YODA compatibility code
############################################

def key(self):
def key(self) -> str:
return self.d_key

def name(self):
def name(self) -> str:
return self.path().split("/")[-1]

def path(self):
def path(self) -> str:
p = self.annotation("Path")
return p if p else "/"

def title(self):
def title(self) -> Optional[str]:
return self.annotation("Title")

def type(self):
def type(self) -> Optional[str]:
return self.annotation("Type")

def annotations(self):
return self.d_annotations.keys()
def annotations(self) -> list[str]:
return list(self.d_annotations.keys())

def annotation(self, k: str, default=None) -> str:
def annotation(self, k: str, default: Optional[str] = None) -> Optional[str]:
return self.d_annotations.get(k, default)

def setAnnotation(self, key: str, value: str):
def setAnnotation(self, key: str, value: str) -> None:
self.d_annotations[key] = value

def clearAnnotations(self):
def clearAnnotations(self) -> None:
self.d_annotations = {}

def hasAnnotation(self, key: str) -> bool:
return key in self.d_annotations

def annotationsDict(self):
def annotationsDict(self) -> dict[str, Optional[str]]:
return self.d_annotations

@classmethod
def from_string(cls, file_content: str) -> "GROGU_ANALYSIS_OBJECT":
lines = file_content.strip().splitlines()
# Extract metadata (path, title)
annotations = {"Path": "/"}
annotations: dict[str, Optional[str]] = {"Path": "/"}
pattern = re.compile(r"(\S+): (.+)")
for line in lines:
pattern_match = pattern.match(line)
Expand All @@ -69,10 +69,10 @@ def from_string(cls, file_content: str) -> "GROGU_ANALYSIS_OBJECT":

return cls(
d_annotations=annotations,
d_key=annotations.get("Path", ""),
d_key=annotations.get("Path", "") or "",
)

def to_string(self):
def to_string(self) -> str:
ret = ""
for k, v in self.d_annotations.items():
val = v
Expand Down
6 changes: 3 additions & 3 deletions src/babyyoda/grogu/counter_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class Bin:
d_numentries: float = 0.0

########################################################
# YODA compatibilty code
# YODA compatibility code
########################################################

def clone(self):
Expand Down Expand Up @@ -119,13 +119,13 @@ def from_string(cls, string: str) -> "GROGU_COUNTER_V2.Bin":

d_bins: list[Bin] = field(default_factory=list)

def __post_init__(self):
def __post_init__(self) -> None:
GROGU_ANALYSIS_OBJECT.__post_init__(self)
self.setAnnotation("Type", "Counter")
assert len(self.d_bins) == 1

############################################
# YODA compatibilty code
# YODA compatibility code
############################################

def sumW(self):
Expand Down
32 changes: 16 additions & 16 deletions src/babyyoda/grogu/counter_v3.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import re
from dataclasses import dataclass, field
from typing import Union
from typing import Optional, Union

from babyyoda.counter import UHICounter
from babyyoda.grogu.analysis_object import GROGU_ANALYSIS_OBJECT


def Counter_v3(title=None, **kwargs):
def Counter_v3(title: Optional[str] = None, **kwargs) -> "GROGU_COUNTER_V3":
return GROGU_COUNTER_V3(
d_bins=[GROGU_COUNTER_V3.Bin()],
d_annotations={"Title": title} if title else {},
Expand All @@ -23,7 +23,7 @@ class Bin:
d_numentries: float = 0.0

########################################################
# YODA compatibilty code
# YODA compatibility code
########################################################

def clone(self):
Expand Down Expand Up @@ -60,13 +60,13 @@ def set(
self.d_sumw2 = sumW2[0]
self.d_numentries = numEntries

def sumW(self):
def sumW(self) -> float:
return self.d_sumw

def sumW2(self):
def sumW2(self) -> float:
return self.d_sumw2

def variance(self):
def variance(self) -> float:
if self.d_sumw**2 - self.d_sumw2 == 0:
return 0
return abs(
Expand All @@ -75,19 +75,19 @@ def variance(self):
)
# return self.d_sumw2/self.d_numentries - (self.d_sumw/self.d_numentries)**2

def errW(self):
def errW(self) -> float:
return self.d_sumw2**0.5

def stdDev(self):
def stdDev(self) -> float:
return self.variance() ** 0.5

def effNumEntries(self):
def effNumEntries(self) -> float:
return self.sumW() ** 2 / self.sumW2()

def stdErr(self):
def stdErr(self) -> float:
return self.stdDev() / self.effNumEntries() ** 0.5

def numEntries(self):
def numEntries(self) -> float:
return self.d_numentries

def __eq__(self, other):
Expand Down Expand Up @@ -119,22 +119,22 @@ def from_string(cls, string: str) -> "GROGU_COUNTER_V3.Bin":

d_bins: list[Bin] = field(default_factory=list)

def __post_init__(self):
def __post_init__(self) -> None:
GROGU_ANALYSIS_OBJECT.__post_init__(self)
self.setAnnotation("Type", "Counter")
assert len(self.d_bins) == 1

############################################
# YODA compatibilty code
# YODA compatibility code
############################################

def sumW(self):
def sumW(self) -> float:
return self.d_bins[0].sumW()

def sumW2(self):
def sumW2(self) -> float:
return self.d_bins[0].sumW2()

def numEntries(self):
def numEntries(self) -> float:
return self.d_bins[0].numEntries()

def clone(self):
Expand Down
Loading

0 comments on commit d2b96b7

Please sign in to comment.