diff --git a/brainles_preprocessing/brain_extraction/__init__.py b/brainles_preprocessing/brain_extraction/__init__.py index 92a981b..8c97b5c 100644 --- a/brainles_preprocessing/brain_extraction/__init__.py +++ b/brainles_preprocessing/brain_extraction/__init__.py @@ -1 +1,6 @@ -from .brain_extractor import HDBetExtractor +from .hd_bet import HDBetExtractor + +try: + from .synthstrip import SynthStripExtractor +except ImportError: + pass diff --git a/brainles_preprocessing/brain_extraction/brain_extractor.py b/brainles_preprocessing/brain_extraction/brain_extractor.py index 03a716b..908982a 100644 --- a/brainles_preprocessing/brain_extraction/brain_extractor.py +++ b/brainles_preprocessing/brain_extraction/brain_extractor.py @@ -1,20 +1,12 @@ # TODO add typing and docs -import shutil -from abc import ABC, abstractmethod +from abc import abstractmethod, ABC from pathlib import Path -from typing import Optional, Union -from enum import Enum +from typing import Union from auxiliary.io import read_image, write_image -from brainles_hd_bet import run_hd_bet -class Mode(Enum): - FAST = "fast" - ACCURATE = "accurate" - - -class BrainExtractor: +class BrainExtractor(ABC): @abstractmethod def extract( self, @@ -75,67 +67,3 @@ def apply_mask( ) except Exception as e: raise RuntimeError(f"Error writing output file: {e}") from e - - -class HDBetExtractor(BrainExtractor): - def extract( - self, - input_image_path: Union[str, Path], - masked_image_path: Union[str, Path], - brain_mask_path: Union[str, Path], - mode: Union[str, Mode] = Mode.ACCURATE, - device: Optional[Union[int, str]] = 0, - do_tta: bool = True, - **kwargs, - ) -> None: - # GPU + accurate + TTA - """ - Skull-strips images with HD-BET and generates a skull-stripped file and mask. - - Args: - input_image_path (str or Path): Path to the input image. - masked_image_path (str or Path): Path where the brain-extracted image will be saved. - brain_mask_path (str or Path): Path where the brain mask will be saved. - mode (str or Mode): Extraction mode ('fast' or 'accurate'). - device (str or int): Device to use for computation (e.g., 0 for GPU 0, 'cpu' for CPU). - do_tta (bool): whether to do test time data augmentation by mirroring along all axes. - """ - - # Ensure mode is a Mode enum instance - if isinstance(mode, str): - try: - mode_enum = Mode(mode.lower()) - except ValueError: - raise ValueError(f"'{mode}' is not a valid Mode.") - elif isinstance(mode, Mode): - mode_enum = mode - else: - raise TypeError("Mode must be a string or a Mode enum instance.") - - # Run HD-BET - run_hd_bet( - mri_fnames=[str(input_image_path)], - output_fnames=[str(masked_image_path)], - mode=mode_enum.value, - device=device, - # TODO consider postprocessing - postprocess=False, - do_tta=do_tta, - keep_mask=True, - overwrite=True, - ) - - # Construct the path to the generated mask - masked_image_path = Path(masked_image_path) - hdbet_mask_path = masked_image_path.with_name( - masked_image_path.name.replace(".nii.gz", "_mask.nii.gz") - ) - - if hdbet_mask_path.resolve() != Path(brain_mask_path).resolve(): - try: - shutil.copyfile( - src=str(hdbet_mask_path), - dst=str(brain_mask_path), - ) - except Exception as e: - raise RuntimeError(f"Error copying mask file: {e}") from e diff --git a/brainles_preprocessing/brain_extraction/hd_bet.py b/brainles_preprocessing/brain_extraction/hd_bet.py new file mode 100644 index 0000000..cd631ea --- /dev/null +++ b/brainles_preprocessing/brain_extraction/hd_bet.py @@ -0,0 +1,76 @@ +from pathlib import Path +from typing import Optional, Union +import shutil +from enum import Enum + +from brainles_preprocessing.brain_extraction.brain_extractor import BrainExtractor +from brainles_hd_bet import run_hd_bet + + +class Mode(Enum): + FAST = "fast" + ACCURATE = "accurate" + + +class HDBetExtractor(BrainExtractor): + def extract( + self, + input_image_path: Union[str, Path], + masked_image_path: Union[str, Path], + brain_mask_path: Union[str, Path], + mode: Union[str, Mode] = Mode.ACCURATE, + device: Optional[Union[int, str]] = 0, + do_tta: bool = True, + **kwargs, + ) -> None: + # GPU + accurate + TTA + """ + Skull-strips images with HD-BET and generates a skull-stripped file and mask. + + Args: + input_image_path (str or Path): Path to the input image. + masked_image_path (str or Path): Path where the brain-extracted image will be saved. + brain_mask_path (str or Path): Path where the brain mask will be saved. + mode (str or Mode): Extraction mode ('fast' or 'accurate'). + device (str or int): Device to use for computation (e.g., 0 for GPU 0, 'cpu' for CPU). + do_tta (bool): whether to do test time data augmentation by mirroring along all axes. + """ + + # Ensure mode is a Mode enum instance + if isinstance(mode, str): + try: + mode_enum = Mode(mode.lower()) + except ValueError: + raise ValueError(f"'{mode}' is not a valid Mode.") + elif isinstance(mode, Mode): + mode_enum = mode + else: + raise TypeError("Mode must be a string or a Mode enum instance.") + + # Run HD-BET + run_hd_bet( + mri_fnames=[str(input_image_path)], + output_fnames=[str(masked_image_path)], + mode=mode_enum.value, + device=device, + # TODO consider postprocessing + postprocess=False, + do_tta=do_tta, + keep_mask=True, + overwrite=True, + ) + + # Construct the path to the generated mask + masked_image_path = Path(masked_image_path) + hdbet_mask_path = masked_image_path.with_name( + masked_image_path.name.replace(".nii.gz", "_mask.nii.gz") + ) + + if hdbet_mask_path.resolve() != Path(brain_mask_path).resolve(): + try: + shutil.copyfile( + src=str(hdbet_mask_path), + dst=str(brain_mask_path), + ) + except Exception as e: + raise RuntimeError(f"Error copying mask file: {e}") from e diff --git a/brainles_preprocessing/preprocessor/preprocessor.py b/brainles_preprocessing/preprocessor/preprocessor.py index 2983f39..797fafe 100644 --- a/brainles_preprocessing/preprocessor/preprocessor.py +++ b/brainles_preprocessing/preprocessor/preprocessor.py @@ -11,10 +11,8 @@ from loguru import logger -from brainles_preprocessing.brain_extraction.brain_extractor import ( - BrainExtractor, - HDBetExtractor, -) +from brainles_preprocessing.brain_extraction.brain_extractor import BrainExtractor +from brainles_preprocessing.brain_extraction.hd_bet import HDBetExtractor from brainles_preprocessing.constants import PreprocessorSteps from brainles_preprocessing.defacing import Defacer, QuickshearDefacer from brainles_preprocessing.modality import CenterModality, Modality