Add scripts/extract_sd_keys_and_shapes.py

This commit is contained in:
Ryan Dick 2024-10-07 16:02:30 +00:00
parent 2cd14dd066
commit 1a7eece695

View File

@ -0,0 +1,30 @@
import argparse
import json
from safetensors.torch import load_file
def extract_sd_keys_and_shapes(safetensors_file: str):
sd = load_file(safetensors_file)
keys_to_shapes = {k: v.shape for k, v in sd.items()}
out_file = "keys_and_shapes.json"
with open(out_file, "w") as f:
json.dump(keys_to_shapes, f, indent=4)
print(f"Keys and shapes written to '{out_file}'.")
def main():
parser = argparse.ArgumentParser(
description="Extracts the keys and shapes from the state dict in a safetensors file. Intended for creating "
+ "dummy state dicts for use in unit tests."
)
parser.add_argument("safetensors_file", type=str, help="Path to the safetensors file.")
args = parser.parse_args()
extract_sd_keys_and_shapes(args.safetensors_file)
if __name__ == "__main__":
main()