add oom reset

This commit is contained in:
maybleMyers 2025-12-03 14:39:36 -08:00
parent f82182b082
commit 814016a4e4
2 changed files with 117 additions and 15 deletions

View File

@ -1449,15 +1449,56 @@ def emergency_memory_cleanup():
print("[OOM Recovery] Emergency memory cleanup initiated...")
print("="*60)
# Step 1: Unload all models from GPU
print("[OOM Recovery] Step 1: Unloading all models...")
# Step 1: Delete all model references (don't try to move - we're OOM)
# Note: forge_objects, forge_objects_original, and forge_objects_after_applying_lora
# are shallow copies that share the SAME model references, so we must clear all three
print("[OOM Recovery] Step 1: Deleting all model references...")
try:
from modules import shared
if hasattr(shared, 'sd_model') and shared.sd_model is not None:
model = shared.sd_model
# Collect all forge_objects variants (they share the same underlying models)
forge_objects_list = []
for attr in ['forge_objects', 'forge_objects_original', 'forge_objects_after_applying_lora']:
if hasattr(model, attr) and getattr(model, attr) is not None:
forge_objects_list.append(getattr(model, attr))
setattr(model, attr, None)
# Clear all component references from all forge_objects
for fo in forge_objects_list:
# Clear unet reference
if hasattr(fo, 'unet'):
fo.unet = None
# Clear clip reference
if hasattr(fo, 'clip'):
fo.clip = None
# Clear vae reference
if hasattr(fo, 'vae'):
fo.vae = None
# Clear clipvision reference
if hasattr(fo, 'clipvision'):
fo.clipvision = None
del forge_objects_list
print("[OOM Recovery] Cleared forge_objects references")
# Clear the main model reference
shared.sd_model = None
del model
print("[OOM Recovery] Cleared shared.sd_model")
except Exception as e:
print(f"[OOM Recovery] Warning during shared model cleanup: {e}")
# Step 2: Unload all tracked models from GPU
print("[OOM Recovery] Step 2: Unloading all tracked models...")
try:
unload_all_models()
except Exception as e:
print(f"[OOM Recovery] Warning during model unload: {e}")
# Step 2: Clear the current_loaded_models list
print("[OOM Recovery] Step 2: Clearing model tracking list...")
# Step 3: Clear the current_loaded_models list
print("[OOM Recovery] Step 3: Clearing model tracking list...")
global current_loaded_models
try:
for model in current_loaded_models:
@ -1469,18 +1510,72 @@ def emergency_memory_cleanup():
except Exception as e:
print(f"[OOM Recovery] Warning during model list cleanup: {e}")
# Step 3: Force Python garbage collection
print("[OOM Recovery] Step 3: Running garbage collection...")
gc.collect()
gc.collect() # Run twice for thorough cleanup
# Step 4: Clear model_data if it exists
print("[OOM Recovery] Step 4: Clearing model data...")
try:
from modules import sd_models
if hasattr(sd_models, 'model_data'):
sd_models.model_data.sd_model = None
sd_models.model_data.forge_hash = None
except Exception as e:
print(f"[OOM Recovery] Warning during model_data cleanup: {e}")
# Step 4: Clear PyTorch CUDA cache
print("[OOM Recovery] Step 4: Clearing GPU cache...")
# Step 5: Aggressively move ALL modules to CPU by scanning globals
print("[OOM Recovery] Step 5: Moving all GPU tensors to CPU...")
try:
# Get all modules that might have been loaded
modules_to_check = []
# Check for any nn.Module in the shared namespace
try:
from modules import shared
if hasattr(shared, 'sd_model') and shared.sd_model is not None:
modules_to_check.append(shared.sd_model)
except:
pass
# Move any found modules to CPU
for mod in modules_to_check:
if mod is not None and isinstance(mod, torch.nn.Module):
try:
mod.to('cpu')
print(f"[OOM Recovery] Moved {mod.__class__.__name__} to CPU")
except:
pass
# Nuclear option: iterate through ALL tensors tracked by PyTorch
# and delete any that are on CUDA
moved_count = 0
for obj in gc.get_objects():
try:
if isinstance(obj, torch.Tensor) and obj.device.type == 'cuda':
# Can't directly move, but we can help GC by removing refs
moved_count += 1
except:
pass
if moved_count > 0:
print(f"[OOM Recovery] Found {moved_count} CUDA tensors (will be freed after gc)")
except Exception as e:
print(f"[OOM Recovery] Warning during tensor scan: {e}")
# Step 6: Force Python garbage collection (multiple passes)
print("[OOM Recovery] Step 6: Running garbage collection...")
gc.collect()
gc.collect()
gc.collect() # Run three times for thorough cleanup
# Step 7: Clear PyTorch CUDA cache aggressively
print("[OOM Recovery] Step 7: Clearing GPU cache...")
try:
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
torch.cuda.synchronize()
# Reset memory stats
try:
torch.cuda.reset_peak_memory_stats()
torch.cuda.reset_accumulated_memory_stats()
except:
pass
elif is_intel_xpu():
torch.xpu.empty_cache()
elif mps_mode():
@ -1488,13 +1583,13 @@ def emergency_memory_cleanup():
except Exception as e:
print(f"[OOM Recovery] Warning during cache clearing: {e}")
# Step 5: Reset signal flags
# Step 8: Final garbage collection after cache clear
gc.collect()
# Step 9: Reset signal flags
global signal_empty_cache
signal_empty_cache = False
# Step 6: Final garbage collection
gc.collect()
# Report memory status
try:
device = get_torch_device()

View File

@ -501,7 +501,14 @@ def forge_model_reload():
dynamic_args['forge_unet_storage_dtype'] = model_data.forge_loading_parameters.get('unet_storage_dtype', None)
dynamic_args['embedding_dir'] = cmd_opts.embeddings_dir
dynamic_args['emphasis_name'] = opts.emphasis
sd_model = forge_loader(state_dict, additional_state_dicts=additional_state_dicts, preset=shared.opts.forge_preset)
try:
sd_model = forge_loader(state_dict, additional_state_dicts=additional_state_dicts, preset=shared.opts.forge_preset)
except memory_management.OOM_EXCEPTION as e:
memory_management.emergency_memory_cleanup()
raise RuntimeError(
"Out of memory during model loading. Memory has been cleared. "
"Please try again with a lower GPU Weights setting or use a smaller model."
) from e
timer.record("forge model load")
sd_model.extra_generation_params = {}