fix: Black linting

This commit is contained in:
blessedcoolant 2023-07-29 17:34:43 +12:00
parent 6ed1bf7084
commit 6d82a1019a
3 changed files with 327 additions and 321 deletions

View File

@ -6,8 +6,7 @@ from pydantic import Field
from invokeai.app.invocations.prompt import PromptOutput from invokeai.app.invocations.prompt import PromptOutput
from .baseinvocation import (BaseInvocation, BaseInvocationOutput, from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
InvocationConfig, InvocationContext)
from .math import FloatOutput, IntOutput from .math import FloatOutput, IntOutput
# Pass-through parameter nodes - used by subgraphs # Pass-through parameter nodes - used by subgraphs
@ -68,6 +67,7 @@ class ParamStringInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> StringOutput: def invoke(self, context: InvocationContext) -> StringOutput:
return StringOutput(text=self.text) return StringOutput(text=self.text)
class ParamPromptInvocation(BaseInvocation): class ParamPromptInvocation(BaseInvocation):
"""A prompt input parameter""" """A prompt input parameter"""
@ -80,4 +80,4 @@ class ParamPromptInvocation(BaseInvocation):
} }
def invoke(self, context: InvocationContext) -> PromptOutput: def invoke(self, context: InvocationContext) -> PromptOutput:
return PromptOutput(prompt=self.prompt) return PromptOutput(prompt=self.prompt)

View File

@ -1,281 +1,283 @@
{ {
"cells": [ "cells": [
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"id": "ycYWcsEKc6w7" "id": "ycYWcsEKc6w7"
}, },
"source": [ "source": [
"# Stable Diffusion AI Notebook (Release 2.0.0)\n", "# Stable Diffusion AI Notebook (Release 2.0.0)\n",
"\n", "\n",
"<img src=\"https://user-images.githubusercontent.com/60411196/186547976-d9de378a-9de8-4201-9c25-c057a9c59bad.jpeg\" alt=\"stable-diffusion-ai\" width=\"170px\"/> <br>\n", "<img src=\"https://user-images.githubusercontent.com/60411196/186547976-d9de378a-9de8-4201-9c25-c057a9c59bad.jpeg\" alt=\"stable-diffusion-ai\" width=\"170px\"/> <br>\n",
"#### Instructions:\n", "#### Instructions:\n",
"1. Execute each cell in order to mount a Dream bot and create images from text. <br>\n", "1. Execute each cell in order to mount a Dream bot and create images from text. <br>\n",
"2. Once cells 1-8 were run correctly you'll be executing a terminal in cell #9, you'll need to enter `python scripts/dream.py` command to run Dream bot.<br> \n", "2. Once cells 1-8 were run correctly you'll be executing a terminal in cell #9, you'll need to enter `python scripts/dream.py` command to run Dream bot.<br> \n",
"3. After launching dream bot, you'll see: <br> `Dream > ` in terminal. <br> Insert a command, eg. `Dream > Astronaut floating in a distant galaxy`, or type `-h` for help.\n", "3. After launching dream bot, you'll see: <br> `Dream > ` in terminal. <br> Insert a command, eg. `Dream > Astronaut floating in a distant galaxy`, or type `-h` for help.\n",
"3. After completion you'll see your generated images in path `stable-diffusion/outputs/img-samples/`, you can also show last generated images in cell #10.\n", "3. After completion you'll see your generated images in path `stable-diffusion/outputs/img-samples/`, you can also show last generated images in cell #10.\n",
"4. To quit Dream bot use `q` command. <br> \n", "4. To quit Dream bot use `q` command. <br> \n",
"---\n", "---\n",
"<font color=\"red\">Note:</font> It takes some time to load, but after installing all dependencies you can use the bot all time you want while colab instance is up. <br>\n", "<font color=\"red\">Note:</font> It takes some time to load, but after installing all dependencies you can use the bot all time you want while colab instance is up. <br>\n",
"<font color=\"red\">Requirements:</font> For this notebook to work you need to have [Stable-Diffusion-v-1-4](https://huggingface.co/CompVis/stable-diffusion-v-1-4-original) stored in your Google Drive, it will be needed in cell #7\n", "<font color=\"red\">Requirements:</font> For this notebook to work you need to have [Stable-Diffusion-v-1-4](https://huggingface.co/CompVis/stable-diffusion-v-1-4-original) stored in your Google Drive, it will be needed in cell #7\n",
"##### For more details visit Github repository: [invoke-ai/InvokeAI](https://github.com/invoke-ai/InvokeAI)\n", "##### For more details visit Github repository: [invoke-ai/InvokeAI](https://github.com/invoke-ai/InvokeAI)\n",
"---\n" "---\n"
] ]
},
{
"cell_type": "markdown",
"metadata": {
"id": "dr32VLxlnouf"
},
"source": [
"## ◢ Installation"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "a2Z5Qu_o8VtQ"
},
"outputs": [],
"source": [
"#@title 1. Check current GPU assigned\n",
"!nvidia-smi -L\n",
"!nvidia-smi"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "vbI9ZsQHzjqF"
},
"outputs": [],
"source": [
"#@title 2. Download stable-diffusion Repository\n",
"from os.path import exists\n",
"\n",
"!git clone --quiet https://github.com/invoke-ai/InvokeAI.git # Original repo\n",
"%cd /content/InvokeAI/\n",
"!git checkout --quiet tags/v2.0.0"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "QbXcGXYEFSNB"
},
"outputs": [],
"source": [
"#@title 3. Install dependencies\n",
"import gc\n",
"\n",
"!wget https://raw.githubusercontent.com/invoke-ai/InvokeAI/development/environments-and-requirements/requirements-base.txt\n",
"!wget https://raw.githubusercontent.com/invoke-ai/InvokeAI/development/environments-and-requirements/requirements-win-colab-cuda.txt\n",
"!pip install colab-xterm\n",
"!pip install -r requirements-lin-win-colab-CUDA.txt\n",
"!pip install clean-fid torchtext\n",
"!pip install transformers\n",
"gc.collect()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "8rSMhgnAttQa"
},
"outputs": [],
"source": [
"#@title 4. Restart Runtime\n",
"exit()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "ChIDWxLVHGGJ"
},
"outputs": [],
"source": [
"#@title 5. Load small ML models required\n",
"import gc\n",
"%cd /content/InvokeAI/\n",
"!python scripts/preload_models.py\n",
"gc.collect()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "795x1tMoo8b1"
},
"source": [
"## ◢ Configuration"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "YEWPV-sF1RDM"
},
"outputs": [],
"source": [
"#@title 6. Mount google Drive\n",
"from google.colab import drive\n",
"drive.mount('/content/drive')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "zRTJeZ461WGu"
},
"outputs": [],
"source": [
"#@title 7. Drive Path to model\n",
"#@markdown Path should start with /content/drive/path-to-your-file <br>\n",
"#@markdown <font color=\"red\">Note:</font> Model should be downloaded from https://huggingface.co <br>\n",
"#@markdown Lastest release: [Stable-Diffusion-v-1-4](https://huggingface.co/CompVis/stable-diffusion-v-1-4-original)\n",
"from os.path import exists\n",
"\n",
"model_path = \"\" #@param {type:\"string\"}\n",
"if exists(model_path):\n",
" print(\"✅ Valid directory\")\n",
"else: \n",
" print(\"❌ File doesn't exist\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "UY-NNz4I8_aG"
},
"outputs": [],
"source": [
"#@title 8. Symlink to model\n",
"\n",
"from os.path import exists\n",
"import os \n",
"\n",
"# Folder creation if it doesn't exist\n",
"if exists(\"/content/InvokeAI/models/ldm/stable-diffusion-v1\"):\n",
" print(\"❗ Dir stable-diffusion-v1 already exists\")\n",
"else:\n",
" %mkdir /content/InvokeAI/models/ldm/stable-diffusion-v1\n",
" print(\"✅ Dir stable-diffusion-v1 created\")\n",
"\n",
"# Symbolic link if it doesn't exist\n",
"if exists(\"/content/InvokeAI/models/ldm/stable-diffusion-v1/model.ckpt\"):\n",
" print(\"❗ Symlink already created\")\n",
"else: \n",
" src = model_path\n",
" dst = '/content/InvokeAI/models/ldm/stable-diffusion-v1/model.ckpt'\n",
" os.symlink(src, dst) \n",
" print(\"✅ Symbolic link created successfully\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Mc28N0_NrCQH"
},
"source": [
"## ◢ Execution"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "ir4hCrMIuUpl"
},
"outputs": [],
"source": [
"#@title 9. Run Terminal and Execute Dream bot\n",
"#@markdown <font color=\"blue\">Steps:</font> <br>\n",
"#@markdown 1. Execute command `python scripts/invoke.py` to run InvokeAI.<br>\n",
"#@markdown 2. After initialized you'll see `Dream>` line.<br>\n",
"#@markdown 3. Example text: `Astronaut floating in a distant galaxy` <br>\n",
"#@markdown 4. To quit Dream bot use: `q` command.<br>\n",
"\n",
"%load_ext colabxterm\n",
"%xterm\n",
"gc.collect()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "qnLohSHmKoGk"
},
"outputs": [],
"source": [
"#@title 10. Show the last 15 generated images\n",
"import glob\n",
"import matplotlib.pyplot as plt\n",
"import matplotlib.image as mpimg\n",
"%matplotlib inline\n",
"\n",
"images = []\n",
"for img_path in sorted(glob.glob('/content/InvokeAI/outputs/img-samples/*.png'), reverse=True):\n",
" images.append(mpimg.imread(img_path))\n",
"\n",
"images = images[:15] \n",
"\n",
"plt.figure(figsize=(20,10))\n",
"\n",
"columns = 5\n",
"for i, image in enumerate(images):\n",
" ax = plt.subplot(len(images) / columns + 1, columns, i + 1)\n",
" ax.axes.xaxis.set_visible(False)\n",
" ax.axes.yaxis.set_visible(False)\n",
" ax.axis('off')\n",
" plt.imshow(image)\n",
" gc.collect()\n",
"\n"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"collapsed_sections": [],
"private_outputs": true,
"provenance": []
},
"gpuClass": "standard",
"kernelspec": {
"display_name": "Python 3.9.12 64-bit",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.9.12"
},
"vscode": {
"interpreter": {
"hash": "4e870c5c5fe42db7e2c5647ae5af656ff3391bf8c2b729cbf7fa0e16ca8cb5af"
}
}
}, },
"nbformat": 4, {
"nbformat_minor": 0 "cell_type": "markdown",
"metadata": {
"id": "dr32VLxlnouf"
},
"source": [
"## ◢ Installation"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "a2Z5Qu_o8VtQ"
},
"outputs": [],
"source": [
"# @title 1. Check current GPU assigned\n",
"!nvidia-smi -L\n",
"!nvidia-smi"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "vbI9ZsQHzjqF"
},
"outputs": [],
"source": [
"# @title 2. Download stable-diffusion Repository\n",
"from os.path import exists\n",
"\n",
"!git clone --quiet https://github.com/invoke-ai/InvokeAI.git # Original repo\n",
"%cd /content/InvokeAI/\n",
"!git checkout --quiet tags/v2.0.0"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "QbXcGXYEFSNB"
},
"outputs": [],
"source": [
"# @title 3. Install dependencies\n",
"import gc\n",
"\n",
"!wget https://raw.githubusercontent.com/invoke-ai/InvokeAI/development/environments-and-requirements/requirements-base.txt\n",
"!wget https://raw.githubusercontent.com/invoke-ai/InvokeAI/development/environments-and-requirements/requirements-win-colab-cuda.txt\n",
"!pip install colab-xterm\n",
"!pip install -r requirements-lin-win-colab-CUDA.txt\n",
"!pip install clean-fid torchtext\n",
"!pip install transformers\n",
"gc.collect()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "8rSMhgnAttQa"
},
"outputs": [],
"source": [
"# @title 4. Restart Runtime\n",
"exit()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "ChIDWxLVHGGJ"
},
"outputs": [],
"source": [
"# @title 5. Load small ML models required\n",
"import gc\n",
"\n",
"%cd /content/InvokeAI/\n",
"!python scripts/preload_models.py\n",
"gc.collect()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "795x1tMoo8b1"
},
"source": [
"## ◢ Configuration"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "YEWPV-sF1RDM"
},
"outputs": [],
"source": [
"# @title 6. Mount google Drive\n",
"from google.colab import drive\n",
"\n",
"drive.mount(\"/content/drive\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "zRTJeZ461WGu"
},
"outputs": [],
"source": [
"# @title 7. Drive Path to model\n",
"# @markdown Path should start with /content/drive/path-to-your-file <br>\n",
"# @markdown <font color=\"red\">Note:</font> Model should be downloaded from https://huggingface.co <br>\n",
"# @markdown Lastest release: [Stable-Diffusion-v-1-4](https://huggingface.co/CompVis/stable-diffusion-v-1-4-original)\n",
"from os.path import exists\n",
"\n",
"model_path = \"\" # @param {type:\"string\"}\n",
"if exists(model_path):\n",
" print(\"✅ Valid directory\")\n",
"else:\n",
" print(\"❌ File doesn't exist\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "UY-NNz4I8_aG"
},
"outputs": [],
"source": [
"# @title 8. Symlink to model\n",
"\n",
"from os.path import exists\n",
"import os\n",
"\n",
"# Folder creation if it doesn't exist\n",
"if exists(\"/content/InvokeAI/models/ldm/stable-diffusion-v1\"):\n",
" print(\"❗ Dir stable-diffusion-v1 already exists\")\n",
"else:\n",
" %mkdir /content/InvokeAI/models/ldm/stable-diffusion-v1\n",
" print(\"✅ Dir stable-diffusion-v1 created\")\n",
"\n",
"# Symbolic link if it doesn't exist\n",
"if exists(\"/content/InvokeAI/models/ldm/stable-diffusion-v1/model.ckpt\"):\n",
" print(\"❗ Symlink already created\")\n",
"else:\n",
" src = model_path\n",
" dst = \"/content/InvokeAI/models/ldm/stable-diffusion-v1/model.ckpt\"\n",
" os.symlink(src, dst)\n",
" print(\"✅ Symbolic link created successfully\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Mc28N0_NrCQH"
},
"source": [
"## ◢ Execution"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "ir4hCrMIuUpl"
},
"outputs": [],
"source": [
"# @title 9. Run Terminal and Execute Dream bot\n",
"# @markdown <font color=\"blue\">Steps:</font> <br>\n",
"# @markdown 1. Execute command `python scripts/invoke.py` to run InvokeAI.<br>\n",
"# @markdown 2. After initialized you'll see `Dream>` line.<br>\n",
"# @markdown 3. Example text: `Astronaut floating in a distant galaxy` <br>\n",
"# @markdown 4. To quit Dream bot use: `q` command.<br>\n",
"\n",
"%load_ext colabxterm\n",
"%xterm\n",
"gc.collect()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "qnLohSHmKoGk"
},
"outputs": [],
"source": [
"#@title 10. Show the last 15 generated images\n",
"import glob\n",
"import matplotlib.pyplot as plt\n",
"import matplotlib.image as mpimg\n",
"%matplotlib inline\n",
"\n",
"images = []\n",
"for img_path in sorted(glob.glob('/content/InvokeAI/outputs/img-samples/*.png'), reverse=True):\n",
" images.append(mpimg.imread(img_path))\n",
"\n",
"images = images[:15] \n",
"\n",
"plt.figure(figsize=(20,10))\n",
"\n",
"columns = 5\n",
"for i, image in enumerate(images):\n",
" ax = plt.subplot(len(images) / columns + 1, columns, i + 1)\n",
" ax.axes.xaxis.set_visible(False)\n",
" ax.axes.yaxis.set_visible(False)\n",
" ax.axis('off')\n",
" plt.imshow(image)\n",
" gc.collect()\n",
"\n"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"collapsed_sections": [],
"private_outputs": true,
"provenance": []
},
"gpuClass": "standard",
"kernelspec": {
"display_name": "Python 3.9.12 64-bit",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.9.12"
},
"vscode": {
"interpreter": {
"hash": "4e870c5c5fe42db7e2c5647ae5af656ff3391bf8c2b729cbf7fa0e16ca8cb5af"
}
}
},
"nbformat": 4,
"nbformat_minor": 0
} }

View File

@ -52,17 +52,17 @@
"name": "stdout", "name": "stdout",
"text": [ "text": [
"Cloning into 'latent-diffusion'...\n", "Cloning into 'latent-diffusion'...\n",
"remote: Enumerating objects: 992, done.\u001B[K\n", "remote: Enumerating objects: 992, done.\u001b[K\n",
"remote: Counting objects: 100% (695/695), done.\u001B[K\n", "remote: Counting objects: 100% (695/695), done.\u001b[K\n",
"remote: Compressing objects: 100% (397/397), done.\u001B[K\n", "remote: Compressing objects: 100% (397/397), done.\u001b[K\n",
"remote: Total 992 (delta 375), reused 564 (delta 253), pack-reused 297\u001B[K\n", "remote: Total 992 (delta 375), reused 564 (delta 253), pack-reused 297\u001b[K\n",
"Receiving objects: 100% (992/992), 30.78 MiB | 29.43 MiB/s, done.\n", "Receiving objects: 100% (992/992), 30.78 MiB | 29.43 MiB/s, done.\n",
"Resolving deltas: 100% (510/510), done.\n", "Resolving deltas: 100% (510/510), done.\n",
"Cloning into 'taming-transformers'...\n", "Cloning into 'taming-transformers'...\n",
"remote: Enumerating objects: 1335, done.\u001B[K\n", "remote: Enumerating objects: 1335, done.\u001b[K\n",
"remote: Counting objects: 100% (525/525), done.\u001B[K\n", "remote: Counting objects: 100% (525/525), done.\u001b[K\n",
"remote: Compressing objects: 100% (493/493), done.\u001B[K\n", "remote: Compressing objects: 100% (493/493), done.\u001b[K\n",
"remote: Total 1335 (delta 58), reused 481 (delta 30), pack-reused 810\u001B[K\n", "remote: Total 1335 (delta 58), reused 481 (delta 30), pack-reused 810\u001b[K\n",
"Receiving objects: 100% (1335/1335), 412.35 MiB | 30.53 MiB/s, done.\n", "Receiving objects: 100% (1335/1335), 412.35 MiB | 30.53 MiB/s, done.\n",
"Resolving deltas: 100% (267/267), done.\n", "Resolving deltas: 100% (267/267), done.\n",
"Obtaining file:///content/taming-transformers\n", "Obtaining file:///content/taming-transformers\n",
@ -73,23 +73,24 @@
"Installing collected packages: taming-transformers\n", "Installing collected packages: taming-transformers\n",
" Running setup.py develop for taming-transformers\n", " Running setup.py develop for taming-transformers\n",
"Successfully installed taming-transformers-0.0.1\n", "Successfully installed taming-transformers-0.0.1\n",
"\u001B[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
"tensorflow 2.8.0 requires tf-estimator-nightly==2.8.0.dev2021122109, which is not installed.\n", "tensorflow 2.8.0 requires tf-estimator-nightly==2.8.0.dev2021122109, which is not installed.\n",
"arviz 0.11.4 requires typing-extensions<4,>=3.7.4.3, but you have typing-extensions 4.1.1 which is incompatible.\u001B[0m\n" "arviz 0.11.4 requires typing-extensions<4,>=3.7.4.3, but you have typing-extensions 4.1.1 which is incompatible.\u001b[0m\n"
] ]
} }
], ],
"source": [ "source": [
"#@title Installation\n", "# @title Installation\n",
"!git clone https://github.com/CompVis/latent-diffusion.git\n", "!git clone https://github.com/CompVis/latent-diffusion.git\n",
"!git clone https://github.com/CompVis/taming-transformers\n", "!git clone https://github.com/CompVis/taming-transformers\n",
"!pip install -e ./taming-transformers\n", "!pip install -e ./taming-transformers\n",
"!pip install omegaconf>=2.0.0 pytorch-lightning>=1.0.8 torch-fidelity einops\n", "!pip install omegaconf>=2.0.0 pytorch-lightning>=1.0.8 torch-fidelity einops\n",
"\n", "\n",
"import sys\n", "import sys\n",
"\n",
"sys.path.append(\".\")\n", "sys.path.append(\".\")\n",
"sys.path.append('./taming-transformers')\n", "sys.path.append(\"./taming-transformers\")\n",
"from taming.models import vqgan " "from taming.models import vqgan"
] ]
}, },
{ {
@ -104,11 +105,11 @@
{ {
"cell_type": "code", "cell_type": "code",
"source": [ "source": [
"#@title Download\n", "# @title Download\n",
"%cd latent-diffusion/ \n", "%cd latent-diffusion/\n",
"\n", "\n",
"!mkdir -p models/ldm/cin256-v2/\n", "!mkdir -p models/ldm/cin256-v2/\n",
"!wget -O models/ldm/cin256-v2/model.ckpt https://ommer-lab.com/files/latent-diffusion/nitro/cin/model.ckpt " "!wget -O models/ldm/cin256-v2/model.ckpt https://ommer-lab.com/files/latent-diffusion/nitro/cin/model.ckpt"
], ],
"metadata": { "metadata": {
"colab": { "colab": {
@ -203,7 +204,7 @@
{ {
"cell_type": "code", "cell_type": "code",
"source": [ "source": [
"#@title loading utils\n", "# @title loading utils\n",
"import torch\n", "import torch\n",
"from omegaconf import OmegaConf\n", "from omegaconf import OmegaConf\n",
"\n", "\n",
@ -212,7 +213,7 @@
"\n", "\n",
"def load_model_from_config(config, ckpt):\n", "def load_model_from_config(config, ckpt):\n",
" print(f\"Loading model from {ckpt}\")\n", " print(f\"Loading model from {ckpt}\")\n",
" pl_sd = torch.load(ckpt)#, map_location=\"cpu\")\n", " pl_sd = torch.load(ckpt) # , map_location=\"cpu\")\n",
" sd = pl_sd[\"state_dict\"]\n", " sd = pl_sd[\"state_dict\"]\n",
" model = instantiate_from_config(config.model)\n", " model = instantiate_from_config(config.model)\n",
" m, u = model.load_state_dict(sd, strict=False)\n", " m, u = model.load_state_dict(sd, strict=False)\n",
@ -222,7 +223,7 @@
"\n", "\n",
"\n", "\n",
"def get_model():\n", "def get_model():\n",
" config = OmegaConf.load(\"configs/latent-diffusion/cin256-v2.yaml\") \n", " config = OmegaConf.load(\"configs/latent-diffusion/cin256-v2.yaml\")\n",
" model = load_model_from_config(config, \"models/ldm/cin256-v2/model.ckpt\")\n", " model = load_model_from_config(config, \"models/ldm/cin256-v2/model.ckpt\")\n",
" return model" " return model"
], ],
@ -276,18 +277,18 @@
{ {
"cell_type": "code", "cell_type": "code",
"source": [ "source": [
"import numpy as np \n", "import numpy as np\n",
"from PIL import Image\n", "from PIL import Image\n",
"from einops import rearrange\n", "from einops import rearrange\n",
"from torchvision.utils import make_grid\n", "from torchvision.utils import make_grid\n",
"\n", "\n",
"\n", "\n",
"classes = [25, 187, 448, 992] # define classes to be sampled here\n", "classes = [25, 187, 448, 992] # define classes to be sampled here\n",
"n_samples_per_class = 6\n", "n_samples_per_class = 6\n",
"\n", "\n",
"ddim_steps = 20\n", "ddim_steps = 20\n",
"ddim_eta = 0.0\n", "ddim_eta = 0.0\n",
"scale = 3.0 # for unconditional guidance\n", "scale = 3.0 # for unconditional guidance\n",
"\n", "\n",
"\n", "\n",
"all_samples = list()\n", "all_samples = list()\n",
@ -295,36 +296,39 @@
"with torch.no_grad():\n", "with torch.no_grad():\n",
" with model.ema_scope():\n", " with model.ema_scope():\n",
" uc = model.get_learned_conditioning(\n", " uc = model.get_learned_conditioning(\n",
" {model.cond_stage_key: torch.tensor(n_samples_per_class*[1000]).to(model.device)}\n", " {model.cond_stage_key: torch.tensor(n_samples_per_class * [1000]).to(model.device)}\n",
" )\n", " )\n",
" \n", "\n",
" for class_label in classes:\n", " for class_label in classes:\n",
" print(f\"rendering {n_samples_per_class} examples of class '{class_label}' in {ddim_steps} steps and using s={scale:.2f}.\")\n", " print(\n",
" xc = torch.tensor(n_samples_per_class*[class_label])\n", " f\"rendering {n_samples_per_class} examples of class '{class_label}' in {ddim_steps} steps and using s={scale:.2f}.\"\n",
" )\n",
" xc = torch.tensor(n_samples_per_class * [class_label])\n",
" c = model.get_learned_conditioning({model.cond_stage_key: xc.to(model.device)})\n", " c = model.get_learned_conditioning({model.cond_stage_key: xc.to(model.device)})\n",
" \n", "\n",
" samples_ddim, _ = sampler.sample(S=ddim_steps,\n", " samples_ddim, _ = sampler.sample(\n",
" conditioning=c,\n", " S=ddim_steps,\n",
" batch_size=n_samples_per_class,\n", " conditioning=c,\n",
" shape=[3, 64, 64],\n", " batch_size=n_samples_per_class,\n",
" verbose=False,\n", " shape=[3, 64, 64],\n",
" unconditional_guidance_scale=scale,\n", " verbose=False,\n",
" unconditional_conditioning=uc, \n", " unconditional_guidance_scale=scale,\n",
" eta=ddim_eta)\n", " unconditional_conditioning=uc,\n",
" eta=ddim_eta,\n",
" )\n",
"\n", "\n",
" x_samples_ddim = model.decode_first_stage(samples_ddim)\n", " x_samples_ddim = model.decode_first_stage(samples_ddim)\n",
" x_samples_ddim = torch.clamp((x_samples_ddim+1.0)/2.0, \n", " x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)\n",
" min=0.0, max=1.0)\n",
" all_samples.append(x_samples_ddim)\n", " all_samples.append(x_samples_ddim)\n",
"\n", "\n",
"\n", "\n",
"# display as grid\n", "# display as grid\n",
"grid = torch.stack(all_samples, 0)\n", "grid = torch.stack(all_samples, 0)\n",
"grid = rearrange(grid, 'n b c h w -> (n b) c h w')\n", "grid = rearrange(grid, \"n b c h w -> (n b) c h w\")\n",
"grid = make_grid(grid, nrow=n_samples_per_class)\n", "grid = make_grid(grid, nrow=n_samples_per_class)\n",
"\n", "\n",
"# to image\n", "# to image\n",
"grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()\n", "grid = 255.0 * rearrange(grid, \"c h w -> h w c\").cpu().numpy()\n",
"Image.fromarray(grid.astype(np.uint8))" "Image.fromarray(grid.astype(np.uint8))"
], ],
"metadata": { "metadata": {