Skip to content

Commit

Permalink
graph operators
Browse files Browse the repository at this point in the history
  • Loading branch information
0x00b1 committed Oct 4, 2024
1 parent 9943577 commit 4ae17dc
Show file tree
Hide file tree
Showing 9 changed files with 32 additions and 6 deletions.
22 changes: 16 additions & 6 deletions src/beignet/_validate_graph_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,8 @@ def validate_graph_matrix(
nan_null=True,
dtype=torch.float64,
):
"""Routine for validation and conversion of csgraph inputs"""
if not (csr_output or dense_output):
raise ValueError("Internal: dense or csr output must be true")
raise ValueError

accept_fv = [null_value_in]

Expand Down Expand Up @@ -54,20 +53,31 @@ def validate_graph_matrix(
graph = masked_tensor_to_graph_matrix(graph)
else:
if dense_output:
graph = tensor_to_masked_graph_matrix(graph, copy=copy_if_dense, null_value=null_value_in, nan_null=nan_null, infinity_null=infinity_null)
graph = tensor_to_masked_graph_matrix(
graph,
copy=copy_if_dense,
null_value=null_value_in,
nan_null=nan_null,
infinity_null=infinity_null,
)

mask = graph.mask

graph = numpy.asarray(graph.data, dtype=dtype)

graph[mask] = null_value_out
else:
graph = tensor_to_graph_matrix(graph, null_value=null_value_in, infinity_is_null_edge=infinity_null, nan_is_null_edge=nan_null)
graph = tensor_to_graph_matrix(
graph,
null_value=null_value_in,
infinity_is_null_edge=infinity_null,
nan_is_null_edge=nan_null,
)

if graph.ndim != 2:
raise ValueError("compressed-sparse graph must be 2-D")
raise ValueError

if graph.shape[0] != graph.shape[1]:
raise ValueError("compressed-sparse graph must be shape (N, N)")
raise ValueError

return graph
2 changes: 2 additions & 0 deletions tests/beignet/test__graph_matrix_to_masked_tensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
def test_graph_matrix_to_masked_tensor():
assert False
2 changes: 2 additions & 0 deletions tests/beignet/test__graph_matrix_to_tensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
def test_graph_matrix_to_tensor():
assert False
2 changes: 2 additions & 0 deletions tests/beignet/test__masked_tensor_to_graph_matrix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
def test_masked_tensor_to_graph_matrix():
assert False
2 changes: 2 additions & 0 deletions tests/beignet/test__predecessor_matrix_to_distance_matrix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
def test_predecessor_matrix_to_distance_matrix():
assert False
2 changes: 2 additions & 0 deletions tests/beignet/test__reconstruct_path.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
def test_reconstruct_path():
assert False
2 changes: 2 additions & 0 deletions tests/beignet/test__tensor_to_graph_matrix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
def test_tensor_to_graph_matrix():
assert False
2 changes: 2 additions & 0 deletions tests/beignet/test__tensor_to_masked_graph_matrix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
def test_tensor_to_masked_graph_matrix():
assert False
2 changes: 2 additions & 0 deletions tests/beignet/test__validate_graph_matrix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
def test_validate_graph_matrix():
assert False

0 comments on commit 4ae17dc

Please sign in to comment.