diff --git a/scripts/extract_sd_keys_and_shapes.py b/scripts/extract_sd_keys_and_shapes.py new file mode 100644 index 0000000000..33c7e83627 --- /dev/null +++ b/scripts/extract_sd_keys_and_shapes.py @@ -0,0 +1,30 @@ +import argparse +import json + +from safetensors.torch import load_file + + +def extract_sd_keys_and_shapes(safetensors_file: str): + sd = load_file(safetensors_file) + + keys_to_shapes = {k: v.shape for k, v in sd.items()} + + out_file = "keys_and_shapes.json" + with open(out_file, "w") as f: + json.dump(keys_to_shapes, f, indent=4) + + print(f"Keys and shapes written to '{out_file}'.") + + +def main(): + parser = argparse.ArgumentParser( + description="Extracts the keys and shapes from the state dict in a safetensors file. Intended for creating " + + "dummy state dicts for use in unit tests." + ) + parser.add_argument("safetensors_file", type=str, help="Path to the safetensors file.") + args = parser.parse_args() + extract_sd_keys_and_shapes(args.safetensors_file) + + +if __name__ == "__main__": + main()