Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pruning, save on GPU, test on CPU. #152

Open
wants to merge 55 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
318e504
Add args input and pruning
bryanbocao Jun 22, 2022
38912e4
Update args and pruning reference
bryanbocao Jun 22, 2022
713fa09
Add commands with args
bryanbocao Jun 22, 2022
2937365
Add nohup version
bryanbocao Jun 22, 2022
b782bba
--train --test with action | Save on GPU, test on CPU
bryanbocao Jun 24, 2022
aea8bde
--train --test with action
bryanbocao Jun 24, 2022
b9f693a
Test on CPU
bryanbocao Jun 24, 2022
3102ecf
Clean main.py
bryanbocao Jun 24, 2022
c85353c
Print summary of net
bryanbocao Jun 24, 2022
18edadd
Print layer params
bryanbocao Jun 24, 2022
ba1df42
Add Trained Weights
bryanbocao Jun 30, 2022
b4cbcd2
test_sim.py
bryanbocao Jul 2, 2022
10d8828
main_n_cls.py
bryanbocao Jul 7, 2022
16cb8ab
Start training for n_cls experiments
bryanbocao Jul 7, 2022
f8a84cf
Save n_cls checkpoints
bryanbocao Jul 7, 2022
40399b7
Merge branch 'wip' of https://github.com/bryanbo-cao/pytorch-cifar in…
bryanbocao Jul 7, 2022
383d959
main_n_cls_nohup.py
bryanbocao Jul 7, 2022
c2c7714
Check if inputs are empty before feeding into the network
bryanbocao Jul 7, 2022
f978249
Check if inputs are empty before feeding into the network
bryanbocao Jul 7, 2022
610e452
Correct num_classes layer for MobileNetV2
bryanbocao Jul 8, 2022
dad10af
Correct num_class layer for MobileNetV2
bryanbocao Jul 8, 2022
7ea9c17
Move inputs,targets to GPU
bryanbocao Jul 8, 2022
62a444a
Save models at epoch=0
bryanbocao Jul 8, 2022
a27d19c
Move net.to(device) after loading the weights
bryanbocao Jul 9, 2022
2e8d7ae
Save models
bryanbocao Jul 9, 2022
9ac704a
Save models
bryanbocao Jul 9, 2022
4ae4620
Load epoch
bryanbocao Jul 9, 2022
e007fe5
Add num_classes args for models
bryanbocao Jul 10, 2022
4aded52
rm __pycache__
bryanbocao Jul 10, 2022
5217d90
Add resnet num_classes
bryanbocao Jul 10, 2022
861f042
Create sweep_n_cls_0.sh
bryanbocao Jul 10, 2022
1113952
Create sweep_n_cls_1.sh
bryanbocao Jul 10, 2022
5ba5748
Add BasicBlock arg for simpleDLA
bryanbocao Jul 10, 2022
148b18b
Split works
bryanbocao Jul 10, 2022
1644d32
training and testing devices
bryanbocao Jul 10, 2022
bb0c72b
Add more bash scrpits for training
bryanbocao Jul 11, 2022
fa96e89
Add CUDA_3
bryanbocao Jul 11, 2022
f048194
Add CUDA_3
bryanbocao Jul 11, 2022
d6aca36
Add arg num_classes in EfficientNetB0
bryanbocao Jul 11, 2022
e349cf0
Correct bash scripts
bryanbocao Jul 11, 2022
6bc1098
GoogLeNet CUDA_0
bryanbocao Jul 11, 2022
9adf370
Add CUDA
bryanbocao Jul 12, 2022
aa20d8e
Add num_class
bryanbocao Jul 13, 2022
ecac0bf
Save 10 classes by default
bryanbocao Jul 13, 2022
c3a1e13
Correct args
bryanbocao Jul 13, 2022
848a806
D and S groups
bryanbocao Jul 18, 2022
0f86d7f
Add num_class groups
bryanbocao Jul 18, 2022
c792b84
Add D2G1
bryanbocao Jul 18, 2022
2a98e5d
Correct class group
bryanbocao Jul 18, 2022
1c8bc3d
Convert target into [0..n] range
bryanbocao Jul 18, 2022
f55c39d
Add tensor
bryanbocao Jul 18, 2022
00fd6fb
Correct the bug of index
bryanbocao Jul 18, 2022
4fd9728
Correct the bug of index
bryanbocao Jul 18, 2022
5535497
Correct the bug of index
bryanbocao Jul 18, 2022
39b2f62
Add class group in checkpoint path
bryanbocao Jul 19, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 21 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,29 @@ I'm playing with [PyTorch](http://pytorch.org/) on the CIFAR10 dataset.
## Training
```
# Start training with:
python main.py
python main.py --net ResNet18 --train --test

# Start training for n_cls experiments:
python main_n_cls.py --net MobileNetV2 --train --test --num_class 5

# You can manually resume the training with:
python main.py --resume --lr=0.01
python main.py --net ResNet18 --train --test --resume --lr=0.01
```

## Testing
```
# Test only on GPU
python main.py --net ResNet18 --test

# Test only on GPU with pruning (0.3)
python main.py --net ResNet18 --test --prune --pruning_rate 0.3

# Test only on CPU
python main.py --net ResNet18 --test --select_device cpu
```

# Trained Weights
[Google Drive](https://drive.google.com/drive/folders/1DRcb7uw1goot8doydHAc0ip3us5zjilk?usp=sharing)

## Accuracy
| Model | Acc. |
Expand All @@ -33,3 +51,4 @@ python main.py --resume --lr=0.01
| [DPN92](https://arxiv.org/abs/1707.01629) | 95.16% |
| [DLA](https://arxiv.org/pdf/1707.06484.pdf) | 95.47% |

Pruning [Reference Link](https://github.com/ultralytics/yolov5/blob/a2a1ed201d150343a4f9912d644be2b210206984/utils/torch_utils.py#L174)
130 changes: 105 additions & 25 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
from torchinfo import summary

import torchvision
import torchvision.transforms as transforms
Expand All @@ -13,17 +14,29 @@

from models import *
from utils import progress_bar

import time

parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training')
parser.add_argument('--lr', default=0.1, type=float, help='learning rate')
parser.add_argument('--resume', '-r', action='store_true',
help='resume from checkpoint')
parser.add_argument('--net', default='SimpleDLA')
parser.add_argument('--train', action='store_true')
parser.add_argument('--test', action='store_true')
parser.add_argument('--epochs', type=int, default=200)
parser.add_argument('--prune', action='store_true')
parser.add_argument('--pruning_rate', type=float, default=0.30)
parser.add_argument('--test_batch_size', type=int, default=100)
parser.add_argument('--select_device', type=str, default='gpu', help='gpu | cpu')
parser.add_argument('--save_model_epoch_interval', type=int, default=10)
parser.add_argument('--load_epoch', type=str, default='best', help='best | <epoch>')

args = parser.parse_args()

device = 'cuda' if torch.cuda.is_available() else 'cpu'
device = 'cuda' if torch.cuda.is_available() and args.select_device == 'gpu' else 'cpu'
best_acc = 0 # best test accuracy
start_epoch = 0 # start from epoch 0 or last checkpoint epoch
num_class = 10

# Data
print('==> Preparing data..')
Expand All @@ -47,28 +60,70 @@
testset = torchvision.datasets.CIFAR10(
root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(
testset, batch_size=100, shuffle=False, num_workers=2)
testset, batch_size=args.test_batch_size, shuffle=False, num_workers=1)

classes = ('plane', 'car', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck')

# Model
print('==> Building model..')
# net = VGG('VGG19')
# net = ResNet18()
# net = PreActResNet18()
# net = GoogLeNet()
# net = DenseNet121()
# net = ResNeXt29_2x64d()
# net = MobileNet()
# net = MobileNetV2()
# net = DPN92()
# net = ShuffleNetG2()
# net = SENet18()
# net = ShuffleNetV2(1)
# net = EfficientNetB0()
# net = RegNetX_200MF()
net = SimpleDLA()
if args.net == 'VGG19': net = VGG('VGG19')
elif args.net == 'ResNet18': net = ResNet18()
elif args.net == 'PreActResNet18': net = PreActResNet18()
elif args.net == 'GoogLeNet': net = GoogLeNet()
elif args.net == 'DenseNet121': net = DenseNet121()
elif args.net == 'ResNeXt29_2x64d': net = ResNeXt29_2x64d()
elif args.net == 'MobileNet': net = MobileNet()
elif args.net == 'MobileNetV2': net = MobileNetV2()
elif args.net == 'DPN92': net = DPN92()
elif args.net == 'ShuffleNetG2': net = ShuffleNetG2()
elif args.net == 'SENet18': net = SENet18()
elif args.net == 'ShuffleNetV2': net = ShuffleNetV2(1)
elif args.net == 'EfficientNetB0': net = EfficientNetB0()
elif args.net == 'RegNetX_200MF': net = RegNetX_200MF()
elif args.net == 'SimpleDLA': net = SimpleDLA()

# Borrow sparsity() and prune() from
# https://github.com/ultralytics/yolov5/blob/a2a1ed201d150343a4f9912d644be2b210206984/utils/torch_utils.py#L174
def sparsity(model):
# Return global model sparsity
a, b = 0, 0
for p in model.parameters():
a += p.numel()
b += (p == 0).sum()
return b / a

def prune(model, amount=0.3):
# Prune model to requested global sparsity
import torch.nn.utils.prune as prune
print('Pruning model... ', end='')
for name, m in model.named_modules():
if isinstance(m, nn.Conv2d):
prune.l1_unstructured(m, name='weight', amount=amount) # prune
prune.remove(m, 'weight') # make permanent
print(' %.3g global sparsity' % sparsity(model))


def count_layer_params(model, layer_name=nn.Conv2d):
print('\n\n layer_name: ', layer_name)
total_params = 0
total_traina_params = 0
n_layers = 0
for name, m in model.named_modules():
if isinstance(m, layer_name):
# print('\nm:', m)
# print('\ndir(m): ', dir(m))

for name, parameter in m.named_parameters():
params = parameter.numel()
total_params += params
if not parameter.requires_grad: continue
n_layers += 1
total_traina_params += params
print('\n\nlayer_name: {}, total_params: {}, total_traina_params: {}, n_layers: {}'.\
format(layer_name, total_params, total_traina_params, n_layers))
time.sleep(100)

net = net.to(device)
if device == 'cuda':
net = torch.nn.DataParallel(net)
Expand All @@ -78,8 +133,10 @@
# Load checkpoint.
print('==> Resuming from checkpoint..')
assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
checkpoint = torch.load('./checkpoint/ckpt.pth')
net.load_state_dict(checkpoint['net'])

print('\n\ndevice: ', device)
checkpoint = torch.load('./checkpoint/{}_ckpt.pth'.format(args.net), map_location=device)
net.load_state_dict(checkpoint['net'], strict=False)
best_acc = checkpoint['acc']
start_epoch = checkpoint['epoch']

Expand Down Expand Up @@ -115,12 +172,20 @@ def train(epoch):

def test(epoch):
global best_acc
if args.prune:
prune(net, args.pruning_rate)
input_size = (1, 3, 32, 32)
summary(net, input_size)
count_layer_params(net)


net.eval()
test_loss = 0
correct = 0
total = 0
with torch.no_grad():
for batch_idx, (inputs, targets) in enumerate(testloader):
print('device: ', device)
inputs, targets = inputs.to(device), targets.to(device)
outputs = net(inputs)
loss = criterion(outputs, targets)
Expand All @@ -135,6 +200,18 @@ def test(epoch):

# Save checkpoint.
acc = 100.*correct/total
if epoch % args.save_model_epoch_interval == 0:
print('Saving..')
state = {
'net': net.state_dict(),
'acc': acc,
'epoch': epoch,
}
if not os.path.isdir('checkpoint'):
os.mkdir('checkpoint')
torch.save(state, './checkpoint/{}_n_cls_{}_epoch_{}_ckpt.pth'.\
format(args.net, num_class, str(epoch)))
best_acc = acc
if acc > best_acc:
print('Saving..')
state = {
Expand All @@ -144,11 +221,14 @@ def test(epoch):
}
if not os.path.isdir('checkpoint'):
os.mkdir('checkpoint')
torch.save(state, './checkpoint/ckpt.pth')
torch.save(state, './checkpoint/{}_n_cls_{}_epoch_best_ckpt.pth'.\
format(args.net, num_class))
best_acc = acc


for epoch in range(start_epoch, start_epoch+200):
train(epoch)
test(epoch)
print('\n\nargs.train: ', args.train, ', args.test:', args.test)
for epoch in range(args.epochs):
if args.train: train(epoch)
if args.test:
test(epoch)
if not args.train: break
scheduler.step()
Loading