Skip to content
Snippets Groups Projects
Commit 4299be40 authored by he7273's avatar he7273
Browse files

merge for forecasts, support steps

parents fc6744ba 636fde10
No related branches found
No related tags found
No related merge requests found
...@@ -10,7 +10,7 @@ import cartopy.crs as ccrs ...@@ -10,7 +10,7 @@ import cartopy.crs as ccrs
from .processing import populate_object, compute_cv from .processing import populate_object, compute_cv
from skimage.draw import line_aa from skimage.draw import line_aa
from enstools.feature.util.enstools_utils import get_u_var, get_v_var, get_vertical_dim, get_longitude_dim, get_latitude_dim from enstools.feature.util.enstools_utils import get_u_var, get_v_var, get_vertical_dim, get_longitude_dim, get_latitude_dim, get_init_time_dim, get_valid_time_dim
import threading import threading
from skimage.draw import line from skimage.draw import line
...@@ -59,9 +59,24 @@ class AEWIdentification(IdentificationStrategy): ...@@ -59,9 +59,24 @@ class AEWIdentification(IdentificationStrategy):
lat_range = self.config.data_lat lat_range = self.config.data_lat
clim_file = self.config.get_clim_file() clim_file = self.config.get_clim_file()
u_name = self.config.u_dim if self.config.u_dim is not None else get_u_var(dataset)
v_name = self.config.v_dim if self.config.v_dim is not None else get_v_var(dataset)
if u_name is None or v_name is None:
print("Could not locate u and v fields in dataset. Needed to compute advection terms.")
exit()
# restrict dimensions only to the ones present in u/v
all_dims = [d for d in dataset.dims]
for d in all_dims:
if d not in dataset[u_name].dims:
dataset = dataset.drop_dims(d)
level_str = get_vertical_dim(dataset) level_str = get_vertical_dim(dataset)
lat_str = get_latitude_dim(dataset) lat_str = get_latitude_dim(dataset)
lon_str = get_longitude_dim(dataset) lon_str = get_longitude_dim(dataset)
init_time_str = get_init_time_dim(dataset)
valid_time_str = get_valid_time_dim(dataset)
if os.path.isfile(clim_file): if os.path.isfile(clim_file):
cv_clim = xr.open_dataset(clim_file) cv_clim = xr.open_dataset(clim_file)
...@@ -83,37 +98,37 @@ class AEWIdentification(IdentificationStrategy): ...@@ -83,37 +98,37 @@ class AEWIdentification(IdentificationStrategy):
# --------------- SUBSET DATA ACCORDING TO CFG # --------------- SUBSET DATA ACCORDING TO CFG
start_date_dt = np.datetime64(self.config.start_date) if self.config.start_date is not None else None start_date_dt = np.datetime64(self.config.start_date) if self.config.start_date is not None else None
end_date_dt = np.datetime64(self.config.end_date) if self.config.end_date is not None else None end_date_dt = np.datetime64(self.config.end_date) if self.config.end_date is not None else None
# if data is lon=0..360, change it to -180..180 # if data is lon=0..360, change it to -180..180
dataset.coords[lon_str] = (dataset.coords[lon_str] + 180) % 360 - 180 dataset.coords[lon_str] = (dataset.coords[lon_str] + 180) % 360 - 180
dataset = dataset.sortby(dataset[lon_str]) dataset = dataset.sortby(dataset[lon_str])
filter_time_str = init_time_str if init_time_str is not None else valid_time_str
# get the data we want to investigate # get the data we want to investigate
dataset = dataset.sortby(lat_str) # in case of descending dataset = dataset.sortby(lat_str) # in case of descending
dataset = dataset.sel( dataset = dataset.sel(
**{lat_str: slice(lat_range[0], lat_range[1])}, **{lat_str: slice(lat_range[0], lat_range[1])},
**{lon_str: slice(lon_range[0], lon_range[1])}, **{lon_str: slice(lon_range[0], lon_range[1])},
time=slice(start_date_dt, end_date_dt)) **{filter_time_str: slice(start_date_dt, end_date_dt)})
if len(dataset.time.values) == 0: if len(dataset.time.values) == 0:
print("Given start and end time leads to no data to process.") print("Given start and end time leads to no data to process.")
exit(1) exit(1)
u_name = self.config.u_dim if self.config.u_dim is not None else get_u_var(dataset)
v_name = self.config.v_dim if self.config.v_dim is not None else get_v_var(dataset)
if u_name is None or v_name is None:
print("Could not locate u and v fields in dataset. Needed to compute advection terms.")
exit()
# dataset = dataset.expand_dims('level') # dataset = dataset.expand_dims('level')
# level_str = 'level' # level_str = 'level'
if level_str in dataset[u_name].dims: if level_str in dataset[u_name].dims:
dataset = dataset.sel(**{level_str: self.config.levels}) # 3-D wind field, select levels dataset = dataset.sel(**{level_str: self.config.levels}) # 3-D wind field, select levels
else: elif level_str is not None:
dataset = dataset.sel(**{level_str: self.config.levels}) # 2-D, reduce 3-D field too dataset = dataset.sel(**{level_str: self.config.levels}) # 2-D, reduce 3-D field too
dataset[u_name] = dataset[u_name].expand_dims('level') dataset[u_name] = dataset[u_name].expand_dims('level')
dataset[v_name] = dataset[v_name].expand_dims('level') dataset[v_name] = dataset[v_name].expand_dims('level')
else:
dataset[u_name] = dataset[u_name].expand_dims(level=self.config.levels)
dataset[v_name] = dataset[v_name].expand_dims(level=self.config.levels)
level_str = 'level'
# rename cv_clim dimensions to be same as in data. # rename cv_clim dimensions to be same as in data.
cv_clim = cv_clim.rename({'lat': lat_str, 'lon': lon_str}) cv_clim = cv_clim.rename({'lat': lat_str, 'lon': lon_str})
...@@ -207,6 +222,8 @@ class AEWIdentification(IdentificationStrategy): ...@@ -207,6 +222,8 @@ class AEWIdentification(IdentificationStrategy):
def identify(self, data_chunk: xr.Dataset, **kwargs): def identify(self, data_chunk: xr.Dataset, **kwargs):
objs = [] objs = []
trough_mask_cur = data_chunk.trough_mask trough_mask_cur = data_chunk.trough_mask
if np.isnan(trough_mask_cur).all():
return data_chunk, objs
def clip(tup, mint, maxt): def clip(tup, mint, maxt):
return np.clip(tup, mint, maxt) return np.clip(tup, mint, maxt)
......
...@@ -17,8 +17,6 @@ xr.set_options(keep_attrs=True) ...@@ -17,8 +17,6 @@ xr.set_options(keep_attrs=True)
import numpy as np import numpy as np
from pprint import pprint from pprint import pprint
pipeline = FeaturePipeline(african_easterly_waves_pb2, processing_mode='2d') pipeline = FeaturePipeline(african_easterly_waves_pb2, processing_mode='2d')
# in_files_all_cv_data = cfg.cv_data_ex # in_files_all_cv_data = cfg.cv_data_ex
...@@ -84,11 +82,11 @@ for trackable_set in od.sets: ...@@ -84,11 +82,11 @@ for trackable_set in od.sets:
ds = pipeline.get_data() ds = pipeline.get_data()
ds_set = get_subset_by_description(ds, trackable_set, '2d') ds_set = get_subset_by_description(ds, trackable_set, '2d')
# plot differences (passed filtering / did not pass) # plot differences (passed filtering / did not pass)
plot_differences(g, tracks, cv=ds_set.cv) # plot_differences(g, tracks, cv=ds_set.cv)
# no out data besides plots on kitweather # no out data besides plots on kitweather
if sys.argv[1] == '-kw': if len(sys.argv) > 1 and sys.argv[1] == '-kw':
# delete old plots # delete old plots
subdirs = [dI for dI in os.listdir(plot_dir) if os.path.isdir(os.path.join(plot_dir,dI))] subdirs = [dI for dI in os.listdir(plot_dir) if os.path.isdir(os.path.join(plot_dir,dI))]
for sd in subdirs: # for each subdir in plot dir for sd in subdirs: # for each subdir in plot dir
......
...@@ -8,7 +8,7 @@ import bisect ...@@ -8,7 +8,7 @@ import bisect
import logging import logging
from enstools.feature.util.data_utils import SplitDimension, print_lock from enstools.feature.util.data_utils import SplitDimension, print_lock
from itertools import product from itertools import product
from enstools.feature.util.data_utils import get_split_dimensions, get_time_split_dimension from enstools.feature.util.data_utils import get_split_dimensions, get_valid_time_split_dimension, valid_time_to_dt
class IdentificationStrategy(ABC): class IdentificationStrategy(ABC):
...@@ -158,7 +158,7 @@ class IdentificationStrategy(ABC): ...@@ -158,7 +158,7 @@ class IdentificationStrategy(ABC):
# dimension names along which data is split (init_time, level, ens, ...) # dimension names along which data is split (init_time, level, ens, ...)
split_dimensions = get_split_dimensions(dataset, split_dimensions = get_split_dimensions(dataset,
self.processing_mode) # {ENUM.init_time: "init_time", len, is_scalar} self.processing_mode) # {ENUM.init_time: "init_time", len, is_scalar}
valid_time_sd = get_time_split_dimension(dataset) valid_time_sd = get_valid_time_split_dimension(dataset)
# make sure is sorted so indices work out at the end # make sure is sorted so indices work out at the end
dataset = dataset.sortby([sd.name for sd in split_dimensions]) dataset = dataset.sortby([sd.name for sd in split_dimensions])
...@@ -191,6 +191,7 @@ class IdentificationStrategy(ABC): ...@@ -191,6 +191,7 @@ class IdentificationStrategy(ABC):
# init time # init time
# TODO simplify this? # TODO simplify this?
init_time_cur = None
if SplitDimension.SplitDimensionDim.INIT_TIME in split_dims_set_enum: if SplitDimension.SplitDimensionDim.INIT_TIME in split_dims_set_enum:
init_time_idx = split_dims_set_enum.index(SplitDimension.SplitDimensionDim.INIT_TIME) init_time_idx = split_dims_set_enum.index(SplitDimension.SplitDimensionDim.INIT_TIME)
init_time_cur = dataset.coords[split_dims_set[init_time_idx].name].data[elem[init_time_idx]] init_time_cur = dataset.coords[split_dims_set[init_time_idx].name].data[elem[init_time_idx]]
...@@ -216,11 +217,13 @@ class IdentificationStrategy(ABC): ...@@ -216,11 +217,13 @@ class IdentificationStrategy(ABC):
for t in range(valid_time_sd.size): for t in range(valid_time_sd.size):
timestep = s_i.timesteps.add() timestep = s_i.timesteps.add()
valid_time_cur = dataset.coords[valid_time_sd.name].data[t] if not valid_time_sd.is_scalar else dataset[ valid_time_data = dataset.coords[valid_time_sd.name].data[t] if not valid_time_sd.is_scalar else dataset[
valid_time_sd.name] valid_time_sd.name]
timestep.valid_time = str(np.datetime_as_string(valid_time_cur, unit='s')) valid_time_dt = valid_time_to_dt(valid_time_data, init_time_cur)
timestep.valid_time = str(np.datetime_as_string(valid_time_dt, unit='s'))
if not valid_time_sd.is_scalar: if not valid_time_sd.is_scalar:
cur_indices[valid_time_sd.name] = valid_time_cur cur_indices[valid_time_sd.name] = valid_time_data
# ID to access this object block in protobuf set # ID to access this object block in protobuf set
sets_id = len(self.pb_dataset.sets) - 1 sets_id = len(self.pb_dataset.sets) - 1
...@@ -233,7 +236,6 @@ class IdentificationStrategy(ABC): ...@@ -233,7 +236,6 @@ class IdentificationStrategy(ABC):
# map the blocks, computation here # map the blocks, computation here
mb = xr.map_blocks(self.identify_block, ds_rechunked, args=[index_field], kwargs=dict_split_arrays, mb = xr.map_blocks(self.identify_block, ds_rechunked, args=[index_field], kwargs=dict_split_arrays,
template=ds_rechunked) template=ds_rechunked)
print("Start")
ds_mapped = mb.compute() # num_workers=4 ds_mapped = mb.compute() # num_workers=4
# postprocess # postprocess
......
...@@ -4,7 +4,7 @@ from datetime import datetime ...@@ -4,7 +4,7 @@ from datetime import datetime
from enstools.misc import get_ensemble_dim, get_time_dim from enstools.misc import get_ensemble_dim, get_time_dim
from enstools.feature.util.enstools_utils import get_vertical_dim, get_init_time_dim, get_longitude_dim, \ from enstools.feature.util.enstools_utils import get_vertical_dim, get_init_time_dim, get_longitude_dim, \
get_latitude_dim get_latitude_dim, get_possible_time_dims, get_valid_time_dim
from multiprocessing import Lock from multiprocessing import Lock
lock = Lock() lock = Lock()
...@@ -46,7 +46,8 @@ def print_lock(msg): ...@@ -46,7 +46,8 @@ def print_lock(msg):
with lock: with lock:
print(msg) print(msg)
def get_time_split_dimension(dataset): def get_valid_time_split_dimension(dataset):
""" """
Get the split-dimension object for the valid_time dimension. Get the split-dimension object for the valid_time dimension.
...@@ -59,7 +60,7 @@ def get_time_split_dimension(dataset): ...@@ -59,7 +60,7 @@ def get_time_split_dimension(dataset):
The SplitDimension object of the valid time. The SplitDimension object of the valid time.
""" """
# search time # search time
time_dim = get_time_dim(dataset) time_dim = get_valid_time_dim(dataset)
time_dim_in_input = (time_dim is not None) time_dim_in_input = (time_dim is not None)
if time_dim_in_input and time_dim in dataset.dims: if time_dim_in_input and time_dim in dataset.dims:
time_dim_len = len(dataset.coords[time_dim].values) time_dim_len = len(dataset.coords[time_dim].values)
...@@ -71,6 +72,21 @@ def get_time_split_dimension(dataset): ...@@ -71,6 +72,21 @@ def get_time_split_dimension(dataset):
return split_dim return split_dim
def valid_time_to_dt(valid_time_data, init_time_data): # (dataset, valid_time_split_dim, valid):
if np.issubdtype(valid_time_data.dtype, np.datetime64):
# valid time is already datetime64
return valid_time_data
elif np.issubdtype(valid_time_data.dtype, np.timedelta64):
# valid is a timedelta. find init time and add.
if init_time_data is None:
print("Timedeltas in valid time, but no init time given.")
exit(1)
return init_time_data + valid_time_data
else:
print("Cant convert valid time dimension value " + str(valid_time_data) + " to datetime. Make sure it is a datetime or the step dimension is a timedelta, and the init time a datetime.")
exit(1)
def get_split_dimensions(dataset: xr.Dataset, processing_mode): def get_split_dimensions(dataset: xr.Dataset, processing_mode):
""" """
...@@ -89,10 +105,25 @@ def get_split_dimensions(dataset: xr.Dataset, processing_mode): ...@@ -89,10 +105,25 @@ def get_split_dimensions(dataset: xr.Dataset, processing_mode):
# TODO copy paste here - simplify with getattr() and enum strings? -> care time dimension! # TODO copy paste here - simplify with getattr() and enum strings? -> care time dimension!
# TODO but then also need to generalize get_X_dim(ds)... # TODO but then also need to generalize get_X_dim(ds)...
# TODO changing:
# get all dimensions besides latitude and longitude, and try to parallelize them.
lat_dim = get_latitude_dim(dataset)
lon_dim = get_longitude_dim(dataset)
lev_dim = get_vertical_dim(dataset)
possible_split_dims = []
for dim in dataset.dims:
if dim != lat_dim and dim != lon_dim:
if dim == lev_dim and processing_mode == '3d':
continue
possible_split_dims.append(dim)
init_time_dim = get_init_time_dim(dataset)
valid_time_dim = get_valid_time_dim(dataset)
# search ensemble # search ensemble
ensemble_dim = get_ensemble_dim(dataset) ensemble_dim = get_ensemble_dim(dataset)
ensemble_dim_in_input = (ensemble_dim is not None) ensemble_dim_in_input = (ensemble_dim is not None)
if ensemble_dim_in_input and ensemble_dim in dataset.dims: if ensemble_dim_in_input and ensemble_dim in possible_split_dims:
ensemble_dim_len = len(dataset.coords[ensemble_dim].values) ensemble_dim_len = len(dataset.coords[ensemble_dim].values)
split_dims.append( split_dims.append(
SplitDimension(SplitDimension.SplitDimensionDim.ENSEMBLE_MEMBER, ensemble_dim, ensemble_dim_len, False)) SplitDimension(SplitDimension.SplitDimensionDim.ENSEMBLE_MEMBER, ensemble_dim, ensemble_dim_len, False))
...@@ -101,12 +132,11 @@ def get_split_dimensions(dataset: xr.Dataset, processing_mode): ...@@ -101,12 +132,11 @@ def get_split_dimensions(dataset: xr.Dataset, processing_mode):
split_dims.append(SplitDimension(SplitDimension.SplitDimensionDim.ENSEMBLE_MEMBER, ensemble_dim, 1, True)) split_dims.append(SplitDimension(SplitDimension.SplitDimensionDim.ENSEMBLE_MEMBER, ensemble_dim, 1, True))
# search init_time # search init_time
init_time_dim = get_init_time_dim(dataset)
init_time_dim_in_input = (init_time_dim is not None) init_time_dim_in_input = (init_time_dim is not None)
if init_time_dim_in_input and init_time_dim in dataset.dims: if init_time_dim_in_input and init_time_dim in possible_split_dims:
init_time_dim_len = len(dataset.coords[init_time_dim].values) init_time_dim_len = len(dataset.coords[init_time_dim].values)
split_dims.append( split_dims.append(
SplitDimension(SplitDimension.SplitDimensionDim.INIT_TIME, init_time_dim, init_time_dim_len, False)) SplitDimension(SplitDimension.SplitDimensionDim.INIT_TIME, init_time_dim, init_time_dim_len, False))
elif not init_time_dim_in_input and init_time_dim in dataset.variables: elif not init_time_dim_in_input and init_time_dim in dataset.variables:
init_time_dim_len = 1 init_time_dim_len = 1
split_dims.append(SplitDimension(SplitDimension.SplitDimensionDim.INIT_TIME, init_time_dim, 1, True)) split_dims.append(SplitDimension(SplitDimension.SplitDimensionDim.INIT_TIME, init_time_dim, 1, True))
......
...@@ -136,6 +136,27 @@ def get_vertical_dim(ds): ...@@ -136,6 +136,27 @@ def get_vertical_dim(ds):
return None return None
def get_possible_time_dims(ds):
"""
Get list of dimensions that can be associated with 'time'.
Parameters
----------
ds : xarray.Dataset or xarray.DataArray
Returns
-------
str or None :
if no init_time dimension was found, None is returned.
"""
v_names = ["init_time", "time", "times", "valid_time", "step"]
found = []
for v_name in v_names:
if v_name in ds.dims:
logging.debug("get_possible_time_dims: found name '%s'" % v_name)
found.append(v_name)
return found
def get_init_time_dim(ds): def get_init_time_dim(ds):
""" """
get the name of the init_time dimension from a dataset or array. get the name of the init_time dimension from a dataset or array.
...@@ -149,11 +170,58 @@ def get_init_time_dim(ds): ...@@ -149,11 +170,58 @@ def get_init_time_dim(ds):
str or None : str or None :
if no init_time dimension was found, None is returned. if no init_time dimension was found, None is returned.
""" """
v_names = ["init_time"] pos = get_possible_time_dims(ds)
for v_name in v_names: if 'init_time' in pos:
if v_name in ds.dims: return 'init_time'
logging.debug("get_longitude_dim: found name '%s'" % v_name)
return v_name if len(pos) == 2:
# cant find 'init time'. Do we have 'step'?
if 'step' in pos: # then use the other dim.
init_time_dim = pos[0] if pos[1] == 'step' else pos[1]
return init_time_dim
else:
print("Cant resolve init_time: " + str(pos) + " Consider renaming to init_time.")
exit(1)
if len(pos) > 2:
print("Cant resolve that many time dimensions: " + str(pos))
exit(1)
return None
def get_valid_time_dim(ds):
"""
get the name of the valid_time dimension from a dataset or array.
Parameters
----------
ds : xarray.Dataset or xarray.DataArray
Returns
-------
str or None :
if no valid time dimension was found, None is returned.
"""
pos = get_possible_time_dims(ds)
if 'valid_time' in pos:
return 'valid_time'
if len(pos) == 1 and pos[0] == 'time':
return 'time'
if len(pos) == 2:
# cant find 'init time'. Do we have 'step'?
if 'init_time' in pos: # then use the other dim.
valid_time_dim = pos[0] if pos[1] == 'init_time' else pos[1]
return valid_time_dim
elif 'step' in pos: # if step in times, use step.
return 'step'
else:
print("Cant resolve valid_time: " + str(pos) + " Consider renaming to valid_time.")
exit(1)
if len(pos) > 2:
print("Cant resolve that many time dimensions: " + str(pos))
exit(1)
return None return None
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment