Source code for vmpt.session_io

"""Session JSON save/load.

`session.json` is written as a pure APT MPT plan JSON — no vMPT-only
keys, no file paths — so APT's MPT loader accepts it directly. A sibling
`vmpt_workspace.json` (same parent directory) carries the bits MPT
doesn't preserve: per-shutter target_id / role, highlighted set, image
+ catalog paths, slitlet height. vMPT reads both on import; APT only
sees the MPT file.

Old-style sessions (single file with a flat top-level `open_shutters`
list and a `pointing` block) are still accepted on import.
"""

from __future__ import annotations

import json
from collections import Counter
from dataclasses import dataclass, field
from datetime import datetime, timezone
from pathlib import Path
from typing import Optional

from vmpt.coords import V3_IDL_Y_ANGLE
from vmpt.empt_io import OpenShutter


SESSION_TOOL_VERSION = "1.4"
# Bundle filenames — chosen so the prefix telegraphs the role of each file:
#   MPT_*   → load these into APT MPT (plan JSON + primaries catalog)
#   vMPT_*  → vMPT-only state (image / catalog paths, target_id+role per shutter)
#   eMPT_*  → use these with the European eMPT pipeline (or any tool that
#             reads the eMPT shutter-mask / observed-targets / pointing-summary)
MPT_PLAN_FILENAME = "MPT_plan.json"
MPT_CATALOG_FILENAME = "MPT_catalog.cat"   # primaries catalog APT can import
WORKSPACE_FILENAME = "vMPT_workspace.json"
EMPT_OBSERVED_FILENAME = "eMPT_observed_targets.cat"
EMPT_POINTING_FILENAME = "eMPT_pointing_summary.txt"
EMPT_SHUTTER_MASK_FILENAME = "eMPT_shutter_mask.csv"
# Legacy / back-compat — older bundles used these names; the importer
# falls back to them if the current names are missing.
_LEGACY_MPT_PLAN_FILENAMES = ("session_MPT_plan.json",)
_LEGACY_WORKSPACE_FILENAMES = ("vmpt_workspace.json",)

# Maps slitlet height → APT's `msaSlitlet` enum value (per the reference
# G140H+G235H+G395H and a370 plans). 3 is the common case; vMPT only
# exposes 1, 3, 5 in the UI dropdown.
_MSA_SLITLET_ENUM = {
    1: "ONE_SHUTTER",
    2: "TWO_SHUTTER",
    3: "THREE_SHUTTER",
    5: "FIVE_SHUTTER",
}


[docs] @dataclass class Session: pointing_ra_deg: float pointing_dec_deg: float pa_v3_deg: float # V3 PA, NOT the aperture PA disperser: str filter_name: str slitlet_height: int open_shutters: list[OpenShutter] highlighted: list[tuple[int, int, int]] = field(default_factory=list) image_path: Optional[str] = None wcs_sidecar_path: Optional[str] = None catalog_path: Optional[str] = None # vMPT 1.1+: multi-catalog support. Each entry: {"path": str, # "enabled": bool}. `catalog_path` (single) is preserved for # backward compatibility with vMPT 1.0 bundles — set to the first # entry's path when catalog_paths is non-empty. catalog_paths: list = field(default_factory=list) tool_version: str = SESSION_TOOL_VERSION created: Optional[str] = None name: Optional[str] = None
def _utc_now_iso() -> str: return datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ") def _group_into_slitlets( open_shutters: list[OpenShutter], ) -> list[tuple[dict, Optional[str]]]: """Compress flat OpenShutter list back into MPT slitlets. Returns a list of (slitlet_dict, primary_id) tuples. Slitlets whose member shutters carry a non-None target_id are placed FIRST in the output (so positional `primaryIds[j] ↔ slitlets[j]` alignment works the way parse_mpt_json expects); manual / target-less slitlets come after, with primary_id=None. """ # Bucket shutters by (q, d). For each bucket, sort by s and walk runs. buckets: dict[tuple[int, int], list[OpenShutter]] = {} for sh in open_shutters: buckets.setdefault((int(sh.q), int(sh.d)), []).append(sh) targeted: list[tuple[dict, str]] = [] manual: list[dict] = [] for (q, d), members in buckets.items(): members.sort(key=lambda sh: int(sh.s)) # Group consecutive s into runs. run: list[OpenShutter] = [] runs: list[list[OpenShutter]] = [] for sh in members: if run and int(sh.s) == int(run[-1].s) + 1: run.append(sh) else: if run: runs.append(run) run = [sh] if run: runs.append(run) for r in runs: sl = {"q": q, "d": d, "s": int(r[0].s), "h": len(r)} # If any shutter in the run has a target_id, attach the most # common one as the slitlet's primary. tids = [sh.target_id for sh in r if sh.target_id is not None] if tids: primary = Counter(tids).most_common(1)[0][0] targeted.append((sl, str(primary))) else: manual.append(sl) # Targeted first → primaryIds positional alignment holds for them. out: list[tuple[dict, Optional[str]]] = [(sl, t) for sl, t in targeted] out.extend((sl, None) for sl in manual) return out def _unfold_slitlets(slitlets: list[dict]) -> list[OpenShutter]: """Expand each {q,d,s,h} into h OpenShutter rows. Middle one → 'target', others → 'sky'. Matches APT MPT semantics.""" out: list[OpenShutter] = [] for sl in slitlets: q = int(sl["q"]); d = int(sl["d"]) s0 = int(sl["s"]); h = int(sl.get("h", 1)) mid = h // 2 for off in range(h): role = "target" if off == mid else "sky" out.append(OpenShutter(q=q, s=s0 + off, d=d, target_id=None, role=role)) return out def _spectral_offset_map(disperser: str) -> str: """APT's spectralOverlapShutterOffsetMap value, per disperser. APT names the map per individual grating (e.g. JWST_NIRSPEC_G395H), not per resolution class — verified against the reference G395H, G140H, and PRISM plan exports. """ d = (disperser or "").upper() if d == "PRISM": return "JWST_NIRSPEC_PRISM" if d in {"G140M", "G235M", "G395M", "G140H", "G235H", "G395H"}: return f"JWST_NIRSPEC_{d}" return "JWST_NIRSPEC_PRISM" def _build_mpt_payload(session: Session) -> dict: """Pure APT MPT plan JSON — no vMPT-only keys. Structure mirrors the reference G140H+G235H+G395H and a370 plan exports byte-for-byte schema-wise (we don't reproduce the full plannerSpecification, but the top-level shape is identical).""" apa = (float(session.pa_v3_deg) + V3_IDL_Y_ANGLE) % 360.0 pairs = _group_into_slitlets(session.open_shutters) slitlets = [sl for sl, _ in pairs] primary_ids: list[int] = [] for _, tid in pairs: if tid is None: continue try: primary_ids.append(int(tid)) except (TypeError, ValueError): # Non-numeric target_ids can't sit in primaryIds (APT uses int); # silently skip — the workspace sidecar carries them losslessly. pass grating_filter = f"{session.disperser}_{session.filter_name}" msa_slitlet = _MSA_SLITLET_ENUM.get(int(session.slitlet_height), "THREE_SHUTTER") created = session.created if session.created is not None else _utc_now_iso() name = session.name or f"vMPT session — {created}" # `catalog_basename` is the identifier APT will look for in its Target # List database. We always export a primaries catalog file in the # bundle whose name MATCHES this basename (so importing the .cat under # its default name in APT lines up automatically with the plan). # Default is the stem of MPT_CATALOG_FILENAME (e.g. "MPT_catalog"); # if a user-loaded catalog file is available we use its stem instead. from pathlib import Path as _P catalog_basename = _P(MPT_CATALOG_FILENAME).stem if session.catalog_path: catalog_basename = _P(session.catalog_path).stem n_targets = len(primary_ids) return { "instrument": "JWST/NIRSpec", "name": name, "aperturePA": apa, "theta": 0.0, "catalog": { "name": catalog_basename, "primariesName": catalog_basename, "fillersName": None, "primaries": None, "fillers": None, }, "referencePointing": { "ra": float(session.pointing_ra_deg), "dec": float(session.pointing_dec_deg), }, "configs": [{ "name": "c1", "version": f"vMPT-{SESSION_TOOL_VERSION}", "info": {"fixedSlit": None}, "masterBackground": False, "slitlets": slitlets, "exposures": [{ "name": "c1e1", "gratingFilter": grating_filter, "msaSlitlet": msa_slitlet, "ra": float(session.pointing_ra_deg), "dec": float(session.pointing_dec_deg), "sourceIds": primary_ids, }], "primaryIds": primary_ids, "fillerIds": [], }], "stats": [{ "name": "s0", "score": float(n_targets), "numberOfConfigurations": 1, "numberOfTargets": n_targets, "duration": 0.0, "totalDuration": 0.0, }], "errors": [], "plannerSpecification": { "gratingSpecification": { "gratings": [grating_filter], "allowContamination": False, "multiplexLimit": None, "multiplexingMinimum": None, }, "planName": name, "planAngle": apa, "theta": None, "candidates": { "fillers": None, "primaries": catalog_basename, "catalog": catalog_basename, "slitSources": None, }, "slitSpecification": { "sweetSpot": "PERCENT_0", "slitlet": msa_slitlet, }, "searchParameters": { "useWeights": False, "enableMonteCarlo": False, "monteCarloShuffles": None, "ignoreStuckOpen": False, "spectralOverlapThreshold": 1.5, "numberOfConfigurations": 1, "allowMultiSourceShutters": False, "spectralOverlapShutterOffsetMap": _spectral_offset_map(session.disperser), }, "slitSearchSpecification": None, "maskingSpecification": { "fillerMask": None, "primaryMask": None, "noGapFiller": False, "noGapPrimary": False, "noRedCutoffPrimary": False, "noBlueCutoffPrimary": False, "noRedCutoffFiller": False, "noBlueCutoffFiller": False, }, "pointingSpecification": { "ditherType": "NONE", "shouldNod": False, "pointingMode": "GRID_SEARCH", "numberOfNods": 5, "fixedDitherOffsets": [{"spatial": 0, "dispersion": 0}], "fixedPointings": None, "partiallyCompletedPrimaries": False, "minPrimaryDitherPoints": 0, "partiallyCompletedFillers": False, "minFillerDitherPoints": None, }, "searchGridSpecification": { "searchArea": { "width": 10.0, "height": 10.0, "center": {"x": 0.0, "y": 0.0}, "ylength": 10.0, "xlength": 10.0, "corners": [], "offsetToBottomCorner": {"x": -5.0, "y": -5.0}, }, "searchAreaCenter": { "ra": float(session.pointing_ra_deg), "dec": float(session.pointing_dec_deg), }, "searchAreaHeight": 10.0, "searchAreaWidth": 10.0, "searchStepSize": 0.5, }, "wavelengthRangeSpecification": { "primaryRange1": False, "primaryRange2": False, "primaryRange3": False, "primaryRange4": False, "primaryRange5": False, "fillerRange1": False, "fillerRange2": False, "fillerRange3": False, "fillerRange4": False, "fillerRange5": False, }, }, } def _build_workspace_payload(session: Session) -> dict: """vMPT-only extras: per-shutter target_id/role, highlighted set, image + catalog paths. Written next to the MPT session.json so APT never sees it.""" return { "vmpt_version": session.tool_version, "created": session.created if session.created is not None else _utc_now_iso(), "pa_v3_deg": float(session.pa_v3_deg), "slitlet_height": int(session.slitlet_height), "open_shutters": [ { "q": int(sh.q), "d": int(sh.d), "s": int(sh.s), "target_id": (str(sh.target_id) if sh.target_id is not None else None), "role": sh.role, } for sh in session.open_shutters ], "highlighted": [[int(q), int(s), int(d)] for (q, s, d) in session.highlighted], "image_path": session.image_path, "wcs_sidecar_path": session.wcs_sidecar_path, "catalog_path": session.catalog_path, "catalog_paths": [ {"path": str(e.get("path")), "enabled": bool(e.get("enabled", True))} for e in (session.catalog_paths or []) if e.get("path") ], }
[docs] def export_session_json(session: Session, path: str) -> None: """Write the session as an MPT-format plan JSON at `path`, AND a sibling `vmpt_workspace.json` carrying the vMPT-only extras.""" p = Path(path) p.parent.mkdir(parents=True, exist_ok=True) with open(p, "w") as f: f.write(json.dumps(_build_mpt_payload(session), indent=2)) workspace_path = p.parent / WORKSPACE_FILENAME with open(workspace_path, "w") as f: f.write(json.dumps(_build_workspace_payload(session), indent=2))
def _parse_grating_filter(gf: Optional[str]) -> tuple[Optional[str], Optional[str]]: if not isinstance(gf, str): return None, None for sep in ("_", "/"): if sep in gf: a, b = gf.split(sep, 1) return a, b return None, None def _parse_catalog_paths( raw: object, legacy_single: Optional[str] = None, ) -> list: """Normalise the workspace's catalog_paths field into a clean list of `{"path": str, "enabled": bool}` dicts. Older bundles (vMPT ≤ 1.0) only stored a single `catalog_path`; if the new field is absent or empty, synthesise a single-entry list from that fallback so multi-catalog UIs still have one row to show. """ out: list = [] if isinstance(raw, list): for entry in raw: if isinstance(entry, dict) and entry.get("path"): out.append({ "path": str(entry["path"]), "enabled": bool(entry.get("enabled", True)), }) elif isinstance(entry, str) and entry: out.append({"path": entry, "enabled": True}) if not out and legacy_single: out.append({"path": str(legacy_single), "enabled": True}) return out def _import_mpt(data: dict, sidecar: dict) -> Session: """Combine an MPT-format session.json + vmpt_workspace.json into a Session.""" configs = data.get("configs") or [] if not configs or not isinstance(configs[0], dict): raise ValueError("session has no usable configs") cfg = configs[0] # Pick first DISPERSED exposure (same logic as mpt_io.parse_mpt_json). exps = cfg.get("exposures") or [] primary_exp = None for e in exps: if isinstance(e, dict) and e.get("gratingFilter"): primary_exp = e break if primary_exp is None and exps and isinstance(exps[0], dict): primary_exp = exps[0] ra = dec = None grating = filt = None if primary_exp is not None: try: ra = float(primary_exp["ra"]); dec = float(primary_exp["dec"]) except (KeyError, TypeError, ValueError) as e: raise ValueError(f"malformed exposure ra/dec: {e}") from e grating, filt = _parse_grating_filter(primary_exp.get("gratingFilter")) # PA from the workspace's exact V3 PA, else derived from APA. try: if "pa_v3_deg" in sidecar: pa_v3 = float(sidecar["pa_v3_deg"]) else: apa = float(data["aperturePA"]) pa_v3 = (apa - V3_IDL_Y_ANGLE) % 360.0 except (KeyError, TypeError, ValueError) as e: raise ValueError(f"malformed aperturePA / pa_v3_deg: {e}") from e slitlet_height = int(sidecar.get("slitlet_height", 3)) # Prefer the workspace's lossless open_shutters list (carries target_id # and role). Fall back to unfolding the MPT slitlets, which loses # those — but at least restores the shutter grid. if isinstance(sidecar.get("open_shutters"), list): opens: list[OpenShutter] = [] for i, sh in enumerate(sidecar["open_shutters"]): try: opens.append(OpenShutter( q=int(sh["q"]), s=int(sh["s"]), d=int(sh["d"]), target_id=sh.get("target_id"), role=sh.get("role", "target"), )) except (KeyError, TypeError, ValueError) as e: raise ValueError(f"malformed workspace open_shutters[{i}]: {e}") from e else: opens = _unfold_slitlets(cfg.get("slitlets") or []) # If primaryIds positionally aligns with slitlets, recover target_id # for the middle ('target') shutter of each slitlet. primary_ids = cfg.get("primaryIds") or [] slitlets = cfg.get("slitlets") or [] if primary_ids and len(primary_ids) <= len(slitlets): cursor = 0 for j, sl in enumerate(slitlets): h = int(sl.get("h", 1)) tid = (str(primary_ids[j]) if j < len(primary_ids) else None) if tid is not None: for off in range(h): opens[cursor + off].target_id = tid cursor += h highlighted: list[tuple[int, int, int]] = [] for i, hl in enumerate(sidecar.get("highlighted") or []): try: q, s, d = hl highlighted.append((int(q), int(s), int(d))) except (TypeError, ValueError) as e: raise ValueError(f"malformed workspace highlighted[{i}]: {e}") from e if grating is None or filt is None: raise ValueError("could not determine disperser/filter from session") catalog_paths = _parse_catalog_paths( sidecar.get("catalog_paths"), sidecar.get("catalog_path"), ) return Session( pointing_ra_deg=ra if ra is not None else 0.0, pointing_dec_deg=dec if dec is not None else 0.0, pa_v3_deg=pa_v3, disperser=str(grating), filter_name=str(filt), slitlet_height=slitlet_height, open_shutters=opens, highlighted=highlighted, image_path=sidecar.get("image_path"), wcs_sidecar_path=sidecar.get("wcs_sidecar_path"), catalog_path=sidecar.get("catalog_path"), catalog_paths=catalog_paths, tool_version=str(sidecar.get("vmpt_version", SESSION_TOOL_VERSION)), created=sidecar.get("created") or data.get("created"), name=data.get("name"), ) def _import_legacy(data: dict) -> Session: """Import the original vMPT-only session schema (flat `open_shutters`).""" for key in ("pointing", "instrument", "open_shutters"): if key not in data: raise ValueError(f"missing required key: {key!r}") pointing = data["pointing"] instrument = data["instrument"] try: ra = float(pointing["ra_deg"]) dec = float(pointing["dec_deg"]) pa = float(pointing["apa_v3_deg"]) except (KeyError, TypeError, ValueError) as e: raise ValueError(f"malformed pointing: {e}") from e try: disperser = str(instrument["disperser"]) filter_name = str(instrument["filter"]) except (KeyError, TypeError) as e: raise ValueError(f"malformed instrument: {e}") from e slitlet_height = int(instrument.get("slitlet_height", 3)) opens: list[OpenShutter] = [] for i, sh in enumerate(data["open_shutters"]): try: opens.append(OpenShutter( q=int(sh["q"]), s=int(sh["s"]), d=int(sh["d"]), target_id=sh.get("target_id"), role=sh.get("role", "target"), )) except (KeyError, TypeError, ValueError) as e: raise ValueError(f"malformed open_shutters[{i}]: {e}") from e highlighted: list[tuple[int, int, int]] = [] for i, hl in enumerate(data.get("highlighted", [])): try: q, s, d = hl highlighted.append((int(q), int(s), int(d))) except (TypeError, ValueError) as e: raise ValueError(f"malformed highlighted[{i}]: {e}") from e catalog_paths = _parse_catalog_paths( data.get("catalog_paths"), data.get("catalog_path"), ) return Session( pointing_ra_deg=ra, pointing_dec_deg=dec, pa_v3_deg=pa, disperser=disperser, filter_name=filter_name, slitlet_height=slitlet_height, open_shutters=opens, highlighted=highlighted, image_path=data.get("image_path"), wcs_sidecar_path=data.get("wcs_sidecar_path"), catalog_path=data.get("catalog_path"), catalog_paths=catalog_paths, tool_version=str(data.get("version", "1.0")), created=data.get("created"), ) def _load_json_or_empty(p: Path) -> dict: """Return {} if the file is missing or malformed.""" if not p.exists(): return {} try: with open(p) as f: data = json.load(f) except (OSError, json.JSONDecodeError): return {} return data if isinstance(data, dict) else {}
[docs] def import_session_json(path: str) -> Session: """Parse a session JSON back into a Session. The user can point at EITHER file in a bundle: • `session_MPT_plan.json` → pure MPT plan; we look for a sibling `vmpt_workspace.json` to merge in target_ids, roles, image path. • `vmpt_workspace.json` → vMPT extras; we look for a sibling `session_MPT_plan.json` (or any `*plan*.json` matching MPT shape) to pull pointing / PA / disperser / slitlet geometry. Legacy single-file sessions (`open_shutters` at top level) still load. """ p = Path(path) try: with open(p) as f: data = json.load(f) except (OSError, json.JSONDecodeError) as e: raise ValueError(f"could not read session JSON: {e}") from e if not isinstance(data, dict): raise ValueError("session JSON root must be an object") # Case A: user pointed at the workspace sidecar — find the MPT plan sibling. if "vmpt_version" in data or "open_shutters" in data and "pointing" not in data: sidecar = data mpt_data: dict = {} for fname in (MPT_PLAN_FILENAME, *_LEGACY_MPT_PLAN_FILENAMES): mpt_data = _load_json_or_empty(p.parent / fname) if mpt_data: break if not mpt_data: # Last resort: any *.json sibling whose shape says MPT plan. for candidate in sorted(p.parent.glob("*.json")): if candidate == p: continue d = _load_json_or_empty(candidate) if "configs" in d and "aperturePA" in d: mpt_data = d break if not mpt_data: raise ValueError( f"workspace at {p.name} needs a sibling MPT plan JSON " f"(expected {MPT_PLAN_FILENAME}); none found" ) return _import_mpt(mpt_data, sidecar) # Case B: user pointed at the MPT plan — find the workspace sidecar. if "configs" in data: sidecar: dict = {} for fname in (WORKSPACE_FILENAME, *_LEGACY_WORKSPACE_FILENAMES): sidecar = _load_json_or_empty(p.parent / fname) if sidecar: break return _import_mpt(data, sidecar) # Case C: legacy single-file vMPT session. if "open_shutters" in data: return _import_legacy(data) raise ValueError( "not a recognized session JSON (no 'configs', 'open_shutters', " "or 'vmpt_version' top-level key)" )