Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding simple GAN network #915

Open
wants to merge 38 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
5e7ce10
Add GAN model 25/7
Shubbair Jul 25, 2024
d426586
Updating GAN Code...
Shubbair Jul 26, 2024
591074b
Updating GAN Code...
Shubbair Jul 26, 2024
959c623
Updating GAN Code...
Shubbair Jul 26, 2024
f176cce
Updating GAN Code...
Shubbair Jul 26, 2024
a05608c
Updating GAN Code...
Shubbair Jul 26, 2024
147cb3d
Updating GAN Code...
Shubbair Jul 26, 2024
f8b7094
Updating GAN Code...
Shubbair Jul 26, 2024
8b17137
Updating GAN Code...
Shubbair Jul 26, 2024
88a20b7
Updating GAN Code...
Shubbair Jul 27, 2024
3716501
Updating GAN Code...
Shubbair Jul 28, 2024
3e63cd9
Updating GAN Code...
Shubbair Jul 28, 2024
d17d293
Updating GAN Code...
Shubbair Jul 28, 2024
c0c8293
Updating GAN Code...
Shubbair Jul 28, 2024
a07ef6d
Updating GAN Code...
Shubbair Jul 28, 2024
4de0583
Updating GAN Code...
Shubbair Jul 28, 2024
8d27be1
Updating GAN Code...
Shubbair Jul 28, 2024
bacaa9e
Updating GAN Code...
Shubbair Jul 28, 2024
306e53c
Updating GAN Code...
Shubbair Jul 29, 2024
4e80759
Updating GAN Code...
Shubbair Jul 29, 2024
f505fe6
Updating GAN Code...
Shubbair Jul 29, 2024
7fea34d
Updating GAN Code...
Shubbair Jul 29, 2024
7438b54
Updating GAN Code...
Shubbair Jul 29, 2024
1e386b5
Updating GAN Code...
Shubbair Jul 29, 2024
ba52447
Updating GAN Code...
Shubbair Jul 30, 2024
c2d731d
Updating GAN Code...
Shubbair Jul 30, 2024
3bea855
Updating GAN Code...
Shubbair Jul 30, 2024
ad2b664
Updating GAN Code...
Shubbair Jul 30, 2024
0644cc1
Updating MLX Notebook
Shubbair Jul 30, 2024
6f7a660
Updating MLX Notebook
Shubbair Jul 30, 2024
f70cef9
Updating GAN Code...
Shubbair Jul 31, 2024
a8ffa9c
Updating GAN Code...
Shubbair Jul 31, 2024
1ef3ad2
Updating GAN Code...
Shubbair Jul 31, 2024
4d17f80
Updating GAN Code...
Shubbair Jul 31, 2024
37bbf3e
Updating GAN Code...
Shubbair Jul 31, 2024
7e0bdac
Code Arrangement
Shubbair Aug 1, 2024
f84b231
Code Arrangement
Shubbair Aug 1, 2024
a5752be
Code Arrangement
Shubbair Aug 1, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added gan/gen_images/img_0.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added gan/gen_images/img_100.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added gan/gen_images/img_200.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added gan/gen_images/img_300.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added gan/gen_images/img_400.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
195 changes: 195 additions & 0 deletions gan/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
import mnist

import argparse

import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim

from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt

# Generator Block
def GenBlock(in_dim:int,out_dim:int):
return nn.Sequential(
nn.Linear(in_dim,out_dim),
nn.BatchNorm(out_dim, 0.8),
nn.LeakyReLU(0.2)
)

# Generator Model
class Generator(nn.Module):

def __init__(self, z_dim:int = 32, im_dim:int = 784, hidden_dim: int = 256):
super(Generator, self).__init__()

self.gen = nn.Sequential(
GenBlock(z_dim, hidden_dim),
GenBlock(hidden_dim, hidden_dim * 2),
GenBlock(hidden_dim * 2, hidden_dim * 4),

nn.Linear(hidden_dim * 4,im_dim),
)

def __call__(self, noise):
x = self.gen(noise)
return mx.tanh(x)

# make 2D noise with shape n_samples x z_dim
def get_noise(n_samples:list[int], z_dim:int)->list[int]:
return mx.random.normal(shape=(n_samples, z_dim))

#---------------------------------------------#

# Discriminator Block
def DisBlock(in_dim:int,out_dim:int):
return nn.Sequential(
nn.Linear(in_dim,out_dim),
nn.LeakyReLU(negative_slope=0.2),
nn.Dropout(0.3),
)

# Discriminator Model
class Discriminator(nn.Module):

def __init__(self,im_dim:int = 784, hidden_dim:int = 256):
super(Discriminator, self).__init__()

self.disc = nn.Sequential(
DisBlock(im_dim, hidden_dim * 4),
DisBlock(hidden_dim * 4, hidden_dim * 2),
DisBlock(hidden_dim * 2, hidden_dim),

nn.Linear(hidden_dim,1),
nn.Sigmoid()
)

def __call__(self, noise):
return self.disc(noise)

# Discriminator Loss
def disc_loss(gen, disc, real, num_images, z_dim):

noise = mx.array(get_noise(num_images, z_dim))
fake_images = gen(noise)

fake_disc = disc(fake_images)

fake_labels = mx.zeros((fake_images.shape[0],1))

fake_loss = mx.mean(nn.losses.binary_cross_entropy(fake_disc,fake_labels,with_logits=True))

real_disc = mx.array(disc(real))
real_labels = mx.ones((real.shape[0],1))

real_loss = mx.mean(nn.losses.binary_cross_entropy(real_disc,real_labels,with_logits=True))

disc_loss = (fake_loss + real_loss) / 2.0

return disc_loss

# Genearator Loss
def gen_loss(gen, disc, num_images, z_dim):

noise = mx.array(get_noise(num_images, z_dim))

fake_images = gen(noise)
fake_disc = mx.array(disc(fake_images))

fake_labels = mx.ones((fake_images.shape[0],1))

gen_loss = nn.losses.binary_cross_entropy(fake_disc,fake_labels,with_logits=True)

return mx.mean(gen_loss)

# make batch of images
def batch_iterate(batch_size: int, ipt: list[int])-> list[int]:
perm = np.random.permutation(len(ipt))
for s in range(0, len(ipt), batch_size):
ids = perm[s : s + batch_size]
yield ipt[ids]

# plot batch of images at epoch steps
def show_images(epoch_num:int,imgs:list[int],num_imgs:int = 25):
if (imgs.shape[0] > 0):
fig,axes = plt.subplots(5, 5, figsize=(5, 5))

for i, ax in enumerate(axes.flat):
img = mx.array(imgs[i]).reshape(28,28)
ax.imshow(img,cmap='gray')
ax.axis('off')
plt.tight_layout()
plt.savefig('gen_images/img_{}.png'.format(epoch_num))
plt.show()

def main(args:dict):
seed = 42
n_epochs = 500
z_dim = 128
batch_size = 128
lr = 2e-5

mx.random.seed(seed)

# Load the data
train_images,*_ = map(np.array, getattr(mnist,'mnist')())

# Normalization images => [-1,1]
train_images = train_images * 2.0 - 1.0

gen = Generator(z_dim)
mx.eval(gen.parameters())
gen_opt = optim.Adam(learning_rate=lr, betas=[0.5, 0.999])

disc = Discriminator()
mx.eval(disc.parameters())
disc_opt = optim.Adam(learning_rate=lr, betas=[0.5, 0.999])

# TODO training...

D_loss_grad = nn.value_and_grad(disc, disc_loss)
G_loss_grad = nn.value_and_grad(gen, gen_loss)

for epoch in tqdm(range(n_epochs)):

for idx,real in enumerate(batch_iterate(batch_size, train_images)):

# TODO Train Discriminator
D_loss,D_grads = D_loss_grad(gen, disc,mx.array(real), batch_size, z_dim)

# Update optimizer
disc_opt.update(disc, D_grads)

# Update gradients
mx.eval(disc.parameters(), disc_opt.state)

# TODO Train Generator
G_loss,G_grads = G_loss_grad(gen, disc, batch_size, z_dim)

# Update optimizer
gen_opt.update(gen, G_grads)

# Update gradients
mx.eval(gen.parameters(), gen_opt.state)

if epoch%100==0:
print("Epoch: {}, iteration: {}, Discriminator Loss:{}, Generator Loss: {}".format(epoch,idx,D_loss,G_loss))
fake_noise = mx.array(get_noise(batch_size, z_dim))
fake = gen(fake_noise)
show_images(epoch,fake)

if __name__ == "__main__":
parser = argparse.ArgumentParser("Train a simple GAN on MNIST with MLX.")
parser.add_argument("--gpu", action="store_true", help="Use the Metal back-end.")
parser.add_argument(
"--dataset",
type=str,
default="mnist",
choices=["mnist", "fashion_mnist"],
help="The dataset to use.",
)
args = parser.parse_args()
if not args.gpu:
mx.set_default_device(mx.cpu)
main(args)
83 changes: 83 additions & 0 deletions gan/mnist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# Copyright © 2023 Apple Inc.

import gzip
import os
import pickle
from urllib import request

import numpy as np


def mnist(
save_dir="/tmp",
base_url="https://raw.githubusercontent.com/fgnt/mnist/master/",
filename="mnist.pkl",
):
"""
Load the MNIST dataset in 4 tensors: train images, train labels,
test images, and test labels.

Checks `save_dir` for already downloaded data otherwise downloads.

Download code modified from:
https://github.com/hsjeong5/MNIST-for-Numpy
"""

def download_and_save(save_file):
filename = [
["training_images", "train-images-idx3-ubyte.gz"],
["test_images", "t10k-images-idx3-ubyte.gz"],
["training_labels", "train-labels-idx1-ubyte.gz"],
["test_labels", "t10k-labels-idx1-ubyte.gz"],
]

mnist = {}
for name in filename:
out_file = os.path.join("/tmp", name[1])
request.urlretrieve(base_url + name[1], out_file)
for name in filename[:2]:
out_file = os.path.join("/tmp", name[1])
with gzip.open(out_file, "rb") as f:
mnist[name[0]] = np.frombuffer(f.read(), np.uint8, offset=16).reshape(
-1, 28 * 28
)
for name in filename[-2:]:
out_file = os.path.join("/tmp", name[1])
with gzip.open(out_file, "rb") as f:
mnist[name[0]] = np.frombuffer(f.read(), np.uint8, offset=8)
with open(save_file, "wb") as f:
pickle.dump(mnist, f)

save_file = os.path.join(save_dir, filename)
if not os.path.exists(save_file):
download_and_save(save_file)
with open(save_file, "rb") as f:
mnist = pickle.load(f)

def preproc(x):
return x.astype(np.float32) / 255.0

mnist["training_images"] = preproc(mnist["training_images"])
mnist["test_images"] = preproc(mnist["test_images"])
return (
mnist["training_images"],
mnist["training_labels"].astype(np.uint32),
mnist["test_images"],
mnist["test_labels"].astype(np.uint32),
)


def fashion_mnist(save_dir="/tmp"):
return mnist(
save_dir,
base_url="http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/",
filename="fashion_mnist.pkl",
)


if __name__ == "__main__":
train_x, train_y, test_x, test_y = mnist()
assert train_x.shape == (60000, 28 * 28), "Wrong training set size"
assert train_y.shape == (60000,), "Wrong training set size"
assert test_x.shape == (10000, 28 * 28), "Wrong test set size"
assert test_y.shape == (10000,), "Wrong test set size"
Loading