Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "TensorAlgebra"
uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
version = "0.9.1"
version = "0.9.2"
authors = ["ITensor developers <support@itensor.org> and contributors"]

[workspace]
Expand All @@ -11,6 +11,7 @@ EllipsisNotation = "da5c29d0-fa7d-589e-88eb-ea29b0a81949"
FunctionImplementations = "7c7cc465-9c6a-495f-bdd1-f42428e86d0c"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MatrixAlgebraKit = "6c742aac-3347-4629-af66-fc926824e5e4"
Strided = "5e0ebb24-38b0-5f93-81fe-25c709ecae67"
StridedViews = "4db3bf67-4bd7-4b4e-b153-31dc3fb37143"
TupleTools = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6"
TypeParameterAccessors = "7e5a90cf-f82e-492e-a09b-e3e26432c138"
Expand All @@ -35,6 +36,7 @@ GPUArraysCore = "0.2"
LinearAlgebra = "1.10"
MatrixAlgebraKit = "0.2, 0.3, 0.4, 0.5, 0.6"
Mooncake = "0.4.202, 0.5"
Strided = "2.3.5"
StridedViews = "0.4.1, 0.5"
TensorOperations = "5"
TupleTools = "1.6"
Expand Down
22 changes: 13 additions & 9 deletions src/permutedimsadd.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import StridedViews as SV
using FunctionImplementations: permuteddims
using Strided: Strided

# Specify if an array is on CPU. This is helpful for backends that don't support
# operations on GPU, such as Strided.jl.
Expand Down Expand Up @@ -50,19 +51,22 @@ function bipermutedimsopadd!(
perm = (perm_codomain..., perm_domain...)
check_input(bipermutedimsopadd!, dest, src, perm_codomain, perm_domain)

# TODO: Remove this 0-dimensional special case once GradedArray is its own type
# (not an alias for BlockSparseArray), so the GradedArray overload catches the
# 0-dimensional contraction result.
# 0-dim short-circuit: avoid the permute-broadcast path entirely so that
# downstream array types (e.g. `BlockSparseArray{T, 0}`) don't have to define
# `getindex` on a 0-dim `PermutedDimsArray` wrapper around them.
# The `iszero(β)` guard follows the BLAS convention that `β = 0` means `dest`
# is write-only — its slot need not be defined. This matters for element types
# whose `undef` storage is unreadable, e.g. `Array{BigFloat, 0}(undef)[]` throws
# `UndefRefError`.
if iszero(ndims(dest))
dest[] = β * dest[] + α * op(src[])
if iszero(β)
dest[] = α * op(src[])
else
dest[] = β * dest[] + α * op(src[])
end
return dest
end

# This works around a bug in Strided.jl v2.3.4 and below when broadcasting
# empty StridedViews: https://github.com/QuantumKitHub/Strided.jl/pull/50
# TODO: Delete this and bump the version of Strided.jl once that is fixed.
isempty(dest) && return dest

dest′, src′ = maybestrided(dest, permuteddims(src, perm))
if op === identity
if iszero(β)
Expand Down
22 changes: 21 additions & 1 deletion test/test_permutedimsadd.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using Adapt: adapt
using JLArrays: JLArray
using TensorAlgebra: add!, permutedimsadd!, permutedimsopadd!
using TensorAlgebra: add!, bipermutedimsopadd!, permutedimsadd!, permutedimsopadd!
using Test: @test, @testset

@testset "[permutedims]add!" begin
Expand Down Expand Up @@ -47,6 +47,26 @@ using Test: @test, @testset
@test b′ ≈ β * b + α * permutedims(a, perm)
end
end
@testset "bipermutedimsopadd! 0-dim with β=0 must not read dest (eltype=$T)" for T in
(
Float64,
BigFloat,
)
# With β=0, `dest` is write-only by BLAS convention; its contents need not be
# defined. For element types whose `undef` storage is unreadable (e.g. mutable
# `BigFloat`), reading the slot would throw `UndefRefError`.
src = fill(T(7))
for op in (identity, conj)
dest = Array{T, 0}(undef)
bipermutedimsopadd!(dest, op, src, (), (), true, false)
@test dest[] == op(src[])
end
# With β nonzero, both reads and writes go through with the accumulating
# semantics `dest = β * dest + α * op(src)`.
dest = fill(T(2))
bipermutedimsopadd!(dest, identity, src, (), (), T(3), T(5))
@test dest[] == 3 * 7 + 5 * 2
end
@testset "permutedimsopadd! (arraytype=$arrayt)" for arrayt in (Array,)
dev = adapt(arrayt)
a = dev(randn(ComplexF64, 2, 2, 2))
Expand Down
Loading