mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-06-21 21:14:23 +08:00
83 lines
2.4 KiB
Python
83 lines
2.4 KiB
Python
"""
|
|
ChromaDCT Memory Optimization Integration
|
|
Automatically applies memory optimizations to ChromaDCT models during loading
|
|
"""
|
|
|
|
import torch
|
|
from typing import Optional
|
|
|
|
# Import the optimization components
|
|
try:
|
|
from myerflow.src.models.chroma.memory_optimization_helper import (
|
|
apply_memory_optimization,
|
|
should_use_optimized_offloading,
|
|
print_memory_optimization_info
|
|
)
|
|
OPTIMIZATION_AVAILABLE = True
|
|
except ImportError:
|
|
print("Warning: ChromaDCT optimization components not found")
|
|
OPTIMIZATION_AVAILABLE = False
|
|
|
|
|
|
def patch_chromadct_model_loading():
|
|
"""
|
|
Patch the model loading system to automatically apply ChromaDCT optimizations
|
|
"""
|
|
if not OPTIMIZATION_AVAILABLE:
|
|
return
|
|
|
|
# This would be called during model initialization
|
|
# For now, this serves as a reference implementation
|
|
pass
|
|
|
|
|
|
def optimize_model_if_needed(model, device: torch.device, strategy: Optional[str] = None):
|
|
"""
|
|
Apply optimization to model if it's a ChromaDCT model
|
|
|
|
Args:
|
|
model: Model to potentially optimize
|
|
device: Target device
|
|
strategy: Offloading strategy (auto-detect if None)
|
|
|
|
Returns:
|
|
Optimized model or original model
|
|
"""
|
|
if not OPTIMIZATION_AVAILABLE:
|
|
return model
|
|
|
|
return apply_memory_optimization(model, device, strategy)
|
|
|
|
|
|
# Example usage for manual optimization:
|
|
def example_optimize_chromadct():
|
|
"""
|
|
Example of how to manually apply ChromaDCT optimization
|
|
"""
|
|
print("ChromaDCT Memory Optimization Example")
|
|
print("=====================================")
|
|
|
|
# This would typically be called during model loading:
|
|
#
|
|
# # Load your ChromaDCT model
|
|
# model = load_chromadct_model()
|
|
# device = torch.device('cuda')
|
|
#
|
|
# # Apply optimization
|
|
# optimized_model = optimize_model_if_needed(model, device, strategy='balanced')
|
|
#
|
|
# # Print optimization status
|
|
# print_memory_optimization_info(optimized_model)
|
|
#
|
|
# # Use optimized model for inference
|
|
# result = optimized_model(img, img_ids, txt, txt_ids, txt_mask, timesteps, guidance)
|
|
#
|
|
# # Print performance stats
|
|
# if hasattr(optimized_model, 'print_performance_summary'):
|
|
# optimized_model.print_performance_summary()
|
|
|
|
print("See chromadct_optimization_integration.py for implementation details")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
example_optimize_chromadct() |