diff --git a/.gitignore b/.gitignore index 7637172..814dca6 100644 --- a/.gitignore +++ b/.gitignore @@ -42,5 +42,8 @@ venv/ .vscode/ # Data & models -/data/ +data/ models/ + +# Exclude 'imnet_sample' from being ignored +!data/imnet_sample/ \ No newline at end of file diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..02b8121 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,55 @@ +# Use an official lightweight image as a base +FROM ubuntu:22.04 + +# Set environment variables to avoid interactive prompts during installation +ENV DEBIAN_FRONTEND=noninteractive + +# Install system dependencies +RUN apt-get update && apt-get install -y \ + curl \ + wget \ + bzip2 \ + ca-certificates \ + libglib2.0-0 \ + libxext6 \ + libsm6 \ + libxrender1 \ + git \ + && apt-get clean \ + && rm -rf /var/lib/apt/lists/* + +# Install micromamba +RUN curl -L https://micro.mamba.pm/api/micromamba/linux-64/latest | tar -xvj -C /usr/local/bin --strip-components=1 bin/micromamba + +# Set up working directory +WORKDIR /app + +# Copy current directory contents into the container +COPY . . + +# Set environment name +ENV ENV_NAME=cvdm + +ENV MAMBA_ROOT_PREFIX=/opt/micromamba +ENV PATH=$MAMBA_ROOT_PREFIX/envs/$ENV_NAME/bin:$PATH +ENV PYTHONPATH=$MAMBA_ROOT_PREFIX/envs/$ENV_NAME/lib/python3.10/site-packages:$PYTHONPATH + +RUN apt-get update && apt-get install ffmpeg libsm6 libxext6 -y +# Set up micromamba environment, install dependencies and pip packages +RUN micromamba create -n $ENV_NAME -y && \ + micromamba install -n $ENV_NAME \ + tensorflow-gpu==2.15.0 \ + keras==2.15.* \ + matplotlib==3.8.0 \ + tqdm==4.65.0 \ + scikit-learn==1.4.2 \ + scikit-image==0.22.0 \ + einops==0.7.0 \ + neptune==1.10.2 -y && \ + /usr/local/bin/micromamba run -n $ENV_NAME pip3 install opencv-python==4.9.0.80 \ + tensorflow-addons==0.23.0 \ + cupy-cuda12x==13.3.0 && \ + /usr/local/bin/micromamba run -n $ENV_NAME python -m pip install . + +# Default command when the container starts (optional) +CMD ["/bin/bash"] diff --git a/README.md b/README.md index 662d91a..647edcf 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # Conditional Variational Diffusion Models -This code implements the Conditional Variational Diffusion Models as described [in the paper](https://arxiv.org/abs/2312.02246). +Diffusion models have become popular for their ability to solve complex problems where hidden information needs to be estimated from observed data. Among others, their use is popular in image generation tasks. These models rely on a key hyperparameter of the variance schedule that impacts how well they learn, but recent work shows that allowing the model to automatically learn this hyperparameter can improve both performance and efficiency. Our CVDM package implements Conditional Variational Diffusion Models (CVDM) as described [in the paper](https://arxiv.org/abs/2312.02246) that build on this idea, with the addition of [Zero-Mean Diffusion (ZMD)](https://arxiv.org/pdf/2406.04388), a technique that enhances performance in certain imaging tasks, aiming to make these approaches more accessible to researchers. ## Where to get the data? @@ -14,27 +14,34 @@ It is assumed that for: - BioSR phase task, data can be found in the directory specified as dataset_path in configs/biosr_phase.yaml, in one file, y.npy (ground truth). Input to the model will be generated based on the ground truth. - ImageNet super-resolution task, data can be found in the directory specified as dataset_path in configs/imagenet_sr.yaml as a collection of JPEG files. Input to the model will be generated based on the ground truth. - ImageNet phase task, data can be found in the directory specified as dataset_path in configs/imagenet_phase.yaml as a collection of JPEG files. Input to the model will be generated based on the ground truth. -- HCOCO phase evaluation task, data can be found in the directory specified as dataset_path in configs/hcoco_phase_eval.yaml as a collection of JPEG files. Input to the model will be generated based on the ground truth. +- HCOCO phase evaluation task, data can be found in the directory specified as dataset_path in configs/hcoco_phase.yaml as a collection of JPEG files. Input to the model will be generated based on the ground truth. ## How to prepare environment? -Run the following code: + +We provide a Dockerfile to prepare the environment. Run the following code in the root of this repository: +``` +docker build -t my-image . +docker run -it my-image +``` +Inside the image run: ``` -conda create -n myenv python=3.10 -conda activate myenv -pip install -r requirements.txt -pip install -e . +eval "$(micromamba shell hook --shell bash)" +micromamba activate cvdm ``` +If you encounter issues with cupy installation (required only for the phase tasks) such as [these](https://github.com/cupy/cupy/issues/8466), you can modify the `cvdm/utils/phase_utils.py` to use pure numpy. + ## How to run the training code? -1. Download the data. -1. Modify the config in `configs/` directory with the path to the data you want to use and the directory for outputs. -2. Run the code from the root directory: `python scripts/train.py --config-path $PATH_TO_CONFIG --neptune-token $NEPTUNE_TOKEN`. +1. Download the data or use the sample data available in the data/ directory. The sample data is a fraction of the ImageNet dataset and can be used with configs `imagenet_sr_sample.yaml` or `imagenet_phase_sample.yaml`. You can also use your own data as long as it is in ".npy" format. To do so, use the task type "other". +2. Modify the config in `configs/` directory with the path to the data you want to use and the directory for outputs. For the description of each parameter, check the documentation in `cvdm/configs/` files. +3. Run the code from the root directory: `python scripts/train.py --config-path $PATH_TO_CONFIG --neptune-token $NEPTUNE_TOKEN`. `--neptune-token` argument is optional. -## How to run the training code? + +## How to run the evaluation code? 1. Download the data. 1. Modify the config in `configs/` directory with the path to the data you want to use and the directory for outputs. @@ -42,6 +49,10 @@ pip install -e . `--neptune-token` argument is optional. +## How to contribute? + +To contribute to the software or seek support, please leave an issue or pull request. + ## License This repository is released under the MIT License (refer to the LICENSE file for details). diff --git a/configs/biosr.yaml b/configs/biosr.yaml index 1f9fe93..2a43a69 100644 --- a/configs/biosr.yaml +++ b/configs/biosr.yaml @@ -4,7 +4,10 @@ model: noise_model_type: "unet" alpha: 0.001 load_weights: null + load_mu_weights: null snr_expansion_n: 1 + zmd: False + diff_inp: False training: lr: 0.0001 @@ -12,7 +15,7 @@ training: eval: output_path: "outputs/biosr" - generation_timesteps: 1000 + generation_timesteps: 200 checkpoint_freq: 1000 log_freq: 10 image_freq: 100 @@ -26,5 +29,5 @@ data: im_size: 256 neptune: - name: "Virtual_Stain" + name: "CVDM" project: "mlid/test" \ No newline at end of file diff --git a/configs/biosr_phase.yaml b/configs/biosr_phase.yaml index 1d23184..b4f58f4 100644 --- a/configs/biosr_phase.yaml +++ b/configs/biosr_phase.yaml @@ -4,7 +4,10 @@ model: noise_model_type: "unet" alpha: 0.001 load_weights: null + load_mu_weights: null snr_expansion_n: 1 + zmd: False + diff_inp: True training: lr: 0.0001 @@ -12,7 +15,7 @@ training: eval: output_path: "outputs/biosr" - generation_timesteps: 1000 + generation_timesteps: 200 checkpoint_freq: 1000 log_freq: 10 image_freq: 100 @@ -26,5 +29,5 @@ data: im_size: 256 neptune: - name: "Virtual_Stain" + name: "CVDM" project: "mlid/test" \ No newline at end of file diff --git a/configs/hcoco_phase_eval.yaml b/configs/hcoco_phase.yaml similarity index 78% rename from configs/hcoco_phase_eval.yaml rename to configs/hcoco_phase.yaml index 2d0179f..866c923 100644 --- a/configs/hcoco_phase_eval.yaml +++ b/configs/hcoco_phase.yaml @@ -4,7 +4,10 @@ model: noise_model_type: "unet" alpha: 0.001 load_weights: null + load_mu_weights: null snr_expansion_n: 1 + zmd: False + diff_inp: False training: lr: 0.0001 @@ -12,7 +15,7 @@ training: eval: output_path: "outputs/hcoco" - generation_timesteps: 1000 + generation_timesteps: 200 checkpoint_freq: 1000 log_freq: 10 image_freq: 100 @@ -22,9 +25,9 @@ eval: data: dataset_path: "/bigdata/casus/MLID/maria/hcoco_sample" n_samples: 100 - batch_size: 1 + batch_size: 2 im_size: 256 neptune: - name: "Virtual_Stain" + name: "CVDM" project: "mlid/test" \ No newline at end of file diff --git a/configs/imagenet_phase.yaml b/configs/imagenet_phase.yaml index 35dbdb0..95e0119 100644 --- a/configs/imagenet_phase.yaml +++ b/configs/imagenet_phase.yaml @@ -4,7 +4,10 @@ model: noise_model_type: "unet" alpha: 0.001 load_weights: null + load_mu_weights: null snr_expansion_n: 1 + zmd: False + diff_inp: False training: lr: 0.0001 @@ -12,7 +15,7 @@ training: eval: output_path: "outputs/imagenet" - generation_timesteps: 1000 + generation_timesteps: 200 checkpoint_freq: 1000 log_freq: 10 image_freq: 100 @@ -20,11 +23,11 @@ eval: val_len: 100 data: - dataset_path: "/bigdata/casus/MLID/maria/imagenet_sample" + dataset_path: "/bigdata/imnet" n_samples: 100 - batch_size: 1 + batch_size: 2 im_size: 256 neptune: - name: "Virtual_Stain" + name: "CVDM" project: "mlid/test" \ No newline at end of file diff --git a/configs/imagenet_phase_sample.yaml b/configs/imagenet_phase_sample.yaml new file mode 100644 index 0000000..1bbfa20 --- /dev/null +++ b/configs/imagenet_phase_sample.yaml @@ -0,0 +1,33 @@ +task: "imagenet_phase" + +model: + noise_model_type: "unet" + alpha: 0.001 + load_weights: null + load_mu_weights: null + snr_expansion_n: 1 + zmd: False + diff_inp: False + +training: + lr: 0.0001 + epochs: 100 + +eval: + output_path: "outputs/imagenet" + generation_timesteps: 200 + checkpoint_freq: 1000 + log_freq: 10 + image_freq: 100 + val_freq: 200 + val_len: 10 + +data: + dataset_path: "data/imnet_sample" + n_samples: 100 + batch_size: 2 + im_size: 256 + +neptune: + name: "CVDM" + project: "mlid/test" \ No newline at end of file diff --git a/configs/imagenet_sr.yaml b/configs/imagenet_sr.yaml index 9b849b0..c943b8a 100644 --- a/configs/imagenet_sr.yaml +++ b/configs/imagenet_sr.yaml @@ -3,7 +3,10 @@ model: noise_model_type: "unet" alpha: 0.001 load_weights: null + load_mu_weights: null snr_expansion_n: 1 + zmd: False + diff_inp: True training: lr: 0.0001 @@ -11,7 +14,7 @@ training: eval: output_path: "outputs/imagenet" - generation_timesteps: 1000 + generation_timesteps: 200 checkpoint_freq: 1000 log_freq: 10 image_freq: 100 @@ -19,11 +22,11 @@ eval: val_len: 100 data: - dataset_path: "/bigdata/casus/MLID/maria/imagenet_sample" + dataset_path: "/bigdata/imnet" n_samples: 100 batch_size: 2 im_size: 256 neptune: - name: "Virtual_Stain" + name: "CVDM" project: "mlid/test" \ No newline at end of file diff --git a/configs/imagenet_sr_sample.yaml b/configs/imagenet_sr_sample.yaml new file mode 100644 index 0000000..b8dfc66 --- /dev/null +++ b/configs/imagenet_sr_sample.yaml @@ -0,0 +1,32 @@ +task: "imagenet_sr" +model: + noise_model_type: "unet" + alpha: 0.001 + load_weights: null + load_mu_weights: null + snr_expansion_n: 1 + zmd: False + diff_inp: True + +training: + lr: 0.0001 + epochs: 100 + +eval: + output_path: "outputs/imagenet" + generation_timesteps: 200 + checkpoint_freq: 1000 + log_freq: 10 + image_freq: 1000 + val_freq: 2000 + val_len: 10 + +data: + dataset_path: "data/imnet_sample" + n_samples: 100 + batch_size: 2 + im_size: 256 + +neptune: + name: "CVDM" + project: "mlid/test" \ No newline at end of file diff --git a/cvdm/architectures/components/residual_block.py b/cvdm/architectures/components/residual_block.py index 6b39216..1239ccd 100644 --- a/cvdm/architectures/components/residual_block.py +++ b/cvdm/architectures/components/residual_block.py @@ -6,10 +6,6 @@ from tensorflow.keras.layers import Concatenate, Conv2D, Dropout from tensorflow_addons.layers import GroupNormalization -from cvdm.architectures.components.conditional_instance_normalization import ( - ConditionalInstanceNormalization, -) - def resblock( x: tf.Tensor, noise_embedding: tf.Tensor, n_out_channels: int, dropout: float = 0.0 diff --git a/cvdm/architectures/sr3.py b/cvdm/architectures/sr3.py index 79b7029..44d8b95 100644 --- a/cvdm/architectures/sr3.py +++ b/cvdm/architectures/sr3.py @@ -1,15 +1,17 @@ from typing import Tuple + import tensorflow as tf +from tensorflow.keras.activations import swish +from tensorflow.keras.layers import AveragePooling2D, Conv2D, Input, UpSampling2D +from tensorflow.keras.models import Model +from tensorflow_addons.layers import GroupNormalization + from cvdm.architectures.components.attention_block import attention_block from cvdm.architectures.components.deep_residual_block import ( deep_resblock, up_deep_resblock, ) from cvdm.architectures.components.residual_block import resblock, up_resblock -from tensorflow.keras.models import Model -from tensorflow.keras.activations import swish -from tensorflow.keras.layers import AveragePooling2D, Conv2D, Input, UpSampling2D -from tensorflow_addons.layers import GroupNormalization def upsample(x: tf.Tensor, use_conv: bool = False) -> tf.Tensor: diff --git a/cvdm/architectures/unet.py b/cvdm/architectures/unet.py index 7f67c50..01a04a3 100644 --- a/cvdm/architectures/unet.py +++ b/cvdm/architectures/unet.py @@ -2,17 +2,15 @@ import numpy as np import tensorflow as tf -from tensorflow.keras.models import Model from tensorflow.keras.layers import ( Add, - Attention, Concatenate, Conv2D, Conv2DTranspose, Input, - Lambda, MaxPooling2D, ) +from tensorflow.keras.models import Model from tensorflow_addons.layers import InstanceNormalization diff --git a/cvdm/configs/data_config.py b/cvdm/configs/data_config.py index 8a12501..56684d0 100644 --- a/cvdm/configs/data_config.py +++ b/cvdm/configs/data_config.py @@ -1,13 +1,20 @@ from dataclasses import dataclass -from typing import Optional, Tuple @dataclass class DataConfig: + """ + Configuration for the dataset. + + Attributes: + dataset_path (str): Path to the dataset directory. For NpyDataloader it is expected that it will contain x.npy and y.npy files. + For PhasePolychromeDataloader it should contain y.npy. For the rest of dataloaders, the directory should include .JPEG images. + n_samples (int): Number of samples to use from the dataset. + batch_size (int): Number of samples per batch during training. + im_size (int): The size of the patches of images (both height and width) to use. + """ + dataset_path: str n_samples: int batch_size: int im_size: int - - - diff --git a/cvdm/configs/eval_config.py b/cvdm/configs/eval_config.py index c5bf3e6..1c9e67d 100644 --- a/cvdm/configs/eval_config.py +++ b/cvdm/configs/eval_config.py @@ -3,6 +3,19 @@ @dataclass class EvalConfig: + """ + Configuration settings for the evaluation process during model training. + + Attributes: + generation_timesteps (int): Number of timesteps to use for image generation. + output_path (str): Directory where generated images and results will be saved. + image_freq (float): Frequency (in terms of steps) at which images are generated for inspection. + checkpoint_freq (int): Frequency (in steps) at which the model checkpoints are saved. + log_freq (int): Frequency (in steps) at which loss logs are recorded. + val_freq (int): Frequency (in steps) at which validation is performed. + val_len (int): Number of samples used for validation. + """ + generation_timesteps: int output_path: str image_freq: float diff --git a/cvdm/configs/model_config.py b/cvdm/configs/model_config.py index d961792..2e15953 100644 --- a/cvdm/configs/model_config.py +++ b/cvdm/configs/model_config.py @@ -3,7 +3,23 @@ @dataclass class ModelConfig: + """ + Configuration settings for the model architecture and initialization. + + Attributes: + noise_model_type (str): The type of noise model used (unet or sr3). + alpha (float): The scaling factor for L_gamma element of CVDM loss. + snr_expansion_n (int): The expansion factor for the Taylor expansion of SNR. + load_weights (str): Path to the pre-trained model weights to be loaded. + load_mu_weights (str): Path to the pre-trained weights for the mean (mu) model, if applicable. + zmd (bool): Indicates whether Zero-Mean Diffusion (ZMD) should be used in the model. + inp_diff (bool): Indicates whether the model should predict just the noise to remove in each diffusion step or the image with the noise already removed. + """ + noise_model_type: str alpha: float snr_expansion_n: int load_weights: str + load_mu_weights: str + zmd: bool + diff_inp: bool diff --git a/cvdm/configs/neptune_config.py b/cvdm/configs/neptune_config.py index 3184203..426eb24 100644 --- a/cvdm/configs/neptune_config.py +++ b/cvdm/configs/neptune_config.py @@ -3,5 +3,13 @@ @dataclass class NeptuneConfig: + """ + Configuration settings for logging and tracking experiments with Neptune. Can be omitted if Neptune should not be used. + + Attributes: + name (str): The name of te run to be logged in Neptune. + project (str): The Neptune project identifier where the experiment will be tracked, typically in the format 'workspace/project-name'. + """ + name: str project: str diff --git a/cvdm/configs/training_config.py b/cvdm/configs/training_config.py index 91594af..d5b6f24 100644 --- a/cvdm/configs/training_config.py +++ b/cvdm/configs/training_config.py @@ -1,8 +1,15 @@ from dataclasses import dataclass -from typing import Optional @dataclass class TrainingConfig: + """ + Configuration settings for the model training process. + + Attributes: + lr (float): Learning rate for the optimizer during training. + epochs (int): Number of complete passes through the training dataset. + """ + lr: float epochs: int diff --git a/cvdm/data/npy_dataloader.py b/cvdm/data/npy_dataloader.py index 6fcd26a..30abeaf 100644 --- a/cvdm/data/npy_dataloader.py +++ b/cvdm/data/npy_dataloader.py @@ -2,7 +2,7 @@ import numpy as np -from cvdm.utils.data_utils import center_crop, sample_norm_01 +from cvdm.utils.data_utils import center_crop class NpyDataloader: @@ -20,11 +20,14 @@ def __init__( def __len__(self) -> int: return self._n_samples + def get_channels(self) -> Tuple[int, int]: + return self._x.shape[-1], self._y.shape[-1] + def __getitem__(self, idx: int) -> Tuple[np.ndarray, np.ndarray]: x, y = self._x[idx], self._y[idx] - x = sample_norm_01(center_crop(x, crop_size=2000)) - y = sample_norm_01(center_crop(y, crop_size=2000)) + x = center_crop(x, crop_size=2000) + y = center_crop(y, crop_size=2000) if x.shape[0] > self._im_size or x.shape[1] > self._im_size: center_x = np.random.randint( self._im_size // 2, x.shape[1] - self._im_size // 2 diff --git a/cvdm/data/phase_2shot_dataloader.py b/cvdm/data/phase_2shot_dataloader.py index eca4348..431a264 100644 --- a/cvdm/data/phase_2shot_dataloader.py +++ b/cvdm/data/phase_2shot_dataloader.py @@ -1,10 +1,7 @@ -from typing import Iterator, List, Tuple +from typing import Iterator, Tuple import cv2 import numpy as np -from PIL import Image -from skimage.transform import resize -from sklearn.feature_extraction.image import extract_patches_2d from cvdm.utils.data_utils import read_and_patch_image_from_filename from cvdm.utils.phase_utils import FresnelPropagator diff --git a/cvdm/diffusion_models/joint_model.py b/cvdm/diffusion_models/joint_model.py index aa47c1b..2804b79 100644 --- a/cvdm/diffusion_models/joint_model.py +++ b/cvdm/diffusion_models/joint_model.py @@ -1,10 +1,11 @@ -from typing import Tuple +from typing import Optional, Tuple import tensorflow as tf from tensorflow.keras.layers import Input, Lambda from tensorflow.keras.models import Model from cvdm.configs.model_config import ModelConfig +from cvdm.diffusion_models.mean_model import mean_model from cvdm.diffusion_models.noise_model import noise_model from cvdm.diffusion_models.variance_model import variance_model from cvdm.utils.data_utils import obtain_noisy_sample @@ -17,7 +18,7 @@ def create_joint_model( timesteps: int, out_channels: int, model_config: ModelConfig, -) -> Tuple[Model, Model, Model]: +) -> Tuple[Model, Model, Model, Optional[Model]]: """This function creates a Keras model for the image denoising pipeline.""" ground_truth = Input(input_shape_condition[:-1] + (out_channels,)) @@ -32,13 +33,39 @@ def create_joint_model( sch_params_Lt = sch_model([dirty_img, timesteps_Lt]) sch_params_L0 = sch_model([dirty_img, timesteps_L0]) - n_model = noise_model( - input_shape_condition, out_channels, model_type=model_config.noise_model_type - ) + if model_config.zmd: + pass + + sigma = 0.5 + n_model = noise_model( + input_shape_condition, + out_channels, + model_type=model_config.noise_model_type, + zmd=model_config.zmd, + ) + mu_model = mean_model(input_shape_condition, out_channels) + mean_pred = mu_model(dirty_img)[0] + mean_pred_sg = tf.stop_gradient(mean_pred) + n_sample_LT = Lambda(obtain_noisy_sample)( + [(ground_truth - mean_pred_sg) / sigma, sch_params_LT[0]] + ) + n_sample_Lt = Lambda(obtain_noisy_sample)( + [(ground_truth - mean_pred_sg) / sigma, sch_params_Lt[0]] + ) + pred_noise_Lt = n_model( + [n_sample_Lt[0], dirty_img, mean_pred_sg, sch_params_Lt[0]] + ) - n_sample_LT = Lambda(obtain_noisy_sample)([ground_truth, sch_params_LT[0]]) - n_sample_Lt = Lambda(obtain_noisy_sample)([ground_truth, sch_params_Lt[0]]) - pred_noise_Lt = n_model([n_sample_Lt[0], dirty_img, sch_params_Lt[0]]) + else: + n_model = noise_model( + input_shape_condition, + out_channels, + model_type=model_config.noise_model_type, + ) + + n_sample_LT = Lambda(obtain_noisy_sample)([ground_truth, sch_params_LT[0]]) + n_sample_Lt = Lambda(obtain_noisy_sample)([ground_truth, sch_params_Lt[0]]) + pred_noise_Lt = n_model([n_sample_Lt[0], dirty_img, sch_params_Lt[0]]) d_alpha_t = Lambda(time_grad)([sch_params_Lt[0], timesteps_Lt]) @@ -64,12 +91,21 @@ def create_joint_model( + tf.square(sch_params_LT[0]) ) L_gamma = model_config.alpha * tf.square(d2_alpha_n_t) - joint_model: Model = Model( - [ground_truth, dirty_img, timesteps_Lt], - [delta_noise, L_beta, kl_divergence_T, L_gamma], - ) - joint_model.summary() - return n_model, joint_model, sch_model + joint_model: Model + if model_config.zmd: + delta_mean = tf.square(ground_truth - mean_pred) + joint_model = Model( + [ground_truth, dirty_img, timesteps_Lt], + [delta_noise, L_beta, kl_divergence_T, L_gamma, delta_mean], + ) + return n_model, joint_model, sch_model, mu_model + else: + joint_model = Model( + [ground_truth, dirty_img, timesteps_Lt], + [delta_noise, L_beta, kl_divergence_T, L_gamma], + ) + joint_model.summary() + return n_model, joint_model, sch_model, None def instantiate_cvdm( @@ -78,12 +114,22 @@ def instantiate_cvdm( cond_shape: tf.TensorShape, out_shape: tf.TensorShape, model_config: ModelConfig, -) -> Tuple[Model, Model, Model]: +) -> Tuple[Model, Model, Model, Optional[Model]]: opt_m = tf.keras.optimizers.Adam(learning_rate=lr) out_channels = out_shape[-1] assert out_channels is not None - noise_model, joint_model, schedule_model = create_joint_model( - cond_shape, generation_timesteps, out_channels, model_config - ) - joint_model.compile(loss=linear_loss, loss_weights=[1, 2, 2, 2], optimizer=opt_m) # type: ignore - return noise_model, joint_model, schedule_model + + if model_config.zmd: + models = create_joint_model( + cond_shape, generation_timesteps, out_channels, model_config + ) + noise_model, joint_model, schedule_model, mu_model = models + joint_model.compile(loss=linear_loss, loss_weights=[1, 2, 2, 2, 2], optimizer=opt_m) # type: ignore + else: + models = create_joint_model( + cond_shape, generation_timesteps, out_channels, model_config + ) + noise_model, joint_model, schedule_model, _ = models + joint_model.compile(loss=linear_loss, loss_weights=[1, 2, 2, 2], optimizer=opt_m) # type: ignore + + return models diff --git a/cvdm/diffusion_models/mean_model.py b/cvdm/diffusion_models/mean_model.py new file mode 100644 index 0000000..5d050c6 --- /dev/null +++ b/cvdm/diffusion_models/mean_model.py @@ -0,0 +1,15 @@ +import tensorflow as tf +from tensorflow.keras.layers import Input +from tensorflow.keras.models import Model + +from cvdm.architectures.unet import UNet + + +def mean_model(input_shape_condition: tf.TensorShape, out_channels: int) -> Model: + mean_model_input = Input(input_shape_condition) + unet_output = UNet( + input_shape_condition, inputs=mean_model_input, out_filters=out_channels + ) + mean_condition, condition_latents = unet_output + mean_model_out: Model = Model(mean_model_input, [mean_condition, condition_latents]) + return mean_model_out diff --git a/cvdm/diffusion_models/noise_model.py b/cvdm/diffusion_models/noise_model.py index b656abf..07c116c 100644 --- a/cvdm/diffusion_models/noise_model.py +++ b/cvdm/diffusion_models/noise_model.py @@ -1,36 +1,44 @@ import keras import tensorflow as tf -from tensorflow.keras.models import Model from tensorflow.keras.layers import Concatenate, Input +from tensorflow.keras.models import Model from cvdm.architectures.sr3 import sr3 from cvdm.architectures.unet import UNet def noise_model( - input_shape_noisy: tf.TensorShape, out_channels: int, model_type: str + input_shape_noisy: tf.TensorShape, out_channels: int, model_type: str, zmd=False ) -> Model: - noisy_input = Input(input_shape_noisy[:-1] + (out_channels,)) + noisy_inp = Input(input_shape_noisy[:-1] + (out_channels,)) ref_frame = Input(input_shape_noisy) - c_inpt = Concatenate()([noisy_input, ref_frame]) + c_inp = Concatenate()([noisy_inp, ref_frame]) gamma_inp = Input(input_shape_noisy[:-1] + (out_channels,)) model: Model + + if zmd: + mean_inp = Input(input_shape_noisy[:-1] + (out_channels,)) + c_inp = Concatenate()([c_inp, mean_inp]) + if model_type == "sr3": s_model = sr3( - keras.backend.int_shape(c_inpt)[1:], + keras.backend.int_shape(c_inp)[1:], input_shape_noisy[:-1] + (out_channels,), out_channels=out_channels, ) - - noise_out = s_model([c_inpt, gamma_inp]) - + noise_out = s_model([c_inp, gamma_inp]) else: noise_out, _ = UNet( input_shape_noisy, - inputs=c_inpt, + inputs=c_inp, gamma_inp=gamma_inp, out_filters=out_channels, ) - model = Model([noisy_input, ref_frame, gamma_inp], noise_out) + + if zmd: + model = Model([noisy_inp, ref_frame, mean_inp, gamma_inp], noise_out) + + else: + model = Model([noisy_inp, ref_frame, gamma_inp], noise_out) return model diff --git a/cvdm/diffusion_models/time_model.py b/cvdm/diffusion_models/time_model.py index 9eda77a..2c94612 100644 --- a/cvdm/diffusion_models/time_model.py +++ b/cvdm/diffusion_models/time_model.py @@ -1,6 +1,6 @@ import tensorflow as tf -from tensorflow.keras.models import Model from tensorflow.keras.layers import Activation, Conv2D, Input +from tensorflow.keras.models import Model def time_model( diff --git a/cvdm/diffusion_models/variance_model.py b/cvdm/diffusion_models/variance_model.py index 01a253f..8928bf2 100644 --- a/cvdm/diffusion_models/variance_model.py +++ b/cvdm/diffusion_models/variance_model.py @@ -1,8 +1,9 @@ import contextlib from typing import Any, Dict, Iterator + import tensorflow as tf -from tensorflow.keras.models import Model from tensorflow.keras.layers import Activation, Input +from tensorflow.keras.models import Model from cvdm.architectures.unet import UNet from cvdm.diffusion_models.time_model import time_model diff --git a/cvdm/utils/data_utils.py b/cvdm/utils/data_utils.py index a65ab1f..33e7178 100644 --- a/cvdm/utils/data_utils.py +++ b/cvdm/utils/data_utils.py @@ -20,15 +20,6 @@ def read_and_patch_image_from_filename(filename: str, im_size: int) -> Image.Ima return img_patch -def sample_norm_01(x: np.ndarray) -> np.ndarray: - x = x.astype(np.float32) - n_x: np.ndarray = (x - np.amin(x, axis=(0, 1), keepdims=True)) / ( - np.amax(x, axis=(0, 1), keepdims=True) - np.amin(x, axis=(0, 1), keepdims=True) - ) - - return n_x * 2 - 1 - - def center_crop(x: np.ndarray, crop_size: int = 2048) -> np.ndarray: x_center = x.shape[1] // 2 y_center = x.shape[0] // 2 diff --git a/cvdm/utils/inference_utils.py b/cvdm/utils/inference_utils.py index c0bae3e..0d774b1 100644 --- a/cvdm/utils/inference_utils.py +++ b/cvdm/utils/inference_utils.py @@ -1,7 +1,7 @@ +import os from typing import Dict, Optional, Tuple, Union import numpy as np -import tensorflow as tf from matplotlib import pyplot as plt from neptune import Run from neptune.types import File @@ -17,13 +17,15 @@ def ddpm_obtain_sr_img( timesteps_test: int, noise_model: Model, schedule_model: Model, + mu_model: Optional[Model], out_shape: Optional[Tuple[int, ...]] = None, ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: if out_shape == None: out_shape = x.shape assert out_shape is not None pred_sr = np.random.normal(0, 1, out_shape) - + if mu_model is not None: + mu_pred = mu_model.predict(x, verbose=0)[0] alpha_vec = np.zeros(out_shape + (timesteps_test,)) for t in tqdm(range(timesteps_test)): t_inp = np.clip( @@ -53,11 +55,18 @@ def ddpm_obtain_sr_img( + np.sqrt(1 - gamma_t - beta_factor) * pred_noise + np.sqrt(beta_factor) * z ) - pred_noise = noise_model.predict([pred_sr, x, gamma_t], verbose=0) + if mu_model is not None: + pred_noise = noise_model.predict([pred_sr, x, mu_pred, gamma_t], verbose=0) + else: + pred_noise = noise_model.predict([pred_sr, x, gamma_t], verbose=0) pred_sr = (pred_sr - np.sqrt(1 - gamma_t) * pred_noise) / np.sqrt(gamma_t) count += 1 - - return pred_sr, gamma_vec, alpha_vec + if mu_model is not None: + sigma = 0.5 + pred_diff = sigma * pred_sr + mu_pred + else: + pred_diff = pred_sr + return pred_diff, gamma_vec, alpha_vec def create_output_montage( @@ -95,6 +104,23 @@ def log_loss(run: Optional[Run], avg_loss: np.ndarray, prefix: str) -> None: run[f"{prefix}_loss_beta"].log(avg_loss[2]) run[f"{prefix}_loss_KL"].log(avg_loss[3]) run[f"{prefix}_loss_gamma"].log(avg_loss[4]) + if len(avg_loss) == 6: + run[f"{prefix}_loss_mean"].log(avg_loss[5]) + else: + loss_labels = [ + "Loss Sum", + "Delta Noise Loss", + "Beta Loss", + "KL Loss", + "Gamma Loss", + ] + formatted_losses = [ + f"{label}: {loss:.6f}" for label, loss in zip(loss_labels, avg_loss[:5]) + ] + for loss in formatted_losses: + print(loss) + if len(avg_loss) == 6: + print(f"Mean Loss: {avg_loss[5]:.6f}") def log_metrics( @@ -103,18 +129,37 @@ def log_metrics( if run is not None: for metric_name, metric_value in metrics_dict.items(): run[f"{prefix}_" + metric_name].log(metric_value) + else: + print(f"{prefix.capitalize()} Metrics:") + for metric_name, metric_value in metrics_dict.items(): + print(f"{metric_name}: {metric_value:.6f}") -def save_weighs( - run: Optional[Run], model: Model, step: int, output_path: str, run_id: str +def save_weights( + run: Optional[Run], + model: Model, + mu_model: Optional[Model], + step: int, + output_path: str, + run_id: str, ) -> None: + weights_dir = f"{output_path}/weights" + os.makedirs(weights_dir, exist_ok=True) - model.save_weights(f"{output_path}/weights/model_{str(step)}_{run_id}.h5", True) + model_weights_path = f"{weights_dir}/model_{str(step)}_{run_id}.h5" + model.save_weights(model_weights_path, True) if run is not None: - run[f"model_weights/model_{str(step)}.h5"].upload( - f"{output_path}/weights/model_{str(step)}_{run_id}.h5" - ) + run[f"model_weights/model_{str(step)}.h5"].upload(model_weights_path) + + if mu_model is not None: + mu_model_weights_path = f"{weights_dir}/mu_model_{str(step)}_{run_id}.h5" + mu_model.save_weights(mu_model_weights_path, True) + + if run is not None: + run[f"mu_model_weights/mu_model_{str(step)}.h5"].upload( + mu_model_weights_path + ) def save_output_montage( @@ -126,16 +171,15 @@ def save_output_montage( prefix: str, cmap: Optional[str] = None, ) -> None: + output_dir = f"{output_path}/images" + os.makedirs(output_dir, exist_ok=True) - plt.imsave( - f"{output_path}/images/{prefix}_output_{str(step)}_{run_id}.png", - output_montage, - cmap=cmap, - ) + image_path = f"{output_dir}/{prefix}_output_{str(step)}_{run_id}.png" + plt.imsave(image_path, output_montage, cmap=cmap) if run is not None: run[f"{prefix}_images"].append( - File(f"{output_path}/images/{prefix}_output_{str(step)}_{run_id}.png"), + File(image_path), description=f"Step {step}, {prefix}", ) @@ -145,16 +189,18 @@ def obtain_output_montage_and_metrics( batch_y: np.ndarray, noise_model: Model, schedule_model: Model, + mu_model: Optional[Model], generation_timesteps: int, + diff_inp: bool, task: str, ) -> Tuple[np.ndarray, Dict]: - diff_inp = task in ["biosr_sr", "imagenet_sr"] pred_diff, gamma_vec, _ = ddpm_obtain_sr_img( batch_x, generation_timesteps, noise_model, schedule_model, + mu_model, batch_y.shape, ) if diff_inp: diff --git a/cvdm/utils/loss_utils.py b/cvdm/utils/loss_utils.py index d75c9d6..cf8b646 100644 --- a/cvdm/utils/loss_utils.py +++ b/cvdm/utils/loss_utils.py @@ -1,4 +1,5 @@ from typing import List + import tensorflow as tf diff --git a/cvdm/utils/metrics_utils.py b/cvdm/utils/metrics_utils.py index 58b83cd..5e9ce81 100644 --- a/cvdm/utils/metrics_utils.py +++ b/cvdm/utils/metrics_utils.py @@ -1,4 +1,5 @@ from typing import Dict, Optional + import numpy as np from skimage.metrics import peak_signal_noise_ratio, structural_similarity @@ -8,24 +9,9 @@ def nmae(y_pred: np.ndarray, y_real: np.ndarray) -> float: return nmae -def dice(y_pred_masks: np.ndarray, y_real_masks: np.ndarray) -> float: - intersection = np.sum(y_pred_masks * y_real_masks) - dice: float = intersection * 2.0 / (np.sum(y_real_masks) + np.sum(y_pred_masks)) - return dice - - -def iou(y_pred_masks: np.ndarray, y_real_masks: np.ndarray) -> float: - intersection = np.sum(y_pred_masks * y_real_masks) - union = np.sum(y_real_masks + y_pred_masks) - intersection - iou: float = intersection / union - return iou - - def calculate_metrics( y_pred_batch: np.ndarray, y_real_batch: np.ndarray, - y_pred_masks_batch: Optional[np.ndarray] = None, - y_real_masks_batch: Optional[np.ndarray] = None, ) -> Dict[str, float]: y_pred_batch = np.array(y_pred_batch) y_real_batch = np.array(y_real_batch) @@ -66,9 +52,4 @@ def calculate_metrics( ), } - # Optional mask-related metrics - if y_pred_masks_batch is not None and y_real_masks_batch is not None: - metrics["dice"] = np.mean(dice(y_pred_masks_batch, y_real_masks_batch)) - metrics["iou"] = np.mean(iou(y_pred_masks_batch, y_real_masks_batch)) - return metrics diff --git a/cvdm/utils/phase_utils.py b/cvdm/utils/phase_utils.py index 0898e1e..b32a6f2 100644 --- a/cvdm/utils/phase_utils.py +++ b/cvdm/utils/phase_utils.py @@ -5,6 +5,7 @@ from cupy.fft import fftn, fftshift, ifftn, ifftshift +# If you encounter issues with cupy installation, you can use the numpy implementation of fft def easy_fft(data: np.ndarray, axes: Optional[Sequence[int]] = None) -> cp.ndarray: """FFT that includes shifting.""" return fftshift(fftn(ifftshift(data, axes=axes), axes=axes), axes=axes) diff --git a/cvdm/utils/training_utils.py b/cvdm/utils/training_utils.py index ad2c45c..ce948c6 100644 --- a/cvdm/utils/training_utils.py +++ b/cvdm/utils/training_utils.py @@ -24,8 +24,6 @@ def prepare_dataset( n_samples=data_config.n_samples, im_size=data_config.im_size, ) - x_channels = 1 - y_channels = x_channels elif task == "imagenet_sr": dataloader = ImageDirDataloader( @@ -54,6 +52,15 @@ def prepare_dataset( ) x_channels = 2 y_channels = 1 + elif task == "other": + dataloader = NpyDataloader( + path=data_config.dataset_path, + n_samples=data_config.n_samples, + im_size=data_config.im_size, + ) + x_channels, y_channels = dataloader.get_channels() + else: + raise NotImplementedError() x_shape = tf.TensorShape([data_config.im_size, data_config.im_size, x_channels]) y_shape = tf.TensorShape([data_config.im_size, data_config.im_size, y_channels]) diff --git a/data/imnet_sample/n01518878_10151.JPEG b/data/imnet_sample/n01518878_10151.JPEG new file mode 100755 index 0000000..bfc08bd Binary files /dev/null and b/data/imnet_sample/n01518878_10151.JPEG differ diff --git a/data/imnet_sample/n01687978_6241.JPEG b/data/imnet_sample/n01687978_6241.JPEG new file mode 100755 index 0000000..2abb9ad Binary files /dev/null and b/data/imnet_sample/n01687978_6241.JPEG differ diff --git a/data/imnet_sample/n01773797_5942.JPEG b/data/imnet_sample/n01773797_5942.JPEG new file mode 100755 index 0000000..8564065 Binary files /dev/null and b/data/imnet_sample/n01773797_5942.JPEG differ diff --git a/data/imnet_sample/n01819313_5400.JPEG b/data/imnet_sample/n01819313_5400.JPEG new file mode 100755 index 0000000..ceb3c79 Binary files /dev/null and b/data/imnet_sample/n01819313_5400.JPEG differ diff --git a/data/imnet_sample/n01824575_9598.JPEG b/data/imnet_sample/n01824575_9598.JPEG new file mode 100755 index 0000000..2f566d3 Binary files /dev/null and b/data/imnet_sample/n01824575_9598.JPEG differ diff --git a/data/imnet_sample/n01833805_8954.JPEG b/data/imnet_sample/n01833805_8954.JPEG new file mode 100755 index 0000000..eede6bf Binary files /dev/null and b/data/imnet_sample/n01833805_8954.JPEG differ diff --git a/data/imnet_sample/n01871265_4223.JPEG b/data/imnet_sample/n01871265_4223.JPEG new file mode 100755 index 0000000..2185602 Binary files /dev/null and b/data/imnet_sample/n01871265_4223.JPEG differ diff --git a/data/imnet_sample/n01968897_5251.JPEG b/data/imnet_sample/n01968897_5251.JPEG new file mode 100755 index 0000000..f0d02cd Binary files /dev/null and b/data/imnet_sample/n01968897_5251.JPEG differ diff --git a/data/imnet_sample/n01978455_1216.JPEG b/data/imnet_sample/n01978455_1216.JPEG new file mode 100755 index 0000000..9c28788 Binary files /dev/null and b/data/imnet_sample/n01978455_1216.JPEG differ diff --git a/data/imnet_sample/n01983481_26090.JPEG b/data/imnet_sample/n01983481_26090.JPEG new file mode 100755 index 0000000..b409dba Binary files /dev/null and b/data/imnet_sample/n01983481_26090.JPEG differ diff --git a/data/imnet_sample/n02002724_578.JPEG b/data/imnet_sample/n02002724_578.JPEG new file mode 100755 index 0000000..8ff201f Binary files /dev/null and b/data/imnet_sample/n02002724_578.JPEG differ diff --git a/data/imnet_sample/n02009912_36790.JPEG b/data/imnet_sample/n02009912_36790.JPEG new file mode 100755 index 0000000..1138da3 Binary files /dev/null and b/data/imnet_sample/n02009912_36790.JPEG differ diff --git a/data/imnet_sample/n02025239_6181.JPEG b/data/imnet_sample/n02025239_6181.JPEG new file mode 100755 index 0000000..24b819f Binary files /dev/null and b/data/imnet_sample/n02025239_6181.JPEG differ diff --git a/data/imnet_sample/n02056570_74.JPEG b/data/imnet_sample/n02056570_74.JPEG new file mode 100755 index 0000000..7ea47cb Binary files /dev/null and b/data/imnet_sample/n02056570_74.JPEG differ diff --git a/data/imnet_sample/n02086910_2749.JPEG b/data/imnet_sample/n02086910_2749.JPEG new file mode 100755 index 0000000..4274540 Binary files /dev/null and b/data/imnet_sample/n02086910_2749.JPEG differ diff --git a/data/imnet_sample/n02088364_5604.JPEG b/data/imnet_sample/n02088364_5604.JPEG new file mode 100755 index 0000000..5d3f509 Binary files /dev/null and b/data/imnet_sample/n02088364_5604.JPEG differ diff --git a/data/imnet_sample/n02088466_1085.JPEG b/data/imnet_sample/n02088466_1085.JPEG new file mode 100755 index 0000000..9b115b4 Binary files /dev/null and b/data/imnet_sample/n02088466_1085.JPEG differ diff --git a/data/imnet_sample/n02093428_316.JPEG b/data/imnet_sample/n02093428_316.JPEG new file mode 100755 index 0000000..460dafb Binary files /dev/null and b/data/imnet_sample/n02093428_316.JPEG differ diff --git a/data/imnet_sample/n02094258_1004.JPEG b/data/imnet_sample/n02094258_1004.JPEG new file mode 100755 index 0000000..8333b42 Binary files /dev/null and b/data/imnet_sample/n02094258_1004.JPEG differ diff --git a/data/imnet_sample/n02096051_137.JPEG b/data/imnet_sample/n02096051_137.JPEG new file mode 100755 index 0000000..423db6b Binary files /dev/null and b/data/imnet_sample/n02096051_137.JPEG differ diff --git a/data/imnet_sample/n02096294_12931.JPEG b/data/imnet_sample/n02096294_12931.JPEG new file mode 100755 index 0000000..370c4be Binary files /dev/null and b/data/imnet_sample/n02096294_12931.JPEG differ diff --git a/data/imnet_sample/n02097298_16152.JPEG b/data/imnet_sample/n02097298_16152.JPEG new file mode 100755 index 0000000..969dc32 Binary files /dev/null and b/data/imnet_sample/n02097298_16152.JPEG differ diff --git a/data/imnet_sample/n02105056_1952.JPEG b/data/imnet_sample/n02105056_1952.JPEG new file mode 100755 index 0000000..f90a7f6 Binary files /dev/null and b/data/imnet_sample/n02105056_1952.JPEG differ diff --git a/data/imnet_sample/n02105412_1856.JPEG b/data/imnet_sample/n02105412_1856.JPEG new file mode 100755 index 0000000..1a77b2c Binary files /dev/null and b/data/imnet_sample/n02105412_1856.JPEG differ diff --git a/data/imnet_sample/n02108422_4094.JPEG b/data/imnet_sample/n02108422_4094.JPEG new file mode 100755 index 0000000..2be1172 Binary files /dev/null and b/data/imnet_sample/n02108422_4094.JPEG differ diff --git a/data/imnet_sample/n02109047_7510.JPEG b/data/imnet_sample/n02109047_7510.JPEG new file mode 100755 index 0000000..e2bcaa4 Binary files /dev/null and b/data/imnet_sample/n02109047_7510.JPEG differ diff --git a/data/imnet_sample/n02119022_7363.JPEG b/data/imnet_sample/n02119022_7363.JPEG new file mode 100755 index 0000000..39c2fbf Binary files /dev/null and b/data/imnet_sample/n02119022_7363.JPEG differ diff --git a/data/imnet_sample/n02120079_34645.JPEG b/data/imnet_sample/n02120079_34645.JPEG new file mode 100755 index 0000000..e6d9478 Binary files /dev/null and b/data/imnet_sample/n02120079_34645.JPEG differ diff --git a/data/imnet_sample/n02127052_22549.JPEG b/data/imnet_sample/n02127052_22549.JPEG new file mode 100755 index 0000000..87252b0 Binary files /dev/null and b/data/imnet_sample/n02127052_22549.JPEG differ diff --git a/data/imnet_sample/n02129165_3553.JPEG b/data/imnet_sample/n02129165_3553.JPEG new file mode 100755 index 0000000..95ced64 Binary files /dev/null and b/data/imnet_sample/n02129165_3553.JPEG differ diff --git a/data/imnet_sample/n02256656_17395.JPEG b/data/imnet_sample/n02256656_17395.JPEG new file mode 100755 index 0000000..450657b Binary files /dev/null and b/data/imnet_sample/n02256656_17395.JPEG differ diff --git a/data/imnet_sample/n02268853_17452.JPEG b/data/imnet_sample/n02268853_17452.JPEG new file mode 100755 index 0000000..e50641a Binary files /dev/null and b/data/imnet_sample/n02268853_17452.JPEG differ diff --git a/data/imnet_sample/n02281787_2598.JPEG b/data/imnet_sample/n02281787_2598.JPEG new file mode 100755 index 0000000..788e924 Binary files /dev/null and b/data/imnet_sample/n02281787_2598.JPEG differ diff --git a/data/imnet_sample/n02317335_9258.JPEG b/data/imnet_sample/n02317335_9258.JPEG new file mode 100755 index 0000000..92b2cdb Binary files /dev/null and b/data/imnet_sample/n02317335_9258.JPEG differ diff --git a/data/imnet_sample/n02319095_4544.JPEG b/data/imnet_sample/n02319095_4544.JPEG new file mode 100755 index 0000000..a80c722 Binary files /dev/null and b/data/imnet_sample/n02319095_4544.JPEG differ diff --git a/data/imnet_sample/n02422106_9113.JPEG b/data/imnet_sample/n02422106_9113.JPEG new file mode 100755 index 0000000..52ac526 Binary files /dev/null and b/data/imnet_sample/n02422106_9113.JPEG differ diff --git a/data/imnet_sample/n02442845_18886.JPEG b/data/imnet_sample/n02442845_18886.JPEG new file mode 100755 index 0000000..1bcfa22 Binary files /dev/null and b/data/imnet_sample/n02442845_18886.JPEG differ diff --git a/data/imnet_sample/n02447366_98.JPEG b/data/imnet_sample/n02447366_98.JPEG new file mode 100755 index 0000000..2cc9501 Binary files /dev/null and b/data/imnet_sample/n02447366_98.JPEG differ diff --git a/data/imnet_sample/n02486410_3043.JPEG b/data/imnet_sample/n02486410_3043.JPEG new file mode 100755 index 0000000..2af11fa Binary files /dev/null and b/data/imnet_sample/n02486410_3043.JPEG differ diff --git a/data/imnet_sample/n02489166_10125.JPEG b/data/imnet_sample/n02489166_10125.JPEG new file mode 100755 index 0000000..46b4b3c Binary files /dev/null and b/data/imnet_sample/n02489166_10125.JPEG differ diff --git a/data/imnet_sample/n02492660_11410.JPEG b/data/imnet_sample/n02492660_11410.JPEG new file mode 100755 index 0000000..682a84e Binary files /dev/null and b/data/imnet_sample/n02492660_11410.JPEG differ diff --git a/data/imnet_sample/n02509815_15551.JPEG b/data/imnet_sample/n02509815_15551.JPEG new file mode 100755 index 0000000..55017ed Binary files /dev/null and b/data/imnet_sample/n02509815_15551.JPEG differ diff --git a/data/imnet_sample/n02640242_203.JPEG b/data/imnet_sample/n02640242_203.JPEG new file mode 100755 index 0000000..83c014c Binary files /dev/null and b/data/imnet_sample/n02640242_203.JPEG differ diff --git a/data/imnet_sample/n02730930_19384.JPEG b/data/imnet_sample/n02730930_19384.JPEG new file mode 100755 index 0000000..a6d0d7e Binary files /dev/null and b/data/imnet_sample/n02730930_19384.JPEG differ diff --git a/data/imnet_sample/n02749479_4070.JPEG b/data/imnet_sample/n02749479_4070.JPEG new file mode 100755 index 0000000..aa7e9a5 Binary files /dev/null and b/data/imnet_sample/n02749479_4070.JPEG differ diff --git a/data/imnet_sample/n02950826_8185.JPEG b/data/imnet_sample/n02950826_8185.JPEG new file mode 100755 index 0000000..f358233 Binary files /dev/null and b/data/imnet_sample/n02950826_8185.JPEG differ diff --git a/data/imnet_sample/n02951358_1545.JPEG b/data/imnet_sample/n02951358_1545.JPEG new file mode 100755 index 0000000..45d98a0 Binary files /dev/null and b/data/imnet_sample/n02951358_1545.JPEG differ diff --git a/data/imnet_sample/n02980441_34980.JPEG b/data/imnet_sample/n02980441_34980.JPEG new file mode 100755 index 0000000..0d91147 Binary files /dev/null and b/data/imnet_sample/n02980441_34980.JPEG differ diff --git a/data/imnet_sample/n03000247_10805.JPEG b/data/imnet_sample/n03000247_10805.JPEG new file mode 100755 index 0000000..9c4ca97 Binary files /dev/null and b/data/imnet_sample/n03000247_10805.JPEG differ diff --git a/data/imnet_sample/n03014705_11365.JPEG b/data/imnet_sample/n03014705_11365.JPEG new file mode 100755 index 0000000..ef201f2 Binary files /dev/null and b/data/imnet_sample/n03014705_11365.JPEG differ diff --git a/data/imnet_sample/n03018349_11340.JPEG b/data/imnet_sample/n03018349_11340.JPEG new file mode 100755 index 0000000..6ea5109 Binary files /dev/null and b/data/imnet_sample/n03018349_11340.JPEG differ diff --git a/data/imnet_sample/n03028079_7271.JPEG b/data/imnet_sample/n03028079_7271.JPEG new file mode 100755 index 0000000..a9da87f Binary files /dev/null and b/data/imnet_sample/n03028079_7271.JPEG differ diff --git a/data/imnet_sample/n03089624_64438.JPEG b/data/imnet_sample/n03089624_64438.JPEG new file mode 100755 index 0000000..4d19068 Binary files /dev/null and b/data/imnet_sample/n03089624_64438.JPEG differ diff --git a/data/imnet_sample/n03134739_1972.JPEG b/data/imnet_sample/n03134739_1972.JPEG new file mode 100755 index 0000000..b6d821a Binary files /dev/null and b/data/imnet_sample/n03134739_1972.JPEG differ diff --git a/data/imnet_sample/n03207941_13371.JPEG b/data/imnet_sample/n03207941_13371.JPEG new file mode 100755 index 0000000..90a1ac2 Binary files /dev/null and b/data/imnet_sample/n03207941_13371.JPEG differ diff --git a/data/imnet_sample/n03345487_8639.JPEG b/data/imnet_sample/n03345487_8639.JPEG new file mode 100755 index 0000000..008742b Binary files /dev/null and b/data/imnet_sample/n03345487_8639.JPEG differ diff --git a/data/imnet_sample/n03388043_9050.JPEG b/data/imnet_sample/n03388043_9050.JPEG new file mode 100755 index 0000000..f03cf59 Binary files /dev/null and b/data/imnet_sample/n03388043_9050.JPEG differ diff --git a/data/imnet_sample/n03388043_9872.JPEG b/data/imnet_sample/n03388043_9872.JPEG new file mode 100755 index 0000000..88fb8af Binary files /dev/null and b/data/imnet_sample/n03388043_9872.JPEG differ diff --git a/data/imnet_sample/n03400231_14172.JPEG b/data/imnet_sample/n03400231_14172.JPEG new file mode 100755 index 0000000..39f317b Binary files /dev/null and b/data/imnet_sample/n03400231_14172.JPEG differ diff --git a/data/imnet_sample/n03400231_16862.JPEG b/data/imnet_sample/n03400231_16862.JPEG new file mode 100755 index 0000000..b378d96 Binary files /dev/null and b/data/imnet_sample/n03400231_16862.JPEG differ diff --git a/data/imnet_sample/n03404251_8619.JPEG b/data/imnet_sample/n03404251_8619.JPEG new file mode 100755 index 0000000..340ca87 Binary files /dev/null and b/data/imnet_sample/n03404251_8619.JPEG differ diff --git a/data/imnet_sample/n03483316_2248.JPEG b/data/imnet_sample/n03483316_2248.JPEG new file mode 100755 index 0000000..4707fe8 Binary files /dev/null and b/data/imnet_sample/n03483316_2248.JPEG differ diff --git a/data/imnet_sample/n03532672_13032.JPEG b/data/imnet_sample/n03532672_13032.JPEG new file mode 100755 index 0000000..bdfbbd2 Binary files /dev/null and b/data/imnet_sample/n03532672_13032.JPEG differ diff --git a/data/imnet_sample/n03657121_3221.JPEG b/data/imnet_sample/n03657121_3221.JPEG new file mode 100755 index 0000000..8feb007 Binary files /dev/null and b/data/imnet_sample/n03657121_3221.JPEG differ diff --git a/data/imnet_sample/n03720891_15866.JPEG b/data/imnet_sample/n03720891_15866.JPEG new file mode 100755 index 0000000..c959804 Binary files /dev/null and b/data/imnet_sample/n03720891_15866.JPEG differ diff --git a/data/imnet_sample/n03733805_7060.JPEG b/data/imnet_sample/n03733805_7060.JPEG new file mode 100755 index 0000000..77834a8 Binary files /dev/null and b/data/imnet_sample/n03733805_7060.JPEG differ diff --git a/data/imnet_sample/n03782006_2336.JPEG b/data/imnet_sample/n03782006_2336.JPEG new file mode 100755 index 0000000..3345dae Binary files /dev/null and b/data/imnet_sample/n03782006_2336.JPEG differ diff --git a/data/imnet_sample/n03785016_25153.JPEG b/data/imnet_sample/n03785016_25153.JPEG new file mode 100755 index 0000000..22dbef6 Binary files /dev/null and b/data/imnet_sample/n03785016_25153.JPEG differ diff --git a/data/imnet_sample/n03814906_36618.JPEG b/data/imnet_sample/n03814906_36618.JPEG new file mode 100755 index 0000000..5e422b3 Binary files /dev/null and b/data/imnet_sample/n03814906_36618.JPEG differ diff --git a/data/imnet_sample/n03874599_15484.JPEG b/data/imnet_sample/n03874599_15484.JPEG new file mode 100755 index 0000000..03a8828 Binary files /dev/null and b/data/imnet_sample/n03874599_15484.JPEG differ diff --git a/data/imnet_sample/n03877845_202.JPEG b/data/imnet_sample/n03877845_202.JPEG new file mode 100755 index 0000000..d56c13c Binary files /dev/null and b/data/imnet_sample/n03877845_202.JPEG differ diff --git a/data/imnet_sample/n03935335_13067.JPEG b/data/imnet_sample/n03935335_13067.JPEG new file mode 100755 index 0000000..9a5ff00 Binary files /dev/null and b/data/imnet_sample/n03935335_13067.JPEG differ diff --git a/data/imnet_sample/n04067472_8614.JPEG b/data/imnet_sample/n04067472_8614.JPEG new file mode 100755 index 0000000..52c27f8 Binary files /dev/null and b/data/imnet_sample/n04067472_8614.JPEG differ diff --git a/data/imnet_sample/n04081281_6555.JPEG b/data/imnet_sample/n04081281_6555.JPEG new file mode 100755 index 0000000..eec5dfc Binary files /dev/null and b/data/imnet_sample/n04081281_6555.JPEG differ diff --git a/data/imnet_sample/n04086273_5973.JPEG b/data/imnet_sample/n04086273_5973.JPEG new file mode 100755 index 0000000..535c19e Binary files /dev/null and b/data/imnet_sample/n04086273_5973.JPEG differ diff --git a/data/imnet_sample/n04125021_881.JPEG b/data/imnet_sample/n04125021_881.JPEG new file mode 100755 index 0000000..2f8f849 Binary files /dev/null and b/data/imnet_sample/n04125021_881.JPEG differ diff --git a/data/imnet_sample/n04209239_1891.JPEG b/data/imnet_sample/n04209239_1891.JPEG new file mode 100755 index 0000000..9967371 Binary files /dev/null and b/data/imnet_sample/n04209239_1891.JPEG differ diff --git a/data/imnet_sample/n04235860_6220.JPEG b/data/imnet_sample/n04235860_6220.JPEG new file mode 100755 index 0000000..8ef3a89 Binary files /dev/null and b/data/imnet_sample/n04235860_6220.JPEG differ diff --git a/data/imnet_sample/n04254120_3647.JPEG b/data/imnet_sample/n04254120_3647.JPEG new file mode 100755 index 0000000..6a290b3 Binary files /dev/null and b/data/imnet_sample/n04254120_3647.JPEG differ diff --git a/data/imnet_sample/n04270147_3050.JPEG b/data/imnet_sample/n04270147_3050.JPEG new file mode 100755 index 0000000..a99689e Binary files /dev/null and b/data/imnet_sample/n04270147_3050.JPEG differ diff --git a/data/imnet_sample/n04285008_6901.JPEG b/data/imnet_sample/n04285008_6901.JPEG new file mode 100755 index 0000000..16642a6 Binary files /dev/null and b/data/imnet_sample/n04285008_6901.JPEG differ diff --git a/data/imnet_sample/n04335435_12746.JPEG b/data/imnet_sample/n04335435_12746.JPEG new file mode 100755 index 0000000..a42e935 Binary files /dev/null and b/data/imnet_sample/n04335435_12746.JPEG differ diff --git a/data/imnet_sample/n04389033_13016.JPEG b/data/imnet_sample/n04389033_13016.JPEG new file mode 100755 index 0000000..92f5068 Binary files /dev/null and b/data/imnet_sample/n04389033_13016.JPEG differ diff --git a/data/imnet_sample/n04429376_28621.JPEG b/data/imnet_sample/n04429376_28621.JPEG new file mode 100755 index 0000000..f9840c4 Binary files /dev/null and b/data/imnet_sample/n04429376_28621.JPEG differ diff --git a/data/imnet_sample/n04536866_11653.JPEG b/data/imnet_sample/n04536866_11653.JPEG new file mode 100755 index 0000000..7d59814 Binary files /dev/null and b/data/imnet_sample/n04536866_11653.JPEG differ diff --git a/data/imnet_sample/n04589890_908.JPEG b/data/imnet_sample/n04589890_908.JPEG new file mode 100755 index 0000000..53fef05 Binary files /dev/null and b/data/imnet_sample/n04589890_908.JPEG differ diff --git a/data/imnet_sample/n04590129_7262.JPEG b/data/imnet_sample/n04590129_7262.JPEG new file mode 100755 index 0000000..5a80d1c Binary files /dev/null and b/data/imnet_sample/n04590129_7262.JPEG differ diff --git a/data/imnet_sample/n06785654_45113.JPEG b/data/imnet_sample/n06785654_45113.JPEG new file mode 100755 index 0000000..3d7131d Binary files /dev/null and b/data/imnet_sample/n06785654_45113.JPEG differ diff --git a/data/imnet_sample/n07697313_5440.JPEG b/data/imnet_sample/n07697313_5440.JPEG new file mode 100755 index 0000000..d9a2d9e Binary files /dev/null and b/data/imnet_sample/n07697313_5440.JPEG differ diff --git a/data/imnet_sample/n07717410_304.JPEG b/data/imnet_sample/n07717410_304.JPEG new file mode 100755 index 0000000..75e5215 Binary files /dev/null and b/data/imnet_sample/n07717410_304.JPEG differ diff --git a/data/imnet_sample/n07742313_10884.JPEG b/data/imnet_sample/n07742313_10884.JPEG new file mode 100755 index 0000000..b4e636e Binary files /dev/null and b/data/imnet_sample/n07742313_10884.JPEG differ diff --git a/data/imnet_sample/n07802026_2160.JPEG b/data/imnet_sample/n07802026_2160.JPEG new file mode 100755 index 0000000..dbceaa9 Binary files /dev/null and b/data/imnet_sample/n07802026_2160.JPEG differ diff --git a/data/imnet_sample/n07831146_2196.JPEG b/data/imnet_sample/n07831146_2196.JPEG new file mode 100755 index 0000000..cfd26cf Binary files /dev/null and b/data/imnet_sample/n07831146_2196.JPEG differ diff --git a/data/imnet_sample/n07871810_3638.JPEG b/data/imnet_sample/n07871810_3638.JPEG new file mode 100755 index 0000000..f9f63fc Binary files /dev/null and b/data/imnet_sample/n07871810_3638.JPEG differ diff --git a/data/imnet_sample/n07920052_13898.JPEG b/data/imnet_sample/n07920052_13898.JPEG new file mode 100755 index 0000000..6158407 Binary files /dev/null and b/data/imnet_sample/n07920052_13898.JPEG differ diff --git a/data/imnet_sample/n11939491_58904.JPEG b/data/imnet_sample/n11939491_58904.JPEG new file mode 100755 index 0000000..4592f17 Binary files /dev/null and b/data/imnet_sample/n11939491_58904.JPEG differ diff --git a/data/imnet_sample/n12144580_1509.JPEG b/data/imnet_sample/n12144580_1509.JPEG new file mode 100755 index 0000000..f929481 Binary files /dev/null and b/data/imnet_sample/n12144580_1509.JPEG differ diff --git a/data/imnet_sample/n12267677_1604.JPEG b/data/imnet_sample/n12267677_1604.JPEG new file mode 100755 index 0000000..7190d65 Binary files /dev/null and b/data/imnet_sample/n12267677_1604.JPEG differ diff --git a/data/imnet_sample/n12998815_14972.JPEG b/data/imnet_sample/n12998815_14972.JPEG new file mode 100755 index 0000000..a63c8b4 Binary files /dev/null and b/data/imnet_sample/n12998815_14972.JPEG differ diff --git a/data/imnet_sample/n13052670_8304.JPEG b/data/imnet_sample/n13052670_8304.JPEG new file mode 100755 index 0000000..28d1b41 Binary files /dev/null and b/data/imnet_sample/n13052670_8304.JPEG differ diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..374a4de --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,39 @@ +[build-system] +requires = ["setuptools>=61.0", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "cvdm" +version = "0.1.0" +description = "Code for Conditional Variational Diffusion Models" +authors = [ + {name = "Gabriel Della Maggiora"}, + {name = "Luis Alberto Croquevielle"}, + {name = "Maria Wyrzykowska", email = "m.wyrzykowska@hzdr.de"}, + {name = "Nikita Deshpande"}, + {name = "Harry Horsley"}, + {name = "Thomas Heinis"}, + {name = "Artur Yakimovich"} +] +readme = "README.md" +requires-python = ">=3.7" +license = {text = "MIT"} +urls = {homepage = "https://github.com/casus/cvdm"} + +dependencies = [ + "tensorflow==2.15.0", + "keras==2.15.0", + "matplotlib==3.8.0", + "tqdm==4.65.0", + "scikit-learn==1.4.2", + "scikit-image==0.22.0", + "einops==0.7.0", + "neptune==1.10.2", + "opencv-python==4.9.0.80", + "tensorflow-addons==0.23.0", + "cupy-cuda12x==13.3.0" +] + +[tool.setuptools.packages.find] +where = ["cvdm"] +include = ["*"] \ No newline at end of file diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index 9788c76..0000000 --- a/requirements.txt +++ /dev/null @@ -1,10 +0,0 @@ -tensorflow[and-cuda]==2.15.* -keras==2.15.* -matplotlib==3.8.0 -tqdm==4.65.0 -scikit-learn==1.4.2 -scikit-image==0.22.0 -einops==0.7.0 -tensorflow-addons==0.23.0 -neptune==1.10.2 -opencv-python==4.9.0.80 \ No newline at end of file diff --git a/scripts/eval.py b/scripts/eval.py index a453a33..2c358cb 100644 --- a/scripts/eval.py +++ b/scripts/eval.py @@ -1,5 +1,7 @@ import argparse import uuid +from collections import defaultdict +from typing import Dict import neptune as neptune import numpy as np @@ -47,7 +49,8 @@ def main() -> None: "biosr_phase", "imagenet_phase", "hcoco_phase", - ], "Possible tasks are: biosr_sr, imagenet_sr, biosr_phase, imagenet_phase, hcoco_phase" + "other", + ], "Possible tasks are: biosr_sr, imagenet_sr, biosr_phase, imagenet_phase, hcoco_phase, other" print("Getting data...") batch_size = data_config.batch_size @@ -58,7 +61,7 @@ def main() -> None: generation_timesteps = eval_config.generation_timesteps print("Creating model...") - noise_model, joint_model, schedule_model = instantiate_cvdm( + noise_model, joint_model, schedule_model, mu_model = instantiate_cvdm( lr=0.0, generation_timesteps=generation_timesteps, cond_shape=x_shape, @@ -67,6 +70,8 @@ def main() -> None: ) if model_config.load_weights is not None: joint_model.load_weights(model_config.load_weights) + if model_config.load_mu_weights is not None and mu_model is not None: + mu_model.load_weights(model_config.load_mu_weights) run = None if args.neptune_token is not None and neptune_config is not None: @@ -78,40 +83,55 @@ def main() -> None: run["config.yaml"].upload(args.config_path) output_path = eval_config.output_path + diff_inp = model_config.diff_inp cumulative_loss = np.zeros(5) run_id = str(uuid.uuid4()) step = 0 + cumulative_metrics: Dict[str, float] = defaultdict(float) + total_samples = 0 + for batch in dataset: batch_x, batch_y = batch - diff_inp = task in ["biosr_sr", "imagenet_sr"] - cmap = "gray" if task in ["biosr_phase", "imagenet_phase", "hcoco_phase"] else None + + cmap = ( + "gray" if task in ["biosr_phase", "imagenet_phase", "hcoco_phase"] else None + ) model_input = prepare_model_input(batch_x, batch_y, diff_inp=diff_inp) cumulative_loss += joint_model.evaluate( model_input, np.zeros_like(batch_y), verbose=0 ) - if step % eval_config.image_freq == 0: - output_montage, metrics = obtain_output_montage_and_metrics( - batch_x, - batch_y.numpy(), - noise_model, - schedule_model, - generation_timesteps, - task, - ) - log_metrics(run, metrics, prefix="val") - save_output_montage( - run=run, - output_montage=output_montage, - step=step, - output_path=output_path, - run_id=run_id, - prefix="val", - cmap=cmap, - ) + output_montage, metrics = obtain_output_montage_and_metrics( + batch_x, + batch_y.numpy(), + noise_model, + schedule_model, + mu_model, + generation_timesteps, + diff_inp, + task, + ) + for metric_name, metric_value in metrics.items(): + cumulative_metrics[metric_name] += metric_value * batch_size + total_samples += batch_size step += 1 + average_metrics = { + metric_name: total / total_samples + for metric_name, total in cumulative_metrics.items() + } + + log_metrics(run, average_metrics, prefix="val") + save_output_montage( + run=run, + output_montage=output_montage, + step=step, + output_path=output_path, + run_id=run_id, + prefix="val", + cmap=cmap, + ) print("Loss: ", cumulative_loss) log_loss(run=run, avg_loss=cumulative_loss / (step + 1), prefix="val") diff --git a/scripts/train.py b/scripts/train.py index c4f6cca..d94b85a 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -20,7 +20,7 @@ log_metrics, obtain_output_montage_and_metrics, save_output_montage, - save_weighs, + save_weights, ) from cvdm.utils.training_utils import ( prepare_dataset, @@ -32,6 +32,11 @@ def main() -> None: + + # The script accepts the following command-line arguments: + # - `--config-path`: The path to the YAML configuration file (required). + # - `--neptune-token`: The API token for Neptune (optional). + parser = argparse.ArgumentParser() parser.add_argument( "--config-path", help="Path to the configuration file", required=True @@ -56,7 +61,8 @@ def main() -> None: "imagenet_sr", "biosr_phase", "imagenet_phase", - ], "Possible tasks are biosr_sr, imagenet_sr, biosr_phase, imagenet_phase" + "other", + ], "Possible tasks are biosr_sr, imagenet_sr, biosr_phase, imagenet_phase, other" print("Getting data...") batch_size = data_config.batch_size @@ -76,15 +82,18 @@ def main() -> None: generation_timesteps = eval_config.generation_timesteps print("Creating model...") - noise_model, joint_model, schedule_model = instantiate_cvdm( + models = instantiate_cvdm( lr=training_config.lr, generation_timesteps=generation_timesteps, cond_shape=x_shape, out_shape=y_shape, model_config=model_config, ) + noise_model, joint_model, schedule_model, mu_model = models if model_config.load_weights is not None: joint_model.load_weights(model_config.load_weights) + if model_config.load_mu_weights is not None and mu_model is not None: + mu_model.load_weights(model_config.load_mu_weights) run = None if args.neptune_token is not None and neptune_config is not None: @@ -100,15 +109,19 @@ def main() -> None: image_freq = eval_config.image_freq val_freq = eval_config.val_freq output_path = eval_config.output_path + diff_inp = model_config.diff_inp print("Starting training...") - cumulative_loss = np.zeros(5) + if model_config.zmd: + cumulative_loss = np.zeros(6) + else: + cumulative_loss = np.zeros(5) step = 0 run_id = str(uuid.uuid4()) for _ in trange(epochs): for batch in dataset: batch_x, batch_y = batch - diff_inp = task in ["biosr_sr", "imagenet_sr"] + cmap = "gray" if task in ["biosr_phase", "imagenet_phase"] else None cumulative_loss += train_on_batch_cvdm( batch_x, batch_y, joint_model, diff_inp=diff_inp @@ -118,9 +131,10 @@ def main() -> None: log_loss(run=run, avg_loss=cumulative_loss / (step + 1), prefix="train") if step % checkpoint_freq == 0: - save_weighs( + save_weights( run=run, model=joint_model, + mu_model=mu_model, step=step, output_path=output_path, run_id=run_id, @@ -132,7 +146,9 @@ def main() -> None: batch_y.numpy(), noise_model, schedule_model, + mu_model, generation_timesteps, + diff_inp, task, ) log_metrics(run, metrics, prefix="train") @@ -143,11 +159,14 @@ def main() -> None: output_path=output_path, run_id=run_id, prefix="train", - cmap=cmap + cmap=cmap, ) if step % val_freq == 0: - val_loss = np.zeros(5) + if model_config.zmd: + val_loss = np.zeros(6) + else: + val_loss = np.zeros(5) for batch in val_dataset: batch_x, batch_y = batch model_input = prepare_model_input( @@ -158,7 +177,7 @@ def main() -> None: ) log_loss(run=run, avg_loss=val_loss, prefix="val") - + # To speed up, images are only generated and metrics are calculated only for one batch. random_batch = val_dataset.take(1) for batch_x, batch_y in random_batch: output_montage, metrics = obtain_output_montage_and_metrics( @@ -166,7 +185,9 @@ def main() -> None: batch_y.numpy(), noise_model, schedule_model, + mu_model, generation_timesteps, + diff_inp, task, ) log_metrics(run, metrics, prefix="val") @@ -177,7 +198,7 @@ def main() -> None: output_path=output_path, run_id=run_id, prefix="val", - cmap=cmap + cmap=cmap, ) step += 1 diff --git a/setup.py b/setup.py deleted file mode 100644 index 6cf86e4..0000000 --- a/setup.py +++ /dev/null @@ -1,13 +0,0 @@ -from setuptools import find_packages, setup - -setup( - name="cvdm", - version="0.1", - description="Code for Conditional Variational Diffusion Models", - url="TODO", - author="Gabriel Della Maggiora, Luis Alberto Croquevielle, Maria Wyrzykowska, Nikita Deshpande, Harry Horsley, Thomas Heinis, Artur Yakimovich", - author_email="m.wyrzykowska@hzdr.de", - license="MIT", - packages=find_packages(), - zip_safe=False, -)