mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-06-04 21:05:48 +08:00
add 30b moe support
This commit is contained in:
parent
ecd8d4e66d
commit
8f66d5f6cb
@ -258,13 +258,14 @@ def expand_prompt_with_llm(prompt, image=None, llm_model=None, system_prompt=Non
|
||||
_expansion_model_cache = {
|
||||
'model': None,
|
||||
'processor': None,
|
||||
'model_path': None
|
||||
'model_path': None,
|
||||
'model_type': None # 'vl' for standard VL, 'vl_moe' for VL+MoE
|
||||
}
|
||||
|
||||
|
||||
def expand_prompt_standalone(prompt: str, model_path: str, system_prompt: str = None, image=None, is_negative: bool = False, positive_prompt: str = None):
|
||||
"""
|
||||
Standalone prompt expansion using Qwen3-VL model.
|
||||
Standalone prompt expansion using Qwen3-VL models (standard VL or VL+MoE).
|
||||
|
||||
Args:
|
||||
prompt: The user's input prompt to expand
|
||||
@ -280,6 +281,7 @@ def expand_prompt_standalone(prompt: str, model_path: str, system_prompt: str =
|
||||
import torch
|
||||
import gc
|
||||
import time
|
||||
import json
|
||||
from modules.shared import opts
|
||||
|
||||
global _expansion_model_cache
|
||||
@ -307,6 +309,28 @@ def expand_prompt_standalone(prompt: str, model_path: str, system_prompt: str =
|
||||
pass
|
||||
return None
|
||||
|
||||
def detect_model_type(model_path):
|
||||
"""Detect whether model is standard VL or VL+MoE from config."""
|
||||
config_path = os.path.join(model_path, "config.json")
|
||||
if os.path.exists(config_path):
|
||||
try:
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
config = json.load(f)
|
||||
architectures = config.get("architectures", [])
|
||||
model_type = config.get("model_type", "")
|
||||
|
||||
# Check for VL+MoE models (e.g., Qwen3VLMoeForConditionalGeneration)
|
||||
if any("Moe" in arch or "MoE" in arch for arch in architectures):
|
||||
return "vl_moe"
|
||||
if "moe" in model_type.lower():
|
||||
return "vl_moe"
|
||||
|
||||
# Default to standard VL
|
||||
return "vl"
|
||||
except Exception as e:
|
||||
log_step(f" Warning: Could not read config.json: {e}")
|
||||
return "vl"
|
||||
|
||||
prompt_type = "NEGATIVE" if is_negative else "POSITIVE"
|
||||
print("\n" + "="*70, flush=True)
|
||||
log_step(f"{prompt_type} PROMPT EXPANSION PIPELINE STARTED")
|
||||
@ -330,6 +354,7 @@ def expand_prompt_standalone(prompt: str, model_path: str, system_prompt: str =
|
||||
del _expansion_model_cache['processor']
|
||||
_expansion_model_cache['model'] = None
|
||||
_expansion_model_cache['processor'] = None
|
||||
_expansion_model_cache['model_type'] = None
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
log_step("Previous model unloaded", unload_start)
|
||||
@ -337,8 +362,12 @@ def expand_prompt_standalone(prompt: str, model_path: str, system_prompt: str =
|
||||
log_step(f"Loading LLM model: {model_path}")
|
||||
load_start = time.time()
|
||||
|
||||
# Detect model type
|
||||
detected_type = detect_model_type(model_path)
|
||||
log_step(f" Detected model type: {detected_type}")
|
||||
|
||||
try:
|
||||
from transformers import Qwen3VLForConditionalGeneration, AutoProcessor
|
||||
from transformers import AutoProcessor
|
||||
|
||||
log_step(" Loading processor...")
|
||||
processor_start = time.time()
|
||||
@ -354,12 +383,48 @@ def expand_prompt_standalone(prompt: str, model_path: str, system_prompt: str =
|
||||
device_index = main_device.index if hasattr(main_device, 'index') and main_device.index is not None else 0
|
||||
log_step(f" Target device: cuda:{device_index}")
|
||||
|
||||
_expansion_model_cache['model'] = Qwen3VLForConditionalGeneration.from_pretrained(
|
||||
model_path,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map={"": f"cuda:{device_index}"},
|
||||
)
|
||||
# Calculate available VRAM for LLM (leave some headroom for diffusion model)
|
||||
total_vram = torch.cuda.get_device_properties(device_index).total_memory / 1024**3
|
||||
# Use 90% of total VRAM, let accelerate handle the split with CPU
|
||||
max_gpu_memory = f"{int(total_vram * 0.9)}GiB"
|
||||
max_memory = {device_index: max_gpu_memory, "cpu": "32GiB"}
|
||||
log_step(f" GPU {device_index} has {total_vram:.1f}GB, allowing up to {max_gpu_memory}")
|
||||
|
||||
if detected_type == "vl_moe":
|
||||
# Try loading VL+MoE model
|
||||
try:
|
||||
from transformers import Qwen3VLMoeForConditionalGeneration
|
||||
log_step(" Using Qwen3VLMoeForConditionalGeneration")
|
||||
_expansion_model_cache['model'] = Qwen3VLMoeForConditionalGeneration.from_pretrained(
|
||||
model_path,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="auto",
|
||||
max_memory=max_memory,
|
||||
)
|
||||
except ImportError:
|
||||
# Fall back to AutoModelForVision2Seq if specific class not available
|
||||
log_step(" Qwen3VLMoeForConditionalGeneration not available, using AutoModelForVision2Seq")
|
||||
from transformers import AutoModelForVision2Seq
|
||||
_expansion_model_cache['model'] = AutoModelForVision2Seq.from_pretrained(
|
||||
model_path,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="auto",
|
||||
max_memory=max_memory,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
else:
|
||||
# Standard VL model
|
||||
from transformers import Qwen3VLForConditionalGeneration
|
||||
log_step(" Using Qwen3VLForConditionalGeneration")
|
||||
_expansion_model_cache['model'] = Qwen3VLForConditionalGeneration.from_pretrained(
|
||||
model_path,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="auto",
|
||||
max_memory=max_memory,
|
||||
)
|
||||
|
||||
_expansion_model_cache['model_path'] = model_path
|
||||
_expansion_model_cache['model_type'] = detected_type
|
||||
log_step(" Model weights loaded", model_start)
|
||||
|
||||
gpu_mem = get_gpu_memory()
|
||||
@ -511,16 +576,32 @@ def expand_prompt_standalone(prompt: str, model_path: str, system_prompt: str =
|
||||
if gpu_before:
|
||||
log_step(f" Before unload: {gpu_before}")
|
||||
|
||||
# Properly unload Hugging Face model with device_map
|
||||
if _expansion_model_cache['model'] is not None:
|
||||
try:
|
||||
# Move model to CPU first to release GPU memory
|
||||
_expansion_model_cache['model'].to('cpu')
|
||||
except:
|
||||
pass
|
||||
# Clear any internal hooks from accelerate
|
||||
try:
|
||||
from accelerate.hooks import remove_hook_from_submodules
|
||||
remove_hook_from_submodules(_expansion_model_cache['model'])
|
||||
except:
|
||||
pass
|
||||
del _expansion_model_cache['model']
|
||||
_expansion_model_cache['model'] = None
|
||||
|
||||
if _expansion_model_cache['processor'] is not None:
|
||||
del _expansion_model_cache['processor']
|
||||
_expansion_model_cache['processor'] = None
|
||||
_expansion_model_cache['model_path'] = None
|
||||
|
||||
# Force garbage collection multiple times for thorough cleanup
|
||||
gc.collect()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
gpu_after = get_gpu_memory()
|
||||
if gpu_after:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user