Newer
Older
import numpy as np
from abc import ABC, abstractmethod
import xarray as xr
# from enstools.feature.identification._proto_gen import identification_pb2
from enstools.feature.util.data_utils import get_subset_by_description, squeeze_nodes
from multiprocessing.pool import ThreadPool as Pool
from functools import partial
class TrackingTechnique(ABC):
"""
Base abstract class for feature tracking algorithms.
Implementations need to override the abstract track() method.
"""
def __init__(self):
self.pb_reference = None
self.graph = None
def track(self, track_set, subset: xr.Dataset): # TODO update docstrings
Abstract tracking method. This gets called for each timesteps list of the feature descriptions.
This timeline can be for example a reforecast or a member of an ensemble forecast, in which detected objects should
be tracked over multiple timestamps. This method gets called in parallel for all timelines in the dataset.
This method should compute links of corresponding objects of two consecutive timestamps. Each object in a
timestep has its unique ID, and a computed tuple (id1, id2) remarks that object id1 from timestamp t is the
same object as id2 from timestamp t+1. One tuple is an edge in a tracking graph.
Parameters
----------
track_set : iterable of identification_pb2.TrackingSet
The data to be tracked
subset : xarray.Dataset
Subset to be tracked. Forwarded from identification
Returns
-------
connections : list of tuples of int
The connections forming the path of the objects in the timesteps.
"""
connections = []
return connections
@abstractmethod
def postprocess(self, object_desc): # TODO update docstrings
pass
def track_set(self, set_idx, obj_desc=None, dataset=None):
obj_set = obj_desc.sets[set_idx]
# get according subset
dataset_sel = get_subset_by_description(dataset, obj_set)
print("Track " + str(obj_set)[:40] + " (" + str(set_idx) + ")")
nodes = self.track(obj_set, dataset_sel) # TODO check subsets in here on more complex DS
print("Nodes before squeezed: " + str(len(nodes)))
nodes = squeeze_nodes(nodes)
print("Nodes after squeezed: " + str(len(nodes)))
exit() # TODO then if squeezed generate tracks.
obj_set.ref_graph.connections.extend(nodes)
# create object connections from index connections for output graph
# not really efficient...
def execute(self, object_desc, dataset_ref: xr.Dataset):
"""
Execute the tracking procedure.
The description is split into the different timelines which can be executed in parallel.
Parameters
----------
object_desc : identification_pb2.DatasetDescription TODO?
The description of the detected features from the identificaiton technique.
dataset_ref : xarray.Dataset
Reference to the dataset used in the pipeline.
"""
# parallel for all tracking sets
pool = Pool()
pool.map(partial(self.track_set, obj_desc=object_desc, dataset=dataset_ref),
range(len(object_desc.sets))) # iterate over sets.
self.postprocess(object_desc) # only postprocess object desc, THEN to graph
# TODO what if tracks are the identification? e.g. AEW identification in Hovmoller
pass
def get_graph(self):
if self.graph is not None:
return self.graph # use cached
self.generate_graph()
return self.graph
self.graph = self.pb_reference.DatasetDescription()
self.graph.CopyFrom(object_desc)
for set_idx, objdesc_set in enumerate(object_desc.sets):
graph_set = self.graph.sets[set_idx]
# empty the graph set
del graph_set.timesteps[:]
graph_set.ref_graph.Clear()
# refg
for c in objdesc_set.ref_graph.connections:
obj_con = self.pb_reference.ObjectConnection()
start, end = c.n1, c.n2
obj_con.n1.time = objdesc_set.timesteps[start.time_index].valid_time
obj_with_startid = \
[objindex for objindex, obj in enumerate(objdesc_set.timesteps[start.time_index].objects) if
obj.id == start.object_id][0]
obj_con.n1.object.CopyFrom(objdesc_set.timesteps[start.time_index].objects[obj_with_startid])
obj_con.n2.time = objdesc_set.timesteps[end.time_index].valid_time
obj_with_endid = [objindex for objindex, obj in enumerate(objdesc_set.timesteps[end.time_index].objects) if
obj_con.n2.object.CopyFrom(objdesc_set.timesteps[end.time_index].objects[obj_with_endid])
graph_set.object_graph.connections.append(obj_con)
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
# TODO add (n1,None) pairs for isolated nodes
# TODO also for endnodes (n1,None) --> so n1 in (n1,n2) covers ALL nodes.
# after graph has been computed, compute "tracks", which is disjunct list of graphs of the total graph
def get_tracks(self):
if self.graph is None:
print("Compute graph first.")
exit(1)
# for each set:
# order connections by time of first node
print("get tracks")
for graph_set in self.graph.sets:
# object_set = self.obj_ref.set_idx
print("new set get tracks")
# sort nodes by time of first node (is in key), as list here
# TODO assert comp by string yields time sorted list (yyyymmddd)
time_sorted_nodes = list(sorted(graph_set.object_graph.connections, key=lambda item: item.n1.time))
wave_id_per_node = [None] * len(time_sorted_nodes)
cur_id = 0
# iterate over all time sorted identified connections
# search temporal downstream tracked troughs and group them using a set id
for con_idx, (n1, n2) in enumerate(time_sorted_nodes):
# TODO dont need n1,n2?
if wave_id_per_node[con_idx] is not None: # already part of a wave
continue
# not part of wave -> get (temporal) downstream connections = wave (return indices of them)
downstream_wave_node_indices = get_downstream_node_indices(time_sorted_nodes, con_idx)
# any of downstream nodes already part of a wave?
connected_wave_id = None
for ds_node_idx in downstream_wave_node_indices:
if wave_id_per_node[ds_node_idx] is not None:
if connected_wave_id is not None:
print("Double ID...")
connected_wave_id = wave_id_per_node[ds_node_idx]
# if so set all nodes to this found id
if connected_wave_id is not None:
for ds_node_idx in downstream_wave_node_indices:
wave_id_per_node[ds_node_idx] = connected_wave_id
continue
# else new path for all wave_nodes
cur_id_needs_update = False
for ds_node_idx in downstream_wave_node_indices:
wave_id_per_node[ds_node_idx] = cur_id
cur_id_needs_update = True
if cur_id_needs_update:
cur_id += 1
print(wave_id_per_node)
# done, now extract every wave by id and put them into subgraphs
waves = [] # generate list of waves according to above policy
for wave_id in range(cur_id):
wave_idxs = [i for i in range(len(wave_id_per_node)) if wave_id_per_node[i] == wave_id]
cur_wave_nodes = [time_sorted_nodes[i] for i in wave_idxs] # troughs of this wave
# cur_troughs = cur_troughs.sortbytime # already sorted?
wave = cur_wave_nodes
print(wave)
print(waves) # waves for this set only
return None
"""
@staticmethod
def get_downstream_node_indices(graph_list, start_idx):
node_indices = [start_idx]
node, connected_nodes = graph_list[start_idx]
for c_node in connected_nodes:
# get index of connected node in graph
c_node_idx = [n_ for n_, cv_ in graph_list].index(c_node)
# call recursively on this connected node
c_node_downstream_indices = WaveGraph.get_downstream_node_indices(graph_list, c_node_idx)
node_indices.extend(c_node_downstream_indices)
return list(set(node_indices))
def get_items(self):
return self._graph.items()
# extract wave objects from graph: question here: what is a wave?
# e.g. before/after merging, what belongs to same wave
# wave object is list of wavetroughs, so wavetroughs for consecutive timesteps
def extract_waves(self):
"""