Compare commits
380 Commits
refiner_alt
...
v1.6.0-RC
| Author | SHA1 | Date | |
|---|---|---|---|
| 935d9d899c | |||
| 189229bbf9 | |||
| b6c0217405 | |||
| 995ff5902f | |||
| b0211ff7f8 | |||
| 0027ce1f6e | |||
| 06f18186dc | |||
| 2c570f641c | |||
| fa68d66c98 | |||
| 32e790a47e | |||
| ddf3d1a7ac | |||
| 31f2be3dce | |||
| 250c416474 | |||
| 12171ca961 | |||
| bae91855f5 | |||
| f29b4cd7cb | |||
| 0232a987bb | |||
| 6a87e35bef | |||
| 8fd1558179 | |||
| 04cfcf91d9 | |||
| 3ec5ce9416 | |||
| 016554e437 | |||
| bb7dd7b646 | |||
| 9c82b34be7 | |||
| 54fbdcf467 | |||
| 2e9289bcbf | |||
| 7fd0ccdffc | |||
| ed49c7c246 | |||
| 0d90064e9e | |||
| 9158d0fd12 | |||
| c4b11ec54e | |||
| 9e4019c5ff | |||
| 96edfb560b | |||
| f6c52f4f41 | |||
| 7d94e5f33b | |||
| e8a9d213e4 | |||
| 0998256fc5 | |||
| a459075d26 | |||
| 70283a9f4a | |||
| e1b37a066d | |||
| d7c9c61420 | |||
| 79fd17ee63 | |||
| 7a3a6e3855 | |||
| f83996cd9f | |||
| 7da73cbcca | |||
| 299b8096bc | |||
| aed52d1632 | |||
| 9dce2aa735 | |||
| 953c3eab7b | |||
| 18fb522660 | |||
| bd6f070882 | |||
| a3fdef4ed4 | |||
| dfd6ea3fca | |||
| 71a0f6ef85 | |||
| d02c4da483 | |||
| df595ae313 | |||
| b4d21e7113 | |||
| d722d6de36 | |||
| 76ae1019b9 | |||
| a7f18b2297 | |||
| d3632368e6 | |||
| 5a3fe7a8d1 | |||
| be301f224d | |||
| db6c7ff084 | |||
| 268dc9b308 | |||
| 549b0fc526 | |||
| 42b72fe246 | |||
| f65d0dc081 | |||
| af5d2e8e5f | |||
| 5159edbf0e | |||
| 4a2bf65fea | |||
| db5c304e29 | |||
| a0d721e109 | |||
| 2c10fda399 | |||
| 7ca20adc6d | |||
| e0e64bcdf6 | |||
| 499cef3c2b | |||
| 2571767204 | |||
| 36ecff71ae | |||
| a3c8510c05 | |||
| 042e1d5d0b | |||
| ae17c775dc | |||
| 8ce613bb3a | |||
| 9d2299ed0b | |||
| 35db3665b3 | |||
| 5a5913828c | |||
| 448d6bef37 | |||
| 7056fdf2be | |||
| 3d81fd714b | |||
| 58a9082411 | |||
| 99a64edea8 | |||
| d75b521af8 | |||
| 296c8f6a4a | |||
| 99cd8de234 | |||
| 5590be7a8c | |||
| f084e6bbd0 | |||
| cd719b08bd | |||
| 90e560bb75 | |||
| 9182dd7e5d | |||
| f739e3e05d | |||
| e7a044a2d1 | |||
| ca72db23d2 | |||
| e4a2a705ad | |||
| bb91bb5e83 | |||
| 4760c3c0b5 | |||
| 1631e96a98 | |||
| 61c1261e4e | |||
| 956e1d8d90 | |||
| 453a5ac1d0 | |||
| 64d5fa1efd | |||
| 9d1d63afca | |||
| 44d4e7c500 | |||
| f89f01f9d8 | |||
| 640cb1bb8d | |||
| a81dc43fcd | |||
| 8a1f32b6a5 | |||
| f9c2216ffa | |||
| 959f8b32d5 | |||
| 13f1357b7f | |||
| 3ce5fb8e5c | |||
| 46e8898f65 | |||
| 3003b10e0a | |||
| 0dc74545c0 | |||
| 254be4eeb2 | |||
| 541ef9247c | |||
| e1a29266b2 | |||
| fc3a57ff96 | |||
| 0cf85b24df | |||
| eaba3d7349 | |||
| 57e59c14c8 | |||
| 0815c45bcd | |||
| 023a3a98a1 | |||
| 86221269f9 | |||
| d9ddc5d4cd | |||
| a7f7701b64 | |||
| fd563e3274 | |||
| d09d33bc2d | |||
| 7083391931 | |||
| 0f77139253 | |||
| 5b28b7dbc7 | |||
| 85fcb7b8df | |||
| 8b181c812f | |||
| f01682ee01 | |||
| aa57a89a21 | |||
| 7327be97aa | |||
| 63f881a5f0 | |||
| dc0e63a48a | |||
| f117bb64fc | |||
| 54209c1639 | |||
| ec505bac41 | |||
| 2154662826 | |||
| 9ab52caf02 | |||
| bc61ad9ec8 | |||
| b0a6d61d73 | |||
| 371b24b17c | |||
| 79d4e81984 | |||
| 7e77a38cbc | |||
| d6b79b9963 | |||
| 6f86573247 | |||
| 45be87afc6 | |||
| 5daf7983d1 | |||
| f23e5ce2da | |||
| e56b7c8419 | |||
| 2359c07ddf | |||
| bc63339df3 | |||
| a2e213bc7b | |||
| 6bfd4dfecf | |||
| 99ab3d43a7 | |||
| 353c876172 | |||
| d61e31bae6 | |||
| f3b96d4998 | |||
| abbecb3e73 | |||
| b39d9364d8 | |||
| c7c16f805c | |||
| f37cc5f5e1 | |||
| 3a4bee1096 | |||
| c1a31ec9f7 | |||
| f70ded8936 | |||
| aa26f8eb40 | |||
| cda2f0a162 | |||
| aeb76ef174 | |||
| e7c03ccdce | |||
| d9cc27cb29 | |||
| 0ea61a74be | |||
| 007ecfbb29 | |||
| 9cd0475c08 | |||
| 8452708560 | |||
| 16781ba09a | |||
| 09ff5b5416 | |||
| f093c9d39d | |||
| 2035cbbd5d | |||
| 5df535b7c2 | |||
| 232c931f40 | |||
| f4dbb0c820 | |||
| 9058620cec | |||
| 2489252099 | |||
| 87dd685224 | |||
| abfa4ad8bc | |||
| 3163d1269a | |||
| 1c6ca09992 | |||
| d73db17ee3 | |||
| 127ab9114f | |||
| d53f3b5596 | |||
| d41a5bb97d | |||
| 551d2fabcc | |||
| db40d26d08 | |||
| 525b55b1e9 | |||
| ce0829d711 | |||
| ac790fc49b | |||
| f4757032e7 | |||
| d1a70c3f05 | |||
| d8419762c1 | |||
| 60a7405165 | |||
| 1ae9dacb4b | |||
| 69f49c8d39 | |||
| 822597db49 | |||
| 7fa5ee54b1 | |||
| da80d649fd | |||
| 61673451ff | |||
| 599f61a1e0 | |||
| 0e3bac8132 | |||
| fa9370b741 | |||
| 5881dcb887 | |||
| a2b8305096 | |||
| bd4da4474b | |||
| dc5b5ee9c6 | |||
| 299eb54308 | |||
| 8d9ca46e0a | |||
| b2080756fc | |||
| 9d0ec13596 | |||
| 6816ad5ed8 | |||
| 4e8690906c | |||
| f0b72b8121 | |||
| 7a68ac6615 | |||
| f131f84e13 | |||
| 6aa26a26d5 | |||
| fd617fad00 | |||
| d20eb11c9e | |||
| c8d453e915 | |||
| b293ed3061 | |||
| 64311faa68 | |||
| 26c92f056a | |||
| ebc1bafb03 | |||
| 9dae70da79 | |||
| f57bc1a21b | |||
| af27b716e5 | |||
| 7c9c19b2a2 | |||
| 3b2f51602d | |||
| ae6b30907d | |||
| 77c52ea701 | |||
| 3c00e41ec0 | |||
| 340c1cc68d | |||
| 2c79f2af6e | |||
| 4fafc34e49 | |||
| d456fb797a | |||
| 458eda1321 | |||
| 54f926b11d | |||
| a75d756a6f | |||
| 863613293e | |||
| 9af5cce4c7 | |||
| e0906096c5 | |||
| 4549f2a9cc | |||
| f4979422dd | |||
| 5a705c2468 | |||
| 36762f0eaf | |||
| ac8a5d18d3 | |||
| 70a01cd444 | |||
| 959404e0e2 | |||
| 887bcfdf65 | |||
| 40ccd26b19 | |||
| 4412398c4b | |||
| 942d7a118a | |||
| 070b034cd5 | |||
| 9d78d317ae | |||
| 045f740892 | |||
| b13806c150 | |||
| 4f6582cb66 | |||
| 1b3093fe3a | |||
| 237b704172 | |||
| 4d93f48f09 | |||
| ed01d2ee3b | |||
| 386202895f | |||
| 0883810592 | |||
| faca86620d | |||
| 6c23061a7d | |||
| 33446acf47 | |||
| 0a0a9d4fe9 | |||
| 9199b6b7eb | |||
| 2c5106ed06 | |||
| 6ed1541ef5 | |||
| 736aaf348b | |||
| f0edd26998 | |||
| ff1bfd01ba | |||
| 2ceb4f81e2 | |||
| 259805947e | |||
| 66c32e40e8 | |||
| edfae9e78a | |||
| d1ba46b6e1 | |||
| c7b9394daf | |||
| ab42f81c75 | |||
| 8b7b99f8d5 | |||
| 4a64d34001 | |||
| 95821f0132 | |||
| a2a97e57f0 | |||
| f2ebcee7c4 | |||
| eed963e972 | |||
| 7ba8f11688 | |||
| aa10faa591 | |||
| 358f55db6a | |||
| c8c48640e6 | |||
| 0cac6ab615 | |||
| 2617598b7a | |||
| 8eea891718 | |||
| 386245a264 | |||
| 7d81ecbea6 | |||
| 8cf8fc6794 | |||
| da0712ee7d | |||
| a6f840b4dc | |||
| 0d5dc9a6e7 | |||
| d81d3fa8cd | |||
| c102780693 | |||
| 7f9dbc45b1 | |||
| 08e538e2e6 | |||
| bd4b4292ef | |||
| e12a1be1ca | |||
| a74c014425 | |||
| a2360de3f3 | |||
| 0e83c67525 | |||
| 1aefb50259 | |||
| ec194b6374 | |||
| f8ff8c0638 | |||
| 54c3e5c913 | |||
| 70c63c1208 | |||
| bc7906e6d6 | |||
| ae1bde1aa1 | |||
| a8a256f9b5 | |||
| 8285a149d8 | |||
| 2a72d76d6f | |||
| 2d8e4a6544 | |||
| c721884cf5 | |||
| ee2b8f2e1b | |||
| a3e27019e4 | |||
| 7e88f57aaa | |||
| 902f8cf292 | |||
| f17c8c2eff | |||
| c75bda867b | |||
| 8c200c2156 | |||
| b0f7f4a991 | |||
| 01997f45ba | |||
| 251140fc88 | |||
| aea0fa9fd5 | |||
| 912356133a | |||
| 250a95b6fe | |||
| fd67eafc65 | |||
| 4c72377bbf | |||
| 7d8f55ec7c | |||
| 0ea20a0d52 | |||
| 5cf37ca89f | |||
| 3453710d10 | |||
| 6e7828e1d2 | |||
| c96e4750d8 | |||
| 7bcfb4654f | |||
| 976963ab6d | |||
| 5a0db84b6c | |||
| 5a38a9c0ee | |||
| 956e69bf3a | |||
| f1975b0213 | |||
| e866c35462 | |||
| 56888644a6 | |||
| 3ca3c7f1c6 | |||
| a1825ee741 | |||
| 8b036d8a82 | |||
| c46525b70b | |||
| 955542a654 | |||
| 2f1d5b6b04 | |||
| 56236dfd3f | |||
| f2a4073aea | |||
| 9421c11346 | |||
| b2f0040da7 | |||
| 7afe7375e1 |
+3
-1
@@ -90,6 +90,8 @@ module.exports = {
|
|||||||
// localStorage.js
|
// localStorage.js
|
||||||
localSet: "readonly",
|
localSet: "readonly",
|
||||||
localGet: "readonly",
|
localGet: "readonly",
|
||||||
localRemove: "readonly"
|
localRemove: "readonly",
|
||||||
|
// resizeHandle.js
|
||||||
|
setupResizeHandle: "writable"
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|||||||
+138
@@ -1,3 +1,141 @@
|
|||||||
|
## 1.6.0
|
||||||
|
|
||||||
|
### Features:
|
||||||
|
* refiner support [#12371](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12371)
|
||||||
|
* add NV option for Random number generator source setting, which allows to generate same pictures on CPU/AMD/Mac as on NVidia videocards
|
||||||
|
* add style editor dialog
|
||||||
|
* hires fix: add an option to use a different checkpoint for second pass ([#12181](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12181))
|
||||||
|
* option to keep multiple loaded models in memory ([#12227](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12227))
|
||||||
|
* new samplers: Restart, DPM++ 2M SDE Exponential, DPM++ 2M SDE Heun, DPM++ 2M SDE Heun Karras, DPM++ 2M SDE Heun Exponential, DPM++ 3M SDE, DPM++ 3M SDE Karras, DPM++ 3M SDE Exponential ([#12300](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12300), [#12519](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12519), [#12542](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12542))
|
||||||
|
* rework DDIM, PLMS, UniPC to use CFG denoiser same as in k-diffusion samplers:
|
||||||
|
* makes all of them work with img2img
|
||||||
|
* makes prompt composition posssible (AND)
|
||||||
|
* makes them available for SDXL
|
||||||
|
* always show extra networks tabs in the UI ([#11808](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/11808))
|
||||||
|
* use less RAM when creating models ([#11958](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/11958), [#12599](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12599))
|
||||||
|
* textual inversion inference support for SDXL
|
||||||
|
* extra networks UI: show metadata for SD checkpoints
|
||||||
|
* checkpoint merger: add metadata support
|
||||||
|
* prompt editing and attention: add support for whitespace after the number ([ red : green : 0.5 ]) (seed breaking change) ([#12177](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12177))
|
||||||
|
* VAE: allow selecting own VAE for each checkpoint (in user metadata editor)
|
||||||
|
* VAE: add selected VAE to infotext
|
||||||
|
* options in main UI: add own separate setting for txt2img and img2img, correctly read values from pasted infotext, add setting for column count ([#12551](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12551))
|
||||||
|
* add resize handle to txt2img and img2img tabs, allowing to change the amount of horizontable space given to generation parameters and resulting image gallery ([#12687](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12687), [#12723](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12723))
|
||||||
|
* change default behavior for batching cond/uncond -- now it's on by default, and is disabled by an UI setting (Optimizatios -> Batch cond/uncond) - if you are on lowvram/medvram and are getting OOM exceptions, you will need to enable it
|
||||||
|
* show current position in queue and make it so that requests are processed in the order of arrival ([#12707](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12707))
|
||||||
|
* add `--medvram-sdxl` flag that only enables `--medvram` for SDXL models
|
||||||
|
* prompt editing timeline has separate range for first pass and hires-fix pass (seed breaking change) ([#12457](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12457))
|
||||||
|
|
||||||
|
### Minor:
|
||||||
|
* img2img batch: RAM savings, VRAM savings, .tif, .tiff in img2img batch ([#12120](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12120), [#12514](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12514), [#12515](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12515))
|
||||||
|
* postprocessing/extras: RAM savings ([#12479](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12479))
|
||||||
|
* XYZ: in the axis labels, remove pathnames from model filenames
|
||||||
|
* XYZ: support hires sampler ([#12298](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12298))
|
||||||
|
* XYZ: new option: use text inputs instead of dropdowns ([#12491](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12491))
|
||||||
|
* add gradio version warning
|
||||||
|
* sort list of VAE checkpoints ([#12297](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12297))
|
||||||
|
* use transparent white for mask in inpainting, along with an option to select the color ([#12326](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12326))
|
||||||
|
* move some settings to their own section: img2img, VAE
|
||||||
|
* add checkbox to show/hide dirs for extra networks
|
||||||
|
* Add TAESD(or more) options for all the VAE encode/decode operation ([#12311](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12311))
|
||||||
|
* gradio theme cache, new gradio themes, along with explanation that the user can input his own values ([#12346](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12346), [#12355](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12355))
|
||||||
|
* sampler fixes/tweaks: s_tmax, s_churn, s_noise, s_tmax ([#12354](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12354), [#12356](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12356), [#12357](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12357), [#12358](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12358), [#12375](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12375), [#12521](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12521))
|
||||||
|
* update README.md with correct instructions for Linux installation ([#12352](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12352))
|
||||||
|
* option to not save incomplete images, on by default ([#12338](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12338))
|
||||||
|
* enable cond cache by default
|
||||||
|
* git autofix for repos that are corrupted ([#12230](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12230))
|
||||||
|
* allow to open images in new browser tab by middle mouse button ([#12379](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12379))
|
||||||
|
* automatically open webui in browser when running "locally" ([#12254](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12254))
|
||||||
|
* put commonly used samplers on top, make DPM++ 2M Karras the default choice
|
||||||
|
* zoom and pan: option to auto-expand a wide image, improved integration ([#12413](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12413), [#12727](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12727))
|
||||||
|
* option to cache Lora networks in memory
|
||||||
|
* rework hires fix UI to use accordion
|
||||||
|
* face restoration and tiling moved to settings - use "Options in main UI" setting if you want them back
|
||||||
|
* change quicksettings items to have variable width
|
||||||
|
* Lora: add Norm module, add support for bias ([#12503](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12503))
|
||||||
|
* Lora: output warnings in UI rather than fail for unfitting loras; switch to logging for error output in console
|
||||||
|
* support search and display of hashes for all extra network items ([#12510](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12510))
|
||||||
|
* add extra noise param for img2img operations ([#12564](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12564))
|
||||||
|
* support for Lora with bias ([#12584](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12584))
|
||||||
|
* make interrupt quicker ([#12634](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12634))
|
||||||
|
* configurable gallery height ([#12648](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12648))
|
||||||
|
* make results column sticky ([#12645](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12645))
|
||||||
|
* more hash filename patterns ([#12639](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12639))
|
||||||
|
* make image viewer actually fit the whole page ([#12635](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12635))
|
||||||
|
* make progress bar work independently from live preview display which results in it being updated a lot more often
|
||||||
|
* forbid Full live preview method for medvram and add a setting to undo the forbidding
|
||||||
|
* make it possible to localize tooltips and placeholders
|
||||||
|
|
||||||
|
### Extensions and API:
|
||||||
|
* gradio 3.41.0
|
||||||
|
* also bump versions for packages: transformers, GitPython, accelerate, scikit-image, timm, tomesd
|
||||||
|
* support tooltip kwarg for gradio elements: gr.Textbox(label='hello', tooltip='world')
|
||||||
|
* properly clear the total console progressbar when using txt2img and img2img from API
|
||||||
|
* add cmd_arg --disable-extra-extensions and --disable-all-extensions ([#12294](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12294))
|
||||||
|
* shared.py and webui.py split into many files
|
||||||
|
* add --loglevel commandline argument for logging
|
||||||
|
* add a custom UI element that combines accordion and checkbox
|
||||||
|
* avoid importing gradio in tests because it spams warnings
|
||||||
|
* put infotext label for setting into OptionInfo definition rather than in a separate list
|
||||||
|
* make `StableDiffusionProcessingImg2Img.mask_blur` a property, make more inline with PIL `GaussianBlur` ([#12470](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12470))
|
||||||
|
* option to make scripts UI without gr.Group
|
||||||
|
* add a way for scripts to register a callback for before/after just a single component's creation
|
||||||
|
* use dataclass for StableDiffusionProcessing
|
||||||
|
* store patches for Lora in a specialized module instead of inside torch
|
||||||
|
* support http/https URLs in API ([#12663](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12663), [#12698](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12698))
|
||||||
|
* add extra noise callback ([#12616](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12616))
|
||||||
|
* dump current stack traces when exiting with SIGINT
|
||||||
|
* add type annotations for extra fields of shared.sd_model
|
||||||
|
|
||||||
|
### Bug Fixes:
|
||||||
|
* Don't crash if out of local storage quota for javascriot localStorage
|
||||||
|
* XYZ plot do not fail if an exception occurs
|
||||||
|
* fix missing TI hash in infotext if generation uses both negative and positive TI ([#12269](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12269))
|
||||||
|
* localization fixes ([#12307](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12307))
|
||||||
|
* fix sdxl model invalid configuration after the hijack
|
||||||
|
* correctly toggle extras checkbox for infotext paste ([#12304](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12304))
|
||||||
|
* open raw sysinfo link in new page ([#12318](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12318))
|
||||||
|
* prompt parser: Account for empty field in alternating words syntax ([#12319](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12319))
|
||||||
|
* add tab and carriage return to invalid filename chars ([#12327](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12327))
|
||||||
|
* fix api only Lora not working ([#12387](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12387))
|
||||||
|
* fix options in main UI misbehaving when there's just one element
|
||||||
|
* make it possible to use a sampler from infotext even if it's hidden in the dropdown
|
||||||
|
* fix styles missing from the prompt in infotext when making a grid of batch of multiplie images
|
||||||
|
* prevent bogus progress output in console when calculating hires fix dimensions
|
||||||
|
* fix --use-textbox-seed
|
||||||
|
* fix broken `Lora/Networks: use old method` option ([#12466](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12466))
|
||||||
|
* properly return `None` for VAE hash when using `--no-hashing` ([#12463](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12463))
|
||||||
|
* MPS/macOS fixes and optimizations ([#12526](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12526))
|
||||||
|
* add second_order to samplers that mistakenly didn't have it
|
||||||
|
* when refreshing cards in extra networks UI, do not discard user's custom resolution
|
||||||
|
* fix processing error that happens if batch_size is not a multiple of how many prompts/negative prompts there are ([#12509](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12509))
|
||||||
|
* fix inpaint upload for alpha masks ([#12588](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12588))
|
||||||
|
* fix exception when image sizes are not integers ([#12586](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12586))
|
||||||
|
* fix incorrect TAESD Latent scale ([#12596](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12596))
|
||||||
|
* auto add data-dir to gradio-allowed-path ([#12603](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12603))
|
||||||
|
* fix exception if extensuions dir is missing ([#12607](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12607))
|
||||||
|
* fix issues with api model-refresh and vae-refresh ([#12638](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12638))
|
||||||
|
* fix img2img background color for transparent images option not being used ([#12633](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12633))
|
||||||
|
* attempt to resolve NaN issue with unstable VAEs in fp32 mk2 ([#12630](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12630))
|
||||||
|
* implement missing undo hijack for SDXL
|
||||||
|
* fix xyz swap axes ([#12684](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12684))
|
||||||
|
* fix errors in backup/restore tab if any of config files are broken ([#12689](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12689))
|
||||||
|
* fix SD VAE switch error after model reuse ([#12685](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12685))
|
||||||
|
* fix trying to create images too large for the chosen format ([#12667](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12667))
|
||||||
|
* create Gradio temp directory if necessary ([#12717](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12717))
|
||||||
|
* prevent possible cache loss if exiting as it's being written by using an atomic operation to replace the cache with the new version
|
||||||
|
* set devices.dtype_unet correctly
|
||||||
|
* run RealESRGAN on GPU for non-CUDA devices ([#12737](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12737))
|
||||||
|
* prevent extra network buttons being obscured by description for very small card sizes ([#12745](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12745))
|
||||||
|
|
||||||
|
|
||||||
|
## 1.5.2
|
||||||
|
|
||||||
|
### Bug Fixes:
|
||||||
|
* fix memory leak when generation fails
|
||||||
|
* update doggettx cross attention optimization to not use an unreasonable amount of memory in some edge cases -- suggestion by MorkTheOrk
|
||||||
|
|
||||||
|
|
||||||
## 1.5.1
|
## 1.5.1
|
||||||
|
|
||||||
### Minor:
|
### Minor:
|
||||||
|
|||||||
@@ -0,0 +1,7 @@
|
|||||||
|
cff-version: 1.2.0
|
||||||
|
message: "If you use this software, please cite it as below."
|
||||||
|
authors:
|
||||||
|
- given-names: AUTOMATIC1111
|
||||||
|
title: "Stable Diffusion Web UI"
|
||||||
|
date-released: 2022-08-22
|
||||||
|
url: "https://github.com/AUTOMATIC1111/stable-diffusion-webui"
|
||||||
@@ -78,7 +78,7 @@ A browser interface based on Gradio library for Stable Diffusion.
|
|||||||
- Clip skip
|
- Clip skip
|
||||||
- Hypernetworks
|
- Hypernetworks
|
||||||
- Loras (same as Hypernetworks but more pretty)
|
- Loras (same as Hypernetworks but more pretty)
|
||||||
- A sparate UI where you can choose, with preview, which embeddings, hypernetworks or Loras to add to your prompt
|
- A separate UI where you can choose, with preview, which embeddings, hypernetworks or Loras to add to your prompt
|
||||||
- Can select to load a different VAE from settings screen
|
- Can select to load a different VAE from settings screen
|
||||||
- Estimated completion time in progress bar
|
- Estimated completion time in progress bar
|
||||||
- API
|
- API
|
||||||
@@ -93,7 +93,10 @@ A browser interface based on Gradio library for Stable Diffusion.
|
|||||||
- Reorder elements in the UI from settings screen
|
- Reorder elements in the UI from settings screen
|
||||||
|
|
||||||
## Installation and Running
|
## Installation and Running
|
||||||
Make sure the required [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) are met and follow the instructions available for both [NVidia](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs) (recommended) and [AMD](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-AMD-GPUs) GPUs.
|
Make sure the required [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) are met and follow the instructions available for:
|
||||||
|
- [NVidia](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs) (recommended)
|
||||||
|
- [AMD](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-AMD-GPUs) GPUs.
|
||||||
|
- [Intel CPUs, Intel GPUs (both integrated and discrete)](https://github.com/openvinotoolkit/stable-diffusion-webui/wiki/Installation-on-Intel-Silicon) (external wiki page)
|
||||||
|
|
||||||
Alternatively, use online services (like Google Colab):
|
Alternatively, use online services (like Google Colab):
|
||||||
|
|
||||||
|
|||||||
@@ -6,9 +6,14 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork):
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__('lora')
|
super().__init__('lora')
|
||||||
|
|
||||||
|
self.errors = {}
|
||||||
|
"""mapping of network names to the number of errors the network had during operation"""
|
||||||
|
|
||||||
def activate(self, p, params_list):
|
def activate(self, p, params_list):
|
||||||
additional = shared.opts.sd_lora
|
additional = shared.opts.sd_lora
|
||||||
|
|
||||||
|
self.errors.clear()
|
||||||
|
|
||||||
if additional != "None" and additional in networks.available_networks and not any(x for x in params_list if x.items[0] == additional):
|
if additional != "None" and additional in networks.available_networks and not any(x for x in params_list if x.items[0] == additional):
|
||||||
p.all_prompts = [x + f"<lora:{additional}:{shared.opts.extra_networks_default_multiplier}>" for x in p.all_prompts]
|
p.all_prompts = [x + f"<lora:{additional}:{shared.opts.extra_networks_default_multiplier}>" for x in p.all_prompts]
|
||||||
params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))
|
params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))
|
||||||
@@ -56,4 +61,7 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork):
|
|||||||
p.extra_generation_params["Lora hashes"] = ", ".join(network_hashes)
|
p.extra_generation_params["Lora hashes"] = ", ".join(network_hashes)
|
||||||
|
|
||||||
def deactivate(self, p):
|
def deactivate(self, p):
|
||||||
pass
|
if self.errors:
|
||||||
|
p.comment("Networks with errors: " + ", ".join(f"{k} ({v})" for k, v in self.errors.items()))
|
||||||
|
|
||||||
|
self.errors.clear()
|
||||||
|
|||||||
@@ -0,0 +1,31 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
import networks
|
||||||
|
from modules import patches
|
||||||
|
|
||||||
|
|
||||||
|
class LoraPatches:
|
||||||
|
def __init__(self):
|
||||||
|
self.Linear_forward = patches.patch(__name__, torch.nn.Linear, 'forward', networks.network_Linear_forward)
|
||||||
|
self.Linear_load_state_dict = patches.patch(__name__, torch.nn.Linear, '_load_from_state_dict', networks.network_Linear_load_state_dict)
|
||||||
|
self.Conv2d_forward = patches.patch(__name__, torch.nn.Conv2d, 'forward', networks.network_Conv2d_forward)
|
||||||
|
self.Conv2d_load_state_dict = patches.patch(__name__, torch.nn.Conv2d, '_load_from_state_dict', networks.network_Conv2d_load_state_dict)
|
||||||
|
self.GroupNorm_forward = patches.patch(__name__, torch.nn.GroupNorm, 'forward', networks.network_GroupNorm_forward)
|
||||||
|
self.GroupNorm_load_state_dict = patches.patch(__name__, torch.nn.GroupNorm, '_load_from_state_dict', networks.network_GroupNorm_load_state_dict)
|
||||||
|
self.LayerNorm_forward = patches.patch(__name__, torch.nn.LayerNorm, 'forward', networks.network_LayerNorm_forward)
|
||||||
|
self.LayerNorm_load_state_dict = patches.patch(__name__, torch.nn.LayerNorm, '_load_from_state_dict', networks.network_LayerNorm_load_state_dict)
|
||||||
|
self.MultiheadAttention_forward = patches.patch(__name__, torch.nn.MultiheadAttention, 'forward', networks.network_MultiheadAttention_forward)
|
||||||
|
self.MultiheadAttention_load_state_dict = patches.patch(__name__, torch.nn.MultiheadAttention, '_load_from_state_dict', networks.network_MultiheadAttention_load_state_dict)
|
||||||
|
|
||||||
|
def undo(self):
|
||||||
|
self.Linear_forward = patches.undo(__name__, torch.nn.Linear, 'forward')
|
||||||
|
self.Linear_load_state_dict = patches.undo(__name__, torch.nn.Linear, '_load_from_state_dict')
|
||||||
|
self.Conv2d_forward = patches.undo(__name__, torch.nn.Conv2d, 'forward')
|
||||||
|
self.Conv2d_load_state_dict = patches.undo(__name__, torch.nn.Conv2d, '_load_from_state_dict')
|
||||||
|
self.GroupNorm_forward = patches.undo(__name__, torch.nn.GroupNorm, 'forward')
|
||||||
|
self.GroupNorm_load_state_dict = patches.undo(__name__, torch.nn.GroupNorm, '_load_from_state_dict')
|
||||||
|
self.LayerNorm_forward = patches.undo(__name__, torch.nn.LayerNorm, 'forward')
|
||||||
|
self.LayerNorm_load_state_dict = patches.undo(__name__, torch.nn.LayerNorm, '_load_from_state_dict')
|
||||||
|
self.MultiheadAttention_forward = patches.undo(__name__, torch.nn.MultiheadAttention, 'forward')
|
||||||
|
self.MultiheadAttention_load_state_dict = patches.undo(__name__, torch.nn.MultiheadAttention, '_load_from_state_dict')
|
||||||
|
|
||||||
@@ -133,7 +133,7 @@ class NetworkModule:
|
|||||||
|
|
||||||
return 1.0
|
return 1.0
|
||||||
|
|
||||||
def finalize_updown(self, updown, orig_weight, output_shape):
|
def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None):
|
||||||
if self.bias is not None:
|
if self.bias is not None:
|
||||||
updown = updown.reshape(self.bias.shape)
|
updown = updown.reshape(self.bias.shape)
|
||||||
updown += self.bias.to(orig_weight.device, dtype=orig_weight.dtype)
|
updown += self.bias.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||||
@@ -145,7 +145,10 @@ class NetworkModule:
|
|||||||
if orig_weight.size().numel() == updown.size().numel():
|
if orig_weight.size().numel() == updown.size().numel():
|
||||||
updown = updown.reshape(orig_weight.shape)
|
updown = updown.reshape(orig_weight.shape)
|
||||||
|
|
||||||
return updown * self.calc_scale() * self.multiplier()
|
if ex_bias is not None:
|
||||||
|
ex_bias = ex_bias * self.multiplier()
|
||||||
|
|
||||||
|
return updown * self.calc_scale() * self.multiplier(), ex_bias
|
||||||
|
|
||||||
def calc_updown(self, target):
|
def calc_updown(self, target):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|||||||
@@ -14,9 +14,14 @@ class NetworkModuleFull(network.NetworkModule):
|
|||||||
super().__init__(net, weights)
|
super().__init__(net, weights)
|
||||||
|
|
||||||
self.weight = weights.w.get("diff")
|
self.weight = weights.w.get("diff")
|
||||||
|
self.ex_bias = weights.w.get("diff_b")
|
||||||
|
|
||||||
def calc_updown(self, orig_weight):
|
def calc_updown(self, orig_weight):
|
||||||
output_shape = self.weight.shape
|
output_shape = self.weight.shape
|
||||||
updown = self.weight.to(orig_weight.device, dtype=orig_weight.dtype)
|
updown = self.weight.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||||
|
if self.ex_bias is not None:
|
||||||
|
ex_bias = self.ex_bias.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||||
|
else:
|
||||||
|
ex_bias = None
|
||||||
|
|
||||||
return self.finalize_updown(updown, orig_weight, output_shape)
|
return self.finalize_updown(updown, orig_weight, output_shape, ex_bias)
|
||||||
|
|||||||
@@ -0,0 +1,28 @@
|
|||||||
|
import network
|
||||||
|
|
||||||
|
|
||||||
|
class ModuleTypeNorm(network.ModuleType):
|
||||||
|
def create_module(self, net: network.Network, weights: network.NetworkWeights):
|
||||||
|
if all(x in weights.w for x in ["w_norm", "b_norm"]):
|
||||||
|
return NetworkModuleNorm(net, weights)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class NetworkModuleNorm(network.NetworkModule):
|
||||||
|
def __init__(self, net: network.Network, weights: network.NetworkWeights):
|
||||||
|
super().__init__(net, weights)
|
||||||
|
|
||||||
|
self.w_norm = weights.w.get("w_norm")
|
||||||
|
self.b_norm = weights.w.get("b_norm")
|
||||||
|
|
||||||
|
def calc_updown(self, orig_weight):
|
||||||
|
output_shape = self.w_norm.shape
|
||||||
|
updown = self.w_norm.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||||
|
|
||||||
|
if self.b_norm is not None:
|
||||||
|
ex_bias = self.b_norm.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||||
|
else:
|
||||||
|
ex_bias = None
|
||||||
|
|
||||||
|
return self.finalize_updown(updown, orig_weight, output_shape, ex_bias)
|
||||||
@@ -1,12 +1,15 @@
|
|||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
|
|
||||||
|
import lora_patches
|
||||||
import network
|
import network
|
||||||
import network_lora
|
import network_lora
|
||||||
import network_hada
|
import network_hada
|
||||||
import network_ia3
|
import network_ia3
|
||||||
import network_lokr
|
import network_lokr
|
||||||
import network_full
|
import network_full
|
||||||
|
import network_norm
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from typing import Union
|
from typing import Union
|
||||||
@@ -19,6 +22,7 @@ module_types = [
|
|||||||
network_ia3.ModuleTypeIa3(),
|
network_ia3.ModuleTypeIa3(),
|
||||||
network_lokr.ModuleTypeLokr(),
|
network_lokr.ModuleTypeLokr(),
|
||||||
network_full.ModuleTypeFull(),
|
network_full.ModuleTypeFull(),
|
||||||
|
network_norm.ModuleTypeNorm(),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@@ -31,6 +35,8 @@ suffix_conversion = {
|
|||||||
"resnets": {
|
"resnets": {
|
||||||
"conv1": "in_layers_2",
|
"conv1": "in_layers_2",
|
||||||
"conv2": "out_layers_3",
|
"conv2": "out_layers_3",
|
||||||
|
"norm1": "in_layers_0",
|
||||||
|
"norm2": "out_layers_0",
|
||||||
"time_emb_proj": "emb_layers_1",
|
"time_emb_proj": "emb_layers_1",
|
||||||
"conv_shortcut": "skip_connection",
|
"conv_shortcut": "skip_connection",
|
||||||
}
|
}
|
||||||
@@ -190,11 +196,19 @@ def load_network(name, network_on_disk):
|
|||||||
net.modules[key] = net_module
|
net.modules[key] = net_module
|
||||||
|
|
||||||
if keys_failed_to_match:
|
if keys_failed_to_match:
|
||||||
print(f"Failed to match keys when loading network {network_on_disk.filename}: {keys_failed_to_match}")
|
logging.debug(f"Network {network_on_disk.filename} didn't match keys: {keys_failed_to_match}")
|
||||||
|
|
||||||
return net
|
return net
|
||||||
|
|
||||||
|
|
||||||
|
def purge_networks_from_memory():
|
||||||
|
while len(networks_in_memory) > shared.opts.lora_in_memory_limit and len(networks_in_memory) > 0:
|
||||||
|
name = next(iter(networks_in_memory))
|
||||||
|
networks_in_memory.pop(name, None)
|
||||||
|
|
||||||
|
devices.torch_gc()
|
||||||
|
|
||||||
|
|
||||||
def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=None):
|
def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=None):
|
||||||
already_loaded = {}
|
already_loaded = {}
|
||||||
|
|
||||||
@@ -212,15 +226,19 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No
|
|||||||
|
|
||||||
failed_to_load_networks = []
|
failed_to_load_networks = []
|
||||||
|
|
||||||
for i, name in enumerate(names):
|
for i, (network_on_disk, name) in enumerate(zip(networks_on_disk, names)):
|
||||||
net = already_loaded.get(name, None)
|
net = already_loaded.get(name, None)
|
||||||
|
|
||||||
network_on_disk = networks_on_disk[i]
|
|
||||||
|
|
||||||
if network_on_disk is not None:
|
if network_on_disk is not None:
|
||||||
|
if net is None:
|
||||||
|
net = networks_in_memory.get(name)
|
||||||
|
|
||||||
if net is None or os.path.getmtime(network_on_disk.filename) > net.mtime:
|
if net is None or os.path.getmtime(network_on_disk.filename) > net.mtime:
|
||||||
try:
|
try:
|
||||||
net = load_network(name, network_on_disk)
|
net = load_network(name, network_on_disk)
|
||||||
|
|
||||||
|
networks_in_memory.pop(name, None)
|
||||||
|
networks_in_memory[name] = net
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
errors.display(e, f"loading network {network_on_disk.filename}")
|
errors.display(e, f"loading network {network_on_disk.filename}")
|
||||||
continue
|
continue
|
||||||
@@ -231,7 +249,7 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No
|
|||||||
|
|
||||||
if net is None:
|
if net is None:
|
||||||
failed_to_load_networks.append(name)
|
failed_to_load_networks.append(name)
|
||||||
print(f"Couldn't find network with name {name}")
|
logging.info(f"Couldn't find network with name {name}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
net.te_multiplier = te_multipliers[i] if te_multipliers else 1.0
|
net.te_multiplier = te_multipliers[i] if te_multipliers else 1.0
|
||||||
@@ -240,23 +258,38 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No
|
|||||||
loaded_networks.append(net)
|
loaded_networks.append(net)
|
||||||
|
|
||||||
if failed_to_load_networks:
|
if failed_to_load_networks:
|
||||||
sd_hijack.model_hijack.comments.append("Failed to find networks: " + ", ".join(failed_to_load_networks))
|
sd_hijack.model_hijack.comments.append("Networks not found: " + ", ".join(failed_to_load_networks))
|
||||||
|
|
||||||
|
purge_networks_from_memory()
|
||||||
|
|
||||||
|
|
||||||
def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]):
|
def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]):
|
||||||
weights_backup = getattr(self, "network_weights_backup", None)
|
weights_backup = getattr(self, "network_weights_backup", None)
|
||||||
|
bias_backup = getattr(self, "network_bias_backup", None)
|
||||||
|
|
||||||
if weights_backup is None:
|
if weights_backup is None and bias_backup is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if weights_backup is not None:
|
||||||
if isinstance(self, torch.nn.MultiheadAttention):
|
if isinstance(self, torch.nn.MultiheadAttention):
|
||||||
self.in_proj_weight.copy_(weights_backup[0])
|
self.in_proj_weight.copy_(weights_backup[0])
|
||||||
self.out_proj.weight.copy_(weights_backup[1])
|
self.out_proj.weight.copy_(weights_backup[1])
|
||||||
else:
|
else:
|
||||||
self.weight.copy_(weights_backup)
|
self.weight.copy_(weights_backup)
|
||||||
|
|
||||||
|
if bias_backup is not None:
|
||||||
|
if isinstance(self, torch.nn.MultiheadAttention):
|
||||||
|
self.out_proj.bias.copy_(bias_backup)
|
||||||
|
else:
|
||||||
|
self.bias.copy_(bias_backup)
|
||||||
|
else:
|
||||||
|
if isinstance(self, torch.nn.MultiheadAttention):
|
||||||
|
self.out_proj.bias = None
|
||||||
|
else:
|
||||||
|
self.bias = None
|
||||||
|
|
||||||
def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]):
|
|
||||||
|
def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]):
|
||||||
"""
|
"""
|
||||||
Applies the currently selected set of networks to the weights of torch layer self.
|
Applies the currently selected set of networks to the weights of torch layer self.
|
||||||
If weights already have this particular set of networks applied, does nothing.
|
If weights already have this particular set of networks applied, does nothing.
|
||||||
@@ -271,7 +304,10 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
|
|||||||
wanted_names = tuple((x.name, x.te_multiplier, x.unet_multiplier, x.dyn_dim) for x in loaded_networks)
|
wanted_names = tuple((x.name, x.te_multiplier, x.unet_multiplier, x.dyn_dim) for x in loaded_networks)
|
||||||
|
|
||||||
weights_backup = getattr(self, "network_weights_backup", None)
|
weights_backup = getattr(self, "network_weights_backup", None)
|
||||||
if weights_backup is None:
|
if weights_backup is None and wanted_names != ():
|
||||||
|
if current_names != ():
|
||||||
|
raise RuntimeError("no backup weights found and current weights are not unchanged")
|
||||||
|
|
||||||
if isinstance(self, torch.nn.MultiheadAttention):
|
if isinstance(self, torch.nn.MultiheadAttention):
|
||||||
weights_backup = (self.in_proj_weight.to(devices.cpu, copy=True), self.out_proj.weight.to(devices.cpu, copy=True))
|
weights_backup = (self.in_proj_weight.to(devices.cpu, copy=True), self.out_proj.weight.to(devices.cpu, copy=True))
|
||||||
else:
|
else:
|
||||||
@@ -279,20 +315,40 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
|
|||||||
|
|
||||||
self.network_weights_backup = weights_backup
|
self.network_weights_backup = weights_backup
|
||||||
|
|
||||||
|
bias_backup = getattr(self, "network_bias_backup", None)
|
||||||
|
if bias_backup is None:
|
||||||
|
if isinstance(self, torch.nn.MultiheadAttention) and self.out_proj.bias is not None:
|
||||||
|
bias_backup = self.out_proj.bias.to(devices.cpu, copy=True)
|
||||||
|
elif getattr(self, 'bias', None) is not None:
|
||||||
|
bias_backup = self.bias.to(devices.cpu, copy=True)
|
||||||
|
else:
|
||||||
|
bias_backup = None
|
||||||
|
self.network_bias_backup = bias_backup
|
||||||
|
|
||||||
if current_names != wanted_names:
|
if current_names != wanted_names:
|
||||||
network_restore_weights_from_backup(self)
|
network_restore_weights_from_backup(self)
|
||||||
|
|
||||||
for net in loaded_networks:
|
for net in loaded_networks:
|
||||||
module = net.modules.get(network_layer_name, None)
|
module = net.modules.get(network_layer_name, None)
|
||||||
if module is not None and hasattr(self, 'weight'):
|
if module is not None and hasattr(self, 'weight'):
|
||||||
|
try:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
updown = module.calc_updown(self.weight)
|
updown, ex_bias = module.calc_updown(self.weight)
|
||||||
|
|
||||||
if len(self.weight.shape) == 4 and self.weight.shape[1] == 9:
|
if len(self.weight.shape) == 4 and self.weight.shape[1] == 9:
|
||||||
# inpainting model. zero pad updown to make channel[1] 4 to 9
|
# inpainting model. zero pad updown to make channel[1] 4 to 9
|
||||||
updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5))
|
updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5))
|
||||||
|
|
||||||
self.weight += updown
|
self.weight += updown
|
||||||
|
if ex_bias is not None and hasattr(self, 'bias'):
|
||||||
|
if self.bias is None:
|
||||||
|
self.bias = torch.nn.Parameter(ex_bias)
|
||||||
|
else:
|
||||||
|
self.bias += ex_bias
|
||||||
|
except RuntimeError as e:
|
||||||
|
logging.debug(f"Network {net.name} layer {network_layer_name}: {e}")
|
||||||
|
extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
|
||||||
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
module_q = net.modules.get(network_layer_name + "_q_proj", None)
|
module_q = net.modules.get(network_layer_name + "_q_proj", None)
|
||||||
@@ -301,21 +357,33 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
|
|||||||
module_out = net.modules.get(network_layer_name + "_out_proj", None)
|
module_out = net.modules.get(network_layer_name + "_out_proj", None)
|
||||||
|
|
||||||
if isinstance(self, torch.nn.MultiheadAttention) and module_q and module_k and module_v and module_out:
|
if isinstance(self, torch.nn.MultiheadAttention) and module_q and module_k and module_v and module_out:
|
||||||
|
try:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
updown_q = module_q.calc_updown(self.in_proj_weight)
|
updown_q, _ = module_q.calc_updown(self.in_proj_weight)
|
||||||
updown_k = module_k.calc_updown(self.in_proj_weight)
|
updown_k, _ = module_k.calc_updown(self.in_proj_weight)
|
||||||
updown_v = module_v.calc_updown(self.in_proj_weight)
|
updown_v, _ = module_v.calc_updown(self.in_proj_weight)
|
||||||
updown_qkv = torch.vstack([updown_q, updown_k, updown_v])
|
updown_qkv = torch.vstack([updown_q, updown_k, updown_v])
|
||||||
updown_out = module_out.calc_updown(self.out_proj.weight)
|
updown_out, ex_bias = module_out.calc_updown(self.out_proj.weight)
|
||||||
|
|
||||||
self.in_proj_weight += updown_qkv
|
self.in_proj_weight += updown_qkv
|
||||||
self.out_proj.weight += updown_out
|
self.out_proj.weight += updown_out
|
||||||
|
if ex_bias is not None:
|
||||||
|
if self.out_proj.bias is None:
|
||||||
|
self.out_proj.bias = torch.nn.Parameter(ex_bias)
|
||||||
|
else:
|
||||||
|
self.out_proj.bias += ex_bias
|
||||||
|
|
||||||
|
except RuntimeError as e:
|
||||||
|
logging.debug(f"Network {net.name} layer {network_layer_name}: {e}")
|
||||||
|
extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
|
||||||
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if module is None:
|
if module is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
print(f'failed to calculate network weights for layer {network_layer_name}')
|
logging.debug(f"Network {net.name} layer {network_layer_name}: couldn't find supported operation")
|
||||||
|
extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
|
||||||
|
|
||||||
self.network_current_names = wanted_names
|
self.network_current_names = wanted_names
|
||||||
|
|
||||||
@@ -342,7 +410,7 @@ def network_forward(module, input, original_forward):
|
|||||||
if module is None:
|
if module is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
y = module.forward(y, input)
|
y = module.forward(input, y)
|
||||||
|
|
||||||
return y
|
return y
|
||||||
|
|
||||||
@@ -354,44 +422,74 @@ def network_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]):
|
|||||||
|
|
||||||
def network_Linear_forward(self, input):
|
def network_Linear_forward(self, input):
|
||||||
if shared.opts.lora_functional:
|
if shared.opts.lora_functional:
|
||||||
return network_forward(self, input, torch.nn.Linear_forward_before_network)
|
return network_forward(self, input, originals.Linear_forward)
|
||||||
|
|
||||||
network_apply_weights(self)
|
network_apply_weights(self)
|
||||||
|
|
||||||
return torch.nn.Linear_forward_before_network(self, input)
|
return originals.Linear_forward(self, input)
|
||||||
|
|
||||||
|
|
||||||
def network_Linear_load_state_dict(self, *args, **kwargs):
|
def network_Linear_load_state_dict(self, *args, **kwargs):
|
||||||
network_reset_cached_weight(self)
|
network_reset_cached_weight(self)
|
||||||
|
|
||||||
return torch.nn.Linear_load_state_dict_before_network(self, *args, **kwargs)
|
return originals.Linear_load_state_dict(self, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def network_Conv2d_forward(self, input):
|
def network_Conv2d_forward(self, input):
|
||||||
if shared.opts.lora_functional:
|
if shared.opts.lora_functional:
|
||||||
return network_forward(self, input, torch.nn.Conv2d_forward_before_network)
|
return network_forward(self, input, originals.Conv2d_forward)
|
||||||
|
|
||||||
network_apply_weights(self)
|
network_apply_weights(self)
|
||||||
|
|
||||||
return torch.nn.Conv2d_forward_before_network(self, input)
|
return originals.Conv2d_forward(self, input)
|
||||||
|
|
||||||
|
|
||||||
def network_Conv2d_load_state_dict(self, *args, **kwargs):
|
def network_Conv2d_load_state_dict(self, *args, **kwargs):
|
||||||
network_reset_cached_weight(self)
|
network_reset_cached_weight(self)
|
||||||
|
|
||||||
return torch.nn.Conv2d_load_state_dict_before_network(self, *args, **kwargs)
|
return originals.Conv2d_load_state_dict(self, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def network_GroupNorm_forward(self, input):
|
||||||
|
if shared.opts.lora_functional:
|
||||||
|
return network_forward(self, input, originals.GroupNorm_forward)
|
||||||
|
|
||||||
|
network_apply_weights(self)
|
||||||
|
|
||||||
|
return originals.GroupNorm_forward(self, input)
|
||||||
|
|
||||||
|
|
||||||
|
def network_GroupNorm_load_state_dict(self, *args, **kwargs):
|
||||||
|
network_reset_cached_weight(self)
|
||||||
|
|
||||||
|
return originals.GroupNorm_load_state_dict(self, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def network_LayerNorm_forward(self, input):
|
||||||
|
if shared.opts.lora_functional:
|
||||||
|
return network_forward(self, input, originals.LayerNorm_forward)
|
||||||
|
|
||||||
|
network_apply_weights(self)
|
||||||
|
|
||||||
|
return originals.LayerNorm_forward(self, input)
|
||||||
|
|
||||||
|
|
||||||
|
def network_LayerNorm_load_state_dict(self, *args, **kwargs):
|
||||||
|
network_reset_cached_weight(self)
|
||||||
|
|
||||||
|
return originals.LayerNorm_load_state_dict(self, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def network_MultiheadAttention_forward(self, *args, **kwargs):
|
def network_MultiheadAttention_forward(self, *args, **kwargs):
|
||||||
network_apply_weights(self)
|
network_apply_weights(self)
|
||||||
|
|
||||||
return torch.nn.MultiheadAttention_forward_before_network(self, *args, **kwargs)
|
return originals.MultiheadAttention_forward(self, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def network_MultiheadAttention_load_state_dict(self, *args, **kwargs):
|
def network_MultiheadAttention_load_state_dict(self, *args, **kwargs):
|
||||||
network_reset_cached_weight(self)
|
network_reset_cached_weight(self)
|
||||||
|
|
||||||
return torch.nn.MultiheadAttention_load_state_dict_before_network(self, *args, **kwargs)
|
return originals.MultiheadAttention_load_state_dict(self, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def list_available_networks():
|
def list_available_networks():
|
||||||
@@ -459,9 +557,14 @@ def infotext_pasted(infotext, params):
|
|||||||
params["Prompt"] += "\n" + "".join(added)
|
params["Prompt"] += "\n" + "".join(added)
|
||||||
|
|
||||||
|
|
||||||
|
originals: lora_patches.LoraPatches = None
|
||||||
|
|
||||||
|
extra_network_lora = None
|
||||||
|
|
||||||
available_networks = {}
|
available_networks = {}
|
||||||
available_network_aliases = {}
|
available_network_aliases = {}
|
||||||
loaded_networks = []
|
loaded_networks = []
|
||||||
|
networks_in_memory = {}
|
||||||
available_network_hash_lookup = {}
|
available_network_hash_lookup = {}
|
||||||
forbidden_network_aliases = {}
|
forbidden_network_aliases = {}
|
||||||
|
|
||||||
|
|||||||
@@ -1,57 +1,30 @@
|
|||||||
import re
|
import re
|
||||||
|
|
||||||
import torch
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
|
|
||||||
import network
|
import network
|
||||||
import networks
|
import networks
|
||||||
import lora # noqa:F401
|
import lora # noqa:F401
|
||||||
|
import lora_patches
|
||||||
import extra_networks_lora
|
import extra_networks_lora
|
||||||
import ui_extra_networks_lora
|
import ui_extra_networks_lora
|
||||||
from modules import script_callbacks, ui_extra_networks, extra_networks, shared
|
from modules import script_callbacks, ui_extra_networks, extra_networks, shared
|
||||||
|
|
||||||
|
|
||||||
def unload():
|
def unload():
|
||||||
torch.nn.Linear.forward = torch.nn.Linear_forward_before_network
|
networks.originals.undo()
|
||||||
torch.nn.Linear._load_from_state_dict = torch.nn.Linear_load_state_dict_before_network
|
|
||||||
torch.nn.Conv2d.forward = torch.nn.Conv2d_forward_before_network
|
|
||||||
torch.nn.Conv2d._load_from_state_dict = torch.nn.Conv2d_load_state_dict_before_network
|
|
||||||
torch.nn.MultiheadAttention.forward = torch.nn.MultiheadAttention_forward_before_network
|
|
||||||
torch.nn.MultiheadAttention._load_from_state_dict = torch.nn.MultiheadAttention_load_state_dict_before_network
|
|
||||||
|
|
||||||
|
|
||||||
def before_ui():
|
def before_ui():
|
||||||
ui_extra_networks.register_page(ui_extra_networks_lora.ExtraNetworksPageLora())
|
ui_extra_networks.register_page(ui_extra_networks_lora.ExtraNetworksPageLora())
|
||||||
|
|
||||||
extra_network = extra_networks_lora.ExtraNetworkLora()
|
networks.extra_network_lora = extra_networks_lora.ExtraNetworkLora()
|
||||||
extra_networks.register_extra_network(extra_network)
|
extra_networks.register_extra_network(networks.extra_network_lora)
|
||||||
extra_networks.register_extra_network_alias(extra_network, "lyco")
|
extra_networks.register_extra_network_alias(networks.extra_network_lora, "lyco")
|
||||||
|
|
||||||
|
|
||||||
if not hasattr(torch.nn, 'Linear_forward_before_network'):
|
networks.originals = lora_patches.LoraPatches()
|
||||||
torch.nn.Linear_forward_before_network = torch.nn.Linear.forward
|
|
||||||
|
|
||||||
if not hasattr(torch.nn, 'Linear_load_state_dict_before_network'):
|
|
||||||
torch.nn.Linear_load_state_dict_before_network = torch.nn.Linear._load_from_state_dict
|
|
||||||
|
|
||||||
if not hasattr(torch.nn, 'Conv2d_forward_before_network'):
|
|
||||||
torch.nn.Conv2d_forward_before_network = torch.nn.Conv2d.forward
|
|
||||||
|
|
||||||
if not hasattr(torch.nn, 'Conv2d_load_state_dict_before_network'):
|
|
||||||
torch.nn.Conv2d_load_state_dict_before_network = torch.nn.Conv2d._load_from_state_dict
|
|
||||||
|
|
||||||
if not hasattr(torch.nn, 'MultiheadAttention_forward_before_network'):
|
|
||||||
torch.nn.MultiheadAttention_forward_before_network = torch.nn.MultiheadAttention.forward
|
|
||||||
|
|
||||||
if not hasattr(torch.nn, 'MultiheadAttention_load_state_dict_before_network'):
|
|
||||||
torch.nn.MultiheadAttention_load_state_dict_before_network = torch.nn.MultiheadAttention._load_from_state_dict
|
|
||||||
|
|
||||||
torch.nn.Linear.forward = networks.network_Linear_forward
|
|
||||||
torch.nn.Linear._load_from_state_dict = networks.network_Linear_load_state_dict
|
|
||||||
torch.nn.Conv2d.forward = networks.network_Conv2d_forward
|
|
||||||
torch.nn.Conv2d._load_from_state_dict = networks.network_Conv2d_load_state_dict
|
|
||||||
torch.nn.MultiheadAttention.forward = networks.network_MultiheadAttention_forward
|
|
||||||
torch.nn.MultiheadAttention._load_from_state_dict = networks.network_MultiheadAttention_load_state_dict
|
|
||||||
|
|
||||||
script_callbacks.on_model_loaded(networks.assign_network_names_to_compvis_modules)
|
script_callbacks.on_model_loaded(networks.assign_network_names_to_compvis_modules)
|
||||||
script_callbacks.on_script_unloaded(unload)
|
script_callbacks.on_script_unloaded(unload)
|
||||||
@@ -65,6 +38,7 @@ shared.options_templates.update(shared.options_section(('extra_networks', "Extra
|
|||||||
"lora_add_hashes_to_infotext": shared.OptionInfo(True, "Add Lora hashes to infotext"),
|
"lora_add_hashes_to_infotext": shared.OptionInfo(True, "Add Lora hashes to infotext"),
|
||||||
"lora_show_all": shared.OptionInfo(False, "Always show all networks on the Lora page").info("otherwise, those detected as for incompatible version of Stable Diffusion will be hidden"),
|
"lora_show_all": shared.OptionInfo(False, "Always show all networks on the Lora page").info("otherwise, those detected as for incompatible version of Stable Diffusion will be hidden"),
|
||||||
"lora_hide_unknown_for_versions": shared.OptionInfo([], "Hide networks of unknown versions for model versions", gr.CheckboxGroup, {"choices": ["SD1", "SD2", "SDXL"]}),
|
"lora_hide_unknown_for_versions": shared.OptionInfo([], "Hide networks of unknown versions for model versions", gr.CheckboxGroup, {"choices": ["SD1", "SD2", "SDXL"]}),
|
||||||
|
"lora_in_memory_limit": shared.OptionInfo(0, "Number of Lora networks to keep cached in memory", gr.Number, {"precision": 0}),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
|
|
||||||
@@ -121,3 +95,5 @@ def infotext_pasted(infotext, d):
|
|||||||
|
|
||||||
|
|
||||||
script_callbacks.on_infotext_pasted(infotext_pasted)
|
script_callbacks.on_infotext_pasted(infotext_pasted)
|
||||||
|
|
||||||
|
shared.opts.onchange("lora_in_memory_limit", networks.purge_networks_from_memory)
|
||||||
|
|||||||
@@ -25,9 +25,10 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
|
|||||||
item = {
|
item = {
|
||||||
"name": name,
|
"name": name,
|
||||||
"filename": lora_on_disk.filename,
|
"filename": lora_on_disk.filename,
|
||||||
|
"shorthash": lora_on_disk.shorthash,
|
||||||
"preview": self.find_preview(path),
|
"preview": self.find_preview(path),
|
||||||
"description": self.find_description(path),
|
"description": self.find_description(path),
|
||||||
"search_term": self.search_terms_from_path(lora_on_disk.filename),
|
"search_term": self.search_terms_from_path(lora_on_disk.filename) + " " + (lora_on_disk.hash or ""),
|
||||||
"local_preview": f"{path}.{shared.opts.samples_format}",
|
"local_preview": f"{path}.{shared.opts.samples_format}",
|
||||||
"metadata": lora_on_disk.metadata,
|
"metadata": lora_on_disk.metadata,
|
||||||
"sort_keys": {'default': index, **self.get_sort_keys(lora_on_disk.filename)},
|
"sort_keys": {'default': index, **self.get_sort_keys(lora_on_disk.filename)},
|
||||||
|
|||||||
@@ -12,8 +12,22 @@ onUiLoaded(async() => {
|
|||||||
"Sketch": elementIDs.sketch
|
"Sketch": elementIDs.sketch
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
// Helper functions
|
// Helper functions
|
||||||
// Get active tab
|
// Get active tab
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Waits for an element to be present in the DOM.
|
||||||
|
*/
|
||||||
|
const waitForElement = (id) => new Promise(resolve => {
|
||||||
|
const checkForElement = () => {
|
||||||
|
const element = document.querySelector(id);
|
||||||
|
if (element) return resolve(element);
|
||||||
|
setTimeout(checkForElement, 100);
|
||||||
|
};
|
||||||
|
checkForElement();
|
||||||
|
});
|
||||||
|
|
||||||
function getActiveTab(elements, all = false) {
|
function getActiveTab(elements, all = false) {
|
||||||
const tabs = elements.img2imgTabs.querySelectorAll("button");
|
const tabs = elements.img2imgTabs.querySelectorAll("button");
|
||||||
|
|
||||||
@@ -34,7 +48,7 @@ onUiLoaded(async() => {
|
|||||||
|
|
||||||
// Wait until opts loaded
|
// Wait until opts loaded
|
||||||
async function waitForOpts() {
|
async function waitForOpts() {
|
||||||
for (;;) {
|
for (; ;) {
|
||||||
if (window.opts && Object.keys(window.opts).length) {
|
if (window.opts && Object.keys(window.opts).length) {
|
||||||
return window.opts;
|
return window.opts;
|
||||||
}
|
}
|
||||||
@@ -42,6 +56,11 @@ onUiLoaded(async() => {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Detect whether the element has a horizontal scroll bar
|
||||||
|
function hasHorizontalScrollbar(element) {
|
||||||
|
return element.scrollWidth > element.clientWidth;
|
||||||
|
}
|
||||||
|
|
||||||
// Function for defining the "Ctrl", "Shift" and "Alt" keys
|
// Function for defining the "Ctrl", "Shift" and "Alt" keys
|
||||||
function isModifierKey(event, key) {
|
function isModifierKey(event, key) {
|
||||||
switch (key) {
|
switch (key) {
|
||||||
@@ -201,7 +220,8 @@ onUiLoaded(async() => {
|
|||||||
canvas_hotkey_overlap: "KeyO",
|
canvas_hotkey_overlap: "KeyO",
|
||||||
canvas_disabled_functions: [],
|
canvas_disabled_functions: [],
|
||||||
canvas_show_tooltip: true,
|
canvas_show_tooltip: true,
|
||||||
canvas_blur_prompt: false
|
canvas_auto_expand: true,
|
||||||
|
canvas_blur_prompt: false,
|
||||||
};
|
};
|
||||||
|
|
||||||
const functionMap = {
|
const functionMap = {
|
||||||
@@ -249,7 +269,7 @@ onUiLoaded(async() => {
|
|||||||
input?.addEventListener("input", () => restoreImgRedMask(elements));
|
input?.addEventListener("input", () => restoreImgRedMask(elements));
|
||||||
}
|
}
|
||||||
|
|
||||||
function applyZoomAndPan(elemId) {
|
function applyZoomAndPan(elemId, isExtension = true) {
|
||||||
const targetElement = gradioApp().querySelector(elemId);
|
const targetElement = gradioApp().querySelector(elemId);
|
||||||
|
|
||||||
if (!targetElement) {
|
if (!targetElement) {
|
||||||
@@ -361,6 +381,10 @@ onUiLoaded(async() => {
|
|||||||
panY: 0
|
panY: 0
|
||||||
};
|
};
|
||||||
|
|
||||||
|
if (isExtension) {
|
||||||
|
targetElement.style.overflow = "hidden";
|
||||||
|
}
|
||||||
|
|
||||||
fixCanvas();
|
fixCanvas();
|
||||||
targetElement.style.transform = `scale(${elemData[elemId].zoomLevel}) translate(${elemData[elemId].panX}px, ${elemData[elemId].panY}px)`;
|
targetElement.style.transform = `scale(${elemData[elemId].zoomLevel}) translate(${elemData[elemId].panX}px, ${elemData[elemId].panY}px)`;
|
||||||
|
|
||||||
@@ -371,8 +395,27 @@ onUiLoaded(async() => {
|
|||||||
toggleOverlap("off");
|
toggleOverlap("off");
|
||||||
fullScreenMode = false;
|
fullScreenMode = false;
|
||||||
|
|
||||||
|
const closeBtn = targetElement.querySelector("button[aria-label='Remove Image']");
|
||||||
|
if (closeBtn) {
|
||||||
|
closeBtn.addEventListener("click", resetZoom);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (canvas && isExtension) {
|
||||||
|
const parentElement = targetElement.closest('[id^="component-"]');
|
||||||
if (
|
if (
|
||||||
canvas &&
|
canvas &&
|
||||||
|
parseFloat(canvas.style.width) > parentElement.offsetWidth &&
|
||||||
|
parseFloat(targetElement.style.width) > parentElement.offsetWidth
|
||||||
|
) {
|
||||||
|
fitToElement();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
if (
|
||||||
|
canvas &&
|
||||||
|
!isExtension &&
|
||||||
parseFloat(canvas.style.width) > 865 &&
|
parseFloat(canvas.style.width) > 865 &&
|
||||||
parseFloat(targetElement.style.width) > 865
|
parseFloat(targetElement.style.width) > 865
|
||||||
) {
|
) {
|
||||||
@@ -381,9 +424,6 @@ onUiLoaded(async() => {
|
|||||||
}
|
}
|
||||||
|
|
||||||
targetElement.style.width = "";
|
targetElement.style.width = "";
|
||||||
if (canvas) {
|
|
||||||
targetElement.style.height = canvas.style.height;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Toggle the zIndex of the target element between two values, allowing it to overlap or be overlapped by other elements
|
// Toggle the zIndex of the target element between two values, allowing it to overlap or be overlapped by other elements
|
||||||
@@ -450,6 +490,10 @@ onUiLoaded(async() => {
|
|||||||
targetElement.style.transform = `translate(${elemData[elemId].panX}px, ${elemData[elemId].panY}px) scale(${newZoomLevel})`;
|
targetElement.style.transform = `translate(${elemData[elemId].panX}px, ${elemData[elemId].panY}px) scale(${newZoomLevel})`;
|
||||||
|
|
||||||
toggleOverlap("on");
|
toggleOverlap("on");
|
||||||
|
if (isExtension) {
|
||||||
|
targetElement.style.overflow = "visible";
|
||||||
|
}
|
||||||
|
|
||||||
return newZoomLevel;
|
return newZoomLevel;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -489,10 +533,19 @@ onUiLoaded(async() => {
|
|||||||
//Reset Zoom
|
//Reset Zoom
|
||||||
targetElement.style.transform = `translate(${0}px, ${0}px) scale(${1})`;
|
targetElement.style.transform = `translate(${0}px, ${0}px) scale(${1})`;
|
||||||
|
|
||||||
|
let parentElement;
|
||||||
|
|
||||||
|
if (isExtension) {
|
||||||
|
parentElement = targetElement.closest('[id^="component-"]');
|
||||||
|
} else {
|
||||||
|
parentElement = targetElement.parentElement;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
// Get element and screen dimensions
|
// Get element and screen dimensions
|
||||||
const elementWidth = targetElement.offsetWidth;
|
const elementWidth = targetElement.offsetWidth;
|
||||||
const elementHeight = targetElement.offsetHeight;
|
const elementHeight = targetElement.offsetHeight;
|
||||||
const parentElement = targetElement.parentElement;
|
|
||||||
const screenWidth = parentElement.clientWidth;
|
const screenWidth = parentElement.clientWidth;
|
||||||
const screenHeight = parentElement.clientHeight;
|
const screenHeight = parentElement.clientHeight;
|
||||||
|
|
||||||
@@ -543,10 +596,15 @@ onUiLoaded(async() => {
|
|||||||
`${elemId} canvas[key="interface"]`
|
`${elemId} canvas[key="interface"]`
|
||||||
);
|
);
|
||||||
|
|
||||||
|
if (isExtension) {
|
||||||
|
targetElement.style.overflow = "visible";
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
if (!canvas) return;
|
if (!canvas) return;
|
||||||
|
|
||||||
if (canvas.offsetWidth > 862) {
|
if (canvas.offsetWidth > 862 || isExtension) {
|
||||||
targetElement.style.width = canvas.offsetWidth + "px";
|
targetElement.style.width = (canvas.offsetWidth + 2) + "px";
|
||||||
}
|
}
|
||||||
|
|
||||||
if (fullScreenMode) {
|
if (fullScreenMode) {
|
||||||
@@ -648,8 +706,48 @@ onUiLoaded(async() => {
|
|||||||
mouseY = e.offsetY;
|
mouseY = e.offsetY;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Simulation of the function to put a long image into the screen.
|
||||||
|
// We detect if an image has a scroll bar or not, make a fullscreen to reveal the image, then reduce it to fit into the element.
|
||||||
|
// We hide the image and show it to the user when it is ready.
|
||||||
|
|
||||||
|
targetElement.isExpanded = false;
|
||||||
|
function autoExpand() {
|
||||||
|
const canvas = document.querySelector(`${elemId} canvas[key="interface"]`);
|
||||||
|
if (canvas) {
|
||||||
|
if (hasHorizontalScrollbar(targetElement) && targetElement.isExpanded === false) {
|
||||||
|
targetElement.style.visibility = "hidden";
|
||||||
|
setTimeout(() => {
|
||||||
|
fitToScreen();
|
||||||
|
resetZoom();
|
||||||
|
targetElement.style.visibility = "visible";
|
||||||
|
targetElement.isExpanded = true;
|
||||||
|
}, 10);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
targetElement.addEventListener("mousemove", getMousePosition);
|
targetElement.addEventListener("mousemove", getMousePosition);
|
||||||
|
|
||||||
|
//observers
|
||||||
|
// Creating an observer with a callback function to handle DOM changes
|
||||||
|
const observer = new MutationObserver((mutationsList, observer) => {
|
||||||
|
for (let mutation of mutationsList) {
|
||||||
|
// If the style attribute of the canvas has changed, by observation it happens only when the picture changes
|
||||||
|
if (mutation.type === 'attributes' && mutation.attributeName === 'style' &&
|
||||||
|
mutation.target.tagName.toLowerCase() === 'canvas') {
|
||||||
|
targetElement.isExpanded = false;
|
||||||
|
setTimeout(resetZoom, 10);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// Apply auto expand if enabled
|
||||||
|
if (hotkeysConfig.canvas_auto_expand) {
|
||||||
|
targetElement.addEventListener("mousemove", autoExpand);
|
||||||
|
// Set up an observer to track attribute changes
|
||||||
|
observer.observe(targetElement, {attributes: true, childList: true, subtree: true});
|
||||||
|
}
|
||||||
|
|
||||||
// Handle events only inside the targetElement
|
// Handle events only inside the targetElement
|
||||||
let isKeyDownHandlerAttached = false;
|
let isKeyDownHandlerAttached = false;
|
||||||
|
|
||||||
@@ -754,6 +852,11 @@ onUiLoaded(async() => {
|
|||||||
if (isMoving && elemId === activeElement) {
|
if (isMoving && elemId === activeElement) {
|
||||||
updatePanPosition(e.movementX, e.movementY);
|
updatePanPosition(e.movementX, e.movementY);
|
||||||
targetElement.style.pointerEvents = "none";
|
targetElement.style.pointerEvents = "none";
|
||||||
|
|
||||||
|
if (isExtension) {
|
||||||
|
targetElement.style.overflow = "visible";
|
||||||
|
}
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
targetElement.style.pointerEvents = "auto";
|
targetElement.style.pointerEvents = "auto";
|
||||||
}
|
}
|
||||||
@@ -767,10 +870,57 @@ onUiLoaded(async() => {
|
|||||||
gradioApp().addEventListener("mousemove", handleMoveByKey);
|
gradioApp().addEventListener("mousemove", handleMoveByKey);
|
||||||
}
|
}
|
||||||
|
|
||||||
applyZoomAndPan(elementIDs.sketch);
|
applyZoomAndPan(elementIDs.sketch, false);
|
||||||
applyZoomAndPan(elementIDs.inpaint);
|
applyZoomAndPan(elementIDs.inpaint, false);
|
||||||
applyZoomAndPan(elementIDs.inpaintSketch);
|
applyZoomAndPan(elementIDs.inpaintSketch, false);
|
||||||
|
|
||||||
// Make the function global so that other extensions can take advantage of this solution
|
// Make the function global so that other extensions can take advantage of this solution
|
||||||
window.applyZoomAndPan = applyZoomAndPan;
|
const applyZoomAndPanIntegration = async(id, elementIDs) => {
|
||||||
|
const mainEl = document.querySelector(id);
|
||||||
|
if (id.toLocaleLowerCase() === "none") {
|
||||||
|
for (const elementID of elementIDs) {
|
||||||
|
const el = await waitForElement(elementID);
|
||||||
|
if (!el) break;
|
||||||
|
applyZoomAndPan(elementID);
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!mainEl) return;
|
||||||
|
mainEl.addEventListener("click", async() => {
|
||||||
|
for (const elementID of elementIDs) {
|
||||||
|
const el = await waitForElement(elementID);
|
||||||
|
if (!el) break;
|
||||||
|
applyZoomAndPan(elementID);
|
||||||
|
}
|
||||||
|
}, {once: true});
|
||||||
|
};
|
||||||
|
|
||||||
|
window.applyZoomAndPan = applyZoomAndPan; // Only 1 elements, argument elementID, for example applyZoomAndPan("#txt2img_controlnet_ControlNet_input_image")
|
||||||
|
|
||||||
|
window.applyZoomAndPanIntegration = applyZoomAndPanIntegration; // for any extension
|
||||||
|
|
||||||
|
/*
|
||||||
|
The function `applyZoomAndPanIntegration` takes two arguments:
|
||||||
|
|
||||||
|
1. `id`: A string identifier for the element to which zoom and pan functionality will be applied on click.
|
||||||
|
If the `id` value is "none", the functionality will be applied to all elements specified in the second argument without a click event.
|
||||||
|
|
||||||
|
2. `elementIDs`: An array of string identifiers for elements. Zoom and pan functionality will be applied to each of these elements on click of the element specified by the first argument.
|
||||||
|
If "none" is specified in the first argument, the functionality will be applied to each of these elements without a click event.
|
||||||
|
|
||||||
|
Example usage:
|
||||||
|
applyZoomAndPanIntegration("#txt2img_controlnet", ["#txt2img_controlnet_ControlNet_input_image"]);
|
||||||
|
In this example, zoom and pan functionality will be applied to the element with the identifier "txt2img_controlnet_ControlNet_input_image" upon clicking the element with the identifier "txt2img_controlnet".
|
||||||
|
*/
|
||||||
|
|
||||||
|
// More examples
|
||||||
|
// Add integration with ControlNet txt2img One TAB
|
||||||
|
// applyZoomAndPanIntegration("#txt2img_controlnet", ["#txt2img_controlnet_ControlNet_input_image"]);
|
||||||
|
|
||||||
|
// Add integration with ControlNet txt2img Tabs
|
||||||
|
// applyZoomAndPanIntegration("#txt2img_controlnet",Array.from({ length: 10 }, (_, i) => `#txt2img_controlnet_ControlNet-${i}_input_image`));
|
||||||
|
|
||||||
|
// Add integration with Inpaint Anything
|
||||||
|
// applyZoomAndPanIntegration("None", ["#ia_sam_image", "#ia_sel_mask"]);
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ shared.options_templates.update(shared.options_section(('canvas_hotkey', "Canvas
|
|||||||
"canvas_hotkey_reset": shared.OptionInfo("R", "Reset zoom and canvas positon"),
|
"canvas_hotkey_reset": shared.OptionInfo("R", "Reset zoom and canvas positon"),
|
||||||
"canvas_hotkey_overlap": shared.OptionInfo("O", "Toggle overlap").info("Technical button, neededs for testing"),
|
"canvas_hotkey_overlap": shared.OptionInfo("O", "Toggle overlap").info("Technical button, neededs for testing"),
|
||||||
"canvas_show_tooltip": shared.OptionInfo(True, "Enable tooltip on the canvas"),
|
"canvas_show_tooltip": shared.OptionInfo(True, "Enable tooltip on the canvas"),
|
||||||
|
"canvas_auto_expand": shared.OptionInfo(True, "Automatically expands an image that does not fit completely in the canvas area, similar to manually pressing the S and R buttons"),
|
||||||
"canvas_blur_prompt": shared.OptionInfo(False, "Take the focus off the prompt when working with a canvas"),
|
"canvas_blur_prompt": shared.OptionInfo(False, "Take the focus off the prompt when working with a canvas"),
|
||||||
"canvas_disabled_functions": shared.OptionInfo(["Overlap"], "Disable function that you don't use", gr.CheckboxGroup, {"choices": ["Zoom","Adjust brush size", "Moving canvas","Fullscreen","Reset Zoom","Overlap"]}),
|
"canvas_disabled_functions": shared.OptionInfo(["Overlap"], "Disable function that you don't use", gr.CheckboxGroup, {"choices": ["Zoom","Adjust brush size", "Moving canvas","Fullscreen","Reset Zoom","Overlap"]}),
|
||||||
}))
|
}))
|
||||||
|
|||||||
@@ -61,3 +61,6 @@
|
|||||||
to {opacity: 1;}
|
to {opacity: 1;}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
.styler {
|
||||||
|
overflow:inherit !important;
|
||||||
|
}
|
||||||
@@ -1,5 +1,7 @@
|
|||||||
|
import math
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
from modules import scripts, shared, ui_components, ui_settings
|
from modules import scripts, shared, ui_components, ui_settings, generation_parameters_copypaste
|
||||||
from modules.ui_components import FormColumn
|
from modules.ui_components import FormColumn
|
||||||
|
|
||||||
|
|
||||||
@@ -19,18 +21,38 @@ class ExtraOptionsSection(scripts.Script):
|
|||||||
def ui(self, is_img2img):
|
def ui(self, is_img2img):
|
||||||
self.comps = []
|
self.comps = []
|
||||||
self.setting_names = []
|
self.setting_names = []
|
||||||
|
self.infotext_fields = []
|
||||||
|
extra_options = shared.opts.extra_options_img2img if is_img2img else shared.opts.extra_options_txt2img
|
||||||
|
|
||||||
|
mapping = {k: v for v, k in generation_parameters_copypaste.infotext_to_setting_name_mapping}
|
||||||
|
|
||||||
with gr.Blocks() as interface:
|
with gr.Blocks() as interface:
|
||||||
with gr.Accordion("Options", open=False) if shared.opts.extra_options_accordion and shared.opts.extra_options else gr.Group(), gr.Row():
|
with gr.Accordion("Options", open=False) if shared.opts.extra_options_accordion and extra_options else gr.Group():
|
||||||
for setting_name in shared.opts.extra_options:
|
|
||||||
|
row_count = math.ceil(len(extra_options) / shared.opts.extra_options_cols)
|
||||||
|
|
||||||
|
for row in range(row_count):
|
||||||
|
with gr.Row():
|
||||||
|
for col in range(shared.opts.extra_options_cols):
|
||||||
|
index = row * shared.opts.extra_options_cols + col
|
||||||
|
if index >= len(extra_options):
|
||||||
|
break
|
||||||
|
|
||||||
|
setting_name = extra_options[index]
|
||||||
|
|
||||||
with FormColumn():
|
with FormColumn():
|
||||||
comp = ui_settings.create_setting_component(setting_name)
|
comp = ui_settings.create_setting_component(setting_name)
|
||||||
|
|
||||||
self.comps.append(comp)
|
self.comps.append(comp)
|
||||||
self.setting_names.append(setting_name)
|
self.setting_names.append(setting_name)
|
||||||
|
|
||||||
|
setting_infotext_name = mapping.get(setting_name)
|
||||||
|
if setting_infotext_name is not None:
|
||||||
|
self.infotext_fields.append((comp, setting_infotext_name))
|
||||||
|
|
||||||
def get_settings_values():
|
def get_settings_values():
|
||||||
return [ui_settings.get_value_for_setting(key) for key in self.setting_names]
|
res = [ui_settings.get_value_for_setting(key) for key in self.setting_names]
|
||||||
|
return res[0] if len(res) == 1 else res
|
||||||
|
|
||||||
interface.load(fn=get_settings_values, inputs=[], outputs=self.comps, queue=False, show_progress=False)
|
interface.load(fn=get_settings_values, inputs=[], outputs=self.comps, queue=False, show_progress=False)
|
||||||
|
|
||||||
@@ -43,6 +65,10 @@ class ExtraOptionsSection(scripts.Script):
|
|||||||
|
|
||||||
|
|
||||||
shared.options_templates.update(shared.options_section(('ui', "User interface"), {
|
shared.options_templates.update(shared.options_section(('ui', "User interface"), {
|
||||||
"extra_options": shared.OptionInfo([], "Options in main UI", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in txt2img/img2img interfaces").needs_reload_ui(),
|
"extra_options_txt2img": shared.OptionInfo([], "Options in main UI - txt2img", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in txt2img interfaces").needs_reload_ui(),
|
||||||
"extra_options_accordion": shared.OptionInfo(False, "Place options in main UI into an accordion").needs_restart()
|
"extra_options_img2img": shared.OptionInfo([], "Options in main UI - img2img", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in img2img interfaces").needs_reload_ui(),
|
||||||
|
"extra_options_cols": shared.OptionInfo(1, "Options in main UI - number of columns", gr.Number, {"precision": 0}).needs_reload_ui(),
|
||||||
|
"extra_options_accordion": shared.OptionInfo(False, "Options in main UI - place into an accordion").needs_reload_ui()
|
||||||
}))
|
}))
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -20,7 +20,13 @@ function reportWindowSize() {
|
|||||||
var button = gradioApp().getElementById(tab + '_generate_box');
|
var button = gradioApp().getElementById(tab + '_generate_box');
|
||||||
var target = gradioApp().getElementById(currentlyMobile ? tab + '_results' : tab + '_actions_column');
|
var target = gradioApp().getElementById(currentlyMobile ? tab + '_results' : tab + '_actions_column');
|
||||||
target.insertBefore(button, target.firstElementChild);
|
target.insertBefore(button, target.firstElementChild);
|
||||||
|
|
||||||
|
gradioApp().getElementById(tab + '_results').classList.toggle('mobile', currentlyMobile);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
window.addEventListener("resize", reportWindowSize);
|
window.addEventListener("resize", reportWindowSize);
|
||||||
|
|
||||||
|
onUiLoaded(function() {
|
||||||
|
reportWindowSize();
|
||||||
|
});
|
||||||
|
|||||||
@@ -332,7 +332,7 @@ function extraNetworksRefreshSingleCard(page, tabname, name) {
|
|||||||
newDiv.innerHTML = data.html;
|
newDiv.innerHTML = data.html;
|
||||||
var newCard = newDiv.firstElementChild;
|
var newCard = newDiv.firstElementChild;
|
||||||
|
|
||||||
newCard.style = '';
|
newCard.style.display = '';
|
||||||
card.parentElement.insertBefore(newCard, card);
|
card.parentElement.insertBefore(newCard, card);
|
||||||
card.parentElement.removeChild(card);
|
card.parentElement.removeChild(card);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -136,6 +136,11 @@ function setupImageForLightbox(e) {
|
|||||||
var event = isFirefox ? 'mousedown' : 'click';
|
var event = isFirefox ? 'mousedown' : 'click';
|
||||||
|
|
||||||
e.addEventListener(event, function(evt) {
|
e.addEventListener(event, function(evt) {
|
||||||
|
if (evt.button == 1) {
|
||||||
|
open(evt.target.src);
|
||||||
|
evt.preventDefault();
|
||||||
|
return;
|
||||||
|
}
|
||||||
if (!opts.js_modal_lightbox || evt.button != 0) return;
|
if (!opts.js_modal_lightbox || evt.button != 0) return;
|
||||||
|
|
||||||
modalZoomSet(gradioApp().getElementById('modalImage'), opts.js_modal_lightbox_initially_zoomed);
|
modalZoomSet(gradioApp().getElementById('modalImage'), opts.js_modal_lightbox_initially_zoomed);
|
||||||
|
|||||||
@@ -0,0 +1,37 @@
|
|||||||
|
var observerAccordionOpen = new MutationObserver(function(mutations) {
|
||||||
|
mutations.forEach(function(mutationRecord) {
|
||||||
|
var elem = mutationRecord.target;
|
||||||
|
var open = elem.classList.contains('open');
|
||||||
|
|
||||||
|
var accordion = elem.parentNode;
|
||||||
|
accordion.classList.toggle('input-accordion-open', open);
|
||||||
|
|
||||||
|
var checkbox = gradioApp().querySelector('#' + accordion.id + "-checkbox input");
|
||||||
|
checkbox.checked = open;
|
||||||
|
updateInput(checkbox);
|
||||||
|
|
||||||
|
var extra = gradioApp().querySelector('#' + accordion.id + "-extra");
|
||||||
|
if (extra) {
|
||||||
|
extra.style.display = open ? "" : "none";
|
||||||
|
}
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
function inputAccordionChecked(id, checked) {
|
||||||
|
var label = gradioApp().querySelector('#' + id + " .label-wrap");
|
||||||
|
if (label.classList.contains('open') != checked) {
|
||||||
|
label.click();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
onUiLoaded(function() {
|
||||||
|
for (var accordion of gradioApp().querySelectorAll('.input-accordion')) {
|
||||||
|
var labelWrap = accordion.querySelector('.label-wrap');
|
||||||
|
observerAccordionOpen.observe(labelWrap, {attributes: true, attributeFilter: ['class']});
|
||||||
|
|
||||||
|
var extra = gradioApp().querySelector('#' + accordion.id + "-extra");
|
||||||
|
if (extra) {
|
||||||
|
labelWrap.insertBefore(extra, labelWrap.lastElementChild);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
@@ -107,12 +107,41 @@ function processNode(node) {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function localizeWholePage() {
|
||||||
|
processNode(gradioApp());
|
||||||
|
|
||||||
|
function elem(comp) {
|
||||||
|
var elem_id = comp.props.elem_id ? comp.props.elem_id : "component-" + comp.id;
|
||||||
|
return gradioApp().getElementById(elem_id);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (var comp of window.gradio_config.components) {
|
||||||
|
if (comp.props.webui_tooltip) {
|
||||||
|
let e = elem(comp);
|
||||||
|
|
||||||
|
let tl = e ? getTranslation(e.title) : undefined;
|
||||||
|
if (tl !== undefined) {
|
||||||
|
e.title = tl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (comp.props.placeholder) {
|
||||||
|
let e = elem(comp);
|
||||||
|
let textbox = e ? e.querySelector('[placeholder]') : null;
|
||||||
|
|
||||||
|
let tl = textbox ? getTranslation(textbox.placeholder) : undefined;
|
||||||
|
if (tl !== undefined) {
|
||||||
|
textbox.placeholder = tl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
function dumpTranslations() {
|
function dumpTranslations() {
|
||||||
if (!hasLocalization()) {
|
if (!hasLocalization()) {
|
||||||
// If we don't have any localization,
|
// If we don't have any localization,
|
||||||
// we will not have traversed the app to find
|
// we will not have traversed the app to find
|
||||||
// original_lines, so do that now.
|
// original_lines, so do that now.
|
||||||
processNode(gradioApp());
|
localizeWholePage();
|
||||||
}
|
}
|
||||||
var dumped = {};
|
var dumped = {};
|
||||||
if (localization.rtl) {
|
if (localization.rtl) {
|
||||||
@@ -154,7 +183,7 @@ document.addEventListener("DOMContentLoaded", function() {
|
|||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
processNode(gradioApp());
|
localizeWholePage();
|
||||||
|
|
||||||
if (localization.rtl) { // if the language is from right to left,
|
if (localization.rtl) { // if the language is from right to left,
|
||||||
(new MutationObserver((mutations, observer) => { // wait for the style to load
|
(new MutationObserver((mutations, observer) => { // wait for the style to load
|
||||||
|
|||||||
+38
-29
@@ -69,7 +69,6 @@ function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgre
|
|||||||
var dateStart = new Date();
|
var dateStart = new Date();
|
||||||
var wasEverActive = false;
|
var wasEverActive = false;
|
||||||
var parentProgressbar = progressbarContainer.parentNode;
|
var parentProgressbar = progressbarContainer.parentNode;
|
||||||
var parentGallery = gallery ? gallery.parentNode : null;
|
|
||||||
|
|
||||||
var divProgress = document.createElement('div');
|
var divProgress = document.createElement('div');
|
||||||
divProgress.className = 'progressDiv';
|
divProgress.className = 'progressDiv';
|
||||||
@@ -80,32 +79,26 @@ function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgre
|
|||||||
divProgress.appendChild(divInner);
|
divProgress.appendChild(divInner);
|
||||||
parentProgressbar.insertBefore(divProgress, progressbarContainer);
|
parentProgressbar.insertBefore(divProgress, progressbarContainer);
|
||||||
|
|
||||||
if (parentGallery) {
|
var livePreview = null;
|
||||||
var livePreview = document.createElement('div');
|
|
||||||
livePreview.className = 'livePreview';
|
|
||||||
parentGallery.insertBefore(livePreview, gallery);
|
|
||||||
}
|
|
||||||
|
|
||||||
var removeProgressBar = function() {
|
var removeProgressBar = function() {
|
||||||
|
if (!divProgress) return;
|
||||||
|
|
||||||
setTitle("");
|
setTitle("");
|
||||||
parentProgressbar.removeChild(divProgress);
|
parentProgressbar.removeChild(divProgress);
|
||||||
if (parentGallery) parentGallery.removeChild(livePreview);
|
if (gallery && livePreview) gallery.removeChild(livePreview);
|
||||||
atEnd();
|
atEnd();
|
||||||
|
|
||||||
|
divProgress = null;
|
||||||
};
|
};
|
||||||
|
|
||||||
var fun = function(id_task, id_live_preview) {
|
var funProgress = function(id_task) {
|
||||||
request("./internal/progress", {id_task: id_task, id_live_preview: id_live_preview}, function(res) {
|
request("./internal/progress", {id_task: id_task, live_preview: false}, function(res) {
|
||||||
if (res.completed) {
|
if (res.completed) {
|
||||||
removeProgressBar();
|
removeProgressBar();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
var rect = progressbarContainer.getBoundingClientRect();
|
|
||||||
|
|
||||||
if (rect.width) {
|
|
||||||
divProgress.style.width = rect.width + "px";
|
|
||||||
}
|
|
||||||
|
|
||||||
let progressText = "";
|
let progressText = "";
|
||||||
|
|
||||||
divInner.style.width = ((res.progress || 0) * 100.0) + '%';
|
divInner.style.width = ((res.progress || 0) * 100.0) + '%';
|
||||||
@@ -119,7 +112,6 @@ function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgre
|
|||||||
progressText += " ETA: " + formatTime(res.eta);
|
progressText += " ETA: " + formatTime(res.eta);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
setTitle(progressText);
|
setTitle(progressText);
|
||||||
|
|
||||||
if (res.textinfo && res.textinfo.indexOf("\n") == -1) {
|
if (res.textinfo && res.textinfo.indexOf("\n") == -1) {
|
||||||
@@ -142,16 +134,33 @@ function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgre
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (onProgress) {
|
||||||
if (res.live_preview && gallery) {
|
onProgress(res);
|
||||||
rect = gallery.getBoundingClientRect();
|
|
||||||
if (rect.width) {
|
|
||||||
livePreview.style.width = rect.width + "px";
|
|
||||||
livePreview.style.height = rect.height + "px";
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
setTimeout(() => {
|
||||||
|
funProgress(id_task, res.id_live_preview);
|
||||||
|
}, opts.live_preview_refresh_period || 500);
|
||||||
|
}, function() {
|
||||||
|
removeProgressBar();
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
|
var funLivePreview = function(id_task, id_live_preview) {
|
||||||
|
request("./internal/progress", {id_task: id_task, id_live_preview: id_live_preview}, function(res) {
|
||||||
|
if (!divProgress) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (res.live_preview && gallery) {
|
||||||
var img = new Image();
|
var img = new Image();
|
||||||
img.onload = function() {
|
img.onload = function() {
|
||||||
|
if (!livePreview) {
|
||||||
|
livePreview = document.createElement('div');
|
||||||
|
livePreview.className = 'livePreview';
|
||||||
|
gallery.insertBefore(livePreview, gallery.firstElementChild);
|
||||||
|
}
|
||||||
|
|
||||||
livePreview.appendChild(img);
|
livePreview.appendChild(img);
|
||||||
if (livePreview.childElementCount > 2) {
|
if (livePreview.childElementCount > 2) {
|
||||||
livePreview.removeChild(livePreview.firstElementChild);
|
livePreview.removeChild(livePreview.firstElementChild);
|
||||||
@@ -160,18 +169,18 @@ function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgre
|
|||||||
img.src = res.live_preview;
|
img.src = res.live_preview;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
if (onProgress) {
|
|
||||||
onProgress(res);
|
|
||||||
}
|
|
||||||
|
|
||||||
setTimeout(() => {
|
setTimeout(() => {
|
||||||
fun(id_task, res.id_live_preview);
|
funLivePreview(id_task, res.id_live_preview);
|
||||||
}, opts.live_preview_refresh_period || 500);
|
}, opts.live_preview_refresh_period || 500);
|
||||||
}, function() {
|
}, function() {
|
||||||
removeProgressBar();
|
removeProgressBar();
|
||||||
});
|
});
|
||||||
};
|
};
|
||||||
|
|
||||||
fun(id_task, 0);
|
funProgress(id_task, 0);
|
||||||
|
|
||||||
|
if (gallery) {
|
||||||
|
funLivePreview(id_task, 0);
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,139 @@
|
|||||||
|
(function() {
|
||||||
|
const GRADIO_MIN_WIDTH = 320;
|
||||||
|
const GRID_TEMPLATE_COLUMNS = '1fr 16px 1fr';
|
||||||
|
const PAD = 16;
|
||||||
|
const DEBOUNCE_TIME = 100;
|
||||||
|
|
||||||
|
const R = {
|
||||||
|
tracking: false,
|
||||||
|
parent: null,
|
||||||
|
parentWidth: null,
|
||||||
|
leftCol: null,
|
||||||
|
leftColStartWidth: null,
|
||||||
|
screenX: null,
|
||||||
|
};
|
||||||
|
|
||||||
|
let resizeTimer;
|
||||||
|
let parents = [];
|
||||||
|
|
||||||
|
function setLeftColGridTemplate(el, width) {
|
||||||
|
el.style.gridTemplateColumns = `${width}px 16px 1fr`;
|
||||||
|
}
|
||||||
|
|
||||||
|
function displayResizeHandle(parent) {
|
||||||
|
if (window.innerWidth < GRADIO_MIN_WIDTH * 2 + PAD * 4) {
|
||||||
|
parent.style.display = 'flex';
|
||||||
|
if (R.handle != null) {
|
||||||
|
R.handle.style.opacity = '0';
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
} else {
|
||||||
|
parent.style.display = 'grid';
|
||||||
|
if (R.handle != null) {
|
||||||
|
R.handle.style.opacity = '100';
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function afterResize(parent) {
|
||||||
|
if (displayResizeHandle(parent) && parent.style.gridTemplateColumns != GRID_TEMPLATE_COLUMNS) {
|
||||||
|
const oldParentWidth = R.parentWidth;
|
||||||
|
const newParentWidth = parent.offsetWidth;
|
||||||
|
const widthL = parseInt(parent.style.gridTemplateColumns.split(' ')[0]);
|
||||||
|
|
||||||
|
const ratio = newParentWidth / oldParentWidth;
|
||||||
|
|
||||||
|
const newWidthL = Math.max(Math.floor(ratio * widthL), GRADIO_MIN_WIDTH);
|
||||||
|
setLeftColGridTemplate(parent, newWidthL);
|
||||||
|
|
||||||
|
R.parentWidth = newParentWidth;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function setup(parent) {
|
||||||
|
const leftCol = parent.firstElementChild;
|
||||||
|
const rightCol = parent.lastElementChild;
|
||||||
|
|
||||||
|
parents.push(parent);
|
||||||
|
|
||||||
|
parent.style.display = 'grid';
|
||||||
|
parent.style.gap = '0';
|
||||||
|
parent.style.gridTemplateColumns = GRID_TEMPLATE_COLUMNS;
|
||||||
|
|
||||||
|
const resizeHandle = document.createElement('div');
|
||||||
|
resizeHandle.classList.add('resize-handle');
|
||||||
|
parent.insertBefore(resizeHandle, rightCol);
|
||||||
|
|
||||||
|
resizeHandle.addEventListener('mousedown', (evt) => {
|
||||||
|
if (evt.button !== 0) return;
|
||||||
|
|
||||||
|
evt.preventDefault();
|
||||||
|
evt.stopPropagation();
|
||||||
|
|
||||||
|
document.body.classList.add('resizing');
|
||||||
|
|
||||||
|
R.tracking = true;
|
||||||
|
R.parent = parent;
|
||||||
|
R.parentWidth = parent.offsetWidth;
|
||||||
|
R.handle = resizeHandle;
|
||||||
|
R.leftCol = leftCol;
|
||||||
|
R.leftColStartWidth = leftCol.offsetWidth;
|
||||||
|
R.screenX = evt.screenX;
|
||||||
|
});
|
||||||
|
|
||||||
|
resizeHandle.addEventListener('dblclick', (evt) => {
|
||||||
|
evt.preventDefault();
|
||||||
|
evt.stopPropagation();
|
||||||
|
|
||||||
|
parent.style.gridTemplateColumns = GRID_TEMPLATE_COLUMNS;
|
||||||
|
});
|
||||||
|
|
||||||
|
afterResize(parent);
|
||||||
|
}
|
||||||
|
|
||||||
|
window.addEventListener('mousemove', (evt) => {
|
||||||
|
if (evt.button !== 0) return;
|
||||||
|
|
||||||
|
if (R.tracking) {
|
||||||
|
evt.preventDefault();
|
||||||
|
evt.stopPropagation();
|
||||||
|
|
||||||
|
const delta = R.screenX - evt.screenX;
|
||||||
|
const leftColWidth = Math.max(Math.min(R.leftColStartWidth - delta, R.parent.offsetWidth - GRADIO_MIN_WIDTH - PAD), GRADIO_MIN_WIDTH);
|
||||||
|
setLeftColGridTemplate(R.parent, leftColWidth);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
window.addEventListener('mouseup', (evt) => {
|
||||||
|
if (evt.button !== 0) return;
|
||||||
|
|
||||||
|
if (R.tracking) {
|
||||||
|
evt.preventDefault();
|
||||||
|
evt.stopPropagation();
|
||||||
|
|
||||||
|
R.tracking = false;
|
||||||
|
|
||||||
|
document.body.classList.remove('resizing');
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
|
||||||
|
window.addEventListener('resize', () => {
|
||||||
|
clearTimeout(resizeTimer);
|
||||||
|
|
||||||
|
resizeTimer = setTimeout(function() {
|
||||||
|
for (const parent of parents) {
|
||||||
|
afterResize(parent);
|
||||||
|
}
|
||||||
|
}, DEBOUNCE_TIME);
|
||||||
|
});
|
||||||
|
|
||||||
|
setupResizeHandle = setup;
|
||||||
|
})();
|
||||||
|
|
||||||
|
onUiLoaded(function() {
|
||||||
|
for (var elem of gradioApp().querySelectorAll('.resize-handle-row')) {
|
||||||
|
setupResizeHandle(elem);
|
||||||
|
}
|
||||||
|
});
|
||||||
+2
-19
@@ -19,28 +19,11 @@ function all_gallery_buttons() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
function selected_gallery_button() {
|
function selected_gallery_button() {
|
||||||
var allCurrentButtons = gradioApp().querySelectorAll('[style="display: block;"].tabitem div[id$=_gallery].gradio-gallery .thumbnail-item.thumbnail-small.selected');
|
return all_gallery_buttons().find(elem => elem.classList.contains('selected')) ?? null;
|
||||||
var visibleCurrentButton = null;
|
|
||||||
allCurrentButtons.forEach(function(elem) {
|
|
||||||
if (elem.parentElement.offsetParent) {
|
|
||||||
visibleCurrentButton = elem;
|
|
||||||
}
|
|
||||||
});
|
|
||||||
return visibleCurrentButton;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
function selected_gallery_index() {
|
function selected_gallery_index() {
|
||||||
var buttons = all_gallery_buttons();
|
return all_gallery_buttons().findIndex(elem => elem.classList.contains('selected'));
|
||||||
var button = selected_gallery_button();
|
|
||||||
|
|
||||||
var result = -1;
|
|
||||||
buttons.forEach(function(v, i) {
|
|
||||||
if (v == button) {
|
|
||||||
result = i;
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
return result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
function extract_image_from_gallery(gallery) {
|
function extract_image_from_gallery(gallery) {
|
||||||
|
|||||||
+44
-5
@@ -4,6 +4,8 @@ import os
|
|||||||
import time
|
import time
|
||||||
import datetime
|
import datetime
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
import ipaddress
|
||||||
|
import requests
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
from threading import Lock
|
from threading import Lock
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
@@ -23,8 +25,7 @@ from modules.textual_inversion.textual_inversion import create_embedding, train_
|
|||||||
from modules.textual_inversion.preprocess import preprocess
|
from modules.textual_inversion.preprocess import preprocess
|
||||||
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
|
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
|
||||||
from PIL import PngImagePlugin,Image
|
from PIL import PngImagePlugin,Image
|
||||||
from modules.sd_models import checkpoints_list, unload_model_weights, reload_model_weights, checkpoint_aliases
|
from modules.sd_models import unload_model_weights, reload_model_weights, checkpoint_aliases
|
||||||
from modules.sd_vae import vae_dict
|
|
||||||
from modules.sd_models_config import find_checkpoint_config_near_filename
|
from modules.sd_models_config import find_checkpoint_config_near_filename
|
||||||
from modules.realesrgan_model import get_realesrgan_models
|
from modules.realesrgan_model import get_realesrgan_models
|
||||||
from modules import devices
|
from modules import devices
|
||||||
@@ -56,7 +57,41 @@ def setUpscalers(req: dict):
|
|||||||
return reqDict
|
return reqDict
|
||||||
|
|
||||||
|
|
||||||
|
def verify_url(url):
|
||||||
|
"""Returns True if the url refers to a global resource."""
|
||||||
|
|
||||||
|
import socket
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
try:
|
||||||
|
parsed_url = urlparse(url)
|
||||||
|
domain_name = parsed_url.netloc
|
||||||
|
host = socket.gethostbyname_ex(domain_name)
|
||||||
|
for ip in host[2]:
|
||||||
|
ip_addr = ipaddress.ip_address(ip)
|
||||||
|
if not ip_addr.is_global:
|
||||||
|
return False
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
def decode_base64_to_image(encoding):
|
def decode_base64_to_image(encoding):
|
||||||
|
if encoding.startswith("http://") or encoding.startswith("https://"):
|
||||||
|
if not opts.api_enable_requests:
|
||||||
|
raise HTTPException(status_code=500, detail="Requests not allowed")
|
||||||
|
|
||||||
|
if opts.api_forbid_local_requests and not verify_url(encoding):
|
||||||
|
raise HTTPException(status_code=500, detail="Request to local resource not allowed")
|
||||||
|
|
||||||
|
headers = {'user-agent': opts.api_useragent} if opts.api_useragent else {}
|
||||||
|
response = requests.get(encoding, timeout=30, headers=headers)
|
||||||
|
try:
|
||||||
|
image = Image.open(BytesIO(response.content))
|
||||||
|
return image
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail="Invalid image url") from e
|
||||||
|
|
||||||
if encoding.startswith("data:image/"):
|
if encoding.startswith("data:image/"):
|
||||||
encoding = encoding.split(";")[1].split(",")[1]
|
encoding = encoding.split(";")[1].split(",")[1]
|
||||||
try:
|
try:
|
||||||
@@ -330,6 +365,7 @@ class Api:
|
|||||||
|
|
||||||
with self.queue_lock:
|
with self.queue_lock:
|
||||||
with closing(StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args)) as p:
|
with closing(StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args)) as p:
|
||||||
|
p.is_api = True
|
||||||
p.scripts = script_runner
|
p.scripts = script_runner
|
||||||
p.outpath_grids = opts.outdir_txt2img_grids
|
p.outpath_grids = opts.outdir_txt2img_grids
|
||||||
p.outpath_samples = opts.outdir_txt2img_samples
|
p.outpath_samples = opts.outdir_txt2img_samples
|
||||||
@@ -390,6 +426,7 @@ class Api:
|
|||||||
with self.queue_lock:
|
with self.queue_lock:
|
||||||
with closing(StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)) as p:
|
with closing(StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)) as p:
|
||||||
p.init_images = [decode_base64_to_image(x) for x in init_images]
|
p.init_images = [decode_base64_to_image(x) for x in init_images]
|
||||||
|
p.is_api = True
|
||||||
p.scripts = script_runner
|
p.scripts = script_runner
|
||||||
p.outpath_grids = opts.outdir_img2img_grids
|
p.outpath_grids = opts.outdir_img2img_grids
|
||||||
p.outpath_samples = opts.outdir_img2img_samples
|
p.outpath_samples = opts.outdir_img2img_samples
|
||||||
@@ -533,7 +570,7 @@ class Api:
|
|||||||
raise RuntimeError(f"model {checkpoint_name!r} not found")
|
raise RuntimeError(f"model {checkpoint_name!r} not found")
|
||||||
|
|
||||||
for k, v in req.items():
|
for k, v in req.items():
|
||||||
shared.opts.set(k, v)
|
shared.opts.set(k, v, is_api=True)
|
||||||
|
|
||||||
shared.opts.save(shared.config_filename)
|
shared.opts.save(shared.config_filename)
|
||||||
return
|
return
|
||||||
@@ -565,10 +602,12 @@ class Api:
|
|||||||
]
|
]
|
||||||
|
|
||||||
def get_sd_models(self):
|
def get_sd_models(self):
|
||||||
return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename, "config": find_checkpoint_config_near_filename(x)} for x in checkpoints_list.values()]
|
import modules.sd_models as sd_models
|
||||||
|
return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename, "config": find_checkpoint_config_near_filename(x)} for x in sd_models.checkpoints_list.values()]
|
||||||
|
|
||||||
def get_sd_vaes(self):
|
def get_sd_vaes(self):
|
||||||
return [{"model_name": x, "filename": vae_dict[x]} for x in vae_dict.keys()]
|
import modules.sd_vae as sd_vae
|
||||||
|
return [{"model_name": x, "filename": sd_vae.vae_dict[x]} for x in sd_vae.vae_dict.keys()]
|
||||||
|
|
||||||
def get_hypernetworks(self):
|
def get_hypernetworks(self):
|
||||||
return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks]
|
return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks]
|
||||||
|
|||||||
@@ -50,10 +50,12 @@ class PydanticModelGenerator:
|
|||||||
additional_fields = None,
|
additional_fields = None,
|
||||||
):
|
):
|
||||||
def field_type_generator(k, v):
|
def field_type_generator(k, v):
|
||||||
# field_type = str if not overrides.get(k) else overrides[k]["type"]
|
|
||||||
# print(k, v.annotation, v.default)
|
|
||||||
field_type = v.annotation
|
field_type = v.annotation
|
||||||
|
|
||||||
|
if field_type == 'Image':
|
||||||
|
# images are sent as base64 strings via API
|
||||||
|
field_type = 'str'
|
||||||
|
|
||||||
return Optional[field_type]
|
return Optional[field_type]
|
||||||
|
|
||||||
def merge_class_params(class_):
|
def merge_class_params(class_):
|
||||||
@@ -63,7 +65,6 @@ class PydanticModelGenerator:
|
|||||||
parameters = {**parameters, **inspect.signature(classes.__init__).parameters}
|
parameters = {**parameters, **inspect.signature(classes.__init__).parameters}
|
||||||
return parameters
|
return parameters
|
||||||
|
|
||||||
|
|
||||||
self._model_name = model_name
|
self._model_name = model_name
|
||||||
self._class_data = merge_class_params(class_instance)
|
self._class_data = merge_class_params(class_instance)
|
||||||
|
|
||||||
@@ -72,7 +73,7 @@ class PydanticModelGenerator:
|
|||||||
field=underscore(k),
|
field=underscore(k),
|
||||||
field_alias=k,
|
field_alias=k,
|
||||||
field_type=field_type_generator(k, v),
|
field_type=field_type_generator(k, v),
|
||||||
field_value=v.default
|
field_value=None if isinstance(v.default, property) else v.default
|
||||||
)
|
)
|
||||||
for (k,v) in self._class_data.items() if k not in API_NOT_ALLOWED
|
for (k,v) in self._class_data.items() if k not in API_NOT_ALLOWED
|
||||||
]
|
]
|
||||||
|
|||||||
+6
-2
@@ -1,11 +1,12 @@
|
|||||||
import json
|
import json
|
||||||
|
import os
|
||||||
import os.path
|
import os.path
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
|
|
||||||
from modules.paths import data_path, script_path
|
from modules.paths import data_path, script_path
|
||||||
|
|
||||||
cache_filename = os.path.join(data_path, "cache.json")
|
cache_filename = os.environ.get('SD_WEBUI_CACHE_FILE', os.path.join(data_path, "cache.json"))
|
||||||
cache_data = None
|
cache_data = None
|
||||||
cache_lock = threading.Lock()
|
cache_lock = threading.Lock()
|
||||||
|
|
||||||
@@ -29,9 +30,12 @@ def dump_cache():
|
|||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
|
|
||||||
with cache_lock:
|
with cache_lock:
|
||||||
with open(cache_filename, "w", encoding="utf8") as file:
|
cache_filename_tmp = cache_filename + "-"
|
||||||
|
with open(cache_filename_tmp, "w", encoding="utf8") as file:
|
||||||
json.dump(cache_data, file, indent=4)
|
json.dump(cache_data, file, indent=4)
|
||||||
|
|
||||||
|
os.replace(cache_filename_tmp, cache_filename)
|
||||||
|
|
||||||
dump_cache_after = None
|
dump_cache_after = None
|
||||||
dump_cache_thread = None
|
dump_cache_thread = None
|
||||||
|
|
||||||
|
|||||||
@@ -1,11 +1,10 @@
|
|||||||
from functools import wraps
|
from functools import wraps
|
||||||
import html
|
import html
|
||||||
import threading
|
|
||||||
import time
|
import time
|
||||||
|
|
||||||
from modules import shared, progress, errors, devices
|
from modules import shared, progress, errors, devices, fifo_lock
|
||||||
|
|
||||||
queue_lock = threading.Lock()
|
queue_lock = fifo_lock.FIFOLock()
|
||||||
|
|
||||||
|
|
||||||
def wrap_queued_call(func):
|
def wrap_queued_call(func):
|
||||||
|
|||||||
+4
-2
@@ -16,6 +16,7 @@ parser.add_argument("--test-server", action='store_true', help="launch.py argume
|
|||||||
parser.add_argument("--log-startup", action='store_true', help="launch.py argument: print a detailed log of what's happening at startup")
|
parser.add_argument("--log-startup", action='store_true', help="launch.py argument: print a detailed log of what's happening at startup")
|
||||||
parser.add_argument("--skip-prepare-environment", action='store_true', help="launch.py argument: skip all environment preparation")
|
parser.add_argument("--skip-prepare-environment", action='store_true', help="launch.py argument: skip all environment preparation")
|
||||||
parser.add_argument("--skip-install", action='store_true', help="launch.py argument: skip installation of packages")
|
parser.add_argument("--skip-install", action='store_true', help="launch.py argument: skip installation of packages")
|
||||||
|
parser.add_argument("--loglevel", type=str, help="log level; one of: CRITICAL, ERROR, WARNING, INFO, DEBUG", default=None)
|
||||||
parser.add_argument("--do-not-download-clip", action='store_true', help="do not download CLIP model even if it's not included in the checkpoint")
|
parser.add_argument("--do-not-download-clip", action='store_true', help="do not download CLIP model even if it's not included in the checkpoint")
|
||||||
parser.add_argument("--data-dir", type=str, default=os.path.dirname(os.path.dirname(os.path.realpath(__file__))), help="base path where all user data is stored")
|
parser.add_argument("--data-dir", type=str, default=os.path.dirname(os.path.dirname(os.path.realpath(__file__))), help="base path where all user data is stored")
|
||||||
parser.add_argument("--config", type=str, default=sd_default_config, help="path to config which constructs model",)
|
parser.add_argument("--config", type=str, default=sd_default_config, help="path to config which constructs model",)
|
||||||
@@ -34,9 +35,10 @@ parser.add_argument("--hypernetwork-dir", type=str, default=os.path.join(models_
|
|||||||
parser.add_argument("--localizations-dir", type=str, default=os.path.join(script_path, 'localizations'), help="localizations directory")
|
parser.add_argument("--localizations-dir", type=str, default=os.path.join(script_path, 'localizations'), help="localizations directory")
|
||||||
parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui")
|
parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui")
|
||||||
parser.add_argument("--medvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a little speed for low VRM usage")
|
parser.add_argument("--medvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a little speed for low VRM usage")
|
||||||
|
parser.add_argument("--medvram-sdxl", action='store_true', help="enable --medvram optimization just for SDXL models")
|
||||||
parser.add_argument("--lowvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a lot of speed for very low VRM usage")
|
parser.add_argument("--lowvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a lot of speed for very low VRM usage")
|
||||||
parser.add_argument("--lowram", action='store_true', help="load stable diffusion checkpoint weights to VRAM instead of RAM")
|
parser.add_argument("--lowram", action='store_true', help="load stable diffusion checkpoint weights to VRAM instead of RAM")
|
||||||
parser.add_argument("--always-batch-cond-uncond", action='store_true', help="disables cond/uncond batching that is enabled to save memory with --medvram or --lowvram")
|
parser.add_argument("--always-batch-cond-uncond", action='store_true', help="does not do anything")
|
||||||
parser.add_argument("--unload-gfpgan", action='store_true', help="does not do anything.")
|
parser.add_argument("--unload-gfpgan", action='store_true', help="does not do anything.")
|
||||||
parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast")
|
parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast")
|
||||||
parser.add_argument("--upcast-sampling", action='store_true', help="upcast sampling. No effect with --no-half. Usually produces similar results to --no-half with better performance while using less memory.")
|
parser.add_argument("--upcast-sampling", action='store_true', help="upcast sampling. No effect with --no-half. Usually produces similar results to --no-half with better performance while using less memory.")
|
||||||
@@ -80,7 +82,7 @@ parser.add_argument("--gradio-auth", type=str, help='set gradio authentication l
|
|||||||
parser.add_argument("--gradio-auth-path", type=str, help='set gradio authentication file path ex. "/path/to/auth/file" same auth format as --gradio-auth', default=None)
|
parser.add_argument("--gradio-auth-path", type=str, help='set gradio authentication file path ex. "/path/to/auth/file" same auth format as --gradio-auth', default=None)
|
||||||
parser.add_argument("--gradio-img2img-tool", type=str, help='does not do anything')
|
parser.add_argument("--gradio-img2img-tool", type=str, help='does not do anything')
|
||||||
parser.add_argument("--gradio-inpaint-tool", type=str, help="does not do anything")
|
parser.add_argument("--gradio-inpaint-tool", type=str, help="does not do anything")
|
||||||
parser.add_argument("--gradio-allowed-path", action='append', help="add path to gradio's allowed_paths, make it possible to serve files from it")
|
parser.add_argument("--gradio-allowed-path", action='append', help="add path to gradio's allowed_paths, make it possible to serve files from it", default=[data_path])
|
||||||
parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last")
|
parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last")
|
||||||
parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(data_path, 'styles.csv'))
|
parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(data_path, 'styles.csv'))
|
||||||
parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False)
|
parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False)
|
||||||
|
|||||||
@@ -8,14 +8,12 @@ import time
|
|||||||
import tqdm
|
import tqdm
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from collections import OrderedDict
|
|
||||||
import git
|
import git
|
||||||
|
|
||||||
from modules import shared, extensions, errors
|
from modules import shared, extensions, errors
|
||||||
from modules.paths_internal import script_path, config_states_dir
|
from modules.paths_internal import script_path, config_states_dir
|
||||||
|
|
||||||
|
all_config_states = {}
|
||||||
all_config_states = OrderedDict()
|
|
||||||
|
|
||||||
|
|
||||||
def list_config_states():
|
def list_config_states():
|
||||||
@@ -28,10 +26,14 @@ def list_config_states():
|
|||||||
for filename in os.listdir(config_states_dir):
|
for filename in os.listdir(config_states_dir):
|
||||||
if filename.endswith(".json"):
|
if filename.endswith(".json"):
|
||||||
path = os.path.join(config_states_dir, filename)
|
path = os.path.join(config_states_dir, filename)
|
||||||
|
try:
|
||||||
with open(path, "r", encoding="utf-8") as f:
|
with open(path, "r", encoding="utf-8") as f:
|
||||||
j = json.load(f)
|
j = json.load(f)
|
||||||
|
assert "created_at" in j, '"created_at" does not exist'
|
||||||
j["filepath"] = path
|
j["filepath"] = path
|
||||||
config_states.append(j)
|
config_states.append(j)
|
||||||
|
except Exception as e:
|
||||||
|
print(f'[ERROR]: Config states {path}, {e}')
|
||||||
|
|
||||||
config_states = sorted(config_states, key=lambda cs: cs["created_at"], reverse=True)
|
config_states = sorted(config_states, key=lambda cs: cs["created_at"], reverse=True)
|
||||||
|
|
||||||
|
|||||||
+2
-87
@@ -3,7 +3,7 @@ import contextlib
|
|||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from modules import errors, rng_philox
|
from modules import errors, shared
|
||||||
|
|
||||||
if sys.platform == "darwin":
|
if sys.platform == "darwin":
|
||||||
from modules import mac_specific
|
from modules import mac_specific
|
||||||
@@ -17,8 +17,6 @@ def has_mps() -> bool:
|
|||||||
|
|
||||||
|
|
||||||
def get_cuda_device_string():
|
def get_cuda_device_string():
|
||||||
from modules import shared
|
|
||||||
|
|
||||||
if shared.cmd_opts.device_id is not None:
|
if shared.cmd_opts.device_id is not None:
|
||||||
return f"cuda:{shared.cmd_opts.device_id}"
|
return f"cuda:{shared.cmd_opts.device_id}"
|
||||||
|
|
||||||
@@ -40,8 +38,6 @@ def get_optimal_device():
|
|||||||
|
|
||||||
|
|
||||||
def get_device_for(task):
|
def get_device_for(task):
|
||||||
from modules import shared
|
|
||||||
|
|
||||||
if task in shared.cmd_opts.use_cpu:
|
if task in shared.cmd_opts.use_cpu:
|
||||||
return cpu
|
return cpu
|
||||||
|
|
||||||
@@ -96,87 +92,7 @@ def cond_cast_float(input):
|
|||||||
nv_rng = None
|
nv_rng = None
|
||||||
|
|
||||||
|
|
||||||
def randn(seed, shape):
|
|
||||||
"""Generate a tensor with random numbers from a normal distribution using seed.
|
|
||||||
|
|
||||||
Uses the seed parameter to set the global torch seed; to generate more with that seed, use randn_like/randn_without_seed."""
|
|
||||||
|
|
||||||
from modules.shared import opts
|
|
||||||
|
|
||||||
manual_seed(seed)
|
|
||||||
|
|
||||||
if opts.randn_source == "NV":
|
|
||||||
return torch.asarray(nv_rng.randn(shape), device=device)
|
|
||||||
|
|
||||||
if opts.randn_source == "CPU" or device.type == 'mps':
|
|
||||||
return torch.randn(shape, device=cpu).to(device)
|
|
||||||
|
|
||||||
return torch.randn(shape, device=device)
|
|
||||||
|
|
||||||
|
|
||||||
def randn_local(seed, shape):
|
|
||||||
"""Generate a tensor with random numbers from a normal distribution using seed.
|
|
||||||
|
|
||||||
Does not change the global random number generator. You can only generate the seed's first tensor using this function."""
|
|
||||||
|
|
||||||
from modules.shared import opts
|
|
||||||
|
|
||||||
if opts.randn_source == "NV":
|
|
||||||
rng = rng_philox.Generator(seed)
|
|
||||||
return torch.asarray(rng.randn(shape), device=device)
|
|
||||||
|
|
||||||
local_device = cpu if opts.randn_source == "CPU" or device.type == 'mps' else device
|
|
||||||
local_generator = torch.Generator(local_device).manual_seed(int(seed))
|
|
||||||
return torch.randn(shape, device=local_device, generator=local_generator).to(device)
|
|
||||||
|
|
||||||
|
|
||||||
def randn_like(x):
|
|
||||||
"""Generate a tensor with random numbers from a normal distribution using the previously initialized genrator.
|
|
||||||
|
|
||||||
Use either randn() or manual_seed() to initialize the generator."""
|
|
||||||
|
|
||||||
from modules.shared import opts
|
|
||||||
|
|
||||||
if opts.randn_source == "NV":
|
|
||||||
return torch.asarray(nv_rng.randn(x.shape), device=x.device, dtype=x.dtype)
|
|
||||||
|
|
||||||
if opts.randn_source == "CPU" or x.device.type == 'mps':
|
|
||||||
return torch.randn_like(x, device=cpu).to(x.device)
|
|
||||||
|
|
||||||
return torch.randn_like(x)
|
|
||||||
|
|
||||||
|
|
||||||
def randn_without_seed(shape):
|
|
||||||
"""Generate a tensor with random numbers from a normal distribution using the previously initialized genrator.
|
|
||||||
|
|
||||||
Use either randn() or manual_seed() to initialize the generator."""
|
|
||||||
|
|
||||||
from modules.shared import opts
|
|
||||||
|
|
||||||
if opts.randn_source == "NV":
|
|
||||||
return torch.asarray(nv_rng.randn(shape), device=device)
|
|
||||||
|
|
||||||
if opts.randn_source == "CPU" or device.type == 'mps':
|
|
||||||
return torch.randn(shape, device=cpu).to(device)
|
|
||||||
|
|
||||||
return torch.randn(shape, device=device)
|
|
||||||
|
|
||||||
|
|
||||||
def manual_seed(seed):
|
|
||||||
"""Set up a global random number generator using the specified seed."""
|
|
||||||
from modules.shared import opts
|
|
||||||
|
|
||||||
if opts.randn_source == "NV":
|
|
||||||
global nv_rng
|
|
||||||
nv_rng = rng_philox.Generator(seed)
|
|
||||||
return
|
|
||||||
|
|
||||||
torch.manual_seed(seed)
|
|
||||||
|
|
||||||
|
|
||||||
def autocast(disable=False):
|
def autocast(disable=False):
|
||||||
from modules import shared
|
|
||||||
|
|
||||||
if disable:
|
if disable:
|
||||||
return contextlib.nullcontext()
|
return contextlib.nullcontext()
|
||||||
|
|
||||||
@@ -195,8 +111,6 @@ class NansException(Exception):
|
|||||||
|
|
||||||
|
|
||||||
def test_for_nans(x, where):
|
def test_for_nans(x, where):
|
||||||
from modules import shared
|
|
||||||
|
|
||||||
if shared.cmd_opts.disable_nan_check:
|
if shared.cmd_opts.disable_nan_check:
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -236,3 +150,4 @@ def first_time_calculation():
|
|||||||
x = torch.zeros((1, 1, 3, 3)).to(device, dtype)
|
x = torch.zeros((1, 1, 3, 3)).to(device, dtype)
|
||||||
conv2d = torch.nn.Conv2d(1, 1, (3, 3)).to(device, dtype)
|
conv2d = torch.nn.Conv2d(1, 1, (3, 3)).to(device, dtype)
|
||||||
conv2d(x)
|
conv2d(x)
|
||||||
|
|
||||||
|
|||||||
+1
-1
@@ -95,7 +95,7 @@ def check_versions():
|
|||||||
|
|
||||||
expected_torch_version = "2.0.0"
|
expected_torch_version = "2.0.0"
|
||||||
expected_xformers_version = "0.0.20"
|
expected_xformers_version = "0.0.20"
|
||||||
expected_gradio_version = "3.39.0"
|
expected_gradio_version = "3.41.0"
|
||||||
|
|
||||||
if version.parse(torch.__version__) < version.parse(expected_torch_version):
|
if version.parse(torch.__version__) < version.parse(expected_torch_version):
|
||||||
print_error_explanation(f"""
|
print_error_explanation(f"""
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
import threading
|
import threading
|
||||||
|
|
||||||
from modules import shared, errors, cache
|
from modules import shared, errors, cache, scripts
|
||||||
from modules.gitpython_hack import Repo
|
from modules.gitpython_hack import Repo
|
||||||
from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path # noqa: F401
|
from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path # noqa: F401
|
||||||
|
|
||||||
@@ -90,8 +90,6 @@ class Extension:
|
|||||||
self.have_info_from_repo = True
|
self.have_info_from_repo = True
|
||||||
|
|
||||||
def list_files(self, subdir, extension):
|
def list_files(self, subdir, extension):
|
||||||
from modules import scripts
|
|
||||||
|
|
||||||
dirpath = os.path.join(self.path, subdir)
|
dirpath = os.path.join(self.path, subdir)
|
||||||
if not os.path.isdir(dirpath):
|
if not os.path.isdir(dirpath):
|
||||||
return []
|
return []
|
||||||
|
|||||||
@@ -0,0 +1,37 @@
|
|||||||
|
import threading
|
||||||
|
import collections
|
||||||
|
|
||||||
|
|
||||||
|
# reference: https://gist.github.com/vitaliyp/6d54dd76ca2c3cdfc1149d33007dc34a
|
||||||
|
class FIFOLock(object):
|
||||||
|
def __init__(self):
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
self._inner_lock = threading.Lock()
|
||||||
|
self._pending_threads = collections.deque()
|
||||||
|
|
||||||
|
def acquire(self, blocking=True):
|
||||||
|
with self._inner_lock:
|
||||||
|
lock_acquired = self._lock.acquire(False)
|
||||||
|
if lock_acquired:
|
||||||
|
return True
|
||||||
|
elif not blocking:
|
||||||
|
return False
|
||||||
|
|
||||||
|
release_event = threading.Event()
|
||||||
|
self._pending_threads.append(release_event)
|
||||||
|
|
||||||
|
release_event.wait()
|
||||||
|
return self._lock.acquire()
|
||||||
|
|
||||||
|
def release(self):
|
||||||
|
with self._inner_lock:
|
||||||
|
if self._pending_threads:
|
||||||
|
release_event = self._pending_threads.popleft()
|
||||||
|
release_event.set()
|
||||||
|
|
||||||
|
self._lock.release()
|
||||||
|
|
||||||
|
__enter__ = acquire
|
||||||
|
|
||||||
|
def __exit__(self, t, v, tb):
|
||||||
|
self.release()
|
||||||
@@ -6,7 +6,7 @@ import re
|
|||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
from modules.paths import data_path
|
from modules.paths import data_path
|
||||||
from modules import shared, ui_tempdir, script_callbacks
|
from modules import shared, ui_tempdir, script_callbacks, processing
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
re_param_code = r'\s*([\w ]+):\s*("(?:\\"[^,]|\\"|\\|[^\"])+"|[^,]*)(?:,|$)'
|
re_param_code = r'\s*([\w ]+):\s*("(?:\\"[^,]|\\"|\\|[^\"])+"|[^,]*)(?:,|$)'
|
||||||
@@ -198,7 +198,6 @@ def restore_old_hires_fix_params(res):
|
|||||||
height = int(res.get("Size-2", 512))
|
height = int(res.get("Size-2", 512))
|
||||||
|
|
||||||
if firstpass_width == 0 or firstpass_height == 0:
|
if firstpass_width == 0 or firstpass_height == 0:
|
||||||
from modules import processing
|
|
||||||
firstpass_width, firstpass_height = processing.old_hires_fix_first_pass_dimensions(width, height)
|
firstpass_width, firstpass_height = processing.old_hires_fix_first_pass_dimensions(width, height)
|
||||||
|
|
||||||
res['Size-1'] = firstpass_width
|
res['Size-1'] = firstpass_width
|
||||||
@@ -317,34 +316,18 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
|
|||||||
|
|
||||||
|
|
||||||
infotext_to_setting_name_mapping = [
|
infotext_to_setting_name_mapping = [
|
||||||
('Clip skip', 'CLIP_stop_at_last_layers', ),
|
|
||||||
|
]
|
||||||
|
"""Mapping of infotext labels to setting names. Only left for backwards compatibility - use OptionInfo(..., infotext='...') instead.
|
||||||
|
Example content:
|
||||||
|
|
||||||
|
infotext_to_setting_name_mapping = [
|
||||||
('Conditional mask weight', 'inpainting_mask_weight'),
|
('Conditional mask weight', 'inpainting_mask_weight'),
|
||||||
('Model hash', 'sd_model_checkpoint'),
|
('Model hash', 'sd_model_checkpoint'),
|
||||||
('ENSD', 'eta_noise_seed_delta'),
|
('ENSD', 'eta_noise_seed_delta'),
|
||||||
('Schedule type', 'k_sched_type'),
|
('Schedule type', 'k_sched_type'),
|
||||||
('Schedule max sigma', 'sigma_max'),
|
|
||||||
('Schedule min sigma', 'sigma_min'),
|
|
||||||
('Schedule rho', 'rho'),
|
|
||||||
('Noise multiplier', 'initial_noise_multiplier'),
|
|
||||||
('Eta', 'eta_ancestral'),
|
|
||||||
('Eta DDIM', 'eta_ddim'),
|
|
||||||
('Sigma churn', 's_churn'),
|
|
||||||
('Sigma tmin', 's_tmin'),
|
|
||||||
('Sigma tmax', 's_tmax'),
|
|
||||||
('Sigma noise', 's_noise'),
|
|
||||||
('Discard penultimate sigma', 'always_discard_next_to_last_sigma'),
|
|
||||||
('UniPC variant', 'uni_pc_variant'),
|
|
||||||
('UniPC skip type', 'uni_pc_skip_type'),
|
|
||||||
('UniPC order', 'uni_pc_order'),
|
|
||||||
('UniPC lower order final', 'uni_pc_lower_order_final'),
|
|
||||||
('Token merging ratio', 'token_merging_ratio'),
|
|
||||||
('Token merging ratio hr', 'token_merging_ratio_hr'),
|
|
||||||
('RNG', 'randn_source'),
|
|
||||||
('NGMS', 's_min_uncond'),
|
|
||||||
('Pad conds', 'pad_cond_uncond'),
|
|
||||||
('VAE Encoder', 'sd_vae_encode_method'),
|
|
||||||
('VAE Decoder', 'sd_vae_decode_method'),
|
|
||||||
]
|
]
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
def create_override_settings_dict(text_pairs):
|
def create_override_settings_dict(text_pairs):
|
||||||
@@ -365,7 +348,8 @@ def create_override_settings_dict(text_pairs):
|
|||||||
|
|
||||||
params[k] = v.strip()
|
params[k] = v.strip()
|
||||||
|
|
||||||
for param_name, setting_name in infotext_to_setting_name_mapping:
|
mapping = [(info.infotext, k) for k, info in shared.opts.data_labels.items() if info.infotext]
|
||||||
|
for param_name, setting_name in mapping + infotext_to_setting_name_mapping:
|
||||||
value = params.get(param_name, None)
|
value = params.get(param_name, None)
|
||||||
|
|
||||||
if value is None:
|
if value is None:
|
||||||
@@ -414,10 +398,16 @@ def connect_paste(button, paste_fields, input_comp, override_settings_component,
|
|||||||
return res
|
return res
|
||||||
|
|
||||||
if override_settings_component is not None:
|
if override_settings_component is not None:
|
||||||
|
already_handled_fields = {key: 1 for _, key in paste_fields}
|
||||||
|
|
||||||
def paste_settings(params):
|
def paste_settings(params):
|
||||||
vals = {}
|
vals = {}
|
||||||
|
|
||||||
for param_name, setting_name in infotext_to_setting_name_mapping:
|
mapping = [(info.infotext, k) for k, info in shared.opts.data_labels.items() if info.infotext]
|
||||||
|
for param_name, setting_name in mapping + infotext_to_setting_name_mapping:
|
||||||
|
if param_name in already_handled_fields:
|
||||||
|
continue
|
||||||
|
|
||||||
v = params.get(param_name, None)
|
v = params.get(param_name, None)
|
||||||
if v is None:
|
if v is None:
|
||||||
continue
|
continue
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
from modules import scripts
|
from modules import scripts, ui_tempdir, patches
|
||||||
|
|
||||||
|
|
||||||
def add_classes_to_gradio_component(comp):
|
def add_classes_to_gradio_component(comp):
|
||||||
"""
|
"""
|
||||||
@@ -40,6 +41,8 @@ def Block_get_config(self):
|
|||||||
if webui_tooltip:
|
if webui_tooltip:
|
||||||
config["webui_tooltip"] = webui_tooltip
|
config["webui_tooltip"] = webui_tooltip
|
||||||
|
|
||||||
|
config.pop('example_inputs', None)
|
||||||
|
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
@@ -51,10 +54,20 @@ def BlockContext_init(self, *args, **kwargs):
|
|||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
original_IOComponent_init = gr.components.IOComponent.__init__
|
def Blocks_get_config_file(self, *args, **kwargs):
|
||||||
original_Block_get_config = gr.blocks.Block.get_config
|
config = original_Blocks_get_config_file(self, *args, **kwargs)
|
||||||
original_BlockContext_init = gr.blocks.BlockContext.__init__
|
|
||||||
|
|
||||||
gr.components.IOComponent.__init__ = IOComponent_init
|
for comp_config in config["components"]:
|
||||||
gr.blocks.Block.get_config = Block_get_config
|
if "example_inputs" in comp_config:
|
||||||
gr.blocks.BlockContext.__init__ = BlockContext_init
|
comp_config["example_inputs"] = {"serialized": []}
|
||||||
|
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
original_IOComponent_init = patches.patch(__name__, obj=gr.components.IOComponent, field="__init__", replacement=IOComponent_init)
|
||||||
|
original_Block_get_config = patches.patch(__name__, obj=gr.blocks.Block, field="get_config", replacement=Block_get_config)
|
||||||
|
original_BlockContext_init = patches.patch(__name__, obj=gr.blocks.BlockContext, field="__init__", replacement=BlockContext_init)
|
||||||
|
original_Blocks_get_config_file = patches.patch(__name__, obj=gr.blocks.Blocks, field="get_config_file", replacement=Blocks_get_config_file)
|
||||||
|
|
||||||
|
|
||||||
|
ui_tempdir.install_ui_tempdir_override()
|
||||||
|
|||||||
+34
-14
@@ -21,8 +21,6 @@ from modules import sd_samplers, shared, script_callbacks, errors
|
|||||||
from modules.paths_internal import roboto_ttf_file
|
from modules.paths_internal import roboto_ttf_file
|
||||||
from modules.shared import opts
|
from modules.shared import opts
|
||||||
|
|
||||||
import modules.sd_vae as sd_vae
|
|
||||||
|
|
||||||
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
|
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
|
||||||
|
|
||||||
|
|
||||||
@@ -342,16 +340,6 @@ def sanitize_filename_part(text, replace_spaces=True):
|
|||||||
|
|
||||||
|
|
||||||
class FilenameGenerator:
|
class FilenameGenerator:
|
||||||
def get_vae_filename(self): #get the name of the VAE file.
|
|
||||||
if sd_vae.loaded_vae_file is None:
|
|
||||||
return "NoneType"
|
|
||||||
file_name = os.path.basename(sd_vae.loaded_vae_file)
|
|
||||||
split_file_name = file_name.split('.')
|
|
||||||
if len(split_file_name) > 1 and split_file_name[0] == '':
|
|
||||||
return split_file_name[1] # if the first character of the filename is "." then [1] is obtained.
|
|
||||||
else:
|
|
||||||
return split_file_name[0]
|
|
||||||
|
|
||||||
replacements = {
|
replacements = {
|
||||||
'seed': lambda self: self.seed if self.seed is not None else '',
|
'seed': lambda self: self.seed if self.seed is not None else '',
|
||||||
'seed_first': lambda self: self.seed if self.p.batch_size == 1 else self.p.all_seeds[0],
|
'seed_first': lambda self: self.seed if self.p.batch_size == 1 else self.p.all_seeds[0],
|
||||||
@@ -367,7 +355,9 @@ class FilenameGenerator:
|
|||||||
'date': lambda self: datetime.datetime.now().strftime('%Y-%m-%d'),
|
'date': lambda self: datetime.datetime.now().strftime('%Y-%m-%d'),
|
||||||
'datetime': lambda self, *args: self.datetime(*args), # accepts formats: [datetime], [datetime<Format>], [datetime<Format><Time Zone>]
|
'datetime': lambda self, *args: self.datetime(*args), # accepts formats: [datetime], [datetime<Format>], [datetime<Format><Time Zone>]
|
||||||
'job_timestamp': lambda self: getattr(self.p, "job_timestamp", shared.state.job_timestamp),
|
'job_timestamp': lambda self: getattr(self.p, "job_timestamp", shared.state.job_timestamp),
|
||||||
'prompt_hash': lambda self: hashlib.sha256(self.prompt.encode()).hexdigest()[0:8],
|
'prompt_hash': lambda self, *args: self.string_hash(self.prompt, *args),
|
||||||
|
'negative_prompt_hash': lambda self, *args: self.string_hash(self.p.negative_prompt, *args),
|
||||||
|
'full_prompt_hash': lambda self, *args: self.string_hash(f"{self.p.prompt} {self.p.negative_prompt}", *args), # a space in between to create a unique string
|
||||||
'prompt': lambda self: sanitize_filename_part(self.prompt),
|
'prompt': lambda self: sanitize_filename_part(self.prompt),
|
||||||
'prompt_no_styles': lambda self: self.prompt_no_style(),
|
'prompt_no_styles': lambda self: self.prompt_no_style(),
|
||||||
'prompt_spaces': lambda self: sanitize_filename_part(self.prompt, replace_spaces=False),
|
'prompt_spaces': lambda self: sanitize_filename_part(self.prompt, replace_spaces=False),
|
||||||
@@ -380,7 +370,8 @@ class FilenameGenerator:
|
|||||||
'denoising': lambda self: self.p.denoising_strength if self.p and self.p.denoising_strength else NOTHING_AND_SKIP_PREVIOUS_TEXT,
|
'denoising': lambda self: self.p.denoising_strength if self.p and self.p.denoising_strength else NOTHING_AND_SKIP_PREVIOUS_TEXT,
|
||||||
'user': lambda self: self.p.user,
|
'user': lambda self: self.p.user,
|
||||||
'vae_filename': lambda self: self.get_vae_filename(),
|
'vae_filename': lambda self: self.get_vae_filename(),
|
||||||
'none': lambda self: '', # Overrides the default so you can get just the sequence number
|
'none': lambda self: '', # Overrides the default, so you can get just the sequence number
|
||||||
|
'image_hash': lambda self, *args: self.image_hash(*args) # accepts formats: [image_hash<length>] default full hash
|
||||||
}
|
}
|
||||||
default_time_format = '%Y%m%d%H%M%S'
|
default_time_format = '%Y%m%d%H%M%S'
|
||||||
|
|
||||||
@@ -391,6 +382,22 @@ class FilenameGenerator:
|
|||||||
self.image = image
|
self.image = image
|
||||||
self.zip = zip
|
self.zip = zip
|
||||||
|
|
||||||
|
def get_vae_filename(self):
|
||||||
|
"""Get the name of the VAE file."""
|
||||||
|
|
||||||
|
import modules.sd_vae as sd_vae
|
||||||
|
|
||||||
|
if sd_vae.loaded_vae_file is None:
|
||||||
|
return "NoneType"
|
||||||
|
|
||||||
|
file_name = os.path.basename(sd_vae.loaded_vae_file)
|
||||||
|
split_file_name = file_name.split('.')
|
||||||
|
if len(split_file_name) > 1 and split_file_name[0] == '':
|
||||||
|
return split_file_name[1] # if the first character of the filename is "." then [1] is obtained.
|
||||||
|
else:
|
||||||
|
return split_file_name[0]
|
||||||
|
|
||||||
|
|
||||||
def hasprompt(self, *args):
|
def hasprompt(self, *args):
|
||||||
lower = self.prompt.lower()
|
lower = self.prompt.lower()
|
||||||
if self.p is None or self.prompt is None:
|
if self.p is None or self.prompt is None:
|
||||||
@@ -444,6 +451,14 @@ class FilenameGenerator:
|
|||||||
|
|
||||||
return sanitize_filename_part(formatted_time, replace_spaces=False)
|
return sanitize_filename_part(formatted_time, replace_spaces=False)
|
||||||
|
|
||||||
|
def image_hash(self, *args):
|
||||||
|
length = int(args[0]) if (args and args[0] != "") else None
|
||||||
|
return hashlib.sha256(self.image.tobytes()).hexdigest()[0:length]
|
||||||
|
|
||||||
|
def string_hash(self, text, *args):
|
||||||
|
length = int(args[0]) if (args and args[0] != "") else 8
|
||||||
|
return hashlib.sha256(text.encode()).hexdigest()[0:length]
|
||||||
|
|
||||||
def apply(self, x):
|
def apply(self, x):
|
||||||
res = ''
|
res = ''
|
||||||
|
|
||||||
@@ -585,6 +600,11 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
|
|||||||
"""
|
"""
|
||||||
namegen = FilenameGenerator(p, seed, prompt, image)
|
namegen = FilenameGenerator(p, seed, prompt, image)
|
||||||
|
|
||||||
|
# WebP and JPG formats have maximum dimension limits of 16383 and 65535 respectively. switch to PNG which has a much higher limit
|
||||||
|
if (image.height > 65535 or image.width > 65535) and extension.lower() in ("jpg", "jpeg") or (image.height > 16383 or image.width > 16383) and extension.lower() == "webp":
|
||||||
|
print('Image dimensions too large; saving as PNG')
|
||||||
|
extension = ".png"
|
||||||
|
|
||||||
if save_to_dirs is None:
|
if save_to_dirs is None:
|
||||||
save_to_dirs = (grid and opts.grid_save_to_dirs) or (not grid and opts.save_to_dirs and not no_prompt)
|
save_to_dirs = (grid and opts.grid_save_to_dirs) or (not grid and opts.save_to_dirs and not no_prompt)
|
||||||
|
|
||||||
|
|||||||
+6
-16
@@ -6,7 +6,7 @@ import numpy as np
|
|||||||
from PIL import Image, ImageOps, ImageFilter, ImageEnhance, UnidentifiedImageError
|
from PIL import Image, ImageOps, ImageFilter, ImageEnhance, UnidentifiedImageError
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
from modules import sd_samplers, images as imgutil
|
from modules import images as imgutil
|
||||||
from modules.generation_parameters_copypaste import create_override_settings_dict, parse_generation_parameters
|
from modules.generation_parameters_copypaste import create_override_settings_dict, parse_generation_parameters
|
||||||
from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
|
from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
|
||||||
from modules.shared import opts, state
|
from modules.shared import opts, state
|
||||||
@@ -116,21 +116,20 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal
|
|||||||
process_images(p)
|
process_images(p)
|
||||||
|
|
||||||
|
|
||||||
def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, request: gr.Request, *args):
|
def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_name: str, mask_blur: int, mask_alpha: float, inpainting_fill: int, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, request: gr.Request, *args):
|
||||||
override_settings = create_override_settings_dict(override_settings_texts)
|
override_settings = create_override_settings_dict(override_settings_texts)
|
||||||
|
|
||||||
is_batch = mode == 5
|
is_batch = mode == 5
|
||||||
|
|
||||||
if mode == 0: # img2img
|
if mode == 0: # img2img
|
||||||
image = init_img.convert("RGB")
|
image = init_img
|
||||||
mask = None
|
mask = None
|
||||||
elif mode == 1: # img2img sketch
|
elif mode == 1: # img2img sketch
|
||||||
image = sketch.convert("RGB")
|
image = sketch
|
||||||
mask = None
|
mask = None
|
||||||
elif mode == 2: # inpaint
|
elif mode == 2: # inpaint
|
||||||
image, mask = init_img_with_mask["image"], init_img_with_mask["mask"]
|
image, mask = init_img_with_mask["image"], init_img_with_mask["mask"]
|
||||||
mask = mask.split()[-1].convert("L").point(lambda x: 255 if x > 128 else 0)
|
mask = processing.create_binary_mask(mask)
|
||||||
image = image.convert("RGB")
|
|
||||||
elif mode == 3: # inpaint sketch
|
elif mode == 3: # inpaint sketch
|
||||||
image = inpaint_color_sketch
|
image = inpaint_color_sketch
|
||||||
orig = inpaint_color_sketch_orig or inpaint_color_sketch
|
orig = inpaint_color_sketch_orig or inpaint_color_sketch
|
||||||
@@ -139,7 +138,6 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
|
|||||||
mask = ImageEnhance.Brightness(mask).enhance(1 - mask_alpha / 100)
|
mask = ImageEnhance.Brightness(mask).enhance(1 - mask_alpha / 100)
|
||||||
blur = ImageFilter.GaussianBlur(mask_blur)
|
blur = ImageFilter.GaussianBlur(mask_blur)
|
||||||
image = Image.composite(image.filter(blur), orig, mask.filter(blur))
|
image = Image.composite(image.filter(blur), orig, mask.filter(blur))
|
||||||
image = image.convert("RGB")
|
|
||||||
elif mode == 4: # inpaint upload mask
|
elif mode == 4: # inpaint upload mask
|
||||||
image = init_img_inpaint
|
image = init_img_inpaint
|
||||||
mask = init_mask_inpaint
|
mask = init_mask_inpaint
|
||||||
@@ -166,21 +164,13 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
|
|||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
negative_prompt=negative_prompt,
|
negative_prompt=negative_prompt,
|
||||||
styles=prompt_styles,
|
styles=prompt_styles,
|
||||||
seed=seed,
|
sampler_name=sampler_name,
|
||||||
subseed=subseed,
|
|
||||||
subseed_strength=subseed_strength,
|
|
||||||
seed_resize_from_h=seed_resize_from_h,
|
|
||||||
seed_resize_from_w=seed_resize_from_w,
|
|
||||||
seed_enable_extras=seed_enable_extras,
|
|
||||||
sampler_name=sd_samplers.samplers_for_img2img[sampler_index].name,
|
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
n_iter=n_iter,
|
n_iter=n_iter,
|
||||||
steps=steps,
|
steps=steps,
|
||||||
cfg_scale=cfg_scale,
|
cfg_scale=cfg_scale,
|
||||||
width=width,
|
width=width,
|
||||||
height=height,
|
height=height,
|
||||||
restore_faces=restore_faces,
|
|
||||||
tiling=tiling,
|
|
||||||
init_images=[image],
|
init_images=[image],
|
||||||
mask=mask,
|
mask=mask,
|
||||||
mask_blur=mask_blur,
|
mask_blur=mask_blur,
|
||||||
|
|||||||
@@ -0,0 +1,168 @@
|
|||||||
|
import importlib
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
import warnings
|
||||||
|
from threading import Thread
|
||||||
|
|
||||||
|
from modules.timer import startup_timer
|
||||||
|
|
||||||
|
|
||||||
|
def imports():
|
||||||
|
logging.getLogger("torch.distributed.nn").setLevel(logging.ERROR) # sshh...
|
||||||
|
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
|
||||||
|
|
||||||
|
import torch # noqa: F401
|
||||||
|
startup_timer.record("import torch")
|
||||||
|
import pytorch_lightning # noqa: F401
|
||||||
|
startup_timer.record("import torch")
|
||||||
|
warnings.filterwarnings(action="ignore", category=DeprecationWarning, module="pytorch_lightning")
|
||||||
|
warnings.filterwarnings(action="ignore", category=UserWarning, module="torchvision")
|
||||||
|
|
||||||
|
import gradio # noqa: F401
|
||||||
|
startup_timer.record("import gradio")
|
||||||
|
|
||||||
|
from modules import paths, timer, import_hook, errors # noqa: F401
|
||||||
|
startup_timer.record("setup paths")
|
||||||
|
|
||||||
|
import ldm.modules.encoders.modules # noqa: F401
|
||||||
|
startup_timer.record("import ldm")
|
||||||
|
|
||||||
|
import sgm.modules.encoders.modules # noqa: F401
|
||||||
|
startup_timer.record("import sgm")
|
||||||
|
|
||||||
|
from modules import shared_init
|
||||||
|
shared_init.initialize()
|
||||||
|
startup_timer.record("initialize shared")
|
||||||
|
|
||||||
|
from modules import processing, gradio_extensons, ui # noqa: F401
|
||||||
|
startup_timer.record("other imports")
|
||||||
|
|
||||||
|
|
||||||
|
def check_versions():
|
||||||
|
from modules.shared_cmd_options import cmd_opts
|
||||||
|
|
||||||
|
if not cmd_opts.skip_version_check:
|
||||||
|
from modules import errors
|
||||||
|
errors.check_versions()
|
||||||
|
|
||||||
|
|
||||||
|
def initialize():
|
||||||
|
from modules import initialize_util
|
||||||
|
initialize_util.fix_torch_version()
|
||||||
|
initialize_util.fix_asyncio_event_loop_policy()
|
||||||
|
initialize_util.validate_tls_options()
|
||||||
|
initialize_util.configure_sigint_handler()
|
||||||
|
initialize_util.configure_opts_onchange()
|
||||||
|
|
||||||
|
from modules import modelloader
|
||||||
|
modelloader.cleanup_models()
|
||||||
|
|
||||||
|
from modules import sd_models
|
||||||
|
sd_models.setup_model()
|
||||||
|
startup_timer.record("setup SD model")
|
||||||
|
|
||||||
|
from modules.shared_cmd_options import cmd_opts
|
||||||
|
|
||||||
|
from modules import codeformer_model
|
||||||
|
warnings.filterwarnings(action="ignore", category=UserWarning, module="torchvision.transforms.functional_tensor")
|
||||||
|
codeformer_model.setup_model(cmd_opts.codeformer_models_path)
|
||||||
|
startup_timer.record("setup codeformer")
|
||||||
|
|
||||||
|
from modules import gfpgan_model
|
||||||
|
gfpgan_model.setup_model(cmd_opts.gfpgan_models_path)
|
||||||
|
startup_timer.record("setup gfpgan")
|
||||||
|
|
||||||
|
initialize_rest(reload_script_modules=False)
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_rest(*, reload_script_modules=False):
|
||||||
|
"""
|
||||||
|
Called both from initialize() and when reloading the webui.
|
||||||
|
"""
|
||||||
|
from modules.shared_cmd_options import cmd_opts
|
||||||
|
|
||||||
|
from modules import sd_samplers
|
||||||
|
sd_samplers.set_samplers()
|
||||||
|
startup_timer.record("set samplers")
|
||||||
|
|
||||||
|
from modules import extensions
|
||||||
|
extensions.list_extensions()
|
||||||
|
startup_timer.record("list extensions")
|
||||||
|
|
||||||
|
from modules import initialize_util
|
||||||
|
initialize_util.restore_config_state_file()
|
||||||
|
startup_timer.record("restore config state file")
|
||||||
|
|
||||||
|
from modules import shared, upscaler, scripts
|
||||||
|
if cmd_opts.ui_debug_mode:
|
||||||
|
shared.sd_upscalers = upscaler.UpscalerLanczos().scalers
|
||||||
|
scripts.load_scripts()
|
||||||
|
return
|
||||||
|
|
||||||
|
from modules import sd_models
|
||||||
|
sd_models.list_models()
|
||||||
|
startup_timer.record("list SD models")
|
||||||
|
|
||||||
|
from modules import localization
|
||||||
|
localization.list_localizations(cmd_opts.localizations_dir)
|
||||||
|
startup_timer.record("list localizations")
|
||||||
|
|
||||||
|
with startup_timer.subcategory("load scripts"):
|
||||||
|
scripts.load_scripts()
|
||||||
|
|
||||||
|
if reload_script_modules:
|
||||||
|
for module in [module for name, module in sys.modules.items() if name.startswith("modules.ui")]:
|
||||||
|
importlib.reload(module)
|
||||||
|
startup_timer.record("reload script modules")
|
||||||
|
|
||||||
|
from modules import modelloader
|
||||||
|
modelloader.load_upscalers()
|
||||||
|
startup_timer.record("load upscalers")
|
||||||
|
|
||||||
|
from modules import sd_vae
|
||||||
|
sd_vae.refresh_vae_list()
|
||||||
|
startup_timer.record("refresh VAE")
|
||||||
|
|
||||||
|
from modules import textual_inversion
|
||||||
|
textual_inversion.textual_inversion.list_textual_inversion_templates()
|
||||||
|
startup_timer.record("refresh textual inversion templates")
|
||||||
|
|
||||||
|
from modules import script_callbacks, sd_hijack_optimizations, sd_hijack
|
||||||
|
script_callbacks.on_list_optimizers(sd_hijack_optimizations.list_optimizers)
|
||||||
|
sd_hijack.list_optimizers()
|
||||||
|
startup_timer.record("scripts list_optimizers")
|
||||||
|
|
||||||
|
from modules import sd_unet
|
||||||
|
sd_unet.list_unets()
|
||||||
|
startup_timer.record("scripts list_unets")
|
||||||
|
|
||||||
|
def load_model():
|
||||||
|
"""
|
||||||
|
Accesses shared.sd_model property to load model.
|
||||||
|
After it's available, if it has been loaded before this access by some extension,
|
||||||
|
its optimization may be None because the list of optimizaers has neet been filled
|
||||||
|
by that time, so we apply optimization again.
|
||||||
|
"""
|
||||||
|
|
||||||
|
shared.sd_model # noqa: B018
|
||||||
|
|
||||||
|
if sd_hijack.current_optimizer is None:
|
||||||
|
sd_hijack.apply_optimizations()
|
||||||
|
|
||||||
|
from modules import devices
|
||||||
|
devices.first_time_calculation()
|
||||||
|
|
||||||
|
Thread(target=load_model).start()
|
||||||
|
|
||||||
|
from modules import shared_items
|
||||||
|
shared_items.reload_hypernetworks()
|
||||||
|
startup_timer.record("reload hypernetworks")
|
||||||
|
|
||||||
|
from modules import ui_extra_networks
|
||||||
|
ui_extra_networks.initialize()
|
||||||
|
ui_extra_networks.register_default_pages()
|
||||||
|
|
||||||
|
from modules import extra_networks
|
||||||
|
extra_networks.initialize()
|
||||||
|
extra_networks.register_default_extra_networks()
|
||||||
|
startup_timer.record("initialize extra networks")
|
||||||
@@ -0,0 +1,202 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
import signal
|
||||||
|
import sys
|
||||||
|
import re
|
||||||
|
|
||||||
|
from modules.timer import startup_timer
|
||||||
|
|
||||||
|
|
||||||
|
def gradio_server_name():
|
||||||
|
from modules.shared_cmd_options import cmd_opts
|
||||||
|
|
||||||
|
if cmd_opts.server_name:
|
||||||
|
return cmd_opts.server_name
|
||||||
|
else:
|
||||||
|
return "0.0.0.0" if cmd_opts.listen else None
|
||||||
|
|
||||||
|
|
||||||
|
def fix_torch_version():
|
||||||
|
import torch
|
||||||
|
|
||||||
|
# Truncate version number of nightly/local build of PyTorch to not cause exceptions with CodeFormer or Safetensors
|
||||||
|
if ".dev" in torch.__version__ or "+git" in torch.__version__:
|
||||||
|
torch.__long_version__ = torch.__version__
|
||||||
|
torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0)
|
||||||
|
|
||||||
|
|
||||||
|
def fix_asyncio_event_loop_policy():
|
||||||
|
"""
|
||||||
|
The default `asyncio` event loop policy only automatically creates
|
||||||
|
event loops in the main threads. Other threads must create event
|
||||||
|
loops explicitly or `asyncio.get_event_loop` (and therefore
|
||||||
|
`.IOLoop.current`) will fail. Installing this policy allows event
|
||||||
|
loops to be created automatically on any thread, matching the
|
||||||
|
behavior of Tornado versions prior to 5.0 (or 5.0 on Python 2).
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
if sys.platform == "win32" and hasattr(asyncio, "WindowsSelectorEventLoopPolicy"):
|
||||||
|
# "Any thread" and "selector" should be orthogonal, but there's not a clean
|
||||||
|
# interface for composing policies so pick the right base.
|
||||||
|
_BasePolicy = asyncio.WindowsSelectorEventLoopPolicy # type: ignore
|
||||||
|
else:
|
||||||
|
_BasePolicy = asyncio.DefaultEventLoopPolicy
|
||||||
|
|
||||||
|
class AnyThreadEventLoopPolicy(_BasePolicy): # type: ignore
|
||||||
|
"""Event loop policy that allows loop creation on any thread.
|
||||||
|
Usage::
|
||||||
|
|
||||||
|
asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy())
|
||||||
|
"""
|
||||||
|
|
||||||
|
def get_event_loop(self) -> asyncio.AbstractEventLoop:
|
||||||
|
try:
|
||||||
|
return super().get_event_loop()
|
||||||
|
except (RuntimeError, AssertionError):
|
||||||
|
# This was an AssertionError in python 3.4.2 (which ships with debian jessie)
|
||||||
|
# and changed to a RuntimeError in 3.4.3.
|
||||||
|
# "There is no current event loop in thread %r"
|
||||||
|
loop = self.new_event_loop()
|
||||||
|
self.set_event_loop(loop)
|
||||||
|
return loop
|
||||||
|
|
||||||
|
asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy())
|
||||||
|
|
||||||
|
|
||||||
|
def restore_config_state_file():
|
||||||
|
from modules import shared, config_states
|
||||||
|
|
||||||
|
config_state_file = shared.opts.restore_config_state_file
|
||||||
|
if config_state_file == "":
|
||||||
|
return
|
||||||
|
|
||||||
|
shared.opts.restore_config_state_file = ""
|
||||||
|
shared.opts.save(shared.config_filename)
|
||||||
|
|
||||||
|
if os.path.isfile(config_state_file):
|
||||||
|
print(f"*** About to restore extension state from file: {config_state_file}")
|
||||||
|
with open(config_state_file, "r", encoding="utf-8") as f:
|
||||||
|
config_state = json.load(f)
|
||||||
|
config_states.restore_extension_config(config_state)
|
||||||
|
startup_timer.record("restore extension config")
|
||||||
|
elif config_state_file:
|
||||||
|
print(f"!!! Config state backup not found: {config_state_file}")
|
||||||
|
|
||||||
|
|
||||||
|
def validate_tls_options():
|
||||||
|
from modules.shared_cmd_options import cmd_opts
|
||||||
|
|
||||||
|
if not (cmd_opts.tls_keyfile and cmd_opts.tls_certfile):
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
if not os.path.exists(cmd_opts.tls_keyfile):
|
||||||
|
print("Invalid path to TLS keyfile given")
|
||||||
|
if not os.path.exists(cmd_opts.tls_certfile):
|
||||||
|
print(f"Invalid path to TLS certfile: '{cmd_opts.tls_certfile}'")
|
||||||
|
except TypeError:
|
||||||
|
cmd_opts.tls_keyfile = cmd_opts.tls_certfile = None
|
||||||
|
print("TLS setup invalid, running webui without TLS")
|
||||||
|
else:
|
||||||
|
print("Running with TLS")
|
||||||
|
startup_timer.record("TLS")
|
||||||
|
|
||||||
|
|
||||||
|
def get_gradio_auth_creds():
|
||||||
|
"""
|
||||||
|
Convert the gradio_auth and gradio_auth_path commandline arguments into
|
||||||
|
an iterable of (username, password) tuples.
|
||||||
|
"""
|
||||||
|
from modules.shared_cmd_options import cmd_opts
|
||||||
|
|
||||||
|
def process_credential_line(s):
|
||||||
|
s = s.strip()
|
||||||
|
if not s:
|
||||||
|
return None
|
||||||
|
return tuple(s.split(':', 1))
|
||||||
|
|
||||||
|
if cmd_opts.gradio_auth:
|
||||||
|
for cred in cmd_opts.gradio_auth.split(','):
|
||||||
|
cred = process_credential_line(cred)
|
||||||
|
if cred:
|
||||||
|
yield cred
|
||||||
|
|
||||||
|
if cmd_opts.gradio_auth_path:
|
||||||
|
with open(cmd_opts.gradio_auth_path, 'r', encoding="utf8") as file:
|
||||||
|
for line in file.readlines():
|
||||||
|
for cred in line.strip().split(','):
|
||||||
|
cred = process_credential_line(cred)
|
||||||
|
if cred:
|
||||||
|
yield cred
|
||||||
|
|
||||||
|
|
||||||
|
def dumpstacks():
|
||||||
|
import threading
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
id2name = {th.ident: th.name for th in threading.enumerate()}
|
||||||
|
code = []
|
||||||
|
for threadId, stack in sys._current_frames().items():
|
||||||
|
code.append(f"\n# Thread: {id2name.get(threadId, '')}({threadId})")
|
||||||
|
for filename, lineno, name, line in traceback.extract_stack(stack):
|
||||||
|
code.append(f"""File: "{filename}", line {lineno}, in {name}""")
|
||||||
|
if line:
|
||||||
|
code.append(" " + line.strip())
|
||||||
|
|
||||||
|
print("\n".join(code))
|
||||||
|
|
||||||
|
|
||||||
|
def configure_sigint_handler():
|
||||||
|
# make the program just exit at ctrl+c without waiting for anything
|
||||||
|
def sigint_handler(sig, frame):
|
||||||
|
print(f'Interrupted with signal {sig} in {frame}')
|
||||||
|
|
||||||
|
dumpstacks()
|
||||||
|
|
||||||
|
os._exit(0)
|
||||||
|
|
||||||
|
if not os.environ.get("COVERAGE_RUN"):
|
||||||
|
# Don't install the immediate-quit handler when running under coverage,
|
||||||
|
# as then the coverage report won't be generated.
|
||||||
|
signal.signal(signal.SIGINT, sigint_handler)
|
||||||
|
|
||||||
|
|
||||||
|
def configure_opts_onchange():
|
||||||
|
from modules import shared, sd_models, sd_vae, ui_tempdir, sd_hijack
|
||||||
|
from modules.call_queue import wrap_queued_call
|
||||||
|
|
||||||
|
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: sd_models.reload_model_weights()), call=False)
|
||||||
|
shared.opts.onchange("sd_vae", wrap_queued_call(lambda: sd_vae.reload_vae_weights()), call=False)
|
||||||
|
shared.opts.onchange("sd_vae_overrides_per_model_preferences", wrap_queued_call(lambda: sd_vae.reload_vae_weights()), call=False)
|
||||||
|
shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed)
|
||||||
|
shared.opts.onchange("gradio_theme", shared.reload_gradio_theme)
|
||||||
|
shared.opts.onchange("cross_attention_optimization", wrap_queued_call(lambda: sd_hijack.model_hijack.redo_hijack(shared.sd_model)), call=False)
|
||||||
|
startup_timer.record("opts onchange")
|
||||||
|
|
||||||
|
|
||||||
|
def setup_middleware(app):
|
||||||
|
from starlette.middleware.gzip import GZipMiddleware
|
||||||
|
|
||||||
|
app.middleware_stack = None # reset current middleware to allow modifying user provided list
|
||||||
|
app.add_middleware(GZipMiddleware, minimum_size=1000)
|
||||||
|
configure_cors_middleware(app)
|
||||||
|
app.build_middleware_stack() # rebuild middleware stack on-the-fly
|
||||||
|
|
||||||
|
|
||||||
|
def configure_cors_middleware(app):
|
||||||
|
from starlette.middleware.cors import CORSMiddleware
|
||||||
|
from modules.shared_cmd_options import cmd_opts
|
||||||
|
|
||||||
|
cors_options = {
|
||||||
|
"allow_methods": ["*"],
|
||||||
|
"allow_headers": ["*"],
|
||||||
|
"allow_credentials": True,
|
||||||
|
}
|
||||||
|
if cmd_opts.cors_allow_origins:
|
||||||
|
cors_options["allow_origins"] = cmd_opts.cors_allow_origins.split(',')
|
||||||
|
if cmd_opts.cors_allow_origins_regex:
|
||||||
|
cors_options["allow_origin_regex"] = cmd_opts.cors_allow_origins_regex
|
||||||
|
app.add_middleware(CORSMiddleware, **cors_options)
|
||||||
|
|
||||||
@@ -186,7 +186,6 @@ class InterrogateModels:
|
|||||||
res = ""
|
res = ""
|
||||||
shared.state.begin(job="interrogate")
|
shared.state.begin(job="interrogate")
|
||||||
try:
|
try:
|
||||||
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
|
||||||
lowvram.send_everything_to_cpu()
|
lowvram.send_everything_to_cpu()
|
||||||
devices.torch_gc()
|
devices.torch_gc()
|
||||||
|
|
||||||
|
|||||||
+41
-13
@@ -1,7 +1,9 @@
|
|||||||
# this scripts installs necessary requirements and launches main program in webui.py
|
# this scripts installs necessary requirements and launches main program in webui.py
|
||||||
|
import logging
|
||||||
import re
|
import re
|
||||||
import subprocess
|
import subprocess
|
||||||
import os
|
import os
|
||||||
|
import shutil
|
||||||
import sys
|
import sys
|
||||||
import importlib.util
|
import importlib.util
|
||||||
import platform
|
import platform
|
||||||
@@ -11,8 +13,10 @@ from functools import lru_cache
|
|||||||
from modules import cmd_args, errors
|
from modules import cmd_args, errors
|
||||||
from modules.paths_internal import script_path, extensions_dir
|
from modules.paths_internal import script_path, extensions_dir
|
||||||
from modules.timer import startup_timer
|
from modules.timer import startup_timer
|
||||||
|
from modules import logging_config
|
||||||
|
|
||||||
args, _ = cmd_args.parser.parse_known_args()
|
args, _ = cmd_args.parser.parse_known_args()
|
||||||
|
logging_config.setup_logging(args.loglevel)
|
||||||
|
|
||||||
python = sys.executable
|
python = sys.executable
|
||||||
git = os.environ.get('GIT', "git")
|
git = os.environ.get('GIT', "git")
|
||||||
@@ -139,6 +143,25 @@ def check_run_python(code: str) -> bool:
|
|||||||
return result.returncode == 0
|
return result.returncode == 0
|
||||||
|
|
||||||
|
|
||||||
|
def git_fix_workspace(dir, name):
|
||||||
|
run(f'"{git}" -C "{dir}" fetch --refetch --no-auto-gc', f"Fetching all contents for {name}", f"Couldn't fetch {name}", live=True)
|
||||||
|
run(f'"{git}" -C "{dir}" gc --aggressive --prune=now', f"Pruning {name}", f"Couldn't prune {name}", live=True)
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
def run_git(dir, name, command, desc=None, errdesc=None, custom_env=None, live: bool = default_command_live, autofix=True):
|
||||||
|
try:
|
||||||
|
return run(f'"{git}" -C "{dir}" {command}', desc=desc, errdesc=errdesc, custom_env=custom_env, live=live)
|
||||||
|
except RuntimeError:
|
||||||
|
if not autofix:
|
||||||
|
raise
|
||||||
|
|
||||||
|
print(f"{errdesc}, attempting autofix...")
|
||||||
|
git_fix_workspace(dir, name)
|
||||||
|
|
||||||
|
return run(f'"{git}" -C "{dir}" {command}', desc=desc, errdesc=errdesc, custom_env=custom_env, live=live)
|
||||||
|
|
||||||
|
|
||||||
def git_clone(url, dir, name, commithash=None):
|
def git_clone(url, dir, name, commithash=None):
|
||||||
# TODO clone into temporary dir and move if successful
|
# TODO clone into temporary dir and move if successful
|
||||||
|
|
||||||
@@ -146,15 +169,24 @@ def git_clone(url, dir, name, commithash=None):
|
|||||||
if commithash is None:
|
if commithash is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
current_hash = run(f'"{git}" -C "{dir}" rev-parse HEAD', None, f"Couldn't determine {name}'s hash: {commithash}", live=False).strip()
|
current_hash = run_git(dir, name, 'rev-parse HEAD', None, f"Couldn't determine {name}'s hash: {commithash}", live=False).strip()
|
||||||
if current_hash == commithash:
|
if current_hash == commithash:
|
||||||
return
|
return
|
||||||
|
|
||||||
run(f'"{git}" -C "{dir}" fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}")
|
if run_git(dir, name, 'config --get remote.origin.url', None, f"Couldn't determine {name}'s origin URL", live=False).strip() != url:
|
||||||
run(f'"{git}" -C "{dir}" checkout {commithash}', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}", live=True)
|
run_git(dir, name, f'remote set-url origin "{url}"', None, f"Failed to set {name}'s origin URL", live=False)
|
||||||
|
|
||||||
|
run_git(dir, name, 'fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}", autofix=False)
|
||||||
|
|
||||||
|
run_git(dir, name, f'checkout {commithash}', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}", live=True)
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
run(f'"{git}" clone "{url}" "{dir}"', f"Cloning {name} into {dir}...", f"Couldn't clone {name}", live=True)
|
run(f'"{git}" clone "{url}" "{dir}"', f"Cloning {name} into {dir}...", f"Couldn't clone {name}", live=True)
|
||||||
|
except RuntimeError:
|
||||||
|
shutil.rmtree(dir, ignore_errors=True)
|
||||||
|
raise
|
||||||
|
|
||||||
if commithash is not None:
|
if commithash is not None:
|
||||||
run(f'"{git}" -C "{dir}" checkout {commithash}', None, "Couldn't checkout {name}'s hash: {commithash}")
|
run(f'"{git}" -C "{dir}" checkout {commithash}', None, "Couldn't checkout {name}'s hash: {commithash}")
|
||||||
@@ -214,7 +246,7 @@ def list_extensions(settings_file):
|
|||||||
disabled_extensions = set(settings.get('disabled_extensions', []))
|
disabled_extensions = set(settings.get('disabled_extensions', []))
|
||||||
disable_all_extensions = settings.get('disable_all_extensions', 'none')
|
disable_all_extensions = settings.get('disable_all_extensions', 'none')
|
||||||
|
|
||||||
if disable_all_extensions != 'none':
|
if disable_all_extensions != 'none' or args.disable_extra_extensions or args.disable_all_extensions or not os.path.isdir(extensions_dir):
|
||||||
return []
|
return []
|
||||||
|
|
||||||
return [x for x in os.listdir(extensions_dir) if x not in disabled_extensions]
|
return [x for x in os.listdir(extensions_dir) if x not in disabled_extensions]
|
||||||
@@ -226,6 +258,8 @@ def run_extensions_installers(settings_file):
|
|||||||
|
|
||||||
with startup_timer.subcategory("run extensions installers"):
|
with startup_timer.subcategory("run extensions installers"):
|
||||||
for dirname_extension in list_extensions(settings_file):
|
for dirname_extension in list_extensions(settings_file):
|
||||||
|
logging.debug(f"Installing {dirname_extension}")
|
||||||
|
|
||||||
path = os.path.join(extensions_dir, dirname_extension)
|
path = os.path.join(extensions_dir, dirname_extension)
|
||||||
|
|
||||||
if os.path.isdir(path):
|
if os.path.isdir(path):
|
||||||
@@ -277,7 +311,6 @@ def prepare_environment():
|
|||||||
requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
|
requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
|
||||||
|
|
||||||
xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.20')
|
xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.20')
|
||||||
gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "https://github.com/TencentARC/GFPGAN/archive/8d2447a2d918f8eba5a4a01463fd48e45126a379.zip")
|
|
||||||
clip_package = os.environ.get('CLIP_PACKAGE', "https://github.com/openai/CLIP/archive/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1.zip")
|
clip_package = os.environ.get('CLIP_PACKAGE', "https://github.com/openai/CLIP/archive/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1.zip")
|
||||||
openclip_package = os.environ.get('OPENCLIP_PACKAGE', "https://github.com/mlfoundations/open_clip/archive/bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b.zip")
|
openclip_package = os.environ.get('OPENCLIP_PACKAGE', "https://github.com/mlfoundations/open_clip/archive/bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b.zip")
|
||||||
|
|
||||||
@@ -288,13 +321,13 @@ def prepare_environment():
|
|||||||
blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git')
|
blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git')
|
||||||
|
|
||||||
stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf")
|
stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf")
|
||||||
stable_diffusion_xl_commit_hash = os.environ.get('STABLE_DIFFUSION_XL_COMMIT_HASH', "5c10deee76adad0032b412294130090932317a87")
|
stable_diffusion_xl_commit_hash = os.environ.get('STABLE_DIFFUSION_XL_COMMIT_HASH', "45c443b316737a4ab6e40413d7794a7f5657c19f")
|
||||||
k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "c9fe758757e022f05ca5a53fa8fac28889e4f1cf")
|
k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "ab527a9a6d347f364e3d185ba6d714e22d80cb3c")
|
||||||
codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af")
|
codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af")
|
||||||
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
|
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# the existance of this file is a signal to webui.sh/bat that webui needs to be restarted when it stops execution
|
# the existence of this file is a signal to webui.sh/bat that webui needs to be restarted when it stops execution
|
||||||
os.remove(os.path.join(script_path, "tmp", "restart"))
|
os.remove(os.path.join(script_path, "tmp", "restart"))
|
||||||
os.environ.setdefault('SD_WEBUI_RESTARTING', '1')
|
os.environ.setdefault('SD_WEBUI_RESTARTING', '1')
|
||||||
except OSError:
|
except OSError:
|
||||||
@@ -324,11 +357,6 @@ def prepare_environment():
|
|||||||
)
|
)
|
||||||
startup_timer.record("torch GPU test")
|
startup_timer.record("torch GPU test")
|
||||||
|
|
||||||
|
|
||||||
if not is_installed("gfpgan"):
|
|
||||||
run_pip(f"install {gfpgan_package}", "gfpgan")
|
|
||||||
startup_timer.record("install gfpgan")
|
|
||||||
|
|
||||||
if not is_installed("clip"):
|
if not is_installed("clip"):
|
||||||
run_pip(f"install {clip_package}", "clip")
|
run_pip(f"install {clip_package}", "clip")
|
||||||
startup_timer.record("install clip")
|
startup_timer.record("install clip")
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from modules import errors
|
from modules import errors, scripts
|
||||||
|
|
||||||
localizations = {}
|
localizations = {}
|
||||||
|
|
||||||
@@ -16,7 +16,6 @@ def list_localizations(dirname):
|
|||||||
|
|
||||||
localizations[fn] = os.path.join(dirname, file)
|
localizations[fn] = os.path.join(dirname, file)
|
||||||
|
|
||||||
from modules import scripts
|
|
||||||
for file in scripts.list_scripts("localizations", ".json"):
|
for file in scripts.list_scripts("localizations", ".json"):
|
||||||
fn, ext = os.path.splitext(file.filename)
|
fn, ext = os.path.splitext(file.filename)
|
||||||
localizations[fn] = file.path
|
localizations[fn] = file.path
|
||||||
|
|||||||
@@ -0,0 +1,16 @@
|
|||||||
|
import os
|
||||||
|
import logging
|
||||||
|
|
||||||
|
|
||||||
|
def setup_logging(loglevel):
|
||||||
|
if loglevel is None:
|
||||||
|
loglevel = os.environ.get("SD_WEBUI_LOG_LEVEL")
|
||||||
|
|
||||||
|
if loglevel:
|
||||||
|
log_level = getattr(logging, loglevel.upper(), None) or logging.INFO
|
||||||
|
logging.basicConfig(
|
||||||
|
level=log_level,
|
||||||
|
format='%(asctime)s %(levelname)s [%(name)s] %(message)s',
|
||||||
|
datefmt='%Y-%m-%d %H:%M:%S',
|
||||||
|
)
|
||||||
|
|
||||||
+16
-2
@@ -1,5 +1,5 @@
|
|||||||
import torch
|
import torch
|
||||||
from modules import devices
|
from modules import devices, shared
|
||||||
|
|
||||||
module_in_gpu = None
|
module_in_gpu = None
|
||||||
cpu = torch.device("cpu")
|
cpu = torch.device("cpu")
|
||||||
@@ -14,6 +14,20 @@ def send_everything_to_cpu():
|
|||||||
module_in_gpu = None
|
module_in_gpu = None
|
||||||
|
|
||||||
|
|
||||||
|
def is_needed(sd_model):
|
||||||
|
return shared.cmd_opts.lowvram or shared.cmd_opts.medvram or shared.cmd_opts.medvram_sdxl and hasattr(sd_model, 'conditioner')
|
||||||
|
|
||||||
|
|
||||||
|
def apply(sd_model):
|
||||||
|
enable = is_needed(sd_model)
|
||||||
|
shared.parallel_processing_allowed = not enable
|
||||||
|
|
||||||
|
if enable:
|
||||||
|
setup_for_low_vram(sd_model, not shared.cmd_opts.lowvram)
|
||||||
|
else:
|
||||||
|
sd_model.lowvram = False
|
||||||
|
|
||||||
|
|
||||||
def setup_for_low_vram(sd_model, use_medvram):
|
def setup_for_low_vram(sd_model, use_medvram):
|
||||||
if getattr(sd_model, 'lowvram', False):
|
if getattr(sd_model, 'lowvram', False):
|
||||||
return
|
return
|
||||||
@@ -130,4 +144,4 @@ def setup_for_low_vram(sd_model, use_medvram):
|
|||||||
|
|
||||||
|
|
||||||
def is_enabled(sd_model):
|
def is_enabled(sd_model):
|
||||||
return getattr(sd_model, 'lowvram', False)
|
return sd_model.lowvram
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import torch
|
|||||||
import platform
|
import platform
|
||||||
from modules.sd_hijack_utils import CondFunc
|
from modules.sd_hijack_utils import CondFunc
|
||||||
from packaging import version
|
from packaging import version
|
||||||
|
from modules import shared
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -30,8 +31,7 @@ has_mps = check_for_mps()
|
|||||||
|
|
||||||
def torch_mps_gc() -> None:
|
def torch_mps_gc() -> None:
|
||||||
try:
|
try:
|
||||||
from modules.shared import state
|
if shared.state.current_latent is not None:
|
||||||
if state.current_latent is not None:
|
|
||||||
log.debug("`current_latent` is set, skipping MPS garbage collection")
|
log.debug("`current_latent` is set, skipping MPS garbage collection")
|
||||||
return
|
return
|
||||||
from torch.mps import empty_cache
|
from torch.mps import empty_cache
|
||||||
@@ -52,9 +52,6 @@ def cumsum_fix(input, cumsum_func, *args, **kwargs):
|
|||||||
|
|
||||||
|
|
||||||
if has_mps:
|
if has_mps:
|
||||||
# MPS fix for randn in torchsde
|
|
||||||
CondFunc('torchsde._brownian.brownian_interval._randn', lambda _, size, dtype, device, seed: torch.randn(size, dtype=dtype, device=torch.device("cpu"), generator=torch.Generator(torch.device("cpu")).manual_seed(int(seed))).to(device), lambda _, size, dtype, device, seed: device.type == 'mps')
|
|
||||||
|
|
||||||
if platform.mac_ver()[0].startswith("13.2."):
|
if platform.mac_ver()[0].startswith("13.2."):
|
||||||
# MPS workaround for https://github.com/pytorch/pytorch/issues/95188, thanks to danieldk (https://github.com/explosion/curated-transformers/pull/124)
|
# MPS workaround for https://github.com/pytorch/pytorch/issues/95188, thanks to danieldk (https://github.com/explosion/curated-transformers/pull/124)
|
||||||
CondFunc('torch.nn.functional.linear', lambda _, input, weight, bias: (torch.matmul(input, weight.t()) + bias) if bias is not None else torch.matmul(input, weight.t()), lambda _, input, weight, bias: input.numel() > 10485760)
|
CondFunc('torch.nn.functional.linear', lambda _, input, weight, bias: (torch.matmul(input, weight.t()) + bias) if bias is not None else torch.matmul(input, weight.t()), lambda _, input, weight, bias: input.numel() > 10485760)
|
||||||
|
|||||||
@@ -0,0 +1,245 @@
|
|||||||
|
import json
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import gradio as gr
|
||||||
|
|
||||||
|
from modules import errors
|
||||||
|
from modules.shared_cmd_options import cmd_opts
|
||||||
|
|
||||||
|
|
||||||
|
class OptionInfo:
|
||||||
|
def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, section=None, refresh=None, comment_before='', comment_after='', infotext=None, restrict_api=False):
|
||||||
|
self.default = default
|
||||||
|
self.label = label
|
||||||
|
self.component = component
|
||||||
|
self.component_args = component_args
|
||||||
|
self.onchange = onchange
|
||||||
|
self.section = section
|
||||||
|
self.refresh = refresh
|
||||||
|
self.do_not_save = False
|
||||||
|
|
||||||
|
self.comment_before = comment_before
|
||||||
|
"""HTML text that will be added after label in UI"""
|
||||||
|
|
||||||
|
self.comment_after = comment_after
|
||||||
|
"""HTML text that will be added before label in UI"""
|
||||||
|
|
||||||
|
self.infotext = infotext
|
||||||
|
|
||||||
|
self.restrict_api = restrict_api
|
||||||
|
"""If True, the setting will not be accessible via API"""
|
||||||
|
|
||||||
|
def link(self, label, url):
|
||||||
|
self.comment_before += f"[<a href='{url}' target='_blank'>{label}</a>]"
|
||||||
|
return self
|
||||||
|
|
||||||
|
def js(self, label, js_func):
|
||||||
|
self.comment_before += f"[<a onclick='{js_func}(); return false'>{label}</a>]"
|
||||||
|
return self
|
||||||
|
|
||||||
|
def info(self, info):
|
||||||
|
self.comment_after += f"<span class='info'>({info})</span>"
|
||||||
|
return self
|
||||||
|
|
||||||
|
def html(self, html):
|
||||||
|
self.comment_after += html
|
||||||
|
return self
|
||||||
|
|
||||||
|
def needs_restart(self):
|
||||||
|
self.comment_after += " <span class='info'>(requires restart)</span>"
|
||||||
|
return self
|
||||||
|
|
||||||
|
def needs_reload_ui(self):
|
||||||
|
self.comment_after += " <span class='info'>(requires Reload UI)</span>"
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
class OptionHTML(OptionInfo):
|
||||||
|
def __init__(self, text):
|
||||||
|
super().__init__(str(text).strip(), label='', component=lambda **kwargs: gr.HTML(elem_classes="settings-info", **kwargs))
|
||||||
|
|
||||||
|
self.do_not_save = True
|
||||||
|
|
||||||
|
|
||||||
|
def options_section(section_identifier, options_dict):
|
||||||
|
for v in options_dict.values():
|
||||||
|
v.section = section_identifier
|
||||||
|
|
||||||
|
return options_dict
|
||||||
|
|
||||||
|
|
||||||
|
options_builtin_fields = {"data_labels", "data", "restricted_opts", "typemap"}
|
||||||
|
|
||||||
|
|
||||||
|
class Options:
|
||||||
|
typemap = {int: float}
|
||||||
|
|
||||||
|
def __init__(self, data_labels: dict[str, OptionInfo], restricted_opts):
|
||||||
|
self.data_labels = data_labels
|
||||||
|
self.data = {k: v.default for k, v in self.data_labels.items()}
|
||||||
|
self.restricted_opts = restricted_opts
|
||||||
|
|
||||||
|
def __setattr__(self, key, value):
|
||||||
|
if key in options_builtin_fields:
|
||||||
|
return super(Options, self).__setattr__(key, value)
|
||||||
|
|
||||||
|
if self.data is not None:
|
||||||
|
if key in self.data or key in self.data_labels:
|
||||||
|
assert not cmd_opts.freeze_settings, "changing settings is disabled"
|
||||||
|
|
||||||
|
info = self.data_labels.get(key, None)
|
||||||
|
if info.do_not_save:
|
||||||
|
return
|
||||||
|
|
||||||
|
comp_args = info.component_args if info else None
|
||||||
|
if isinstance(comp_args, dict) and comp_args.get('visible', True) is False:
|
||||||
|
raise RuntimeError(f"not possible to set {key} because it is restricted")
|
||||||
|
|
||||||
|
if cmd_opts.hide_ui_dir_config and key in self.restricted_opts:
|
||||||
|
raise RuntimeError(f"not possible to set {key} because it is restricted")
|
||||||
|
|
||||||
|
self.data[key] = value
|
||||||
|
return
|
||||||
|
|
||||||
|
return super(Options, self).__setattr__(key, value)
|
||||||
|
|
||||||
|
def __getattr__(self, item):
|
||||||
|
if item in options_builtin_fields:
|
||||||
|
return super(Options, self).__getattribute__(item)
|
||||||
|
|
||||||
|
if self.data is not None:
|
||||||
|
if item in self.data:
|
||||||
|
return self.data[item]
|
||||||
|
|
||||||
|
if item in self.data_labels:
|
||||||
|
return self.data_labels[item].default
|
||||||
|
|
||||||
|
return super(Options, self).__getattribute__(item)
|
||||||
|
|
||||||
|
def set(self, key, value, is_api=False, run_callbacks=True):
|
||||||
|
"""sets an option and calls its onchange callback, returning True if the option changed and False otherwise"""
|
||||||
|
|
||||||
|
oldval = self.data.get(key, None)
|
||||||
|
if oldval == value:
|
||||||
|
return False
|
||||||
|
|
||||||
|
option = self.data_labels[key]
|
||||||
|
if option.do_not_save:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if is_api and option.restrict_api:
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
setattr(self, key, value)
|
||||||
|
except RuntimeError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if run_callbacks and option.onchange is not None:
|
||||||
|
try:
|
||||||
|
option.onchange()
|
||||||
|
except Exception as e:
|
||||||
|
errors.display(e, f"changing setting {key} to {value}")
|
||||||
|
setattr(self, key, oldval)
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def get_default(self, key):
|
||||||
|
"""returns the default value for the key"""
|
||||||
|
|
||||||
|
data_label = self.data_labels.get(key)
|
||||||
|
if data_label is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return data_label.default
|
||||||
|
|
||||||
|
def save(self, filename):
|
||||||
|
assert not cmd_opts.freeze_settings, "saving settings is disabled"
|
||||||
|
|
||||||
|
with open(filename, "w", encoding="utf8") as file:
|
||||||
|
json.dump(self.data, file, indent=4)
|
||||||
|
|
||||||
|
def same_type(self, x, y):
|
||||||
|
if x is None or y is None:
|
||||||
|
return True
|
||||||
|
|
||||||
|
type_x = self.typemap.get(type(x), type(x))
|
||||||
|
type_y = self.typemap.get(type(y), type(y))
|
||||||
|
|
||||||
|
return type_x == type_y
|
||||||
|
|
||||||
|
def load(self, filename):
|
||||||
|
with open(filename, "r", encoding="utf8") as file:
|
||||||
|
self.data = json.load(file)
|
||||||
|
|
||||||
|
# 1.6.0 VAE defaults
|
||||||
|
if self.data.get('sd_vae_as_default') is not None and self.data.get('sd_vae_overrides_per_model_preferences') is None:
|
||||||
|
self.data['sd_vae_overrides_per_model_preferences'] = not self.data.get('sd_vae_as_default')
|
||||||
|
|
||||||
|
# 1.1.1 quicksettings list migration
|
||||||
|
if self.data.get('quicksettings') is not None and self.data.get('quicksettings_list') is None:
|
||||||
|
self.data['quicksettings_list'] = [i.strip() for i in self.data.get('quicksettings').split(',')]
|
||||||
|
|
||||||
|
# 1.4.0 ui_reorder
|
||||||
|
if isinstance(self.data.get('ui_reorder'), str) and self.data.get('ui_reorder') and "ui_reorder_list" not in self.data:
|
||||||
|
self.data['ui_reorder_list'] = [i.strip() for i in self.data.get('ui_reorder').split(',')]
|
||||||
|
|
||||||
|
bad_settings = 0
|
||||||
|
for k, v in self.data.items():
|
||||||
|
info = self.data_labels.get(k, None)
|
||||||
|
if info is not None and not self.same_type(info.default, v):
|
||||||
|
print(f"Warning: bad setting value: {k}: {v} ({type(v).__name__}; expected {type(info.default).__name__})", file=sys.stderr)
|
||||||
|
bad_settings += 1
|
||||||
|
|
||||||
|
if bad_settings > 0:
|
||||||
|
print(f"The program is likely to not work with bad settings.\nSettings file: {filename}\nEither fix the file, or delete it and restart.", file=sys.stderr)
|
||||||
|
|
||||||
|
def onchange(self, key, func, call=True):
|
||||||
|
item = self.data_labels.get(key)
|
||||||
|
item.onchange = func
|
||||||
|
|
||||||
|
if call:
|
||||||
|
func()
|
||||||
|
|
||||||
|
def dumpjson(self):
|
||||||
|
d = {k: self.data.get(k, v.default) for k, v in self.data_labels.items()}
|
||||||
|
d["_comments_before"] = {k: v.comment_before for k, v in self.data_labels.items() if v.comment_before is not None}
|
||||||
|
d["_comments_after"] = {k: v.comment_after for k, v in self.data_labels.items() if v.comment_after is not None}
|
||||||
|
return json.dumps(d)
|
||||||
|
|
||||||
|
def add_option(self, key, info):
|
||||||
|
self.data_labels[key] = info
|
||||||
|
|
||||||
|
def reorder(self):
|
||||||
|
"""reorder settings so that all items related to section always go together"""
|
||||||
|
|
||||||
|
section_ids = {}
|
||||||
|
settings_items = self.data_labels.items()
|
||||||
|
for _, item in settings_items:
|
||||||
|
if item.section not in section_ids:
|
||||||
|
section_ids[item.section] = len(section_ids)
|
||||||
|
|
||||||
|
self.data_labels = dict(sorted(settings_items, key=lambda x: section_ids[x[1].section]))
|
||||||
|
|
||||||
|
def cast_value(self, key, value):
|
||||||
|
"""casts an arbitrary to the same type as this setting's value with key
|
||||||
|
Example: cast_value("eta_noise_seed_delta", "12") -> returns 12 (an int rather than str)
|
||||||
|
"""
|
||||||
|
|
||||||
|
if value is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
default_value = self.data_labels[key].default
|
||||||
|
if default_value is None:
|
||||||
|
default_value = getattr(self, key, None)
|
||||||
|
if default_value is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
expected_type = type(default_value)
|
||||||
|
if expected_type == bool and value == "False":
|
||||||
|
value = False
|
||||||
|
else:
|
||||||
|
value = expected_type(value)
|
||||||
|
|
||||||
|
return value
|
||||||
@@ -0,0 +1,64 @@
|
|||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
|
||||||
|
def patch(key, obj, field, replacement):
|
||||||
|
"""Replaces a function in a module or a class.
|
||||||
|
|
||||||
|
Also stores the original function in this module, possible to be retrieved via original(key, obj, field).
|
||||||
|
If the function is already replaced by this caller (key), an exception is raised -- use undo() before that.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
key: identifying information for who is doing the replacement. You can use __name__.
|
||||||
|
obj: the module or the class
|
||||||
|
field: name of the function as a string
|
||||||
|
replacement: the new function
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
the original function
|
||||||
|
"""
|
||||||
|
|
||||||
|
patch_key = (obj, field)
|
||||||
|
if patch_key in originals[key]:
|
||||||
|
raise RuntimeError(f"patch for {field} is already applied")
|
||||||
|
|
||||||
|
original_func = getattr(obj, field)
|
||||||
|
originals[key][patch_key] = original_func
|
||||||
|
|
||||||
|
setattr(obj, field, replacement)
|
||||||
|
|
||||||
|
return original_func
|
||||||
|
|
||||||
|
|
||||||
|
def undo(key, obj, field):
|
||||||
|
"""Undoes the peplacement by the patch().
|
||||||
|
|
||||||
|
If the function is not replaced, raises an exception.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
key: identifying information for who is doing the replacement. You can use __name__.
|
||||||
|
obj: the module or the class
|
||||||
|
field: name of the function as a string
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Always None
|
||||||
|
"""
|
||||||
|
|
||||||
|
patch_key = (obj, field)
|
||||||
|
|
||||||
|
if patch_key not in originals[key]:
|
||||||
|
raise RuntimeError(f"there is no patch for {field} to undo")
|
||||||
|
|
||||||
|
original_func = originals[key].pop(patch_key)
|
||||||
|
setattr(obj, field, original_func)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def original(key, obj, field):
|
||||||
|
"""Returns the original function for the patch created by the patch() function"""
|
||||||
|
patch_key = (obj, field)
|
||||||
|
|
||||||
|
return originals[key].get(patch_key, None)
|
||||||
|
|
||||||
|
|
||||||
|
originals = defaultdict(dict)
|
||||||
+11
-12
@@ -11,10 +11,9 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir,
|
|||||||
|
|
||||||
shared.state.begin(job="extras")
|
shared.state.begin(job="extras")
|
||||||
|
|
||||||
image_data = []
|
|
||||||
image_names = []
|
|
||||||
outputs = []
|
outputs = []
|
||||||
|
|
||||||
|
def get_images(extras_mode, image, image_folder, input_dir):
|
||||||
if extras_mode == 1:
|
if extras_mode == 1:
|
||||||
for img in image_folder:
|
for img in image_folder:
|
||||||
if isinstance(img, Image.Image):
|
if isinstance(img, Image.Image):
|
||||||
@@ -23,8 +22,7 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir,
|
|||||||
else:
|
else:
|
||||||
image = Image.open(os.path.abspath(img.name))
|
image = Image.open(os.path.abspath(img.name))
|
||||||
fn = os.path.splitext(img.orig_name)[0]
|
fn = os.path.splitext(img.orig_name)[0]
|
||||||
image_data.append(image)
|
yield image, fn
|
||||||
image_names.append(fn)
|
|
||||||
elif extras_mode == 2:
|
elif extras_mode == 2:
|
||||||
assert not shared.cmd_opts.hide_ui_dir_config, '--hide-ui-dir-config option must be disabled'
|
assert not shared.cmd_opts.hide_ui_dir_config, '--hide-ui-dir-config option must be disabled'
|
||||||
assert input_dir, 'input directory not selected'
|
assert input_dir, 'input directory not selected'
|
||||||
@@ -35,13 +33,10 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir,
|
|||||||
image = Image.open(filename)
|
image = Image.open(filename)
|
||||||
except Exception:
|
except Exception:
|
||||||
continue
|
continue
|
||||||
image_data.append(image)
|
yield image, filename
|
||||||
image_names.append(filename)
|
|
||||||
else:
|
else:
|
||||||
assert image, 'image not selected'
|
assert image, 'image not selected'
|
||||||
|
yield image, None
|
||||||
image_data.append(image)
|
|
||||||
image_names.append(None)
|
|
||||||
|
|
||||||
if extras_mode == 2 and output_dir != '':
|
if extras_mode == 2 and output_dir != '':
|
||||||
outpath = output_dir
|
outpath = output_dir
|
||||||
@@ -50,14 +45,16 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir,
|
|||||||
|
|
||||||
infotext = ''
|
infotext = ''
|
||||||
|
|
||||||
for image, name in zip(image_data, image_names):
|
for image_data, name in get_images(extras_mode, image, image_folder, input_dir):
|
||||||
|
image_data: Image.Image
|
||||||
|
|
||||||
shared.state.textinfo = name
|
shared.state.textinfo = name
|
||||||
|
|
||||||
parameters, existing_pnginfo = images.read_info_from_image(image)
|
parameters, existing_pnginfo = images.read_info_from_image(image_data)
|
||||||
if parameters:
|
if parameters:
|
||||||
existing_pnginfo["parameters"] = parameters
|
existing_pnginfo["parameters"] = parameters
|
||||||
|
|
||||||
pp = scripts_postprocessing.PostprocessedImage(image.convert("RGB"))
|
pp = scripts_postprocessing.PostprocessedImage(image_data.convert("RGB"))
|
||||||
|
|
||||||
scripts.scripts_postproc.run(pp, args)
|
scripts.scripts_postproc.run(pp, args)
|
||||||
|
|
||||||
@@ -78,6 +75,8 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir,
|
|||||||
if extras_mode != 2 or show_extras_results:
|
if extras_mode != 2 or show_extras_results:
|
||||||
outputs.append(pp.image)
|
outputs.append(pp.image)
|
||||||
|
|
||||||
|
image_data.close()
|
||||||
|
|
||||||
devices.torch_gc()
|
devices.torch_gc()
|
||||||
|
|
||||||
return outputs, ui_common.plaintext_to_html(infotext), ''
|
return outputs, ui_common.plaintext_to_html(infotext), ''
|
||||||
|
|||||||
+382
-273
@@ -1,9 +1,11 @@
|
|||||||
|
from __future__ import annotations
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import hashlib
|
import hashlib
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -11,10 +13,11 @@ from PIL import Image, ImageOps
|
|||||||
import random
|
import random
|
||||||
import cv2
|
import cv2
|
||||||
from skimage import exposure
|
from skimage import exposure
|
||||||
from typing import Any, Dict, List
|
from typing import Any
|
||||||
|
|
||||||
import modules.sd_hijack
|
import modules.sd_hijack
|
||||||
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors
|
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors, rng
|
||||||
|
from modules.rng import slerp # noqa: F401
|
||||||
from modules.sd_hijack import model_hijack
|
from modules.sd_hijack import model_hijack
|
||||||
from modules.sd_samplers_common import images_tensor_to_samples, decode_first_stage, approximation_indexes
|
from modules.sd_samplers_common import images_tensor_to_samples, decode_first_stage, approximation_indexes
|
||||||
from modules.shared import opts, cmd_opts, state
|
from modules.shared import opts, cmd_opts, state
|
||||||
@@ -56,7 +59,7 @@ def apply_color_correction(correction, original_image):
|
|||||||
|
|
||||||
image = blendLayers(image, original_image, BlendType.LUMINOSITY)
|
image = blendLayers(image, original_image, BlendType.LUMINOSITY)
|
||||||
|
|
||||||
return image
|
return image.convert('RGB')
|
||||||
|
|
||||||
|
|
||||||
def apply_overlay(image, paste_loc, index, overlays):
|
def apply_overlay(image, paste_loc, index, overlays):
|
||||||
@@ -78,6 +81,12 @@ def apply_overlay(image, paste_loc, index, overlays):
|
|||||||
|
|
||||||
return image
|
return image
|
||||||
|
|
||||||
|
def create_binary_mask(image):
|
||||||
|
if image.mode == 'RGBA' and image.getextrema()[-1] != (255, 255):
|
||||||
|
image = image.split()[-1].convert("L").point(lambda x: 255 if x > 128 else 0)
|
||||||
|
else:
|
||||||
|
image = image.convert('L')
|
||||||
|
return image
|
||||||
|
|
||||||
def txt2img_image_conditioning(sd_model, x, width, height):
|
def txt2img_image_conditioning(sd_model, x, width, height):
|
||||||
if sd_model.model.conditioning_key in {'hybrid', 'concat'}: # Inpainting models
|
if sd_model.model.conditioning_key in {'hybrid', 'concat'}: # Inpainting models
|
||||||
@@ -103,94 +112,165 @@ def txt2img_image_conditioning(sd_model, x, width, height):
|
|||||||
return x.new_zeros(x.shape[0], 5, 1, 1, dtype=x.dtype, device=x.device)
|
return x.new_zeros(x.shape[0], 5, 1, 1, dtype=x.dtype, device=x.device)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(repr=False)
|
||||||
class StableDiffusionProcessing:
|
class StableDiffusionProcessing:
|
||||||
"""
|
sd_model: object = None
|
||||||
The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing
|
outpath_samples: str = None
|
||||||
"""
|
outpath_grids: str = None
|
||||||
|
prompt: str = ""
|
||||||
|
prompt_for_display: str = None
|
||||||
|
negative_prompt: str = ""
|
||||||
|
styles: list[str] = None
|
||||||
|
seed: int = -1
|
||||||
|
subseed: int = -1
|
||||||
|
subseed_strength: float = 0
|
||||||
|
seed_resize_from_h: int = -1
|
||||||
|
seed_resize_from_w: int = -1
|
||||||
|
seed_enable_extras: bool = True
|
||||||
|
sampler_name: str = None
|
||||||
|
batch_size: int = 1
|
||||||
|
n_iter: int = 1
|
||||||
|
steps: int = 50
|
||||||
|
cfg_scale: float = 7.0
|
||||||
|
width: int = 512
|
||||||
|
height: int = 512
|
||||||
|
restore_faces: bool = None
|
||||||
|
tiling: bool = None
|
||||||
|
do_not_save_samples: bool = False
|
||||||
|
do_not_save_grid: bool = False
|
||||||
|
extra_generation_params: dict[str, Any] = None
|
||||||
|
overlay_images: list = None
|
||||||
|
eta: float = None
|
||||||
|
do_not_reload_embeddings: bool = False
|
||||||
|
denoising_strength: float = 0
|
||||||
|
ddim_discretize: str = None
|
||||||
|
s_min_uncond: float = None
|
||||||
|
s_churn: float = None
|
||||||
|
s_tmax: float = None
|
||||||
|
s_tmin: float = None
|
||||||
|
s_noise: float = None
|
||||||
|
override_settings: dict[str, Any] = None
|
||||||
|
override_settings_restore_afterwards: bool = True
|
||||||
|
sampler_index: int = None
|
||||||
|
refiner_checkpoint: str = None
|
||||||
|
refiner_switch_at: float = None
|
||||||
|
token_merging_ratio = 0
|
||||||
|
token_merging_ratio_hr = 0
|
||||||
|
disable_extra_networks: bool = False
|
||||||
|
|
||||||
|
scripts_value: scripts.ScriptRunner = field(default=None, init=False)
|
||||||
|
script_args_value: list = field(default=None, init=False)
|
||||||
|
scripts_setup_complete: bool = field(default=False, init=False)
|
||||||
|
|
||||||
cached_uc = [None, None]
|
cached_uc = [None, None]
|
||||||
cached_c = [None, None]
|
cached_c = [None, None]
|
||||||
|
|
||||||
def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_min_uncond: float = 0.0, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = None, override_settings: Dict[str, Any] = None, override_settings_restore_afterwards: bool = True, sampler_index: int = None, script_args: list = None):
|
comments: dict = None
|
||||||
if sampler_index is not None:
|
sampler: sd_samplers_common.Sampler | None = field(default=None, init=False)
|
||||||
|
is_using_inpainting_conditioning: bool = field(default=False, init=False)
|
||||||
|
paste_to: tuple | None = field(default=None, init=False)
|
||||||
|
|
||||||
|
is_hr_pass: bool = field(default=False, init=False)
|
||||||
|
|
||||||
|
c: tuple = field(default=None, init=False)
|
||||||
|
uc: tuple = field(default=None, init=False)
|
||||||
|
|
||||||
|
rng: rng.ImageRNG | None = field(default=None, init=False)
|
||||||
|
step_multiplier: int = field(default=1, init=False)
|
||||||
|
color_corrections: list = field(default=None, init=False)
|
||||||
|
|
||||||
|
all_prompts: list = field(default=None, init=False)
|
||||||
|
all_negative_prompts: list = field(default=None, init=False)
|
||||||
|
all_seeds: list = field(default=None, init=False)
|
||||||
|
all_subseeds: list = field(default=None, init=False)
|
||||||
|
iteration: int = field(default=0, init=False)
|
||||||
|
main_prompt: str = field(default=None, init=False)
|
||||||
|
main_negative_prompt: str = field(default=None, init=False)
|
||||||
|
|
||||||
|
prompts: list = field(default=None, init=False)
|
||||||
|
negative_prompts: list = field(default=None, init=False)
|
||||||
|
seeds: list = field(default=None, init=False)
|
||||||
|
subseeds: list = field(default=None, init=False)
|
||||||
|
extra_network_data: dict = field(default=None, init=False)
|
||||||
|
|
||||||
|
user: str = field(default=None, init=False)
|
||||||
|
|
||||||
|
sd_model_name: str = field(default=None, init=False)
|
||||||
|
sd_model_hash: str = field(default=None, init=False)
|
||||||
|
sd_vae_name: str = field(default=None, init=False)
|
||||||
|
sd_vae_hash: str = field(default=None, init=False)
|
||||||
|
|
||||||
|
is_api: bool = field(default=False, init=False)
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if self.sampler_index is not None:
|
||||||
print("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name", file=sys.stderr)
|
print("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name", file=sys.stderr)
|
||||||
|
|
||||||
self.outpath_samples: str = outpath_samples
|
self.comments = {}
|
||||||
self.outpath_grids: str = outpath_grids
|
|
||||||
self.prompt: str = prompt
|
|
||||||
self.prompt_for_display: str = None
|
|
||||||
self.negative_prompt: str = (negative_prompt or "")
|
|
||||||
self.styles: list = styles or []
|
|
||||||
self.seed: int = seed
|
|
||||||
self.subseed: int = subseed
|
|
||||||
self.subseed_strength: float = subseed_strength
|
|
||||||
self.seed_resize_from_h: int = seed_resize_from_h
|
|
||||||
self.seed_resize_from_w: int = seed_resize_from_w
|
|
||||||
self.sampler_name: str = sampler_name
|
|
||||||
self.batch_size: int = batch_size
|
|
||||||
self.n_iter: int = n_iter
|
|
||||||
self.steps: int = steps
|
|
||||||
self.cfg_scale: float = cfg_scale
|
|
||||||
self.width: int = width
|
|
||||||
self.height: int = height
|
|
||||||
self.restore_faces: bool = restore_faces
|
|
||||||
self.tiling: bool = tiling
|
|
||||||
self.do_not_save_samples: bool = do_not_save_samples
|
|
||||||
self.do_not_save_grid: bool = do_not_save_grid
|
|
||||||
self.extra_generation_params: dict = extra_generation_params or {}
|
|
||||||
self.overlay_images = overlay_images
|
|
||||||
self.eta = eta
|
|
||||||
self.do_not_reload_embeddings = do_not_reload_embeddings
|
|
||||||
self.paste_to = None
|
|
||||||
self.color_corrections = None
|
|
||||||
self.denoising_strength: float = denoising_strength
|
|
||||||
self.sampler_noise_scheduler_override = None
|
|
||||||
self.ddim_discretize = ddim_discretize or opts.ddim_discretize
|
|
||||||
self.s_min_uncond = s_min_uncond or opts.s_min_uncond
|
|
||||||
self.s_churn = s_churn or opts.s_churn
|
|
||||||
self.s_tmin = s_tmin or opts.s_tmin
|
|
||||||
self.s_tmax = (s_tmax if s_tmax is not None else opts.s_tmax) or float('inf')
|
|
||||||
self.s_noise = s_noise if s_noise is not None else opts.s_noise
|
|
||||||
self.override_settings = {k: v for k, v in (override_settings or {}).items() if k not in shared.restricted_opts}
|
|
||||||
self.override_settings_restore_afterwards = override_settings_restore_afterwards
|
|
||||||
self.is_using_inpainting_conditioning = False
|
|
||||||
self.disable_extra_networks = False
|
|
||||||
self.token_merging_ratio = 0
|
|
||||||
self.token_merging_ratio_hr = 0
|
|
||||||
|
|
||||||
if not seed_enable_extras:
|
if self.styles is None:
|
||||||
|
self.styles = []
|
||||||
|
|
||||||
|
self.sampler_noise_scheduler_override = None
|
||||||
|
self.s_min_uncond = self.s_min_uncond if self.s_min_uncond is not None else opts.s_min_uncond
|
||||||
|
self.s_churn = self.s_churn if self.s_churn is not None else opts.s_churn
|
||||||
|
self.s_tmin = self.s_tmin if self.s_tmin is not None else opts.s_tmin
|
||||||
|
self.s_tmax = (self.s_tmax if self.s_tmax is not None else opts.s_tmax) or float('inf')
|
||||||
|
self.s_noise = self.s_noise if self.s_noise is not None else opts.s_noise
|
||||||
|
|
||||||
|
self.extra_generation_params = self.extra_generation_params or {}
|
||||||
|
self.override_settings = self.override_settings or {}
|
||||||
|
self.script_args = self.script_args or {}
|
||||||
|
|
||||||
|
self.refiner_checkpoint_info = None
|
||||||
|
|
||||||
|
if not self.seed_enable_extras:
|
||||||
self.subseed = -1
|
self.subseed = -1
|
||||||
self.subseed_strength = 0
|
self.subseed_strength = 0
|
||||||
self.seed_resize_from_h = 0
|
self.seed_resize_from_h = 0
|
||||||
self.seed_resize_from_w = 0
|
self.seed_resize_from_w = 0
|
||||||
|
|
||||||
self.scripts = None
|
|
||||||
self.script_args = script_args
|
|
||||||
self.all_prompts = None
|
|
||||||
self.all_negative_prompts = None
|
|
||||||
self.all_seeds = None
|
|
||||||
self.all_subseeds = None
|
|
||||||
self.iteration = 0
|
|
||||||
self.is_hr_pass = False
|
|
||||||
self.sampler = None
|
|
||||||
|
|
||||||
self.prompts = None
|
|
||||||
self.negative_prompts = None
|
|
||||||
self.extra_network_data = None
|
|
||||||
self.seeds = None
|
|
||||||
self.subseeds = None
|
|
||||||
|
|
||||||
self.step_multiplier = 1
|
|
||||||
self.cached_uc = StableDiffusionProcessing.cached_uc
|
self.cached_uc = StableDiffusionProcessing.cached_uc
|
||||||
self.cached_c = StableDiffusionProcessing.cached_c
|
self.cached_c = StableDiffusionProcessing.cached_c
|
||||||
self.uc = None
|
|
||||||
self.c = None
|
|
||||||
|
|
||||||
self.user = None
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def sd_model(self):
|
def sd_model(self):
|
||||||
return shared.sd_model
|
return shared.sd_model
|
||||||
|
|
||||||
|
@sd_model.setter
|
||||||
|
def sd_model(self, value):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@property
|
||||||
|
def scripts(self):
|
||||||
|
return self.scripts_value
|
||||||
|
|
||||||
|
@scripts.setter
|
||||||
|
def scripts(self, value):
|
||||||
|
self.scripts_value = value
|
||||||
|
|
||||||
|
if self.scripts_value and self.script_args_value and not self.scripts_setup_complete:
|
||||||
|
self.setup_scripts()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def script_args(self):
|
||||||
|
return self.script_args_value
|
||||||
|
|
||||||
|
@script_args.setter
|
||||||
|
def script_args(self, value):
|
||||||
|
self.script_args_value = value
|
||||||
|
|
||||||
|
if self.scripts_value and self.script_args_value and not self.scripts_setup_complete:
|
||||||
|
self.setup_scripts()
|
||||||
|
|
||||||
|
def setup_scripts(self):
|
||||||
|
self.scripts_setup_complete = True
|
||||||
|
|
||||||
|
self.scripts.setup_scrips(self, is_ui=not self.is_api)
|
||||||
|
|
||||||
|
def comment(self, text):
|
||||||
|
self.comments[text] = 1
|
||||||
|
|
||||||
def txt2img_image_conditioning(self, x, width=None, height=None):
|
def txt2img_image_conditioning(self, x, width=None, height=None):
|
||||||
self.is_using_inpainting_conditioning = self.sd_model.model.conditioning_key in {'hybrid', 'concat'}
|
self.is_using_inpainting_conditioning = self.sd_model.model.conditioning_key in {'hybrid', 'concat'}
|
||||||
|
|
||||||
@@ -306,25 +386,35 @@ class StableDiffusionProcessing:
|
|||||||
return self.token_merging_ratio or opts.token_merging_ratio
|
return self.token_merging_ratio or opts.token_merging_ratio
|
||||||
|
|
||||||
def setup_prompts(self):
|
def setup_prompts(self):
|
||||||
if type(self.prompt) == list:
|
if isinstance(self.prompt,list):
|
||||||
self.all_prompts = self.prompt
|
self.all_prompts = self.prompt
|
||||||
|
elif isinstance(self.negative_prompt, list):
|
||||||
|
self.all_prompts = [self.prompt] * len(self.negative_prompt)
|
||||||
else:
|
else:
|
||||||
self.all_prompts = self.batch_size * self.n_iter * [self.prompt]
|
self.all_prompts = self.batch_size * self.n_iter * [self.prompt]
|
||||||
|
|
||||||
if type(self.negative_prompt) == list:
|
if isinstance(self.negative_prompt, list):
|
||||||
self.all_negative_prompts = self.negative_prompt
|
self.all_negative_prompts = self.negative_prompt
|
||||||
else:
|
else:
|
||||||
self.all_negative_prompts = self.batch_size * self.n_iter * [self.negative_prompt]
|
self.all_negative_prompts = [self.negative_prompt] * len(self.all_prompts)
|
||||||
|
|
||||||
|
if len(self.all_prompts) != len(self.all_negative_prompts):
|
||||||
|
raise RuntimeError(f"Received a different number of prompts ({len(self.all_prompts)}) and negative prompts ({len(self.all_negative_prompts)})")
|
||||||
|
|
||||||
self.all_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, self.styles) for x in self.all_prompts]
|
self.all_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, self.styles) for x in self.all_prompts]
|
||||||
self.all_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, self.styles) for x in self.all_negative_prompts]
|
self.all_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, self.styles) for x in self.all_negative_prompts]
|
||||||
|
|
||||||
def cached_params(self, required_prompts, steps, extra_network_data):
|
self.main_prompt = self.all_prompts[0]
|
||||||
|
self.main_negative_prompt = self.all_negative_prompts[0]
|
||||||
|
|
||||||
|
def cached_params(self, required_prompts, steps, extra_network_data, hires_steps=None, use_old_scheduling=False):
|
||||||
"""Returns parameters that invalidate the cond cache if changed"""
|
"""Returns parameters that invalidate the cond cache if changed"""
|
||||||
|
|
||||||
return (
|
return (
|
||||||
required_prompts,
|
required_prompts,
|
||||||
steps,
|
steps,
|
||||||
|
hires_steps,
|
||||||
|
use_old_scheduling,
|
||||||
opts.CLIP_stop_at_last_layers,
|
opts.CLIP_stop_at_last_layers,
|
||||||
shared.sd_model.sd_checkpoint_info,
|
shared.sd_model.sd_checkpoint_info,
|
||||||
extra_network_data,
|
extra_network_data,
|
||||||
@@ -334,7 +424,7 @@ class StableDiffusionProcessing:
|
|||||||
self.height,
|
self.height,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_conds_with_caching(self, function, required_prompts, steps, caches, extra_network_data):
|
def get_conds_with_caching(self, function, required_prompts, steps, caches, extra_network_data, hires_steps=None):
|
||||||
"""
|
"""
|
||||||
Returns the result of calling function(shared.sd_model, required_prompts, steps)
|
Returns the result of calling function(shared.sd_model, required_prompts, steps)
|
||||||
using a cache to store the result if the same arguments have been used before.
|
using a cache to store the result if the same arguments have been used before.
|
||||||
@@ -347,7 +437,13 @@ class StableDiffusionProcessing:
|
|||||||
caches is a list with items described above.
|
caches is a list with items described above.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
cached_params = self.cached_params(required_prompts, steps, extra_network_data)
|
if shared.opts.use_old_scheduling:
|
||||||
|
old_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(required_prompts, steps, hires_steps, False)
|
||||||
|
new_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(required_prompts, steps, hires_steps, True)
|
||||||
|
if old_schedules != new_schedules:
|
||||||
|
self.extra_generation_params["Old prompt editing timelines"] = True
|
||||||
|
|
||||||
|
cached_params = self.cached_params(required_prompts, steps, extra_network_data, hires_steps, shared.opts.use_old_scheduling)
|
||||||
|
|
||||||
for cache in caches:
|
for cache in caches:
|
||||||
if cache[0] is not None and cached_params == cache[0]:
|
if cache[0] is not None and cached_params == cache[0]:
|
||||||
@@ -356,7 +452,7 @@ class StableDiffusionProcessing:
|
|||||||
cache = caches[0]
|
cache = caches[0]
|
||||||
|
|
||||||
with devices.autocast():
|
with devices.autocast():
|
||||||
cache[1] = function(shared.sd_model, required_prompts, steps)
|
cache[1] = function(shared.sd_model, required_prompts, steps, hires_steps, shared.opts.use_old_scheduling)
|
||||||
|
|
||||||
cache[0] = cached_params
|
cache[0] = cached_params
|
||||||
return cache[1]
|
return cache[1]
|
||||||
@@ -366,9 +462,15 @@ class StableDiffusionProcessing:
|
|||||||
negative_prompts = prompt_parser.SdConditioning(self.negative_prompts, width=self.width, height=self.height, is_negative_prompt=True)
|
negative_prompts = prompt_parser.SdConditioning(self.negative_prompts, width=self.width, height=self.height, is_negative_prompt=True)
|
||||||
|
|
||||||
sampler_config = sd_samplers.find_sampler_config(self.sampler_name)
|
sampler_config = sd_samplers.find_sampler_config(self.sampler_name)
|
||||||
self.step_multiplier = 2 if sampler_config and sampler_config.options.get("second_order", False) else 1
|
total_steps = sampler_config.total_steps(self.steps) if sampler_config else self.steps
|
||||||
self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, self.steps * self.step_multiplier, [self.cached_uc], self.extra_network_data)
|
self.step_multiplier = total_steps // self.steps
|
||||||
self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, self.steps * self.step_multiplier, [self.cached_c], self.extra_network_data)
|
self.firstpass_steps = total_steps
|
||||||
|
|
||||||
|
self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, total_steps, [self.cached_uc], self.extra_network_data)
|
||||||
|
self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, total_steps, [self.cached_c], self.extra_network_data)
|
||||||
|
|
||||||
|
def get_conds(self):
|
||||||
|
return self.c, self.uc
|
||||||
|
|
||||||
def parse_extra_network_prompts(self):
|
def parse_extra_network_prompts(self):
|
||||||
self.prompts, self.extra_network_data = extra_networks.parse_prompts(self.prompts)
|
self.prompts, self.extra_network_data = extra_networks.parse_prompts(self.prompts)
|
||||||
@@ -387,7 +489,7 @@ class Processed:
|
|||||||
self.subseed = subseed
|
self.subseed = subseed
|
||||||
self.subseed_strength = p.subseed_strength
|
self.subseed_strength = p.subseed_strength
|
||||||
self.info = info
|
self.info = info
|
||||||
self.comments = comments
|
self.comments = "".join(f"{comment}\n" for comment in p.comments)
|
||||||
self.width = p.width
|
self.width = p.width
|
||||||
self.height = p.height
|
self.height = p.height
|
||||||
self.sampler_name = p.sampler_name
|
self.sampler_name = p.sampler_name
|
||||||
@@ -397,7 +499,10 @@ class Processed:
|
|||||||
self.batch_size = p.batch_size
|
self.batch_size = p.batch_size
|
||||||
self.restore_faces = p.restore_faces
|
self.restore_faces = p.restore_faces
|
||||||
self.face_restoration_model = opts.face_restoration_model if p.restore_faces else None
|
self.face_restoration_model = opts.face_restoration_model if p.restore_faces else None
|
||||||
self.sd_model_hash = shared.sd_model.sd_model_hash
|
self.sd_model_name = p.sd_model_name
|
||||||
|
self.sd_model_hash = p.sd_model_hash
|
||||||
|
self.sd_vae_name = p.sd_vae_name
|
||||||
|
self.sd_vae_hash = p.sd_vae_hash
|
||||||
self.seed_resize_from_w = p.seed_resize_from_w
|
self.seed_resize_from_w = p.seed_resize_from_w
|
||||||
self.seed_resize_from_h = p.seed_resize_from_h
|
self.seed_resize_from_h = p.seed_resize_from_h
|
||||||
self.denoising_strength = getattr(p, 'denoising_strength', None)
|
self.denoising_strength = getattr(p, 'denoising_strength', None)
|
||||||
@@ -417,10 +522,10 @@ class Processed:
|
|||||||
self.s_noise = p.s_noise
|
self.s_noise = p.s_noise
|
||||||
self.s_min_uncond = p.s_min_uncond
|
self.s_min_uncond = p.s_min_uncond
|
||||||
self.sampler_noise_scheduler_override = p.sampler_noise_scheduler_override
|
self.sampler_noise_scheduler_override = p.sampler_noise_scheduler_override
|
||||||
self.prompt = self.prompt if type(self.prompt) != list else self.prompt[0]
|
self.prompt = self.prompt if not isinstance(self.prompt, list) else self.prompt[0]
|
||||||
self.negative_prompt = self.negative_prompt if type(self.negative_prompt) != list else self.negative_prompt[0]
|
self.negative_prompt = self.negative_prompt if not isinstance(self.negative_prompt, list) else self.negative_prompt[0]
|
||||||
self.seed = int(self.seed if type(self.seed) != list else self.seed[0]) if self.seed is not None else -1
|
self.seed = int(self.seed if not isinstance(self.seed, list) else self.seed[0]) if self.seed is not None else -1
|
||||||
self.subseed = int(self.subseed if type(self.subseed) != list else self.subseed[0]) if self.subseed is not None else -1
|
self.subseed = int(self.subseed if not isinstance(self.subseed, list) else self.subseed[0]) if self.subseed is not None else -1
|
||||||
self.is_using_inpainting_conditioning = p.is_using_inpainting_conditioning
|
self.is_using_inpainting_conditioning = p.is_using_inpainting_conditioning
|
||||||
|
|
||||||
self.all_prompts = all_prompts or p.all_prompts or [self.prompt]
|
self.all_prompts = all_prompts or p.all_prompts or [self.prompt]
|
||||||
@@ -448,7 +553,10 @@ class Processed:
|
|||||||
"batch_size": self.batch_size,
|
"batch_size": self.batch_size,
|
||||||
"restore_faces": self.restore_faces,
|
"restore_faces": self.restore_faces,
|
||||||
"face_restoration_model": self.face_restoration_model,
|
"face_restoration_model": self.face_restoration_model,
|
||||||
|
"sd_model_name": self.sd_model_name,
|
||||||
"sd_model_hash": self.sd_model_hash,
|
"sd_model_hash": self.sd_model_hash,
|
||||||
|
"sd_vae_name": self.sd_vae_name,
|
||||||
|
"sd_vae_hash": self.sd_vae_hash,
|
||||||
"seed_resize_from_w": self.seed_resize_from_w,
|
"seed_resize_from_w": self.seed_resize_from_w,
|
||||||
"seed_resize_from_h": self.seed_resize_from_h,
|
"seed_resize_from_h": self.seed_resize_from_h,
|
||||||
"denoising_strength": self.denoising_strength,
|
"denoising_strength": self.denoising_strength,
|
||||||
@@ -470,82 +578,9 @@ class Processed:
|
|||||||
return self.token_merging_ratio_hr if for_hr else self.token_merging_ratio
|
return self.token_merging_ratio_hr if for_hr else self.token_merging_ratio
|
||||||
|
|
||||||
|
|
||||||
# from https://discuss.pytorch.org/t/help-regarding-slerp-function-for-generative-model-sampling/32475/3
|
|
||||||
def slerp(val, low, high):
|
|
||||||
low_norm = low/torch.norm(low, dim=1, keepdim=True)
|
|
||||||
high_norm = high/torch.norm(high, dim=1, keepdim=True)
|
|
||||||
dot = (low_norm*high_norm).sum(1)
|
|
||||||
|
|
||||||
if dot.mean() > 0.9995:
|
|
||||||
return low * val + high * (1 - val)
|
|
||||||
|
|
||||||
omega = torch.acos(dot)
|
|
||||||
so = torch.sin(omega)
|
|
||||||
res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low + (torch.sin(val*omega)/so).unsqueeze(1) * high
|
|
||||||
return res
|
|
||||||
|
|
||||||
|
|
||||||
def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, seed_resize_from_h=0, seed_resize_from_w=0, p=None):
|
def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, seed_resize_from_h=0, seed_resize_from_w=0, p=None):
|
||||||
eta_noise_seed_delta = opts.eta_noise_seed_delta or 0
|
g = rng.ImageRNG(shape, seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=seed_resize_from_h, seed_resize_from_w=seed_resize_from_w)
|
||||||
xs = []
|
return g.next()
|
||||||
|
|
||||||
# if we have multiple seeds, this means we are working with batch size>1; this then
|
|
||||||
# enables the generation of additional tensors with noise that the sampler will use during its processing.
|
|
||||||
# Using those pre-generated tensors instead of simple torch.randn allows a batch with seeds [100, 101] to
|
|
||||||
# produce the same images as with two batches [100], [101].
|
|
||||||
if p is not None and p.sampler is not None and (len(seeds) > 1 and opts.enable_batch_seeds or eta_noise_seed_delta > 0):
|
|
||||||
sampler_noises = [[] for _ in range(p.sampler.number_of_needed_noises(p))]
|
|
||||||
else:
|
|
||||||
sampler_noises = None
|
|
||||||
|
|
||||||
for i, seed in enumerate(seeds):
|
|
||||||
noise_shape = shape if seed_resize_from_h <= 0 or seed_resize_from_w <= 0 else (shape[0], seed_resize_from_h//8, seed_resize_from_w//8)
|
|
||||||
|
|
||||||
subnoise = None
|
|
||||||
if subseeds is not None and subseed_strength != 0:
|
|
||||||
subseed = 0 if i >= len(subseeds) else subseeds[i]
|
|
||||||
|
|
||||||
subnoise = devices.randn(subseed, noise_shape)
|
|
||||||
|
|
||||||
# randn results depend on device; gpu and cpu get different results for same seed;
|
|
||||||
# the way I see it, it's better to do this on CPU, so that everyone gets same result;
|
|
||||||
# but the original script had it like this, so I do not dare change it for now because
|
|
||||||
# it will break everyone's seeds.
|
|
||||||
noise = devices.randn(seed, noise_shape)
|
|
||||||
|
|
||||||
if subnoise is not None:
|
|
||||||
noise = slerp(subseed_strength, noise, subnoise)
|
|
||||||
|
|
||||||
if noise_shape != shape:
|
|
||||||
x = devices.randn(seed, shape)
|
|
||||||
dx = (shape[2] - noise_shape[2]) // 2
|
|
||||||
dy = (shape[1] - noise_shape[1]) // 2
|
|
||||||
w = noise_shape[2] if dx >= 0 else noise_shape[2] + 2 * dx
|
|
||||||
h = noise_shape[1] if dy >= 0 else noise_shape[1] + 2 * dy
|
|
||||||
tx = 0 if dx < 0 else dx
|
|
||||||
ty = 0 if dy < 0 else dy
|
|
||||||
dx = max(-dx, 0)
|
|
||||||
dy = max(-dy, 0)
|
|
||||||
|
|
||||||
x[:, ty:ty+h, tx:tx+w] = noise[:, dy:dy+h, dx:dx+w]
|
|
||||||
noise = x
|
|
||||||
|
|
||||||
if sampler_noises is not None:
|
|
||||||
cnt = p.sampler.number_of_needed_noises(p)
|
|
||||||
|
|
||||||
if eta_noise_seed_delta > 0:
|
|
||||||
devices.manual_seed(seed + eta_noise_seed_delta)
|
|
||||||
|
|
||||||
for j in range(cnt):
|
|
||||||
sampler_noises[j].append(devices.randn_without_seed(tuple(noise_shape)))
|
|
||||||
|
|
||||||
xs.append(noise)
|
|
||||||
|
|
||||||
if sampler_noises is not None:
|
|
||||||
p.sampler.sampler_noises = [torch.stack(n).to(shared.device) for n in sampler_noises]
|
|
||||||
|
|
||||||
x = torch.stack(xs).to(shared.device)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class DecodedSamples(list):
|
class DecodedSamples(list):
|
||||||
@@ -568,7 +603,7 @@ def decode_latent_batch(model, batch, target_device=None, check_for_nans=False):
|
|||||||
errors.print_error_explanation(
|
errors.print_error_explanation(
|
||||||
"A tensor with all NaNs was produced in VAE.\n"
|
"A tensor with all NaNs was produced in VAE.\n"
|
||||||
"Web UI will now convert VAE into 32-bit float and retry.\n"
|
"Web UI will now convert VAE into 32-bit float and retry.\n"
|
||||||
"To disable this behavior, disable the 'Automaticlly revert VAE to 32-bit floats' setting.\n"
|
"To disable this behavior, disable the 'Automatically revert VAE to 32-bit floats' setting.\n"
|
||||||
"To always start with 32-bit VAE, use --no-half-vae commandline flag."
|
"To always start with 32-bit VAE, use --no-half-vae commandline flag."
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -587,7 +622,15 @@ def decode_latent_batch(model, batch, target_device=None, check_for_nans=False):
|
|||||||
|
|
||||||
|
|
||||||
def get_fixed_seed(seed):
|
def get_fixed_seed(seed):
|
||||||
if seed is None or seed == '' or seed == -1:
|
if seed == '' or seed is None:
|
||||||
|
seed = -1
|
||||||
|
elif isinstance(seed, str):
|
||||||
|
try:
|
||||||
|
seed = int(seed)
|
||||||
|
except Exception:
|
||||||
|
seed = -1
|
||||||
|
|
||||||
|
if seed == -1:
|
||||||
return int(random.randrange(4294967294))
|
return int(random.randrange(4294967294))
|
||||||
|
|
||||||
return seed
|
return seed
|
||||||
@@ -630,10 +673,12 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
|
|||||||
"CFG scale": p.cfg_scale,
|
"CFG scale": p.cfg_scale,
|
||||||
"Image CFG scale": getattr(p, 'image_cfg_scale', None),
|
"Image CFG scale": getattr(p, 'image_cfg_scale', None),
|
||||||
"Seed": p.all_seeds[0] if use_main_prompt else all_seeds[index],
|
"Seed": p.all_seeds[0] if use_main_prompt else all_seeds[index],
|
||||||
"Face restoration": (opts.face_restoration_model if p.restore_faces else None),
|
"Face restoration": opts.face_restoration_model if p.restore_faces else None,
|
||||||
"Size": f"{p.width}x{p.height}",
|
"Size": f"{p.width}x{p.height}",
|
||||||
"Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
|
"Model hash": p.sd_model_hash if opts.add_model_hash_to_info else None,
|
||||||
"Model": (None if not opts.add_model_name_to_info else shared.sd_model.sd_checkpoint_info.name_for_extra),
|
"Model": p.sd_model_name if opts.add_model_name_to_info else None,
|
||||||
|
"VAE hash": p.sd_vae_hash if opts.add_model_hash_to_info else None,
|
||||||
|
"VAE": p.sd_vae_name if opts.add_model_name_to_info else None,
|
||||||
"Variation seed": (None if p.subseed_strength == 0 else (p.all_subseeds[0] if use_main_prompt else all_subseeds[index])),
|
"Variation seed": (None if p.subseed_strength == 0 else (p.all_subseeds[0] if use_main_prompt else all_subseeds[index])),
|
||||||
"Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
|
"Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
|
||||||
"Seed resize from": (None if p.seed_resize_from_w <= 0 or p.seed_resize_from_h <= 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
|
"Seed resize from": (None if p.seed_resize_from_w <= 0 or p.seed_resize_from_h <= 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
|
||||||
@@ -646,6 +691,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
|
|||||||
"Init image hash": getattr(p, 'init_img_hash', None),
|
"Init image hash": getattr(p, 'init_img_hash', None),
|
||||||
"RNG": opts.randn_source if opts.randn_source != "GPU" and opts.randn_source != "NV" else None,
|
"RNG": opts.randn_source if opts.randn_source != "GPU" and opts.randn_source != "NV" else None,
|
||||||
"NGMS": None if p.s_min_uncond == 0 else p.s_min_uncond,
|
"NGMS": None if p.s_min_uncond == 0 else p.s_min_uncond,
|
||||||
|
"Tiling": "True" if p.tiling else None,
|
||||||
**p.extra_generation_params,
|
**p.extra_generation_params,
|
||||||
"Version": program_version() if opts.add_version_to_infotext else None,
|
"Version": program_version() if opts.add_version_to_infotext else None,
|
||||||
"User": p.user if opts.add_user_name_to_info else None,
|
"User": p.user if opts.add_user_name_to_info else None,
|
||||||
@@ -653,8 +699,8 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
|
|||||||
|
|
||||||
generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None])
|
generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None])
|
||||||
|
|
||||||
prompt_text = p.prompt if use_main_prompt else all_prompts[index]
|
prompt_text = p.main_prompt if use_main_prompt else all_prompts[index]
|
||||||
negative_prompt_text = f"\nNegative prompt: {all_negative_prompts[index]}" if all_negative_prompts[index] else ""
|
negative_prompt_text = f"\nNegative prompt: {p.main_negative_prompt if use_main_prompt else all_negative_prompts[index]}" if all_negative_prompts[index] else ""
|
||||||
|
|
||||||
return f"{prompt_text}{negative_prompt_text}\n{generation_params_text}".strip()
|
return f"{prompt_text}{negative_prompt_text}\n{generation_params_text}".strip()
|
||||||
|
|
||||||
@@ -667,12 +713,13 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# if no checkpoint override or the override checkpoint can't be found, remove override entry and load opts checkpoint
|
# if no checkpoint override or the override checkpoint can't be found, remove override entry and load opts checkpoint
|
||||||
|
# and if after running refiner, the refiner model is not unloaded - webui swaps back to main model here, if model over is present it will be reloaded afterwards
|
||||||
if sd_models.checkpoint_aliases.get(p.override_settings.get('sd_model_checkpoint')) is None:
|
if sd_models.checkpoint_aliases.get(p.override_settings.get('sd_model_checkpoint')) is None:
|
||||||
p.override_settings.pop('sd_model_checkpoint', None)
|
p.override_settings.pop('sd_model_checkpoint', None)
|
||||||
sd_models.reload_model_weights()
|
sd_models.reload_model_weights()
|
||||||
|
|
||||||
for k, v in p.override_settings.items():
|
for k, v in p.override_settings.items():
|
||||||
setattr(opts, k, v)
|
opts.set(k, v, is_api=True, run_callbacks=False)
|
||||||
|
|
||||||
if k == 'sd_model_checkpoint':
|
if k == 'sd_model_checkpoint':
|
||||||
sd_models.reload_model_weights()
|
sd_models.reload_model_weights()
|
||||||
@@ -701,7 +748,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
|||||||
def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||||
"""this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
|
"""this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
|
||||||
|
|
||||||
if type(p.prompt) == list:
|
if isinstance(p.prompt, list):
|
||||||
assert(len(p.prompt) > 0)
|
assert(len(p.prompt) > 0)
|
||||||
else:
|
else:
|
||||||
assert p.prompt is not None
|
assert p.prompt is not None
|
||||||
@@ -711,19 +758,33 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||||||
seed = get_fixed_seed(p.seed)
|
seed = get_fixed_seed(p.seed)
|
||||||
subseed = get_fixed_seed(p.subseed)
|
subseed = get_fixed_seed(p.subseed)
|
||||||
|
|
||||||
|
if p.restore_faces is None:
|
||||||
|
p.restore_faces = opts.face_restoration
|
||||||
|
|
||||||
|
if p.tiling is None:
|
||||||
|
p.tiling = opts.tiling
|
||||||
|
|
||||||
|
if p.refiner_checkpoint not in (None, "", "None", "none"):
|
||||||
|
p.refiner_checkpoint_info = sd_models.get_closet_checkpoint_match(p.refiner_checkpoint)
|
||||||
|
if p.refiner_checkpoint_info is None:
|
||||||
|
raise Exception(f'Could not find checkpoint with name {p.refiner_checkpoint}')
|
||||||
|
|
||||||
|
p.sd_model_name = shared.sd_model.sd_checkpoint_info.name_for_extra
|
||||||
|
p.sd_model_hash = shared.sd_model.sd_model_hash
|
||||||
|
p.sd_vae_name = sd_vae.get_loaded_vae_name()
|
||||||
|
p.sd_vae_hash = sd_vae.get_loaded_vae_hash()
|
||||||
|
|
||||||
modules.sd_hijack.model_hijack.apply_circular(p.tiling)
|
modules.sd_hijack.model_hijack.apply_circular(p.tiling)
|
||||||
modules.sd_hijack.model_hijack.clear_comments()
|
modules.sd_hijack.model_hijack.clear_comments()
|
||||||
|
|
||||||
comments = {}
|
|
||||||
|
|
||||||
p.setup_prompts()
|
p.setup_prompts()
|
||||||
|
|
||||||
if type(seed) == list:
|
if isinstance(seed, list):
|
||||||
p.all_seeds = seed
|
p.all_seeds = seed
|
||||||
else:
|
else:
|
||||||
p.all_seeds = [int(seed) + (x if p.subseed_strength == 0 else 0) for x in range(len(p.all_prompts))]
|
p.all_seeds = [int(seed) + (x if p.subseed_strength == 0 else 0) for x in range(len(p.all_prompts))]
|
||||||
|
|
||||||
if type(subseed) == list:
|
if isinstance(subseed, list):
|
||||||
p.all_subseeds = subseed
|
p.all_subseeds = subseed
|
||||||
else:
|
else:
|
||||||
p.all_subseeds = [int(subseed) + x for x in range(len(p.all_prompts))]
|
p.all_subseeds = [int(subseed) + x for x in range(len(p.all_prompts))]
|
||||||
@@ -759,11 +820,15 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||||||
if state.interrupted:
|
if state.interrupted:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
sd_models.reload_model_weights() # model can be changed for example by refiner
|
||||||
|
|
||||||
p.prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
|
p.prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
|
||||||
p.negative_prompts = p.all_negative_prompts[n * p.batch_size:(n + 1) * p.batch_size]
|
p.negative_prompts = p.all_negative_prompts[n * p.batch_size:(n + 1) * p.batch_size]
|
||||||
p.seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
|
p.seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
|
||||||
p.subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
|
p.subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
|
||||||
|
|
||||||
|
p.rng = rng.ImageRNG((opt_C, p.height // opt_f, p.width // opt_f), p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, seed_resize_from_h=p.seed_resize_from_h, seed_resize_from_w=p.seed_resize_from_w)
|
||||||
|
|
||||||
if p.scripts is not None:
|
if p.scripts is not None:
|
||||||
p.scripts.before_process_batch(p, batch_number=n, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds)
|
p.scripts.before_process_batch(p, batch_number=n, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds)
|
||||||
|
|
||||||
@@ -785,13 +850,13 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||||||
# strength, which is saved as "Model Strength: 1.0" in the infotext
|
# strength, which is saved as "Model Strength: 1.0" in the infotext
|
||||||
if n == 0:
|
if n == 0:
|
||||||
with open(os.path.join(paths.data_path, "params.txt"), "w", encoding="utf8") as file:
|
with open(os.path.join(paths.data_path, "params.txt"), "w", encoding="utf8") as file:
|
||||||
processed = Processed(p, [], p.seed, "")
|
processed = Processed(p, [])
|
||||||
file.write(processed.infotext(p, 0))
|
file.write(processed.infotext(p, 0))
|
||||||
|
|
||||||
p.setup_conds()
|
p.setup_conds()
|
||||||
|
|
||||||
for comment in model_hijack.comments:
|
for comment in model_hijack.comments:
|
||||||
comments[comment] = 1
|
p.comment(comment)
|
||||||
|
|
||||||
p.extra_generation_params.update(model_hijack.extra_generation_params)
|
p.extra_generation_params.update(model_hijack.extra_generation_params)
|
||||||
|
|
||||||
@@ -920,7 +985,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||||||
images_list=output_images,
|
images_list=output_images,
|
||||||
seed=p.all_seeds[0],
|
seed=p.all_seeds[0],
|
||||||
info=infotexts[0],
|
info=infotexts[0],
|
||||||
comments="".join(f"{comment}\n" for comment in comments),
|
|
||||||
subseed=p.all_subseeds[0],
|
subseed=p.all_subseeds[0],
|
||||||
index_of_first_image=index_of_first_image,
|
index_of_first_image=index_of_first_image,
|
||||||
infotexts=infotexts,
|
infotexts=infotexts,
|
||||||
@@ -944,74 +1008,53 @@ def old_hires_fix_first_pass_dimensions(width, height):
|
|||||||
return width, height
|
return width, height
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(repr=False)
|
||||||
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||||
sampler = None
|
enable_hr: bool = False
|
||||||
|
denoising_strength: float = 0.75
|
||||||
|
firstphase_width: int = 0
|
||||||
|
firstphase_height: int = 0
|
||||||
|
hr_scale: float = 2.0
|
||||||
|
hr_upscaler: str = None
|
||||||
|
hr_second_pass_steps: int = 0
|
||||||
|
hr_resize_x: int = 0
|
||||||
|
hr_resize_y: int = 0
|
||||||
|
hr_checkpoint_name: str = None
|
||||||
|
hr_sampler_name: str = None
|
||||||
|
hr_prompt: str = ''
|
||||||
|
hr_negative_prompt: str = ''
|
||||||
|
|
||||||
cached_hr_uc = [None, None]
|
cached_hr_uc = [None, None]
|
||||||
cached_hr_c = [None, None]
|
cached_hr_c = [None, None]
|
||||||
|
|
||||||
def __init__(self, enable_hr: bool = False, denoising_strength: float = 0.75, firstphase_width: int = 0, firstphase_height: int = 0, hr_scale: float = 2.0, hr_upscaler: str = None, hr_second_pass_steps: int = 0, hr_resize_x: int = 0, hr_resize_y: int = 0, hr_checkpoint_name: str = None, hr_sampler_name: str = None, hr_prompt: str = '', hr_negative_prompt: str = '', **kwargs):
|
hr_checkpoint_info: dict = field(default=None, init=False)
|
||||||
super().__init__(**kwargs)
|
hr_upscale_to_x: int = field(default=0, init=False)
|
||||||
self.enable_hr = enable_hr
|
hr_upscale_to_y: int = field(default=0, init=False)
|
||||||
self.denoising_strength = denoising_strength
|
truncate_x: int = field(default=0, init=False)
|
||||||
self.hr_scale = hr_scale
|
truncate_y: int = field(default=0, init=False)
|
||||||
self.hr_upscaler = hr_upscaler
|
applied_old_hires_behavior_to: tuple = field(default=None, init=False)
|
||||||
self.hr_second_pass_steps = hr_second_pass_steps
|
latent_scale_mode: dict = field(default=None, init=False)
|
||||||
self.hr_resize_x = hr_resize_x
|
hr_c: tuple | None = field(default=None, init=False)
|
||||||
self.hr_resize_y = hr_resize_y
|
hr_uc: tuple | None = field(default=None, init=False)
|
||||||
self.hr_upscale_to_x = hr_resize_x
|
all_hr_prompts: list = field(default=None, init=False)
|
||||||
self.hr_upscale_to_y = hr_resize_y
|
all_hr_negative_prompts: list = field(default=None, init=False)
|
||||||
self.hr_checkpoint_name = hr_checkpoint_name
|
hr_prompts: list = field(default=None, init=False)
|
||||||
self.hr_checkpoint_info = None
|
hr_negative_prompts: list = field(default=None, init=False)
|
||||||
self.hr_sampler_name = hr_sampler_name
|
hr_extra_network_data: list = field(default=None, init=False)
|
||||||
self.hr_prompt = hr_prompt
|
|
||||||
self.hr_negative_prompt = hr_negative_prompt
|
|
||||||
self.all_hr_prompts = None
|
|
||||||
self.all_hr_negative_prompts = None
|
|
||||||
self.latent_scale_mode = None
|
|
||||||
|
|
||||||
if firstphase_width != 0 or firstphase_height != 0:
|
def __post_init__(self):
|
||||||
|
super().__post_init__()
|
||||||
|
|
||||||
|
if self.firstphase_width != 0 or self.firstphase_height != 0:
|
||||||
self.hr_upscale_to_x = self.width
|
self.hr_upscale_to_x = self.width
|
||||||
self.hr_upscale_to_y = self.height
|
self.hr_upscale_to_y = self.height
|
||||||
self.width = firstphase_width
|
self.width = self.firstphase_width
|
||||||
self.height = firstphase_height
|
self.height = self.firstphase_height
|
||||||
|
|
||||||
self.truncate_x = 0
|
|
||||||
self.truncate_y = 0
|
|
||||||
self.applied_old_hires_behavior_to = None
|
|
||||||
|
|
||||||
self.hr_prompts = None
|
|
||||||
self.hr_negative_prompts = None
|
|
||||||
self.hr_extra_network_data = None
|
|
||||||
|
|
||||||
self.cached_hr_uc = StableDiffusionProcessingTxt2Img.cached_hr_uc
|
self.cached_hr_uc = StableDiffusionProcessingTxt2Img.cached_hr_uc
|
||||||
self.cached_hr_c = StableDiffusionProcessingTxt2Img.cached_hr_c
|
self.cached_hr_c = StableDiffusionProcessingTxt2Img.cached_hr_c
|
||||||
self.hr_c = None
|
|
||||||
self.hr_uc = None
|
|
||||||
|
|
||||||
def init(self, all_prompts, all_seeds, all_subseeds):
|
|
||||||
if self.enable_hr:
|
|
||||||
if self.hr_checkpoint_name:
|
|
||||||
self.hr_checkpoint_info = sd_models.get_closet_checkpoint_match(self.hr_checkpoint_name)
|
|
||||||
|
|
||||||
if self.hr_checkpoint_info is None:
|
|
||||||
raise Exception(f'Could not find checkpoint with name {self.hr_checkpoint_name}')
|
|
||||||
|
|
||||||
self.extra_generation_params["Hires checkpoint"] = self.hr_checkpoint_info.short_title
|
|
||||||
|
|
||||||
if self.hr_sampler_name is not None and self.hr_sampler_name != self.sampler_name:
|
|
||||||
self.extra_generation_params["Hires sampler"] = self.hr_sampler_name
|
|
||||||
|
|
||||||
if tuple(self.hr_prompt) != tuple(self.prompt):
|
|
||||||
self.extra_generation_params["Hires prompt"] = self.hr_prompt
|
|
||||||
|
|
||||||
if tuple(self.hr_negative_prompt) != tuple(self.negative_prompt):
|
|
||||||
self.extra_generation_params["Hires negative prompt"] = self.hr_negative_prompt
|
|
||||||
|
|
||||||
self.latent_scale_mode = shared.latent_upscale_modes.get(self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_modes.get(shared.latent_upscale_default_mode, "nearest")
|
|
||||||
if self.enable_hr and self.latent_scale_mode is None:
|
|
||||||
if not any(x.name == self.hr_upscaler for x in shared.sd_upscalers):
|
|
||||||
raise Exception(f"could not find upscaler named {self.hr_upscaler}")
|
|
||||||
|
|
||||||
|
def calculate_target_resolution(self):
|
||||||
if opts.use_old_hires_fix_width_height and self.applied_old_hires_behavior_to != (self.width, self.height):
|
if opts.use_old_hires_fix_width_height and self.applied_old_hires_behavior_to != (self.width, self.height):
|
||||||
self.hr_resize_x = self.width
|
self.hr_resize_x = self.width
|
||||||
self.hr_resize_y = self.height
|
self.hr_resize_y = self.height
|
||||||
@@ -1050,6 +1093,32 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
|||||||
self.truncate_x = (self.hr_upscale_to_x - target_w) // opt_f
|
self.truncate_x = (self.hr_upscale_to_x - target_w) // opt_f
|
||||||
self.truncate_y = (self.hr_upscale_to_y - target_h) // opt_f
|
self.truncate_y = (self.hr_upscale_to_y - target_h) // opt_f
|
||||||
|
|
||||||
|
def init(self, all_prompts, all_seeds, all_subseeds):
|
||||||
|
if self.enable_hr:
|
||||||
|
if self.hr_checkpoint_name:
|
||||||
|
self.hr_checkpoint_info = sd_models.get_closet_checkpoint_match(self.hr_checkpoint_name)
|
||||||
|
|
||||||
|
if self.hr_checkpoint_info is None:
|
||||||
|
raise Exception(f'Could not find checkpoint with name {self.hr_checkpoint_name}')
|
||||||
|
|
||||||
|
self.extra_generation_params["Hires checkpoint"] = self.hr_checkpoint_info.short_title
|
||||||
|
|
||||||
|
if self.hr_sampler_name is not None and self.hr_sampler_name != self.sampler_name:
|
||||||
|
self.extra_generation_params["Hires sampler"] = self.hr_sampler_name
|
||||||
|
|
||||||
|
if tuple(self.hr_prompt) != tuple(self.prompt):
|
||||||
|
self.extra_generation_params["Hires prompt"] = self.hr_prompt
|
||||||
|
|
||||||
|
if tuple(self.hr_negative_prompt) != tuple(self.negative_prompt):
|
||||||
|
self.extra_generation_params["Hires negative prompt"] = self.hr_negative_prompt
|
||||||
|
|
||||||
|
self.latent_scale_mode = shared.latent_upscale_modes.get(self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_modes.get(shared.latent_upscale_default_mode, "nearest")
|
||||||
|
if self.enable_hr and self.latent_scale_mode is None:
|
||||||
|
if not any(x.name == self.hr_upscaler for x in shared.sd_upscalers):
|
||||||
|
raise Exception(f"could not find upscaler named {self.hr_upscaler}")
|
||||||
|
|
||||||
|
self.calculate_target_resolution()
|
||||||
|
|
||||||
if not state.processing_has_refined_job_count:
|
if not state.processing_has_refined_job_count:
|
||||||
if state.job_count == -1:
|
if state.job_count == -1:
|
||||||
state.job_count = self.n_iter
|
state.job_count = self.n_iter
|
||||||
@@ -1067,7 +1136,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
|||||||
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
|
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
|
||||||
self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
|
self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
|
||||||
|
|
||||||
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
|
x = self.rng.next()
|
||||||
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
|
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
|
||||||
del x
|
del x
|
||||||
|
|
||||||
@@ -1093,6 +1162,9 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
|||||||
devices.torch_gc()
|
devices.torch_gc()
|
||||||
|
|
||||||
def sample_hr_pass(self, samples, decoded_samples, seeds, subseeds, subseed_strength, prompts):
|
def sample_hr_pass(self, samples, decoded_samples, seeds, subseeds, subseed_strength, prompts):
|
||||||
|
if shared.state.interrupted:
|
||||||
|
return samples
|
||||||
|
|
||||||
self.is_hr_pass = True
|
self.is_hr_pass = True
|
||||||
|
|
||||||
target_width = self.hr_upscale_to_x
|
target_width = self.hr_upscale_to_x
|
||||||
@@ -1112,9 +1184,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
|||||||
|
|
||||||
img2img_sampler_name = self.hr_sampler_name or self.sampler_name
|
img2img_sampler_name = self.hr_sampler_name or self.sampler_name
|
||||||
|
|
||||||
if self.sampler_name in ['PLMS', 'UniPC']: # PLMS/UniPC do not support img2img so we just silently switch to DDIM
|
|
||||||
img2img_sampler_name = 'DDIM'
|
|
||||||
|
|
||||||
self.sampler = sd_samplers.create_sampler(img2img_sampler_name, self.sd_model)
|
self.sampler = sd_samplers.create_sampler(img2img_sampler_name, self.sd_model)
|
||||||
|
|
||||||
if self.latent_scale_mode is not None:
|
if self.latent_scale_mode is not None:
|
||||||
@@ -1158,7 +1227,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
|||||||
|
|
||||||
samples = samples[:, :, self.truncate_y//2:samples.shape[2]-(self.truncate_y+1)//2, self.truncate_x//2:samples.shape[3]-(self.truncate_x+1)//2]
|
samples = samples[:, :, self.truncate_y//2:samples.shape[2]-(self.truncate_y+1)//2, self.truncate_x//2:samples.shape[3]-(self.truncate_x+1)//2]
|
||||||
|
|
||||||
noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, p=self)
|
self.rng = rng.ImageRNG(samples.shape[1:], self.seeds, subseeds=self.subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w)
|
||||||
|
noise = self.rng.next()
|
||||||
|
|
||||||
# GC now before running the next img2img to prevent running out of memory
|
# GC now before running the next img2img to prevent running out of memory
|
||||||
devices.torch_gc()
|
devices.torch_gc()
|
||||||
@@ -1179,6 +1249,9 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
|||||||
|
|
||||||
sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio())
|
sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio())
|
||||||
|
|
||||||
|
self.sampler = None
|
||||||
|
devices.torch_gc()
|
||||||
|
|
||||||
decoded_samples = decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)
|
decoded_samples = decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)
|
||||||
|
|
||||||
self.is_hr_pass = False
|
self.is_hr_pass = False
|
||||||
@@ -1205,12 +1278,12 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
|||||||
if self.hr_negative_prompt == '':
|
if self.hr_negative_prompt == '':
|
||||||
self.hr_negative_prompt = self.negative_prompt
|
self.hr_negative_prompt = self.negative_prompt
|
||||||
|
|
||||||
if type(self.hr_prompt) == list:
|
if isinstance(self.hr_prompt, list):
|
||||||
self.all_hr_prompts = self.hr_prompt
|
self.all_hr_prompts = self.hr_prompt
|
||||||
else:
|
else:
|
||||||
self.all_hr_prompts = self.batch_size * self.n_iter * [self.hr_prompt]
|
self.all_hr_prompts = self.batch_size * self.n_iter * [self.hr_prompt]
|
||||||
|
|
||||||
if type(self.hr_negative_prompt) == list:
|
if isinstance(self.hr_negative_prompt, list):
|
||||||
self.all_hr_negative_prompts = self.hr_negative_prompt
|
self.all_hr_negative_prompts = self.hr_negative_prompt
|
||||||
else:
|
else:
|
||||||
self.all_hr_negative_prompts = self.batch_size * self.n_iter * [self.hr_negative_prompt]
|
self.all_hr_negative_prompts = self.batch_size * self.n_iter * [self.hr_negative_prompt]
|
||||||
@@ -1225,10 +1298,20 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
|||||||
hr_prompts = prompt_parser.SdConditioning(self.hr_prompts, width=self.hr_upscale_to_x, height=self.hr_upscale_to_y)
|
hr_prompts = prompt_parser.SdConditioning(self.hr_prompts, width=self.hr_upscale_to_x, height=self.hr_upscale_to_y)
|
||||||
hr_negative_prompts = prompt_parser.SdConditioning(self.hr_negative_prompts, width=self.hr_upscale_to_x, height=self.hr_upscale_to_y, is_negative_prompt=True)
|
hr_negative_prompts = prompt_parser.SdConditioning(self.hr_negative_prompts, width=self.hr_upscale_to_x, height=self.hr_upscale_to_y, is_negative_prompt=True)
|
||||||
|
|
||||||
self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, hr_negative_prompts, self.steps * self.step_multiplier, [self.cached_hr_uc, self.cached_uc], self.hr_extra_network_data)
|
sampler_config = sd_samplers.find_sampler_config(self.hr_sampler_name or self.sampler_name)
|
||||||
self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, hr_prompts, self.steps * self.step_multiplier, [self.cached_hr_c, self.cached_c], self.hr_extra_network_data)
|
steps = self.hr_second_pass_steps or self.steps
|
||||||
|
total_steps = sampler_config.total_steps(steps) if sampler_config else steps
|
||||||
|
|
||||||
|
self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, hr_negative_prompts, self.firstpass_steps, [self.cached_hr_uc, self.cached_uc], self.hr_extra_network_data, total_steps)
|
||||||
|
self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, hr_prompts, self.firstpass_steps, [self.cached_hr_c, self.cached_c], self.hr_extra_network_data, total_steps)
|
||||||
|
|
||||||
def setup_conds(self):
|
def setup_conds(self):
|
||||||
|
if self.is_hr_pass:
|
||||||
|
# if we are in hr pass right now, the call is being made from the refiner, and we don't need to setup firstpass cons or switch model
|
||||||
|
self.hr_c = None
|
||||||
|
self.calculate_hr_conds()
|
||||||
|
return
|
||||||
|
|
||||||
super().setup_conds()
|
super().setup_conds()
|
||||||
|
|
||||||
self.hr_uc = None
|
self.hr_uc = None
|
||||||
@@ -1247,6 +1330,12 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
|||||||
with devices.autocast():
|
with devices.autocast():
|
||||||
extra_networks.activate(self, self.extra_network_data)
|
extra_networks.activate(self, self.extra_network_data)
|
||||||
|
|
||||||
|
def get_conds(self):
|
||||||
|
if self.is_hr_pass:
|
||||||
|
return self.hr_c, self.hr_uc
|
||||||
|
|
||||||
|
return super().get_conds()
|
||||||
|
|
||||||
def parse_extra_network_prompts(self):
|
def parse_extra_network_prompts(self):
|
||||||
res = super().parse_extra_network_prompts()
|
res = super().parse_extra_network_prompts()
|
||||||
|
|
||||||
@@ -1259,55 +1348,75 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
|||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(repr=False)
|
||||||
class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
||||||
sampler = None
|
init_images: list = None
|
||||||
|
resize_mode: int = 0
|
||||||
|
denoising_strength: float = 0.75
|
||||||
|
image_cfg_scale: float = None
|
||||||
|
mask: Any = None
|
||||||
|
mask_blur_x: int = 4
|
||||||
|
mask_blur_y: int = 4
|
||||||
|
mask_blur: int = None
|
||||||
|
inpainting_fill: int = 0
|
||||||
|
inpaint_full_res: bool = True
|
||||||
|
inpaint_full_res_padding: int = 0
|
||||||
|
inpainting_mask_invert: int = 0
|
||||||
|
initial_noise_multiplier: float = None
|
||||||
|
latent_mask: Image = None
|
||||||
|
|
||||||
def __init__(self, init_images: list = None, resize_mode: int = 0, denoising_strength: float = 0.75, image_cfg_scale: float = None, mask: Any = None, mask_blur: int = None, mask_blur_x: int = 4, mask_blur_y: int = 4, inpainting_fill: int = 0, inpaint_full_res: bool = True, inpaint_full_res_padding: int = 0, inpainting_mask_invert: int = 0, initial_noise_multiplier: float = None, **kwargs):
|
image_mask: Any = field(default=None, init=False)
|
||||||
super().__init__(**kwargs)
|
|
||||||
|
|
||||||
self.init_images = init_images
|
nmask: torch.Tensor = field(default=None, init=False)
|
||||||
self.resize_mode: int = resize_mode
|
image_conditioning: torch.Tensor = field(default=None, init=False)
|
||||||
self.denoising_strength: float = denoising_strength
|
init_img_hash: str = field(default=None, init=False)
|
||||||
self.image_cfg_scale: float = image_cfg_scale if shared.sd_model.cond_stage_key == "edit" else None
|
mask_for_overlay: Image = field(default=None, init=False)
|
||||||
self.init_latent = None
|
init_latent: torch.Tensor = field(default=None, init=False)
|
||||||
self.image_mask = mask
|
|
||||||
self.latent_mask = None
|
def __post_init__(self):
|
||||||
self.mask_for_overlay = None
|
super().__post_init__()
|
||||||
if mask_blur is not None:
|
|
||||||
mask_blur_x = mask_blur
|
self.image_mask = self.mask
|
||||||
mask_blur_y = mask_blur
|
|
||||||
self.mask_blur_x = mask_blur_x
|
|
||||||
self.mask_blur_y = mask_blur_y
|
|
||||||
self.inpainting_fill = inpainting_fill
|
|
||||||
self.inpaint_full_res = inpaint_full_res
|
|
||||||
self.inpaint_full_res_padding = inpaint_full_res_padding
|
|
||||||
self.inpainting_mask_invert = inpainting_mask_invert
|
|
||||||
self.initial_noise_multiplier = opts.initial_noise_multiplier if initial_noise_multiplier is None else initial_noise_multiplier
|
|
||||||
self.mask = None
|
self.mask = None
|
||||||
self.nmask = None
|
self.initial_noise_multiplier = opts.initial_noise_multiplier if self.initial_noise_multiplier is None else self.initial_noise_multiplier
|
||||||
self.image_conditioning = None
|
|
||||||
|
@property
|
||||||
|
def mask_blur(self):
|
||||||
|
if self.mask_blur_x == self.mask_blur_y:
|
||||||
|
return self.mask_blur_x
|
||||||
|
return None
|
||||||
|
|
||||||
|
@mask_blur.setter
|
||||||
|
def mask_blur(self, value):
|
||||||
|
if isinstance(value, int):
|
||||||
|
self.mask_blur_x = value
|
||||||
|
self.mask_blur_y = value
|
||||||
|
|
||||||
def init(self, all_prompts, all_seeds, all_subseeds):
|
def init(self, all_prompts, all_seeds, all_subseeds):
|
||||||
|
self.image_cfg_scale: float = self.image_cfg_scale if shared.sd_model.cond_stage_key == "edit" else None
|
||||||
|
|
||||||
self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
|
self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
|
||||||
crop_region = None
|
crop_region = None
|
||||||
|
|
||||||
image_mask = self.image_mask
|
image_mask = self.image_mask
|
||||||
|
|
||||||
if image_mask is not None:
|
if image_mask is not None:
|
||||||
image_mask = image_mask.convert('L')
|
# image_mask is passed in as RGBA by Gradio to support alpha masks,
|
||||||
|
# but we still want to support binary masks.
|
||||||
|
image_mask = create_binary_mask(image_mask)
|
||||||
|
|
||||||
if self.inpainting_mask_invert:
|
if self.inpainting_mask_invert:
|
||||||
image_mask = ImageOps.invert(image_mask)
|
image_mask = ImageOps.invert(image_mask)
|
||||||
|
|
||||||
if self.mask_blur_x > 0:
|
if self.mask_blur_x > 0:
|
||||||
np_mask = np.array(image_mask)
|
np_mask = np.array(image_mask)
|
||||||
kernel_size = 2 * int(4 * self.mask_blur_x + 0.5) + 1
|
kernel_size = 2 * int(2.5 * self.mask_blur_x + 0.5) + 1
|
||||||
np_mask = cv2.GaussianBlur(np_mask, (kernel_size, 1), self.mask_blur_x)
|
np_mask = cv2.GaussianBlur(np_mask, (kernel_size, 1), self.mask_blur_x)
|
||||||
image_mask = Image.fromarray(np_mask)
|
image_mask = Image.fromarray(np_mask)
|
||||||
|
|
||||||
if self.mask_blur_y > 0:
|
if self.mask_blur_y > 0:
|
||||||
np_mask = np.array(image_mask)
|
np_mask = np.array(image_mask)
|
||||||
kernel_size = 2 * int(4 * self.mask_blur_y + 0.5) + 1
|
kernel_size = 2 * int(2.5 * self.mask_blur_y + 0.5) + 1
|
||||||
np_mask = cv2.GaussianBlur(np_mask, (1, kernel_size), self.mask_blur_y)
|
np_mask = cv2.GaussianBlur(np_mask, (1, kernel_size), self.mask_blur_y)
|
||||||
image_mask = Image.fromarray(np_mask)
|
image_mask = Image.fromarray(np_mask)
|
||||||
|
|
||||||
@@ -1413,10 +1522,10 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
|||||||
elif self.inpainting_fill == 3:
|
elif self.inpainting_fill == 3:
|
||||||
self.init_latent = self.init_latent * self.mask
|
self.init_latent = self.init_latent * self.mask
|
||||||
|
|
||||||
self.image_conditioning = self.img2img_image_conditioning(image, self.init_latent, image_mask)
|
self.image_conditioning = self.img2img_image_conditioning(image * 2 - 1, self.init_latent, image_mask)
|
||||||
|
|
||||||
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
|
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
|
||||||
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
|
x = self.rng.next()
|
||||||
|
|
||||||
if self.initial_noise_multiplier != 1.0:
|
if self.initial_noise_multiplier != 1.0:
|
||||||
self.extra_generation_params["Noise multiplier"] = self.initial_noise_multiplier
|
self.extra_generation_params["Noise multiplier"] = self.initial_noise_multiplier
|
||||||
|
|||||||
@@ -0,0 +1,49 @@
|
|||||||
|
import gradio as gr
|
||||||
|
|
||||||
|
from modules import scripts, sd_models
|
||||||
|
from modules.ui_common import create_refresh_button
|
||||||
|
from modules.ui_components import InputAccordion
|
||||||
|
|
||||||
|
|
||||||
|
class ScriptRefiner(scripts.ScriptBuiltinUI):
|
||||||
|
section = "accordions"
|
||||||
|
create_group = False
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def title(self):
|
||||||
|
return "Refiner"
|
||||||
|
|
||||||
|
def show(self, is_img2img):
|
||||||
|
return scripts.AlwaysVisible
|
||||||
|
|
||||||
|
def ui(self, is_img2img):
|
||||||
|
with InputAccordion(False, label="Refiner", elem_id=self.elem_id("enable")) as enable_refiner:
|
||||||
|
with gr.Row():
|
||||||
|
refiner_checkpoint = gr.Dropdown(label='Checkpoint', elem_id=self.elem_id("checkpoint"), choices=sd_models.checkpoint_tiles(), value='', tooltip="switch to another model in the middle of generation")
|
||||||
|
create_refresh_button(refiner_checkpoint, sd_models.list_models, lambda: {"choices": sd_models.checkpoint_tiles()}, self.elem_id("checkpoint_refresh"))
|
||||||
|
|
||||||
|
refiner_switch_at = gr.Slider(value=0.8, label="Switch at", minimum=0.01, maximum=1.0, step=0.01, elem_id=self.elem_id("switch_at"), tooltip="fraction of sampling steps when the switch to refiner model should happen; 1=never, 0.5=switch in the middle of generation")
|
||||||
|
|
||||||
|
def lookup_checkpoint(title):
|
||||||
|
info = sd_models.get_closet_checkpoint_match(title)
|
||||||
|
return None if info is None else info.title
|
||||||
|
|
||||||
|
self.infotext_fields = [
|
||||||
|
(enable_refiner, lambda d: 'Refiner' in d),
|
||||||
|
(refiner_checkpoint, lambda d: lookup_checkpoint(d.get('Refiner'))),
|
||||||
|
(refiner_switch_at, 'Refiner switch at'),
|
||||||
|
]
|
||||||
|
|
||||||
|
return enable_refiner, refiner_checkpoint, refiner_switch_at
|
||||||
|
|
||||||
|
def setup(self, p, enable_refiner, refiner_checkpoint, refiner_switch_at):
|
||||||
|
# the actual implementation is in sd_samplers_common.py, apply_refiner
|
||||||
|
|
||||||
|
if not enable_refiner or refiner_checkpoint in (None, "", "None"):
|
||||||
|
p.refiner_checkpoint = None
|
||||||
|
p.refiner_switch_at = None
|
||||||
|
else:
|
||||||
|
p.refiner_checkpoint = refiner_checkpoint
|
||||||
|
p.refiner_switch_at = refiner_switch_at
|
||||||
@@ -0,0 +1,111 @@
|
|||||||
|
import json
|
||||||
|
|
||||||
|
import gradio as gr
|
||||||
|
|
||||||
|
from modules import scripts, ui, errors
|
||||||
|
from modules.shared import cmd_opts
|
||||||
|
from modules.ui_components import ToolButton
|
||||||
|
|
||||||
|
|
||||||
|
class ScriptSeed(scripts.ScriptBuiltinUI):
|
||||||
|
section = "seed"
|
||||||
|
create_group = False
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.seed = None
|
||||||
|
self.reuse_seed = None
|
||||||
|
self.reuse_subseed = None
|
||||||
|
|
||||||
|
def title(self):
|
||||||
|
return "Seed"
|
||||||
|
|
||||||
|
def show(self, is_img2img):
|
||||||
|
return scripts.AlwaysVisible
|
||||||
|
|
||||||
|
def ui(self, is_img2img):
|
||||||
|
with gr.Row(elem_id=self.elem_id("seed_row")):
|
||||||
|
if cmd_opts.use_textbox_seed:
|
||||||
|
self.seed = gr.Textbox(label='Seed', value="", elem_id=self.elem_id("seed"), min_width=100)
|
||||||
|
else:
|
||||||
|
self.seed = gr.Number(label='Seed', value=-1, elem_id=self.elem_id("seed"), min_width=100, precision=0)
|
||||||
|
|
||||||
|
random_seed = ToolButton(ui.random_symbol, elem_id=self.elem_id("random_seed"), label='Random seed')
|
||||||
|
reuse_seed = ToolButton(ui.reuse_symbol, elem_id=self.elem_id("reuse_seed"), label='Reuse seed')
|
||||||
|
|
||||||
|
seed_checkbox = gr.Checkbox(label='Extra', elem_id=self.elem_id("subseed_show"), value=False)
|
||||||
|
|
||||||
|
with gr.Group(visible=False, elem_id=self.elem_id("seed_extras")) as seed_extras:
|
||||||
|
with gr.Row(elem_id=self.elem_id("subseed_row")):
|
||||||
|
subseed = gr.Number(label='Variation seed', value=-1, elem_id=self.elem_id("subseed"), precision=0)
|
||||||
|
random_subseed = ToolButton(ui.random_symbol, elem_id=self.elem_id("random_subseed"))
|
||||||
|
reuse_subseed = ToolButton(ui.reuse_symbol, elem_id=self.elem_id("reuse_subseed"))
|
||||||
|
subseed_strength = gr.Slider(label='Variation strength', value=0.0, minimum=0, maximum=1, step=0.01, elem_id=self.elem_id("subseed_strength"))
|
||||||
|
|
||||||
|
with gr.Row(elem_id=self.elem_id("seed_resize_from_row")):
|
||||||
|
seed_resize_from_w = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from width", value=0, elem_id=self.elem_id("seed_resize_from_w"))
|
||||||
|
seed_resize_from_h = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from height", value=0, elem_id=self.elem_id("seed_resize_from_h"))
|
||||||
|
|
||||||
|
random_seed.click(fn=None, _js="function(){setRandomSeed('" + self.elem_id("seed") + "')}", show_progress=False, inputs=[], outputs=[])
|
||||||
|
random_subseed.click(fn=None, _js="function(){setRandomSeed('" + self.elem_id("subseed") + "')}", show_progress=False, inputs=[], outputs=[])
|
||||||
|
|
||||||
|
seed_checkbox.change(lambda x: gr.update(visible=x), show_progress=False, inputs=[seed_checkbox], outputs=[seed_extras])
|
||||||
|
|
||||||
|
self.infotext_fields = [
|
||||||
|
(self.seed, "Seed"),
|
||||||
|
(seed_checkbox, lambda d: "Variation seed" in d or "Seed resize from-1" in d),
|
||||||
|
(subseed, "Variation seed"),
|
||||||
|
(subseed_strength, "Variation seed strength"),
|
||||||
|
(seed_resize_from_w, "Seed resize from-1"),
|
||||||
|
(seed_resize_from_h, "Seed resize from-2"),
|
||||||
|
]
|
||||||
|
|
||||||
|
self.on_after_component(lambda x: connect_reuse_seed(self.seed, reuse_seed, x.component, False), elem_id=f'generation_info_{self.tabname}')
|
||||||
|
self.on_after_component(lambda x: connect_reuse_seed(subseed, reuse_subseed, x.component, True), elem_id=f'generation_info_{self.tabname}')
|
||||||
|
|
||||||
|
return self.seed, seed_checkbox, subseed, subseed_strength, seed_resize_from_w, seed_resize_from_h
|
||||||
|
|
||||||
|
def setup(self, p, seed, seed_checkbox, subseed, subseed_strength, seed_resize_from_w, seed_resize_from_h):
|
||||||
|
p.seed = seed
|
||||||
|
|
||||||
|
if seed_checkbox and subseed_strength > 0:
|
||||||
|
p.subseed = subseed
|
||||||
|
p.subseed_strength = subseed_strength
|
||||||
|
|
||||||
|
if seed_checkbox and seed_resize_from_w > 0 and seed_resize_from_h > 0:
|
||||||
|
p.seed_resize_from_w = seed_resize_from_w
|
||||||
|
p.seed_resize_from_h = seed_resize_from_h
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info: gr.Textbox, is_subseed):
|
||||||
|
""" Connects a 'reuse (sub)seed' button's click event so that it copies last used
|
||||||
|
(sub)seed value from generation info the to the seed field. If copying subseed and subseed strength
|
||||||
|
was 0, i.e. no variation seed was used, it copies the normal seed value instead."""
|
||||||
|
|
||||||
|
def copy_seed(gen_info_string: str, index):
|
||||||
|
res = -1
|
||||||
|
|
||||||
|
try:
|
||||||
|
gen_info = json.loads(gen_info_string)
|
||||||
|
index -= gen_info.get('index_of_first_image', 0)
|
||||||
|
|
||||||
|
if is_subseed and gen_info.get('subseed_strength', 0) > 0:
|
||||||
|
all_subseeds = gen_info.get('all_subseeds', [-1])
|
||||||
|
res = all_subseeds[index if 0 <= index < len(all_subseeds) else 0]
|
||||||
|
else:
|
||||||
|
all_seeds = gen_info.get('all_seeds', [-1])
|
||||||
|
res = all_seeds[index if 0 <= index < len(all_seeds) else 0]
|
||||||
|
|
||||||
|
except json.decoder.JSONDecodeError:
|
||||||
|
if gen_info_string:
|
||||||
|
errors.report(f"Error parsing JSON generation info: {gen_info_string}")
|
||||||
|
|
||||||
|
return [res, gr.update()]
|
||||||
|
|
||||||
|
reuse_seed.click(
|
||||||
|
fn=copy_seed,
|
||||||
|
_js="(x, y) => [x, selected_gallery_index()]",
|
||||||
|
show_progress=False,
|
||||||
|
inputs=[generation_info, seed],
|
||||||
|
outputs=[seed, seed]
|
||||||
|
)
|
||||||
+11
-6
@@ -48,6 +48,7 @@ def add_task_to_queue(id_job):
|
|||||||
class ProgressRequest(BaseModel):
|
class ProgressRequest(BaseModel):
|
||||||
id_task: str = Field(default=None, title="Task ID", description="id of the task to get progress for")
|
id_task: str = Field(default=None, title="Task ID", description="id of the task to get progress for")
|
||||||
id_live_preview: int = Field(default=-1, title="Live preview image ID", description="id of last received last preview image")
|
id_live_preview: int = Field(default=-1, title="Live preview image ID", description="id of last received last preview image")
|
||||||
|
live_preview: bool = Field(default=True, title="Include live preview", description="boolean flag indicating whether to include the live preview image")
|
||||||
|
|
||||||
|
|
||||||
class ProgressResponse(BaseModel):
|
class ProgressResponse(BaseModel):
|
||||||
@@ -71,7 +72,12 @@ def progressapi(req: ProgressRequest):
|
|||||||
completed = req.id_task in finished_tasks
|
completed = req.id_task in finished_tasks
|
||||||
|
|
||||||
if not active:
|
if not active:
|
||||||
return ProgressResponse(active=active, queued=queued, completed=completed, id_live_preview=-1, textinfo="In queue..." if queued else "Waiting...")
|
textinfo = "Waiting..."
|
||||||
|
if queued:
|
||||||
|
sorted_queued = sorted(pending_tasks.keys(), key=lambda x: pending_tasks[x])
|
||||||
|
queue_index = sorted_queued.index(req.id_task)
|
||||||
|
textinfo = "In queue: {}/{}".format(queue_index + 1, len(sorted_queued))
|
||||||
|
return ProgressResponse(active=active, queued=queued, completed=completed, id_live_preview=-1, textinfo=textinfo)
|
||||||
|
|
||||||
progress = 0
|
progress = 0
|
||||||
|
|
||||||
@@ -89,9 +95,12 @@ def progressapi(req: ProgressRequest):
|
|||||||
predicted_duration = elapsed_since_start / progress if progress > 0 else None
|
predicted_duration = elapsed_since_start / progress if progress > 0 else None
|
||||||
eta = predicted_duration - elapsed_since_start if predicted_duration is not None else None
|
eta = predicted_duration - elapsed_since_start if predicted_duration is not None else None
|
||||||
|
|
||||||
|
live_preview = None
|
||||||
id_live_preview = req.id_live_preview
|
id_live_preview = req.id_live_preview
|
||||||
|
|
||||||
|
if opts.live_previews_enable and req.live_preview:
|
||||||
shared.state.set_current_image()
|
shared.state.set_current_image()
|
||||||
if opts.live_previews_enable and shared.state.id_live_preview != req.id_live_preview:
|
if shared.state.id_live_preview != req.id_live_preview:
|
||||||
image = shared.state.current_image
|
image = shared.state.current_image
|
||||||
if image is not None:
|
if image is not None:
|
||||||
buffered = io.BytesIO()
|
buffered = io.BytesIO()
|
||||||
@@ -110,10 +119,6 @@ def progressapi(req: ProgressRequest):
|
|||||||
base64_image = base64.b64encode(buffered.getvalue()).decode('ascii')
|
base64_image = base64.b64encode(buffered.getvalue()).decode('ascii')
|
||||||
live_preview = f"data:image/{opts.live_previews_image_format};base64,{base64_image}"
|
live_preview = f"data:image/{opts.live_previews_image_format};base64,{base64_image}"
|
||||||
id_live_preview = shared.state.id_live_preview
|
id_live_preview = shared.state.id_live_preview
|
||||||
else:
|
|
||||||
live_preview = None
|
|
||||||
else:
|
|
||||||
live_preview = None
|
|
||||||
|
|
||||||
return ProgressResponse(active=active, queued=queued, completed=completed, progress=progress, eta=eta, live_preview=live_preview, id_live_preview=id_live_preview, textinfo=shared.state.textinfo)
|
return ProgressResponse(active=active, queued=queued, completed=completed, progress=progress, eta=eta, live_preview=live_preview, id_live_preview=id_live_preview, textinfo=shared.state.textinfo)
|
||||||
|
|
||||||
|
|||||||
+31
-10
@@ -26,7 +26,7 @@ plain: /([^\\\[\]():|]|\\.)+/
|
|||||||
%import common.SIGNED_NUMBER -> NUMBER
|
%import common.SIGNED_NUMBER -> NUMBER
|
||||||
""")
|
""")
|
||||||
|
|
||||||
def get_learned_conditioning_prompt_schedules(prompts, steps):
|
def get_learned_conditioning_prompt_schedules(prompts, base_steps, hires_steps=None, use_old_scheduling=False):
|
||||||
"""
|
"""
|
||||||
>>> g = lambda p: get_learned_conditioning_prompt_schedules([p], 10)[0]
|
>>> g = lambda p: get_learned_conditioning_prompt_schedules([p], 10)[0]
|
||||||
>>> g("test")
|
>>> g("test")
|
||||||
@@ -57,17 +57,38 @@ def get_learned_conditioning_prompt_schedules(prompts, steps):
|
|||||||
[[1, 'female'], [2, 'male'], [3, 'female'], [4, 'male'], [5, 'female'], [6, 'male'], [7, 'female'], [8, 'male'], [9, 'female'], [10, 'male']]
|
[[1, 'female'], [2, 'male'], [3, 'female'], [4, 'male'], [5, 'female'], [6, 'male'], [7, 'female'], [8, 'male'], [9, 'female'], [10, 'male']]
|
||||||
>>> g("[fe|||]male")
|
>>> g("[fe|||]male")
|
||||||
[[1, 'female'], [2, 'male'], [3, 'male'], [4, 'male'], [5, 'female'], [6, 'male'], [7, 'male'], [8, 'male'], [9, 'female'], [10, 'male']]
|
[[1, 'female'], [2, 'male'], [3, 'male'], [4, 'male'], [5, 'female'], [6, 'male'], [7, 'male'], [8, 'male'], [9, 'female'], [10, 'male']]
|
||||||
|
>>> g = lambda p: get_learned_conditioning_prompt_schedules([p], 10, 10)[0]
|
||||||
|
>>> g("a [b:.5] c")
|
||||||
|
[[10, 'a b c']]
|
||||||
|
>>> g("a [b:1.5] c")
|
||||||
|
[[5, 'a c'], [10, 'a b c']]
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if hires_steps is None or use_old_scheduling:
|
||||||
|
int_offset = 0
|
||||||
|
flt_offset = 0
|
||||||
|
steps = base_steps
|
||||||
|
else:
|
||||||
|
int_offset = base_steps
|
||||||
|
flt_offset = 1.0
|
||||||
|
steps = hires_steps
|
||||||
|
|
||||||
def collect_steps(steps, tree):
|
def collect_steps(steps, tree):
|
||||||
res = [steps]
|
res = [steps]
|
||||||
|
|
||||||
class CollectSteps(lark.Visitor):
|
class CollectSteps(lark.Visitor):
|
||||||
def scheduled(self, tree):
|
def scheduled(self, tree):
|
||||||
tree.children[-2] = float(tree.children[-2])
|
s = tree.children[-2]
|
||||||
if tree.children[-2] < 1:
|
v = float(s)
|
||||||
tree.children[-2] *= steps
|
if use_old_scheduling:
|
||||||
tree.children[-2] = min(steps, int(tree.children[-2]))
|
v = v*steps if v<1 else v
|
||||||
|
else:
|
||||||
|
if "." in s:
|
||||||
|
v = (v - flt_offset) * steps
|
||||||
|
else:
|
||||||
|
v = (v - int_offset)
|
||||||
|
tree.children[-2] = min(steps, int(v))
|
||||||
|
if tree.children[-2] >= 1:
|
||||||
res.append(tree.children[-2])
|
res.append(tree.children[-2])
|
||||||
|
|
||||||
def alternate(self, tree):
|
def alternate(self, tree):
|
||||||
@@ -86,7 +107,7 @@ def get_learned_conditioning_prompt_schedules(prompts, steps):
|
|||||||
yield args[(step - 1) % len(args)]
|
yield args[(step - 1) % len(args)]
|
||||||
def start(self, args):
|
def start(self, args):
|
||||||
def flatten(x):
|
def flatten(x):
|
||||||
if type(x) == str:
|
if isinstance(x, str):
|
||||||
yield x
|
yield x
|
||||||
else:
|
else:
|
||||||
for gen in x:
|
for gen in x:
|
||||||
@@ -134,7 +155,7 @@ class SdConditioning(list):
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
def get_learned_conditioning(model, prompts: SdConditioning | list[str], steps):
|
def get_learned_conditioning(model, prompts: SdConditioning | list[str], steps, hires_steps=None, use_old_scheduling=False):
|
||||||
"""converts a list of prompts into a list of prompt schedules - each schedule is a list of ScheduledPromptConditioning, specifying the comdition (cond),
|
"""converts a list of prompts into a list of prompt schedules - each schedule is a list of ScheduledPromptConditioning, specifying the comdition (cond),
|
||||||
and the sampling step at which this condition is to be replaced by the next one.
|
and the sampling step at which this condition is to be replaced by the next one.
|
||||||
|
|
||||||
@@ -154,7 +175,7 @@ def get_learned_conditioning(model, prompts: SdConditioning | list[str], steps):
|
|||||||
"""
|
"""
|
||||||
res = []
|
res = []
|
||||||
|
|
||||||
prompt_schedules = get_learned_conditioning_prompt_schedules(prompts, steps)
|
prompt_schedules = get_learned_conditioning_prompt_schedules(prompts, steps, hires_steps, use_old_scheduling)
|
||||||
cache = {}
|
cache = {}
|
||||||
|
|
||||||
for prompt, prompt_schedule in zip(prompts, prompt_schedules):
|
for prompt, prompt_schedule in zip(prompts, prompt_schedules):
|
||||||
@@ -229,7 +250,7 @@ class MulticondLearnedConditioning:
|
|||||||
self.batch: List[List[ComposableScheduledPromptConditioning]] = batch
|
self.batch: List[List[ComposableScheduledPromptConditioning]] = batch
|
||||||
|
|
||||||
|
|
||||||
def get_multicond_learned_conditioning(model, prompts, steps) -> MulticondLearnedConditioning:
|
def get_multicond_learned_conditioning(model, prompts, steps, hires_steps=None, use_old_scheduling=False) -> MulticondLearnedConditioning:
|
||||||
"""same as get_learned_conditioning, but returns a list of ScheduledPromptConditioning along with the weight objects for each prompt.
|
"""same as get_learned_conditioning, but returns a list of ScheduledPromptConditioning along with the weight objects for each prompt.
|
||||||
For each prompt, the list is obtained by splitting the prompt using the AND separator.
|
For each prompt, the list is obtained by splitting the prompt using the AND separator.
|
||||||
|
|
||||||
@@ -238,7 +259,7 @@ def get_multicond_learned_conditioning(model, prompts, steps) -> MulticondLearne
|
|||||||
|
|
||||||
res_indexes, prompt_flat_list, prompt_indexes = get_multicond_prompt_list(prompts)
|
res_indexes, prompt_flat_list, prompt_indexes = get_multicond_prompt_list(prompts)
|
||||||
|
|
||||||
learned_conditioning = get_learned_conditioning(model, prompt_flat_list, steps)
|
learned_conditioning = get_learned_conditioning(model, prompt_flat_list, steps, hires_steps, use_old_scheduling)
|
||||||
|
|
||||||
res = []
|
res = []
|
||||||
for indexes in res_indexes:
|
for indexes in res_indexes:
|
||||||
|
|||||||
@@ -55,6 +55,7 @@ class UpscalerRealESRGAN(Upscaler):
|
|||||||
half=not cmd_opts.no_half and not cmd_opts.upcast_sampling,
|
half=not cmd_opts.no_half and not cmd_opts.upcast_sampling,
|
||||||
tile=opts.ESRGAN_tile,
|
tile=opts.ESRGAN_tile,
|
||||||
tile_pad=opts.ESRGAN_tile_overlap,
|
tile_pad=opts.ESRGAN_tile_overlap,
|
||||||
|
device=self.device,
|
||||||
)
|
)
|
||||||
|
|
||||||
upsampled = upsampler.enhance(np.array(img), outscale=info.scale)[0]
|
upsampled = upsampler.enhance(np.array(img), outscale=info.scale)[0]
|
||||||
|
|||||||
+170
@@ -0,0 +1,170 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
from modules import devices, rng_philox, shared
|
||||||
|
|
||||||
|
|
||||||
|
def randn(seed, shape, generator=None):
|
||||||
|
"""Generate a tensor with random numbers from a normal distribution using seed.
|
||||||
|
|
||||||
|
Uses the seed parameter to set the global torch seed; to generate more with that seed, use randn_like/randn_without_seed."""
|
||||||
|
|
||||||
|
manual_seed(seed)
|
||||||
|
|
||||||
|
if shared.opts.randn_source == "NV":
|
||||||
|
return torch.asarray((generator or nv_rng).randn(shape), device=devices.device)
|
||||||
|
|
||||||
|
if shared.opts.randn_source == "CPU" or devices.device.type == 'mps':
|
||||||
|
return torch.randn(shape, device=devices.cpu, generator=generator).to(devices.device)
|
||||||
|
|
||||||
|
return torch.randn(shape, device=devices.device, generator=generator)
|
||||||
|
|
||||||
|
|
||||||
|
def randn_local(seed, shape):
|
||||||
|
"""Generate a tensor with random numbers from a normal distribution using seed.
|
||||||
|
|
||||||
|
Does not change the global random number generator. You can only generate the seed's first tensor using this function."""
|
||||||
|
|
||||||
|
if shared.opts.randn_source == "NV":
|
||||||
|
rng = rng_philox.Generator(seed)
|
||||||
|
return torch.asarray(rng.randn(shape), device=devices.device)
|
||||||
|
|
||||||
|
local_device = devices.cpu if shared.opts.randn_source == "CPU" or devices.device.type == 'mps' else devices.device
|
||||||
|
local_generator = torch.Generator(local_device).manual_seed(int(seed))
|
||||||
|
return torch.randn(shape, device=local_device, generator=local_generator).to(devices.device)
|
||||||
|
|
||||||
|
|
||||||
|
def randn_like(x):
|
||||||
|
"""Generate a tensor with random numbers from a normal distribution using the previously initialized genrator.
|
||||||
|
|
||||||
|
Use either randn() or manual_seed() to initialize the generator."""
|
||||||
|
|
||||||
|
if shared.opts.randn_source == "NV":
|
||||||
|
return torch.asarray(nv_rng.randn(x.shape), device=x.device, dtype=x.dtype)
|
||||||
|
|
||||||
|
if shared.opts.randn_source == "CPU" or x.device.type == 'mps':
|
||||||
|
return torch.randn_like(x, device=devices.cpu).to(x.device)
|
||||||
|
|
||||||
|
return torch.randn_like(x)
|
||||||
|
|
||||||
|
|
||||||
|
def randn_without_seed(shape, generator=None):
|
||||||
|
"""Generate a tensor with random numbers from a normal distribution using the previously initialized genrator.
|
||||||
|
|
||||||
|
Use either randn() or manual_seed() to initialize the generator."""
|
||||||
|
|
||||||
|
if shared.opts.randn_source == "NV":
|
||||||
|
return torch.asarray((generator or nv_rng).randn(shape), device=devices.device)
|
||||||
|
|
||||||
|
if shared.opts.randn_source == "CPU" or devices.device.type == 'mps':
|
||||||
|
return torch.randn(shape, device=devices.cpu, generator=generator).to(devices.device)
|
||||||
|
|
||||||
|
return torch.randn(shape, device=devices.device, generator=generator)
|
||||||
|
|
||||||
|
|
||||||
|
def manual_seed(seed):
|
||||||
|
"""Set up a global random number generator using the specified seed."""
|
||||||
|
|
||||||
|
if shared.opts.randn_source == "NV":
|
||||||
|
global nv_rng
|
||||||
|
nv_rng = rng_philox.Generator(seed)
|
||||||
|
return
|
||||||
|
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
|
||||||
|
|
||||||
|
def create_generator(seed):
|
||||||
|
if shared.opts.randn_source == "NV":
|
||||||
|
return rng_philox.Generator(seed)
|
||||||
|
|
||||||
|
device = devices.cpu if shared.opts.randn_source == "CPU" or devices.device.type == 'mps' else devices.device
|
||||||
|
generator = torch.Generator(device).manual_seed(int(seed))
|
||||||
|
return generator
|
||||||
|
|
||||||
|
|
||||||
|
# from https://discuss.pytorch.org/t/help-regarding-slerp-function-for-generative-model-sampling/32475/3
|
||||||
|
def slerp(val, low, high):
|
||||||
|
low_norm = low/torch.norm(low, dim=1, keepdim=True)
|
||||||
|
high_norm = high/torch.norm(high, dim=1, keepdim=True)
|
||||||
|
dot = (low_norm*high_norm).sum(1)
|
||||||
|
|
||||||
|
if dot.mean() > 0.9995:
|
||||||
|
return low * val + high * (1 - val)
|
||||||
|
|
||||||
|
omega = torch.acos(dot)
|
||||||
|
so = torch.sin(omega)
|
||||||
|
res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low + (torch.sin(val*omega)/so).unsqueeze(1) * high
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
class ImageRNG:
|
||||||
|
def __init__(self, shape, seeds, subseeds=None, subseed_strength=0.0, seed_resize_from_h=0, seed_resize_from_w=0):
|
||||||
|
self.shape = tuple(map(int, shape))
|
||||||
|
self.seeds = seeds
|
||||||
|
self.subseeds = subseeds
|
||||||
|
self.subseed_strength = subseed_strength
|
||||||
|
self.seed_resize_from_h = seed_resize_from_h
|
||||||
|
self.seed_resize_from_w = seed_resize_from_w
|
||||||
|
|
||||||
|
self.generators = [create_generator(seed) for seed in seeds]
|
||||||
|
|
||||||
|
self.is_first = True
|
||||||
|
|
||||||
|
def first(self):
|
||||||
|
noise_shape = self.shape if self.seed_resize_from_h <= 0 or self.seed_resize_from_w <= 0 else (self.shape[0], self.seed_resize_from_h // 8, self.seed_resize_from_w // 8)
|
||||||
|
|
||||||
|
xs = []
|
||||||
|
|
||||||
|
for i, (seed, generator) in enumerate(zip(self.seeds, self.generators)):
|
||||||
|
subnoise = None
|
||||||
|
if self.subseeds is not None and self.subseed_strength != 0:
|
||||||
|
subseed = 0 if i >= len(self.subseeds) else self.subseeds[i]
|
||||||
|
subnoise = randn(subseed, noise_shape)
|
||||||
|
|
||||||
|
if noise_shape != self.shape:
|
||||||
|
noise = randn(seed, noise_shape)
|
||||||
|
else:
|
||||||
|
noise = randn(seed, self.shape, generator=generator)
|
||||||
|
|
||||||
|
if subnoise is not None:
|
||||||
|
noise = slerp(self.subseed_strength, noise, subnoise)
|
||||||
|
|
||||||
|
if noise_shape != self.shape:
|
||||||
|
x = randn(seed, self.shape, generator=generator)
|
||||||
|
dx = (self.shape[2] - noise_shape[2]) // 2
|
||||||
|
dy = (self.shape[1] - noise_shape[1]) // 2
|
||||||
|
w = noise_shape[2] if dx >= 0 else noise_shape[2] + 2 * dx
|
||||||
|
h = noise_shape[1] if dy >= 0 else noise_shape[1] + 2 * dy
|
||||||
|
tx = 0 if dx < 0 else dx
|
||||||
|
ty = 0 if dy < 0 else dy
|
||||||
|
dx = max(-dx, 0)
|
||||||
|
dy = max(-dy, 0)
|
||||||
|
|
||||||
|
x[:, ty:ty + h, tx:tx + w] = noise[:, dy:dy + h, dx:dx + w]
|
||||||
|
noise = x
|
||||||
|
|
||||||
|
xs.append(noise)
|
||||||
|
|
||||||
|
eta_noise_seed_delta = shared.opts.eta_noise_seed_delta or 0
|
||||||
|
if eta_noise_seed_delta:
|
||||||
|
self.generators = [create_generator(seed + eta_noise_seed_delta) for seed in self.seeds]
|
||||||
|
|
||||||
|
return torch.stack(xs).to(shared.device)
|
||||||
|
|
||||||
|
def next(self):
|
||||||
|
if self.is_first:
|
||||||
|
self.is_first = False
|
||||||
|
return self.first()
|
||||||
|
|
||||||
|
xs = []
|
||||||
|
for generator in self.generators:
|
||||||
|
x = randn_without_seed(self.shape, generator=generator)
|
||||||
|
xs.append(x)
|
||||||
|
|
||||||
|
return torch.stack(xs).to(shared.device)
|
||||||
|
|
||||||
|
|
||||||
|
devices.randn = randn
|
||||||
|
devices.randn_local = randn_local
|
||||||
|
devices.randn_like = randn_like
|
||||||
|
devices.randn_without_seed = randn_without_seed
|
||||||
|
devices.manual_seed = manual_seed
|
||||||
@@ -28,6 +28,15 @@ class ImageSaveParams:
|
|||||||
"""dictionary with parameters for image's PNG info data; infotext will have the key 'parameters'"""
|
"""dictionary with parameters for image's PNG info data; infotext will have the key 'parameters'"""
|
||||||
|
|
||||||
|
|
||||||
|
class ExtraNoiseParams:
|
||||||
|
def __init__(self, noise, x):
|
||||||
|
self.noise = noise
|
||||||
|
"""Random noise generated by the seed"""
|
||||||
|
|
||||||
|
self.x = x
|
||||||
|
"""Latent image representation of the image"""
|
||||||
|
|
||||||
|
|
||||||
class CFGDenoiserParams:
|
class CFGDenoiserParams:
|
||||||
def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps, text_cond, text_uncond):
|
def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps, text_cond, text_uncond):
|
||||||
self.x = x
|
self.x = x
|
||||||
@@ -100,6 +109,7 @@ callback_map = dict(
|
|||||||
callbacks_ui_settings=[],
|
callbacks_ui_settings=[],
|
||||||
callbacks_before_image_saved=[],
|
callbacks_before_image_saved=[],
|
||||||
callbacks_image_saved=[],
|
callbacks_image_saved=[],
|
||||||
|
callbacks_extra_noise=[],
|
||||||
callbacks_cfg_denoiser=[],
|
callbacks_cfg_denoiser=[],
|
||||||
callbacks_cfg_denoised=[],
|
callbacks_cfg_denoised=[],
|
||||||
callbacks_cfg_after_cfg=[],
|
callbacks_cfg_after_cfg=[],
|
||||||
@@ -189,6 +199,14 @@ def image_saved_callback(params: ImageSaveParams):
|
|||||||
report_exception(c, 'image_saved_callback')
|
report_exception(c, 'image_saved_callback')
|
||||||
|
|
||||||
|
|
||||||
|
def extra_noise_callback(params: ExtraNoiseParams):
|
||||||
|
for c in callback_map['callbacks_extra_noise']:
|
||||||
|
try:
|
||||||
|
c.callback(params)
|
||||||
|
except Exception:
|
||||||
|
report_exception(c, 'callbacks_extra_noise')
|
||||||
|
|
||||||
|
|
||||||
def cfg_denoiser_callback(params: CFGDenoiserParams):
|
def cfg_denoiser_callback(params: CFGDenoiserParams):
|
||||||
for c in callback_map['callbacks_cfg_denoiser']:
|
for c in callback_map['callbacks_cfg_denoiser']:
|
||||||
try:
|
try:
|
||||||
@@ -367,6 +385,14 @@ def on_image_saved(callback):
|
|||||||
add_callback(callback_map['callbacks_image_saved'], callback)
|
add_callback(callback_map['callbacks_image_saved'], callback)
|
||||||
|
|
||||||
|
|
||||||
|
def on_extra_noise(callback):
|
||||||
|
"""register a function to be called before adding extra noise in img2img or hires fix;
|
||||||
|
The callback is called with one argument:
|
||||||
|
- params: ExtraNoiseParams - contains noise determined by seed and latent representation of image
|
||||||
|
"""
|
||||||
|
add_callback(callback_map['callbacks_extra_noise'], callback)
|
||||||
|
|
||||||
|
|
||||||
def on_cfg_denoiser(callback):
|
def on_cfg_denoiser(callback):
|
||||||
"""register a function to be called in the kdiffussion cfg_denoiser method after building the inner model inputs.
|
"""register a function to be called in the kdiffussion cfg_denoiser method after building the inner model inputs.
|
||||||
The callback is called with one argument:
|
The callback is called with one argument:
|
||||||
|
|||||||
+131
-6
@@ -3,6 +3,7 @@ import re
|
|||||||
import sys
|
import sys
|
||||||
import inspect
|
import inspect
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
@@ -21,6 +22,11 @@ class PostprocessBatchListArgs:
|
|||||||
self.images = images
|
self.images = images
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class OnComponent:
|
||||||
|
component: gr.blocks.Block
|
||||||
|
|
||||||
|
|
||||||
class Script:
|
class Script:
|
||||||
name = None
|
name = None
|
||||||
"""script's internal name derived from title"""
|
"""script's internal name derived from title"""
|
||||||
@@ -35,9 +41,13 @@ class Script:
|
|||||||
|
|
||||||
is_txt2img = False
|
is_txt2img = False
|
||||||
is_img2img = False
|
is_img2img = False
|
||||||
|
tabname = None
|
||||||
|
|
||||||
group = None
|
group = None
|
||||||
"""A gr.Group component that has all script's UI inside it"""
|
"""A gr.Group component that has all script's UI inside it."""
|
||||||
|
|
||||||
|
create_group = True
|
||||||
|
"""If False, for alwayson scripts, a group component will not be created."""
|
||||||
|
|
||||||
infotext_fields = None
|
infotext_fields = None
|
||||||
"""if set in ui(), this is a list of pairs of gradio component + text; the text will be used when
|
"""if set in ui(), this is a list of pairs of gradio component + text; the text will be used when
|
||||||
@@ -52,6 +62,15 @@ class Script:
|
|||||||
api_info = None
|
api_info = None
|
||||||
"""Generated value of type modules.api.models.ScriptInfo with information about the script for API"""
|
"""Generated value of type modules.api.models.ScriptInfo with information about the script for API"""
|
||||||
|
|
||||||
|
on_before_component_elem_id = None
|
||||||
|
"""list of callbacks to be called before a component with an elem_id is created"""
|
||||||
|
|
||||||
|
on_after_component_elem_id = None
|
||||||
|
"""list of callbacks to be called after a component with an elem_id is created"""
|
||||||
|
|
||||||
|
setup_for_ui_only = False
|
||||||
|
"""If true, the script setup will only be run in Gradio UI, not in API"""
|
||||||
|
|
||||||
def title(self):
|
def title(self):
|
||||||
"""this function should return the title of the script. This is what will be displayed in the dropdown menu."""
|
"""this function should return the title of the script. This is what will be displayed in the dropdown menu."""
|
||||||
|
|
||||||
@@ -90,9 +109,16 @@ class Script:
|
|||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def setup(self, p, *args):
|
||||||
|
"""For AlwaysVisible scripts, this function is called when the processing object is set up, before any processing starts.
|
||||||
|
args contains all values returned by components from ui().
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def before_process(self, p, *args):
|
def before_process(self, p, *args):
|
||||||
"""
|
"""
|
||||||
This function is called very early before processing begins for AlwaysVisible scripts.
|
This function is called very early during processing begins for AlwaysVisible scripts.
|
||||||
You can modify the processing object (p) here, inject hooks, etc.
|
You can modify the processing object (p) here, inject hooks, etc.
|
||||||
args contains all values returned by components from ui()
|
args contains all values returned by components from ui()
|
||||||
"""
|
"""
|
||||||
@@ -212,6 +238,29 @@ class Script:
|
|||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def on_before_component(self, callback, *, elem_id):
|
||||||
|
"""
|
||||||
|
Calls callback before a component is created. The callback function is called with a single argument of type OnComponent.
|
||||||
|
|
||||||
|
May be called in show() or ui() - but it may be too late in latter as some components may already be created.
|
||||||
|
|
||||||
|
This function is an alternative to before_component in that it also cllows to run before a component is created, but
|
||||||
|
it doesn't require to be called for every created component - just for the one you need.
|
||||||
|
"""
|
||||||
|
if self.on_before_component_elem_id is None:
|
||||||
|
self.on_before_component_elem_id = []
|
||||||
|
|
||||||
|
self.on_before_component_elem_id.append((elem_id, callback))
|
||||||
|
|
||||||
|
def on_after_component(self, callback, *, elem_id):
|
||||||
|
"""
|
||||||
|
Calls callback after a component is created. The callback function is called with a single argument of type OnComponent.
|
||||||
|
"""
|
||||||
|
if self.on_after_component_elem_id is None:
|
||||||
|
self.on_after_component_elem_id = []
|
||||||
|
|
||||||
|
self.on_after_component_elem_id.append((elem_id, callback))
|
||||||
|
|
||||||
def describe(self):
|
def describe(self):
|
||||||
"""unused"""
|
"""unused"""
|
||||||
return ""
|
return ""
|
||||||
@@ -220,7 +269,7 @@ class Script:
|
|||||||
"""helper function to generate id for a HTML element, constructs final id out of script name, tab and user-supplied item_id"""
|
"""helper function to generate id for a HTML element, constructs final id out of script name, tab and user-supplied item_id"""
|
||||||
|
|
||||||
need_tabname = self.show(True) == self.show(False)
|
need_tabname = self.show(True) == self.show(False)
|
||||||
tabkind = 'img2img' if self.is_img2img else 'txt2txt'
|
tabkind = 'img2img' if self.is_img2img else 'txt2img'
|
||||||
tabname = f"{tabkind}_" if need_tabname else ""
|
tabname = f"{tabkind}_" if need_tabname else ""
|
||||||
title = re.sub(r'[^a-z_0-9]', '', re.sub(r'\s', '_', self.title().lower()))
|
title = re.sub(r'[^a-z_0-9]', '', re.sub(r'\s', '_', self.title().lower()))
|
||||||
|
|
||||||
@@ -232,6 +281,19 @@ class Script:
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ScriptBuiltinUI(Script):
|
||||||
|
setup_for_ui_only = True
|
||||||
|
|
||||||
|
def elem_id(self, item_id):
|
||||||
|
"""helper function to generate id for a HTML element, constructs final id out of tab and user-supplied item_id"""
|
||||||
|
|
||||||
|
need_tabname = self.show(True) == self.show(False)
|
||||||
|
tabname = ('img2img' if self.is_img2img else 'txt2img') + "_" if need_tabname else ""
|
||||||
|
|
||||||
|
return f'{tabname}{item_id}'
|
||||||
|
|
||||||
|
|
||||||
current_basedir = paths.script_path
|
current_basedir = paths.script_path
|
||||||
|
|
||||||
|
|
||||||
@@ -250,7 +312,7 @@ postprocessing_scripts_data = []
|
|||||||
ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir", "module"])
|
ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir", "module"])
|
||||||
|
|
||||||
|
|
||||||
def list_scripts(scriptdirname, extension):
|
def list_scripts(scriptdirname, extension, *, include_extensions=True):
|
||||||
scripts_list = []
|
scripts_list = []
|
||||||
|
|
||||||
basedir = os.path.join(paths.script_path, scriptdirname)
|
basedir = os.path.join(paths.script_path, scriptdirname)
|
||||||
@@ -258,6 +320,7 @@ def list_scripts(scriptdirname, extension):
|
|||||||
for filename in sorted(os.listdir(basedir)):
|
for filename in sorted(os.listdir(basedir)):
|
||||||
scripts_list.append(ScriptFile(paths.script_path, filename, os.path.join(basedir, filename)))
|
scripts_list.append(ScriptFile(paths.script_path, filename, os.path.join(basedir, filename)))
|
||||||
|
|
||||||
|
if include_extensions:
|
||||||
for ext in extensions.active():
|
for ext in extensions.active():
|
||||||
scripts_list += ext.list_files(scriptdirname, extension)
|
scripts_list += ext.list_files(scriptdirname, extension)
|
||||||
|
|
||||||
@@ -288,7 +351,7 @@ def load_scripts():
|
|||||||
postprocessing_scripts_data.clear()
|
postprocessing_scripts_data.clear()
|
||||||
script_callbacks.clear_callbacks()
|
script_callbacks.clear_callbacks()
|
||||||
|
|
||||||
scripts_list = list_scripts("scripts", ".py")
|
scripts_list = list_scripts("scripts", ".py") + list_scripts("modules/processing_scripts", ".py", include_extensions=False)
|
||||||
|
|
||||||
syspath = sys.path
|
syspath = sys.path
|
||||||
|
|
||||||
@@ -349,10 +412,17 @@ class ScriptRunner:
|
|||||||
self.selectable_scripts = []
|
self.selectable_scripts = []
|
||||||
self.alwayson_scripts = []
|
self.alwayson_scripts = []
|
||||||
self.titles = []
|
self.titles = []
|
||||||
|
self.title_map = {}
|
||||||
self.infotext_fields = []
|
self.infotext_fields = []
|
||||||
self.paste_field_names = []
|
self.paste_field_names = []
|
||||||
self.inputs = [None]
|
self.inputs = [None]
|
||||||
|
|
||||||
|
self.on_before_component_elem_id = {}
|
||||||
|
"""dict of callbacks to be called before an element is created; key=elem_id, value=list of callbacks"""
|
||||||
|
|
||||||
|
self.on_after_component_elem_id = {}
|
||||||
|
"""dict of callbacks to be called after an element is created; key=elem_id, value=list of callbacks"""
|
||||||
|
|
||||||
def initialize_scripts(self, is_img2img):
|
def initialize_scripts(self, is_img2img):
|
||||||
from modules import scripts_auto_postprocessing
|
from modules import scripts_auto_postprocessing
|
||||||
|
|
||||||
@@ -367,6 +437,7 @@ class ScriptRunner:
|
|||||||
script.filename = script_data.path
|
script.filename = script_data.path
|
||||||
script.is_txt2img = not is_img2img
|
script.is_txt2img = not is_img2img
|
||||||
script.is_img2img = is_img2img
|
script.is_img2img = is_img2img
|
||||||
|
script.tabname = "img2img" if is_img2img else "txt2img"
|
||||||
|
|
||||||
visibility = script.show(script.is_img2img)
|
visibility = script.show(script.is_img2img)
|
||||||
|
|
||||||
@@ -379,6 +450,28 @@ class ScriptRunner:
|
|||||||
self.scripts.append(script)
|
self.scripts.append(script)
|
||||||
self.selectable_scripts.append(script)
|
self.selectable_scripts.append(script)
|
||||||
|
|
||||||
|
self.apply_on_before_component_callbacks()
|
||||||
|
|
||||||
|
def apply_on_before_component_callbacks(self):
|
||||||
|
for script in self.scripts:
|
||||||
|
on_before = script.on_before_component_elem_id or []
|
||||||
|
on_after = script.on_after_component_elem_id or []
|
||||||
|
|
||||||
|
for elem_id, callback in on_before:
|
||||||
|
if elem_id not in self.on_before_component_elem_id:
|
||||||
|
self.on_before_component_elem_id[elem_id] = []
|
||||||
|
|
||||||
|
self.on_before_component_elem_id[elem_id].append((callback, script))
|
||||||
|
|
||||||
|
for elem_id, callback in on_after:
|
||||||
|
if elem_id not in self.on_after_component_elem_id:
|
||||||
|
self.on_after_component_elem_id[elem_id] = []
|
||||||
|
|
||||||
|
self.on_after_component_elem_id[elem_id].append((callback, script))
|
||||||
|
|
||||||
|
on_before.clear()
|
||||||
|
on_after.clear()
|
||||||
|
|
||||||
def create_script_ui(self, script):
|
def create_script_ui(self, script):
|
||||||
import modules.api.models as api_models
|
import modules.api.models as api_models
|
||||||
|
|
||||||
@@ -429,15 +522,20 @@ class ScriptRunner:
|
|||||||
if script.alwayson and script.section != section:
|
if script.alwayson and script.section != section:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
if script.create_group:
|
||||||
with gr.Group(visible=script.alwayson) as group:
|
with gr.Group(visible=script.alwayson) as group:
|
||||||
self.create_script_ui(script)
|
self.create_script_ui(script)
|
||||||
|
|
||||||
script.group = group
|
script.group = group
|
||||||
|
else:
|
||||||
|
self.create_script_ui(script)
|
||||||
|
|
||||||
def prepare_ui(self):
|
def prepare_ui(self):
|
||||||
self.inputs = [None]
|
self.inputs = [None]
|
||||||
|
|
||||||
def setup_ui(self):
|
def setup_ui(self):
|
||||||
|
all_titles = [wrap_call(script.title, script.filename, "title") or script.filename for script in self.scripts]
|
||||||
|
self.title_map = {title.lower(): script for title, script in zip(all_titles, self.scripts)}
|
||||||
self.titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.selectable_scripts]
|
self.titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.selectable_scripts]
|
||||||
|
|
||||||
self.setup_ui_for_section(None)
|
self.setup_ui_for_section(None)
|
||||||
@@ -484,6 +582,8 @@ class ScriptRunner:
|
|||||||
self.infotext_fields.append((dropdown, lambda x: gr.update(value=x.get('Script', 'None'))))
|
self.infotext_fields.append((dropdown, lambda x: gr.update(value=x.get('Script', 'None'))))
|
||||||
self.infotext_fields.extend([(script.group, onload_script_visibility) for script in self.selectable_scripts])
|
self.infotext_fields.extend([(script.group, onload_script_visibility) for script in self.selectable_scripts])
|
||||||
|
|
||||||
|
self.apply_on_before_component_callbacks()
|
||||||
|
|
||||||
return self.inputs
|
return self.inputs
|
||||||
|
|
||||||
def run(self, p, *args):
|
def run(self, p, *args):
|
||||||
@@ -577,6 +677,12 @@ class ScriptRunner:
|
|||||||
errors.report(f"Error running postprocess_image: {script.filename}", exc_info=True)
|
errors.report(f"Error running postprocess_image: {script.filename}", exc_info=True)
|
||||||
|
|
||||||
def before_component(self, component, **kwargs):
|
def before_component(self, component, **kwargs):
|
||||||
|
for callback, script in self.on_before_component_elem_id.get(kwargs.get("elem_id"), []):
|
||||||
|
try:
|
||||||
|
callback(OnComponent(component=component))
|
||||||
|
except Exception:
|
||||||
|
errors.report(f"Error running on_before_component: {script.filename}", exc_info=True)
|
||||||
|
|
||||||
for script in self.scripts:
|
for script in self.scripts:
|
||||||
try:
|
try:
|
||||||
script.before_component(component, **kwargs)
|
script.before_component(component, **kwargs)
|
||||||
@@ -584,12 +690,21 @@ class ScriptRunner:
|
|||||||
errors.report(f"Error running before_component: {script.filename}", exc_info=True)
|
errors.report(f"Error running before_component: {script.filename}", exc_info=True)
|
||||||
|
|
||||||
def after_component(self, component, **kwargs):
|
def after_component(self, component, **kwargs):
|
||||||
|
for callback, script in self.on_after_component_elem_id.get(component.elem_id, []):
|
||||||
|
try:
|
||||||
|
callback(OnComponent(component=component))
|
||||||
|
except Exception:
|
||||||
|
errors.report(f"Error running on_after_component: {script.filename}", exc_info=True)
|
||||||
|
|
||||||
for script in self.scripts:
|
for script in self.scripts:
|
||||||
try:
|
try:
|
||||||
script.after_component(component, **kwargs)
|
script.after_component(component, **kwargs)
|
||||||
except Exception:
|
except Exception:
|
||||||
errors.report(f"Error running after_component: {script.filename}", exc_info=True)
|
errors.report(f"Error running after_component: {script.filename}", exc_info=True)
|
||||||
|
|
||||||
|
def script(self, title):
|
||||||
|
return self.title_map.get(title.lower())
|
||||||
|
|
||||||
def reload_sources(self, cache):
|
def reload_sources(self, cache):
|
||||||
for si, script in list(enumerate(self.scripts)):
|
for si, script in list(enumerate(self.scripts)):
|
||||||
args_from = script.args_from
|
args_from = script.args_from
|
||||||
@@ -608,7 +723,6 @@ class ScriptRunner:
|
|||||||
self.scripts[si].args_from = args_from
|
self.scripts[si].args_from = args_from
|
||||||
self.scripts[si].args_to = args_to
|
self.scripts[si].args_to = args_to
|
||||||
|
|
||||||
|
|
||||||
def before_hr(self, p):
|
def before_hr(self, p):
|
||||||
for script in self.alwayson_scripts:
|
for script in self.alwayson_scripts:
|
||||||
try:
|
try:
|
||||||
@@ -617,6 +731,17 @@ class ScriptRunner:
|
|||||||
except Exception:
|
except Exception:
|
||||||
errors.report(f"Error running before_hr: {script.filename}", exc_info=True)
|
errors.report(f"Error running before_hr: {script.filename}", exc_info=True)
|
||||||
|
|
||||||
|
def setup_scrips(self, p, *, is_ui=True):
|
||||||
|
for script in self.alwayson_scripts:
|
||||||
|
if not is_ui and script.setup_for_ui_only:
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
script_args = p.script_args[script.args_from:script.args_to]
|
||||||
|
script.setup(p, *script_args)
|
||||||
|
except Exception:
|
||||||
|
errors.report(f"Error running setup: {script.filename}", exc_info=True)
|
||||||
|
|
||||||
|
|
||||||
scripts_txt2img: ScriptRunner = None
|
scripts_txt2img: ScriptRunner = None
|
||||||
scripts_img2img: ScriptRunner = None
|
scripts_img2img: ScriptRunner = None
|
||||||
|
|||||||
@@ -155,10 +155,16 @@ class LoadStateDictOnMeta(ReplaceHelper):
|
|||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, state_dict, device):
|
def __init__(self, state_dict, device, weight_dtype_conversion=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.state_dict = state_dict
|
self.state_dict = state_dict
|
||||||
self.device = device
|
self.device = device
|
||||||
|
self.weight_dtype_conversion = weight_dtype_conversion or {}
|
||||||
|
self.default_dtype = self.weight_dtype_conversion.get('')
|
||||||
|
|
||||||
|
def get_weight_dtype(self, key):
|
||||||
|
key_first_term, _ = key.split('.', 1)
|
||||||
|
return self.weight_dtype_conversion.get(key_first_term, self.default_dtype)
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
if shared.cmd_opts.disable_model_loading_ram_optimization:
|
if shared.cmd_opts.disable_model_loading_ram_optimization:
|
||||||
@@ -167,23 +173,60 @@ class LoadStateDictOnMeta(ReplaceHelper):
|
|||||||
sd = self.state_dict
|
sd = self.state_dict
|
||||||
device = self.device
|
device = self.device
|
||||||
|
|
||||||
def load_from_state_dict(original, self, state_dict, prefix, *args, **kwargs):
|
def load_from_state_dict(original, module, state_dict, prefix, *args, **kwargs):
|
||||||
params = [(name, param) for name, param in self._parameters.items() if param is not None and param.is_meta]
|
used_param_keys = []
|
||||||
|
|
||||||
for name, param in params:
|
for name, param in module._parameters.items():
|
||||||
if param.is_meta:
|
if param is None:
|
||||||
self._parameters[name] = torch.nn.parameter.Parameter(torch.zeros_like(param, device=device), requires_grad=param.requires_grad)
|
continue
|
||||||
|
|
||||||
original(self, state_dict, prefix, *args, **kwargs)
|
|
||||||
|
|
||||||
for name, _ in params:
|
|
||||||
key = prefix + name
|
key = prefix + name
|
||||||
if key in sd:
|
sd_param = sd.pop(key, None)
|
||||||
del sd[key]
|
if sd_param is not None:
|
||||||
|
state_dict[key] = sd_param.to(dtype=self.get_weight_dtype(key))
|
||||||
|
used_param_keys.append(key)
|
||||||
|
|
||||||
|
if param.is_meta:
|
||||||
|
dtype = sd_param.dtype if sd_param is not None else param.dtype
|
||||||
|
module._parameters[name] = torch.nn.parameter.Parameter(torch.zeros_like(param, device=device, dtype=dtype), requires_grad=param.requires_grad)
|
||||||
|
|
||||||
|
for name in module._buffers:
|
||||||
|
key = prefix + name
|
||||||
|
|
||||||
|
sd_param = sd.pop(key, None)
|
||||||
|
if sd_param is not None:
|
||||||
|
state_dict[key] = sd_param
|
||||||
|
used_param_keys.append(key)
|
||||||
|
|
||||||
|
original(module, state_dict, prefix, *args, **kwargs)
|
||||||
|
|
||||||
|
for key in used_param_keys:
|
||||||
|
state_dict.pop(key, None)
|
||||||
|
|
||||||
|
def load_state_dict(original, module, state_dict, strict=True):
|
||||||
|
"""torch makes a lot of copies of the dictionary with weights, so just deleting entries from state_dict does not help
|
||||||
|
because the same values are stored in multiple copies of the dict. The trick used here is to give torch a dict with
|
||||||
|
all weights on meta device, i.e. deleted, and then it doesn't matter how many copies torch makes.
|
||||||
|
|
||||||
|
In _load_from_state_dict, the correct weight will be obtained from a single dict with the right weights (sd).
|
||||||
|
|
||||||
|
The dangerous thing about this is if _load_from_state_dict is not called, (if some exotic module overloads
|
||||||
|
the function and does not call the original) the state dict will just fail to load because weights
|
||||||
|
would be on the meta device.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if state_dict == sd:
|
||||||
|
state_dict = {k: v.to(device="meta", dtype=v.dtype) for k, v in state_dict.items()}
|
||||||
|
|
||||||
|
original(module, state_dict, strict=strict)
|
||||||
|
|
||||||
|
module_load_state_dict = self.replace(torch.nn.Module, 'load_state_dict', lambda *args, **kwargs: load_state_dict(module_load_state_dict, *args, **kwargs))
|
||||||
|
module_load_from_state_dict = self.replace(torch.nn.Module, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(module_load_from_state_dict, *args, **kwargs))
|
||||||
linear_load_from_state_dict = self.replace(torch.nn.Linear, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(linear_load_from_state_dict, *args, **kwargs))
|
linear_load_from_state_dict = self.replace(torch.nn.Linear, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(linear_load_from_state_dict, *args, **kwargs))
|
||||||
conv2d_load_from_state_dict = self.replace(torch.nn.Conv2d, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(conv2d_load_from_state_dict, *args, **kwargs))
|
conv2d_load_from_state_dict = self.replace(torch.nn.Conv2d, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(conv2d_load_from_state_dict, *args, **kwargs))
|
||||||
mha_load_from_state_dict = self.replace(torch.nn.MultiheadAttention, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(mha_load_from_state_dict, *args, **kwargs))
|
mha_load_from_state_dict = self.replace(torch.nn.MultiheadAttention, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(mha_load_from_state_dict, *args, **kwargs))
|
||||||
|
layer_norm_load_from_state_dict = self.replace(torch.nn.LayerNorm, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(layer_norm_load_from_state_dict, *args, **kwargs))
|
||||||
|
group_norm_load_from_state_dict = self.replace(torch.nn.GroupNorm, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(group_norm_load_from_state_dict, *args, **kwargs))
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
self.restore()
|
self.restore()
|
||||||
|
|||||||
+16
-4
@@ -5,7 +5,7 @@ from types import MethodType
|
|||||||
from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet
|
from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet
|
||||||
from modules.hypernetworks import hypernetwork
|
from modules.hypernetworks import hypernetwork
|
||||||
from modules.shared import cmd_opts
|
from modules.shared import cmd_opts
|
||||||
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr, sd_hijack_inpainting
|
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr
|
||||||
|
|
||||||
import ldm.modules.attention
|
import ldm.modules.attention
|
||||||
import ldm.modules.diffusionmodules.model
|
import ldm.modules.diffusionmodules.model
|
||||||
@@ -34,8 +34,6 @@ ldm.modules.diffusionmodules.model.print = shared.ldm_print
|
|||||||
ldm.util.print = shared.ldm_print
|
ldm.util.print = shared.ldm_print
|
||||||
ldm.models.diffusion.ddpm.print = shared.ldm_print
|
ldm.models.diffusion.ddpm.print = shared.ldm_print
|
||||||
|
|
||||||
sd_hijack_inpainting.do_inpainting_hijack()
|
|
||||||
|
|
||||||
optimizers = []
|
optimizers = []
|
||||||
current_optimizer: sd_hijack_optimizations.SdOptimization = None
|
current_optimizer: sd_hijack_optimizations.SdOptimization = None
|
||||||
|
|
||||||
@@ -247,7 +245,21 @@ class StableDiffusionModelHijack:
|
|||||||
ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = sd_unet.UNetModel_forward
|
ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = sd_unet.UNetModel_forward
|
||||||
|
|
||||||
def undo_hijack(self, m):
|
def undo_hijack(self, m):
|
||||||
if type(m.cond_stage_model) == sd_hijack_xlmr.FrozenXLMREmbedderWithCustomWords:
|
conditioner = getattr(m, 'conditioner', None)
|
||||||
|
if conditioner:
|
||||||
|
for i in range(len(conditioner.embedders)):
|
||||||
|
embedder = conditioner.embedders[i]
|
||||||
|
if isinstance(embedder, (sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords, sd_hijack_open_clip.FrozenOpenCLIPEmbedder2WithCustomWords)):
|
||||||
|
embedder.wrapped.model.token_embedding = embedder.wrapped.model.token_embedding.wrapped
|
||||||
|
conditioner.embedders[i] = embedder.wrapped
|
||||||
|
if isinstance(embedder, sd_hijack_clip.FrozenCLIPEmbedderForSDXLWithCustomWords):
|
||||||
|
embedder.wrapped.transformer.text_model.embeddings.token_embedding = embedder.wrapped.transformer.text_model.embeddings.token_embedding.wrapped
|
||||||
|
conditioner.embedders[i] = embedder.wrapped
|
||||||
|
|
||||||
|
if hasattr(m, 'cond_stage_model'):
|
||||||
|
delattr(m, 'cond_stage_model')
|
||||||
|
|
||||||
|
elif type(m.cond_stage_model) == sd_hijack_xlmr.FrozenXLMREmbedderWithCustomWords:
|
||||||
m.cond_stage_model = m.cond_stage_model.wrapped
|
m.cond_stage_model = m.cond_stage_model.wrapped
|
||||||
|
|
||||||
elif type(m.cond_stage_model) == sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords:
|
elif type(m.cond_stage_model) == sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords:
|
||||||
|
|||||||
@@ -1,95 +0,0 @@
|
|||||||
import torch
|
|
||||||
|
|
||||||
import ldm.models.diffusion.ddpm
|
|
||||||
import ldm.models.diffusion.ddim
|
|
||||||
import ldm.models.diffusion.plms
|
|
||||||
|
|
||||||
from ldm.models.diffusion.ddim import noise_like
|
|
||||||
from ldm.models.diffusion.sampling_util import norm_thresholding
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
|
|
||||||
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
|
||||||
unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None, dynamic_threshold=None):
|
|
||||||
b, *_, device = *x.shape, x.device
|
|
||||||
|
|
||||||
def get_model_output(x, t):
|
|
||||||
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
|
|
||||||
e_t = self.model.apply_model(x, t, c)
|
|
||||||
else:
|
|
||||||
x_in = torch.cat([x] * 2)
|
|
||||||
t_in = torch.cat([t] * 2)
|
|
||||||
|
|
||||||
if isinstance(c, dict):
|
|
||||||
assert isinstance(unconditional_conditioning, dict)
|
|
||||||
c_in = {}
|
|
||||||
for k in c:
|
|
||||||
if isinstance(c[k], list):
|
|
||||||
c_in[k] = [
|
|
||||||
torch.cat([unconditional_conditioning[k][i], c[k][i]])
|
|
||||||
for i in range(len(c[k]))
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
c_in[k] = torch.cat([unconditional_conditioning[k], c[k]])
|
|
||||||
else:
|
|
||||||
c_in = torch.cat([unconditional_conditioning, c])
|
|
||||||
|
|
||||||
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
|
|
||||||
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
|
|
||||||
|
|
||||||
if score_corrector is not None:
|
|
||||||
assert self.model.parameterization == "eps"
|
|
||||||
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
|
|
||||||
|
|
||||||
return e_t
|
|
||||||
|
|
||||||
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
|
||||||
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
|
|
||||||
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
|
|
||||||
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
|
|
||||||
|
|
||||||
def get_x_prev_and_pred_x0(e_t, index):
|
|
||||||
# select parameters corresponding to the currently considered timestep
|
|
||||||
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
|
|
||||||
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
|
|
||||||
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
|
|
||||||
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
|
|
||||||
|
|
||||||
# current prediction for x_0
|
|
||||||
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
|
||||||
if quantize_denoised:
|
|
||||||
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
|
||||||
if dynamic_threshold is not None:
|
|
||||||
pred_x0 = norm_thresholding(pred_x0, dynamic_threshold)
|
|
||||||
# direction pointing to x_t
|
|
||||||
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
|
|
||||||
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
|
|
||||||
if noise_dropout > 0.:
|
|
||||||
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
|
||||||
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
|
||||||
return x_prev, pred_x0
|
|
||||||
|
|
||||||
e_t = get_model_output(x, t)
|
|
||||||
if len(old_eps) == 0:
|
|
||||||
# Pseudo Improved Euler (2nd order)
|
|
||||||
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
|
|
||||||
e_t_next = get_model_output(x_prev, t_next)
|
|
||||||
e_t_prime = (e_t + e_t_next) / 2
|
|
||||||
elif len(old_eps) == 1:
|
|
||||||
# 2nd order Pseudo Linear Multistep (Adams-Bashforth)
|
|
||||||
e_t_prime = (3 * e_t - old_eps[-1]) / 2
|
|
||||||
elif len(old_eps) == 2:
|
|
||||||
# 3nd order Pseudo Linear Multistep (Adams-Bashforth)
|
|
||||||
e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
|
|
||||||
elif len(old_eps) >= 3:
|
|
||||||
# 4nd order Pseudo Linear Multistep (Adams-Bashforth)
|
|
||||||
e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
|
|
||||||
|
|
||||||
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
|
|
||||||
|
|
||||||
return x_prev, pred_x0, e_t
|
|
||||||
|
|
||||||
|
|
||||||
def do_inpainting_hijack():
|
|
||||||
ldm.models.diffusion.plms.PLMSSampler.p_sample_plms = p_sample_plms
|
|
||||||
@@ -1,6 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
import math
|
import math
|
||||||
import psutil
|
import psutil
|
||||||
|
import platform
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import einsum
|
from torch import einsum
|
||||||
@@ -94,7 +95,10 @@ class SdOptimizationSdp(SdOptimizationSdpNoMem):
|
|||||||
class SdOptimizationSubQuad(SdOptimization):
|
class SdOptimizationSubQuad(SdOptimization):
|
||||||
name = "sub-quadratic"
|
name = "sub-quadratic"
|
||||||
cmd_opt = "opt_sub_quad_attention"
|
cmd_opt = "opt_sub_quad_attention"
|
||||||
priority = 10
|
|
||||||
|
@property
|
||||||
|
def priority(self):
|
||||||
|
return 1000 if shared.device.type == 'mps' else 10
|
||||||
|
|
||||||
def apply(self):
|
def apply(self):
|
||||||
ldm.modules.attention.CrossAttention.forward = sub_quad_attention_forward
|
ldm.modules.attention.CrossAttention.forward = sub_quad_attention_forward
|
||||||
@@ -120,7 +124,7 @@ class SdOptimizationInvokeAI(SdOptimization):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def priority(self):
|
def priority(self):
|
||||||
return 1000 if not torch.cuda.is_available() else 10
|
return 1000 if shared.device.type != 'mps' and not torch.cuda.is_available() else 10
|
||||||
|
|
||||||
def apply(self):
|
def apply(self):
|
||||||
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_invokeAI
|
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_invokeAI
|
||||||
@@ -427,7 +431,10 @@ def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_
|
|||||||
qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens
|
qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens
|
||||||
|
|
||||||
if chunk_threshold is None:
|
if chunk_threshold is None:
|
||||||
chunk_threshold_bytes = int(get_available_vram() * 0.9) if q.device.type == 'mps' else int(get_available_vram() * 0.7)
|
if q.device.type == 'mps':
|
||||||
|
chunk_threshold_bytes = 268435456 * (2 if platform.processor() == 'i386' else bytes_per_token)
|
||||||
|
else:
|
||||||
|
chunk_threshold_bytes = int(get_available_vram() * 0.7)
|
||||||
elif chunk_threshold == 0:
|
elif chunk_threshold == 0:
|
||||||
chunk_threshold_bytes = None
|
chunk_threshold_bytes = None
|
||||||
else:
|
else:
|
||||||
|
|||||||
+74
-22
@@ -14,7 +14,7 @@ import ldm.modules.midas as midas
|
|||||||
|
|
||||||
from ldm.util import instantiate_from_config
|
from ldm.util import instantiate_from_config
|
||||||
|
|
||||||
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache
|
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache, extra_networks, processing, lowvram, sd_hijack
|
||||||
from modules.timer import Timer
|
from modules.timer import Timer
|
||||||
import tomesd
|
import tomesd
|
||||||
|
|
||||||
@@ -68,7 +68,9 @@ class CheckpointInfo:
|
|||||||
self.title = name if self.shorthash is None else f'{name} [{self.shorthash}]'
|
self.title = name if self.shorthash is None else f'{name} [{self.shorthash}]'
|
||||||
self.short_title = self.name_for_extra if self.shorthash is None else f'{self.name_for_extra} [{self.shorthash}]'
|
self.short_title = self.name_for_extra if self.shorthash is None else f'{self.name_for_extra} [{self.shorthash}]'
|
||||||
|
|
||||||
self.ids = [self.hash, self.model_name, self.title, name, self.name_for_extra, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] if self.shorthash else [])
|
self.ids = [self.hash, self.model_name, self.title, name, self.name_for_extra, f'{name} [{self.hash}]']
|
||||||
|
if self.shorthash:
|
||||||
|
self.ids += [self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]', f'{self.name_for_extra} [{self.shorthash}]']
|
||||||
|
|
||||||
def register(self):
|
def register(self):
|
||||||
checkpoints_list[self.title] = self
|
checkpoints_list[self.title] = self
|
||||||
@@ -80,10 +82,14 @@ class CheckpointInfo:
|
|||||||
if self.sha256 is None:
|
if self.sha256 is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
self.shorthash = self.sha256[0:10]
|
shorthash = self.sha256[0:10]
|
||||||
|
if self.shorthash == self.sha256[0:10]:
|
||||||
|
return self.shorthash
|
||||||
|
|
||||||
|
self.shorthash = shorthash
|
||||||
|
|
||||||
if self.shorthash not in self.ids:
|
if self.shorthash not in self.ids:
|
||||||
self.ids += [self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]']
|
self.ids += [self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]', f'{self.name_for_extra} [{self.shorthash}]']
|
||||||
|
|
||||||
checkpoints_list.pop(self.title, None)
|
checkpoints_list.pop(self.title, None)
|
||||||
self.title = f'{self.name} [{self.shorthash}]'
|
self.title = f'{self.name} [{self.shorthash}]'
|
||||||
@@ -141,6 +147,9 @@ re_strip_checksum = re.compile(r"\s*\[[^]]+]\s*$")
|
|||||||
|
|
||||||
|
|
||||||
def get_closet_checkpoint_match(search_string):
|
def get_closet_checkpoint_match(search_string):
|
||||||
|
if not search_string:
|
||||||
|
return None
|
||||||
|
|
||||||
checkpoint_info = checkpoint_aliases.get(search_string, None)
|
checkpoint_info = checkpoint_aliases.get(search_string, None)
|
||||||
if checkpoint_info is not None:
|
if checkpoint_info is not None:
|
||||||
return checkpoint_info
|
return checkpoint_info
|
||||||
@@ -289,10 +298,26 @@ def get_checkpoint_state_dict(checkpoint_info: CheckpointInfo, timer):
|
|||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
class SkipWritingToConfig:
|
||||||
|
"""This context manager prevents load_model_weights from writing checkpoint name to the config when it loads weight."""
|
||||||
|
|
||||||
|
skip = False
|
||||||
|
previous = None
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
self.previous = SkipWritingToConfig.skip
|
||||||
|
SkipWritingToConfig.skip = True
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_value, exc_traceback):
|
||||||
|
SkipWritingToConfig.skip = self.previous
|
||||||
|
|
||||||
|
|
||||||
def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer):
|
def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer):
|
||||||
sd_model_hash = checkpoint_info.calculate_shorthash()
|
sd_model_hash = checkpoint_info.calculate_shorthash()
|
||||||
timer.record("calculate hash")
|
timer.record("calculate hash")
|
||||||
|
|
||||||
|
if not SkipWritingToConfig.skip:
|
||||||
shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title
|
shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title
|
||||||
|
|
||||||
if state_dict is None:
|
if state_dict is None:
|
||||||
@@ -318,7 +343,11 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
|
|||||||
model.to(memory_format=torch.channels_last)
|
model.to(memory_format=torch.channels_last)
|
||||||
timer.record("apply channels_last")
|
timer.record("apply channels_last")
|
||||||
|
|
||||||
if not shared.cmd_opts.no_half:
|
if shared.cmd_opts.no_half:
|
||||||
|
model.float()
|
||||||
|
devices.dtype_unet = torch.float32
|
||||||
|
timer.record("apply float()")
|
||||||
|
else:
|
||||||
vae = model.first_stage_model
|
vae = model.first_stage_model
|
||||||
depth_model = getattr(model, 'depth_model', None)
|
depth_model = getattr(model, 'depth_model', None)
|
||||||
|
|
||||||
@@ -334,9 +363,9 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
|
|||||||
if depth_model:
|
if depth_model:
|
||||||
model.depth_model = depth_model
|
model.depth_model = depth_model
|
||||||
|
|
||||||
|
devices.dtype_unet = torch.float16
|
||||||
timer.record("apply half()")
|
timer.record("apply half()")
|
||||||
|
|
||||||
devices.dtype_unet = torch.float16 if model.is_sdxl and not shared.cmd_opts.no_half else model.model.diffusion_model.dtype
|
|
||||||
devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16
|
devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16
|
||||||
|
|
||||||
model.first_stage_model.to(devices.dtype_vae)
|
model.first_stage_model.to(devices.dtype_vae)
|
||||||
@@ -356,7 +385,7 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
|
|||||||
|
|
||||||
sd_vae.delete_base_vae()
|
sd_vae.delete_base_vae()
|
||||||
sd_vae.clear_loaded_vae()
|
sd_vae.clear_loaded_vae()
|
||||||
vae_file, vae_source = sd_vae.resolve_vae(checkpoint_info.filename)
|
vae_file, vae_source = sd_vae.resolve_vae(checkpoint_info.filename).tuple()
|
||||||
sd_vae.load_vae(model, vae_file, vae_source)
|
sd_vae.load_vae(model, vae_file, vae_source)
|
||||||
timer.record("load VAE")
|
timer.record("load VAE")
|
||||||
|
|
||||||
@@ -457,8 +486,12 @@ class SdModelData:
|
|||||||
|
|
||||||
return self.sd_model
|
return self.sd_model
|
||||||
|
|
||||||
def set_sd_model(self, v):
|
def set_sd_model(self, v, already_loaded=False):
|
||||||
self.sd_model = v
|
self.sd_model = v
|
||||||
|
if already_loaded:
|
||||||
|
sd_vae.base_vae = getattr(v, "base_vae", None)
|
||||||
|
sd_vae.loaded_vae_file = getattr(v, "loaded_vae_file", None)
|
||||||
|
sd_vae.checkpoint_info = v.sd_checkpoint_info
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.loaded_sd_models.remove(v)
|
self.loaded_sd_models.remove(v)
|
||||||
@@ -473,7 +506,6 @@ model_data = SdModelData()
|
|||||||
|
|
||||||
|
|
||||||
def get_empty_cond(sd_model):
|
def get_empty_cond(sd_model):
|
||||||
from modules import extra_networks, processing
|
|
||||||
|
|
||||||
p = processing.StableDiffusionProcessingTxt2Img()
|
p = processing.StableDiffusionProcessingTxt2Img()
|
||||||
extra_networks.activate(p, {})
|
extra_networks.activate(p, {})
|
||||||
@@ -486,9 +518,7 @@ def get_empty_cond(sd_model):
|
|||||||
|
|
||||||
|
|
||||||
def send_model_to_cpu(m):
|
def send_model_to_cpu(m):
|
||||||
from modules import lowvram
|
if m.lowvram:
|
||||||
|
|
||||||
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
|
||||||
lowvram.send_everything_to_cpu()
|
lowvram.send_everything_to_cpu()
|
||||||
else:
|
else:
|
||||||
m.to(devices.cpu)
|
m.to(devices.cpu)
|
||||||
@@ -496,12 +526,17 @@ def send_model_to_cpu(m):
|
|||||||
devices.torch_gc()
|
devices.torch_gc()
|
||||||
|
|
||||||
|
|
||||||
def send_model_to_device(m):
|
def model_target_device(m):
|
||||||
from modules import lowvram
|
if lowvram.is_needed(m):
|
||||||
|
return devices.cpu
|
||||||
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
|
||||||
lowvram.setup_for_low_vram(m, shared.cmd_opts.medvram)
|
|
||||||
else:
|
else:
|
||||||
|
return devices.device
|
||||||
|
|
||||||
|
|
||||||
|
def send_model_to_device(m):
|
||||||
|
lowvram.apply(m)
|
||||||
|
|
||||||
|
if not m.lowvram:
|
||||||
m.to(shared.device)
|
m.to(shared.device)
|
||||||
|
|
||||||
|
|
||||||
@@ -559,7 +594,15 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
|
|||||||
|
|
||||||
timer.record("create model")
|
timer.record("create model")
|
||||||
|
|
||||||
with sd_disable_initialization.LoadStateDictOnMeta(state_dict, devices.cpu):
|
if shared.cmd_opts.no_half:
|
||||||
|
weight_dtype_conversion = None
|
||||||
|
else:
|
||||||
|
weight_dtype_conversion = {
|
||||||
|
'first_stage_model': None,
|
||||||
|
'': torch.float16,
|
||||||
|
}
|
||||||
|
|
||||||
|
with sd_disable_initialization.LoadStateDictOnMeta(state_dict, device=model_target_device(sd_model), weight_dtype_conversion=weight_dtype_conversion):
|
||||||
load_model_weights(sd_model, checkpoint_info, state_dict, timer)
|
load_model_weights(sd_model, checkpoint_info, state_dict, timer)
|
||||||
timer.record("load weights from state dict")
|
timer.record("load weights from state dict")
|
||||||
|
|
||||||
@@ -622,8 +665,14 @@ def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer):
|
|||||||
send_model_to_device(already_loaded)
|
send_model_to_device(already_loaded)
|
||||||
timer.record("send model to device")
|
timer.record("send model to device")
|
||||||
|
|
||||||
model_data.set_sd_model(already_loaded)
|
model_data.set_sd_model(already_loaded, already_loaded=True)
|
||||||
|
|
||||||
|
if not SkipWritingToConfig.skip:
|
||||||
|
shared.opts.data["sd_model_checkpoint"] = already_loaded.sd_checkpoint_info.title
|
||||||
|
shared.opts.data["sd_checkpoint_hash"] = already_loaded.sd_checkpoint_info.sha256
|
||||||
|
|
||||||
print(f"Using already loaded model {already_loaded.sd_checkpoint_info.title}: done in {timer.summary()}")
|
print(f"Using already loaded model {already_loaded.sd_checkpoint_info.title}: done in {timer.summary()}")
|
||||||
|
sd_vae.reload_vae_weights(already_loaded)
|
||||||
return model_data.sd_model
|
return model_data.sd_model
|
||||||
elif shared.opts.sd_checkpoints_limit > 1 and len(model_data.loaded_sd_models) < shared.opts.sd_checkpoints_limit:
|
elif shared.opts.sd_checkpoints_limit > 1 and len(model_data.loaded_sd_models) < shared.opts.sd_checkpoints_limit:
|
||||||
print(f"Loading model {checkpoint_info.title} ({len(model_data.loaded_sd_models) + 1} out of {shared.opts.sd_checkpoints_limit})")
|
print(f"Loading model {checkpoint_info.title} ({len(model_data.loaded_sd_models) + 1} out of {shared.opts.sd_checkpoints_limit})")
|
||||||
@@ -635,6 +684,10 @@ def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer):
|
|||||||
sd_model = model_data.loaded_sd_models.pop()
|
sd_model = model_data.loaded_sd_models.pop()
|
||||||
model_data.sd_model = sd_model
|
model_data.sd_model = sd_model
|
||||||
|
|
||||||
|
sd_vae.base_vae = getattr(sd_model, "base_vae", None)
|
||||||
|
sd_vae.loaded_vae_file = getattr(sd_model, "loaded_vae_file", None)
|
||||||
|
sd_vae.checkpoint_info = sd_model.sd_checkpoint_info
|
||||||
|
|
||||||
print(f"Reusing loaded model {sd_model.sd_checkpoint_info.title} to load {checkpoint_info.title}")
|
print(f"Reusing loaded model {sd_model.sd_checkpoint_info.title} to load {checkpoint_info.title}")
|
||||||
return sd_model
|
return sd_model
|
||||||
else:
|
else:
|
||||||
@@ -642,7 +695,6 @@ def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer):
|
|||||||
|
|
||||||
|
|
||||||
def reload_model_weights(sd_model=None, info=None):
|
def reload_model_weights(sd_model=None, info=None):
|
||||||
from modules import devices, sd_hijack
|
|
||||||
checkpoint_info = info or select_checkpoint()
|
checkpoint_info = info or select_checkpoint()
|
||||||
|
|
||||||
timer = Timer()
|
timer = Timer()
|
||||||
@@ -692,19 +744,19 @@ def reload_model_weights(sd_model=None, info=None):
|
|||||||
script_callbacks.model_loaded_callback(sd_model)
|
script_callbacks.model_loaded_callback(sd_model)
|
||||||
timer.record("script callbacks")
|
timer.record("script callbacks")
|
||||||
|
|
||||||
if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
|
if not sd_model.lowvram:
|
||||||
sd_model.to(devices.device)
|
sd_model.to(devices.device)
|
||||||
timer.record("move model to device")
|
timer.record("move model to device")
|
||||||
|
|
||||||
print(f"Weights loaded in {timer.summary()}.")
|
print(f"Weights loaded in {timer.summary()}.")
|
||||||
|
|
||||||
model_data.set_sd_model(sd_model)
|
model_data.set_sd_model(sd_model)
|
||||||
|
sd_unet.apply_unet()
|
||||||
|
|
||||||
return sd_model
|
return sd_model
|
||||||
|
|
||||||
|
|
||||||
def unload_model_weights(sd_model=None, info=None):
|
def unload_model_weights(sd_model=None, info=None):
|
||||||
from modules import devices, sd_hijack
|
|
||||||
timer = Timer()
|
timer = Timer()
|
||||||
|
|
||||||
if model_data.sd_model:
|
if model_data.sd_model:
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ import os
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from modules import shared, paths, sd_disable_initialization
|
from modules import shared, paths, sd_disable_initialization, devices
|
||||||
|
|
||||||
sd_configs_path = shared.sd_configs_path
|
sd_configs_path = shared.sd_configs_path
|
||||||
sd_repo_configs_path = os.path.join(paths.paths['Stable Diffusion'], "configs", "stable-diffusion")
|
sd_repo_configs_path = os.path.join(paths.paths['Stable Diffusion'], "configs", "stable-diffusion")
|
||||||
@@ -29,7 +29,6 @@ def is_using_v_parameterization_for_sd2(state_dict):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import ldm.modules.diffusionmodules.openaimodel
|
import ldm.modules.diffusionmodules.openaimodel
|
||||||
from modules import devices
|
|
||||||
|
|
||||||
device = devices.cpu
|
device = devices.cpu
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,31 @@
|
|||||||
|
from ldm.models.diffusion.ddpm import LatentDiffusion
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from modules.sd_models import CheckpointInfo
|
||||||
|
|
||||||
|
|
||||||
|
class WebuiSdModel(LatentDiffusion):
|
||||||
|
"""This class is not actually instantinated, but its fields are created and fieeld by webui"""
|
||||||
|
|
||||||
|
lowvram: bool
|
||||||
|
"""True if lowvram/medvram optimizations are enabled -- see modules.lowvram for more info"""
|
||||||
|
|
||||||
|
sd_model_hash: str
|
||||||
|
"""short hash, 10 first characters of SHA1 hash of the model file; may be None if --no-hashing flag is used"""
|
||||||
|
|
||||||
|
sd_model_checkpoint: str
|
||||||
|
"""path to the file on disk that model weights were obtained from"""
|
||||||
|
|
||||||
|
sd_checkpoint_info: 'CheckpointInfo'
|
||||||
|
"""structure with additional information about the file with model's weights"""
|
||||||
|
|
||||||
|
is_sdxl: bool
|
||||||
|
"""True if the model's architecture is SDXL"""
|
||||||
|
|
||||||
|
is_sd2: bool
|
||||||
|
"""True if the model's architecture is SD 2.x"""
|
||||||
|
|
||||||
|
is_sd1: bool
|
||||||
|
"""True if the model's architecture is SD 1.x"""
|
||||||
+11
-8
@@ -1,17 +1,18 @@
|
|||||||
from modules import sd_samplers_compvis, sd_samplers_kdiffusion, shared
|
from modules import sd_samplers_kdiffusion, sd_samplers_timesteps, shared
|
||||||
|
|
||||||
# imports for functions that previously were here and are used by other modules
|
# imports for functions that previously were here and are used by other modules
|
||||||
from modules.sd_samplers_common import samples_to_image_grid, sample_to_image # noqa: F401
|
from modules.sd_samplers_common import samples_to_image_grid, sample_to_image # noqa: F401
|
||||||
|
|
||||||
all_samplers = [
|
all_samplers = [
|
||||||
*sd_samplers_kdiffusion.samplers_data_k_diffusion,
|
*sd_samplers_kdiffusion.samplers_data_k_diffusion,
|
||||||
*sd_samplers_compvis.samplers_data_compvis,
|
*sd_samplers_timesteps.samplers_data_timesteps,
|
||||||
]
|
]
|
||||||
all_samplers_map = {x.name: x for x in all_samplers}
|
all_samplers_map = {x.name: x for x in all_samplers}
|
||||||
|
|
||||||
samplers = []
|
samplers = []
|
||||||
samplers_for_img2img = []
|
samplers_for_img2img = []
|
||||||
samplers_map = {}
|
samplers_map = {}
|
||||||
|
samplers_hidden = {}
|
||||||
|
|
||||||
|
|
||||||
def find_sampler_config(name):
|
def find_sampler_config(name):
|
||||||
@@ -38,13 +39,11 @@ def create_sampler(name, model):
|
|||||||
|
|
||||||
|
|
||||||
def set_samplers():
|
def set_samplers():
|
||||||
global samplers, samplers_for_img2img
|
global samplers, samplers_for_img2img, samplers_hidden
|
||||||
|
|
||||||
hidden = set(shared.opts.hide_samplers)
|
samplers_hidden = set(shared.opts.hide_samplers)
|
||||||
hidden_img2img = set(shared.opts.hide_samplers + ['PLMS', 'UniPC'])
|
samplers = all_samplers
|
||||||
|
samplers_for_img2img = all_samplers
|
||||||
samplers = [x for x in all_samplers if x.name not in hidden]
|
|
||||||
samplers_for_img2img = [x for x in all_samplers if x.name not in hidden_img2img]
|
|
||||||
|
|
||||||
samplers_map.clear()
|
samplers_map.clear()
|
||||||
for sampler in all_samplers:
|
for sampler in all_samplers:
|
||||||
@@ -53,4 +52,8 @@ def set_samplers():
|
|||||||
samplers_map[alias.lower()] = sampler.name
|
samplers_map[alias.lower()] = sampler.name
|
||||||
|
|
||||||
|
|
||||||
|
def visible_sampler_names():
|
||||||
|
return [x.name for x in samplers if x.name not in samplers_hidden]
|
||||||
|
|
||||||
|
|
||||||
set_samplers()
|
set_samplers()
|
||||||
|
|||||||
@@ -0,0 +1,230 @@
|
|||||||
|
import torch
|
||||||
|
from modules import prompt_parser, devices, sd_samplers_common
|
||||||
|
|
||||||
|
from modules.shared import opts, state
|
||||||
|
import modules.shared as shared
|
||||||
|
from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback
|
||||||
|
from modules.script_callbacks import CFGDenoisedParams, cfg_denoised_callback
|
||||||
|
from modules.script_callbacks import AfterCFGCallbackParams, cfg_after_cfg_callback
|
||||||
|
|
||||||
|
|
||||||
|
def catenate_conds(conds):
|
||||||
|
if not isinstance(conds[0], dict):
|
||||||
|
return torch.cat(conds)
|
||||||
|
|
||||||
|
return {key: torch.cat([x[key] for x in conds]) for key in conds[0].keys()}
|
||||||
|
|
||||||
|
|
||||||
|
def subscript_cond(cond, a, b):
|
||||||
|
if not isinstance(cond, dict):
|
||||||
|
return cond[a:b]
|
||||||
|
|
||||||
|
return {key: vec[a:b] for key, vec in cond.items()}
|
||||||
|
|
||||||
|
|
||||||
|
def pad_cond(tensor, repeats, empty):
|
||||||
|
if not isinstance(tensor, dict):
|
||||||
|
return torch.cat([tensor, empty.repeat((tensor.shape[0], repeats, 1))], axis=1)
|
||||||
|
|
||||||
|
tensor['crossattn'] = pad_cond(tensor['crossattn'], repeats, empty)
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
|
class CFGDenoiser(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
Classifier free guidance denoiser. A wrapper for stable diffusion model (specifically for unet)
|
||||||
|
that can take a noisy picture and produce a noise-free picture using two guidances (prompts)
|
||||||
|
instead of one. Originally, the second prompt is just an empty string, but we use non-empty
|
||||||
|
negative prompt.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, sampler):
|
||||||
|
super().__init__()
|
||||||
|
self.model_wrap = None
|
||||||
|
self.mask = None
|
||||||
|
self.nmask = None
|
||||||
|
self.init_latent = None
|
||||||
|
self.steps = None
|
||||||
|
"""number of steps as specified by user in UI"""
|
||||||
|
|
||||||
|
self.total_steps = None
|
||||||
|
"""expected number of calls to denoiser calculated from self.steps and specifics of the selected sampler"""
|
||||||
|
|
||||||
|
self.step = 0
|
||||||
|
self.image_cfg_scale = None
|
||||||
|
self.padded_cond_uncond = False
|
||||||
|
self.sampler = sampler
|
||||||
|
self.model_wrap = None
|
||||||
|
self.p = None
|
||||||
|
self.mask_before_denoising = False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def inner_model(self):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def combine_denoised(self, x_out, conds_list, uncond, cond_scale):
|
||||||
|
denoised_uncond = x_out[-uncond.shape[0]:]
|
||||||
|
denoised = torch.clone(denoised_uncond)
|
||||||
|
|
||||||
|
for i, conds in enumerate(conds_list):
|
||||||
|
for cond_index, weight in conds:
|
||||||
|
denoised[i] += (x_out[cond_index] - denoised_uncond[i]) * (weight * cond_scale)
|
||||||
|
|
||||||
|
return denoised
|
||||||
|
|
||||||
|
def combine_denoised_for_edit_model(self, x_out, cond_scale):
|
||||||
|
out_cond, out_img_cond, out_uncond = x_out.chunk(3)
|
||||||
|
denoised = out_uncond + cond_scale * (out_cond - out_img_cond) + self.image_cfg_scale * (out_img_cond - out_uncond)
|
||||||
|
|
||||||
|
return denoised
|
||||||
|
|
||||||
|
def get_pred_x0(self, x_in, x_out, sigma):
|
||||||
|
return x_out
|
||||||
|
|
||||||
|
def update_inner_model(self):
|
||||||
|
self.model_wrap = None
|
||||||
|
|
||||||
|
c, uc = self.p.get_conds()
|
||||||
|
self.sampler.sampler_extra_args['cond'] = c
|
||||||
|
self.sampler.sampler_extra_args['uncond'] = uc
|
||||||
|
|
||||||
|
def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond):
|
||||||
|
if state.interrupted or state.skipped:
|
||||||
|
raise sd_samplers_common.InterruptedException
|
||||||
|
|
||||||
|
if sd_samplers_common.apply_refiner(self):
|
||||||
|
cond = self.sampler.sampler_extra_args['cond']
|
||||||
|
uncond = self.sampler.sampler_extra_args['uncond']
|
||||||
|
|
||||||
|
# at self.image_cfg_scale == 1.0 produced results for edit model are the same as with normal sampling,
|
||||||
|
# so is_edit_model is set to False to support AND composition.
|
||||||
|
is_edit_model = shared.sd_model.cond_stage_key == "edit" and self.image_cfg_scale is not None and self.image_cfg_scale != 1.0
|
||||||
|
|
||||||
|
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
|
||||||
|
uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step)
|
||||||
|
|
||||||
|
assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)"
|
||||||
|
|
||||||
|
if self.mask_before_denoising and self.mask is not None:
|
||||||
|
x = self.init_latent * self.mask + self.nmask * x
|
||||||
|
|
||||||
|
batch_size = len(conds_list)
|
||||||
|
repeats = [len(conds_list[i]) for i in range(batch_size)]
|
||||||
|
|
||||||
|
if shared.sd_model.model.conditioning_key == "crossattn-adm":
|
||||||
|
image_uncond = torch.zeros_like(image_cond)
|
||||||
|
make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": [c_crossattn], "c_adm": c_adm}
|
||||||
|
else:
|
||||||
|
image_uncond = image_cond
|
||||||
|
if isinstance(uncond, dict):
|
||||||
|
make_condition_dict = lambda c_crossattn, c_concat: {**c_crossattn, "c_concat": [c_concat]}
|
||||||
|
else:
|
||||||
|
make_condition_dict = lambda c_crossattn, c_concat: {"c_crossattn": [c_crossattn], "c_concat": [c_concat]}
|
||||||
|
|
||||||
|
if not is_edit_model:
|
||||||
|
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
|
||||||
|
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma])
|
||||||
|
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_uncond])
|
||||||
|
else:
|
||||||
|
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x] + [x])
|
||||||
|
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma] + [sigma])
|
||||||
|
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_uncond] + [torch.zeros_like(self.init_latent)])
|
||||||
|
|
||||||
|
denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps, tensor, uncond)
|
||||||
|
cfg_denoiser_callback(denoiser_params)
|
||||||
|
x_in = denoiser_params.x
|
||||||
|
image_cond_in = denoiser_params.image_cond
|
||||||
|
sigma_in = denoiser_params.sigma
|
||||||
|
tensor = denoiser_params.text_cond
|
||||||
|
uncond = denoiser_params.text_uncond
|
||||||
|
skip_uncond = False
|
||||||
|
|
||||||
|
# alternating uncond allows for higher thresholds without the quality loss normally expected from raising it
|
||||||
|
if self.step % 2 and s_min_uncond > 0 and sigma[0] < s_min_uncond and not is_edit_model:
|
||||||
|
skip_uncond = True
|
||||||
|
x_in = x_in[:-batch_size]
|
||||||
|
sigma_in = sigma_in[:-batch_size]
|
||||||
|
|
||||||
|
self.padded_cond_uncond = False
|
||||||
|
if shared.opts.pad_cond_uncond and tensor.shape[1] != uncond.shape[1]:
|
||||||
|
empty = shared.sd_model.cond_stage_model_empty_prompt
|
||||||
|
num_repeats = (tensor.shape[1] - uncond.shape[1]) // empty.shape[1]
|
||||||
|
|
||||||
|
if num_repeats < 0:
|
||||||
|
tensor = pad_cond(tensor, -num_repeats, empty)
|
||||||
|
self.padded_cond_uncond = True
|
||||||
|
elif num_repeats > 0:
|
||||||
|
uncond = pad_cond(uncond, num_repeats, empty)
|
||||||
|
self.padded_cond_uncond = True
|
||||||
|
|
||||||
|
if tensor.shape[1] == uncond.shape[1] or skip_uncond:
|
||||||
|
if is_edit_model:
|
||||||
|
cond_in = catenate_conds([tensor, uncond, uncond])
|
||||||
|
elif skip_uncond:
|
||||||
|
cond_in = tensor
|
||||||
|
else:
|
||||||
|
cond_in = catenate_conds([tensor, uncond])
|
||||||
|
|
||||||
|
if shared.opts.batch_cond_uncond:
|
||||||
|
x_out = self.inner_model(x_in, sigma_in, cond=make_condition_dict(cond_in, image_cond_in))
|
||||||
|
else:
|
||||||
|
x_out = torch.zeros_like(x_in)
|
||||||
|
for batch_offset in range(0, x_out.shape[0], batch_size):
|
||||||
|
a = batch_offset
|
||||||
|
b = a + batch_size
|
||||||
|
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(subscript_cond(cond_in, a, b), image_cond_in[a:b]))
|
||||||
|
else:
|
||||||
|
x_out = torch.zeros_like(x_in)
|
||||||
|
batch_size = batch_size*2 if shared.opts.batch_cond_uncond else batch_size
|
||||||
|
for batch_offset in range(0, tensor.shape[0], batch_size):
|
||||||
|
a = batch_offset
|
||||||
|
b = min(a + batch_size, tensor.shape[0])
|
||||||
|
|
||||||
|
if not is_edit_model:
|
||||||
|
c_crossattn = subscript_cond(tensor, a, b)
|
||||||
|
else:
|
||||||
|
c_crossattn = torch.cat([tensor[a:b]], uncond)
|
||||||
|
|
||||||
|
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(c_crossattn, image_cond_in[a:b]))
|
||||||
|
|
||||||
|
if not skip_uncond:
|
||||||
|
x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=make_condition_dict(uncond, image_cond_in[-uncond.shape[0]:]))
|
||||||
|
|
||||||
|
denoised_image_indexes = [x[0][0] for x in conds_list]
|
||||||
|
if skip_uncond:
|
||||||
|
fake_uncond = torch.cat([x_out[i:i+1] for i in denoised_image_indexes])
|
||||||
|
x_out = torch.cat([x_out, fake_uncond]) # we skipped uncond denoising, so we put cond-denoised image to where the uncond-denoised image should be
|
||||||
|
|
||||||
|
denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps, self.inner_model)
|
||||||
|
cfg_denoised_callback(denoised_params)
|
||||||
|
|
||||||
|
devices.test_for_nans(x_out, "unet")
|
||||||
|
|
||||||
|
if is_edit_model:
|
||||||
|
denoised = self.combine_denoised_for_edit_model(x_out, cond_scale)
|
||||||
|
elif skip_uncond:
|
||||||
|
denoised = self.combine_denoised(x_out, conds_list, uncond, 1.0)
|
||||||
|
else:
|
||||||
|
denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
|
||||||
|
|
||||||
|
if not self.mask_before_denoising and self.mask is not None:
|
||||||
|
denoised = self.init_latent * self.mask + self.nmask * denoised
|
||||||
|
|
||||||
|
self.sampler.last_latent = self.get_pred_x0(torch.cat([x_in[i:i + 1] for i in denoised_image_indexes]), torch.cat([x_out[i:i + 1] for i in denoised_image_indexes]), sigma)
|
||||||
|
|
||||||
|
if opts.live_preview_content == "Prompt":
|
||||||
|
preview = self.sampler.last_latent
|
||||||
|
elif opts.live_preview_content == "Negative prompt":
|
||||||
|
preview = self.get_pred_x0(x_in[-uncond.shape[0]:], x_out[-uncond.shape[0]:], sigma)
|
||||||
|
else:
|
||||||
|
preview = self.get_pred_x0(torch.cat([x_in[i:i+1] for i in denoised_image_indexes]), torch.cat([denoised[i:i+1] for i in denoised_image_indexes]), sigma)
|
||||||
|
|
||||||
|
sd_samplers_common.store_latent(preview)
|
||||||
|
|
||||||
|
after_cfg_callback_params = AfterCFGCallbackParams(denoised, state.sampling_step, state.sampling_steps)
|
||||||
|
cfg_after_cfg_callback(after_cfg_callback_params)
|
||||||
|
denoised = after_cfg_callback_params.x
|
||||||
|
|
||||||
|
self.step += 1
|
||||||
|
return denoised
|
||||||
|
|
||||||
@@ -1,11 +1,22 @@
|
|||||||
|
import inspect
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from modules import devices, images, sd_vae_approx, sd_samplers, sd_vae_taesd, shared
|
from modules import devices, images, sd_vae_approx, sd_samplers, sd_vae_taesd, shared, sd_models
|
||||||
from modules.shared import opts, state
|
from modules.shared import opts, state
|
||||||
|
import k_diffusion.sampling
|
||||||
|
|
||||||
SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])
|
|
||||||
|
SamplerDataTuple = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])
|
||||||
|
|
||||||
|
|
||||||
|
class SamplerData(SamplerDataTuple):
|
||||||
|
def total_steps(self, steps):
|
||||||
|
if self.options.get("second_order", False):
|
||||||
|
steps = steps * 2
|
||||||
|
|
||||||
|
return steps
|
||||||
|
|
||||||
|
|
||||||
def setup_img2img_steps(p, steps=None):
|
def setup_img2img_steps(p, steps=None):
|
||||||
@@ -24,21 +35,26 @@ approximation_indexes = {"Full": 0, "Approx NN": 1, "Approx cheap": 2, "TAESD":
|
|||||||
|
|
||||||
|
|
||||||
def samples_to_images_tensor(sample, approximation=None, model=None):
|
def samples_to_images_tensor(sample, approximation=None, model=None):
|
||||||
'''latents -> images [-1, 1]'''
|
"""Transforms 4-channel latent space images into 3-channel RGB image tensors, with values in range [-1, 1]."""
|
||||||
if approximation is None:
|
|
||||||
|
if approximation is None or (shared.state.interrupted and opts.live_preview_fast_interrupt):
|
||||||
approximation = approximation_indexes.get(opts.show_progress_type, 0)
|
approximation = approximation_indexes.get(opts.show_progress_type, 0)
|
||||||
|
|
||||||
|
from modules import lowvram
|
||||||
|
if approximation == 0 and lowvram.is_enabled(shared.sd_model) and not shared.opts.live_preview_allow_lowvram_full:
|
||||||
|
approximation = 1
|
||||||
|
|
||||||
if approximation == 2:
|
if approximation == 2:
|
||||||
x_sample = sd_vae_approx.cheap_approximation(sample)
|
x_sample = sd_vae_approx.cheap_approximation(sample)
|
||||||
elif approximation == 1:
|
elif approximation == 1:
|
||||||
x_sample = sd_vae_approx.model()(sample.to(devices.device, devices.dtype)).detach()
|
x_sample = sd_vae_approx.model()(sample.to(devices.device, devices.dtype)).detach()
|
||||||
elif approximation == 3:
|
elif approximation == 3:
|
||||||
x_sample = sample * 1.5
|
x_sample = sd_vae_taesd.decoder_model()(sample.to(devices.device, devices.dtype)).detach()
|
||||||
x_sample = sd_vae_taesd.decoder_model()(x_sample.to(devices.device, devices.dtype)).detach()
|
|
||||||
x_sample = x_sample * 2 - 1
|
x_sample = x_sample * 2 - 1
|
||||||
else:
|
else:
|
||||||
if model is None:
|
if model is None:
|
||||||
model = shared.sd_model
|
model = shared.sd_model
|
||||||
|
with devices.without_autocast(): # fixes an issue with unstable VAEs that are flaky even in fp32
|
||||||
x_sample = model.decode_first_stage(sample.to(model.first_stage_model.dtype))
|
x_sample = model.decode_first_stage(sample.to(model.first_stage_model.dtype))
|
||||||
|
|
||||||
return x_sample
|
return x_sample
|
||||||
@@ -81,6 +97,14 @@ def images_tensor_to_samples(image, approximation=None, model=None):
|
|||||||
model = shared.sd_model
|
model = shared.sd_model
|
||||||
image = image.to(shared.device, dtype=devices.dtype_vae)
|
image = image.to(shared.device, dtype=devices.dtype_vae)
|
||||||
image = image * 2 - 1
|
image = image * 2 - 1
|
||||||
|
if len(image) > 1:
|
||||||
|
x_latent = torch.stack([
|
||||||
|
model.get_first_stage_encoding(
|
||||||
|
model.encode_first_stage(torch.unsqueeze(img, 0))
|
||||||
|
)[0]
|
||||||
|
for img in image
|
||||||
|
])
|
||||||
|
else:
|
||||||
x_latent = model.get_first_stage_encoding(model.encode_first_stage(image))
|
x_latent = model.get_first_stage_encoding(model.encode_first_stage(image))
|
||||||
|
|
||||||
return x_latent
|
return x_latent
|
||||||
@@ -127,3 +151,176 @@ def replace_torchsde_browinan():
|
|||||||
|
|
||||||
|
|
||||||
replace_torchsde_browinan()
|
replace_torchsde_browinan()
|
||||||
|
|
||||||
|
|
||||||
|
def apply_refiner(cfg_denoiser):
|
||||||
|
completed_ratio = cfg_denoiser.step / cfg_denoiser.total_steps
|
||||||
|
refiner_switch_at = cfg_denoiser.p.refiner_switch_at
|
||||||
|
refiner_checkpoint_info = cfg_denoiser.p.refiner_checkpoint_info
|
||||||
|
|
||||||
|
if refiner_switch_at is not None and completed_ratio < refiner_switch_at:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if refiner_checkpoint_info is None or shared.sd_model.sd_checkpoint_info == refiner_checkpoint_info:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if getattr(cfg_denoiser.p, "enable_hr", False) and not cfg_denoiser.p.is_hr_pass:
|
||||||
|
return False
|
||||||
|
|
||||||
|
cfg_denoiser.p.extra_generation_params['Refiner'] = refiner_checkpoint_info.short_title
|
||||||
|
cfg_denoiser.p.extra_generation_params['Refiner switch at'] = refiner_switch_at
|
||||||
|
|
||||||
|
with sd_models.SkipWritingToConfig():
|
||||||
|
sd_models.reload_model_weights(info=refiner_checkpoint_info)
|
||||||
|
|
||||||
|
devices.torch_gc()
|
||||||
|
cfg_denoiser.p.setup_conds()
|
||||||
|
cfg_denoiser.update_inner_model()
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
class TorchHijack:
|
||||||
|
"""This is here to replace torch.randn_like of k-diffusion.
|
||||||
|
|
||||||
|
k-diffusion has random_sampler argument for most samplers, but not for all, so
|
||||||
|
this is needed to properly replace every use of torch.randn_like.
|
||||||
|
|
||||||
|
We need to replace to make images generated in batches to be same as images generated individually."""
|
||||||
|
|
||||||
|
def __init__(self, p):
|
||||||
|
self.rng = p.rng
|
||||||
|
|
||||||
|
def __getattr__(self, item):
|
||||||
|
if item == 'randn_like':
|
||||||
|
return self.randn_like
|
||||||
|
|
||||||
|
if hasattr(torch, item):
|
||||||
|
return getattr(torch, item)
|
||||||
|
|
||||||
|
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'")
|
||||||
|
|
||||||
|
def randn_like(self, x):
|
||||||
|
return self.rng.next()
|
||||||
|
|
||||||
|
|
||||||
|
class Sampler:
|
||||||
|
def __init__(self, funcname):
|
||||||
|
self.funcname = funcname
|
||||||
|
self.func = funcname
|
||||||
|
self.extra_params = []
|
||||||
|
self.sampler_noises = None
|
||||||
|
self.stop_at = None
|
||||||
|
self.eta = None
|
||||||
|
self.config: SamplerData = None # set by the function calling the constructor
|
||||||
|
self.last_latent = None
|
||||||
|
self.s_min_uncond = None
|
||||||
|
self.s_churn = 0.0
|
||||||
|
self.s_tmin = 0.0
|
||||||
|
self.s_tmax = float('inf')
|
||||||
|
self.s_noise = 1.0
|
||||||
|
|
||||||
|
self.eta_option_field = 'eta_ancestral'
|
||||||
|
self.eta_infotext_field = 'Eta'
|
||||||
|
self.eta_default = 1.0
|
||||||
|
|
||||||
|
self.conditioning_key = shared.sd_model.model.conditioning_key
|
||||||
|
|
||||||
|
self.p = None
|
||||||
|
self.model_wrap_cfg = None
|
||||||
|
self.sampler_extra_args = None
|
||||||
|
self.options = {}
|
||||||
|
|
||||||
|
def callback_state(self, d):
|
||||||
|
step = d['i']
|
||||||
|
|
||||||
|
if self.stop_at is not None and step > self.stop_at:
|
||||||
|
raise InterruptedException
|
||||||
|
|
||||||
|
state.sampling_step = step
|
||||||
|
shared.total_tqdm.update()
|
||||||
|
|
||||||
|
def launch_sampling(self, steps, func):
|
||||||
|
self.model_wrap_cfg.steps = steps
|
||||||
|
self.model_wrap_cfg.total_steps = self.config.total_steps(steps)
|
||||||
|
state.sampling_steps = steps
|
||||||
|
state.sampling_step = 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
return func()
|
||||||
|
except RecursionError:
|
||||||
|
print(
|
||||||
|
'Encountered RecursionError during sampling, returning last latent. '
|
||||||
|
'rho >5 with a polyexponential scheduler may cause this error. '
|
||||||
|
'You should try to use a smaller rho value instead.'
|
||||||
|
)
|
||||||
|
return self.last_latent
|
||||||
|
except InterruptedException:
|
||||||
|
return self.last_latent
|
||||||
|
|
||||||
|
def number_of_needed_noises(self, p):
|
||||||
|
return p.steps
|
||||||
|
|
||||||
|
def initialize(self, p) -> dict:
|
||||||
|
self.p = p
|
||||||
|
self.model_wrap_cfg.p = p
|
||||||
|
self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None
|
||||||
|
self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None
|
||||||
|
self.model_wrap_cfg.step = 0
|
||||||
|
self.model_wrap_cfg.image_cfg_scale = getattr(p, 'image_cfg_scale', None)
|
||||||
|
self.eta = p.eta if p.eta is not None else getattr(opts, self.eta_option_field, 0.0)
|
||||||
|
self.s_min_uncond = getattr(p, 's_min_uncond', 0.0)
|
||||||
|
|
||||||
|
k_diffusion.sampling.torch = TorchHijack(p)
|
||||||
|
|
||||||
|
extra_params_kwargs = {}
|
||||||
|
for param_name in self.extra_params:
|
||||||
|
if hasattr(p, param_name) and param_name in inspect.signature(self.func).parameters:
|
||||||
|
extra_params_kwargs[param_name] = getattr(p, param_name)
|
||||||
|
|
||||||
|
if 'eta' in inspect.signature(self.func).parameters:
|
||||||
|
if self.eta != self.eta_default:
|
||||||
|
p.extra_generation_params[self.eta_infotext_field] = self.eta
|
||||||
|
|
||||||
|
extra_params_kwargs['eta'] = self.eta
|
||||||
|
|
||||||
|
if len(self.extra_params) > 0:
|
||||||
|
s_churn = getattr(opts, 's_churn', p.s_churn)
|
||||||
|
s_tmin = getattr(opts, 's_tmin', p.s_tmin)
|
||||||
|
s_tmax = getattr(opts, 's_tmax', p.s_tmax) or self.s_tmax # 0 = inf
|
||||||
|
s_noise = getattr(opts, 's_noise', p.s_noise)
|
||||||
|
|
||||||
|
if 's_churn' in extra_params_kwargs and s_churn != self.s_churn:
|
||||||
|
extra_params_kwargs['s_churn'] = s_churn
|
||||||
|
p.s_churn = s_churn
|
||||||
|
p.extra_generation_params['Sigma churn'] = s_churn
|
||||||
|
if 's_tmin' in extra_params_kwargs and s_tmin != self.s_tmin:
|
||||||
|
extra_params_kwargs['s_tmin'] = s_tmin
|
||||||
|
p.s_tmin = s_tmin
|
||||||
|
p.extra_generation_params['Sigma tmin'] = s_tmin
|
||||||
|
if 's_tmax' in extra_params_kwargs and s_tmax != self.s_tmax:
|
||||||
|
extra_params_kwargs['s_tmax'] = s_tmax
|
||||||
|
p.s_tmax = s_tmax
|
||||||
|
p.extra_generation_params['Sigma tmax'] = s_tmax
|
||||||
|
if 's_noise' in extra_params_kwargs and s_noise != self.s_noise:
|
||||||
|
extra_params_kwargs['s_noise'] = s_noise
|
||||||
|
p.s_noise = s_noise
|
||||||
|
p.extra_generation_params['Sigma noise'] = s_noise
|
||||||
|
|
||||||
|
return extra_params_kwargs
|
||||||
|
|
||||||
|
def create_noise_sampler(self, x, sigmas, p):
|
||||||
|
"""For DPM++ SDE: manually create noise sampler to enable deterministic results across different batch sizes"""
|
||||||
|
if shared.opts.no_dpmpp_sde_batch_determinism:
|
||||||
|
return None
|
||||||
|
|
||||||
|
from k_diffusion.sampling import BrownianTreeNoiseSampler
|
||||||
|
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
||||||
|
current_iter_seeds = p.all_seeds[p.iteration * p.batch_size:(p.iteration + 1) * p.batch_size]
|
||||||
|
return BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=current_iter_seeds)
|
||||||
|
|
||||||
|
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|||||||
@@ -1,224 +0,0 @@
|
|||||||
import math
|
|
||||||
import ldm.models.diffusion.ddim
|
|
||||||
import ldm.models.diffusion.plms
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from modules.shared import state
|
|
||||||
from modules import sd_samplers_common, prompt_parser, shared
|
|
||||||
import modules.models.diffusion.uni_pc
|
|
||||||
|
|
||||||
|
|
||||||
samplers_data_compvis = [
|
|
||||||
sd_samplers_common.SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {"default_eta_is_0": True, "uses_ensd": True, "no_sdxl": True}),
|
|
||||||
sd_samplers_common.SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {"no_sdxl": True}),
|
|
||||||
sd_samplers_common.SamplerData('UniPC', lambda model: VanillaStableDiffusionSampler(modules.models.diffusion.uni_pc.UniPCSampler, model), [], {"no_sdxl": True}),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class VanillaStableDiffusionSampler:
|
|
||||||
def __init__(self, constructor, sd_model):
|
|
||||||
self.sampler = constructor(sd_model)
|
|
||||||
self.is_ddim = hasattr(self.sampler, 'p_sample_ddim')
|
|
||||||
self.is_plms = hasattr(self.sampler, 'p_sample_plms')
|
|
||||||
self.is_unipc = isinstance(self.sampler, modules.models.diffusion.uni_pc.UniPCSampler)
|
|
||||||
self.orig_p_sample_ddim = None
|
|
||||||
if self.is_plms:
|
|
||||||
self.orig_p_sample_ddim = self.sampler.p_sample_plms
|
|
||||||
elif self.is_ddim:
|
|
||||||
self.orig_p_sample_ddim = self.sampler.p_sample_ddim
|
|
||||||
self.mask = None
|
|
||||||
self.nmask = None
|
|
||||||
self.init_latent = None
|
|
||||||
self.sampler_noises = None
|
|
||||||
self.step = 0
|
|
||||||
self.stop_at = None
|
|
||||||
self.eta = None
|
|
||||||
self.config = None
|
|
||||||
self.last_latent = None
|
|
||||||
|
|
||||||
self.conditioning_key = sd_model.model.conditioning_key
|
|
||||||
|
|
||||||
def number_of_needed_noises(self, p):
|
|
||||||
return 0
|
|
||||||
|
|
||||||
def launch_sampling(self, steps, func):
|
|
||||||
state.sampling_steps = steps
|
|
||||||
state.sampling_step = 0
|
|
||||||
|
|
||||||
try:
|
|
||||||
return func()
|
|
||||||
except sd_samplers_common.InterruptedException:
|
|
||||||
return self.last_latent
|
|
||||||
|
|
||||||
def p_sample_ddim_hook(self, x_dec, cond, ts, unconditional_conditioning, *args, **kwargs):
|
|
||||||
x_dec, ts, cond, unconditional_conditioning = self.before_sample(x_dec, ts, cond, unconditional_conditioning)
|
|
||||||
|
|
||||||
res = self.orig_p_sample_ddim(x_dec, cond, ts, *args, unconditional_conditioning=unconditional_conditioning, **kwargs)
|
|
||||||
|
|
||||||
x_dec, ts, cond, unconditional_conditioning, res = self.after_sample(x_dec, ts, cond, unconditional_conditioning, res)
|
|
||||||
|
|
||||||
return res
|
|
||||||
|
|
||||||
def before_sample(self, x, ts, cond, unconditional_conditioning):
|
|
||||||
if state.interrupted or state.skipped:
|
|
||||||
raise sd_samplers_common.InterruptedException
|
|
||||||
|
|
||||||
if self.stop_at is not None and self.step > self.stop_at:
|
|
||||||
raise sd_samplers_common.InterruptedException
|
|
||||||
|
|
||||||
# Have to unwrap the inpainting conditioning here to perform pre-processing
|
|
||||||
image_conditioning = None
|
|
||||||
uc_image_conditioning = None
|
|
||||||
if isinstance(cond, dict):
|
|
||||||
if self.conditioning_key == "crossattn-adm":
|
|
||||||
image_conditioning = cond["c_adm"]
|
|
||||||
uc_image_conditioning = unconditional_conditioning["c_adm"]
|
|
||||||
else:
|
|
||||||
image_conditioning = cond["c_concat"][0]
|
|
||||||
cond = cond["c_crossattn"][0]
|
|
||||||
unconditional_conditioning = unconditional_conditioning["c_crossattn"][0]
|
|
||||||
|
|
||||||
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
|
|
||||||
unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step)
|
|
||||||
|
|
||||||
assert all(len(conds) == 1 for conds in conds_list), 'composition via AND is not supported for DDIM/PLMS samplers'
|
|
||||||
cond = tensor
|
|
||||||
|
|
||||||
# for DDIM, shapes must match, we can't just process cond and uncond independently;
|
|
||||||
# filling unconditional_conditioning with repeats of the last vector to match length is
|
|
||||||
# not 100% correct but should work well enough
|
|
||||||
if unconditional_conditioning.shape[1] < cond.shape[1]:
|
|
||||||
last_vector = unconditional_conditioning[:, -1:]
|
|
||||||
last_vector_repeated = last_vector.repeat([1, cond.shape[1] - unconditional_conditioning.shape[1], 1])
|
|
||||||
unconditional_conditioning = torch.hstack([unconditional_conditioning, last_vector_repeated])
|
|
||||||
elif unconditional_conditioning.shape[1] > cond.shape[1]:
|
|
||||||
unconditional_conditioning = unconditional_conditioning[:, :cond.shape[1]]
|
|
||||||
|
|
||||||
if self.mask is not None:
|
|
||||||
img_orig = self.sampler.model.q_sample(self.init_latent, ts)
|
|
||||||
x = img_orig * self.mask + self.nmask * x
|
|
||||||
|
|
||||||
# Wrap the image conditioning back up since the DDIM code can accept the dict directly.
|
|
||||||
# Note that they need to be lists because it just concatenates them later.
|
|
||||||
if image_conditioning is not None:
|
|
||||||
if self.conditioning_key == "crossattn-adm":
|
|
||||||
cond = {"c_adm": image_conditioning, "c_crossattn": [cond]}
|
|
||||||
unconditional_conditioning = {"c_adm": uc_image_conditioning, "c_crossattn": [unconditional_conditioning]}
|
|
||||||
else:
|
|
||||||
cond = {"c_concat": [image_conditioning], "c_crossattn": [cond]}
|
|
||||||
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
|
|
||||||
|
|
||||||
return x, ts, cond, unconditional_conditioning
|
|
||||||
|
|
||||||
def update_step(self, last_latent):
|
|
||||||
if self.mask is not None:
|
|
||||||
self.last_latent = self.init_latent * self.mask + self.nmask * last_latent
|
|
||||||
else:
|
|
||||||
self.last_latent = last_latent
|
|
||||||
|
|
||||||
sd_samplers_common.store_latent(self.last_latent)
|
|
||||||
|
|
||||||
self.step += 1
|
|
||||||
state.sampling_step = self.step
|
|
||||||
shared.total_tqdm.update()
|
|
||||||
|
|
||||||
def after_sample(self, x, ts, cond, uncond, res):
|
|
||||||
if not self.is_unipc:
|
|
||||||
self.update_step(res[1])
|
|
||||||
|
|
||||||
return x, ts, cond, uncond, res
|
|
||||||
|
|
||||||
def unipc_after_update(self, x, model_x):
|
|
||||||
self.update_step(x)
|
|
||||||
|
|
||||||
def initialize(self, p):
|
|
||||||
if self.is_ddim:
|
|
||||||
self.eta = p.eta if p.eta is not None else shared.opts.eta_ddim
|
|
||||||
else:
|
|
||||||
self.eta = 0.0
|
|
||||||
|
|
||||||
if self.eta != 0.0:
|
|
||||||
p.extra_generation_params["Eta DDIM"] = self.eta
|
|
||||||
|
|
||||||
if self.is_unipc:
|
|
||||||
keys = [
|
|
||||||
('UniPC variant', 'uni_pc_variant'),
|
|
||||||
('UniPC skip type', 'uni_pc_skip_type'),
|
|
||||||
('UniPC order', 'uni_pc_order'),
|
|
||||||
('UniPC lower order final', 'uni_pc_lower_order_final'),
|
|
||||||
]
|
|
||||||
|
|
||||||
for name, key in keys:
|
|
||||||
v = getattr(shared.opts, key)
|
|
||||||
if v != shared.opts.get_default(key):
|
|
||||||
p.extra_generation_params[name] = v
|
|
||||||
|
|
||||||
for fieldname in ['p_sample_ddim', 'p_sample_plms']:
|
|
||||||
if hasattr(self.sampler, fieldname):
|
|
||||||
setattr(self.sampler, fieldname, self.p_sample_ddim_hook)
|
|
||||||
if self.is_unipc:
|
|
||||||
self.sampler.set_hooks(lambda x, t, c, u: self.before_sample(x, t, c, u), lambda x, t, c, u, r: self.after_sample(x, t, c, u, r), lambda x, mx: self.unipc_after_update(x, mx))
|
|
||||||
|
|
||||||
self.mask = p.mask if hasattr(p, 'mask') else None
|
|
||||||
self.nmask = p.nmask if hasattr(p, 'nmask') else None
|
|
||||||
|
|
||||||
|
|
||||||
def adjust_steps_if_invalid(self, p, num_steps):
|
|
||||||
if ((self.config.name == 'DDIM') and p.ddim_discretize == 'uniform') or (self.config.name == 'PLMS') or (self.config.name == 'UniPC'):
|
|
||||||
if self.config.name == 'UniPC' and num_steps < shared.opts.uni_pc_order:
|
|
||||||
num_steps = shared.opts.uni_pc_order
|
|
||||||
valid_step = 999 / (1000 // num_steps)
|
|
||||||
if valid_step == math.floor(valid_step):
|
|
||||||
return int(valid_step) + 1
|
|
||||||
|
|
||||||
return num_steps
|
|
||||||
|
|
||||||
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
|
||||||
steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps)
|
|
||||||
steps = self.adjust_steps_if_invalid(p, steps)
|
|
||||||
self.initialize(p)
|
|
||||||
|
|
||||||
self.sampler.make_schedule(ddim_num_steps=steps, ddim_eta=self.eta, ddim_discretize=p.ddim_discretize, verbose=False)
|
|
||||||
x1 = self.sampler.stochastic_encode(x, torch.tensor([t_enc] * int(x.shape[0])).to(shared.device), noise=noise)
|
|
||||||
|
|
||||||
self.init_latent = x
|
|
||||||
self.last_latent = x
|
|
||||||
self.step = 0
|
|
||||||
|
|
||||||
# Wrap the conditioning models with additional image conditioning for inpainting model
|
|
||||||
if image_conditioning is not None:
|
|
||||||
if self.conditioning_key == "crossattn-adm":
|
|
||||||
conditioning = {"c_adm": image_conditioning, "c_crossattn": [conditioning]}
|
|
||||||
unconditional_conditioning = {"c_adm": torch.zeros_like(image_conditioning), "c_crossattn": [unconditional_conditioning]}
|
|
||||||
else:
|
|
||||||
conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]}
|
|
||||||
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
|
|
||||||
|
|
||||||
samples = self.launch_sampling(t_enc + 1, lambda: self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning))
|
|
||||||
|
|
||||||
return samples
|
|
||||||
|
|
||||||
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
|
||||||
self.initialize(p)
|
|
||||||
|
|
||||||
self.init_latent = None
|
|
||||||
self.last_latent = x
|
|
||||||
self.step = 0
|
|
||||||
|
|
||||||
steps = self.adjust_steps_if_invalid(p, steps or p.steps)
|
|
||||||
|
|
||||||
# Wrap the conditioning models with additional image conditioning for inpainting model
|
|
||||||
# dummy_for_plms is needed because PLMS code checks the first item in the dict to have the right shape
|
|
||||||
if image_conditioning is not None:
|
|
||||||
if self.conditioning_key == "crossattn-adm":
|
|
||||||
conditioning = {"dummy_for_plms": np.zeros((conditioning.shape[0],)), "c_crossattn": [conditioning], "c_adm": image_conditioning}
|
|
||||||
unconditional_conditioning = {"c_crossattn": [unconditional_conditioning], "c_adm": torch.zeros_like(image_conditioning)}
|
|
||||||
else:
|
|
||||||
conditioning = {"dummy_for_plms": np.zeros((conditioning.shape[0],)), "c_crossattn": [conditioning], "c_concat": [image_conditioning]}
|
|
||||||
unconditional_conditioning = {"c_crossattn": [unconditional_conditioning], "c_concat": [image_conditioning]}
|
|
||||||
|
|
||||||
samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0])
|
|
||||||
|
|
||||||
return samples_ddim
|
|
||||||
|
|||||||
@@ -1,38 +1,41 @@
|
|||||||
from collections import deque
|
|
||||||
import torch
|
import torch
|
||||||
import inspect
|
import inspect
|
||||||
import k_diffusion.sampling
|
import k_diffusion.sampling
|
||||||
from modules import prompt_parser, devices, sd_samplers_common, sd_samplers_extra
|
from modules import sd_samplers_common, sd_samplers_extra, sd_samplers_cfg_denoiser
|
||||||
|
from modules.sd_samplers_cfg_denoiser import CFGDenoiser # noqa: F401
|
||||||
|
from modules.script_callbacks import ExtraNoiseParams, extra_noise_callback
|
||||||
|
|
||||||
from modules.processing import StableDiffusionProcessing
|
from modules.shared import opts
|
||||||
from modules.shared import opts, state
|
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback
|
|
||||||
from modules.script_callbacks import CFGDenoisedParams, cfg_denoised_callback
|
|
||||||
from modules.script_callbacks import AfterCFGCallbackParams, cfg_after_cfg_callback
|
|
||||||
|
|
||||||
samplers_k_diffusion = [
|
samplers_k_diffusion = [
|
||||||
|
('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}),
|
||||||
|
('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras', "second_order": True, "brownian_noise": True}),
|
||||||
|
('DPM++ 2M SDE Exponential', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_exp'], {'scheduler': 'exponential', "brownian_noise": True}),
|
||||||
|
('DPM++ 2M SDE Karras', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_ka'], {'scheduler': 'karras', "brownian_noise": True}),
|
||||||
('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {"uses_ensd": True}),
|
('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {"uses_ensd": True}),
|
||||||
('Euler', 'sample_euler', ['k_euler'], {}),
|
('Euler', 'sample_euler', ['k_euler'], {}),
|
||||||
('LMS', 'sample_lms', ['k_lms'], {}),
|
('LMS', 'sample_lms', ['k_lms'], {}),
|
||||||
('Heun', 'sample_heun', ['k_heun'], {"second_order": True}),
|
('Heun', 'sample_heun', ['k_heun'], {"second_order": True}),
|
||||||
('DPM2', 'sample_dpm_2', ['k_dpm_2'], {'discard_next_to_last_sigma': True}),
|
('DPM2', 'sample_dpm_2', ['k_dpm_2'], {'discard_next_to_last_sigma': True, "second_order": True}),
|
||||||
('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {'discard_next_to_last_sigma': True, "uses_ensd": True}),
|
('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {'discard_next_to_last_sigma': True, "uses_ensd": True, "second_order": True}),
|
||||||
('DPM++ 2S a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {"uses_ensd": True, "second_order": True}),
|
('DPM++ 2S a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {"uses_ensd": True, "second_order": True}),
|
||||||
('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {}),
|
('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {}),
|
||||||
('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {"second_order": True, "brownian_noise": True}),
|
('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {"second_order": True, "brownian_noise": True}),
|
||||||
('DPM++ 2M SDE', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_ka'], {"brownian_noise": True}),
|
('DPM++ 2M SDE', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_ka'], {"brownian_noise": True}),
|
||||||
|
('DPM++ 2M SDE Heun', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_heun'], {"brownian_noise": True, "solver_type": "heun"}),
|
||||||
|
('DPM++ 2M SDE Heun Karras', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_heun_ka'], {'scheduler': 'karras', "brownian_noise": True, "solver_type": "heun"}),
|
||||||
|
('DPM++ 2M SDE Heun Exponential', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_heun_exp'], {'scheduler': 'exponential', "brownian_noise": True, "solver_type": "heun"}),
|
||||||
|
('DPM++ 3M SDE', 'sample_dpmpp_3m_sde', ['k_dpmpp_3m_sde'], {'discard_next_to_last_sigma': True, "brownian_noise": True}),
|
||||||
|
('DPM++ 3M SDE Karras', 'sample_dpmpp_3m_sde', ['k_dpmpp_3m_sde_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "brownian_noise": True}),
|
||||||
|
('DPM++ 3M SDE Exponential', 'sample_dpmpp_3m_sde', ['k_dpmpp_3m_sde_exp'], {'scheduler': 'exponential', 'discard_next_to_last_sigma': True, "brownian_noise": True}),
|
||||||
('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {"uses_ensd": True}),
|
('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {"uses_ensd": True}),
|
||||||
('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {"uses_ensd": True}),
|
('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {"uses_ensd": True}),
|
||||||
('LMS Karras', 'sample_lms', ['k_lms_ka'], {'scheduler': 'karras'}),
|
('LMS Karras', 'sample_lms', ['k_lms_ka'], {'scheduler': 'karras'}),
|
||||||
('DPM2 Karras', 'sample_dpm_2', ['k_dpm_2_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "uses_ensd": True, "second_order": True}),
|
('DPM2 Karras', 'sample_dpm_2', ['k_dpm_2_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "uses_ensd": True, "second_order": True}),
|
||||||
('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "uses_ensd": True, "second_order": True}),
|
('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "uses_ensd": True, "second_order": True}),
|
||||||
('DPM++ 2S a Karras', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a_ka'], {'scheduler': 'karras', "uses_ensd": True, "second_order": True}),
|
('DPM++ 2S a Karras', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a_ka'], {'scheduler': 'karras', "uses_ensd": True, "second_order": True}),
|
||||||
('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}),
|
('Restart', sd_samplers_extra.restart_sampler, ['restart'], {'scheduler': 'karras', "second_order": True}),
|
||||||
('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras', "second_order": True, "brownian_noise": True}),
|
|
||||||
('DPM++ 2M SDE Karras', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_ka'], {'scheduler': 'karras', "brownian_noise": True}),
|
|
||||||
('DPM++ 2M SDE Exponential', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_exp'], {'scheduler': 'exponential', "brownian_noise": True}),
|
|
||||||
('Restart', sd_samplers_extra.restart_sampler, ['restart'], {'scheduler': 'karras'}),
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@@ -46,6 +49,12 @@ sampler_extra_params = {
|
|||||||
'sample_euler': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
|
'sample_euler': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
|
||||||
'sample_heun': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
|
'sample_heun': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
|
||||||
'sample_dpm_2': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
|
'sample_dpm_2': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
|
||||||
|
'sample_dpm_fast': ['s_noise'],
|
||||||
|
'sample_dpm_2_ancestral': ['s_noise'],
|
||||||
|
'sample_dpmpp_2s_ancestral': ['s_noise'],
|
||||||
|
'sample_dpmpp_sde': ['s_noise'],
|
||||||
|
'sample_dpmpp_2m_sde': ['s_noise'],
|
||||||
|
'sample_dpmpp_3m_sde': ['s_noise'],
|
||||||
}
|
}
|
||||||
|
|
||||||
k_diffusion_samplers_map = {x.name: x for x in samplers_data_k_diffusion}
|
k_diffusion_samplers_map = {x.name: x for x in samplers_data_k_diffusion}
|
||||||
@@ -57,317 +66,27 @@ k_diffusion_scheduler = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def catenate_conds(conds):
|
class CFGDenoiserKDiffusion(sd_samplers_cfg_denoiser.CFGDenoiser):
|
||||||
if not isinstance(conds[0], dict):
|
@property
|
||||||
return torch.cat(conds)
|
def inner_model(self):
|
||||||
|
if self.model_wrap is None:
|
||||||
|
denoiser = k_diffusion.external.CompVisVDenoiser if shared.sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser
|
||||||
|
self.model_wrap = denoiser(shared.sd_model, quantize=shared.opts.enable_quantization)
|
||||||
|
|
||||||
return {key: torch.cat([x[key] for x in conds]) for key in conds[0].keys()}
|
return self.model_wrap
|
||||||
|
|
||||||
|
|
||||||
def subscript_cond(cond, a, b):
|
class KDiffusionSampler(sd_samplers_common.Sampler):
|
||||||
if not isinstance(cond, dict):
|
def __init__(self, funcname, sd_model, options=None):
|
||||||
return cond[a:b]
|
super().__init__(funcname)
|
||||||
|
|
||||||
return {key: vec[a:b] for key, vec in cond.items()}
|
|
||||||
|
|
||||||
|
|
||||||
def pad_cond(tensor, repeats, empty):
|
|
||||||
if not isinstance(tensor, dict):
|
|
||||||
return torch.cat([tensor, empty.repeat((tensor.shape[0], repeats, 1))], axis=1)
|
|
||||||
|
|
||||||
tensor['crossattn'] = pad_cond(tensor['crossattn'], repeats, empty)
|
|
||||||
return tensor
|
|
||||||
|
|
||||||
|
|
||||||
class CFGDenoiser(torch.nn.Module):
|
|
||||||
"""
|
|
||||||
Classifier free guidance denoiser. A wrapper for stable diffusion model (specifically for unet)
|
|
||||||
that can take a noisy picture and produce a noise-free picture using two guidances (prompts)
|
|
||||||
instead of one. Originally, the second prompt is just an empty string, but we use non-empty
|
|
||||||
negative prompt.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, model):
|
|
||||||
super().__init__()
|
|
||||||
self.inner_model = model
|
|
||||||
self.mask = None
|
|
||||||
self.nmask = None
|
|
||||||
self.init_latent = None
|
|
||||||
self.step = 0
|
|
||||||
self.image_cfg_scale = None
|
|
||||||
self.padded_cond_uncond = False
|
|
||||||
|
|
||||||
def combine_denoised(self, x_out, conds_list, uncond, cond_scale):
|
|
||||||
denoised_uncond = x_out[-uncond.shape[0]:]
|
|
||||||
denoised = torch.clone(denoised_uncond)
|
|
||||||
|
|
||||||
for i, conds in enumerate(conds_list):
|
|
||||||
for cond_index, weight in conds:
|
|
||||||
denoised[i] += (x_out[cond_index] - denoised_uncond[i]) * (weight * cond_scale)
|
|
||||||
|
|
||||||
return denoised
|
|
||||||
|
|
||||||
def combine_denoised_for_edit_model(self, x_out, cond_scale):
|
|
||||||
out_cond, out_img_cond, out_uncond = x_out.chunk(3)
|
|
||||||
denoised = out_uncond + cond_scale * (out_cond - out_img_cond) + self.image_cfg_scale * (out_img_cond - out_uncond)
|
|
||||||
|
|
||||||
return denoised
|
|
||||||
|
|
||||||
def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond):
|
|
||||||
if state.interrupted or state.skipped:
|
|
||||||
raise sd_samplers_common.InterruptedException
|
|
||||||
|
|
||||||
# at self.image_cfg_scale == 1.0 produced results for edit model are the same as with normal sampling,
|
|
||||||
# so is_edit_model is set to False to support AND composition.
|
|
||||||
is_edit_model = shared.sd_model.cond_stage_key == "edit" and self.image_cfg_scale is not None and self.image_cfg_scale != 1.0
|
|
||||||
|
|
||||||
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
|
|
||||||
uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step)
|
|
||||||
|
|
||||||
assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)"
|
|
||||||
|
|
||||||
batch_size = len(conds_list)
|
|
||||||
repeats = [len(conds_list[i]) for i in range(batch_size)]
|
|
||||||
|
|
||||||
if shared.sd_model.model.conditioning_key == "crossattn-adm":
|
|
||||||
image_uncond = torch.zeros_like(image_cond)
|
|
||||||
make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": [c_crossattn], "c_adm": c_adm}
|
|
||||||
else:
|
|
||||||
image_uncond = image_cond
|
|
||||||
if isinstance(uncond, dict):
|
|
||||||
make_condition_dict = lambda c_crossattn, c_concat: {**c_crossattn, "c_concat": [c_concat]}
|
|
||||||
else:
|
|
||||||
make_condition_dict = lambda c_crossattn, c_concat: {"c_crossattn": [c_crossattn], "c_concat": [c_concat]}
|
|
||||||
|
|
||||||
if not is_edit_model:
|
|
||||||
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
|
|
||||||
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma])
|
|
||||||
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_uncond])
|
|
||||||
else:
|
|
||||||
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x] + [x])
|
|
||||||
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma] + [sigma])
|
|
||||||
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_uncond] + [torch.zeros_like(self.init_latent)])
|
|
||||||
|
|
||||||
denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps, tensor, uncond)
|
|
||||||
cfg_denoiser_callback(denoiser_params)
|
|
||||||
x_in = denoiser_params.x
|
|
||||||
image_cond_in = denoiser_params.image_cond
|
|
||||||
sigma_in = denoiser_params.sigma
|
|
||||||
tensor = denoiser_params.text_cond
|
|
||||||
uncond = denoiser_params.text_uncond
|
|
||||||
skip_uncond = False
|
|
||||||
|
|
||||||
# alternating uncond allows for higher thresholds without the quality loss normally expected from raising it
|
|
||||||
if self.step % 2 and s_min_uncond > 0 and sigma[0] < s_min_uncond and not is_edit_model:
|
|
||||||
skip_uncond = True
|
|
||||||
x_in = x_in[:-batch_size]
|
|
||||||
sigma_in = sigma_in[:-batch_size]
|
|
||||||
|
|
||||||
self.padded_cond_uncond = False
|
|
||||||
if shared.opts.pad_cond_uncond and tensor.shape[1] != uncond.shape[1]:
|
|
||||||
empty = shared.sd_model.cond_stage_model_empty_prompt
|
|
||||||
num_repeats = (tensor.shape[1] - uncond.shape[1]) // empty.shape[1]
|
|
||||||
|
|
||||||
if num_repeats < 0:
|
|
||||||
tensor = pad_cond(tensor, -num_repeats, empty)
|
|
||||||
self.padded_cond_uncond = True
|
|
||||||
elif num_repeats > 0:
|
|
||||||
uncond = pad_cond(uncond, num_repeats, empty)
|
|
||||||
self.padded_cond_uncond = True
|
|
||||||
|
|
||||||
if tensor.shape[1] == uncond.shape[1] or skip_uncond:
|
|
||||||
if is_edit_model:
|
|
||||||
cond_in = catenate_conds([tensor, uncond, uncond])
|
|
||||||
elif skip_uncond:
|
|
||||||
cond_in = tensor
|
|
||||||
else:
|
|
||||||
cond_in = catenate_conds([tensor, uncond])
|
|
||||||
|
|
||||||
if shared.batch_cond_uncond:
|
|
||||||
x_out = self.inner_model(x_in, sigma_in, cond=make_condition_dict(cond_in, image_cond_in))
|
|
||||||
else:
|
|
||||||
x_out = torch.zeros_like(x_in)
|
|
||||||
for batch_offset in range(0, x_out.shape[0], batch_size):
|
|
||||||
a = batch_offset
|
|
||||||
b = a + batch_size
|
|
||||||
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(subscript_cond(cond_in, a, b), image_cond_in[a:b]))
|
|
||||||
else:
|
|
||||||
x_out = torch.zeros_like(x_in)
|
|
||||||
batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size
|
|
||||||
for batch_offset in range(0, tensor.shape[0], batch_size):
|
|
||||||
a = batch_offset
|
|
||||||
b = min(a + batch_size, tensor.shape[0])
|
|
||||||
|
|
||||||
if not is_edit_model:
|
|
||||||
c_crossattn = subscript_cond(tensor, a, b)
|
|
||||||
else:
|
|
||||||
c_crossattn = torch.cat([tensor[a:b]], uncond)
|
|
||||||
|
|
||||||
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(c_crossattn, image_cond_in[a:b]))
|
|
||||||
|
|
||||||
if not skip_uncond:
|
|
||||||
x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=make_condition_dict(uncond, image_cond_in[-uncond.shape[0]:]))
|
|
||||||
|
|
||||||
denoised_image_indexes = [x[0][0] for x in conds_list]
|
|
||||||
if skip_uncond:
|
|
||||||
fake_uncond = torch.cat([x_out[i:i+1] for i in denoised_image_indexes])
|
|
||||||
x_out = torch.cat([x_out, fake_uncond]) # we skipped uncond denoising, so we put cond-denoised image to where the uncond-denoised image should be
|
|
||||||
|
|
||||||
denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps, self.inner_model)
|
|
||||||
cfg_denoised_callback(denoised_params)
|
|
||||||
|
|
||||||
devices.test_for_nans(x_out, "unet")
|
|
||||||
|
|
||||||
if opts.live_preview_content == "Prompt":
|
|
||||||
sd_samplers_common.store_latent(torch.cat([x_out[i:i+1] for i in denoised_image_indexes]))
|
|
||||||
elif opts.live_preview_content == "Negative prompt":
|
|
||||||
sd_samplers_common.store_latent(x_out[-uncond.shape[0]:])
|
|
||||||
|
|
||||||
if is_edit_model:
|
|
||||||
denoised = self.combine_denoised_for_edit_model(x_out, cond_scale)
|
|
||||||
elif skip_uncond:
|
|
||||||
denoised = self.combine_denoised(x_out, conds_list, uncond, 1.0)
|
|
||||||
else:
|
|
||||||
denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
|
|
||||||
|
|
||||||
if self.mask is not None:
|
|
||||||
denoised = self.init_latent * self.mask + self.nmask * denoised
|
|
||||||
|
|
||||||
after_cfg_callback_params = AfterCFGCallbackParams(denoised, state.sampling_step, state.sampling_steps)
|
|
||||||
cfg_after_cfg_callback(after_cfg_callback_params)
|
|
||||||
denoised = after_cfg_callback_params.x
|
|
||||||
|
|
||||||
self.step += 1
|
|
||||||
return denoised
|
|
||||||
|
|
||||||
|
|
||||||
class TorchHijack:
|
|
||||||
def __init__(self, sampler_noises):
|
|
||||||
# Using a deque to efficiently receive the sampler_noises in the same order as the previous index-based
|
|
||||||
# implementation.
|
|
||||||
self.sampler_noises = deque(sampler_noises)
|
|
||||||
|
|
||||||
def __getattr__(self, item):
|
|
||||||
if item == 'randn_like':
|
|
||||||
return self.randn_like
|
|
||||||
|
|
||||||
if hasattr(torch, item):
|
|
||||||
return getattr(torch, item)
|
|
||||||
|
|
||||||
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'")
|
|
||||||
|
|
||||||
def randn_like(self, x):
|
|
||||||
if self.sampler_noises:
|
|
||||||
noise = self.sampler_noises.popleft()
|
|
||||||
if noise.shape == x.shape:
|
|
||||||
return noise
|
|
||||||
|
|
||||||
return devices.randn_like(x)
|
|
||||||
|
|
||||||
|
|
||||||
class KDiffusionSampler:
|
|
||||||
def __init__(self, funcname, sd_model):
|
|
||||||
denoiser = k_diffusion.external.CompVisVDenoiser if sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser
|
|
||||||
|
|
||||||
self.model_wrap = denoiser(sd_model, quantize=shared.opts.enable_quantization)
|
|
||||||
self.funcname = funcname
|
|
||||||
self.func = funcname if callable(funcname) else getattr(k_diffusion.sampling, self.funcname)
|
|
||||||
self.extra_params = sampler_extra_params.get(funcname, [])
|
self.extra_params = sampler_extra_params.get(funcname, [])
|
||||||
self.model_wrap_cfg = CFGDenoiser(self.model_wrap)
|
|
||||||
self.sampler_noises = None
|
|
||||||
self.stop_at = None
|
|
||||||
self.eta = None
|
|
||||||
self.config = None # set by the function calling the constructor
|
|
||||||
self.last_latent = None
|
|
||||||
self.s_min_uncond = None
|
|
||||||
|
|
||||||
# NOTE: These are also defined in the StableDiffusionProcessing class.
|
self.options = options or {}
|
||||||
# They should have been here to begin with but we're going to
|
self.func = funcname if callable(funcname) else getattr(k_diffusion.sampling, self.funcname)
|
||||||
# leave that class __init__ signature alone.
|
|
||||||
self.s_churn = 0.0
|
|
||||||
self.s_tmin = 0.0
|
|
||||||
self.s_tmax = float('inf')
|
|
||||||
self.s_noise = 1.0
|
|
||||||
|
|
||||||
self.conditioning_key = sd_model.model.conditioning_key
|
self.model_wrap_cfg = CFGDenoiserKDiffusion(self)
|
||||||
|
self.model_wrap = self.model_wrap_cfg.inner_model
|
||||||
def callback_state(self, d):
|
|
||||||
step = d['i']
|
|
||||||
latent = d["denoised"]
|
|
||||||
if opts.live_preview_content == "Combined":
|
|
||||||
sd_samplers_common.store_latent(latent)
|
|
||||||
self.last_latent = latent
|
|
||||||
|
|
||||||
if self.stop_at is not None and step > self.stop_at:
|
|
||||||
raise sd_samplers_common.InterruptedException
|
|
||||||
|
|
||||||
state.sampling_step = step
|
|
||||||
shared.total_tqdm.update()
|
|
||||||
|
|
||||||
def launch_sampling(self, steps, func):
|
|
||||||
state.sampling_steps = steps
|
|
||||||
state.sampling_step = 0
|
|
||||||
|
|
||||||
try:
|
|
||||||
return func()
|
|
||||||
except RecursionError:
|
|
||||||
print(
|
|
||||||
'Encountered RecursionError during sampling, returning last latent. '
|
|
||||||
'rho >5 with a polyexponential scheduler may cause this error. '
|
|
||||||
'You should try to use a smaller rho value instead.'
|
|
||||||
)
|
|
||||||
return self.last_latent
|
|
||||||
except sd_samplers_common.InterruptedException:
|
|
||||||
return self.last_latent
|
|
||||||
|
|
||||||
def number_of_needed_noises(self, p):
|
|
||||||
return p.steps
|
|
||||||
|
|
||||||
def initialize(self, p: StableDiffusionProcessing):
|
|
||||||
self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None
|
|
||||||
self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None
|
|
||||||
self.model_wrap_cfg.step = 0
|
|
||||||
self.model_wrap_cfg.image_cfg_scale = getattr(p, 'image_cfg_scale', None)
|
|
||||||
self.eta = p.eta if p.eta is not None else opts.eta_ancestral
|
|
||||||
self.s_min_uncond = getattr(p, 's_min_uncond', 0.0)
|
|
||||||
|
|
||||||
k_diffusion.sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else [])
|
|
||||||
|
|
||||||
extra_params_kwargs = {}
|
|
||||||
for param_name in self.extra_params:
|
|
||||||
if hasattr(p, param_name) and param_name in inspect.signature(self.func).parameters:
|
|
||||||
extra_params_kwargs[param_name] = getattr(p, param_name)
|
|
||||||
|
|
||||||
if 'eta' in inspect.signature(self.func).parameters:
|
|
||||||
if self.eta != 1.0:
|
|
||||||
p.extra_generation_params["Eta"] = self.eta
|
|
||||||
|
|
||||||
extra_params_kwargs['eta'] = self.eta
|
|
||||||
|
|
||||||
if len(self.extra_params) > 0:
|
|
||||||
s_churn = getattr(opts, 's_churn', p.s_churn)
|
|
||||||
s_tmin = getattr(opts, 's_tmin', p.s_tmin)
|
|
||||||
s_tmax = getattr(opts, 's_tmax', p.s_tmax) or self.s_tmax # 0 = inf
|
|
||||||
s_noise = getattr(opts, 's_noise', p.s_noise)
|
|
||||||
|
|
||||||
if s_churn != self.s_churn:
|
|
||||||
extra_params_kwargs['s_churn'] = s_churn
|
|
||||||
p.s_churn = s_churn
|
|
||||||
p.extra_generation_params['Sigma churn'] = s_churn
|
|
||||||
if s_tmin != self.s_tmin:
|
|
||||||
extra_params_kwargs['s_tmin'] = s_tmin
|
|
||||||
p.s_tmin = s_tmin
|
|
||||||
p.extra_generation_params['Sigma tmin'] = s_tmin
|
|
||||||
if s_tmax != self.s_tmax:
|
|
||||||
extra_params_kwargs['s_tmax'] = s_tmax
|
|
||||||
p.s_tmax = s_tmax
|
|
||||||
p.extra_generation_params['Sigma tmax'] = s_tmax
|
|
||||||
if s_noise != self.s_noise:
|
|
||||||
extra_params_kwargs['s_noise'] = s_noise
|
|
||||||
p.s_noise = s_noise
|
|
||||||
p.extra_generation_params['Sigma noise'] = s_noise
|
|
||||||
|
|
||||||
return extra_params_kwargs
|
|
||||||
|
|
||||||
def get_sigmas(self, p, steps):
|
def get_sigmas(self, p, steps):
|
||||||
discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False)
|
discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False)
|
||||||
@@ -419,24 +138,21 @@ class KDiffusionSampler:
|
|||||||
|
|
||||||
return sigmas
|
return sigmas
|
||||||
|
|
||||||
def create_noise_sampler(self, x, sigmas, p):
|
|
||||||
"""For DPM++ SDE: manually create noise sampler to enable deterministic results across different batch sizes"""
|
|
||||||
if shared.opts.no_dpmpp_sde_batch_determinism:
|
|
||||||
return None
|
|
||||||
|
|
||||||
from k_diffusion.sampling import BrownianTreeNoiseSampler
|
|
||||||
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
|
||||||
current_iter_seeds = p.all_seeds[p.iteration * p.batch_size:(p.iteration + 1) * p.batch_size]
|
|
||||||
return BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=current_iter_seeds)
|
|
||||||
|
|
||||||
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
||||||
steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps)
|
steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps)
|
||||||
|
|
||||||
sigmas = self.get_sigmas(p, steps)
|
sigmas = self.get_sigmas(p, steps)
|
||||||
|
|
||||||
sigma_sched = sigmas[steps - t_enc - 1:]
|
sigma_sched = sigmas[steps - t_enc - 1:]
|
||||||
|
|
||||||
xi = x + noise * sigma_sched[0]
|
xi = x + noise * sigma_sched[0]
|
||||||
|
|
||||||
|
if opts.img2img_extra_noise > 0:
|
||||||
|
p.extra_generation_params["Extra noise"] = opts.img2img_extra_noise
|
||||||
|
extra_noise_params = ExtraNoiseParams(noise, x)
|
||||||
|
extra_noise_callback(extra_noise_params)
|
||||||
|
noise = extra_noise_params.noise
|
||||||
|
xi += noise * opts.img2img_extra_noise
|
||||||
|
|
||||||
extra_params_kwargs = self.initialize(p)
|
extra_params_kwargs = self.initialize(p)
|
||||||
parameters = inspect.signature(self.func).parameters
|
parameters = inspect.signature(self.func).parameters
|
||||||
|
|
||||||
@@ -456,9 +172,12 @@ class KDiffusionSampler:
|
|||||||
noise_sampler = self.create_noise_sampler(x, sigmas, p)
|
noise_sampler = self.create_noise_sampler(x, sigmas, p)
|
||||||
extra_params_kwargs['noise_sampler'] = noise_sampler
|
extra_params_kwargs['noise_sampler'] = noise_sampler
|
||||||
|
|
||||||
|
if self.config.options.get('solver_type', None) == 'heun':
|
||||||
|
extra_params_kwargs['solver_type'] = 'heun'
|
||||||
|
|
||||||
self.model_wrap_cfg.init_latent = x
|
self.model_wrap_cfg.init_latent = x
|
||||||
self.last_latent = x
|
self.last_latent = x
|
||||||
extra_args = {
|
self.sampler_extra_args = {
|
||||||
'cond': conditioning,
|
'cond': conditioning,
|
||||||
'image_cond': image_conditioning,
|
'image_cond': image_conditioning,
|
||||||
'uncond': unconditional_conditioning,
|
'uncond': unconditional_conditioning,
|
||||||
@@ -466,7 +185,7 @@ class KDiffusionSampler:
|
|||||||
's_min_uncond': self.s_min_uncond
|
's_min_uncond': self.s_min_uncond
|
||||||
}
|
}
|
||||||
|
|
||||||
samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
|
samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=self.sampler_extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
|
||||||
|
|
||||||
if self.model_wrap_cfg.padded_cond_uncond:
|
if self.model_wrap_cfg.padded_cond_uncond:
|
||||||
p.extra_generation_params["Pad conds"] = True
|
p.extra_generation_params["Pad conds"] = True
|
||||||
@@ -483,29 +202,37 @@ class KDiffusionSampler:
|
|||||||
extra_params_kwargs = self.initialize(p)
|
extra_params_kwargs = self.initialize(p)
|
||||||
parameters = inspect.signature(self.func).parameters
|
parameters = inspect.signature(self.func).parameters
|
||||||
|
|
||||||
|
if 'n' in parameters:
|
||||||
|
extra_params_kwargs['n'] = steps
|
||||||
|
|
||||||
if 'sigma_min' in parameters:
|
if 'sigma_min' in parameters:
|
||||||
extra_params_kwargs['sigma_min'] = self.model_wrap.sigmas[0].item()
|
extra_params_kwargs['sigma_min'] = self.model_wrap.sigmas[0].item()
|
||||||
extra_params_kwargs['sigma_max'] = self.model_wrap.sigmas[-1].item()
|
extra_params_kwargs['sigma_max'] = self.model_wrap.sigmas[-1].item()
|
||||||
if 'n' in parameters:
|
|
||||||
extra_params_kwargs['n'] = steps
|
if 'sigmas' in parameters:
|
||||||
else:
|
|
||||||
extra_params_kwargs['sigmas'] = sigmas
|
extra_params_kwargs['sigmas'] = sigmas
|
||||||
|
|
||||||
if self.config.options.get('brownian_noise', False):
|
if self.config.options.get('brownian_noise', False):
|
||||||
noise_sampler = self.create_noise_sampler(x, sigmas, p)
|
noise_sampler = self.create_noise_sampler(x, sigmas, p)
|
||||||
extra_params_kwargs['noise_sampler'] = noise_sampler
|
extra_params_kwargs['noise_sampler'] = noise_sampler
|
||||||
|
|
||||||
|
if self.config.options.get('solver_type', None) == 'heun':
|
||||||
|
extra_params_kwargs['solver_type'] = 'heun'
|
||||||
|
|
||||||
self.last_latent = x
|
self.last_latent = x
|
||||||
samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={
|
self.sampler_extra_args = {
|
||||||
'cond': conditioning,
|
'cond': conditioning,
|
||||||
'image_cond': image_conditioning,
|
'image_cond': image_conditioning,
|
||||||
'uncond': unconditional_conditioning,
|
'uncond': unconditional_conditioning,
|
||||||
'cond_scale': p.cfg_scale,
|
'cond_scale': p.cfg_scale,
|
||||||
's_min_uncond': self.s_min_uncond
|
's_min_uncond': self.s_min_uncond
|
||||||
}, disable=False, callback=self.callback_state, **extra_params_kwargs))
|
}
|
||||||
|
|
||||||
|
samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args=self.sampler_extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
|
||||||
|
|
||||||
if self.model_wrap_cfg.padded_cond_uncond:
|
if self.model_wrap_cfg.padded_cond_uncond:
|
||||||
p.extra_generation_params["Pad conds"] = True
|
p.extra_generation_params["Pad conds"] = True
|
||||||
|
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,167 @@
|
|||||||
|
import torch
|
||||||
|
import inspect
|
||||||
|
import sys
|
||||||
|
from modules import devices, sd_samplers_common, sd_samplers_timesteps_impl
|
||||||
|
from modules.sd_samplers_cfg_denoiser import CFGDenoiser
|
||||||
|
from modules.script_callbacks import ExtraNoiseParams, extra_noise_callback
|
||||||
|
|
||||||
|
from modules.shared import opts
|
||||||
|
import modules.shared as shared
|
||||||
|
|
||||||
|
samplers_timesteps = [
|
||||||
|
('DDIM', sd_samplers_timesteps_impl.ddim, ['ddim'], {}),
|
||||||
|
('PLMS', sd_samplers_timesteps_impl.plms, ['plms'], {}),
|
||||||
|
('UniPC', sd_samplers_timesteps_impl.unipc, ['unipc'], {}),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
samplers_data_timesteps = [
|
||||||
|
sd_samplers_common.SamplerData(label, lambda model, funcname=funcname: CompVisSampler(funcname, model), aliases, options)
|
||||||
|
for label, funcname, aliases, options in samplers_timesteps
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class CompVisTimestepsDenoiser(torch.nn.Module):
|
||||||
|
def __init__(self, model, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.inner_model = model
|
||||||
|
|
||||||
|
def forward(self, input, timesteps, **kwargs):
|
||||||
|
return self.inner_model.apply_model(input, timesteps, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class CompVisTimestepsVDenoiser(torch.nn.Module):
|
||||||
|
def __init__(self, model, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.inner_model = model
|
||||||
|
|
||||||
|
def predict_eps_from_z_and_v(self, x_t, t, v):
|
||||||
|
return self.inner_model.sqrt_alphas_cumprod[t.to(torch.int), None, None, None] * v + self.inner_model.sqrt_one_minus_alphas_cumprod[t.to(torch.int), None, None, None] * x_t
|
||||||
|
|
||||||
|
def forward(self, input, timesteps, **kwargs):
|
||||||
|
model_output = self.inner_model.apply_model(input, timesteps, **kwargs)
|
||||||
|
e_t = self.predict_eps_from_z_and_v(input, timesteps, model_output)
|
||||||
|
return e_t
|
||||||
|
|
||||||
|
|
||||||
|
class CFGDenoiserTimesteps(CFGDenoiser):
|
||||||
|
|
||||||
|
def __init__(self, sampler):
|
||||||
|
super().__init__(sampler)
|
||||||
|
|
||||||
|
self.alphas = shared.sd_model.alphas_cumprod
|
||||||
|
self.mask_before_denoising = True
|
||||||
|
|
||||||
|
def get_pred_x0(self, x_in, x_out, sigma):
|
||||||
|
ts = sigma.to(dtype=int)
|
||||||
|
|
||||||
|
a_t = self.alphas[ts][:, None, None, None]
|
||||||
|
sqrt_one_minus_at = (1 - a_t).sqrt()
|
||||||
|
|
||||||
|
pred_x0 = (x_in - sqrt_one_minus_at * x_out) / a_t.sqrt()
|
||||||
|
|
||||||
|
return pred_x0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def inner_model(self):
|
||||||
|
if self.model_wrap is None:
|
||||||
|
denoiser = CompVisTimestepsVDenoiser if shared.sd_model.parameterization == "v" else CompVisTimestepsDenoiser
|
||||||
|
self.model_wrap = denoiser(shared.sd_model)
|
||||||
|
|
||||||
|
return self.model_wrap
|
||||||
|
|
||||||
|
|
||||||
|
class CompVisSampler(sd_samplers_common.Sampler):
|
||||||
|
def __init__(self, funcname, sd_model):
|
||||||
|
super().__init__(funcname)
|
||||||
|
|
||||||
|
self.eta_option_field = 'eta_ddim'
|
||||||
|
self.eta_infotext_field = 'Eta DDIM'
|
||||||
|
self.eta_default = 0.0
|
||||||
|
|
||||||
|
self.model_wrap_cfg = CFGDenoiserTimesteps(self)
|
||||||
|
|
||||||
|
def get_timesteps(self, p, steps):
|
||||||
|
discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False)
|
||||||
|
if opts.always_discard_next_to_last_sigma and not discard_next_to_last_sigma:
|
||||||
|
discard_next_to_last_sigma = True
|
||||||
|
p.extra_generation_params["Discard penultimate sigma"] = True
|
||||||
|
|
||||||
|
steps += 1 if discard_next_to_last_sigma else 0
|
||||||
|
|
||||||
|
timesteps = torch.clip(torch.asarray(list(range(0, 1000, 1000 // steps)), device=devices.device) + 1, 0, 999)
|
||||||
|
|
||||||
|
return timesteps
|
||||||
|
|
||||||
|
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
||||||
|
steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps)
|
||||||
|
|
||||||
|
timesteps = self.get_timesteps(p, steps)
|
||||||
|
timesteps_sched = timesteps[:t_enc]
|
||||||
|
|
||||||
|
alphas_cumprod = shared.sd_model.alphas_cumprod
|
||||||
|
sqrt_alpha_cumprod = torch.sqrt(alphas_cumprod[timesteps[t_enc]])
|
||||||
|
sqrt_one_minus_alpha_cumprod = torch.sqrt(1 - alphas_cumprod[timesteps[t_enc]])
|
||||||
|
|
||||||
|
xi = x * sqrt_alpha_cumprod + noise * sqrt_one_minus_alpha_cumprod
|
||||||
|
|
||||||
|
if opts.img2img_extra_noise > 0:
|
||||||
|
p.extra_generation_params["Extra noise"] = opts.img2img_extra_noise
|
||||||
|
extra_noise_params = ExtraNoiseParams(noise, x)
|
||||||
|
extra_noise_callback(extra_noise_params)
|
||||||
|
noise = extra_noise_params.noise
|
||||||
|
xi += noise * opts.img2img_extra_noise * sqrt_alpha_cumprod
|
||||||
|
|
||||||
|
extra_params_kwargs = self.initialize(p)
|
||||||
|
parameters = inspect.signature(self.func).parameters
|
||||||
|
|
||||||
|
if 'timesteps' in parameters:
|
||||||
|
extra_params_kwargs['timesteps'] = timesteps_sched
|
||||||
|
if 'is_img2img' in parameters:
|
||||||
|
extra_params_kwargs['is_img2img'] = True
|
||||||
|
|
||||||
|
self.model_wrap_cfg.init_latent = x
|
||||||
|
self.last_latent = x
|
||||||
|
self.sampler_extra_args = {
|
||||||
|
'cond': conditioning,
|
||||||
|
'image_cond': image_conditioning,
|
||||||
|
'uncond': unconditional_conditioning,
|
||||||
|
'cond_scale': p.cfg_scale,
|
||||||
|
's_min_uncond': self.s_min_uncond
|
||||||
|
}
|
||||||
|
|
||||||
|
samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=self.sampler_extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
|
||||||
|
|
||||||
|
if self.model_wrap_cfg.padded_cond_uncond:
|
||||||
|
p.extra_generation_params["Pad conds"] = True
|
||||||
|
|
||||||
|
return samples
|
||||||
|
|
||||||
|
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
||||||
|
steps = steps or p.steps
|
||||||
|
timesteps = self.get_timesteps(p, steps)
|
||||||
|
|
||||||
|
extra_params_kwargs = self.initialize(p)
|
||||||
|
parameters = inspect.signature(self.func).parameters
|
||||||
|
|
||||||
|
if 'timesteps' in parameters:
|
||||||
|
extra_params_kwargs['timesteps'] = timesteps
|
||||||
|
|
||||||
|
self.last_latent = x
|
||||||
|
self.sampler_extra_args = {
|
||||||
|
'cond': conditioning,
|
||||||
|
'image_cond': image_conditioning,
|
||||||
|
'uncond': unconditional_conditioning,
|
||||||
|
'cond_scale': p.cfg_scale,
|
||||||
|
's_min_uncond': self.s_min_uncond
|
||||||
|
}
|
||||||
|
samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args=self.sampler_extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
|
||||||
|
|
||||||
|
if self.model_wrap_cfg.padded_cond_uncond:
|
||||||
|
p.extra_generation_params["Pad conds"] = True
|
||||||
|
|
||||||
|
return samples
|
||||||
|
|
||||||
|
|
||||||
|
sys.modules['modules.sd_samplers_compvis'] = sys.modules[__name__]
|
||||||
|
VanillaStableDiffusionSampler = CompVisSampler # temp. compatibility with older extensions
|
||||||
@@ -0,0 +1,137 @@
|
|||||||
|
import torch
|
||||||
|
import tqdm
|
||||||
|
import k_diffusion.sampling
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from modules import shared
|
||||||
|
from modules.models.diffusion.uni_pc import uni_pc
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def ddim(model, x, timesteps, extra_args=None, callback=None, disable=None, eta=0.0):
|
||||||
|
alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
|
||||||
|
alphas = alphas_cumprod[timesteps]
|
||||||
|
alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(torch.float64 if x.device.type != 'mps' else torch.float32)
|
||||||
|
sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
|
||||||
|
sigmas = eta * np.sqrt((1 - alphas_prev.cpu().numpy()) / (1 - alphas.cpu()) * (1 - alphas.cpu() / alphas_prev.cpu().numpy()))
|
||||||
|
|
||||||
|
extra_args = {} if extra_args is None else extra_args
|
||||||
|
s_in = x.new_ones((x.shape[0]))
|
||||||
|
s_x = x.new_ones((x.shape[0], 1, 1, 1))
|
||||||
|
for i in tqdm.trange(len(timesteps) - 1, disable=disable):
|
||||||
|
index = len(timesteps) - 1 - i
|
||||||
|
|
||||||
|
e_t = model(x, timesteps[index].item() * s_in, **extra_args)
|
||||||
|
|
||||||
|
a_t = alphas[index].item() * s_x
|
||||||
|
a_prev = alphas_prev[index].item() * s_x
|
||||||
|
sigma_t = sigmas[index].item() * s_x
|
||||||
|
sqrt_one_minus_at = sqrt_one_minus_alphas[index].item() * s_x
|
||||||
|
|
||||||
|
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||||
|
dir_xt = (1. - a_prev - sigma_t ** 2).sqrt() * e_t
|
||||||
|
noise = sigma_t * k_diffusion.sampling.torch.randn_like(x)
|
||||||
|
x = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
||||||
|
|
||||||
|
if callback is not None:
|
||||||
|
callback({'x': x, 'i': i, 'sigma': 0, 'sigma_hat': 0, 'denoised': pred_x0})
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def plms(model, x, timesteps, extra_args=None, callback=None, disable=None):
|
||||||
|
alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
|
||||||
|
alphas = alphas_cumprod[timesteps]
|
||||||
|
alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(torch.float64 if x.device.type != 'mps' else torch.float32)
|
||||||
|
sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
|
||||||
|
|
||||||
|
extra_args = {} if extra_args is None else extra_args
|
||||||
|
s_in = x.new_ones([x.shape[0]])
|
||||||
|
s_x = x.new_ones((x.shape[0], 1, 1, 1))
|
||||||
|
old_eps = []
|
||||||
|
|
||||||
|
def get_x_prev_and_pred_x0(e_t, index):
|
||||||
|
# select parameters corresponding to the currently considered timestep
|
||||||
|
a_t = alphas[index].item() * s_x
|
||||||
|
a_prev = alphas_prev[index].item() * s_x
|
||||||
|
sqrt_one_minus_at = sqrt_one_minus_alphas[index].item() * s_x
|
||||||
|
|
||||||
|
# current prediction for x_0
|
||||||
|
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||||
|
|
||||||
|
# direction pointing to x_t
|
||||||
|
dir_xt = (1. - a_prev).sqrt() * e_t
|
||||||
|
x_prev = a_prev.sqrt() * pred_x0 + dir_xt
|
||||||
|
return x_prev, pred_x0
|
||||||
|
|
||||||
|
for i in tqdm.trange(len(timesteps) - 1, disable=disable):
|
||||||
|
index = len(timesteps) - 1 - i
|
||||||
|
ts = timesteps[index].item() * s_in
|
||||||
|
t_next = timesteps[max(index - 1, 0)].item() * s_in
|
||||||
|
|
||||||
|
e_t = model(x, ts, **extra_args)
|
||||||
|
|
||||||
|
if len(old_eps) == 0:
|
||||||
|
# Pseudo Improved Euler (2nd order)
|
||||||
|
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
|
||||||
|
e_t_next = model(x_prev, t_next, **extra_args)
|
||||||
|
e_t_prime = (e_t + e_t_next) / 2
|
||||||
|
elif len(old_eps) == 1:
|
||||||
|
# 2nd order Pseudo Linear Multistep (Adams-Bashforth)
|
||||||
|
e_t_prime = (3 * e_t - old_eps[-1]) / 2
|
||||||
|
elif len(old_eps) == 2:
|
||||||
|
# 3nd order Pseudo Linear Multistep (Adams-Bashforth)
|
||||||
|
e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
|
||||||
|
else:
|
||||||
|
# 4nd order Pseudo Linear Multistep (Adams-Bashforth)
|
||||||
|
e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
|
||||||
|
|
||||||
|
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
|
||||||
|
|
||||||
|
old_eps.append(e_t)
|
||||||
|
if len(old_eps) >= 4:
|
||||||
|
old_eps.pop(0)
|
||||||
|
|
||||||
|
x = x_prev
|
||||||
|
|
||||||
|
if callback is not None:
|
||||||
|
callback({'x': x, 'i': i, 'sigma': 0, 'sigma_hat': 0, 'denoised': pred_x0})
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class UniPCCFG(uni_pc.UniPC):
|
||||||
|
def __init__(self, cfg_model, extra_args, callback, *args, **kwargs):
|
||||||
|
super().__init__(None, *args, **kwargs)
|
||||||
|
|
||||||
|
def after_update(x, model_x):
|
||||||
|
callback({'x': x, 'i': self.index, 'sigma': 0, 'sigma_hat': 0, 'denoised': model_x})
|
||||||
|
self.index += 1
|
||||||
|
|
||||||
|
self.cfg_model = cfg_model
|
||||||
|
self.extra_args = extra_args
|
||||||
|
self.callback = callback
|
||||||
|
self.index = 0
|
||||||
|
self.after_update = after_update
|
||||||
|
|
||||||
|
def get_model_input_time(self, t_continuous):
|
||||||
|
return (t_continuous - 1. / self.noise_schedule.total_N) * 1000.
|
||||||
|
|
||||||
|
def model(self, x, t):
|
||||||
|
t_input = self.get_model_input_time(t)
|
||||||
|
|
||||||
|
res = self.cfg_model(x, t_input, **self.extra_args)
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
def unipc(model, x, timesteps, extra_args=None, callback=None, disable=None, is_img2img=False):
|
||||||
|
alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
|
||||||
|
|
||||||
|
ns = uni_pc.NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod)
|
||||||
|
t_start = timesteps[-1] / 1000 + 1 / 1000 if is_img2img else None # this is likely off by a bit - if someone wants to fix it please by all means
|
||||||
|
unipc_sampler = UniPCCFG(model, extra_args, callback, ns, predict_x0=True, thresholding=False, variant=shared.opts.uni_pc_variant)
|
||||||
|
x = unipc_sampler.sample(x, steps=len(timesteps), t_start=t_start, skip_type=shared.opts.uni_pc_skip_type, method="multistep", order=shared.opts.uni_pc_order, lower_order_final=shared.opts.uni_pc_lower_order_final)
|
||||||
|
|
||||||
|
return x
|
||||||
+1
-1
@@ -47,7 +47,7 @@ def apply_unet(option=None):
|
|||||||
if current_unet_option is None:
|
if current_unet_option is None:
|
||||||
current_unet = None
|
current_unet = None
|
||||||
|
|
||||||
if not (shared.cmd_opts.lowvram or shared.cmd_opts.medvram):
|
if not shared.sd_model.lowvram:
|
||||||
shared.sd_model.model.diffusion_model.to(devices.device)
|
shared.sd_model.model.diffusion_model.to(devices.device)
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|||||||
+80
-25
@@ -1,6 +1,9 @@
|
|||||||
import os
|
import os
|
||||||
import collections
|
import collections
|
||||||
from modules import paths, shared, devices, script_callbacks, sd_models, extra_networks
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from modules import paths, shared, devices, script_callbacks, sd_models, extra_networks, lowvram, sd_hijack, hashes
|
||||||
|
|
||||||
import glob
|
import glob
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
|
||||||
@@ -17,6 +20,22 @@ checkpoint_info = None
|
|||||||
checkpoints_loaded = collections.OrderedDict()
|
checkpoints_loaded = collections.OrderedDict()
|
||||||
|
|
||||||
|
|
||||||
|
def get_loaded_vae_name():
|
||||||
|
if loaded_vae_file is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return os.path.basename(loaded_vae_file)
|
||||||
|
|
||||||
|
|
||||||
|
def get_loaded_vae_hash():
|
||||||
|
if loaded_vae_file is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
sha256 = hashes.sha256(loaded_vae_file, 'vae')
|
||||||
|
|
||||||
|
return sha256[0:10] if sha256 else None
|
||||||
|
|
||||||
|
|
||||||
def get_base_vae(model):
|
def get_base_vae(model):
|
||||||
if base_vae is not None and checkpoint_info == model.sd_checkpoint_info and model:
|
if base_vae is not None and checkpoint_info == model.sd_checkpoint_info and model:
|
||||||
return base_vae
|
return base_vae
|
||||||
@@ -51,7 +70,6 @@ def get_filename(filepath):
|
|||||||
|
|
||||||
|
|
||||||
def refresh_vae_list():
|
def refresh_vae_list():
|
||||||
global vae_dict
|
|
||||||
vae_dict.clear()
|
vae_dict.clear()
|
||||||
|
|
||||||
paths = [
|
paths = [
|
||||||
@@ -85,7 +103,7 @@ def refresh_vae_list():
|
|||||||
name = get_filename(filepath)
|
name = get_filename(filepath)
|
||||||
vae_dict[name] = filepath
|
vae_dict[name] = filepath
|
||||||
|
|
||||||
vae_dict = dict(sorted(vae_dict.items(), key=lambda item: shared.natural_sort_key(item[0])))
|
vae_dict.update(dict(sorted(vae_dict.items(), key=lambda item: shared.natural_sort_key(item[0]))))
|
||||||
|
|
||||||
|
|
||||||
def find_vae_near_checkpoint(checkpoint_file):
|
def find_vae_near_checkpoint(checkpoint_file):
|
||||||
@@ -97,37 +115,74 @@ def find_vae_near_checkpoint(checkpoint_file):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def resolve_vae(checkpoint_file):
|
@dataclass
|
||||||
if shared.cmd_opts.vae_path is not None:
|
class VaeResolution:
|
||||||
return shared.cmd_opts.vae_path, 'from commandline argument'
|
vae: str = None
|
||||||
|
source: str = None
|
||||||
|
resolved: bool = True
|
||||||
|
|
||||||
|
def tuple(self):
|
||||||
|
return self.vae, self.source
|
||||||
|
|
||||||
|
|
||||||
|
def is_automatic():
|
||||||
|
return shared.opts.sd_vae in {"Automatic", "auto"} # "auto" for people with old config
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_vae_from_setting() -> VaeResolution:
|
||||||
|
if shared.opts.sd_vae == "None":
|
||||||
|
return VaeResolution()
|
||||||
|
|
||||||
|
vae_from_options = vae_dict.get(shared.opts.sd_vae, None)
|
||||||
|
if vae_from_options is not None:
|
||||||
|
return VaeResolution(vae_from_options, 'specified in settings')
|
||||||
|
|
||||||
|
if not is_automatic():
|
||||||
|
print(f"Couldn't find VAE named {shared.opts.sd_vae}; using None instead")
|
||||||
|
|
||||||
|
return VaeResolution(resolved=False)
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_vae_from_user_metadata(checkpoint_file) -> VaeResolution:
|
||||||
metadata = extra_networks.get_user_metadata(checkpoint_file)
|
metadata = extra_networks.get_user_metadata(checkpoint_file)
|
||||||
vae_metadata = metadata.get("vae", None)
|
vae_metadata = metadata.get("vae", None)
|
||||||
if vae_metadata is not None and vae_metadata != "Automatic":
|
if vae_metadata is not None and vae_metadata != "Automatic":
|
||||||
if vae_metadata == "None":
|
if vae_metadata == "None":
|
||||||
return None, None
|
return VaeResolution()
|
||||||
|
|
||||||
vae_from_metadata = vae_dict.get(vae_metadata, None)
|
vae_from_metadata = vae_dict.get(vae_metadata, None)
|
||||||
if vae_from_metadata is not None:
|
if vae_from_metadata is not None:
|
||||||
return vae_from_metadata, "from user metadata"
|
return VaeResolution(vae_from_metadata, "from user metadata")
|
||||||
|
|
||||||
is_automatic = shared.opts.sd_vae in {"Automatic", "auto"} # "auto" for people with old config
|
return VaeResolution(resolved=False)
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_vae_near_checkpoint(checkpoint_file) -> VaeResolution:
|
||||||
vae_near_checkpoint = find_vae_near_checkpoint(checkpoint_file)
|
vae_near_checkpoint = find_vae_near_checkpoint(checkpoint_file)
|
||||||
if vae_near_checkpoint is not None and (shared.opts.sd_vae_as_default or is_automatic):
|
if vae_near_checkpoint is not None and (not shared.opts.sd_vae_overrides_per_model_preferences or is_automatic):
|
||||||
return vae_near_checkpoint, 'found near the checkpoint'
|
return VaeResolution(vae_near_checkpoint, 'found near the checkpoint')
|
||||||
|
|
||||||
if shared.opts.sd_vae == "None":
|
return VaeResolution(resolved=False)
|
||||||
return None, None
|
|
||||||
|
|
||||||
vae_from_options = vae_dict.get(shared.opts.sd_vae, None)
|
|
||||||
if vae_from_options is not None:
|
|
||||||
return vae_from_options, 'specified in settings'
|
|
||||||
|
|
||||||
if not is_automatic:
|
def resolve_vae(checkpoint_file) -> VaeResolution:
|
||||||
print(f"Couldn't find VAE named {shared.opts.sd_vae}; using None instead")
|
if shared.cmd_opts.vae_path is not None:
|
||||||
|
return VaeResolution(shared.cmd_opts.vae_path, 'from commandline argument')
|
||||||
|
|
||||||
return None, None
|
if shared.opts.sd_vae_overrides_per_model_preferences and not is_automatic():
|
||||||
|
return resolve_vae_from_setting()
|
||||||
|
|
||||||
|
res = resolve_vae_from_user_metadata(checkpoint_file)
|
||||||
|
if res.resolved:
|
||||||
|
return res
|
||||||
|
|
||||||
|
res = resolve_vae_near_checkpoint(checkpoint_file)
|
||||||
|
if res.resolved:
|
||||||
|
return res
|
||||||
|
|
||||||
|
res = resolve_vae_from_setting()
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
def load_vae_dict(filename, map_location):
|
def load_vae_dict(filename, map_location):
|
||||||
@@ -137,7 +192,7 @@ def load_vae_dict(filename, map_location):
|
|||||||
|
|
||||||
|
|
||||||
def load_vae(model, vae_file=None, vae_source="from unknown source"):
|
def load_vae(model, vae_file=None, vae_source="from unknown source"):
|
||||||
global vae_dict, loaded_vae_file
|
global vae_dict, base_vae, loaded_vae_file
|
||||||
# save_settings = False
|
# save_settings = False
|
||||||
|
|
||||||
cache_enabled = shared.opts.sd_vae_checkpoint_cache > 0
|
cache_enabled = shared.opts.sd_vae_checkpoint_cache > 0
|
||||||
@@ -175,6 +230,8 @@ def load_vae(model, vae_file=None, vae_source="from unknown source"):
|
|||||||
restore_base_vae(model)
|
restore_base_vae(model)
|
||||||
|
|
||||||
loaded_vae_file = vae_file
|
loaded_vae_file = vae_file
|
||||||
|
model.base_vae = base_vae
|
||||||
|
model.loaded_vae_file = loaded_vae_file
|
||||||
|
|
||||||
|
|
||||||
# don't call this from outside
|
# don't call this from outside
|
||||||
@@ -192,8 +249,6 @@ unspecified = object()
|
|||||||
|
|
||||||
|
|
||||||
def reload_vae_weights(sd_model=None, vae_file=unspecified):
|
def reload_vae_weights(sd_model=None, vae_file=unspecified):
|
||||||
from modules import lowvram, devices, sd_hijack
|
|
||||||
|
|
||||||
if not sd_model:
|
if not sd_model:
|
||||||
sd_model = shared.sd_model
|
sd_model = shared.sd_model
|
||||||
|
|
||||||
@@ -201,14 +256,14 @@ def reload_vae_weights(sd_model=None, vae_file=unspecified):
|
|||||||
checkpoint_file = checkpoint_info.filename
|
checkpoint_file = checkpoint_info.filename
|
||||||
|
|
||||||
if vae_file == unspecified:
|
if vae_file == unspecified:
|
||||||
vae_file, vae_source = resolve_vae(checkpoint_file)
|
vae_file, vae_source = resolve_vae(checkpoint_file).tuple()
|
||||||
else:
|
else:
|
||||||
vae_source = "from function argument"
|
vae_source = "from function argument"
|
||||||
|
|
||||||
if loaded_vae_file == vae_file:
|
if loaded_vae_file == vae_file:
|
||||||
return
|
return
|
||||||
|
|
||||||
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
if sd_model.lowvram:
|
||||||
lowvram.send_everything_to_cpu()
|
lowvram.send_everything_to_cpu()
|
||||||
else:
|
else:
|
||||||
sd_model.to(devices.cpu)
|
sd_model.to(devices.cpu)
|
||||||
@@ -220,7 +275,7 @@ def reload_vae_weights(sd_model=None, vae_file=unspecified):
|
|||||||
sd_hijack.model_hijack.hijack(sd_model)
|
sd_hijack.model_hijack.hijack(sd_model)
|
||||||
script_callbacks.model_loaded_callback(sd_model)
|
script_callbacks.model_loaded_callback(sd_model)
|
||||||
|
|
||||||
if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
|
if not sd_model.lowvram:
|
||||||
sd_model.to(devices.device)
|
sd_model.to(devices.device)
|
||||||
|
|
||||||
print("VAE weights loaded.")
|
print("VAE weights loaded.")
|
||||||
|
|||||||
+38
-922
@@ -1,838 +1,51 @@
|
|||||||
import datetime
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
import re
|
|
||||||
import sys
|
import sys
|
||||||
import threading
|
|
||||||
import time
|
|
||||||
import logging
|
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
import torch
|
|
||||||
import tqdm
|
|
||||||
|
|
||||||
import launch
|
from modules import shared_cmd_options, shared_gradio_themes, options, shared_items, sd_models_types
|
||||||
import modules.interrogate
|
|
||||||
import modules.memmon
|
|
||||||
import modules.styles
|
|
||||||
import modules.devices as devices
|
|
||||||
from modules import localization, script_loading, errors, ui_components, shared_items, cmd_args
|
|
||||||
from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir # noqa: F401
|
from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir # noqa: F401
|
||||||
from ldm.models.diffusion.ddpm import LatentDiffusion
|
from modules import util
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
cmd_opts = shared_cmd_options.cmd_opts
|
||||||
|
parser = shared_cmd_options.parser
|
||||||
|
|
||||||
|
batch_cond_uncond = True # old field, unused now in favor of shared.opts.batch_cond_uncond
|
||||||
|
parallel_processing_allowed = True
|
||||||
|
styles_filename = cmd_opts.styles_file
|
||||||
|
config_filename = cmd_opts.ui_settings_file
|
||||||
|
hide_dirs = {"visible": not cmd_opts.hide_ui_dir_config}
|
||||||
|
|
||||||
demo = None
|
demo = None
|
||||||
|
|
||||||
parser = cmd_args.parser
|
device = None
|
||||||
|
|
||||||
script_loading.preload_extensions(extensions_dir, parser, extension_list=launch.list_extensions(launch.args.ui_settings_file))
|
weight_load_location = None
|
||||||
script_loading.preload_extensions(extensions_builtin_dir, parser)
|
|
||||||
|
|
||||||
if os.environ.get('IGNORE_CMD_ARGS_ERRORS', None) is None:
|
|
||||||
cmd_opts = parser.parse_args()
|
|
||||||
else:
|
|
||||||
cmd_opts, _ = parser.parse_known_args()
|
|
||||||
|
|
||||||
|
|
||||||
restricted_opts = {
|
|
||||||
"samples_filename_pattern",
|
|
||||||
"directories_filename_pattern",
|
|
||||||
"outdir_samples",
|
|
||||||
"outdir_txt2img_samples",
|
|
||||||
"outdir_img2img_samples",
|
|
||||||
"outdir_extras_samples",
|
|
||||||
"outdir_grids",
|
|
||||||
"outdir_txt2img_grids",
|
|
||||||
"outdir_save",
|
|
||||||
"outdir_init_images"
|
|
||||||
}
|
|
||||||
|
|
||||||
# https://huggingface.co/datasets/freddyaboulton/gradio-theme-subdomains/resolve/main/subdomains.json
|
|
||||||
gradio_hf_hub_themes = [
|
|
||||||
"gradio/base",
|
|
||||||
"gradio/glass",
|
|
||||||
"gradio/monochrome",
|
|
||||||
"gradio/seafoam",
|
|
||||||
"gradio/soft",
|
|
||||||
"gradio/dracula_test",
|
|
||||||
"abidlabs/dracula_test",
|
|
||||||
"abidlabs/Lime",
|
|
||||||
"abidlabs/pakistan",
|
|
||||||
"Ama434/neutral-barlow",
|
|
||||||
"dawood/microsoft_windows",
|
|
||||||
"finlaymacklon/smooth_slate",
|
|
||||||
"Franklisi/darkmode",
|
|
||||||
"freddyaboulton/dracula_revamped",
|
|
||||||
"freddyaboulton/test-blue",
|
|
||||||
"gstaff/xkcd",
|
|
||||||
"Insuz/Mocha",
|
|
||||||
"Insuz/SimpleIndigo",
|
|
||||||
"JohnSmith9982/small_and_pretty",
|
|
||||||
"nota-ai/theme",
|
|
||||||
"nuttea/Softblue",
|
|
||||||
"ParityError/Anime",
|
|
||||||
"reilnuud/polite",
|
|
||||||
"remilia/Ghostly",
|
|
||||||
"rottenlittlecreature/Moon_Goblin",
|
|
||||||
"step-3-profit/Midnight-Deep",
|
|
||||||
"Taithrah/Minimal",
|
|
||||||
"ysharma/huggingface",
|
|
||||||
"ysharma/steampunk"
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
cmd_opts.disable_extension_access = (cmd_opts.share or cmd_opts.listen or cmd_opts.server_name) and not cmd_opts.enable_insecure_extension_access
|
|
||||||
|
|
||||||
devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_esrgan, devices.device_codeformer = \
|
|
||||||
(devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'esrgan', 'codeformer'])
|
|
||||||
|
|
||||||
devices.dtype = torch.float32 if cmd_opts.no_half else torch.float16
|
|
||||||
devices.dtype_vae = torch.float32 if cmd_opts.no_half or cmd_opts.no_half_vae else torch.float16
|
|
||||||
|
|
||||||
device = devices.device
|
|
||||||
weight_load_location = None if cmd_opts.lowram else "cpu"
|
|
||||||
|
|
||||||
batch_cond_uncond = cmd_opts.always_batch_cond_uncond or not (cmd_opts.lowvram or cmd_opts.medvram)
|
|
||||||
parallel_processing_allowed = not cmd_opts.lowvram and not cmd_opts.medvram
|
|
||||||
xformers_available = False
|
xformers_available = False
|
||||||
config_filename = cmd_opts.ui_settings_file
|
|
||||||
|
|
||||||
os.makedirs(cmd_opts.hypernetwork_dir, exist_ok=True)
|
|
||||||
hypernetworks = {}
|
hypernetworks = {}
|
||||||
|
|
||||||
loaded_hypernetworks = []
|
loaded_hypernetworks = []
|
||||||
|
|
||||||
|
state = None
|
||||||
|
|
||||||
def reload_hypernetworks():
|
prompt_styles = None
|
||||||
from modules.hypernetworks import hypernetwork
|
|
||||||
global hypernetworks
|
|
||||||
|
|
||||||
hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir)
|
interrogator = None
|
||||||
|
|
||||||
|
|
||||||
class State:
|
|
||||||
skipped = False
|
|
||||||
interrupted = False
|
|
||||||
job = ""
|
|
||||||
job_no = 0
|
|
||||||
job_count = 0
|
|
||||||
processing_has_refined_job_count = False
|
|
||||||
job_timestamp = '0'
|
|
||||||
sampling_step = 0
|
|
||||||
sampling_steps = 0
|
|
||||||
current_latent = None
|
|
||||||
current_image = None
|
|
||||||
current_image_sampling_step = 0
|
|
||||||
id_live_preview = 0
|
|
||||||
textinfo = None
|
|
||||||
time_start = None
|
|
||||||
server_start = None
|
|
||||||
_server_command_signal = threading.Event()
|
|
||||||
_server_command: Optional[str] = None
|
|
||||||
|
|
||||||
@property
|
|
||||||
def need_restart(self) -> bool:
|
|
||||||
# Compatibility getter for need_restart.
|
|
||||||
return self.server_command == "restart"
|
|
||||||
|
|
||||||
@need_restart.setter
|
|
||||||
def need_restart(self, value: bool) -> None:
|
|
||||||
# Compatibility setter for need_restart.
|
|
||||||
if value:
|
|
||||||
self.server_command = "restart"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def server_command(self):
|
|
||||||
return self._server_command
|
|
||||||
|
|
||||||
@server_command.setter
|
|
||||||
def server_command(self, value: Optional[str]) -> None:
|
|
||||||
"""
|
|
||||||
Set the server command to `value` and signal that it's been set.
|
|
||||||
"""
|
|
||||||
self._server_command = value
|
|
||||||
self._server_command_signal.set()
|
|
||||||
|
|
||||||
def wait_for_server_command(self, timeout: Optional[float] = None) -> Optional[str]:
|
|
||||||
"""
|
|
||||||
Wait for server command to get set; return and clear the value and signal.
|
|
||||||
"""
|
|
||||||
if self._server_command_signal.wait(timeout):
|
|
||||||
self._server_command_signal.clear()
|
|
||||||
req = self._server_command
|
|
||||||
self._server_command = None
|
|
||||||
return req
|
|
||||||
return None
|
|
||||||
|
|
||||||
def request_restart(self) -> None:
|
|
||||||
self.interrupt()
|
|
||||||
self.server_command = "restart"
|
|
||||||
log.info("Received restart request")
|
|
||||||
|
|
||||||
def skip(self):
|
|
||||||
self.skipped = True
|
|
||||||
log.info("Received skip request")
|
|
||||||
|
|
||||||
def interrupt(self):
|
|
||||||
self.interrupted = True
|
|
||||||
log.info("Received interrupt request")
|
|
||||||
|
|
||||||
def nextjob(self):
|
|
||||||
if opts.live_previews_enable and opts.show_progress_every_n_steps == -1:
|
|
||||||
self.do_set_current_image()
|
|
||||||
|
|
||||||
self.job_no += 1
|
|
||||||
self.sampling_step = 0
|
|
||||||
self.current_image_sampling_step = 0
|
|
||||||
|
|
||||||
def dict(self):
|
|
||||||
obj = {
|
|
||||||
"skipped": self.skipped,
|
|
||||||
"interrupted": self.interrupted,
|
|
||||||
"job": self.job,
|
|
||||||
"job_count": self.job_count,
|
|
||||||
"job_timestamp": self.job_timestamp,
|
|
||||||
"job_no": self.job_no,
|
|
||||||
"sampling_step": self.sampling_step,
|
|
||||||
"sampling_steps": self.sampling_steps,
|
|
||||||
}
|
|
||||||
|
|
||||||
return obj
|
|
||||||
|
|
||||||
def begin(self, job: str = "(unknown)"):
|
|
||||||
self.sampling_step = 0
|
|
||||||
self.job_count = -1
|
|
||||||
self.processing_has_refined_job_count = False
|
|
||||||
self.job_no = 0
|
|
||||||
self.job_timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
|
|
||||||
self.current_latent = None
|
|
||||||
self.current_image = None
|
|
||||||
self.current_image_sampling_step = 0
|
|
||||||
self.id_live_preview = 0
|
|
||||||
self.skipped = False
|
|
||||||
self.interrupted = False
|
|
||||||
self.textinfo = None
|
|
||||||
self.time_start = time.time()
|
|
||||||
self.job = job
|
|
||||||
devices.torch_gc()
|
|
||||||
log.info("Starting job %s", job)
|
|
||||||
|
|
||||||
def end(self):
|
|
||||||
duration = time.time() - self.time_start
|
|
||||||
log.info("Ending job %s (%.2f seconds)", self.job, duration)
|
|
||||||
self.job = ""
|
|
||||||
self.job_count = 0
|
|
||||||
|
|
||||||
devices.torch_gc()
|
|
||||||
|
|
||||||
def set_current_image(self):
|
|
||||||
"""sets self.current_image from self.current_latent if enough sampling steps have been made after the last call to this"""
|
|
||||||
if not parallel_processing_allowed:
|
|
||||||
return
|
|
||||||
|
|
||||||
if self.sampling_step - self.current_image_sampling_step >= opts.show_progress_every_n_steps and opts.live_previews_enable and opts.show_progress_every_n_steps != -1:
|
|
||||||
self.do_set_current_image()
|
|
||||||
|
|
||||||
def do_set_current_image(self):
|
|
||||||
if self.current_latent is None:
|
|
||||||
return
|
|
||||||
|
|
||||||
import modules.sd_samplers
|
|
||||||
|
|
||||||
try:
|
|
||||||
if opts.show_progress_grid:
|
|
||||||
self.assign_current_image(modules.sd_samplers.samples_to_image_grid(self.current_latent))
|
|
||||||
else:
|
|
||||||
self.assign_current_image(modules.sd_samplers.sample_to_image(self.current_latent))
|
|
||||||
|
|
||||||
self.current_image_sampling_step = self.sampling_step
|
|
||||||
|
|
||||||
except Exception:
|
|
||||||
# when switching models during genration, VAE would be on CPU, so creating an image will fail.
|
|
||||||
# we silently ignore this error
|
|
||||||
errors.record_exception()
|
|
||||||
|
|
||||||
def assign_current_image(self, image):
|
|
||||||
self.current_image = image
|
|
||||||
self.id_live_preview += 1
|
|
||||||
|
|
||||||
|
|
||||||
state = State()
|
|
||||||
state.server_start = time.time()
|
|
||||||
|
|
||||||
styles_filename = cmd_opts.styles_file
|
|
||||||
prompt_styles = modules.styles.StyleDatabase(styles_filename)
|
|
||||||
|
|
||||||
interrogator = modules.interrogate.InterrogateModels("interrogate")
|
|
||||||
|
|
||||||
face_restorers = []
|
face_restorers = []
|
||||||
|
|
||||||
|
options_templates = None
|
||||||
|
opts = None
|
||||||
|
restricted_opts = None
|
||||||
|
|
||||||
class OptionInfo:
|
sd_model: sd_models_types.WebuiSdModel = None
|
||||||
def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, section=None, refresh=None, comment_before='', comment_after=''):
|
|
||||||
self.default = default
|
|
||||||
self.label = label
|
|
||||||
self.component = component
|
|
||||||
self.component_args = component_args
|
|
||||||
self.onchange = onchange
|
|
||||||
self.section = section
|
|
||||||
self.refresh = refresh
|
|
||||||
self.do_not_save = False
|
|
||||||
|
|
||||||
self.comment_before = comment_before
|
|
||||||
"""HTML text that will be added after label in UI"""
|
|
||||||
|
|
||||||
self.comment_after = comment_after
|
|
||||||
"""HTML text that will be added before label in UI"""
|
|
||||||
|
|
||||||
def link(self, label, url):
|
|
||||||
self.comment_before += f"[<a href='{url}' target='_blank'>{label}</a>]"
|
|
||||||
return self
|
|
||||||
|
|
||||||
def js(self, label, js_func):
|
|
||||||
self.comment_before += f"[<a onclick='{js_func}(); return false'>{label}</a>]"
|
|
||||||
return self
|
|
||||||
|
|
||||||
def info(self, info):
|
|
||||||
self.comment_after += f"<span class='info'>({info})</span>"
|
|
||||||
return self
|
|
||||||
|
|
||||||
def html(self, html):
|
|
||||||
self.comment_after += html
|
|
||||||
return self
|
|
||||||
|
|
||||||
def needs_restart(self):
|
|
||||||
self.comment_after += " <span class='info'>(requires restart)</span>"
|
|
||||||
return self
|
|
||||||
|
|
||||||
def needs_reload_ui(self):
|
|
||||||
self.comment_after += " <span class='info'>(requires Reload UI)</span>"
|
|
||||||
return self
|
|
||||||
|
|
||||||
|
|
||||||
class OptionHTML(OptionInfo):
|
|
||||||
def __init__(self, text):
|
|
||||||
super().__init__(str(text).strip(), label='', component=lambda **kwargs: gr.HTML(elem_classes="settings-info", **kwargs))
|
|
||||||
|
|
||||||
self.do_not_save = True
|
|
||||||
|
|
||||||
|
|
||||||
def options_section(section_identifier, options_dict):
|
|
||||||
for v in options_dict.values():
|
|
||||||
v.section = section_identifier
|
|
||||||
|
|
||||||
return options_dict
|
|
||||||
|
|
||||||
|
|
||||||
def list_checkpoint_tiles():
|
|
||||||
import modules.sd_models
|
|
||||||
return modules.sd_models.checkpoint_tiles()
|
|
||||||
|
|
||||||
|
|
||||||
def refresh_checkpoints():
|
|
||||||
import modules.sd_models
|
|
||||||
return modules.sd_models.list_models()
|
|
||||||
|
|
||||||
|
|
||||||
def list_samplers():
|
|
||||||
import modules.sd_samplers
|
|
||||||
return modules.sd_samplers.all_samplers
|
|
||||||
|
|
||||||
|
|
||||||
hide_dirs = {"visible": not cmd_opts.hide_ui_dir_config}
|
|
||||||
tab_names = []
|
|
||||||
|
|
||||||
options_templates = {}
|
|
||||||
|
|
||||||
options_templates.update(options_section(('saving-images', "Saving images/grids"), {
|
|
||||||
"samples_save": OptionInfo(True, "Always save all generated images"),
|
|
||||||
"samples_format": OptionInfo('png', 'File format for images'),
|
|
||||||
"samples_filename_pattern": OptionInfo("", "Images filename pattern", component_args=hide_dirs).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Images-Filename-Name-and-Subdirectory"),
|
|
||||||
"save_images_add_number": OptionInfo(True, "Add number to filename when saving", component_args=hide_dirs),
|
|
||||||
|
|
||||||
"grid_save": OptionInfo(True, "Always save all generated image grids"),
|
|
||||||
"grid_format": OptionInfo('png', 'File format for grids'),
|
|
||||||
"grid_extended_filename": OptionInfo(False, "Add extended info (seed, prompt) to filename when saving grid"),
|
|
||||||
"grid_only_if_multiple": OptionInfo(True, "Do not save grids consisting of one picture"),
|
|
||||||
"grid_prevent_empty_spots": OptionInfo(False, "Prevent empty spots in grid (when set to autodetect)"),
|
|
||||||
"grid_zip_filename_pattern": OptionInfo("", "Archive filename pattern", component_args=hide_dirs).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Images-Filename-Name-and-Subdirectory"),
|
|
||||||
"n_rows": OptionInfo(-1, "Grid row count; use -1 for autodetect and 0 for it to be same as batch size", gr.Slider, {"minimum": -1, "maximum": 16, "step": 1}),
|
|
||||||
"font": OptionInfo("", "Font for image grids that have text"),
|
|
||||||
"grid_text_active_color": OptionInfo("#000000", "Text color for image grids", ui_components.FormColorPicker, {}),
|
|
||||||
"grid_text_inactive_color": OptionInfo("#999999", "Inactive text color for image grids", ui_components.FormColorPicker, {}),
|
|
||||||
"grid_background_color": OptionInfo("#ffffff", "Background color for image grids", ui_components.FormColorPicker, {}),
|
|
||||||
|
|
||||||
"enable_pnginfo": OptionInfo(True, "Save text information about generation parameters as chunks to png files"),
|
|
||||||
"save_txt": OptionInfo(False, "Create a text file next to every image with generation parameters."),
|
|
||||||
"save_images_before_face_restoration": OptionInfo(False, "Save a copy of image before doing face restoration."),
|
|
||||||
"save_images_before_highres_fix": OptionInfo(False, "Save a copy of image before applying highres fix."),
|
|
||||||
"save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"),
|
|
||||||
"save_mask": OptionInfo(False, "For inpainting, save a copy of the greyscale mask"),
|
|
||||||
"save_mask_composite": OptionInfo(False, "For inpainting, save a masked composite"),
|
|
||||||
"jpeg_quality": OptionInfo(80, "Quality for saved jpeg images", gr.Slider, {"minimum": 1, "maximum": 100, "step": 1}),
|
|
||||||
"webp_lossless": OptionInfo(False, "Use lossless compression for webp images"),
|
|
||||||
"export_for_4chan": OptionInfo(True, "Save copy of large images as JPG").info("if the file size is above the limit, or either width or height are above the limit"),
|
|
||||||
"img_downscale_threshold": OptionInfo(4.0, "File size limit for the above option, MB", gr.Number),
|
|
||||||
"target_side_length": OptionInfo(4000, "Width/height limit for the above option, in pixels", gr.Number),
|
|
||||||
"img_max_size_mp": OptionInfo(200, "Maximum image size", gr.Number).info("in megapixels"),
|
|
||||||
|
|
||||||
"use_original_name_batch": OptionInfo(True, "Use original name for output filename during batch process in extras tab"),
|
|
||||||
"use_upscaler_name_as_suffix": OptionInfo(False, "Use upscaler name as filename suffix in the extras tab"),
|
|
||||||
"save_selected_only": OptionInfo(True, "When using 'Save' button, only save a single selected image"),
|
|
||||||
"save_init_img": OptionInfo(False, "Save init images when using img2img"),
|
|
||||||
|
|
||||||
"temp_dir": OptionInfo("", "Directory for temporary images; leave empty for default"),
|
|
||||||
"clean_temp_dir_at_start": OptionInfo(False, "Cleanup non-default temporary directory when starting webui"),
|
|
||||||
|
|
||||||
"save_incomplete_images": OptionInfo(False, "Save incomplete images").info("save images that has been interrupted in mid-generation; even if not saved, they will still show up in webui output."),
|
|
||||||
}))
|
|
||||||
|
|
||||||
options_templates.update(options_section(('saving-paths', "Paths for saving"), {
|
|
||||||
"outdir_samples": OptionInfo("", "Output directory for images; if empty, defaults to three directories below", component_args=hide_dirs),
|
|
||||||
"outdir_txt2img_samples": OptionInfo("outputs/txt2img-images", 'Output directory for txt2img images', component_args=hide_dirs),
|
|
||||||
"outdir_img2img_samples": OptionInfo("outputs/img2img-images", 'Output directory for img2img images', component_args=hide_dirs),
|
|
||||||
"outdir_extras_samples": OptionInfo("outputs/extras-images", 'Output directory for images from extras tab', component_args=hide_dirs),
|
|
||||||
"outdir_grids": OptionInfo("", "Output directory for grids; if empty, defaults to two directories below", component_args=hide_dirs),
|
|
||||||
"outdir_txt2img_grids": OptionInfo("outputs/txt2img-grids", 'Output directory for txt2img grids', component_args=hide_dirs),
|
|
||||||
"outdir_img2img_grids": OptionInfo("outputs/img2img-grids", 'Output directory for img2img grids', component_args=hide_dirs),
|
|
||||||
"outdir_save": OptionInfo("log/images", "Directory for saving images using the Save button", component_args=hide_dirs),
|
|
||||||
"outdir_init_images": OptionInfo("outputs/init-images", "Directory for saving init images when using img2img", component_args=hide_dirs),
|
|
||||||
}))
|
|
||||||
|
|
||||||
options_templates.update(options_section(('saving-to-dirs', "Saving to a directory"), {
|
|
||||||
"save_to_dirs": OptionInfo(True, "Save images to a subdirectory"),
|
|
||||||
"grid_save_to_dirs": OptionInfo(True, "Save grids to a subdirectory"),
|
|
||||||
"use_save_to_dirs_for_ui": OptionInfo(False, "When using \"Save\" button, save images to a subdirectory"),
|
|
||||||
"directories_filename_pattern": OptionInfo("[date]", "Directory name pattern", component_args=hide_dirs).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Images-Filename-Name-and-Subdirectory"),
|
|
||||||
"directories_max_prompt_words": OptionInfo(8, "Max prompt words for [prompt_words] pattern", gr.Slider, {"minimum": 1, "maximum": 20, "step": 1, **hide_dirs}),
|
|
||||||
}))
|
|
||||||
|
|
||||||
options_templates.update(options_section(('upscaling', "Upscaling"), {
|
|
||||||
"ESRGAN_tile": OptionInfo(192, "Tile size for ESRGAN upscalers.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}).info("0 = no tiling"),
|
|
||||||
"ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap for ESRGAN upscalers.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}).info("Low values = visible seam"),
|
|
||||||
"realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI.", gr.CheckboxGroup, lambda: {"choices": shared_items.realesrgan_models_names()}),
|
|
||||||
"upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in sd_upscalers]}),
|
|
||||||
}))
|
|
||||||
|
|
||||||
options_templates.update(options_section(('face-restoration', "Face restoration"), {
|
|
||||||
"face_restoration_model": OptionInfo("CodeFormer", "Face restoration model", gr.Radio, lambda: {"choices": [x.name() for x in face_restorers]}),
|
|
||||||
"code_former_weight": OptionInfo(0.5, "CodeFormer weight", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}).info("0 = maximum effect; 1 = minimum effect"),
|
|
||||||
"face_restoration_unload": OptionInfo(False, "Move face restoration model from VRAM into RAM after processing"),
|
|
||||||
}))
|
|
||||||
|
|
||||||
options_templates.update(options_section(('system', "System"), {
|
|
||||||
"show_warnings": OptionInfo(False, "Show warnings in console.").needs_reload_ui(),
|
|
||||||
"show_gradio_deprecation_warnings": OptionInfo(True, "Show gradio deprecation warnings in console.").needs_reload_ui(),
|
|
||||||
"memmon_poll_rate": OptionInfo(8, "VRAM usage polls per second during generation.", gr.Slider, {"minimum": 0, "maximum": 40, "step": 1}).info("0 = disable"),
|
|
||||||
"samples_log_stdout": OptionInfo(False, "Always print all generation info to standard output"),
|
|
||||||
"multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."),
|
|
||||||
"print_hypernet_extra": OptionInfo(False, "Print extra hypernetwork information to console."),
|
|
||||||
"list_hidden_files": OptionInfo(True, "Load models/files in hidden directories").info("directory is hidden if its name starts with \".\""),
|
|
||||||
"disable_mmap_load_safetensors": OptionInfo(False, "Disable memmapping for loading .safetensors files.").info("fixes very slow loading speed in some cases"),
|
|
||||||
"hide_ldm_prints": OptionInfo(True, "Prevent Stability-AI's ldm/sgm modules from printing noise to console."),
|
|
||||||
}))
|
|
||||||
|
|
||||||
options_templates.update(options_section(('training', "Training"), {
|
|
||||||
"unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."),
|
|
||||||
"pin_memory": OptionInfo(False, "Turn on pin_memory for DataLoader. Makes training slightly faster but can increase memory usage."),
|
|
||||||
"save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training of embedding or HN can be resumed with the matching optim file."),
|
|
||||||
"save_training_settings_to_txt": OptionInfo(True, "Save textual inversion and hypernet settings to a text file whenever training starts."),
|
|
||||||
"dataset_filename_word_regex": OptionInfo("", "Filename word regex"),
|
|
||||||
"dataset_filename_join_string": OptionInfo(" ", "Filename join string"),
|
|
||||||
"training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}),
|
|
||||||
"training_write_csv_every": OptionInfo(500, "Save an csv containing the loss to log directory every N steps, 0 to disable"),
|
|
||||||
"training_xattention_optimizations": OptionInfo(False, "Use cross attention optimizations while training"),
|
|
||||||
"training_enable_tensorboard": OptionInfo(False, "Enable tensorboard logging."),
|
|
||||||
"training_tensorboard_save_images": OptionInfo(False, "Save generated images within tensorboard."),
|
|
||||||
"training_tensorboard_flush_every": OptionInfo(120, "How often, in seconds, to flush the pending tensorboard events and summaries to disk."),
|
|
||||||
}))
|
|
||||||
|
|
||||||
options_templates.update(options_section(('sd', "Stable Diffusion"), {
|
|
||||||
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints),
|
|
||||||
"sd_checkpoints_limit": OptionInfo(1, "Maximum number of checkpoints loaded at the same time", gr.Slider, {"minimum": 1, "maximum": 10, "step": 1}),
|
|
||||||
"sd_checkpoints_keep_in_cpu": OptionInfo(True, "Only keep one model on device").info("will keep models other than the currently used one in RAM rather than VRAM"),
|
|
||||||
"sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}).info("obsolete; set to 0 and use the two settings above instead"),
|
|
||||||
"sd_unet": OptionInfo("Automatic", "SD Unet", gr.Dropdown, lambda: {"choices": shared_items.sd_unet_items()}, refresh=shared_items.refresh_unet_list).info("choose Unet model: Automatic = use one with same filename as checkpoint; None = use Unet from checkpoint"),
|
|
||||||
"enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds").needs_reload_ui(),
|
|
||||||
"enable_emphasis": OptionInfo(True, "Enable emphasis").info("use (text) to make model pay more attention to text and [text] to make it pay less attention"),
|
|
||||||
"enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"),
|
|
||||||
"comma_padding_backtrack": OptionInfo(20, "Prompt word wrap length limit", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1}).info("in tokens - for texts shorter than specified, if they don't fit into 75 token limit, move them to the next 75 token chunk"),
|
|
||||||
"CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#clip-skip").info("ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer"),
|
|
||||||
"upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"),
|
|
||||||
"randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU", "NV"]}).info("changes seeds drastically; use CPU to produce the same picture across different videocard vendors; use NV to produce same picture as on NVidia videocards"),
|
|
||||||
}))
|
|
||||||
|
|
||||||
options_templates.update(options_section(('sdxl', "Stable Diffusion XL"), {
|
|
||||||
"sdxl_crop_top": OptionInfo(0, "crop top coordinate"),
|
|
||||||
"sdxl_crop_left": OptionInfo(0, "crop left coordinate"),
|
|
||||||
"sdxl_refiner_low_aesthetic_score": OptionInfo(2.5, "SDXL low aesthetic score", gr.Number).info("used for refiner model negative prompt"),
|
|
||||||
"sdxl_refiner_high_aesthetic_score": OptionInfo(6.0, "SDXL high aesthetic score", gr.Number).info("used for refiner model prompt"),
|
|
||||||
}))
|
|
||||||
|
|
||||||
options_templates.update(options_section(('vae', "VAE"), {
|
|
||||||
"sd_vae_explanation": OptionHTML("""
|
|
||||||
<abbr title='Variational autoencoder'>VAE</abbr> is a neural network that transforms a standard <abbr title='red/green/blue'>RGB</abbr>
|
|
||||||
image into latent space representation and back. Latent space representation is what stable diffusion is working on during sampling
|
|
||||||
(i.e. when the progress bar is between empty and full). For txt2img, VAE is used to create a resulting image after the sampling is finished.
|
|
||||||
For img2img, VAE is used to process user's input image before the sampling, and to create an image after sampling.
|
|
||||||
"""),
|
|
||||||
"sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
|
|
||||||
"sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": shared_items.sd_vae_items()}, refresh=shared_items.refresh_vae_list).info("choose VAE model: Automatic = use one with same filename as checkpoint; None = use VAE from checkpoint"),
|
|
||||||
"sd_vae_as_default": OptionInfo(True, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"),
|
|
||||||
"auto_vae_precision": OptionInfo(True, "Automaticlly revert VAE to 32-bit floats").info("triggers when a tensor with NaNs is produced in VAE; disabling the option in this case will result in a black square image"),
|
|
||||||
"sd_vae_encode_method": OptionInfo("Full", "VAE type for encode", gr.Radio, {"choices": ["Full", "TAESD"]}).info("method to encode image to latent (use in img2img, hires-fix or inpaint mask)"),
|
|
||||||
"sd_vae_decode_method": OptionInfo("Full", "VAE type for decode", gr.Radio, {"choices": ["Full", "TAESD"]}).info("method to decode latent to image"),
|
|
||||||
}))
|
|
||||||
|
|
||||||
options_templates.update(options_section(('img2img', "img2img"), {
|
|
||||||
"inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
|
||||||
"initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.5, "maximum": 1.5, "step": 0.01}),
|
|
||||||
"img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
|
|
||||||
"img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies.").info("normally you'd do less with less denoising"),
|
|
||||||
"img2img_background_color": OptionInfo("#ffffff", "With img2img, fill transparent parts of the input image with this color.", ui_components.FormColorPicker, {}),
|
|
||||||
"img2img_editor_height": OptionInfo(720, "Height of the image editor", gr.Slider, {"minimum": 80, "maximum": 1600, "step": 1}).info("in pixels").needs_reload_ui(),
|
|
||||||
"img2img_sketch_default_brush_color": OptionInfo("#ffffff", "Sketch initial brush color", ui_components.FormColorPicker, {}).info("default brush color of img2img sketch").needs_reload_ui(),
|
|
||||||
"img2img_inpaint_mask_brush_color": OptionInfo("#ffffff", "Inpaint mask brush color", ui_components.FormColorPicker, {}).info("brush color of inpaint mask").needs_reload_ui(),
|
|
||||||
"img2img_inpaint_sketch_default_brush_color": OptionInfo("#ffffff", "Inpaint sketch initial brush color", ui_components.FormColorPicker, {}).info("default brush color of img2img inpaint sketch").needs_reload_ui(),
|
|
||||||
"return_mask": OptionInfo(False, "For inpainting, include the greyscale mask in results for web"),
|
|
||||||
"return_mask_composite": OptionInfo(False, "For inpainting, include masked composite in results for web"),
|
|
||||||
}))
|
|
||||||
|
|
||||||
options_templates.update(options_section(('optimizations', "Optimizations"), {
|
|
||||||
"cross_attention_optimization": OptionInfo("Automatic", "Cross attention optimization", gr.Dropdown, lambda: {"choices": shared_items.cross_attention_optimizations()}),
|
|
||||||
"s_min_uncond": OptionInfo(0.0, "Negative Guidance minimum sigma", gr.Slider, {"minimum": 0.0, "maximum": 15.0, "step": 0.01}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9177").info("skip negative prompt for some steps when the image is almost ready; 0=disable, higher=faster"),
|
|
||||||
"token_merging_ratio": OptionInfo(0.0, "Token merging ratio", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9256").info("0=disable, higher=faster"),
|
|
||||||
"token_merging_ratio_img2img": OptionInfo(0.0, "Token merging ratio for img2img", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).info("only applies if non-zero and overrides above"),
|
|
||||||
"token_merging_ratio_hr": OptionInfo(0.0, "Token merging ratio for high-res pass", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).info("only applies if non-zero and overrides above"),
|
|
||||||
"pad_cond_uncond": OptionInfo(False, "Pad prompt/negative prompt to be same length").info("improves performance when prompt and negative prompt have different lengths; changes seeds"),
|
|
||||||
"persistent_cond_cache": OptionInfo(True, "Persistent cond cache").info("Do not recalculate conds from prompts if prompts have not changed since previous calculation"),
|
|
||||||
}))
|
|
||||||
|
|
||||||
options_templates.update(options_section(('compatibility', "Compatibility"), {
|
|
||||||
"use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."),
|
|
||||||
"use_old_karras_scheduler_sigmas": OptionInfo(False, "Use old karras scheduler sigmas (0.1 to 10)."),
|
|
||||||
"no_dpmpp_sde_batch_determinism": OptionInfo(False, "Do not make DPM++ SDE deterministic across different batch sizes."),
|
|
||||||
"use_old_hires_fix_width_height": OptionInfo(False, "For hires fix, use width/height sliders to set final resolution rather than first pass (disables Upscale by, Resize width/height to)."),
|
|
||||||
"dont_fix_second_order_samplers_schedule": OptionInfo(False, "Do not fix prompt schedule for second order samplers."),
|
|
||||||
"hires_fix_use_firstpass_conds": OptionInfo(False, "For hires fix, calculate conds of second pass using extra networks of first pass."),
|
|
||||||
}))
|
|
||||||
|
|
||||||
options_templates.update(options_section(('interrogate', "Interrogate"), {
|
|
||||||
"interrogate_keep_models_in_memory": OptionInfo(False, "Keep models in VRAM"),
|
|
||||||
"interrogate_return_ranks": OptionInfo(False, "Include ranks of model tags matches in results.").info("booru only"),
|
|
||||||
"interrogate_clip_num_beams": OptionInfo(1, "BLIP: num_beams", gr.Slider, {"minimum": 1, "maximum": 16, "step": 1}),
|
|
||||||
"interrogate_clip_min_length": OptionInfo(24, "BLIP: minimum description length", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}),
|
|
||||||
"interrogate_clip_max_length": OptionInfo(48, "BLIP: maximum description length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}),
|
|
||||||
"interrogate_clip_dict_limit": OptionInfo(1500, "CLIP: maximum number of lines in text file").info("0 = No limit"),
|
|
||||||
"interrogate_clip_skip_categories": OptionInfo([], "CLIP: skip inquire categories", gr.CheckboxGroup, lambda: {"choices": modules.interrogate.category_types()}, refresh=modules.interrogate.category_types),
|
|
||||||
"interrogate_deepbooru_score_threshold": OptionInfo(0.5, "deepbooru: score threshold", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}),
|
|
||||||
"deepbooru_sort_alpha": OptionInfo(True, "deepbooru: sort tags alphabetically").info("if not: sort by score"),
|
|
||||||
"deepbooru_use_spaces": OptionInfo(True, "deepbooru: use spaces in tags").info("if not: use underscores"),
|
|
||||||
"deepbooru_escape": OptionInfo(True, "deepbooru: escape (\\) brackets").info("so they are used as literal brackets and not for emphasis"),
|
|
||||||
"deepbooru_filter_tags": OptionInfo("", "deepbooru: filter out those tags").info("separate by comma"),
|
|
||||||
}))
|
|
||||||
|
|
||||||
options_templates.update(options_section(('extra_networks', "Extra Networks"), {
|
|
||||||
"extra_networks_show_hidden_directories": OptionInfo(True, "Show hidden directories").info("directory is hidden if its name starts with \".\"."),
|
|
||||||
"extra_networks_hidden_models": OptionInfo("When searched", "Show cards for models in hidden directories", gr.Radio, {"choices": ["Always", "When searched", "Never"]}).info('"When searched" option will only show the item when the search string has 4 characters or more'),
|
|
||||||
"extra_networks_default_multiplier": OptionInfo(1.0, "Default multiplier for extra networks", gr.Slider, {"minimum": 0.0, "maximum": 2.0, "step": 0.01}),
|
|
||||||
"extra_networks_card_width": OptionInfo(0, "Card width for Extra Networks").info("in pixels"),
|
|
||||||
"extra_networks_card_height": OptionInfo(0, "Card height for Extra Networks").info("in pixels"),
|
|
||||||
"extra_networks_card_text_scale": OptionInfo(1.0, "Card text scale", gr.Slider, {"minimum": 0.0, "maximum": 2.0, "step": 0.01}).info("1 = original size"),
|
|
||||||
"extra_networks_card_show_desc": OptionInfo(True, "Show description on card"),
|
|
||||||
"extra_networks_add_text_separator": OptionInfo(" ", "Extra networks separator").info("extra text to add before <...> when adding extra network to prompt"),
|
|
||||||
"ui_extra_networks_tab_reorder": OptionInfo("", "Extra networks tab order").needs_reload_ui(),
|
|
||||||
"textual_inversion_print_at_load": OptionInfo(False, "Print a list of Textual Inversion embeddings when loading model"),
|
|
||||||
"textual_inversion_add_hashes_to_infotext": OptionInfo(True, "Add Textual Inversion hashes to infotext"),
|
|
||||||
"sd_hypernetwork": OptionInfo("None", "Add hypernetwork to prompt", gr.Dropdown, lambda: {"choices": ["None", *hypernetworks]}, refresh=reload_hypernetworks),
|
|
||||||
}))
|
|
||||||
|
|
||||||
options_templates.update(options_section(('ui', "User interface"), {
|
|
||||||
"localization": OptionInfo("None", "Localization", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)).needs_reload_ui(),
|
|
||||||
"gradio_theme": OptionInfo("Default", "Gradio theme", ui_components.DropdownEditable, lambda: {"choices": ["Default"] + gradio_hf_hub_themes}).info("you can also manually enter any of themes from the <a href='https://huggingface.co/spaces/gradio/theme-gallery'>gallery</a>.").needs_reload_ui(),
|
|
||||||
"gradio_themes_cache": OptionInfo(True, "Cache gradio themes locally").info("disable to update the selected Gradio theme"),
|
|
||||||
"return_grid": OptionInfo(True, "Show grid in results for web"),
|
|
||||||
"do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
|
|
||||||
"send_seed": OptionInfo(True, "Send seed when sending prompt or image to other interface"),
|
|
||||||
"send_size": OptionInfo(True, "Send size when sending prompt or image to another interface"),
|
|
||||||
"js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"),
|
|
||||||
"js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"),
|
|
||||||
"js_modal_lightbox_gamepad": OptionInfo(False, "Navigate image viewer with gamepad"),
|
|
||||||
"js_modal_lightbox_gamepad_repeat": OptionInfo(250, "Gamepad repeat period, in milliseconds"),
|
|
||||||
"show_progress_in_title": OptionInfo(True, "Show generation progress in window title."),
|
|
||||||
"samplers_in_dropdown": OptionInfo(True, "Use dropdown for sampler selection instead of radio group").needs_reload_ui(),
|
|
||||||
"dimensions_and_batch_together": OptionInfo(True, "Show Width/Height and Batch sliders in same row").needs_reload_ui(),
|
|
||||||
"keyedit_precision_attention": OptionInfo(0.1, "Ctrl+up/down precision when editing (attention:1.1)", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
|
|
||||||
"keyedit_precision_extra": OptionInfo(0.05, "Ctrl+up/down precision when editing <extra networks:0.9>", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
|
|
||||||
"keyedit_delimiters": OptionInfo(".,\\/!?%^*;:{}=`~()", "Ctrl+up/down word delimiters"),
|
|
||||||
"keyedit_move": OptionInfo(True, "Alt+left/right moves prompt elements"),
|
|
||||||
"quicksettings_list": OptionInfo(["sd_model_checkpoint"], "Quicksettings list", ui_components.DropdownMulti, lambda: {"choices": list(opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that appear at the top of page rather than in settings tab").needs_reload_ui(),
|
|
||||||
"ui_tab_order": OptionInfo([], "UI tab order", ui_components.DropdownMulti, lambda: {"choices": list(tab_names)}).needs_reload_ui(),
|
|
||||||
"hidden_tabs": OptionInfo([], "Hidden UI tabs", ui_components.DropdownMulti, lambda: {"choices": list(tab_names)}).needs_reload_ui(),
|
|
||||||
"ui_reorder_list": OptionInfo([], "txt2img/img2img UI item order", ui_components.DropdownMulti, lambda: {"choices": list(shared_items.ui_reorder_categories())}).info("selected items appear first").needs_reload_ui(),
|
|
||||||
"hires_fix_show_sampler": OptionInfo(False, "Hires fix: show hires checkpoint and sampler selection").needs_reload_ui(),
|
|
||||||
"hires_fix_show_prompts": OptionInfo(False, "Hires fix: show hires prompt and negative prompt").needs_reload_ui(),
|
|
||||||
"disable_token_counters": OptionInfo(False, "Disable prompt token counters").needs_reload_ui(),
|
|
||||||
}))
|
|
||||||
|
|
||||||
|
|
||||||
options_templates.update(options_section(('infotext', "Infotext"), {
|
|
||||||
"add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
|
|
||||||
"add_model_name_to_info": OptionInfo(True, "Add model name to generation information"),
|
|
||||||
"add_user_name_to_info": OptionInfo(False, "Add user name to generation information when authenticated"),
|
|
||||||
"add_version_to_infotext": OptionInfo(True, "Add program version to generation information"),
|
|
||||||
"disable_weights_auto_swap": OptionInfo(True, "Disregard checkpoint information from pasted infotext").info("when reading generation parameters from text into UI"),
|
|
||||||
"infotext_styles": OptionInfo("Apply if any", "Infer styles from prompts of pasted infotext", gr.Radio, {"choices": ["Ignore", "Apply", "Discard", "Apply if any"]}).info("when reading generation parameters from text into UI)").html("""<ul style='margin-left: 1.5em'>
|
|
||||||
<li>Ignore: keep prompt and styles dropdown as it is.</li>
|
|
||||||
<li>Apply: remove style text from prompt, always replace styles dropdown value with found styles (even if none are found).</li>
|
|
||||||
<li>Discard: remove style text from prompt, keep styles dropdown as it is.</li>
|
|
||||||
<li>Apply if any: remove style text from prompt; if any styles are found in prompt, put them into styles dropdown, otherwise keep it as it is.</li>
|
|
||||||
</ul>"""),
|
|
||||||
|
|
||||||
}))
|
|
||||||
|
|
||||||
options_templates.update(options_section(('ui', "Live previews"), {
|
|
||||||
"show_progressbar": OptionInfo(True, "Show progressbar"),
|
|
||||||
"live_previews_enable": OptionInfo(True, "Show live previews of the created image"),
|
|
||||||
"live_previews_image_format": OptionInfo("png", "Live preview file format", gr.Radio, {"choices": ["jpeg", "png", "webp"]}),
|
|
||||||
"show_progress_grid": OptionInfo(True, "Show previews of all images generated in a batch as a grid"),
|
|
||||||
"show_progress_every_n_steps": OptionInfo(10, "Live preview display period", gr.Slider, {"minimum": -1, "maximum": 32, "step": 1}).info("in sampling steps - show new live preview image every N sampling steps; -1 = only show after completion of batch"),
|
|
||||||
"show_progress_type": OptionInfo("Approx NN", "Live preview method", gr.Radio, {"choices": ["Full", "Approx NN", "Approx cheap", "TAESD"]}).info("Full = slow but pretty; Approx NN and TAESD = fast but low quality; Approx cheap = super fast but terrible otherwise"),
|
|
||||||
"live_preview_content": OptionInfo("Prompt", "Live preview subject", gr.Radio, {"choices": ["Combined", "Prompt", "Negative prompt"]}),
|
|
||||||
"live_preview_refresh_period": OptionInfo(1000, "Progressbar and preview update period").info("in milliseconds"),
|
|
||||||
}))
|
|
||||||
|
|
||||||
options_templates.update(options_section(('sampler-params', "Sampler parameters"), {
|
|
||||||
"hide_samplers": OptionInfo([], "Hide samplers in user interface", gr.CheckboxGroup, lambda: {"choices": [x.name for x in list_samplers()]}).needs_reload_ui(),
|
|
||||||
"eta_ddim": OptionInfo(0.0, "Eta for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}).info("noise multiplier; higher = more unperdictable results"),
|
|
||||||
"eta_ancestral": OptionInfo(1.0, "Eta for ancestral samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}).info("noise multiplier; applies to Euler a and other samplers that have a in them"),
|
|
||||||
"ddim_discretize": OptionInfo('uniform', "img2img DDIM discretize", gr.Radio, {"choices": ['uniform', 'quad']}),
|
|
||||||
's_churn': OptionInfo(0.0, "sigma churn", gr.Slider, {"minimum": 0.0, "maximum": 100.0, "step": 0.01}),
|
|
||||||
's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
|
||||||
's_tmax': OptionInfo(0.0, "sigma tmax", gr.Slider, {"minimum": 0.0, "maximum": 999.0, "step": 0.01}).info("0 = inf"),
|
|
||||||
's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
|
||||||
'k_sched_type': OptionInfo("Automatic", "scheduler type", gr.Dropdown, {"choices": ["Automatic", "karras", "exponential", "polyexponential"]}).info("lets you override the noise schedule for k-diffusion samplers; choosing Automatic disables the three parameters below"),
|
|
||||||
'sigma_min': OptionInfo(0.0, "sigma min", gr.Number).info("0 = default (~0.03); minimum noise strength for k-diffusion noise scheduler"),
|
|
||||||
'sigma_max': OptionInfo(0.0, "sigma max", gr.Number).info("0 = default (~14.6); maximum noise strength for k-diffusion noise schedule"),
|
|
||||||
'rho': OptionInfo(0.0, "rho", gr.Number).info("0 = default (7 for karras, 1 for polyexponential); higher values result in a more steep noise schedule (decreases faster)"),
|
|
||||||
'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}).info("ENSD; does not improve anything, just produces different results for ancestral samplers - only useful for reproducing images"),
|
|
||||||
'always_discard_next_to_last_sigma': OptionInfo(False, "Always discard next-to-last sigma").link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/6044"),
|
|
||||||
'uni_pc_variant': OptionInfo("bh1", "UniPC variant", gr.Radio, {"choices": ["bh1", "bh2", "vary_coeff"]}),
|
|
||||||
'uni_pc_skip_type': OptionInfo("time_uniform", "UniPC skip type", gr.Radio, {"choices": ["time_uniform", "time_quadratic", "logSNR"]}),
|
|
||||||
'uni_pc_order': OptionInfo(3, "UniPC order", gr.Slider, {"minimum": 1, "maximum": 50, "step": 1}).info("must be < sampling steps"),
|
|
||||||
'uni_pc_lower_order_final': OptionInfo(True, "UniPC lower order final"),
|
|
||||||
}))
|
|
||||||
|
|
||||||
options_templates.update(options_section(('postprocessing', "Postprocessing"), {
|
|
||||||
'postprocessing_enable_in_main_ui': OptionInfo([], "Enable postprocessing operations in txt2img and img2img tabs", ui_components.DropdownMulti, lambda: {"choices": [x.name for x in shared_items.postprocessing_scripts()]}),
|
|
||||||
'postprocessing_operation_order': OptionInfo([], "Postprocessing operation order", ui_components.DropdownMulti, lambda: {"choices": [x.name for x in shared_items.postprocessing_scripts()]}),
|
|
||||||
'upscaling_max_images_in_cache': OptionInfo(5, "Maximum number of images in upscaling cache", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
|
|
||||||
}))
|
|
||||||
|
|
||||||
options_templates.update(options_section((None, "Hidden options"), {
|
|
||||||
"disabled_extensions": OptionInfo([], "Disable these extensions"),
|
|
||||||
"disable_all_extensions": OptionInfo("none", "Disable all extensions (preserves the list of disabled extensions)", gr.Radio, {"choices": ["none", "extra", "all"]}),
|
|
||||||
"restore_config_state_file": OptionInfo("", "Config state file to restore from, under 'config-states/' folder"),
|
|
||||||
"sd_checkpoint_hash": OptionInfo("", "SHA256 hash of the current checkpoint"),
|
|
||||||
}))
|
|
||||||
|
|
||||||
|
|
||||||
options_templates.update()
|
|
||||||
|
|
||||||
|
|
||||||
class Options:
|
|
||||||
data = None
|
|
||||||
data_labels = options_templates
|
|
||||||
typemap = {int: float}
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.data = {k: v.default for k, v in self.data_labels.items()}
|
|
||||||
|
|
||||||
def __setattr__(self, key, value):
|
|
||||||
if self.data is not None:
|
|
||||||
if key in self.data or key in self.data_labels:
|
|
||||||
assert not cmd_opts.freeze_settings, "changing settings is disabled"
|
|
||||||
|
|
||||||
info = opts.data_labels.get(key, None)
|
|
||||||
if info.do_not_save:
|
|
||||||
return
|
|
||||||
|
|
||||||
comp_args = info.component_args if info else None
|
|
||||||
if isinstance(comp_args, dict) and comp_args.get('visible', True) is False:
|
|
||||||
raise RuntimeError(f"not possible to set {key} because it is restricted")
|
|
||||||
|
|
||||||
if cmd_opts.hide_ui_dir_config and key in restricted_opts:
|
|
||||||
raise RuntimeError(f"not possible to set {key} because it is restricted")
|
|
||||||
|
|
||||||
self.data[key] = value
|
|
||||||
return
|
|
||||||
|
|
||||||
return super(Options, self).__setattr__(key, value)
|
|
||||||
|
|
||||||
def __getattr__(self, item):
|
|
||||||
if self.data is not None:
|
|
||||||
if item in self.data:
|
|
||||||
return self.data[item]
|
|
||||||
|
|
||||||
if item in self.data_labels:
|
|
||||||
return self.data_labels[item].default
|
|
||||||
|
|
||||||
return super(Options, self).__getattribute__(item)
|
|
||||||
|
|
||||||
def set(self, key, value):
|
|
||||||
"""sets an option and calls its onchange callback, returning True if the option changed and False otherwise"""
|
|
||||||
|
|
||||||
oldval = self.data.get(key, None)
|
|
||||||
if oldval == value:
|
|
||||||
return False
|
|
||||||
|
|
||||||
if self.data_labels[key].do_not_save:
|
|
||||||
return False
|
|
||||||
|
|
||||||
try:
|
|
||||||
setattr(self, key, value)
|
|
||||||
except RuntimeError:
|
|
||||||
return False
|
|
||||||
|
|
||||||
if self.data_labels[key].onchange is not None:
|
|
||||||
try:
|
|
||||||
self.data_labels[key].onchange()
|
|
||||||
except Exception as e:
|
|
||||||
errors.display(e, f"changing setting {key} to {value}")
|
|
||||||
setattr(self, key, oldval)
|
|
||||||
return False
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
def get_default(self, key):
|
|
||||||
"""returns the default value for the key"""
|
|
||||||
|
|
||||||
data_label = self.data_labels.get(key)
|
|
||||||
if data_label is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return data_label.default
|
|
||||||
|
|
||||||
def save(self, filename):
|
|
||||||
assert not cmd_opts.freeze_settings, "saving settings is disabled"
|
|
||||||
|
|
||||||
with open(filename, "w", encoding="utf8") as file:
|
|
||||||
json.dump(self.data, file, indent=4)
|
|
||||||
|
|
||||||
def same_type(self, x, y):
|
|
||||||
if x is None or y is None:
|
|
||||||
return True
|
|
||||||
|
|
||||||
type_x = self.typemap.get(type(x), type(x))
|
|
||||||
type_y = self.typemap.get(type(y), type(y))
|
|
||||||
|
|
||||||
return type_x == type_y
|
|
||||||
|
|
||||||
def load(self, filename):
|
|
||||||
with open(filename, "r", encoding="utf8") as file:
|
|
||||||
self.data = json.load(file)
|
|
||||||
|
|
||||||
# 1.1.1 quicksettings list migration
|
|
||||||
if self.data.get('quicksettings') is not None and self.data.get('quicksettings_list') is None:
|
|
||||||
self.data['quicksettings_list'] = [i.strip() for i in self.data.get('quicksettings').split(',')]
|
|
||||||
|
|
||||||
# 1.4.0 ui_reorder
|
|
||||||
if isinstance(self.data.get('ui_reorder'), str) and self.data.get('ui_reorder') and "ui_reorder_list" not in self.data:
|
|
||||||
self.data['ui_reorder_list'] = [i.strip() for i in self.data.get('ui_reorder').split(',')]
|
|
||||||
|
|
||||||
bad_settings = 0
|
|
||||||
for k, v in self.data.items():
|
|
||||||
info = self.data_labels.get(k, None)
|
|
||||||
if info is not None and not self.same_type(info.default, v):
|
|
||||||
print(f"Warning: bad setting value: {k}: {v} ({type(v).__name__}; expected {type(info.default).__name__})", file=sys.stderr)
|
|
||||||
bad_settings += 1
|
|
||||||
|
|
||||||
if bad_settings > 0:
|
|
||||||
print(f"The program is likely to not work with bad settings.\nSettings file: {filename}\nEither fix the file, or delete it and restart.", file=sys.stderr)
|
|
||||||
|
|
||||||
def onchange(self, key, func, call=True):
|
|
||||||
item = self.data_labels.get(key)
|
|
||||||
item.onchange = func
|
|
||||||
|
|
||||||
if call:
|
|
||||||
func()
|
|
||||||
|
|
||||||
def dumpjson(self):
|
|
||||||
d = {k: self.data.get(k, v.default) for k, v in self.data_labels.items()}
|
|
||||||
d["_comments_before"] = {k: v.comment_before for k, v in self.data_labels.items() if v.comment_before is not None}
|
|
||||||
d["_comments_after"] = {k: v.comment_after for k, v in self.data_labels.items() if v.comment_after is not None}
|
|
||||||
return json.dumps(d)
|
|
||||||
|
|
||||||
def add_option(self, key, info):
|
|
||||||
self.data_labels[key] = info
|
|
||||||
|
|
||||||
def reorder(self):
|
|
||||||
"""reorder settings so that all items related to section always go together"""
|
|
||||||
|
|
||||||
section_ids = {}
|
|
||||||
settings_items = self.data_labels.items()
|
|
||||||
for _, item in settings_items:
|
|
||||||
if item.section not in section_ids:
|
|
||||||
section_ids[item.section] = len(section_ids)
|
|
||||||
|
|
||||||
self.data_labels = dict(sorted(settings_items, key=lambda x: section_ids[x[1].section]))
|
|
||||||
|
|
||||||
def cast_value(self, key, value):
|
|
||||||
"""casts an arbitrary to the same type as this setting's value with key
|
|
||||||
Example: cast_value("eta_noise_seed_delta", "12") -> returns 12 (an int rather than str)
|
|
||||||
"""
|
|
||||||
|
|
||||||
if value is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
default_value = self.data_labels[key].default
|
|
||||||
if default_value is None:
|
|
||||||
default_value = getattr(self, key, None)
|
|
||||||
if default_value is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
expected_type = type(default_value)
|
|
||||||
if expected_type == bool and value == "False":
|
|
||||||
value = False
|
|
||||||
else:
|
|
||||||
value = expected_type(value)
|
|
||||||
|
|
||||||
return value
|
|
||||||
|
|
||||||
|
|
||||||
opts = Options()
|
|
||||||
if os.path.exists(config_filename):
|
|
||||||
opts.load(config_filename)
|
|
||||||
|
|
||||||
|
|
||||||
class Shared(sys.modules[__name__].__class__):
|
|
||||||
"""
|
|
||||||
this class is here to provide sd_model field as a property, so that it can be created and loaded on demand rather than
|
|
||||||
at program startup.
|
|
||||||
"""
|
|
||||||
|
|
||||||
sd_model_val = None
|
|
||||||
|
|
||||||
@property
|
|
||||||
def sd_model(self):
|
|
||||||
import modules.sd_models
|
|
||||||
|
|
||||||
return modules.sd_models.model_data.get_sd_model()
|
|
||||||
|
|
||||||
@sd_model.setter
|
|
||||||
def sd_model(self, value):
|
|
||||||
import modules.sd_models
|
|
||||||
|
|
||||||
modules.sd_models.model_data.set_sd_model(value)
|
|
||||||
|
|
||||||
|
|
||||||
sd_model: LatentDiffusion = None # this var is here just for IDE's type checking; it cannot be accessed because the class field above will be accessed instead
|
|
||||||
sys.modules[__name__].__class__ = Shared
|
|
||||||
|
|
||||||
settings_components = None
|
settings_components = None
|
||||||
"""assinged from ui.py, a mapping on setting names to gradio components repsponsible for those settings"""
|
"""assinged from ui.py, a mapping on setting names to gradio components repsponsible for those settings"""
|
||||||
|
|
||||||
|
tab_names = []
|
||||||
|
|
||||||
latent_upscale_default_mode = "Latent"
|
latent_upscale_default_mode = "Latent"
|
||||||
latent_upscale_modes = {
|
latent_upscale_modes = {
|
||||||
"Latent": {"mode": "bilinear", "antialias": False},
|
"Latent": {"mode": "bilinear", "antialias": False},
|
||||||
@@ -851,121 +64,24 @@ progress_print_out = sys.stdout
|
|||||||
|
|
||||||
gradio_theme = gr.themes.Base()
|
gradio_theme = gr.themes.Base()
|
||||||
|
|
||||||
|
total_tqdm = None
|
||||||
|
|
||||||
def reload_gradio_theme(theme_name=None):
|
mem_mon = None
|
||||||
global gradio_theme
|
|
||||||
if not theme_name:
|
|
||||||
theme_name = opts.gradio_theme
|
|
||||||
|
|
||||||
default_theme_args = dict(
|
options_section = options.options_section
|
||||||
font=["Source Sans Pro", 'ui-sans-serif', 'system-ui', 'sans-serif'],
|
OptionInfo = options.OptionInfo
|
||||||
font_mono=['IBM Plex Mono', 'ui-monospace', 'Consolas', 'monospace'],
|
OptionHTML = options.OptionHTML
|
||||||
)
|
|
||||||
|
|
||||||
if theme_name == "Default":
|
natural_sort_key = util.natural_sort_key
|
||||||
gradio_theme = gr.themes.Default(**default_theme_args)
|
listfiles = util.listfiles
|
||||||
else:
|
html_path = util.html_path
|
||||||
try:
|
html = util.html
|
||||||
theme_cache_dir = os.path.join(script_path, 'tmp', 'gradio_themes')
|
walk_files = util.walk_files
|
||||||
theme_cache_path = os.path.join(theme_cache_dir, f'{theme_name.replace("/", "_")}.json')
|
ldm_print = util.ldm_print
|
||||||
if opts.gradio_themes_cache and os.path.exists(theme_cache_path):
|
|
||||||
gradio_theme = gr.themes.ThemeClass.load(theme_cache_path)
|
|
||||||
else:
|
|
||||||
os.makedirs(theme_cache_dir, exist_ok=True)
|
|
||||||
gradio_theme = gr.themes.ThemeClass.from_hub(theme_name)
|
|
||||||
gradio_theme.dump(theme_cache_path)
|
|
||||||
except Exception as e:
|
|
||||||
errors.display(e, "changing gradio theme")
|
|
||||||
gradio_theme = gr.themes.Default(**default_theme_args)
|
|
||||||
|
|
||||||
|
reload_gradio_theme = shared_gradio_themes.reload_gradio_theme
|
||||||
|
|
||||||
class TotalTQDM:
|
list_checkpoint_tiles = shared_items.list_checkpoint_tiles
|
||||||
def __init__(self):
|
refresh_checkpoints = shared_items.refresh_checkpoints
|
||||||
self._tqdm = None
|
list_samplers = shared_items.list_samplers
|
||||||
|
reload_hypernetworks = shared_items.reload_hypernetworks
|
||||||
def reset(self):
|
|
||||||
self._tqdm = tqdm.tqdm(
|
|
||||||
desc="Total progress",
|
|
||||||
total=state.job_count * state.sampling_steps,
|
|
||||||
position=1,
|
|
||||||
file=progress_print_out
|
|
||||||
)
|
|
||||||
|
|
||||||
def update(self):
|
|
||||||
if not opts.multiple_tqdm or cmd_opts.disable_console_progressbars:
|
|
||||||
return
|
|
||||||
if self._tqdm is None:
|
|
||||||
self.reset()
|
|
||||||
self._tqdm.update()
|
|
||||||
|
|
||||||
def updateTotal(self, new_total):
|
|
||||||
if not opts.multiple_tqdm or cmd_opts.disable_console_progressbars:
|
|
||||||
return
|
|
||||||
if self._tqdm is None:
|
|
||||||
self.reset()
|
|
||||||
self._tqdm.total = new_total
|
|
||||||
|
|
||||||
def clear(self):
|
|
||||||
if self._tqdm is not None:
|
|
||||||
self._tqdm.refresh()
|
|
||||||
self._tqdm.close()
|
|
||||||
self._tqdm = None
|
|
||||||
|
|
||||||
|
|
||||||
total_tqdm = TotalTQDM()
|
|
||||||
|
|
||||||
mem_mon = modules.memmon.MemUsageMonitor("MemMon", device, opts)
|
|
||||||
mem_mon.start()
|
|
||||||
|
|
||||||
|
|
||||||
def natural_sort_key(s, regex=re.compile('([0-9]+)')):
|
|
||||||
return [int(text) if text.isdigit() else text.lower() for text in regex.split(s)]
|
|
||||||
|
|
||||||
|
|
||||||
def listfiles(dirname):
|
|
||||||
filenames = [os.path.join(dirname, x) for x in sorted(os.listdir(dirname), key=natural_sort_key) if not x.startswith(".")]
|
|
||||||
return [file for file in filenames if os.path.isfile(file)]
|
|
||||||
|
|
||||||
|
|
||||||
def html_path(filename):
|
|
||||||
return os.path.join(script_path, "html", filename)
|
|
||||||
|
|
||||||
|
|
||||||
def html(filename):
|
|
||||||
path = html_path(filename)
|
|
||||||
|
|
||||||
if os.path.exists(path):
|
|
||||||
with open(path, encoding="utf8") as file:
|
|
||||||
return file.read()
|
|
||||||
|
|
||||||
return ""
|
|
||||||
|
|
||||||
|
|
||||||
def walk_files(path, allowed_extensions=None):
|
|
||||||
if not os.path.exists(path):
|
|
||||||
return
|
|
||||||
|
|
||||||
if allowed_extensions is not None:
|
|
||||||
allowed_extensions = set(allowed_extensions)
|
|
||||||
|
|
||||||
items = list(os.walk(path, followlinks=True))
|
|
||||||
items = sorted(items, key=lambda x: natural_sort_key(x[0]))
|
|
||||||
|
|
||||||
for root, _, files in items:
|
|
||||||
for filename in sorted(files, key=natural_sort_key):
|
|
||||||
if allowed_extensions is not None:
|
|
||||||
_, ext = os.path.splitext(filename)
|
|
||||||
if ext not in allowed_extensions:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if not opts.list_hidden_files and ("/." in root or "\\." in root):
|
|
||||||
continue
|
|
||||||
|
|
||||||
yield os.path.join(root, filename)
|
|
||||||
|
|
||||||
|
|
||||||
def ldm_print(*args, **kwargs):
|
|
||||||
if opts.hide_ldm_prints:
|
|
||||||
return
|
|
||||||
|
|
||||||
print(*args, **kwargs)
|
|
||||||
|
|||||||
@@ -0,0 +1,18 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
import launch
|
||||||
|
from modules import cmd_args, script_loading
|
||||||
|
from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir # noqa: F401
|
||||||
|
|
||||||
|
parser = cmd_args.parser
|
||||||
|
|
||||||
|
script_loading.preload_extensions(extensions_dir, parser, extension_list=launch.list_extensions(launch.args.ui_settings_file))
|
||||||
|
script_loading.preload_extensions(extensions_builtin_dir, parser)
|
||||||
|
|
||||||
|
if os.environ.get('IGNORE_CMD_ARGS_ERRORS', None) is None:
|
||||||
|
cmd_opts = parser.parse_args()
|
||||||
|
else:
|
||||||
|
cmd_opts, _ = parser.parse_known_args()
|
||||||
|
|
||||||
|
|
||||||
|
cmd_opts.disable_extension_access = (cmd_opts.share or cmd_opts.listen or cmd_opts.server_name) and not cmd_opts.enable_insecure_extension_access
|
||||||
@@ -0,0 +1,67 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
import gradio as gr
|
||||||
|
|
||||||
|
from modules import errors, shared
|
||||||
|
from modules.paths_internal import script_path
|
||||||
|
|
||||||
|
|
||||||
|
# https://huggingface.co/datasets/freddyaboulton/gradio-theme-subdomains/resolve/main/subdomains.json
|
||||||
|
gradio_hf_hub_themes = [
|
||||||
|
"gradio/base",
|
||||||
|
"gradio/glass",
|
||||||
|
"gradio/monochrome",
|
||||||
|
"gradio/seafoam",
|
||||||
|
"gradio/soft",
|
||||||
|
"gradio/dracula_test",
|
||||||
|
"abidlabs/dracula_test",
|
||||||
|
"abidlabs/Lime",
|
||||||
|
"abidlabs/pakistan",
|
||||||
|
"Ama434/neutral-barlow",
|
||||||
|
"dawood/microsoft_windows",
|
||||||
|
"finlaymacklon/smooth_slate",
|
||||||
|
"Franklisi/darkmode",
|
||||||
|
"freddyaboulton/dracula_revamped",
|
||||||
|
"freddyaboulton/test-blue",
|
||||||
|
"gstaff/xkcd",
|
||||||
|
"Insuz/Mocha",
|
||||||
|
"Insuz/SimpleIndigo",
|
||||||
|
"JohnSmith9982/small_and_pretty",
|
||||||
|
"nota-ai/theme",
|
||||||
|
"nuttea/Softblue",
|
||||||
|
"ParityError/Anime",
|
||||||
|
"reilnuud/polite",
|
||||||
|
"remilia/Ghostly",
|
||||||
|
"rottenlittlecreature/Moon_Goblin",
|
||||||
|
"step-3-profit/Midnight-Deep",
|
||||||
|
"Taithrah/Minimal",
|
||||||
|
"ysharma/huggingface",
|
||||||
|
"ysharma/steampunk",
|
||||||
|
"NoCrypt/miku"
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def reload_gradio_theme(theme_name=None):
|
||||||
|
if not theme_name:
|
||||||
|
theme_name = shared.opts.gradio_theme
|
||||||
|
|
||||||
|
default_theme_args = dict(
|
||||||
|
font=["Source Sans Pro", 'ui-sans-serif', 'system-ui', 'sans-serif'],
|
||||||
|
font_mono=['IBM Plex Mono', 'ui-monospace', 'Consolas', 'monospace'],
|
||||||
|
)
|
||||||
|
|
||||||
|
if theme_name == "Default":
|
||||||
|
shared.gradio_theme = gr.themes.Default(**default_theme_args)
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
theme_cache_dir = os.path.join(script_path, 'tmp', 'gradio_themes')
|
||||||
|
theme_cache_path = os.path.join(theme_cache_dir, f'{theme_name.replace("/", "_")}.json')
|
||||||
|
if shared.opts.gradio_themes_cache and os.path.exists(theme_cache_path):
|
||||||
|
shared.gradio_theme = gr.themes.ThemeClass.load(theme_cache_path)
|
||||||
|
else:
|
||||||
|
os.makedirs(theme_cache_dir, exist_ok=True)
|
||||||
|
shared.gradio_theme = gr.themes.ThemeClass.from_hub(theme_name)
|
||||||
|
shared.gradio_theme.dump(theme_cache_path)
|
||||||
|
except Exception as e:
|
||||||
|
errors.display(e, "changing gradio theme")
|
||||||
|
shared.gradio_theme = gr.themes.Default(**default_theme_args)
|
||||||
@@ -0,0 +1,49 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from modules import shared
|
||||||
|
from modules.shared import cmd_opts
|
||||||
|
|
||||||
|
|
||||||
|
def initialize():
|
||||||
|
"""Initializes fields inside the shared module in a controlled manner.
|
||||||
|
|
||||||
|
Should be called early because some other modules you can import mingt need these fields to be already set.
|
||||||
|
"""
|
||||||
|
|
||||||
|
os.makedirs(cmd_opts.hypernetwork_dir, exist_ok=True)
|
||||||
|
|
||||||
|
from modules import options, shared_options
|
||||||
|
shared.options_templates = shared_options.options_templates
|
||||||
|
shared.opts = options.Options(shared_options.options_templates, shared_options.restricted_opts)
|
||||||
|
shared.restricted_opts = shared_options.restricted_opts
|
||||||
|
if os.path.exists(shared.config_filename):
|
||||||
|
shared.opts.load(shared.config_filename)
|
||||||
|
|
||||||
|
from modules import devices
|
||||||
|
devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_esrgan, devices.device_codeformer = \
|
||||||
|
(devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'esrgan', 'codeformer'])
|
||||||
|
|
||||||
|
devices.dtype = torch.float32 if cmd_opts.no_half else torch.float16
|
||||||
|
devices.dtype_vae = torch.float32 if cmd_opts.no_half or cmd_opts.no_half_vae else torch.float16
|
||||||
|
|
||||||
|
shared.device = devices.device
|
||||||
|
shared.weight_load_location = None if cmd_opts.lowram else "cpu"
|
||||||
|
|
||||||
|
from modules import shared_state
|
||||||
|
shared.state = shared_state.State()
|
||||||
|
|
||||||
|
from modules import styles
|
||||||
|
shared.prompt_styles = styles.StyleDatabase(shared.styles_filename)
|
||||||
|
|
||||||
|
from modules import interrogate
|
||||||
|
shared.interrogator = interrogate.InterrogateModels("interrogate")
|
||||||
|
|
||||||
|
from modules import shared_total_tqdm
|
||||||
|
shared.total_tqdm = shared_total_tqdm.TotalTQDM()
|
||||||
|
|
||||||
|
from modules import memmon, devices
|
||||||
|
shared.mem_mon = memmon.MemUsageMonitor("MemMon", devices.device, shared.opts)
|
||||||
|
shared.mem_mon.start()
|
||||||
|
|
||||||
+52
-2
@@ -1,3 +1,6 @@
|
|||||||
|
import sys
|
||||||
|
|
||||||
|
from modules.shared_cmd_options import cmd_opts
|
||||||
|
|
||||||
|
|
||||||
def realesrgan_models_names():
|
def realesrgan_models_names():
|
||||||
@@ -41,13 +44,36 @@ def refresh_unet_list():
|
|||||||
modules.sd_unet.list_unets()
|
modules.sd_unet.list_unets()
|
||||||
|
|
||||||
|
|
||||||
|
def list_checkpoint_tiles():
|
||||||
|
import modules.sd_models
|
||||||
|
return modules.sd_models.checkpoint_tiles()
|
||||||
|
|
||||||
|
|
||||||
|
def refresh_checkpoints():
|
||||||
|
import modules.sd_models
|
||||||
|
return modules.sd_models.list_models()
|
||||||
|
|
||||||
|
|
||||||
|
def list_samplers():
|
||||||
|
import modules.sd_samplers
|
||||||
|
return modules.sd_samplers.all_samplers
|
||||||
|
|
||||||
|
|
||||||
|
def reload_hypernetworks():
|
||||||
|
from modules.hypernetworks import hypernetwork
|
||||||
|
from modules import shared
|
||||||
|
|
||||||
|
shared.hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir)
|
||||||
|
|
||||||
|
|
||||||
ui_reorder_categories_builtin_items = [
|
ui_reorder_categories_builtin_items = [
|
||||||
"inpaint",
|
"inpaint",
|
||||||
"sampler",
|
"sampler",
|
||||||
|
"accordions",
|
||||||
"checkboxes",
|
"checkboxes",
|
||||||
"hires_fix",
|
|
||||||
"dimensions",
|
"dimensions",
|
||||||
"cfg",
|
"cfg",
|
||||||
|
"denoising",
|
||||||
"seed",
|
"seed",
|
||||||
"batch",
|
"batch",
|
||||||
"override_settings",
|
"override_settings",
|
||||||
@@ -61,9 +87,33 @@ def ui_reorder_categories():
|
|||||||
|
|
||||||
sections = {}
|
sections = {}
|
||||||
for script in scripts.scripts_txt2img.scripts + scripts.scripts_img2img.scripts:
|
for script in scripts.scripts_txt2img.scripts + scripts.scripts_img2img.scripts:
|
||||||
if isinstance(script.section, str):
|
if isinstance(script.section, str) and script.section not in ui_reorder_categories_builtin_items:
|
||||||
sections[script.section] = 1
|
sections[script.section] = 1
|
||||||
|
|
||||||
yield from sections
|
yield from sections
|
||||||
|
|
||||||
yield "scripts"
|
yield "scripts"
|
||||||
|
|
||||||
|
|
||||||
|
class Shared(sys.modules[__name__].__class__):
|
||||||
|
"""
|
||||||
|
this class is here to provide sd_model field as a property, so that it can be created and loaded on demand rather than
|
||||||
|
at program startup.
|
||||||
|
"""
|
||||||
|
|
||||||
|
sd_model_val = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sd_model(self):
|
||||||
|
import modules.sd_models
|
||||||
|
|
||||||
|
return modules.sd_models.model_data.get_sd_model()
|
||||||
|
|
||||||
|
@sd_model.setter
|
||||||
|
def sd_model(self, value):
|
||||||
|
import modules.sd_models
|
||||||
|
|
||||||
|
modules.sd_models.model_data.set_sd_model(value)
|
||||||
|
|
||||||
|
|
||||||
|
sys.modules['modules.shared'].__class__ = Shared
|
||||||
|
|||||||
@@ -0,0 +1,330 @@
|
|||||||
|
import gradio as gr
|
||||||
|
|
||||||
|
from modules import localization, ui_components, shared_items, shared, interrogate, shared_gradio_themes
|
||||||
|
from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir # noqa: F401
|
||||||
|
from modules.shared_cmd_options import cmd_opts
|
||||||
|
from modules.options import options_section, OptionInfo, OptionHTML
|
||||||
|
|
||||||
|
options_templates = {}
|
||||||
|
hide_dirs = shared.hide_dirs
|
||||||
|
|
||||||
|
restricted_opts = {
|
||||||
|
"samples_filename_pattern",
|
||||||
|
"directories_filename_pattern",
|
||||||
|
"outdir_samples",
|
||||||
|
"outdir_txt2img_samples",
|
||||||
|
"outdir_img2img_samples",
|
||||||
|
"outdir_extras_samples",
|
||||||
|
"outdir_grids",
|
||||||
|
"outdir_txt2img_grids",
|
||||||
|
"outdir_save",
|
||||||
|
"outdir_init_images"
|
||||||
|
}
|
||||||
|
|
||||||
|
options_templates.update(options_section(('saving-images', "Saving images/grids"), {
|
||||||
|
"samples_save": OptionInfo(True, "Always save all generated images"),
|
||||||
|
"samples_format": OptionInfo('png', 'File format for images'),
|
||||||
|
"samples_filename_pattern": OptionInfo("", "Images filename pattern", component_args=hide_dirs).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Images-Filename-Name-and-Subdirectory"),
|
||||||
|
"save_images_add_number": OptionInfo(True, "Add number to filename when saving", component_args=hide_dirs),
|
||||||
|
|
||||||
|
"grid_save": OptionInfo(True, "Always save all generated image grids"),
|
||||||
|
"grid_format": OptionInfo('png', 'File format for grids'),
|
||||||
|
"grid_extended_filename": OptionInfo(False, "Add extended info (seed, prompt) to filename when saving grid"),
|
||||||
|
"grid_only_if_multiple": OptionInfo(True, "Do not save grids consisting of one picture"),
|
||||||
|
"grid_prevent_empty_spots": OptionInfo(False, "Prevent empty spots in grid (when set to autodetect)"),
|
||||||
|
"grid_zip_filename_pattern": OptionInfo("", "Archive filename pattern", component_args=hide_dirs).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Images-Filename-Name-and-Subdirectory"),
|
||||||
|
"n_rows": OptionInfo(-1, "Grid row count; use -1 for autodetect and 0 for it to be same as batch size", gr.Slider, {"minimum": -1, "maximum": 16, "step": 1}),
|
||||||
|
"font": OptionInfo("", "Font for image grids that have text"),
|
||||||
|
"grid_text_active_color": OptionInfo("#000000", "Text color for image grids", ui_components.FormColorPicker, {}),
|
||||||
|
"grid_text_inactive_color": OptionInfo("#999999", "Inactive text color for image grids", ui_components.FormColorPicker, {}),
|
||||||
|
"grid_background_color": OptionInfo("#ffffff", "Background color for image grids", ui_components.FormColorPicker, {}),
|
||||||
|
|
||||||
|
"enable_pnginfo": OptionInfo(True, "Save text information about generation parameters as chunks to png files"),
|
||||||
|
"save_txt": OptionInfo(False, "Create a text file next to every image with generation parameters."),
|
||||||
|
"save_images_before_face_restoration": OptionInfo(False, "Save a copy of image before doing face restoration."),
|
||||||
|
"save_images_before_highres_fix": OptionInfo(False, "Save a copy of image before applying highres fix."),
|
||||||
|
"save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"),
|
||||||
|
"save_mask": OptionInfo(False, "For inpainting, save a copy of the greyscale mask"),
|
||||||
|
"save_mask_composite": OptionInfo(False, "For inpainting, save a masked composite"),
|
||||||
|
"jpeg_quality": OptionInfo(80, "Quality for saved jpeg images", gr.Slider, {"minimum": 1, "maximum": 100, "step": 1}),
|
||||||
|
"webp_lossless": OptionInfo(False, "Use lossless compression for webp images"),
|
||||||
|
"export_for_4chan": OptionInfo(True, "Save copy of large images as JPG").info("if the file size is above the limit, or either width or height are above the limit"),
|
||||||
|
"img_downscale_threshold": OptionInfo(4.0, "File size limit for the above option, MB", gr.Number),
|
||||||
|
"target_side_length": OptionInfo(4000, "Width/height limit for the above option, in pixels", gr.Number),
|
||||||
|
"img_max_size_mp": OptionInfo(200, "Maximum image size", gr.Number).info("in megapixels"),
|
||||||
|
|
||||||
|
"use_original_name_batch": OptionInfo(True, "Use original name for output filename during batch process in extras tab"),
|
||||||
|
"use_upscaler_name_as_suffix": OptionInfo(False, "Use upscaler name as filename suffix in the extras tab"),
|
||||||
|
"save_selected_only": OptionInfo(True, "When using 'Save' button, only save a single selected image"),
|
||||||
|
"save_init_img": OptionInfo(False, "Save init images when using img2img"),
|
||||||
|
|
||||||
|
"temp_dir": OptionInfo("", "Directory for temporary images; leave empty for default"),
|
||||||
|
"clean_temp_dir_at_start": OptionInfo(False, "Cleanup non-default temporary directory when starting webui"),
|
||||||
|
|
||||||
|
"save_incomplete_images": OptionInfo(False, "Save incomplete images").info("save images that has been interrupted in mid-generation; even if not saved, they will still show up in webui output."),
|
||||||
|
}))
|
||||||
|
|
||||||
|
options_templates.update(options_section(('saving-paths', "Paths for saving"), {
|
||||||
|
"outdir_samples": OptionInfo("", "Output directory for images; if empty, defaults to three directories below", component_args=hide_dirs),
|
||||||
|
"outdir_txt2img_samples": OptionInfo("outputs/txt2img-images", 'Output directory for txt2img images', component_args=hide_dirs),
|
||||||
|
"outdir_img2img_samples": OptionInfo("outputs/img2img-images", 'Output directory for img2img images', component_args=hide_dirs),
|
||||||
|
"outdir_extras_samples": OptionInfo("outputs/extras-images", 'Output directory for images from extras tab', component_args=hide_dirs),
|
||||||
|
"outdir_grids": OptionInfo("", "Output directory for grids; if empty, defaults to two directories below", component_args=hide_dirs),
|
||||||
|
"outdir_txt2img_grids": OptionInfo("outputs/txt2img-grids", 'Output directory for txt2img grids', component_args=hide_dirs),
|
||||||
|
"outdir_img2img_grids": OptionInfo("outputs/img2img-grids", 'Output directory for img2img grids', component_args=hide_dirs),
|
||||||
|
"outdir_save": OptionInfo("log/images", "Directory for saving images using the Save button", component_args=hide_dirs),
|
||||||
|
"outdir_init_images": OptionInfo("outputs/init-images", "Directory for saving init images when using img2img", component_args=hide_dirs),
|
||||||
|
}))
|
||||||
|
|
||||||
|
options_templates.update(options_section(('saving-to-dirs', "Saving to a directory"), {
|
||||||
|
"save_to_dirs": OptionInfo(True, "Save images to a subdirectory"),
|
||||||
|
"grid_save_to_dirs": OptionInfo(True, "Save grids to a subdirectory"),
|
||||||
|
"use_save_to_dirs_for_ui": OptionInfo(False, "When using \"Save\" button, save images to a subdirectory"),
|
||||||
|
"directories_filename_pattern": OptionInfo("[date]", "Directory name pattern", component_args=hide_dirs).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Images-Filename-Name-and-Subdirectory"),
|
||||||
|
"directories_max_prompt_words": OptionInfo(8, "Max prompt words for [prompt_words] pattern", gr.Slider, {"minimum": 1, "maximum": 20, "step": 1, **hide_dirs}),
|
||||||
|
}))
|
||||||
|
|
||||||
|
options_templates.update(options_section(('upscaling', "Upscaling"), {
|
||||||
|
"ESRGAN_tile": OptionInfo(192, "Tile size for ESRGAN upscalers.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}).info("0 = no tiling"),
|
||||||
|
"ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap for ESRGAN upscalers.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}).info("Low values = visible seam"),
|
||||||
|
"realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI.", gr.CheckboxGroup, lambda: {"choices": shared_items.realesrgan_models_names()}),
|
||||||
|
"upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in shared.sd_upscalers]}),
|
||||||
|
}))
|
||||||
|
|
||||||
|
options_templates.update(options_section(('face-restoration', "Face restoration"), {
|
||||||
|
"face_restoration": OptionInfo(False, "Restore faces", infotext='Face restoration').info("will use a third-party model on generation result to reconstruct faces"),
|
||||||
|
"face_restoration_model": OptionInfo("CodeFormer", "Face restoration model", gr.Radio, lambda: {"choices": [x.name() for x in shared.face_restorers]}),
|
||||||
|
"code_former_weight": OptionInfo(0.5, "CodeFormer weight", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}).info("0 = maximum effect; 1 = minimum effect"),
|
||||||
|
"face_restoration_unload": OptionInfo(False, "Move face restoration model from VRAM into RAM after processing"),
|
||||||
|
}))
|
||||||
|
|
||||||
|
options_templates.update(options_section(('system', "System"), {
|
||||||
|
"auto_launch_browser": OptionInfo("Local", "Automatically open webui in browser on startup", gr.Radio, lambda: {"choices": ["Disable", "Local", "Remote"]}),
|
||||||
|
"show_warnings": OptionInfo(False, "Show warnings in console.").needs_reload_ui(),
|
||||||
|
"show_gradio_deprecation_warnings": OptionInfo(True, "Show gradio deprecation warnings in console.").needs_reload_ui(),
|
||||||
|
"memmon_poll_rate": OptionInfo(8, "VRAM usage polls per second during generation.", gr.Slider, {"minimum": 0, "maximum": 40, "step": 1}).info("0 = disable"),
|
||||||
|
"samples_log_stdout": OptionInfo(False, "Always print all generation info to standard output"),
|
||||||
|
"multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."),
|
||||||
|
"print_hypernet_extra": OptionInfo(False, "Print extra hypernetwork information to console."),
|
||||||
|
"list_hidden_files": OptionInfo(True, "Load models/files in hidden directories").info("directory is hidden if its name starts with \".\""),
|
||||||
|
"disable_mmap_load_safetensors": OptionInfo(False, "Disable memmapping for loading .safetensors files.").info("fixes very slow loading speed in some cases"),
|
||||||
|
"hide_ldm_prints": OptionInfo(True, "Prevent Stability-AI's ldm/sgm modules from printing noise to console."),
|
||||||
|
}))
|
||||||
|
|
||||||
|
options_templates.update(options_section(('API', "API"), {
|
||||||
|
"api_enable_requests": OptionInfo(True, "Allow http:// and https:// URLs for input images in API", restrict_api=True),
|
||||||
|
"api_forbid_local_requests": OptionInfo(True, "Forbid URLs to local resources", restrict_api=True),
|
||||||
|
"api_useragent": OptionInfo("", "User agent for requests", restrict_api=True),
|
||||||
|
}))
|
||||||
|
|
||||||
|
options_templates.update(options_section(('training', "Training"), {
|
||||||
|
"unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."),
|
||||||
|
"pin_memory": OptionInfo(False, "Turn on pin_memory for DataLoader. Makes training slightly faster but can increase memory usage."),
|
||||||
|
"save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training of embedding or HN can be resumed with the matching optim file."),
|
||||||
|
"save_training_settings_to_txt": OptionInfo(True, "Save textual inversion and hypernet settings to a text file whenever training starts."),
|
||||||
|
"dataset_filename_word_regex": OptionInfo("", "Filename word regex"),
|
||||||
|
"dataset_filename_join_string": OptionInfo(" ", "Filename join string"),
|
||||||
|
"training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}),
|
||||||
|
"training_write_csv_every": OptionInfo(500, "Save an csv containing the loss to log directory every N steps, 0 to disable"),
|
||||||
|
"training_xattention_optimizations": OptionInfo(False, "Use cross attention optimizations while training"),
|
||||||
|
"training_enable_tensorboard": OptionInfo(False, "Enable tensorboard logging."),
|
||||||
|
"training_tensorboard_save_images": OptionInfo(False, "Save generated images within tensorboard."),
|
||||||
|
"training_tensorboard_flush_every": OptionInfo(120, "How often, in seconds, to flush the pending tensorboard events and summaries to disk."),
|
||||||
|
}))
|
||||||
|
|
||||||
|
options_templates.update(options_section(('sd', "Stable Diffusion"), {
|
||||||
|
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": shared_items.list_checkpoint_tiles()}, refresh=shared_items.refresh_checkpoints, infotext='Model hash'),
|
||||||
|
"sd_checkpoints_limit": OptionInfo(1, "Maximum number of checkpoints loaded at the same time", gr.Slider, {"minimum": 1, "maximum": 10, "step": 1}),
|
||||||
|
"sd_checkpoints_keep_in_cpu": OptionInfo(True, "Only keep one model on device").info("will keep models other than the currently used one in RAM rather than VRAM"),
|
||||||
|
"sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}).info("obsolete; set to 0 and use the two settings above instead"),
|
||||||
|
"sd_unet": OptionInfo("Automatic", "SD Unet", gr.Dropdown, lambda: {"choices": shared_items.sd_unet_items()}, refresh=shared_items.refresh_unet_list).info("choose Unet model: Automatic = use one with same filename as checkpoint; None = use Unet from checkpoint"),
|
||||||
|
"enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds").needs_reload_ui(),
|
||||||
|
"enable_emphasis": OptionInfo(True, "Enable emphasis").info("use (text) to make model pay more attention to text and [text] to make it pay less attention"),
|
||||||
|
"enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"),
|
||||||
|
"comma_padding_backtrack": OptionInfo(20, "Prompt word wrap length limit", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1}).info("in tokens - for texts shorter than specified, if they don't fit into 75 token limit, move them to the next 75 token chunk"),
|
||||||
|
"CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}, infotext="Clip skip").link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#clip-skip").info("ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer"),
|
||||||
|
"upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"),
|
||||||
|
"randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU", "NV"]}).info("changes seeds drastically; use CPU to produce the same picture across different videocard vendors; use NV to produce same picture as on NVidia videocards"),
|
||||||
|
"tiling": OptionInfo(False, "Tiling", infotext='Tiling').info("produce a tileable picture"),
|
||||||
|
}))
|
||||||
|
|
||||||
|
options_templates.update(options_section(('sdxl', "Stable Diffusion XL"), {
|
||||||
|
"sdxl_crop_top": OptionInfo(0, "crop top coordinate"),
|
||||||
|
"sdxl_crop_left": OptionInfo(0, "crop left coordinate"),
|
||||||
|
"sdxl_refiner_low_aesthetic_score": OptionInfo(2.5, "SDXL low aesthetic score", gr.Number).info("used for refiner model negative prompt"),
|
||||||
|
"sdxl_refiner_high_aesthetic_score": OptionInfo(6.0, "SDXL high aesthetic score", gr.Number).info("used for refiner model prompt"),
|
||||||
|
}))
|
||||||
|
|
||||||
|
options_templates.update(options_section(('vae', "VAE"), {
|
||||||
|
"sd_vae_explanation": OptionHTML("""
|
||||||
|
<abbr title='Variational autoencoder'>VAE</abbr> is a neural network that transforms a standard <abbr title='red/green/blue'>RGB</abbr>
|
||||||
|
image into latent space representation and back. Latent space representation is what stable diffusion is working on during sampling
|
||||||
|
(i.e. when the progress bar is between empty and full). For txt2img, VAE is used to create a resulting image after the sampling is finished.
|
||||||
|
For img2img, VAE is used to process user's input image before the sampling, and to create an image after sampling.
|
||||||
|
"""),
|
||||||
|
"sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
|
||||||
|
"sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": shared_items.sd_vae_items()}, refresh=shared_items.refresh_vae_list, infotext='VAE').info("choose VAE model: Automatic = use one with same filename as checkpoint; None = use VAE from checkpoint"),
|
||||||
|
"sd_vae_overrides_per_model_preferences": OptionInfo(True, "Selected VAE overrides per-model preferences").info("you can set per-model VAE either by editing user metadata for checkpoints, or by making the VAE have same name as checkpoint"),
|
||||||
|
"auto_vae_precision": OptionInfo(True, "Automatically revert VAE to 32-bit floats").info("triggers when a tensor with NaNs is produced in VAE; disabling the option in this case will result in a black square image"),
|
||||||
|
"sd_vae_encode_method": OptionInfo("Full", "VAE type for encode", gr.Radio, {"choices": ["Full", "TAESD"]}, infotext='VAE Encoder').info("method to encode image to latent (use in img2img, hires-fix or inpaint mask)"),
|
||||||
|
"sd_vae_decode_method": OptionInfo("Full", "VAE type for decode", gr.Radio, {"choices": ["Full", "TAESD"]}, infotext='VAE Decoder').info("method to decode latent to image"),
|
||||||
|
}))
|
||||||
|
|
||||||
|
options_templates.update(options_section(('img2img', "img2img"), {
|
||||||
|
"inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}, infotext='Conditional mask weight'),
|
||||||
|
"initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.0, "maximum": 1.5, "step": 0.001}, infotext='Noise multiplier'),
|
||||||
|
"img2img_extra_noise": OptionInfo(0.0, "Extra noise multiplier for img2img and hires fix", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}, infotext='Extra noise').info("0 = disabled (default); should be lower than denoising strength"),
|
||||||
|
"img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
|
||||||
|
"img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies.").info("normally you'd do less with less denoising"),
|
||||||
|
"img2img_background_color": OptionInfo("#ffffff", "With img2img, fill transparent parts of the input image with this color.", ui_components.FormColorPicker, {}),
|
||||||
|
"img2img_editor_height": OptionInfo(720, "Height of the image editor", gr.Slider, {"minimum": 80, "maximum": 1600, "step": 1}).info("in pixels").needs_reload_ui(),
|
||||||
|
"img2img_sketch_default_brush_color": OptionInfo("#ffffff", "Sketch initial brush color", ui_components.FormColorPicker, {}).info("default brush color of img2img sketch").needs_reload_ui(),
|
||||||
|
"img2img_inpaint_mask_brush_color": OptionInfo("#ffffff", "Inpaint mask brush color", ui_components.FormColorPicker, {}).info("brush color of inpaint mask").needs_reload_ui(),
|
||||||
|
"img2img_inpaint_sketch_default_brush_color": OptionInfo("#ffffff", "Inpaint sketch initial brush color", ui_components.FormColorPicker, {}).info("default brush color of img2img inpaint sketch").needs_reload_ui(),
|
||||||
|
"return_mask": OptionInfo(False, "For inpainting, include the greyscale mask in results for web"),
|
||||||
|
"return_mask_composite": OptionInfo(False, "For inpainting, include masked composite in results for web"),
|
||||||
|
}))
|
||||||
|
|
||||||
|
options_templates.update(options_section(('optimizations', "Optimizations"), {
|
||||||
|
"cross_attention_optimization": OptionInfo("Automatic", "Cross attention optimization", gr.Dropdown, lambda: {"choices": shared_items.cross_attention_optimizations()}),
|
||||||
|
"s_min_uncond": OptionInfo(0.0, "Negative Guidance minimum sigma", gr.Slider, {"minimum": 0.0, "maximum": 15.0, "step": 0.01}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9177").info("skip negative prompt for some steps when the image is almost ready; 0=disable, higher=faster"),
|
||||||
|
"token_merging_ratio": OptionInfo(0.0, "Token merging ratio", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}, infotext='Token merging ratio').link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9256").info("0=disable, higher=faster"),
|
||||||
|
"token_merging_ratio_img2img": OptionInfo(0.0, "Token merging ratio for img2img", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).info("only applies if non-zero and overrides above"),
|
||||||
|
"token_merging_ratio_hr": OptionInfo(0.0, "Token merging ratio for high-res pass", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}, infotext='Token merging ratio hr').info("only applies if non-zero and overrides above"),
|
||||||
|
"pad_cond_uncond": OptionInfo(False, "Pad prompt/negative prompt to be same length", infotext='Pad conds').info("improves performance when prompt and negative prompt have different lengths; changes seeds"),
|
||||||
|
"persistent_cond_cache": OptionInfo(True, "Persistent cond cache").info("do not recalculate conds from prompts if prompts have not changed since previous calculation"),
|
||||||
|
"batch_cond_uncond": OptionInfo(True, "Batch cond/uncond").info("do both conditional and unconditional denoising in one batch; uses a bit more VRAM during sampling, but improves speed; previously this was controlled by --always-batch-cond-uncond comandline argument"),
|
||||||
|
}))
|
||||||
|
|
||||||
|
options_templates.update(options_section(('compatibility', "Compatibility"), {
|
||||||
|
"use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."),
|
||||||
|
"use_old_karras_scheduler_sigmas": OptionInfo(False, "Use old karras scheduler sigmas (0.1 to 10)."),
|
||||||
|
"no_dpmpp_sde_batch_determinism": OptionInfo(False, "Do not make DPM++ SDE deterministic across different batch sizes."),
|
||||||
|
"use_old_hires_fix_width_height": OptionInfo(False, "For hires fix, use width/height sliders to set final resolution rather than first pass (disables Upscale by, Resize width/height to)."),
|
||||||
|
"dont_fix_second_order_samplers_schedule": OptionInfo(False, "Do not fix prompt schedule for second order samplers."),
|
||||||
|
"hires_fix_use_firstpass_conds": OptionInfo(False, "For hires fix, calculate conds of second pass using extra networks of first pass."),
|
||||||
|
"use_old_scheduling": OptionInfo(False, "Use old prompt editing timelines.", infotext="Old prompt editing timelines").info("For [red:green:N]; old: If N < 1, it's a fraction of steps (and hires fix uses range from 0 to 1), if N >= 1, it's an absolute number of steps; new: If N has a decimal point in it, it's a fraction of steps (and hires fix uses range from 1 to 2), othewrwise it's an absolute number of steps"),
|
||||||
|
}))
|
||||||
|
|
||||||
|
options_templates.update(options_section(('interrogate', "Interrogate"), {
|
||||||
|
"interrogate_keep_models_in_memory": OptionInfo(False, "Keep models in VRAM"),
|
||||||
|
"interrogate_return_ranks": OptionInfo(False, "Include ranks of model tags matches in results.").info("booru only"),
|
||||||
|
"interrogate_clip_num_beams": OptionInfo(1, "BLIP: num_beams", gr.Slider, {"minimum": 1, "maximum": 16, "step": 1}),
|
||||||
|
"interrogate_clip_min_length": OptionInfo(24, "BLIP: minimum description length", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}),
|
||||||
|
"interrogate_clip_max_length": OptionInfo(48, "BLIP: maximum description length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}),
|
||||||
|
"interrogate_clip_dict_limit": OptionInfo(1500, "CLIP: maximum number of lines in text file").info("0 = No limit"),
|
||||||
|
"interrogate_clip_skip_categories": OptionInfo([], "CLIP: skip inquire categories", gr.CheckboxGroup, lambda: {"choices": interrogate.category_types()}, refresh=interrogate.category_types),
|
||||||
|
"interrogate_deepbooru_score_threshold": OptionInfo(0.5, "deepbooru: score threshold", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}),
|
||||||
|
"deepbooru_sort_alpha": OptionInfo(True, "deepbooru: sort tags alphabetically").info("if not: sort by score"),
|
||||||
|
"deepbooru_use_spaces": OptionInfo(True, "deepbooru: use spaces in tags").info("if not: use underscores"),
|
||||||
|
"deepbooru_escape": OptionInfo(True, "deepbooru: escape (\\) brackets").info("so they are used as literal brackets and not for emphasis"),
|
||||||
|
"deepbooru_filter_tags": OptionInfo("", "deepbooru: filter out those tags").info("separate by comma"),
|
||||||
|
}))
|
||||||
|
|
||||||
|
options_templates.update(options_section(('extra_networks', "Extra Networks"), {
|
||||||
|
"extra_networks_show_hidden_directories": OptionInfo(True, "Show hidden directories").info("directory is hidden if its name starts with \".\"."),
|
||||||
|
"extra_networks_hidden_models": OptionInfo("When searched", "Show cards for models in hidden directories", gr.Radio, {"choices": ["Always", "When searched", "Never"]}).info('"When searched" option will only show the item when the search string has 4 characters or more'),
|
||||||
|
"extra_networks_default_multiplier": OptionInfo(1.0, "Default multiplier for extra networks", gr.Slider, {"minimum": 0.0, "maximum": 2.0, "step": 0.01}),
|
||||||
|
"extra_networks_card_width": OptionInfo(0, "Card width for Extra Networks").info("in pixels"),
|
||||||
|
"extra_networks_card_height": OptionInfo(0, "Card height for Extra Networks").info("in pixels"),
|
||||||
|
"extra_networks_card_text_scale": OptionInfo(1.0, "Card text scale", gr.Slider, {"minimum": 0.0, "maximum": 2.0, "step": 0.01}).info("1 = original size"),
|
||||||
|
"extra_networks_card_show_desc": OptionInfo(True, "Show description on card"),
|
||||||
|
"extra_networks_add_text_separator": OptionInfo(" ", "Extra networks separator").info("extra text to add before <...> when adding extra network to prompt"),
|
||||||
|
"ui_extra_networks_tab_reorder": OptionInfo("", "Extra networks tab order").needs_reload_ui(),
|
||||||
|
"textual_inversion_print_at_load": OptionInfo(False, "Print a list of Textual Inversion embeddings when loading model"),
|
||||||
|
"textual_inversion_add_hashes_to_infotext": OptionInfo(True, "Add Textual Inversion hashes to infotext"),
|
||||||
|
"sd_hypernetwork": OptionInfo("None", "Add hypernetwork to prompt", gr.Dropdown, lambda: {"choices": ["None", *shared.hypernetworks]}, refresh=shared_items.reload_hypernetworks),
|
||||||
|
}))
|
||||||
|
|
||||||
|
options_templates.update(options_section(('ui', "User interface"), {
|
||||||
|
"localization": OptionInfo("None", "Localization", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)).needs_reload_ui(),
|
||||||
|
"gradio_theme": OptionInfo("Default", "Gradio theme", ui_components.DropdownEditable, lambda: {"choices": ["Default"] + shared_gradio_themes.gradio_hf_hub_themes}).info("you can also manually enter any of themes from the <a href='https://huggingface.co/spaces/gradio/theme-gallery'>gallery</a>.").needs_reload_ui(),
|
||||||
|
"gradio_themes_cache": OptionInfo(True, "Cache gradio themes locally").info("disable to update the selected Gradio theme"),
|
||||||
|
"gallery_height": OptionInfo("", "Gallery height", gr.Textbox).info("an be any valid CSS value").needs_reload_ui(),
|
||||||
|
"return_grid": OptionInfo(True, "Show grid in results for web"),
|
||||||
|
"do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
|
||||||
|
"send_seed": OptionInfo(True, "Send seed when sending prompt or image to other interface"),
|
||||||
|
"send_size": OptionInfo(True, "Send size when sending prompt or image to another interface"),
|
||||||
|
"js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"),
|
||||||
|
"js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"),
|
||||||
|
"js_modal_lightbox_gamepad": OptionInfo(False, "Navigate image viewer with gamepad"),
|
||||||
|
"js_modal_lightbox_gamepad_repeat": OptionInfo(250, "Gamepad repeat period, in milliseconds"),
|
||||||
|
"show_progress_in_title": OptionInfo(True, "Show generation progress in window title."),
|
||||||
|
"samplers_in_dropdown": OptionInfo(True, "Use dropdown for sampler selection instead of radio group").needs_reload_ui(),
|
||||||
|
"dimensions_and_batch_together": OptionInfo(True, "Show Width/Height and Batch sliders in same row").needs_reload_ui(),
|
||||||
|
"keyedit_precision_attention": OptionInfo(0.1, "Ctrl+up/down precision when editing (attention:1.1)", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
|
||||||
|
"keyedit_precision_extra": OptionInfo(0.05, "Ctrl+up/down precision when editing <extra networks:0.9>", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
|
||||||
|
"keyedit_delimiters": OptionInfo(".,\\/!?%^*;:{}=`~()", "Ctrl+up/down word delimiters"),
|
||||||
|
"keyedit_move": OptionInfo(True, "Alt+left/right moves prompt elements"),
|
||||||
|
"quicksettings_list": OptionInfo(["sd_model_checkpoint"], "Quicksettings list", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that appear at the top of page rather than in settings tab").needs_reload_ui(),
|
||||||
|
"ui_tab_order": OptionInfo([], "UI tab order", ui_components.DropdownMulti, lambda: {"choices": list(shared.tab_names)}).needs_reload_ui(),
|
||||||
|
"hidden_tabs": OptionInfo([], "Hidden UI tabs", ui_components.DropdownMulti, lambda: {"choices": list(shared.tab_names)}).needs_reload_ui(),
|
||||||
|
"ui_reorder_list": OptionInfo([], "txt2img/img2img UI item order", ui_components.DropdownMulti, lambda: {"choices": list(shared_items.ui_reorder_categories())}).info("selected items appear first").needs_reload_ui(),
|
||||||
|
"hires_fix_show_sampler": OptionInfo(False, "Hires fix: show hires checkpoint and sampler selection").needs_reload_ui(),
|
||||||
|
"hires_fix_show_prompts": OptionInfo(False, "Hires fix: show hires prompt and negative prompt").needs_reload_ui(),
|
||||||
|
"disable_token_counters": OptionInfo(False, "Disable prompt token counters").needs_reload_ui(),
|
||||||
|
}))
|
||||||
|
|
||||||
|
|
||||||
|
options_templates.update(options_section(('infotext', "Infotext"), {
|
||||||
|
"add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
|
||||||
|
"add_model_name_to_info": OptionInfo(True, "Add model name to generation information"),
|
||||||
|
"add_user_name_to_info": OptionInfo(False, "Add user name to generation information when authenticated"),
|
||||||
|
"add_version_to_infotext": OptionInfo(True, "Add program version to generation information"),
|
||||||
|
"disable_weights_auto_swap": OptionInfo(True, "Disregard checkpoint information from pasted infotext").info("when reading generation parameters from text into UI"),
|
||||||
|
"infotext_styles": OptionInfo("Apply if any", "Infer styles from prompts of pasted infotext", gr.Radio, {"choices": ["Ignore", "Apply", "Discard", "Apply if any"]}).info("when reading generation parameters from text into UI)").html("""<ul style='margin-left: 1.5em'>
|
||||||
|
<li>Ignore: keep prompt and styles dropdown as it is.</li>
|
||||||
|
<li>Apply: remove style text from prompt, always replace styles dropdown value with found styles (even if none are found).</li>
|
||||||
|
<li>Discard: remove style text from prompt, keep styles dropdown as it is.</li>
|
||||||
|
<li>Apply if any: remove style text from prompt; if any styles are found in prompt, put them into styles dropdown, otherwise keep it as it is.</li>
|
||||||
|
</ul>"""),
|
||||||
|
|
||||||
|
}))
|
||||||
|
|
||||||
|
options_templates.update(options_section(('ui', "Live previews"), {
|
||||||
|
"show_progressbar": OptionInfo(True, "Show progressbar"),
|
||||||
|
"live_previews_enable": OptionInfo(True, "Show live previews of the created image"),
|
||||||
|
"live_previews_image_format": OptionInfo("png", "Live preview file format", gr.Radio, {"choices": ["jpeg", "png", "webp"]}),
|
||||||
|
"show_progress_grid": OptionInfo(True, "Show previews of all images generated in a batch as a grid"),
|
||||||
|
"show_progress_every_n_steps": OptionInfo(10, "Live preview display period", gr.Slider, {"minimum": -1, "maximum": 32, "step": 1}).info("in sampling steps - show new live preview image every N sampling steps; -1 = only show after completion of batch"),
|
||||||
|
"show_progress_type": OptionInfo("Approx NN", "Live preview method", gr.Radio, {"choices": ["Full", "Approx NN", "Approx cheap", "TAESD"]}).info("Full = slow but pretty; Approx NN and TAESD = fast but low quality; Approx cheap = super fast but terrible otherwise"),
|
||||||
|
"live_preview_allow_lowvram_full": OptionInfo(False, "Allow Full live preview method with lowvram/medvram").info("If not, Approx NN will be used instead; Full live preview method is very detrimental to speed if lowvram/medvram optimizations are enabled"),
|
||||||
|
"live_preview_content": OptionInfo("Prompt", "Live preview subject", gr.Radio, {"choices": ["Combined", "Prompt", "Negative prompt"]}),
|
||||||
|
"live_preview_refresh_period": OptionInfo(1000, "Progressbar and preview update period").info("in milliseconds"),
|
||||||
|
"live_preview_fast_interrupt": OptionInfo(False, "Return image with chosen live preview method on interrupt").info("makes interrupts faster"),
|
||||||
|
}))
|
||||||
|
|
||||||
|
options_templates.update(options_section(('sampler-params', "Sampler parameters"), {
|
||||||
|
"hide_samplers": OptionInfo([], "Hide samplers in user interface", gr.CheckboxGroup, lambda: {"choices": [x.name for x in shared_items.list_samplers()]}).needs_reload_ui(),
|
||||||
|
"eta_ddim": OptionInfo(0.0, "Eta for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}, infotext='Eta DDIM').info("noise multiplier; higher = more unpredictable results"),
|
||||||
|
"eta_ancestral": OptionInfo(1.0, "Eta for k-diffusion samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}, infotext='Eta').info("noise multiplier; currently only applies to ancestral samplers (i.e. Euler a) and SDE samplers"),
|
||||||
|
"ddim_discretize": OptionInfo('uniform', "img2img DDIM discretize", gr.Radio, {"choices": ['uniform', 'quad']}),
|
||||||
|
's_churn': OptionInfo(0.0, "sigma churn", gr.Slider, {"minimum": 0.0, "maximum": 100.0, "step": 0.01}, infotext='Sigma churn').info('amount of stochasticity; only applies to Euler, Heun, and DPM2'),
|
||||||
|
's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 10.0, "step": 0.01}, infotext='Sigma tmin').info('enable stochasticity; start value of the sigma range; only applies to Euler, Heun, and DPM2'),
|
||||||
|
's_tmax': OptionInfo(0.0, "sigma tmax", gr.Slider, {"minimum": 0.0, "maximum": 999.0, "step": 0.01}, infotext='Sigma tmax').info("0 = inf; end value of the sigma range; only applies to Euler, Heun, and DPM2"),
|
||||||
|
's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.1, "step": 0.001}, infotext='Sigma noise').info('amount of additional noise to counteract loss of detail during sampling'),
|
||||||
|
'k_sched_type': OptionInfo("Automatic", "Scheduler type", gr.Dropdown, {"choices": ["Automatic", "karras", "exponential", "polyexponential"]}, infotext='Schedule type').info("lets you override the noise schedule for k-diffusion samplers; choosing Automatic disables the three parameters below"),
|
||||||
|
'sigma_min': OptionInfo(0.0, "sigma min", gr.Number, infotext='Schedule max sigma').info("0 = default (~0.03); minimum noise strength for k-diffusion noise scheduler"),
|
||||||
|
'sigma_max': OptionInfo(0.0, "sigma max", gr.Number, infotext='Schedule min sigma').info("0 = default (~14.6); maximum noise strength for k-diffusion noise scheduler"),
|
||||||
|
'rho': OptionInfo(0.0, "rho", gr.Number, infotext='Schedule rho').info("0 = default (7 for karras, 1 for polyexponential); higher values result in a steeper noise schedule (decreases faster)"),
|
||||||
|
'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}, infotext='ENSD').info("ENSD; does not improve anything, just produces different results for ancestral samplers - only useful for reproducing images"),
|
||||||
|
'always_discard_next_to_last_sigma': OptionInfo(False, "Always discard next-to-last sigma", infotext='Discard penultimate sigma').link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/6044"),
|
||||||
|
'uni_pc_variant': OptionInfo("bh1", "UniPC variant", gr.Radio, {"choices": ["bh1", "bh2", "vary_coeff"]}, infotext='UniPC variant'),
|
||||||
|
'uni_pc_skip_type': OptionInfo("time_uniform", "UniPC skip type", gr.Radio, {"choices": ["time_uniform", "time_quadratic", "logSNR"]}, infotext='UniPC skip type'),
|
||||||
|
'uni_pc_order': OptionInfo(3, "UniPC order", gr.Slider, {"minimum": 1, "maximum": 50, "step": 1}, infotext='UniPC order').info("must be < sampling steps"),
|
||||||
|
'uni_pc_lower_order_final': OptionInfo(True, "UniPC lower order final", infotext='UniPC lower order final'),
|
||||||
|
}))
|
||||||
|
|
||||||
|
options_templates.update(options_section(('postprocessing', "Postprocessing"), {
|
||||||
|
'postprocessing_enable_in_main_ui': OptionInfo([], "Enable postprocessing operations in txt2img and img2img tabs", ui_components.DropdownMulti, lambda: {"choices": [x.name for x in shared_items.postprocessing_scripts()]}),
|
||||||
|
'postprocessing_operation_order': OptionInfo([], "Postprocessing operation order", ui_components.DropdownMulti, lambda: {"choices": [x.name for x in shared_items.postprocessing_scripts()]}),
|
||||||
|
'upscaling_max_images_in_cache': OptionInfo(5, "Maximum number of images in upscaling cache", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
|
||||||
|
}))
|
||||||
|
|
||||||
|
options_templates.update(options_section((None, "Hidden options"), {
|
||||||
|
"disabled_extensions": OptionInfo([], "Disable these extensions"),
|
||||||
|
"disable_all_extensions": OptionInfo("none", "Disable all extensions (preserves the list of disabled extensions)", gr.Radio, {"choices": ["none", "extra", "all"]}),
|
||||||
|
"restore_config_state_file": OptionInfo("", "Config state file to restore from, under 'config-states/' folder"),
|
||||||
|
"sd_checkpoint_hash": OptionInfo("", "SHA256 hash of the current checkpoint"),
|
||||||
|
}))
|
||||||
|
|
||||||
@@ -0,0 +1,159 @@
|
|||||||
|
import datetime
|
||||||
|
import logging
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
|
||||||
|
from modules import errors, shared, devices
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class State:
|
||||||
|
skipped = False
|
||||||
|
interrupted = False
|
||||||
|
job = ""
|
||||||
|
job_no = 0
|
||||||
|
job_count = 0
|
||||||
|
processing_has_refined_job_count = False
|
||||||
|
job_timestamp = '0'
|
||||||
|
sampling_step = 0
|
||||||
|
sampling_steps = 0
|
||||||
|
current_latent = None
|
||||||
|
current_image = None
|
||||||
|
current_image_sampling_step = 0
|
||||||
|
id_live_preview = 0
|
||||||
|
textinfo = None
|
||||||
|
time_start = None
|
||||||
|
server_start = None
|
||||||
|
_server_command_signal = threading.Event()
|
||||||
|
_server_command: Optional[str] = None
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.server_start = time.time()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def need_restart(self) -> bool:
|
||||||
|
# Compatibility getter for need_restart.
|
||||||
|
return self.server_command == "restart"
|
||||||
|
|
||||||
|
@need_restart.setter
|
||||||
|
def need_restart(self, value: bool) -> None:
|
||||||
|
# Compatibility setter for need_restart.
|
||||||
|
if value:
|
||||||
|
self.server_command = "restart"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def server_command(self):
|
||||||
|
return self._server_command
|
||||||
|
|
||||||
|
@server_command.setter
|
||||||
|
def server_command(self, value: Optional[str]) -> None:
|
||||||
|
"""
|
||||||
|
Set the server command to `value` and signal that it's been set.
|
||||||
|
"""
|
||||||
|
self._server_command = value
|
||||||
|
self._server_command_signal.set()
|
||||||
|
|
||||||
|
def wait_for_server_command(self, timeout: Optional[float] = None) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Wait for server command to get set; return and clear the value and signal.
|
||||||
|
"""
|
||||||
|
if self._server_command_signal.wait(timeout):
|
||||||
|
self._server_command_signal.clear()
|
||||||
|
req = self._server_command
|
||||||
|
self._server_command = None
|
||||||
|
return req
|
||||||
|
return None
|
||||||
|
|
||||||
|
def request_restart(self) -> None:
|
||||||
|
self.interrupt()
|
||||||
|
self.server_command = "restart"
|
||||||
|
log.info("Received restart request")
|
||||||
|
|
||||||
|
def skip(self):
|
||||||
|
self.skipped = True
|
||||||
|
log.info("Received skip request")
|
||||||
|
|
||||||
|
def interrupt(self):
|
||||||
|
self.interrupted = True
|
||||||
|
log.info("Received interrupt request")
|
||||||
|
|
||||||
|
def nextjob(self):
|
||||||
|
if shared.opts.live_previews_enable and shared.opts.show_progress_every_n_steps == -1:
|
||||||
|
self.do_set_current_image()
|
||||||
|
|
||||||
|
self.job_no += 1
|
||||||
|
self.sampling_step = 0
|
||||||
|
self.current_image_sampling_step = 0
|
||||||
|
|
||||||
|
def dict(self):
|
||||||
|
obj = {
|
||||||
|
"skipped": self.skipped,
|
||||||
|
"interrupted": self.interrupted,
|
||||||
|
"job": self.job,
|
||||||
|
"job_count": self.job_count,
|
||||||
|
"job_timestamp": self.job_timestamp,
|
||||||
|
"job_no": self.job_no,
|
||||||
|
"sampling_step": self.sampling_step,
|
||||||
|
"sampling_steps": self.sampling_steps,
|
||||||
|
}
|
||||||
|
|
||||||
|
return obj
|
||||||
|
|
||||||
|
def begin(self, job: str = "(unknown)"):
|
||||||
|
self.sampling_step = 0
|
||||||
|
self.job_count = -1
|
||||||
|
self.processing_has_refined_job_count = False
|
||||||
|
self.job_no = 0
|
||||||
|
self.job_timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
|
||||||
|
self.current_latent = None
|
||||||
|
self.current_image = None
|
||||||
|
self.current_image_sampling_step = 0
|
||||||
|
self.id_live_preview = 0
|
||||||
|
self.skipped = False
|
||||||
|
self.interrupted = False
|
||||||
|
self.textinfo = None
|
||||||
|
self.time_start = time.time()
|
||||||
|
self.job = job
|
||||||
|
devices.torch_gc()
|
||||||
|
log.info("Starting job %s", job)
|
||||||
|
|
||||||
|
def end(self):
|
||||||
|
duration = time.time() - self.time_start
|
||||||
|
log.info("Ending job %s (%.2f seconds)", self.job, duration)
|
||||||
|
self.job = ""
|
||||||
|
self.job_count = 0
|
||||||
|
|
||||||
|
devices.torch_gc()
|
||||||
|
|
||||||
|
def set_current_image(self):
|
||||||
|
"""if enough sampling steps have been made after the last call to this, sets self.current_image from self.current_latent, and modifies self.id_live_preview accordingly"""
|
||||||
|
if not shared.parallel_processing_allowed:
|
||||||
|
return
|
||||||
|
|
||||||
|
if self.sampling_step - self.current_image_sampling_step >= shared.opts.show_progress_every_n_steps and shared.opts.live_previews_enable and shared.opts.show_progress_every_n_steps != -1:
|
||||||
|
self.do_set_current_image()
|
||||||
|
|
||||||
|
def do_set_current_image(self):
|
||||||
|
if self.current_latent is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
import modules.sd_samplers
|
||||||
|
|
||||||
|
try:
|
||||||
|
if shared.opts.show_progress_grid:
|
||||||
|
self.assign_current_image(modules.sd_samplers.samples_to_image_grid(self.current_latent))
|
||||||
|
else:
|
||||||
|
self.assign_current_image(modules.sd_samplers.sample_to_image(self.current_latent))
|
||||||
|
|
||||||
|
self.current_image_sampling_step = self.sampling_step
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
# when switching models during genration, VAE would be on CPU, so creating an image will fail.
|
||||||
|
# we silently ignore this error
|
||||||
|
errors.record_exception()
|
||||||
|
|
||||||
|
def assign_current_image(self, image):
|
||||||
|
self.current_image = image
|
||||||
|
self.id_live_preview += 1
|
||||||
@@ -0,0 +1,37 @@
|
|||||||
|
import tqdm
|
||||||
|
|
||||||
|
from modules import shared
|
||||||
|
|
||||||
|
|
||||||
|
class TotalTQDM:
|
||||||
|
def __init__(self):
|
||||||
|
self._tqdm = None
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self._tqdm = tqdm.tqdm(
|
||||||
|
desc="Total progress",
|
||||||
|
total=shared.state.job_count * shared.state.sampling_steps,
|
||||||
|
position=1,
|
||||||
|
file=shared.progress_print_out
|
||||||
|
)
|
||||||
|
|
||||||
|
def update(self):
|
||||||
|
if not shared.opts.multiple_tqdm or shared.cmd_opts.disable_console_progressbars:
|
||||||
|
return
|
||||||
|
if self._tqdm is None:
|
||||||
|
self.reset()
|
||||||
|
self._tqdm.update()
|
||||||
|
|
||||||
|
def updateTotal(self, new_total):
|
||||||
|
if not shared.opts.multiple_tqdm or shared.cmd_opts.disable_console_progressbars:
|
||||||
|
return
|
||||||
|
if self._tqdm is None:
|
||||||
|
self.reset()
|
||||||
|
self._tqdm.total = new_total
|
||||||
|
|
||||||
|
def clear(self):
|
||||||
|
if self._tqdm is not None:
|
||||||
|
self._tqdm.refresh()
|
||||||
|
self._tqdm.close()
|
||||||
|
self._tqdm = None
|
||||||
|
|
||||||
@@ -58,7 +58,7 @@ def _summarize_chunk(
|
|||||||
scale: float,
|
scale: float,
|
||||||
) -> AttnChunk:
|
) -> AttnChunk:
|
||||||
attn_weights = torch.baddbmm(
|
attn_weights = torch.baddbmm(
|
||||||
torch.empty(1, 1, 1, device=query.device, dtype=query.dtype),
|
torch.zeros(1, 1, 1, device=query.device, dtype=query.dtype),
|
||||||
query,
|
query,
|
||||||
key.transpose(1,2),
|
key.transpose(1,2),
|
||||||
alpha=scale,
|
alpha=scale,
|
||||||
@@ -121,7 +121,7 @@ def _get_attention_scores_no_kv_chunking(
|
|||||||
scale: float,
|
scale: float,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
attn_scores = torch.baddbmm(
|
attn_scores = torch.baddbmm(
|
||||||
torch.empty(1, 1, 1, device=query.device, dtype=query.dtype),
|
torch.zeros(1, 1, 1, device=query.device, dtype=query.dtype),
|
||||||
query,
|
query,
|
||||||
key.transpose(1,2),
|
key.transpose(1,2),
|
||||||
alpha=scale,
|
alpha=scale,
|
||||||
|
|||||||
+1
-7
@@ -10,7 +10,7 @@ import psutil
|
|||||||
import re
|
import re
|
||||||
|
|
||||||
import launch
|
import launch
|
||||||
from modules import paths_internal, timer
|
from modules import paths_internal, timer, shared, extensions, errors
|
||||||
|
|
||||||
checksum_token = "DontStealMyGamePlz__WINNERS_DONT_USE_DRUGS__DONT_COPY_THAT_FLOPPY"
|
checksum_token = "DontStealMyGamePlz__WINNERS_DONT_USE_DRUGS__DONT_COPY_THAT_FLOPPY"
|
||||||
environment_whitelist = {
|
environment_whitelist = {
|
||||||
@@ -23,7 +23,6 @@ environment_whitelist = {
|
|||||||
"TORCH_COMMAND",
|
"TORCH_COMMAND",
|
||||||
"REQS_FILE",
|
"REQS_FILE",
|
||||||
"XFORMERS_PACKAGE",
|
"XFORMERS_PACKAGE",
|
||||||
"GFPGAN_PACKAGE",
|
|
||||||
"CLIP_PACKAGE",
|
"CLIP_PACKAGE",
|
||||||
"OPENCLIP_PACKAGE",
|
"OPENCLIP_PACKAGE",
|
||||||
"STABLE_DIFFUSION_REPO",
|
"STABLE_DIFFUSION_REPO",
|
||||||
@@ -115,8 +114,6 @@ def format_exception(e, tb):
|
|||||||
|
|
||||||
def get_exceptions():
|
def get_exceptions():
|
||||||
try:
|
try:
|
||||||
from modules import errors
|
|
||||||
|
|
||||||
return list(reversed(errors.exception_records))
|
return list(reversed(errors.exception_records))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return str(e)
|
return str(e)
|
||||||
@@ -142,8 +139,6 @@ def get_torch_sysinfo():
|
|||||||
def get_extensions(*, enabled):
|
def get_extensions(*, enabled):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from modules import extensions
|
|
||||||
|
|
||||||
def to_json(x: extensions.Extension):
|
def to_json(x: extensions.Extension):
|
||||||
return {
|
return {
|
||||||
"name": x.name,
|
"name": x.name,
|
||||||
@@ -160,7 +155,6 @@ def get_extensions(*, enabled):
|
|||||||
|
|
||||||
def get_config():
|
def get_config():
|
||||||
try:
|
try:
|
||||||
from modules import shared
|
|
||||||
return shared.opts.data
|
return shared.opts.data
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return str(e)
|
return str(e)
|
||||||
|
|||||||
+4
-12
@@ -1,7 +1,7 @@
|
|||||||
from contextlib import closing
|
from contextlib import closing
|
||||||
|
|
||||||
import modules.scripts
|
import modules.scripts
|
||||||
from modules import sd_samplers, processing
|
from modules import processing
|
||||||
from modules.generation_parameters_copypaste import create_override_settings_dict
|
from modules.generation_parameters_copypaste import create_override_settings_dict
|
||||||
from modules.shared import opts, cmd_opts
|
from modules.shared import opts, cmd_opts
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
@@ -9,7 +9,7 @@ from modules.ui import plaintext_to_html
|
|||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
|
|
||||||
def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_checkpoint_name: str, hr_sampler_index: int, hr_prompt: str, hr_negative_prompt, override_settings_texts, request: gr.Request, *args):
|
def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_name: str, n_iter: int, batch_size: int, cfg_scale: float, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_checkpoint_name: str, hr_sampler_name: str, hr_prompt: str, hr_negative_prompt, override_settings_texts, request: gr.Request, *args):
|
||||||
override_settings = create_override_settings_dict(override_settings_texts)
|
override_settings = create_override_settings_dict(override_settings_texts)
|
||||||
|
|
||||||
p = processing.StableDiffusionProcessingTxt2Img(
|
p = processing.StableDiffusionProcessingTxt2Img(
|
||||||
@@ -19,21 +19,13 @@ def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, step
|
|||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
styles=prompt_styles,
|
styles=prompt_styles,
|
||||||
negative_prompt=negative_prompt,
|
negative_prompt=negative_prompt,
|
||||||
seed=seed,
|
sampler_name=sampler_name,
|
||||||
subseed=subseed,
|
|
||||||
subseed_strength=subseed_strength,
|
|
||||||
seed_resize_from_h=seed_resize_from_h,
|
|
||||||
seed_resize_from_w=seed_resize_from_w,
|
|
||||||
seed_enable_extras=seed_enable_extras,
|
|
||||||
sampler_name=sd_samplers.samplers[sampler_index].name,
|
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
n_iter=n_iter,
|
n_iter=n_iter,
|
||||||
steps=steps,
|
steps=steps,
|
||||||
cfg_scale=cfg_scale,
|
cfg_scale=cfg_scale,
|
||||||
width=width,
|
width=width,
|
||||||
height=height,
|
height=height,
|
||||||
restore_faces=restore_faces,
|
|
||||||
tiling=tiling,
|
|
||||||
enable_hr=enable_hr,
|
enable_hr=enable_hr,
|
||||||
denoising_strength=denoising_strength if enable_hr else None,
|
denoising_strength=denoising_strength if enable_hr else None,
|
||||||
hr_scale=hr_scale,
|
hr_scale=hr_scale,
|
||||||
@@ -42,7 +34,7 @@ def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, step
|
|||||||
hr_resize_x=hr_resize_x,
|
hr_resize_x=hr_resize_x,
|
||||||
hr_resize_y=hr_resize_y,
|
hr_resize_y=hr_resize_y,
|
||||||
hr_checkpoint_name=None if hr_checkpoint_name == 'Use same checkpoint' else hr_checkpoint_name,
|
hr_checkpoint_name=None if hr_checkpoint_name == 'Use same checkpoint' else hr_checkpoint_name,
|
||||||
hr_sampler_name=sd_samplers.samplers_for_img2img[hr_sampler_index - 1].name if hr_sampler_index != 0 else None,
|
hr_sampler_name=None if hr_sampler_name == 'Use same sampler' else hr_sampler_name,
|
||||||
hr_prompt=hr_prompt,
|
hr_prompt=hr_prompt,
|
||||||
hr_negative_prompt=hr_negative_prompt,
|
hr_negative_prompt=hr_negative_prompt,
|
||||||
override_settings=override_settings,
|
override_settings=override_settings,
|
||||||
|
|||||||
+49
-156
@@ -1,5 +1,4 @@
|
|||||||
import datetime
|
import datetime
|
||||||
import json
|
|
||||||
import mimetypes
|
import mimetypes
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
@@ -13,8 +12,8 @@ from PIL import Image, PngImagePlugin # noqa: F401
|
|||||||
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
|
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
|
||||||
|
|
||||||
from modules import gradio_extensons # noqa: F401
|
from modules import gradio_extensons # noqa: F401
|
||||||
from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, errors, shared_items, ui_settings, timer, sysinfo, ui_checkpoint_merger, ui_prompt_styles, scripts
|
from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, shared_items, ui_settings, timer, sysinfo, ui_checkpoint_merger, ui_prompt_styles, scripts, sd_samplers, processing, ui_extra_networks
|
||||||
from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML
|
from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML, InputAccordion, ResizeHandleRow
|
||||||
from modules.paths import script_path
|
from modules.paths import script_path
|
||||||
from modules.ui_common import create_refresh_button
|
from modules.ui_common import create_refresh_button
|
||||||
from modules.ui_gradio_extensions import reload_javascript
|
from modules.ui_gradio_extensions import reload_javascript
|
||||||
@@ -29,7 +28,6 @@ import modules.shared as shared
|
|||||||
import modules.images
|
import modules.images
|
||||||
from modules import prompt_parser
|
from modules import prompt_parser
|
||||||
from modules.sd_hijack import model_hijack
|
from modules.sd_hijack import model_hijack
|
||||||
from modules.sd_samplers import samplers, samplers_for_img2img
|
|
||||||
from modules.generation_parameters_copypaste import image_from_url_text
|
from modules.generation_parameters_copypaste import image_from_url_text
|
||||||
|
|
||||||
create_setting_component = ui_settings.create_setting_component
|
create_setting_component = ui_settings.create_setting_component
|
||||||
@@ -41,6 +39,9 @@ warnings.filterwarnings("default" if opts.show_gradio_deprecation_warnings else
|
|||||||
mimetypes.init()
|
mimetypes.init()
|
||||||
mimetypes.add_type('application/javascript', '.js')
|
mimetypes.add_type('application/javascript', '.js')
|
||||||
|
|
||||||
|
# Likewise, add explicit content-type header for certain missing image types
|
||||||
|
mimetypes.add_type('image/webp', '.webp')
|
||||||
|
|
||||||
if not cmd_opts.share and not cmd_opts.listen:
|
if not cmd_opts.share and not cmd_opts.listen:
|
||||||
# fix gradio phoning home
|
# fix gradio phoning home
|
||||||
gradio.utils.version_check = lambda: None
|
gradio.utils.version_check = lambda: None
|
||||||
@@ -76,7 +77,6 @@ extra_networks_symbol = '\U0001F3B4' # 🎴
|
|||||||
switch_values_symbol = '\U000021C5' # ⇅
|
switch_values_symbol = '\U000021C5' # ⇅
|
||||||
restore_progress_symbol = '\U0001F300' # 🌀
|
restore_progress_symbol = '\U0001F300' # 🌀
|
||||||
detect_image_size_symbol = '\U0001F4D0' # 📐
|
detect_image_size_symbol = '\U0001F4D0' # 📐
|
||||||
up_down_symbol = '\u2195\ufe0f' # ↕️
|
|
||||||
|
|
||||||
|
|
||||||
plaintext_to_html = ui_common.plaintext_to_html
|
plaintext_to_html = ui_common.plaintext_to_html
|
||||||
@@ -89,17 +89,13 @@ def send_gradio_gallery_to_image(x):
|
|||||||
|
|
||||||
|
|
||||||
def calc_resolution_hires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y):
|
def calc_resolution_hires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y):
|
||||||
from modules import processing, devices
|
|
||||||
|
|
||||||
if not enable:
|
if not enable:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
p = processing.StableDiffusionProcessingTxt2Img(width=width, height=height, enable_hr=True, hr_scale=hr_scale, hr_resize_x=hr_resize_x, hr_resize_y=hr_resize_y)
|
p = processing.StableDiffusionProcessingTxt2Img(width=width, height=height, enable_hr=True, hr_scale=hr_scale, hr_resize_x=hr_resize_x, hr_resize_y=hr_resize_y)
|
||||||
|
p.calculate_target_resolution()
|
||||||
|
|
||||||
with devices.autocast():
|
return f"from <span class='resolution'>{p.width}x{p.height}</span> to <span class='resolution'>{p.hr_resize_x or p.hr_upscale_to_x}x{p.hr_resize_y or p.hr_upscale_to_y}</span>"
|
||||||
p.init([""], [0], [0])
|
|
||||||
|
|
||||||
return f"resize: from <span class='resolution'>{p.width}x{p.height}</span> to <span class='resolution'>{p.hr_resize_x or p.hr_upscale_to_x}x{p.hr_resize_y or p.hr_upscale_to_y}</span>"
|
|
||||||
|
|
||||||
|
|
||||||
def resize_from_to_html(width, height, scale_by):
|
def resize_from_to_html(width, height, scale_by):
|
||||||
@@ -145,41 +141,6 @@ def interrogate_deepbooru(image):
|
|||||||
return gr.update() if prompt is None else prompt
|
return gr.update() if prompt is None else prompt
|
||||||
|
|
||||||
|
|
||||||
def create_seed_inputs(target_interface):
|
|
||||||
with FormRow(elem_id=f"{target_interface}_seed_row", variant="compact"):
|
|
||||||
seed = (gr.Textbox if cmd_opts.use_textbox_seed else gr.Number)(label='Seed', value=-1, elem_id=f"{target_interface}_seed")
|
|
||||||
random_seed = ToolButton(random_symbol, elem_id=f"{target_interface}_random_seed", label='Random seed')
|
|
||||||
reuse_seed = ToolButton(reuse_symbol, elem_id=f"{target_interface}_reuse_seed", label='Reuse seed')
|
|
||||||
|
|
||||||
seed_checkbox = gr.Checkbox(label='Extra', elem_id=f"{target_interface}_subseed_show", value=False)
|
|
||||||
|
|
||||||
# Components to show/hide based on the 'Extra' checkbox
|
|
||||||
seed_extras = []
|
|
||||||
|
|
||||||
with FormRow(visible=False, elem_id=f"{target_interface}_subseed_row") as seed_extra_row_1:
|
|
||||||
seed_extras.append(seed_extra_row_1)
|
|
||||||
subseed = gr.Number(label='Variation seed', value=-1, elem_id=f"{target_interface}_subseed")
|
|
||||||
random_subseed = ToolButton(random_symbol, elem_id=f"{target_interface}_random_subseed")
|
|
||||||
reuse_subseed = ToolButton(reuse_symbol, elem_id=f"{target_interface}_reuse_subseed")
|
|
||||||
subseed_strength = gr.Slider(label='Variation strength', value=0.0, minimum=0, maximum=1, step=0.01, elem_id=f"{target_interface}_subseed_strength")
|
|
||||||
|
|
||||||
with FormRow(visible=False) as seed_extra_row_2:
|
|
||||||
seed_extras.append(seed_extra_row_2)
|
|
||||||
seed_resize_from_w = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from width", value=0, elem_id=f"{target_interface}_seed_resize_from_w")
|
|
||||||
seed_resize_from_h = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from height", value=0, elem_id=f"{target_interface}_seed_resize_from_h")
|
|
||||||
|
|
||||||
random_seed.click(fn=None, _js="function(){setRandomSeed('" + target_interface + "_seed')}", show_progress=False, inputs=[], outputs=[])
|
|
||||||
random_subseed.click(fn=None, _js="function(){setRandomSeed('" + target_interface + "_subseed')}", show_progress=False, inputs=[], outputs=[])
|
|
||||||
|
|
||||||
def change_visibility(show):
|
|
||||||
return {comp: gr_show(show) for comp in seed_extras}
|
|
||||||
|
|
||||||
seed_checkbox.change(change_visibility, show_progress=False, inputs=[seed_checkbox], outputs=seed_extras)
|
|
||||||
|
|
||||||
return seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def connect_clear_prompt(button):
|
def connect_clear_prompt(button):
|
||||||
"""Given clear button, prompt, and token_counter objects, setup clear prompt button click event"""
|
"""Given clear button, prompt, and token_counter objects, setup clear prompt button click event"""
|
||||||
button.click(
|
button.click(
|
||||||
@@ -190,39 +151,6 @@ def connect_clear_prompt(button):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info: gr.Textbox, dummy_component, is_subseed):
|
|
||||||
""" Connects a 'reuse (sub)seed' button's click event so that it copies last used
|
|
||||||
(sub)seed value from generation info the to the seed field. If copying subseed and subseed strength
|
|
||||||
was 0, i.e. no variation seed was used, it copies the normal seed value instead."""
|
|
||||||
def copy_seed(gen_info_string: str, index):
|
|
||||||
res = -1
|
|
||||||
|
|
||||||
try:
|
|
||||||
gen_info = json.loads(gen_info_string)
|
|
||||||
index -= gen_info.get('index_of_first_image', 0)
|
|
||||||
|
|
||||||
if is_subseed and gen_info.get('subseed_strength', 0) > 0:
|
|
||||||
all_subseeds = gen_info.get('all_subseeds', [-1])
|
|
||||||
res = all_subseeds[index if 0 <= index < len(all_subseeds) else 0]
|
|
||||||
else:
|
|
||||||
all_seeds = gen_info.get('all_seeds', [-1])
|
|
||||||
res = all_seeds[index if 0 <= index < len(all_seeds) else 0]
|
|
||||||
|
|
||||||
except json.decoder.JSONDecodeError:
|
|
||||||
if gen_info_string:
|
|
||||||
errors.report(f"Error parsing JSON generation info: {gen_info_string}")
|
|
||||||
|
|
||||||
return [res, gr_show(False)]
|
|
||||||
|
|
||||||
reuse_seed.click(
|
|
||||||
fn=copy_seed,
|
|
||||||
_js="(x, y) => [x, selected_gallery_index()]",
|
|
||||||
show_progress=False,
|
|
||||||
inputs=[generation_info, dummy_component],
|
|
||||||
outputs=[seed, dummy_component]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def update_token_counter(text, steps):
|
def update_token_counter(text, steps):
|
||||||
try:
|
try:
|
||||||
text, _ = extra_networks.parse_prompt(text)
|
text, _ = extra_networks.parse_prompt(text)
|
||||||
@@ -357,14 +285,14 @@ def create_output_panel(tabname, outdir):
|
|||||||
def create_sampler_and_steps_selection(choices, tabname):
|
def create_sampler_and_steps_selection(choices, tabname):
|
||||||
if opts.samplers_in_dropdown:
|
if opts.samplers_in_dropdown:
|
||||||
with FormRow(elem_id=f"sampler_selection_{tabname}"):
|
with FormRow(elem_id=f"sampler_selection_{tabname}"):
|
||||||
sampler_index = gr.Dropdown(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index")
|
sampler_name = gr.Dropdown(label='Sampling method', elem_id=f"{tabname}_sampling", choices=choices, value=choices[0])
|
||||||
steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling steps", value=20)
|
steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling steps", value=20)
|
||||||
else:
|
else:
|
||||||
with FormGroup(elem_id=f"sampler_selection_{tabname}"):
|
with FormGroup(elem_id=f"sampler_selection_{tabname}"):
|
||||||
steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling steps", value=20)
|
steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling steps", value=20)
|
||||||
sampler_index = gr.Radio(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index")
|
sampler_name = gr.Radio(label='Sampling method', elem_id=f"{tabname}_sampling", choices=choices, value=choices[0])
|
||||||
|
|
||||||
return steps, sampler_index
|
return steps, sampler_name
|
||||||
|
|
||||||
|
|
||||||
def ordered_ui_categories():
|
def ordered_ui_categories():
|
||||||
@@ -405,13 +333,13 @@ def create_ui():
|
|||||||
extra_tabs = gr.Tabs(elem_id="txt2img_extra_tabs")
|
extra_tabs = gr.Tabs(elem_id="txt2img_extra_tabs")
|
||||||
extra_tabs.__enter__()
|
extra_tabs.__enter__()
|
||||||
|
|
||||||
with gr.Tab("Generation", id="txt2img_generation") as txt2img_generation_tab, gr.Row().style(equal_height=False):
|
with gr.Tab("Generation", id="txt2img_generation") as txt2img_generation_tab, ResizeHandleRow(equal_height=False):
|
||||||
with gr.Column(variant='compact', elem_id="txt2img_settings"):
|
with gr.Column(variant='compact', elem_id="txt2img_settings"):
|
||||||
scripts.scripts_txt2img.prepare_ui()
|
scripts.scripts_txt2img.prepare_ui()
|
||||||
|
|
||||||
for category in ordered_ui_categories():
|
for category in ordered_ui_categories():
|
||||||
if category == "sampler":
|
if category == "sampler":
|
||||||
steps, sampler_index = create_sampler_and_steps_selection(samplers, "txt2img")
|
steps, sampler_name = create_sampler_and_steps_selection(sd_samplers.visible_sampler_names(), "txt2img")
|
||||||
|
|
||||||
elif category == "dimensions":
|
elif category == "dimensions":
|
||||||
with FormRow():
|
with FormRow():
|
||||||
@@ -428,20 +356,19 @@ def create_ui():
|
|||||||
batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size")
|
batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size")
|
||||||
|
|
||||||
elif category == "cfg":
|
elif category == "cfg":
|
||||||
|
with gr.Row():
|
||||||
cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="txt2img_cfg_scale")
|
cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="txt2img_cfg_scale")
|
||||||
|
|
||||||
elif category == "seed":
|
|
||||||
seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('txt2img')
|
|
||||||
|
|
||||||
elif category == "checkboxes":
|
elif category == "checkboxes":
|
||||||
with FormRow(elem_classes="checkboxes-row", variant="compact"):
|
with FormRow(elem_classes="checkboxes-row", variant="compact"):
|
||||||
restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="txt2img_restore_faces")
|
pass
|
||||||
tiling = gr.Checkbox(label='Tiling', value=False, elem_id="txt2img_tiling")
|
|
||||||
enable_hr = gr.Checkbox(label='Hires. fix', value=False, elem_id="txt2img_enable_hr")
|
elif category == "accordions":
|
||||||
hr_final_resolution = FormHTML(value="", elem_id="txtimg_hr_finalres", label="Upscaled resolution", interactive=False)
|
with gr.Row(elem_id="txt2img_accordions", elem_classes="accordions"):
|
||||||
|
with InputAccordion(False, label="Hires. fix", elem_id="txt2img_hr") as enable_hr:
|
||||||
|
with enable_hr.extra():
|
||||||
|
hr_final_resolution = FormHTML(value="", elem_id="txtimg_hr_finalres", label="Upscaled resolution", interactive=False, min_width=0)
|
||||||
|
|
||||||
elif category == "hires_fix":
|
|
||||||
with FormGroup(visible=False, elem_id="txt2img_hires_fix") as hr_options:
|
|
||||||
with FormRow(elem_id="txt2img_hires_fix_row1", variant="compact"):
|
with FormRow(elem_id="txt2img_hires_fix_row1", variant="compact"):
|
||||||
hr_upscaler = gr.Dropdown(label="Upscaler", elem_id="txt2img_hr_upscaler", choices=[*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]], value=shared.latent_upscale_default_mode)
|
hr_upscaler = gr.Dropdown(label="Upscaler", elem_id="txt2img_hr_upscaler", choices=[*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]], value=shared.latent_upscale_default_mode)
|
||||||
hr_second_pass_steps = gr.Slider(minimum=0, maximum=150, step=1, label='Hires steps', value=0, elem_id="txt2img_hires_steps")
|
hr_second_pass_steps = gr.Slider(minimum=0, maximum=150, step=1, label='Hires steps', value=0, elem_id="txt2img_hires_steps")
|
||||||
@@ -457,7 +384,7 @@ def create_ui():
|
|||||||
hr_checkpoint_name = gr.Dropdown(label='Hires checkpoint', elem_id="hr_checkpoint", choices=["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True), value="Use same checkpoint")
|
hr_checkpoint_name = gr.Dropdown(label='Hires checkpoint', elem_id="hr_checkpoint", choices=["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True), value="Use same checkpoint")
|
||||||
create_refresh_button(hr_checkpoint_name, modules.sd_models.list_models, lambda: {"choices": ["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True)}, "hr_checkpoint_refresh")
|
create_refresh_button(hr_checkpoint_name, modules.sd_models.list_models, lambda: {"choices": ["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True)}, "hr_checkpoint_refresh")
|
||||||
|
|
||||||
hr_sampler_index = gr.Dropdown(label='Hires sampling method', elem_id="hr_sampler", choices=["Use same sampler"] + [x.name for x in samplers_for_img2img], value="Use same sampler", type="index")
|
hr_sampler_name = gr.Dropdown(label='Hires sampling method', elem_id="hr_sampler", choices=["Use same sampler"] + sd_samplers.visible_sampler_names(), value="Use same sampler")
|
||||||
|
|
||||||
with FormRow(elem_id="txt2img_hires_fix_row4", variant="compact", visible=opts.hires_fix_show_prompts) as hr_prompts_container:
|
with FormRow(elem_id="txt2img_hires_fix_row4", variant="compact", visible=opts.hires_fix_show_prompts) as hr_prompts_container:
|
||||||
with gr.Column(scale=80):
|
with gr.Column(scale=80):
|
||||||
@@ -467,6 +394,8 @@ def create_ui():
|
|||||||
with gr.Row():
|
with gr.Row():
|
||||||
hr_negative_prompt = gr.Textbox(label="Hires negative prompt", elem_id="hires_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt for hires fix pass.\nLeave empty to use the same negative prompt as in first pass.", elem_classes=["prompt"])
|
hr_negative_prompt = gr.Textbox(label="Hires negative prompt", elem_id="hires_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt for hires fix pass.\nLeave empty to use the same negative prompt as in first pass.", elem_classes=["prompt"])
|
||||||
|
|
||||||
|
scripts.scripts_txt2img.setup_ui_for_section(category)
|
||||||
|
|
||||||
elif category == "batch":
|
elif category == "batch":
|
||||||
if not opts.dimensions_and_batch_together:
|
if not opts.dimensions_and_batch_together:
|
||||||
with FormRow(elem_id="txt2img_column_batch"):
|
with FormRow(elem_id="txt2img_column_batch"):
|
||||||
@@ -481,7 +410,7 @@ def create_ui():
|
|||||||
with FormGroup(elem_id="txt2img_script_container"):
|
with FormGroup(elem_id="txt2img_script_container"):
|
||||||
custom_inputs = scripts.scripts_txt2img.setup_ui()
|
custom_inputs = scripts.scripts_txt2img.setup_ui()
|
||||||
|
|
||||||
else:
|
if category not in {"accordions"}:
|
||||||
scripts.scripts_txt2img.setup_ui_for_section(category)
|
scripts.scripts_txt2img.setup_ui_for_section(category)
|
||||||
|
|
||||||
hr_resolution_preview_inputs = [enable_hr, width, height, hr_scale, hr_resize_x, hr_resize_y]
|
hr_resolution_preview_inputs = [enable_hr, width, height, hr_scale, hr_resize_x, hr_resize_y]
|
||||||
@@ -505,9 +434,6 @@ def create_ui():
|
|||||||
|
|
||||||
txt2img_gallery, generation_info, html_info, html_log = create_output_panel("txt2img", opts.outdir_txt2img_samples)
|
txt2img_gallery, generation_info, html_info, html_log = create_output_panel("txt2img", opts.outdir_txt2img_samples)
|
||||||
|
|
||||||
connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False)
|
|
||||||
connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True)
|
|
||||||
|
|
||||||
txt2img_args = dict(
|
txt2img_args = dict(
|
||||||
fn=wrap_gradio_gpu_call(modules.txt2img.txt2img, extra_outputs=[None, '', '']),
|
fn=wrap_gradio_gpu_call(modules.txt2img.txt2img, extra_outputs=[None, '', '']),
|
||||||
_js="submit",
|
_js="submit",
|
||||||
@@ -517,14 +443,10 @@ def create_ui():
|
|||||||
toprow.negative_prompt,
|
toprow.negative_prompt,
|
||||||
toprow.ui_styles.dropdown,
|
toprow.ui_styles.dropdown,
|
||||||
steps,
|
steps,
|
||||||
sampler_index,
|
sampler_name,
|
||||||
restore_faces,
|
|
||||||
tiling,
|
|
||||||
batch_count,
|
batch_count,
|
||||||
batch_size,
|
batch_size,
|
||||||
cfg_scale,
|
cfg_scale,
|
||||||
seed,
|
|
||||||
subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox,
|
|
||||||
height,
|
height,
|
||||||
width,
|
width,
|
||||||
enable_hr,
|
enable_hr,
|
||||||
@@ -535,7 +457,7 @@ def create_ui():
|
|||||||
hr_resize_x,
|
hr_resize_x,
|
||||||
hr_resize_y,
|
hr_resize_y,
|
||||||
hr_checkpoint_name,
|
hr_checkpoint_name,
|
||||||
hr_sampler_index,
|
hr_sampler_name,
|
||||||
hr_prompt,
|
hr_prompt,
|
||||||
hr_negative_prompt,
|
hr_negative_prompt,
|
||||||
override_settings,
|
override_settings,
|
||||||
@@ -569,40 +491,25 @@ def create_ui():
|
|||||||
show_progress=False,
|
show_progress=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
enable_hr.change(
|
|
||||||
fn=lambda x: gr_show(x),
|
|
||||||
inputs=[enable_hr],
|
|
||||||
outputs=[hr_options],
|
|
||||||
show_progress = False,
|
|
||||||
)
|
|
||||||
|
|
||||||
txt2img_paste_fields = [
|
txt2img_paste_fields = [
|
||||||
(toprow.prompt, "Prompt"),
|
(toprow.prompt, "Prompt"),
|
||||||
(toprow.negative_prompt, "Negative prompt"),
|
(toprow.negative_prompt, "Negative prompt"),
|
||||||
(steps, "Steps"),
|
(steps, "Steps"),
|
||||||
(sampler_index, "Sampler"),
|
(sampler_name, "Sampler"),
|
||||||
(restore_faces, "Face restoration"),
|
|
||||||
(cfg_scale, "CFG scale"),
|
(cfg_scale, "CFG scale"),
|
||||||
(seed, "Seed"),
|
|
||||||
(width, "Size-1"),
|
(width, "Size-1"),
|
||||||
(height, "Size-2"),
|
(height, "Size-2"),
|
||||||
(batch_size, "Batch size"),
|
(batch_size, "Batch size"),
|
||||||
(seed_checkbox, lambda d: "Variation seed" in d or "Seed resize from-1" in d),
|
|
||||||
(subseed, "Variation seed"),
|
|
||||||
(subseed_strength, "Variation seed strength"),
|
|
||||||
(seed_resize_from_w, "Seed resize from-1"),
|
|
||||||
(seed_resize_from_h, "Seed resize from-2"),
|
|
||||||
(toprow.ui_styles.dropdown, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()),
|
(toprow.ui_styles.dropdown, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()),
|
||||||
(denoising_strength, "Denoising strength"),
|
(denoising_strength, "Denoising strength"),
|
||||||
(enable_hr, lambda d: "Denoising strength" in d and ("Hires upscale" in d or "Hires upscaler" in d or "Hires resize-1" in d)),
|
(enable_hr, lambda d: "Denoising strength" in d and ("Hires upscale" in d or "Hires upscaler" in d or "Hires resize-1" in d)),
|
||||||
(hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d and ("Hires upscale" in d or "Hires upscaler" in d or "Hires resize-1" in d))),
|
|
||||||
(hr_scale, "Hires upscale"),
|
(hr_scale, "Hires upscale"),
|
||||||
(hr_upscaler, "Hires upscaler"),
|
(hr_upscaler, "Hires upscaler"),
|
||||||
(hr_second_pass_steps, "Hires steps"),
|
(hr_second_pass_steps, "Hires steps"),
|
||||||
(hr_resize_x, "Hires resize-1"),
|
(hr_resize_x, "Hires resize-1"),
|
||||||
(hr_resize_y, "Hires resize-2"),
|
(hr_resize_y, "Hires resize-2"),
|
||||||
(hr_checkpoint_name, "Hires checkpoint"),
|
(hr_checkpoint_name, "Hires checkpoint"),
|
||||||
(hr_sampler_index, "Hires sampler"),
|
(hr_sampler_name, "Hires sampler"),
|
||||||
(hr_sampler_container, lambda d: gr.update(visible=True) if d.get("Hires sampler", "Use same sampler") != "Use same sampler" or d.get("Hires checkpoint", "Use same checkpoint") != "Use same checkpoint" else gr.update()),
|
(hr_sampler_container, lambda d: gr.update(visible=True) if d.get("Hires sampler", "Use same sampler") != "Use same sampler" or d.get("Hires checkpoint", "Use same checkpoint") != "Use same checkpoint" else gr.update()),
|
||||||
(hr_prompt, "Hires prompt"),
|
(hr_prompt, "Hires prompt"),
|
||||||
(hr_negative_prompt, "Hires negative prompt"),
|
(hr_negative_prompt, "Hires negative prompt"),
|
||||||
@@ -618,9 +525,9 @@ def create_ui():
|
|||||||
toprow.prompt,
|
toprow.prompt,
|
||||||
toprow.negative_prompt,
|
toprow.negative_prompt,
|
||||||
steps,
|
steps,
|
||||||
sampler_index,
|
sampler_name,
|
||||||
cfg_scale,
|
cfg_scale,
|
||||||
seed,
|
scripts.scripts_txt2img.script('Seed').seed,
|
||||||
width,
|
width,
|
||||||
height,
|
height,
|
||||||
]
|
]
|
||||||
@@ -628,7 +535,6 @@ def create_ui():
|
|||||||
toprow.token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[toprow.prompt, steps], outputs=[toprow.token_counter])
|
toprow.token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[toprow.prompt, steps], outputs=[toprow.token_counter])
|
||||||
toprow.negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[toprow.negative_prompt, steps], outputs=[toprow.negative_token_counter])
|
toprow.negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[toprow.negative_prompt, steps], outputs=[toprow.negative_token_counter])
|
||||||
|
|
||||||
from modules import ui_extra_networks
|
|
||||||
extra_networks_ui = ui_extra_networks.create_ui(txt2img_interface, [txt2img_generation_tab], 'txt2img')
|
extra_networks_ui = ui_extra_networks.create_ui(txt2img_interface, [txt2img_generation_tab], 'txt2img')
|
||||||
ui_extra_networks.setup_ui(extra_networks_ui, txt2img_gallery)
|
ui_extra_networks.setup_ui(extra_networks_ui, txt2img_gallery)
|
||||||
|
|
||||||
@@ -643,7 +549,7 @@ def create_ui():
|
|||||||
extra_tabs = gr.Tabs(elem_id="img2img_extra_tabs")
|
extra_tabs = gr.Tabs(elem_id="img2img_extra_tabs")
|
||||||
extra_tabs.__enter__()
|
extra_tabs.__enter__()
|
||||||
|
|
||||||
with gr.Tab("Generation", id="img2img_generation") as img2img_generation_tab, FormRow().style(equal_height=False):
|
with gr.Tab("Generation", id="img2img_generation") as img2img_generation_tab, ResizeHandleRow(equal_height=False):
|
||||||
with gr.Column(variant='compact', elem_id="img2img_settings"):
|
with gr.Column(variant='compact', elem_id="img2img_settings"):
|
||||||
copy_image_buttons = []
|
copy_image_buttons = []
|
||||||
copy_image_destinations = {}
|
copy_image_destinations = {}
|
||||||
@@ -669,7 +575,7 @@ def create_ui():
|
|||||||
add_copy_image_controls('img2img', init_img)
|
add_copy_image_controls('img2img', init_img)
|
||||||
|
|
||||||
with gr.TabItem('Sketch', id='img2img_sketch', elem_id="img2img_img2img_sketch_tab") as tab_sketch:
|
with gr.TabItem('Sketch', id='img2img_sketch', elem_id="img2img_img2img_sketch_tab") as tab_sketch:
|
||||||
sketch = gr.Image(label="Image for img2img", elem_id="img2img_sketch", show_label=False, source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGBA", height=opts.img2img_editor_height, brush_color=opts.img2img_sketch_default_brush_color)
|
sketch = gr.Image(label="Image for img2img", elem_id="img2img_sketch", show_label=False, source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGB", height=opts.img2img_editor_height, brush_color=opts.img2img_sketch_default_brush_color)
|
||||||
add_copy_image_controls('sketch', sketch)
|
add_copy_image_controls('sketch', sketch)
|
||||||
|
|
||||||
with gr.TabItem('Inpaint', id='inpaint', elem_id="img2img_inpaint_tab") as tab_inpaint:
|
with gr.TabItem('Inpaint', id='inpaint', elem_id="img2img_inpaint_tab") as tab_inpaint:
|
||||||
@@ -677,7 +583,7 @@ def create_ui():
|
|||||||
add_copy_image_controls('inpaint', init_img_with_mask)
|
add_copy_image_controls('inpaint', init_img_with_mask)
|
||||||
|
|
||||||
with gr.TabItem('Inpaint sketch', id='inpaint_sketch', elem_id="img2img_inpaint_sketch_tab") as tab_inpaint_color:
|
with gr.TabItem('Inpaint sketch', id='inpaint_sketch', elem_id="img2img_inpaint_sketch_tab") as tab_inpaint_color:
|
||||||
inpaint_color_sketch = gr.Image(label="Color sketch inpainting", show_label=False, elem_id="inpaint_sketch", source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGBA", height=opts.img2img_editor_height, brush_color=opts.img2img_inpaint_sketch_default_brush_color)
|
inpaint_color_sketch = gr.Image(label="Color sketch inpainting", show_label=False, elem_id="inpaint_sketch", source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGB", height=opts.img2img_editor_height, brush_color=opts.img2img_inpaint_sketch_default_brush_color)
|
||||||
inpaint_color_sketch_orig = gr.State(None)
|
inpaint_color_sketch_orig = gr.State(None)
|
||||||
add_copy_image_controls('inpaint_sketch', inpaint_color_sketch)
|
add_copy_image_controls('inpaint_sketch', inpaint_color_sketch)
|
||||||
|
|
||||||
@@ -692,7 +598,7 @@ def create_ui():
|
|||||||
|
|
||||||
with gr.TabItem('Inpaint upload', id='inpaint_upload', elem_id="img2img_inpaint_upload_tab") as tab_inpaint_upload:
|
with gr.TabItem('Inpaint upload', id='inpaint_upload', elem_id="img2img_inpaint_upload_tab") as tab_inpaint_upload:
|
||||||
init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", elem_id="img_inpaint_base")
|
init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", elem_id="img_inpaint_base")
|
||||||
init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", elem_id="img_inpaint_mask")
|
init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", image_mode="RGBA", elem_id="img_inpaint_mask")
|
||||||
|
|
||||||
with gr.TabItem('Batch', id='batch', elem_id="img2img_batch_tab") as tab_batch:
|
with gr.TabItem('Batch', id='batch', elem_id="img2img_batch_tab") as tab_batch:
|
||||||
hidden = '<br>Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else ''
|
hidden = '<br>Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else ''
|
||||||
@@ -741,7 +647,7 @@ def create_ui():
|
|||||||
|
|
||||||
for category in ordered_ui_categories():
|
for category in ordered_ui_categories():
|
||||||
if category == "sampler":
|
if category == "sampler":
|
||||||
steps, sampler_index = create_sampler_and_steps_selection(samplers_for_img2img, "img2img")
|
steps, sampler_name = create_sampler_and_steps_selection(sd_samplers.visible_sampler_names(), "img2img")
|
||||||
|
|
||||||
elif category == "dimensions":
|
elif category == "dimensions":
|
||||||
with FormRow():
|
with FormRow():
|
||||||
@@ -791,20 +697,21 @@ def create_ui():
|
|||||||
batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count")
|
batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count")
|
||||||
batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size")
|
batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size")
|
||||||
|
|
||||||
elif category == "cfg":
|
elif category == "denoising":
|
||||||
with FormGroup():
|
|
||||||
with FormRow():
|
|
||||||
cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="img2img_cfg_scale")
|
|
||||||
image_cfg_scale = gr.Slider(minimum=0, maximum=3.0, step=0.05, label='Image CFG Scale', value=1.5, elem_id="img2img_image_cfg_scale", visible=False)
|
|
||||||
denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.75, elem_id="img2img_denoising_strength")
|
denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.75, elem_id="img2img_denoising_strength")
|
||||||
|
|
||||||
elif category == "seed":
|
elif category == "cfg":
|
||||||
seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('img2img')
|
with gr.Row():
|
||||||
|
cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="img2img_cfg_scale")
|
||||||
|
image_cfg_scale = gr.Slider(minimum=0, maximum=3.0, step=0.05, label='Image CFG Scale', value=1.5, elem_id="img2img_image_cfg_scale", visible=False)
|
||||||
|
|
||||||
elif category == "checkboxes":
|
elif category == "checkboxes":
|
||||||
with FormRow(elem_classes="checkboxes-row", variant="compact"):
|
with FormRow(elem_classes="checkboxes-row", variant="compact"):
|
||||||
restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="img2img_restore_faces")
|
pass
|
||||||
tiling = gr.Checkbox(label='Tiling', value=False, elem_id="img2img_tiling")
|
|
||||||
|
elif category == "accordions":
|
||||||
|
with gr.Row(elem_id="img2img_accordions", elem_classes="accordions"):
|
||||||
|
scripts.scripts_img2img.setup_ui_for_section(category)
|
||||||
|
|
||||||
elif category == "batch":
|
elif category == "batch":
|
||||||
if not opts.dimensions_and_batch_together:
|
if not opts.dimensions_and_batch_together:
|
||||||
@@ -848,14 +755,12 @@ def create_ui():
|
|||||||
inputs=[],
|
inputs=[],
|
||||||
outputs=[inpaint_controls, mask_alpha],
|
outputs=[inpaint_controls, mask_alpha],
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
|
if category not in {"accordions"}:
|
||||||
scripts.scripts_img2img.setup_ui_for_section(category)
|
scripts.scripts_img2img.setup_ui_for_section(category)
|
||||||
|
|
||||||
img2img_gallery, generation_info, html_info, html_log = create_output_panel("img2img", opts.outdir_img2img_samples)
|
img2img_gallery, generation_info, html_info, html_log = create_output_panel("img2img", opts.outdir_img2img_samples)
|
||||||
|
|
||||||
connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False)
|
|
||||||
connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True)
|
|
||||||
|
|
||||||
img2img_args = dict(
|
img2img_args = dict(
|
||||||
fn=wrap_gradio_gpu_call(modules.img2img.img2img, extra_outputs=[None, '', '']),
|
fn=wrap_gradio_gpu_call(modules.img2img.img2img, extra_outputs=[None, '', '']),
|
||||||
_js="submit_img2img",
|
_js="submit_img2img",
|
||||||
@@ -873,19 +778,15 @@ def create_ui():
|
|||||||
init_img_inpaint,
|
init_img_inpaint,
|
||||||
init_mask_inpaint,
|
init_mask_inpaint,
|
||||||
steps,
|
steps,
|
||||||
sampler_index,
|
sampler_name,
|
||||||
mask_blur,
|
mask_blur,
|
||||||
mask_alpha,
|
mask_alpha,
|
||||||
inpainting_fill,
|
inpainting_fill,
|
||||||
restore_faces,
|
|
||||||
tiling,
|
|
||||||
batch_count,
|
batch_count,
|
||||||
batch_size,
|
batch_size,
|
||||||
cfg_scale,
|
cfg_scale,
|
||||||
image_cfg_scale,
|
image_cfg_scale,
|
||||||
denoising_strength,
|
denoising_strength,
|
||||||
seed,
|
|
||||||
subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox,
|
|
||||||
selected_scale_tab,
|
selected_scale_tab,
|
||||||
height,
|
height,
|
||||||
width,
|
width,
|
||||||
@@ -969,19 +870,12 @@ def create_ui():
|
|||||||
(toprow.prompt, "Prompt"),
|
(toprow.prompt, "Prompt"),
|
||||||
(toprow.negative_prompt, "Negative prompt"),
|
(toprow.negative_prompt, "Negative prompt"),
|
||||||
(steps, "Steps"),
|
(steps, "Steps"),
|
||||||
(sampler_index, "Sampler"),
|
(sampler_name, "Sampler"),
|
||||||
(restore_faces, "Face restoration"),
|
|
||||||
(cfg_scale, "CFG scale"),
|
(cfg_scale, "CFG scale"),
|
||||||
(image_cfg_scale, "Image CFG scale"),
|
(image_cfg_scale, "Image CFG scale"),
|
||||||
(seed, "Seed"),
|
|
||||||
(width, "Size-1"),
|
(width, "Size-1"),
|
||||||
(height, "Size-2"),
|
(height, "Size-2"),
|
||||||
(batch_size, "Batch size"),
|
(batch_size, "Batch size"),
|
||||||
(seed_checkbox, lambda d: "Variation seed" in d or "Seed resize from-1" in d),
|
|
||||||
(subseed, "Variation seed"),
|
|
||||||
(subseed_strength, "Variation seed strength"),
|
|
||||||
(seed_resize_from_w, "Seed resize from-1"),
|
|
||||||
(seed_resize_from_h, "Seed resize from-2"),
|
|
||||||
(toprow.ui_styles.dropdown, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()),
|
(toprow.ui_styles.dropdown, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()),
|
||||||
(denoising_strength, "Denoising strength"),
|
(denoising_strength, "Denoising strength"),
|
||||||
(mask_blur, "Mask blur"),
|
(mask_blur, "Mask blur"),
|
||||||
@@ -993,7 +887,6 @@ def create_ui():
|
|||||||
paste_button=toprow.paste, tabname="img2img", source_text_component=toprow.prompt, source_image_component=None,
|
paste_button=toprow.paste, tabname="img2img", source_text_component=toprow.prompt, source_image_component=None,
|
||||||
))
|
))
|
||||||
|
|
||||||
from modules import ui_extra_networks
|
|
||||||
extra_networks_ui_img2img = ui_extra_networks.create_ui(img2img_interface, [img2img_generation_tab], 'img2img')
|
extra_networks_ui_img2img = ui_extra_networks.create_ui(img2img_interface, [img2img_generation_tab], 'img2img')
|
||||||
ui_extra_networks.setup_ui(extra_networks_ui_img2img, img2img_gallery)
|
ui_extra_networks.setup_ui(extra_networks_ui_img2img, img2img_gallery)
|
||||||
|
|
||||||
|
|||||||
+10
-8
@@ -11,7 +11,7 @@ from modules import call_queue, shared
|
|||||||
from modules.generation_parameters_copypaste import image_from_url_text
|
from modules.generation_parameters_copypaste import image_from_url_text
|
||||||
import modules.images
|
import modules.images
|
||||||
from modules.ui_components import ToolButton
|
from modules.ui_components import ToolButton
|
||||||
|
import modules.generation_parameters_copypaste as parameters_copypaste
|
||||||
|
|
||||||
folder_symbol = '\U0001f4c2' # 📂
|
folder_symbol = '\U0001f4c2' # 📂
|
||||||
refresh_symbol = '\U0001f504' # 🔄
|
refresh_symbol = '\U0001f504' # 🔄
|
||||||
@@ -105,8 +105,6 @@ def save_files(js_data, images, do_make_zip, index):
|
|||||||
|
|
||||||
|
|
||||||
def create_output_panel(tabname, outdir):
|
def create_output_panel(tabname, outdir):
|
||||||
from modules import shared
|
|
||||||
import modules.generation_parameters_copypaste as parameters_copypaste
|
|
||||||
|
|
||||||
def open_folder(f):
|
def open_folder(f):
|
||||||
if not os.path.exists(f):
|
if not os.path.exists(f):
|
||||||
@@ -134,18 +132,22 @@ Requested path was: {f}
|
|||||||
|
|
||||||
with gr.Column(variant='panel', elem_id=f"{tabname}_results"):
|
with gr.Column(variant='panel', elem_id=f"{tabname}_results"):
|
||||||
with gr.Group(elem_id=f"{tabname}_gallery_container"):
|
with gr.Group(elem_id=f"{tabname}_gallery_container"):
|
||||||
result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery", columns=4)
|
result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery", columns=4, preview=True, height=shared.opts.gallery_height or None)
|
||||||
|
|
||||||
generation_info = None
|
generation_info = None
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
with gr.Row(elem_id=f"image_buttons_{tabname}", elem_classes="image-buttons"):
|
with gr.Row(elem_id=f"image_buttons_{tabname}", elem_classes="image-buttons"):
|
||||||
open_folder_button = gr.Button(folder_symbol, visible=not shared.cmd_opts.hide_ui_dir_config)
|
open_folder_button = ToolButton(folder_symbol, elem_id=f'{tabname}_open_folder', visible=not shared.cmd_opts.hide_ui_dir_config, tooltip="Open images output directory.")
|
||||||
|
|
||||||
if tabname != "extras":
|
if tabname != "extras":
|
||||||
save = gr.Button('Save', elem_id=f'save_{tabname}')
|
save = ToolButton('💾', elem_id=f'save_{tabname}', tooltip=f"Save the image to a dedicated directory ({shared.opts.outdir_save}).")
|
||||||
save_zip = gr.Button('Zip', elem_id=f'save_zip_{tabname}')
|
save_zip = ToolButton('🗃️', elem_id=f'save_zip_{tabname}', tooltip=f"Save zip archive with images to a dedicated directory ({shared.opts.outdir_save})")
|
||||||
|
|
||||||
buttons = parameters_copypaste.create_buttons(["img2img", "inpaint", "extras"])
|
buttons = {
|
||||||
|
'img2img': ToolButton('🖼️', elem_id=f'{tabname}_send_to_img2img', tooltip="Send image and generation parameters to img2img tab."),
|
||||||
|
'inpaint': ToolButton('🎨️', elem_id=f'{tabname}_send_to_inpaint', tooltip="Send image and generation parameters to img2img inpaint tab."),
|
||||||
|
'extras': ToolButton('📐', elem_id=f'{tabname}_send_to_extras', tooltip="Send image and generation parameters to extras tab.")
|
||||||
|
}
|
||||||
|
|
||||||
open_folder_button.click(
|
open_folder_button.click(
|
||||||
fn=lambda: open_folder(shared.opts.outdir_samples or outdir),
|
fn=lambda: open_folder(shared.opts.outdir_samples or outdir),
|
||||||
|
|||||||
@@ -20,6 +20,18 @@ class ToolButton(FormComponent, gr.Button):
|
|||||||
return "button"
|
return "button"
|
||||||
|
|
||||||
|
|
||||||
|
class ResizeHandleRow(gr.Row):
|
||||||
|
"""Same as gr.Row but fits inside gradio forms"""
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
self.elem_classes.append("resize-handle-row")
|
||||||
|
|
||||||
|
def get_block_name(self):
|
||||||
|
return "row"
|
||||||
|
|
||||||
|
|
||||||
class FormRow(FormComponent, gr.Row):
|
class FormRow(FormComponent, gr.Row):
|
||||||
"""Same as gr.Row but fits inside gradio forms"""
|
"""Same as gr.Row but fits inside gradio forms"""
|
||||||
|
|
||||||
@@ -72,3 +84,62 @@ class DropdownEditable(FormComponent, gr.Dropdown):
|
|||||||
def get_block_name(self):
|
def get_block_name(self):
|
||||||
return "dropdown"
|
return "dropdown"
|
||||||
|
|
||||||
|
|
||||||
|
class InputAccordion(gr.Checkbox):
|
||||||
|
"""A gr.Accordion that can be used as an input - returns True if open, False if closed.
|
||||||
|
|
||||||
|
Actaully just a hidden checkbox, but creates an accordion that follows and is followed by the state of the checkbox.
|
||||||
|
"""
|
||||||
|
|
||||||
|
global_index = 0
|
||||||
|
|
||||||
|
def __init__(self, value, **kwargs):
|
||||||
|
self.accordion_id = kwargs.get('elem_id')
|
||||||
|
if self.accordion_id is None:
|
||||||
|
self.accordion_id = f"input-accordion-{InputAccordion.global_index}"
|
||||||
|
InputAccordion.global_index += 1
|
||||||
|
|
||||||
|
kwargs_checkbox = {
|
||||||
|
**kwargs,
|
||||||
|
"elem_id": f"{self.accordion_id}-checkbox",
|
||||||
|
"visible": False,
|
||||||
|
}
|
||||||
|
super().__init__(value, **kwargs_checkbox)
|
||||||
|
|
||||||
|
self.change(fn=None, _js='function(checked){ inputAccordionChecked("' + self.accordion_id + '", checked); }', inputs=[self])
|
||||||
|
|
||||||
|
kwargs_accordion = {
|
||||||
|
**kwargs,
|
||||||
|
"elem_id": self.accordion_id,
|
||||||
|
"label": kwargs.get('label', 'Accordion'),
|
||||||
|
"elem_classes": ['input-accordion'],
|
||||||
|
"open": value,
|
||||||
|
}
|
||||||
|
self.accordion = gr.Accordion(**kwargs_accordion)
|
||||||
|
|
||||||
|
def extra(self):
|
||||||
|
"""Allows you to put something into the label of the accordion.
|
||||||
|
|
||||||
|
Use it like this:
|
||||||
|
|
||||||
|
```
|
||||||
|
with InputAccordion(False, label="Accordion") as acc:
|
||||||
|
with acc.extra():
|
||||||
|
FormHTML(value="hello", min_width=0)
|
||||||
|
|
||||||
|
...
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
return gr.Column(elem_id=self.accordion_id + '-extra', elem_classes='input-accordion-extra', min_width=0)
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
self.accordion.__enter__()
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
self.accordion.__exit__(exc_type, exc_val, exc_tb)
|
||||||
|
|
||||||
|
def get_block_name(self):
|
||||||
|
return "checkbox"
|
||||||
|
|
||||||
|
|||||||
+36
-26
@@ -65,7 +65,7 @@ def save_config_state(name):
|
|||||||
filename = os.path.join(config_states_dir, f"{timestamp}_{name}.json")
|
filename = os.path.join(config_states_dir, f"{timestamp}_{name}.json")
|
||||||
print(f"Saving backup of webui/extension state to {filename}.")
|
print(f"Saving backup of webui/extension state to {filename}.")
|
||||||
with open(filename, "w", encoding="utf-8") as f:
|
with open(filename, "w", encoding="utf-8") as f:
|
||||||
json.dump(current_config_state, f)
|
json.dump(current_config_state, f, indent=4)
|
||||||
config_states.list_config_states()
|
config_states.list_config_states()
|
||||||
new_value = next(iter(config_states.all_config_states.keys()), "Current")
|
new_value = next(iter(config_states.all_config_states.keys()), "Current")
|
||||||
new_choices = ["Current"] + list(config_states.all_config_states.keys())
|
new_choices = ["Current"] + list(config_states.all_config_states.keys())
|
||||||
@@ -200,8 +200,7 @@ def update_config_states_table(state_name):
|
|||||||
created_date = time.asctime(time.gmtime(config_state["created_at"]))
|
created_date = time.asctime(time.gmtime(config_state["created_at"]))
|
||||||
filepath = config_state.get("filepath", "<unknown>")
|
filepath = config_state.get("filepath", "<unknown>")
|
||||||
|
|
||||||
code = f"""<!-- {time.time()} -->"""
|
try:
|
||||||
|
|
||||||
webui_remote = config_state["webui"]["remote"] or ""
|
webui_remote = config_state["webui"]["remote"] or ""
|
||||||
webui_branch = config_state["webui"]["branch"]
|
webui_branch = config_state["webui"]["branch"]
|
||||||
webui_commit_hash = config_state["webui"]["commit_hash"] or "<unknown>"
|
webui_commit_hash = config_state["webui"]["commit_hash"] or "<unknown>"
|
||||||
@@ -227,12 +226,12 @@ def update_config_states_table(state_name):
|
|||||||
if current_webui["commit_hash"] != webui_commit_hash:
|
if current_webui["commit_hash"] != webui_commit_hash:
|
||||||
style_commit = STYLE_PRIMARY
|
style_commit = STYLE_PRIMARY
|
||||||
|
|
||||||
code += f"""<h2>Config Backup: {config_name}</h2>
|
code = f"""<!-- {time.time()} -->
|
||||||
<div><b>Filepath:</b> {filepath}</div>
|
<h2>Config Backup: {config_name}</h2>
|
||||||
<div><b>Created at:</b> {created_date}</div>"""
|
<div><b>Filepath:</b> {filepath}</div>
|
||||||
|
<div><b>Created at:</b> {created_date}</div>
|
||||||
code += f"""<h2>WebUI State</h2>
|
<h2>WebUI State</h2>
|
||||||
<table id="config_state_webui">
|
<table id="config_state_webui">
|
||||||
<thead>
|
<thead>
|
||||||
<tr>
|
<tr>
|
||||||
<th>URL</th>
|
<th>URL</th>
|
||||||
@@ -243,17 +242,23 @@ def update_config_states_table(state_name):
|
|||||||
</thead>
|
</thead>
|
||||||
<tbody>
|
<tbody>
|
||||||
<tr>
|
<tr>
|
||||||
<td><label{style_remote}>{remote}</label></td>
|
<td>
|
||||||
<td><label{style_branch}>{webui_branch}</label></td>
|
<label{style_remote}>{remote}</label>
|
||||||
<td><label{style_commit}>{commit_link}</label></td>
|
</td>
|
||||||
<td><label{style_commit}>{date_link}</label></td>
|
<td>
|
||||||
|
<label{style_branch}>{webui_branch}</label>
|
||||||
|
</td>
|
||||||
|
<td>
|
||||||
|
<label{style_commit}>{commit_link}</label>
|
||||||
|
</td>
|
||||||
|
<td>
|
||||||
|
<label{style_commit}>{date_link}</label>
|
||||||
|
</td>
|
||||||
</tr>
|
</tr>
|
||||||
</tbody>
|
</tbody>
|
||||||
</table>
|
</table>
|
||||||
"""
|
<h2>Extension State</h2>
|
||||||
|
<table id="config_state_extensions">
|
||||||
code += """<h2>Extension State</h2>
|
|
||||||
<table id="config_state_extensions">
|
|
||||||
<thead>
|
<thead>
|
||||||
<tr>
|
<tr>
|
||||||
<th>Extension</th>
|
<th>Extension</th>
|
||||||
@@ -264,7 +269,7 @@ def update_config_states_table(state_name):
|
|||||||
</tr>
|
</tr>
|
||||||
</thead>
|
</thead>
|
||||||
<tbody>
|
<tbody>
|
||||||
"""
|
"""
|
||||||
|
|
||||||
ext_map = {ext.name: ext for ext in extensions.extensions}
|
ext_map = {ext.name: ext for ext in extensions.extensions}
|
||||||
|
|
||||||
@@ -299,20 +304,25 @@ def update_config_states_table(state_name):
|
|||||||
if current_ext.commit_hash != ext_commit_hash:
|
if current_ext.commit_hash != ext_commit_hash:
|
||||||
style_commit = STYLE_PRIMARY
|
style_commit = STYLE_PRIMARY
|
||||||
|
|
||||||
code += f"""
|
code += f""" <tr>
|
||||||
<tr>
|
|
||||||
<td><label{style_enabled}><input class="gr-check-radio gr-checkbox" type="checkbox" disabled="true" {'checked="checked"' if ext_enabled else ''}>{html.escape(ext_name)}</label></td>
|
<td><label{style_enabled}><input class="gr-check-radio gr-checkbox" type="checkbox" disabled="true" {'checked="checked"' if ext_enabled else ''}>{html.escape(ext_name)}</label></td>
|
||||||
<td><label{style_remote}>{remote}</label></td>
|
<td><label{style_remote}>{remote}</label></td>
|
||||||
<td><label{style_branch}>{ext_branch}</label></td>
|
<td><label{style_branch}>{ext_branch}</label></td>
|
||||||
<td><label{style_commit}>{commit_link}</label></td>
|
<td><label{style_commit}>{commit_link}</label></td>
|
||||||
<td><label{style_commit}>{date_link}</label></td>
|
<td><label{style_commit}>{date_link}</label></td>
|
||||||
</tr>
|
</tr>
|
||||||
"""
|
"""
|
||||||
|
|
||||||
code += """
|
code += """ </tbody>
|
||||||
</tbody>
|
</table>"""
|
||||||
</table>
|
|
||||||
"""
|
except Exception as e:
|
||||||
|
print(f"[ERROR]: Config states {filepath}, {e}")
|
||||||
|
code = f"""<!-- {time.time()} -->
|
||||||
|
<h2>Config Backup: {config_name}</h2>
|
||||||
|
<div><b>Filepath:</b> {filepath}</div>
|
||||||
|
<div><b>Created at:</b> {created_date}</div>
|
||||||
|
<h2>This file is corrupted</h2>"""
|
||||||
|
|
||||||
return code
|
return code
|
||||||
|
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ from pathlib import Path
|
|||||||
|
|
||||||
from modules import shared, ui_extra_networks_user_metadata, errors, extra_networks
|
from modules import shared, ui_extra_networks_user_metadata, errors, extra_networks
|
||||||
from modules.images import read_info_from_image, save_image_with_geninfo
|
from modules.images import read_info_from_image, save_image_with_geninfo
|
||||||
from modules.ui import up_down_symbol
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
import json
|
import json
|
||||||
import html
|
import html
|
||||||
@@ -348,6 +347,8 @@ def pages_in_preferred_order(pages):
|
|||||||
|
|
||||||
|
|
||||||
def create_ui(interface: gr.Blocks, unrelated_tabs, tabname):
|
def create_ui(interface: gr.Blocks, unrelated_tabs, tabname):
|
||||||
|
from modules.ui import switch_values_symbol
|
||||||
|
|
||||||
ui = ExtraNetworksUi()
|
ui = ExtraNetworksUi()
|
||||||
ui.pages = []
|
ui.pages = []
|
||||||
ui.pages_contents = []
|
ui.pages_contents = []
|
||||||
@@ -373,7 +374,7 @@ def create_ui(interface: gr.Blocks, unrelated_tabs, tabname):
|
|||||||
|
|
||||||
edit_search = gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", elem_classes="search", placeholder="Search...", visible=False, interactive=True)
|
edit_search = gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", elem_classes="search", placeholder="Search...", visible=False, interactive=True)
|
||||||
dropdown_sort = gr.Dropdown(choices=['Default Sort', 'Date Created', 'Date Modified', 'Name'], value='Default Sort', elem_id=tabname+"_extra_sort", elem_classes="sort", multiselect=False, visible=False, show_label=False, interactive=True, label=tabname+"_extra_sort_order")
|
dropdown_sort = gr.Dropdown(choices=['Default Sort', 'Date Created', 'Date Modified', 'Name'], value='Default Sort', elem_id=tabname+"_extra_sort", elem_classes="sort", multiselect=False, visible=False, show_label=False, interactive=True, label=tabname+"_extra_sort_order")
|
||||||
button_sortorder = ToolButton(up_down_symbol, elem_id=tabname+"_extra_sortorder", elem_classes="sortorder", visible=False)
|
button_sortorder = ToolButton(switch_values_symbol, elem_id=tabname+"_extra_sortorder", elem_classes="sortorder", visible=False)
|
||||||
button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh", visible=False)
|
button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh", visible=False)
|
||||||
checkbox_show_dirs = gr.Checkbox(True, label='Show dirs', elem_id=tabname+"_extra_show_dirs", elem_classes="show-dirs", visible=False)
|
checkbox_show_dirs = gr.Checkbox(True, label='Show dirs', elem_id=tabname+"_extra_show_dirs", elem_classes="show-dirs", visible=False)
|
||||||
|
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
|
|||||||
return {
|
return {
|
||||||
"name": checkpoint.name_for_extra,
|
"name": checkpoint.name_for_extra,
|
||||||
"filename": checkpoint.filename,
|
"filename": checkpoint.filename,
|
||||||
|
"shorthash": checkpoint.shorthash,
|
||||||
"preview": self.find_preview(path),
|
"preview": self.find_preview(path),
|
||||||
"description": self.find_description(path),
|
"description": self.find_description(path),
|
||||||
"search_term": self.search_terms_from_path(checkpoint.filename) + " " + (checkpoint.sha256 or ""),
|
"search_term": self.search_terms_from_path(checkpoint.filename) + " " + (checkpoint.sha256 or ""),
|
||||||
@@ -29,7 +30,8 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
|
|||||||
}
|
}
|
||||||
|
|
||||||
def list_items(self):
|
def list_items(self):
|
||||||
for index, name in enumerate(sd_models.checkpoints_list):
|
names = list(sd_models.checkpoints_list)
|
||||||
|
for index, name in enumerate(names):
|
||||||
yield self.create_item(name, index)
|
yield self.create_item(name, index)
|
||||||
|
|
||||||
def allowed_directories_for_previews(self):
|
def allowed_directories_for_previews(self):
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
from modules import ui_extra_networks_user_metadata, sd_vae
|
from modules import ui_extra_networks_user_metadata, sd_vae, shared
|
||||||
from modules.ui_common import create_refresh_button
|
from modules.ui_common import create_refresh_button
|
||||||
|
|
||||||
|
|
||||||
@@ -18,6 +18,10 @@ class CheckpointUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataE
|
|||||||
|
|
||||||
self.write_user_metadata(name, user_metadata)
|
self.write_user_metadata(name, user_metadata)
|
||||||
|
|
||||||
|
def update_vae(self, name):
|
||||||
|
if name == shared.sd_model.sd_checkpoint_info.name_for_extra:
|
||||||
|
sd_vae.reload_vae_weights()
|
||||||
|
|
||||||
def put_values_into_components(self, name):
|
def put_values_into_components(self, name):
|
||||||
user_metadata = self.get_user_metadata(name)
|
user_metadata = self.get_user_metadata(name)
|
||||||
values = super().put_values_into_components(name)
|
values = super().put_values_into_components(name)
|
||||||
@@ -58,3 +62,5 @@ class CheckpointUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataE
|
|||||||
]
|
]
|
||||||
|
|
||||||
self.setup_save_handler(self.button_save, self.save_user_metadata, edited_components)
|
self.setup_save_handler(self.button_save, self.save_user_metadata, edited_components)
|
||||||
|
self.button_save.click(fn=self.update_vae, inputs=[self.edit_name_input])
|
||||||
|
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ import os
|
|||||||
|
|
||||||
from modules import shared, ui_extra_networks
|
from modules import shared, ui_extra_networks
|
||||||
from modules.ui_extra_networks import quote_js
|
from modules.ui_extra_networks import quote_js
|
||||||
|
from modules.hashes import sha256_from_cache
|
||||||
|
|
||||||
|
|
||||||
class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage):
|
class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage):
|
||||||
@@ -14,13 +15,16 @@ class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage):
|
|||||||
def create_item(self, name, index=None, enable_filter=True):
|
def create_item(self, name, index=None, enable_filter=True):
|
||||||
full_path = shared.hypernetworks[name]
|
full_path = shared.hypernetworks[name]
|
||||||
path, ext = os.path.splitext(full_path)
|
path, ext = os.path.splitext(full_path)
|
||||||
|
sha256 = sha256_from_cache(full_path, f'hypernet/{name}')
|
||||||
|
shorthash = sha256[0:10] if sha256 else None
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"name": name,
|
"name": name,
|
||||||
"filename": full_path,
|
"filename": full_path,
|
||||||
|
"shorthash": shorthash,
|
||||||
"preview": self.find_preview(path),
|
"preview": self.find_preview(path),
|
||||||
"description": self.find_description(path),
|
"description": self.find_description(path),
|
||||||
"search_term": self.search_terms_from_path(path),
|
"search_term": self.search_terms_from_path(path) + " " + (sha256 or ""),
|
||||||
"prompt": quote_js(f"<hypernet:{name}:") + " + opts.extra_networks_default_multiplier + " + quote_js(">"),
|
"prompt": quote_js(f"<hypernet:{name}:") + " + opts.extra_networks_default_multiplier + " + quote_js(">"),
|
||||||
"local_preview": f"{path}.preview.{shared.opts.samples_format}",
|
"local_preview": f"{path}.preview.{shared.opts.samples_format}",
|
||||||
"sort_keys": {'default': index, **self.get_sort_keys(path + ext)},
|
"sort_keys": {'default': index, **self.get_sort_keys(path + ext)},
|
||||||
|
|||||||
@@ -19,9 +19,10 @@ class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage):
|
|||||||
return {
|
return {
|
||||||
"name": name,
|
"name": name,
|
||||||
"filename": embedding.filename,
|
"filename": embedding.filename,
|
||||||
|
"shorthash": embedding.shorthash,
|
||||||
"preview": self.find_preview(path),
|
"preview": self.find_preview(path),
|
||||||
"description": self.find_description(path),
|
"description": self.find_description(path),
|
||||||
"search_term": self.search_terms_from_path(embedding.filename),
|
"search_term": self.search_terms_from_path(embedding.filename) + " " + (embedding.hash or ""),
|
||||||
"prompt": quote_js(embedding.name),
|
"prompt": quote_js(embedding.name),
|
||||||
"local_preview": f"{path}.preview.{shared.opts.samples_format}",
|
"local_preview": f"{path}.preview.{shared.opts.samples_format}",
|
||||||
"sort_keys": {'default': index, **self.get_sort_keys(embedding.filename)},
|
"sort_keys": {'default': index, **self.get_sort_keys(embedding.filename)},
|
||||||
|
|||||||
@@ -36,8 +36,8 @@ class UserMetadataEditor:
|
|||||||
item = self.page.items.get(name, {})
|
item = self.page.items.get(name, {})
|
||||||
|
|
||||||
user_metadata = item.get('user_metadata', None)
|
user_metadata = item.get('user_metadata', None)
|
||||||
if user_metadata is None:
|
if not user_metadata:
|
||||||
user_metadata = {}
|
user_metadata = {'description': item.get('description', '')}
|
||||||
item['user_metadata'] = user_metadata
|
item['user_metadata'] = user_metadata
|
||||||
|
|
||||||
return user_metadata
|
return user_metadata
|
||||||
@@ -93,11 +93,13 @@ class UserMetadataEditor:
|
|||||||
item = self.page.items.get(name, {})
|
item = self.page.items.get(name, {})
|
||||||
try:
|
try:
|
||||||
filename = item["filename"]
|
filename = item["filename"]
|
||||||
|
shorthash = item.get("shorthash", None)
|
||||||
|
|
||||||
stats = os.stat(filename)
|
stats = os.stat(filename)
|
||||||
params = [
|
params = [
|
||||||
('Filename: ', os.path.basename(filename)),
|
('Filename: ', os.path.basename(filename)),
|
||||||
('File size: ', sysinfo.pretty_bytes(stats.st_size)),
|
('File size: ', sysinfo.pretty_bytes(stats.st_size)),
|
||||||
|
('Hash: ', shorthash),
|
||||||
('Modified: ', datetime.datetime.fromtimestamp(stats.st_mtime).strftime('%Y-%m-%d %H:%M')),
|
('Modified: ', datetime.datetime.fromtimestamp(stats.st_mtime).strftime('%Y-%m-%d %H:%M')),
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -115,7 +117,7 @@ class UserMetadataEditor:
|
|||||||
errors.display(e, f"reading metadata info for {name}")
|
errors.display(e, f"reading metadata info for {name}")
|
||||||
params = []
|
params = []
|
||||||
|
|
||||||
table = '<table class="file-metadata">' + "".join(f"<tr><th>{name}</th><td>{value}</td></tr>" for name, value in params) + '</table>'
|
table = '<table class="file-metadata">' + "".join(f"<tr><th>{name}</th><td>{value}</td></tr>" for name, value in params if value is not None) + '</table>'
|
||||||
|
|
||||||
return html.escape(name), user_metadata.get('description', ''), table, self.get_card_html(name), user_metadata.get('notes', '')
|
return html.escape(name), user_metadata.get('description', ''), table, self.get_card_html(name), user_metadata.get('notes', '')
|
||||||
|
|
||||||
@@ -125,7 +127,7 @@ class UserMetadataEditor:
|
|||||||
basename, ext = os.path.splitext(filename)
|
basename, ext = os.path.splitext(filename)
|
||||||
|
|
||||||
with open(basename + '.json', "w", encoding="utf8") as file:
|
with open(basename + '.json', "w", encoding="utf8") as file:
|
||||||
json.dump(metadata, file)
|
json.dump(metadata, file, indent=4)
|
||||||
|
|
||||||
def save_user_metadata(self, name, desc, notes):
|
def save_user_metadata(self, name, desc, notes):
|
||||||
user_metadata = self.get_user_metadata(name)
|
user_metadata = self.get_user_metadata(name)
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ from modules.ui_components import ToolButton
|
|||||||
|
|
||||||
|
|
||||||
class UiLoadsave:
|
class UiLoadsave:
|
||||||
"""allows saving and restorig default values for gradio components"""
|
"""allows saving and restoring default values for gradio components"""
|
||||||
|
|
||||||
def __init__(self, filename):
|
def __init__(self, filename):
|
||||||
self.filename = filename
|
self.filename = filename
|
||||||
@@ -48,6 +48,14 @@ class UiLoadsave:
|
|||||||
elif condition and not condition(saved_value):
|
elif condition and not condition(saved_value):
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
|
if isinstance(x, gr.Textbox) and field == 'value': # due to an undesirable behavior of gr.Textbox, if you give it an int value instead of str, everything dies
|
||||||
|
saved_value = str(saved_value)
|
||||||
|
elif isinstance(x, gr.Number) and field == 'value':
|
||||||
|
try:
|
||||||
|
saved_value = float(saved_value)
|
||||||
|
except ValueError:
|
||||||
|
return
|
||||||
|
|
||||||
setattr(obj, field, saved_value)
|
setattr(obj, field, saved_value)
|
||||||
if init_field is not None:
|
if init_field is not None:
|
||||||
init_field(saved_value)
|
init_field(saved_value)
|
||||||
|
|||||||
@@ -44,6 +44,8 @@ def save_pil_to_file(self, pil_image, dir=None, format="png"):
|
|||||||
|
|
||||||
if shared.opts.temp_dir != "":
|
if shared.opts.temp_dir != "":
|
||||||
dir = shared.opts.temp_dir
|
dir = shared.opts.temp_dir
|
||||||
|
else:
|
||||||
|
os.makedirs(dir, exist_ok=True)
|
||||||
|
|
||||||
use_metadata = False
|
use_metadata = False
|
||||||
metadata = PngImagePlugin.PngInfo()
|
metadata = PngImagePlugin.PngInfo()
|
||||||
@@ -57,8 +59,9 @@ def save_pil_to_file(self, pil_image, dir=None, format="png"):
|
|||||||
return file_obj.name
|
return file_obj.name
|
||||||
|
|
||||||
|
|
||||||
# override save to file function so that it also writes PNG info
|
def install_ui_tempdir_override():
|
||||||
gradio.components.IOComponent.pil_to_temp_file = save_pil_to_file
|
"""override save to file function so that it also writes PNG info"""
|
||||||
|
gradio.components.IOComponent.pil_to_temp_file = save_pil_to_file
|
||||||
|
|
||||||
|
|
||||||
def on_tmpdir_changed():
|
def on_tmpdir_changed():
|
||||||
|
|||||||
@@ -0,0 +1,58 @@
|
|||||||
|
import os
|
||||||
|
import re
|
||||||
|
|
||||||
|
from modules import shared
|
||||||
|
from modules.paths_internal import script_path
|
||||||
|
|
||||||
|
|
||||||
|
def natural_sort_key(s, regex=re.compile('([0-9]+)')):
|
||||||
|
return [int(text) if text.isdigit() else text.lower() for text in regex.split(s)]
|
||||||
|
|
||||||
|
|
||||||
|
def listfiles(dirname):
|
||||||
|
filenames = [os.path.join(dirname, x) for x in sorted(os.listdir(dirname), key=natural_sort_key) if not x.startswith(".")]
|
||||||
|
return [file for file in filenames if os.path.isfile(file)]
|
||||||
|
|
||||||
|
|
||||||
|
def html_path(filename):
|
||||||
|
return os.path.join(script_path, "html", filename)
|
||||||
|
|
||||||
|
|
||||||
|
def html(filename):
|
||||||
|
path = html_path(filename)
|
||||||
|
|
||||||
|
if os.path.exists(path):
|
||||||
|
with open(path, encoding="utf8") as file:
|
||||||
|
return file.read()
|
||||||
|
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
def walk_files(path, allowed_extensions=None):
|
||||||
|
if not os.path.exists(path):
|
||||||
|
return
|
||||||
|
|
||||||
|
if allowed_extensions is not None:
|
||||||
|
allowed_extensions = set(allowed_extensions)
|
||||||
|
|
||||||
|
items = list(os.walk(path, followlinks=True))
|
||||||
|
items = sorted(items, key=lambda x: natural_sort_key(x[0]))
|
||||||
|
|
||||||
|
for root, _, files in items:
|
||||||
|
for filename in sorted(files, key=natural_sort_key):
|
||||||
|
if allowed_extensions is not None:
|
||||||
|
_, ext = os.path.splitext(filename)
|
||||||
|
if ext not in allowed_extensions:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not shared.opts.list_hidden_files and ("/." in root or "\\." in root):
|
||||||
|
continue
|
||||||
|
|
||||||
|
yield os.path.join(root, filename)
|
||||||
|
|
||||||
|
|
||||||
|
def ldm_print(*args, **kwargs):
|
||||||
|
if shared.opts.hide_ldm_prints:
|
||||||
|
return
|
||||||
|
|
||||||
|
print(*args, **kwargs)
|
||||||
+2
-1
@@ -6,8 +6,9 @@ basicsr
|
|||||||
blendmodes
|
blendmodes
|
||||||
clean-fid
|
clean-fid
|
||||||
einops
|
einops
|
||||||
|
fastapi>=0.90.1
|
||||||
gfpgan
|
gfpgan
|
||||||
gradio==3.39.0
|
gradio==3.41.0
|
||||||
inflection
|
inflection
|
||||||
jsonmerge
|
jsonmerge
|
||||||
kornia
|
kornia
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ clean-fid==0.1.35
|
|||||||
einops==0.4.1
|
einops==0.4.1
|
||||||
fastapi==0.94.0
|
fastapi==0.94.0
|
||||||
gfpgan==1.3.8
|
gfpgan==1.3.8
|
||||||
gradio==3.39.0
|
gradio==3.41.0
|
||||||
httpcore==0.15
|
httpcore==0.15
|
||||||
inflection==0.5.1
|
inflection==0.5.1
|
||||||
jsonmerge==1.8.0
|
jsonmerge==1.8.0
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user