Skip to content

Commit

Permalink
Merge pull request #160 from geocryology/bugfix-pl-elev-interp
Browse files Browse the repository at this point in the history
fix/improve interpolation quality
  • Loading branch information
nicholas512 authored Jan 29, 2025
2 parents 1af867e + 61b64ba commit 5e8d828
Show file tree
Hide file tree
Showing 10 changed files with 222 additions and 54 deletions.
2 changes: 1 addition & 1 deletion globsim/_version.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
__version__ = "4.1.5"
__version__ = "4.2.0"


2 changes: 1 addition & 1 deletion globsim/download/JraDownloadHandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ def copy_variable_in_chunks(src_var, dst_var, max_mem_gb = 4):

size_in_gb = src_var.data.nbytes * 1e-9
n_time = src_var.shape[0]
n_chunks = np.ceil(size_in_gb / max_mem_gb)
n_chunks = int(np.ceil(size_in_gb / max_mem_gb))
chunk_size = int(n_time / n_chunks)
chunk_dims = list(src_var.shape)
chunk_dims[0] = chunk_size
Expand Down
25 changes: 20 additions & 5 deletions globsim/globsim_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,30 @@
from globsim import globsim_convert, globsim_download, globsim_scale, globsim_interpolate
from globsim.globsim_convert import export_styles
from globsim._version import __version__
from globsim.view.interp_vis import main_args as interp_vis_main


def configure_logging(args):
def configure_logging(args: argparse.Namespace):
# logging.basicConfig(format='%(asctime)s %(asctime)s ')
logger.setLevel(args.level)
try:
level = args.level
except AttributeError:
level = logging.INFO

logger.setLevel(level)
console_formatter = logging.Formatter('%(asctime)s %(levelname)s %(message)s', datefmt="%H:%M:%S")
# console_formatter = logging.Formatter('%(message)s', datefmt="%H:%M:%S")

ch = logging.StreamHandler()
ch.setLevel(args.level)
ch.setLevel(level)
ch.setFormatter(console_formatter)
logger.addHandler(ch)

if args.logfile:
if getattr(args, "logfile", None):
file_formatter = logging.Formatter('%(asctime)s %(name)s %(levelname)s %(message)s', datefmt="%Y-%m-%d %H:%M:%S")
logfile = args.logfile # TODO: write logfile to project directory if missing
fh = logging.FileHandler(logfile)
fh.setLevel(args.level)
fh.setLevel(level)
fh.setFormatter(file_formatter)
logger.addHandler(fh)

Expand All @@ -50,6 +55,7 @@ def main():
interpolate = subparsers.add_parser("interpolate")
scale = subparsers.add_parser("scale")
convert = subparsers.add_parser("convert")
view = subparsers.add_parser("view")

mainparser.add_argument("--version", action='version', version=f"GlobSim version {__version__}")

Expand Down Expand Up @@ -94,6 +100,15 @@ def main():
help="(optional) The name of the site you want to export. If not provided, all sites will be exported")
convert.add_argument('-p', "--profile", dest='profile', default=None, type=str, help="Path to an 'export profile' TOML file (geotop only) ")

view.set_defaults(func=interp_vis_main)
view.add_argument("file", nargs="?", type=str, help="file to plot")
view.add_argument("--file", dest='file', type=str, help="file to plot")
view.add_argument("-v", "--var", type=str, dest='variable', help="variable to plot")
view.add_argument("-a", "--agg", choices=["1h", "6h", "D", "ME", "YE"], dest='aggregate', default="ME", help="aggregate data")
view.add_argument("-o", "--output", type=str, dest='output', help="output directory")



if len(sys.argv) == 1:
mainparser.print_help(sys.stderr)
sys.exit(1)
Expand Down
8 changes: 0 additions & 8 deletions globsim/globsim_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,19 +33,11 @@

from multiprocessing.dummy import Pool as ThreadPool
from globsim.LazyLoader import LazyLoader
from globsim.download.era5_monthly import download_threadded

download = LazyLoader('globsim.download')
interpolate = LazyLoader('globsim.interpolate')
scale = LazyLoader('globsim.scale')

# from globsim.download import *
# from globsim.scale import *
# from globsim.interpolate import *

# from globsim.JRA import JRAdownload, JRAinterpolate, JRAscale


def GlobsimDownload(pfile, multithread=True,
ERA5=True,
ERA5ENS=True, MERRA=True, JRA=True, JRA3Q=True, JRA3QG=True):
Expand Down
88 changes: 71 additions & 17 deletions globsim/interp.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def ele_interpolate(elevation: np.ndarray, h: float, nl: int):
Parameters
----------
elevation : np.ndarray
elevation of data at (time, level) in meters. Must be monotonically decreasing along axis 1
elevation of data at (time, level) in meters. Must be monotonically increasing/decreasing along axis 1
h : float
elevation of station
nl : int
Expand All @@ -26,27 +26,81 @@ def ele_interpolate(elevation: np.ndarray, h: float, nl: int):
vb : np.ndarray
array of indices corresponding to nearest level below station
"""

# difference in elevation, level directly above will be >= 0
if np.all(np.diff(elevation, axis=1) > 0): # monotonically increasing
elev_diff = -(elevation - h)
elif np.all(np.diff(elevation, axis=1) < 0): # monotonically decreasing
elev_diff = elevation - h
# difference in elevation, level directly station above will be >= 0
elev_diff = elevation - h
i_ravel = np.arange(elevation.shape[0]) * elevation.shape[1] # indices for first level in each time
if np.all(np.diff(elevation, axis=1) > 0): # elevations monotonically increasing
inverted = True
elif np.all(np.diff(elevation, axis=1) < 0): # elevations monotonically decreasing
inverted = False
else:
raise ValueError("Elevation must be monotonically decreasing or increasing")
if inverted:
va = np.argmin(elev_diff + (elev_diff < 0) * 1e6, axis=1) # level indices that are directly above station.
mask = (va != 0) # station is above lowest level
va += i_ravel # will be added to raveled data
vb = va - mask # next-lowest station index when va != 0
else:
# vector of level indices that fall directly above station.
# Apply after ravel() of data.
va = np.argmin(elev_diff + (elev_diff < 0) * 100000, axis=1)
# mask for situations where station is below lowest level
mask = va < (nl - 1)
va += i_ravel

# Vector level indices that fall directly below station.
# Apply after ravel() of data.
vb = va + mask # +1 when OK, +0 when below lowest level

# vector of level indices that fall directly above station.
# Apply after ravel() of data.
va = np.argmin(elev_diff + (elev_diff < 0) * 100000, axis=1)
# mask for situations where station is below lowest level
mask = va < (nl - 1)
va += np.arange(elevation.shape[0]) * elevation.shape[1]
return elev_diff, va, vb

# Vector level indices that fall directly below station.
# Apply after ravel() of data.
vb = va + mask # +1 when OK, +0 when below lowest level

return elev_diff, va, vb
def extrapolate_below_grid(elevation: np.ndarray, data:np.ndarray, h: float):
"""
Parameters
----------
elevation : np.ndarray
elevation of data at (time, level) in meters. Must be monotonically increasing/decreasing along axis 1
data : np.ndarray
data to interpolate.
h : float
elevation of station
Returns
-------
epol : np.ndarray
extrapolated data for station below lowest level (masked where station is above lowest level)
"""
nl = elevation.shape[1]
i_ravel = np.arange(elevation.shape[0]) * nl # indices for first level in each time

if np.all(np.diff(elevation, axis=1) > 0):
inverted = True # elevations monotonically increasing
elif np.all(np.diff(elevation, axis=1) < 0):
inverted = False # elevations monotonically decreasing
else:
raise ValueError("Elevation must be monotonically decreasing or increasing")

if inverted:
delta_s = elevation[:, 0] - h
delta_L = elevation[:, 1] - elevation[:, 0]
i_lowest_diff = i_ravel
i_lowest_data = i_ravel
R = -(delta_s / delta_L)

else:
delta_s = elevation[:, -1] - h
delta_L = elevation[:, -2] - elevation[:, -1]
i_lowest_diff = i_ravel + nl - 2
i_lowest_data = i_ravel + nl - 1
R = delta_s / delta_L

below_lowest = np.where(np.min(elevation, axis=1) > h)[0] # indices of times where station is below lowest level
delta_V = np.diff(data)[i_lowest_diff] # difference between levels at each time [nt,]
epol = np.ma.MaskedArray(data=data[i_lowest_data] + R * delta_V, mask=True) # extrapolated values
epol.mask[below_lowest] = False

return epol


def calculate_weights(elev_diff, va, vb) -> tuple:
Expand Down
13 changes: 7 additions & 6 deletions globsim/interpolate/ERA5interpolate.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from globsim.common_utils import variables_skip, str_encode
from globsim.interpolate.GenericInterpolate import GenericInterpolate
from globsim.nc_elements import netcdf_base, new_interpolated_netcdf
from globsim.interp import ele_interpolate, calculate_weights
from globsim.interp import ele_interpolate, calculate_weights, extrapolate_below_grid
import globsim.constants as const

logger = logging.getLogger('globsim.interpolate')
Expand Down Expand Up @@ -402,8 +402,13 @@ def levels2elevation(self, ncfile_in, ncfile_out):
# read data from netCDF
logger.debug(f"Reading {var}")
data = ncf.variables[var][:,:,n].ravel()

ipol = data[va] * wa + data[vb] * wb # interpolated value

if self.extrapolate_below_grid:
extrapolated_values = extrapolate_below_grid(elevation, data, h)
ipol = np.where(~extrapolated_values.mask, extrapolated_values, ipol)

rootgrp.variables[var][:,n] = ipol # write to file

rootgrp.vars_written = " ".join(set(str(rootgrp.vars_written).split(" ") + [var]))
Expand Down Expand Up @@ -487,7 +492,3 @@ def _process_sf(self):
with xr.open_mfdataset(self.get_input_file_paths('sf'), decode_times=False) as sf:
self.ERA2station(sf, self.getOutFile('sf'),
self.stations, varlist, date=self.date)


# vv = nc.Dataset(self.getOutFile('pl'))
# xx = xr.open_dataset(self.getOutFile('sa'))
21 changes: 19 additions & 2 deletions globsim/interpolate/GenericInterpolate.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,25 @@ def __init__(self, ifile: str, **kwargs):
self._skip_sa = kwargs.get('skip_sa', False)
self._skip_sf = kwargs.get('skip_sf', False)
self._skip_pl = kwargs.get('skip_pl', False)
self.resume = kwargs.get('resume', False)

self.resume = bool(self.read_and_report(kwargs, 'resume', False))
self.extrapolate_below_grid = bool(self.read_and_report(kwargs, 'extrapolate_below_grid', True))

def read_and_report(self, kwargs, name=None, default=None):
value = kwargs.get(name, "MISSING FROM KWARGS")

if value == "MISSING FROM KWARGS":
value = self.par.get(name, "MISSING FROM TOML")
if value == "MISSING FROM TOML":
value = default
setfrom = "DEFAULT"
else:
setfrom = "TOML "
else:
setfrom = "CLI "
value = self.par.get(name, default)
logger.debug(f"{setfrom} {name}: {value}")
return value

@property
def vn_time(self):
return 'time'
Expand Down
12 changes: 7 additions & 5 deletions globsim/interpolate/JRAinterpolate.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from globsim.common_utils import str_encode, variables_skip
from globsim.interpolate.GenericInterpolate import GenericInterpolate
from globsim.nc_elements import netcdf_base, new_interpolated_netcdf
from globsim.interp import calculate_weights, ele_interpolate
from globsim.interp import calculate_weights, ele_interpolate, extrapolate_below_grid

logger = logging.getLogger('globsim.interpolate')

Expand Down Expand Up @@ -333,12 +333,14 @@ def levels2elevation(self, ncfile_in, ncfile_out):
data = np.repeat([ncf.variables['level'][:]],
len(time),axis=0).ravel()
else:
# read data from netCDF
data = ncf.variables[var][:,:,n].ravel()

multvawa = np.multiply(data[va], wa)
multvbwb = np.multiply(data[vb], wb)
ipol = multvawa + multvbwb
ipol = np.multiply(data[va], wa) + np.multiply(data[vb], wb)

if self.extrapolate_below_grid:
extrapolated_values = extrapolate_below_grid(elevation, data, h)
ipol = np.where(~extrapolated_values.mask, extrapolated_values, ipol)

rootgrp.variables[var][:,n] = ipol # assign to file
rootgrp.vars_written = " ".join(set(str(rootgrp.vars_written).split(" ") + [var]))

Expand Down
21 changes: 12 additions & 9 deletions globsim/interpolate/MERRAinterpolate.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from globsim.common_utils import str_encode, variables_skip
from globsim.interpolate.GenericInterpolate import GenericInterpolate
from globsim.nc_elements import netcdf_base
from globsim.interp import calculate_weights, ele_interpolate
from globsim.interp import calculate_weights, ele_interpolate, extrapolate_below_grid

import warnings
warnings.filterwarnings("ignore", category=UserWarning, module='netCDF4')
Expand Down Expand Up @@ -353,20 +353,23 @@ def levels2elevation(self, ncfile_in, ncfile_out):
# pressure [hPa] variable from levels, shape: (time, level)
data = np.repeat([ncf.variables['level'][:]],
len(time),axis=0).ravel()
ipol = data[va] * wa + data[vb] * wb # interpolated value

# if mask[pixel] == false, pass the maximum of pressure level to pixles
level_highest = ncf.variables['level'][:][-1]

# 2025-01-28 [NB]: I don't think this block of code is necessary
"""level_highest = ncf.variables['level'][:][-1]
level_lowest = ncf.variables['level'][:][0]
for j, value in enumerate(ipol):
if value == level_highest:
ipol[j] = level_lowest

ipol[j] = level_lowest"""
else:
# read data from netCDF
data = ncf.variables[var][:,:,n].ravel()
ipol = data[va] * wa + data[vb] * wb # interpolated value

ipol = data[va] * wa + data[vb] * wb # interpolated value

if self.extrapolate_below_grid:
extrapolated_values = extrapolate_below_grid(elevation, data, h)
ipol = np.where(~extrapolated_values.mask, extrapolated_values, ipol)

rootgrp.variables[var][:,n] = ipol # assign to file
rootgrp.vars_written = " ".join(set(str(rootgrp.vars_written).split(" ") + [var]))

Expand Down
Loading

0 comments on commit 5e8d828

Please sign in to comment.