from enstools.feature.identification import IdentificationTechnique from enstools.feature.tracking import TrackingTechnique from datetime import datetime import xarray as xr class FeaturePipeline: """ This class encapsules the feature detection pipeline. The pipeline consists of an identification and a tracking procedure. """ def __init__(self, proto_ref): self.id_tech = None self.tr_tech = None self.pb_identification_result = None self.dataset = None self.dataset_path = None self.pb_reference = proto_ref def set_identification_strategy(self, strategy: IdentificationTechnique): """ Set the strategy to use for the identification. Parameters ---------- strategy : enstools.feature.identification.IdentificationTechnique """ self.id_tech = strategy self.id_tech.pb_reference = self.pb_reference pass def set_tracking_strategy(self, strategy: TrackingTechnique): """ Set the strategy to use for the tracking. Parameters ---------- strategy : enstools.feature.tracking.TrackingTechnique """ self.tr_tech = strategy self.tr_tech.pb_reference = self.pb_reference pass def set_data_path(self, path): """ Set the path to the dataset(s) to process. This function calls enstools.io.read and therefore can read directories using wildcards. Parameters ---------- path : list of str or tuple of str names of individual files or filename pattern """ if path is None: raise Exception("None path provided.") from enstools.io import read import xarray as xr self.dataset = read(path) self.dataset_path = path def set_data(self, dataset: xr.Dataset): """ Set the dataset to process. The function set_data_path() can be used instead. Parameters ---------- dataset : xr.Dataset the xarray Dataset """ if dataset is None: raise Exception("None dataset provided.") self.dataset = dataset self.dataset_path = "" def execute(self): """ Execute the feature detection based on the set data and set techniques. """ # TODO need API to check if identification output type fits to tracking input type. self.pb_identification_result, self.dataset = self.id_tech.execute(self.dataset) self.pb_identification_result.file = str(self.dataset_path) self.pb_identification_result.run_time = str(datetime.now().isoformat()) if self.tr_tech is not None: self.tr_tech.execute(self.pb_identification_result, self.dataset) def get_json_object(self): """ Get the JSON type message of the currently saved result. Returns ------- JSON object of identification/tracking result. """ from google.protobuf.json_format import MessageToJson json_dataset = MessageToJson(self.pb_identification_result) return json_dataset def save_result(self, description_path=None, description_type='json', dataset_path=None, graph_path=None): """ Save the result of the detection process. Parameters ---------- description_path : str Path to the file where the feature descriptions will be stored if not None. description_type : {'json', 'binary'} Type of the descriptions, either in JSON or in a binary format. Default is JSON. dataset_path : str Path to the file where the (altered) dataset should be stored if not None. """ if description_path is not None: if description_type == 'binary': print("writing binary to " + description_path) with open(description_path, "wb") as file: file.write(self.pb_identification_result.SerializeToString()) elif description_type == 'json': from google.protobuf.json_format import MessageToJson print("writing json to " + description_path) with open(description_path, "w") as file: json_dataset = MessageToJson(self.pb_identification_result) # TODO do smth like including_default_value_fields=True # -> to get empty objects list if nothing detected. but also would add default values for ALL optional fields file.write(json_dataset) else: print("Unknown type format, supported are 'binary' and 'json'.") if graph_path is not None and self.tr_tech is not None: print("writing graph to " + graph_path) from google.protobuf.json_format import MessageToJson with open(graph_path, "w") as file: json_graph = MessageToJson(self.tr_tech.get_graph()) file.write(json_graph) if dataset_path is not None: # write netcdf dataset to path print("writing netcdf to " + dataset_path) from enstools.io import write write(self.dataset, dataset_path)