prompt expansion

This commit is contained in:
maybleMyers 2025-12-04 16:47:20 -08:00
parent cecbcd8525
commit 63a337707a
2 changed files with 83 additions and 30 deletions

View File

@ -491,13 +491,14 @@ class ZImage(ForgeDiffusionEngine):
_generation_processor = None
@torch.inference_mode()
def expand_prompt(self, prompt: str, max_new_tokens: int = None, temperature: float = None) -> str:
def expand_prompt(self, prompt: str, image=None, max_new_tokens: int = None, temperature: float = None) -> str:
"""
Expand a prompt using Qwen3-VL model for generation.
Loads a separate pre-trained Qwen3-VL model for text generation.
Args:
prompt: The user's input prompt to expand
image: Optional PIL Image to use as context (from img2img)
max_new_tokens: Maximum tokens to generate (uses settings if None)
temperature: Generation temperature (uses settings if None)
"""
@ -553,10 +554,18 @@ User input prompt: '''
# Format the expansion request as a chat message
full_prompt = expansion_template + prompt
# Use chat format for Qwen3-VL (text-only, no images)
messages = [
{"role": "user", "content": [{"type": "text", "text": full_prompt}]}
]
# Build message content based on whether image is provided
if image is not None:
print("Using image context for prompt expansion...")
# Include image in the message for vision-language understanding
content = [
{"type": "image", "image": image},
{"type": "text", "text": full_prompt}
]
else:
content = [{"type": "text", "text": full_prompt}]
messages = [{"role": "user", "content": content}]
# Apply chat template
text_input = processor.apply_chat_template(
@ -565,12 +574,20 @@ User input prompt: '''
add_generation_prompt=True,
)
# Process inputs (text only, no images)
inputs = processor(
text=[text_input],
padding=True,
return_tensors="pt",
)
# Process inputs (with or without image)
if image is not None:
inputs = processor(
text=[text_input],
images=[image],
padding=True,
return_tensors="pt",
)
else:
inputs = processor(
text=[text_input],
padding=True,
return_tensors="pt",
)
# Move to device
device = next(model.parameters()).device
@ -594,17 +611,36 @@ User input prompt: '''
raw_output = processor.decode(generated_ids, skip_special_tokens=True)
# Print full output to console (including thinking if present)
print("\n" + "="*60)
print("PROMPT EXPANSION OUTPUT:")
print("="*60)
print(raw_output)
print("="*60 + "\n")
print("\n" + "="*60, flush=True)
print("PROMPT EXPANSION OUTPUT:", flush=True)
print("="*60, flush=True)
print(raw_output, flush=True)
print("="*60, flush=True)
# Clean up the output - remove any thinking tags if present
expanded_prompt = raw_output
if "</think>" in expanded_prompt:
expanded_prompt = expanded_prompt.split("</think>")[-1].strip()
print("\nCLEANED PROMPT:", flush=True)
print("-"*60, flush=True)
print(expanded_prompt, flush=True)
print("="*60 + "\n", flush=True)
# Unload Qwen3-VL model to free VRAM for image generation
print("Unloading Qwen3-VL model to free VRAM...", flush=True)
if ZImage._generation_model is not None:
del ZImage._generation_model
ZImage._generation_model = None
if ZImage._generation_processor is not None:
del ZImage._generation_processor
ZImage._generation_processor = None
import gc
gc.collect()
torch.cuda.empty_cache()
print("Qwen3-VL model unloaded.", flush=True)
return expanded_prompt.strip()
@torch.inference_mode()

View File

@ -153,7 +153,7 @@ def interrogate_deepbooru(image):
return gr.update() if prompt is None else prompt
def expand_prompt_with_llm(prompt):
def expand_prompt_with_llm(prompt, image=None):
"""Expand the prompt using the loaded model's LLM capabilities (Z-Image only)."""
try:
# Ensure model is loaded (forge loads models lazily)
@ -161,33 +161,50 @@ def expand_prompt_with_llm(prompt):
if model is None or isinstance(model, sd_models.FakeInitialModel):
gr.Warning("No model loaded. Please select a Z-Image model first.")
return [prompt]
return prompt
# Check if this is a Z-Image model with expand_prompt capability
if not hasattr(model, 'expand_prompt'):
model_type = type(model).__name__
gr.Warning(f"Prompt expansion is only available for Z-Image models. Current model ({model_type}) does not support this feature.")
return [prompt]
return prompt
if not prompt or prompt.strip() == "":
gr.Warning("Please enter a prompt to expand.")
return [prompt]
return prompt
gr.Info("Expanding prompt using Qwen3... This may take a moment.")
expanded = model.expand_prompt(prompt.strip())
# Extract PIL image if provided (from ForgeCanvas or gr.Image)
pil_image = None
if image is not None:
if hasattr(image, 'convert'):
# Already a PIL Image
pil_image = image
elif isinstance(image, dict) and 'image' in image:
# ForgeCanvas format
pil_image = image.get('image')
elif isinstance(image, dict) and 'background' in image:
# ForgeCanvas background format
pil_image = image.get('background')
if pil_image is not None:
gr.Info("Expanding prompt using Qwen3-VL with image context... This may take a moment.")
else:
gr.Info("Expanding prompt using Qwen3-VL... This may take a moment.")
expanded = model.expand_prompt(prompt.strip(), image=pil_image)
if expanded and expanded != prompt:
gr.Info("Prompt expanded successfully!")
return [expanded]
return expanded
else:
gr.Warning("Prompt expansion returned empty result, keeping original.")
return [prompt]
return prompt
except Exception as e:
import traceback
traceback.print_exc()
gr.Warning(f"Error during prompt expansion: {str(e)}")
return [prompt]
return prompt
def connect_clear_prompt(button):
@ -577,9 +594,9 @@ def create_ui():
toprow.token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[toprow.prompt, steps, toprow.ui_styles.dropdown], outputs=[toprow.token_counter])
toprow.negative_token_button.click(fn=wrap_queued_call(update_negative_prompt_token_counter), inputs=[toprow.negative_prompt, steps, toprow.ui_styles.dropdown], outputs=[toprow.negative_token_counter])
# Connect expand prompt button for Z-Image models
# Connect expand prompt button for Z-Image models (no performance wrapper - we don't want HTML in prompt)
toprow.expand_prompt_button.click(
fn=wrap_gradio_gpu_call(expand_prompt_with_llm),
fn=expand_prompt_with_llm,
inputs=[toprow.prompt],
outputs=[toprow.prompt],
show_progress=True,
@ -932,10 +949,10 @@ def create_ui():
toprow.token_button.click(fn=update_token_counter, inputs=[toprow.prompt, steps, toprow.ui_styles.dropdown], outputs=[toprow.token_counter])
toprow.negative_token_button.click(fn=wrap_queued_call(update_negative_prompt_token_counter), inputs=[toprow.negative_prompt, steps, toprow.ui_styles.dropdown], outputs=[toprow.negative_token_counter])
# Connect expand prompt button for Z-Image models
# Connect expand prompt button for Z-Image models (with image from img2img, no performance wrapper)
toprow.expand_prompt_button.click(
fn=wrap_gradio_gpu_call(expand_prompt_with_llm),
inputs=[toprow.prompt],
fn=expand_prompt_with_llm,
inputs=[toprow.prompt, init_img.background],
outputs=[toprow.prompt],
show_progress=True,
)