Skip to content
Snippets Groups Projects
pipeline.py 4.29 KiB
Newer Older
from enstools.feature.identification import IdentificationTechnique
from enstools.feature.tracking import TrackingTechnique

from datetime import datetime

class FeaturePipeline:
    """
    This class encapsules the feature detection pipeline. The pipeline consists of an identification and a tracking procedure.
    """

    def __init__(self):
        self.id_tech = None
        self.tr_tech = None
        self.pb_identification_result = None
        self.dataset = None
        self.dataset_path = None

    def set_identification_strategy(self, strategy: IdentificationTechnique):
        """
        Set the strategy to use for the identification.

        Parameters
        ----------
        strategy : enstools.feature.identification.IdentificationTechnique
        """
        self.id_tech = strategy
        pass

    def set_tracking_strategy(self, strategy: TrackingTechnique):
        """
        Set the strategy to use for the tracking.

        Parameters
        ----------
        strategy : enstools.feature.tracking.TrackingTechnique
        """
        self.tr_tech = strategy
        pass

    def set_data_path(self, path):
        """
        Set the path to the dataset(s) to process.
        This function calls enstools.io.read and therefore can read directories using wildcards.

        Parameters
        ----------
        path : list of str or tuple of str
                names of individual files or filename pattern
        """
        if path is None:
            raise Exception("None path provided.")

        from enstools.io import read
        import xarray as xr
        self.dataset = read(path)
        self.dataset_path = path

    def execute(self):
        """
        Execute the feature detection based on the set data and set techniques.
        """
        # TODO need API to check if identification output type fits to tracking input type.

        self.pb_identification_result, self.dataset = self.id_tech.execute(self.dataset)

        self.pb_identification_result.file = str(self.dataset_path)
        self.pb_identification_result.run_time = str(datetime.now().isoformat())

        if self.tr_tech is not None:
            self.tr_tech.execute(self.pb_identification_result, self.dataset)

    def get_json_object(self):
        """
        Get the JSON type message of the currently saved result.

        Returns
        -------
        JSON object of identification/tracking result.
        """
        from google.protobuf.json_format import MessageToJson
        json_dataset = MessageToJson(self.pb_identification_result)
        return json_dataset


    def save_result(self, description_path=None, description_type='json', dataset_path=None):
        """
        Save the result of the detection process.

        Parameters
        ----------
        description_path : str
                Path to the file where the feature descriptions will be stored if not None.
        description_type : {'json', 'binary'}
                Type of the descriptions, either in JSON or in a binary format. Default is JSON.
        dataset_path : str
                Path to the file where the (altered) dataset should be stored if not None.
        """
        if description_path is not None:

            if description_type == 'binary':
                print("writing binary to " + description_path)
                with open(description_path, "wb") as file:
                    file.write(self.pb_identification_result.SerializeToString())

            elif description_type == 'json':
                from google.protobuf.json_format import MessageToJson
                print("writing json to " + description_path)
                with open(description_path, "w") as file:
                    json_dataset = MessageToJson(self.pb_identification_result)
                    # TODO do smth like including_default_value_fields=True
                    #  -> to get empty objects list if nothing detected. but also would add default values for ALL optional fields
                    file.write(json_dataset)

            else:
                print("Unknown type format, supported are 'binary' and 'json'.")

        if dataset_path is not None:
            # write netcdf dataset to path
            print("writing netcdf to " + dataset_path)

            from enstools.io import write
            write(self.dataset, dataset_path)