add oom reset

This commit is contained in:
maybleMyers 2025-12-03 14:23:25 -08:00
parent 9007164a4e
commit f82182b082
2 changed files with 93 additions and 3 deletions

View File

@ -1435,3 +1435,72 @@ def soft_empty_cache(force=False):
def unload_all_models():
free_memory(1e30, get_torch_device(), free_all=True)
def emergency_memory_cleanup():
"""
Emergency memory cleanup function for OOM recovery.
This aggressively clears all GPU memory and unloads all models
to allow the application to continue without requiring a restart.
"""
import gc
print("\n" + "="*60)
print("[OOM Recovery] Emergency memory cleanup initiated...")
print("="*60)
# Step 1: Unload all models from GPU
print("[OOM Recovery] Step 1: Unloading all models...")
try:
unload_all_models()
except Exception as e:
print(f"[OOM Recovery] Warning during model unload: {e}")
# Step 2: Clear the current_loaded_models list
print("[OOM Recovery] Step 2: Clearing model tracking list...")
global current_loaded_models
try:
for model in current_loaded_models:
try:
model.model_unload(avoid_model_moving=True)
except:
pass
current_loaded_models.clear()
except Exception as e:
print(f"[OOM Recovery] Warning during model list cleanup: {e}")
# Step 3: Force Python garbage collection
print("[OOM Recovery] Step 3: Running garbage collection...")
gc.collect()
gc.collect() # Run twice for thorough cleanup
# Step 4: Clear PyTorch CUDA cache
print("[OOM Recovery] Step 4: Clearing GPU cache...")
try:
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
torch.cuda.synchronize()
elif is_intel_xpu():
torch.xpu.empty_cache()
elif mps_mode():
torch.mps.empty_cache()
except Exception as e:
print(f"[OOM Recovery] Warning during cache clearing: {e}")
# Step 5: Reset signal flags
global signal_empty_cache
signal_empty_cache = False
# Step 6: Final garbage collection
gc.collect()
# Report memory status
try:
device = get_torch_device()
free_mem = get_free_memory(device)
print(f"[OOM Recovery] Cleanup complete. Free memory: {free_mem / (1024*1024):.2f} MB")
except:
print("[OOM Recovery] Cleanup complete.")
print("="*60 + "\n")

View File

@ -972,7 +972,14 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if p.scripts is not None:
p.scripts.process_batch(p, batch_number=n, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds)
p.setup_conds()
try:
p.setup_conds()
except memory_management.OOM_EXCEPTION as e:
memory_management.emergency_memory_cleanup()
raise RuntimeError(
"Out of memory during text encoding. Memory has been cleared. "
"Please try again with a shorter prompt or reduce GPU Weights slider."
) from e
p.extra_generation_params.update(p.sd_model.extra_generation_params)
@ -1000,7 +1007,14 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
sigmas_backup = p.sd_model.forge_objects.unet.model.predictor.sigmas
p.sd_model.forge_objects.unet.model.predictor.set_sigmas(rescale_zero_terminal_snr_sigmas(p.sd_model.forge_objects.unet.model.predictor.sigmas))
samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts)
try:
samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts)
except memory_management.OOM_EXCEPTION as e:
memory_management.emergency_memory_cleanup()
raise RuntimeError(
"Out of memory during sampling. Memory has been cleared. "
"Please try again with a smaller resolution, fewer steps, or reduce GPU Weights slider."
) from e
for x_sample in samples_ddim:
p.latents_after_sampling.append(x_sample)
@ -1020,7 +1034,14 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if opts.sd_vae_decode_method != 'Full':
p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method
x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True)
try:
x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True)
except memory_management.OOM_EXCEPTION as e:
memory_management.emergency_memory_cleanup()
raise RuntimeError(
"Out of memory during VAE decode. Memory has been cleared. "
"Please try again with a smaller resolution or use tiled VAE decoding."
) from e
x_samples_ddim = torch.stack(x_samples_ddim).float()
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)