Skip to content

Commit

Permalink
Update dll utils to avoid segementation faults
Browse files Browse the repository at this point in the history
  • Loading branch information
schmoelder committed Dec 10, 2024
1 parent 5c6951a commit a56847d
Showing 1 changed file with 148 additions and 102 deletions.
250 changes: 148 additions & 102 deletions cadet/cadet_dll_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def null(*args: Any) -> None:

# %% Single entries


def param_provider_get_double(
reader: Any,
name: ctypes.c_char_p,
Expand Down Expand Up @@ -41,15 +42,22 @@ def param_provider_get_double(

if n in c:
o = c[n]

# Safely handle scalars and lists
try:
float_val = float(o)
except TypeError:
float_val = float(o[0])
if isinstance(o, (list, np.ndarray)):
float_val = float(o[0]) # Use the first element for arrays/lists
else:
float_val = float(o)
except (TypeError, ValueError, IndexError) as e:
log_print(f"Error converting {n} to double: {e}")
return -1

val[0] = ctypes.c_double(float_val)
log_print(f"GET scalar [double] {n}: {float(val[0])}")
log_print(f"GET scalar [double] {n}: {float_val}")
return 0

log_print(f"Parameter {n} not found.")
return -1


Expand Down Expand Up @@ -80,15 +88,22 @@ def param_provider_get_int(

if n in c:
o = c[n]

# Safely handle scalars and lists
try:
int_val = int(o)
except TypeError:
int_val = int(o[0])
if isinstance(o, (list, np.ndarray)):
int_val = np.int32(o[0]) # Use the first element for arrays/lists
else:
int_val = np.int32(o)
except (TypeError, ValueError, IndexError) as e:
log_print(f"Error converting {n} to int: {e}")
return -1

val[0] = ctypes.c_int(int_val)
log_print(f"GET scalar [int] {n}: {int(val[0])}")
log_print(f"GET scalar [int] {n}: {int_val}")
return 0

log_print(f"Parameter {n} not found.")
return -1


Expand Down Expand Up @@ -119,15 +134,22 @@ def param_provider_get_bool(

if n in c:
o = c[n]

# Safely handle scalars and lists
try:
int_val = int(o)
except TypeError:
int_val = int(o[0])
if isinstance(o, (list, np.ndarray)):
bool_val = bool(o[0]) # Use the first element for arrays/lists
else:
bool_val = bool(o)
except (TypeError, ValueError, IndexError) as e:
log_print(f"Error converting {n} to bool: {e}")
return -1

val[0] = ctypes.c_uint8(int_val)
val[0] = ctypes.c_uint8(bool_val)
log_print(f"GET scalar [bool] {n}: {bool(val[0])}")
return 0

log_print(f"Parameter {n} not found.")
return -1


Expand Down Expand Up @@ -159,22 +181,32 @@ def param_provider_get_string(
if n in c:
o = c[n]

if hasattr(o, 'encode'):
bytes_val = o.encode('utf-8')
elif hasattr(o, 'decode'):
bytes_val = o
elif hasattr(o[0], 'encode'):
bytes_val = o[0].encode('utf-8')
elif hasattr(o[0], 'decode'):
bytes_val = o[0]
# Safely handle string conversions
try:
if isinstance(o, str):
bytes_val = o.encode('utf-8')
elif isinstance(o, bytes):
bytes_val = o
elif isinstance(o, (list, np.ndarray)) and isinstance(o[0], str):
bytes_val = o[0].encode('utf-8') # Use the first element
elif isinstance(o, (list, np.ndarray)) and isinstance(o[0], bytes):
bytes_val = o[0]
else:
log_print(f"Error: Unsupported type for parameter {n}.")
return -1
except (TypeError, ValueError, IndexError) as e:
log_print(f"Error converting {n} to string: {e}")
return -1

# Store in reader's buffer
reader.buffer = bytes_val
val[0] = ctypes.cast(reader.buffer, ctypes.c_char_p)
log_print(f"GET scalar [string] {n}: {reader.buffer.decode('utf-8')}")
return 0

log_print(f"Parameter {n} not found.")
return -1


# %% Arrays

def param_provider_get_double_array(
Expand Down Expand Up @@ -207,16 +239,25 @@ def param_provider_get_double_array(

if n in c:
o = c[n]
if isinstance(o, list):
o = np.ascontiguousarray(o)

# Ensure object is a properly aligned numpy array
if isinstance(o, list): # Convert lists to numpy arrays
o = np.array(o, dtype=np.double)
c[n] = o # Update the reader's storage

# Validate the array
if not isinstance(o, np.ndarray) or o.dtype != np.double or not o.flags.c_contiguous:
log_print(f"Error: Parameter {n} is not a contiguous double array.")
return -1

# Provide array data to the caller
n_elem[0] = ctypes.c_int(o.size)
val[0] = o.ctypes.data_as(ctypes.POINTER(ctypes.c_double))
val[0] = np.ctypeslib.as_ctypes(o)

log_print(f"GET array [double] {n}: {o}")
return 0

log_print(f"Parameter {n} not found.")
return -1


Expand Down Expand Up @@ -250,9 +291,11 @@ def param_provider_get_int_array(

if n in c:
o = c[n]
if isinstance(o, list):
o = np.ascontiguousarray(o)
if not isinstance(o, np.ndarray) or o.dtype != int or not o.flags.c_contiguous:
if isinstance(o, list): # Convert lists to numpy arrays
o = np.array(o, dtype=np.int32)
c[n] = o # Update the reader's storage
if not isinstance(o, np.ndarray) or o.dtype != np.int32 or not o.flags.c_contiguous:
log_print(f"Error: Parameter {n} is not a contiguous int array.")
return -1

n_elem[0] = ctypes.c_int(o.size)
Expand Down Expand Up @@ -295,15 +338,29 @@ def param_provider_get_double_array_item(
if n in c:
o = c[n]

# Ensure the object is a numpy array
if isinstance(o, list):
o = np.array(o, dtype=np.double, copy=True)
c[n] = o # Update the reader's storage for consistency
elif isinstance(o, np.ndarray):
o = np.ascontiguousarray(o, dtype=np.double) # Ensure it is contiguous and has correct dtype

if not (isinstance(o, np.ndarray) and o.dtype == np.double and o.flags.c_contiguous):
log_print(f"Error: Parameter {n} is not a valid double array.")
return -1

# Safely retrieve the indexed item
try:
float_val = float(o)
except TypeError:
float_val = float(o[index])
except (IndexError, ValueError, TypeError) as e:
log_print(f"Error accessing index {index} for parameter {n}: {e}")
return -1

val[0] = ctypes.c_double(float_val)
log_print(f"GET array [double] ({index}) {n}: {val[0]}")
log_print(f"GET array [double] ({index}) {n}: {float_val}")
return 0

log_print(f"Parameter {n} not found.")
return -1


Expand All @@ -315,38 +372,36 @@ def param_provider_get_int_array_item(
) -> int:
"""
Retrieve an item from an integer array in the reader based on the provided name and index.
Parameters
----------
reader : Any
The reader object containing the current data scope.
name : ctypes.c_char_p
The name of the parameter to retrieve.
index : int
The index of the array item to retrieve.
val : ctypes.POINTER(ctypes.c_int)
A pointer to store the retrieved integer value.
Returns
-------
int
0 if the value was found and retrieved successfully, -1 otherwise.
"""
n = name.decode('utf-8')
c = reader.current()

if n in c:
o = c[n]

# Ensure the object is a numpy array
if isinstance(o, list):
o = np.array(o, dtype=np.int32, copy=True)
c[n] = o # Update the reader's storage for consistency
elif isinstance(o, np.ndarray):
o = np.ascontiguousarray(o, dtype=np.int32)

if not (isinstance(o, np.ndarray) and o.dtype == np.int32 and o.flags.c_contiguous):
log_print(f"Error: Parameter {n} is not a valid int32 array.")
return -1

# Retrieve the indexed item
try:
int_val = int(o)
except TypeError:
int_val = int(o[index])
except (IndexError, ValueError, TypeError) as e:
log_print(f"Error accessing index {index} for parameter {n}: {e}")
return -1

val[0] = ctypes.c_int(int_val)
log_print(f"GET array [int] ({index}) {n}: {val[0]}")
log_print(f"GET array [int] ({index}) {n}: {int_val}")
return 0

log_print(f"Parameter {n} not found.")
return -1


Expand All @@ -358,38 +413,36 @@ def param_provider_get_bool_array_item(
) -> int:
"""
Retrieve an item from a boolean array in the reader based on the provided name and index.
Parameters
----------
reader : Any
The reader object containing the current data scope.
name : ctypes.c_char_p
The name of the parameter to retrieve.
index : int
The index of the array item to retrieve.
val : ctypes.POINTER(ctypes.c_uint8)
A pointer to store the retrieved boolean value.
Returns
-------
int
0 if the value was found and retrieved successfully, -1 otherwise.
"""
n = name.decode('utf-8')
c = reader.current()

if n in c:
o = c[n]

# Ensure the object is a numpy array
if isinstance(o, list):
o = np.array(o, dtype=np.bool_, copy=True)
c[n] = o # Update the reader's storage for consistency
elif isinstance(o, np.ndarray):
o = np.ascontiguousarray(o, dtype=np.bool_)

if not (isinstance(o, np.ndarray) and o.dtype == np.bool_ and o.flags.c_contiguous):
log_print(f"Error: Parameter {n} is not a valid bool array.")
return -1

# Retrieve the indexed item
try:
int_val = int(o)
except TypeError:
int_val = int(o[index])
bool_val = bool(o[index])
except (IndexError, ValueError, TypeError) as e:
log_print(f"Error accessing index {index} for parameter {n}: {e}")
return -1

val[0] = ctypes.c_uint8(int_val)
val[0] = ctypes.c_uint8(bool_val)
log_print(f"GET array [bool] ({index}) {n}: {bool(val[0])}")
return 0

log_print(f"Parameter {n} not found.")
return -1


Expand All @@ -401,45 +454,38 @@ def param_provider_get_string_array_item(
) -> int:
"""
Retrieve an item from a string array in the reader based on the provided name and index.
"""
n = name.decode('utf-8')
c = reader.current()

Parameters
----------
reader : Any
The reader object containing the current data scope.
name : ctypes.c_char_p
The name of the parameter to retrieve.
index : int
The index of the array item to retrieve.
val : ctypes.POINTER(ctypes.c_char_p)
A pointer to store the retrieved string value.
if n in c:
o = c[n]

Returns
-------
int
0 if the value was found and retrieved successfully, -1 otherwise.
"""
name_str = name.decode('utf-8')
current_reader = reader.current()

if name_str in current_reader:
str_value = current_reader[name_str]
if isinstance(str_value, bytes):
bytes_val = str_value
elif isinstance(str_value, str):
bytes_val = str_value.encode('utf-8')
elif isinstance(str_value, np.ndarray):
bytes_val = str_value[index]
else:
raise TypeError(
"Unexpected type for str_value. "
"Must be of type bytes, str, or np.ndarray."
)
# Ensure the object is a numpy array
if isinstance(o, list):
o = np.array(o, dtype=np.str_, copy=True)
c[n] = o # Update the reader's storage for consistency
elif isinstance(o, np.ndarray):
o = np.ascontiguousarray(o, dtype=np.str_)

reader.buffer = bytes_val
if not (isinstance(o, np.ndarray) and o.dtype.kind == 'U' and o.flags.c_contiguous):
log_print(f"Error: Parameter {n} is not a valid string array.")
return -1

# Retrieve the indexed item
try:
string_val = o[index]
except (IndexError, ValueError, TypeError) as e:
log_print(f"Error accessing index {index} for parameter {n}: {e}")
return -1

# Encode to UTF-8 and store in reader.buffer
reader.buffer = string_val.encode('utf-8')
val[0] = ctypes.cast(reader.buffer, ctypes.c_char_p)
log_print(f"GET array [string] ({index}) {name_str}: {reader.buffer.decode('utf-8')}")
log_print(f"GET array [string] ({index}) {n}: {reader.buffer.decode('utf-8')}")
return 0

log_print(f"Parameter {n} not found.")
return -1


Expand Down

0 comments on commit a56847d

Please sign in to comment.