"""
Module for splitting images into sub-images
"""
import copy
import logging
import numpy as np
from mirar.data import Dataset, Image, ImageBatch
from mirar.paths import BASE_NAME_KEY, LATEST_SAVE_KEY, LATEST_WEIGHT_SAVE_KEY
from mirar.processors.base_processor import BaseImageProcessor
logger = logging.getLogger(__name__)
SUB_ID_KEY = "SUBDETID"
SUB_COORD_KEY = "SUBCOORD"
[docs]
class SplitImage(BaseImageProcessor):
"""
Processor for splitting images
"""
base_key = "split"
def __init__(self, buffer_pixels: int = 0, n_x: int = 1, n_y: int = 1):
super().__init__()
self.buffer_pixels = buffer_pixels
self.n_x = n_x
self.n_y = n_y
def __str__(self) -> str:
return (
f"Processor to split images into "
f"{self.n_x}x{self.n_y}={self.n_x*self.n_y} smaller images."
)
[docs]
def get_range(
self,
n_chunks: int,
pixel_width: int,
i: int,
) -> tuple[int, int]:
"""
Function to return pixel index range for sub images
:param n_chunks: number of chunks to divide axis into
:param pixel_width: total pixel width of axis
:param i: index of chunk to evaluate
:return: lower pixel index and upper pixel index of chunk
"""
lower = max(0, i * int(pixel_width / n_chunks) - self.buffer_pixels)
upper = min(
pixel_width, (1 + i) * int(pixel_width / n_chunks) + self.buffer_pixels
)
return lower, upper
def _apply_to_images(
self,
batch: ImageBatch,
) -> ImageBatch:
new_images = ImageBatch()
logger.debug(f"Splitting each data into {self.n_x*self.n_y} sub-images")
for image in batch:
pix_width_x, pix_width_y = image.get_data().shape
k = 0
for index_x in range(self.n_x):
x_0, x_1 = self.get_range(self.n_x, pix_width_x, index_x)
for index_y in range(self.n_y):
y_0, y_1 = self.get_range(self.n_y, pix_width_y, index_y)
new_data = np.array(image.get_data()[x_0:x_1, y_0:y_1])
new_header = copy.copy(image.get_header())
for key in ["DETSIZE", "INFOSEC", "TRIMSEC", "DATASEC"]:
if key in new_header.keys():
del new_header[key]
sub_img_id = f"{index_x}_{index_y}"
new_header[SUB_COORD_KEY] = (
sub_img_id,
"Sub-data coordinate, in form x_y",
)
new_header["SUBNX"] = (index_x + 1, "Sub-data x index")
new_header["SUBNY"] = (index_y + 1, "Sub-data y index")
new_header["SUBNXTOT"] = (self.n_x, "Total number of sub-data in x")
new_header["SUBNYTOT"] = (self.n_y, "Total number of sub-data in y")
new_header[SUB_ID_KEY] = k
k += 1
new_header["SRCIMAGE"] = (
image[BASE_NAME_KEY],
"Source data name, from which sub-data was made",
)
new_header["NAXIS1"], new_header["NAXIS2"] = new_data.shape
new_header[BASE_NAME_KEY] = image[BASE_NAME_KEY].replace(
".fits", f"_{sub_img_id}.fits"
)
for key in [LATEST_SAVE_KEY, LATEST_WEIGHT_SAVE_KEY]:
if key in new_header.keys():
del new_header[key]
new_images.append(Image(data=new_data, header=new_header))
return new_images
[docs]
def update_dataset(self, dataset: Dataset) -> Dataset:
all_new_batches = []
for batch in dataset:
new_images = [[] for _ in range(self.n_x * self.n_y)]
for image in batch:
idx = image[SUB_ID_KEY]
new_images[idx] += [image]
all_new_batches += new_images
all_new_batches = [ImageBatch(x) for x in all_new_batches]
return Dataset(all_new_batches)