mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-06-13 21:01:06 +08:00
improve memory management when using async and normal chroma
This commit is contained in:
parent
7aa721693a
commit
58f545bfba
@ -165,7 +165,8 @@ def get_chromadct_inference_memory_multiplier() -> float:
|
||||
if is_chromadct_model(None):
|
||||
# ChromaDCT processes in pixel space (3 channels) vs latent space (16 channels)
|
||||
# and has more efficient NeRF processing
|
||||
return 1 # reset
|
||||
# Reduce inference memory requirement for ChromaDCT
|
||||
return 0.75 # Use 75% of normal inference memory
|
||||
return 1.0
|
||||
|
||||
|
||||
|
||||
@ -691,7 +691,8 @@ class LoadedModel:
|
||||
|
||||
# Reset signal_empty_cache after model loading is complete to prevent
|
||||
# unnecessary cache clearing during inference
|
||||
signal_empty_cache = False
|
||||
if not stream.should_use_stream():
|
||||
signal_empty_cache = False
|
||||
|
||||
return self.real_model
|
||||
|
||||
@ -719,30 +720,24 @@ current_inference_memory = 1024 * 1024 * 1024
|
||||
|
||||
def minimum_inference_memory():
|
||||
global current_inference_memory
|
||||
|
||||
|
||||
# Apply ChromaDCT-specific memory optimization
|
||||
try:
|
||||
from modules import shared
|
||||
if (hasattr(shared, 'sd_model') and shared.sd_model is not None and
|
||||
hasattr(shared.sd_model, 'forge_objects') and
|
||||
hasattr(shared.sd_model.forge_objects, 'vae') and
|
||||
shared.sd_model.forge_objects.vae is None):
|
||||
|
||||
# ChromaDCT models are more memory efficient - reduce inference memory requirement
|
||||
# ChromaDCT works in pixel space (3 channels) vs latent space (16 channels)
|
||||
# and has more efficient patch-based processing
|
||||
chromadct_inference_memory = int(current_inference_memory * 1) # reset
|
||||
|
||||
from backend import chromadct_memory_strategy
|
||||
if chromadct_memory_strategy.is_chromadct_model(None):
|
||||
multiplier = chromadct_memory_strategy.get_chromadct_inference_memory_multiplier()
|
||||
chromadct_inference_memory = int(current_inference_memory * multiplier)
|
||||
|
||||
# Only print message once per session
|
||||
if not hasattr(minimum_inference_memory, '_chromadct_message_shown'):
|
||||
print(f"ChromaDCT detected - reducing inference memory from {current_inference_memory / (1024**2):.0f} MB to {chromadct_inference_memory / (1024**2):.0f} MB")
|
||||
print(f"ChromaDCT detected - optimizing inference memory from {current_inference_memory / (1024**2):.0f} MB to {chromadct_inference_memory / (1024**2):.0f} MB")
|
||||
minimum_inference_memory._chromadct_message_shown = True
|
||||
|
||||
|
||||
return chromadct_inference_memory
|
||||
except:
|
||||
except Exception as e:
|
||||
# Fallback to normal memory if detection fails
|
||||
pass
|
||||
|
||||
|
||||
return current_inference_memory
|
||||
|
||||
|
||||
@ -800,6 +795,14 @@ def free_memory(memory_required, device, keep_loaded=[], free_all=False):
|
||||
def compute_model_gpu_memory_when_using_cpu_swap(current_free_mem, inference_memory):
|
||||
maximum_memory_available = current_free_mem - inference_memory
|
||||
|
||||
# When using async swap with user-specified GPU weights, respect that setting
|
||||
from modules_forge import main_entry
|
||||
if hasattr(main_entry, 'user_specified_model_memory') and main_entry.user_specified_model_memory is not None:
|
||||
# User has explicitly set GPU weights - use that value
|
||||
user_memory = main_entry.user_specified_model_memory * 1024 * 1024 # Convert MB to bytes
|
||||
if user_memory <= maximum_memory_available:
|
||||
return int(user_memory)
|
||||
|
||||
suggestion = max(
|
||||
maximum_memory_available / 1.3,
|
||||
maximum_memory_available - 1024 * 1024 * 1024 * 1.25
|
||||
@ -847,11 +850,21 @@ def load_models_gpu(models, memory_required=0, hard_memory_preservation=0):
|
||||
loaded_model.compute_inclusive_exclusive_memory()
|
||||
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.exclusive_memory + loaded_model.inclusive_memory * 0.25
|
||||
|
||||
# When using async swap, calculate total GPU budget to properly allocate between models
|
||||
total_models_memory = sum(loaded_model.exclusive_memory for loaded_model in models_to_load)
|
||||
gpu_memory_budget = None
|
||||
if stream.should_use_stream() and len(models_to_load) > 1:
|
||||
# Get the user-specified GPU memory budget
|
||||
from modules_forge import main_entry
|
||||
if hasattr(main_entry, 'user_specified_model_memory') and main_entry.user_specified_model_memory is not None:
|
||||
gpu_memory_budget = main_entry.user_specified_model_memory * 1024 * 1024 # Convert MB to bytes
|
||||
print(f"[Async Swap] Using user-specified GPU budget: {gpu_memory_budget / (1024 * 1024):.2f} MB for {len(models_to_load)} models")
|
||||
|
||||
for device in total_memory_required:
|
||||
if device != torch.device("cpu"):
|
||||
free_memory(total_memory_required[device] * 1.3 + memory_to_free, device, models_already_loaded)
|
||||
|
||||
for loaded_model in models_to_load:
|
||||
for idx, loaded_model in enumerate(models_to_load):
|
||||
model = loaded_model.model
|
||||
torch_dev = model.load_device
|
||||
if is_device_cpu(torch_dev):
|
||||
@ -869,11 +882,22 @@ def load_models_gpu(models, memory_required=0, hard_memory_preservation=0):
|
||||
|
||||
print(f"[Memory Management] Target: {loaded_model.model.model.__class__.__name__}, Free GPU: {current_free_mem / (1024 * 1024):.2f} MB, Model Require: {model_require / (1024 * 1024):.2f} MB, Previously Loaded: {previously_loaded / (1024 * 1024):.2f} MB, Inference Require: {memory_for_inference / (1024 * 1024):.2f} MB, Remaining: {estimated_remaining_memory / (1024 * 1024):.2f} MB, ", end="")
|
||||
|
||||
# Use same memory allocation logic for both async and queue methods
|
||||
if estimated_remaining_memory < 0:
|
||||
vram_set_state = VRAMState.LOW_VRAM
|
||||
model_gpu_memory_when_using_cpu_swap = compute_model_gpu_memory_when_using_cpu_swap(current_free_mem, memory_for_inference)
|
||||
if previously_loaded > 0:
|
||||
model_gpu_memory_when_using_cpu_swap = previously_loaded
|
||||
if gpu_memory_budget is not None and total_models_memory > 0:
|
||||
# Allocate GPU memory proportionally based on model size
|
||||
model_proportion = loaded_model.exclusive_memory / total_models_memory
|
||||
model_gpu_memory_when_using_cpu_swap = int(gpu_memory_budget * model_proportion)
|
||||
# Ensure we leave some memory for the last model in case of rounding
|
||||
if idx == len(models_to_load) - 1:
|
||||
already_allocated = sum(getattr(m, 'allocated_gpu_memory', 0) for m in models_to_load[:idx])
|
||||
model_gpu_memory_when_using_cpu_swap = int(gpu_memory_budget - already_allocated)
|
||||
loaded_model.allocated_gpu_memory = model_gpu_memory_when_using_cpu_swap
|
||||
else:
|
||||
model_gpu_memory_when_using_cpu_swap = compute_model_gpu_memory_when_using_cpu_swap(current_free_mem, memory_for_inference)
|
||||
if previously_loaded > 0:
|
||||
model_gpu_memory_when_using_cpu_swap = previously_loaded
|
||||
|
||||
if vram_set_state == VRAMState.NO_VRAM:
|
||||
model_gpu_memory_when_using_cpu_swap = 0
|
||||
|
||||
@ -8,9 +8,6 @@ from backend import stream, memory_management, utils
|
||||
from backend.patcher.lora import merge_lora_to_weight
|
||||
|
||||
|
||||
stash = {}
|
||||
|
||||
|
||||
def get_weight_and_bias(layer, weight_args=None, bias_args=None, weight_fn=None, bias_fn=None):
|
||||
scale_weight = getattr(layer, 'scale_weight', None)
|
||||
patches = getattr(layer, 'forge_online_loras', None)
|
||||
@ -56,7 +53,7 @@ def get_weight_and_bias(layer, weight_args=None, bias_args=None, weight_fn=None,
|
||||
|
||||
def weights_manual_cast(layer, x, skip_weight_dtype=False, skip_bias_dtype=False, weight_fn=None, bias_fn=None):
|
||||
weight, bias, signal = None, None, None
|
||||
non_blocking = True
|
||||
non_blocking = stream.should_use_stream() # Use async transfer if stream is enabled
|
||||
|
||||
if getattr(x.device, 'type', None) == 'mps':
|
||||
non_blocking = False
|
||||
@ -74,45 +71,21 @@ def weights_manual_cast(layer, x, skip_weight_dtype=False, skip_bias_dtype=False
|
||||
else:
|
||||
bias_args = dict(device=target_device, dtype=target_dtype, non_blocking=non_blocking)
|
||||
|
||||
if stream.should_use_stream():
|
||||
with stream.stream_context()(stream.mover_stream):
|
||||
weight, bias = get_weight_and_bias(layer, weight_args, bias_args, weight_fn=weight_fn, bias_fn=bias_fn)
|
||||
signal = stream.mover_stream.record_event()
|
||||
else:
|
||||
weight, bias = get_weight_and_bias(layer, weight_args, bias_args, weight_fn=weight_fn, bias_fn=bias_fn)
|
||||
# Simply get weights with async flag, no special stream handling
|
||||
weight, bias = get_weight_and_bias(layer, weight_args, bias_args, weight_fn=weight_fn, bias_fn=bias_fn)
|
||||
|
||||
return weight, bias, signal
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def main_stream_worker(weight, bias, signal):
|
||||
if signal is None or not stream.should_use_stream():
|
||||
yield
|
||||
return
|
||||
|
||||
with stream.stream_context()(stream.current_stream):
|
||||
stream.current_stream.wait_event(signal)
|
||||
yield
|
||||
finished_signal = stream.current_stream.record_event()
|
||||
stash[id(finished_signal)] = (weight, bias, finished_signal)
|
||||
|
||||
garbage = []
|
||||
for k, (w, b, s) in stash.items():
|
||||
if s.query():
|
||||
garbage.append(k)
|
||||
|
||||
for k in garbage:
|
||||
del stash[k]
|
||||
# Simple pass-through - async transfers are handled by non_blocking flag
|
||||
yield
|
||||
return
|
||||
|
||||
|
||||
def cleanup_cache():
|
||||
if not stream.should_use_stream():
|
||||
return
|
||||
|
||||
stream.current_stream.synchronize()
|
||||
stream.mover_stream.synchronize()
|
||||
stash.clear()
|
||||
# No longer needed since we're not using stash
|
||||
return
|
||||
|
||||
|
||||
|
||||
@ -188,7 +188,9 @@ def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options):
|
||||
to_batch = to_batch_temp[:1]
|
||||
|
||||
if memory_management.signal_empty_cache:
|
||||
memory_management.soft_empty_cache()
|
||||
# Don't empty cache during async swap to maintain model weights on GPU
|
||||
if not memory_management.stream.should_use_stream():
|
||||
memory_management.soft_empty_cache()
|
||||
|
||||
free_memory = memory_management.get_free_memory(x_in.device)
|
||||
|
||||
|
||||
@ -10,6 +10,7 @@ from modules.shared import cmd_opts
|
||||
|
||||
|
||||
total_vram = int(memory_management.total_vram)
|
||||
user_specified_model_memory = None # Track user-specified GPU weights
|
||||
|
||||
ui_forge_preset: gr.Radio = None
|
||||
|
||||
@ -173,6 +174,7 @@ def ui_refresh_memory_management_settings(model_memory, async_loading, pin_share
|
||||
)
|
||||
|
||||
def refresh_memory_management_settings(async_loading=None, inference_memory=None, pin_shared_memory=None, model_memory=None):
|
||||
global user_specified_model_memory
|
||||
# Fallback to defaults if values are not passed
|
||||
async_loading = async_loading if async_loading is not None else shared.opts.forge_async_loading
|
||||
inference_memory = inference_memory if inference_memory is not None else shared.opts.forge_inference_memory
|
||||
@ -181,8 +183,10 @@ def refresh_memory_management_settings(async_loading=None, inference_memory=None
|
||||
# If model_memory is provided, calculate inference memory accordingly, otherwise use inference_memory directly
|
||||
if model_memory is None:
|
||||
model_memory = total_vram - inference_memory
|
||||
user_specified_model_memory = None # Using inference memory slider
|
||||
else:
|
||||
inference_memory = total_vram - model_memory
|
||||
user_specified_model_memory = model_memory # User explicitly set GPU weights
|
||||
|
||||
shared.opts.set('forge_async_loading', async_loading)
|
||||
shared.opts.set('forge_inference_memory', inference_memory)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user