Skip to content
Snippets Groups Projects
tracking.py 2.41 KiB
from ..object_compare_tracking import ObjectComparisonTracking
from datetime import datetime, timedelta
from enstools.feature.util.data_utils import pb_str_to_datetime


class TrackingCompareTemplate(ObjectComparisonTracking):
    """
    Template for simple object-comparision tracking: Objects of consecutive timestamps are considered the same
    if correspond() returns True.
    This can be used if the tracking can simply be done by this pairwise comparision. More complex algorithms, which
    rely on a broader state or which need information from multiple time steps before, need a more general approach.
    """

    def __init__(self, **kwargs):
        # called from example_template.py, to set parameters.
        pass

    def correspond(self, time1, obj1, time2, obj2):
        # Compare obj1 and obj2, based on their .proto definition
        # Each object has by default an ID, it can have a positional representation (volume, boundary), and it can have
        # the descriptions set in the identification setup, represented by the .proto file (e.g. template_pb2).

        # For example, if we want to track objects from the identificaton template, we can compare my_classification,
        # size or their centroids.

        # Assume this is used with IdentificationTemplate
        # get ID of this object, unique for given timestep
        o1_id = obj1.id
        o2_id = obj2.id

        # access the properties set earlier
        properties1 = obj1.properties
        properties2 = obj2.properties

        if abs(properties1.centroid.x - properties2.centroid.x) < 1:
            # some random heuristic as example
            # if this is True, obj1 at timestep t and obj2 at timestep t+1 are the same
            return True

        return False

    def postprocess(self, obj_desc):
        print("postprocess @ objcomp template tracking")
        return

    # can be overwritten to filter tracks after the generation process
    # for example, only keep tracks that sustain longer than a certain time range
    def keep_track(self, track):
        first_node = track.nodes[0].this_node
        last_node = track.nodes[-1].this_node  # get time of first and last node in track

        end_time = pb_str_to_datetime(last_node.time)
        start_time = pb_str_to_datetime(first_node.time)

        if end_time - start_time > timedelta(hours=4):  # only keep if track exists for longer than 4h
            return True
        return False