resolved conflicts with main
@ -20,13 +20,13 @@ def calc_images_mean_L1(image1_path, image2_path):
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('image1_path')
|
||||
parser.add_argument('image2_path')
|
||||
parser.add_argument("image1_path")
|
||||
parser.add_argument("image2_path")
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
mean_L1 = calc_images_mean_L1(args.image1_path, args.image2_path)
|
||||
print(mean_L1)
|
||||
|
@ -1 +1,2 @@
|
||||
b3dccfaeb636599c02effc377cdd8a87d658256c
|
||||
218b6d0546b990fc449c876fb99f44b50c4daa35
|
||||
|
9
.github/workflows/close-inactive-issues.yml
vendored
@ -1,11 +1,11 @@
|
||||
name: Close inactive issues
|
||||
on:
|
||||
schedule:
|
||||
- cron: "00 6 * * *"
|
||||
- cron: "00 4 * * *"
|
||||
|
||||
env:
|
||||
DAYS_BEFORE_ISSUE_STALE: 14
|
||||
DAYS_BEFORE_ISSUE_CLOSE: 28
|
||||
DAYS_BEFORE_ISSUE_STALE: 30
|
||||
DAYS_BEFORE_ISSUE_CLOSE: 14
|
||||
|
||||
jobs:
|
||||
close-issues:
|
||||
@ -14,7 +14,7 @@ jobs:
|
||||
issues: write
|
||||
pull-requests: write
|
||||
steps:
|
||||
- uses: actions/stale@v5
|
||||
- uses: actions/stale@v8
|
||||
with:
|
||||
days-before-issue-stale: ${{ env.DAYS_BEFORE_ISSUE_STALE }}
|
||||
days-before-issue-close: ${{ env.DAYS_BEFORE_ISSUE_CLOSE }}
|
||||
@ -23,5 +23,6 @@ jobs:
|
||||
close-issue-message: "Due to inactivity, this issue was automatically closed. If you are still experiencing the issue, please recreate the issue."
|
||||
days-before-pr-stale: -1
|
||||
days-before-pr-close: -1
|
||||
exempt-issue-labels: "Active Issue"
|
||||
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
operations-per-run: 500
|
||||
|
27
.github/workflows/style-checks.yml
vendored
Normal file
@ -0,0 +1,27 @@
|
||||
name: Black # TODO: add isort and flake8 later
|
||||
|
||||
on:
|
||||
pull_request: {}
|
||||
push:
|
||||
branches: master
|
||||
tags: "*"
|
||||
|
||||
jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: '3.10'
|
||||
|
||||
- name: Install dependencies with pip
|
||||
run: |
|
||||
pip install --upgrade pip wheel
|
||||
pip install .[test]
|
||||
|
||||
# - run: isort --check-only .
|
||||
- run: black --check .
|
||||
# - run: flake8
|
1
.gitignore
vendored
@ -38,7 +38,6 @@ develop-eggs/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
|
10
.pre-commit-config.yaml
Normal file
@ -0,0 +1,10 @@
|
||||
# See https://pre-commit.com/ for usage and config
|
||||
repos:
|
||||
- repo: local
|
||||
hooks:
|
||||
- id: black
|
||||
name: black
|
||||
stages: [commit]
|
||||
language: system
|
||||
entry: black
|
||||
types: [python]
|
290
LICENSE-SDXL.txt
Normal file
@ -0,0 +1,290 @@
|
||||
Copyright (c) 2023 Stability AI
|
||||
CreativeML Open RAIL++-M License dated July 26, 2023
|
||||
|
||||
Section I: PREAMBLE
|
||||
|
||||
Multimodal generative models are being widely adopted and used, and
|
||||
have the potential to transform the way artists, among other
|
||||
individuals, conceive and benefit from AI or ML technologies as a tool
|
||||
for content creation.
|
||||
|
||||
Notwithstanding the current and potential benefits that these
|
||||
artifacts can bring to society at large, there are also concerns about
|
||||
potential misuses of them, either due to their technical limitations
|
||||
or ethical considerations.
|
||||
|
||||
In short, this license strives for both the open and responsible
|
||||
downstream use of the accompanying model. When it comes to the open
|
||||
character, we took inspiration from open source permissive licenses
|
||||
regarding the grant of IP rights. Referring to the downstream
|
||||
responsible use, we added use-based restrictions not permitting the
|
||||
use of the model in very specific scenarios, in order for the licensor
|
||||
to be able to enforce the license in case potential misuses of the
|
||||
Model may occur. At the same time, we strive to promote open and
|
||||
responsible research on generative models for art and content
|
||||
generation.
|
||||
|
||||
Even though downstream derivative versions of the model could be
|
||||
released under different licensing terms, the latter will always have
|
||||
to include - at minimum - the same use-based restrictions as the ones
|
||||
in the original license (this license). We believe in the intersection
|
||||
between open and responsible AI development; thus, this agreement aims
|
||||
to strike a balance between both in order to enable responsible
|
||||
open-science in the field of AI.
|
||||
|
||||
This CreativeML Open RAIL++-M License governs the use of the model
|
||||
(and its derivatives) and is informed by the model card associated
|
||||
with the model.
|
||||
|
||||
NOW THEREFORE, You and Licensor agree as follows:
|
||||
|
||||
Definitions
|
||||
|
||||
"License" means the terms and conditions for use, reproduction, and
|
||||
Distribution as defined in this document.
|
||||
|
||||
"Data" means a collection of information and/or content extracted from
|
||||
the dataset used with the Model, including to train, pretrain, or
|
||||
otherwise evaluate the Model. The Data is not licensed under this
|
||||
License.
|
||||
|
||||
"Output" means the results of operating a Model as embodied in
|
||||
informational content resulting therefrom.
|
||||
|
||||
"Model" means any accompanying machine-learning based assemblies
|
||||
(including checkpoints), consisting of learnt weights, parameters
|
||||
(including optimizer states), corresponding to the model architecture
|
||||
as embodied in the Complementary Material, that have been trained or
|
||||
tuned, in whole or in part on the Data, using the Complementary
|
||||
Material.
|
||||
|
||||
"Derivatives of the Model" means all modifications to the Model, works
|
||||
based on the Model, or any other model which is created or initialized
|
||||
by transfer of patterns of the weights, parameters, activations or
|
||||
output of the Model, to the other model, in order to cause the other
|
||||
model to perform similarly to the Model, including - but not limited
|
||||
to - distillation methods entailing the use of intermediate data
|
||||
representations or methods based on the generation of synthetic data
|
||||
by the Model for training the other model.
|
||||
|
||||
"Complementary Material" means the accompanying source code and
|
||||
scripts used to define, run, load, benchmark or evaluate the Model,
|
||||
and used to prepare data for training or evaluation, if any. This
|
||||
includes any accompanying documentation, tutorials, examples, etc, if
|
||||
any.
|
||||
|
||||
"Distribution" means any transmission, reproduction, publication or
|
||||
other sharing of the Model or Derivatives of the Model to a third
|
||||
party, including providing the Model as a hosted service made
|
||||
available by electronic or other remote means - e.g. API-based or web
|
||||
access.
|
||||
|
||||
"Licensor" means the copyright owner or entity authorized by the
|
||||
copyright owner that is granting the License, including the persons or
|
||||
entities that may have rights in the Model and/or distributing the
|
||||
Model.
|
||||
|
||||
"You" (or "Your") means an individual or Legal Entity exercising
|
||||
permissions granted by this License and/or making use of the Model for
|
||||
whichever purpose and in any field of use, including usage of the
|
||||
Model in an end-use application - e.g. chatbot, translator, image
|
||||
generator.
|
||||
|
||||
"Third Parties" means individuals or legal entities that are not under
|
||||
common control with Licensor or You.
|
||||
|
||||
"Contribution" means any work of authorship, including the original
|
||||
version of the Model and any modifications or additions to that Model
|
||||
or Derivatives of the Model thereof, that is intentionally submitted
|
||||
to Licensor for inclusion in the Model by the copyright owner or by an
|
||||
individual or Legal Entity authorized to submit on behalf of the
|
||||
copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent to
|
||||
the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control
|
||||
systems, and issue tracking systems that are managed by, or on behalf
|
||||
of, the Licensor for the purpose of discussing and improving the
|
||||
Model, but excluding communication that is conspicuously marked or
|
||||
otherwise designated in writing by the copyright owner as "Not a
|
||||
Contribution."
|
||||
|
||||
"Contributor" means Licensor and any individual or Legal Entity on
|
||||
behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Model.
|
||||
|
||||
Section II: INTELLECTUAL PROPERTY RIGHTS
|
||||
|
||||
Both copyright and patent grants apply to the Model, Derivatives of
|
||||
the Model and Complementary Material. The Model and Derivatives of the
|
||||
Model are subject to additional terms as described in
|
||||
|
||||
Section III.
|
||||
|
||||
Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare, publicly display, publicly
|
||||
perform, sublicense, and distribute the Complementary Material, the
|
||||
Model, and Derivatives of the Model.
|
||||
|
||||
Grant of Patent License. Subject to the terms and conditions of this
|
||||
License and where and as applicable, each Contributor hereby grants to
|
||||
You a perpetual, worldwide, non-exclusive, no-charge, royalty-free,
|
||||
irrevocable (except as stated in this paragraph) patent license to
|
||||
make, have made, use, offer to sell, sell, import, and otherwise
|
||||
transfer the Model and the Complementary Material, where such license
|
||||
applies only to those patent claims licensable by such Contributor
|
||||
that are necessarily infringed by their Contribution(s) alone or by
|
||||
combination of their Contribution(s) with the Model to which such
|
||||
Contribution(s) was submitted. If You institute patent litigation
|
||||
against any entity (including a cross-claim or counterclaim in a
|
||||
lawsuit) alleging that the Model and/or Complementary Material or a
|
||||
Contribution incorporated within the Model and/or Complementary
|
||||
Material constitutes direct or contributory patent infringement, then
|
||||
any patent licenses granted to You under this License for the Model
|
||||
and/or Work shall terminate as of the date such litigation is asserted
|
||||
or filed.
|
||||
|
||||
Section III: CONDITIONS OF USAGE, DISTRIBUTION AND REDISTRIBUTION
|
||||
|
||||
Distribution and Redistribution. You may host for Third Party remote
|
||||
access purposes (e.g. software-as-a-service), reproduce and distribute
|
||||
copies of the Model or Derivatives of the Model thereof in any medium,
|
||||
with or without modifications, provided that You meet the following
|
||||
conditions: Use-based restrictions as referenced in paragraph 5 MUST
|
||||
be included as an enforceable provision by You in any type of legal
|
||||
agreement (e.g. a license) governing the use and/or distribution of
|
||||
the Model or Derivatives of the Model, and You shall give notice to
|
||||
subsequent users You Distribute to, that the Model or Derivatives of
|
||||
the Model are subject to paragraph 5. This provision does not apply to
|
||||
the use of Complementary Material. You must give any Third Party
|
||||
recipients of the Model or Derivatives of the Model a copy of this
|
||||
License; You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; You must retain all copyright,
|
||||
patent, trademark, and attribution notices excluding those notices
|
||||
that do not pertain to any part of the Model, Derivatives of the
|
||||
Model. You may add Your own copyright statement to Your modifications
|
||||
and may provide additional or different license terms and conditions -
|
||||
respecting paragraph 4.a. - for use, reproduction, or Distribution of
|
||||
Your modifications, or for any such Derivatives of the Model as a
|
||||
whole, provided Your use, reproduction, and Distribution of the Model
|
||||
otherwise complies with the conditions stated in this License.
|
||||
|
||||
Use-based restrictions. The restrictions set forth in Attachment A are
|
||||
considered Use-based restrictions. Therefore You cannot use the Model
|
||||
and the Derivatives of the Model for the specified restricted
|
||||
uses. You may use the Model subject to this License, including only
|
||||
for lawful purposes and in accordance with the License. Use may
|
||||
include creating any content with, finetuning, updating, running,
|
||||
training, evaluating and/or reparametrizing the Model. You shall
|
||||
require all of Your users who use the Model or a Derivative of the
|
||||
Model to comply with the terms of this paragraph (paragraph 5).
|
||||
|
||||
The Output You Generate. Except as set forth herein, Licensor claims
|
||||
no rights in the Output You generate using the Model. You are
|
||||
accountable for the Output you generate and its subsequent uses. No
|
||||
use of the output can contravene any provision as stated in the
|
||||
License.
|
||||
|
||||
Section IV: OTHER PROVISIONS
|
||||
|
||||
Updates and Runtime Restrictions. To the maximum extent permitted by
|
||||
law, Licensor reserves the right to restrict (remotely or otherwise)
|
||||
usage of the Model in violation of this License.
|
||||
|
||||
Trademarks and related. Nothing in this License permits You to make
|
||||
use of Licensors’ trademarks, trade names, logos or to otherwise
|
||||
suggest endorsement or misrepresent the relationship between the
|
||||
parties; and any rights not expressly granted herein are reserved by
|
||||
the Licensors.
|
||||
|
||||
Disclaimer of Warranty. Unless required by applicable law or agreed to
|
||||
in writing, Licensor provides the Model and the Complementary Material
|
||||
(and each Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Model, Derivatives of
|
||||
the Model, and the Complementary Material and assume any risks
|
||||
associated with Your exercise of permissions under this License.
|
||||
|
||||
Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise, unless
|
||||
required by applicable law (such as deliberate and grossly negligent
|
||||
acts) or agreed to in writing, shall any Contributor be liable to You
|
||||
for damages, including any direct, indirect, special, incidental, or
|
||||
consequential damages of any character arising as a result of this
|
||||
License or out of the use or inability to use the Model and the
|
||||
Complementary Material (including but not limited to damages for loss
|
||||
of goodwill, work stoppage, computer failure or malfunction, or any
|
||||
and all other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
Accepting Warranty or Additional Liability. While redistributing the
|
||||
Model, Derivatives of the Model and the Complementary Material
|
||||
thereof, You may choose to offer, and charge a fee for, acceptance of
|
||||
support, warranty, indemnity, or other liability obligations and/or
|
||||
rights consistent with this License. However, in accepting such
|
||||
obligations, You may act only on Your own behalf and on Your sole
|
||||
responsibility, not on behalf of any other Contributor, and only if
|
||||
You agree to indemnify, defend, and hold each Contributor harmless for
|
||||
any liability incurred by, or claims asserted against, such
|
||||
Contributor by reason of your accepting any such warranty or
|
||||
additional liability.
|
||||
|
||||
If any provision of this License is held to be invalid, illegal or
|
||||
unenforceable, the remaining provisions shall be unaffected thereby
|
||||
and remain valid as if such provision had not been set forth herein.
|
||||
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
Attachment A
|
||||
|
||||
Use Restrictions
|
||||
|
||||
You agree not to use the Model or Derivatives of the Model:
|
||||
|
||||
* In any way that violates any applicable national, federal, state,
|
||||
local or international law or regulation;
|
||||
|
||||
* For the purpose of exploiting, harming or attempting to exploit or
|
||||
harm minors in any way;
|
||||
|
||||
* To generate or disseminate verifiably false information and/or
|
||||
content with the purpose of harming others;
|
||||
|
||||
* To generate or disseminate personal identifiable information that
|
||||
can be used to harm an individual;
|
||||
|
||||
* To defame, disparage or otherwise harass others;
|
||||
|
||||
* For fully automated decision making that adversely impacts an
|
||||
individual’s legal rights or otherwise creates or modifies a
|
||||
binding, enforceable obligation;
|
||||
|
||||
* For any use intended to or which has the effect of discriminating
|
||||
against or harming individuals or groups based on online or offline
|
||||
social behavior or known or predicted personal or personality
|
||||
characteristics;
|
||||
|
||||
* To exploit any of the vulnerabilities of a specific group of persons
|
||||
based on their age, social, physical or mental characteristics, in
|
||||
order to materially distort the behavior of a person pertaining to
|
||||
that group in a manner that causes or is likely to cause that person
|
||||
or another person physical or psychological harm;
|
||||
|
||||
* For any use intended to or which has the effect of discriminating
|
||||
against individuals or groups based on legally protected
|
||||
characteristics or categories;
|
||||
|
||||
* To provide medical advice and medical results interpretation;
|
||||
|
||||
* To generate or disseminate information for the purpose to be used
|
||||
for administration of justice, law enforcement, immigration or
|
||||
asylum processes, such as predicting an individual will commit
|
||||
fraud/crime commitment (e.g. by text profiling, drawing causal
|
||||
relationships between assertions made in documents, indiscriminate
|
||||
and arbitrarily-targeted use).
|
||||
|
BIN
docs/assets/nodes/groupsallscale.png
Normal file
After Width: | Height: | Size: 490 KiB |
BIN
docs/assets/nodes/groupsconditioning.png
Normal file
After Width: | Height: | Size: 335 KiB |
BIN
docs/assets/nodes/groupscontrol.png
Normal file
After Width: | Height: | Size: 217 KiB |
BIN
docs/assets/nodes/groupsimgvae.png
Normal file
After Width: | Height: | Size: 244 KiB |
BIN
docs/assets/nodes/groupsiterate.png
Normal file
After Width: | Height: | Size: 948 KiB |
BIN
docs/assets/nodes/groupslora.png
Normal file
After Width: | Height: | Size: 292 KiB |
BIN
docs/assets/nodes/groupsmultigenseeding.png
Normal file
After Width: | Height: | Size: 420 KiB |
BIN
docs/assets/nodes/groupsnoise.png
Normal file
After Width: | Height: | Size: 179 KiB |
BIN
docs/assets/nodes/groupsrandseed.png
Normal file
After Width: | Height: | Size: 216 KiB |
BIN
docs/assets/nodes/nodescontrol.png
Normal file
After Width: | Height: | Size: 439 KiB |
BIN
docs/assets/nodes/nodesi2i.png
Normal file
After Width: | Height: | Size: 563 KiB |
BIN
docs/assets/nodes/nodest2i.png
Normal file
After Width: | Height: | Size: 353 KiB |
@ -65,7 +65,6 @@ InvokeAI:
|
||||
esrgan: true
|
||||
internet_available: true
|
||||
log_tokenization: false
|
||||
nsfw_checker: false
|
||||
patchmatch: true
|
||||
restore: true
|
||||
...
|
||||
@ -136,19 +135,16 @@ command-line options by giving the `--help` argument:
|
||||
|
||||
```
|
||||
(.venv) > invokeai-web --help
|
||||
usage: InvokeAI [-h] [--host HOST] [--port PORT] [--allow_origins [ALLOW_ORIGINS ...]] [--allow_credentials | --no-allow_credentials]
|
||||
[--allow_methods [ALLOW_METHODS ...]] [--allow_headers [ALLOW_HEADERS ...]] [--esrgan | --no-esrgan]
|
||||
[--internet_available | --no-internet_available] [--log_tokenization | --no-log_tokenization]
|
||||
[--nsfw_checker | --no-nsfw_checker] [--patchmatch | --no-patchmatch] [--restore | --no-restore]
|
||||
[--always_use_cpu | --no-always_use_cpu] [--free_gpu_mem | --no-free_gpu_mem] [--max_cache_size MAX_CACHE_SIZE]
|
||||
[--max_vram_cache_size MAX_VRAM_CACHE_SIZE] [--precision {auto,float16,float32,autocast}]
|
||||
[--sequential_guidance | --no-sequential_guidance] [--xformers_enabled | --no-xformers_enabled]
|
||||
[--tiled_decode | --no-tiled_decode] [--root ROOT] [--autoimport_dir AUTOIMPORT_DIR] [--lora_dir LORA_DIR]
|
||||
[--embedding_dir EMBEDDING_DIR] [--controlnet_dir CONTROLNET_DIR] [--conf_path CONF_PATH] [--models_dir MODELS_DIR]
|
||||
[--legacy_conf_dir LEGACY_CONF_DIR] [--db_dir DB_DIR] [--outdir OUTDIR] [--from_file FROM_FILE]
|
||||
[--use_memory_db | --no-use_memory_db] [--model MODEL] [--log_handlers [LOG_HANDLERS ...]]
|
||||
[--log_format {plain,color,syslog,legacy}] [--log_level {debug,info,warning,error,critical}]
|
||||
...
|
||||
usage: InvokeAI [-h] [--host HOST] [--port PORT] [--allow_origins [ALLOW_ORIGINS ...]] [--allow_credentials | --no-allow_credentials] [--allow_methods [ALLOW_METHODS ...]]
|
||||
[--allow_headers [ALLOW_HEADERS ...]] [--esrgan | --no-esrgan] [--internet_available | --no-internet_available] [--log_tokenization | --no-log_tokenization]
|
||||
[--patchmatch | --no-patchmatch] [--restore | --no-restore]
|
||||
[--always_use_cpu | --no-always_use_cpu] [--free_gpu_mem | --no-free_gpu_mem] [--max_loaded_models MAX_LOADED_MODELS] [--max_cache_size MAX_CACHE_SIZE]
|
||||
[--max_vram_cache_size MAX_VRAM_CACHE_SIZE] [--gpu_mem_reserved GPU_MEM_RESERVED] [--precision {auto,float16,float32,autocast}]
|
||||
[--sequential_guidance | --no-sequential_guidance] [--xformers_enabled | --no-xformers_enabled] [--tiled_decode | --no-tiled_decode] [--root ROOT]
|
||||
[--autoimport_dir AUTOIMPORT_DIR] [--lora_dir LORA_DIR] [--embedding_dir EMBEDDING_DIR] [--controlnet_dir CONTROLNET_DIR] [--conf_path CONF_PATH]
|
||||
[--models_dir MODELS_DIR] [--legacy_conf_dir LEGACY_CONF_DIR] [--db_dir DB_DIR] [--outdir OUTDIR] [--from_file FROM_FILE]
|
||||
[--use_memory_db | --no-use_memory_db] [--model MODEL] [--log_handlers [LOG_HANDLERS ...]] [--log_format {plain,color,syslog,legacy}]
|
||||
[--log_level {debug,info,warning,error,critical}] [--version | --no-version]
|
||||
```
|
||||
|
||||
## The Configuration Settings
|
||||
@ -178,7 +174,6 @@ These configuration settings allow you to enable and disable various InvokeAI fe
|
||||
| `esrgan` | `true` | Activate the ESRGAN upscaling options|
|
||||
| `internet_available` | `true` | When a resource is not available locally, try to fetch it via the internet |
|
||||
| `log_tokenization` | `false` | Before each text2image generation, print a color-coded representation of the prompt to the console; this can help understand why a prompt is not working as expected |
|
||||
| `nsfw_checker` | `true` | Activate the NSFW checker to blur out risque images |
|
||||
| `patchmatch` | `true` | Activate the "patchmatch" algorithm for improved inpainting |
|
||||
| `restore` | `true` | Activate the facial restoration features (DEPRECATED; restoration features will be removed in 3.0.0) |
|
||||
|
||||
|
@ -61,11 +61,13 @@ A noise scheduler (eg. DPM++ 2M Karras) schedules the subtraction of noise from
|
||||
| ImageInverseLerp | Inverse linear interpolation of all pixels of an image |
|
||||
| ImageLerp | Linear interpolation of all pixels of an image |
|
||||
| ImageMultiply | Multiplies two images together using `PIL.ImageChops.Multiply()` |
|
||||
| ImageNSFWBlurInvocation | Detects and blurs images that may contain sexually explicit content |
|
||||
| ImagePaste | Pastes an image into another image |
|
||||
| ImageProcessor | Base class for invocations that reprocess images for ControlNet |
|
||||
| ImageResize | Resizes an image to specific dimensions |
|
||||
| ImageScale | Scales an image by a factor |
|
||||
| ImageToLatents | Scales latents by a given factor |
|
||||
| ImageWatermarkInvocation | Adds an invisible watermark to images |
|
||||
| InfillColor | Infills transparent areas of an image with a solid color |
|
||||
| InfillPatchMatch | Infills transparent areas of an image using the PatchMatch algorithm |
|
||||
| InfillTile | Infills transparent areas of an image with tiles of the image |
|
||||
@ -116,49 +118,49 @@ There are several node grouping concepts that can be examined with a narrow focu
|
||||
|
||||
As described, an initial noise tensor is necessary for the latent diffusion process. As a result, all non-image *ToLatents nodes require a noise node input.
|
||||
|
||||
<img width="654" alt="groupsnoise" src="https://github.com/ymgenesis/InvokeAI/assets/25252829/2e8d297e-ad55-4d27-bc93-c119dad2a2c5">
|
||||
![groupsnoise](../assets/nodes/groupsnoise.png)
|
||||
|
||||
### Conditioning
|
||||
|
||||
As described, conditioning is necessary for the latent diffusion process, whether empty or not. As a result, all non-image *ToLatents nodes require positive and negative conditioning inputs. Conditioning is reliant on a CLIP tokenizer provided by the Model Loader node.
|
||||
|
||||
<img width="1024" alt="groupsconditioning" src="https://github.com/ymgenesis/InvokeAI/assets/25252829/f8f7ad8a-8d9c-418e-b5ad-1437b774b27e">
|
||||
![groupsconditioning](../assets/nodes/groupsconditioning.png)
|
||||
|
||||
### Image Space & VAE
|
||||
|
||||
The ImageToLatents node doesn't require a noise node input, but requires a VAE input to convert the image from image space into latent space. In reverse, the LatentsToImage node requires a VAE input to convert from latent space back into image space.
|
||||
|
||||
<img width="637" alt="groupsimgvae" src="https://github.com/ymgenesis/InvokeAI/assets/25252829/dd99969c-e0a8-4f78-9b17-3ffe179cef9a">
|
||||
![groupsimgvae](../assets/nodes/groupsimgvae.png)
|
||||
|
||||
### Defined & Random Seeds
|
||||
|
||||
It is common to want to use both the same seed (for continuity) and random seeds (for variance). To define a seed, simply enter it into the 'Seed' field on a noise node. Conversely, the RandomInt node generates a random integer between 'Low' and 'High', and can be used as input to the 'Seed' edge point on a noise node to randomize your seed.
|
||||
|
||||
<img width="922" alt="groupsrandseed" src="https://github.com/ymgenesis/InvokeAI/assets/25252829/af55bc20-60f6-438e-aba5-3ec871443710">
|
||||
![groupsrandseed](../assets/nodes/groupsrandseed.png)
|
||||
|
||||
### Control
|
||||
|
||||
Control means to guide the diffusion process to adhere to a defined input or structure. Control can be provided as input to non-image *ToLatents nodes from ControlNet nodes. ControlNet nodes usually require an image processor which converts an input image for use with ControlNet.
|
||||
|
||||
<img width="805" alt="groupscontrol" src="https://github.com/ymgenesis/InvokeAI/assets/25252829/cc9c5de7-23a7-46c8-bbad-1f3609d999a6">
|
||||
![groupscontrol](../assets/nodes/groupscontrol.png)
|
||||
|
||||
### LoRA
|
||||
|
||||
The Lora Loader node lets you load a LoRA (say that ten times fast) and pass it as output to both the Prompt (Compel) and non-image *ToLatents nodes. A model's CLIP tokenizer is passed through the LoRA into Prompt (Compel), where it affects conditioning. A model's U-Net is also passed through the LoRA into a non-image *ToLatents node, where it affects noise prediction.
|
||||
|
||||
<img width="993" alt="groupslora" src="https://github.com/ymgenesis/InvokeAI/assets/25252829/630962b0-d914-4505-b3ea-ccae9b0269da">
|
||||
![groupslora](../assets/nodes/groupslora.png)
|
||||
|
||||
### Scaling
|
||||
|
||||
Use the ImageScale, ScaleLatents, and Upscale nodes to upscale images and/or latent images. The chosen method differs across contexts. However, be aware that latents are already noisy and compressed at their original resolution; scaling an image could produce more detailed results.
|
||||
|
||||
<img width="644" alt="groupsallscale" src="https://github.com/ymgenesis/InvokeAI/assets/25252829/99314f05-dd9f-4b6d-b378-31de55346a13">
|
||||
![groupsallscale](../assets/nodes/groupsallscale.png)
|
||||
|
||||
### Iteration + Multiple Images as Input
|
||||
|
||||
Iteration is a common concept in any processing, and means to repeat a process with given input. In nodes, you're able to use the Iterate node to iterate through collections usually gathered by the Collect node. The Iterate node has many potential uses, from processing a collection of images one after another, to varying seeds across multiple image generations and more. This screenshot demonstrates how to collect several images and pass them out one at a time.
|
||||
|
||||
<img width="788" alt="groupsiterate" src="https://github.com/ymgenesis/InvokeAI/assets/25252829/4af5ca27-82c9-4018-8c5b-024d3ee0a121">
|
||||
![groupsiterate](../assets/nodes/groupsiterate.png)
|
||||
|
||||
### Multiple Image Generation + Random Seeds
|
||||
|
||||
@ -166,7 +168,7 @@ Multiple image generation in the node editor is done using the RandomRange node.
|
||||
|
||||
To control seeds across generations takes some care. The first row in the screenshot will generate multiple images with different seeds, but using the same RandomRange parameters across invocations will result in the same group of random seeds being used across the images, producing repeatable results. In the second row, adding the RandomInt node as input to RandomRange's 'Seed' edge point will ensure that seeds are varied across all images across invocations, producing varied results.
|
||||
|
||||
<img width="1027" alt="groupsmultigenseeding" src="https://github.com/ymgenesis/InvokeAI/assets/25252829/518d1b2b-fed1-416b-a052-ab06552521b3">
|
||||
![groupsmultigenseeding](../assets/nodes/groupsmultigenseeding.png)
|
||||
|
||||
## Examples
|
||||
|
||||
@ -174,7 +176,7 @@ With our knowledge of node grouping and the diffusion process, let’s break dow
|
||||
|
||||
### Basic text-to-image Node Graph
|
||||
|
||||
<img width="875" alt="nodest2i" src="https://github.com/ymgenesis/InvokeAI/assets/25252829/17c67720-c376-4db8-94f0-5e00381a61ee">
|
||||
![nodest2i](../assets/nodes/nodest2i.png)
|
||||
|
||||
- Model Loader: A necessity to generating images (as we’ve read above). We choose our model from the dropdown. It outputs a U-Net, CLIP tokenizer, and VAE.
|
||||
- Prompt (Compel): Another necessity. Two prompt nodes are created. One will output positive conditioning (what you want, ‘dog’), one will output negative (what you don’t want, ‘cat’). They both input the CLIP tokenizer that the Model Loader node outputs.
|
||||
@ -184,7 +186,7 @@ With our knowledge of node grouping and the diffusion process, let’s break dow
|
||||
|
||||
### Basic image-to-image Node Graph
|
||||
|
||||
<img width="998" alt="nodesi2i" src="https://github.com/ymgenesis/InvokeAI/assets/25252829/3f2c95d5-cee7-4415-9b79-b46ee60a92fe">
|
||||
![nodesi2i](../assets/nodes/nodesi2i.png)
|
||||
|
||||
- Model Loader: Choose a model from the dropdown.
|
||||
- Prompt (Compel): Two prompt nodes. One positive (dog), one negative (dog). Same CLIP inputs from the Model Loader node as before.
|
||||
@ -195,7 +197,7 @@ With our knowledge of node grouping and the diffusion process, let’s break dow
|
||||
|
||||
### Basic ControlNet Node Graph
|
||||
|
||||
<img width="703" alt="nodescontrol" src="https://github.com/ymgenesis/InvokeAI/assets/25252829/b02ded86-ceb4-44a2-9910-e19ad184d471">
|
||||
![nodescontrol](../assets/nodes/nodescontrol.png)
|
||||
|
||||
- Model Loader
|
||||
- Prompt (Compel)
|
||||
|
@ -16,21 +16,24 @@ Output Example:
|
||||
|
||||
---
|
||||
|
||||
## **Seamless Tiling**
|
||||
## **Invisible Watermark**
|
||||
|
||||
The seamless tiling mode causes generated images to seamlessly tile
|
||||
with itself creating repetitive wallpaper-like patterns. To use it,
|
||||
activate the Seamless Tiling option in the Web GUI and then select
|
||||
whether to tile on the X (horizontal) and/or Y (vertical) axes. Tiling
|
||||
will then be active for the next set of generations.
|
||||
In keeping with the principles for responsible AI generation, and to
|
||||
help AI researchers avoid synthetic images contaminating their
|
||||
training sets, InvokeAI adds an invisible watermark to each of the
|
||||
final images it generates. The watermark consists of the text
|
||||
"InvokeAI" and can be viewed using the
|
||||
[invisible-watermarks](https://github.com/ShieldMnt/invisible-watermark)
|
||||
tool.
|
||||
|
||||
A nice prompt to test seamless tiling with is:
|
||||
Watermarking is controlled using the `invisible-watermark` setting in
|
||||
`invokeai.yaml`. To turn it off, add the following line under the `Features`
|
||||
category.
|
||||
|
||||
```
|
||||
pond garden with lotus by claude monet"
|
||||
invisible_watermark: false
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## **Weighted Prompts**
|
||||
|
||||
@ -39,34 +42,10 @@ priority to them, by adding `:<percent>` to the end of the section you wish to u
|
||||
example consider this prompt:
|
||||
|
||||
```bash
|
||||
tabby cat:0.25 white duck:0.75 hybrid
|
||||
(tabby cat):0.25 (white duck):0.75 hybrid
|
||||
```
|
||||
|
||||
This will tell the sampler to invest 25% of its effort on the tabby cat aspect of the image and 75%
|
||||
on the white duck aspect (surprisingly, this example actually works). The prompt weights can use any
|
||||
combination of integers and floating point numbers, and they do not need to add up to 1.
|
||||
|
||||
## **Thresholding and Perlin Noise Initialization Options**
|
||||
|
||||
Under the Noise section of the Web UI, you will find two options named
|
||||
Perlin Noise and Noise Threshold. [Perlin
|
||||
noise](https://en.wikipedia.org/wiki/Perlin_noise) is a type of
|
||||
structured noise used to simulate terrain and other natural
|
||||
textures. The slider controls the percentage of perlin noise that will
|
||||
be mixed into the image at the beginning of generation. Adding a little
|
||||
perlin noise to a generation will alter the image substantially.
|
||||
|
||||
The noise threshold limits the range of the latent values during
|
||||
sampling and helps combat the oversharpening seem with higher CFG
|
||||
scale values.
|
||||
|
||||
For better intuition into what these options do in practice:
|
||||
|
||||
![here is a graphic demonstrating them both](../assets/truncation_comparison.jpg)
|
||||
|
||||
In generating this graphic, perlin noise at initialization was
|
||||
programmatically varied going across on the diagram by values 0.0,
|
||||
0.1, 0.2, 0.4, 0.5, 0.6, 0.8, 0.9, 1.0; and the threshold was varied
|
||||
going down from 0, 1, 2, 3, 4, 5, 10, 20, 100. The other options are
|
||||
fixed using the prompt "a portrait of a beautiful young lady" a CFG of
|
||||
20, 100 steps, and a seed of 1950357039.
|
||||
|
@ -1,12 +1,40 @@
|
||||
---
|
||||
title: The NSFW Checker
|
||||
title: Watermarking, NSFW Image Checking
|
||||
---
|
||||
|
||||
# :material-image-off: NSFW Checker
|
||||
# :material-image-off: Invisible Watermark and the NSFW Checker
|
||||
|
||||
## Watermarking
|
||||
|
||||
InvokeAI does not apply watermarking to images by default. However,
|
||||
many computer scientists working in the field of generative AI worry
|
||||
that a flood of computer-generated imagery will contaminate the image
|
||||
data sets needed to train future generations of generative models.
|
||||
|
||||
InvokeAI offers an optional watermarking mode that writes a small bit
|
||||
of text, **InvokeAI**, into each image that it generates using an
|
||||
"invisible" watermarking library that spreads the information
|
||||
throughout the image in a way that is not perceptible to the human
|
||||
eye. If you are planning to share your generated images on
|
||||
internet-accessible services, we encourage you to activate the
|
||||
invisible watermark mode in order to help preserve the digital image
|
||||
environment.
|
||||
|
||||
The downside of watermarking is that it increases the size of the
|
||||
image moderately, and has been reported by some individuals to degrade
|
||||
image quality. Your mileage may vary.
|
||||
|
||||
To read the watermark in an image, activate the InvokeAI virtual
|
||||
environment (called the "developer's console" in the launcher) and run
|
||||
the command:
|
||||
|
||||
```
|
||||
invisible-watermark -a decode -t bytes -m dwtDct -l 64 /path/to/image.png
|
||||
```
|
||||
|
||||
## The NSFW ("Safety") Checker
|
||||
|
||||
The Stable Diffusion image generation models will produce sexual
|
||||
Stable Diffusion 1.5-based image generation models will produce sexual
|
||||
imagery if deliberately prompted, and will occasionally produce such
|
||||
images when this is not intended. Such images are colloquially known
|
||||
as "Not Safe for Work" (NSFW). This behavior is due to the nature of
|
||||
@ -18,35 +46,17 @@ jurisdictions it may be illegal to publicly distribute such imagery,
|
||||
including mounting a publicly-available server that provides
|
||||
unfiltered images to the public. Furthermore, the [Stable Diffusion
|
||||
weights
|
||||
License](https://github.com/invoke-ai/InvokeAI/blob/main/LICENSE-ModelWeights.txt)
|
||||
forbids the model from being used to "exploit any of the
|
||||
License](https://github.com/invoke-ai/InvokeAI/blob/main/LICENSE-SD1+SD2.txt),
|
||||
and the [Stable Diffusion XL
|
||||
License][https://github.com/invoke-ai/InvokeAI/blob/main/LICENSE-SDXL.txt]
|
||||
both forbid the models from being used to "exploit any of the
|
||||
vulnerabilities of a specific group of persons."
|
||||
|
||||
For these reasons Stable Diffusion offers a "safety checker," a
|
||||
machine learning model trained to recognize potentially disturbing
|
||||
imagery. When a potentially NSFW image is detected, the checker will
|
||||
blur the image and paste a warning icon on top. The checker can be
|
||||
turned on and off on the command line using `--nsfw_checker` and
|
||||
`--no-nsfw_checker`.
|
||||
|
||||
At installation time, InvokeAI will ask whether the checker should be
|
||||
activated by default (neither argument given on the command line). The
|
||||
response is stored in the InvokeAI initialization file
|
||||
(`invokeai.yaml` in the InvokeAI root directory). You can change the
|
||||
default at any time by opening this file in a text editor and
|
||||
changing the line `nsfw_checker:` from true to false or vice-versa:
|
||||
|
||||
|
||||
```
|
||||
...
|
||||
Features:
|
||||
esrgan: true
|
||||
internet_available: true
|
||||
log_tokenization: false
|
||||
nsfw_checker: true
|
||||
patchmatch: true
|
||||
restore: true
|
||||
```
|
||||
turned on and off in the Web interface under Settings.
|
||||
|
||||
## Caveats
|
||||
|
||||
@ -84,10 +94,3 @@ are encouraged to turn **off** intermediate image rendering when you
|
||||
are using the checker. Future versions of InvokeAI will apply
|
||||
additional blurring to intermediate images when the checker is active.
|
||||
|
||||
### Watermarking
|
||||
|
||||
InvokeAI does not apply any sort of watermark to images it
|
||||
generates. However, it does write metadata into the PNG data area,
|
||||
including the prompt used to generate the image and relevant parameter
|
||||
settings. These fields can be examined using the `sd-metadata.py`
|
||||
script that comes with the InvokeAI package.
|
@ -148,7 +148,7 @@ images in full-precision mode:
|
||||
- [Model Merging](features/MODEL_MERGING.md)
|
||||
- [ControlNet Models](features/CONTROLNET.md)
|
||||
- [Style/Subject Concepts and Embeddings](features/CONCEPTS.md)
|
||||
- [Not Safe for Work (NSFW) Checker](features/NSFW.md)
|
||||
- [Watermarking and the Not Safe for Work (NSFW) Checker](features/WATERMARK+NSFW.md)
|
||||
<!-- seperator -->
|
||||
### Prompt Engineering
|
||||
- [Prompt Syntax](features/PROMPTS.md)
|
||||
|
@ -213,17 +213,6 @@ experimental versions later.
|
||||
Generally the defaults are fine, and you can come back to this screen at
|
||||
any time to tweak your system. Here are the options you can adjust:
|
||||
|
||||
- ***Output directory for images***
|
||||
This is the path to a directory in which InvokeAI will store all its
|
||||
generated images.
|
||||
|
||||
- ***NSFW checker***
|
||||
If checked, InvokeAI will test images for potential sexual content
|
||||
and blur them out if found. Note that the NSFW checker consumes
|
||||
an additional 0.6 GB of VRAM on top of the 2-3 GB of VRAM used
|
||||
by most image models. If you have a low VRAM GPU (4-6 GB), you
|
||||
can reduce out of memory errors by disabling the checker.
|
||||
|
||||
- ***HuggingFace Access Token***
|
||||
InvokeAI has the ability to download embedded styles and subjects
|
||||
from the HuggingFace Concept Library on-demand. However, some of
|
||||
@ -255,20 +244,30 @@ experimental versions later.
|
||||
and graphics cards. The "autocast" option is deprecated and
|
||||
shouldn't be used unless you are asked to by a member of the team.
|
||||
|
||||
- ***Number of models to cache in CPU memory***
|
||||
- **Size of the RAM cache used for fast model switching***
|
||||
This allows you to keep models in memory and switch rapidly among
|
||||
them rather than having them load from disk each time. This slider
|
||||
controls how many models to keep loaded at once. Each
|
||||
model will use 2-4 GB of RAM, so use this cautiously
|
||||
controls how many models to keep loaded at once. A typical SD-1 or SD-2 model
|
||||
uses 2-3 GB of memory. A typical SDXL model uses 6-7 GB. Providing more
|
||||
RAM will allow more models to be co-resident.
|
||||
|
||||
- ***Directory containing embedding/textual inversion files***
|
||||
This is the directory in which you can place custom embedding
|
||||
files (.pt or .bin). During startup, this directory will be
|
||||
scanned and InvokeAI will print out the text terms that
|
||||
are available to trigger the embeddings.
|
||||
- ***Output directory for images***
|
||||
This is the path to a directory in which InvokeAI will store all its
|
||||
generated images.
|
||||
|
||||
- ***Autoimport Folder***
|
||||
This is the directory in which you can place models you have
|
||||
downloaded and wish to load into InvokeAI. You can place a variety
|
||||
of models in this directory, including diffusers folders, .ckpt files,
|
||||
.safetensors files, as well as LoRAs, ControlNet and Textual Inversion
|
||||
files (both folder and file versions). To help organize this folder,
|
||||
you can create several levels of subfolders and drop your models into
|
||||
whichever ones you want.
|
||||
|
||||
- ***Autoimport FolderLICENSE***
|
||||
|
||||
At the bottom of the screen you will see a checkbox for accepting
|
||||
the CreativeML Responsible AI License. You need to accept the license
|
||||
the CreativeML Responsible AI Licenses. You need to accept the license
|
||||
in order to download Stable Diffusion models from the next screen.
|
||||
|
||||
_You can come back to the startup options form_ as many times as you like.
|
||||
|
@ -141,15 +141,16 @@ class Installer:
|
||||
|
||||
# upgrade pip in Python 3.9 environments
|
||||
if int(platform.python_version_tuple()[1]) == 9:
|
||||
|
||||
from plumbum import FG, local
|
||||
|
||||
pip = local[get_pip_from_venv(venv_dir)]
|
||||
pip[ "install", "--upgrade", "pip"] & FG
|
||||
pip["install", "--upgrade", "pip"] & FG
|
||||
|
||||
return venv_dir
|
||||
|
||||
def install(self, root: str = "~/invokeai-3", version: str = "latest", yes_to_all=False, find_links: Path = None) -> None:
|
||||
def install(
|
||||
self, root: str = "~/invokeai-3", version: str = "latest", yes_to_all=False, find_links: Path = None
|
||||
) -> None:
|
||||
"""
|
||||
Install the InvokeAI application into the given runtime path
|
||||
|
||||
@ -175,7 +176,7 @@ class Installer:
|
||||
self.instance = InvokeAiInstance(runtime=self.dest, venv=self.venv, version=version)
|
||||
|
||||
# install dependencies and the InvokeAI application
|
||||
(extra_index_url,optional_modules) = get_torch_source() if not yes_to_all else (None,None)
|
||||
(extra_index_url, optional_modules) = get_torch_source() if not yes_to_all else (None, None)
|
||||
self.instance.install(
|
||||
extra_index_url,
|
||||
optional_modules,
|
||||
@ -188,6 +189,7 @@ class Installer:
|
||||
# run through the configuration flow
|
||||
self.instance.configure()
|
||||
|
||||
|
||||
class InvokeAiInstance:
|
||||
"""
|
||||
Manages an installed instance of InvokeAI, comprising a virtual environment and a runtime directory.
|
||||
@ -196,7 +198,6 @@ class InvokeAiInstance:
|
||||
"""
|
||||
|
||||
def __init__(self, runtime: Path, venv: Path, version: str) -> None:
|
||||
|
||||
self.runtime = runtime
|
||||
self.venv = venv
|
||||
self.pip = get_pip_from_venv(venv)
|
||||
@ -312,7 +313,7 @@ class InvokeAiInstance:
|
||||
"install",
|
||||
"--require-virtualenv",
|
||||
"--use-pep517",
|
||||
str(src)+(optional_modules if optional_modules else ''),
|
||||
str(src) + (optional_modules if optional_modules else ""),
|
||||
"--find-links" if find_links is not None else None,
|
||||
find_links,
|
||||
"--extra-index-url" if extra_index_url is not None else None,
|
||||
@ -329,15 +330,15 @@ class InvokeAiInstance:
|
||||
|
||||
# set sys.argv to a consistent state
|
||||
new_argv = [sys.argv[0]]
|
||||
for i in range(1,len(sys.argv)):
|
||||
for i in range(1, len(sys.argv)):
|
||||
el = sys.argv[i]
|
||||
if el in ['-r','--root']:
|
||||
if el in ["-r", "--root"]:
|
||||
new_argv.append(el)
|
||||
new_argv.append(sys.argv[i+1])
|
||||
elif el in ['-y','--yes','--yes-to-all']:
|
||||
new_argv.append(sys.argv[i + 1])
|
||||
elif el in ["-y", "--yes", "--yes-to-all"]:
|
||||
new_argv.append(el)
|
||||
sys.argv = new_argv
|
||||
|
||||
|
||||
import requests # to catch download exceptions
|
||||
from messages import introduction
|
||||
|
||||
@ -353,16 +354,16 @@ class InvokeAiInstance:
|
||||
invokeai_configure()
|
||||
succeeded = True
|
||||
except requests.exceptions.ConnectionError as e:
|
||||
print(f'\nA network error was encountered during configuration and download: {str(e)}')
|
||||
print(f"\nA network error was encountered during configuration and download: {str(e)}")
|
||||
except OSError as e:
|
||||
print(f'\nAn OS error was encountered during configuration and download: {str(e)}')
|
||||
print(f"\nAn OS error was encountered during configuration and download: {str(e)}")
|
||||
except Exception as e:
|
||||
print(f'\nA problem was encountered during the configuration and download steps: {str(e)}')
|
||||
print(f"\nA problem was encountered during the configuration and download steps: {str(e)}")
|
||||
finally:
|
||||
if not succeeded:
|
||||
print('To try again, find the "invokeai" directory, run the script "invoke.sh" or "invoke.bat"')
|
||||
print('and choose option 7 to fix a broken install, optionally followed by option 5 to install models.')
|
||||
print('Alternatively you can relaunch the installer.')
|
||||
print("and choose option 7 to fix a broken install, optionally followed by option 5 to install models.")
|
||||
print("Alternatively you can relaunch the installer.")
|
||||
|
||||
def install_user_scripts(self):
|
||||
"""
|
||||
@ -371,11 +372,11 @@ class InvokeAiInstance:
|
||||
|
||||
ext = "bat" if OS == "Windows" else "sh"
|
||||
|
||||
#scripts = ['invoke', 'update']
|
||||
scripts = ['invoke']
|
||||
|
||||
# scripts = ['invoke', 'update']
|
||||
scripts = ["invoke"]
|
||||
|
||||
for script in scripts:
|
||||
src = Path(__file__).parent / '..' / "templates" / f"{script}.{ext}.in"
|
||||
src = Path(__file__).parent / ".." / "templates" / f"{script}.{ext}.in"
|
||||
dest = self.runtime / f"{script}.{ext}"
|
||||
shutil.copy(src, dest)
|
||||
os.chmod(dest, 0o0755)
|
||||
@ -420,11 +421,7 @@ def set_sys_path(venv_path: Path) -> None:
|
||||
# filter out any paths in sys.path that may be system- or user-wide
|
||||
# but leave the temporary bootstrap virtualenv as it contains packages we
|
||||
# temporarily need at install time
|
||||
sys.path = list(filter(
|
||||
lambda p: not p.endswith("-packages")
|
||||
or p.find(BOOTSTRAP_VENV_PREFIX) != -1,
|
||||
sys.path
|
||||
))
|
||||
sys.path = list(filter(lambda p: not p.endswith("-packages") or p.find(BOOTSTRAP_VENV_PREFIX) != -1, sys.path))
|
||||
|
||||
# determine site-packages/lib directory location for the venv
|
||||
lib = "Lib" if OS == "Windows" else f"lib/python{sys.version_info.major}.{sys.version_info.minor}"
|
||||
@ -433,7 +430,7 @@ def set_sys_path(venv_path: Path) -> None:
|
||||
sys.path.append(str(Path(venv_path, lib, "site-packages").expanduser().resolve()))
|
||||
|
||||
|
||||
def get_torch_source() -> (Union[str, None],str):
|
||||
def get_torch_source() -> (Union[str, None], str):
|
||||
"""
|
||||
Determine the extra index URL for pip to use for torch installation.
|
||||
This depends on the OS and the graphics accelerator in use.
|
||||
@ -461,9 +458,9 @@ def get_torch_source() -> (Union[str, None],str):
|
||||
elif device == "cpu":
|
||||
url = "https://download.pytorch.org/whl/cpu"
|
||||
|
||||
if device == 'cuda':
|
||||
url = 'https://download.pytorch.org/whl/cu117'
|
||||
optional_modules = '[xformers]'
|
||||
if device == "cuda":
|
||||
url = "https://download.pytorch.org/whl/cu117"
|
||||
optional_modules = "[xformers]"
|
||||
|
||||
# in all other cases, Torch wheels should be coming from PyPi as of Torch 1.13
|
||||
|
||||
|
@ -41,7 +41,7 @@ if __name__ == "__main__":
|
||||
type=Path,
|
||||
default=None,
|
||||
)
|
||||
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
inst = Installer()
|
||||
|
@ -36,13 +36,15 @@ else:
|
||||
|
||||
|
||||
def welcome():
|
||||
|
||||
@group()
|
||||
def text():
|
||||
if (platform_specific := _platform_specific_help()) != "":
|
||||
yield platform_specific
|
||||
yield ""
|
||||
yield Text.from_markup("Some of the installation steps take a long time to run. Please be patient. If the script appears to hang for more than 10 minutes, please interrupt with [i]Control-C[/] and retry.", justify="center")
|
||||
yield Text.from_markup(
|
||||
"Some of the installation steps take a long time to run. Please be patient. If the script appears to hang for more than 10 minutes, please interrupt with [i]Control-C[/] and retry.",
|
||||
justify="center",
|
||||
)
|
||||
|
||||
console.rule()
|
||||
print(
|
||||
@ -58,6 +60,7 @@ def welcome():
|
||||
)
|
||||
console.line()
|
||||
|
||||
|
||||
def confirm_install(dest: Path) -> bool:
|
||||
if dest.exists():
|
||||
print(f":exclamation: Directory {dest} already exists :exclamation:")
|
||||
@ -92,7 +95,6 @@ def dest_path(dest=None) -> Path:
|
||||
dest_confirmed = confirm_install(dest)
|
||||
|
||||
while not dest_confirmed:
|
||||
|
||||
# if the given destination already exists, the starting point for browsing is its parent directory.
|
||||
# the user may have made a typo, or otherwise wants to place the root dir next to an existing one.
|
||||
# if the destination dir does NOT exist, then the user must have changed their mind about the selection.
|
||||
@ -300,15 +302,20 @@ def introduction() -> None:
|
||||
)
|
||||
console.line(2)
|
||||
|
||||
def _platform_specific_help()->str:
|
||||
|
||||
def _platform_specific_help() -> str:
|
||||
if OS == "Darwin":
|
||||
text = Text.from_markup("""[b wheat1]macOS Users![/]\n\nPlease be sure you have the [b wheat1]Xcode command-line tools[/] installed before continuing.\nIf not, cancel with [i]Control-C[/] and follow the Xcode install instructions at [deep_sky_blue1]https://www.freecodecamp.org/news/install-xcode-command-line-tools/[/].""")
|
||||
text = Text.from_markup(
|
||||
"""[b wheat1]macOS Users![/]\n\nPlease be sure you have the [b wheat1]Xcode command-line tools[/] installed before continuing.\nIf not, cancel with [i]Control-C[/] and follow the Xcode install instructions at [deep_sky_blue1]https://www.freecodecamp.org/news/install-xcode-command-line-tools/[/]."""
|
||||
)
|
||||
elif OS == "Windows":
|
||||
text = Text.from_markup("""[b wheat1]Windows Users![/]\n\nBefore you start, please do the following:
|
||||
text = Text.from_markup(
|
||||
"""[b wheat1]Windows Users![/]\n\nBefore you start, please do the following:
|
||||
1. Double-click on the file [b wheat1]WinLongPathsEnabled.reg[/] in order to
|
||||
enable long path support on your system.
|
||||
2. Make sure you have the [b wheat1]Visual C++ core libraries[/] installed. If not, install from
|
||||
[deep_sky_blue1]https://learn.microsoft.com/en-US/cpp/windows/latest-supported-vc-redist?view=msvc-170[/]""")
|
||||
[deep_sky_blue1]https://learn.microsoft.com/en-US/cpp/windows/latest-supported-vc-redist?view=msvc-170[/]"""
|
||||
)
|
||||
else:
|
||||
text = ""
|
||||
return text
|
||||
|
@ -78,9 +78,7 @@ class ApiDependencies:
|
||||
image_record_storage = SqliteImageRecordStorage(db_location)
|
||||
image_file_storage = DiskImageFileStorage(f"{output_folder}/images")
|
||||
names = SimpleNameService()
|
||||
latents = ForwardCacheLatentsStorage(
|
||||
DiskLatentsStorage(f"{output_folder}/latents")
|
||||
)
|
||||
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents"))
|
||||
|
||||
board_record_storage = SqliteBoardRecordStorage(db_location)
|
||||
board_image_record_storage = SqliteBoardImageRecordStorage(db_location)
|
||||
@ -125,9 +123,7 @@ class ApiDependencies:
|
||||
boards=boards,
|
||||
board_images=board_images,
|
||||
queue=MemoryInvocationQueue(),
|
||||
graph_library=SqliteItemStorage[LibraryGraph](
|
||||
filename=db_location, table_name="graphs"
|
||||
),
|
||||
graph_library=SqliteItemStorage[LibraryGraph](filename=db_location, table_name="graphs"),
|
||||
graph_execution_manager=graph_execution_manager,
|
||||
processor=DefaultInvocationProcessor(),
|
||||
configuration=config,
|
||||
|
@ -1,14 +1,21 @@
|
||||
import typing
|
||||
from enum import Enum
|
||||
from fastapi import Body
|
||||
from fastapi.routing import APIRouter
|
||||
from pathlib import Path
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.backend.image_util.patchmatch import PatchMatch
|
||||
from invokeai.backend.image_util.safety_checker import SafetyChecker
|
||||
from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark
|
||||
from invokeai.app.invocations.upscale import ESRGAN_MODELS
|
||||
|
||||
from invokeai.version import __version__
|
||||
|
||||
from ..dependencies import ApiDependencies
|
||||
from invokeai.backend.util.logging import logging
|
||||
|
||||
|
||||
class LogLevel(int, Enum):
|
||||
NotSet = logging.NOTSET
|
||||
Debug = logging.DEBUG
|
||||
@ -16,7 +23,13 @@ class LogLevel(int, Enum):
|
||||
Warning = logging.WARNING
|
||||
Error = logging.ERROR
|
||||
Critical = logging.CRITICAL
|
||||
|
||||
|
||||
|
||||
class Upscaler(BaseModel):
|
||||
upscaling_method: str = Field(description="Name of upscaling method")
|
||||
upscaling_models: list[str] = Field(description="List of upscaling models for this method")
|
||||
|
||||
|
||||
app_router = APIRouter(prefix="/v1/app", tags=["app"])
|
||||
|
||||
|
||||
@ -30,43 +43,62 @@ class AppConfig(BaseModel):
|
||||
"""App Config Response"""
|
||||
|
||||
infill_methods: list[str] = Field(description="List of available infill methods")
|
||||
upscaling_methods: list[Upscaler] = Field(description="List of upscaling methods")
|
||||
nsfw_methods: list[str] = Field(description="List of NSFW checking methods")
|
||||
watermarking_methods: list[str] = Field(description="List of invisible watermark methods")
|
||||
|
||||
|
||||
@app_router.get(
|
||||
"/version", operation_id="app_version", status_code=200, response_model=AppVersion
|
||||
)
|
||||
@app_router.get("/version", operation_id="app_version", status_code=200, response_model=AppVersion)
|
||||
async def get_version() -> AppVersion:
|
||||
return AppVersion(version=__version__)
|
||||
|
||||
|
||||
@app_router.get(
|
||||
"/config", operation_id="get_config", status_code=200, response_model=AppConfig
|
||||
)
|
||||
@app_router.get("/config", operation_id="get_config", status_code=200, response_model=AppConfig)
|
||||
async def get_config() -> AppConfig:
|
||||
infill_methods = ['tile']
|
||||
infill_methods = ["tile"]
|
||||
if PatchMatch.patchmatch_available():
|
||||
infill_methods.append('patchmatch')
|
||||
return AppConfig(infill_methods=infill_methods)
|
||||
infill_methods.append("patchmatch")
|
||||
|
||||
upscaling_models = []
|
||||
for model in typing.get_args(ESRGAN_MODELS):
|
||||
upscaling_models.append(str(Path(model).stem))
|
||||
upscaler = Upscaler(upscaling_method="esrgan", upscaling_models=upscaling_models)
|
||||
|
||||
nsfw_methods = []
|
||||
if SafetyChecker.safety_checker_available():
|
||||
nsfw_methods.append("nsfw_checker")
|
||||
|
||||
watermarking_methods = []
|
||||
if InvisibleWatermark.invisible_watermark_available():
|
||||
watermarking_methods.append("invisible_watermark")
|
||||
|
||||
return AppConfig(
|
||||
infill_methods=infill_methods,
|
||||
upscaling_methods=[upscaler],
|
||||
nsfw_methods=nsfw_methods,
|
||||
watermarking_methods=watermarking_methods,
|
||||
)
|
||||
|
||||
|
||||
@app_router.get(
|
||||
"/logging",
|
||||
operation_id="get_log_level",
|
||||
responses={200: {"description" : "The operation was successful"}},
|
||||
response_model = LogLevel,
|
||||
responses={200: {"description": "The operation was successful"}},
|
||||
response_model=LogLevel,
|
||||
)
|
||||
async def get_log_level(
|
||||
) -> LogLevel:
|
||||
async def get_log_level() -> LogLevel:
|
||||
"""Returns the log level"""
|
||||
return LogLevel(ApiDependencies.invoker.services.logger.level)
|
||||
|
||||
|
||||
@app_router.post(
|
||||
"/logging",
|
||||
operation_id="set_log_level",
|
||||
responses={200: {"description" : "The operation was successful"}},
|
||||
response_model = LogLevel,
|
||||
responses={200: {"description": "The operation was successful"}},
|
||||
response_model=LogLevel,
|
||||
)
|
||||
async def set_log_level(
|
||||
level: LogLevel = Body(description="New log verbosity level"),
|
||||
level: LogLevel = Body(description="New log verbosity level"),
|
||||
) -> LogLevel:
|
||||
"""Sets the log verbosity level"""
|
||||
ApiDependencies.invoker.services.logger.setLevel(level)
|
||||
|
@ -52,4 +52,3 @@ async def remove_board_image(
|
||||
return result
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail="Failed to update board")
|
||||
|
||||
|
@ -18,9 +18,7 @@ class DeleteBoardResult(BaseModel):
|
||||
deleted_board_images: list[str] = Field(
|
||||
description="The image names of the board-images relationships that were deleted."
|
||||
)
|
||||
deleted_images: list[str] = Field(
|
||||
description="The names of the images that were deleted."
|
||||
)
|
||||
deleted_images: list[str] = Field(description="The names of the images that were deleted.")
|
||||
|
||||
|
||||
@boards_router.post(
|
||||
@ -73,22 +71,16 @@ async def update_board(
|
||||
) -> BoardDTO:
|
||||
"""Updates a board"""
|
||||
try:
|
||||
result = ApiDependencies.invoker.services.boards.update(
|
||||
board_id=board_id, changes=changes
|
||||
)
|
||||
result = ApiDependencies.invoker.services.boards.update(board_id=board_id, changes=changes)
|
||||
return result
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail="Failed to update board")
|
||||
|
||||
|
||||
@boards_router.delete(
|
||||
"/{board_id}", operation_id="delete_board", response_model=DeleteBoardResult
|
||||
)
|
||||
@boards_router.delete("/{board_id}", operation_id="delete_board", response_model=DeleteBoardResult)
|
||||
async def delete_board(
|
||||
board_id: str = Path(description="The id of board to delete"),
|
||||
include_images: Optional[bool] = Query(
|
||||
description="Permanently delete all images on the board", default=False
|
||||
),
|
||||
include_images: Optional[bool] = Query(description="Permanently delete all images on the board", default=False),
|
||||
) -> DeleteBoardResult:
|
||||
"""Deletes a board"""
|
||||
try:
|
||||
@ -96,9 +88,7 @@ async def delete_board(
|
||||
deleted_images = ApiDependencies.invoker.services.board_images.get_all_board_image_names_for_board(
|
||||
board_id=board_id
|
||||
)
|
||||
ApiDependencies.invoker.services.images.delete_images_on_board(
|
||||
board_id=board_id
|
||||
)
|
||||
ApiDependencies.invoker.services.images.delete_images_on_board(board_id=board_id)
|
||||
ApiDependencies.invoker.services.boards.delete(board_id=board_id)
|
||||
return DeleteBoardResult(
|
||||
board_id=board_id,
|
||||
@ -127,9 +117,7 @@ async def delete_board(
|
||||
async def list_boards(
|
||||
all: Optional[bool] = Query(default=None, description="Whether to list all boards"),
|
||||
offset: Optional[int] = Query(default=None, description="The page offset"),
|
||||
limit: Optional[int] = Query(
|
||||
default=None, description="The number of boards per page"
|
||||
),
|
||||
limit: Optional[int] = Query(default=None, description="The number of boards per page"),
|
||||
) -> Union[OffsetPaginatedResults[BoardDTO], list[BoardDTO]]:
|
||||
"""Gets a list of boards"""
|
||||
if all:
|
||||
|
@ -40,15 +40,9 @@ async def upload_image(
|
||||
response: Response,
|
||||
image_category: ImageCategory = Query(description="The category of the image"),
|
||||
is_intermediate: bool = Query(description="Whether this is an intermediate image"),
|
||||
board_id: Optional[str] = Query(
|
||||
default=None, description="The board to add this image to, if any"
|
||||
),
|
||||
session_id: Optional[str] = Query(
|
||||
default=None, description="The session ID associated with this upload, if any"
|
||||
),
|
||||
crop_visible: Optional[bool] = Query(
|
||||
default=False, description="Whether to crop the image"
|
||||
),
|
||||
board_id: Optional[str] = Query(default=None, description="The board to add this image to, if any"),
|
||||
session_id: Optional[str] = Query(default=None, description="The session ID associated with this upload, if any"),
|
||||
crop_visible: Optional[bool] = Query(default=False, description="Whether to crop the image"),
|
||||
) -> ImageDTO:
|
||||
"""Uploads an image"""
|
||||
if not file.content_type.startswith("image"):
|
||||
@ -115,9 +109,7 @@ async def clear_intermediates() -> int:
|
||||
)
|
||||
async def update_image(
|
||||
image_name: str = Path(description="The name of the image to update"),
|
||||
image_changes: ImageRecordChanges = Body(
|
||||
description="The changes to apply to the image"
|
||||
),
|
||||
image_changes: ImageRecordChanges = Body(description="The changes to apply to the image"),
|
||||
) -> ImageDTO:
|
||||
"""Updates an image"""
|
||||
|
||||
@ -212,15 +204,11 @@ async def get_image_thumbnail(
|
||||
"""Gets a thumbnail image file"""
|
||||
|
||||
try:
|
||||
path = ApiDependencies.invoker.services.images.get_path(
|
||||
image_name, thumbnail=True
|
||||
)
|
||||
path = ApiDependencies.invoker.services.images.get_path(image_name, thumbnail=True)
|
||||
if not ApiDependencies.invoker.services.images.validate_path(path):
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
response = FileResponse(
|
||||
path, media_type="image/webp", content_disposition_type="inline"
|
||||
)
|
||||
response = FileResponse(path, media_type="image/webp", content_disposition_type="inline")
|
||||
response.headers["Cache-Control"] = f"max-age={IMAGE_MAX_AGE}"
|
||||
return response
|
||||
except Exception as e:
|
||||
@ -239,9 +227,7 @@ async def get_image_urls(
|
||||
|
||||
try:
|
||||
image_url = ApiDependencies.invoker.services.images.get_url(image_name)
|
||||
thumbnail_url = ApiDependencies.invoker.services.images.get_url(
|
||||
image_name, thumbnail=True
|
||||
)
|
||||
thumbnail_url = ApiDependencies.invoker.services.images.get_url(image_name, thumbnail=True)
|
||||
return ImageUrlsDTO(
|
||||
image_name=image_name,
|
||||
image_url=image_url,
|
||||
@ -257,15 +243,9 @@ async def get_image_urls(
|
||||
response_model=OffsetPaginatedResults[ImageDTO],
|
||||
)
|
||||
async def list_image_dtos(
|
||||
image_origin: Optional[ResourceOrigin] = Query(
|
||||
default=None, description="The origin of images to list."
|
||||
),
|
||||
categories: Optional[list[ImageCategory]] = Query(
|
||||
default=None, description="The categories of image to include."
|
||||
),
|
||||
is_intermediate: Optional[bool] = Query(
|
||||
default=None, description="Whether to list intermediate images."
|
||||
),
|
||||
image_origin: Optional[ResourceOrigin] = Query(default=None, description="The origin of images to list."),
|
||||
categories: Optional[list[ImageCategory]] = Query(default=None, description="The categories of image to include."),
|
||||
is_intermediate: Optional[bool] = Query(default=None, description="Whether to list intermediate images."),
|
||||
board_id: Optional[str] = Query(
|
||||
default=None,
|
||||
description="The board id to filter by. Use 'none' to find images without a board.",
|
||||
|
@ -30,9 +30,7 @@ session_router = APIRouter(prefix="/v1/sessions", tags=["sessions"])
|
||||
},
|
||||
)
|
||||
async def create_session(
|
||||
graph: Optional[Graph] = Body(
|
||||
default=None, description="The graph to initialize the session with"
|
||||
)
|
||||
graph: Optional[Graph] = Body(default=None, description="The graph to initialize the session with")
|
||||
) -> GraphExecutionState:
|
||||
"""Creates a new session, optionally initializing it with an invocation graph"""
|
||||
session = ApiDependencies.invoker.create_execution_state(graph)
|
||||
@ -51,13 +49,9 @@ async def list_sessions(
|
||||
) -> PaginatedResults[GraphExecutionState]:
|
||||
"""Gets a list of sessions, optionally searching"""
|
||||
if query == "":
|
||||
result = ApiDependencies.invoker.services.graph_execution_manager.list(
|
||||
page, per_page
|
||||
)
|
||||
result = ApiDependencies.invoker.services.graph_execution_manager.list(page, per_page)
|
||||
else:
|
||||
result = ApiDependencies.invoker.services.graph_execution_manager.search(
|
||||
query, page, per_page
|
||||
)
|
||||
result = ApiDependencies.invoker.services.graph_execution_manager.search(query, page, per_page)
|
||||
return result
|
||||
|
||||
|
||||
@ -91,9 +85,9 @@ async def get_session(
|
||||
)
|
||||
async def add_node(
|
||||
session_id: str = Path(description="The id of the session"),
|
||||
node: Annotated[
|
||||
Union[BaseInvocation.get_invocations()], Field(discriminator="type") # type: ignore
|
||||
] = Body(description="The node to add"),
|
||||
node: Annotated[Union[BaseInvocation.get_invocations()], Field(discriminator="type")] = Body( # type: ignore
|
||||
description="The node to add"
|
||||
),
|
||||
) -> str:
|
||||
"""Adds a node to the graph"""
|
||||
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||
@ -124,9 +118,9 @@ async def add_node(
|
||||
async def update_node(
|
||||
session_id: str = Path(description="The id of the session"),
|
||||
node_path: str = Path(description="The path to the node in the graph"),
|
||||
node: Annotated[
|
||||
Union[BaseInvocation.get_invocations()], Field(discriminator="type") # type: ignore
|
||||
] = Body(description="The new node"),
|
||||
node: Annotated[Union[BaseInvocation.get_invocations()], Field(discriminator="type")] = Body( # type: ignore
|
||||
description="The new node"
|
||||
),
|
||||
) -> GraphExecutionState:
|
||||
"""Updates a node in the graph and removes all linked edges"""
|
||||
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||
@ -230,7 +224,7 @@ async def delete_edge(
|
||||
try:
|
||||
edge = Edge(
|
||||
source=EdgeConnection(node_id=from_node_id, field=from_field),
|
||||
destination=EdgeConnection(node_id=to_node_id, field=to_field)
|
||||
destination=EdgeConnection(node_id=to_node_id, field=to_field),
|
||||
)
|
||||
session.delete_edge(edge)
|
||||
ApiDependencies.invoker.services.graph_execution_manager.set(
|
||||
@ -255,9 +249,7 @@ async def delete_edge(
|
||||
)
|
||||
async def invoke_session(
|
||||
session_id: str = Path(description="The id of the session to invoke"),
|
||||
all: bool = Query(
|
||||
default=False, description="Whether or not to invoke all remaining invocations"
|
||||
),
|
||||
all: bool = Query(default=False, description="Whether or not to invoke all remaining invocations"),
|
||||
) -> Response:
|
||||
"""Invokes a session"""
|
||||
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||
@ -274,9 +266,7 @@ async def invoke_session(
|
||||
@session_router.delete(
|
||||
"/{session_id}/invoke",
|
||||
operation_id="cancel_session_invoke",
|
||||
responses={
|
||||
202: {"description": "The invocation is canceled"}
|
||||
},
|
||||
responses={202: {"description": "The invocation is canceled"}},
|
||||
)
|
||||
async def cancel_session_invoke(
|
||||
session_id: str = Path(description="The id of the session to cancel"),
|
||||
|
@ -16,9 +16,7 @@ class SocketIO:
|
||||
self.__sio.on("subscribe", handler=self._handle_sub)
|
||||
self.__sio.on("unsubscribe", handler=self._handle_unsub)
|
||||
|
||||
local_handler.register(
|
||||
event_name=EventServiceBase.session_event, _func=self._handle_session_event
|
||||
)
|
||||
local_handler.register(event_name=EventServiceBase.session_event, _func=self._handle_session_event)
|
||||
|
||||
async def _handle_session_event(self, event: Event):
|
||||
await self.__sio.emit(
|
||||
|
@ -16,9 +16,10 @@ from fastapi_events.middleware import EventHandlerASGIMiddleware
|
||||
from pathlib import Path
|
||||
from pydantic.schema import schema
|
||||
|
||||
#This should come early so that modules can log their initialization properly
|
||||
# This should come early so that modules can log their initialization properly
|
||||
from .services.config import InvokeAIAppConfig
|
||||
from ..backend.util.logging import InvokeAILogger
|
||||
|
||||
app_config = InvokeAIAppConfig.get_config()
|
||||
app_config.parse_args()
|
||||
logger = InvokeAILogger.getLogger(config=app_config)
|
||||
@ -27,7 +28,7 @@ from invokeai.version.invokeai_version import __version__
|
||||
# we call this early so that the message appears before
|
||||
# other invokeai initialization messages
|
||||
if app_config.version:
|
||||
print(f'InvokeAI version {__version__}')
|
||||
print(f"InvokeAI version {__version__}")
|
||||
sys.exit(0)
|
||||
|
||||
import invokeai.frontend.web as web_dir
|
||||
@ -37,17 +38,18 @@ from .api.dependencies import ApiDependencies
|
||||
from .api.routers import sessions, models, images, boards, board_images, app_info
|
||||
from .api.sockets import SocketIO
|
||||
from .invocations.baseinvocation import BaseInvocation
|
||||
|
||||
|
||||
|
||||
import torch
|
||||
import invokeai.backend.util.hotfixes
|
||||
|
||||
if torch.backends.mps.is_available():
|
||||
import invokeai.backend.util.mps_fixes
|
||||
|
||||
# fix for windows mimetypes registry entries being borked
|
||||
# see https://github.com/invoke-ai/InvokeAI/discussions/3684#discussioncomment-6391352
|
||||
mimetypes.add_type('application/javascript', '.js')
|
||||
mimetypes.add_type('text/css', '.css')
|
||||
mimetypes.add_type("application/javascript", ".js")
|
||||
mimetypes.add_type("text/css", ".css")
|
||||
|
||||
# Create the app
|
||||
# TODO: create this all in a method so configuration/etc. can be passed in?
|
||||
@ -57,14 +59,13 @@ app = FastAPI(title="Invoke AI", docs_url=None, redoc_url=None)
|
||||
event_handler_id: int = id(app)
|
||||
app.add_middleware(
|
||||
EventHandlerASGIMiddleware,
|
||||
handlers=[
|
||||
local_handler
|
||||
], # TODO: consider doing this in services to support different configurations
|
||||
handlers=[local_handler], # TODO: consider doing this in services to support different configurations
|
||||
middleware_id=event_handler_id,
|
||||
)
|
||||
|
||||
socket_io = SocketIO(app)
|
||||
|
||||
|
||||
# Add startup event to load dependencies
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
@ -76,9 +77,7 @@ async def startup_event():
|
||||
allow_headers=app_config.allow_headers,
|
||||
)
|
||||
|
||||
ApiDependencies.initialize(
|
||||
config=app_config, event_handler_id=event_handler_id, logger=logger
|
||||
)
|
||||
ApiDependencies.initialize(config=app_config, event_handler_id=event_handler_id, logger=logger)
|
||||
|
||||
|
||||
# Shut down threads
|
||||
@ -103,7 +102,8 @@ app.include_router(boards.boards_router, prefix="/api")
|
||||
|
||||
app.include_router(board_images.board_images_router, prefix="/api")
|
||||
|
||||
app.include_router(app_info.app_router, prefix='/api')
|
||||
app.include_router(app_info.app_router, prefix="/api")
|
||||
|
||||
|
||||
# Build a custom OpenAPI to include all outputs
|
||||
# TODO: can outputs be included on metadata of invocation schemas somehow?
|
||||
@ -144,6 +144,7 @@ def custom_openapi():
|
||||
invoker_schema["output"] = outputs_ref
|
||||
|
||||
from invokeai.backend.model_management.models import get_model_config_enums
|
||||
|
||||
for model_config_format_enum in set(get_model_config_enums()):
|
||||
name = model_config_format_enum.__qualname__
|
||||
|
||||
@ -166,7 +167,8 @@ def custom_openapi():
|
||||
app.openapi = custom_openapi
|
||||
|
||||
# Override API doc favicons
|
||||
app.mount("/static", StaticFiles(directory=Path(web_dir.__path__[0], 'static/dream_web')), name="static")
|
||||
app.mount("/static", StaticFiles(directory=Path(web_dir.__path__[0], "static/dream_web")), name="static")
|
||||
|
||||
|
||||
@app.get("/docs", include_in_schema=False)
|
||||
def overridden_swagger():
|
||||
@ -187,11 +189,8 @@ def overridden_redoc():
|
||||
|
||||
|
||||
# Must mount *after* the other routes else it borks em
|
||||
app.mount("/",
|
||||
StaticFiles(directory=Path(web_dir.__path__[0],"dist"),
|
||||
html=True
|
||||
), name="ui"
|
||||
)
|
||||
app.mount("/", StaticFiles(directory=Path(web_dir.__path__[0], "dist"), html=True), name="ui")
|
||||
|
||||
|
||||
def invoke_api():
|
||||
def find_port(port: int):
|
||||
@ -204,6 +203,10 @@ def invoke_api():
|
||||
else:
|
||||
return port
|
||||
|
||||
from invokeai.backend.install.check_root import check_invokeai_root
|
||||
|
||||
check_invokeai_root(app_config) # note, may exit with an exception if root not set up
|
||||
|
||||
port = find_port(app_config.port)
|
||||
if port != app_config.port:
|
||||
logger.warn(f"Port {app_config.port} in use, using port {port}")
|
||||
@ -214,5 +217,6 @@ def invoke_api():
|
||||
server = uvicorn.Server(config)
|
||||
loop.run_until_complete(server.serve())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
invoke_api()
|
||||
|
Before Width: | Height: | Size: 33 KiB After Width: | Height: | Size: 33 KiB |
@ -14,8 +14,14 @@ from ..services.graph import GraphExecutionState, LibraryGraph, Edge
|
||||
from ..services.invoker import Invoker
|
||||
|
||||
|
||||
def add_field_argument(command_parser, name: str, field, default_override = None):
|
||||
default = default_override if default_override is not None else field.default if field.default_factory is None else field.default_factory()
|
||||
def add_field_argument(command_parser, name: str, field, default_override=None):
|
||||
default = (
|
||||
default_override
|
||||
if default_override is not None
|
||||
else field.default
|
||||
if field.default_factory is None
|
||||
else field.default_factory()
|
||||
)
|
||||
if get_origin(field.type_) == Literal:
|
||||
allowed_values = get_args(field.type_)
|
||||
allowed_types = set()
|
||||
@ -47,8 +53,8 @@ def add_parsers(
|
||||
commands: list[type],
|
||||
command_field: str = "type",
|
||||
exclude_fields: list[str] = ["id", "type"],
|
||||
add_arguments: Union[Callable[[argparse.ArgumentParser], None],None] = None
|
||||
):
|
||||
add_arguments: Union[Callable[[argparse.ArgumentParser], None], None] = None,
|
||||
):
|
||||
"""Adds parsers for each command to the subparsers"""
|
||||
|
||||
# Create subparsers for each command
|
||||
@ -61,7 +67,7 @@ def add_parsers(
|
||||
add_arguments(command_parser)
|
||||
|
||||
# Convert all fields to arguments
|
||||
fields = command.__fields__ # type: ignore
|
||||
fields = command.__fields__ # type: ignore
|
||||
for name, field in fields.items():
|
||||
if name in exclude_fields:
|
||||
continue
|
||||
@ -70,13 +76,11 @@ def add_parsers(
|
||||
|
||||
|
||||
def add_graph_parsers(
|
||||
subparsers,
|
||||
graphs: list[LibraryGraph],
|
||||
add_arguments: Union[Callable[[argparse.ArgumentParser], None], None] = None
|
||||
subparsers, graphs: list[LibraryGraph], add_arguments: Union[Callable[[argparse.ArgumentParser], None], None] = None
|
||||
):
|
||||
for graph in graphs:
|
||||
command_parser = subparsers.add_parser(graph.name, help=graph.description)
|
||||
|
||||
|
||||
if add_arguments is not None:
|
||||
add_arguments(command_parser)
|
||||
|
||||
@ -128,6 +132,7 @@ class CliContext:
|
||||
|
||||
class ExitCli(Exception):
|
||||
"""Exception to exit the CLI"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@ -155,7 +160,7 @@ class BaseCommand(ABC, BaseModel):
|
||||
@classmethod
|
||||
def get_commands_map(cls):
|
||||
# Get the type strings out of the literals and into a dictionary
|
||||
return dict(map(lambda t: (get_args(get_type_hints(t)['type'])[0], t),BaseCommand.get_all_subclasses()))
|
||||
return dict(map(lambda t: (get_args(get_type_hints(t)["type"])[0], t), BaseCommand.get_all_subclasses()))
|
||||
|
||||
@abstractmethod
|
||||
def run(self, context: CliContext) -> None:
|
||||
@ -165,7 +170,8 @@ class BaseCommand(ABC, BaseModel):
|
||||
|
||||
class ExitCommand(BaseCommand):
|
||||
"""Exits the CLI"""
|
||||
type: Literal['exit'] = 'exit'
|
||||
|
||||
type: Literal["exit"] = "exit"
|
||||
|
||||
def run(self, context: CliContext) -> None:
|
||||
raise ExitCli()
|
||||
@ -173,7 +179,8 @@ class ExitCommand(BaseCommand):
|
||||
|
||||
class HelpCommand(BaseCommand):
|
||||
"""Shows help"""
|
||||
type: Literal['help'] = 'help'
|
||||
|
||||
type: Literal["help"] = "help"
|
||||
|
||||
def run(self, context: CliContext) -> None:
|
||||
context.parser.print_help()
|
||||
@ -183,11 +190,7 @@ def get_graph_execution_history(
|
||||
graph_execution_state: GraphExecutionState,
|
||||
) -> Iterable[str]:
|
||||
"""Gets the history of fully-executed invocations for a graph execution"""
|
||||
return (
|
||||
n
|
||||
for n in reversed(graph_execution_state.executed_history)
|
||||
if n in graph_execution_state.graph.nodes
|
||||
)
|
||||
return (n for n in reversed(graph_execution_state.executed_history) if n in graph_execution_state.graph.nodes)
|
||||
|
||||
|
||||
def get_invocation_command(invocation) -> str:
|
||||
@ -218,7 +221,8 @@ def get_invocation_command(invocation) -> str:
|
||||
|
||||
class HistoryCommand(BaseCommand):
|
||||
"""Shows the invocation history"""
|
||||
type: Literal['history'] = 'history'
|
||||
|
||||
type: Literal["history"] = "history"
|
||||
|
||||
# Inputs
|
||||
# fmt: off
|
||||
@ -235,7 +239,8 @@ class HistoryCommand(BaseCommand):
|
||||
|
||||
class SetDefaultCommand(BaseCommand):
|
||||
"""Sets a default value for a field"""
|
||||
type: Literal['default'] = 'default'
|
||||
|
||||
type: Literal["default"] = "default"
|
||||
|
||||
# Inputs
|
||||
# fmt: off
|
||||
@ -253,7 +258,8 @@ class SetDefaultCommand(BaseCommand):
|
||||
|
||||
class DrawGraphCommand(BaseCommand):
|
||||
"""Debugs a graph"""
|
||||
type: Literal['draw_graph'] = 'draw_graph'
|
||||
|
||||
type: Literal["draw_graph"] = "draw_graph"
|
||||
|
||||
def run(self, context: CliContext) -> None:
|
||||
session: GraphExecutionState = context.invoker.services.graph_execution_manager.get(context.session.id)
|
||||
@ -271,7 +277,8 @@ class DrawGraphCommand(BaseCommand):
|
||||
|
||||
class DrawExecutionGraphCommand(BaseCommand):
|
||||
"""Debugs an execution graph"""
|
||||
type: Literal['draw_xgraph'] = 'draw_xgraph'
|
||||
|
||||
type: Literal["draw_xgraph"] = "draw_xgraph"
|
||||
|
||||
def run(self, context: CliContext) -> None:
|
||||
session: GraphExecutionState = context.invoker.services.graph_execution_manager.get(context.session.id)
|
||||
@ -286,6 +293,7 @@ class DrawExecutionGraphCommand(BaseCommand):
|
||||
plt.axis("off")
|
||||
plt.show()
|
||||
|
||||
|
||||
class SortedHelpFormatter(argparse.HelpFormatter):
|
||||
def _iter_indented_subactions(self, action):
|
||||
try:
|
||||
|
@ -19,8 +19,8 @@ from ..services.invocation_services import InvocationServices
|
||||
# singleton object, class variable
|
||||
completer = None
|
||||
|
||||
|
||||
class Completer(object):
|
||||
|
||||
def __init__(self, model_manager: ModelManager):
|
||||
self.commands = self.get_commands()
|
||||
self.matches = None
|
||||
@ -43,7 +43,7 @@ class Completer(object):
|
||||
except IndexError:
|
||||
pass
|
||||
options = options or list(self.parse_commands().keys())
|
||||
|
||||
|
||||
if not text: # first time
|
||||
self.matches = options
|
||||
else:
|
||||
@ -56,17 +56,17 @@ class Completer(object):
|
||||
return match
|
||||
|
||||
@classmethod
|
||||
def get_commands(self)->List[object]:
|
||||
def get_commands(self) -> List[object]:
|
||||
"""
|
||||
Return a list of all the client commands and invocations.
|
||||
"""
|
||||
return BaseCommand.get_commands() + BaseInvocation.get_invocations()
|
||||
|
||||
def get_current_command(self, buffer: str)->tuple[str, str]:
|
||||
def get_current_command(self, buffer: str) -> tuple[str, str]:
|
||||
"""
|
||||
Parse the readline buffer to find the most recent command and its switch.
|
||||
"""
|
||||
if len(buffer)==0:
|
||||
if len(buffer) == 0:
|
||||
return None, None
|
||||
tokens = shlex.split(buffer)
|
||||
command = None
|
||||
@ -78,11 +78,11 @@ class Completer(object):
|
||||
else:
|
||||
switch = t
|
||||
# don't try to autocomplete switches that are already complete
|
||||
if switch and buffer.endswith(' '):
|
||||
switch=None
|
||||
return command or '', switch or ''
|
||||
if switch and buffer.endswith(" "):
|
||||
switch = None
|
||||
return command or "", switch or ""
|
||||
|
||||
def parse_commands(self)->Dict[str, List[str]]:
|
||||
def parse_commands(self) -> Dict[str, List[str]]:
|
||||
"""
|
||||
Return a dict in which the keys are the command name
|
||||
and the values are the parameters the command takes.
|
||||
@ -90,11 +90,11 @@ class Completer(object):
|
||||
result = dict()
|
||||
for command in self.commands:
|
||||
hints = get_type_hints(command)
|
||||
name = get_args(hints['type'])[0]
|
||||
result.update({name:hints})
|
||||
name = get_args(hints["type"])[0]
|
||||
result.update({name: hints})
|
||||
return result
|
||||
|
||||
def get_command_options(self, command: str, switch: str)->List[str]:
|
||||
def get_command_options(self, command: str, switch: str) -> List[str]:
|
||||
"""
|
||||
Return all the parameters that can be passed to the command as
|
||||
command-line switches. Returns None if the command is unrecognized.
|
||||
@ -102,42 +102,46 @@ class Completer(object):
|
||||
parsed_commands = self.parse_commands()
|
||||
if command not in parsed_commands:
|
||||
return None
|
||||
|
||||
|
||||
# handle switches in the format "-foo=bar"
|
||||
argument = None
|
||||
if switch and '=' in switch:
|
||||
switch, argument = switch.split('=')
|
||||
|
||||
parameter = switch.strip('-')
|
||||
if switch and "=" in switch:
|
||||
switch, argument = switch.split("=")
|
||||
|
||||
parameter = switch.strip("-")
|
||||
if parameter in parsed_commands[command]:
|
||||
if argument is None:
|
||||
return self.get_parameter_options(parameter, parsed_commands[command][parameter])
|
||||
else:
|
||||
return [f"--{parameter}={x}" for x in self.get_parameter_options(parameter, parsed_commands[command][parameter])]
|
||||
return [
|
||||
f"--{parameter}={x}"
|
||||
for x in self.get_parameter_options(parameter, parsed_commands[command][parameter])
|
||||
]
|
||||
else:
|
||||
return [f"--{x}" for x in parsed_commands[command].keys()]
|
||||
|
||||
def get_parameter_options(self, parameter: str, typehint)->List[str]:
|
||||
def get_parameter_options(self, parameter: str, typehint) -> List[str]:
|
||||
"""
|
||||
Given a parameter type (such as Literal), offers autocompletions.
|
||||
"""
|
||||
if get_origin(typehint) == Literal:
|
||||
return get_args(typehint)
|
||||
if parameter == 'model':
|
||||
if parameter == "model":
|
||||
return self.manager.model_names()
|
||||
|
||||
|
||||
def _pre_input_hook(self):
|
||||
if self.linebuffer:
|
||||
readline.insert_text(self.linebuffer)
|
||||
readline.redisplay()
|
||||
self.linebuffer = None
|
||||
|
||||
|
||||
|
||||
def set_autocompleter(services: InvocationServices) -> Completer:
|
||||
global completer
|
||||
|
||||
|
||||
if completer:
|
||||
return completer
|
||||
|
||||
|
||||
completer = Completer(services.model_manager)
|
||||
|
||||
readline.set_completer(completer.complete)
|
||||
@ -162,8 +166,6 @@ def set_autocompleter(services: InvocationServices) -> Completer:
|
||||
pass
|
||||
except OSError: # file likely corrupted
|
||||
newname = f"{histfile}.old"
|
||||
logger.error(
|
||||
f"Your history file {histfile} couldn't be loaded and may be corrupted. Renaming it to {newname}"
|
||||
)
|
||||
logger.error(f"Your history file {histfile} couldn't be loaded and may be corrupted. Renaming it to {newname}")
|
||||
histfile.replace(Path(newname))
|
||||
atexit.register(readline.write_history_file, histfile)
|
||||
|
@ -13,6 +13,7 @@ from pydantic.fields import Field
|
||||
# This should come early so that the logger can pick up its configuration options
|
||||
from .services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
config.parse_args()
|
||||
logger = InvokeAILogger().getLogger(config=config)
|
||||
@ -20,7 +21,7 @@ from invokeai.version.invokeai_version import __version__
|
||||
|
||||
# we call this early so that the message appears before other invokeai initialization messages
|
||||
if config.version:
|
||||
print(f'InvokeAI version {__version__}')
|
||||
print(f"InvokeAI version {__version__}")
|
||||
sys.exit(0)
|
||||
|
||||
from invokeai.app.services.board_image_record_storage import (
|
||||
@ -36,18 +37,21 @@ from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
|
||||
from invokeai.app.services.images import ImageService, ImageServiceDependencies
|
||||
from invokeai.app.services.resource_name import SimpleNameService
|
||||
from invokeai.app.services.urls import LocalUrlService
|
||||
from .services.default_graphs import (default_text_to_image_graph_id,
|
||||
create_system_graphs)
|
||||
from .services.default_graphs import default_text_to_image_graph_id, create_system_graphs
|
||||
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
||||
|
||||
from .cli.commands import (BaseCommand, CliContext, ExitCli,
|
||||
SortedHelpFormatter, add_graph_parsers, add_parsers)
|
||||
from .cli.commands import BaseCommand, CliContext, ExitCli, SortedHelpFormatter, add_graph_parsers, add_parsers
|
||||
from .cli.completer import set_autocompleter
|
||||
from .invocations.baseinvocation import BaseInvocation
|
||||
from .services.events import EventServiceBase
|
||||
from .services.graph import (Edge, EdgeConnection, GraphExecutionState,
|
||||
GraphInvocation, LibraryGraph,
|
||||
are_connection_types_compatible)
|
||||
from .services.graph import (
|
||||
Edge,
|
||||
EdgeConnection,
|
||||
GraphExecutionState,
|
||||
GraphInvocation,
|
||||
LibraryGraph,
|
||||
are_connection_types_compatible,
|
||||
)
|
||||
from .services.image_file_storage import DiskImageFileStorage
|
||||
from .services.invocation_queue import MemoryInvocationQueue
|
||||
from .services.invocation_services import InvocationServices
|
||||
@ -58,6 +62,7 @@ from .services.sqlite import SqliteItemStorage
|
||||
|
||||
import torch
|
||||
import invokeai.backend.util.hotfixes
|
||||
|
||||
if torch.backends.mps.is_available():
|
||||
import invokeai.backend.util.mps_fixes
|
||||
|
||||
@ -69,6 +74,7 @@ class CliCommand(BaseModel):
|
||||
class InvalidArgs(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def add_invocation_args(command_parser):
|
||||
# Add linking capability
|
||||
command_parser.add_argument(
|
||||
@ -113,7 +119,7 @@ def get_command_parser(services: InvocationServices) -> argparse.ArgumentParser:
|
||||
return parser
|
||||
|
||||
|
||||
class NodeField():
|
||||
class NodeField:
|
||||
alias: str
|
||||
node_path: str
|
||||
field: str
|
||||
@ -126,15 +132,20 @@ class NodeField():
|
||||
self.field_type = field_type
|
||||
|
||||
|
||||
def fields_from_type_hints(hints: dict[str, type], node_path: str) -> dict[str,NodeField]:
|
||||
return {k:NodeField(alias=k, node_path=node_path, field=k, field_type=v) for k, v in hints.items()}
|
||||
def fields_from_type_hints(hints: dict[str, type], node_path: str) -> dict[str, NodeField]:
|
||||
return {k: NodeField(alias=k, node_path=node_path, field=k, field_type=v) for k, v in hints.items()}
|
||||
|
||||
|
||||
def get_node_input_field(graph: LibraryGraph, field_alias: str, node_id: str) -> NodeField:
|
||||
"""Gets the node field for the specified field alias"""
|
||||
exposed_input = next(e for e in graph.exposed_inputs if e.alias == field_alias)
|
||||
node_type = type(graph.graph.get_node(exposed_input.node_path))
|
||||
return NodeField(alias=exposed_input.alias, node_path=f'{node_id}.{exposed_input.node_path}', field=exposed_input.field, field_type=get_type_hints(node_type)[exposed_input.field])
|
||||
return NodeField(
|
||||
alias=exposed_input.alias,
|
||||
node_path=f"{node_id}.{exposed_input.node_path}",
|
||||
field=exposed_input.field,
|
||||
field_type=get_type_hints(node_type)[exposed_input.field],
|
||||
)
|
||||
|
||||
|
||||
def get_node_output_field(graph: LibraryGraph, field_alias: str, node_id: str) -> NodeField:
|
||||
@ -142,7 +153,12 @@ def get_node_output_field(graph: LibraryGraph, field_alias: str, node_id: str) -
|
||||
exposed_output = next(e for e in graph.exposed_outputs if e.alias == field_alias)
|
||||
node_type = type(graph.graph.get_node(exposed_output.node_path))
|
||||
node_output_type = node_type.get_output_type()
|
||||
return NodeField(alias=exposed_output.alias, node_path=f'{node_id}.{exposed_output.node_path}', field=exposed_output.field, field_type=get_type_hints(node_output_type)[exposed_output.field])
|
||||
return NodeField(
|
||||
alias=exposed_output.alias,
|
||||
node_path=f"{node_id}.{exposed_output.node_path}",
|
||||
field=exposed_output.field,
|
||||
field_type=get_type_hints(node_output_type)[exposed_output.field],
|
||||
)
|
||||
|
||||
|
||||
def get_node_inputs(invocation: BaseInvocation, context: CliContext) -> dict[str, NodeField]:
|
||||
@ -165,9 +181,7 @@ def get_node_outputs(invocation: BaseInvocation, context: CliContext) -> dict[st
|
||||
return {e.alias: get_node_output_field(graph, e.alias, invocation.id) for e in graph.exposed_outputs}
|
||||
|
||||
|
||||
def generate_matching_edges(
|
||||
a: BaseInvocation, b: BaseInvocation, context: CliContext
|
||||
) -> list[Edge]:
|
||||
def generate_matching_edges(a: BaseInvocation, b: BaseInvocation, context: CliContext) -> list[Edge]:
|
||||
"""Generates all possible edges between two invocations"""
|
||||
afields = get_node_outputs(a, context)
|
||||
bfields = get_node_inputs(b, context)
|
||||
@ -179,12 +193,14 @@ def generate_matching_edges(
|
||||
matching_fields = matching_fields.difference(invalid_fields)
|
||||
|
||||
# Validate types
|
||||
matching_fields = [f for f in matching_fields if are_connection_types_compatible(afields[f].field_type, bfields[f].field_type)]
|
||||
matching_fields = [
|
||||
f for f in matching_fields if are_connection_types_compatible(afields[f].field_type, bfields[f].field_type)
|
||||
]
|
||||
|
||||
edges = [
|
||||
Edge(
|
||||
source=EdgeConnection(node_id=afields[alias].node_path, field=afields[alias].field),
|
||||
destination=EdgeConnection(node_id=bfields[alias].node_path, field=bfields[alias].field)
|
||||
destination=EdgeConnection(node_id=bfields[alias].node_path, field=bfields[alias].field),
|
||||
)
|
||||
for alias in matching_fields
|
||||
]
|
||||
@ -193,6 +209,7 @@ def generate_matching_edges(
|
||||
|
||||
class SessionError(Exception):
|
||||
"""Raised when a session error has occurred"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@ -209,22 +226,23 @@ def invoke_all(context: CliContext):
|
||||
context.invoker.services.logger.error(
|
||||
f"Error in node {n} (source node {context.session.prepared_source_mapping[n]}): {context.session.errors[n]}"
|
||||
)
|
||||
|
||||
|
||||
raise SessionError()
|
||||
|
||||
|
||||
def invoke_cli():
|
||||
logger.info(f'InvokeAI version {__version__}')
|
||||
logger.info(f"InvokeAI version {__version__}")
|
||||
# get the optional list of invocations to execute on the command line
|
||||
parser = config.get_parser()
|
||||
parser.add_argument('commands',nargs='*')
|
||||
parser.add_argument("commands", nargs="*")
|
||||
invocation_commands = parser.parse_args().commands
|
||||
|
||||
# get the optional file to read commands from.
|
||||
# Simplest is to use it for STDIN
|
||||
if infile := config.from_file:
|
||||
sys.stdin = open(infile,"r")
|
||||
|
||||
model_manager = ModelManagerService(config,logger)
|
||||
sys.stdin = open(infile, "r")
|
||||
|
||||
model_manager = ModelManagerService(config, logger)
|
||||
|
||||
events = EventServiceBase()
|
||||
output_folder = config.output_path
|
||||
@ -234,13 +252,13 @@ def invoke_cli():
|
||||
db_location = ":memory:"
|
||||
else:
|
||||
db_location = config.db_path
|
||||
db_location.parent.mkdir(parents=True,exist_ok=True)
|
||||
db_location.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
logger.info(f'InvokeAI database location is "{db_location}"')
|
||||
|
||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](
|
||||
filename=db_location, table_name="graph_executions"
|
||||
)
|
||||
filename=db_location, table_name="graph_executions"
|
||||
)
|
||||
|
||||
urls = LocalUrlService()
|
||||
image_record_storage = SqliteImageRecordStorage(db_location)
|
||||
@ -281,24 +299,21 @@ def invoke_cli():
|
||||
graph_execution_manager=graph_execution_manager,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
services = InvocationServices(
|
||||
model_manager=model_manager,
|
||||
events=events,
|
||||
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f'{output_folder}/latents')),
|
||||
latents=ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents")),
|
||||
images=images,
|
||||
boards=boards,
|
||||
board_images=board_images,
|
||||
queue=MemoryInvocationQueue(),
|
||||
graph_library=SqliteItemStorage[LibraryGraph](
|
||||
filename=db_location, table_name="graphs"
|
||||
),
|
||||
graph_library=SqliteItemStorage[LibraryGraph](filename=db_location, table_name="graphs"),
|
||||
graph_execution_manager=graph_execution_manager,
|
||||
processor=DefaultInvocationProcessor(),
|
||||
logger=logger,
|
||||
configuration=config,
|
||||
)
|
||||
|
||||
|
||||
system_graphs = create_system_graphs(services.graph_library)
|
||||
system_graph_names = set([g.name for g in system_graphs])
|
||||
@ -308,7 +323,7 @@ def invoke_cli():
|
||||
session: GraphExecutionState = invoker.create_execution_state()
|
||||
parser = get_command_parser(services)
|
||||
|
||||
re_negid = re.compile('^-[0-9]+$')
|
||||
re_negid = re.compile("^-[0-9]+$")
|
||||
|
||||
# Uncomment to print out previous sessions at startup
|
||||
# print(services.session_manager.list())
|
||||
@ -318,7 +333,7 @@ def invoke_cli():
|
||||
|
||||
command_line_args_exist = len(invocation_commands) > 0
|
||||
done = False
|
||||
|
||||
|
||||
while not done:
|
||||
try:
|
||||
if command_line_args_exist:
|
||||
@ -332,7 +347,7 @@ def invoke_cli():
|
||||
|
||||
try:
|
||||
# Refresh the state of the session
|
||||
#history = list(get_graph_execution_history(context.session))
|
||||
# history = list(get_graph_execution_history(context.session))
|
||||
history = list(reversed(context.nodes_added))
|
||||
|
||||
# Split the command for piping
|
||||
@ -353,17 +368,17 @@ def invoke_cli():
|
||||
args[field_name] = field_default
|
||||
|
||||
# Parse invocation
|
||||
command: CliCommand = None # type:ignore
|
||||
command: CliCommand = None # type:ignore
|
||||
system_graph: Optional[LibraryGraph] = None
|
||||
if args['type'] in system_graph_names:
|
||||
system_graph = next(filter(lambda g: g.name == args['type'], system_graphs))
|
||||
if args["type"] in system_graph_names:
|
||||
system_graph = next(filter(lambda g: g.name == args["type"], system_graphs))
|
||||
invocation = GraphInvocation(graph=system_graph.graph, id=str(current_id))
|
||||
for exposed_input in system_graph.exposed_inputs:
|
||||
if exposed_input.alias in args:
|
||||
node = invocation.graph.get_node(exposed_input.node_path)
|
||||
field = exposed_input.field
|
||||
setattr(node, field, args[exposed_input.alias])
|
||||
command = CliCommand(command = invocation)
|
||||
command = CliCommand(command=invocation)
|
||||
context.graph_nodes[invocation.id] = system_graph.id
|
||||
else:
|
||||
args["id"] = current_id
|
||||
@ -385,17 +400,13 @@ def invoke_cli():
|
||||
# Pipe previous command output (if there was a previous command)
|
||||
edges: list[Edge] = list()
|
||||
if len(history) > 0 or current_id != start_id:
|
||||
from_id = (
|
||||
history[0] if current_id == start_id else str(current_id - 1)
|
||||
)
|
||||
from_id = history[0] if current_id == start_id else str(current_id - 1)
|
||||
from_node = (
|
||||
next(filter(lambda n: n[0].id == from_id, new_invocations))[0]
|
||||
if current_id != start_id
|
||||
else context.session.graph.get_node(from_id)
|
||||
)
|
||||
matching_edges = generate_matching_edges(
|
||||
from_node, command.command, context
|
||||
)
|
||||
matching_edges = generate_matching_edges(from_node, command.command, context)
|
||||
edges.extend(matching_edges)
|
||||
|
||||
# Parse provided links
|
||||
@ -406,16 +417,18 @@ def invoke_cli():
|
||||
node_id = str(current_id + int(node_id))
|
||||
|
||||
link_node = context.session.graph.get_node(node_id)
|
||||
matching_edges = generate_matching_edges(
|
||||
link_node, command.command, context
|
||||
)
|
||||
matching_edges = generate_matching_edges(link_node, command.command, context)
|
||||
matching_destinations = [e.destination for e in matching_edges]
|
||||
edges = [e for e in edges if e.destination not in matching_destinations]
|
||||
edges.extend(matching_edges)
|
||||
|
||||
if "link" in args and args["link"]:
|
||||
for link in args["link"]:
|
||||
edges = [e for e in edges if e.destination.node_id != command.command.id or e.destination.field != link[2]]
|
||||
edges = [
|
||||
e
|
||||
for e in edges
|
||||
if e.destination.node_id != command.command.id or e.destination.field != link[2]
|
||||
]
|
||||
|
||||
node_id = link[0]
|
||||
if re_negid.match(node_id):
|
||||
@ -428,7 +441,7 @@ def invoke_cli():
|
||||
edges.append(
|
||||
Edge(
|
||||
source=EdgeConnection(node_id=node_output.node_path, field=node_output.field),
|
||||
destination=EdgeConnection(node_id=node_input.node_path, field=node_input.field)
|
||||
destination=EdgeConnection(node_id=node_input.node_path, field=node_input.field),
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -4,9 +4,5 @@ __all__ = []
|
||||
|
||||
dirname = os.path.dirname(os.path.abspath(__file__))
|
||||
for f in os.listdir(dirname):
|
||||
if (
|
||||
f != "__init__.py"
|
||||
and os.path.isfile("%s/%s" % (dirname, f))
|
||||
and f[-3:] == ".py"
|
||||
):
|
||||
if f != "__init__.py" and os.path.isfile("%s/%s" % (dirname, f)) and f[-3:] == ".py":
|
||||
__all__.append(f[:-3])
|
||||
|
@ -4,8 +4,7 @@ from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from inspect import signature
|
||||
from typing import (TYPE_CHECKING, Dict, List, Literal, TypedDict, get_args,
|
||||
get_type_hints)
|
||||
from typing import TYPE_CHECKING, Dict, List, Literal, TypedDict, get_args, get_type_hints
|
||||
|
||||
from pydantic import BaseConfig, BaseModel, Field
|
||||
|
||||
|
@ -8,8 +8,7 @@ from pydantic import Field, validator
|
||||
from invokeai.app.models.image import ImageField
|
||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||
|
||||
from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
|
||||
InvocationConfig, InvocationContext, UIConfig)
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext, UIConfig
|
||||
|
||||
|
||||
class IntCollectionOutput(BaseInvocationOutput):
|
||||
@ -27,8 +26,7 @@ class FloatCollectionOutput(BaseInvocationOutput):
|
||||
type: Literal["float_collection"] = "float_collection"
|
||||
|
||||
# Outputs
|
||||
collection: list[float] = Field(
|
||||
default=[], description="The float collection")
|
||||
collection: list[float] = Field(default=[], description="The float collection")
|
||||
|
||||
|
||||
class ImageCollectionOutput(BaseInvocationOutput):
|
||||
@ -37,8 +35,7 @@ class ImageCollectionOutput(BaseInvocationOutput):
|
||||
type: Literal["image_collection"] = "image_collection"
|
||||
|
||||
# Outputs
|
||||
collection: list[ImageField] = Field(
|
||||
default=[], description="The output images")
|
||||
collection: list[ImageField] = Field(default=[], description="The output images")
|
||||
|
||||
class Config:
|
||||
schema_extra = {"required": ["type", "collection"]}
|
||||
@ -56,10 +53,7 @@ class RangeInvocation(BaseInvocation):
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Range",
|
||||
"tags": ["range", "integer", "collection"]
|
||||
},
|
||||
"ui": {"title": "Range", "tags": ["range", "integer", "collection"]},
|
||||
}
|
||||
|
||||
@validator("stop")
|
||||
@ -69,9 +63,7 @@ class RangeInvocation(BaseInvocation):
|
||||
return v
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntCollectionOutput:
|
||||
return IntCollectionOutput(
|
||||
collection=list(range(self.start, self.stop, self.step))
|
||||
)
|
||||
return IntCollectionOutput(collection=list(range(self.start, self.stop, self.step)))
|
||||
|
||||
|
||||
class RangeOfSizeInvocation(BaseInvocation):
|
||||
@ -86,18 +78,11 @@ class RangeOfSizeInvocation(BaseInvocation):
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Sized Range",
|
||||
"tags": ["range", "integer", "size", "collection"]
|
||||
},
|
||||
"ui": {"title": "Sized Range", "tags": ["range", "integer", "size", "collection"]},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntCollectionOutput:
|
||||
return IntCollectionOutput(
|
||||
collection=list(
|
||||
range(
|
||||
self.start, self.start + self.size,
|
||||
self.step)))
|
||||
return IntCollectionOutput(collection=list(range(self.start, self.start + self.size, self.step)))
|
||||
|
||||
|
||||
class RandomRangeInvocation(BaseInvocation):
|
||||
@ -107,9 +92,7 @@ class RandomRangeInvocation(BaseInvocation):
|
||||
|
||||
# Inputs
|
||||
low: int = Field(default=0, description="The inclusive low value")
|
||||
high: int = Field(
|
||||
default=np.iinfo(np.int32).max, description="The exclusive high value"
|
||||
)
|
||||
high: int = Field(default=np.iinfo(np.int32).max, description="The exclusive high value")
|
||||
size: int = Field(default=1, description="The number of values to generate")
|
||||
seed: int = Field(
|
||||
ge=0,
|
||||
@ -120,19 +103,12 @@ class RandomRangeInvocation(BaseInvocation):
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Random Range",
|
||||
"tags": ["range", "integer", "random", "collection"]
|
||||
},
|
||||
"ui": {"title": "Random Range", "tags": ["range", "integer", "random", "collection"]},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntCollectionOutput:
|
||||
rng = np.random.default_rng(self.seed)
|
||||
return IntCollectionOutput(
|
||||
collection=list(
|
||||
rng.integers(
|
||||
low=self.low, high=self.high,
|
||||
size=self.size)))
|
||||
return IntCollectionOutput(collection=list(rng.integers(low=self.low, high=self.high, size=self.size)))
|
||||
|
||||
|
||||
class ImageCollectionInvocation(BaseInvocation):
|
||||
|
@ -3,64 +3,63 @@ from pydantic import BaseModel, Field
|
||||
import re
|
||||
import torch
|
||||
from compel import Compel, ReturnedEmbeddingsType
|
||||
from compel.prompt_parser import (Blend, Conjunction,
|
||||
CrossAttentionControlSubstitute,
|
||||
FlattenedPrompt, Fragment)
|
||||
from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment
|
||||
from ...backend.util.devices import torch_dtype
|
||||
from ...backend.model_management import ModelType
|
||||
from ...backend.model_management.models import ModelNotFoundException
|
||||
from ...backend.model_management.lora import ModelPatcher
|
||||
from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent
|
||||
from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
|
||||
InvocationConfig, InvocationContext)
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
|
||||
from .model import ClipField
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
class ConditioningField(BaseModel):
|
||||
conditioning_name: Optional[str] = Field(
|
||||
default=None, description="The name of conditioning data")
|
||||
conditioning_name: Optional[str] = Field(default=None, description="The name of conditioning data")
|
||||
|
||||
class Config:
|
||||
schema_extra = {"required": ["conditioning_name"]}
|
||||
|
||||
|
||||
@dataclass
|
||||
class BasicConditioningInfo:
|
||||
#type: Literal["basic_conditioning"] = "basic_conditioning"
|
||||
# type: Literal["basic_conditioning"] = "basic_conditioning"
|
||||
embeds: torch.Tensor
|
||||
extra_conditioning: Optional[InvokeAIDiffuserComponent.ExtraConditioningInfo]
|
||||
# weight: float
|
||||
# mode: ConditioningAlgo
|
||||
|
||||
|
||||
@dataclass
|
||||
class SDXLConditioningInfo(BasicConditioningInfo):
|
||||
#type: Literal["sdxl_conditioning"] = "sdxl_conditioning"
|
||||
# type: Literal["sdxl_conditioning"] = "sdxl_conditioning"
|
||||
pooled_embeds: torch.Tensor
|
||||
add_time_ids: torch.Tensor
|
||||
|
||||
ConditioningInfoType = Annotated[
|
||||
Union[BasicConditioningInfo, SDXLConditioningInfo],
|
||||
Field(discriminator="type")
|
||||
]
|
||||
|
||||
ConditioningInfoType = Annotated[Union[BasicConditioningInfo, SDXLConditioningInfo], Field(discriminator="type")]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConditioningFieldData:
|
||||
conditionings: List[Union[BasicConditioningInfo, SDXLConditioningInfo]]
|
||||
#unconditioned: Optional[torch.Tensor]
|
||||
# unconditioned: Optional[torch.Tensor]
|
||||
|
||||
#class ConditioningAlgo(str, Enum):
|
||||
|
||||
# class ConditioningAlgo(str, Enum):
|
||||
# Compose = "compose"
|
||||
# ComposeEx = "compose_ex"
|
||||
# PerpNeg = "perp_neg"
|
||||
|
||||
|
||||
class CompelOutput(BaseInvocationOutput):
|
||||
"""Compel parser output"""
|
||||
|
||||
#fmt: off
|
||||
# fmt: off
|
||||
type: Literal["compel_output"] = "compel_output"
|
||||
|
||||
conditioning: ConditioningField = Field(default=None, description="Conditioning")
|
||||
#fmt: on
|
||||
# fmt: on
|
||||
|
||||
|
||||
class CompelInvocation(BaseInvocation):
|
||||
@ -74,33 +73,28 @@ class CompelInvocation(BaseInvocation):
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Prompt (Compel)",
|
||||
"tags": ["prompt", "compel"],
|
||||
"type_hints": {
|
||||
"model": "model"
|
||||
}
|
||||
},
|
||||
"ui": {"title": "Prompt (Compel)", "tags": ["prompt", "compel"], "type_hints": {"model": "model"}},
|
||||
}
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> CompelOutput:
|
||||
tokenizer_info = context.services.model_manager.get_model(
|
||||
**self.clip.tokenizer.dict(), context=context,
|
||||
**self.clip.tokenizer.dict(),
|
||||
context=context,
|
||||
)
|
||||
text_encoder_info = context.services.model_manager.get_model(
|
||||
**self.clip.text_encoder.dict(), context=context,
|
||||
**self.clip.text_encoder.dict(),
|
||||
context=context,
|
||||
)
|
||||
|
||||
def _lora_loader():
|
||||
for lora in self.clip.loras:
|
||||
lora_info = context.services.model_manager.get_model(
|
||||
**lora.dict(exclude={"weight"}))
|
||||
lora_info = context.services.model_manager.get_model(**lora.dict(exclude={"weight"}), context=context)
|
||||
yield (lora_info.context.model, lora.weight)
|
||||
del lora_info
|
||||
return
|
||||
|
||||
#loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
|
||||
# loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
|
||||
|
||||
ti_list = []
|
||||
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt):
|
||||
@ -116,15 +110,18 @@ class CompelInvocation(BaseInvocation):
|
||||
)
|
||||
except ModelNotFoundException:
|
||||
# print(e)
|
||||
#import traceback
|
||||
#print(traceback.format_exc())
|
||||
print(f"Warn: trigger: \"{trigger}\" not found")
|
||||
|
||||
with ModelPatcher.apply_lora_text_encoder(text_encoder_info.context.model, _lora_loader()),\
|
||||
ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (tokenizer, ti_manager),\
|
||||
ModelPatcher.apply_clip_skip(text_encoder_info.context.model, self.clip.skipped_layers),\
|
||||
text_encoder_info as text_encoder:
|
||||
# import traceback
|
||||
# print(traceback.format_exc())
|
||||
print(f'Warn: trigger: "{trigger}" not found')
|
||||
|
||||
with ModelPatcher.apply_lora_text_encoder(
|
||||
text_encoder_info.context.model, _lora_loader()
|
||||
), ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (
|
||||
tokenizer,
|
||||
ti_manager,
|
||||
), ModelPatcher.apply_clip_skip(
|
||||
text_encoder_info.context.model, self.clip.skipped_layers
|
||||
), text_encoder_info as text_encoder:
|
||||
compel = Compel(
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=text_encoder,
|
||||
@ -139,14 +136,12 @@ class CompelInvocation(BaseInvocation):
|
||||
if context.services.configuration.log_tokenization:
|
||||
log_tokenization_for_prompt_object(prompt, tokenizer)
|
||||
|
||||
c, options = compel.build_conditioning_tensor_for_prompt_object(
|
||||
prompt)
|
||||
c, options = compel.build_conditioning_tensor_for_prompt_object(prompt)
|
||||
|
||||
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(
|
||||
tokens_count_including_eos_bos=get_max_token_count(
|
||||
tokenizer, conjunction),
|
||||
cross_attention_control_args=options.get(
|
||||
"cross_attention_control", None),)
|
||||
tokens_count_including_eos_bos=get_max_token_count(tokenizer, conjunction),
|
||||
cross_attention_control_args=options.get("cross_attention_control", None),
|
||||
)
|
||||
|
||||
c = c.detach().to("cpu")
|
||||
|
||||
@ -168,24 +163,26 @@ class CompelInvocation(BaseInvocation):
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class SDXLPromptInvocationBase:
|
||||
def run_clip_raw(self, context, clip_field, prompt, get_pooled):
|
||||
tokenizer_info = context.services.model_manager.get_model(
|
||||
**clip_field.tokenizer.dict(),
|
||||
context=context,
|
||||
)
|
||||
text_encoder_info = context.services.model_manager.get_model(
|
||||
**clip_field.text_encoder.dict(),
|
||||
context=context,
|
||||
)
|
||||
|
||||
def _lora_loader():
|
||||
for lora in clip_field.loras:
|
||||
lora_info = context.services.model_manager.get_model(
|
||||
**lora.dict(exclude={"weight"}))
|
||||
lora_info = context.services.model_manager.get_model(**lora.dict(exclude={"weight"}), context=context)
|
||||
yield (lora_info.context.model, lora.weight)
|
||||
del lora_info
|
||||
return
|
||||
|
||||
#loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
|
||||
# loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
|
||||
|
||||
ti_list = []
|
||||
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", prompt):
|
||||
@ -196,19 +193,23 @@ class SDXLPromptInvocationBase:
|
||||
model_name=name,
|
||||
base_model=clip_field.text_encoder.base_model,
|
||||
model_type=ModelType.TextualInversion,
|
||||
context=context,
|
||||
).context.model
|
||||
)
|
||||
except ModelNotFoundException:
|
||||
# print(e)
|
||||
#import traceback
|
||||
#print(traceback.format_exc())
|
||||
print(f"Warn: trigger: \"{trigger}\" not found")
|
||||
|
||||
with ModelPatcher.apply_lora_text_encoder(text_encoder_info.context.model, _lora_loader()),\
|
||||
ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (tokenizer, ti_manager),\
|
||||
ModelPatcher.apply_clip_skip(text_encoder_info.context.model, clip_field.skipped_layers),\
|
||||
text_encoder_info as text_encoder:
|
||||
# import traceback
|
||||
# print(traceback.format_exc())
|
||||
print(f'Warn: trigger: "{trigger}" not found')
|
||||
|
||||
with ModelPatcher.apply_lora_text_encoder(
|
||||
text_encoder_info.context.model, _lora_loader()
|
||||
), ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (
|
||||
tokenizer,
|
||||
ti_manager,
|
||||
), ModelPatcher.apply_clip_skip(
|
||||
text_encoder_info.context.model, clip_field.skipped_layers
|
||||
), text_encoder_info as text_encoder:
|
||||
text_inputs = tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
@ -241,20 +242,21 @@ class SDXLPromptInvocationBase:
|
||||
def run_clip_compel(self, context, clip_field, prompt, get_pooled):
|
||||
tokenizer_info = context.services.model_manager.get_model(
|
||||
**clip_field.tokenizer.dict(),
|
||||
context=context,
|
||||
)
|
||||
text_encoder_info = context.services.model_manager.get_model(
|
||||
**clip_field.text_encoder.dict(),
|
||||
context=context,
|
||||
)
|
||||
|
||||
def _lora_loader():
|
||||
for lora in clip_field.loras:
|
||||
lora_info = context.services.model_manager.get_model(
|
||||
**lora.dict(exclude={"weight"}))
|
||||
lora_info = context.services.model_manager.get_model(**lora.dict(exclude={"weight"}), context=context)
|
||||
yield (lora_info.context.model, lora.weight)
|
||||
del lora_info
|
||||
return
|
||||
|
||||
#loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
|
||||
# loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
|
||||
|
||||
ti_list = []
|
||||
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", prompt):
|
||||
@ -265,26 +267,30 @@ class SDXLPromptInvocationBase:
|
||||
model_name=name,
|
||||
base_model=clip_field.text_encoder.base_model,
|
||||
model_type=ModelType.TextualInversion,
|
||||
context=context,
|
||||
).context.model
|
||||
)
|
||||
except ModelNotFoundException:
|
||||
# print(e)
|
||||
#import traceback
|
||||
#print(traceback.format_exc())
|
||||
print(f"Warn: trigger: \"{trigger}\" not found")
|
||||
|
||||
with ModelPatcher.apply_lora_text_encoder(text_encoder_info.context.model, _lora_loader()),\
|
||||
ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (tokenizer, ti_manager),\
|
||||
ModelPatcher.apply_clip_skip(text_encoder_info.context.model, clip_field.skipped_layers),\
|
||||
text_encoder_info as text_encoder:
|
||||
# import traceback
|
||||
# print(traceback.format_exc())
|
||||
print(f'Warn: trigger: "{trigger}" not found')
|
||||
|
||||
with ModelPatcher.apply_lora_text_encoder(
|
||||
text_encoder_info.context.model, _lora_loader()
|
||||
), ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (
|
||||
tokenizer,
|
||||
ti_manager,
|
||||
), ModelPatcher.apply_clip_skip(
|
||||
text_encoder_info.context.model, clip_field.skipped_layers
|
||||
), text_encoder_info as text_encoder:
|
||||
compel = Compel(
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=text_encoder,
|
||||
textual_inversion_manager=ti_manager,
|
||||
dtype_for_device_getter=torch_dtype,
|
||||
truncate_long_prompts=True, # TODO:
|
||||
returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, # TODO: clip skip
|
||||
returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, # TODO: clip skip
|
||||
requires_pooled=True,
|
||||
)
|
||||
|
||||
@ -318,6 +324,7 @@ class SDXLPromptInvocationBase:
|
||||
|
||||
return c, c_pooled, ec
|
||||
|
||||
|
||||
class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||
"""Parse prompt using compel package to conditioning."""
|
||||
|
||||
@ -337,13 +344,7 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "SDXL Prompt (Compel)",
|
||||
"tags": ["prompt", "compel"],
|
||||
"type_hints": {
|
||||
"model": "model"
|
||||
}
|
||||
},
|
||||
"ui": {"title": "SDXL Prompt (Compel)", "tags": ["prompt", "compel"], "type_hints": {"model": "model"}},
|
||||
}
|
||||
|
||||
@torch.no_grad()
|
||||
@ -358,9 +359,7 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||
crop_coords = (self.crop_top, self.crop_left)
|
||||
target_size = (self.target_height, self.target_width)
|
||||
|
||||
add_time_ids = torch.tensor([
|
||||
original_size + crop_coords + target_size
|
||||
])
|
||||
add_time_ids = torch.tensor([original_size + crop_coords + target_size])
|
||||
|
||||
conditioning_data = ConditioningFieldData(
|
||||
conditionings=[
|
||||
@ -382,12 +381,13 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||
"""Parse prompt using compel package to conditioning."""
|
||||
|
||||
type: Literal["sdxl_refiner_compel_prompt"] = "sdxl_refiner_compel_prompt"
|
||||
|
||||
style: str = Field(default="", description="Style prompt") # TODO: ?
|
||||
style: str = Field(default="", description="Style prompt") # TODO: ?
|
||||
original_width: int = Field(1024, description="")
|
||||
original_height: int = Field(1024, description="")
|
||||
crop_top: int = Field(0, description="")
|
||||
@ -401,9 +401,7 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
|
||||
"ui": {
|
||||
"title": "SDXL Refiner Prompt (Compel)",
|
||||
"tags": ["prompt", "compel"],
|
||||
"type_hints": {
|
||||
"model": "model"
|
||||
}
|
||||
"type_hints": {"model": "model"},
|
||||
},
|
||||
}
|
||||
|
||||
@ -414,9 +412,7 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
|
||||
original_size = (self.original_height, self.original_width)
|
||||
crop_coords = (self.crop_top, self.crop_left)
|
||||
|
||||
add_time_ids = torch.tensor([
|
||||
original_size + crop_coords + (self.aesthetic_score,)
|
||||
])
|
||||
add_time_ids = torch.tensor([original_size + crop_coords + (self.aesthetic_score,)])
|
||||
|
||||
conditioning_data = ConditioningFieldData(
|
||||
conditionings=[
|
||||
@ -424,7 +420,7 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
|
||||
embeds=c2,
|
||||
pooled_embeds=c2_pooled,
|
||||
add_time_ids=add_time_ids,
|
||||
extra_conditioning=ec2, # or None
|
||||
extra_conditioning=ec2, # or None
|
||||
)
|
||||
]
|
||||
)
|
||||
@ -438,6 +434,7 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class SDXLRawPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||
"""Pass unmodified prompt to conditioning without compel processing."""
|
||||
|
||||
@ -457,13 +454,7 @@ class SDXLRawPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "SDXL Prompt (Raw)",
|
||||
"tags": ["prompt", "compel"],
|
||||
"type_hints": {
|
||||
"model": "model"
|
||||
}
|
||||
},
|
||||
"ui": {"title": "SDXL Prompt (Raw)", "tags": ["prompt", "compel"], "type_hints": {"model": "model"}},
|
||||
}
|
||||
|
||||
@torch.no_grad()
|
||||
@ -478,9 +469,7 @@ class SDXLRawPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||
crop_coords = (self.crop_top, self.crop_left)
|
||||
target_size = (self.target_height, self.target_width)
|
||||
|
||||
add_time_ids = torch.tensor([
|
||||
original_size + crop_coords + target_size
|
||||
])
|
||||
add_time_ids = torch.tensor([original_size + crop_coords + target_size])
|
||||
|
||||
conditioning_data = ConditioningFieldData(
|
||||
conditionings=[
|
||||
@ -502,12 +491,13 @@ class SDXLRawPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class SDXLRefinerRawPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||
"""Parse prompt using compel package to conditioning."""
|
||||
|
||||
type: Literal["sdxl_refiner_raw_prompt"] = "sdxl_refiner_raw_prompt"
|
||||
|
||||
style: str = Field(default="", description="Style prompt") # TODO: ?
|
||||
style: str = Field(default="", description="Style prompt") # TODO: ?
|
||||
original_width: int = Field(1024, description="")
|
||||
original_height: int = Field(1024, description="")
|
||||
crop_top: int = Field(0, description="")
|
||||
@ -521,9 +511,7 @@ class SDXLRefinerRawPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||
"ui": {
|
||||
"title": "SDXL Refiner Prompt (Raw)",
|
||||
"tags": ["prompt", "compel"],
|
||||
"type_hints": {
|
||||
"model": "model"
|
||||
}
|
||||
"type_hints": {"model": "model"},
|
||||
},
|
||||
}
|
||||
|
||||
@ -534,9 +522,7 @@ class SDXLRefinerRawPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||
original_size = (self.original_height, self.original_width)
|
||||
crop_coords = (self.crop_top, self.crop_left)
|
||||
|
||||
add_time_ids = torch.tensor([
|
||||
original_size + crop_coords + (self.aesthetic_score,)
|
||||
])
|
||||
add_time_ids = torch.tensor([original_size + crop_coords + (self.aesthetic_score,)])
|
||||
|
||||
conditioning_data = ConditioningFieldData(
|
||||
conditionings=[
|
||||
@ -544,7 +530,7 @@ class SDXLRefinerRawPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||
embeds=c2,
|
||||
pooled_embeds=c2_pooled,
|
||||
add_time_ids=add_time_ids,
|
||||
extra_conditioning=ec2, # or None
|
||||
extra_conditioning=ec2, # or None
|
||||
)
|
||||
]
|
||||
)
|
||||
@ -561,11 +547,14 @@ class SDXLRefinerRawPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||
|
||||
class ClipSkipInvocationOutput(BaseInvocationOutput):
|
||||
"""Clip skip node output"""
|
||||
|
||||
type: Literal["clip_skip_output"] = "clip_skip_output"
|
||||
clip: ClipField = Field(None, description="Clip with skipped layers")
|
||||
|
||||
|
||||
class ClipSkipInvocation(BaseInvocation):
|
||||
"""Skip layers in clip text_encoder model."""
|
||||
|
||||
type: Literal["clip_skip"] = "clip_skip"
|
||||
|
||||
clip: ClipField = Field(None, description="Clip to use")
|
||||
@ -573,10 +562,7 @@ class ClipSkipInvocation(BaseInvocation):
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "CLIP Skip",
|
||||
"tags": ["clip", "skip"]
|
||||
},
|
||||
"ui": {"title": "CLIP Skip", "tags": ["clip", "skip"]},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ClipSkipInvocationOutput:
|
||||
@ -587,46 +573,26 @@ class ClipSkipInvocation(BaseInvocation):
|
||||
|
||||
|
||||
def get_max_token_count(
|
||||
tokenizer, prompt: Union[FlattenedPrompt, Blend, Conjunction],
|
||||
truncate_if_too_long=False) -> int:
|
||||
tokenizer, prompt: Union[FlattenedPrompt, Blend, Conjunction], truncate_if_too_long=False
|
||||
) -> int:
|
||||
if type(prompt) is Blend:
|
||||
blend: Blend = prompt
|
||||
return max(
|
||||
[
|
||||
get_max_token_count(tokenizer, p, truncate_if_too_long)
|
||||
for p in blend.prompts
|
||||
]
|
||||
)
|
||||
return max([get_max_token_count(tokenizer, p, truncate_if_too_long) for p in blend.prompts])
|
||||
elif type(prompt) is Conjunction:
|
||||
conjunction: Conjunction = prompt
|
||||
return sum(
|
||||
[
|
||||
get_max_token_count(tokenizer, p, truncate_if_too_long)
|
||||
for p in conjunction.prompts
|
||||
]
|
||||
)
|
||||
return sum([get_max_token_count(tokenizer, p, truncate_if_too_long) for p in conjunction.prompts])
|
||||
else:
|
||||
return len(
|
||||
get_tokens_for_prompt_object(
|
||||
tokenizer, prompt, truncate_if_too_long))
|
||||
return len(get_tokens_for_prompt_object(tokenizer, prompt, truncate_if_too_long))
|
||||
|
||||
|
||||
def get_tokens_for_prompt_object(
|
||||
tokenizer, parsed_prompt: FlattenedPrompt, truncate_if_too_long=True
|
||||
) -> List[str]:
|
||||
def get_tokens_for_prompt_object(tokenizer, parsed_prompt: FlattenedPrompt, truncate_if_too_long=True) -> List[str]:
|
||||
if type(parsed_prompt) is Blend:
|
||||
raise ValueError(
|
||||
"Blend is not supported here - you need to get tokens for each of its .children"
|
||||
)
|
||||
raise ValueError("Blend is not supported here - you need to get tokens for each of its .children")
|
||||
|
||||
text_fragments = [
|
||||
x.text
|
||||
if type(x) is Fragment
|
||||
else (
|
||||
" ".join([f.text for f in x.original])
|
||||
if type(x) is CrossAttentionControlSubstitute
|
||||
else str(x)
|
||||
)
|
||||
else (" ".join([f.text for f in x.original]) if type(x) is CrossAttentionControlSubstitute else str(x))
|
||||
for x in parsed_prompt.children
|
||||
]
|
||||
text = " ".join(text_fragments)
|
||||
@ -637,25 +603,17 @@ def get_tokens_for_prompt_object(
|
||||
return tokens
|
||||
|
||||
|
||||
def log_tokenization_for_conjunction(
|
||||
c: Conjunction, tokenizer, display_label_prefix=None
|
||||
):
|
||||
def log_tokenization_for_conjunction(c: Conjunction, tokenizer, display_label_prefix=None):
|
||||
display_label_prefix = display_label_prefix or ""
|
||||
for i, p in enumerate(c.prompts):
|
||||
if len(c.prompts) > 1:
|
||||
this_display_label_prefix = f"{display_label_prefix}(conjunction part {i + 1}, weight={c.weights[i]})"
|
||||
else:
|
||||
this_display_label_prefix = display_label_prefix
|
||||
log_tokenization_for_prompt_object(
|
||||
p,
|
||||
tokenizer,
|
||||
display_label_prefix=this_display_label_prefix
|
||||
)
|
||||
log_tokenization_for_prompt_object(p, tokenizer, display_label_prefix=this_display_label_prefix)
|
||||
|
||||
|
||||
def log_tokenization_for_prompt_object(
|
||||
p: Union[Blend, FlattenedPrompt], tokenizer, display_label_prefix=None
|
||||
):
|
||||
def log_tokenization_for_prompt_object(p: Union[Blend, FlattenedPrompt], tokenizer, display_label_prefix=None):
|
||||
display_label_prefix = display_label_prefix or ""
|
||||
if type(p) is Blend:
|
||||
blend: Blend = p
|
||||
@ -692,13 +650,10 @@ def log_tokenization_for_prompt_object(
|
||||
)
|
||||
else:
|
||||
text = " ".join([x.text for x in flattened_prompt.children])
|
||||
log_tokenization_for_text(
|
||||
text, tokenizer, display_label=display_label_prefix
|
||||
)
|
||||
log_tokenization_for_text(text, tokenizer, display_label=display_label_prefix)
|
||||
|
||||
|
||||
def log_tokenization_for_text(
|
||||
text, tokenizer, display_label=None, truncate_if_too_long=False):
|
||||
def log_tokenization_for_text(text, tokenizer, display_label=None, truncate_if_too_long=False):
|
||||
"""shows how the prompt is tokenized
|
||||
# usually tokens have '</w>' to indicate end-of-word,
|
||||
# but for readability it has been replaced with ' '
|
||||
|
@ -6,21 +6,30 @@ from typing import Dict, List, Literal, Optional, Union
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from controlnet_aux import (CannyDetector, ContentShuffleDetector, HEDdetector,
|
||||
LeresDetector, LineartAnimeDetector,
|
||||
LineartDetector, MediapipeFaceDetector,
|
||||
MidasDetector, MLSDdetector, NormalBaeDetector,
|
||||
OpenposeDetector, PidiNetDetector, SamDetector,
|
||||
ZoeDetector)
|
||||
from controlnet_aux import (
|
||||
CannyDetector,
|
||||
ContentShuffleDetector,
|
||||
HEDdetector,
|
||||
LeresDetector,
|
||||
LineartAnimeDetector,
|
||||
LineartDetector,
|
||||
MediapipeFaceDetector,
|
||||
MidasDetector,
|
||||
MLSDdetector,
|
||||
NormalBaeDetector,
|
||||
OpenposeDetector,
|
||||
PidiNetDetector,
|
||||
SamDetector,
|
||||
ZoeDetector,
|
||||
)
|
||||
from controlnet_aux.util import HWC3, ade_palette
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel, Field, validator
|
||||
|
||||
from ...backend.model_management import BaseModelType, ModelType
|
||||
from ..models.image import ImageCategory, ImageField, ResourceOrigin
|
||||
from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
|
||||
InvocationConfig, InvocationContext)
|
||||
from .image import ImageOutput, PILInvocationConfig
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
|
||||
from ..models.image import ImageOutput, PILInvocationConfig
|
||||
|
||||
CONTROLNET_DEFAULT_MODELS = [
|
||||
###########################################
|
||||
@ -34,7 +43,6 @@ CONTROLNET_DEFAULT_MODELS = [
|
||||
"lllyasviel/sd-controlnet-scribble",
|
||||
"lllyasviel/sd-controlnet-normal",
|
||||
"lllyasviel/sd-controlnet-mlsd",
|
||||
|
||||
#############################################
|
||||
# lllyasviel sd v1.5, ControlNet v1.1 models
|
||||
#############################################
|
||||
@ -56,7 +64,6 @@ CONTROLNET_DEFAULT_MODELS = [
|
||||
"lllyasviel/control_v11e_sd15_shuffle",
|
||||
"lllyasviel/control_v11e_sd15_ip2p",
|
||||
"lllyasviel/control_v11f1e_sd15_tile",
|
||||
|
||||
#################################################
|
||||
# thibaud sd v2.1 models (ControlNet v1.0? or v1.1?
|
||||
##################################################
|
||||
@ -71,7 +78,6 @@ CONTROLNET_DEFAULT_MODELS = [
|
||||
"thibaud/controlnet-sd21-lineart-diffusers",
|
||||
"thibaud/controlnet-sd21-normalbae-diffusers",
|
||||
"thibaud/controlnet-sd21-ade20k-diffusers",
|
||||
|
||||
##############################################
|
||||
# ControlNetMediaPipeface, ControlNet v1.1
|
||||
##############################################
|
||||
@ -83,10 +89,17 @@ CONTROLNET_DEFAULT_MODELS = [
|
||||
]
|
||||
|
||||
CONTROLNET_NAME_VALUES = Literal[tuple(CONTROLNET_DEFAULT_MODELS)]
|
||||
CONTROLNET_MODE_VALUES = Literal[tuple(
|
||||
["balanced", "more_prompt", "more_control", "unbalanced"])]
|
||||
CONTROLNET_RESIZE_VALUES = Literal[tuple(
|
||||
["just_resize", "crop_resize", "fill_resize", "just_resize_simple",])]
|
||||
CONTROLNET_MODE_VALUES = Literal[tuple(["balanced", "more_prompt", "more_control", "unbalanced"])]
|
||||
CONTROLNET_RESIZE_VALUES = Literal[
|
||||
tuple(
|
||||
[
|
||||
"just_resize",
|
||||
"crop_resize",
|
||||
"fill_resize",
|
||||
"just_resize_simple",
|
||||
]
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
class ControlNetModelField(BaseModel):
|
||||
@ -98,21 +111,17 @@ class ControlNetModelField(BaseModel):
|
||||
|
||||
class ControlField(BaseModel):
|
||||
image: ImageField = Field(default=None, description="The control image")
|
||||
control_model: Optional[ControlNetModelField] = Field(
|
||||
default=None, description="The ControlNet model to use")
|
||||
control_model: Optional[ControlNetModelField] = Field(default=None, description="The ControlNet model to use")
|
||||
# control_weight: Optional[float] = Field(default=1, description="weight given to controlnet")
|
||||
control_weight: Union[float, List[float]] = Field(
|
||||
default=1, description="The weight given to the ControlNet")
|
||||
control_weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
|
||||
begin_step_percent: float = Field(
|
||||
default=0, ge=0, le=1,
|
||||
description="When the ControlNet is first applied (% of total steps)")
|
||||
default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
|
||||
)
|
||||
end_step_percent: float = Field(
|
||||
default=1, ge=0, le=1,
|
||||
description="When the ControlNet is last applied (% of total steps)")
|
||||
control_mode: CONTROLNET_MODE_VALUES = Field(
|
||||
default="balanced", description="The control mode to use")
|
||||
resize_mode: CONTROLNET_RESIZE_VALUES = Field(
|
||||
default="just_resize", description="The resize mode to use")
|
||||
default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)"
|
||||
)
|
||||
control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The control mode to use")
|
||||
resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")
|
||||
|
||||
@validator("control_weight")
|
||||
def validate_control_weight(cls, v):
|
||||
@ -120,11 +129,10 @@ class ControlField(BaseModel):
|
||||
if isinstance(v, list):
|
||||
for i in v:
|
||||
if i < -1 or i > 2:
|
||||
raise ValueError(
|
||||
'Control weights must be within -1 to 2 range')
|
||||
raise ValueError("Control weights must be within -1 to 2 range")
|
||||
else:
|
||||
if v < -1 or v > 2:
|
||||
raise ValueError('Control weights must be within -1 to 2 range')
|
||||
raise ValueError("Control weights must be within -1 to 2 range")
|
||||
return v
|
||||
|
||||
class Config:
|
||||
@ -136,12 +144,13 @@ class ControlField(BaseModel):
|
||||
"control_model": "controlnet_model",
|
||||
# "control_weight": "number",
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class ControlOutput(BaseInvocationOutput):
|
||||
"""node output for ControlNet info"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["control_output"] = "control_output"
|
||||
control: ControlField = Field(default=None, description="The control info")
|
||||
@ -150,6 +159,7 @@ class ControlOutput(BaseInvocationOutput):
|
||||
|
||||
class ControlNetInvocation(BaseInvocation):
|
||||
"""Collects ControlNet info to pass to other nodes"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["controlnet"] = "controlnet"
|
||||
# Inputs
|
||||
@ -176,7 +186,7 @@ class ControlNetInvocation(BaseInvocation):
|
||||
# "cfg_scale": "float",
|
||||
"cfg_scale": "number",
|
||||
"control_weight": "float",
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@ -205,10 +215,7 @@ class ImageProcessorInvocation(BaseInvocation, PILInvocationConfig):
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Image Processor",
|
||||
"tags": ["image", "processor"]
|
||||
},
|
||||
"ui": {"title": "Image Processor", "tags": ["image", "processor"]},
|
||||
}
|
||||
|
||||
def run_processor(self, image):
|
||||
@ -233,7 +240,7 @@ class ImageProcessorInvocation(BaseInvocation, PILInvocationConfig):
|
||||
image_category=ImageCategory.CONTROL,
|
||||
session_id=context.graph_execution_state_id,
|
||||
node_id=self.id,
|
||||
is_intermediate=self.is_intermediate
|
||||
is_intermediate=self.is_intermediate,
|
||||
)
|
||||
|
||||
"""Builds an ImageOutput and its ImageField"""
|
||||
@ -248,9 +255,9 @@ class ImageProcessorInvocation(BaseInvocation, PILInvocationConfig):
|
||||
)
|
||||
|
||||
|
||||
class CannyImageProcessorInvocation(
|
||||
ImageProcessorInvocation, PILInvocationConfig):
|
||||
class CannyImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||
"""Canny edge detection for ControlNet"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["canny_image_processor"] = "canny_image_processor"
|
||||
# Input
|
||||
@ -260,22 +267,18 @@ class CannyImageProcessorInvocation(
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Canny Processor",
|
||||
"tags": ["controlnet", "canny", "image", "processor"]
|
||||
},
|
||||
"ui": {"title": "Canny Processor", "tags": ["controlnet", "canny", "image", "processor"]},
|
||||
}
|
||||
|
||||
def run_processor(self, image):
|
||||
canny_processor = CannyDetector()
|
||||
processed_image = canny_processor(
|
||||
image, self.low_threshold, self.high_threshold)
|
||||
processed_image = canny_processor(image, self.low_threshold, self.high_threshold)
|
||||
return processed_image
|
||||
|
||||
|
||||
class HedImageProcessorInvocation(
|
||||
ImageProcessorInvocation, PILInvocationConfig):
|
||||
class HedImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||
"""Applies HED edge detection to image"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["hed_image_processor"] = "hed_image_processor"
|
||||
# Inputs
|
||||
@ -288,27 +291,25 @@ class HedImageProcessorInvocation(
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Softedge(HED) Processor",
|
||||
"tags": ["controlnet", "softedge", "hed", "image", "processor"]
|
||||
},
|
||||
"ui": {"title": "Softedge(HED) Processor", "tags": ["controlnet", "softedge", "hed", "image", "processor"]},
|
||||
}
|
||||
|
||||
def run_processor(self, image):
|
||||
hed_processor = HEDdetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = hed_processor(image,
|
||||
detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution,
|
||||
# safe not supported in controlnet_aux v0.0.3
|
||||
# safe=self.safe,
|
||||
scribble=self.scribble,
|
||||
)
|
||||
processed_image = hed_processor(
|
||||
image,
|
||||
detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution,
|
||||
# safe not supported in controlnet_aux v0.0.3
|
||||
# safe=self.safe,
|
||||
scribble=self.scribble,
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
class LineartImageProcessorInvocation(
|
||||
ImageProcessorInvocation, PILInvocationConfig):
|
||||
class LineartImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||
"""Applies line art processing to image"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["lineart_image_processor"] = "lineart_image_processor"
|
||||
# Inputs
|
||||
@ -319,24 +320,20 @@ class LineartImageProcessorInvocation(
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Lineart Processor",
|
||||
"tags": ["controlnet", "lineart", "image", "processor"]
|
||||
},
|
||||
"ui": {"title": "Lineart Processor", "tags": ["controlnet", "lineart", "image", "processor"]},
|
||||
}
|
||||
|
||||
def run_processor(self, image):
|
||||
lineart_processor = LineartDetector.from_pretrained(
|
||||
"lllyasviel/Annotators")
|
||||
lineart_processor = LineartDetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = lineart_processor(
|
||||
image, detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution, coarse=self.coarse)
|
||||
image, detect_resolution=self.detect_resolution, image_resolution=self.image_resolution, coarse=self.coarse
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
class LineartAnimeImageProcessorInvocation(
|
||||
ImageProcessorInvocation, PILInvocationConfig):
|
||||
class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||
"""Applies line art anime processing to image"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["lineart_anime_image_processor"] = "lineart_anime_image_processor"
|
||||
# Inputs
|
||||
@ -348,23 +345,23 @@ class LineartAnimeImageProcessorInvocation(
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Lineart Anime Processor",
|
||||
"tags": ["controlnet", "lineart", "anime", "image", "processor"]
|
||||
"tags": ["controlnet", "lineart", "anime", "image", "processor"],
|
||||
},
|
||||
}
|
||||
|
||||
def run_processor(self, image):
|
||||
processor = LineartAnimeDetector.from_pretrained(
|
||||
"lllyasviel/Annotators")
|
||||
processed_image = processor(image,
|
||||
detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution,
|
||||
)
|
||||
processor = LineartAnimeDetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = processor(
|
||||
image,
|
||||
detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution,
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
class OpenposeImageProcessorInvocation(
|
||||
ImageProcessorInvocation, PILInvocationConfig):
|
||||
class OpenposeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||
"""Applies Openpose processing to image"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["openpose_image_processor"] = "openpose_image_processor"
|
||||
# Inputs
|
||||
@ -375,25 +372,23 @@ class OpenposeImageProcessorInvocation(
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Openpose Processor",
|
||||
"tags": ["controlnet", "openpose", "image", "processor"]
|
||||
},
|
||||
"ui": {"title": "Openpose Processor", "tags": ["controlnet", "openpose", "image", "processor"]},
|
||||
}
|
||||
|
||||
def run_processor(self, image):
|
||||
openpose_processor = OpenposeDetector.from_pretrained(
|
||||
"lllyasviel/Annotators")
|
||||
openpose_processor = OpenposeDetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = openpose_processor(
|
||||
image, detect_resolution=self.detect_resolution,
|
||||
image,
|
||||
detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution,
|
||||
hand_and_face=self.hand_and_face,)
|
||||
hand_and_face=self.hand_and_face,
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
class MidasDepthImageProcessorInvocation(
|
||||
ImageProcessorInvocation, PILInvocationConfig):
|
||||
class MidasDepthImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||
"""Applies Midas depth processing to image"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["midas_depth_image_processor"] = "midas_depth_image_processor"
|
||||
# Inputs
|
||||
@ -405,26 +400,24 @@ class MidasDepthImageProcessorInvocation(
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Midas (Depth) Processor",
|
||||
"tags": ["controlnet", "midas", "depth", "image", "processor"]
|
||||
},
|
||||
"ui": {"title": "Midas (Depth) Processor", "tags": ["controlnet", "midas", "depth", "image", "processor"]},
|
||||
}
|
||||
|
||||
def run_processor(self, image):
|
||||
midas_processor = MidasDetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = midas_processor(image,
|
||||
a=np.pi * self.a_mult,
|
||||
bg_th=self.bg_th,
|
||||
# dept_and_normal not supported in controlnet_aux v0.0.3
|
||||
# depth_and_normal=self.depth_and_normal,
|
||||
)
|
||||
processed_image = midas_processor(
|
||||
image,
|
||||
a=np.pi * self.a_mult,
|
||||
bg_th=self.bg_th,
|
||||
# dept_and_normal not supported in controlnet_aux v0.0.3
|
||||
# depth_and_normal=self.depth_and_normal,
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
class NormalbaeImageProcessorInvocation(
|
||||
ImageProcessorInvocation, PILInvocationConfig):
|
||||
class NormalbaeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||
"""Applies NormalBae processing to image"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["normalbae_image_processor"] = "normalbae_image_processor"
|
||||
# Inputs
|
||||
@ -434,24 +427,20 @@ class NormalbaeImageProcessorInvocation(
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Normal BAE Processor",
|
||||
"tags": ["controlnet", "normal", "bae", "image", "processor"]
|
||||
},
|
||||
"ui": {"title": "Normal BAE Processor", "tags": ["controlnet", "normal", "bae", "image", "processor"]},
|
||||
}
|
||||
|
||||
def run_processor(self, image):
|
||||
normalbae_processor = NormalBaeDetector.from_pretrained(
|
||||
"lllyasviel/Annotators")
|
||||
normalbae_processor = NormalBaeDetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = normalbae_processor(
|
||||
image, detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution)
|
||||
image, detect_resolution=self.detect_resolution, image_resolution=self.image_resolution
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
class MlsdImageProcessorInvocation(
|
||||
ImageProcessorInvocation, PILInvocationConfig):
|
||||
class MlsdImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||
"""Applies MLSD processing to image"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["mlsd_image_processor"] = "mlsd_image_processor"
|
||||
# Inputs
|
||||
@ -463,24 +452,24 @@ class MlsdImageProcessorInvocation(
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "MLSD Processor",
|
||||
"tags": ["controlnet", "mlsd", "image", "processor"]
|
||||
},
|
||||
"ui": {"title": "MLSD Processor", "tags": ["controlnet", "mlsd", "image", "processor"]},
|
||||
}
|
||||
|
||||
def run_processor(self, image):
|
||||
mlsd_processor = MLSDdetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = mlsd_processor(
|
||||
image, detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution, thr_v=self.thr_v,
|
||||
thr_d=self.thr_d)
|
||||
image,
|
||||
detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution,
|
||||
thr_v=self.thr_v,
|
||||
thr_d=self.thr_d,
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
class PidiImageProcessorInvocation(
|
||||
ImageProcessorInvocation, PILInvocationConfig):
|
||||
class PidiImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||
"""Applies PIDI processing to image"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["pidi_image_processor"] = "pidi_image_processor"
|
||||
# Inputs
|
||||
@ -492,25 +481,24 @@ class PidiImageProcessorInvocation(
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "PIDI Processor",
|
||||
"tags": ["controlnet", "pidi", "image", "processor"]
|
||||
},
|
||||
"ui": {"title": "PIDI Processor", "tags": ["controlnet", "pidi", "image", "processor"]},
|
||||
}
|
||||
|
||||
def run_processor(self, image):
|
||||
pidi_processor = PidiNetDetector.from_pretrained(
|
||||
"lllyasviel/Annotators")
|
||||
pidi_processor = PidiNetDetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = pidi_processor(
|
||||
image, detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution, safe=self.safe,
|
||||
scribble=self.scribble)
|
||||
image,
|
||||
detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution,
|
||||
safe=self.safe,
|
||||
scribble=self.scribble,
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
class ContentShuffleImageProcessorInvocation(
|
||||
ImageProcessorInvocation, PILInvocationConfig):
|
||||
class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||
"""Applies content shuffle processing to image"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["content_shuffle_image_processor"] = "content_shuffle_image_processor"
|
||||
# Inputs
|
||||
@ -525,48 +513,45 @@ class ContentShuffleImageProcessorInvocation(
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Content Shuffle Processor",
|
||||
"tags": ["controlnet", "contentshuffle", "image", "processor"]
|
||||
"tags": ["controlnet", "contentshuffle", "image", "processor"],
|
||||
},
|
||||
}
|
||||
|
||||
def run_processor(self, image):
|
||||
content_shuffle_processor = ContentShuffleDetector()
|
||||
processed_image = content_shuffle_processor(image,
|
||||
detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution,
|
||||
h=self.h,
|
||||
w=self.w,
|
||||
f=self.f
|
||||
)
|
||||
processed_image = content_shuffle_processor(
|
||||
image,
|
||||
detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution,
|
||||
h=self.h,
|
||||
w=self.w,
|
||||
f=self.f,
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
# should work with controlnet_aux >= 0.0.4 and timm <= 0.6.13
|
||||
class ZoeDepthImageProcessorInvocation(
|
||||
ImageProcessorInvocation, PILInvocationConfig):
|
||||
class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||
"""Applies Zoe depth processing to image"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["zoe_depth_image_processor"] = "zoe_depth_image_processor"
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Zoe (Depth) Processor",
|
||||
"tags": ["controlnet", "zoe", "depth", "image", "processor"]
|
||||
},
|
||||
"ui": {"title": "Zoe (Depth) Processor", "tags": ["controlnet", "zoe", "depth", "image", "processor"]},
|
||||
}
|
||||
|
||||
def run_processor(self, image):
|
||||
zoe_depth_processor = ZoeDetector.from_pretrained(
|
||||
"lllyasviel/Annotators")
|
||||
zoe_depth_processor = ZoeDetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = zoe_depth_processor(image)
|
||||
return processed_image
|
||||
|
||||
|
||||
class MediapipeFaceProcessorInvocation(
|
||||
ImageProcessorInvocation, PILInvocationConfig):
|
||||
class MediapipeFaceProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||
"""Applies mediapipe face processing to image"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["mediapipe_face_processor"] = "mediapipe_face_processor"
|
||||
# Inputs
|
||||
@ -576,26 +561,22 @@ class MediapipeFaceProcessorInvocation(
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Mediapipe Processor",
|
||||
"tags": ["controlnet", "mediapipe", "image", "processor"]
|
||||
},
|
||||
"ui": {"title": "Mediapipe Processor", "tags": ["controlnet", "mediapipe", "image", "processor"]},
|
||||
}
|
||||
|
||||
def run_processor(self, image):
|
||||
# MediaPipeFaceDetector throws an error if image has alpha channel
|
||||
# so convert to RGB if needed
|
||||
if image.mode == 'RGBA':
|
||||
image = image.convert('RGB')
|
||||
if image.mode == "RGBA":
|
||||
image = image.convert("RGB")
|
||||
mediapipe_face_processor = MediapipeFaceDetector()
|
||||
processed_image = mediapipe_face_processor(
|
||||
image, max_faces=self.max_faces, min_confidence=self.min_confidence)
|
||||
processed_image = mediapipe_face_processor(image, max_faces=self.max_faces, min_confidence=self.min_confidence)
|
||||
return processed_image
|
||||
|
||||
|
||||
class LeresImageProcessorInvocation(
|
||||
ImageProcessorInvocation, PILInvocationConfig):
|
||||
class LeresImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||
"""Applies leres processing to image"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["leres_image_processor"] = "leres_image_processor"
|
||||
# Inputs
|
||||
@ -608,24 +589,23 @@ class LeresImageProcessorInvocation(
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Leres (Depth) Processor",
|
||||
"tags": ["controlnet", "leres", "depth", "image", "processor"]
|
||||
},
|
||||
"ui": {"title": "Leres (Depth) Processor", "tags": ["controlnet", "leres", "depth", "image", "processor"]},
|
||||
}
|
||||
|
||||
def run_processor(self, image):
|
||||
leres_processor = LeresDetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = leres_processor(
|
||||
image, thr_a=self.thr_a, thr_b=self.thr_b, boost=self.boost,
|
||||
image,
|
||||
thr_a=self.thr_a,
|
||||
thr_b=self.thr_b,
|
||||
boost=self.boost,
|
||||
detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution)
|
||||
image_resolution=self.image_resolution,
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
class TileResamplerProcessorInvocation(
|
||||
ImageProcessorInvocation, PILInvocationConfig):
|
||||
|
||||
class TileResamplerProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||
# fmt: off
|
||||
type: Literal["tile_image_processor"] = "tile_image_processor"
|
||||
# Inputs
|
||||
@ -637,16 +617,17 @@ class TileResamplerProcessorInvocation(
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Tile Resample Processor",
|
||||
"tags": ["controlnet", "tile", "resample", "image", "processor"]
|
||||
"tags": ["controlnet", "tile", "resample", "image", "processor"],
|
||||
},
|
||||
}
|
||||
|
||||
# tile_resample copied from sd-webui-controlnet/scripts/processor.py
|
||||
def tile_resample(self,
|
||||
np_img: np.ndarray,
|
||||
res=512, # never used?
|
||||
down_sampling_rate=1.0,
|
||||
):
|
||||
def tile_resample(
|
||||
self,
|
||||
np_img: np.ndarray,
|
||||
res=512, # never used?
|
||||
down_sampling_rate=1.0,
|
||||
):
|
||||
np_img = HWC3(np_img)
|
||||
if down_sampling_rate < 1.1:
|
||||
return np_img
|
||||
@ -658,36 +639,41 @@ class TileResamplerProcessorInvocation(
|
||||
|
||||
def run_processor(self, img):
|
||||
np_img = np.array(img, dtype=np.uint8)
|
||||
processed_np_image = self.tile_resample(np_img,
|
||||
# res=self.tile_size,
|
||||
down_sampling_rate=self.down_sampling_rate
|
||||
)
|
||||
processed_np_image = self.tile_resample(
|
||||
np_img,
|
||||
# res=self.tile_size,
|
||||
down_sampling_rate=self.down_sampling_rate,
|
||||
)
|
||||
processed_image = Image.fromarray(processed_np_image)
|
||||
return processed_image
|
||||
|
||||
|
||||
class SegmentAnythingProcessorInvocation(
|
||||
ImageProcessorInvocation, PILInvocationConfig):
|
||||
class SegmentAnythingProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||
"""Applies segment anything processing to image"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["segment_anything_processor"] = "segment_anything_processor"
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {"ui": {"title": "Segment Anything Processor", "tags": [
|
||||
"controlnet", "segment", "anything", "sam", "image", "processor"]}, }
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Segment Anything Processor",
|
||||
"tags": ["controlnet", "segment", "anything", "sam", "image", "processor"],
|
||||
},
|
||||
}
|
||||
|
||||
def run_processor(self, image):
|
||||
# segment_anything_processor = SamDetector.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints")
|
||||
segment_anything_processor = SamDetectorReproducibleColors.from_pretrained(
|
||||
"ybelkada/segment-anything", subfolder="checkpoints")
|
||||
"ybelkada/segment-anything", subfolder="checkpoints"
|
||||
)
|
||||
np_img = np.array(image, dtype=np.uint8)
|
||||
processed_image = segment_anything_processor(np_img)
|
||||
return processed_image
|
||||
|
||||
|
||||
class SamDetectorReproducibleColors(SamDetector):
|
||||
|
||||
# overriding SamDetector.show_anns() method to use reproducible colors for segmentation image
|
||||
# base class show_anns() method randomizes colors,
|
||||
# which seems to also lead to non-reproducible image generation
|
||||
@ -695,19 +681,15 @@ class SamDetectorReproducibleColors(SamDetector):
|
||||
def show_anns(self, anns: List[Dict]):
|
||||
if len(anns) == 0:
|
||||
return
|
||||
sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
|
||||
h, w = anns[0]['segmentation'].shape
|
||||
final_img = Image.fromarray(
|
||||
np.zeros((h, w, 3), dtype=np.uint8), mode="RGB")
|
||||
sorted_anns = sorted(anns, key=(lambda x: x["area"]), reverse=True)
|
||||
h, w = anns[0]["segmentation"].shape
|
||||
final_img = Image.fromarray(np.zeros((h, w, 3), dtype=np.uint8), mode="RGB")
|
||||
palette = ade_palette()
|
||||
for i, ann in enumerate(sorted_anns):
|
||||
m = ann['segmentation']
|
||||
m = ann["segmentation"]
|
||||
img = np.empty((m.shape[0], m.shape[1], 3), dtype=np.uint8)
|
||||
# doing modulo just in case number of annotated regions exceeds number of colors in palette
|
||||
ann_color = palette[i % len(palette)]
|
||||
img[:, :] = ann_color
|
||||
final_img.paste(
|
||||
Image.fromarray(img, mode="RGB"),
|
||||
(0, 0),
|
||||
Image.fromarray(np.uint8(m * 255)))
|
||||
final_img.paste(Image.fromarray(img, mode="RGB"), (0, 0), Image.fromarray(np.uint8(m * 255)))
|
||||
return np.array(final_img, dtype=np.uint8)
|
||||
|
@ -37,10 +37,7 @@ class CvInpaintInvocation(BaseInvocation, CvInvocationConfig):
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "OpenCV Inpaint",
|
||||
"tags": ["opencv", "inpaint"]
|
||||
},
|
||||
"ui": {"title": "OpenCV Inpaint", "tags": ["opencv", "inpaint"]},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
|
@ -6,8 +6,7 @@ from typing import Literal, Optional, get_args
|
||||
import torch
|
||||
from pydantic import Field
|
||||
|
||||
from invokeai.app.models.image import (ColorField, ImageCategory, ImageField,
|
||||
ResourceOrigin)
|
||||
from invokeai.app.models.image import ColorField, ImageCategory, ImageField, ResourceOrigin
|
||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||
from invokeai.backend.generator.inpaint import infill_methods
|
||||
|
||||
@ -25,13 +24,12 @@ from contextlib import contextmanager, ExitStack, ContextDecorator
|
||||
|
||||
SAMPLER_NAME_VALUES = Literal[tuple(InvokeAIGenerator.schedulers())]
|
||||
INFILL_METHODS = Literal[tuple(infill_methods())]
|
||||
DEFAULT_INFILL_METHOD = (
|
||||
"patchmatch" if "patchmatch" in get_args(INFILL_METHODS) else "tile"
|
||||
)
|
||||
DEFAULT_INFILL_METHOD = "patchmatch" if "patchmatch" in get_args(INFILL_METHODS) else "tile"
|
||||
|
||||
|
||||
from .latent import get_scheduler
|
||||
|
||||
|
||||
class OldModelContext(ContextDecorator):
|
||||
model: StableDiffusionGeneratorPipeline
|
||||
|
||||
@ -44,6 +42,7 @@ class OldModelContext(ContextDecorator):
|
||||
def __exit__(self, *exc):
|
||||
return False
|
||||
|
||||
|
||||
class OldModelInfo:
|
||||
name: str
|
||||
hash: str
|
||||
@ -64,20 +63,34 @@ class InpaintInvocation(BaseInvocation):
|
||||
|
||||
positive_conditioning: Optional[ConditioningField] = Field(description="Positive conditioning for generation")
|
||||
negative_conditioning: Optional[ConditioningField] = Field(description="Negative conditioning for generation")
|
||||
seed: int = Field(ge=0, le=SEED_MAX, description="The seed to use (omit for random)", default_factory=get_random_seed)
|
||||
steps: int = Field(default=30, gt=0, description="The number of steps to use to generate the image")
|
||||
width: int = Field(default=512, multiple_of=8, gt=0, description="The width of the resulting image", )
|
||||
height: int = Field(default=512, multiple_of=8, gt=0, description="The height of the resulting image", )
|
||||
cfg_scale: float = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
|
||||
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
|
||||
seed: int = Field(
|
||||
ge=0, le=SEED_MAX, description="The seed to use (omit for random)", default_factory=get_random_seed
|
||||
)
|
||||
steps: int = Field(default=30, gt=0, description="The number of steps to use to generate the image")
|
||||
width: int = Field(
|
||||
default=512,
|
||||
multiple_of=8,
|
||||
gt=0,
|
||||
description="The width of the resulting image",
|
||||
)
|
||||
height: int = Field(
|
||||
default=512,
|
||||
multiple_of=8,
|
||||
gt=0,
|
||||
description="The height of the resulting image",
|
||||
)
|
||||
cfg_scale: float = Field(
|
||||
default=7.5,
|
||||
ge=1,
|
||||
description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt",
|
||||
)
|
||||
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use")
|
||||
unet: UNetField = Field(default=None, description="UNet model")
|
||||
vae: VaeField = Field(default=None, description="Vae model")
|
||||
|
||||
# Inputs
|
||||
image: Optional[ImageField] = Field(description="The input image")
|
||||
strength: float = Field(
|
||||
default=0.75, gt=0, le=1, description="The strength of the original image"
|
||||
)
|
||||
strength: float = Field(default=0.75, gt=0, le=1, description="The strength of the original image")
|
||||
fit: bool = Field(
|
||||
default=True,
|
||||
description="Whether or not the result should be fit to the aspect ratio of the input image",
|
||||
@ -86,18 +99,10 @@ class InpaintInvocation(BaseInvocation):
|
||||
# Inputs
|
||||
mask: Optional[ImageField] = Field(description="The mask")
|
||||
seam_size: int = Field(default=96, ge=1, description="The seam inpaint size (px)")
|
||||
seam_blur: int = Field(
|
||||
default=16, ge=0, description="The seam inpaint blur radius (px)"
|
||||
)
|
||||
seam_strength: float = Field(
|
||||
default=0.75, gt=0, le=1, description="The seam inpaint strength"
|
||||
)
|
||||
seam_steps: int = Field(
|
||||
default=30, ge=1, description="The number of steps to use for seam inpaint"
|
||||
)
|
||||
tile_size: int = Field(
|
||||
default=32, ge=1, description="The tile infill method size (px)"
|
||||
)
|
||||
seam_blur: int = Field(default=16, ge=0, description="The seam inpaint blur radius (px)")
|
||||
seam_strength: float = Field(default=0.75, gt=0, le=1, description="The seam inpaint strength")
|
||||
seam_steps: int = Field(default=30, ge=1, description="The number of steps to use for seam inpaint")
|
||||
tile_size: int = Field(default=32, ge=1, description="The tile infill method size (px)")
|
||||
infill_method: INFILL_METHODS = Field(
|
||||
default=DEFAULT_INFILL_METHOD,
|
||||
description="The method used to infill empty regions (px)",
|
||||
@ -128,10 +133,7 @@ class InpaintInvocation(BaseInvocation):
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["stable-diffusion", "image"],
|
||||
"title": "Inpaint"
|
||||
},
|
||||
"ui": {"tags": ["stable-diffusion", "image"], "title": "Inpaint"},
|
||||
}
|
||||
|
||||
def dispatch_progress(
|
||||
@ -162,18 +164,23 @@ class InpaintInvocation(BaseInvocation):
|
||||
def _lora_loader():
|
||||
for lora in self.unet.loras:
|
||||
lora_info = context.services.model_manager.get_model(
|
||||
**lora.dict(exclude={"weight"}), context=context,)
|
||||
**lora.dict(exclude={"weight"}),
|
||||
context=context,
|
||||
)
|
||||
yield (lora_info.context.model, lora.weight)
|
||||
del lora_info
|
||||
return
|
||||
|
||||
unet_info = context.services.model_manager.get_model(**self.unet.unet.dict(), context=context,)
|
||||
vae_info = context.services.model_manager.get_model(**self.vae.vae.dict(), context=context,)
|
||||
|
||||
with vae_info as vae,\
|
||||
ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
|
||||
unet_info as unet:
|
||||
unet_info = context.services.model_manager.get_model(
|
||||
**self.unet.unet.dict(),
|
||||
context=context,
|
||||
)
|
||||
vae_info = context.services.model_manager.get_model(
|
||||
**self.vae.vae.dict(),
|
||||
context=context,
|
||||
)
|
||||
|
||||
with vae_info as vae, ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()), unet_info as unet:
|
||||
device = context.services.model_manager.mgr.cache.execution_device
|
||||
dtype = context.services.model_manager.mgr.cache.precision
|
||||
|
||||
@ -197,21 +204,11 @@ class InpaintInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = (
|
||||
None
|
||||
if self.image is None
|
||||
else context.services.images.get_pil_image(self.image.image_name)
|
||||
)
|
||||
mask = (
|
||||
None
|
||||
if self.mask is None
|
||||
else context.services.images.get_pil_image(self.mask.image_name)
|
||||
)
|
||||
image = None if self.image is None else context.services.images.get_pil_image(self.image.image_name)
|
||||
mask = None if self.mask is None else context.services.images.get_pil_image(self.mask.image_name)
|
||||
|
||||
# Get the source node id (we are invoking the prepared node)
|
||||
graph_execution_state = context.services.graph_execution_manager.get(
|
||||
context.graph_execution_state_id
|
||||
)
|
||||
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
|
||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||
|
||||
scheduler = get_scheduler(
|
||||
|
@ -4,60 +4,25 @@ from typing import Literal, Optional
|
||||
|
||||
import numpy
|
||||
from PIL import Image, ImageFilter, ImageOps, ImageChops
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import Field
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
from ..models.image import ImageCategory, ImageField, ResourceOrigin
|
||||
from invokeai.app.invocations.metadata import CoreMetadata
|
||||
from ..models.image import (
|
||||
ImageCategory,
|
||||
ImageField,
|
||||
ResourceOrigin,
|
||||
PILInvocationConfig,
|
||||
ImageOutput,
|
||||
MaskOutput,
|
||||
)
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
InvocationContext,
|
||||
InvocationConfig,
|
||||
)
|
||||
|
||||
|
||||
class PILInvocationConfig(BaseModel):
|
||||
"""Helper class to provide all PIL invocations with additional config"""
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["PIL", "image"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class ImageOutput(BaseInvocationOutput):
|
||||
"""Base class for invocations that output an image"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["image_output"] = "image_output"
|
||||
image: ImageField = Field(default=None, description="The output image")
|
||||
width: int = Field(description="The width of the image in pixels")
|
||||
height: int = Field(description="The height of the image in pixels")
|
||||
# fmt: on
|
||||
|
||||
class Config:
|
||||
schema_extra = {"required": ["type", "image", "width", "height"]}
|
||||
|
||||
|
||||
class MaskOutput(BaseInvocationOutput):
|
||||
"""Base class for invocations that output a mask"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["mask"] = "mask"
|
||||
mask: ImageField = Field(default=None, description="The output mask")
|
||||
width: int = Field(description="The width of the mask in pixels")
|
||||
height: int = Field(description="The height of the mask in pixels")
|
||||
# fmt: on
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
"required": [
|
||||
"type",
|
||||
"mask",
|
||||
]
|
||||
}
|
||||
from invokeai.backend.image_util.safety_checker import SafetyChecker
|
||||
from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark
|
||||
|
||||
|
||||
class LoadImageInvocation(BaseInvocation):
|
||||
@ -74,10 +39,7 @@ class LoadImageInvocation(BaseInvocation):
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Load Image",
|
||||
"tags": ["image", "load"]
|
||||
},
|
||||
"ui": {"title": "Load Image", "tags": ["image", "load"]},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
@ -96,16 +58,11 @@ class ShowImageInvocation(BaseInvocation):
|
||||
type: Literal["show_image"] = "show_image"
|
||||
|
||||
# Inputs
|
||||
image: Optional[ImageField] = Field(
|
||||
default=None, description="The image to show"
|
||||
)
|
||||
image: Optional[ImageField] = Field(default=None, description="The image to show")
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Show Image",
|
||||
"tags": ["image", "show"]
|
||||
},
|
||||
"ui": {"title": "Show Image", "tags": ["image", "show"]},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
@ -138,18 +95,13 @@ class ImageCropInvocation(BaseInvocation, PILInvocationConfig):
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Crop Image",
|
||||
"tags": ["image", "crop"]
|
||||
},
|
||||
"ui": {"title": "Crop Image", "tags": ["image", "crop"]},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
image_crop = Image.new(
|
||||
mode="RGBA", size=(self.width, self.height), color=(0, 0, 0, 0)
|
||||
)
|
||||
image_crop = Image.new(mode="RGBA", size=(self.width, self.height), color=(0, 0, 0, 0))
|
||||
image_crop.paste(image, (-self.x, -self.y))
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
@ -184,21 +136,14 @@ class ImagePasteInvocation(BaseInvocation, PILInvocationConfig):
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Paste Image",
|
||||
"tags": ["image", "paste"]
|
||||
},
|
||||
"ui": {"title": "Paste Image", "tags": ["image", "paste"]},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
base_image = context.services.images.get_pil_image(self.base_image.image_name)
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
mask = (
|
||||
None
|
||||
if self.mask is None
|
||||
else ImageOps.invert(
|
||||
context.services.images.get_pil_image(self.mask.image_name)
|
||||
)
|
||||
None if self.mask is None else ImageOps.invert(context.services.images.get_pil_image(self.mask.image_name))
|
||||
)
|
||||
# TODO: probably shouldn't invert mask here... should user be required to do it?
|
||||
|
||||
@ -207,9 +152,7 @@ class ImagePasteInvocation(BaseInvocation, PILInvocationConfig):
|
||||
max_x = max(base_image.width, image.width + self.x)
|
||||
max_y = max(base_image.height, image.height + self.y)
|
||||
|
||||
new_image = Image.new(
|
||||
mode="RGBA", size=(max_x - min_x, max_y - min_y), color=(0, 0, 0, 0)
|
||||
)
|
||||
new_image = Image.new(mode="RGBA", size=(max_x - min_x, max_y - min_y), color=(0, 0, 0, 0))
|
||||
new_image.paste(base_image, (abs(min_x), abs(min_y)))
|
||||
new_image.paste(image, (max(0, self.x), max(0, self.y)), mask=mask)
|
||||
|
||||
@ -242,10 +185,7 @@ class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig):
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Mask From Alpha",
|
||||
"tags": ["image", "mask", "alpha"]
|
||||
},
|
||||
"ui": {"title": "Mask From Alpha", "tags": ["image", "mask", "alpha"]},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> MaskOutput:
|
||||
@ -284,10 +224,7 @@ class ImageMultiplyInvocation(BaseInvocation, PILInvocationConfig):
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Multiply Images",
|
||||
"tags": ["image", "multiply"]
|
||||
},
|
||||
"ui": {"title": "Multiply Images", "tags": ["image", "multiply"]},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
@ -328,10 +265,7 @@ class ImageChannelInvocation(BaseInvocation, PILInvocationConfig):
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Image Channel",
|
||||
"tags": ["image", "channel"]
|
||||
},
|
||||
"ui": {"title": "Image Channel", "tags": ["image", "channel"]},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
@ -371,10 +305,7 @@ class ImageConvertInvocation(BaseInvocation, PILInvocationConfig):
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Convert Image",
|
||||
"tags": ["image", "convert"]
|
||||
},
|
||||
"ui": {"title": "Convert Image", "tags": ["image", "convert"]},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
@ -412,19 +343,14 @@ class ImageBlurInvocation(BaseInvocation, PILInvocationConfig):
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Blur Image",
|
||||
"tags": ["image", "blur"]
|
||||
},
|
||||
"ui": {"title": "Blur Image", "tags": ["image", "blur"]},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
blur = (
|
||||
ImageFilter.GaussianBlur(self.radius)
|
||||
if self.blur_type == "gaussian"
|
||||
else ImageFilter.BoxBlur(self.radius)
|
||||
ImageFilter.GaussianBlur(self.radius) if self.blur_type == "gaussian" else ImageFilter.BoxBlur(self.radius)
|
||||
)
|
||||
blur_image = image.filter(blur)
|
||||
|
||||
@ -479,10 +405,7 @@ class ImageResizeInvocation(BaseInvocation, PILInvocationConfig):
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Resize Image",
|
||||
"tags": ["image", "resize"]
|
||||
},
|
||||
"ui": {"title": "Resize Image", "tags": ["image", "resize"]},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
@ -525,10 +448,7 @@ class ImageScaleInvocation(BaseInvocation, PILInvocationConfig):
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Scale Image",
|
||||
"tags": ["image", "scale"]
|
||||
},
|
||||
"ui": {"title": "Scale Image", "tags": ["image", "scale"]},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
@ -573,10 +493,7 @@ class ImageLerpInvocation(BaseInvocation, PILInvocationConfig):
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Image Linear Interpolation",
|
||||
"tags": ["image", "linear", "interpolation", "lerp"]
|
||||
},
|
||||
"ui": {"title": "Image Linear Interpolation", "tags": ["image", "linear", "interpolation", "lerp"]},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
@ -619,7 +536,7 @@ class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Image Inverse Linear Interpolation",
|
||||
"tags": ["image", "linear", "interpolation", "inverse"]
|
||||
"tags": ["image", "linear", "interpolation", "inverse"],
|
||||
},
|
||||
}
|
||||
|
||||
@ -627,12 +544,7 @@ class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig):
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
image_arr = numpy.asarray(image, dtype=numpy.float32)
|
||||
image_arr = (
|
||||
numpy.minimum(
|
||||
numpy.maximum(image_arr - self.min, 0) / float(self.max - self.min), 1
|
||||
)
|
||||
* 255
|
||||
)
|
||||
image_arr = numpy.minimum(numpy.maximum(image_arr - self.min, 0) / float(self.max - self.min), 1) * 255
|
||||
|
||||
ilerp_image = Image.fromarray(numpy.uint8(image_arr))
|
||||
|
||||
@ -650,3 +562,91 @@ class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig):
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
|
||||
class ImageNSFWBlurInvocation(BaseInvocation, PILInvocationConfig):
|
||||
"""Add blur to NSFW-flagged images"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["img_nsfw"] = "img_nsfw"
|
||||
|
||||
# Inputs
|
||||
image: Optional[ImageField] = Field(default=None, description="The image to check")
|
||||
metadata: Optional[CoreMetadata] = Field(default=None, description="Optional core metadata to be written to the image")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {"title": "Blur NSFW Images", "tags": ["image", "nsfw", "checker"]},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
logger = context.services.logger
|
||||
logger.debug("Running NSFW checker")
|
||||
if SafetyChecker.has_nsfw_concept(image):
|
||||
logger.info("A potentially NSFW image has been detected. Image will be blurred.")
|
||||
blurry_image = image.filter(filter=ImageFilter.GaussianBlur(radius=32))
|
||||
caution = self._get_caution_img()
|
||||
blurry_image.paste(caution, (0, 0), caution)
|
||||
image = blurry_image
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=image,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata.dict() if self.metadata else None,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
def _get_caution_img(self) -> Image:
|
||||
import invokeai.app.assets.images as image_assets
|
||||
|
||||
caution = Image.open(Path(image_assets.__path__[0]) / "caution.png")
|
||||
return caution.resize((caution.width // 2, caution.height // 2))
|
||||
|
||||
|
||||
class ImageWatermarkInvocation(BaseInvocation, PILInvocationConfig):
|
||||
"""Add an invisible watermark to an image"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["img_watermark"] = "img_watermark"
|
||||
|
||||
# Inputs
|
||||
image: Optional[ImageField] = Field(default=None, description="The image to check")
|
||||
text: str = Field(default='InvokeAI', description="Watermark text")
|
||||
metadata: Optional[CoreMetadata] = Field(default=None, description="Optional core metadata to be written to the image")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {"title": "Add Invisible Watermark", "tags": ["image", "watermark", "invisible"]},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
new_image = InvisibleWatermark.add_watermark(image, self.text)
|
||||
image_dto = context.services.images.create(
|
||||
image=new_image,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata.dict() if self.metadata else None,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
@ -30,9 +30,7 @@ def infill_methods() -> list[str]:
|
||||
|
||||
|
||||
INFILL_METHODS = Literal[tuple(infill_methods())]
|
||||
DEFAULT_INFILL_METHOD = (
|
||||
"patchmatch" if "patchmatch" in get_args(INFILL_METHODS) else "tile"
|
||||
)
|
||||
DEFAULT_INFILL_METHOD = "patchmatch" if "patchmatch" in get_args(INFILL_METHODS) else "tile"
|
||||
|
||||
|
||||
def infill_patchmatch(im: Image.Image) -> Image.Image:
|
||||
@ -44,9 +42,7 @@ def infill_patchmatch(im: Image.Image) -> Image.Image:
|
||||
return im
|
||||
|
||||
# Patchmatch (note, we may want to expose patch_size? Increasing it significantly impacts performance though)
|
||||
im_patched_np = PatchMatch.inpaint(
|
||||
im.convert("RGB"), ImageOps.invert(im.split()[-1]), patch_size=3
|
||||
)
|
||||
im_patched_np = PatchMatch.inpaint(im.convert("RGB"), ImageOps.invert(im.split()[-1]), patch_size=3)
|
||||
im_patched = Image.fromarray(im_patched_np, mode="RGB")
|
||||
return im_patched
|
||||
|
||||
@ -68,9 +64,7 @@ def get_tile_images(image: np.ndarray, width=8, height=8):
|
||||
)
|
||||
|
||||
|
||||
def tile_fill_missing(
|
||||
im: Image.Image, tile_size: int = 16, seed: Optional[int] = None
|
||||
) -> Image.Image:
|
||||
def tile_fill_missing(im: Image.Image, tile_size: int = 16, seed: Optional[int] = None) -> Image.Image:
|
||||
# Only fill if there's an alpha layer
|
||||
if im.mode != "RGBA":
|
||||
return im
|
||||
@ -103,9 +97,7 @@ def tile_fill_missing(
|
||||
# Find all invalid tiles and replace with a random valid tile
|
||||
replace_count = (tiles_mask == False).sum()
|
||||
rng = np.random.default_rng(seed=seed)
|
||||
tiles_all[np.logical_not(tiles_mask)] = filtered_tiles[
|
||||
rng.choice(filtered_tiles.shape[0], replace_count), :, :, :
|
||||
]
|
||||
tiles_all[np.logical_not(tiles_mask)] = filtered_tiles[rng.choice(filtered_tiles.shape[0], replace_count), :, :, :]
|
||||
|
||||
# Convert back to an image
|
||||
tiles_all = tiles_all.reshape(tshape)
|
||||
@ -126,9 +118,7 @@ class InfillColorInvocation(BaseInvocation):
|
||||
"""Infills transparent areas of an image with a solid color"""
|
||||
|
||||
type: Literal["infill_rgba"] = "infill_rgba"
|
||||
image: Optional[ImageField] = Field(
|
||||
default=None, description="The image to infill"
|
||||
)
|
||||
image: Optional[ImageField] = Field(default=None, description="The image to infill")
|
||||
color: ColorField = Field(
|
||||
default=ColorField(r=127, g=127, b=127, a=255),
|
||||
description="The color to use to infill",
|
||||
@ -136,10 +126,7 @@ class InfillColorInvocation(BaseInvocation):
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Color Infill",
|
||||
"tags": ["image", "inpaint", "color", "infill"]
|
||||
},
|
||||
"ui": {"title": "Color Infill", "tags": ["image", "inpaint", "color", "infill"]},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
@ -171,9 +158,7 @@ class InfillTileInvocation(BaseInvocation):
|
||||
|
||||
type: Literal["infill_tile"] = "infill_tile"
|
||||
|
||||
image: Optional[ImageField] = Field(
|
||||
default=None, description="The image to infill"
|
||||
)
|
||||
image: Optional[ImageField] = Field(default=None, description="The image to infill")
|
||||
tile_size: int = Field(default=32, ge=1, description="The tile size (px)")
|
||||
seed: int = Field(
|
||||
ge=0,
|
||||
@ -184,18 +169,13 @@ class InfillTileInvocation(BaseInvocation):
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Tile Infill",
|
||||
"tags": ["image", "inpaint", "tile", "infill"]
|
||||
},
|
||||
"ui": {"title": "Tile Infill", "tags": ["image", "inpaint", "tile", "infill"]},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
infilled = tile_fill_missing(
|
||||
image.copy(), seed=self.seed, tile_size=self.tile_size
|
||||
)
|
||||
infilled = tile_fill_missing(image.copy(), seed=self.seed, tile_size=self.tile_size)
|
||||
infilled.paste(image, (0, 0), image.split()[-1])
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
@ -219,16 +199,11 @@ class InfillPatchMatchInvocation(BaseInvocation):
|
||||
|
||||
type: Literal["infill_patchmatch"] = "infill_patchmatch"
|
||||
|
||||
image: Optional[ImageField] = Field(
|
||||
default=None, description="The image to infill"
|
||||
)
|
||||
image: Optional[ImageField] = Field(default=None, description="The image to infill")
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Patch Match Infill",
|
||||
"tags": ["image", "inpaint", "patchmatch", "infill"]
|
||||
},
|
||||
"ui": {"title": "Patch Match Infill", "tags": ["image", "inpaint", "patchmatch", "infill"]},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
|
@ -375,7 +375,7 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
|
||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||
context.services.latents.save(name, result_latents)
|
||||
return build_latents_output(latents_name=name, latents=result_latents)
|
||||
return build_latents_output(latents_name=name, latents=result_latents)
|
||||
|
||||
|
||||
class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
||||
@ -492,7 +492,7 @@ class LatentsToImageInvocation(BaseInvocation):
|
||||
# Inputs
|
||||
latents: Optional[LatentsField] = Field(description="The latents to generate an image from")
|
||||
vae: VaeField = Field(default=None, description="Vae submodel")
|
||||
tiled: bool = Field(default=False, description="Decode latents by overlaping tiles(less memory consumption)")
|
||||
tiled: bool = Field(default=False, description="Decode latents by overlaping tiles (less memory consumption)")
|
||||
fp32: bool = Field(DEFAULT_PRECISION == "float32", description="Decode in full precision")
|
||||
metadata: Optional[CoreMetadata] = Field(
|
||||
default=None, description="Optional core metadata to be written to the image"
|
||||
|
@ -54,10 +54,7 @@ class AddInvocation(BaseInvocation, MathInvocationConfig):
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Add",
|
||||
"tags": ["math", "add"]
|
||||
},
|
||||
"ui": {"title": "Add", "tags": ["math", "add"]},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntOutput:
|
||||
@ -75,10 +72,7 @@ class SubtractInvocation(BaseInvocation, MathInvocationConfig):
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Subtract",
|
||||
"tags": ["math", "subtract"]
|
||||
},
|
||||
"ui": {"title": "Subtract", "tags": ["math", "subtract"]},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntOutput:
|
||||
@ -96,10 +90,7 @@ class MultiplyInvocation(BaseInvocation, MathInvocationConfig):
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Multiply",
|
||||
"tags": ["math", "multiply"]
|
||||
},
|
||||
"ui": {"title": "Multiply", "tags": ["math", "multiply"]},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntOutput:
|
||||
@ -117,10 +108,7 @@ class DivideInvocation(BaseInvocation, MathInvocationConfig):
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Divide",
|
||||
"tags": ["math", "divide"]
|
||||
},
|
||||
"ui": {"title": "Divide", "tags": ["math", "divide"]},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntOutput:
|
||||
@ -140,10 +128,7 @@ class RandomIntInvocation(BaseInvocation):
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Random Integer",
|
||||
"tags": ["math", "random", "integer"]
|
||||
},
|
||||
"ui": {"title": "Random Integer", "tags": ["math", "random", "integer"]},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntOutput:
|
||||
|
@ -2,16 +2,19 @@ from typing import Literal, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import (BaseInvocation,
|
||||
BaseInvocationOutput, InvocationConfig,
|
||||
InvocationContext)
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
InvocationConfig,
|
||||
InvocationContext,
|
||||
)
|
||||
from invokeai.app.invocations.controlnet_image_processors import ControlField
|
||||
from invokeai.app.invocations.model import (LoRAModelField, MainModelField,
|
||||
VAEModelField)
|
||||
from invokeai.app.invocations.model import LoRAModelField, MainModelField, VAEModelField
|
||||
|
||||
|
||||
class LoRAMetadataField(BaseModel):
|
||||
"""LoRA metadata for an image generated in InvokeAI."""
|
||||
|
||||
lora: LoRAModelField = Field(description="The LoRA model")
|
||||
weight: float = Field(description="The weight of the LoRA model")
|
||||
|
||||
@ -19,7 +22,9 @@ class LoRAMetadataField(BaseModel):
|
||||
class CoreMetadata(BaseModel):
|
||||
"""Core generation metadata for an image generated in InvokeAI."""
|
||||
|
||||
generation_mode: str = Field(description="The generation mode that output this image",)
|
||||
generation_mode: str = Field(
|
||||
description="The generation mode that output this image",
|
||||
)
|
||||
positive_prompt: str = Field(description="The positive prompt parameter")
|
||||
negative_prompt: str = Field(description="The negative prompt parameter")
|
||||
width: int = Field(description="The width parameter")
|
||||
@ -29,22 +34,41 @@ class CoreMetadata(BaseModel):
|
||||
cfg_scale: float = Field(description="The classifier-free guidance scale parameter")
|
||||
steps: int = Field(description="The number of steps used for inference")
|
||||
scheduler: str = Field(description="The scheduler used for inference")
|
||||
clip_skip: int = Field(description="The number of skipped CLIP layers",)
|
||||
clip_skip: int = Field(
|
||||
description="The number of skipped CLIP layers",
|
||||
)
|
||||
model: MainModelField = Field(description="The main model used for inference")
|
||||
controlnets: list[ControlField]= Field(description="The ControlNets used for inference")
|
||||
controlnets: list[ControlField] = Field(description="The ControlNets used for inference")
|
||||
loras: list[LoRAMetadataField] = Field(description="The LoRAs used for inference")
|
||||
strength: Union[float, None] = Field(
|
||||
default=None,
|
||||
description="The strength used for latents-to-latents",
|
||||
)
|
||||
init_image: Union[str, None] = Field(
|
||||
default=None, description="The name of the initial image"
|
||||
)
|
||||
vae: Union[VAEModelField, None] = Field(
|
||||
default=None,
|
||||
description="The VAE used for decoding, if the main model's default was not used",
|
||||
)
|
||||
|
||||
# Latents-to-Latents
|
||||
strength: Union[float, None] = Field(
|
||||
default=None,
|
||||
description="The strength used for latents-to-latents",
|
||||
)
|
||||
init_image: Union[str, None] = Field(default=None, description="The name of the initial image")
|
||||
|
||||
# SDXL
|
||||
positive_style_prompt: Union[str, None] = Field(default=None, description="The positive style prompt parameter")
|
||||
negative_style_prompt: Union[str, None] = Field(default=None, description="The negative style prompt parameter")
|
||||
|
||||
# SDXL Refiner
|
||||
refiner_model: Union[MainModelField, None] = Field(default=None, description="The SDXL Refiner model used")
|
||||
refiner_cfg_scale: Union[float, None] = Field(
|
||||
default=None,
|
||||
description="The classifier-free guidance scale parameter used for the refiner",
|
||||
)
|
||||
refiner_steps: Union[int, None] = Field(default=None, description="The number of steps used for the refiner")
|
||||
refiner_scheduler: Union[str, None] = Field(default=None, description="The scheduler used for the refiner")
|
||||
refiner_aesthetic_store: Union[float, None] = Field(
|
||||
default=None, description="The aesthetic score used for the refiner"
|
||||
)
|
||||
refiner_start: Union[float, None] = Field(default=None, description="The start value used for refiner denoising")
|
||||
|
||||
|
||||
class ImageMetadata(BaseModel):
|
||||
"""An image's generation metadata"""
|
||||
@ -53,9 +77,7 @@ class ImageMetadata(BaseModel):
|
||||
default=None,
|
||||
description="The image's core metadata, if it was created in the Linear or Canvas UI",
|
||||
)
|
||||
graph: Optional[dict] = Field(
|
||||
default=None, description="The graph that created the image"
|
||||
)
|
||||
graph: Optional[dict] = Field(default=None, description="The graph that created the image")
|
||||
|
||||
|
||||
class MetadataAccumulatorOutput(BaseInvocationOutput):
|
||||
@ -71,7 +93,9 @@ class MetadataAccumulatorInvocation(BaseInvocation):
|
||||
|
||||
type: Literal["metadata_accumulator"] = "metadata_accumulator"
|
||||
|
||||
generation_mode: str = Field(description="The generation mode that output this image",)
|
||||
generation_mode: str = Field(
|
||||
description="The generation mode that output this image",
|
||||
)
|
||||
positive_prompt: str = Field(description="The positive prompt parameter")
|
||||
negative_prompt: str = Field(description="The negative prompt parameter")
|
||||
width: int = Field(description="The width parameter")
|
||||
@ -81,52 +105,48 @@ class MetadataAccumulatorInvocation(BaseInvocation):
|
||||
cfg_scale: float = Field(description="The classifier-free guidance scale parameter")
|
||||
steps: int = Field(description="The number of steps used for inference")
|
||||
scheduler: str = Field(description="The scheduler used for inference")
|
||||
clip_skip: int = Field(description="The number of skipped CLIP layers",)
|
||||
clip_skip: int = Field(
|
||||
description="The number of skipped CLIP layers",
|
||||
)
|
||||
model: MainModelField = Field(description="The main model used for inference")
|
||||
controlnets: list[ControlField]= Field(description="The ControlNets used for inference")
|
||||
controlnets: list[ControlField] = Field(description="The ControlNets used for inference")
|
||||
loras: list[LoRAMetadataField] = Field(description="The LoRAs used for inference")
|
||||
strength: Union[float, None] = Field(
|
||||
default=None,
|
||||
description="The strength used for latents-to-latents",
|
||||
)
|
||||
init_image: Union[str, None] = Field(
|
||||
default=None, description="The name of the initial image"
|
||||
)
|
||||
init_image: Union[str, None] = Field(default=None, description="The name of the initial image")
|
||||
vae: Union[VAEModelField, None] = Field(
|
||||
default=None,
|
||||
description="The VAE used for decoding, if the main model's default was not used",
|
||||
)
|
||||
|
||||
# SDXL
|
||||
positive_style_prompt: Union[str, None] = Field(default=None, description="The positive style prompt parameter")
|
||||
negative_style_prompt: Union[str, None] = Field(default=None, description="The negative style prompt parameter")
|
||||
|
||||
# SDXL Refiner
|
||||
refiner_model: Union[MainModelField, None] = Field(default=None, description="The SDXL Refiner model used")
|
||||
refiner_cfg_scale: Union[float, None] = Field(
|
||||
default=None,
|
||||
description="The classifier-free guidance scale parameter used for the refiner",
|
||||
)
|
||||
refiner_steps: Union[int, None] = Field(default=None, description="The number of steps used for the refiner")
|
||||
refiner_scheduler: Union[str, None] = Field(default=None, description="The scheduler used for the refiner")
|
||||
refiner_aesthetic_store: Union[float, None] = Field(
|
||||
default=None, description="The aesthetic score used for the refiner"
|
||||
)
|
||||
refiner_start: Union[float, None] = Field(default=None, description="The start value used for refiner denoising")
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Metadata Accumulator",
|
||||
"tags": ["image", "metadata", "generation"]
|
||||
"tags": ["image", "metadata", "generation"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def invoke(self, context: InvocationContext) -> MetadataAccumulatorOutput:
|
||||
"""Collects and outputs a CoreMetadata object"""
|
||||
|
||||
return MetadataAccumulatorOutput(
|
||||
metadata=CoreMetadata(
|
||||
generation_mode=self.generation_mode,
|
||||
positive_prompt=self.positive_prompt,
|
||||
negative_prompt=self.negative_prompt,
|
||||
width=self.width,
|
||||
height=self.height,
|
||||
seed=self.seed,
|
||||
rand_device=self.rand_device,
|
||||
cfg_scale=self.cfg_scale,
|
||||
steps=self.steps,
|
||||
scheduler=self.scheduler,
|
||||
model=self.model,
|
||||
strength=self.strength,
|
||||
init_image=self.init_image,
|
||||
vae=self.vae,
|
||||
controlnets=self.controlnets,
|
||||
loras=self.loras,
|
||||
clip_skip=self.clip_skip,
|
||||
)
|
||||
)
|
||||
return MetadataAccumulatorOutput(metadata=CoreMetadata(**self.dict()))
|
||||
|
@ -4,17 +4,14 @@ from typing import List, Literal, Optional, Union
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ...backend.model_management import BaseModelType, ModelType, SubModelType
|
||||
from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
|
||||
InvocationConfig, InvocationContext)
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
|
||||
|
||||
|
||||
class ModelInfo(BaseModel):
|
||||
model_name: str = Field(description="Info to load submodel")
|
||||
base_model: BaseModelType = Field(description="Base model")
|
||||
model_type: ModelType = Field(description="Info to load submodel")
|
||||
submodel: Optional[SubModelType] = Field(
|
||||
default=None, description="Info to load submodel"
|
||||
)
|
||||
submodel: Optional[SubModelType] = Field(default=None, description="Info to load submodel")
|
||||
|
||||
|
||||
class LoraInfo(ModelInfo):
|
||||
@ -33,6 +30,7 @@ class ClipField(BaseModel):
|
||||
skipped_layers: int = Field(description="Number of skipped layers in text_encoder")
|
||||
loras: List[LoraInfo] = Field(description="Loras to apply on model loading")
|
||||
|
||||
|
||||
class VaeField(BaseModel):
|
||||
# TODO: better naming?
|
||||
vae: ModelInfo = Field(description="Info to load vae submodel")
|
||||
@ -49,6 +47,7 @@ class ModelLoaderOutput(BaseInvocationOutput):
|
||||
vae: VaeField = Field(default=None, description="Vae submodel")
|
||||
# fmt: on
|
||||
|
||||
|
||||
class MainModelField(BaseModel):
|
||||
"""Main model field"""
|
||||
|
||||
@ -62,6 +61,7 @@ class LoRAModelField(BaseModel):
|
||||
model_name: str = Field(description="Name of the LoRA model")
|
||||
base_model: BaseModelType = Field(description="Base model")
|
||||
|
||||
|
||||
class MainModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads a main model, outputting its submodels."""
|
||||
|
||||
@ -180,7 +180,7 @@ class MainModelLoaderInvocation(BaseInvocation):
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
|
||||
class LoraLoaderOutput(BaseInvocationOutput):
|
||||
"""Model loader output"""
|
||||
|
||||
@ -197,9 +197,7 @@ class LoraLoaderInvocation(BaseInvocation):
|
||||
|
||||
type: Literal["lora_loader"] = "lora_loader"
|
||||
|
||||
lora: Union[LoRAModelField, None] = Field(
|
||||
default=None, description="Lora model name"
|
||||
)
|
||||
lora: Union[LoRAModelField, None] = Field(default=None, description="Lora model name")
|
||||
weight: float = Field(default=0.75, description="With what weight to apply lora")
|
||||
|
||||
unet: Optional[UNetField] = Field(description="UNet model for applying lora")
|
||||
@ -228,14 +226,10 @@ class LoraLoaderInvocation(BaseInvocation):
|
||||
):
|
||||
raise Exception(f"Unkown lora name: {lora_name}!")
|
||||
|
||||
if self.unet is not None and any(
|
||||
lora.model_name == lora_name for lora in self.unet.loras
|
||||
):
|
||||
if self.unet is not None and any(lora.model_name == lora_name for lora in self.unet.loras):
|
||||
raise Exception(f'Lora "{lora_name}" already applied to unet')
|
||||
|
||||
if self.clip is not None and any(
|
||||
lora.model_name == lora_name for lora in self.clip.loras
|
||||
):
|
||||
if self.clip is not None and any(lora.model_name == lora_name for lora in self.clip.loras):
|
||||
raise Exception(f'Lora "{lora_name}" already applied to clip')
|
||||
|
||||
output = LoraLoaderOutput()
|
||||
|
@ -12,16 +12,37 @@ import matplotlib.pyplot as plt
|
||||
|
||||
from easing_functions import (
|
||||
LinearInOut,
|
||||
QuadEaseInOut, QuadEaseIn, QuadEaseOut,
|
||||
CubicEaseInOut, CubicEaseIn, CubicEaseOut,
|
||||
QuarticEaseInOut, QuarticEaseIn, QuarticEaseOut,
|
||||
QuinticEaseInOut, QuinticEaseIn, QuinticEaseOut,
|
||||
SineEaseInOut, SineEaseIn, SineEaseOut,
|
||||
CircularEaseIn, CircularEaseInOut, CircularEaseOut,
|
||||
ExponentialEaseInOut, ExponentialEaseIn, ExponentialEaseOut,
|
||||
ElasticEaseIn, ElasticEaseInOut, ElasticEaseOut,
|
||||
BackEaseIn, BackEaseInOut, BackEaseOut,
|
||||
BounceEaseIn, BounceEaseInOut, BounceEaseOut)
|
||||
QuadEaseInOut,
|
||||
QuadEaseIn,
|
||||
QuadEaseOut,
|
||||
CubicEaseInOut,
|
||||
CubicEaseIn,
|
||||
CubicEaseOut,
|
||||
QuarticEaseInOut,
|
||||
QuarticEaseIn,
|
||||
QuarticEaseOut,
|
||||
QuinticEaseInOut,
|
||||
QuinticEaseIn,
|
||||
QuinticEaseOut,
|
||||
SineEaseInOut,
|
||||
SineEaseIn,
|
||||
SineEaseOut,
|
||||
CircularEaseIn,
|
||||
CircularEaseInOut,
|
||||
CircularEaseOut,
|
||||
ExponentialEaseInOut,
|
||||
ExponentialEaseIn,
|
||||
ExponentialEaseOut,
|
||||
ElasticEaseIn,
|
||||
ElasticEaseInOut,
|
||||
ElasticEaseOut,
|
||||
BackEaseIn,
|
||||
BackEaseInOut,
|
||||
BackEaseOut,
|
||||
BounceEaseIn,
|
||||
BounceEaseInOut,
|
||||
BounceEaseOut,
|
||||
)
|
||||
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
@ -45,17 +66,12 @@ class FloatLinearRangeInvocation(BaseInvocation):
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Linear Range (Float)",
|
||||
"tags": ["math", "float", "linear", "range"]
|
||||
},
|
||||
"ui": {"title": "Linear Range (Float)", "tags": ["math", "float", "linear", "range"]},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FloatCollectionOutput:
|
||||
param_list = list(np.linspace(self.start, self.stop, self.steps))
|
||||
return FloatCollectionOutput(
|
||||
collection=param_list
|
||||
)
|
||||
return FloatCollectionOutput(collection=param_list)
|
||||
|
||||
|
||||
EASING_FUNCTIONS_MAP = {
|
||||
@ -92,9 +108,7 @@ EASING_FUNCTIONS_MAP = {
|
||||
"BounceInOut": BounceEaseInOut,
|
||||
}
|
||||
|
||||
EASING_FUNCTION_KEYS: Any = Literal[
|
||||
tuple(list(EASING_FUNCTIONS_MAP.keys()))
|
||||
]
|
||||
EASING_FUNCTION_KEYS: Any = Literal[tuple(list(EASING_FUNCTIONS_MAP.keys()))]
|
||||
|
||||
|
||||
# actually I think for now could just use CollectionOutput (which is list[Any]
|
||||
@ -123,13 +137,9 @@ class StepParamEasingInvocation(BaseInvocation):
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Param Easing By Step",
|
||||
"tags": ["param", "step", "easing"]
|
||||
},
|
||||
"ui": {"title": "Param Easing By Step", "tags": ["param", "step", "easing"]},
|
||||
}
|
||||
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FloatCollectionOutput:
|
||||
log_diagnostics = False
|
||||
# convert from start_step_percent to nearest step <= (steps * start_step_percent)
|
||||
@ -170,12 +180,13 @@ class StepParamEasingInvocation(BaseInvocation):
|
||||
# and create reverse copy of list[1:end-1]
|
||||
# but if even then number_of_steps/2 === ceil(number_of_steps/2), so can just use ceil always
|
||||
|
||||
base_easing_duration = int(np.ceil(num_easing_steps/2.0))
|
||||
if log_diagnostics: context.services.logger.debug("base easing duration: " + str(base_easing_duration))
|
||||
even_num_steps = (num_easing_steps % 2 == 0) # even number of steps
|
||||
easing_function = easing_class(start=self.start_value,
|
||||
end=self.end_value,
|
||||
duration=base_easing_duration - 1)
|
||||
base_easing_duration = int(np.ceil(num_easing_steps / 2.0))
|
||||
if log_diagnostics:
|
||||
context.services.logger.debug("base easing duration: " + str(base_easing_duration))
|
||||
even_num_steps = num_easing_steps % 2 == 0 # even number of steps
|
||||
easing_function = easing_class(
|
||||
start=self.start_value, end=self.end_value, duration=base_easing_duration - 1
|
||||
)
|
||||
base_easing_vals = list()
|
||||
for step_index in range(base_easing_duration):
|
||||
easing_val = easing_function.ease(step_index)
|
||||
@ -214,9 +225,7 @@ class StepParamEasingInvocation(BaseInvocation):
|
||||
#
|
||||
|
||||
else: # no mirroring (default)
|
||||
easing_function = easing_class(start=self.start_value,
|
||||
end=self.end_value,
|
||||
duration=num_easing_steps - 1)
|
||||
easing_function = easing_class(start=self.start_value, end=self.end_value, duration=num_easing_steps - 1)
|
||||
for step_index in range(num_easing_steps):
|
||||
step_val = easing_function.ease(step_index)
|
||||
easing_list.append(step_val)
|
||||
@ -240,13 +249,11 @@ class StepParamEasingInvocation(BaseInvocation):
|
||||
ax = plt.gca()
|
||||
ax.xaxis.set_major_locator(MaxNLocator(integer=True))
|
||||
buf = io.BytesIO()
|
||||
plt.savefig(buf, format='png')
|
||||
plt.savefig(buf, format="png")
|
||||
buf.seek(0)
|
||||
im = PIL.Image.open(buf)
|
||||
im.show()
|
||||
buf.close()
|
||||
|
||||
# output array of size steps, each entry list[i] is param value for step i
|
||||
return FloatCollectionOutput(
|
||||
collection=param_list
|
||||
)
|
||||
return FloatCollectionOutput(collection=param_list)
|
||||
|
@ -4,67 +4,63 @@ from typing import Literal
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
|
||||
InvocationConfig, InvocationContext)
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
|
||||
from .math import FloatOutput, IntOutput
|
||||
|
||||
# Pass-through parameter nodes - used by subgraphs
|
||||
|
||||
|
||||
class ParamIntInvocation(BaseInvocation):
|
||||
"""An integer parameter"""
|
||||
#fmt: off
|
||||
|
||||
# fmt: off
|
||||
type: Literal["param_int"] = "param_int"
|
||||
a: int = Field(default=0, description="The integer value")
|
||||
#fmt: on
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["param", "integer"],
|
||||
"title": "Integer Parameter"
|
||||
},
|
||||
}
|
||||
schema_extra = {
|
||||
"ui": {"tags": ["param", "integer"], "title": "Integer Parameter"},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntOutput:
|
||||
return IntOutput(a=self.a)
|
||||
|
||||
|
||||
class ParamFloatInvocation(BaseInvocation):
|
||||
"""A float parameter"""
|
||||
#fmt: off
|
||||
|
||||
# fmt: off
|
||||
type: Literal["param_float"] = "param_float"
|
||||
param: float = Field(default=0.0, description="The float value")
|
||||
#fmt: on
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["param", "float"],
|
||||
"title": "Float Parameter"
|
||||
},
|
||||
}
|
||||
schema_extra = {
|
||||
"ui": {"tags": ["param", "float"], "title": "Float Parameter"},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FloatOutput:
|
||||
return FloatOutput(param=self.param)
|
||||
|
||||
|
||||
class StringOutput(BaseInvocationOutput):
|
||||
"""A string output"""
|
||||
|
||||
type: Literal["string_output"] = "string_output"
|
||||
text: str = Field(default=None, description="The output string")
|
||||
|
||||
|
||||
class ParamStringInvocation(BaseInvocation):
|
||||
"""A string parameter"""
|
||||
type: Literal['param_string'] = 'param_string'
|
||||
text: str = Field(default='', description='The string value')
|
||||
|
||||
type: Literal["param_string"] = "param_string"
|
||||
text: str = Field(default="", description="The string value")
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["param", "string"],
|
||||
"title": "String Parameter"
|
||||
},
|
||||
}
|
||||
schema_extra = {
|
||||
"ui": {"tags": ["param", "string"], "title": "String Parameter"},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> StringOutput:
|
||||
return StringOutput(text=self.text)
|
||||
|
@ -7,19 +7,21 @@ from pydantic import Field, validator
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
|
||||
from dynamicprompts.generators import RandomPromptGenerator, CombinatorialPromptGenerator
|
||||
|
||||
|
||||
class PromptOutput(BaseInvocationOutput):
|
||||
"""Base class for invocations that output a prompt"""
|
||||
#fmt: off
|
||||
|
||||
# fmt: off
|
||||
type: Literal["prompt"] = "prompt"
|
||||
|
||||
prompt: str = Field(default=None, description="The output prompt")
|
||||
#fmt: on
|
||||
# fmt: on
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
'required': [
|
||||
'type',
|
||||
'prompt',
|
||||
"required": [
|
||||
"type",
|
||||
"prompt",
|
||||
]
|
||||
}
|
||||
|
||||
@ -44,16 +46,11 @@ class DynamicPromptInvocation(BaseInvocation):
|
||||
type: Literal["dynamic_prompt"] = "dynamic_prompt"
|
||||
prompt: str = Field(description="The prompt to parse with dynamicprompts")
|
||||
max_prompts: int = Field(default=1, description="The number of prompts to generate")
|
||||
combinatorial: bool = Field(
|
||||
default=False, description="Whether to use the combinatorial generator"
|
||||
)
|
||||
combinatorial: bool = Field(default=False, description="Whether to use the combinatorial generator")
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Dynamic Prompt",
|
||||
"tags": ["prompt", "dynamic"]
|
||||
},
|
||||
"ui": {"title": "Dynamic Prompt", "tags": ["prompt", "dynamic"]},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> PromptCollectionOutput:
|
||||
@ -65,10 +62,11 @@ class DynamicPromptInvocation(BaseInvocation):
|
||||
prompts = generator.generate(self.prompt, num_images=self.max_prompts)
|
||||
|
||||
return PromptCollectionOutput(prompt_collection=prompts, count=len(prompts))
|
||||
|
||||
|
||||
|
||||
class PromptsFromFileInvocation(BaseInvocation):
|
||||
'''Loads prompts from a text file'''
|
||||
"""Loads prompts from a text file"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal['prompt_from_file'] = 'prompt_from_file'
|
||||
|
||||
@ -78,14 +76,11 @@ class PromptsFromFileInvocation(BaseInvocation):
|
||||
post_prompt: Optional[str] = Field(description="String to append to each prompt")
|
||||
start_line: int = Field(default=1, ge=1, description="Line in the file to start start from")
|
||||
max_prompts: int = Field(default=1, ge=0, description="Max lines to read from file (0=all)")
|
||||
#fmt: on
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Prompts From File",
|
||||
"tags": ["prompt", "file"]
|
||||
},
|
||||
"ui": {"title": "Prompts From File", "tags": ["prompt", "file"]},
|
||||
}
|
||||
|
||||
@validator("file_path")
|
||||
@ -103,11 +98,13 @@ class PromptsFromFileInvocation(BaseInvocation):
|
||||
with open(file_path) as f:
|
||||
for i, line in enumerate(f):
|
||||
if i >= start_line and i < end_line:
|
||||
prompts.append((pre_prompt or '') + line.strip() + (post_prompt or ''))
|
||||
prompts.append((pre_prompt or "") + line.strip() + (post_prompt or ""))
|
||||
if i >= end_line:
|
||||
break
|
||||
return prompts
|
||||
|
||||
def invoke(self, context: InvocationContext) -> PromptCollectionOutput:
|
||||
prompts = self.promptsFromFile(self.file_path, self.pre_prompt, self.post_prompt, self.start_line, self.max_prompts)
|
||||
prompts = self.promptsFromFile(
|
||||
self.file_path, self.pre_prompt, self.post_prompt, self.start_line, self.max_prompts
|
||||
)
|
||||
return PromptCollectionOutput(prompt_collection=prompts, count=len(prompts))
|
||||
|
@ -7,13 +7,13 @@ from pydantic import Field, validator
|
||||
|
||||
from ...backend.model_management import ModelType, SubModelType
|
||||
from invokeai.app.util.step_callback import stable_diffusion_xl_step_callback
|
||||
from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
|
||||
InvocationConfig, InvocationContext)
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
|
||||
|
||||
from .model import UNetField, ClipField, VaeField, MainModelField, ModelInfo
|
||||
from .compel import ConditioningField
|
||||
from .latent import LatentsField, SAMPLER_NAME_VALUES, LatentsOutput, get_scheduler, build_latents_output
|
||||
|
||||
|
||||
class SDXLModelLoaderOutput(BaseInvocationOutput):
|
||||
"""SDXL base model loader output"""
|
||||
|
||||
@ -26,16 +26,19 @@ class SDXLModelLoaderOutput(BaseInvocationOutput):
|
||||
vae: VaeField = Field(default=None, description="Vae submodel")
|
||||
# fmt: on
|
||||
|
||||
|
||||
class SDXLRefinerModelLoaderOutput(BaseInvocationOutput):
|
||||
"""SDXL refiner model loader output"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["sdxl_refiner_model_loader_output"] = "sdxl_refiner_model_loader_output"
|
||||
unet: UNetField = Field(default=None, description="UNet submodel")
|
||||
clip2: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels")
|
||||
vae: VaeField = Field(default=None, description="Vae submodel")
|
||||
# fmt: on
|
||||
#fmt: on
|
||||
|
||||
# fmt: on
|
||||
|
||||
|
||||
class SDXLModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads an sdxl base model, outputting its submodels."""
|
||||
|
||||
@ -125,8 +128,10 @@ class SDXLModelLoaderInvocation(BaseInvocation):
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class SDXLRefinerModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads an sdxl refiner model, outputting its submodels."""
|
||||
|
||||
type: Literal["sdxl_refiner_model_loader"] = "sdxl_refiner_model_loader"
|
||||
|
||||
model: MainModelField = Field(description="The model to load")
|
||||
@ -138,7 +143,7 @@ class SDXLRefinerModelLoaderInvocation(BaseInvocation):
|
||||
"ui": {
|
||||
"title": "SDXL Refiner Model Loader",
|
||||
"tags": ["model", "loader", "sdxl_refiner"],
|
||||
"type_hints": {"model": "model"},
|
||||
"type_hints": {"model": "refiner_model"},
|
||||
},
|
||||
}
|
||||
|
||||
@ -196,7 +201,8 @@ class SDXLRefinerModelLoaderInvocation(BaseInvocation):
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
|
||||
# Text to image
|
||||
class SDXLTextToLatentsInvocation(BaseInvocation):
|
||||
"""Generates latents from conditionings."""
|
||||
@ -213,9 +219,9 @@ class SDXLTextToLatentsInvocation(BaseInvocation):
|
||||
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
|
||||
unet: UNetField = Field(default=None, description="UNet submodel")
|
||||
denoising_end: float = Field(default=1.0, gt=0, le=1, description="")
|
||||
#control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use")
|
||||
#seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
|
||||
#seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
|
||||
# control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use")
|
||||
# seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
|
||||
# seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
|
||||
# fmt: on
|
||||
|
||||
@validator("cfg_scale")
|
||||
@ -224,10 +230,10 @@ class SDXLTextToLatentsInvocation(BaseInvocation):
|
||||
if isinstance(v, list):
|
||||
for i in v:
|
||||
if i < 1:
|
||||
raise ValueError('cfg_scale must be greater than 1')
|
||||
raise ValueError("cfg_scale must be greater than 1")
|
||||
else:
|
||||
if v < 1:
|
||||
raise ValueError('cfg_scale must be greater than 1')
|
||||
raise ValueError("cfg_scale must be greater than 1")
|
||||
return v
|
||||
|
||||
# Schema customisation
|
||||
@ -237,10 +243,10 @@ class SDXLTextToLatentsInvocation(BaseInvocation):
|
||||
"title": "SDXL Text To Latents",
|
||||
"tags": ["latents"],
|
||||
"type_hints": {
|
||||
"model": "model",
|
||||
# "cfg_scale": "float",
|
||||
"cfg_scale": "number"
|
||||
}
|
||||
"model": "model",
|
||||
# "cfg_scale": "float",
|
||||
"cfg_scale": "number",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@ -265,9 +271,7 @@ class SDXLTextToLatentsInvocation(BaseInvocation):
|
||||
# https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L375
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
graph_execution_state = context.services.graph_execution_manager.get(
|
||||
context.graph_execution_state_id
|
||||
)
|
||||
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
|
||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||
latents = context.services.latents.get(self.noise.latents_name)
|
||||
|
||||
@ -293,14 +297,10 @@ class SDXLTextToLatentsInvocation(BaseInvocation):
|
||||
|
||||
latents = latents * scheduler.init_noise_sigma
|
||||
|
||||
|
||||
unet_info = context.services.model_manager.get_model(
|
||||
**self.unet.unet.dict()
|
||||
)
|
||||
unet_info = context.services.model_manager.get_model(**self.unet.unet.dict(), context=context)
|
||||
do_classifier_free_guidance = True
|
||||
cross_attention_kwargs = None
|
||||
with unet_info as unet:
|
||||
|
||||
extra_step_kwargs = dict()
|
||||
if "eta" in set(inspect.signature(scheduler.step).parameters.keys()):
|
||||
extra_step_kwargs.update(
|
||||
@ -350,10 +350,10 @@ class SDXLTextToLatentsInvocation(BaseInvocation):
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + self.cfg_scale * (noise_pred_text - noise_pred_uncond)
|
||||
#del noise_pred_uncond
|
||||
#del noise_pred_text
|
||||
# del noise_pred_uncond
|
||||
# del noise_pred_text
|
||||
|
||||
#if do_classifier_free_guidance and guidance_rescale > 0.0:
|
||||
# if do_classifier_free_guidance and guidance_rescale > 0.0:
|
||||
# # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
||||
# noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
|
||||
|
||||
@ -364,7 +364,7 @@ class SDXLTextToLatentsInvocation(BaseInvocation):
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
self.dispatch_progress(context, source_node_id, latents, i, num_inference_steps)
|
||||
#if callback is not None and i % callback_steps == 0:
|
||||
# if callback is not None and i % callback_steps == 0:
|
||||
# callback(i, t, latents)
|
||||
else:
|
||||
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.to(device=unet.device, dtype=unet.dtype)
|
||||
@ -378,13 +378,13 @@ class SDXLTextToLatentsInvocation(BaseInvocation):
|
||||
with tqdm(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
#latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
# latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
|
||||
latent_model_input = scheduler.scale_model_input(latents, t)
|
||||
|
||||
#import gc
|
||||
#gc.collect()
|
||||
#torch.cuda.empty_cache()
|
||||
# import gc
|
||||
# gc.collect()
|
||||
# torch.cuda.empty_cache()
|
||||
|
||||
# predict the noise residual
|
||||
|
||||
@ -411,42 +411,41 @@ class SDXLTextToLatentsInvocation(BaseInvocation):
|
||||
# perform guidance
|
||||
noise_pred = noise_pred_uncond + self.cfg_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
#del noise_pred_text
|
||||
#del noise_pred_uncond
|
||||
#import gc
|
||||
#gc.collect()
|
||||
#torch.cuda.empty_cache()
|
||||
# del noise_pred_text
|
||||
# del noise_pred_uncond
|
||||
# import gc
|
||||
# gc.collect()
|
||||
# torch.cuda.empty_cache()
|
||||
|
||||
#if do_classifier_free_guidance and guidance_rescale > 0.0:
|
||||
# if do_classifier_free_guidance and guidance_rescale > 0.0:
|
||||
# # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
||||
# noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
||||
|
||||
#del noise_pred
|
||||
#import gc
|
||||
#gc.collect()
|
||||
#torch.cuda.empty_cache()
|
||||
# del noise_pred
|
||||
# import gc
|
||||
# gc.collect()
|
||||
# torch.cuda.empty_cache()
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
self.dispatch_progress(context, source_node_id, latents, i, num_inference_steps)
|
||||
#if callback is not None and i % callback_steps == 0:
|
||||
# if callback is not None and i % callback_steps == 0:
|
||||
# callback(i, t, latents)
|
||||
|
||||
|
||||
|
||||
#################
|
||||
|
||||
latents = latents.to("cpu")
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
name = f'{context.graph_execution_state_id}__{self.id}'
|
||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||
context.services.latents.save(name, latents)
|
||||
return build_latents_output(latents_name=name, latents=latents)
|
||||
|
||||
|
||||
class SDXLLatentsToLatentsInvocation(BaseInvocation):
|
||||
"""Generates latents from conditionings."""
|
||||
|
||||
@ -463,12 +462,12 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation):
|
||||
unet: UNetField = Field(default=None, description="UNet submodel")
|
||||
latents: Optional[LatentsField] = Field(description="Initial latents")
|
||||
|
||||
denoising_start: float = Field(default=0.0, ge=0, lt=1, description="")
|
||||
denoising_end: float = Field(default=1.0, gt=0, le=1, description="")
|
||||
denoising_start: float = Field(default=0.0, ge=0, le=1, description="")
|
||||
denoising_end: float = Field(default=1.0, ge=0, le=1, description="")
|
||||
|
||||
#control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use")
|
||||
#seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
|
||||
#seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
|
||||
# control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use")
|
||||
# seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
|
||||
# seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
|
||||
# fmt: on
|
||||
|
||||
@validator("cfg_scale")
|
||||
@ -477,10 +476,10 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation):
|
||||
if isinstance(v, list):
|
||||
for i in v:
|
||||
if i < 1:
|
||||
raise ValueError('cfg_scale must be greater than 1')
|
||||
raise ValueError("cfg_scale must be greater than 1")
|
||||
else:
|
||||
if v < 1:
|
||||
raise ValueError('cfg_scale must be greater than 1')
|
||||
raise ValueError("cfg_scale must be greater than 1")
|
||||
return v
|
||||
|
||||
# Schema customisation
|
||||
@ -490,10 +489,10 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation):
|
||||
"title": "SDXL Latents to Latents",
|
||||
"tags": ["latents"],
|
||||
"type_hints": {
|
||||
"model": "model",
|
||||
# "cfg_scale": "float",
|
||||
"cfg_scale": "number"
|
||||
}
|
||||
"model": "model",
|
||||
# "cfg_scale": "float",
|
||||
"cfg_scale": "number",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@ -518,9 +517,7 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation):
|
||||
# https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L375
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
graph_execution_state = context.services.graph_execution_manager.get(
|
||||
context.graph_execution_state_id
|
||||
)
|
||||
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
|
||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||
latents = context.services.latents.get(self.latents.latents_name)
|
||||
|
||||
@ -545,22 +542,22 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation):
|
||||
scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
t_start = int(round(self.denoising_start * num_inference_steps))
|
||||
timesteps = scheduler.timesteps[t_start * scheduler.order:]
|
||||
timesteps = scheduler.timesteps[t_start * scheduler.order :]
|
||||
num_inference_steps = num_inference_steps - t_start
|
||||
|
||||
# apply noise(if provided)
|
||||
if self.noise is not None:
|
||||
if self.noise is not None and timesteps.shape[0] > 0:
|
||||
noise = context.services.latents.get(self.noise.latents_name)
|
||||
latents = scheduler.add_noise(latents, noise, timesteps[:1])
|
||||
del noise
|
||||
|
||||
unet_info = context.services.model_manager.get_model(
|
||||
**self.unet.unet.dict()
|
||||
**self.unet.unet.dict(),
|
||||
context=context,
|
||||
)
|
||||
do_classifier_free_guidance = True
|
||||
cross_attention_kwargs = None
|
||||
with unet_info as unet:
|
||||
|
||||
# apply scheduler extra args
|
||||
extra_step_kwargs = dict()
|
||||
if "eta" in set(inspect.signature(scheduler.step).parameters.keys()):
|
||||
@ -611,10 +608,10 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation):
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + self.cfg_scale * (noise_pred_text - noise_pred_uncond)
|
||||
#del noise_pred_uncond
|
||||
#del noise_pred_text
|
||||
# del noise_pred_uncond
|
||||
# del noise_pred_text
|
||||
|
||||
#if do_classifier_free_guidance and guidance_rescale > 0.0:
|
||||
# if do_classifier_free_guidance and guidance_rescale > 0.0:
|
||||
# # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
||||
# noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
|
||||
|
||||
@ -625,7 +622,7 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation):
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
self.dispatch_progress(context, source_node_id, latents, i, num_inference_steps)
|
||||
#if callback is not None and i % callback_steps == 0:
|
||||
# if callback is not None and i % callback_steps == 0:
|
||||
# callback(i, t, latents)
|
||||
else:
|
||||
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.to(device=unet.device, dtype=unet.dtype)
|
||||
@ -639,13 +636,13 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation):
|
||||
with tqdm(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
#latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
# latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
|
||||
latent_model_input = scheduler.scale_model_input(latents, t)
|
||||
|
||||
#import gc
|
||||
#gc.collect()
|
||||
#torch.cuda.empty_cache()
|
||||
# import gc
|
||||
# gc.collect()
|
||||
# torch.cuda.empty_cache()
|
||||
|
||||
# predict the noise residual
|
||||
|
||||
@ -672,38 +669,36 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation):
|
||||
# perform guidance
|
||||
noise_pred = noise_pred_uncond + self.cfg_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
#del noise_pred_text
|
||||
#del noise_pred_uncond
|
||||
#import gc
|
||||
#gc.collect()
|
||||
#torch.cuda.empty_cache()
|
||||
# del noise_pred_text
|
||||
# del noise_pred_uncond
|
||||
# import gc
|
||||
# gc.collect()
|
||||
# torch.cuda.empty_cache()
|
||||
|
||||
#if do_classifier_free_guidance and guidance_rescale > 0.0:
|
||||
# if do_classifier_free_guidance and guidance_rescale > 0.0:
|
||||
# # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
||||
# noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
||||
|
||||
#del noise_pred
|
||||
#import gc
|
||||
#gc.collect()
|
||||
#torch.cuda.empty_cache()
|
||||
# del noise_pred
|
||||
# import gc
|
||||
# gc.collect()
|
||||
# torch.cuda.empty_cache()
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
self.dispatch_progress(context, source_node_id, latents, i, num_inference_steps)
|
||||
#if callback is not None and i % callback_steps == 0:
|
||||
# if callback is not None and i % callback_steps == 0:
|
||||
# callback(i, t, latents)
|
||||
|
||||
|
||||
|
||||
#################
|
||||
|
||||
latents = latents.to("cpu")
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
name = f'{context.graph_execution_state_id}__{self.id}'
|
||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||
context.services.latents.save(name, latents)
|
||||
return build_latents_output(latents_name=name, latents=latents)
|
||||
|
@ -29,16 +29,11 @@ class ESRGANInvocation(BaseInvocation):
|
||||
|
||||
type: Literal["esrgan"] = "esrgan"
|
||||
image: Union[ImageField, None] = Field(default=None, description="The input image")
|
||||
model_name: ESRGAN_MODELS = Field(
|
||||
default="RealESRGAN_x4plus.pth", description="The Real-ESRGAN model to use"
|
||||
)
|
||||
model_name: ESRGAN_MODELS = Field(default="RealESRGAN_x4plus.pth", description="The Real-ESRGAN model to use")
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Upscale (RealESRGAN)",
|
||||
"tags": ["image", "upscale", "realesrgan"]
|
||||
},
|
||||
"ui": {"title": "Upscale (RealESRGAN)", "tags": ["image", "upscale", "realesrgan"]},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
@ -108,9 +103,7 @@ class ESRGANInvocation(BaseInvocation):
|
||||
upscaled_image, img_mode = upsampler.enhance(cv_image)
|
||||
|
||||
# back to PIL
|
||||
pil_image = Image.fromarray(
|
||||
cv.cvtColor(upscaled_image, cv.COLOR_BGR2RGB)
|
||||
).convert("RGBA")
|
||||
pil_image = Image.fromarray(cv.cvtColor(upscaled_image, cv.COLOR_BGR2RGB)).convert("RGBA")
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=pil_image,
|
||||
|
@ -1,3 +1,4 @@
|
||||
class CanceledException(Exception):
|
||||
"""Execution canceled by user."""
|
||||
|
||||
pass
|
||||
|
@ -1,8 +1,83 @@
|
||||
from enum import Enum
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional, Tuple, Literal
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.util.metaenum import MetaEnum
|
||||
from ..invocations.baseinvocation import (
|
||||
BaseInvocationOutput,
|
||||
InvocationConfig,
|
||||
)
|
||||
|
||||
|
||||
class ImageField(BaseModel):
|
||||
"""An image field used for passing image objects between invocations"""
|
||||
|
||||
image_name: Optional[str] = Field(default=None, description="The name of the image")
|
||||
|
||||
class Config:
|
||||
schema_extra = {"required": ["image_name"]}
|
||||
|
||||
|
||||
class ColorField(BaseModel):
|
||||
r: int = Field(ge=0, le=255, description="The red component")
|
||||
g: int = Field(ge=0, le=255, description="The green component")
|
||||
b: int = Field(ge=0, le=255, description="The blue component")
|
||||
a: int = Field(ge=0, le=255, description="The alpha component")
|
||||
|
||||
def tuple(self) -> Tuple[int, int, int, int]:
|
||||
return (self.r, self.g, self.b, self.a)
|
||||
|
||||
|
||||
class ProgressImage(BaseModel):
|
||||
"""The progress image sent intermittently during processing"""
|
||||
|
||||
width: int = Field(description="The effective width of the image in pixels")
|
||||
height: int = Field(description="The effective height of the image in pixels")
|
||||
dataURL: str = Field(description="The image data as a b64 data URL")
|
||||
|
||||
|
||||
class PILInvocationConfig(BaseModel):
|
||||
"""Helper class to provide all PIL invocations with additional config"""
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["PIL", "image"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class ImageOutput(BaseInvocationOutput):
|
||||
"""Base class for invocations that output an image"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["image_output"] = "image_output"
|
||||
image: ImageField = Field(default=None, description="The output image")
|
||||
width: int = Field(description="The width of the image in pixels")
|
||||
height: int = Field(description="The height of the image in pixels")
|
||||
# fmt: on
|
||||
|
||||
class Config:
|
||||
schema_extra = {"required": ["type", "image", "width", "height"]}
|
||||
|
||||
|
||||
class MaskOutput(BaseInvocationOutput):
|
||||
"""Base class for invocations that output a mask"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["mask"] = "mask"
|
||||
mask: ImageField = Field(default=None, description="The output mask")
|
||||
width: int = Field(description="The width of the mask in pixels")
|
||||
height: int = Field(description="The height of the mask in pixels")
|
||||
# fmt: on
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
"required": [
|
||||
"type",
|
||||
"mask",
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
class ResourceOrigin(str, Enum, metaclass=MetaEnum):
|
||||
@ -61,30 +136,3 @@ class InvalidImageCategoryException(ValueError):
|
||||
|
||||
def __init__(self, message="Invalid image category."):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class ImageField(BaseModel):
|
||||
"""An image field used for passing image objects between invocations"""
|
||||
|
||||
image_name: Optional[str] = Field(default=None, description="The name of the image")
|
||||
|
||||
class Config:
|
||||
schema_extra = {"required": ["image_name"]}
|
||||
|
||||
|
||||
class ColorField(BaseModel):
|
||||
r: int = Field(ge=0, le=255, description="The red component")
|
||||
g: int = Field(ge=0, le=255, description="The green component")
|
||||
b: int = Field(ge=0, le=255, description="The blue component")
|
||||
a: int = Field(ge=0, le=255, description="The alpha component")
|
||||
|
||||
def tuple(self) -> Tuple[int, int, int, int]:
|
||||
return (self.r, self.g, self.b, self.a)
|
||||
|
||||
|
||||
class ProgressImage(BaseModel):
|
||||
"""The progress image sent intermittently during processing"""
|
||||
|
||||
width: int = Field(description="The effective width of the image in pixels")
|
||||
height: int = Field(description="The effective height of the image in pixels")
|
||||
dataURL: str = Field(description="The image data as a b64 data URL")
|
||||
|
@ -207,9 +207,7 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
|
||||
raise e
|
||||
finally:
|
||||
self._lock.release()
|
||||
return OffsetPaginatedResults(
|
||||
items=images, offset=offset, limit=limit, total=count
|
||||
)
|
||||
return OffsetPaginatedResults(items=images, offset=offset, limit=limit, total=count)
|
||||
|
||||
def get_all_board_image_names_for_board(self, board_id: str) -> list[str]:
|
||||
try:
|
||||
|
@ -102,9 +102,7 @@ class BoardImagesService(BoardImagesServiceABC):
|
||||
self,
|
||||
board_id: str,
|
||||
) -> list[str]:
|
||||
return self._services.board_image_records.get_all_board_image_names_for_board(
|
||||
board_id
|
||||
)
|
||||
return self._services.board_image_records.get_all_board_image_names_for_board(board_id)
|
||||
|
||||
def get_board_for_image(
|
||||
self,
|
||||
@ -114,9 +112,7 @@ class BoardImagesService(BoardImagesServiceABC):
|
||||
return board_id
|
||||
|
||||
|
||||
def board_record_to_dto(
|
||||
board_record: BoardRecord, cover_image_name: Optional[str], image_count: int
|
||||
) -> BoardDTO:
|
||||
def board_record_to_dto(board_record: BoardRecord, cover_image_name: Optional[str], image_count: int) -> BoardDTO:
|
||||
"""Converts a board record to a board DTO."""
|
||||
return BoardDTO(
|
||||
**board_record.dict(exclude={"cover_image_name"}),
|
||||
|
@ -15,9 +15,7 @@ from pydantic import BaseModel, Field, Extra
|
||||
|
||||
class BoardChanges(BaseModel, extra=Extra.forbid):
|
||||
board_name: Optional[str] = Field(description="The board's new name.")
|
||||
cover_image_name: Optional[str] = Field(
|
||||
description="The name of the board's new cover image."
|
||||
)
|
||||
cover_image_name: Optional[str] = Field(description="The name of the board's new cover image.")
|
||||
|
||||
|
||||
class BoardRecordNotFoundException(Exception):
|
||||
@ -292,9 +290,7 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
||||
|
||||
count = cast(int, self._cursor.fetchone()[0])
|
||||
|
||||
return OffsetPaginatedResults[BoardRecord](
|
||||
items=boards, offset=offset, limit=limit, total=count
|
||||
)
|
||||
return OffsetPaginatedResults[BoardRecord](items=boards, offset=offset, limit=limit, total=count)
|
||||
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
|
@ -108,16 +108,12 @@ class BoardService(BoardServiceABC):
|
||||
|
||||
def get_dto(self, board_id: str) -> BoardDTO:
|
||||
board_record = self._services.board_records.get(board_id)
|
||||
cover_image = self._services.image_records.get_most_recent_image_for_board(
|
||||
board_record.board_id
|
||||
)
|
||||
cover_image = self._services.image_records.get_most_recent_image_for_board(board_record.board_id)
|
||||
if cover_image:
|
||||
cover_image_name = cover_image.image_name
|
||||
else:
|
||||
cover_image_name = None
|
||||
image_count = self._services.board_image_records.get_image_count_for_board(
|
||||
board_id
|
||||
)
|
||||
image_count = self._services.board_image_records.get_image_count_for_board(board_id)
|
||||
return board_record_to_dto(board_record, cover_image_name, image_count)
|
||||
|
||||
def update(
|
||||
@ -126,60 +122,44 @@ class BoardService(BoardServiceABC):
|
||||
changes: BoardChanges,
|
||||
) -> BoardDTO:
|
||||
board_record = self._services.board_records.update(board_id, changes)
|
||||
cover_image = self._services.image_records.get_most_recent_image_for_board(
|
||||
board_record.board_id
|
||||
)
|
||||
cover_image = self._services.image_records.get_most_recent_image_for_board(board_record.board_id)
|
||||
if cover_image:
|
||||
cover_image_name = cover_image.image_name
|
||||
else:
|
||||
cover_image_name = None
|
||||
|
||||
image_count = self._services.board_image_records.get_image_count_for_board(
|
||||
board_id
|
||||
)
|
||||
image_count = self._services.board_image_records.get_image_count_for_board(board_id)
|
||||
return board_record_to_dto(board_record, cover_image_name, image_count)
|
||||
|
||||
def delete(self, board_id: str) -> None:
|
||||
self._services.board_records.delete(board_id)
|
||||
|
||||
def get_many(
|
||||
self, offset: int = 0, limit: int = 10
|
||||
) -> OffsetPaginatedResults[BoardDTO]:
|
||||
def get_many(self, offset: int = 0, limit: int = 10) -> OffsetPaginatedResults[BoardDTO]:
|
||||
board_records = self._services.board_records.get_many(offset, limit)
|
||||
board_dtos = []
|
||||
for r in board_records.items:
|
||||
cover_image = self._services.image_records.get_most_recent_image_for_board(
|
||||
r.board_id
|
||||
)
|
||||
cover_image = self._services.image_records.get_most_recent_image_for_board(r.board_id)
|
||||
if cover_image:
|
||||
cover_image_name = cover_image.image_name
|
||||
else:
|
||||
cover_image_name = None
|
||||
|
||||
image_count = self._services.board_image_records.get_image_count_for_board(
|
||||
r.board_id
|
||||
)
|
||||
image_count = self._services.board_image_records.get_image_count_for_board(r.board_id)
|
||||
board_dtos.append(board_record_to_dto(r, cover_image_name, image_count))
|
||||
|
||||
return OffsetPaginatedResults[BoardDTO](
|
||||
items=board_dtos, offset=offset, limit=limit, total=len(board_dtos)
|
||||
)
|
||||
return OffsetPaginatedResults[BoardDTO](items=board_dtos, offset=offset, limit=limit, total=len(board_dtos))
|
||||
|
||||
def get_all(self) -> list[BoardDTO]:
|
||||
board_records = self._services.board_records.get_all()
|
||||
board_dtos = []
|
||||
for r in board_records:
|
||||
cover_image = self._services.image_records.get_most_recent_image_for_board(
|
||||
r.board_id
|
||||
)
|
||||
cover_image = self._services.image_records.get_most_recent_image_for_board(r.board_id)
|
||||
if cover_image:
|
||||
cover_image_name = cover_image.image_name
|
||||
else:
|
||||
cover_image_name = None
|
||||
|
||||
image_count = self._services.board_image_records.get_image_count_for_board(
|
||||
r.board_id
|
||||
)
|
||||
image_count = self._services.board_image_records.get_image_count_for_board(r.board_id)
|
||||
board_dtos.append(board_record_to_dto(r, cover_image_name, image_count))
|
||||
|
||||
return board_dtos
|
||||
return board_dtos
|
||||
|
@ -1,6 +1,6 @@
|
||||
# Copyright (c) 2023 Lincoln Stein (https://github.com/lstein) and the InvokeAI Development Team
|
||||
|
||||
'''Invokeai configuration system.
|
||||
"""Invokeai configuration system.
|
||||
|
||||
Arguments and fields are taken from the pydantic definition of the
|
||||
model. Defaults can be set by creating a yaml configuration file that
|
||||
@ -28,7 +28,6 @@ InvokeAI:
|
||||
always_use_cpu: false
|
||||
free_gpu_mem: false
|
||||
Features:
|
||||
nsfw_checker: true
|
||||
restore: true
|
||||
esrgan: true
|
||||
patchmatch: true
|
||||
@ -92,18 +91,18 @@ Typical usage at the top level file:
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
|
||||
# get global configuration and print its nsfw_checker value
|
||||
# get global configuration and print its cache size
|
||||
conf = InvokeAIAppConfig.get_config()
|
||||
conf.parse_args()
|
||||
print(conf.nsfw_checker)
|
||||
print(conf.max_cache_size)
|
||||
|
||||
Typical usage in a backend module:
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
|
||||
# get global configuration and print its nsfw_checker value
|
||||
# get global configuration and print its cache size value
|
||||
conf = InvokeAIAppConfig.get_config()
|
||||
print(conf.nsfw_checker)
|
||||
print(conf.max_cache_size)
|
||||
|
||||
|
||||
Computed properties:
|
||||
@ -159,7 +158,7 @@ two configs are kept in separate sections of the config file:
|
||||
outdir: outputs
|
||||
...
|
||||
|
||||
'''
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import argparse
|
||||
import pydoc
|
||||
@ -171,64 +170,68 @@ from pathlib import Path
|
||||
from pydantic import BaseSettings, Field, parse_obj_as
|
||||
from typing import ClassVar, Dict, List, Set, Literal, Union, get_origin, get_type_hints, get_args
|
||||
|
||||
INIT_FILE = Path('invokeai.yaml')
|
||||
MODEL_CORE = Path('models/core')
|
||||
DB_FILE = Path('invokeai.db')
|
||||
LEGACY_INIT_FILE = Path('invokeai.init')
|
||||
INIT_FILE = Path("invokeai.yaml")
|
||||
MODEL_CORE = Path("models/core")
|
||||
DB_FILE = Path("invokeai.db")
|
||||
LEGACY_INIT_FILE = Path("invokeai.init")
|
||||
|
||||
|
||||
class InvokeAISettings(BaseSettings):
|
||||
'''
|
||||
"""
|
||||
Runtime configuration settings in which default values are
|
||||
read from an omegaconf .yaml file.
|
||||
'''
|
||||
initconf : ClassVar[DictConfig] = None
|
||||
argparse_groups : ClassVar[Dict] = {}
|
||||
"""
|
||||
|
||||
def parse_args(self, argv: list=sys.argv[1:]):
|
||||
initconf: ClassVar[DictConfig] = None
|
||||
argparse_groups: ClassVar[Dict] = {}
|
||||
|
||||
def parse_args(self, argv: list = sys.argv[1:]):
|
||||
parser = self.get_parser()
|
||||
opt = parser.parse_args(argv)
|
||||
for name in self.__fields__:
|
||||
if name not in self._excluded():
|
||||
setattr(self, name, getattr(opt,name))
|
||||
setattr(self, name, getattr(opt, name))
|
||||
|
||||
def to_yaml(self)->str:
|
||||
def to_yaml(self) -> str:
|
||||
"""
|
||||
Return a YAML string representing our settings. This can be used
|
||||
as the contents of `invokeai.yaml` to restore settings later.
|
||||
"""
|
||||
cls = self.__class__
|
||||
type = get_args(get_type_hints(cls)['type'])[0]
|
||||
field_dict = dict({type:dict()})
|
||||
for name,field in self.__fields__.items():
|
||||
type = get_args(get_type_hints(cls)["type"])[0]
|
||||
field_dict = dict({type: dict()})
|
||||
for name, field in self.__fields__.items():
|
||||
if name in cls._excluded_from_yaml():
|
||||
continue
|
||||
category = field.field_info.extra.get("category") or "Uncategorized"
|
||||
value = getattr(self,name)
|
||||
value = getattr(self, name)
|
||||
if category not in field_dict[type]:
|
||||
field_dict[type][category] = dict()
|
||||
# keep paths as strings to make it easier to read
|
||||
field_dict[type][category][name] = str(value) if isinstance(value,Path) else value
|
||||
field_dict[type][category][name] = str(value) if isinstance(value, Path) else value
|
||||
conf = OmegaConf.create(field_dict)
|
||||
return OmegaConf.to_yaml(conf)
|
||||
|
||||
@classmethod
|
||||
def add_parser_arguments(cls, parser):
|
||||
if 'type' in get_type_hints(cls):
|
||||
settings_stanza = get_args(get_type_hints(cls)['type'])[0]
|
||||
if "type" in get_type_hints(cls):
|
||||
settings_stanza = get_args(get_type_hints(cls)["type"])[0]
|
||||
else:
|
||||
settings_stanza = "Uncategorized"
|
||||
|
||||
env_prefix = cls.Config.env_prefix if hasattr(cls.Config,'env_prefix') else settings_stanza.upper()
|
||||
env_prefix = cls.Config.env_prefix if hasattr(cls.Config, "env_prefix") else settings_stanza.upper()
|
||||
|
||||
initconf = cls.initconf.get(settings_stanza) \
|
||||
if cls.initconf and settings_stanza in cls.initconf \
|
||||
else OmegaConf.create()
|
||||
initconf = (
|
||||
cls.initconf.get(settings_stanza)
|
||||
if cls.initconf and settings_stanza in cls.initconf
|
||||
else OmegaConf.create()
|
||||
)
|
||||
|
||||
# create an upcase version of the environment in
|
||||
# order to achieve case-insensitive environment
|
||||
# variables (the way Windows does)
|
||||
upcase_environ = dict()
|
||||
for key,value in os.environ.items():
|
||||
for key, value in os.environ.items():
|
||||
upcase_environ[key.upper()] = value
|
||||
|
||||
fields = cls.__fields__
|
||||
@ -238,8 +241,8 @@ class InvokeAISettings(BaseSettings):
|
||||
if name not in cls._excluded():
|
||||
current_default = field.default
|
||||
|
||||
category = field.field_info.extra.get("category","Uncategorized")
|
||||
env_name = env_prefix + '_' + name
|
||||
category = field.field_info.extra.get("category", "Uncategorized")
|
||||
env_name = env_prefix + "_" + name
|
||||
if category in initconf and name in initconf.get(category):
|
||||
field.default = initconf.get(category).get(name)
|
||||
if env_name.upper() in upcase_environ:
|
||||
@ -249,15 +252,15 @@ class InvokeAISettings(BaseSettings):
|
||||
field.default = current_default
|
||||
|
||||
@classmethod
|
||||
def cmd_name(self, command_field: str='type')->str:
|
||||
def cmd_name(self, command_field: str = "type") -> str:
|
||||
hints = get_type_hints(self)
|
||||
if command_field in hints:
|
||||
return get_args(hints[command_field])[0]
|
||||
else:
|
||||
return 'Uncategorized'
|
||||
return "Uncategorized"
|
||||
|
||||
@classmethod
|
||||
def get_parser(cls)->ArgumentParser:
|
||||
def get_parser(cls) -> ArgumentParser:
|
||||
parser = PagingArgumentParser(
|
||||
prog=cls.cmd_name(),
|
||||
description=cls.__doc__,
|
||||
@ -270,24 +273,41 @@ class InvokeAISettings(BaseSettings):
|
||||
parser.add_parser(cls.cmd_name(), help=cls.__doc__)
|
||||
|
||||
@classmethod
|
||||
def _excluded(self)->List[str]:
|
||||
def _excluded(self) -> List[str]:
|
||||
# internal fields that shouldn't be exposed as command line options
|
||||
return ['type','initconf']
|
||||
|
||||
return ["type", "initconf"]
|
||||
|
||||
@classmethod
|
||||
def _excluded_from_yaml(self)->List[str]:
|
||||
def _excluded_from_yaml(self) -> List[str]:
|
||||
# combination of deprecated parameters and internal ones that shouldn't be exposed as invokeai.yaml options
|
||||
return ['type','initconf', 'gpu_mem_reserved', 'max_loaded_models', 'version', 'from_file', 'model', 'restore', 'root']
|
||||
return [
|
||||
"type",
|
||||
"initconf",
|
||||
"gpu_mem_reserved",
|
||||
"max_loaded_models",
|
||||
"version",
|
||||
"from_file",
|
||||
"model",
|
||||
"restore",
|
||||
"root",
|
||||
"nsfw_checker",
|
||||
]
|
||||
|
||||
class Config:
|
||||
env_file_encoding = 'utf-8'
|
||||
env_file_encoding = "utf-8"
|
||||
arbitrary_types_allowed = True
|
||||
case_sensitive = True
|
||||
|
||||
@classmethod
|
||||
def add_field_argument(cls, command_parser, name: str, field, default_override = None):
|
||||
def add_field_argument(cls, command_parser, name: str, field, default_override=None):
|
||||
field_type = get_type_hints(cls).get(name)
|
||||
default = default_override if default_override is not None else field.default if field.default_factory is None else field.default_factory()
|
||||
default = (
|
||||
default_override
|
||||
if default_override is not None
|
||||
else field.default
|
||||
if field.default_factory is None
|
||||
else field.default_factory()
|
||||
)
|
||||
if category := field.field_info.extra.get("category"):
|
||||
if category not in cls.argparse_groups:
|
||||
cls.argparse_groups[category] = command_parser.add_argument_group(category)
|
||||
@ -316,10 +336,10 @@ class InvokeAISettings(BaseSettings):
|
||||
argparse_group.add_argument(
|
||||
f"--{name}",
|
||||
dest=name,
|
||||
nargs='*',
|
||||
nargs="*",
|
||||
type=field.type_,
|
||||
default=default,
|
||||
action=argparse.BooleanOptionalAction if field.type_==bool else 'store',
|
||||
action=argparse.BooleanOptionalAction if field.type_ == bool else "store",
|
||||
help=field.field_info.description,
|
||||
)
|
||||
else:
|
||||
@ -328,31 +348,35 @@ class InvokeAISettings(BaseSettings):
|
||||
dest=name,
|
||||
type=field.type_,
|
||||
default=default,
|
||||
action=argparse.BooleanOptionalAction if field.type_==bool else 'store',
|
||||
action=argparse.BooleanOptionalAction if field.type_ == bool else "store",
|
||||
help=field.field_info.description,
|
||||
)
|
||||
def _find_root()->Path:
|
||||
|
||||
|
||||
def _find_root() -> Path:
|
||||
venv = Path(os.environ.get("VIRTUAL_ENV") or ".")
|
||||
if os.environ.get("INVOKEAI_ROOT"):
|
||||
root = Path(os.environ.get("INVOKEAI_ROOT")).resolve()
|
||||
elif any([(venv.parent/x).exists() for x in [INIT_FILE, LEGACY_INIT_FILE, MODEL_CORE]]):
|
||||
elif any([(venv.parent / x).exists() for x in [INIT_FILE, LEGACY_INIT_FILE, MODEL_CORE]]):
|
||||
root = (venv.parent).resolve()
|
||||
else:
|
||||
root = Path("~/invokeai").expanduser().resolve()
|
||||
return root
|
||||
|
||||
|
||||
class InvokeAIAppConfig(InvokeAISettings):
|
||||
'''
|
||||
Generate images using Stable Diffusion. Use "invokeai" to launch
|
||||
the command-line client (recommended for experts only), or
|
||||
"invokeai-web" to launch the web server. Global options
|
||||
can be changed by editing the file "INVOKEAI_ROOT/invokeai.yaml" or by
|
||||
setting environment variables INVOKEAI_<setting>.
|
||||
'''
|
||||
"""
|
||||
Generate images using Stable Diffusion. Use "invokeai" to launch
|
||||
the command-line client (recommended for experts only), or
|
||||
"invokeai-web" to launch the web server. Global options
|
||||
can be changed by editing the file "INVOKEAI_ROOT/invokeai.yaml" or by
|
||||
setting environment variables INVOKEAI_<setting>.
|
||||
"""
|
||||
|
||||
singleton_config: ClassVar[InvokeAIAppConfig] = None
|
||||
singleton_init: ClassVar[Dict] = None
|
||||
|
||||
#fmt: off
|
||||
# fmt: off
|
||||
type: Literal["InvokeAI"] = "InvokeAI"
|
||||
host : str = Field(default="127.0.0.1", description="IP address to bind to", category='Web Server')
|
||||
port : int = Field(default=9090, description="Port to bind to", category='Web Server')
|
||||
@ -364,7 +388,6 @@ setting environment variables INVOKEAI_<setting>.
|
||||
esrgan : bool = Field(default=True, description="Enable/disable upscaling code", category='Features')
|
||||
internet_available : bool = Field(default=True, description="If true, attempt to download models on the fly; otherwise only use local models", category='Features')
|
||||
log_tokenization : bool = Field(default=False, description="Enable logging of parsed prompt tokens.", category='Features')
|
||||
nsfw_checker : bool = Field(default=True, description="Enable/disable the NSFW checker", category='Features')
|
||||
patchmatch : bool = Field(default=True, description="Enable/disable patchmatch inpaint code", category='Features')
|
||||
restore : bool = Field(default=True, description="Enable/disable face restoration code (DEPRECATED)", category='DEPRECATED')
|
||||
|
||||
@ -374,6 +397,7 @@ setting environment variables INVOKEAI_<setting>.
|
||||
max_cache_size : float = Field(default=6.0, gt=0, description="Maximum memory amount used by model cache for rapid switching", category='Memory/Performance')
|
||||
max_vram_cache_size : float = Field(default=2.75, ge=0, description="Amount of VRAM reserved for model storage", category='Memory/Performance')
|
||||
gpu_mem_reserved : float = Field(default=2.75, ge=0, description="DEPRECATED: use max_vram_cache_size. Amount of VRAM reserved for model storage", category='DEPRECATED')
|
||||
nsfw_checker : bool = Field(default=True, description="DEPRECATED: use Web settings to enable/disable", category='DEPRECATED')
|
||||
precision : Literal[tuple(['auto','float16','float32','autocast'])] = Field(default='auto',description='Floating point precision', category='Memory/Performance')
|
||||
sequential_guidance : bool = Field(default=False, description="Whether to calculate guidance in serial instead of in parallel, lowering memory requirements", category='Memory/Performance')
|
||||
xformers_enabled : bool = Field(default=True, description="Enable/disable memory-efficient attention", category='Memory/Performance')
|
||||
@ -400,16 +424,16 @@ setting environment variables INVOKEAI_<setting>.
|
||||
log_level : Literal[tuple(["debug","info","warning","error","critical"])] = Field(default="info", description="Emit logging messages at this level or higher", category="Logging")
|
||||
|
||||
version : bool = Field(default=False, description="Show InvokeAI version and exit", category="Other")
|
||||
#fmt: on
|
||||
# fmt: on
|
||||
|
||||
def parse_args(self, argv: List[str]=None, conf: DictConfig = None, clobber=False):
|
||||
'''
|
||||
def parse_args(self, argv: List[str] = None, conf: DictConfig = None, clobber=False):
|
||||
"""
|
||||
Update settings with contents of init file, environment, and
|
||||
command-line settings.
|
||||
:param conf: alternate Omegaconf dictionary object
|
||||
:param argv: aternate sys.argv list
|
||||
:param clobber: ovewrite any initialization parameters passed during initialization
|
||||
'''
|
||||
"""
|
||||
# Set the runtime root directory. We parse command-line switches here
|
||||
# in order to pick up the --root_dir option.
|
||||
super().parse_args(argv)
|
||||
@ -426,125 +450,139 @@ setting environment variables INVOKEAI_<setting>.
|
||||
if self.singleton_init and not clobber:
|
||||
hints = get_type_hints(self.__class__)
|
||||
for k in self.singleton_init:
|
||||
setattr(self,k,parse_obj_as(hints[k],self.singleton_init[k]))
|
||||
setattr(self, k, parse_obj_as(hints[k], self.singleton_init[k]))
|
||||
|
||||
@classmethod
|
||||
def get_config(cls,**kwargs)->InvokeAIAppConfig:
|
||||
'''
|
||||
def get_config(cls, **kwargs) -> InvokeAIAppConfig:
|
||||
"""
|
||||
This returns a singleton InvokeAIAppConfig configuration object.
|
||||
'''
|
||||
if cls.singleton_config is None \
|
||||
or type(cls.singleton_config)!=cls \
|
||||
or (kwargs and cls.singleton_init != kwargs):
|
||||
"""
|
||||
if (
|
||||
cls.singleton_config is None
|
||||
or type(cls.singleton_config) != cls
|
||||
or (kwargs and cls.singleton_init != kwargs)
|
||||
):
|
||||
cls.singleton_config = cls(**kwargs)
|
||||
cls.singleton_init = kwargs
|
||||
return cls.singleton_config
|
||||
|
||||
@property
|
||||
def root_path(self)->Path:
|
||||
'''
|
||||
def root_path(self) -> Path:
|
||||
"""
|
||||
Path to the runtime root directory
|
||||
'''
|
||||
"""
|
||||
if self.root:
|
||||
return Path(self.root).expanduser().absolute()
|
||||
else:
|
||||
return self.find_root()
|
||||
|
||||
@property
|
||||
def root_dir(self)->Path:
|
||||
'''
|
||||
def root_dir(self) -> Path:
|
||||
"""
|
||||
Alias for above.
|
||||
'''
|
||||
"""
|
||||
return self.root_path
|
||||
|
||||
def _resolve(self,partial_path:Path)->Path:
|
||||
def _resolve(self, partial_path: Path) -> Path:
|
||||
return (self.root_path / partial_path).resolve()
|
||||
|
||||
@property
|
||||
def init_file_path(self)->Path:
|
||||
'''
|
||||
def init_file_path(self) -> Path:
|
||||
"""
|
||||
Path to invokeai.yaml
|
||||
'''
|
||||
"""
|
||||
return self._resolve(INIT_FILE)
|
||||
|
||||
@property
|
||||
def output_path(self)->Path:
|
||||
'''
|
||||
def output_path(self) -> Path:
|
||||
"""
|
||||
Path to defaults outputs directory.
|
||||
'''
|
||||
"""
|
||||
return self._resolve(self.outdir)
|
||||
|
||||
@property
|
||||
def db_path(self)->Path:
|
||||
'''
|
||||
def db_path(self) -> Path:
|
||||
"""
|
||||
Path to the invokeai.db file.
|
||||
'''
|
||||
"""
|
||||
return self._resolve(self.db_dir) / DB_FILE
|
||||
|
||||
@property
|
||||
def model_conf_path(self)->Path:
|
||||
'''
|
||||
def model_conf_path(self) -> Path:
|
||||
"""
|
||||
Path to models configuration file.
|
||||
'''
|
||||
"""
|
||||
return self._resolve(self.conf_path)
|
||||
|
||||
@property
|
||||
def legacy_conf_path(self)->Path:
|
||||
'''
|
||||
def legacy_conf_path(self) -> Path:
|
||||
"""
|
||||
Path to directory of legacy configuration files (e.g. v1-inference.yaml)
|
||||
'''
|
||||
"""
|
||||
return self._resolve(self.legacy_conf_dir)
|
||||
|
||||
@property
|
||||
def models_path(self)->Path:
|
||||
'''
|
||||
def models_path(self) -> Path:
|
||||
"""
|
||||
Path to the models directory
|
||||
'''
|
||||
"""
|
||||
return self._resolve(self.models_dir)
|
||||
|
||||
@property
|
||||
def autoconvert_path(self)->Path:
|
||||
'''
|
||||
def autoconvert_path(self) -> Path:
|
||||
"""
|
||||
Path to the directory containing models to be imported automatically at startup.
|
||||
'''
|
||||
"""
|
||||
return self._resolve(self.autoconvert_dir) if self.autoconvert_dir else None
|
||||
|
||||
# the following methods support legacy calls leftover from the Globals era
|
||||
@property
|
||||
def full_precision(self)->bool:
|
||||
def full_precision(self) -> bool:
|
||||
"""Return true if precision set to float32"""
|
||||
return self.precision=='float32'
|
||||
return self.precision == "float32"
|
||||
|
||||
@property
|
||||
def disable_xformers(self)->bool:
|
||||
def disable_xformers(self) -> bool:
|
||||
"""Return true if xformers_enabled is false"""
|
||||
return not self.xformers_enabled
|
||||
|
||||
@property
|
||||
def try_patchmatch(self)->bool:
|
||||
def try_patchmatch(self) -> bool:
|
||||
"""Return true if patchmatch true"""
|
||||
return self.patchmatch
|
||||
|
||||
@property
|
||||
def nsfw_checker(self) -> bool:
|
||||
"""NSFW node is always active and disabled from Web UIe"""
|
||||
return True
|
||||
|
||||
@property
|
||||
def invisible_watermark(self) -> bool:
|
||||
"""invisible watermark node is always active and disabled from Web UIe"""
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def find_root()->Path:
|
||||
'''
|
||||
def find_root() -> Path:
|
||||
"""
|
||||
Choose the runtime root directory when not specified on command line or
|
||||
init file.
|
||||
'''
|
||||
"""
|
||||
return _find_root()
|
||||
|
||||
|
||||
class PagingArgumentParser(argparse.ArgumentParser):
|
||||
'''
|
||||
"""
|
||||
A custom ArgumentParser that uses pydoc to page its output.
|
||||
It also supports reading defaults from an init file.
|
||||
'''
|
||||
"""
|
||||
|
||||
def print_help(self, file=None):
|
||||
text = self.format_help()
|
||||
pydoc.pager(text)
|
||||
|
||||
def get_invokeai_config(**kwargs)->InvokeAIAppConfig:
|
||||
'''
|
||||
|
||||
def get_invokeai_config(**kwargs) -> InvokeAIAppConfig:
|
||||
"""
|
||||
Legacy function which returns InvokeAIAppConfig.get_config()
|
||||
'''
|
||||
"""
|
||||
return InvokeAIAppConfig.get_config(**kwargs)
|
||||
|
@ -1,4 +1,5 @@
|
||||
from ..invocations.latent import LatentsToImageInvocation, TextToLatentsInvocation
|
||||
from ..invocations.image import ImageNSFWBlurInvocation
|
||||
from ..invocations.noise import NoiseInvocation
|
||||
from ..invocations.compel import CompelInvocation
|
||||
from ..invocations.params import ParamIntInvocation
|
||||
@ -6,55 +7,80 @@ from .graph import Edge, EdgeConnection, ExposedNodeInput, ExposedNodeOutput, Gr
|
||||
from .item_storage import ItemStorageABC
|
||||
|
||||
|
||||
default_text_to_image_graph_id = '539b2af5-2b4d-4d8c-8071-e54a3255fc74'
|
||||
default_text_to_image_graph_id = "539b2af5-2b4d-4d8c-8071-e54a3255fc74"
|
||||
|
||||
|
||||
def create_text_to_image() -> LibraryGraph:
|
||||
return LibraryGraph(
|
||||
id=default_text_to_image_graph_id,
|
||||
name='t2i',
|
||||
description='Converts text to an image',
|
||||
name="t2i",
|
||||
description="Converts text to an image",
|
||||
graph=Graph(
|
||||
nodes={
|
||||
'width': ParamIntInvocation(id='width', a=512),
|
||||
'height': ParamIntInvocation(id='height', a=512),
|
||||
'seed': ParamIntInvocation(id='seed', a=-1),
|
||||
'3': NoiseInvocation(id='3'),
|
||||
'4': CompelInvocation(id='4'),
|
||||
'5': CompelInvocation(id='5'),
|
||||
'6': TextToLatentsInvocation(id='6'),
|
||||
'7': LatentsToImageInvocation(id='7'),
|
||||
"width": ParamIntInvocation(id="width", a=512),
|
||||
"height": ParamIntInvocation(id="height", a=512),
|
||||
"seed": ParamIntInvocation(id="seed", a=-1),
|
||||
"3": NoiseInvocation(id="3"),
|
||||
"4": CompelInvocation(id="4"),
|
||||
"5": CompelInvocation(id="5"),
|
||||
"6": TextToLatentsInvocation(id="6"),
|
||||
"7": LatentsToImageInvocation(id="7"),
|
||||
"8": ImageNSFWBlurInvocation(id="8"),
|
||||
},
|
||||
edges=[
|
||||
Edge(source=EdgeConnection(node_id='width', field='a'), destination=EdgeConnection(node_id='3', field='width')),
|
||||
Edge(source=EdgeConnection(node_id='height', field='a'), destination=EdgeConnection(node_id='3', field='height')),
|
||||
Edge(source=EdgeConnection(node_id='seed', field='a'), destination=EdgeConnection(node_id='3', field='seed')),
|
||||
Edge(source=EdgeConnection(node_id='3', field='noise'), destination=EdgeConnection(node_id='6', field='noise')),
|
||||
Edge(source=EdgeConnection(node_id='6', field='latents'), destination=EdgeConnection(node_id='7', field='latents')),
|
||||
Edge(source=EdgeConnection(node_id='4', field='conditioning'), destination=EdgeConnection(node_id='6', field='positive_conditioning')),
|
||||
Edge(source=EdgeConnection(node_id='5', field='conditioning'), destination=EdgeConnection(node_id='6', field='negative_conditioning')),
|
||||
]
|
||||
Edge(
|
||||
source=EdgeConnection(node_id="width", field="a"),
|
||||
destination=EdgeConnection(node_id="3", field="width"),
|
||||
),
|
||||
Edge(
|
||||
source=EdgeConnection(node_id="height", field="a"),
|
||||
destination=EdgeConnection(node_id="3", field="height"),
|
||||
),
|
||||
Edge(
|
||||
source=EdgeConnection(node_id="seed", field="a"),
|
||||
destination=EdgeConnection(node_id="3", field="seed"),
|
||||
),
|
||||
Edge(
|
||||
source=EdgeConnection(node_id="3", field="noise"),
|
||||
destination=EdgeConnection(node_id="6", field="noise"),
|
||||
),
|
||||
Edge(
|
||||
source=EdgeConnection(node_id="6", field="latents"),
|
||||
destination=EdgeConnection(node_id="7", field="latents"),
|
||||
),
|
||||
Edge(
|
||||
source=EdgeConnection(node_id="4", field="conditioning"),
|
||||
destination=EdgeConnection(node_id="6", field="positive_conditioning"),
|
||||
),
|
||||
Edge(
|
||||
source=EdgeConnection(node_id="5", field="conditioning"),
|
||||
destination=EdgeConnection(node_id="6", field="negative_conditioning"),
|
||||
),
|
||||
Edge(
|
||||
source=EdgeConnection(node_id="7", field="image"),
|
||||
destination=EdgeConnection(node_id="8", field="image"),
|
||||
),
|
||||
],
|
||||
),
|
||||
exposed_inputs=[
|
||||
ExposedNodeInput(node_path='4', field='prompt', alias='positive_prompt'),
|
||||
ExposedNodeInput(node_path='5', field='prompt', alias='negative_prompt'),
|
||||
ExposedNodeInput(node_path='width', field='a', alias='width'),
|
||||
ExposedNodeInput(node_path='height', field='a', alias='height'),
|
||||
ExposedNodeInput(node_path='seed', field='a', alias='seed'),
|
||||
ExposedNodeInput(node_path="4", field="prompt", alias="positive_prompt"),
|
||||
ExposedNodeInput(node_path="5", field="prompt", alias="negative_prompt"),
|
||||
ExposedNodeInput(node_path="width", field="a", alias="width"),
|
||||
ExposedNodeInput(node_path="height", field="a", alias="height"),
|
||||
ExposedNodeInput(node_path="seed", field="a", alias="seed"),
|
||||
],
|
||||
exposed_outputs=[
|
||||
ExposedNodeOutput(node_path='7', field='image', alias='image')
|
||||
])
|
||||
exposed_outputs=[ExposedNodeOutput(node_path="8", field="image", alias="image")],
|
||||
)
|
||||
|
||||
|
||||
def create_system_graphs(graph_library: ItemStorageABC[LibraryGraph]) -> list[LibraryGraph]:
|
||||
"""Creates the default system graphs, or adds new versions if the old ones don't match"""
|
||||
|
||||
|
||||
# TODO: Uncomment this when we are ready to fix this up to prevent breaking changes
|
||||
graphs: list[LibraryGraph] = list()
|
||||
|
||||
# text_to_image = graph_library.get(default_text_to_image_graph_id)
|
||||
|
||||
|
||||
# # TODO: Check if the graph is the same as the default one, and if not, update it
|
||||
# #if text_to_image is None:
|
||||
text_to_image = create_text_to_image()
|
||||
|
@ -44,9 +44,7 @@ class EventServiceBase:
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
node=node,
|
||||
source_node_id=source_node_id,
|
||||
progress_image=progress_image.dict()
|
||||
if progress_image is not None
|
||||
else None,
|
||||
progress_image=progress_image.dict() if progress_image is not None else None,
|
||||
step=step,
|
||||
total_steps=total_steps,
|
||||
),
|
||||
@ -90,9 +88,7 @@ class EventServiceBase:
|
||||
),
|
||||
)
|
||||
|
||||
def emit_invocation_started(
|
||||
self, graph_execution_state_id: str, node: dict, source_node_id: str
|
||||
) -> None:
|
||||
def emit_invocation_started(self, graph_execution_state_id: str, node: dict, source_node_id: str) -> None:
|
||||
"""Emitted when an invocation has started"""
|
||||
self.__emit_session_event(
|
||||
event_name="invocation_started",
|
||||
|
@ -28,6 +28,7 @@ from ..invocations.baseinvocation import (
|
||||
# in 3.10 this would be "from types import NoneType"
|
||||
NoneType = type(None)
|
||||
|
||||
|
||||
class EdgeConnection(BaseModel):
|
||||
node_id: str = Field(description="The id of the node for this edge connection")
|
||||
field: str = Field(description="The field for this connection")
|
||||
@ -61,6 +62,7 @@ def get_input_field(node: BaseInvocation, field: str) -> Any:
|
||||
node_input_field = node_inputs.get(field) or None
|
||||
return node_input_field
|
||||
|
||||
|
||||
def is_union_subtype(t1, t2):
|
||||
t1_args = get_args(t1)
|
||||
t2_args = get_args(t2)
|
||||
@ -71,6 +73,7 @@ def is_union_subtype(t1, t2):
|
||||
# t1 is a Union, check that all of its types are in t2_args
|
||||
return all(arg in t2_args for arg in t1_args)
|
||||
|
||||
|
||||
def is_list_or_contains_list(t):
|
||||
t_args = get_args(t)
|
||||
|
||||
@ -154,15 +157,17 @@ class GraphInvocationOutput(BaseInvocationOutput):
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
'required': [
|
||||
'type',
|
||||
'image',
|
||||
"required": [
|
||||
"type",
|
||||
"image",
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
# TODO: Fill this out and move to invocations
|
||||
class GraphInvocation(BaseInvocation):
|
||||
"""Execute a graph"""
|
||||
|
||||
type: Literal["graph"] = "graph"
|
||||
|
||||
# TODO: figure out how to create a default here
|
||||
@ -182,23 +187,21 @@ class IterateInvocationOutput(BaseInvocationOutput):
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
'required': [
|
||||
'type',
|
||||
'item',
|
||||
"required": [
|
||||
"type",
|
||||
"item",
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
# TODO: Fill this out and move to invocations
|
||||
class IterateInvocation(BaseInvocation):
|
||||
"""Iterates over a list of items"""
|
||||
|
||||
type: Literal["iterate"] = "iterate"
|
||||
|
||||
collection: list[Any] = Field(
|
||||
description="The list of items to iterate over", default_factory=list
|
||||
)
|
||||
index: int = Field(
|
||||
description="The index, will be provided on executed iterators", default=0
|
||||
)
|
||||
collection: list[Any] = Field(description="The list of items to iterate over", default_factory=list)
|
||||
index: int = Field(description="The index, will be provided on executed iterators", default=0)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IterateInvocationOutput:
|
||||
"""Produces the outputs as values"""
|
||||
@ -212,12 +215,13 @@ class CollectInvocationOutput(BaseInvocationOutput):
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
'required': [
|
||||
'type',
|
||||
'collection',
|
||||
"required": [
|
||||
"type",
|
||||
"collection",
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
class CollectInvocation(BaseInvocation):
|
||||
"""Collects values into a collection"""
|
||||
|
||||
@ -269,9 +273,7 @@ class Graph(BaseModel):
|
||||
if node_path in self.nodes:
|
||||
return (self, node_path)
|
||||
|
||||
node_id = (
|
||||
node_path if "." not in node_path else node_path[: node_path.index(".")]
|
||||
)
|
||||
node_id = node_path if "." not in node_path else node_path[: node_path.index(".")]
|
||||
if node_id not in self.nodes:
|
||||
raise NodeNotFoundError(f"Node {node_path} not found in graph")
|
||||
|
||||
@ -333,9 +335,7 @@ class Graph(BaseModel):
|
||||
return False
|
||||
|
||||
# Validate all edges reference nodes in the graph
|
||||
node_ids = set(
|
||||
[e.source.node_id for e in self.edges] + [e.destination.node_id for e in self.edges]
|
||||
)
|
||||
node_ids = set([e.source.node_id for e in self.edges] + [e.destination.node_id for e in self.edges])
|
||||
if not all((self.has_node(node_id) for node_id in node_ids)):
|
||||
return False
|
||||
|
||||
@ -361,22 +361,14 @@ class Graph(BaseModel):
|
||||
# Validate all iterators
|
||||
# TODO: may need to validate all iterators in subgraphs so edge connections in parent graphs will be available
|
||||
if not all(
|
||||
(
|
||||
self._is_iterator_connection_valid(n.id)
|
||||
for n in self.nodes.values()
|
||||
if isinstance(n, IterateInvocation)
|
||||
)
|
||||
(self._is_iterator_connection_valid(n.id) for n in self.nodes.values() if isinstance(n, IterateInvocation))
|
||||
):
|
||||
return False
|
||||
|
||||
# Validate all collectors
|
||||
# TODO: may need to validate all collectors in subgraphs so edge connections in parent graphs will be available
|
||||
if not all(
|
||||
(
|
||||
self._is_collector_connection_valid(n.id)
|
||||
for n in self.nodes.values()
|
||||
if isinstance(n, CollectInvocation)
|
||||
)
|
||||
(self._is_collector_connection_valid(n.id) for n in self.nodes.values() if isinstance(n, CollectInvocation))
|
||||
):
|
||||
return False
|
||||
|
||||
@ -395,48 +387,51 @@ class Graph(BaseModel):
|
||||
# Validate that an edge to this node+field doesn't already exist
|
||||
input_edges = self._get_input_edges(edge.destination.node_id, edge.destination.field)
|
||||
if len(input_edges) > 0 and not isinstance(to_node, CollectInvocation):
|
||||
raise InvalidEdgeError(f'Edge to node {edge.destination.node_id} field {edge.destination.field} already exists')
|
||||
raise InvalidEdgeError(
|
||||
f"Edge to node {edge.destination.node_id} field {edge.destination.field} already exists"
|
||||
)
|
||||
|
||||
# Validate that no cycles would be created
|
||||
g = self.nx_graph_flat()
|
||||
g.add_edge(edge.source.node_id, edge.destination.node_id)
|
||||
if not nx.is_directed_acyclic_graph(g):
|
||||
raise InvalidEdgeError(f'Edge creates a cycle in the graph: {edge.source.node_id} -> {edge.destination.node_id}')
|
||||
raise InvalidEdgeError(
|
||||
f"Edge creates a cycle in the graph: {edge.source.node_id} -> {edge.destination.node_id}"
|
||||
)
|
||||
|
||||
# Validate that the field types are compatible
|
||||
if not are_connections_compatible(
|
||||
from_node, edge.source.field, to_node, edge.destination.field
|
||||
):
|
||||
raise InvalidEdgeError(f'Fields are incompatible: cannot connect {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}')
|
||||
if not are_connections_compatible(from_node, edge.source.field, to_node, edge.destination.field):
|
||||
raise InvalidEdgeError(
|
||||
f"Fields are incompatible: cannot connect {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}"
|
||||
)
|
||||
|
||||
# Validate if iterator output type matches iterator input type (if this edge results in both being set)
|
||||
if isinstance(to_node, IterateInvocation) and edge.destination.field == "collection":
|
||||
if not self._is_iterator_connection_valid(
|
||||
edge.destination.node_id, new_input=edge.source
|
||||
):
|
||||
raise InvalidEdgeError(f'Iterator input type does not match iterator output type: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}')
|
||||
if not self._is_iterator_connection_valid(edge.destination.node_id, new_input=edge.source):
|
||||
raise InvalidEdgeError(
|
||||
f"Iterator input type does not match iterator output type: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}"
|
||||
)
|
||||
|
||||
# Validate if iterator input type matches output type (if this edge results in both being set)
|
||||
if isinstance(from_node, IterateInvocation) and edge.source.field == "item":
|
||||
if not self._is_iterator_connection_valid(
|
||||
edge.source.node_id, new_output=edge.destination
|
||||
):
|
||||
raise InvalidEdgeError(f'Iterator output type does not match iterator input type:, {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}')
|
||||
if not self._is_iterator_connection_valid(edge.source.node_id, new_output=edge.destination):
|
||||
raise InvalidEdgeError(
|
||||
f"Iterator output type does not match iterator input type:, {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}"
|
||||
)
|
||||
|
||||
# Validate if collector input type matches output type (if this edge results in both being set)
|
||||
if isinstance(to_node, CollectInvocation) and edge.destination.field == "item":
|
||||
if not self._is_collector_connection_valid(
|
||||
edge.destination.node_id, new_input=edge.source
|
||||
):
|
||||
raise InvalidEdgeError(f'Collector output type does not match collector input type: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}')
|
||||
if not self._is_collector_connection_valid(edge.destination.node_id, new_input=edge.source):
|
||||
raise InvalidEdgeError(
|
||||
f"Collector output type does not match collector input type: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}"
|
||||
)
|
||||
|
||||
# Validate if collector output type matches input type (if this edge results in both being set)
|
||||
if isinstance(from_node, CollectInvocation) and edge.source.field == "collection":
|
||||
if not self._is_collector_connection_valid(
|
||||
edge.source.node_id, new_output=edge.destination
|
||||
):
|
||||
raise InvalidEdgeError(f'Collector input type does not match collector output type: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}')
|
||||
|
||||
if not self._is_collector_connection_valid(edge.source.node_id, new_output=edge.destination):
|
||||
raise InvalidEdgeError(
|
||||
f"Collector input type does not match collector output type: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}"
|
||||
)
|
||||
|
||||
def has_node(self, node_path: str) -> bool:
|
||||
"""Determines whether or not a node exists in the graph."""
|
||||
@ -465,17 +460,13 @@ class Graph(BaseModel):
|
||||
|
||||
# Ensure the node type matches the new node
|
||||
if type(node) != type(new_node):
|
||||
raise TypeError(
|
||||
f"Node {node_path} is type {type(node)} but new node is type {type(new_node)}"
|
||||
)
|
||||
raise TypeError(f"Node {node_path} is type {type(node)} but new node is type {type(new_node)}")
|
||||
|
||||
# Ensure the new id is either the same or is not in the graph
|
||||
prefix = None if "." not in node_path else node_path[: node_path.rindex(".")]
|
||||
new_path = self._get_node_path(new_node.id, prefix=prefix)
|
||||
if new_node.id != node.id and self.has_node(new_path):
|
||||
raise NodeAlreadyInGraphError(
|
||||
"Node with id {new_node.id} already exists in graph"
|
||||
)
|
||||
raise NodeAlreadyInGraphError("Node with id {new_node.id} already exists in graph")
|
||||
|
||||
# Set the new node in the graph
|
||||
graph.nodes[new_node.id] = new_node
|
||||
@ -497,9 +488,7 @@ class Graph(BaseModel):
|
||||
graph.add_edge(
|
||||
Edge(
|
||||
source=edge.source,
|
||||
destination=EdgeConnection(
|
||||
node_id=new_graph_node_path, field=edge.destination.field
|
||||
)
|
||||
destination=EdgeConnection(node_id=new_graph_node_path, field=edge.destination.field),
|
||||
)
|
||||
)
|
||||
|
||||
@ -512,16 +501,12 @@ class Graph(BaseModel):
|
||||
)
|
||||
graph.add_edge(
|
||||
Edge(
|
||||
source=EdgeConnection(
|
||||
node_id=new_graph_node_path, field=edge.source.field
|
||||
),
|
||||
destination=edge.destination
|
||||
source=EdgeConnection(node_id=new_graph_node_path, field=edge.source.field),
|
||||
destination=edge.destination,
|
||||
)
|
||||
)
|
||||
|
||||
def _get_input_edges(
|
||||
self, node_path: str, field: Optional[str] = None
|
||||
) -> list[Edge]:
|
||||
def _get_input_edges(self, node_path: str, field: Optional[str] = None) -> list[Edge]:
|
||||
"""Gets all input edges for a node"""
|
||||
edges = self._get_input_edges_and_graphs(node_path)
|
||||
|
||||
@ -538,7 +523,7 @@ class Graph(BaseModel):
|
||||
destination=EdgeConnection(
|
||||
node_id=self._get_node_path(e.destination.node_id, prefix=prefix),
|
||||
field=e.destination.field,
|
||||
)
|
||||
),
|
||||
)
|
||||
for _, prefix, e in filtered_edges
|
||||
]
|
||||
@ -550,32 +535,20 @@ class Graph(BaseModel):
|
||||
edges = list()
|
||||
|
||||
# Return any input edges that appear in this graph
|
||||
edges.extend(
|
||||
[(self, prefix, e) for e in self.edges if e.destination.node_id == node_path]
|
||||
)
|
||||
edges.extend([(self, prefix, e) for e in self.edges if e.destination.node_id == node_path])
|
||||
|
||||
node_id = (
|
||||
node_path if "." not in node_path else node_path[: node_path.index(".")]
|
||||
)
|
||||
node_id = node_path if "." not in node_path else node_path[: node_path.index(".")]
|
||||
node = self.nodes[node_id]
|
||||
|
||||
if isinstance(node, GraphInvocation):
|
||||
graph = node.graph
|
||||
graph_path = (
|
||||
node.id
|
||||
if prefix is None or prefix == ""
|
||||
else self._get_node_path(node.id, prefix=prefix)
|
||||
)
|
||||
graph_edges = graph._get_input_edges_and_graphs(
|
||||
node_path[(len(node_id) + 1) :], prefix=graph_path
|
||||
)
|
||||
graph_path = node.id if prefix is None or prefix == "" else self._get_node_path(node.id, prefix=prefix)
|
||||
graph_edges = graph._get_input_edges_and_graphs(node_path[(len(node_id) + 1) :], prefix=graph_path)
|
||||
edges.extend(graph_edges)
|
||||
|
||||
return edges
|
||||
|
||||
def _get_output_edges(
|
||||
self, node_path: str, field: str
|
||||
) -> list[Edge]:
|
||||
def _get_output_edges(self, node_path: str, field: str) -> list[Edge]:
|
||||
"""Gets all output edges for a node"""
|
||||
edges = self._get_output_edges_and_graphs(node_path)
|
||||
|
||||
@ -592,7 +565,7 @@ class Graph(BaseModel):
|
||||
destination=EdgeConnection(
|
||||
node_id=self._get_node_path(e.destination.node_id, prefix=prefix),
|
||||
field=e.destination.field,
|
||||
)
|
||||
),
|
||||
)
|
||||
for _, prefix, e in filtered_edges
|
||||
]
|
||||
@ -604,25 +577,15 @@ class Graph(BaseModel):
|
||||
edges = list()
|
||||
|
||||
# Return any input edges that appear in this graph
|
||||
edges.extend(
|
||||
[(self, prefix, e) for e in self.edges if e.source.node_id == node_path]
|
||||
)
|
||||
edges.extend([(self, prefix, e) for e in self.edges if e.source.node_id == node_path])
|
||||
|
||||
node_id = (
|
||||
node_path if "." not in node_path else node_path[: node_path.index(".")]
|
||||
)
|
||||
node_id = node_path if "." not in node_path else node_path[: node_path.index(".")]
|
||||
node = self.nodes[node_id]
|
||||
|
||||
if isinstance(node, GraphInvocation):
|
||||
graph = node.graph
|
||||
graph_path = (
|
||||
node.id
|
||||
if prefix is None or prefix == ""
|
||||
else self._get_node_path(node.id, prefix=prefix)
|
||||
)
|
||||
graph_edges = graph._get_output_edges_and_graphs(
|
||||
node_path[(len(node_id) + 1) :], prefix=graph_path
|
||||
)
|
||||
graph_path = node.id if prefix is None or prefix == "" else self._get_node_path(node.id, prefix=prefix)
|
||||
graph_edges = graph._get_output_edges_and_graphs(node_path[(len(node_id) + 1) :], prefix=graph_path)
|
||||
edges.extend(graph_edges)
|
||||
|
||||
return edges
|
||||
@ -646,12 +609,8 @@ class Graph(BaseModel):
|
||||
return False
|
||||
|
||||
# Get input and output fields (the fields linked to the iterator's input/output)
|
||||
input_field = get_output_field(
|
||||
self.get_node(inputs[0].node_id), inputs[0].field
|
||||
)
|
||||
output_fields = list(
|
||||
[get_input_field(self.get_node(e.node_id), e.field) for e in outputs]
|
||||
)
|
||||
input_field = get_output_field(self.get_node(inputs[0].node_id), inputs[0].field)
|
||||
output_fields = list([get_input_field(self.get_node(e.node_id), e.field) for e in outputs])
|
||||
|
||||
# Input type must be a list
|
||||
if get_origin(input_field) != list:
|
||||
@ -659,12 +618,7 @@ class Graph(BaseModel):
|
||||
|
||||
# Validate that all outputs match the input type
|
||||
input_field_item_type = get_args(input_field)[0]
|
||||
if not all(
|
||||
(
|
||||
are_connection_types_compatible(input_field_item_type, f)
|
||||
for f in output_fields
|
||||
)
|
||||
):
|
||||
if not all((are_connection_types_compatible(input_field_item_type, f) for f in output_fields)):
|
||||
return False
|
||||
|
||||
return True
|
||||
@ -684,35 +638,21 @@ class Graph(BaseModel):
|
||||
outputs.append(new_output)
|
||||
|
||||
# Get input and output fields (the fields linked to the iterator's input/output)
|
||||
input_fields = list(
|
||||
[get_output_field(self.get_node(e.node_id), e.field) for e in inputs]
|
||||
)
|
||||
output_fields = list(
|
||||
[get_input_field(self.get_node(e.node_id), e.field) for e in outputs]
|
||||
)
|
||||
input_fields = list([get_output_field(self.get_node(e.node_id), e.field) for e in inputs])
|
||||
output_fields = list([get_input_field(self.get_node(e.node_id), e.field) for e in outputs])
|
||||
|
||||
# Validate that all inputs are derived from or match a single type
|
||||
input_field_types = set(
|
||||
[
|
||||
t
|
||||
for input_field in input_fields
|
||||
for t in (
|
||||
[input_field]
|
||||
if get_origin(input_field) == None
|
||||
else get_args(input_field)
|
||||
)
|
||||
for t in ([input_field] if get_origin(input_field) == None else get_args(input_field))
|
||||
if t != NoneType
|
||||
]
|
||||
) # Get unique types
|
||||
type_tree = nx.DiGraph()
|
||||
type_tree.add_nodes_from(input_field_types)
|
||||
type_tree.add_edges_from(
|
||||
[
|
||||
e
|
||||
for e in itertools.permutations(input_field_types, 2)
|
||||
if issubclass(e[1], e[0])
|
||||
]
|
||||
)
|
||||
type_tree.add_edges_from([e for e in itertools.permutations(input_field_types, 2) if issubclass(e[1], e[0])])
|
||||
type_degrees = type_tree.in_degree(type_tree.nodes)
|
||||
if sum((t[1] == 0 for t in type_degrees)) != 1: # type: ignore
|
||||
return False # There is more than one root type
|
||||
@ -729,9 +669,7 @@ class Graph(BaseModel):
|
||||
return False
|
||||
|
||||
# Verify that all outputs match the input type (are a base class or the same class)
|
||||
if not all(
|
||||
(issubclass(input_root_type, get_args(f)[0]) for f in output_fields)
|
||||
):
|
||||
if not all((issubclass(input_root_type, get_args(f)[0]) for f in output_fields)):
|
||||
return False
|
||||
|
||||
return True
|
||||
@ -751,9 +689,7 @@ class Graph(BaseModel):
|
||||
g.add_edges_from(set([(e.source.node_id, e.destination.node_id) for e in self.edges]))
|
||||
return g
|
||||
|
||||
def nx_graph_flat(
|
||||
self, nx_graph: Optional[nx.DiGraph] = None, prefix: Optional[str] = None
|
||||
) -> nx.DiGraph:
|
||||
def nx_graph_flat(self, nx_graph: Optional[nx.DiGraph] = None, prefix: Optional[str] = None) -> nx.DiGraph:
|
||||
"""Returns a flattened NetworkX DiGraph, including all subgraphs (but not with iterations expanded)"""
|
||||
g = nx_graph or nx.DiGraph()
|
||||
|
||||
@ -762,26 +698,18 @@ class Graph(BaseModel):
|
||||
[
|
||||
self._get_node_path(n.id, prefix)
|
||||
for n in self.nodes.values()
|
||||
if not isinstance(n, GraphInvocation)
|
||||
and not isinstance(n, IterateInvocation)
|
||||
if not isinstance(n, GraphInvocation) and not isinstance(n, IterateInvocation)
|
||||
]
|
||||
)
|
||||
|
||||
# Expand graph nodes
|
||||
for sgn in (
|
||||
gn for gn in self.nodes.values() if isinstance(gn, GraphInvocation)
|
||||
):
|
||||
for sgn in (gn for gn in self.nodes.values() if isinstance(gn, GraphInvocation)):
|
||||
g = sgn.graph.nx_graph_flat(g, self._get_node_path(sgn.id, prefix))
|
||||
|
||||
# TODO: figure out if iteration nodes need to be expanded
|
||||
|
||||
unique_edges = set([(e.source.node_id, e.destination.node_id) for e in self.edges])
|
||||
g.add_edges_from(
|
||||
[
|
||||
(self._get_node_path(e[0], prefix), self._get_node_path(e[1], prefix))
|
||||
for e in unique_edges
|
||||
]
|
||||
)
|
||||
g.add_edges_from([(self._get_node_path(e[0], prefix), self._get_node_path(e[1], prefix)) for e in unique_edges])
|
||||
return g
|
||||
|
||||
|
||||
@ -800,23 +728,19 @@ class GraphExecutionState(BaseModel):
|
||||
)
|
||||
|
||||
# Nodes that have been executed
|
||||
executed: set[str] = Field(
|
||||
description="The set of node ids that have been executed", default_factory=set
|
||||
)
|
||||
executed: set[str] = Field(description="The set of node ids that have been executed", default_factory=set)
|
||||
executed_history: list[str] = Field(
|
||||
description="The list of node ids that have been executed, in order of execution",
|
||||
default_factory=list,
|
||||
)
|
||||
|
||||
# The results of executed nodes
|
||||
results: dict[
|
||||
str, Annotated[InvocationOutputsUnion, Field(discriminator="type")]
|
||||
] = Field(description="The results of node executions", default_factory=dict)
|
||||
results: dict[str, Annotated[InvocationOutputsUnion, Field(discriminator="type")]] = Field(
|
||||
description="The results of node executions", default_factory=dict
|
||||
)
|
||||
|
||||
# Errors raised when executing nodes
|
||||
errors: dict[str, str] = Field(
|
||||
description="Errors raised when executing nodes", default_factory=dict
|
||||
)
|
||||
errors: dict[str, str] = Field(description="Errors raised when executing nodes", default_factory=dict)
|
||||
|
||||
# Map of prepared/executed nodes to their original nodes
|
||||
prepared_source_mapping: dict[str, str] = Field(
|
||||
@ -832,16 +756,16 @@ class GraphExecutionState(BaseModel):
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
'required': [
|
||||
'id',
|
||||
'graph',
|
||||
'execution_graph',
|
||||
'executed',
|
||||
'executed_history',
|
||||
'results',
|
||||
'errors',
|
||||
'prepared_source_mapping',
|
||||
'source_prepared_mapping',
|
||||
"required": [
|
||||
"id",
|
||||
"graph",
|
||||
"execution_graph",
|
||||
"executed",
|
||||
"executed_history",
|
||||
"results",
|
||||
"errors",
|
||||
"prepared_source_mapping",
|
||||
"source_prepared_mapping",
|
||||
]
|
||||
}
|
||||
|
||||
@ -899,9 +823,7 @@ class GraphExecutionState(BaseModel):
|
||||
"""Returns true if the graph has any errors"""
|
||||
return len(self.errors) > 0
|
||||
|
||||
def _create_execution_node(
|
||||
self, node_path: str, iteration_node_map: list[tuple[str, str]]
|
||||
) -> list[str]:
|
||||
def _create_execution_node(self, node_path: str, iteration_node_map: list[tuple[str, str]]) -> list[str]:
|
||||
"""Prepares an iteration node and connects all edges, returning the new node id"""
|
||||
|
||||
node = self.graph.get_node(node_path)
|
||||
@ -911,20 +833,12 @@ class GraphExecutionState(BaseModel):
|
||||
# If this is an iterator node, we must create a copy for each iteration
|
||||
if isinstance(node, IterateInvocation):
|
||||
# Get input collection edge (should error if there are no inputs)
|
||||
input_collection_edge = next(
|
||||
iter(self.graph._get_input_edges(node_path, "collection"))
|
||||
)
|
||||
input_collection_edge = next(iter(self.graph._get_input_edges(node_path, "collection")))
|
||||
input_collection_prepared_node_id = next(
|
||||
n[1]
|
||||
for n in iteration_node_map
|
||||
if n[0] == input_collection_edge.source.node_id
|
||||
)
|
||||
input_collection_prepared_node_output = self.results[
|
||||
input_collection_prepared_node_id
|
||||
]
|
||||
input_collection = getattr(
|
||||
input_collection_prepared_node_output, input_collection_edge.source.field
|
||||
n[1] for n in iteration_node_map if n[0] == input_collection_edge.source.node_id
|
||||
)
|
||||
input_collection_prepared_node_output = self.results[input_collection_prepared_node_id]
|
||||
input_collection = getattr(input_collection_prepared_node_output, input_collection_edge.source.field)
|
||||
self_iteration_count = len(input_collection)
|
||||
|
||||
new_nodes = list()
|
||||
@ -939,9 +853,7 @@ class GraphExecutionState(BaseModel):
|
||||
# For collect nodes, this may contain multiple inputs to the same field
|
||||
new_edges = list()
|
||||
for edge in input_edges:
|
||||
for input_node_id in (
|
||||
n[1] for n in iteration_node_map if n[0] == edge.source.node_id
|
||||
):
|
||||
for input_node_id in (n[1] for n in iteration_node_map if n[0] == edge.source.node_id):
|
||||
new_edge = Edge(
|
||||
source=EdgeConnection(node_id=input_node_id, field=edge.source.field),
|
||||
destination=EdgeConnection(node_id="", field=edge.destination.field),
|
||||
@ -982,11 +894,7 @@ class GraphExecutionState(BaseModel):
|
||||
def _iterator_graph(self) -> nx.DiGraph:
|
||||
"""Gets a DiGraph with edges to collectors removed so an ancestor search produces all active iterators for any node"""
|
||||
g = self.graph.nx_graph_flat()
|
||||
collectors = (
|
||||
n
|
||||
for n in self.graph.nodes
|
||||
if isinstance(self.graph.get_node(n), CollectInvocation)
|
||||
)
|
||||
collectors = (n for n in self.graph.nodes if isinstance(self.graph.get_node(n), CollectInvocation))
|
||||
for c in collectors:
|
||||
g.remove_edges_from(list(g.in_edges(c)))
|
||||
return g
|
||||
@ -994,11 +902,7 @@ class GraphExecutionState(BaseModel):
|
||||
def _get_node_iterators(self, node_id: str) -> list[str]:
|
||||
"""Gets iterators for a node"""
|
||||
g = self._iterator_graph()
|
||||
iterators = [
|
||||
n
|
||||
for n in nx.ancestors(g, node_id)
|
||||
if isinstance(self.graph.get_node(n), IterateInvocation)
|
||||
]
|
||||
iterators = [n for n in nx.ancestors(g, node_id) if isinstance(self.graph.get_node(n), IterateInvocation)]
|
||||
return iterators
|
||||
|
||||
def _prepare(self) -> Optional[str]:
|
||||
@ -1045,29 +949,18 @@ class GraphExecutionState(BaseModel):
|
||||
if isinstance(next_node, CollectInvocation):
|
||||
# Collapse all iterator input mappings and create a single execution node for the collect invocation
|
||||
all_iteration_mappings = list(
|
||||
itertools.chain(
|
||||
*(
|
||||
((s, p) for p in self.source_prepared_mapping[s])
|
||||
for s in next_node_parents
|
||||
)
|
||||
)
|
||||
itertools.chain(*(((s, p) for p in self.source_prepared_mapping[s]) for s in next_node_parents))
|
||||
)
|
||||
# all_iteration_mappings = list(set(itertools.chain(*prepared_parent_mappings)))
|
||||
create_results = self._create_execution_node(
|
||||
next_node_id, all_iteration_mappings
|
||||
)
|
||||
create_results = self._create_execution_node(next_node_id, all_iteration_mappings)
|
||||
if create_results is not None:
|
||||
new_node_ids.extend(create_results)
|
||||
else: # Iterators or normal nodes
|
||||
# Get all iterator combinations for this node
|
||||
# Will produce a list of lists of prepared iterator nodes, from which results can be iterated
|
||||
iterator_nodes = self._get_node_iterators(next_node_id)
|
||||
iterator_nodes_prepared = [
|
||||
list(self.source_prepared_mapping[n]) for n in iterator_nodes
|
||||
]
|
||||
iterator_node_prepared_combinations = list(
|
||||
itertools.product(*iterator_nodes_prepared)
|
||||
)
|
||||
iterator_nodes_prepared = [list(self.source_prepared_mapping[n]) for n in iterator_nodes]
|
||||
iterator_node_prepared_combinations = list(itertools.product(*iterator_nodes_prepared))
|
||||
|
||||
# Select the correct prepared parents for each iteration
|
||||
# For every iterator, the parent must either not be a child of that iterator, or must match the prepared iteration for that iterator
|
||||
@ -1096,31 +989,16 @@ class GraphExecutionState(BaseModel):
|
||||
return next(iter(prepared_nodes))
|
||||
|
||||
# Check if the requested node is an iterator
|
||||
prepared_iterator = next(
|
||||
(n for n in prepared_nodes if n in prepared_iterator_nodes), None
|
||||
)
|
||||
prepared_iterator = next((n for n in prepared_nodes if n in prepared_iterator_nodes), None)
|
||||
if prepared_iterator is not None:
|
||||
return prepared_iterator
|
||||
|
||||
# Filter to only iterator nodes that are a parent of the specified node, in tuple format (prepared, source)
|
||||
iterator_source_node_mapping = [
|
||||
(n, self.prepared_source_mapping[n]) for n in prepared_iterator_nodes
|
||||
]
|
||||
parent_iterators = [
|
||||
itn
|
||||
for itn in iterator_source_node_mapping
|
||||
if nx.has_path(graph, itn[1], source_node_path)
|
||||
]
|
||||
iterator_source_node_mapping = [(n, self.prepared_source_mapping[n]) for n in prepared_iterator_nodes]
|
||||
parent_iterators = [itn for itn in iterator_source_node_mapping if nx.has_path(graph, itn[1], source_node_path)]
|
||||
|
||||
return next(
|
||||
(
|
||||
n
|
||||
for n in prepared_nodes
|
||||
if all(
|
||||
nx.has_path(execution_graph, pit[0], n)
|
||||
for pit in parent_iterators
|
||||
)
|
||||
),
|
||||
(n for n in prepared_nodes if all(nx.has_path(execution_graph, pit[0], n) for pit in parent_iterators)),
|
||||
None,
|
||||
)
|
||||
|
||||
@ -1130,13 +1008,13 @@ class GraphExecutionState(BaseModel):
|
||||
|
||||
# Depth-first search with pre-order traversal is a depth-first topological sort
|
||||
sorted_nodes = nx.dfs_preorder_nodes(g)
|
||||
|
||||
|
||||
next_node = next(
|
||||
(
|
||||
n
|
||||
for n in sorted_nodes
|
||||
if n not in self.executed # the node must not already be executed...
|
||||
and all((e[0] in self.executed for e in g.in_edges(n))) # ...and all its inputs must be executed
|
||||
if n not in self.executed # the node must not already be executed...
|
||||
and all((e[0] in self.executed for e in g.in_edges(n))) # ...and all its inputs must be executed
|
||||
),
|
||||
None,
|
||||
)
|
||||
@ -1221,15 +1099,18 @@ class ExposedNodeOutput(BaseModel):
|
||||
field: str = Field(description="The field name of the output")
|
||||
alias: str = Field(description="The alias of the output")
|
||||
|
||||
|
||||
class LibraryGraph(BaseModel):
|
||||
id: str = Field(description="The unique identifier for this library graph", default_factory=uuid.uuid4)
|
||||
graph: Graph = Field(description="The graph")
|
||||
name: str = Field(description="The name of the graph")
|
||||
description: str = Field(description="The description of the graph")
|
||||
exposed_inputs: list[ExposedNodeInput] = Field(description="The inputs exposed by this graph", default_factory=list)
|
||||
exposed_outputs: list[ExposedNodeOutput] = Field(description="The outputs exposed by this graph", default_factory=list)
|
||||
exposed_outputs: list[ExposedNodeOutput] = Field(
|
||||
description="The outputs exposed by this graph", default_factory=list
|
||||
)
|
||||
|
||||
@validator('exposed_inputs', 'exposed_outputs')
|
||||
@validator("exposed_inputs", "exposed_outputs")
|
||||
def validate_exposed_aliases(cls, v):
|
||||
if len(v) != len(set(i.alias for i in v)):
|
||||
raise ValueError("Duplicate exposed alias")
|
||||
@ -1237,23 +1118,27 @@ class LibraryGraph(BaseModel):
|
||||
|
||||
@root_validator
|
||||
def validate_exposed_nodes(cls, values):
|
||||
graph = values['graph']
|
||||
graph = values["graph"]
|
||||
|
||||
# Validate exposed inputs
|
||||
for exposed_input in values['exposed_inputs']:
|
||||
for exposed_input in values["exposed_inputs"]:
|
||||
if not graph.has_node(exposed_input.node_path):
|
||||
raise ValueError(f"Exposed input node {exposed_input.node_path} does not exist")
|
||||
node = graph.get_node(exposed_input.node_path)
|
||||
if get_input_field(node, exposed_input.field) is None:
|
||||
raise ValueError(f"Exposed input field {exposed_input.field} does not exist on node {exposed_input.node_path}")
|
||||
raise ValueError(
|
||||
f"Exposed input field {exposed_input.field} does not exist on node {exposed_input.node_path}"
|
||||
)
|
||||
|
||||
# Validate exposed outputs
|
||||
for exposed_output in values['exposed_outputs']:
|
||||
for exposed_output in values["exposed_outputs"]:
|
||||
if not graph.has_node(exposed_output.node_path):
|
||||
raise ValueError(f"Exposed output node {exposed_output.node_path} does not exist")
|
||||
node = graph.get_node(exposed_output.node_path)
|
||||
if get_output_field(node, exposed_output.field) is None:
|
||||
raise ValueError(f"Exposed output field {exposed_output.field} does not exist on node {exposed_output.node_path}")
|
||||
raise ValueError(
|
||||
f"Exposed output field {exposed_output.field} does not exist on node {exposed_output.node_path}"
|
||||
)
|
||||
|
||||
return values
|
||||
|
||||
|
@ -85,9 +85,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
||||
self.__cache_ids = Queue()
|
||||
self.__max_cache_size = 10 # TODO: get this from config
|
||||
|
||||
self.__output_folder: Path = (
|
||||
output_folder if isinstance(output_folder, Path) else Path(output_folder)
|
||||
)
|
||||
self.__output_folder: Path = output_folder if isinstance(output_folder, Path) else Path(output_folder)
|
||||
self.__thumbnails_folder = self.__output_folder / "thumbnails"
|
||||
|
||||
# Validate required output folders at launch
|
||||
@ -120,7 +118,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
||||
image_path = self.get_path(image_name)
|
||||
|
||||
pnginfo = PngImagePlugin.PngInfo()
|
||||
|
||||
|
||||
if metadata is not None:
|
||||
pnginfo.add_text("invokeai_metadata", json.dumps(metadata))
|
||||
if graph is not None:
|
||||
@ -183,9 +181,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
||||
def __set_cache(self, image_name: Path, image: PILImageType):
|
||||
if not image_name in self.__cache:
|
||||
self.__cache[image_name] = image
|
||||
self.__cache_ids.put(
|
||||
image_name
|
||||
) # TODO: this should refresh position for LRU cache
|
||||
self.__cache_ids.put(image_name) # TODO: this should refresh position for LRU cache
|
||||
if len(self.__cache) > self.__max_cache_size:
|
||||
cache_id = self.__cache_ids.get()
|
||||
if cache_id in self.__cache:
|
||||
|
@ -426,9 +426,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
return OffsetPaginatedResults(
|
||||
items=images, offset=offset, limit=limit, total=count
|
||||
)
|
||||
return OffsetPaginatedResults(items=images, offset=offset, limit=limit, total=count)
|
||||
|
||||
def delete(self, image_name: str) -> None:
|
||||
try:
|
||||
@ -466,7 +464,6 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
|
||||
def delete_intermediates(self) -> list[str]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
@ -505,9 +502,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
is_intermediate: bool = False,
|
||||
) -> datetime:
|
||||
try:
|
||||
metadata_json = (
|
||||
None if metadata is None else json.dumps(metadata)
|
||||
)
|
||||
metadata_json = None if metadata is None else json.dumps(metadata)
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
|
@ -216,16 +216,9 @@ class ImageService(ImageServiceABC):
|
||||
metadata=metadata,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if board_id is not None:
|
||||
self._services.board_image_records.add_image_to_board(
|
||||
board_id=board_id, image_name=image_name
|
||||
)
|
||||
|
||||
self._services.image_files.save(
|
||||
image_name=image_name, image=image, metadata=metadata, graph=graph
|
||||
)
|
||||
|
||||
self._services.board_image_records.add_image_to_board(board_id=board_id, image_name=image_name)
|
||||
self._services.image_files.save(image_name=image_name, image=image, metadata=metadata, graph=graph)
|
||||
image_dto = self.get_dto(image_name)
|
||||
|
||||
return image_dto
|
||||
@ -236,7 +229,7 @@ class ImageService(ImageServiceABC):
|
||||
self._services.logger.error("Failed to save image file")
|
||||
raise
|
||||
except Exception as e:
|
||||
self._services.logger.error("Problem saving image record and file")
|
||||
self._services.logger.error(f"Problem saving image record and file: {str(e)}")
|
||||
raise e
|
||||
|
||||
def update(
|
||||
@ -300,9 +293,7 @@ class ImageService(ImageServiceABC):
|
||||
if not image_record.session_id:
|
||||
return ImageMetadata()
|
||||
|
||||
session_raw = self._services.graph_execution_manager.get_raw(
|
||||
image_record.session_id
|
||||
)
|
||||
session_raw = self._services.graph_execution_manager.get_raw(image_record.session_id)
|
||||
graph = None
|
||||
|
||||
if session_raw:
|
||||
@ -367,9 +358,7 @@ class ImageService(ImageServiceABC):
|
||||
r,
|
||||
self._services.urls.get_image_url(r.image_name),
|
||||
self._services.urls.get_image_url(r.image_name, True),
|
||||
self._services.board_image_records.get_board_for_image(
|
||||
r.image_name
|
||||
),
|
||||
self._services.board_image_records.get_board_for_image(r.image_name),
|
||||
),
|
||||
results.items,
|
||||
)
|
||||
@ -401,11 +390,7 @@ class ImageService(ImageServiceABC):
|
||||
|
||||
def delete_images_on_board(self, board_id: str):
|
||||
try:
|
||||
image_names = (
|
||||
self._services.board_image_records.get_all_board_image_names_for_board(
|
||||
board_id
|
||||
)
|
||||
)
|
||||
image_names = self._services.board_image_records.get_all_board_image_names_for_board(board_id)
|
||||
for image_name in image_names:
|
||||
self._services.image_files.delete(image_name)
|
||||
self._services.image_records.delete_many(image_names)
|
||||
|
@ -7,6 +7,7 @@ from queue import Queue
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class InvocationQueueItem(BaseModel):
|
||||
graph_execution_state_id: str = Field(description="The ID of the graph execution state")
|
||||
invocation_id: str = Field(description="The ID of the node being invoked")
|
||||
@ -45,9 +46,11 @@ class MemoryInvocationQueue(InvocationQueueABC):
|
||||
def get(self) -> InvocationQueueItem:
|
||||
item = self.__queue.get()
|
||||
|
||||
while isinstance(item, InvocationQueueItem) \
|
||||
and item.graph_execution_state_id in self.__cancellations \
|
||||
and self.__cancellations[item.graph_execution_state_id] > item.timestamp:
|
||||
while (
|
||||
isinstance(item, InvocationQueueItem)
|
||||
and item.graph_execution_state_id in self.__cancellations
|
||||
and self.__cancellations[item.graph_execution_state_id] > item.timestamp
|
||||
):
|
||||
item = self.__queue.get()
|
||||
|
||||
# Clear old items
|
||||
|
@ -7,6 +7,7 @@ from .graph import Graph, GraphExecutionState
|
||||
from .invocation_queue import InvocationQueueItem
|
||||
from .invocation_services import InvocationServices
|
||||
|
||||
|
||||
class Invoker:
|
||||
"""The invoker, used to execute invocations"""
|
||||
|
||||
@ -16,9 +17,7 @@ class Invoker:
|
||||
self.services = services
|
||||
self._start()
|
||||
|
||||
def invoke(
|
||||
self, graph_execution_state: GraphExecutionState, invoke_all: bool = False
|
||||
) -> Optional[str]:
|
||||
def invoke(self, graph_execution_state: GraphExecutionState, invoke_all: bool = False) -> Optional[str]:
|
||||
"""Determines the next node to invoke and enqueues it, preparing if needed.
|
||||
Returns the id of the queued node, or `None` if there are no nodes left to enqueue."""
|
||||
|
||||
|
@ -9,13 +9,15 @@ T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
class PaginatedResults(GenericModel, Generic[T]):
|
||||
"""Paginated results"""
|
||||
#fmt: off
|
||||
|
||||
# fmt: off
|
||||
items: list[T] = Field(description="Items")
|
||||
page: int = Field(description="Current Page")
|
||||
pages: int = Field(description="Total number of pages")
|
||||
per_page: int = Field(description="Number of items per page")
|
||||
total: int = Field(description="Total number of items in result")
|
||||
#fmt: on
|
||||
# fmt: on
|
||||
|
||||
|
||||
class ItemStorageABC(ABC, Generic[T]):
|
||||
_on_changed_callbacks: list[Callable[[T], None]]
|
||||
@ -48,9 +50,7 @@ class ItemStorageABC(ABC, Generic[T]):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def search(
|
||||
self, query: str, page: int = 0, per_page: int = 10
|
||||
) -> PaginatedResults[T]:
|
||||
def search(self, query: str, page: int = 0, per_page: int = 10) -> PaginatedResults[T]:
|
||||
pass
|
||||
|
||||
def on_changed(self, on_changed: Callable[[T], None]) -> None:
|
||||
|
@ -7,6 +7,7 @@ from typing import Dict, Union, Optional
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class LatentsStorageBase(ABC):
|
||||
"""Responsible for storing and retrieving latents."""
|
||||
|
||||
@ -25,7 +26,7 @@ class LatentsStorageBase(ABC):
|
||||
|
||||
class ForwardCacheLatentsStorage(LatentsStorageBase):
|
||||
"""Caches the latest N latents in memory, writing-thorugh to and reading from underlying storage"""
|
||||
|
||||
|
||||
__cache: Dict[str, torch.Tensor]
|
||||
__cache_ids: Queue
|
||||
__max_cache_size: int
|
||||
@ -87,8 +88,6 @@ class DiskLatentsStorage(LatentsStorageBase):
|
||||
def delete(self, name: str) -> None:
|
||||
latent_path = self.get_path(name)
|
||||
latent_path.unlink()
|
||||
|
||||
|
||||
def get_path(self, name: str) -> Path:
|
||||
return self.__output_folder / name
|
||||
|
||||
|
@ -103,7 +103,7 @@ class ModelManagerServiceBase(ABC):
|
||||
}
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def list_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
|
||||
"""
|
||||
@ -125,7 +125,7 @@ class ModelManagerServiceBase(ABC):
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
model_attributes: dict,
|
||||
clobber: bool = False
|
||||
clobber: bool = False,
|
||||
) -> AddModelResult:
|
||||
"""
|
||||
Update the named model with a dictionary of attributes. Will fail with an
|
||||
@ -148,12 +148,12 @@ class ModelManagerServiceBase(ABC):
|
||||
Update the named model with a dictionary of attributes. Will fail with a
|
||||
ModelNotFoundException if the name does not already exist.
|
||||
|
||||
On a successful update, the config will be changed in memory. Will fail
|
||||
with an assertion error if provided attributes are incorrect or
|
||||
On a successful update, the config will be changed in memory. Will fail
|
||||
with an assertion error if provided attributes are incorrect or
|
||||
the model name is missing. Call commit() to write changes to disk.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def del_model(
|
||||
self,
|
||||
@ -169,21 +169,20 @@ class ModelManagerServiceBase(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def rename_model(self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
new_name: str,
|
||||
):
|
||||
def rename_model(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
new_name: str,
|
||||
):
|
||||
"""
|
||||
Rename the indicated model.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_checkpoint_configs(
|
||||
self
|
||||
)->List[Path]:
|
||||
def list_checkpoint_configs(self) -> List[Path]:
|
||||
"""
|
||||
List the checkpoint config paths from ROOT/configs/stable-diffusion.
|
||||
"""
|
||||
@ -194,7 +193,7 @@ class ModelManagerServiceBase(ABC):
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: Union[ModelType.Main,ModelType.Vae],
|
||||
model_type: Union[ModelType.Main, ModelType.Vae],
|
||||
) -> AddModelResult:
|
||||
"""
|
||||
Convert a checkpoint file into a diffusers folder, deleting the cached
|
||||
@ -211,11 +210,12 @@ class ModelManagerServiceBase(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def heuristic_import(self,
|
||||
items_to_import: set[str],
|
||||
prediction_type_helper: Optional[Callable[[Path],SchedulerPredictionType]]=None,
|
||||
)->dict[str, AddModelResult]:
|
||||
'''Import a list of paths, repo_ids or URLs. Returns the set of
|
||||
def heuristic_import(
|
||||
self,
|
||||
items_to_import: set[str],
|
||||
prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None,
|
||||
) -> dict[str, AddModelResult]:
|
||||
"""Import a list of paths, repo_ids or URLs. Returns the set of
|
||||
successfully imported items.
|
||||
:param items_to_import: Set of strings corresponding to models to be imported.
|
||||
:param prediction_type_helper: A callback that receives the Path of a Stable Diffusion 2 checkpoint model and returns a SchedulerPredictionType.
|
||||
@ -230,19 +230,23 @@ class ModelManagerServiceBase(ABC):
|
||||
The result is a set of successfully installed models. Each element
|
||||
of the set is a dict corresponding to the newly-created OmegaConf stanza for
|
||||
that model.
|
||||
'''
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def merge_models(
|
||||
self,
|
||||
model_names: List[str] = Field(default=None, min_items=2, max_items=3, description="List of model names to merge"),
|
||||
base_model: Union[BaseModelType,str] = Field(default=None, description="Base model shared by all models to be merged"),
|
||||
merged_model_name: str = Field(default=None, description="Name of destination model after merging"),
|
||||
alpha: Optional[float] = 0.5,
|
||||
interp: Optional[MergeInterpolationMethod] = None,
|
||||
force: Optional[bool] = False,
|
||||
merge_dest_directory: Optional[Path] = None
|
||||
self,
|
||||
model_names: List[str] = Field(
|
||||
default=None, min_items=2, max_items=3, description="List of model names to merge"
|
||||
),
|
||||
base_model: Union[BaseModelType, str] = Field(
|
||||
default=None, description="Base model shared by all models to be merged"
|
||||
),
|
||||
merged_model_name: str = Field(default=None, description="Name of destination model after merging"),
|
||||
alpha: Optional[float] = 0.5,
|
||||
interp: Optional[MergeInterpolationMethod] = None,
|
||||
force: Optional[bool] = False,
|
||||
merge_dest_directory: Optional[Path] = None,
|
||||
) -> AddModelResult:
|
||||
"""
|
||||
Merge two to three diffusrs pipeline models and save as a new model.
|
||||
@ -250,27 +254,27 @@ class ModelManagerServiceBase(ABC):
|
||||
:param base_model: Base model to use for all models
|
||||
:param merged_model_name: Name of destination merged model
|
||||
:param alpha: Alpha strength to apply to 2d and 3d model
|
||||
:param interp: Interpolation method. None (default)
|
||||
:param interp: Interpolation method. None (default)
|
||||
:param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended)
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def search_for_models(self, directory: Path)->List[Path]:
|
||||
def search_for_models(self, directory: Path) -> List[Path]:
|
||||
"""
|
||||
Return list of all models found in the designated directory.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def sync_to_config(self):
|
||||
"""
|
||||
Re-read models.yaml, rescan the models directory, and reimport models
|
||||
Re-read models.yaml, rescan the models directory, and reimport models
|
||||
in the autoimport directories. Call after making changes outside the
|
||||
model manager API.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def commit(self, conf_file: Optional[Path] = None) -> None:
|
||||
"""
|
||||
@ -280,9 +284,11 @@ class ModelManagerServiceBase(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
# simple implementation
|
||||
class ModelManagerService(ModelManagerServiceBase):
|
||||
"""Responsible for managing models on disk and in memory"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: InvokeAIAppConfig,
|
||||
@ -298,17 +304,17 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
config_file = config.model_conf_path
|
||||
else:
|
||||
config_file = config.root_dir / "configs/models.yaml"
|
||||
|
||||
logger.debug(f'Config file={config_file}')
|
||||
|
||||
logger.debug(f"Config file={config_file}")
|
||||
|
||||
device = torch.device(choose_torch_device())
|
||||
device_name = torch.cuda.get_device_name() if device==torch.device('cuda') else ''
|
||||
logger.info(f'GPU device = {device} {device_name}')
|
||||
device_name = torch.cuda.get_device_name() if device == torch.device("cuda") else ""
|
||||
logger.info(f"GPU device = {device} {device_name}")
|
||||
|
||||
precision = config.precision
|
||||
if precision == "auto":
|
||||
precision = choose_precision(device)
|
||||
dtype = torch.float32 if precision == 'float32' else torch.float16
|
||||
dtype = torch.float32 if precision == "float32" else torch.float16
|
||||
|
||||
# this is transitional backward compatibility
|
||||
# support for the deprecated `max_loaded_models`
|
||||
@ -316,9 +322,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
# cache size is set to 2.5 GB times
|
||||
# the number of max_loaded_models. Otherwise
|
||||
# use new `max_cache_size` config setting
|
||||
max_cache_size = config.max_cache_size \
|
||||
if hasattr(config,'max_cache_size') \
|
||||
else config.max_loaded_models * 2.5
|
||||
max_cache_size = config.max_cache_size if hasattr(config, "max_cache_size") else config.max_loaded_models * 2.5
|
||||
|
||||
logger.debug(f"Maximum RAM cache size: {max_cache_size} GiB")
|
||||
|
||||
@ -332,7 +336,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
sequential_offload=sequential_offload,
|
||||
logger=logger,
|
||||
)
|
||||
logger.info('Model manager service initialized')
|
||||
logger.info("Model manager service initialized")
|
||||
|
||||
def get_model(
|
||||
self,
|
||||
@ -371,7 +375,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=submodel,
|
||||
model_info=model_info
|
||||
model_info=model_info,
|
||||
)
|
||||
|
||||
return model_info
|
||||
@ -405,9 +409,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
return self.mgr.model_names()
|
||||
|
||||
def list_models(
|
||||
self,
|
||||
base_model: Optional[BaseModelType] = None,
|
||||
model_type: Optional[ModelType] = None
|
||||
self, base_model: Optional[BaseModelType] = None, model_type: Optional[ModelType] = None
|
||||
) -> list[dict]:
|
||||
"""
|
||||
Return a list of models.
|
||||
@ -418,9 +420,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
"""
|
||||
Return information about the model using the same format as list_models()
|
||||
"""
|
||||
return self.mgr.list_model(model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type)
|
||||
return self.mgr.list_model(model_name=model_name, base_model=base_model, model_type=model_type)
|
||||
|
||||
def add_model(
|
||||
self,
|
||||
@ -429,7 +429,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
model_type: ModelType,
|
||||
model_attributes: dict,
|
||||
clobber: bool = False,
|
||||
)->None:
|
||||
) -> None:
|
||||
"""
|
||||
Update the named model with a dictionary of attributes. Will fail with an
|
||||
assertion error if the name already exists. Pass clobber=True to overwrite.
|
||||
@ -437,7 +437,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
with an assertion error if provided attributes are incorrect or
|
||||
the model name is missing. Call commit() to write changes to disk.
|
||||
"""
|
||||
self.logger.debug(f'add/update model {model_name}')
|
||||
self.logger.debug(f"add/update model {model_name}")
|
||||
return self.mgr.add_model(model_name, base_model, model_type, model_attributes, clobber)
|
||||
|
||||
def update_model(
|
||||
@ -450,15 +450,15 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
"""
|
||||
Update the named model with a dictionary of attributes. Will fail with a
|
||||
ModelNotFoundException exception if the name does not already exist.
|
||||
On a successful update, the config will be changed in memory. Will fail
|
||||
with an assertion error if provided attributes are incorrect or
|
||||
On a successful update, the config will be changed in memory. Will fail
|
||||
with an assertion error if provided attributes are incorrect or
|
||||
the model name is missing. Call commit() to write changes to disk.
|
||||
"""
|
||||
self.logger.debug(f'update model {model_name}')
|
||||
self.logger.debug(f"update model {model_name}")
|
||||
if not self.model_exists(model_name, base_model, model_type):
|
||||
raise ModelNotFoundException(f"Unknown model {model_name}")
|
||||
return self.add_model(model_name, base_model, model_type, model_attributes, clobber=True)
|
||||
|
||||
|
||||
def del_model(
|
||||
self,
|
||||
model_name: str,
|
||||
@ -470,7 +470,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
then the underlying weight file or diffusers directory will be deleted
|
||||
as well.
|
||||
"""
|
||||
self.logger.debug(f'delete model {model_name}')
|
||||
self.logger.debug(f"delete model {model_name}")
|
||||
self.mgr.del_model(model_name, base_model, model_type)
|
||||
self.mgr.commit()
|
||||
|
||||
@ -478,8 +478,10 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: Union[ModelType.Main,ModelType.Vae],
|
||||
convert_dest_directory: Optional[Path] = Field(default=None, description="Optional directory location for merged model"),
|
||||
model_type: Union[ModelType.Main, ModelType.Vae],
|
||||
convert_dest_directory: Optional[Path] = Field(
|
||||
default=None, description="Optional directory location for merged model"
|
||||
),
|
||||
) -> AddModelResult:
|
||||
"""
|
||||
Convert a checkpoint file into a diffusers folder, deleting the cached
|
||||
@ -494,10 +496,10 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
also raise a ValueError in the event that there is a similarly-named diffusers
|
||||
directory already in place.
|
||||
"""
|
||||
self.logger.debug(f'convert model {model_name}')
|
||||
self.logger.debug(f"convert model {model_name}")
|
||||
return self.mgr.convert_model(model_name, base_model, model_type, convert_dest_directory)
|
||||
|
||||
def commit(self, conf_file: Optional[Path]=None):
|
||||
def commit(self, conf_file: Optional[Path] = None):
|
||||
"""
|
||||
Write current configuration out to the indicated file.
|
||||
If no conf_file is provided, then replaces the
|
||||
@ -524,7 +526,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=submodel,
|
||||
model_info=model_info
|
||||
model_info=model_info,
|
||||
)
|
||||
else:
|
||||
context.services.events.emit_model_load_started(
|
||||
@ -535,16 +537,16 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
submodel=submodel,
|
||||
)
|
||||
|
||||
|
||||
@property
|
||||
def logger(self):
|
||||
return self.mgr.logger
|
||||
|
||||
def heuristic_import(self,
|
||||
items_to_import: set[str],
|
||||
prediction_type_helper: Optional[Callable[[Path],SchedulerPredictionType]]=None,
|
||||
)->dict[str, AddModelResult]:
|
||||
'''Import a list of paths, repo_ids or URLs. Returns the set of
|
||||
def heuristic_import(
|
||||
self,
|
||||
items_to_import: set[str],
|
||||
prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None,
|
||||
) -> dict[str, AddModelResult]:
|
||||
"""Import a list of paths, repo_ids or URLs. Returns the set of
|
||||
successfully imported items.
|
||||
:param items_to_import: Set of strings corresponding to models to be imported.
|
||||
:param prediction_type_helper: A callback that receives the Path of a Stable Diffusion 2 checkpoint model and returns a SchedulerPredictionType.
|
||||
@ -559,18 +561,24 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
The result is a set of successfully installed models. Each element
|
||||
of the set is a dict corresponding to the newly-created OmegaConf stanza for
|
||||
that model.
|
||||
'''
|
||||
return self.mgr.heuristic_import(items_to_import, prediction_type_helper)
|
||||
"""
|
||||
return self.mgr.heuristic_import(items_to_import, prediction_type_helper)
|
||||
|
||||
def merge_models(
|
||||
self,
|
||||
model_names: List[str] = Field(default=None, min_items=2, max_items=3, description="List of model names to merge"),
|
||||
base_model: Union[BaseModelType,str] = Field(default=None, description="Base model shared by all models to be merged"),
|
||||
merged_model_name: str = Field(default=None, description="Name of destination model after merging"),
|
||||
alpha: Optional[float] = 0.5,
|
||||
interp: Optional[MergeInterpolationMethod] = None,
|
||||
force: Optional[bool] = False,
|
||||
merge_dest_directory: Optional[Path] = Field(default=None, description="Optional directory location for merged model"),
|
||||
self,
|
||||
model_names: List[str] = Field(
|
||||
default=None, min_items=2, max_items=3, description="List of model names to merge"
|
||||
),
|
||||
base_model: Union[BaseModelType, str] = Field(
|
||||
default=None, description="Base model shared by all models to be merged"
|
||||
),
|
||||
merged_model_name: str = Field(default=None, description="Name of destination model after merging"),
|
||||
alpha: Optional[float] = 0.5,
|
||||
interp: Optional[MergeInterpolationMethod] = None,
|
||||
force: Optional[bool] = False,
|
||||
merge_dest_directory: Optional[Path] = Field(
|
||||
default=None, description="Optional directory location for merged model"
|
||||
),
|
||||
) -> AddModelResult:
|
||||
"""
|
||||
Merge two to three diffusrs pipeline models and save as a new model.
|
||||
@ -578,25 +586,25 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
:param base_model: Base model to use for all models
|
||||
:param merged_model_name: Name of destination merged model
|
||||
:param alpha: Alpha strength to apply to 2d and 3d model
|
||||
:param interp: Interpolation method. None (default)
|
||||
:param interp: Interpolation method. None (default)
|
||||
:param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended)
|
||||
"""
|
||||
merger = ModelMerger(self.mgr)
|
||||
try:
|
||||
result = merger.merge_diffusion_models_and_save(
|
||||
model_names = model_names,
|
||||
base_model = base_model,
|
||||
merged_model_name = merged_model_name,
|
||||
alpha = alpha,
|
||||
interp = interp,
|
||||
force = force,
|
||||
model_names=model_names,
|
||||
base_model=base_model,
|
||||
merged_model_name=merged_model_name,
|
||||
alpha=alpha,
|
||||
interp=interp,
|
||||
force=force,
|
||||
merge_dest_directory=merge_dest_directory,
|
||||
)
|
||||
except AssertionError as e:
|
||||
raise ValueError(e)
|
||||
return result
|
||||
|
||||
def search_for_models(self, directory: Path)->List[Path]:
|
||||
def search_for_models(self, directory: Path) -> List[Path]:
|
||||
"""
|
||||
Return list of all models found in the designated directory.
|
||||
"""
|
||||
@ -605,28 +613,29 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
|
||||
def sync_to_config(self):
|
||||
"""
|
||||
Re-read models.yaml, rescan the models directory, and reimport models
|
||||
Re-read models.yaml, rescan the models directory, and reimport models
|
||||
in the autoimport directories. Call after making changes outside the
|
||||
model manager API.
|
||||
"""
|
||||
return self.mgr.sync_to_config()
|
||||
|
||||
def list_checkpoint_configs(self)->List[Path]:
|
||||
def list_checkpoint_configs(self) -> List[Path]:
|
||||
"""
|
||||
List the checkpoint config paths from ROOT/configs/stable-diffusion.
|
||||
"""
|
||||
config = self.mgr.app_config
|
||||
conf_path = config.legacy_conf_path
|
||||
root_path = config.root_path
|
||||
return [(conf_path / x).relative_to(root_path) for x in conf_path.glob('**/*.yaml')]
|
||||
return [(conf_path / x).relative_to(root_path) for x in conf_path.glob("**/*.yaml")]
|
||||
|
||||
def rename_model(self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
new_name: str = None,
|
||||
new_base: BaseModelType = None,
|
||||
):
|
||||
def rename_model(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
new_name: str = None,
|
||||
new_base: BaseModelType = None,
|
||||
):
|
||||
"""
|
||||
Rename the indicated model. Can provide a new name and/or a new base.
|
||||
:param model_name: Current name of the model
|
||||
@ -635,10 +644,10 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
:param new_name: New name for the model
|
||||
:param new_base: New base for the model
|
||||
"""
|
||||
self.mgr.rename_model(base_model = base_model,
|
||||
model_type = model_type,
|
||||
model_name = model_name,
|
||||
new_name = new_name,
|
||||
new_base = new_base,
|
||||
)
|
||||
|
||||
self.mgr.rename_model(
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
model_name=model_name,
|
||||
new_name=new_name,
|
||||
new_base=new_base,
|
||||
)
|
||||
|
@ -11,30 +11,20 @@ class BoardRecord(BaseModel):
|
||||
"""The unique ID of the board."""
|
||||
board_name: str = Field(description="The name of the board.")
|
||||
"""The name of the board."""
|
||||
created_at: Union[datetime, str] = Field(
|
||||
description="The created timestamp of the board."
|
||||
)
|
||||
created_at: Union[datetime, str] = Field(description="The created timestamp of the board.")
|
||||
"""The created timestamp of the image."""
|
||||
updated_at: Union[datetime, str] = Field(
|
||||
description="The updated timestamp of the board."
|
||||
)
|
||||
updated_at: Union[datetime, str] = Field(description="The updated timestamp of the board.")
|
||||
"""The updated timestamp of the image."""
|
||||
deleted_at: Union[datetime, str, None] = Field(
|
||||
description="The deleted timestamp of the board."
|
||||
)
|
||||
deleted_at: Union[datetime, str, None] = Field(description="The deleted timestamp of the board.")
|
||||
"""The updated timestamp of the image."""
|
||||
cover_image_name: Optional[str] = Field(
|
||||
description="The name of the cover image of the board."
|
||||
)
|
||||
cover_image_name: Optional[str] = Field(description="The name of the cover image of the board.")
|
||||
"""The name of the cover image of the board."""
|
||||
|
||||
|
||||
class BoardDTO(BoardRecord):
|
||||
"""Deserialized board record with cover image URL and image count."""
|
||||
|
||||
cover_image_name: Optional[str] = Field(
|
||||
description="The name of the board's cover image."
|
||||
)
|
||||
cover_image_name: Optional[str] = Field(description="The name of the board's cover image.")
|
||||
"""The URL of the thumbnail of the most recent image in the board."""
|
||||
image_count: int = Field(description="The number of images in the board.")
|
||||
"""The number of images in the board."""
|
||||
|
@ -20,17 +20,11 @@ class ImageRecord(BaseModel):
|
||||
"""The actual width of the image in px. This may be different from the width in metadata."""
|
||||
height: int = Field(description="The height of the image in px.")
|
||||
"""The actual height of the image in px. This may be different from the height in metadata."""
|
||||
created_at: Union[datetime.datetime, str] = Field(
|
||||
description="The created timestamp of the image."
|
||||
)
|
||||
created_at: Union[datetime.datetime, str] = Field(description="The created timestamp of the image.")
|
||||
"""The created timestamp of the image."""
|
||||
updated_at: Union[datetime.datetime, str] = Field(
|
||||
description="The updated timestamp of the image."
|
||||
)
|
||||
updated_at: Union[datetime.datetime, str] = Field(description="The updated timestamp of the image.")
|
||||
"""The updated timestamp of the image."""
|
||||
deleted_at: Union[datetime.datetime, str, None] = Field(
|
||||
description="The deleted timestamp of the image."
|
||||
)
|
||||
deleted_at: Union[datetime.datetime, str, None] = Field(description="The deleted timestamp of the image.")
|
||||
"""The deleted timestamp of the image."""
|
||||
is_intermediate: bool = Field(description="Whether this is an intermediate image.")
|
||||
"""Whether this is an intermediate image."""
|
||||
@ -55,18 +49,14 @@ class ImageRecordChanges(BaseModel, extra=Extra.forbid):
|
||||
- `is_intermediate`: change the image's `is_intermediate` flag
|
||||
"""
|
||||
|
||||
image_category: Optional[ImageCategory] = Field(
|
||||
description="The image's new category."
|
||||
)
|
||||
image_category: Optional[ImageCategory] = Field(description="The image's new category.")
|
||||
"""The image's new category."""
|
||||
session_id: Optional[StrictStr] = Field(
|
||||
default=None,
|
||||
description="The image's new session ID.",
|
||||
)
|
||||
"""The image's new session ID."""
|
||||
is_intermediate: Optional[StrictBool] = Field(
|
||||
default=None, description="The image's new `is_intermediate` flag."
|
||||
)
|
||||
is_intermediate: Optional[StrictBool] = Field(default=None, description="The image's new `is_intermediate` flag.")
|
||||
"""The image's new `is_intermediate` flag."""
|
||||
|
||||
|
||||
@ -84,9 +74,7 @@ class ImageUrlsDTO(BaseModel):
|
||||
class ImageDTO(ImageRecord, ImageUrlsDTO):
|
||||
"""Deserialized image record, enriched for the frontend."""
|
||||
|
||||
board_id: Optional[str] = Field(
|
||||
description="The id of the board the image belongs to, if one exists."
|
||||
)
|
||||
board_id: Optional[str] = Field(description="The id of the board the image belongs to, if one exists.")
|
||||
"""The id of the board the image belongs to, if one exists."""
|
||||
pass
|
||||
|
||||
@ -110,12 +98,8 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord:
|
||||
|
||||
# TODO: do we really need to handle default values here? ideally the data is the correct shape...
|
||||
image_name = image_dict.get("image_name", "unknown")
|
||||
image_origin = ResourceOrigin(
|
||||
image_dict.get("image_origin", ResourceOrigin.INTERNAL.value)
|
||||
)
|
||||
image_category = ImageCategory(
|
||||
image_dict.get("image_category", ImageCategory.GENERAL.value)
|
||||
)
|
||||
image_origin = ResourceOrigin(image_dict.get("image_origin", ResourceOrigin.INTERNAL.value))
|
||||
image_category = ImageCategory(image_dict.get("image_category", ImageCategory.GENERAL.value))
|
||||
width = image_dict.get("width", 0)
|
||||
height = image_dict.get("height", 0)
|
||||
session_id = image_dict.get("session_id", None)
|
||||
|
@ -8,6 +8,8 @@ from .invoker import InvocationProcessorABC, Invoker
|
||||
from ..models.exceptions import CanceledException
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
|
||||
|
||||
class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
__invoker_thread: Thread
|
||||
__stop_event: Event
|
||||
@ -24,9 +26,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
target=self.__process,
|
||||
kwargs=dict(stop_event=self.__stop_event),
|
||||
)
|
||||
self.__invoker_thread.daemon = (
|
||||
True # TODO: make async and do not use threads
|
||||
)
|
||||
self.__invoker_thread.daemon = True # TODO: make async and do not use threads
|
||||
self.__invoker_thread.start()
|
||||
|
||||
def stop(self, *args, **kwargs) -> None:
|
||||
@ -47,10 +47,8 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
continue
|
||||
|
||||
try:
|
||||
graph_execution_state = (
|
||||
self.__invoker.services.graph_execution_manager.get(
|
||||
queue_item.graph_execution_state_id
|
||||
)
|
||||
graph_execution_state = self.__invoker.services.graph_execution_manager.get(
|
||||
queue_item.graph_execution_state_id
|
||||
)
|
||||
except Exception as e:
|
||||
self.__invoker.services.logger.error("Exception while retrieving session:\n%s" % e)
|
||||
@ -60,11 +58,9 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
error=traceback.format_exc(),
|
||||
)
|
||||
continue
|
||||
|
||||
|
||||
try:
|
||||
invocation = graph_execution_state.execution_graph.get_node(
|
||||
queue_item.invocation_id
|
||||
)
|
||||
invocation = graph_execution_state.execution_graph.get_node(queue_item.invocation_id)
|
||||
except Exception as e:
|
||||
self.__invoker.services.logger.error("Exception while retrieving invocation:\n%s" % e)
|
||||
self.__invoker.services.events.emit_invocation_retrieval_error(
|
||||
@ -82,7 +78,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
self.__invoker.services.events.emit_invocation_started(
|
||||
graph_execution_state_id=graph_execution_state.id,
|
||||
node=invocation.dict(),
|
||||
source_node_id=source_node_id
|
||||
source_node_id=source_node_id,
|
||||
)
|
||||
|
||||
# Invoke
|
||||
@ -95,18 +91,14 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
)
|
||||
|
||||
# Check queue to see if this is canceled, and skip if so
|
||||
if self.__invoker.services.queue.is_canceled(
|
||||
graph_execution_state.id
|
||||
):
|
||||
if self.__invoker.services.queue.is_canceled(graph_execution_state.id):
|
||||
continue
|
||||
|
||||
# Save outputs and history
|
||||
graph_execution_state.complete(invocation.id, outputs)
|
||||
|
||||
# Save the state changes
|
||||
self.__invoker.services.graph_execution_manager.set(
|
||||
graph_execution_state
|
||||
)
|
||||
self.__invoker.services.graph_execution_manager.set(graph_execution_state)
|
||||
|
||||
# Send complete event
|
||||
self.__invoker.services.events.emit_invocation_complete(
|
||||
@ -130,9 +122,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
graph_execution_state.set_node_error(invocation.id, error)
|
||||
|
||||
# Save the state changes
|
||||
self.__invoker.services.graph_execution_manager.set(
|
||||
graph_execution_state
|
||||
)
|
||||
self.__invoker.services.graph_execution_manager.set(graph_execution_state)
|
||||
|
||||
self.__invoker.services.logger.error("Error while invoking:\n%s" % e)
|
||||
# Send error event
|
||||
@ -147,9 +137,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
pass
|
||||
|
||||
# Check queue to see if this is canceled, and skip if so
|
||||
if self.__invoker.services.queue.is_canceled(
|
||||
graph_execution_state.id
|
||||
):
|
||||
if self.__invoker.services.queue.is_canceled(graph_execution_state.id):
|
||||
continue
|
||||
|
||||
# Queue any further commands if invoking all
|
||||
@ -164,12 +152,10 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
node=invocation.dict(),
|
||||
source_node_id=source_node_id,
|
||||
error_type=e.__class__.__name__,
|
||||
error=traceback.format_exc()
|
||||
error=traceback.format_exc(),
|
||||
)
|
||||
elif is_complete:
|
||||
self.__invoker.services.events.emit_graph_execution_complete(
|
||||
graph_execution_state.id
|
||||
)
|
||||
self.__invoker.services.events.emit_graph_execution_complete(graph_execution_state.id)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
pass # Log something? KeyboardInterrupt is probably not going to be seen by the processor
|
||||
|
@ -66,9 +66,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
||||
def get(self, id: str) -> Optional[T]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
f"""SELECT item FROM {self._table_name} WHERE id = ?;""", (str(id),)
|
||||
)
|
||||
self._cursor.execute(f"""SELECT item FROM {self._table_name} WHERE id = ?;""", (str(id),))
|
||||
result = self._cursor.fetchone()
|
||||
finally:
|
||||
self._lock.release()
|
||||
@ -81,9 +79,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
||||
def get_raw(self, id: str) -> Optional[str]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
f"""SELECT item FROM {self._table_name} WHERE id = ?;""", (str(id),)
|
||||
)
|
||||
self._cursor.execute(f"""SELECT item FROM {self._table_name} WHERE id = ?;""", (str(id),))
|
||||
result = self._cursor.fetchone()
|
||||
finally:
|
||||
self._lock.release()
|
||||
@ -96,9 +92,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
||||
def delete(self, id: str):
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
f"""DELETE FROM {self._table_name} WHERE id = ?;""", (str(id),)
|
||||
)
|
||||
self._cursor.execute(f"""DELETE FROM {self._table_name} WHERE id = ?;""", (str(id),))
|
||||
self._conn.commit()
|
||||
finally:
|
||||
self._lock.release()
|
||||
@ -122,13 +116,9 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
||||
|
||||
pageCount = int(count / per_page) + 1
|
||||
|
||||
return PaginatedResults[T](
|
||||
items=items, page=page, pages=pageCount, per_page=per_page, total=count
|
||||
)
|
||||
return PaginatedResults[T](items=items, page=page, pages=pageCount, per_page=per_page, total=count)
|
||||
|
||||
def search(
|
||||
self, query: str, page: int = 0, per_page: int = 10
|
||||
) -> PaginatedResults[T]:
|
||||
def search(self, query: str, page: int = 0, per_page: int = 10) -> PaginatedResults[T]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
@ -149,6 +139,4 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
||||
|
||||
pageCount = int(count / per_page) + 1
|
||||
|
||||
return PaginatedResults[T](
|
||||
items=items, page=page, pages=pageCount, per_page=per_page, total=count
|
||||
)
|
||||
return PaginatedResults[T](items=items, page=page, pages=pageCount, per_page=per_page, total=count)
|
||||
|
@ -17,16 +17,8 @@ from controlnet_aux.util import HWC3, resize_image
|
||||
# If you use this, please Cite "High Quality Edge Thinning using Pure Python", Lvmin Zhang, In Mikubill/sd-webui-controlnet.
|
||||
|
||||
lvmin_kernels_raw = [
|
||||
np.array([
|
||||
[-1, -1, -1],
|
||||
[0, 1, 0],
|
||||
[1, 1, 1]
|
||||
], dtype=np.int32),
|
||||
np.array([
|
||||
[0, -1, -1],
|
||||
[1, 1, -1],
|
||||
[0, 1, 0]
|
||||
], dtype=np.int32)
|
||||
np.array([[-1, -1, -1], [0, 1, 0], [1, 1, 1]], dtype=np.int32),
|
||||
np.array([[0, -1, -1], [1, 1, -1], [0, 1, 0]], dtype=np.int32),
|
||||
]
|
||||
|
||||
lvmin_kernels = []
|
||||
@ -36,16 +28,8 @@ lvmin_kernels += [np.rot90(x, k=2, axes=(0, 1)) for x in lvmin_kernels_raw]
|
||||
lvmin_kernels += [np.rot90(x, k=3, axes=(0, 1)) for x in lvmin_kernels_raw]
|
||||
|
||||
lvmin_prunings_raw = [
|
||||
np.array([
|
||||
[-1, -1, -1],
|
||||
[-1, 1, -1],
|
||||
[0, 0, -1]
|
||||
], dtype=np.int32),
|
||||
np.array([
|
||||
[-1, -1, -1],
|
||||
[-1, 1, -1],
|
||||
[-1, 0, 0]
|
||||
], dtype=np.int32)
|
||||
np.array([[-1, -1, -1], [-1, 1, -1], [0, 0, -1]], dtype=np.int32),
|
||||
np.array([[-1, -1, -1], [-1, 1, -1], [-1, 0, 0]], dtype=np.int32),
|
||||
]
|
||||
|
||||
lvmin_prunings = []
|
||||
@ -99,10 +83,10 @@ def nake_nms(x):
|
||||
################################################################################
|
||||
# FIXME: not using yet, if used in the future will most likely require modification of preprocessors
|
||||
def pixel_perfect_resolution(
|
||||
image: np.ndarray,
|
||||
target_H: int,
|
||||
target_W: int,
|
||||
resize_mode: str,
|
||||
image: np.ndarray,
|
||||
target_H: int,
|
||||
target_W: int,
|
||||
resize_mode: str,
|
||||
) -> int:
|
||||
"""
|
||||
Calculate the estimated resolution for resizing an image while preserving aspect ratio.
|
||||
@ -135,7 +119,7 @@ def pixel_perfect_resolution(
|
||||
|
||||
if resize_mode == "fill_resize":
|
||||
estimation = min(k0, k1) * float(min(raw_H, raw_W))
|
||||
else: # "crop_resize" or "just_resize" (or possibly "just_resize_simple"?)
|
||||
else: # "crop_resize" or "just_resize" (or possibly "just_resize_simple"?)
|
||||
estimation = max(k0, k1) * float(min(raw_H, raw_W))
|
||||
|
||||
# print(f"Pixel Perfect Computation:")
|
||||
@ -154,13 +138,7 @@ def pixel_perfect_resolution(
|
||||
# modified for InvokeAI
|
||||
###########################################################################
|
||||
# def detectmap_proc(detected_map, module, resize_mode, h, w):
|
||||
def np_img_resize(
|
||||
np_img: np.ndarray,
|
||||
resize_mode: str,
|
||||
h: int,
|
||||
w: int,
|
||||
device: torch.device = torch.device('cpu')
|
||||
):
|
||||
def np_img_resize(np_img: np.ndarray, resize_mode: str, h: int, w: int, device: torch.device = torch.device("cpu")):
|
||||
# if 'inpaint' in module:
|
||||
# np_img = np_img.astype(np.float32)
|
||||
# else:
|
||||
@ -184,15 +162,14 @@ def np_img_resize(
|
||||
# below is very boring but do not change these. If you change these Apple or Mac may fail.
|
||||
y = torch.from_numpy(y)
|
||||
y = y.float() / 255.0
|
||||
y = rearrange(y, 'h w c -> 1 c h w')
|
||||
y = rearrange(y, "h w c -> 1 c h w")
|
||||
y = y.clone()
|
||||
# y = y.to(devices.get_device_for("controlnet"))
|
||||
y = y.to(device)
|
||||
y = y.clone()
|
||||
return y
|
||||
|
||||
def high_quality_resize(x: np.ndarray,
|
||||
size):
|
||||
def high_quality_resize(x: np.ndarray, size):
|
||||
# Written by lvmin
|
||||
# Super high-quality control map up-scaling, considering binary, seg, and one-pixel edges
|
||||
inpaint_mask = None
|
||||
@ -244,7 +221,7 @@ def np_img_resize(
|
||||
return y
|
||||
|
||||
# if resize_mode == external_code.ResizeMode.RESIZE:
|
||||
if resize_mode == "just_resize": # RESIZE
|
||||
if resize_mode == "just_resize": # RESIZE
|
||||
np_img = high_quality_resize(np_img, (w, h))
|
||||
np_img = safe_numpy(np_img)
|
||||
return get_pytorch_control(np_img), np_img
|
||||
@ -270,20 +247,21 @@ def np_img_resize(
|
||||
new_h, new_w, _ = np_img.shape
|
||||
pad_h = max(0, (h - new_h) // 2)
|
||||
pad_w = max(0, (w - new_w) // 2)
|
||||
high_quality_background[pad_h:pad_h + new_h, pad_w:pad_w + new_w] = np_img
|
||||
high_quality_background[pad_h : pad_h + new_h, pad_w : pad_w + new_w] = np_img
|
||||
np_img = high_quality_background
|
||||
np_img = safe_numpy(np_img)
|
||||
return get_pytorch_control(np_img), np_img
|
||||
else: # resize_mode == "crop_resize" (INNER_FIT)
|
||||
else: # resize_mode == "crop_resize" (INNER_FIT)
|
||||
k = max(k0, k1)
|
||||
np_img = high_quality_resize(np_img, (safeint(old_w * k), safeint(old_h * k)))
|
||||
new_h, new_w, _ = np_img.shape
|
||||
pad_h = max(0, (new_h - h) // 2)
|
||||
pad_w = max(0, (new_w - w) // 2)
|
||||
np_img = np_img[pad_h:pad_h + h, pad_w:pad_w + w]
|
||||
np_img = np_img[pad_h : pad_h + h, pad_w : pad_w + w]
|
||||
np_img = safe_numpy(np_img)
|
||||
return get_pytorch_control(np_img), np_img
|
||||
|
||||
|
||||
def prepare_control_image(
|
||||
# image used to be Union[PIL.Image.Image, List[PIL.Image.Image], torch.Tensor, List[torch.Tensor]]
|
||||
# but now should be able to assume that image is a single PIL.Image, which simplifies things
|
||||
@ -301,15 +279,17 @@ def prepare_control_image(
|
||||
resize_mode="just_resize_simple",
|
||||
):
|
||||
# FIXME: implement "crop_resize_simple" and "fill_resize_simple", or pull them out
|
||||
if (resize_mode == "just_resize_simple" or
|
||||
resize_mode == "crop_resize_simple" or
|
||||
resize_mode == "fill_resize_simple"):
|
||||
if (
|
||||
resize_mode == "just_resize_simple"
|
||||
or resize_mode == "crop_resize_simple"
|
||||
or resize_mode == "fill_resize_simple"
|
||||
):
|
||||
image = image.convert("RGB")
|
||||
if (resize_mode == "just_resize_simple"):
|
||||
if resize_mode == "just_resize_simple":
|
||||
image = image.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
|
||||
elif (resize_mode == "crop_resize_simple"): # not yet implemented
|
||||
elif resize_mode == "crop_resize_simple": # not yet implemented
|
||||
pass
|
||||
elif (resize_mode == "fill_resize_simple"): # not yet implemented
|
||||
elif resize_mode == "fill_resize_simple": # not yet implemented
|
||||
pass
|
||||
nimage = np.array(image)
|
||||
nimage = nimage[None, :]
|
||||
@ -320,7 +300,7 @@ def prepare_control_image(
|
||||
timage = torch.from_numpy(nimage)
|
||||
|
||||
# use fancy lvmin controlnet resizing
|
||||
elif (resize_mode == "just_resize" or resize_mode == "crop_resize" or resize_mode == "fill_resize"):
|
||||
elif resize_mode == "just_resize" or resize_mode == "crop_resize" or resize_mode == "fill_resize":
|
||||
nimage = np.array(image)
|
||||
timage, nimage = np_img_resize(
|
||||
np_img=nimage,
|
||||
@ -336,7 +316,7 @@ def prepare_control_image(
|
||||
exit(1)
|
||||
|
||||
timage = timage.to(device=device, dtype=dtype)
|
||||
cfg_injection = (control_mode == "more_control" or control_mode == "unbalanced")
|
||||
cfg_injection = control_mode == "more_control" or control_mode == "unbalanced"
|
||||
if do_classifier_free_guidance and not cfg_injection:
|
||||
timage = torch.cat([timage] * 2)
|
||||
return timage
|
||||
|
@ -9,19 +9,16 @@ from ...backend.stable_diffusion import PipelineIntermediateState
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
|
||||
|
||||
def sample_to_lowres_estimated_image(samples, latent_rgb_factors, smooth_matrix = None):
|
||||
def sample_to_lowres_estimated_image(samples, latent_rgb_factors, smooth_matrix=None):
|
||||
latent_image = samples[0].permute(1, 2, 0) @ latent_rgb_factors
|
||||
|
||||
if smooth_matrix is not None:
|
||||
latent_image = latent_image.unsqueeze(0).permute(3, 0, 1, 2)
|
||||
latent_image = torch.nn.functional.conv2d(latent_image, smooth_matrix.reshape((1,1,3,3)), padding=1)
|
||||
latent_image = torch.nn.functional.conv2d(latent_image, smooth_matrix.reshape((1, 1, 3, 3)), padding=1)
|
||||
latent_image = latent_image.permute(1, 2, 3, 0).squeeze(0)
|
||||
|
||||
latents_ubyte = (
|
||||
((latent_image + 1) / 2)
|
||||
.clamp(0, 1) # change scale from -1..1 to 0..1
|
||||
.mul(0xFF) # to 0..255
|
||||
.byte()
|
||||
((latent_image + 1) / 2).clamp(0, 1).mul(0xFF).byte() # change scale from -1..1 to 0..1 # to 0..255
|
||||
).cpu()
|
||||
|
||||
return Image.fromarray(latents_ubyte.numpy())
|
||||
@ -92,6 +89,7 @@ def stable_diffusion_step_callback(
|
||||
total_steps=node["steps"],
|
||||
)
|
||||
|
||||
|
||||
def stable_diffusion_xl_step_callback(
|
||||
context: InvocationContext,
|
||||
node: dict,
|
||||
@ -106,9 +104,9 @@ def stable_diffusion_xl_step_callback(
|
||||
sdxl_latent_rgb_factors = torch.tensor(
|
||||
[
|
||||
# R G B
|
||||
[ 0.3816, 0.4930, 0.5320],
|
||||
[-0.3753, 0.1631, 0.1739],
|
||||
[ 0.1770, 0.3588, -0.2048],
|
||||
[0.3816, 0.4930, 0.5320],
|
||||
[-0.3753, 0.1631, 0.1739],
|
||||
[0.1770, 0.3588, -0.2048],
|
||||
[-0.4350, -0.2644, -0.4289],
|
||||
],
|
||||
dtype=sample.dtype,
|
||||
@ -117,9 +115,9 @@ def stable_diffusion_xl_step_callback(
|
||||
|
||||
sdxl_smooth_matrix = torch.tensor(
|
||||
[
|
||||
#[ 0.0478, 0.1285, 0.0478],
|
||||
#[ 0.1285, 0.2948, 0.1285],
|
||||
#[ 0.0478, 0.1285, 0.0478],
|
||||
# [ 0.0478, 0.1285, 0.0478],
|
||||
# [ 0.1285, 0.2948, 0.1285],
|
||||
# [ 0.0478, 0.1285, 0.0478],
|
||||
[0.0358, 0.0964, 0.0358],
|
||||
[0.0964, 0.4711, 0.0964],
|
||||
[0.0358, 0.0964, 0.0358],
|
||||
@ -143,4 +141,4 @@ def stable_diffusion_xl_step_callback(
|
||||
progress_image=ProgressImage(width=width, height=height, dataURL=dataURL),
|
||||
step=step,
|
||||
total_steps=total_steps,
|
||||
)
|
||||
)
|
||||
|
@ -1,15 +1,6 @@
|
||||
"""
|
||||
Initialization file for invokeai.backend
|
||||
"""
|
||||
from .generator import (
|
||||
InvokeAIGeneratorBasicParams,
|
||||
InvokeAIGenerator,
|
||||
InvokeAIGeneratorOutput,
|
||||
Img2Img,
|
||||
Inpaint
|
||||
)
|
||||
from .model_management import (
|
||||
ModelManager, ModelCache, BaseModelType,
|
||||
ModelType, SubModelType, ModelInfo
|
||||
)
|
||||
from .safety_checker import SafetyChecker
|
||||
from .generator import InvokeAIGeneratorBasicParams, InvokeAIGenerator, InvokeAIGeneratorOutput, Img2Img, Inpaint
|
||||
from .model_management import ModelManager, ModelCache, BaseModelType, ModelType, SubModelType, ModelInfo
|
||||
from .model_management.models import SilenceWarnings
|
||||
|
@ -28,68 +28,71 @@ from diffusers.schedulers import SchedulerMixin as Scheduler
|
||||
import invokeai.backend.util.logging as logger
|
||||
from ..image_util import configure_model_padding
|
||||
from ..util.util import rand_perlin_2d
|
||||
from ..safety_checker import SafetyChecker
|
||||
from ..stable_diffusion.diffusers_pipeline import StableDiffusionGeneratorPipeline
|
||||
from ..stable_diffusion.schedulers import SCHEDULER_MAP
|
||||
|
||||
downsampling = 8
|
||||
|
||||
|
||||
@dataclass
|
||||
class InvokeAIGeneratorBasicParams:
|
||||
seed: Optional[int]=None
|
||||
width: int=512
|
||||
height: int=512
|
||||
cfg_scale: float=7.5
|
||||
steps: int=20
|
||||
ddim_eta: float=0.0
|
||||
scheduler: str='ddim'
|
||||
precision: str='float16'
|
||||
perlin: float=0.0
|
||||
threshold: float=0.0
|
||||
seamless: bool=False
|
||||
seamless_axes: List[str]=field(default_factory=lambda: ['x', 'y'])
|
||||
h_symmetry_time_pct: Optional[float]=None
|
||||
v_symmetry_time_pct: Optional[float]=None
|
||||
seed: Optional[int] = None
|
||||
width: int = 512
|
||||
height: int = 512
|
||||
cfg_scale: float = 7.5
|
||||
steps: int = 20
|
||||
ddim_eta: float = 0.0
|
||||
scheduler: str = "ddim"
|
||||
precision: str = "float16"
|
||||
perlin: float = 0.0
|
||||
threshold: float = 0.0
|
||||
seamless: bool = False
|
||||
seamless_axes: List[str] = field(default_factory=lambda: ["x", "y"])
|
||||
h_symmetry_time_pct: Optional[float] = None
|
||||
v_symmetry_time_pct: Optional[float] = None
|
||||
variation_amount: float = 0.0
|
||||
with_variations: list=field(default_factory=list)
|
||||
safety_checker: Optional[SafetyChecker]=None
|
||||
with_variations: list = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class InvokeAIGeneratorOutput:
|
||||
'''
|
||||
"""
|
||||
InvokeAIGeneratorOutput is a dataclass that contains the outputs of a generation
|
||||
operation, including the image, its seed, the model name used to generate the image
|
||||
and the model hash, as well as all the generate() parameters that went into
|
||||
generating the image (in .params, also available as attributes)
|
||||
'''
|
||||
"""
|
||||
|
||||
image: Image.Image
|
||||
seed: int
|
||||
model_hash: str
|
||||
attention_maps_images: List[Image.Image]
|
||||
params: Namespace
|
||||
|
||||
|
||||
# we are interposing a wrapper around the original Generator classes so that
|
||||
# old code that calls Generate will continue to work.
|
||||
class InvokeAIGenerator(metaclass=ABCMeta):
|
||||
def __init__(self,
|
||||
model_info: dict,
|
||||
params: InvokeAIGeneratorBasicParams=InvokeAIGeneratorBasicParams(),
|
||||
**kwargs,
|
||||
):
|
||||
self.model_info=model_info
|
||||
self.params=params
|
||||
def __init__(
|
||||
self,
|
||||
model_info: dict,
|
||||
params: InvokeAIGeneratorBasicParams = InvokeAIGeneratorBasicParams(),
|
||||
**kwargs,
|
||||
):
|
||||
self.model_info = model_info
|
||||
self.params = params
|
||||
self.kwargs = kwargs
|
||||
|
||||
def generate(
|
||||
self,
|
||||
conditioning: tuple,
|
||||
scheduler,
|
||||
callback: Optional[Callable]=None,
|
||||
step_callback: Optional[Callable]=None,
|
||||
iterations: int=1,
|
||||
callback: Optional[Callable] = None,
|
||||
step_callback: Optional[Callable] = None,
|
||||
iterations: int = 1,
|
||||
**keyword_args,
|
||||
)->Iterator[InvokeAIGeneratorOutput]:
|
||||
'''
|
||||
) -> Iterator[InvokeAIGeneratorOutput]:
|
||||
"""
|
||||
Return an iterator across the indicated number of generations.
|
||||
Each time the iterator is called it will return an InvokeAIGeneratorOutput
|
||||
object. Use like this:
|
||||
@ -109,7 +112,7 @@ class InvokeAIGenerator(metaclass=ABCMeta):
|
||||
for o in outputs:
|
||||
print(o.image, o.seed)
|
||||
|
||||
'''
|
||||
"""
|
||||
generator_args = dataclasses.asdict(self.params)
|
||||
generator_args.update(keyword_args)
|
||||
|
||||
@ -120,22 +123,21 @@ class InvokeAIGenerator(metaclass=ABCMeta):
|
||||
gen_class = self._generator_class()
|
||||
generator = gen_class(model, self.params.precision, **self.kwargs)
|
||||
if self.params.variation_amount > 0:
|
||||
generator.set_variation(generator_args.get('seed'),
|
||||
generator_args.get('variation_amount'),
|
||||
generator_args.get('with_variations')
|
||||
)
|
||||
generator.set_variation(
|
||||
generator_args.get("seed"),
|
||||
generator_args.get("variation_amount"),
|
||||
generator_args.get("with_variations"),
|
||||
)
|
||||
|
||||
if isinstance(model, DiffusionPipeline):
|
||||
for component in [model.unet, model.vae]:
|
||||
configure_model_padding(component,
|
||||
generator_args.get('seamless',False),
|
||||
generator_args.get('seamless_axes')
|
||||
)
|
||||
configure_model_padding(
|
||||
component, generator_args.get("seamless", False), generator_args.get("seamless_axes")
|
||||
)
|
||||
else:
|
||||
configure_model_padding(model,
|
||||
generator_args.get('seamless',False),
|
||||
generator_args.get('seamless_axes')
|
||||
)
|
||||
configure_model_padding(
|
||||
model, generator_args.get("seamless", False), generator_args.get("seamless_axes")
|
||||
)
|
||||
|
||||
iteration_count = range(iterations) if iterations else itertools.count(start=0, step=1)
|
||||
for i in iteration_count:
|
||||
@ -149,66 +151,66 @@ class InvokeAIGenerator(metaclass=ABCMeta):
|
||||
image=results[0][0],
|
||||
seed=results[0][1],
|
||||
attention_maps_images=results[0][2],
|
||||
model_hash = model_hash,
|
||||
params=Namespace(model_name=model_name,**generator_args),
|
||||
model_hash=model_hash,
|
||||
params=Namespace(model_name=model_name, **generator_args),
|
||||
)
|
||||
if callback:
|
||||
callback(output)
|
||||
yield output
|
||||
|
||||
@classmethod
|
||||
def schedulers(self)->List[str]:
|
||||
'''
|
||||
def schedulers(self) -> List[str]:
|
||||
"""
|
||||
Return list of all the schedulers that we currently handle.
|
||||
'''
|
||||
"""
|
||||
return list(SCHEDULER_MAP.keys())
|
||||
|
||||
def load_generator(self, model: StableDiffusionGeneratorPipeline, generator_class: Type[Generator]):
|
||||
return generator_class(model, self.params.precision)
|
||||
|
||||
@classmethod
|
||||
def _generator_class(cls)->Type[Generator]:
|
||||
'''
|
||||
def _generator_class(cls) -> Type[Generator]:
|
||||
"""
|
||||
In derived classes return the name of the generator to apply.
|
||||
If you don't override will return the name of the derived
|
||||
class, which nicely parallels the generator class names.
|
||||
'''
|
||||
"""
|
||||
return Generator
|
||||
|
||||
|
||||
# ------------------------------------
|
||||
class Img2Img(InvokeAIGenerator):
|
||||
def generate(self,
|
||||
init_image: Union[Image.Image, torch.FloatTensor],
|
||||
strength: float=0.75,
|
||||
**keyword_args
|
||||
)->Iterator[InvokeAIGeneratorOutput]:
|
||||
return super().generate(init_image=init_image,
|
||||
strength=strength,
|
||||
**keyword_args
|
||||
)
|
||||
def generate(
|
||||
self, init_image: Union[Image.Image, torch.FloatTensor], strength: float = 0.75, **keyword_args
|
||||
) -> Iterator[InvokeAIGeneratorOutput]:
|
||||
return super().generate(init_image=init_image, strength=strength, **keyword_args)
|
||||
|
||||
@classmethod
|
||||
def _generator_class(cls):
|
||||
from .img2img import Img2Img
|
||||
|
||||
return Img2Img
|
||||
|
||||
|
||||
# ------------------------------------
|
||||
# Takes all the arguments of Img2Img and adds the mask image and the seam/infill stuff
|
||||
class Inpaint(Img2Img):
|
||||
def generate(self,
|
||||
mask_image: Union[Image.Image, torch.FloatTensor],
|
||||
# Seam settings - when 0, doesn't fill seam
|
||||
seam_size: int = 96,
|
||||
seam_blur: int = 16,
|
||||
seam_strength: float = 0.7,
|
||||
seam_steps: int = 30,
|
||||
tile_size: int = 32,
|
||||
inpaint_replace=False,
|
||||
infill_method=None,
|
||||
inpaint_width=None,
|
||||
inpaint_height=None,
|
||||
inpaint_fill: tuple(int) = (0x7F, 0x7F, 0x7F, 0xFF),
|
||||
**keyword_args
|
||||
)->Iterator[InvokeAIGeneratorOutput]:
|
||||
def generate(
|
||||
self,
|
||||
mask_image: Union[Image.Image, torch.FloatTensor],
|
||||
# Seam settings - when 0, doesn't fill seam
|
||||
seam_size: int = 96,
|
||||
seam_blur: int = 16,
|
||||
seam_strength: float = 0.7,
|
||||
seam_steps: int = 30,
|
||||
tile_size: int = 32,
|
||||
inpaint_replace=False,
|
||||
infill_method=None,
|
||||
inpaint_width=None,
|
||||
inpaint_height=None,
|
||||
inpaint_fill: tuple(int) = (0x7F, 0x7F, 0x7F, 0xFF),
|
||||
**keyword_args,
|
||||
) -> Iterator[InvokeAIGeneratorOutput]:
|
||||
return super().generate(
|
||||
mask_image=mask_image,
|
||||
seam_size=seam_size,
|
||||
@ -221,13 +223,16 @@ class Inpaint(Img2Img):
|
||||
inpaint_width=inpaint_width,
|
||||
inpaint_height=inpaint_height,
|
||||
inpaint_fill=inpaint_fill,
|
||||
**keyword_args
|
||||
**keyword_args,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _generator_class(cls):
|
||||
from .inpaint import Inpaint
|
||||
|
||||
return Inpaint
|
||||
|
||||
|
||||
class Generator:
|
||||
downsampling_factor: int
|
||||
latent_channels: int
|
||||
@ -240,7 +245,6 @@ class Generator:
|
||||
self.seed = None
|
||||
self.latent_channels = model.unet.config.in_channels
|
||||
self.downsampling_factor = downsampling # BUG: should come from model or config
|
||||
self.safety_checker = None
|
||||
self.perlin = 0.0
|
||||
self.threshold = 0
|
||||
self.variation_amount = 0
|
||||
@ -254,9 +258,7 @@ class Generator:
|
||||
Returns a function returning an image derived from the prompt and the initial image
|
||||
Return value depends on the seed at the time you call it
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"image_iterator() must be implemented in a descendent class"
|
||||
)
|
||||
raise NotImplementedError("image_iterator() must be implemented in a descendent class")
|
||||
|
||||
def set_variation(self, seed, variation_amount, with_variations):
|
||||
self.seed = seed
|
||||
@ -277,17 +279,13 @@ class Generator:
|
||||
perlin=0.0,
|
||||
h_symmetry_time_pct=None,
|
||||
v_symmetry_time_pct=None,
|
||||
safety_checker: SafetyChecker=None,
|
||||
free_gpu_mem: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
scope = nullcontext
|
||||
self.safety_checker = safety_checker
|
||||
self.free_gpu_mem = free_gpu_mem
|
||||
attention_maps_images = []
|
||||
attention_maps_callback = lambda saver: attention_maps_images.append(
|
||||
saver.get_stacked_maps_image()
|
||||
)
|
||||
attention_maps_callback = lambda saver: attention_maps_images.append(saver.get_stacked_maps_image())
|
||||
make_image = self.get_make_image(
|
||||
sampler=sampler,
|
||||
init_image=init_image,
|
||||
@ -329,17 +327,10 @@ class Generator:
|
||||
# Pass on the seed in case a layer beneath us needs to generate noise on its own.
|
||||
image = make_image(x_T, seed)
|
||||
|
||||
if self.safety_checker is not None:
|
||||
image = self.safety_checker.check(image)
|
||||
|
||||
results.append([image, seed, attention_maps_images])
|
||||
|
||||
if image_callback is not None:
|
||||
attention_maps_image = (
|
||||
None
|
||||
if len(attention_maps_images) == 0
|
||||
else attention_maps_images[-1]
|
||||
)
|
||||
attention_maps_image = None if len(attention_maps_images) == 0 else attention_maps_images[-1]
|
||||
image_callback(
|
||||
image,
|
||||
seed,
|
||||
@ -350,9 +341,7 @@ class Generator:
|
||||
seed = self.new_seed()
|
||||
|
||||
# Free up memory from the last generation.
|
||||
clear_cuda_cache = (
|
||||
kwargs["clear_cuda_cache"] if "clear_cuda_cache" in kwargs else None
|
||||
)
|
||||
clear_cuda_cache = kwargs["clear_cuda_cache"] if "clear_cuda_cache" in kwargs else None
|
||||
if clear_cuda_cache is not None:
|
||||
clear_cuda_cache()
|
||||
|
||||
@ -379,14 +368,8 @@ class Generator:
|
||||
|
||||
# Get the original alpha channel of the mask if there is one.
|
||||
# Otherwise it is some other black/white image format ('1', 'L' or 'RGB')
|
||||
pil_init_mask = (
|
||||
init_mask.getchannel("A")
|
||||
if init_mask.mode == "RGBA"
|
||||
else init_mask.convert("L")
|
||||
)
|
||||
pil_init_image = init_image.convert(
|
||||
"RGBA"
|
||||
) # Add an alpha channel if one doesn't exist
|
||||
pil_init_mask = init_mask.getchannel("A") if init_mask.mode == "RGBA" else init_mask.convert("L")
|
||||
pil_init_image = init_image.convert("RGBA") # Add an alpha channel if one doesn't exist
|
||||
|
||||
# Build an image with only visible pixels from source to use as reference for color-matching.
|
||||
init_rgb_pixels = np.asarray(init_image.convert("RGB"), dtype=np.uint8)
|
||||
@ -412,10 +395,7 @@ class Generator:
|
||||
np_matched_result[:, :, :] = (
|
||||
(
|
||||
(
|
||||
(
|
||||
np_matched_result[:, :, :].astype(np.float32)
|
||||
- gen_means[None, None, :]
|
||||
)
|
||||
(np_matched_result[:, :, :].astype(np.float32) - gen_means[None, None, :])
|
||||
/ gen_std[None, None, :]
|
||||
)
|
||||
* init_std[None, None, :]
|
||||
@ -441,9 +421,7 @@ class Generator:
|
||||
else:
|
||||
blurred_init_mask = pil_init_mask
|
||||
|
||||
multiplied_blurred_init_mask = ImageChops.multiply(
|
||||
blurred_init_mask, self.pil_image.split()[-1]
|
||||
)
|
||||
multiplied_blurred_init_mask = ImageChops.multiply(blurred_init_mask, self.pil_image.split()[-1])
|
||||
|
||||
# Paste original on color-corrected generation (using blurred mask)
|
||||
matched_result.paste(init_image, (0, 0), mask=multiplied_blurred_init_mask)
|
||||
@ -469,10 +447,7 @@ class Generator:
|
||||
|
||||
latent_image = samples[0].permute(1, 2, 0) @ v1_5_latent_rgb_factors
|
||||
latents_ubyte = (
|
||||
((latent_image + 1) / 2)
|
||||
.clamp(0, 1) # change scale from -1..1 to 0..1
|
||||
.mul(0xFF) # to 0..255
|
||||
.byte()
|
||||
((latent_image + 1) / 2).clamp(0, 1).mul(0xFF).byte() # change scale from -1..1 to 0..1 # to 0..255
|
||||
).cpu()
|
||||
|
||||
return Image.fromarray(latents_ubyte.numpy())
|
||||
@ -502,9 +477,7 @@ class Generator:
|
||||
temp_height = int((height + 7) / 8) * 8
|
||||
noise = torch.stack(
|
||||
[
|
||||
rand_perlin_2d(
|
||||
(temp_height, temp_width), (8, 8), device=self.model.device
|
||||
).to(fixdevice)
|
||||
rand_perlin_2d((temp_height, temp_width), (8, 8), device=self.model.device).to(fixdevice)
|
||||
for _ in range(input_channels)
|
||||
],
|
||||
dim=0,
|
||||
@ -581,8 +554,6 @@ class Generator:
|
||||
device=device,
|
||||
)
|
||||
if self.perlin > 0.0:
|
||||
perlin_noise = self.get_perlin_noise(
|
||||
width // self.downsampling_factor, height // self.downsampling_factor
|
||||
)
|
||||
perlin_noise = self.get_perlin_noise(width // self.downsampling_factor, height // self.downsampling_factor)
|
||||
x = (1 - self.perlin) * x + self.perlin * perlin_noise
|
||||
return x
|
||||
|
@ -77,10 +77,7 @@ class Img2Img(Generator):
|
||||
callback=step_callback,
|
||||
seed=seed,
|
||||
)
|
||||
if (
|
||||
pipeline_output.attention_map_saver is not None
|
||||
and attention_maps_callback is not None
|
||||
):
|
||||
if pipeline_output.attention_map_saver is not None and attention_maps_callback is not None:
|
||||
attention_maps_callback(pipeline_output.attention_map_saver)
|
||||
return pipeline.numpy_to_pil(pipeline_output.images)[0]
|
||||
|
||||
@ -91,7 +88,5 @@ class Img2Img(Generator):
|
||||
x = torch.randn_like(like, device=device)
|
||||
if self.perlin > 0.0:
|
||||
shape = like.shape
|
||||
x = (1 - self.perlin) * x + self.perlin * self.get_perlin_noise(
|
||||
shape[3], shape[2]
|
||||
)
|
||||
x = (1 - self.perlin) * x + self.perlin * self.get_perlin_noise(shape[3], shape[2])
|
||||
return x
|
||||
|
@ -68,15 +68,11 @@ class Inpaint(Img2Img):
|
||||
return im
|
||||
|
||||
# Patchmatch (note, we may want to expose patch_size? Increasing it significantly impacts performance though)
|
||||
im_patched_np = PatchMatch.inpaint(
|
||||
im.convert("RGB"), ImageOps.invert(im.split()[-1]), patch_size=3
|
||||
)
|
||||
im_patched_np = PatchMatch.inpaint(im.convert("RGB"), ImageOps.invert(im.split()[-1]), patch_size=3)
|
||||
im_patched = Image.fromarray(im_patched_np, mode="RGB")
|
||||
return im_patched
|
||||
|
||||
def tile_fill_missing(
|
||||
self, im: Image.Image, tile_size: int = 16, seed: Optional[int] = None
|
||||
) -> Image.Image:
|
||||
def tile_fill_missing(self, im: Image.Image, tile_size: int = 16, seed: Optional[int] = None) -> Image.Image:
|
||||
# Only fill if there's an alpha layer
|
||||
if im.mode != "RGBA":
|
||||
return im
|
||||
@ -127,15 +123,11 @@ class Inpaint(Img2Img):
|
||||
|
||||
return si
|
||||
|
||||
def mask_edge(
|
||||
self, mask: Image.Image, edge_size: int, edge_blur: int
|
||||
) -> Image.Image:
|
||||
def mask_edge(self, mask: Image.Image, edge_size: int, edge_blur: int) -> Image.Image:
|
||||
npimg = np.asarray(mask, dtype=np.uint8)
|
||||
|
||||
# Detect any partially transparent regions
|
||||
npgradient = np.uint8(
|
||||
255 * (1.0 - np.floor(np.abs(0.5 - np.float32(npimg) / 255.0) * 2.0))
|
||||
)
|
||||
npgradient = np.uint8(255 * (1.0 - np.floor(np.abs(0.5 - np.float32(npimg) / 255.0) * 2.0)))
|
||||
|
||||
# Detect hard edges
|
||||
npedge = cv2.Canny(npimg, threshold1=100, threshold2=200)
|
||||
@ -144,9 +136,7 @@ class Inpaint(Img2Img):
|
||||
npmask = npgradient + npedge
|
||||
|
||||
# Expand
|
||||
npmask = cv2.dilate(
|
||||
npmask, np.ones((3, 3), np.uint8), iterations=int(edge_size / 2)
|
||||
)
|
||||
npmask = cv2.dilate(npmask, np.ones((3, 3), np.uint8), iterations=int(edge_size / 2))
|
||||
|
||||
new_mask = Image.fromarray(npmask)
|
||||
|
||||
@ -242,25 +232,19 @@ class Inpaint(Img2Img):
|
||||
if infill_method == "patchmatch" and PatchMatch.patchmatch_available():
|
||||
init_filled = self.infill_patchmatch(self.pil_image.copy())
|
||||
elif infill_method == "tile":
|
||||
init_filled = self.tile_fill_missing(
|
||||
self.pil_image.copy(), seed=self.seed, tile_size=tile_size
|
||||
)
|
||||
init_filled = self.tile_fill_missing(self.pil_image.copy(), seed=self.seed, tile_size=tile_size)
|
||||
elif infill_method == "solid":
|
||||
solid_bg = Image.new("RGBA", init_image.size, inpaint_fill)
|
||||
init_filled = Image.alpha_composite(solid_bg, init_image)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Non-supported infill type {infill_method}", infill_method
|
||||
)
|
||||
raise ValueError(f"Non-supported infill type {infill_method}", infill_method)
|
||||
init_filled.paste(init_image, (0, 0), init_image.split()[-1])
|
||||
|
||||
# Resize if requested for inpainting
|
||||
if inpaint_width and inpaint_height:
|
||||
init_filled = init_filled.resize((inpaint_width, inpaint_height))
|
||||
|
||||
debug_image(
|
||||
init_filled, "init_filled", debug_status=self.enable_image_debugging
|
||||
)
|
||||
debug_image(init_filled, "init_filled", debug_status=self.enable_image_debugging)
|
||||
|
||||
# Create init tensor
|
||||
init_image = image_resized_to_grid_as_tensor(init_filled.convert("RGB"))
|
||||
@ -289,9 +273,7 @@ class Inpaint(Img2Img):
|
||||
"mask_image AFTER multiply with pil_image",
|
||||
debug_status=self.enable_image_debugging,
|
||||
)
|
||||
mask: torch.FloatTensor = image_resized_to_grid_as_tensor(
|
||||
mask_image, normalize=False
|
||||
)
|
||||
mask: torch.FloatTensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
|
||||
else:
|
||||
mask: torch.FloatTensor = mask_image
|
||||
|
||||
@ -302,9 +284,9 @@ class Inpaint(Img2Img):
|
||||
|
||||
# todo: support cross-attention control
|
||||
uc, c, _ = conditioning
|
||||
conditioning_data = ConditioningData(
|
||||
uc, c, cfg_scale
|
||||
).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta)
|
||||
conditioning_data = ConditioningData(uc, c, cfg_scale).add_scheduler_args_if_applicable(
|
||||
pipeline.scheduler, eta=ddim_eta
|
||||
)
|
||||
|
||||
def make_image(x_T: torch.Tensor, seed: int):
|
||||
pipeline_output = pipeline.inpaint_from_embeddings(
|
||||
@ -318,15 +300,10 @@ class Inpaint(Img2Img):
|
||||
seed=seed,
|
||||
)
|
||||
|
||||
if (
|
||||
pipeline_output.attention_map_saver is not None
|
||||
and attention_maps_callback is not None
|
||||
):
|
||||
if pipeline_output.attention_map_saver is not None and attention_maps_callback is not None:
|
||||
attention_maps_callback(pipeline_output.attention_map_saver)
|
||||
|
||||
result = self.postprocess_size_and_mask(
|
||||
pipeline.numpy_to_pil(pipeline_output.images)[0]
|
||||
)
|
||||
result = self.postprocess_size_and_mask(pipeline.numpy_to_pil(pipeline_output.images)[0])
|
||||
|
||||
# Seam paint if this is our first pass (seam_size set to 0 during seam painting)
|
||||
if seam_size > 0:
|
||||
|
@ -8,9 +8,7 @@ from .txt2mask import Txt2Mask
|
||||
from .util import InitImageResizer, make_grid
|
||||
|
||||
|
||||
def debug_image(
|
||||
debug_image, debug_text, debug_show=True, debug_result=False, debug_status=False
|
||||
):
|
||||
def debug_image(debug_image, debug_text, debug_show=True, debug_result=False, debug_status=False):
|
||||
if not debug_status:
|
||||
return
|
||||
|
||||
|
34
invokeai/backend/image_util/invisible_watermark.py
Normal file
@ -0,0 +1,34 @@
|
||||
"""
|
||||
This module defines a singleton object, "invisible_watermark" that
|
||||
wraps the invisible watermark model. It respects the global "invisible_watermark"
|
||||
configuration variable, that allows the watermarking to be supressed.
|
||||
"""
|
||||
import numpy as np
|
||||
import cv2
|
||||
from PIL import Image
|
||||
from imwatermark import WatermarkEncoder
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
import invokeai.backend.util.logging as logger
|
||||
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
|
||||
|
||||
class InvisibleWatermark:
|
||||
"""
|
||||
Wrapper around InvisibleWatermark module.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def invisible_watermark_available(self) -> bool:
|
||||
return config.invisible_watermark
|
||||
|
||||
@classmethod
|
||||
def add_watermark(self, image: Image, watermark_text: str) -> Image:
|
||||
if not self.invisible_watermark_available():
|
||||
return image
|
||||
logger.debug(f'Applying invisible watermark "{watermark_text}"')
|
||||
bgr = cv2.cvtColor(np.array(image.convert("RGB")), cv2.COLOR_RGB2BGR)
|
||||
encoder = WatermarkEncoder()
|
||||
encoder.set_watermark("bytes", watermark_text.encode("utf-8"))
|
||||
bgr_encoded = encoder.encode(bgr, "dwtDct")
|
||||
return Image.fromarray(cv2.cvtColor(bgr_encoded, cv2.COLOR_BGR2RGB)).convert("RGBA")
|
@ -7,8 +7,10 @@ be suppressed or deferred
|
||||
import numpy as np
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
|
||||
|
||||
class PatchMatch:
|
||||
"""
|
||||
Thin class wrapper around the patchmatch function.
|
||||
|
@ -34,9 +34,7 @@ class PngWriter:
|
||||
|
||||
# saves image named _image_ to outdir/name, writing metadata from prompt
|
||||
# returns full path of output
|
||||
def save_image_and_prompt_to_png(
|
||||
self, image, dream_prompt, name, metadata=None, compress_level=6
|
||||
):
|
||||
def save_image_and_prompt_to_png(self, image, dream_prompt, name, metadata=None, compress_level=6):
|
||||
path = os.path.join(self.outdir, name)
|
||||
info = PngImagePlugin.PngInfo()
|
||||
info.add_text("Dream", dream_prompt)
|
||||
@ -114,8 +112,6 @@ class PromptFormatter:
|
||||
if opt.variation_amount > 0:
|
||||
switches.append(f"-v{opt.variation_amount}")
|
||||
if opt.with_variations:
|
||||
formatted_variations = ",".join(
|
||||
f"{seed}:{weight}" for seed, weight in opt.with_variations
|
||||
)
|
||||
formatted_variations = ",".join(f"{seed}:{weight}" for seed, weight in opt.with_variations)
|
||||
switches.append(f"-V{formatted_variations}")
|
||||
return " ".join(switches)
|
||||
|
64
invokeai/backend/image_util/safety_checker.py
Normal file
@ -0,0 +1,64 @@
|
||||
"""
|
||||
This module defines a singleton object, "safety_checker" that
|
||||
wraps the safety_checker model. It respects the global "nsfw_checker"
|
||||
configuration variable, that allows the checker to be supressed.
|
||||
"""
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from invokeai.backend import SilenceWarnings
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.util.devices import choose_torch_device
|
||||
import invokeai.backend.util.logging as logger
|
||||
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
|
||||
CHECKER_PATH = "core/convert/stable-diffusion-safety-checker"
|
||||
|
||||
|
||||
class SafetyChecker:
|
||||
"""
|
||||
Wrapper around SafetyChecker model.
|
||||
"""
|
||||
|
||||
safety_checker = None
|
||||
feature_extractor = None
|
||||
tried_load: bool = False
|
||||
|
||||
@classmethod
|
||||
def _load_safety_checker(self):
|
||||
if self.tried_load:
|
||||
return
|
||||
|
||||
if config.nsfw_checker:
|
||||
try:
|
||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
from transformers import AutoFeatureExtractor
|
||||
|
||||
self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(config.models_path / CHECKER_PATH)
|
||||
self.feature_extractor = AutoFeatureExtractor.from_pretrained(config.models_path / CHECKER_PATH)
|
||||
logger.info("NSFW checker initialized")
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not load NSFW checker: {str(e)}")
|
||||
else:
|
||||
logger.info("NSFW checker loading disabled")
|
||||
self.tried_load = True
|
||||
|
||||
@classmethod
|
||||
def safety_checker_available(self) -> bool:
|
||||
self._load_safety_checker()
|
||||
return self.safety_checker is not None
|
||||
|
||||
@classmethod
|
||||
def has_nsfw_concept(self, image: Image) -> bool:
|
||||
if not self.safety_checker_available():
|
||||
return False
|
||||
|
||||
device = choose_torch_device()
|
||||
features = self.feature_extractor([image], return_tensors="pt")
|
||||
features.to(device)
|
||||
self.safety_checker.to(device)
|
||||
x_image = np.array(image).astype(np.float32) / 255.0
|
||||
x_image = x_image[None].transpose(0, 3, 1, 2)
|
||||
with SilenceWarnings():
|
||||
checked_image, has_nsfw_concept = self.safety_checker(images=x_image, clip_input=features.pixel_values)
|
||||
return has_nsfw_concept[0]
|
@ -5,12 +5,8 @@ def _conv_forward_asymmetric(self, input, weight, bias):
|
||||
"""
|
||||
Patch for Conv2d._conv_forward that supports asymmetric padding
|
||||
"""
|
||||
working = nn.functional.pad(
|
||||
input, self.asymmetric_padding["x"], mode=self.asymmetric_padding_mode["x"]
|
||||
)
|
||||
working = nn.functional.pad(
|
||||
working, self.asymmetric_padding["y"], mode=self.asymmetric_padding_mode["y"]
|
||||
)
|
||||
working = nn.functional.pad(input, self.asymmetric_padding["x"], mode=self.asymmetric_padding_mode["x"])
|
||||
working = nn.functional.pad(working, self.asymmetric_padding["y"], mode=self.asymmetric_padding_mode["y"])
|
||||
return nn.functional.conv2d(
|
||||
working,
|
||||
weight,
|
||||
@ -32,18 +28,14 @@ def configure_model_padding(model, seamless, seamless_axes):
|
||||
if seamless:
|
||||
m.asymmetric_padding_mode = {}
|
||||
m.asymmetric_padding = {}
|
||||
m.asymmetric_padding_mode["x"] = (
|
||||
"circular" if ("x" in seamless_axes) else "constant"
|
||||
)
|
||||
m.asymmetric_padding_mode["x"] = "circular" if ("x" in seamless_axes) else "constant"
|
||||
m.asymmetric_padding["x"] = (
|
||||
m._reversed_padding_repeated_twice[0],
|
||||
m._reversed_padding_repeated_twice[1],
|
||||
0,
|
||||
0,
|
||||
)
|
||||
m.asymmetric_padding_mode["y"] = (
|
||||
"circular" if ("y" in seamless_axes) else "constant"
|
||||
)
|
||||
m.asymmetric_padding_mode["y"] = "circular" if ("y" in seamless_axes) else "constant"
|
||||
m.asymmetric_padding["y"] = (
|
||||
0,
|
||||
0,
|
||||
|
@ -39,23 +39,18 @@ CLIPSEG_MODEL = "CIDAS/clipseg-rd64-refined"
|
||||
CLIPSEG_SIZE = 352
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
|
||||
|
||||
class SegmentedGrayscale(object):
|
||||
def __init__(self, image: Image, heatmap: torch.Tensor):
|
||||
self.heatmap = heatmap
|
||||
self.image = image
|
||||
|
||||
def to_grayscale(self, invert: bool = False) -> Image:
|
||||
return self._rescale(
|
||||
Image.fromarray(
|
||||
np.uint8(255 - self.heatmap * 255 if invert else self.heatmap * 255)
|
||||
)
|
||||
)
|
||||
return self._rescale(Image.fromarray(np.uint8(255 - self.heatmap * 255 if invert else self.heatmap * 255)))
|
||||
|
||||
def to_mask(self, threshold: float = 0.5) -> Image:
|
||||
discrete_heatmap = self.heatmap.lt(threshold).int()
|
||||
return self._rescale(
|
||||
Image.fromarray(np.uint8(discrete_heatmap * 255), mode="L")
|
||||
)
|
||||
return self._rescale(Image.fromarray(np.uint8(discrete_heatmap * 255), mode="L"))
|
||||
|
||||
def to_transparent(self, invert: bool = False) -> Image:
|
||||
transparent_image = self.image.copy()
|
||||
@ -67,11 +62,7 @@ class SegmentedGrayscale(object):
|
||||
|
||||
# unscales and uncrops the 352x352 heatmap so that it matches the image again
|
||||
def _rescale(self, heatmap: Image) -> Image:
|
||||
size = (
|
||||
self.image.width
|
||||
if (self.image.width > self.image.height)
|
||||
else self.image.height
|
||||
)
|
||||
size = self.image.width if (self.image.width > self.image.height) else self.image.height
|
||||
resized_image = heatmap.resize((size, size), resample=Image.Resampling.LANCZOS)
|
||||
return resized_image.crop((0, 0, self.image.width, self.image.height))
|
||||
|
||||
@ -87,12 +78,8 @@ class Txt2Mask(object):
|
||||
|
||||
# BUG: we are not doing anything with the device option at this time
|
||||
self.device = device
|
||||
self.processor = AutoProcessor.from_pretrained(
|
||||
CLIPSEG_MODEL, cache_dir=config.cache_dir
|
||||
)
|
||||
self.model = CLIPSegForImageSegmentation.from_pretrained(
|
||||
CLIPSEG_MODEL, cache_dir=config.cache_dir
|
||||
)
|
||||
self.processor = AutoProcessor.from_pretrained(CLIPSEG_MODEL, cache_dir=config.cache_dir)
|
||||
self.model = CLIPSegForImageSegmentation.from_pretrained(CLIPSEG_MODEL, cache_dir=config.cache_dir)
|
||||
|
||||
@torch.no_grad()
|
||||
def segment(self, image, prompt: str) -> SegmentedGrayscale:
|
||||
@ -107,9 +94,7 @@ class Txt2Mask(object):
|
||||
image = ImageOps.exif_transpose(image)
|
||||
img = self._scale_and_crop(image)
|
||||
|
||||
inputs = self.processor(
|
||||
text=[prompt], images=[img], padding=True, return_tensors="pt"
|
||||
)
|
||||
inputs = self.processor(text=[prompt], images=[img], padding=True, return_tensors="pt")
|
||||
outputs = self.model(**inputs)
|
||||
heatmap = torch.sigmoid(outputs.logits)
|
||||
return SegmentedGrayscale(image, heatmap)
|
||||
|
36
invokeai/backend/install/check_root.py
Normal file
@ -0,0 +1,36 @@
|
||||
"""
|
||||
Check that the invokeai_root is correctly configured and exit if not.
|
||||
"""
|
||||
import sys
|
||||
from invokeai.app.services.config import (
|
||||
InvokeAIAppConfig,
|
||||
)
|
||||
|
||||
|
||||
def check_invokeai_root(config: InvokeAIAppConfig):
|
||||
try:
|
||||
assert config.model_conf_path.exists(), f"{config.model_conf_path} not found"
|
||||
assert config.db_path.parent.exists(), f"{config.db_path.parent} not found"
|
||||
assert config.models_path.exists(), f"{config.models_path} not found"
|
||||
for model in [
|
||||
"CLIP-ViT-bigG-14-laion2B-39B-b160k",
|
||||
"bert-base-uncased",
|
||||
"clip-vit-large-patch14",
|
||||
"sd-vae-ft-mse",
|
||||
"stable-diffusion-2-clip",
|
||||
"stable-diffusion-safety-checker",
|
||||
]:
|
||||
path = config.models_path / f"core/convert/{model}"
|
||||
assert path.exists(), f"{path} is missing"
|
||||
except Exception as e:
|
||||
print()
|
||||
print(f"An exception has occurred: {str(e)}")
|
||||
print("== STARTUP ABORTED ==")
|
||||
print("** One or more necessary files is missing from your InvokeAI root directory **")
|
||||
print("** Please rerun the configuration script to fix this problem. **")
|
||||
print("** From the launcher, selection option [7]. **")
|
||||
print(
|
||||
'** From the command line, activate the virtual environment and run "invokeai-configure --yes --skip-sd-weights" **'
|
||||
)
|
||||
input("Press any key to continue...")
|
||||
sys.exit(0)
|
@ -13,8 +13,8 @@ import os
|
||||
import shutil
|
||||
import textwrap
|
||||
import traceback
|
||||
import warnings
|
||||
import yaml
|
||||
import warnings
|
||||
from argparse import Namespace
|
||||
from pathlib import Path
|
||||
from shutil import get_terminal_size
|
||||
@ -32,6 +32,7 @@ from omegaconf import OmegaConf
|
||||
from tqdm import tqdm
|
||||
from transformers import (
|
||||
CLIPTextModel,
|
||||
CLIPTextConfig,
|
||||
CLIPTokenizer,
|
||||
AutoFeatureExtractor,
|
||||
BertTokenizerFast,
|
||||
@ -44,6 +45,7 @@ from invokeai.app.services.config import (
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from invokeai.frontend.install.model_install import addModelsForm, process_and_execute
|
||||
from invokeai.frontend.install.widgets import (
|
||||
SingleSelectColumns,
|
||||
CenteredButtonPress,
|
||||
FileBox,
|
||||
IntTitleSlider,
|
||||
@ -58,9 +60,7 @@ from invokeai.backend.install.model_install_backend import (
|
||||
InstallSelections,
|
||||
ModelInstall,
|
||||
)
|
||||
from invokeai.backend.model_management.model_probe import (
|
||||
ModelType, BaseModelType
|
||||
)
|
||||
from invokeai.backend.model_management.model_probe import ModelType, BaseModelType
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
transformers.logging.set_verbosity_error()
|
||||
@ -75,7 +75,7 @@ Model_dir = "models"
|
||||
Default_config_file = config.model_conf_path
|
||||
SD_Configs = config.legacy_conf_path
|
||||
|
||||
PRECISION_CHOICES = ['auto','float16','float32']
|
||||
PRECISION_CHOICES = ["auto", "float16", "float32"]
|
||||
|
||||
INIT_FILE_PREAMBLE = """# InvokeAI initialization file
|
||||
# This is the InvokeAI initialization file, which contains command-line default values.
|
||||
@ -83,7 +83,8 @@ INIT_FILE_PREAMBLE = """# InvokeAI initialization file
|
||||
# or renaming it and then running invokeai-configure again.
|
||||
"""
|
||||
|
||||
logger=InvokeAILogger.getLogger()
|
||||
logger = InvokeAILogger.getLogger()
|
||||
|
||||
|
||||
# --------------------------------------------
|
||||
def postscript(errors: None):
|
||||
@ -106,7 +107,9 @@ Add the '--help' argument to see all of the command-line switches available for
|
||||
"""
|
||||
|
||||
else:
|
||||
message = "\n** There were errors during installation. It is possible some of the models were not fully downloaded.\n"
|
||||
message = (
|
||||
"\n** There were errors during installation. It is possible some of the models were not fully downloaded.\n"
|
||||
)
|
||||
for err in errors:
|
||||
message += f"\t - {err}\n"
|
||||
message += "Please check the logs above and correct any issues."
|
||||
@ -167,9 +170,7 @@ def download_with_progress_bar(model_url: str, model_dest: str, label: str = "th
|
||||
logger.info(f"Installing {label} model file {model_url}...")
|
||||
if not os.path.exists(model_dest):
|
||||
os.makedirs(os.path.dirname(model_dest), exist_ok=True)
|
||||
request.urlretrieve(
|
||||
model_url, model_dest, ProgressBar(os.path.basename(model_dest))
|
||||
)
|
||||
request.urlretrieve(model_url, model_dest, ProgressBar(os.path.basename(model_dest)))
|
||||
logger.info("...downloaded successfully")
|
||||
else:
|
||||
logger.info("...exists")
|
||||
@ -180,81 +181,93 @@ def download_with_progress_bar(model_url: str, model_dest: str, label: str = "th
|
||||
|
||||
|
||||
def download_conversion_models():
|
||||
target_dir = config.root_path / 'models/core/convert'
|
||||
target_dir = config.root_path / "models/core/convert"
|
||||
kwargs = dict() # for future use
|
||||
try:
|
||||
logger.info('Downloading core tokenizers and text encoders')
|
||||
logger.info("Downloading core tokenizers and text encoders")
|
||||
|
||||
# bert
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
||||
bert = BertTokenizerFast.from_pretrained("bert-base-uncased", **kwargs)
|
||||
bert.save_pretrained(target_dir / 'bert-base-uncased', safe_serialization=True)
|
||||
|
||||
bert.save_pretrained(target_dir / "bert-base-uncased", safe_serialization=True)
|
||||
|
||||
# sd-1
|
||||
repo_id = 'openai/clip-vit-large-patch14'
|
||||
hf_download_from_pretrained(CLIPTokenizer, repo_id, target_dir / 'clip-vit-large-patch14')
|
||||
hf_download_from_pretrained(CLIPTextModel, repo_id, target_dir / 'clip-vit-large-patch14')
|
||||
repo_id = "openai/clip-vit-large-patch14"
|
||||
hf_download_from_pretrained(CLIPTokenizer, repo_id, target_dir / "clip-vit-large-patch14")
|
||||
hf_download_from_pretrained(CLIPTextModel, repo_id, target_dir / "clip-vit-large-patch14")
|
||||
|
||||
# sd-2
|
||||
repo_id = "stabilityai/stable-diffusion-2"
|
||||
pipeline = CLIPTokenizer.from_pretrained(repo_id, subfolder="tokenizer", **kwargs)
|
||||
pipeline.save_pretrained(target_dir / 'stable-diffusion-2-clip' / 'tokenizer', safe_serialization=True)
|
||||
pipeline.save_pretrained(target_dir / "stable-diffusion-2-clip" / "tokenizer", safe_serialization=True)
|
||||
|
||||
pipeline = CLIPTextModel.from_pretrained(repo_id, subfolder="text_encoder", **kwargs)
|
||||
pipeline.save_pretrained(target_dir / 'stable-diffusion-2-clip' / 'text_encoder', safe_serialization=True)
|
||||
pipeline.save_pretrained(target_dir / "stable-diffusion-2-clip" / "text_encoder", safe_serialization=True)
|
||||
|
||||
# sd-xl - tokenizer_2
|
||||
repo_id = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
|
||||
_, model_name = repo_id.split("/")
|
||||
pipeline = CLIPTokenizer.from_pretrained(repo_id, **kwargs)
|
||||
pipeline.save_pretrained(target_dir / model_name, safe_serialization=True)
|
||||
|
||||
pipeline = CLIPTextConfig.from_pretrained(repo_id, **kwargs)
|
||||
pipeline.save_pretrained(target_dir / model_name, safe_serialization=True)
|
||||
|
||||
# VAE
|
||||
logger.info('Downloading stable diffusion VAE')
|
||||
vae = AutoencoderKL.from_pretrained('stabilityai/sd-vae-ft-mse', **kwargs)
|
||||
vae.save_pretrained(target_dir / 'sd-vae-ft-mse', safe_serialization=True)
|
||||
logger.info("Downloading stable diffusion VAE")
|
||||
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", **kwargs)
|
||||
vae.save_pretrained(target_dir / "sd-vae-ft-mse", safe_serialization=True)
|
||||
|
||||
# safety checking
|
||||
logger.info('Downloading safety checker')
|
||||
logger.info("Downloading safety checker")
|
||||
repo_id = "CompVis/stable-diffusion-safety-checker"
|
||||
pipeline = AutoFeatureExtractor.from_pretrained(repo_id,**kwargs)
|
||||
pipeline.save_pretrained(target_dir / 'stable-diffusion-safety-checker', safe_serialization=True)
|
||||
pipeline = AutoFeatureExtractor.from_pretrained(repo_id, **kwargs)
|
||||
pipeline.save_pretrained(target_dir / "stable-diffusion-safety-checker", safe_serialization=True)
|
||||
|
||||
pipeline = StableDiffusionSafetyChecker.from_pretrained(repo_id,**kwargs)
|
||||
pipeline.save_pretrained(target_dir / 'stable-diffusion-safety-checker', safe_serialization=True)
|
||||
pipeline = StableDiffusionSafetyChecker.from_pretrained(repo_id, **kwargs)
|
||||
pipeline.save_pretrained(target_dir / "stable-diffusion-safety-checker", safe_serialization=True)
|
||||
except KeyboardInterrupt:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(str(e))
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
def download_realesrgan():
|
||||
logger.info("Installing ESRGAN Upscaling models...")
|
||||
URLs = [
|
||||
dict(
|
||||
url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
|
||||
dest = "core/upscaling/realesrgan/RealESRGAN_x4plus.pth",
|
||||
description = "RealESRGAN_x4plus.pth",
|
||||
url="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
|
||||
dest="core/upscaling/realesrgan/RealESRGAN_x4plus.pth",
|
||||
description="RealESRGAN_x4plus.pth",
|
||||
),
|
||||
dict(
|
||||
url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
|
||||
dest = "core/upscaling/realesrgan/RealESRGAN_x4plus_anime_6B.pth",
|
||||
description = "RealESRGAN_x4plus_anime_6B.pth",
|
||||
url="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
|
||||
dest="core/upscaling/realesrgan/RealESRGAN_x4plus_anime_6B.pth",
|
||||
description="RealESRGAN_x4plus_anime_6B.pth",
|
||||
),
|
||||
dict(
|
||||
url= "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
|
||||
dest= "core/upscaling/realesrgan/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
|
||||
description = "ESRGAN_SRx4_DF2KOST_official.pth",
|
||||
url="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
|
||||
dest="core/upscaling/realesrgan/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
|
||||
description="ESRGAN_SRx4_DF2KOST_official.pth",
|
||||
),
|
||||
dict(
|
||||
url= "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
|
||||
dest= "core/upscaling/realesrgan/RealESRGAN_x2plus.pth",
|
||||
description = "RealESRGAN_x2plus.pth",
|
||||
url="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
|
||||
dest="core/upscaling/realesrgan/RealESRGAN_x2plus.pth",
|
||||
description="RealESRGAN_x2plus.pth",
|
||||
),
|
||||
]
|
||||
for model in URLs:
|
||||
download_with_progress_bar(model['url'], config.models_path / model['dest'], model['description'])
|
||||
download_with_progress_bar(model["url"], config.models_path / model["dest"], model["description"])
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
def download_support_models():
|
||||
download_realesrgan()
|
||||
download_conversion_models()
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
def get_root(root: str = None) -> str:
|
||||
if root:
|
||||
@ -264,6 +277,7 @@ def get_root(root: str = None) -> str:
|
||||
else:
|
||||
return str(config.root_path)
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
class editOptsForm(CyclingForm, npyscreen.FormMultiPage):
|
||||
# for responsive resizing - disabled
|
||||
@ -272,14 +286,14 @@ class editOptsForm(CyclingForm, npyscreen.FormMultiPage):
|
||||
def create(self):
|
||||
program_opts = self.parentApp.program_opts
|
||||
old_opts = self.parentApp.invokeai_opts
|
||||
first_time = not (config.root_path / 'invokeai.yaml').exists()
|
||||
first_time = not (config.root_path / "invokeai.yaml").exists()
|
||||
access_token = HfFolder.get_token()
|
||||
window_width, window_height = get_terminal_size()
|
||||
label = """Configure startup settings. You can come back and change these later.
|
||||
Use ctrl-N and ctrl-P to move to the <N>ext and <P>revious fields.
|
||||
Use cursor arrows to make a checkbox selection, and space to toggle.
|
||||
"""
|
||||
for i in textwrap.wrap(label,width=window_width-6):
|
||||
for i in textwrap.wrap(label, width=window_width - 6):
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.FixedText,
|
||||
value=i,
|
||||
@ -287,50 +301,9 @@ Use cursor arrows to make a checkbox selection, and space to toggle.
|
||||
color="CONTROL",
|
||||
)
|
||||
|
||||
self.nextrely += 1
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.TitleFixedText,
|
||||
name="== BASIC OPTIONS ==",
|
||||
begin_entry_at=0,
|
||||
editable=False,
|
||||
color="CONTROL",
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.nextrely -= 1
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.FixedText,
|
||||
value="Select an output directory for images:",
|
||||
editable=False,
|
||||
color="CONTROL",
|
||||
)
|
||||
self.outdir = self.add_widget_intelligent(
|
||||
npyscreen.TitleFilename,
|
||||
name="(<tab> autocompletes, ctrl-N advances):",
|
||||
value=str(default_output_dir()),
|
||||
select_dir=True,
|
||||
must_exist=False,
|
||||
use_two_lines=False,
|
||||
labelColor="GOOD",
|
||||
begin_entry_at=40,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.nextrely += 1
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.FixedText,
|
||||
value="Activate the NSFW checker to blur images showing potential sexual imagery:",
|
||||
editable=False,
|
||||
color="CONTROL",
|
||||
)
|
||||
self.nsfw_checker = self.add_widget_intelligent(
|
||||
npyscreen.Checkbox,
|
||||
name="NSFW checker",
|
||||
value=old_opts.nsfw_checker,
|
||||
relx=5,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.nextrely += 1
|
||||
label = """HuggingFace access token (OPTIONAL) for automatic model downloads. See https://huggingface.co/settings/tokens."""
|
||||
for line in textwrap.wrap(label,width=window_width-6):
|
||||
for line in textwrap.wrap(label, width=window_width - 6):
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.FixedText,
|
||||
value=line,
|
||||
@ -347,15 +320,6 @@ Use cursor arrows to make a checkbox selection, and space to toggle.
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.nextrely += 1
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.TitleFixedText,
|
||||
name="== ADVANCED OPTIONS ==",
|
||||
begin_entry_at=0,
|
||||
editable=False,
|
||||
color="CONTROL",
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.nextrely -= 1
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.TitleFixedText,
|
||||
name="GPU Management",
|
||||
@ -369,34 +333,47 @@ Use cursor arrows to make a checkbox selection, and space to toggle.
|
||||
npyscreen.Checkbox,
|
||||
name="Free GPU memory after each generation",
|
||||
value=old_opts.free_gpu_mem,
|
||||
max_width=45,
|
||||
relx=5,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.nextrely -= 1
|
||||
self.xformers_enabled = self.add_widget_intelligent(
|
||||
npyscreen.Checkbox,
|
||||
name="Enable xformers support if available",
|
||||
name="Enable xformers support",
|
||||
value=old_opts.xformers_enabled,
|
||||
relx=5,
|
||||
max_width=30,
|
||||
relx=50,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.nextrely -= 1
|
||||
self.always_use_cpu = self.add_widget_intelligent(
|
||||
npyscreen.Checkbox,
|
||||
name="Force CPU to be used on GPU systems",
|
||||
value=old_opts.always_use_cpu,
|
||||
relx=5,
|
||||
relx=80,
|
||||
scroll_exit=True,
|
||||
)
|
||||
precision = old_opts.precision or (
|
||||
"float32" if program_opts.full_precision else "auto"
|
||||
precision = old_opts.precision or ("float32" if program_opts.full_precision else "auto")
|
||||
self.nextrely += 1
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.TitleFixedText,
|
||||
name="Floating Point Precision",
|
||||
begin_entry_at=0,
|
||||
editable=False,
|
||||
color="CONTROL",
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.nextrely -= 1
|
||||
self.precision = self.add_widget_intelligent(
|
||||
npyscreen.TitleSelectOne,
|
||||
columns = 2,
|
||||
SingleSelectColumns,
|
||||
columns=3,
|
||||
name="Precision",
|
||||
values=PRECISION_CHOICES,
|
||||
value=PRECISION_CHOICES.index(precision),
|
||||
begin_entry_at=3,
|
||||
max_height=len(PRECISION_CHOICES) + 1,
|
||||
max_height=2,
|
||||
max_width=80,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.max_cache_size = self.add_widget_intelligent(
|
||||
@ -409,40 +386,38 @@ Use cursor arrows to make a checkbox selection, and space to toggle.
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.nextrely += 1
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.FixedText,
|
||||
value="Folder to recursively scan for new checkpoints, ControlNets, LoRAs and TI models (<tab> autocompletes, ctrl-N advances):",
|
||||
editable=False,
|
||||
color="CONTROL",
|
||||
)
|
||||
self.autoimport_dirs = {}
|
||||
self.autoimport_dirs['autoimport_dir'] = self.add_widget_intelligent(
|
||||
FileBox,
|
||||
name=f'Autoimport Folder',
|
||||
value=str(config.root_path / config.autoimport_dir),
|
||||
select_dir=True,
|
||||
must_exist=False,
|
||||
use_two_lines=False,
|
||||
labelColor="GOOD",
|
||||
begin_entry_at=32,
|
||||
max_height = 3,
|
||||
scroll_exit=True
|
||||
)
|
||||
self.nextrely += 1
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.TitleFixedText,
|
||||
name="== LICENSE ==",
|
||||
begin_entry_at=0,
|
||||
editable=False,
|
||||
color="CONTROL",
|
||||
self.outdir = self.add_widget_intelligent(
|
||||
FileBox,
|
||||
name="Output directory for images (<tab> autocompletes, ctrl-N advances):",
|
||||
value=str(default_output_dir()),
|
||||
select_dir=True,
|
||||
must_exist=False,
|
||||
use_two_lines=False,
|
||||
labelColor="GOOD",
|
||||
begin_entry_at=40,
|
||||
max_height=3,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.nextrely -= 1
|
||||
self.autoimport_dirs = {}
|
||||
self.autoimport_dirs["autoimport_dir"] = self.add_widget_intelligent(
|
||||
FileBox,
|
||||
name=f"Folder to recursively scan for new checkpoints, ControlNets, LoRAs and TI models",
|
||||
value=str(config.root_path / config.autoimport_dir),
|
||||
select_dir=True,
|
||||
must_exist=False,
|
||||
use_two_lines=False,
|
||||
labelColor="GOOD",
|
||||
begin_entry_at=32,
|
||||
max_height=3,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.nextrely += 1
|
||||
label = """BY DOWNLOADING THE STABLE DIFFUSION WEIGHT FILES, YOU AGREE TO HAVE READ
|
||||
AND ACCEPTED THE CREATIVEML RESPONSIBLE AI LICENSE LOCATED AT
|
||||
https://huggingface.co/spaces/CompVis/stable-diffusion-license
|
||||
AND ACCEPTED THE CREATIVEML RESPONSIBLE AI LICENSES LOCATED AT
|
||||
https://huggingface.co/spaces/CompVis/stable-diffusion-license and
|
||||
https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/LICENSE.md
|
||||
"""
|
||||
for i in textwrap.wrap(label,width=window_width-6):
|
||||
for i in textwrap.wrap(label, width=window_width - 6):
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.FixedText,
|
||||
value=i,
|
||||
@ -451,22 +426,17 @@ https://huggingface.co/spaces/CompVis/stable-diffusion-license
|
||||
)
|
||||
self.license_acceptance = self.add_widget_intelligent(
|
||||
npyscreen.Checkbox,
|
||||
name="I accept the CreativeML Responsible AI License",
|
||||
name="I accept the CreativeML Responsible AI Licenses",
|
||||
value=not first_time,
|
||||
relx=2,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.nextrely += 1
|
||||
label = (
|
||||
"DONE"
|
||||
if program_opts.skip_sd_weights or program_opts.default_only
|
||||
else "NEXT"
|
||||
)
|
||||
label = "DONE" if program_opts.skip_sd_weights or program_opts.default_only else "NEXT"
|
||||
self.ok_button = self.add_widget_intelligent(
|
||||
CenteredButtonPress,
|
||||
name=label,
|
||||
relx=(window_width - len(label)) // 2,
|
||||
rely=-3,
|
||||
when_pressed_function=self.on_ok,
|
||||
)
|
||||
|
||||
@ -481,13 +451,11 @@ https://huggingface.co/spaces/CompVis/stable-diffusion-license
|
||||
self.editing = False
|
||||
else:
|
||||
self.editing = True
|
||||
|
||||
|
||||
def validate_field_values(self, opt: Namespace) -> bool:
|
||||
bad_fields = []
|
||||
if not opt.license_acceptance:
|
||||
bad_fields.append(
|
||||
"Please accept the license terms before proceeding to model downloads"
|
||||
)
|
||||
bad_fields.append("Please accept the license terms before proceeding to model downloads")
|
||||
if not Path(opt.outdir).parent.exists():
|
||||
bad_fields.append(
|
||||
f"The output directory does not seem to be valid. Please check that {str(Path(opt.outdir).parent)} is an existing directory."
|
||||
@ -505,12 +473,11 @@ https://huggingface.co/spaces/CompVis/stable-diffusion-license
|
||||
new_opts = Namespace()
|
||||
|
||||
for attr in [
|
||||
"outdir",
|
||||
"nsfw_checker",
|
||||
"free_gpu_mem",
|
||||
"max_cache_size",
|
||||
"xformers_enabled",
|
||||
"always_use_cpu",
|
||||
"outdir",
|
||||
"free_gpu_mem",
|
||||
"max_cache_size",
|
||||
"xformers_enabled",
|
||||
"always_use_cpu",
|
||||
]:
|
||||
setattr(new_opts, attr, getattr(self, attr).value)
|
||||
|
||||
@ -523,7 +490,7 @@ https://huggingface.co/spaces/CompVis/stable-diffusion-license
|
||||
new_opts.hf_token = self.hf_token.value
|
||||
new_opts.license_acceptance = self.license_acceptance.value
|
||||
new_opts.precision = PRECISION_CHOICES[self.precision.value[0]]
|
||||
|
||||
|
||||
return new_opts
|
||||
|
||||
|
||||
@ -542,7 +509,7 @@ class EditOptApplication(npyscreen.NPSAppManaged):
|
||||
"MAIN",
|
||||
editOptsForm,
|
||||
name="InvokeAI Startup Options",
|
||||
cycle_widgets=True,
|
||||
cycle_widgets=False,
|
||||
)
|
||||
if not (self.program_opts.skip_sd_weights or self.program_opts.default_only):
|
||||
self.model_select = self.addForm(
|
||||
@ -550,7 +517,7 @@ class EditOptApplication(npyscreen.NPSAppManaged):
|
||||
addModelsForm,
|
||||
name="Install Stable Diffusion Models",
|
||||
multipage=True,
|
||||
cycle_widgets=True,
|
||||
cycle_widgets=False,
|
||||
)
|
||||
|
||||
def new_opts(self):
|
||||
@ -562,21 +529,20 @@ def edit_opts(program_opts: Namespace, invokeai_opts: Namespace) -> argparse.Nam
|
||||
editApp.run()
|
||||
return editApp.new_opts()
|
||||
|
||||
|
||||
def default_startup_options(init_file: Path) -> Namespace:
|
||||
opts = InvokeAIAppConfig.get_config()
|
||||
if not init_file.exists():
|
||||
opts.nsfw_checker = True
|
||||
return opts
|
||||
|
||||
|
||||
def default_user_selections(program_opts: Namespace) -> InstallSelections:
|
||||
|
||||
try:
|
||||
installer = ModelInstall(config)
|
||||
except omegaconf.errors.ConfigKeyError:
|
||||
logger.warning('Your models.yaml file is corrupt or out of date. Reinitializing')
|
||||
logger.warning("Your models.yaml file is corrupt or out of date. Reinitializing")
|
||||
initialize_rootdir(config.root_path, True)
|
||||
installer = ModelInstall(config)
|
||||
|
||||
|
||||
models = installer.all_models()
|
||||
return InstallSelections(
|
||||
install_models=[models[installer.default_model()].path or models[installer.default_model()].repo_id]
|
||||
@ -586,44 +552,46 @@ def default_user_selections(program_opts: Namespace) -> InstallSelections:
|
||||
else list(),
|
||||
)
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
def initialize_rootdir(root: Path, yes_to_all: bool = False):
|
||||
logger.info("** INITIALIZING INVOKEAI RUNTIME DIRECTORY **")
|
||||
for name in (
|
||||
"models",
|
||||
"databases",
|
||||
"text-inversion-output",
|
||||
"text-inversion-training-data",
|
||||
"configs"
|
||||
):
|
||||
logger.info("Initializing InvokeAI runtime directory")
|
||||
for name in ("models", "databases", "text-inversion-output", "text-inversion-training-data", "configs"):
|
||||
os.makedirs(os.path.join(root, name), exist_ok=True)
|
||||
for model_type in ModelType:
|
||||
Path(root, 'autoimport', model_type.value).mkdir(parents=True, exist_ok=True)
|
||||
Path(root, "autoimport", model_type.value).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
configs_src = Path(configs.__path__[0])
|
||||
configs_dest = root / "configs"
|
||||
if not os.path.samefile(configs_src, configs_dest):
|
||||
shutil.copytree(configs_src, configs_dest, dirs_exist_ok=True)
|
||||
|
||||
dest = root / 'models'
|
||||
dest = root / "models"
|
||||
for model_base in BaseModelType:
|
||||
for model_type in ModelType:
|
||||
path = dest / model_base.value / model_type.value
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
path = dest / 'core'
|
||||
path = dest / "core"
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(root / 'configs' / 'models.yaml','w') as yaml_file:
|
||||
yaml_file.write(yaml.dump({'__metadata__':
|
||||
{'version':'3.0.0'}
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
maybe_create_models_yaml(root)
|
||||
|
||||
|
||||
def maybe_create_models_yaml(root: Path):
|
||||
models_yaml = root / "configs" / "models.yaml"
|
||||
if models_yaml.exists():
|
||||
if OmegaConf.load(models_yaml).get("__metadata__"): # up to date
|
||||
return
|
||||
else:
|
||||
logger.info("Creating new models.yaml, original saved as models.yaml.orig")
|
||||
models_yaml.rename(models_yaml.parent / "models.yaml.orig")
|
||||
|
||||
with open(models_yaml, "w") as yaml_file:
|
||||
yaml_file.write(yaml.dump({"__metadata__": {"version": "3.0.0"}}))
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
def run_console_ui(
|
||||
program_opts: Namespace, initfile: Path = None
|
||||
) -> (Namespace, Namespace):
|
||||
def run_console_ui(program_opts: Namespace, initfile: Path = None) -> (Namespace, Namespace):
|
||||
# parse_args() will read from init file if present
|
||||
invokeai_opts = default_startup_options(initfile)
|
||||
invokeai_opts.root = program_opts.root
|
||||
@ -635,8 +603,9 @@ def run_console_ui(
|
||||
# the install-models application spawns a subprocess to install
|
||||
# models, and will crash unless this is set before running.
|
||||
import torch
|
||||
|
||||
torch.multiprocessing.set_start_method("spawn")
|
||||
|
||||
|
||||
editApp = EditOptApplication(program_opts, invokeai_opts)
|
||||
editApp.run()
|
||||
if editApp.user_cancelled:
|
||||
@ -653,81 +622,86 @@ def write_opts(opts: Namespace, init_file: Path):
|
||||
# this will load current settings
|
||||
new_config = InvokeAIAppConfig.get_config()
|
||||
new_config.root = config.root
|
||||
|
||||
for key,value in opts.__dict__.items():
|
||||
if hasattr(new_config,key):
|
||||
setattr(new_config,key,value)
|
||||
|
||||
with open(init_file,'w', encoding='utf-8') as file:
|
||||
for key, value in opts.__dict__.items():
|
||||
if hasattr(new_config, key):
|
||||
setattr(new_config, key, value)
|
||||
|
||||
with open(init_file, "w", encoding="utf-8") as file:
|
||||
file.write(new_config.to_yaml())
|
||||
|
||||
if hasattr(opts,'hf_token') and opts.hf_token:
|
||||
if hasattr(opts, "hf_token") and opts.hf_token:
|
||||
HfLogin(opts.hf_token)
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
def default_output_dir() -> Path:
|
||||
return config.root_path / "outputs"
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
def write_default_options(program_opts: Namespace, initfile: Path):
|
||||
opt = default_startup_options(initfile)
|
||||
write_opts(opt, initfile)
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
# Here we bring in
|
||||
# the legacy Args object in order to parse
|
||||
# the old init file and write out the new
|
||||
# yaml format.
|
||||
def migrate_init_file(legacy_format:Path):
|
||||
old = legacy_parser.parse_args([f'@{str(legacy_format)}'])
|
||||
def migrate_init_file(legacy_format: Path):
|
||||
old = legacy_parser.parse_args([f"@{str(legacy_format)}"])
|
||||
new = InvokeAIAppConfig.get_config()
|
||||
|
||||
fields = list(get_type_hints(InvokeAIAppConfig).keys())
|
||||
for attr in fields:
|
||||
if hasattr(old,attr):
|
||||
setattr(new,attr,getattr(old,attr))
|
||||
if hasattr(old, attr):
|
||||
setattr(new, attr, getattr(old, attr))
|
||||
|
||||
# a few places where the field names have changed and we have to
|
||||
# manually add in the new names/values
|
||||
new.nsfw_checker = old.safety_checker
|
||||
new.xformers_enabled = old.xformers
|
||||
new.conf_path = old.conf
|
||||
new.root = legacy_format.parent.resolve()
|
||||
|
||||
invokeai_yaml = legacy_format.parent / 'invokeai.yaml'
|
||||
with open(invokeai_yaml,"w", encoding="utf-8") as outfile:
|
||||
invokeai_yaml = legacy_format.parent / "invokeai.yaml"
|
||||
with open(invokeai_yaml, "w", encoding="utf-8") as outfile:
|
||||
outfile.write(new.to_yaml())
|
||||
|
||||
legacy_format.replace(legacy_format.parent / 'invokeai.init.orig')
|
||||
legacy_format.replace(legacy_format.parent / "invokeai.init.orig")
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
def migrate_models(root: Path):
|
||||
from invokeai.backend.install.migrate_to_3 import do_migrate
|
||||
|
||||
do_migrate(root, root)
|
||||
|
||||
def migrate_if_needed(opt: Namespace, root: Path)->bool:
|
||||
# We check for to see if the runtime directory is correctly initialized.
|
||||
old_init_file = root / 'invokeai.init'
|
||||
new_init_file = root / 'invokeai.yaml'
|
||||
old_hub = root / 'models/hub'
|
||||
migration_needed = (old_init_file.exists() and not new_init_file.exists()) and old_hub.exists()
|
||||
|
||||
if migration_needed:
|
||||
if opt.yes_to_all or \
|
||||
yes_or_no(f'{str(config.root_path)} appears to be a 2.3 format root directory. Convert to version 3.0?'):
|
||||
|
||||
logger.info('** Migrating invokeai.init to invokeai.yaml')
|
||||
def migrate_if_needed(opt: Namespace, root: Path) -> bool:
|
||||
# We check for to see if the runtime directory is correctly initialized.
|
||||
old_init_file = root / "invokeai.init"
|
||||
new_init_file = root / "invokeai.yaml"
|
||||
old_hub = root / "models/hub"
|
||||
migration_needed = (old_init_file.exists() and not new_init_file.exists()) and old_hub.exists()
|
||||
|
||||
if migration_needed:
|
||||
if opt.yes_to_all or yes_or_no(
|
||||
f"{str(config.root_path)} appears to be a 2.3 format root directory. Convert to version 3.0?"
|
||||
):
|
||||
logger.info("** Migrating invokeai.init to invokeai.yaml")
|
||||
migrate_init_file(old_init_file)
|
||||
config.parse_args(argv=[],conf=OmegaConf.load(new_init_file))
|
||||
config.parse_args(argv=[], conf=OmegaConf.load(new_init_file))
|
||||
|
||||
if old_hub.exists():
|
||||
migrate_models(config.root_path)
|
||||
else:
|
||||
print('Cannot continue without conversion. Aborting.')
|
||||
|
||||
print("Cannot continue without conversion. Aborting.")
|
||||
|
||||
return migration_needed
|
||||
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="InvokeAI model downloader")
|
||||
@ -784,9 +758,9 @@ def main():
|
||||
|
||||
invoke_args = []
|
||||
if opt.root:
|
||||
invoke_args.extend(['--root',opt.root])
|
||||
invoke_args.extend(["--root", opt.root])
|
||||
if opt.full_precision:
|
||||
invoke_args.extend(['--precision','float32'])
|
||||
invoke_args.extend(["--precision", "float32"])
|
||||
config.parse_args(invoke_args)
|
||||
logger = InvokeAILogger().getLogger(config=config)
|
||||
|
||||
@ -798,41 +772,36 @@ def main():
|
||||
if migrate_if_needed(opt, config.root_path):
|
||||
sys.exit(0)
|
||||
|
||||
if not config.model_conf_path.exists():
|
||||
initialize_rootdir(config.root_path, opt.yes_to_all)
|
||||
# run this unconditionally in case new directories need to be added
|
||||
initialize_rootdir(config.root_path, opt.yes_to_all)
|
||||
|
||||
models_to_download = default_user_selections(opt)
|
||||
new_init_file = config.root_path / 'invokeai.yaml'
|
||||
new_init_file = config.root_path / "invokeai.yaml"
|
||||
if opt.yes_to_all:
|
||||
write_default_options(opt, new_init_file)
|
||||
init_options = Namespace(
|
||||
precision="float32" if opt.full_precision else "float16"
|
||||
)
|
||||
init_options = Namespace(precision="float32" if opt.full_precision else "float16")
|
||||
else:
|
||||
init_options, models_to_download = run_console_ui(opt, new_init_file)
|
||||
if init_options:
|
||||
write_opts(init_options, new_init_file)
|
||||
else:
|
||||
logger.info(
|
||||
'\n** CANCELLED AT USER\'S REQUEST. USE THE "invoke.sh" LAUNCHER TO RUN LATER **\n'
|
||||
)
|
||||
logger.info('\n** CANCELLED AT USER\'S REQUEST. USE THE "invoke.sh" LAUNCHER TO RUN LATER **\n')
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
if opt.skip_support_models:
|
||||
logger.info("SKIPPING SUPPORT MODEL DOWNLOADS PER USER REQUEST")
|
||||
logger.info("Skipping support models at user's request")
|
||||
else:
|
||||
logger.info("CHECKING/UPDATING SUPPORT MODELS")
|
||||
logger.info("Installing support models")
|
||||
download_support_models()
|
||||
|
||||
if opt.skip_sd_weights:
|
||||
logger.warning("SKIPPING DIFFUSION WEIGHTS DOWNLOAD PER USER REQUEST")
|
||||
logger.warning("Skipping diffusion weights download per user request")
|
||||
elif models_to_download:
|
||||
logger.info("DOWNLOADING DIFFUSION WEIGHTS")
|
||||
process_and_execute(opt, models_to_download)
|
||||
|
||||
postscript(errors=errors)
|
||||
if not opt.yes_to_all:
|
||||
input('Press any key to continue...')
|
||||
input("Press any key to continue...")
|
||||
except KeyboardInterrupt:
|
||||
print("\nGoodbye! Come back soon.")
|
||||
|
||||
|
@ -47,17 +47,18 @@ PRECISION_CHOICES = [
|
||||
"float16",
|
||||
]
|
||||
|
||||
|
||||
class FileArgumentParser(ArgumentParser):
|
||||
"""
|
||||
Supports reading defaults from an init file.
|
||||
"""
|
||||
|
||||
def convert_arg_line_to_args(self, arg_line):
|
||||
return shlex.split(arg_line, comments=True)
|
||||
|
||||
|
||||
legacy_parser = FileArgumentParser(
|
||||
description=
|
||||
"""
|
||||
description="""
|
||||
Generate images using Stable Diffusion.
|
||||
Use --web to launch the web interface.
|
||||
Use --from_file to load prompts from a file path or standard input ("-").
|
||||
@ -65,304 +66,279 @@ Generate images using Stable Diffusion.
|
||||
Other command-line arguments are defaults that can usually be overridden
|
||||
prompt the command prompt.
|
||||
""",
|
||||
fromfile_prefix_chars='@',
|
||||
fromfile_prefix_chars="@",
|
||||
)
|
||||
general_group = legacy_parser.add_argument_group('General')
|
||||
model_group = legacy_parser.add_argument_group('Model selection')
|
||||
file_group = legacy_parser.add_argument_group('Input/output')
|
||||
web_server_group = legacy_parser.add_argument_group('Web server')
|
||||
render_group = legacy_parser.add_argument_group('Rendering')
|
||||
postprocessing_group = legacy_parser.add_argument_group('Postprocessing')
|
||||
deprecated_group = legacy_parser.add_argument_group('Deprecated options')
|
||||
general_group = legacy_parser.add_argument_group("General")
|
||||
model_group = legacy_parser.add_argument_group("Model selection")
|
||||
file_group = legacy_parser.add_argument_group("Input/output")
|
||||
web_server_group = legacy_parser.add_argument_group("Web server")
|
||||
render_group = legacy_parser.add_argument_group("Rendering")
|
||||
postprocessing_group = legacy_parser.add_argument_group("Postprocessing")
|
||||
deprecated_group = legacy_parser.add_argument_group("Deprecated options")
|
||||
|
||||
deprecated_group.add_argument('--laion400m')
|
||||
deprecated_group.add_argument('--weights') # deprecated
|
||||
general_group.add_argument(
|
||||
'--version','-V',
|
||||
action='store_true',
|
||||
help='Print InvokeAI version number'
|
||||
)
|
||||
deprecated_group.add_argument("--laion400m")
|
||||
deprecated_group.add_argument("--weights") # deprecated
|
||||
general_group.add_argument("--version", "-V", action="store_true", help="Print InvokeAI version number")
|
||||
model_group.add_argument(
|
||||
'--root_dir',
|
||||
"--root_dir",
|
||||
default=None,
|
||||
help='Path to directory containing "models", "outputs" and "configs". If not present will read from environment variable INVOKEAI_ROOT. Defaults to ~/invokeai.',
|
||||
)
|
||||
model_group.add_argument(
|
||||
'--config',
|
||||
'-c',
|
||||
'-config',
|
||||
dest='conf',
|
||||
default='./configs/models.yaml',
|
||||
help='Path to configuration file for alternate models.',
|
||||
"--config",
|
||||
"-c",
|
||||
"-config",
|
||||
dest="conf",
|
||||
default="./configs/models.yaml",
|
||||
help="Path to configuration file for alternate models.",
|
||||
)
|
||||
model_group.add_argument(
|
||||
'--model',
|
||||
"--model",
|
||||
help='Indicates which diffusion model to load (defaults to "default" stanza in configs/models.yaml)',
|
||||
)
|
||||
model_group.add_argument(
|
||||
'--weight_dirs',
|
||||
nargs='+',
|
||||
"--weight_dirs",
|
||||
nargs="+",
|
||||
type=str,
|
||||
help='List of one or more directories that will be auto-scanned for new model weights to import',
|
||||
help="List of one or more directories that will be auto-scanned for new model weights to import",
|
||||
)
|
||||
model_group.add_argument(
|
||||
'--png_compression','-z',
|
||||
"--png_compression",
|
||||
"-z",
|
||||
type=int,
|
||||
default=6,
|
||||
choices=range(0,9),
|
||||
dest='png_compression',
|
||||
help='level of PNG compression, from 0 (none) to 9 (maximum). Default is 6.'
|
||||
choices=range(0, 9),
|
||||
dest="png_compression",
|
||||
help="level of PNG compression, from 0 (none) to 9 (maximum). Default is 6.",
|
||||
)
|
||||
model_group.add_argument(
|
||||
'-F',
|
||||
'--full_precision',
|
||||
dest='full_precision',
|
||||
action='store_true',
|
||||
help='Deprecated way to set --precision=float32',
|
||||
"-F",
|
||||
"--full_precision",
|
||||
dest="full_precision",
|
||||
action="store_true",
|
||||
help="Deprecated way to set --precision=float32",
|
||||
)
|
||||
model_group.add_argument(
|
||||
'--max_loaded_models',
|
||||
dest='max_loaded_models',
|
||||
"--max_loaded_models",
|
||||
dest="max_loaded_models",
|
||||
type=int,
|
||||
default=2,
|
||||
help='Maximum number of models to keep in memory for fast switching, including the one in GPU',
|
||||
help="Maximum number of models to keep in memory for fast switching, including the one in GPU",
|
||||
)
|
||||
model_group.add_argument(
|
||||
'--free_gpu_mem',
|
||||
dest='free_gpu_mem',
|
||||
action='store_true',
|
||||
help='Force free gpu memory before final decoding',
|
||||
"--free_gpu_mem",
|
||||
dest="free_gpu_mem",
|
||||
action="store_true",
|
||||
help="Force free gpu memory before final decoding",
|
||||
)
|
||||
model_group.add_argument(
|
||||
'--sequential_guidance',
|
||||
dest='sequential_guidance',
|
||||
action='store_true',
|
||||
help="Calculate guidance in serial instead of in parallel, lowering memory requirement "
|
||||
"at the expense of speed",
|
||||
"--sequential_guidance",
|
||||
dest="sequential_guidance",
|
||||
action="store_true",
|
||||
help="Calculate guidance in serial instead of in parallel, lowering memory requirement " "at the expense of speed",
|
||||
)
|
||||
model_group.add_argument(
|
||||
'--xformers',
|
||||
"--xformers",
|
||||
action=argparse.BooleanOptionalAction,
|
||||
default=True,
|
||||
help='Enable/disable xformers support (default enabled if installed)',
|
||||
help="Enable/disable xformers support (default enabled if installed)",
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--always_use_cpu",
|
||||
dest="always_use_cpu",
|
||||
action="store_true",
|
||||
help="Force use of CPU even if GPU is available"
|
||||
"--always_use_cpu", dest="always_use_cpu", action="store_true", help="Force use of CPU even if GPU is available"
|
||||
)
|
||||
model_group.add_argument(
|
||||
'--precision',
|
||||
dest='precision',
|
||||
"--precision",
|
||||
dest="precision",
|
||||
type=str,
|
||||
choices=PRECISION_CHOICES,
|
||||
metavar='PRECISION',
|
||||
metavar="PRECISION",
|
||||
help=f'Set model precision. Defaults to auto selected based on device. Options: {", ".join(PRECISION_CHOICES)}',
|
||||
default='auto',
|
||||
default="auto",
|
||||
)
|
||||
model_group.add_argument(
|
||||
'--ckpt_convert',
|
||||
"--ckpt_convert",
|
||||
action=argparse.BooleanOptionalAction,
|
||||
dest='ckpt_convert',
|
||||
dest="ckpt_convert",
|
||||
default=True,
|
||||
help='Deprecated option. Legacy ckpt files are now always converted to diffusers when loaded.'
|
||||
help="Deprecated option. Legacy ckpt files are now always converted to diffusers when loaded.",
|
||||
)
|
||||
model_group.add_argument(
|
||||
'--internet',
|
||||
"--internet",
|
||||
action=argparse.BooleanOptionalAction,
|
||||
dest='internet_available',
|
||||
dest="internet_available",
|
||||
default=True,
|
||||
help='Indicate whether internet is available for just-in-time model downloading (default: probe automatically).',
|
||||
help="Indicate whether internet is available for just-in-time model downloading (default: probe automatically).",
|
||||
)
|
||||
model_group.add_argument(
|
||||
'--nsfw_checker',
|
||||
'--safety_checker',
|
||||
"--nsfw_checker",
|
||||
"--safety_checker",
|
||||
action=argparse.BooleanOptionalAction,
|
||||
dest='safety_checker',
|
||||
dest="safety_checker",
|
||||
default=False,
|
||||
help='Check for and blur potentially NSFW images. Use --no-nsfw_checker to disable.',
|
||||
help="Check for and blur potentially NSFW images. Use --no-nsfw_checker to disable.",
|
||||
)
|
||||
model_group.add_argument(
|
||||
'--autoimport',
|
||||
"--autoimport",
|
||||
default=None,
|
||||
type=str,
|
||||
help='Check the indicated directory for .ckpt/.safetensors weights files at startup and import directly',
|
||||
help="Check the indicated directory for .ckpt/.safetensors weights files at startup and import directly",
|
||||
)
|
||||
model_group.add_argument(
|
||||
'--autoconvert',
|
||||
"--autoconvert",
|
||||
default=None,
|
||||
type=str,
|
||||
help='Check the indicated directory for .ckpt/.safetensors weights files at startup and import as optimized diffuser models',
|
||||
help="Check the indicated directory for .ckpt/.safetensors weights files at startup and import as optimized diffuser models",
|
||||
)
|
||||
model_group.add_argument(
|
||||
'--patchmatch',
|
||||
"--patchmatch",
|
||||
action=argparse.BooleanOptionalAction,
|
||||
default=True,
|
||||
help='Load the patchmatch extension for outpainting. Use --no-patchmatch to disable.',
|
||||
help="Load the patchmatch extension for outpainting. Use --no-patchmatch to disable.",
|
||||
)
|
||||
file_group.add_argument(
|
||||
'--from_file',
|
||||
dest='infile',
|
||||
"--from_file",
|
||||
dest="infile",
|
||||
type=str,
|
||||
help='If specified, load prompts from this file',
|
||||
help="If specified, load prompts from this file",
|
||||
)
|
||||
file_group.add_argument(
|
||||
'--outdir',
|
||||
'-o',
|
||||
"--outdir",
|
||||
"-o",
|
||||
type=str,
|
||||
help='Directory to save generated images and a log of prompts and seeds. Default: ROOTDIR/outputs',
|
||||
default='outputs',
|
||||
help="Directory to save generated images and a log of prompts and seeds. Default: ROOTDIR/outputs",
|
||||
default="outputs",
|
||||
)
|
||||
file_group.add_argument(
|
||||
'--prompt_as_dir',
|
||||
'-p',
|
||||
action='store_true',
|
||||
help='Place images in subdirectories named after the prompt.',
|
||||
"--prompt_as_dir",
|
||||
"-p",
|
||||
action="store_true",
|
||||
help="Place images in subdirectories named after the prompt.",
|
||||
)
|
||||
render_group.add_argument(
|
||||
'--fnformat',
|
||||
default='{prefix}.{seed}.png',
|
||||
"--fnformat",
|
||||
default="{prefix}.{seed}.png",
|
||||
type=str,
|
||||
help='Overwrite the filename format. You can use any argument as wildcard enclosed in curly braces. Default is {prefix}.{seed}.png',
|
||||
help="Overwrite the filename format. You can use any argument as wildcard enclosed in curly braces. Default is {prefix}.{seed}.png",
|
||||
)
|
||||
render_group.add_argument("-s", "--steps", type=int, default=50, help="Number of steps")
|
||||
render_group.add_argument(
|
||||
'-s',
|
||||
'--steps',
|
||||
"-W",
|
||||
"--width",
|
||||
type=int,
|
||||
default=50,
|
||||
help='Number of steps'
|
||||
help="Image width, multiple of 64",
|
||||
)
|
||||
render_group.add_argument(
|
||||
'-W',
|
||||
'--width',
|
||||
"-H",
|
||||
"--height",
|
||||
type=int,
|
||||
help='Image width, multiple of 64',
|
||||
help="Image height, multiple of 64",
|
||||
)
|
||||
render_group.add_argument(
|
||||
'-H',
|
||||
'--height',
|
||||
type=int,
|
||||
help='Image height, multiple of 64',
|
||||
)
|
||||
render_group.add_argument(
|
||||
'-C',
|
||||
'--cfg_scale',
|
||||
"-C",
|
||||
"--cfg_scale",
|
||||
default=7.5,
|
||||
type=float,
|
||||
help='Classifier free guidance (CFG) scale - higher numbers cause generator to "try" harder.',
|
||||
)
|
||||
render_group.add_argument(
|
||||
'--sampler',
|
||||
'-A',
|
||||
'-m',
|
||||
dest='sampler_name',
|
||||
"--sampler",
|
||||
"-A",
|
||||
"-m",
|
||||
dest="sampler_name",
|
||||
type=str,
|
||||
choices=SAMPLER_CHOICES,
|
||||
metavar='SAMPLER_NAME',
|
||||
metavar="SAMPLER_NAME",
|
||||
help=f'Set the default sampler. Supported samplers: {", ".join(SAMPLER_CHOICES)}',
|
||||
default='k_lms',
|
||||
default="k_lms",
|
||||
)
|
||||
render_group.add_argument(
|
||||
'--log_tokenization',
|
||||
'-t',
|
||||
action='store_true',
|
||||
help='shows how the prompt is split into tokens'
|
||||
"--log_tokenization", "-t", action="store_true", help="shows how the prompt is split into tokens"
|
||||
)
|
||||
render_group.add_argument(
|
||||
'-f',
|
||||
'--strength',
|
||||
"-f",
|
||||
"--strength",
|
||||
type=float,
|
||||
help='img2img strength for noising/unnoising. 0.0 preserves image exactly, 1.0 replaces it completely',
|
||||
help="img2img strength for noising/unnoising. 0.0 preserves image exactly, 1.0 replaces it completely",
|
||||
)
|
||||
render_group.add_argument(
|
||||
'-T',
|
||||
'-fit',
|
||||
'--fit',
|
||||
"-T",
|
||||
"-fit",
|
||||
"--fit",
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help='If specified, will resize the input image to fit within the dimensions of width x height (512x512 default)',
|
||||
help="If specified, will resize the input image to fit within the dimensions of width x height (512x512 default)",
|
||||
)
|
||||
|
||||
render_group.add_argument("--grid", "-g", action=argparse.BooleanOptionalAction, help="generate a grid")
|
||||
render_group.add_argument(
|
||||
'--grid',
|
||||
'-g',
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help='generate a grid'
|
||||
)
|
||||
render_group.add_argument(
|
||||
'--embedding_directory',
|
||||
'--embedding_path',
|
||||
dest='embedding_path',
|
||||
default='embeddings',
|
||||
"--embedding_directory",
|
||||
"--embedding_path",
|
||||
dest="embedding_path",
|
||||
default="embeddings",
|
||||
type=str,
|
||||
help='Path to a directory containing .bin and/or .pt files, or a single .bin/.pt file. You may use subdirectories. (default is ROOTDIR/embeddings)'
|
||||
help="Path to a directory containing .bin and/or .pt files, or a single .bin/.pt file. You may use subdirectories. (default is ROOTDIR/embeddings)",
|
||||
)
|
||||
render_group.add_argument(
|
||||
'--lora_directory',
|
||||
dest='lora_path',
|
||||
default='loras',
|
||||
"--lora_directory",
|
||||
dest="lora_path",
|
||||
default="loras",
|
||||
type=str,
|
||||
help='Path to a directory containing LoRA files; subdirectories are not supported. (default is ROOTDIR/loras)'
|
||||
help="Path to a directory containing LoRA files; subdirectories are not supported. (default is ROOTDIR/loras)",
|
||||
)
|
||||
render_group.add_argument(
|
||||
'--embeddings',
|
||||
"--embeddings",
|
||||
action=argparse.BooleanOptionalAction,
|
||||
default=True,
|
||||
help='Enable embedding directory (default). Use --no-embeddings to disable.',
|
||||
help="Enable embedding directory (default). Use --no-embeddings to disable.",
|
||||
)
|
||||
render_group.add_argument("--enable_image_debugging", action="store_true", help="Generates debugging image to display")
|
||||
render_group.add_argument(
|
||||
'--enable_image_debugging',
|
||||
action='store_true',
|
||||
help='Generates debugging image to display'
|
||||
)
|
||||
render_group.add_argument(
|
||||
'--karras_max',
|
||||
"--karras_max",
|
||||
type=int,
|
||||
default=None,
|
||||
help="control the point at which the K* samplers will shift from using the Karras noise schedule (good for low step counts) to the LatentDiffusion noise schedule (good for high step counts). Set to 0 to use LatentDiffusion for all step values, and to a high value (e.g. 1000) to use Karras for all step values. [29]."
|
||||
help="control the point at which the K* samplers will shift from using the Karras noise schedule (good for low step counts) to the LatentDiffusion noise schedule (good for high step counts). Set to 0 to use LatentDiffusion for all step values, and to a high value (e.g. 1000) to use Karras for all step values. [29].",
|
||||
)
|
||||
# Restoration related args
|
||||
postprocessing_group.add_argument(
|
||||
'--no_restore',
|
||||
dest='restore',
|
||||
action='store_false',
|
||||
help='Disable face restoration with GFPGAN or codeformer',
|
||||
"--no_restore",
|
||||
dest="restore",
|
||||
action="store_false",
|
||||
help="Disable face restoration with GFPGAN or codeformer",
|
||||
)
|
||||
postprocessing_group.add_argument(
|
||||
'--no_upscale',
|
||||
dest='esrgan',
|
||||
action='store_false',
|
||||
help='Disable upscaling with ESRGAN',
|
||||
"--no_upscale",
|
||||
dest="esrgan",
|
||||
action="store_false",
|
||||
help="Disable upscaling with ESRGAN",
|
||||
)
|
||||
postprocessing_group.add_argument(
|
||||
'--esrgan_bg_tile',
|
||||
"--esrgan_bg_tile",
|
||||
type=int,
|
||||
default=400,
|
||||
help='Tile size for background sampler, 0 for no tile during testing. Default: 400.',
|
||||
help="Tile size for background sampler, 0 for no tile during testing. Default: 400.",
|
||||
)
|
||||
postprocessing_group.add_argument(
|
||||
'--esrgan_denoise_str',
|
||||
"--esrgan_denoise_str",
|
||||
type=float,
|
||||
default=0.75,
|
||||
help='esrgan denoise str. 0 is no denoise, 1 is max denoise. Default: 0.75',
|
||||
help="esrgan denoise str. 0 is no denoise, 1 is max denoise. Default: 0.75",
|
||||
)
|
||||
postprocessing_group.add_argument(
|
||||
'--gfpgan_model_path',
|
||||
"--gfpgan_model_path",
|
||||
type=str,
|
||||
default='./models/gfpgan/GFPGANv1.4.pth',
|
||||
help='Indicates the path to the GFPGAN model',
|
||||
default="./models/gfpgan/GFPGANv1.4.pth",
|
||||
help="Indicates the path to the GFPGAN model",
|
||||
)
|
||||
web_server_group.add_argument(
|
||||
'--web',
|
||||
dest='web',
|
||||
action='store_true',
|
||||
help='Start in web server mode.',
|
||||
"--web",
|
||||
dest="web",
|
||||
action="store_true",
|
||||
help="Start in web server mode.",
|
||||
)
|
||||
web_server_group.add_argument(
|
||||
'--web_develop',
|
||||
dest='web_develop',
|
||||
action='store_true',
|
||||
help='Start in web server development mode.',
|
||||
"--web_develop",
|
||||
dest="web_develop",
|
||||
action="store_true",
|
||||
help="Start in web server development mode.",
|
||||
)
|
||||
web_server_group.add_argument(
|
||||
"--web_verbose",
|
||||
@ -376,32 +352,27 @@ web_server_group.add_argument(
|
||||
help="Additional allowed origins, comma-separated",
|
||||
)
|
||||
web_server_group.add_argument(
|
||||
'--host',
|
||||
"--host",
|
||||
type=str,
|
||||
default='127.0.0.1',
|
||||
help='Web server: Host or IP to listen on. Set to 0.0.0.0 to accept traffic from other devices on your network.'
|
||||
default="127.0.0.1",
|
||||
help="Web server: Host or IP to listen on. Set to 0.0.0.0 to accept traffic from other devices on your network.",
|
||||
)
|
||||
web_server_group.add_argument("--port", type=int, default="9090", help="Web server: Port to listen on")
|
||||
web_server_group.add_argument(
|
||||
'--port',
|
||||
type=int,
|
||||
default='9090',
|
||||
help='Web server: Port to listen on'
|
||||
)
|
||||
web_server_group.add_argument(
|
||||
'--certfile',
|
||||
"--certfile",
|
||||
type=str,
|
||||
default=None,
|
||||
help='Web server: Path to certificate file to use for SSL. Use together with --keyfile'
|
||||
help="Web server: Path to certificate file to use for SSL. Use together with --keyfile",
|
||||
)
|
||||
web_server_group.add_argument(
|
||||
'--keyfile',
|
||||
"--keyfile",
|
||||
type=str,
|
||||
default=None,
|
||||
help='Web server: Path to private key file to use for SSL. Use together with --certfile'
|
||||
help="Web server: Path to private key file to use for SSL. Use together with --certfile",
|
||||
)
|
||||
web_server_group.add_argument(
|
||||
'--gui',
|
||||
dest='gui',
|
||||
action='store_true',
|
||||
help='Start InvokeAI GUI',
|
||||
"--gui",
|
||||
dest="gui",
|
||||
action="store_true",
|
||||
help="Start InvokeAI GUI",
|
||||
)
|
||||
|
@ -1,7 +1,7 @@
|
||||
'''
|
||||
"""
|
||||
Migrate the models directory and models.yaml file from an existing
|
||||
InvokeAI 2.3 installation to 3.0.0.
|
||||
'''
|
||||
"""
|
||||
|
||||
import os
|
||||
import argparse
|
||||
@ -29,14 +29,13 @@ from transformers import (
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.model_management import ModelManager
|
||||
from invokeai.backend.model_management.model_probe import (
|
||||
ModelProbe, ModelType, BaseModelType, ModelProbeInfo
|
||||
)
|
||||
from invokeai.backend.model_management.model_probe import ModelProbe, ModelType, BaseModelType, ModelProbeInfo
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
transformers.logging.set_verbosity_error()
|
||||
diffusers.logging.set_verbosity_error()
|
||||
|
||||
|
||||
# holder for paths that we will migrate
|
||||
@dataclass
|
||||
class ModelPaths:
|
||||
@ -45,81 +44,82 @@ class ModelPaths:
|
||||
loras: Path
|
||||
controlnets: Path
|
||||
|
||||
|
||||
class MigrateTo3(object):
|
||||
def __init__(self,
|
||||
from_root: Path,
|
||||
to_models: Path,
|
||||
model_manager: ModelManager,
|
||||
src_paths: ModelPaths,
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
from_root: Path,
|
||||
to_models: Path,
|
||||
model_manager: ModelManager,
|
||||
src_paths: ModelPaths,
|
||||
):
|
||||
self.root_directory = from_root
|
||||
self.dest_models = to_models
|
||||
self.mgr = model_manager
|
||||
self.src_paths = src_paths
|
||||
|
||||
|
||||
@classmethod
|
||||
def initialize_yaml(cls, yaml_file: Path):
|
||||
with open(yaml_file, 'w') as file:
|
||||
file.write(
|
||||
yaml.dump(
|
||||
{
|
||||
'__metadata__': {'version':'3.0.0'}
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
with open(yaml_file, "w") as file:
|
||||
file.write(yaml.dump({"__metadata__": {"version": "3.0.0"}}))
|
||||
|
||||
def create_directory_structure(self):
|
||||
'''
|
||||
"""
|
||||
Create the basic directory structure for the models folder.
|
||||
'''
|
||||
for model_base in [BaseModelType.StableDiffusion1,BaseModelType.StableDiffusion2]:
|
||||
for model_type in [ModelType.Main, ModelType.Vae, ModelType.Lora,
|
||||
ModelType.ControlNet,ModelType.TextualInversion]:
|
||||
"""
|
||||
for model_base in [BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2]:
|
||||
for model_type in [
|
||||
ModelType.Main,
|
||||
ModelType.Vae,
|
||||
ModelType.Lora,
|
||||
ModelType.ControlNet,
|
||||
ModelType.TextualInversion,
|
||||
]:
|
||||
path = self.dest_models / model_base.value / model_type.value
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
path = self.dest_models / 'core'
|
||||
path = self.dest_models / "core"
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
@staticmethod
|
||||
def copy_file(src:Path,dest:Path):
|
||||
'''
|
||||
def copy_file(src: Path, dest: Path):
|
||||
"""
|
||||
copy a single file with logging
|
||||
'''
|
||||
"""
|
||||
if dest.exists():
|
||||
logger.info(f'Skipping existing {str(dest)}')
|
||||
logger.info(f"Skipping existing {str(dest)}")
|
||||
return
|
||||
logger.info(f'Copying {str(src)} to {str(dest)}')
|
||||
logger.info(f"Copying {str(src)} to {str(dest)}")
|
||||
try:
|
||||
shutil.copy(src, dest)
|
||||
except Exception as e:
|
||||
logger.error(f'COPY FAILED: {str(e)}')
|
||||
logger.error(f"COPY FAILED: {str(e)}")
|
||||
|
||||
@staticmethod
|
||||
def copy_dir(src:Path,dest:Path):
|
||||
'''
|
||||
def copy_dir(src: Path, dest: Path):
|
||||
"""
|
||||
Recursively copy a directory with logging
|
||||
'''
|
||||
"""
|
||||
if dest.exists():
|
||||
logger.info(f'Skipping existing {str(dest)}')
|
||||
logger.info(f"Skipping existing {str(dest)}")
|
||||
return
|
||||
|
||||
logger.info(f'Copying {str(src)} to {str(dest)}')
|
||||
|
||||
logger.info(f"Copying {str(src)} to {str(dest)}")
|
||||
try:
|
||||
shutil.copytree(src, dest)
|
||||
except Exception as e:
|
||||
logger.error(f'COPY FAILED: {str(e)}')
|
||||
logger.error(f"COPY FAILED: {str(e)}")
|
||||
|
||||
def migrate_models(self, src_dir: Path):
|
||||
'''
|
||||
"""
|
||||
Recursively walk through src directory, probe anything
|
||||
that looks like a model, and copy the model into the
|
||||
appropriate location within the destination models directory.
|
||||
'''
|
||||
"""
|
||||
directories_scanned = set()
|
||||
for root, dirs, files in os.walk(src_dir):
|
||||
for d in dirs:
|
||||
try:
|
||||
model = Path(root,d)
|
||||
model = Path(root, d)
|
||||
info = ModelProbe().heuristic_probe(model)
|
||||
if not info:
|
||||
continue
|
||||
@ -136,9 +136,9 @@ class MigrateTo3(object):
|
||||
# don't copy raw learned_embeds.bin or pytorch_lora_weights.bin
|
||||
# let them be copied as part of a tree copy operation
|
||||
try:
|
||||
if f in {'learned_embeds.bin','pytorch_lora_weights.bin'}:
|
||||
if f in {"learned_embeds.bin", "pytorch_lora_weights.bin"}:
|
||||
continue
|
||||
model = Path(root,f)
|
||||
model = Path(root, f)
|
||||
if model.parent in directories_scanned:
|
||||
continue
|
||||
info = ModelProbe().heuristic_probe(model)
|
||||
@ -154,148 +154,146 @@ class MigrateTo3(object):
|
||||
logger.error(str(e))
|
||||
|
||||
def migrate_support_models(self):
|
||||
'''
|
||||
"""
|
||||
Copy the clipseg, upscaler, and restoration models to their new
|
||||
locations.
|
||||
'''
|
||||
"""
|
||||
dest_directory = self.dest_models
|
||||
if (self.root_directory / 'models/clipseg').exists():
|
||||
self.copy_dir(self.root_directory / 'models/clipseg', dest_directory / 'core/misc/clipseg')
|
||||
if (self.root_directory / 'models/realesrgan').exists():
|
||||
self.copy_dir(self.root_directory / 'models/realesrgan', dest_directory / 'core/upscaling/realesrgan')
|
||||
for d in ['codeformer','gfpgan']:
|
||||
path = self.root_directory / 'models' / d
|
||||
if (self.root_directory / "models/clipseg").exists():
|
||||
self.copy_dir(self.root_directory / "models/clipseg", dest_directory / "core/misc/clipseg")
|
||||
if (self.root_directory / "models/realesrgan").exists():
|
||||
self.copy_dir(self.root_directory / "models/realesrgan", dest_directory / "core/upscaling/realesrgan")
|
||||
for d in ["codeformer", "gfpgan"]:
|
||||
path = self.root_directory / "models" / d
|
||||
if path.exists():
|
||||
self.copy_dir(path,dest_directory / f'core/face_restoration/{d}')
|
||||
self.copy_dir(path, dest_directory / f"core/face_restoration/{d}")
|
||||
|
||||
def migrate_tuning_models(self):
|
||||
'''
|
||||
"""
|
||||
Migrate the embeddings, loras and controlnets directories to their new homes.
|
||||
'''
|
||||
"""
|
||||
for src in [self.src_paths.embeddings, self.src_paths.loras, self.src_paths.controlnets]:
|
||||
if not src:
|
||||
continue
|
||||
if src.is_dir():
|
||||
logger.info(f'Scanning {src}')
|
||||
logger.info(f"Scanning {src}")
|
||||
self.migrate_models(src)
|
||||
else:
|
||||
logger.info(f'{src} directory not found; skipping')
|
||||
logger.info(f"{src} directory not found; skipping")
|
||||
continue
|
||||
|
||||
def migrate_conversion_models(self):
|
||||
'''
|
||||
"""
|
||||
Migrate all the models that are needed by the ckpt_to_diffusers conversion
|
||||
script.
|
||||
'''
|
||||
"""
|
||||
|
||||
dest_directory = self.dest_models
|
||||
kwargs = dict(
|
||||
cache_dir = self.root_directory / 'models/hub',
|
||||
#local_files_only = True
|
||||
cache_dir=self.root_directory / "models/hub",
|
||||
# local_files_only = True
|
||||
)
|
||||
try:
|
||||
logger.info('Migrating core tokenizers and text encoders')
|
||||
target_dir = dest_directory / 'core' / 'convert'
|
||||
logger.info("Migrating core tokenizers and text encoders")
|
||||
target_dir = dest_directory / "core" / "convert"
|
||||
|
||||
self._migrate_pretrained(BertTokenizerFast,
|
||||
repo_id='bert-base-uncased',
|
||||
dest = target_dir / 'bert-base-uncased',
|
||||
**kwargs)
|
||||
self._migrate_pretrained(
|
||||
BertTokenizerFast, repo_id="bert-base-uncased", dest=target_dir / "bert-base-uncased", **kwargs
|
||||
)
|
||||
|
||||
# sd-1
|
||||
repo_id = 'openai/clip-vit-large-patch14'
|
||||
self._migrate_pretrained(CLIPTokenizer,
|
||||
repo_id= repo_id,
|
||||
dest= target_dir / 'clip-vit-large-patch14',
|
||||
**kwargs)
|
||||
self._migrate_pretrained(CLIPTextModel,
|
||||
repo_id = repo_id,
|
||||
dest = target_dir / 'clip-vit-large-patch14',
|
||||
force = True,
|
||||
**kwargs)
|
||||
repo_id = "openai/clip-vit-large-patch14"
|
||||
self._migrate_pretrained(
|
||||
CLIPTokenizer, repo_id=repo_id, dest=target_dir / "clip-vit-large-patch14", **kwargs
|
||||
)
|
||||
self._migrate_pretrained(
|
||||
CLIPTextModel, repo_id=repo_id, dest=target_dir / "clip-vit-large-patch14", force=True, **kwargs
|
||||
)
|
||||
|
||||
# sd-2
|
||||
repo_id = "stabilityai/stable-diffusion-2"
|
||||
self._migrate_pretrained(CLIPTokenizer,
|
||||
repo_id = repo_id,
|
||||
dest = target_dir / 'stable-diffusion-2-clip' / 'tokenizer',
|
||||
**{'subfolder':'tokenizer',**kwargs}
|
||||
)
|
||||
self._migrate_pretrained(CLIPTextModel,
|
||||
repo_id = repo_id,
|
||||
dest = target_dir / 'stable-diffusion-2-clip' / 'text_encoder',
|
||||
**{'subfolder':'text_encoder',**kwargs}
|
||||
)
|
||||
self._migrate_pretrained(
|
||||
CLIPTokenizer,
|
||||
repo_id=repo_id,
|
||||
dest=target_dir / "stable-diffusion-2-clip" / "tokenizer",
|
||||
**{"subfolder": "tokenizer", **kwargs},
|
||||
)
|
||||
self._migrate_pretrained(
|
||||
CLIPTextModel,
|
||||
repo_id=repo_id,
|
||||
dest=target_dir / "stable-diffusion-2-clip" / "text_encoder",
|
||||
**{"subfolder": "text_encoder", **kwargs},
|
||||
)
|
||||
|
||||
# VAE
|
||||
logger.info('Migrating stable diffusion VAE')
|
||||
self._migrate_pretrained(AutoencoderKL,
|
||||
repo_id = 'stabilityai/sd-vae-ft-mse',
|
||||
dest = target_dir / 'sd-vae-ft-mse',
|
||||
**kwargs)
|
||||
|
||||
logger.info("Migrating stable diffusion VAE")
|
||||
self._migrate_pretrained(
|
||||
AutoencoderKL, repo_id="stabilityai/sd-vae-ft-mse", dest=target_dir / "sd-vae-ft-mse", **kwargs
|
||||
)
|
||||
|
||||
# safety checking
|
||||
logger.info('Migrating safety checker')
|
||||
logger.info("Migrating safety checker")
|
||||
repo_id = "CompVis/stable-diffusion-safety-checker"
|
||||
self._migrate_pretrained(AutoFeatureExtractor,
|
||||
repo_id = repo_id,
|
||||
dest = target_dir / 'stable-diffusion-safety-checker',
|
||||
**kwargs)
|
||||
self._migrate_pretrained(StableDiffusionSafetyChecker,
|
||||
repo_id = repo_id,
|
||||
dest = target_dir / 'stable-diffusion-safety-checker',
|
||||
**kwargs)
|
||||
self._migrate_pretrained(
|
||||
AutoFeatureExtractor, repo_id=repo_id, dest=target_dir / "stable-diffusion-safety-checker", **kwargs
|
||||
)
|
||||
self._migrate_pretrained(
|
||||
StableDiffusionSafetyChecker,
|
||||
repo_id=repo_id,
|
||||
dest=target_dir / "stable-diffusion-safety-checker",
|
||||
**kwargs,
|
||||
)
|
||||
except KeyboardInterrupt:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(str(e))
|
||||
|
||||
def _model_probe_to_path(self, info: ModelProbeInfo)->Path:
|
||||
def _model_probe_to_path(self, info: ModelProbeInfo) -> Path:
|
||||
return Path(self.dest_models, info.base_type.value, info.model_type.value)
|
||||
|
||||
def _migrate_pretrained(self, model_class, repo_id: str, dest: Path, force:bool=False, **kwargs):
|
||||
def _migrate_pretrained(self, model_class, repo_id: str, dest: Path, force: bool = False, **kwargs):
|
||||
if dest.exists() and not force:
|
||||
logger.info(f'Skipping existing {dest}')
|
||||
logger.info(f"Skipping existing {dest}")
|
||||
return
|
||||
model = model_class.from_pretrained(repo_id, **kwargs)
|
||||
self._save_pretrained(model, dest, overwrite=force)
|
||||
|
||||
def _save_pretrained(self, model, dest: Path, overwrite: bool=False):
|
||||
def _save_pretrained(self, model, dest: Path, overwrite: bool = False):
|
||||
model_name = dest.name
|
||||
if overwrite:
|
||||
model.save_pretrained(dest, safe_serialization=True)
|
||||
else:
|
||||
download_path = dest.with_name(f'{model_name}.downloading')
|
||||
download_path = dest.with_name(f"{model_name}.downloading")
|
||||
model.save_pretrained(download_path, safe_serialization=True)
|
||||
download_path.replace(dest)
|
||||
|
||||
def _download_vae(self, repo_id: str, subfolder:str=None)->Path:
|
||||
vae = AutoencoderKL.from_pretrained(repo_id, cache_dir=self.root_directory / 'models/hub', subfolder=subfolder)
|
||||
def _download_vae(self, repo_id: str, subfolder: str = None) -> Path:
|
||||
vae = AutoencoderKL.from_pretrained(repo_id, cache_dir=self.root_directory / "models/hub", subfolder=subfolder)
|
||||
info = ModelProbe().heuristic_probe(vae)
|
||||
_, model_name = repo_id.split('/')
|
||||
_, model_name = repo_id.split("/")
|
||||
dest = self._model_probe_to_path(info) / self.unique_name(model_name, info)
|
||||
vae.save_pretrained(dest, safe_serialization=True)
|
||||
return dest
|
||||
|
||||
def _vae_path(self, vae: Union[str,dict])->Path:
|
||||
'''
|
||||
def _vae_path(self, vae: Union[str, dict]) -> Path:
|
||||
"""
|
||||
Convert 2.3 VAE stanza to a straight path.
|
||||
'''
|
||||
"""
|
||||
vae_path = None
|
||||
|
||||
|
||||
# First get a path
|
||||
if isinstance(vae,str):
|
||||
if isinstance(vae, str):
|
||||
vae_path = vae
|
||||
|
||||
elif isinstance(vae,DictConfig):
|
||||
if p := vae.get('path'):
|
||||
elif isinstance(vae, DictConfig):
|
||||
if p := vae.get("path"):
|
||||
vae_path = p
|
||||
elif repo_id := vae.get('repo_id'):
|
||||
if repo_id=='stabilityai/sd-vae-ft-mse': # this guy is already downloaded
|
||||
vae_path = 'models/core/convert/sd-vae-ft-mse'
|
||||
elif repo_id := vae.get("repo_id"):
|
||||
if repo_id == "stabilityai/sd-vae-ft-mse": # this guy is already downloaded
|
||||
vae_path = "models/core/convert/sd-vae-ft-mse"
|
||||
return vae_path
|
||||
else:
|
||||
vae_path = self._download_vae(repo_id, vae.get('subfolder'))
|
||||
vae_path = self._download_vae(repo_id, vae.get("subfolder"))
|
||||
|
||||
assert vae_path is not None, "Couldn't find VAE for this model"
|
||||
|
||||
@ -307,152 +305,144 @@ class MigrateTo3(object):
|
||||
dest = self._model_probe_to_path(info) / vae_path.name
|
||||
if not dest.exists():
|
||||
if vae_path.is_dir():
|
||||
self.copy_dir(vae_path,dest)
|
||||
self.copy_dir(vae_path, dest)
|
||||
else:
|
||||
self.copy_file(vae_path,dest)
|
||||
self.copy_file(vae_path, dest)
|
||||
vae_path = dest
|
||||
|
||||
if vae_path.is_relative_to(self.dest_models):
|
||||
rel_path = vae_path.relative_to(self.dest_models)
|
||||
return Path('models',rel_path)
|
||||
return Path("models", rel_path)
|
||||
else:
|
||||
return vae_path
|
||||
|
||||
def migrate_repo_id(self, repo_id: str, model_name: str=None, **extra_config):
|
||||
'''
|
||||
def migrate_repo_id(self, repo_id: str, model_name: str = None, **extra_config):
|
||||
"""
|
||||
Migrate a locally-cached diffusers pipeline identified with a repo_id
|
||||
'''
|
||||
"""
|
||||
dest_dir = self.dest_models
|
||||
|
||||
cache = self.root_directory / 'models/hub'
|
||||
|
||||
cache = self.root_directory / "models/hub"
|
||||
kwargs = dict(
|
||||
cache_dir = cache,
|
||||
safety_checker = None,
|
||||
cache_dir=cache,
|
||||
safety_checker=None,
|
||||
# local_files_only = True,
|
||||
)
|
||||
|
||||
owner,repo_name = repo_id.split('/')
|
||||
owner, repo_name = repo_id.split("/")
|
||||
model_name = model_name or repo_name
|
||||
model = cache / '--'.join(['models',owner,repo_name])
|
||||
|
||||
if len(list(model.glob('snapshots/**/model_index.json')))==0:
|
||||
model = cache / "--".join(["models", owner, repo_name])
|
||||
|
||||
if len(list(model.glob("snapshots/**/model_index.json"))) == 0:
|
||||
return
|
||||
revisions = [x.name for x in model.glob('refs/*')]
|
||||
revisions = [x.name for x in model.glob("refs/*")]
|
||||
|
||||
# if an fp16 is available we use that
|
||||
revision = 'fp16' if len(revisions) > 1 and 'fp16' in revisions else revisions[0]
|
||||
pipeline = StableDiffusionPipeline.from_pretrained(
|
||||
repo_id,
|
||||
revision=revision,
|
||||
**kwargs)
|
||||
revision = "fp16" if len(revisions) > 1 and "fp16" in revisions else revisions[0]
|
||||
pipeline = StableDiffusionPipeline.from_pretrained(repo_id, revision=revision, **kwargs)
|
||||
|
||||
info = ModelProbe().heuristic_probe(pipeline)
|
||||
if not info:
|
||||
return
|
||||
|
||||
if self.mgr.model_exists(model_name, info.base_type, info.model_type):
|
||||
logger.warning(f'A model named {model_name} already exists at the destination. Skipping migration.')
|
||||
logger.warning(f"A model named {model_name} already exists at the destination. Skipping migration.")
|
||||
return
|
||||
|
||||
dest = self._model_probe_to_path(info) / model_name
|
||||
self._save_pretrained(pipeline, dest)
|
||||
|
||||
rel_path = Path('models',dest.relative_to(dest_dir))
|
||||
|
||||
rel_path = Path("models", dest.relative_to(dest_dir))
|
||||
self._add_model(model_name, info, rel_path, **extra_config)
|
||||
|
||||
def migrate_path(self, location: Path, model_name: str=None, **extra_config):
|
||||
'''
|
||||
def migrate_path(self, location: Path, model_name: str = None, **extra_config):
|
||||
"""
|
||||
Migrate a model referred to using 'weights' or 'path'
|
||||
'''
|
||||
"""
|
||||
|
||||
# handle relative paths
|
||||
dest_dir = self.dest_models
|
||||
location = self.root_directory / location
|
||||
model_name = model_name or location.stem
|
||||
|
||||
|
||||
info = ModelProbe().heuristic_probe(location)
|
||||
if not info:
|
||||
return
|
||||
|
||||
|
||||
if self.mgr.model_exists(model_name, info.base_type, info.model_type):
|
||||
logger.warning(f'A model named {model_name} already exists at the destination. Skipping migration.')
|
||||
logger.warning(f"A model named {model_name} already exists at the destination. Skipping migration.")
|
||||
return
|
||||
|
||||
# uh oh, weights is in the old models directory - move it into the new one
|
||||
if Path(location).is_relative_to(self.src_paths.models):
|
||||
dest = Path(dest_dir, info.base_type.value, info.model_type.value, location.name)
|
||||
if location.is_dir():
|
||||
self.copy_dir(location,dest)
|
||||
self.copy_dir(location, dest)
|
||||
else:
|
||||
self.copy_file(location,dest)
|
||||
location = Path('models', info.base_type.value, info.model_type.value, location.name)
|
||||
self.copy_file(location, dest)
|
||||
location = Path("models", info.base_type.value, info.model_type.value, location.name)
|
||||
|
||||
self._add_model(model_name, info, location, **extra_config)
|
||||
|
||||
def _add_model(self,
|
||||
model_name: str,
|
||||
info: ModelProbeInfo,
|
||||
location: Path,
|
||||
**extra_config):
|
||||
def _add_model(self, model_name: str, info: ModelProbeInfo, location: Path, **extra_config):
|
||||
if info.model_type != ModelType.Main:
|
||||
return
|
||||
|
||||
self.mgr.add_model(
|
||||
model_name = model_name,
|
||||
base_model = info.base_type,
|
||||
model_type = info.model_type,
|
||||
clobber = True,
|
||||
model_attributes = {
|
||||
'path': str(location),
|
||||
'description': f'A {info.base_type.value} {info.model_type.value} model',
|
||||
'model_format': info.format,
|
||||
'variant': info.variant_type.value,
|
||||
**extra_config,
|
||||
}
|
||||
)
|
||||
|
||||
def migrate_defined_models(self):
|
||||
'''
|
||||
Migrate models defined in models.yaml
|
||||
'''
|
||||
# find any models referred to in old models.yaml
|
||||
conf = OmegaConf.load(self.root_directory / 'configs/models.yaml')
|
||||
|
||||
for model_name, stanza in conf.items():
|
||||
|
||||
self.mgr.add_model(
|
||||
model_name=model_name,
|
||||
base_model=info.base_type,
|
||||
model_type=info.model_type,
|
||||
clobber=True,
|
||||
model_attributes={
|
||||
"path": str(location),
|
||||
"description": f"A {info.base_type.value} {info.model_type.value} model",
|
||||
"model_format": info.format,
|
||||
"variant": info.variant_type.value,
|
||||
**extra_config,
|
||||
},
|
||||
)
|
||||
|
||||
def migrate_defined_models(self):
|
||||
"""
|
||||
Migrate models defined in models.yaml
|
||||
"""
|
||||
# find any models referred to in old models.yaml
|
||||
conf = OmegaConf.load(self.root_directory / "configs/models.yaml")
|
||||
|
||||
for model_name, stanza in conf.items():
|
||||
try:
|
||||
passthru_args = {}
|
||||
|
||||
if vae := stanza.get('vae'):
|
||||
|
||||
if vae := stanza.get("vae"):
|
||||
try:
|
||||
passthru_args['vae'] = str(self._vae_path(vae))
|
||||
passthru_args["vae"] = str(self._vae_path(vae))
|
||||
except Exception as e:
|
||||
logger.warning(f'Could not find a VAE matching "{vae}" for model "{model_name}"')
|
||||
logger.warning(str(e))
|
||||
|
||||
if config := stanza.get('config'):
|
||||
passthru_args['config'] = config
|
||||
if config := stanza.get("config"):
|
||||
passthru_args["config"] = config
|
||||
|
||||
if description:= stanza.get('description'):
|
||||
passthru_args['description'] = description
|
||||
|
||||
if repo_id := stanza.get('repo_id'):
|
||||
logger.info(f'Migrating diffusers model {model_name}')
|
||||
if description := stanza.get("description"):
|
||||
passthru_args["description"] = description
|
||||
|
||||
if repo_id := stanza.get("repo_id"):
|
||||
logger.info(f"Migrating diffusers model {model_name}")
|
||||
self.migrate_repo_id(repo_id, model_name, **passthru_args)
|
||||
|
||||
elif location := stanza.get('weights'):
|
||||
logger.info(f'Migrating checkpoint model {model_name}')
|
||||
elif location := stanza.get("weights"):
|
||||
logger.info(f"Migrating checkpoint model {model_name}")
|
||||
self.migrate_path(Path(location), model_name, **passthru_args)
|
||||
|
||||
elif location := stanza.get('path'):
|
||||
logger.info(f'Migrating diffusers model {model_name}')
|
||||
|
||||
elif location := stanza.get("path"):
|
||||
logger.info(f"Migrating diffusers model {model_name}")
|
||||
self.migrate_path(Path(location), model_name, **passthru_args)
|
||||
|
||||
|
||||
except KeyboardInterrupt:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(str(e))
|
||||
|
||||
|
||||
def migrate(self):
|
||||
self.create_directory_structure()
|
||||
# the configure script is doing this
|
||||
@ -461,67 +451,71 @@ class MigrateTo3(object):
|
||||
self.migrate_tuning_models()
|
||||
self.migrate_defined_models()
|
||||
|
||||
def _parse_legacy_initfile(root: Path, initfile: Path)->ModelPaths:
|
||||
'''
|
||||
|
||||
def _parse_legacy_initfile(root: Path, initfile: Path) -> ModelPaths:
|
||||
"""
|
||||
Returns tuple of (embedding_path, lora_path, controlnet_path)
|
||||
'''
|
||||
parser = argparse.ArgumentParser(fromfile_prefix_chars='@')
|
||||
"""
|
||||
parser = argparse.ArgumentParser(fromfile_prefix_chars="@")
|
||||
parser.add_argument(
|
||||
'--embedding_directory',
|
||||
'--embedding_path',
|
||||
"--embedding_directory",
|
||||
"--embedding_path",
|
||||
type=Path,
|
||||
dest='embedding_path',
|
||||
default=Path('embeddings'),
|
||||
dest="embedding_path",
|
||||
default=Path("embeddings"),
|
||||
)
|
||||
parser.add_argument(
|
||||
'--lora_directory',
|
||||
dest='lora_path',
|
||||
"--lora_directory",
|
||||
dest="lora_path",
|
||||
type=Path,
|
||||
default=Path('loras'),
|
||||
default=Path("loras"),
|
||||
)
|
||||
opt,_ = parser.parse_known_args([f'@{str(initfile)}'])
|
||||
opt, _ = parser.parse_known_args([f"@{str(initfile)}"])
|
||||
return ModelPaths(
|
||||
models = root / 'models',
|
||||
embeddings = root / str(opt.embedding_path).strip('"'),
|
||||
loras = root / str(opt.lora_path).strip('"'),
|
||||
controlnets = root / 'controlnets',
|
||||
models=root / "models",
|
||||
embeddings=root / str(opt.embedding_path).strip('"'),
|
||||
loras=root / str(opt.lora_path).strip('"'),
|
||||
controlnets=root / "controlnets",
|
||||
)
|
||||
|
||||
def _parse_legacy_yamlfile(root: Path, initfile: Path)->ModelPaths:
|
||||
'''
|
||||
|
||||
def _parse_legacy_yamlfile(root: Path, initfile: Path) -> ModelPaths:
|
||||
"""
|
||||
Returns tuple of (embedding_path, lora_path, controlnet_path)
|
||||
'''
|
||||
"""
|
||||
# Don't use the config object because it is unforgiving of version updates
|
||||
# Just use omegaconf directly
|
||||
opt = OmegaConf.load(initfile)
|
||||
paths = opt.InvokeAI.Paths
|
||||
models = paths.get('models_dir','models')
|
||||
embeddings = paths.get('embedding_dir','embeddings')
|
||||
loras = paths.get('lora_dir','loras')
|
||||
controlnets = paths.get('controlnet_dir','controlnets')
|
||||
models = paths.get("models_dir", "models")
|
||||
embeddings = paths.get("embedding_dir", "embeddings")
|
||||
loras = paths.get("lora_dir", "loras")
|
||||
controlnets = paths.get("controlnet_dir", "controlnets")
|
||||
return ModelPaths(
|
||||
models = root / models,
|
||||
embeddings = root / embeddings,
|
||||
loras = root /loras,
|
||||
controlnets = root / controlnets,
|
||||
models=root / models,
|
||||
embeddings=root / embeddings,
|
||||
loras=root / loras,
|
||||
controlnets=root / controlnets,
|
||||
)
|
||||
|
||||
|
||||
|
||||
def get_legacy_embeddings(root: Path) -> ModelPaths:
|
||||
path = root / 'invokeai.init'
|
||||
path = root / "invokeai.init"
|
||||
if path.exists():
|
||||
return _parse_legacy_initfile(root, path)
|
||||
path = root / 'invokeai.yaml'
|
||||
path = root / "invokeai.yaml"
|
||||
if path.exists():
|
||||
return _parse_legacy_yamlfile(root, path)
|
||||
|
||||
|
||||
def do_migrate(src_directory: Path, dest_directory: Path):
|
||||
"""
|
||||
Migrate models from src to dest InvokeAI root directories
|
||||
"""
|
||||
config_file = dest_directory / 'configs' / 'models.yaml.3'
|
||||
dest_models = dest_directory / 'models.3'
|
||||
|
||||
version_3 = (dest_directory / 'models' / 'core').exists()
|
||||
config_file = dest_directory / "configs" / "models.yaml.3"
|
||||
dest_models = dest_directory / "models.3"
|
||||
|
||||
version_3 = (dest_directory / "models" / "core").exists()
|
||||
|
||||
# Here we create the destination models.yaml file.
|
||||
# If we are writing into a version 3 directory and the
|
||||
@ -530,80 +524,80 @@ def do_migrate(src_directory: Path, dest_directory: Path):
|
||||
# create a new empty one.
|
||||
if version_3: # write into the dest directory
|
||||
try:
|
||||
shutil.copy(dest_directory / 'configs' / 'models.yaml', config_file)
|
||||
shutil.copy(dest_directory / "configs" / "models.yaml", config_file)
|
||||
except:
|
||||
MigrateTo3.initialize_yaml(config_file)
|
||||
mgr = ModelManager(config_file) # important to initialize BEFORE moving the models directory
|
||||
(dest_directory / 'models').replace(dest_models)
|
||||
mgr = ModelManager(config_file) # important to initialize BEFORE moving the models directory
|
||||
(dest_directory / "models").replace(dest_models)
|
||||
else:
|
||||
MigrateTo3.initialize_yaml(config_file)
|
||||
mgr = ModelManager(config_file)
|
||||
|
||||
|
||||
paths = get_legacy_embeddings(src_directory)
|
||||
migrator = MigrateTo3(
|
||||
from_root = src_directory,
|
||||
to_models = dest_models,
|
||||
model_manager = mgr,
|
||||
src_paths = paths
|
||||
)
|
||||
migrator = MigrateTo3(from_root=src_directory, to_models=dest_models, model_manager=mgr, src_paths=paths)
|
||||
migrator.migrate()
|
||||
print("Migration successful.")
|
||||
|
||||
if not version_3:
|
||||
(dest_directory / 'models').replace(src_directory / 'models.orig')
|
||||
print(f'Original models directory moved to {dest_directory}/models.orig')
|
||||
|
||||
(dest_directory / 'configs' / 'models.yaml').replace(src_directory / 'configs' / 'models.yaml.orig')
|
||||
print(f'Original models.yaml file moved to {dest_directory}/configs/models.yaml.orig')
|
||||
|
||||
config_file.replace(config_file.with_suffix(''))
|
||||
dest_models.replace(dest_models.with_suffix(''))
|
||||
|
||||
(dest_directory / "models").replace(src_directory / "models.orig")
|
||||
print(f"Original models directory moved to {dest_directory}/models.orig")
|
||||
|
||||
(dest_directory / "configs" / "models.yaml").replace(src_directory / "configs" / "models.yaml.orig")
|
||||
print(f"Original models.yaml file moved to {dest_directory}/configs/models.yaml.orig")
|
||||
|
||||
config_file.replace(config_file.with_suffix(""))
|
||||
dest_models.replace(dest_models.with_suffix(""))
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(prog="invokeai-migrate3",
|
||||
description="""
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="invokeai-migrate3",
|
||||
description="""
|
||||
This will copy and convert the models directory and the configs/models.yaml from the InvokeAI 2.3 format
|
||||
'--from-directory' root to the InvokeAI 3.0 '--to-directory' root. These may be abbreviated '--from' and '--to'.a
|
||||
|
||||
The old models directory and config file will be renamed 'models.orig' and 'models.yaml.orig' respectively.
|
||||
It is safe to provide the same directory for both arguments, but it is better to use the invokeai_configure
|
||||
script, which will perform a full upgrade in place."""
|
||||
)
|
||||
parser.add_argument('--from-directory',
|
||||
dest='src_root',
|
||||
type=Path,
|
||||
required=True,
|
||||
help='Source InvokeAI 2.3 root directory (containing "invokeai.init" or "invokeai.yaml")'
|
||||
)
|
||||
parser.add_argument('--to-directory',
|
||||
dest='dest_root',
|
||||
type=Path,
|
||||
required=True,
|
||||
help='Destination InvokeAI 3.0 directory (containing "invokeai.yaml")'
|
||||
)
|
||||
script, which will perform a full upgrade in place.""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--from-directory",
|
||||
dest="src_root",
|
||||
type=Path,
|
||||
required=True,
|
||||
help='Source InvokeAI 2.3 root directory (containing "invokeai.init" or "invokeai.yaml")',
|
||||
)
|
||||
parser.add_argument(
|
||||
"--to-directory",
|
||||
dest="dest_root",
|
||||
type=Path,
|
||||
required=True,
|
||||
help='Destination InvokeAI 3.0 directory (containing "invokeai.yaml")',
|
||||
)
|
||||
args = parser.parse_args()
|
||||
src_root = args.src_root
|
||||
assert src_root.is_dir(), f"{src_root} is not a valid directory"
|
||||
assert (src_root / 'models').is_dir(), f"{src_root} does not contain a 'models' subdirectory"
|
||||
assert (src_root / 'models' / 'hub').exists(), f"{src_root} does not contain a version 2.3 models directory"
|
||||
assert (src_root / 'invokeai.init').exists() or (src_root / 'invokeai.yaml').exists(), f"{src_root} does not contain an InvokeAI init file."
|
||||
assert (src_root / "models").is_dir(), f"{src_root} does not contain a 'models' subdirectory"
|
||||
assert (src_root / "models" / "hub").exists(), f"{src_root} does not contain a version 2.3 models directory"
|
||||
assert (src_root / "invokeai.init").exists() or (
|
||||
src_root / "invokeai.yaml"
|
||||
).exists(), f"{src_root} does not contain an InvokeAI init file."
|
||||
|
||||
dest_root = args.dest_root
|
||||
assert dest_root.is_dir(), f"{dest_root} is not a valid directory"
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
config.parse_args(['--root',str(dest_root)])
|
||||
config.parse_args(["--root", str(dest_root)])
|
||||
|
||||
# TODO: revisit - don't rely on invokeai.yaml to exist yet!
|
||||
dest_is_setup = (dest_root / 'models/core').exists() and (dest_root / 'databases').exists()
|
||||
dest_is_setup = (dest_root / "models/core").exists() and (dest_root / "databases").exists()
|
||||
if not dest_is_setup:
|
||||
import invokeai.frontend.install.invokeai_configure
|
||||
from invokeai.backend.install.invokeai_configure import initialize_rootdir
|
||||
|
||||
initialize_rootdir(dest_root, True)
|
||||
|
||||
do_migrate(src_root,dest_root)
|
||||
do_migrate(src_root, dest_root)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
||||
|
||||
|
@ -4,7 +4,7 @@ Utility (backend) functions used by model_install.py
|
||||
import os
|
||||
import shutil
|
||||
import warnings
|
||||
from dataclasses import dataclass,field
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import List, Dict, Callable, Union, Set
|
||||
@ -28,7 +28,7 @@ warnings.filterwarnings("ignore")
|
||||
|
||||
# --------------------------globals-----------------------
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
logger = InvokeAILogger.getLogger(name='InvokeAI')
|
||||
logger = InvokeAILogger.getLogger(name="InvokeAI")
|
||||
|
||||
# the initial "configs" dir is now bundled in the `invokeai.configs` package
|
||||
Dataset_path = Path(configs.__path__[0]) / "INITIAL_MODELS.yaml"
|
||||
@ -45,51 +45,63 @@ Config_preamble = """
|
||||
|
||||
LEGACY_CONFIGS = {
|
||||
BaseModelType.StableDiffusion1: {
|
||||
ModelVariantType.Normal: 'v1-inference.yaml',
|
||||
ModelVariantType.Inpaint: 'v1-inpainting-inference.yaml',
|
||||
ModelVariantType.Normal: "v1-inference.yaml",
|
||||
ModelVariantType.Inpaint: "v1-inpainting-inference.yaml",
|
||||
},
|
||||
|
||||
BaseModelType.StableDiffusion2: {
|
||||
ModelVariantType.Normal: {
|
||||
SchedulerPredictionType.Epsilon: 'v2-inference.yaml',
|
||||
SchedulerPredictionType.VPrediction: 'v2-inference-v.yaml',
|
||||
SchedulerPredictionType.Epsilon: "v2-inference.yaml",
|
||||
SchedulerPredictionType.VPrediction: "v2-inference-v.yaml",
|
||||
},
|
||||
ModelVariantType.Inpaint: {
|
||||
SchedulerPredictionType.Epsilon: 'v2-inpainting-inference.yaml',
|
||||
SchedulerPredictionType.VPrediction: 'v2-inpainting-inference-v.yaml',
|
||||
}
|
||||
}
|
||||
SchedulerPredictionType.Epsilon: "v2-inpainting-inference.yaml",
|
||||
SchedulerPredictionType.VPrediction: "v2-inpainting-inference-v.yaml",
|
||||
},
|
||||
},
|
||||
BaseModelType.StableDiffusionXL: {
|
||||
ModelVariantType.Normal: "sd_xl_base.yaml",
|
||||
},
|
||||
BaseModelType.StableDiffusionXLRefiner: {
|
||||
ModelVariantType.Normal: "sd_xl_refiner.yaml",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelInstallList:
|
||||
'''Class for listing models to be installed/removed'''
|
||||
"""Class for listing models to be installed/removed"""
|
||||
|
||||
install_models: List[str] = field(default_factory=list)
|
||||
remove_models: List[str] = field(default_factory=list)
|
||||
|
||||
@dataclass
|
||||
class InstallSelections():
|
||||
install_models: List[str]= field(default_factory=list)
|
||||
remove_models: List[str]=field(default_factory=list)
|
||||
|
||||
@dataclass
|
||||
class ModelLoadInfo():
|
||||
class InstallSelections:
|
||||
install_models: List[str] = field(default_factory=list)
|
||||
remove_models: List[str] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelLoadInfo:
|
||||
name: str
|
||||
model_type: ModelType
|
||||
base_type: BaseModelType
|
||||
path: Path = None
|
||||
repo_id: str = None
|
||||
description: str = ''
|
||||
description: str = ""
|
||||
installed: bool = False
|
||||
recommended: bool = False
|
||||
default: bool = False
|
||||
|
||||
|
||||
class ModelInstall(object):
|
||||
def __init__(self,
|
||||
config:InvokeAIAppConfig,
|
||||
prediction_type_helper: Callable[[Path],SchedulerPredictionType]=None,
|
||||
model_manager: ModelManager = None,
|
||||
access_token:str = None):
|
||||
def __init__(
|
||||
self,
|
||||
config: InvokeAIAppConfig,
|
||||
prediction_type_helper: Callable[[Path], SchedulerPredictionType] = None,
|
||||
model_manager: ModelManager = None,
|
||||
access_token: str = None,
|
||||
):
|
||||
self.config = config
|
||||
self.mgr = model_manager or ModelManager(config.model_conf_path)
|
||||
self.datasets = OmegaConf.load(Dataset_path)
|
||||
@ -97,65 +109,66 @@ class ModelInstall(object):
|
||||
self.access_token = access_token or HfFolder.get_token()
|
||||
self.reverse_paths = self._reverse_paths(self.datasets)
|
||||
|
||||
def all_models(self)->Dict[str,ModelLoadInfo]:
|
||||
'''
|
||||
def all_models(self) -> Dict[str, ModelLoadInfo]:
|
||||
"""
|
||||
Return dict of model_key=>ModelLoadInfo objects.
|
||||
This method consolidates and simplifies the entries in both
|
||||
models.yaml and INITIAL_MODELS.yaml so that they can
|
||||
be treated uniformly. It also sorts the models alphabetically
|
||||
by their name, to improve the display somewhat.
|
||||
'''
|
||||
"""
|
||||
model_dict = dict()
|
||||
|
||||
|
||||
# first populate with the entries in INITIAL_MODELS.yaml
|
||||
for key, value in self.datasets.items():
|
||||
name,base,model_type = ModelManager.parse_key(key)
|
||||
value['name'] = name
|
||||
value['base_type'] = base
|
||||
value['model_type'] = model_type
|
||||
name, base, model_type = ModelManager.parse_key(key)
|
||||
value["name"] = name
|
||||
value["base_type"] = base
|
||||
value["model_type"] = model_type
|
||||
model_dict[key] = ModelLoadInfo(**value)
|
||||
|
||||
# supplement with entries in models.yaml
|
||||
installed_models = self.mgr.list_models()
|
||||
|
||||
|
||||
for md in installed_models:
|
||||
base = md['base_model']
|
||||
model_type = md['model_type']
|
||||
name = md['model_name']
|
||||
base = md["base_model"]
|
||||
model_type = md["model_type"]
|
||||
name = md["model_name"]
|
||||
key = ModelManager.create_key(name, base, model_type)
|
||||
if key in model_dict:
|
||||
model_dict[key].installed = True
|
||||
else:
|
||||
model_dict[key] = ModelLoadInfo(
|
||||
name = name,
|
||||
base_type = base,
|
||||
model_type = model_type,
|
||||
path = value.get('path'),
|
||||
installed = True,
|
||||
name=name,
|
||||
base_type=base,
|
||||
model_type=model_type,
|
||||
path=value.get("path"),
|
||||
installed=True,
|
||||
)
|
||||
return {x : model_dict[x] for x in sorted(model_dict.keys(),key=lambda y: model_dict[y].name.lower())}
|
||||
return {x: model_dict[x] for x in sorted(model_dict.keys(), key=lambda y: model_dict[y].name.lower())}
|
||||
|
||||
def list_models(self, model_type):
|
||||
installed = self.mgr.list_models(model_type=model_type)
|
||||
print(f'Installed models of type `{model_type}`:')
|
||||
print(f"Installed models of type `{model_type}`:")
|
||||
for i in installed:
|
||||
print(f"{i['model_name']}\t{i['base_model']}\t{i['path']}")
|
||||
|
||||
def starter_models(self)->Set[str]:
|
||||
# logic here a little reversed to maintain backward compatibility
|
||||
def starter_models(self, all_models: bool = False) -> Set[str]:
|
||||
models = set()
|
||||
for key, value in self.datasets.items():
|
||||
name,base,model_type = ModelManager.parse_key(key)
|
||||
if model_type==ModelType.Main:
|
||||
name, base, model_type = ModelManager.parse_key(key)
|
||||
if all_models or model_type in [ModelType.Main, ModelType.Vae]:
|
||||
models.add(key)
|
||||
return models
|
||||
|
||||
def recommended_models(self)->Set[str]:
|
||||
def recommended_models(self) -> Set[str]:
|
||||
starters = self.starter_models(all_models=True)
|
||||
return set([x for x in starters if self.datasets[x].get("recommended", False)])
|
||||
|
||||
def default_model(self) -> str:
|
||||
starters = self.starter_models()
|
||||
return set([x for x in starters if self.datasets[x].get('recommended',False)])
|
||||
|
||||
def default_model(self)->str:
|
||||
starters = self.starter_models()
|
||||
defaults = [x for x in starters if self.datasets[x].get('default',False)]
|
||||
defaults = [x for x in starters if self.datasets[x].get("default", False)]
|
||||
return defaults[0]
|
||||
|
||||
def install(self, selections: InstallSelections):
|
||||
@ -164,54 +177,57 @@ class ModelInstall(object):
|
||||
|
||||
job = 1
|
||||
jobs = len(selections.remove_models) + len(selections.install_models)
|
||||
|
||||
|
||||
# remove requested models
|
||||
for key in selections.remove_models:
|
||||
name,base,mtype = self.mgr.parse_key(key)
|
||||
logger.info(f'Deleting {mtype} model {name} [{job}/{jobs}]')
|
||||
name, base, mtype = self.mgr.parse_key(key)
|
||||
logger.info(f"Deleting {mtype} model {name} [{job}/{jobs}]")
|
||||
try:
|
||||
self.mgr.del_model(name,base,mtype)
|
||||
self.mgr.del_model(name, base, mtype)
|
||||
except FileNotFoundError as e:
|
||||
logger.warning(e)
|
||||
job += 1
|
||||
|
||||
|
||||
# add requested models
|
||||
for path in selections.install_models:
|
||||
logger.info(f'Installing {path} [{job}/{jobs}]')
|
||||
logger.info(f"Installing {path} [{job}/{jobs}]")
|
||||
try:
|
||||
self.heuristic_import(path)
|
||||
except (ValueError, KeyError) as e:
|
||||
logger.error(str(e))
|
||||
job += 1
|
||||
|
||||
|
||||
dlogging.set_verbosity(verbosity)
|
||||
self.mgr.commit()
|
||||
|
||||
def heuristic_import(self,
|
||||
model_path_id_or_url: Union[str,Path],
|
||||
models_installed: Set[Path]=None,
|
||||
)->Dict[str, AddModelResult]:
|
||||
'''
|
||||
def heuristic_import(
|
||||
self,
|
||||
model_path_id_or_url: Union[str, Path],
|
||||
models_installed: Set[Path] = None,
|
||||
) -> Dict[str, AddModelResult]:
|
||||
"""
|
||||
:param model_path_id_or_url: A Path to a local model to import, or a string representing its repo_id or URL
|
||||
:param models_installed: Set of installed models, used for recursive invocation
|
||||
Returns a set of dict objects corresponding to newly-created stanzas in models.yaml.
|
||||
'''
|
||||
"""
|
||||
|
||||
if not models_installed:
|
||||
models_installed = dict()
|
||||
|
||||
|
||||
# A little hack to allow nested routines to retrieve info on the requested ID
|
||||
self.current_id = model_path_id_or_url
|
||||
path = Path(model_path_id_or_url)
|
||||
# checkpoint file, or similar
|
||||
if path.is_file():
|
||||
models_installed.update({str(path):self._install_path(path)})
|
||||
models_installed.update({str(path): self._install_path(path)})
|
||||
|
||||
# folders style or similar
|
||||
elif path.is_dir() and any([(path/x).exists() for x in \
|
||||
{'config.json','model_index.json','learned_embeds.bin','pytorch_lora_weights.bin'}
|
||||
]
|
||||
):
|
||||
elif path.is_dir() and any(
|
||||
[
|
||||
(path / x).exists()
|
||||
for x in {"config.json", "model_index.json", "learned_embeds.bin", "pytorch_lora_weights.bin"}
|
||||
]
|
||||
):
|
||||
models_installed.update({str(model_path_id_or_url): self._install_path(path)})
|
||||
|
||||
# recursive scan
|
||||
@ -220,7 +236,7 @@ class ModelInstall(object):
|
||||
self.heuristic_import(child, models_installed=models_installed)
|
||||
|
||||
# huggingface repo
|
||||
elif len(str(model_path_id_or_url).split('/')) == 2:
|
||||
elif len(str(model_path_id_or_url).split("/")) == 2:
|
||||
models_installed.update({str(model_path_id_or_url): self._install_repo(str(model_path_id_or_url))})
|
||||
|
||||
# a URL
|
||||
@ -228,42 +244,43 @@ class ModelInstall(object):
|
||||
models_installed.update({str(model_path_id_or_url): self._install_url(model_path_id_or_url)})
|
||||
|
||||
else:
|
||||
raise KeyError(f'{str(model_path_id_or_url)} is not recognized as a local path, repo ID or URL. Skipping')
|
||||
raise KeyError(f"{str(model_path_id_or_url)} is not recognized as a local path, repo ID or URL. Skipping")
|
||||
|
||||
return models_installed
|
||||
|
||||
# install a model from a local path. The optional info parameter is there to prevent
|
||||
# the model from being probed twice in the event that it has already been probed.
|
||||
def _install_path(self, path: Path, info: ModelProbeInfo=None)->AddModelResult:
|
||||
info = info or ModelProbe().heuristic_probe(path,self.prediction_helper)
|
||||
def _install_path(self, path: Path, info: ModelProbeInfo = None) -> AddModelResult:
|
||||
info = info or ModelProbe().heuristic_probe(path, self.prediction_helper)
|
||||
if not info:
|
||||
logger.warning(f'Unable to parse format of {path}')
|
||||
logger.warning(f"Unable to parse format of {path}")
|
||||
return None
|
||||
model_name = path.stem if path.is_file() else path.name
|
||||
if self.mgr.model_exists(model_name, info.base_type, info.model_type):
|
||||
raise ValueError(f'A model named "{model_name}" is already installed.')
|
||||
attributes = self._make_attributes(path,info)
|
||||
return self.mgr.add_model(model_name = model_name,
|
||||
base_model = info.base_type,
|
||||
model_type = info.model_type,
|
||||
model_attributes = attributes,
|
||||
)
|
||||
attributes = self._make_attributes(path, info)
|
||||
return self.mgr.add_model(
|
||||
model_name=model_name,
|
||||
base_model=info.base_type,
|
||||
model_type=info.model_type,
|
||||
model_attributes=attributes,
|
||||
)
|
||||
|
||||
def _install_url(self, url: str)->AddModelResult:
|
||||
def _install_url(self, url: str) -> AddModelResult:
|
||||
with TemporaryDirectory(dir=self.config.models_path) as staging:
|
||||
location = download_with_resume(url,Path(staging))
|
||||
location = download_with_resume(url, Path(staging))
|
||||
if not location:
|
||||
logger.error(f'Unable to download {url}. Skipping.')
|
||||
logger.error(f"Unable to download {url}. Skipping.")
|
||||
info = ModelProbe().heuristic_probe(location)
|
||||
dest = self.config.models_path / info.base_type.value / info.model_type.value / location.name
|
||||
models_path = shutil.move(location,dest)
|
||||
models_path = shutil.move(location, dest)
|
||||
|
||||
# staged version will be garbage-collected at this time
|
||||
return self._install_path(Path(models_path), info)
|
||||
|
||||
def _install_repo(self, repo_id: str)->AddModelResult:
|
||||
def _install_repo(self, repo_id: str) -> AddModelResult:
|
||||
hinfo = HfApi().model_info(repo_id)
|
||||
|
||||
|
||||
# we try to figure out how to download this most economically
|
||||
# list all the files in the repo
|
||||
files = [x.rfilename for x in hinfo.siblings]
|
||||
@ -271,42 +288,49 @@ class ModelInstall(object):
|
||||
|
||||
with TemporaryDirectory(dir=self.config.models_path) as staging:
|
||||
staging = Path(staging)
|
||||
if 'model_index.json' in files:
|
||||
location = self._download_hf_pipeline(repo_id, staging) # pipeline
|
||||
if "model_index.json" in files:
|
||||
location = self._download_hf_pipeline(repo_id, staging) # pipeline
|
||||
else:
|
||||
for suffix in ['safetensors','bin']:
|
||||
if f'pytorch_lora_weights.{suffix}' in files:
|
||||
location = self._download_hf_model(repo_id, ['pytorch_lora_weights.bin'], staging) # LoRA
|
||||
for suffix in ["safetensors", "bin"]:
|
||||
if f"pytorch_lora_weights.{suffix}" in files:
|
||||
location = self._download_hf_model(repo_id, ["pytorch_lora_weights.bin"], staging) # LoRA
|
||||
break
|
||||
elif self.config.precision=='float16' and f'diffusion_pytorch_model.fp16.{suffix}' in files: # vae, controlnet or some other standalone
|
||||
files = ['config.json', f'diffusion_pytorch_model.fp16.{suffix}']
|
||||
elif (
|
||||
self.config.precision == "float16" and f"diffusion_pytorch_model.fp16.{suffix}" in files
|
||||
): # vae, controlnet or some other standalone
|
||||
files = ["config.json", f"diffusion_pytorch_model.fp16.{suffix}"]
|
||||
location = self._download_hf_model(repo_id, files, staging)
|
||||
break
|
||||
elif f'diffusion_pytorch_model.{suffix}' in files:
|
||||
files = ['config.json', f'diffusion_pytorch_model.{suffix}']
|
||||
elif f"diffusion_pytorch_model.{suffix}" in files:
|
||||
files = ["config.json", f"diffusion_pytorch_model.{suffix}"]
|
||||
location = self._download_hf_model(repo_id, files, staging)
|
||||
break
|
||||
elif f'learned_embeds.{suffix}' in files:
|
||||
location = self._download_hf_model(repo_id, [f'learned_embeds.{suffix}'], staging)
|
||||
elif f"learned_embeds.{suffix}" in files:
|
||||
location = self._download_hf_model(repo_id, [f"learned_embeds.{suffix}"], staging)
|
||||
break
|
||||
if not location:
|
||||
logger.warning(f'Could not determine type of repo {repo_id}. Skipping install.')
|
||||
logger.warning(f"Could not determine type of repo {repo_id}. Skipping install.")
|
||||
return {}
|
||||
|
||||
info = ModelProbe().heuristic_probe(location, self.prediction_helper)
|
||||
if not info:
|
||||
logger.warning(f'Could not probe {location}. Skipping install.')
|
||||
logger.warning(f"Could not probe {location}. Skipping install.")
|
||||
return {}
|
||||
dest = self.config.models_path / info.base_type.value / info.model_type.value / self._get_model_name(repo_id,location)
|
||||
dest = (
|
||||
self.config.models_path
|
||||
/ info.base_type.value
|
||||
/ info.model_type.value
|
||||
/ self._get_model_name(repo_id, location)
|
||||
)
|
||||
if dest.exists():
|
||||
shutil.rmtree(dest)
|
||||
shutil.copytree(location,dest)
|
||||
shutil.copytree(location, dest)
|
||||
return self._install_path(dest, info)
|
||||
|
||||
def _get_model_name(self,path_name: str, location: Path)->str:
|
||||
'''
|
||||
def _get_model_name(self, path_name: str, location: Path) -> str:
|
||||
"""
|
||||
Calculate a name for the model - primitive implementation.
|
||||
'''
|
||||
"""
|
||||
if key := self.reverse_paths.get(path_name):
|
||||
(name, base, mtype) = ModelManager.parse_key(key)
|
||||
return name
|
||||
@ -315,92 +339,103 @@ class ModelInstall(object):
|
||||
else:
|
||||
return location.stem
|
||||
|
||||
def _make_attributes(self, path: Path, info: ModelProbeInfo)->dict:
|
||||
def _make_attributes(self, path: Path, info: ModelProbeInfo) -> dict:
|
||||
model_name = path.name if path.is_dir() else path.stem
|
||||
description = f'{info.base_type.value} {info.model_type.value} model {model_name}'
|
||||
description = f"{info.base_type.value} {info.model_type.value} model {model_name}"
|
||||
if key := self.reverse_paths.get(self.current_id):
|
||||
if key in self.datasets:
|
||||
description = self.datasets[key].get('description') or description
|
||||
description = self.datasets[key].get("description") or description
|
||||
|
||||
rel_path = self.relative_to_root(path)
|
||||
|
||||
attributes = dict(
|
||||
path = str(rel_path),
|
||||
description = str(description),
|
||||
model_format = info.format,
|
||||
)
|
||||
path=str(rel_path),
|
||||
description=str(description),
|
||||
model_format=info.format,
|
||||
)
|
||||
legacy_conf = None
|
||||
if info.model_type == ModelType.Main:
|
||||
attributes.update(dict(variant = info.variant_type,))
|
||||
if info.format=="checkpoint":
|
||||
attributes.update(
|
||||
dict(
|
||||
variant=info.variant_type,
|
||||
)
|
||||
)
|
||||
if info.format == "checkpoint":
|
||||
try:
|
||||
possible_conf = path.with_suffix('.yaml')
|
||||
possible_conf = path.with_suffix(".yaml")
|
||||
if possible_conf.exists():
|
||||
legacy_conf = str(self.relative_to_root(possible_conf))
|
||||
elif info.base_type == BaseModelType.StableDiffusion2:
|
||||
legacy_conf = Path(self.config.legacy_conf_dir, LEGACY_CONFIGS[info.base_type][info.variant_type][info.prediction_type])
|
||||
legacy_conf = Path(
|
||||
self.config.legacy_conf_dir,
|
||||
LEGACY_CONFIGS[info.base_type][info.variant_type][info.prediction_type],
|
||||
)
|
||||
else:
|
||||
legacy_conf = Path(self.config.legacy_conf_dir, LEGACY_CONFIGS[info.base_type][info.variant_type])
|
||||
legacy_conf = Path(
|
||||
self.config.legacy_conf_dir, LEGACY_CONFIGS[info.base_type][info.variant_type]
|
||||
)
|
||||
except KeyError:
|
||||
legacy_conf = Path(self.config.legacy_conf_dir, 'v1-inference.yaml') # best guess
|
||||
|
||||
attributes.update(
|
||||
dict(
|
||||
config = str(legacy_conf)
|
||||
)
|
||||
)
|
||||
legacy_conf = Path(self.config.legacy_conf_dir, "v1-inference.yaml") # best guess
|
||||
|
||||
if info.model_type == ModelType.ControlNet and info.format == "checkpoint":
|
||||
possible_conf = path.with_suffix(".yaml")
|
||||
if possible_conf.exists():
|
||||
legacy_conf = str(self.relative_to_root(possible_conf))
|
||||
|
||||
if legacy_conf:
|
||||
attributes.update(dict(config=str(legacy_conf)))
|
||||
return attributes
|
||||
|
||||
def relative_to_root(self, path: Path)->Path:
|
||||
def relative_to_root(self, path: Path) -> Path:
|
||||
root = self.config.root_path
|
||||
if path.is_relative_to(root):
|
||||
return path.relative_to(root)
|
||||
else:
|
||||
return path
|
||||
|
||||
def _download_hf_pipeline(self, repo_id: str, staging: Path)->Path:
|
||||
'''
|
||||
def _download_hf_pipeline(self, repo_id: str, staging: Path) -> Path:
|
||||
"""
|
||||
This retrieves a StableDiffusion model from cache or remote and then
|
||||
does a save_pretrained() to the indicated staging area.
|
||||
'''
|
||||
_,name = repo_id.split("/")
|
||||
revisions = ['fp16','main'] if self.config.precision=='float16' else ['main']
|
||||
"""
|
||||
_, name = repo_id.split("/")
|
||||
revisions = ["fp16", "main"] if self.config.precision == "float16" else ["main"]
|
||||
model = None
|
||||
for revision in revisions:
|
||||
try:
|
||||
model = DiffusionPipeline.from_pretrained(repo_id,revision=revision,safety_checker=None)
|
||||
model = DiffusionPipeline.from_pretrained(repo_id, revision=revision, safety_checker=None)
|
||||
except: # most errors are due to fp16 not being present. Fix this to catch other errors
|
||||
pass
|
||||
if model:
|
||||
break
|
||||
if not model:
|
||||
logger.error(f'Diffusers model {repo_id} could not be downloaded. Skipping.')
|
||||
logger.error(f"Diffusers model {repo_id} could not be downloaded. Skipping.")
|
||||
return None
|
||||
model.save_pretrained(staging / name, safe_serialization=True)
|
||||
return staging / name
|
||||
|
||||
def _download_hf_model(self, repo_id: str, files: List[str], staging: Path)->Path:
|
||||
_,name = repo_id.split("/")
|
||||
def _download_hf_model(self, repo_id: str, files: List[str], staging: Path) -> Path:
|
||||
_, name = repo_id.split("/")
|
||||
location = staging / name
|
||||
paths = list()
|
||||
for filename in files:
|
||||
p = hf_download_with_resume(repo_id,
|
||||
model_dir=location,
|
||||
model_name=filename,
|
||||
access_token = self.access_token
|
||||
)
|
||||
p = hf_download_with_resume(
|
||||
repo_id, model_dir=location, model_name=filename, access_token=self.access_token
|
||||
)
|
||||
if p:
|
||||
paths.append(p)
|
||||
else:
|
||||
logger.warning(f'Could not download {filename} from {repo_id}.')
|
||||
|
||||
return location if len(paths)>0 else None
|
||||
logger.warning(f"Could not download {filename} from {repo_id}.")
|
||||
|
||||
return location if len(paths) > 0 else None
|
||||
|
||||
@classmethod
|
||||
def _reverse_paths(cls,datasets)->dict:
|
||||
'''
|
||||
def _reverse_paths(cls, datasets) -> dict:
|
||||
"""
|
||||
Reverse mapping from repo_id/path to destination name.
|
||||
'''
|
||||
return {v.get('path') or v.get('repo_id') : k for k, v in datasets.items()}
|
||||
"""
|
||||
return {v.get("path") or v.get("repo_id"): k for k, v in datasets.items()}
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
def yes_or_no(prompt: str, default_yes=True):
|
||||
@ -411,13 +446,12 @@ def yes_or_no(prompt: str, default_yes=True):
|
||||
else:
|
||||
return response[0] in ("y", "Y")
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
def hf_download_from_pretrained(
|
||||
model_class: object, model_name: str, destination: Path, **kwargs
|
||||
):
|
||||
logger = InvokeAILogger.getLogger('InvokeAI')
|
||||
logger.addFilter(lambda x: 'fp16 is not a valid' not in x.getMessage())
|
||||
|
||||
def hf_download_from_pretrained(model_class: object, model_name: str, destination: Path, **kwargs):
|
||||
logger = InvokeAILogger.getLogger("InvokeAI")
|
||||
logger.addFilter(lambda x: "fp16 is not a valid" not in x.getMessage())
|
||||
|
||||
model = model_class.from_pretrained(
|
||||
model_name,
|
||||
resume_download=True,
|
||||
@ -426,13 +460,14 @@ def hf_download_from_pretrained(
|
||||
model.save_pretrained(destination, safe_serialization=True)
|
||||
return destination
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
def hf_download_with_resume(
|
||||
repo_id: str,
|
||||
model_dir: str,
|
||||
model_name: str,
|
||||
model_dest: Path = None,
|
||||
access_token: str = None,
|
||||
repo_id: str,
|
||||
model_dir: str,
|
||||
model_name: str,
|
||||
model_dest: Path = None,
|
||||
access_token: str = None,
|
||||
) -> Path:
|
||||
model_dest = model_dest or Path(os.path.join(model_dir, model_name))
|
||||
os.makedirs(model_dir, exist_ok=True)
|
||||
@ -451,9 +486,7 @@ def hf_download_with_resume(
|
||||
resp = requests.get(url, headers=header, stream=True)
|
||||
total = int(resp.headers.get("content-length", 0))
|
||||
|
||||
if (
|
||||
resp.status_code == 416
|
||||
): # "range not satisfiable", which means nothing to return
|
||||
if resp.status_code == 416: # "range not satisfiable", which means nothing to return
|
||||
logger.info(f"{model_name}: complete file found. Skipping.")
|
||||
return model_dest
|
||||
elif resp.status_code == 404:
|
||||
@ -482,5 +515,3 @@ def hf_download_with_resume(
|
||||
logger.error(f"An error occurred while downloading {model_name}: {str(e)}")
|
||||
return None
|
||||
return model_dest
|
||||
|
||||
|
||||
|
@ -3,6 +3,12 @@ Initialization file for invokeai.backend.model_management
|
||||
"""
|
||||
from .model_manager import ModelManager, ModelInfo, AddModelResult, SchedulerPredictionType
|
||||
from .model_cache import ModelCache
|
||||
from .models import BaseModelType, ModelType, SubModelType, ModelVariantType, ModelNotFoundException, DuplicateModelException
|
||||
from .models import (
|
||||
BaseModelType,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
ModelVariantType,
|
||||
ModelNotFoundException,
|
||||
DuplicateModelException,
|
||||
)
|
||||
from .model_merge import ModelMerger, MergeInterpolationMethod
|
||||
|
||||
|