improve memory management when using async and normal chroma

This commit is contained in:
maybleMyers 2025-09-26 10:52:26 -07:00
parent 7aa721693a
commit 58f545bfba
5 changed files with 60 additions and 56 deletions

View File

@ -165,7 +165,8 @@ def get_chromadct_inference_memory_multiplier() -> float:
if is_chromadct_model(None):
# ChromaDCT processes in pixel space (3 channels) vs latent space (16 channels)
# and has more efficient NeRF processing
return 1 # reset
# Reduce inference memory requirement for ChromaDCT
return 0.75 # Use 75% of normal inference memory
return 1.0

View File

@ -691,7 +691,8 @@ class LoadedModel:
# Reset signal_empty_cache after model loading is complete to prevent
# unnecessary cache clearing during inference
signal_empty_cache = False
if not stream.should_use_stream():
signal_empty_cache = False
return self.real_model
@ -719,30 +720,24 @@ current_inference_memory = 1024 * 1024 * 1024
def minimum_inference_memory():
global current_inference_memory
# Apply ChromaDCT-specific memory optimization
try:
from modules import shared
if (hasattr(shared, 'sd_model') and shared.sd_model is not None and
hasattr(shared.sd_model, 'forge_objects') and
hasattr(shared.sd_model.forge_objects, 'vae') and
shared.sd_model.forge_objects.vae is None):
# ChromaDCT models are more memory efficient - reduce inference memory requirement
# ChromaDCT works in pixel space (3 channels) vs latent space (16 channels)
# and has more efficient patch-based processing
chromadct_inference_memory = int(current_inference_memory * 1) # reset
from backend import chromadct_memory_strategy
if chromadct_memory_strategy.is_chromadct_model(None):
multiplier = chromadct_memory_strategy.get_chromadct_inference_memory_multiplier()
chromadct_inference_memory = int(current_inference_memory * multiplier)
# Only print message once per session
if not hasattr(minimum_inference_memory, '_chromadct_message_shown'):
print(f"ChromaDCT detected - reducing inference memory from {current_inference_memory / (1024**2):.0f} MB to {chromadct_inference_memory / (1024**2):.0f} MB")
print(f"ChromaDCT detected - optimizing inference memory from {current_inference_memory / (1024**2):.0f} MB to {chromadct_inference_memory / (1024**2):.0f} MB")
minimum_inference_memory._chromadct_message_shown = True
return chromadct_inference_memory
except:
except Exception as e:
# Fallback to normal memory if detection fails
pass
return current_inference_memory
@ -800,6 +795,14 @@ def free_memory(memory_required, device, keep_loaded=[], free_all=False):
def compute_model_gpu_memory_when_using_cpu_swap(current_free_mem, inference_memory):
maximum_memory_available = current_free_mem - inference_memory
# When using async swap with user-specified GPU weights, respect that setting
from modules_forge import main_entry
if hasattr(main_entry, 'user_specified_model_memory') and main_entry.user_specified_model_memory is not None:
# User has explicitly set GPU weights - use that value
user_memory = main_entry.user_specified_model_memory * 1024 * 1024 # Convert MB to bytes
if user_memory <= maximum_memory_available:
return int(user_memory)
suggestion = max(
maximum_memory_available / 1.3,
maximum_memory_available - 1024 * 1024 * 1024 * 1.25
@ -847,11 +850,21 @@ def load_models_gpu(models, memory_required=0, hard_memory_preservation=0):
loaded_model.compute_inclusive_exclusive_memory()
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.exclusive_memory + loaded_model.inclusive_memory * 0.25
# When using async swap, calculate total GPU budget to properly allocate between models
total_models_memory = sum(loaded_model.exclusive_memory for loaded_model in models_to_load)
gpu_memory_budget = None
if stream.should_use_stream() and len(models_to_load) > 1:
# Get the user-specified GPU memory budget
from modules_forge import main_entry
if hasattr(main_entry, 'user_specified_model_memory') and main_entry.user_specified_model_memory is not None:
gpu_memory_budget = main_entry.user_specified_model_memory * 1024 * 1024 # Convert MB to bytes
print(f"[Async Swap] Using user-specified GPU budget: {gpu_memory_budget / (1024 * 1024):.2f} MB for {len(models_to_load)} models")
for device in total_memory_required:
if device != torch.device("cpu"):
free_memory(total_memory_required[device] * 1.3 + memory_to_free, device, models_already_loaded)
for loaded_model in models_to_load:
for idx, loaded_model in enumerate(models_to_load):
model = loaded_model.model
torch_dev = model.load_device
if is_device_cpu(torch_dev):
@ -869,11 +882,22 @@ def load_models_gpu(models, memory_required=0, hard_memory_preservation=0):
print(f"[Memory Management] Target: {loaded_model.model.model.__class__.__name__}, Free GPU: {current_free_mem / (1024 * 1024):.2f} MB, Model Require: {model_require / (1024 * 1024):.2f} MB, Previously Loaded: {previously_loaded / (1024 * 1024):.2f} MB, Inference Require: {memory_for_inference / (1024 * 1024):.2f} MB, Remaining: {estimated_remaining_memory / (1024 * 1024):.2f} MB, ", end="")
# Use same memory allocation logic for both async and queue methods
if estimated_remaining_memory < 0:
vram_set_state = VRAMState.LOW_VRAM
model_gpu_memory_when_using_cpu_swap = compute_model_gpu_memory_when_using_cpu_swap(current_free_mem, memory_for_inference)
if previously_loaded > 0:
model_gpu_memory_when_using_cpu_swap = previously_loaded
if gpu_memory_budget is not None and total_models_memory > 0:
# Allocate GPU memory proportionally based on model size
model_proportion = loaded_model.exclusive_memory / total_models_memory
model_gpu_memory_when_using_cpu_swap = int(gpu_memory_budget * model_proportion)
# Ensure we leave some memory for the last model in case of rounding
if idx == len(models_to_load) - 1:
already_allocated = sum(getattr(m, 'allocated_gpu_memory', 0) for m in models_to_load[:idx])
model_gpu_memory_when_using_cpu_swap = int(gpu_memory_budget - already_allocated)
loaded_model.allocated_gpu_memory = model_gpu_memory_when_using_cpu_swap
else:
model_gpu_memory_when_using_cpu_swap = compute_model_gpu_memory_when_using_cpu_swap(current_free_mem, memory_for_inference)
if previously_loaded > 0:
model_gpu_memory_when_using_cpu_swap = previously_loaded
if vram_set_state == VRAMState.NO_VRAM:
model_gpu_memory_when_using_cpu_swap = 0

View File

@ -8,9 +8,6 @@ from backend import stream, memory_management, utils
from backend.patcher.lora import merge_lora_to_weight
stash = {}
def get_weight_and_bias(layer, weight_args=None, bias_args=None, weight_fn=None, bias_fn=None):
scale_weight = getattr(layer, 'scale_weight', None)
patches = getattr(layer, 'forge_online_loras', None)
@ -56,7 +53,7 @@ def get_weight_and_bias(layer, weight_args=None, bias_args=None, weight_fn=None,
def weights_manual_cast(layer, x, skip_weight_dtype=False, skip_bias_dtype=False, weight_fn=None, bias_fn=None):
weight, bias, signal = None, None, None
non_blocking = True
non_blocking = stream.should_use_stream() # Use async transfer if stream is enabled
if getattr(x.device, 'type', None) == 'mps':
non_blocking = False
@ -74,45 +71,21 @@ def weights_manual_cast(layer, x, skip_weight_dtype=False, skip_bias_dtype=False
else:
bias_args = dict(device=target_device, dtype=target_dtype, non_blocking=non_blocking)
if stream.should_use_stream():
with stream.stream_context()(stream.mover_stream):
weight, bias = get_weight_and_bias(layer, weight_args, bias_args, weight_fn=weight_fn, bias_fn=bias_fn)
signal = stream.mover_stream.record_event()
else:
weight, bias = get_weight_and_bias(layer, weight_args, bias_args, weight_fn=weight_fn, bias_fn=bias_fn)
# Simply get weights with async flag, no special stream handling
weight, bias = get_weight_and_bias(layer, weight_args, bias_args, weight_fn=weight_fn, bias_fn=bias_fn)
return weight, bias, signal
@contextlib.contextmanager
def main_stream_worker(weight, bias, signal):
if signal is None or not stream.should_use_stream():
yield
return
with stream.stream_context()(stream.current_stream):
stream.current_stream.wait_event(signal)
yield
finished_signal = stream.current_stream.record_event()
stash[id(finished_signal)] = (weight, bias, finished_signal)
garbage = []
for k, (w, b, s) in stash.items():
if s.query():
garbage.append(k)
for k in garbage:
del stash[k]
# Simple pass-through - async transfers are handled by non_blocking flag
yield
return
def cleanup_cache():
if not stream.should_use_stream():
return
stream.current_stream.synchronize()
stream.mover_stream.synchronize()
stash.clear()
# No longer needed since we're not using stash
return

View File

@ -188,7 +188,9 @@ def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options):
to_batch = to_batch_temp[:1]
if memory_management.signal_empty_cache:
memory_management.soft_empty_cache()
# Don't empty cache during async swap to maintain model weights on GPU
if not memory_management.stream.should_use_stream():
memory_management.soft_empty_cache()
free_memory = memory_management.get_free_memory(x_in.device)

View File

@ -10,6 +10,7 @@ from modules.shared import cmd_opts
total_vram = int(memory_management.total_vram)
user_specified_model_memory = None # Track user-specified GPU weights
ui_forge_preset: gr.Radio = None
@ -173,6 +174,7 @@ def ui_refresh_memory_management_settings(model_memory, async_loading, pin_share
)
def refresh_memory_management_settings(async_loading=None, inference_memory=None, pin_shared_memory=None, model_memory=None):
global user_specified_model_memory
# Fallback to defaults if values are not passed
async_loading = async_loading if async_loading is not None else shared.opts.forge_async_loading
inference_memory = inference_memory if inference_memory is not None else shared.opts.forge_inference_memory
@ -181,8 +183,10 @@ def refresh_memory_management_settings(async_loading=None, inference_memory=None
# If model_memory is provided, calculate inference memory accordingly, otherwise use inference_memory directly
if model_memory is None:
model_memory = total_vram - inference_memory
user_specified_model_memory = None # Using inference memory slider
else:
inference_memory = total_vram - model_memory
user_specified_model_memory = model_memory # User explicitly set GPU weights
shared.opts.set('forge_async_loading', async_loading)
shared.opts.set('forge_inference_memory', inference_memory)