from ..object_compare_tracking import ObjectComparisonTracking import enstools.feature.identification.african_easterly_waves.configuration as cfg from datetime import datetime, timedelta from collections import defaultdict import statistics from shapely.geometry import Polygon, LineString from enstools.feature.util.data_utils import pb_str_to_datetime from enstools.feature.util.graph import DataGraph zero_dt = timedelta(seconds=0) class AEWTracking(ObjectComparisonTracking): """ Tracking of AEWs. Based on comparisons on wavetroughs """ def __init__(self, **kwargs): self.config = cfg self.set_max_delta_compare(cfg.max_cmp_delta) pass def correspond(self, time1, obj1, time2, obj2): # returns True if objects are same wave # access the properties set earlier prop1 = obj1.properties prop2 = obj2.properties lineseg1 = prop1.line_pts lineseg2 = prop2.line_pts bb1_lon_center = (prop1.bb.max.lon - prop1.bb.min.lon) / 2.0 bb2_lon_center = (prop2.bb.max.lon - prop2.bb.min.lon) / 2.0 time_diff = time2 - time1 time_diff_h = time_diff.seconds // 3600 # predicted polygonal area predicted_polygon_list = [] for i in range(len(lineseg1)): # min speed predicted_polygon_list.append( [lineseg1[i].lon + self.config.speed_deg_per_h[0] * time_diff_h, lineseg1[i].lat]) for i in range(len(lineseg1) - 1, -1, -1): # max speed predicted_polygon_list.append( [lineseg1[i].lon + self.config.speed_deg_per_h[1] * time_diff_h, lineseg1[i].lat]) predicted_polygon = Polygon(predicted_polygon_list) # actual line string p2_pt_list = [] for i in range(len(lineseg2)): p2_pt_list.append((lineseg2[i].lon, lineseg2[i].lat)) actual_linestring = LineString(p2_pt_list) if not predicted_polygon.intersects(actual_linestring): # linesting poly not where expected? return False return True def postprocess(self, obj_desc): print("postprocess @ aew") return # get dict of key=endnode, values=nodes on way) for endnodes at end if time_delta @staticmethod def get_nodes_after(time_delta, edges, start_index): this_node = edges[start_index].parent if time_delta <= zero_dt or len(edges[start_index].childs) == 0: dd = defaultdict(list) dd[start_index] = [start_index] return dd obj_node_list = [e.parent for e in edges] childs_ = defaultdict(list) for child in edges[start_index].childs: dt = pb_str_to_datetime(child.time) - pb_str_to_datetime(this_node.time) remaining_time_delta = time_delta - dt # get index of connected node in graph c_node_idx = obj_node_list.index(child) after_nodes = AEWTracking.get_nodes_after(remaining_time_delta, edges, c_node_idx) # add own node to it (this_node) for endn_, pathn_ in after_nodes.items(): childs_[endn_].extend(pathn_ + [start_index]) return childs_ def adjust_track(self, tracking_graph): # keep track if persists longer than duration_threshold track = tracking_graph.graph nodes = [edge.parent for edge in track.edges] min_time = pb_str_to_datetime(nodes[0].time) max_time = pb_str_to_datetime(nodes[-1].time) duration = max_time - min_time if duration < self.config.duration_threshold: return None times = [pb_str_to_datetime(node.time) for node in nodes] center_lons = [(node.object.properties.bb.max.lon + node.object.properties.bb.min.lon) / 2.0 for node in nodes] center_lats = [(node.object.properties.bb.max.lat + node.object.properties.bb.min.lat) / 2.0 for node in nodes] extent_lons = [(node.object.properties.bb.max.lon - node.object.properties.bb.min.lon) for node in nodes] extent_lats = [(node.object.properties.bb.max.lat - node.object.properties.bb.min.lat) for node in nodes] # init nodes as False = should be discarded keep_node = [False] * len(nodes) # iterate over all nodes: if fast enough from this to this+dt, keep for node_idx, node in enumerate(nodes): # downstream_nodes_with_path = self.get_downstream_nodes_after(self.config.slow_duration_window, track.edges, node_idx) # get list of (end_node, nodes_on_way) ds_nodes = AEWTracking.get_nodes_after(self.config.avg_speed_t, track.edges, node_idx) for end_node_idx, way_node_idxs in ds_nodes.items(): way_node_idxs = list(set(way_node_idxs)) # make them unique. might be dup if multiple paths end_node = nodes[end_node_idx] # if dt < 48h -> didnt reach 2 days. if pb_str_to_datetime(end_node.time) - pb_str_to_datetime(node.time) < self.config.avg_speed_t: # could not track wave that long, dont set to True continue # if nodes along track are predominantly horizontal orientated discard them. lon_lat_ratio = [extent_lons[i] / extent_lats[i] for i in way_node_idxs] if statistics.mean(lon_lat_ratio) > 1: continue # check avg speed in that 48h end_node_lon_center = (end_node.object.properties.bb.max.lon + end_node.object.properties.bb.min.lon) / 2.0 avg_speed_window_h = self.config.avg_speed_t.total_seconds() / 3600 deg_per_hr_avg = (end_node_lon_center - center_lons[node_idx]) / avg_speed_window_h # speed check. fast enough -> set to True if deg_per_hr_avg < self.config.avg_speed_min_deg_per_h: for way_node_idx in way_node_idxs: keep_node[way_node_idx] = True # us_nodes = get_upstream_nodes_at_time(t) # nodes_in_path_ab = cut(ds_nodes, us_nodes) # then generate_tracks() kept_edges = [edge for x, edge in enumerate(track.edges) if keep_node[x]] # regenerate tracks out of leftover edges. this also splits a track etc, and removes short ones # create return object reduced_tg = tracking_graph.tr_tech.pb_reference.TrackableSet() reduced_tg.CopyFrom(tracking_graph.set_desc) del reduced_tg.graph.edges[:] del reduced_tg.tracks[:] reduced_tg.graph.edges.extend(kept_edges) # clean up -> degenerate nodes (child without connections) dg = DataGraph(reduced_tg, self, fix=True) # if dg nodes same as input to this method, nothing changed, dont regenerate tracks. # otherwise stuck in infinite loop if len(dg.graph.edges) == len(tracking_graph.graph.edges): dg.set_desc.tracks.append(dg.set_desc.graph) print("Nothing changed.") return dg print("Regenerate.") dg.generate_tracks(apply_filter=True) # TODO # redo clim, plot which WTs not considered return dg