Compare commits

...

4 Commits

Author SHA1 Message Date
AUTOMATIC1111
f1d7c07a5a repair img2img 2023-08-07 12:21:05 +03:00
AUTOMATIC1111
686598387f send noisy latent into refiner without adding noise 2023-08-07 12:10:16 +03:00
AUTOMATIC1111
3f82820612 apply unet overrides after switching model 2023-08-07 08:16:30 +03:00
AUTOMATIC1111
6c7b6ecb81 alternative refiner implementation 2023-08-06 22:08:52 +03:00
6 changed files with 107 additions and 16 deletions

View File

@@ -344,6 +344,8 @@ infotext_to_setting_name_mapping = [
('Pad conds', 'pad_cond_uncond'),
('VAE Encoder', 'sd_vae_encode_method'),
('VAE Decoder', 'sd_vae_decode_method'),
('Refiner', 'sd_refiner_checkpoint'),
('Refiner switch at', 'sd_refiner_switch_at'),
]

View File

@@ -178,6 +178,8 @@ class StableDiffusionProcessing:
self.extra_network_data = None
self.seeds = None
self.subseeds = None
self.recorded_checkpoint = None
self.recorded_checkpoint_hash = None
self.step_multiplier = 1
self.cached_uc = StableDiffusionProcessing.cached_uc
@@ -186,6 +188,7 @@ class StableDiffusionProcessing:
self.c = None
self.user = None
self.image_conditioning = None
@property
def sd_model(self):
@@ -276,10 +279,10 @@ class StableDiffusionProcessing:
if self.sd_model.cond_stage_key == "edit":
return self.edit_image_conditioning(source_image)
if self.sampler.conditioning_key in {'hybrid', 'concat'}:
if self.sd_model.model.conditioning_key in {'hybrid', 'concat'}:
return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask)
if self.sampler.conditioning_key == "crossattn-adm":
if self.sd_model.model.conditioning_key == "crossattn-adm":
return self.unclip_image_conditioning(source_image)
# Dummy zero conditioning if we're not using inpainting or depth model.
@@ -377,6 +380,54 @@ class StableDiffusionProcessing:
"""Returns whether generated images need to be written to disk"""
return opts.samples_save and not self.do_not_save_samples and (opts.save_incomplete_images or not state.interrupted and not state.skipped)
def run_refiner(self, samples):
shared.state.nextjob()
stopped_at = self.sampler.stop_at
noisy_output = self.sampler.noisy_output
self.sampler = None
a_is_sdxl = shared.sd_model.is_sdxl
decoded_noisy = decode_latent_batch(shared.sd_model, noisy_output, target_device=devices.cpu, check_for_nans=True)
refiner_checkpoint_info = sd_models.get_closet_checkpoint_match(shared.opts.sd_refiner_checkpoint)
if refiner_checkpoint_info is None:
raise Exception(f'Could not find checkpoint with name {shared.opts.sd_refiner_checkpoint}')
self.recorded_checkpoint = shared.sd_model.sd_checkpoint_info.name_for_extra
self.recorded_checkpoint_hash = shared.sd_model.sd_model_hash
self.extra_generation_params['Refiner'] = refiner_checkpoint_info.short_title
self.extra_generation_params['Refiner switch at'] = shared.opts.sd_refiner_switch_at
with sd_models.SkipWritingToConfig():
sd_models.reload_model_weights(info=refiner_checkpoint_info)
devices.torch_gc()
self.setup_conds()
b_is_sdxl = shared.sd_model.is_sdxl
if a_is_sdxl != b_is_sdxl:
decoded_noisy = torch.stack(decoded_noisy).float()
decoded_noisy = torch.clamp((decoded_noisy + 1.0) / 2.0, min=0.0, max=1.0)
noisy_latent = images_tensor_to_samples(decoded_noisy, approximation_indexes.get(opts.sd_vae_encode_method), shared.sd_model)
else:
noisy_latent = noisy_output
x = torch.zeros_like(noisy_latent)
with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():
denoising_strength = self.denoising_strength
self.denoising_strength = 1.0 - (stopped_at + 1) / self.steps
self.image_conditioning = txt2img_image_conditioning(shared.sd_model, noisy_latent, self.width, self.height)
self.sampler = sd_samplers.create_sampler(self.sampler_name, shared.sd_model)
samples = self.sampler.sample_img2img(self, noisy_latent, x, self.c, self.uc, image_conditioning=self.image_conditioning, steps=max(1, self.steps - stopped_at - 1))
self.denoising_strength = denoising_strength
return samples
class Processed:
def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_negative_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None, comments=""):
@@ -553,6 +604,9 @@ class DecodedSamples(list):
def decode_latent_batch(model, batch, target_device=None, check_for_nans=False):
if getattr(batch, 'already_decoded', False):
return batch
samples = DecodedSamples()
for i in range(batch.shape[0]):
@@ -632,8 +686,8 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
"Seed": p.all_seeds[0] if use_main_prompt else all_seeds[index],
"Face restoration": (opts.face_restoration_model if p.restore_faces else None),
"Size": f"{p.width}x{p.height}",
"Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
"Model": (None if not opts.add_model_name_to_info else shared.sd_model.sd_checkpoint_info.name_for_extra),
"Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else p.recorded_checkpoint_hash or shared.sd_model.sd_model_hash),
"Model": (None if not opts.add_model_name_to_info else p.recorded_checkpoint or shared.sd_model.sd_checkpoint_info.name_for_extra),
"Variation seed": (None if p.subseed_strength == 0 else (p.all_subseeds[0] if use_main_prompt else all_subseeds[index])),
"Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
"Seed resize from": (None if p.seed_resize_from_w <= 0 or p.seed_resize_from_h <= 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
@@ -666,6 +720,10 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
stored_opts = {k: opts.data[k] for k in p.override_settings.keys()}
try:
# after running refiner, the refiner model is not unloaded - webui swaps back to main model here
if shared.sd_model.sd_checkpoint_info.title != opts.sd_model_checkpoint:
sd_models.reload_model_weights()
# if no checkpoint override or the override checkpoint can't be found, remove override entry and load opts checkpoint
if sd_models.checkpoint_aliases.get(p.override_settings.get('sd_model_checkpoint')) is None:
p.override_settings.pop('sd_model_checkpoint', None)
@@ -737,6 +795,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
infotexts = []
output_images = []
have_refiner = shared.opts.sd_refiner_switch_at < 1.0 and shared.sd_model.sd_checkpoint_info.title != shared.opts.sd_refiner_checkpoint
with torch.no_grad(), p.sd_model.ema_scope():
with devices.autocast():
p.init(p.all_prompts, p.all_seeds, p.all_subseeds)
@@ -750,6 +810,10 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if state.job_count == -1:
state.job_count = p.n_iter
if have_refiner:
state.job_count *= 2
shared.total_tqdm.updateTotal(p.steps * state.job_count // 2)
for n in range(p.n_iter):
p.iteration = n
@@ -759,6 +823,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if state.interrupted:
break
sd_models.reload_model_weights() # model can be changed for example by refiner
p.prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
p.negative_prompts = p.all_negative_prompts[n * p.batch_size:(n + 1) * p.batch_size]
p.seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
@@ -799,15 +865,20 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
shared.state.job = f"Batch {n+1} out of {p.n_iter}"
with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():
p.sampler = sd_samplers.create_sampler(p.sampler_name, p.sd_model)
if have_refiner:
p.sampler.stop_at = max(1, int(shared.opts.sd_refiner_switch_at * p.steps - 1))
samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts)
if getattr(samples_ddim, 'already_decoded', False):
x_samples_ddim = samples_ddim
else:
if opts.sd_vae_decode_method != 'Full':
p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method
if opts.sd_vae_decode_method != 'Full':
p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method
x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True)
if have_refiner:
samples_ddim = p.run_refiner(samples_ddim)
x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True)
x_samples_ddim = torch.stack(x_samples_ddim).float()
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
@@ -1065,8 +1136,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
self.extra_generation_params["Hires upscaler"] = self.hr_upscaler
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
del x
@@ -1288,7 +1357,6 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
self.image_conditioning = None
def init(self, all_prompts, all_seeds, all_subseeds):
self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
crop_region = None
image_mask = self.image_mask

View File

@@ -289,11 +289,27 @@ def get_checkpoint_state_dict(checkpoint_info: CheckpointInfo, timer):
return res
class SkipWritingToConfig:
"""This context manager prevents load_model_weights from writing checkpoint name to the config when it loads weight."""
skip = False
previous = None
def __enter__(self):
self.previous = SkipWritingToConfig.skip
SkipWritingToConfig.skip = True
return self
def __exit__(self, exc_type, exc_value, exc_traceback):
SkipWritingToConfig.skip = self.previous
def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer):
sd_model_hash = checkpoint_info.calculate_shorthash()
timer.record("calculate hash")
shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title
if not SkipWritingToConfig.skip:
shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title
if state_dict is None:
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
@@ -699,6 +715,7 @@ def reload_model_weights(sd_model=None, info=None):
print(f"Weights loaded in {timer.summary()}.")
model_data.set_sd_model(sd_model)
sd_unet.apply_unet()
return sd_model

View File

@@ -44,7 +44,7 @@ class VanillaStableDiffusionSampler:
return 0
def launch_sampling(self, steps, func):
state.sampling_steps = steps
state.sampling_steps = self.stop_at if self.stop_at is not None else steps
state.sampling_step = 0
try:

View File

@@ -276,6 +276,7 @@ class KDiffusionSampler:
self.model_wrap_cfg = CFGDenoiser(self.model_wrap)
self.sampler_noises = None
self.stop_at = None
self.noisy_output = None
self.eta = None
self.config = None # set by the function calling the constructor
self.last_latent = None
@@ -297,6 +298,7 @@ class KDiffusionSampler:
if opts.live_preview_content == "Combined":
sd_samplers_common.store_latent(latent)
self.last_latent = latent
self.noisy_output = d['x']
if self.stop_at is not None and step > self.stop_at:
raise sd_samplers_common.InterruptedException
@@ -305,7 +307,7 @@ class KDiffusionSampler:
shared.total_tqdm.update()
def launch_sampling(self, steps, func):
state.sampling_steps = steps
state.sampling_steps = self.stop_at if self.stop_at is not None else steps
state.sampling_step = 0
try:

View File

@@ -461,6 +461,8 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
"CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#clip-skip").info("ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer"),
"upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"),
"randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU", "NV"]}).info("changes seeds drastically; use CPU to produce the same picture across different videocard vendors; use NV to produce same picture as on NVidia videocards"),
"sd_refiner_checkpoint": OptionInfo(None, "Refiner checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints).info("switch to another model in the middle of generation"),
"sd_refiner_switch_at": OptionInfo(1.0, "Refiner switch at", gr.Slider, {"minimum": 0.01, "maximum": 1.0, "step": 0.01}).info("fraction of sampling steps when the swtch to refiner model should happen; 1=never, 0.5=switch in the middle of generation"),
}))
options_templates.update(options_section(('sdxl', "Stable Diffusion XL"), {