Skip to content

Commit

Permalink
style
Browse files Browse the repository at this point in the history
  • Loading branch information
alexbarghi-nv committed Sep 22, 2023
1 parent 72bebc2 commit 4d51751
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ def _mg_call_plc_uniform_neighbor_sample(
renumber=renumber,
use_legacy_names=use_legacy_names,
compression=compression,
include_hop_column=include_hop_column
include_hop_column=include_hop_column,
)
if with_edge_properties
else create_empty_df(indices_t, weight_t)
Expand Down Expand Up @@ -522,10 +522,8 @@ def uniform_neighbor_sample(
Contains the batch offsets for the renumber maps
"""

if compression not in ['COO', 'CSR', 'CSC', 'DCSR', 'DCSC']:
raise ValueError(
"compression must be one of COO, CSR, CSC, DCSR, or DCSC"
)
if compression not in ["COO", "CSR", "CSC", "DCSR", "DCSC"]:
raise ValueError("compression must be one of COO, CSR, CSC, DCSR, or DCSC")

if with_edge_properties:
warning_msg = (
Expand Down
6 changes: 2 additions & 4 deletions python/cugraph/cugraph/sampling/uniform_neighbor_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,10 +236,8 @@ def uniform_neighbor_sample(
major_col_name = "majors"
minor_col_name = "minors"

if compression not in ['COO', 'CSR', 'CSC', 'DCSR', 'DCSC']:
raise ValueError(
"compression must be one of COO, CSR, CSC, DCSR, or DCSC"
)
if compression not in ["COO", "CSR", "CSC", "DCSR", "DCSC"]:
raise ValueError("compression must be one of COO, CSR, CSC, DCSR, or DCSC")

if (
(compression != "COO")
Expand Down
24 changes: 14 additions & 10 deletions python/cugraph/cugraph/tests/sampling/test_bulk_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ def test_bulk_sampler_csr(scratch_dir):
el = email_Eu_core.get_edgelist()

G = cugraph.Graph(directed=True)
G.from_cudf_edgelist(el, source='src', destination='dst')
G.from_cudf_edgelist(el, source="src", destination="dst")

samples_path = os.path.join(scratch_dir, "test_bulk_sampler_csr")
create_directory_with_overwrite(samples_path)
Expand All @@ -318,21 +318,25 @@ def test_bulk_sampler_csr(scratch_dir):
batches_per_partition=7,
renumber=True,
use_legacy_names=False,
compression='CSR',
compression="CSR",
compress_per_hop=False,
prior_sources_behavior='exclude',
include_hop_column=False
prior_sources_behavior="exclude",
include_hop_column=False,
)

seeds = G.select_random_vertices(62, 1000)
batch_ids = cudf.Series(cupy.repeat(cupy.arange(int(1000/7)+1,dtype='int32'), 7)[:1000]).sort_values()
batch_ids = cudf.Series(
cupy.repeat(cupy.arange(int(1000 / 7) + 1, dtype="int32"), 7)[:1000]
).sort_values()

batch_df = cudf.DataFrame({
'seed': seeds,
'batch': batch_ids,
})
batch_df = cudf.DataFrame(
{
"seed": seeds,
"batch": batch_ids,
}
)

bs.add_batches(batch_df, start_col_name='seed', batch_col_name='batch')
bs.add_batches(batch_df, start_col_name="seed", batch_col_name="batch")
bs.flush()

assert len(os.listdir(samples_path)) == 21
Expand Down
35 changes: 18 additions & 17 deletions python/cugraph/cugraph/tests/sampling/test_bulk_sampler_mg.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,12 +251,12 @@ def test_bulk_sampler_empty_batches(dask_client, scratch_dir):

@pytest.mark.mg
@pytest.mark.parametrize("mg_input", [True, False])
def test_bulk_sampler_csr(dask_client,scratch_dir,mg_input):
def test_bulk_sampler_csr(dask_client, scratch_dir, mg_input):
nworkers = len(dask_client.scheduler_info()["workers"])
el = dask_cudf.from_cudf(email_Eu_core.get_edgelist(), npartitions=nworkers*2)
el = dask_cudf.from_cudf(email_Eu_core.get_edgelist(), npartitions=nworkers * 2)

G = cugraph.Graph(directed=True)
G.from_dask_cudf_edgelist(el, source='src', destination='dst')
G.from_dask_cudf_edgelist(el, source="src", destination="dst")

samples_path = os.path.join(scratch_dir, "mg_test_bulk_sampler_csr")
create_directory_with_overwrite(samples_path)
Expand All @@ -270,29 +270,30 @@ def test_bulk_sampler_csr(dask_client,scratch_dir,mg_input):
batches_per_partition=7,
renumber=True,
use_legacy_names=False,
compression='CSR',
compression="CSR",
compress_per_hop=False,
prior_sources_behavior='exclude',
include_hop_column=False
prior_sources_behavior="exclude",
include_hop_column=False,
)

seeds = G.select_random_vertices(62, 1000)
batch_ids = cudf.Series(cupy.repeat(cupy.arange(int(1000/7)+1,dtype='int32'), 7)[:1000]).sort_values()
batch_ids = cudf.Series(
cupy.repeat(cupy.arange(int(1000 / 7) + 1, dtype="int32"), 7)[:1000]
).sort_values()

batch_df = cudf.DataFrame({
'seed': seeds.compute().values,
'batch': batch_ids,
})
batch_df = cudf.DataFrame(
{
"seed": seeds.compute().values,
"batch": batch_ids,
}
)

if mg_input:
batch_df = dask_cudf.from_cudf(
batch_df,
npartitions=2
)
batch_df = dask_cudf.from_cudf(batch_df, npartitions=2)

bs.add_batches(batch_df, start_col_name='seed', batch_col_name='batch')
bs.add_batches(batch_df, start_col_name="seed", batch_col_name="batch")
bs.flush()

assert len(os.listdir(samples_path)) == 21

shutil.rmtree(samples_path)
shutil.rmtree(samples_path)

0 comments on commit 4d51751

Please sign in to comment.