diff --git a/vlm.py b/vlm.py index 232c1c30..86f73dcc 100644 --- a/vlm.py +++ b/vlm.py @@ -123,6 +123,11 @@ class VLMManager: with open(config_path, "r", encoding="utf-8") as f: config = json.load(f) model_type = config.get("model_type", "").lower() + # Check for MoE models first + if "moe" in model_type: + if "qwen3" in model_type: + return "qwen3_vl_moe" + # Then standard models if "qwen3" in model_type: return "qwen3_vl" elif "qwen2_5" in model_type or "qwen2.5" in model_type: @@ -185,7 +190,9 @@ class VLMManager: progress(0.5, desc=f"Loading model weights ({model_type})...") # Select the correct model class based on detected type - if model_type == "qwen3_vl": + if model_type == "qwen3_vl_moe": + from transformers import Qwen3VLMoeForConditionalGeneration as ModelClass + elif model_type == "qwen3_vl": from transformers import Qwen3VLForConditionalGeneration as ModelClass else: from transformers import Qwen2_5_VLForConditionalGeneration as ModelClass