Skip to content
Snippets Groups Projects
Commit 1318e02e authored by Oriol Tintó's avatar Oriol Tintó
Browse files

Adding file with code to slice and chunk arrays.

parent 04813bed
No related branches found
No related tags found
1 merge request!10Prepare release 2023.4
"""
This file contains classes and methods to divide a multi dimensional slice in different parts, via selecting a
chunk size or the number of desired parts.
"""
import logging
from typing import List, Tuple
from itertools import product
import numpy as np
def split_array(array: np.ndarray, parts: int):
"""
Splits the given array into a specified number of parts.
The function returns a list of chunks, where each chunk is a numpy array.
:param array: A numpy array to be split
:param parts: An int, the number of parts to split the array into
:return: A list of numpy arrays representing the chunks
"""
if parts == -1:
parts = array.size
shape = array.shape
possible_chunk_sizes = []
# Generate all possible chunk sizes for the given array shape
for chunk_size in product(*[range(1, shape[i] + 1) for i in range(len(shape))]):
# Check if the number of chunks generated by the current chunk size is equal to the desired number of parts
if np.prod(
[shape[i] // chunk_size[i] + int(shape[i] % chunk_size[i] != 0) for i in range(len(shape))]) == parts:
possible_chunk_sizes.append(chunk_size)
# Sort the possible chunk sizes in ascending order of the sum of the squares of their dimensions
possible_chunk_sizes.sort(key=lambda x: np.sum(np.array(x) ** 2)) # type: ignore
if not possible_chunk_sizes:
logging.warning(f"Could not divide the domain in {parts} parts. Trying with parts={parts - 1}.")
return split_array(array=array, parts=parts - 1)
chunk_size = possible_chunk_sizes[0]
chunks = []
# Get the number of chunks for the first possible chunk size
n_chunks = [shape[i] // chunk_size[i] + int(shape[i] % chunk_size[i] != 0) for i in range(len(shape))]
indexes = [range(n_chunks[i]) for i in range(len(shape))]
# Iterate over the chunks and append the corresponding slice of the array to the chunks list
for indx in product(*indexes):
sl = tuple(
slice(chunk_size[i] * indx[i], min(chunk_size[i] * (indx[i] + 1), shape[i])) for i in range(len(shape)))
chunks.append(array[sl])
return chunks
class MultiDimensionalSlice:
def __init__(self, indices: Tuple[int, ...], slices: Tuple[slice, ...] = (slice(0, 0, 0),)):
"""
Initialize the MultiDimensionalSlice object with indices and slice
:param indices: Tuple of indices
:param slices: Tuple of slice objects representing the slice indices
"""
self.indices = indices
self.slices = slices
def __repr__(self) -> str:
"""
Return the string representation of the MultiDimensionalSlice object
"""
return f"{self.__class__.__name__}({self.indices})"
@property
def size(self):
"""
Return the size of the slice
"""
size = 1
for s in self.slices:
size *= s.stop - s.start
return size
class MultiDimensionalSliceCollection:
def __init__(self, *, objects_array: np.ndarray = None, shape: Tuple[int, ...] = None,
chunk_sizes: Tuple[int, ...] = None):
"""
Initializes the MultiDimensionalSliceCollection class.
Args:
objects_array (np.ndarray): an array of MultiDimensionalSlice objects
shape (Tuple[int,...]): the shape of the array to be divided into chunks
chunk_sizes (Tuple[int,...]): the size of the chunks in each dimension
Raises:
AssertionError: if objects_array is not provided and shape and chunk_sizes are not provided
"""
if objects_array is not None:
assert shape is None and chunk_sizes is None
self.__initialize_from_array(objects_array=objects_array)
elif shape:
assert chunk_sizes is not None
self.__initialize_from_shape_and_chunk_size(shape, chunk_sizes)
else:
raise AssertionError("objects_array or shape and chunksize should be provided")
def __initialize_from_array(self, objects_array: np.ndarray):
"""
Create a MultiDimensionalSliceCollection given an array of MultiDimensionalSlice
Parameters
----------
objects_array: numpy array of MultiDimensionalSlice
Returns
-------
MultiDimensionalSliceCollection
"""
self.objects = objects_array
self.collection_shape = self.objects.shape
def __initialize_from_shape_and_chunk_size(self, shape: Tuple[int, ...], chunk_size: Tuple[int, ...]):
"""
Create a MultiDimensionalSliceCollection given a shape and a chunk size
Parameters
----------
shape
chunk_size
Returns
-------
MultiDimensionalSliceCollection
"""
cs = chunk_size
collection_shape = tuple(
[shape[i] // cs[i] + int(shape[i] % cs[i] != 0) for i in range(len(shape))])
# Calculates the shape of the collection of chunks, by dividing the shape of the array by the size of the chunks
# and adding 1 if there is a remainder
objects = np.empty(collection_shape, dtype=MultiDimensionalSlice)
# Initializes an empty array of the same shape as the collection with the dtype of MyObject
n_chunks = [shape[i] // cs[i] + int(shape[i] % cs[i] != 0) for i in range(len(shape))]
# Calculates the number of chunks in each dimension
indexes = [range(n_chunks[i]) for i in range(len(shape))]
# Create the indexes to iterate over the collection
for indx in product(*indexes):
sl = tuple(slice(cs[i] * indx[i], min(cs[i] * (indx[i] + 1), shape[i])) for i in range(len(shape)))
# for each index of the collection create a slice that corresponds
# to the portion of the array contained in the chunk
objects[indx] = MultiDimensionalSlice(indices=indx, slices=sl)
# assigns the created object to the corresponding position in the collection array
self.__initialize_from_array(objects_array=objects)
def __getitem__(self, args):
return self.objects[args] # returns the object at the specified position of the collection array
def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.collection_shape})" # returns a string representation of the class
@property
def slice(self) -> Tuple[slice, ...]:
total_slice = tuple(slice(None) for _ in self.collection_shape)
for obj in self.objects.flat:
for i, s in enumerate(obj.slices):
if total_slice[i].start is None:
total_slice = total_slice[:i] + (s,) + total_slice[i + 1:]
else:
if s.start < total_slice[i].start:
total_slice = total_slice[:i] + (
slice(s.start, total_slice[i].stop, total_slice[i].step),) + total_slice[i + 1:]
if s.stop > total_slice[i].stop:
total_slice = total_slice[:i] + (
slice(total_slice[i].start, s.stop, total_slice[i].step),) + total_slice[i + 1:]
return total_slice
def split(self, parts: int) -> List['MultiDimensionalSliceCollection']:
"""
Divide the MultiDimensionalSliceCollection into a different parts of similar size
Args:
parts (int): Number of parts to divide the group
Returns:
_type_: _description_
"""
array_parts = split_array(self.objects, parts)
return [MultiDimensionalSliceCollection(objects_array=p) for p in array_parts]
@property
def size(self) -> int:
return sum([ob.size for ob in self.objects.ravel()])
def __len__(self) -> int:
return self.objects.size
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment