Skip to main content
Inference outputs are written as Zarr (default), HDF5, or NPZ. All contain the same datasets, so a single loader handles any format.

Load an output file

from pathlib import Path
import numpy as np

def load_prediction(path):
    path = Path(path)
    if path.suffix == ".zarr" or path.is_dir():
        import zarr
        store = zarr.open(str(path), mode="r")
        data = {"outputs": np.array(store["outputs"]),
                "coords": np.array(store["coords"]),
                "thumbnail": np.array(store["thumbnail"]),
                "metadata": dict(store.attrs)}
        if "output_gene_names" in store:
            data["gene_names"] = [s.decode() if isinstance(s, bytes) else str(s)
                                  for s in np.array(store["output_gene_names"])]
        return data
    # HDF5 (.h5) via h5py, NPZ (.npz) via np.load(allow_pickle=True) follow the same keys.
    raise ValueError(f"Unsupported format: {path.suffix}")
Datasets: outputs (n_tiles, D), coords (n_tiles, 2), tissue_ratios, thumbnail, tissue_mask, and (M-Optimus) output_gene_names / input_gene_names. Metadata attributes include slide_dimensions, slide_dimensions_at_mpp, tile_size, mpp, and num_tiles.

Tile-grid heatmap

Build a heatmap at tile-grid resolution (one cell per tile) and resize to the thumbnail — fast and memory-light.
from PIL import Image

def tile_heatmap(values, coords, meta, thumb):
    tiling_w, tiling_h = meta.get("slide_dimensions_at_mpp", meta["slide_dimensions"])
    tw, th = meta.get("tile_size", [224, 224])
    grid_w, grid_h = (tiling_w + tw - 1)//tw, (tiling_h + th - 1)//th
    gx = np.clip((coords[:,0]/tw).astype(int), 0, grid_w-1)
    gy = np.clip((coords[:,1]/th).astype(int), 0, grid_h-1)
    grid = np.full((grid_h, grid_w), np.nan, np.float32)
    grid[gy, gx] = values
    out = np.array(Image.fromarray(np.nan_to_num(grid)).resize(
        (thumb.shape[1], thumb.shape[0]), Image.BILINEAR))
    return out

Overlay a gene

import matplotlib.pyplot as plt

data = load_prediction("/data/output/tcga_coad.zarr")
gene = "ENSG00000198851"  # CD3E — T-cell infiltrate
idx = data["gene_names"].index(gene)
heat = tile_heatmap(data["outputs"][:, idx], data["coords"], data["metadata"], data["thumbnail"])

plt.imshow(data["thumbnail"])
plt.imshow(heat, cmap="inferno", alpha=0.55)
plt.axis("off"); plt.colorbar(label="Predicted expression"); plt.show()