Source code for vmpt.image_io

"""Image loaders (FITS / JPG+sidecar) and display stretching."""

from __future__ import annotations

from dataclasses import dataclass
from typing import Optional

import numpy as np
from astropy.io import fits
from astropy.wcs import WCS
from PIL import Image

Image.MAX_IMAGE_PIXELS = None


[docs] @dataclass class LoadedImage: data: np.ndarray wcs: WCS shape: tuple source_path: str mode: str wcs_sidecar_path: Optional[str] = None # set only for jpg+sidecar mode
def _first_image_hdu(hdul: fits.HDUList) -> int: for i, h in enumerate(hdul): d = h.data if d is not None and getattr(d, "ndim", 0) == 2: return i raise ValueError("No 2D image HDU found") def load_fits(path: str, hdu: Optional[int] = None) -> LoadedImage: with fits.open(path) as hdul: idx = _first_image_hdu(hdul) if hdu is None else hdu header = hdul[idx].header data = np.asarray(hdul[idx].data, dtype=np.float32) data = np.nan_to_num(data, nan=0.0, posinf=0.0, neginf=0.0) wcs = WCS(header).celestial if WCS(header).has_celestial else WCS(header) return LoadedImage(data=data, wcs=wcs, shape=data.shape, source_path=path, mode="fits") def _scale_wcs(wcs: WCS, factor: int) -> WCS: w = wcs.deepcopy() w.wcs.crpix = (np.asarray(wcs.wcs.crpix) - 0.5) / factor + 0.5 if wcs.wcs.has_cd(): w.wcs.cd = wcs.wcs.cd * factor elif wcs.wcs.has_pc(): w.wcs.cdelt = np.asarray(wcs.wcs.cdelt) * factor else: w.wcs.cdelt = np.asarray(wcs.wcs.cdelt) * factor return w def load_jpg_with_sidecar( jpg_path: str, sidecar_fits_path: str, max_dim: int = 8000, ) -> LoadedImage: im = Image.open(jpg_path) jpg_w, jpg_h = im.size with fits.open(sidecar_fits_path) as hdul: header = hdul[0].header.copy() naxis1 = header.get("NAXIS1") naxis2 = header.get("NAXIS2") crpix1 = header.get("CRPIX1") crpix2 = header.get("CRPIX2") if naxis1 is None or naxis2 is None: if crpix1 is not None and crpix2 is not None: implied_w = 2 * (crpix1 - 0.5) implied_h = 2 * (crpix2 - 0.5) if abs(implied_w - jpg_w) / max(implied_w, 1) > 0.1 or abs(implied_h - jpg_h) / max(implied_h, 1) > 0.1: print( f"WARNING: JPG dims ({jpg_w}x{jpg_h}) disagree with CRPIX-implied " f"sidecar dims ({implied_w:.0f}x{implied_h:.0f}); using JPG dims." ) header["NAXIS"] = 2 header["NAXIS1"] = jpg_w header["NAXIS2"] = jpg_h wcs = WCS(header) if wcs.has_celestial: wcs = wcs.celestial factor = 1 if max(jpg_w, jpg_h) > max_dim: factor = int(np.ceil(max(jpg_w, jpg_h) / max_dim)) new_w = jpg_w // factor new_h = jpg_h // factor im = im.resize((new_w, new_h), Image.BILINEAR) wcs = _scale_wcs(wcs, factor) arr = np.asarray(im) if arr.ndim == 2: shape = arr.shape else: shape = arr.shape[:2] return LoadedImage( data=arr, wcs=wcs, shape=shape, source_path=jpg_path, mode="jpg+sidecar", wcs_sidecar_path=sidecar_fits_path, ) def _apply_stretch(x: np.ndarray, stretch: str) -> np.ndarray: if stretch == "linear": return x if stretch == "sqrt": return np.sqrt(np.clip(x, 0, 1)) if stretch == "asinh": a = 0.1 return np.arcsinh(x / a) / np.arcsinh(1.0 / a) if stretch == "log": a = 1000.0 return np.log1p(a * np.clip(x, 0, 1)) / np.log1p(a) raise ValueError(f"Unknown stretch: {stretch}") def stretch_for_display( arr: np.ndarray, stretch: str = "asinh", percentile_lo: float = 1.0, percentile_hi: float = 99.5, ) -> np.ndarray: if arr.ndim == 3 and arr.shape[2] >= 3: rgb = arr[..., :3].astype(np.float32) / 255.0 stretched = _apply_stretch(rgb, stretch) stretched = np.clip(stretched * 255.0, 0, 255).astype(np.uint8) h, w = stretched.shape[:2] rgba = np.empty((h, w, 4), dtype=np.uint8) rgba[..., :3] = stretched rgba[..., 3] = 255 return rgba.view(np.uint32).reshape(h, w) x = arr.astype(np.float32) finite = np.isfinite(x) if finite.any(): lo, hi = np.percentile(x[finite], [percentile_lo, percentile_hi]) else: lo, hi = 0.0, 1.0 if hi <= lo: hi = lo + 1.0 norm = np.clip((x - lo) / (hi - lo), 0.0, 1.0) s = _apply_stretch(norm, stretch) g = np.clip(s * 255.0, 0, 255).astype(np.uint8) h, w = g.shape rgba = np.empty((h, w, 4), dtype=np.uint8) rgba[..., 0] = g rgba[..., 1] = g rgba[..., 2] = g rgba[..., 3] = 255 return rgba.view(np.uint32).reshape(h, w)