Skip to content

Commit

Permalink
Merge pull request #67 from SANDAG/BayDAG_estimation_skim_mem
Browse files Browse the repository at this point in the history
Load skims into shared memory to be accessed by later models
  • Loading branch information
bhargavasana authored Apr 12, 2024
2 parents 2c11873 + 31c3c35 commit f3b9ee8
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 14 deletions.
3 changes: 3 additions & 0 deletions activitysim/core/mem.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,9 @@ def shared_memory_size(data_buffers=None):

shared_size += Dataset.shm.preload_shared_memory_size(data_buffer[11:])
continue
if isinstance(data_buffer, multiprocessing.shared_memory.SharedMemory):
shared_size += data_buffer.size
continue
try:
obj = data_buffer.get_obj()
except Exception:
Expand Down
44 changes: 30 additions & 14 deletions activitysim/core/skim_dict_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,16 +410,23 @@ def allocate_skim_buffer(self, skim_info, shared=False):
)

if shared:
if dtype_name == "float64":
typecode = "d"
elif dtype_name == "float32":
typecode = "f"
else:
raise RuntimeError(
"allocate_skim_buffer unrecognized dtype %s" % dtype_name
)

buffer = multiprocessing.RawArray(typecode, buffer_size)
# if dtype_name == "float64":
# typecode = "d"
# elif dtype_name == "float32":
# typecode = "f"
# else:
# raise RuntimeError(
# "allocate_skim_buffer unrecognized dtype %s" % dtype_name
# )

# buffer = multiprocessing.RawArray(typecode, buffer_size)
shared_mem_name = f"skim_shared_memory__{skim_info.skim_tag}"
try:
buffer = multiprocessing.shared_memory.SharedMemory(name=shared_mem_name)
logger.info(f"skim buffer already allocated in shared memory: {shared_mem_name}, size: {buffer.size}")
except FileNotFoundError:
buffer = multiprocessing.shared_memory.SharedMemory(create=True, size=csz, name=shared_mem_name)
logger.info(f"allocating skim buffer in shared memory: {shared_mem_name}, size: {buffer.size}")
else:
buffer = np.zeros(buffer_size, dtype=dtype)

Expand All @@ -440,10 +447,16 @@ def _skim_data_from_buffer(self, skim_info, skim_buffer):
"""

dtype = np.dtype(skim_info.dtype_name)
assert len(skim_buffer) == util.iprod(skim_info.skim_data_shape)
skim_data = np.frombuffer(skim_buffer, dtype=dtype).reshape(
skim_info.skim_data_shape
)
if isinstance(skim_buffer, multiprocessing.shared_memory.SharedMemory):
assert skim_buffer.size >= util.iprod(skim_info.skim_data_shape) * dtype.itemsize
skim_data = np.frombuffer(skim_buffer.buf, dtype=dtype, count=util.iprod(skim_info.skim_data_shape)).reshape(
skim_info.skim_data_shape
)
else:
assert len(skim_buffer) == util.iprod(skim_info.skim_data_shape)
skim_data = np.frombuffer(skim_buffer, dtype=dtype).reshape(
skim_info.skim_data_shape
)
return skim_data

def load_skims_to_buffer(self, skim_info, skim_buffer):
Expand All @@ -462,6 +475,9 @@ def load_skims_to_buffer(self, skim_info, skim_buffer):
skim_data = self._skim_data_from_buffer(skim_info, skim_buffer)
assert skim_data.shape == skim_info.skim_data_shape

if isinstance(skim_buffer, multiprocessing.shared_memory.SharedMemory) and skim_data.any():
return

if read_cache:
# returns None if cache file not found
cache_data = self._open_existing_readonly_memmap_skim_cache(skim_info)
Expand Down

0 comments on commit f3b9ee8

Please sign in to comment.