Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
grzanka committed Mar 22, 2024
1 parent 51428b6 commit bce1299
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 16 deletions.
18 changes: 17 additions & 1 deletion pymchelper/averaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
- WeightedStatsAggregator: for calculating weighted mean and variance
- ConcatenatingAggregator: for concatenating data
- SumAggregator: for calculating sum instead of variance
- NoAggregator: for cases when no aggregation is required
All aggregators have `data` and `error` property, which can be used to obtain the result of the aggregation.
The `data` property returns the result of the aggregation: mean, sum or concatenated array.
Expand All @@ -28,6 +29,7 @@
"""

from dataclasses import dataclass, field
import logging
from typing import Union, Optional
import numpy as np
from numpy.typing import ArrayLike
Expand Down Expand Up @@ -143,8 +145,22 @@ class SumAggregator(Aggregator):

def update(self, value: Union[float, ArrayLike]):
# first value added
if np.isnan(self.total):
if np.isnan(self.data):
self.data = value
# subsequent values added
else:
self.data += value


@dataclass
class NoAggregator(Aggregator):
"""
Class for cases when no aggregation is required.
Sets the data to the first value and does not update it.
"""

def update(self, value: Union[float, ArrayLike]):
# set value only on first update
if np.isnan(self.data):
logging.debug("Setting data to %s", value)
self.data = value
35 changes: 25 additions & 10 deletions pymchelper/input_output.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from enum import IntEnum
import logging
import os
from collections import defaultdict
Expand All @@ -7,7 +8,7 @@

import numpy as np

from pymchelper.averaging import Aggregator, SumAggregator, WeightedStatsAggregator, ConcatenatingAggregator
from pymchelper.averaging import Aggregator, SumAggregator, WeightedStatsAggregator, ConcatenatingAggregator, NoAggregator
from pymchelper.estimator import ErrorEstimate, Estimator, average_with_nan
from pymchelper.readers.topas import TopasReaderFactory
from pymchelper.readers.fluka import FlukaReader, FlukaReaderFactory
Expand All @@ -18,6 +19,26 @@
logger = logging.getLogger(__name__)


class AggregationType(IntEnum):
"""
Enum for different types of aggregation.
"""
NoAggregation = 0
Sum = 1
AveragingCumulative = 2
AveragingPerPrimary = 3
Concatenation = 4


_aggregator_mapping: dict[AggregationType, Aggregator] = {
AggregationType.NoAggregation: NoAggregator,
AggregationType.Sum: SumAggregator,
AggregationType.AveragingCumulative: WeightedStatsAggregator,
AggregationType.AveragingPerPrimary: WeightedStatsAggregator,
AggregationType.Concatenation: ConcatenatingAggregator
}


def guess_reader(filename):
"""
Guess a reader based on file contents or extensions.
Expand Down Expand Up @@ -67,14 +88,6 @@ def fromfile(filename: str) -> Optional[Estimator]:
return estimator


aggregator_type: dict[int, Aggregator] = {
1: SumAggregator,
2: WeightedStatsAggregator,
3: WeightedStatsAggregator,
4: ConcatenatingAggregator
}


def fromfilelist(input_file_list, error: ErrorEstimate = ErrorEstimate.stderr, nan: bool = True) -> Optional[Estimator]:
"""
Reads all files from a given list, and returns a list of averaged estimators.
Expand Down Expand Up @@ -103,7 +116,8 @@ def fromfilelist(input_file_list, error: ErrorEstimate = ErrorEstimate.stderr, n
for page in result.pages:
# got a page with "concatenate normalisation"
current_page_normalisation = getattr(page, 'page_normalized', 2)
aggregator = aggregator_type[current_page_normalisation]()
aggregator = _aggregator_mapping.get(current_page_normalisation, WeightedStatsAggregator)()
logger.info("Selected aggregator %s for page %s", aggregator, page.name)
aggregator.update(page.data_raw)
page_aggregators.append(aggregator)

Expand All @@ -113,6 +127,7 @@ def fromfilelist(input_file_list, error: ErrorEstimate = ErrorEstimate.stderr, n
aggregator.update(current_page.data_raw)

for page, aggregator in zip(result.pages, page_aggregators):
logger.info("Extracting data from aggregator %s for page %s", aggregator, page.name)
page.data_raw = aggregator.data
page.error_raw = aggregator.error()

Expand Down
11 changes: 6 additions & 5 deletions pymchelper/writers/inspector.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,24 @@
import logging
from pymchelper.estimator import Estimator

logger = logging.getLogger(__name__)


class Inspector:

def __init__(self, filename, options):
logger.debug("Initialising Inspector writer")
self.options = options

def write(self, estimator):
def write(self, estimator: Estimator):
"""Print all keys and values from estimator structure
they include also a metadata read from binary output file
"""
for name, value in sorted(estimator.__dict__.items()):
# skip non-metadata fields
if name not in {'data', 'data_raw', 'error', 'error_raw', 'counter', 'pages'}:
line = "{:24s}: '{:s}'".format(str(name), str(value))
line = f"{name:24s}: {value}"
print(line)
# print some data-related statistics
print(75 * "*")
Expand All @@ -26,10 +28,9 @@ def write(self, estimator):
for name, value in sorted(page.__dict__.items()):
# skip non-metadata fields
if name not in {'data', 'data_raw', 'error', 'error_raw'}:
line = "\t{:24s}: '{:s}'".format(str(name), str(value))
line = f"\t{name:24s}: {value}"
print(line)
print("Data min: {:g}, max: {:g}, mean: {:g}".format(
page.data_raw.min(), page.data_raw.max(), page.data_raw.mean()))
print(f"Data min: {page.data_raw.min():g}, max: {page.data_raw.max():g}, mean: {page.data_raw.mean():g}")
print(75 * "-")

if self.options.details:
Expand Down

0 comments on commit bce1299

Please sign in to comment.