add 30b moe support

This commit is contained in:
maybleMyers 2025-12-05 04:58:52 -08:00
parent ecd8d4e66d
commit 8f66d5f6cb

View File

@ -258,13 +258,14 @@ def expand_prompt_with_llm(prompt, image=None, llm_model=None, system_prompt=Non
_expansion_model_cache = {
'model': None,
'processor': None,
'model_path': None
'model_path': None,
'model_type': None # 'vl' for standard VL, 'vl_moe' for VL+MoE
}
def expand_prompt_standalone(prompt: str, model_path: str, system_prompt: str = None, image=None, is_negative: bool = False, positive_prompt: str = None):
"""
Standalone prompt expansion using Qwen3-VL model.
Standalone prompt expansion using Qwen3-VL models (standard VL or VL+MoE).
Args:
prompt: The user's input prompt to expand
@ -280,6 +281,7 @@ def expand_prompt_standalone(prompt: str, model_path: str, system_prompt: str =
import torch
import gc
import time
import json
from modules.shared import opts
global _expansion_model_cache
@ -307,6 +309,28 @@ def expand_prompt_standalone(prompt: str, model_path: str, system_prompt: str =
pass
return None
def detect_model_type(model_path):
"""Detect whether model is standard VL or VL+MoE from config."""
config_path = os.path.join(model_path, "config.json")
if os.path.exists(config_path):
try:
with open(config_path, 'r', encoding='utf-8') as f:
config = json.load(f)
architectures = config.get("architectures", [])
model_type = config.get("model_type", "")
# Check for VL+MoE models (e.g., Qwen3VLMoeForConditionalGeneration)
if any("Moe" in arch or "MoE" in arch for arch in architectures):
return "vl_moe"
if "moe" in model_type.lower():
return "vl_moe"
# Default to standard VL
return "vl"
except Exception as e:
log_step(f" Warning: Could not read config.json: {e}")
return "vl"
prompt_type = "NEGATIVE" if is_negative else "POSITIVE"
print("\n" + "="*70, flush=True)
log_step(f"{prompt_type} PROMPT EXPANSION PIPELINE STARTED")
@ -330,6 +354,7 @@ def expand_prompt_standalone(prompt: str, model_path: str, system_prompt: str =
del _expansion_model_cache['processor']
_expansion_model_cache['model'] = None
_expansion_model_cache['processor'] = None
_expansion_model_cache['model_type'] = None
gc.collect()
torch.cuda.empty_cache()
log_step("Previous model unloaded", unload_start)
@ -337,8 +362,12 @@ def expand_prompt_standalone(prompt: str, model_path: str, system_prompt: str =
log_step(f"Loading LLM model: {model_path}")
load_start = time.time()
# Detect model type
detected_type = detect_model_type(model_path)
log_step(f" Detected model type: {detected_type}")
try:
from transformers import Qwen3VLForConditionalGeneration, AutoProcessor
from transformers import AutoProcessor
log_step(" Loading processor...")
processor_start = time.time()
@ -354,12 +383,48 @@ def expand_prompt_standalone(prompt: str, model_path: str, system_prompt: str =
device_index = main_device.index if hasattr(main_device, 'index') and main_device.index is not None else 0
log_step(f" Target device: cuda:{device_index}")
_expansion_model_cache['model'] = Qwen3VLForConditionalGeneration.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
device_map={"": f"cuda:{device_index}"},
)
# Calculate available VRAM for LLM (leave some headroom for diffusion model)
total_vram = torch.cuda.get_device_properties(device_index).total_memory / 1024**3
# Use 90% of total VRAM, let accelerate handle the split with CPU
max_gpu_memory = f"{int(total_vram * 0.9)}GiB"
max_memory = {device_index: max_gpu_memory, "cpu": "32GiB"}
log_step(f" GPU {device_index} has {total_vram:.1f}GB, allowing up to {max_gpu_memory}")
if detected_type == "vl_moe":
# Try loading VL+MoE model
try:
from transformers import Qwen3VLMoeForConditionalGeneration
log_step(" Using Qwen3VLMoeForConditionalGeneration")
_expansion_model_cache['model'] = Qwen3VLMoeForConditionalGeneration.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
device_map="auto",
max_memory=max_memory,
)
except ImportError:
# Fall back to AutoModelForVision2Seq if specific class not available
log_step(" Qwen3VLMoeForConditionalGeneration not available, using AutoModelForVision2Seq")
from transformers import AutoModelForVision2Seq
_expansion_model_cache['model'] = AutoModelForVision2Seq.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
device_map="auto",
max_memory=max_memory,
trust_remote_code=True,
)
else:
# Standard VL model
from transformers import Qwen3VLForConditionalGeneration
log_step(" Using Qwen3VLForConditionalGeneration")
_expansion_model_cache['model'] = Qwen3VLForConditionalGeneration.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
device_map="auto",
max_memory=max_memory,
)
_expansion_model_cache['model_path'] = model_path
_expansion_model_cache['model_type'] = detected_type
log_step(" Model weights loaded", model_start)
gpu_mem = get_gpu_memory()
@ -511,16 +576,32 @@ def expand_prompt_standalone(prompt: str, model_path: str, system_prompt: str =
if gpu_before:
log_step(f" Before unload: {gpu_before}")
# Properly unload Hugging Face model with device_map
if _expansion_model_cache['model'] is not None:
try:
# Move model to CPU first to release GPU memory
_expansion_model_cache['model'].to('cpu')
except:
pass
# Clear any internal hooks from accelerate
try:
from accelerate.hooks import remove_hook_from_submodules
remove_hook_from_submodules(_expansion_model_cache['model'])
except:
pass
del _expansion_model_cache['model']
_expansion_model_cache['model'] = None
if _expansion_model_cache['processor'] is not None:
del _expansion_model_cache['processor']
_expansion_model_cache['processor'] = None
_expansion_model_cache['model_path'] = None
# Force garbage collection multiple times for thorough cleanup
gc.collect()
gc.collect()
torch.cuda.empty_cache()
torch.cuda.synchronize()
gpu_after = get_gpu_memory()
if gpu_after: