mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-06-21 21:14:23 +08:00
26 lines
825 B
Python
26 lines
825 B
Python
import torch
|
|
from safetensors.torch import save_file
|
|
|
|
def convert_pth_to_safetensors(pth_file: str, safetensors_file: str):
|
|
"""
|
|
Converts a PyTorch .pth file to a .safetensors file.
|
|
|
|
Args:
|
|
pth_file (str): Path to the .pth file to convert.
|
|
safetensors_file (str): Path to save the .safetensors file.
|
|
|
|
Returns:
|
|
None
|
|
"""
|
|
# Load the .pth file
|
|
state_dict = torch.load(pth_file, map_location="cpu")
|
|
|
|
# Ensure the state_dict is a dictionary
|
|
if not isinstance(state_dict, dict):
|
|
raise ValueError("The .pth file must contain a dictionary-like object.")
|
|
|
|
# Save the state_dict as a .safetensors file
|
|
save_file(state_dict, safetensors_file)
|
|
print(f"Converted {pth_file} to {safetensors_file}")
|
|
|
|
convert_pth_to_safetensors("source.pth", "output.safetensors") |