diff --git a/backend/chromadct_memory_strategy.py b/backend/chromadct_memory_strategy.py index 0377b82c..375051e1 100644 --- a/backend/chromadct_memory_strategy.py +++ b/backend/chromadct_memory_strategy.py @@ -165,7 +165,8 @@ def get_chromadct_inference_memory_multiplier() -> float: if is_chromadct_model(None): # ChromaDCT processes in pixel space (3 channels) vs latent space (16 channels) # and has more efficient NeRF processing - return 1 # reset + # Reduce inference memory requirement for ChromaDCT + return 0.75 # Use 75% of normal inference memory return 1.0 diff --git a/backend/memory_management.py b/backend/memory_management.py index 2f5e25c3..b36aed22 100644 --- a/backend/memory_management.py +++ b/backend/memory_management.py @@ -691,7 +691,8 @@ class LoadedModel: # Reset signal_empty_cache after model loading is complete to prevent # unnecessary cache clearing during inference - signal_empty_cache = False + if not stream.should_use_stream(): + signal_empty_cache = False return self.real_model @@ -719,30 +720,24 @@ current_inference_memory = 1024 * 1024 * 1024 def minimum_inference_memory(): global current_inference_memory - + # Apply ChromaDCT-specific memory optimization try: - from modules import shared - if (hasattr(shared, 'sd_model') and shared.sd_model is not None and - hasattr(shared.sd_model, 'forge_objects') and - hasattr(shared.sd_model.forge_objects, 'vae') and - shared.sd_model.forge_objects.vae is None): - - # ChromaDCT models are more memory efficient - reduce inference memory requirement - # ChromaDCT works in pixel space (3 channels) vs latent space (16 channels) - # and has more efficient patch-based processing - chromadct_inference_memory = int(current_inference_memory * 1) # reset - + from backend import chromadct_memory_strategy + if chromadct_memory_strategy.is_chromadct_model(None): + multiplier = chromadct_memory_strategy.get_chromadct_inference_memory_multiplier() + chromadct_inference_memory = int(current_inference_memory * multiplier) + # Only print message once per session if not hasattr(minimum_inference_memory, '_chromadct_message_shown'): - print(f"ChromaDCT detected - reducing inference memory from {current_inference_memory / (1024**2):.0f} MB to {chromadct_inference_memory / (1024**2):.0f} MB") + print(f"ChromaDCT detected - optimizing inference memory from {current_inference_memory / (1024**2):.0f} MB to {chromadct_inference_memory / (1024**2):.0f} MB") minimum_inference_memory._chromadct_message_shown = True - + return chromadct_inference_memory - except: + except Exception as e: # Fallback to normal memory if detection fails pass - + return current_inference_memory @@ -800,6 +795,14 @@ def free_memory(memory_required, device, keep_loaded=[], free_all=False): def compute_model_gpu_memory_when_using_cpu_swap(current_free_mem, inference_memory): maximum_memory_available = current_free_mem - inference_memory + # When using async swap with user-specified GPU weights, respect that setting + from modules_forge import main_entry + if hasattr(main_entry, 'user_specified_model_memory') and main_entry.user_specified_model_memory is not None: + # User has explicitly set GPU weights - use that value + user_memory = main_entry.user_specified_model_memory * 1024 * 1024 # Convert MB to bytes + if user_memory <= maximum_memory_available: + return int(user_memory) + suggestion = max( maximum_memory_available / 1.3, maximum_memory_available - 1024 * 1024 * 1024 * 1.25 @@ -847,11 +850,21 @@ def load_models_gpu(models, memory_required=0, hard_memory_preservation=0): loaded_model.compute_inclusive_exclusive_memory() total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.exclusive_memory + loaded_model.inclusive_memory * 0.25 + # When using async swap, calculate total GPU budget to properly allocate between models + total_models_memory = sum(loaded_model.exclusive_memory for loaded_model in models_to_load) + gpu_memory_budget = None + if stream.should_use_stream() and len(models_to_load) > 1: + # Get the user-specified GPU memory budget + from modules_forge import main_entry + if hasattr(main_entry, 'user_specified_model_memory') and main_entry.user_specified_model_memory is not None: + gpu_memory_budget = main_entry.user_specified_model_memory * 1024 * 1024 # Convert MB to bytes + print(f"[Async Swap] Using user-specified GPU budget: {gpu_memory_budget / (1024 * 1024):.2f} MB for {len(models_to_load)} models") + for device in total_memory_required: if device != torch.device("cpu"): free_memory(total_memory_required[device] * 1.3 + memory_to_free, device, models_already_loaded) - for loaded_model in models_to_load: + for idx, loaded_model in enumerate(models_to_load): model = loaded_model.model torch_dev = model.load_device if is_device_cpu(torch_dev): @@ -869,11 +882,22 @@ def load_models_gpu(models, memory_required=0, hard_memory_preservation=0): print(f"[Memory Management] Target: {loaded_model.model.model.__class__.__name__}, Free GPU: {current_free_mem / (1024 * 1024):.2f} MB, Model Require: {model_require / (1024 * 1024):.2f} MB, Previously Loaded: {previously_loaded / (1024 * 1024):.2f} MB, Inference Require: {memory_for_inference / (1024 * 1024):.2f} MB, Remaining: {estimated_remaining_memory / (1024 * 1024):.2f} MB, ", end="") + # Use same memory allocation logic for both async and queue methods if estimated_remaining_memory < 0: vram_set_state = VRAMState.LOW_VRAM - model_gpu_memory_when_using_cpu_swap = compute_model_gpu_memory_when_using_cpu_swap(current_free_mem, memory_for_inference) - if previously_loaded > 0: - model_gpu_memory_when_using_cpu_swap = previously_loaded + if gpu_memory_budget is not None and total_models_memory > 0: + # Allocate GPU memory proportionally based on model size + model_proportion = loaded_model.exclusive_memory / total_models_memory + model_gpu_memory_when_using_cpu_swap = int(gpu_memory_budget * model_proportion) + # Ensure we leave some memory for the last model in case of rounding + if idx == len(models_to_load) - 1: + already_allocated = sum(getattr(m, 'allocated_gpu_memory', 0) for m in models_to_load[:idx]) + model_gpu_memory_when_using_cpu_swap = int(gpu_memory_budget - already_allocated) + loaded_model.allocated_gpu_memory = model_gpu_memory_when_using_cpu_swap + else: + model_gpu_memory_when_using_cpu_swap = compute_model_gpu_memory_when_using_cpu_swap(current_free_mem, memory_for_inference) + if previously_loaded > 0: + model_gpu_memory_when_using_cpu_swap = previously_loaded if vram_set_state == VRAMState.NO_VRAM: model_gpu_memory_when_using_cpu_swap = 0 diff --git a/backend/operations.py b/backend/operations.py index 7a7e6500..d8b2f7df 100644 --- a/backend/operations.py +++ b/backend/operations.py @@ -8,9 +8,6 @@ from backend import stream, memory_management, utils from backend.patcher.lora import merge_lora_to_weight -stash = {} - - def get_weight_and_bias(layer, weight_args=None, bias_args=None, weight_fn=None, bias_fn=None): scale_weight = getattr(layer, 'scale_weight', None) patches = getattr(layer, 'forge_online_loras', None) @@ -56,7 +53,7 @@ def get_weight_and_bias(layer, weight_args=None, bias_args=None, weight_fn=None, def weights_manual_cast(layer, x, skip_weight_dtype=False, skip_bias_dtype=False, weight_fn=None, bias_fn=None): weight, bias, signal = None, None, None - non_blocking = True + non_blocking = stream.should_use_stream() # Use async transfer if stream is enabled if getattr(x.device, 'type', None) == 'mps': non_blocking = False @@ -74,45 +71,21 @@ def weights_manual_cast(layer, x, skip_weight_dtype=False, skip_bias_dtype=False else: bias_args = dict(device=target_device, dtype=target_dtype, non_blocking=non_blocking) - if stream.should_use_stream(): - with stream.stream_context()(stream.mover_stream): - weight, bias = get_weight_and_bias(layer, weight_args, bias_args, weight_fn=weight_fn, bias_fn=bias_fn) - signal = stream.mover_stream.record_event() - else: - weight, bias = get_weight_and_bias(layer, weight_args, bias_args, weight_fn=weight_fn, bias_fn=bias_fn) + # Simply get weights with async flag, no special stream handling + weight, bias = get_weight_and_bias(layer, weight_args, bias_args, weight_fn=weight_fn, bias_fn=bias_fn) return weight, bias, signal @contextlib.contextmanager def main_stream_worker(weight, bias, signal): - if signal is None or not stream.should_use_stream(): - yield - return - - with stream.stream_context()(stream.current_stream): - stream.current_stream.wait_event(signal) - yield - finished_signal = stream.current_stream.record_event() - stash[id(finished_signal)] = (weight, bias, finished_signal) - - garbage = [] - for k, (w, b, s) in stash.items(): - if s.query(): - garbage.append(k) - - for k in garbage: - del stash[k] + # Simple pass-through - async transfers are handled by non_blocking flag + yield return def cleanup_cache(): - if not stream.should_use_stream(): - return - - stream.current_stream.synchronize() - stream.mover_stream.synchronize() - stash.clear() + # No longer needed since we're not using stash return diff --git a/backend/sampling/sampling_function.py b/backend/sampling/sampling_function.py index b7ed6624..67791edc 100644 --- a/backend/sampling/sampling_function.py +++ b/backend/sampling/sampling_function.py @@ -188,7 +188,9 @@ def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options): to_batch = to_batch_temp[:1] if memory_management.signal_empty_cache: - memory_management.soft_empty_cache() + # Don't empty cache during async swap to maintain model weights on GPU + if not memory_management.stream.should_use_stream(): + memory_management.soft_empty_cache() free_memory = memory_management.get_free_memory(x_in.device) diff --git a/modules_forge/main_entry.py b/modules_forge/main_entry.py index 8960d3ca..3ecbe9f2 100644 --- a/modules_forge/main_entry.py +++ b/modules_forge/main_entry.py @@ -10,6 +10,7 @@ from modules.shared import cmd_opts total_vram = int(memory_management.total_vram) +user_specified_model_memory = None # Track user-specified GPU weights ui_forge_preset: gr.Radio = None @@ -173,6 +174,7 @@ def ui_refresh_memory_management_settings(model_memory, async_loading, pin_share ) def refresh_memory_management_settings(async_loading=None, inference_memory=None, pin_shared_memory=None, model_memory=None): + global user_specified_model_memory # Fallback to defaults if values are not passed async_loading = async_loading if async_loading is not None else shared.opts.forge_async_loading inference_memory = inference_memory if inference_memory is not None else shared.opts.forge_inference_memory @@ -181,8 +183,10 @@ def refresh_memory_management_settings(async_loading=None, inference_memory=None # If model_memory is provided, calculate inference memory accordingly, otherwise use inference_memory directly if model_memory is None: model_memory = total_vram - inference_memory + user_specified_model_memory = None # Using inference memory slider else: inference_memory = total_vram - model_memory + user_specified_model_memory = model_memory # User explicitly set GPU weights shared.opts.set('forge_async_loading', async_loading) shared.opts.set('forge_inference_memory', inference_memory)