diff --git a/backend/memory_management.py b/backend/memory_management.py index 32936d43..031baf1e 100644 --- a/backend/memory_management.py +++ b/backend/memory_management.py @@ -1449,15 +1449,56 @@ def emergency_memory_cleanup(): print("[OOM Recovery] Emergency memory cleanup initiated...") print("="*60) - # Step 1: Unload all models from GPU - print("[OOM Recovery] Step 1: Unloading all models...") + # Step 1: Delete all model references (don't try to move - we're OOM) + # Note: forge_objects, forge_objects_original, and forge_objects_after_applying_lora + # are shallow copies that share the SAME model references, so we must clear all three + print("[OOM Recovery] Step 1: Deleting all model references...") + try: + from modules import shared + if hasattr(shared, 'sd_model') and shared.sd_model is not None: + model = shared.sd_model + + # Collect all forge_objects variants (they share the same underlying models) + forge_objects_list = [] + for attr in ['forge_objects', 'forge_objects_original', 'forge_objects_after_applying_lora']: + if hasattr(model, attr) and getattr(model, attr) is not None: + forge_objects_list.append(getattr(model, attr)) + setattr(model, attr, None) + + # Clear all component references from all forge_objects + for fo in forge_objects_list: + # Clear unet reference + if hasattr(fo, 'unet'): + fo.unet = None + # Clear clip reference + if hasattr(fo, 'clip'): + fo.clip = None + # Clear vae reference + if hasattr(fo, 'vae'): + fo.vae = None + # Clear clipvision reference + if hasattr(fo, 'clipvision'): + fo.clipvision = None + + del forge_objects_list + print("[OOM Recovery] Cleared forge_objects references") + + # Clear the main model reference + shared.sd_model = None + del model + print("[OOM Recovery] Cleared shared.sd_model") + except Exception as e: + print(f"[OOM Recovery] Warning during shared model cleanup: {e}") + + # Step 2: Unload all tracked models from GPU + print("[OOM Recovery] Step 2: Unloading all tracked models...") try: unload_all_models() except Exception as e: print(f"[OOM Recovery] Warning during model unload: {e}") - # Step 2: Clear the current_loaded_models list - print("[OOM Recovery] Step 2: Clearing model tracking list...") + # Step 3: Clear the current_loaded_models list + print("[OOM Recovery] Step 3: Clearing model tracking list...") global current_loaded_models try: for model in current_loaded_models: @@ -1469,18 +1510,72 @@ def emergency_memory_cleanup(): except Exception as e: print(f"[OOM Recovery] Warning during model list cleanup: {e}") - # Step 3: Force Python garbage collection - print("[OOM Recovery] Step 3: Running garbage collection...") - gc.collect() - gc.collect() # Run twice for thorough cleanup + # Step 4: Clear model_data if it exists + print("[OOM Recovery] Step 4: Clearing model data...") + try: + from modules import sd_models + if hasattr(sd_models, 'model_data'): + sd_models.model_data.sd_model = None + sd_models.model_data.forge_hash = None + except Exception as e: + print(f"[OOM Recovery] Warning during model_data cleanup: {e}") - # Step 4: Clear PyTorch CUDA cache - print("[OOM Recovery] Step 4: Clearing GPU cache...") + # Step 5: Aggressively move ALL modules to CPU by scanning globals + print("[OOM Recovery] Step 5: Moving all GPU tensors to CPU...") + try: + # Get all modules that might have been loaded + modules_to_check = [] + # Check for any nn.Module in the shared namespace + try: + from modules import shared + if hasattr(shared, 'sd_model') and shared.sd_model is not None: + modules_to_check.append(shared.sd_model) + except: + pass + + # Move any found modules to CPU + for mod in modules_to_check: + if mod is not None and isinstance(mod, torch.nn.Module): + try: + mod.to('cpu') + print(f"[OOM Recovery] Moved {mod.__class__.__name__} to CPU") + except: + pass + + # Nuclear option: iterate through ALL tensors tracked by PyTorch + # and delete any that are on CUDA + moved_count = 0 + for obj in gc.get_objects(): + try: + if isinstance(obj, torch.Tensor) and obj.device.type == 'cuda': + # Can't directly move, but we can help GC by removing refs + moved_count += 1 + except: + pass + if moved_count > 0: + print(f"[OOM Recovery] Found {moved_count} CUDA tensors (will be freed after gc)") + except Exception as e: + print(f"[OOM Recovery] Warning during tensor scan: {e}") + + # Step 6: Force Python garbage collection (multiple passes) + print("[OOM Recovery] Step 6: Running garbage collection...") + gc.collect() + gc.collect() + gc.collect() # Run three times for thorough cleanup + + # Step 7: Clear PyTorch CUDA cache aggressively + print("[OOM Recovery] Step 7: Clearing GPU cache...") try: if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.ipc_collect() torch.cuda.synchronize() + # Reset memory stats + try: + torch.cuda.reset_peak_memory_stats() + torch.cuda.reset_accumulated_memory_stats() + except: + pass elif is_intel_xpu(): torch.xpu.empty_cache() elif mps_mode(): @@ -1488,13 +1583,13 @@ def emergency_memory_cleanup(): except Exception as e: print(f"[OOM Recovery] Warning during cache clearing: {e}") - # Step 5: Reset signal flags + # Step 8: Final garbage collection after cache clear + gc.collect() + + # Step 9: Reset signal flags global signal_empty_cache signal_empty_cache = False - # Step 6: Final garbage collection - gc.collect() - # Report memory status try: device = get_torch_device() diff --git a/modules/sd_models.py b/modules/sd_models.py index 1d68f4cf..da38cbf0 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -501,7 +501,14 @@ def forge_model_reload(): dynamic_args['forge_unet_storage_dtype'] = model_data.forge_loading_parameters.get('unet_storage_dtype', None) dynamic_args['embedding_dir'] = cmd_opts.embeddings_dir dynamic_args['emphasis_name'] = opts.emphasis - sd_model = forge_loader(state_dict, additional_state_dicts=additional_state_dicts, preset=shared.opts.forge_preset) + try: + sd_model = forge_loader(state_dict, additional_state_dicts=additional_state_dicts, preset=shared.opts.forge_preset) + except memory_management.OOM_EXCEPTION as e: + memory_management.emergency_memory_cleanup() + raise RuntimeError( + "Out of memory during model loading. Memory has been cleared. " + "Please try again with a lower GPU Weights setting or use a smaller model." + ) from e timer.record("forge model load") sd_model.extra_generation_params = {}