Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cleaning #689

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 25 additions & 63 deletions clinicadl/API/complicated_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,41 +2,26 @@

import torchio.transforms as transforms

from clinicadl.dataset.dataloader_config import DataLoaderConfig
from clinicadl.dataset.datasets.caps_dataset import CapsDataset
from clinicadl.dataset.datasets.concat import ConcatDataset
from clinicadl.dataset.preprocessing import (
PreprocessingCustom,
from clinicadl.data.dataloader import DataLoaderConfig
from clinicadl.data.datasets.caps_dataset import CapsDataset
from clinicadl.data.datasets.concat import ConcatDataset
from clinicadl.data.preprocessing import (
PreprocessingPET,
PreprocessingT1,
)
from clinicadl.dataset.readers.caps_reader import CapsReader
from clinicadl.experiment_manager.experiment_manager import ExperimentManager
from clinicadl.losses.config import CrossEntropyLossConfig
from clinicadl.losses.factory import get_loss_function
from clinicadl.model.clinicadl_model import ClinicaDLModel
from clinicadl.networks.config import ImplementedNetworks
from clinicadl.networks.factory import (
ConvEncoderOptions,
create_network_config,
get_network_from_config,
)
from clinicadl.optimization.optimizer.config import AdamConfig, OptimizerConfig
from clinicadl.optimization.optimizer.factory import get_optimizer
from clinicadl.predictor.predictor import Predictor
from clinicadl.splitter.kfold import KFolder
from clinicadl.splitter.split import get_single_split, split_tsv
from clinicadl.networks.config.resnet import ResNetConfig
from clinicadl.optim.optimizers.config import AdamConfig
from clinicadl.splitter import KFold, make_kfold, make_split
from clinicadl.trainer.trainer import Trainer
from clinicadl.transforms.extraction import ROI, BaseExtraction, Image, Patch, Slice
from clinicadl.transforms.extraction import Extraction, Image, Patch, Slice
from clinicadl.transforms.transforms import Transforms

# Create the Maps Manager / Read/write manager /
maps_path = Path("/")
manager = ExperimentManager(
maps_path, overwrite=False
) # a ajouter dans le manager: mlflow/ profiler/ etc ...

caps_directory = Path("caps_directory") # output of clinica pipelines
caps_directory = Path(
"/Users/camille.brianceau/aramis/CLINICADL/caps"
) # output of clinica pipelines

sub_ses_t1 = Path("/Users/camille.brianceau/aramis/CLINICADL/caps/subjects_t1.tsv")
preprocessing_t1 = PreprocessingT1()
Expand All @@ -60,7 +45,7 @@
sub_ses_pet_45 = Path(
"/Users/camille.brianceau/aramis/CLINICADL/caps/subjects_pet_18FAV45.tsv"
)
preprocessing_pet_45 = PreprocessingPET(tracer="18FAV45", suvr_reference_region="pons2")
preprocessing_pet_45 = PreprocessingPET(tracer="18FAV45", suvr_reference_region="pons2") # type: ignore

dataset_pet_image = CapsDataset(
caps_directory=caps_directory,
Expand All @@ -79,47 +64,27 @@
) # 3 train.tsv en entrée qu'il faut concat et pareil pour les transforms à faire attention


config_file = Path("config_file")
trainer = Trainer.from_json(config_file=config_file, manager=manager)

# CAS CROSS-VALIDATION
splitter = KFolder(caps_dataset=dataset_multi_modality_multi_extract, manager=manager)
split_dir = splitter.make_splits(
n_splits=3, output_dir=Path(""), subset_name="validation", stratification=""
) # Optional data tsv and output_dir

dataloader_config = DataLoaderConfig(n_procs=3, batch_size=10)
split_dir = make_split(sub_ses_t1, n_test=0.2) # Optional data tsv and output_dir
fold_dir = make_kfold(split_dir / "train.tsv", n_splits=2)

splitter = KFold(fold_dir)

# CAS 1

# Prérequis : déjà avoir des fichiers avec les listes train et validation
split_dir = make_kfold(
"dataset.tsv"
) # lit dataset.tsv => fait le kfold => ecrit la sortie dans split_dir
splitter = KFolder(
dataset_multi_modality, split_dir
) # c'est plutôt un iterable de dataloader
maps_path = Path("/")
manager = ExperimentManager(maps_path, overwrite=False)

# CAS 2
splitter = KFolder(caps_dataset=dataset_t1_image)
splitter.make_splits(n_splits=3)
splitter.write(split_dir)
config_file = Path("config_file")
trainer = Trainer.from_json(config_file=config_file, manager=manager)

# or
splitter = KFolder(caps_dataset=dataset_t1_image)
splitter.read(split_dir)

for split in splitter.get_splits(splits=(0, 3, 4), dataloader_config=dataloader_config):
# bien définir ce qu'il y a dans l'objet split
for split in splitter.get_splits(dataset=dataset_t1_image):
train_loader = split.build_train_loader(batch_size=2)
val_loader = split.build_val_loader(DataLoaderConfig())

network_config = create_network_config(ImplementedNetworks.CNN)(
in_shape=[2, 2, 2],
num_outputs=1,
conv_args=ConvEncoderOptions(channels=[3, 2, 2]),
)
model = ClinicaDLModelClassif.from_config(
network_config=network_config,
model = ClinicaDLModel.from_config(
network_config=ResNetConfig(num_outputs=1, spatial_dims=1, in_channels=1),
loss_config=CrossEntropyLossConfig(),
optimizer_config=AdamConfig(),
)
Expand All @@ -133,9 +98,6 @@
dataset_test = CapsDataset(
caps_directory=caps_directory,
preprocessing=preprocessing_t1,
sub_ses_tsv=Path("test.tsv"), # test only on data from the first dataset
data=Path("test.tsv"), # test only on data from the first dataset
transforms=transforms_image,
)

predictor = Predictor(model=model, manager=manager)
predictor.predict(dataset_test=dataset_test, split_number=2)
27 changes: 12 additions & 15 deletions clinicadl/API/cross_val.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from pathlib import Path

from clinicadl.dataset.datasets.caps_dataset import CapsDataset
from clinicadl.experiment_manager.experiment_manager import ExperimentManager
from clinicadl.predictor.predictor import Predictor
from clinicadl.splitter.new_splitter.dataloader import DataLoaderConfig
from clinicadl.splitter.new_splitter.splitter.kfold import KFold
from clinicadl.trainer.trainer import Trainer
from clinicadl.data.dataloader import DataLoaderConfig
from clinicadl.data.datasets import CapsDataset
from clinicadl.experiment_manager import ExperimentManager
from clinicadl.splitter import KFold, make_kfold, make_split
from clinicadl.trainer import Trainer

# SIMPLE EXPERIMENT WITH A CAPS ALREADY EXISTING

Expand All @@ -19,20 +18,18 @@
config_file=config_file, manager=manager
) # gpu, amp, fsdp, seed

splitter = KFold(dataset=dataset_t1_image)
splitter.make_splits(n_splits=3)
split_dir = Path("")
splitter.write(split_dir)
split_dir = make_split(
dataset_t1_image.df, n_test=0.2, subset_name="validation", output_dir="test"
) # Optional data tsv and output_dir
fold_dir = make_kfold(split_dir / "train.tsv", n_splits=2)

splitter = KFold(fold_dir)

splitter.read(split_dir)

# define the needed parameters for the dataloader
dataloader_config = DataLoaderConfig(num_workers=3, batch_size=10)


for split in splitter.get_splits(splits=(0, 3, 4)):
print(split)
for split in splitter.get_splits(dataset=dataset_t1_image):
split.build_train_loader(dataloader_config)
split.build_val_loader(num_workers=3, batch_size=10)

print(split)
61 changes: 26 additions & 35 deletions clinicadl/API/dataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,25 @@

import torchio.transforms as transforms

from clinicadl.dataset.datasets.caps_dataset import CapsDataset
from clinicadl.dataset.datasets.concat import ConcatDataset
from clinicadl.dataset.preprocessing import (
from clinicadl.data.datasets import CapsDataset, ConcatDataset
from clinicadl.data.preprocessing import (
BasePreprocessing,
PreprocessingFlair,
PreprocessingPET,
PreprocessingT1,
)
from clinicadl.experiment_manager.experiment_manager import ExperimentManager
from clinicadl.data.preprocessing.pet import SUVRReferenceRegions, Tracer
from clinicadl.experiment_manager import ExperimentManager
from clinicadl.losses.config import CrossEntropyLossConfig
from clinicadl.model.clinicadl_model import ClinicaDLModel
from clinicadl.networks.factory import (
ConvEncoderOptions,
create_network_config,
get_network_from_config,
)
from clinicadl.splitter.kfold import KFolder
from clinicadl.splitter.split import get_single_split, split_tsv
from clinicadl.transforms.extraction import ROI, Image, Patch, Slice
from clinicadl.transforms.transforms import Transforms
from clinicadl.splitter import KFold, make_kfold, make_split
from clinicadl.transforms import Transforms
from clinicadl.transforms.extraction import Image, Patch, Slice

sub_ses_t1 = Path("/Users/camille.brianceau/aramis/CLINICADL/caps/subjects_t1.tsv")
sub_ses_pet_45 = Path(
Expand All @@ -38,8 +37,12 @@
"/Users/camille.brianceau/aramis/CLINICADL/caps"
) # output of clinica pipelines

preprocessing_pet_45 = PreprocessingPET(tracer="18FAV45", suvr_reference_region="pons2")
preprocessing_pet_11 = PreprocessingPET(tracer="11CPIB", suvr_reference_region="pons2")
preprocessing_pet_45 = PreprocessingPET(
tracer=Tracer.FAV45, suvr_reference_region=SUVRReferenceRegions.PONS2
)
preprocessing_pet_11 = PreprocessingPET(
tracer=Tracer.CPIB, suvr_reference_region=SUVRReferenceRegions.PONS2
)

preprocessing_t1 = PreprocessingT1()
preprocessing_flair = PreprocessingFlair()
Expand All @@ -55,18 +58,6 @@

transforms_slice = Transforms(extraction=Slice())

transforms_roi = Transforms(
object_augmentation=[transforms.Ghosting(2, 1, 0.1, 0.1)],
object_transforms=[transforms.RandomMotion()],
extraction=ROI(
roi_list=["leftHippocampusBox", "rightHippocampusBox"],
roi_mask_location=Path(
"/Users/camille.brianceau/aramis/CLINICADL/caps/masks/tpl-MNI152NLin2009cSym"
),
roi_crop_input=True,
),
)

transforms_image = Transforms(
image_augmentation=[transforms.RandomMotion()],
extraction=Image(),
Expand Down Expand Up @@ -96,25 +87,25 @@
)


print("Pet 11 and ROI ")
print("Pet 11 and Image ")

dataset_pet_11_roi = CapsDataset(
dataset_pet_11_image = CapsDataset(
caps_directory=caps_directory,
data=sub_ses_pet_11,
preprocessing=preprocessing_pet_11,
transforms=transforms_roi,
transforms=transforms_image,
)
dataset_pet_11_roi.prepare_data(
dataset_pet_11_image.prepare_data(
n_proc=2
) # to extract the tensor of the PET file this time

print(dataset_pet_11_roi)
print(dataset_pet_11_roi.__len__())
print(dataset_pet_11_roi._get_meta_data(0))
print(dataset_pet_11_roi._get_meta_data(1))
# print(dataset_pet_11_roi._get_full_image())
print(dataset_pet_11_roi.__getitem__(1).elem_idx)
print(dataset_pet_11_roi.elem_per_image)
print(dataset_pet_11_image)
print(dataset_pet_11_image.__len__())
print(dataset_pet_11_image._get_meta_data(0))
print(dataset_pet_11_image._get_meta_data(1))
# print(dataset_pet_11_image._get_full_image())
print(dataset_pet_11_image.__getitem__(1).elem_idx)
print(dataset_pet_11_image.elem_per_image)


print("T1 and image ")
Expand Down Expand Up @@ -161,7 +152,7 @@

lity_multi_extract = ConcatDataset(
[
dataset_t1,
dataset_pet,
dataset_t1_image,
dataset_pet_11_image,
]
) # 3 train.tsv en entrée qu'il faut concat et pareil pour les transforms à faire attention
Loading
Loading