From 63a337707acff69f31abe3851d9105df4088726a Mon Sep 17 00:00:00 2001 From: maybleMyers Date: Thu, 4 Dec 2025 16:47:20 -0800 Subject: [PATCH] prompt expansion --- backend/diffusion_engine/zimage.py | 68 +++++++++++++++++++++++------- modules/ui.py | 45 ++++++++++++++------ 2 files changed, 83 insertions(+), 30 deletions(-) diff --git a/backend/diffusion_engine/zimage.py b/backend/diffusion_engine/zimage.py index ebbd0bc1..a8a2be8c 100644 --- a/backend/diffusion_engine/zimage.py +++ b/backend/diffusion_engine/zimage.py @@ -491,13 +491,14 @@ class ZImage(ForgeDiffusionEngine): _generation_processor = None @torch.inference_mode() - def expand_prompt(self, prompt: str, max_new_tokens: int = None, temperature: float = None) -> str: + def expand_prompt(self, prompt: str, image=None, max_new_tokens: int = None, temperature: float = None) -> str: """ Expand a prompt using Qwen3-VL model for generation. Loads a separate pre-trained Qwen3-VL model for text generation. Args: prompt: The user's input prompt to expand + image: Optional PIL Image to use as context (from img2img) max_new_tokens: Maximum tokens to generate (uses settings if None) temperature: Generation temperature (uses settings if None) """ @@ -553,10 +554,18 @@ User input prompt: ''' # Format the expansion request as a chat message full_prompt = expansion_template + prompt - # Use chat format for Qwen3-VL (text-only, no images) - messages = [ - {"role": "user", "content": [{"type": "text", "text": full_prompt}]} - ] + # Build message content based on whether image is provided + if image is not None: + print("Using image context for prompt expansion...") + # Include image in the message for vision-language understanding + content = [ + {"type": "image", "image": image}, + {"type": "text", "text": full_prompt} + ] + else: + content = [{"type": "text", "text": full_prompt}] + + messages = [{"role": "user", "content": content}] # Apply chat template text_input = processor.apply_chat_template( @@ -565,12 +574,20 @@ User input prompt: ''' add_generation_prompt=True, ) - # Process inputs (text only, no images) - inputs = processor( - text=[text_input], - padding=True, - return_tensors="pt", - ) + # Process inputs (with or without image) + if image is not None: + inputs = processor( + text=[text_input], + images=[image], + padding=True, + return_tensors="pt", + ) + else: + inputs = processor( + text=[text_input], + padding=True, + return_tensors="pt", + ) # Move to device device = next(model.parameters()).device @@ -594,17 +611,36 @@ User input prompt: ''' raw_output = processor.decode(generated_ids, skip_special_tokens=True) # Print full output to console (including thinking if present) - print("\n" + "="*60) - print("PROMPT EXPANSION OUTPUT:") - print("="*60) - print(raw_output) - print("="*60 + "\n") + print("\n" + "="*60, flush=True) + print("PROMPT EXPANSION OUTPUT:", flush=True) + print("="*60, flush=True) + print(raw_output, flush=True) + print("="*60, flush=True) # Clean up the output - remove any thinking tags if present expanded_prompt = raw_output if "" in expanded_prompt: expanded_prompt = expanded_prompt.split("")[-1].strip() + print("\nCLEANED PROMPT:", flush=True) + print("-"*60, flush=True) + print(expanded_prompt, flush=True) + print("="*60 + "\n", flush=True) + + # Unload Qwen3-VL model to free VRAM for image generation + print("Unloading Qwen3-VL model to free VRAM...", flush=True) + if ZImage._generation_model is not None: + del ZImage._generation_model + ZImage._generation_model = None + if ZImage._generation_processor is not None: + del ZImage._generation_processor + ZImage._generation_processor = None + + import gc + gc.collect() + torch.cuda.empty_cache() + print("Qwen3-VL model unloaded.", flush=True) + return expanded_prompt.strip() @torch.inference_mode() diff --git a/modules/ui.py b/modules/ui.py index 2d97d140..20fc3407 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -153,7 +153,7 @@ def interrogate_deepbooru(image): return gr.update() if prompt is None else prompt -def expand_prompt_with_llm(prompt): +def expand_prompt_with_llm(prompt, image=None): """Expand the prompt using the loaded model's LLM capabilities (Z-Image only).""" try: # Ensure model is loaded (forge loads models lazily) @@ -161,33 +161,50 @@ def expand_prompt_with_llm(prompt): if model is None or isinstance(model, sd_models.FakeInitialModel): gr.Warning("No model loaded. Please select a Z-Image model first.") - return [prompt] + return prompt # Check if this is a Z-Image model with expand_prompt capability if not hasattr(model, 'expand_prompt'): model_type = type(model).__name__ gr.Warning(f"Prompt expansion is only available for Z-Image models. Current model ({model_type}) does not support this feature.") - return [prompt] + return prompt if not prompt or prompt.strip() == "": gr.Warning("Please enter a prompt to expand.") - return [prompt] + return prompt - gr.Info("Expanding prompt using Qwen3... This may take a moment.") - expanded = model.expand_prompt(prompt.strip()) + # Extract PIL image if provided (from ForgeCanvas or gr.Image) + pil_image = None + if image is not None: + if hasattr(image, 'convert'): + # Already a PIL Image + pil_image = image + elif isinstance(image, dict) and 'image' in image: + # ForgeCanvas format + pil_image = image.get('image') + elif isinstance(image, dict) and 'background' in image: + # ForgeCanvas background format + pil_image = image.get('background') + + if pil_image is not None: + gr.Info("Expanding prompt using Qwen3-VL with image context... This may take a moment.") + else: + gr.Info("Expanding prompt using Qwen3-VL... This may take a moment.") + + expanded = model.expand_prompt(prompt.strip(), image=pil_image) if expanded and expanded != prompt: gr.Info("Prompt expanded successfully!") - return [expanded] + return expanded else: gr.Warning("Prompt expansion returned empty result, keeping original.") - return [prompt] + return prompt except Exception as e: import traceback traceback.print_exc() gr.Warning(f"Error during prompt expansion: {str(e)}") - return [prompt] + return prompt def connect_clear_prompt(button): @@ -577,9 +594,9 @@ def create_ui(): toprow.token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[toprow.prompt, steps, toprow.ui_styles.dropdown], outputs=[toprow.token_counter]) toprow.negative_token_button.click(fn=wrap_queued_call(update_negative_prompt_token_counter), inputs=[toprow.negative_prompt, steps, toprow.ui_styles.dropdown], outputs=[toprow.negative_token_counter]) - # Connect expand prompt button for Z-Image models + # Connect expand prompt button for Z-Image models (no performance wrapper - we don't want HTML in prompt) toprow.expand_prompt_button.click( - fn=wrap_gradio_gpu_call(expand_prompt_with_llm), + fn=expand_prompt_with_llm, inputs=[toprow.prompt], outputs=[toprow.prompt], show_progress=True, @@ -932,10 +949,10 @@ def create_ui(): toprow.token_button.click(fn=update_token_counter, inputs=[toprow.prompt, steps, toprow.ui_styles.dropdown], outputs=[toprow.token_counter]) toprow.negative_token_button.click(fn=wrap_queued_call(update_negative_prompt_token_counter), inputs=[toprow.negative_prompt, steps, toprow.ui_styles.dropdown], outputs=[toprow.negative_token_counter]) - # Connect expand prompt button for Z-Image models + # Connect expand prompt button for Z-Image models (with image from img2img, no performance wrapper) toprow.expand_prompt_button.click( - fn=wrap_gradio_gpu_call(expand_prompt_with_llm), - inputs=[toprow.prompt], + fn=expand_prompt_with_llm, + inputs=[toprow.prompt, init_img.background], outputs=[toprow.prompt], show_progress=True, )