mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-06-04 21:05:48 +08:00
add oom reset
This commit is contained in:
parent
9007164a4e
commit
f82182b082
@ -1435,3 +1435,72 @@ def soft_empty_cache(force=False):
|
||||
|
||||
def unload_all_models():
|
||||
free_memory(1e30, get_torch_device(), free_all=True)
|
||||
|
||||
|
||||
def emergency_memory_cleanup():
|
||||
"""
|
||||
Emergency memory cleanup function for OOM recovery.
|
||||
This aggressively clears all GPU memory and unloads all models
|
||||
to allow the application to continue without requiring a restart.
|
||||
"""
|
||||
import gc
|
||||
|
||||
print("\n" + "="*60)
|
||||
print("[OOM Recovery] Emergency memory cleanup initiated...")
|
||||
print("="*60)
|
||||
|
||||
# Step 1: Unload all models from GPU
|
||||
print("[OOM Recovery] Step 1: Unloading all models...")
|
||||
try:
|
||||
unload_all_models()
|
||||
except Exception as e:
|
||||
print(f"[OOM Recovery] Warning during model unload: {e}")
|
||||
|
||||
# Step 2: Clear the current_loaded_models list
|
||||
print("[OOM Recovery] Step 2: Clearing model tracking list...")
|
||||
global current_loaded_models
|
||||
try:
|
||||
for model in current_loaded_models:
|
||||
try:
|
||||
model.model_unload(avoid_model_moving=True)
|
||||
except:
|
||||
pass
|
||||
current_loaded_models.clear()
|
||||
except Exception as e:
|
||||
print(f"[OOM Recovery] Warning during model list cleanup: {e}")
|
||||
|
||||
# Step 3: Force Python garbage collection
|
||||
print("[OOM Recovery] Step 3: Running garbage collection...")
|
||||
gc.collect()
|
||||
gc.collect() # Run twice for thorough cleanup
|
||||
|
||||
# Step 4: Clear PyTorch CUDA cache
|
||||
print("[OOM Recovery] Step 4: Clearing GPU cache...")
|
||||
try:
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.ipc_collect()
|
||||
torch.cuda.synchronize()
|
||||
elif is_intel_xpu():
|
||||
torch.xpu.empty_cache()
|
||||
elif mps_mode():
|
||||
torch.mps.empty_cache()
|
||||
except Exception as e:
|
||||
print(f"[OOM Recovery] Warning during cache clearing: {e}")
|
||||
|
||||
# Step 5: Reset signal flags
|
||||
global signal_empty_cache
|
||||
signal_empty_cache = False
|
||||
|
||||
# Step 6: Final garbage collection
|
||||
gc.collect()
|
||||
|
||||
# Report memory status
|
||||
try:
|
||||
device = get_torch_device()
|
||||
free_mem = get_free_memory(device)
|
||||
print(f"[OOM Recovery] Cleanup complete. Free memory: {free_mem / (1024*1024):.2f} MB")
|
||||
except:
|
||||
print("[OOM Recovery] Cleanup complete.")
|
||||
|
||||
print("="*60 + "\n")
|
||||
|
||||
@ -972,7 +972,14 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||
if p.scripts is not None:
|
||||
p.scripts.process_batch(p, batch_number=n, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds)
|
||||
|
||||
p.setup_conds()
|
||||
try:
|
||||
p.setup_conds()
|
||||
except memory_management.OOM_EXCEPTION as e:
|
||||
memory_management.emergency_memory_cleanup()
|
||||
raise RuntimeError(
|
||||
"Out of memory during text encoding. Memory has been cleared. "
|
||||
"Please try again with a shorter prompt or reduce GPU Weights slider."
|
||||
) from e
|
||||
|
||||
p.extra_generation_params.update(p.sd_model.extra_generation_params)
|
||||
|
||||
@ -1000,7 +1007,14 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||
sigmas_backup = p.sd_model.forge_objects.unet.model.predictor.sigmas
|
||||
p.sd_model.forge_objects.unet.model.predictor.set_sigmas(rescale_zero_terminal_snr_sigmas(p.sd_model.forge_objects.unet.model.predictor.sigmas))
|
||||
|
||||
samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts)
|
||||
try:
|
||||
samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts)
|
||||
except memory_management.OOM_EXCEPTION as e:
|
||||
memory_management.emergency_memory_cleanup()
|
||||
raise RuntimeError(
|
||||
"Out of memory during sampling. Memory has been cleared. "
|
||||
"Please try again with a smaller resolution, fewer steps, or reduce GPU Weights slider."
|
||||
) from e
|
||||
|
||||
for x_sample in samples_ddim:
|
||||
p.latents_after_sampling.append(x_sample)
|
||||
@ -1020,7 +1034,14 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||
|
||||
if opts.sd_vae_decode_method != 'Full':
|
||||
p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method
|
||||
x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True)
|
||||
try:
|
||||
x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True)
|
||||
except memory_management.OOM_EXCEPTION as e:
|
||||
memory_management.emergency_memory_cleanup()
|
||||
raise RuntimeError(
|
||||
"Out of memory during VAE decode. Memory has been cleared. "
|
||||
"Please try again with a smaller resolution or use tiled VAE decoding."
|
||||
) from e
|
||||
|
||||
x_samples_ddim = torch.stack(x_samples_ddim).float()
|
||||
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user