Skip to content
Snippets Groups Projects
Commit cb9cf486 authored by Christoph.Fischer's avatar Christoph.Fischer
Browse files

clean up and added functions to graph template

parent e5266cef
No related branches found
No related tags found
No related merge requests found
......@@ -36,9 +36,24 @@ for trackable_set in od.sets:
tracks = g.generate_tracks(apply_filter=True)
# can perform operations on tracks and nodes...
# track = tracks[0]
# parents of a node: track.get_parents(track.graph.edges[0].parent)
# childs of a node: track.get_childs(track.graph.edges[0].parent)
track = tracks[0]
# get earliest and latest nodes of that track
earliest_nodes = track.get_earliest_nodes()
print("First nodes:")
print([en.time + " ID " + str(en.object.id) for en in earliest_nodes])
latest_nodes = track.get_latest_nodes()
print("Latest nodes:")
print([ln.time + " ID " + str(ln.object.id) for ln in latest_nodes])
if len(latest_nodes):
parents = track.get_parents(latest_nodes[0])
print("Parents of a latest node:")
print([p.time + " ID " + str(p.object.id) for p in parents])
childs = track.get_childs(latest_nodes[0])
print("Childs of latest nodes:")
print([c.time + " ID " + str(c.object.id) for c in childs])
for track_id, track in enumerate(tracks):
print("Track " + str(track_id) + " has " + str(len(track.graph.edges)) + " nodes.")
......
......@@ -51,7 +51,7 @@ class TrackingTemplate(TrackingTechnique):
return connections # TODO..
def postprocess(self, obj_desc):
print("postprocess @ objcomp template tracking")
print("Postprocess @ TrackingTemplate")
return
def keep_track(self, track):
......
......@@ -36,7 +36,7 @@ class TrackingCompareTemplate(ObjectComparisonTracking):
return False
def postprocess(self, obj_desc):
print("postprocess @ objcomp template tracking")
print("Postprocess @ TrackingCompareTemplate")
return
# can be overwritten to filter tracks after the generation process
......
......@@ -92,9 +92,9 @@ class TrackingTechnique(ABC):
cons = squeeze_nodes(list(empty_cons) + list(cons)) # TODO could be None.
# sort them by time (string key)
cons = sorted(cons, key=lambda c: c.parent.time)
print("Remove transitive edges...")
# remove transitive edges: a->b->c and a->c --> remove the a->c node
cons = TrackingTechnique.remove_transitive(cons) # TODO more efficient maybe?
cons = TrackingTechnique.remove_transitive(cons) # TODO more efficient maybe?
# add graph nodes to ref graph
obj_set.graph.edges.extend(cons)
......@@ -120,22 +120,11 @@ class TrackingTechnique(ABC):
pass
# filter the generated tracks: for each track call the keep_track() function.
def filter_track(self, object_desc): # TODO dont need anymore?
"""
Filter the generated tracks. For each track call the keep_track() function.
Returns
-------
"""
for set_ in object_desc.sets:
for t_id in range(len(set_.tracks) - 1, -1, -1):
track = set_.tracks[t_id]
if not self.keep_track(track):
del set_.tracks[t_id]
def keep_track(self, track):
"""
Override this method to filter out certain tracks.
When generate_tracks() on a Graph is called, each track is checked via this method whether it should be kept.
For example, you could check here if the track holds together for longer than a certain time range.
Parameters
----------
......@@ -143,11 +132,10 @@ class TrackingTechnique(ABC):
Returns
-------
True if keep, else discard
True if track should be kept, else discard
"""
return True
@staticmethod
def get_downstream_node_indices(graph_list, start_idx, until_time=None):
"""
......@@ -172,7 +160,8 @@ class TrackingTechnique(ABC):
c_node_idx = obj_node_list.index(c_node)
# call recursively on this connected node
c_node_downstream_indices = TrackingTechnique.get_downstream_node_indices(graph_list, c_node_idx, until_time=until_time)
c_node_downstream_indices = TrackingTechnique.get_downstream_node_indices(graph_list, c_node_idx,
until_time=until_time)
node_indices.extend(c_node_downstream_indices)
return list(set(node_indices))
......@@ -265,9 +254,10 @@ class TrackingTechnique(ABC):
continue
for cld in childs:
child_time = pb_str_to_datetime(cld.time)
con.childs.remove(cld) # remove edge temporarily
con.childs.remove(cld) # remove edge temporarily
# and check if downstream nodes still contains end node of edge. if yes -> transitive
downstream_idxs = TrackingTechnique.get_downstream_node_indices(connections, con_id, until_time=child_time)
downstream_idxs = TrackingTechnique.get_downstream_node_indices(connections, con_id,
until_time=child_time)
downstream_objs = [connections[ds_idx].parent for ds_idx in downstream_idxs]
if cld in downstream_objs:
# transitive
......@@ -276,7 +266,7 @@ class TrackingTechnique(ABC):
# add edge back
con.childs.append(cld)
print(str(tn) + " transitive edges")
print("Removed " + str(tn) + " transitive edges in current set.")
return connections
# TODO what if tracks are the identification? e.g. AEW identification in Hovmoller
from enstools.feature.util.data_utils import pb_str_to_datetime
class DataGraph:
"""
Abstraction for graphs:
......@@ -9,9 +10,62 @@ class DataGraph:
def __init__(self, set_desc, tracking_tech):
self.set_desc = set_desc
# sort the edges in the graph by time.
# makes it easier to traverse later.
self.graph = set_desc.graph
self.graph.edges.sort(key=lambda item: item.parent.time)
self.tr_tech = tracking_tech
def get_earliest_nodes(self):
"""
Get the nodes which are earliest in the graph.
Note that multiple earliest nodes might exist.
Returns
-------
List of nodes with earliest timestamp.
"""
if len(self.graph.edges) == 0:
return []
earliest_time = self.graph.edges[0].parent.time
nodes = [self.graph.edges[0].parent]
# iterate over first nodes until time changes or end of list
idx = 1
while idx < len(self.graph.edges) and self.graph.edges[idx].parent.time == earliest_time:
nodes.append(self.graph.edges[idx].parent)
idx += 1
return nodes
def get_latest_nodes(self):
"""
Get the nodes which are latest in the graph.
Note that multiple latest nodes might exist.
Returns
-------
List of nodes with latest timestamp.
"""
if len(self.graph.edges) == 0:
return []
latest_time = self.graph.edges[-1].parent.time
nodes = [self.graph.edges[-1].parent]
# iterate over first nodes until time changes or end of list
idx = len(self.graph.edges) - 2
while idx >= 0 and self.graph.edges[idx].parent.time == latest_time:
nodes.append(self.graph.edges[idx].parent)
idx -= 1
return nodes
def get_parents(self, node):
"""
Get the parents of a node (previous timestep)
......@@ -131,12 +185,22 @@ class DataGraph:
@staticmethod
def get_downstream_node_indices(graph_list, start_idx, until_time=None):
"""
Helper method for the track generation. Searches all downstream node indices.
Helper method for the track generation. Searches all downstream node indices. (time-wise)
Parameters
----------
graph_list:
List of GraphNodes, sorted by parent time
start_idx:
Index of GraphNode to start
until_time:
latest time to consider, None if don't set a threshold
Returns
-------
list of downstream indices in list
List of indices in graph_list containing the downstream nodes of the given one.
"""
node_indices = [start_idx]
obj_node_list = [con_.parent for con_ in graph_list] # this up
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment