Skip to content

Commit 8cd78c4

Browse files
committed
Ingestion of ML inference code from Danni Du
1 parent fab56f9 commit 8cd78c4

File tree

2 files changed

+645
-57
lines changed

2 files changed

+645
-57
lines changed

src/ocean_data_assim/MOM_oda_driver.F90

+43-20
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ module MOM_oda_driver_mod
3030
use ocean_da_types_mod, only : ensemble_control_struct, ocean_control_struct
3131
use ocean_da_core_mod, only : ocean_da_core_init, get_profiles
3232
use MOM_oda_ml_mod, only: oda_ml_init, oda_ml_end, oda_ml_inference
33-
use MOM_oda_ml_mod, only: ocean_oda_ml_struct
33+
use MOM_oda_ml_mod, only: ocean_oda_ml_data, ocean_oda_ml_config
3434
!This preprocessing directive enables the SPEAR online ensemble data assimilation
3535
!configuration. Existing community based APIs for data assimilation are currently
3636
!called offline for forecast applications using information read from a MOM6 state file.
@@ -101,7 +101,8 @@ module MOM_oda_driver_mod
101101
type(ensemble_control_struct), pointer :: Ocean_increment=> NULL() !< A separate structure for
102102
!! increment diagnostics
103103
type(ocean_control_struct), pointer :: Ocean_background_ave=> NULL() !< ocean averaged prior states in model space
104-
type(ocean_oda_ml_struct), pointer :: ml_CS => NULL()
104+
type(ocean_oda_ml_data), pointer :: ml_data => NULL()
105+
type(ocean_oda_ml_config), pointer :: ml_config => NULL()
105106
integer :: nk !< number of vertical layers used for DA
106107
type(ocean_grid_type), pointer :: Grid => NULL() !< MOM6 grid type and decomposition for the DA
107108
type(ocean_grid_type), pointer :: model_G => NULL() !< MOM6 grid type and decomposition for the model
@@ -159,6 +160,8 @@ module MOM_oda_driver_mod
159160
type(INC_CS) :: INC_CS !< A Structure containing integer file handles for bias adjustment
160161
integer :: id_inc_t !< A diagnostic handle for the temperature climatological adjustment
161162
integer :: id_inc_s !< A diagnostic handle for the salinity climatological adjustment
163+
integer :: id_inc_ml_t !< A diagnostic handle for the temperature climatological adjustment
164+
integer :: id_inc_ml_s !< A diagnostic handle for the salinity climatological adjustment
162165
integer :: answer_date !< The vintage of the order of arithmetic and expressions in the
163166
!! remapping invoked by the ODA driver. Values below 20190101 recover
164167
!! the answers from the end of 2018, while higher values use updated
@@ -435,9 +438,15 @@ subroutine init_oda(Time, G, GV, US, diag_CS, CS)
435438

436439
if (CS%do_ml_bias_adjustment) then
437440

438-
allocate(CS%ml_CS)
439-
call oda_ml_init(CS%ml_CS, CS%GV%ke)
441+
allocate(CS%ml_data)
442+
allocate(CS%ml_config)
443+
call oda_ml_init(CS%ml_config, CS%ml_data, CS%GV)
440444

445+
CS%id_inc_ml_t = register_diag_field('ocean_model', 'temp_ml_increment', diag_CS%axesTL, &
446+
Time, 'ocean potential temperature increments predicted by ML', 'degC', conversion=US%C_to_degC)
447+
CS%id_inc_ml_s = register_diag_field('ocean_model', 'salt_ml_increment', diag_CS%axesTL, &
448+
Time, 'ocean salinity increments predicted by ML', 'psu', conversion=US%S_to_ppt)
449+
441450
allocate(CS%T_ml_tend(G%isd:G%ied,G%jsd:G%jed,CS%GV%ke), source=0.0)
442451
allocate(CS%S_ml_tend(G%isd:G%ied,G%jsd:G%jed,CS%GV%ke), source=0.0)
443452

@@ -781,23 +790,37 @@ subroutine get_ML_bias_correction(Time, US, CS)
781790
!! Loop through all local gridpoints
782791
do j=CS%model_G%jsc,CS%model_G%jec ; do i=CS%model_G%isc,CS%model_G%iec
783792

784-
!! put local variables into ml_CS
785-
CS%ml_CS%T = CS%Ocean_background_ave%T(i,j,:)
786-
CS%ml_CS%S = CS%Ocean_background_ave%S(i,j,:)
787-
CS%ml_CS%latent = CS%Ocean_background_ave%latent(i,j)
788-
CS%ml_CS%sensible = CS%Ocean_background_ave%sensible(i,j)
789-
CS%ml_CS%lw = CS%Ocean_background_ave%lw(i,j)
790-
CS%ml_CS%sw = CS%Ocean_background_ave%sw(i,j)
793+
!! put local variables into ml_data
794+
CS%ml_data%T = CS%Ocean_background_ave%T(i,j,:)
795+
CS%ml_data%S = CS%Ocean_background_ave%S(i,j,:)
796+
CS%ml_data%U_left = CS%Ocean_background_ave%U(i-1,j,:)
797+
CS%ml_data%U_right = CS%Ocean_background_ave%U(i,j,:)
798+
CS%ml_data%V_north = CS%Ocean_background_ave%V(i,j,:)
799+
CS%ml_data%V_south = CS%Ocean_background_ave%V(i,j-1,:)
800+
CS%ml_data%latent = CS%Ocean_background_ave%latent(i,j)
801+
CS%ml_data%sensible = CS%Ocean_background_ave%sensible(i,j)
802+
CS%ml_data%lw = CS%Ocean_background_ave%lw(i,j)
803+
CS%ml_data%sw = CS%Ocean_background_ave%sw(i,j)
804+
CS%ml_data%taux_left = CS%Ocean_background_ave%taux(i-1,j)
805+
CS%ml_data%taux_right = CS%Ocean_background_ave%taux(i,j)
806+
CS%ml_data%tauy_north = CS%Ocean_background_ave%tauy(i,j)
807+
CS%ml_data%tauy_south = CS%Ocean_background_ave%tauy(i,j-1)
791808

809+
CS%ml_data%dyCu_left = CS%model_G%dyCu(i-1,j)
810+
CS%ml_data%dyCu_right = CS%model_G%dyCu(i,j)
811+
CS%ml_data%dxCv_north = CS%model_G%dxCv(i,j)
812+
CS%ml_data%dxCv_south = CS%model_G%dxCv(i,j-1)
813+
CS%ml_data%areacello = CS%model_G%areaT(i,j)
814+
792815
!! Call inference subroutine with the concatenated vector
793-
call oda_ml_inference(CS%ml_CS)
816+
call oda_ml_inference(CS%ml_config, CS%ml_data)
794817

795-
! CS%T_ml_tend(i,j,:) = CS%ml_CS%T_inc
796-
! CS%S_ml_tend(i,j,:) = CS%ml_CS%S_inc
818+
CS%T_ml_tend(i,j,:) = CS%ml_data%T_inc
819+
CS%S_ml_tend(i,j,:) = CS%ml_data%S_inc
797820
enddo; enddo
798821

799-
CS%T_ml_tend = CS%T_bc_tend * CS%ml_bias_adjustment_multiplier
800-
CS%S_ml_tend = CS%S_bc_tend * CS%ml_bias_adjustment_multiplier
822+
CS%T_ml_tend = CS%T_ml_tend * CS%ml_bias_adjustment_multiplier
823+
CS%S_ml_tend = CS%S_ml_tend * CS%ml_bias_adjustment_multiplier
801824

802825
call pass_var(CS%T_ml_tend, CS%domains(CS%ensemble_id))
803826
call pass_var(CS%S_ml_tend, CS%domains(CS%ensemble_id))
@@ -936,10 +959,10 @@ subroutine apply_oda_tracer_increments(dt, Time_end, G, GV, tv, h, CS)
936959
T_tend = T_tend + CS%T_tend
937960
S_tend = S_tend + CS%S_tend
938961
endif
939-
! if (CS%do_bias_adjustment ) then
940-
! T_tend = T_tend + CS%T_bc_tend
941-
! S_tend = S_tend + CS%S_bc_tend
942-
! endif
962+
if (CS%do_bias_adjustment ) then
963+
T_tend = T_tend + CS%T_bc_tend
964+
S_tend = S_tend + CS%S_bc_tend
965+
endif
943966
if (CS%do_ml_bias_adjustment ) then
944967
T_tend = T_tend + CS%T_ml_tend
945968
S_tend = S_tend + CS%S_ml_tend

0 commit comments

Comments
 (0)