diff --git a/modules/ui.py b/modules/ui.py index 8218eadc..860b8cd9 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -258,13 +258,14 @@ def expand_prompt_with_llm(prompt, image=None, llm_model=None, system_prompt=Non _expansion_model_cache = { 'model': None, 'processor': None, - 'model_path': None + 'model_path': None, + 'model_type': None # 'vl' for standard VL, 'vl_moe' for VL+MoE } def expand_prompt_standalone(prompt: str, model_path: str, system_prompt: str = None, image=None, is_negative: bool = False, positive_prompt: str = None): """ - Standalone prompt expansion using Qwen3-VL model. + Standalone prompt expansion using Qwen3-VL models (standard VL or VL+MoE). Args: prompt: The user's input prompt to expand @@ -280,6 +281,7 @@ def expand_prompt_standalone(prompt: str, model_path: str, system_prompt: str = import torch import gc import time + import json from modules.shared import opts global _expansion_model_cache @@ -307,6 +309,28 @@ def expand_prompt_standalone(prompt: str, model_path: str, system_prompt: str = pass return None + def detect_model_type(model_path): + """Detect whether model is standard VL or VL+MoE from config.""" + config_path = os.path.join(model_path, "config.json") + if os.path.exists(config_path): + try: + with open(config_path, 'r', encoding='utf-8') as f: + config = json.load(f) + architectures = config.get("architectures", []) + model_type = config.get("model_type", "") + + # Check for VL+MoE models (e.g., Qwen3VLMoeForConditionalGeneration) + if any("Moe" in arch or "MoE" in arch for arch in architectures): + return "vl_moe" + if "moe" in model_type.lower(): + return "vl_moe" + + # Default to standard VL + return "vl" + except Exception as e: + log_step(f" Warning: Could not read config.json: {e}") + return "vl" + prompt_type = "NEGATIVE" if is_negative else "POSITIVE" print("\n" + "="*70, flush=True) log_step(f"{prompt_type} PROMPT EXPANSION PIPELINE STARTED") @@ -330,6 +354,7 @@ def expand_prompt_standalone(prompt: str, model_path: str, system_prompt: str = del _expansion_model_cache['processor'] _expansion_model_cache['model'] = None _expansion_model_cache['processor'] = None + _expansion_model_cache['model_type'] = None gc.collect() torch.cuda.empty_cache() log_step("Previous model unloaded", unload_start) @@ -337,8 +362,12 @@ def expand_prompt_standalone(prompt: str, model_path: str, system_prompt: str = log_step(f"Loading LLM model: {model_path}") load_start = time.time() + # Detect model type + detected_type = detect_model_type(model_path) + log_step(f" Detected model type: {detected_type}") + try: - from transformers import Qwen3VLForConditionalGeneration, AutoProcessor + from transformers import AutoProcessor log_step(" Loading processor...") processor_start = time.time() @@ -354,12 +383,48 @@ def expand_prompt_standalone(prompt: str, model_path: str, system_prompt: str = device_index = main_device.index if hasattr(main_device, 'index') and main_device.index is not None else 0 log_step(f" Target device: cuda:{device_index}") - _expansion_model_cache['model'] = Qwen3VLForConditionalGeneration.from_pretrained( - model_path, - torch_dtype=torch.bfloat16, - device_map={"": f"cuda:{device_index}"}, - ) + # Calculate available VRAM for LLM (leave some headroom for diffusion model) + total_vram = torch.cuda.get_device_properties(device_index).total_memory / 1024**3 + # Use 90% of total VRAM, let accelerate handle the split with CPU + max_gpu_memory = f"{int(total_vram * 0.9)}GiB" + max_memory = {device_index: max_gpu_memory, "cpu": "32GiB"} + log_step(f" GPU {device_index} has {total_vram:.1f}GB, allowing up to {max_gpu_memory}") + + if detected_type == "vl_moe": + # Try loading VL+MoE model + try: + from transformers import Qwen3VLMoeForConditionalGeneration + log_step(" Using Qwen3VLMoeForConditionalGeneration") + _expansion_model_cache['model'] = Qwen3VLMoeForConditionalGeneration.from_pretrained( + model_path, + torch_dtype=torch.bfloat16, + device_map="auto", + max_memory=max_memory, + ) + except ImportError: + # Fall back to AutoModelForVision2Seq if specific class not available + log_step(" Qwen3VLMoeForConditionalGeneration not available, using AutoModelForVision2Seq") + from transformers import AutoModelForVision2Seq + _expansion_model_cache['model'] = AutoModelForVision2Seq.from_pretrained( + model_path, + torch_dtype=torch.bfloat16, + device_map="auto", + max_memory=max_memory, + trust_remote_code=True, + ) + else: + # Standard VL model + from transformers import Qwen3VLForConditionalGeneration + log_step(" Using Qwen3VLForConditionalGeneration") + _expansion_model_cache['model'] = Qwen3VLForConditionalGeneration.from_pretrained( + model_path, + torch_dtype=torch.bfloat16, + device_map="auto", + max_memory=max_memory, + ) + _expansion_model_cache['model_path'] = model_path + _expansion_model_cache['model_type'] = detected_type log_step(" Model weights loaded", model_start) gpu_mem = get_gpu_memory() @@ -511,16 +576,32 @@ def expand_prompt_standalone(prompt: str, model_path: str, system_prompt: str = if gpu_before: log_step(f" Before unload: {gpu_before}") + # Properly unload Hugging Face model with device_map if _expansion_model_cache['model'] is not None: + try: + # Move model to CPU first to release GPU memory + _expansion_model_cache['model'].to('cpu') + except: + pass + # Clear any internal hooks from accelerate + try: + from accelerate.hooks import remove_hook_from_submodules + remove_hook_from_submodules(_expansion_model_cache['model']) + except: + pass del _expansion_model_cache['model'] _expansion_model_cache['model'] = None + if _expansion_model_cache['processor'] is not None: del _expansion_model_cache['processor'] _expansion_model_cache['processor'] = None _expansion_model_cache['model_path'] = None + # Force garbage collection multiple times for thorough cleanup + gc.collect() gc.collect() torch.cuda.empty_cache() + torch.cuda.synchronize() gpu_after = get_gpu_memory() if gpu_after: