diff --git a/Cargo.lock b/Cargo.lock index 1d4bb65a9..cbbd8794f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -228,6 +228,12 @@ version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" +[[package]] +name = "heck" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" + [[package]] name = "indexmap" version = "2.4.0" @@ -278,6 +284,7 @@ dependencies = [ "rstest", "serde", "serial_test", + "strum", "thiserror", ] @@ -447,7 +454,7 @@ version = "0.21.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "08260721f32db5e1a5beae69a55553f56b99bd0e1c3e6e0a5e8851a9d0f5a85c" dependencies = [ - "heck", + "heck 0.4.1", "proc-macro2", "pyo3-build-config", "quote", @@ -560,6 +567,12 @@ dependencies = [ "semver", ] +[[package]] +name = "rustversion" +version = "1.0.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "955d28af4278de8121b7ebeb796b6a45735dc01436d898801014aced2773a3d6" + [[package]] name = "scc" version = "2.1.16" @@ -647,6 +660,28 @@ version = "1.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" +[[package]] +name = "strum" +version = "0.26.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8fec0f0aef304996cf250b31b5a10dee7980c85da9d759361292b8bca5a18f06" +dependencies = [ + "strum_macros", +] + +[[package]] +name = "strum_macros" +version = "0.26.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4c6bee85a5a24955dc440386795aa378cd9cf82acd5f764469152d2270e581be" +dependencies = [ + "heck 0.5.0", + "proc-macro2", + "quote", + "rustversion", + "syn", +] + [[package]] name = "syn" version = "2.0.75" diff --git a/librapidflux/Cargo.toml b/librapidflux/Cargo.toml index 40ea7c3b1..58fd1cde2 100644 --- a/librapidflux/Cargo.toml +++ b/librapidflux/Cargo.toml @@ -11,6 +11,7 @@ indexmap = { workspace = true, features = ["serde"] } lazy_static = { workspace = true } owo-colors = "4.0.0" serde = { workspace = true } +strum = { version = "0.26.3", features = ["derive"] } thiserror = "1.0.61" [dev-dependencies] diff --git a/librapidflux/src/ty.rs b/librapidflux/src/ty.rs index 00a4262c9..177dfca1f 100644 --- a/librapidflux/src/ty.rs +++ b/librapidflux/src/ty.rs @@ -3,31 +3,59 @@ use std::{collections::HashSet, fmt::Display, path::PathBuf, str::FromStr}; use indexmap::IndexMap; use lazy_static::lazy_static; use serde::{Deserialize, Serialize}; +use strum::EnumDiscriminants; use crate::{ consts, create_id, - diagnostics::{FilePosition, Location}, + diagnostics::{Annotation, ErrorEntry, FilePosition, Location, RapidFluxError, Severity}, identifier::ID, }; #[must_use] -#[derive(Clone, PartialEq, Serialize, Deserialize, Debug)] +#[derive(EnumDiscriminants, Clone, PartialEq, Serialize, Deserialize, Debug)] +#[strum_discriminants(derive(strum::Display))] pub enum Ty { + #[strum_discriminants(strum(to_string = "undefined type"))] Undefined, + #[strum_discriminants(strum(to_string = "any type"))] Any, + #[strum_discriminants(strum(to_string = "enumeration type"))] Enumeration(Enumeration), + #[strum_discriminants(strum(to_string = "integer type"))] + AnyInteger, + #[strum_discriminants(strum(to_string = "type universal integer"))] UniversalInteger(UniversalInteger), + #[strum_discriminants(strum(to_string = "integer type"))] Integer(Integer), + #[strum_discriminants(strum(to_string = "composite type"))] + Composite, + #[strum_discriminants(strum(to_string = "aggregate"))] Aggregate(Aggregate), + #[strum_discriminants(strum(to_string = "sequence type"))] Sequence(Sequence), + #[strum_discriminants(strum(to_string = "compound type"))] + Compound, + #[strum_discriminants(strum(to_string = "structure type"))] Structure(Structure), + #[strum_discriminants(strum(to_string = "message type"))] Message(Message), + #[strum_discriminants(strum(to_string = "channel"))] Channel(Channel), } impl Ty { + /// Check compatibility to another type. + /// + /// # Panics + /// + /// Will panic if `self` or `other` is an instance of `AnyInteger`, `Composite` or `Compound`. pub fn is_compatible(&self, other: &Ty) -> bool { match (&self, other) { + (Ty::AnyInteger | Ty::Composite | Ty::Compound, _) + | (_, Ty::AnyInteger | Ty::Composite | Ty::Compound) => { + panic!("unexpected type instance") + } + (Ty::Undefined, _) | (_, Ty::Undefined) => false, (Ty::Any, _) @@ -36,6 +64,7 @@ impl Ty { Ty::UniversalInteger(..) | Ty::Integer(..), Ty::UniversalInteger(..) | Ty::Integer(..), ) => true, + (Ty::Enumeration(enumeration), Ty::Enumeration(other)) => enumeration.id == other.id, (Ty::Aggregate(aggregate), Ty::Aggregate(other)) => { aggregate.element.is_compatible(&other.element) @@ -75,8 +104,17 @@ impl Ty { } } + /// Determine common type. + /// + /// # Panics + /// + /// Will panic if `self` or `other` is an instance of `AnyInteger`, `Composite` or `Compound`. pub fn common_type(&self, other: &Ty) -> Ty { match (self, other) { + (Ty::AnyInteger | Ty::Composite | Ty::Compound, _) + | (_, Ty::AnyInteger | Ty::Composite | Ty::Compound) => { + panic!("unexpected type instance") + } (Ty::Undefined, _) | (_, Ty::Undefined) => Ty::Undefined, (_, Ty::Any) => self.to_owned(), (Ty::Any, _) => other.to_owned(), @@ -138,6 +176,31 @@ impl Ty { } } +impl TyDiscriminants { + pub fn is_instance(&self, other: &TyDiscriminants) -> bool { + match (&self, other) { + (TyDiscriminants::Undefined, TyDiscriminants::Any) + | (TyDiscriminants::Any, TyDiscriminants::Undefined) => false, + + (_, TyDiscriminants::Any) + | (TyDiscriminants::Any, _) + | ( + TyDiscriminants::UniversalInteger | TyDiscriminants::Integer, + TyDiscriminants::AnyInteger, + ) + | ( + TyDiscriminants::Aggregate | TyDiscriminants::Sequence, + TyDiscriminants::Composite, + ) + | (TyDiscriminants::Structure | TyDiscriminants::Message, TyDiscriminants::Compound) => { + true + } + + _ => self == other, + } + } +} + impl Display for Ty { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!( @@ -157,16 +220,131 @@ impl Display for Ty { Ty::Message(message) => message.to_string(), Ty::Channel(channel) => channel.to_string(), Ty::Any => { - "any type".to_string() + TyDiscriminants::Any.to_string() + } + Ty::AnyInteger => { + TyDiscriminants::AnyInteger.to_string() + } + Ty::Composite => { + TyDiscriminants::Composite.to_string() + } + Ty::Compound => { + TyDiscriminants::Compound.to_string() } Ty::Undefined => { - "undefined type".to_string() + TyDiscriminants::Undefined.to_string() } } ) } } +pub fn common_type(types: &[Ty]) -> Ty { + types.iter().fold(Ty::Any, |r, t| r.common_type(t)) +} + +/// Check if the given type is compatible to the expected types. +/// +/// # Panics +/// +/// Will panic if `expected` is empty or location is `None`. +pub fn check_type( + actual: &Ty, + expected: &[Ty], + location: Option<&Location>, + description: &str, +) -> RapidFluxError { + assert!(!expected.is_empty()); + + if *actual == Ty::Undefined { + return undefined_type(location, description); + } + let mut error = RapidFluxError::default(); + + if !expected.contains(&Ty::Undefined) && !expected.iter().any(|e| actual.is_compatible(e)) { + let desc = expected + .iter() + .map(std::string::ToString::to_string) + .collect::>() + .join(" or "); + error.push(ErrorEntry::new( + format!("expected {desc}"), + Severity::Error, + location.cloned(), + vec![Annotation::new( + Some(format!("found {actual}")), + Severity::Error, + location.cloned().unwrap(), + )], + false, + )); + } + + error +} + +/// Check if the given type is an instance of the expected types. +/// +/// # Panics +/// +/// Will panic if `expected` is empty or if an unexpected type was found and location is `None`. +pub fn check_type_instance( + actual: &Ty, + expected: &[TyDiscriminants], + location: Option<&Location>, + description: &str, + additional_annotations: &[Annotation], +) -> RapidFluxError { + assert!(!expected.is_empty()); + + if *actual == Ty::Undefined { + return undefined_type(location, description); + } + let mut error = RapidFluxError::default(); + + if !expected + .iter() + .any(|e| TyDiscriminants::from(actual).is_instance(e)) + { + let desc = expected + .iter() + .map(std::string::ToString::to_string) + .collect::>() + .join(" or "); + let mut annotations = vec![]; + annotations.push(Annotation::new( + Some(format!("found {actual}")), + Severity::Error, + location.cloned().unwrap(), + )); + annotations.extend(additional_annotations.iter().cloned()); + error.push(ErrorEntry::new( + format!("expected {desc}"), + Severity::Error, + location.cloned(), + annotations, + false, + )); + } + + error +} + +fn undefined_type(location: Option<&Location>, description: &str) -> RapidFluxError { + let description = if description.is_empty() { + String::new() + } else { + format!(" {description}") + }; + RapidFluxError::from(vec![ErrorEntry::new( + format!("undefined{description}"), + Severity::Error, + location.cloned(), + vec![], + true, + )]) +} + #[derive(Clone, Serialize, Deserialize, Debug)] pub struct Enumeration { pub id: ID, @@ -519,6 +697,7 @@ mod tests { use std::collections::HashSet; use indexmap::IndexMap; + use indoc::indoc; use lazy_static::lazy_static; use pretty_assertions::assert_eq; use rstest::rstest; @@ -526,9 +705,10 @@ mod tests { use crate::create_id; use super::{ - Aggregate, Bounds, Channel, Enumeration, Integer, Message, Refinement, Sequence, Structure, - Ty, UniversalInteger, BASE_INTEGER, BIT_LENGTH_BOUNDS, ID, LENGTH_BOUNDS, OPAQUE, - UNIVERSAL_INTEGER, + check_type, check_type_instance, common_type, Aggregate, Bounds, Channel, Enumeration, + FilePosition, Integer, Location, Message, Refinement, Sequence, Structure, Ty, + TyDiscriminants, UniversalInteger, BASE_INTEGER, BIT_LENGTH_BOUNDS, ID, LENGTH_BOUNDS, + OPAQUE, UNIVERSAL_INTEGER, }; lazy_static! { @@ -660,10 +840,13 @@ mod tests { #[case::undefined(&Ty::Undefined, "undefined type")] #[case::any(&Ty::Any, "any type")] #[case::enumeration(&ENUM_A, "enumeration type \"A\"")] + #[case::any_integer(&Ty::AnyInteger, "integer type")] #[case::universal_integer(&UNIV_INT_1_3, "type universal integer (1 .. 3)")] #[case::integer(&INT_A, "integer type \"A\" (1 .. 5)")] + #[case::composite(&Ty::Composite, "composite type")] #[case::aggregate(&AGG_INT_A, "aggregate with element integer type \"A\" (1 .. 5)")] #[case::sequence(&SEQ_A, "sequence type \"A\" with element integer type \"A\" (1 .. 5)")] + #[case::compound(&Ty::Compound, "compound type")] #[case::structure(&STRUCT_A, "structure type \"A\"")] #[case::message(&MSG_A, "message type \"A\"")] #[case::channel(&CHAN, "channel")] @@ -739,6 +922,30 @@ mod tests { assert_eq!(other.is_compatible_strong(ty), is_compatible_strong); } + #[rstest] + #[case(&Ty::AnyInteger, &Ty::Any)] + #[case(&Ty::Composite, &Ty::Any)] + #[case(&Ty::Compound, &Ty::Any)] + #[case(&Ty::Any, &Ty::AnyInteger)] + #[case(&Ty::Any, &Ty::Composite)] + #[case(&Ty::Any, &Ty::Compound)] + #[should_panic(expected = "unexpected type instance")] + fn test_ty_is_compatible_panic(#[case] ty: &Ty, #[case] other: &Ty) { + ty.is_compatible(other); + } + + #[rstest] + #[case(&Ty::AnyInteger, &Ty::Any)] + #[case(&Ty::Composite, &Ty::Any)] + #[case(&Ty::Compound, &Ty::Any)] + #[case(&Ty::Any, &Ty::AnyInteger)] + #[case(&Ty::Any, &Ty::Composite)] + #[case(&Ty::Any, &Ty::Compound)] + #[should_panic(expected = "unexpected type instance")] + fn test_ty_common_type_panic(#[case] ty: &Ty, #[case] other: &Ty) { + let _ = ty.common_type(other); + } + #[rstest] #[case::undefined(&Ty::Undefined)] #[case::any(&Ty::Any)] @@ -756,6 +963,111 @@ mod tests { assert_eq!(*ty, deserialized_ty); } + #[rstest] + #[case(&TyDiscriminants::Undefined, &TyDiscriminants::Any, false)] + #[case(&TyDiscriminants::Any, &TyDiscriminants::Undefined, false)] + #[case(&TyDiscriminants::Integer, &TyDiscriminants::Enumeration, false)] + #[case(&TyDiscriminants::Enumeration, &TyDiscriminants::Enumeration, true)] + #[case(&TyDiscriminants::UniversalInteger, &TyDiscriminants::AnyInteger, true)] + #[case(&TyDiscriminants::Integer, &TyDiscriminants::AnyInteger, true)] + #[case(&TyDiscriminants::Aggregate, &TyDiscriminants::Composite, true)] + #[case(&TyDiscriminants::Sequence, &TyDiscriminants::Composite, true)] + #[case(&TyDiscriminants::Structure, &TyDiscriminants::Compound, true)] + #[case(&TyDiscriminants::Message, &TyDiscriminants::Compound, true)] + fn test_ty_discriminants_is_instance( + #[case] ty: &TyDiscriminants, + #[case] other: &TyDiscriminants, + #[case] expected: bool, + ) { + assert_eq!(ty.is_instance(other), expected); + } + + #[rstest] + #[case::empty(&[], &Ty::Any)] + #[case::undefined(&[&CHAN_R, &*MSG_A], &Ty::Undefined)] + #[case::integer(&[&*INT_A, &*UNIV_INT_1_3], &BASE_INTEGER)] + #[case::aggregate(&[&*AGG_INT_A, &*AGG_INT_B], &AGG_BASE_INT)] + fn test_common_type(#[case] types: &[&Ty], #[case] expected: &Ty) { + assert_eq!( + common_type(&types.iter().copied().cloned().collect::>()), + *expected + ); + } + + #[rstest] + #[case::valid(&Ty::Any, &[&Ty::Any], "")] + #[case::expected_undefined(&*INT_A, &[&Ty::Undefined], "")] + #[case::actual_undefined( + &Ty::Undefined, + &[&*INT_A], + indoc! {r" + :1:1: error: undefined foo + "}, + )] + #[case::invalid( + &ENUM_A, + &[&*INT_A], + indoc! {r#" + :1:1: error: expected integer type "A" (1 .. 5) + :1:1: error: found enumeration type "A" + "#}, + )] + fn test_check_type(#[case] actual: &Ty, #[case] types: &[&Ty], #[case] expected: &str) { + assert_eq!( + check_type( + actual, + &types.iter().copied().cloned().collect::>(), + Some(&Location { + start: FilePosition::new(1, 1), + end: None, + source: None, + }), + "foo" + ) + .to_string(), + *expected.trim() + ); + } + + #[rstest] + #[case::valid(&Ty::Any, &[TyDiscriminants::Any], "")] + #[case::actual_undefined( + &Ty::Undefined, + &[TyDiscriminants::AnyInteger], + indoc! {r" + :1:1: error: undefined + "}, + )] + #[case::invalid( + &ENUM_A, + &[TyDiscriminants::AnyInteger], + indoc! {r#" + :1:1: error: expected integer type + :1:1: error: found enumeration type "A" + "#}, + )] + fn test_check_type_instance( + #[case] actual: &Ty, + #[case] types: &[TyDiscriminants], + #[case] expected: &str, + ) { + assert_eq!( + check_type_instance( + actual, + types, + Some(&Location { + start: FilePosition::new(1, 1), + end: None, + source: None, + }), + "", + &[] + ) + .to_string(), + *expected.trim() + ); + } + #[rstest] #[case(i128::MIN, i128::MIN)] #[case(i128::MIN, i128::MAX)] diff --git a/rapidflux/src/diagnostics/errors.rs b/rapidflux/src/diagnostics/errors.rs index d936a4298..bd3c2fc2c 100644 --- a/rapidflux/src/diagnostics/errors.rs +++ b/rapidflux/src/diagnostics/errors.rs @@ -145,7 +145,7 @@ impl From for lib::Severity { #[pyclass(module = "rflx.rapidflux")] #[derive(Clone, Debug, Serialize, Deserialize)] -pub struct Annotation(lib::Annotation); +pub struct Annotation(pub lib::Annotation); #[pymethods] impl Annotation { @@ -319,7 +319,7 @@ impl ErrorEntry { #[pyclass(module = "rflx.rapidflux", extends = PyException, subclass)] #[derive(Clone, Serialize, Deserialize, Debug)] #[pyo3(name = "RecordFluxError")] -pub struct RapidFluxError(lib::RapidFluxError); +pub struct RapidFluxError(pub lib::RapidFluxError); #[pymethods] impl RapidFluxError { diff --git a/rapidflux/src/ty.rs b/rapidflux/src/ty.rs index dd4b6819b..662d3f813 100644 --- a/rapidflux/src/ty.rs +++ b/rapidflux/src/ty.rs @@ -6,14 +6,15 @@ use pyo3::{ exceptions::PyTypeError, prelude::*, sync::GILOnceCell, - types::{PyBool, PyBytes, PyInt, PyNotImplemented, PySet, PyTuple}, + type_object::PyTypeInfo, + types::{PyBool, PyBytes, PyInt, PyList, PyNotImplemented, PySet, PyTuple, PyType}, }; use serde::{Deserialize, Serialize}; use librapidflux::ty as lib; use crate::{ - diagnostics::Location, + diagnostics::{Annotation, Location, RapidFluxError}, identifier::{to_id, ID}, impl_states, register_submodule_declarations, }; @@ -23,6 +24,13 @@ pub struct Builtins; #[pymethods] impl Builtins { + #[classattr] + #[pyo3(name = "UNDEFINED")] + fn undefined(py: Python<'_>) -> &PyObject { + static UNDEFINED: GILOnceCell = GILOnceCell::new(); + UNDEFINED.get_or_init(py, || to_py(&lib::Ty::Undefined, py)) + } + #[classattr] #[pyo3(name = "BOOLEAN")] fn boolean(py: Python<'_>) -> &PyObject { @@ -1043,6 +1051,100 @@ impl Bounds { } } +#[pyfunction] +fn common_type(types: &Bound<'_, PyList>, py: Python<'_>) -> PyObject { + to_py( + &lib::common_type(&types.iter().map(|t| to_ty(&t)).collect::>()), + py, + ) +} + +#[pyfunction] +#[pyo3(signature = (actual, expected, location = None, description = ""))] +#[allow(clippy::needless_pass_by_value)] +fn check_type( + actual: &Bound<'_, PyAny>, + expected: &Bound<'_, PyAny>, + location: Option<&Location>, + description: &str, +) -> RapidFluxError { + let expected = if let Ok(tuple) = expected.extract::>>() { + tuple + } else if let Ok(ty) = expected.extract::>() { + Vec::from([ty]) + } else { + panic!("unexpected argument type for expected: {expected}") + }; + RapidFluxError(lib::check_type( + &to_ty(actual), + &expected.iter().map(|e| to_ty(e)).collect::>(), + location.map(|l| &l.0), + description, + )) +} + +#[pyfunction] +#[pyo3(signature = (actual, expected, location = None, description = "", additional_annotations = None))] +fn check_type_instance( + actual: &Bound<'_, PyAny>, + expected: &Bound<'_, PyAny>, + location: Option<&Location>, + description: &str, + additional_annotations: Option>, + py: Python<'_>, +) -> PyResult { + let mut exp = vec![]; + let expected = if let Ok(tuple) = expected.extract::>>() { + tuple + } else if let Ok(ty) = expected.extract::>() { + Vec::from([ty]) + } else { + panic!("unexpected argument type for expected: {expected}") + }; + for e in expected { + if e.eq(Undefined::type_object_bound(py))? { + exp.push(lib::TyDiscriminants::Undefined); + } else if e.eq(Any::type_object_bound(py))? { + exp.push(lib::TyDiscriminants::Any); + } else if e.eq(Enumeration::type_object_bound(py))? { + exp.push(lib::TyDiscriminants::Enumeration); + } else if e.eq(AnyInteger::type_object_bound(py))? { + exp.push(lib::TyDiscriminants::AnyInteger); + } else if e.eq(UniversalInteger::type_object_bound(py))? { + exp.push(lib::TyDiscriminants::UniversalInteger); + } else if e.eq(Integer::type_object_bound(py))? { + exp.push(lib::TyDiscriminants::Integer); + } else if e.eq(Composite::type_object_bound(py))? { + exp.push(lib::TyDiscriminants::Composite); + } else if e.eq(Aggregate::type_object_bound(py))? { + exp.push(lib::TyDiscriminants::Aggregate); + } else if e.eq(Sequence::type_object_bound(py))? { + exp.push(lib::TyDiscriminants::Sequence); + } else if e.eq(Compound::type_object_bound(py))? { + exp.push(lib::TyDiscriminants::Compound); + } else if e.eq(Structure::type_object_bound(py))? { + exp.push(lib::TyDiscriminants::Structure); + } else if e.eq(Message::type_object_bound(py))? { + exp.push(lib::TyDiscriminants::Message); + } else if e.eq(Channel::type_object_bound(py))? { + exp.push(lib::TyDiscriminants::Channel); + } else { + panic!("unexpected type {e}") + } + } + Ok(RapidFluxError(lib::check_type_instance( + &to_ty(actual), + &exp, + location.map(|l| &l.0), + description, + &additional_annotations + .unwrap_or_default() + .into_iter() + .map(|a| a.0) + .collect::>(), + ))) +} + fn to_ty(obj: &Bound<'_, PyAny>) -> lib::Ty { if obj.extract::>().is_ok() { lib::Ty::Undefined @@ -1079,6 +1181,7 @@ fn to_py(obj: &lib::Ty, py: Python<'_>) -> PyObject { ) .unwrap() .into_py(py), + lib::Ty::AnyInteger => Py::new(py, AnyInteger::new()).unwrap().into_py(py), lib::Ty::UniversalInteger(universal_integer) => Py::new( py, AnyInteger::new().add_subclass(UniversalInteger(universal_integer.clone())), @@ -1090,6 +1193,7 @@ fn to_py(obj: &lib::Ty, py: Python<'_>) -> PyObject { .unwrap() .into_py(py) } + lib::Ty::Composite => Py::new(py, Composite::new()).unwrap().into_py(py), lib::Ty::Aggregate(aggregate) => Py::new( py, Composite::new().add_subclass(Aggregate(aggregate.clone())), @@ -1102,6 +1206,7 @@ fn to_py(obj: &lib::Ty, py: Python<'_>) -> PyObject { ) .unwrap() .into_py(py), + lib::Ty::Compound => Py::new(py, Compound::new()).unwrap().into_py(py), lib::Ty::Structure(structure) => Py::new( py, Compound::new().add_subclass(Structure(structure.clone())), @@ -1162,5 +1267,5 @@ register_submodule_declarations!( Message, Channel, ], - [] + [common_type, check_type, check_type_instance] ); diff --git a/rflx/ada.py b/rflx/ada.py index 1f13a59f9..c1ec2ddf8 100644 --- a/rflx/ada.py +++ b/rflx/ada.py @@ -11,7 +11,7 @@ from typing_extensions import Self -from rflx import expr, typing_ as rty +from rflx import expr, ty from rflx.common import Base, file_name, indent, indent_next, unique from rflx.identifier import ID, StrID @@ -611,7 +611,7 @@ def _representation(self) -> str: def rflx_expr(self) -> expr.Call: assert not self.named_arguments - return expr.Call(self.identifier, rty.UNDEFINED, [a.rflx_expr() for a in self.arguments]) + return expr.Call(self.identifier, ty.UNDEFINED, [a.rflx_expr() for a in self.arguments]) class Slice(Name): diff --git a/rflx/expr.py b/rflx/expr.py index 22e2c2c33..9d0cadb7a 100644 --- a/rflx/expr.py +++ b/rflx/expr.py @@ -13,11 +13,11 @@ from sys import intern from typing import TYPE_CHECKING, Final -from rflx import const, typing_ as rty +from rflx import const, ty from rflx.common import Base, indent, indent_next, unique from rflx.error import are_all_locations_present from rflx.identifier import ID, StrID -from rflx.rapidflux import Annotation, ErrorEntry, Location, RecordFluxError, Severity, ty +from rflx.rapidflux import Annotation, ErrorEntry, Location, RecordFluxError, Severity if TYPE_CHECKING: from _typeshed import SupportsAllComparisons @@ -41,7 +41,7 @@ class Expr(Base): def __init__( self, - type_: rty.Type = rty.UNDEFINED, + type_: ty.Type = ty.UNDEFINED, location: Location | None = None, ): self.type_ = type_ @@ -98,11 +98,11 @@ def _check_type_subexpr(self) -> RecordFluxError: """Initialize and check the types of sub-expressions.""" raise NotImplementedError - def check_type(self, expected: rty.Type | tuple[rty.Type, ...]) -> RecordFluxError: + def check_type(self, expected: ty.Type | tuple[ty.Type, ...]) -> RecordFluxError: """Initialize and check the types of the expression and all sub-expressions.""" error = self._check_type_subexpr() error.extend( - rty.check_type( + ty.check_type( self.type_, expected, self.location, @@ -113,12 +113,12 @@ def check_type(self, expected: rty.Type | tuple[rty.Type, ...]) -> RecordFluxErr def check_type_instance( self, - expected: type[rty.Type] | tuple[type[rty.Type], ...], + expected: type[ty.Type] | tuple[type[ty.Type], ...], ) -> RecordFluxError: """Initialize and check the types of the expression and all sub-expressions.""" error = self._check_type_subexpr() error.extend( - rty.check_type_instance( + ty.check_type_instance( self.type_, expected, self.location, @@ -158,14 +158,14 @@ def parenthesized(self, expr: Expr) -> str: class Not(Expr): def __init__(self, expr: Expr, location: Location | None = None) -> None: - super().__init__(rty.BOOLEAN, location) + super().__init__(ty.BOOLEAN, location) self.expr = expr def _update_str(self) -> None: self._str = intern(f"not {self.parenthesized(self.expr)}") def _check_type_subexpr(self) -> RecordFluxError: - return self.expr.check_type(rty.BOOLEAN) + return self.expr.check_type(ty.BOOLEAN) def __neg__(self) -> Expr: return self.expr @@ -233,7 +233,7 @@ def __init__( self, left: Expr, right: Expr, - type_: rty.Type = rty.UNDEFINED, + type_: ty.Type = ty.UNDEFINED, location: Location | None = None, ) -> None: super().__init__(type_, location) @@ -303,7 +303,7 @@ def symbol(self) -> str: class AssExpr(Expr): def __init__(self, *terms: Expr, location: Location | None = None) -> None: - super().__init__(rty.UNDEFINED, location=location) + super().__init__(ty.UNDEFINED, location=location) self.terms = list(terms) def __repr__(self) -> str: @@ -468,7 +468,7 @@ def symbol(self) -> str: class BoolAssExpr(AssExpr): def __init__(self, *terms: Expr, location: Location | None = None) -> None: super().__init__(*terms, location=location) - self.type_ = rty.BOOLEAN + self.type_ = ty.BOOLEAN def _update_str(self) -> None: if not self.terms: @@ -490,7 +490,7 @@ def _update_str(self) -> None: def _check_type_subexpr(self) -> RecordFluxError: error = RecordFluxError() for t in self.terms: - error.extend(t.check_type(rty.BOOLEAN).entries) + error.extend(t.check_type(ty.BOOLEAN).entries) return error @abstractmethod @@ -570,10 +570,10 @@ def symbol(self) -> str: class Number(Expr): - type_: rty.UniversalInteger + type_: ty.UniversalInteger def __init__(self, value: int, base: int = 0, location: Location | None = None) -> None: - super().__init__(rty.UniversalInteger(ty.Bounds(value, value)), location) + super().__init__(ty.UniversalInteger(ty.Bounds(value, value)), location) self.value = value self.base = base @@ -692,7 +692,7 @@ def _update_str(self) -> None: self._str = intern(f"-{self.parenthesized(self.expr)}") def _check_type_subexpr(self) -> RecordFluxError: - return self.expr.check_type_instance(rty.AnyInteger) + return self.expr.check_type_instance(ty.AnyInteger) def __neg__(self) -> Expr: return self.expr @@ -735,13 +735,13 @@ def simplified(self) -> Expr: class MathAssExpr(AssExpr): def __init__(self, *terms: Expr, location: Location | None = None) -> None: super().__init__(*terms, location=location) - common_type = rty.common_type([t.type_ for t in terms]) - self.type_ = common_type if common_type != rty.UNDEFINED else rty.BASE_INTEGER + common_type = ty.common_type([t.type_ for t in terms]) + self.type_ = common_type if common_type != ty.UNDEFINED else ty.BASE_INTEGER def _check_type_subexpr(self) -> RecordFluxError: error = RecordFluxError() for t in self.terms: - error.extend(t.check_type_instance(rty.AnyInteger).entries) + error.extend(t.check_type_instance(ty.AnyInteger).entries) return error @@ -816,14 +816,14 @@ def symbol(self) -> str: class MathBinExpr(BinExpr): def __init__(self, left: Expr, right: Expr, location: Location | None = None) -> None: - super().__init__(left, right, rty.common_type([left.type_, right.type_]), location) + super().__init__(left, right, ty.common_type([left.type_, right.type_]), location) def _check_type_subexpr(self) -> RecordFluxError: error = RecordFluxError() for e in [self.left, self.right]: - error.extend(e.check_type_instance(rty.AnyInteger).entries) + error.extend(e.check_type_instance(ty.AnyInteger).entries) - self.type_ = rty.common_type([self.left.type_, self.right.type_]) + self.type_ = ty.common_type([self.left.type_, self.right.type_]) return error @@ -936,7 +936,7 @@ class Name(Expr): def __init__( self, immutable: bool = False, - type_: rty.Type = rty.UNDEFINED, + type_: ty.Type = ty.UNDEFINED, location: Location | None = None, ) -> None: super().__init__(type_, location) @@ -973,7 +973,7 @@ class TypeName(Name): def __init__( self, identifier: StrID, - type_: rty.Type = rty.UNDEFINED, + type_: ty.Type = ty.UNDEFINED, location: Location | None = None, ) -> None: self.identifier = ID(identifier) @@ -1001,7 +1001,7 @@ class Literal(Name): def __init__( self, identifier: StrID, - type_: rty.Type = rty.UNDEFINED, + type_: ty.Type = ty.UNDEFINED, location: Location | None = None, ) -> None: self.identifier = ID(identifier) @@ -1035,7 +1035,7 @@ def variables(self) -> list[Variable]: def copy( self, identifier: StrID | None = None, - type_: rty.Type | None = None, + type_: ty.Type | None = None, location: Location | None = None, ) -> Literal: return self.__class__( @@ -1050,7 +1050,7 @@ def __init__( self, identifier: StrID, immutable: bool = False, - type_: rty.Type = rty.UNDEFINED, + type_: ty.Type = ty.UNDEFINED, location: Location | None = None, ) -> None: self.identifier = ID(identifier) @@ -1086,7 +1086,7 @@ def copy( self, identifier: StrID | None = None, immutable: bool | None = None, - type_: rty.Type | None = None, + type_: ty.Type | None = None, location: Location | None = None, ) -> Variable: return self.__class__( @@ -1099,12 +1099,12 @@ def copy( TRUE = Literal( "True", - type_=rty.BOOLEAN, + type_=ty.BOOLEAN, location=Location((1, 1), Path(str(const.BUILTINS_PACKAGE)), (1, 1)), ) FALSE = Literal( "False", - type_=rty.BOOLEAN, + type_=ty.BOOLEAN, location=Location((1, 1), Path(str(const.BUILTINS_PACKAGE)), (1, 1)), ) @@ -1160,46 +1160,46 @@ def variables(self) -> list[Variable]: class Size(Attribute): def __init__(self, prefix: StrID | Expr) -> None: super().__init__(prefix) - self.type_ = rty.UNIVERSAL_INTEGER + self.type_ = ty.UNIVERSAL_INTEGER def _check_type_subexpr(self) -> RecordFluxError: - return self.prefix.check_type_instance(rty.Any) + return self.prefix.check_type_instance(ty.Any) class Length(Attribute): def __init__(self, prefix: StrID | Expr) -> None: super().__init__(prefix) - self.type_ = rty.UNIVERSAL_INTEGER + self.type_ = ty.UNIVERSAL_INTEGER def _check_type_subexpr(self) -> RecordFluxError: - return self.prefix.check_type_instance(rty.Any) + return self.prefix.check_type_instance(ty.Any) class First(Attribute): def __init__(self, prefix: StrID | Expr) -> None: super().__init__(prefix) - self.type_ = rty.UNIVERSAL_INTEGER + self.type_ = ty.UNIVERSAL_INTEGER def _check_type_subexpr(self) -> RecordFluxError: - return self.prefix.check_type_instance(rty.Any) + return self.prefix.check_type_instance(ty.Any) class Last(Attribute): def __init__(self, prefix: StrID | Expr) -> None: super().__init__(prefix) - self.type_ = rty.UNIVERSAL_INTEGER + self.type_ = ty.UNIVERSAL_INTEGER def _check_type_subexpr(self) -> RecordFluxError: - return self.prefix.check_type_instance(rty.Any) + return self.prefix.check_type_instance(ty.Any) class ValidChecksum(Attribute): def __init__(self, prefix: StrID | Expr) -> None: super().__init__(prefix) - self.type_ = rty.BOOLEAN + self.type_ = ty.BOOLEAN def _check_type_subexpr(self) -> RecordFluxError: - return self.prefix.check_type_instance(rty.Any) + return self.prefix.check_type_instance(ty.Any) @property def representation(self) -> str: @@ -1209,22 +1209,22 @@ def representation(self) -> str: class Valid(Attribute): def __init__(self, prefix: StrID | Expr) -> None: super().__init__(prefix) - self.type_ = rty.BOOLEAN + self.type_ = ty.BOOLEAN def _check_type_subexpr(self) -> RecordFluxError: return self.prefix.check_type_instance( - (rty.Sequence, rty.Message) if isinstance(self.prefix, Variable) else rty.Any, + (ty.Sequence, ty.Message) if isinstance(self.prefix, Variable) else (ty.Any,), ) class Present(Attribute): def __init__(self, prefix: StrID | Expr) -> None: super().__init__(prefix) - self.type_ = rty.BOOLEAN + self.type_ = ty.BOOLEAN def _check_type_subexpr(self) -> RecordFluxError: if isinstance(self.prefix, Selected): - error = self.prefix.prefix.check_type_instance(rty.Message) + error = self.prefix.prefix.check_type_instance(ty.Message) else: error = RecordFluxError( [ @@ -1241,29 +1241,29 @@ def _check_type_subexpr(self) -> RecordFluxError: class HasData(Attribute): def __init__(self, prefix: StrID | Expr) -> None: super().__init__(prefix) - self.type_ = rty.BOOLEAN + self.type_ = ty.BOOLEAN @property def symbol(self) -> str: return "Has_Data" def _check_type_subexpr(self) -> RecordFluxError: - return self.prefix.check_type_instance(rty.Message) + return self.prefix.check_type_instance(ty.Message) class Head(Attribute): def __init__( self, prefix: StrID | Expr, - type_: rty.Type = rty.UNDEFINED, + type_: ty.Type = ty.UNDEFINED, ): super().__init__(prefix) self.type_ = type_ def _check_type_subexpr(self) -> RecordFluxError: - error = self.prefix.check_type_instance(rty.Composite) + error = self.prefix.check_type_instance(ty.Composite) self.type_ = ( - self.prefix.type_.element if isinstance(self.prefix.type_, rty.Composite) else rty.Any() + self.prefix.type_.element if isinstance(self.prefix.type_, ty.Composite) else ty.Any() ) if not isinstance(self.prefix, (Variable, Selected, Comprehension)): error.push( @@ -1279,10 +1279,10 @@ def _check_type_subexpr(self) -> RecordFluxError: class Opaque(Attribute): def __init__(self, prefix: StrID | Expr) -> None: super().__init__(prefix) - self.type_ = rty.OPAQUE + self.type_ = ty.OPAQUE def _check_type_subexpr(self) -> RecordFluxError: - return self.prefix.check_type_instance((rty.Sequence, rty.Message)) + return self.prefix.check_type_instance((ty.Sequence, ty.Message)) class Constrained(Attribute): @@ -1356,7 +1356,7 @@ def __init__( prefix: Expr, selector: StrID, immutable: bool = False, - type_: rty.Type = rty.UNDEFINED, + type_: ty.Type = ty.UNDEFINED, location: Location | None = None, ) -> None: self.prefix = prefix @@ -1374,7 +1374,7 @@ def findall(self, match: Callable[[Expr], bool]) -> Sequence[Expr]: def _check_type_subexpr(self) -> RecordFluxError: error = RecordFluxError() - if isinstance(self.prefix.type_, rty.Message): + if isinstance(self.prefix.type_, ty.Message): if self.selector in self.prefix.type_.types: self.type_ = self.prefix.type_.types[self.selector] else: @@ -1392,10 +1392,10 @@ def _check_type_subexpr(self) -> RecordFluxError: ), ], ) - self.type_ = rty.Any() + self.type_ = ty.Any() else: - self.type_ = rty.Any() - error.extend(self.prefix.check_type_instance(rty.Message).entries) + self.type_ = ty.Any() + error.extend(self.prefix.check_type_instance(ty.Message).entries) return error @property @@ -1427,7 +1427,7 @@ def copy( prefix: Expr | None = None, selector: StrID | None = None, immutable: bool | None = None, - type_: rty.Type | None = None, + type_: ty.Type | None = None, location: Location | None = None, ) -> Selected: return self.__class__( @@ -1443,10 +1443,10 @@ class Call(Name): def __init__( # noqa: PLR0913 self, identifier: StrID, - type_: rty.Type, + type_: ty.Type, args: Sequence[Expr] | None = None, immutable: bool = False, - argument_types: Sequence[rty.Type] | None = None, + argument_types: Sequence[ty.Type] | None = None, location: Location | None = None, ) -> None: self.identifier = ID(identifier) @@ -1461,9 +1461,9 @@ def _check_type_subexpr(self) -> RecordFluxError: error = RecordFluxError() for a, t in itertools.zip_longest(self.args, self.argument_types[: len(self.args)]): - error.extend(a.check_type(t if t is not None else rty.Any()).entries) + error.extend(a.check_type(t if t is not None else ty.Any()).entries) - if self.type_ != rty.UNDEFINED: + if self.type_ != ty.UNDEFINED: if len(self.args) < len(self.argument_types): error.push( ErrorEntry( @@ -1558,7 +1558,7 @@ def _check_type_subexpr(self) -> RecordFluxError: class Aggregate(Expr): def __init__(self, *elements: Expr, location: Location | None = None) -> None: - super().__init__(rty.Aggregate(rty.common_type([e.type_ for e in elements])), location) + super().__init__(ty.Aggregate(ty.common_type([e.type_ for e in elements])), location) self.elements = list(elements) def __eq__(self, other: object) -> bool: @@ -1575,7 +1575,7 @@ def _update_str(self) -> None: def _check_type_subexpr(self) -> RecordFluxError: error = RecordFluxError() for e in self.elements: - error.extend(e.check_type_instance(rty.Any).entries) + error.extend(e.check_type_instance(ty.Any).entries) return error def __neg__(self) -> Expr: @@ -1663,7 +1663,7 @@ def simplified(self) -> Expr: class Relation(BinExpr): def __init__(self, left: Expr, right: Expr, location: Location | None = None) -> None: - super().__init__(left, right, rty.BOOLEAN, location) + super().__init__(left, right, ty.BOOLEAN, location) @abstractmethod def __neg__(self) -> Expr: @@ -1735,7 +1735,7 @@ def __neg__(self) -> Expr: def _check_type_subexpr(self) -> RecordFluxError: error = RecordFluxError() for e in [self.left, self.right]: - error.extend(e.check_type_instance(rty.AnyInteger).entries) + error.extend(e.check_type_instance(ty.AnyInteger).entries) return error @property @@ -1753,7 +1753,7 @@ def __neg__(self) -> Expr: def _check_type_subexpr(self) -> RecordFluxError: error = RecordFluxError() for e in [self.left, self.right]: - error.extend(e.check_type_instance(rty.AnyInteger).entries) + error.extend(e.check_type_instance(ty.AnyInteger).entries) return error @property @@ -1769,7 +1769,7 @@ def __neg__(self) -> Expr: return NotEqual(self.left, self.right) def _check_type_subexpr(self) -> RecordFluxError: - error = self.left.check_type_instance(rty.Any) + error = self.left.check_type_instance(ty.Any) error.extend(self.right.check_type(self.left.type_).entries) return error @@ -1788,7 +1788,7 @@ def __neg__(self) -> Expr: def _check_type_subexpr(self) -> RecordFluxError: error = RecordFluxError() for e in [self.left, self.right]: - error.extend(e.check_type_instance(rty.AnyInteger).entries) + error.extend(e.check_type_instance(ty.AnyInteger).entries) return error @property @@ -1806,7 +1806,7 @@ def __neg__(self) -> Expr: def _check_type_subexpr(self) -> RecordFluxError: error = RecordFluxError() for e in [self.left, self.right]: - error.extend(e.check_type_instance(rty.AnyInteger).entries) + error.extend(e.check_type_instance(ty.AnyInteger).entries) return error @property @@ -1822,7 +1822,7 @@ def __neg__(self) -> Expr: return Equal(self.left, self.right) def _check_type_subexpr(self) -> RecordFluxError: - error = self.left.check_type_instance(rty.Any) + error = self.left.check_type_instance(ty.Any) error.extend(self.right.check_type(self.left.type_).entries) return error @@ -1839,10 +1839,10 @@ def __neg__(self) -> Expr: return NotIn(self.left, self.right) def _check_type_subexpr(self) -> RecordFluxError: - error = self.left.check_type_instance(rty.Any) + error = self.left.check_type_instance(ty.Any) error.extend( self.right.check_type( - rty.Aggregate(self.left.type_), + ty.Aggregate(self.left.type_), ).entries, ) return error @@ -1857,10 +1857,10 @@ def __neg__(self) -> Expr: return In(self.left, self.right) def _check_type_subexpr(self) -> RecordFluxError: - error = self.left.check_type_instance(rty.Any) + error = self.left.check_type_instance(ty.Any) error.extend( self.right.check_type( - rty.Aggregate(self.left.type_), + ty.Aggregate(self.left.type_), ).entries, ) return error @@ -1877,7 +1877,7 @@ def __init__( else_expression: Expr | None = None, ) -> None: super().__init__( - rty.common_type( + ty.common_type( [ *[e.type_ for _, e in condition_expressions], *([else_expression.type_] if else_expression else []), @@ -1966,7 +1966,7 @@ def __init__( predicate: Expr, location: Location | None = None, ) -> None: - super().__init__(rty.BOOLEAN, location) + super().__init__(ty.BOOLEAN, location) self.parameter_identifier = ID(parameter_identifier) self.iterable = iterable self.predicate = predicate @@ -1980,17 +1980,17 @@ def _update_str(self) -> None: def _check_type_subexpr(self) -> RecordFluxError: def typify_variable(expr: Expr) -> Expr: if isinstance(expr, Variable) and expr.identifier == self.parameter_identifier: - if isinstance(self.iterable.type_, (rty.Aggregate, rty.Sequence)): + if isinstance(self.iterable.type_, (ty.Aggregate, ty.Sequence)): expr.type_ = self.iterable.type_.element else: - expr.type_ = rty.Any() + expr.type_ = ty.Any() return expr - error = self.iterable.check_type_instance(rty.Composite) + error = self.iterable.check_type_instance(ty.Composite) self.predicate = self.predicate.substituted(typify_variable) - error.extend(self.predicate.check_type(rty.BOOLEAN).entries) + error.extend(self.predicate.check_type(ty.BOOLEAN).entries) return error @property @@ -2081,7 +2081,7 @@ def keyword(self) -> str: class ValueRange(Expr): def __init__(self, lower: Expr, upper: Expr, location: Location | None = None): - super().__init__(rty.Any(), location) + super().__init__(ty.Any(), location) self.lower = lower self.upper = upper @@ -2091,7 +2091,7 @@ def _update_str(self) -> None: def _check_type_subexpr(self) -> RecordFluxError: error = RecordFluxError() for e in [self.lower, self.upper]: - error.extend(e.check_type_instance(rty.AnyInteger).entries) + error.extend(e.check_type_instance(ty.AnyInteger).entries) return error def __neg__(self) -> Expr: @@ -2125,8 +2125,8 @@ def __init__( self, identifier: StrID, argument: Expr, - type_: rty.Type = rty.UNDEFINED, - argument_types: Sequence[rty.Type] | None = None, + type_: ty.Type = ty.UNDEFINED, + argument_types: Sequence[ty.Type] | None = None, location: Location | None = None, ) -> None: super().__init__(type_, location) @@ -2138,7 +2138,7 @@ def _update_str(self) -> None: self._str = intern(f"{self.identifier} ({self.argument})") def _check_type_subexpr(self) -> RecordFluxError: - error = self.argument.check_type(rty.OPAQUE) + error = self.argument.check_type(ty.OPAQUE) if isinstance(self.argument, Selected): if self.argument_types: @@ -2160,7 +2160,7 @@ def _check_type_subexpr(self) -> RecordFluxError: self.location, ), ] - if isinstance(self.argument.prefix.type_, rty.Message) + if isinstance(self.argument.prefix.type_, ty.Message) else [] ), ), @@ -2244,7 +2244,7 @@ def simplified(self) -> Expr: class Comprehension(Expr): - type_: rty.Aggregate + type_: ty.Aggregate def __init__( self, @@ -2254,7 +2254,7 @@ def __init__( condition: Expr, location: Location | None = None, ) -> None: - super().__init__(rty.Aggregate(selector.type_), location) + super().__init__(ty.Aggregate(selector.type_), location) self.iterator = ID(iterator) self.sequence = sequence self.selector = selector @@ -2268,21 +2268,21 @@ def _update_str(self) -> None: def _check_type_subexpr(self) -> RecordFluxError: def typify_variable(expr: Expr) -> Expr: if isinstance(expr, Variable) and expr.identifier == self.iterator: - if isinstance(self.sequence.type_, (rty.Aggregate, rty.Sequence)): + if isinstance(self.sequence.type_, (ty.Aggregate, ty.Sequence)): expr.type_ = self.sequence.type_.element else: - expr.type_ = rty.Any() + expr.type_ = ty.Any() return expr - error = self.sequence.check_type_instance(rty.Composite) + error = self.sequence.check_type_instance(ty.Composite) self.selector = self.selector.substituted(typify_variable) self.condition = self.condition.substituted(typify_variable) - error.extend(self.selector.check_type_instance(rty.Any).entries) - error.extend(self.condition.check_type(rty.BOOLEAN).entries) + error.extend(self.selector.check_type_instance(ty.Any).entries) + error.extend(self.condition.check_type(ty.BOOLEAN).entries) - self.type_ = rty.Aggregate(self.selector.type_) + self.type_ = ty.Aggregate(self.selector.type_) return error @@ -2334,7 +2334,7 @@ def __init__( self, identifier: StrID, field_values: Mapping[StrID, Expr], - type_: rty.Type = rty.UNDEFINED, + type_: ty.Type = ty.UNDEFINED, location: Location | None = None, ) -> None: super().__init__(type_, location) @@ -2352,9 +2352,9 @@ def _update_str(self) -> None: def _check_type_subexpr(self) -> RecordFluxError: error = RecordFluxError() - if not isinstance(self.type_, rty.Message): + if not isinstance(self.type_, ty.Message): for d in self.field_values.values(): - error.extend(d.check_type_instance(rty.Any).entries) + error.extend(d.check_type_instance(ty.Any).entries) return error @@ -2363,7 +2363,7 @@ def _check_type_subexpr(self) -> RecordFluxError: return error def _field_combinations(self) -> set[tuple[str, ...]]: - assert isinstance(self.type_, rty.Message) + assert isinstance(self.type_, ty.Message) return set(self.type_.field_combinations) @@ -2376,7 +2376,7 @@ def _matching_field_combinations(self, field_position: int) -> set[tuple[str, .. } def _check_for_invalid_fields(self) -> RecordFluxError: - assert isinstance(self.type_, rty.Message) + assert isinstance(self.type_, ty.Message) error = RecordFluxError() @@ -2396,7 +2396,7 @@ def _check_for_invalid_fields(self) -> RecordFluxError: field_type = self.type_.types[field] - if field_type == rty.OPAQUE: + if field_type == ty.OPAQUE: if not any( r.field == field and expr.type_.is_compatible(r.sdu) for r in self.type_.refinements @@ -2567,7 +2567,7 @@ def __init__( choices: Sequence[tuple[Sequence[ID | Number], Expr]], location: Location | None = None, ) -> None: - super().__init__(rty.common_type([e.type_ for _, e in choices]), location) + super().__init__(ty.common_type([e.type_ for _, e in choices]), location) self.expr = expr self.choices = choices @@ -2576,7 +2576,7 @@ def _update_str(self) -> None: self._str = intern(f"(case {self.expr} is\n{data})") def _check_enumeration(self) -> RecordFluxError: - assert isinstance(self.expr.type_, rty.Enumeration) + assert isinstance(self.expr.type_, ty.Enumeration) assert self.expr.type_.literals error = RecordFluxError() @@ -2628,7 +2628,7 @@ def _check_enumeration(self) -> RecordFluxError: return error def _check_integer(self) -> RecordFluxError: - assert isinstance(self.expr.type_, rty.Integer) + assert isinstance(self.expr.type_, ty.Integer) assert self.expr.type_.bounds.lower assert self.expr.type_.bounds.upper @@ -2694,11 +2694,11 @@ def _check_integer(self) -> RecordFluxError: def _check_type_subexpr(self) -> RecordFluxError: error = RecordFluxError() - result_type: rty.Type = rty.Any() + result_type: ty.Type = ty.Any() literals = [c for (choice, _) in self.choices for c in choice] for _, expr in self.choices: - error.extend(expr.check_type_instance(rty.Any).entries) + error.extend(expr.check_type_instance(ty.Any).entries) result_type = result_type.common_type(expr.type_) for i1, (_, e1) in enumerate(self.choices): @@ -2720,7 +2720,7 @@ def _check_type_subexpr(self) -> RecordFluxError: ), ) - error.extend(self.expr.check_type_instance(rty.Any).entries) + error.extend(self.expr.check_type_instance(ty.Any).entries) error.propagate() duplicates = [ @@ -2749,9 +2749,9 @@ def _check_type_subexpr(self) -> RecordFluxError: ), ) - if isinstance(self.expr.type_, rty.Enumeration): + if isinstance(self.expr.type_, ty.Enumeration): error.extend(self._check_enumeration().entries) - elif isinstance(self.expr.type_, rty.Integer): + elif isinstance(self.expr.type_, ty.Integer): error.extend(self._check_integer().entries) else: assert self.expr.location is not None diff --git a/rflx/expr_conv.py b/rflx/expr_conv.py index 3900fcc91..f43195960 100644 --- a/rflx/expr_conv.py +++ b/rflx/expr_conv.py @@ -3,7 +3,7 @@ from functools import singledispatch from typing import Generator -from rflx import ada, expr, ir, typing_ as rty +from rflx import ada, expr, ir, ty from rflx.error import fail from rflx.identifier import ID @@ -257,8 +257,8 @@ def _(expression: expr.BoolAssExpr, variable_id: Generator[ID, None, None]) -> i return ir.ComplexBoolExpr( [ *right.stmts, - ir.VarDecl(right_id, rty.BOOLEAN, None, origin=right_origin), - ir.Assign(right_id, right.expr, rty.BOOLEAN, origin=right_origin), + ir.VarDecl(right_id, ty.BOOLEAN, None, origin=right_origin), + ir.Assign(right_id, right.expr, ty.BOOLEAN, origin=right_origin), *left_stmts, ], getattr(ir, expression.__class__.__name__)( @@ -276,7 +276,7 @@ def _(expression: expr.Number, _variable_id: Generator[ID, None, None]) -> ir.Co @to_ir.register def _(expression: expr.Neg, variable_id: Generator[ID, None, None]) -> ir.ComplexIntExpr: - assert isinstance(expression.type_, rty.AnyInteger) + assert isinstance(expression.type_, ty.AnyInteger) inner_stmts, inner_expr = _to_ir_basic_int(expression.expr, variable_id) return ir.ComplexIntExpr(inner_stmts, ir.Neg(inner_expr, origin=expression)) @@ -286,7 +286,7 @@ def _(expression: expr.MathAssExpr, variable_id: Generator[ID, None, None]) -> i if len(expression.terms) == 0: return ir.ComplexIntExpr([], ir.IntVal(0, origin=expression)) - assert isinstance(expression.type_, rty.AnyInteger) + assert isinstance(expression.type_, ty.AnyInteger) if len(expression.terms) == 1: first_stmts, first_expr = _to_ir_basic_int(expression.terms[0], variable_id) @@ -307,7 +307,7 @@ def _(expression: expr.MathAssExpr, variable_id: Generator[ID, None, None]) -> i location=expression.terms[1].location, ) - assert isinstance(right_origin.type_, rty.AnyInteger) + assert isinstance(right_origin.type_, ty.AnyInteger) right = to_ir(right_origin, variable_id) @@ -328,7 +328,7 @@ def _(expression: expr.MathAssExpr, variable_id: Generator[ID, None, None]) -> i @to_ir.register def _(expression: expr.MathBinExpr, variable_id: Generator[ID, None, None]) -> ir.ComplexIntExpr: - assert isinstance(expression.type_, rty.AnyInteger) + assert isinstance(expression.type_, ty.AnyInteger) left_stmts, left_expr = _to_ir_basic_int(expression.left, variable_id) right_stmts, right_expr = _to_ir_basic_int(expression.right, variable_id) @@ -340,9 +340,9 @@ def _(expression: expr.MathBinExpr, variable_id: Generator[ID, None, None]) -> i @to_ir.register def _(expression: expr.Literal, _variable_id: Generator[ID, None, None]) -> ir.ComplexExpr: - assert isinstance(expression.type_, rty.Enumeration) + assert isinstance(expression.type_, ty.Enumeration) - if expression.type_ == rty.BOOLEAN: + if expression.type_ == ty.BOOLEAN: if expression.identifier == ID("True"): return ir.ComplexBoolExpr([], ir.BoolVal(value=True, origin=expression)) assert expression.identifier == ID("False") @@ -353,22 +353,22 @@ def _(expression: expr.Literal, _variable_id: Generator[ID, None, None]) -> ir.C @to_ir.register def _(expression: expr.Variable, _variable_id: Generator[ID, None, None]) -> ir.ComplexExpr: - if expression.type_ == rty.BOOLEAN: + if expression.type_ == ty.BOOLEAN: return ir.ComplexBoolExpr([], ir.BoolVar(expression.name, origin=expression)) - if isinstance(expression.type_, rty.Integer): + if isinstance(expression.type_, ty.Integer): return ir.ComplexIntExpr( [], ir.IntVar(expression.name, expression.type_, origin=expression), ) - assert isinstance(expression.type_, rty.Any) + assert isinstance(expression.type_, ty.Any) return ir.ComplexExpr([], ir.ObjVar(expression.name, expression.type_, origin=expression)) @to_ir.register def _(expression: expr.Attribute, variable_id: Generator[ID, None, None]) -> ir.ComplexExpr: - assert isinstance(expression.type_, rty.Any) + assert isinstance(expression.type_, ty.Any) prefix_stmts, prefix_expr = to_ir_basic_expr(expression.prefix, variable_id) @@ -379,10 +379,10 @@ def _(expression: expr.Attribute, variable_id: Generator[ID, None, None]) -> ir. @to_ir.register def _(expression: expr.Size, variable_id: Generator[ID, None, None]) -> ir.ComplexExpr: - assert isinstance(expression.type_, rty.Any) + assert isinstance(expression.type_, ty.Any) if isinstance(expression.prefix, expr.Selected): - assert isinstance(expression.prefix.prefix.type_, rty.Compound) + assert isinstance(expression.prefix.prefix.type_, ty.Compound) assert isinstance(expression.prefix.prefix, expr.Variable) return ir.ComplexExpr( [], @@ -413,41 +413,41 @@ def _attribute_to_ir( @_attribute_to_ir.register def _(expression: expr.Size, prefix: ID) -> ir.Expr: - assert isinstance(expression.prefix.type_, rty.Any) + assert isinstance(expression.prefix.type_, ty.Any) return ir.Size(prefix, expression.prefix.type_, origin=expression) @_attribute_to_ir.register def _(expression: expr.Length, prefix: ID) -> ir.Expr: - assert isinstance(expression.prefix.type_, rty.Any) + assert isinstance(expression.prefix.type_, ty.Any) return ir.Length(prefix, expression.prefix.type_, origin=expression) @_attribute_to_ir.register def _(expression: expr.First, prefix: ID) -> ir.Expr: - assert isinstance(expression.prefix.type_, rty.Any) + assert isinstance(expression.prefix.type_, ty.Any) return ir.First(prefix, expression.prefix.type_, origin=expression) @_attribute_to_ir.register def _(expression: expr.Last, prefix: ID) -> ir.Expr: - assert isinstance(expression.prefix.type_, rty.Any) + assert isinstance(expression.prefix.type_, ty.Any) return ir.Last(prefix, expression.prefix.type_, origin=expression) @_attribute_to_ir.register def _(expression: expr.ValidChecksum, prefix: ID) -> ir.Expr: - assert isinstance(expression.prefix.type_, rty.Any) + assert isinstance(expression.prefix.type_, ty.Any) return ir.ValidChecksum(prefix, expression.prefix.type_, origin=expression) @to_ir.register def _(expression: expr.Valid, variable_id: Generator[ID, None, None]) -> ir.ComplexExpr: - assert isinstance(expression.type_, rty.Any) + assert isinstance(expression.type_, ty.Any) if isinstance(expression.prefix, expr.Selected): assert isinstance(expression.prefix.prefix, expr.Variable) - assert isinstance(expression.prefix.prefix.type_, rty.Compound) + assert isinstance(expression.prefix.prefix.type_, ty.Compound) return ir.ComplexExpr( [], ir.FieldValid( @@ -466,17 +466,17 @@ def _(expression: expr.Valid, variable_id: Generator[ID, None, None]) -> ir.Comp @_attribute_to_ir.register def _(expression: expr.Valid, prefix: ID) -> ir.Expr: - assert isinstance(expression.prefix.type_, rty.Any) + assert isinstance(expression.prefix.type_, ty.Any) return ir.Valid(prefix, expression.prefix.type_, origin=expression) @to_ir.register def _(expression: expr.Present, variable_id: Generator[ID, None, None]) -> ir.ComplexExpr: - assert isinstance(expression.type_, rty.Any) + assert isinstance(expression.type_, ty.Any) if isinstance(expression.prefix, expr.Selected): assert isinstance(expression.prefix.prefix, expr.Variable) - assert isinstance(expression.prefix.prefix.type_, rty.Compound) + assert isinstance(expression.prefix.prefix.type_, ty.Compound) return ir.ComplexExpr( [], ir.FieldPresent( @@ -495,19 +495,19 @@ def _(expression: expr.Present, variable_id: Generator[ID, None, None]) -> ir.Co @_attribute_to_ir.register def _(expression: expr.Present, prefix: ID) -> ir.Expr: - assert isinstance(expression.prefix.type_, rty.Any) + assert isinstance(expression.prefix.type_, ty.Any) return ir.Present(prefix, expression.prefix.type_, origin=expression) @_attribute_to_ir.register def _(expression: expr.HasData, prefix: ID) -> ir.Expr: - assert isinstance(expression.prefix.type_, rty.Any) + assert isinstance(expression.prefix.type_, ty.Any) return ir.HasData(prefix, expression.prefix.type_, origin=expression) @to_ir.register def _(expression: expr.Head, variable_id: Generator[ID, None, None]) -> ir.ComplexExpr: - assert isinstance(expression.type_, rty.Any) + assert isinstance(expression.type_, ty.Any) if isinstance(expression.prefix, expr.Comprehension): comprehension = to_ir(expression.prefix, variable_id) @@ -532,14 +532,14 @@ def _(expression: expr.Head, variable_id: Generator[ID, None, None]) -> ir.Compl @_attribute_to_ir.register def _(expression: expr.Head, prefix: ID) -> ir.Expr: - assert isinstance(expression.prefix.type_, rty.Composite) - assert isinstance(expression.prefix.type_.element, rty.Any) + assert isinstance(expression.prefix.type_, ty.Composite) + assert isinstance(expression.prefix.type_.element, ty.Any) return ir.Head(prefix, expression.prefix.type_, origin=expression) @_attribute_to_ir.register def _(expression: expr.Opaque, prefix: ID) -> ir.Expr: - assert isinstance(expression.prefix.type_, (rty.Sequence, rty.Message)) + assert isinstance(expression.prefix.type_, (ty.Sequence, ty.Message)) return ir.Opaque(prefix, expression.prefix.type_, origin=expression) @@ -569,11 +569,11 @@ def _( @to_ir.register def _(expression: expr.Selected, variable_id: Generator[ID, None, None]) -> ir.ComplexExpr: - assert isinstance(expression.type_, rty.Any) - assert isinstance(expression.prefix.type_, rty.Compound) + assert isinstance(expression.type_, ty.Any) + assert isinstance(expression.prefix.type_, ty.Compound) stmts, msg = to_ir_basic_expr(expression.prefix, variable_id) assert isinstance(msg, ir.ObjVar) - if expression.type_ == rty.BOOLEAN: + if expression.type_ == ty.BOOLEAN: return ir.ComplexExpr( stmts, ir.BoolFieldAccess( @@ -583,7 +583,7 @@ def _(expression: expr.Selected, variable_id: Generator[ID, None, None]) -> ir.C origin=expression, ), ) - if isinstance(expression.type_, rty.Integer): + if isinstance(expression.type_, ty.Integer): return ir.ComplexExpr( stmts, ir.IntFieldAccess( @@ -593,7 +593,7 @@ def _(expression: expr.Selected, variable_id: Generator[ID, None, None]) -> ir.C origin=expression, ), ) - if isinstance(expression.type_, (rty.Enumeration, rty.Sequence)): + if isinstance(expression.type_, (ty.Enumeration, ty.Sequence)): return ir.ComplexExpr( stmts, ir.ObjFieldAccess( @@ -616,10 +616,10 @@ def _(expression: expr.Call, variable_id: Generator[ID, None, None]) -> ir.Compl arguments_stmts.extend(a_ir.stmts) arguments_exprs.append(a_ir.expr) - assert all(isinstance(t, rty.Any) for t in expression.argument_types) - argument_types = [t for t in expression.argument_types if isinstance(t, rty.Any)] + assert all(isinstance(t, ty.Any) for t in expression.argument_types) + argument_types = [t for t in expression.argument_types if isinstance(t, ty.Any)] - if expression.type_ is rty.BOOLEAN: + if expression.type_ is ty.BOOLEAN: return ir.ComplexExpr( arguments_stmts, ir.BoolCall( @@ -630,7 +630,7 @@ def _(expression: expr.Call, variable_id: Generator[ID, None, None]) -> ir.Compl ), ) - if isinstance(expression.type_, rty.Integer): + if isinstance(expression.type_, ty.Integer): return ir.ComplexExpr( arguments_stmts, ir.IntCall( @@ -642,7 +642,7 @@ def _(expression: expr.Call, variable_id: Generator[ID, None, None]) -> ir.Compl ), ) - assert isinstance(expression.type_, (rty.Enumeration, rty.Structure, rty.Message)) + assert isinstance(expression.type_, (ty.Enumeration, ty.Structure, ty.Message)) return ir.ComplexExpr( arguments_stmts, ir.ObjCall( @@ -673,7 +673,7 @@ def _( @to_ir.register def _(expression: expr.Aggregate, variable_id: Generator[ID, None, None]) -> ir.ComplexExpr: - assert isinstance(expression.type_, rty.Any) + assert isinstance(expression.type_, ty.Any) elements = [] stmts = [] @@ -728,11 +728,11 @@ def _(expression: expr.IfExpr, variable_id: Generator[ID, None, None]) -> ir.Com condition = expression.condition_expressions[0][0] condition_stmts, condition_expr = _to_ir_basic_bool(condition, variable_id) - assert condition.type_ is rty.BOOLEAN + assert condition.type_ is ty.BOOLEAN then_expression = expression.condition_expressions[0][1] - if then_expression.type_ is rty.BOOLEAN and expression.else_expression.type_ is rty.BOOLEAN: + if then_expression.type_ is ty.BOOLEAN and expression.else_expression.type_ is ty.BOOLEAN: then_expr = to_ir(then_expression, variable_id) else_expr = to_ir(expression.else_expression, variable_id) assert isinstance(then_expr, ir.ComplexBoolExpr) @@ -747,9 +747,9 @@ def _(expression: expr.IfExpr, variable_id: Generator[ID, None, None]) -> ir.Com ), ) - assert isinstance(expression.type_, rty.AnyInteger) - assert isinstance(then_expression.type_, rty.AnyInteger) - assert isinstance(expression.else_expression.type_, rty.AnyInteger) + assert isinstance(expression.type_, ty.AnyInteger) + assert isinstance(then_expression.type_, ty.AnyInteger) + assert isinstance(expression.else_expression.type_, ty.AnyInteger) then_expr = to_ir(then_expression, variable_id) else_expr = to_ir(expression.else_expression, variable_id) assert isinstance(then_expr, ir.ComplexIntExpr) @@ -787,7 +787,7 @@ def _( @to_ir.register def _(expression: expr.Conversion, variable_id: Generator[ID, None, None]) -> ir.ComplexExpr: - assert isinstance(expression.type_, rty.NamedTypeClass) + assert isinstance(expression.type_, ty.NamedTypeClass) argument = to_ir(expression.argument, variable_id) return ir.ComplexExpr( argument.stmts, @@ -824,7 +824,7 @@ def _(expression: expr.Comprehension, variable_id: Generator[ID, None, None]) -> @to_ir.register def _(expression: expr.MessageAggregate, variable_id: Generator[ID, None, None]) -> ir.ComplexExpr: - assert isinstance(expression.type_, rty.Message) + assert isinstance(expression.type_, ty.Message) field_values = {} stmts = [] for i, e in expression.field_values.items(): @@ -842,7 +842,7 @@ def _( expression: expr.DeltaMessageAggregate, variable_id: Generator[ID, None, None], ) -> ir.ComplexExpr: - assert isinstance(expression.type_, rty.Message) + assert isinstance(expression.type_, ty.Message) field_values = {} stmts = [] for i, e in expression.field_values.items(): @@ -857,7 +857,7 @@ def _( @to_ir.register def _(expression: expr.CaseExpr, variable_id: Generator[ID, None, None]) -> ir.ComplexExpr: - assert isinstance(expression.type_, rty.Any) + assert isinstance(expression.type_, ty.Any) expression_stmts, expression_expr = to_ir_basic_expr(expression.expr, variable_id) choices = [] @@ -867,11 +867,11 @@ def _(expression: expr.CaseExpr, variable_id: Generator[ID, None, None]) -> ir.C # TODO(eng/recordflux/RecordFlux#633): Check for unsupported case expressions in model assert not e_stmts cs: list[ir.BasicExpr] - if isinstance(expression.expr.type_, rty.Enumeration): + if isinstance(expression.expr.type_, ty.Enumeration): assert all(isinstance(c, ID) for c in choice) cs = [ir.EnumLit(c, expression.expr.type_) for c in choice if isinstance(c, ID)] else: - assert isinstance(expression.expr.type_, rty.AnyInteger) + assert isinstance(expression.expr.type_, ty.AnyInteger) assert all(isinstance(c, expr.Number) for c in choice) cs = [ir.IntVal(int(c)) for c in choice if isinstance(c, expr.Number)] choices.append((cs, e_expr)) @@ -891,7 +891,7 @@ def _to_ir_basic_int( expression: expr.Expr, variable_id: Generator[ID, None, None], ) -> tuple[list[ir.Stmt], ir.BasicIntExpr]: - assert isinstance(expression.type_, rty.AnyInteger) + assert isinstance(expression.type_, ty.AnyInteger) result = to_ir(expression, variable_id) if isinstance(result.expr, ir.BasicIntExpr): @@ -913,7 +913,7 @@ def _to_ir_basic_bool( expression: expr.Expr, variable_id: Generator[ID, None, None], ) -> tuple[list[ir.Stmt], ir.BasicBoolExpr]: - assert expression.type_ == rty.BOOLEAN + assert expression.type_ == ty.BOOLEAN result = to_ir(expression, variable_id) if isinstance(result.expr, ir.BasicBoolExpr): @@ -924,8 +924,8 @@ def _to_ir_basic_bool( result_expr = ir.BoolVar(result_id, origin=expression) result_stmts = [ *result.stmts, - ir.VarDecl(result_id, rty.BOOLEAN, None, origin=expression), - ir.Assign(result_id, result.expr, rty.BOOLEAN, origin=expression), + ir.VarDecl(result_id, ty.BOOLEAN, None, origin=expression), + ir.Assign(result_id, result.expr, ty.BOOLEAN, origin=expression), ] return (result_stmts, result_expr) @@ -944,19 +944,19 @@ def to_ir_basic_expr( if isinstance(result.expr, ir.BoolExpr): result_expr = ir.BoolVar(result_id, origin=expression) elif isinstance(result.expr, ir.IntExpr): - assert isinstance(expression.type_, rty.AnyInteger) + assert isinstance(expression.type_, ty.AnyInteger) result_expr = ir.IntVar(result_id, ir.to_integer(expression.type_), origin=expression) else: - assert isinstance(expression.type_, rty.Any) + assert isinstance(expression.type_, ty.Any) result_expr = ir.ObjVar(result_id, expression.type_, origin=expression) - if isinstance(result_expr.type_, rty.Aggregate): + if isinstance(result_expr.type_, ty.Aggregate): # TODO(eng/recordflux/RecordFlux#1497): Support comparisons of opaque fields result_stmts = [ # pragma: no cover *result.stmts, ir.VarDecl( result_id, - rty.OPAQUE, + ty.OPAQUE, ir.ComplexExpr([], result.expr), origin=expression, ), @@ -964,7 +964,7 @@ def to_ir_basic_expr( else: result_type = result_expr.type_ - assert isinstance(result_type, rty.NamedTypeClass) + assert isinstance(result_type, ty.NamedTypeClass) result_stmts = [ *result.stmts, diff --git a/rflx/expr_proof.py b/rflx/expr_proof.py index cb67ad184..d03c52af3 100644 --- a/rflx/expr_proof.py +++ b/rflx/expr_proof.py @@ -10,7 +10,7 @@ import z3 -from rflx import expr, typing_ as rty +from rflx import expr, ty from rflx.const import MP_CONTEXT from rflx.error import are_all_locations_present from rflx.identifier import ID @@ -294,7 +294,7 @@ def _(expression: expr.Literal) -> z3.ExprRef: @_to_z3.register def _(expression: expr.Variable) -> z3.ExprRef: - if expression.type_ == rty.BOOLEAN: + if expression.type_ == ty.BOOLEAN: return z3.Bool(expression.name) return z3.Int(expression.name) diff --git a/rflx/generator/allocator.py b/rflx/generator/allocator.py index c38418bca..a1d3392d9 100644 --- a/rflx/generator/allocator.py +++ b/rflx/generator/allocator.py @@ -4,7 +4,7 @@ from dataclasses import dataclass from itertools import zip_longest -from rflx import ir, typing_ as rty +from rflx import ir, ty from rflx.ada import ( Add, And, @@ -352,8 +352,8 @@ def _create_finalize_proc(self, slots: Sequence[NumberedSlotInfo]) -> UnitPart: ) @staticmethod - def _needs_allocation(type_: rty.Type) -> bool: - return isinstance(type_, (rty.Message, rty.Sequence)) + def _needs_allocation(type_: ty.Type) -> bool: + return isinstance(type_, (ty.Message, ty.Sequence)) def _allocate_global_slots( self, @@ -421,8 +421,8 @@ def determine_allocation_requirements( if ( isinstance(statement, ir.Assign) and isinstance(statement.expression, ir.Comprehension) - and isinstance(statement.expression.sequence.type_, rty.Sequence) - and isinstance(statement.expression.sequence.type_.element, rty.Message) + and isinstance(statement.expression.sequence.type_, ty.Sequence) + and isinstance(statement.expression.sequence.type_.element, ty.Message) and isinstance(statement.expression.sequence, (ir.Var, ir.FieldAccess)) ): if isinstance(statement.expression.sequence, ir.FieldAccess): diff --git a/rflx/generator/common.py b/rflx/generator/common.py index d23181e7b..ef8c5158a 100644 --- a/rflx/generator/common.py +++ b/rflx/generator/common.py @@ -6,7 +6,7 @@ from collections.abc import Callable from dataclasses import dataclass -from rflx import expr, expr_conv, ir, model, typing_ as rty +from rflx import expr, expr_conv, ir, model, ty from rflx.ada import ( TRUE, Add, @@ -83,7 +83,7 @@ def buffer_id(self) -> ID: return ID(f"{self.identifier}_Buffer") -def type_to_id(type_: rty.NamedType) -> ID: +def type_to_id(type_: ty.NamedType) -> ID: if type_.identifier.parent == BUILTINS_PACKAGE: return const.TYPES * type_.identifier.name @@ -95,7 +95,7 @@ def substitution( prefix: str, embedded: bool = False, public: bool = False, - target_type: rty.NamedType = rty.BASE_INTEGER, + target_type: ty.NamedType = ty.BASE_INTEGER, ) -> Callable[[expr.Expr], expr.Expr]: facts = substitution_facts(message, prefix, embedded, public, target_type) @@ -140,7 +140,7 @@ def byte_aggregate(aggregate: expr.Aggregate) -> expr.Aggregate: expr.ValueRange( expr.Call( const.TYPES_TO_INDEX, - rty.INDEX, + ty.INDEX, [ expr.Selected( expr.Indexed( @@ -153,7 +153,7 @@ def byte_aggregate(aggregate: expr.Aggregate) -> expr.Aggregate: ), expr.Call( const.TYPES_TO_INDEX, - rty.INDEX, + ty.INDEX, [ expr.Selected( expr.Indexed( @@ -170,7 +170,7 @@ def byte_aggregate(aggregate: expr.Aggregate) -> expr.Aggregate: ) equal_call = expr.Call( "Equal", - rty.BOOLEAN, + ty.BOOLEAN, [expr.Variable("Ctx"), expr.Variable(field.affixed_name), aggregate], ) return equal_call if isinstance(expression, expr.Equal) else expr.Not(equal_call) @@ -222,14 +222,14 @@ def substitution_facts( prefix: str, embedded: bool = False, public: bool = False, - target_type: rty.NamedType = rty.BASE_INTEGER, + target_type: ty.NamedType = ty.BASE_INTEGER, ) -> dict[expr.Name, expr.Expr]: def prefixed(name: str) -> expr.Expr: return expr.Variable(ID("Ctx") * name) if not embedded else expr.Variable(name) first = prefixed("First") last = ( - expr.Call("Written_Last", rty.BIT_LENGTH, [expr.Variable("Ctx")]) + expr.Call("Written_Last", ty.BIT_LENGTH, [expr.Variable("Ctx")]) if public else prefixed("Written_Last") ) @@ -239,7 +239,7 @@ def field_first(field: model.Field) -> expr.Expr: if public: return expr.Call( "Field_First", - rty.BIT_INDEX, + ty.BIT_INDEX, [expr.Variable("Ctx"), expr.Variable(field.affixed_name)], ) return expr.Selected(expr.Indexed(cursors, expr.Variable(field.affixed_name)), "First") @@ -248,7 +248,7 @@ def field_last(field: model.Field) -> expr.Expr: if public: return expr.Call( "Field_Last", - rty.BIT_LENGTH, + ty.BIT_LENGTH, [expr.Variable("Ctx"), expr.Variable(field.affixed_name)], ) return expr.Selected(expr.Indexed(cursors, expr.Variable(field.affixed_name)), "Last") @@ -257,7 +257,7 @@ def field_size(field: model.Field) -> expr.Expr: if public: return expr.Call( "Field_Size", - rty.BIT_LENGTH, + ty.BIT_LENGTH, [expr.Variable("Ctx"), expr.Variable(field.affixed_name)], ) return expr.Add( @@ -277,7 +277,7 @@ def parameter_value(parameter: model.Field, parameter_type: model.TypeDecl) -> e if parameter_type == model.BOOLEAN: return var if isinstance(parameter_type, model.Enumeration): - return expr.Call("To_Base_Integer", rty.BASE_INTEGER, [var]) + return expr.Call("To_Base_Integer", ty.BASE_INTEGER, [var]) if isinstance(parameter_type, model.Scalar): return var @@ -297,7 +297,7 @@ def field_value(field: model.Field, field_type: model.TypeDecl) -> expr.Expr: return expr.Call( "To_Base_Integer", - rty.BASE_INTEGER, + ty.BASE_INTEGER, [call], ) value = expr.Selected( @@ -305,7 +305,7 @@ def field_value(field: model.Field, field_type: model.TypeDecl) -> expr.Expr: "Value", ) if field_type == model.BOOLEAN: - return expr.Call("To_Actual", rty.BOOLEAN, [value]) + return expr.Call("To_Actual", ty.BOOLEAN, [value]) return value if isinstance(field_type, model.Scalar): @@ -321,7 +321,7 @@ def field_value(field: model.Field, field_type: model.TypeDecl) -> expr.Expr: assert False, f'unexpected type "{type(field_type).__name__}"' def type_conversion(expression: expr.Expr) -> expr.Expr: - if expression.type_ == rty.BOOLEAN: + if expression.type_ == ty.BOOLEAN: return expression return expr.Call(type_to_id(target_type), target_type, [expression]) @@ -342,7 +342,7 @@ def type_conversion(expression: expr.Expr) -> expr.Expr: }, **{ expr.Literal(l): type_conversion( - expr.Call("To_Base_Integer", rty.BASE_INTEGER, [expr.Variable(l)]), + expr.Call("To_Base_Integer", ty.BASE_INTEGER, [expr.Variable(l)]), ) for t in message.types.values() if isinstance(t, model.Enumeration) and t != model.BOOLEAN @@ -352,7 +352,7 @@ def type_conversion(expression: expr.Expr) -> expr.Expr: expr.Literal(t.package * l): type_conversion( expr.Call( "To_Base_Integer", - rty.BASE_INTEGER, + ty.BASE_INTEGER, [expr.Variable(prefix * t.package * l)], ), ) @@ -390,14 +390,14 @@ def link_property(link: model.Link, unique: bool) -> Expr: field_type.size if isinstance(field_type, model.Scalar) else link.size.substituted( - substitution(message, prefix, embedded, target_type=rty.BIT_LENGTH), + substitution(message, prefix, embedded, target_type=ty.BIT_LENGTH), ).simplified() ) first = ( prefixed("First") if link.source == model.INITIAL else link.first.substituted( - substitution(message, prefix, embedded, target_type=rty.BIT_INDEX), + substitution(message, prefix, embedded, target_type=ty.BIT_INDEX), ) .substituted( mapping={ @@ -828,8 +828,7 @@ def has_scalar_value_dependent_condition(message: model.Message) -> bool: True for l in message.structure for v in l.condition.variables() - if v.identifier == l.source.identifier - and isinstance(v.type_, (rty.Integer, rty.Enumeration)) + if v.identifier == l.source.identifier and isinstance(v.type_, (ty.Integer, ty.Enumeration)) ) @@ -986,7 +985,7 @@ def substituted(expression: expr.Expr) -> Expr: substitution( message, prefix, - target_type=rty.BIT_LENGTH, + target_type=ty.BIT_LENGTH, embedded=True, ), ).simplified(), @@ -1082,7 +1081,7 @@ def external_io_buffers(state_machine: ir.StateMachine) -> list[Message]: if ( isinstance(action, ir.ChannelStmt) and isinstance(action.expression, ir.Var) - and isinstance(action.expression.type_, rty.Message) + and isinstance(action.expression.type_, ty.Message) ) }, ) diff --git a/rflx/generator/generator.py b/rflx/generator/generator.py index 1c3cacb73..83de078ea 100644 --- a/rflx/generator/generator.py +++ b/rflx/generator/generator.py @@ -7,7 +7,7 @@ from functools import cached_property from pathlib import Path -from rflx import __version__, expr, expr_conv, typing_ as rty +from rflx import __version__, expr, expr_conv, ty from rflx.ada import ( FALSE, TRUE, @@ -1650,7 +1650,7 @@ def _refinement_conditions( pdu_identifier = self._prefix * refinement.pdu.identifier conditions: list[expr.Expr] = [ - expr.Call(pdu_identifier * "Has_Buffer", rty.BOOLEAN, [expr.Variable(pdu_context)]), + expr.Call(pdu_identifier * "Has_Buffer", ty.BOOLEAN, [expr.Variable(pdu_context)]), ] if null_sdu: @@ -1658,7 +1658,7 @@ def _refinement_conditions( [ expr.Call( pdu_identifier * "Well_Formed", - rty.BOOLEAN, + ty.BOOLEAN, [ expr.Variable(pdu_context), expr.Variable(pdu_identifier * refinement.field.affixed_name), @@ -1667,7 +1667,7 @@ def _refinement_conditions( expr.Not( expr.Call( pdu_identifier * "Present", - rty.BOOLEAN, + ty.BOOLEAN, [ expr.Variable(pdu_context), expr.Variable(pdu_identifier * refinement.field.affixed_name), @@ -1680,7 +1680,7 @@ def _refinement_conditions( conditions.append( expr.Call( pdu_identifier * "Present", - rty.BOOLEAN, + ty.BOOLEAN, [ expr.Variable(pdu_context), expr.Variable(pdu_identifier * refinement.field.affixed_name), @@ -1692,7 +1692,7 @@ def _refinement_conditions( [ expr.Call( pdu_identifier * "Valid", - rty.BOOLEAN, + ty.BOOLEAN, [ expr.Variable(pdu_context), expr.Variable(pdu_identifier * f.affixed_name), diff --git a/rflx/generator/message.py b/rflx/generator/message.py index 4069aef97..f7eb19d49 100644 --- a/rflx/generator/message.py +++ b/rflx/generator/message.py @@ -2,7 +2,7 @@ from collections import abc -from rflx import expr, expr_conv, typing_ as rty +from rflx import expr, expr_conv, ty from rflx.ada import ( FALSE, NULL, @@ -420,7 +420,7 @@ def create_valid_predecessors_invariant_function( if l.source in composite_fields else "Valid" ), - rty.BOOLEAN, + ty.BOOLEAN, [ expr.Indexed( expr.Variable("Cursors"), @@ -639,7 +639,7 @@ def create_field_first_internal_function(message: Message, prefix: str) -> UnitP def recursive_call(fld: Field) -> expr.Expr: return expr.Call( "Field_First_" + fld.name, - rty.BIT_INDEX, + ty.BIT_INDEX, [ expr.Variable("Cursors"), expr.Variable("First"), @@ -653,7 +653,7 @@ def recursive_call(fld: Field) -> expr.Expr: def field_size_internal_call(fld: expr.Variable) -> expr.Expr: return expr.Call( "Field_Size_Internal", - rty.BIT_LENGTH, + ty.BIT_LENGTH, [ expr.Variable("Cursors"), expr.Variable("First"), @@ -678,7 +678,7 @@ def link_first_expr(link: Link) -> tuple[expr.Expr, expr.Expr]: expr.AndThen( expr.Call( "Well_Formed", - rty.BOOLEAN, + ty.BOOLEAN, [ expr.Indexed( expr.Variable("Cursors"), @@ -708,7 +708,7 @@ def fld_first_expr(fld: Field) -> expr.Expr: return first_expr[0][1] return expr.IfExpr( first_expr, - expr.Call("RFLX_Types.Unreachable", rty.BOOLEAN), + expr.Call("RFLX_Types.Unreachable", ty.BOOLEAN), ) assert first_node != fld return expr.Add( @@ -2156,16 +2156,16 @@ def condition(field: Field, message: Message) -> Expr: mapping={ expr.Size(field.name): expr.Call( const.TYPES_BASE_INT, - rty.BASE_INTEGER, + ty.BASE_INTEGER, [expr.Variable("Size")], ), expr.Last(field.name): expr.Call( const.TYPES_BASE_INT, - rty.BASE_INTEGER, + ty.BASE_INTEGER, [ expr.Call( "Field_Last", - rty.BIT_LENGTH, + ty.BIT_LENGTH, [ expr.Variable("Ctx"), expr.Variable(field.affixed_name, immutable=True), @@ -2180,7 +2180,7 @@ def condition(field: Field, message: Message) -> Expr: if message.field_types[field] == BOOLEAN: c = c.substituted( lambda x: ( - expr.Call("To_Actual", rty.BOOLEAN, [expr.Variable("Val")]) + expr.Call("To_Actual", ty.BOOLEAN, [expr.Variable("Val")]) if x == expr.Variable(field.name) else x ), @@ -3682,14 +3682,14 @@ def func(expression: expr.Expr) -> expr.Expr: if isinstance(field_type, Enumeration): return expr.Call( "To_Base_Integer", - rty.BASE_INTEGER, + ty.BASE_INTEGER, [var], ) if isinstance(field_type, Scalar): return expr.Call( const.TYPES_BASE_INT, - rty.BASE_INTEGER, + ty.BASE_INTEGER, [var], ) @@ -3728,7 +3728,7 @@ def _create_to_context_procedure(prefix: str, message: Message) -> UnitPart: lambda x: ( expr.Call( const.TYPES_BIT_LENGTH, - rty.BIT_LENGTH, + ty.BIT_LENGTH, [expr.Variable("Struct" * x.identifier)], ) if isinstance(x, expr.Variable) @@ -3837,7 +3837,7 @@ def substitute(expression: expr.Expr) -> expr.Expr: ): return expr.Call( f"Field_Size_{expression.prefix.identifier}", - rty.BIT_LENGTH, + ty.BIT_LENGTH, [expr.Variable("Struct")], ) if ( @@ -3846,7 +3846,7 @@ def substitute(expression: expr.Expr) -> expr.Expr: ): return expr.Call( const.TYPES_BIT_LENGTH, - rty.BIT_LENGTH, + ty.BIT_LENGTH, [ expr.Selected( expr.Variable("Struct"), diff --git a/rflx/generator/parser.py b/rflx/generator/parser.py index 7dba035cb..0f8fd6e44 100644 --- a/rflx/generator/parser.py +++ b/rflx/generator/parser.py @@ -2,7 +2,7 @@ from collections.abc import Mapping, Sequence -from rflx import expr, expr_conv, typing_ as rty +from rflx import expr, expr_conv, ty from rflx.ada import ( TRUE, Add, @@ -1191,7 +1191,7 @@ def valid_message_condition(self, message: Message, well_formed: bool = False) - and isinstance(message.field_types[l.source], Composite) else "Valid" ), - rty.BOOLEAN, + ty.BOOLEAN, [ expr.Variable("Ctx"), expr.Variable(l.source.affixed_name, immutable=True), diff --git a/rflx/generator/serializer.py b/rflx/generator/serializer.py index ee1c86844..b4975877f 100644 --- a/rflx/generator/serializer.py +++ b/rflx/generator/serializer.py @@ -3,7 +3,7 @@ from collections.abc import Mapping from enum import Enum -from rflx import expr, expr_conv, typing_ as rty +from rflx import expr, expr_conv, ty from rflx.ada import ( TRUE, Add, @@ -121,7 +121,7 @@ def create_valid_size_function(self, message: Message) -> UnitPart: common.substitution( message, self.prefix, - target_type=rty.BIT_LENGTH, + target_type=ty.BIT_LENGTH, ), ).simplified(), ), diff --git a/rflx/generator/state_machine.py b/rflx/generator/state_machine.py index f8ada7d03..ac2857f0f 100644 --- a/rflx/generator/state_machine.py +++ b/rflx/generator/state_machine.py @@ -1,6 +1,6 @@ from __future__ import annotations -import typing as ty +import typing from collections.abc import Callable, Iterable, Mapping, Sequence from dataclasses import dataclass, field as dataclass_field from functools import partial, singledispatchmethod @@ -8,7 +8,7 @@ from typing_extensions import Self -from rflx import ada, ir, model, typing_ as rty +from rflx import ada, ir, model, ty from rflx.ada import ( FALSE, TRUE, @@ -244,22 +244,22 @@ def _create_function(self, function: ir.FuncDecl) -> Sequence[SubprogramDeclarat ), ] - if function.type_ == rty.Undefined(): + if function.type_ == ty.Undefined(): fatal_fail( f'return type of function "{function.identifier}" is undefined', location=function.location, ) - if function.type_ == rty.OPAQUE: + if function.type_ == ty.OPAQUE: fatal_fail( f'Opaque as return type of function "{function.identifier}" not allowed', location=function.location, ) - if isinstance(function.type_, rty.Sequence): + if isinstance(function.type_, ty.Sequence): fail( f'sequence as return type of function "{function.identifier}" not yet supported', location=function.location, ) - if isinstance(function.type_, rty.Message): + if isinstance(function.type_, ty.Message): if not function.type_.is_definite: fatal_fail( "non-definite message" @@ -267,7 +267,7 @@ def _create_function(self, function: ir.FuncDecl) -> Sequence[SubprogramDeclarat location=function.location, ) if any( - isinstance(field_type, rty.Sequence) and field_type != rty.OPAQUE + isinstance(field_type, ty.Sequence) and field_type != ty.OPAQUE for field_type in function.type_.types.values() ): fail( @@ -279,7 +279,7 @@ def _create_function(self, function: ir.FuncDecl) -> Sequence[SubprogramDeclarat self._state_machine_context.referenced_types.append(function.return_type) for a in function.arguments: - if isinstance(a.type_, rty.Sequence) and a.type_ != rty.OPAQUE: + if isinstance(a.type_, ty.Sequence) and a.type_ != ty.OPAQUE: fail( f'sequence as parameter of function "{function.identifier}" not yet supported', location=function.location, @@ -289,13 +289,13 @@ def _create_function(self, function: ir.FuncDecl) -> Sequence[SubprogramDeclarat [a.identifier], ( const.TYPES_BYTES - if a.type_ == rty.OPAQUE + if a.type_ == ty.OPAQUE else ( ID("Boolean") - if a.type_ == rty.BOOLEAN + if a.type_ == ty.BOOLEAN else ( self._prefix * a.type_identifier * "Structure" - if isinstance(a.type_, rty.Message) + if isinstance(a.type_, ty.Message) else self._prefix * a.type_identifier ) ) @@ -303,7 +303,7 @@ def _create_function(self, function: ir.FuncDecl) -> Sequence[SubprogramDeclarat ), ) - assert isinstance(a.type_, (rty.Integer, rty.Enumeration, rty.Message, rty.Sequence)) + assert isinstance(a.type_, (ty.Integer, ty.Enumeration, ty.Message, ty.Sequence)) self._state_machine_context.referenced_types.append(a.type_.identifier) @@ -312,10 +312,10 @@ def _create_function(self, function: ir.FuncDecl) -> Sequence[SubprogramDeclarat [ID("RFLX_Result")], ( self._prefix * function.return_type * "Structure" - if isinstance(function.type_, rty.Message) + if isinstance(function.type_, ty.Message) else ( ID("Boolean") - if function.type_ == rty.BOOLEAN + if function.type_ == ty.BOOLEAN else self._prefix * function.return_type ) ), @@ -331,7 +331,7 @@ def _create_function(self, function: ir.FuncDecl) -> Sequence[SubprogramDeclarat [ *( [Precondition(Not(Constrained("RFLX_Result")))] - if isinstance(function.type_, rty.Enumeration) + if isinstance(function.type_, ty.Enumeration) and function.type_.always_valid else [] ), @@ -585,7 +585,7 @@ def _create_use_clauses_body(self) -> list[Declaration]: ), ] if any( - type_identifier == rty.BASE_INTEGER.identifier + type_identifier == ty.BASE_INTEGER.identifier for type_identifier in self._state_machine_context.used_types_body ) else [] @@ -617,7 +617,7 @@ def is_global(identifier: ID) -> bool: composite_globals = [ d for d in self._state_machine.declarations - if isinstance(d, ir.VarDecl) and isinstance(d.type_, (rty.Message, rty.Sequence)) + if isinstance(d, ir.VarDecl) and isinstance(d.type_, (ty.Message, ty.Sequence)) ] channel_reads = self._channel_io(self._state_machine, read=True) @@ -734,7 +734,7 @@ def _channel_io( or (isinstance(action, ir.Write) and write) ) and isinstance(action.expression, ir.Var) - and isinstance(action.expression.type_, rty.Message) + and isinstance(action.expression.type_, ty.Message) ): channels[action.channel].append( ChannelAccess( @@ -898,8 +898,8 @@ def _create_uninitialized_function( ), ) for declaration in composite_globals - if isinstance(declaration.type_, (rty.Message, rty.Sequence)) - and declaration.type_ != rty.OPAQUE + if isinstance(declaration.type_, (ty.Message, ty.Sequence)) + and declaration.type_ != ty.OPAQUE ], *( [ @@ -1123,7 +1123,7 @@ def _create_states( ] for d in declarations: - if isinstance(d.type_, (rty.Message, rty.Sequence)) and d.type_ != rty.OPAQUE: + if isinstance(d.type_, (ty.Message, ty.Sequence)) and d.type_ != ty.OPAQUE: self._state_machine_context.used_packages_body.append( const.TYPES_OPERATORS_PACKAGE, ) @@ -1478,7 +1478,7 @@ def _create_reset_messages_before_write_procedure( if ( isinstance(action, ir.Read) and isinstance(action.expression, ir.Var) - and isinstance(action.expression.type_, rty.Message) + and isinstance(action.expression.type_, ty.Message) ) ], ) @@ -1597,7 +1597,7 @@ def _create_in_io_state_function(state_machine: ir.StateMachine) -> UnitPart: if ( isinstance(action, (ir.Read, ir.Write)) and isinstance(action.expression, ir.Var) - and isinstance(action.expression.type_, rty.Message) + and isinstance(action.expression.type_, ty.Message) ) ) ] @@ -2920,7 +2920,7 @@ def always_true(_: ID) -> bool: declaration.expression, state_machine_global=state_machine_global, ) - if isinstance(declaration.type_, (rty.Message, rty.Sequence)): + if isinstance(declaration.type_, (ty.Message, ty.Sequence)): has_composite_declarations |= True if state_machine_global and self._allocator.required: @@ -3006,7 +3006,7 @@ def _state_action( def _declare( # noqa: PLR0912, PLR0913 self, identifier: ID, - type_: rty.Type, + type_: ty.Type, is_global: Callable[[ID], bool], alloc_id: Location | None, expression: ir.ComplexExpr | None = None, @@ -3021,7 +3021,7 @@ def _declare( # noqa: PLR0912, PLR0913 location=expression.expr.location, ) - if type_ == rty.OPAQUE: + if type_ == ty.OPAQUE: initialization = None object_type: Expr = Variable(const.TYPES_BYTES) aspects: list[Aspect] = [] @@ -3058,13 +3058,13 @@ def _declare( # noqa: PLR0912, PLR0913 ), ) - elif isinstance(type_, (rty.UniversalInteger, rty.Integer, rty.Enumeration)): + elif isinstance(type_, (ty.UniversalInteger, ty.Integer, ty.Enumeration)): result.global_declarations.append( ObjectDeclaration( [identifier], ( self._ada_type(type_.identifier) - if isinstance(type_, rty.NamedTypeClass) + if isinstance(type_, ty.NamedTypeClass) else const.TYPES_BASE_INT ), ( @@ -3091,7 +3091,7 @@ def _declare( # noqa: PLR0912, PLR0913 location=expression.expr.location, ) - elif isinstance(type_, (rty.Message, rty.Sequence)): + elif isinstance(type_, (ty.Message, ty.Sequence)): if expression is not None: fail( f"initialization for {type_} not yet supported", @@ -3120,9 +3120,9 @@ def _declare( # noqa: PLR0912, PLR0913 { n: First(self._ada_type(t.identifier)) for n, t in type_.parameter_types.items() - if isinstance(t, (rty.Integer, rty.Enumeration)) + if isinstance(t, (ty.Integer, ty.Enumeration)) } - if isinstance(type_, rty.Message) + if isinstance(type_, ty.Message) else None ), ), @@ -3133,7 +3133,7 @@ def _declare( # noqa: PLR0912, PLR0913 result.finalization.extend( self._free_context_buffer(identifier, type_identifier, is_global, alloc_id), ) - elif isinstance(type_, rty.Structure): + elif isinstance(type_, ty.Structure): # Messages with initialization clauses are not optimized assert expression is None @@ -3153,10 +3153,10 @@ def _declare( # noqa: PLR0912, PLR0913 location=identifier.location, ) - assert isinstance(type_, (rty.NamedTypeClass, rty.UniversalInteger)), type_ + assert isinstance(type_, (ty.NamedTypeClass, ty.UniversalInteger)), type_ type_identifier = ( - type_.identifier if isinstance(type_, rty.NamedTypeClass) else const.TYPES_BASE_INT + type_.identifier if isinstance(type_, ty.NamedTypeClass) else const.TYPES_BASE_INT ) if state_machine_global: self._state_machine_context.referenced_types.append(type_identifier) @@ -3168,7 +3168,7 @@ def _declare( # noqa: PLR0912, PLR0913 def _assign( # noqa: PLR0913 self, target: ID, - target_type: rty.Type, + target_type: ty.Type, expression: ir.Expr, exception_handler: ExceptionHandler, is_global: Callable[[ID], bool], @@ -3192,7 +3192,7 @@ def _assign( # noqa: PLR0913 ) if ( - isinstance(target_type, rty.Message) + isinstance(target_type, ty.Message) and isinstance(expression, ir.Var) and expression.identifier == target ): @@ -3216,7 +3216,7 @@ def _assign( # noqa: PLR0913 ) if isinstance(expression, ir.Comprehension): - assert isinstance(target_type, rty.Sequence) + assert isinstance(target_type, ty.Sequence) return self._assign_to_comprehension( target, target_type, @@ -3259,12 +3259,12 @@ def _assign( # noqa: PLR0913 ir.CaseExpr, ), ) and ( - isinstance(expression.type_, (rty.AnyInteger, rty.Enumeration, rty.Aggregate)) - or expression.type_ == rty.OPAQUE + isinstance(expression.type_, (ty.AnyInteger, ty.Enumeration, ty.Aggregate)) + or expression.type_ == ty.OPAQUE ): assert isinstance( target_type, - (rty.Integer, rty.Enumeration, rty.Message, rty.Sequence), + (ty.Integer, ty.Enumeration, ty.Message, ty.Sequence), ), target_type return [ Assignment( @@ -3275,7 +3275,7 @@ def _assign( # noqa: PLR0913 if isinstance(expression, ir.Var) and isinstance( expression.type_, - (rty.Message, rty.Sequence), + (ty.Message, ty.Sequence), ): _unsupported_expression(expression, "in assignment") @@ -3287,7 +3287,7 @@ def _assign_to_field_access( field_access: ir.FieldAccess, is_global: Callable[[ID], bool], ) -> Sequence[Statement]: - if isinstance(field_access.message_type, rty.Structure): + if isinstance(field_access.message_type, ty.Structure): return [ Assignment( Variable(variable_id(target, is_global)), @@ -3295,15 +3295,15 @@ def _assign_to_field_access( ), ] - assert isinstance(field_access.message_type, rty.Message) + assert isinstance(field_access.message_type, ty.Message) message_type_id = field_access.message_type.identifier message_context = context_id(field_access.message, is_global) field = field_access.field if ( - isinstance(field_access.type_, (rty.AnyInteger, rty.Enumeration)) - or field_access.type_ == rty.OPAQUE + isinstance(field_access.type_, (ty.AnyInteger, ty.Enumeration)) + or field_access.type_ == ty.OPAQUE ): if field in field_access.message_type.parameter_types: return [ @@ -3323,7 +3323,7 @@ def _assign_to_field_access( ), ] - if isinstance(field_access.type_, rty.Sequence): + if isinstance(field_access.type_, ty.Sequence): # Eng/RecordFlux/RecordFlux#577 # The relevant buffer part has to be copied from the message context into a # sequence context. With the current implementation the sequence needs to @@ -3348,7 +3348,7 @@ def _assign_to_message_aggregate( exception_handler: ExceptionHandler, is_global: Callable[[ID], bool], ) -> Sequence[Statement]: - assert isinstance(message_aggregate.type_, rty.Message) + assert isinstance(message_aggregate.type_, ty.Message) self._state_machine_context.used_types_body.append(const.TYPES_BIT_LENGTH) @@ -3359,7 +3359,7 @@ def _assign_to_message_aggregate( for f, v in message_aggregate.field_values.items() if f in message_aggregate.type_.parameter_types for t in [message_aggregate.type_.parameter_types[f]] - if isinstance(t, (rty.Integer, rty.Enumeration)) + if isinstance(t, (ty.Integer, ty.Enumeration)) ] return [ @@ -3388,7 +3388,7 @@ def _assign_to_delta_message_aggregate( exception_handler: ExceptionHandler, is_global: Callable[[ID], bool], ) -> Sequence[Statement]: - assert isinstance(delta_message_aggregate.type_, rty.Message) + assert isinstance(delta_message_aggregate.type_, ty.Message) self._state_machine_context.used_types_body.append(const.TYPES_BIT_LENGTH) @@ -3438,7 +3438,7 @@ def _assign_to_head( # noqa: PLR0913 state: ID, alloc_id: Location | None, ) -> Sequence[Statement]: - if not isinstance(head.type_, (rty.Integer, rty.Enumeration, rty.Message)): + if not isinstance(head.type_, (ty.Integer, ty.Enumeration, ty.Message)): fatal_fail( f"unexpected sequence element type {head.type_}" f' for "{head}" in assignment of "{target}"', @@ -3463,11 +3463,11 @@ def _assign_to_find( # noqa: PLR0913 state: ID, alloc_id: Location | None, ) -> Sequence[Statement]: - assert isinstance(find.sequence.type_, rty.Sequence) + assert isinstance(find.sequence.type_, ty.Sequence) sequence_type_id = find.sequence.type_.identifier sequence_element_type = find.sequence.type_.element - if isinstance(sequence_element_type, rty.Message): + if isinstance(sequence_element_type, ty.Message): if isinstance(find.sequence, ir.Var): sequence_id = ID(f"{find.sequence}") comprehension_sequence_id = copy_id(sequence_id) @@ -3485,13 +3485,13 @@ def _assign_to_find( # noqa: PLR0913 def comprehension_statements( local_exception_handler: ExceptionHandler, ) -> list[Statement]: - assert isinstance(find.type_, (rty.Integer, rty.Enumeration, rty.Message)) + assert isinstance(find.type_, (ty.Integer, ty.Enumeration, ty.Message)) assert isinstance( sequence_element_type, - (rty.Message, rty.Integer, rty.Enumeration), + (ty.Message, ty.Integer, ty.Enumeration), ) default_assignment = [] - if isinstance(find.type_, (rty.Integer, rty.Enumeration)): + if isinstance(find.type_, (ty.Integer, ty.Enumeration)): default_assignment = [Assignment(target, First(find.type_.identifier))] return [ Declare( @@ -3534,7 +3534,7 @@ def comprehension_statements( alloc_id, ) if isinstance(find.sequence, ir.FieldAccess): - assert isinstance(selected.message_type, rty.Message) + assert isinstance(selected.message_type, ty.Message) message_id = selected.message message_type = selected.message_type.identifier message_field = selected.field @@ -3579,8 +3579,8 @@ def _assign_to_head_sequence( # noqa: PLR0913 state: ID, alloc_id: Location | None, ) -> Sequence[Statement]: - assert isinstance(head.prefix_type, rty.Sequence) - assert isinstance(head.type_, (rty.Integer, rty.Enumeration, rty.Message)) + assert isinstance(head.prefix_type, ty.Sequence) + assert isinstance(head.type_, (ty.Integer, ty.Enumeration, ty.Message)) target_type = head.type_.identifier sequence_type = head.prefix_type.identifier @@ -3588,7 +3588,7 @@ def _assign_to_head_sequence( # noqa: PLR0913 sequence_context = context_id(sequence_id, is_global) sequence_identifier = ID(f"{head.prefix}") - if isinstance(head.type_, (rty.Integer, rty.Enumeration)): + if isinstance(head.type_, (ty.Integer, ty.Enumeration)): return [ # TODO(eng/recordflux/RecordFlux#1742): Move check into IR self._raise_exception_if( @@ -3621,7 +3621,7 @@ def _assign_to_head_sequence( # noqa: PLR0913 ), ] - assert isinstance(head.type_, rty.Message) + assert isinstance(head.type_, ty.Message) self._state_machine_context.used_types_body.append(const.TYPES_LENGTH) self._state_machine_context.used_packages_body.append(const.TYPES_OPERATORS_PACKAGE) @@ -3757,15 +3757,15 @@ def statements(exception_handler: ExceptionHandler) -> list[Statement]: def _assign_to_comprehension( # noqa: PLR0913 self, target: ID, - target_type: rty.Sequence, + target_type: ty.Sequence, comprehension: ir.Comprehension, exception_handler: ExceptionHandler, is_global: Callable[[ID], bool], state: ID, alloc_id: Location | None, ) -> Sequence[Statement]: - assert isinstance(comprehension.type_, (rty.Sequence, rty.Aggregate)) - assert isinstance(comprehension.sequence.type_, rty.Sequence) + assert isinstance(comprehension.type_, (ty.Sequence, ty.Aggregate)) + assert isinstance(comprehension.sequence.type_, ty.Sequence) self._state_machine_context.used_types_body.append(const.TYPES_BIT_LENGTH) @@ -3778,7 +3778,7 @@ def _assign_to_comprehension( # noqa: PLR0913 reset_target = CallStatement(target_type.identifier * "Reset", [Variable(target_context)]) - if isinstance(sequence_element_type, rty.Message): + if isinstance(sequence_element_type, ty.Message): iterator_type_id = sequence_element_type.identifier if isinstance(comprehension.sequence, ir.Var): @@ -3817,7 +3817,7 @@ def statements(local_exception_handler: ExceptionHandler) -> list[Statement]: if isinstance(comprehension.sequence, ir.FieldAccess): field_access = comprehension.sequence - assert isinstance(field_access.message_type, rty.Message) + assert isinstance(field_access.message_type, ty.Message) message_id = ID(field_access.message) message_type = field_access.message_type.identifier @@ -3892,7 +3892,7 @@ def _assign_to_call( target_id = variable_id(target, is_global) message_id = context_id(target, is_global) - if isinstance(call_expr.type_, rty.Message): + if isinstance(call_expr.type_, ty.Message): type_identifier = self._ada_type(call_expr.type_.identifier) local_declarations.append( ObjectDeclaration( @@ -3928,7 +3928,7 @@ def _assign_to_call( ], ) - elif isinstance(call_expr.type_, rty.Structure): + elif isinstance(call_expr.type_, ty.Structure): type_identifier = self._ada_type(call_expr.type_.identifier) post_call.append( # TODO(eng/recordflux/RecordFlux#1742): Move check into IR @@ -3960,7 +3960,7 @@ def _assign_to_call( ): _unsupported_expression(a, "as function argument") - if isinstance(a, ir.Var) and isinstance(a.type_, rty.Message): + if isinstance(a, ir.Var) and isinstance(a.type_, ty.Message): type_identifier = self._ada_type(a.type_.identifier) local_declarations.append( ObjectDeclaration( @@ -3978,8 +3978,8 @@ def _assign_to_call( ), ) arguments.append(self._to_ada_expr(a, is_global)) - elif isinstance(a, ir.FieldAccess) and a.type_ == rty.OPAQUE: - assert isinstance(a.type_, rty.Sequence) + elif isinstance(a, ir.FieldAccess) and a.type_ == ty.OPAQUE: + assert isinstance(a.type_, ty.Sequence) self._state_machine_context.used_packages_body.append(const.TYPES_OPERATORS_PACKAGE) argument_name = f"RFLX_{call_expr.identifier}_Arg_{i}_{a.message}" argument_length = f"{argument_name}_Length" @@ -4055,7 +4055,7 @@ def _assign_to_call( arguments.append(argument) elif isinstance(a, ir.Opaque) and isinstance( a.prefix_type, - (rty.Message, rty.Sequence), + (ty.Message, ty.Sequence), ): self._state_machine_context.used_types_body.append(const.TYPES_LENGTH) self._state_machine_context.used_packages_body.append(const.TYPES_OPERATORS_PACKAGE) @@ -4125,7 +4125,7 @@ def _assign_to_call( type_identifier * ( "Well_Formed_Message" - if isinstance(a.prefix_type, rty.Message) + if isinstance(a.prefix_type, ty.Message) else "Valid" ), [Variable(context)], @@ -4166,7 +4166,7 @@ def _assign_to_call( ), ] - if isinstance(call_expr.type_, rty.Structure): + if isinstance(call_expr.type_, ty.Structure): return [*call, *post_call] return call @@ -4178,7 +4178,7 @@ def _assign_to_conversion( exception_handler: ExceptionHandler, is_global: Callable[[ID], bool], ) -> Sequence[Statement]: - if not isinstance(conversion.type_, rty.Message): + if not isinstance(conversion.type_, ty.Message): return [ Assignment( variable_id(target, is_global), @@ -4186,10 +4186,10 @@ def _assign_to_conversion( ), ] - assert isinstance(conversion.type_, rty.Message) + assert isinstance(conversion.type_, ty.Message) assert isinstance(conversion.argument, ir.FieldAccess), f"{target}, {conversion}" - assert conversion.argument.type_ == rty.OPAQUE - assert isinstance(conversion.argument.message_type, rty.Message) + assert conversion.argument.type_ == ty.OPAQUE + assert isinstance(conversion.argument.message_type, ty.Message) pdu = conversion.argument.message_type sdu = conversion.type_ @@ -4272,12 +4272,12 @@ def _message_field_assign( # noqa: PLR0913 self, target: ID, target_field: ID, - message_type: rty.Type, + message_type: ty.Type, value: ir.Expr, exception_handler: ExceptionHandler, is_global: Callable[[ID], bool], ) -> Sequence[Statement]: - assert isinstance(message_type, rty.Message) + assert isinstance(message_type, ty.Message) target_context = context_id(target, is_global) @@ -4327,7 +4327,7 @@ def _append( exception_handler: ExceptionHandler, is_global: Callable[[ID], bool], ) -> Sequence[Statement]: - assert isinstance(append.type_, rty.Sequence) + assert isinstance(append.type_, ty.Sequence) self._state_machine_context.used_types_body.append(const.TYPES_BIT_LENGTH) @@ -4369,7 +4369,7 @@ def check( ), ] - if isinstance(append.type_.element, (rty.Integer, rty.Enumeration)): + if isinstance(append.type_.element, (ty.Integer, ty.Enumeration)): if isinstance(append.expression, (ir.Var, ir.EnumLit, ir.IntVal)): sequence_type = append.type_.identifier sequence_context = context_id(append.sequence, is_global) @@ -4388,7 +4388,7 @@ def check( _unsupported_expression(append.expression, "in Append statement") - if isinstance(append.type_.element, rty.Message): + if isinstance(append.type_.element, ty.Message): sequence_type = append.type_.identifier sequence_context = context_id(append.sequence, is_global) element_type = append.type_.element.identifier @@ -4445,7 +4445,7 @@ def check( def _read(read: ir.Read, is_global: Callable[[ID], bool]) -> Sequence[Statement]: if not isinstance(read.expression, ir.Var) or not isinstance( read.expression.type_, - rty.Message, + ty.Message, ): _unsupported_expression(read.expression, "in Read statement") @@ -4463,7 +4463,7 @@ def _write( ) -> Sequence[Statement]: if not isinstance(write.expression, ir.Var) or not isinstance( write.expression.type_, - rty.Message, + ty.Message, ): _unsupported_expression(write.expression, "in Write statement") @@ -4510,7 +4510,7 @@ def _reset( reset: ir.Reset, is_global: Callable[[ID], bool], ) -> Sequence[Statement]: - assert isinstance(reset.type_, (rty.Message, rty.Sequence)) + assert isinstance(reset.type_, (ty.Message, ty.Sequence)) target_type = reset.type_.identifier target_context = context_id(reset.identifier, is_global) @@ -4527,16 +4527,16 @@ def _to_ada_expr(self, expression: ir.Expr, is_global: Callable[[ID], bool]) -> raise NotImplementedError(f"{type(expression).__name__} is not yet supported") @_to_ada_expr.register - def _(self, expression: ir.Var, is_global: ty.Callable[[ID], bool]) -> Expr: + def _(self, expression: ir.Var, is_global: typing.Callable[[ID], bool]) -> Expr: # TODO(eng/recordflux/RecordFlux#1359): Replace typing.Callable by collections.abc.Callable return Variable(variable_id(expression.identifier, is_global)) @_to_ada_expr.register - def _(self, expression: ir.IntVar, is_global: ty.Callable[[ID], bool]) -> Expr: + def _(self, expression: ir.IntVar, is_global: typing.Callable[[ID], bool]) -> Expr: return Variable(variable_id(expression.identifier, is_global)) @_to_ada_expr.register - def _(self, expression: ir.EnumLit, _is_global: ty.Callable[[ID], bool]) -> Expr: + def _(self, expression: ir.EnumLit, _is_global: typing.Callable[[ID], bool]) -> Expr: literal = Literal(expression.identifier) if expression.type_.always_valid: @@ -4545,38 +4545,38 @@ def _(self, expression: ir.EnumLit, _is_global: ty.Callable[[ID], bool]) -> Expr return literal @_to_ada_expr.register - def _(self, expression: ir.IntVal, _is_global: ty.Callable[[ID], bool]) -> Expr: + def _(self, expression: ir.IntVal, _is_global: typing.Callable[[ID], bool]) -> Expr: return Number(expression.value) @_to_ada_expr.register - def _(self, expression: ir.BoolVal, _is_global: ty.Callable[[ID], bool]) -> Expr: + def _(self, expression: ir.BoolVal, _is_global: typing.Callable[[ID], bool]) -> Expr: return Literal(str(expression.value)) @_to_ada_expr.register - def _(self, expression: ir.First, _is_global: ty.Callable[[ID], bool]) -> Expr: - assert isinstance(expression.type_, (rty.AnyInteger, rty.Enumeration)) + def _(self, expression: ir.First, _is_global: typing.Callable[[ID], bool]) -> Expr: + assert isinstance(expression.type_, (ty.AnyInteger, ty.Enumeration)) return First(self._ada_type(expression.prefix)) @_to_ada_expr.register - def _(self, expression: ir.Last, _is_global: ty.Callable[[ID], bool]) -> Expr: - assert isinstance(expression.type_, (rty.AnyInteger, rty.Enumeration)) + def _(self, expression: ir.Last, _is_global: typing.Callable[[ID], bool]) -> Expr: + assert isinstance(expression.type_, (ty.AnyInteger, ty.Enumeration)) return Last(self._ada_type(expression.prefix)) @_to_ada_expr.register - def _(self, expression: ir.Valid, is_global: ty.Callable[[ID], bool]) -> Expr: - if isinstance(expression.prefix_type, rty.Message): + def _(self, expression: ir.Valid, is_global: typing.Callable[[ID], bool]) -> Expr: + if isinstance(expression.prefix_type, ty.Message): return Call( expression.prefix_type.identifier * "Well_Formed_Message", [Variable(context_id(expression.prefix, is_global))], ) - if isinstance(expression.prefix_type, rty.Structure): + if isinstance(expression.prefix_type, ty.Structure): return Call( expression.prefix_type.identifier * "Valid_Structure", [Variable(expression.prefix)], ) - if isinstance(expression.prefix_type, rty.Sequence): + if isinstance(expression.prefix_type, ty.Sequence): return Call( expression.prefix_type.identifier * "Valid", [Variable(context_id(expression.prefix, is_global))], @@ -4588,25 +4588,25 @@ def _convert_types_of_int_relation(self, expression: ir.Relation) -> ir.Relation if ( isinstance(expression.left, ir.IntExpr) and isinstance(expression.right, ir.IntExpr) - and isinstance(expression.left.type_, rty.Integer) - and isinstance(expression.right.type_, rty.Integer) + and isinstance(expression.left.type_, ty.Integer) + and isinstance(expression.right.type_, ty.Integer) and ( - expression.left.type_ != rty.BASE_INTEGER - or expression.right.type_ != rty.BASE_INTEGER + expression.left.type_ != ty.BASE_INTEGER + or expression.right.type_ != ty.BASE_INTEGER ) ): - self._state_machine_context.used_types_body.append(rty.BASE_INTEGER.identifier) - self._state_machine_context.referenced_types_body.append(rty.BASE_INTEGER.identifier) + self._state_machine_context.used_types_body.append(ty.BASE_INTEGER.identifier) + self._state_machine_context.referenced_types_body.append(ty.BASE_INTEGER.identifier) result = expression.__class__( ( - ir.IntConversion(rty.BASE_INTEGER, expression.left) - if expression.left.type_ != rty.BASE_INTEGER + ir.IntConversion(ty.BASE_INTEGER, expression.left) + if expression.left.type_ != ty.BASE_INTEGER else expression.left ), ( - ir.IntConversion(rty.BASE_INTEGER, expression.right) - if expression.right.type_ != rty.BASE_INTEGER + ir.IntConversion(ty.BASE_INTEGER, expression.right) + if expression.right.type_ != ty.BASE_INTEGER else expression.right ), ) @@ -4620,13 +4620,10 @@ def _convert_types_of_int_relation(self, expression: ir.Relation) -> ir.Relation def _relation_to_ada_expr( self, expression: ir.Relation, - is_global: ty.Callable[[ID], bool], + is_global: typing.Callable[[ID], bool], ) -> Expr: assert isinstance(expression, (ir.Equal, ir.NotEqual)) - if ( - isinstance(expression.left.type_, rty.Enumeration) - and expression.left.type_.always_valid - ): + if isinstance(expression.left.type_, ty.Enumeration) and expression.left.type_.always_valid: relation = Equal if isinstance(expression, ir.Equal) else NotEqual self._state_machine_context.used_types_body.append(expression.left.type_.identifier) @@ -4643,23 +4640,23 @@ def _relation_to_ada_expr( return result @_to_ada_expr.register - def _(self, expression: ir.Size, is_global: ty.Callable[[ID], bool]) -> Expr: + def _(self, expression: ir.Size, is_global: typing.Callable[[ID], bool]) -> Expr: if ( - isinstance(expression.prefix_type, rty.AnyInteger) + isinstance(expression.prefix_type, ty.AnyInteger) or ( - isinstance(expression.prefix_type, rty.Aggregate) - and isinstance(expression.prefix_type.element, rty.AnyInteger) + isinstance(expression.prefix_type, ty.Aggregate) + and isinstance(expression.prefix_type.element, ty.AnyInteger) ) or ( - isinstance(expression.prefix_type, (rty.Integer, rty.Enumeration)) + isinstance(expression.prefix_type, (ty.Integer, ty.Enumeration)) and expression.prefix == expression.prefix_type.identifier ) ): return Size(expression.prefix) if ( - isinstance(expression.prefix_type, (rty.Message, rty.Sequence)) - and expression.prefix_type != rty.OPAQUE + isinstance(expression.prefix_type, (ty.Message, ty.Sequence)) + and expression.prefix_type != ty.OPAQUE ): type_ = expression.prefix_type.identifier context = context_id(expression.prefix, is_global) @@ -4668,23 +4665,23 @@ def _(self, expression: ir.Size, is_global: ty.Callable[[ID], bool]) -> Expr: assert False @_to_ada_expr.register - def _(self, expression: ir.HasData, is_global: ty.Callable[[ID], bool]) -> Expr: - assert isinstance(expression.prefix_type, rty.Message) + def _(self, expression: ir.HasData, is_global: typing.Callable[[ID], bool]) -> Expr: + assert isinstance(expression.prefix_type, ty.Message) type_ = expression.prefix_type.identifier context = context_id(expression.prefix, is_global) return Greater(Call(type_ * "Byte_Size", [Variable(context)]), Number(0)) @_to_ada_expr.register - def _(self, expression: ir.Opaque, is_global: ty.Callable[[ID], bool]) -> Expr: + def _(self, expression: ir.Opaque, is_global: typing.Callable[[ID], bool]) -> Expr: raise NotImplementedError @_to_ada_expr.register - def _(self, expression: ir.Head, _is_global: ty.Callable[[ID], bool]) -> Expr: + def _(self, expression: ir.Head, _is_global: typing.Callable[[ID], bool]) -> Expr: _unsupported_expression(expression, "in expression") @_to_ada_expr.register - def _(self, expression: ir.FieldValidNext, is_global: ty.Callable[[ID], bool]) -> Expr: - assert isinstance(expression.message_type, rty.Message) + def _(self, expression: ir.FieldValidNext, is_global: typing.Callable[[ID], bool]) -> Expr: + assert isinstance(expression.message_type, ty.Message) type_name = expression.message_type.identifier return Call( type_name * "Valid_Next", @@ -4695,14 +4692,14 @@ def _(self, expression: ir.FieldValidNext, is_global: ty.Callable[[ID], bool]) - ) @_to_ada_expr.register - def _(self, expression: ir.FieldValid, is_global: ty.Callable[[ID], bool]) -> Expr: - assert isinstance(expression.message_type, rty.Message) + def _(self, expression: ir.FieldValid, is_global: typing.Callable[[ID], bool]) -> Expr: + assert isinstance(expression.message_type, ty.Message) type_name = expression.message_type.identifier return Call( type_name * ( "Valid" - if isinstance(expression.field_type, (rty.Integer, rty.Enumeration)) + if isinstance(expression.field_type, (ty.Integer, ty.Enumeration)) else "Well_Formed" ), [ @@ -4712,14 +4709,14 @@ def _(self, expression: ir.FieldValid, is_global: ty.Callable[[ID], bool]) -> Ex ) @_to_ada_expr.register - def _(self, expression: ir.FieldPresent, is_global: ty.Callable[[ID], bool]) -> Expr: - assert isinstance(expression.message_type, rty.Message) + def _(self, expression: ir.FieldPresent, is_global: typing.Callable[[ID], bool]) -> Expr: + assert isinstance(expression.message_type, ty.Message) type_name = expression.message_type.identifier return Call( type_name * ( "Valid" - if isinstance(expression.field_type, (rty.Integer, rty.Enumeration)) + if isinstance(expression.field_type, (ty.Integer, ty.Enumeration)) else "Well_Formed" ), [ @@ -4729,20 +4726,20 @@ def _(self, expression: ir.FieldPresent, is_global: ty.Callable[[ID], bool]) -> ) @_to_ada_expr.register - def _(self, expression: ir.FieldSize, is_global: ty.Callable[[ID], bool]) -> Expr: + def _(self, expression: ir.FieldSize, is_global: typing.Callable[[ID], bool]) -> Expr: type_ = expression.message_type.identifier - if isinstance(expression.message_type, rty.Message): + if isinstance(expression.message_type, ty.Message): context = context_id(expression.message, is_global) return Call( type_ * "Field_Size", [Variable(context), Variable(type_ * "F_" + expression.field)], ) - assert isinstance(expression.message_type, rty.Structure) + assert isinstance(expression.message_type, ty.Structure) return Call(type_ * f"Field_Size_{expression.field}", [Variable(expression.message)]) @_to_ada_expr.register - def _(self, expression: ir.UnaryExpr, is_global: ty.Callable[[ID], bool]) -> Expr: + def _(self, expression: ir.UnaryExpr, is_global: typing.Callable[[ID], bool]) -> Expr: result = getattr(ada, expression.__class__.__name__)( self._to_ada_expr(expression.expression, is_global), ) @@ -4750,7 +4747,7 @@ def _(self, expression: ir.UnaryExpr, is_global: ty.Callable[[ID], bool]) -> Exp return result @_to_ada_expr.register - def _(self, expression: ir.BinaryExpr, is_global: ty.Callable[[ID], bool]) -> Expr: + def _(self, expression: ir.BinaryExpr, is_global: typing.Callable[[ID], bool]) -> Expr: self._record_used_types(expression) name = expression.__class__.__name__ if name == "And": @@ -4765,7 +4762,7 @@ def _(self, expression: ir.BinaryExpr, is_global: ty.Callable[[ID], bool]) -> Ex return result @_to_ada_expr.register - def _(self, expression: ir.Relation, is_global: ty.Callable[[ID], bool]) -> Expr: + def _(self, expression: ir.Relation, is_global: typing.Callable[[ID], bool]) -> Expr: relation = self._convert_types_of_int_relation(expression) result = getattr(ada, relation.__class__.__name__)( self._to_ada_expr(relation.left, is_global), @@ -4775,7 +4772,7 @@ def _(self, expression: ir.Relation, is_global: ty.Callable[[ID], bool]) -> Expr return result @_to_ada_expr.register - def _(self, expression: ir.Equal, is_global: ty.Callable[[ID], bool]) -> Expr: + def _(self, expression: ir.Equal, is_global: typing.Callable[[ID], bool]) -> Expr: if expression.left == ir.BoolVal(value=True) and isinstance(expression.right, ir.Var): return Variable(variable_id(expression.right.identifier, is_global)) if isinstance(expression.left, ir.Var) and expression.right == ir.BoolVal(value=True): @@ -4790,7 +4787,7 @@ def _(self, expression: ir.Equal, is_global: ty.Callable[[ID], bool]) -> Expr: ) @_to_ada_expr.register - def _(self, expression: ir.NotEqual, is_global: ty.Callable[[ID], bool]) -> Expr: + def _(self, expression: ir.NotEqual, is_global: typing.Callable[[ID], bool]) -> Expr: if expression.left == ir.BoolVal(value=True) and isinstance(expression.right, ir.Var): return Not(Variable(variable_id(expression.right.identifier, is_global))) if isinstance(expression.left, ir.Var) and expression.right == ir.BoolVal(value=True): @@ -4805,14 +4802,14 @@ def _(self, expression: ir.NotEqual, is_global: ty.Callable[[ID], bool]) -> Expr ) @_to_ada_expr.register - def _(self, expression: ir.Call, is_global: ty.Callable[[ID], bool]) -> Expr: + def _(self, expression: ir.Call, is_global: typing.Callable[[ID], bool]) -> Expr: raise NotImplementedError @_to_ada_expr.register - def _(self, expression: ir.FieldAccess, is_global: ty.Callable[[ID], bool]) -> Expr: + def _(self, expression: ir.FieldAccess, is_global: typing.Callable[[ID], bool]) -> Expr: if expression.field in expression.message_type.parameter_types: return Selected(Variable(context_id(expression.message, is_global)), expression.field) - if isinstance(expression.message_type, rty.Structure): + if isinstance(expression.message_type, ty.Structure): raise NotImplementedError return Call( expression.message_type.identifier * f"Get_{expression.field}", @@ -4820,13 +4817,13 @@ def _(self, expression: ir.FieldAccess, is_global: ty.Callable[[ID], bool]) -> E ) @_to_ada_expr.register - def _(self, expression: ir.IntFieldAccess, is_global: ty.Callable[[ID], bool]) -> Expr: + def _(self, expression: ir.IntFieldAccess, is_global: typing.Callable[[ID], bool]) -> Expr: if expression.field in expression.message_type.parameter_types: return Selected( Variable(context_id(expression.message, is_global)), expression.field, ) - if isinstance(expression.message_type, rty.Structure): + if isinstance(expression.message_type, ty.Structure): return Selected(Variable(expression.message), expression.field) return Call( expression.message_type.identifier * f"Get_{expression.field}", @@ -4834,7 +4831,7 @@ def _(self, expression: ir.IntFieldAccess, is_global: ty.Callable[[ID], bool]) - ) @_to_ada_expr.register - def _(self, expression: ir.IfExpr, is_global: ty.Callable[[ID], bool]) -> Expr: + def _(self, expression: ir.IfExpr, is_global: typing.Callable[[ID], bool]) -> Expr: assert expression.then_expr.is_expr() assert expression.else_expr.is_expr() return ada.IfThenElse( @@ -4844,14 +4841,14 @@ def _(self, expression: ir.IfExpr, is_global: ty.Callable[[ID], bool]) -> Expr: ) @_to_ada_expr.register - def _(self, expression: ir.Conversion, is_global: ty.Callable[[ID], bool]) -> Expr: + def _(self, expression: ir.Conversion, is_global: typing.Callable[[ID], bool]) -> Expr: return Conversion( self._ada_type(expression.target_type.identifier), self._to_ada_expr(expression.argument, is_global), ) @_to_ada_expr.register - def _(self, expression: ir.Agg, is_global: ty.Callable[[ID], bool]) -> Expr: + def _(self, expression: ir.Agg, is_global: typing.Callable[[ID], bool]) -> Expr: assert len(expression.elements) > 0 if len(expression.elements) == 1: return NamedAggregate( @@ -4865,7 +4862,7 @@ def _(self, expression: ir.Agg, is_global: ty.Callable[[ID], bool]) -> Expr: ) @_to_ada_expr.register - def _(self, expression: ir.NamedAgg, is_global: ty.Callable[[ID], bool]) -> Expr: + def _(self, expression: ir.NamedAgg, is_global: typing.Callable[[ID], bool]) -> Expr: elements: list[tuple[ID | ada.Expr, ada.Expr]] = [ ( n if isinstance(n, ID) else self._to_ada_expr(n, is_global), @@ -4876,11 +4873,11 @@ def _(self, expression: ir.NamedAgg, is_global: ty.Callable[[ID], bool]) -> Expr return NamedAggregate(*elements) @_to_ada_expr.register - def _(self, expression: ir.Str, is_global: ty.Callable[[ID], bool]) -> Expr: + def _(self, expression: ir.Str, is_global: typing.Callable[[ID], bool]) -> Expr: raise NotImplementedError @_to_ada_expr.register - def _(self, expression: ir.CaseExpr, is_global: ty.Callable[[ID], bool]) -> Expr: + def _(self, expression: ir.CaseExpr, is_global: typing.Callable[[ID], bool]) -> Expr: choices = [ (self._to_ada_expr(choice, is_global), self._to_ada_expr(expr, is_global)) for choices, expr in expression.choices @@ -4889,7 +4886,7 @@ def _(self, expression: ir.CaseExpr, is_global: ty.Callable[[ID], bool]) -> Expr return ada.CaseExpr(self._to_ada_expr(expression.expression, is_global), choices) @_to_ada_expr.register - def _(self, expression: ir.SufficientSpace, is_global: ty.Callable[[ID], bool]) -> Expr: + def _(self, expression: ir.SufficientSpace, is_global: typing.Callable[[ID], bool]) -> Expr: return Call( expression.message_type.identifier * "Sufficient_Space", [ @@ -4901,7 +4898,7 @@ def _(self, expression: ir.SufficientSpace, is_global: ty.Callable[[ID], bool]) ) @_to_ada_expr.register - def _(self, expression: ir.HasElement, is_global: ty.Callable[[ID], bool]) -> Expr: + def _(self, expression: ir.HasElement, is_global: typing.Callable[[ID], bool]) -> Expr: return Call( expression.prefix_type.identifier * "Has_Element", [Variable(context_id(expression.prefix, is_global))], @@ -4909,8 +4906,8 @@ def _(self, expression: ir.HasElement, is_global: ty.Callable[[ID], bool]) -> Ex def _record_used_types(self, expression: ir.BinaryExpr) -> None: for e in [expression.left, expression.right]: - if isinstance(e.type_, rty.Integer) or ( - isinstance(e.type_, rty.Enumeration) and not e.type_.always_valid + if isinstance(e.type_, ty.Integer) or ( + isinstance(e.type_, ty.Enumeration) and not e.type_.always_valid ): self._state_machine_context.used_types_body.append(e.type_.identifier) self._state_machine_context.referenced_types_body.append(e.type_.identifier) @@ -5023,7 +5020,7 @@ def _set_message_fields( exception_handler: ExceptionHandler, is_global: Callable[[ID], bool], ) -> Sequence[Statement]: - assert isinstance(message_aggregate.type_, rty.Message) + assert isinstance(message_aggregate.type_, ty.Message) message_type = message_aggregate.type_ @@ -5050,12 +5047,12 @@ def _set_message_field( # noqa: PLR0913, PLR0912, PLR0915 self, message_context: ID, field: ID, - message_type: rty.Message, + message_type: ty.Message, value: ir.Expr, exception_handler: ExceptionHandler, is_global: Callable[[ID], bool], ) -> Sequence[Statement]: - if isinstance(value, ir.FieldAccess) and value.type_ == rty.OPAQUE: + if isinstance(value, ir.FieldAccess) and value.type_ == ty.OPAQUE: target_context = message_context target_message_type = message_type target_field = field @@ -5064,7 +5061,7 @@ def _set_message_field( # noqa: PLR0913, PLR0912, PLR0915 source_message_type = value.message_type self._state_machine_context.used_types_body.append(const.TYPES_LENGTH) - if isinstance(source_message_type, rty.Message): + if isinstance(source_message_type, ty.Message): return [ self._set_opaque_field_to_message_field( target_message_type.identifier, @@ -5077,7 +5074,7 @@ def _set_message_field( # noqa: PLR0913, PLR0912, PLR0915 ), ] - assert isinstance(source_message_type, rty.Structure) + assert isinstance(source_message_type, ty.Structure) return [ self._set_opaque_field_to_message_field_from_structure( target_message_type.identifier, @@ -5095,7 +5092,7 @@ def _set_message_field( # noqa: PLR0913, PLR0912, PLR0915 target_message_type = message_type target_field = field source_context = context_id(value.prefix, is_global) - assert isinstance(value.prefix_type, rty.Message) + assert isinstance(value.prefix_type, ty.Message) source_message_type = value.prefix_type return [ self._set_opaque_field_to_message( @@ -5112,13 +5109,13 @@ def _set_message_field( # noqa: PLR0913, PLR0912, PLR0915 message_model = self._model_type(message_type_id) assert isinstance(message_model, model.Message) field_type = message_type.field_types[field] - assert isinstance(field_type, rty.NamedTypeClass) + assert isinstance(field_type, ty.NamedTypeClass) statements: list[Statement] = [] result = statements - if isinstance(field_type, rty.Sequence): + if isinstance(field_type, ty.Sequence): size: Expr - if isinstance(value, ir.Var) and isinstance(value.type_, (rty.Message, rty.Sequence)): + if isinstance(value, ir.Var) and isinstance(value.type_, (ty.Message, ty.Sequence)): type_ = value.type_.identifier context = context_id(value.identifier, is_global) # TODO(eng/recordflux/RecordFlux#1742): Move check into IR @@ -5132,7 +5129,7 @@ def _set_message_field( # noqa: PLR0913, PLR0912, PLR0915 Variable(message_type_id * f"F_{field}"), ( Length(value.identifier) - if value.type_ == rty.OPAQUE + if value.type_ == ty.OPAQUE else Call(type_ * "Byte_Size", [Variable(context)]) ), ], @@ -5187,7 +5184,7 @@ def _set_message_field( # noqa: PLR0913, PLR0912, PLR0915 value, (ir.Var, ir.EnumLit, ir.BinaryIntExpr, ir.Size), ) - and isinstance(value.type_, (rty.AnyInteger, rty.Enumeration, rty.Aggregate)) + and isinstance(value.type_, (ty.AnyInteger, ty.Enumeration, ty.Aggregate)) ): if isinstance(value, ir.Agg) and len(value.elements) == 0: statements.append( @@ -5200,7 +5197,7 @@ def _set_message_field( # noqa: PLR0913, PLR0912, PLR0915 ada_value: ada.Expr if ( isinstance(value, ir.Var) - and isinstance(value.type_, rty.Enumeration) + and isinstance(value.type_, ty.Enumeration) and value.type_.always_valid ): ada_value = Selected(self._to_ada_expr(value, is_global), "Enum") @@ -5208,15 +5205,15 @@ def _set_message_field( # noqa: PLR0913, PLR0912, PLR0915 ada_value = Literal(value.identifier) else: ada_value = self._to_ada_expr(value, is_global) - if isinstance(field_type, (rty.Enumeration, rty.Integer)): + if isinstance(field_type, (ty.Enumeration, ty.Integer)): ada_value = QualifiedExpr( self._ada_type(field_type.identifier), ada_value, ) if ( not isinstance(ada_value, Number) - and isinstance(value.type_, rty.NamedTypeClass) - and value.type_ == rty.BOOLEAN + and isinstance(value.type_, ty.NamedTypeClass) + and value.type_ == ty.BOOLEAN and common.has_scalar_value_dependent_condition(message_model) ): to_base_integer = const.BUILTIN_TYPES_CONVERSIONS_PACKAGE * "To_Base_Integer" @@ -5243,7 +5240,7 @@ def _set_message_field( # noqa: PLR0913, PLR0912, PLR0915 ) if isinstance( value.type_, - (rty.Enumeration, rty.AnyInteger), + (ty.Enumeration, ty.AnyInteger), ) else None ), @@ -5253,7 +5250,7 @@ def _set_message_field( # noqa: PLR0913, PLR0912, PLR0915 if isinstance(value, ir.Agg) else ( Call(const.TYPES_TO_BIT_LENGTH, [Length(ada_value)]) - if isinstance(value.type_, rty.Sequence) + if isinstance(value.type_, ty.Sequence) else Number(0) ) ), @@ -5273,7 +5270,7 @@ def _set_message_field( # noqa: PLR0913, PLR0912, PLR0915 ), ], ) - elif isinstance(value, ir.Var) and isinstance(value.type_, rty.Sequence): + elif isinstance(value, ir.Var) and isinstance(value.type_, ty.Sequence): sequence_context = context_id(value.identifier, is_global) sequence_type_id = value.type_.identifier statements.extend( @@ -5307,7 +5304,7 @@ def _set_message_field( # noqa: PLR0913, PLR0912, PLR0915 ), ], ) - elif isinstance(value, ir.Var) and isinstance(value.type_, rty.Message): + elif isinstance(value, ir.Var) and isinstance(value.type_, ty.Message): _unsupported_expression(value, "in message aggregate") else: _unsupported_expression(value, "as value of message field") @@ -5820,7 +5817,7 @@ def _comprehension( # noqa: PLR0913 sequence_identifier: ID, sequence_type: ID, target_identifier: ID, - target_type: rty.Sequence | rty.Integer | rty.Enumeration | rty.Message, + target_type: ty.Sequence | ty.Integer | ty.Enumeration | ty.Message, iterator_identifier: ID, iterator_type: ID, selector_stmts: list[ir.Stmt], @@ -5835,9 +5832,9 @@ def _comprehension( # noqa: PLR0913 assert not isinstance(selector, ir.MsgAgg) assert ( - isinstance(target_type.element, (rty.Integer, rty.Enumeration, rty.Message)) - if isinstance(target_type, rty.Sequence) - else isinstance(target_type, (rty.Integer, rty.Enumeration, rty.Message)) + isinstance(target_type.element, (ty.Integer, ty.Enumeration, ty.Message)) + if isinstance(target_type, ty.Sequence) + else isinstance(target_type, (ty.Integer, ty.Enumeration, ty.Message)) ) target_type_id = target_type.identifier @@ -5885,7 +5882,7 @@ def _comprehension( # noqa: PLR0913 ], ), ] - if isinstance(target_type, (rty.Message, rty.Sequence)): + if isinstance(target_type, (ty.Message, ty.Sequence)): target_invariants += [ PragmaStatement( "Loop_Invariant", @@ -5932,7 +5929,7 @@ def _comprehension( # noqa: PLR0913 ), ] - if isinstance(target_type, rty.Sequence): + if isinstance(target_type, ty.Sequence): target_invariants += [ PragmaStatement( "Loop_Invariant", @@ -5956,7 +5953,7 @@ def _comprehension( # noqa: PLR0913 is_global, state, ) - if isinstance(target_type, rty.Sequence) + if isinstance(target_type, ty.Sequence) else self._comprehension_assign_element( target_identifier, target_type, @@ -6044,7 +6041,7 @@ def _comprehension( # noqa: PLR0913 def _comprehension_assign_element( # noqa: PLR0913 self, target_identifier: ID, - target_type: rty.Integer | rty.Enumeration | rty.Message, + target_type: ty.Integer | ty.Enumeration | ty.Message, selector_stmts: list[ir.Stmt], selector: ir.Expr, update_context: Sequence[Statement], @@ -6055,7 +6052,7 @@ def _comprehension_assign_element( # noqa: PLR0913 target_type_id = target_type.identifier assign_element: Sequence[Statement] - if isinstance(target_type, rty.Message): + if isinstance(target_type, ty.Message): if not isinstance(selector, ir.Var): fail( "expressions other than variables not yet supported" @@ -6128,7 +6125,7 @@ def _comprehension_assign_element( # noqa: PLR0913 ), ] - elif isinstance(target_type, (rty.Integer, rty.Enumeration)): + elif isinstance(target_type, (ty.Integer, ty.Enumeration)): assign_element = [ Assignment( Variable(variable_id(target_identifier, is_global)), @@ -6153,7 +6150,7 @@ def _comprehension_assign_element( # noqa: PLR0913 def _comprehension_append_element( # noqa: PLR0913 self, target_identifier: ID, - target_type: rty.Sequence, + target_type: ty.Sequence, selector_stmts: list[ir.Stmt], selector: ir.Expr, _: Sequence[Statement], @@ -6161,13 +6158,13 @@ def _comprehension_append_element( # noqa: PLR0913 is_global: Callable[[ID], bool], state: ID, ) -> Sequence[Statement]: - assert isinstance(target_type, rty.Sequence) + assert isinstance(target_type, ty.Sequence) target_type_id = target_type.identifier required_space: Expr append_element: list[Statement] - if isinstance(target_type.element, rty.Message): + if isinstance(target_type.element, ty.Message): if not isinstance(selector, ir.Var): fail( "expressions other than variables not yet supported" @@ -6212,13 +6209,13 @@ def _comprehension_append_element( # noqa: PLR0913 ), ] - elif isinstance(target_type.element, (rty.Integer, rty.Enumeration)): + elif isinstance(target_type.element, (ty.Integer, ty.Enumeration)): required_space = Size( ( target_type.element.identifier + "_Enum" if isinstance( target_type.element, - rty.Enumeration, + ty.Enumeration, ) and target_type.element.always_valid else target_type.element.identifier @@ -6489,7 +6486,7 @@ def _copy_to_buffer( def _convert_type( self, expression: ir.Expr, - target_type: rty.Type, + target_type: ty.Type, ) -> ir.Expr: if target_type.is_compatible_strong(expression.type_) and not isinstance( expression, @@ -6497,7 +6494,7 @@ def _convert_type( ): return expression - assert isinstance(target_type, (rty.Integer, rty.Enumeration)), target_type + assert isinstance(target_type, (ty.Integer, ty.Enumeration)), target_type self._state_machine_context.referenced_types_body.append(target_type.identifier) diff --git a/rflx/ir.py b/rflx/ir.py index 6ef13f309..d94654ac2 100644 --- a/rflx/ir.py +++ b/rflx/ir.py @@ -14,12 +14,12 @@ import z3 from attr import define, field, frozen -from rflx import typing_ as rty +from rflx import ty from rflx.common import Base from rflx.const import MAX_SCALAR_SIZE, MP_CONTEXT from rflx.error import info from rflx.identifier import ID, StrID -from rflx.rapidflux import Location, ty +from rflx.rapidflux import Location if TYPE_CHECKING: from rflx.model import type_decl @@ -170,7 +170,7 @@ def _update_str(self) -> None: @define(eq=False) class VarDecl(Stmt): identifier: ID = field(converter=ID) - type_: rty.NamedType + type_: ty.NamedType expression: ComplexExpr | None = None origin: Origin | None = None @@ -182,7 +182,7 @@ def preconditions(self, _variable_id: Generator[ID, None, None]) -> list[Cond]: return [] def to_z3_expr(self) -> z3.BoolRef: - if isinstance(self.type_, rty.Integer): + if isinstance(self.type_, ty.Integer): first = z3.Int(str(First(self.type_.identifier, self.type_))) last = z3.Int(str(Last(self.type_.identifier, self.type_))) return z3.And( @@ -202,7 +202,7 @@ def _update_str(self) -> None: class Assign(Stmt): target: ID = field(converter=ID) expression: Expr - type_: rty.NamedType + type_: ty.NamedType origin: Origin | None = None @property @@ -252,7 +252,7 @@ class FieldAssign(Stmt): message: ID = field(converter=ID) field: ID = field(converter=ID) expression: Expr - type_: rty.Message + type_: ty.Message origin: Origin | None = None @property @@ -293,7 +293,7 @@ def _update_str(self) -> None: class Append(Stmt): sequence: ID = field(converter=ID) expression: Expr - type_: rty.Sequence + type_: ty.Sequence origin: Origin | None = None @property @@ -326,7 +326,7 @@ def _update_str(self) -> None: class Extend(Stmt): sequence: ID = field(converter=ID) expression: Expr - type_: rty.Sequence + type_: ty.Sequence origin: Origin | None = None @property @@ -355,7 +355,7 @@ def _update_str(self) -> None: class Reset(Stmt): identifier: ID = field(converter=ID) parameter_values: Mapping[ID, Expr] - type_: rty.Any + type_: ty.Any origin: Origin | None = None @property @@ -477,7 +477,7 @@ def __str__(self) -> str: @property @abstractmethod - def type_(self) -> rty.Any: + def type_(self) -> ty.Any: raise NotImplementedError @property @@ -517,7 +517,7 @@ class BasicExpr(Expr): class IntExpr(Expr): @property @abstractmethod - def type_(self) -> rty.AnyInteger: + def type_(self) -> ty.AnyInteger: raise NotImplementedError @abstractmethod @@ -527,8 +527,8 @@ def to_z3_expr(self) -> z3.ArithRef: class BoolExpr(Expr): @property - def type_(self) -> rty.Enumeration: - return rty.BOOLEAN + def type_(self) -> ty.Enumeration: + return ty.BOOLEAN @abstractmethod def to_z3_expr(self) -> z3.BoolRef: @@ -538,7 +538,7 @@ def to_z3_expr(self) -> z3.BoolRef: class BasicIntExpr(BasicExpr, IntExpr): @property @abstractmethod - def type_(self) -> rty.AnyInteger: + def type_(self) -> ty.AnyInteger: raise NotImplementedError @@ -562,11 +562,11 @@ def _update_str(self) -> None: @define(eq=False) class IntVar(Var, BasicIntExpr): identifier: ID = field(converter=ID) - var_type: rty.AnyInteger + var_type: ty.AnyInteger origin: Origin | None = None @property - def type_(self) -> rty.AnyInteger: + def type_(self) -> ty.AnyInteger: return self.var_type def substituted(self, mapping: Mapping[ID, ID]) -> IntVar: @@ -584,8 +584,8 @@ class BoolVar(Var, BasicBoolExpr): origin: Origin | None = None @property - def type_(self) -> rty.Enumeration: - return rty.BOOLEAN + def type_(self) -> ty.Enumeration: + return ty.BOOLEAN def substituted(self, mapping: Mapping[ID, ID]) -> BoolVar: if self.identifier in mapping: @@ -599,11 +599,11 @@ def to_z3_expr(self) -> z3.BoolRef: @define(eq=False) class ObjVar(Var): identifier: ID = field(converter=ID) - var_type: rty.Any + var_type: ty.Any origin: Origin | None = None @property - def type_(self) -> rty.Any: + def type_(self) -> ty.Any: return self.var_type def substituted(self, mapping: Mapping[ID, ID]) -> ObjVar: @@ -618,11 +618,11 @@ def to_z3_expr(self) -> z3.ExprRef: @define(eq=False) class EnumLit(BasicExpr): identifier: ID = field(converter=ID) - enum_type: rty.Enumeration + enum_type: ty.Enumeration origin: Origin | None = None @property - def type_(self) -> rty.Enumeration: + def type_(self) -> ty.Enumeration: return self.enum_type @property @@ -645,8 +645,8 @@ class IntVal(BasicIntExpr): origin: Origin | None = None @property - def type_(self) -> rty.UniversalInteger: - return rty.UniversalInteger(ty.Bounds(self.value, self.value)) + def type_(self) -> ty.UniversalInteger: + return ty.UniversalInteger(ty.Bounds(self.value, self.value)) @property def accessed_vars(self) -> list[ID]: @@ -684,7 +684,7 @@ def _update_str(self) -> None: @define(eq=False) class Attr(Expr): prefix: ID = field(converter=ID) - prefix_type: rty.Any + prefix_type: ty.Any origin: Origin | None = None @property @@ -706,7 +706,7 @@ def _update_str(self) -> None: @define(eq=False) class IntAttr(Attr, IntExpr): prefix: ID = field(converter=ID) - prefix_type: rty.Any + prefix_type: ty.Any origin: Origin | None = None def substituted(self, mapping: Mapping[ID, ID]) -> IntAttr: @@ -723,33 +723,33 @@ def to_z3_expr(self) -> z3.ArithRef: @define(eq=False) class Size(IntAttr): @property - def type_(self) -> rty.AnyInteger: + def type_(self) -> ty.AnyInteger: return ( - rty.BIT_LENGTH - if isinstance(self.prefix_type, (rty.Composite, rty.Compound)) - else rty.UNIVERSAL_INTEGER + ty.BIT_LENGTH + if isinstance(self.prefix_type, (ty.Composite, ty.Compound)) + else ty.UNIVERSAL_INTEGER ) @define(eq=False) class Length(IntAttr): @property - def type_(self) -> rty.UniversalInteger: - return rty.UNIVERSAL_INTEGER + def type_(self) -> ty.UniversalInteger: + return ty.UNIVERSAL_INTEGER @define(eq=False) class First(IntAttr): @property - def type_(self) -> rty.UniversalInteger: - return rty.UNIVERSAL_INTEGER + def type_(self) -> ty.UniversalInteger: + return ty.UNIVERSAL_INTEGER @define(eq=False) class Last(IntAttr): @property - def type_(self) -> rty.UniversalInteger: - return rty.UNIVERSAL_INTEGER + def type_(self) -> ty.UniversalInteger: + return ty.UNIVERSAL_INTEGER @define(eq=False) @@ -767,8 +767,8 @@ def to_z3_expr(self) -> z3.BoolRef: @define(eq=False) class Present(Attr): @property - def type_(self) -> rty.Enumeration: - return rty.BOOLEAN + def type_(self) -> ty.Enumeration: + return ty.BOOLEAN def to_z3_expr(self) -> z3.ExprRef: return z3.Bool(str(self)) @@ -783,12 +783,12 @@ def to_z3_expr(self) -> z3.BoolRef: @define(eq=False) class Head(Attr): prefix: ID = field(converter=ID) - prefix_type: rty.Composite + prefix_type: ty.Composite origin: Origin | None = None @property - def type_(self) -> rty.Any: - assert isinstance(self.prefix_type.element, rty.Any) + def type_(self) -> ty.Any: + assert isinstance(self.prefix_type.element, ty.Any) return self.prefix_type.element def to_z3_expr(self) -> z3.ExprRef: @@ -798,12 +798,12 @@ def to_z3_expr(self) -> z3.ExprRef: @define(eq=False) class Opaque(Attr): prefix: ID = field(converter=ID) - prefix_type: rty.Message | rty.Sequence + prefix_type: ty.Message | ty.Sequence origin: Origin | None = None @property - def type_(self) -> rty.Sequence: - return rty.OPAQUE + def type_(self) -> ty.Sequence: + return ty.OPAQUE def preconditions(self, _variable_id: Generator[ID, None, None]) -> list[Cond]: return [ @@ -818,7 +818,7 @@ def to_z3_expr(self) -> z3.ExprRef: class FieldAccessAttr(Expr): message: ID = field(converter=ID) field: ID = field(converter=ID) - message_type: rty.Compound + message_type: ty.Compound origin: Origin | None = None @property @@ -826,9 +826,9 @@ def accessed_vars(self) -> list[ID]: return [self.message] @property - def field_type(self) -> rty.Any: + def field_type(self) -> ty.Any: type_ = self.message_type.field_types[self.field] - assert isinstance(type_, rty.Any) + assert isinstance(type_, ty.Any) return type_ def substituted(self, mapping: Mapping[ID, ID]) -> FieldAccessAttr: @@ -880,15 +880,15 @@ def _symbol(self) -> str: @define(eq=False) class FieldSize(FieldAccessAttr, IntExpr): @property - def type_(self) -> rty.Integer: - return rty.BIT_LENGTH + def type_(self) -> ty.Integer: + return ty.BIT_LENGTH def preconditions(self, _variable_id: Generator[ID, None, None]) -> list[Cond]: return ( [ Cond(FieldValidNext(self.message, self.field, self.message_type)), ] - if isinstance(self.message_type, rty.Message) + if isinstance(self.message_type, ty.Message) else [] ) @@ -919,7 +919,7 @@ class UnaryIntExpr(UnaryExpr, IntExpr): origin: Origin | None = None @property - def type_(self) -> rty.AnyInteger: + def type_(self) -> ty.AnyInteger: return self.expression.type_ @@ -978,14 +978,14 @@ class BinaryIntExpr(BinaryExpr, IntExpr): def preconditions( self, variable_id: Generator[ID, None, None], - target_type: rty.Type | None = None, + target_type: ty.Type | None = None, ) -> list[Cond]: raise NotImplementedError @property - def type_(self) -> rty.AnyInteger: + def type_(self) -> ty.AnyInteger: type_ = self.left.type_.common_type(self.right.type_) - assert isinstance(type_, rty.AnyInteger) + assert isinstance(type_, ty.AnyInteger) return type_ @@ -1013,14 +1013,14 @@ def to_z3_expr(self) -> z3.ArithRef: def preconditions( self, variable_id: Generator[ID, None, None], - target_type: rty.Type | None = None, + target_type: ty.Type | None = None, ) -> list[Cond]: target_type = target_type or self.type_ v_id = next(variable_id) - v_type = rty.BASE_INTEGER + v_type = ty.BASE_INTEGER upper_bound = ( target_type.bounds.upper - if isinstance(target_type, rty.AnyInteger) and target_type.bounds is not None + if isinstance(target_type, ty.AnyInteger) and target_type.bounds is not None else INT_MAX ) return [ @@ -1051,11 +1051,11 @@ def preconditions( IntVal(upper_bound), ( IntConversion( - rty.BASE_INTEGER, + ty.BASE_INTEGER, self.right, ) - if self.right.type_ != rty.BASE_INTEGER - and not isinstance(self.right.type_, rty.UniversalInteger) + if self.right.type_ != ty.BASE_INTEGER + and not isinstance(self.right.type_, ty.UniversalInteger) else self.right ), ), @@ -1079,7 +1079,7 @@ def to_z3_expr(self) -> z3.ArithRef: def preconditions( self, variable_id: Generator[ID, None, None], - _target_type: rty.Type | None = None, + _target_type: ty.Type | None = None, ) -> list[Cond]: return [ *self.left.preconditions(variable_id), @@ -1101,14 +1101,14 @@ def to_z3_expr(self) -> z3.ArithRef: def preconditions( self, variable_id: Generator[ID, None, None], - target_type: rty.Type | None = None, + target_type: ty.Type | None = None, ) -> list[Cond]: target_type = target_type or self.type_ v_id = next(variable_id) - v_type = rty.BASE_INTEGER + v_type = ty.BASE_INTEGER upper_bound = ( target_type.bounds.upper - if isinstance(target_type, rty.AnyInteger) and target_type.bounds is not None + if isinstance(target_type, ty.AnyInteger) and target_type.bounds is not None else INT_MAX ) return [ @@ -1151,7 +1151,7 @@ def to_z3_expr(self) -> z3.ArithRef: def preconditions( self, variable_id: Generator[ID, None, None], - _target_type: rty.Type | None = None, + _target_type: ty.Type | None = None, ) -> list[Cond]: return [ *self.left.preconditions(variable_id), @@ -1173,14 +1173,14 @@ def to_z3_expr(self) -> z3.ArithRef: def preconditions( self, variable_id: Generator[ID, None, None], - target_type: rty.Type | None = None, + target_type: ty.Type | None = None, ) -> list[Cond]: target_type = target_type or self.type_ v_id = next(variable_id) - v_type = rty.BASE_INTEGER + v_type = ty.BASE_INTEGER upper_bound = ( target_type.bounds.upper - if isinstance(target_type, rty.AnyInteger) and target_type.bounds is not None + if isinstance(target_type, ty.AnyInteger) and target_type.bounds is not None else INT_MAX ) return [ @@ -1212,7 +1212,7 @@ def to_z3_expr(self) -> z3.ArithRef: def preconditions( self, variable_id: Generator[ID, None, None], - _target_type: rty.Type | None = None, + _target_type: ty.Type | None = None, ) -> list[Cond]: return [ *self.left.preconditions(variable_id), @@ -1342,7 +1342,7 @@ def _symbol(self) -> str: class Call(Expr): identifier: ID = field(converter=ID) arguments: Sequence[Expr] - argument_types: Sequence[rty.Any] + argument_types: Sequence[ty.Any] origin: Origin | None = None _preconditions: list[Cond] = field(init=False, factory=list) @@ -1370,8 +1370,8 @@ def _update_str(self) -> None: class IntCall(Call, IntExpr): identifier: ID = field(converter=ID) arguments: Sequence[Expr] - argument_types: Sequence[rty.Any] - type_: rty.AnyInteger + argument_types: Sequence[ty.Any] + type_: ty.AnyInteger origin: Origin | None = None def substituted(self, mapping: Mapping[ID, ID]) -> IntCall: @@ -1415,8 +1415,8 @@ def to_z3_expr(self) -> z3.BoolRef: class ObjCall(Call): identifier: ID = field(converter=ID) arguments: Sequence[Expr] - argument_types: Sequence[rty.Any] - type_: rty.Any + argument_types: Sequence[ty.Any] + type_: ty.Any origin: Origin | None = None def substituted(self, mapping: Mapping[ID, ID]) -> ObjCall: @@ -1436,7 +1436,7 @@ def to_z3_expr(self) -> z3.ExprRef: class FieldAccess(Expr): message: ID = field(converter=ID) field: ID = field(converter=ID) - message_type: rty.Compound + message_type: ty.Compound origin: Origin | None = None @property @@ -1454,7 +1454,7 @@ def substituted(self, mapping: Mapping[ID, ID]) -> FieldAccess: def preconditions(self, _: Generator[ID, None, None]) -> list[Cond]: return ( [Cond(FieldValid(self.message, self.field, self.message_type, self.origin))] - if isinstance(self.message_type, rty.Message) + if isinstance(self.message_type, ty.Message) and self.field in self.message_type.field_types else [] ) @@ -1467,13 +1467,13 @@ def _update_str(self) -> None: class IntFieldAccess(FieldAccess, IntExpr): message: ID = field(converter=ID) field: ID = field(converter=ID) - message_type: rty.Compound + message_type: ty.Compound origin: Origin | None = None @property - def type_(self) -> rty.AnyInteger: + def type_(self) -> ty.AnyInteger: type_ = self.message_type.types[self.field] - assert isinstance(type_, rty.AnyInteger) + assert isinstance(type_, ty.AnyInteger) return type_ def substituted(self, mapping: Mapping[ID, ID]) -> IntFieldAccess: @@ -1501,13 +1501,13 @@ def to_z3_expr(self) -> z3.BoolRef: class ObjFieldAccess(FieldAccess): message: ID = field(converter=ID) field: ID = field(converter=ID) - message_type: rty.Compound + message_type: ty.Compound origin: Origin | None = None @property - def type_(self) -> rty.Any: + def type_(self) -> ty.Any: type_ = self.message_type.field_types[self.field] - assert isinstance(type_, rty.Any) + assert isinstance(type_, ty.Any) return type_ def substituted(self, mapping: Mapping[ID, ID]) -> ObjFieldAccess: @@ -1562,11 +1562,11 @@ class IntIfExpr(IfExpr, IntExpr): condition: BasicBoolExpr then_expr: ComplexIntExpr else_expr: ComplexIntExpr - return_type: rty.AnyInteger + return_type: ty.AnyInteger origin: Origin | None = None @property - def type_(self) -> rty.AnyInteger: + def type_(self) -> ty.AnyInteger: return self.return_type def to_z3_expr(self) -> z3.ArithRef: @@ -1583,8 +1583,8 @@ class BoolIfExpr(IfExpr, BoolExpr): origin: Origin | None = None @property - def type_(self) -> rty.Enumeration: - return rty.BOOLEAN + def type_(self) -> ty.Enumeration: + return ty.BOOLEAN def to_z3_expr(self) -> z3.BoolRef: result = super().to_z3_expr() @@ -1594,12 +1594,12 @@ def to_z3_expr(self) -> z3.BoolRef: @define(eq=False) class Conversion(Expr): - target_type: rty.NamedType + target_type: ty.NamedType argument: Expr origin: Origin | None = None @property - def type_(self) -> rty.Any: + def type_(self) -> ty.Any: return self.target_type @property @@ -1625,12 +1625,12 @@ def _update_str(self) -> None: @define(eq=False) class IntConversion(Conversion, BasicIntExpr): - target_type: rty.Integer + target_type: ty.Integer argument: IntExpr origin: Origin | None = None @property - def type_(self) -> rty.Integer: + def type_(self) -> ty.Integer: return self.target_type def preconditions(self, variable_id: Generator[ID, None, None]) -> list[Cond]: @@ -1640,11 +1640,11 @@ def preconditions(self, variable_id: Generator[ID, None, None]) -> list[Cond]: Cond( LessEqual( IntConversion( - rty.BASE_INTEGER, + ty.BASE_INTEGER, First(self.target_type.identifier, self.argument.type_), ), IntConversion( - rty.BASE_INTEGER, + ty.BASE_INTEGER, self.argument, ), ), @@ -1653,11 +1653,11 @@ def preconditions(self, variable_id: Generator[ID, None, None]) -> list[Cond]: Cond( LessEqual( IntConversion( - rty.BASE_INTEGER, + ty.BASE_INTEGER, self.argument, ), IntConversion( - rty.BASE_INTEGER, + ty.BASE_INTEGER, Last(self.target_type.identifier, self.argument.type_), ), ), @@ -1677,8 +1677,8 @@ class Comprehension(Expr): origin: Origin | None = None @property - def type_(self) -> rty.Aggregate: - return rty.Aggregate(self.selector.expr.type_) + def type_(self) -> ty.Aggregate: + return ty.Aggregate(self.selector.expr.type_) @property def accessed_vars(self) -> list[ID]: @@ -1724,7 +1724,7 @@ class Find(Expr): origin: Origin | None = None @property - def type_(self) -> rty.Any: + def type_(self) -> ty.Any: return self.selector.expr.type_ @property @@ -1768,8 +1768,8 @@ class Agg(Expr): origin: Origin | None = None @property - def type_(self) -> rty.Aggregate: - return rty.Aggregate(rty.common_type([e.type_ for e in self.elements])) + def type_(self) -> ty.Aggregate: + return ty.Aggregate(ty.common_type([e.type_ for e in self.elements])) @property def accessed_vars(self) -> list[ID]: @@ -1807,7 +1807,7 @@ class NamedAgg(Expr): origin: Origin | None = None @property - def type_(self) -> rty.Any: + def type_(self) -> ty.Any: raise NotImplementedError @property @@ -1833,8 +1833,8 @@ class Str(Expr): origin: Origin | None = None @property - def type_(self) -> rty.Sequence: - return rty.OPAQUE + def type_(self) -> ty.Sequence: + return ty.OPAQUE @property def accessed_vars(self) -> list[ID]: @@ -1857,7 +1857,7 @@ def _update_str(self) -> None: class MsgAgg(Expr): identifier: ID = field(converter=ID) field_values: Mapping[ID, Expr] - type_: rty.Message + type_: ty.Message origin: Origin | None = None @property @@ -1891,7 +1891,7 @@ def _update_str(self) -> None: class DeltaMsgAgg(Expr): identifier: ID = field(converter=ID) field_values: Mapping[ID, Expr] - type_: rty.Message + type_: ty.Message origin: Origin | None = None @property @@ -1925,7 +1925,7 @@ def _update_str(self) -> None: class CaseExpr(Expr): expression: BasicExpr choices: Sequence[tuple[Sequence[BasicExpr], BasicExpr]] - type_: rty.Any + type_: ty.Any origin: Origin | None = None @property @@ -1964,7 +1964,7 @@ def _update_str(self) -> None: class SufficientSpace(FieldAccessAttr, BoolExpr): message: ID = field(converter=ID) field: ID = field(converter=ID) - message_type: rty.Message + message_type: ty.Message origin: Origin | None = None def to_z3_expr(self) -> z3.BoolRef: @@ -1974,7 +1974,7 @@ def to_z3_expr(self) -> z3.BoolRef: @define(eq=False) class HasElement(Attr, BoolExpr): prefix: ID = field(converter=ID) - prefix_type: rty.Sequence + prefix_type: ty.Sequence origin: Origin | None = None def to_z3_expr(self) -> z3.BoolRef: @@ -1996,7 +1996,7 @@ class FormalDecl(Decl): class Argument: identifier: ID = field(converter=ID) type_identifier: ID = field(converter=ID) - type_: rty.Type + type_: ty.Type @frozen @@ -2004,7 +2004,7 @@ class FuncDecl(FormalDecl): identifier: ID = field(converter=ID) arguments: Sequence[Argument] return_type: ID = field(converter=ID) - type_: rty.Type + type_: ty.Type location: Location | None @@ -2117,7 +2117,7 @@ def __init__( # noqa: PLR0913 Assign( "RFLX_Transition_Condition", t.condition.expr, - rty.BOOLEAN, + ty.BOOLEAN, ), ], variable_id, @@ -2202,7 +2202,7 @@ def add_conversions(statements: Sequence[Stmt]) -> list[Stmt]: @singledispatch def _convert_expression( expression: Expr, # noqa: ARG001 - target_type: rty.Type, # noqa: ARG001 + target_type: ty.Type, # noqa: ARG001 ) -> Expr: raise NotImplementedError @@ -2211,7 +2211,7 @@ def _convert_expression( @_convert_expression.register(DeltaMsgAgg) def _( expression: MsgAgg | DeltaMsgAgg, - target_type: rty.Type, # noqa: ARG001 + target_type: ty.Type, # noqa: ARG001 ) -> Expr: field_values: dict[ID, Expr] = { f: _convert_expression(v, expression.type_.types[f]) @@ -2231,19 +2231,19 @@ def _( @_convert_expression.register(Expr) def _( expression: BinaryIntExpr | IntExpr | Expr, - target_type: rty.Type, + target_type: ty.Type, ) -> Expr: result: Expr if ( target_type.is_compatible_strong(expression.type_) and not isinstance(expression, BinaryIntExpr) - or not isinstance(target_type, (rty.Integer, rty.Enumeration)) + or not isinstance(target_type, (ty.Integer, ty.Enumeration)) ): return expression if isinstance(expression, BinaryIntExpr): - assert isinstance(target_type, rty.Integer) + assert isinstance(target_type, ty.Integer) left = ( expression.left if target_type.is_compatible_strong(expression.left.type_) @@ -2269,7 +2269,7 @@ def _( ) elif isinstance(expression, IntExpr): - assert isinstance(target_type, rty.Integer) + assert isinstance(target_type, ty.Integer) result = IntConversion( target_type, expression, @@ -2367,5 +2367,5 @@ def add_required_checks( return result -def to_integer(type_: rty.AnyInteger) -> rty.Integer: - return type_ if isinstance(type_, rty.Integer) else rty.BASE_INTEGER +def to_integer(type_: ty.AnyInteger) -> ty.Integer: + return type_ if isinstance(type_, ty.Integer) else ty.BASE_INTEGER diff --git a/rflx/model/declaration.py b/rflx/model/declaration.py index b080f2f5e..576a24593 100644 --- a/rflx/model/declaration.py +++ b/rflx/model/declaration.py @@ -4,8 +4,7 @@ from collections.abc import Callable, Generator, Sequence from typing import ClassVar -import rflx.typing_ as rty -from rflx import expr_conv, ir +from rflx import expr_conv, ir, ty from rflx.common import Base from rflx.error import fail from rflx.expr import Expr, Selected, Variable @@ -32,7 +31,7 @@ def reference(self) -> None: @property @abstractmethod - def type_(self) -> rty.Type: + def type_(self) -> ty.Type: raise NotImplementedError @property @@ -55,12 +54,12 @@ def __init__( self, identifier: StrID, type_identifier: StrID, - type_: rty.Type = rty.UNDEFINED, + type_: ty.Type = ty.UNDEFINED, location: Location | None = None, ): super().__init__(identifier, location) self._type_identifier = ID(type_identifier) - self._type: rty.Type = type_ + self._type: ty.Type = type_ @property def type_identifier(self) -> ID: @@ -71,17 +70,17 @@ def type_identifier(self, identifier: ID) -> None: self._type_identifier = identifier @property - def type_(self) -> rty.Type: + def type_(self) -> ty.Type: return self._type @type_.setter - def type_(self, value: rty.Type) -> None: + def type_(self, value: ty.Type) -> None: self._type = value @abstractmethod def check_type( self, - declaration_type: rty.Type, + declaration_type: ty.Type, typify_variable: Callable[[Expr], Expr], ) -> RecordFluxError: """Set the types of the declaration and variables, and check the types of expressions.""" @@ -96,7 +95,7 @@ def __init__( identifier: StrID, type_identifier: StrID, expression: Expr | None = None, - type_: rty.Type = rty.UNDEFINED, + type_: ty.Type = ty.UNDEFINED, location: Location | None = None, ): super().__init__(identifier, type_identifier, type_, location) @@ -108,7 +107,7 @@ def __str__(self) -> str: def check_type( self, - declaration_type: rty.Type, + declaration_type: ty.Type, typify_variable: Callable[[Expr], Expr], ) -> RecordFluxError: self.type_ = declaration_type @@ -125,7 +124,7 @@ def variables(self) -> Sequence[Variable]: return [] def to_ir(self, variable_id: Generator[ID, None, None]) -> ir.VarDecl: - assert isinstance(self.type_, rty.NamedTypeClass), self.type_ + assert isinstance(self.type_, ty.NamedTypeClass), self.type_ expression = expr_conv.to_ir(self.expression, variable_id) if self.expression else None return ir.VarDecl( self.identifier, @@ -143,7 +142,7 @@ def __init__( identifier: StrID, type_identifier: StrID, expression: Selected, - type_: rty.Type = rty.UNDEFINED, + type_: ty.Type = ty.UNDEFINED, location: Location | None = None, ): super().__init__(identifier, type_identifier, type_, location) @@ -154,7 +153,7 @@ def __str__(self) -> str: def check_type( self, - declaration_type: rty.Type, + declaration_type: ty.Type, typify_variable: Callable[[Expr], Expr], ) -> RecordFluxError: self.type_ = declaration_type @@ -162,11 +161,11 @@ def check_type( assert isinstance(expression, Selected) self.expression = expression - error = self.expression.prefix.check_type_instance(rty.Message) + error = self.expression.prefix.check_type_instance(ty.Message) if error.has_errors(): return error - assert isinstance(self.expression.prefix.type_, rty.Message) + assert isinstance(self.expression.prefix.type_, ty.Message) error = RecordFluxError() for r in self.expression.prefix.type_.refinements: @@ -194,7 +193,7 @@ def check_type( ), ], ) - error.extend(self.expression.check_type(rty.OPAQUE).entries) + error.extend(self.expression.check_type(ty.OPAQUE).entries) return error def variables(self) -> Sequence[Variable]: @@ -216,7 +215,7 @@ def to_ir(self) -> ir.FormalDecl: class Parameter(Base): - def __init__(self, identifier: StrID, type_identifier: StrID, type_: rty.Type = rty.UNDEFINED): + def __init__(self, identifier: StrID, type_identifier: StrID, type_: ty.Type = ty.UNDEFINED): super().__init__() self._identifier = ID(identifier) self._type_identifier = ID(type_identifier) @@ -245,7 +244,7 @@ def __init__( identifier: StrID, parameters: Sequence[Parameter], return_type: StrID, - type_: rty.Type = rty.UNDEFINED, + type_: ty.Type = ty.UNDEFINED, location: Location | None = None, ): super().__init__(identifier, return_type, type_, location) @@ -262,7 +261,7 @@ def __str__(self) -> str: def check_type( self, - declaration_type: rty.Type, + declaration_type: ty.Type, _typify_variable: Callable[[Expr], Expr], ) -> RecordFluxError: self.type_ = declaration_type @@ -311,8 +310,8 @@ def __str__(self) -> str: return f"{self.identifier} : Channel{with_aspects}" @property - def type_(self) -> rty.Type: - return rty.Channel(self.readable, self.writable) + def type_(self) -> ty.Type: + return ty.Channel(self.readable, self.writable) @property def readable(self) -> bool: diff --git a/rflx/model/message.py b/rflx/model/message.py index 72feca0cb..9f4f477ef 100644 --- a/rflx/model/message.py +++ b/rflx/model/message.py @@ -10,8 +10,7 @@ from functools import cached_property, partial from typing import Callable -import rflx.typing_ as rty -from rflx import expr, expr_proof +from rflx import expr, expr_proof, ty from rflx.common import Base, indent, indent_next, unique, verbose_repr from rflx.const import MP_CONTEXT from rflx.error import are_all_locations_present, fail, fatal_fail @@ -532,8 +531,8 @@ def set_refinements(self, refinements: list[Refinement]) -> None: self._refinements = refinements @property - def type_(self) -> rty.Message: - return rty.Message( + def type_(self) -> ty.Message: + return ty.Message( self.full_name, ( { @@ -548,7 +547,7 @@ def type_(self) -> rty.Message: ), {f.identifier: t.type_ for f, t in self._parameter_types.items()}, {f.identifier: t.type_ for f, t in self._field_types.items()}, - [rty.Refinement(r.field.identifier, r.sdu.type_, r.package) for r in self._refinements], + [ty.Refinement(r.field.identifier, r.sdu.type_, r.package) for r in self._refinements], self.is_definite, ) @@ -925,13 +924,13 @@ def _validate_types( if f in parameters: if not isinstance(t, type_decl.Scalar): assert f.identifier.location is not None - additionnal_annotations = [] + additional_annotations = [] if not ( type_decl.is_builtin_type(t.identifier) or type_decl.is_internal_type(t.identifier) ): assert t.identifier.location is not None - additionnal_annotations.append( + additional_annotations.append( Annotation( "type declared here", Severity.NOTE, @@ -940,11 +939,11 @@ def _validate_types( ) self.error.extend( - rty.check_type_instance( + ty.check_type_instance( t.type_, - (rty.Enumeration, rty.AnyInteger), + (ty.Enumeration, ty.AnyInteger), location=f.identifier.location, - additionnal_annotations=additionnal_annotations, + additional_annotations=additional_annotations, ).entries, ) elif isinstance(t, type_decl.Enumeration) and t.always_valid: @@ -1532,7 +1531,7 @@ def _verify_message_types(self) -> None: if expression == expr.UNDEFINED: continue for var in expression.variables(): - if var.type_ == rty.Undefined(): + if var.type_ == ty.Undefined(): self.error.push( ErrorEntry( f'undefined variable "{var.identifier}"', @@ -1550,10 +1549,10 @@ def typed_variable(expression: expr.Expr) -> expr.Expr: def remove_types(expression: expr.Expr) -> expr.Expr: if isinstance(expression, expr.Variable): expression = copy(expression) - expression.type_ = rty.Undefined() + expression.type_ = ty.Undefined() return expression - def check_expr_type(expression: expr.Expr, ty: rty.Type, path: Sequence[Link]) -> None: + def check_expr_type(expression: expr.Expr, ty: ty.Type, path: Sequence[Link]) -> None: entries = expression.check_type(ty).entries assert self.location is not None @@ -1594,7 +1593,7 @@ def check_expr_type(expression: expr.Expr, ty: rty.Type, path: Sequence[Link]) - types[l.source] = self.types[l.source] substituted_condition = l.condition.substituted(typed_variable) - check_expr_type(substituted_condition, rty.BOOLEAN, path) + check_expr_type(substituted_condition, ty.BOOLEAN, path) for expression in ( l.size.substituted(typed_variable), @@ -1603,7 +1602,7 @@ def check_expr_type(expression: expr.Expr, ty: rty.Type, path: Sequence[Link]) - if expression == expr.UNDEFINED: continue - check_expr_type(expression, rty.BASE_INTEGER, path) + check_expr_type(expression, ty.BASE_INTEGER, path) def _verify_links(self) -> None: for link in self.structure: @@ -2608,7 +2607,7 @@ def _verify_field(self) -> None: ) def _verify_condition(self) -> None: - self.error.extend(self.condition.check_type(rty.Any()).entries) + self.error.extend(self.condition.check_type(ty.Any()).entries) if not self.error.has_errors() and self.condition != expr.TRUE: for cond, val in [ @@ -3005,7 +3004,7 @@ def typed_variable(expression: expr.Expr) -> expr.Expr: .substituted(substitute_message_variables) .substituted(typed_variable) ) - merged_condition.check_type(rty.Any()).propagate() + merged_condition.check_type(ty.Any()).propagate() proof = expr_proof.Proof( merged_condition, [ @@ -3400,7 +3399,7 @@ def typed_expression( *qualified_type_names, }, f'variable "{expression.identifier}" has the same name as a literal' if expression.name.lower() == "message": - expression.type_ = rty.OPAQUE + expression.type_ = ty.OPAQUE elif Field(expression.identifier) in types: expression.type_ = types[Field(expression.identifier)].type_ if isinstance(expression, expr.Literal) and expression.identifier in qualified_enum_literals: diff --git a/rflx/model/state_machine.py b/rflx/model/state_machine.py index 86ab08e49..58ba31ccd 100644 --- a/rflx/model/state_machine.py +++ b/rflx/model/state_machine.py @@ -9,7 +9,7 @@ from functools import lru_cache from typing import Final -from rflx import expr, expr_conv, ir, typing_ as rty +from rflx import expr, expr_conv, ir, ty from rflx.common import Base, indent, indent_next, verbose_repr from rflx.identifier import ID, StrID, id_generator from rflx.rapidflux import Annotation, ErrorEntry, Location, RecordFluxError, Severity @@ -159,7 +159,7 @@ def has_expression_exceptions(expression: expr.Expr) -> bool: isinstance(a, (stmt.Append, stmt.Extend, stmt.MessageFieldAssignment)) or ( isinstance(a, stmt.VariableAssignment) - and (isinstance(a.type_, rty.Message) or (has_expression_exceptions(a.expression))) + and (isinstance(a.type_, ty.Message) or (has_expression_exceptions(a.expression))) ) for a in self._actions ) or any(has_expression_exceptions(t.condition) for t in self._transitions) @@ -277,11 +277,11 @@ def find_attribute_prefix(name: ID, expression: expr.Expr) -> bool: ), ) - def substituted(expression: expr.Expr, structure: rty.Structure) -> expr.Expr: + def substituted(expression: expr.Expr, structure: ty.Structure) -> expr.Expr: def replace_expression_type(expression: expr.Expr) -> expr.Expr: if ( isinstance(expression, (expr.Variable, expr.Call)) - and isinstance(expression.type_, rty.Message) + and isinstance(expression.type_, ty.Message) and expression.type_.identifier == structure.identifier ): expression.type_ = structure @@ -308,7 +308,7 @@ def contains_unsupported_feature(name: ID, action: stmt.Statement) -> bool: for name, declaration in self.declarations.items(): if ( not isinstance(declaration, decl.VariableDeclaration) - or not isinstance(declaration.type_, rty.Message) + or not isinstance(declaration.type_, ty.Message) or not declaration.type_.is_definite or ( isinstance(declaration, decl.VariableDeclaration) @@ -325,7 +325,7 @@ def contains_unsupported_feature(name: ID, action: stmt.Statement) -> bool: identifier=declaration.identifier, type_identifier=declaration.type_identifier, expression=declaration.expression, - type_=rty.Structure( + type_=ty.Structure( identifier=declaration.type_.identifier, field_combinations=declaration.type_.field_combinations, parameter_types=declaration.type_.parameter_types, @@ -333,7 +333,7 @@ def contains_unsupported_feature(name: ID, action: stmt.Statement) -> bool: ), location=declaration.location, ) - assert isinstance(message_decl.type_, rty.Structure) + assert isinstance(message_decl.type_, ty.Structure) self.declarations[name] = message_decl for action in self._actions: @@ -725,7 +725,7 @@ def undefined_type(type_identifier: StrID, location: Location | None) -> None: ) else: undefined_type(d.type_identifier, d.location) - d.type_ = rty.Any() + d.type_ = ty.Any() if isinstance(d, decl.FunctionDeclaration): for p in d.parameters: @@ -738,7 +738,7 @@ def undefined_type(type_identifier: StrID, location: Location | None) -> None: p.type_ = argument_type.type_ self._validate_function_parameter_type(parameter_id) else: - p.type_ = rty.Any() + p.type_ = ty.Any() undefined_type(p.type_identifier, d.location) return_type_id = type_decl.internal_type_identifier( @@ -772,7 +772,7 @@ def _validate_function_parameter_type(self, type_identifier: ID) -> None: ) if ( not isinstance(parameter_type, (type_decl.Scalar, Message)) - and parameter_type.identifier != rty.OPAQUE.identifier + and parameter_type.identifier != ty.OPAQUE.identifier ): assert type_identifier.location self.error.extend( @@ -843,7 +843,7 @@ def _validate_actions( try: type_ = declarations[a.identifier].type_ except KeyError: - type_ = rty.Undefined() + type_ = ty.Undefined() self.error.extend( a.check_type( @@ -952,7 +952,7 @@ def _validate_transitions( ) -> None: for t in state.transitions: t.condition = t.condition.substituted(lambda x: self._typify_variable(x, declarations)) - self.error.extend(t.condition.check_type(rty.BOOLEAN).entries) + self.error.extend(t.condition.check_type(ty.BOOLEAN).entries) self._reference_variable_declaration(t.condition.variables(), declarations) t.condition.substituted(lambda e: error_on_unsupported_expression(e, self.error)) @@ -1149,7 +1149,7 @@ def error_on_unsupported_expression(expression: expr.Expr, error: RecordFluxErro # TODO(eng/recordflux/RecordFlux#1497): Support comparisons of opaque fields if isinstance(expression, (expr.Equal, expr.NotEqual)): for e in [expression.left, expression.right]: - if isinstance(e, expr.Selected) and e.type_ == rty.OPAQUE: + if isinstance(e, expr.Selected) and e.type_ == ty.OPAQUE: error.push( ErrorEntry( "comparisons of opaque fields not yet supported", diff --git a/rflx/model/statement.py b/rflx/model/statement.py index d192d709f..406887511 100644 --- a/rflx/model/statement.py +++ b/rflx/model/statement.py @@ -3,7 +3,7 @@ from abc import abstractmethod from collections.abc import Callable, Generator, Mapping, Sequence -from rflx import expr_conv, ir, typing_ as rty +from rflx import expr_conv, ir, ty from rflx.common import Base from rflx.expr import Expr, Variable from rflx.identifier import ID, StrID @@ -14,7 +14,7 @@ class Statement(Base): def __init__( self, identifier: StrID, - type_: rty.Type = rty.UNDEFINED, + type_: ty.Type = ty.UNDEFINED, location: Location | None = None, ): self.identifier = ID(identifier) @@ -24,7 +24,7 @@ def __init__( @abstractmethod def check_type( self, - statement_type: rty.Type, + statement_type: ty.Type, typify_variable: Callable[[Expr], Expr], ) -> RecordFluxError: """Set the types of variables, and check the types of the statement and expressions.""" @@ -45,7 +45,7 @@ def __init__( self, identifier: StrID, expression: Expr, - type_: rty.Type = rty.UNDEFINED, + type_: ty.Type = ty.UNDEFINED, location: Location | None = None, ) -> None: super().__init__(identifier, type_, location) @@ -61,14 +61,14 @@ def __str__(self) -> str: def check_type( self, - statement_type: rty.Type, + statement_type: ty.Type, typify_variable: Callable[[Expr], Expr], ) -> RecordFluxError: self.type_ = statement_type self.expression = self.expression.substituted(typify_variable) - error = rty.check_type_instance( + error = ty.check_type_instance( statement_type, - rty.Any, + ty.Any, self.location, f'variable "{self.identifier}"', ) @@ -76,7 +76,7 @@ def check_type( return error def to_ir(self, variable_id: Generator[ID, None, None]) -> list[ir.Stmt]: - assert isinstance(self.type_, rty.NamedTypeClass) + assert isinstance(self.type_, ty.NamedTypeClass) expression = expr_conv.to_ir(self.expression, variable_id) return [*expression.stmts, ir.Assign(self.identifier, expression.expr, self.type_, self)] @@ -87,7 +87,7 @@ def __init__( message: StrID, field: StrID, expression: Expr, - type_: rty.Type = rty.UNDEFINED, + type_: ty.Type = ty.UNDEFINED, location: Location | None = None, ) -> None: super().__init__(message, expression, type_, location) @@ -99,12 +99,12 @@ def __str__(self) -> str: def check_type( self, - statement_type: rty.Type, + statement_type: ty.Type, typify_variable: Callable[[Expr], Expr], ) -> RecordFluxError: error = RecordFluxError() - field_type: rty.Type = rty.UNDEFINED - if isinstance(statement_type, rty.Message): + field_type: ty.Type = ty.UNDEFINED + if isinstance(statement_type, ty.Message): if self.field in statement_type.fields: field_type = statement_type.types[self.field] elif self.field in statement_type.parameters: @@ -136,9 +136,9 @@ def check_type( self.type_ = statement_type self.expression = self.expression.substituted(typify_variable) error.extend( - rty.check_type_instance( + ty.check_type_instance( statement_type, - rty.Message, + ty.Message, self.location, f'variable "{self.identifier}"', ).entries, @@ -147,7 +147,7 @@ def check_type( return error def to_ir(self, variable_id: Generator[ID, None, None]) -> list[ir.Stmt]: - assert isinstance(self.type_, rty.Message) + assert isinstance(self.type_, ty.Message) expression = expr_conv.to_ir(self.expression, variable_id) return [ @@ -162,7 +162,7 @@ def __init__( identifier: StrID, attribute: str, parameters: list[Expr], - type_: rty.Type = rty.UNDEFINED, + type_: ty.Type = ty.UNDEFINED, location: Location | None = None, ) -> None: super().__init__(identifier, type_, location) @@ -175,7 +175,7 @@ def __str__(self) -> str: def check_type( self, - statement_type: rty.Type, + statement_type: ty.Type, typify_variable: Callable[[Expr], Expr], ) -> RecordFluxError: raise NotImplementedError @@ -192,7 +192,7 @@ def __init__( self, identifier: StrID, parameter: Expr, - type_: rty.Type = rty.UNDEFINED, + type_: ty.Type = ty.UNDEFINED, location: Location | None = None, ) -> None: super().__init__(identifier, self.__class__.__name__, [parameter], type_, location) @@ -201,20 +201,20 @@ def __init__( class Append(ListAttributeStatement): def check_type( self, - statement_type: rty.Type, + statement_type: ty.Type, typify_variable: Callable[[Expr], Expr], ) -> RecordFluxError: self.type_ = statement_type self.parameter = self.parameter.substituted(typify_variable) - error = rty.check_type_instance( + error = ty.check_type_instance( statement_type, - rty.Sequence, + ty.Sequence, self.location, f'variable "{self.identifier}"', ) - if isinstance(statement_type, rty.Sequence): + if isinstance(statement_type, ty.Sequence): error.extend(self.parameter.check_type(statement_type.element).entries) - if isinstance(statement_type.element, rty.Message) and isinstance( + if isinstance(statement_type.element, ty.Message) and isinstance( self.parameter, Variable, ): @@ -247,7 +247,7 @@ def parameter(self, value: Expr) -> None: self.parameters[0] = value def to_ir(self, variable_id: Generator[ID, None, None]) -> list[ir.Stmt]: - assert isinstance(self.type_, rty.Sequence) + assert isinstance(self.type_, ty.Sequence) parameter = expr_conv.to_ir(self.parameter, variable_id) return [ *parameter.stmts, @@ -258,14 +258,14 @@ def to_ir(self, variable_id: Generator[ID, None, None]) -> list[ir.Stmt]: class Extend(ListAttributeStatement): def check_type( self, - statement_type: rty.Type, + statement_type: ty.Type, typify_variable: Callable[[Expr], Expr], ) -> RecordFluxError: self.type_ = statement_type self.parameter = self.parameter.substituted(typify_variable) - error = rty.check_type_instance( + error = ty.check_type_instance( statement_type, - rty.Sequence, + ty.Sequence, self.location, f'variable "{self.identifier}"', ) @@ -282,7 +282,7 @@ def parameter(self, value: Expr) -> None: self.parameters[0] = value def to_ir(self, variable_id: Generator[ID, None, None]) -> list[ir.Stmt]: - assert isinstance(self.type_, rty.Sequence) + assert isinstance(self.type_, ty.Sequence) parameter = expr_conv.to_ir(self.parameter, variable_id) return [ *parameter.stmts, @@ -295,7 +295,7 @@ def __init__( self, identifier: StrID, associations: Mapping[ID, Expr] | None = None, - type_: rty.Type = rty.UNDEFINED, + type_: ty.Type = ty.UNDEFINED, location: Location | None = None, ) -> None: super().__init__(identifier, self.__class__.__name__, [], type_, location) @@ -309,7 +309,7 @@ def __str__(self) -> str: def check_type( self, - statement_type: rty.Type, + statement_type: ty.Type, typify_variable: Callable[[Expr], Expr], ) -> RecordFluxError: error = RecordFluxError() @@ -317,7 +317,7 @@ def check_type( self.associations = { i: e.substituted(typify_variable) for i, e in self.associations.items() } - if isinstance(statement_type, rty.Sequence): + if isinstance(statement_type, ty.Sequence): for i, e in self.associations.items(): error.push( ErrorEntry( @@ -326,7 +326,7 @@ def check_type( e.location, ), ) - elif isinstance(statement_type, rty.Message): + elif isinstance(statement_type, ty.Message): for i, e in self.associations.items(): if i in statement_type.parameter_types: error.extend(e.check_type(statement_type.parameter_types[i]).entries) @@ -348,9 +348,9 @@ def check_type( ), ) error.extend( - rty.check_type_instance( + ty.check_type_instance( statement_type, - (rty.Sequence, rty.Message), + (ty.Sequence, ty.Message), self.location, f'variable "{self.identifier}"', ).entries, @@ -364,7 +364,7 @@ def variables(self) -> Sequence[Variable]: ] def to_ir(self, variable_id: Generator[ID, None, None]) -> list[ir.Stmt]: - assert isinstance(self.type_, (rty.Sequence, rty.Message)) + assert isinstance(self.type_, (ty.Sequence, ty.Message)) associations = {} stmts = [] for i, e in self.associations.items(): @@ -382,25 +382,25 @@ def __init__( self, identifier: StrID, parameter: Expr, - type_: rty.Type = rty.UNDEFINED, + type_: ty.Type = ty.UNDEFINED, location: Location | None = None, ) -> None: super().__init__(identifier, self.__class__.__name__, [parameter], type_, location) def check_type( self, - statement_type: rty.Type, + statement_type: ty.Type, typify_variable: Callable[[Expr], Expr], ) -> RecordFluxError: self.type_ = statement_type self.parameters = [self.parameter.substituted(typify_variable)] - error = rty.check_type( + error = ty.check_type( statement_type, self._expected_channel_type(), self.location, f'channel "{self.identifier}"', ) - error.extend(self.parameter.check_type_instance(rty.Message).entries) + error.extend(self.parameter.check_type_instance(ty.Message).entries) return error @property @@ -415,15 +415,15 @@ def to_ir(self, variable_id: Generator[ID, None, None]) -> list[ir.Stmt]: ] @abstractmethod - def _expected_channel_type(self) -> rty.Channel: + def _expected_channel_type(self) -> ty.Channel: raise NotImplementedError class Read(ChannelAttributeStatement): - def _expected_channel_type(self) -> rty.Channel: - return rty.Channel(readable=True, writable=False) + def _expected_channel_type(self) -> ty.Channel: + return ty.Channel(readable=True, writable=False) class Write(ChannelAttributeStatement): - def _expected_channel_type(self) -> rty.Channel: - return rty.Channel(readable=False, writable=True) + def _expected_channel_type(self) -> ty.Channel: + return ty.Channel(readable=False, writable=True) diff --git a/rflx/model/type_decl.py b/rflx/model/type_decl.py index 3041c2d84..c5b8ea10e 100644 --- a/rflx/model/type_decl.py +++ b/rflx/model/type_decl.py @@ -7,12 +7,11 @@ from pathlib import Path from typing import Literal -import rflx.typing_ as rty -from rflx import const, expr +from rflx import const, expr, ty from rflx.common import indent_next, verbose_repr from rflx.error import fail from rflx.identifier import ID, StrID -from rflx.rapidflux import Annotation, ErrorEntry, Location, RecordFluxError, Severity, ty +from rflx.rapidflux import Annotation, ErrorEntry, Location, RecordFluxError, Severity from . import message from .top_level_declaration import TopLevelDeclaration, UncheckedTopLevelDeclaration @@ -20,8 +19,8 @@ class TypeDecl(TopLevelDeclaration): @property - def type_(self) -> rty.Type: - return rty.Undefined() + def type_(self) -> ty.Type: + return ty.Undefined() @property def direct_dependencies(self) -> list[TypeDecl]: @@ -216,8 +215,8 @@ def __str__(self) -> str: ) @property - def type_(self) -> rty.Type: - return rty.Integer( + def type_(self) -> ty.Type: + return ty.Integer( self.full_name, ty.Bounds(self.first.value, self.last.value), location=self.location, @@ -432,8 +431,8 @@ def __str__(self) -> str: return result @property - def type_(self) -> rty.Type: - return rty.Enumeration( + def type_(self) -> ty.Type: + return ty.Enumeration( self.full_name, list(map(ID, self.literals.keys())), self.always_valid, @@ -610,8 +609,8 @@ def __str__(self) -> str: return f"type {self.name} is sequence of {self.element_type.qualified_identifier}" @property - def type_(self) -> rty.Type: - return rty.Sequence(self.full_name, self.element_type.type_) + def type_(self) -> ty.Type: + return ty.Sequence(self.full_name, self.element_type.type_) @property def element_size(self) -> expr.Expr: @@ -637,8 +636,8 @@ def __str__(self) -> str: return "" @property - def type_(self) -> rty.Type: - return rty.OPAQUE + def type_(self) -> ty.Type: + return ty.OPAQUE @property def element_size(self) -> expr.Expr: @@ -760,7 +759,7 @@ def checked( ], expr.Number(1), always_valid=False, - location=rty.BOOLEAN.location, + location=ty.BOOLEAN.location, ) BOOLEAN = UNCHECKED_BOOLEAN.checked([]) diff --git a/rflx/rapidflux/ty.pyi b/rflx/rapidflux/ty.pyi index f91acfe06..d06847922 100644 --- a/rflx/rapidflux/ty.pyi +++ b/rflx/rapidflux/ty.pyi @@ -2,9 +2,10 @@ from collections import abc from typing import Final from rflx.identifier import StrID -from rflx.rapidflux import ID, Location +from rflx.rapidflux import ID, Annotation, Location, RecordFluxError class Builtins: + UNDEFINED: Final[Undefined] BOOLEAN: Final[Enumeration] INDEX: Final[Integer] BIT_LENGTH: Final[Integer] @@ -158,3 +159,18 @@ class Bounds: @property def upper(self) -> int: ... def merge(self, bounds: Bounds) -> Bounds: ... + +def common_type(types: abc.Sequence[Type]) -> Type: ... +def check_type( + actual: Type, + expected: Type | tuple[Type, ...], + location: Location | None, + description: str, +) -> RecordFluxError: ... +def check_type_instance( + actual: Type, + expected: type[Type] | tuple[type[Type], ...], + location: Location | None, + description: str = "", + additional_annotations: abc.Sequence[Annotation] | None = None, +) -> RecordFluxError: ... diff --git a/rflx/specification/parser.py b/rflx/specification/parser.py index 78c948167..828c9dad4 100644 --- a/rflx/specification/parser.py +++ b/rflx/specification/parser.py @@ -7,8 +7,7 @@ from pathlib import Path from typing import Iterable -import rflx.typing_ as rty -from rflx import expr, lang, model +from rflx import expr, lang, model, ty from rflx.common import STDIN, unique from rflx.const import RESERVED_WORDS from rflx.error import fail @@ -600,7 +599,7 @@ def create_variable(error: RecordFluxError, expression: lang.Expr, filename: Pat return expr.Literal( create_id(error, expression.f_identifier, filename), location=location, - type_=rty.BOOLEAN, + type_=ty.BOOLEAN, ) return expr.Variable(create_id(error, expression.f_identifier, filename), location=location) @@ -676,7 +675,7 @@ def create_call(error: RecordFluxError, expression: lang.Expr, filename: Path) - assert isinstance(expression, lang.Call) return expr.Call( create_id(error, expression.f_identifier, filename), - rty.UNDEFINED, + ty.UNDEFINED, [create_expression(error, a, filename) for a in expression.f_arguments], location=node_location(expression, filename), ) @@ -1021,7 +1020,7 @@ def create_math_expression( } validate_handler(error, "math expression", expression, list(handlers.keys()), filename) result = handlers[expression.kind_name](error, expression, filename) - if result.type_ == rty.BOOLEAN: + if result.type_ == ty.BOOLEAN: fail( "boolean expression in math context", location=node_location(expression, filename), @@ -1286,7 +1285,7 @@ def extract_then( condition = ( create_bool_expression(error, then.f_condition, filename) if then.f_condition - else expr.Literal("True", type_=rty.BOOLEAN, location=node_location(then, filename)) + else expr.Literal("True", type_=ty.BOOLEAN, location=node_location(then, filename)) ) size, first = extract_aspect(then.f_aspects) return target, condition, size, first, node_location(then, filename) diff --git a/rflx/ty.py b/rflx/ty.py new file mode 100644 index 000000000..a87d6affa --- /dev/null +++ b/rflx/ty.py @@ -0,0 +1,40 @@ +from __future__ import annotations + +from typing import Final, Union + +from typing_extensions import TypeAlias + +from rflx.rapidflux.ty import ( + Aggregate as Aggregate, + Any as Any, + AnyInteger as AnyInteger, + Bounds as Bounds, + Builtins as Builtins, + Channel as Channel, + Composite as Composite, + Compound as Compound, + Enumeration as Enumeration, + Integer as Integer, + Message as Message, + Refinement as Refinement, + Sequence as Sequence, + Structure as Structure, + Type as Type, + Undefined as Undefined, + UniversalInteger as UniversalInteger, + check_type as check_type, + check_type_instance as check_type_instance, + common_type as common_type, +) + +NamedType: TypeAlias = Union[Enumeration, Integer, Message, Sequence, Structure] +NamedTypeClass = (Enumeration, Integer, Message, Sequence, Structure) + +UNDEFINED: Final = Builtins.UNDEFINED +BOOLEAN: Final = Builtins.BOOLEAN +INDEX: Final = Builtins.INDEX +BIT_LENGTH: Final = Builtins.BIT_LENGTH +BIT_INDEX: Final = Builtins.BIT_INDEX +UNIVERSAL_INTEGER: Final = Builtins.UNIVERSAL_INTEGER +BASE_INTEGER: Final = Builtins.BASE_INTEGER +OPAQUE: Final = Builtins.OPAQUE diff --git a/rflx/typing_.py b/rflx/typing_.py deleted file mode 100644 index 5650dc15a..000000000 --- a/rflx/typing_.py +++ /dev/null @@ -1,141 +0,0 @@ -from __future__ import annotations - -from collections import abc -from typing import Final, Union - -from typing_extensions import TypeAlias - -from rflx.rapidflux import Annotation, ErrorEntry, Location, RecordFluxError, Severity -from rflx.rapidflux.ty import ( - Aggregate as Aggregate, - Any as Any, - AnyInteger as AnyInteger, - Builtins as Builtins, - Channel as Channel, - Composite as Composite, - Compound as Compound, - Enumeration as Enumeration, - Integer as Integer, - Message as Message, - Refinement as Refinement, - Sequence as Sequence, - Structure as Structure, - Type as Type, - Undefined as Undefined, - UniversalInteger as UniversalInteger, -) - -NamedType: TypeAlias = Union[Enumeration, Integer, Message, Sequence, Structure] -NamedTypeClass = (Enumeration, Integer, Message, Sequence, Structure) - -UNDEFINED: Final = Undefined() -BOOLEAN: Final = Builtins.BOOLEAN -INDEX: Final = Builtins.INDEX -BIT_LENGTH: Final = Builtins.BIT_LENGTH -BIT_INDEX: Final = Builtins.BIT_INDEX -UNIVERSAL_INTEGER: Final = Builtins.UNIVERSAL_INTEGER -BASE_INTEGER: Final = Builtins.BASE_INTEGER -OPAQUE: Final = Builtins.OPAQUE - - -def common_type(types: abc.Sequence[Type]) -> Type: - result: Type = Any() - - for t in types: - result = result.common_type(t) - - return result - - -def check_type( - actual: Type, - expected: Type | tuple[Type, ...], - location: Location | None, - description: str, -) -> RecordFluxError: - assert expected, "empty expected types" - - if actual == Undefined(): - return _undefined_type(location, description) - - error = RecordFluxError() - - expected_types = [expected] if isinstance(expected, Type) else list(expected) - - if Undefined() not in [actual, *expected_types] and all( - not actual.is_compatible(t) for t in expected_types - ): - desc = ( - " or ".join(map(str, expected_types)) if isinstance(expected, tuple) else str(expected) - ) - - assert location is not None - error.push( - ErrorEntry( - f"expected {desc}", - Severity.ERROR, - location, - annotations=( - [ - Annotation( - f"found {actual}", - Severity.ERROR, - location, - ), - ] - ), - generate_default_annotation=False, - ), - ) - - return error - - -def check_type_instance( - actual: Type, - expected: type[Type] | tuple[type[Type], ...], - location: Location | None, - description: str = "", - additionnal_annotations: abc.Sequence[Annotation] | None = None, -) -> RecordFluxError: - assert expected, "empty expected types" - - if actual == Undefined(): - return _undefined_type(location, description) - - error = RecordFluxError() - - if not isinstance(actual, expected) and actual != Any(): - additionnal_annotations = additionnal_annotations or [] - desc = ( - " or ".join(e.DESCRIPTIVE_NAME for e in expected) - if isinstance(expected, tuple) - else expected.DESCRIPTIVE_NAME - ) - assert location is not None - error.push( - ErrorEntry( - f"expected {desc}", - Severity.ERROR, - location, - annotations=[ - Annotation(f"found {actual}", Severity.ERROR, location), - *additionnal_annotations, - ], - generate_default_annotation=False, - ), - ) - - return error - - -def _undefined_type(location: Location | None, description: str = "") -> RecordFluxError: - return RecordFluxError( - [ - ErrorEntry( - "undefined" + (f" {description}" if description else ""), - Severity.ERROR, - location, - ), - ], - ) diff --git a/tests/integration/specification_model_test.py b/tests/integration/specification_model_test.py index 9e0a9ac72..c333bd1de 100644 --- a/tests/integration/specification_model_test.py +++ b/tests/integration/specification_model_test.py @@ -6,7 +6,7 @@ import pytest -from rflx import expr, typing_ as rty +from rflx import expr, ty from rflx.const import RESERVED_WORDS from rflx.identifier import ID from rflx.model import ( @@ -646,7 +646,7 @@ def test_consistency_specification_parsing_generation(tmp_path: Path) -> None: condition=expr.And( expr.Equal(expr.Variable("Z"), expr.TRUE), expr.Equal( - expr.Call("G", rty.BOOLEAN, [expr.Variable("F")]), + expr.Call("G", ty.BOOLEAN, [expr.Variable("F")]), expr.TRUE, ), ), diff --git a/tests/property/strategies.py b/tests/property/strategies.py index 7f21f25d9..5dcaf835e 100644 --- a/tests/property/strategies.py +++ b/tests/property/strategies.py @@ -7,7 +7,7 @@ from hypothesis import assume, strategies as st -from rflx import const, expr, typing_ as rty +from rflx import const, expr, ty from rflx.identifier import ID from rflx.model import ( BUILTIN_TYPES, @@ -400,7 +400,7 @@ def attributes(draw: Draw, elements: st.SearchStrategy[expr.Expr]) -> expr.Expr: @st.composite def calls(draw: Draw, elements: st.SearchStrategy[expr.Expr]) -> expr.Call: return draw( - st.builds(expr.Call, identifiers(), st.just(rty.Undefined), st.lists(elements, min_size=1)), + st.builds(expr.Call, identifiers(), st.just(ty.Undefined), st.lists(elements, min_size=1)), ) diff --git a/tests/unit/ada_test.py b/tests/unit/ada_test.py index e829f7fdf..ee813e39f 100644 --- a/tests/unit/ada_test.py +++ b/tests/unit/ada_test.py @@ -5,7 +5,7 @@ import pytest -from rflx import ada, expr, typing_ as rty +from rflx import ada, expr, ty from rflx.identifier import ID from tests.utils import assert_equal @@ -203,7 +203,7 @@ def test_indexed_rflx_expr() -> None: def test_call_rflx_expr() -> None: assert ada.Call("X", [ada.Variable("Y"), ada.Variable("Z")]).rflx_expr() == expr.Call( "X", - rty.UNDEFINED, + ty.UNDEFINED, [expr.Variable("Y"), expr.Variable("Z")], ) diff --git a/tests/unit/expr_conv_test.py b/tests/unit/expr_conv_test.py index eb3fa4cbe..7aa6ebe7c 100644 --- a/tests/unit/expr_conv_test.py +++ b/tests/unit/expr_conv_test.py @@ -2,7 +2,7 @@ import pytest -from rflx import ada, expr_conv, ir, typing_ as rty +from rflx import ada, expr_conv, ir, ty from rflx.expr import ( FALSE, TRUE, @@ -58,12 +58,11 @@ Variable, ) from rflx.identifier import ID, id_generator -from rflx.rapidflux import ty -INT_TY = rty.Integer("I", ty.Bounds(10, 100)) -ENUM_TY = rty.Enumeration("E", [ID("E1"), ID("E2")]) -MSG_TY = rty.Message("M") -SEQ_TY = rty.Sequence("S", rty.Message("M")) +INT_TY = ty.Integer("I", ty.Bounds(10, 100)) +ENUM_TY = ty.Enumeration("E", [ID("E1"), ID("E2")]) +MSG_TY = ty.Message("M") +SEQ_TY = ty.Sequence("S", ty.Message("M")) @pytest.mark.parametrize("expression", [And, AndThen, Or, OrElse]) @@ -147,7 +146,7 @@ def test_to_ir_not() -> None: ir.Not(ir.BoolVal(value=True)), ) assert expr_conv.to_ir( - Not(Variable("X", type_=rty.BOOLEAN)), + Not(Variable("X", type_=ty.BOOLEAN)), id_generator(), ) == ir.ComplexBoolExpr( [], @@ -158,8 +157,8 @@ def test_to_ir_not() -> None: id_generator(), ) == ir.ComplexBoolExpr( [ - ir.VarDecl("T_0", rty.BOOLEAN), - ir.Assign("T_0", ir.Less(ir.IntVar("X", INT_TY), ir.IntVar("Y", INT_TY)), rty.BOOLEAN), + ir.VarDecl("T_0", ty.BOOLEAN), + ir.Assign("T_0", ir.Less(ir.IntVar("X", INT_TY), ir.IntVar("Y", INT_TY)), ty.BOOLEAN), ], ir.Not(ir.BoolVar("T_0")), ) @@ -175,7 +174,7 @@ def test_to_ir_and_or( # type: ignore[misc] ir.BoolVal(value=True), ) assert expr_conv.to_ir( - op(Variable("X", type_=rty.BOOLEAN)), + op(Variable("X", type_=ty.BOOLEAN)), id_generator(), ) == ir.ComplexBoolExpr( [], @@ -183,31 +182,31 @@ def test_to_ir_and_or( # type: ignore[misc] ) assert expr_conv.to_ir( op( - Variable("X", type_=rty.BOOLEAN), - Variable("Y", type_=rty.BOOLEAN), - Variable("Z", type_=rty.BOOLEAN), + Variable("X", type_=ty.BOOLEAN), + Variable("Y", type_=ty.BOOLEAN), + Variable("Z", type_=ty.BOOLEAN), ), id_generator(), ) == ir.ComplexBoolExpr( [ - ir.VarDecl("T_0", rty.BOOLEAN), - ir.Assign("T_0", ir_op(ir.BoolVar("Y"), ir.BoolVar("Z")), rty.BOOLEAN), + ir.VarDecl("T_0", ty.BOOLEAN), + ir.Assign("T_0", ir_op(ir.BoolVar("Y"), ir.BoolVar("Z")), ty.BOOLEAN), ], ir_op(ir.BoolVar("X"), ir.BoolVar("T_0")), ) assert expr_conv.to_ir( op( op( - Variable("X", type_=rty.BOOLEAN), - Variable("Y", type_=rty.BOOLEAN), + Variable("X", type_=ty.BOOLEAN), + Variable("Y", type_=ty.BOOLEAN), ), - Variable("Z", type_=rty.BOOLEAN), + Variable("Z", type_=ty.BOOLEAN), ), id_generator(), ) == ir.ComplexBoolExpr( [ - ir.VarDecl("T_0", rty.BOOLEAN), - ir.Assign("T_0", ir_op(ir.BoolVar("X"), ir.BoolVar("Y")), rty.BOOLEAN), + ir.VarDecl("T_0", ty.BOOLEAN), + ir.Assign("T_0", ir_op(ir.BoolVar("X"), ir.BoolVar("Y")), ty.BOOLEAN), ], ir_op(ir.BoolVar("T_0"), ir.BoolVar("Z")), ) @@ -227,14 +226,14 @@ def test_to_ir_neg() -> None: id_generator(), ) == ir.ComplexIntExpr( [ - ir.VarDecl("T_0", rty.BASE_INTEGER), + ir.VarDecl("T_0", ty.BASE_INTEGER), ir.Assign( "T_0", ir.Add(ir.IntVar("X", INT_TY), ir.IntVar("Y", INT_TY)), - rty.BASE_INTEGER, + ty.BASE_INTEGER, ), ], - ir.Neg(ir.IntVar("T_0", rty.BASE_INTEGER)), + ir.Neg(ir.IntVar("T_0", ty.BASE_INTEGER)), ) @@ -260,14 +259,14 @@ def test_to_ir_add_mul( # type: ignore[misc] id_generator(), ) == ir.ComplexIntExpr( [ - ir.VarDecl("T_0", rty.BASE_INTEGER), + ir.VarDecl("T_0", ty.BASE_INTEGER), ir.Assign( "T_0", ir_op(ir.IntVar("Y", INT_TY), ir.IntVar("Z", INT_TY)), - rty.BASE_INTEGER, + ty.BASE_INTEGER, ), ], - ir_op(ir.IntVar("X", INT_TY), ir.IntVar("T_0", rty.BASE_INTEGER)), + ir_op(ir.IntVar("X", INT_TY), ir.IntVar("T_0", ty.BASE_INTEGER)), ) assert expr_conv.to_ir( op( @@ -280,14 +279,14 @@ def test_to_ir_add_mul( # type: ignore[misc] id_generator(), ) == ir.ComplexIntExpr( [ - ir.VarDecl("T_0", rty.BASE_INTEGER), + ir.VarDecl("T_0", ty.BASE_INTEGER), ir.Assign( "T_0", ir_op(ir.IntVar("X", INT_TY), ir.IntVar("Y", INT_TY)), - rty.BASE_INTEGER, + ty.BASE_INTEGER, ), ], - ir_op(ir.IntVar("T_0", rty.BASE_INTEGER), ir.IntVar("Z", INT_TY)), + ir_op(ir.IntVar("T_0", ty.BASE_INTEGER), ir.IntVar("Z", INT_TY)), ) @@ -320,14 +319,14 @@ def test_to_ir_sub_div_pow_mod( # type: ignore[misc] id_generator(), ) == ir.ComplexIntExpr( [ - ir.VarDecl("T_0", rty.BASE_INTEGER), + ir.VarDecl("T_0", ty.BASE_INTEGER), ir.Assign( "T_0", ir_op(ir.IntVar("X", INT_TY), ir.IntVar("Y", INT_TY)), - rty.BASE_INTEGER, + ty.BASE_INTEGER, ), ], - ir_op(ir.IntVar("T_0", rty.BASE_INTEGER), ir.IntVar("Z", INT_TY)), + ir_op(ir.IntVar("T_0", ty.BASE_INTEGER), ir.IntVar("Z", INT_TY)), ) @@ -339,7 +338,7 @@ def test_to_ir_literal() -> None: def test_to_ir_variable() -> None: - assert expr_conv.to_ir(Variable("X", type_=rty.BOOLEAN), id_generator()) == ir.ComplexBoolExpr( + assert expr_conv.to_ir(Variable("X", type_=ty.BOOLEAN), id_generator()) == ir.ComplexBoolExpr( [], ir.BoolVar("X"), ) @@ -383,7 +382,7 @@ def test_to_ir_attribute_int(attribute: Expr, ir_attribute: ir.Expr) -> None: (Valid(Variable("X", type_=MSG_TY)), ir.Valid("X", MSG_TY)), (Present(Variable("X", type_=MSG_TY)), ir.Present("X", MSG_TY)), ( - HasData(Variable("X", type_=rty.Channel(readable=True, writable=False))), + HasData(Variable("X", type_=ty.Channel(readable=True, writable=False))), ir.HasData("X", MSG_TY), ), ], @@ -398,16 +397,16 @@ def test_to_ir_aggregate() -> None: id_generator(), ) == ir.ComplexExpr( [ - ir.VarDecl("T_0", rty.BASE_INTEGER), - ir.Assign("T_0", ir.First("X", INT_TY), rty.BASE_INTEGER), - ir.VarDecl("T_1", rty.BASE_INTEGER), + ir.VarDecl("T_0", ty.BASE_INTEGER), + ir.Assign("T_0", ir.First("X", INT_TY), ty.BASE_INTEGER), + ir.VarDecl("T_1", ty.BASE_INTEGER), ir.Assign( "T_1", - ir.Add(ir.IntVar("T_0", rty.BASE_INTEGER), ir.IntVal(1)), - rty.BASE_INTEGER, + ir.Add(ir.IntVar("T_0", ty.BASE_INTEGER), ir.IntVal(1)), + ty.BASE_INTEGER, ), ], - ir.Agg([ir.IntVar("T_1", rty.BASE_INTEGER), ir.IntVal(2)]), + ir.Agg([ir.IntVar("T_1", ty.BASE_INTEGER), ir.IntVal(2)]), ) @@ -435,8 +434,8 @@ def test_to_ir_relation( # type: ignore[misc] def test_to_ir_if_expr() -> None: assert expr_conv.to_ir( IfExpr( - [(Variable("X", type_=rty.BOOLEAN), Variable("Y", type_=rty.BOOLEAN))], - Variable("Z", type_=rty.BOOLEAN), + [(Variable("X", type_=ty.BOOLEAN), Variable("Y", type_=ty.BOOLEAN))], + Variable("Z", type_=ty.BOOLEAN), ), id_generator(), ) == ir.ComplexBoolExpr( @@ -445,12 +444,12 @@ def test_to_ir_if_expr() -> None: ir.BoolVar("X"), ir.ComplexBoolExpr([], ir.BoolVar("Y")), ir.ComplexBoolExpr([], ir.BoolVar("Z")), - rty.BOOLEAN, + ty.BOOLEAN, ), ) assert expr_conv.to_ir( IfExpr( - [(Variable("X", type_=rty.BOOLEAN), Variable("Y", type_=INT_TY))], + [(Variable("X", type_=ty.BOOLEAN), Variable("Y", type_=INT_TY))], Variable("Z", type_=INT_TY), ), id_generator(), @@ -467,7 +466,7 @@ def test_to_ir_if_expr() -> None: IfExpr( [ ( - And(Variable("X", type_=rty.BOOLEAN), TRUE), + And(Variable("X", type_=ty.BOOLEAN), TRUE), Add(Variable("Y", type_=INT_TY), Number(1)), ), ], @@ -476,8 +475,8 @@ def test_to_ir_if_expr() -> None: id_generator(), ) == ir.ComplexIntExpr( [ - ir.VarDecl("T_0", rty.BOOLEAN), - ir.Assign("T_0", ir.And(ir.BoolVar("X"), ir.BoolVal(value=True)), rty.BOOLEAN), + ir.VarDecl("T_0", ty.BOOLEAN), + ir.Assign("T_0", ir.And(ir.BoolVar("X"), ir.BoolVal(value=True)), ty.BOOLEAN), ], ir.IntIfExpr( ir.BoolVar("T_0"), @@ -494,15 +493,15 @@ def test_to_ir_number() -> None: def test_to_ir_selected() -> None: assert expr_conv.to_ir( - Selected(Variable("X", type_=rty.Message("M")), "Y", type_=rty.BOOLEAN), + Selected(Variable("X", type_=ty.Message("M")), "Y", type_=ty.BOOLEAN), id_generator(), ) == ir.ComplexExpr([], ir.BoolFieldAccess("X", "Y", MSG_TY)) assert expr_conv.to_ir( - Selected(Variable("X", type_=rty.Message("M")), "Y", type_=INT_TY), + Selected(Variable("X", type_=ty.Message("M")), "Y", type_=INT_TY), id_generator(), ) == ir.ComplexExpr([], ir.IntFieldAccess("X", "Y", MSG_TY)) assert expr_conv.to_ir( - Neg(Selected(Variable("X", type_=rty.Message("M")), "Y", type_=INT_TY)), + Neg(Selected(Variable("X", type_=ty.Message("M")), "Y", type_=INT_TY)), id_generator(), ) == ir.ComplexIntExpr( [ @@ -512,7 +511,7 @@ def test_to_ir_selected() -> None: ir.Neg(ir.IntVar("T_0", INT_TY)), ) assert expr_conv.to_ir( - Selected(Variable("X", type_=rty.Message("M")), "Y", type_=SEQ_TY), + Selected(Variable("X", type_=ty.Message("M")), "Y", type_=SEQ_TY), id_generator(), ) == ir.ComplexExpr([], ir.ObjFieldAccess("X", "Y", MSG_TY)) @@ -522,30 +521,30 @@ def test_to_ir_call() -> None: Call( "X", INT_TY, - [Variable("Y", type_=rty.BOOLEAN), Variable("Z", type_=INT_TY)], + [Variable("Y", type_=ty.BOOLEAN), Variable("Z", type_=INT_TY)], ), id_generator(), ) == ir.ComplexExpr( [], - ir.IntCall("X", [ir.BoolVar("Y"), ir.IntVar("Z", INT_TY)], [rty.BOOLEAN, INT_TY], INT_TY), + ir.IntCall("X", [ir.BoolVar("Y"), ir.IntVar("Z", INT_TY)], [ty.BOOLEAN, INT_TY], INT_TY), ) assert expr_conv.to_ir( Call( "X", - rty.BOOLEAN, - [Variable("Y", type_=rty.BOOLEAN), Variable("Z", type_=rty.BOOLEAN)], + ty.BOOLEAN, + [Variable("Y", type_=ty.BOOLEAN), Variable("Z", type_=ty.BOOLEAN)], ), id_generator(), ) == ir.ComplexExpr( [], - ir.BoolCall("X", [ir.BoolVar("Y"), ir.BoolVar("Z")], [rty.BOOLEAN, rty.BOOLEAN]), + ir.BoolCall("X", [ir.BoolVar("Y"), ir.BoolVar("Z")], [ty.BOOLEAN, ty.BOOLEAN]), ) assert expr_conv.to_ir( Call( "X", - rty.BOOLEAN, + ty.BOOLEAN, [ - And(Variable("X", type_=rty.BOOLEAN), TRUE), + And(Variable("X", type_=ty.BOOLEAN), TRUE), Add(Variable("Y", type_=INT_TY), Number(1)), ], ), @@ -558,25 +557,25 @@ def test_to_ir_call() -> None: ir.And(ir.BoolVar("X"), ir.BoolVal(value=True)), ir.Add(ir.IntVar("Y", INT_TY), ir.IntVal(1)), ], - [rty.BOOLEAN, INT_TY], + [ty.BOOLEAN, INT_TY], ), ) assert expr_conv.to_ir( Call( "X", MSG_TY, - [Variable("Y", type_=rty.BOOLEAN), Variable("Z", type_=INT_TY)], + [Variable("Y", type_=ty.BOOLEAN), Variable("Z", type_=INT_TY)], ), id_generator(), ) == ir.ComplexExpr( [], - ir.ObjCall("X", [ir.BoolVar("Y"), ir.IntVar("Z", INT_TY)], [rty.BOOLEAN, INT_TY], MSG_TY), + ir.ObjCall("X", [ir.BoolVar("Y"), ir.IntVar("Z", INT_TY)], [ty.BOOLEAN, INT_TY], MSG_TY), ) def test_to_ir_conversion() -> None: assert expr_conv.to_ir( - Conversion("I", Variable("Y", type_=rty.BOOLEAN), type_=INT_TY), + Conversion("I", Variable("Y", type_=ty.BOOLEAN), type_=INT_TY), id_generator(), ) == ir.ComplexExpr([], ir.Conversion(INT_TY, ir.BoolVar("Y"))) @@ -585,7 +584,7 @@ def test_to_ir_comprehension() -> None: assert expr_conv.to_ir( Comprehension( "X", - Selected(Variable("M", type_=rty.Message("M")), "Y", type_=rty.Sequence("S", INT_TY)), + Selected(Variable("M", type_=ty.Message("M")), "Y", type_=ty.Sequence("S", INT_TY)), Add(Variable("X", type_=INT_TY), Variable("Y", type_=INT_TY), Number(1)), Less(Sub(Variable("X", type_=INT_TY), Number(1)), Number(ir.INT_MAX)), ), @@ -597,21 +596,21 @@ def test_to_ir_comprehension() -> None: ir.ObjFieldAccess("M", ID("Y"), MSG_TY), ir.ComplexExpr( [ - ir.VarDecl("T_0", rty.BASE_INTEGER), + ir.VarDecl("T_0", ty.BASE_INTEGER), ir.Assign( "T_0", ir.Add(ir.IntVar("Y", INT_TY), ir.IntVal(1)), - rty.BASE_INTEGER, + ty.BASE_INTEGER, ), ], - ir.Add(ir.IntVar("X", INT_TY), ir.IntVar("T_0", rty.BASE_INTEGER)), + ir.Add(ir.IntVar("X", INT_TY), ir.IntVar("T_0", ty.BASE_INTEGER)), ), ir.ComplexBoolExpr( [ - ir.VarDecl("T_1", rty.BASE_INTEGER), - ir.Assign("T_1", ir.Add(ir.IntVar("X", INT_TY), ir.IntVal(-1)), rty.BOOLEAN), + ir.VarDecl("T_1", ty.BASE_INTEGER), + ir.Assign("T_1", ir.Add(ir.IntVar("X", INT_TY), ir.IntVal(-1)), ty.BOOLEAN), ], - ir.Less(ir.IntVar("T_1", rty.BASE_INTEGER), ir.IntVal(ir.INT_MAX)), + ir.Less(ir.IntVar("T_1", ty.BASE_INTEGER), ir.IntVal(ir.INT_MAX)), ), ), ) @@ -630,9 +629,9 @@ def test_to_ir_message_aggregate( # type: ignore[misc] "X", { "Y": Selected( - Variable("M", type_=rty.Message("M")), + Variable("M", type_=ty.Message("M")), "Y", - type_=rty.Sequence("S", INT_TY), + type_=ty.Sequence("S", INT_TY), ), "Z": Add(Variable("X", type_=INT_TY), Variable("Y", type_=INT_TY), Number(1)), }, @@ -641,14 +640,14 @@ def test_to_ir_message_aggregate( # type: ignore[misc] id_generator(), ) == ir.ComplexExpr( [ - ir.VarDecl("T_0", rty.BASE_INTEGER), - ir.Assign("T_0", ir.Add(ir.IntVar("Y", INT_TY), ir.IntVal(1)), rty.BASE_INTEGER), + ir.VarDecl("T_0", ty.BASE_INTEGER), + ir.Assign("T_0", ir.Add(ir.IntVar("Y", INT_TY), ir.IntVal(1)), ty.BASE_INTEGER), ], ir_agg( "X", { ID("Y"): ir.ObjFieldAccess("M", ID("Y"), MSG_TY), - ID("Z"): ir.Add(ir.IntVar("X", INT_TY), ir.IntVar("T_0", rty.BASE_INTEGER)), + ID("Z"): ir.Add(ir.IntVar("X", INT_TY), ir.IntVar("T_0", ty.BASE_INTEGER)), }, MSG_TY, ), diff --git a/tests/unit/expr_proof_test.py b/tests/unit/expr_proof_test.py index 999143f1b..2837431fe 100644 --- a/tests/unit/expr_proof_test.py +++ b/tests/unit/expr_proof_test.py @@ -3,7 +3,7 @@ import pytest import z3 -from rflx import typing_ as rty +from rflx import ty from rflx.expr import ( FALSE, TRUE, @@ -182,7 +182,7 @@ def test_to_z3_attribute(attribute: Expr, z3name: str) -> None: def test_to_z3_attribute_error() -> None: with pytest.raises(Z3TypeError): - _to_z3(First(Call("X", rty.BASE_INTEGER))) + _to_z3(First(Call("X", ty.BASE_INTEGER))) def test_to_z3_aggregate() -> None: @@ -192,7 +192,7 @@ def test_to_z3_aggregate() -> None: @pytest.mark.parametrize("relation", [Less, LessEqual, GreaterEqual, Greater]) def test_to_z3_relation_error(relation: Callable[[Expr, Expr], Expr]) -> None: with pytest.raises(Z3TypeError): - _to_z3(relation(Variable("X", type_=rty.BOOLEAN), Number(1))) + _to_z3(relation(Variable("X", type_=ty.BOOLEAN), Number(1))) def test_to_z3_less() -> None: diff --git a/tests/unit/expr_test.py b/tests/unit/expr_test.py index 7258108de..ec72b965e 100644 --- a/tests/unit/expr_test.py +++ b/tests/unit/expr_test.py @@ -5,7 +5,7 @@ import pytest -from rflx import typing_ as rty +from rflx import ty from rflx.expr import ( FALSE, TRUE, @@ -67,7 +67,7 @@ ) from rflx.identifier import ID, StrID from rflx.model import Integer -from rflx.rapidflux import Location, RecordFluxError, ty +from rflx.rapidflux import Location, RecordFluxError from tests.data import models from tests.utils import assert_equal, check_regex @@ -75,13 +75,13 @@ TINY_INT = Integer("P::Tiny", Number(1), Number(3), Number(8), location=Location((1, 2))) INT = Integer("P::Int", Number(1), Number(100), Number(8), location=Location((3, 2))) -INT_TY = rty.Integer("I", ty.Bounds(10, 100)) -ENUM_TY = rty.Enumeration("E", [ID("E1"), ID("E2")]) -MSG_TY = rty.Message("M") -SEQ_TY = rty.Sequence("S", rty.Message("M")) +INT_TY = ty.Integer("I", ty.Bounds(10, 100)) +ENUM_TY = ty.Enumeration("E", [ID("E1"), ID("E2")]) +MSG_TY = ty.Message("M") +SEQ_TY = ty.Sequence("S", ty.Message("M")) -def assert_type(expr: Expr, type_: rty.Type) -> None: +def assert_type(expr: Expr, type_: ty.Type) -> None: expr.check_type(type_).propagate() assert expr.type_ == type_ @@ -89,13 +89,13 @@ def assert_type(expr: Expr, type_: rty.Type) -> None: def assert_type_error(expr: Expr, regex: str) -> None: check_regex(regex) with pytest.raises(RecordFluxError, match=regex): - expr.check_type(rty.Any()).propagate() + expr.check_type(ty.Any()).propagate() def test_true_type() -> None: assert_type( TRUE, - rty.BOOLEAN, + ty.BOOLEAN, ) @@ -110,7 +110,7 @@ def test_true_variables() -> None: def test_false_type() -> None: assert_type( FALSE, - rty.BOOLEAN, + ty.BOOLEAN, ) @@ -124,8 +124,8 @@ def test_false_variables() -> None: def test_not_type() -> None: assert_type( - Not(Variable("X", type_=rty.BOOLEAN)), - rty.BOOLEAN, + Not(Variable("X", type_=ty.BOOLEAN)), + ty.BOOLEAN, ) @@ -389,8 +389,8 @@ def test_bool_expr_str() -> None: @pytest.mark.parametrize("operation", [And, Or]) def test_bool_expr_type(operation: Callable[[Expr, Expr], Expr]) -> None: assert_type( - operation(Variable("X", type_=rty.BOOLEAN), Variable("Y", type_=rty.BOOLEAN)), - rty.BOOLEAN, + operation(Variable("X", type_=ty.BOOLEAN), Variable("Y", type_=ty.BOOLEAN)), + ty.BOOLEAN, ) @@ -398,7 +398,7 @@ def test_bool_expr_type(operation: Callable[[Expr, Expr], Expr]) -> None: def test_bool_expr_type_error(operation: Callable[[Expr, Expr], Expr]) -> None: assert_type_error( operation( - Variable("X", type_=rty.Integer("A", ty.Bounds(0, 100)), location=Location((10, 20))), + Variable("X", type_=ty.Integer("A", ty.Bounds(0, 100)), location=Location((10, 20))), Number(1, location=Location((10, 30))), ), r'^:10:20: error: expected enumeration type "__BUILTINS__::Boolean"\n' @@ -498,7 +498,7 @@ def test_undefined_str() -> None: def test_number_type() -> None: assert_type( Number(1), - rty.UniversalInteger(ty.Bounds(1, 1)), + ty.UniversalInteger(ty.Bounds(1, 1)), ) @@ -617,7 +617,7 @@ def test_number_hashable() -> None: def test_math_expr_type(operation: Callable[[Expr, Expr], Expr]) -> None: assert_type( operation(Variable("X", type_=INT_TY), Variable("Y", type_=INT_TY)), - rty.BASE_INTEGER, + ty.BASE_INTEGER, ) @@ -625,8 +625,8 @@ def test_math_expr_type(operation: Callable[[Expr, Expr], Expr]) -> None: def test_math_expr_type_error(operation: Callable[[Expr, Expr], Expr]) -> None: assert_type_error( operation( - Variable("X", type_=rty.BOOLEAN, location=Location((10, 20))), - Variable("True", type_=rty.BOOLEAN, location=Location((10, 30))), + Variable("X", type_=ty.BOOLEAN, location=Location((10, 20))), + Variable("True", type_=ty.BOOLEAN, location=Location((10, 30))), ), r"^:10:20: error: expected integer type\n" r':10:20: error: found enumeration type "__BUILTINS__::Boolean"\n' @@ -649,7 +649,7 @@ def test_neg_type() -> None: def test_neg_type_error() -> None: assert_type_error( - Neg(Variable("X", type_=rty.BOOLEAN, location=Location((10, 20)))), + Neg(Variable("X", type_=ty.BOOLEAN, location=Location((10, 20)))), r"^:10:20: error: expected integer type\n" r':10:20: error: found enumeration type "__BUILTINS__::Boolean"$', ) @@ -693,8 +693,8 @@ def test_neg_simplified() -> None: def test_add_str() -> None: assert str(Add(Variable("X"), Number(1))) == "X + 1" assert str(-Add(Variable("X"), Number(1))) == "-X - 1" - assert str(Add(Number(1), Call("Test", rty.BASE_INTEGER, []))) == "1 + Test" - assert str(Add(Number(1), -Call("Test", rty.BASE_INTEGER, []))) == "1 - Test" + assert str(Add(Number(1), Call("Test", ty.BASE_INTEGER, []))) == "1 + Test" + assert str(Add(Number(1), -Call("Test", ty.BASE_INTEGER, []))) == "1 - Test" assert str(Add()) == "0" @@ -878,8 +878,8 @@ def test_variable_invalid_name() -> None: def test_variable_type() -> None: assert_type( - Variable("X", type_=rty.BOOLEAN), - rty.BOOLEAN, + Variable("X", type_=ty.BOOLEAN), + ty.BOOLEAN, ) assert_type( Variable("X", type_=INT_TY), @@ -937,27 +937,27 @@ def test_attribute() -> None: @pytest.mark.parametrize( ("attribute", "expr", "expected"), [ - (Size, Variable("X", type_=INT_TY), rty.UNIVERSAL_INTEGER), - (Length, Variable("X", type_=INT_TY), rty.UNIVERSAL_INTEGER), - (First, Variable("X", type_=INT_TY), rty.UNIVERSAL_INTEGER), - (Last, Variable("X", type_=INT_TY), rty.UNIVERSAL_INTEGER), - (ValidChecksum, Variable("X", type_=INT_TY), rty.BOOLEAN), - (Valid, Variable("X", type_=rty.Message("A")), rty.BOOLEAN), + (Size, Variable("X", type_=INT_TY), ty.UNIVERSAL_INTEGER), + (Length, Variable("X", type_=INT_TY), ty.UNIVERSAL_INTEGER), + (First, Variable("X", type_=INT_TY), ty.UNIVERSAL_INTEGER), + (Last, Variable("X", type_=INT_TY), ty.UNIVERSAL_INTEGER), + (ValidChecksum, Variable("X", type_=INT_TY), ty.BOOLEAN), + (Valid, Variable("X", type_=ty.Message("A")), ty.BOOLEAN), ( Present, Selected( - Variable("X", type_=rty.Message("M", {("F",)}, {ID("F"): INT_TY})), + Variable("X", type_=ty.Message("M", {("F",)}, {ID("F"): INT_TY})), "F", ), - rty.BOOLEAN, + ty.BOOLEAN, ), - (Head, Variable("X", type_=rty.Sequence("A", INT_TY)), INT_TY), - (Opaque, Variable("X", type_=rty.Message("A")), rty.OPAQUE), + (Head, Variable("X", type_=ty.Sequence("A", INT_TY)), INT_TY), + (Opaque, Variable("X", type_=ty.Message("A")), ty.OPAQUE), ( Head, Comprehension( "X", - Variable("Y", type_=rty.Sequence("A", INT_TY)), + Variable("Y", type_=ty.Sequence("A", INT_TY)), Variable("X", type_=INT_TY), TRUE, location=Location((10, 30)), @@ -968,14 +968,14 @@ def test_attribute() -> None: Head, Variable( "Z", - type_=rty.Sequence("Universal::Options", rty.Message("Universal::Option")), + type_=ty.Sequence("Universal::Options", ty.Message("Universal::Option")), location=Location((10, 20)), ), - rty.Message("Universal::Option"), + ty.Message("Universal::Option"), ), ], ) -def test_attribute_type(attribute: Callable[[Expr], Expr], expr: Expr, expected: rty.Type) -> None: +def test_attribute_type(attribute: Callable[[Expr], Expr], expr: Expr, expected: ty.Type) -> None: assert_type( attribute(expr), expected, @@ -994,7 +994,7 @@ def test_attribute_type(attribute: Callable[[Expr], Expr], expr: Expr, expected: Opaque( Variable( "X", - type_=rty.Sequence("A", INT_TY), + type_=ty.Sequence("A", INT_TY), location=Location((10, 30)), ), ), @@ -1005,7 +1005,7 @@ def test_attribute_type(attribute: Callable[[Expr], Expr], expr: Expr, expected: Opaque( Call( "X", - rty.UNDEFINED, + ty.UNDEFINED, [Variable("Y", location=Location((10, 30)))], location=Location((10, 20)), ), @@ -1037,12 +1037,12 @@ def test_attribute_substituted() -> None: Number(-42), ) assert_equal( - First("X").substituted(lambda x: Call("Y", rty.BASE_INTEGER) if x == Variable("X") else x), - First(Call("Y", rty.BASE_INTEGER)), + First("X").substituted(lambda x: Call("Y", ty.BASE_INTEGER) if x == Variable("X") else x), + First(Call("Y", ty.BASE_INTEGER)), ) assert_equal( - -First("X").substituted(lambda x: Call("Y", rty.BASE_INTEGER) if x == Variable("X") else x), - -First(Call("Y", rty.BASE_INTEGER)), + -First("X").substituted(lambda x: Call("Y", ty.BASE_INTEGER) if x == Variable("X") else x), + -First(Call("Y", ty.BASE_INTEGER)), ) assert_equal( -First("X").substituted( @@ -1076,7 +1076,7 @@ def test_attribute_str() -> None: def test_attribute_variables() -> None: assert First("X").variables() == [Variable("X")] - assert First(Call("X", rty.BASE_INTEGER, [Variable("Y")])).variables() == [ + assert First(Call("X", ty.BASE_INTEGER, [Variable("Y")])).variables() == [ Variable("X"), Variable("Y"), ] @@ -1108,7 +1108,7 @@ def test_val_str() -> None: def test_aggregate_type() -> None: assert_type( Aggregate(Number(0), Number(1)), - rty.Aggregate(rty.UniversalInteger(ty.Bounds(0, 1))), + ty.Aggregate(ty.UniversalInteger(ty.Bounds(0, 1))), ) @@ -1157,7 +1157,7 @@ def test_aggregate_precedence() -> None: def test_relation_integer_type(relation: Callable[[Expr, Expr], Expr]) -> None: assert_type( relation(Variable("X", type_=INT_TY), Variable("Y", type_=INT_TY)), - rty.BOOLEAN, + ty.BOOLEAN, ) @@ -1169,7 +1169,7 @@ def test_relation_integer_type_error(relation: Callable[[Expr, Expr], Expr]) -> assert_type_error( relation( Variable("X", type_=INT_TY), - Variable("True", type_=rty.BOOLEAN, location=Location((10, 30))), + Variable("True", type_=ty.BOOLEAN, location=Location((10, 30))), ), rf"^:10:30: error: expected {integer_type}\n" r':10:30: error: found enumeration type "__BUILTINS__::Boolean"$', @@ -1181,9 +1181,9 @@ def test_relation_composite_type(relation: Callable[[Expr, Expr], Expr]) -> None assert_type( relation( Variable("X", type_=INT_TY), - Variable("Y", type_=rty.Sequence("A", INT_TY)), + Variable("Y", type_=ty.Sequence("A", INT_TY)), ), - rty.BOOLEAN, + ty.BOOLEAN, ) @@ -1192,7 +1192,7 @@ def test_relation_composite_type_error(relation: Callable[[Expr, Expr], Expr]) - assert_type_error( relation( Variable("X", type_=INT_TY, location=Location((10, 20))), - Variable("True", type_=rty.BOOLEAN, location=Location((10, 30))), + Variable("True", type_=ty.BOOLEAN, location=Location((10, 30))), ), r"^:10:30: error: expected aggregate" r' with element integer type "I" \(10 \.\. 100\)\n' @@ -1201,7 +1201,7 @@ def test_relation_composite_type_error(relation: Callable[[Expr, Expr], Expr]) - assert_type_error( relation( Variable("X", type_=INT_TY, location=Location((10, 20))), - Variable("Y", type_=rty.Sequence("A", rty.BOOLEAN), location=Location((10, 30))), + Variable("Y", type_=ty.Sequence("A", ty.BOOLEAN), location=Location((10, 30))), ), r"^:10:30: error: expected aggregate" r' with element integer type "I" \(10 \.\. 100\)\n' @@ -1461,15 +1461,15 @@ def test_if_expr_simplified() -> None: def test_value_range_type() -> None: assert_type( ValueRange(Number(1), Number(42)), - rty.Any(), + ty.Any(), ) def test_value_range_type_error() -> None: assert_type_error( ValueRange( - Variable("X", type_=rty.BOOLEAN, location=Location((10, 30))), - Variable("Y", type_=rty.Sequence("A", INT_TY), location=Location((10, 40))), + Variable("X", type_=ty.BOOLEAN, location=Location((10, 30))), + Variable("Y", type_=ty.Sequence("A", INT_TY), location=Location((10, 40))), location=Location((10, 20)), ), r"^" @@ -1514,10 +1514,10 @@ def test_quantified_expression_type(expr: Callable[[str, Expr, Expr], Expr]) -> assert_type( expr( "X", - Variable("Y", type_=rty.Sequence("A", INT_TY)), - Variable("Z", type_=rty.BOOLEAN), + Variable("Y", type_=ty.Sequence("A", INT_TY)), + Variable("Z", type_=ty.BOOLEAN), ), - rty.BOOLEAN, + ty.BOOLEAN, ) @@ -1526,8 +1526,8 @@ def test_quantified_expression_type(expr: Callable[[str, Expr, Expr], Expr]) -> ("iterable", "predicate", "match"), [ ( - Variable("Y", type_=rty.BOOLEAN, location=Location((10, 30))), - Variable("Z", type_=rty.Sequence("A", INT_TY), location=Location((10, 40))), + Variable("Y", type_=ty.BOOLEAN, location=Location((10, 30))), + Variable("Z", type_=ty.Sequence("A", INT_TY), location=Location((10, 40))), r"^:10:30: error: expected composite type\n" r':10:30: error: found enumeration type "__BUILTINS__::Boolean"\n' r':10:40: error: expected enumeration type "__BUILTINS__::Boolean"\n' @@ -1535,13 +1535,13 @@ def test_quantified_expression_type(expr: Callable[[str, Expr, Expr], Expr]) -> r' with element integer type "I" \(10 \.\. 100\)$', ), ( - Variable("Y", type_=rty.BOOLEAN, location=Location((10, 30))), + Variable("Y", type_=ty.BOOLEAN, location=Location((10, 30))), Equal(Variable("X"), Number(1)), r"^:10:30: error: expected composite type\n" r':10:30: error: found enumeration type "__BUILTINS__::Boolean"$', ), ( - Variable("Y", type_=rty.Sequence("A", rty.BOOLEAN)), + Variable("Y", type_=ty.Sequence("A", ty.BOOLEAN)), Equal(Variable("X"), Number(1, location=Location((10, 30)))), r'^:10:30: error: expected enumeration type "__BUILTINS__::Boolean"\n' r":10:30: error: found type universal integer \(1 .. 1\)$", @@ -1711,7 +1711,7 @@ def test_expr_substituted_pre() -> None: with pytest.raises(AssertionError): Selected(Variable("X"), "F").substituted(lambda x: x, mapping) # pragma: no branch with pytest.raises(AssertionError): - Call("Sub", rty.BASE_INTEGER).substituted(lambda x: x, mapping) # pragma: no branch + Call("Sub", ty.BASE_INTEGER).substituted(lambda x: x, mapping) # pragma: no branch with pytest.raises(AssertionError): ForAllOf("X", Variable("Y"), Variable("Z")).substituted( # pragma: no branch lambda x: x, @@ -1834,7 +1834,7 @@ def test_string_str() -> None: def test_selected_type() -> None: assert_type( - Selected(Variable("X", type_=rty.Message("M", {("F",)}, {ID("F"): INT_TY})), "F"), + Selected(Variable("X", type_=ty.Message("M", {("F",)}, {ID("F"): INT_TY})), "F"), INT_TY, ) @@ -1843,7 +1843,7 @@ def test_selected_type() -> None: ("expr", "match"), [ ( - Selected(Variable("X", type_=rty.BOOLEAN, location=Location((10, 20))), "Y"), + Selected(Variable("X", type_=ty.BOOLEAN, location=Location((10, 20))), "Y"), r"^:10:20: error: expected message type\n" r':10:20: error: found enumeration type "__BUILTINS__::Boolean"$', ), @@ -1851,7 +1851,7 @@ def test_selected_type() -> None: Selected( Variable( "X", - type_=rty.Message("M", {("F",)}, {ID("F"): INT_TY}), + type_=ty.Message("M", {("F",)}, {ID("F"): INT_TY}), ), "Y", location=Location((10, 20)), @@ -1862,7 +1862,7 @@ def test_selected_type() -> None: Selected( Variable( "X", - type_=rty.Message( + type_=ty.Message( "M", {("F1",), ("F2",)}, {ID("F1"): INT_TY, ID("F2"): INT_TY}, @@ -1919,11 +1919,11 @@ def test_call_type() -> None: assert_type( Call( "X", - rty.BOOLEAN, + ty.BOOLEAN, [Variable("Y", type_=INT_TY)], argument_types=[INT_TY], ), - rty.BOOLEAN, + ty.BOOLEAN, ) @@ -1931,7 +1931,7 @@ def test_call_type_error() -> None: assert_type_error( Call( "X", - rty.UNDEFINED, + ty.UNDEFINED, [Variable("Y", location=Location((10, 30)))], location=Location((10, 20)), ), @@ -1941,13 +1941,13 @@ def test_call_type_error() -> None: assert_type_error( Call( "X", - rty.BOOLEAN, + ty.BOOLEAN, [ Variable("Y", type_=INT_TY, location=Location((10, 30))), - Variable("Z", type_=rty.BOOLEAN, location=Location((10, 40))), + Variable("Z", type_=ty.BOOLEAN, location=Location((10, 40))), ], argument_types=[ - rty.BOOLEAN, + ty.BOOLEAN, INT_TY, ], ), @@ -1959,13 +1959,13 @@ def test_call_type_error() -> None: def test_call_variables() -> None: - result = Call("Sub", rty.BASE_INTEGER, [Variable("A"), Variable("B")]).variables() + result = Call("Sub", ty.BASE_INTEGER, [Variable("A"), Variable("B")]).variables() expected = [Variable("Sub"), Variable("A"), Variable("B")] assert result == expected def test_call_findall() -> None: - assert Call("X", rty.BASE_INTEGER, [Variable("Y"), Variable("Z")]).findall( + assert Call("X", ty.BASE_INTEGER, [Variable("Y"), Variable("Z")]).findall( lambda x: isinstance(x, Variable), ) == [ Variable("Y"), @@ -1974,22 +1974,22 @@ def test_call_findall() -> None: def test_call_str() -> None: - assert str(Call("Test", rty.BASE_INTEGER, [])) == "Test" + assert str(Call("Test", ty.BASE_INTEGER, [])) == "Test" def test_call_neg() -> None: - assert -Call("Test", rty.BASE_INTEGER, []) == Neg(Call("Test", rty.BASE_INTEGER, [])) + assert -Call("Test", ty.BASE_INTEGER, []) == Neg(Call("Test", ty.BASE_INTEGER, [])) def test_conversion_type() -> None: assert_type( Conversion( "X", - Selected(Variable("Y", type_=rty.Message("Y", {("Z",)}, {ID("Z"): rty.OPAQUE})), "Z"), - type_=rty.Message("X"), - argument_types=[rty.Message("Y")], + Selected(Variable("Y", type_=ty.Message("Y", {("Z",)}, {ID("Z"): ty.OPAQUE})), "Z"), + type_=ty.Message("X"), + argument_types=[ty.Message("Y")], ), - rty.Message("X"), + ty.Message("X"), ) @@ -2047,11 +2047,11 @@ def test_comprehension_type() -> None: assert_type( Comprehension( "X", - Variable("Y", type_=rty.Sequence("A", INT_TY)), + Variable("Y", type_=ty.Sequence("A", INT_TY)), Add(Variable("X"), Variable("Z", type_=INT_TY)), TRUE, ), - rty.Aggregate(rty.BASE_INTEGER), + ty.Aggregate(ty.BASE_INTEGER), ) assert_type( Comprehension( @@ -2059,10 +2059,10 @@ def test_comprehension_type() -> None: Selected( Variable( "Y", - type_=rty.Message( + type_=ty.Message( "M", {("F",)}, - {ID("F"): rty.Sequence("A", INT_TY)}, + {ID("F"): ty.Sequence("A", INT_TY)}, ), ), "F", @@ -2070,7 +2070,7 @@ def test_comprehension_type() -> None: Variable("X"), Equal(Variable("X"), Number(1)), ), - rty.Aggregate(INT_TY), + ty.Aggregate(INT_TY), ) @@ -2146,8 +2146,8 @@ def test_comprehension_str() -> None: ("field_values", "type_"), [ ( - {"X": Variable("A", type_=INT_TY), "Y": Variable("B", type_=rty.BOOLEAN)}, - rty.Message( + {"X": Variable("A", type_=INT_TY), "Y": Variable("B", type_=ty.BOOLEAN)}, + ty.Message( "M", { ("X",), @@ -2156,30 +2156,30 @@ def test_comprehension_str() -> None: }, { ID("X"): INT_TY, - ID("Y"): rty.BOOLEAN, + ID("Y"): ty.BOOLEAN, }, ), ), ( - {"X": Variable("A", type_=rty.Message("I"))}, - rty.Message( + {"X": Variable("A", type_=ty.Message("I"))}, + ty.Message( "M", { ("X",), }, {}, { - ID("X"): rty.OPAQUE, + ID("X"): ty.OPAQUE, }, [ - rty.Refinement(ID("X"), rty.Message("I"), "P"), - rty.Refinement(ID("X"), rty.Message("J"), "P"), + ty.Refinement(ID("X"), ty.Message("I"), "P"), + ty.Refinement(ID("X"), ty.Message("J"), "P"), ], ), ), ], ) -def test_message_aggregate_type(field_values: Mapping[StrID, Expr], type_: rty.Type) -> None: +def test_message_aggregate_type(field_values: Mapping[StrID, Expr], type_: ty.Type) -> None: assert_type( MessageAggregate( "M", @@ -2198,14 +2198,14 @@ def test_message_aggregate_type(field_values: Mapping[StrID, Expr], type_: rty.T "X": Variable("A", location=Location((10, 30))), "Y": Variable("B", location=Location((10, 40))), }, - rty.Message( + ty.Message( "M", { ("X", "Y"), }, { ID("X"): INT_TY, - ID("Y"): rty.BOOLEAN, + ID("Y"): ty.BOOLEAN, }, ), r'^:10:30: error: undefined variable "A"\n' @@ -2214,34 +2214,34 @@ def test_message_aggregate_type(field_values: Mapping[StrID, Expr], type_: rty.T ( { "X": Variable("A", type_=INT_TY), - "Y": Variable("B", type_=rty.BOOLEAN), + "Y": Variable("B", type_=ty.BOOLEAN), ID("Z", location=Location((10, 50))): Variable("Z", type_=INT_TY), }, - rty.Message( + ty.Message( "M", { ("X", "Y"), }, { ID("X"): INT_TY, - ID("Y"): rty.BOOLEAN, + ID("Y"): ty.BOOLEAN, }, ), r'^:10:50: error: invalid field "Z" for message type "M"$', ), ( { - ID("Y", location=Location((10, 30))): Variable("B", type_=rty.BOOLEAN), + ID("Y", location=Location((10, 30))): Variable("B", type_=ty.BOOLEAN), "X": Variable("A", type_=INT_TY), }, - rty.Message( + ty.Message( "M", { ("X", "Y"), }, { ID("X"): INT_TY, - ID("Y"): rty.BOOLEAN, + ID("Y"): ty.BOOLEAN, }, ), r'^:10:30: error: invalid position for field "Y" of message type "M"$', @@ -2250,7 +2250,7 @@ def test_message_aggregate_type(field_values: Mapping[StrID, Expr], type_: rty.T { "X": Variable("A", type_=INT_TY), }, - rty.Message( + ty.Message( "M", { ("X", "Y"), @@ -2258,7 +2258,7 @@ def test_message_aggregate_type(field_values: Mapping[StrID, Expr], type_: rty.T }, { ID("X"): INT_TY, - ID("Y"): rty.BOOLEAN, + ID("Y"): ty.BOOLEAN, ID("Z"): INT_TY, }, ), @@ -2270,14 +2270,14 @@ def test_message_aggregate_type(field_values: Mapping[StrID, Expr], type_: rty.T "X": Variable("A", location=Location((10, 40))), "Y": Variable("B", location=Location((10, 30))), }, - rty.Message( + ty.Message( "M", { ("X", "Y", "Z"), }, { ID("X"): INT_TY, - ID("Y"): rty.BOOLEAN, + ID("Y"): ty.BOOLEAN, ID("Z"): INT_TY, }, ), @@ -2291,7 +2291,7 @@ def test_message_aggregate_type(field_values: Mapping[StrID, Expr], type_: rty.T "X": Variable("A", location=Location((10, 40))), "Y": Literal("B", location=Location((10, 30))), }, - rty.Undefined(), + ty.Undefined(), r'^:10:40: error: undefined variable "A"\n' r':10:30: error: undefined literal "B"\n' r':10:20: error: undefined message "X"$', @@ -2300,7 +2300,7 @@ def test_message_aggregate_type(field_values: Mapping[StrID, Expr], type_: rty.T ) def test_message_aggregate_type_error( field_values: Mapping[StrID, Expr], - type_: rty.Type, + type_: ty.Type, match: str, ) -> None: assert_type_error( @@ -2353,8 +2353,8 @@ def test_message_aggregate_variables() -> None: ("field_values", "type_"), [ ( - {"Y": Variable("A", type_=INT_TY), "Z": Variable("B", type_=rty.BOOLEAN)}, - rty.Message( + {"Y": Variable("A", type_=INT_TY), "Z": Variable("B", type_=ty.BOOLEAN)}, + ty.Message( "M", { ("X",), @@ -2364,30 +2364,30 @@ def test_message_aggregate_variables() -> None: {}, { ID("Y"): INT_TY, - ID("Z"): rty.BOOLEAN, + ID("Z"): ty.BOOLEAN, }, ), ), ( - {"Y": Variable("A", type_=rty.Message("I"))}, - rty.Message( + {"Y": Variable("A", type_=ty.Message("I"))}, + ty.Message( "M", { ("X", "Y", "Z"), }, {}, { - ID("Y"): rty.OPAQUE, + ID("Y"): ty.OPAQUE, }, [ - rty.Refinement(ID("Y"), rty.Message("I"), "P"), - rty.Refinement(ID("Y"), rty.Message("J"), "P"), + ty.Refinement(ID("Y"), ty.Message("I"), "P"), + ty.Refinement(ID("Y"), ty.Message("J"), "P"), ], ), ), ], ) -def test_delta_message_aggregate_type(field_values: Mapping[StrID, Expr], type_: rty.Type) -> None: +def test_delta_message_aggregate_type(field_values: Mapping[StrID, Expr], type_: ty.Type) -> None: assert_type( DeltaMessageAggregate( "M", @@ -2406,7 +2406,7 @@ def test_delta_message_aggregate_type(field_values: Mapping[StrID, Expr], type_: "X": Variable("A", location=Location((10, 30))), "Y": Variable("B", location=Location((10, 40))), }, - rty.Message( + ty.Message( "M", { ("X", "Y"), @@ -2414,7 +2414,7 @@ def test_delta_message_aggregate_type(field_values: Mapping[StrID, Expr], type_: {}, { ID("X"): INT_TY, - ID("Y"): rty.BOOLEAN, + ID("Y"): ty.BOOLEAN, }, ), r'^:10:30: error: undefined variable "A"\n' @@ -2423,10 +2423,10 @@ def test_delta_message_aggregate_type(field_values: Mapping[StrID, Expr], type_: ( { "X": Variable("A", type_=INT_TY), - "Y": Variable("B", type_=rty.BOOLEAN), + "Y": Variable("B", type_=ty.BOOLEAN), ID("Z", location=Location((10, 50))): Variable("Z", type_=INT_TY), }, - rty.Message( + ty.Message( "M", { ("X", "Y"), @@ -2434,17 +2434,17 @@ def test_delta_message_aggregate_type(field_values: Mapping[StrID, Expr], type_: {}, { ID("X"): INT_TY, - ID("Y"): rty.BOOLEAN, + ID("Y"): ty.BOOLEAN, }, ), r'^:10:50: error: invalid field "Z" for message type "M"$', ), ( { - "Y": Variable("B", type_=rty.BOOLEAN), + "Y": Variable("B", type_=ty.BOOLEAN), ID("X", location=Location((10, 30))): Variable("A", type_=INT_TY), }, - rty.Message( + ty.Message( "M", { ("X", "Y"), @@ -2452,7 +2452,7 @@ def test_delta_message_aggregate_type(field_values: Mapping[StrID, Expr], type_: {}, { ID("X"): INT_TY, - ID("Y"): rty.BOOLEAN, + ID("Y"): ty.BOOLEAN, }, ), r'^:10:30: error: invalid position for field "X" of message type "M"$', @@ -2462,7 +2462,7 @@ def test_delta_message_aggregate_type(field_values: Mapping[StrID, Expr], type_: "X": Variable("A", location=Location((10, 40))), "Y": Variable("B", location=Location((10, 30))), }, - rty.Undefined(), + ty.Undefined(), r'^:10:40: error: undefined variable "A"\n' r':10:30: error: undefined variable "B"\n' r':10:20: error: undefined message "T"$', @@ -2471,7 +2471,7 @@ def test_delta_message_aggregate_type(field_values: Mapping[StrID, Expr], type_: ) def test_delta_message_aggregate_type_error( field_values: Mapping[StrID, Expr], - type_: rty.Type, + type_: ty.Type, match: str, ) -> None: assert_type_error( @@ -2591,14 +2591,14 @@ def test_case_findall() -> None: def test_case_type() -> None: c1 = Variable("C", type_=models.enumeration().type_) - assert_type(CaseExpr(c1, [([ID("Zero"), ID("One")], TRUE), ([ID("Two")], FALSE)]), rty.BOOLEAN) + assert_type(CaseExpr(c1, [([ID("Zero"), ID("One")], TRUE), ([ID("Two")], FALSE)]), ty.BOOLEAN) assert_type( CaseExpr(c1, [([ID("Zero"), ID("One")], Number(1)), ([ID("Two")], Number(2))]), - rty.UniversalInteger(ty.Bounds(1, 2)), + ty.UniversalInteger(ty.Bounds(1, 2)), ) c2 = Variable("C", type_=TINY_INT.type_) - assert_type(CaseExpr(c2, [([Number(1), Number(2)], TRUE), ([Number(3)], FALSE)]), rty.BOOLEAN) + assert_type(CaseExpr(c2, [([Number(1), Number(2)], TRUE), ([Number(3)], FALSE)]), ty.BOOLEAN) assert_type_error( CaseExpr( @@ -2617,7 +2617,7 @@ def test_case_type() -> None: Opaque( Variable( ID("X", location=Location((1, 1))), - type_=rty.Message(ID("A", location=Location((1, 2)))), + type_=ty.Message(ID("A", location=Location((1, 2)))), location=Location((1, 3)), ), ), diff --git a/tests/unit/generator/allocator_test.py b/tests/unit/generator/allocator_test.py index 6566fc77c..1332c83f2 100644 --- a/tests/unit/generator/allocator_test.py +++ b/tests/unit/generator/allocator_test.py @@ -4,7 +4,7 @@ import pytest -from rflx import ir, typing_ as rty +from rflx import ir, ty from rflx.generator.allocator import AllocatorGenerator from rflx.identifier import ID, id_generator from rflx.integration import Integration, IntegrationFile, StateMachineIntegration @@ -41,7 +41,7 @@ [ ir.VarDecl( "Y", - rty.Message("T"), + ty.Message("T"), origin=ir.ConstructedOrigin("X : T", Location((2, 2))), ), ], @@ -52,7 +52,7 @@ declarations=[ ir.VarDecl( "X", - rty.Message("T"), + ty.Message("T"), origin=ir.ConstructedOrigin("X : T", Location((1, 1))), ), ], @@ -80,7 +80,7 @@ ), ], None, - [ir.Read("C", ir.ObjVar("X", rty.Message("T")))], + [ir.Read("C", ir.ObjVar("X", ty.Message("T")))], None, None, ), @@ -88,7 +88,7 @@ declarations=[ ir.VarDecl( "X", - rty.Message("T"), + ty.Message("T"), origin=ir.ConstructedOrigin("X : T", Location((1, 1))), ), ], @@ -133,7 +133,7 @@ def test_allocator( ), ], None, - [ir.Read("C", ir.ObjVar("X", rty.Message("T")))], + [ir.Read("C", ir.ObjVar("X", ty.Message("T")))], None, None, ), @@ -141,7 +141,7 @@ def test_allocator( declarations=[ ir.VarDecl( "X", - rty.Message("T"), + ty.Message("T"), origin=ir.ConstructedOrigin("X : T", Location((1, 1))), ), ], @@ -169,7 +169,7 @@ def test_allocator( ), ], None, - [ir.Read("C", ir.ObjVar("X", rty.Message("T")))], + [ir.Read("C", ir.ObjVar("X", ty.Message("T")))], None, None, ), @@ -177,12 +177,12 @@ def test_allocator( declarations=[ ir.VarDecl( "X", - rty.Message("T"), + ty.Message("T"), origin=ir.ConstructedOrigin("X : T", Location((1, 1))), ), ir.VarDecl( "Y", - rty.Message("T"), + ty.Message("T"), origin=ir.ConstructedOrigin("X : T", Location((2, 2))), ), ], diff --git a/tests/unit/generator/common_test.py b/tests/unit/generator/common_test.py index c0989cc5f..a27cd89e4 100644 --- a/tests/unit/generator/common_test.py +++ b/tests/unit/generator/common_test.py @@ -5,19 +5,19 @@ import pytest -from rflx import expr, typing_ as rty +from rflx import expr, ty from rflx.generator import common, const from rflx.identifier import ID from rflx.model import BUILTIN_TYPES, type_decl from rflx.model.message import FINAL, INITIAL, Field, Link, Message -from rflx.rapidflux import Location, ty +from rflx.rapidflux import Location from tests.data import models from tests.utils import assert_equal def test_type_translation() -> None: - assert (common.type_to_id(rty.BASE_INTEGER)) == const.TYPES_BASE_INT - assert (common.type_to_id(rty.Integer("P::mytype", ty.Bounds(1, 9)))) == ID("P::mytype") + assert (common.type_to_id(ty.BASE_INTEGER)) == const.TYPES_BASE_INT + assert (common.type_to_id(ty.Integer("P::mytype", ty.Bounds(1, 9)))) == ID("P::mytype") @pytest.mark.parametrize("embedded", [True, False]) @@ -47,7 +47,7 @@ def test_substitution_relation_aggregate( expr.ValueRange( expr.Call( const.TYPES_TO_INDEX, - rty.INDEX, + ty.INDEX, [ expr.Selected( expr.Indexed( @@ -60,7 +60,7 @@ def test_substitution_relation_aggregate( ), expr.Call( const.TYPES_TO_INDEX, - rty.INDEX, + ty.INDEX, [ expr.Selected( expr.Indexed( @@ -78,7 +78,7 @@ def test_substitution_relation_aggregate( else: equal_call = expr.Call( "Equal", - rty.BOOLEAN, + ty.BOOLEAN, [ expr.Variable("Ctx"), expr.Variable("F_Value"), @@ -98,11 +98,11 @@ def test_substitution_relation_aggregate( [ ( (expr.Variable("Length"), expr.Number(1)), - (expr.Call("Get_Length", rty.BASE_INTEGER, [expr.Variable("Ctx")]), expr.Number(1)), + (expr.Call("Get_Length", ty.BASE_INTEGER, [expr.Variable("Ctx")]), expr.Number(1)), ), ( (expr.Number(1), expr.Variable("Length")), - (expr.Number(1), expr.Call("Get_Length", rty.BASE_INTEGER, [expr.Variable("Ctx")])), + (expr.Number(1), expr.Call("Get_Length", ty.BASE_INTEGER, [expr.Variable("Ctx")])), ), ((expr.Number(1), expr.Variable("Unknown")), (expr.Number(1), expr.Variable("Unknown"))), ], @@ -172,13 +172,13 @@ def test_param_enumeration_condition() -> None: expr.Equal( expr.Call( "RFLX_Types::Base_Integer", - rty.BASE_INTEGER, - [expr.Call("To_Base_Integer", rty.BASE_INTEGER, [expr.Variable("Param")])], + ty.BASE_INTEGER, + [expr.Call("To_Base_Integer", ty.BASE_INTEGER, [expr.Variable("Param")])], ), expr.Call( "RFLX_Types::Base_Integer", - rty.BASE_INTEGER, - [expr.Call("To_Base_Integer", rty.BASE_INTEGER, [expr.Literal("E1")])], + ty.BASE_INTEGER, + [expr.Call("To_Base_Integer", ty.BASE_INTEGER, [expr.Literal("E1")])], ), ), ) diff --git a/tests/unit/generator/state_machine_test.py b/tests/unit/generator/state_machine_test.py index 730bbc882..7a54fb572 100644 --- a/tests/unit/generator/state_machine_test.py +++ b/tests/unit/generator/state_machine_test.py @@ -9,7 +9,7 @@ import z3 from attr import define -from rflx import ada, ir, typing_ as rty +from rflx import ada, ir, ty from rflx.error import FatalError from rflx.generator import const from rflx.generator.allocator import AllocatorGenerator @@ -22,12 +22,12 @@ ) from rflx.identifier import ID, id_generator from rflx.integration import Integration -from rflx.rapidflux import Location, RecordFluxError, ty +from rflx.rapidflux import Location, RecordFluxError from tests.data import models -INT_TY = rty.Integer("I", ty.Bounds(1, 100)) -MSG_TY = rty.Message(ID("M", Location((1, 1)))) -SEQ_TY = rty.Sequence("S", rty.Message(ID("M", Location((1, 1))))) +INT_TY = ty.Integer("I", ty.Bounds(1, 100)) +MSG_TY = ty.Message(ID("M", Location((1, 1)))) +SEQ_TY = ty.Sequence("S", ty.Message(ID("M", Location((1, 1))))) @lru_cache @@ -56,7 +56,7 @@ def dummy_state_machine() -> ir.StateMachine: ("parameter", "expected"), [ ( - ir.FuncDecl("F", [], "T", type_=rty.BOOLEAN, location=None), + ir.FuncDecl("F", [], "T", type_=ty.BOOLEAN, location=None), [ ada.SubprogramDeclaration( specification=ada.ProcedureSpecification( @@ -73,18 +73,18 @@ def dummy_state_machine() -> ir.StateMachine: ir.FuncDecl( "F", [ - ir.Argument("P1", "Boolean", type_=rty.BOOLEAN), - ir.Argument("P2", "T2", type_=rty.OPAQUE), + ir.Argument("P1", "Boolean", type_=ty.BOOLEAN), + ir.Argument("P2", "T2", type_=ty.OPAQUE), ir.Argument( "P3", "T3", - type_=rty.Enumeration("T4", [ID("E1")], always_valid=True), + type_=ty.Enumeration("T4", [ID("E1")], always_valid=True), ), ir.Argument("P4", "T4", type_=INT_TY), - ir.Argument("P5", "T5", type_=rty.Message("T5", is_definite=True)), + ir.Argument("P5", "T5", type_=ty.Message("T5", is_definite=True)), ], "T", - type_=rty.Message("T", is_definite=True), + type_=ty.Message("T", is_definite=True), location=None, ), [ @@ -123,7 +123,7 @@ def __str__(self) -> str: raise NotImplementedError @property - def type_(self) -> rty.Type: + def type_(self) -> ty.Type: raise NotImplementedError @@ -159,7 +159,7 @@ def test_state_machine_verify_formal_parameters( "F", [], "T", - rty.Undefined(), + ty.Undefined(), Location((10, 20)), ), FatalError, @@ -170,7 +170,7 @@ def test_state_machine_verify_formal_parameters( "F", [], "T", - rty.OPAQUE, + ty.OPAQUE, Location((10, 20)), ), FatalError, @@ -181,7 +181,7 @@ def test_state_machine_verify_formal_parameters( "F", [], "T", - rty.Sequence("A", INT_TY), + ty.Sequence("A", INT_TY), Location((10, 20)), ), RecordFluxError, @@ -192,7 +192,7 @@ def test_state_machine_verify_formal_parameters( "F", [], "T", - rty.Message("A", is_definite=False), + ty.Message("A", is_definite=False), Location((10, 20)), ), FatalError, @@ -203,10 +203,10 @@ def test_state_machine_verify_formal_parameters( "F", [], "T", - rty.Message( + ty.Message( "M", {("F",)}, - {ID("F"): rty.Sequence("A", INT_TY)}, + {ID("F"): ty.Sequence("A", INT_TY)}, is_definite=True, ), Location((10, 20)), @@ -217,9 +217,9 @@ def test_state_machine_verify_formal_parameters( ( ir.FuncDecl( "F", - [ir.Argument("P", "T", rty.Sequence("A", INT_TY))], + [ir.Argument("P", "T", ty.Sequence("A", INT_TY))], "T", - rty.BOOLEAN, + ty.BOOLEAN, Location((10, 20)), ), RecordFluxError, @@ -245,7 +245,7 @@ def test_state_machine_create_functions_error( ("declaration", "state_machine_global", "expected"), [ ( - ir.VarDecl("X", rty.BOOLEAN), + ir.VarDecl("X", ty.BOOLEAN), False, EvaluatedDeclaration(global_declarations=[ada.ObjectDeclaration("X", "Boolean")]), ), @@ -276,7 +276,7 @@ def test_state_machine_create_functions_error( ( ir.VarDecl( "X", - rty.Message("T"), + ty.Message("T"), origin=ir.ConstructedOrigin("X : T", Location((1, 1))), ), False, @@ -365,7 +365,7 @@ def test_state_machine_create_functions_error( ( ir.VarDecl( "X", - rty.Message("T"), + ty.Message("T"), origin=ir.ConstructedOrigin("X : T", Location((1, 1))), ), True, @@ -531,7 +531,7 @@ class EvaluatedDeclarationStr: ), ), ( - rty.Message("T"), + ty.Message("T"), None, False, False, @@ -559,7 +559,7 @@ class EvaluatedDeclarationStr: ), ), ( - rty.Message("T"), + ty.Message("T"), None, True, False, @@ -587,7 +587,7 @@ class EvaluatedDeclarationStr: ), ), ( - rty.Message("T"), + ty.Message("T"), None, False, True, @@ -615,7 +615,7 @@ class EvaluatedDeclarationStr: ), ), ( - rty.OPAQUE, + ty.OPAQUE, ir.ComplexExpr([], ir.Agg([])), False, False, @@ -626,7 +626,7 @@ class EvaluatedDeclarationStr: ), ), ( - rty.OPAQUE, + ty.OPAQUE, ir.ComplexExpr([], ir.Agg([])), True, False, @@ -637,7 +637,7 @@ class EvaluatedDeclarationStr: ), ), ( - rty.OPAQUE, + ty.OPAQUE, ir.ComplexExpr([], ir.Agg([ir.IntVal(1)])), False, False, @@ -649,7 +649,7 @@ class EvaluatedDeclarationStr: ), ), ( - rty.OPAQUE, + ty.OPAQUE, ir.ComplexExpr([], ir.Agg([ir.IntVal(1)])), True, False, @@ -662,7 +662,7 @@ class EvaluatedDeclarationStr: ), ), ( - rty.OPAQUE, + ty.OPAQUE, ir.ComplexExpr([], ir.Agg([ir.IntVal(1), ir.IntVal(2)])), False, False, @@ -674,7 +674,7 @@ class EvaluatedDeclarationStr: ), ), ( - rty.OPAQUE, + ty.OPAQUE, ir.ComplexExpr([], ir.Agg([ir.IntVal(1), ir.IntVal(2)])), False, True, @@ -686,7 +686,7 @@ class EvaluatedDeclarationStr: ), ), ( - rty.OPAQUE, + ty.OPAQUE, ir.ComplexExpr([], ir.Agg([ir.IntVal(1), ir.IntVal(2)])), True, False, @@ -699,7 +699,7 @@ class EvaluatedDeclarationStr: ), ), ( - rty.OPAQUE, + ty.OPAQUE, None, False, False, @@ -708,7 +708,7 @@ class EvaluatedDeclarationStr: ), ), ( - rty.OPAQUE, + ty.OPAQUE, None, False, True, @@ -717,7 +717,7 @@ class EvaluatedDeclarationStr: ), ), ( - rty.OPAQUE, + ty.OPAQUE, None, True, False, @@ -728,7 +728,7 @@ class EvaluatedDeclarationStr: ], ) def test_state_machine_declare( - type_: rty.Type, + type_: ty.Type, expression: ir.ComplexExpr | None, constant: bool, state_machine_global: bool, @@ -782,7 +782,7 @@ def test_state_machine_declare( r"initialization using function call not yet supported", ), ( - rty.OPAQUE, + ty.OPAQUE, ir.ComplexExpr( [ir.Assign("X", ir.IntVal(0), INT_TY)], ir.IntVar( @@ -795,12 +795,12 @@ def test_state_machine_declare( r"initialization not yet supported", ), ( - rty.Message("T"), + ty.Message("T"), ir.ComplexExpr( [], ir.EnumLit( "True", - rty.BOOLEAN, + ty.BOOLEAN, origin=ir.ConstructedOrigin("True", Location((10, 20))), ), ), @@ -821,7 +821,7 @@ def test_state_machine_declare( r"initialization with complex expression not yet supported", ), ( - rty.Undefined(), + ty.Undefined(), None, FatalError, r"unexpected variable declaration for undefined type", @@ -829,7 +829,7 @@ def test_state_machine_declare( ], ) def test_state_machine_declare_error( - type_: rty.Type, + type_: ty.Type, expression: ir.ComplexExpr | None, error_type: type[RecordFluxError], error_msg: str, @@ -882,20 +882,20 @@ def _update_str(self) -> None: { ID("Message_Type"): ir.EnumLit( ID("Universal::MT_Data"), - rty.Enumeration("Universal::Message_Type", [ID("Universal::MT_Data")]), + ty.Enumeration("Universal::Message_Type", [ID("Universal::MT_Data")]), ), ID("Length"): ir.IntVal(0), ID("Data"): ir.Agg([]), }, - type_=rty.Message( + type_=ty.Message( "Universal::Message", field_types={ - ID("Message_Type"): rty.Enumeration( + ID("Message_Type"): ty.Enumeration( "Universal::Message_Type", [ID("Universal::MT_Data")], ), - ID("Length"): rty.Integer("Universal::Length", ty.Bounds(0, 100)), - ID("Data"): rty.OPAQUE, + ID("Length"): ty.Integer("Universal::Length", ty.Bounds(0, 100)), + ID("Data"): ty.OPAQUE, }, ), ), @@ -946,13 +946,13 @@ def _update_str(self) -> None: ir.BoolCall( "F", [ - ir.ObjVar("A", rty.Message("Universal::Message")), + ir.ObjVar("A", ty.Message("Universal::Message")), ], [ - rty.Message("Universal::Message"), + ty.Message("Universal::Message"), ], ), - rty.BOOLEAN, + ty.BOOLEAN, origin=ir.ConstructedOrigin("", Location((1, 1))), ), """\ @@ -971,14 +971,14 @@ def _update_str(self) -> None: ir.ObjCall( "F", [ - ir.ObjVar("A", rty.Message("Universal::Message")), + ir.ObjVar("A", ty.Message("Universal::Message")), ], [ - rty.Message("Universal::Message"), + ty.Message("Universal::Message"), ], - rty.Message("Universal::Option"), + ty.Message("Universal::Option"), ), - rty.Message("Universal::Option"), + ty.Message("Universal::Option"), origin=ir.ConstructedOrigin("", Location((1, 1))), ), """\ @@ -1012,7 +1012,7 @@ def _update_str(self) -> None: ir.BoolVar("A"), ir.BoolVar("B"), ), - rty.BOOLEAN, + ty.BOOLEAN, origin=ir.ConstructedOrigin("", Location((1, 1))), ), "-- :1:1\nX := A\nand then B;", @@ -1021,7 +1021,7 @@ def _update_str(self) -> None: ir.Reset( "X", {}, - rty.Message("P::M"), + ty.Message("P::M"), origin=ir.ConstructedOrigin("", Location((1, 1))), ), "-- :1:1\nP.M.Reset (X_Ctx);", @@ -1030,7 +1030,7 @@ def _update_str(self) -> None: ir.Reset( "X", {}, - rty.Sequence("P::S", INT_TY), + ty.Sequence("P::S", INT_TY), origin=ir.ConstructedOrigin("", Location((1, 1))), ), "-- :1:1\nP.S.Reset (X_Ctx);", @@ -1038,7 +1038,7 @@ def _update_str(self) -> None: ( ir.Read( "X", - ir.ObjVar("Y", rty.Message("P::M")), + ir.ObjVar("Y", ty.Message("P::M")), origin=ir.ConstructedOrigin("", Location((1, 1))), ), "-- :1:1\nP.M.Verify_Message (Y_Ctx);", @@ -1046,7 +1046,7 @@ def _update_str(self) -> None: ( ir.Write( "X", - ir.ObjVar("Y", rty.Message("P::M")), + ir.ObjVar("Y", ty.Message("P::M")), origin=ir.ConstructedOrigin("", Location((1, 1))), ), "-- :1:1", @@ -1136,8 +1136,8 @@ def test_state_machine_state_action_error( @define class UnknownExpr(ir.Expr): @property - def type_(self) -> rty.Any: - return rty.Message("T") + def type_(self) -> ty.Any: + return ty.Message("T") @property def accessed_vars(self) -> list[ID]: @@ -1157,22 +1157,22 @@ def _update_str(self) -> None: ("type_", "expression", "error_type", "error_msg"), [ ( - rty.Sequence("A", INT_TY), + ty.Sequence("A", INT_TY), ir.ObjFieldAccess( "Z", "Z", - rty.Message("C", {("Z",)}, {}, {ID("Z"): rty.Sequence("A", INT_TY)}), + ty.Message("C", {("Z",)}, {}, {ID("Z"): ty.Sequence("A", INT_TY)}), origin=ir.ConstructedOrigin("", Location((10, 20))), ), RecordFluxError, r"copying of sequence not yet supported", ), ( - rty.Aggregate(INT_TY), + ty.Aggregate(INT_TY), ir.ObjFieldAccess( "Z", "Z", - rty.Message("B", {("Z",)}, {}, {ID("Z"): rty.Aggregate(INT_TY)}), + ty.Message("B", {("Z",)}, {}, {ID("Z"): ty.Aggregate(INT_TY)}), origin=ir.ConstructedOrigin("", Location((10, 20))), ), FatalError, @@ -1180,30 +1180,30 @@ def _update_str(self) -> None: r' in assignment of "X"', ), ( - rty.Message("A"), + ty.Message("A"), ir.MsgAgg( "Universal::Message", { ID("Message_Type"): ir.EnumLit( "Universal::MT_Data", - rty.Enumeration("Universal::Message_Type", [ID("Universal::MT_Data")]), + ty.Enumeration("Universal::Message_Type", [ID("Universal::MT_Data")]), ), ID("Length"): ir.IntVal(1), ID("Data"): ir.ObjVar( "Z", - rty.Message("Universal::Option"), + ty.Message("Universal::Option"), origin=ir.ConstructedOrigin("", Location((10, 20))), ), }, - rty.Message( + ty.Message( "Universal::Message", field_types={ - ID("Message_Type"): rty.Enumeration( + ID("Message_Type"): ty.Enumeration( "Universal::Message_Type", [ID("Universal::MT_Data")], ), - ID("Length"): rty.Integer("Universal::Length", ty.Bounds(0, 100)), - ID("Data"): rty.OPAQUE, + ID("Length"): ty.Integer("Universal::Length", ty.Bounds(0, 100)), + ID("Data"): ty.OPAQUE, }, ), ), @@ -1212,34 +1212,34 @@ def _update_str(self) -> None: r" not yet supported", ), ( - rty.Message("A"), + ty.Message("A"), ir.MsgAgg( "Universal::Message", { ID("Message_Type"): ir.EnumLit( "Universal::MT_Data", - rty.Enumeration("Universal::Message_Type", [ID("Universal::MT_Data")]), + ty.Enumeration("Universal::Message_Type", [ID("Universal::MT_Data")]), ), ID("Length"): ir.Last( "Z", - rty.Message("Universal::Option"), + ty.Message("Universal::Option"), origin=ir.ConstructedOrigin("", Location((10, 20))), ), ID("Data"): ir.ObjVar( "Z", - rty.Message("Universal::Option"), + ty.Message("Universal::Option"), origin=ir.ConstructedOrigin("", Location((10, 20))), ), }, - rty.Message( + ty.Message( "Universal::Message", field_types={ - ID("Message_Type"): rty.Enumeration( + ID("Message_Type"): ty.Enumeration( "Universal::Message_Type", [ID("Universal::MT_Data")], ), - ID("Length"): rty.Integer("Universal::Length", ty.Bounds(0, 100)), - ID("Data"): rty.OPAQUE, + ID("Length"): ty.Integer("Universal::Length", ty.Bounds(0, 100)), + ID("Data"): ty.OPAQUE, }, ), ), @@ -1248,30 +1248,30 @@ def _update_str(self) -> None: r" field not yet supported", ), ( - rty.Message("A"), + ty.Message("A"), ir.MsgAgg( "Universal::Message", { ID("Message_Type"): ir.EnumLit( "Universal::MT_Data", - rty.Enumeration("Universal::Message_Type", [ID("Universal::MT_Data")]), + ty.Enumeration("Universal::Message_Type", [ID("Universal::MT_Data")]), ), ID("Length"): ir.IntVal(1), ID("Data"): ir.Head( "Z", - rty.Sequence("Universal::Options", rty.Message("Universal::Option")), + ty.Sequence("Universal::Options", ty.Message("Universal::Option")), origin=ir.ConstructedOrigin("", Location((10, 20))), ), }, - rty.Message( + ty.Message( "Universal::Message", field_types={ - ID("Message_Type"): rty.Enumeration( + ID("Message_Type"): ty.Enumeration( "Universal::Message_Type", [ID("Universal::MT_Data")], ), - ID("Length"): rty.Integer("Universal::Length", ty.Bounds(0, 100)), - ID("Data"): rty.OPAQUE, + ID("Length"): ty.Integer("Universal::Length", ty.Bounds(0, 100)), + ID("Data"): ty.OPAQUE, }, ), ), @@ -1279,12 +1279,12 @@ def _update_str(self) -> None: r'Head with message type "Universal::Option" in expression not yet supported', ), ( - rty.Sequence("A", rty.Message("B")), + ty.Sequence("A", ty.Message("B")), ir.Comprehension( "E", ir.ObjVar( "L", - rty.Sequence("A", rty.Message("B")), + ty.Sequence("A", ty.Message("B")), ), ir.ComplexExpr( [], @@ -1292,7 +1292,7 @@ def _update_str(self) -> None: "F", [], [], - rty.Message("B"), + ty.Message("B"), origin=ir.ConstructedOrigin("", Location((10, 20))), ), ), @@ -1302,20 +1302,20 @@ def _update_str(self) -> None: "expressions other than variables not yet supported as selector for message types", ), ( - rty.Message("B"), + ty.Message("B"), ir.Find( "E", ir.ObjFieldAccess( "Y", "Z", - rty.Message("C", {("Z",)}, {}, {ID("Z"): rty.Sequence("D", rty.Message("B"))}), + ty.Message("C", {("Z",)}, {}, {ID("Z"): ty.Sequence("D", ty.Message("B"))}), ), ir.ComplexExpr( [], ir.ObjFieldAccess( "E", "Z", - rty.Message("B", {("Z",)}, {}, {ID("Z"): rty.Message("B")}), + ty.Message("B", {("Z",)}, {}, {ID("Z"): ty.Message("B")}), origin=ir.ConstructedOrigin("", Location((10, 20))), ), ), @@ -1327,7 +1327,7 @@ def _update_str(self) -> None: ir.IntFieldAccess( "E", "Z", - rty.Message("B", {("Z",)}, {}, {ID("Z"): INT_TY}), + ty.Message("B", {("Z",)}, {}, {ID("Z"): INT_TY}), ), INT_TY, origin=ir.ConstructedOrigin("", Location((20, 30))), @@ -1343,12 +1343,12 @@ def _update_str(self) -> None: "expressions other than variables not yet supported as selector for message types", ), ( - rty.Message("B"), + ty.Message("B"), ir.Find( "E", ir.ObjVar( "L", - rty.Sequence("A", rty.Message("B")), + ty.Sequence("A", ty.Message("B")), ), ir.ComplexExpr( [], @@ -1356,7 +1356,7 @@ def _update_str(self) -> None: "F", [], [], - rty.Message("B"), + ty.Message("B"), origin=ir.ConstructedOrigin("", Location((10, 20))), ), ), @@ -1368,7 +1368,7 @@ def _update_str(self) -> None: ir.IntFieldAccess( "E", "Z", - rty.Message("B", {("Z",)}, {}, {ID("Z"): INT_TY}), + ty.Message("B", {("Z",)}, {}, {ID("Z"): INT_TY}), ), INT_TY, origin=ir.ConstructedOrigin("", Location((20, 30))), @@ -1384,12 +1384,12 @@ def _update_str(self) -> None: "expressions other than variables not yet supported as selector for message types", ), ( - rty.Sequence("A", INT_TY), + ty.Sequence("A", INT_TY), ir.Comprehension( "E", ir.ObjVar( "L", - rty.Sequence("A", INT_TY), + ty.Sequence("A", INT_TY), origin=ir.ConstructedOrigin("", Location((10, 20))), ), ir.ComplexExpr([], ir.ObjVar("E", INT_TY)), @@ -1405,7 +1405,7 @@ def _update_str(self) -> None: "E", ir.ObjVar( "L", - rty.Sequence("A", INT_TY), + ty.Sequence("A", INT_TY), origin=ir.ConstructedOrigin("", Location((10, 20))), ), ir.ComplexExpr([], ir.ObjVar("E", INT_TY)), @@ -1419,7 +1419,7 @@ def _update_str(self) -> None: r" not yet supported", ), ( - rty.Sequence("A", INT_TY), + ty.Sequence("A", INT_TY), ir.BoolCall( "F", [ @@ -1437,37 +1437,37 @@ def _update_str(self) -> None: r'IntCall with integer type "I" \(1 \.\. 100\) as function argument not yet supported', ), ( - rty.Message("A"), + ty.Message("A"), ir.Conversion( - rty.Message("A"), - ir.ObjFieldAccess("Z", "Z", rty.Message("B", {("Z",)}, {}, {ID("Z"): rty.OPAQUE})), + ty.Message("A"), + ir.ObjFieldAccess("Z", "Z", ty.Message("B", {("Z",)}, {}, {ID("Z"): ty.OPAQUE})), origin=ir.ConstructedOrigin("", Location((10, 20))), ), FatalError, r'no refinement for field "Z" of message "B" leads to "A"', ), ( - rty.Message("A"), + ty.Message("A"), ir.ObjVar( "X", - rty.Message("A"), + ty.Message("A"), origin=ir.ConstructedOrigin("", Location((10, 20))), ), RecordFluxError, r'referencing assignment target "X" of type message in expression not yet supported', ), ( - rty.Message("A"), + ty.Message("A"), ir.ObjVar( "Y", - rty.Message("A"), + ty.Message("A"), origin=ir.ConstructedOrigin("", Location((10, 20))), ), RecordFluxError, r'ObjVar with message type "A" in assignment not yet supported', ), ( - rty.Message("A"), + ty.Message("A"), UnknownExpr( origin=ir.ConstructedOrigin("", Location((10, 20))), ), @@ -1475,10 +1475,10 @@ def _update_str(self) -> None: r'unexpected expression "UnknownExpr" with message type "T" in assignment', ), ( - rty.Message("A"), + ty.Message("A"), ir.Head( "X", - rty.Sequence("B", rty.OPAQUE), + ty.Sequence("B", ty.OPAQUE), origin=ir.ConstructedOrigin("", Location((10, 20))), ), FatalError, @@ -1488,7 +1488,7 @@ def _update_str(self) -> None: ], ) def test_state_machine_assign_error( - type_: rty.Type, + type_: ty.Type, expression: ir.Expr, error_type: type[RecordFluxError], error_msg: str, @@ -1535,10 +1535,10 @@ def test_state_machine_assign_error( "L", ir.ObjVar( "X", - rty.Message("A"), + ty.Message("A"), origin=ir.ConstructedOrigin("", Location((10, 20))), ), - rty.Sequence("B", rty.Message("A")), + ty.Sequence("B", ty.Message("A")), ), RecordFluxError, r'ObjVar with message type "A" in Append statement not yet supported', @@ -1547,7 +1547,7 @@ def test_state_machine_assign_error( ir.Append( "L", ir.ObjVar("X", MSG_TY, origin=ir.ConstructedOrigin("", Location((10, 20)))), - rty.Sequence("B", rty.Undefined()), + ty.Sequence("B", ty.Undefined()), ), FatalError, r"unexpected element type undefined type in Append statement", @@ -1562,7 +1562,7 @@ def test_state_machine_assign_error( INT_TY, origin=ir.ConstructedOrigin("", Location((10, 20))), ), - rty.Sequence("B", INT_TY), + ty.Sequence("B", INT_TY), ), RecordFluxError, r'IntCall with integer type "I" \(1 \.\. 100\) in Append statement not yet supported', @@ -1608,7 +1608,7 @@ def test_state_machine_append_error( "L", ir.EnumLit( "E", - rty.Enumeration("A", [ID("E")]), + ty.Enumeration("A", [ID("E")]), origin=ir.ConstructedOrigin("", Location((10, 20))), ), ), @@ -1644,7 +1644,7 @@ def test_state_machine_read_error( "L", ir.EnumLit( "E", - rty.Enumeration("A", [ID("E")]), + ty.Enumeration("A", [ID("E")]), origin=ir.ConstructedOrigin("", Location((10, 20))), ), ), @@ -1707,8 +1707,8 @@ def test_state_machine_write_error( ), ( ir.Equal( - ir.ObjVar("X", rty.Enumeration("P::E", [ID("P::E1")], always_valid=True)), - ir.EnumLit("P::E1", rty.Enumeration("P::E", [ID("P::E1")], always_valid=True)), + ir.ObjVar("X", ty.Enumeration("P::E", [ID("P::E1")], always_valid=True)), + ir.EnumLit("P::E1", ty.Enumeration("P::E", [ID("P::E1")], always_valid=True)), ), ada.Equal( ada.Variable("X"), @@ -1717,8 +1717,8 @@ def test_state_machine_write_error( ), ( ir.NotEqual( - ir.EnumLit("P::E1", rty.Enumeration("P::E", [ID("P::E1")], always_valid=True)), - ir.ObjVar("X", rty.Enumeration("P::E", [ID("P::E1")], always_valid=True)), + ir.EnumLit("P::E1", ty.Enumeration("P::E", [ID("P::E1")], always_valid=True)), + ir.ObjVar("X", ty.Enumeration("P::E", [ID("P::E1")], always_valid=True)), ), ada.NotEqual( ada.NamedAggregate(("Known", ada.Literal("True")), ("Enum", ada.Literal("P::E1"))), diff --git a/tests/unit/ir_test.py b/tests/unit/ir_test.py index 499466998..61184730b 100644 --- a/tests/unit/ir_test.py +++ b/tests/unit/ir_test.py @@ -5,16 +5,15 @@ import pytest import z3 -from rflx import expr, ir, typing_ as rty +from rflx import expr, ir, ty from rflx.error import Location from rflx.identifier import ID, id_generator -from rflx.rapidflux import ty PROOF_MANAGER = ir.ProofManager(2) -INT_TY = rty.Integer("I", ty.Bounds(10, 100)) -ENUM_TY = rty.Enumeration("E", [ID("E1"), ID("E2")]) -MSG_TY = rty.Message("M", {("E", "I")}, {}, {ID("E"): ENUM_TY, ID("I"): INT_TY}) -SEQ_TY = rty.Sequence("S", MSG_TY) +INT_TY = ty.Integer("I", ty.Bounds(10, 100)) +ENUM_TY = ty.Enumeration("E", [ID("E1"), ID("E2")]) +MSG_TY = ty.Message("M", {("E", "I")}, {}, {ID("E"): ENUM_TY, ID("I"): INT_TY}) +SEQ_TY = ty.Sequence("S", MSG_TY) def test_constructed_origin_location() -> None: @@ -85,7 +84,7 @@ def test_assign_to_z3_expr() -> None: assert ir.Assign("X", ir.IntVar("Y", INT_TY), INT_TY).to_z3_expr() == ( z3.Int("X") == z3.Int("Y") ) - assert ir.Assign("X", ir.BoolVar("Y"), rty.BOOLEAN).to_z3_expr() == ( + assert ir.Assign("X", ir.BoolVar("Y"), ty.BOOLEAN).to_z3_expr() == ( z3.Bool("X") == z3.Bool("Y") ) assert ir.Assign("X", ir.ObjVar("Y", MSG_TY), MSG_TY).to_z3_expr() == z3.BoolVal(val=True) @@ -93,7 +92,7 @@ def test_assign_to_z3_expr() -> None: def test_assign_target_var() -> None: assert ir.Assign("X", ir.IntVar("Y", INT_TY), INT_TY).target_var == ir.IntVar("X", INT_TY) - assert ir.Assign("X", ir.BoolVar("Y"), rty.BOOLEAN).target_var == ir.BoolVar("X") + assert ir.Assign("X", ir.BoolVar("Y"), ty.BOOLEAN).target_var == ir.BoolVar("X") assert ir.Assign("X", ir.ObjVar("Y", MSG_TY), MSG_TY).target_var == ir.ObjVar("X", MSG_TY) @@ -372,7 +371,7 @@ def test_bool_var_identifier() -> None: def test_bool_var_type() -> None: - assert ir.BoolVar("X").type_ == rty.BOOLEAN + assert ir.BoolVar("X").type_ == ty.BOOLEAN def test_bool_var_accessed_vars() -> None: @@ -500,20 +499,20 @@ def test_attr_str(attribute: ir.Attr, expected: str) -> None: @pytest.mark.parametrize( ("attribute", "expected"), [ - (ir.Size("X", MSG_TY), rty.BIT_LENGTH), - (ir.Size("X", ENUM_TY), rty.UNIVERSAL_INTEGER), - (ir.Length("X", MSG_TY), rty.UNIVERSAL_INTEGER), - (ir.First("X", MSG_TY), rty.UNIVERSAL_INTEGER), - (ir.Last("X", MSG_TY), rty.UNIVERSAL_INTEGER), - (ir.ValidChecksum("X", MSG_TY), rty.BOOLEAN), - (ir.Valid("X", MSG_TY), rty.BOOLEAN), - (ir.Present("X", MSG_TY), rty.BOOLEAN), - (ir.HasData("X", MSG_TY), rty.BOOLEAN), + (ir.Size("X", MSG_TY), ty.BIT_LENGTH), + (ir.Size("X", ENUM_TY), ty.UNIVERSAL_INTEGER), + (ir.Length("X", MSG_TY), ty.UNIVERSAL_INTEGER), + (ir.First("X", MSG_TY), ty.UNIVERSAL_INTEGER), + (ir.Last("X", MSG_TY), ty.UNIVERSAL_INTEGER), + (ir.ValidChecksum("X", MSG_TY), ty.BOOLEAN), + (ir.Valid("X", MSG_TY), ty.BOOLEAN), + (ir.Present("X", MSG_TY), ty.BOOLEAN), + (ir.HasData("X", MSG_TY), ty.BOOLEAN), (ir.Head("X", SEQ_TY), MSG_TY), - (ir.Opaque("X", MSG_TY), rty.OPAQUE), + (ir.Opaque("X", MSG_TY), ty.OPAQUE), ], ) -def test_attr_type(attribute: ir.Attr, expected: rty.Type) -> None: +def test_attr_type(attribute: ir.Attr, expected: ty.Type) -> None: assert attribute.type_ == expected @@ -597,13 +596,13 @@ def test_field_access_attr_str(attribute: ir.FieldAccessAttr, expected: str) -> @pytest.mark.parametrize( ("attribute", "expected"), [ - (ir.FieldValidNext("X", "Y", MSG_TY), rty.BOOLEAN), - (ir.FieldValid("X", "Y", MSG_TY), rty.BOOLEAN), - (ir.FieldPresent("X", "Y", MSG_TY), rty.BOOLEAN), - (ir.FieldSize("X", "Y", MSG_TY), rty.BIT_LENGTH), + (ir.FieldValidNext("X", "Y", MSG_TY), ty.BOOLEAN), + (ir.FieldValid("X", "Y", MSG_TY), ty.BOOLEAN), + (ir.FieldPresent("X", "Y", MSG_TY), ty.BOOLEAN), + (ir.FieldSize("X", "Y", MSG_TY), ty.BIT_LENGTH), ], ) -def test_field_access_attr_type(attribute: ir.FieldAccessAttr, expected: rty.Type) -> None: +def test_field_access_attr_type(attribute: ir.FieldAccessAttr, expected: ty.Type) -> None: assert attribute.type_ == expected @@ -815,15 +814,15 @@ def test_add_to_z3_expr() -> None: def test_add_preconditions() -> None: assert ir.Add( - ir.IntVar("X", rty.BASE_INTEGER), + ir.IntVar("X", ty.BASE_INTEGER), ir.IntVal(1), origin=expr.Add(expr.Variable("X"), expr.Number(1)), ).preconditions(id_generator()) == [ ir.Cond( ir.LessEqual(ir.IntVar("X", INT_TY), ir.IntVar("T_0", INT_TY)), [ - ir.VarDecl("T_0", rty.BASE_INTEGER), - ir.Assign("T_0", ir.Sub(ir.IntVal(ir.INT_MAX), ir.IntVal(1)), rty.BASE_INTEGER), + ir.VarDecl("T_0", ty.BASE_INTEGER), + ir.Assign("T_0", ir.Sub(ir.IntVal(ir.INT_MAX), ir.IntVal(1)), ty.BASE_INTEGER), ], ), ] @@ -858,8 +857,8 @@ def test_mul_preconditions() -> None: ir.Cond( ir.LessEqual(ir.IntVar("X", INT_TY), ir.IntVar("T_0", INT_TY)), [ - ir.VarDecl("T_0", rty.BASE_INTEGER), - ir.Assign("T_0", ir.Div(ir.IntVal(ir.INT_MAX), ir.IntVal(1)), rty.BASE_INTEGER), + ir.VarDecl("T_0", ty.BASE_INTEGER), + ir.Assign("T_0", ir.Div(ir.IntVal(ir.INT_MAX), ir.IntVal(1)), ty.BASE_INTEGER), ], ), ] @@ -894,8 +893,8 @@ def test_pow_preconditions() -> None: ir.Cond( ir.LessEqual(ir.IntVar("T_0", INT_TY), ir.IntVal(ir.INT_MAX)), [ - ir.VarDecl("T_0", rty.BASE_INTEGER), - ir.Assign("T_0", ir.Pow(ir.IntVar("X", INT_TY), ir.IntVal(1)), rty.BASE_INTEGER), + ir.VarDecl("T_0", ty.BASE_INTEGER), + ir.Assign("T_0", ir.Pow(ir.IntVar("X", INT_TY), ir.IntVal(1)), ty.BASE_INTEGER), ], ), ] @@ -1048,7 +1047,7 @@ def test_bool_call_str() -> None: def test_bool_call_type() -> None: - assert ir.BoolCall("X", [ir.BoolVar("Y")], [], ir.BoolVal(value=True)).type_ == rty.BOOLEAN + assert ir.BoolCall("X", [ir.BoolVar("Y")], [], ir.BoolVal(value=True)).type_ == ty.BOOLEAN def test_bool_call_substituted() -> None: @@ -1234,7 +1233,7 @@ def test_bool_if_expr_type() -> None: ir.ComplexBoolExpr([], ir.BoolVar("Y")), ir.ComplexBoolExpr([], ir.BoolVal(value=False)), ).type_ - == rty.BOOLEAN + == ty.BOOLEAN ) @@ -1349,7 +1348,7 @@ def test_comprehension_type() -> None: ir.ObjVar("Y", SEQ_TY), ir.ComplexExpr([], ir.ObjVar("X", MSG_TY)), ir.ComplexBoolExpr([], ir.BoolVal(value=True)), - ).type_ == rty.Aggregate(MSG_TY) + ).type_ == ty.Aggregate(MSG_TY) def test_comprehension_accessed_vars() -> None: @@ -1494,7 +1493,7 @@ def test_agg_str() -> None: def test_agg_type() -> None: - assert ir.Agg([ir.IntVar("X", INT_TY), ir.IntVal(10)]).type_ == rty.Aggregate(rty.BASE_INTEGER) + assert ir.Agg([ir.IntVar("X", INT_TY), ir.IntVal(10)]).type_ == ty.Aggregate(ty.BASE_INTEGER) def test_agg_accessed_vars() -> None: @@ -1529,7 +1528,7 @@ def test_str_str() -> None: def test_str_type() -> None: - assert ir.Str("X").type_ == rty.OPAQUE + assert ir.Str("X").type_ == ty.OPAQUE def test_str_accessed_vars() -> None: @@ -1744,24 +1743,24 @@ def test_add_required_checks() -> None: [ ir.Assign( "A", - ir.Add(ir.IntVar("Y", rty.BASE_INTEGER), ir.IntVal(1)), - rty.BASE_INTEGER, + ir.Add(ir.IntVar("Y", ty.BASE_INTEGER), ir.IntVal(1)), + ty.BASE_INTEGER, ), ir.Assign( "B", - ir.Div(ir.IntVar("A", rty.BASE_INTEGER), ir.IntVar("Z", rty.BASE_INTEGER)), - rty.BASE_INTEGER, + ir.Div(ir.IntVar("A", ty.BASE_INTEGER), ir.IntVar("Z", ty.BASE_INTEGER)), + ty.BASE_INTEGER, ), ir.Assign( "X", - ir.Sub(ir.IntVar("B", rty.BASE_INTEGER), ir.IntVal(1)), - rty.BASE_INTEGER, + ir.Sub(ir.IntVar("B", ty.BASE_INTEGER), ir.IntVal(1)), + ty.BASE_INTEGER, ), - ir.Assign("Z", ir.IntVal(0), rty.BASE_INTEGER), + ir.Assign("Z", ir.IntVal(0), ty.BASE_INTEGER), ir.Assign( "C", - ir.Add(ir.IntVar("Z", rty.BASE_INTEGER), ir.IntVal(1)), - rty.BASE_INTEGER, + ir.Add(ir.IntVar("Z", ty.BASE_INTEGER), ir.IntVal(1)), + ty.BASE_INTEGER, ), ], id_generator(), @@ -1790,23 +1789,23 @@ def test_add_conversions() -> None: [ ir.Assign( "A", - ir.Add(ir.IntVar("Y", rty.BASE_INTEGER), ir.IntVal(10)), + ir.Add(ir.IntVar("Y", ty.BASE_INTEGER), ir.IntVal(10)), INT_TY, ), ir.Assign( "B", - ir.Div(ir.IntVar("A", rty.BASE_INTEGER), ir.IntVar("Z", rty.BASE_INTEGER)), + ir.Div(ir.IntVar("A", ty.BASE_INTEGER), ir.IntVar("Z", ty.BASE_INTEGER)), INT_TY, ), ir.Assign( "X", - ir.Sub(ir.IntVar("B", rty.BASE_INTEGER), ir.IntVal(10)), + ir.Sub(ir.IntVar("B", ty.BASE_INTEGER), ir.IntVal(10)), INT_TY, ), ir.Assign("Z", ir.IntVal(10), INT_TY), ir.Assign( "C", - ir.Add(ir.IntVar("Z", rty.BASE_INTEGER), ir.IntVal(10)), + ir.Add(ir.IntVar("Z", ty.BASE_INTEGER), ir.IntVal(10)), INT_TY, ), ], diff --git a/tests/unit/model/message_test.py b/tests/unit/model/message_test.py index ec40efa41..36b83c94a 100644 --- a/tests/unit/model/message_test.py +++ b/tests/unit/model/message_test.py @@ -8,7 +8,7 @@ import pytest -from rflx import expr_proof, typing_ as rty +from rflx import expr_proof, ty from rflx.error import FatalError from rflx.expr import ( FALSE, @@ -6005,13 +6005,13 @@ def test_set_refinements() -> None: message.set_refinements([models.refinement()]) assert message.type_.refinements == [ - rty.Refinement( + ty.Refinement( "F", - rty.Message( + ty.Message( ID("P::M", Location((1, 1))), {("F",)}, {}, - {ID("F"): rty.OPAQUE}, + {ID("F"): ty.OPAQUE}, refinements=[], is_definite=True, ), @@ -6346,7 +6346,7 @@ def test_refinement_type_error_in_condition() -> None: message, Field("P"), message, - Equal(Variable("L"), Literal("True", type_=rty.BOOLEAN, location=Location((10, 20)))), + Equal(Variable("L"), Literal("True", type_=ty.BOOLEAN, location=Location((10, 20)))), ) diff --git a/tests/unit/model/state_machine_test.py b/tests/unit/model/state_machine_test.py index 6286db8ad..c1b7f8fe7 100644 --- a/tests/unit/model/state_machine_test.py +++ b/tests/unit/model/state_machine_test.py @@ -5,8 +5,7 @@ import pytest -import rflx.typing_ as rty -from rflx import expr +from rflx import expr, ty from rflx.identifier import ID from rflx.model import ( BOOLEAN, @@ -19,7 +18,7 @@ declaration as decl, statement as stmt, ) -from rflx.rapidflux import Location, RecordFluxError, ty +from rflx.rapidflux import Location, RecordFluxError from tests.data import models from tests.utils import assert_equal, assert_state_machine_model_error, get_test_model @@ -50,7 +49,7 @@ def test_str() -> None: condition=expr.And( expr.Equal(expr.Variable("Z"), expr.TRUE), expr.Equal( - expr.Call("G", rty.BOOLEAN, [expr.Variable("F")]), + expr.Call("G", ty.BOOLEAN, [expr.Variable("F")]), expr.TRUE, ), ), @@ -140,7 +139,7 @@ def test_identifier_normalization(monkeypatch: pytest.MonkeyPatch) -> None: condition=expr.And( expr.Equal(expr.Variable("z"), expr.TRUE), expr.Equal( - expr.Call("g", rty.BOOLEAN, [expr.Variable("f")]), + expr.Call("g", ty.BOOLEAN, [expr.Variable("f")]), expr.TRUE, ), ), @@ -269,7 +268,7 @@ def test_inconsistent_identifier_casing() -> None: expr.Equal( expr.Call( ID("g", location=Location((7, 7))), - rty.BOOLEAN, + ty.BOOLEAN, [expr.Variable(ID("f", location=Location((8, 8))))], ), expr.TRUE, @@ -860,7 +859,7 @@ def test_function_declaration_invalid_parameter_type( actions=[ stmt.VariableAssignment( "Result", - expr.Call("Function", rty.BOOLEAN, [expr.Variable("M")]), + expr.Call("Function", ty.BOOLEAN, [expr.Variable("M")]), ), ], ), @@ -915,7 +914,7 @@ def test_function_declaration_invalid_return_type( actions=[ stmt.VariableAssignment( "Result", - expr.Call("Function", rty.BOOLEAN, []), + expr.Call("Function", ty.BOOLEAN, []), ), ], ), @@ -948,7 +947,7 @@ def test_call_to_undeclared_function() -> None: "Global", expr.Call( "UndefSub", - rty.UNDEFINED, + ty.UNDEFINED, [expr.Variable("Global")], location=Location((10, 20)), ), @@ -984,7 +983,7 @@ def test_call_undeclared_variable() -> None: "Result", expr.Call( "SubProg", - rty.BOOLEAN, + ty.BOOLEAN, [expr.Variable("Undefined", location=Location((10, 20)))], ), ), @@ -1015,7 +1014,7 @@ def test_call_invalid_argument_type() -> None: "Result", expr.Call( "Function", - rty.BOOLEAN, + ty.BOOLEAN, [expr.Variable("Channel", location=Location((10, 20)))], ), ), @@ -1051,7 +1050,7 @@ def test_call_missing_arguments() -> None: "Result", expr.Call( "Function", - rty.BOOLEAN, + ty.BOOLEAN, location=Location((10, 20)), ), ), @@ -1081,7 +1080,7 @@ def test_call_too_many_arguments() -> None: "Result", expr.Call( "Function", - rty.BOOLEAN, + ty.BOOLEAN, [expr.TRUE, expr.Number(1)], location=Location((10, 20)), ), @@ -1290,7 +1289,7 @@ def test_undeclared_variable_in_function_call() -> None: "Result", expr.Call( "SubProg", - rty.BOOLEAN, + ty.BOOLEAN, [expr.Variable("Undefined", location=Location((10, 20)))], ), ), @@ -2030,7 +2029,7 @@ def test_undefined_type_in_parameters(parameters: abc.Sequence[decl.FormalDeclar transitions=[ Transition( target=ID("null"), - condition=expr.Equal(expr.Call("X", rty.BOOLEAN, [expr.TRUE]), expr.TRUE), + condition=expr.Equal(expr.Call("X", ty.BOOLEAN, [expr.TRUE]), expr.TRUE), ), Transition( target=ID("Start"), @@ -2468,15 +2467,15 @@ def test_resolving_of_function_calls() -> None: global_decl = state_machine.declarations[ID("Global")] assert isinstance(global_decl, decl.VariableDeclaration) - assert global_decl.expression == expr.Call("Func", rty.BOOLEAN) + assert global_decl.expression == expr.Call("Func", ty.BOOLEAN) local_decl = state_machine.states[0].declarations[ID("Local")] assert isinstance(local_decl, decl.VariableDeclaration) - assert local_decl.expression == expr.Call("Func", rty.BOOLEAN) + assert local_decl.expression == expr.Call("Func", ty.BOOLEAN) local_stmt = state_machine.states[0].actions[0] assert isinstance(local_stmt, stmt.VariableAssignment) - assert local_stmt.expression == expr.Call("Func", rty.BOOLEAN) + assert local_stmt.expression == expr.Call("Func", ty.BOOLEAN) @pytest.mark.parametrize( @@ -2735,7 +2734,7 @@ def test_state_normalization( decl.VariableDeclaration( "Msg", "Message", - type_=rty.Message("M", is_definite=False), + type_=ty.Message("M", is_definite=False), ), ], transitions=[Transition(target=ID("null"))], @@ -2746,7 +2745,7 @@ def test_state_normalization( decl.VariableDeclaration( "Msg", "Message", - type_=rty.Message("M", is_definite=False), + type_=ty.Message("M", is_definite=False), ), ], transitions=[Transition(target=ID("null"))], @@ -2759,7 +2758,7 @@ def test_state_normalization( decl.VariableDeclaration( "Msg", "Message", - type_=rty.Message( + type_=ty.Message( "M", is_definite=True, ), @@ -2774,7 +2773,7 @@ def test_state_normalization( decl.VariableDeclaration( "Msg", "Message", - type_=rty.Message( + type_=ty.Message( "M", is_definite=True, ), @@ -2791,7 +2790,7 @@ def test_state_normalization( decl.VariableDeclaration( "Int", "Integer", - type_=rty.Integer("Integer", ty.Bounds(0, 255)), + type_=ty.Integer("Integer", ty.Bounds(0, 255)), ), ], transitions=[Transition(target=ID("null"))], @@ -2802,7 +2801,7 @@ def test_state_normalization( decl.VariableDeclaration( "Int", "Integer", - type_=rty.Integer("Integer", ty.Bounds(0, 255)), + type_=ty.Integer("Integer", ty.Bounds(0, 255)), ), ], transitions=[Transition(target=ID("null"))], @@ -2815,7 +2814,7 @@ def test_state_normalization( decl.VariableDeclaration( "Msg", "Message", - type_=rty.Message( + type_=ty.Message( "M", is_definite=True, ), @@ -2830,7 +2829,7 @@ def test_state_normalization( decl.VariableDeclaration( "Msg", "Message", - type_=rty.Message( + type_=ty.Message( "M", is_definite=True, ), @@ -2847,7 +2846,7 @@ def test_state_normalization( decl.VariableDeclaration( "Msg", "Message", - type_=rty.Message( + type_=ty.Message( "M", is_definite=True, ), @@ -2865,7 +2864,7 @@ def test_state_normalization( decl.VariableDeclaration( "Msg", "Message", - type_=rty.Message( + type_=ty.Message( "M", is_definite=True, ), @@ -2885,7 +2884,7 @@ def test_state_normalization( decl.VariableDeclaration( "Msg", "Message", - type_=rty.Message( + type_=ty.Message( "M", is_definite=True, ), @@ -2905,7 +2904,7 @@ def test_state_normalization( decl.VariableDeclaration( "Msg", "Message", - type_=rty.Message( + type_=ty.Message( "M", is_definite=True, ), @@ -2927,7 +2926,7 @@ def test_state_normalization( decl.VariableDeclaration( "Msg", "Message", - type_=rty.Message( + type_=ty.Message( "M", is_definite=True, ), @@ -2948,7 +2947,7 @@ def test_state_normalization( decl.VariableDeclaration( "Msg", "Message", - type_=rty.Message( + type_=ty.Message( "M", is_definite=True, ), @@ -2971,7 +2970,7 @@ def test_state_normalization( decl.VariableDeclaration( "Msg", "Message", - type_=rty.Message( + type_=ty.Message( "M", is_definite=True, ), @@ -2979,7 +2978,7 @@ def test_state_normalization( decl.VariableDeclaration( "Msg2", "Message", - type_=rty.Message( + type_=ty.Message( "M", is_definite=True, ), @@ -2990,7 +2989,7 @@ def test_state_normalization( "Msg", expr.Call( "Func", - rty.Message( + ty.Message( "M", is_definite=True, ), @@ -3006,12 +3005,12 @@ def test_state_normalization( decl.VariableDeclaration( "Msg", "Message", - type_=rty.Structure("M"), + type_=ty.Structure("M"), ), decl.VariableDeclaration( "Msg2", "Message", - type_=rty.Message( + type_=ty.Message( "M", is_definite=True, ), @@ -3022,7 +3021,7 @@ def test_state_normalization( "Msg", expr.Call( "Func", - type_=rty.Structure( + type_=ty.Structure( "M", ), args=[expr.Opaque("Msg2")], @@ -3039,7 +3038,7 @@ def test_state_normalization( decl.VariableDeclaration( "Msg", "Message", - type_=rty.Message("M", is_definite=True), + type_=ty.Message("M", is_definite=True), ), ], actions=[ @@ -3048,7 +3047,7 @@ def test_state_normalization( expr.Selected( prefix=expr.Variable( "Msg", - type_=rty.Message( + type_=ty.Message( "M", is_definite=True, ), @@ -3065,7 +3064,7 @@ def test_state_normalization( decl.VariableDeclaration( "Msg", "Message", - type_=rty.Structure( + type_=ty.Structure( "M", ), ), @@ -3076,7 +3075,7 @@ def test_state_normalization( expr.Selected( prefix=expr.Variable( "Msg", - type_=rty.Structure( + type_=ty.Structure( "M", ), ), @@ -3094,7 +3093,7 @@ def test_state_normalization( decl.VariableDeclaration( "Msg", "Message", - type_=rty.Message( + type_=ty.Message( "M", is_definite=True, ), @@ -3114,7 +3113,7 @@ def test_state_normalization( decl.VariableDeclaration( "Msg", "Message", - type_=rty.Message( + type_=ty.Message( "M", is_definite=True, ), @@ -3145,7 +3144,7 @@ def test_message_assignment_from_function() -> None: transitions=[Transition(target=ID("null"))], exception_transition=Transition(target=ID("null")), declarations=[decl.VariableDeclaration("Msg", "Null_Msg::Message")], - actions=[stmt.VariableAssignment("Msg", expr.Call("SubProg", rty.BASE_INTEGER))], + actions=[stmt.VariableAssignment("Msg", expr.Call("SubProg", ty.BASE_INTEGER))], ), ], declarations=[], @@ -3179,7 +3178,7 @@ def test_unchecked_state_machine_checked() -> None: condition=expr.And( expr.Equal(expr.Variable("Z"), expr.TRUE), expr.Equal( - expr.Call("G", rty.BOOLEAN, [expr.Variable("F")]), + expr.Call("G", ty.BOOLEAN, [expr.Variable("F")]), expr.TRUE, ), ), @@ -3229,7 +3228,7 @@ def test_unchecked_state_machine_checked() -> None: condition=expr.And( expr.Equal(expr.Variable("Z"), expr.TRUE), expr.Equal( - expr.Call("G", rty.BOOLEAN, [expr.Variable("F")]), + expr.Call("G", ty.BOOLEAN, [expr.Variable("F")]), expr.TRUE, ), ), diff --git a/tests/unit/model/statement_test.py b/tests/unit/model/statement_test.py index 0f9276796..54452884a 100644 --- a/tests/unit/model/statement_test.py +++ b/tests/unit/model/statement_test.py @@ -1,13 +1,13 @@ import pytest -from rflx import expr, ir, typing_ as rty +from rflx import expr, ir, ty from rflx.identifier import ID, id_generator from rflx.model import statement as stmt -from rflx.rapidflux import Location, RecordFluxError, ty +from rflx.rapidflux import Location, RecordFluxError -INT_TY = rty.Integer("I", ty.Bounds(10, 100)) -MSG_TY = rty.Message("M") -SEQ_TY = rty.Sequence("S", rty.Message("M")) +INT_TY = ty.Integer("I", ty.Bounds(10, 100)) +MSG_TY = ty.Message("M") +SEQ_TY = ty.Sequence("S", ty.Message("M")) def test_variable_assignment_to_ir() -> None: @@ -16,7 +16,7 @@ def test_variable_assignment_to_ir() -> None: expr.Add(expr.Variable("Y", type_=INT_TY), expr.Number(1)), INT_TY, ).to_ir(id_generator()) == [ - ir.Assign("X", ir.Add(ir.IntVar("Y", INT_TY), ir.IntVal(1)), rty.BASE_INTEGER), + ir.Assign("X", ir.Add(ir.IntVar("Y", INT_TY), ir.IntVal(1)), ty.BASE_INTEGER), ] assert stmt.VariableAssignment( "X", @@ -26,12 +26,12 @@ def test_variable_assignment_to_ir() -> None: ), INT_TY, ).to_ir(id_generator()) == [ - ir.VarDecl("T_0", rty.BASE_INTEGER), - ir.Assign("T_0", ir.Sub(ir.IntVar("Z", INT_TY), ir.IntVal(1)), rty.BASE_INTEGER), + ir.VarDecl("T_0", ty.BASE_INTEGER), + ir.Assign("T_0", ir.Sub(ir.IntVar("Z", INT_TY), ir.IntVal(1)), ty.BASE_INTEGER), ir.Assign( "X", - ir.Add(ir.IntVar("Y", INT_TY), ir.IntVar("T_0", rty.BASE_INTEGER)), - rty.BASE_INTEGER, + ir.Add(ir.IntVar("Y", INT_TY), ir.IntVar("T_0", ty.BASE_INTEGER)), + ty.BASE_INTEGER, ), ] @@ -50,12 +50,12 @@ def test_message_field_assignment_to_ir() -> None: ), MSG_TY, ).to_ir(id_generator()) == [ - ir.VarDecl("T_0", rty.BASE_INTEGER), - ir.Assign("T_0", ir.Add(ir.IntVar("Z", INT_TY), ir.IntVal(1)), rty.BASE_INTEGER), + ir.VarDecl("T_0", ty.BASE_INTEGER), + ir.Assign("T_0", ir.Add(ir.IntVar("Z", INT_TY), ir.IntVal(1)), ty.BASE_INTEGER), ir.FieldAssign( "X", "Y", - ir.Add(ir.IntVar("Y", INT_TY), ir.IntVar("T_0", rty.BASE_INTEGER)), + ir.Add(ir.IntVar("Y", INT_TY), ir.IntVar("T_0", ty.BASE_INTEGER)), MSG_TY, ), ] @@ -74,9 +74,9 @@ def test_append_to_ir() -> None: ), SEQ_TY, ).to_ir(id_generator()) == [ - ir.VarDecl("T_0", rty.BASE_INTEGER), - ir.Assign("T_0", ir.Add(ir.IntVar("Z", INT_TY), ir.IntVal(1)), rty.BASE_INTEGER), - ir.Append("X", ir.Add(ir.IntVar("Y", INT_TY), ir.IntVar("T_0", rty.BASE_INTEGER)), SEQ_TY), + ir.VarDecl("T_0", ty.BASE_INTEGER), + ir.Assign("T_0", ir.Add(ir.IntVar("Z", INT_TY), ir.IntVal(1)), ty.BASE_INTEGER), + ir.Append("X", ir.Add(ir.IntVar("Y", INT_TY), ir.IntVar("T_0", ty.BASE_INTEGER)), SEQ_TY), ] @@ -89,10 +89,10 @@ def test_extend_to_ir() -> None: def test_reset_check_type() -> None: def typify_variable(expression: expr.Expr) -> expr.Expr: if isinstance(expression, expr.Variable) and expression.identifier == ID("M"): - return expr.Variable("M", type_=rty.Message("M", {("F",)}, {ID("F"): INT_TY})) + return expr.Variable("M", type_=ty.Message("M", {("F",)}, {ID("F"): INT_TY})) return expression - t = rty.Message("T", parameter_types={ID("Y"): INT_TY}) + t = ty.Message("T", parameter_types={ID("Y"): INT_TY}) reset = stmt.Reset("X", {ID("Y"): expr.Selected(expr.Variable("M"), "F")}) reset.check_type(t, typify_variable).propagate() @@ -114,7 +114,7 @@ def typify_variable(expression: expr.Expr) -> expr.Expr: match=r'^:1:2: error: undefined variable "M"$', ): reset.check_type( - rty.Message("T", parameter_types={ID("Y"): INT_TY}), + ty.Message("T", parameter_types={ID("Y"): INT_TY}), typify_variable, ).propagate() @@ -122,7 +122,7 @@ def typify_variable(expression: expr.Expr) -> expr.Expr: def test_reset_check_type_error_invalid_arguments() -> None: def typify_variable(expression: expr.Expr) -> expr.Expr: if isinstance(expression, expr.Variable) and expression.identifier == ID("M"): - return expr.Variable("M", type_=rty.Message("M", {("F",)}, {ID("F"): INT_TY})) + return expr.Variable("M", type_=ty.Message("M", {("F",)}, {ID("F"): INT_TY})) return expression reset = stmt.Reset( @@ -140,7 +140,7 @@ def typify_variable(expression: expr.Expr) -> expr.Expr: ), ): reset.check_type( - rty.Message("T", parameter_types={ID("Y"): INT_TY}), + ty.Message("T", parameter_types={ID("Y"): INT_TY}), typify_variable, ).propagate() @@ -148,7 +148,7 @@ def typify_variable(expression: expr.Expr) -> expr.Expr: def test_reset_check_type_error_unexpected_arguments() -> None: def typify_variable(expression: expr.Expr) -> expr.Expr: if isinstance(expression, expr.Variable) and expression.identifier == ID("M"): - return expr.Variable("M", type_=rty.Message("M", {("F",)}, {ID("F"): INT_TY})) + return expr.Variable("M", type_=ty.Message("M", {("F",)}, {ID("F"): INT_TY})) return expression reset = stmt.Reset( @@ -160,7 +160,7 @@ def typify_variable(expression: expr.Expr) -> expr.Expr: match=r'^:1:2: error: unexpected argument "Z"$', ): reset.check_type( - rty.Sequence("T", INT_TY), + ty.Sequence("T", INT_TY), typify_variable, ).propagate() @@ -185,7 +185,7 @@ def test_read_to_ir() -> None: assert stmt.Read("X", expr.Variable("M", type_=MSG_TY)).to_ir(id_generator()) == [ ir.Read("X", ir.ObjVar("M", MSG_TY)), ] - assert stmt.Read("X", expr.Call("Y", type_=rty.Message("M"))).to_ir(id_generator()) == [ + assert stmt.Read("X", expr.Call("Y", type_=ty.Message("M"))).to_ir(id_generator()) == [ ir.VarDecl("T_0", MSG_TY), ir.Assign("T_0", ir.ObjCall("Y", [], [], MSG_TY), MSG_TY), ir.Read("X", ir.ObjVar("T_0", MSG_TY)), @@ -196,7 +196,7 @@ def test_write_to_ir() -> None: assert stmt.Write("X", expr.Variable("M", type_=MSG_TY)).to_ir(id_generator()) == [ ir.Write("X", ir.ObjVar("M", MSG_TY)), ] - assert stmt.Write("X", expr.Call("Y", type_=rty.Message("M"))).to_ir(id_generator()) == [ + assert stmt.Write("X", expr.Call("Y", type_=ty.Message("M"))).to_ir(id_generator()) == [ ir.VarDecl("T_0", MSG_TY), ir.Assign("T_0", ir.ObjCall("Y", [], [], MSG_TY), MSG_TY), ir.Write("X", ir.ObjVar("T_0", MSG_TY)), diff --git a/tests/unit/model/type_decl_test.py b/tests/unit/model/type_decl_test.py index 0a6d0fdce..27c6abdf9 100644 --- a/tests/unit/model/type_decl_test.py +++ b/tests/unit/model/type_decl_test.py @@ -2,7 +2,7 @@ import pytest -import rflx.typing_ as rty +from rflx import ty from rflx.expr import Add, Aggregate, Equal, Mul, Number, Pow, Size, Sub, Variable from rflx.identifier import ID from rflx.model import ( @@ -48,7 +48,7 @@ def test_type_type() -> None: class NewType(TypeDecl): pass - assert NewType("P::T").type_ == rty.Undefined() + assert NewType("P::T").type_ == ty.Undefined() def test_type_dependencies() -> None: diff --git a/tests/unit/rapidflux/ty_test.py b/tests/unit/rapidflux/ty_test.py index 92d805f00..932f302fe 100644 --- a/tests/unit/rapidflux/ty_test.py +++ b/tests/unit/rapidflux/ty_test.py @@ -1,26 +1,736 @@ +from __future__ import annotations + import pickle +from collections import abc from pathlib import Path import pytest -from rflx.rapidflux import ID +from rflx.rapidflux import ID, Location, RecordFluxError from rflx.rapidflux.ty import ( Aggregate, Any, + AnyInteger, Bounds, Builtins, Channel, + Composite, + Compound, Enumeration, Integer, Message, Refinement, + Sequence, Structure, + Type, Undefined, + UniversalInteger, + check_type, + check_type_instance, + common_type, +) +from rflx.ty import BASE_INTEGER, UNDEFINED + +INT_A = Integer("A", Bounds(10, 100)) +ENUM_A = Enumeration("A", [ID("AE1"), ID("AE2")]) +ENUM_B = Enumeration("B", [ID("BE1"), ID("BE2"), ID("BE3")]) +SEQ_A = Sequence("A", INT_A) +MSG_A = Message("A") + + +@pytest.mark.parametrize( + ("enumeration", "other", "expected"), + [ + (ENUM_A, Any(), ENUM_A), + (ENUM_A, ENUM_A, ENUM_A), + (ENUM_A, Undefined(), Undefined()), + (ENUM_A, ENUM_B, Undefined()), + (ENUM_A, Integer("A", Bounds(10, 100)), Undefined()), + ], +) +def test_enumeration_common_type(enumeration: Type, other: Type, expected: Type) -> None: + assert enumeration.common_type(other) == expected + assert other.common_type(enumeration) == expected + + +@pytest.mark.parametrize( + ("enumeration", "other", "expected"), + [ + (ENUM_A, Any(), True), + (ENUM_A, ENUM_A, True), + (ENUM_A, Undefined(), False), + (ENUM_A, ENUM_B, False), + (ENUM_A, Integer("A", Bounds(10, 100)), False), + ], +) +def test_enumeration_is_compatible(enumeration: Type, other: Type, expected: bool) -> None: + assert enumeration.is_compatible(other) == expected + assert other.is_compatible(enumeration) == expected + + +@pytest.mark.parametrize( + ("base_integer", "other", "expected"), + [ + (BASE_INTEGER, Any(), BASE_INTEGER), + (BASE_INTEGER, BASE_INTEGER, BASE_INTEGER), + ( + BASE_INTEGER, + Integer("A", Bounds(10, 100)), + BASE_INTEGER, + ), + ( + BASE_INTEGER, + UniversalInteger(Bounds(10, 100)), + BASE_INTEGER, + ), + (BASE_INTEGER, Undefined(), Undefined()), + (BASE_INTEGER, ENUM_B, Undefined()), + ], +) +def test_base_integer_common_type(base_integer: Type, other: Type, expected: Type) -> None: + assert base_integer.common_type(other) == expected + assert other.common_type(base_integer) == expected + + +@pytest.mark.parametrize( + ("base_integer", "other", "expected"), + [ + (BASE_INTEGER, Any(), True), + (BASE_INTEGER, BASE_INTEGER, True), + ( + BASE_INTEGER, + Integer("A", Bounds(10, 100)), + True, + ), + ( + BASE_INTEGER, + UniversalInteger(Bounds(10, 100)), + True, + ), + (BASE_INTEGER, Undefined(), False), + (BASE_INTEGER, ENUM_B, False), + ], +) +def test_base_integer_is_compatible(base_integer: Type, other: Type, expected: bool) -> None: + assert base_integer.is_compatible(other) == expected + assert other.is_compatible(base_integer) == expected + + +@pytest.mark.parametrize( + ("universal_integer", "other", "expected"), + [ + ( + UniversalInteger(Bounds(10, 100)), + Any(), + UniversalInteger(Bounds(10, 100)), + ), + ( + UniversalInteger(Bounds(10, 100)), + BASE_INTEGER, + BASE_INTEGER, + ), + ( + UniversalInteger(Bounds(10, 100)), + UniversalInteger(Bounds(10, 100)), + UniversalInteger(Bounds(10, 100)), + ), + ( + UniversalInteger(Bounds(10, 100)), + Integer("A", Bounds(10, 100)), + BASE_INTEGER, + ), + ( + UniversalInteger(Bounds(20, 80)), + Integer("A", Bounds(10, 100)), + BASE_INTEGER, + ), + ( + UniversalInteger(Bounds(10, 100)), + Undefined(), + Undefined(), + ), + ( + UniversalInteger(Bounds(10, 100)), + ENUM_B, + Undefined(), + ), + ], +) +def test_universal_integer_common_type( + universal_integer: Type, + other: Type, + expected: Type, +) -> None: + assert universal_integer.common_type(other) == expected + assert other.common_type(universal_integer) == expected + + +@pytest.mark.parametrize( + ("universal_integer", "other", "expected"), + [ + (UniversalInteger(Bounds(10, 100)), Any(), True), + (UniversalInteger(Bounds(10, 100)), BASE_INTEGER, True), + (UniversalInteger(Bounds(10, 100)), UniversalInteger(Bounds(10, 100)), True), + ( + UniversalInteger(Bounds(10, 100)), + Integer("A", Bounds(10, 100)), + True, + ), + (UniversalInteger(Bounds(10, 100)), Undefined(), False), + (UniversalInteger(Bounds(10, 100)), ENUM_B, False), + ], +) +def test_universal_integer_is_compatible( + universal_integer: Type, + other: Type, + expected: bool, +) -> None: + assert universal_integer.is_compatible(other) == expected + assert other.is_compatible(universal_integer) == expected + + +@pytest.mark.parametrize( + ("integer", "other", "expected"), + [ + ( + Integer("A", Bounds(10, 100)), + Any(), + Integer("A", Bounds(10, 100)), + ), + ( + Integer("A", Bounds(10, 100)), + BASE_INTEGER, + BASE_INTEGER, + ), + ( + Integer("A", Bounds(10, 100)), + Integer("A", Bounds(10, 100)), + BASE_INTEGER, + ), + ( + Integer("A", Bounds(10, 100)), + UniversalInteger(Bounds(10, 100)), + BASE_INTEGER, + ), + ( + Integer("A", Bounds(10, 100)), + Integer("B", Bounds(10, 100)), + BASE_INTEGER, + ), + ( + Integer("A", Bounds(10, 100)), + UniversalInteger(Bounds(0, 200)), + BASE_INTEGER, + ), + ( + Integer("A", Bounds(10, 100)), + Undefined(), + Undefined(), + ), + ( + Integer("A", Bounds(10, 100)), + ENUM_B, + Undefined(), + ), + ], +) +def test_integer_common_type(integer: Type, other: Type, expected: Type) -> None: + assert integer.common_type(other) == expected + assert other.common_type(integer) == expected + + +@pytest.mark.parametrize( + ("integer", "other", "expected"), + [ + (Integer("A", Bounds(10, 100)), Any(), True), + (Integer("A", Bounds(10, 100)), BASE_INTEGER, True), + (Integer("A", Bounds(10, 100)), Integer("A", Bounds(10, 100)), True), + (Integer("A", Bounds(10, 100)), UniversalInteger(Bounds(10, 100)), True), + ( + Integer("A", Bounds(10, 100)), + Integer("B", Bounds(10, 100)), + True, + ), + ( + Integer("A", Bounds(0, 200)), + UniversalInteger(Bounds(10, 100)), + True, + ), + ( + Integer("A", Bounds(10, 100)), + UniversalInteger(Bounds(0, 200)), + True, + ), + (Integer("A", Bounds(10, 100)), Undefined(), False), + (Integer("A", Bounds(10, 100)), ENUM_B, False), + ], +) +def test_integer_is_compatible(integer: Type, other: Type, expected: bool) -> None: + assert integer.is_compatible(other) == expected + assert other.is_compatible(integer) == expected + + +@pytest.mark.parametrize( + ("integer", "other", "expected"), + [ + (Integer("A", Bounds(10, 100)), Any(), True), + (Integer("A", Bounds(10, 100)), BASE_INTEGER, False), + (Integer("A", Bounds(10, 100)), Integer("A", Bounds(10, 100)), True), + (Integer("A", Bounds(10, 100)), UniversalInteger(Bounds(10, 100)), True), + ( + Integer("A", Bounds(10, 100)), + Integer("B", Bounds(10, 100)), + False, + ), + ( + Integer("A", Bounds(0, 200)), + UniversalInteger(Bounds(10, 100)), + True, + ), + ( + Integer("A", Bounds(10, 100)), + UniversalInteger(Bounds(0, 200)), + False, + ), + (Integer("A", Bounds(10, 100)), Undefined(), False), + (Integer("A", Bounds(10, 100)), ENUM_B, False), + ], +) +def test_integer_is_compatible_strong(integer: Type, other: Type, expected: bool) -> None: + assert integer.is_compatible_strong(other) == expected + assert other.is_compatible_strong(integer) == expected + + +@pytest.mark.parametrize( + ("aggregate", "other", "expected"), + [ + ( + Aggregate(Integer("A", Bounds(10, 100))), + Any(), + Aggregate(Integer("A", Bounds(10, 100))), + ), + ( + Aggregate(Integer("A", Bounds(10, 100))), + Aggregate(Integer("A", Bounds(10, 100))), + Aggregate(Integer("A", Bounds(10, 100))), + ), + ( + Aggregate(Integer("A", Bounds(10, 100))), + Aggregate(Integer("B", Bounds(10, 100))), + Aggregate(BASE_INTEGER), + ), + ( + Aggregate(Integer("A", Bounds(10, 100))), + Aggregate(Integer("B", Bounds(20, 200))), + Aggregate(BASE_INTEGER), + ), + ( + Aggregate(UniversalInteger(Bounds(10, 100))), + Aggregate(UniversalInteger(Bounds(20, 200))), + Aggregate(UniversalInteger(Bounds(10, 200))), + ), + ( + Aggregate(Integer("A", Bounds(10, 100))), + Undefined(), + Undefined(), + ), + ], +) +def test_aggregate_common_type(aggregate: Type, other: Type, expected: Type) -> None: + assert aggregate.common_type(other) == expected + assert other.common_type(aggregate) == expected + + +@pytest.mark.parametrize( + ("aggregate", "other", "expected"), + [ + ( + Aggregate(Integer("A", Bounds(10, 100))), + Any(), + True, + ), + ( + Aggregate(Integer("A", Bounds(10, 100))), + Aggregate(Integer("A", Bounds(10, 100))), + True, + ), + ( + Aggregate(Integer("A", Bounds(10, 100))), + Aggregate(Integer("B", Bounds(10, 100))), + True, + ), + ( + Aggregate(Integer("A", Bounds(10, 100))), + Aggregate(Integer("A", Bounds(20, 200))), + True, + ), + ( + Aggregate(UniversalInteger(Bounds(10, 100))), + Aggregate(UniversalInteger(Bounds(20, 200))), + True, + ), + ( + Aggregate(Integer("A", Bounds(10, 100))), + Undefined(), + False, + ), + ], ) +def test_aggregate_is_compatible(aggregate: Type, other: Type, expected: bool) -> None: + assert aggregate.is_compatible(other) == expected + assert other.is_compatible(aggregate) == expected + -INTEGER_A = Integer("A", Bounds(10, 100)) -ENUMERATION_A = Enumeration("A", [ID("AE1"), ID("AE2")]) -ENUMERATION_B = Enumeration("B", [ID("BE1"), ID("BE2"), ID("BE3")]) +@pytest.mark.parametrize( + ("composite", "other", "expected"), + [ + ( + Sequence("A", Integer("B", Bounds(10, 100))), + Any(), + Sequence("A", Integer("B", Bounds(10, 100))), + ), + ( + Sequence("A", Integer("B", Bounds(10, 100))), + Sequence("A", Integer("B", Bounds(10, 100))), + Sequence("A", Integer("B", Bounds(10, 100))), + ), + ( + Sequence("A", Integer("B", Bounds(10, 100))), + Aggregate(Integer("B", Bounds(10, 100))), + Sequence("A", Integer("B", Bounds(10, 100))), + ), + ( + Sequence("A", Integer("B", Bounds(10, 100))), + Aggregate(UniversalInteger(Bounds(10, 100))), + Sequence("A", Integer("B", Bounds(10, 100))), + ), + ( + Sequence("A", Integer("B", Bounds(10, 100))), + Aggregate(Integer("C", Bounds(10, 100))), + Undefined(), + ), + ( + Sequence("A", Integer("B", Bounds(10, 100))), + Aggregate(Integer("C", Bounds(20, 200))), + Undefined(), + ), + ( + Sequence("A", Integer("B", Bounds(10, 100))), + Aggregate(UniversalInteger(Bounds(20, 200))), + Undefined(), + ), + ( + Sequence("A", Integer("B", Bounds(10, 100))), + Undefined(), + Undefined(), + ), + ], +) +def test_composite_common_type(composite: Type, other: Type, expected: Type) -> None: + assert composite.common_type(other) == expected + assert other.common_type(composite) == expected + + +@pytest.mark.parametrize( + ("composite", "other", "expected"), + [ + ( + Sequence("A", Integer("B", Bounds(10, 100))), + Any(), + True, + ), + ( + Sequence("A", Integer("B", Bounds(10, 100))), + Sequence("A", Integer("B", Bounds(10, 100))), + True, + ), + ( + Sequence("A", Integer("B", Bounds(10, 100))), + Aggregate(Any()), + True, + ), + ( + Sequence("A", Integer("B", Bounds(10, 100))), + Aggregate(Integer("B", Bounds(10, 100))), + True, + ), + ( + Sequence("A", Integer("B", Bounds(10, 100))), + Aggregate(UniversalInteger(Bounds(10, 100))), + True, + ), + ( + Sequence("A", Integer("B", Bounds(10, 100))), + Aggregate(Integer("C", Bounds(10, 100))), + False, + ), + ( + Sequence("A", Integer("B", Bounds(10, 100))), + Aggregate(Integer("C", Bounds(20, 200))), + False, + ), + ( + Sequence("A", Integer("B", Bounds(10, 100))), + Aggregate(UniversalInteger(Bounds(20, 200))), + False, + ), + ( + Sequence("A", Integer("B", Bounds(10, 100))), + Undefined(), + False, + ), + ], +) +def test_composite_is_compatible(composite: Type, other: Type, expected: bool) -> None: + assert composite.is_compatible(other) == expected + assert other.is_compatible(composite) == expected + + +@pytest.mark.parametrize( + ("message", "other", "expected"), + [ + (Message("A"), Any(), Message("A")), + (Message("A"), Message("A"), Message("A")), + (Message("A"), Message("B"), Undefined()), + (Message("A"), Undefined(), Undefined()), + (Message("A"), ENUM_B, Undefined()), + ], +) +def test_message_common_type(message: Type, other: Type, expected: Type) -> None: + assert message.common_type(other) == expected + assert other.common_type(message) == expected + + +@pytest.mark.parametrize( + ("message", "other", "expected"), + [ + (Message("A"), Any(), True), + (Message("A"), Message("A"), True), + (Message("A"), Message("B"), False), + (Message("A"), Undefined(), False), + (Message("A"), ENUM_B, False), + ], +) +def test_message_is_compatible(message: Type, other: Type, expected: bool) -> None: + assert message.is_compatible(other) == expected + assert other.is_compatible(message) == expected + + +@pytest.mark.parametrize( + ("channel", "other", "expected"), + [ + (Channel(readable=True, writable=False), Any(), Channel(readable=True, writable=False)), + ( + Channel(readable=True, writable=False), + Channel(readable=True, writable=False), + Channel(readable=True, writable=False), + ), + ( + Channel(readable=True, writable=False), + Channel(readable=False, writable=True), + Undefined(), + ), + (Channel(readable=True, writable=False), Undefined(), Undefined()), + ( + Channel(readable=True, writable=False), + ENUM_A, + Undefined(), + ), + ], +) +def test_channel_common_type(channel: Type, other: Type, expected: Type) -> None: + assert channel.common_type(other) == expected + assert other.common_type(channel) == expected + + +@pytest.mark.parametrize( + ("channel", "other", "expected"), + [ + (Channel(readable=True, writable=False), Any(), True), + (Any(), Channel(readable=True, writable=False), True), + (Channel(readable=True, writable=False), Channel(readable=True, writable=False), True), + (Channel(readable=True, writable=False), Channel(readable=False, writable=True), False), + (Channel(readable=False, writable=True), Channel(readable=True, writable=False), False), + (Channel(readable=True, writable=True), Channel(readable=False, writable=True), True), + (Channel(readable=True, writable=True), Channel(readable=True, writable=False), True), + (Channel(readable=False, writable=True), Channel(readable=True, writable=True), False), + (Channel(readable=True, writable=False), Channel(readable=True, writable=True), False), + (Channel(readable=True, writable=False), Undefined(), False), + (Channel(readable=True, writable=False), ENUM_A, False), + ], +) +def test_channel_is_compatible(channel: Type, other: Type, expected: bool) -> None: + assert channel.is_compatible(other) == expected + + +@pytest.mark.parametrize( + ("types", "expected"), + [ + ( + [], + Any(), + ), + ( + [INT_A, SEQ_A], + UNDEFINED, + ), + ( + [ + Integer("A", Bounds(10, 100)), + Integer("A", Bounds(10, 100)), + ], + BASE_INTEGER, + ), + ( + [ + UniversalInteger(Bounds(10, 50)), + Integer("A", Bounds(10, 100)), + UniversalInteger(Bounds(50, 100)), + ], + BASE_INTEGER, + ), + ( + [ + UniversalInteger(Bounds(10, 50)), + Integer("A", Bounds(10, 100)), + UniversalInteger(Bounds(20, 200)), + ], + BASE_INTEGER, + ), + ( + [ + Aggregate(Integer("A", Bounds(10, 100))), + Aggregate(UniversalInteger(Bounds(20, 100))), + Aggregate(Integer("B", Bounds(20, 200))), + ], + Aggregate(BASE_INTEGER), + ), + ( + [ + Aggregate(UniversalInteger(Bounds(10, 20))), + Aggregate(UniversalInteger(Bounds(50, 60))), + Aggregate(UniversalInteger(Bounds(90, 100))), + ], + Aggregate(UniversalInteger(Bounds(10, 100))), + ), + ], +) +def test_common_type(types: abc.Sequence[Type], expected: Type) -> None: + assert common_type(types) == expected + assert common_type(list(reversed(types))) == expected + + +@pytest.mark.parametrize( + ("actual", "expected"), + [ + ( + Any(), + (INT_A, ENUM_A), + ), + ( + INT_A, + UNDEFINED, + ), + ], +) +def test_check_type(actual: Type, expected: Type | tuple[Type, ...]) -> None: + check_type(actual, expected, Location((10, 20)), '"A"').propagate() + + +@pytest.mark.parametrize( + ("actual", "expected", "match"), + [ + ( + Message("A"), + Channel(readable=False, writable=True), + r"^:10:20: error: expected writable channel\n" + r':10:20: error: found message type "A"$', + ), + ( + BASE_INTEGER, + Message("A"), + r"^" + r':10:20: error: expected message type "A"\n' + r':10:20: error: found integer type "__BUILTINS__::Base_Integer"' + r" \(0 \.\. 2\*\*63 - 1\)" + r"$", + ), + ( + Undefined(), + Integer("A", Bounds(10, 100)), + r'^:10:20: error: undefined "A"$', + ), + ], +) +def test_check_type_error( + actual: Type, + expected: Type | tuple[Type, ...], + match: str, +) -> None: + with pytest.raises(RecordFluxError, match=match): + check_type(actual, expected, Location((10, 20)), '"A"').propagate() + + +@pytest.mark.parametrize( + ("actual", "expected"), + [ + ( + Any(), + Integer, + ), + ( + BASE_INTEGER, + AnyInteger, + ), + ( + SEQ_A, + Composite, + ), + ( + MSG_A, + Compound, + ), + ], +) +def test_check_type_instance( + actual: Type, + expected: type[Type] | tuple[type[Type], ...], +) -> None: + check_type_instance(actual, expected, Location((10, 20)), '"A"').propagate() + + +@pytest.mark.parametrize( + ("actual", "expected", "match"), + [ + ( + Message("M"), + Channel, + r"^:10:20: error: expected channel\n" + r':10:20: error: found message type "M"$', + ), + ( + BASE_INTEGER, + (Sequence, Message), + r"^" + r":10:20: error: expected sequence type or message type\n" + r':10:20: error: found integer type "__BUILTINS__::Base_Integer"' + r" \(0 \.\. 2\*\*63 - 1\)" + r"$", + ), + ( + Undefined(), + Integer, + r'^:10:20: error: undefined "A"$', + ), + ], +) +def test_check_type_instance_error( + actual: Type, + expected: type[Type] | tuple[type[Type], ...], + match: str, +) -> None: + with pytest.raises(RecordFluxError, match=match): + check_type_instance(actual, expected, Location((10, 20)), '"A"').propagate() @pytest.mark.parametrize( @@ -28,11 +738,11 @@ [ Undefined(), Any(), - Builtins.BOOLEAN, + ENUM_A, Builtins.UNIVERSAL_INTEGER, Builtins.BASE_INTEGER, - Aggregate(INTEGER_A), - Builtins.OPAQUE, + Aggregate(INT_A), + SEQ_A, Structure( "A", {("F",)}, diff --git a/tests/unit/specification/parser_test.py b/tests/unit/specification/parser_test.py index 3e4a11d5a..cd7d35687 100644 --- a/tests/unit/specification/parser_test.py +++ b/tests/unit/specification/parser_test.py @@ -30,7 +30,7 @@ from rflx.model.message import ByteOrder from rflx.rapidflux import Location, RecordFluxError, Severity from rflx.specification import parser -from rflx.typing_ import UNDEFINED +from rflx.ty import UNDEFINED from tests.const import SPEC_DIR from tests.data import models from tests.utils import ( diff --git a/tests/unit/ty_test.py b/tests/unit/ty_test.py new file mode 100644 index 000000000..ebfa4f3da --- /dev/null +++ b/tests/unit/ty_test.py @@ -0,0 +1,2 @@ +def test_dummy() -> None: + pass diff --git a/tests/unit/typing__test.py b/tests/unit/typing__test.py deleted file mode 100644 index 8efaef9ab..000000000 --- a/tests/unit/typing__test.py +++ /dev/null @@ -1,700 +0,0 @@ -from __future__ import annotations - -from collections import abc - -import pytest - -from rflx.identifier import ID -from rflx.rapidflux import Location, RecordFluxError -from rflx.rapidflux.ty import Bounds -from rflx.typing_ import ( - BASE_INTEGER, - Aggregate, - Any, - Channel, - Enumeration, - Integer, - Message, - Sequence, - Type, - Undefined, - UniversalInteger, - check_type, - check_type_instance, - common_type, -) - -INTEGER_A = Integer("A", Bounds(10, 100)) -ENUMERATION_A = Enumeration("A", [ID("AE1"), ID("AE2")]) -ENUMERATION_B = Enumeration("B", [ID("BE1"), ID("BE2"), ID("BE3")]) - - -@pytest.mark.parametrize( - ("enumeration", "other", "expected"), - [ - (ENUMERATION_A, Any(), ENUMERATION_A), - (ENUMERATION_A, ENUMERATION_A, ENUMERATION_A), - (ENUMERATION_A, Undefined(), Undefined()), - (ENUMERATION_A, ENUMERATION_B, Undefined()), - (ENUMERATION_A, Integer("A", Bounds(10, 100)), Undefined()), - ], -) -def test_enumeration_common_type(enumeration: Type, other: Type, expected: Type) -> None: - assert enumeration.common_type(other) == expected - assert other.common_type(enumeration) == expected - - -@pytest.mark.parametrize( - ("enumeration", "other", "expected"), - [ - (ENUMERATION_A, Any(), True), - (ENUMERATION_A, ENUMERATION_A, True), - (ENUMERATION_A, Undefined(), False), - (ENUMERATION_A, ENUMERATION_B, False), - (ENUMERATION_A, Integer("A", Bounds(10, 100)), False), - ], -) -def test_enumeration_is_compatible(enumeration: Type, other: Type, expected: bool) -> None: - assert enumeration.is_compatible(other) == expected - assert other.is_compatible(enumeration) == expected - - -@pytest.mark.parametrize( - ("base_integer", "other", "expected"), - [ - (BASE_INTEGER, Any(), BASE_INTEGER), - (BASE_INTEGER, BASE_INTEGER, BASE_INTEGER), - ( - BASE_INTEGER, - Integer("A", Bounds(10, 100)), - BASE_INTEGER, - ), - ( - BASE_INTEGER, - UniversalInteger(Bounds(10, 100)), - BASE_INTEGER, - ), - (BASE_INTEGER, Undefined(), Undefined()), - (BASE_INTEGER, ENUMERATION_B, Undefined()), - ], -) -def test_base_integer_common_type(base_integer: Type, other: Type, expected: Type) -> None: - assert base_integer.common_type(other) == expected - assert other.common_type(base_integer) == expected - - -@pytest.mark.parametrize( - ("base_integer", "other", "expected"), - [ - (BASE_INTEGER, Any(), True), - (BASE_INTEGER, BASE_INTEGER, True), - ( - BASE_INTEGER, - Integer("A", Bounds(10, 100)), - True, - ), - ( - BASE_INTEGER, - UniversalInteger(Bounds(10, 100)), - True, - ), - (BASE_INTEGER, Undefined(), False), - (BASE_INTEGER, ENUMERATION_B, False), - ], -) -def test_base_integer_is_compatible(base_integer: Type, other: Type, expected: bool) -> None: - assert base_integer.is_compatible(other) == expected - assert other.is_compatible(base_integer) == expected - - -@pytest.mark.parametrize( - ("universal_integer", "other", "expected"), - [ - ( - UniversalInteger(Bounds(10, 100)), - Any(), - UniversalInteger(Bounds(10, 100)), - ), - ( - UniversalInteger(Bounds(10, 100)), - BASE_INTEGER, - BASE_INTEGER, - ), - ( - UniversalInteger(Bounds(10, 100)), - UniversalInteger(Bounds(10, 100)), - UniversalInteger(Bounds(10, 100)), - ), - ( - UniversalInteger(Bounds(10, 100)), - Integer("A", Bounds(10, 100)), - BASE_INTEGER, - ), - ( - UniversalInteger(Bounds(20, 80)), - Integer("A", Bounds(10, 100)), - BASE_INTEGER, - ), - ( - UniversalInteger(Bounds(10, 100)), - Undefined(), - Undefined(), - ), - ( - UniversalInteger(Bounds(10, 100)), - ENUMERATION_B, - Undefined(), - ), - ], -) -def test_universal_integer_common_type( - universal_integer: Type, - other: Type, - expected: Type, -) -> None: - assert universal_integer.common_type(other) == expected - assert other.common_type(universal_integer) == expected - - -@pytest.mark.parametrize( - ("universal_integer", "other", "expected"), - [ - (UniversalInteger(Bounds(10, 100)), Any(), True), - (UniversalInteger(Bounds(10, 100)), BASE_INTEGER, True), - (UniversalInteger(Bounds(10, 100)), UniversalInteger(Bounds(10, 100)), True), - ( - UniversalInteger(Bounds(10, 100)), - Integer("A", Bounds(10, 100)), - True, - ), - (UniversalInteger(Bounds(10, 100)), Undefined(), False), - (UniversalInteger(Bounds(10, 100)), ENUMERATION_B, False), - ], -) -def test_universal_integer_is_compatible( - universal_integer: Type, - other: Type, - expected: bool, -) -> None: - assert universal_integer.is_compatible(other) == expected - assert other.is_compatible(universal_integer) == expected - - -@pytest.mark.parametrize( - ("integer", "other", "expected"), - [ - ( - Integer("A", Bounds(10, 100)), - Any(), - Integer("A", Bounds(10, 100)), - ), - ( - Integer("A", Bounds(10, 100)), - BASE_INTEGER, - BASE_INTEGER, - ), - ( - Integer("A", Bounds(10, 100)), - Integer("A", Bounds(10, 100)), - BASE_INTEGER, - ), - ( - Integer("A", Bounds(10, 100)), - UniversalInteger(Bounds(10, 100)), - BASE_INTEGER, - ), - ( - Integer("A", Bounds(10, 100)), - Integer("B", Bounds(10, 100)), - BASE_INTEGER, - ), - ( - Integer("A", Bounds(10, 100)), - UniversalInteger(Bounds(0, 200)), - BASE_INTEGER, - ), - ( - Integer("A", Bounds(10, 100)), - Undefined(), - Undefined(), - ), - ( - Integer("A", Bounds(10, 100)), - ENUMERATION_B, - Undefined(), - ), - ], -) -def test_integer_common_type(integer: Type, other: Type, expected: Type) -> None: - assert integer.common_type(other) == expected - assert other.common_type(integer) == expected - - -@pytest.mark.parametrize( - ("integer", "other", "expected"), - [ - (Integer("A", Bounds(10, 100)), Any(), True), - (Integer("A", Bounds(10, 100)), BASE_INTEGER, True), - (Integer("A", Bounds(10, 100)), Integer("A", Bounds(10, 100)), True), - (Integer("A", Bounds(10, 100)), UniversalInteger(Bounds(10, 100)), True), - ( - Integer("A", Bounds(10, 100)), - Integer("B", Bounds(10, 100)), - True, - ), - ( - Integer("A", Bounds(0, 200)), - UniversalInteger(Bounds(10, 100)), - True, - ), - ( - Integer("A", Bounds(10, 100)), - UniversalInteger(Bounds(0, 200)), - True, - ), - (Integer("A", Bounds(10, 100)), Undefined(), False), - (Integer("A", Bounds(10, 100)), ENUMERATION_B, False), - ], -) -def test_integer_is_compatible(integer: Type, other: Type, expected: bool) -> None: - assert integer.is_compatible(other) == expected - assert other.is_compatible(integer) == expected - - -@pytest.mark.parametrize( - ("integer", "other", "expected"), - [ - (Integer("A", Bounds(10, 100)), Any(), True), - (Integer("A", Bounds(10, 100)), BASE_INTEGER, False), - (Integer("A", Bounds(10, 100)), Integer("A", Bounds(10, 100)), True), - (Integer("A", Bounds(10, 100)), UniversalInteger(Bounds(10, 100)), True), - ( - Integer("A", Bounds(10, 100)), - Integer("B", Bounds(10, 100)), - False, - ), - ( - Integer("A", Bounds(0, 200)), - UniversalInteger(Bounds(10, 100)), - True, - ), - ( - Integer("A", Bounds(10, 100)), - UniversalInteger(Bounds(0, 200)), - False, - ), - (Integer("A", Bounds(10, 100)), Undefined(), False), - (Integer("A", Bounds(10, 100)), ENUMERATION_B, False), - ], -) -def test_integer_is_compatible_strong(integer: Type, other: Type, expected: bool) -> None: - assert integer.is_compatible_strong(other) == expected - assert other.is_compatible_strong(integer) == expected - - -@pytest.mark.parametrize( - ("aggregate", "other", "expected"), - [ - ( - Aggregate(Integer("A", Bounds(10, 100))), - Any(), - Aggregate(Integer("A", Bounds(10, 100))), - ), - ( - Aggregate(Integer("A", Bounds(10, 100))), - Aggregate(Integer("A", Bounds(10, 100))), - Aggregate(Integer("A", Bounds(10, 100))), - ), - ( - Aggregate(Integer("A", Bounds(10, 100))), - Aggregate(Integer("B", Bounds(10, 100))), - Aggregate(BASE_INTEGER), - ), - ( - Aggregate(Integer("A", Bounds(10, 100))), - Aggregate(Integer("B", Bounds(20, 200))), - Aggregate(BASE_INTEGER), - ), - ( - Aggregate(UniversalInteger(Bounds(10, 100))), - Aggregate(UniversalInteger(Bounds(20, 200))), - Aggregate(UniversalInteger(Bounds(10, 200))), - ), - ( - Aggregate(Integer("A", Bounds(10, 100))), - Undefined(), - Undefined(), - ), - ], -) -def test_aggregate_common_type(aggregate: Type, other: Type, expected: Type) -> None: - assert aggregate.common_type(other) == expected - assert other.common_type(aggregate) == expected - - -@pytest.mark.parametrize( - ("aggregate", "other", "expected"), - [ - ( - Aggregate(Integer("A", Bounds(10, 100))), - Any(), - True, - ), - ( - Aggregate(Integer("A", Bounds(10, 100))), - Aggregate(Integer("A", Bounds(10, 100))), - True, - ), - ( - Aggregate(Integer("A", Bounds(10, 100))), - Aggregate(Integer("B", Bounds(10, 100))), - True, - ), - ( - Aggregate(Integer("A", Bounds(10, 100))), - Aggregate(Integer("A", Bounds(20, 200))), - True, - ), - ( - Aggregate(UniversalInteger(Bounds(10, 100))), - Aggregate(UniversalInteger(Bounds(20, 200))), - True, - ), - ( - Aggregate(Integer("A", Bounds(10, 100))), - Undefined(), - False, - ), - ], -) -def test_aggregate_is_compatible(aggregate: Type, other: Type, expected: bool) -> None: - assert aggregate.is_compatible(other) == expected - assert other.is_compatible(aggregate) == expected - - -@pytest.mark.parametrize( - ("composite", "other", "expected"), - [ - ( - Sequence("A", Integer("B", Bounds(10, 100))), - Any(), - Sequence("A", Integer("B", Bounds(10, 100))), - ), - ( - Sequence("A", Integer("B", Bounds(10, 100))), - Sequence("A", Integer("B", Bounds(10, 100))), - Sequence("A", Integer("B", Bounds(10, 100))), - ), - ( - Sequence("A", Integer("B", Bounds(10, 100))), - Aggregate(Integer("B", Bounds(10, 100))), - Sequence("A", Integer("B", Bounds(10, 100))), - ), - ( - Sequence("A", Integer("B", Bounds(10, 100))), - Aggregate(UniversalInteger(Bounds(10, 100))), - Sequence("A", Integer("B", Bounds(10, 100))), - ), - ( - Sequence("A", Integer("B", Bounds(10, 100))), - Aggregate(Integer("C", Bounds(10, 100))), - Undefined(), - ), - ( - Sequence("A", Integer("B", Bounds(10, 100))), - Aggregate(Integer("C", Bounds(20, 200))), - Undefined(), - ), - ( - Sequence("A", Integer("B", Bounds(10, 100))), - Aggregate(UniversalInteger(Bounds(20, 200))), - Undefined(), - ), - ( - Sequence("A", Integer("B", Bounds(10, 100))), - Undefined(), - Undefined(), - ), - ], -) -def test_composite_common_type(composite: Type, other: Type, expected: Type) -> None: - assert composite.common_type(other) == expected - assert other.common_type(composite) == expected - - -@pytest.mark.parametrize( - ("composite", "other", "expected"), - [ - ( - Sequence("A", Integer("B", Bounds(10, 100))), - Any(), - True, - ), - ( - Sequence("A", Integer("B", Bounds(10, 100))), - Sequence("A", Integer("B", Bounds(10, 100))), - True, - ), - ( - Sequence("A", Integer("B", Bounds(10, 100))), - Aggregate(Any()), - True, - ), - ( - Sequence("A", Integer("B", Bounds(10, 100))), - Aggregate(Integer("B", Bounds(10, 100))), - True, - ), - ( - Sequence("A", Integer("B", Bounds(10, 100))), - Aggregate(UniversalInteger(Bounds(10, 100))), - True, - ), - ( - Sequence("A", Integer("B", Bounds(10, 100))), - Aggregate(Integer("C", Bounds(10, 100))), - False, - ), - ( - Sequence("A", Integer("B", Bounds(10, 100))), - Aggregate(Integer("C", Bounds(20, 200))), - False, - ), - ( - Sequence("A", Integer("B", Bounds(10, 100))), - Aggregate(UniversalInteger(Bounds(20, 200))), - False, - ), - ( - Sequence("A", Integer("B", Bounds(10, 100))), - Undefined(), - False, - ), - ], -) -def test_composite_is_compatible(composite: Type, other: Type, expected: bool) -> None: - assert composite.is_compatible(other) == expected - assert other.is_compatible(composite) == expected - - -@pytest.mark.parametrize( - ("message", "other", "expected"), - [ - (Message("A"), Any(), Message("A")), - (Message("A"), Message("A"), Message("A")), - (Message("A"), Message("B"), Undefined()), - (Message("A"), Undefined(), Undefined()), - (Message("A"), ENUMERATION_B, Undefined()), - ], -) -def test_message_common_type(message: Type, other: Type, expected: Type) -> None: - assert message.common_type(other) == expected - assert other.common_type(message) == expected - - -@pytest.mark.parametrize( - ("message", "other", "expected"), - [ - (Message("A"), Any(), True), - (Message("A"), Message("A"), True), - (Message("A"), Message("B"), False), - (Message("A"), Undefined(), False), - (Message("A"), ENUMERATION_B, False), - ], -) -def test_message_is_compatible(message: Type, other: Type, expected: bool) -> None: - assert message.is_compatible(other) == expected - assert other.is_compatible(message) == expected - - -@pytest.mark.parametrize( - ("channel", "other", "expected"), - [ - (Channel(readable=True, writable=False), Any(), Channel(readable=True, writable=False)), - ( - Channel(readable=True, writable=False), - Channel(readable=True, writable=False), - Channel(readable=True, writable=False), - ), - ( - Channel(readable=True, writable=False), - Channel(readable=False, writable=True), - Undefined(), - ), - (Channel(readable=True, writable=False), Undefined(), Undefined()), - ( - Channel(readable=True, writable=False), - ENUMERATION_A, - Undefined(), - ), - ], -) -def test_channel_common_type(channel: Type, other: Type, expected: Type) -> None: - assert channel.common_type(other) == expected - assert other.common_type(channel) == expected - - -@pytest.mark.parametrize( - ("channel", "other", "expected"), - [ - (Channel(readable=True, writable=False), Any(), True), - (Any(), Channel(readable=True, writable=False), True), - (Channel(readable=True, writable=False), Channel(readable=True, writable=False), True), - (Channel(readable=True, writable=False), Channel(readable=False, writable=True), False), - (Channel(readable=False, writable=True), Channel(readable=True, writable=False), False), - (Channel(readable=True, writable=True), Channel(readable=False, writable=True), True), - (Channel(readable=True, writable=True), Channel(readable=True, writable=False), True), - (Channel(readable=False, writable=True), Channel(readable=True, writable=True), False), - (Channel(readable=True, writable=False), Channel(readable=True, writable=True), False), - (Channel(readable=True, writable=False), Undefined(), False), - (Channel(readable=True, writable=False), ENUMERATION_A, False), - ], -) -def test_channel_is_compatible(channel: Type, other: Type, expected: bool) -> None: - assert channel.is_compatible(other) == expected - - -@pytest.mark.parametrize( - ("types", "expected"), - [ - ( - [], - Any(), - ), - ( - [ - Integer("A", Bounds(10, 100)), - Integer("A", Bounds(10, 100)), - ], - BASE_INTEGER, - ), - ( - [ - UniversalInteger(Bounds(10, 50)), - Integer("A", Bounds(10, 100)), - UniversalInteger(Bounds(50, 100)), - ], - BASE_INTEGER, - ), - ( - [ - UniversalInteger(Bounds(10, 50)), - Integer("A", Bounds(10, 100)), - UniversalInteger(Bounds(20, 200)), - ], - BASE_INTEGER, - ), - ( - [ - Aggregate(Integer("A", Bounds(10, 100))), - Aggregate(UniversalInteger(Bounds(20, 100))), - Aggregate(Integer("B", Bounds(20, 200))), - ], - Aggregate(BASE_INTEGER), - ), - ( - [ - Aggregate(UniversalInteger(Bounds(10, 20))), - Aggregate(UniversalInteger(Bounds(50, 60))), - Aggregate(UniversalInteger(Bounds(90, 100))), - ], - Aggregate(UniversalInteger(Bounds(10, 100))), - ), - ], -) -def test_common_type(types: abc.Sequence[Type], expected: Type) -> None: - assert common_type(types) == expected - assert common_type(list(reversed(types))) == expected - - -@pytest.mark.parametrize( - ("actual", "expected"), - [ - ( - Any(), - Integer("A", Bounds(10, 100)), - ), - ], -) -def test_check_type(actual: Type, expected: Type) -> None: - check_type(actual, expected, Location((10, 20)), '"A"').propagate() - - -@pytest.mark.parametrize( - ("actual", "expected", "match"), - [ - ( - Message("A"), - Channel(readable=False, writable=True), - r"^:10:20: error: expected writable channel\n" - r':10:20: error: found message type "A"$', - ), - ( - BASE_INTEGER, - Message("A"), - r"^" - r':10:20: error: expected message type "A"\n' - r':10:20: error: found integer type "__BUILTINS__::Base_Integer"' - r" \(0 \.\. 2\*\*63 - 1\)" - r"$", - ), - ( - Undefined(), - Integer("A", Bounds(10, 100)), - r'^:10:20: error: undefined "A"$', - ), - ], -) -def test_check_type_error(actual: Type, expected: Type, match: str) -> None: - with pytest.raises(RecordFluxError, match=match): - check_type(actual, expected, Location((10, 20)), '"A"').propagate() - - -@pytest.mark.parametrize( - ("actual", "expected"), - [ - ( - Any(), - Integer, - ), - ], -) -def test_check_type_instance( - actual: Type, - expected: type[Type] | tuple[type[Type], ...], -) -> None: - check_type_instance(actual, expected, Location((10, 20)), '"A"').propagate() - - -@pytest.mark.parametrize( - ("actual", "expected", "match"), - [ - ( - Message("M"), - Channel, - r"^:10:20: error: expected channel\n" - r':10:20: error: found message type "M"$', - ), - ( - BASE_INTEGER, - (Sequence, Message), - r"^" - r":10:20: error: expected sequence type or message type\n" - r':10:20: error: found integer type "__BUILTINS__::Base_Integer"' - r" \(0 \.\. 2\*\*63 - 1\)" - r"$", - ), - ( - Undefined(), - Integer, - r'^:10:20: error: undefined "A"$', - ), - ], -) -def test_check_type_instance_error( - actual: Type, - expected: type[Type] | tuple[type[Type], ...], - match: str, -) -> None: - with pytest.raises(RecordFluxError, match=match): - check_type_instance(actual, expected, Location((10, 20)), '"A"').propagate()