Skip to content

Commit

Permalink
Add support for OneScalar cloud optics.
Browse files Browse the repository at this point in the history
Simplify test interface.
  • Loading branch information
sriharshakandala committed Jun 20, 2024
1 parent 968cf09 commit 566832f
Show file tree
Hide file tree
Showing 15 changed files with 193 additions and 130 deletions.
12 changes: 10 additions & 2 deletions ext/cuda/rte_longwave_1scalar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ function rte_lw_noscat_solve!(
angle_disc::AngularDiscretization,
as::AtmosphericState,
lookup_lw::LookUpLW,
lookup_lw_cld::Union{LookUpCld, PadeCld, Nothing} = nothing,
lookup_lw_cld::Union{LookUpCld, Nothing} = nothing,
)
nlay, ncol = AtmosphericStates.get_dims(as)
nlev = nlay + 1
Expand All @@ -72,7 +72,7 @@ function rte_lw_noscat_solve_CUDA!(
ncol,
as::AtmosphericState,
lookup_lw::LookUpLW,
lookup_lw_cld::Union{LookUpCld, PadeCld, Nothing} = nothing,
lookup_lw_cld::Union{LookUpCld, Nothing} = nothing,
)
gcol = threadIdx().x + (blockIdx().x - 1) * blockDim().x # global id
nlev = nlay + 1
Expand All @@ -84,8 +84,16 @@ function rte_lw_noscat_solve_CUDA!(
flux_up_lw = flux_lw.flux_up
flux_dn_lw = flux_lw.flux_dn
flux_net_lw = flux_lw.flux_net
cloud_state = as.cloud_state
@inbounds for igpt in 1:n_gpt
ibnd = major_gpt2bnd[igpt]
if cloud_state isa CloudState
Optics.build_cloud_mask!(
view(cloud_state.mask_lw, :, gcol),
view(cloud_state.cld_frac, :, gcol),
cloud_state.mask_type,
)
end
igpt == 1 && set_flux_to_zero!(flux_lw, gcol)
compute_optical_props!(op, as, src_lw, gcol, igpt, lookup_lw, lookup_lw_cld)
rte_lw_noscat!(src_lw, bcs_lw, op, angle_disc, gcol, flux, igpt, ibnd, nlay, nlev)
Expand Down
17 changes: 10 additions & 7 deletions perf/benchmark.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import Logging
@info "------------------------------------------------- Benchmark: gray_atm"
@suppress_out begin
include(joinpath(root_dir, "test", "gray_atm_utils.jl"))
gray_atmos_lw_equil(ClimaComms.context(), OneScalar, NoScatLWRTE, FT; exfiltrate = true)
gray_atmos_lw_equil(ClimaComms.context(), NoScatLWRTE, FT; exfiltrate = true)
end
(; slv_lw, gray_as) = Infiltrator.exfiltrated
@info "gray_atm lw"
Expand All @@ -32,7 +32,7 @@ end
show(stdout, MIME("text/plain"), trial)
println()

gray_atmos_sw_test(ClimaComms.context(), OneScalar, NoScatSWRTE, FT, 1; exfiltrate = true)
gray_atmos_sw_test(ClimaComms.context(), NoScatSWRTE, FT, 1; exfiltrate = true)
(; slv_sw, as) = Infiltrator.exfiltrated
solve_sw!(slv_sw, as) # compile first
@info "gray_atm sw"
Expand All @@ -56,8 +56,6 @@ toler_sw = Dict(Float64 => Float64(1e-3), Float32 => Float32(0.04))

clear_sky(
ClimaComms.context(),
TwoStream,
TwoStream,
TwoStreamLWRTE,
TwoStreamSWRTE,
VmrGM,
Expand Down Expand Up @@ -96,13 +94,18 @@ println()
@info "------------------------------------------------- Benchmark: all_sky"
# @suppress_out begin
include(joinpath(root_dir, "test", "all_sky_utils.jl"))

toler_lw_noscat = Dict(Float64 => Float64(1e-5), Float32 => Float32(0.05))
toler_lw_2stream = Dict(Float64 => Float64(5), Float32 => Float32(5))
toler_sw = Dict(Float64 => Float64(1e-5), Float32 => Float32(0.06))

all_sky(
ClimaComms.context(),
TwoStream,
TwoStream,
TwoStreamLWRTE,
TwoStreamSWRTE,
FT;
FT,
toler_lw_2stream,
toler_sw;
use_lut = true,
cldfrac = FT(1),
exfiltrate = true,
Expand Down
51 changes: 50 additions & 1 deletion src/optics/CloudOptics.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

"""
add_cloud_optics_2stream!(
τ,
Expand Down Expand Up @@ -57,6 +56,56 @@ to the TwoStream gas optics properties.
return nothing
end

@inline function add_cloud_optics_1scalar!(
τ,
cld_mask,
cld_r_eff_liq,
cld_r_eff_ice,
cld_path_liq,
cld_path_ice,
ice_rgh,
lkp_cld::LookUpCld,
ibnd;
)
@inbounds begin
nlay = length(τ)
lut_extliq, lut_ssaliq, lut_asyliq = LookUpTables.getview_liqdata(lkp_cld, ibnd)
lut_extice, lut_ssaice, lut_asyice = LookUpTables.getview_icedata(lkp_cld, ibnd, ice_rgh)
_, _, nsize_liq, nsize_ice, _ = lkp_cld.dims
radliq_lwr, radliq_upr, _, radice_lwr, radice_upr, _ = lkp_cld.bounds

for glay in 1:nlay
if cld_mask[glay]
# cloud liquid particles
τl, τl_ssa, _ = compute_lookup_cld_liq_props(
nsize_liq,
radliq_lwr,
radliq_upr,
lut_extliq,
lut_ssaliq,
lut_asyliq,
cld_r_eff_liq[glay],
cld_path_liq[glay],
)
# cloud ice particles
τi, τi_ssa, _ = compute_lookup_cld_ice_props(
nsize_ice,
radice_lwr,
radice_upr,
lut_extice,
lut_ssaice,
lut_asyice,
cld_r_eff_ice[glay],
cld_path_ice[glay],
)
# add cloud optical optics
τ[glay] += (τl - τl_ssa) + (τi - τi_ssa)
end
end
end
return nothing
end

@inline function add_cloud_optics_2stream!(
τ,
ssa,
Expand Down
22 changes: 21 additions & 1 deletion src/optics/Optics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ Computes optical properties for the longwave problem.
gcol::Int,
igpt::Int,
lkp::LookUpLW,
lkp_cld::Union{LookUpCld, PadeCld, Nothing} = nothing,
lkp_cld::Union{LookUpCld, Nothing} = nothing,
)
nlay = AtmosphericStates.get_nlay(as)
(; vmr) = as
Expand Down Expand Up @@ -177,6 +177,26 @@ Computes optical properties for the longwave problem.
t_lev_dec = t_lev_inc
end
lev_source[nlay + 1, gcol] = lev_src_inc_prev
if !isnothing(lkp_cld)
cloud_state = as.cloud_state
cld_r_eff_liq = view(cloud_state.cld_r_eff_liq, :, gcol)
cld_r_eff_ice = view(cloud_state.cld_r_eff_ice, :, gcol)
cld_path_liq = view(cloud_state.cld_path_liq, :, gcol)
cld_path_ice = view(cloud_state.cld_path_ice, :, gcol)
cld_mask = view(cloud_state.mask_lw, :, gcol)

add_cloud_optics_1scalar!(
τ,
cld_mask,
cld_r_eff_liq,
cld_r_eff_ice,
cld_path_liq,
cld_path_ice,
cloud_state.ice_rgh,
lkp_cld,
ibnd;
)
end
end
return nothing
end
Expand Down
21 changes: 3 additions & 18 deletions src/optics/RTE.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ configurations for a non-scattering longwave simulation.
# Fields
$(DocStringExtensions.FIELDS)
"""
struct NoScatLWRTE{C, OP, SL <: SourceLWNoScat, BC <: LwBCs, FXBL, FXL <: FluxLW, AD}
struct NoScatLWRTE{C, OP <: OneScalar, SL <: SourceLWNoScat, BC <: LwBCs, FXBL, FXL <: FluxLW, AD}
"ClimaComms context"
context::C
"optical properties"
Expand All @@ -51,18 +51,8 @@ struct NoScatLWRTE{C, OP, SL <: SourceLWNoScat, BC <: LwBCs, FXBL, FXL <: FluxLW
end
Adapt.@adapt_structure NoScatLWRTE

function NoScatLWRTE(
::Type{FT},
::Type{DA},
::Type{OP},
context,
param_set,
nlay,
ncol,
sfc_emis,
inc_flux,
) where {FT, DA, OP}
op = OP(FT, ncol, nlay, DA)
function NoScatLWRTE(::Type{FT}, ::Type{DA}, context, param_set, nlay, ncol, sfc_emis, inc_flux) where {FT, DA}
op = OneScalar(FT, ncol, nlay, DA)
src = SourceLWNoScat(param_set, FT, DA, nlay, ncol)
bcs = LwBCs(sfc_emis, inc_flux)
fluxb = FluxLW(ncol, nlay, FT, DA)
Expand Down Expand Up @@ -106,8 +96,6 @@ function TwoStreamLWRTE(::Type{FT}, ::Type{DA}, context, param_set, nlay, ncol,
return TwoStreamLWRTE(context, op, src, bcs, fluxb, flux)
end

TwoStreamLWRTE(::Type{FT}, ::Type{DA}, ::Type{OP}, args...) where {FT, DA, OP} = TwoStreamLWRTE(FT, DA, args...)

"""
NoScatSWRTE(::Type{FT}, ::Type{DA}, context, nlay, ncol, swbcs...)
Expand Down Expand Up @@ -140,8 +128,6 @@ function NoScatSWRTE(::Type{FT}, ::Type{DA}, context, nlay, ncol, swbcs...) wher
return NoScatSWRTE(context, op, bcs, fluxb, flux)
end

NoScatSWRTE(::Type{FT}, ::Type{DA}, ::Type{OP}, args...) where {FT, DA, OP} = NoScatSWRTE(FT, DA, args...)

"""
TwoStreamSWRTE(::Type{FT}, ::Type{DA}, context, nlay, ncol, swbcs...)
Expand Down Expand Up @@ -177,5 +163,4 @@ function TwoStreamSWRTE(::Type{FT}, ::Type{DA}, context, nlay, ncol, swbcs...) w
return TwoStreamSWRTE(context, op, src, bcs, fluxb, flux)
end

TwoStreamSWRTE(::Type{FT}, ::Type{DA}, ::Type{OP}, args...) where {FT, DA, OP} = TwoStreamSWRTE(FT, DA, args...)
end
2 changes: 1 addition & 1 deletion src/rte/RTESolver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ solve_lw!(
(; context, fluxb, flux, src, bcs, op, angle_disc)::NoScatLWRTE,
as::AtmosphericState,
lookup_lw::LookUpLW,
lookup_lw_cld::Union{LookUpCld, PadeCld},
lookup_lw_cld::LookUpCld,
) = rte_lw_noscat_solve!(context.device, fluxb, flux, src, bcs, op, angle_disc, as, lookup_lw, lookup_lw_cld)

solve_lw!((; context, fluxb, flux, src, bcs, op, angle_disc)::NoScatLWRTE, as::AtmosphericState, lookup_lw::LookUpLW) =
Expand Down
13 changes: 10 additions & 3 deletions src/rte/longwave1scalar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,21 +35,28 @@ function rte_lw_noscat_solve!(
angle_disc::AngularDiscretization,
as::AtmosphericState,
lookup_lw::LookUpLW,
lookup_lw_cld::Union{LookUpCld, PadeCld, Nothing} = nothing,
lookup_lw_cld::Union{LookUpCld, Nothing} = nothing,
)
nlay, ncol = AtmosphericStates.get_dims(as)
nlev = nlay + 1
(; major_gpt2bnd) = lookup_lw.band_data
n_gpt = length(major_gpt2bnd)
τ = op.τ
Ds = angle_disc.gauss_Ds
cloud_state = as.cloud_state
bld_cld_mask = cloud_state isa CloudState
flux_up_lw = flux_lw.flux_up
flux_dn_lw = flux_lw.flux_dn
flux_net_lw = flux_lw.flux_net
@inbounds begin
for igpt in 1:n_gpt
ClimaComms.@threaded device for gcol in 1:ncol
ibnd = major_gpt2bnd[igpt]
if bld_cld_mask
Optics.build_cloud_mask!(
view(cloud_state.mask_lw, :, gcol),
view(cloud_state.cld_frac, :, gcol),
cloud_state.mask_type,
)
end
igpt == 1 && set_flux_to_zero!(flux_lw, gcol)
compute_optical_props!(op, as, src_lw, gcol, igpt, lookup_lw, lookup_lw_cld)
rte_lw_noscat!(src_lw, bcs_lw, op, angle_disc, gcol, flux, igpt, ibnd, nlay, nlev)
Expand Down
27 changes: 23 additions & 4 deletions test/all_sky.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,33 @@ FT = get(ARGS, 1, Float64) == "Float32" ? Float32 : Float64
include("all_sky_utils.jl")

context = ClimaComms.context()
@testset "Cloudy (all-sky, Two-stream calculations using lookup table method" begin

toler_lw_noscat = Dict(Float64 => Float64(1e-5), Float32 => Float32(0.05))
toler_lw_2stream = Dict(Float64 => Float64(5), Float32 => Float32(5))
toler_sw = Dict(Float64 => Float64(1e-5), Float32 => Float32(0.06))

@testset "Cloudy-sky (gas + clouds) calculations using lookup table method, with non-scattering LW and TwoStream SW solvers" begin
@time all_sky(
context,
NoScatLWRTE,
TwoStreamSWRTE,
FT,
toler_lw_noscat,
toler_sw;
ncol = 128,
use_lut = true,
cldfrac = FT(1),
)
end

@testset "Cloudy-sky (gas + clouds) Two-stream calculations using lookup table method" begin
@time all_sky(
context,
TwoStream,
TwoStream,
TwoStreamLWRTE,
TwoStreamSWRTE,
FT;
FT,
toler_lw_2stream,
toler_sw;
ncol = 128,
use_lut = true,
cldfrac = FT(1),
Expand Down
29 changes: 10 additions & 19 deletions test/all_sky_dyamond_gpu_benchmark.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,13 @@ include("read_all_sky.jl")

function benchmark_all_sky(
context,
::Type{OPLW},
::Type{OPSW},
::Type{SLVLW},
::Type{SLVSW},
::Type{FT};
ncol = 128,# repeats col#1 ncol times per RRTMGP example
use_lut::Bool = true,
cldfrac = FT(1),
) where {FT <: AbstractFloat, OPLW, OPSW, SLVLW, SLVSW}
) where {FT <: AbstractFloat, SLVLW, SLVSW}
overrides = (; grav = 9.80665, molmass_dryair = 0.028964, molmass_water = 0.018016)
param_set = RRTMGPParameters(FT, overrides)

Expand Down Expand Up @@ -97,11 +95,11 @@ function benchmark_all_sky(

# Setting up longwave problem---------------------------------------
inc_flux = nothing
slv_lw = SLVLW(FT, DA, OPLW, context, param_set, nlay, ncol, sfc_emis, inc_flux)
slv_lw = SLVLW(FT, DA, context, param_set, nlay, ncol, sfc_emis, inc_flux)
# Setting up shortwave problem---------------------------------------
inc_flux_diffuse = nothing
swbcs = (cos_zenith, toa_flux, sfc_alb_direct, inc_flux_diffuse, sfc_alb_diffuse)
slv_sw = SLVSW(FT, DA, OPSW, context, nlay, ncol, swbcs...)
slv_sw = SLVSW(FT, DA, context, nlay, ncol, swbcs...)
#------calling solvers
solve_lw!(slv_lw, as, lookup_lw, lookup_lw_cld)
trial_lw = @benchmark CUDA.@sync solve_lw!($slv_lw, $as, $lookup_lw, $lookup_lw_cld)
Expand All @@ -111,13 +109,14 @@ function benchmark_all_sky(
return trial_lw, trial_sw
end

function generate_gpu_allsky_benchmarks(FT, npts)
function generate_gpu_allsky_benchmarks(FT, npts, ::Type{SLVLW}, ::Type{SLVSW}) where {SLVLW, SLVSW}
context = ClimaComms.context()
# compute equivalent ncols for DYAMOND resolution
helems, nlevels, nlev_test, nq = 30, 64, 73, 4
ncols_dyamond = Int(ceil(helems * helems * 6 * nq * nq * (nlevels / nlev_test)))
println("\n")
printstyled("Running DYAMOND all-sky benchmark on $(context.device) device with $FT precision\n", color = 130)
printstyled("Longwave solver = $SLVLW; Shortwave solver = $SLVSW\n", color = 130)
printstyled("==============|====================================|==================================\n", color = 130)
printstyled(
" ncols | median time for longwave solver | median time for shortwave solver \n",
Expand All @@ -128,17 +127,7 @@ function generate_gpu_allsky_benchmarks(FT, npts)
ncols = unsafe_trunc(Int, cld(ncols_dyamond, 2^(pts - 1)))
ndof = ncols * nlev_test
sz_per_fld_gb = ndof * sizeof(FT) / 1024 / 1024 / 1024
trial_lw, trial_sw = benchmark_all_sky(
context,
TwoStream,
TwoStream,
TwoStreamLWRTE,
TwoStreamSWRTE,
FT;
ncol = ncols,
use_lut = true,
cldfrac = FT(1),
)
trial_lw, trial_sw = benchmark_all_sky(context, SLVLW, SLVSW, FT; ncol = ncols, use_lut = true, cldfrac = FT(1))
Printf.@printf(
"%10i | %25s| %25s \n",
ncols,
Expand All @@ -150,5 +139,7 @@ function generate_gpu_allsky_benchmarks(FT, npts)
return nothing
end

generate_gpu_allsky_benchmarks(Float64, 4)
generate_gpu_allsky_benchmarks(Float32, 4)
for FT in (Float32, Float64)
generate_gpu_allsky_benchmarks(FT, 4, NoScatLWRTE, TwoStreamSWRTE)
generate_gpu_allsky_benchmarks(FT, 4, TwoStreamLWRTE, TwoStreamSWRTE)
end
Loading

0 comments on commit 566832f

Please sign in to comment.