diff --git a/src/searching/src/dijkstra.cairo b/src/searching/src/dijkstra.cairo index e01840ce..5604e413 100644 --- a/src/searching/src/dijkstra.cairo +++ b/src/searching/src/dijkstra.cairo @@ -3,15 +3,30 @@ use core::nullable::{FromNullableResult, match_nullable}; #[derive(Copy, Drop)] pub struct Node { - pub source: u32, - pub dest: u32, - pub weight: u128 + source: u32, + dest: u32, + weight: u128 +} + +#[generate_trait] +pub impl NodeGetters of NodeGettersTrait { + fn weight(self: @Node) -> @u128 { + self.weight + } + + fn dest(self: @Node) -> @u32 { + self.dest + } + + fn source(self: @Node) -> @u32 { + self.source + } } /// Graph representation. pub struct Graph { - pub nodes: Array, - pub adj_nodes: Felt252Dict, + pub(crate) nodes: Array, + adj_nodes: Felt252Dict, } /// Graph trait. @@ -22,6 +37,8 @@ pub trait GraphTrait { fn add_edge(ref self: Graph>>, source: u32, dest: u32, weight: u128); /// return shortest path from s fn shortest_path(ref self: Graph>>, source: u32) -> Felt252Dict; + /// return shortest path from s + fn adj_nodes(ref self: Graph>>, source: felt252) -> Nullable>; } impl DestructGraph, +Felt252DictValue> of Destruct> { @@ -66,6 +83,11 @@ impl GraphImpl of GraphTrait { fn shortest_path(ref self: Graph>>, source: u32) -> Felt252Dict { dijkstra(ref self, source) } + + + fn adj_nodes(ref self: Graph>>, source: felt252) -> Nullable> { + self.adj_nodes.get(source) + } } pub fn dijkstra(ref self: Graph>>, source: u32) -> Felt252Dict { diff --git a/src/searching/src/tests/dijkstra_test.cairo b/src/searching/src/tests/dijkstra_test.cairo index a06af61c..3c62d0ce 100644 --- a/src/searching/src/tests/dijkstra_test.cairo +++ b/src/searching/src/tests/dijkstra_test.cairo @@ -1,4 +1,4 @@ -use alexandria_searching::dijkstra::{Graph, Node, GraphTrait}; +use alexandria_searching::dijkstra::{Graph, Node, GraphTrait, NodeGetters}; use core::nullable::{FromNullableResult, match_nullable}; @@ -17,7 +17,7 @@ fn add_edge() { GraphTrait::add_edge(ref graph, 2, 3, 3); assert_eq!(graph.nodes.len(), 6, "wrong node number"); - let val = graph.adj_nodes.get(source.into()); + let val = graph.adj_nodes(source.into()); let span = match match_nullable(val) { FromNullableResult::Null => { panic!("No value found") }, @@ -27,14 +27,14 @@ fn add_edge() { assert_eq!(span.len(), 4, "wrong nb of adj edge for node 0"); let new_node = *span.get(1).unwrap().unbox(); - assert_eq!(new_node.dest, dest + 1, "Wrong dest in adj edge"); - assert_eq!(new_node.weight, weight + 1, "Wrong weight in adj edge"); + assert_eq!(*new_node.dest(), dest + 1, "Wrong dest in adj edge"); + assert_eq!(*new_node.weight(), weight + 1, "Wrong weight in adj edge"); let new_node = *span.get(3).unwrap().unbox(); - assert_eq!(new_node.dest, dest + 3, "Wrong dest in adj edge"); - assert_eq!(new_node.weight, weight + 3, "Wrong weight in adj edge"); + assert_eq!(*new_node.dest(), dest + 3, "Wrong dest in adj edge"); + assert_eq!(*new_node.weight(), weight + 3, "Wrong weight in adj edge"); - let val = graph.adj_nodes.get(2.into()); + let val = graph.adj_nodes(2.into()); let span = match match_nullable(val) { FromNullableResult::Null => { panic!("No value found") },