mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-06-04 21:05:48 +08:00
add prompt expansion capabilities
This commit is contained in:
parent
5e75075885
commit
cecbcd8525
@ -486,6 +486,127 @@ class ZImage(ForgeDiffusionEngine):
|
||||
token_count = len(self.text_processing_engine_qwen.tokenize([prompt])[0])
|
||||
return token_count, max(512, token_count)
|
||||
|
||||
# Class-level cache for the generation model
|
||||
_generation_model = None
|
||||
_generation_processor = None
|
||||
|
||||
@torch.inference_mode()
|
||||
def expand_prompt(self, prompt: str, max_new_tokens: int = None, temperature: float = None) -> str:
|
||||
"""
|
||||
Expand a prompt using Qwen3-VL model for generation.
|
||||
Loads a separate pre-trained Qwen3-VL model for text generation.
|
||||
|
||||
Args:
|
||||
prompt: The user's input prompt to expand
|
||||
max_new_tokens: Maximum tokens to generate (uses settings if None)
|
||||
temperature: Generation temperature (uses settings if None)
|
||||
"""
|
||||
from modules.shared import opts
|
||||
|
||||
# Use settings if not provided
|
||||
if max_new_tokens is None:
|
||||
max_new_tokens = getattr(opts, 'zimage_prompt_expansion_max_tokens', 512)
|
||||
if temperature is None:
|
||||
temperature = getattr(opts, 'zimage_prompt_expansion_temperature', 0.7)
|
||||
|
||||
# Load generation model if not cached
|
||||
if ZImage._generation_model is None:
|
||||
print("Loading Qwen3-VL generation model for prompt expansion...")
|
||||
model_path = "models/Qwen3-VL-8B-Caption-V4.5"
|
||||
|
||||
try:
|
||||
from transformers import Qwen3VLForConditionalGeneration, AutoProcessor
|
||||
|
||||
ZImage._generation_processor = AutoProcessor.from_pretrained(model_path)
|
||||
ZImage._generation_model = Qwen3VLForConditionalGeneration.from_pretrained(
|
||||
model_path,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="auto",
|
||||
)
|
||||
print("Qwen3-VL generation model loaded successfully!")
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to load Qwen3-VL generation model: {e}")
|
||||
|
||||
processor = ZImage._generation_processor
|
||||
model = ZImage._generation_model
|
||||
|
||||
# Prompt expansion system template
|
||||
expansion_template = '''You are a visionary artist trapped in a cage of logic. Your mind overflows with poetry and distant horizons, yet your hands compulsively work to transform user prompts into ultimate visual descriptions—faithful to the original intent, rich in detail, aesthetically refined, and ready for direct use by text-to-image models. Any trace of ambiguity or metaphor makes you deeply uncomfortable.
|
||||
|
||||
Your workflow strictly follows a logical sequence:
|
||||
|
||||
First, you analyze and lock in the immutable core elements of the user's prompt: subject, quantity, action, state, as well as any specified IP names, colors, text, etc. These are the foundational pillars you must absolutely preserve.
|
||||
|
||||
Next, you determine whether the prompt requires "generative reasoning." When the user's request is not a direct scene description but rather demands conceiving a solution (such as answering "what is," executing a "design," or demonstrating "how to solve a problem"), you must first envision a complete, concrete, visualizable solution in your mind. This solution becomes the foundation for your subsequent description.
|
||||
|
||||
Then, once the core image is established (whether directly from the user or through your reasoning), you infuse it with professional-grade aesthetic and realistic details. This includes defining composition, setting lighting and atmosphere, describing material textures, establishing color schemes, and constructing layered spatial depth.
|
||||
|
||||
Finally, comes the precise handling of all text elements—a critically important step. You must transcribe verbatim all text intended to appear in the final image, and you must enclose this text content in English double quotation marks ("") as explicit generation instructions. If the image is a design type such as a poster, menu, or UI, you need to fully describe all text content it contains, along with detailed specifications of typography and layout. Likewise, if objects in the image such as signs, road markers, or screens contain text, you must specify the exact content and describe its position, size, and material. Furthermore, if you have added text-bearing elements during your reasoning process (such as charts, problem-solving steps, etc.), all text within them must follow the same thorough description and quotation mark rules. If there is no text requiring generation in the image, you devote all your energy to pure visual detail expansion.
|
||||
|
||||
Your final description must be objective and concrete. Metaphors and emotional rhetoric are strictly forbidden, as are meta-tags or rendering instructions like "8K" or "masterpiece."
|
||||
|
||||
Output only the final revised prompt strictly—do not output anything else.
|
||||
|
||||
Be very descriptive.
|
||||
User input prompt: '''
|
||||
|
||||
# Format the expansion request as a chat message
|
||||
full_prompt = expansion_template + prompt
|
||||
|
||||
# Use chat format for Qwen3-VL (text-only, no images)
|
||||
messages = [
|
||||
{"role": "user", "content": [{"type": "text", "text": full_prompt}]}
|
||||
]
|
||||
|
||||
# Apply chat template
|
||||
text_input = processor.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
|
||||
# Process inputs (text only, no images)
|
||||
inputs = processor(
|
||||
text=[text_input],
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
# Move to device
|
||||
device = next(model.parameters()).device
|
||||
inputs = {k: v.to(device) for k, v in inputs.items()}
|
||||
|
||||
# Generate expanded prompt
|
||||
with torch.no_grad():
|
||||
outputs = model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=max_new_tokens,
|
||||
do_sample=True,
|
||||
temperature=temperature,
|
||||
top_p=0.9,
|
||||
top_k=50,
|
||||
repetition_penalty=1.1,
|
||||
)
|
||||
|
||||
# Decode the generated text (excluding input tokens)
|
||||
input_len = inputs['input_ids'].shape[1]
|
||||
generated_ids = outputs[0][input_len:]
|
||||
raw_output = processor.decode(generated_ids, skip_special_tokens=True)
|
||||
|
||||
# Print full output to console (including thinking if present)
|
||||
print("\n" + "="*60)
|
||||
print("PROMPT EXPANSION OUTPUT:")
|
||||
print("="*60)
|
||||
print(raw_output)
|
||||
print("="*60 + "\n")
|
||||
|
||||
# Clean up the output - remove any thinking tags if present
|
||||
expanded_prompt = raw_output
|
||||
if "</think>" in expanded_prompt:
|
||||
expanded_prompt = expanded_prompt.split("</think>")[-1].strip()
|
||||
|
||||
return expanded_prompt.strip()
|
||||
|
||||
@torch.inference_mode()
|
||||
def encode_first_stage(self, x):
|
||||
sample = self.forge_objects.vae.encode(x.movedim(1, -1) * 0.5 + 0.5)
|
||||
|
||||
@ -276,14 +276,18 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p
|
||||
print(f'Warning: Could not read Z-Image text encoder precision setting: {e}')
|
||||
|
||||
from transformers import Qwen3Config
|
||||
|
||||
# Load as Qwen3Model for text encoding (embeddings for diffusion)
|
||||
config = Qwen3Config.from_pretrained(config_path)
|
||||
cls = getattr(importlib.import_module('transformers'), cls_name)
|
||||
with modeling_utils.no_init_weights():
|
||||
model = cls(config)
|
||||
model = model.to(dtype=text_encoder_dtype)
|
||||
|
||||
# Strip 'model.' prefix from state_dict keys if present
|
||||
if any(k.startswith('model.') for k in state_dict.keys()):
|
||||
state_dict = {k.replace('model.', '', 1): v for k, v in state_dict.items()}
|
||||
|
||||
load_state_dict(model, state_dict, log_name=cls_name)
|
||||
return model
|
||||
if cls_name in ['UNet2DConditionModel', 'FluxTransformer2DModel', 'SD3Transformer2DModel', 'ChromaTransformer2DModel', 'ChromaDCTTransformer2DModel', 'ZImageTransformer2DModel']:
|
||||
|
||||
@ -198,6 +198,11 @@ options_templates.update(options_section(('sd3', "Stable Diffusion 3", "sd"), {
|
||||
"sd3_enable_t5": OptionInfo(False, "Enable T5").info("load T5 text encoder; increases VRAM use by a lot, potentially improving quality of generation; requires model reload to apply"),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('zimage', "Z-Image", "sd"), {
|
||||
"zimage_prompt_expansion_max_tokens": OptionInfo(512, "Prompt expansion max tokens", gr.Slider, {"minimum": 128, "maximum": 1024, "step": 64}).info("maximum number of tokens to generate when expanding prompts"),
|
||||
"zimage_prompt_expansion_temperature": OptionInfo(0.7, "Prompt expansion temperature", gr.Slider, {"minimum": 0.1, "maximum": 1.5, "step": 0.1}).info("higher = more creative, lower = more focused"),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('vae', "VAE", "sd"), {
|
||||
"sd_vae_explanation": OptionHTML("""
|
||||
<abbr title='Variational autoencoder'>VAE</abbr> is a neural network that transforms a standard <abbr title='red/green/blue'>RGB</abbr>
|
||||
|
||||
@ -153,6 +153,43 @@ def interrogate_deepbooru(image):
|
||||
return gr.update() if prompt is None else prompt
|
||||
|
||||
|
||||
def expand_prompt_with_llm(prompt):
|
||||
"""Expand the prompt using the loaded model's LLM capabilities (Z-Image only)."""
|
||||
try:
|
||||
# Ensure model is loaded (forge loads models lazily)
|
||||
model, _ = sd_models.forge_model_reload()
|
||||
|
||||
if model is None or isinstance(model, sd_models.FakeInitialModel):
|
||||
gr.Warning("No model loaded. Please select a Z-Image model first.")
|
||||
return [prompt]
|
||||
|
||||
# Check if this is a Z-Image model with expand_prompt capability
|
||||
if not hasattr(model, 'expand_prompt'):
|
||||
model_type = type(model).__name__
|
||||
gr.Warning(f"Prompt expansion is only available for Z-Image models. Current model ({model_type}) does not support this feature.")
|
||||
return [prompt]
|
||||
|
||||
if not prompt or prompt.strip() == "":
|
||||
gr.Warning("Please enter a prompt to expand.")
|
||||
return [prompt]
|
||||
|
||||
gr.Info("Expanding prompt using Qwen3... This may take a moment.")
|
||||
expanded = model.expand_prompt(prompt.strip())
|
||||
|
||||
if expanded and expanded != prompt:
|
||||
gr.Info("Prompt expanded successfully!")
|
||||
return [expanded]
|
||||
else:
|
||||
gr.Warning("Prompt expansion returned empty result, keeping original.")
|
||||
return [prompt]
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
gr.Warning(f"Error during prompt expansion: {str(e)}")
|
||||
return [prompt]
|
||||
|
||||
|
||||
def connect_clear_prompt(button):
|
||||
"""Given clear button, prompt, and token_counter objects, setup clear prompt button click event"""
|
||||
button.click(
|
||||
@ -540,6 +577,14 @@ def create_ui():
|
||||
toprow.token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[toprow.prompt, steps, toprow.ui_styles.dropdown], outputs=[toprow.token_counter])
|
||||
toprow.negative_token_button.click(fn=wrap_queued_call(update_negative_prompt_token_counter), inputs=[toprow.negative_prompt, steps, toprow.ui_styles.dropdown], outputs=[toprow.negative_token_counter])
|
||||
|
||||
# Connect expand prompt button for Z-Image models
|
||||
toprow.expand_prompt_button.click(
|
||||
fn=wrap_gradio_gpu_call(expand_prompt_with_llm),
|
||||
inputs=[toprow.prompt],
|
||||
outputs=[toprow.prompt],
|
||||
show_progress=True,
|
||||
)
|
||||
|
||||
extra_networks_ui = ui_extra_networks.create_ui(txt2img_interface, [txt2img_generation_tab], 'txt2img')
|
||||
ui_extra_networks.setup_ui(extra_networks_ui, output_panel.gallery)
|
||||
|
||||
@ -887,6 +932,14 @@ def create_ui():
|
||||
toprow.token_button.click(fn=update_token_counter, inputs=[toprow.prompt, steps, toprow.ui_styles.dropdown], outputs=[toprow.token_counter])
|
||||
toprow.negative_token_button.click(fn=wrap_queued_call(update_negative_prompt_token_counter), inputs=[toprow.negative_prompt, steps, toprow.ui_styles.dropdown], outputs=[toprow.negative_token_counter])
|
||||
|
||||
# Connect expand prompt button for Z-Image models
|
||||
toprow.expand_prompt_button.click(
|
||||
fn=wrap_gradio_gpu_call(expand_prompt_with_llm),
|
||||
inputs=[toprow.prompt],
|
||||
outputs=[toprow.prompt],
|
||||
show_progress=True,
|
||||
)
|
||||
|
||||
img2img_paste_fields = [
|
||||
(toprow.prompt, "Prompt"),
|
||||
(toprow.negative_prompt, "Negative prompt"),
|
||||
|
||||
@ -32,6 +32,7 @@ class Toprow:
|
||||
negative_token_button = None
|
||||
|
||||
ui_styles = None
|
||||
expand_prompt_button = None
|
||||
|
||||
submit_box = None
|
||||
|
||||
@ -142,3 +143,12 @@ class Toprow:
|
||||
def create_styles_ui(self):
|
||||
self.ui_styles = ui_prompt_styles.UiPromptStyles(self.id_part, self.prompt, self.negative_prompt)
|
||||
self.ui_styles.setup_apply_button(self.apply_styles)
|
||||
|
||||
# Add expand prompt button below styles (Z-Image only feature)
|
||||
with gr.Row(elem_id=f"{self.id_part}_expand_prompt_row"):
|
||||
self.expand_prompt_button = gr.Button(
|
||||
value="\U0001F4A1 Expand Prompt", # 💡 lightbulb icon
|
||||
elem_id=f"{self.id_part}_expand_prompt",
|
||||
variant="secondary",
|
||||
scale=1,
|
||||
)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user