From 1d7dc6b7a281a4af34ed19bf00e805305135ca11 Mon Sep 17 00:00:00 2001 From: weilycoder Date: Wed, 27 Nov 2024 07:34:42 +0800 Subject: [PATCH] Fix forest generation --- cyaron/graph.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/cyaron/graph.py b/cyaron/graph.py index 565ac41..3775b47 100644 --- a/cyaron/graph.py +++ b/cyaron/graph.py @@ -545,9 +545,10 @@ def forest(point_count, tree_count, **kwargs): """ if tree_count <= 0 or tree_count > point_count: raise ValueError("tree_count must be between 1 and point_count") - tree = list(Graph.tree(point_count, **kwargs).iterate_edges()) - result = Graph(point_count, 0) - need_add = random.sample(tree, len(tree) - tree_count + 1) + tree = Graph.tree(point_count, **kwargs) + tree_edges = list(tree.iterate_edges()) + result = Graph(point_count, tree.directed) + need_add = random.sample(tree_edges, len(tree_edges) - tree_count + 1) for edge in need_add: result.add_edge(edge.start, edge.end, weight=edge.weight) return result