Skip to content

Latest commit

 

History

History
136 lines (86 loc) · 4.29 KB

README.md

File metadata and controls

136 lines (86 loc) · 4.29 KB

🖼️ Dream Bench 📊

What does it do?

dream_bench provides a simplified interface for benchmarking your image-generation models.

This repository also hosts common prompt-lists for benchmarking, as well as instructions for preparing a dataset to perform a more comprehensive evaluation.

How does it work?

To start off, you will need to create an evaluation adapter for you model so that it can interface with dream_bench. This function will be called by dream_bench and expect an image in return.

class MyImageModel():
    def __init__(self, *args, **kwargs):
        super().__init__()
        ...

    def sample(self, img_emb, txt_emb):
        ...

    def my_evaluation_harness(self, conditioning_args):

        # extract what you need to from the conditioning arguments

        img_emb = conditioning_args["prior_image_embedding.npy"]
        txt_emb = conditioning_args["clip_text_embedding.npy"]

        # sample with your model's function

        predicted_image = self.sample(img_emb=img_emb, txt_emb=txt_emb)

        # return the image(s) to be evaluated by dream bench

        return predicted_image

Once you have a function that will accept conditioning_args and return an image, you can pass this function to dream_bench to handle the rest!

from dream_bench import benchmark, DreamBenchConfig

# specify what you'd like to benchmark

config = DreamBenchConfig.from_json_path("<path to your config>")

def train(model, dataloader, epochs):
    for epoch in range(epochs):
        for x,y in dataloader:
            # do training

            ...

            # benchmark on some interval

            if time_to_benchmark:
                benchmark(adapter=model.my_evaluation_harness, config=config)

If you're done training and would like to benchmark a pre-trained model it can be done in the following way.

from dream_bench import benchmark, DreamBenchConfig

# specify what you'd like to benchmark

config = DreamBenchConfig.from_json_path("<path to your config>")

# if your model doesn't have an adapter, you can create one now

def my_evaluation_harness(self, conditioning_args):

        # extract what you need to from the conditioning arguments

        img_emb = conditioning_args["prior_image_embedding.npy"]
        txt_emb = conditioning_args["clip_text_embedding.npy"]

        # sample with your model's function

        predicted_image = self.sample(img_emb=img_emb, txt_emb=txt_emb)

        # return the image(s) to be evaluated by dream bench

        return predicted_image

def main():
    # load your model

    model = load_model()

    # call benchmark outside of training

    benchmark(adapter=model.my_evaluation_harness, config=config)

if __name__ == "__main__":
    main()

Setup

dream_bench works on the Webdataset format.

Before you run evaluation, you must preprocess your dataset/prompt-list so that is compatible with dream_bench.

For more information on how this can be done, take a look at the dedicated readme

Next you will need to create a configuration file to run your evaluation, this will tell dream-bench how to benchmark your run, some information on how to build this config file you can read its dedicated readme

Tracking

To track your experiments, dream_bench utilizes weights & biases. With wandb it is possible to easily view your generations in an easy-to-read format, compile reports, and query against/across your runs.

Additional support may be added for other trackers in the future.


ToDo

  • Determine dataset format
  • Provide scripts for generating datasets from prompt lists
  • Block in more metrics
  • Add continuous integration
  • Add wandb integration
  • Build configuration
  • Lint project & Capture unecessary warnings
  • Build out benchmarking interface for clip-based models
  • Track runs with wandb (look for efficiency later)
  • Track averages of each metric
  • Add FID metric
  • Add option for distributed evaluation
  • Prioritize benchmarking for ongoing training runs
  • Provide guide/scripts for formatting other datasets/url-lists
  • Add More tests
  • Profile code & look for optimizations
  • Publish to PyPi
  • Add support for distributed benchmarking