Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[rocFFT] Don't perform a copy for each call of FFT #728

Merged
merged 5 commits into from
Jan 30, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 65 additions & 35 deletions src/fft/fft.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
release_plan!(plan)
plan.handle = C_NULL
end
if !isnothing(plan.buffer)
unsafe_free!(plan.buffer)
end
unsafe_free!(plan.workarea)
rocfft_execution_info_destroy(plan.execution_info)
end
Expand All @@ -22,27 +25,28 @@
# For R2C -> cast array to Complex first

# K is flag for forward/inverse
mutable struct cROCFFTPlan{T, K, inplace, N} <: ROCFFTPlan{T, K, inplace}
mutable struct cROCFFTPlan{T,K,inplace,N,R,B} <: ROCFFTPlan{T, K, inplace}
handle::rocfft_plan
stream::HIPStream
workarea::ROCVector{Int8}
execution_info::rocfft_execution_info
sz::NTuple{N, Int} # Julia size of input array
osz::NTuple{N, Int} # Julia size of output array
xtype::rocfft_transform_type
region::Any
region::NTuple{R,Int}
buffer::B
# These two fields are used in the dtor as a key for HandleCache.
input_sz_as_key::Bool
key_T::Type
# required by AbstractFFTs API
pinv::ScaledPlan

function cROCFFTPlan{T,K,inplace,N}(
function cROCFFTPlan{T,K,inplace,N,R,B}(
handle::rocfft_plan, workarea::ROCVector{Int8},
X::ROCArray{T,N}, sizey::Tuple,
xtype::rocfft_transform_type, region,
input_sz_as_key::Bool, key_T::Type,
) where {T,inplace,N,K}
xtype::rocfft_transform_type, region::NTuple{R,Int},
buffer::B, input_sz_as_key::Bool, key_T::Type,
) where {T,K,inplace,N,R,B}
info_ref = Ref{rocfft_execution_info}()
rocfft_execution_info_create(info_ref)
info = info_ref[]
Expand All @@ -53,31 +57,32 @@
if length(workarea) > 0
rocfft_execution_info_set_work_buffer(info, workarea, length(workarea))
end
p = new(handle, stream, workarea, info, size(X), sizey, xtype, region, input_sz_as_key, key_T)
p = new(handle, stream, workarea, info, size(X), sizey, xtype, region, buffer, input_sz_as_key, key_T)
return finalizer(AMDGPU.unsafe_free!, p)
end
end

mutable struct rROCFFTPlan{T,K,inplace,N} <: ROCFFTPlan{T,K,inplace}
mutable struct rROCFFTPlan{T,K,inplace,N,R,B} <: ROCFFTPlan{T,K,inplace}
handle::rocfft_plan
stream::HIPStream
workarea::ROCVector{Int8}
execution_info::rocfft_execution_info
sz::NTuple{N,Int} # Julia size of input array
osz::NTuple{N,Int} # Julia size of output array
xtype::rocfft_transform_type
region::Any
region::NTuple{R,Int}
buffer::B
# These two fields are used in the dtor as a key for HandleCache.
input_sz_as_key::Bool
key_T::Type
# required by AbstractFFTs API
pinv::ScaledPlan

function rROCFFTPlan{T,K,inplace,N}(
function rROCFFTPlan{T,K,inplace,N,R,B}(
handle::rocfft_plan, workarea::ROCVector{Int8}, X::ROCArray{T,N},
sizey::Tuple, xtype::rocfft_transform_type, region,
input_sz_as_key::Bool, key_T::Type,
) where {T,inplace,N,K}
sizey::Tuple, xtype::rocfft_transform_type, region::NTuple{R,Int},
buffer::B, input_sz_as_key::Bool, key_T::Type,
) where {T,inplace,N,K,R,B}
info_ref = Ref{rocfft_execution_info}()
rocfft_execution_info_create(info_ref)
info = info_ref[]
Expand All @@ -87,7 +92,7 @@
if length(workarea) > 0
rocfft_execution_info_set_work_buffer(info, workarea, length(workarea))
end
p = new(handle, stream, workarea, info, size(X), sizey, xtype, region, input_sz_as_key, key_T)
p = new(handle, stream, workarea, info, size(X), sizey, xtype, region, buffer, input_sz_as_key, key_T)
return finalizer(unsafe_free!, p)
end
end
Expand Down Expand Up @@ -134,67 +139,84 @@
@eval function $f(X::ROCArray{T, N}, region) where {T <: rocfftComplexes, N}
_inplace = $(inplace)
_xtype = $(xtype)
R = length(region)
region = NTuple{R,Int}(region)
pp = get_plan(_xtype, size(X), T, _inplace, region)
return cROCFFTPlan{T, $forward, _inplace, N}(pp..., X, size(X), _xtype, region, false, T)
return cROCFFTPlan{T,$forward,_inplace,N,R,Nothing}(pp..., X, size(X), _xtype, region, nothing, false, T)
end
end

function plan_rfft(X::ROCArray{T,N}, region) where {T<:rocfftReals,N}
inplace = false
xtype = rocfft_transform_type_real_forward
R = length(region)
region = NTuple{R,Int}(region)
pp = get_plan(xtype, size(X), T, inplace, region)
ydims = collect(size(X))
ydims[region[1]] = div(ydims[region[1]],2) + 1
return rROCFFTPlan{T,ROCFFT_FORWARD,inplace,N}(pp..., X, (ydims...,), xtype, region, true, T)

# The buffer is not needed for real-to-complex (`mul!`),
# but it’s required for complex-to-real (`ldiv!`).
buffer = ROCArray{complex(T)}(undef, ydims...)
B = typeof(buffer)

return rROCFFTPlan{T,ROCFFT_FORWARD,inplace,N,R,B}(pp..., X, (ydims...,), xtype, region, buffer, true, T)
end

function plan_brfft(X::ROCArray{T,N}, d::Integer, region::Any) where {T <: rocfftComplexes, N}
function plan_brfft(X::ROCArray{T,N}, d::Integer, region) where {T <: rocfftComplexes, N}
inplace = false
xtype = rocfft_transform_type_real_inverse
R = length(region)
region = NTuple{R,Int}(region)
ydims = collect(size(X))
ydims[region[1]] = d
pp = get_plan(xtype, (ydims...,), T, inplace, region)
return rROCFFTPlan{T,ROCFFT_INVERSE,inplace,N}(pp..., X, (ydims...,), xtype, region, false, T)

# Buffer to not modify the input in a complex-to-real FFT.
buffer = ROCArray{T}(undef, size(X))
B = typeof(buffer)

return rROCFFTPlan{T,ROCFFT_INVERSE,inplace,N,R,B}(pp..., X, (ydims...,), xtype, region, buffer, false, T)
end

# FIXME: plan_inv methods allocate needlessly (to provide type parameters and normalization function)
# Perhaps use FakeArray types to avoid this.
function plan_inv(p::cROCFFTPlan{T,ROCFFT_FORWARD,inplace,N}) where {T<:rocfftComplexes,N,inplace}
function plan_inv(p::cROCFFTPlan{T,ROCFFT_FORWARD,inplace,N,R,B}) where {T<:rocfftComplexes,N,inplace,R,B}
X = ROCArray{T}(undef, p.sz)
xtype = rocfft_transform_type_complex_inverse
pp = get_plan(xtype, p.sz, T, inplace, p.region)
ScaledPlan(
cROCFFTPlan{T,ROCFFT_INVERSE,inplace,N}(pp..., X, p.sz, xtype, p.region, false, T),
cROCFFTPlan{T,ROCFFT_INVERSE,inplace,N,R,B}(pp..., X, p.sz, xtype, p.region, p.buffer, false, T),
normalization(X, p.region))
end

function plan_inv(p::cROCFFTPlan{T,ROCFFT_INVERSE,inplace,N}) where {T<:rocfftComplexes,N,inplace}
function plan_inv(p::cROCFFTPlan{T,ROCFFT_INVERSE,inplace,N,R,B}) where {T<:rocfftComplexes,N,inplace,R,B}

Check warning on line 193 in src/fft/fft.jl

View check run for this annotation

Codecov / codecov/patch

src/fft/fft.jl#L193

Added line #L193 was not covered by tests
X = ROCArray{T}(undef, p.sz)
xtype = rocfft_transform_type_complex_forward
pp = get_plan(xtype, p.sz, T, inplace, p.region)
ScaledPlan(
cROCFFTPlan{T,ROCFFT_FORWARD,inplace,N}(pp..., X, p.sz, xtype, p.region, false, T),
cROCFFTPlan{T,ROCFFT_FORWARD,inplace,N,R,B}(pp..., X, p.sz, xtype, p.region, p.buffer, false, T),
normalization(X, p.region))
end

function plan_inv(p::rROCFFTPlan{T,ROCFFT_FORWARD,inplace,N}) where {T<:rocfftReals,N,inplace}
function plan_inv(p::rROCFFTPlan{T,ROCFFT_FORWARD,inplace,N,R,B}) where {T<:rocfftReals,N,inplace,R,B}
X = ROCArray{complex(T)}(undef, p.osz)
Y = ROCArray{T}(undef, p.sz)
xtype = rocfft_transform_type_real_inverse
pp = get_plan(xtype, p.sz, T, inplace, p.region)
scale = normalization(Y, p.region)
ScaledPlan(
rROCFFTPlan{complex(T),ROCFFT_INVERSE,inplace,N}(pp..., X, p.sz, xtype, p.region, false, T),
rROCFFTPlan{complex(T),ROCFFT_INVERSE,inplace,N,R,B}(pp..., X, p.sz, xtype, p.region, p.buffer, false, T),
scale)
end

function plan_inv(p::rROCFFTPlan{T,ROCFFT_INVERSE,inplace,N}) where {T<:rocfftComplexes,N,inplace}
function plan_inv(p::rROCFFTPlan{T,ROCFFT_INVERSE,inplace,N,R,B}) where {T<:rocfftComplexes,N,inplace,R,B}
X = ROCArray{real(T)}(undef, p.osz)
xtype = rocfft_transform_type_real_forward
pp = get_plan(xtype, p.osz, T, inplace, p.region)
scale = normalization(X, p.region)
ScaledPlan(
rROCFFTPlan{real(T),ROCFFT_FORWARD,inplace,N}(pp..., X, p.sz, xtype, p.region, true, T),
rROCFFTPlan{real(T),ROCFFT_FORWARD,inplace,N,R,B}(pp..., X, p.sz, xtype, p.region, p.buffer, true, T),
scale)
end

Expand Down Expand Up @@ -227,8 +249,6 @@
function unsafe_execute!(
plan::cROCFFTPlan{T,K,false,N}, X::ROCArray{T,N}, Y::ROCArray{T},
) where {T,N,K}
X = copy(X) # since input array can also be modified
# TODO on 1.11 we need to manually cast `pointer(X)` to `Ptr{Cvoid}`.
update_stream!(plan)
rocfft_execute(plan, [pointer(X),], [pointer(Y),], plan.execution_info)
end
Expand All @@ -238,24 +258,30 @@
X::ROCArray{T,N}, Y::ROCArray{<:rocfftComplexes,N},
) where {T<:rocfftReals,N}
@assert plan.xtype == rocfft_transform_type_real_forward
Xcopy = copy(X)
update_stream!(plan)
rocfft_execute(plan, [pointer(Xcopy),], [pointer(Y),], plan.execution_info)
rocfft_execute(plan, [pointer(X),], [pointer(Y),], plan.execution_info)
end

function unsafe_execute!(
plan::rROCFFTPlan{T,ROCFFT_INVERSE,false,N},
X::ROCArray{T,N}, Y::ROCArray{<:rocfftReals,N},
) where {T<:rocfftComplexes,N}
@assert plan.xtype == rocfft_transform_type_real_inverse
Xcopy = copy(X)
update_stream!(plan)
rocfft_execute(plan, [pointer(Xcopy),], [pointer(Y),], plan.execution_info)
rocfft_execute(plan, [pointer(X),], [pointer(Y),], plan.execution_info)
end

function LinearAlgebra.mul!(y::ROCArray{Ty}, p::ROCFFTPlan{T,K,false}, x::ROCArray{T}) where {T,Ty,K}
assert_applicable(p, x, y)
unsafe_execute!(p, x, y)
if T<:Complex && Ty<:Real
# Out-of-place complex-to-real FFT will always overwrite input x.
# We copy the input x in an auxiliary buffer.
z = p.buffer
copyto!(z, x)
else
z = x
end
assert_applicable(p, z, y)
unsafe_execute!(p, z, y)
return y
end

Expand All @@ -279,5 +305,9 @@
function Base.:(*)(p::rROCFFTPlan{T,ROCFFT_INVERSE,false,N}, x::ROCArray{T,N}) where {T<:rocfftComplexes,N}
@assert p.xtype == rocfft_transform_type_real_inverse
y = ROCArray{real(T)}(undef, p.osz)
mul!(y, p, x)
# Out-of-place complex-to-real FFT will always overwrite input x.
# We copy the input x in an auxiliary buffer.
z = p.buffer
copyto!(z, x)
mul!(y, p, z)
end