Skip to content
Snippets Groups Projects
tracking.py 4.41 KiB
Newer Older
import numpy as np
from abc import ABC, abstractmethod
import xarray as xr
# from enstools.feature.identification._proto_gen import identification_pb2
from enstools.feature.util.data_utils import get_subset_by_description

class TrackingTechnique(ABC):
    """
    Base abstract class for feature tracking algorithms.
    Implementations need to override the abstract track() method.
    """

Christoph.Fischer's avatar
Christoph.Fischer committed
    def __init__(self):
        self.pb_reference = None
        self.graph = None

    def track(self, track_set, subset: xr.Dataset): # TODO update docstrings
        Abstract tracking method. This gets called for each timesteps list of the feature descriptions.
        This timeline can be for example a reforecast or a member of an ensemble forecast, in which detected objects should
        be tracked over multiple timestamps. This method gets called in parallel for all timelines in the dataset.

        This method should compute links of corresponding objects of two consecutive timestamps. Each object in a
        timestep has its unique ID, and a computed tuple (id1, id2) remarks that object id1 from timestamp t is the
        same object as id2 from timestamp t+1. One tuple is an edge in a tracking graph.

        Parameters
        ----------
        track_set : iterable of identification_pb2.TrackingSet
                The data to be tracked
        subset : xarray.Dataset
                Subset to be tracked. Forwarded from identification

        Returns
        -------
        connections : list of tuples of int
                The connections forming the path of the objects in the timesteps.
        """
        connections = []
        return connections

    def execute(self, object_desc, dataset_ref: xr.Dataset):
        """
        Execute the tracking procedure.
        The description is split into the different timelines which can be executed in parallel.
        object_desc : identification_pb2.DatasetDescription TODO?
                The description of the detected features from the identificaiton technique.
        dataset_ref : xarray.Dataset
                Reference to the dataset used in the pipeline.
        """
        tracking_sets = object_desc.sets
        from enstools.misc import get_ensemble_dim

Christoph.Fischer's avatar
Christoph.Fischer committed
        graph_ds = self.pb_reference.DatasetDescription()
        graph_ds.CopyFrom(object_desc)

        # TODO parallelize sets (intra-set is parallelized in comparer)
Christoph.Fischer's avatar
Christoph.Fischer committed
        eo = enumerate(tracking_sets)
Christoph.Fischer's avatar
Christoph.Fischer committed
        for set_idx, tracking_set in enumerate(tracking_sets):
            # extract reference to where in graph_ds
            graph_set = graph_ds.sets[set_idx]
            del graph_set.timesteps[:]
            refg = graph_set.ref_graph
            del refg

            dataset_sel = get_subset_by_description(dataset_ref, tracking_set)
            connections = self.track(tracking_set, dataset_sel) # TODO check subsets in here on more complex DS
Christoph.Fischer's avatar
Christoph.Fischer committed
            # create object connections from index connections for output graph
            tracking_set.ref_graph.connections.extend(connections)
            for c in connections:
                obj_con = self.pb_reference.ObjectConnection()
                start, end = c.n1, c.n2
                obj_con.n1.time = tracking_set.timesteps[start.time_index].valid_time
                obj_with_startid = [objindex for objindex, obj in enumerate(tracking_set.timesteps[start.time_index].objects) if obj.id == start.object_id] [0]
                obj_con.n1.object.CopyFrom(tracking_set.timesteps[start.time_index].objects[obj_with_startid])

                obj_con.n2.time = tracking_set.timesteps[end.time_index].valid_time
                obj_with_endid = [objindex for objindex, obj in enumerate(tracking_set.timesteps[end.time_index].objects) if obj.id == end.object_id] [0]
                obj_con.n2.object.CopyFrom(tracking_set.timesteps[end.time_index].objects[obj_with_endid])
Christoph.Fischer's avatar
Christoph.Fischer committed
                graph_set.object_graph.connections.append(obj_con)
Christoph.Fischer's avatar
Christoph.Fischer committed
            self.postprocess_set(tracking_set) # TODO return smth?

        self.graph = graph_ds
        # TODO next link pairs together to "path"
        # track then has N-list of IDs

        # object_desc has now the links.
        # TODO more to do? is edge list of graph enough? maybe wrapper for graph
        # TODO what if tracks are the identification? e.g. AEW identification in Hovmoller
Christoph.Fischer's avatar
Christoph.Fischer committed


    def get_graph(self):
        return self.graph