-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1 from casus/master
Add the code
- Loading branch information
Showing
45 changed files
with
2,219 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
__pycache__/ | ||
*.py[cod] | ||
*$py.class | ||
|
||
# Distribution / packaging | ||
.Python | ||
build/ | ||
develop-eggs/ | ||
dist/ | ||
downloads/ | ||
eggs/ | ||
.eggs/ | ||
lib/ | ||
lib64/ | ||
parts/ | ||
sdist/ | ||
var/ | ||
wheels/ | ||
share/python-wheels/ | ||
*.egg-info/ | ||
.installed.cfg | ||
*.egg | ||
MANIFEST | ||
|
||
# Jupyter Notebook | ||
.ipynb_checkpoints | ||
|
||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm | ||
__pypackages__/ | ||
|
||
# Environments | ||
.env | ||
.venv | ||
env/ | ||
venv/ | ||
|
||
# mypy files | ||
.mypy_cache/ | ||
|
||
# Visual Studio & PyCharm files | ||
.idea/ | ||
.vscode/ | ||
|
||
# Data & models | ||
/data/ | ||
models/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,47 @@ | ||
# cvdm | ||
# Conditional Variational Diffusion Models | ||
|
||
This code implements the Conditional Variational Diffusion Models as described [in the paper](https://arxiv.org/abs/2312.02246). | ||
|
||
## Where to get the data? | ||
|
||
The datasets that we are using are available online: | ||
- [BioSR](https://github.com/qc17-THU/DL-SR), the data that we are using has been transformed to .npy files | ||
- [ImageNet from ILSVRC2012](https://www.image-net.org/challenges/LSVRC/2012/) | ||
- [HCOCO](https://github.com/bcmi/Image-Harmonization-Dataset-iHarmony4?tab=readme-ov-file) - only used in model evaluation | ||
|
||
It is assumed that for: | ||
- BioSR super-resolution task, data can be found in the directory specified as dataset_path in configs/biosr.yaml, in two files, x.npy (input) and y.npy (ground truth) | ||
- BioSR phase task, data can be found in the directory specified as dataset_path in configs/biosr_phase.yaml, in one file, y.npy (ground truth). Input to the model will be generated based on the ground truth. | ||
- ImageNet super-resolution task, data can be found in the directory specified as dataset_path in configs/imagenet_sr.yaml as a collection of JPEG files. Input to the model will be generated based on the ground truth. | ||
- ImageNet phase task, data can be found in the directory specified as dataset_path in configs/imagenet_phase.yaml as a collection of JPEG files. Input to the model will be generated based on the ground truth. | ||
- HCOCO phase evaluation task, data can be found in the directory specified as dataset_path in configs/hcoco_phase_eval.yaml as a collection of JPEG files. Input to the model will be generated based on the ground truth. | ||
|
||
## How to prepare environment? | ||
|
||
Run the following code: | ||
``` | ||
conda create -n myenv python=3.10 | ||
conda activate myenv | ||
pip install -r requirements.txt | ||
pip install -e . | ||
``` | ||
|
||
## How to run the training code? | ||
|
||
1. Download the data. | ||
1. Modify the config in `configs/` directory with the path to the data you want to use and the directory for outputs. | ||
2. Run the code from the root directory: `python scripts/train.py --config-path $PATH_TO_CONFIG --neptune-token $NEPTUNE_TOKEN`. | ||
|
||
`--neptune-token` argument is optional. | ||
|
||
## How to run the training code? | ||
|
||
1. Download the data. | ||
1. Modify the config in `configs/` directory with the path to the data you want to use and the directory for outputs. | ||
2. Run the code from the root directory: `python scripts/eval.py --config-path $PATH_TO_CONFIG --neptune-token $NEPTUNE_TOKEN`. | ||
|
||
`--neptune-token` argument is optional. | ||
|
||
## License | ||
This repository is released under the MIT License (refer to the LICENSE file for details). | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
task: "biosr_sr" | ||
|
||
model: | ||
noise_model_type: "unet" | ||
alpha: 0.001 | ||
load_weights: null | ||
snr_expansion_n: 1 | ||
|
||
training: | ||
lr: 0.0001 | ||
epochs: 10 | ||
|
||
eval: | ||
output_path: "outputs/biosr" | ||
generation_timesteps: 1000 | ||
checkpoint_freq: 1000 | ||
log_freq: 10 | ||
image_freq: 100 | ||
val_freq: 200 | ||
val_len: 100 | ||
|
||
data: | ||
dataset_path: "/bigdata/casus/MLID/maria/biosr_sample" | ||
n_samples: 100 | ||
batch_size: 2 | ||
im_size: 256 | ||
|
||
neptune: | ||
name: "Virtual_Stain" | ||
project: "mlid/test" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
task: "biosr_phase" | ||
|
||
model: | ||
noise_model_type: "unet" | ||
alpha: 0.001 | ||
load_weights: null | ||
snr_expansion_n: 1 | ||
|
||
training: | ||
lr: 0.0001 | ||
epochs: 10 | ||
|
||
eval: | ||
output_path: "outputs/biosr" | ||
generation_timesteps: 1000 | ||
checkpoint_freq: 1000 | ||
log_freq: 10 | ||
image_freq: 100 | ||
val_freq: 200 | ||
val_len: 100 | ||
|
||
data: | ||
dataset_path: "/bigdata/casus/MLID/maria/biosr_sample" | ||
n_samples: 100 | ||
batch_size: 2 | ||
im_size: 256 | ||
|
||
neptune: | ||
name: "Virtual_Stain" | ||
project: "mlid/test" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
task: "hcoco_phase" | ||
|
||
model: | ||
noise_model_type: "unet" | ||
alpha: 0.001 | ||
load_weights: null | ||
snr_expansion_n: 1 | ||
|
||
training: | ||
lr: 0.0001 | ||
epochs: 100 | ||
|
||
eval: | ||
output_path: "outputs/hcoco" | ||
generation_timesteps: 1000 | ||
checkpoint_freq: 1000 | ||
log_freq: 10 | ||
image_freq: 100 | ||
val_freq: 200 | ||
val_len: 100 | ||
|
||
data: | ||
dataset_path: "/bigdata/casus/MLID/maria/hcoco_sample" | ||
n_samples: 100 | ||
batch_size: 1 | ||
im_size: 256 | ||
|
||
neptune: | ||
name: "Virtual_Stain" | ||
project: "mlid/test" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
task: "imagenet_phase" | ||
|
||
model: | ||
noise_model_type: "unet" | ||
alpha: 0.001 | ||
load_weights: null | ||
snr_expansion_n: 1 | ||
|
||
training: | ||
lr: 0.0001 | ||
epochs: 100 | ||
|
||
eval: | ||
output_path: "outputs/imagenet" | ||
generation_timesteps: 1000 | ||
checkpoint_freq: 1000 | ||
log_freq: 10 | ||
image_freq: 100 | ||
val_freq: 200 | ||
val_len: 100 | ||
|
||
data: | ||
dataset_path: "/bigdata/casus/MLID/maria/imagenet_sample" | ||
n_samples: 100 | ||
batch_size: 1 | ||
im_size: 256 | ||
|
||
neptune: | ||
name: "Virtual_Stain" | ||
project: "mlid/test" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
task: "imagenet_sr" | ||
model: | ||
noise_model_type: "unet" | ||
alpha: 0.001 | ||
load_weights: null | ||
snr_expansion_n: 1 | ||
|
||
training: | ||
lr: 0.0001 | ||
epochs: 100 | ||
|
||
eval: | ||
output_path: "outputs/imagenet" | ||
generation_timesteps: 1000 | ||
checkpoint_freq: 1000 | ||
log_freq: 10 | ||
image_freq: 100 | ||
val_freq: 200 | ||
val_len: 100 | ||
|
||
data: | ||
dataset_path: "/bigdata/casus/MLID/maria/imagenet_sample" | ||
n_samples: 100 | ||
batch_size: 2 | ||
im_size: 256 | ||
|
||
neptune: | ||
name: "Virtual_Stain" | ||
project: "mlid/test" |
Empty file.
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
from __future__ import annotations | ||
|
||
from typing import Any, Dict | ||
|
||
import numpy as np | ||
import tensorflow as tf | ||
from tensorflow.keras.initializers import VarianceScaling | ||
from tensorflow.keras.layers import Attention, Layer | ||
from tensorflow_addons.layers import GroupNormalization | ||
|
||
|
||
class AttentionVectorLayer(Layer): | ||
""" | ||
Building the query, key or value for self-attention | ||
from the feature map | ||
""" | ||
|
||
def __init__(self, **kwargs) -> None: | ||
super(AttentionVectorLayer, self).__init__(**kwargs) | ||
|
||
def build(self, input_shape: np.ndarray) -> None: | ||
self.n_channels = input_shape[-1] | ||
self.w = self.add_weight( | ||
shape=(self.n_channels, self.n_channels), | ||
initializer=VarianceScaling( | ||
scale=1.0, mode="fan_avg", distribution="uniform" | ||
), | ||
trainable=True, | ||
name="attention_w", | ||
) | ||
self.b = self.add_weight( | ||
shape=(self.n_channels,), | ||
initializer="zero", | ||
trainable=True, | ||
name="attention_b", | ||
) | ||
|
||
def call(self, x: tf.Tensor) -> tf.Tensor: | ||
out: tf.Tensor = tf.tensordot(x, self.w, 1) + self.b | ||
return out | ||
|
||
def get_config(self) -> Dict[str, Any]: | ||
return {} | ||
|
||
@classmethod | ||
def from_config(cls, config: Dict[str, Any]) -> AttentionVectorLayer: | ||
return cls(**config) | ||
|
||
|
||
def attention_block(x: tf.Tensor) -> tf.Tensor: | ||
""" | ||
Implementing self-attention block, as mentioned in | ||
https://arxiv.org/pdf/1809.11096.pdf | ||
""" | ||
|
||
x = GroupNormalization(groups=32, axis=-1)(x) | ||
|
||
q = AttentionVectorLayer()(x) | ||
v = AttentionVectorLayer()(x) | ||
k = AttentionVectorLayer()(x) | ||
|
||
h: tf.Tensor = Attention()([q, v, k]) | ||
|
||
return x + h |
74 changes: 74 additions & 0 deletions
74
cvdm/architectures/components/conditional_instance_normalization.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
from __future__ import annotations | ||
|
||
from collections.abc import Iterable | ||
from typing import Any, Dict | ||
|
||
import numpy as np | ||
import tensorflow as tf | ||
from tensorflow.keras.initializers import RandomNormal | ||
from tensorflow.keras.layers import Layer | ||
|
||
|
||
class ConditionalInstanceNormalization(Layer): | ||
""" | ||
The goal of conditional instance normalization is to make the | ||
model aware of the amount of noise to be removed (i.e how many | ||
steps in the noise diffusion process are being considered). The | ||
implementation was informed by the appendix in | ||
https://arxiv.org/pdf/1907.05600.pdf and implementation at | ||
https://github.com/ermongroup/ncsn/blob/master/models/cond_refinenet_dilated.py | ||
""" | ||
|
||
def __init__(self, **kwargs): | ||
super(ConditionalInstanceNormalization, self).__init__(**kwargs) | ||
|
||
def build(self, input_shape: np.ndarray) -> None: | ||
self.batch_size = input_shape[0][0] | ||
self.height = input_shape[0][1] | ||
self.width = input_shape[0][2] | ||
self.n_channels = input_shape[0][3] | ||
self.embedding_dim = input_shape[1][1] | ||
self.w1 = self.add_weight( | ||
shape=(self.embedding_dim, self.n_channels), | ||
initializer=RandomNormal(mean=1.0, stddev=0.02), | ||
trainable=True, | ||
name="conditional_w1", | ||
) | ||
self.b = self.add_weight( | ||
shape=(self.embedding_dim, self.n_channels), | ||
initializer="zero", | ||
trainable=True, | ||
name="conditional_b", | ||
) | ||
self.w2 = self.add_weight( | ||
shape=(self.embedding_dim, self.n_channels), | ||
initializer=RandomNormal(mean=1.0, stddev=0.02), | ||
trainable=True, | ||
name="conditional_w2", | ||
) | ||
|
||
def call(self, inputs: Iterable[tf.Tensor]) -> tf.Tensor: | ||
x, noise_embedding = inputs | ||
feature_map_means = tf.math.reduce_mean(x, axis=(1, 2), keepdims=True) | ||
feature_map_std_dev = tf.math.reduce_std(x, axis=(1, 2), keepdims=True) + 1e-5 | ||
m = tf.math.reduce_mean(feature_map_means, axis=-1, keepdims=True) | ||
v = tf.math.reduce_std(feature_map_means, axis=-1, keepdims=True) + 1e-5 | ||
gamma = tf.expand_dims( | ||
tf.expand_dims(tf.tensordot(noise_embedding, self.w1, 1), 1), 1 | ||
) | ||
beta = tf.expand_dims( | ||
tf.expand_dims(tf.tensordot(noise_embedding, self.b, 1), 1), 1 | ||
) | ||
alpha = tf.expand_dims( | ||
tf.expand_dims(tf.tensordot(noise_embedding, self.w2, 1), 1), 1 | ||
) | ||
instance_norm = (x - feature_map_means) / feature_map_std_dev | ||
x = gamma * instance_norm + beta + alpha * (feature_map_means - m) / v | ||
return x | ||
|
||
def get_config(self) -> Dict[str, Any]: | ||
return {} | ||
|
||
@classmethod | ||
def from_config(cls, config: Dict[str, Any]) -> ConditionalInstanceNormalization: | ||
return cls(**config) |
Oops, something went wrong.