Skip to content
Snippets Groups Projects
pipeline.py 6.06 KiB
Newer Older
from enstools.feature.identification import IdentificationTechnique
from enstools.feature.tracking import TrackingTechnique

from datetime import datetime

class FeaturePipeline:
    """
    This class encapsules the feature detection pipeline. The pipeline consists of an identification and a tracking procedure.
    """

        self.id_tech = None
        self.tr_tech = None

        self.object_desc = None
        self.graph_desc = None

        self.dataset = None
        self.dataset_path = None

    def set_identification_strategy(self, strategy: IdentificationTechnique):
        """
        Set the strategy to use for the identification.

        Parameters
        ----------
        strategy : enstools.feature.identification.IdentificationTechnique
        """
        self.id_tech = strategy
        self.id_tech.pb_reference = self.pb_reference
        pass

    def set_tracking_strategy(self, strategy: TrackingTechnique):
        """
        Set the strategy to use for the tracking.

        Parameters
        ----------
        strategy : enstools.feature.tracking.TrackingTechnique
        """
        self.tr_tech = strategy
        self.tr_tech.pb_reference = self.pb_reference
        pass

    def set_data_path(self, path):
        """
        Set the path to the dataset(s) to process.
        This function calls enstools.io.read and therefore can read directories using wildcards.

        Parameters
        ----------
        path : list of str or tuple of str
                names of individual files or filename pattern
        """
        if path is None:
            raise Exception("None path provided.")

        from enstools.io import read
        import xarray as xr
        self.dataset = read(path)
        self.dataset_path = path

    def set_data(self, dataset: xr.Dataset):
        """
        Set the dataset to process.
        The function set_data_path() can be used instead.

        Parameters
        ----------
        dataset : xr.Dataset
                the xarray Dataset
        """
        if dataset is None:
            raise Exception("None dataset provided.")

        self.dataset = dataset
        self.dataset_path = ""

    def execute(self):
        """
        Execute the feature detection based on the set data and set techniques.
        """
        # TODO need API to check if identification output type fits to tracking input type.

        self.object_desc, self.dataset = self.id_tech.execute(self.dataset) # TODO need to return here dataset? or inplace?
        self.object_desc.file = str(self.dataset_path)
        self.object_desc.run_time = str(datetime.now().isoformat())

        if self.tr_tech is not None:
            # alters maybe dataset and adds connections to object_desc
            self.tr_tech.execute(self.object_desc, self.dataset)

    def generate_graph(self):
        if self.tr_tech is None:
            print("Graph requires set and executed tracking strategy. Exit.")
            exit()
        if self.object_desc is None:
            print("Need to execute pipeline first. Exit.")
            exit()

        if self.graph_desc is None:
            # generate graph out of object desc
            self.graph_desc = self.tr_tech.generate_graph(self.object_desc)
        pass

    def get_object_desc(self):
        return self.object_desc

    def get_graph_desc(self):
        self.generate_graph()
        return self.graph_desc

    def get_data(self):
        return self.dataset

    def generate_tracks(self):
        self.tr_tech.generate_tracks()
        self.tr_tech.filter_tracks()
        return

    def get_json_object(self):
        """
        Get the JSON type message of the currently saved result.

        Returns
        -------
        JSON object of identification/tracking result.
        """
        from google.protobuf.json_format import MessageToJson
        json_dataset = MessageToJson(self.object_desc)
Christoph.Fischer's avatar
Christoph.Fischer committed
    def save_result(self, description_path=None, description_type='json', dataset_path=None, graph_path=None):
        """
        Save the result of the detection process.

        Parameters
        ----------
        description_path : str
                Path to the file where the feature descriptions will be stored if not None.
        description_type : {'json', 'binary'}
                Type of the descriptions, either in JSON or in a binary format. Default is JSON.
        dataset_path : str
                Path to the file where the (altered) dataset should be stored if not None.
        """
        if description_path is not None:

            if description_type == 'binary':
                print("writing binary to " + description_path)
                with open(description_path, "wb") as file:
                    file.write(self.object_desc.SerializeToString())

            elif description_type == 'json':
                from google.protobuf.json_format import MessageToJson
                print("writing json to " + description_path)
                with open(description_path, "w") as file:
                    json_dataset = MessageToJson(self.object_desc)
                    # TODO do smth like including_default_value_fields=True
                    #  -> to get empty objects list if nothing detected. but also would add default values for ALL optional fields
                    file.write(json_dataset)

            else:
                print("Unknown type format, supported are 'binary' and 'json'.")

Christoph.Fischer's avatar
Christoph.Fischer committed
        if graph_path is not None and self.tr_tech is not None:
            print("writing graph to " + graph_path)

            from google.protobuf.json_format import MessageToJson
            with open(graph_path, "w") as file:
                self.generate_graph()
                json_graph = MessageToJson(self.graph_desc)
Christoph.Fischer's avatar
Christoph.Fischer committed
                file.write(json_graph)

        if dataset_path is not None:
            # write netcdf dataset to path
            print("writing netcdf to " + dataset_path)

            from enstools.io import write
            write(self.dataset, dataset_path)