Skip to content

Latest commit

 

History

History
99 lines (79 loc) · 4.1 KB

README.md

File metadata and controls

99 lines (79 loc) · 4.1 KB

Conjugate Pseudo-Inverse Guidance

NeurIPS'24


Official Implementation of the paper: Fast Samplers for Inverse Problems in Iterative Refinement Models

Overview

We propose a method for fast sampler based on Conjugate Integrators for solving inverse problems using pretrained diffusion or flow models. Also see our official project page for qualitative results.

Code Setup

This repo builds on top of the official RED-Diff and official Rectified Flow implementations. Please clone the repo and cd to either diffusion or flow as the working directory for downstream tasks like setting up the dependencies, and inference.

Dependency Setup

Diffusion: Our requirements.txt file can be used to setup a conda environment directly using the following command:

conda create --name <env> --file requirements.txt

Flow: Please see Rectified Flow for setting up the dependencies.

Setting up Pretrained Checkpoints

Diffusion: In this work we use pretrained unconditional ImageNet checkpoints from the ADM repository. We use an pretrained unconditional FFHQ checkpoint from the DPS repository

Flow: We use pretrained weight from Rectified Flow, with some additional packages including pytorch_msssim, omegaconf, compressai, lpips (mostly about evaluation and degradation transform).

Config Management

Diffusion: Config management is done using Hydra configs. All configs can be found in the directory _configs/ with algorithm specific config in the _configs/algo directory.

Diffusion Inference

We include a sample inference script used in this work in the directory test_scripts/test.sh. For instance, using the noiseless linear Conjugate $\Pi$-GDM for a 4x super-resolution task using 5 sampling steps can be performed using the following command:

samples_root=demo_samples/
exp_root=/home/pandeyk1/ciip_results/
ckpt_root=/home/pandeyk1/ciip_results/pretrained_chkpts/adm/
save_deg=True
save_ori=True
overwrite=True
smoke_test=1 # Controls the number of batches generated
batch_size=1


python main.py \
        diffusion=vpsde \
        classifier=none \
        algo=cpgdm \
        algo.lam=0 \
        algo.w=15.0 \
        algo.deg=bicubic4 \
        algo.num_eps=1e-6 \
        algo.denoise=False \
        loader=imagenet256_ddrmpp \
        loader.batch_size=$batch_size \
        loader.num_workers=2 \
        dist.num_processes_per_node=1 \
        exp.name=debug \
        exp.num_steps=5 \
        exp.seed=0 \
        exp.stride=uniform \
        exp.t_start=0.6 \
        exp.t_end=0.999 \
        exp.root=$exp_root \
        exp.name=demo \
        exp.ckpt_root=$ckpt_root \
        exp.samples_root=$samples_root \
        exp.overwrite=True \
        exp.save_ori=$save_ori \
        exp.save_deg=$save_deg \
        exp.smoke_test=$smoke_test

Flow Inference

See run.sh in the flow directory for running the inference on the flow model.

Citation

If you find the code useful for your research, please consider citing our paper:

@misc{pandey2024fastsamplersinverseproblems,
      title={Fast Samplers for Inverse Problems in Iterative Refinement Models}, 
      author={Kushagra Pandey and Ruihan Yang and Stephan Mandt},
      year={2024},
      eprint={2405.17673},
      archivePrefix={arXiv},
      primaryClass={cs.CV},
      url={https://arxiv.org/abs/2405.17673}, 
}