|
|
|
|
@@ -178,6 +178,8 @@ class StableDiffusionProcessing:
|
|
|
|
|
self.extra_network_data = None
|
|
|
|
|
self.seeds = None
|
|
|
|
|
self.subseeds = None
|
|
|
|
|
self.recorded_checkpoint = None
|
|
|
|
|
self.recorded_checkpoint_hash = None
|
|
|
|
|
|
|
|
|
|
self.step_multiplier = 1
|
|
|
|
|
self.cached_uc = StableDiffusionProcessing.cached_uc
|
|
|
|
|
@@ -186,6 +188,7 @@ class StableDiffusionProcessing:
|
|
|
|
|
self.c = None
|
|
|
|
|
|
|
|
|
|
self.user = None
|
|
|
|
|
self.image_conditioning = None
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def sd_model(self):
|
|
|
|
|
@@ -276,10 +279,10 @@ class StableDiffusionProcessing:
|
|
|
|
|
if self.sd_model.cond_stage_key == "edit":
|
|
|
|
|
return self.edit_image_conditioning(source_image)
|
|
|
|
|
|
|
|
|
|
if self.sampler.conditioning_key in {'hybrid', 'concat'}:
|
|
|
|
|
if self.sd_model.model.conditioning_key in {'hybrid', 'concat'}:
|
|
|
|
|
return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask)
|
|
|
|
|
|
|
|
|
|
if self.sampler.conditioning_key == "crossattn-adm":
|
|
|
|
|
if self.sd_model.model.conditioning_key == "crossattn-adm":
|
|
|
|
|
return self.unclip_image_conditioning(source_image)
|
|
|
|
|
|
|
|
|
|
# Dummy zero conditioning if we're not using inpainting or depth model.
|
|
|
|
|
@@ -377,6 +380,54 @@ class StableDiffusionProcessing:
|
|
|
|
|
"""Returns whether generated images need to be written to disk"""
|
|
|
|
|
return opts.samples_save and not self.do_not_save_samples and (opts.save_incomplete_images or not state.interrupted and not state.skipped)
|
|
|
|
|
|
|
|
|
|
def run_refiner(self, samples):
|
|
|
|
|
shared.state.nextjob()
|
|
|
|
|
|
|
|
|
|
stopped_at = self.sampler.stop_at
|
|
|
|
|
noisy_output = self.sampler.noisy_output
|
|
|
|
|
self.sampler = None
|
|
|
|
|
|
|
|
|
|
a_is_sdxl = shared.sd_model.is_sdxl
|
|
|
|
|
decoded_noisy = decode_latent_batch(shared.sd_model, noisy_output, target_device=devices.cpu, check_for_nans=True)
|
|
|
|
|
|
|
|
|
|
refiner_checkpoint_info = sd_models.get_closet_checkpoint_match(shared.opts.sd_refiner_checkpoint)
|
|
|
|
|
if refiner_checkpoint_info is None:
|
|
|
|
|
raise Exception(f'Could not find checkpoint with name {shared.opts.sd_refiner_checkpoint}')
|
|
|
|
|
|
|
|
|
|
self.recorded_checkpoint = shared.sd_model.sd_checkpoint_info.name_for_extra
|
|
|
|
|
self.recorded_checkpoint_hash = shared.sd_model.sd_model_hash
|
|
|
|
|
self.extra_generation_params['Refiner'] = refiner_checkpoint_info.short_title
|
|
|
|
|
self.extra_generation_params['Refiner switch at'] = shared.opts.sd_refiner_switch_at
|
|
|
|
|
|
|
|
|
|
with sd_models.SkipWritingToConfig():
|
|
|
|
|
sd_models.reload_model_weights(info=refiner_checkpoint_info)
|
|
|
|
|
|
|
|
|
|
devices.torch_gc()
|
|
|
|
|
self.setup_conds()
|
|
|
|
|
|
|
|
|
|
b_is_sdxl = shared.sd_model.is_sdxl
|
|
|
|
|
|
|
|
|
|
if a_is_sdxl != b_is_sdxl:
|
|
|
|
|
decoded_noisy = torch.stack(decoded_noisy).float()
|
|
|
|
|
decoded_noisy = torch.clamp((decoded_noisy + 1.0) / 2.0, min=0.0, max=1.0)
|
|
|
|
|
noisy_latent = images_tensor_to_samples(decoded_noisy, approximation_indexes.get(opts.sd_vae_encode_method), shared.sd_model)
|
|
|
|
|
else:
|
|
|
|
|
noisy_latent = noisy_output
|
|
|
|
|
|
|
|
|
|
x = torch.zeros_like(noisy_latent)
|
|
|
|
|
|
|
|
|
|
with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():
|
|
|
|
|
denoising_strength = self.denoising_strength
|
|
|
|
|
|
|
|
|
|
self.denoising_strength = 1.0 - (stopped_at + 1) / self.steps
|
|
|
|
|
self.image_conditioning = txt2img_image_conditioning(shared.sd_model, noisy_latent, self.width, self.height)
|
|
|
|
|
self.sampler = sd_samplers.create_sampler(self.sampler_name, shared.sd_model)
|
|
|
|
|
samples = self.sampler.sample_img2img(self, noisy_latent, x, self.c, self.uc, image_conditioning=self.image_conditioning, steps=max(1, self.steps - stopped_at - 1))
|
|
|
|
|
|
|
|
|
|
self.denoising_strength = denoising_strength
|
|
|
|
|
|
|
|
|
|
return samples
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Processed:
|
|
|
|
|
def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_negative_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None, comments=""):
|
|
|
|
|
@@ -553,6 +604,9 @@ class DecodedSamples(list):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def decode_latent_batch(model, batch, target_device=None, check_for_nans=False):
|
|
|
|
|
if getattr(batch, 'already_decoded', False):
|
|
|
|
|
return batch
|
|
|
|
|
|
|
|
|
|
samples = DecodedSamples()
|
|
|
|
|
|
|
|
|
|
for i in range(batch.shape[0]):
|
|
|
|
|
@@ -632,8 +686,8 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
|
|
|
|
|
"Seed": p.all_seeds[0] if use_main_prompt else all_seeds[index],
|
|
|
|
|
"Face restoration": (opts.face_restoration_model if p.restore_faces else None),
|
|
|
|
|
"Size": f"{p.width}x{p.height}",
|
|
|
|
|
"Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
|
|
|
|
|
"Model": (None if not opts.add_model_name_to_info else shared.sd_model.sd_checkpoint_info.name_for_extra),
|
|
|
|
|
"Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else p.recorded_checkpoint_hash or shared.sd_model.sd_model_hash),
|
|
|
|
|
"Model": (None if not opts.add_model_name_to_info else p.recorded_checkpoint or shared.sd_model.sd_checkpoint_info.name_for_extra),
|
|
|
|
|
"Variation seed": (None if p.subseed_strength == 0 else (p.all_subseeds[0] if use_main_prompt else all_subseeds[index])),
|
|
|
|
|
"Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
|
|
|
|
|
"Seed resize from": (None if p.seed_resize_from_w <= 0 or p.seed_resize_from_h <= 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
|
|
|
|
|
@@ -666,6 +720,10 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
|
|
|
|
stored_opts = {k: opts.data[k] for k in p.override_settings.keys()}
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
# after running refiner, the refiner model is not unloaded - webui swaps back to main model here
|
|
|
|
|
if shared.sd_model.sd_checkpoint_info.title != opts.sd_model_checkpoint:
|
|
|
|
|
sd_models.reload_model_weights()
|
|
|
|
|
|
|
|
|
|
# if no checkpoint override or the override checkpoint can't be found, remove override entry and load opts checkpoint
|
|
|
|
|
if sd_models.checkpoint_aliases.get(p.override_settings.get('sd_model_checkpoint')) is None:
|
|
|
|
|
p.override_settings.pop('sd_model_checkpoint', None)
|
|
|
|
|
@@ -737,6 +795,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|
|
|
|
infotexts = []
|
|
|
|
|
output_images = []
|
|
|
|
|
|
|
|
|
|
have_refiner = shared.opts.sd_refiner_switch_at < 1.0 and shared.sd_model.sd_checkpoint_info.title != shared.opts.sd_refiner_checkpoint
|
|
|
|
|
|
|
|
|
|
with torch.no_grad(), p.sd_model.ema_scope():
|
|
|
|
|
with devices.autocast():
|
|
|
|
|
p.init(p.all_prompts, p.all_seeds, p.all_subseeds)
|
|
|
|
|
@@ -750,6 +810,10 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|
|
|
|
if state.job_count == -1:
|
|
|
|
|
state.job_count = p.n_iter
|
|
|
|
|
|
|
|
|
|
if have_refiner:
|
|
|
|
|
state.job_count *= 2
|
|
|
|
|
shared.total_tqdm.updateTotal(p.steps * state.job_count // 2)
|
|
|
|
|
|
|
|
|
|
for n in range(p.n_iter):
|
|
|
|
|
p.iteration = n
|
|
|
|
|
|
|
|
|
|
@@ -759,6 +823,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|
|
|
|
if state.interrupted:
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
sd_models.reload_model_weights() # model can be changed for example by refiner
|
|
|
|
|
|
|
|
|
|
p.prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
|
|
|
|
|
p.negative_prompts = p.all_negative_prompts[n * p.batch_size:(n + 1) * p.batch_size]
|
|
|
|
|
p.seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
|
|
|
|
|
@@ -799,15 +865,20 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|
|
|
|
shared.state.job = f"Batch {n+1} out of {p.n_iter}"
|
|
|
|
|
|
|
|
|
|
with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():
|
|
|
|
|
p.sampler = sd_samplers.create_sampler(p.sampler_name, p.sd_model)
|
|
|
|
|
|
|
|
|
|
if have_refiner:
|
|
|
|
|
p.sampler.stop_at = max(1, int(shared.opts.sd_refiner_switch_at * p.steps - 1))
|
|
|
|
|
|
|
|
|
|
samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts)
|
|
|
|
|
|
|
|
|
|
if getattr(samples_ddim, 'already_decoded', False):
|
|
|
|
|
x_samples_ddim = samples_ddim
|
|
|
|
|
else:
|
|
|
|
|
if opts.sd_vae_decode_method != 'Full':
|
|
|
|
|
p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method
|
|
|
|
|
if opts.sd_vae_decode_method != 'Full':
|
|
|
|
|
p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method
|
|
|
|
|
|
|
|
|
|
x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True)
|
|
|
|
|
if have_refiner:
|
|
|
|
|
samples_ddim = p.run_refiner(samples_ddim)
|
|
|
|
|
|
|
|
|
|
x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True)
|
|
|
|
|
|
|
|
|
|
x_samples_ddim = torch.stack(x_samples_ddim).float()
|
|
|
|
|
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
|
|
|
|
@@ -1065,8 +1136,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
|
|
|
|
self.extra_generation_params["Hires upscaler"] = self.hr_upscaler
|
|
|
|
|
|
|
|
|
|
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
|
|
|
|
|
self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
|
|
|
|
|
|
|
|
|
|
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
|
|
|
|
|
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
|
|
|
|
|
del x
|
|
|
|
|
@@ -1288,7 +1357,6 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
|
|
|
|
self.image_conditioning = None
|
|
|
|
|
|
|
|
|
|
def init(self, all_prompts, all_seeds, all_subseeds):
|
|
|
|
|
self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
|
|
|
|
|
crop_region = None
|
|
|
|
|
|
|
|
|
|
image_mask = self.image_mask
|
|
|
|
|
|