Skip to content

Commit

Permalink
feat!: Allow CustomConsts to (optionally) be hashable (#1397)
Browse files Browse the repository at this point in the history
* Add trait TryHash as prereq for CustomConst
* Automatically impl'd if your const impl's Hash
* Can also trivially implement (i.e. `impl TryHash for Foo { }`) to say
"no, not hashable"
* Derive Hash for most consts, but not ConstF64

BREAKING CHANGE: any `impl CustomConst` will need to either `impl Hash`
or `impl MaybeHash`
  • Loading branch information
acl-cqc authored Sep 11, 2024
1 parent 123321e commit 07b2f58
Show file tree
Hide file tree
Showing 6 changed files with 131 additions and 15 deletions.
8 changes: 4 additions & 4 deletions hugr-core/src/extension/prelude.rs
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ pub const STRING_CUSTOM_TYPE: CustomType =
/// String type.
pub const STRING_TYPE: Type = Type::new_extension(STRING_CUSTOM_TYPE);

#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
#[derive(Debug, Clone, PartialEq, Hash, serde::Serialize, serde::Deserialize)]
/// Structure for holding constant string values.
pub struct ConstString(String);

Expand Down Expand Up @@ -329,7 +329,7 @@ pub fn const_fail_tuple(
const_left_tuple(values, ty_ok)
}

#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
#[derive(Debug, Clone, PartialEq, Hash, serde::Serialize, serde::Deserialize)]
/// Structure for holding constant usize values.
pub struct ConstUsize(u64);

Expand Down Expand Up @@ -364,7 +364,7 @@ impl CustomConst for ConstUsize {
}
}

#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
#[derive(Debug, Clone, PartialEq, Hash, serde::Serialize, serde::Deserialize)]
/// Structure for holding constant usize values.
pub struct ConstError {
/// Integer tag/signal for the error.
Expand Down Expand Up @@ -409,7 +409,7 @@ impl CustomConst for ConstError {
}
}

#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
#[derive(Debug, Clone, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
/// A structure for holding references to external symbols.
pub struct ConstExternalSymbol {
/// The symbol name that this value refers to. Must be nonempty.
Expand Down
75 changes: 73 additions & 2 deletions hugr-core/src/ops/constant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
mod custom;

use std::collections::hash_map::DefaultHasher; // Moves into std::hash in Rust 1.76.
use std::hash::{Hash, Hasher};

use super::{NamedOp, OpName, OpTrait, StaticTag};
use super::{OpTag, OpType};
use crate::extension::ExtensionSet;
Expand All @@ -16,7 +19,7 @@ use thiserror::Error;

pub use custom::{
downcast_equal_consts, get_pair_of_input_values, get_single_input_value, CustomConst,
CustomSerialized,
CustomSerialized, TryHash,
};

#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
Expand Down Expand Up @@ -134,6 +137,24 @@ impl Sum {
// For valid instances, the type row will not have any row variables.
self.sum_type.as_tuple().map(|_| self.values.as_ref())
}

fn try_hash<H: Hasher>(&self, st: &mut H) -> bool {
maybe_hash_values(&self.values, st) && {
st.write_usize(self.tag);
self.sum_type.hash(st);
true
}
}
}

pub(crate) fn maybe_hash_values<H: Hasher>(vals: &[Value], st: &mut H) -> bool {
// We can't mutate the Hasher with the first element
// if any element, even the last, fails.
let mut hasher = DefaultHasher::new();
vals.iter().all(|e| e.try_hash(&mut hasher)) && {
st.write_u64(hasher.finish());
true
}
}

impl TryFrom<SerialSum> for Sum {
Expand Down Expand Up @@ -508,6 +529,17 @@ impl Value {
None
}
}

/// Hashes this value, if possible. [Value::Extension]s are hashable according
/// to their implementation of [TryHash]; [Value::Function]s never are;
/// [Value::Sum]s are if their contents are.
pub fn try_hash<H: Hasher>(&self, st: &mut H) -> bool {
match self {
Value::Extension { e } => e.value().try_hash(&mut *st),
Value::Function { .. } => false,
Value::Sum(s) => s.try_hash(st),
}
}
}

impl<T> From<T> for Value
Expand All @@ -527,6 +559,8 @@ pub type ValueNameRef = str;

#[cfg(test)]
mod test {
use std::collections::HashSet;

use super::Value;
use crate::builder::inout_sig;
use crate::builder::test::simple_dfg_hugr;
Expand All @@ -547,7 +581,7 @@ mod test {

use super::*;

#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
#[derive(Debug, Clone, PartialEq, Hash, serde::Serialize, serde::Deserialize)]
/// A custom constant value used in testing
pub(crate) struct CustomTestValue(pub CustomType);

Expand Down Expand Up @@ -727,6 +761,43 @@ mod test {
assert_ne!(json_const.get_type(), t);
}

#[rstest]
fn hash_tuple(const_tuple: Value) {
let vals = [
Value::unit(),
Value::true_val(),
Value::false_val(),
ConstUsize::new(13).into(),
Value::tuple([ConstUsize::new(13).into()]),
Value::tuple([ConstUsize::new(13).into(), ConstUsize::new(14).into()]),
Value::tuple([ConstUsize::new(13).into(), ConstUsize::new(15).into()]),
const_tuple,
];

let num_vals = vals.len();
let hashes = vals.map(|v| {
let mut h = DefaultHasher::new();
v.try_hash(&mut h).then_some(()).unwrap();
h.finish()
});
assert_eq!(HashSet::from(hashes).len(), num_vals); // all distinct
}

#[test]
fn unhashable_tuple() {
let tup = Value::tuple([ConstUsize::new(5).into(), ConstF64::new(4.97).into()]);
let mut h1 = DefaultHasher::new();
let r = tup.try_hash(&mut h1);
assert!(!r);

// Check that didn't do anything, by checking the hasher behaves
// just like one which never saw the tuple
h1.write_usize(5);
let mut h2 = DefaultHasher::new();
h2.write_usize(5);
assert_eq!(h1.finish(), h2.finish());
}

mod proptest {
use super::super::{OpaqueValue, Sum};
use crate::{
Expand Down
44 changes: 38 additions & 6 deletions hugr-core/src/ops/constant/custom.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,17 @@
//! [`Const`]: crate::ops::Const
use std::any::Any;
use std::hash::{Hash, Hasher};

use downcast_rs::{impl_downcast, Downcast};
use thiserror::Error;

use crate::extension::ExtensionSet;
use crate::macros::impl_box_clone;

use crate::types::{CustomCheckFailure, Type};
use crate::IncomingPort;

use super::Value;

use super::ValueName;
use super::{Value, ValueName};

/// Extensible constant values.
///
Expand All @@ -37,7 +35,7 @@ use super::ValueName;
/// extension::ExtensionSet, std_extensions::arithmetic::int_types};
/// use serde_json::json;
///
/// #[derive(std::fmt::Debug, Clone, Serialize,Deserialize)]
/// #[derive(std::fmt::Debug, Clone, Hash, Serialize,Deserialize)]
/// struct CC(i64);
///
/// #[typetag::serde]
Expand All @@ -55,7 +53,7 @@ use super::ValueName;
/// ```
#[typetag::serde(tag = "c", content = "v")]
pub trait CustomConst:
Send + Sync + std::fmt::Debug + CustomConstBoxClone + Any + Downcast
Send + Sync + std::fmt::Debug + TryHash + CustomConstBoxClone + Any + Downcast
{
/// An identifier for the constant.
fn name(&self) -> ValueName;
Expand Down Expand Up @@ -90,6 +88,32 @@ pub trait CustomConst:
fn get_type(&self) -> Type;
}

/// Prerequisite for `CustomConst`. Allows to declare a custom hash function,
/// but the easiest options are either to `impl TryHash for ... {}` to indicate
/// "not hashable", or else to implement/derive [Hash].
pub trait TryHash {
/// Hashes the value, if possible; else return `false` without mutating the `Hasher`.
/// This relates with [CustomConst::equal_consts] just like [Hash] with [Eq]:
/// * if `x.equal_consts(y)` ==> `x.try_hash(s)` behaves equivalently to `y.try_hash(s)`
/// * if `x.hash(s)` behaves differently from `y.hash(s)` ==> `x.equal_consts(y) == false`
///
/// As with [Hash], these requirements can trivially be satisfied by either
/// * `equal_consts` always returning `false`, or
/// * `try_hash` always behaving the same (e.g. returning `false`, as it does by default)
///
/// Note: uses `dyn` rather than being parametrized by `<H: Hasher>` to be object-safe.
fn try_hash(&self, _state: &mut dyn Hasher) -> bool {
false
}
}

impl<T: Hash> TryHash for T {
fn try_hash(&self, mut st: &mut dyn Hasher) -> bool {
Hash::hash(self, &mut st);
true
}
}

impl PartialEq for dyn CustomConst {
fn eq(&self, other: &Self) -> bool {
(*self).equal_consts(other)
Expand Down Expand Up @@ -253,6 +277,14 @@ impl CustomSerialized {
}
}

impl TryHash for CustomSerialized {
fn try_hash(&self, mut st: &mut dyn Hasher) -> bool {
// Consistent with equality, same serialization <=> same hash.
self.value.to_string().hash(&mut st);
true
}
}

#[typetag::serde]
impl CustomConst for CustomSerialized {
fn name(&self) -> ValueName {
Expand Down
4 changes: 3 additions & 1 deletion hugr-core/src/std_extensions/arithmetic/float_types.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//! Basic floating-point types
use crate::ops::constant::ValueName;
use crate::ops::constant::{TryHash, ValueName};
use crate::types::TypeName;
use crate::{
extension::{ExtensionId, ExtensionSet},
Expand Down Expand Up @@ -56,6 +56,8 @@ impl ConstF64 {
}
}

impl TryHash for ConstF64 {}

#[typetag::serde]
impl CustomConst for ConstF64 {
fn name(&self) -> ValueName {
Expand Down
2 changes: 1 addition & 1 deletion hugr-core/src/std_extensions/arithmetic/int_types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ const fn type_arg(log_width: u8) -> TypeArg {
}

/// An integer (either signed or unsigned)
#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize, serde::Deserialize)]
#[derive(Clone, Debug, Eq, PartialEq, Hash, serde::Serialize, serde::Deserialize)]
pub struct ConstInt {
log_width: u8,
// We always use a u64 for the value. The interpretation is:
Expand Down
13 changes: 12 additions & 1 deletion hugr-core/src/std_extensions/collections.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
//! List type and operations.
use std::hash::{Hash, Hasher};

mod list_fold;

use std::str::FromStr;
Expand All @@ -12,7 +14,7 @@ use strum_macros::{EnumIter, EnumString, IntoStaticStr};
use crate::extension::prelude::{either_type, option_type, USIZE_T};
use crate::extension::simple_op::{MakeOpDef, MakeRegisteredOp};
use crate::extension::{ExtensionBuildError, OpDef, SignatureFunc, PRELUDE};
use crate::ops::constant::ValueName;
use crate::ops::constant::{maybe_hash_values, TryHash, ValueName};
use crate::ops::{OpName, Value};
use crate::types::{TypeName, TypeRowRV};
use crate::{
Expand Down Expand Up @@ -58,6 +60,15 @@ impl ListValue {
}
}

impl TryHash for ListValue {
fn try_hash(&self, mut st: &mut dyn Hasher) -> bool {
maybe_hash_values(&self.0, &mut st) && {
self.1.hash(&mut st);
true
}
}
}

#[typetag::serde]
impl CustomConst for ListValue {
fn name(&self) -> ValueName {
Expand Down

0 comments on commit 07b2f58

Please sign in to comment.