""" ChromaDCT Memory Optimization Integration Automatically applies memory optimizations to ChromaDCT models during loading """ import torch from typing import Optional # Import the optimization components try: from myerflow.src.models.chroma.memory_optimization_helper import ( apply_memory_optimization, should_use_optimized_offloading, print_memory_optimization_info ) OPTIMIZATION_AVAILABLE = True except ImportError: print("Warning: ChromaDCT optimization components not found") OPTIMIZATION_AVAILABLE = False def patch_chromadct_model_loading(): """ Patch the model loading system to automatically apply ChromaDCT optimizations """ if not OPTIMIZATION_AVAILABLE: return # This would be called during model initialization # For now, this serves as a reference implementation pass def optimize_model_if_needed(model, device: torch.device, strategy: Optional[str] = None): """ Apply optimization to model if it's a ChromaDCT model Args: model: Model to potentially optimize device: Target device strategy: Offloading strategy (auto-detect if None) Returns: Optimized model or original model """ if not OPTIMIZATION_AVAILABLE: return model return apply_memory_optimization(model, device, strategy) # Example usage for manual optimization: def example_optimize_chromadct(): """ Example of how to manually apply ChromaDCT optimization """ print("ChromaDCT Memory Optimization Example") print("=====================================") # This would typically be called during model loading: # # # Load your ChromaDCT model # model = load_chromadct_model() # device = torch.device('cuda') # # # Apply optimization # optimized_model = optimize_model_if_needed(model, device, strategy='balanced') # # # Print optimization status # print_memory_optimization_info(optimized_model) # # # Use optimized model for inference # result = optimized_model(img, img_ids, txt, txt_ids, txt_mask, timesteps, guidance) # # # Print performance stats # if hasattr(optimized_model, 'print_performance_summary'): # optimized_model.print_performance_summary() print("See chromadct_optimization_integration.py for implementation details") if __name__ == "__main__": example_optimize_chromadct()