import os.path from matplotlib import pyplot as plt import numpy as np import matplotlib import matplotlib.patches as patches import cartopy.crs as ccrs import cartopy.feature as cfeature from datetime import datetime from enstools.feature.util.data_utils import pb_str_to_datetime from pathlib import Path import enstools.feature.identification.african_easterly_waves.configuration as cfg # plots the wave state (all wavetroughs given specific timestep in a set) ts: pb2.Timestep def plot_wavetroughs(ts, fig_name, cv=None): fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(11, 4), subplot_kw=dict(projection=ccrs.PlateCarree())) x_ticks = [-100, -95, -85, -75, -65, -55, -45, -35, -25, -15, -5, 5, 15, 25, 35] y_ticks = [0, 10, 20, 30] extent = [-100, -45, -10, 35] if cv is not None: levelfc = np.asarray([0, 0.5, 1, 2, 3]) * 1e-5 cv.plot.contourf(levels=levelfc, vmin=0, extend='max', cmap='Blues') # generate plot per pressure level, per time step # colors per time step # min_time = wave_thr_list[0].time.astype('float64') # max_time = wave_thr_list[-1].time.astype('float64') # cmap = matplotlib.cm.get_cmap('rainbow') # color_wgts = np.linspace(0.0, 1.0, len(wave_thr_list)) # colors = ['red', 'yellow', 'green', 'blue', 'purple'] vt = ts.valid_time for obj_idx, obj in enumerate(ts.objects): # time64 = wave.time.astype('float64') # time_weight = (time64 - min_time) / (max_time - min_time) if max_time > min_time else 1.0 line_pts = obj.properties.line_pts line = patches.Path([[p.lon, p.lat] for p in line_pts]) patch = patches.PathPatch(line, linewidth=2, facecolor='none', edgecolor='red') # cmap(time_weight) ax.add_patch(patch) ax.coastlines() ax.add_feature(cfeature.BORDERS.with_scale('50m')) ax.set_extent(extent, crs=ccrs.PlateCarree()) yt1 = ax.set_yticks(y_ticks, crs=ccrs.PlateCarree()) xt1 = ax.set_xticks(x_ticks, crs=ccrs.PlateCarree()) figure_name = fig_name.replace(':', '_') + '_aew_troughs.png' plt.savefig(figure_name, format='png') plt.figure().clear() plt.close() plt.cla() plt.clf() return figure_name def plot_timesteps_from_desc(object_desc, cv=None): # plot for each set for each timestep everything detected. from enstools.feature.util.data_utils import get_subset_by_description for set_idx, od_set in enumerate(object_desc.sets): fn = cfg.plot_dir + "ts_set_" + str(set_idx) cv_set = get_subset_by_description(cv, od_set, '2d') for ts in od_set.timesteps: cv_st = cv_set.sel(time=ts.valid_time).cv fnt = fn + "_" + ts.valid_time print(fnt) # ts.validTime / .objects fout_name = plot_wavetroughs(ts, fnt, cv=cv_st) print("Plot to " + fout_name) return None def plot_track(track, fn): nodes = [edge.parent for edge in track.edges] fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(11, 4), subplot_kw=dict(projection=ccrs.PlateCarree())) x_ticks = [-100, -95, -85, -75, -65, -55, -45, -35, -25, -15, -5, 5, 15, 25, 35] y_ticks = [0, 10, 20, 30] extent = [-100, -45, -10, 35] # generate plot per pressure level, per time step # colors per time step min_time = pb_str_to_datetime(nodes[0].time).timestamp() max_time = pb_str_to_datetime(nodes[-1].time).timestamp() cmap = matplotlib.cm.get_cmap('rainbow') color_wgts = np.linspace(0.0, 1.0, len(nodes)) colors = ['red', 'yellow', 'green', 'blue', 'purple'] for node_idx, node in enumerate(nodes): obj = node.object time_d = pb_str_to_datetime(node.time).timestamp() time_weight = (time_d - min_time) / (max_time - min_time) if max_time > min_time else 1.0 line_pts = obj.properties.line_pts line = patches.Path([[p.lon, p.lat] for p in line_pts]) patch = patches.PathPatch(line, linewidth=2, facecolor='none', edgecolor=cmap(time_weight)) ax.add_patch(patch) ax.coastlines() ax.add_feature(cfeature.BORDERS.with_scale('50m')) ax.set_extent(extent, crs=ccrs.PlateCarree()) yt1 = ax.set_yticks(y_ticks, crs=ccrs.PlateCarree()) xt1 = ax.set_xticks(x_ticks, crs=ccrs.PlateCarree()) figure_name = cfg.plot_dir + fn + '.png' # .replace(':', '_') plt.title(nodes[0].time + " - " + nodes[-1].time) print("Plot to " + str(figure_name)) plt.savefig(figure_name, format='png') plt.figure().clear() plt.close() plt.cla() plt.clf() return figure_name from collections import defaultdict def plot_differences(set_graph, tracks, cv=None): # TODO CV too? # plot the differences of the total graph and the tracks # so check which WTs are part of a track and which have been dropped. # TODO # join tracks list # set_graph elements not in tracks # for eaach timestep plot two mengen set_nodes = [e.parent for e in set_graph.graph.edges] is_in_set_nodes = [False] * len(set_nodes) for track in tracks: track_nodes = [e.parent for e in track.edges] for track_node in track_nodes: try: # track node in list, set to True idx = set_nodes.index(track_node) is_in_set_nodes[idx] = True except ValueError: # not in it continue # lists which WTs are part of tracks, and which are not part of tracks wts_in_tracks_list = [set_nodes[i] for i, b in enumerate(is_in_set_nodes) if b] wts_not_in_tracks_list = [set_nodes[i] for i, b in enumerate(is_in_set_nodes) if not b] # make these lists to dicts with date as key wts_in_tracks = defaultdict(list) for wt_in_track in wts_in_tracks_list: wts_in_tracks[wt_in_track.time].append(wt_in_track) wts_not_in_tracks = defaultdict(list) for wt_not_in_track in wts_not_in_tracks_list: wts_not_in_tracks[wt_not_in_track.time].append(wt_not_in_track) dates = set() dates.update(wts_in_tracks.keys()) dates.update(wts_not_in_tracks.keys()) dates_list = list(dates) dates_list.sort() for date in dates_list: fig_name = cfg.plot_dir + "part_of_wave_" + date.replace(':', '_') + ".png" cv_ss = cv.sel(time=date) plot_ts_part_of_track(wts_in_tracks[date], wts_not_in_tracks[date], fig_name, cv_ss) def plot_ts_part_of_track(wts_part_of_tracks, wt_not_part_of_tracks, fig_name, cv=None): fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(11, 4), subplot_kw=dict(projection=ccrs.PlateCarree())) x_ticks = [-100, -95, -85, -75, -65, -55, -45, -35, -25, -15, -5, 5, 15, 25, 35] y_ticks = [0, 10, 20, 30] extent = [-100, -45, -10, 35] if cv is not None: levelfc = np.asarray([0, 0.5, 1, 2, 3]) * 1e-5 cv.plot.contourf(levels=levelfc, vmin=0, extend='max', cmap='Blues') # generate plot per pressure level, per time step # colors per time step # min_time = wave_thr_list[0].time.astype('float64') # max_time = wave_thr_list[-1].time.astype('float64') # cmap = matplotlib.cm.get_cmap('rainbow') # color_wgts = np.linspace(0.0, 1.0, len(wave_thr_list)) # colors = ['red', 'yellow', 'green', 'blue', 'purple'] for obj_idx, node in enumerate(wts_part_of_tracks): line_pts = node.object.properties.line_pts line = patches.Path([[p.lon, p.lat] for p in line_pts]) patch = patches.PathPatch(line, linewidth=2, facecolor='none', edgecolor='lime') # cmap(time_weight) ax.add_patch(patch) for obj_idx, node in enumerate(wt_not_part_of_tracks): line_pts = node.object.properties.line_pts line = patches.Path([[p.lon, p.lat] for p in line_pts]) patch = patches.PathPatch(line, linewidth=2, facecolor='none', edgecolor='red') # cmap(time_weight) ax.add_patch(patch) ax.coastlines() ax.add_feature(cfeature.BORDERS.with_scale('50m')) ax.set_extent(extent, crs=ccrs.PlateCarree()) yt1 = ax.set_yticks(y_ticks, crs=ccrs.PlateCarree()) xt1 = ax.set_xticks(x_ticks, crs=ccrs.PlateCarree()) print("Save to " + fig_name) plt.savefig(fig_name, format='png') plt.figure().clear() plt.close() plt.cla() plt.clf() return def plot_track_from_graph(track_desc, fig_name_prefix, cv=None): nodes = [edge.parent for edge in track_desc.edges] fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(11, 4), subplot_kw=dict(projection=ccrs.PlateCarree())) x_ticks = [-100, -95, -85, -75, -65, -55, -45, -35, -25, -15, -5, 5, 15, 25, 35] y_ticks = [0, 10, 20, 30] extent = [-100, -45, -10, 35] if cv is not None: cv = cv.isel(time=0) levelfc = np.asarray([0, 0.5, 1, 2, 3]) * 1e-5 cv.plot.contourf(levels=levelfc, vmin=0, extend='max', cmap='Blues') # generate plot per pressure level, per time step # colors per time step min_time = pb_str_to_datetime(nodes[0].time).timestamp() max_time = pb_str_to_datetime(nodes[-1].time).timestamp() cmap = matplotlib.cm.get_cmap('rainbow') color_wgts = np.linspace(0.0, 1.0, len(nodes)) colors = ['red', 'yellow', 'green', 'blue', 'purple'] for node_idx, node in enumerate(nodes): obj = node.object time_d = pb_str_to_datetime(node.time).timestamp() time_weight = (time_d - min_time) / (max_time - min_time) if max_time > min_time else 1.0 line_pts = obj.properties.line_pts line = patches.Path([[p.lon, p.lat] for p in line_pts]) patch = patches.PathPatch(line, linewidth=2, facecolor='none', edgecolor=cmap(time_weight)) ax.add_patch(patch) ax.coastlines() ax.add_feature(cfeature.BORDERS.with_scale('50m')) ax.set_extent(extent, crs=ccrs.PlateCarree()) yt1 = ax.set_yticks(y_ticks, crs=ccrs.PlateCarree()) xt1 = ax.set_xticks(x_ticks, crs=ccrs.PlateCarree()) figure_name = fig_name_prefix + '_troughs.png' # .replace(':', '_') plt.title(nodes[0].time + " - " + nodes[-1].time) print("Plot to " + str(figure_name)) plt.savefig(figure_name, format='png') plt.figure().clear() plt.close() plt.cla() plt.clf() return figure_name def plot_wt_list(nodes, fn): fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(11, 4), subplot_kw=dict(projection=ccrs.PlateCarree())) x_ticks = [-100, -95, -85, -75, -65, -55, -45, -35, -25, -15, -5, 5, 15, 25, 35] y_ticks = [0, 10, 20, 30] extent = [-100, -45, -10, 35] # generate plot per pressure level, per time step # colors per time step # min_time = wave_thr_list[0].time.astype('float64') # max_time = wave_thr_list[-1].time.astype('float64') # cmap = matplotlib.cm.get_cmap('rainbow') # color_wgts = np.linspace(0.0, 1.0, len(wave_thr_list)) # colors = ['red', 'yellow', 'green', 'blue', 'purple'] for node in nodes: line_pts = node.object.properties.line_pts line = patches.Path([[p.lon, p.lat] for p in line_pts]) patch = patches.PathPatch(line, linewidth=2, facecolor='none', edgecolor='red') # cmap(time_weight) ax.add_patch(patch) ax.coastlines() ax.add_feature(cfeature.BORDERS.with_scale('50m')) ax.set_extent(extent, crs=ccrs.PlateCarree()) yt1 = ax.set_yticks(y_ticks, crs=ccrs.PlateCarree()) xt1 = ax.set_xticks(x_ticks, crs=ccrs.PlateCarree()) figure_name = fn + '.png' plt.savefig(figure_name, format='png') plt.figure().clear() plt.close() plt.cla() plt.clf() return figure_name def plot_track_in_ts(track): fn = cfg.plot_dir + 'singletrack_' per_ts_wts = dict() for edge in track.edges: node = edge.parent key = node.time.replace(':', '_') if key in per_ts_wts: per_ts_wts[key].append(node) else: per_ts_wts[key] = [node] for time, nodes in per_ts_wts.items(): plot_wt_list(nodes, fn + time) def plot_tracks_from_desc(graph_desc, ds=None): from enstools.feature.util.data_utils import get_subset_by_description for set_idx, od_set in enumerate(graph_desc.sets): # cv_set = get_subset_by_description(ds, od_set, '2d') for set_tr, track in enumerate(od_set.tracks): fn = cfg.plot_dir + 'set_' + str(set_idx) + "_track_" + str(set_tr) plot_track_from_graph(track, fn, cv=None) pass