Skip to content

Commit f03b5d2

Browse files
committed
A few more specializations for truncation and diagonal
1 parent 677d141 commit f03b5d2

2 files changed

Lines changed: 47 additions & 0 deletions

File tree

src/linalg/factorizations.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,3 +188,30 @@ for f! in (
188188
@eval MAK.$f!(::AbstractBlockTensorMap, x, ::DiagonalAlgorithm) =
189189
error("Blocktensors are incompatible with diagonal algorithm")
190190
end
191+
192+
function TensorKit.Factorizations.truncate_domain!(tdst::AbstractBlockTensorMap, tsrc::AbstractBlockTensorMap, inds)
193+
TensorKit.foreachblock(tdst, tsrc) do c, (dst_block, src_block)
194+
I = get(inds, c, nothing)
195+
dst_dense = copy_dense!(similar_dense(dst_block), dst_block)
196+
src_dense = copy_dense!(similar_dense(src_block), src_block)
197+
@assert !isnothing(I)
198+
@views dst_dense .= src_dense[:, I]
199+
# deal with the case where the output is not in-place
200+
dst_dense === dst_block || copyto!(dst_block, dst_dense)
201+
return nothing
202+
end
203+
return tdst
204+
end
205+
function TensorKit.Factorizations.truncate_codomain!(tdst::AbstractBlockTensorMap, tsrc::AbstractBlockTensorMap, inds)
206+
TensorKit.foreachblock(tdst, tsrc) do c, (dst_block, src_block)
207+
I = get(inds, c, nothing)
208+
dst_dense = copy_dense!(similar_dense(dst_block), dst_block)
209+
src_dense = copy_dense!(similar_dense(src_block), src_block)
210+
@assert !isnothing(I)
211+
@views dst_dense .= src_dense[I, :]
212+
# deal with the case where the output is not in-place
213+
dst_dense === dst_block || copyto!(dst_block, dst_dense)
214+
return nothing
215+
end
216+
return tdst
217+
end

src/linalg/linalg.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,3 +239,23 @@ function LinearAlgebra.isposdef!(t::AbstractBlockTensorMap)
239239
end
240240
return true
241241
end
242+
243+
function LinearAlgebra.lmul!(D::DiagonalTensorMap, t::AbstractBlockTensorMap)
244+
domain(D) == codomain(t) || throw(SpaceMismatch())
245+
TensorKit.foreachblock(t, D) do c, (tblock, bs...)
246+
tblock′ = lmul!(bs..., copy_dense!(similar_dense(tblock), tblock))
247+
tblock === tblock′ || copyto!(tblock, tblock′)
248+
return tblock
249+
end
250+
return t
251+
end
252+
253+
function LinearAlgebra.rmul!(t::AbstractBlockTensorMap, D::DiagonalTensorMap)
254+
codomain(D) == domain(t) || throw(SpaceMismatch())
255+
TensorKit.foreachblock(t, D) do c, (tblock, bs...)
256+
tblock′ = rmul!(copy_dense!(similar_dense(tblock), tblock), bs...)
257+
tblock === tblock′ || copyto!(tblock, tblock′)
258+
return tblock
259+
end
260+
return t
261+
end

0 commit comments

Comments
 (0)