mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-06-04 21:05:48 +08:00
prompt expansion
This commit is contained in:
parent
cecbcd8525
commit
63a337707a
@ -491,13 +491,14 @@ class ZImage(ForgeDiffusionEngine):
|
||||
_generation_processor = None
|
||||
|
||||
@torch.inference_mode()
|
||||
def expand_prompt(self, prompt: str, max_new_tokens: int = None, temperature: float = None) -> str:
|
||||
def expand_prompt(self, prompt: str, image=None, max_new_tokens: int = None, temperature: float = None) -> str:
|
||||
"""
|
||||
Expand a prompt using Qwen3-VL model for generation.
|
||||
Loads a separate pre-trained Qwen3-VL model for text generation.
|
||||
|
||||
Args:
|
||||
prompt: The user's input prompt to expand
|
||||
image: Optional PIL Image to use as context (from img2img)
|
||||
max_new_tokens: Maximum tokens to generate (uses settings if None)
|
||||
temperature: Generation temperature (uses settings if None)
|
||||
"""
|
||||
@ -553,10 +554,18 @@ User input prompt: '''
|
||||
# Format the expansion request as a chat message
|
||||
full_prompt = expansion_template + prompt
|
||||
|
||||
# Use chat format for Qwen3-VL (text-only, no images)
|
||||
messages = [
|
||||
{"role": "user", "content": [{"type": "text", "text": full_prompt}]}
|
||||
]
|
||||
# Build message content based on whether image is provided
|
||||
if image is not None:
|
||||
print("Using image context for prompt expansion...")
|
||||
# Include image in the message for vision-language understanding
|
||||
content = [
|
||||
{"type": "image", "image": image},
|
||||
{"type": "text", "text": full_prompt}
|
||||
]
|
||||
else:
|
||||
content = [{"type": "text", "text": full_prompt}]
|
||||
|
||||
messages = [{"role": "user", "content": content}]
|
||||
|
||||
# Apply chat template
|
||||
text_input = processor.apply_chat_template(
|
||||
@ -565,12 +574,20 @@ User input prompt: '''
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
|
||||
# Process inputs (text only, no images)
|
||||
inputs = processor(
|
||||
text=[text_input],
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
# Process inputs (with or without image)
|
||||
if image is not None:
|
||||
inputs = processor(
|
||||
text=[text_input],
|
||||
images=[image],
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
else:
|
||||
inputs = processor(
|
||||
text=[text_input],
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
# Move to device
|
||||
device = next(model.parameters()).device
|
||||
@ -594,17 +611,36 @@ User input prompt: '''
|
||||
raw_output = processor.decode(generated_ids, skip_special_tokens=True)
|
||||
|
||||
# Print full output to console (including thinking if present)
|
||||
print("\n" + "="*60)
|
||||
print("PROMPT EXPANSION OUTPUT:")
|
||||
print("="*60)
|
||||
print(raw_output)
|
||||
print("="*60 + "\n")
|
||||
print("\n" + "="*60, flush=True)
|
||||
print("PROMPT EXPANSION OUTPUT:", flush=True)
|
||||
print("="*60, flush=True)
|
||||
print(raw_output, flush=True)
|
||||
print("="*60, flush=True)
|
||||
|
||||
# Clean up the output - remove any thinking tags if present
|
||||
expanded_prompt = raw_output
|
||||
if "</think>" in expanded_prompt:
|
||||
expanded_prompt = expanded_prompt.split("</think>")[-1].strip()
|
||||
|
||||
print("\nCLEANED PROMPT:", flush=True)
|
||||
print("-"*60, flush=True)
|
||||
print(expanded_prompt, flush=True)
|
||||
print("="*60 + "\n", flush=True)
|
||||
|
||||
# Unload Qwen3-VL model to free VRAM for image generation
|
||||
print("Unloading Qwen3-VL model to free VRAM...", flush=True)
|
||||
if ZImage._generation_model is not None:
|
||||
del ZImage._generation_model
|
||||
ZImage._generation_model = None
|
||||
if ZImage._generation_processor is not None:
|
||||
del ZImage._generation_processor
|
||||
ZImage._generation_processor = None
|
||||
|
||||
import gc
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
print("Qwen3-VL model unloaded.", flush=True)
|
||||
|
||||
return expanded_prompt.strip()
|
||||
|
||||
@torch.inference_mode()
|
||||
|
||||
@ -153,7 +153,7 @@ def interrogate_deepbooru(image):
|
||||
return gr.update() if prompt is None else prompt
|
||||
|
||||
|
||||
def expand_prompt_with_llm(prompt):
|
||||
def expand_prompt_with_llm(prompt, image=None):
|
||||
"""Expand the prompt using the loaded model's LLM capabilities (Z-Image only)."""
|
||||
try:
|
||||
# Ensure model is loaded (forge loads models lazily)
|
||||
@ -161,33 +161,50 @@ def expand_prompt_with_llm(prompt):
|
||||
|
||||
if model is None or isinstance(model, sd_models.FakeInitialModel):
|
||||
gr.Warning("No model loaded. Please select a Z-Image model first.")
|
||||
return [prompt]
|
||||
return prompt
|
||||
|
||||
# Check if this is a Z-Image model with expand_prompt capability
|
||||
if not hasattr(model, 'expand_prompt'):
|
||||
model_type = type(model).__name__
|
||||
gr.Warning(f"Prompt expansion is only available for Z-Image models. Current model ({model_type}) does not support this feature.")
|
||||
return [prompt]
|
||||
return prompt
|
||||
|
||||
if not prompt or prompt.strip() == "":
|
||||
gr.Warning("Please enter a prompt to expand.")
|
||||
return [prompt]
|
||||
return prompt
|
||||
|
||||
gr.Info("Expanding prompt using Qwen3... This may take a moment.")
|
||||
expanded = model.expand_prompt(prompt.strip())
|
||||
# Extract PIL image if provided (from ForgeCanvas or gr.Image)
|
||||
pil_image = None
|
||||
if image is not None:
|
||||
if hasattr(image, 'convert'):
|
||||
# Already a PIL Image
|
||||
pil_image = image
|
||||
elif isinstance(image, dict) and 'image' in image:
|
||||
# ForgeCanvas format
|
||||
pil_image = image.get('image')
|
||||
elif isinstance(image, dict) and 'background' in image:
|
||||
# ForgeCanvas background format
|
||||
pil_image = image.get('background')
|
||||
|
||||
if pil_image is not None:
|
||||
gr.Info("Expanding prompt using Qwen3-VL with image context... This may take a moment.")
|
||||
else:
|
||||
gr.Info("Expanding prompt using Qwen3-VL... This may take a moment.")
|
||||
|
||||
expanded = model.expand_prompt(prompt.strip(), image=pil_image)
|
||||
|
||||
if expanded and expanded != prompt:
|
||||
gr.Info("Prompt expanded successfully!")
|
||||
return [expanded]
|
||||
return expanded
|
||||
else:
|
||||
gr.Warning("Prompt expansion returned empty result, keeping original.")
|
||||
return [prompt]
|
||||
return prompt
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
gr.Warning(f"Error during prompt expansion: {str(e)}")
|
||||
return [prompt]
|
||||
return prompt
|
||||
|
||||
|
||||
def connect_clear_prompt(button):
|
||||
@ -577,9 +594,9 @@ def create_ui():
|
||||
toprow.token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[toprow.prompt, steps, toprow.ui_styles.dropdown], outputs=[toprow.token_counter])
|
||||
toprow.negative_token_button.click(fn=wrap_queued_call(update_negative_prompt_token_counter), inputs=[toprow.negative_prompt, steps, toprow.ui_styles.dropdown], outputs=[toprow.negative_token_counter])
|
||||
|
||||
# Connect expand prompt button for Z-Image models
|
||||
# Connect expand prompt button for Z-Image models (no performance wrapper - we don't want HTML in prompt)
|
||||
toprow.expand_prompt_button.click(
|
||||
fn=wrap_gradio_gpu_call(expand_prompt_with_llm),
|
||||
fn=expand_prompt_with_llm,
|
||||
inputs=[toprow.prompt],
|
||||
outputs=[toprow.prompt],
|
||||
show_progress=True,
|
||||
@ -932,10 +949,10 @@ def create_ui():
|
||||
toprow.token_button.click(fn=update_token_counter, inputs=[toprow.prompt, steps, toprow.ui_styles.dropdown], outputs=[toprow.token_counter])
|
||||
toprow.negative_token_button.click(fn=wrap_queued_call(update_negative_prompt_token_counter), inputs=[toprow.negative_prompt, steps, toprow.ui_styles.dropdown], outputs=[toprow.negative_token_counter])
|
||||
|
||||
# Connect expand prompt button for Z-Image models
|
||||
# Connect expand prompt button for Z-Image models (with image from img2img, no performance wrapper)
|
||||
toprow.expand_prompt_button.click(
|
||||
fn=wrap_gradio_gpu_call(expand_prompt_with_llm),
|
||||
inputs=[toprow.prompt],
|
||||
fn=expand_prompt_with_llm,
|
||||
inputs=[toprow.prompt, init_img.background],
|
||||
outputs=[toprow.prompt],
|
||||
show_progress=True,
|
||||
)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user