Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,29 @@ julia> fft(rand(Double64, 2))
2-element Vector{Complex{Double64}}:
0.4026739024263829 + 0.0im
0.3969515892883767 + 0.0im
```
```

## Usage for low-precision FFTs

```julia
julia> using GenericFFT, BFloat16s

julia> fs = 1000.0
julia> t = 0:1/fs:1-1/fs
julia> f1, f2 = 50.0, 120.0

julia> T = Float16
julia> # see: https://www.mathworks.com/help/matlab/ref/fft.html
julia> x = T.(0.7*sin.(2π * f1 * t) .+ 0.5 * sin.(2π * f2 * t)) .+ T(0.8)
julia> X = fft(x)
julia> println("Max round-trip error: ", maximum(abs.(x - real(ifft(X)))))

julia> T = BFloat16
julia> x = T.(0.7*sin.(2π * f1 * t) .+ 0.5 * sin.(2π * f2 * t)) .+ T(0.8)
julia> X = fft(x)
julia> println("Max round-trip error: ", maximum(abs.(x - real(ifft(X)))))
```


## History

Expand Down
116 changes: 89 additions & 27 deletions src/fft.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
const AbstractFloats = Union{AbstractFloat,Complex{T} where T<:AbstractFloat}

# We use these type definitions for clarity
const RealFloats = T where T<:AbstractFloat
const ComplexFloats = Complex{T} where T<:AbstractFloat

const AbstractFloats = Union{RealFloats, ComplexFloats}

# The following implements Bluestein's algorithm, following http://www.dsprelated.com/dspbooks/mdft/Bluestein_s_FFT_Algorithm.html
# To add more types, add them in the union of the function's signature.

function generic_fft!(x::AbstractVector{Complex{T}}) where {T<:AbstractFloat}
if ispow2(length(x))
Expand All @@ -23,7 +20,7 @@ function generic_fft!(x::AbstractVector{Complex{T}}, region::Integer) where {T<:
end

function _generic_fft_first_dim!(x, Ipost)
Threads.@threads for I in Ipost
for I in Ipost
generic_fft!(@view x[:, I])
end
x
Expand Down Expand Up @@ -81,18 +78,24 @@ function generic_fft!(x)
end


generic_fft(x, region) = generic_fft!(copy(x), region)
# generic_fft(x, region) = generic_fft!(copy(complex(x)), region)
# generic_fft(x) = generic_fft!(copy(complex(x)))

copycomplex(A::AbstractArray{<:Complex}) = copy(A)
copycomplex(A::AbstractArray{<:Real}) = complex(A)
generic_fft(x, region) = generic_fft!(copycomplex(x), region)
generic_fft(x) = generic_fft!(copycomplex(x))

generic_fft(x) = generic_fft!(copy(x))

function generic_fft(x::AbstractVector{T}) where T<:AbstractFloats
n = length(x)
ispow2(n) && return generic_fft_pow2(x)
ks = range(zero(real(T)),stop=n-one(real(T)),length=n)
Wks = @. cispi(-T(ks^2/n))
S = promote_type(real(T), Float64)
ks = range(zero(S), stop=S(n)-one(S), length=n)
Wks = Complex{real(T)}.(cispi.(-ks.^2 ./ S(n))) # always Complex
Wksrev = @view Wks[reverse(eachindex(Wks))]
xq, wq = x.*Wks, conj!([cispi(-T(n)); Wksrev; @view Wks[2:end]])
return Wks.* @view _conv!(xq,wq)[n+1:2n]
xq, wq = complex(x).*Wks, conj!([Complex{real(T)}(cispi(-S(n))); Wksrev; @view Wks[2:end]])
return Wks .* @view _conv!(xq,wq)[n+1:2n]
end

generic_bfft(x::AbstractArray{T, N}, region) where {T <: AbstractFloats, N} = conj!(generic_fft(conj(x), region))
Expand All @@ -105,27 +108,78 @@ generic_ifft(x::AbstractArray{T, N}, region) where {T<:AbstractFloats, N} = ldiv
generic_ifft!(x::AbstractArray{T, N}, region) where {T<:AbstractFloats, N} = ldiv!(T(_regionscale(x, region)), conj!(generic_fft!(conj!(x), region)))

generic_rfft(v::AbstractVector{T}, region) where T<:AbstractFloats = generic_fft(v, region)[1:div(length(v),2)+1]

function generic_rfft(x::AbstractArray{T, N}, region) where {T<:AbstractFloats, N}
d = first(region)
if length(region) > 1
return generic_fft(generic_rfft(x, d), region[2:end])
end

nout = size(x, d) ÷ 2 + 1
sz = collect(size(x))
sz[d] = nout
out = similar(x, Complex{real(T)}, tuple(sz...))

# CartesianIndices enables iterating over slices in arbitrary dimensions
Rpre = CartesianIndices(size(x)[1:d-1])
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using CartesianIndices is somewhat obscure so I would suggest adding a comment here to explain this.

Copy link
Copy Markdown
Author

@jamesquinlan jamesquinlan May 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will add (based on Julia's help): "# CartesianIndices enables iterating over slices in arbitrary dimensions"

or more detailed (but I tend to be parsimonious with comments). If preferred: # The generic_rfft function needs to apply a 1D FFT along dimension d of an array that could be 2D, 3D, 4D, etc. CartesianIndices lets us iterate over all the slices perpendicular to dimension d without knowing in advance how many dimensions the array has. Without it we'd need to write separate loops for 2D, 3D, 4D arrays, or use eachslice which is less efficient. With CartesianIndices one loop handles any dimensions

Rpost = CartesianIndices(size(x)[d+1:end])

for Ipost in Rpost
for Ipre in Rpre
out[Ipre, :, Ipost] .= generic_rfft(view(x, Ipre, :, Ipost), 1)
end
end
return out
end

function generic_irfft(v::AbstractVector{T}, n::Integer, region) where T<:ComplexFloats
@assert length(v) == n>>1 + 1
r = Vector{T}(undef, n)
r[1:length(v)]=v
r[length(v)+1:n]=reverse(conj(v[2:end])[1:n-length(v)])
real(generic_ifft(r, region))
return real(generic_ifft(r, region))
end

function generic_irfft(x::AbstractArray{T, N}, n::Integer, region) where {T<:ComplexFloats, N}
d = first(region)
if length(region) > 1
return generic_irfft(generic_ifft(x, region[2:end]), n, d)
end

sz = collect(size(x))
sz[d] = n
out = similar(x, real(T), tuple(sz...))

Rpre = CartesianIndices(size(x)[1:d-1])
Rpost = CartesianIndices(size(x)[d+1:end])

for Ipost in Rpost
for Ipre in Rpre
out[Ipre, :, Ipost] .= generic_irfft(view(x, Ipre, :, Ipost), n, 1)
end
end
return out
end

function generic_brfft(v::AbstractArray, n::Integer, region)
scale = n * _regionscale(v, region isa Integer ? () : region[2:end])
return generic_irfft(v, n, region) * scale
end
generic_brfft(v::AbstractArray, n::Integer, region) = generic_irfft(v, n, region)*n

function _conv!(u::AbstractVector{T}, v::AbstractVector{T}) where T<:AbstractFloats
nu = length(u)
nv = length(v)
n = nu + nv - 1
nu, nv = length(u), length(v)
n = nu + nv - 1
np2 = nextpow(2, n)
append!(u, zeros(T, np2-nu))
append!(v, zeros(T, np2-nv))
y = generic_ifft_pow2(generic_fft_pow2(u).*generic_fft_pow2(v))
#TODO This would not handle Dual/ComplexDual numbers correctly
y = T<:Real ? real(y[1:n]) : y[1:n]
S = promote_type(real(T), Float64)
uf = Complex{S}.(u)
vf = Complex{S}.(v)
y = generic_ifft_pow2(generic_fft_pow2(uf) .* generic_fft_pow2(vf))
y = T <: Real ? T.(real(y[1:n])) : T.(y[1:n])
end


# This is a Cooley-Tukey FFT algorithm inspired by many widely available algorithms including:
# c_radix2.c in the GNU Scientific Library and four1 in the Numerical Recipes in C.
# However, the trigonometric recurrence is improved for greater efficiency.
Expand Down Expand Up @@ -262,7 +316,7 @@ for P in (:DummyFFTPlan, :DummyiFFTPlan, :DummybFFTPlan, :DummyDCTPlan, :DummyiD
@eval begin
mutable struct $P{T,inplace,G} <: DummyPlan{T}
region::G # region (iterable) of dims that are transformed
pinv::DummyPlan{T}
pinv::Plan
$P{T,inplace,G}(region::G) where {T<:AbstractFloats, inplace, G} = new(region)
end
end
Expand All @@ -271,8 +325,8 @@ for P in (:DummyrFFTPlan, :DummyirFFTPlan, :DummybrFFTPlan)
@eval begin
mutable struct $P{T,inplace,G} <: DummyPlan{T}
n::Integer
region::G # region (iterable) of dims that are transformed
pinv::DummyPlan{T}
region::G
pinv::Plan
$P{T,inplace,G}(n::Integer, region::G) where {T<:AbstractFloats, inplace, G} = new(n, region)
end
end
Expand All @@ -287,8 +341,8 @@ for (Plan,iPlan) in ((:DummyFFTPlan,:DummyiFFTPlan),
end

# Specific for rfft, irfft and brfft:
plan_inv(p::DummyirFFTPlan{T,inplace,G}) where {T,inplace,G} = DummyrFFTPlan{T,inplace,G}(p.n, p.region)
plan_inv(p::DummyrFFTPlan{T,inplace,G}) where {T,inplace,G} = DummyirFFTPlan{T,inplace,G}(p.n, p.region)
plan_inv(p::DummyirFFTPlan{T,inplace,G}) where {T,inplace,G} = DummyrFFTPlan{real(T),inplace,G}(p.n, p.region)
plan_inv(p::DummyrFFTPlan{T,inplace,G}) where {T,inplace,G} = DummyirFFTPlan{Complex{T},inplace,G}(p.n, p.region)



Expand Down Expand Up @@ -331,6 +385,14 @@ end

plan_fft(x::StridedArray{T}, region) where {T <: ComplexFloats} = DummyFFTPlan{T,false,typeof(region)}(region)
plan_fft!(x::StridedArray{T}, region) where {T <: ComplexFloats} = DummyFFTPlan{T,true,typeof(region)}(region)
plan_fft(x::StridedArray{T}, region; kws...) where {T <: RealFloats} =
T <: FFTW.fftwReal ? invoke(plan_fft, Tuple{AbstractArray{<:Real}, Any}, x, region; kws...) : DummyFFTPlan{Complex{T},false,typeof(region)}(region)
plan_fft!(x::StridedArray{T}, region; kws...) where {T <: RealFloats} =
T <: FFTW.fftwReal ? invoke(plan_fft!, Tuple{AbstractArray, Any}, x, region; kws...) : DummyFFTPlan{Complex{T},true,typeof(region)}(region)

# intercept fft(x) before AbstractFFTs gets a chance for any non-FFTW float type.
fft(x::StridedArray{T}) where {T<:AbstractFloats} = generic_fft(x)
fft(x::StridedArray{T}, region) where {T<:AbstractFloats} = generic_fft(x, region)

plan_bfft(x::StridedArray{T}, region) where {T <: ComplexFloats} = DummybFFTPlan{T,false,typeof(region)}(region)
plan_bfft!(x::StridedArray{T}, region) where {T <: ComplexFloats} = DummybFFTPlan{T,true,typeof(region)}(region)
Expand All @@ -345,11 +407,11 @@ plan_dct!(x::StridedArray{T}, region) where {T <: AbstractFloats} = DummyDCTPlan
plan_idct(x::StridedArray{T}, region) where {T <: AbstractFloats} = DummyiDCTPlan{T,false,typeof(region)}(region)
plan_idct!(x::StridedArray{T}, region) where {T <: AbstractFloats} = DummyiDCTPlan{T,true,typeof(region)}(region)

plan_rfft(x::StridedArray{T}, region) where {T <: RealFloats} = DummyrFFTPlan{T,false,typeof(region)}(length(x), region)
plan_rfft(x::StridedArray{T}, region) where {T <: RealFloats} = DummyrFFTPlan{T,false,typeof(region)}(size(x, first(region)), region)
plan_brfft(x::StridedArray{T}, n::Integer, region) where {T <: ComplexFloats} = DummybrFFTPlan{T,false,typeof(region)}(n, region)

# A plan for irfft is created in terms of a plan for brfft.
# plan_irfft(x::StridedArray{T}, n::Integer, region) where {T <: ComplexFloats} = DummyirFFTPlan{Complex{real(T)},false,typeof(region)}(n, region)
# Explicitly define plan_irfft to ensure correct scaling
plan_irfft(x::StridedArray{T}, n::Integer, region) where {T <: ComplexFloats} = DummyirFFTPlan{T,false,typeof(region)}(n, region)

# These don't exist for now:
# plan_rfft!(x::StridedArray{T}) where {T <: RealFloats} = DummyrFFTPlan{Complex{real(T)},true}()
Expand Down
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b"
DoubleFloats = "497a8b3b-efae-58df-a0af-a86822472b78"
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand Down
Loading