Skip to content
Snippets Groups Projects
pb2_properties_api.py 4.35 KiB
Newer Older
class ObjectProperties:
    """
    Class encapsulating the protobuf API, acting as adapter.
    The idea is that the user should simply have to call
        obj = ObjectProperties.getInstance(pb2_reference)
    where pb2_reference is the generated python file by protobuf based on the users features.
    The user then can simply use obj.set('attr', value) to set the value. __getitem__ accessor is with protobuf
    extensions very complicated.
    """

    def __init__(self, ref, instance):
        """

        Parameters
        ----------
        ref
            reference to the protobuf type (class)
        instance
            Property instance
        """
        object.__setattr__(self, 'pb2_ref', ref)
        object.__setattr__(self, 'pb2_instance', instance)

    @staticmethod
    def get_instance(ref_pb2):
        """
        Get an instance of a new empty object based on the given protobuf description.

        Parameters
        ----------
        ref_pb2
            Reference to the type of protobuf object.

        Returns
        -------
        The object properties to be filled.
        """

        instance_constructor = getattr(ref_pb2.identification__pb2, 'Properties')
        instance = instance_constructor()
        return ObjectProperties(ref_pb2, instance)

    @staticmethod
    def add_to_objects(object_block_ref, obj_prop, id=None):
        """
        Add the given id-object pair to the object block in the protobuf description. If id is None or not given, auto
        assign an ID which is currently not taken in this object block.

        Parameters
        ----------
        object_block_ref
            Reference to the object block for this subset.
        obj_prop
            The ObjectProperties object to be added.
        id
            ID to be set for the object. Can be for example the ID set in a labeled field.
        """

        if id is None:
            ids_in_use = [o.id for o in object_block_ref]
            # pick first free id of object block
            for cur_id in range(1, len(ids_in_use) + 2):
                if cur_id not in ids_in_use:
                    id = cur_id
                    break

        obj = object_block_ref.add()
        obj.id = id

        pb2_prop = obj_prop.as_pb2()
        obj.properties.CopyFrom(pb2_prop)

    def as_pb2(self):
        return object.__getattribute__(self, 'pb2_instance')

    @staticmethod
    def remove_from_objects(object_block_ref, obj_instance):
        """
        Parameters
        ----------
        object_block_ref
            Reference to the object block for this subset.
        obj_instance
            The object instance to be removed

        Returns
        -------
        """
        # obj_instance = obj_instance.pb2_instance
        for o in object_block_ref:
            if o.properties == obj_instance.pb2_instance:
                object_block_ref.remove(o)
                return
        raise ValueError("ObjectProperties to be removed not in given block.")

    # TODO does this work with repeated fields? e.g. obj.foo[0].bar
    def set(self, attr, value):
        """
        Set an attribute of the object description.

        Parameters
        ----------
        attr : str
                Name of attribute, e.g. 'foo' or if nested 'foo.bar'
        value : the value
        """
        attr_splits = attr.split('.')

        if len(attr_splits) == 1:
            self.pb2_instance.Extensions[getattr(self.pb2_ref, attr)] = value
        else:
            accessor = self.pb2_instance.Extensions[getattr(self.pb2_ref, attr_splits[0])]
            for i in range(1, len(attr_splits) - 1):
                accessor = getattr(accessor, attr_splits[i])
            setattr(accessor, attr_splits[-1], value)

    def get(self, attr):
        """
        Getter for attributes in object description.

        Parameters
        ----------
        attr : str
                Name of attribute.

        Returns
        -------
        Value of attribute.
        """
        attr_splits = attr.split('.')

        if len(attr_splits) == 1:
            return self.pb2_instance.Extensions[getattr(self.pb2_ref, attr)]
        else:
            accessor = self.pb2_instance.Extensions[getattr(self.pb2_ref, attr_splits[0])]
            for i in range(1, len(attr_splits) - 1):
                accessor = getattr(accessor, attr_splits[i])
            return getattr(accessor, attr_splits[-1])