From 750f09fbed98a4a8bd663f5be0cdc836d77b787d Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Mon, 7 Aug 2023 21:01:59 -0400 Subject: [PATCH] blackify --- scripts/create_checkpoint_template.py | 20 +++----------------- scripts/verify_checkpoint_template.py | 24 +++++------------------- 2 files changed, 8 insertions(+), 36 deletions(-) diff --git a/scripts/create_checkpoint_template.py b/scripts/create_checkpoint_template.py index 5b8fca8b58..7ff201c841 100755 --- a/scripts/create_checkpoint_template.py +++ b/scripts/create_checkpoint_template.py @@ -13,18 +13,8 @@ from pathlib import Path from invokeai.backend.model_management.models.base import read_checkpoint_meta parser = argparse.ArgumentParser(description="Create a .json template from checkpoint/safetensors model") -parser.add_argument( - "--checkpoint", - "--in", - type=Path, - help="Path to the input checkpoint/safetensors file" -) -parser.add_argument( - "--template", - "--out", - type=Path, - help="Path to the output .json file" -) +parser.add_argument("--checkpoint", "--in", type=Path, help="Path to the input checkpoint/safetensors file") +parser.add_argument("--template", "--out", type=Path, help="Path to the output .json file") opt = parser.parse_args() ckpt = read_checkpoint_meta(opt.checkpoint) @@ -37,12 +27,8 @@ for key, tensor in ckpt.items(): tmpl[key] = list(tensor.shape) try: - with open(opt.template,'w') as f: + with open(opt.template, "w") as f: json.dump(tmpl, f) print(f"Template written out as {opt.template}") except Exception as e: print(f"An exception occurred while writing template: {str(e)}") - - - - diff --git a/scripts/verify_checkpoint_template.py b/scripts/verify_checkpoint_template.py index 42c7acca3a..68ed72037e 100755 --- a/scripts/verify_checkpoint_template.py +++ b/scripts/verify_checkpoint_template.py @@ -13,18 +13,8 @@ from pathlib import Path from invokeai.backend.model_management.models.base import read_checkpoint_meta parser = argparse.ArgumentParser(description="Create a .json template from checkpoint/safetensors model") -parser.add_argument( - "--checkpoint", - "--in", - type=Path, - help="Path to the input checkpoint/safetensors file" -) -parser.add_argument( - "--template", - "--out", - type=Path, - help="Path to the template .json file to match against" -) +parser.add_argument("--checkpoint", "--in", type=Path, help="Path to the input checkpoint/safetensors file") +parser.add_argument("--template", "--out", type=Path, help="Path to the template .json file to match against") opt = parser.parse_args() ckpt = read_checkpoint_meta(opt.checkpoint) @@ -36,16 +26,12 @@ checkpoint_metadata = {} for key, tensor in ckpt.items(): checkpoint_metadata[key] = list(tensor.shape) -with open(opt.template,'r') as f: +with open(opt.template, "r") as f: template = json.load(f) if checkpoint_metadata == template: - print('True') + print("True") sys.exit(0) else: - print('False') + print("False") sys.exit(-1) - - - -