stable-diffusion-webui-forge/convertpth.py

26 lines
825 B
Python

import torch
from safetensors.torch import save_file
def convert_pth_to_safetensors(pth_file: str, safetensors_file: str):
"""
Converts a PyTorch .pth file to a .safetensors file.
Args:
pth_file (str): Path to the .pth file to convert.
safetensors_file (str): Path to save the .safetensors file.
Returns:
None
"""
# Load the .pth file
state_dict = torch.load(pth_file, map_location="cpu")
# Ensure the state_dict is a dictionary
if not isinstance(state_dict, dict):
raise ValueError("The .pth file must contain a dictionary-like object.")
# Save the state_dict as a .safetensors file
save_file(state_dict, safetensors_file)
print(f"Converted {pth_file} to {safetensors_file}")
convert_pth_to_safetensors("source.pth", "output.safetensors")