diff --git a/backend/memory_management.py b/backend/memory_management.py index 04ce077e..32936d43 100644 --- a/backend/memory_management.py +++ b/backend/memory_management.py @@ -1435,3 +1435,72 @@ def soft_empty_cache(force=False): def unload_all_models(): free_memory(1e30, get_torch_device(), free_all=True) + + +def emergency_memory_cleanup(): + """ + Emergency memory cleanup function for OOM recovery. + This aggressively clears all GPU memory and unloads all models + to allow the application to continue without requiring a restart. + """ + import gc + + print("\n" + "="*60) + print("[OOM Recovery] Emergency memory cleanup initiated...") + print("="*60) + + # Step 1: Unload all models from GPU + print("[OOM Recovery] Step 1: Unloading all models...") + try: + unload_all_models() + except Exception as e: + print(f"[OOM Recovery] Warning during model unload: {e}") + + # Step 2: Clear the current_loaded_models list + print("[OOM Recovery] Step 2: Clearing model tracking list...") + global current_loaded_models + try: + for model in current_loaded_models: + try: + model.model_unload(avoid_model_moving=True) + except: + pass + current_loaded_models.clear() + except Exception as e: + print(f"[OOM Recovery] Warning during model list cleanup: {e}") + + # Step 3: Force Python garbage collection + print("[OOM Recovery] Step 3: Running garbage collection...") + gc.collect() + gc.collect() # Run twice for thorough cleanup + + # Step 4: Clear PyTorch CUDA cache + print("[OOM Recovery] Step 4: Clearing GPU cache...") + try: + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + torch.cuda.synchronize() + elif is_intel_xpu(): + torch.xpu.empty_cache() + elif mps_mode(): + torch.mps.empty_cache() + except Exception as e: + print(f"[OOM Recovery] Warning during cache clearing: {e}") + + # Step 5: Reset signal flags + global signal_empty_cache + signal_empty_cache = False + + # Step 6: Final garbage collection + gc.collect() + + # Report memory status + try: + device = get_torch_device() + free_mem = get_free_memory(device) + print(f"[OOM Recovery] Cleanup complete. Free memory: {free_mem / (1024*1024):.2f} MB") + except: + print("[OOM Recovery] Cleanup complete.") + + print("="*60 + "\n") diff --git a/modules/processing.py b/modules/processing.py index 2769c770..ce2ccea5 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -972,7 +972,14 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if p.scripts is not None: p.scripts.process_batch(p, batch_number=n, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds) - p.setup_conds() + try: + p.setup_conds() + except memory_management.OOM_EXCEPTION as e: + memory_management.emergency_memory_cleanup() + raise RuntimeError( + "Out of memory during text encoding. Memory has been cleared. " + "Please try again with a shorter prompt or reduce GPU Weights slider." + ) from e p.extra_generation_params.update(p.sd_model.extra_generation_params) @@ -1000,7 +1007,14 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: sigmas_backup = p.sd_model.forge_objects.unet.model.predictor.sigmas p.sd_model.forge_objects.unet.model.predictor.set_sigmas(rescale_zero_terminal_snr_sigmas(p.sd_model.forge_objects.unet.model.predictor.sigmas)) - samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts) + try: + samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts) + except memory_management.OOM_EXCEPTION as e: + memory_management.emergency_memory_cleanup() + raise RuntimeError( + "Out of memory during sampling. Memory has been cleared. " + "Please try again with a smaller resolution, fewer steps, or reduce GPU Weights slider." + ) from e for x_sample in samples_ddim: p.latents_after_sampling.append(x_sample) @@ -1020,7 +1034,14 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if opts.sd_vae_decode_method != 'Full': p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method - x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True) + try: + x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True) + except memory_management.OOM_EXCEPTION as e: + memory_management.emergency_memory_cleanup() + raise RuntimeError( + "Out of memory during VAE decode. Memory has been cleared. " + "Please try again with a smaller resolution or use tiled VAE decoding." + ) from e x_samples_ddim = torch.stack(x_samples_ddim).float() x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)