diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/onearg.jl index 82b769ef6..8c4b84031 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/onearg.jl @@ -550,3 +550,140 @@ function DI.value_gradient_and_hessian!( ) return fc(x), grad, hess end + +## HVP + +struct FiniteDiffHVPPrep{SIG, C1, C2, RG, AG, RH, AH, H} <: DI.HVPPrep{SIG} + _sig::Val{SIG} + gradient_cache::C1 + hessian_cache::C2 + relstep_g::RG + absstep_g::AG + relstep_h::RH + absstep_h::AH + hess::H +end + +function DI.prepare_hvp_nokwarg( + strict::Val, f, backend::AutoFiniteDiff, x, tx::NTuple, contexts::Vararg{DI.Context, C} + ) where {C} + _sig = DI.signature(f, backend, x, tx, contexts...; strict) + fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) + y = fc(x) + df = zero(y) .* x + gradient_cache = GradientCache(df, x, fdtype(backend)) + hessian_cache = HessianCache(x, fdhtype(backend)) + relstep_g = if isnothing(backend.relstep) + default_relstep(fdtype(backend), eltype(x)) + else + backend.relstep + end + relstep_h = if isnothing(backend.relstep) + default_relstep(fdhtype(backend), eltype(x)) + else + backend.relstep + end + absstep_g = if isnothing(backend.absstep) + relstep_g + else + backend.absstep + end + absstep_h = if isnothing(backend.absstep) + relstep_h + else + backend.absstep + end + hess = similar(x, eltype(x), (length(x), length(x))) + return FiniteDiffHVPPrep( + _sig, gradient_cache, hessian_cache, relstep_g, absstep_g, relstep_h, absstep_h, hess + ) +end + +function DI.hvp( + f, + prep::FiniteDiffHVPPrep, + backend::AutoFiniteDiff, + x, + tx::NTuple, + contexts::Vararg{DI.Context, C}, + ) where {C} + DI.check_prep(f, prep, backend, x, tx, contexts...) + (; relstep_h, absstep_h, hess) = prep + fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) + finite_difference_hessian!( + hess, fc, x, prep.hessian_cache; relstep = relstep_h, absstep = absstep_h + ) + tg = map(tx) do dx + reshape(hess * vec(dx), size(x)) + end + return tg +end + +function DI.hvp!( + f, + tg::NTuple, + prep::FiniteDiffHVPPrep, + backend::AutoFiniteDiff, + x, + tx::NTuple, + contexts::Vararg{DI.Context, C}, + ) where {C} + DI.check_prep(f, prep, backend, x, tx, contexts...) + (; relstep_h, absstep_h, hess) = prep + fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) + finite_difference_hessian!( + hess, fc, x, prep.hessian_cache; relstep = relstep_h, absstep = absstep_h + ) + for b in eachindex(tx, tg) + mul!(vec(tg[b]), hess, vec(tx[b])) + end + return tg +end + +function DI.gradient_and_hvp( + f, + prep::FiniteDiffHVPPrep, + backend::AutoFiniteDiff, + x, + tx::NTuple, + contexts::Vararg{DI.Context, C}, + ) where {C} + DI.check_prep(f, prep, backend, x, tx, contexts...) + (; relstep_g, absstep_g, relstep_h, absstep_h, hess) = prep + fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) + grad = finite_difference_gradient( + fc, x, prep.gradient_cache; relstep = relstep_g, absstep = absstep_g + ) + finite_difference_hessian!( + hess, fc, x, prep.hessian_cache; relstep = relstep_h, absstep = absstep_h + ) + tg = map(tx) do dx + reshape(hess * vec(dx), size(x)) + end + return grad, tg +end + +function DI.gradient_and_hvp!( + f, + grad, + tg::NTuple, + prep::FiniteDiffHVPPrep, + backend::AutoFiniteDiff, + x, + tx::NTuple, + contexts::Vararg{DI.Context, C}, + ) where {C} + DI.check_prep(f, prep, backend, x, tx, contexts...) + (; relstep_g, absstep_g, relstep_h, absstep_h, hess) = prep + fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) + finite_difference_gradient!( + grad, fc, x, prep.gradient_cache; relstep = relstep_g, absstep = absstep_g + ) + finite_difference_hessian!( + hess, fc, x, prep.hessian_cache; relstep = relstep_h, absstep = absstep_h + ) + for b in eachindex(tx, tg) + mul!(vec(tg[b]), hess, vec(tx[b])) + end + return grad, tg +end diff --git a/DifferentiationInterface/test/Back/FiniteDiff/test.jl b/DifferentiationInterface/test/Back/FiniteDiff/test.jl index d2a63f03e..04f416e92 100644 --- a/DifferentiationInterface/test/Back/FiniteDiff/test.jl +++ b/DifferentiationInterface/test/Back/FiniteDiff/test.jl @@ -26,7 +26,7 @@ end use_tuples = true, include_smaller = true, ); - excluded = [:second_derivative, :hvp], + excluded = [:second_derivative], logging = LOGGING, ) @@ -43,7 +43,7 @@ end AutoFiniteDiff(; relstep = cbrt(eps(Float64)), absstep = cbrt(eps(Float64))), AutoFiniteDiff(; dir = 0.5), ]; - excluded = [:second_derivative, :hvp], + excluded = [:second_derivative], logging = LOGGING, ) end @@ -90,6 +90,11 @@ end; @test prep.absstep_h == 1000 @test prep.relstep_g == 0.1 @test prep.relstep_h == 0.1 + prep = prepare_hvp(sum, backend, [1.0], ([1.0],)) + @test prep.absstep_g == 1000 + @test prep.absstep_h == 1000 + @test prep.relstep_g == 0.1 + @test prep.relstep_h == 0.1 backend = AutoFiniteDiff(; relstep = 0.1) preps = [ @@ -110,6 +115,55 @@ end; @test prep.absstep_h == 0.1 @test prep.relstep_g == 0.1 @test prep.relstep_h == 0.1 + prep = prepare_hvp(sum, backend, [1.0], ([1.0],)) + @test prep.absstep_g == 0.1 + @test prep.absstep_h == 0.1 + @test prep.relstep_g == 0.1 + @test prep.relstep_h == 0.1 +end + +@testset "HVP accuracy (issue 1012)" begin + # hvp should match hessian * v for default AutoFiniteDiff() + # Previously, hvp used fdtype (forward) while hessian used fdhtype (central), + # causing significant accuracy differences + backend = AutoFiniteDiff() + + for (f, x, v) in [ + (x -> sum(x .^ 2), [1.0, 2.0, 3.0], [1.0, 0.0, 0.0]), + (x -> sum(x .^ 3), [1.0, 2.0, 3.0], [1.0, 0.0, 0.0]), + (x -> sum(x .^ 4), [1.0, 2.0, 3.0], [1.0, 0.0, 0.0]), + (x -> x' * [1 2; 3 4] * x, [1.0, 2.0], [1.0, 0.0]), + ] + H = hessian(f, backend, x) + Hv_direct = H * v + Hv_hvp = hvp(f, backend, x, (v,))[1] + @test Hv_hvp ≈ Hv_direct rtol = 1e-10 + end + + # Also test hvp!, gradient_and_hvp, gradient_and_hvp! + f(x) = sum(x .^ 2) + x = [1.0, 2.0, 3.0] + v = [1.0, 0.0, 0.0] + H = hessian(f, backend, x) + expected_Hv = H * v + expected_grad = [2.0, 4.0, 6.0] + + # hvp! + tg = (similar(x),) + hvp!(f, tg, backend, x, (v,)) + @test tg[1] ≈ expected_Hv rtol = 1e-10 + + # gradient_and_hvp + grad, tg = gradient_and_hvp(f, backend, x, (v,)) + @test grad ≈ expected_grad rtol = 1e-6 + @test tg[1] ≈ expected_Hv rtol = 1e-10 + + # gradient_and_hvp! + grad = similar(x) + tg = (similar(x),) + gradient_and_hvp!(f, grad, tg, backend, x, (v,)) + @test grad ≈ expected_grad rtol = 1e-6 + @test tg[1] ≈ expected_Hv rtol = 1e-10 end include("benchmark.jl")