mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-06-04 21:05:48 +08:00
add oom reset
This commit is contained in:
parent
f82182b082
commit
814016a4e4
@ -1449,15 +1449,56 @@ def emergency_memory_cleanup():
|
||||
print("[OOM Recovery] Emergency memory cleanup initiated...")
|
||||
print("="*60)
|
||||
|
||||
# Step 1: Unload all models from GPU
|
||||
print("[OOM Recovery] Step 1: Unloading all models...")
|
||||
# Step 1: Delete all model references (don't try to move - we're OOM)
|
||||
# Note: forge_objects, forge_objects_original, and forge_objects_after_applying_lora
|
||||
# are shallow copies that share the SAME model references, so we must clear all three
|
||||
print("[OOM Recovery] Step 1: Deleting all model references...")
|
||||
try:
|
||||
from modules import shared
|
||||
if hasattr(shared, 'sd_model') and shared.sd_model is not None:
|
||||
model = shared.sd_model
|
||||
|
||||
# Collect all forge_objects variants (they share the same underlying models)
|
||||
forge_objects_list = []
|
||||
for attr in ['forge_objects', 'forge_objects_original', 'forge_objects_after_applying_lora']:
|
||||
if hasattr(model, attr) and getattr(model, attr) is not None:
|
||||
forge_objects_list.append(getattr(model, attr))
|
||||
setattr(model, attr, None)
|
||||
|
||||
# Clear all component references from all forge_objects
|
||||
for fo in forge_objects_list:
|
||||
# Clear unet reference
|
||||
if hasattr(fo, 'unet'):
|
||||
fo.unet = None
|
||||
# Clear clip reference
|
||||
if hasattr(fo, 'clip'):
|
||||
fo.clip = None
|
||||
# Clear vae reference
|
||||
if hasattr(fo, 'vae'):
|
||||
fo.vae = None
|
||||
# Clear clipvision reference
|
||||
if hasattr(fo, 'clipvision'):
|
||||
fo.clipvision = None
|
||||
|
||||
del forge_objects_list
|
||||
print("[OOM Recovery] Cleared forge_objects references")
|
||||
|
||||
# Clear the main model reference
|
||||
shared.sd_model = None
|
||||
del model
|
||||
print("[OOM Recovery] Cleared shared.sd_model")
|
||||
except Exception as e:
|
||||
print(f"[OOM Recovery] Warning during shared model cleanup: {e}")
|
||||
|
||||
# Step 2: Unload all tracked models from GPU
|
||||
print("[OOM Recovery] Step 2: Unloading all tracked models...")
|
||||
try:
|
||||
unload_all_models()
|
||||
except Exception as e:
|
||||
print(f"[OOM Recovery] Warning during model unload: {e}")
|
||||
|
||||
# Step 2: Clear the current_loaded_models list
|
||||
print("[OOM Recovery] Step 2: Clearing model tracking list...")
|
||||
# Step 3: Clear the current_loaded_models list
|
||||
print("[OOM Recovery] Step 3: Clearing model tracking list...")
|
||||
global current_loaded_models
|
||||
try:
|
||||
for model in current_loaded_models:
|
||||
@ -1469,18 +1510,72 @@ def emergency_memory_cleanup():
|
||||
except Exception as e:
|
||||
print(f"[OOM Recovery] Warning during model list cleanup: {e}")
|
||||
|
||||
# Step 3: Force Python garbage collection
|
||||
print("[OOM Recovery] Step 3: Running garbage collection...")
|
||||
gc.collect()
|
||||
gc.collect() # Run twice for thorough cleanup
|
||||
# Step 4: Clear model_data if it exists
|
||||
print("[OOM Recovery] Step 4: Clearing model data...")
|
||||
try:
|
||||
from modules import sd_models
|
||||
if hasattr(sd_models, 'model_data'):
|
||||
sd_models.model_data.sd_model = None
|
||||
sd_models.model_data.forge_hash = None
|
||||
except Exception as e:
|
||||
print(f"[OOM Recovery] Warning during model_data cleanup: {e}")
|
||||
|
||||
# Step 4: Clear PyTorch CUDA cache
|
||||
print("[OOM Recovery] Step 4: Clearing GPU cache...")
|
||||
# Step 5: Aggressively move ALL modules to CPU by scanning globals
|
||||
print("[OOM Recovery] Step 5: Moving all GPU tensors to CPU...")
|
||||
try:
|
||||
# Get all modules that might have been loaded
|
||||
modules_to_check = []
|
||||
# Check for any nn.Module in the shared namespace
|
||||
try:
|
||||
from modules import shared
|
||||
if hasattr(shared, 'sd_model') and shared.sd_model is not None:
|
||||
modules_to_check.append(shared.sd_model)
|
||||
except:
|
||||
pass
|
||||
|
||||
# Move any found modules to CPU
|
||||
for mod in modules_to_check:
|
||||
if mod is not None and isinstance(mod, torch.nn.Module):
|
||||
try:
|
||||
mod.to('cpu')
|
||||
print(f"[OOM Recovery] Moved {mod.__class__.__name__} to CPU")
|
||||
except:
|
||||
pass
|
||||
|
||||
# Nuclear option: iterate through ALL tensors tracked by PyTorch
|
||||
# and delete any that are on CUDA
|
||||
moved_count = 0
|
||||
for obj in gc.get_objects():
|
||||
try:
|
||||
if isinstance(obj, torch.Tensor) and obj.device.type == 'cuda':
|
||||
# Can't directly move, but we can help GC by removing refs
|
||||
moved_count += 1
|
||||
except:
|
||||
pass
|
||||
if moved_count > 0:
|
||||
print(f"[OOM Recovery] Found {moved_count} CUDA tensors (will be freed after gc)")
|
||||
except Exception as e:
|
||||
print(f"[OOM Recovery] Warning during tensor scan: {e}")
|
||||
|
||||
# Step 6: Force Python garbage collection (multiple passes)
|
||||
print("[OOM Recovery] Step 6: Running garbage collection...")
|
||||
gc.collect()
|
||||
gc.collect()
|
||||
gc.collect() # Run three times for thorough cleanup
|
||||
|
||||
# Step 7: Clear PyTorch CUDA cache aggressively
|
||||
print("[OOM Recovery] Step 7: Clearing GPU cache...")
|
||||
try:
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.ipc_collect()
|
||||
torch.cuda.synchronize()
|
||||
# Reset memory stats
|
||||
try:
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
torch.cuda.reset_accumulated_memory_stats()
|
||||
except:
|
||||
pass
|
||||
elif is_intel_xpu():
|
||||
torch.xpu.empty_cache()
|
||||
elif mps_mode():
|
||||
@ -1488,13 +1583,13 @@ def emergency_memory_cleanup():
|
||||
except Exception as e:
|
||||
print(f"[OOM Recovery] Warning during cache clearing: {e}")
|
||||
|
||||
# Step 5: Reset signal flags
|
||||
# Step 8: Final garbage collection after cache clear
|
||||
gc.collect()
|
||||
|
||||
# Step 9: Reset signal flags
|
||||
global signal_empty_cache
|
||||
signal_empty_cache = False
|
||||
|
||||
# Step 6: Final garbage collection
|
||||
gc.collect()
|
||||
|
||||
# Report memory status
|
||||
try:
|
||||
device = get_torch_device()
|
||||
|
||||
@ -501,7 +501,14 @@ def forge_model_reload():
|
||||
dynamic_args['forge_unet_storage_dtype'] = model_data.forge_loading_parameters.get('unet_storage_dtype', None)
|
||||
dynamic_args['embedding_dir'] = cmd_opts.embeddings_dir
|
||||
dynamic_args['emphasis_name'] = opts.emphasis
|
||||
sd_model = forge_loader(state_dict, additional_state_dicts=additional_state_dicts, preset=shared.opts.forge_preset)
|
||||
try:
|
||||
sd_model = forge_loader(state_dict, additional_state_dicts=additional_state_dicts, preset=shared.opts.forge_preset)
|
||||
except memory_management.OOM_EXCEPTION as e:
|
||||
memory_management.emergency_memory_cleanup()
|
||||
raise RuntimeError(
|
||||
"Out of memory during model loading. Memory has been cleared. "
|
||||
"Please try again with a lower GPU Weights setting or use a smaller model."
|
||||
) from e
|
||||
timer.record("forge model load")
|
||||
|
||||
sd_model.extra_generation_params = {}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user