Trung Dao Thuan Hoang Nguyen Thanh Le Duc Vu
Khoi Nguyen Cuong Pham Anh Tran
In this paper, we aim to enhance the performance of SwiftBrush, a prominent one-step text-to-image diffusion model, to be competitive with its multi-step Stable Diffusion counterpart. Initially, we explore the quality-diversity trade-off between SwiftBrush and SD Turbo: the former excels in image diversity, while the latter excels in image quality. This observation motivates our proposed modifications in the training methodology, including better weight initialization and efficient LoRA training. Moreover, our introduction of a novel clamped CLIP loss enhances image-text alignment and results in improved image quality. Remarkably, by combining the weights of models trained with efficient LoRA and full training, we achieve a new state-of-the-art one-step diffusion model, achieving an FID of 8.14 and surpassing all GAN-based and multi-step Stable Diffusion models.
Please CITE our paper whenever this repository is used to help produce published results or incorporated into other software:
@inproceedings{dao2024swiftbrushv2,
title={SwiftBrush v2: Make Your One-step Diffusion Model Better Than Its Teacher},
author={Trung Dao and Thuan Hoang Nguyen and Thanh Le and Duc Vu and Khoi Nguyen and Cuong Pham and Anh Tran},
booktitle={Proceedings of the European Conference on Computer Vision (ECCV)},
year={2024}
}
- We provide the following files and folders:
tools
: utils code such as: text embedding extraction, resizingeval
: evaluation codedataset.json
: The MS-COCO2014 prompts used for evaluation.
- We also provide the checkpoint at this link.
- NOTE: While our codebase uses the 3-Clause BSD License, our model is derived from SD-Turbo and therefore must comply with SD-Turbo's original license
- First create a torch-cuda available environment:
conda create -n sbv2 python=3.10 conda activate sbv2 conda install pytorch==2.2.2 torchvision==0.17.2 torchaudio==2.2.2 pytorch-cuda=11.8 -c pytorch -c nvidia
- Install the remaining dependencies:
pip install -r requirements.txt
- For evaluating
recall
, please create another environment:conda create -n sbv2_recall python=3.10 conda activate sbv2_recall pip install -r requirements_eval_recall.txt
- Infer a normal prompt or a txt file of prompts by using
infer_by_prompt.py
python infer_by_prompt.py <prompt> <ckpt_path>
- Infer the COCO2014 prompts using the json that we delivered using the corresponding
infer.py
script.
- Easier to run but longer:
python infer.py <ckpt_path> --caption-path=dataset.json
- Faster route:
- Generate the embeddings:
python tools/prepare.py dataset.json
- Infer using the embeddings:
python infer.py <ckpt_path> --caption-path=dataset.json --precomputed-text-path-path=<generated_embeds_path>
- Generate the embeddings:
- Following GigaGAN paper, we evaluate the model with the following flow:
- Center-crop the COCO2014 to 256x256 images
python tools/resize.py <coco2014_path> <coco2014_resized_path> --nsamples=-1
- Resize the inferred folder:
python tools/resize.py main <infer_folder> <infer_resized_folder>
- Evaluate:
- FID:
python eval/fid.py <infer_resized_folder> --ref-dir=<coco2014_resized_path> --no-crop
- CLIP Score:
python eval/clip_score.py <infer_resized_folder> --batch-size=1024 --prompt_path=<coco2014_prompt_path>
- Precision/Recall:
python eval/recall.py <coco2014_resized_path> <infer_resized_folder>
- One can use
scripts/eval.sh
to automate step 4, but remember to change the paths
- For HPSv2 evaluation metrics:
- Infer using the corresponding
infer_hps.py
script:python infer_hps.py <ckpt_path>
- Get the final score with:
python eval/hps.py <hps_infer_folder>
- Infer using the corresponding
Copyright (c) 2024 VinAI
Licensed under the 3-Clause BSD License.
You may obtain a copy of the License at
https://opensource.org/license/bsd-3-clause