Skip to content

Commit 0a01a2f

Browse files
authored
feat: Add array clone and discard ops (#2100)
1 parent 078dfc5 commit 0a01a2f

File tree

5 files changed

+758
-0
lines changed

5 files changed

+758
-0
lines changed

hugr-core/src/std_extensions/collections/array.rs

+14
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
//! Fixed-length array type and operations extension.
22
3+
mod array_clone;
4+
mod array_discard;
35
mod array_kind;
46
mod array_op;
57
mod array_repeat;
@@ -20,6 +22,8 @@ use crate::types::type_param::{TypeArg, TypeParam};
2022
use crate::types::{CustomCheckFailure, Type, TypeBound, TypeName};
2123
use crate::Extension;
2224

25+
pub use array_clone::{GenericArrayClone, GenericArrayCloneDef, ARRAY_CLONE_OP_ID};
26+
pub use array_discard::{GenericArrayDiscard, GenericArrayDiscardDef, ARRAY_DISCARD_OP_ID};
2327
pub use array_kind::ArrayKind;
2428
pub use array_op::{GenericArrayOp, GenericArrayOpDef};
2529
pub use array_repeat::{GenericArrayRepeat, GenericArrayRepeatDef, ARRAY_REPEAT_OP_ID};
@@ -57,13 +61,21 @@ impl ArrayKind for Array {
5761

5862
/// Array operation definitions.
5963
pub type ArrayOpDef = GenericArrayOpDef<Array>;
64+
/// Array clone operation definition.
65+
pub type ArrayCloneDef = GenericArrayCloneDef<Array>;
66+
/// Array discard operation definition.
67+
pub type ArrayDiscardDef = GenericArrayDiscardDef<Array>;
6068
/// Array repeat operation definition.
6169
pub type ArrayRepeatDef = GenericArrayRepeatDef<Array>;
6270
/// Array scan operation definition.
6371
pub type ArrayScanDef = GenericArrayScanDef<Array>;
6472

6573
/// Array operations.
6674
pub type ArrayOp = GenericArrayOp<Array>;
75+
/// The array clone operation.
76+
pub type ArrayClone = GenericArrayClone<Array>;
77+
/// The array discard operation.
78+
pub type ArrayDiscard = GenericArrayDiscard<Array>;
6779
/// The array repeat operation.
6880
pub type ArrayRepeat = GenericArrayRepeat<Array>;
6981
/// The array scan operation.
@@ -87,6 +99,8 @@ lazy_static! {
8799
.unwrap();
88100

89101
ArrayOpDef::load_all_ops(extension, extension_ref).unwrap();
102+
ArrayCloneDef::new().add_to_extension(extension, extension_ref).unwrap();
103+
ArrayDiscardDef::new().add_to_extension(extension, extension_ref).unwrap();
90104
ArrayRepeatDef::new().add_to_extension(extension, extension_ref).unwrap();
91105
ArrayScanDef::new().add_to_extension(extension, extension_ref).unwrap();
92106
})
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,234 @@
1+
//! Definition of the array clone operation.
2+
3+
use std::marker::PhantomData;
4+
use std::str::FromStr;
5+
use std::sync::{Arc, Weak};
6+
7+
use crate::extension::simple_op::{
8+
HasConcrete, HasDef, MakeExtensionOp, MakeOpDef, MakeRegisteredOp, OpLoadError,
9+
};
10+
use crate::extension::{ExtensionId, OpDef, SignatureError, SignatureFunc, TypeDef};
11+
use crate::ops::{ExtensionOp, NamedOp, OpName};
12+
use crate::types::type_param::{TypeArg, TypeParam};
13+
use crate::types::{FuncValueType, PolyFuncTypeRV, Type, TypeBound};
14+
use crate::Extension;
15+
16+
use super::array_kind::ArrayKind;
17+
18+
/// Name of the operation to clone an array
19+
pub const ARRAY_CLONE_OP_ID: OpName = OpName::new_inline("clone");
20+
21+
/// Definition of the array clone operation. Generic over the concrete array implementation.
22+
#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)]
23+
pub struct GenericArrayCloneDef<AK: ArrayKind>(PhantomData<AK>);
24+
25+
impl<AK: ArrayKind> GenericArrayCloneDef<AK> {
26+
/// Creates a new clone operation definition.
27+
pub fn new() -> Self {
28+
GenericArrayCloneDef(PhantomData)
29+
}
30+
}
31+
32+
impl<AK: ArrayKind> Default for GenericArrayCloneDef<AK> {
33+
fn default() -> Self {
34+
Self::new()
35+
}
36+
}
37+
38+
impl<AK: ArrayKind> NamedOp for GenericArrayCloneDef<AK> {
39+
fn name(&self) -> OpName {
40+
ARRAY_CLONE_OP_ID
41+
}
42+
}
43+
44+
impl<AK: ArrayKind> FromStr for GenericArrayCloneDef<AK> {
45+
type Err = ();
46+
47+
fn from_str(s: &str) -> Result<Self, Self::Err> {
48+
if s == ARRAY_CLONE_OP_ID {
49+
Ok(GenericArrayCloneDef::new())
50+
} else {
51+
Err(())
52+
}
53+
}
54+
}
55+
56+
impl<AK: ArrayKind> GenericArrayCloneDef<AK> {
57+
/// To avoid recursion when defining the extension, take the type definition as an argument.
58+
fn signature_from_def(&self, array_def: &TypeDef) -> SignatureFunc {
59+
let params = vec![TypeParam::max_nat(), TypeBound::Copyable.into()];
60+
let size = TypeArg::new_var_use(0, TypeParam::max_nat());
61+
let element_ty = Type::new_var_use(1, TypeBound::Copyable);
62+
let array_ty = AK::instantiate_ty(array_def, size, element_ty)
63+
.expect("Array type instantiation failed");
64+
PolyFuncTypeRV::new(
65+
params,
66+
FuncValueType::new(array_ty.clone(), vec![array_ty; 2]),
67+
)
68+
.into()
69+
}
70+
}
71+
72+
impl<AK: ArrayKind> MakeOpDef for GenericArrayCloneDef<AK> {
73+
fn from_def(op_def: &OpDef) -> Result<Self, OpLoadError>
74+
where
75+
Self: Sized,
76+
{
77+
crate::extension::simple_op::try_from_name(op_def.name(), op_def.extension_id())
78+
}
79+
80+
fn init_signature(&self, _extension_ref: &Weak<Extension>) -> SignatureFunc {
81+
self.signature_from_def(AK::type_def())
82+
}
83+
84+
fn extension_ref(&self) -> Weak<Extension> {
85+
Arc::downgrade(AK::extension())
86+
}
87+
88+
fn extension(&self) -> ExtensionId {
89+
AK::EXTENSION_ID
90+
}
91+
92+
fn description(&self) -> String {
93+
"Clones an array with copyable elements".into()
94+
}
95+
96+
/// Add an operation implemented as a [MakeOpDef], which can provide the data
97+
/// required to define an [OpDef], to an extension.
98+
//
99+
// This method is re-defined here since we need to pass the array type def while
100+
// computing the signature, to avoid recursive loops initializing the extension.
101+
fn add_to_extension(
102+
&self,
103+
extension: &mut Extension,
104+
extension_ref: &Weak<Extension>,
105+
) -> Result<(), crate::extension::ExtensionBuildError> {
106+
let sig = self.signature_from_def(extension.get_type(&AK::TYPE_NAME).unwrap());
107+
let def = extension.add_op(self.name(), self.description(), sig, extension_ref)?;
108+
self.post_opdef(def);
109+
Ok(())
110+
}
111+
}
112+
113+
/// Definition of the array clone op. Generic over the concrete array implementation.
114+
#[derive(Clone, Debug, PartialEq)]
115+
pub struct GenericArrayClone<AK: ArrayKind> {
116+
/// The element type of the array.
117+
pub elem_ty: Type,
118+
/// Size of the array.
119+
pub size: u64,
120+
_kind: PhantomData<AK>,
121+
}
122+
123+
impl<AK: ArrayKind> GenericArrayClone<AK> {
124+
/// Creates a new array clone op.
125+
///
126+
/// # Errors
127+
///
128+
/// If the provided element type is not copyable.
129+
pub fn new(elem_ty: Type, size: u64) -> Result<Self, OpLoadError> {
130+
elem_ty
131+
.copyable()
132+
.then_some(GenericArrayClone {
133+
elem_ty,
134+
size,
135+
_kind: PhantomData,
136+
})
137+
.ok_or(SignatureError::InvalidTypeArgs.into())
138+
}
139+
}
140+
141+
impl<AK: ArrayKind> NamedOp for GenericArrayClone<AK> {
142+
fn name(&self) -> OpName {
143+
ARRAY_CLONE_OP_ID
144+
}
145+
}
146+
147+
impl<AK: ArrayKind> MakeExtensionOp for GenericArrayClone<AK> {
148+
fn from_extension_op(ext_op: &ExtensionOp) -> Result<Self, OpLoadError>
149+
where
150+
Self: Sized,
151+
{
152+
let def = GenericArrayCloneDef::<AK>::from_def(ext_op.def())?;
153+
def.instantiate(ext_op.args())
154+
}
155+
156+
fn type_args(&self) -> Vec<TypeArg> {
157+
vec![
158+
TypeArg::BoundedNat { n: self.size },
159+
self.elem_ty.clone().into(),
160+
]
161+
}
162+
}
163+
164+
impl<AK: ArrayKind> MakeRegisteredOp for GenericArrayClone<AK> {
165+
fn extension_id(&self) -> ExtensionId {
166+
AK::EXTENSION_ID
167+
}
168+
169+
fn extension_ref(&self) -> Weak<Extension> {
170+
Arc::downgrade(AK::extension())
171+
}
172+
}
173+
174+
impl<AK: ArrayKind> HasDef for GenericArrayClone<AK> {
175+
type Def = GenericArrayCloneDef<AK>;
176+
}
177+
178+
impl<AK: ArrayKind> HasConcrete for GenericArrayCloneDef<AK> {
179+
type Concrete = GenericArrayClone<AK>;
180+
181+
fn instantiate(&self, type_args: &[TypeArg]) -> Result<Self::Concrete, OpLoadError> {
182+
match type_args {
183+
[TypeArg::BoundedNat { n }, TypeArg::Type { ty }] if ty.copyable() => {
184+
Ok(GenericArrayClone::new(ty.clone(), *n).unwrap())
185+
}
186+
_ => Err(SignatureError::InvalidTypeArgs.into()),
187+
}
188+
}
189+
}
190+
191+
#[cfg(test)]
192+
mod tests {
193+
use rstest::rstest;
194+
195+
use crate::extension::prelude::bool_t;
196+
use crate::std_extensions::collections::array::Array;
197+
use crate::{
198+
extension::prelude::qb_t,
199+
ops::{OpTrait, OpType},
200+
};
201+
202+
use super::*;
203+
204+
#[rstest]
205+
#[case(Array)]
206+
fn test_clone_def<AK: ArrayKind>(#[case] _kind: AK) {
207+
let op = GenericArrayClone::<AK>::new(bool_t(), 2).unwrap();
208+
let optype: OpType = op.clone().into();
209+
let new_op: GenericArrayClone<AK> = optype.cast().unwrap();
210+
assert_eq!(new_op, op);
211+
212+
assert_eq!(
213+
GenericArrayClone::<AK>::new(qb_t(), 2),
214+
Err(OpLoadError::InvalidArgs(SignatureError::InvalidTypeArgs))
215+
);
216+
}
217+
218+
#[rstest]
219+
#[case(Array)]
220+
fn test_clone<AK: ArrayKind>(#[case] _kind: AK) {
221+
let size = 2;
222+
let element_ty = bool_t();
223+
let op = GenericArrayClone::<AK>::new(element_ty.clone(), size).unwrap();
224+
let optype: OpType = op.into();
225+
let sig = optype.dataflow_signature().unwrap();
226+
assert_eq!(
227+
sig.io(),
228+
(
229+
&vec![AK::ty(size, element_ty.clone())].into(),
230+
&vec![AK::ty(size, element_ty.clone()); 2].into(),
231+
)
232+
);
233+
}
234+
}

0 commit comments

Comments
 (0)