Skip to content
Merged
36 changes: 20 additions & 16 deletions src/algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -208,16 +208,19 @@ abstract type ProjectionAlgorithm end
if als.check isa FitCheck
if als.check.iter == 0
println("Warning: FitCheck is not enabled for $(als.mttkrp_alg) will run $(als.check.max_counter) iterations.")
if verbose
println("Warning: Sampled fit will be provided")
end
end
als.check.iter += 1
if verbose
inner_prod = real((had_contract([als.target, dag.(factors)...], cprank) * dag(λ))[])
partial_gram = [fact * dag(prime(fact; tags=tags(cprank))) for fact in factors];
fact_square = ITensorCPD.norm_factors(partial_gram, λ)
normResidual =
sqrt(abs(als.check.ref_norm * als.check.ref_norm + fact_square - 2 * abs(inner_prod)))
elt = typeof(inner_prod)
println("$(dim(cprank))\t$(als.check.iter)\t$(one(elt) - normResidual / norm(als.check.ref_norm))")
cpd = CPD{ITensor}(factors, λ)
krpproj = had_contract(factors[1], λ, cprank) * compute_krp(als.mttkrp_alg, als, factors, cpd, cprank, 1)
tproj = matricize_tensor(als.mttkrp_alg, als, factors, cpd, cprank, 1)

cpfit = one(real(eltype(krpproj))) - norm(tproj - krpproj) / norm(tproj)

println("$(dim(cprank))\t$(als.check.iter)\t$(cpfit)")
end
if als.check.iter == als.check.max_counter
als.check.iter = 0
Expand Down Expand Up @@ -258,14 +261,14 @@ abstract type ProjectionAlgorithm end
sampled_cols = sample_factor_matrices(nsamps, fact, als.additional_items[:factor_weights])
## Write new samples to pivot tensor
dRis = dims(inds(cp)[1:end .!= fact])
data(als.additional_items[:pivot_tensors][fact]) .= multi_coords_to_column(dRis, sampled_cols)
data(als.additional_items[:projects_tensors][fact]) .= multi_coords_to_column(dRis, sampled_cols)

return pivot_hadamard(factors, rank, sampled_cols, inds(als.additional_items[:pivot_tensors][fact])[end])
return pivot_hadamard(factors, rank, sampled_cols, inds(als.additional_items[:projects_tensors][fact])[end])
end

function matricize_tensor(::LevScoreSampled, als, factors, cp, rank::Index, fact::Int)
## I need to turn this into an ITensor and then pass it to the computed algorithm.
return fused_flatten_sample(als.target, fact, als.additional_items[:pivot_tensors][fact])
return fused_flatten_sample(als.target, fact, als.additional_items[:projects_tensors][fact])
end


Expand All @@ -288,6 +291,7 @@ abstract type ProjectionAlgorithm end
BlockLevScoreSampled(n::Int, m::Int) = BlockLevScoreSampled((n,), (m,))
BlockLevScoreSampled(n::Tuple) = BlockLevScoreSampled{n, (1,)}()
BlockLevScoreSampled(n::Int, m::Tuple) = BlockLevScoreSampled((n,), m)
BlockLevScoreSampled(n::Tuple, m::Int) = BlockLevScoreSampled(n, (m,))

nsamples(alg::BlockLevScoreSampled) = alg.NSamples
blocks(alg::BlockLevScoreSampled) = alg.Blocks
Expand All @@ -302,14 +306,14 @@ abstract type ProjectionAlgorithm end
sampled_cols = block_sample_factor_matrices(nsamps, als.additional_items[:factor_weights], block_size, fact)
## Write new samples to pivot tensor
dRis = dims(inds(cp)[1:end .!= fact])
data(als.additional_items[:pivot_tensors][fact]) .= multi_coords_to_column(dRis, sampled_cols)
data(als.additional_items[:projects_tensors][fact]) .= multi_coords_to_column(dRis, sampled_cols)

return pivot_hadamard(factors, rank, sampled_cols, inds(als.additional_items[:pivot_tensors][fact])[end])
return pivot_hadamard(factors, rank, sampled_cols, inds(als.additional_items[:projects_tensors][fact])[end])
end

function matricize_tensor(::BlockLevScoreSampled, als, factors, cp, rank::Index, fact::Int)
## I need to turn this into an ITensor and then pass it to the computed algorithm.
return fused_flatten_sample(als.target, fact, als.additional_items[:pivot_tensors][fact])
return fused_flatten_sample(als.target, fact, als.additional_items[:projects_tensors][fact])
end


Expand Down Expand Up @@ -376,7 +380,7 @@ abstract type ProjectionAlgorithm end
## need to recompute the QR. This is a "dumb" algorithm because it resamples the full
## target tensor so a future algorithm should just modify the target to reduce the amount of work.
## reshuffle redoes the sampling of the pivots beyond the rank of the matrix.
function update_samples(als, new_num_end; reshuffle = false, new_num_start = 0)
function update_samples(target, als, new_num_end; reshuffle = false, new_num_start = 0)
@assert(als.mttkrp_alg isa PivotBasedSolvers)

## Make an updated alg with correct new range
Expand All @@ -391,7 +395,7 @@ abstract type ProjectionAlgorithm end
## This is reshuffling the indices
if reshuffle
p1 = p[1:meff]
p_rest = p[m+1:end]
p_rest = p[meff+1:end]
p2 = p_rest[randperm(length(p_rest))]
pshuff = vcat(p1, p2)

Expand All @@ -415,7 +419,7 @@ abstract type ProjectionAlgorithm end

projectors[pos] = itensor(tensor(Diag(pivots[pos][int_start:int_end]), (Ris..., piv_id)))

targets[pos] = fused_flatten_sample(als.target, pos, projectors[pos])
targets[pos] = fused_flatten_sample(target, pos, projectors[pos])
end

extra_args = Dict(
Expand Down
2 changes: 2 additions & 0 deletions src/converge_checks/converge_checks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ function norm_factors(partial_gram::Vector, λ::ITensor)
return real(had*(λ*dag(prime(λ))))[]
end

save_mttkrp(::ConvergeAlg, ::ITensor) = nothing

include("no_check.jl")
include("fit_check.jl")
include("cp_diff_check.jl")
Expand Down
4 changes: 1 addition & 3 deletions src/converge_checks/cp_angle_check.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,4 @@ function check_converge(check::CPAngleCheck, factors, λ, partial_gram; verbose
return false
end

CPDFit(check::CPAngleCheck) = check.final_fit

function save_mttkrp(::CPAngleCheck, ::ITensor) end
CPDFit(check::CPAngleCheck) = check.final_fit
4 changes: 1 addition & 3 deletions src/converge_checks/cp_diff_check.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,4 @@ function check_converge(check::CPDiffCheck, factors, λ, partial_gram; verbose =
return false
end

CPDFit(check::CPDiffCheck) = check.final_fit

function save_mttkrp(::CPDiffCheck, ::ITensor) end
CPDFit(check::CPDiffCheck) = check.final_fit
2 changes: 0 additions & 2 deletions src/converge_checks/no_check.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,3 @@ function check_converge(check::NoCheck, factors, λ, partial_gram; verbose = fal
check.counter += 1
return false
end

function save_mttkrp(::NoCheck, ::ITensor) end
4 changes: 2 additions & 2 deletions src/math_tools/probability.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ function compute_leverage_score_probabilitiy(A, row::Index)
## This only works on matrices for now.
@assert ndims(A) == 2
q, _ = qr(A, row)
ITensors.hadamard_product!(q, q, q)
ITensors.hadamard_product!(q, q, dag(q))
ni = dim(q, 1)
return [sum(array(q)[i,:]) for i in 1:ni] ./ minimum(dims(A))
return [real(sum(array(q)[i,:])) for i in 1:ni] ./ minimum(dims(A))
end

function samples_from_probability_vector(PW::Vector, samples)
Expand Down
16 changes: 8 additions & 8 deletions src/optimizers/ALS/als.jl
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ function compute_als(
extra_args[:qr_factors] = qr_factors
extra_args[:effective_ranks] = effective_ranks

return ALS(target, alg, extra_args, check)
return ALS(ITensor(inds(target)), alg, extra_args, check)
end

function compute_als(
Expand Down Expand Up @@ -317,7 +317,7 @@ function compute_als(
extra_args[:qr_factors] = qr_factors
extra_args[:effective_ranks] = effective_ranks

return ALS(target, alg, extra_args, check)
return ALS(ITensor(inds(target)), alg, extra_args, check)
end

function compute_als(
Expand Down Expand Up @@ -362,7 +362,7 @@ function compute_als(
)
## For each factor matrix compute its weights
extra_args[:factor_weights] = [compute_leverage_score_probabilitiy(cp[i], ind(cp, i)) for i in 1:length(cp)]
pivot_tensors = Vector{ITensor}()
projects_tensors = Vector{ITensor}()
for fact in 1:length(cp)
## grab the tensor indices for all other factors but fact
Ris = inds(cp)[1:end .!= fact]
Expand All @@ -380,9 +380,9 @@ function compute_als(
piv_ind = Index(length(sampled_tensor_cols), "selector_$(fact)")

## make the canonical pivot tensor. This list of pivots will be overwritten each ALS iteration
push!(pivot_tensors, itensor(tensor(Diag(sampled_tensor_cols), (Ris..., piv_ind))))
push!(projects_tensors, itensor(tensor(Diag(sampled_tensor_cols), (Ris..., piv_ind))))
end
extra_args[:pivot_tensors] = pivot_tensors
extra_args[:projects_tensors] = projects_tensors
return ALS(target, alg, extra_args, check)
end

Expand All @@ -396,7 +396,7 @@ function compute_als(
)
## For each factor matrix compute its weights
extra_args[:factor_weights] = [compute_leverage_score_probabilitiy(cp[i], ind(cp, i)) for i in 1:length(cp)]
pivot_tensors = Vector{ITensor}()
projects_tensors = Vector{ITensor}()
for fact in 1:length(cp)
## grab the tensor indices for all other factors but fact
Ris = inds(cp)[1:end .!= fact]
Expand All @@ -417,11 +417,11 @@ function compute_als(
piv_ind = Index(length(sampled_tensor_cols), "selector_$(fact)")

## make the canonical pivot tensor. This list of pivots will be overwritten each ALS iteration
push!(pivot_tensors, itensor(tensor(Diag(sampled_tensor_cols), (Ris..., piv_ind))))
push!(projects_tensors, itensor(tensor(Diag(sampled_tensor_cols), (Ris..., piv_ind))))
end
## Notice the pivot tensor is actually a low rank tensor it stores the diagonal pivot values
## in α form (rows of the matricized tensor) and the indices which are captured in the pivot.
## The order of indices are (indices which connect to the pivot, pivot_index).
extra_args[:pivot_tensors] = pivot_tensors
extra_args[:projects_tensors] = projects_tensors
return ALS(target, alg, extra_args, check)
end
32 changes: 29 additions & 3 deletions test/standard_cpd.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
@testset "Standard CPD-ALS, elt=$elt" for elt in [Float64, ComplexF64]
elt = Float64
verbose = false
i, j, k = Index.((20, 30, 40))
r = Index(400, "CP_rank")
Expand Down Expand Up @@ -53,7 +52,7 @@
alg = ITensorCPD.QRPivProjected(800)
als = ITensorCPD.compute_als(A, cp_A; alg, check, trunc_tol=4);

als = ITensorCPD.update_samples(als, 900; reshuffle = false);
als = ITensorCPD.update_samples(A, als, 900; reshuffle = false);
@test ITensorCPD.stop(als.mttkrp_alg) == 900
@test ITensorCPD.start(als.mttkrp_alg) == 1
@test typeof(als.mttkrp_alg) == ITensorCPD.QRPivProjected
Expand All @@ -68,7 +67,7 @@
alg = ITensorCPD.SEQRCSPivProjected(1, 800, (1,2,3),(100,100,100))
als = ITensorCPD.compute_als(A, cp_A; alg, check);

als = ITensorCPD.update_samples(als, 600; reshuffle = true);
als = ITensorCPD.update_samples(A, als, 600; reshuffle = true);
@test ITensorCPD.stop(als.mttkrp_alg) == 600
@test ITensorCPD.start(als.mttkrp_alg) == 1
@test typeof(als.mttkrp_alg) == ITensorCPD.SEQRCSPivProjected
Expand Down Expand Up @@ -131,6 +130,25 @@
min_val = val < min_val ? val : val
end
@test min_val < 0.1

### Test for Leverage score sampling CPD
alg = ITensorCPD.BlockLevScoreSampled((50, 50, 500), 1)
min_val = 1
for i in 1:3
cpd_opt = ITensorCPD.als_optimize(T, cpd; alg, check, verbose);
val = norm(reconstruct(cpd_opt) - T) / norm(T)
min_val = val < min_val ? val : val
end
@test min_val < 0.1

alg = ITensorCPD.BlockLevScoreSampled((50, 50, 500), 12)
min_val = 1
for i in 1:3
cpd_opt = ITensorCPD.als_optimize(T, cpd; alg, check, verbose);
val = norm(reconstruct(cpd_opt) - T) / norm(T)
min_val = val < min_val ? val : val
end
@test min_val < 0.1
end

@testset "Standard CPD-ALS, elt=$elt" for elt in [Float32, ComplexF32]
Expand Down Expand Up @@ -227,6 +245,14 @@ end
alg = ITensorCPD.LevScoreSampled(100)
cpd_opt = ITensorCPD.als_optimize(T, cpd; alg, check, verbose);
@test norm(reconstruct(cpd_opt) - T) / norm(T) < 0.1

alg = ITensorCPD.BlockLevScoreSampled(100,1)
cpd_opt = ITensorCPD.als_optimize(T, cpd; alg, check, verbose);
@test norm(reconstruct(cpd_opt) - T) / norm(T) < 0.1

alg = ITensorCPD.BlockLevScoreSampled(100,3)
cpd_opt = ITensorCPD.als_optimize(T, cpd; alg, check, verbose);
@test norm(reconstruct(cpd_opt) - T) / norm(T) < 0.1
end


Expand Down