diff --git a/config.toml b/config.toml index d72a478..318917e 100644 --- a/config.toml +++ b/config.toml @@ -3,3 +3,4 @@ baudrate = 921600 heartbeat_pin = 34 update_pin = 35 heartbeat_freq = 10 # Hz +socket = "/tmp/scheduler_socket" diff --git a/examples/cli.rs b/examples/cli.rs index 6c6c3f5..fc3fd8b 100644 --- a/examples/cli.rs +++ b/examples/cli.rs @@ -1,5 +1,9 @@ use std::{ - error::Error, io::{Read, Write}, path::Path, process::{Child, ChildStdin, ChildStdout, Stdio}, time::Duration + error::Error, + io::{Read, Write}, + path::Path, + process::{Child, ChildStdin, ChildStdout, Stdio}, + time::Duration, }; use STS1_EDU_Scheduler::communication::{CEPPacket, CommunicationHandle}; @@ -71,7 +75,10 @@ fn write_scheduler_config(path: &str) { const COMMANDS: &[&str] = &["StoreArchive", "ExecuteProgram", "StopProgram", "GetStatus", "ReturnResult", "UpdateTime"]; -fn inquire_and_send_command(edu: &mut impl CommunicationHandle, path: &str) -> Result<(), Box> { +fn inquire_and_send_command( + edu: &mut impl CommunicationHandle, + path: &str, +) -> Result<(), Box> { let mut select = inquire::Select::new("Select command", COMMANDS.to_vec()); if Path::new(&format!("{path}/updatepin")).exists() { select.help_message = Some("Update Pin is high"); @@ -122,18 +129,20 @@ fn inquire_and_send_command(edu: &mut impl CommunicationHandle, path: &str) -> R n => println!("Unknown event {n}"), } } - }, + } "ReturnResult" => { let program_id = inquire::Text::new("Program id:").prompt()?.parse()?; let timestamp = inquire::Text::new("Timestamp:").prompt()?.parse()?; - let result_path = inquire::Text::new("File path for returned result:").with_default("./result.tar").prompt()?; + let result_path = inquire::Text::new("File path for returned result:") + .with_default("./result.tar") + .prompt()?; edu.send_packet(&CEPPacket::Data(return_result(program_id, timestamp)))?; match edu.receive_multi_packet() { Ok(data) => { std::fs::write(result_path, data)?; edu.send_packet(&CEPPacket::Ack)?; println!("Wrote result to file"); - }, + } Err(e) => println!("Received {:?}", e), } } diff --git a/src/command/execution_context.rs b/src/command/execution_context.rs index 1d0bce3..6654d5c 100644 --- a/src/command/execution_context.rs +++ b/src/command/execution_context.rs @@ -1,4 +1,5 @@ use std::{ + str::FromStr, sync::{Arc, Mutex}, thread, }; @@ -100,7 +101,7 @@ impl UpdatePin { } /// Struct used for storing information about a finished student program -#[derive(Clone, Copy, serde::Serialize, serde::Deserialize, PartialEq, Eq)] +#[derive(Clone, Copy, serde::Serialize, serde::Deserialize, PartialEq, Eq, Debug)] pub struct ProgramStatus { pub program_id: u16, pub timestamp: u32, @@ -108,13 +109,13 @@ pub struct ProgramStatus { } /// Struct used for storing information of a result, waiting to be sent -#[derive(Clone, Copy, serde::Serialize, serde::Deserialize, PartialEq, Eq)] +#[derive(Clone, Copy, serde::Serialize, serde::Deserialize, PartialEq, Eq, Debug)] pub struct ResultId { pub program_id: u16, pub timestamp: u32, } -#[derive(serde::Serialize, serde::Deserialize, Clone, Copy, PartialEq, Eq)] +#[derive(serde::Serialize, serde::Deserialize, Clone, Copy, PartialEq, Eq, Debug)] pub enum Event { Status(ProgramStatus), Result(ResultId), @@ -147,3 +148,15 @@ impl From for Vec { v } } + +impl FromStr for Event { + type Err = (); + + fn from_str(s: &str) -> Result { + match s { + "dosimeter/on" => Ok(Event::EnableDosimeter), + "dosimeter/off" => Ok(Event::DisableDosimeter), + _ => Err(()), + } + } +} diff --git a/src/command/get_status.rs b/src/command/get_status.rs index 86303a1..03c3039 100644 --- a/src/command/get_status.rs +++ b/src/command/get_status.rs @@ -24,12 +24,12 @@ pub fn get_status( com.send_packet(&CEPPacket::Data(event.into()))?; l_exec.event_vec.remove(index)?; } else { - let event = *l_exec.event_vec.as_ref().last().unwrap(); // Safe, because we know it is not empty + let event = *l_exec.event_vec.as_ref().first().unwrap(); // Safe, because we know it is not empty com.send_packet(&CEPPacket::Data(event.into()))?; if !matches!(event, Event::Result(_)) { // Results are removed when deleted - l_exec.event_vec.pop()?; + l_exec.event_vec.remove(0)?; } } diff --git a/src/communication/mod.rs b/src/communication/mod.rs index 231cde2..2ef0e51 100644 --- a/src/communication/mod.rs +++ b/src/communication/mod.rs @@ -1,6 +1,8 @@ mod cep; pub use cep::CEPPacket; +pub mod socket; + use std::{ io::{Read, Write}, time::Duration, diff --git a/src/communication/socket.rs b/src/communication/socket.rs new file mode 100644 index 0000000..3373fc3 --- /dev/null +++ b/src/communication/socket.rs @@ -0,0 +1,158 @@ +use std::{ + io::{BufRead, BufReader, Write}, + os::unix::net::{UnixListener, UnixStream}, + path::Path, + str::FromStr, +}; + +pub struct UnixSocketParser { + listener: UnixListener, + connection: Option>, +} + +impl UnixSocketParser { + pub fn new(path: &str) -> std::io::Result { + let _ = std::fs::remove_file(path); + Ok(Self { listener: UnixListener::bind(path)?, connection: None }) + } + + pub fn read_object(&mut self) -> std::io::Result { + if self.connection.is_none() { + let (stream, _) = self.listener.accept()?; + self.connection = Some(BufReader::new(stream)); + } + + let con = self.connection.as_mut().unwrap(); + let mut line = String::new(); + con.read_line(&mut line)?; + + if !line.ends_with('\n') || line.is_empty() { + self.connection.take(); + return Err(std::io::ErrorKind::ConnectionAborted.into()); + } + + if line == Self::SHUTDOWN_STRING { + return Err(std::io::ErrorKind::Other.into()); + } + + T::from_str(line.trim_end()).map_err(|_| std::io::ErrorKind::InvalidData.into()) + } + + const SHUTDOWN_STRING: &'static str = "shutdown\n"; + pub fn _shutdown(path: impl AsRef) -> std::io::Result<()> { + let mut stream = UnixStream::connect(path)?; + stream.write_all(Self::SHUTDOWN_STRING.as_bytes())?; + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use std::sync::atomic::{AtomicUsize, Ordering}; + + use super::*; + + fn get_unique_tmp_path() -> String { + static COUNTER: AtomicUsize = AtomicUsize::new(0); + let value = COUNTER.fetch_add(1, Ordering::Relaxed); + let path = format!("/tmp/STS1_socket_test_{value}"); + let _ = std::fs::remove_file(&path); + path + } + + #[test] + fn can_shutdown() { + let path = get_unique_tmp_path(); + let mut rx = UnixSocketParser::new(&path).unwrap(); + + UnixSocketParser::_shutdown(&path).unwrap(); + + assert_eq!(std::io::ErrorKind::Other, rx.read_object::().unwrap_err().kind()); + } + + #[test] + fn can_parse_single_value() { + let path = get_unique_tmp_path(); + let mut rx = UnixSocketParser::new(&path).unwrap(); + + let mut stream = UnixStream::connect(&path).unwrap(); + writeln!(stream, "1234").unwrap(); + + assert_eq!(1234, rx.read_object::().unwrap()); + + UnixSocketParser::_shutdown(path).unwrap(); + } + + #[test] + fn can_parse_multiple_values() { + let path = get_unique_tmp_path(); + let mut rx = UnixSocketParser::new(&path).unwrap(); + + let mut stream = UnixStream::connect(&path).unwrap(); + + const REPS: usize = 100; + for i in 0..REPS { + writeln!(stream, "{i}").unwrap(); + } + + for i in 0..REPS { + assert_eq!(i, rx.read_object::().unwrap()); + } + + UnixSocketParser::_shutdown(path).unwrap(); + } + + #[test] + fn can_reconnect_multiple_times() { + let path = get_unique_tmp_path(); + let mut rx = UnixSocketParser::new(&path).unwrap(); + + for i in 0..10 { + { + let mut stream = UnixStream::connect(&path).unwrap(); + writeln!(stream, "{i}").unwrap(); + } + + assert_eq!(i, rx.read_object::().unwrap()); + assert_eq!( + rx.read_object::().unwrap_err().kind(), + std::io::ErrorKind::ConnectionAborted + ); + } + + UnixSocketParser::_shutdown(path).unwrap(); + } + + #[test] + fn can_deal_with_invalid_data() { + let path = get_unique_tmp_path(); + let mut rx = UnixSocketParser::new(&path).unwrap(); + + let mut stream = UnixStream::connect(&path).unwrap(); + writeln!(stream, "invalid").unwrap(); + assert_eq!(std::io::ErrorKind::InvalidData, rx.read_object::().unwrap_err().kind()); + + writeln!(stream, "123").unwrap(); + assert_eq!(123, rx.read_object::().unwrap()); + + UnixSocketParser::_shutdown(path).unwrap(); + } + + #[test] + fn can_reconnect_after_midline_abort() { + let path = get_unique_tmp_path(); + let mut rx = UnixSocketParser::new(&path).unwrap(); + + { + let mut stream = UnixStream::connect(&path).unwrap(); + write!(stream, "1234").unwrap(); + } + + let mut stream = UnixStream::connect(&path).unwrap(); + writeln!(stream, "5647").unwrap(); + + rx.read_object::().unwrap_err(); + assert_eq!(5647, rx.read_object::().unwrap()); + } +} diff --git a/src/main.rs b/src/main.rs index a61bb77..9c94aed 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,12 +1,20 @@ #![allow(non_snake_case)] +use command::ExecutionContext; +use communication::socket::UnixSocketParser; use core::time; use rppal::gpio::Gpio; use serialport::SerialPort; -use std::thread; +use std::{ + io::ErrorKind, + sync::{Arc, Mutex}, + thread, +}; use STS1_EDU_Scheduler::communication::CommunicationHandle; use simplelog as sl; +use crate::command::Event; + mod command; mod communication; @@ -17,6 +25,7 @@ struct Configuration { heartbeat_pin: u8, update_pin: u8, heartbeat_freq: u64, + socket: String, } fn main() -> ! { @@ -45,6 +54,10 @@ fn main() -> ! { // construct a wrapper for resources that are shared between different commands let mut exec = command::ExecutionContext::new("events".to_string(), config.update_pin).unwrap(); + let socket_rx = communication::socket::UnixSocketParser::new(&config.socket).unwrap(); + let socket_context = exec.clone(); + std::thread::spawn(move || event_socket_loop(socket_context, socket_rx)); + // start a thread that will update the heartbeat pin thread::spawn(move || heartbeat_loop(config.heartbeat_pin, config.heartbeat_freq)); @@ -72,6 +85,22 @@ fn heartbeat_loop(heartbeat_pin: u8, freq: u64) -> ! { } } +fn event_socket_loop(context: Arc>, mut socket: UnixSocketParser) { + loop { + let s = socket.read_object::(); + let event = match s { + Ok(e) => e, + Err(ref e) if e.kind() == ErrorKind::Other => break, + Err(_) => continue, + }; + + log::info!("Received on socket: {event:?}"); + let mut context = context.lock().unwrap(); + context.event_vec.push(event).unwrap(); + context.check_update_pin(); + } +} + /// Tries to create a directory, but only returns an error if the path does not already exists fn create_directory_if_not_exists(path: impl AsRef) -> std::io::Result<()> { match std::fs::create_dir(path) { diff --git a/tests/simulation/mod.rs b/tests/simulation/mod.rs index 9edcf4a..1caed40 100644 --- a/tests/simulation/mod.rs +++ b/tests/simulation/mod.rs @@ -1,6 +1,7 @@ mod command_execution; mod full_run; mod logging; +mod socket; mod timeout; use std::{ @@ -65,14 +66,14 @@ impl CommunicationHandle for SimulationComHandle { fn get_config_str(unique: &str) -> String { format!( " - uart = \"/tmp/ttySTS1-{}\" + uart = \"/tmp/ttySTS1-{unique}\" baudrate = 921600 heartbeat_pin = 34 update_pin = 35 heartbeat_freq = 10 log_path = \"log\" - ", - unique + socket = \"/tmp/STS1_EDU_Scheduler_SIM_{unique}\" + " ) } diff --git a/tests/simulation/socket.rs b/tests/simulation/socket.rs new file mode 100644 index 0000000..df17615 --- /dev/null +++ b/tests/simulation/socket.rs @@ -0,0 +1,39 @@ +use std::io::Write; +use std::os::unix::net::UnixStream; +use std::time::Duration; + +use super::{simulate_get_status, start_scheduler, SimulationComHandle}; + +#[test] +fn dosimeter_events_are_added() { + let (mut com, _socat) = SimulationComHandle::with_socat_proc("dosimeter"); + let _sched = start_scheduler("dosimeter").unwrap(); + std::thread::sleep(Duration::from_millis(200)); + + { + let mut socket = UnixStream::connect("/tmp/STS1_EDU_Scheduler_SIM_dosimeter").unwrap(); + writeln!(socket, "dosimeter/on").unwrap(); + } + + std::thread::sleep(Duration::from_millis(200)); + assert_eq!(simulate_get_status(&mut com).unwrap(), [0x03]); +} + +#[test] +fn multiple_dosimeter_events() { + let (mut com, _socat) = SimulationComHandle::with_socat_proc("dosimeter-multi"); + let _sched = start_scheduler("dosimeter-multi").unwrap(); + std::thread::sleep(Duration::from_millis(200)); + + let mut socket = UnixStream::connect("/tmp/STS1_EDU_Scheduler_SIM_dosimeter-multi").unwrap(); + for _ in 0..10 { + writeln!(socket, "dosimeter/on").unwrap(); + writeln!(socket, "dosimeter/off").unwrap(); + } + + std::thread::sleep(Duration::from_millis(200)); + for _ in 0..10 { + assert_eq!(simulate_get_status(&mut com).unwrap(), [0x03]); + assert_eq!(simulate_get_status(&mut com).unwrap(), [0x04]); + } +}