From ef4de275634dbc1aed8eaf817594a0db46f3d65a Mon Sep 17 00:00:00 2001 From: Hossein Pourbozorg Date: Fri, 12 Jun 2026 12:09:51 +0330 Subject: [PATCH 1/5] remove `ClipNorm` --- examples/usage.jl | 1 - src/exts/mlj_ext/core_cond_icnf.jl | 5 ----- src/exts/mlj_ext/core_icnf.jl | 5 ----- 3 files changed, 11 deletions(-) diff --git a/examples/usage.jl b/examples/usage.jl index 1b1fac82..eeffb079 100644 --- a/examples/usage.jl +++ b/examples/usage.jl @@ -87,7 +87,6 @@ if !isfile(icnf_mach_fn) callback = opt_callback, alg = OptimiserChain( WeightDecay(; lambda = 1.0e-4), - ClipNorm(10.0, 2.0; throw = true), Adam(; eta = 0.001, beta = (0.9, 0.999), epsilon = 1.0e-8), ), progress = true, diff --git a/src/exts/mlj_ext/core_cond_icnf.jl b/src/exts/mlj_ext/core_cond_icnf.jl index 4c8979a0..4ac401a6 100644 --- a/src/exts/mlj_ext/core_cond_icnf.jl +++ b/src/exts/mlj_ext/core_cond_icnf.jl @@ -16,11 +16,6 @@ function CondICNFModel(; callback = make_opt_callback(64), alg = Optimisers.OptimiserChain( Optimisers.WeightDecay(; lambda = convert(eltype(icnf), 1.0e-4)), - Optimisers.ClipNorm( - convert(eltype(icnf), 10.0), - convert(eltype(icnf), 2.0); - throw = true, - ), Optimisers.Adam(; eta = convert(eltype(icnf), 0.001), beta = (convert(eltype(icnf), 0.9), convert(eltype(icnf), 0.999)), diff --git a/src/exts/mlj_ext/core_icnf.jl b/src/exts/mlj_ext/core_icnf.jl index 7f5e980f..2dba0235 100644 --- a/src/exts/mlj_ext/core_icnf.jl +++ b/src/exts/mlj_ext/core_icnf.jl @@ -16,11 +16,6 @@ function ICNFModel(; callback = make_opt_callback(64), alg = Optimisers.OptimiserChain( Optimisers.WeightDecay(; lambda = convert(eltype(icnf), 1.0e-4)), - Optimisers.ClipNorm( - convert(eltype(icnf), 10.0), - convert(eltype(icnf), 2.0); - throw = true, - ), Optimisers.Adam(; eta = convert(eltype(icnf), 0.001), beta = (convert(eltype(icnf), 0.9), convert(eltype(icnf), 0.999)), From c8d1a8e1cb755884d54660e58161f3485ff77b14 Mon Sep 17 00:00:00 2001 From: Hossein Pourbozorg Date: Fri, 12 Jun 2026 13:44:02 +0330 Subject: [PATCH 2/5] switch to `Float32` and loose `abstol` --- examples/usage.jl | 22 +++++++++++----------- src/core/icnf.jl | 6 +++--- src/layers/planar_layer.jl | 2 +- 3 files changed, 15 insertions(+), 15 deletions(-) diff --git a/examples/usage.jl b/examples/usage.jl index eeffb079..2d818966 100644 --- a/examples/usage.jl +++ b/examples/usage.jl @@ -37,11 +37,11 @@ icnf = ICNF(; nvariables = nvariables, # number of variables naugments = naugments, # number of augmented dimensions nconditions = 0, # number of conditioning inputs - λ₁ = 0.01, # regulate flow - λ₂ = 0.01, # regulate volume change - λ₃ = 0.01, # regulate augmented dimensions - steer_rate = 0.1, # add random noise to end of the time span - tspan = (0.0, 1.0), # time span + λ₁ = 0.01f0, # regulate flow + λ₂ = 0.01f0, # regulate volume change + λ₃ = 0.01f0, # regulate augmented dimensions + steer_rate = 0.1f0, # add random noise to end of the time span + tspan = (0.0f0, 1.0f0), # time span device = cpu_device(), # process data by CPU # device = gpu_device(), # process data by GPU autonomous = false, # using non-autonomous flow @@ -51,14 +51,14 @@ icnf = ICNF(; sol_kwargs = (; save_everystep = false, maxiters = typemax(Int), - reltol = 1.0e-4, - abstol = 1.0e-8, + reltol = 1.0f-4, + abstol = 1.0f-4, alg = VCABM(), sensealg = QuadratureAdjoint(; autodiff = true, autojacvec = ZygoteVJP(), - reltol = 1.0e-4, - abstol = 1.0e-8, + reltol = 1.0f-4, + abstol = 1.0f-4, ), progress = false, verbose = Detailed(), @@ -86,8 +86,8 @@ if !isfile(icnf_mach_fn) epochs = 300, callback = opt_callback, alg = OptimiserChain( - WeightDecay(; lambda = 1.0e-4), - Adam(; eta = 0.001, beta = (0.9, 0.999), epsilon = 1.0e-8), + WeightDecay(; lambda = 1.0f-4), + Adam(; eta = 0.001f0, beta = (0.9f0, 0.999f0), epsilon = 1.0f-8), ), progress = true, verbose = Detailed(), diff --git a/src/core/icnf.jl b/src/core/icnf.jl index 1b6477f9..3e07dbb9 100644 --- a/src/core/icnf.jl +++ b/src/core/icnf.jl @@ -51,7 +51,7 @@ struct ICNF{ end function ICNF(; - data_type::Type{<:AbstractFloat} = Float64, + data_type::Type{<:AbstractFloat} = Float32, compute_mode::ComputeMode = LuxVecJacMatrixMode(ADTypes.AutoZygote()), inplace::Bool = false, autonomous::Bool = false, @@ -85,7 +85,7 @@ function ICNF(; save_everystep = false, maxiters = typemax(Int), reltol = convert(data_type, 1.0e-4), - abstol = convert(data_type, 1.0e-8), + abstol = convert(data_type, 1.0e-4), alg = OrdinaryDiffEqAdamsBashforthMoulton.VCABM(), sensealg = SciMLSensitivity.QuadratureAdjoint(; autodiff = true, @@ -95,7 +95,7 @@ function ICNF(; true, ), reltol = convert(data_type, 1.0e-4), - abstol = convert(data_type, 1.0e-8), + abstol = convert(data_type, 1.0e-4), ), progress = false, verbose = SciMLLogging.Detailed(), diff --git a/src/layers/planar_layer.jl b/src/layers/planar_layer.jl index aaca4a83..3c2f42c9 100644 --- a/src/layers/planar_layer.jl +++ b/src/layers/planar_layer.jl @@ -15,7 +15,7 @@ function PlanarLayer( mapping::Pair{<:Int, <:Int}, activation::Any = identity; init_weight::Any = WeightInitializers.glorot_uniform, - init_bias::Any = WeightInitializers.zeros64, + init_bias::Any = WeightInitializers.zeros32, use_bias::Bool = true, ) return PlanarLayer{ From 2d61f044584e67652f04fcd36902d82163f57a26 Mon Sep 17 00:00:00 2001 From: Hossein Pourbozorg Date: Fri, 12 Jun 2026 13:47:33 +0330 Subject: [PATCH 3/5] drop old compats --- Project.toml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/Project.toml b/Project.toml index 5f4ee713..d7ada4fd 100644 --- a/Project.toml +++ b/Project.toml @@ -51,10 +51,10 @@ MLUtils = "0.4" NNlib = "0.9" Optimisers = "0.4" OptimizationOptimisers = "0.3" -OrdinaryDiffEqAdamsBashforthMoulton = "1, 2" +OrdinaryDiffEqAdamsBashforthMoulton = "2" Random = "1" -SciMLBase = "2, 3" -SciMLLogging = "1, 2" +SciMLBase = "3" +SciMLLogging = "2" SciMLSensitivity = "7" ScientificTypesBase = "3" Statistics = "1" From 2685d05d27aa6782dc182ae97dc687b149157284 Mon Sep 17 00:00:00 2001 From: Hossein Pourbozorg Date: Fri, 12 Jun 2026 13:56:10 +0330 Subject: [PATCH 4/5] fix --- examples/usage.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/usage.jl b/examples/usage.jl index 2d818966..d0a6295f 100644 --- a/examples/usage.jl +++ b/examples/usage.jl @@ -6,7 +6,7 @@ global_logger(TerminalLogger()) using Distributions ndata = 1024 ndimensions = 1 -data_dist = Beta(2.0, 4.0) +data_dist = Beta(2.0f0, 4.0f0) r = rand(data_dist, ndimensions, ndata) ## Parameters @@ -119,8 +119,8 @@ display(res_df) using CairoMakie f = Figure() ax = Axis(f[1, 1]; title = "Result") -lines!(ax, 0.0 .. 1.0, x -> pdf(data_dist, x); label = "Actual") -lines!(ax, 0.0 .. 1.0, x -> pdf(d, vcat(x)); label = "Estimated") +lines!(ax, 0.0f0 .. 1.0f0, x -> pdf(data_dist, x); label = "Actual") +lines!(ax, 0.0f0 .. 1.0f0, x -> pdf(d, vcat(x)); label = "Estimated") axislegend(ax) save("result-figure.svg", f) save("result-figure.png", f) From d912ed0485f5f19ad9a9f6ffa7775a91aadb3659 Mon Sep 17 00:00:00 2001 From: Hossein Pourbozorg Date: Fri, 12 Jun 2026 14:07:40 +0330 Subject: [PATCH 5/5] fix 2 --- benchmark/benchmarks.jl | 2 +- test/ci_tests/regression_tests.jl | 2 +- test/ci_tests/smoke_tests.jl | 5 ++--- test/quality_tests/checkby_JET_tests.jl | 4 ++-- 4 files changed, 6 insertions(+), 7 deletions(-) diff --git a/benchmark/benchmarks.jl b/benchmark/benchmarks.jl index 33ee69fb..3ca8ae8b 100644 --- a/benchmark/benchmarks.jl +++ b/benchmark/benchmarks.jl @@ -10,7 +10,7 @@ import ADTypes, ndata = 2^10 ndimensions = 1 -data_dist = Distributions.Beta(2.0, 4.0) +data_dist = Distributions.Beta(2.0f0, 4.0f0) r = rand(data_dist, ndimensions, ndata) nvariables = size(r, 1) diff --git a/test/ci_tests/regression_tests.jl b/test/ci_tests/regression_tests.jl index abf9811a..2f588e99 100644 --- a/test/ci_tests/regression_tests.jl +++ b/test/ci_tests/regression_tests.jl @@ -1,7 +1,7 @@ Test.@testset verbose = true showtiming = true failfast = false "Regression Tests" begin ndata = 2^10 ndimensions = 1 - data_dist = Distributions.Beta(2.0, 4.0) + data_dist = Distributions.Beta(2.0f0, 4.0f0) r = rand(data_dist, ndimensions, ndata) nvariables = size(r, 1) diff --git a/test/ci_tests/smoke_tests.jl b/test/ci_tests/smoke_tests.jl index 7b37a4e5..a1b9210f 100644 --- a/test/ci_tests/smoke_tests.jl +++ b/test/ci_tests/smoke_tests.jl @@ -9,8 +9,8 @@ Test.@testset verbose = true showtiming = true failfast = false "Smoke Tests" be ndata = 4 ndimensions = 2 - data_dist = Distributions.Beta(2.0, 4.0) - data_dist2 = Distributions.Beta(2.0, 4.0) + data_dist = Distributions.Beta(2.0f0, 4.0f0) + data_dist2 = Distributions.Beta(2.0f0, 4.0f0) if compute_mode isa ContinuousNormalizingFlows.VectorMode r = rand(data_dist, ndimensions) r2 = rand(data_dist2, ndimensions) @@ -124,7 +124,6 @@ Test.@testset verbose = true showtiming = true failfast = false "Smoke Tests" be Test.@testset verbose = true showtiming = true failfast = false "$adtype on loss" for adtype in adtypes - Test.@test !isnothing(DifferentiationInterface.gradient(diff_loss, adtype, ps)) Test.@test !isnothing(DifferentiationInterface.gradient(diff2_loss, adtype, r)) diff --git a/test/quality_tests/checkby_JET_tests.jl b/test/quality_tests/checkby_JET_tests.jl index e669602d..5c883765 100644 --- a/test/quality_tests/checkby_JET_tests.jl +++ b/test/quality_tests/checkby_JET_tests.jl @@ -13,8 +13,8 @@ Test.@testset verbose = true showtiming = true failfast = false "CheckByJET" beg ndata = 4 ndimensions = 2 - data_dist = Distributions.Beta(2.0, 4.0) - data_dist2 = Distributions.Beta(2.0, 4.0) + data_dist = Distributions.Beta(2.0f0, 4.0f0) + data_dist2 = Distributions.Beta(2.0f0, 4.0f0) if compute_mode isa ContinuousNormalizingFlows.VectorMode r = rand(data_dist, ndimensions) r2 = rand(data_dist2, ndimensions)