mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-06-04 21:05:48 +08:00
update 30b req
This commit is contained in:
parent
aeafe08c14
commit
8c604355bf
39
vlm.py
39
vlm.py
@ -148,13 +148,14 @@ class VLMManager:
|
||||
|
||||
return "qwen2_5_vl" # Default fallback
|
||||
|
||||
def load_model(self, model_name: str, quantization: str = "none", use_flash_attn: bool = False, progress=gr.Progress()) -> str:
|
||||
def load_model(self, model_name: str, quantization: str = "none", use_flash_attn: bool = False, vram_buffer: int = 0, progress=gr.Progress()) -> str:
|
||||
"""Load a Qwen VL model.
|
||||
|
||||
Args:
|
||||
model_name: Name of the model to load
|
||||
quantization: "none", "4bit", or "8bit"
|
||||
use_flash_attn: Whether to use Flash Attention 2
|
||||
vram_buffer: GB of VRAM to reserve (for loading large models)
|
||||
progress: Gradio progress callback
|
||||
"""
|
||||
if model_name == "No models found":
|
||||
@ -231,13 +232,27 @@ class VLMManager:
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
print("Using 8-bit quantization...")
|
||||
|
||||
load_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||
load_in_8bit=True,
|
||||
)
|
||||
# Load entirely to GPU
|
||||
load_kwargs["device_map"] = {"": 0}
|
||||
|
||||
# Apply VRAM buffer if specified
|
||||
if vram_buffer > 0 and torch.cuda.is_available():
|
||||
total_mem_gb = torch.cuda.get_device_properties(0).total_memory / (1024**3)
|
||||
max_gpu = int(total_mem_gb - vram_buffer)
|
||||
if max_gpu > 0:
|
||||
load_kwargs["device_map"] = "auto"
|
||||
load_kwargs["max_memory"] = {0: f"{max_gpu}GiB", "cpu": "100GiB"}
|
||||
offload_dir = Path(tempfile.gettempdir()) / "vlm_offload"
|
||||
offload_dir.mkdir(exist_ok=True)
|
||||
load_kwargs["offload_folder"] = str(offload_dir)
|
||||
print(f"Using 8-bit quantization (max GPU: {max_gpu}GB, buffer: {vram_buffer}GB)")
|
||||
else:
|
||||
load_kwargs["device_map"] = "auto"
|
||||
print("Using 8-bit quantization...")
|
||||
else:
|
||||
load_kwargs["device_map"] = "auto"
|
||||
print("Using 8-bit quantization...")
|
||||
|
||||
except ImportError as e:
|
||||
print(f"Warning: bitsandbytes not installed, falling back to bfloat16")
|
||||
@ -422,11 +437,11 @@ def initialize_manager(low_vram: bool = False):
|
||||
vlm_manager = VLMManager(low_vram=low_vram)
|
||||
|
||||
|
||||
def load_model_handler(model_name: str, quantization: str, use_flash_attn: bool, progress=gr.Progress()):
|
||||
def load_model_handler(model_name: str, quantization: str, use_flash_attn: bool, vram_buffer: int, progress=gr.Progress()):
|
||||
"""Handle model loading from UI."""
|
||||
if vlm_manager is None:
|
||||
return "Manager not initialized"
|
||||
return vlm_manager.load_model(model_name, quantization, use_flash_attn, progress)
|
||||
return vlm_manager.load_model(model_name, quantization, use_flash_attn, int(vram_buffer), progress)
|
||||
|
||||
|
||||
def unload_model_handler():
|
||||
@ -782,6 +797,14 @@ def create_ui():
|
||||
)
|
||||
|
||||
gr.Markdown("### Memory Settings")
|
||||
vram_buffer = gr.Slider(
|
||||
minimum=0,
|
||||
maximum=32,
|
||||
value=0,
|
||||
step=1,
|
||||
label="VRAM Buffer (GB)",
|
||||
info="Reserve GPU memory during loading. Useful for large models with 8-bit quantization.",
|
||||
)
|
||||
auto_unload = gr.Checkbox(
|
||||
label="Auto-unload after generation",
|
||||
value=False,
|
||||
@ -870,7 +893,7 @@ def create_ui():
|
||||
|
||||
load_btn.click(
|
||||
fn=load_model_handler,
|
||||
inputs=[model_dropdown, quantization_dropdown, use_flash_attn],
|
||||
inputs=[model_dropdown, quantization_dropdown, use_flash_attn, vram_buffer],
|
||||
outputs=[model_status],
|
||||
)
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user