Compare commits
187 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 08c68820cd | |||
| f0f100e67b | |||
| 500de919ed | |||
| a15dd151ff | |||
| 2a40d3c603 | |||
| e44103264d | |||
| 6955c210b7 | |||
| d1750e5eca | |||
| c5a0c59a83 | |||
| f7f015e84b | |||
| f85b74763d | |||
| fd8674a4bc | |||
| d2e0c1ca13 | |||
| 3a9bf4ac10 | |||
| 5cedc8f9b2 | |||
| 86b99b1e98 | |||
| 066afda2f6 | |||
| 8fe1e19522 | |||
| 8aa51f682c | |||
| 5f36f6ab21 | |||
| 1463cea949 | |||
| 73a0b4bba6 | |||
| 9b471436b2 | |||
| 411da7c281 | |||
| 6d337bf23d | |||
| dea5e43c83 | |||
| bde439ef67 | |||
| fc83af4432 | |||
| 337bc4a2fb | |||
| 6fac65f334 | |||
| 5a031d9233 | |||
| e4e875fffe | |||
| b945ba716b | |||
| 2207ef363a | |||
| 3a13b0e762 | |||
| 6429c3db11 | |||
| 5a9dc1c0ca | |||
| 4f2a4a3615 | |||
| 97431f29fe | |||
| ffd0f8ddc3 | |||
| c0725ba2d0 | |||
| c40be2252a | |||
| 7021cdb1de | |||
| cdb60a690d | |||
| 236eb82c3a | |||
| 472c22cc8a | |||
| bcfaf3979a | |||
| eb667e715a | |||
| d6d0b22e66 | |||
| af45872fdb | |||
| b29fc6d4de | |||
| a292d2c47f | |||
| c1c816006e | |||
| 94e9669566 | |||
| 3bb32befe9 | |||
| 48d6102b31 | |||
| 520e52f846 | |||
| 7af576e745 | |||
| 294f8a514f | |||
| bc1a450124 | |||
| 0d1924c48b | |||
| 0fc7dc1c04 | |||
| 3a4a6c43a4 | |||
| 5432d93013 | |||
| 6a86b3ad9b | |||
| 7ff54005fe | |||
| 66767e3876 | |||
| 6d77a6e1c6 | |||
| 98fc525a2c | |||
| ff2952f105 | |||
| 9aa4d098f0 | |||
| a625a7bb81 | |||
| f9c14a8c8c | |||
| 5e80d9ee99 | |||
| 47bccbebae | |||
| 9ba991cad8 | |||
| 9c1c0da026 | |||
| 656437e0a5 | |||
| 6ad666e479 | |||
| 80d639a440 | |||
| 96ee3eff6c | |||
| ff805d8d0e | |||
| c3699d4fd1 | |||
| 4d4a9e7332 | |||
| 44c5097375 | |||
| 44db35fb1a | |||
| ff1609f91e | |||
| d9499f4301 | |||
| 16ab174290 | |||
| 046c7b053a | |||
| 6b8c661c49 | |||
| 2b06cefe66 | |||
| 7edd50f304 | |||
| bbf00a96af | |||
| 329c8bacce | |||
| 1dd25be037 | |||
| f6c8201e56 | |||
| fe1967a4c4 | |||
| 452ab8fe72 | |||
| 399baa54c2 | |||
| 21d561885e | |||
| 73c74baa6a | |||
| 1f373a2baa | |||
| 4afaaf8a02 | |||
| bda2ecdbf5 | |||
| 4c423f6d37 | |||
| cc80a09d82 | |||
| 8052a4971e | |||
| 759515316e | |||
| d727ddfccd | |||
| 65ccd6305f | |||
| a2fad6ee05 | |||
| fbc5c531b9 | |||
| 5121846d34 | |||
| dfc4c27b24 | |||
| 88b2ef3b04 | |||
| 6523edb8a4 | |||
| 3b8515d2c9 | |||
| 4a50c9638c | |||
| de8ee92ed8 | |||
| 76f5abdbdb | |||
| fce86ab7d7 | |||
| 7683547728 | |||
| 2d8c894b27 | |||
| 236dd55dbe | |||
| 443ca983ad | |||
| 464fbcd921 | |||
| 384fab9627 | |||
| 0550659ce6 | |||
| d10c4db57e | |||
| 321680ccd0 | |||
| eb01d7f0e0 | |||
| 853e21d98e | |||
| 1c6efdbba7 | |||
| ec718f76b5 | |||
| 861cbd5636 | |||
| d33cb2b812 | |||
| 3e223523ce | |||
| d295e97a0d | |||
| 77bd953da2 | |||
| 2f6ea8b103 | |||
| a3d9b011a3 | |||
| 282903bb67 | |||
| 0d65d0eabd | |||
| f00eaa4d00 | |||
| d4255506ff | |||
| 117ec71994 | |||
| 4be7b620c2 | |||
| a8cbe50c9f | |||
| 19f5795c27 | |||
| 6fe16a9e1a | |||
| eadef35512 | |||
| 771dac9c5f | |||
| 0619df9835 | |||
| 7cc96429f2 | |||
| 26500b8c1b | |||
| a109c7aeb8 | |||
| 27fdc26a74 | |||
| 3a66c3c9e1 | |||
| 499543cf1d | |||
| 902afa6b4c | |||
| fff1a0c74f | |||
| 954499a494 | |||
| 44d14bc32e | |||
| fbc8d21354 | |||
| 906d1179e9 | |||
| dbb10fbd8c | |||
| 9821625a76 | |||
| 3562b0dc74 | |||
| fd51b8501e | |||
| 09a2da835e | |||
| 770ee23f18 | |||
| 76010a51ef | |||
| e34949be52 | |||
| 35fd24e857 | |||
| f71e919ecb | |||
| 2d947175b9 | |||
| f8f4ff2bb8 | |||
| 702a1e1cc7 | |||
| 74b80e7211 | |||
| e785402b6a | |||
| f5959c1c30 | |||
| 25de9a785c | |||
| c3d51fc696 | |||
| 25189b29af | |||
| aab385d01b | |||
| 061a4a295d |
@@ -1,25 +1,45 @@
|
|||||||
name: Bug Report
|
name: Bug Report
|
||||||
description: You think somethings is broken in the UI
|
description: You think something is broken in the UI
|
||||||
title: "[Bug]: "
|
title: "[Bug]: "
|
||||||
labels: ["bug-report"]
|
labels: ["bug-report"]
|
||||||
|
|
||||||
body:
|
body:
|
||||||
- type: checkboxes
|
|
||||||
attributes:
|
|
||||||
label: Is there an existing issue for this?
|
|
||||||
description: Please search to see if an issue already exists for the bug you encountered, and that it hasn't been fixed in a recent build/commit.
|
|
||||||
options:
|
|
||||||
- label: I have searched the existing issues and checked the recent builds/commits
|
|
||||||
required: true
|
|
||||||
- type: markdown
|
- type: markdown
|
||||||
attributes:
|
attributes:
|
||||||
value: |
|
value: |
|
||||||
*Please fill this form with as much information as possible, don't forget to fill "What OS..." and "What browsers" and *provide screenshots if possible**
|
> The title of the bug report should be short and descriptive.
|
||||||
|
> Use relevant keywords for searchability.
|
||||||
|
> Do not leave it blank, but also do not put an entire error log in it.
|
||||||
|
- type: checkboxes
|
||||||
|
attributes:
|
||||||
|
label: Checklist
|
||||||
|
description: |
|
||||||
|
Please perform basic debugging to see if extensions or configuration is the cause of the issue.
|
||||||
|
Basic debug procedure
|
||||||
|
1. Disable all third-party extensions - check if extension is the cause
|
||||||
|
2. Update extensions and webui - sometimes things just need to be updated
|
||||||
|
3. Backup and remove your config.json and ui-config.json - check if the issue is caused by bad configuration
|
||||||
|
4. Delete venv with third-party extensions disabled - sometimes extensions might cause wrong libraries to be installed
|
||||||
|
5. Try a fresh installation webui in a different directory - see if a clean installation solves the issue
|
||||||
|
Before making a issue report please, check that the issue hasn't been reported recently.
|
||||||
|
options:
|
||||||
|
- label: The issue exists after disabling all extensions
|
||||||
|
- label: The issue exists on a clean installation of webui
|
||||||
|
- label: The issue is caused by an extension, but I believe it is caused by a bug in the webui
|
||||||
|
- label: The issue exists in the current version of the webui
|
||||||
|
- label: The issue has not been reported before recently
|
||||||
|
- label: The issue has been reported before but has not been fixed yet
|
||||||
|
- type: markdown
|
||||||
|
attributes:
|
||||||
|
value: |
|
||||||
|
> Please fill this form with as much information as possible. Don't forget to "Upload Sysinfo" and "What browsers" and provide screenshots if possible
|
||||||
- type: textarea
|
- type: textarea
|
||||||
id: what-did
|
id: what-did
|
||||||
attributes:
|
attributes:
|
||||||
label: What happened?
|
label: What happened?
|
||||||
description: Tell us what happened in a very clear and simple way
|
description: Tell us what happened in a very clear and simple way
|
||||||
|
placeholder: |
|
||||||
|
txt2img is not working as intended.
|
||||||
validations:
|
validations:
|
||||||
required: true
|
required: true
|
||||||
- type: textarea
|
- type: textarea
|
||||||
@@ -27,9 +47,9 @@ body:
|
|||||||
attributes:
|
attributes:
|
||||||
label: Steps to reproduce the problem
|
label: Steps to reproduce the problem
|
||||||
description: Please provide us with precise step by step instructions on how to reproduce the bug
|
description: Please provide us with precise step by step instructions on how to reproduce the bug
|
||||||
value: |
|
placeholder: |
|
||||||
1. Go to ....
|
1. Go to ...
|
||||||
2. Press ....
|
2. Press ...
|
||||||
3. ...
|
3. ...
|
||||||
validations:
|
validations:
|
||||||
required: true
|
required: true
|
||||||
@@ -38,13 +58,8 @@ body:
|
|||||||
attributes:
|
attributes:
|
||||||
label: What should have happened?
|
label: What should have happened?
|
||||||
description: Tell us what you think the normal behavior should be
|
description: Tell us what you think the normal behavior should be
|
||||||
validations:
|
placeholder: |
|
||||||
required: true
|
WebUI should ...
|
||||||
- type: textarea
|
|
||||||
id: sysinfo
|
|
||||||
attributes:
|
|
||||||
label: Sysinfo
|
|
||||||
description: System info file, generated by WebUI. You can generate it in settings, on the Sysinfo page. Drag the file into the field to upload it. If you submit your report without including the sysinfo file, the report will be closed. If needed, review the report to make sure it includes no personal information you don't want to share. If you can't start WebUI, you can use --dump-sysinfo commandline argument to generate the file.
|
|
||||||
validations:
|
validations:
|
||||||
required: true
|
required: true
|
||||||
- type: dropdown
|
- type: dropdown
|
||||||
@@ -58,12 +73,25 @@ body:
|
|||||||
- Brave
|
- Brave
|
||||||
- Apple Safari
|
- Apple Safari
|
||||||
- Microsoft Edge
|
- Microsoft Edge
|
||||||
|
- Android
|
||||||
|
- iOS
|
||||||
- Other
|
- Other
|
||||||
|
- type: textarea
|
||||||
|
id: sysinfo
|
||||||
|
attributes:
|
||||||
|
label: Sysinfo
|
||||||
|
description: System info file, generated by WebUI. You can generate it in settings, on the Sysinfo page. Drag the file into the field to upload it. If you submit your report without including the sysinfo file, the report will be closed. If needed, review the report to make sure it includes no personal information you don't want to share. If you can't start WebUI, you can use --dump-sysinfo commandline argument to generate the file.
|
||||||
|
placeholder: |
|
||||||
|
1. Go to WebUI Settings -> Sysinfo -> Download system info.
|
||||||
|
If WebUI fails to launch, use --dump-sysinfo commandline argument to generate the file
|
||||||
|
2. Upload the Sysinfo as a attached file, Do NOT paste it in as plain text.
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
- type: textarea
|
- type: textarea
|
||||||
id: logs
|
id: logs
|
||||||
attributes:
|
attributes:
|
||||||
label: Console logs
|
label: Console logs
|
||||||
description: Please provide **full** cmd/terminal logs from the moment you started UI to the end of it, after your bug happened. If it's very long, provide a link to pastebin or similar service.
|
description: Please provide **full** cmd/terminal logs from the moment you started UI to the end of it, after the bug occured. If it's very long, provide a link to pastebin or similar service.
|
||||||
render: Shell
|
render: Shell
|
||||||
validations:
|
validations:
|
||||||
required: true
|
required: true
|
||||||
@@ -71,4 +99,7 @@ body:
|
|||||||
id: misc
|
id: misc
|
||||||
attributes:
|
attributes:
|
||||||
label: Additional information
|
label: Additional information
|
||||||
description: Please provide us with any relevant additional info or context.
|
description: |
|
||||||
|
Please provide us with any relevant additional info or context.
|
||||||
|
Examples:
|
||||||
|
I have updated my GPU driver recently.
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ jobs:
|
|||||||
# not to have GHA download an (at the time of writing) 4 GB cache
|
# not to have GHA download an (at the time of writing) 4 GB cache
|
||||||
# of PyTorch and other dependencies.
|
# of PyTorch and other dependencies.
|
||||||
- name: Install Ruff
|
- name: Install Ruff
|
||||||
run: pip install ruff==0.0.272
|
run: pip install ruff==0.1.6
|
||||||
- name: Run Ruff
|
- name: Run Ruff
|
||||||
run: ruff .
|
run: ruff .
|
||||||
lint-js:
|
lint-js:
|
||||||
|
|||||||
@@ -1,3 +1,8 @@
|
|||||||
|
## 1.6.1
|
||||||
|
|
||||||
|
### Bug Fixes:
|
||||||
|
* fix an error causing the webui to fail to start ([#13839](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13839))
|
||||||
|
|
||||||
## 1.6.0
|
## 1.6.0
|
||||||
|
|
||||||
### Features:
|
### Features:
|
||||||
|
|||||||
@@ -91,6 +91,7 @@ A browser interface based on Gradio library for Stable Diffusion.
|
|||||||
- Eased resolution restriction: generated image's dimensions must be a multiple of 8 rather than 64
|
- Eased resolution restriction: generated image's dimensions must be a multiple of 8 rather than 64
|
||||||
- Now with a license!
|
- Now with a license!
|
||||||
- Reorder elements in the UI from settings screen
|
- Reorder elements in the UI from settings screen
|
||||||
|
- [Segmind Stable Diffusion](https://huggingface.co/segmind/SSD-1B) support
|
||||||
|
|
||||||
## Installation and Running
|
## Installation and Running
|
||||||
Make sure the required [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) are met and follow the instructions available for:
|
Make sure the required [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) are met and follow the instructions available for:
|
||||||
@@ -120,7 +121,9 @@ Alternatively, use online services (like Google Colab):
|
|||||||
# Debian-based:
|
# Debian-based:
|
||||||
sudo apt install wget git python3 python3-venv libgl1 libglib2.0-0
|
sudo apt install wget git python3 python3-venv libgl1 libglib2.0-0
|
||||||
# Red Hat-based:
|
# Red Hat-based:
|
||||||
sudo dnf install wget git python3
|
sudo dnf install wget git python3 gperftools-libs libglvnd-glx
|
||||||
|
# openSUSE-based:
|
||||||
|
sudo zypper install wget git python3 libtcmalloc4 libglvnd
|
||||||
# Arch-based:
|
# Arch-based:
|
||||||
sudo pacman -S wget git python3
|
sudo pacman -S wget git python3
|
||||||
```
|
```
|
||||||
@@ -146,7 +149,7 @@ For the purposes of getting Google and other search engines to crawl the wiki, h
|
|||||||
## Credits
|
## Credits
|
||||||
Licenses for borrowed code can be found in `Settings -> Licenses` screen, and also in `html/licenses.html` file.
|
Licenses for borrowed code can be found in `Settings -> Licenses` screen, and also in `html/licenses.html` file.
|
||||||
|
|
||||||
- Stable Diffusion - https://github.com/CompVis/stable-diffusion, https://github.com/CompVis/taming-transformers
|
- Stable Diffusion - https://github.com/Stability-AI/stablediffusion, https://github.com/CompVis/taming-transformers
|
||||||
- k-diffusion - https://github.com/crowsonkb/k-diffusion.git
|
- k-diffusion - https://github.com/crowsonkb/k-diffusion.git
|
||||||
- GFPGAN - https://github.com/TencentARC/GFPGAN.git
|
- GFPGAN - https://github.com/TencentARC/GFPGAN.git
|
||||||
- CodeFormer - https://github.com/sczhou/CodeFormer
|
- CodeFormer - https://github.com/sczhou/CodeFormer
|
||||||
@@ -173,5 +176,6 @@ Licenses for borrowed code can be found in `Settings -> Licenses` screen, and al
|
|||||||
- TAESD - Ollin Boer Bohan - https://github.com/madebyollin/taesd
|
- TAESD - Ollin Boer Bohan - https://github.com/madebyollin/taesd
|
||||||
- LyCORIS - KohakuBlueleaf
|
- LyCORIS - KohakuBlueleaf
|
||||||
- Restart sampling - lambertae - https://github.com/Newbeeer/diffusion_restart_sampling
|
- Restart sampling - lambertae - https://github.com/Newbeeer/diffusion_restart_sampling
|
||||||
|
- Hypertile - tfernd - https://github.com/tfernd/HyperTile
|
||||||
- Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user.
|
- Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user.
|
||||||
- (You)
|
- (You)
|
||||||
|
|||||||
@@ -0,0 +1,73 @@
|
|||||||
|
model:
|
||||||
|
base_learning_rate: 1.0e-04
|
||||||
|
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||||
|
params:
|
||||||
|
linear_start: 0.00085
|
||||||
|
linear_end: 0.0120
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
first_stage_key: "jpg"
|
||||||
|
cond_stage_key: "txt"
|
||||||
|
image_size: 64
|
||||||
|
channels: 4
|
||||||
|
cond_stage_trainable: false # Note: different from the one we trained before
|
||||||
|
conditioning_key: crossattn
|
||||||
|
monitor: val/loss_simple_ema
|
||||||
|
scale_factor: 0.18215
|
||||||
|
use_ema: False
|
||||||
|
|
||||||
|
scheduler_config: # 10000 warmup steps
|
||||||
|
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||||
|
params:
|
||||||
|
warm_up_steps: [ 10000 ]
|
||||||
|
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||||
|
f_start: [ 1.e-6 ]
|
||||||
|
f_max: [ 1. ]
|
||||||
|
f_min: [ 1. ]
|
||||||
|
|
||||||
|
unet_config:
|
||||||
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
image_size: 32 # unused
|
||||||
|
in_channels: 4
|
||||||
|
out_channels: 4
|
||||||
|
model_channels: 320
|
||||||
|
attention_resolutions: [ 4, 2, 1 ]
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult: [ 1, 2, 4, 4 ]
|
||||||
|
num_head_channels: 64
|
||||||
|
use_spatial_transformer: True
|
||||||
|
use_linear_in_transformer: True
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 1024
|
||||||
|
use_checkpoint: True
|
||||||
|
legacy: False
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ddconfig:
|
||||||
|
double_z: true
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
- 4
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
cond_stage_config:
|
||||||
|
target: modules.xlmr_m18.BertSeriesModelWithTransformation
|
||||||
|
params:
|
||||||
|
name: "XLMR-Large"
|
||||||
@@ -19,3 +19,50 @@ def rebuild_cp_decomposition(up, down, mid):
|
|||||||
up = up.reshape(up.size(0), -1)
|
up = up.reshape(up.size(0), -1)
|
||||||
down = down.reshape(down.size(0), -1)
|
down = down.reshape(down.size(0), -1)
|
||||||
return torch.einsum('n m k l, i n, m j -> i j k l', mid, up, down)
|
return torch.einsum('n m k l, i n, m j -> i j k l', mid, up, down)
|
||||||
|
|
||||||
|
|
||||||
|
# copied from https://github.com/KohakuBlueleaf/LyCORIS/blob/dev/lycoris/modules/lokr.py
|
||||||
|
def factorization(dimension: int, factor:int=-1) -> tuple[int, int]:
|
||||||
|
'''
|
||||||
|
return a tuple of two value of input dimension decomposed by the number closest to factor
|
||||||
|
second value is higher or equal than first value.
|
||||||
|
|
||||||
|
In LoRA with Kroneckor Product, first value is a value for weight scale.
|
||||||
|
secon value is a value for weight.
|
||||||
|
|
||||||
|
Becuase of non-commutative property, A⊗B ≠ B⊗A. Meaning of two matrices is slightly different.
|
||||||
|
|
||||||
|
examples)
|
||||||
|
factor
|
||||||
|
-1 2 4 8 16 ...
|
||||||
|
127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127
|
||||||
|
128 -> 8, 16 128 -> 2, 64 128 -> 4, 32 128 -> 8, 16 128 -> 8, 16
|
||||||
|
250 -> 10, 25 250 -> 2, 125 250 -> 2, 125 250 -> 5, 50 250 -> 10, 25
|
||||||
|
360 -> 8, 45 360 -> 2, 180 360 -> 4, 90 360 -> 8, 45 360 -> 12, 30
|
||||||
|
512 -> 16, 32 512 -> 2, 256 512 -> 4, 128 512 -> 8, 64 512 -> 16, 32
|
||||||
|
1024 -> 32, 32 1024 -> 2, 512 1024 -> 4, 256 1024 -> 8, 128 1024 -> 16, 64
|
||||||
|
'''
|
||||||
|
|
||||||
|
if factor > 0 and (dimension % factor) == 0:
|
||||||
|
m = factor
|
||||||
|
n = dimension // factor
|
||||||
|
if m > n:
|
||||||
|
n, m = m, n
|
||||||
|
return m, n
|
||||||
|
if factor < 0:
|
||||||
|
factor = dimension
|
||||||
|
m, n = 1, dimension
|
||||||
|
length = m + n
|
||||||
|
while m<n:
|
||||||
|
new_m = m + 1
|
||||||
|
while dimension%new_m != 0:
|
||||||
|
new_m += 1
|
||||||
|
new_n = dimension // new_m
|
||||||
|
if new_m + new_n > length or new_m>factor:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
m, n = new_m, new_n
|
||||||
|
if m > n:
|
||||||
|
n, m = m, n
|
||||||
|
return m, n
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,33 @@
|
|||||||
|
|
||||||
|
import network
|
||||||
|
|
||||||
|
class ModuleTypeGLora(network.ModuleType):
|
||||||
|
def create_module(self, net: network.Network, weights: network.NetworkWeights):
|
||||||
|
if all(x in weights.w for x in ["a1.weight", "a2.weight", "alpha", "b1.weight", "b2.weight"]):
|
||||||
|
return NetworkModuleGLora(net, weights)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
# adapted from https://github.com/KohakuBlueleaf/LyCORIS
|
||||||
|
class NetworkModuleGLora(network.NetworkModule):
|
||||||
|
def __init__(self, net: network.Network, weights: network.NetworkWeights):
|
||||||
|
super().__init__(net, weights)
|
||||||
|
|
||||||
|
if hasattr(self.sd_module, 'weight'):
|
||||||
|
self.shape = self.sd_module.weight.shape
|
||||||
|
|
||||||
|
self.w1a = weights.w["a1.weight"]
|
||||||
|
self.w1b = weights.w["b1.weight"]
|
||||||
|
self.w2a = weights.w["a2.weight"]
|
||||||
|
self.w2b = weights.w["b2.weight"]
|
||||||
|
|
||||||
|
def calc_updown(self, orig_weight):
|
||||||
|
w1a = self.w1a.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||||
|
w1b = self.w1b.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||||
|
w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||||
|
w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||||
|
|
||||||
|
output_shape = [w1a.size(0), w1b.size(1)]
|
||||||
|
updown = ((w2b @ w1b) + ((orig_weight @ w2a) @ w1a))
|
||||||
|
|
||||||
|
return self.finalize_updown(updown, orig_weight, output_shape)
|
||||||
@@ -0,0 +1,97 @@
|
|||||||
|
import torch
|
||||||
|
import network
|
||||||
|
from lyco_helpers import factorization
|
||||||
|
from einops import rearrange
|
||||||
|
|
||||||
|
|
||||||
|
class ModuleTypeOFT(network.ModuleType):
|
||||||
|
def create_module(self, net: network.Network, weights: network.NetworkWeights):
|
||||||
|
if all(x in weights.w for x in ["oft_blocks"]) or all(x in weights.w for x in ["oft_diag"]):
|
||||||
|
return NetworkModuleOFT(net, weights)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Supports both kohya-ss' implementation of COFT https://github.com/kohya-ss/sd-scripts/blob/main/networks/oft.py
|
||||||
|
# and KohakuBlueleaf's implementation of OFT/COFT https://github.com/KohakuBlueleaf/LyCORIS/blob/dev/lycoris/modules/diag_oft.py
|
||||||
|
class NetworkModuleOFT(network.NetworkModule):
|
||||||
|
def __init__(self, net: network.Network, weights: network.NetworkWeights):
|
||||||
|
|
||||||
|
super().__init__(net, weights)
|
||||||
|
|
||||||
|
self.lin_module = None
|
||||||
|
self.org_module: list[torch.Module] = [self.sd_module]
|
||||||
|
|
||||||
|
# kohya-ss
|
||||||
|
if "oft_blocks" in weights.w.keys():
|
||||||
|
self.is_kohya = True
|
||||||
|
self.oft_blocks = weights.w["oft_blocks"] # (num_blocks, block_size, block_size)
|
||||||
|
self.alpha = weights.w["alpha"] # alpha is constraint
|
||||||
|
self.dim = self.oft_blocks.shape[0] # lora dim
|
||||||
|
# LyCORIS
|
||||||
|
elif "oft_diag" in weights.w.keys():
|
||||||
|
self.is_kohya = False
|
||||||
|
self.oft_blocks = weights.w["oft_diag"]
|
||||||
|
# self.alpha is unused
|
||||||
|
self.dim = self.oft_blocks.shape[1] # (num_blocks, block_size, block_size)
|
||||||
|
|
||||||
|
is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear]
|
||||||
|
is_conv = type(self.sd_module) in [torch.nn.Conv2d]
|
||||||
|
is_other_linear = type(self.sd_module) in [torch.nn.MultiheadAttention] # unsupported
|
||||||
|
|
||||||
|
if is_linear:
|
||||||
|
self.out_dim = self.sd_module.out_features
|
||||||
|
elif is_conv:
|
||||||
|
self.out_dim = self.sd_module.out_channels
|
||||||
|
elif is_other_linear:
|
||||||
|
self.out_dim = self.sd_module.embed_dim
|
||||||
|
|
||||||
|
if self.is_kohya:
|
||||||
|
self.constraint = self.alpha * self.out_dim
|
||||||
|
self.num_blocks = self.dim
|
||||||
|
self.block_size = self.out_dim // self.dim
|
||||||
|
else:
|
||||||
|
self.constraint = None
|
||||||
|
self.block_size, self.num_blocks = factorization(self.out_dim, self.dim)
|
||||||
|
|
||||||
|
def calc_updown_kb(self, orig_weight, multiplier):
|
||||||
|
oft_blocks = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||||
|
oft_blocks = oft_blocks - oft_blocks.transpose(1, 2) # ensure skew-symmetric orthogonal matrix
|
||||||
|
|
||||||
|
R = oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||||
|
R = R * multiplier + torch.eye(self.block_size, device=orig_weight.device)
|
||||||
|
|
||||||
|
# This errors out for MultiheadAttention, might need to be handled up-stream
|
||||||
|
merged_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size)
|
||||||
|
merged_weight = torch.einsum(
|
||||||
|
'k n m, k n ... -> k m ...',
|
||||||
|
R,
|
||||||
|
merged_weight
|
||||||
|
)
|
||||||
|
merged_weight = rearrange(merged_weight, 'k m ... -> (k m) ...')
|
||||||
|
|
||||||
|
updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight
|
||||||
|
output_shape = orig_weight.shape
|
||||||
|
return self.finalize_updown(updown, orig_weight, output_shape)
|
||||||
|
|
||||||
|
def calc_updown(self, orig_weight):
|
||||||
|
# if alpha is a very small number as in coft, calc_scale() will return a almost zero number so we ignore it
|
||||||
|
multiplier = self.multiplier()
|
||||||
|
return self.calc_updown_kb(orig_weight, multiplier)
|
||||||
|
|
||||||
|
# override to remove the multiplier/scale factor; it's already multiplied in get_weight
|
||||||
|
def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None):
|
||||||
|
if self.bias is not None:
|
||||||
|
updown = updown.reshape(self.bias.shape)
|
||||||
|
updown += self.bias.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||||
|
updown = updown.reshape(output_shape)
|
||||||
|
|
||||||
|
if len(output_shape) == 4:
|
||||||
|
updown = updown.reshape(output_shape)
|
||||||
|
|
||||||
|
if orig_weight.size().numel() == updown.size().numel():
|
||||||
|
updown = updown.reshape(orig_weight.shape)
|
||||||
|
|
||||||
|
if ex_bias is not None:
|
||||||
|
ex_bias = ex_bias * self.multiplier()
|
||||||
|
|
||||||
|
return updown, ex_bias
|
||||||
@@ -5,17 +5,19 @@ import re
|
|||||||
import lora_patches
|
import lora_patches
|
||||||
import network
|
import network
|
||||||
import network_lora
|
import network_lora
|
||||||
|
import network_glora
|
||||||
import network_hada
|
import network_hada
|
||||||
import network_ia3
|
import network_ia3
|
||||||
import network_lokr
|
import network_lokr
|
||||||
import network_full
|
import network_full
|
||||||
import network_norm
|
import network_norm
|
||||||
|
import network_oft
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
from modules import shared, devices, sd_models, errors, scripts, sd_hijack
|
from modules import shared, devices, sd_models, errors, scripts, sd_hijack
|
||||||
from modules.textual_inversion.textual_inversion import Embedding
|
import modules.textual_inversion.textual_inversion as textual_inversion
|
||||||
|
|
||||||
from lora_logger import logger
|
from lora_logger import logger
|
||||||
|
|
||||||
@@ -26,6 +28,8 @@ module_types = [
|
|||||||
network_lokr.ModuleTypeLokr(),
|
network_lokr.ModuleTypeLokr(),
|
||||||
network_full.ModuleTypeFull(),
|
network_full.ModuleTypeFull(),
|
||||||
network_norm.ModuleTypeNorm(),
|
network_norm.ModuleTypeNorm(),
|
||||||
|
network_glora.ModuleTypeGLora(),
|
||||||
|
network_oft.ModuleTypeOFT(),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@@ -187,6 +191,17 @@ def load_network(name, network_on_disk):
|
|||||||
key = key_network_without_network_parts.replace("lora_te1_text_model", "transformer_text_model")
|
key = key_network_without_network_parts.replace("lora_te1_text_model", "transformer_text_model")
|
||||||
sd_module = shared.sd_model.network_layer_mapping.get(key, None)
|
sd_module = shared.sd_model.network_layer_mapping.get(key, None)
|
||||||
|
|
||||||
|
# kohya_ss OFT module
|
||||||
|
elif sd_module is None and "oft_unet" in key_network_without_network_parts:
|
||||||
|
key = key_network_without_network_parts.replace("oft_unet", "diffusion_model")
|
||||||
|
sd_module = shared.sd_model.network_layer_mapping.get(key, None)
|
||||||
|
|
||||||
|
# KohakuBlueLeaf OFT module
|
||||||
|
if sd_module is None and "oft_diag" in key:
|
||||||
|
key = key_network_without_network_parts.replace("lora_unet", "diffusion_model")
|
||||||
|
key = key_network_without_network_parts.replace("lora_te1_text_model", "0_transformer_text_model")
|
||||||
|
sd_module = shared.sd_model.network_layer_mapping.get(key, None)
|
||||||
|
|
||||||
if sd_module is None:
|
if sd_module is None:
|
||||||
keys_failed_to_match[key_network] = key
|
keys_failed_to_match[key_network] = key
|
||||||
continue
|
continue
|
||||||
@@ -210,34 +225,7 @@ def load_network(name, network_on_disk):
|
|||||||
|
|
||||||
embeddings = {}
|
embeddings = {}
|
||||||
for emb_name, data in bundle_embeddings.items():
|
for emb_name, data in bundle_embeddings.items():
|
||||||
# textual inversion embeddings
|
embedding = textual_inversion.create_embedding_from_data(data, emb_name, filename=network_on_disk.filename + "/" + emb_name)
|
||||||
if 'string_to_param' in data:
|
|
||||||
param_dict = data['string_to_param']
|
|
||||||
param_dict = getattr(param_dict, '_parameters', param_dict) # fix for torch 1.12.1 loading saved file from torch 1.11
|
|
||||||
assert len(param_dict) == 1, 'embedding file has multiple terms in it'
|
|
||||||
emb = next(iter(param_dict.items()))[1]
|
|
||||||
vec = emb.detach().to(devices.device, dtype=torch.float32)
|
|
||||||
shape = vec.shape[-1]
|
|
||||||
vectors = vec.shape[0]
|
|
||||||
elif type(data) == dict and 'clip_g' in data and 'clip_l' in data: # SDXL embedding
|
|
||||||
vec = {k: v.detach().to(devices.device, dtype=torch.float32) for k, v in data.items()}
|
|
||||||
shape = data['clip_g'].shape[-1] + data['clip_l'].shape[-1]
|
|
||||||
vectors = data['clip_g'].shape[0]
|
|
||||||
elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: # diffuser concepts
|
|
||||||
assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
|
|
||||||
|
|
||||||
emb = next(iter(data.values()))
|
|
||||||
if len(emb.shape) == 1:
|
|
||||||
emb = emb.unsqueeze(0)
|
|
||||||
vec = emb.detach().to(devices.device, dtype=torch.float32)
|
|
||||||
shape = vec.shape[-1]
|
|
||||||
vectors = vec.shape[0]
|
|
||||||
else:
|
|
||||||
raise Exception(f"Couldn't identify {emb_name} in lora: {name} as neither textual inversion embedding nor diffuser concept.")
|
|
||||||
|
|
||||||
embedding = Embedding(vec, emb_name)
|
|
||||||
embedding.vectors = vectors
|
|
||||||
embedding.shape = shape
|
|
||||||
embedding.loaded = None
|
embedding.loaded = None
|
||||||
embeddings[emb_name] = embedding
|
embeddings[emb_name] = embedding
|
||||||
|
|
||||||
|
|||||||
@@ -17,6 +17,8 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
|
|||||||
|
|
||||||
def create_item(self, name, index=None, enable_filter=True):
|
def create_item(self, name, index=None, enable_filter=True):
|
||||||
lora_on_disk = networks.available_networks.get(name)
|
lora_on_disk = networks.available_networks.get(name)
|
||||||
|
if lora_on_disk is None:
|
||||||
|
return
|
||||||
|
|
||||||
path, ext = os.path.splitext(lora_on_disk.filename)
|
path, ext = os.path.splitext(lora_on_disk.filename)
|
||||||
|
|
||||||
@@ -66,9 +68,10 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
|
|||||||
return item
|
return item
|
||||||
|
|
||||||
def list_items(self):
|
def list_items(self):
|
||||||
for index, name in enumerate(networks.available_networks):
|
# instantiate a list to protect against concurrent modification
|
||||||
|
names = list(networks.available_networks)
|
||||||
|
for index, name in enumerate(names):
|
||||||
item = self.create_item(name, index)
|
item = self.create_item(name, index)
|
||||||
|
|
||||||
if item is not None:
|
if item is not None:
|
||||||
yield item
|
yield item
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,345 @@
|
|||||||
|
"""
|
||||||
|
Hypertile module for splitting attention layers in SD-1.5 U-Net and SD-1.5 VAE
|
||||||
|
Warn: The patch works well only if the input image has a width and height that are multiples of 128
|
||||||
|
Original author: @tfernd Github: https://github.com/tfernd/HyperTile
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import functools
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
from functools import wraps, cache
|
||||||
|
|
||||||
|
import math
|
||||||
|
import torch.nn as nn
|
||||||
|
import random
|
||||||
|
|
||||||
|
from einops import rearrange
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class HypertileParams:
|
||||||
|
depth = 0
|
||||||
|
layer_name = ""
|
||||||
|
tile_size: int = 0
|
||||||
|
swap_size: int = 0
|
||||||
|
aspect_ratio: float = 1.0
|
||||||
|
forward = None
|
||||||
|
enabled = False
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# TODO add SD-XL layers
|
||||||
|
DEPTH_LAYERS = {
|
||||||
|
0: [
|
||||||
|
# SD 1.5 U-Net (diffusers)
|
||||||
|
"down_blocks.0.attentions.0.transformer_blocks.0.attn1",
|
||||||
|
"down_blocks.0.attentions.1.transformer_blocks.0.attn1",
|
||||||
|
"up_blocks.3.attentions.0.transformer_blocks.0.attn1",
|
||||||
|
"up_blocks.3.attentions.1.transformer_blocks.0.attn1",
|
||||||
|
"up_blocks.3.attentions.2.transformer_blocks.0.attn1",
|
||||||
|
# SD 1.5 U-Net (ldm)
|
||||||
|
"input_blocks.1.1.transformer_blocks.0.attn1",
|
||||||
|
"input_blocks.2.1.transformer_blocks.0.attn1",
|
||||||
|
"output_blocks.9.1.transformer_blocks.0.attn1",
|
||||||
|
"output_blocks.10.1.transformer_blocks.0.attn1",
|
||||||
|
"output_blocks.11.1.transformer_blocks.0.attn1",
|
||||||
|
# SD 1.5 VAE
|
||||||
|
"decoder.mid_block.attentions.0",
|
||||||
|
"decoder.mid.attn_1",
|
||||||
|
],
|
||||||
|
1: [
|
||||||
|
# SD 1.5 U-Net (diffusers)
|
||||||
|
"down_blocks.1.attentions.0.transformer_blocks.0.attn1",
|
||||||
|
"down_blocks.1.attentions.1.transformer_blocks.0.attn1",
|
||||||
|
"up_blocks.2.attentions.0.transformer_blocks.0.attn1",
|
||||||
|
"up_blocks.2.attentions.1.transformer_blocks.0.attn1",
|
||||||
|
"up_blocks.2.attentions.2.transformer_blocks.0.attn1",
|
||||||
|
# SD 1.5 U-Net (ldm)
|
||||||
|
"input_blocks.4.1.transformer_blocks.0.attn1",
|
||||||
|
"input_blocks.5.1.transformer_blocks.0.attn1",
|
||||||
|
"output_blocks.6.1.transformer_blocks.0.attn1",
|
||||||
|
"output_blocks.7.1.transformer_blocks.0.attn1",
|
||||||
|
"output_blocks.8.1.transformer_blocks.0.attn1",
|
||||||
|
],
|
||||||
|
2: [
|
||||||
|
# SD 1.5 U-Net (diffusers)
|
||||||
|
"down_blocks.2.attentions.0.transformer_blocks.0.attn1",
|
||||||
|
"down_blocks.2.attentions.1.transformer_blocks.0.attn1",
|
||||||
|
"up_blocks.1.attentions.0.transformer_blocks.0.attn1",
|
||||||
|
"up_blocks.1.attentions.1.transformer_blocks.0.attn1",
|
||||||
|
"up_blocks.1.attentions.2.transformer_blocks.0.attn1",
|
||||||
|
# SD 1.5 U-Net (ldm)
|
||||||
|
"input_blocks.7.1.transformer_blocks.0.attn1",
|
||||||
|
"input_blocks.8.1.transformer_blocks.0.attn1",
|
||||||
|
"output_blocks.3.1.transformer_blocks.0.attn1",
|
||||||
|
"output_blocks.4.1.transformer_blocks.0.attn1",
|
||||||
|
"output_blocks.5.1.transformer_blocks.0.attn1",
|
||||||
|
],
|
||||||
|
3: [
|
||||||
|
# SD 1.5 U-Net (diffusers)
|
||||||
|
"mid_block.attentions.0.transformer_blocks.0.attn1",
|
||||||
|
# SD 1.5 U-Net (ldm)
|
||||||
|
"middle_block.1.transformer_blocks.0.attn1",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
# XL layers, thanks for GitHub@gel-crabs for the help
|
||||||
|
DEPTH_LAYERS_XL = {
|
||||||
|
0: [
|
||||||
|
# SD 1.5 U-Net (diffusers)
|
||||||
|
"down_blocks.0.attentions.0.transformer_blocks.0.attn1",
|
||||||
|
"down_blocks.0.attentions.1.transformer_blocks.0.attn1",
|
||||||
|
"up_blocks.3.attentions.0.transformer_blocks.0.attn1",
|
||||||
|
"up_blocks.3.attentions.1.transformer_blocks.0.attn1",
|
||||||
|
"up_blocks.3.attentions.2.transformer_blocks.0.attn1",
|
||||||
|
# SD 1.5 U-Net (ldm)
|
||||||
|
"input_blocks.4.1.transformer_blocks.0.attn1",
|
||||||
|
"input_blocks.5.1.transformer_blocks.0.attn1",
|
||||||
|
"output_blocks.3.1.transformer_blocks.0.attn1",
|
||||||
|
"output_blocks.4.1.transformer_blocks.0.attn1",
|
||||||
|
"output_blocks.5.1.transformer_blocks.0.attn1",
|
||||||
|
# SD 1.5 VAE
|
||||||
|
"decoder.mid_block.attentions.0",
|
||||||
|
"decoder.mid.attn_1",
|
||||||
|
],
|
||||||
|
1: [
|
||||||
|
# SD 1.5 U-Net (diffusers)
|
||||||
|
#"down_blocks.1.attentions.0.transformer_blocks.0.attn1",
|
||||||
|
#"down_blocks.1.attentions.1.transformer_blocks.0.attn1",
|
||||||
|
#"up_blocks.2.attentions.0.transformer_blocks.0.attn1",
|
||||||
|
#"up_blocks.2.attentions.1.transformer_blocks.0.attn1",
|
||||||
|
#"up_blocks.2.attentions.2.transformer_blocks.0.attn1",
|
||||||
|
# SD 1.5 U-Net (ldm)
|
||||||
|
"input_blocks.4.1.transformer_blocks.1.attn1",
|
||||||
|
"input_blocks.5.1.transformer_blocks.1.attn1",
|
||||||
|
"output_blocks.3.1.transformer_blocks.1.attn1",
|
||||||
|
"output_blocks.4.1.transformer_blocks.1.attn1",
|
||||||
|
"output_blocks.5.1.transformer_blocks.1.attn1",
|
||||||
|
"input_blocks.7.1.transformer_blocks.0.attn1",
|
||||||
|
"input_blocks.8.1.transformer_blocks.0.attn1",
|
||||||
|
"output_blocks.0.1.transformer_blocks.0.attn1",
|
||||||
|
"output_blocks.1.1.transformer_blocks.0.attn1",
|
||||||
|
"output_blocks.2.1.transformer_blocks.0.attn1",
|
||||||
|
"input_blocks.7.1.transformer_blocks.1.attn1",
|
||||||
|
"input_blocks.8.1.transformer_blocks.1.attn1",
|
||||||
|
"output_blocks.0.1.transformer_blocks.1.attn1",
|
||||||
|
"output_blocks.1.1.transformer_blocks.1.attn1",
|
||||||
|
"output_blocks.2.1.transformer_blocks.1.attn1",
|
||||||
|
"input_blocks.7.1.transformer_blocks.2.attn1",
|
||||||
|
"input_blocks.8.1.transformer_blocks.2.attn1",
|
||||||
|
"output_blocks.0.1.transformer_blocks.2.attn1",
|
||||||
|
"output_blocks.1.1.transformer_blocks.2.attn1",
|
||||||
|
"output_blocks.2.1.transformer_blocks.2.attn1",
|
||||||
|
"input_blocks.7.1.transformer_blocks.3.attn1",
|
||||||
|
"input_blocks.8.1.transformer_blocks.3.attn1",
|
||||||
|
"output_blocks.0.1.transformer_blocks.3.attn1",
|
||||||
|
"output_blocks.1.1.transformer_blocks.3.attn1",
|
||||||
|
"output_blocks.2.1.transformer_blocks.3.attn1",
|
||||||
|
"input_blocks.7.1.transformer_blocks.4.attn1",
|
||||||
|
"input_blocks.8.1.transformer_blocks.4.attn1",
|
||||||
|
"output_blocks.0.1.transformer_blocks.4.attn1",
|
||||||
|
"output_blocks.1.1.transformer_blocks.4.attn1",
|
||||||
|
"output_blocks.2.1.transformer_blocks.4.attn1",
|
||||||
|
"input_blocks.7.1.transformer_blocks.5.attn1",
|
||||||
|
"input_blocks.8.1.transformer_blocks.5.attn1",
|
||||||
|
"output_blocks.0.1.transformer_blocks.5.attn1",
|
||||||
|
"output_blocks.1.1.transformer_blocks.5.attn1",
|
||||||
|
"output_blocks.2.1.transformer_blocks.5.attn1",
|
||||||
|
"input_blocks.7.1.transformer_blocks.6.attn1",
|
||||||
|
"input_blocks.8.1.transformer_blocks.6.attn1",
|
||||||
|
"output_blocks.0.1.transformer_blocks.6.attn1",
|
||||||
|
"output_blocks.1.1.transformer_blocks.6.attn1",
|
||||||
|
"output_blocks.2.1.transformer_blocks.6.attn1",
|
||||||
|
"input_blocks.7.1.transformer_blocks.7.attn1",
|
||||||
|
"input_blocks.8.1.transformer_blocks.7.attn1",
|
||||||
|
"output_blocks.0.1.transformer_blocks.7.attn1",
|
||||||
|
"output_blocks.1.1.transformer_blocks.7.attn1",
|
||||||
|
"output_blocks.2.1.transformer_blocks.7.attn1",
|
||||||
|
"input_blocks.7.1.transformer_blocks.8.attn1",
|
||||||
|
"input_blocks.8.1.transformer_blocks.8.attn1",
|
||||||
|
"output_blocks.0.1.transformer_blocks.8.attn1",
|
||||||
|
"output_blocks.1.1.transformer_blocks.8.attn1",
|
||||||
|
"output_blocks.2.1.transformer_blocks.8.attn1",
|
||||||
|
"input_blocks.7.1.transformer_blocks.9.attn1",
|
||||||
|
"input_blocks.8.1.transformer_blocks.9.attn1",
|
||||||
|
"output_blocks.0.1.transformer_blocks.9.attn1",
|
||||||
|
"output_blocks.1.1.transformer_blocks.9.attn1",
|
||||||
|
"output_blocks.2.1.transformer_blocks.9.attn1",
|
||||||
|
],
|
||||||
|
2: [
|
||||||
|
# SD 1.5 U-Net (diffusers)
|
||||||
|
"mid_block.attentions.0.transformer_blocks.0.attn1",
|
||||||
|
# SD 1.5 U-Net (ldm)
|
||||||
|
"middle_block.1.transformer_blocks.0.attn1",
|
||||||
|
"middle_block.1.transformer_blocks.1.attn1",
|
||||||
|
"middle_block.1.transformer_blocks.2.attn1",
|
||||||
|
"middle_block.1.transformer_blocks.3.attn1",
|
||||||
|
"middle_block.1.transformer_blocks.4.attn1",
|
||||||
|
"middle_block.1.transformer_blocks.5.attn1",
|
||||||
|
"middle_block.1.transformer_blocks.6.attn1",
|
||||||
|
"middle_block.1.transformer_blocks.7.attn1",
|
||||||
|
"middle_block.1.transformer_blocks.8.attn1",
|
||||||
|
"middle_block.1.transformer_blocks.9.attn1",
|
||||||
|
],
|
||||||
|
3 : [] # TODO - separate layers for SD-XL
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
RNG_INSTANCE = random.Random()
|
||||||
|
|
||||||
|
|
||||||
|
def random_divisor(value: int, min_value: int, /, max_options: int = 1) -> int:
|
||||||
|
"""
|
||||||
|
Returns a random divisor of value that
|
||||||
|
x * min_value <= value
|
||||||
|
if max_options is 1, the behavior is deterministic
|
||||||
|
"""
|
||||||
|
min_value = min(min_value, value)
|
||||||
|
|
||||||
|
# All big divisors of value (inclusive)
|
||||||
|
divisors = [i for i in range(min_value, value + 1) if value % i == 0] # divisors in small -> big order
|
||||||
|
|
||||||
|
ns = [value // i for i in divisors[:max_options]] # has at least 1 element # big -> small order
|
||||||
|
|
||||||
|
idx = RNG_INSTANCE.randint(0, len(ns) - 1)
|
||||||
|
|
||||||
|
return ns[idx]
|
||||||
|
|
||||||
|
|
||||||
|
def set_hypertile_seed(seed: int) -> None:
|
||||||
|
RNG_INSTANCE.seed(seed)
|
||||||
|
|
||||||
|
|
||||||
|
@functools.cache
|
||||||
|
def largest_tile_size_available(width: int, height: int) -> int:
|
||||||
|
"""
|
||||||
|
Calculates the largest tile size available for a given width and height
|
||||||
|
Tile size is always a power of 2
|
||||||
|
"""
|
||||||
|
gcd = math.gcd(width, height)
|
||||||
|
largest_tile_size_available = 1
|
||||||
|
while gcd % (largest_tile_size_available * 2) == 0:
|
||||||
|
largest_tile_size_available *= 2
|
||||||
|
return largest_tile_size_available
|
||||||
|
|
||||||
|
|
||||||
|
def iterative_closest_divisors(hw:int, aspect_ratio:float) -> tuple[int, int]:
|
||||||
|
"""
|
||||||
|
Finds h and w such that h*w = hw and h/w = aspect_ratio
|
||||||
|
We check all possible divisors of hw and return the closest to the aspect ratio
|
||||||
|
"""
|
||||||
|
divisors = [i for i in range(2, hw + 1) if hw % i == 0] # all divisors of hw
|
||||||
|
pairs = [(i, hw // i) for i in divisors] # all pairs of divisors of hw
|
||||||
|
ratios = [w/h for h, w in pairs] # all ratios of pairs of divisors of hw
|
||||||
|
closest_ratio = min(ratios, key=lambda x: abs(x - aspect_ratio)) # closest ratio to aspect_ratio
|
||||||
|
closest_pair = pairs[ratios.index(closest_ratio)] # closest pair of divisors to aspect_ratio
|
||||||
|
return closest_pair
|
||||||
|
|
||||||
|
|
||||||
|
@cache
|
||||||
|
def find_hw_candidates(hw:int, aspect_ratio:float) -> tuple[int, int]:
|
||||||
|
"""
|
||||||
|
Finds h and w such that h*w = hw and h/w = aspect_ratio
|
||||||
|
"""
|
||||||
|
h, w = round(math.sqrt(hw * aspect_ratio)), round(math.sqrt(hw / aspect_ratio))
|
||||||
|
# find h and w such that h*w = hw and h/w = aspect_ratio
|
||||||
|
if h * w != hw:
|
||||||
|
w_candidate = hw / h
|
||||||
|
# check if w is an integer
|
||||||
|
if not w_candidate.is_integer():
|
||||||
|
h_candidate = hw / w
|
||||||
|
# check if h is an integer
|
||||||
|
if not h_candidate.is_integer():
|
||||||
|
return iterative_closest_divisors(hw, aspect_ratio)
|
||||||
|
else:
|
||||||
|
h = int(h_candidate)
|
||||||
|
else:
|
||||||
|
w = int(w_candidate)
|
||||||
|
return h, w
|
||||||
|
|
||||||
|
|
||||||
|
def self_attn_forward(params: HypertileParams, scale_depth=True) -> Callable:
|
||||||
|
|
||||||
|
@wraps(params.forward)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
if not params.enabled:
|
||||||
|
return params.forward(*args, **kwargs)
|
||||||
|
|
||||||
|
latent_tile_size = max(128, params.tile_size) // 8
|
||||||
|
x = args[0]
|
||||||
|
|
||||||
|
# VAE
|
||||||
|
if x.ndim == 4:
|
||||||
|
b, c, h, w = x.shape
|
||||||
|
|
||||||
|
nh = random_divisor(h, latent_tile_size, params.swap_size)
|
||||||
|
nw = random_divisor(w, latent_tile_size, params.swap_size)
|
||||||
|
|
||||||
|
if nh * nw > 1:
|
||||||
|
x = rearrange(x, "b c (nh h) (nw w) -> (b nh nw) c h w", nh=nh, nw=nw) # split into nh * nw tiles
|
||||||
|
|
||||||
|
out = params.forward(x, *args[1:], **kwargs)
|
||||||
|
|
||||||
|
if nh * nw > 1:
|
||||||
|
out = rearrange(out, "(b nh nw) c h w -> b c (nh h) (nw w)", nh=nh, nw=nw)
|
||||||
|
|
||||||
|
# U-Net
|
||||||
|
else:
|
||||||
|
hw: int = x.size(1)
|
||||||
|
h, w = find_hw_candidates(hw, params.aspect_ratio)
|
||||||
|
assert h * w == hw, f"Invalid aspect ratio {params.aspect_ratio} for input of shape {x.shape}, hw={hw}, h={h}, w={w}"
|
||||||
|
|
||||||
|
factor = 2 ** params.depth if scale_depth else 1
|
||||||
|
nh = random_divisor(h, latent_tile_size * factor, params.swap_size)
|
||||||
|
nw = random_divisor(w, latent_tile_size * factor, params.swap_size)
|
||||||
|
|
||||||
|
if nh * nw > 1:
|
||||||
|
x = rearrange(x, "b (nh h nw w) c -> (b nh nw) (h w) c", h=h // nh, w=w // nw, nh=nh, nw=nw)
|
||||||
|
|
||||||
|
out = params.forward(x, *args[1:], **kwargs)
|
||||||
|
|
||||||
|
if nh * nw > 1:
|
||||||
|
out = rearrange(out, "(b nh nw) hw c -> b nh nw hw c", nh=nh, nw=nw)
|
||||||
|
out = rearrange(out, "b nh nw (h w) c -> b (nh h nw w) c", h=h // nh, w=w // nw)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
def hypertile_hook_model(model: nn.Module, width, height, *, enable=False, tile_size_max=128, swap_size=1, max_depth=3, is_sdxl=False):
|
||||||
|
hypertile_layers = getattr(model, "__webui_hypertile_layers", None)
|
||||||
|
if hypertile_layers is None:
|
||||||
|
if not enable:
|
||||||
|
return
|
||||||
|
|
||||||
|
hypertile_layers = {}
|
||||||
|
layers = DEPTH_LAYERS_XL if is_sdxl else DEPTH_LAYERS
|
||||||
|
|
||||||
|
for depth in range(4):
|
||||||
|
for layer_name, module in model.named_modules():
|
||||||
|
if any(layer_name.endswith(try_name) for try_name in layers[depth]):
|
||||||
|
params = HypertileParams()
|
||||||
|
module.__webui_hypertile_params = params
|
||||||
|
params.forward = module.forward
|
||||||
|
params.depth = depth
|
||||||
|
params.layer_name = layer_name
|
||||||
|
module.forward = self_attn_forward(params)
|
||||||
|
|
||||||
|
hypertile_layers[layer_name] = 1
|
||||||
|
|
||||||
|
model.__webui_hypertile_layers = hypertile_layers
|
||||||
|
|
||||||
|
aspect_ratio = width / height
|
||||||
|
tile_size = min(largest_tile_size_available(width, height), tile_size_max)
|
||||||
|
|
||||||
|
for layer_name, module in model.named_modules():
|
||||||
|
if layer_name in hypertile_layers:
|
||||||
|
params = module.__webui_hypertile_params
|
||||||
|
|
||||||
|
params.tile_size = tile_size
|
||||||
|
params.swap_size = swap_size
|
||||||
|
params.aspect_ratio = aspect_ratio
|
||||||
|
params.enabled = enable and params.depth <= max_depth
|
||||||
@@ -0,0 +1,73 @@
|
|||||||
|
import hypertile
|
||||||
|
from modules import scripts, script_callbacks, shared
|
||||||
|
|
||||||
|
|
||||||
|
class ScriptHypertile(scripts.Script):
|
||||||
|
name = "Hypertile"
|
||||||
|
|
||||||
|
def title(self):
|
||||||
|
return self.name
|
||||||
|
|
||||||
|
def show(self, is_img2img):
|
||||||
|
return scripts.AlwaysVisible
|
||||||
|
|
||||||
|
def process(self, p, *args):
|
||||||
|
hypertile.set_hypertile_seed(p.all_seeds[0])
|
||||||
|
|
||||||
|
configure_hypertile(p.width, p.height, enable_unet=shared.opts.hypertile_enable_unet)
|
||||||
|
|
||||||
|
def before_hr(self, p, *args):
|
||||||
|
configure_hypertile(p.hr_upscale_to_x, p.hr_upscale_to_y, enable_unet=shared.opts.hypertile_enable_unet_secondpass or shared.opts.hypertile_enable_unet)
|
||||||
|
|
||||||
|
|
||||||
|
def configure_hypertile(width, height, enable_unet=True):
|
||||||
|
hypertile.hypertile_hook_model(
|
||||||
|
shared.sd_model.first_stage_model,
|
||||||
|
width,
|
||||||
|
height,
|
||||||
|
swap_size=shared.opts.hypertile_swap_size_vae,
|
||||||
|
max_depth=shared.opts.hypertile_max_depth_vae,
|
||||||
|
tile_size_max=shared.opts.hypertile_max_tile_vae,
|
||||||
|
enable=shared.opts.hypertile_enable_vae,
|
||||||
|
)
|
||||||
|
|
||||||
|
hypertile.hypertile_hook_model(
|
||||||
|
shared.sd_model.model,
|
||||||
|
width,
|
||||||
|
height,
|
||||||
|
swap_size=shared.opts.hypertile_swap_size_unet,
|
||||||
|
max_depth=shared.opts.hypertile_max_depth_unet,
|
||||||
|
tile_size_max=shared.opts.hypertile_max_tile_unet,
|
||||||
|
enable=enable_unet,
|
||||||
|
is_sdxl=shared.sd_model.is_sdxl
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def on_ui_settings():
|
||||||
|
import gradio as gr
|
||||||
|
|
||||||
|
options = {
|
||||||
|
"hypertile_explanation": shared.OptionHTML("""
|
||||||
|
<a href='https://github.com/tfernd/HyperTile'>Hypertile</a> optimizes the self-attention layer within U-Net and VAE models,
|
||||||
|
resulting in a reduction in computation time ranging from 1 to 4 times. The larger the generated image is, the greater the
|
||||||
|
benefit.
|
||||||
|
"""),
|
||||||
|
|
||||||
|
"hypertile_enable_unet": shared.OptionInfo(False, "Enable Hypertile U-Net").info("noticeable change in details of the generated picture; if enabled, overrides the setting below"),
|
||||||
|
"hypertile_enable_unet_secondpass": shared.OptionInfo(False, "Enable Hypertile U-Net for hires fix second pass"),
|
||||||
|
"hypertile_max_depth_unet": shared.OptionInfo(3, "Hypertile U-Net max depth", gr.Slider, {"minimum": 0, "maximum": 3, "step": 1}),
|
||||||
|
"hypertile_max_tile_unet": shared.OptionInfo(256, "Hypertile U-net max tile size", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}),
|
||||||
|
"hypertile_swap_size_unet": shared.OptionInfo(3, "Hypertile U-net swap size", gr.Slider, {"minimum": 0, "maximum": 6, "step": 1}),
|
||||||
|
|
||||||
|
"hypertile_enable_vae": shared.OptionInfo(False, "Enable Hypertile VAE").info("minimal change in the generated picture"),
|
||||||
|
"hypertile_max_depth_vae": shared.OptionInfo(3, "Hypertile VAE max depth", gr.Slider, {"minimum": 0, "maximum": 3, "step": 1}),
|
||||||
|
"hypertile_max_tile_vae": shared.OptionInfo(128, "Hypertile VAE max tile size", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}),
|
||||||
|
"hypertile_swap_size_vae": shared.OptionInfo(3, "Hypertile VAE swap size ", gr.Slider, {"minimum": 0, "maximum": 6, "step": 1}),
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, opt in options.items():
|
||||||
|
opt.section = ('hypertile', "Hypertile")
|
||||||
|
shared.opts.add_option(name, opt)
|
||||||
|
|
||||||
|
|
||||||
|
script_callbacks.on_ui_settings(on_ui_settings)
|
||||||
@@ -12,6 +12,8 @@ function isMobile() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
function reportWindowSize() {
|
function reportWindowSize() {
|
||||||
|
if (gradioApp().querySelector('.toprow-compact-tools')) return; // not applicable for compact prompt layout
|
||||||
|
|
||||||
var currentlyMobile = isMobile();
|
var currentlyMobile = isMobile();
|
||||||
if (currentlyMobile == isSetupForMobile) return;
|
if (currentlyMobile == isSetupForMobile) return;
|
||||||
isSetupForMobile = currentlyMobile;
|
isSetupForMobile = currentlyMobile;
|
||||||
|
|||||||
@@ -19,16 +19,28 @@ function keyupEditAttention(event) {
|
|||||||
let beforeParen = before.lastIndexOf(OPEN);
|
let beforeParen = before.lastIndexOf(OPEN);
|
||||||
if (beforeParen == -1) return false;
|
if (beforeParen == -1) return false;
|
||||||
|
|
||||||
|
let beforeClosingParen = before.lastIndexOf(CLOSE);
|
||||||
|
if (beforeClosingParen != -1 && beforeClosingParen > beforeParen) return false;
|
||||||
|
|
||||||
// Find closing parenthesis around current cursor
|
// Find closing parenthesis around current cursor
|
||||||
const after = text.substring(selectionStart);
|
const after = text.substring(selectionStart);
|
||||||
let afterParen = after.indexOf(CLOSE);
|
let afterParen = after.indexOf(CLOSE);
|
||||||
if (afterParen == -1) return false;
|
if (afterParen == -1) return false;
|
||||||
|
|
||||||
|
let afterOpeningParen = after.indexOf(OPEN);
|
||||||
|
if (afterOpeningParen != -1 && afterOpeningParen < afterParen) return false;
|
||||||
|
|
||||||
// Set the selection to the text between the parenthesis
|
// Set the selection to the text between the parenthesis
|
||||||
const parenContent = text.substring(beforeParen + 1, selectionStart + afterParen);
|
const parenContent = text.substring(beforeParen + 1, selectionStart + afterParen);
|
||||||
|
if (/.*:-?[\d.]+/s.test(parenContent)) {
|
||||||
const lastColon = parenContent.lastIndexOf(":");
|
const lastColon = parenContent.lastIndexOf(":");
|
||||||
selectionStart = beforeParen + 1;
|
selectionStart = beforeParen + 1;
|
||||||
selectionEnd = selectionStart + lastColon;
|
selectionEnd = selectionStart + lastColon;
|
||||||
|
} else {
|
||||||
|
selectionStart = beforeParen + 1;
|
||||||
|
selectionEnd = selectionStart + parenContent.length;
|
||||||
|
}
|
||||||
|
|
||||||
target.setSelectionRange(selectionStart, selectionEnd);
|
target.setSelectionRange(selectionStart, selectionEnd);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
@@ -57,7 +69,7 @@ function keyupEditAttention(event) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// If the user hasn't selected anything, let's select their current parenthesis block or word
|
// If the user hasn't selected anything, let's select their current parenthesis block or word
|
||||||
if (!selectCurrentParenthesisBlock('<', '>') && !selectCurrentParenthesisBlock('(', ')')) {
|
if (!selectCurrentParenthesisBlock('<', '>') && !selectCurrentParenthesisBlock('(', ')') && !selectCurrentParenthesisBlock('[', ']')) {
|
||||||
selectCurrentWord();
|
selectCurrentWord();
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -65,33 +77,54 @@ function keyupEditAttention(event) {
|
|||||||
|
|
||||||
var closeCharacter = ')';
|
var closeCharacter = ')';
|
||||||
var delta = opts.keyedit_precision_attention;
|
var delta = opts.keyedit_precision_attention;
|
||||||
|
var start = selectionStart > 0 ? text[selectionStart - 1] : "";
|
||||||
|
var end = text[selectionEnd];
|
||||||
|
|
||||||
if (selectionStart > 0 && text[selectionStart - 1] == '<') {
|
if (start == '<') {
|
||||||
closeCharacter = '>';
|
closeCharacter = '>';
|
||||||
delta = opts.keyedit_precision_extra;
|
delta = opts.keyedit_precision_extra;
|
||||||
} else if (selectionStart == 0 || text[selectionStart - 1] != "(") {
|
} else if (start == '(' && end == ')' || start == '[' && end == ']') { // convert old-style (((emphasis)))
|
||||||
|
let numParen = 0;
|
||||||
|
|
||||||
|
while (text[selectionStart - numParen - 1] == start && text[selectionEnd + numParen] == end) {
|
||||||
|
numParen++;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (start == "[") {
|
||||||
|
weight = (1 / 1.1) ** numParen;
|
||||||
|
} else {
|
||||||
|
weight = 1.1 ** numParen;
|
||||||
|
}
|
||||||
|
|
||||||
|
weight = Math.round(weight / opts.keyedit_precision_attention) * opts.keyedit_precision_attention;
|
||||||
|
|
||||||
|
text = text.slice(0, selectionStart - numParen) + "(" + text.slice(selectionStart, selectionEnd) + ":" + weight + ")" + text.slice(selectionEnd + numParen);
|
||||||
|
selectionStart -= numParen - 1;
|
||||||
|
selectionEnd -= numParen - 1;
|
||||||
|
} else if (start != '(') {
|
||||||
// do not include spaces at the end
|
// do not include spaces at the end
|
||||||
while (selectionEnd > selectionStart && text[selectionEnd - 1] == ' ') {
|
while (selectionEnd > selectionStart && text[selectionEnd - 1] == ' ') {
|
||||||
selectionEnd -= 1;
|
selectionEnd--;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (selectionStart == selectionEnd) {
|
if (selectionStart == selectionEnd) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
text = text.slice(0, selectionStart) + "(" + text.slice(selectionStart, selectionEnd) + ":1.0)" + text.slice(selectionEnd);
|
text = text.slice(0, selectionStart) + "(" + text.slice(selectionStart, selectionEnd) + ":1.0)" + text.slice(selectionEnd);
|
||||||
|
|
||||||
selectionStart += 1;
|
selectionStart++;
|
||||||
selectionEnd += 1;
|
selectionEnd++;
|
||||||
}
|
}
|
||||||
|
|
||||||
var end = text.slice(selectionEnd + 1).indexOf(closeCharacter) + 1;
|
if (text[selectionEnd] != ':') return;
|
||||||
var weight = parseFloat(text.slice(selectionEnd + 1, selectionEnd + end));
|
var weightLength = text.slice(selectionEnd + 1).indexOf(closeCharacter) + 1;
|
||||||
|
var weight = parseFloat(text.slice(selectionEnd + 1, selectionEnd + weightLength));
|
||||||
if (isNaN(weight)) return;
|
if (isNaN(weight)) return;
|
||||||
|
|
||||||
weight += isPlus ? delta : -delta;
|
weight += isPlus ? delta : -delta;
|
||||||
weight = parseFloat(weight.toPrecision(12));
|
weight = parseFloat(weight.toPrecision(12));
|
||||||
if (String(weight).length == 1) weight += ".0";
|
if (Number.isInteger(weight)) weight += ".0";
|
||||||
|
|
||||||
if (closeCharacter == ')' && weight == 1) {
|
if (closeCharacter == ')' && weight == 1) {
|
||||||
var endParenPos = text.substring(selectionEnd).indexOf(')');
|
var endParenPos = text.substring(selectionEnd).indexOf(')');
|
||||||
@@ -99,7 +132,7 @@ function keyupEditAttention(event) {
|
|||||||
selectionStart--;
|
selectionStart--;
|
||||||
selectionEnd--;
|
selectionEnd--;
|
||||||
} else {
|
} else {
|
||||||
text = text.slice(0, selectionEnd + 1) + weight + text.slice(selectionEnd + end);
|
text = text.slice(0, selectionEnd + 1) + weight + text.slice(selectionEnd + weightLength);
|
||||||
}
|
}
|
||||||
|
|
||||||
target.focus();
|
target.focus();
|
||||||
|
|||||||
+53
-11
@@ -26,8 +26,9 @@ function setupExtraNetworksForTab(tabname) {
|
|||||||
var refresh = gradioApp().getElementById(tabname + '_extra_refresh');
|
var refresh = gradioApp().getElementById(tabname + '_extra_refresh');
|
||||||
var showDirsDiv = gradioApp().getElementById(tabname + '_extra_show_dirs');
|
var showDirsDiv = gradioApp().getElementById(tabname + '_extra_show_dirs');
|
||||||
var showDirs = gradioApp().querySelector('#' + tabname + '_extra_show_dirs input');
|
var showDirs = gradioApp().querySelector('#' + tabname + '_extra_show_dirs input');
|
||||||
|
var promptContainer = gradioApp().querySelector('.prompt-container-compact#' + tabname + '_prompt_container');
|
||||||
|
var negativePrompt = gradioApp().querySelector('#' + tabname + '_neg_prompt');
|
||||||
|
|
||||||
sort.dataset.sortkey = 'sortDefault';
|
|
||||||
tabs.appendChild(searchDiv);
|
tabs.appendChild(searchDiv);
|
||||||
tabs.appendChild(sort);
|
tabs.appendChild(sort);
|
||||||
tabs.appendChild(sortOrder);
|
tabs.appendChild(sortOrder);
|
||||||
@@ -49,20 +50,23 @@ function setupExtraNetworksForTab(tabname) {
|
|||||||
|
|
||||||
elem.style.display = visible ? "" : "none";
|
elem.style.display = visible ? "" : "none";
|
||||||
});
|
});
|
||||||
|
|
||||||
|
applySort();
|
||||||
};
|
};
|
||||||
|
|
||||||
var applySort = function() {
|
var applySort = function() {
|
||||||
|
var cards = gradioApp().querySelectorAll('#' + tabname + '_extra_tabs div.card');
|
||||||
|
|
||||||
var reverse = sortOrder.classList.contains("sortReverse");
|
var reverse = sortOrder.classList.contains("sortReverse");
|
||||||
var sortKey = sort.querySelector("input").value.toLowerCase().replace("sort", "").replaceAll(" ", "_").replace(/_+$/, "").trim();
|
var sortKey = sort.querySelector("input").value.toLowerCase().replace("sort", "").replaceAll(" ", "_").replace(/_+$/, "").trim() || "name";
|
||||||
sortKey = sortKey ? "sort" + sortKey.charAt(0).toUpperCase() + sortKey.slice(1) : "";
|
sortKey = "sort" + sortKey.charAt(0).toUpperCase() + sortKey.slice(1);
|
||||||
var sortKeyStore = sortKey ? sortKey + (reverse ? "Reverse" : "") : "";
|
var sortKeyStore = sortKey + "-" + (reverse ? "Descending" : "Ascending") + "-" + cards.length;
|
||||||
if (!sortKey || sortKeyStore == sort.dataset.sortkey) {
|
|
||||||
|
if (sortKeyStore == sort.dataset.sortkey) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
sort.dataset.sortkey = sortKeyStore;
|
sort.dataset.sortkey = sortKeyStore;
|
||||||
|
|
||||||
var cards = gradioApp().querySelectorAll('#' + tabname + '_extra_tabs div.card');
|
|
||||||
cards.forEach(function(card) {
|
cards.forEach(function(card) {
|
||||||
card.originalParentElement = card.parentElement;
|
card.originalParentElement = card.parentElement;
|
||||||
});
|
});
|
||||||
@@ -88,15 +92,13 @@ function setupExtraNetworksForTab(tabname) {
|
|||||||
};
|
};
|
||||||
|
|
||||||
search.addEventListener("input", applyFilter);
|
search.addEventListener("input", applyFilter);
|
||||||
applyFilter();
|
|
||||||
["change", "blur", "click"].forEach(function(evt) {
|
|
||||||
sort.querySelector("input").addEventListener(evt, applySort);
|
|
||||||
});
|
|
||||||
sortOrder.addEventListener("click", function() {
|
sortOrder.addEventListener("click", function() {
|
||||||
sortOrder.classList.toggle("sortReverse");
|
sortOrder.classList.toggle("sortReverse");
|
||||||
applySort();
|
applySort();
|
||||||
});
|
});
|
||||||
|
applyFilter();
|
||||||
|
|
||||||
|
extraNetworksApplySort[tabname] = applySort;
|
||||||
extraNetworksApplyFilter[tabname] = applyFilter;
|
extraNetworksApplyFilter[tabname] = applyFilter;
|
||||||
|
|
||||||
var showDirsUpdate = function() {
|
var showDirsUpdate = function() {
|
||||||
@@ -109,11 +111,51 @@ function setupExtraNetworksForTab(tabname) {
|
|||||||
showDirsUpdate();
|
showDirsUpdate();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function extraNetworksMovePromptToTab(tabname, id, showPrompt, showNegativePrompt) {
|
||||||
|
if (!gradioApp().querySelector('.toprow-compact-tools')) return; // only applicable for compact prompt layout
|
||||||
|
|
||||||
|
var promptContainer = gradioApp().getElementById(tabname + '_prompt_container');
|
||||||
|
var prompt = gradioApp().getElementById(tabname + '_prompt_row');
|
||||||
|
var negPrompt = gradioApp().getElementById(tabname + '_neg_prompt_row');
|
||||||
|
var elem = id ? gradioApp().getElementById(id) : null;
|
||||||
|
|
||||||
|
if (showNegativePrompt && elem) {
|
||||||
|
elem.insertBefore(negPrompt, elem.firstChild);
|
||||||
|
} else {
|
||||||
|
promptContainer.insertBefore(negPrompt, promptContainer.firstChild);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (showPrompt && elem) {
|
||||||
|
elem.insertBefore(prompt, elem.firstChild);
|
||||||
|
} else {
|
||||||
|
promptContainer.insertBefore(prompt, promptContainer.firstChild);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (elem) {
|
||||||
|
elem.classList.toggle('extra-page-prompts-active', showNegativePrompt || showPrompt);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
function extraNetworksUrelatedTabSelected(tabname) { // called from python when user selects an unrelated tab (generate)
|
||||||
|
extraNetworksMovePromptToTab(tabname, '', false, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
function extraNetworksTabSelected(tabname, id, showPrompt, showNegativePrompt) { // called from python when user selects an extra networks tab
|
||||||
|
extraNetworksMovePromptToTab(tabname, id, showPrompt, showNegativePrompt);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
function applyExtraNetworkFilter(tabname) {
|
function applyExtraNetworkFilter(tabname) {
|
||||||
setTimeout(extraNetworksApplyFilter[tabname], 1);
|
setTimeout(extraNetworksApplyFilter[tabname], 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function applyExtraNetworkSort(tabname) {
|
||||||
|
setTimeout(extraNetworksApplySort[tabname], 1);
|
||||||
|
}
|
||||||
|
|
||||||
var extraNetworksApplyFilter = {};
|
var extraNetworksApplyFilter = {};
|
||||||
|
var extraNetworksApplySort = {};
|
||||||
var activePromptTextarea = {};
|
var activePromptTextarea = {};
|
||||||
|
|
||||||
function setupExtraNetworks() {
|
function setupExtraNetworks() {
|
||||||
|
|||||||
@@ -33,8 +33,11 @@ function updateOnBackgroundChange() {
|
|||||||
const modalImage = gradioApp().getElementById("modalImage");
|
const modalImage = gradioApp().getElementById("modalImage");
|
||||||
if (modalImage && modalImage.offsetParent) {
|
if (modalImage && modalImage.offsetParent) {
|
||||||
let currentButton = selected_gallery_button();
|
let currentButton = selected_gallery_button();
|
||||||
|
let preview = gradioApp().querySelectorAll('.livePreview > img');
|
||||||
if (currentButton?.children?.length > 0 && modalImage.src != currentButton.children[0].src) {
|
if (preview.length > 0) {
|
||||||
|
// show preview image if available
|
||||||
|
modalImage.src = preview[preview.length - 1].src;
|
||||||
|
} else if (currentButton?.children?.length > 0 && modalImage.src != currentButton.children[0].src) {
|
||||||
modalImage.src = currentButton.children[0].src;
|
modalImage.src = currentButton.children[0].src;
|
||||||
if (modalImage.style.display === 'none') {
|
if (modalImage.style.display === 'none') {
|
||||||
const modal = gradioApp().getElementById("lightboxModal");
|
const modal = gradioApp().getElementById("lightboxModal");
|
||||||
|
|||||||
@@ -1,37 +1,68 @@
|
|||||||
var observerAccordionOpen = new MutationObserver(function(mutations) {
|
function inputAccordionChecked(id, checked) {
|
||||||
mutations.forEach(function(mutationRecord) {
|
var accordion = gradioApp().getElementById(id);
|
||||||
var elem = mutationRecord.target;
|
accordion.visibleCheckbox.checked = checked;
|
||||||
var open = elem.classList.contains('open');
|
accordion.onVisibleCheckboxChange();
|
||||||
|
}
|
||||||
var accordion = elem.parentNode;
|
|
||||||
accordion.classList.toggle('input-accordion-open', open);
|
|
||||||
|
|
||||||
var checkbox = gradioApp().querySelector('#' + accordion.id + "-checkbox input");
|
|
||||||
checkbox.checked = open;
|
|
||||||
updateInput(checkbox);
|
|
||||||
|
|
||||||
|
function setupAccordion(accordion) {
|
||||||
|
var labelWrap = accordion.querySelector('.label-wrap');
|
||||||
|
var gradioCheckbox = gradioApp().querySelector('#' + accordion.id + "-checkbox input");
|
||||||
var extra = gradioApp().querySelector('#' + accordion.id + "-extra");
|
var extra = gradioApp().querySelector('#' + accordion.id + "-extra");
|
||||||
if (extra) {
|
var span = labelWrap.querySelector('span');
|
||||||
extra.style.display = open ? "" : "none";
|
var linked = true;
|
||||||
|
|
||||||
|
var isOpen = function() {
|
||||||
|
return labelWrap.classList.contains('open');
|
||||||
|
};
|
||||||
|
|
||||||
|
var observerAccordionOpen = new MutationObserver(function(mutations) {
|
||||||
|
mutations.forEach(function(mutationRecord) {
|
||||||
|
accordion.classList.toggle('input-accordion-open', isOpen());
|
||||||
|
|
||||||
|
if (linked) {
|
||||||
|
accordion.visibleCheckbox.checked = isOpen();
|
||||||
|
accordion.onVisibleCheckboxChange();
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
observerAccordionOpen.observe(labelWrap, {attributes: true, attributeFilter: ['class']});
|
||||||
|
|
||||||
function inputAccordionChecked(id, checked) {
|
if (extra) {
|
||||||
var label = gradioApp().querySelector('#' + id + " .label-wrap");
|
labelWrap.insertBefore(extra, labelWrap.lastElementChild);
|
||||||
if (label.classList.contains('open') != checked) {
|
|
||||||
label.click();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
accordion.onChecked = function(checked) {
|
||||||
|
if (isOpen() != checked) {
|
||||||
|
labelWrap.click();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
var visibleCheckbox = document.createElement('INPUT');
|
||||||
|
visibleCheckbox.type = 'checkbox';
|
||||||
|
visibleCheckbox.checked = isOpen();
|
||||||
|
visibleCheckbox.id = accordion.id + "-visible-checkbox";
|
||||||
|
visibleCheckbox.className = gradioCheckbox.className + " input-accordion-checkbox";
|
||||||
|
span.insertBefore(visibleCheckbox, span.firstChild);
|
||||||
|
|
||||||
|
accordion.visibleCheckbox = visibleCheckbox;
|
||||||
|
accordion.onVisibleCheckboxChange = function() {
|
||||||
|
if (linked && isOpen() != visibleCheckbox.checked) {
|
||||||
|
labelWrap.click();
|
||||||
|
}
|
||||||
|
|
||||||
|
gradioCheckbox.checked = visibleCheckbox.checked;
|
||||||
|
updateInput(gradioCheckbox);
|
||||||
|
};
|
||||||
|
|
||||||
|
visibleCheckbox.addEventListener('click', function(event) {
|
||||||
|
linked = false;
|
||||||
|
event.stopPropagation();
|
||||||
|
});
|
||||||
|
visibleCheckbox.addEventListener('input', accordion.onVisibleCheckboxChange);
|
||||||
}
|
}
|
||||||
|
|
||||||
onUiLoaded(function() {
|
onUiLoaded(function() {
|
||||||
for (var accordion of gradioApp().querySelectorAll('.input-accordion')) {
|
for (var accordion of gradioApp().querySelectorAll('.input-accordion')) {
|
||||||
var labelWrap = accordion.querySelector('.label-wrap');
|
setupAccordion(accordion);
|
||||||
observerAccordionOpen.observe(labelWrap, {attributes: true, attributeFilter: ['class']});
|
|
||||||
|
|
||||||
var extra = gradioApp().querySelector('#' + accordion.id + "-extra");
|
|
||||||
if (extra) {
|
|
||||||
labelWrap.insertBefore(extra, labelWrap.lastElementChild);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -26,7 +26,11 @@ onAfterUiUpdate(function() {
|
|||||||
lastHeadImg = headImg;
|
lastHeadImg = headImg;
|
||||||
|
|
||||||
// play notification sound if available
|
// play notification sound if available
|
||||||
gradioApp().querySelector('#audio_notification audio')?.play();
|
const notificationAudio = gradioApp().querySelector('#audio_notification audio');
|
||||||
|
if (notificationAudio) {
|
||||||
|
notificationAudio.volume = opts.notification_volume / 100.0 || 1.0;
|
||||||
|
notificationAudio.play();
|
||||||
|
}
|
||||||
|
|
||||||
if (document.hasFocus()) return;
|
if (document.hasFocus()) return;
|
||||||
|
|
||||||
|
|||||||
@@ -44,3 +44,28 @@ onUiLoaded(function() {
|
|||||||
|
|
||||||
buttonShowAllPages.addEventListener("click", settingsShowAllTabs);
|
buttonShowAllPages.addEventListener("click", settingsShowAllTabs);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
||||||
|
onOptionsChanged(function() {
|
||||||
|
if (gradioApp().querySelector('#settings .settings-category')) return;
|
||||||
|
|
||||||
|
var sectionMap = {};
|
||||||
|
gradioApp().querySelectorAll('#settings > div > button').forEach(function(x) {
|
||||||
|
sectionMap[x.textContent.trim()] = x;
|
||||||
|
});
|
||||||
|
|
||||||
|
opts._categories.forEach(function(x) {
|
||||||
|
var section = x[0];
|
||||||
|
var category = x[1];
|
||||||
|
|
||||||
|
var span = document.createElement('SPAN');
|
||||||
|
span.textContent = category;
|
||||||
|
span.className = 'settings-category';
|
||||||
|
|
||||||
|
var sectionElem = sectionMap[section];
|
||||||
|
if (!sectionElem) return;
|
||||||
|
|
||||||
|
sectionElem.parentElement.insertBefore(span, sectionElem);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
|||||||
+7
-7
@@ -17,15 +17,14 @@ from fastapi.encoders import jsonable_encoder
|
|||||||
from secrets import compare_digest
|
from secrets import compare_digest
|
||||||
|
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart, shared_items, script_callbacks, generation_parameters_copypaste
|
from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart, shared_items, script_callbacks, generation_parameters_copypaste, sd_models
|
||||||
from modules.api import models
|
from modules.api import models
|
||||||
from modules.shared import opts
|
from modules.shared import opts
|
||||||
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
|
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
|
||||||
from modules.textual_inversion.textual_inversion import create_embedding, train_embedding
|
from modules.textual_inversion.textual_inversion import create_embedding, train_embedding
|
||||||
from modules.textual_inversion.preprocess import preprocess
|
from modules.textual_inversion.preprocess import preprocess
|
||||||
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
|
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
|
||||||
from PIL import PngImagePlugin,Image
|
from PIL import PngImagePlugin, Image
|
||||||
from modules.sd_models import unload_model_weights, reload_model_weights, checkpoint_aliases
|
|
||||||
from modules.sd_models_config import find_checkpoint_config_near_filename
|
from modules.sd_models_config import find_checkpoint_config_near_filename
|
||||||
from modules.realesrgan_model import get_realesrgan_models
|
from modules.realesrgan_model import get_realesrgan_models
|
||||||
from modules import devices
|
from modules import devices
|
||||||
@@ -103,7 +102,8 @@ def decode_base64_to_image(encoding):
|
|||||||
|
|
||||||
def encode_pil_to_base64(image):
|
def encode_pil_to_base64(image):
|
||||||
with io.BytesIO() as output_bytes:
|
with io.BytesIO() as output_bytes:
|
||||||
|
if isinstance(image, str):
|
||||||
|
return image
|
||||||
if opts.samples_format.lower() == 'png':
|
if opts.samples_format.lower() == 'png':
|
||||||
use_metadata = False
|
use_metadata = False
|
||||||
metadata = PngImagePlugin.PngInfo()
|
metadata = PngImagePlugin.PngInfo()
|
||||||
@@ -540,12 +540,12 @@ class Api:
|
|||||||
return {}
|
return {}
|
||||||
|
|
||||||
def unloadapi(self):
|
def unloadapi(self):
|
||||||
unload_model_weights()
|
sd_models.unload_model_weights()
|
||||||
|
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
def reloadapi(self):
|
def reloadapi(self):
|
||||||
reload_model_weights()
|
sd_models.send_model_to_device(shared.sd_model)
|
||||||
|
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
@@ -565,7 +565,7 @@ class Api:
|
|||||||
|
|
||||||
def set_config(self, req: dict[str, Any]):
|
def set_config(self, req: dict[str, Any]):
|
||||||
checkpoint_name = req.get("sd_model_checkpoint", None)
|
checkpoint_name = req.get("sd_model_checkpoint", None)
|
||||||
if checkpoint_name is not None and checkpoint_name not in checkpoint_aliases:
|
if checkpoint_name is not None and checkpoint_name not in sd_models.checkpoint_aliases:
|
||||||
raise RuntimeError(f"model {checkpoint_name!r} not found")
|
raise RuntimeError(f"model {checkpoint_name!r} not found")
|
||||||
|
|
||||||
for k, v in req.items():
|
for k, v in req.items():
|
||||||
|
|||||||
@@ -93,8 +93,8 @@ class PydanticModelGenerator:
|
|||||||
d.field: (d.field_type, Field(default=d.field_value, alias=d.field_alias, exclude=d.field_exclude)) for d in self._model_def
|
d.field: (d.field_type, Field(default=d.field_value, alias=d.field_alias, exclude=d.field_exclude)) for d in self._model_def
|
||||||
}
|
}
|
||||||
DynamicModel = create_model(self._model_name, **fields)
|
DynamicModel = create_model(self._model_name, **fields)
|
||||||
DynamicModel.__config__.allow_population_by_field_name = True
|
DynamicModel.model_config['populate_by_name'] = True
|
||||||
DynamicModel.__config__.allow_mutation = True
|
DynamicModel.model_config['frozen'] = True
|
||||||
return DynamicModel
|
return DynamicModel
|
||||||
|
|
||||||
StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator(
|
StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator(
|
||||||
|
|||||||
+1
-1
@@ -32,7 +32,7 @@ def dump_cache():
|
|||||||
with cache_lock:
|
with cache_lock:
|
||||||
cache_filename_tmp = cache_filename + "-"
|
cache_filename_tmp = cache_filename + "-"
|
||||||
with open(cache_filename_tmp, "w", encoding="utf8") as file:
|
with open(cache_filename_tmp, "w", encoding="utf8") as file:
|
||||||
json.dump(cache_data, file, indent=4)
|
json.dump(cache_data, file, indent=4, ensure_ascii=False)
|
||||||
|
|
||||||
os.replace(cache_filename_tmp, cache_filename)
|
os.replace(cache_filename_tmp, cache_filename)
|
||||||
|
|
||||||
|
|||||||
+1
-1
@@ -107,7 +107,7 @@ parser.add_argument("--tls-certfile", type=str, help="Partially enables TLS, req
|
|||||||
parser.add_argument("--disable-tls-verify", action="store_false", help="When passed, enables the use of self-signed certificates.", default=None)
|
parser.add_argument("--disable-tls-verify", action="store_false", help="When passed, enables the use of self-signed certificates.", default=None)
|
||||||
parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None)
|
parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None)
|
||||||
parser.add_argument("--gradio-queue", action='store_true', help="does not do anything", default=True)
|
parser.add_argument("--gradio-queue", action='store_true', help="does not do anything", default=True)
|
||||||
parser.add_argument("--no-gradio-queue", action='store_true', help="Disables gradio queue; causes the webpage to use http requests instead of websockets; was the defaul in earlier versions")
|
parser.add_argument("--no-gradio-queue", action='store_true', help="Disables gradio queue; causes the webpage to use http requests instead of websockets; was the default in earlier versions")
|
||||||
parser.add_argument("--skip-version-check", action='store_true', help="Do not check versions of torch and xformers")
|
parser.add_argument("--skip-version-check", action='store_true', help="Do not check versions of torch and xformers")
|
||||||
parser.add_argument("--no-hashing", action='store_true', help="disable sha256 hashing of checkpoints to help loading performance", default=False)
|
parser.add_argument("--no-hashing", action='store_true', help="disable sha256 hashing of checkpoints to help loading performance", default=False)
|
||||||
parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False)
|
parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False)
|
||||||
|
|||||||
+16
-2
@@ -6,6 +6,21 @@ import traceback
|
|||||||
exception_records = []
|
exception_records = []
|
||||||
|
|
||||||
|
|
||||||
|
def format_traceback(tb):
|
||||||
|
return [[f"{x.filename}, line {x.lineno}, {x.name}", x.line] for x in traceback.extract_tb(tb)]
|
||||||
|
|
||||||
|
|
||||||
|
def format_exception(e, tb):
|
||||||
|
return {"exception": str(e), "traceback": format_traceback(tb)}
|
||||||
|
|
||||||
|
|
||||||
|
def get_exceptions():
|
||||||
|
try:
|
||||||
|
return list(reversed(exception_records))
|
||||||
|
except Exception as e:
|
||||||
|
return str(e)
|
||||||
|
|
||||||
|
|
||||||
def record_exception():
|
def record_exception():
|
||||||
_, e, tb = sys.exc_info()
|
_, e, tb = sys.exc_info()
|
||||||
if e is None:
|
if e is None:
|
||||||
@@ -14,8 +29,7 @@ def record_exception():
|
|||||||
if exception_records and exception_records[-1] == e:
|
if exception_records and exception_records[-1] == e:
|
||||||
return
|
return
|
||||||
|
|
||||||
from modules import sysinfo
|
exception_records.append(format_exception(e, tb))
|
||||||
exception_records.append(sysinfo.format_exception(e, tb))
|
|
||||||
|
|
||||||
if len(exception_records) > 5:
|
if len(exception_records) > 5:
|
||||||
exception_records.pop(0)
|
exception_records.pop(0)
|
||||||
|
|||||||
+83
-11
@@ -1,11 +1,14 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import configparser
|
||||||
import os
|
import os
|
||||||
import threading
|
import threading
|
||||||
|
import re
|
||||||
|
|
||||||
from modules import shared, errors, cache, scripts
|
from modules import shared, errors, cache, scripts
|
||||||
from modules.gitpython_hack import Repo
|
from modules.gitpython_hack import Repo
|
||||||
from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path # noqa: F401
|
from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path # noqa: F401
|
||||||
|
|
||||||
extensions = []
|
|
||||||
|
|
||||||
os.makedirs(extensions_dir, exist_ok=True)
|
os.makedirs(extensions_dir, exist_ok=True)
|
||||||
|
|
||||||
@@ -19,11 +22,55 @@ def active():
|
|||||||
return [x for x in extensions if x.enabled]
|
return [x for x in extensions if x.enabled]
|
||||||
|
|
||||||
|
|
||||||
|
class ExtensionMetadata:
|
||||||
|
filename = "metadata.ini"
|
||||||
|
config: configparser.ConfigParser
|
||||||
|
canonical_name: str
|
||||||
|
requires: list
|
||||||
|
|
||||||
|
def __init__(self, path, canonical_name):
|
||||||
|
self.config = configparser.ConfigParser()
|
||||||
|
|
||||||
|
filepath = os.path.join(path, self.filename)
|
||||||
|
if os.path.isfile(filepath):
|
||||||
|
try:
|
||||||
|
self.config.read(filepath)
|
||||||
|
except Exception:
|
||||||
|
errors.report(f"Error reading {self.filename} for extension {canonical_name}.", exc_info=True)
|
||||||
|
|
||||||
|
self.canonical_name = self.config.get("Extension", "Name", fallback=canonical_name)
|
||||||
|
self.canonical_name = canonical_name.lower().strip()
|
||||||
|
|
||||||
|
self.requires = self.get_script_requirements("Requires", "Extension")
|
||||||
|
|
||||||
|
def get_script_requirements(self, field, section, extra_section=None):
|
||||||
|
"""reads a list of requirements from the config; field is the name of the field in the ini file,
|
||||||
|
like Requires or Before, and section is the name of the [section] in the ini file; additionally,
|
||||||
|
reads more requirements from [extra_section] if specified."""
|
||||||
|
|
||||||
|
x = self.config.get(section, field, fallback='')
|
||||||
|
|
||||||
|
if extra_section:
|
||||||
|
x = x + ', ' + self.config.get(extra_section, field, fallback='')
|
||||||
|
|
||||||
|
return self.parse_list(x.lower())
|
||||||
|
|
||||||
|
def parse_list(self, text):
|
||||||
|
"""converts a line from config ("ext1 ext2, ext3 ") into a python list (["ext1", "ext2", "ext3"])"""
|
||||||
|
|
||||||
|
if not text:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# both "," and " " are accepted as separator
|
||||||
|
return [x for x in re.split(r"[,\s]+", text.strip()) if x]
|
||||||
|
|
||||||
|
|
||||||
class Extension:
|
class Extension:
|
||||||
lock = threading.Lock()
|
lock = threading.Lock()
|
||||||
cached_fields = ['remote', 'commit_date', 'branch', 'commit_hash', 'version']
|
cached_fields = ['remote', 'commit_date', 'branch', 'commit_hash', 'version']
|
||||||
|
metadata: ExtensionMetadata
|
||||||
|
|
||||||
def __init__(self, name, path, enabled=True, is_builtin=False):
|
def __init__(self, name, path, enabled=True, is_builtin=False, metadata=None):
|
||||||
self.name = name
|
self.name = name
|
||||||
self.path = path
|
self.path = path
|
||||||
self.enabled = enabled
|
self.enabled = enabled
|
||||||
@@ -36,6 +83,8 @@ class Extension:
|
|||||||
self.branch = None
|
self.branch = None
|
||||||
self.remote = None
|
self.remote = None
|
||||||
self.have_info_from_repo = False
|
self.have_info_from_repo = False
|
||||||
|
self.metadata = metadata if metadata else ExtensionMetadata(self.path, name.lower())
|
||||||
|
self.canonical_name = metadata.canonical_name
|
||||||
|
|
||||||
def to_dict(self):
|
def to_dict(self):
|
||||||
return {x: getattr(self, x) for x in self.cached_fields}
|
return {x: getattr(self, x) for x in self.cached_fields}
|
||||||
@@ -56,6 +105,7 @@ class Extension:
|
|||||||
self.do_read_info_from_repo()
|
self.do_read_info_from_repo()
|
||||||
|
|
||||||
return self.to_dict()
|
return self.to_dict()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
d = cache.cached_data_for_file('extensions-git', self.name, os.path.join(self.path, ".git"), read_from_repo)
|
d = cache.cached_data_for_file('extensions-git', self.name, os.path.join(self.path, ".git"), read_from_repo)
|
||||||
self.from_dict(d)
|
self.from_dict(d)
|
||||||
@@ -136,9 +186,6 @@ class Extension:
|
|||||||
def list_extensions():
|
def list_extensions():
|
||||||
extensions.clear()
|
extensions.clear()
|
||||||
|
|
||||||
if not os.path.isdir(extensions_dir):
|
|
||||||
return
|
|
||||||
|
|
||||||
if shared.cmd_opts.disable_all_extensions:
|
if shared.cmd_opts.disable_all_extensions:
|
||||||
print("*** \"--disable-all-extensions\" arg was used, will not load any extensions ***")
|
print("*** \"--disable-all-extensions\" arg was used, will not load any extensions ***")
|
||||||
elif shared.opts.disable_all_extensions == "all":
|
elif shared.opts.disable_all_extensions == "all":
|
||||||
@@ -148,18 +195,43 @@ def list_extensions():
|
|||||||
elif shared.opts.disable_all_extensions == "extra":
|
elif shared.opts.disable_all_extensions == "extra":
|
||||||
print("*** \"Disable all extensions\" option was set, will only load built-in extensions ***")
|
print("*** \"Disable all extensions\" option was set, will only load built-in extensions ***")
|
||||||
|
|
||||||
extension_paths = []
|
loaded_extensions = {}
|
||||||
for dirname in [extensions_dir, extensions_builtin_dir]:
|
|
||||||
|
# scan through extensions directory and load metadata
|
||||||
|
for dirname in [extensions_builtin_dir, extensions_dir]:
|
||||||
if not os.path.isdir(dirname):
|
if not os.path.isdir(dirname):
|
||||||
return
|
continue
|
||||||
|
|
||||||
for extension_dirname in sorted(os.listdir(dirname)):
|
for extension_dirname in sorted(os.listdir(dirname)):
|
||||||
path = os.path.join(dirname, extension_dirname)
|
path = os.path.join(dirname, extension_dirname)
|
||||||
if not os.path.isdir(path):
|
if not os.path.isdir(path):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
extension_paths.append((extension_dirname, path, dirname == extensions_builtin_dir))
|
canonical_name = extension_dirname
|
||||||
|
metadata = ExtensionMetadata(path, canonical_name)
|
||||||
|
|
||||||
for dirname, path, is_builtin in extension_paths:
|
# check for duplicated canonical names
|
||||||
extension = Extension(name=dirname, path=path, enabled=dirname not in shared.opts.disabled_extensions, is_builtin=is_builtin)
|
already_loaded_extension = loaded_extensions.get(metadata.canonical_name)
|
||||||
|
if already_loaded_extension is not None:
|
||||||
|
errors.report(f'Duplicate canonical name "{canonical_name}" found in extensions "{extension_dirname}" and "{already_loaded_extension.name}". Former will be discarded.', exc_info=False)
|
||||||
|
continue
|
||||||
|
|
||||||
|
is_builtin = dirname == extensions_builtin_dir
|
||||||
|
extension = Extension(name=extension_dirname, path=path, enabled=extension_dirname not in shared.opts.disabled_extensions, is_builtin=is_builtin, metadata=metadata)
|
||||||
extensions.append(extension)
|
extensions.append(extension)
|
||||||
|
loaded_extensions[canonical_name] = extension
|
||||||
|
|
||||||
|
# check for requirements
|
||||||
|
for extension in extensions:
|
||||||
|
for req in extension.metadata.requires:
|
||||||
|
required_extension = loaded_extensions.get(req)
|
||||||
|
if required_extension is None:
|
||||||
|
errors.report(f'Extension "{extension.name}" requires "{req}" which is not installed.', exc_info=False)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not extension.enabled:
|
||||||
|
errors.report(f'Extension "{extension.name}" requires "{required_extension.name}" which is disabled.', exc_info=False)
|
||||||
|
continue
|
||||||
|
|
||||||
|
|
||||||
|
extensions: list[Extension] = []
|
||||||
|
|||||||
+20
-5
@@ -9,6 +9,7 @@ from modules import paths, shared, devices, modelloader, errors
|
|||||||
model_dir = "GFPGAN"
|
model_dir = "GFPGAN"
|
||||||
user_path = None
|
user_path = None
|
||||||
model_path = os.path.join(paths.models_path, model_dir)
|
model_path = os.path.join(paths.models_path, model_dir)
|
||||||
|
model_file_path = None
|
||||||
model_url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth"
|
model_url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth"
|
||||||
have_gfpgan = False
|
have_gfpgan = False
|
||||||
loaded_gfpgan_model = None
|
loaded_gfpgan_model = None
|
||||||
@@ -17,6 +18,7 @@ loaded_gfpgan_model = None
|
|||||||
def gfpgann():
|
def gfpgann():
|
||||||
global loaded_gfpgan_model
|
global loaded_gfpgan_model
|
||||||
global model_path
|
global model_path
|
||||||
|
global model_file_path
|
||||||
if loaded_gfpgan_model is not None:
|
if loaded_gfpgan_model is not None:
|
||||||
loaded_gfpgan_model.gfpgan.to(devices.device_gfpgan)
|
loaded_gfpgan_model.gfpgan.to(devices.device_gfpgan)
|
||||||
return loaded_gfpgan_model
|
return loaded_gfpgan_model
|
||||||
@@ -24,17 +26,24 @@ def gfpgann():
|
|||||||
if gfpgan_constructor is None:
|
if gfpgan_constructor is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
models = modelloader.load_models(model_path, model_url, user_path, ext_filter="GFPGAN")
|
models = modelloader.load_models(model_path, model_url, user_path, ext_filter=['.pth'])
|
||||||
|
|
||||||
if len(models) == 1 and models[0].startswith("http"):
|
if len(models) == 1 and models[0].startswith("http"):
|
||||||
model_file = models[0]
|
model_file = models[0]
|
||||||
elif len(models) != 0:
|
elif len(models) != 0:
|
||||||
latest_file = max(models, key=os.path.getctime)
|
gfp_models = []
|
||||||
|
for item in models:
|
||||||
|
if 'GFPGAN' in os.path.basename(item):
|
||||||
|
gfp_models.append(item)
|
||||||
|
latest_file = max(gfp_models, key=os.path.getctime)
|
||||||
model_file = latest_file
|
model_file = latest_file
|
||||||
else:
|
else:
|
||||||
print("Unable to load gfpgan model!")
|
print("Unable to load gfpgan model!")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if hasattr(facexlib.detection.retinaface, 'device'):
|
if hasattr(facexlib.detection.retinaface, 'device'):
|
||||||
facexlib.detection.retinaface.device = devices.device_gfpgan
|
facexlib.detection.retinaface.device = devices.device_gfpgan
|
||||||
|
model_file_path = model_file
|
||||||
model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=devices.device_gfpgan)
|
model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=devices.device_gfpgan)
|
||||||
loaded_gfpgan_model = model
|
loaded_gfpgan_model = model
|
||||||
|
|
||||||
@@ -77,19 +86,25 @@ def setup_model(dirname):
|
|||||||
global user_path
|
global user_path
|
||||||
global have_gfpgan
|
global have_gfpgan
|
||||||
global gfpgan_constructor
|
global gfpgan_constructor
|
||||||
|
global model_file_path
|
||||||
|
|
||||||
|
facexlib_path = model_path
|
||||||
|
|
||||||
|
if dirname is not None:
|
||||||
|
facexlib_path = dirname
|
||||||
|
|
||||||
load_file_from_url_orig = gfpgan.utils.load_file_from_url
|
load_file_from_url_orig = gfpgan.utils.load_file_from_url
|
||||||
facex_load_file_from_url_orig = facexlib.detection.load_file_from_url
|
facex_load_file_from_url_orig = facexlib.detection.load_file_from_url
|
||||||
facex_load_file_from_url_orig2 = facexlib.parsing.load_file_from_url
|
facex_load_file_from_url_orig2 = facexlib.parsing.load_file_from_url
|
||||||
|
|
||||||
def my_load_file_from_url(**kwargs):
|
def my_load_file_from_url(**kwargs):
|
||||||
return load_file_from_url_orig(**dict(kwargs, model_dir=model_path))
|
return load_file_from_url_orig(**dict(kwargs, model_dir=model_file_path))
|
||||||
|
|
||||||
def facex_load_file_from_url(**kwargs):
|
def facex_load_file_from_url(**kwargs):
|
||||||
return facex_load_file_from_url_orig(**dict(kwargs, save_dir=model_path, model_dir=None))
|
return facex_load_file_from_url_orig(**dict(kwargs, save_dir=facexlib_path, model_dir=None))
|
||||||
|
|
||||||
def facex_load_file_from_url2(**kwargs):
|
def facex_load_file_from_url2(**kwargs):
|
||||||
return facex_load_file_from_url_orig2(**dict(kwargs, save_dir=model_path, model_dir=None))
|
return facex_load_file_from_url_orig2(**dict(kwargs, save_dir=facexlib_path, model_dir=None))
|
||||||
|
|
||||||
gfpgan.utils.load_file_from_url = my_load_file_from_url
|
gfpgan.utils.load_file_from_url = my_load_file_from_url
|
||||||
facexlib.detection.load_file_from_url = facex_load_file_from_url
|
facexlib.detection.load_file_from_url = facex_load_file_from_url
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
from inspect import signature
|
||||||
|
from functools import wraps
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
from modules import scripts, ui_tempdir, patches
|
from modules import scripts, ui_tempdir, patches
|
||||||
@@ -64,10 +66,77 @@ def Blocks_get_config_file(self, *args, **kwargs):
|
|||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
original_IOComponent_init = patches.patch(__name__, obj=gr.components.IOComponent, field="__init__", replacement=IOComponent_init)
|
def gradio_component_compatibility_layer(component_function):
|
||||||
|
@wraps(component_function)
|
||||||
|
def patched_function(*args, **kwargs):
|
||||||
|
original_signature = signature(component_function).parameters
|
||||||
|
valid_kwargs = {k: v for k, v in kwargs.items() if k in original_signature}
|
||||||
|
result = component_function(*args, **valid_kwargs)
|
||||||
|
return result
|
||||||
|
|
||||||
|
return patched_function
|
||||||
|
|
||||||
|
|
||||||
|
sub_events = ['then', 'success']
|
||||||
|
|
||||||
|
|
||||||
|
def gradio_component_events_compatibility_layer(component_function):
|
||||||
|
@wraps(component_function)
|
||||||
|
def patched_function(*args, **kwargs):
|
||||||
|
kwargs['js'] = kwargs.get('js', kwargs.pop('_js', None))
|
||||||
|
original_signature = signature(component_function).parameters
|
||||||
|
valid_kwargs = {k: v for k, v in kwargs.items() if k in original_signature}
|
||||||
|
|
||||||
|
result = component_function(*args, **valid_kwargs)
|
||||||
|
|
||||||
|
for sub_event in sub_events:
|
||||||
|
component_event_then_function = getattr(result, sub_event, None)
|
||||||
|
if component_event_then_function:
|
||||||
|
patched_component_event_then_function = gradio_component_sub_events_compatibility_layer(component_event_then_function)
|
||||||
|
setattr(result, sub_event, patched_component_event_then_function)
|
||||||
|
# original_component_event_then_function = patches.patch(f'{__name__}.', obj=result, field='then', replacement=patched_component_event_then_function)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
return patched_function
|
||||||
|
|
||||||
|
|
||||||
|
def gradio_component_sub_events_compatibility_layer(component_function):
|
||||||
|
@wraps(component_function)
|
||||||
|
def patched_function(*args, **kwargs):
|
||||||
|
kwargs['js'] = kwargs.get('js', kwargs.pop('_js', None))
|
||||||
|
original_signature = signature(component_function).parameters
|
||||||
|
valid_kwargs = {k: v for k, v in kwargs.items() if k in original_signature}
|
||||||
|
result = component_function(*args, **valid_kwargs)
|
||||||
|
return result
|
||||||
|
|
||||||
|
return patched_function
|
||||||
|
|
||||||
|
|
||||||
|
for component_name in set(gr.components.__all__ + gr.layouts.__all__):
|
||||||
|
try:
|
||||||
|
component = getattr(gr, component_name)
|
||||||
|
component_init = getattr(component, '__init__')
|
||||||
|
patched_component_init = gradio_component_compatibility_layer(component_init)
|
||||||
|
original_IOComponent_init = patches.patch(f'{__name__}.{component_name}', obj=component, field="__init__", replacement=patched_component_init)
|
||||||
|
|
||||||
|
component_events = set(getattr(component, 'EVENTS'))
|
||||||
|
for component_event in component_events:
|
||||||
|
component_event_function = getattr(component, component_event)
|
||||||
|
patched_component_event_function = gradio_component_events_compatibility_layer(component_event_function)
|
||||||
|
original_component_event_function = patches.patch(f'{__name__}.{component_name}.{component_event}', obj=component, field=component_event, replacement=patched_component_event_function)
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
pass
|
||||||
|
|
||||||
|
gr.Box = gr.Group
|
||||||
|
|
||||||
|
|
||||||
|
original_IOComponent_init = patches.patch(__name__, obj=gr.components.base.Component, field="__init__", replacement=IOComponent_init)
|
||||||
original_Block_get_config = patches.patch(__name__, obj=gr.blocks.Block, field="get_config", replacement=Block_get_config)
|
original_Block_get_config = patches.patch(__name__, obj=gr.blocks.Block, field="get_config", replacement=Block_get_config)
|
||||||
original_BlockContext_init = patches.patch(__name__, obj=gr.blocks.BlockContext, field="__init__", replacement=BlockContext_init)
|
original_BlockContext_init = patches.patch(__name__, obj=gr.blocks.BlockContext, field="__init__", replacement=BlockContext_init)
|
||||||
original_Blocks_get_config_file = patches.patch(__name__, obj=gr.blocks.Blocks, field="get_config_file", replacement=Blocks_get_config_file)
|
original_Blocks_get_config_file = patches.patch(__name__, obj=gr.blocks.Blocks, field="get_config_file", replacement=Blocks_get_config_file)
|
||||||
|
|
||||||
|
|
||||||
ui_tempdir.install_ui_tempdir_override()
|
ui_tempdir.install_ui_tempdir_override()
|
||||||
|
|
||||||
|
|||||||
+19
-3
@@ -44,6 +44,8 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal
|
|||||||
steps = p.steps
|
steps = p.steps
|
||||||
override_settings = p.override_settings
|
override_settings = p.override_settings
|
||||||
sd_model_checkpoint_override = get_closet_checkpoint_match(override_settings.get("sd_model_checkpoint", None))
|
sd_model_checkpoint_override = get_closet_checkpoint_match(override_settings.get("sd_model_checkpoint", None))
|
||||||
|
batch_results = None
|
||||||
|
discard_further_results = False
|
||||||
for i, image in enumerate(images):
|
for i, image in enumerate(images):
|
||||||
state.job = f"{i+1} out of {len(images)}"
|
state.job = f"{i+1} out of {len(images)}"
|
||||||
if state.skipped:
|
if state.skipped:
|
||||||
@@ -127,7 +129,21 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal
|
|||||||
|
|
||||||
if proc is None:
|
if proc is None:
|
||||||
p.override_settings.pop('save_images_replace_action', None)
|
p.override_settings.pop('save_images_replace_action', None)
|
||||||
process_images(p)
|
proc = process_images(p)
|
||||||
|
|
||||||
|
if not discard_further_results and proc:
|
||||||
|
if batch_results:
|
||||||
|
batch_results.images.extend(proc.images)
|
||||||
|
batch_results.infotexts.extend(proc.infotexts)
|
||||||
|
else:
|
||||||
|
batch_results = proc
|
||||||
|
|
||||||
|
if 0 <= shared.opts.img2img_batch_show_results_limit < len(batch_results.images):
|
||||||
|
discard_further_results = True
|
||||||
|
batch_results.images = batch_results.images[:int(shared.opts.img2img_batch_show_results_limit)]
|
||||||
|
batch_results.infotexts = batch_results.infotexts[:int(shared.opts.img2img_batch_show_results_limit)]
|
||||||
|
|
||||||
|
return batch_results
|
||||||
|
|
||||||
|
|
||||||
def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_name: str, mask_blur: int, mask_alpha: float, inpainting_fill: int, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, request: gr.Request, *args):
|
def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_name: str, mask_blur: int, mask_alpha: float, inpainting_fill: int, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, request: gr.Request, *args):
|
||||||
@@ -212,9 +228,9 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
|
|||||||
with closing(p):
|
with closing(p):
|
||||||
if is_batch:
|
if is_batch:
|
||||||
assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled"
|
assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled"
|
||||||
|
processed = process_batch(p, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir, args, to_scale=selected_scale_tab == 1, scale_by=scale_by, use_png_info=img2img_batch_use_png_info, png_info_props=img2img_batch_png_info_props, png_info_dir=img2img_batch_png_info_dir)
|
||||||
|
|
||||||
process_batch(p, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir, args, to_scale=selected_scale_tab == 1, scale_by=scale_by, use_png_info=img2img_batch_use_png_info, png_info_props=img2img_batch_png_info_props, png_info_dir=img2img_batch_png_info_dir)
|
if processed is None:
|
||||||
|
|
||||||
processed = Processed(p, [], p.seed, "")
|
processed = Processed(p, [], p.seed, "")
|
||||||
else:
|
else:
|
||||||
processed = modules.scripts.scripts_img2img.run(p, *args)
|
processed = modules.scripts.scripts_img2img.run(p, *args)
|
||||||
|
|||||||
@@ -150,9 +150,13 @@ def dumpstacks():
|
|||||||
|
|
||||||
def configure_sigint_handler():
|
def configure_sigint_handler():
|
||||||
# make the program just exit at ctrl+c without waiting for anything
|
# make the program just exit at ctrl+c without waiting for anything
|
||||||
|
|
||||||
|
from modules import shared
|
||||||
|
|
||||||
def sigint_handler(sig, frame):
|
def sigint_handler(sig, frame):
|
||||||
print(f'Interrupted with signal {sig} in {frame}')
|
print(f'Interrupted with signal {sig} in {frame}')
|
||||||
|
|
||||||
|
if shared.opts.dump_stacks_on_signal:
|
||||||
dumpstacks()
|
dumpstacks()
|
||||||
|
|
||||||
os._exit(0)
|
os._exit(0)
|
||||||
|
|||||||
@@ -441,7 +441,7 @@ def dump_sysinfo():
|
|||||||
import datetime
|
import datetime
|
||||||
|
|
||||||
text = sysinfo.get()
|
text = sysinfo.get()
|
||||||
filename = f"sysinfo-{datetime.datetime.utcnow().strftime('%Y-%m-%d-%H-%M')}.txt"
|
filename = f"sysinfo-{datetime.datetime.utcnow().strftime('%Y-%m-%d-%H-%M')}.json"
|
||||||
|
|
||||||
with open(filename, "w", encoding="utf8") as file:
|
with open(filename, "w", encoding="utf8") as file:
|
||||||
file.write(text)
|
file.write(text)
|
||||||
|
|||||||
@@ -1,16 +1,41 @@
|
|||||||
import os
|
import os
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
try:
|
||||||
|
from tqdm.auto import tqdm
|
||||||
|
|
||||||
|
class TqdmLoggingHandler(logging.Handler):
|
||||||
|
def __init__(self, level=logging.INFO):
|
||||||
|
super().__init__(level)
|
||||||
|
|
||||||
|
def emit(self, record):
|
||||||
|
try:
|
||||||
|
msg = self.format(record)
|
||||||
|
tqdm.write(msg)
|
||||||
|
self.flush()
|
||||||
|
except Exception:
|
||||||
|
self.handleError(record)
|
||||||
|
|
||||||
|
TQDM_IMPORTED = True
|
||||||
|
except ImportError:
|
||||||
|
# tqdm does not exist before first launch
|
||||||
|
# I will import once the UI finishes seting up the enviroment and reloads.
|
||||||
|
TQDM_IMPORTED = False
|
||||||
|
|
||||||
def setup_logging(loglevel):
|
def setup_logging(loglevel):
|
||||||
if loglevel is None:
|
if loglevel is None:
|
||||||
loglevel = os.environ.get("SD_WEBUI_LOG_LEVEL")
|
loglevel = os.environ.get("SD_WEBUI_LOG_LEVEL")
|
||||||
|
|
||||||
|
loghandlers = []
|
||||||
|
|
||||||
|
if TQDM_IMPORTED:
|
||||||
|
loghandlers.append(TqdmLoggingHandler())
|
||||||
|
|
||||||
if loglevel:
|
if loglevel:
|
||||||
log_level = getattr(logging, loglevel.upper(), None) or logging.INFO
|
log_level = getattr(logging, loglevel.upper(), None) or logging.INFO
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
level=log_level,
|
level=log_level,
|
||||||
format='%(asctime)s %(levelname)s [%(name)s] %(message)s',
|
format='%(asctime)s %(levelname)s [%(name)s] %(message)s',
|
||||||
datefmt='%Y-%m-%d %H:%M:%S',
|
datefmt='%Y-%m-%d %H:%M:%S',
|
||||||
|
handlers=loghandlers
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
+70
-9
@@ -1,5 +1,6 @@
|
|||||||
import json
|
import json
|
||||||
import sys
|
import sys
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
@@ -8,13 +9,14 @@ from modules.shared_cmd_options import cmd_opts
|
|||||||
|
|
||||||
|
|
||||||
class OptionInfo:
|
class OptionInfo:
|
||||||
def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, section=None, refresh=None, comment_before='', comment_after='', infotext=None, restrict_api=False):
|
def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, section=None, refresh=None, comment_before='', comment_after='', infotext=None, restrict_api=False, category_id=None):
|
||||||
self.default = default
|
self.default = default
|
||||||
self.label = label
|
self.label = label
|
||||||
self.component = component
|
self.component = component
|
||||||
self.component_args = component_args
|
self.component_args = component_args
|
||||||
self.onchange = onchange
|
self.onchange = onchange
|
||||||
self.section = section
|
self.section = section
|
||||||
|
self.category_id = category_id
|
||||||
self.refresh = refresh
|
self.refresh = refresh
|
||||||
self.do_not_save = False
|
self.do_not_save = False
|
||||||
|
|
||||||
@@ -63,7 +65,11 @@ class OptionHTML(OptionInfo):
|
|||||||
|
|
||||||
def options_section(section_identifier, options_dict):
|
def options_section(section_identifier, options_dict):
|
||||||
for v in options_dict.values():
|
for v in options_dict.values():
|
||||||
|
if len(section_identifier) == 2:
|
||||||
v.section = section_identifier
|
v.section = section_identifier
|
||||||
|
elif len(section_identifier) == 3:
|
||||||
|
v.section = section_identifier[0:2]
|
||||||
|
v.category_id = section_identifier[2]
|
||||||
|
|
||||||
return options_dict
|
return options_dict
|
||||||
|
|
||||||
@@ -76,7 +82,7 @@ class Options:
|
|||||||
|
|
||||||
def __init__(self, data_labels: dict[str, OptionInfo], restricted_opts):
|
def __init__(self, data_labels: dict[str, OptionInfo], restricted_opts):
|
||||||
self.data_labels = data_labels
|
self.data_labels = data_labels
|
||||||
self.data = {k: v.default for k, v in self.data_labels.items()}
|
self.data = {k: v.default for k, v in self.data_labels.items() if not v.do_not_save}
|
||||||
self.restricted_opts = restricted_opts
|
self.restricted_opts = restricted_opts
|
||||||
|
|
||||||
def __setattr__(self, key, value):
|
def __setattr__(self, key, value):
|
||||||
@@ -158,7 +164,7 @@ class Options:
|
|||||||
assert not cmd_opts.freeze_settings, "saving settings is disabled"
|
assert not cmd_opts.freeze_settings, "saving settings is disabled"
|
||||||
|
|
||||||
with open(filename, "w", encoding="utf8") as file:
|
with open(filename, "w", encoding="utf8") as file:
|
||||||
json.dump(self.data, file, indent=4)
|
json.dump(self.data, file, indent=4, ensure_ascii=False)
|
||||||
|
|
||||||
def same_type(self, x, y):
|
def same_type(self, x, y):
|
||||||
if x is None or y is None:
|
if x is None or y is None:
|
||||||
@@ -206,23 +212,59 @@ class Options:
|
|||||||
d = {k: self.data.get(k, v.default) for k, v in self.data_labels.items()}
|
d = {k: self.data.get(k, v.default) for k, v in self.data_labels.items()}
|
||||||
d["_comments_before"] = {k: v.comment_before for k, v in self.data_labels.items() if v.comment_before is not None}
|
d["_comments_before"] = {k: v.comment_before for k, v in self.data_labels.items() if v.comment_before is not None}
|
||||||
d["_comments_after"] = {k: v.comment_after for k, v in self.data_labels.items() if v.comment_after is not None}
|
d["_comments_after"] = {k: v.comment_after for k, v in self.data_labels.items() if v.comment_after is not None}
|
||||||
|
|
||||||
|
item_categories = {}
|
||||||
|
for item in self.data_labels.values():
|
||||||
|
category = categories.mapping.get(item.category_id)
|
||||||
|
category = "Uncategorized" if category is None else category.label
|
||||||
|
if category not in item_categories:
|
||||||
|
item_categories[category] = item.section[1]
|
||||||
|
|
||||||
|
# _categories is a list of pairs: [section, category]. Each section (a setting page) will get a special heading above it with the category as text.
|
||||||
|
d["_categories"] = [[v, k] for k, v in item_categories.items()] + [["Defaults", "Other"]]
|
||||||
|
|
||||||
return json.dumps(d)
|
return json.dumps(d)
|
||||||
|
|
||||||
def add_option(self, key, info):
|
def add_option(self, key, info):
|
||||||
self.data_labels[key] = info
|
self.data_labels[key] = info
|
||||||
if key not in self.data:
|
if key not in self.data and not info.do_not_save:
|
||||||
self.data[key] = info.default
|
self.data[key] = info.default
|
||||||
|
|
||||||
def reorder(self):
|
def reorder(self):
|
||||||
"""reorder settings so that all items related to section always go together"""
|
"""Reorder settings so that:
|
||||||
|
- all items related to section always go together
|
||||||
|
- all sections belonging to a category go together
|
||||||
|
- sections inside a category are ordered alphabetically
|
||||||
|
- categories are ordered by creation order
|
||||||
|
|
||||||
|
Category is a superset of sections: for category "postprocessing" there could be multiple sections: "face restoration", "upscaling".
|
||||||
|
|
||||||
|
This function also changes items' category_id so that all items belonging to a section have the same category_id.
|
||||||
|
"""
|
||||||
|
|
||||||
|
category_ids = {}
|
||||||
|
section_categories = {}
|
||||||
|
|
||||||
section_ids = {}
|
|
||||||
settings_items = self.data_labels.items()
|
settings_items = self.data_labels.items()
|
||||||
for _, item in settings_items:
|
for _, item in settings_items:
|
||||||
if item.section not in section_ids:
|
if item.section not in section_categories:
|
||||||
section_ids[item.section] = len(section_ids)
|
section_categories[item.section] = item.category_id
|
||||||
|
|
||||||
self.data_labels = dict(sorted(settings_items, key=lambda x: section_ids[x[1].section]))
|
for _, item in settings_items:
|
||||||
|
item.category_id = section_categories.get(item.section)
|
||||||
|
|
||||||
|
for category_id in categories.mapping:
|
||||||
|
if category_id not in category_ids:
|
||||||
|
category_ids[category_id] = len(category_ids)
|
||||||
|
|
||||||
|
def sort_key(x):
|
||||||
|
item: OptionInfo = x[1]
|
||||||
|
category_order = category_ids.get(item.category_id, len(category_ids))
|
||||||
|
section_order = item.section[1]
|
||||||
|
|
||||||
|
return category_order, section_order
|
||||||
|
|
||||||
|
self.data_labels = dict(sorted(settings_items, key=sort_key))
|
||||||
|
|
||||||
def cast_value(self, key, value):
|
def cast_value(self, key, value):
|
||||||
"""casts an arbitrary to the same type as this setting's value with key
|
"""casts an arbitrary to the same type as this setting's value with key
|
||||||
@@ -245,3 +287,22 @@ class Options:
|
|||||||
value = expected_type(value)
|
value = expected_type(value)
|
||||||
|
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class OptionsCategory:
|
||||||
|
id: str
|
||||||
|
label: str
|
||||||
|
|
||||||
|
class OptionsCategories:
|
||||||
|
def __init__(self):
|
||||||
|
self.mapping = {}
|
||||||
|
|
||||||
|
def register_category(self, category_id, label):
|
||||||
|
if category_id in self.mapping:
|
||||||
|
return category_id
|
||||||
|
|
||||||
|
self.mapping[category_id] = OptionsCategory(category_id, label)
|
||||||
|
|
||||||
|
|
||||||
|
categories = OptionsCategories()
|
||||||
|
|||||||
@@ -78,7 +78,7 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir,
|
|||||||
image_data.close()
|
image_data.close()
|
||||||
|
|
||||||
devices.torch_gc()
|
devices.torch_gc()
|
||||||
|
shared.state.end()
|
||||||
return outputs, ui_common.plaintext_to_html(infotext), ''
|
return outputs, ui_common.plaintext_to_html(infotext), ''
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -296,7 +296,7 @@ class StableDiffusionProcessing:
|
|||||||
return conditioning
|
return conditioning
|
||||||
|
|
||||||
def edit_image_conditioning(self, source_image):
|
def edit_image_conditioning(self, source_image):
|
||||||
conditioning_image = images_tensor_to_samples(source_image*0.5+0.5, approximation_indexes.get(opts.sd_vae_encode_method))
|
conditioning_image = shared.sd_model.encode_first_stage(source_image).mode()
|
||||||
|
|
||||||
return conditioning_image
|
return conditioning_image
|
||||||
|
|
||||||
@@ -711,7 +711,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
|||||||
if p.scripts is not None:
|
if p.scripts is not None:
|
||||||
p.scripts.before_process(p)
|
p.scripts.before_process(p)
|
||||||
|
|
||||||
stored_opts = {k: opts.data[k] for k in p.override_settings.keys() if k in opts.data}
|
stored_opts = {k: opts.data[k] if k in opts.data else opts.get_default(k) for k in p.override_settings.keys() if k in opts.data}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# if no checkpoint override or the override checkpoint can't be found, remove override entry and load opts checkpoint
|
# if no checkpoint override or the override checkpoint can't be found, remove override entry and load opts checkpoint
|
||||||
@@ -799,7 +799,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||||||
|
|
||||||
infotexts = []
|
infotexts = []
|
||||||
output_images = []
|
output_images = []
|
||||||
|
|
||||||
with torch.no_grad(), p.sd_model.ema_scope():
|
with torch.no_grad(), p.sd_model.ema_scope():
|
||||||
with devices.autocast():
|
with devices.autocast():
|
||||||
p.init(p.all_prompts, p.all_seeds, p.all_subseeds)
|
p.init(p.all_prompts, p.all_seeds, p.all_subseeds)
|
||||||
@@ -873,7 +872,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||||||
else:
|
else:
|
||||||
if opts.sd_vae_decode_method != 'Full':
|
if opts.sd_vae_decode_method != 'Full':
|
||||||
p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method
|
p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method
|
||||||
|
|
||||||
x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True)
|
x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True)
|
||||||
|
|
||||||
x_samples_ddim = torch.stack(x_samples_ddim).float()
|
x_samples_ddim = torch.stack(x_samples_ddim).float()
|
||||||
@@ -886,6 +884,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||||||
|
|
||||||
devices.torch_gc()
|
devices.torch_gc()
|
||||||
|
|
||||||
|
state.nextjob()
|
||||||
|
|
||||||
if p.scripts is not None:
|
if p.scripts is not None:
|
||||||
p.scripts.postprocess_batch(p, x_samples_ddim, batch_number=n)
|
p.scripts.postprocess_batch(p, x_samples_ddim, batch_number=n)
|
||||||
|
|
||||||
@@ -958,7 +958,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||||||
|
|
||||||
devices.torch_gc()
|
devices.torch_gc()
|
||||||
|
|
||||||
state.nextjob()
|
if not infotexts:
|
||||||
|
infotexts.append(Processed(p, []).infotext(p, 0))
|
||||||
|
|
||||||
p.color_corrections = None
|
p.color_corrections = None
|
||||||
|
|
||||||
@@ -1144,6 +1145,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
|||||||
|
|
||||||
if not self.enable_hr:
|
if not self.enable_hr:
|
||||||
return samples
|
return samples
|
||||||
|
devices.torch_gc()
|
||||||
|
|
||||||
if self.latent_scale_mode is None:
|
if self.latent_scale_mode is None:
|
||||||
decoded_samples = torch.stack(decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)).to(dtype=torch.float32)
|
decoded_samples = torch.stack(decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)).to(dtype=torch.float32)
|
||||||
@@ -1153,8 +1155,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
|||||||
with sd_models.SkipWritingToConfig():
|
with sd_models.SkipWritingToConfig():
|
||||||
sd_models.reload_model_weights(info=self.hr_checkpoint_info)
|
sd_models.reload_model_weights(info=self.hr_checkpoint_info)
|
||||||
|
|
||||||
devices.torch_gc()
|
|
||||||
|
|
||||||
return self.sample_hr_pass(samples, decoded_samples, seeds, subseeds, subseed_strength, prompts)
|
return self.sample_hr_pass(samples, decoded_samples, seeds, subseeds, subseed_strength, prompts)
|
||||||
|
|
||||||
def sample_hr_pass(self, samples, decoded_samples, seeds, subseeds, subseed_strength, prompts):
|
def sample_hr_pass(self, samples, decoded_samples, seeds, subseeds, subseed_strength, prompts):
|
||||||
@@ -1162,7 +1162,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
|||||||
return samples
|
return samples
|
||||||
|
|
||||||
self.is_hr_pass = True
|
self.is_hr_pass = True
|
||||||
|
|
||||||
target_width = self.hr_upscale_to_x
|
target_width = self.hr_upscale_to_x
|
||||||
target_height = self.hr_upscale_to_y
|
target_height = self.hr_upscale_to_y
|
||||||
|
|
||||||
@@ -1251,7 +1250,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
|||||||
decoded_samples = decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)
|
decoded_samples = decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)
|
||||||
|
|
||||||
self.is_hr_pass = False
|
self.is_hr_pass = False
|
||||||
|
|
||||||
return decoded_samples
|
return decoded_samples
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import re
|
|||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
import lark
|
import lark
|
||||||
|
|
||||||
# a prompt like this: "fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][ in background:0.25] [shoddy:masterful:0.5]"
|
# a prompt like this: "fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][: in background:0.25] [shoddy:masterful:0.5]"
|
||||||
# will be represented with prompt_schedule like this (assuming steps=100):
|
# will be represented with prompt_schedule like this (assuming steps=100):
|
||||||
# [25, 'fantasy landscape with a mountain and an oak in foreground shoddy']
|
# [25, 'fantasy landscape with a mountain and an oak in foreground shoddy']
|
||||||
# [50, 'fantasy landscape with a lake and an oak in foreground in background shoddy']
|
# [50, 'fantasy landscape with a lake and an oak in foreground in background shoddy']
|
||||||
|
|||||||
+1
-1
@@ -110,7 +110,7 @@ class ImageRNG:
|
|||||||
self.is_first = True
|
self.is_first = True
|
||||||
|
|
||||||
def first(self):
|
def first(self):
|
||||||
noise_shape = self.shape if self.seed_resize_from_h <= 0 or self.seed_resize_from_w <= 0 else (self.shape[0], self.seed_resize_from_h // 8, self.seed_resize_from_w // 8)
|
noise_shape = self.shape if self.seed_resize_from_h <= 0 or self.seed_resize_from_w <= 0 else (self.shape[0], int(self.seed_resize_from_h) // 8, int(self.seed_resize_from_w // 8))
|
||||||
|
|
||||||
xs = []
|
xs = []
|
||||||
|
|
||||||
|
|||||||
+103
-16
@@ -311,20 +311,113 @@ scripts_data = []
|
|||||||
postprocessing_scripts_data = []
|
postprocessing_scripts_data = []
|
||||||
ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir", "module"])
|
ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir", "module"])
|
||||||
|
|
||||||
|
def topological_sort(dependencies):
|
||||||
|
"""Accepts a dictionary mapping name to its dependencies, returns a list of names ordered according to dependencies.
|
||||||
|
Ignores errors relating to missing dependeencies or circular dependencies
|
||||||
|
"""
|
||||||
|
|
||||||
|
visited = {}
|
||||||
|
result = []
|
||||||
|
|
||||||
|
def inner(name):
|
||||||
|
visited[name] = True
|
||||||
|
|
||||||
|
for dep in dependencies.get(name, []):
|
||||||
|
if dep in dependencies and dep not in visited:
|
||||||
|
inner(dep)
|
||||||
|
|
||||||
|
result.append(name)
|
||||||
|
|
||||||
|
for depname in dependencies:
|
||||||
|
if depname not in visited:
|
||||||
|
inner(depname)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ScriptWithDependencies:
|
||||||
|
script_canonical_name: str
|
||||||
|
file: ScriptFile
|
||||||
|
requires: list
|
||||||
|
load_before: list
|
||||||
|
load_after: list
|
||||||
|
|
||||||
|
|
||||||
def list_scripts(scriptdirname, extension, *, include_extensions=True):
|
def list_scripts(scriptdirname, extension, *, include_extensions=True):
|
||||||
scripts_list = []
|
scripts = {}
|
||||||
|
|
||||||
basedir = os.path.join(paths.script_path, scriptdirname)
|
loaded_extensions = {ext.canonical_name: ext for ext in extensions.active()}
|
||||||
if os.path.exists(basedir):
|
loaded_extensions_scripts = {ext.canonical_name: [] for ext in extensions.active()}
|
||||||
for filename in sorted(os.listdir(basedir)):
|
|
||||||
scripts_list.append(ScriptFile(paths.script_path, filename, os.path.join(basedir, filename)))
|
# build script dependency map
|
||||||
|
root_script_basedir = os.path.join(paths.script_path, scriptdirname)
|
||||||
|
if os.path.exists(root_script_basedir):
|
||||||
|
for filename in sorted(os.listdir(root_script_basedir)):
|
||||||
|
if not os.path.isfile(os.path.join(root_script_basedir, filename)):
|
||||||
|
continue
|
||||||
|
|
||||||
|
if os.path.splitext(filename)[1].lower() != extension:
|
||||||
|
continue
|
||||||
|
|
||||||
|
script_file = ScriptFile(paths.script_path, filename, os.path.join(root_script_basedir, filename))
|
||||||
|
scripts[filename] = ScriptWithDependencies(filename, script_file, [], [], [])
|
||||||
|
|
||||||
if include_extensions:
|
if include_extensions:
|
||||||
for ext in extensions.active():
|
for ext in extensions.active():
|
||||||
scripts_list += ext.list_files(scriptdirname, extension)
|
extension_scripts_list = ext.list_files(scriptdirname, extension)
|
||||||
|
for extension_script in extension_scripts_list:
|
||||||
|
if not os.path.isfile(extension_script.path):
|
||||||
|
continue
|
||||||
|
|
||||||
scripts_list = [x for x in scripts_list if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)]
|
script_canonical_name = ("builtin/" if ext.is_builtin else "") + ext.canonical_name + "/" + extension_script.filename
|
||||||
|
relative_path = scriptdirname + "/" + extension_script.filename
|
||||||
|
|
||||||
|
script = ScriptWithDependencies(
|
||||||
|
script_canonical_name=script_canonical_name,
|
||||||
|
file=extension_script,
|
||||||
|
requires=ext.metadata.get_script_requirements("Requires", relative_path, scriptdirname),
|
||||||
|
load_before=ext.metadata.get_script_requirements("Before", relative_path, scriptdirname),
|
||||||
|
load_after=ext.metadata.get_script_requirements("After", relative_path, scriptdirname),
|
||||||
|
)
|
||||||
|
|
||||||
|
scripts[script_canonical_name] = script
|
||||||
|
loaded_extensions_scripts[ext.canonical_name].append(script)
|
||||||
|
|
||||||
|
for script_canonical_name, script in scripts.items():
|
||||||
|
# load before requires inverse dependency
|
||||||
|
# in this case, append the script name into the load_after list of the specified script
|
||||||
|
for load_before in script.load_before:
|
||||||
|
# if this requires an individual script to be loaded before
|
||||||
|
other_script = scripts.get(load_before)
|
||||||
|
if other_script:
|
||||||
|
other_script.load_after.append(script_canonical_name)
|
||||||
|
|
||||||
|
# if this requires an extension
|
||||||
|
other_extension_scripts = loaded_extensions_scripts.get(load_before)
|
||||||
|
if other_extension_scripts:
|
||||||
|
for other_script in other_extension_scripts:
|
||||||
|
other_script.load_after.append(script_canonical_name)
|
||||||
|
|
||||||
|
# if After mentions an extension, remove it and instead add all of its scripts
|
||||||
|
for load_after in list(script.load_after):
|
||||||
|
if load_after not in scripts and load_after in loaded_extensions_scripts:
|
||||||
|
script.load_after.remove(load_after)
|
||||||
|
|
||||||
|
for other_script in loaded_extensions_scripts.get(load_after, []):
|
||||||
|
script.load_after.append(other_script.script_canonical_name)
|
||||||
|
|
||||||
|
dependencies = {}
|
||||||
|
|
||||||
|
for script_canonical_name, script in scripts.items():
|
||||||
|
for required_script in script.requires:
|
||||||
|
if required_script not in scripts and required_script not in loaded_extensions:
|
||||||
|
errors.report(f'Script "{script_canonical_name}" requires "{required_script}" to be loaded, but it is not.', exc_info=False)
|
||||||
|
|
||||||
|
dependencies[script_canonical_name] = script.load_after
|
||||||
|
|
||||||
|
ordered_scripts = topological_sort(dependencies)
|
||||||
|
scripts_list = [scripts[script_canonical_name].file for script_canonical_name in ordered_scripts]
|
||||||
|
|
||||||
return scripts_list
|
return scripts_list
|
||||||
|
|
||||||
@@ -365,15 +458,9 @@ def load_scripts():
|
|||||||
elif issubclass(script_class, scripts_postprocessing.ScriptPostprocessing):
|
elif issubclass(script_class, scripts_postprocessing.ScriptPostprocessing):
|
||||||
postprocessing_scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module))
|
postprocessing_scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module))
|
||||||
|
|
||||||
def orderby(basedir):
|
# here the scripts_list is already ordered
|
||||||
# 1st webui, 2nd extensions-builtin, 3rd extensions
|
# processing_script is not considered though
|
||||||
priority = {os.path.join(paths.script_path, "extensions-builtin"):1, paths.script_path:0}
|
for scriptfile in scripts_list:
|
||||||
for key in priority:
|
|
||||||
if basedir.startswith(key):
|
|
||||||
return priority[key]
|
|
||||||
return 9999
|
|
||||||
|
|
||||||
for scriptfile in sorted(scripts_list, key=lambda x: [orderby(x.basedir), x]):
|
|
||||||
try:
|
try:
|
||||||
if scriptfile.basedir != paths.script_path:
|
if scriptfile.basedir != paths.script_path:
|
||||||
sys.path = [scriptfile.basedir] + sys.path
|
sys.path = [scriptfile.basedir] + sys.path
|
||||||
|
|||||||
+20
-2
@@ -5,7 +5,7 @@ from types import MethodType
|
|||||||
from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet, patches
|
from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet, patches
|
||||||
from modules.hypernetworks import hypernetwork
|
from modules.hypernetworks import hypernetwork
|
||||||
from modules.shared import cmd_opts
|
from modules.shared import cmd_opts
|
||||||
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr
|
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr, xlmr_m18
|
||||||
|
|
||||||
import ldm.modules.attention
|
import ldm.modules.attention
|
||||||
import ldm.modules.diffusionmodules.model
|
import ldm.modules.diffusionmodules.model
|
||||||
@@ -184,6 +184,20 @@ class StableDiffusionModelHijack:
|
|||||||
errors.display(e, "applying cross attention optimization")
|
errors.display(e, "applying cross attention optimization")
|
||||||
undo_optimizations()
|
undo_optimizations()
|
||||||
|
|
||||||
|
def convert_sdxl_to_ssd(self, m):
|
||||||
|
"""Converts an SDXL model to a Segmind Stable Diffusion model (see https://huggingface.co/segmind/SSD-1B)"""
|
||||||
|
|
||||||
|
delattr(m.model.diffusion_model.middle_block, '1')
|
||||||
|
delattr(m.model.diffusion_model.middle_block, '2')
|
||||||
|
for i in ['9', '8', '7', '6', '5', '4']:
|
||||||
|
delattr(m.model.diffusion_model.input_blocks[7][1].transformer_blocks, i)
|
||||||
|
delattr(m.model.diffusion_model.input_blocks[8][1].transformer_blocks, i)
|
||||||
|
delattr(m.model.diffusion_model.output_blocks[0][1].transformer_blocks, i)
|
||||||
|
delattr(m.model.diffusion_model.output_blocks[1][1].transformer_blocks, i)
|
||||||
|
delattr(m.model.diffusion_model.output_blocks[4][1].transformer_blocks, '1')
|
||||||
|
delattr(m.model.diffusion_model.output_blocks[5][1].transformer_blocks, '1')
|
||||||
|
devices.torch_gc()
|
||||||
|
|
||||||
def hijack(self, m):
|
def hijack(self, m):
|
||||||
conditioner = getattr(m, 'conditioner', None)
|
conditioner = getattr(m, 'conditioner', None)
|
||||||
if conditioner:
|
if conditioner:
|
||||||
@@ -211,7 +225,7 @@ class StableDiffusionModelHijack:
|
|||||||
else:
|
else:
|
||||||
m.cond_stage_model = conditioner
|
m.cond_stage_model = conditioner
|
||||||
|
|
||||||
if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
|
if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation or type(m.cond_stage_model) == xlmr_m18.BertSeriesModelWithTransformation:
|
||||||
model_embeddings = m.cond_stage_model.roberta.embeddings
|
model_embeddings = m.cond_stage_model.roberta.embeddings
|
||||||
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self)
|
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self)
|
||||||
m.cond_stage_model = sd_hijack_xlmr.FrozenXLMREmbedderWithCustomWords(m.cond_stage_model, self)
|
m.cond_stage_model = sd_hijack_xlmr.FrozenXLMREmbedderWithCustomWords(m.cond_stage_model, self)
|
||||||
@@ -242,8 +256,12 @@ class StableDiffusionModelHijack:
|
|||||||
|
|
||||||
self.layers = flatten(m)
|
self.layers = flatten(m)
|
||||||
|
|
||||||
|
import modules.models.diffusion.ddpm_edit
|
||||||
|
|
||||||
if isinstance(m, ldm.models.diffusion.ddpm.LatentDiffusion):
|
if isinstance(m, ldm.models.diffusion.ddpm.LatentDiffusion):
|
||||||
sd_unet.original_forward = ldm_original_forward
|
sd_unet.original_forward = ldm_original_forward
|
||||||
|
elif isinstance(m, modules.models.diffusion.ddpm_edit.LatentDiffusion):
|
||||||
|
sd_unet.original_forward = ldm_original_forward
|
||||||
elif isinstance(m, sgm.models.diffusion.DiffusionEngine):
|
elif isinstance(m, sgm.models.diffusion.DiffusionEngine):
|
||||||
sd_unet.original_forward = sgm_original_forward
|
sd_unet.original_forward = sgm_original_forward
|
||||||
else:
|
else:
|
||||||
|
|||||||
+8
-16
@@ -1,7 +1,6 @@
|
|||||||
import collections
|
import collections
|
||||||
import os.path
|
import os.path
|
||||||
import sys
|
import sys
|
||||||
import gc
|
|
||||||
import threading
|
import threading
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -353,16 +352,19 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
|
|||||||
model.is_sdxl = hasattr(model, 'conditioner')
|
model.is_sdxl = hasattr(model, 'conditioner')
|
||||||
model.is_sd2 = not model.is_sdxl and hasattr(model.cond_stage_model, 'model')
|
model.is_sd2 = not model.is_sdxl and hasattr(model.cond_stage_model, 'model')
|
||||||
model.is_sd1 = not model.is_sdxl and not model.is_sd2
|
model.is_sd1 = not model.is_sdxl and not model.is_sd2
|
||||||
|
model.is_ssd = model.is_sdxl and 'model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight' not in state_dict.keys()
|
||||||
if model.is_sdxl:
|
if model.is_sdxl:
|
||||||
sd_models_xl.extend_sdxl(model)
|
sd_models_xl.extend_sdxl(model)
|
||||||
|
|
||||||
model.load_state_dict(state_dict, strict=False)
|
if model.is_ssd:
|
||||||
timer.record("apply weights to model")
|
sd_hijack.model_hijack.convert_sdxl_to_ssd(model)
|
||||||
|
|
||||||
if shared.opts.sd_checkpoint_cache > 0:
|
if shared.opts.sd_checkpoint_cache > 0:
|
||||||
# cache newly loaded model
|
# cache newly loaded model
|
||||||
checkpoints_loaded[checkpoint_info] = state_dict
|
checkpoints_loaded[checkpoint_info] = state_dict.copy()
|
||||||
|
|
||||||
|
model.load_state_dict(state_dict, strict=False)
|
||||||
|
timer.record("apply weights to model")
|
||||||
|
|
||||||
del state_dict
|
del state_dict
|
||||||
|
|
||||||
@@ -798,17 +800,7 @@ def reload_model_weights(sd_model=None, info=None):
|
|||||||
|
|
||||||
|
|
||||||
def unload_model_weights(sd_model=None, info=None):
|
def unload_model_weights(sd_model=None, info=None):
|
||||||
timer = Timer()
|
send_model_to_cpu(sd_model or shared.sd_model)
|
||||||
|
|
||||||
if model_data.sd_model:
|
|
||||||
model_data.sd_model.to(devices.cpu)
|
|
||||||
sd_hijack.model_hijack.undo_hijack(model_data.sd_model)
|
|
||||||
model_data.sd_model = None
|
|
||||||
sd_model = None
|
|
||||||
gc.collect()
|
|
||||||
devices.torch_gc()
|
|
||||||
|
|
||||||
print(f"Unloaded weights {timer.summary()}.")
|
|
||||||
|
|
||||||
return sd_model
|
return sd_model
|
||||||
|
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ config_unopenclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-h-inf
|
|||||||
config_inpainting = os.path.join(sd_configs_path, "v1-inpainting-inference.yaml")
|
config_inpainting = os.path.join(sd_configs_path, "v1-inpainting-inference.yaml")
|
||||||
config_instruct_pix2pix = os.path.join(sd_configs_path, "instruct-pix2pix.yaml")
|
config_instruct_pix2pix = os.path.join(sd_configs_path, "instruct-pix2pix.yaml")
|
||||||
config_alt_diffusion = os.path.join(sd_configs_path, "alt-diffusion-inference.yaml")
|
config_alt_diffusion = os.path.join(sd_configs_path, "alt-diffusion-inference.yaml")
|
||||||
|
config_alt_diffusion_m18 = os.path.join(sd_configs_path, "alt-diffusion-m18-inference.yaml")
|
||||||
|
|
||||||
def is_using_v_parameterization_for_sd2(state_dict):
|
def is_using_v_parameterization_for_sd2(state_dict):
|
||||||
"""
|
"""
|
||||||
@@ -95,7 +95,10 @@ def guess_model_config_from_state_dict(sd, filename):
|
|||||||
if diffusion_model_input.shape[1] == 8:
|
if diffusion_model_input.shape[1] == 8:
|
||||||
return config_instruct_pix2pix
|
return config_instruct_pix2pix
|
||||||
|
|
||||||
|
|
||||||
if sd.get('cond_stage_model.roberta.embeddings.word_embeddings.weight', None) is not None:
|
if sd.get('cond_stage_model.roberta.embeddings.word_embeddings.weight', None) is not None:
|
||||||
|
if sd.get('cond_stage_model.transformation.weight').size()[0] == 1024:
|
||||||
|
return config_alt_diffusion_m18
|
||||||
return config_alt_diffusion
|
return config_alt_diffusion
|
||||||
|
|
||||||
return config_default
|
return config_default
|
||||||
|
|||||||
@@ -22,7 +22,10 @@ class WebuiSdModel(LatentDiffusion):
|
|||||||
"""structure with additional information about the file with model's weights"""
|
"""structure with additional information about the file with model's weights"""
|
||||||
|
|
||||||
is_sdxl: bool
|
is_sdxl: bool
|
||||||
"""True if the model's architecture is SDXL"""
|
"""True if the model's architecture is SDXL or SSD"""
|
||||||
|
|
||||||
|
is_ssd: bool
|
||||||
|
"""True if the model is SSD"""
|
||||||
|
|
||||||
is_sd2: bool
|
is_sd2: bool
|
||||||
"""True if the model's architecture is SD 2.x"""
|
"""True if the model's architecture is SD 2.x"""
|
||||||
|
|||||||
@@ -60,7 +60,7 @@ def restart_sampler(model, x, sigmas, extra_args=None, callback=None, disable=No
|
|||||||
sigma_restart = get_sigmas_karras(restart_steps, sigmas[min_idx].item(), sigmas[max_idx].item(), device=sigmas.device)[:-1]
|
sigma_restart = get_sigmas_karras(restart_steps, sigmas[min_idx].item(), sigmas[max_idx].item(), device=sigmas.device)[:-1]
|
||||||
while restart_times > 0:
|
while restart_times > 0:
|
||||||
restart_times -= 1
|
restart_times -= 1
|
||||||
step_list.extend([(old_sigma, new_sigma) for (old_sigma, new_sigma) in zip(sigma_restart[:-1], sigma_restart[1:])])
|
step_list.extend(zip(sigma_restart[:-1], sigma_restart[1:]))
|
||||||
|
|
||||||
last_sigma = None
|
last_sigma = None
|
||||||
for old_sigma, new_sigma in tqdm.tqdm(step_list, disable=disable):
|
for old_sigma, new_sigma in tqdm.tqdm(step_list, disable=disable):
|
||||||
|
|||||||
@@ -67,6 +67,8 @@ def reload_hypernetworks():
|
|||||||
|
|
||||||
|
|
||||||
ui_reorder_categories_builtin_items = [
|
ui_reorder_categories_builtin_items = [
|
||||||
|
"prompt",
|
||||||
|
"image",
|
||||||
"inpaint",
|
"inpaint",
|
||||||
"sampler",
|
"sampler",
|
||||||
"accordions",
|
"accordions",
|
||||||
|
|||||||
+38
-22
@@ -3,7 +3,7 @@ import gradio as gr
|
|||||||
from modules import localization, ui_components, shared_items, shared, interrogate, shared_gradio_themes
|
from modules import localization, ui_components, shared_items, shared, interrogate, shared_gradio_themes
|
||||||
from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir # noqa: F401
|
from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir # noqa: F401
|
||||||
from modules.shared_cmd_options import cmd_opts
|
from modules.shared_cmd_options import cmd_opts
|
||||||
from modules.options import options_section, OptionInfo, OptionHTML
|
from modules.options import options_section, OptionInfo, OptionHTML, categories
|
||||||
|
|
||||||
options_templates = {}
|
options_templates = {}
|
||||||
hide_dirs = shared.hide_dirs
|
hide_dirs = shared.hide_dirs
|
||||||
@@ -21,7 +21,14 @@ restricted_opts = {
|
|||||||
"outdir_init_images"
|
"outdir_init_images"
|
||||||
}
|
}
|
||||||
|
|
||||||
options_templates.update(options_section(('saving-images', "Saving images/grids"), {
|
categories.register_category("saving", "Saving images")
|
||||||
|
categories.register_category("sd", "Stable Diffusion")
|
||||||
|
categories.register_category("ui", "User Interface")
|
||||||
|
categories.register_category("system", "System")
|
||||||
|
categories.register_category("postprocessing", "Postprocessing")
|
||||||
|
categories.register_category("training", "Training")
|
||||||
|
|
||||||
|
options_templates.update(options_section(('saving-images', "Saving images/grids", "saving"), {
|
||||||
"samples_save": OptionInfo(True, "Always save all generated images"),
|
"samples_save": OptionInfo(True, "Always save all generated images"),
|
||||||
"samples_format": OptionInfo('png', 'File format for images'),
|
"samples_format": OptionInfo('png', 'File format for images'),
|
||||||
"samples_filename_pattern": OptionInfo("", "Images filename pattern", component_args=hide_dirs).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Images-Filename-Name-and-Subdirectory"),
|
"samples_filename_pattern": OptionInfo("", "Images filename pattern", component_args=hide_dirs).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Images-Filename-Name-and-Subdirectory"),
|
||||||
@@ -62,9 +69,12 @@ options_templates.update(options_section(('saving-images', "Saving images/grids"
|
|||||||
"clean_temp_dir_at_start": OptionInfo(False, "Cleanup non-default temporary directory when starting webui"),
|
"clean_temp_dir_at_start": OptionInfo(False, "Cleanup non-default temporary directory when starting webui"),
|
||||||
|
|
||||||
"save_incomplete_images": OptionInfo(False, "Save incomplete images").info("save images that has been interrupted in mid-generation; even if not saved, they will still show up in webui output."),
|
"save_incomplete_images": OptionInfo(False, "Save incomplete images").info("save images that has been interrupted in mid-generation; even if not saved, they will still show up in webui output."),
|
||||||
|
|
||||||
|
"notification_audio": OptionInfo(True, "Play notification sound after image generation").info("notification.mp3 should be present in the root directory").needs_reload_ui(),
|
||||||
|
"notification_volume": OptionInfo(100, "Notification sound volume", gr.Slider, {"minimum": 0, "maximum": 100, "step": 1}).info("in %"),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('saving-paths', "Paths for saving"), {
|
options_templates.update(options_section(('saving-paths', "Paths for saving", "saving"), {
|
||||||
"outdir_samples": OptionInfo("", "Output directory for images; if empty, defaults to three directories below", component_args=hide_dirs),
|
"outdir_samples": OptionInfo("", "Output directory for images; if empty, defaults to three directories below", component_args=hide_dirs),
|
||||||
"outdir_txt2img_samples": OptionInfo("outputs/txt2img-images", 'Output directory for txt2img images', component_args=hide_dirs),
|
"outdir_txt2img_samples": OptionInfo("outputs/txt2img-images", 'Output directory for txt2img images', component_args=hide_dirs),
|
||||||
"outdir_img2img_samples": OptionInfo("outputs/img2img-images", 'Output directory for img2img images', component_args=hide_dirs),
|
"outdir_img2img_samples": OptionInfo("outputs/img2img-images", 'Output directory for img2img images', component_args=hide_dirs),
|
||||||
@@ -76,7 +86,7 @@ options_templates.update(options_section(('saving-paths', "Paths for saving"), {
|
|||||||
"outdir_init_images": OptionInfo("outputs/init-images", "Directory for saving init images when using img2img", component_args=hide_dirs),
|
"outdir_init_images": OptionInfo("outputs/init-images", "Directory for saving init images when using img2img", component_args=hide_dirs),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('saving-to-dirs', "Saving to a directory"), {
|
options_templates.update(options_section(('saving-to-dirs', "Saving to a directory", "saving"), {
|
||||||
"save_to_dirs": OptionInfo(True, "Save images to a subdirectory"),
|
"save_to_dirs": OptionInfo(True, "Save images to a subdirectory"),
|
||||||
"grid_save_to_dirs": OptionInfo(True, "Save grids to a subdirectory"),
|
"grid_save_to_dirs": OptionInfo(True, "Save grids to a subdirectory"),
|
||||||
"use_save_to_dirs_for_ui": OptionInfo(False, "When using \"Save\" button, save images to a subdirectory"),
|
"use_save_to_dirs_for_ui": OptionInfo(False, "When using \"Save\" button, save images to a subdirectory"),
|
||||||
@@ -84,21 +94,21 @@ options_templates.update(options_section(('saving-to-dirs', "Saving to a directo
|
|||||||
"directories_max_prompt_words": OptionInfo(8, "Max prompt words for [prompt_words] pattern", gr.Slider, {"minimum": 1, "maximum": 20, "step": 1, **hide_dirs}),
|
"directories_max_prompt_words": OptionInfo(8, "Max prompt words for [prompt_words] pattern", gr.Slider, {"minimum": 1, "maximum": 20, "step": 1, **hide_dirs}),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('upscaling', "Upscaling"), {
|
options_templates.update(options_section(('upscaling', "Upscaling", "postprocessing"), {
|
||||||
"ESRGAN_tile": OptionInfo(192, "Tile size for ESRGAN upscalers.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}).info("0 = no tiling"),
|
"ESRGAN_tile": OptionInfo(192, "Tile size for ESRGAN upscalers.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}).info("0 = no tiling"),
|
||||||
"ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap for ESRGAN upscalers.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}).info("Low values = visible seam"),
|
"ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap for ESRGAN upscalers.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}).info("Low values = visible seam"),
|
||||||
"realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI.", gr.CheckboxGroup, lambda: {"choices": shared_items.realesrgan_models_names()}),
|
"realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI.", gr.CheckboxGroup, lambda: {"choices": shared_items.realesrgan_models_names()}),
|
||||||
"upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in shared.sd_upscalers]}),
|
"upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in shared.sd_upscalers]}),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('face-restoration', "Face restoration"), {
|
options_templates.update(options_section(('face-restoration', "Face restoration", "postprocessing"), {
|
||||||
"face_restoration": OptionInfo(False, "Restore faces", infotext='Face restoration').info("will use a third-party model on generation result to reconstruct faces"),
|
"face_restoration": OptionInfo(False, "Restore faces", infotext='Face restoration').info("will use a third-party model on generation result to reconstruct faces"),
|
||||||
"face_restoration_model": OptionInfo("CodeFormer", "Face restoration model", gr.Radio, lambda: {"choices": [x.name() for x in shared.face_restorers]}),
|
"face_restoration_model": OptionInfo("CodeFormer", "Face restoration model", gr.Radio, lambda: {"choices": [x.name() for x in shared.face_restorers]}),
|
||||||
"code_former_weight": OptionInfo(0.5, "CodeFormer weight", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}).info("0 = maximum effect; 1 = minimum effect"),
|
"code_former_weight": OptionInfo(0.5, "CodeFormer weight", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}).info("0 = maximum effect; 1 = minimum effect"),
|
||||||
"face_restoration_unload": OptionInfo(False, "Move face restoration model from VRAM into RAM after processing"),
|
"face_restoration_unload": OptionInfo(False, "Move face restoration model from VRAM into RAM after processing"),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('system', "System"), {
|
options_templates.update(options_section(('system', "System", "system"), {
|
||||||
"auto_launch_browser": OptionInfo("Local", "Automatically open webui in browser on startup", gr.Radio, lambda: {"choices": ["Disable", "Local", "Remote"]}),
|
"auto_launch_browser": OptionInfo("Local", "Automatically open webui in browser on startup", gr.Radio, lambda: {"choices": ["Disable", "Local", "Remote"]}),
|
||||||
"enable_console_prompts": OptionInfo(shared.cmd_opts.enable_console_prompts, "Print prompts to console when generating with txt2img and img2img."),
|
"enable_console_prompts": OptionInfo(shared.cmd_opts.enable_console_prompts, "Print prompts to console when generating with txt2img and img2img."),
|
||||||
"show_warnings": OptionInfo(False, "Show warnings in console.").needs_reload_ui(),
|
"show_warnings": OptionInfo(False, "Show warnings in console.").needs_reload_ui(),
|
||||||
@@ -110,15 +120,16 @@ options_templates.update(options_section(('system', "System"), {
|
|||||||
"list_hidden_files": OptionInfo(True, "Load models/files in hidden directories").info("directory is hidden if its name starts with \".\""),
|
"list_hidden_files": OptionInfo(True, "Load models/files in hidden directories").info("directory is hidden if its name starts with \".\""),
|
||||||
"disable_mmap_load_safetensors": OptionInfo(False, "Disable memmapping for loading .safetensors files.").info("fixes very slow loading speed in some cases"),
|
"disable_mmap_load_safetensors": OptionInfo(False, "Disable memmapping for loading .safetensors files.").info("fixes very slow loading speed in some cases"),
|
||||||
"hide_ldm_prints": OptionInfo(True, "Prevent Stability-AI's ldm/sgm modules from printing noise to console."),
|
"hide_ldm_prints": OptionInfo(True, "Prevent Stability-AI's ldm/sgm modules from printing noise to console."),
|
||||||
|
"dump_stacks_on_signal": OptionInfo(False, "Print stack traces before exiting the program with ctrl+c."),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('API', "API"), {
|
options_templates.update(options_section(('API', "API", "system"), {
|
||||||
"api_enable_requests": OptionInfo(True, "Allow http:// and https:// URLs for input images in API", restrict_api=True),
|
"api_enable_requests": OptionInfo(True, "Allow http:// and https:// URLs for input images in API", restrict_api=True),
|
||||||
"api_forbid_local_requests": OptionInfo(True, "Forbid URLs to local resources", restrict_api=True),
|
"api_forbid_local_requests": OptionInfo(True, "Forbid URLs to local resources", restrict_api=True),
|
||||||
"api_useragent": OptionInfo("", "User agent for requests", restrict_api=True),
|
"api_useragent": OptionInfo("", "User agent for requests", restrict_api=True),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('training', "Training"), {
|
options_templates.update(options_section(('training', "Training", "training"), {
|
||||||
"unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."),
|
"unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."),
|
||||||
"pin_memory": OptionInfo(False, "Turn on pin_memory for DataLoader. Makes training slightly faster but can increase memory usage."),
|
"pin_memory": OptionInfo(False, "Turn on pin_memory for DataLoader. Makes training slightly faster but can increase memory usage."),
|
||||||
"save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training of embedding or HN can be resumed with the matching optim file."),
|
"save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training of embedding or HN can be resumed with the matching optim file."),
|
||||||
@@ -133,7 +144,7 @@ options_templates.update(options_section(('training', "Training"), {
|
|||||||
"training_tensorboard_flush_every": OptionInfo(120, "How often, in seconds, to flush the pending tensorboard events and summaries to disk."),
|
"training_tensorboard_flush_every": OptionInfo(120, "How often, in seconds, to flush the pending tensorboard events and summaries to disk."),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('sd', "Stable Diffusion"), {
|
options_templates.update(options_section(('sd', "Stable Diffusion", "sd"), {
|
||||||
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": shared_items.list_checkpoint_tiles(shared.opts.sd_checkpoint_dropdown_use_short)}, refresh=shared_items.refresh_checkpoints, infotext='Model hash'),
|
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": shared_items.list_checkpoint_tiles(shared.opts.sd_checkpoint_dropdown_use_short)}, refresh=shared_items.refresh_checkpoints, infotext='Model hash'),
|
||||||
"sd_checkpoints_limit": OptionInfo(1, "Maximum number of checkpoints loaded at the same time", gr.Slider, {"minimum": 1, "maximum": 10, "step": 1}),
|
"sd_checkpoints_limit": OptionInfo(1, "Maximum number of checkpoints loaded at the same time", gr.Slider, {"minimum": 1, "maximum": 10, "step": 1}),
|
||||||
"sd_checkpoints_keep_in_cpu": OptionInfo(True, "Only keep one model on device").info("will keep models other than the currently used one in RAM rather than VRAM"),
|
"sd_checkpoints_keep_in_cpu": OptionInfo(True, "Only keep one model on device").info("will keep models other than the currently used one in RAM rather than VRAM"),
|
||||||
@@ -150,14 +161,14 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
|
|||||||
"hires_fix_refiner_pass": OptionInfo("second pass", "Hires fix: which pass to enable refiner for", gr.Radio, {"choices": ["first pass", "second pass", "both passes"]}, infotext="Hires refiner"),
|
"hires_fix_refiner_pass": OptionInfo("second pass", "Hires fix: which pass to enable refiner for", gr.Radio, {"choices": ["first pass", "second pass", "both passes"]}, infotext="Hires refiner"),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('sdxl', "Stable Diffusion XL"), {
|
options_templates.update(options_section(('sdxl', "Stable Diffusion XL", "sd"), {
|
||||||
"sdxl_crop_top": OptionInfo(0, "crop top coordinate"),
|
"sdxl_crop_top": OptionInfo(0, "crop top coordinate"),
|
||||||
"sdxl_crop_left": OptionInfo(0, "crop left coordinate"),
|
"sdxl_crop_left": OptionInfo(0, "crop left coordinate"),
|
||||||
"sdxl_refiner_low_aesthetic_score": OptionInfo(2.5, "SDXL low aesthetic score", gr.Number).info("used for refiner model negative prompt"),
|
"sdxl_refiner_low_aesthetic_score": OptionInfo(2.5, "SDXL low aesthetic score", gr.Number).info("used for refiner model negative prompt"),
|
||||||
"sdxl_refiner_high_aesthetic_score": OptionInfo(6.0, "SDXL high aesthetic score", gr.Number).info("used for refiner model prompt"),
|
"sdxl_refiner_high_aesthetic_score": OptionInfo(6.0, "SDXL high aesthetic score", gr.Number).info("used for refiner model prompt"),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('vae', "VAE"), {
|
options_templates.update(options_section(('vae', "VAE", "sd"), {
|
||||||
"sd_vae_explanation": OptionHTML("""
|
"sd_vae_explanation": OptionHTML("""
|
||||||
<abbr title='Variational autoencoder'>VAE</abbr> is a neural network that transforms a standard <abbr title='red/green/blue'>RGB</abbr>
|
<abbr title='Variational autoencoder'>VAE</abbr> is a neural network that transforms a standard <abbr title='red/green/blue'>RGB</abbr>
|
||||||
image into latent space representation and back. Latent space representation is what stable diffusion is working on during sampling
|
image into latent space representation and back. Latent space representation is what stable diffusion is working on during sampling
|
||||||
@@ -172,7 +183,7 @@ For img2img, VAE is used to process user's input image before the sampling, and
|
|||||||
"sd_vae_decode_method": OptionInfo("Full", "VAE type for decode", gr.Radio, {"choices": ["Full", "TAESD"]}, infotext='VAE Decoder').info("method to decode latent to image"),
|
"sd_vae_decode_method": OptionInfo("Full", "VAE type for decode", gr.Radio, {"choices": ["Full", "TAESD"]}, infotext='VAE Decoder').info("method to decode latent to image"),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('img2img', "img2img"), {
|
options_templates.update(options_section(('img2img', "img2img", "sd"), {
|
||||||
"inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}, infotext='Conditional mask weight'),
|
"inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}, infotext='Conditional mask weight'),
|
||||||
"initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.0, "maximum": 1.5, "step": 0.001}, infotext='Noise multiplier'),
|
"initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.0, "maximum": 1.5, "step": 0.001}, infotext='Noise multiplier'),
|
||||||
"img2img_extra_noise": OptionInfo(0.0, "Extra noise multiplier for img2img and hires fix", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}, infotext='Extra noise').info("0 = disabled (default); should be lower than denoising strength"),
|
"img2img_extra_noise": OptionInfo(0.0, "Extra noise multiplier for img2img and hires fix", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}, infotext='Extra noise').info("0 = disabled (default); should be lower than denoising strength"),
|
||||||
@@ -185,9 +196,10 @@ options_templates.update(options_section(('img2img', "img2img"), {
|
|||||||
"img2img_inpaint_sketch_default_brush_color": OptionInfo("#ffffff", "Inpaint sketch initial brush color", ui_components.FormColorPicker, {}).info("default brush color of img2img inpaint sketch").needs_reload_ui(),
|
"img2img_inpaint_sketch_default_brush_color": OptionInfo("#ffffff", "Inpaint sketch initial brush color", ui_components.FormColorPicker, {}).info("default brush color of img2img inpaint sketch").needs_reload_ui(),
|
||||||
"return_mask": OptionInfo(False, "For inpainting, include the greyscale mask in results for web"),
|
"return_mask": OptionInfo(False, "For inpainting, include the greyscale mask in results for web"),
|
||||||
"return_mask_composite": OptionInfo(False, "For inpainting, include masked composite in results for web"),
|
"return_mask_composite": OptionInfo(False, "For inpainting, include masked composite in results for web"),
|
||||||
|
"img2img_batch_show_results_limit": OptionInfo(32, "Show the first N batch img2img results in UI", gr.Slider, {"minimum": -1, "maximum": 1000, "step": 1}).info('0: disable, -1: show all images. Too many images can cause lag'),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('optimizations', "Optimizations"), {
|
options_templates.update(options_section(('optimizations', "Optimizations", "sd"), {
|
||||||
"cross_attention_optimization": OptionInfo("Automatic", "Cross attention optimization", gr.Dropdown, lambda: {"choices": shared_items.cross_attention_optimizations()}),
|
"cross_attention_optimization": OptionInfo("Automatic", "Cross attention optimization", gr.Dropdown, lambda: {"choices": shared_items.cross_attention_optimizations()}),
|
||||||
"s_min_uncond": OptionInfo(0.0, "Negative Guidance minimum sigma", gr.Slider, {"minimum": 0.0, "maximum": 15.0, "step": 0.01}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9177").info("skip negative prompt for some steps when the image is almost ready; 0=disable, higher=faster"),
|
"s_min_uncond": OptionInfo(0.0, "Negative Guidance minimum sigma", gr.Slider, {"minimum": 0.0, "maximum": 15.0, "step": 0.01}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9177").info("skip negative prompt for some steps when the image is almost ready; 0=disable, higher=faster"),
|
||||||
"token_merging_ratio": OptionInfo(0.0, "Token merging ratio", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}, infotext='Token merging ratio').link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9256").info("0=disable, higher=faster"),
|
"token_merging_ratio": OptionInfo(0.0, "Token merging ratio", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}, infotext='Token merging ratio').link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9256").info("0=disable, higher=faster"),
|
||||||
@@ -198,7 +210,7 @@ options_templates.update(options_section(('optimizations', "Optimizations"), {
|
|||||||
"batch_cond_uncond": OptionInfo(True, "Batch cond/uncond").info("do both conditional and unconditional denoising in one batch; uses a bit more VRAM during sampling, but improves speed; previously this was controlled by --always-batch-cond-uncond comandline argument"),
|
"batch_cond_uncond": OptionInfo(True, "Batch cond/uncond").info("do both conditional and unconditional denoising in one batch; uses a bit more VRAM during sampling, but improves speed; previously this was controlled by --always-batch-cond-uncond comandline argument"),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('compatibility', "Compatibility"), {
|
options_templates.update(options_section(('compatibility', "Compatibility", "sd"), {
|
||||||
"use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."),
|
"use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."),
|
||||||
"use_old_karras_scheduler_sigmas": OptionInfo(False, "Use old karras scheduler sigmas (0.1 to 10)."),
|
"use_old_karras_scheduler_sigmas": OptionInfo(False, "Use old karras scheduler sigmas (0.1 to 10)."),
|
||||||
"no_dpmpp_sde_batch_determinism": OptionInfo(False, "Do not make DPM++ SDE deterministic across different batch sizes."),
|
"no_dpmpp_sde_batch_determinism": OptionInfo(False, "Do not make DPM++ SDE deterministic across different batch sizes."),
|
||||||
@@ -223,7 +235,7 @@ options_templates.update(options_section(('interrogate', "Interrogate"), {
|
|||||||
"deepbooru_filter_tags": OptionInfo("", "deepbooru: filter out those tags").info("separate by comma"),
|
"deepbooru_filter_tags": OptionInfo("", "deepbooru: filter out those tags").info("separate by comma"),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('extra_networks', "Extra Networks"), {
|
options_templates.update(options_section(('extra_networks', "Extra Networks", "sd"), {
|
||||||
"extra_networks_show_hidden_directories": OptionInfo(True, "Show hidden directories").info("directory is hidden if its name starts with \".\"."),
|
"extra_networks_show_hidden_directories": OptionInfo(True, "Show hidden directories").info("directory is hidden if its name starts with \".\"."),
|
||||||
"extra_networks_hidden_models": OptionInfo("When searched", "Show cards for models in hidden directories", gr.Radio, {"choices": ["Always", "When searched", "Never"]}).info('"When searched" option will only show the item when the search string has 4 characters or more'),
|
"extra_networks_hidden_models": OptionInfo("When searched", "Show cards for models in hidden directories", gr.Radio, {"choices": ["Always", "When searched", "Never"]}).info('"When searched" option will only show the item when the search string has 4 characters or more'),
|
||||||
"extra_networks_default_multiplier": OptionInfo(1.0, "Default multiplier for extra networks", gr.Slider, {"minimum": 0.0, "maximum": 2.0, "step": 0.01}),
|
"extra_networks_default_multiplier": OptionInfo(1.0, "Default multiplier for extra networks", gr.Slider, {"minimum": 0.0, "maximum": 2.0, "step": 0.01}),
|
||||||
@@ -231,6 +243,8 @@ options_templates.update(options_section(('extra_networks', "Extra Networks"), {
|
|||||||
"extra_networks_card_height": OptionInfo(0, "Card height for Extra Networks").info("in pixels"),
|
"extra_networks_card_height": OptionInfo(0, "Card height for Extra Networks").info("in pixels"),
|
||||||
"extra_networks_card_text_scale": OptionInfo(1.0, "Card text scale", gr.Slider, {"minimum": 0.0, "maximum": 2.0, "step": 0.01}).info("1 = original size"),
|
"extra_networks_card_text_scale": OptionInfo(1.0, "Card text scale", gr.Slider, {"minimum": 0.0, "maximum": 2.0, "step": 0.01}).info("1 = original size"),
|
||||||
"extra_networks_card_show_desc": OptionInfo(True, "Show description on card"),
|
"extra_networks_card_show_desc": OptionInfo(True, "Show description on card"),
|
||||||
|
"extra_networks_card_order_field": OptionInfo("Path", "Default order field for Extra Networks cards", gr.Dropdown, {"choices": ['Path', 'Name', 'Date Created', 'Date Modified']}).needs_reload_ui(),
|
||||||
|
"extra_networks_card_order": OptionInfo("Ascending", "Default order for Extra Networks cards", gr.Dropdown, {"choices": ['Ascending', 'Descending']}).needs_reload_ui(),
|
||||||
"extra_networks_add_text_separator": OptionInfo(" ", "Extra networks separator").info("extra text to add before <...> when adding extra network to prompt"),
|
"extra_networks_add_text_separator": OptionInfo(" ", "Extra networks separator").info("extra text to add before <...> when adding extra network to prompt"),
|
||||||
"ui_extra_networks_tab_reorder": OptionInfo("", "Extra networks tab order").needs_reload_ui(),
|
"ui_extra_networks_tab_reorder": OptionInfo("", "Extra networks tab order").needs_reload_ui(),
|
||||||
"textual_inversion_print_at_load": OptionInfo(False, "Print a list of Textual Inversion embeddings when loading model"),
|
"textual_inversion_print_at_load": OptionInfo(False, "Print a list of Textual Inversion embeddings when loading model"),
|
||||||
@@ -238,7 +252,7 @@ options_templates.update(options_section(('extra_networks', "Extra Networks"), {
|
|||||||
"sd_hypernetwork": OptionInfo("None", "Add hypernetwork to prompt", gr.Dropdown, lambda: {"choices": ["None", *shared.hypernetworks]}, refresh=shared_items.reload_hypernetworks),
|
"sd_hypernetwork": OptionInfo("None", "Add hypernetwork to prompt", gr.Dropdown, lambda: {"choices": ["None", *shared.hypernetworks]}, refresh=shared_items.reload_hypernetworks),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('ui', "User interface"), {
|
options_templates.update(options_section(('ui', "User interface", "ui"), {
|
||||||
"localization": OptionInfo("None", "Localization", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)).needs_reload_ui(),
|
"localization": OptionInfo("None", "Localization", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)).needs_reload_ui(),
|
||||||
"gradio_theme": OptionInfo("Default", "Gradio theme", ui_components.DropdownEditable, lambda: {"choices": ["Default"] + shared_gradio_themes.gradio_hf_hub_themes}).info("you can also manually enter any of themes from the <a href='https://huggingface.co/spaces/gradio/theme-gallery'>gallery</a>.").needs_reload_ui(),
|
"gradio_theme": OptionInfo("Default", "Gradio theme", ui_components.DropdownEditable, lambda: {"choices": ["Default"] + shared_gradio_themes.gradio_hf_hub_themes}).info("you can also manually enter any of themes from the <a href='https://huggingface.co/spaces/gradio/theme-gallery'>gallery</a>.").needs_reload_ui(),
|
||||||
"gradio_themes_cache": OptionInfo(True, "Cache gradio themes locally").info("disable to update the selected Gradio theme"),
|
"gradio_themes_cache": OptionInfo(True, "Cache gradio themes locally").info("disable to update the selected Gradio theme"),
|
||||||
@@ -267,10 +281,13 @@ options_templates.update(options_section(('ui', "User interface"), {
|
|||||||
"hires_fix_show_sampler": OptionInfo(False, "Hires fix: show hires checkpoint and sampler selection").needs_reload_ui(),
|
"hires_fix_show_sampler": OptionInfo(False, "Hires fix: show hires checkpoint and sampler selection").needs_reload_ui(),
|
||||||
"hires_fix_show_prompts": OptionInfo(False, "Hires fix: show hires prompt and negative prompt").needs_reload_ui(),
|
"hires_fix_show_prompts": OptionInfo(False, "Hires fix: show hires prompt and negative prompt").needs_reload_ui(),
|
||||||
"disable_token_counters": OptionInfo(False, "Disable prompt token counters").needs_reload_ui(),
|
"disable_token_counters": OptionInfo(False, "Disable prompt token counters").needs_reload_ui(),
|
||||||
|
"txt2img_settings_accordion": OptionInfo(False, "Settings in txt2img hidden under Accordion").needs_reload_ui(),
|
||||||
|
"img2img_settings_accordion": OptionInfo(False, "Settings in img2img hidden under Accordion").needs_reload_ui(),
|
||||||
|
"compact_prompt_box": OptionInfo(False, "Compact prompt layout").info("puts prompt and negative prompt inside the Generate tab, leaving more vertical space for the image on the right").needs_reload_ui(),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
|
|
||||||
options_templates.update(options_section(('infotext', "Infotext"), {
|
options_templates.update(options_section(('infotext', "Infotext", "ui"), {
|
||||||
"add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
|
"add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
|
||||||
"add_model_name_to_info": OptionInfo(True, "Add model name to generation information"),
|
"add_model_name_to_info": OptionInfo(True, "Add model name to generation information"),
|
||||||
"add_user_name_to_info": OptionInfo(False, "Add user name to generation information when authenticated"),
|
"add_user_name_to_info": OptionInfo(False, "Add user name to generation information when authenticated"),
|
||||||
@@ -285,7 +302,7 @@ options_templates.update(options_section(('infotext', "Infotext"), {
|
|||||||
|
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('ui', "Live previews"), {
|
options_templates.update(options_section(('ui', "Live previews", "ui"), {
|
||||||
"show_progressbar": OptionInfo(True, "Show progressbar"),
|
"show_progressbar": OptionInfo(True, "Show progressbar"),
|
||||||
"live_previews_enable": OptionInfo(True, "Show live previews of the created image"),
|
"live_previews_enable": OptionInfo(True, "Show live previews of the created image"),
|
||||||
"live_previews_image_format": OptionInfo("png", "Live preview file format", gr.Radio, {"choices": ["jpeg", "png", "webp"]}),
|
"live_previews_image_format": OptionInfo("png", "Live preview file format", gr.Radio, {"choices": ["jpeg", "png", "webp"]}),
|
||||||
@@ -298,7 +315,7 @@ options_templates.update(options_section(('ui', "Live previews"), {
|
|||||||
"live_preview_fast_interrupt": OptionInfo(False, "Return image with chosen live preview method on interrupt").info("makes interrupts faster"),
|
"live_preview_fast_interrupt": OptionInfo(False, "Return image with chosen live preview method on interrupt").info("makes interrupts faster"),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('sampler-params', "Sampler parameters"), {
|
options_templates.update(options_section(('sampler-params', "Sampler parameters", "sd"), {
|
||||||
"hide_samplers": OptionInfo([], "Hide samplers in user interface", gr.CheckboxGroup, lambda: {"choices": [x.name for x in shared_items.list_samplers()]}).needs_reload_ui(),
|
"hide_samplers": OptionInfo([], "Hide samplers in user interface", gr.CheckboxGroup, lambda: {"choices": [x.name for x in shared_items.list_samplers()]}).needs_reload_ui(),
|
||||||
"eta_ddim": OptionInfo(0.0, "Eta for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}, infotext='Eta DDIM').info("noise multiplier; higher = more unpredictable results"),
|
"eta_ddim": OptionInfo(0.0, "Eta for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}, infotext='Eta DDIM').info("noise multiplier; higher = more unpredictable results"),
|
||||||
"eta_ancestral": OptionInfo(1.0, "Eta for k-diffusion samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}, infotext='Eta').info("noise multiplier; currently only applies to ancestral samplers (i.e. Euler a) and SDE samplers"),
|
"eta_ancestral": OptionInfo(1.0, "Eta for k-diffusion samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}, infotext='Eta').info("noise multiplier; currently only applies to ancestral samplers (i.e. Euler a) and SDE samplers"),
|
||||||
@@ -320,7 +337,7 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters"
|
|||||||
'uni_pc_lower_order_final': OptionInfo(True, "UniPC lower order final", infotext='UniPC lower order final'),
|
'uni_pc_lower_order_final': OptionInfo(True, "UniPC lower order final", infotext='UniPC lower order final'),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('postprocessing', "Postprocessing"), {
|
options_templates.update(options_section(('postprocessing', "Postprocessing", "postprocessing"), {
|
||||||
'postprocessing_enable_in_main_ui': OptionInfo([], "Enable postprocessing operations in txt2img and img2img tabs", ui_components.DropdownMulti, lambda: {"choices": [x.name for x in shared_items.postprocessing_scripts()]}),
|
'postprocessing_enable_in_main_ui': OptionInfo([], "Enable postprocessing operations in txt2img and img2img tabs", ui_components.DropdownMulti, lambda: {"choices": [x.name for x in shared_items.postprocessing_scripts()]}),
|
||||||
'postprocessing_operation_order': OptionInfo([], "Postprocessing operation order", ui_components.DropdownMulti, lambda: {"choices": [x.name for x in shared_items.postprocessing_scripts()]}),
|
'postprocessing_operation_order': OptionInfo([], "Postprocessing operation order", ui_components.DropdownMulti, lambda: {"choices": [x.name for x in shared_items.postprocessing_scripts()]}),
|
||||||
'upscaling_max_images_in_cache': OptionInfo(5, "Maximum number of images in upscaling cache", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
|
'upscaling_max_images_in_cache': OptionInfo(5, "Maximum number of images in upscaling cache", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
|
||||||
@@ -332,4 +349,3 @@ options_templates.update(options_section((None, "Hidden options"), {
|
|||||||
"restore_config_state_file": OptionInfo("", "Config state file to restore from, under 'config-states/' folder"),
|
"restore_config_state_file": OptionInfo("", "Config state file to restore from, under 'config-states/' folder"),
|
||||||
"sd_checkpoint_hash": OptionInfo("", "SHA256 hash of the current checkpoint"),
|
"sd_checkpoint_hash": OptionInfo("", "SHA256 hash of the current checkpoint"),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
|
|||||||
+1
-17
@@ -1,7 +1,6 @@
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import traceback
|
|
||||||
|
|
||||||
import platform
|
import platform
|
||||||
import hashlib
|
import hashlib
|
||||||
@@ -84,7 +83,7 @@ def get_dict():
|
|||||||
"Checksum": checksum_token,
|
"Checksum": checksum_token,
|
||||||
"Commandline": get_argv(),
|
"Commandline": get_argv(),
|
||||||
"Torch env info": get_torch_sysinfo(),
|
"Torch env info": get_torch_sysinfo(),
|
||||||
"Exceptions": get_exceptions(),
|
"Exceptions": errors.get_exceptions(),
|
||||||
"CPU": {
|
"CPU": {
|
||||||
"model": platform.processor(),
|
"model": platform.processor(),
|
||||||
"count logical": psutil.cpu_count(logical=True),
|
"count logical": psutil.cpu_count(logical=True),
|
||||||
@@ -104,21 +103,6 @@ def get_dict():
|
|||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
def format_traceback(tb):
|
|
||||||
return [[f"{x.filename}, line {x.lineno}, {x.name}", x.line] for x in traceback.extract_tb(tb)]
|
|
||||||
|
|
||||||
|
|
||||||
def format_exception(e, tb):
|
|
||||||
return {"exception": str(e), "traceback": format_traceback(tb)}
|
|
||||||
|
|
||||||
|
|
||||||
def get_exceptions():
|
|
||||||
try:
|
|
||||||
return list(reversed(errors.exception_records))
|
|
||||||
except Exception as e:
|
|
||||||
return str(e)
|
|
||||||
|
|
||||||
|
|
||||||
def get_environment():
|
def get_environment():
|
||||||
return {k: os.environ[k] for k in sorted(os.environ) if k in environment_whitelist}
|
return {k: os.environ[k] for k in sorted(os.environ) if k in environment_whitelist}
|
||||||
|
|
||||||
|
|||||||
@@ -181,40 +181,7 @@ class EmbeddingDatabase:
|
|||||||
else:
|
else:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
embedding = create_embedding_from_data(data, name, filename=filename, filepath=path)
|
||||||
# textual inversion embeddings
|
|
||||||
if 'string_to_param' in data:
|
|
||||||
param_dict = data['string_to_param']
|
|
||||||
param_dict = getattr(param_dict, '_parameters', param_dict) # fix for torch 1.12.1 loading saved file from torch 1.11
|
|
||||||
assert len(param_dict) == 1, 'embedding file has multiple terms in it'
|
|
||||||
emb = next(iter(param_dict.items()))[1]
|
|
||||||
vec = emb.detach().to(devices.device, dtype=torch.float32)
|
|
||||||
shape = vec.shape[-1]
|
|
||||||
vectors = vec.shape[0]
|
|
||||||
elif type(data) == dict and 'clip_g' in data and 'clip_l' in data: # SDXL embedding
|
|
||||||
vec = {k: v.detach().to(devices.device, dtype=torch.float32) for k, v in data.items()}
|
|
||||||
shape = data['clip_g'].shape[-1] + data['clip_l'].shape[-1]
|
|
||||||
vectors = data['clip_g'].shape[0]
|
|
||||||
elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: # diffuser concepts
|
|
||||||
assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
|
|
||||||
|
|
||||||
emb = next(iter(data.values()))
|
|
||||||
if len(emb.shape) == 1:
|
|
||||||
emb = emb.unsqueeze(0)
|
|
||||||
vec = emb.detach().to(devices.device, dtype=torch.float32)
|
|
||||||
shape = vec.shape[-1]
|
|
||||||
vectors = vec.shape[0]
|
|
||||||
else:
|
|
||||||
raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.")
|
|
||||||
|
|
||||||
embedding = Embedding(vec, name)
|
|
||||||
embedding.step = data.get('step', None)
|
|
||||||
embedding.sd_checkpoint = data.get('sd_checkpoint', None)
|
|
||||||
embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
|
|
||||||
embedding.vectors = vectors
|
|
||||||
embedding.shape = shape
|
|
||||||
embedding.filename = path
|
|
||||||
embedding.set_hash(hashes.sha256(embedding.filename, "textual_inversion/" + name) or '')
|
|
||||||
|
|
||||||
if self.expected_shape == -1 or self.expected_shape == embedding.shape:
|
if self.expected_shape == -1 or self.expected_shape == embedding.shape:
|
||||||
self.register_embedding(embedding, shared.sd_model)
|
self.register_embedding(embedding, shared.sd_model)
|
||||||
@@ -313,6 +280,45 @@ def create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'):
|
|||||||
return fn
|
return fn
|
||||||
|
|
||||||
|
|
||||||
|
def create_embedding_from_data(data, name, filename='unknown embedding file', filepath=None):
|
||||||
|
if 'string_to_param' in data: # textual inversion embeddings
|
||||||
|
param_dict = data['string_to_param']
|
||||||
|
param_dict = getattr(param_dict, '_parameters', param_dict) # fix for torch 1.12.1 loading saved file from torch 1.11
|
||||||
|
assert len(param_dict) == 1, 'embedding file has multiple terms in it'
|
||||||
|
emb = next(iter(param_dict.items()))[1]
|
||||||
|
vec = emb.detach().to(devices.device, dtype=torch.float32)
|
||||||
|
shape = vec.shape[-1]
|
||||||
|
vectors = vec.shape[0]
|
||||||
|
elif type(data) == dict and 'clip_g' in data and 'clip_l' in data: # SDXL embedding
|
||||||
|
vec = {k: v.detach().to(devices.device, dtype=torch.float32) for k, v in data.items()}
|
||||||
|
shape = data['clip_g'].shape[-1] + data['clip_l'].shape[-1]
|
||||||
|
vectors = data['clip_g'].shape[0]
|
||||||
|
elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: # diffuser concepts
|
||||||
|
assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
|
||||||
|
|
||||||
|
emb = next(iter(data.values()))
|
||||||
|
if len(emb.shape) == 1:
|
||||||
|
emb = emb.unsqueeze(0)
|
||||||
|
vec = emb.detach().to(devices.device, dtype=torch.float32)
|
||||||
|
shape = vec.shape[-1]
|
||||||
|
vectors = vec.shape[0]
|
||||||
|
else:
|
||||||
|
raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.")
|
||||||
|
|
||||||
|
embedding = Embedding(vec, name)
|
||||||
|
embedding.step = data.get('step', None)
|
||||||
|
embedding.sd_checkpoint = data.get('sd_checkpoint', None)
|
||||||
|
embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
|
||||||
|
embedding.vectors = vectors
|
||||||
|
embedding.shape = shape
|
||||||
|
|
||||||
|
if filepath:
|
||||||
|
embedding.filename = filepath
|
||||||
|
embedding.set_hash(hashes.sha256(filepath, "textual_inversion/" + name) or '')
|
||||||
|
|
||||||
|
return embedding
|
||||||
|
|
||||||
|
|
||||||
def write_loss(log_directory, filename, step, epoch_len, values):
|
def write_loss(log_directory, filename, step, epoch_len, values):
|
||||||
if shared.opts.training_write_csv_every == 0:
|
if shared.opts.training_write_csv_every == 0:
|
||||||
return
|
return
|
||||||
|
|||||||
+41
-99
@@ -4,6 +4,7 @@ import os
|
|||||||
import sys
|
import sys
|
||||||
from functools import reduce
|
from functools import reduce
|
||||||
import warnings
|
import warnings
|
||||||
|
from contextlib import ExitStack
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
import gradio.utils
|
import gradio.utils
|
||||||
@@ -12,7 +13,7 @@ from PIL import Image, PngImagePlugin # noqa: F401
|
|||||||
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
|
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
|
||||||
|
|
||||||
from modules import gradio_extensons # noqa: F401
|
from modules import gradio_extensons # noqa: F401
|
||||||
from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, shared_items, ui_settings, timer, sysinfo, ui_checkpoint_merger, ui_prompt_styles, scripts, sd_samplers, processing, ui_extra_networks
|
from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, shared_items, ui_settings, timer, sysinfo, ui_checkpoint_merger, scripts, sd_samplers, processing, ui_extra_networks, ui_toprow
|
||||||
from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML, InputAccordion, ResizeHandleRow
|
from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML, InputAccordion, ResizeHandleRow
|
||||||
from modules.paths import script_path
|
from modules.paths import script_path
|
||||||
from modules.ui_common import create_refresh_button
|
from modules.ui_common import create_refresh_button
|
||||||
@@ -25,15 +26,14 @@ import modules.hypernetworks.ui as hypernetworks_ui
|
|||||||
import modules.textual_inversion.ui as textual_inversion_ui
|
import modules.textual_inversion.ui as textual_inversion_ui
|
||||||
import modules.textual_inversion.textual_inversion as textual_inversion
|
import modules.textual_inversion.textual_inversion as textual_inversion
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
import modules.images
|
|
||||||
from modules import prompt_parser
|
from modules import prompt_parser
|
||||||
from modules.sd_hijack import model_hijack
|
from modules.sd_hijack import model_hijack
|
||||||
from modules.generation_parameters_copypaste import image_from_url_text
|
from modules.generation_parameters_copypaste import image_from_url_text
|
||||||
|
|
||||||
create_setting_component = ui_settings.create_setting_component
|
create_setting_component = ui_settings.create_setting_component
|
||||||
|
|
||||||
warnings.filterwarnings("default" if opts.show_warnings else "ignore", category=UserWarning)
|
# warnings.filterwarnings("default" if opts.show_warnings else "ignore", category=UserWarning)
|
||||||
warnings.filterwarnings("default" if opts.show_gradio_deprecation_warnings else "ignore", category=gr.deprecation.GradioDeprecationWarning)
|
# warnings.filterwarnings("default" if opts.show_gradio_deprecation_warnings else "ignore", category=gr.deprecation.GradioDeprecationWarning)
|
||||||
|
|
||||||
# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI
|
# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI
|
||||||
mimetypes.init()
|
mimetypes.init()
|
||||||
@@ -177,79 +177,6 @@ def update_negative_prompt_token_counter(text, steps):
|
|||||||
return update_token_counter(text, steps, is_positive=False)
|
return update_token_counter(text, steps, is_positive=False)
|
||||||
|
|
||||||
|
|
||||||
class Toprow:
|
|
||||||
"""Creates a top row UI with prompts, generate button, styles, extra little buttons for things, and enables some functionality related to their operation"""
|
|
||||||
|
|
||||||
def __init__(self, is_img2img):
|
|
||||||
id_part = "img2img" if is_img2img else "txt2img"
|
|
||||||
self.id_part = id_part
|
|
||||||
|
|
||||||
with gr.Row(elem_id=f"{id_part}_toprow", variant="compact"):
|
|
||||||
with gr.Column(elem_id=f"{id_part}_prompt_container", scale=6):
|
|
||||||
with gr.Row():
|
|
||||||
with gr.Column(scale=80):
|
|
||||||
with gr.Row():
|
|
||||||
self.prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, lines=3, placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)", elem_classes=["prompt"])
|
|
||||||
self.prompt_img = gr.File(label="", elem_id=f"{id_part}_prompt_image", file_count="single", type="binary", visible=False)
|
|
||||||
|
|
||||||
with gr.Row():
|
|
||||||
with gr.Column(scale=80):
|
|
||||||
with gr.Row():
|
|
||||||
self.negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)", elem_classes=["prompt"])
|
|
||||||
|
|
||||||
self.button_interrogate = None
|
|
||||||
self.button_deepbooru = None
|
|
||||||
if is_img2img:
|
|
||||||
with gr.Column(scale=1, elem_classes="interrogate-col"):
|
|
||||||
self.button_interrogate = gr.Button('Interrogate\nCLIP', elem_id="interrogate")
|
|
||||||
self.button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru")
|
|
||||||
|
|
||||||
with gr.Column(scale=1, elem_id=f"{id_part}_actions_column"):
|
|
||||||
with gr.Row(elem_id=f"{id_part}_generate_box", elem_classes="generate-box"):
|
|
||||||
self.interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt", elem_classes="generate-box-interrupt")
|
|
||||||
self.skip = gr.Button('Skip', elem_id=f"{id_part}_skip", elem_classes="generate-box-skip")
|
|
||||||
self.submit = gr.Button('Generate', elem_id=f"{id_part}_generate", variant='primary')
|
|
||||||
|
|
||||||
self.skip.click(
|
|
||||||
fn=lambda: shared.state.skip(),
|
|
||||||
inputs=[],
|
|
||||||
outputs=[],
|
|
||||||
)
|
|
||||||
|
|
||||||
self.interrupt.click(
|
|
||||||
fn=lambda: shared.state.interrupt(),
|
|
||||||
inputs=[],
|
|
||||||
outputs=[],
|
|
||||||
)
|
|
||||||
|
|
||||||
with gr.Row(elem_id=f"{id_part}_tools"):
|
|
||||||
self.paste = ToolButton(value=paste_symbol, elem_id="paste", tooltip="Read generation parameters from prompt or last generation if prompt is empty into user interface.")
|
|
||||||
self.clear_prompt_button = ToolButton(value=clear_prompt_symbol, elem_id=f"{id_part}_clear_prompt", tooltip="Clear prompt")
|
|
||||||
self.apply_styles = ToolButton(value=ui_prompt_styles.styles_materialize_symbol, elem_id=f"{id_part}_style_apply", tooltip="Apply all selected styles to prompts.")
|
|
||||||
self.restore_progress_button = ToolButton(value=restore_progress_symbol, elem_id=f"{id_part}_restore_progress", visible=False, tooltip="Restore progress")
|
|
||||||
|
|
||||||
self.token_counter = gr.HTML(value="<span>0/75</span>", elem_id=f"{id_part}_token_counter", elem_classes=["token-counter"])
|
|
||||||
self.token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button")
|
|
||||||
self.negative_token_counter = gr.HTML(value="<span>0/75</span>", elem_id=f"{id_part}_negative_token_counter", elem_classes=["token-counter"])
|
|
||||||
self.negative_token_button = gr.Button(visible=False, elem_id=f"{id_part}_negative_token_button")
|
|
||||||
|
|
||||||
self.clear_prompt_button.click(
|
|
||||||
fn=lambda *x: x,
|
|
||||||
_js="confirm_clear_prompt",
|
|
||||||
inputs=[self.prompt, self.negative_prompt],
|
|
||||||
outputs=[self.prompt, self.negative_prompt],
|
|
||||||
)
|
|
||||||
|
|
||||||
self.ui_styles = ui_prompt_styles.UiPromptStyles(id_part, self.prompt, self.negative_prompt)
|
|
||||||
self.ui_styles.setup_apply_button(self.apply_styles)
|
|
||||||
|
|
||||||
self.prompt_img.change(
|
|
||||||
fn=modules.images.image_data,
|
|
||||||
inputs=[self.prompt_img],
|
|
||||||
outputs=[self.prompt, self.prompt_img],
|
|
||||||
show_progress=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def setup_progressbar(*args, **kwargs):
|
def setup_progressbar(*args, **kwargs):
|
||||||
pass
|
pass
|
||||||
@@ -288,8 +215,8 @@ def apply_setting(key, value):
|
|||||||
return getattr(opts, key)
|
return getattr(opts, key)
|
||||||
|
|
||||||
|
|
||||||
def create_output_panel(tabname, outdir):
|
def create_output_panel(tabname, outdir, toprow=None):
|
||||||
return ui_common.create_output_panel(tabname, outdir)
|
return ui_common.create_output_panel(tabname, outdir, toprow)
|
||||||
|
|
||||||
|
|
||||||
def create_sampler_and_steps_selection(choices, tabname):
|
def create_sampler_and_steps_selection(choices, tabname):
|
||||||
@@ -336,7 +263,7 @@ def create_ui():
|
|||||||
scripts.scripts_txt2img.initialize_scripts(is_img2img=False)
|
scripts.scripts_txt2img.initialize_scripts(is_img2img=False)
|
||||||
|
|
||||||
with gr.Blocks(analytics_enabled=False) as txt2img_interface:
|
with gr.Blocks(analytics_enabled=False) as txt2img_interface:
|
||||||
toprow = Toprow(is_img2img=False)
|
toprow = ui_toprow.Toprow(is_img2img=False, is_compact=shared.opts.compact_prompt_box)
|
||||||
|
|
||||||
dummy_component = gr.Label(visible=False)
|
dummy_component = gr.Label(visible=False)
|
||||||
|
|
||||||
@@ -344,10 +271,17 @@ def create_ui():
|
|||||||
extra_tabs.__enter__()
|
extra_tabs.__enter__()
|
||||||
|
|
||||||
with gr.Tab("Generation", id="txt2img_generation") as txt2img_generation_tab, ResizeHandleRow(equal_height=False):
|
with gr.Tab("Generation", id="txt2img_generation") as txt2img_generation_tab, ResizeHandleRow(equal_height=False):
|
||||||
with gr.Column(variant='compact', elem_id="txt2img_settings"):
|
with ExitStack() as stack:
|
||||||
|
if shared.opts.txt2img_settings_accordion:
|
||||||
|
stack.enter_context(gr.Accordion("Open for Settings", open=False))
|
||||||
|
stack.enter_context(gr.Column(variant='compact', elem_id="txt2img_settings"))
|
||||||
|
|
||||||
scripts.scripts_txt2img.prepare_ui()
|
scripts.scripts_txt2img.prepare_ui()
|
||||||
|
|
||||||
for category in ordered_ui_categories():
|
for category in ordered_ui_categories():
|
||||||
|
if category == "prompt":
|
||||||
|
toprow.create_inline_toprow_prompts()
|
||||||
|
|
||||||
if category == "sampler":
|
if category == "sampler":
|
||||||
steps, sampler_name = create_sampler_and_steps_selection(sd_samplers.visible_sampler_names(), "txt2img")
|
steps, sampler_name = create_sampler_and_steps_selection(sd_samplers.visible_sampler_names(), "txt2img")
|
||||||
|
|
||||||
@@ -442,7 +376,7 @@ def create_ui():
|
|||||||
show_progress=False,
|
show_progress=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
txt2img_gallery, generation_info, html_info, html_log = create_output_panel("txt2img", opts.outdir_txt2img_samples)
|
txt2img_gallery, generation_info, html_info, html_log = create_output_panel("txt2img", opts.outdir_txt2img_samples, toprow)
|
||||||
|
|
||||||
txt2img_args = dict(
|
txt2img_args = dict(
|
||||||
fn=wrap_gradio_gpu_call(modules.txt2img.txt2img, extra_outputs=[None, '', '']),
|
fn=wrap_gradio_gpu_call(modules.txt2img.txt2img, extra_outputs=[None, '', '']),
|
||||||
@@ -554,13 +488,17 @@ def create_ui():
|
|||||||
scripts.scripts_img2img.initialize_scripts(is_img2img=True)
|
scripts.scripts_img2img.initialize_scripts(is_img2img=True)
|
||||||
|
|
||||||
with gr.Blocks(analytics_enabled=False) as img2img_interface:
|
with gr.Blocks(analytics_enabled=False) as img2img_interface:
|
||||||
toprow = Toprow(is_img2img=True)
|
toprow = ui_toprow.Toprow(is_img2img=True, is_compact=shared.opts.compact_prompt_box)
|
||||||
|
|
||||||
extra_tabs = gr.Tabs(elem_id="img2img_extra_tabs")
|
extra_tabs = gr.Tabs(elem_id="img2img_extra_tabs")
|
||||||
extra_tabs.__enter__()
|
extra_tabs.__enter__()
|
||||||
|
|
||||||
with gr.Tab("Generation", id="img2img_generation") as img2img_generation_tab, ResizeHandleRow(equal_height=False):
|
with gr.Tab("Generation", id="img2img_generation") as img2img_generation_tab, ResizeHandleRow(equal_height=False):
|
||||||
with gr.Column(variant='compact', elem_id="img2img_settings"):
|
with ExitStack() as stack:
|
||||||
|
if shared.opts.img2img_settings_accordion:
|
||||||
|
stack.enter_context(gr.Accordion("Open for Settings", open=False))
|
||||||
|
stack.enter_context(gr.Column(variant='compact', elem_id="img2img_settings"))
|
||||||
|
|
||||||
copy_image_buttons = []
|
copy_image_buttons = []
|
||||||
copy_image_destinations = {}
|
copy_image_destinations = {}
|
||||||
|
|
||||||
@@ -577,6 +515,13 @@ def create_ui():
|
|||||||
button = gr.Button(title)
|
button = gr.Button(title)
|
||||||
copy_image_buttons.append((button, name, elem))
|
copy_image_buttons.append((button, name, elem))
|
||||||
|
|
||||||
|
scripts.scripts_img2img.prepare_ui()
|
||||||
|
|
||||||
|
for category in ordered_ui_categories():
|
||||||
|
if category == "prompt":
|
||||||
|
toprow.create_inline_toprow_prompts()
|
||||||
|
|
||||||
|
if category == "image":
|
||||||
with gr.Tabs(elem_id="mode_img2img"):
|
with gr.Tabs(elem_id="mode_img2img"):
|
||||||
img2img_selected_tab = gr.State(0)
|
img2img_selected_tab = gr.State(0)
|
||||||
|
|
||||||
@@ -653,9 +598,6 @@ def create_ui():
|
|||||||
with FormRow():
|
with FormRow():
|
||||||
resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", choices=["Just resize", "Crop and resize", "Resize and fill", "Just resize (latent upscale)"], type="index", value="Just resize")
|
resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", choices=["Just resize", "Crop and resize", "Resize and fill", "Just resize (latent upscale)"], type="index", value="Just resize")
|
||||||
|
|
||||||
scripts.scripts_img2img.prepare_ui()
|
|
||||||
|
|
||||||
for category in ordered_ui_categories():
|
|
||||||
if category == "sampler":
|
if category == "sampler":
|
||||||
steps, sampler_name = create_sampler_and_steps_selection(sd_samplers.visible_sampler_names(), "img2img")
|
steps, sampler_name = create_sampler_and_steps_selection(sd_samplers.visible_sampler_names(), "img2img")
|
||||||
|
|
||||||
@@ -693,12 +635,6 @@ def create_ui():
|
|||||||
scale_by.release(**on_change_args)
|
scale_by.release(**on_change_args)
|
||||||
button_update_resize_to.click(**on_change_args)
|
button_update_resize_to.click(**on_change_args)
|
||||||
|
|
||||||
# the code below is meant to update the resolution label after the image in the image selection UI has changed.
|
|
||||||
# as it is now the event keeps firing continuously for inpaint edits, which ruins the page with constant requests.
|
|
||||||
# I assume this must be a gradio bug and for now we'll just do it for non-inpaint inputs.
|
|
||||||
for component in [init_img, sketch]:
|
|
||||||
component.change(fn=lambda: None, _js="updateImg2imgResizeToTextAfterChangingImage", inputs=[], outputs=[], show_progress=False)
|
|
||||||
|
|
||||||
tab_scale_to.select(fn=lambda: 0, inputs=[], outputs=[selected_scale_tab])
|
tab_scale_to.select(fn=lambda: 0, inputs=[], outputs=[selected_scale_tab])
|
||||||
tab_scale_by.select(fn=lambda: 1, inputs=[], outputs=[selected_scale_tab])
|
tab_scale_by.select(fn=lambda: 1, inputs=[], outputs=[selected_scale_tab])
|
||||||
|
|
||||||
@@ -756,6 +692,15 @@ def create_ui():
|
|||||||
with gr.Column(scale=4):
|
with gr.Column(scale=4):
|
||||||
inpaint_full_res_padding = gr.Slider(label='Only masked padding, pixels', minimum=0, maximum=256, step=4, value=32, elem_id="img2img_inpaint_full_res_padding")
|
inpaint_full_res_padding = gr.Slider(label='Only masked padding, pixels', minimum=0, maximum=256, step=4, value=32, elem_id="img2img_inpaint_full_res_padding")
|
||||||
|
|
||||||
|
if category not in {"accordions"}:
|
||||||
|
scripts.scripts_img2img.setup_ui_for_section(category)
|
||||||
|
|
||||||
|
# the code below is meant to update the resolution label after the image in the image selection UI has changed.
|
||||||
|
# as it is now the event keeps firing continuously for inpaint edits, which ruins the page with constant requests.
|
||||||
|
# I assume this must be a gradio bug and for now we'll just do it for non-inpaint inputs.
|
||||||
|
for component in [init_img, sketch]:
|
||||||
|
component.change(fn=lambda: None, _js="updateImg2imgResizeToTextAfterChangingImage", inputs=[], outputs=[], show_progress=False)
|
||||||
|
|
||||||
def select_img2img_tab(tab):
|
def select_img2img_tab(tab):
|
||||||
return gr.update(visible=tab in [2, 3, 4]), gr.update(visible=tab == 3),
|
return gr.update(visible=tab in [2, 3, 4]), gr.update(visible=tab == 3),
|
||||||
|
|
||||||
@@ -766,10 +711,7 @@ def create_ui():
|
|||||||
outputs=[inpaint_controls, mask_alpha],
|
outputs=[inpaint_controls, mask_alpha],
|
||||||
)
|
)
|
||||||
|
|
||||||
if category not in {"accordions"}:
|
img2img_gallery, generation_info, html_info, html_log = create_output_panel("img2img", opts.outdir_img2img_samples, toprow)
|
||||||
scripts.scripts_img2img.setup_ui_for_section(category)
|
|
||||||
|
|
||||||
img2img_gallery, generation_info, html_info, html_log = create_output_panel("img2img", opts.outdir_img2img_samples)
|
|
||||||
|
|
||||||
img2img_args = dict(
|
img2img_args = dict(
|
||||||
fn=wrap_gradio_gpu_call(modules.img2img.img2img, extra_outputs=[None, '', '']),
|
fn=wrap_gradio_gpu_call(modules.img2img.img2img, extra_outputs=[None, '', '']),
|
||||||
@@ -1296,7 +1238,7 @@ def create_ui():
|
|||||||
|
|
||||||
loadsave.setup_ui()
|
loadsave.setup_ui()
|
||||||
|
|
||||||
if os.path.exists(os.path.join(script_path, "notification.mp3")):
|
if os.path.exists(os.path.join(script_path, "notification.mp3")) and shared.opts.notification_audio:
|
||||||
gr.Audio(interactive=False, value=os.path.join(script_path, "notification.mp3"), elem_id="audio_notification", visible=False)
|
gr.Audio(interactive=False, value=os.path.join(script_path, "notification.mp3"), elem_id="audio_notification", visible=False)
|
||||||
|
|
||||||
footer = shared.html("footer.html")
|
footer = shared.html("footer.html")
|
||||||
@@ -1366,7 +1308,7 @@ def setup_ui_api(app):
|
|||||||
from fastapi.responses import PlainTextResponse
|
from fastapi.responses import PlainTextResponse
|
||||||
|
|
||||||
text = sysinfo.get()
|
text = sysinfo.get()
|
||||||
filename = f"sysinfo-{datetime.datetime.utcnow().strftime('%Y-%m-%d-%H-%M')}.txt"
|
filename = f"sysinfo-{datetime.datetime.utcnow().strftime('%Y-%m-%d-%H-%M')}.json"
|
||||||
|
|
||||||
return PlainTextResponse(text, headers={'Content-Disposition': f'{"attachment" if attachment else "inline"}; filename="{filename}"'})
|
return PlainTextResponse(text, headers={'Content-Disposition': f'{"attachment" if attachment else "inline"}; filename="{filename}"'})
|
||||||
|
|
||||||
|
|||||||
@@ -104,7 +104,7 @@ def save_files(js_data, images, do_make_zip, index):
|
|||||||
return gr.File.update(value=fullfns, visible=True), plaintext_to_html(f"Saved: {filenames[0]}")
|
return gr.File.update(value=fullfns, visible=True), plaintext_to_html(f"Saved: {filenames[0]}")
|
||||||
|
|
||||||
|
|
||||||
def create_output_panel(tabname, outdir):
|
def create_output_panel(tabname, outdir, toprow=None):
|
||||||
|
|
||||||
def open_folder(f):
|
def open_folder(f):
|
||||||
if not os.path.exists(f):
|
if not os.path.exists(f):
|
||||||
@@ -130,12 +130,15 @@ Requested path was: {f}
|
|||||||
else:
|
else:
|
||||||
sp.Popen(["xdg-open", path])
|
sp.Popen(["xdg-open", path])
|
||||||
|
|
||||||
with gr.Column(variant='panel', elem_id=f"{tabname}_results"):
|
with gr.Column(elem_id=f"{tabname}_results"):
|
||||||
|
if toprow:
|
||||||
|
toprow.create_inline_toprow_image()
|
||||||
|
|
||||||
|
with gr.Column(variant='panel', elem_id=f"{tabname}_results_panel"):
|
||||||
with gr.Group(elem_id=f"{tabname}_gallery_container"):
|
with gr.Group(elem_id=f"{tabname}_gallery_container"):
|
||||||
result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery", columns=4, preview=True, height=shared.opts.gallery_height or None)
|
result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery", columns=4, preview=True, height=shared.opts.gallery_height or None)
|
||||||
|
|
||||||
generation_info = None
|
generation_info = None
|
||||||
with gr.Column():
|
|
||||||
with gr.Row(elem_id=f"image_buttons_{tabname}", elem_classes="image-buttons"):
|
with gr.Row(elem_id=f"image_buttons_{tabname}", elem_classes="image-buttons"):
|
||||||
open_folder_button = ToolButton(folder_symbol, elem_id=f'{tabname}_open_folder', visible=not shared.cmd_opts.hide_ui_dir_config, tooltip="Open images output directory.")
|
open_folder_button = ToolButton(folder_symbol, elem_id=f'{tabname}_open_folder', visible=not shared.cmd_opts.hide_ui_dir_config, tooltip="Open images output directory.")
|
||||||
|
|
||||||
|
|||||||
@@ -65,7 +65,7 @@ def save_config_state(name):
|
|||||||
filename = os.path.join(config_states_dir, f"{timestamp}_{name}.json")
|
filename = os.path.join(config_states_dir, f"{timestamp}_{name}.json")
|
||||||
print(f"Saving backup of webui/extension state to {filename}.")
|
print(f"Saving backup of webui/extension state to {filename}.")
|
||||||
with open(filename, "w", encoding="utf-8") as f:
|
with open(filename, "w", encoding="utf-8") as f:
|
||||||
json.dump(current_config_state, f, indent=4)
|
json.dump(current_config_state, f, indent=4, ensure_ascii=False)
|
||||||
config_states.list_config_states()
|
config_states.list_config_states()
|
||||||
new_value = next(iter(config_states.all_config_states.keys()), "Current")
|
new_value = next(iter(config_states.all_config_states.keys()), "Current")
|
||||||
new_choices = ["Current"] + list(config_states.all_config_states.keys())
|
new_choices = ["Current"] + list(config_states.all_config_states.keys())
|
||||||
|
|||||||
@@ -103,6 +103,7 @@ class ExtraNetworksPage:
|
|||||||
self.name = title.lower()
|
self.name = title.lower()
|
||||||
self.id_page = self.name.replace(" ", "_")
|
self.id_page = self.name.replace(" ", "_")
|
||||||
self.card_page = shared.html("extra-networks-card.html")
|
self.card_page = shared.html("extra-networks-card.html")
|
||||||
|
self.allow_prompt = True
|
||||||
self.allow_negative_prompt = False
|
self.allow_negative_prompt = False
|
||||||
self.metadata = {}
|
self.metadata = {}
|
||||||
self.items = {}
|
self.items = {}
|
||||||
@@ -278,6 +279,7 @@ class ExtraNetworksPage:
|
|||||||
"date_created": int(stat.st_ctime or 0),
|
"date_created": int(stat.st_ctime or 0),
|
||||||
"date_modified": int(stat.st_mtime or 0),
|
"date_modified": int(stat.st_mtime or 0),
|
||||||
"name": pth.name.lower(),
|
"name": pth.name.lower(),
|
||||||
|
"path": str(pth.parent).lower(),
|
||||||
}
|
}
|
||||||
|
|
||||||
def find_preview(self, path):
|
def find_preview(self, path):
|
||||||
@@ -367,7 +369,10 @@ def create_ui(interface: gr.Blocks, unrelated_tabs, tabname):
|
|||||||
related_tabs = []
|
related_tabs = []
|
||||||
|
|
||||||
for page in ui.stored_extra_pages:
|
for page in ui.stored_extra_pages:
|
||||||
with gr.Tab(page.title, id=page.id_page) as tab:
|
with gr.Tab(page.title, elem_id=f"{tabname}_{page.id_page}", elem_classes=["extra-page"]) as tab:
|
||||||
|
with gr.Column(elem_id=f"{tabname}_{page.id_page}_prompts", elem_classes=["extra-page-prompts"]):
|
||||||
|
pass
|
||||||
|
|
||||||
elem_id = f"{tabname}_{page.id_page}_cards_html"
|
elem_id = f"{tabname}_{page.id_page}_cards_html"
|
||||||
page_elem = gr.HTML('Loading...', elem_id=elem_id)
|
page_elem = gr.HTML('Loading...', elem_id=elem_id)
|
||||||
ui.pages.append(page_elem)
|
ui.pages.append(page_elem)
|
||||||
@@ -381,19 +386,28 @@ def create_ui(interface: gr.Blocks, unrelated_tabs, tabname):
|
|||||||
related_tabs.append(tab)
|
related_tabs.append(tab)
|
||||||
|
|
||||||
edit_search = gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", elem_classes="search", placeholder="Search...", visible=False, interactive=True)
|
edit_search = gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", elem_classes="search", placeholder="Search...", visible=False, interactive=True)
|
||||||
dropdown_sort = gr.Dropdown(choices=['Default Sort', 'Date Created', 'Date Modified', 'Name'], value='Default Sort', elem_id=tabname+"_extra_sort", elem_classes="sort", multiselect=False, visible=False, show_label=False, interactive=True, label=tabname+"_extra_sort_order")
|
dropdown_sort = gr.Dropdown(choices=['Path', 'Name', 'Date Created', 'Date Modified', ], value=shared.opts.extra_networks_card_order_field, elem_id=tabname+"_extra_sort", elem_classes="sort", multiselect=False, visible=False, show_label=False, interactive=True, label=tabname+"_extra_sort_order")
|
||||||
button_sortorder = ToolButton(switch_values_symbol, elem_id=tabname+"_extra_sortorder", elem_classes="sortorder", visible=False, tooltip="Invert sort order")
|
button_sortorder = ToolButton(switch_values_symbol, elem_id=tabname+"_extra_sortorder", elem_classes=["sortorder"] + ([] if shared.opts.extra_networks_card_order == "Ascending" else ["sortReverse"]), visible=False, tooltip="Invert sort order")
|
||||||
button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh", visible=False)
|
button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh", visible=False)
|
||||||
checkbox_show_dirs = gr.Checkbox(True, label='Show dirs', elem_id=tabname+"_extra_show_dirs", elem_classes="show-dirs", visible=False)
|
checkbox_show_dirs = gr.Checkbox(True, label='Show dirs', elem_id=tabname+"_extra_show_dirs", elem_classes="show-dirs", visible=False)
|
||||||
|
|
||||||
ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False)
|
ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False)
|
||||||
ui.preview_target_filename = gr.Textbox('Preview save filename', elem_id=tabname+"_preview_filename", visible=False)
|
ui.preview_target_filename = gr.Textbox('Preview save filename', elem_id=tabname+"_preview_filename", visible=False)
|
||||||
|
|
||||||
for tab in unrelated_tabs:
|
tab_controls = [edit_search, dropdown_sort, button_sortorder, button_refresh, checkbox_show_dirs]
|
||||||
tab.select(fn=lambda: [gr.update(visible=False) for _ in range(5)], inputs=[], outputs=[edit_search, dropdown_sort, button_sortorder, button_refresh, checkbox_show_dirs], show_progress=False)
|
|
||||||
|
|
||||||
for tab in related_tabs:
|
for tab in unrelated_tabs:
|
||||||
tab.select(fn=lambda: [gr.update(visible=True) for _ in range(5)], inputs=[], outputs=[edit_search, dropdown_sort, button_sortorder, button_refresh, checkbox_show_dirs], show_progress=False)
|
tab.select(fn=lambda: [gr.update(visible=False) for _ in tab_controls], _js='function(){ extraNetworksUrelatedTabSelected("' + tabname + '"); }', inputs=[], outputs=tab_controls, show_progress=False)
|
||||||
|
|
||||||
|
for page, tab in zip(ui.stored_extra_pages, related_tabs):
|
||||||
|
allow_prompt = "true" if page.allow_prompt else "false"
|
||||||
|
allow_negative_prompt = "true" if page.allow_negative_prompt else "false"
|
||||||
|
|
||||||
|
jscode = 'extraNetworksTabSelected("' + tabname + '", "' + f"{tabname}_{page.id_page}_prompts" + '", ' + allow_prompt + ', ' + allow_negative_prompt + ');'
|
||||||
|
|
||||||
|
tab.select(fn=lambda: [gr.update(visible=True) for _ in tab_controls], _js='function(){ ' + jscode + ' }', inputs=[], outputs=tab_controls, show_progress=False)
|
||||||
|
|
||||||
|
dropdown_sort.change(fn=lambda: None, _js="function(){ applyExtraNetworkSort('" + tabname + "'); }")
|
||||||
|
|
||||||
def pages_html():
|
def pages_html():
|
||||||
if not ui.pages_contents:
|
if not ui.pages_contents:
|
||||||
|
|||||||
@@ -10,11 +10,16 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__('Checkpoints')
|
super().__init__('Checkpoints')
|
||||||
|
|
||||||
|
self.allow_prompt = False
|
||||||
|
|
||||||
def refresh(self):
|
def refresh(self):
|
||||||
shared.refresh_checkpoints()
|
shared.refresh_checkpoints()
|
||||||
|
|
||||||
def create_item(self, name, index=None, enable_filter=True):
|
def create_item(self, name, index=None, enable_filter=True):
|
||||||
checkpoint: sd_models.CheckpointInfo = sd_models.checkpoint_aliases.get(name)
|
checkpoint: sd_models.CheckpointInfo = sd_models.checkpoint_aliases.get(name)
|
||||||
|
if checkpoint is None:
|
||||||
|
return
|
||||||
|
|
||||||
path, ext = os.path.splitext(checkpoint.filename)
|
path, ext = os.path.splitext(checkpoint.filename)
|
||||||
return {
|
return {
|
||||||
"name": checkpoint.name_for_extra,
|
"name": checkpoint.name_for_extra,
|
||||||
@@ -30,9 +35,12 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
|
|||||||
}
|
}
|
||||||
|
|
||||||
def list_items(self):
|
def list_items(self):
|
||||||
|
# instantiate a list to protect against concurrent modification
|
||||||
names = list(sd_models.checkpoints_list)
|
names = list(sd_models.checkpoints_list)
|
||||||
for index, name in enumerate(names):
|
for index, name in enumerate(names):
|
||||||
yield self.create_item(name, index)
|
item = self.create_item(name, index)
|
||||||
|
if item is not None:
|
||||||
|
yield item
|
||||||
|
|
||||||
def allowed_directories_for_previews(self):
|
def allowed_directories_for_previews(self):
|
||||||
return [v for v in [shared.cmd_opts.ckpt_dir, sd_models.model_path] if v is not None]
|
return [v for v in [shared.cmd_opts.ckpt_dir, sd_models.model_path] if v is not None]
|
||||||
|
|||||||
@@ -13,7 +13,10 @@ class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage):
|
|||||||
shared.reload_hypernetworks()
|
shared.reload_hypernetworks()
|
||||||
|
|
||||||
def create_item(self, name, index=None, enable_filter=True):
|
def create_item(self, name, index=None, enable_filter=True):
|
||||||
full_path = shared.hypernetworks[name]
|
full_path = shared.hypernetworks.get(name)
|
||||||
|
if full_path is None:
|
||||||
|
return
|
||||||
|
|
||||||
path, ext = os.path.splitext(full_path)
|
path, ext = os.path.splitext(full_path)
|
||||||
sha256 = sha256_from_cache(full_path, f'hypernet/{name}')
|
sha256 = sha256_from_cache(full_path, f'hypernet/{name}')
|
||||||
shorthash = sha256[0:10] if sha256 else None
|
shorthash = sha256[0:10] if sha256 else None
|
||||||
@@ -31,8 +34,12 @@ class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage):
|
|||||||
}
|
}
|
||||||
|
|
||||||
def list_items(self):
|
def list_items(self):
|
||||||
for index, name in enumerate(shared.hypernetworks):
|
# instantiate a list to protect against concurrent modification
|
||||||
yield self.create_item(name, index)
|
names = list(shared.hypernetworks)
|
||||||
|
for index, name in enumerate(names):
|
||||||
|
item = self.create_item(name, index)
|
||||||
|
if item is not None:
|
||||||
|
yield item
|
||||||
|
|
||||||
def allowed_directories_for_previews(self):
|
def allowed_directories_for_previews(self):
|
||||||
return [shared.cmd_opts.hypernetwork_dir]
|
return [shared.cmd_opts.hypernetwork_dir]
|
||||||
|
|||||||
@@ -14,6 +14,8 @@ class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage):
|
|||||||
|
|
||||||
def create_item(self, name, index=None, enable_filter=True):
|
def create_item(self, name, index=None, enable_filter=True):
|
||||||
embedding = sd_hijack.model_hijack.embedding_db.word_embeddings.get(name)
|
embedding = sd_hijack.model_hijack.embedding_db.word_embeddings.get(name)
|
||||||
|
if embedding is None:
|
||||||
|
return
|
||||||
|
|
||||||
path, ext = os.path.splitext(embedding.filename)
|
path, ext = os.path.splitext(embedding.filename)
|
||||||
return {
|
return {
|
||||||
@@ -29,8 +31,12 @@ class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage):
|
|||||||
}
|
}
|
||||||
|
|
||||||
def list_items(self):
|
def list_items(self):
|
||||||
for index, name in enumerate(sd_hijack.model_hijack.embedding_db.word_embeddings):
|
# instantiate a list to protect against concurrent modification
|
||||||
yield self.create_item(name, index)
|
names = list(sd_hijack.model_hijack.embedding_db.word_embeddings)
|
||||||
|
for index, name in enumerate(names):
|
||||||
|
item = self.create_item(name, index)
|
||||||
|
if item is not None:
|
||||||
|
yield item
|
||||||
|
|
||||||
def allowed_directories_for_previews(self):
|
def allowed_directories_for_previews(self):
|
||||||
return list(sd_hijack.model_hijack.embedding_db.embedding_dirs)
|
return list(sd_hijack.model_hijack.embedding_db.embedding_dirs)
|
||||||
|
|||||||
@@ -134,7 +134,7 @@ class UserMetadataEditor:
|
|||||||
basename, ext = os.path.splitext(filename)
|
basename, ext = os.path.splitext(filename)
|
||||||
|
|
||||||
with open(basename + '.json', "w", encoding="utf8") as file:
|
with open(basename + '.json', "w", encoding="utf8") as file:
|
||||||
json.dump(metadata, file, indent=4)
|
json.dump(metadata, file, indent=4, ensure_ascii=False)
|
||||||
|
|
||||||
def save_user_metadata(self, name, desc, notes):
|
def save_user_metadata(self, name, desc, notes):
|
||||||
user_metadata = self.get_user_metadata(name)
|
user_metadata = self.get_user_metadata(name)
|
||||||
|
|||||||
@@ -141,7 +141,7 @@ class UiLoadsave:
|
|||||||
|
|
||||||
def write_to_file(self, current_ui_settings):
|
def write_to_file(self, current_ui_settings):
|
||||||
with open(self.filename, "w", encoding="utf8") as file:
|
with open(self.filename, "w", encoding="utf8") as file:
|
||||||
json.dump(current_ui_settings, file, indent=4)
|
json.dump(current_ui_settings, file, indent=4, ensure_ascii=False)
|
||||||
|
|
||||||
def dump_defaults(self):
|
def dump_defaults(self):
|
||||||
"""saves default values to a file unless tjhe file is present and there was an error loading default values at start"""
|
"""saves default values to a file unless tjhe file is present and there was an error loading default values at start"""
|
||||||
|
|||||||
@@ -68,10 +68,10 @@ class UiPromptStyles:
|
|||||||
self.copy = ui_components.ToolButton(value=styles_copy_symbol, elem_id=f"{tabname}_style_copy", tooltip="Copy main UI prompt to style.")
|
self.copy = ui_components.ToolButton(value=styles_copy_symbol, elem_id=f"{tabname}_style_copy", tooltip="Copy main UI prompt to style.")
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
self.prompt = gr.Textbox(label="Prompt", show_label=True, elem_id=f"{tabname}_edit_style_prompt", lines=3)
|
self.prompt = gr.Textbox(label="Prompt", show_label=True, elem_id=f"{tabname}_edit_style_prompt", lines=3, elem_classes=["prompt"])
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
self.neg_prompt = gr.Textbox(label="Negative prompt", show_label=True, elem_id=f"{tabname}_edit_style_neg_prompt", lines=3)
|
self.neg_prompt = gr.Textbox(label="Negative prompt", show_label=True, elem_id=f"{tabname}_edit_style_neg_prompt", lines=3, elem_classes=["prompt"])
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
self.save = gr.Button('Save', variant='primary', elem_id=f'{tabname}_edit_style_save', visible=False)
|
self.save = gr.Button('Save', variant='primary', elem_id=f'{tabname}_edit_style_save', visible=False)
|
||||||
|
|||||||
+17
-7
@@ -1,6 +1,6 @@
|
|||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
from modules import ui_common, shared, script_callbacks, scripts, sd_models, sysinfo
|
from modules import ui_common, shared, script_callbacks, scripts, sd_models, sysinfo, timer
|
||||||
from modules.call_queue import wrap_gradio_call
|
from modules.call_queue import wrap_gradio_call
|
||||||
from modules.shared import opts
|
from modules.shared import opts
|
||||||
from modules.ui_components import FormRow
|
from modules.ui_components import FormRow
|
||||||
@@ -177,8 +177,8 @@ class UiSettings:
|
|||||||
download_localization = gr.Button(value='Download localization template', elem_id="download_localization")
|
download_localization = gr.Button(value='Download localization template', elem_id="download_localization")
|
||||||
reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary', elem_id="settings_reload_script_bodies")
|
reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary', elem_id="settings_reload_script_bodies")
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
unload_sd_model = gr.Button(value='Unload SD checkpoint to free VRAM', elem_id="sett_unload_sd_model")
|
unload_sd_model = gr.Button(value='Unload SD checkpoint to RAM', elem_id="sett_unload_sd_model")
|
||||||
reload_sd_model = gr.Button(value='Reload the last SD checkpoint back into VRAM', elem_id="sett_reload_sd_model")
|
reload_sd_model = gr.Button(value='Load SD checkpoint to VRAM from RAM', elem_id="sett_reload_sd_model")
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
calculate_all_checkpoint_hash = gr.Button(value='Calculate hash for all checkpoint', elem_id="calculate_all_checkpoint_hash")
|
calculate_all_checkpoint_hash = gr.Button(value='Calculate hash for all checkpoint', elem_id="calculate_all_checkpoint_hash")
|
||||||
calculate_all_checkpoint_hash_threads = gr.Number(value=1, label="Number of parallel calculations", elem_id="calculate_all_checkpoint_hash_threads", precision=0, minimum=1)
|
calculate_all_checkpoint_hash_threads = gr.Number(value=1, label="Number of parallel calculations", elem_id="calculate_all_checkpoint_hash_threads", precision=0, minimum=1)
|
||||||
@@ -194,16 +194,26 @@ class UiSettings:
|
|||||||
|
|
||||||
self.text_settings = gr.Textbox(elem_id="settings_json", value=lambda: opts.dumpjson(), visible=False)
|
self.text_settings = gr.Textbox(elem_id="settings_json", value=lambda: opts.dumpjson(), visible=False)
|
||||||
|
|
||||||
|
def call_func_and_return_text(func, text):
|
||||||
|
def handler():
|
||||||
|
t = timer.Timer()
|
||||||
|
func()
|
||||||
|
t.record(text)
|
||||||
|
|
||||||
|
return f'{text} in {t.total:.1f}s'
|
||||||
|
|
||||||
|
return handler
|
||||||
|
|
||||||
unload_sd_model.click(
|
unload_sd_model.click(
|
||||||
fn=sd_models.unload_model_weights,
|
fn=call_func_and_return_text(sd_models.unload_model_weights, 'Unloaded the checkpoint'),
|
||||||
inputs=[],
|
inputs=[],
|
||||||
outputs=[]
|
outputs=[self.result]
|
||||||
)
|
)
|
||||||
|
|
||||||
reload_sd_model.click(
|
reload_sd_model.click(
|
||||||
fn=sd_models.reload_model_weights,
|
fn=call_func_and_return_text(lambda: sd_models.send_model_to_device(shared.sd_model), 'Loaded the checkpoint'),
|
||||||
inputs=[],
|
inputs=[],
|
||||||
outputs=[]
|
outputs=[self.result]
|
||||||
)
|
)
|
||||||
|
|
||||||
request_notifications.click(
|
request_notifications.click(
|
||||||
|
|||||||
@@ -61,7 +61,7 @@ def save_pil_to_file(self, pil_image, dir=None, format="png"):
|
|||||||
|
|
||||||
def install_ui_tempdir_override():
|
def install_ui_tempdir_override():
|
||||||
"""override save to file function so that it also writes PNG info"""
|
"""override save to file function so that it also writes PNG info"""
|
||||||
gradio.components.IOComponent.pil_to_temp_file = save_pil_to_file
|
# gradio.components.IOComponent.pil_to_temp_file = save_pil_to_file
|
||||||
|
|
||||||
|
|
||||||
def on_tmpdir_changed():
|
def on_tmpdir_changed():
|
||||||
|
|||||||
@@ -0,0 +1,141 @@
|
|||||||
|
import gradio as gr
|
||||||
|
|
||||||
|
from modules import shared, ui_prompt_styles
|
||||||
|
import modules.images
|
||||||
|
|
||||||
|
from modules.ui_components import ToolButton
|
||||||
|
|
||||||
|
|
||||||
|
class Toprow:
|
||||||
|
"""Creates a top row UI with prompts, generate button, styles, extra little buttons for things, and enables some functionality related to their operation"""
|
||||||
|
|
||||||
|
prompt = None
|
||||||
|
prompt_img = None
|
||||||
|
negative_prompt = None
|
||||||
|
|
||||||
|
button_interrogate = None
|
||||||
|
button_deepbooru = None
|
||||||
|
|
||||||
|
interrupt = None
|
||||||
|
skip = None
|
||||||
|
submit = None
|
||||||
|
|
||||||
|
paste = None
|
||||||
|
clear_prompt_button = None
|
||||||
|
apply_styles = None
|
||||||
|
restore_progress_button = None
|
||||||
|
|
||||||
|
token_counter = None
|
||||||
|
token_button = None
|
||||||
|
negative_token_counter = None
|
||||||
|
negative_token_button = None
|
||||||
|
|
||||||
|
ui_styles = None
|
||||||
|
|
||||||
|
submit_box = None
|
||||||
|
|
||||||
|
def __init__(self, is_img2img, is_compact=False):
|
||||||
|
id_part = "img2img" if is_img2img else "txt2img"
|
||||||
|
self.id_part = id_part
|
||||||
|
self.is_img2img = is_img2img
|
||||||
|
self.is_compact = is_compact
|
||||||
|
|
||||||
|
if not is_compact:
|
||||||
|
with gr.Row(elem_id=f"{id_part}_toprow", variant="compact"):
|
||||||
|
self.create_classic_toprow()
|
||||||
|
else:
|
||||||
|
self.create_submit_box()
|
||||||
|
|
||||||
|
def create_classic_toprow(self):
|
||||||
|
self.create_prompts()
|
||||||
|
|
||||||
|
with gr.Column(scale=1, elem_id=f"{self.id_part}_actions_column"):
|
||||||
|
self.create_submit_box()
|
||||||
|
|
||||||
|
self.create_tools_row()
|
||||||
|
|
||||||
|
self.create_styles_ui()
|
||||||
|
|
||||||
|
def create_inline_toprow_prompts(self):
|
||||||
|
if not self.is_compact:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.create_prompts()
|
||||||
|
|
||||||
|
with gr.Row(elem_classes=["toprow-compact-stylerow"]):
|
||||||
|
with gr.Column(elem_classes=["toprow-compact-tools"]):
|
||||||
|
self.create_tools_row()
|
||||||
|
with gr.Column():
|
||||||
|
self.create_styles_ui()
|
||||||
|
|
||||||
|
def create_inline_toprow_image(self):
|
||||||
|
if not self.is_compact:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.submit_box.render()
|
||||||
|
|
||||||
|
def create_prompts(self):
|
||||||
|
with gr.Column(elem_id=f"{self.id_part}_prompt_container", elem_classes=["prompt-container-compact"] if self.is_compact else [], scale=6):
|
||||||
|
with gr.Row(elem_id=f"{self.id_part}_prompt_row", elem_classes=["prompt-row"]):
|
||||||
|
self.prompt = gr.Textbox(label="Prompt", elem_id=f"{self.id_part}_prompt", show_label=False, lines=3, placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)", elem_classes=["prompt"])
|
||||||
|
self.prompt_img = gr.File(label="", elem_id=f"{self.id_part}_prompt_image", file_count="single", type="binary", visible=False)
|
||||||
|
|
||||||
|
with gr.Row(elem_id=f"{self.id_part}_neg_prompt_row", elem_classes=["prompt-row"]):
|
||||||
|
self.negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{self.id_part}_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)", elem_classes=["prompt"])
|
||||||
|
|
||||||
|
self.prompt_img.change(
|
||||||
|
fn=modules.images.image_data,
|
||||||
|
inputs=[self.prompt_img],
|
||||||
|
outputs=[self.prompt, self.prompt_img],
|
||||||
|
show_progress=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_submit_box(self):
|
||||||
|
with gr.Row(elem_id=f"{self.id_part}_generate_box", elem_classes=["generate-box"] + (["generate-box-compact"] if self.is_compact else []), render=not self.is_compact) as submit_box:
|
||||||
|
self.submit_box = submit_box
|
||||||
|
|
||||||
|
self.interrupt = gr.Button('Interrupt', elem_id=f"{self.id_part}_interrupt", elem_classes="generate-box-interrupt")
|
||||||
|
self.skip = gr.Button('Skip', elem_id=f"{self.id_part}_skip", elem_classes="generate-box-skip")
|
||||||
|
self.submit = gr.Button('Generate', elem_id=f"{self.id_part}_generate", variant='primary')
|
||||||
|
|
||||||
|
self.skip.click(
|
||||||
|
fn=lambda: shared.state.skip(),
|
||||||
|
inputs=[],
|
||||||
|
outputs=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
self.interrupt.click(
|
||||||
|
fn=lambda: shared.state.interrupt(),
|
||||||
|
inputs=[],
|
||||||
|
outputs=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_tools_row(self):
|
||||||
|
with gr.Row(elem_id=f"{self.id_part}_tools"):
|
||||||
|
from modules.ui import paste_symbol, clear_prompt_symbol, restore_progress_symbol
|
||||||
|
|
||||||
|
self.paste = ToolButton(value=paste_symbol, elem_id="paste", tooltip="Read generation parameters from prompt or last generation if prompt is empty into user interface.")
|
||||||
|
self.clear_prompt_button = ToolButton(value=clear_prompt_symbol, elem_id=f"{self.id_part}_clear_prompt", tooltip="Clear prompt")
|
||||||
|
self.apply_styles = ToolButton(value=ui_prompt_styles.styles_materialize_symbol, elem_id=f"{self.id_part}_style_apply", tooltip="Apply all selected styles to prompts.")
|
||||||
|
|
||||||
|
if self.is_img2img:
|
||||||
|
self.button_interrogate = ToolButton('📎', tooltip='Interrogate CLIP - use CLIP neural network to create a text describing the image, and put it into the prompt field', elem_id="interrogate")
|
||||||
|
self.button_deepbooru = ToolButton('📦', tooltip='Interrogate DeepBooru - use DeepBooru neural network to create a text describing the image, and put it into the prompt field', elem_id="deepbooru")
|
||||||
|
|
||||||
|
self.restore_progress_button = ToolButton(value=restore_progress_symbol, elem_id=f"{self.id_part}_restore_progress", visible=False, tooltip="Restore progress")
|
||||||
|
|
||||||
|
self.token_counter = gr.HTML(value="<span>0/75</span>", elem_id=f"{self.id_part}_token_counter", elem_classes=["token-counter"])
|
||||||
|
self.token_button = gr.Button(visible=False, elem_id=f"{self.id_part}_token_button")
|
||||||
|
self.negative_token_counter = gr.HTML(value="<span>0/75</span>", elem_id=f"{self.id_part}_negative_token_counter", elem_classes=["token-counter"])
|
||||||
|
self.negative_token_button = gr.Button(visible=False, elem_id=f"{self.id_part}_negative_token_button")
|
||||||
|
|
||||||
|
self.clear_prompt_button.click(
|
||||||
|
fn=lambda *x: x,
|
||||||
|
_js="confirm_clear_prompt",
|
||||||
|
inputs=[self.prompt, self.negative_prompt],
|
||||||
|
outputs=[self.prompt, self.negative_prompt],
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_styles_ui(self):
|
||||||
|
self.ui_styles = ui_prompt_styles.UiPromptStyles(self.id_part, self.prompt, self.negative_prompt)
|
||||||
|
self.ui_styles.setup_apply_button(self.apply_styles)
|
||||||
@@ -0,0 +1,164 @@
|
|||||||
|
from transformers import BertPreTrainedModel,BertConfig
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch
|
||||||
|
from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRobertaConfig
|
||||||
|
from transformers import XLMRobertaModel,XLMRobertaTokenizer
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
class BertSeriesConfig(BertConfig):
|
||||||
|
def __init__(self, vocab_size=30522, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, position_embedding_type="absolute", use_cache=True, classifier_dropout=None,project_dim=512, pooler_fn="average",learn_encoder=False,model_type='bert',**kwargs):
|
||||||
|
|
||||||
|
super().__init__(vocab_size, hidden_size, num_hidden_layers, num_attention_heads, intermediate_size, hidden_act, hidden_dropout_prob, attention_probs_dropout_prob, max_position_embeddings, type_vocab_size, initializer_range, layer_norm_eps, pad_token_id, position_embedding_type, use_cache, classifier_dropout, **kwargs)
|
||||||
|
self.project_dim = project_dim
|
||||||
|
self.pooler_fn = pooler_fn
|
||||||
|
self.learn_encoder = learn_encoder
|
||||||
|
|
||||||
|
class RobertaSeriesConfig(XLMRobertaConfig):
|
||||||
|
def __init__(self, pad_token_id=1, bos_token_id=0, eos_token_id=2,project_dim=512,pooler_fn='cls',learn_encoder=False, **kwargs):
|
||||||
|
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
||||||
|
self.project_dim = project_dim
|
||||||
|
self.pooler_fn = pooler_fn
|
||||||
|
self.learn_encoder = learn_encoder
|
||||||
|
|
||||||
|
|
||||||
|
class BertSeriesModelWithTransformation(BertPreTrainedModel):
|
||||||
|
|
||||||
|
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
||||||
|
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
|
||||||
|
config_class = BertSeriesConfig
|
||||||
|
|
||||||
|
def __init__(self, config=None, **kargs):
|
||||||
|
# modify initialization for autoloading
|
||||||
|
if config is None:
|
||||||
|
config = XLMRobertaConfig()
|
||||||
|
config.attention_probs_dropout_prob= 0.1
|
||||||
|
config.bos_token_id=0
|
||||||
|
config.eos_token_id=2
|
||||||
|
config.hidden_act='gelu'
|
||||||
|
config.hidden_dropout_prob=0.1
|
||||||
|
config.hidden_size=1024
|
||||||
|
config.initializer_range=0.02
|
||||||
|
config.intermediate_size=4096
|
||||||
|
config.layer_norm_eps=1e-05
|
||||||
|
config.max_position_embeddings=514
|
||||||
|
|
||||||
|
config.num_attention_heads=16
|
||||||
|
config.num_hidden_layers=24
|
||||||
|
config.output_past=True
|
||||||
|
config.pad_token_id=1
|
||||||
|
config.position_embedding_type= "absolute"
|
||||||
|
|
||||||
|
config.type_vocab_size= 1
|
||||||
|
config.use_cache=True
|
||||||
|
config.vocab_size= 250002
|
||||||
|
config.project_dim = 1024
|
||||||
|
config.learn_encoder = False
|
||||||
|
super().__init__(config)
|
||||||
|
self.roberta = XLMRobertaModel(config)
|
||||||
|
self.transformation = nn.Linear(config.hidden_size,config.project_dim)
|
||||||
|
# self.pre_LN=nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
|
self.tokenizer = XLMRobertaTokenizer.from_pretrained('xlm-roberta-large')
|
||||||
|
# self.pooler = lambda x: x[:,0]
|
||||||
|
# self.post_init()
|
||||||
|
|
||||||
|
self.has_pre_transformation = True
|
||||||
|
if self.has_pre_transformation:
|
||||||
|
self.transformation_pre = nn.Linear(config.hidden_size, config.project_dim)
|
||||||
|
self.pre_LN = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
|
def encode(self,c):
|
||||||
|
device = next(self.parameters()).device
|
||||||
|
text = self.tokenizer(c,
|
||||||
|
truncation=True,
|
||||||
|
max_length=77,
|
||||||
|
return_length=False,
|
||||||
|
return_overflowing_tokens=False,
|
||||||
|
padding="max_length",
|
||||||
|
return_tensors="pt")
|
||||||
|
text["input_ids"] = torch.tensor(text["input_ids"]).to(device)
|
||||||
|
text["attention_mask"] = torch.tensor(
|
||||||
|
text['attention_mask']).to(device)
|
||||||
|
features = self(**text)
|
||||||
|
return features['projection_state']
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: Optional[torch.Tensor] = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
token_type_ids: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.Tensor] = None,
|
||||||
|
head_mask: Optional[torch.Tensor] = None,
|
||||||
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
|
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||||
|
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
) :
|
||||||
|
r"""
|
||||||
|
"""
|
||||||
|
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
|
||||||
|
outputs = self.roberta(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
token_type_ids=token_type_ids,
|
||||||
|
position_ids=position_ids,
|
||||||
|
head_mask=head_mask,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=True,
|
||||||
|
return_dict=return_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
# # last module outputs
|
||||||
|
# sequence_output = outputs[0]
|
||||||
|
|
||||||
|
|
||||||
|
# # project every module
|
||||||
|
# sequence_output_ln = self.pre_LN(sequence_output)
|
||||||
|
|
||||||
|
# # pooler
|
||||||
|
# pooler_output = self.pooler(sequence_output_ln)
|
||||||
|
# pooler_output = self.transformation(pooler_output)
|
||||||
|
# projection_state = self.transformation(outputs.last_hidden_state)
|
||||||
|
|
||||||
|
if self.has_pre_transformation:
|
||||||
|
sequence_output2 = outputs["hidden_states"][-2]
|
||||||
|
sequence_output2 = self.pre_LN(sequence_output2)
|
||||||
|
projection_state2 = self.transformation_pre(sequence_output2)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"projection_state": projection_state2,
|
||||||
|
"last_hidden_state": outputs.last_hidden_state,
|
||||||
|
"hidden_states": outputs.hidden_states,
|
||||||
|
"attentions": outputs.attentions,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
projection_state = self.transformation(outputs.last_hidden_state)
|
||||||
|
return {
|
||||||
|
"projection_state": projection_state,
|
||||||
|
"last_hidden_state": outputs.last_hidden_state,
|
||||||
|
"hidden_states": outputs.hidden_states,
|
||||||
|
"attentions": outputs.attentions,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# return {
|
||||||
|
# 'pooler_output':pooler_output,
|
||||||
|
# 'last_hidden_state':outputs.last_hidden_state,
|
||||||
|
# 'hidden_states':outputs.hidden_states,
|
||||||
|
# 'attentions':outputs.attentions,
|
||||||
|
# 'projection_state':projection_state,
|
||||||
|
# 'sequence_out': sequence_output
|
||||||
|
# }
|
||||||
|
|
||||||
|
|
||||||
|
class RobertaSeriesModelWithTransformation(BertSeriesModelWithTransformation):
|
||||||
|
base_model_prefix = 'roberta'
|
||||||
|
config_class= RobertaSeriesConfig
|
||||||
@@ -16,6 +16,7 @@ exclude = [
|
|||||||
|
|
||||||
ignore = [
|
ignore = [
|
||||||
"E501", # Line too long
|
"E501", # Line too long
|
||||||
|
"E721", # Do not compare types, use `isinstance`
|
||||||
"E731", # Do not assign a `lambda` expression, use a `def`
|
"E731", # Do not assign a `lambda` expression, use a `def`
|
||||||
|
|
||||||
"I001", # Import block is un-sorted or un-formatted
|
"I001", # Import block is un-sorted or un-formatted
|
||||||
|
|||||||
+1
-1
@@ -8,7 +8,7 @@ clean-fid
|
|||||||
einops
|
einops
|
||||||
fastapi>=0.90.1
|
fastapi>=0.90.1
|
||||||
gfpgan
|
gfpgan
|
||||||
gradio==3.41.2
|
gradio==4.7.1
|
||||||
inflection
|
inflection
|
||||||
jsonmerge
|
jsonmerge
|
||||||
kornia
|
kornia
|
||||||
|
|||||||
@@ -5,9 +5,9 @@ basicsr==1.4.2
|
|||||||
blendmodes==2022
|
blendmodes==2022
|
||||||
clean-fid==0.1.35
|
clean-fid==0.1.35
|
||||||
einops==0.4.1
|
einops==0.4.1
|
||||||
fastapi==0.94.0
|
fastapi==0.104.1
|
||||||
gfpgan==1.3.8
|
gfpgan==1.3.8
|
||||||
gradio==3.41.2
|
gradio==4.7.1
|
||||||
httpcore==0.15
|
httpcore==0.15
|
||||||
inflection==0.5.1
|
inflection==0.5.1
|
||||||
jsonmerge==1.8.0
|
jsonmerge==1.8.0
|
||||||
@@ -29,3 +29,4 @@ torch
|
|||||||
torchdiffeq==0.2.3
|
torchdiffeq==0.2.3
|
||||||
torchsde==0.2.6
|
torchsde==0.2.6
|
||||||
transformers==4.30.2
|
transformers==4.30.2
|
||||||
|
httpx==0.24.1
|
||||||
|
|||||||
@@ -124,16 +124,29 @@ document.addEventListener("DOMContentLoaded", function() {
|
|||||||
* Add a ctrl+enter as a shortcut to start a generation
|
* Add a ctrl+enter as a shortcut to start a generation
|
||||||
*/
|
*/
|
||||||
document.addEventListener('keydown', function(e) {
|
document.addEventListener('keydown', function(e) {
|
||||||
var handled = false;
|
const isEnter = e.key === 'Enter' || e.keyCode === 13;
|
||||||
if (e.key !== undefined) {
|
const isModifierKey = e.metaKey || e.ctrlKey || e.altKey;
|
||||||
if ((e.key == "Enter" && (e.metaKey || e.ctrlKey || e.altKey))) handled = true;
|
|
||||||
} else if (e.keyCode !== undefined) {
|
const interruptButton = get_uiCurrentTabContent().querySelector('button[id$=_interrupt]');
|
||||||
if ((e.keyCode == 13 && (e.metaKey || e.ctrlKey || e.altKey))) handled = true;
|
const generateButton = get_uiCurrentTabContent().querySelector('button[id$=_generate]');
|
||||||
|
|
||||||
|
if (isEnter && isModifierKey) {
|
||||||
|
if (interruptButton.style.display === 'block') {
|
||||||
|
interruptButton.click();
|
||||||
|
const callback = (mutationList) => {
|
||||||
|
for (const mutation of mutationList) {
|
||||||
|
if (mutation.type === 'attributes' && mutation.attributeName === 'style') {
|
||||||
|
if (interruptButton.style.display === 'none') {
|
||||||
|
generateButton.click();
|
||||||
|
observer.disconnect();
|
||||||
}
|
}
|
||||||
if (handled) {
|
}
|
||||||
var button = get_uiCurrentTabContent().querySelector('button[id$=_generate]');
|
}
|
||||||
if (button) {
|
};
|
||||||
button.click();
|
const observer = new MutationObserver(callback);
|
||||||
|
observer.observe(interruptButton, {attributes: true});
|
||||||
|
} else {
|
||||||
|
generateButton.click();
|
||||||
}
|
}
|
||||||
e.preventDefault();
|
e.preventDefault();
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -114,6 +114,7 @@ class Script(scripts.Script):
|
|||||||
def ui(self, is_img2img):
|
def ui(self, is_img2img):
|
||||||
checkbox_iterate = gr.Checkbox(label="Iterate seed every line", value=False, elem_id=self.elem_id("checkbox_iterate"))
|
checkbox_iterate = gr.Checkbox(label="Iterate seed every line", value=False, elem_id=self.elem_id("checkbox_iterate"))
|
||||||
checkbox_iterate_batch = gr.Checkbox(label="Use same random seed for all lines", value=False, elem_id=self.elem_id("checkbox_iterate_batch"))
|
checkbox_iterate_batch = gr.Checkbox(label="Use same random seed for all lines", value=False, elem_id=self.elem_id("checkbox_iterate_batch"))
|
||||||
|
prompt_position = gr.Radio(["start", "end"], label="Insert prompts at the", elem_id=self.elem_id("prompt_position"), value="start")
|
||||||
|
|
||||||
prompt_txt = gr.Textbox(label="List of prompt inputs", lines=1, elem_id=self.elem_id("prompt_txt"))
|
prompt_txt = gr.Textbox(label="List of prompt inputs", lines=1, elem_id=self.elem_id("prompt_txt"))
|
||||||
file = gr.File(label="Upload prompt inputs", type='binary', elem_id=self.elem_id("file"))
|
file = gr.File(label="Upload prompt inputs", type='binary', elem_id=self.elem_id("file"))
|
||||||
@@ -124,9 +125,9 @@ class Script(scripts.Script):
|
|||||||
# We don't shrink back to 1, because that causes the control to ignore [enter], and it may
|
# We don't shrink back to 1, because that causes the control to ignore [enter], and it may
|
||||||
# be unclear to the user that shift-enter is needed.
|
# be unclear to the user that shift-enter is needed.
|
||||||
prompt_txt.change(lambda tb: gr.update(lines=7) if ("\n" in tb) else gr.update(lines=2), inputs=[prompt_txt], outputs=[prompt_txt], show_progress=False)
|
prompt_txt.change(lambda tb: gr.update(lines=7) if ("\n" in tb) else gr.update(lines=2), inputs=[prompt_txt], outputs=[prompt_txt], show_progress=False)
|
||||||
return [checkbox_iterate, checkbox_iterate_batch, prompt_txt]
|
return [checkbox_iterate, checkbox_iterate_batch, prompt_position, prompt_txt]
|
||||||
|
|
||||||
def run(self, p, checkbox_iterate, checkbox_iterate_batch, prompt_txt: str):
|
def run(self, p, checkbox_iterate, checkbox_iterate_batch, prompt_position, prompt_txt: str):
|
||||||
lines = [x for x in (x.strip() for x in prompt_txt.splitlines()) if x]
|
lines = [x for x in (x.strip() for x in prompt_txt.splitlines()) if x]
|
||||||
|
|
||||||
p.do_not_save_grid = True
|
p.do_not_save_grid = True
|
||||||
@@ -167,6 +168,18 @@ class Script(scripts.Script):
|
|||||||
else:
|
else:
|
||||||
setattr(copy_p, k, v)
|
setattr(copy_p, k, v)
|
||||||
|
|
||||||
|
if args.get("prompt") and p.prompt:
|
||||||
|
if prompt_position == "start":
|
||||||
|
copy_p.prompt = args.get("prompt") + " " + p.prompt
|
||||||
|
else:
|
||||||
|
copy_p.prompt = p.prompt + " " + args.get("prompt")
|
||||||
|
|
||||||
|
if args.get("negative_prompt") and p.negative_prompt:
|
||||||
|
if prompt_position == "start":
|
||||||
|
copy_p.negative_prompt = args.get("negative_prompt") + " " + p.negative_prompt
|
||||||
|
else:
|
||||||
|
copy_p.negative_prompt = p.negative_prompt + " " + args.get("negative_prompt")
|
||||||
|
|
||||||
proc = process_images(copy_p)
|
proc = process_images(copy_p)
|
||||||
images += proc.images
|
images += proc.images
|
||||||
|
|
||||||
|
|||||||
@@ -204,6 +204,11 @@ div.block.gradio-accordion {
|
|||||||
padding: 8px 8px;
|
padding: 8px 8px;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
input[type="checkbox"].input-accordion-checkbox{
|
||||||
|
vertical-align: sub;
|
||||||
|
margin-right: 0.5em;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
/* txt2img/img2img specific */
|
/* txt2img/img2img specific */
|
||||||
|
|
||||||
@@ -291,6 +296,13 @@ div.block.gradio-accordion {
|
|||||||
min-height: 4.5em;
|
min-height: 4.5em;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#txt2img_generate, #img2img_generate {
|
||||||
|
min-height: 4.5em;
|
||||||
|
}
|
||||||
|
.generate-box-compact #txt2img_generate, .generate-box-compact #img2img_generate {
|
||||||
|
min-height: 3em;
|
||||||
|
}
|
||||||
|
|
||||||
@media screen and (min-width: 2500px) {
|
@media screen and (min-width: 2500px) {
|
||||||
#txt2img_gallery, #img2img_gallery {
|
#txt2img_gallery, #img2img_gallery {
|
||||||
min-height: 768px;
|
min-height: 768px;
|
||||||
@@ -398,6 +410,15 @@ div#extras_scale_to_tab div.form{
|
|||||||
min-width: 0.5em;
|
min-width: 0.5em;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
div.toprow-compact-stylerow{
|
||||||
|
margin: 0.5em 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
div.toprow-compact-tools{
|
||||||
|
min-width: fit-content !important;
|
||||||
|
max-width: fit-content;
|
||||||
|
}
|
||||||
|
|
||||||
/* settings */
|
/* settings */
|
||||||
#quicksettings {
|
#quicksettings {
|
||||||
align-items: end;
|
align-items: end;
|
||||||
@@ -441,6 +462,15 @@ div#extras_scale_to_tab div.form{
|
|||||||
padding: 4px;
|
padding: 4px;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#settings > div.tab-nav .settings-category{
|
||||||
|
display: block;
|
||||||
|
margin: 1em 0 0.25em 0;
|
||||||
|
font-weight: bold;
|
||||||
|
text-decoration: underline;
|
||||||
|
cursor: default;
|
||||||
|
user-select: none;
|
||||||
|
}
|
||||||
|
|
||||||
#settings_result{
|
#settings_result{
|
||||||
height: 1.4em;
|
height: 1.4em;
|
||||||
margin: 0 1.2em;
|
margin: 0 1.2em;
|
||||||
@@ -520,7 +550,8 @@ table.popup-table .link{
|
|||||||
height: 20px;
|
height: 20px;
|
||||||
background: #b4c0cc;
|
background: #b4c0cc;
|
||||||
border-radius: 3px !important;
|
border-radius: 3px !important;
|
||||||
top: -20px;
|
top: -14px;
|
||||||
|
left: 0px;
|
||||||
width: 100%;
|
width: 100%;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -818,6 +849,18 @@ footer {
|
|||||||
|
|
||||||
/* extra networks UI */
|
/* extra networks UI */
|
||||||
|
|
||||||
|
.extra-page > div.gap{
|
||||||
|
gap: 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
.extra-page-prompts{
|
||||||
|
margin-bottom: 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
.extra-page-prompts.extra-page-prompts-active{
|
||||||
|
margin-bottom: 1em;
|
||||||
|
}
|
||||||
|
|
||||||
.extra-network-cards{
|
.extra-network-cards{
|
||||||
height: calc(100vh - 24rem);
|
height: calc(100vh - 24rem);
|
||||||
overflow: clip scroll;
|
overflow: clip scroll;
|
||||||
|
|||||||
@@ -1,5 +1,9 @@
|
|||||||
@echo off
|
@echo off
|
||||||
|
|
||||||
|
if exist webui.settings.bat (
|
||||||
|
call webui.settings.bat
|
||||||
|
)
|
||||||
|
|
||||||
if not defined PYTHON (set PYTHON=python)
|
if not defined PYTHON (set PYTHON=python)
|
||||||
if defined GIT (set "GIT_PYTHON_GIT_EXECUTABLE=%GIT%")
|
if defined GIT (set "GIT_PYTHON_GIT_EXECUTABLE=%GIT%")
|
||||||
if not defined VENV_DIR (set "VENV_DIR=%~dp0%venv")
|
if not defined VENV_DIR (set "VENV_DIR=%~dp0%venv")
|
||||||
|
|||||||
@@ -89,7 +89,7 @@ delimiter="################################################################"
|
|||||||
|
|
||||||
printf "\n%s\n" "${delimiter}"
|
printf "\n%s\n" "${delimiter}"
|
||||||
printf "\e[1m\e[32mInstall script for stable-diffusion + Web UI\n"
|
printf "\e[1m\e[32mInstall script for stable-diffusion + Web UI\n"
|
||||||
printf "\e[1m\e[34mTested on Debian 11 (Bullseye)\e[0m"
|
printf "\e[1m\e[34mTested on Debian 11 (Bullseye), Fedora 34+ and openSUSE Leap 15.4 or newer.\e[0m"
|
||||||
printf "\n%s\n" "${delimiter}"
|
printf "\n%s\n" "${delimiter}"
|
||||||
|
|
||||||
# Do not run as root
|
# Do not run as root
|
||||||
@@ -223,7 +223,7 @@ fi
|
|||||||
# Try using TCMalloc on Linux
|
# Try using TCMalloc on Linux
|
||||||
prepare_tcmalloc() {
|
prepare_tcmalloc() {
|
||||||
if [[ "${OSTYPE}" == "linux"* ]] && [[ -z "${NO_TCMALLOC}" ]] && [[ -z "${LD_PRELOAD}" ]]; then
|
if [[ "${OSTYPE}" == "linux"* ]] && [[ -z "${NO_TCMALLOC}" ]] && [[ -z "${LD_PRELOAD}" ]]; then
|
||||||
TCMALLOC="$(PATH=/usr/sbin:$PATH ldconfig -p | grep -Po "libtcmalloc(_minimal|)\.so\.\d" | head -n 1)"
|
TCMALLOC="$(PATH=/sbin:$PATH ldconfig -p | grep -Po "libtcmalloc(_minimal|)\.so\.\d" | head -n 1)"
|
||||||
if [[ ! -z "${TCMALLOC}" ]]; then
|
if [[ ! -z "${TCMALLOC}" ]]; then
|
||||||
echo "Using TCMalloc: ${TCMALLOC}"
|
echo "Using TCMalloc: ${TCMALLOC}"
|
||||||
export LD_PRELOAD="${TCMALLOC}"
|
export LD_PRELOAD="${TCMALLOC}"
|
||||||
|
|||||||
Reference in New Issue
Block a user