Skip to content
Snippets Groups Projects
identification.py 3.21 KiB
from ..identification import IdentificationTechnique
from .processing import mask_to_proto

import operator
import xarray as xr


class MultiThresholdIdentification(IdentificationTechnique):

    def __init__(self, field, thresholds, comparison_operator, processing_mode="2d", compress=True, **kwargs):
        self.field = field
        self.thresholds = thresholds # TODO checks?
        # Save comparison operator name and get associated comparison function
        self.comparison_operator = comparison_operator
        if comparison_operator in { ">=", "ge" }:
            self.compare = operator.ge
        elif comparison_operator in { ">", "gt" }:
            self.compare = operator.gt
        elif comparison_operator in { "<=", "le" }:
            self.compare = operator.le
        elif comparison_operator in { "<", "lt" }:
            self.compare = operator.lt
        else:
            raise ValueError(f"unknown comparison operator '{comparison_operator}'")
        # Compression of object mask array in protobuf objects
        self.compress = compress
        # Specify processing mode 2d or 3d: In 2d, identification will be performed on 2d levels,
        #   in 3d the identification will be performed per 3d block
        assert processing_mode in { "2d", "3d" }
        self.processing_mode = processing_mode

    def precompute(self, dataset: xr.Dataset, **kwargs):
        assert self.field in dataset
        return dataset

    def identify(self, dataset: xr.Dataset, **kwargs):
        field = dataset["field"]
        # Identification, here is the place for your identification algorithm.
        # This is called in parallel for each spatial subset of the data, e.g., for each member and for each timestep
        # Therefore, dataset contains only spatial dimensions (lat,lon) if self.processing_mode = '2d' is set,
        #   (see __init__), or (lat,lon,level) if '3d' is set.

        # here your code ... identify objects for this subset.

        # Let's say you detected 5 objects:
        obj_list = []

        for i in range(5):
            # get an instance of a new object, can pass an ID or set in manually afterwards
            obj = self.get_new_object()
            # set some ID to it
            obj.id = i + 1

            # get properties of object and populate them (like defined in template.proto)
            properties = obj.properties

            properties.size = 42
            properties.centroid.x = i
            properties.centroid.y = 2*i
            properties.centroid.z = 3*i

            properties.list_of_something.append("type x")
            properties.list_of_something.append("type z")

            obj_list.append(obj)

        # return the dataset (can be changed here), and the list of objects
        return dataset, obj_list

    def postprocess(self, dataset: xr.Dataset, pb2_desc, **kwargs):
        # Here post process: called once after processing the whole dataset.
        # If needed, the dataset or objects can be changed here.

        # objects are already in finished protobuf description, cf. the JSON output.
        # pb2_desc.sets is list of trackable sets, each one contains the timesteps list, where each entry is a valid time and contains a list of objects.
        return dataset, pb2_desc