diff --git a/README.md b/README.md index 53cd479..d0fe56c 100644 --- a/README.md +++ b/README.md @@ -5,6 +5,7 @@ [![Build Status](https://travis-ci.org/idealo/image-super-resolution.svg?branch=master)](https://travis-ci.org/idealo/image-super-resolution) [![Docs](https://img.shields.io/badge/docs-online-brightgreen)](https://idealo.github.io/image-super-resolution/) [![License](https://img.shields.io/badge/License-Apache%202.0-orange.svg)](https://github.com/idealo/image-super-resolution/blob/master/LICENSE) + The goal of this project is to upscale and improve the quality of low resolution images. @@ -65,18 +66,18 @@ pip install 'h5py==2.10.0' --force-reinstall ## Pre-trained networks -The weights used to produced these images are available directly when creating the model object. +The weights used to produced these images are available directly when creating the model object. Currently 4 models are available: - RDN: psnr-large, psnr-small, noise-cancel - RRDN: gans - + Example usage: ``` model = RRDN(weights='gans') ``` - + The network parameters will be automatically chosen. (see [Additional Information](#additional-information)). diff --git a/cog.yaml b/cog.yaml new file mode 100644 index 0000000..53259f0 --- /dev/null +++ b/cog.yaml @@ -0,0 +1,8 @@ +build: + python_version: "3.7" + gpu: false + python_packages: + - ISR==2.2.0 + - h5py==2.10.0 --force-reinstall + +predict: "predict.py:ISRPredictor" diff --git a/predict.py b/predict.py new file mode 100644 index 0000000..097955b --- /dev/null +++ b/predict.py @@ -0,0 +1,43 @@ +import os +import tempfile +from pathlib import Path + +import cog +import numpy as np +from ISR.models import RDN, RRDN +from PIL import Image + + +class ISRPredictor(cog.Predictor): + def setup(self): + """Load the super-resolution ans noise canceling models""" + self.model_gans = RRDN(weights="gans") + self.model_noise_cancel = RDN(weights="noise-cancel") + + @cog.input("input", type=Path, help="Image path") + @cog.input( + "type", + type=str, + default="super-resolution", + options=["super-resolution", "noise-cancel"], + help="Precessing type: super-resolution or noise-cancel", + ) + def predict(self, input, type): + """Apply super-resolution or noise-canceling to input image""" + # compute super resolution + img = Image.open(str(input)) + lr_img = np.array(img) + + if type == "super-resolution": + img = self.model_gans.predict(np.array(img)) + elif type == "noise-cancel": + img = self.model_noise_cancel.predict(np.array(img)) + else: + raise NotImplementedError("Invalid processing type selected") + + img = Image.fromarray(img) + + output_path = Path(tempfile.mkdtemp()) / "output.png" + img.save(str(output_path), "PNG") + + return output_path