Compare commits

...

4 Commits

Author SHA1 Message Date
Kohaku-Blueleaf
a36a30fb93 add gc after using consistency dec 2023-11-07 13:01:10 +08:00
Kohaku-Blueleaf
2ea8726597 custom schedule 2023-11-07 12:35:56 +08:00
Kohaku-Blueleaf
5dbd0355b0 Fix linting 2023-11-07 11:00:24 +08:00
Kohaku-Blueleaf
64fd916334 Add consistency decoder 2023-11-07 10:52:29 +08:00
5 changed files with 48 additions and 3 deletions

View File

@@ -3,7 +3,7 @@ from collections import namedtuple
import numpy as np
import torch
from PIL import Image
from modules import devices, images, sd_vae_approx, sd_samplers, sd_vae_taesd, shared, sd_models
from modules import devices, images, sd_vae_approx, sd_samplers, sd_vae_taesd, sd_vae_consistency, shared, sd_models
from modules.shared import opts, state
import k_diffusion.sampling
@@ -31,7 +31,7 @@ def setup_img2img_steps(p, steps=None):
return steps, t_enc
approximation_indexes = {"Full": 0, "Approx NN": 1, "Approx cheap": 2, "TAESD": 3}
approximation_indexes = {"Full": 0, "Approx NN": 1, "Approx cheap": 2, "TAESD": 3, "Consistency Decoder": 4}
def samples_to_images_tensor(sample, approximation=None, model=None):
@@ -51,6 +51,13 @@ def samples_to_images_tensor(sample, approximation=None, model=None):
elif approximation == 3:
x_sample = sd_vae_taesd.decoder_model()(sample.to(devices.device, devices.dtype)).detach()
x_sample = x_sample * 2 - 1
elif approximation == 4:
with devices.autocast(), torch.no_grad():
x_sample = sd_vae_consistency.decoder_model()(
sample.detach().to(devices.device, devices.dtype)/0.18215,
schedule=[float(i.strip()) for i in shared.opts.sd_vae_consistency_schedule.split(',')],
)
sd_vae_consistency.unload()
else:
if model is None:
model = shared.sd_model

View File

@@ -0,0 +1,34 @@
"""
Consistency Decoder
Improved decoding for stable diffusion vaes.
https://github.com/openai/consistencydecoder
"""
import os
from modules import devices, paths_internal, shared
from consistencydecoder import ConsistencyDecoder
sd_vae_consistency_models = None
model_path = os.path.join(paths_internal.models_path, 'consistencydecoder')
def decoder_model():
global sd_vae_consistency_models
if getattr(shared.sd_model, 'is_sdxl', False):
raise NotImplementedError("SDXL is not supported for consistency decoder")
if sd_vae_consistency_models is not None:
sd_vae_consistency_models.ckpt.to(devices.device)
return sd_vae_consistency_models
loaded_model = ConsistencyDecoder(devices.device, model_path)
sd_vae_consistency_models = loaded_model
return loaded_model
def unload():
global sd_vae_consistency_models
if sd_vae_consistency_models is not None:
devices.torch_gc()
sd_vae_consistency_models.ckpt.to('cpu')

View File

@@ -172,7 +172,8 @@ For img2img, VAE is used to process user's input image before the sampling, and
"sd_vae_overrides_per_model_preferences": OptionInfo(True, "Selected VAE overrides per-model preferences").info("you can set per-model VAE either by editing user metadata for checkpoints, or by making the VAE have same name as checkpoint"),
"auto_vae_precision": OptionInfo(True, "Automatically revert VAE to 32-bit floats").info("triggers when a tensor with NaNs is produced in VAE; disabling the option in this case will result in a black square image"),
"sd_vae_encode_method": OptionInfo("Full", "VAE type for encode", gr.Radio, {"choices": ["Full", "TAESD"]}, infotext='VAE Encoder').info("method to encode image to latent (use in img2img, hires-fix or inpaint mask)"),
"sd_vae_decode_method": OptionInfo("Full", "VAE type for decode", gr.Radio, {"choices": ["Full", "TAESD"]}, infotext='VAE Decoder').info("method to decode latent to image"),
"sd_vae_decode_method": OptionInfo("Full", "VAE type for decode", gr.Radio, {"choices": ["Full", "TAESD", "Consistency Decoder"]}, infotext='VAE Decoder').info("method to decode latent to image"),
"sd_vae_consistency_schedule": OptionInfo("1.0, 0.5", "consistency schedule").info("sampling schedule for consistency decoder."),
}))
options_templates.update(options_section(('img2img', "img2img"), {

View File

@@ -32,3 +32,5 @@ torch
torchdiffeq
torchsde
transformers==4.30.2
git+https://github.com/openai/consistencydecoder.git

View File

@@ -30,3 +30,4 @@ torchdiffeq==0.2.3
torchsde==0.2.6
transformers==4.30.2
httpx==0.24.1
git+https://github.com/openai/consistencydecoder.git