"""
Module containing the :class:`~wintedrp.processors.BaseProcessor`
"""
import datetime
import getpass
import hashlib
import logging
import socket
import threading
from abc import ABC
from pathlib import Path
from queue import Queue
from threading import Thread
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
from mirar.data import DataBatch, Dataset, Image, ImageBatch, SourceBatch
from mirar.errors import (
ErrorReport,
ErrorStack,
NoncriticalProcessingError,
ProcessorError,
)
from mirar.io import open_fits, save_fits
from mirar.paths import (
BASE_NAME_KEY,
CAL_OUTPUT_SUB_DIR,
LATEST_WEIGHT_SAVE_KEY,
PACKAGE_NAME,
PROC_HISTORY_KEY,
RAW_IMG_KEY,
get_mask_path,
get_output_path,
max_n_cpu,
)
logger = logging.getLogger(__name__)
[docs]
class PrerequisiteError(ProcessorError):
"""
An error raised if a processor requires another one as a prerequisite,
but that processor is not present
"""
[docs]
class NoCandidatesError(ProcessorError):
"""
An error raised if a :class:`~wintedrp.processors.CandidateGenerator` produces
no candidates
"""
[docs]
class BaseProcessor:
"""
Base processor class, to be inherited from for all processors
"""
@property
def base_key(self):
"""
Unique key for the processor, to be used e.g in processing history tracking
:return: None
"""
raise NotImplementedError
max_n_cpu: int = max_n_cpu
subclasses = {}
def __init__(self):
self.night = None
self.night_sub_dir = None
self.preceding_steps = None
# For caching/multithreading
self.passed_batches = {}
self.err_stack = {}
self.progress = {}
@classmethod
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
cls.subclasses[cls.base_key] = cls
[docs]
def set_preceding_steps(self, previous_steps: list):
"""
Provides processor with the list of preceding processors, and saves this
:param previous_steps: list of processors
:return: None
"""
self.preceding_steps = previous_steps
[docs]
def set_night(self, night_sub_dir: str | int = ""):
"""
Sets the night subdirectory for the processor to read/write data
:param night_sub_dir: String/int subdirectory for night
:return: None
"""
self.night_sub_dir = night_sub_dir
self.night = night_sub_dir.split("/")[-1]
[docs]
def generate_error_report(
self, exception: Exception, batch: DataBatch
) -> ErrorReport:
"""
Generates an error report based on a python Exception
:param exception: exception raised
:param batch: batch which generated exception
:return: error report
"""
return ErrorReport(exception, self.__module__, batch.get_raw_image_names())
[docs]
def update_dataset(self, dataset: Dataset) -> Dataset:
"""
Update a dataset after processing
:param dataset: Initial dataset
:return: Updated dataset
"""
return dataset
[docs]
def check_prerequisites(
self,
):
"""
Check to see if any prerequisite processors are missing
:return: None
"""
[docs]
def clean_cache(self, cache_id: int):
"""
Function to clean the internal cache filled by base_apply
:param cache_id: key for cache
:return: None
"""
del self.passed_batches[cache_id]
del self.err_stack[cache_id]
[docs]
def base_apply(self, dataset: Dataset) -> tuple[Dataset, ErrorStack]:
"""
Core function to act on a dataset, and return an updated dataset
:param dataset: Input dataset
:return: Updated dataset, and any caught errors
"""
cache_id = threading.get_ident()
self.passed_batches[cache_id] = {}
self.err_stack[cache_id] = ErrorStack()
if len(dataset) > 0:
n_cpu = min([self.max_n_cpu, len(dataset)])
logger.info(f"Running {self.__class__.__name__} on {n_cpu} threads")
watchdog_queue = Queue()
workers = []
for _ in range(n_cpu):
# Set up a worker thread to process database load
worker = Thread(
target=self.apply_to_batch, args=(watchdog_queue, cache_id)
)
worker.daemon = True
worker.start()
workers.append(worker)
with tqdm(total=len(dataset), position=0, leave=False) as progress:
# Set up progress bar
self.progress[cache_id] = progress
# Loop over batches to add to queue
for j, batch in enumerate(dataset):
watchdog_queue.put(item=(j, batch))
# Wait for the queue to empty
watchdog_queue.join()
self.progress[cache_id].refresh()
self.progress[cache_id].close()
new_dataset = []
for key in sorted(self.passed_batches[cache_id].keys()):
new_dataset.append(self.passed_batches[cache_id][key])
dataset = self.update_dataset(Dataset(new_dataset))
err_stack = self.err_stack[cache_id]
self.clean_cache(cache_id=cache_id)
return dataset, err_stack
[docs]
def apply_to_batch(self, queue, cache_id: int):
"""
Function to run self.apply on a batch in the queue, catch any errors, and then
update the internal cache with the results.
:param queue: python threading queue
:param cache_id: key for cache
:return: None
"""
while True:
j, batch = queue.get()
try:
batch = self.apply(batch)
self.passed_batches[cache_id][j] = batch
except NoncriticalProcessingError as exc:
err = self.generate_error_report(exc, batch)
logger.error(err.generate_log_message())
self.err_stack[cache_id].add_report(err)
self.passed_batches[cache_id][j] = batch
except Exception as exc: # pylint: disable=broad-except
err = self.generate_error_report(exc, batch)
logger.error(err.generate_log_message())
self.err_stack[cache_id].add_report(err)
self.progress[cache_id].update(1)
self.progress[cache_id].refresh()
queue.task_done()
[docs]
def apply(self, batch: DataBatch):
"""
Function applying the processor to a
:class:`~mirar.data.base_data.DataBatch`.
Also updates the processing history.
:param batch: input data batch
:return: updated data batch
"""
batch = self._apply(batch)
batch = self._update_processing_history(batch)
return batch
def _apply(self, batch: DataBatch) -> DataBatch:
"""
Core function to update the :class:`~mirar.data.base_data.DataBatch`
:param batch: Input data batch
:return: updated data batch
"""
raise NotImplementedError
def _update_processing_history(
self,
batch: DataBatch,
) -> DataBatch:
"""
Function to update the processing history of each
:class:`~mirar.data.base_data.DataBlock` object in a
:class:`~mirar.data.base_data.DataBatch`.
:param batch: Input data batch
:return: Updated data batch
"""
for i, data_block in enumerate(batch):
data_block[PROC_HISTORY_KEY] += self.base_key + ","
data_block["REDUCER"] = getpass.getuser()
data_block["REDMACH"] = socket.gethostname()
data_block["REDTIME"] = str(datetime.datetime.now())
data_block["REDSOFT"] = PACKAGE_NAME
batch[i] = data_block
return batch
[docs]
class CleanupProcessor(BaseProcessor, ABC):
"""
Processor which 'cleans up' by deleting empty batches
"""
[docs]
def update_dataset(self, dataset: Dataset) -> Dataset:
# Remove empty dataset
new_dataset = Dataset([x for x in dataset.get_batches() if len(x) > 0])
return new_dataset
[docs]
class ImageHandler:
"""
Base class for handling images
"""
[docs]
@staticmethod
def open_fits(path: str | Path) -> Image:
"""
Opens a fits file, and returns an Image object
:param path: Path of image
:return: Image object
"""
path = str(path)
data, header = open_fits(path)
if RAW_IMG_KEY not in header:
header[RAW_IMG_KEY] = path
if BASE_NAME_KEY not in header:
header[BASE_NAME_KEY] = Path(path).name
return Image(data=data, header=header)
[docs]
@staticmethod
def save_fits(
image: Image,
path: str | Path,
):
"""
Save an Image to path
:param image: Image to save
:param path: path
:return: None
"""
save_fits(image, path)
[docs]
def save_mask_image(self, image: Image, img_path: Path) -> Path:
"""
Saves a mask image, following the astromatic software convention of
masked value = 0. and non-masked value = 1.
:param image: Science image
:param img_path: Path of parent image
:return: Path of mask image
"""
mask_path = get_mask_path(img_path)
header = image.get_header()
mask = image.get_mask()
if LATEST_WEIGHT_SAVE_KEY in image.header:
weight_data = self.open_fits(
image.header[LATEST_WEIGHT_SAVE_KEY]
).get_data()
mask = mask * weight_data
self.save_fits(Image(mask.astype(float), header), mask_path)
return mask_path
[docs]
@staticmethod
def get_hash(image_batch: ImageBatch):
"""
Get a unique hash for an image batch
:param image_batch: image batch
:return: unique hash for that batch
"""
key = "".join(
sorted([x[BASE_NAME_KEY] + x[PROC_HISTORY_KEY] for x in image_batch])
)
return hashlib.sha1(key.encode()).hexdigest()
[docs]
class BaseImageProcessor(BaseProcessor, ImageHandler, ABC):
"""
Base processor handling images in/images out
"""
def _apply(self, batch: ImageBatch) -> ImageBatch:
return self._apply_to_images(batch)
def _apply_to_images(
self,
batch: ImageBatch,
) -> ImageBatch:
raise NotImplementedError
[docs]
class ProcessorWithCache(BaseImageProcessor, ABC):
"""
Image processor with cached images associated to it, e.g a master flat
"""
def __init__(
self,
try_load_cache: bool = True,
write_to_cache: bool = True,
overwrite: bool = True,
cache_sub_dir: str = CAL_OUTPUT_SUB_DIR,
):
super().__init__()
self.try_load_cache = try_load_cache
self.write_to_cache = write_to_cache
self.overwrite = overwrite
self.cache_sub_dir = cache_sub_dir
[docs]
def select_cache_images(self, images: ImageBatch) -> ImageBatch:
"""
Select the appropriate cached image for the batch
:param images: images to process
:return: cached images to use
"""
raise NotImplementedError
[docs]
def get_cache_path(self, images: ImageBatch) -> Path:
"""
Gets path for saving/loading cached image
:param images: images to process
:return: cache path
"""
file_name = self.get_cache_file_name(images)
output_path = get_output_path(
base_name=file_name, dir_root=self.cache_sub_dir, sub_dir=self.night_sub_dir
)
output_path.parent.mkdir(parents=True, exist_ok=True)
return output_path
[docs]
def get_cache_file_name(self, images: ImageBatch) -> str:
"""
Get unique cache name for images
:param images: images to process
:return: unique hashed name
"""
cache_images = self.select_cache_images(images)
return f"{self.base_key}_{self.get_hash(cache_images)}.fits"
[docs]
def get_cache_file(self, images: ImageBatch) -> Image:
"""
Return the appropriate cached image for the batch
:param images: images to process
:return: cached image to use
"""
path = self.get_cache_path(images)
exists = path.exists()
if np.logical_and(self.try_load_cache, exists):
logger.debug(f"Loading cached file {path}")
return self.open_fits(path)
image = self.make_image(images)
if self.write_to_cache:
if np.sum([not exists, self.overwrite]) > 0:
self.save_fits(image, path)
return image
[docs]
def make_image(self, images: ImageBatch) -> Image:
"""
Make a cached image (e.g master flat)
:param images: images to use
:return: cached image
"""
raise NotImplementedError
[docs]
class ProcessorPremadeCache(ProcessorWithCache, ABC):
"""
Processor with pre-made master image
"""
def __init__(self, master_image_path: str | Path, *args, **kwargs):
super().__init__(*args, **kwargs)
self.master_image_path = Path(master_image_path)
[docs]
def get_cache_path(self, images: ImageBatch) -> Path:
"""
Gets path for saving/loading cached image
:param images: Images to process
:return: Path to cached image
"""
return self.master_image_path
[docs]
class BaseSourceGenerator(CleanupProcessor, ImageHandler, ABC):
"""
Base CandidateGenerator processor (image batch in, source batch out)
"""
@classmethod
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
cls.subclasses[cls.base_key] = cls
def _apply(self, batch: ImageBatch) -> SourceBatch:
source_batch = self._apply_to_images(batch)
if len(source_batch) == 0:
msg = "No sources found in image batch"
logger.warning(msg)
return source_batch
def _apply_to_images(self, batch: ImageBatch) -> SourceBatch:
raise NotImplementedError
[docs]
class BaseSourceProcessor(BaseProcessor, ABC):
"""
Base dataframe processor (Source batch in, source batch out)
"""
@classmethod
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
cls.subclasses[cls.base_key] = cls
def _apply(self, batch: SourceBatch) -> SourceBatch:
return self._apply_to_sources(batch)
def _apply_to_sources(
self,
batch: SourceBatch,
) -> SourceBatch:
raise NotImplementedError
[docs]
@staticmethod
def generate_super_dict(metadata: dict, source_row: pd.Series) -> dict:
"""
Generate a dictionary of metadata and candidate row, with lower case keys
:param metadata: Metadata for the source table
:param source_row: Individual row of the source table
:return: Combined dictionary
"""
super_dict = {key.lower(): val for key, val in metadata.items()}
super_dict.update(
{key.lower(): val for key, val in source_row.to_dict().items()}
)
return super_dict