Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 25 additions & 6 deletions src/pullbacks/qr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,18 @@ function check_and_prepare_qr_cotangents(
ΔR₁₁ = UpperTriangular(view(ΔR, 1:p, 1:p))
ΔR₁₂ = view(ΔR, 1:p, (p + 1):n)
ΔR₂₂ = view(ΔR, (p + 1):minmn, (p + 1):n)
Δgauge_R = norm(view(ΔR₂₂, uppertriangularind(ΔR₂₂)), Inf)
Δgauge_R = max(Δgauge_R, norm(view(ΔR₂₂, diagind(ΔR₂₂)), Inf))
Δgauge = max(Δgauge, Δgauge_R)
if p < minmn # otherwise ΔR₂₂ is empty
# uppertriangularind generates linear indices
# compute the appropriate offset in ΔR so we aren't
# operating on a view-of-view, which doesn't work
# for GPU arrays
I = uppertriangularind(ΔR₂₂)
upper_inds = view(LinearIndices(ΔR), (p + 1):minmn, (p + 1):n)[I]
ΔR₂₂upper = view(ΔR, upper_inds)
Δgauge_R = norm(ΔR₂₂upper, Inf)
Δgauge_R = max(Δgauge_R, norm(view(ΔR₂₂, diagind(ΔR₂₂)), Inf))
Δgauge = max(Δgauge, Δgauge_R)
end
else
ΔR₁₁ = nothing
ΔR₁₂ = nothing
Expand Down Expand Up @@ -75,7 +84,7 @@ function qr_pullback!(


Q₁ = view(Q, :, 1:p)
R₁₁ = UpperTriangular(view(R, 1:p, 1:p))
R₁₁ = UpperTriangular(R[1:p, 1:p])
R₁₂ = view(R, 1:p, (p + 1):n)

ΔA₁ = view(ΔA, :, 1:p)
Expand All @@ -101,7 +110,8 @@ function qr_pullback!(
Md = diagview(M)
Md .= real.(Md)
end
ΔA₁ .+= rdiv!(mul!(ΔQ₁, Q₁, M, +1, 1), R₁₁')
mul!(ΔQ₁, Q₁, M, +1, 1)
ΔA₁ .+= rdiv!(ΔQ₁, R₁₁')
return ΔA
end

Expand Down Expand Up @@ -160,7 +170,16 @@ function remove_qr_gauge_dependence!(ΔQ, ΔR, A, Q, R; rank_atol = MatrixAlgebr
end
ΔR₂₂ = view(ΔR, (r + 1):minmn, (r + 1):size(R, 2))
zero!(diagview(ΔR₂₂))
zero!(view(ΔR₂₂, uppertriangularind(ΔR₂₂)))
if r < minmn
# uppertriangularind generates linear indices
# compute the appropriate offset in ΔR so we aren't
# operating on a view-of-view, which doesn't work
# for GPU arrays
offset = LinearIndices(ΔR)[r + 1, r + 1]
upper_inds = uppertriangularind(ΔR₂₂) .+ offset
ΔR₂₂upper = view(ΔR, upper_inds)
zero!(ΔR₂₂upper)
end
return ΔQ, ΔR
end

Expand Down
7 changes: 7 additions & 0 deletions test/mooncake/qr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,11 @@ for T in (BLASFloats..., GenericFloats...), n in (17, m, 23)
TestSuite.test_mooncake_qr(AT, (m, m); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T))
end
end
if T ∈ BLASFloats && CUDA.functional()
TestSuite.test_mooncake_qr(CuMatrix{T}, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T))
#=if m == n
AT = Diagonal{T, CuVector{T}}
TestSuite.test_mooncake_qr(AT, (m, m); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T))
end=# # currently broken
end
end
Loading