In this project, an image classifier was built and trained (on a flower data set) with Pytorch using a pre-trained deep neural network. The image classifier was trained to recognize different species of flowers (You can imagine using something like this in a phone app that tells you the name of the flower your camera is looking at). The project was firstly written in a Jupyter Notebook then converted to a command line application.
- PyTorch
- Python
- Numpy
- Matplotlib
- GPU
A test accuracy of 82% was reached during training - using the train set. A test accuracy of 84% was reached after testing the classifier on the test set.
-
- Before running the application on your computer, add a directory (named: flowers) to the project directory with 3 sub-directories (train, test and valid) which will need to be populated with images to be used for testing, training and validation.
train.py trains a new network on a dataset and save the model as a checkpoint. The second file, predict.py, uses a trained network to predict the class for an input image.
The following basic usage will help you run the application with default arguments.
- Basic usage: python train.py
- Prints out training loss, validation loss, and validation accuracy as the network trains
- Basic usage: python predict.py
- Return top KK most likely classes: python predict.py input checkpoint --top_k 3