Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,10 @@ MLUtils = "0.4"
NNlib = "0.9"
Optimisers = "0.4"
OptimizationOptimisers = "0.3"
OrdinaryDiffEqAdamsBashforthMoulton = "1, 2"
OrdinaryDiffEqAdamsBashforthMoulton = "2"
Random = "1"
SciMLBase = "2, 3"
SciMLLogging = "1, 2"
SciMLBase = "3"
SciMLLogging = "2"
SciMLSensitivity = "7"
ScientificTypesBase = "3"
Statistics = "1"
Expand Down
2 changes: 1 addition & 1 deletion benchmark/benchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import ADTypes,

ndata = 2^10
ndimensions = 1
data_dist = Distributions.Beta(2.0, 4.0)
data_dist = Distributions.Beta(2.0f0, 4.0f0)
r = rand(data_dist, ndimensions, ndata)

nvariables = size(r, 1)
Expand Down
29 changes: 14 additions & 15 deletions examples/usage.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ global_logger(TerminalLogger())
using Distributions
ndata = 1024
ndimensions = 1
data_dist = Beta(2.0, 4.0)
data_dist = Beta(2.0f0, 4.0f0)
r = rand(data_dist, ndimensions, ndata)

## Parameters
Expand Down Expand Up @@ -37,11 +37,11 @@ icnf = ICNF(;
nvariables = nvariables, # number of variables
naugments = naugments, # number of augmented dimensions
nconditions = 0, # number of conditioning inputs
λ₁ = 0.01, # regulate flow
λ₂ = 0.01, # regulate volume change
λ₃ = 0.01, # regulate augmented dimensions
steer_rate = 0.1, # add random noise to end of the time span
tspan = (0.0, 1.0), # time span
λ₁ = 0.01f0, # regulate flow
λ₂ = 0.01f0, # regulate volume change
λ₃ = 0.01f0, # regulate augmented dimensions
steer_rate = 0.1f0, # add random noise to end of the time span
tspan = (0.0f0, 1.0f0), # time span
device = cpu_device(), # process data by CPU
# device = gpu_device(), # process data by GPU
autonomous = false, # using non-autonomous flow
Expand All @@ -51,14 +51,14 @@ icnf = ICNF(;
sol_kwargs = (;
save_everystep = false,
maxiters = typemax(Int),
reltol = 1.0e-4,
abstol = 1.0e-8,
reltol = 1.0f-4,
abstol = 1.0f-4,
alg = VCABM(),
sensealg = QuadratureAdjoint(;
autodiff = true,
autojacvec = ZygoteVJP(),
reltol = 1.0e-4,
abstol = 1.0e-8,
reltol = 1.0f-4,
abstol = 1.0f-4,
),
progress = false,
verbose = Detailed(),
Expand Down Expand Up @@ -86,9 +86,8 @@ if !isfile(icnf_mach_fn)
epochs = 300,
callback = opt_callback,
alg = OptimiserChain(
WeightDecay(; lambda = 1.0e-4),
ClipNorm(10.0, 2.0; throw = true),
Adam(; eta = 0.001, beta = (0.9, 0.999), epsilon = 1.0e-8),
WeightDecay(; lambda = 1.0f-4),
Adam(; eta = 0.001f0, beta = (0.9f0, 0.999f0), epsilon = 1.0f-8),
),
progress = true,
verbose = Detailed(),
Expand Down Expand Up @@ -120,8 +119,8 @@ display(res_df)
using CairoMakie
f = Figure()
ax = Axis(f[1, 1]; title = "Result")
lines!(ax, 0.0 .. 1.0, x -> pdf(data_dist, x); label = "Actual")
lines!(ax, 0.0 .. 1.0, x -> pdf(d, vcat(x)); label = "Estimated")
lines!(ax, 0.0f0 .. 1.0f0, x -> pdf(data_dist, x); label = "Actual")
lines!(ax, 0.0f0 .. 1.0f0, x -> pdf(d, vcat(x)); label = "Estimated")
axislegend(ax)
save("result-figure.svg", f)
save("result-figure.png", f)
6 changes: 3 additions & 3 deletions src/core/icnf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ struct ICNF{
end

function ICNF(;
data_type::Type{<:AbstractFloat} = Float64,
data_type::Type{<:AbstractFloat} = Float32,
compute_mode::ComputeMode = LuxVecJacMatrixMode(ADTypes.AutoZygote()),
inplace::Bool = false,
autonomous::Bool = false,
Expand Down Expand Up @@ -85,7 +85,7 @@ function ICNF(;
save_everystep = false,
maxiters = typemax(Int),
reltol = convert(data_type, 1.0e-4),
abstol = convert(data_type, 1.0e-8),
abstol = convert(data_type, 1.0e-4),
alg = OrdinaryDiffEqAdamsBashforthMoulton.VCABM(),
sensealg = SciMLSensitivity.QuadratureAdjoint(;
autodiff = true,
Expand All @@ -95,7 +95,7 @@ function ICNF(;
true,
),
reltol = convert(data_type, 1.0e-4),
abstol = convert(data_type, 1.0e-8),
abstol = convert(data_type, 1.0e-4),
),
progress = false,
verbose = SciMLLogging.Detailed(),
Expand Down
5 changes: 0 additions & 5 deletions src/exts/mlj_ext/core_cond_icnf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,6 @@ function CondICNFModel(;
callback = make_opt_callback(64),
alg = Optimisers.OptimiserChain(
Optimisers.WeightDecay(; lambda = convert(eltype(icnf), 1.0e-4)),
Optimisers.ClipNorm(
convert(eltype(icnf), 10.0),
convert(eltype(icnf), 2.0);
throw = true,
),
Optimisers.Adam(;
eta = convert(eltype(icnf), 0.001),
beta = (convert(eltype(icnf), 0.9), convert(eltype(icnf), 0.999)),
Expand Down
5 changes: 0 additions & 5 deletions src/exts/mlj_ext/core_icnf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,6 @@ function ICNFModel(;
callback = make_opt_callback(64),
alg = Optimisers.OptimiserChain(
Optimisers.WeightDecay(; lambda = convert(eltype(icnf), 1.0e-4)),
Optimisers.ClipNorm(
convert(eltype(icnf), 10.0),
convert(eltype(icnf), 2.0);
throw = true,
),
Optimisers.Adam(;
eta = convert(eltype(icnf), 0.001),
beta = (convert(eltype(icnf), 0.9), convert(eltype(icnf), 0.999)),
Expand Down
2 changes: 1 addition & 1 deletion src/layers/planar_layer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ function PlanarLayer(
mapping::Pair{<:Int, <:Int},
activation::Any = identity;
init_weight::Any = WeightInitializers.glorot_uniform,
init_bias::Any = WeightInitializers.zeros64,
init_bias::Any = WeightInitializers.zeros32,
use_bias::Bool = true,
)
return PlanarLayer{
Expand Down
2 changes: 1 addition & 1 deletion test/ci_tests/regression_tests.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Test.@testset verbose = true showtiming = true failfast = false "Regression Tests" begin
ndata = 2^10
ndimensions = 1
data_dist = Distributions.Beta(2.0, 4.0)
data_dist = Distributions.Beta(2.0f0, 4.0f0)
r = rand(data_dist, ndimensions, ndata)

nvariables = size(r, 1)
Expand Down
5 changes: 2 additions & 3 deletions test/ci_tests/smoke_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ Test.@testset verbose = true showtiming = true failfast = false "Smoke Tests" be

ndata = 4
ndimensions = 2
data_dist = Distributions.Beta(2.0, 4.0)
data_dist2 = Distributions.Beta(2.0, 4.0)
data_dist = Distributions.Beta(2.0f0, 4.0f0)
data_dist2 = Distributions.Beta(2.0f0, 4.0f0)
if compute_mode isa ContinuousNormalizingFlows.VectorMode
r = rand(data_dist, ndimensions)
r2 = rand(data_dist2, ndimensions)
Expand Down Expand Up @@ -124,7 +124,6 @@ Test.@testset verbose = true showtiming = true failfast = false "Smoke Tests" be

Test.@testset verbose = true showtiming = true failfast = false "$adtype on loss" for adtype in
adtypes

Test.@test !isnothing(DifferentiationInterface.gradient(diff_loss, adtype, ps))
Test.@test !isnothing(DifferentiationInterface.gradient(diff2_loss, adtype, r))

Expand Down
4 changes: 2 additions & 2 deletions test/quality_tests/checkby_JET_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ Test.@testset verbose = true showtiming = true failfast = false "CheckByJET" beg

ndata = 4
ndimensions = 2
data_dist = Distributions.Beta(2.0, 4.0)
data_dist2 = Distributions.Beta(2.0, 4.0)
data_dist = Distributions.Beta(2.0f0, 4.0f0)
data_dist2 = Distributions.Beta(2.0f0, 4.0f0)
if compute_mode isa ContinuousNormalizingFlows.VectorMode
r = rand(data_dist, ndimensions)
r2 = rand(data_dist2, ndimensions)
Expand Down
Loading