From dc638c4be1ddf24aea0af3d534766ec12c9b07fe Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Tue, 22 Apr 2025 09:09:25 +0100 Subject: [PATCH 1/5] feat: Make arrays linear and add copyable value arrays --- hugr-core/src/ops/constant.rs | 26 + hugr-core/src/std_extensions.rs | 1 + hugr-core/src/std_extensions/collections.rs | 1 + .../src/std_extensions/collections/array.rs | 254 +++----- .../collections/array/array_kind.rs | 93 +++ .../collections/array/array_op.rs | 198 +++--- .../collections/array/array_repeat.rs | 100 +-- .../collections/array/array_scan.rs | 115 ++-- .../collections/array/array_value.rs | 166 +++++ .../std_extensions/collections/value_array.rs | 133 ++++ hugr-core/src/utils.rs | 12 + hugr-llvm/src/emit/test.rs | 1 + hugr-llvm/src/extension/collections.rs | 2 +- ...lue_array__test__emit_all_ops@llvm14.snap} | 2 +- ...est__emit_all_ops@pre-mem2reg@llvm14.snap} | 2 +- ...array__test__emit_array_value@llvm14.snap} | 2 +- ..._emit_array_value@pre-mem2reg@llvm14.snap} | 2 +- ...__value_array__test__emit_get@llvm14.snap} | 2 +- ...y__test__emit_get@pre-mem2reg@llvm14.snap} | 2 +- .../collections/{array.rs => value_array.rs} | 169 ++--- hugr-llvm/src/utils/array_op_builder.rs | 50 +- hugr-passes/src/monomorphize.rs | 34 +- hugr-passes/src/replace_types.rs | 62 +- hugr-passes/src/replace_types/handlers.rs | 65 +- hugr-passes/src/replace_types/linearize.rs | 41 +- .../std/_json_defs/collections/array.json | 8 +- .../_json_defs/collections/value_array.json | 589 ++++++++++++++++++ hugr-py/src/hugr/std/collections/array.py | 2 +- .../src/hugr/std/collections/value_array.py | 83 +++ hugr-py/tests/test_tys.py | 23 + .../std_extensions/collections/array.json | 8 +- .../collections/value_array.json | 589 ++++++++++++++++++ uv.lock | 2 +- 33 files changed, 2300 insertions(+), 539 deletions(-) create mode 100644 hugr-core/src/std_extensions/collections/array/array_kind.rs create mode 100644 hugr-core/src/std_extensions/collections/array/array_value.rs create mode 100644 hugr-core/src/std_extensions/collections/value_array.rs rename hugr-llvm/src/extension/collections/snapshots/{hugr_llvm__extension__collections__array__test__emit_all_ops@llvm14.snap => hugr_llvm__extension__collections__value_array__test__emit_all_ops@llvm14.snap} (99%) rename hugr-llvm/src/extension/collections/snapshots/{hugr_llvm__extension__collections__array__test__emit_all_ops@pre-mem2reg@llvm14.snap => hugr_llvm__extension__collections__value_array__test__emit_all_ops@pre-mem2reg@llvm14.snap} (99%) rename hugr-llvm/src/extension/collections/snapshots/{hugr_llvm__extension__collections__array__test__emit_array_value@llvm14.snap => hugr_llvm__extension__collections__value_array__test__emit_array_value@llvm14.snap} (82%) rename hugr-llvm/src/extension/collections/snapshots/{hugr_llvm__extension__collections__array__test__emit_array_value@pre-mem2reg@llvm14.snap => hugr_llvm__extension__collections__value_array__test__emit_array_value@pre-mem2reg@llvm14.snap} (90%) rename hugr-llvm/src/extension/collections/snapshots/{hugr_llvm__extension__collections__array__test__emit_get@llvm14.snap => hugr_llvm__extension__collections__value_array__test__emit_get@llvm14.snap} (94%) rename hugr-llvm/src/extension/collections/snapshots/{hugr_llvm__extension__collections__array__test__emit_get@pre-mem2reg@llvm14.snap => hugr_llvm__extension__collections__value_array__test__emit_get@pre-mem2reg@llvm14.snap} (96%) rename hugr-llvm/src/extension/collections/{array.rs => value_array.rs} (91%) create mode 100644 hugr-py/src/hugr/std/_json_defs/collections/value_array.json create mode 100644 hugr-py/src/hugr/std/collections/value_array.py create mode 100644 specification/std_extensions/collections/value_array.json diff --git a/hugr-core/src/ops/constant.rs b/hugr-core/src/ops/constant.rs index 794e6eaaa..925cfd828 100644 --- a/hugr-core/src/ops/constant.rs +++ b/hugr-core/src/ops/constant.rs @@ -605,6 +605,7 @@ pub(crate) mod test { use crate::extension::PRELUDE; use crate::std_extensions::arithmetic::int_types::ConstInt; use crate::std_extensions::collections::array::{array_type, ArrayValue}; + use crate::std_extensions::collections::value_array::{value_array_type, VArrayValue}; use crate::{ builder::{BuildError, DFGBuilder, Dataflow, DataflowHugr}, extension::{ @@ -778,6 +779,11 @@ pub(crate) mod test { ArrayValue::new(bool_t(), [Value::true_val(), Value::false_val()]).into() } + #[fixture] + fn const_value_array_bool() -> Value { + VArrayValue::new(bool_t(), [Value::true_val(), Value::false_val()]).into() + } + #[fixture] fn const_array_options() -> Value { let some_true = Value::some([Value::true_val()]); @@ -786,17 +792,35 @@ pub(crate) mod test { ArrayValue::new(elem_ty.into(), [some_true, none]).into() } + #[fixture] + fn const_value_array_options() -> Value { + let some_true = Value::some([Value::true_val()]); + let none = Value::none(vec![bool_t()]); + let elem_ty = SumType::new_option(vec![bool_t()]); + VArrayValue::new(elem_ty.into(), [some_true, none]).into() + } + #[rstest] #[case(Value::unit(), Type::UNIT, "const:seq:{}")] #[case(const_usize(), usize_t(), "const:custom:ConstUsize(")] #[case(serialized_float(17.4), float64_type(), "const:custom:json:Object")] #[case(const_tuple(), Type::new_tuple(vec![usize_t(), bool_t()]), "const:seq:{")] #[case(const_array_bool(), array_type(2, bool_t()), "const:custom:array")] + #[case( + const_value_array_bool(), + value_array_type(2, bool_t()), + "const:custom:value_array" + )] #[case( const_array_options(), array_type(2, SumType::new_option(vec![bool_t()]).into()), "const:custom:array" )] + #[case( + const_value_array_options(), + value_array_type(2, SumType::new_option(vec![bool_t()]).into()), + "const:custom:value_array" + )] fn const_type( #[case] const_value: Value, #[case] expected_type: Type, @@ -816,7 +840,9 @@ pub(crate) mod test { #[case(const_serialized_usize(), const_usize())] #[case(const_tuple_serialized(), const_tuple())] #[case(const_array_bool(), const_array_bool())] + #[case(const_value_array_bool(), const_value_array_bool())] #[case(const_array_options(), const_array_options())] + #[case(const_value_array_options(), const_value_array_options())] // Opaque constants don't get resolved into concrete types when running miri, // as the `typetag` machinery is not available. #[cfg_attr(miri, ignore)] diff --git a/hugr-core/src/std_extensions.rs b/hugr-core/src/std_extensions.rs index 7892e8fec..cf582f8a1 100644 --- a/hugr-core/src/std_extensions.rs +++ b/hugr-core/src/std_extensions.rs @@ -21,6 +21,7 @@ pub fn std_reg() -> ExtensionRegistry { collections::array::EXTENSION.to_owned(), collections::list::EXTENSION.to_owned(), collections::static_array::EXTENSION.to_owned(), + collections::value_array::EXTENSION.to_owned(), logic::EXTENSION.to_owned(), ptr::EXTENSION.to_owned(), ]); diff --git a/hugr-core/src/std_extensions/collections.rs b/hugr-core/src/std_extensions/collections.rs index 13f5c007e..efd53c805 100644 --- a/hugr-core/src/std_extensions/collections.rs +++ b/hugr-core/src/std_extensions/collections.rs @@ -3,3 +3,4 @@ pub mod array; pub mod list; pub mod static_array; +pub mod value_array; diff --git a/hugr-core/src/std_extensions/collections/array.rs b/hugr-core/src/std_extensions/collections/array.rs index aa43b403e..6e8fae312 100644 --- a/hugr-core/src/std_extensions/collections/array.rs +++ b/hugr-core/src/std_extensions/collections/array.rs @@ -1,162 +1,76 @@ //! Fixed-length array type and operations extension. +mod array_kind; mod array_op; mod array_repeat; mod array_scan; +mod array_value; use std::sync::Arc; -use itertools::Itertools as _; +use delegate::delegate; use lazy_static::lazy_static; -use serde::{Deserialize, Serialize}; -use std::hash::{Hash, Hasher}; -use crate::extension::resolution::{ - resolve_type_extensions, resolve_value_extensions, ExtensionResolutionError, - WeakExtensionRegistry, -}; +use crate::extension::resolution::{ExtensionResolutionError, WeakExtensionRegistry}; use crate::extension::simple_op::{MakeOpDef, MakeRegisteredOp}; use crate::extension::{ExtensionId, ExtensionSet, SignatureError, TypeDef, TypeDefBound}; -use crate::ops::constant::{maybe_hash_values, CustomConst, TryHash, ValueName}; -use crate::ops::{ExtensionOp, OpName, Value}; +use crate::ops::constant::{CustomConst, ValueName}; +use crate::ops::{ExtensionOp, OpName}; use crate::types::type_param::{TypeArg, TypeParam}; -use crate::types::{CustomCheckFailure, CustomType, Type, TypeBound, TypeName}; +use crate::types::{CustomCheckFailure, Type, TypeBound, TypeName}; use crate::Extension; -pub use array_op::{ArrayOp, ArrayOpDef, ArrayOpDefIter}; -pub use array_repeat::{ArrayRepeat, ArrayRepeatDef, ARRAY_REPEAT_OP_ID}; -pub use array_scan::{ArrayScan, ArrayScanDef, ARRAY_SCAN_OP_ID}; +pub use array_kind::ArrayKind; +pub use array_op::{GenericArrayOp, GenericArrayOpDef}; +pub use array_repeat::{GenericArrayRepeat, GenericArrayRepeatDef, ARRAY_REPEAT_OP_ID}; +pub use array_scan::{GenericArrayScan, GenericArrayScanDef, ARRAY_SCAN_OP_ID}; +pub use array_value::GenericArrayValue; /// Reported unique name of the array type. pub const ARRAY_TYPENAME: TypeName = TypeName::new_inline("array"); +/// Reported unique name of the array value. +pub const ARRAY_VALUENAME: TypeName = TypeName::new_inline("array"); /// Reported unique name of the extension pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("collections.array"); /// Extension version. pub const VERSION: semver::Version = semver::Version::new(0, 1, 0); -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] -/// Statically sized array of values, all of the same type. -pub struct ArrayValue { - values: Vec, - typ: Type, -} - -impl ArrayValue { - /// Name of the constructor for creating constant arrays. - #[cfg_attr(not(feature = "model_unstable"), allow(dead_code))] - pub(crate) const CTR_NAME: &'static str = "collections.array.const"; - - /// Create a new [CustomConst] for an array of values of type `typ`. - /// That all values are of type `typ` is not checked here. - pub fn new(typ: Type, contents: impl IntoIterator) -> Self { - Self { - values: contents.into_iter().collect_vec(), - typ, - } - } - - /// Create a new [CustomConst] for an empty array of values of type `typ`. - pub fn new_empty(typ: Type) -> Self { - Self { - values: vec![], - typ, - } - } - - /// Returns the type of the `[ArrayValue]` as a `[CustomType]`.` - pub fn custom_type(&self) -> CustomType { - array_custom_type(self.values.len() as u64, self.typ.clone()) - } +/// A linear, fixed-length collection of values. +/// +/// Arrays are linear, even if their elements are copyable. +#[derive(Clone, Copy, Debug, derive_more::Display, Eq, PartialEq, Default)] +pub struct Array; - /// Returns the type of values inside the `[ArrayValue]`. - pub fn get_element_type(&self) -> &Type { - &self.typ - } +impl ArrayKind for Array { + const EXTENSION_ID: ExtensionId = EXTENSION_ID; + const TYPE_NAME: TypeName = ARRAY_TYPENAME; + const VALUE_NAME: ValueName = ARRAY_VALUENAME; - /// Returns the values contained inside the `[ArrayValue]`. - pub fn get_contents(&self) -> &[Value] { - &self.values + fn extension() -> &'static Arc { + &EXTENSION } -} -impl TryHash for ArrayValue { - fn try_hash(&self, mut st: &mut dyn Hasher) -> bool { - maybe_hash_values(&self.values, &mut st) && { - self.typ.hash(&mut st); - true - } + fn type_def() -> &'static TypeDef { + EXTENSION.get_type(&ARRAY_TYPENAME).unwrap() } } -#[typetag::serde] -impl CustomConst for ArrayValue { - fn name(&self) -> ValueName { - ValueName::new_inline("array") - } - - fn get_type(&self) -> Type { - self.custom_type().into() - } +/// Array operation definitions. +pub type ArrayOpDef = GenericArrayOpDef; +/// Array repeat operation definition. +pub type ArrayRepeatDef = GenericArrayRepeatDef; +/// Array scan operation definition. +pub type ArrayScanDef = GenericArrayScanDef; - fn validate(&self) -> Result<(), CustomCheckFailure> { - let typ = self.custom_type(); - - EXTENSION - .get_type(&ARRAY_TYPENAME) - .unwrap() - .check_custom(&typ) - .map_err(|_| { - CustomCheckFailure::Message(format!( - "Custom typ {typ} is not a valid instantiation of array." - )) - })?; - - // constant can only hold classic type. - let ty = match typ.args() { - [TypeArg::BoundedNat { n }, TypeArg::Type { ty }] - if *n as usize == self.values.len() => - { - ty - } - _ => { - return Err(CustomCheckFailure::Message(format!( - "Invalid array type arguments: {:?}", - typ.args() - ))) - } - }; - - // check all values are instances of the element type - for v in &self.values { - if v.get_type() != *ty { - return Err(CustomCheckFailure::Message(format!( - "Array element {v:?} is not of expected type {ty}" - ))); - } - } - - Ok(()) - } - - fn equal_consts(&self, other: &dyn CustomConst) -> bool { - crate::ops::constant::downcast_equal_consts(self, other) - } - - fn extension_reqs(&self) -> ExtensionSet { - ExtensionSet::union_over(self.values.iter().map(Value::extension_reqs)) - .union(EXTENSION_ID.into()) - } +/// Array operations. +pub type ArrayOp = GenericArrayOp; +/// The array repeat operation. +pub type ArrayRepeat = GenericArrayRepeat; +/// The array scan operation. +pub type ArrayScan = GenericArrayScan; - fn update_extensions( - &mut self, - extensions: &WeakExtensionRegistry, - ) -> Result<(), ExtensionResolutionError> { - for val in &mut self.values { - resolve_value_extensions(val, extensions)?; - } - resolve_type_extensions(&mut self.typ, extensions) - } -} +/// An array extension value. +pub type ArrayValue = GenericArrayValue; lazy_static! { /// Extension for array operations. @@ -166,22 +80,49 @@ lazy_static! { ARRAY_TYPENAME, vec![ TypeParam::max_nat(), TypeBound::Any.into()], "Fixed-length array".into(), - TypeDefBound::from_params(vec![1] ), + // Default array is linear, even if the elements are copyable + TypeDefBound::any(), extension_ref, ) .unwrap(); - array_op::ArrayOpDef::load_all_ops(extension, extension_ref).unwrap(); - array_repeat::ArrayRepeatDef.add_to_extension(extension, extension_ref).unwrap(); - array_scan::ArrayScanDef.add_to_extension(extension, extension_ref).unwrap(); + ArrayOpDef::load_all_ops(extension, extension_ref).unwrap(); + ArrayRepeatDef::new().add_to_extension(extension, extension_ref).unwrap(); + ArrayScanDef::new().add_to_extension(extension, extension_ref).unwrap(); }) }; } +impl ArrayValue { + /// Name of the constructor for creating constant arrays. + #[cfg_attr(not(feature = "model_unstable"), allow(dead_code))] + pub(crate) const CTR_NAME: &'static str = "collections.array.const"; +} + +#[typetag::serde(name = "ArrayValue")] +impl CustomConst for ArrayValue { + delegate! { + to self { + fn name(&self) -> ValueName; + fn extension_reqs(&self) -> ExtensionSet; + fn validate(&self) -> Result<(), CustomCheckFailure>; + fn update_extensions( + &mut self, + extensions: &WeakExtensionRegistry, + ) -> Result<(), ExtensionResolutionError>; + fn get_type(&self) -> Type; + } + } + + fn equal_consts(&self, other: &dyn CustomConst) -> bool { + crate::ops::constant::downcast_equal_consts(self, other) + } +} + /// Gets the [TypeDef] for arrays. Note that instantiations are more easily /// created via [array_type] and [array_type_parametric] pub fn array_type_def() -> &'static TypeDef { - EXTENSION.get_type(&ARRAY_TYPENAME).unwrap() + Array::type_def() } /// Instantiate a new array type given a size argument and element type. @@ -189,7 +130,7 @@ pub fn array_type_def() -> &'static TypeDef { /// This method is equivalent to [`array_type_parametric`], but uses concrete /// arguments types to ensure no errors are possible. pub fn array_type(size: u64, element_ty: Type) -> Type { - array_custom_type(size, element_ty).into() + Array::ty(size, element_ty) } /// Instantiate a new array type given the size and element type parameters. @@ -199,28 +140,7 @@ pub fn array_type_parametric( size: impl Into, element_ty: impl Into, ) -> Result { - instantiate_array(array_type_def(), size, element_ty) -} - -fn array_custom_type(size: impl Into, element_ty: impl Into) -> CustomType { - instantiate_array_custom(array_type_def(), size, element_ty) - .expect("array parameters are valid") -} - -fn instantiate_array_custom( - array_def: &TypeDef, - size: impl Into, - element_ty: impl Into, -) -> Result { - array_def.instantiate(vec![size.into(), element_ty.into()]) -} - -fn instantiate_array( - array_def: &TypeDef, - size: impl Into, - element_ty: impl Into, -) -> Result { - instantiate_array_custom(array_def, size, element_ty).map(Into::into) + Array::ty_parametric(size, element_ty) } /// Name of the operation in the prelude for creating new arrays. @@ -228,18 +148,16 @@ pub const NEW_ARRAY_OP_ID: OpName = OpName::new_inline("new_array"); /// Initialize a new array op of element type `element_ty` of length `size` pub fn new_array_op(element_ty: Type, size: u64) -> ExtensionOp { - let op = array_op::ArrayOpDef::new_array.to_concrete(element_ty, size); + let op = ArrayOpDef::new_array.to_concrete(element_ty, size); op.to_extension_op().unwrap() } #[cfg(test)] mod test { use crate::builder::{inout_sig, DFGBuilder, Dataflow, DataflowHugr}; - use crate::extension::prelude::{qb_t, usize_t, ConstUsize}; - use crate::ops::constant::CustomConst; - use crate::std_extensions::arithmetic::float_types::ConstF64; + use crate::extension::prelude::qb_t; - use super::{array_type, new_array_op, ArrayValue}; + use super::{array_type, new_array_op}; #[test] /// Test building a HUGR involving a new_array operation. @@ -255,20 +173,4 @@ mod test { b.finish_hugr_with_outputs(out.outputs()).unwrap(); } - - #[test] - fn test_array_value() { - let array_value = ArrayValue { - values: vec![ConstUsize::new(3).into()], - typ: usize_t(), - }; - - array_value.validate().unwrap(); - - let wrong_array_value = ArrayValue { - values: vec![ConstF64::new(1.2).into()], - typ: usize_t(), - }; - assert!(wrong_array_value.validate().is_err()); - } } diff --git a/hugr-core/src/std_extensions/collections/array/array_kind.rs b/hugr-core/src/std_extensions/collections/array/array_kind.rs new file mode 100644 index 000000000..61a8cd3ae --- /dev/null +++ b/hugr-core/src/std_extensions/collections/array/array_kind.rs @@ -0,0 +1,93 @@ +use std::sync::Arc; + +use crate::{ + extension::{ExtensionId, SignatureError, TypeDef}, + ops::constant::ValueName, + types::{CustomType, Type, TypeArg, TypeName}, + Extension, +}; + +/// Trait capturing a concrete array implementation in an extension. +/// +/// Array operations are generically defined over this trait so the different +/// array extensions can share parts of their implementation. See for example +/// [`GenericArrayOpDef`] or [`GenericArrayValue`] +/// +/// Currently the available kinds of array are [`Array`] (the default one) and +/// [`ValueArray`]. +/// +/// [`GenericArrayOpDef`]: super::GenericArrayOpDef +/// [`GenericArrayValue`]: super::GenericArrayValue +/// [`Array`]: super::Array +/// [`ValueArray`]: crate::std_extensions::collections::value_array::ValueArray +pub trait ArrayKind: + Clone + + Copy + + std::fmt::Debug + + std::fmt::Display + + Eq + + PartialEq + + Default + + Send + + Sync + + 'static +{ + /// Identifier of the extension containing the array. + const EXTENSION_ID: ExtensionId; + + /// Name of the array type. + const TYPE_NAME: TypeName; + + /// Name of the array value. + const VALUE_NAME: ValueName; + + /// Returns the extension containing the array. + fn extension() -> &'static Arc; + + /// Returns the definition for the array type. + fn type_def() -> &'static TypeDef; + + /// Instantiates an array [CustomType] from its definition given a size and + /// element type argument. + fn instantiate_custom_ty( + array_def: &TypeDef, + size: impl Into, + element_ty: impl Into, + ) -> Result { + array_def.instantiate(vec![size.into(), element_ty.into()]) + } + + /// Instantiates an array type from its definition given a size and element + /// type argument. + fn instantiate_ty( + array_def: &TypeDef, + size: impl Into, + element_ty: impl Into, + ) -> Result { + Self::instantiate_custom_ty(array_def, size, element_ty).map(Into::into) + } + + /// Instantiates an array [CustomType] given a size and element type argument. + fn custom_ty(size: impl Into, element_ty: impl Into) -> CustomType { + Self::instantiate_custom_ty(Self::type_def(), size, element_ty) + .expect("array parameters are valid") + } + + /// Instantiate a new array type given a size argument and element type. + /// + /// This method is equivalent to [`ArrayKind::ty_parametric`], but uses concrete + /// arguments types to ensure no errors are possible. + fn ty(size: u64, element_ty: Type) -> Type { + Self::custom_ty(size, element_ty).into() + } + + /// Instantiate a new array type given the size and element type parameters. + /// + /// This is a generic version of [`ArrayKind::ty`]. + fn ty_parametric( + size: impl Into, + element_ty: impl Into, + ) -> Result { + Self::instantiate_ty(Self::type_def(), size, element_ty) + } +} diff --git a/hugr-core/src/std_extensions/collections/array/array_op.rs b/hugr-core/src/std_extensions/collections/array/array_op.rs index 197536032..188656807 100644 --- a/hugr-core/src/std_extensions/collections/array/array_op.rs +++ b/hugr-core/src/std_extensions/collections/array/array_op.rs @@ -1,5 +1,6 @@ //! Definitions of `ArrayOp` and `ArrayOpDef`. +use std::marker::PhantomData; use std::sync::{Arc, Weak}; use strum::{EnumIter, EnumString, IntoStaticStr}; @@ -12,19 +13,19 @@ use crate::extension::{ ExtensionId, OpDef, SignatureError, SignatureFromArgs, SignatureFunc, TypeDef, }; use crate::ops::{ExtensionOp, NamedOp, OpName}; -use crate::std_extensions::collections::array::instantiate_array; use crate::type_row; use crate::types::type_param::{TypeArg, TypeParam}; use crate::types::{FuncValueType, PolyFuncTypeRV, Type, TypeBound}; +use crate::utils::Never; use crate::Extension; -use super::{array_type, array_type_def, ARRAY_TYPENAME}; +use super::array_kind::ArrayKind; -/// Array operation definitions. -#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, EnumIter, IntoStaticStr, EnumString)] +/// Array operation definitions. Generic over the conrete array implementation. +#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, IntoStaticStr, EnumIter, EnumString)] #[allow(non_camel_case_types)] #[non_exhaustive] -pub enum ArrayOpDef { +pub enum GenericArrayOpDef { /// Makes a new array, given distinct inputs equal to its length: /// `new_array: (elemty)^SIZE -> array` /// where `SIZE` must be statically known (not a variable) @@ -53,26 +54,30 @@ pub enum ArrayOpDef { /// Allows discarding a 0-element array of linear type. /// `discard_empty: array<0, elemty> -> ` (no outputs) discard_empty, + /// Not an actual operation definition, but an unhabitable variant that + /// references `AK` to ensure that the type parameter is used. + #[strum(disabled)] + _phantom(PhantomData, Never), } /// Static parameters for array operations. Includes array size. Type is part of the type scheme. const STATIC_SIZE_PARAM: &[TypeParam; 1] = &[TypeParam::max_nat()]; -impl SignatureFromArgs for ArrayOpDef { +impl SignatureFromArgs for GenericArrayOpDef { fn compute_signature(&self, arg_values: &[TypeArg]) -> Result { let [TypeArg::BoundedNat { n }] = *arg_values else { return Err(SignatureError::InvalidTypeArgs); }; let elem_ty_var = Type::new_var_use(0, TypeBound::Any); - let array_ty = array_type(n, elem_ty_var.clone()); + let array_ty = AK::ty(n, elem_ty_var.clone()); let params = vec![TypeBound::Any.into()]; let poly_func_ty = match self { - ArrayOpDef::new_array => PolyFuncTypeRV::new( + GenericArrayOpDef::new_array => PolyFuncTypeRV::new( params, FuncValueType::new(vec![elem_ty_var.clone(); n as usize], array_ty), ), - ArrayOpDef::pop_left | ArrayOpDef::pop_right => { - let popped_array_ty = array_type(n - 1, elem_ty_var.clone()); + GenericArrayOpDef::pop_left | GenericArrayOpDef::pop_right => { + let popped_array_ty = AK::ty(n - 1, elem_ty_var.clone()); PolyFuncTypeRV::new( params, FuncValueType::new( @@ -81,6 +86,7 @@ impl SignatureFromArgs for ArrayOpDef { ), ) } + GenericArrayOpDef::_phantom(_, never) => match *never {}, _ => unreachable!( "Operation {} should not need custom computation.", self.name() @@ -94,16 +100,16 @@ impl SignatureFromArgs for ArrayOpDef { } } -impl ArrayOpDef { +impl GenericArrayOpDef { /// Instantiate a new array operation with the given element type and array size. - pub fn to_concrete(self, elem_ty: Type, size: u64) -> ArrayOp { - if self == ArrayOpDef::discard_empty { + pub fn to_concrete(self, elem_ty: Type, size: u64) -> GenericArrayOp { + if self == GenericArrayOpDef::discard_empty { debug_assert_eq!( size, 0, "discard_empty should only be called on empty arrays" ); } - ArrayOp { + GenericArrayOp { def: self, elem_ty, size, @@ -116,7 +122,7 @@ impl ArrayOpDef { array_def: &TypeDef, _extension_ref: &Weak, ) -> SignatureFunc { - use ArrayOpDef::*; + use GenericArrayOpDef::*; if let new_array | pop_left | pop_right = self { // implements SignatureFromArgs // signature computed dynamically, so can rely on type definition in extension. @@ -124,7 +130,7 @@ impl ArrayOpDef { } else { let size_var = TypeArg::new_var_use(0, TypeParam::max_nat()); let elem_ty_var = Type::new_var_use(1, TypeBound::Any); - let array_ty = instantiate_array(array_def, size_var.clone(), elem_ty_var.clone()) + let array_ty = AK::instantiate_ty(array_def, size_var.clone(), elem_ty_var.clone()) .expect("Array type instantiation failed"); let standard_params = vec![TypeParam::max_nat(), TypeBound::Any.into()]; @@ -137,7 +143,7 @@ impl ArrayOpDef { let params = vec![TypeParam::max_nat(), TypeBound::Copyable.into()]; let copy_elem_ty = Type::new_var_use(1, TypeBound::Copyable); let copy_array_ty = - instantiate_array(array_def, size_var, copy_elem_ty.clone()) + AK::instantiate_ty(array_def, size_var, copy_elem_ty.clone()) .expect("Array type instantiation failed"); let option_type: Type = option_type(copy_elem_ty).into(); PolyFuncTypeRV::new( @@ -166,11 +172,12 @@ impl ArrayOpDef { discard_empty => PolyFuncTypeRV::new( vec![TypeBound::Any.into()], FuncValueType::new( - instantiate_array(array_def, 0, Type::new_var_use(0, TypeBound::Any)) + AK::instantiate_ty(array_def, 0, Type::new_var_use(0, TypeBound::Any)) .expect("Array type instantiation failed"), type_row![], ), ), + _phantom(_, never) => match *never {}, new_array | pop_left | pop_right => unreachable!(), } .into() @@ -178,7 +185,7 @@ impl ArrayOpDef { } } -impl MakeOpDef for ArrayOpDef { +impl MakeOpDef for GenericArrayOpDef { fn from_def(op_def: &OpDef) -> Result where Self: Sized, @@ -187,26 +194,27 @@ impl MakeOpDef for ArrayOpDef { } fn init_signature(&self, extension_ref: &Weak) -> SignatureFunc { - self.signature_from_def(array_type_def(), extension_ref) + self.signature_from_def(AK::type_def(), extension_ref) } fn extension_ref(&self) -> Weak { - Arc::downgrade(&super::EXTENSION) + Arc::downgrade(AK::extension()) } fn extension(&self) -> ExtensionId { - super::EXTENSION_ID + AK::EXTENSION_ID } fn description(&self) -> String { match self { - ArrayOpDef::new_array => "Create a new array from elements", - ArrayOpDef::get => "Get an element from an array", - ArrayOpDef::set => "Set an element in an array", - ArrayOpDef::swap => "Swap two elements in an array", - ArrayOpDef::pop_left => "Pop an element from the left of an array", - ArrayOpDef::pop_right => "Pop an element from the right of an array", - ArrayOpDef::discard_empty => "Discard an empty array", + GenericArrayOpDef::new_array => "Create a new array from elements", + GenericArrayOpDef::get => "Get an element from an array", + GenericArrayOpDef::set => "Set an element in an array", + GenericArrayOpDef::swap => "Swap two elements in an array", + GenericArrayOpDef::pop_left => "Pop an element from the left of an array", + GenericArrayOpDef::pop_right => "Pop an element from the right of an array", + GenericArrayOpDef::discard_empty => "Discard an empty array", + GenericArrayOpDef::_phantom(_, never) => match *never {}, } .into() } @@ -222,7 +230,7 @@ impl MakeOpDef for ArrayOpDef { extension_ref: &Weak, ) -> Result<(), crate::extension::ExtensionBuildError> { let sig = - self.signature_from_def(extension.get_type(&ARRAY_TYPENAME).unwrap(), extension_ref); + self.signature_from_def(extension.get_type(&AK::TYPE_NAME).unwrap(), extension_ref); let def = extension.add_op(self.name(), self.description(), sig, extension_ref)?; self.post_opdef(def); @@ -232,33 +240,33 @@ impl MakeOpDef for ArrayOpDef { } #[derive(Clone, Debug, PartialEq)] -/// Concrete array operation. -pub struct ArrayOp { +/// Concrete array operation. Generic over the actual array implemenation. +pub struct GenericArrayOp { /// The operation definition. - pub def: ArrayOpDef, + pub def: GenericArrayOpDef, /// The element type of the array. pub elem_ty: Type, /// The size of the array. pub size: u64, } -impl NamedOp for ArrayOp { +impl NamedOp for GenericArrayOp { fn name(&self) -> OpName { self.def.name() } } -impl MakeExtensionOp for ArrayOp { +impl MakeExtensionOp for GenericArrayOp { fn from_extension_op(ext_op: &ExtensionOp) -> Result where Self: Sized, { - let def = ArrayOpDef::from_def(ext_op.def())?; + let def = GenericArrayOpDef::from_def(ext_op.def())?; def.instantiate(ext_op.args()) } fn type_args(&self) -> Vec { - use ArrayOpDef::*; + use GenericArrayOpDef::*; let ty_arg = TypeArg::Type { ty: self.elem_ty.clone(), }; @@ -273,30 +281,31 @@ impl MakeExtensionOp for ArrayOp { new_array | pop_left | pop_right | get | set | swap => { vec![TypeArg::BoundedNat { n: self.size }, ty_arg] } + _phantom(_, never) => match never {}, } } } -impl MakeRegisteredOp for ArrayOp { +impl MakeRegisteredOp for GenericArrayOp { fn extension_id(&self) -> ExtensionId { - super::EXTENSION_ID + AK::EXTENSION_ID } fn extension_ref(&self) -> Weak { - Arc::downgrade(&super::EXTENSION) + Arc::downgrade(AK::extension()) } } -impl HasDef for ArrayOp { - type Def = ArrayOpDef; +impl HasDef for GenericArrayOp { + type Def = GenericArrayOpDef; } -impl HasConcrete for ArrayOpDef { - type Concrete = ArrayOp; +impl HasConcrete for GenericArrayOpDef { + type Concrete = GenericArrayOp; fn instantiate(&self, type_args: &[TypeArg]) -> Result { let (ty, size) = match (self, type_args) { - (ArrayOpDef::discard_empty, [TypeArg::Type { ty }]) => (ty.clone(), 0), + (GenericArrayOpDef::discard_empty, [TypeArg::Type { ty }]) => (ty.clone(), 0), (_, [TypeArg::BoundedNat { n }, TypeArg::Type { ty }]) => (ty.clone(), *n), _ => return Err(SignatureError::InvalidTypeArgs.into()), }; @@ -307,11 +316,13 @@ impl HasConcrete for ArrayOpDef { #[cfg(test)] mod tests { + use rstest::rstest; use strum::IntoEnumIterator; use crate::extension::prelude::usize_t; use crate::std_extensions::arithmetic::float_types::float64_type; - use crate::std_extensions::collections::array::new_array_op; + use crate::std_extensions::collections::array::Array; + use crate::std_extensions::collections::value_array::ValueArray; use crate::{ builder::{inout_sig, DFGBuilder, Dataflow, DataflowHugr}, extension::prelude::{bool_t, qb_t}, @@ -320,46 +331,51 @@ mod tests { use super::*; - #[test] - fn test_array_ops() { - for def in ArrayOpDef::iter() { - let ty = if def == ArrayOpDef::get { + #[rstest] + #[case(Array)] + #[case(ValueArray)] + fn test_array_ops(#[case] _kind: AK) { + for def in GenericArrayOpDef::::iter() { + let ty = if def == GenericArrayOpDef::get { bool_t() } else { qb_t() }; - let size = if def == ArrayOpDef::discard_empty { + let size = if def == GenericArrayOpDef::discard_empty { 0 } else { 2 }; let op = def.to_concrete(ty, size); let optype: OpType = op.clone().into(); - let new_op: ArrayOp = optype.cast().unwrap(); + let new_op: GenericArrayOp = optype.cast().unwrap(); assert_eq!(new_op, op); } } - #[test] + #[rstest] + #[case(Array)] + #[case(ValueArray)] /// Test building a HUGR involving a new_array operation. - fn test_new_array() { - let mut b = - DFGBuilder::new(inout_sig(vec![qb_t(), qb_t()], array_type(2, qb_t()))).unwrap(); + fn test_new_array(#[case] _kind: AK) { + let mut b = DFGBuilder::new(inout_sig(vec![qb_t(), qb_t()], AK::ty(2, qb_t()))).unwrap(); let [q1, q2] = b.input_wires_arr(); - let op = new_array_op(qb_t(), 2); + let op = GenericArrayOpDef::::new_array.to_concrete(qb_t(), 2); let out = b.add_dataflow_op(op, [q1, q2]).unwrap(); b.finish_hugr_with_outputs(out.outputs()).unwrap(); } - #[test] - fn test_get() { + #[rstest] + #[case(Array)] + #[case(ValueArray)] + fn test_get(#[case] _kind: AK) { let size = 2; let element_ty = bool_t(); - let op = ArrayOpDef::get.to_concrete(element_ty.clone(), size); + let op = GenericArrayOpDef::::get.to_concrete(element_ty.clone(), size); let optype: OpType = op.into(); @@ -368,22 +384,24 @@ mod tests { assert_eq!( sig.io(), ( - &vec![array_type(size, element_ty.clone()), usize_t()].into(), + &vec![AK::ty(size, element_ty.clone()), usize_t()].into(), &vec![option_type(element_ty.clone()).into()].into() ) ); } - #[test] - fn test_set() { + #[rstest] + #[case(Array)] + #[case(ValueArray)] + fn test_set(#[case] _kind: AK) { let size = 2; let element_ty = bool_t(); - let op = ArrayOpDef::set.to_concrete(element_ty.clone(), size); + let op = GenericArrayOpDef::::set.to_concrete(element_ty.clone(), size); let optype: OpType = op.into(); let sig = optype.dataflow_signature().unwrap(); - let array_ty = array_type(size, element_ty.clone()); + let array_ty = AK::ty(size, element_ty.clone()); let result_row = vec![element_ty.clone(), array_ty.clone()]; assert_eq!( sig.io(), @@ -394,16 +412,18 @@ mod tests { ); } - #[test] - fn test_swap() { + #[rstest] + #[case(Array)] + #[case(ValueArray)] + fn test_swap(#[case] _kind: AK) { let size = 2; let element_ty = bool_t(); - let op = ArrayOpDef::swap.to_concrete(element_ty.clone(), size); + let op = GenericArrayOpDef::::swap.to_concrete(element_ty.clone(), size); let optype: OpType = op.into(); let sig = optype.dataflow_signature().unwrap(); - let array_ty = array_type(size, element_ty.clone()); + let array_ty = AK::ty(size, element_ty.clone()); assert_eq!( sig.io(), ( @@ -413,11 +433,18 @@ mod tests { ); } - #[test] - fn test_pops() { + #[rstest] + #[case(Array)] + #[case(ValueArray)] + fn test_pops(#[case] _kind: AK) { let size = 2; let element_ty = bool_t(); - for op in [ArrayOpDef::pop_left, ArrayOpDef::pop_right].iter() { + for op in [ + GenericArrayOpDef::::pop_left, + GenericArrayOpDef::::pop_right, + ] + .iter() + { let op = op.to_concrete(element_ty.clone(), size); let optype: OpType = op.into(); @@ -426,10 +453,10 @@ mod tests { assert_eq!( sig.io(), ( - &vec![array_type(size, element_ty.clone())].into(), + &vec![AK::ty(size, element_ty.clone())].into(), &vec![option_type(vec![ element_ty.clone(), - array_type(size - 1, element_ty.clone()) + AK::ty(size - 1, element_ty.clone()) ]) .into()] .into() @@ -438,11 +465,13 @@ mod tests { } } - #[test] - fn test_discard_empty() { + #[rstest] + #[case(Array)] + #[case(ValueArray)] + fn test_discard_empty(#[case] _kind: AK) { let size = 0; let element_ty = bool_t(); - let op = ArrayOpDef::discard_empty.to_concrete(element_ty.clone(), size); + let op = GenericArrayOpDef::::discard_empty.to_concrete(element_ty.clone(), size); let optype: OpType = op.into(); @@ -450,19 +479,18 @@ mod tests { assert_eq!( sig.io(), - ( - &vec![array_type(size, element_ty.clone())].into(), - &type_row![] - ) + (&vec![AK::ty(size, element_ty.clone())].into(), &type_row![]) ); } - #[test] + #[rstest] + #[case(Array)] + #[case(ValueArray)] /// Initialize an array operation where the element type is not from the prelude. - fn test_non_prelude_op() { + fn test_non_prelude_op(#[case] _kind: AK) { let size = 2; let element_ty = float64_type(); - let op = ArrayOpDef::get.to_concrete(element_ty.clone(), size); + let op = GenericArrayOpDef::::get.to_concrete(element_ty.clone(), size); let optype: OpType = op.into(); @@ -471,7 +499,7 @@ mod tests { assert_eq!( sig.io(), ( - &vec![array_type(size, element_ty.clone()), usize_t()].into(), + &vec![AK::ty(size, element_ty.clone()), usize_t()].into(), &vec![option_type(element_ty.clone()).into()].into() ) ); diff --git a/hugr-core/src/std_extensions/collections/array/array_repeat.rs b/hugr-core/src/std_extensions/collections/array/array_repeat.rs index 544866970..5ee927aeb 100644 --- a/hugr-core/src/std_extensions/collections/array/array_repeat.rs +++ b/hugr-core/src/std_extensions/collections/array/array_repeat.rs @@ -1,5 +1,6 @@ //! Definition of the array repeat operation. +use std::marker::PhantomData; use std::str::FromStr; use std::sync::{Arc, Weak}; @@ -12,34 +13,47 @@ use crate::types::type_param::{TypeArg, TypeParam}; use crate::types::{FuncValueType, PolyFuncTypeRV, Signature, Type, TypeBound}; use crate::Extension; -use super::{array_type_def, instantiate_array, ARRAY_TYPENAME}; +use super::array_kind::ArrayKind; /// Name of the operation to repeat a value multiple times pub const ARRAY_REPEAT_OP_ID: OpName = OpName::new_inline("repeat"); -/// Definition of the array repeat op. +/// Definition of the array repeat op. Generic over the concrete array implementation. #[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)] -pub struct ArrayRepeatDef; +pub struct GenericArrayRepeatDef(PhantomData); -impl NamedOp for ArrayRepeatDef { +impl GenericArrayRepeatDef { + /// Creates a new array repeat operation definition. + pub fn new() -> Self { + GenericArrayRepeatDef(PhantomData) + } +} + +impl Default for GenericArrayRepeatDef { + fn default() -> Self { + Self::new() + } +} + +impl NamedOp for GenericArrayRepeatDef { fn name(&self) -> OpName { ARRAY_REPEAT_OP_ID } } -impl FromStr for ArrayRepeatDef { +impl FromStr for GenericArrayRepeatDef { type Err = (); fn from_str(s: &str) -> Result { - if s == ArrayRepeatDef.name() { - Ok(Self) + if s == ARRAY_REPEAT_OP_ID { + Ok(GenericArrayRepeatDef::new()) } else { Err(()) } } } -impl ArrayRepeatDef { +impl GenericArrayRepeatDef { /// To avoid recursion when defining the extension, take the type definition as an argument. fn signature_from_def(&self, array_def: &TypeDef) -> SignatureFunc { let params = vec![ @@ -52,12 +66,13 @@ impl ArrayRepeatDef { let es = ExtensionSet::type_var(2); let func = Type::new_function(Signature::new(vec![], vec![t.clone()]).with_extension_delta(es)); - let array_ty = instantiate_array(array_def, n, t).expect("Array type instantiation failed"); + let array_ty = + AK::instantiate_ty(array_def, n, t).expect("Array type instantiation failed"); PolyFuncTypeRV::new(params, FuncValueType::new(vec![func], array_ty)).into() } } -impl MakeOpDef for ArrayRepeatDef { +impl MakeOpDef for GenericArrayRepeatDef { fn from_def(op_def: &OpDef) -> Result where Self: Sized, @@ -66,15 +81,15 @@ impl MakeOpDef for ArrayRepeatDef { } fn init_signature(&self, _extension_ref: &Weak) -> SignatureFunc { - self.signature_from_def(array_type_def()) + self.signature_from_def(AK::type_def()) } fn extension_ref(&self) -> Weak { - Arc::downgrade(&super::EXTENSION) + Arc::downgrade(AK::extension()) } fn extension(&self) -> ExtensionId { - super::EXTENSION_ID + AK::EXTENSION_ID } fn description(&self) -> String { @@ -93,7 +108,7 @@ impl MakeOpDef for ArrayRepeatDef { extension: &mut Extension, extension_ref: &Weak, ) -> Result<(), crate::extension::ExtensionBuildError> { - let sig = self.signature_from_def(extension.get_type(&ARRAY_TYPENAME).unwrap()); + let sig = self.signature_from_def(extension.get_type(&AK::TYPE_NAME).unwrap()); let def = extension.add_op(self.name(), self.description(), sig, extension_ref)?; self.post_opdef(def); @@ -102,40 +117,42 @@ impl MakeOpDef for ArrayRepeatDef { } } -/// Definition of the array repeat op. +/// Definition of the array repeat op. Generic over the concrete array implementation. #[derive(Clone, Debug, PartialEq)] -pub struct ArrayRepeat { +pub struct GenericArrayRepeat { /// The element type of the resulting array. pub elem_ty: Type, /// Size of the array. pub size: u64, /// The extensions required by the function that generates the array elements. pub extension_reqs: ExtensionSet, + _kind: PhantomData, } -impl ArrayRepeat { +impl GenericArrayRepeat { /// Creates a new array repeat op. pub fn new(elem_ty: Type, size: u64, extension_reqs: ExtensionSet) -> Self { - ArrayRepeat { + GenericArrayRepeat { elem_ty, size, extension_reqs, + _kind: PhantomData, } } } -impl NamedOp for ArrayRepeat { +impl NamedOp for GenericArrayRepeat { fn name(&self) -> OpName { ARRAY_REPEAT_OP_ID } } -impl MakeExtensionOp for ArrayRepeat { +impl MakeExtensionOp for GenericArrayRepeat { fn from_extension_op(ext_op: &ExtensionOp) -> Result where Self: Sized, { - let def = ArrayRepeatDef::from_def(ext_op.def())?; + let def = GenericArrayRepeatDef::::from_def(ext_op.def())?; def.instantiate(ext_op.args()) } @@ -150,27 +167,27 @@ impl MakeExtensionOp for ArrayRepeat { } } -impl MakeRegisteredOp for ArrayRepeat { +impl MakeRegisteredOp for GenericArrayRepeat { fn extension_id(&self) -> ExtensionId { - super::EXTENSION_ID + AK::EXTENSION_ID } fn extension_ref(&self) -> Weak { - Arc::downgrade(&super::EXTENSION) + Arc::downgrade(AK::extension()) } } -impl HasDef for ArrayRepeat { - type Def = ArrayRepeatDef; +impl HasDef for GenericArrayRepeat { + type Def = GenericArrayRepeatDef; } -impl HasConcrete for ArrayRepeatDef { - type Concrete = ArrayRepeat; +impl HasConcrete for GenericArrayRepeatDef { + type Concrete = GenericArrayRepeat; fn instantiate(&self, type_args: &[TypeArg]) -> Result { match type_args { [TypeArg::BoundedNat { n }, TypeArg::Type { ty }, TypeArg::Extensions { es }] => { - Ok(ArrayRepeat::new(ty.clone(), *n, es.clone())) + Ok(GenericArrayRepeat::new(ty.clone(), *n, es.clone())) } _ => Err(SignatureError::InvalidTypeArgs.into()), } @@ -179,7 +196,10 @@ impl HasConcrete for ArrayRepeatDef { #[cfg(test)] mod tests { - use crate::std_extensions::collections::array::{array_type, EXTENSION_ID}; + use rstest::rstest; + + use crate::std_extensions::collections::array::{Array, EXTENSION_ID}; + use crate::std_extensions::collections::value_array::ValueArray; use crate::{ extension::prelude::qb_t, ops::{OpTrait, OpType}, @@ -188,20 +208,24 @@ mod tests { use super::*; - #[test] - fn test_repeat_def() { - let op = ArrayRepeat::new(qb_t(), 2, ExtensionSet::singleton(EXTENSION_ID)); + #[rstest] + #[case(Array)] + #[case(ValueArray)] + fn test_repeat_def(#[case] _kind: AK) { + let op = GenericArrayRepeat::::new(qb_t(), 2, ExtensionSet::singleton(EXTENSION_ID)); let optype: OpType = op.clone().into(); - let new_op: ArrayRepeat = optype.cast().unwrap(); + let new_op: GenericArrayRepeat = optype.cast().unwrap(); assert_eq!(new_op, op); } - #[test] - fn test_repeat() { + #[rstest] + #[case(Array)] + #[case(ValueArray)] + fn test_repeat(#[case] _kind: AK) { let size = 2; let element_ty = qb_t(); let es = ExtensionSet::singleton(EXTENSION_ID); - let op = ArrayRepeat::new(element_ty.clone(), size, es.clone()); + let op = GenericArrayRepeat::::new(element_ty.clone(), size, es.clone()); let optype: OpType = op.into(); @@ -214,7 +238,7 @@ mod tests { Signature::new(vec![], vec![qb_t()]).with_extension_delta(es) )] .into(), - &vec![array_type(size, element_ty.clone())].into(), + &vec![AK::ty(size, element_ty.clone())].into(), ) ); } diff --git a/hugr-core/src/std_extensions/collections/array/array_scan.rs b/hugr-core/src/std_extensions/collections/array/array_scan.rs index 86a0fe94e..8352de3f7 100644 --- a/hugr-core/src/std_extensions/collections/array/array_scan.rs +++ b/hugr-core/src/std_extensions/collections/array/array_scan.rs @@ -1,5 +1,6 @@ //! Array scanning operation +use std::marker::PhantomData; use std::str::FromStr; use std::sync::{Arc, Weak}; @@ -14,34 +15,47 @@ use crate::types::type_param::{TypeArg, TypeParam}; use crate::types::{FuncTypeBase, PolyFuncTypeRV, RowVariable, Type, TypeBound, TypeRV}; use crate::Extension; -use super::{array_type_def, instantiate_array, ARRAY_TYPENAME}; +use super::array_kind::ArrayKind; /// Name of the operation for the combined map/fold operation pub const ARRAY_SCAN_OP_ID: OpName = OpName::new_inline("scan"); -/// Definition of the array scan op. +/// Definition of the array scan op. Generic over the concrete array implementation. #[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)] -pub struct ArrayScanDef; +pub struct GenericArrayScanDef(PhantomData); -impl NamedOp for ArrayScanDef { +impl GenericArrayScanDef { + /// Creates a new array scan operation definition. + pub fn new() -> Self { + GenericArrayScanDef(PhantomData) + } +} + +impl Default for GenericArrayScanDef { + fn default() -> Self { + Self::new() + } +} + +impl NamedOp for GenericArrayScanDef { fn name(&self) -> OpName { ARRAY_SCAN_OP_ID } } -impl FromStr for ArrayScanDef { +impl FromStr for GenericArrayScanDef { type Err = (); fn from_str(s: &str) -> Result { - if s == ArrayScanDef.name() { - Ok(Self) + if s == ARRAY_SCAN_OP_ID { + Ok(Self::new()) } else { Err(()) } } } -impl ArrayScanDef { +impl GenericArrayScanDef { /// To avoid recursion when defining the extension, take the type definition /// and a reference to the extension as an argument. fn signature_from_def(&self, array_def: &TypeDef) -> SignatureFunc { @@ -62,7 +76,7 @@ impl ArrayScanDef { params, FuncTypeBase::::new( vec![ - instantiate_array(array_def, n.clone(), t1.clone()) + AK::instantiate_ty(array_def, n.clone(), t1.clone()) .expect("Array type instantiation failed") .into(), Type::new_function( @@ -76,7 +90,7 @@ impl ArrayScanDef { s.clone(), ], vec![ - instantiate_array(array_def, n, t2) + AK::instantiate_ty(array_def, n, t2) .expect("Array type instantiation failed") .into(), s, @@ -87,7 +101,7 @@ impl ArrayScanDef { } } -impl MakeOpDef for ArrayScanDef { +impl MakeOpDef for GenericArrayScanDef { fn from_def(op_def: &OpDef) -> Result where Self: Sized, @@ -96,15 +110,15 @@ impl MakeOpDef for ArrayScanDef { } fn init_signature(&self, _extension_ref: &Weak) -> SignatureFunc { - self.signature_from_def(array_type_def()) + self.signature_from_def(AK::type_def()) } fn extension_ref(&self) -> Weak { - Arc::downgrade(&super::EXTENSION) + Arc::downgrade(AK::extension()) } fn extension(&self) -> ExtensionId { - super::EXTENSION_ID + AK::EXTENSION_ID } fn description(&self) -> String { @@ -125,7 +139,7 @@ impl MakeOpDef for ArrayScanDef { extension: &mut Extension, extension_ref: &Weak, ) -> Result<(), crate::extension::ExtensionBuildError> { - let sig = self.signature_from_def(extension.get_type(&ARRAY_TYPENAME).unwrap()); + let sig = self.signature_from_def(extension.get_type(&AK::TYPE_NAME).unwrap()); let def = extension.add_op(self.name(), self.description(), sig, extension_ref)?; self.post_opdef(def); @@ -134,9 +148,9 @@ impl MakeOpDef for ArrayScanDef { } } -/// Definition of the array scan op. +/// Definition of the array scan op. Generic over the concrete array implementation. #[derive(Clone, Debug, PartialEq)] -pub struct ArrayScan { +pub struct GenericArrayScan { /// The element type of the input array. pub src_ty: Type, /// The target element type of the output array. @@ -147,9 +161,10 @@ pub struct ArrayScan { pub size: u64, /// The extensions required by the scan function. pub extension_reqs: ExtensionSet, + _kind: PhantomData, } -impl ArrayScan { +impl GenericArrayScan { /// Creates a new array scan op. pub fn new( src_ty: Type, @@ -158,28 +173,29 @@ impl ArrayScan { size: u64, extension_reqs: ExtensionSet, ) -> Self { - ArrayScan { + GenericArrayScan { src_ty, tgt_ty, acc_tys, size, extension_reqs, + _kind: PhantomData, } } } -impl NamedOp for ArrayScan { +impl NamedOp for GenericArrayScan { fn name(&self) -> OpName { ARRAY_SCAN_OP_ID } } -impl MakeExtensionOp for ArrayScan { +impl MakeExtensionOp for GenericArrayScan { fn from_extension_op(ext_op: &ExtensionOp) -> Result where Self: Sized, { - let def = ArrayScanDef::from_def(ext_op.def())?; + let def = GenericArrayScanDef::::from_def(ext_op.def())?; def.instantiate(ext_op.args()) } @@ -198,22 +214,22 @@ impl MakeExtensionOp for ArrayScan { } } -impl MakeRegisteredOp for ArrayScan { +impl MakeRegisteredOp for GenericArrayScan { fn extension_id(&self) -> ExtensionId { - super::EXTENSION_ID + AK::EXTENSION_ID } fn extension_ref(&self) -> Weak { - Arc::downgrade(&super::EXTENSION) + Arc::downgrade(AK::extension()) } } -impl HasDef for ArrayScan { - type Def = ArrayScanDef; +impl HasDef for GenericArrayScan { + type Def = GenericArrayScanDef; } -impl HasConcrete for ArrayScanDef { - type Concrete = ArrayScan; +impl HasConcrete for GenericArrayScanDef { + type Concrete = GenericArrayScan; fn instantiate(&self, type_args: &[TypeArg]) -> Result { match type_args { @@ -226,7 +242,7 @@ impl HasConcrete for ArrayScanDef { _ => Err(SignatureError::InvalidTypeArgs.into()), }) .collect(); - Ok(ArrayScan::new( + Ok(GenericArrayScan::new( src_ty.clone(), tgt_ty.clone(), acc_tys?, @@ -241,9 +257,11 @@ impl HasConcrete for ArrayScanDef { #[cfg(test)] mod tests { + use rstest::rstest; use crate::extension::prelude::usize_t; - use crate::std_extensions::collections::array::{array_type, EXTENSION_ID}; + use crate::std_extensions::collections::array::{Array, EXTENSION_ID}; + use crate::std_extensions::collections::value_array::ValueArray; use crate::{ extension::prelude::{bool_t, qb_t}, ops::{OpTrait, OpType}, @@ -252,9 +270,11 @@ mod tests { use super::*; - #[test] - fn test_scan_def() { - let op = ArrayScan::new( + #[rstest] + #[case(Array)] + #[case(ValueArray)] + fn test_scan_def(#[case] _kind: AK) { + let op = GenericArrayScan::::new( bool_t(), qb_t(), vec![usize_t()], @@ -262,18 +282,21 @@ mod tests { ExtensionSet::singleton(EXTENSION_ID), ); let optype: OpType = op.clone().into(); - let new_op: ArrayScan = optype.cast().unwrap(); + let new_op: GenericArrayScan = optype.cast().unwrap(); assert_eq!(new_op, op); } - #[test] - fn test_scan_map() { + #[rstest] + #[case(Array)] + #[case(ValueArray)] + fn test_scan_map(#[case] _kind: AK) { let size = 2; let src_ty = qb_t(); let tgt_ty = bool_t(); let es = ExtensionSet::singleton(EXTENSION_ID); - let op = ArrayScan::new(src_ty.clone(), tgt_ty.clone(), vec![], size, es.clone()); + let op = + GenericArrayScan::::new(src_ty.clone(), tgt_ty.clone(), vec![], size, es.clone()); let optype: OpType = op.into(); let sig = optype.dataflow_signature().unwrap(); @@ -281,19 +304,21 @@ mod tests { sig.io(), ( &vec![ - array_type(size, src_ty.clone()), + AK::ty(size, src_ty.clone()), Type::new_function( Signature::new(vec![src_ty], vec![tgt_ty.clone()]).with_extension_delta(es) ) ] .into(), - &vec![array_type(size, tgt_ty)].into(), + &vec![AK::ty(size, tgt_ty)].into(), ) ); } - #[test] - fn test_scan_accs() { + #[rstest] + #[case(Array)] + #[case(ValueArray)] + fn test_scan_accs(#[case] _kind: AK) { let size = 2; let src_ty = qb_t(); let tgt_ty = bool_t(); @@ -301,7 +326,7 @@ mod tests { let acc_ty2 = qb_t(); let es = ExtensionSet::singleton(EXTENSION_ID); - let op = ArrayScan::new( + let op = GenericArrayScan::::new( src_ty.clone(), tgt_ty.clone(), vec![acc_ty1.clone(), acc_ty2.clone()], @@ -315,7 +340,7 @@ mod tests { sig.io(), ( &vec![ - array_type(size, src_ty.clone()), + AK::ty(size, src_ty.clone()), Type::new_function( Signature::new( vec![src_ty, acc_ty1.clone(), acc_ty2.clone()], @@ -327,7 +352,7 @@ mod tests { acc_ty2.clone() ] .into(), - &vec![array_type(size, tgt_ty), acc_ty1, acc_ty2].into(), + &vec![AK::ty(size, tgt_ty), acc_ty1, acc_ty2].into(), ) ); } diff --git a/hugr-core/src/std_extensions/collections/array/array_value.rs b/hugr-core/src/std_extensions/collections/array/array_value.rs new file mode 100644 index 000000000..c3fb94b4d --- /dev/null +++ b/hugr-core/src/std_extensions/collections/array/array_value.rs @@ -0,0 +1,166 @@ +use itertools::Itertools as _; +use serde::{Deserialize, Serialize}; +use std::hash::{Hash, Hasher}; +use std::marker::PhantomData; + +use crate::extension::resolution::{ + resolve_type_extensions, resolve_value_extensions, ExtensionResolutionError, + WeakExtensionRegistry, +}; +use crate::extension::ExtensionSet; +use crate::ops::constant::{maybe_hash_values, TryHash, ValueName}; +use crate::ops::Value; +use crate::types::type_param::TypeArg; +use crate::types::{CustomCheckFailure, CustomType, Type}; + +use super::array_kind::ArrayKind; + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +/// Statically sized array of values, all of the same type. +pub struct GenericArrayValue { + values: Vec, + typ: Type, + _kind: PhantomData, +} + +impl GenericArrayValue { + /// Create a new [CustomConst] for an array of values of type `typ`. + /// That all values are of type `typ` is not checked here. + /// + /// [CustomConst]: crate::ops::constant::CustomConst + pub fn new(typ: Type, contents: impl IntoIterator) -> Self { + Self { + values: contents.into_iter().collect_vec(), + typ, + _kind: PhantomData, + } + } + + /// Create a new [CustomConst] for an empty array of values of type `typ`. + /// + /// [CustomConst]: crate::ops::constant::CustomConst + pub fn new_empty(typ: Type) -> Self { + Self { + values: vec![], + typ, + _kind: PhantomData, + } + } + + /// Returns the type of the `[GenericArrayValue]` as a `[CustomType]`.` + pub fn custom_type(&self) -> CustomType { + AK::custom_ty(self.values.len() as u64, self.typ.clone()) + } + + /// Returns the type of the `[GenericArrayValue]`. + pub fn get_type(&self) -> Type { + self.custom_type().into() + } + + /// Returns the type of values inside the `[ArrayValue]`. + pub fn get_element_type(&self) -> &Type { + &self.typ + } + + /// Returns the values contained inside the `[ArrayValue]`. + pub fn get_contents(&self) -> &[Value] { + &self.values + } + + /// Returns the name of the value. + pub fn name(&self) -> ValueName { + AK::VALUE_NAME + } + + /// Validates the array value. + pub fn validate(&self) -> Result<(), CustomCheckFailure> { + let typ = self.custom_type(); + + AK::extension() + .get_type(&AK::TYPE_NAME) + .unwrap() + .check_custom(&typ) + .map_err(|_| { + CustomCheckFailure::Message(format!( + "Custom typ {typ} is not a valid instantiation of array." + )) + })?; + + // constant can only hold classic type. + let ty = match typ.args() { + [TypeArg::BoundedNat { n }, TypeArg::Type { ty }] + if *n as usize == self.values.len() => + { + ty + } + _ => { + return Err(CustomCheckFailure::Message(format!( + "Invalid array type arguments: {:?}", + typ.args() + ))) + } + }; + + // check all values are instances of the element type + for v in &self.values { + if v.get_type() != *ty { + return Err(CustomCheckFailure::Message(format!( + "Array element {v:?} is not of expected type {ty}" + ))); + } + } + + Ok(()) + } + + /// Returns the extension requirements for the array value. + pub fn extension_reqs(&self) -> ExtensionSet { + ExtensionSet::union_over(self.values.iter().map(Value::extension_reqs)) + .union(AK::EXTENSION_ID.into()) + } + + /// Update the extensions associated with the internal values. + pub fn update_extensions( + &mut self, + extensions: &WeakExtensionRegistry, + ) -> Result<(), ExtensionResolutionError> { + for val in &mut self.values { + resolve_value_extensions(val, extensions)?; + } + resolve_type_extensions(&mut self.typ, extensions) + } +} + +impl TryHash for GenericArrayValue { + fn try_hash(&self, mut st: &mut dyn Hasher) -> bool { + maybe_hash_values(&self.values, &mut st) && { + self.typ.hash(&mut st); + true + } + } +} + +#[cfg(test)] +mod test { + use rstest::rstest; + + use crate::extension::prelude::{usize_t, ConstUsize}; + use crate::std_extensions::arithmetic::float_types::ConstF64; + + use crate::std_extensions::collections::array::Array; + use crate::std_extensions::collections::value_array::ValueArray; + + use super::*; + + #[rstest] + #[case(Array)] + #[case(ValueArray)] + fn test_array_value(#[case] _kind: AK) { + let array_value = GenericArrayValue::::new(usize_t(), vec![ConstUsize::new(3).into()]); + array_value.validate().unwrap(); + + let wrong_array_value = + GenericArrayValue::::new(usize_t(), vec![ConstF64::new(1.2).into()]); + assert!(wrong_array_value.validate().is_err()); + } +} diff --git a/hugr-core/src/std_extensions/collections/value_array.rs b/hugr-core/src/std_extensions/collections/value_array.rs new file mode 100644 index 000000000..d5883b7a5 --- /dev/null +++ b/hugr-core/src/std_extensions/collections/value_array.rs @@ -0,0 +1,133 @@ +//! A version of the standard fixed-length array extension where arrays of copyable types +//! are copyable themselves. +//! +//! Supports all regular array operations apart from `clone` and `discard`. + +use std::sync::Arc; + +use delegate::delegate; +use lazy_static::lazy_static; + +use crate::extension::resolution::{ExtensionResolutionError, WeakExtensionRegistry}; +use crate::extension::simple_op::MakeOpDef; +use crate::extension::{ExtensionId, ExtensionSet, SignatureError, TypeDef, TypeDefBound}; +use crate::ops::constant::{CustomConst, ValueName}; +use crate::types::type_param::{TypeArg, TypeParam}; +use crate::types::{CustomCheckFailure, Type, TypeBound, TypeName}; +use crate::Extension; + +use super::array::{ + ArrayKind, GenericArrayOp, GenericArrayOpDef, GenericArrayRepeat, GenericArrayRepeatDef, + GenericArrayScan, GenericArrayScanDef, GenericArrayValue, +}; + +/// Reported unique name of the value array type. +pub const VALUE_ARRAY_TYPENAME: TypeName = TypeName::new_inline("value_array"); +/// Reported unique name of the value array value. +pub const VALUE_ARRAY_VALUENAME: TypeName = TypeName::new_inline("value_array"); +/// Reported unique name of the extension +pub const EXTENSION_ID: ExtensionId = ExtensionId::new_static_unchecked("collections.value_array"); +/// Extension version. +pub const VERSION: semver::Version = semver::Version::new(0, 1, 0); + +/// A fixed-length collection of values. +/// +/// A value array inherits its linearity from its elements. +#[derive(Clone, Copy, Debug, derive_more::Display, Eq, PartialEq, Default)] +pub struct ValueArray; + +impl ArrayKind for ValueArray { + const EXTENSION_ID: ExtensionId = EXTENSION_ID; + const TYPE_NAME: TypeName = VALUE_ARRAY_TYPENAME; + const VALUE_NAME: ValueName = VALUE_ARRAY_VALUENAME; + + fn extension() -> &'static Arc { + &EXTENSION + } + + fn type_def() -> &'static TypeDef { + EXTENSION.get_type(&VALUE_ARRAY_TYPENAME).unwrap() + } +} + +/// Value array operation definitions. +pub type VArrayOpDef = GenericArrayOpDef; +/// Value array repeat operation definition. +pub type VArrayRepeatDef = GenericArrayRepeatDef; +/// Value array scan operation definition. +pub type VArrayScanDef = GenericArrayScanDef; + +/// Value array operations. +pub type VArrayOp = GenericArrayOp; +/// The value array repeat operation. +pub type VArrayRepeat = GenericArrayRepeat; +/// The value array scan operation. +pub type VArrayScan = GenericArrayScan; + +/// A value array extension value. +pub type VArrayValue = GenericArrayValue; + +lazy_static! { + /// Extension for value array operations. + pub static ref EXTENSION: Arc = { + Extension::new_arc(EXTENSION_ID, VERSION, |extension, extension_ref| { + extension.add_type( + VALUE_ARRAY_TYPENAME, + vec![ TypeParam::max_nat(), TypeBound::Any.into()], + "Fixed-length value array".into(), + // Value arrays are copyable iff their elements are + TypeDefBound::from_params(vec![1]), + extension_ref, + ) + .unwrap(); + + VArrayOpDef::load_all_ops(extension, extension_ref).unwrap(); + VArrayRepeatDef::new().add_to_extension(extension, extension_ref).unwrap(); + VArrayScanDef::new().add_to_extension(extension, extension_ref).unwrap(); + }) + }; +} + +#[typetag::serde(name = "VArrayValue")] +impl CustomConst for VArrayValue { + delegate! { + to self { + fn name(&self) -> ValueName; + fn extension_reqs(&self) -> ExtensionSet; + fn validate(&self) -> Result<(), CustomCheckFailure>; + fn update_extensions( + &mut self, + extensions: &WeakExtensionRegistry, + ) -> Result<(), ExtensionResolutionError>; + fn get_type(&self) -> Type; + } + } + + fn equal_consts(&self, other: &dyn CustomConst) -> bool { + crate::ops::constant::downcast_equal_consts(self, other) + } +} + +/// Gets the [TypeDef] for value arrays. Note that instantiations are more easily +/// created via [value_array_type] and [value_array_type_parametric] +pub fn value_array_type_def() -> &'static TypeDef { + ValueArray::type_def() +} + +/// Instantiate a new value array type given a size argument and element type. +/// +/// This method is equivalent to [`value_array_type_parametric`], but uses concrete +/// arguments types to ensure no errors are possible. +pub fn value_array_type(size: u64, element_ty: Type) -> Type { + ValueArray::ty(size, element_ty) +} + +/// Instantiate a new value array type given the size and element type parameters. +/// +/// This is a generic version of [`value_array_type`]. +pub fn value_array_type_parametric( + size: impl Into, + element_ty: impl Into, +) -> Result { + ValueArray::ty_parametric(size, element_ty) +} diff --git a/hugr-core/src/utils.rs b/hugr-core/src/utils.rs index f44b075f1..efa0eef84 100644 --- a/hugr-core/src/utils.rs +++ b/hugr-core/src/utils.rs @@ -101,6 +101,18 @@ pub(crate) fn is_default(t: &T) -> bool { *t == Default::default() } +/// An empty type. +/// +/// # Example +/// +/// ```ignore +/// fn foo(never: Never) -> ! { +/// match never {} +/// } +/// ``` +#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)] +pub enum Never {} + #[cfg(test)] pub(crate) mod test_quantum_extension { use std::sync::Arc; diff --git a/hugr-llvm/src/emit/test.rs b/hugr-llvm/src/emit/test.rs index 3f6977a8c..9fbb7ed8b 100644 --- a/hugr-llvm/src/emit/test.rs +++ b/hugr-llvm/src/emit/test.rs @@ -164,6 +164,7 @@ impl SimpleHugrConfig { collections::array::EXTENSION_ID, collections::list::EXTENSION_ID, collections::static_array::EXTENSION_ID, + collections::value_array::EXTENSION_ID, ]), ), ) diff --git a/hugr-llvm/src/extension/collections.rs b/hugr-llvm/src/extension/collections.rs index 6c10d3ed1..f50015b8b 100644 --- a/hugr-llvm/src/extension/collections.rs +++ b/hugr-llvm/src/extension/collections.rs @@ -1,5 +1,5 @@ //! Emission logic for collections. -pub mod array; pub mod list; pub mod static_array; +pub mod value_array; diff --git a/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__array__test__emit_all_ops@llvm14.snap b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__value_array__test__emit_all_ops@llvm14.snap similarity index 99% rename from hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__array__test__emit_all_ops@llvm14.snap rename to hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__value_array__test__emit_all_ops@llvm14.snap index bc9aa19c6..b56595a95 100644 --- a/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__array__test__emit_all_ops@llvm14.snap +++ b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__value_array__test__emit_all_ops@llvm14.snap @@ -1,5 +1,5 @@ --- -source: hugr-llvm/src/extension/collections/array.rs +source: hugr-llvm/src/extension/collections/value_array.rs expression: mod_str --- ; ModuleID = 'test_context' diff --git a/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__array__test__emit_all_ops@pre-mem2reg@llvm14.snap b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__value_array__test__emit_all_ops@pre-mem2reg@llvm14.snap similarity index 99% rename from hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__array__test__emit_all_ops@pre-mem2reg@llvm14.snap rename to hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__value_array__test__emit_all_ops@pre-mem2reg@llvm14.snap index 9b294486d..39604d600 100644 --- a/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__array__test__emit_all_ops@pre-mem2reg@llvm14.snap +++ b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__value_array__test__emit_all_ops@pre-mem2reg@llvm14.snap @@ -1,5 +1,5 @@ --- -source: hugr-llvm/src/extension/collections/array.rs +source: hugr-llvm/src/extension/collections/value_array.rs expression: mod_str --- ; ModuleID = 'test_context' diff --git a/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__array__test__emit_array_value@llvm14.snap b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__value_array__test__emit_array_value@llvm14.snap similarity index 82% rename from hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__array__test__emit_array_value@llvm14.snap rename to hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__value_array__test__emit_array_value@llvm14.snap index 3a718f7f2..00474a526 100644 --- a/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__array__test__emit_array_value@llvm14.snap +++ b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__value_array__test__emit_array_value@llvm14.snap @@ -1,5 +1,5 @@ --- -source: hugr-llvm/src/extension/collections/array.rs +source: hugr-llvm/src/extension/collections/value_array.rs expression: mod_str --- ; ModuleID = 'test_context' diff --git a/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__array__test__emit_array_value@pre-mem2reg@llvm14.snap b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__value_array__test__emit_array_value@pre-mem2reg@llvm14.snap similarity index 90% rename from hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__array__test__emit_array_value@pre-mem2reg@llvm14.snap rename to hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__value_array__test__emit_array_value@pre-mem2reg@llvm14.snap index 5befaf3df..310722041 100644 --- a/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__array__test__emit_array_value@pre-mem2reg@llvm14.snap +++ b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__value_array__test__emit_array_value@pre-mem2reg@llvm14.snap @@ -1,5 +1,5 @@ --- -source: hugr-llvm/src/extension/collections/array.rs +source: hugr-llvm/src/extension/collections/value_array.rs expression: mod_str --- ; ModuleID = 'test_context' diff --git a/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__array__test__emit_get@llvm14.snap b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__value_array__test__emit_get@llvm14.snap similarity index 94% rename from hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__array__test__emit_get@llvm14.snap rename to hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__value_array__test__emit_get@llvm14.snap index 1c638784d..c31061551 100644 --- a/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__array__test__emit_get@llvm14.snap +++ b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__value_array__test__emit_get@llvm14.snap @@ -1,5 +1,5 @@ --- -source: hugr-llvm/src/extension/collections/array.rs +source: hugr-llvm/src/extension/collections/value_array.rs expression: mod_str --- ; ModuleID = 'test_context' diff --git a/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__array__test__emit_get@pre-mem2reg@llvm14.snap b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__value_array__test__emit_get@pre-mem2reg@llvm14.snap similarity index 96% rename from hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__array__test__emit_get@pre-mem2reg@llvm14.snap rename to hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__value_array__test__emit_get@pre-mem2reg@llvm14.snap index 15902b579..2bc368ef0 100644 --- a/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__array__test__emit_get@pre-mem2reg@llvm14.snap +++ b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__value_array__test__emit_get@pre-mem2reg@llvm14.snap @@ -1,5 +1,5 @@ --- -source: hugr-llvm/src/extension/collections/array.rs +source: hugr-llvm/src/extension/collections/value_array.rs expression: mod_str --- ; ModuleID = 'test_context' diff --git a/hugr-llvm/src/extension/collections/array.rs b/hugr-llvm/src/extension/collections/value_array.rs similarity index 91% rename from hugr-llvm/src/extension/collections/array.rs rename to hugr-llvm/src/extension/collections/value_array.rs index 65e0599ea..c7373c953 100644 --- a/hugr-llvm/src/extension/collections/array.rs +++ b/hugr-llvm/src/extension/collections/value_array.rs @@ -1,12 +1,13 @@ -//! Codegen for prelude array operations. +//! Codegen for prelude value_array operations. use std::iter; use anyhow::{anyhow, Ok, Result}; use hugr_core::extension::prelude::option_type; use hugr_core::extension::simple_op::{MakeExtensionOp, MakeRegisteredOp}; use hugr_core::ops::DataflowOpTrait; -use hugr_core::std_extensions::collections::array::{ - self, array_type, ArrayOp, ArrayOpDef, ArrayRepeat, ArrayScan, +use hugr_core::std_extensions::collections::array; +use hugr_core::std_extensions::collections::value_array::{ + self, value_array_type, VArrayOp, VArrayOpDef, VArrayRepeat, VArrayScan, }; use hugr_core::types::{TypeArg, TypeEnum}; use hugr_core::{HugrView, Node}; @@ -28,21 +29,21 @@ use crate::{CodegenExtension, CodegenExtsBuilder}; impl<'a, H: HugrView + 'a> CodegenExtsBuilder<'a, H> { /// Add a [ArrayCodegenExtension] to the given [CodegenExtsBuilder] using `ccg` /// as the implementation. - pub fn add_default_array_extensions(self) -> Self { - self.add_array_extensions(DefaultArrayCodegen) + pub fn add_default_value_array_extensions(self) -> Self { + self.add_value_array_extensions(DefaultVArrayCodegen) } /// Add a [ArrayCodegenExtension] to the given [CodegenExtsBuilder] using - /// [DefaultArrayCodegen] as the implementation. - pub fn add_array_extensions(self, ccg: impl ArrayCodegen + 'a) -> Self { + /// [DefaultVArrayCodegen] as the implementation. + pub fn add_value_array_extensions(self, ccg: impl VArrayCodegen + 'a) -> Self { self.add_extension(ArrayCodegenExtension::from(ccg)) } } -/// A helper trait for customising the lowering of [hugr_core::std_extensions::collections::array] +/// A helper trait for customising the lowering of [hugr_core::std_extensions::collections::value_array] /// types, [hugr_core::ops::constant::CustomConst]s, and ops. -pub trait ArrayCodegen: Clone { - /// Return the llvm type of [hugr_core::std_extensions::collections::array::ARRAY_TYPENAME]. +pub trait VArrayCodegen: Clone { + /// Return the llvm type of [hugr_core::std_extensions::collections::value_array::VALUE_ARRAY_TYPENAME]. fn array_type<'c>( &self, _session: &TypingSession<'c, '_>, @@ -52,43 +53,43 @@ pub trait ArrayCodegen: Clone { elem_ty.array_type(size as u32) } - /// Emit a [hugr_core::std_extensions::collections::array::ArrayValue]. + /// Emit a [hugr_core::std_extensions::collections::value_array::VArrayValue]. fn emit_array_value<'c, H: HugrView>( &self, ctx: &mut EmitFuncContext<'c, '_, H>, - value: &array::ArrayValue, + value: &value_array::VArrayValue, ) -> Result> { emit_array_value(self, ctx, value) } - /// Emit a [hugr_core::std_extensions::collections::array::ArrayOp]. + /// Emit a [hugr_core::std_extensions::collections::value_array::VArrayOp]. fn emit_array_op<'c, H: HugrView>( &self, ctx: &mut EmitFuncContext<'c, '_, H>, - op: ArrayOp, + op: VArrayOp, inputs: Vec>, outputs: RowPromise<'c>, ) -> Result<()> { emit_array_op(self, ctx, op, inputs, outputs) } - /// Emit a [hugr_core::std_extensions::collections::array::ArrayRepeat] op. + /// Emit a [hugr_core::std_extensions::collections::value_array::VArrayRepeat] op. fn emit_array_repeat<'c, H: HugrView>( &self, ctx: &mut EmitFuncContext<'c, '_, H>, - op: ArrayRepeat, + op: VArrayRepeat, func: BasicValueEnum<'c>, ) -> Result> { emit_repeat_op(ctx, op, func) } - /// Emit a [hugr_core::std_extensions::collections::array::ArrayScan] op. + /// Emit a [hugr_core::std_extensions::collections::value_array::VArrayScan] op. /// /// Returns the resulting array and the final values of the accumulators. fn emit_array_scan<'c, H: HugrView>( &self, ctx: &mut EmitFuncContext<'c, '_, H>, - op: ArrayScan, + op: VArrayScan, src_array: BasicValueEnum<'c>, func: BasicValueEnum<'c>, initial_accs: &[BasicValueEnum<'c>], @@ -97,29 +98,29 @@ pub trait ArrayCodegen: Clone { } } -/// A trivial implementation of [ArrayCodegen] which passes all methods +/// A trivial implementation of [VArrayCodegen] which passes all methods /// through to their default implementations. #[derive(Default, Clone)] -pub struct DefaultArrayCodegen; +pub struct DefaultVArrayCodegen; -impl ArrayCodegen for DefaultArrayCodegen {} +impl VArrayCodegen for DefaultVArrayCodegen {} #[derive(Clone, Debug, Default)] pub struct ArrayCodegenExtension(CCG); -impl ArrayCodegenExtension { +impl ArrayCodegenExtension { pub fn new(ccg: CCG) -> Self { Self(ccg) } } -impl From for ArrayCodegenExtension { +impl From for ArrayCodegenExtension { fn from(ccg: CCG) -> Self { Self::new(ccg) } } -impl CodegenExtension for ArrayCodegenExtension { +impl CodegenExtension for ArrayCodegenExtension { fn add_extension<'a, H: HugrView + 'a>( self, builder: CodegenExtsBuilder<'a, H>, @@ -128,47 +129,51 @@ impl CodegenExtension for ArrayCodegenExtension { Self: 'a, { builder - .custom_type((array::EXTENSION_ID, array::ARRAY_TYPENAME), { - let ccg = self.0.clone(); - move |ts, hugr_type| { - let [TypeArg::BoundedNat { n }, TypeArg::Type { ty }] = hugr_type.args() else { - return Err(anyhow!("Invalid type args for array type")); - }; - let elem_ty = ts.llvm_type(ty)?; - Ok(ccg.array_type(&ts, elem_ty, *n).as_basic_type_enum()) - } - }) - .custom_const::({ + .custom_type( + (value_array::EXTENSION_ID, value_array::VALUE_ARRAY_TYPENAME), + { + let ccg = self.0.clone(); + move |ts, hugr_type| { + let [TypeArg::BoundedNat { n }, TypeArg::Type { ty }] = hugr_type.args() + else { + return Err(anyhow!("Invalid type args for array type")); + }; + let elem_ty = ts.llvm_type(ty)?; + Ok(ccg.array_type(&ts, elem_ty, *n).as_basic_type_enum()) + } + }, + ) + .custom_const::({ let ccg = self.0.clone(); move |context, k| ccg.emit_array_value(context, k) }) - .simple_extension_op::({ + .simple_extension_op::({ let ccg = self.0.clone(); move |context, args, _| { ccg.emit_array_op( context, - ArrayOp::from_extension_op(args.node().as_ref())?, + VArrayOp::from_extension_op(args.node().as_ref())?, args.inputs, args.outputs, ) } }) - .extension_op(array::EXTENSION_ID, array::ARRAY_REPEAT_OP_ID, { + .extension_op(value_array::EXTENSION_ID, array::ARRAY_REPEAT_OP_ID, { let ccg = self.0.clone(); move |context, args| { let func = args.inputs[0]; - let op = ArrayRepeat::from_extension_op(args.node().as_ref())?; + let op = VArrayRepeat::from_extension_op(args.node().as_ref())?; let arr = ccg.emit_array_repeat(context, op, func)?; args.outputs.finish(context.builder(), [arr]) } }) - .extension_op(array::EXTENSION_ID, array::ARRAY_SCAN_OP_ID, { + .extension_op(value_array::EXTENSION_ID, array::ARRAY_SCAN_OP_ID, { let ccg = self.0.clone(); move |context, args| { let src_array = args.inputs[0]; let func = args.inputs[1]; let initial_accs = &args.inputs[2..]; - let op = ArrayScan::from_extension_op(args.node().as_ref())?; + let op = VArrayScan::from_extension_op(args.node().as_ref())?; let (tgt_array, final_accs) = ccg.emit_array_scan(context, op, src_array, func, initial_accs)?; args.outputs @@ -258,9 +263,9 @@ fn build_loop<'c, T, H: HugrView>( } pub fn emit_array_value<'c, H: HugrView>( - ccg: &impl ArrayCodegen, + ccg: &impl VArrayCodegen, ctx: &mut EmitFuncContext<'c, '_, H>, - value: &array::ArrayValue, + value: &value_array::VArrayValue, ) -> Result> { let ts = ctx.typing_session(); let llvm_array_ty = ccg @@ -283,9 +288,9 @@ pub fn emit_array_value<'c, H: HugrView>( } pub fn emit_array_op<'c, H: HugrView>( - ccg: &impl ArrayCodegen, + ccg: &impl VArrayCodegen, ctx: &mut EmitFuncContext<'c, '_, H>, - op: ArrayOp, + op: VArrayOp, inputs: Vec>, outputs: RowPromise<'c>, ) -> Result<()> { @@ -297,7 +302,7 @@ pub fn emit_array_op<'c, H: HugrView>( .unwrap() .signature() .into_owned(); - let ArrayOp { + let VArrayOp { def, ref elem_ty, size, @@ -307,7 +312,7 @@ pub fn emit_array_op<'c, H: HugrView>( .as_basic_type_enum() .into_array_type(); match def { - ArrayOpDef::new_array => { + VArrayOpDef::new_array => { let mut array_v = llvm_array_ty.get_undef(); for (i, v) in inputs.into_iter().enumerate() { array_v = builder @@ -316,7 +321,7 @@ pub fn emit_array_op<'c, H: HugrView>( } outputs.finish(builder, [array_v.as_basic_value_enum()]) } - ArrayOpDef::get => { + VArrayOpDef::get => { let [array_v, index_v] = inputs .try_into() .map_err(|_| anyhow!("ArrayOpDef::get expects two arguments"))?; @@ -378,7 +383,7 @@ pub fn emit_array_op<'c, H: HugrView>( builder.position_at_end(exit_block); Ok(()) } - ArrayOpDef::set => { + VArrayOpDef::set => { let [array_v0, index_v, value_v] = inputs .try_into() .map_err(|_| anyhow!("ArrayOpDef::set expects three arguments"))?; @@ -451,7 +456,7 @@ pub fn emit_array_op<'c, H: HugrView>( builder.position_at_end(exit_block); Ok(()) } - ArrayOpDef::swap => { + VArrayOpDef::swap => { let [array_v0, index1_v, index2_v] = inputs .try_into() .map_err(|_| anyhow!("ArrayOpDef::swap expects three arguments"))?; @@ -543,7 +548,7 @@ pub fn emit_array_op<'c, H: HugrView>( builder.position_at_end(exit_block); Ok(()) } - ArrayOpDef::pop_left => { + VArrayOpDef::pop_left => { let [array_v] = inputs .try_into() .map_err(|_| anyhow!("ArrayOpDef::pop_left expects one argument"))?; @@ -557,7 +562,7 @@ pub fn emit_array_op<'c, H: HugrView>( )?; outputs.finish(ctx.builder(), [r]) } - ArrayOpDef::pop_right => { + VArrayOpDef::pop_right => { let [array_v] = inputs .try_into() .map_err(|_| anyhow!("ArrayOpDef::pop_right expects one argument"))?; @@ -571,7 +576,7 @@ pub fn emit_array_op<'c, H: HugrView>( )?; outputs.finish(ctx.builder(), [r]) } - ArrayOpDef::discard_empty => Ok(()), + VArrayOpDef::discard_empty => Ok(()), _ => todo!(), } } @@ -587,7 +592,7 @@ fn emit_pop_op<'c>( ) -> Result> { let ret_ty = ts.llvm_sum_type(option_type(vec![ elem_ty.clone(), - array_type(size.saturating_add_signed(-1), elem_ty), + value_array_type(size.saturating_add_signed(-1), elem_ty), ]))?; if size == 0 { return Ok(ret_ty.build_tag(builder, 0, vec![])?.into()); @@ -620,10 +625,10 @@ fn emit_pop_op<'c>( Ok(ret_ty.build_tag(builder, 1, vec![elem_v, array_v])?.into()) } -/// Emits an [ArrayRepeat] op. +/// Emits an [VArrayRepeat] op. pub fn emit_repeat_op<'c, H: HugrView>( ctx: &mut EmitFuncContext<'c, '_, H>, - op: ArrayRepeat, + op: VArrayRepeat, func: BasicValueEnum<'c>, ) -> Result> { let builder = ctx.builder(); @@ -649,12 +654,12 @@ pub fn emit_repeat_op<'c, H: HugrView>( Ok(array_v) } -/// Emits an [ArrayScan] op. +/// Emits an [VArrayScan] op. /// /// Returns the resulting array and the final values of the accumulators. pub fn emit_scan_op<'c, H: HugrView>( ctx: &mut EmitFuncContext<'c, '_, H>, - op: ArrayScan, + op: VArrayScan, src_array: BasicValueEnum<'c>, func: BasicValueEnum<'c>, initial_accs: &[BasicValueEnum<'c>], @@ -709,7 +714,9 @@ mod test { use hugr_core::builder::Container as _; use hugr_core::extension::ExtensionSet; use hugr_core::ops::Tag; - use hugr_core::std_extensions::collections::array::{self, array_type, ArrayRepeat, ArrayScan}; + use hugr_core::std_extensions::collections::value_array::{ + self, value_array_type, VArrayRepeat, VArrayScan, + }; use hugr_core::std_extensions::STD_REG; use hugr_core::types::Type; use hugr_core::{ @@ -744,14 +751,14 @@ mod test { let hugr = SimpleHugrConfig::new() .with_extensions(STD_REG.to_owned()) .finish(|mut builder| { - array_op_builder::test::all_array_ops(builder.dfg_builder_endo([]).unwrap()) + array_op_builder::test::all_value_array_ops(builder.dfg_builder_endo([]).unwrap()) .finish_sub_container() .unwrap(); builder.finish_sub_container().unwrap() }); llvm_ctx.add_extensions(|cge| { cge.add_default_prelude_extensions() - .add_default_array_extensions() + .add_default_value_array_extensions() }); check_emission!(hugr, llvm_ctx); } @@ -769,7 +776,7 @@ mod test { }); llvm_ctx.add_extensions(|cge| { cge.add_default_prelude_extensions() - .add_default_array_extensions() + .add_default_value_array_extensions() }); check_emission!(hugr, llvm_ctx); } @@ -778,15 +785,15 @@ mod test { fn emit_array_value(mut llvm_ctx: TestContext) { let hugr = SimpleHugrConfig::new() .with_extensions(STD_REG.to_owned()) - .with_outs(vec![array_type(2, usize_t())]) + .with_outs(vec![value_array_type(2, usize_t())]) .finish(|mut builder| { let vs = vec![ConstUsize::new(1).into(), ConstUsize::new(2).into()]; - let arr = builder.add_load_value(array::ArrayValue::new(usize_t(), vs)); + let arr = builder.add_load_value(value_array::VArrayValue::new(usize_t(), vs)); builder.finish_with_outputs([arr]).unwrap() }); llvm_ctx.add_extensions(|cge| { cge.add_default_prelude_extensions() - .add_default_array_extensions() + .add_default_value_array_extensions() }); check_emission!(hugr, llvm_ctx); } @@ -797,7 +804,7 @@ mod test { int_ops::EXTENSION.to_owned(), logic::EXTENSION.to_owned(), prelude::PRELUDE.to_owned(), - array::EXTENSION.to_owned(), + value_array::EXTENSION.to_owned(), ]) } @@ -807,7 +814,7 @@ mod test { int_ops::EXTENSION_ID, logic::EXTENSION_ID, prelude::PRELUDE_ID, - array::EXTENSION_ID, + value_array::EXTENSION_ID, ]) } @@ -854,7 +861,7 @@ mod test { }); exec_ctx.add_extensions(|cge| { cge.add_default_prelude_extensions() - .add_default_array_extensions() + .add_default_value_array_extensions() }); assert_eq!(expected, exec_ctx.exec_hugr_u64(hugr, "main")); } @@ -896,7 +903,7 @@ mod test { .unwrap(); let r = { let res_sum_ty = { - let row = vec![int_ty.clone(), array_type(2, int_ty.clone())]; + let row = vec![int_ty.clone(), value_array_type(2, int_ty.clone())]; either_type(row.clone(), row) }; let variants = (0..res_sum_ty.num_variants()) @@ -962,7 +969,7 @@ mod test { }); exec_ctx.add_extensions(|cge| { cge.add_default_prelude_extensions() - .add_default_array_extensions() + .add_default_value_array_extensions() .add_default_int_extensions() .add_logic_extensions() }); @@ -991,7 +998,7 @@ mod test { // - The swap operation succeeded iff `expected_succeeded` let int_ty = int_type(3); - let arr_ty = array_type(2, int_ty.clone()); + let arr_ty = value_array_type(2, int_ty.clone()); let hugr = SimpleHugrConfig::new() .with_outs(usize_t()) .with_extensions(exec_registry()) @@ -1072,7 +1079,7 @@ mod test { }); exec_ctx.add_extensions(|cge| { cge.add_default_prelude_extensions() - .add_default_array_extensions() + .add_default_value_array_extensions() .add_default_int_extensions() .add_logic_extensions() }); @@ -1129,7 +1136,7 @@ mod test { 1, option_type(vec![ int_ty.clone(), - array_type(array_size - 1, int_ty.clone()), + value_array_type(array_size - 1, int_ty.clone()), ]), pop_res, ) @@ -1141,7 +1148,7 @@ mod test { }); exec_ctx.add_extensions(|cge| { cge.add_default_prelude_extensions() - .add_default_array_extensions() + .add_default_value_array_extensions() .add_default_int_extensions() }); assert_eq!(expected, exec_ctx.exec_hugr_u64(hugr, "main")); @@ -1179,7 +1186,7 @@ mod test { let v = func.add_load_value(ConstInt::new_u(6, value).unwrap()); let func_id = func.finish_with_outputs(vec![v]).unwrap(); let func_v = builder.load_func(func_id.handle(), &[]).unwrap(); - let repeat = ArrayRepeat::new(int_ty.clone(), size, exec_extension_set()); + let repeat = VArrayRepeat::new(int_ty.clone(), size, exec_extension_set()); let arr = builder .add_dataflow_op(repeat, vec![func_v]) .unwrap() @@ -1195,7 +1202,7 @@ mod test { }); exec_ctx.add_extensions(|cge| { cge.add_default_prelude_extensions() - .add_default_array_extensions() + .add_default_value_array_extensions() .add_default_int_extensions() }); assert_eq!(value, exec_ctx.exec_hugr_u64(hugr, "main")); @@ -1236,7 +1243,7 @@ mod test { let out = func.add_iadd(6, elem, delta).unwrap(); let func_id = func.finish_with_outputs(vec![out]).unwrap(); let func_v = builder.load_func(func_id.handle(), &[]).unwrap(); - let scan = ArrayScan::new( + let scan = VArrayScan::new( int_ty.clone(), int_ty.clone(), vec![], @@ -1258,7 +1265,7 @@ mod test { 1, option_type(vec![ int_ty.clone(), - array_type(array_size - 1, int_ty.clone()), + value_array_type(array_size - 1, int_ty.clone()), ]), pop_res, ) @@ -1270,7 +1277,7 @@ mod test { }); exec_ctx.add_extensions(|cge| { cge.add_default_prelude_extensions() - .add_default_array_extensions() + .add_default_value_array_extensions() .add_default_int_extensions() }); let expected: u64 = (inc..size + inc).sum(); @@ -1316,7 +1323,7 @@ mod test { .out_wire(0); let func_id = func.finish_with_outputs(vec![unit, acc]).unwrap(); let func_v = builder.load_func(func_id.handle(), &[]).unwrap(); - let scan = ArrayScan::new( + let scan = VArrayScan::new( int_ty.clone(), Type::UNIT, vec![int_ty.clone()], @@ -1332,7 +1339,7 @@ mod test { }); exec_ctx.add_extensions(|cge| { cge.add_default_prelude_extensions() - .add_default_array_extensions() + .add_default_value_array_extensions() .add_default_int_extensions() }); let expected: u64 = (0..size).sum(); diff --git a/hugr-llvm/src/utils/array_op_builder.rs b/hugr-llvm/src/utils/array_op_builder.rs index dfe2faba4..8ec56c8bb 100644 --- a/hugr-llvm/src/utils/array_op_builder.rs +++ b/hugr-llvm/src/utils/array_op_builder.rs @@ -1,4 +1,5 @@ -use hugr_core::std_extensions::collections::array::{new_array_op, ArrayOpDef}; +use hugr_core::std_extensions::collections::array::{ArrayKind, GenericArrayOpDef}; +use hugr_core::std_extensions::collections::value_array::ValueArray; use hugr_core::{ builder::{BuildError, Dataflow}, extension::simple_op::HasConcrete as _, @@ -7,7 +8,7 @@ use hugr_core::{ }; use itertools::Itertools as _; -pub trait ArrayOpBuilder: Dataflow { +pub trait ArrayOpBuilder: Dataflow { fn add_new_array( &mut self, elem_ty: Type, @@ -15,7 +16,10 @@ pub trait ArrayOpBuilder: Dataflow { ) -> Result { let inputs = values.into_iter().collect_vec(); let [out] = self - .add_dataflow_op(new_array_op(elem_ty, inputs.len() as u64), inputs)? + .add_dataflow_op( + GenericArrayOpDef::::new_array.to_concrete(elem_ty, inputs.len() as u64), + inputs, + )? .outputs_arr(); Ok(out) } @@ -28,7 +32,7 @@ pub trait ArrayOpBuilder: Dataflow { index: Wire, ) -> Result { // TODO Add an OpLoadError variant to BuildError. - let op = ArrayOpDef::get + let op = GenericArrayOpDef::::get .instantiate(&[size.into(), elem_ty.into()]) .unwrap(); let [out] = self.add_dataflow_op(op, vec![input, index])?.outputs_arr(); @@ -44,7 +48,7 @@ pub trait ArrayOpBuilder: Dataflow { value: Wire, ) -> Result { // TODO Add an OpLoadError variant to BuildError - let op = ArrayOpDef::set + let op = GenericArrayOpDef::::set .instantiate(&[size.into(), elem_ty.into()]) .unwrap(); let [out] = self @@ -62,7 +66,7 @@ pub trait ArrayOpBuilder: Dataflow { index2: Wire, ) -> Result { // TODO Add an OpLoadError variant to BuildError - let op = ArrayOpDef::swap + let op = GenericArrayOpDef::::swap .instantiate(&[size.into(), elem_ty.into()]) .unwrap(); let [out] = self @@ -78,7 +82,7 @@ pub trait ArrayOpBuilder: Dataflow { input: Wire, ) -> Result { // TODO Add an OpLoadError variant to BuildError - let op = ArrayOpDef::pop_left + let op = GenericArrayOpDef::::pop_left .instantiate(&[size.into(), elem_ty.into()]) .unwrap(); Ok(self.add_dataflow_op(op, vec![input])?.out_wire(0)) @@ -91,7 +95,7 @@ pub trait ArrayOpBuilder: Dataflow { input: Wire, ) -> Result { // TODO Add an OpLoadError variant to BuildError - let op = ArrayOpDef::pop_right + let op = GenericArrayOpDef::::pop_right .instantiate(&[size.into(), elem_ty.into()]) .unwrap(); Ok(self.add_dataflow_op(op, vec![input])?.out_wire(0)) @@ -100,7 +104,7 @@ pub trait ArrayOpBuilder: Dataflow { fn add_array_discard_empty(&mut self, elem_ty: Type, input: Wire) -> Result<(), BuildError> { // TODO Add an OpLoadError variant to BuildError self.add_dataflow_op( - ArrayOpDef::discard_empty + GenericArrayOpDef::::discard_empty .instantiate(&[elem_ty.into()]) .unwrap(), [input], @@ -109,13 +113,13 @@ pub trait ArrayOpBuilder: Dataflow { } } -impl ArrayOpBuilder for D {} +impl ArrayOpBuilder for D {} #[cfg(test)] pub mod test { use hugr_core::extension::prelude::PRELUDE_ID; use hugr_core::extension::ExtensionSet; - use hugr_core::std_extensions::collections::array::{self, array_type}; + use hugr_core::std_extensions::collections::value_array::{self, value_array_type}; use hugr_core::{ builder::{DFGBuilder, HugrBuilder}, extension::prelude::{either_type, option_type, usize_t, ConstUsize, UnwrapBuilder as _}, @@ -128,11 +132,11 @@ pub mod test { #[rstest::fixture] #[default(DFGBuilder)] - pub fn all_array_ops( + pub fn all_value_array_ops( #[default(DFGBuilder::new(Signature::new_endo(Type::EMPTY_TYPEROW) .with_extension_delta(ExtensionSet::from_iter([ PRELUDE_ID, - array::EXTENSION_ID + value_array::EXTENSION_ID ]))).unwrap())] mut builder: B, ) -> B { @@ -143,7 +147,7 @@ pub mod test { let [arr] = { let r = builder.add_array_swap(usize_t(), 2, arr, us0, us1).unwrap(); let res_sum_ty = { - let array_type = array_type(2, usize_t()); + let array_type = value_array_type(2, usize_t()); either_type(array_type.clone(), array_type) }; builder.build_unwrap_sum(1, res_sum_ty, r).unwrap() @@ -161,7 +165,7 @@ pub mod test { .add_array_set(usize_t(), 2, arr, us1, elem_0) .unwrap(); let res_sum_ty = { - let row = vec![usize_t(), array_type(2, usize_t())]; + let row = vec![usize_t(), value_array_type(2, usize_t())]; either_type(row.clone(), row) }; builder.build_unwrap_sum(1, res_sum_ty, r).unwrap() @@ -170,13 +174,21 @@ pub mod test { let [_elem_left, arr] = { let r = builder.add_array_pop_left(usize_t(), 2, arr).unwrap(); builder - .build_unwrap_sum(1, option_type(vec![usize_t(), array_type(1, usize_t())]), r) + .build_unwrap_sum( + 1, + option_type(vec![usize_t(), value_array_type(1, usize_t())]), + r, + ) .unwrap() }; let [_elem_right, arr] = { let r = builder.add_array_pop_right(usize_t(), 1, arr).unwrap(); builder - .build_unwrap_sum(1, option_type(vec![usize_t(), array_type(0, usize_t())]), r) + .build_unwrap_sum( + 1, + option_type(vec![usize_t(), value_array_type(0, usize_t())]), + r, + ) .unwrap() }; @@ -185,7 +197,7 @@ pub mod test { } #[rstest] - fn build_all_ops(all_array_ops: DFGBuilder) { - all_array_ops.finish_hugr().unwrap(); + fn build_all_ops(all_value_array_ops: DFGBuilder) { + all_value_array_ops.finish_hugr().unwrap(); } } diff --git a/hugr-passes/src/monomorphize.rs b/hugr-passes/src/monomorphize.rs index 4f4e9bda2..6d1132d0a 100644 --- a/hugr-passes/src/monomorphize.rs +++ b/hugr-passes/src/monomorphize.rs @@ -366,7 +366,8 @@ mod test { use hugr_core::extension::simple_op::MakeRegisteredOp as _; use hugr_core::std_extensions::collections; - use hugr_core::std_extensions::collections::array::{array_type_parametric, ArrayOpDef}; + use hugr_core::std_extensions::collections::array::ArrayKind; + use hugr_core::std_extensions::collections::value_array::{VArrayOpDef, ValueArray}; use hugr_core::types::type_param::TypeParam; use itertools::Itertools; @@ -521,26 +522,35 @@ mod test { let mut outer = FunctionBuilder::new( "mainish", prelusig( - array_type_parametric(sa(n), array_type_parametric(sa(2), usize_t()).unwrap()) - .unwrap(), + ValueArray::ty_parametric( + sa(n), + ValueArray::ty_parametric(sa(2), usize_t()).unwrap(), + ) + .unwrap(), vec![usize_t(); 2], ) - .with_extension_delta(collections::array::EXTENSION_ID), + .with_extension_delta(collections::value_array::EXTENSION_ID), ) .unwrap(); - let arr2u = || array_type_parametric(sa(2), usize_t()).unwrap(); + let arr2u = || ValueArray::ty_parametric(sa(2), usize_t()).unwrap(); let pf1t = PolyFuncType::new( [TypeParam::max_nat()], - prelusig(array_type_parametric(sv(0), arr2u()).unwrap(), usize_t()) - .with_extension_delta(collections::array::EXTENSION_ID), + prelusig( + ValueArray::ty_parametric(sv(0), arr2u()).unwrap(), + usize_t(), + ) + .with_extension_delta(collections::value_array::EXTENSION_ID), ); let mut pf1 = outer.define_function("pf1", pf1t).unwrap(); let pf2t = PolyFuncType::new( [TypeParam::max_nat(), TypeBound::Copyable.into()], - prelusig(vec![array_type_parametric(sv(0), tv(1)).unwrap()], tv(1)) - .with_extension_delta(collections::array::EXTENSION_ID), + prelusig( + vec![ValueArray::ty_parametric(sv(0), tv(1)).unwrap()], + tv(1), + ) + .with_extension_delta(collections::value_array::EXTENSION_ID), ); let mut pf2 = pf1.define_function("pf2", pf2t).unwrap(); @@ -549,7 +559,7 @@ mod test { .define_function( "get_usz", prelusig(vec![], usize_t()) - .with_extension_delta(collections::array::EXTENSION_ID), + .with_extension_delta(collections::value_array::EXTENSION_ID), ) .unwrap(); let cst0 = fb.add_load_value(ConstUsize::new(1)); @@ -558,7 +568,7 @@ mod test { let pf2 = { let [inw] = pf2.input_wires_arr(); let [idx] = pf2.call(mono_func.handle(), &[], []).unwrap().outputs_arr(); - let op_def = collections::array::EXTENSION.get_op("get").unwrap(); + let op_def = collections::value_array::EXTENSION.get_op("get").unwrap(); let op = hugr_core::ops::ExtensionOp::new(op_def.clone(), vec![sv(0), tv(1).into()]) .unwrap(); let [get] = pf2.add_dataflow_op(op, [inw, idx]).unwrap().outputs_arr(); @@ -584,7 +594,7 @@ mod test { .call(pf1.handle(), &[sa(n)], outer.input_wires()) .unwrap() .outputs_arr(); - let popleft = ArrayOpDef::pop_left.to_concrete(arr2u(), n); + let popleft = VArrayOpDef::pop_left.to_concrete(arr2u(), n); let ar2 = outer .add_dataflow_op(popleft.clone(), outer.input_wires()) .unwrap(); diff --git a/hugr-passes/src/replace_types.rs b/hugr-passes/src/replace_types.rs index 3ed7337a9..844eb528b 100644 --- a/hugr-passes/src/replace_types.rs +++ b/hugr-passes/src/replace_types.rs @@ -9,6 +9,7 @@ use std::sync::Arc; use handlers::list_const; use hugr_core::std_extensions::collections::array::array_type_def; use hugr_core::std_extensions::collections::list::list_type_def; +use hugr_core::std_extensions::collections::value_array::value_array_type_def; use thiserror::Error; use hugr_core::builder::{BuildError, BuildHandle, Dataflow}; @@ -151,6 +152,7 @@ impl Default for ReplaceTypes { let mut res = Self::new_empty(); res.linearize = DelegatingLinearizer::default(); res.replace_consts_parametrized(array_type_def(), handlers::array_const); + res.replace_consts_parametrized(value_array_type_def(), handlers::value_array_const); res.replace_consts_parametrized(list_type_def(), list_const); res } @@ -535,16 +537,18 @@ mod test { use hugr_core::extension::simple_op::MakeExtensionOp; use hugr_core::extension::{TypeDefBound, Version}; + use hugr_core::ops::constant::CustomConst; use hugr_core::ops::constant::OpaqueValue; use hugr_core::ops::{ExtensionOp, NamedOp, OpTrait, OpType, Tag, Value}; use hugr_core::std_extensions::arithmetic::int_types::ConstInt; use hugr_core::std_extensions::arithmetic::{conversions::ConvertOpDef, int_types::INT_TYPES}; - use hugr_core::std_extensions::collections::array::{ - array_type, array_type_def, ArrayOp, ArrayOpDef, ArrayValue, - }; + use hugr_core::std_extensions::collections::array::{Array, ArrayKind, GenericArrayValue}; use hugr_core::std_extensions::collections::list::{ list_type, list_type_def, ListOp, ListValue, }; + use hugr_core::std_extensions::collections::value_array::{ + value_array_type, VArrayOp, VArrayOpDef, VArrayValue, ValueArray, + }; use hugr_core::hugr::ValidationError; use hugr_core::types::{PolyFuncType, Signature, SumType, Type, TypeArg, TypeBound, TypeRow}; @@ -552,6 +556,7 @@ mod test { use itertools::Itertools; use rstest::rstest; + use crate::replace_types::handlers::generic_array_const; use crate::validation::ValidatePassError; use super::ReplaceTypesError; @@ -619,7 +624,7 @@ mod test { fn lowered_read(args: &[TypeArg]) -> Option { let ty = just_elem_type(args); let mut dfb = DFGBuilder::new(inout_sig( - vec![array_type(64, ty.clone()), i64_t()], + vec![value_array_type(64, ty.clone()), i64_t()], ty.clone(), )) .unwrap(); @@ -629,7 +634,7 @@ mod test { .unwrap() .outputs_arr(); let [opt] = dfb - .add_dataflow_op(ArrayOpDef::get.to_concrete(ty.clone(), 64), [val, idx]) + .add_dataflow_op(VArrayOpDef::get.to_concrete(ty.clone(), 64), [val, idx]) .unwrap() .outputs_arr(); let [res] = dfb @@ -644,7 +649,7 @@ mod test { lw.replace_type(pv.instantiate([bool_t().into()]).unwrap(), i64_t()); lw.replace_parametrized_type( pv, - Box::new(|args: &[TypeArg]| Some(array_type(64, just_elem_type(args).clone()))), + Box::new(|args: &[TypeArg]| Some(value_array_type(64, just_elem_type(args).clone()))), ); lw.replace_op( &read_op(ext, bool_t()), @@ -768,10 +773,10 @@ mod test { // The PackedVec> becomes an array let [array_get] = ext_ops .into_iter() - .filter_map(|e| ArrayOp::from_extension_op(e).ok()) + .filter_map(|e| VArrayOp::from_extension_op(e).ok()) .collect_array() .unwrap(); - assert_eq!(array_get, ArrayOpDef::get.to_concrete(i64_t(), 64)); + assert_eq!(array_get, VArrayOpDef::get.to_concrete(i64_t(), 64)); } #[test] @@ -801,7 +806,7 @@ mod test { // 1. Lower List to Array<10, T> UNLESS T is usize_t() or i64_t lowerer.replace_parametrized_type(list_type_def(), |args| { let ty = just_elem_type(args); - (![usize_t(), i64_t()].contains(ty)).then_some(array_type(10, ty.clone())) + (![usize_t(), i64_t()].contains(ty)).then_some(value_array_type(10, ty.clone())) }); { let mut h = backup.clone(); @@ -809,7 +814,7 @@ mod test { let sig = h.signature(h.root()).unwrap(); assert_eq!( sig.input(), - &TypeRow::from(vec![list_type(usize_t()), array_type(10, bool_t())]) + &TypeRow::from(vec![list_type(usize_t()), value_array_type(10, bool_t())]) ); assert_eq!(sig.input(), sig.output()); } @@ -831,7 +836,7 @@ mod test { let sig = h.signature(h.root()).unwrap(); assert_eq!( sig.input(), - &TypeRow::from(vec![list_type(i64_t()), array_type(10, bool_t())]) + &TypeRow::from(vec![list_type(i64_t()), value_array_type(10, bool_t())]) ); assert_eq!(sig.input(), sig.output()); // This will have to update inside the Const @@ -848,7 +853,7 @@ mod test { let mut h = backup; lowerer.replace_parametrized_type( list_type_def(), - Box::new(|args: &[TypeArg]| Some(array_type(4, just_elem_type(args).clone()))), + Box::new(|args: &[TypeArg]| Some(value_array_type(4, just_elem_type(args).clone()))), ); lowerer.replace_consts_parametrized(list_type_def(), |opaq, repl| { // First recursively transform the contents @@ -858,7 +863,7 @@ mod test { let lv = opaq.value().downcast_ref::().unwrap(); Ok(Some( - ArrayValue::new(lv.get_element_type().clone(), lv.get_contents().to_vec()).into(), + VArrayValue::new(lv.get_element_type().clone(), lv.get_contents().to_vec()).into(), )) }); lowerer.run(&mut h).unwrap(); @@ -867,7 +872,10 @@ mod test { h.get_optype(pred.node()) .as_load_constant() .map(|lc| lc.constant_type()), - Some(&Type::new_sum(vec![Type::from(array_type(4, i64_t())); 2])) + Some(&Type::new_sum(vec![ + Type::from(value_array_type(4, i64_t())); + 2 + ])) ); } @@ -956,17 +964,19 @@ mod test { } #[rstest] - #[case(&[])] - #[case(&[3])] - #[case(&[5,7,11,13,17,19])] - fn array_const(#[case] vals: &[u64]) { - use super::handlers::array_const; - let mut dfb = DFGBuilder::new(inout_sig( - type_row![], - array_type(vals.len() as _, usize_t()), - )) - .unwrap(); - let c = dfb.add_load_value(ArrayValue::new( + #[case(&[], Array)] + #[case(&[], ValueArray)] + #[case(&[3], Array)] + #[case(&[3], ValueArray)] + #[case(&[5,7,11,13,17,19], Array)] + #[case(&[5,7,11,13,17,19], ValueArray)] + fn array_const(#[case] vals: &[u64], #[case] _kind: AK) + where + GenericArrayValue: CustomConst, + { + let mut dfb = + DFGBuilder::new(inout_sig(type_row![], AK::ty(vals.len() as _, usize_t()))).unwrap(); + let c = dfb.add_load_value(GenericArrayValue::::new( usize_t(), vals.iter().map(|u| ConstUsize::new(*u).into()), )); @@ -984,7 +994,7 @@ mod test { err: ValidationError::IncompatiblePorts {from, to, ..}, .. })) if backup.get_optype(from).is_const() && to == c.node()) ); - repl.replace_consts_parametrized(array_type_def(), array_const); + repl.replace_consts_parametrized(AK::type_def(), generic_array_const::); let mut h = backup; repl.run(&mut h).unwrap(); // Includes validation } diff --git a/hugr-passes/src/replace_types/handlers.rs b/hugr-passes/src/replace_types/handlers.rs index e835a2d9b..8e6e9ada4 100644 --- a/hugr-passes/src/replace_types/handlers.rs +++ b/hugr-passes/src/replace_types/handlers.rs @@ -4,15 +4,17 @@ use hugr_core::builder::{endo_sig, inout_sig, DFGBuilder, Dataflow, DataflowHugr}; use hugr_core::extension::prelude::{option_type, UnwrapBuilder}; use hugr_core::extension::ExtensionSet; +use hugr_core::ops::constant::CustomConst; use hugr_core::ops::{constant::OpaqueValue, Value}; use hugr_core::ops::{OpTrait, OpType, Tag}; use hugr_core::std_extensions::arithmetic::conversions::ConvertOpDef; use hugr_core::std_extensions::arithmetic::int_ops::IntOpDef; use hugr_core::std_extensions::arithmetic::int_types::{ConstInt, INT_TYPES}; -use hugr_core::std_extensions::collections::array::{ - array_type, ArrayOpDef, ArrayRepeat, ArrayScan, ArrayValue, -}; +use hugr_core::std_extensions::collections::array::{Array, ArrayKind, GenericArrayValue}; use hugr_core::std_extensions::collections::list::ListValue; +use hugr_core::std_extensions::collections::value_array::{ + value_array_type, VArrayOpDef, VArrayRepeat, VArrayScan, ValueArray, +}; use hugr_core::types::{SumType, Transformable, Type, TypeArg}; use hugr_core::{type_row, Hugr, HugrView}; use itertools::Itertools; @@ -44,14 +46,17 @@ pub fn list_const( Ok(Some(ListValue::new(elem_t, vals).into())) } -/// Handler for [ArrayValue] constants that recursively +/// Handler for [GenericArrayValue] constants that recursively /// [ReplaceTypes::change_value]s the elements of the list. /// Included in [ReplaceTypes::default]. -pub fn array_const( +pub fn generic_array_const( val: &OpaqueValue, repl: &ReplaceTypes, -) -> Result, ReplaceTypesError> { - let Some(av) = val.value().downcast_ref::() else { +) -> Result, ReplaceTypesError> +where + GenericArrayValue: CustomConst, +{ + let Some(av) = val.value().downcast_ref::>() else { return Ok(None); }; let mut elem_t = av.get_element_type().clone(); @@ -64,18 +69,42 @@ pub fn array_const( for v in vals.iter_mut() { repl.change_value(v)?; } - Ok(Some(ArrayValue::new(elem_t, vals).into())) + Ok(Some(GenericArrayValue::::new(elem_t, vals).into())) +} + +/// Handler for [ArrayValue] constants that recursively +/// [ReplaceTypes::change_value]s the elements of the list. +/// Included in [ReplaceTypes::default]. +/// +/// [ArrayValue]: hugr_core::std_extensions::collections::array::ArrayValue +pub fn array_const( + val: &OpaqueValue, + repl: &ReplaceTypes, +) -> Result, ReplaceTypesError> { + generic_array_const::(val, repl) +} + +/// Handler for [VArrayValue] constants that recursively +/// [ReplaceTypes::change_value]s the elements of the list. +/// Included in [ReplaceTypes::default]. +/// +/// [VArrayValue]: hugr_core::std_extensions::collections::value_array::VArrayValue +pub fn value_array_const( + val: &OpaqueValue, + repl: &ReplaceTypes, +) -> Result, ReplaceTypesError> { + generic_array_const::(val, repl) } fn runtime_reqs(h: &Hugr) -> ExtensionSet { h.signature(h.root()).unwrap().runtime_reqs.clone() } -/// Handler for copying/discarding arrays if their elements have become linear. +/// Handler for copying/discarding value arrays if their elements have become linear. /// Included in [ReplaceTypes::default] and [DelegatingLinearizer::default]. /// /// [DelegatingLinearizer::default]: super::DelegatingLinearizer::default -pub fn linearize_array( +pub fn linearize_value_array( args: &[TypeArg], num_outports: usize, lin: &CallbackHandler, @@ -97,8 +126,8 @@ pub fn linearize_array( dfb.finish_hugr_with_outputs([ret]).unwrap() }; // Now array.scan that over the input array to get an array of unit (which can be discarded) - let array_scan = ArrayScan::new(ty.clone(), Type::UNIT, vec![], *n, runtime_reqs(&map_fn)); - let in_type = array_type(*n, ty.clone()); + let array_scan = VArrayScan::new(ty.clone(), Type::UNIT, vec![], *n, runtime_reqs(&map_fn)); + let in_type = value_array_type(*n, ty.clone()); return Ok(NodeTemplate::CompoundOp(Box::new({ let mut dfb = DFGBuilder::new(inout_sig(in_type, type_row![])).unwrap(); let [in_array] = dfb.input_wires_arr(); @@ -113,7 +142,7 @@ pub fn linearize_array( // The num_outports>1 case will simplify, and unify with the previous, when we have a // more general ArrayScan https://github.com/CQCL/hugr/issues/2041. In the meantime: let num_new = num_outports - 1; - let array_ty = array_type(*n, ty.clone()); + let array_ty = value_array_type(*n, ty.clone()); let mut dfb = DFGBuilder::new(inout_sig( array_ty.clone(), vec![array_ty.clone(); num_outports], @@ -132,7 +161,7 @@ pub fn linearize_array( dfb.finish_hugr_with_outputs(none.outputs()).unwrap() }; let repeats = - vec![ArrayRepeat::new(option_ty.clone(), *n, runtime_reqs(&fn_none)); num_new]; + vec![VArrayRepeat::new(option_ty.clone(), *n, runtime_reqs(&fn_none)); num_new]; let fn_none = dfb.add_load_value(Value::function(fn_none).unwrap()); repeats .into_iter() @@ -146,7 +175,7 @@ pub fn linearize_array( // 2. use a scan through the input array, copying the element num_outputs times; // return the first copy, and put each of the other copies into one of the array