Skip to content

Commit

Permalink
move finalize from run() to __call__() to match other Stages (#71)
Browse files Browse the repository at this point in the history
* move finalize from run() to __call__() to match other Stages

* fix lenght -> length
  • Loading branch information
eacharles authored Nov 1, 2023
1 parent 025a2e2 commit 578d872
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 13 deletions.
20 changes: 10 additions & 10 deletions src/rail/core/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __init__(self, tag, data=None, path=None, creator=None):
self.fileObj = None
self.groups = None
self.partial = False
self.lenght = None
self.length = None

def open(self, **kwargs):
"""Open and return the associated file
Expand Down Expand Up @@ -90,14 +90,14 @@ def write(self, **kwargs):
def _write(cls, data, path, **kwargs):
raise NotImplementedError("DataHandle._write") #pragma: no cover

def initialize_write(self, data_lenght, **kwargs):
def initialize_write(self, data_length, **kwargs):
"""Initialize file to be written by chunks"""
if self.path is None: #pragma: no cover
raise ValueError("TableHandle.write() called but path has not been specified")
self.groups, self.fileObj = self._initialize_write(self.data, os.path.expandvars(self.path), data_lenght, **kwargs)
self.groups, self.fileObj = self._initialize_write(self.data, os.path.expandvars(self.path), data_length, **kwargs)

@classmethod
def _initialize_write(cls, data, path, data_lenght, **kwargs):
def _initialize_write(cls, data, path, data_length, **kwargs):
raise NotImplementedError("DataHandle._initialize_write") #pragma: no cover

def write_chunk(self, start, end, **kwargs):
Expand Down Expand Up @@ -233,18 +233,18 @@ class Hdf5Handle(TableHandle): # pragma: no cover
suffix = 'hdf5'

@classmethod
def _initialize_write(cls, data, path, data_lenght, **kwargs):
initial_dict = cls._get_allocation_kwds(data, data_lenght)
def _initialize_write(cls, data, path, data_length, **kwargs):
initial_dict = cls._get_allocation_kwds(data, data_length)
comm = kwargs.get('communicator', None)
group, fout = tables_io.io.initializeHdf5WriteSingle(path, groupname=None, comm=comm, **initial_dict)
return group, fout

@classmethod
def _get_allocation_kwds(cls, data, data_lenght):
def _get_allocation_kwds(cls, data, data_length):
keywords = {}
for key, array in data.items():
shape = list(array.shape)
shape[0] = data_lenght
shape[0] = data_length
keywords[key] = (shape, array.dtype)
return keywords

Expand Down Expand Up @@ -293,9 +293,9 @@ def _write(cls, data, path, **kwargs):
return data.write_to(path)

@classmethod
def _initialize_write(cls, data, path, data_lenght, **kwargs):
def _initialize_write(cls, data, path, data_length, **kwargs):
comm = kwargs.get('communicator', None)
return data.initializeHdf5Write(path, data_lenght, comm)
return data.initializeHdf5Write(path, data_length, comm)

@classmethod
def _write_chunk(cls, data, fileObj, groups, start, end, **kwargs):
Expand Down
6 changes: 3 additions & 3 deletions src/rail/core/util_stages.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ def run(self):
if self.config.inplace: #pragma: no cover
out_data = data
self.add_data('output', out_data)
self.finalize()

def __repr__(self): # pragma: no cover
printMsg = "Stage that applies remaps the following column names in a pandas DataFrame:\n"
Expand All @@ -57,6 +56,7 @@ def __call__(self, data):
"""
self.set_data('input', data)
self.run()
self.finalize()
return self.get_handle('output')


Expand Down Expand Up @@ -84,7 +84,6 @@ def run(self):
data = self.get_data('input', allow_missing=True)
out_data = data.iloc[self.config.start:self.config.stop]
self.add_data('output', out_data)
self.finalize()

def __repr__(self): # pragma: no cover
printMsg = "Stage that applies remaps the following column names in a pandas DataFrame:\n"
Expand All @@ -106,6 +105,7 @@ def __call__(self, data):
"""
self.set_data('input', data)
self.run()
self.finalize()
return self.get_handle('output')


Expand All @@ -129,7 +129,6 @@ def run(self):
out_fmt = tables_io.types.TABULAR_FORMAT_NAMES[self.config.output_format]
out_data = tables_io.convert(data, out_fmt)
self.add_data('output', out_data)
self.finalize()

def __call__(self, data):
"""Return a converted table
Expand All @@ -146,4 +145,5 @@ def __call__(self, data):
"""
self.set_data('input', data)
self.run()
self.finalize()
return self.get_handle('output')

0 comments on commit 578d872

Please sign in to comment.