Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 61 additions & 6 deletions src/datasets/features/nifti.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import base64
import os
import uuid
from dataclasses import dataclass, field
from io import BytesIO
from pathlib import Path
from typing import TYPE_CHECKING, Any, ClassVar, Dict, Optional, Union

Expand All @@ -18,6 +19,62 @@

from .features import FeatureType

if config.NIBABEL_AVAILABLE:
import nibabel as nib

class Nifti1ImageWrapper(nib.nifti1.Nifti1Image):
"""
A wrapper around nibabel's Nifti1Image to customize its representation.
"""

def __init__(self, nifti_image: nib.nifti1.Nifti1Image):
super().__init__(
dataobj=nifti_image.get_fdata(),
affine=nifti_image.affine,
header=nifti_image.header,
extra=nifti_image.extra,
file_map=nifti_image.file_map,
dtype=nifti_image.get_data_dtype(),
)
self.nifti_image = nifti_image

def _repr_html_(self):
bytes_ = self.nifti_image.to_bytes()
b64 = base64.b64encode(bytes_).decode("utf-8")

self.nifti_data_url = f"data:application/octet-stream;base64,{b64}"
viewer_id = f"papaya-{uuid.uuid4().hex[:8]}"

html = f"""
<div id="{viewer_id}" style="width: 100%; height: 800px;"></div>
<link rel="stylesheet" type="text/css" href="https://cdn.jsdelivr.net/npm/[email protected]/release/current/standard/papaya.css" />
<script src="https://cdn.jsdelivr.net/npm/[email protected]/release/current/standard/papaya.js"></script>
<script type="text/javascript">
(function() {{
// Wait for Papaya to load
function initPapaya() {{
if (typeof papaya === 'undefined' || typeof papaya.Container === 'undefined') {{
setTimeout(initPapaya, 100);
return;
}}
// Papaya loaded - manually initialize
var params = {{}};
params["images"] = ["{self.nifti_data_url}"];
params["kioskMode"] = false;
params["showControls"] = true;
// Manual initialization
papaya.Container.startPapaya();
papaya.Container.addViewer("{viewer_id}", params);
}}
initPapaya();
}})();
</script>
"""
return html


@dataclass
class Nifti:
Expand Down Expand Up @@ -106,7 +163,7 @@ def encode_example(self, value: Union[str, bytes, bytearray, dict, "nib.Nifti1Im
f"A nifti sample should be a string, bytes, Path, nibabel image, or dict, but got {type(value)}."
)

def decode_example(self, value: dict, token_per_repo_id=None) -> "nib.nifti1.Nifti1Image":
def decode_example(self, value: dict, token_per_repo_id=None) -> "Nifti1ImageWrapper":
"""Decode example NIfTI file into nibabel image object.
Args:
Expand Down Expand Up @@ -165,11 +222,9 @@ def decode_example(self, value: dict, token_per_repo_id=None) -> "nib.nifti1.Nif
): # gzip magic number, see https://stackoverflow.com/a/76055284/9534390 or "Magic number" on https://en.wikipedia.org/wiki/Gzip
bytes_ = gzip.decompress(bytes_)

bio = BytesIO(bytes_)
fh = nib.FileHolder(fileobj=bio)
nifti = nib.Nifti1Image.from_file_map({"header": fh, "image": fh})
nifti = nib.Nifti1Image.from_bytes(bytes_)

return nifti
return Nifti1ImageWrapper(nifti)

def embed_storage(self, storage: pa.StructArray, token_per_repo_id=None) -> pa.StructArray:
"""Embed NifTI files into the Arrow array.
Expand Down
Loading