diff --git a/cyaron/graph.py b/cyaron/graph.py index 565ac41..3775b47 100644 --- a/cyaron/graph.py +++ b/cyaron/graph.py @@ -545,9 +545,10 @@ def forest(point_count, tree_count, **kwargs): """ if tree_count <= 0 or tree_count > point_count: raise ValueError("tree_count must be between 1 and point_count") - tree = list(Graph.tree(point_count, **kwargs).iterate_edges()) - result = Graph(point_count, 0) - need_add = random.sample(tree, len(tree) - tree_count + 1) + tree = Graph.tree(point_count, **kwargs) + tree_edges = list(tree.iterate_edges()) + result = Graph(point_count, tree.directed) + need_add = random.sample(tree_edges, len(tree_edges) - tree_count + 1) for edge in need_add: result.add_edge(edge.start, edge.end, weight=edge.weight) return result