Compare commits

...

653 Commits

Author SHA1 Message Date
w-e-w c7fbcb5789 remove space after comma and colon 2024-01-23 03:44:21 +09:00
w-e-w 28cc18cbdf add exception in case of the infotext is damaged 2024-01-23 03:34:49 +09:00
w-e-w 9a042c907b compact Soft Inpainting infotext 2024-01-23 02:44:21 +09:00
w-e-w 36fb876da6 compact infotaxt info_json 2024-01-23 02:17:21 +09:00
AUTOMATIC1111 8a6a4ad894 Merge pull request #14709 from AUTOMATIC1111/improve-get_crop_region
improve get_crop_region
2024-01-21 16:01:44 +03:00
w-e-w e36827af32 improve get_crop_region 2024-01-21 09:02:18 +09:00
AUTOMATIC1111 3a5196de1b Merge pull request #14663 from aalbacetef/feature/add-model-to-log-file
feature/add-model-to-log-file
2024-01-21 02:05:18 +03:00
Arturo Albacete 4aa99f77ab add docstring 2024-01-20 22:04:53 +01:00
Arturo Albacete f190b85182 restore saving fields 2024-01-20 21:27:38 +01:00
Arturo Albacete 8459015017 skip if headers haven't changed 2024-01-20 21:19:53 +01:00
Arturo Albacete d0b65e148b merge dev 2024-01-20 21:15:57 +01:00
AUTOMATIC1111 f939bce845 Merge pull request #14702 from light-and-ray/keep_postprocessing_upscale_selected_tab_after_restart
[Bug] Keep postprocessing upscale selected tab after restart
2024-01-20 14:56:53 +03:00
Andray ed383eb5a0 keep postprocessing upscale selected tab after restart 2024-01-20 15:37:49 +04:00
AUTOMATIC1111 c1713bfeac Merge pull request #14689 from AUTOMATIC1111/fix-nested-manual-cast
Fix nested manual cast
2024-01-20 11:43:51 +03:00
Kohaku-Blueleaf 4a66d2fb22 Avoid exceptions to be silenced 2024-01-20 16:33:59 +08:00
Kohaku-Blueleaf 81126027f5 Avoid early disable 2024-01-20 16:31:12 +08:00
AUTOMATIC1111 1dbee391b4 Merge pull request #14637 from light-and-ray/fix_tab_indexes_resets_after_restart_ui
[Bug] Fix tab indexes are reseted after restart UI
2024-01-20 11:00:29 +03:00
Andray 56676ff923 fix tab indexes reset after restart ui 2024-01-20 11:49:05 +04:00
AUTOMATIC1111 a06ae54a18 Merge pull request #14639 from light-and-ray/fix_extension_check_for_requirements
[Bug] Fix extension check for requirements
2024-01-20 10:30:28 +03:00
AUTOMATIC1111 58a142b56b Merge pull request #14640 from WebDev9000/replace-hashtags-in-filenames
Add # to the invalid_filename_chars list
2024-01-20 10:29:42 +03:00
AUTOMATIC1111 0f2de4cfbe Merge pull request #14655 from chi2nagisa/lora-alias-fix
[bug] fix using wrong model caused by alias
2024-01-20 10:27:54 +03:00
AUTOMATIC1111 cfa326bbcc Merge pull request #14659 from AUTOMATIC1111/immediately-stop-on-second-interrupt
immediately stop on second interrupt
2024-01-20 10:26:52 +03:00
AUTOMATIC1111 7581da5d92 Merge pull request #14645 from AUTOMATIC1111/infotexts
geninfo from Infotexts
2024-01-20 10:26:03 +03:00
AUTOMATIC1111 41c2121e51 Merge pull request #14657 from AUTOMATIC1111/callback-postprocess_image_after_composite
callback postprocess_image_after_composite
2024-01-20 10:25:29 +03:00
AUTOMATIC1111 5df56b3980 Merge pull request #14626 from AUTOMATIC1111/hr-button-fix
more Hr button fix
2024-01-20 10:24:59 +03:00
AUTOMATIC1111 9cdd161160 Merge pull request #14690 from n0kovo/dev
Add support for DAT upscaler models
2024-01-20 10:15:59 +03:00
AUTOMATIC1111 7c30c5eec4 Merge pull request #14699 from light-and-ray/fix_extras_big_batch_crashes
[Bug] Fix extras big batch crashes
2024-01-20 10:12:16 +03:00
AUTOMATIC1111 5cbda43d63 Merge pull request #14638 from WebDev9000/brush-size-adjust
Adjust brush size with hotkeys
2024-01-20 09:42:10 +03:00
WebDev 30f31d23c6 Update hotkey_config.py 2024-01-19 16:39:33 -08:00
WebDev bcdcc8be7d Update hotkey_config.py 2024-01-19 16:39:07 -08:00
WebDev 4dde96109b Update zoom.js 2024-01-19 16:38:34 -08:00
Andray d31dc7a739 fix extras big batch crashes 2024-01-20 00:40:03 +04:00
n0kovo 1ddb886a80 Fix wrong options value 2024-01-19 00:48:46 +01:00
n0kovo 2e7efe47b6 Minor cleanup 2024-01-19 00:39:14 +01:00
n0kovo a97147bc8a Add support for DAT upscaler models 2024-01-19 00:10:02 +01:00
Kohaku-Blueleaf 0181c1f76b Fix nested manual cast 2024-01-19 00:14:03 +08:00
w-e-w 2cf23099eb fix console total progress bar when using txt2img_upscale
add p.txt2img_upscale as indicator
2024-01-18 04:44:21 +09:00
w-e-w e1dfd452c0 parse_generation_parameters with no skip_fields 2024-01-18 02:39:42 +09:00
w-e-w 6916de5c0b parse_generation_parameters skip_fields 2024-01-18 02:39:15 +09:00
w-e-w 45a51c07e2 parse_generation_parameters with no skip_fields 2024-01-18 02:36:19 +09:00
w-e-w d224fed0ce parse_generation_parameters skip_fields 2024-01-18 02:36:19 +09:00
w-e-w 6acd8e28fc save_files info base on infotexts 2024-01-18 02:33:43 +09:00
w-e-w 0b83f4c263 reuse seed from infotexts 2024-01-18 02:33:43 +09:00
Arturo Albacete 315e40a49c reuse variable for log file path 2024-01-16 19:11:28 +01:00
Arturo Albacete a75dfe1c0d - expand fields to include model name and hash
- write these in the CSV log file
- ensure old log files are updated w.r.t delimiter count
2024-01-16 19:03:48 +01:00
w-e-w 14b9762bca immediately stop on second interrupt
Revert "immediately stop on second interrupt"

This reverts commit ab409072a1f2a9c911a63aee98a6b42081803cdc.

immediately stop on second interrupt
2024-01-16 17:18:05 +09:00
w-e-w c1e04c63b3 callback postprocess_image_after_composite 2024-01-16 14:18:20 +09:00
chi2nagisa e280eb4055 fix using wrong model caused by alias 2024-01-16 03:45:19 +08:00
WebDev 39db1d09d1 Update zoom.js 2024-01-14 01:43:23 -08:00
WebDev 92032f2da8 Update hotkey_config.py 2024-01-14 01:42:03 -08:00
w-e-w 208ccfbe7c seed info from infotexts 2024-01-14 17:57:49 +09:00
w-e-w ee9d487081 fix gallery black image issue 2024-01-14 17:57:42 +09:00
w-e-w cfb90a938e allowe hr pass to return multiple images 2024-01-14 17:57:26 +09:00
w-e-w 92501d4f80 disable saving images before highres fix 2024-01-14 17:56:34 +09:00
Andray b6dc307c99 fix_extension_check_for_requirements 2024-01-13 14:45:15 +04:00
WebDev 47b52d9b28 Add # to the invalid_filename_chars list 2024-01-13 02:31:26 -08:00
WebDev 541881e318 Adjust brush size with hotkeys. 2024-01-13 02:11:06 -08:00
AUTOMATIC1111 cb5b335acd Merge pull request #14618 from akx/log-fmt-fix
Logging: set formatter correctly for fallback logger too
2024-01-11 13:49:36 +03:00
Aarni Koskela 0011640ab1 Logging: set formatter correctly for fallback logger too 2024-01-11 08:29:42 +02:00
AUTOMATIC1111 85bf2eb441 Merge pull request #14598 from AUTOMATIC1111/fix-txt2img_upscale
hires button, fix seeds
2024-01-09 20:05:10 +03:00
w-e-w 4d9f2c3ec8 update p.seed and p.subseed 2024-01-10 01:56:44 +09:00
AUTOMATIC1111 abb630108e Merge pull request #14589 from Sj-Si/hotfix/remove-upscaler-size-limits
Increase Upscaler Limits
2024-01-09 19:52:54 +03:00
Sj-Si 9dd2534824 Restore scale factor limit changes to master branch. 2024-01-09 11:47:39 -05:00
Sj-Si 3cc7572f5d Restore scale factor limit changes to master branch. 2024-01-09 11:46:10 -05:00
AUTOMATIC1111 751d014cd6 Merge pull request #14583 from continue-revolution/conrevo/lcm-sampler
Official LCM Sampler Support
2024-01-09 19:37:35 +03:00
AUTOMATIC1111 639f22ea7c Merge pull request #14593 from papuSpartan/api_tls
Allow TLS with API only mode (--nowebui)
2024-01-09 19:33:31 +03:00
AUTOMATIC1111 905b14237f Merge pull request #14597 from AUTOMATIC1111/improved-manual-cast
Improve the implementation of Manual Cast and IPEX support
2024-01-09 19:33:00 +03:00
Kohaku-Blueleaf ca671e5d7b rearrange if-statements for cpu 2024-01-09 23:30:55 +08:00
Kohaku-Blueleaf 58d5b042cd Apply the correct behavior of precision='full' 2024-01-09 23:23:40 +08:00
Kohaku-Blueleaf 1fd69655fe Revert "Apply correct inference precision implementation"
This reverts commit e00365962b.
2024-01-09 23:15:05 +08:00
Kohaku-Blueleaf e00365962b Apply correct inference precision implementation 2024-01-09 23:13:34 +08:00
Kohaku-Blueleaf c2c05fcca8 linting and debugs 2024-01-09 22:53:58 +08:00
KohakuBlueleaf 42e6df723c Fix bugs when arg dtype doesn't match 2024-01-09 22:39:39 +08:00
Kohaku-Blueleaf 209c26a1cb improve efficiency and support more device 2024-01-09 22:11:44 +08:00
unknown 8d986727b3 include tls arguments in api uvicorn init 2024-01-09 03:01:20 -06:00
Sj-Si 46413b20a4 Increase limits for upscalers. 2024-01-08 16:23:35 -05:00
continue-revolution 8e292373ec lcm sampler 2024-01-08 06:43:39 -06:00
AUTOMATIC1111 6869d95890 Merge pull request #14573 from continue-revolution/conrevo/add-p-to-cfg
Add self to CFGDenoiserParams
2024-01-08 10:23:08 +03:00
Chengsong Zhang 37906e429a make denoiser None by default 2024-01-07 20:17:42 -06:00
continue-revolution f56cebf5ba add self instead 2024-01-07 12:35:35 -06:00
continue-revolution 425507bd10 add p to cfgdenoiserparams 2024-01-07 10:25:01 -06:00
AUTOMATIC1111 2f98a35fc4 add assets repo; serve fonts locally rather than from google's servers 2024-01-07 09:21:21 +03:00
AUTOMATIC1111 30aa5e0a7c Merge pull request #14560 from Nuullll/condfunc-warning
Handle CondFunc exception when resolving attributes
2024-01-07 08:23:08 +03:00
AUTOMATIC1111 cab1d839b5 Merge pull request #14563 from Nuullll/model-loaded-callback
Execute model_loaded_callback after moving to target device
2024-01-07 08:22:17 +03:00
AUTOMATIC1111 71e0057137 Merge pull request #14562 from Nuullll/fix-ipex-xpu-generator
[IPEX] Fix xpu generator
2024-01-07 08:21:43 +03:00
Nuullll a183de04e3 Execute model_loaded_callback after moving to target device 2024-01-06 20:03:33 +08:00
Nuullll 818d6a11e7 Fix format 2024-01-06 19:14:06 +08:00
Nuullll 73786c047f [IPEX] Fix torch.Generator hijack 2024-01-06 19:09:56 +08:00
AUTOMATIC1111 b00b429477 Merge pull request #14559 from Nuullll/ipex-sdpa-fix
[IPEX] Fix SDPA attn_mask dtype
2024-01-06 13:14:18 +03:00
Nuullll ec9acb3145 Handle CondFunc exception when resolving attributes 2024-01-06 17:18:38 +08:00
Nuullll 16b4d2cf3f [IPEX] Fix SDPA attn_mask dtype 2024-01-06 16:32:18 +08:00
AUTOMATIC1111 8b6848c6db Merge pull request #14546 from AUTOMATIC1111/fix-oft-dtype
Fix dtype casting in OFT module
2024-01-06 10:50:38 +03:00
AUTOMATIC1111 a4ee64050a Merge pull request #14547 from AUTOMATIC1111/lyco-forward
Implement general forward method for all method in built-in lora ext
2024-01-06 10:50:06 +03:00
AUTOMATIC1111 942617f828 Merge pull request #14548 from keshav-nischal/patch-2
Update README.md
2024-01-06 10:49:45 +03:00
AUTOMATIC1111 233c66b36e Make the upscale button update the gallery with the new image rather than replace it. 2024-01-05 12:28:41 +03:00
Keshav Nischal 88ba095fd0 Update README.md 2024-01-05 14:15:58 +05:30
Kohaku-Blueleaf 44744d6005 linting 2024-01-05 16:38:05 +08:00
Kohaku-Blueleaf 18ca987c92 Add general forward method for all modules. 2024-01-05 16:32:19 +08:00
Kohaku-Blueleaf f8f38c7c28 Fix dtype casting for OFT module 2024-01-05 16:31:48 +08:00
AUTOMATIC1111 a06dab8d7a Merge pull request #14538 from akx/log-wut
Fix logging configuration again
2024-01-05 11:04:14 +03:00
AUTOMATIC1111 6ffbff0857 Merge pull request #14537 from akx/gradio-analytics-enabled-again
Ensure GRADIO_ANALYTICS_ENABLED is set early enough
2024-01-05 11:02:50 +03:00
Aarni Koskela 6fa42e919f Fix logging configuration again
* Only use `tqdm.write()` if `tqdm` is active, defer to stderr
* Correct log formatter for TqdmLoggingHandler
* If `rich` is installed and `SD_WEBUI_RICH_LOG` is set, use `rich`'s formatter
2024-01-04 19:32:03 +02:00
Aarni Koskela 9805f35c6f Ensure GRADIO_ANALYTICS_ENABLED is set early enough 2024-01-04 19:13:47 +02:00
AUTOMATIC1111 15ec54dd96 Have upscale button use the same seed as hires fix. 2024-01-04 19:47:00 +03:00
AUTOMATIC1111 f903b4dda3 Merge pull request #14523 from AUTOMATIC1111/paste-infotext-cast-int-as-float
paste infotext cast int as float
2024-01-04 11:19:18 +03:00
AUTOMATIC1111 3f7f61e541 Merge pull request #14524 from akx/fix-swinir-issues
Fix SwinIR issues
2024-01-04 11:17:20 +03:00
AUTOMATIC1111 1e7a8ce5e4 Merge pull request #14525 from AUTOMATIC1111/handle-config.json-failed-to-load
handle config.json failed to load
2024-01-04 11:16:37 +03:00
AUTOMATIC1111 397251ba0c Merge pull request #14527 from akx/avoid-isfiles
Avoid unnecessary `isfile`/`exists` calls
2024-01-04 11:15:56 +03:00
AUTOMATIC1111 df62ffbd25 Merge branch 'dev' into avoid-isfiles 2024-01-04 11:15:50 +03:00
AUTOMATIC1111 149c9d2234 Merge pull request #14528 from AUTOMATIC1111/mass-file-lister
mass file lister as an attempt to tackle #14507
2024-01-04 11:09:59 +03:00
AUTOMATIC1111 320a217b78 forgot something 2024-01-04 02:39:02 +03:00
AUTOMATIC1111 420f56c2e8 mass file lister as an attempt to tackle #14507 2024-01-04 02:28:05 +03:00
Aarni Koskela d9034b48a5 Avoid unnecessary isfile/exists calls 2024-01-04 00:26:30 +02:00
w-e-w 50158a1fc9 handle config.json failed to load 2024-01-04 06:30:52 +09:00
Aarni Koskela 62470ee234 upscale_2: cast image to model's dtype 2024-01-03 22:39:12 +02:00
Aarni Koskela 3d31d5c27b SwinIR: pass model.scale 2024-01-03 22:38:49 +02:00
Aarni Koskela dfdc51246c SwinIR: use prefer_half 2024-01-03 22:38:13 +02:00
w-e-w bfc48fbc24 paste infotext cast int as float 2024-01-04 03:46:05 +09:00
AUTOMATIC1111 04a005f0e9 Merge pull request #14512 from AUTOMATIC1111/remove-excessive-extra-networks-reload
reduce unnecessary re-indexing extra networks directory
2024-01-03 19:15:46 +03:00
w-e-w fccd0b00c2 reduce unnecessary re-indexing extra networks dir 2024-01-03 19:25:06 +09:00
AUTOMATIC1111 9c6ea5386b Merge pull request #14504 from akx/you-spin-me-round
torch_bgr_to_pil_image: round, don't truncate
2024-01-03 12:42:05 +03:00
Aarni Koskela 7ad6899bf9 torch_bgr_to_pil_image: round, don't truncate
This matches what `realesrgan` does.
2024-01-02 17:14:05 +02:00
AUTOMATIC1111 e4dcdcc955 Merge pull request #14501 from akx/credits-remove-copypaste
Remove licenses and README mentions for code that's no longer copy-pasted
2024-01-02 16:25:26 +03:00
Aarni Koskela 62bd7624d2 Remove licenses for code that's no longer copy-pasted; adjust README 2024-01-02 11:46:42 +02:00
AUTOMATIC1111 7c3ab416ad Merge pull request #14500 from akx/spandrel-prefer-half
Spandrel: "prefer half" instead of "force half"
2024-01-02 12:23:23 +03:00
Aarni Koskela 2cacbc124c load_spandrel_model: make half prefer_half
As discussed with the Spandrel folks, it's good to heed Spandrel's
"supports half precision" flag to avoid e.g. black blotches and what-not.
2024-01-02 10:44:38 +02:00
AUTOMATIC1111 51f1cca852 Merge pull request #14484 from akx/swinir-resample-for-div8
Refactor Torch-space upscale fully out of ScuNET/SwinIR
2024-01-02 10:56:37 +03:00
Aarni Koskela cf14a6a7aa Refactor upscale_2 helper out of ScuNET/SwinIR; make sure devices are right 2024-01-02 08:57:12 +02:00
AUTOMATIC1111 980970d390 final touches 2024-01-02 07:08:32 +03:00
AUTOMATIC1111 80873b1538 fix #14497 2024-01-02 07:05:05 +03:00
AUTOMATIC1111 1341b22081 add an option to hide upscaling progressbar 2024-01-02 06:47:26 +03:00
AUTOMATIC1111 6f9fcfdbb7 Merge pull request #14497 from Jibaku789/dev
Add inpaint arguments in .txt file
2024-01-02 06:42:21 +03:00
Jibaku789 a5b6a5a3ad Add inpaint options to img2img.py 2024-01-01 14:58:55 -06:00
Jibaku789 c2ea571005 Add inpaint options to paste fields 2024-01-01 14:57:41 -06:00
AUTOMATIC1111 ac3cc1adc5 Merge pull request #14495 from akx/fix-js-lint
Fix lint issue from 501993eb
2024-01-01 21:02:47 +03:00
Aarni Koskela c32c51a0fc Fix lint issue from 501993eb 2024-01-01 19:20:54 +02:00
AUTOMATIC1111 501993ebf2 added a button to run hires fix on selected image in the gallery 2024-01-01 19:31:06 +03:00
AUTOMATIC1111 5d7d1823af rename infotext.py again, this time to infotext_utils.py; I didn't realize infotext would be used for variable names in multiple places, which makes it awkward to import the module; also fix the bug I caused by this rename that breaks tests 2024-01-01 17:25:30 +03:00
AUTOMATIC1111 1ffdedc11d restore lines lost from #13789 merge 2024-01-01 17:03:08 +03:00
AUTOMATIC1111 c507d7b252 Merge pull request #13789 from nickpharrison/finer-settings-freezing-control
Finer settings freezing control
2024-01-01 17:01:28 +03:00
AUTOMATIC1111 7ba02e0b7c Merge branch 'dev' into finer-settings-freezing-control 2024-01-01 17:01:06 +03:00
AUTOMATIC1111 15156cde18 Merge pull request #14291 from AUTOMATIC1111/on-mouse-hover-show-hide-modal-image-viewer-icons
on mouse hover show / hide modal image viewer icons
2024-01-01 16:53:33 +03:00
AUTOMATIC1111 0aa7c53c0b fix borked merge, rename fields to better match what they do, change setting default to true for #13653 2024-01-01 16:50:59 +03:00
AUTOMATIC1111 2a7ad70db5 Merge pull request #13653 from antfu/feat/interrupted-end
Interrupt after current generation
2024-01-01 16:40:02 +03:00
AUTOMATIC1111 dfd6438221 Merge branch 'dev' into feat/interrupted-end 2024-01-01 16:39:51 +03:00
AUTOMATIC1111 0ce67cb618 Merge pull request #14352 from AUTOMATIC1111/reduce-unnecessary-ui-config-write
only rewrite ui-config when there is change
2024-01-01 16:35:07 +03:00
AUTOMATIC1111 cba6fba123 Merge pull request #14353 from Nuullll/ipex-sdpa
[IPEX] Slice SDPA into smaller chunks
2024-01-01 16:33:55 +03:00
AUTOMATIC1111 ac0ecf3b4b option to convert VAE to bfloat16 (implementation of #9295) 2024-01-01 16:28:58 +03:00
AUTOMATIC1111 0743ee9b3e re-layout checkboxes for XYZ grid a bit 2024-01-01 15:50:47 +03:00
AUTOMATIC1111 c352008c95 Merge remote-tracking branch 'rubberbaron/xyz-grid-vary-seeds' into dev 2024-01-01 15:50:10 +03:00
AUTOMATIC1111 d8126be578 linter 2024-01-01 15:00:39 +03:00
AUTOMATIC1111 45b7bba3d0 add automatic version support for zero terminal SNR noise schedule option from #14145 2024-01-01 14:51:56 +03:00
AUTOMATIC1111 267fd5d76b Merge pull request #14145 from drhead/zero-terminal-snr
Implement zero terminal SNR noise schedule option
2024-01-01 14:45:12 +03:00
AUTOMATIC1111 d613cd17c7 add automatic backwards version compatibility 2024-01-01 14:38:29 +03:00
AUTOMATIC1111 d859cec696 infotext.py: rename usages in the codebase 2024-01-01 13:53:12 +03:00
AUTOMATIC1111 c5496c7646 infotext.py: add support for old modules.generation_parameters_copypaste name 2024-01-01 13:52:37 +03:00
AUTOMATIC1111 003b91f083 rename generation_parameters_copypaste module to infotext 2024-01-01 13:45:18 +03:00
AUTOMATIC1111 5692bf1517 add missing field for DDIM sampler that was breaking img2img 2024-01-01 11:11:14 +03:00
AUTOMATIC1111 e55fec9d9a Merge pull request #14487 from AUTOMATIC1111/handle-selectable-script_index-is-None
handle selectable script_index is None
2024-01-01 10:19:58 +03:00
w-e-w 00901bfbe0 handle selectable script_index is None 2024-01-01 15:47:57 +09:00
AUTOMATIC1111 a70dfb64a8 change import statements for #14478 2023-12-31 22:38:30 +03:00
AUTOMATIC1111 be5f1acc8f Merge pull request #14478 from akx/dtype-inspect
Add utility to inspect a model's dtype/device
2023-12-31 22:33:32 +03:00
AUTOMATIC1111 f3af8c8d04 Merge pull request #14475 from Learwin/negative_prompt
Adding negative prompts to Loras in extra networks
2023-12-31 22:32:28 +03:00
Learwin b6f74e936e Revert change from linting for unrelated file 2023-12-31 13:36:36 +01:00
Learwin d4945f4422 Removed weight slider for negative prompts 2023-12-31 13:22:30 +01:00
Aarni Koskela 5768afc776 Add utility to inspect a model's parameters (to get dtype/device) 2023-12-31 13:22:43 +02:00
AUTOMATIC1111 a84e842189 Merge pull request #14476 from akx/dedupe-tiled-weighted-inference
Deduplicate tiled inference code from SwinIR/ScuNET
2023-12-31 09:41:49 +03:00
Aarni Koskela 6f86b62a1b Deduplicate tiled inference code from SwinIR/ScuNET 2023-12-31 01:13:30 +02:00
AUTOMATIC1111 ce21840a04 Merge pull request #14477 from akx/spandrel-type-fix
Be more clear about Spandrel model nomenclature and types
2023-12-31 01:38:43 +03:00
AUTOMATIC1111 ae124439c4 Merge pull request #14471 from akx/bump-numpy
Bump numpy to 1.26.2
2023-12-31 01:37:56 +03:00
Aarni Koskela 777af661a2 Be more clear about Spandrel model nomenclature 2023-12-31 00:22:58 +02:00
Aarni Koskela c0ca6348e8 load_spandrel_model: always return a model descriptor 2023-12-31 00:04:47 +02:00
AUTOMATIC1111 3be9074031 fix for the previous fix. 2023-12-31 00:43:41 +03:00
Learwin a2f23f9d22 Code Style fixes 2023-12-30 22:16:51 +01:00
Learwin bc5ae74c7d Added negative prompts to extra networks lora 2023-12-30 21:52:27 +01:00
AUTOMATIC1111 8100e901ab fix error with RealESRGAN model failing to upscale fp32 image 2023-12-30 22:41:53 +03:00
AUTOMATIC1111 c2fd7c0344 Merge pull request #14474 from akx/realesrgan-is-esrgan
Correct RealESRGAN expected architecture type to ESRGAN
2023-12-30 22:40:11 +03:00
AUTOMATIC1111 7c13ffdbb1 Merge pull request #14472 from akx/drop-move-code
Remove `cleanup_models` code
2023-12-30 22:15:30 +03:00
AUTOMATIC1111 a86f4411cb Merge pull request #14473 from akx/soften-model-arch-check
Soften Spandrel model-architecture check to just a warning
2023-12-30 22:13:29 +03:00
Aarni Koskela 393a5b82ba Correct RealESRGAN expected architecture type to ESRGAN 2023-12-30 21:12:32 +02:00
Aarni Koskela af050dcaa7 Soften Spandrel model-architecture check to just a warning 2023-12-30 21:05:59 +02:00
Aarni Koskela 5fbb13e0da Remove cleanup_models code 2023-12-30 20:47:12 +02:00
AUTOMATIC1111 16848f950b Merge pull request #14467 from akx/drop-basicsr
Drop basicsr dependency
2023-12-30 21:27:33 +03:00
Aarni Koskela 48a2a1a437 Don't wait for 10 minutes for test server to come up 2023-12-30 19:44:38 +02:00
Aarni Koskela 1465dab715 Make Tensorboard a late import (it was implicitly installed by basicsr) 2023-12-30 19:44:05 +02:00
AUTOMATIC1111 79c9151802 Merge pull request #14421 from lanyeeee/api_thread_safe
fix API thread safe issues of txt2img and img2img
2023-12-30 20:21:13 +03:00
lanyeeee f651405427 remove locks, move init code to __init__ 2023-12-31 01:09:13 +08:00
Aarni Koskela b58ed1b243 Bump numpy to 1.26.2
This avoids it being downgraded during `launch.py`
2023-12-30 18:02:01 +02:00
Aarni Koskela c9174253fb Drop dependency on basicsr 2023-12-30 17:53:19 +02:00
lanyeeee 91560e98c4 fix format issue 2023-12-30 23:42:10 +08:00
Aarni Koskela f476649c02 Correct arg type for restore_face 2023-12-30 17:41:29 +02:00
AUTOMATIC1111 cd12c0e15c Merge pull request #14425 from akx/spandrel
Use Spandrel for upscaling and face restoration architectures
2023-12-30 18:06:31 +03:00
AUTOMATIC1111 05230c0260 fix img2img api that i broke when implementing infotext support 2023-12-30 18:02:51 +03:00
Aarni Koskela 4ad0c0c0a8 Verify architecture for loaded Spandrel models 2023-12-30 16:37:03 +02:00
Aarni Koskela c756133541 Add experimental HAT model 2023-12-30 16:30:49 +02:00
Aarni Koskela b621a63cf6 Unify CodeFormer and GFPGAN restoration backends, use Spandrel for GFPGAN 2023-12-30 16:30:49 +02:00
Aarni Koskela b0f5934234 Use Spandrel for upscaling and face restoration architectures (aside from GFPGAN and LDSR) 2023-12-30 16:24:01 +02:00
Aarni Koskela e472383acb Refactor esrgan_upscale to more generic upscale_with_model 2023-12-30 16:24:01 +02:00
Aarni Koskela 12c6f37f8e Add tile_count property to Grid 2023-12-30 16:24:01 +02:00
Aarni Koskela 7aa27b000a Add types to split_grid 2023-12-30 16:24:01 +02:00
AUTOMATIC1111 31992eff9b make it possible again to extract styles that have whitespace at the end. 2023-12-30 16:51:13 +03:00
kurisu_u d05f9e8124 Merge branch 'dev' into api_thread_safe 2023-12-30 21:47:59 +08:00
lanyeeee c069c2c562 add locks to ensure init args are thread-safe 2023-12-30 21:32:22 +08:00
AUTOMATIC1111 adcd65ba34 Merge pull request #14367 from AUTOMATIC1111/reorder-post-processing-modules
reorder training preprocessing modules in extras tab
2023-12-30 15:23:52 +03:00
AUTOMATIC1111 f0e2e8b930 update #14354 2023-12-30 15:12:48 +03:00
AUTOMATIC1111 83c0758d90 Merge pull request #14354 from ranareehanaslam/master
Update Added (Fixed) IPV6 Functionality When there is No Webui Argument Passed webui.py
2023-12-30 15:11:49 +03:00
AUTOMATIC1111 1d603eb5a8 Merge pull request #14394 from AUTOMATIC1111/minor-xyz-fix
xyz grid handle axis_type is None
2023-12-30 15:06:39 +03:00
AUTOMATIC1111 4b6eb8072b Merge pull request #14407 from AUTOMATIC1111/prevent-crash-due-to-Script-__init__-exception
prevent crash due to Script __init__ exception
2023-12-30 14:54:31 +03:00
AUTOMATIC1111 908fb4ea71 Merge pull request #14390 from wangqyqq/sdxl-inpaint
Supporting for SDXL-Inpaint Model
2023-12-30 14:49:52 +03:00
AUTOMATIC1111 c9c105c7db Merge pull request #14446 from AUTOMATIC1111/base-output-path-off-data_path
Base output path off data path
2023-12-30 14:45:28 +03:00
AUTOMATIC1111 a79890efd6 Merge pull request #14452 from AUTOMATIC1111/save-info-of-init-image
save info of init image
2023-12-30 14:41:39 +03:00
AUTOMATIC1111 32862e4379 Merge pull request #14464 from AUTOMATIC1111/more-lora-not-found-warning
More lora not found warning
2023-12-30 14:04:17 +03:00
AUTOMATIC1111 8f18263759 fix bad values read from infotext for API, add comment 2023-12-30 13:48:25 +03:00
AUTOMATIC1111 11a435b469 img2img support for infotext API 2023-12-30 13:34:46 +03:00
AUTOMATIC1111 0aacd4c72b add support for alwayson scripts for infotext API 2023-12-30 13:33:18 +03:00
AUTOMATIC1111 8b08b78c03 make it so that if an option from infotext conflicts with an argument from API, the latter overrides the former 2023-12-30 12:27:23 +03:00
AUTOMATIC1111 ba92135a2b add override_settings support for infotext API 2023-12-30 12:11:09 +03:00
w-e-w 59d060fd5e More lora not found warning 2023-12-30 17:11:03 +09:00
AUTOMATIC1111 bb07cb6a0d a 2023-12-30 10:42:42 +03:00
w-e-w dc57ec0296 save info of init image 2023-12-29 01:56:48 +09:00
w-e-w 892e703b59 webpath use truncate_path 2023-12-28 06:52:41 +09:00
w-e-w af2951ed53 base default image output on data_path
Co-Authored-By: Alberto Cano <34340962+canoalberto@users.noreply.github.com>
2023-12-28 06:52:33 +09:00
w-e-w de04573438 create utility truncate_path
utli.truncate_path(target_path, base_path)
return the target_path relative to base_path if target_path is a sub path of base_path else return the absolute path
2023-12-28 06:22:51 +09:00
wangqyqq bfe418a58d add some codes for robust 2023-12-27 10:20:56 +08:00
lanyeeee 00d4a4d4ac move thread-unsafe code to __init__ 2023-12-26 14:46:29 +08:00
w-e-w edfae95d90 prevent crash due to Script __init__ exception 2023-12-23 01:21:00 +09:00
w-e-w de1809bd14 handle axis_type is None 2023-12-22 00:37:30 +09:00
wangqyqq 9feb034e34 support for sdxl-inpaint model 2023-12-21 20:15:51 +08:00
w-e-w 3e068de0dc reorder training preprocessing modules in extras tab
using the order from before the rework
11d23e8ca5
2023-12-19 18:48:49 +09:00
Muhammad Rehan Aslam 0d5941edbc Update webui.py
Co-authored-by: Aarni Koskela <akx@iki.fi>
2023-12-19 09:50:38 +05:00
Muhammad Rehan Aslam fe4d084390 Update webui.py
Added (Fixed) IPV6 Functionality When there is No Webui Argument Passed
2023-12-18 17:50:00 +05:00
Nuullll f586f4973a Fix device id 2023-12-18 19:44:52 +08:00
Nuullll e4b4a9c4ac [IPEX] Slice SDPA into smaller chunks 2023-12-18 18:01:09 +08:00
w-e-w 10945aa41a only rewrite ui-config when there is change
and a typo
2023-12-18 15:27:41 +09:00
AUTOMATIC1111 de03882d6c make task ids for API work without force_task_id 2023-12-17 08:55:35 +03:00
AUTOMATIC1111 3d9a0d9e4b Merge pull request #14330 from AUTOMATIC1111/fix-extras-caption-BLIP
fix extras caption BLIP
2023-12-16 16:41:40 +03:00
w-e-w 98c5fa9201 fix extras caption BLIP
#14328
2023-12-16 22:14:39 +09:00
AUTOMATIC1111 7428ce52ab Merge pull request #14327 from AUTOMATIC1111/fp8-cond-cache-fix
Fix FP8 non-reproducible problem
2023-12-16 14:59:35 +03:00
Kohaku-Blueleaf a978320334 Let fp8-related settings to invalidate cond_cache 2023-12-16 19:39:43 +08:00
AUTOMATIC1111 4f5281a92e Merge pull request #14227 from kingljl/kingljl-patch-memory-leak
Long running memory leak problem
2023-12-16 11:24:07 +03:00
AUTOMATIC1111 86b3aa94e2 rename pending tasks api endpoint to be more in line with others 2023-12-16 11:04:59 +03:00
AUTOMATIC1111 5b7d86d42b Merge pull request #14314 from gayshub/master
Add allow specify the task id and get the location of task in the queue of pending task
2023-12-16 11:01:42 +03:00
AUTOMATIC1111 93eae69895 move soft inpainting to a built-in extension 2023-12-16 11:00:42 +03:00
AUTOMATIC1111 cd9ce2e31c Use radio for FP8 mode selection 2023-12-16 10:40:20 +03:00
AUTOMATIC1111 c121f8c315 Merge pull request #14031 from AUTOMATIC1111/test-fp8
A big improvement for dtype casting system with fp8 storage type and manual cast
2023-12-16 10:22:51 +03:00
AUTOMATIC1111 8edb9144cc Merge branch 'dev' into test-fp8 2023-12-16 10:22:16 +03:00
AUTOMATIC1111 60186c7b9d Merge pull request #14107 from AUTOMATIC1111/torch210
Torch210
2023-12-16 10:15:55 +03:00
AUTOMATIC1111 7745db6fc0 torch 2.1.2 2023-12-16 10:15:08 +03:00
Kohaku-Blueleaf ea272152e0 Add FP8 settings into PNG info 2023-12-16 15:08:08 +08:00
AUTOMATIC1111 e9c6325fc6 Merge branch 'dev' into torch210 2023-12-16 10:05:10 +03:00
AUTOMATIC1111 7504f14503 Merge branch 'master' into dev 2023-12-16 09:59:47 +03:00
AUTOMATIC1111 cf2772fab0 Merge branch 'release_candidate' 2023-12-16 09:58:07 +03:00
AUTOMATIC1111 0dfffe53ec Merge pull request #14307 from AUTOMATIC1111/default-Falst-js_live_preview_in_modal_lightbox
default False js_live_preview_in_modal_lightbox
2023-12-16 09:25:33 +03:00
AUTOMATIC1111 c16fcb7f46 Merge pull request #14307 from AUTOMATIC1111/default-Falst-js_live_preview_in_modal_lightbox
default False js_live_preview_in_modal_lightbox
2023-12-16 09:25:08 +03:00
gayshub 6d7e57ba6a fix the problem of ruff of github 2023-12-15 18:03:14 +08:00
gayshub da45e73b4f fix the problem of ruff of github 2023-12-15 17:57:58 +08:00
gayshub d859de37d9 fix the problem of ruff of github 2023-12-15 17:48:20 +08:00
gayshub 1242ba08e1 add allow specify the task id and get the location of task in the queue of pending task 2023-12-15 16:57:17 +08:00
w-e-w 0c5427960b make modal toolbar and icon opacity adjustable 2023-12-15 17:11:59 +09:00
w-e-w 3c0c277579 default False js_live_preview_in_modal_lightbox 2023-12-15 00:48:37 +09:00
Kohaku-Blueleaf 0fb34b57b8 Merge branch 'dev' into test-fp8 2023-12-14 16:54:45 +08:00
AUTOMATIC1111 2be85f8fe0 Merge pull request #14237 from ReneKroon/dev
#13354 : solve lora loading issue
2023-12-14 10:15:36 +03:00
AUTOMATIC1111 eb52c803b8 Merge pull request #14216 from wfjsw/state-dict-ref-comparison
change state dict comparison to ref compare
2023-12-14 10:15:22 +03:00
AUTOMATIC1111 f8871dedcf Merge pull request #14230 from AUTOMATIC1111/add-option-Live-preview-in-full-page-image-viewer
add option: Live preview in full page image viewer
2023-12-14 10:15:18 +03:00
AUTOMATIC1111 b7e0d4a7e1 Merge pull request #14229 from Nuullll/ipex-embedding
[IPEX] Fix embedding and ControlNet
2023-12-14 10:14:59 +03:00
AUTOMATIC1111 5cb1ce470d Merge pull request #14266 from kaalibro/dev
Re-add setting lost as part of e294e46
2023-12-14 10:14:54 +03:00
AUTOMATIC1111 888b928f0d Merge pull request #14276 from AUTOMATIC1111/fix-styles
Fix styles
2023-12-14 10:14:50 +03:00
AUTOMATIC1111 b55f09c4e1 Merge pull request #14270 from kaalibro/extra-options-elem-id
Assign id for "extra_options". Replace numeric field with slider.
2023-12-14 10:14:46 +03:00
AUTOMATIC1111 c7cd9b441d Merge pull request #14296 from akx/paste-resolution
Allow pasting in WIDTHxHEIGHT strings into the width/height fields
2023-12-14 10:14:41 +03:00
AUTOMATIC1111 6ef0ff39f2 Merge pull request #14300 from AUTOMATIC1111/oft_fixes
Fix wrong implementation in network_oft
2023-12-14 10:14:19 +03:00
AUTOMATIC1111 aeaf1c510f Merge pull request #14293 from HinaHyugaHime/master
Bump torch-rocm to 5.6/5.7
2023-12-14 10:10:58 +03:00
AUTOMATIC1111 097140ac1a Merge branch 'dev' into master 2023-12-14 10:10:43 +03:00
AUTOMATIC1111 778a30a95e Merge pull request #14237 from ReneKroon/dev
#13354 : solve lora loading issue
2023-12-14 10:08:03 +03:00
AUTOMATIC1111 96c393a7a7 Merge pull request #14269 from kaalibro/skip-interrupt-keyb-shortcuts
Add keyboard shortcuts for generate/skip/interrupt
2023-12-14 10:04:17 +03:00
AUTOMATIC1111 09013b357c Merge pull request #14216 from wfjsw/state-dict-ref-comparison
change state dict comparison to ref compare
2023-12-14 10:03:14 +03:00
AUTOMATIC1111 d45f790f58 Merge pull request #14230 from AUTOMATIC1111/add-option-Live-preview-in-full-page-image-viewer
add option: Live preview in full page image viewer
2023-12-14 09:59:48 +03:00
AUTOMATIC1111 8c32594d3b Merge pull request #14208 from CodeHatchling/soft-inpainting
Soft Inpainting
2023-12-14 09:56:12 +03:00
AUTOMATIC1111 f3cc5f8382 Merge pull request #14229 from Nuullll/ipex-embedding
[IPEX] Fix embedding and ControlNet
2023-12-14 09:52:23 +03:00
AUTOMATIC1111 28bafffdc2 Merge pull request #14266 from kaalibro/dev
Re-add setting lost as part of e294e46
2023-12-14 09:48:36 +03:00
AUTOMATIC1111 5db09d1865 Merge pull request #14276 from AUTOMATIC1111/fix-styles
Fix styles
2023-12-14 09:48:14 +03:00
AUTOMATIC1111 c5631aa90d Merge pull request #14270 from kaalibro/extra-options-elem-id
Assign id for "extra_options". Replace numeric field with slider.
2023-12-14 09:46:05 +03:00
AUTOMATIC1111 206de1a6b0 Merge pull request #14296 from akx/paste-resolution
Allow pasting in WIDTHxHEIGHT strings into the width/height fields
2023-12-14 09:41:18 +03:00
AUTOMATIC1111 b943eebb1d Merge pull request #14300 from AUTOMATIC1111/oft_fixes
Fix wrong implementation in network_oft
2023-12-14 09:39:57 +03:00
Kohaku-Blueleaf 3772a82a70 better naming and correct order for device. 2023-12-14 01:47:13 +08:00
Kohaku-Blueleaf 8fc67f3851 remove debug print 2023-12-14 01:44:49 +08:00
Kohaku-Blueleaf 265bc26c21 Use self.scale instead of custom finalize 2023-12-14 01:43:24 +08:00
Kohaku-Blueleaf 735c9e8059 Fix network_oft 2023-12-14 01:38:32 +08:00
Aarni Koskela 89cfbc3bbe Allow pasting in WIDTHxHEIGHT strings into the width/height fields 2023-12-13 12:22:13 +02:00
Hina bda86f0fd9 Update webui.sh 2023-12-12 19:39:14 -06:00
w-e-w cc41cc4349 on mouse hover show / hide modal image viewer icons 2023-12-13 02:06:56 +09:00
kaalibro 6513470f0d Remove unnecessary 'else', add 'lightboxModal' check 2023-12-11 18:06:08 +06:00
kaalibro cee1a40651 Fix linter issues 2023-12-10 17:06:12 +06:00
kaalibro 1d42babd32 Replace Ctrl+Alt+Enter with Esc 2023-12-10 16:28:56 +06:00
kaalibro 6b8143a84e Number of columns slider: max count set to 20, add description info 2023-12-10 15:35:06 +06:00
w-e-w 8b74389e76 fix styles.csv filename 2023-12-10 15:48:16 +09:00
w-e-w 23a0e60b9b fix save styles 2023-12-10 15:48:00 +09:00
drhead 5381405eaa re-derive sqrt alpha bar and sqrt one minus alphabar
This is the only place these values are ever referenced outside of training code so this change is very justifiable and more consistent.
2023-12-09 14:09:28 -05:00
kaalibro 1a79a5049b Assign id for "extra_options". Replace numeric field with slider in Settings. 2023-12-09 22:35:31 +06:00
kaalibro 9c201550dd Add keyboard shortcuts for generation
(Removed Alt+Enter) Ctrl+Enter to start/restart generation
(New) Alt/Option+Enter to skip generation
(New) Ctrl+Alt/Option+Enter to interrupt generation
2023-12-09 21:04:45 +06:00
kaalibro 39ec4cfea9 Re-add setting lost as part of e294e46 2023-12-09 19:12:59 +06:00
Nuullll 049d5642e5 Fix format 2023-12-09 18:11:26 +08:00
Nuullll 5942979344 Fix ControlNet 2023-12-09 18:09:45 +08:00
CodeHatchling f1ff932caf Formatted soft_inpainting. 2023-12-08 17:33:11 -07:00
CodeHatchling b2414476ef soft_inpainting now appears in the "inpaint" section, and will not activate unless inpainting is activated. 2023-12-08 17:32:41 -07:00
Rene Kroon 16bdcce92d #13354: solve lora loading issue 2023-12-08 21:20:55 +01:00
CodeHatchling 659f62e120 Fixed grammar error. 2023-12-07 21:39:54 -07:00
CodeHatchling fc3e246c0f Fixed complaint about whitespace, updated help section for a parameter. 2023-12-07 20:28:38 -07:00
CodeHatchling f284ae23bc Added parameters for the composite stage, fixed batched generation. 2023-12-07 20:19:35 -07:00
CodeHatchling 0ef4a4cb23 Fixed error that occurs when using vanilla samplers (somehow). 2023-12-07 14:54:26 -07:00
CodeHatchling 56604f08a1 Moved image filters used by soft inpainting into soft_inpainting.py from images.py 2023-12-07 14:53:44 -07:00
CodeHatchling 8dbacc7d01 Fixed "No newline at end of file". 2023-12-07 14:30:30 -07:00
CodeHatchling 2abc417834 Re-implemented soft inpainting via a script. Also fixed some mistakes with the previous hooks, removed unnecessary formatting changes, removed code that I had forgotten to. 2023-12-07 14:28:02 -07:00
Kohaku-Blueleaf 39ebd5684b Merge remote-tracking branch 'origin/release_candidate' into test-fp8 2023-12-07 20:48:59 +08:00
CodeHatchling ac45789123 Removed soft inpainting, added hooks for softpainting to work instead. 2023-12-06 21:16:27 -07:00
CodeHatchling 4608f6236f Removed changes in some scripts since the arguments for soft painting are no longer passed through the same path as "mask_blur". 2023-12-06 18:11:17 -07:00
CodeHatchling e90d4334ad A custom blending function can be provided by p, replacing the use of soft_inpainting. 2023-12-06 18:02:07 -07:00
w-e-w 9d2cbf8e97 add option: Live preview in full page image viewer
make #13459 "show the preview image in the modal view if available" optional
2023-12-06 23:06:32 +09:00
Nuullll 746783f7a4 [IPEX] Fix embedding
Cast `torch.bmm` args into same `dtype`.

Fixes the following error when using Text Inversion embedding (#14224):

```
RuntimeError: could not create a primitive descriptor for a matmul
primitive
```
2023-12-06 20:55:47 +08:00
fuchen.ljl c2bdbb67b6 Merge branch 'dev' into kingljl-patch-memory-leak 2023-12-06 20:42:04 +08:00
fuchen.ljl 4d56383025 Long distance memory overflow issue
Problem: The memory will slowly increase with the drawing until restarting.
Observation: GC analysis shows that no occupation has occurred, so it is suspected to be a problem with the underlying allocator.
Reason: Under Linux, glibc is used to allocate memory. glibc uses brk and mmap to allocate memory, and the memory allocated by brk cannot be released until the high-address memory is released. That is to say, if you apply for two pieces of memory A and B through brk, it is impossible to release A before B is released, and it is still occupied by the process. Check the suspected "memory leak" through TOP.
So I replaced TCMalloc, but found that libtcmalloc_minimal could not find ptthread_Key_Create. After analysis, it was found that pthread was not entered during compilation.
2023-12-06 20:23:56 +08:00
Kohaku-Blueleaf 294ec5ac37 Merge branch 'dev' into test-fp8 2023-12-06 15:16:49 +08:00
Kohaku-Blueleaf 672dc4efa8 Fix forced reload 2023-12-06 15:16:10 +08:00
Jabasukuriputo Wang 895456c4a2 change state dict comparison to ref compare 2023-12-05 18:00:48 -06:00
AUTOMATIC1111 120a84bd2f Merge pull request #14203 from AUTOMATIC1111/remove-clean_text()
remove clean_text()
2023-12-05 07:15:54 +03:00
AUTOMATIC1111 f92d61497a Merge pull request #14203 from AUTOMATIC1111/remove-clean_text()
remove clean_text()
2023-12-05 07:15:39 +03:00
CodeHatchling 38864816fa Merge remote-tracking branch 'origin2/dev' into soft-inpainting
# Conflicts:
#	modules/processing.py
2023-12-04 20:38:13 -07:00
CodeHatchling 49bbf11407 Fixed unused import. 2023-12-04 19:47:40 -07:00
CodeHatchling 6fc12428e3 Fixed issue where batched inpainting (batch size > 1) wouldn't work because of mismatched tensor sizes. The 'already_decoded' decoded case should also be handled correctly (tested indirectly). 2023-12-04 19:42:59 -07:00
CodeHatchling b32a334e3d Applies a convert('RGBA') operation early to mimic previous behaviour. 2023-12-04 17:57:10 -07:00
CodeHatchling 60c602232f Restored original formatting. 2023-12-04 17:55:14 -07:00
CodeHatchling 57f29bd61d Re-introduce latent blending step from the vanilla inpainting procedure. 2023-12-04 17:41:18 -07:00
CodeHatchling 1455159cf4 Fixed issue with whitespace, removed commented out code that was meant to be used as a reference. 2023-12-04 16:43:57 -07:00
CodeHatchling 976c1053ef Cleaned up code, moved main code contributions into soft_inpainting.py 2023-12-04 16:06:58 -07:00
w-e-w 854f8c318c remove clean_text() 2023-12-05 04:41:09 +09:00
AUTOMATIC1111 368d66c9cc add hypertile infotext 2023-12-04 15:56:11 +03:00
AUTOMATIC1111 22e23dbf29 add hypertile infotext 2023-12-04 15:56:03 +03:00
AUTOMATIC1111 81105ee013 repair old handler for postprocessing API in a way that doesn't break interface 2023-12-04 13:11:12 +03:00
AUTOMATIC1111 883d6a2b34 repair old handler for postprocessing API in a way that doesn't break interface 2023-12-04 13:11:00 +03:00
AUTOMATIC1111 24dae9bc4c repair old handler for postprocessing API 2023-12-04 12:36:56 +03:00
AUTOMATIC1111 15322e1b1a repair old handler for postprocessing API 2023-12-04 12:36:41 +03:00
CodeHatchling 259d33c3c8 Enables the original functionality to be toggled on and off. 2023-12-04 01:57:21 -07:00
Kohaku-Blueleaf f5f89780cc Merge branch 'dev' into test-fp8 2023-12-04 16:47:41 +08:00
CodeHatchling aaacf48232 Organized the settings and UI of soft inpainting to allow for toggling the feature, and centralizes default values to reduce the amount of copy-pasta. 2023-12-04 01:27:22 -07:00
AUTOMATIC1111 48fae7ccdc update changelog 2023-12-04 09:35:52 +03:00
AUTOMATIC1111 9e1f3feb12 make webui not crash when running with --disable-all-extensions option 2023-12-04 09:15:19 +03:00
AUTOMATIC1111 208760f348 Merge pull request #14192 from illtellyoulater/patch-1
Update launch_utils.py - fixes repetead package reinstalls
2023-12-04 08:14:40 +03:00
missionfloyd 06725af40b Lint 2023-12-03 21:26:12 -07:00
illtellyoulater 639ccf254b Update launch_utils.py to fix wrong dep. checks and reinstalls
Fixes failing dependency checks for extensions having a different package name and import name (for example ffmpeg-python / ffmpeg), which currently is causing the unneeded reinstall of packages at runtime.

In fact with current code, the same string is used when installing a package and when checking for its presence, as you can see in the following example:

> launch_utils.run_pip("install ffmpeg-python", "required package")
[ Installing required package: "ffmpeg-python" ... ]
[ Installed ]

> launch_utils.is_installed("ffmpeg-python")
False

... which would actually return true with:

> launch_utils.is_installed("ffmpeg")
True
2023-12-04 02:35:35 +00:00
CodeHatchling 552f8bc832 "Uncrop" the original denoised image for the composite step, fixing a "ValueError: Images do not match" *shudder* 2023-12-03 14:49:41 -07:00
CodeHatchling 28a2b5b4aa Fixed a math mistake. 2023-12-03 14:20:20 -07:00
AUTOMATIC1111 334298d473 Merge pull request #14186 from akx/torchvision-basicsr-hack
Add import_hook hack to work around basicsr/torchvision incompatibility
2023-12-03 19:58:53 +03:00
AUTOMATIC1111 2d5507fce5 Merge pull request #14181 from AUTOMATIC1111/rework-mask-and-mask_composite-logic
slight optimization for mask and mask_composite
2023-12-03 19:58:14 +03:00
Aarni Koskela d92ce145bb Add import_hook hack to work around basicsr incompatibility
Fixes #13985
2023-12-03 16:55:38 +02:00
w-e-w d3fdc4af61 rework mask and mask_composite logic 2023-12-03 18:22:41 +09:00
AUTOMATIC1111 b4776ea3a2 Merge pull request #14177 from catboxanon/fix/mask-composite-save
Fix `save_samples` being checked early when saving masked composite
2023-12-03 11:57:14 +03:00
CodeHatchling 3bd3a09160 Merge remote-tracking branch 'origin/dev' into soft-inpainting
# Conflicts:
#	modules/processing.py
2023-12-02 21:14:02 -07:00
CodeHatchling bb04d400c9 Rewrote latent_blend() to use in-place operations and to aggressively "del" references with the intention of minimizing allocations and easing garbage collection. 2023-12-02 21:08:26 -07:00
CodeHatchling 73ab982d1b Blend masks are now produced afterward, based on an estimate of the visual difference between the original and modified latent images. This should remove ghosting and clipping artifacts from masks, while preserving the details of largely unchanged content. 2023-12-02 21:07:02 -07:00
AUTOMATIC1111 fed5b1d55c Merge pull request #14178 from catboxanon/fix/missing-setting-v1
Re-add `keyedit_delimiters_whitespace` setting lost as part of commit e294e46
2023-12-03 06:33:16 +03:00
Kohaku-Blueleaf 9a15ae2a92 Merge branch 'dev' into test-fp8 2023-12-03 10:54:54 +08:00
CodeHatchling 609dea36ea Added utility functions related to processing masks. 2023-12-02 18:56:49 -07:00
catboxanon 9528d66c94 Re-add setting lost as part of e294e46 2023-12-02 14:56:26 -05:00
drhead 78acdcf677 fix variable 2023-12-02 14:09:18 -05:00
drhead dc1adeecdd Create alphas_cumprod_original on full precision path 2023-12-02 14:06:56 -05:00
drhead 4a43334376 Revert 309a606c 2023-12-02 14:05:42 -05:00
catboxanon 83e8c32276 Fix save_samples being checked early when saving masked composite 2023-12-02 13:30:53 -05:00
drhead 81c4ddf6eb fix linting 2023-12-02 13:11:00 -05:00
drhead 309a606c2f ensure that original alpha bar always exists 2023-12-02 13:07:45 -05:00
AUTOMATIC1111 ac02216e54 alternate implementation for unet forward replacement that does not depend on hijack being applied 2023-12-02 19:35:47 +03:00
AUTOMATIC1111 af5f0734c9 Merge pull request #14171 from Nuullll/ipex
Initial IPEX support for Intel Arc GPU
2023-12-02 19:22:32 +03:00
AUTOMATIC1111 a5f61aa8c5 potential fix for #14172 2023-12-02 18:03:34 +03:00
AUTOMATIC1111 11d23e8ca5 remove Train/Preprocessing tab and put all its functionality into extras batch images mode 2023-12-02 18:01:11 +03:00
Kohaku-Blueleaf 50a21cb09f Ensure the cached weight will not be affected 2023-12-02 22:06:47 +08:00
Nuullll 96871e4f74 Remove webui-ipex-user.bat 2023-12-02 17:11:31 +08:00
AUTOMATIC1111 4a666381bf extras tab batch: actually use original filename
preprocessing upscale: do not do an extra upscale step if it's not needed
2023-12-02 12:11:21 +03:00
Kohaku-Blueleaf 110485d5bb Merge branch 'dev' into test-fp8 2023-12-02 17:00:09 +08:00
Nuullll 87cd07b3af Fix fp64 2023-12-02 15:54:25 +08:00
AUTOMATIC1111 0bb6e00ba3 Merge pull request #13957 from h43lb1t0/extra_network_subdirs
dir buttons start with / so only the correct dir will be shown and no…
2023-12-02 09:59:29 +03:00
AUTOMATIC1111 87d973e389 Merge pull request #14063 from wfjsw/use-ext-name-for-installed
use extension name for determining an extension is installed in the index
2023-12-02 09:58:44 +03:00
AUTOMATIC1111 ef6b8123dc put code that can cause an exception into its own function for #14120 2023-12-02 09:57:39 +03:00
AUTOMATIC1111 5ed7daa3d9 Merge pull request #14120 from AUTOMATIC1111/protect-against-bad-ui-creation-scripts
catch uncaught exception with ui creation scripts
2023-12-02 09:54:21 +03:00
AUTOMATIC1111 ef1723ef41 Merge pull request #14125 from cjj1977/dev
Allow use of mutiple styles csv files
2023-12-02 09:53:27 +03:00
AUTOMATIC1111 7547d7c791 Merge pull request #14126 from aria1th/hypertile-xyz
Support XYZ scripts / split hires path from unet
2023-12-02 09:48:40 +03:00
AUTOMATIC1111 88736b5557 Merge pull request #14131 from read-0nly/patch-1
Update devices.py - Make 'use-cpu all' actually apply to 'all'
2023-12-02 09:46:19 +03:00
AUTOMATIC1111 9eadc4f146 Merge pull request #14121 from AUTOMATIC1111/fix-Auto-focal-point-crop-for-opencv-4.8.x
Fix auto focal point crop for opencv >= 4.8
2023-12-02 09:46:00 +03:00
AUTOMATIC1111 97c8e7e0c7 Merge pull request #14119 from AUTOMATIC1111/add-Block-component-creation-callback
add Block component creation callback
2023-12-02 09:45:03 +03:00
AUTOMATIC1111 e12a26c253 Merge pull request #14046 from hidenorly/AddFP32FallbackSupportOnSdVaeApprox
Add FP32 fallback support on sd_vae_approx
2023-12-02 09:44:00 +03:00
AUTOMATIC1111 600036d158 Merge pull request #14156 from AUTOMATIC1111/metadata-pop-up-size-limit
fix not able to exit metadata popup when pop up is too big
2023-12-02 09:30:27 +03:00
AUTOMATIC1111 4125552752 Merge pull request #14170 from MrCheeze/sd-turbo
Add support for SD 2.1 Turbo
2023-12-02 09:30:07 +03:00
AUTOMATIC1111 e294e46d46 split UI settings page into many 2023-12-02 09:26:38 +03:00
Nuullll 7499148ad4 Disable ipex autocast due to its bad perf 2023-12-02 14:00:46 +08:00
AUTOMATIC1111 b58d061e41 infotext updates: add option to disregard certain infotext fields, add option to not include VAE in infotext, add explanation to infotext settings page, move some options to infotext settings page 2023-12-02 08:33:28 +03:00
MrCheeze 6080045b2a Add support for SD 2.1 Turbo, by converting the state dict from SGM to LDM on load 2023-12-01 22:58:05 -05:00
MrCheeze 293f44e6c1 Fix bug where is_using_v_parameterization_for_sd2 fails because the sd_hijack is only partially undone 2023-12-01 22:56:08 -05:00
missionfloyd 01c8f1803a Close popups with escape key 2023-11-30 22:36:12 -07:00
w-e-w c2ed413203 add max-heigh/width to global-popup-inner
prevent the pop-up from being too big as to making exiting the pop-up impossible
2023-12-01 02:59:41 +09:00
Nuullll 8b40f475a3 Initial IPEX support 2023-11-30 20:22:46 +08:00
drhead 668ae34e21 remove debug print 2023-11-29 22:48:31 -05:00
catboxanon de79597ab9 Only apply ztSNR related code if alphas_cumprod exists 2023-11-29 18:33:32 -05:00
catboxanon ffa7f8201d Lint 2023-11-29 18:10:43 -05:00
catboxanon ec6ee5c13b Fix infotext for ztSNR 2023-11-29 18:10:27 -05:00
drhead 6d0a8dcd89 Implement zero terminal SNR schedule option 2023-11-29 17:42:07 -05:00
drhead 588a52891d Add options for zero terminal SNR 2023-11-29 17:40:23 -05:00
drhead b25c126ccd Protect alphas_cumprod from downcasting 2023-11-29 17:38:53 -05:00
CodeHatchling c7a1ff8720 Tweaked default values. 2023-11-28 23:31:10 -07:00
CodeHatchling 284fd8f415 Tweaked UI sliders and labels. 2023-11-28 23:03:50 -07:00
CodeHatchling c5c7fa06aa Added slider for detail preservation strength, removed largely needless offset parameter, changed labels in UI and for saving to/pasting data from PNG files. 2023-11-28 22:35:07 -07:00
CodeHatchling debf836fcc Added UI elements to control blending parameters. 2023-11-28 16:15:36 -07:00
CodeHatchling a6e5846453 Nerfs the aggressive post-processing step of overlaying the original image. 2023-11-28 16:13:42 -07:00
CodeHatchling e715e46b6a Implements "scheduling" for blending of the original latents and a latent blending formula that preserves details in blend transition areas. 2023-11-28 16:10:22 -07:00
CodeHatchling bbba133f05 Removed conflicting step that replaces the softly inpainted latents with a naive blend with the original latents. 2023-11-28 15:09:43 -07:00
CodeHatchling dec791d35d Removed code which forces the inpainting mask to be 0 or 1. Now fractional values (e.g. 0.5) are accepted. 2023-11-28 15:05:01 -07:00
hidenorly 81c00728b8 Fix the Ruff error about unused import 2023-11-29 04:59:35 +09:00
hidenorly a0096c5897 Add FP32 fallback support on torch.nn.functional.interpolate
This tries to execute interpolate with FP32 if it failed.

Background is that
on some environment such as Mx chip MacOS devices, we get error as follows:

```
"torch/nn/functional.py", line 3931, in interpolate
        return torch._C._nn.upsample_nearest2d(input, output_size, scale_factors)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    RuntimeError: "upsample_nearest2d_channels_last" not implemented for 'Half'
```

In this case, ```--no-half``` doesn't help to solve. Therefore this commits add the FP32 fallback execution to solve it.

Note that the ```upsample_nearest2d``` is called from ```torch.nn.functional.interpolate```.
And the fallback for torch.nn.functional.interpolate is necessary at
```modules/sd_vae_approx.py``` 's ```VAEApprox.forward```
```repositories/stable-diffusion-stability-ai/ldm/modules/diffusionmodules/openaimodel.py``` 's ```Upsample.forward```
2023-11-29 04:45:04 +09:00
hidenorly 39eae9f009 Revert "Add FP32 fallback support on sd_vae_approx"
This reverts commit 58c19545c8.
Since the modification is expected to move to mac_specific.py
(https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14046#issuecomment-1826731532)
2023-11-29 04:07:48 +09:00
w-e-w d608926f81 reformat file with uniform indentation 2023-11-28 12:12:27 +09:00
w-e-w 03ee297aa2 fix Auto focal point crop for opencv >= 4.8.x
autocrop.download_and_cache_models
in opencv >= 4.8 the face detection model was updated
download the base on opencv version
returns the model path or raise exception
2023-11-28 12:09:51 +09:00
obsol 3cd6e1d0a0 Update devices.py
fixes issue where "--use-cpu" all properly makes SD run on CPU but leaves ControlNet (and other extensions, I presume) pointed at GPU, causing a crash in ControlNet caused by a mismatch between devices between SD and CN

https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/14097
2023-11-27 19:21:43 -05:00
aria1th ec78354efa hypertile_xyz: we don't need isnumeric check for AxisOption 2023-11-27 22:25:28 +09:00
aria1th 524d6a4dba fix ruff - set comprehension 2023-11-27 22:13:18 +09:00
aria1th f207eb7a0d fix ruff in hypertile_xyz.py 2023-11-27 22:11:28 +09:00
aria1th 601a7b4ce5 cache divisors / fix ruff 2023-11-27 22:10:31 +09:00
Charlie Joynt 0cd5b0ed54 Merge branch 'dev' of https://github.com/cjj1977/stable-diffusion-webui into dev 2023-11-27 12:11:06 +00:00
aria1th 23c36f59b4 Support XYZ scripts / split hires path from unet 2023-11-27 21:10:26 +09:00
Charlie Joynt 26a0c29587 Allow use of mutiple styles csv files
* https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/14122
Fix edge case where style text has multiple {prompt} placeholders
* https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/14005
2023-11-27 12:08:51 +00:00
MisterSeajay a75314b41f bugfix for warning message (#6)
* bugfix for warning message

* bugfix error message
2023-11-27 12:03:42 +00:00
MisterSeajay 1c64bb7140 bugfix for warning message (#6) 2023-11-27 11:57:27 +00:00
Charlie Joynt 9621ca4d64 Allow use of mutiple styles csv files 2023-11-27 11:39:50 +00:00
w-e-w 8a6e4bda21 catch uncaught exception with ui creation scripts
prevent total webui crash
2023-11-27 14:18:17 +09:00
w-e-w b30cc87b78 add Block component creation callback 2023-11-27 13:15:17 +09:00
Jabasukuriputo Wang 1f6844eb7e also consider extension url 2023-11-26 10:04:39 -06:00
AUTOMATIC1111 f0f100e67b add categories to settings 2023-11-26 17:56:22 +03:00
AUTOMATIC1111 500de919ed Merge pull request #14108 from AUTOMATIC1111/json.dump(ensure_ascii=False)
json.dump(ensure_ascii=False)
2023-11-26 16:15:56 +03:00
w-e-w a15dd151ff json.dump(ensure_ascii=False)
improve json readability
2023-11-26 21:56:21 +09:00
AUTOMATIC1111 2a40d3c603 compact prompt layout: preserve scroll when switching between lora tabs 2023-11-26 14:58:56 +03:00
Kohaku-Blueleaf 3d341ebc7d Merge branch 'dev' into test-fp8 2023-11-26 17:32:52 +08:00
AUTOMATIC1111 29f04149b6 update torch to 2.1.0 2023-11-26 12:07:33 +03:00
AUTOMATIC1111 e44103264d Merge pull request #13936 from cabelo/compatibility
Compatibility
2023-11-26 11:57:13 +03:00
AUTOMATIC1111 6955c210b7 Merge pull request #14059 from akx/upruff
Update Ruff to 0.1.6
2023-11-26 11:54:36 +03:00
AUTOMATIC1111 d1750e5eca fix linter errors 2023-11-26 11:37:12 +03:00
AUTOMATIC1111 c5a0c59a83 do not save HTML explanations from options page to config 2023-11-26 11:36:17 +03:00
AUTOMATIC1111 f7f015e84b Merge pull request #14084 from wfjsw/move-from-sysinfo-to-errors
Move exception_records related methods to errors.py
2023-11-26 11:29:27 +03:00
AUTOMATIC1111 f85b74763d Merge branch 'hypertile-in-sample' into dev 2023-11-26 11:18:49 +03:00
AUTOMATIC1111 fd8674a4bc Merge pull request #13948 from aria1th/hypertile-in-sample
support HyperTile optimization
2023-11-26 11:18:25 +03:00
AUTOMATIC1111 d2e0c1ca13 rework hypertile into a built-in extension 2023-11-26 11:17:38 +03:00
AUTOMATIC1111 3a9bf4ac10 move file 2023-11-26 08:29:12 +03:00
Kohaku-Blueleaf 40ac134c55 Fix pre-fp8 2023-11-25 12:35:09 +08:00
Jabasukuriputo Wang 5cedc8f9b2 remove traceback in sysinfo 2023-11-24 11:30:30 -06:00
Jabasukuriputo Wang 86b99b1e98 Move exception_records related methods to errors.py 2023-11-24 11:28:54 -06:00
wfjsw ac2a981c4f use extension name for determining an extension is installed in the index 2023-11-22 22:40:24 -06:00
Aarni Koskela 066afda2f6 Simplify restart_sampler (suggested by ruff) 2023-11-22 18:05:12 +02:00
Aarni Koskela 8fe1e19522 Update ruff to 0.1.6 2023-11-22 18:05:12 +02:00
Kohaku-Blueleaf f5d719d1f1 Add forced reload for fp16 cache 2023-11-22 01:45:56 +08:00
Kohaku-Blueleaf 370a77f8e7 Option for using fp16 weight when apply lora 2023-11-21 19:59:34 +08:00
AUTOMATIC1111 8aa51f682c fix [Bug]: (Dev Branch) Placing "Dimensions" first in "ui_reorder_list" prevents start #14047 2023-11-21 08:32:07 +03:00
hidenorly 58c19545c8 Add FP32 fallback support on sd_vae_approx
This tries to execute interpolate with FP32 if it failed.

Background is that
on some environment such as Mx chip MacOS devices, we get error as follows:

```
"torch/nn/functional.py", line 3931, in interpolate
        return torch._C._nn.upsample_nearest2d(input, output_size, scale_factors)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    RuntimeError: "upsample_nearest2d_channels_last" not implemented for 'Half'
```

In this case, ```--no-half``` doesn't help to solve. Therefore this commits add the FP32 fallback execution to solve it.

Note that the submodule may require additional modifications. The following is the example modification on the other submodule.

```repositories/stable-diffusion-stability-ai/ldm/modules/diffusionmodules/openaimodel.py

class Upsample(nn.Module):
..snip..
    def forward(self, x):
        assert x.shape[1] == self.channels
        if self.dims == 3:
            x = F.interpolate(
                x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
            )
        else:
            try:
                x = F.interpolate(x, scale_factor=2, mode="nearest")
            except:
                x = F.interpolate(x.to(th.float32), scale_factor=2, mode="nearest").to(x.dtype)
        if self.use_conv:
            x = self.conv(x)
        return x
..snip..
```

You can see the FP32 fallback execution as same as sd_vae_approx.py.
2023-11-21 01:13:53 +09:00
Tom Haelbich 314ae1535e added option for default behavior of dir buttons 2023-11-20 16:19:54 +01:00
AUTOMATIC1111 5f36f6ab21 Merge pull request #14009 from AUTOMATIC1111/Option-to-show-batch-img2img-results-in-UI
Option to show batch img2img results in UI
2023-11-20 17:44:58 +03:00
AUTOMATIC1111 1463cea949 Merge branch 'dag' into dev 2023-11-20 14:50:01 +03:00
AUTOMATIC1111 73a0b4bba6 Merge pull request #13944 from wfjsw/dag
implementing script metadata and DAG sorting mechanism
2023-11-20 14:49:46 +03:00
AUTOMATIC1111 9b471436b2 rework extensions metadata: use custom sorter that doesn't mess the order as much and ignores cyclic errors, use classes with named fields instead of dictionaries, eliminate some duplicated code 2023-11-20 14:47:09 +03:00
Kohaku-Blueleaf b2e039d07b Update webui-macos-env.sh 2023-11-20 14:05:32 +08:00
AUTOMATIC1111 411da7c281 Merge pull request #14035 from AUTOMATIC1111/sysinfo-json
save sysinfo as .json
2023-11-20 08:56:45 +03:00
w-e-w 6d337bf23d save sysinfo as .json
GitHub now allows uploading of .json files in issues
2023-11-20 01:38:31 +09:00
w-e-w dea5e43c83 Option to show batch img2img results in UI
shared.opts.img2img_batch_show_results_limit
limit the number of images return to the UI for batch img2img
default limit 32
0 no images are shown
-1 unlimited, all images are shown
2023-11-19 17:37:32 +09:00
Kohaku-Blueleaf 043d2edcf6 Better naming 2023-11-19 15:56:31 +08:00
Kohaku-Blueleaf f383af2729 update xformers/torch versions 2023-11-19 15:56:23 +08:00
Kohaku-Blueleaf 890181e1d4 Update the xformers/torch versions 2023-11-19 15:54:39 +08:00
Kohaku-Blueleaf 598da5cd49 Use options instead of cmd_args 2023-11-19 15:50:06 +08:00
Kohaku-Blueleaf b60e1088db Merge branch 'dev' into test-fp8 2023-11-19 15:24:57 +08:00
wfjsw bde439ef67 use metadata.ini for meta filename 2023-11-19 00:58:47 -06:00
AUTOMATIC1111 fc83af4432 Merge pull request #13931 from AUTOMATIC1111/style-hotkeys
Enable prompt hotkeys in style editor
2023-11-19 09:11:49 +03:00
AUTOMATIC1111 337bc4a2fb Merge pull request #13014 from AUTOMATIC1111/thread-safe-extranetworks-list_items
thread safe extra network list_items
2023-11-19 09:09:21 +03:00
AUTOMATIC1111 6fac65f334 Merge pull request #13929 from kingljl/fix-dependency-address-patch-1
Fix dependency address patch 1
2023-11-19 09:01:39 +03:00
AUTOMATIC1111 5a031d9233 Merge pull request #13962 from kaalibro/dev
Fixes generation restart not working for some users when 'Ctrl+Enter' is pressed
2023-11-19 09:01:11 +03:00
AUTOMATIC1111 e4e875fffe Merge pull request #13968 from kaalibro/extranetworks-path-sorting
Adds 'Path' sorting for Extra network cards
2023-11-19 09:00:05 +03:00
AUTOMATIC1111 b945ba716b Merge pull request #13977 from AUTOMATIC1111/hotfix-postprocessing-state-end
Hotfix: call shared.state.end() after postprocessing done
2023-11-19 08:59:32 +03:00
AUTOMATIC1111 2207ef363a Merge pull request #13692 from v0xie/network-oft
Support inference with OFT networks
2023-11-19 08:59:09 +03:00
AUTOMATIC1111 3a13b0e762 Merge pull request #13996 from Luxter77/patch-1
Adds tqdm handler to logging_config.py for progress bar integration
2023-11-19 08:57:14 +03:00
AUTOMATIC1111 6429c3db11 Merge pull request #13826 from ezxzeng/ui_mobile_optimizations
added accordion settings options
2023-11-19 08:42:58 +03:00
AUTOMATIC1111 5a9dc1c0ca Merge pull request #14004 from storyicon/master
feat: fix randn found element of type float at pos 2
2023-11-19 08:40:29 +03:00
storyicon 4f2a4a3615 feat: fix randn found element of type float at pos 2
Signed-off-by: storyicon <storyicon@foxmail.com>
2023-11-17 09:48:18 +00:00
aria1th 97431f29fe fix double gc and decoding with unet context 2023-11-17 10:05:28 +09:00
aria1th ffd0f8ddc3 set empty value for SD XL 3rd layer 2023-11-17 09:54:33 +09:00
aria1th c0725ba2d0 Fix inverted option issue
I'm pretty sure I was sleepy while implementing this
2023-11-17 09:34:50 +09:00
aria1th c40be2252a Fix critical issue - unet apply 2023-11-17 09:22:27 +09:00
Your Name 7021cdb1de actually adds handler to logging_config.py 2023-11-16 17:53:57 -03:00
Lucas Daniel Velazquez M cdb60a690d Take into account tqdm not being installed before first boot for logging 2023-11-16 16:49:59 -03:00
Lucas Daniel Velazquez M 236eb82c3a Adds tqdm handler to logging_config.py for progress bar integration 2023-11-16 13:20:33 -03:00
Kohaku-Blueleaf cd12256575 Merge branch 'dev' into test-fp8 2023-11-16 21:53:13 +08:00
AngelBottomless 472c22cc8a fix ruff - add newline 2023-11-16 19:03:45 +09:00
AngelBottomless bcfaf3979a convert/add hypertile options 2023-11-16 18:43:16 +09:00
v0xie eb667e715a feat: LyCORIS/kohya OFT network support 2023-11-15 18:28:48 -08:00
v0xie d6d0b22e66 fix: ignore calc_scale() for COFT which has very small alpha 2023-11-15 03:08:50 -08:00
aria1th af45872fdb copy LDM VAE key from XL 2023-11-15 15:15:14 +09:00
aria1th b29fc6d4de Implement Hypertile
Co-Authored-By: Kieran Hunt <kph@hotmail.ca>
2023-11-15 15:13:39 +09:00
AngelBottomless a292d2c47f hotfix: call shared.state.end() after postprocessing done 2023-11-15 14:26:37 +09:00
kaalibro c1c816006e Adds 'Path' sorting for Extra network cards 2023-11-13 22:01:52 +06:00
kaalibro 94e9669566 Fixes generation restart not working for some users when 'Ctrl+Enter' is pressed 2023-11-13 14:51:06 +06:00
missionfloyd 8048f36072 Lint 2023-11-12 17:12:50 -07:00
Tom Haelbich f6762d2ad9 dir buttons start with / so only the correct dir will be shown and not dirs with a substrings as name from the dir 2023-11-12 14:14:16 +01:00
wfjsw 3bb32befe9 bug fix 2023-11-11 11:58:19 -06:00
wfjsw 48d6102b31 fix 2023-11-11 11:17:26 -06:00
wfjsw 520e52f846 allow comma and whitespace as separator 2023-11-11 10:58:26 -06:00
wfjsw 7af576e745 remove the assumption of same name 2023-11-11 10:46:47 -06:00
aria1th 294f8a514f add hyperTile
https://github.com/tfernd/HyperTile
2023-11-11 23:28:12 +09:00
wfjsw bc1a450124 reverse the extension load order so builtin extensions load earlier natively 2023-11-11 04:08:45 -06:00
wfjsw 0d1924c48b populate loaded_extensions from extension list instead 2023-11-11 04:03:55 -06:00
wfjsw 0fc7dc1c04 implementing script metadata and DAG sorting mechanism 2023-11-11 04:01:13 -06:00
Emily Zeng 3a4a6c43a4 ExitStack as alternative to suppress 2023-11-10 16:06:01 -05:00
w-e-w 5432d93013 fix added accordion settings options 2023-11-11 05:30:35 +09:00
Alessandro de Oliveira Faria (A.K.A. CABELO) 6a86b3ad9b Compatibility with Debian 11, Fedora 34+ and openSUSE 15.4+ 2023-11-10 14:15:34 -03:00
missionfloyd 7ff54005fe Enable prompt hotkeys in style editor 2023-11-09 23:47:53 -07:00
Alessandro de Oliveira Faria (A.K.A. CABELO) 66767e3876 - opensuse compatibility 2023-11-10 03:45:44 -03:00
fuchen.ljl 6d77a6e1c6 Update README.md
Modify the stablediffusion dependency address
2023-11-10 14:40:39 +08:00
fuchen.ljl 42dbcad3ef Merge pull request #1 from kingljl/fix-dependency-address-patch-1
Update README.md
2023-11-10 14:38:26 +08:00
fuchen.ljl 98fc525a2c Update README.md
Modify the stablediffusion dependency address
2023-11-10 14:37:30 +08:00
Emily Zeng ff2952f105 multiline with statement for readibility 2023-11-09 13:35:52 -05:00
Emily Zeng 9aa4d098f0 removed changes that weren't merged properly 2023-11-09 13:25:24 -05:00
Emily Zeng a625a7bb81 moved nested with to single line to remove extra tabs 2023-11-09 13:15:06 -05:00
ezxzeng f9c14a8c8c Merge branch 'dev' into ui_mobile_optimizations 2023-11-07 15:25:27 -05:00
AUTOMATIC1111 5e80d9ee99 fix pix2pix producing bad results 2023-11-07 11:33:33 +03:00
AUTOMATIC1111 47bccbebae Merge pull request #13884 from GerryDE/notification-sound-volume
Add option to set notification sound volume
2023-11-07 08:29:06 +03:00
GerryDE 9ba991cad8 Add option to set notification sound volume 2023-11-07 03:09:08 +01:00
AUTOMATIC1111 9c1c0da026 fix exception related to the pix2pix 2023-11-06 11:17:36 +03:00
AUTOMATIC1111 656437e0a5 fix img2img_tabs error 2023-11-06 10:32:21 +03:00
AUTOMATIC1111 6ad666e479 more changes for #13865: fix formatting, rename the function, add comment and add a readme entry 2023-11-05 19:46:20 +03:00
AUTOMATIC1111 80d639a440 linter 2023-11-05 19:32:21 +03:00
AUTOMATIC1111 96ee3eff6c Merge pull request #13865 from Gothos/master
Add support for SSD-1B
2023-11-05 19:31:44 +03:00
AUTOMATIC1111 ff805d8d0e Merge branch 'dev' into master 2023-11-05 19:30:57 +03:00
AUTOMATIC1111 c3699d4fd1 compact prompt option disabled by default 2023-11-05 19:23:48 +03:00
AUTOMATIC1111 4d4a9e7332 added compact prompt option 2023-11-05 19:19:55 +03:00
Ritesh Gangnani 44c5097375 Use devices.torch_gc() instead of empty_cache() 2023-11-05 20:31:57 +05:30
Ritesh Gangnani 44db35fb1a Added memory clearance after deletion 2023-11-05 19:15:38 +05:30
Ritesh Gangnani ff1609f91e Add SSD-1B as a supported model 2023-11-05 19:13:49 +05:30
AUTOMATIC1111 d9499f4301 properly apply sort order for extra network cards when selected from dropdown
allow selection of default sort order in settings
remove 'Default' sort order, replace with 'Name'
2023-11-05 10:12:50 +03:00
AUTOMATIC1111 16ab174290 eslint 2023-11-05 09:20:15 +03:00
AUTOMATIC1111 046c7b053a Merge pull request #13855 from gibiee/patch-1
Corrected a typo in `modules/cmd_args.py`
2023-11-05 08:57:59 +03:00
AUTOMATIC1111 6b8c661c49 add a visible checkbox to input accordion 2023-11-05 08:55:54 +03:00
gibiee 2b06cefe66 correct a typo
modify "defaul" to "default"
2023-11-05 11:37:23 +09:00
v0xie 7edd50f304 Merge pull request #2 from v0xie/network-oft-change-impl
Use same updown implementation for LyCORIS OFT as kohya-ss OFT
2023-11-04 15:06:04 -07:00
v0xie bbf00a96af refactor: remove unused function 2023-11-04 14:56:47 -07:00
v0xie 329c8bacce refactor: use same updown for both kohya OFT and LyCORIS diag-oft 2023-11-04 14:54:36 -07:00
Kohaku-Blueleaf c3facab495 Merge branch 'dev' into test-fp8 2023-11-04 12:56:58 +08:00
v0xie 1dd25be037 Merge pull request #1 from v0xie/oft-faster
Support LyCORIS diag-oft OFT implementation (minus MultiheadAttention layer), maintains support for kohya-ss OFT
2023-11-03 19:47:27 -07:00
v0xie f6c8201e56 refactor: move factorization to lyco_helpers, separate calc_updown for kohya and kb 2023-11-03 19:35:15 -07:00
v0xie fe1967a4c4 skip multihead attn for now 2023-11-03 17:52:55 -07:00
AUTOMATIC1111 452ab8fe72 Merge pull request #13718 from avantcontra/bugfix_gfpgan_custom_path
fix bug when using --gfpgan-models-path
2023-11-03 20:19:58 +03:00
AUTOMATIC1111 399baa54c2 Merge pull request #13733 from dben/patch-1
Update prompts_from_file script to allow concatenating entries with the general prompt.
2023-11-03 20:19:04 +03:00
AUTOMATIC1111 21d561885e Merge pull request #13762 from wkpark/nextjob
call state.jobnext() before postproces*()
2023-11-03 20:16:58 +03:00
AUTOMATIC1111 73c74baa6a Merge pull request #13797 from Meerkov/master
Fix #13796
2023-11-03 20:11:54 +03:00
AUTOMATIC1111 1f373a2baa Merge pull request #13829 from AUTOMATIC1111/paren-fix
Fix parenthesis auto selection
2023-11-03 19:59:01 +03:00
AUTOMATIC1111 4afaaf8a02 add changelog entry 2023-11-03 19:50:14 +03:00
AUTOMATIC1111 bda2ecdbf5 Merge pull request #13839 from AUTOMATIC1111/httpx==0.24.1
requirements_versions httpx==0.24.1
2023-11-03 19:46:07 +03:00
AUTOMATIC1111 4c423f6d37 Merge pull request #13839 from AUTOMATIC1111/httpx==0.24.1
requirements_versions httpx==0.24.1
2023-11-03 19:44:57 +03:00
w-e-w cc80a09d82 Update requirements_versions.txt 2023-11-04 00:50:30 +09:00
missionfloyd 8052a4971e Fix parenthesis auto selection
Fixes #13813
2023-11-03 00:59:19 -06:00
Emily Zeng 759515316e added accordion settings options 2023-11-02 21:54:48 -04:00
v0xie d727ddfccd no idea what i'm doing, trying to support both type of OFT, kblueleaf diag_oft has MultiheadAttn which kohya's doesn't?, attempt create new module based off network_lora.py, errors about tensor dim mismatch 2023-11-02 00:13:11 -07:00
v0xie 65ccd6305f detect diag_oft type 2023-11-02 00:11:32 -07:00
v0xie a2fad6ee05 test implementation based on kohaku diag-oft implementation 2023-11-01 22:34:27 -07:00
Meerkov fbc5c531b9 Fix #13796
Fix comment error that makes understanding scheduling more confusing.
2023-10-29 15:37:08 -07:00
Nick Harrison be31e7e71a Remove blank line whitespace 2023-10-29 16:05:01 +00:00
Nick Harrison 844c23975f Add assertions for checking additional settings freezing parameters 2023-10-29 15:40:58 +00:00
Nick Harrison f2b83517aa Add new arguments to known command prompts 2023-10-29 15:40:13 +00:00
KohakuBlueleaf ddc2a3499b Add MPS manual cast 2023-10-28 16:52:35 +08:00
Kohaku-Blueleaf d4d3134f6d ManualCast for 10/16 series gpu 2023-10-28 15:24:26 +08:00
Won-Kyu Park 5121846d34 call state.jobnext() before postproces*() 2023-10-25 21:57:41 +09:00
Kohaku-Blueleaf 0beb131c7f change torch version 2023-10-25 20:07:37 +08:00
Kohaku-Blueleaf dda067f64d ignore mps for fp8 2023-10-25 19:53:22 +08:00
Kohaku-Blueleaf bf5067f50c Fix alphas cumprod 2023-10-25 12:54:28 +08:00
Kohaku-Blueleaf 4830b25136 Fix alphas_cumprod dtype 2023-10-25 11:53:37 +08:00
Kohaku-Blueleaf 1df6c8bfec fp8 for TE 2023-10-25 11:36:43 +08:00
Kohaku-Blueleaf 9c1eba2af3 Fix lint 2023-10-24 02:11:27 +08:00
Kohaku-Blueleaf eaa9f5162f Add CPU fp8 support
Since norm layer need fp32, I only convert the linear operation layer(conv2d/linear)

And TE have some pytorch function not support bf16 amp in CPU. I add a condition to indicate if the autocast is for unet.
2023-10-24 01:49:05 +08:00
David Benson dfc4c27b24 linting issue 2023-10-23 08:26:40 -04:00
David Benson 88b2ef3b04 Update prompts_from_file script to allow concatenating entries with the general prompt. 2023-10-23 08:16:26 -04:00
v0xie 6523edb8a4 style: conform style 2023-10-22 09:31:15 -07:00
v0xie 3b8515d2c9 fix: multiplier applied twice in finalize_updown 2023-10-22 09:27:48 -07:00
v0xie 4a50c9638c refactor: remove used OFT functions 2023-10-22 08:54:24 -07:00
v0xie de8ee92ed8 fix: use merge_weight to cache value 2023-10-21 17:37:17 -07:00
v0xie 76f5abdbdb style: cleanup oft 2023-10-21 16:07:45 -07:00
v0xie fce86ab7d7 fix: support multiplier, no forward pass hook 2023-10-21 16:03:54 -07:00
v0xie 7683547728 fix: return orig weights during updown, merge weights before forward 2023-10-21 14:42:24 -07:00
v0xie 2d8c894b27 refactor: use forward hook instead of custom forward 2023-10-21 13:43:31 -07:00
avantcontra 236dd55dbe fix Blank line contains whitespace 2023-10-22 04:32:13 +08:00
avantcontra 443ca983ad fix bug when using --gfpgan-models-path 2023-10-22 03:21:23 +08:00
AUTOMATIC1111 464fbcd921 fix the situation with emphasis editing (aaaa:1.1) bbbb (cccc:1.1) 2023-10-21 09:09:32 +03:00
AUTOMATIC1111 384fab9627 rework some of changes for emphasis editing keys, force conversion of old-style emphasis 2023-10-21 08:45:51 +03:00
v0xie 0550659ce6 style: fix ambiguous variable name 2023-10-19 13:13:02 -07:00
v0xie d10c4db57e style: formatting 2023-10-19 12:52:14 -07:00
v0xie 321680ccd0 refactor: fix constraint, re-use get_weight 2023-10-19 12:41:17 -07:00
Kohaku-Blueleaf 5f9ddfa46f Add sdxl only arg 2023-10-19 23:57:22 +08:00
Kohaku-Blueleaf 7c128bbdac Add fp8 for sd unet 2023-10-19 13:56:17 +08:00
v0xie eb01d7f0e0 faster by calculating R in updown and using cached R in forward 2023-10-18 04:56:53 -07:00
v0xie 853e21d98e faster by using cached R in forward 2023-10-18 04:27:44 -07:00
v0xie 1c6efdbba7 inference working but SLOW 2023-10-18 04:16:01 -07:00
v0xie ec718f76b5 wip incorrect OFT implementation 2023-10-17 23:35:50 -07:00
Anthony Fu 3d15e58b0a feat: refactor 2023-10-16 15:00:17 +08:00
Anthony Fu 8aa13d5dce Interrupt after current generation 2023-10-16 14:12:18 +08:00
AUTOMATIC1111 861cbd5636 Merge pull request #13644 from XpucT/dev
Start / Restart generation by Ctrl (Alt) + Enter
2023-10-15 14:19:48 +03:00
Khachatur Avanesian d33cb2b812 Add files via upload
LF
2023-10-15 11:01:45 +03:00
Khachatur Avanesian 3e223523ce Update script.js 2023-10-15 10:48:50 +03:00
Khachatur Avanesian d295e97a0d Update script.js
LF instead CRLF
2023-10-15 10:37:48 +03:00
Khachatur Avanesian 77bd953da2 Update script.js
Exclude lambda
2023-10-15 10:25:36 +03:00
AUTOMATIC1111 2f6ea8b103 respect keyedit_precision_attention setting when converting from old (((attention))) syntax 2023-10-15 10:12:38 +03:00
AUTOMATIC1111 a3d9b011a3 Merge pull request #13533 from missionfloyd/edit-attention-fix
Edit-attention fixes
2023-10-15 10:08:52 +03:00
AUTOMATIC1111 282903bb67 repair unload sd checkpoint button 2023-10-15 09:41:02 +03:00
AUTOMATIC1111 0d65d0eabd add an option to not print stack traces on ctrl+c. 2023-10-15 08:45:38 +03:00
Khachatur Avanesian f00eaa4d00 Start / Restart generation by Ctrl (Alt) + Enter
Add ability to interrupt current generation and start generation again by Ctrl (Alt) + Enter
2023-10-15 02:34:03 +03:00
AUTOMATIC1111 d4255506ff Merge pull request #13638 from wkpark/user-settings-2
webui.settings.bat support
2023-10-14 23:00:35 +03:00
Won-Kyu Park 117ec71994 support webui.settings.bat 2023-10-15 04:36:27 +09:00
AUTOMATIC1111 4be7b620c2 Merge pull request #13568 from AUTOMATIC1111/lora_emb_bundle
Add lora-embedding bundle system
2023-10-14 12:18:55 +03:00
AUTOMATIC1111 a8cbe50c9f remove duplicated code 2023-10-14 12:17:59 +03:00
AUTOMATIC1111 19f5795c27 Merge pull request #13463 from FluttyProger/patch-1
Ability for extensions to return custom data via api in response.images
2023-10-14 08:37:45 +03:00
AUTOMATIC1111 6fe16a9e1a Merge pull request #12991 from AUTOMATIC1111/but-report-template
Update bug_report.yml
2023-10-14 08:36:43 +03:00
AUTOMATIC1111 eadef35512 Merge pull request #13567 from LeonZhao28/bugfix_key_error_in_processing
fix the key error exception when processing override_settings keys
2023-10-14 08:34:41 +03:00
AUTOMATIC1111 771dac9c5f Merge pull request #13459 from wkpark/preview-fix
show the preview image in the modalview if available
2023-10-14 08:21:53 +03:00
AUTOMATIC1111 0619df9835 use shallow copy for #13535 2023-10-14 08:01:04 +03:00
AUTOMATIC1111 7cc96429f2 Merge pull request #13535 from chu8129/dev
fix: checkpoints_loaded:{checkpoint:state_dict}, model.load_state_dict issue in dict value empty
2023-10-14 08:00:04 +03:00
AUTOMATIC1111 26500b8c1b Merge pull request #13610 from v0xie/network-glora
Support inference with LyCORIS GLora networks
2023-10-14 07:52:52 +03:00
AUTOMATIC1111 a109c7aeb8 more general case of adding an infotext when no images have been generated 2023-10-14 07:49:03 +03:00
AUTOMATIC1111 27fdc26a74 Merge pull request #13630 from wkpark/indexerror-fix
fix IndexError
2023-10-14 07:46:34 +03:00
AUTOMATIC1111 3a66c3c9e1 put notification.mp3 option at the end of the page 2023-10-14 07:35:06 +03:00
AUTOMATIC1111 499543cf1d Merge pull request #13631 from galekseev/master
added option to play notification sound or not
2023-10-14 07:30:31 +03:00
AUTOMATIC1111 902afa6b4c Merge pull request #13364 from superhero-7/master
Add altdiffusion-m18 support
2023-10-14 07:29:01 +03:00
missionfloyd fff1a0c74f Make attention conversion optional
Fix square brackets multiplier
2023-10-13 17:18:02 -06:00
missionfloyd 954499a494 Convert (emphasis) to (emphasis:1.1)
per @SirVeggie's suggestion
2023-10-13 16:46:05 -06:00
Gleb Alekseev 44d14bc32e added option to play notification sound or not 2023-10-13 15:08:59 -03:00
Won-Kyu Park fbc8d21354 fix IndexError: list index out of range error interrupted while postprocess 2023-10-14 02:45:09 +09:00
v0xie 906d1179e9 support inference with LyCORIS GLora networks 2023-10-11 21:26:58 -07:00
Won-Kyu Park dbb10fbd8c show the preview image in the modalview if available 2023-10-11 21:56:17 +09:00
Leon 9821625a76 fix the key error exception when adding an overwriting key which is defined in the extensions 2023-10-09 18:36:48 +08:00
missionfloyd 3562b0dc74 Fix negative values 2023-10-07 15:52:16 -06:00
missionfloyd fd51b8501e Fix multi-line selections 2023-10-07 15:28:25 -06:00
missionfloyd 09a2da835e Add brackets, vertical bar to default delimiters 2023-10-07 14:48:43 -06:00
wangqiuwen 770ee23f18 reverst 2023-10-07 15:38:50 +08:00
wangqiuwen 76010a51ef up 2023-10-07 15:36:01 +08:00
missionfloyd e34949be52 Edit-attention fixes 2023-10-06 22:49:33 -06:00
w-e-w 35fd24e857 Less placeholder bug_report template 2023-10-03 23:05:48 +09:00
FluttyProger f71e919ecb Ability for extensions to return custom data via api in response.images 2023-10-01 18:06:48 +03:00
superhero-7 2d947175b9 fix linter issues 2023-10-01 12:25:19 +08:00
superhero-7 f8f4ff2bb8 support altdiffusion-m18 2023-09-23 17:55:19 +08:00
superhero-7 702a1e1cc7 support m18 2023-09-23 17:51:41 +08:00
w-e-w 74b80e7211 add comment 2023-09-12 09:29:07 +09:00
w-e-w e785402b6a return nothing if not found 2023-09-11 19:37:55 +09:00
w-e-w f5959c1c30 thread safe extra network using list 2023-09-09 17:05:50 +09:00
w-e-w 25de9a785c Revert "thread safe extra network list_items"
This reverts commit aab385d01b.
2023-09-09 16:56:19 +09:00
w-e-w c3d51fc696 Update bug_report.yml 2023-09-07 19:35:55 +09:00
catboxanon 25189b29af Grammar fixes 2023-09-05 22:13:36 -04:00
w-e-w aab385d01b thread safe extra network list_items 2023-09-03 11:56:02 +09:00
w-e-w 061a4a295d Update bug_report.yml 2023-09-02 18:11:08 +09:00
Robert Barron c4ee6d9b73 xyz_grid: allow varying the seed along an axis along with the axis's other changes 2023-07-30 03:45:02 -07:00
163 changed files with 6828 additions and 6012 deletions
+52 -21
View File
@@ -1,25 +1,45 @@
name: Bug Report name: Bug Report
description: You think somethings is broken in the UI description: You think something is broken in the UI
title: "[Bug]: " title: "[Bug]: "
labels: ["bug-report"] labels: ["bug-report"]
body: body:
- type: checkboxes
attributes:
label: Is there an existing issue for this?
description: Please search to see if an issue already exists for the bug you encountered, and that it hasn't been fixed in a recent build/commit.
options:
- label: I have searched the existing issues and checked the recent builds/commits
required: true
- type: markdown - type: markdown
attributes: attributes:
value: | value: |
*Please fill this form with as much information as possible, don't forget to fill "What OS..." and "What browsers" and *provide screenshots if possible** > The title of the bug report should be short and descriptive.
> Use relevant keywords for searchability.
> Do not leave it blank, but also do not put an entire error log in it.
- type: checkboxes
attributes:
label: Checklist
description: |
Please perform basic debugging to see if extensions or configuration is the cause of the issue.
Basic debug procedure
 1. Disable all third-party extensions - check if extension is the cause
 2. Update extensions and webui - sometimes things just need to be updated
 3. Backup and remove your config.json and ui-config.json - check if the issue is caused by bad configuration
 4. Delete venv with third-party extensions disabled - sometimes extensions might cause wrong libraries to be installed
 5. Try a fresh installation webui in a different directory - see if a clean installation solves the issue
Before making a issue report please, check that the issue hasn't been reported recently.
options:
- label: The issue exists after disabling all extensions
- label: The issue exists on a clean installation of webui
- label: The issue is caused by an extension, but I believe it is caused by a bug in the webui
- label: The issue exists in the current version of the webui
- label: The issue has not been reported before recently
- label: The issue has been reported before but has not been fixed yet
- type: markdown
attributes:
value: |
> Please fill this form with as much information as possible. Don't forget to "Upload Sysinfo" and "What browsers" and provide screenshots if possible
- type: textarea - type: textarea
id: what-did id: what-did
attributes: attributes:
label: What happened? label: What happened?
description: Tell us what happened in a very clear and simple way description: Tell us what happened in a very clear and simple way
placeholder: |
txt2img is not working as intended.
validations: validations:
required: true required: true
- type: textarea - type: textarea
@@ -27,9 +47,9 @@ body:
attributes: attributes:
label: Steps to reproduce the problem label: Steps to reproduce the problem
description: Please provide us with precise step by step instructions on how to reproduce the bug description: Please provide us with precise step by step instructions on how to reproduce the bug
value: | placeholder: |
1. Go to .... 1. Go to ...
2. Press .... 2. Press ...
3. ... 3. ...
validations: validations:
required: true required: true
@@ -38,13 +58,8 @@ body:
attributes: attributes:
label: What should have happened? label: What should have happened?
description: Tell us what you think the normal behavior should be description: Tell us what you think the normal behavior should be
validations: placeholder: |
required: true WebUI should ...
- type: textarea
id: sysinfo
attributes:
label: Sysinfo
description: System info file, generated by WebUI. You can generate it in settings, on the Sysinfo page. Drag the file into the field to upload it. If you submit your report without including the sysinfo file, the report will be closed. If needed, review the report to make sure it includes no personal information you don't want to share. If you can't start WebUI, you can use --dump-sysinfo commandline argument to generate the file.
validations: validations:
required: true required: true
- type: dropdown - type: dropdown
@@ -58,12 +73,25 @@ body:
- Brave - Brave
- Apple Safari - Apple Safari
- Microsoft Edge - Microsoft Edge
- Android
- iOS
- Other - Other
- type: textarea
id: sysinfo
attributes:
label: Sysinfo
description: System info file, generated by WebUI. You can generate it in settings, on the Sysinfo page. Drag the file into the field to upload it. If you submit your report without including the sysinfo file, the report will be closed. If needed, review the report to make sure it includes no personal information you don't want to share. If you can't start WebUI, you can use --dump-sysinfo commandline argument to generate the file.
placeholder: |
1. Go to WebUI Settings -> Sysinfo -> Download system info.
If WebUI fails to launch, use --dump-sysinfo commandline argument to generate the file
2. Upload the Sysinfo as a attached file, Do NOT paste it in as plain text.
validations:
required: true
- type: textarea - type: textarea
id: logs id: logs
attributes: attributes:
label: Console logs label: Console logs
description: Please provide **full** cmd/terminal logs from the moment you started UI to the end of it, after your bug happened. If it's very long, provide a link to pastebin or similar service. description: Please provide **full** cmd/terminal logs from the moment you started UI to the end of it, after the bug occured. If it's very long, provide a link to pastebin or similar service.
render: Shell render: Shell
validations: validations:
required: true required: true
@@ -71,4 +99,7 @@ body:
id: misc id: misc
attributes: attributes:
label: Additional information label: Additional information
description: Please provide us with any relevant additional info or context. description: |
Please provide us with any relevant additional info or context.
Examples:
 I have updated my GPU driver recently.
+1 -1
View File
@@ -20,7 +20,7 @@ jobs:
# not to have GHA download an (at the time of writing) 4 GB cache # not to have GHA download an (at the time of writing) 4 GB cache
# of PyTorch and other dependencies. # of PyTorch and other dependencies.
- name: Install Ruff - name: Install Ruff
run: pip install ruff==0.0.272 run: pip install ruff==0.1.6
- name: Run Ruff - name: Run Ruff
run: ruff . run: ruff .
lint-js: lint-js:
+9 -1
View File
@@ -20,6 +20,12 @@ jobs:
cache-dependency-path: | cache-dependency-path: |
**/requirements*txt **/requirements*txt
launch.py launch.py
- name: Cache models
id: cache-models
uses: actions/cache@v3
with:
path: models
key: "2023-12-30"
- name: Install test dependencies - name: Install test dependencies
run: pip install wait-for-it -r requirements-test.txt run: pip install wait-for-it -r requirements-test.txt
env: env:
@@ -33,6 +39,8 @@ jobs:
TORCH_INDEX_URL: https://download.pytorch.org/whl/cpu TORCH_INDEX_URL: https://download.pytorch.org/whl/cpu
WEBUI_LAUNCH_LIVE_OUTPUT: "1" WEBUI_LAUNCH_LIVE_OUTPUT: "1"
PYTHONUNBUFFERED: "1" PYTHONUNBUFFERED: "1"
- name: Print installed packages
run: pip freeze
- name: Start test server - name: Start test server
run: > run: >
python -m coverage run python -m coverage run
@@ -49,7 +57,7 @@ jobs:
2>&1 | tee output.txt & 2>&1 | tee output.txt &
- name: Run tests - name: Run tests
run: | run: |
wait-for-it --service 127.0.0.1:7860 -t 600 wait-for-it --service 127.0.0.1:7860 -t 20
python -m pytest -vv --junitxml=test/results.xml --cov . --cov-report=xml --verify-base-url test python -m pytest -vv --junitxml=test/results.xml --cov . --cov-report=xml --verify-base-url test
- name: Kill test server - name: Kill test server
if: always() if: always()
+1
View File
@@ -37,3 +37,4 @@ notification.mp3
/node_modules /node_modules
/package-lock.json /package-lock.json
/.coverage* /.coverage*
/test/test_outputs
+167
View File
@@ -1,3 +1,170 @@
## 1.7.0
### Features:
* settings tab rework: add search field, add categories, split UI settings page into many
* add altdiffusion-m18 support ([#13364](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13364))
* support inference with LyCORIS GLora networks ([#13610](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13610))
* add lora-embedding bundle system ([#13568](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13568))
* option to move prompt from top row into generation parameters
* add support for SSD-1B ([#13865](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13865))
* support inference with OFT networks ([#13692](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13692))
* script metadata and DAG sorting mechanism ([#13944](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13944))
* support HyperTile optimization ([#13948](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13948))
* add support for SD 2.1 Turbo ([#14170](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14170))
* remove Train->Preprocessing tab and put all its functionality into Extras tab
* initial IPEX support for Intel Arc GPU ([#14171](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14171))
### Minor:
* allow reading model hash from images in img2img batch mode ([#12767](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12767))
* add option to align with sgm repo's sampling implementation ([#12818](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12818))
* extra field for lora metadata viewer: `ss_output_name` ([#12838](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12838))
* add action in settings page to calculate all SD checkpoint hashes ([#12909](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12909))
* add button to copy prompt to style editor ([#12975](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12975))
* add --skip-load-model-at-start option ([#13253](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13253))
* write infotext to gif images
* read infotext from gif images ([#13068](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13068))
* allow configuring the initial state of InputAccordion in ui-config.json ([#13189](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13189))
* allow editing whitespace delimiters for ctrl+up/ctrl+down prompt editing ([#13444](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13444))
* prevent accidentally closing popup dialogs ([#13480](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13480))
* added option to play notification sound or not ([#13631](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13631))
* show the preview image in the full screen image viewer if available ([#13459](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13459))
* support for webui.settings.bat ([#13638](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13638))
* add an option to not print stack traces on ctrl+c
* start/restart generation by Ctrl (Alt) + Enter ([#13644](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13644))
* update prompts_from_file script to allow concatenating entries with the general prompt ([#13733](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13733))
* added a visible checkbox to input accordion
* added an option to hide all txt2img/img2img parameters in an accordion ([#13826](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13826))
* added 'Path' sorting option for Extra network cards ([#13968](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13968))
* enable prompt hotkeys in style editor ([#13931](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13931))
* option to show batch img2img results in UI ([#14009](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14009))
* infotext updates: add option to disregard certain infotext fields, add option to not include VAE in infotext, add explanation to infotext settings page, move some options to infotext settings page
* add FP32 fallback support on sd_vae_approx ([#14046](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14046))
* support XYZ scripts / split hires path from unet ([#14126](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14126))
* allow use of mutiple styles csv files ([#14125](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14125))
### Extensions and API:
* update gradio to 3.41.2
* support installed extensions list api ([#12774](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12774))
* update pnginfo API to return dict with parsed values
* add noisy latent to `ExtraNoiseParams` for callback ([#12856](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12856))
* show extension datetime in UTC ([#12864](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12864), [#12865](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12865), [#13281](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13281))
* add an option to choose how to combine hires fix and refiner
* include program version in info response. ([#13135](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13135))
* sd_unet support for SDXL
* patch DDPM.register_betas so that users can put given_betas in model yaml ([#13276](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13276))
* xyz_grid: add prepare ([#13266](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13266))
* allow multiple localization files with same language in extensions ([#13077](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13077))
* add onEdit function for js and rework token-counter.js to use it
* fix the key error exception when processing override_settings keys ([#13567](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13567))
* ability for extensions to return custom data via api in response.images ([#13463](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13463))
* call state.jobnext() before postproces*() ([#13762](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13762))
* add option to set notification sound volume ([#13884](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13884))
* update Ruff to 0.1.6 ([#14059](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14059))
* add Block component creation callback ([#14119](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14119))
* catch uncaught exception with ui creation scripts ([#14120](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14120))
* use extension name for determining an extension is installed in the index ([#14063](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14063))
* update is_installed() from launch_utils.py to fix reinstalling already installed packages ([#14192](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14192))
### Bug Fixes:
* fix pix2pix producing bad results
* fix defaults settings page breaking when any of main UI tabs are hidden
* fix error that causes some extra networks to be disabled if both <lora:> and <lyco:> are present in the prompt
* fix for Reload UI function: if you reload UI on one tab, other opened tabs will no longer stop working
* prevent duplicate resize handler ([#12795](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12795))
* small typo: vae resolve bug ([#12797](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12797))
* hide broken image crop tool ([#12792](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12792))
* don't show hidden samplers in dropdown for XYZ script ([#12780](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12780))
* fix style editing dialog breaking if it's opened in both img2img and txt2img tabs
* hide --gradio-auth and --api-auth values from /internal/sysinfo report
* add missing infotext for RNG in options ([#12819](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12819))
* fix notification not playing when built-in webui tab is inactive ([#12834](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12834))
* honor `--skip-install` for extension installers ([#12832](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12832))
* don't print blank stdout in extension installers ([#12833](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12833), [#12855](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12855))
* get progressbar to display correctly in extensions tab
* keep order in list of checkpoints when loading model that doesn't have a checksum
* fix inpainting models in txt2img creating black pictures
* fix generation params regex ([#12876](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12876))
* fix batch img2img output dir with script ([#12926](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12926))
* fix #13080 - Hypernetwork/TI preview generation ([#13084](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13084))
* fix bug with sigma min/max overrides. ([#12995](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12995))
* more accurate check for enabling cuDNN benchmark on 16XX cards ([#12924](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12924))
* don't use multicond parser for negative prompt counter ([#13118](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13118))
* fix data-sort-name containing spaces ([#13412](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13412))
* update card on correct tab when editing metadata ([#13411](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13411))
* fix viewing/editing metadata when filename contains an apostrophe ([#13395](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13395))
* fix: --sd_model in "Prompts from file or textbox" script is not working ([#13302](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13302))
* better Support for Portable Git ([#13231](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13231))
* fix issues when webui_dir is not work_dir ([#13210](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13210))
* fix: lora-bias-backup don't reset cache ([#13178](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13178))
* account for customizable extra network separators whyen removing extra network text from the prompt ([#12877](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12877))
* re fix batch img2img output dir with script ([#13170](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13170))
* fix `--ckpt-dir` path separator and option use `short name` for checkpoint dropdown ([#13139](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13139))
* consolidated allowed preview formats, Fix extra network `.gif` not woking as preview ([#13121](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13121))
* fix venv_dir=- environment variable not working as expected on linux ([#13469](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13469))
* repair unload sd checkpoint button
* edit-attention fixes ([#13533](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13533))
* fix bug when using --gfpgan-models-path ([#13718](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13718))
* properly apply sort order for extra network cards when selected from dropdown
* fixes generation restart not working for some users when 'Ctrl+Enter' is pressed ([#13962](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13962))
* thread safe extra network list_items ([#13014](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13014))
* fix not able to exit metadata popup when pop up is too big ([#14156](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14156))
* fix auto focal point crop for opencv >= 4.8 ([#14121](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14121))
* make 'use-cpu all' actually apply to 'all' ([#14131](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14131))
* extras tab batch: actually use original filename
* make webui not crash when running with --disable-all-extensions option
### Other:
* non-local condition ([#12814](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12814))
* fix minor typos ([#12827](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12827))
* remove xformers Python version check ([#12842](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12842))
* style: file-metadata word-break ([#12837](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12837))
* revert SGM noise multiplier change for img2img because it breaks hires fix
* do not change quicksettings dropdown option when value returned is `None` ([#12854](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12854))
* [RC 1.6.0 - zoom is partly hidden] Update style.css ([#12839](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12839))
* chore: change extension time format ([#12851](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12851))
* WEBUI.SH - Use torch 2.1.0 release candidate for Navi 3 ([#12929](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12929))
* add Fallback at images.read_info_from_image if exif data was invalid ([#13028](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13028))
* update cmd arg description ([#12986](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12986))
* fix: update shared.opts.data when add_option ([#12957](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12957), [#13213](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13213))
* restore missing tooltips ([#12976](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12976))
* use default dropdown padding on mobile ([#12880](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12880))
* put enable console prompts option into settings from commandline args ([#13119](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13119))
* fix some deprecated types ([#12846](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12846))
* bump to torchsde==0.2.6 ([#13418](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13418))
* update dragdrop.js ([#13372](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13372))
* use orderdict as lru cache:opt/bug ([#13313](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13313))
* XYZ if not include sub grids do not save sub grid ([#13282](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13282))
* initialize state.time_start befroe state.job_count ([#13229](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13229))
* fix fieldname regex ([#13458](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13458))
* change denoising_strength default to None. ([#13466](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13466))
* fix regression ([#13475](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13475))
* fix IndexError ([#13630](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13630))
* fix: checkpoints_loaded:{checkpoint:state_dict}, model.load_state_dict issue in dict value empty ([#13535](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13535))
* update bug_report.yml ([#12991](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12991))
* requirements_versions httpx==0.24.1 ([#13839](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13839))
* fix parenthesis auto selection ([#13829](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13829))
* fix #13796 ([#13797](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13797))
* corrected a typo in `modules/cmd_args.py` ([#13855](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13855))
* feat: fix randn found element of type float at pos 2 ([#14004](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14004))
* adds tqdm handler to logging_config.py for progress bar integration ([#13996](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13996))
* hotfix: call shared.state.end() after postprocessing done ([#13977](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13977))
* fix dependency address patch 1 ([#13929](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13929))
* save sysinfo as .json ([#14035](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14035))
* move exception_records related methods to errors.py ([#14084](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14084))
* compatibility ([#13936](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13936))
* json.dump(ensure_ascii=False) ([#14108](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14108))
* dir buttons start with / so only the correct dir will be shown and no… ([#13957](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13957))
* alternate implementation for unet forward replacement that does not depend on hijack being applied
* re-add `keyedit_delimiters_whitespace` setting lost as part of commit e294e46 ([#14178](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14178))
* fix `save_samples` being checked early when saving masked composite ([#14177](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14177))
* slight optimization for mask and mask_composite ([#14181](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14181))
* add import_hook hack to work around basicsr/torchvision incompatibility ([#14186](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14186))
## 1.6.1
### Bug Fixes:
* fix an error causing the webui to fail to start ([#13839](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13839))
## 1.6.0 ## 1.6.0
### Features: ### Features:
+13 -8
View File
@@ -1,5 +1,5 @@
# Stable Diffusion web UI # Stable Diffusion web UI
A browser interface based on Gradio library for Stable Diffusion. A web interface for Stable Diffusion, implemented using Gradio library.
![](screenshot.png) ![](screenshot.png)
@@ -91,6 +91,7 @@ A browser interface based on Gradio library for Stable Diffusion.
- Eased resolution restriction: generated image's dimensions must be a multiple of 8 rather than 64 - Eased resolution restriction: generated image's dimensions must be a multiple of 8 rather than 64
- Now with a license! - Now with a license!
- Reorder elements in the UI from settings screen - Reorder elements in the UI from settings screen
- [Segmind Stable Diffusion](https://huggingface.co/segmind/SSD-1B) support
## Installation and Running ## Installation and Running
Make sure the required [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) are met and follow the instructions available for: Make sure the required [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) are met and follow the instructions available for:
@@ -120,7 +121,9 @@ Alternatively, use online services (like Google Colab):
# Debian-based: # Debian-based:
sudo apt install wget git python3 python3-venv libgl1 libglib2.0-0 sudo apt install wget git python3 python3-venv libgl1 libglib2.0-0
# Red Hat-based: # Red Hat-based:
sudo dnf install wget git python3 sudo dnf install wget git python3 gperftools-libs libglvnd-glx
# openSUSE-based:
sudo zypper install wget git python3 libtcmalloc4 libglvnd
# Arch-based: # Arch-based:
sudo pacman -S wget git python3 sudo pacman -S wget git python3
``` ```
@@ -146,13 +149,14 @@ For the purposes of getting Google and other search engines to crawl the wiki, h
## Credits ## Credits
Licenses for borrowed code can be found in `Settings -> Licenses` screen, and also in `html/licenses.html` file. Licenses for borrowed code can be found in `Settings -> Licenses` screen, and also in `html/licenses.html` file.
- Stable Diffusion - https://github.com/CompVis/stable-diffusion, https://github.com/CompVis/taming-transformers - Stable Diffusion - https://github.com/Stability-AI/stablediffusion, https://github.com/CompVis/taming-transformers
- k-diffusion - https://github.com/crowsonkb/k-diffusion.git - k-diffusion - https://github.com/crowsonkb/k-diffusion.git
- GFPGAN - https://github.com/TencentARC/GFPGAN.git - Spandrel - https://github.com/chaiNNer-org/spandrel implementing
- CodeFormer - https://github.com/sczhou/CodeFormer - GFPGAN - https://github.com/TencentARC/GFPGAN.git
- ESRGAN - https://github.com/xinntao/ESRGAN - CodeFormer - https://github.com/sczhou/CodeFormer
- SwinIR - https://github.com/JingyunLiang/SwinIR - ESRGAN - https://github.com/xinntao/ESRGAN
- Swin2SR - https://github.com/mv-lab/swin2sr - SwinIR - https://github.com/JingyunLiang/SwinIR
- Swin2SR - https://github.com/mv-lab/swin2sr
- LDSR - https://github.com/Hafiidz/latent-diffusion - LDSR - https://github.com/Hafiidz/latent-diffusion
- MiDaS - https://github.com/isl-org/MiDaS - MiDaS - https://github.com/isl-org/MiDaS
- Ideas for optimizations - https://github.com/basujindal/stable-diffusion - Ideas for optimizations - https://github.com/basujindal/stable-diffusion
@@ -173,5 +177,6 @@ Licenses for borrowed code can be found in `Settings -> Licenses` screen, and al
- TAESD - Ollin Boer Bohan - https://github.com/madebyollin/taesd - TAESD - Ollin Boer Bohan - https://github.com/madebyollin/taesd
- LyCORIS - KohakuBlueleaf - LyCORIS - KohakuBlueleaf
- Restart sampling - lambertae - https://github.com/Newbeeer/diffusion_restart_sampling - Restart sampling - lambertae - https://github.com/Newbeeer/diffusion_restart_sampling
- Hypertile - tfernd - https://github.com/tfernd/HyperTile
- Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user. - Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user.
- (You) - (You)
+73
View File
@@ -0,0 +1,73 @@
model:
base_learning_rate: 1.0e-04
target: ldm.models.diffusion.ddpm.LatentDiffusion
params:
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: "jpg"
cond_stage_key: "txt"
image_size: 64
channels: 4
cond_stage_trainable: false # Note: different from the one we trained before
conditioning_key: crossattn
monitor: val/loss_simple_ema
scale_factor: 0.18215
use_ema: False
scheduler_config: # 10000 warmup steps
target: ldm.lr_scheduler.LambdaLinearScheduler
params:
warm_up_steps: [ 10000 ]
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
f_start: [ 1.e-6 ]
f_max: [ 1. ]
f_min: [ 1. ]
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
image_size: 32 # unused
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_head_channels: 64
use_spatial_transformer: True
use_linear_in_transformer: True
transformer_depth: 1
context_dim: 1024
use_checkpoint: True
legacy: False
first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: modules.xlmr_m18.BertSeriesModelWithTransformation
params:
name: "XLMR-Large"
+98
View File
@@ -0,0 +1,98 @@
model:
target: sgm.models.diffusion.DiffusionEngine
params:
scale_factor: 0.13025
disable_first_stage_autocast: True
denoiser_config:
target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
params:
num_idx: 1000
weighting_config:
target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
scaling_config:
target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
discretization_config:
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
network_config:
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
params:
adm_in_channels: 2816
num_classes: sequential
use_checkpoint: True
in_channels: 9
out_channels: 4
model_channels: 320
attention_resolutions: [4, 2]
num_res_blocks: 2
channel_mult: [1, 2, 4]
num_head_channels: 64
use_spatial_transformer: True
use_linear_in_transformer: True
transformer_depth: [1, 2, 10] # note: the first is unused (due to attn_res starting at 2) 32, 16, 8 --> 64, 32, 16
context_dim: 2048
spatial_transformer_attn_type: softmax-xformers
legacy: False
conditioner_config:
target: sgm.modules.GeneralConditioner
params:
emb_models:
# crossattn cond
- is_trainable: False
input_key: txt
target: sgm.modules.encoders.modules.FrozenCLIPEmbedder
params:
layer: hidden
layer_idx: 11
# crossattn and vector cond
- is_trainable: False
input_key: txt
target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder2
params:
arch: ViT-bigG-14
version: laion2b_s39b_b160k
freeze: True
layer: penultimate
always_return_pooled: True
legacy: False
# vector cond
- is_trainable: False
input_key: original_size_as_tuple
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256 # multiplied by two
# vector cond
- is_trainable: False
input_key: crop_coords_top_left
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256 # multiplied by two
# vector cond
- is_trainable: False
input_key: target_size_as_tuple
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256 # multiplied by two
first_stage_config:
target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
attn_type: vanilla-xformers
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [1, 2, 4, 4]
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
+47
View File
@@ -19,3 +19,50 @@ def rebuild_cp_decomposition(up, down, mid):
up = up.reshape(up.size(0), -1) up = up.reshape(up.size(0), -1)
down = down.reshape(down.size(0), -1) down = down.reshape(down.size(0), -1)
return torch.einsum('n m k l, i n, m j -> i j k l', mid, up, down) return torch.einsum('n m k l, i n, m j -> i j k l', mid, up, down)
# copied from https://github.com/KohakuBlueleaf/LyCORIS/blob/dev/lycoris/modules/lokr.py
def factorization(dimension: int, factor:int=-1) -> tuple[int, int]:
'''
return a tuple of two value of input dimension decomposed by the number closest to factor
second value is higher or equal than first value.
In LoRA with Kroneckor Product, first value is a value for weight scale.
secon value is a value for weight.
Becuase of non-commutative property, A⊗B ≠ B⊗A. Meaning of two matrices is slightly different.
examples)
factor
-1 2 4 8 16 ...
127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127
128 -> 8, 16 128 -> 2, 64 128 -> 4, 32 128 -> 8, 16 128 -> 8, 16
250 -> 10, 25 250 -> 2, 125 250 -> 2, 125 250 -> 5, 50 250 -> 10, 25
360 -> 8, 45 360 -> 2, 180 360 -> 4, 90 360 -> 8, 45 360 -> 12, 30
512 -> 16, 32 512 -> 2, 256 512 -> 4, 128 512 -> 8, 64 512 -> 16, 32
1024 -> 32, 32 1024 -> 2, 512 1024 -> 4, 256 1024 -> 8, 128 1024 -> 16, 64
'''
if factor > 0 and (dimension % factor) == 0:
m = factor
n = dimension // factor
if m > n:
n, m = m, n
return m, n
if factor < 0:
factor = dimension
m, n = 1, dimension
length = m + n
while m<n:
new_m = m + 1
while dimension%new_m != 0:
new_m += 1
new_n = dimension // new_m
if new_m + new_n > length or new_m>factor:
break
else:
m, n = new_m, new_n
if m > n:
n, m = m, n
return m, n
+33 -2
View File
@@ -3,6 +3,9 @@ import os
from collections import namedtuple from collections import namedtuple
import enum import enum
import torch.nn as nn
import torch.nn.functional as F
from modules import sd_models, cache, errors, hashes, shared from modules import sd_models, cache, errors, hashes, shared
NetworkWeights = namedtuple('NetworkWeights', ['network_key', 'sd_key', 'w', 'sd_module']) NetworkWeights = namedtuple('NetworkWeights', ['network_key', 'sd_key', 'w', 'sd_module'])
@@ -115,6 +118,29 @@ class NetworkModule:
if hasattr(self.sd_module, 'weight'): if hasattr(self.sd_module, 'weight'):
self.shape = self.sd_module.weight.shape self.shape = self.sd_module.weight.shape
self.ops = None
self.extra_kwargs = {}
if isinstance(self.sd_module, nn.Conv2d):
self.ops = F.conv2d
self.extra_kwargs = {
'stride': self.sd_module.stride,
'padding': self.sd_module.padding
}
elif isinstance(self.sd_module, nn.Linear):
self.ops = F.linear
elif isinstance(self.sd_module, nn.LayerNorm):
self.ops = F.layer_norm
self.extra_kwargs = {
'normalized_shape': self.sd_module.normalized_shape,
'eps': self.sd_module.eps
}
elif isinstance(self.sd_module, nn.GroupNorm):
self.ops = F.group_norm
self.extra_kwargs = {
'num_groups': self.sd_module.num_groups,
'eps': self.sd_module.eps
}
self.dim = None self.dim = None
self.bias = weights.w.get("bias") self.bias = weights.w.get("bias")
self.alpha = weights.w["alpha"].item() if "alpha" in weights.w else None self.alpha = weights.w["alpha"].item() if "alpha" in weights.w else None
@@ -137,7 +163,7 @@ class NetworkModule:
def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None): def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None):
if self.bias is not None: if self.bias is not None:
updown = updown.reshape(self.bias.shape) updown = updown.reshape(self.bias.shape)
updown += self.bias.to(orig_weight.device, dtype=orig_weight.dtype) updown += self.bias.to(orig_weight.device, dtype=updown.dtype)
updown = updown.reshape(output_shape) updown = updown.reshape(output_shape)
if len(output_shape) == 4: if len(output_shape) == 4:
@@ -155,5 +181,10 @@ class NetworkModule:
raise NotImplementedError() raise NotImplementedError()
def forward(self, x, y): def forward(self, x, y):
raise NotImplementedError() """A general forward implementation for all modules"""
if self.ops is None:
raise NotImplementedError()
else:
updown, ex_bias = self.calc_updown(self.sd_module.weight)
return y + self.ops(x, weight=updown, bias=ex_bias, **self.extra_kwargs)
+2 -2
View File
@@ -18,9 +18,9 @@ class NetworkModuleFull(network.NetworkModule):
def calc_updown(self, orig_weight): def calc_updown(self, orig_weight):
output_shape = self.weight.shape output_shape = self.weight.shape
updown = self.weight.to(orig_weight.device, dtype=orig_weight.dtype) updown = self.weight.to(orig_weight.device)
if self.ex_bias is not None: if self.ex_bias is not None:
ex_bias = self.ex_bias.to(orig_weight.device, dtype=orig_weight.dtype) ex_bias = self.ex_bias.to(orig_weight.device)
else: else:
ex_bias = None ex_bias = None
+33
View File
@@ -0,0 +1,33 @@
import network
class ModuleTypeGLora(network.ModuleType):
def create_module(self, net: network.Network, weights: network.NetworkWeights):
if all(x in weights.w for x in ["a1.weight", "a2.weight", "alpha", "b1.weight", "b2.weight"]):
return NetworkModuleGLora(net, weights)
return None
# adapted from https://github.com/KohakuBlueleaf/LyCORIS
class NetworkModuleGLora(network.NetworkModule):
def __init__(self, net: network.Network, weights: network.NetworkWeights):
super().__init__(net, weights)
if hasattr(self.sd_module, 'weight'):
self.shape = self.sd_module.weight.shape
self.w1a = weights.w["a1.weight"]
self.w1b = weights.w["b1.weight"]
self.w2a = weights.w["a2.weight"]
self.w2b = weights.w["b2.weight"]
def calc_updown(self, orig_weight):
w1a = self.w1a.to(orig_weight.device)
w1b = self.w1b.to(orig_weight.device)
w2a = self.w2a.to(orig_weight.device)
w2b = self.w2b.to(orig_weight.device)
output_shape = [w1a.size(0), w1b.size(1)]
updown = ((w2b @ w1b) + ((orig_weight.to(dtype = w1a.dtype) @ w2a) @ w1a))
return self.finalize_updown(updown, orig_weight, output_shape)
+6 -6
View File
@@ -27,16 +27,16 @@ class NetworkModuleHada(network.NetworkModule):
self.t2 = weights.w.get("hada_t2") self.t2 = weights.w.get("hada_t2")
def calc_updown(self, orig_weight): def calc_updown(self, orig_weight):
w1a = self.w1a.to(orig_weight.device, dtype=orig_weight.dtype) w1a = self.w1a.to(orig_weight.device)
w1b = self.w1b.to(orig_weight.device, dtype=orig_weight.dtype) w1b = self.w1b.to(orig_weight.device)
w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype) w2a = self.w2a.to(orig_weight.device)
w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype) w2b = self.w2b.to(orig_weight.device)
output_shape = [w1a.size(0), w1b.size(1)] output_shape = [w1a.size(0), w1b.size(1)]
if self.t1 is not None: if self.t1 is not None:
output_shape = [w1a.size(1), w1b.size(1)] output_shape = [w1a.size(1), w1b.size(1)]
t1 = self.t1.to(orig_weight.device, dtype=orig_weight.dtype) t1 = self.t1.to(orig_weight.device)
updown1 = lyco_helpers.make_weight_cp(t1, w1a, w1b) updown1 = lyco_helpers.make_weight_cp(t1, w1a, w1b)
output_shape += t1.shape[2:] output_shape += t1.shape[2:]
else: else:
@@ -45,7 +45,7 @@ class NetworkModuleHada(network.NetworkModule):
updown1 = lyco_helpers.rebuild_conventional(w1a, w1b, output_shape) updown1 = lyco_helpers.rebuild_conventional(w1a, w1b, output_shape)
if self.t2 is not None: if self.t2 is not None:
t2 = self.t2.to(orig_weight.device, dtype=orig_weight.dtype) t2 = self.t2.to(orig_weight.device)
updown2 = lyco_helpers.make_weight_cp(t2, w2a, w2b) updown2 = lyco_helpers.make_weight_cp(t2, w2a, w2b)
else: else:
updown2 = lyco_helpers.rebuild_conventional(w2a, w2b, output_shape) updown2 = lyco_helpers.rebuild_conventional(w2a, w2b, output_shape)
+1 -1
View File
@@ -17,7 +17,7 @@ class NetworkModuleIa3(network.NetworkModule):
self.on_input = weights.w["on_input"].item() self.on_input = weights.w["on_input"].item()
def calc_updown(self, orig_weight): def calc_updown(self, orig_weight):
w = self.w.to(orig_weight.device, dtype=orig_weight.dtype) w = self.w.to(orig_weight.device)
output_shape = [w.size(0), orig_weight.size(1)] output_shape = [w.size(0), orig_weight.size(1)]
if self.on_input: if self.on_input:
+9 -9
View File
@@ -37,22 +37,22 @@ class NetworkModuleLokr(network.NetworkModule):
def calc_updown(self, orig_weight): def calc_updown(self, orig_weight):
if self.w1 is not None: if self.w1 is not None:
w1 = self.w1.to(orig_weight.device, dtype=orig_weight.dtype) w1 = self.w1.to(orig_weight.device)
else: else:
w1a = self.w1a.to(orig_weight.device, dtype=orig_weight.dtype) w1a = self.w1a.to(orig_weight.device)
w1b = self.w1b.to(orig_weight.device, dtype=orig_weight.dtype) w1b = self.w1b.to(orig_weight.device)
w1 = w1a @ w1b w1 = w1a @ w1b
if self.w2 is not None: if self.w2 is not None:
w2 = self.w2.to(orig_weight.device, dtype=orig_weight.dtype) w2 = self.w2.to(orig_weight.device)
elif self.t2 is None: elif self.t2 is None:
w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype) w2a = self.w2a.to(orig_weight.device)
w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype) w2b = self.w2b.to(orig_weight.device)
w2 = w2a @ w2b w2 = w2a @ w2b
else: else:
t2 = self.t2.to(orig_weight.device, dtype=orig_weight.dtype) t2 = self.t2.to(orig_weight.device)
w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype) w2a = self.w2a.to(orig_weight.device)
w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype) w2b = self.w2b.to(orig_weight.device)
w2 = lyco_helpers.make_weight_cp(t2, w2a, w2b) w2 = lyco_helpers.make_weight_cp(t2, w2a, w2b)
output_shape = [w1.size(0) * w2.size(0), w1.size(1) * w2.size(1)] output_shape = [w1.size(0) * w2.size(0), w1.size(1) * w2.size(1)]
+3 -3
View File
@@ -61,13 +61,13 @@ class NetworkModuleLora(network.NetworkModule):
return module return module
def calc_updown(self, orig_weight): def calc_updown(self, orig_weight):
up = self.up_model.weight.to(orig_weight.device, dtype=orig_weight.dtype) up = self.up_model.weight.to(orig_weight.device)
down = self.down_model.weight.to(orig_weight.device, dtype=orig_weight.dtype) down = self.down_model.weight.to(orig_weight.device)
output_shape = [up.size(0), down.size(1)] output_shape = [up.size(0), down.size(1)]
if self.mid_model is not None: if self.mid_model is not None:
# cp-decomposition # cp-decomposition
mid = self.mid_model.weight.to(orig_weight.device, dtype=orig_weight.dtype) mid = self.mid_model.weight.to(orig_weight.device)
updown = lyco_helpers.rebuild_cp_decomposition(up, down, mid) updown = lyco_helpers.rebuild_cp_decomposition(up, down, mid)
output_shape += mid.shape[2:] output_shape += mid.shape[2:]
else: else:
+2 -2
View File
@@ -18,10 +18,10 @@ class NetworkModuleNorm(network.NetworkModule):
def calc_updown(self, orig_weight): def calc_updown(self, orig_weight):
output_shape = self.w_norm.shape output_shape = self.w_norm.shape
updown = self.w_norm.to(orig_weight.device, dtype=orig_weight.dtype) updown = self.w_norm.to(orig_weight.device)
if self.b_norm is not None: if self.b_norm is not None:
ex_bias = self.b_norm.to(orig_weight.device, dtype=orig_weight.dtype) ex_bias = self.b_norm.to(orig_weight.device)
else: else:
ex_bias = None ex_bias = None
+82
View File
@@ -0,0 +1,82 @@
import torch
import network
from lyco_helpers import factorization
from einops import rearrange
class ModuleTypeOFT(network.ModuleType):
def create_module(self, net: network.Network, weights: network.NetworkWeights):
if all(x in weights.w for x in ["oft_blocks"]) or all(x in weights.w for x in ["oft_diag"]):
return NetworkModuleOFT(net, weights)
return None
# Supports both kohya-ss' implementation of COFT https://github.com/kohya-ss/sd-scripts/blob/main/networks/oft.py
# and KohakuBlueleaf's implementation of OFT/COFT https://github.com/KohakuBlueleaf/LyCORIS/blob/dev/lycoris/modules/diag_oft.py
class NetworkModuleOFT(network.NetworkModule):
def __init__(self, net: network.Network, weights: network.NetworkWeights):
super().__init__(net, weights)
self.lin_module = None
self.org_module: list[torch.Module] = [self.sd_module]
self.scale = 1.0
# kohya-ss
if "oft_blocks" in weights.w.keys():
self.is_kohya = True
self.oft_blocks = weights.w["oft_blocks"] # (num_blocks, block_size, block_size)
self.alpha = weights.w["alpha"] # alpha is constraint
self.dim = self.oft_blocks.shape[0] # lora dim
# LyCORIS
elif "oft_diag" in weights.w.keys():
self.is_kohya = False
self.oft_blocks = weights.w["oft_diag"]
# self.alpha is unused
self.dim = self.oft_blocks.shape[1] # (num_blocks, block_size, block_size)
is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear]
is_conv = type(self.sd_module) in [torch.nn.Conv2d]
is_other_linear = type(self.sd_module) in [torch.nn.MultiheadAttention] # unsupported
if is_linear:
self.out_dim = self.sd_module.out_features
elif is_conv:
self.out_dim = self.sd_module.out_channels
elif is_other_linear:
self.out_dim = self.sd_module.embed_dim
if self.is_kohya:
self.constraint = self.alpha * self.out_dim
self.num_blocks = self.dim
self.block_size = self.out_dim // self.dim
else:
self.constraint = None
self.block_size, self.num_blocks = factorization(self.out_dim, self.dim)
def calc_updown(self, orig_weight):
oft_blocks = self.oft_blocks.to(orig_weight.device)
eye = torch.eye(self.block_size, device=self.oft_blocks.device)
if self.is_kohya:
block_Q = oft_blocks - oft_blocks.transpose(1, 2) # ensure skew-symmetric orthogonal matrix
norm_Q = torch.norm(block_Q.flatten())
new_norm_Q = torch.clamp(norm_Q, max=self.constraint)
block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8))
oft_blocks = torch.matmul(eye + block_Q, (eye - block_Q).float().inverse())
R = oft_blocks.to(orig_weight.device)
# This errors out for MultiheadAttention, might need to be handled up-stream
merged_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size)
merged_weight = torch.einsum(
'k n m, k n ... -> k m ...',
R,
merged_weight
)
merged_weight = rearrange(merged_weight, 'k m ... -> (k m) ...')
updown = merged_weight.to(orig_weight.device) - orig_weight.to(merged_weight.dtype)
output_shape = orig_weight.shape
return self.finalize_updown(updown, orig_weight, output_shape)
+47 -44
View File
@@ -1,3 +1,4 @@
import gradio as gr
import logging import logging
import os import os
import re import re
@@ -5,17 +6,19 @@ import re
import lora_patches import lora_patches
import network import network
import network_lora import network_lora
import network_glora
import network_hada import network_hada
import network_ia3 import network_ia3
import network_lokr import network_lokr
import network_full import network_full
import network_norm import network_norm
import network_oft
import torch import torch
from typing import Union from typing import Union
from modules import shared, devices, sd_models, errors, scripts, sd_hijack from modules import shared, devices, sd_models, errors, scripts, sd_hijack
from modules.textual_inversion.textual_inversion import Embedding import modules.textual_inversion.textual_inversion as textual_inversion
from lora_logger import logger from lora_logger import logger
@@ -26,6 +29,8 @@ module_types = [
network_lokr.ModuleTypeLokr(), network_lokr.ModuleTypeLokr(),
network_full.ModuleTypeFull(), network_full.ModuleTypeFull(),
network_norm.ModuleTypeNorm(), network_norm.ModuleTypeNorm(),
network_glora.ModuleTypeGLora(),
network_oft.ModuleTypeOFT(),
] ]
@@ -155,7 +160,8 @@ def load_network(name, network_on_disk):
bundle_embeddings = {} bundle_embeddings = {}
for key_network, weight in sd.items(): for key_network, weight in sd.items():
key_network_without_network_parts, network_part = key_network.split(".", 1) key_network_without_network_parts, _, network_part = key_network.partition(".")
if key_network_without_network_parts == "bundle_emb": if key_network_without_network_parts == "bundle_emb":
emb_name, vec_name = network_part.split(".", 1) emb_name, vec_name = network_part.split(".", 1)
emb_dict = bundle_embeddings.get(emb_name, {}) emb_dict = bundle_embeddings.get(emb_name, {})
@@ -187,6 +193,17 @@ def load_network(name, network_on_disk):
key = key_network_without_network_parts.replace("lora_te1_text_model", "transformer_text_model") key = key_network_without_network_parts.replace("lora_te1_text_model", "transformer_text_model")
sd_module = shared.sd_model.network_layer_mapping.get(key, None) sd_module = shared.sd_model.network_layer_mapping.get(key, None)
# kohya_ss OFT module
elif sd_module is None and "oft_unet" in key_network_without_network_parts:
key = key_network_without_network_parts.replace("oft_unet", "diffusion_model")
sd_module = shared.sd_model.network_layer_mapping.get(key, None)
# KohakuBlueLeaf OFT module
if sd_module is None and "oft_diag" in key:
key = key_network_without_network_parts.replace("lora_unet", "diffusion_model")
key = key_network_without_network_parts.replace("lora_te1_text_model", "0_transformer_text_model")
sd_module = shared.sd_model.network_layer_mapping.get(key, None)
if sd_module is None: if sd_module is None:
keys_failed_to_match[key_network] = key keys_failed_to_match[key_network] = key
continue continue
@@ -210,34 +227,7 @@ def load_network(name, network_on_disk):
embeddings = {} embeddings = {}
for emb_name, data in bundle_embeddings.items(): for emb_name, data in bundle_embeddings.items():
# textual inversion embeddings embedding = textual_inversion.create_embedding_from_data(data, emb_name, filename=network_on_disk.filename + "/" + emb_name)
if 'string_to_param' in data:
param_dict = data['string_to_param']
param_dict = getattr(param_dict, '_parameters', param_dict) # fix for torch 1.12.1 loading saved file from torch 1.11
assert len(param_dict) == 1, 'embedding file has multiple terms in it'
emb = next(iter(param_dict.items()))[1]
vec = emb.detach().to(devices.device, dtype=torch.float32)
shape = vec.shape[-1]
vectors = vec.shape[0]
elif type(data) == dict and 'clip_g' in data and 'clip_l' in data: # SDXL embedding
vec = {k: v.detach().to(devices.device, dtype=torch.float32) for k, v in data.items()}
shape = data['clip_g'].shape[-1] + data['clip_l'].shape[-1]
vectors = data['clip_g'].shape[0]
elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: # diffuser concepts
assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
emb = next(iter(data.values()))
if len(emb.shape) == 1:
emb = emb.unsqueeze(0)
vec = emb.detach().to(devices.device, dtype=torch.float32)
shape = vec.shape[-1]
vectors = vec.shape[0]
else:
raise Exception(f"Couldn't identify {emb_name} in lora: {name} as neither textual inversion embedding nor diffuser concept.")
embedding = Embedding(vec, emb_name)
embedding.vectors = vectors
embedding.shape = shape
embedding.loaded = None embedding.loaded = None
embeddings[emb_name] = embedding embeddings[emb_name] = embedding
@@ -270,11 +260,11 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No
loaded_networks.clear() loaded_networks.clear()
networks_on_disk = [available_network_aliases.get(name, None) for name in names] networks_on_disk = [available_networks.get(name, None) if name.lower() in forbidden_network_aliases else available_network_aliases.get(name, None) for name in names]
if any(x is None for x in networks_on_disk): if any(x is None for x in networks_on_disk):
list_available_networks() list_available_networks()
networks_on_disk = [available_network_aliases.get(name, None) for name in names] networks_on_disk = [available_networks.get(name, None) if name.lower() in forbidden_network_aliases else available_network_aliases.get(name, None) for name in names]
failed_to_load_networks = [] failed_to_load_networks = []
@@ -325,7 +315,12 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No
emb_db.skipped_embeddings[name] = embedding emb_db.skipped_embeddings[name] = embedding
if failed_to_load_networks: if failed_to_load_networks:
sd_hijack.model_hijack.comments.append("Networks not found: " + ", ".join(failed_to_load_networks)) lora_not_found_message = f'Lora not found: {", ".join(failed_to_load_networks)}'
sd_hijack.model_hijack.comments.append(lora_not_found_message)
if shared.opts.lora_not_found_warning_console:
print(f'\n{lora_not_found_message}\n')
if shared.opts.lora_not_found_gradio_warning:
gr.Warning(lora_not_found_message)
purge_networks_from_memory() purge_networks_from_memory()
@@ -400,18 +395,26 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
if module is not None and hasattr(self, 'weight'): if module is not None and hasattr(self, 'weight'):
try: try:
with torch.no_grad(): with torch.no_grad():
updown, ex_bias = module.calc_updown(self.weight) if getattr(self, 'fp16_weight', None) is None:
weight = self.weight
bias = self.bias
else:
weight = self.fp16_weight.clone().to(self.weight.device)
bias = getattr(self, 'fp16_bias', None)
if bias is not None:
bias = bias.clone().to(self.bias.device)
updown, ex_bias = module.calc_updown(weight)
if len(self.weight.shape) == 4 and self.weight.shape[1] == 9: if len(weight.shape) == 4 and weight.shape[1] == 9:
# inpainting model. zero pad updown to make channel[1] 4 to 9 # inpainting model. zero pad updown to make channel[1] 4 to 9
updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5)) updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5))
self.weight += updown self.weight.copy_((weight.to(dtype=updown.dtype) + updown).to(dtype=self.weight.dtype))
if ex_bias is not None and hasattr(self, 'bias'): if ex_bias is not None and hasattr(self, 'bias'):
if self.bias is None: if self.bias is None:
self.bias = torch.nn.Parameter(ex_bias) self.bias = torch.nn.Parameter(ex_bias).to(self.weight.dtype)
else: else:
self.bias += ex_bias self.bias.copy_((bias + ex_bias).to(dtype=self.bias.dtype))
except RuntimeError as e: except RuntimeError as e:
logging.debug(f"Network {net.name} layer {network_layer_name}: {e}") logging.debug(f"Network {net.name} layer {network_layer_name}: {e}")
extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1 extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
@@ -455,23 +458,23 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
self.network_current_names = wanted_names self.network_current_names = wanted_names
def network_forward(module, input, original_forward): def network_forward(org_module, input, original_forward):
""" """
Old way of applying Lora by executing operations during layer's forward. Old way of applying Lora by executing operations during layer's forward.
Stacking many loras this way results in big performance degradation. Stacking many loras this way results in big performance degradation.
""" """
if len(loaded_networks) == 0: if len(loaded_networks) == 0:
return original_forward(module, input) return original_forward(org_module, input)
input = devices.cond_cast_unet(input) input = devices.cond_cast_unet(input)
network_restore_weights_from_backup(module) network_restore_weights_from_backup(org_module)
network_reset_cached_weight(module) network_reset_cached_weight(org_module)
y = original_forward(module, input) y = original_forward(org_module, input)
network_layer_name = getattr(module, 'network_layer_name', None) network_layer_name = getattr(org_module, 'network_layer_name', None)
for lora in loaded_networks: for lora in loaded_networks:
module = lora.modules.get(network_layer_name, None) module = lora.modules.get(network_layer_name, None)
if module is None: if module is None:
@@ -39,6 +39,8 @@ shared.options_templates.update(shared.options_section(('extra_networks', "Extra
"lora_show_all": shared.OptionInfo(False, "Always show all networks on the Lora page").info("otherwise, those detected as for incompatible version of Stable Diffusion will be hidden"), "lora_show_all": shared.OptionInfo(False, "Always show all networks on the Lora page").info("otherwise, those detected as for incompatible version of Stable Diffusion will be hidden"),
"lora_hide_unknown_for_versions": shared.OptionInfo([], "Hide networks of unknown versions for model versions", gr.CheckboxGroup, {"choices": ["SD1", "SD2", "SDXL"]}), "lora_hide_unknown_for_versions": shared.OptionInfo([], "Hide networks of unknown versions for model versions", gr.CheckboxGroup, {"choices": ["SD1", "SD2", "SDXL"]}),
"lora_in_memory_limit": shared.OptionInfo(0, "Number of Lora networks to keep cached in memory", gr.Number, {"precision": 0}), "lora_in_memory_limit": shared.OptionInfo(0, "Number of Lora networks to keep cached in memory", gr.Number, {"precision": 0}),
"lora_not_found_warning_console": shared.OptionInfo(False, "Lora not found warning in console"),
"lora_not_found_gradio_warning": shared.OptionInfo(False, "Lora not found warning popup in webui"),
})) }))
@@ -54,12 +54,13 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor)
self.slider_preferred_weight = None self.slider_preferred_weight = None
self.edit_notes = None self.edit_notes = None
def save_lora_user_metadata(self, name, desc, sd_version, activation_text, preferred_weight, notes): def save_lora_user_metadata(self, name, desc, sd_version, activation_text, preferred_weight, negative_text, notes):
user_metadata = self.get_user_metadata(name) user_metadata = self.get_user_metadata(name)
user_metadata["description"] = desc user_metadata["description"] = desc
user_metadata["sd version"] = sd_version user_metadata["sd version"] = sd_version
user_metadata["activation text"] = activation_text user_metadata["activation text"] = activation_text
user_metadata["preferred weight"] = preferred_weight user_metadata["preferred weight"] = preferred_weight
user_metadata["negative text"] = negative_text
user_metadata["notes"] = notes user_metadata["notes"] = notes
self.write_user_metadata(name, user_metadata) self.write_user_metadata(name, user_metadata)
@@ -127,6 +128,7 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor)
gr.HighlightedText.update(value=gradio_tags, visible=True if tags else False), gr.HighlightedText.update(value=gradio_tags, visible=True if tags else False),
user_metadata.get('activation text', ''), user_metadata.get('activation text', ''),
float(user_metadata.get('preferred weight', 0.0)), float(user_metadata.get('preferred weight', 0.0)),
user_metadata.get('negative text', ''),
gr.update(visible=True if tags else False), gr.update(visible=True if tags else False),
gr.update(value=self.generate_random_prompt_from_tags(tags), visible=True if tags else False), gr.update(value=self.generate_random_prompt_from_tags(tags), visible=True if tags else False),
] ]
@@ -162,7 +164,7 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor)
self.taginfo = gr.HighlightedText(label="Training dataset tags") self.taginfo = gr.HighlightedText(label="Training dataset tags")
self.edit_activation_text = gr.Text(label='Activation text', info="Will be added to prompt along with Lora") self.edit_activation_text = gr.Text(label='Activation text', info="Will be added to prompt along with Lora")
self.slider_preferred_weight = gr.Slider(label='Preferred weight', info="Set to 0 to disable", minimum=0.0, maximum=2.0, step=0.01) self.slider_preferred_weight = gr.Slider(label='Preferred weight', info="Set to 0 to disable", minimum=0.0, maximum=2.0, step=0.01)
self.edit_negative_text = gr.Text(label='Negative prompt', info="Will be added to negative prompts")
with gr.Row() as row_random_prompt: with gr.Row() as row_random_prompt:
with gr.Column(scale=8): with gr.Column(scale=8):
random_prompt = gr.Textbox(label='Random prompt', lines=4, max_lines=4, interactive=False) random_prompt = gr.Textbox(label='Random prompt', lines=4, max_lines=4, interactive=False)
@@ -198,6 +200,7 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor)
self.taginfo, self.taginfo,
self.edit_activation_text, self.edit_activation_text,
self.slider_preferred_weight, self.slider_preferred_weight,
self.edit_negative_text,
row_random_prompt, row_random_prompt,
random_prompt, random_prompt,
] ]
@@ -211,7 +214,9 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor)
self.select_sd_version, self.select_sd_version,
self.edit_activation_text, self.edit_activation_text,
self.slider_preferred_weight, self.slider_preferred_weight,
self.edit_negative_text,
self.edit_notes, self.edit_notes,
] ]
self.setup_save_handler(self.button_save, self.save_lora_user_metadata, edited_components) self.setup_save_handler(self.button_save, self.save_lora_user_metadata, edited_components)
@@ -17,6 +17,8 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
def create_item(self, name, index=None, enable_filter=True): def create_item(self, name, index=None, enable_filter=True):
lora_on_disk = networks.available_networks.get(name) lora_on_disk = networks.available_networks.get(name)
if lora_on_disk is None:
return
path, ext = os.path.splitext(lora_on_disk.filename) path, ext = os.path.splitext(lora_on_disk.filename)
@@ -43,6 +45,11 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
if activation_text: if activation_text:
item["prompt"] += " + " + quote_js(" " + activation_text) item["prompt"] += " + " + quote_js(" " + activation_text)
negative_prompt = item["user_metadata"].get("negative text")
item["negative_prompt"] = quote_js("")
if negative_prompt:
item["negative_prompt"] = quote_js('(' + negative_prompt + ':1)')
sd_version = item["user_metadata"].get("sd version") sd_version = item["user_metadata"].get("sd version")
if sd_version in network.SdVersion.__members__: if sd_version in network.SdVersion.__members__:
item["sd_version"] = sd_version item["sd_version"] = sd_version
@@ -66,9 +73,10 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
return item return item
def list_items(self): def list_items(self):
for index, name in enumerate(networks.available_networks): # instantiate a list to protect against concurrent modification
names = list(networks.available_networks)
for index, name in enumerate(names):
item = self.create_item(name, index) item = self.create_item(name, index)
if item is not None: if item is not None:
yield item yield item
@@ -1,16 +1,9 @@
import sys import sys
import PIL.Image import PIL.Image
import numpy as np
import torch
from tqdm import tqdm
import modules.upscaler import modules.upscaler
from modules import devices, modelloader, script_callbacks, errors from modules import devices, errors, modelloader, script_callbacks, shared, upscaler_utils
from scunet_model_arch import SCUNet
from modules.modelloader import load_file_from_url
from modules.shared import opts
class UpscalerScuNET(modules.upscaler.Upscaler): class UpscalerScuNET(modules.upscaler.Upscaler):
@@ -42,100 +35,37 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
scalers.append(scaler_data2) scalers.append(scaler_data2)
self.scalers = scalers self.scalers = scalers
@staticmethod
@torch.no_grad()
def tiled_inference(img, model):
# test the image tile by tile
h, w = img.shape[2:]
tile = opts.SCUNET_tile
tile_overlap = opts.SCUNET_tile_overlap
if tile == 0:
return model(img)
device = devices.get_device_for('scunet')
assert tile % 8 == 0, "tile size should be a multiple of window_size"
sf = 1
stride = tile - tile_overlap
h_idx_list = list(range(0, h - tile, stride)) + [h - tile]
w_idx_list = list(range(0, w - tile, stride)) + [w - tile]
E = torch.zeros(1, 3, h * sf, w * sf, dtype=img.dtype, device=device)
W = torch.zeros_like(E, dtype=devices.dtype, device=device)
with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="ScuNET tiles") as pbar:
for h_idx in h_idx_list:
for w_idx in w_idx_list:
in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile]
out_patch = model(in_patch)
out_patch_mask = torch.ones_like(out_patch)
E[
..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
].add_(out_patch)
W[
..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
].add_(out_patch_mask)
pbar.update(1)
output = E.div_(W)
return output
def do_upscale(self, img: PIL.Image.Image, selected_file): def do_upscale(self, img: PIL.Image.Image, selected_file):
devices.torch_gc() devices.torch_gc()
try: try:
model = self.load_model(selected_file) model = self.load_model(selected_file)
except Exception as e: except Exception as e:
print(f"ScuNET: Unable to load model from {selected_file}: {e}", file=sys.stderr) print(f"ScuNET: Unable to load model from {selected_file}: {e}", file=sys.stderr)
return img return img
device = devices.get_device_for('scunet') img = upscaler_utils.upscale_2(
tile = opts.SCUNET_tile img,
h, w = img.height, img.width model,
np_img = np.array(img) tile_size=shared.opts.SCUNET_tile,
np_img = np_img[:, :, ::-1] # RGB to BGR tile_overlap=shared.opts.SCUNET_tile_overlap,
np_img = np_img.transpose((2, 0, 1)) / 255 # HWC to CHW scale=1, # ScuNET is a denoising model, not an upscaler
torch_img = torch.from_numpy(np_img).float().unsqueeze(0).to(device) # type: ignore desc='ScuNET',
)
if tile > h or tile > w:
_img = torch.zeros(1, 3, max(h, tile), max(w, tile), dtype=torch_img.dtype, device=torch_img.device)
_img[:, :, :h, :w] = torch_img # pad image
torch_img = _img
torch_output = self.tiled_inference(torch_img, model).squeeze(0)
torch_output = torch_output[:, :h * 1, :w * 1] # remove padding, if any
np_output: np.ndarray = torch_output.float().cpu().clamp_(0, 1).numpy()
del torch_img, torch_output
devices.torch_gc() devices.torch_gc()
return img
output = np_output.transpose((1, 2, 0)) # CHW to HWC
output = output[:, :, ::-1] # BGR to RGB
return PIL.Image.fromarray((output * 255).astype(np.uint8))
def load_model(self, path: str): def load_model(self, path: str):
device = devices.get_device_for('scunet') device = devices.get_device_for('scunet')
if path.startswith("http"): if path.startswith("http"):
# TODO: this doesn't use `path` at all? # TODO: this doesn't use `path` at all?
filename = load_file_from_url(self.model_url, model_dir=self.model_download_path, file_name=f"{self.name}.pth") filename = modelloader.load_file_from_url(self.model_url, model_dir=self.model_download_path, file_name=f"{self.name}.pth")
else: else:
filename = path filename = path
model = SCUNet(in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64) return modelloader.load_spandrel_model(filename, device=device, expected_architecture='SCUNet')
model.load_state_dict(torch.load(filename), strict=True)
model.eval()
for _, v in model.named_parameters():
v.requires_grad = False
model = model.to(device)
return model
def on_ui_settings(): def on_ui_settings():
import gradio as gr import gradio as gr
from modules import shared
shared.opts.add_option("SCUNET_tile", shared.OptionInfo(256, "Tile size for SCUNET upscalers.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}, section=('upscaling', "Upscaling")).info("0 = no tiling")) shared.opts.add_option("SCUNET_tile", shared.OptionInfo(256, "Tile size for SCUNET upscalers.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}, section=('upscaling', "Upscaling")).info("0 = no tiling"))
shared.opts.add_option("SCUNET_tile_overlap", shared.OptionInfo(8, "Tile overlap for SCUNET upscalers.", gr.Slider, {"minimum": 0, "maximum": 64, "step": 1}, section=('upscaling', "Upscaling")).info("Low values = visible seam")) shared.opts.add_option("SCUNET_tile_overlap", shared.OptionInfo(8, "Tile overlap for SCUNET upscalers.", gr.Slider, {"minimum": 0, "maximum": 64, "step": 1}, section=('upscaling', "Upscaling")).info("Low values = visible seam"))
@@ -1,268 +0,0 @@
# -*- coding: utf-8 -*-
import numpy as np
import torch
import torch.nn as nn
from einops import rearrange
from einops.layers.torch import Rearrange
from timm.models.layers import trunc_normal_, DropPath
class WMSA(nn.Module):
""" Self-attention module in Swin Transformer
"""
def __init__(self, input_dim, output_dim, head_dim, window_size, type):
super(WMSA, self).__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.head_dim = head_dim
self.scale = self.head_dim ** -0.5
self.n_heads = input_dim // head_dim
self.window_size = window_size
self.type = type
self.embedding_layer = nn.Linear(self.input_dim, 3 * self.input_dim, bias=True)
self.relative_position_params = nn.Parameter(
torch.zeros((2 * window_size - 1) * (2 * window_size - 1), self.n_heads))
self.linear = nn.Linear(self.input_dim, self.output_dim)
trunc_normal_(self.relative_position_params, std=.02)
self.relative_position_params = torch.nn.Parameter(
self.relative_position_params.view(2 * window_size - 1, 2 * window_size - 1, self.n_heads).transpose(1,
2).transpose(
0, 1))
def generate_mask(self, h, w, p, shift):
""" generating the mask of SW-MSA
Args:
shift: shift parameters in CyclicShift.
Returns:
attn_mask: should be (1 1 w p p),
"""
# supporting square.
attn_mask = torch.zeros(h, w, p, p, p, p, dtype=torch.bool, device=self.relative_position_params.device)
if self.type == 'W':
return attn_mask
s = p - shift
attn_mask[-1, :, :s, :, s:, :] = True
attn_mask[-1, :, s:, :, :s, :] = True
attn_mask[:, -1, :, :s, :, s:] = True
attn_mask[:, -1, :, s:, :, :s] = True
attn_mask = rearrange(attn_mask, 'w1 w2 p1 p2 p3 p4 -> 1 1 (w1 w2) (p1 p2) (p3 p4)')
return attn_mask
def forward(self, x):
""" Forward pass of Window Multi-head Self-attention module.
Args:
x: input tensor with shape of [b h w c];
attn_mask: attention mask, fill -inf where the value is True;
Returns:
output: tensor shape [b h w c]
"""
if self.type != 'W':
x = torch.roll(x, shifts=(-(self.window_size // 2), -(self.window_size // 2)), dims=(1, 2))
x = rearrange(x, 'b (w1 p1) (w2 p2) c -> b w1 w2 p1 p2 c', p1=self.window_size, p2=self.window_size)
h_windows = x.size(1)
w_windows = x.size(2)
# square validation
# assert h_windows == w_windows
x = rearrange(x, 'b w1 w2 p1 p2 c -> b (w1 w2) (p1 p2) c', p1=self.window_size, p2=self.window_size)
qkv = self.embedding_layer(x)
q, k, v = rearrange(qkv, 'b nw np (threeh c) -> threeh b nw np c', c=self.head_dim).chunk(3, dim=0)
sim = torch.einsum('hbwpc,hbwqc->hbwpq', q, k) * self.scale
# Adding learnable relative embedding
sim = sim + rearrange(self.relative_embedding(), 'h p q -> h 1 1 p q')
# Using Attn Mask to distinguish different subwindows.
if self.type != 'W':
attn_mask = self.generate_mask(h_windows, w_windows, self.window_size, shift=self.window_size // 2)
sim = sim.masked_fill_(attn_mask, float("-inf"))
probs = nn.functional.softmax(sim, dim=-1)
output = torch.einsum('hbwij,hbwjc->hbwic', probs, v)
output = rearrange(output, 'h b w p c -> b w p (h c)')
output = self.linear(output)
output = rearrange(output, 'b (w1 w2) (p1 p2) c -> b (w1 p1) (w2 p2) c', w1=h_windows, p1=self.window_size)
if self.type != 'W':
output = torch.roll(output, shifts=(self.window_size // 2, self.window_size // 2), dims=(1, 2))
return output
def relative_embedding(self):
cord = torch.tensor(np.array([[i, j] for i in range(self.window_size) for j in range(self.window_size)]))
relation = cord[:, None, :] - cord[None, :, :] + self.window_size - 1
# negative is allowed
return self.relative_position_params[:, relation[:, :, 0].long(), relation[:, :, 1].long()]
class Block(nn.Module):
def __init__(self, input_dim, output_dim, head_dim, window_size, drop_path, type='W', input_resolution=None):
""" SwinTransformer Block
"""
super(Block, self).__init__()
self.input_dim = input_dim
self.output_dim = output_dim
assert type in ['W', 'SW']
self.type = type
if input_resolution <= window_size:
self.type = 'W'
self.ln1 = nn.LayerNorm(input_dim)
self.msa = WMSA(input_dim, input_dim, head_dim, window_size, self.type)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.ln2 = nn.LayerNorm(input_dim)
self.mlp = nn.Sequential(
nn.Linear(input_dim, 4 * input_dim),
nn.GELU(),
nn.Linear(4 * input_dim, output_dim),
)
def forward(self, x):
x = x + self.drop_path(self.msa(self.ln1(x)))
x = x + self.drop_path(self.mlp(self.ln2(x)))
return x
class ConvTransBlock(nn.Module):
def __init__(self, conv_dim, trans_dim, head_dim, window_size, drop_path, type='W', input_resolution=None):
""" SwinTransformer and Conv Block
"""
super(ConvTransBlock, self).__init__()
self.conv_dim = conv_dim
self.trans_dim = trans_dim
self.head_dim = head_dim
self.window_size = window_size
self.drop_path = drop_path
self.type = type
self.input_resolution = input_resolution
assert self.type in ['W', 'SW']
if self.input_resolution <= self.window_size:
self.type = 'W'
self.trans_block = Block(self.trans_dim, self.trans_dim, self.head_dim, self.window_size, self.drop_path,
self.type, self.input_resolution)
self.conv1_1 = nn.Conv2d(self.conv_dim + self.trans_dim, self.conv_dim + self.trans_dim, 1, 1, 0, bias=True)
self.conv1_2 = nn.Conv2d(self.conv_dim + self.trans_dim, self.conv_dim + self.trans_dim, 1, 1, 0, bias=True)
self.conv_block = nn.Sequential(
nn.Conv2d(self.conv_dim, self.conv_dim, 3, 1, 1, bias=False),
nn.ReLU(True),
nn.Conv2d(self.conv_dim, self.conv_dim, 3, 1, 1, bias=False)
)
def forward(self, x):
conv_x, trans_x = torch.split(self.conv1_1(x), (self.conv_dim, self.trans_dim), dim=1)
conv_x = self.conv_block(conv_x) + conv_x
trans_x = Rearrange('b c h w -> b h w c')(trans_x)
trans_x = self.trans_block(trans_x)
trans_x = Rearrange('b h w c -> b c h w')(trans_x)
res = self.conv1_2(torch.cat((conv_x, trans_x), dim=1))
x = x + res
return x
class SCUNet(nn.Module):
# def __init__(self, in_nc=3, config=[2, 2, 2, 2, 2, 2, 2], dim=64, drop_path_rate=0.0, input_resolution=256):
def __init__(self, in_nc=3, config=None, dim=64, drop_path_rate=0.0, input_resolution=256):
super(SCUNet, self).__init__()
if config is None:
config = [2, 2, 2, 2, 2, 2, 2]
self.config = config
self.dim = dim
self.head_dim = 32
self.window_size = 8
# drop path rate for each layer
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(config))]
self.m_head = [nn.Conv2d(in_nc, dim, 3, 1, 1, bias=False)]
begin = 0
self.m_down1 = [ConvTransBlock(dim // 2, dim // 2, self.head_dim, self.window_size, dpr[i + begin],
'W' if not i % 2 else 'SW', input_resolution)
for i in range(config[0])] + \
[nn.Conv2d(dim, 2 * dim, 2, 2, 0, bias=False)]
begin += config[0]
self.m_down2 = [ConvTransBlock(dim, dim, self.head_dim, self.window_size, dpr[i + begin],
'W' if not i % 2 else 'SW', input_resolution // 2)
for i in range(config[1])] + \
[nn.Conv2d(2 * dim, 4 * dim, 2, 2, 0, bias=False)]
begin += config[1]
self.m_down3 = [ConvTransBlock(2 * dim, 2 * dim, self.head_dim, self.window_size, dpr[i + begin],
'W' if not i % 2 else 'SW', input_resolution // 4)
for i in range(config[2])] + \
[nn.Conv2d(4 * dim, 8 * dim, 2, 2, 0, bias=False)]
begin += config[2]
self.m_body = [ConvTransBlock(4 * dim, 4 * dim, self.head_dim, self.window_size, dpr[i + begin],
'W' if not i % 2 else 'SW', input_resolution // 8)
for i in range(config[3])]
begin += config[3]
self.m_up3 = [nn.ConvTranspose2d(8 * dim, 4 * dim, 2, 2, 0, bias=False), ] + \
[ConvTransBlock(2 * dim, 2 * dim, self.head_dim, self.window_size, dpr[i + begin],
'W' if not i % 2 else 'SW', input_resolution // 4)
for i in range(config[4])]
begin += config[4]
self.m_up2 = [nn.ConvTranspose2d(4 * dim, 2 * dim, 2, 2, 0, bias=False), ] + \
[ConvTransBlock(dim, dim, self.head_dim, self.window_size, dpr[i + begin],
'W' if not i % 2 else 'SW', input_resolution // 2)
for i in range(config[5])]
begin += config[5]
self.m_up1 = [nn.ConvTranspose2d(2 * dim, dim, 2, 2, 0, bias=False), ] + \
[ConvTransBlock(dim // 2, dim // 2, self.head_dim, self.window_size, dpr[i + begin],
'W' if not i % 2 else 'SW', input_resolution)
for i in range(config[6])]
self.m_tail = [nn.Conv2d(dim, in_nc, 3, 1, 1, bias=False)]
self.m_head = nn.Sequential(*self.m_head)
self.m_down1 = nn.Sequential(*self.m_down1)
self.m_down2 = nn.Sequential(*self.m_down2)
self.m_down3 = nn.Sequential(*self.m_down3)
self.m_body = nn.Sequential(*self.m_body)
self.m_up3 = nn.Sequential(*self.m_up3)
self.m_up2 = nn.Sequential(*self.m_up2)
self.m_up1 = nn.Sequential(*self.m_up1)
self.m_tail = nn.Sequential(*self.m_tail)
# self.apply(self._init_weights)
def forward(self, x0):
h, w = x0.size()[-2:]
paddingBottom = int(np.ceil(h / 64) * 64 - h)
paddingRight = int(np.ceil(w / 64) * 64 - w)
x0 = nn.ReplicationPad2d((0, paddingRight, 0, paddingBottom))(x0)
x1 = self.m_head(x0)
x2 = self.m_down1(x1)
x3 = self.m_down2(x2)
x4 = self.m_down3(x3)
x = self.m_body(x4)
x = self.m_up3(x + x4)
x = self.m_up2(x + x3)
x = self.m_up1(x + x2)
x = self.m_tail(x + x1)
x = x[..., :h, :w]
return x
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
+32 -129
View File
@@ -1,20 +1,15 @@
import logging
import sys import sys
import platform
import numpy as np
import torch import torch
from PIL import Image from PIL import Image
from tqdm import tqdm
from modules import modelloader, devices, script_callbacks, shared from modules import devices, modelloader, script_callbacks, shared, upscaler_utils
from modules.shared import opts, state
from swinir_model_arch import SwinIR
from swinir_model_arch_v2 import Swin2SR
from modules.upscaler import Upscaler, UpscalerData from modules.upscaler import Upscaler, UpscalerData
SWINIR_MODEL_URL = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth" SWINIR_MODEL_URL = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth"
device_swinir = devices.get_device_for('swinir') logger = logging.getLogger(__name__)
class UpscalerSwinIR(Upscaler): class UpscalerSwinIR(Upscaler):
@@ -37,26 +32,28 @@ class UpscalerSwinIR(Upscaler):
scalers.append(model_data) scalers.append(model_data)
self.scalers = scalers self.scalers = scalers
def do_upscale(self, img, model_file): def do_upscale(self, img: Image.Image, model_file: str) -> Image.Image:
use_compile = hasattr(opts, 'SWIN_torch_compile') and opts.SWIN_torch_compile \ current_config = (model_file, shared.opts.SWIN_tile)
and int(torch.__version__.split('.')[0]) >= 2 and platform.system() != "Windows"
current_config = (model_file, opts.SWIN_tile)
if use_compile and self._cached_model_config == current_config: if self._cached_model_config == current_config:
model = self._cached_model model = self._cached_model
else: else:
self._cached_model = None
try: try:
model = self.load_model(model_file) model = self.load_model(model_file)
except Exception as e: except Exception as e:
print(f"Failed loading SwinIR model {model_file}: {e}", file=sys.stderr) print(f"Failed loading SwinIR model {model_file}: {e}", file=sys.stderr)
return img return img
model = model.to(device_swinir, dtype=devices.dtype) self._cached_model = model
if use_compile: self._cached_model_config = current_config
model = torch.compile(model)
self._cached_model = model img = upscaler_utils.upscale_2(
self._cached_model_config = current_config img,
img = upscale(img, model) model,
tile_size=shared.opts.SWIN_tile,
tile_overlap=shared.opts.SWIN_tile_overlap,
scale=model.scale,
desc="SwinIR",
)
devices.torch_gc() devices.torch_gc()
return img return img
@@ -69,115 +66,22 @@ class UpscalerSwinIR(Upscaler):
) )
else: else:
filename = path filename = path
if filename.endswith(".v2.pth"):
model = Swin2SR(
upscale=scale,
in_chans=3,
img_size=64,
window_size=8,
img_range=1.0,
depths=[6, 6, 6, 6, 6, 6],
embed_dim=180,
num_heads=[6, 6, 6, 6, 6, 6],
mlp_ratio=2,
upsampler="nearest+conv",
resi_connection="1conv",
)
params = None
else:
model = SwinIR(
upscale=scale,
in_chans=3,
img_size=64,
window_size=8,
img_range=1.0,
depths=[6, 6, 6, 6, 6, 6, 6, 6, 6],
embed_dim=240,
num_heads=[8, 8, 8, 8, 8, 8, 8, 8, 8],
mlp_ratio=2,
upsampler="nearest+conv",
resi_connection="3conv",
)
params = "params_ema"
pretrained_model = torch.load(filename) model_descriptor = modelloader.load_spandrel_model(
if params is not None: filename,
model.load_state_dict(pretrained_model[params], strict=True) device=self._get_device(),
else: prefer_half=(devices.dtype == torch.float16),
model.load_state_dict(pretrained_model, strict=True) expected_architecture="SwinIR",
return model )
if getattr(shared.opts, 'SWIN_torch_compile', False):
try:
model_descriptor.model.compile()
except Exception:
logger.warning("Failed to compile SwinIR model, fallback to JIT", exc_info=True)
return model_descriptor
def _get_device(self):
def upscale( return devices.get_device_for('swinir')
img,
model,
tile=None,
tile_overlap=None,
window_size=8,
scale=4,
):
tile = tile or opts.SWIN_tile
tile_overlap = tile_overlap or opts.SWIN_tile_overlap
img = np.array(img)
img = img[:, :, ::-1]
img = np.moveaxis(img, 2, 0) / 255
img = torch.from_numpy(img).float()
img = img.unsqueeze(0).to(device_swinir, dtype=devices.dtype)
with torch.no_grad(), devices.autocast():
_, _, h_old, w_old = img.size()
h_pad = (h_old // window_size + 1) * window_size - h_old
w_pad = (w_old // window_size + 1) * window_size - w_old
img = torch.cat([img, torch.flip(img, [2])], 2)[:, :, : h_old + h_pad, :]
img = torch.cat([img, torch.flip(img, [3])], 3)[:, :, :, : w_old + w_pad]
output = inference(img, model, tile, tile_overlap, window_size, scale)
output = output[..., : h_old * scale, : w_old * scale]
output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
if output.ndim == 3:
output = np.transpose(
output[[2, 1, 0], :, :], (1, 2, 0)
) # CHW-RGB to HCW-BGR
output = (output * 255.0).round().astype(np.uint8) # float32 to uint8
return Image.fromarray(output, "RGB")
def inference(img, model, tile, tile_overlap, window_size, scale):
# test the image tile by tile
b, c, h, w = img.size()
tile = min(tile, h, w)
assert tile % window_size == 0, "tile size should be a multiple of window_size"
sf = scale
stride = tile - tile_overlap
h_idx_list = list(range(0, h - tile, stride)) + [h - tile]
w_idx_list = list(range(0, w - tile, stride)) + [w - tile]
E = torch.zeros(b, c, h * sf, w * sf, dtype=devices.dtype, device=device_swinir).type_as(img)
W = torch.zeros_like(E, dtype=devices.dtype, device=device_swinir)
with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="SwinIR tiles") as pbar:
for h_idx in h_idx_list:
if state.interrupted or state.skipped:
break
for w_idx in w_idx_list:
if state.interrupted or state.skipped:
break
in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile]
out_patch = model(in_patch)
out_patch_mask = torch.ones_like(out_patch)
E[
..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
].add_(out_patch)
W[
..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
].add_(out_patch_mask)
pbar.update(1)
output = E.div_(W)
return output
def on_ui_settings(): def on_ui_settings():
@@ -185,8 +89,7 @@ def on_ui_settings():
shared.opts.add_option("SWIN_tile", shared.OptionInfo(192, "Tile size for all SwinIR.", gr.Slider, {"minimum": 16, "maximum": 512, "step": 16}, section=('upscaling', "Upscaling"))) shared.opts.add_option("SWIN_tile", shared.OptionInfo(192, "Tile size for all SwinIR.", gr.Slider, {"minimum": 16, "maximum": 512, "step": 16}, section=('upscaling', "Upscaling")))
shared.opts.add_option("SWIN_tile_overlap", shared.OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}, section=('upscaling', "Upscaling"))) shared.opts.add_option("SWIN_tile_overlap", shared.OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}, section=('upscaling', "Upscaling")))
if int(torch.__version__.split('.')[0]) >= 2 and platform.system() != "Windows": # torch.compile() require pytorch 2.0 or above, and not on Windows shared.opts.add_option("SWIN_torch_compile", shared.OptionInfo(False, "Use torch.compile to accelerate SwinIR.", gr.Checkbox, {"interactive": True}, section=('upscaling', "Upscaling")).info("Takes longer on first run"))
shared.opts.add_option("SWIN_torch_compile", shared.OptionInfo(False, "Use torch.compile to accelerate SwinIR.", gr.Checkbox, {"interactive": True}, section=('upscaling', "Upscaling")).info("Takes longer on first run"))
script_callbacks.on_ui_settings(on_ui_settings) script_callbacks.on_ui_settings(on_ui_settings)
@@ -1,867 +0,0 @@
# -----------------------------------------------------------------------------------
# SwinIR: Image Restoration Using Swin Transformer, https://arxiv.org/abs/2108.10257
# Originally Written by Ze Liu, Modified by Jingyun Liang.
# -----------------------------------------------------------------------------------
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
def window_partition(x, window_size):
"""
Args:
x: (B, H, W, C)
window_size (int): window size
Returns:
windows: (num_windows*B, window_size, window_size, C)
"""
B, H, W, C = x.shape
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows
def window_reverse(windows, window_size, H, W):
"""
Args:
windows: (num_windows*B, window_size, window_size, C)
window_size (int): Window size
H (int): Height of image
W (int): Width of image
Returns:
x: (B, H, W, C)
"""
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x
class WindowAttention(nn.Module):
r""" Window based multi-head self attention (W-MSA) module with relative position bias.
It supports both of shifted and non-shifted window.
Args:
dim (int): Number of input channels.
window_size (tuple[int]): The height and width of the window.
num_heads (int): Number of attention heads.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
"""
def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.dim = dim
self.window_size = window_size # Wh, Ww
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
# define a parameter table of relative position bias
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
self.register_buffer("relative_position_index", relative_position_index)
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
trunc_normal_(self.relative_position_bias_table, std=.02)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x, mask=None):
"""
Args:
x: input features with shape of (num_windows*B, N, C)
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
"""
B_, N, C = x.shape
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)
if mask is not None:
nW = mask.shape[0]
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
def extra_repr(self) -> str:
return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
def flops(self, N):
# calculate flops for 1 window with token length of N
flops = 0
# qkv = self.qkv(x)
flops += N * self.dim * 3 * self.dim
# attn = (q @ k.transpose(-2, -1))
flops += self.num_heads * N * (self.dim // self.num_heads) * N
# x = (attn @ v)
flops += self.num_heads * N * N * (self.dim // self.num_heads)
# x = self.proj(x)
flops += N * self.dim * self.dim
return flops
class SwinTransformerBlock(nn.Module):
r""" Swin Transformer Block.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resolution.
num_heads (int): Number of attention heads.
window_size (int): Window size.
shift_size (int): Shift size for SW-MSA.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float, optional): Stochastic depth rate. Default: 0.0
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.num_heads = num_heads
self.window_size = window_size
self.shift_size = shift_size
self.mlp_ratio = mlp_ratio
if min(self.input_resolution) <= self.window_size:
# if window size is larger than input resolution, we don't partition windows
self.shift_size = 0
self.window_size = min(self.input_resolution)
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
self.norm1 = norm_layer(dim)
self.attn = WindowAttention(
dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
if self.shift_size > 0:
attn_mask = self.calculate_mask(self.input_resolution)
else:
attn_mask = None
self.register_buffer("attn_mask", attn_mask)
def calculate_mask(self, x_size):
# calculate attention mask for SW-MSA
H, W = x_size
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
h_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
w_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
return attn_mask
def forward(self, x, x_size):
H, W = x_size
B, L, C = x.shape
# assert L == H * W, "input feature has wrong size"
shortcut = x
x = self.norm1(x)
x = x.view(B, H, W, C)
# cyclic shift
if self.shift_size > 0:
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
else:
shifted_x = x
# partition windows
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
# W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size
if self.input_resolution == x_size:
attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
else:
attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device))
# merge windows
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
# reverse cyclic shift
if self.shift_size > 0:
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
else:
x = shifted_x
x = x.view(B, H * W, C)
# FFN
x = shortcut + self.drop_path(x)
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
def extra_repr(self) -> str:
return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
def flops(self):
flops = 0
H, W = self.input_resolution
# norm1
flops += self.dim * H * W
# W-MSA/SW-MSA
nW = H * W / self.window_size / self.window_size
flops += nW * self.attn.flops(self.window_size * self.window_size)
# mlp
flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
# norm2
flops += self.dim * H * W
return flops
class PatchMerging(nn.Module):
r""" Patch Merging Layer.
Args:
input_resolution (tuple[int]): Resolution of input feature.
dim (int): Number of input channels.
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
super().__init__()
self.input_resolution = input_resolution
self.dim = dim
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
self.norm = norm_layer(4 * dim)
def forward(self, x):
"""
x: B, H*W, C
"""
H, W = self.input_resolution
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
x = x.view(B, H, W, C)
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
x = self.norm(x)
x = self.reduction(x)
return x
def extra_repr(self) -> str:
return f"input_resolution={self.input_resolution}, dim={self.dim}"
def flops(self):
H, W = self.input_resolution
flops = H * W * self.dim
flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
return flops
class BasicLayer(nn.Module):
""" A basic Swin Transformer layer for one stage.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resolution.
depth (int): Number of blocks.
num_heads (int): Number of attention heads.
window_size (int): Local window size.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
"""
def __init__(self, dim, input_resolution, depth, num_heads, window_size,
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.depth = depth
self.use_checkpoint = use_checkpoint
# build blocks
self.blocks = nn.ModuleList([
SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
num_heads=num_heads, window_size=window_size,
shift_size=0 if (i % 2 == 0) else window_size // 2,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop, attn_drop=attn_drop,
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
norm_layer=norm_layer)
for i in range(depth)])
# patch merging layer
if downsample is not None:
self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
else:
self.downsample = None
def forward(self, x, x_size):
for blk in self.blocks:
if self.use_checkpoint:
x = checkpoint.checkpoint(blk, x, x_size)
else:
x = blk(x, x_size)
if self.downsample is not None:
x = self.downsample(x)
return x
def extra_repr(self) -> str:
return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
def flops(self):
flops = 0
for blk in self.blocks:
flops += blk.flops()
if self.downsample is not None:
flops += self.downsample.flops()
return flops
class RSTB(nn.Module):
"""Residual Swin Transformer Block (RSTB).
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resolution.
depth (int): Number of blocks.
num_heads (int): Number of attention heads.
window_size (int): Local window size.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
img_size: Input image size.
patch_size: Patch size.
resi_connection: The convolutional block before residual connection.
"""
def __init__(self, dim, input_resolution, depth, num_heads, window_size,
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
img_size=224, patch_size=4, resi_connection='1conv'):
super(RSTB, self).__init__()
self.dim = dim
self.input_resolution = input_resolution
self.residual_group = BasicLayer(dim=dim,
input_resolution=input_resolution,
depth=depth,
num_heads=num_heads,
window_size=window_size,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop, attn_drop=attn_drop,
drop_path=drop_path,
norm_layer=norm_layer,
downsample=downsample,
use_checkpoint=use_checkpoint)
if resi_connection == '1conv':
self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
elif resi_connection == '3conv':
# to save parameters and memory
self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Conv2d(dim // 4, dim // 4, 1, 1, 0),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Conv2d(dim // 4, dim, 3, 1, 1))
self.patch_embed = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim,
norm_layer=None)
self.patch_unembed = PatchUnEmbed(
img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim,
norm_layer=None)
def forward(self, x, x_size):
return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x
def flops(self):
flops = 0
flops += self.residual_group.flops()
H, W = self.input_resolution
flops += H * W * self.dim * self.dim * 9
flops += self.patch_embed.flops()
flops += self.patch_unembed.flops()
return flops
class PatchEmbed(nn.Module):
r""" Image to Patch Embedding
Args:
img_size (int): Image size. Default: 224.
patch_size (int): Patch token size. Default: 4.
in_chans (int): Number of input image channels. Default: 3.
embed_dim (int): Number of linear projection output channels. Default: 96.
norm_layer (nn.Module, optional): Normalization layer. Default: None
"""
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
self.img_size = img_size
self.patch_size = patch_size
self.patches_resolution = patches_resolution
self.num_patches = patches_resolution[0] * patches_resolution[1]
self.in_chans = in_chans
self.embed_dim = embed_dim
if norm_layer is not None:
self.norm = norm_layer(embed_dim)
else:
self.norm = None
def forward(self, x):
x = x.flatten(2).transpose(1, 2) # B Ph*Pw C
if self.norm is not None:
x = self.norm(x)
return x
def flops(self):
flops = 0
H, W = self.img_size
if self.norm is not None:
flops += H * W * self.embed_dim
return flops
class PatchUnEmbed(nn.Module):
r""" Image to Patch Unembedding
Args:
img_size (int): Image size. Default: 224.
patch_size (int): Patch token size. Default: 4.
in_chans (int): Number of input image channels. Default: 3.
embed_dim (int): Number of linear projection output channels. Default: 96.
norm_layer (nn.Module, optional): Normalization layer. Default: None
"""
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
self.img_size = img_size
self.patch_size = patch_size
self.patches_resolution = patches_resolution
self.num_patches = patches_resolution[0] * patches_resolution[1]
self.in_chans = in_chans
self.embed_dim = embed_dim
def forward(self, x, x_size):
B, HW, C = x.shape
x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C
return x
def flops(self):
flops = 0
return flops
class Upsample(nn.Sequential):
"""Upsample module.
Args:
scale (int): Scale factor. Supported scales: 2^n and 3.
num_feat (int): Channel number of intermediate features.
"""
def __init__(self, scale, num_feat):
m = []
if (scale & (scale - 1)) == 0: # scale = 2^n
for _ in range(int(math.log(scale, 2))):
m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
m.append(nn.PixelShuffle(2))
elif scale == 3:
m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
m.append(nn.PixelShuffle(3))
else:
raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
super(Upsample, self).__init__(*m)
class UpsampleOneStep(nn.Sequential):
"""UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle)
Used in lightweight SR to save parameters.
Args:
scale (int): Scale factor. Supported scales: 2^n and 3.
num_feat (int): Channel number of intermediate features.
"""
def __init__(self, scale, num_feat, num_out_ch, input_resolution=None):
self.num_feat = num_feat
self.input_resolution = input_resolution
m = []
m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1))
m.append(nn.PixelShuffle(scale))
super(UpsampleOneStep, self).__init__(*m)
def flops(self):
H, W = self.input_resolution
flops = H * W * self.num_feat * 3 * 9
return flops
class SwinIR(nn.Module):
r""" SwinIR
A PyTorch impl of : `SwinIR: Image Restoration Using Swin Transformer`, based on Swin Transformer.
Args:
img_size (int | tuple(int)): Input image size. Default 64
patch_size (int | tuple(int)): Patch size. Default: 1
in_chans (int): Number of input image channels. Default: 3
embed_dim (int): Patch embedding dimension. Default: 96
depths (tuple(int)): Depth of each Swin Transformer layer.
num_heads (tuple(int)): Number of attention heads in different layers.
window_size (int): Window size. Default: 7
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
drop_rate (float): Dropout rate. Default: 0
attn_drop_rate (float): Attention dropout rate. Default: 0
drop_path_rate (float): Stochastic depth rate. Default: 0.1
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
patch_norm (bool): If True, add normalization after patch embedding. Default: True
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction
img_range: Image range. 1. or 255.
upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None
resi_connection: The convolutional block before residual connection. '1conv'/'3conv'
"""
def __init__(self, img_size=64, patch_size=1, in_chans=3,
embed_dim=96, depths=(6, 6, 6, 6), num_heads=(6, 6, 6, 6),
window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv',
**kwargs):
super(SwinIR, self).__init__()
num_in_ch = in_chans
num_out_ch = in_chans
num_feat = 64
self.img_range = img_range
if in_chans == 3:
rgb_mean = (0.4488, 0.4371, 0.4040)
self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
else:
self.mean = torch.zeros(1, 1, 1, 1)
self.upscale = upscale
self.upsampler = upsampler
self.window_size = window_size
#####################################################################################################
################################### 1, shallow feature extraction ###################################
self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)
#####################################################################################################
################################### 2, deep feature extraction ######################################
self.num_layers = len(depths)
self.embed_dim = embed_dim
self.ape = ape
self.patch_norm = patch_norm
self.num_features = embed_dim
self.mlp_ratio = mlp_ratio
# split image into non-overlapping patches
self.patch_embed = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
norm_layer=norm_layer if self.patch_norm else None)
num_patches = self.patch_embed.num_patches
patches_resolution = self.patch_embed.patches_resolution
self.patches_resolution = patches_resolution
# merge non-overlapping patches into image
self.patch_unembed = PatchUnEmbed(
img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
norm_layer=norm_layer if self.patch_norm else None)
# absolute position embedding
if self.ape:
self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
trunc_normal_(self.absolute_pos_embed, std=.02)
self.pos_drop = nn.Dropout(p=drop_rate)
# stochastic depth
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
# build Residual Swin Transformer blocks (RSTB)
self.layers = nn.ModuleList()
for i_layer in range(self.num_layers):
layer = RSTB(dim=embed_dim,
input_resolution=(patches_resolution[0],
patches_resolution[1]),
depth=depths[i_layer],
num_heads=num_heads[i_layer],
window_size=window_size,
mlp_ratio=self.mlp_ratio,
qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate,
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
norm_layer=norm_layer,
downsample=None,
use_checkpoint=use_checkpoint,
img_size=img_size,
patch_size=patch_size,
resi_connection=resi_connection
)
self.layers.append(layer)
self.norm = norm_layer(self.num_features)
# build the last conv layer in deep feature extraction
if resi_connection == '1conv':
self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
elif resi_connection == '3conv':
# to save parameters and memory
self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1))
#####################################################################################################
################################ 3, high quality image reconstruction ################################
if self.upsampler == 'pixelshuffle':
# for classical SR
self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
nn.LeakyReLU(inplace=True))
self.upsample = Upsample(upscale, num_feat)
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
elif self.upsampler == 'pixelshuffledirect':
# for lightweight SR (to save parameters)
self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch,
(patches_resolution[0], patches_resolution[1]))
elif self.upsampler == 'nearest+conv':
# for real-world SR (less artifacts)
self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
nn.LeakyReLU(inplace=True))
self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
if self.upscale == 4:
self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
else:
# for image denoising and JPEG compression artifact reduction
self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
@torch.jit.ignore
def no_weight_decay(self):
return {'absolute_pos_embed'}
@torch.jit.ignore
def no_weight_decay_keywords(self):
return {'relative_position_bias_table'}
def check_image_size(self, x):
_, _, h, w = x.size()
mod_pad_h = (self.window_size - h % self.window_size) % self.window_size
mod_pad_w = (self.window_size - w % self.window_size) % self.window_size
x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect')
return x
def forward_features(self, x):
x_size = (x.shape[2], x.shape[3])
x = self.patch_embed(x)
if self.ape:
x = x + self.absolute_pos_embed
x = self.pos_drop(x)
for layer in self.layers:
x = layer(x, x_size)
x = self.norm(x) # B L C
x = self.patch_unembed(x, x_size)
return x
def forward(self, x):
H, W = x.shape[2:]
x = self.check_image_size(x)
self.mean = self.mean.type_as(x)
x = (x - self.mean) * self.img_range
if self.upsampler == 'pixelshuffle':
# for classical SR
x = self.conv_first(x)
x = self.conv_after_body(self.forward_features(x)) + x
x = self.conv_before_upsample(x)
x = self.conv_last(self.upsample(x))
elif self.upsampler == 'pixelshuffledirect':
# for lightweight SR
x = self.conv_first(x)
x = self.conv_after_body(self.forward_features(x)) + x
x = self.upsample(x)
elif self.upsampler == 'nearest+conv':
# for real-world SR
x = self.conv_first(x)
x = self.conv_after_body(self.forward_features(x)) + x
x = self.conv_before_upsample(x)
x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
if self.upscale == 4:
x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
x = self.conv_last(self.lrelu(self.conv_hr(x)))
else:
# for image denoising and JPEG compression artifact reduction
x_first = self.conv_first(x)
res = self.conv_after_body(self.forward_features(x_first)) + x_first
x = x + self.conv_last(res)
x = x / self.img_range + self.mean
return x[:, :, :H*self.upscale, :W*self.upscale]
def flops(self):
flops = 0
H, W = self.patches_resolution
flops += H * W * 3 * self.embed_dim * 9
flops += self.patch_embed.flops()
for layer in self.layers:
flops += layer.flops()
flops += H * W * 3 * self.embed_dim * self.embed_dim
flops += self.upsample.flops()
return flops
if __name__ == '__main__':
upscale = 4
window_size = 8
height = (1024 // upscale // window_size + 1) * window_size
width = (720 // upscale // window_size + 1) * window_size
model = SwinIR(upscale=2, img_size=(height, width),
window_size=window_size, img_range=1., depths=[6, 6, 6, 6],
embed_dim=60, num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect')
print(model)
print(height, width, model.flops() / 1e9)
x = torch.randn((1, 3, height, width))
x = model(x)
print(x.shape)
File diff suppressed because it is too large Load Diff
@@ -218,6 +218,8 @@ onUiLoaded(async() => {
canvas_hotkey_fullscreen: "KeyS", canvas_hotkey_fullscreen: "KeyS",
canvas_hotkey_move: "KeyF", canvas_hotkey_move: "KeyF",
canvas_hotkey_overlap: "KeyO", canvas_hotkey_overlap: "KeyO",
canvas_hotkey_shrink_brush: "KeyQ",
canvas_hotkey_grow_brush: "KeyW",
canvas_disabled_functions: [], canvas_disabled_functions: [],
canvas_show_tooltip: true, canvas_show_tooltip: true,
canvas_auto_expand: true, canvas_auto_expand: true,
@@ -227,6 +229,8 @@ onUiLoaded(async() => {
const functionMap = { const functionMap = {
"Zoom": "canvas_hotkey_zoom", "Zoom": "canvas_hotkey_zoom",
"Adjust brush size": "canvas_hotkey_adjust", "Adjust brush size": "canvas_hotkey_adjust",
"Hotkey shrink brush": "canvas_hotkey_shrink_brush",
"Hotkey enlarge brush": "canvas_hotkey_grow_brush",
"Moving canvas": "canvas_hotkey_move", "Moving canvas": "canvas_hotkey_move",
"Fullscreen": "canvas_hotkey_fullscreen", "Fullscreen": "canvas_hotkey_fullscreen",
"Reset Zoom": "canvas_hotkey_reset", "Reset Zoom": "canvas_hotkey_reset",
@@ -686,7 +690,9 @@ onUiLoaded(async() => {
const hotkeyActions = { const hotkeyActions = {
[hotkeysConfig.canvas_hotkey_reset]: resetZoom, [hotkeysConfig.canvas_hotkey_reset]: resetZoom,
[hotkeysConfig.canvas_hotkey_overlap]: toggleOverlap, [hotkeysConfig.canvas_hotkey_overlap]: toggleOverlap,
[hotkeysConfig.canvas_hotkey_fullscreen]: fitToScreen [hotkeysConfig.canvas_hotkey_fullscreen]: fitToScreen,
[hotkeysConfig.canvas_hotkey_shrink_brush]: () => adjustBrushSize(elemId, 10),
[hotkeysConfig.canvas_hotkey_grow_brush]: () => adjustBrushSize(elemId, -10)
}; };
const action = hotkeyActions[event.code]; const action = hotkeyActions[event.code];
@@ -4,6 +4,8 @@ from modules import shared
shared.options_templates.update(shared.options_section(('canvas_hotkey', "Canvas Hotkeys"), { shared.options_templates.update(shared.options_section(('canvas_hotkey', "Canvas Hotkeys"), {
"canvas_hotkey_zoom": shared.OptionInfo("Alt", "Zoom canvas", gr.Radio, {"choices": ["Shift","Ctrl", "Alt"]}).info("If you choose 'Shift' you cannot scroll horizontally, 'Alt' can cause a little trouble in firefox"), "canvas_hotkey_zoom": shared.OptionInfo("Alt", "Zoom canvas", gr.Radio, {"choices": ["Shift","Ctrl", "Alt"]}).info("If you choose 'Shift' you cannot scroll horizontally, 'Alt' can cause a little trouble in firefox"),
"canvas_hotkey_adjust": shared.OptionInfo("Ctrl", "Adjust brush size", gr.Radio, {"choices": ["Shift","Ctrl", "Alt"]}).info("If you choose 'Shift' you cannot scroll horizontally, 'Alt' can cause a little trouble in firefox"), "canvas_hotkey_adjust": shared.OptionInfo("Ctrl", "Adjust brush size", gr.Radio, {"choices": ["Shift","Ctrl", "Alt"]}).info("If you choose 'Shift' you cannot scroll horizontally, 'Alt' can cause a little trouble in firefox"),
"canvas_hotkey_shrink_brush": shared.OptionInfo("Q", "Shrink the brush size"),
"canvas_hotkey_grow_brush": shared.OptionInfo("W", "Enlarge the brush size"),
"canvas_hotkey_move": shared.OptionInfo("F", "Moving the canvas").info("To work correctly in firefox, turn off 'Automatically search the page text when typing' in the browser settings"), "canvas_hotkey_move": shared.OptionInfo("F", "Moving the canvas").info("To work correctly in firefox, turn off 'Automatically search the page text when typing' in the browser settings"),
"canvas_hotkey_fullscreen": shared.OptionInfo("S", "Fullscreen Mode, maximizes the picture so that it fits into the screen and stretches it to its full width "), "canvas_hotkey_fullscreen": shared.OptionInfo("S", "Fullscreen Mode, maximizes the picture so that it fits into the screen and stretches it to its full width "),
"canvas_hotkey_reset": shared.OptionInfo("R", "Reset zoom and canvas positon"), "canvas_hotkey_reset": shared.OptionInfo("R", "Reset zoom and canvas positon"),
@@ -11,5 +13,5 @@ shared.options_templates.update(shared.options_section(('canvas_hotkey', "Canvas
"canvas_show_tooltip": shared.OptionInfo(True, "Enable tooltip on the canvas"), "canvas_show_tooltip": shared.OptionInfo(True, "Enable tooltip on the canvas"),
"canvas_auto_expand": shared.OptionInfo(True, "Automatically expands an image that does not fit completely in the canvas area, similar to manually pressing the S and R buttons"), "canvas_auto_expand": shared.OptionInfo(True, "Automatically expands an image that does not fit completely in the canvas area, similar to manually pressing the S and R buttons"),
"canvas_blur_prompt": shared.OptionInfo(False, "Take the focus off the prompt when working with a canvas"), "canvas_blur_prompt": shared.OptionInfo(False, "Take the focus off the prompt when working with a canvas"),
"canvas_disabled_functions": shared.OptionInfo(["Overlap"], "Disable function that you don't use", gr.CheckboxGroup, {"choices": ["Zoom","Adjust brush size", "Moving canvas","Fullscreen","Reset Zoom","Overlap"]}), "canvas_disabled_functions": shared.OptionInfo(["Overlap"], "Disable function that you don't use", gr.CheckboxGroup, {"choices": ["Zoom","Adjust brush size","Hotkey enlarge brush","Hotkey shrink brush","Moving canvas","Fullscreen","Reset Zoom","Overlap"]}),
})) }))
@@ -1,7 +1,7 @@
import math import math
import gradio as gr import gradio as gr
from modules import scripts, shared, ui_components, ui_settings, generation_parameters_copypaste from modules import scripts, shared, ui_components, ui_settings, infotext_utils
from modules.ui_components import FormColumn from modules.ui_components import FormColumn
@@ -23,11 +23,12 @@ class ExtraOptionsSection(scripts.Script):
self.setting_names = [] self.setting_names = []
self.infotext_fields = [] self.infotext_fields = []
extra_options = shared.opts.extra_options_img2img if is_img2img else shared.opts.extra_options_txt2img extra_options = shared.opts.extra_options_img2img if is_img2img else shared.opts.extra_options_txt2img
elem_id_tabname = "extra_options_" + ("img2img" if is_img2img else "txt2img")
mapping = {k: v for v, k in generation_parameters_copypaste.infotext_to_setting_name_mapping} mapping = {k: v for v, k in infotext_utils.infotext_to_setting_name_mapping}
with gr.Blocks() as interface: with gr.Blocks() as interface:
with gr.Accordion("Options", open=False) if shared.opts.extra_options_accordion and extra_options else gr.Group(): with gr.Accordion("Options", open=False, elem_id=elem_id_tabname) if shared.opts.extra_options_accordion and extra_options else gr.Group(elem_id=elem_id_tabname):
row_count = math.ceil(len(extra_options) / shared.opts.extra_options_cols) row_count = math.ceil(len(extra_options) / shared.opts.extra_options_cols)
@@ -64,11 +65,14 @@ class ExtraOptionsSection(scripts.Script):
p.override_settings[name] = value p.override_settings[name] = value
shared.options_templates.update(shared.options_section(('ui', "User interface"), { shared.options_templates.update(shared.options_section(('settings_in_ui', "Settings in UI", "ui"), {
"extra_options_txt2img": shared.OptionInfo([], "Options in main UI - txt2img", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in txt2img interfaces").needs_reload_ui(), "settings_in_ui": shared.OptionHTML("""
"extra_options_img2img": shared.OptionInfo([], "Options in main UI - img2img", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in img2img interfaces").needs_reload_ui(), This page allows you to add some settings to the main interface of txt2img and img2img tabs.
"extra_options_cols": shared.OptionInfo(1, "Options in main UI - number of columns", gr.Number, {"precision": 0}).needs_reload_ui(), """),
"extra_options_accordion": shared.OptionInfo(False, "Options in main UI - place into an accordion").needs_reload_ui() "extra_options_txt2img": shared.OptionInfo([], "Settings for txt2img", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in txt2img interfaces").needs_reload_ui(),
"extra_options_img2img": shared.OptionInfo([], "Settings for img2img", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in img2img interfaces").needs_reload_ui(),
"extra_options_cols": shared.OptionInfo(1, "Number of columns for added settings", gr.Slider, {"step": 1, "minimum": 1, "maximum": 20}).info("displayed amount will depend on the actual browser window width").needs_reload_ui(),
"extra_options_accordion": shared.OptionInfo(False, "Place added settings into an accordion").needs_reload_ui()
})) }))
+351
View File
@@ -0,0 +1,351 @@
"""
Hypertile module for splitting attention layers in SD-1.5 U-Net and SD-1.5 VAE
Warn: The patch works well only if the input image has a width and height that are multiples of 128
Original author: @tfernd Github: https://github.com/tfernd/HyperTile
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Callable
from functools import wraps, cache
import math
import torch.nn as nn
import random
from einops import rearrange
@dataclass
class HypertileParams:
depth = 0
layer_name = ""
tile_size: int = 0
swap_size: int = 0
aspect_ratio: float = 1.0
forward = None
enabled = False
# TODO add SD-XL layers
DEPTH_LAYERS = {
0: [
# SD 1.5 U-Net (diffusers)
"down_blocks.0.attentions.0.transformer_blocks.0.attn1",
"down_blocks.0.attentions.1.transformer_blocks.0.attn1",
"up_blocks.3.attentions.0.transformer_blocks.0.attn1",
"up_blocks.3.attentions.1.transformer_blocks.0.attn1",
"up_blocks.3.attentions.2.transformer_blocks.0.attn1",
# SD 1.5 U-Net (ldm)
"input_blocks.1.1.transformer_blocks.0.attn1",
"input_blocks.2.1.transformer_blocks.0.attn1",
"output_blocks.9.1.transformer_blocks.0.attn1",
"output_blocks.10.1.transformer_blocks.0.attn1",
"output_blocks.11.1.transformer_blocks.0.attn1",
# SD 1.5 VAE
"decoder.mid_block.attentions.0",
"decoder.mid.attn_1",
],
1: [
# SD 1.5 U-Net (diffusers)
"down_blocks.1.attentions.0.transformer_blocks.0.attn1",
"down_blocks.1.attentions.1.transformer_blocks.0.attn1",
"up_blocks.2.attentions.0.transformer_blocks.0.attn1",
"up_blocks.2.attentions.1.transformer_blocks.0.attn1",
"up_blocks.2.attentions.2.transformer_blocks.0.attn1",
# SD 1.5 U-Net (ldm)
"input_blocks.4.1.transformer_blocks.0.attn1",
"input_blocks.5.1.transformer_blocks.0.attn1",
"output_blocks.6.1.transformer_blocks.0.attn1",
"output_blocks.7.1.transformer_blocks.0.attn1",
"output_blocks.8.1.transformer_blocks.0.attn1",
],
2: [
# SD 1.5 U-Net (diffusers)
"down_blocks.2.attentions.0.transformer_blocks.0.attn1",
"down_blocks.2.attentions.1.transformer_blocks.0.attn1",
"up_blocks.1.attentions.0.transformer_blocks.0.attn1",
"up_blocks.1.attentions.1.transformer_blocks.0.attn1",
"up_blocks.1.attentions.2.transformer_blocks.0.attn1",
# SD 1.5 U-Net (ldm)
"input_blocks.7.1.transformer_blocks.0.attn1",
"input_blocks.8.1.transformer_blocks.0.attn1",
"output_blocks.3.1.transformer_blocks.0.attn1",
"output_blocks.4.1.transformer_blocks.0.attn1",
"output_blocks.5.1.transformer_blocks.0.attn1",
],
3: [
# SD 1.5 U-Net (diffusers)
"mid_block.attentions.0.transformer_blocks.0.attn1",
# SD 1.5 U-Net (ldm)
"middle_block.1.transformer_blocks.0.attn1",
],
}
# XL layers, thanks for GitHub@gel-crabs for the help
DEPTH_LAYERS_XL = {
0: [
# SD 1.5 U-Net (diffusers)
"down_blocks.0.attentions.0.transformer_blocks.0.attn1",
"down_blocks.0.attentions.1.transformer_blocks.0.attn1",
"up_blocks.3.attentions.0.transformer_blocks.0.attn1",
"up_blocks.3.attentions.1.transformer_blocks.0.attn1",
"up_blocks.3.attentions.2.transformer_blocks.0.attn1",
# SD 1.5 U-Net (ldm)
"input_blocks.4.1.transformer_blocks.0.attn1",
"input_blocks.5.1.transformer_blocks.0.attn1",
"output_blocks.3.1.transformer_blocks.0.attn1",
"output_blocks.4.1.transformer_blocks.0.attn1",
"output_blocks.5.1.transformer_blocks.0.attn1",
# SD 1.5 VAE
"decoder.mid_block.attentions.0",
"decoder.mid.attn_1",
],
1: [
# SD 1.5 U-Net (diffusers)
#"down_blocks.1.attentions.0.transformer_blocks.0.attn1",
#"down_blocks.1.attentions.1.transformer_blocks.0.attn1",
#"up_blocks.2.attentions.0.transformer_blocks.0.attn1",
#"up_blocks.2.attentions.1.transformer_blocks.0.attn1",
#"up_blocks.2.attentions.2.transformer_blocks.0.attn1",
# SD 1.5 U-Net (ldm)
"input_blocks.4.1.transformer_blocks.1.attn1",
"input_blocks.5.1.transformer_blocks.1.attn1",
"output_blocks.3.1.transformer_blocks.1.attn1",
"output_blocks.4.1.transformer_blocks.1.attn1",
"output_blocks.5.1.transformer_blocks.1.attn1",
"input_blocks.7.1.transformer_blocks.0.attn1",
"input_blocks.8.1.transformer_blocks.0.attn1",
"output_blocks.0.1.transformer_blocks.0.attn1",
"output_blocks.1.1.transformer_blocks.0.attn1",
"output_blocks.2.1.transformer_blocks.0.attn1",
"input_blocks.7.1.transformer_blocks.1.attn1",
"input_blocks.8.1.transformer_blocks.1.attn1",
"output_blocks.0.1.transformer_blocks.1.attn1",
"output_blocks.1.1.transformer_blocks.1.attn1",
"output_blocks.2.1.transformer_blocks.1.attn1",
"input_blocks.7.1.transformer_blocks.2.attn1",
"input_blocks.8.1.transformer_blocks.2.attn1",
"output_blocks.0.1.transformer_blocks.2.attn1",
"output_blocks.1.1.transformer_blocks.2.attn1",
"output_blocks.2.1.transformer_blocks.2.attn1",
"input_blocks.7.1.transformer_blocks.3.attn1",
"input_blocks.8.1.transformer_blocks.3.attn1",
"output_blocks.0.1.transformer_blocks.3.attn1",
"output_blocks.1.1.transformer_blocks.3.attn1",
"output_blocks.2.1.transformer_blocks.3.attn1",
"input_blocks.7.1.transformer_blocks.4.attn1",
"input_blocks.8.1.transformer_blocks.4.attn1",
"output_blocks.0.1.transformer_blocks.4.attn1",
"output_blocks.1.1.transformer_blocks.4.attn1",
"output_blocks.2.1.transformer_blocks.4.attn1",
"input_blocks.7.1.transformer_blocks.5.attn1",
"input_blocks.8.1.transformer_blocks.5.attn1",
"output_blocks.0.1.transformer_blocks.5.attn1",
"output_blocks.1.1.transformer_blocks.5.attn1",
"output_blocks.2.1.transformer_blocks.5.attn1",
"input_blocks.7.1.transformer_blocks.6.attn1",
"input_blocks.8.1.transformer_blocks.6.attn1",
"output_blocks.0.1.transformer_blocks.6.attn1",
"output_blocks.1.1.transformer_blocks.6.attn1",
"output_blocks.2.1.transformer_blocks.6.attn1",
"input_blocks.7.1.transformer_blocks.7.attn1",
"input_blocks.8.1.transformer_blocks.7.attn1",
"output_blocks.0.1.transformer_blocks.7.attn1",
"output_blocks.1.1.transformer_blocks.7.attn1",
"output_blocks.2.1.transformer_blocks.7.attn1",
"input_blocks.7.1.transformer_blocks.8.attn1",
"input_blocks.8.1.transformer_blocks.8.attn1",
"output_blocks.0.1.transformer_blocks.8.attn1",
"output_blocks.1.1.transformer_blocks.8.attn1",
"output_blocks.2.1.transformer_blocks.8.attn1",
"input_blocks.7.1.transformer_blocks.9.attn1",
"input_blocks.8.1.transformer_blocks.9.attn1",
"output_blocks.0.1.transformer_blocks.9.attn1",
"output_blocks.1.1.transformer_blocks.9.attn1",
"output_blocks.2.1.transformer_blocks.9.attn1",
],
2: [
# SD 1.5 U-Net (diffusers)
"mid_block.attentions.0.transformer_blocks.0.attn1",
# SD 1.5 U-Net (ldm)
"middle_block.1.transformer_blocks.0.attn1",
"middle_block.1.transformer_blocks.1.attn1",
"middle_block.1.transformer_blocks.2.attn1",
"middle_block.1.transformer_blocks.3.attn1",
"middle_block.1.transformer_blocks.4.attn1",
"middle_block.1.transformer_blocks.5.attn1",
"middle_block.1.transformer_blocks.6.attn1",
"middle_block.1.transformer_blocks.7.attn1",
"middle_block.1.transformer_blocks.8.attn1",
"middle_block.1.transformer_blocks.9.attn1",
],
3 : [] # TODO - separate layers for SD-XL
}
RNG_INSTANCE = random.Random()
@cache
def get_divisors(value: int, min_value: int, /, max_options: int = 1) -> list[int]:
"""
Returns divisors of value that
x * min_value <= value
in big -> small order, amount of divisors is limited by max_options
"""
max_options = max(1, max_options) # at least 1 option should be returned
min_value = min(min_value, value)
divisors = [i for i in range(min_value, value + 1) if value % i == 0] # divisors in small -> big order
ns = [value // i for i in divisors[:max_options]] # has at least 1 element # big -> small order
return ns
def random_divisor(value: int, min_value: int, /, max_options: int = 1) -> int:
"""
Returns a random divisor of value that
x * min_value <= value
if max_options is 1, the behavior is deterministic
"""
ns = get_divisors(value, min_value, max_options=max_options) # get cached divisors
idx = RNG_INSTANCE.randint(0, len(ns) - 1)
return ns[idx]
def set_hypertile_seed(seed: int) -> None:
RNG_INSTANCE.seed(seed)
@cache
def largest_tile_size_available(width: int, height: int) -> int:
"""
Calculates the largest tile size available for a given width and height
Tile size is always a power of 2
"""
gcd = math.gcd(width, height)
largest_tile_size_available = 1
while gcd % (largest_tile_size_available * 2) == 0:
largest_tile_size_available *= 2
return largest_tile_size_available
def iterative_closest_divisors(hw:int, aspect_ratio:float) -> tuple[int, int]:
"""
Finds h and w such that h*w = hw and h/w = aspect_ratio
We check all possible divisors of hw and return the closest to the aspect ratio
"""
divisors = [i for i in range(2, hw + 1) if hw % i == 0] # all divisors of hw
pairs = [(i, hw // i) for i in divisors] # all pairs of divisors of hw
ratios = [w/h for h, w in pairs] # all ratios of pairs of divisors of hw
closest_ratio = min(ratios, key=lambda x: abs(x - aspect_ratio)) # closest ratio to aspect_ratio
closest_pair = pairs[ratios.index(closest_ratio)] # closest pair of divisors to aspect_ratio
return closest_pair
@cache
def find_hw_candidates(hw:int, aspect_ratio:float) -> tuple[int, int]:
"""
Finds h and w such that h*w = hw and h/w = aspect_ratio
"""
h, w = round(math.sqrt(hw * aspect_ratio)), round(math.sqrt(hw / aspect_ratio))
# find h and w such that h*w = hw and h/w = aspect_ratio
if h * w != hw:
w_candidate = hw / h
# check if w is an integer
if not w_candidate.is_integer():
h_candidate = hw / w
# check if h is an integer
if not h_candidate.is_integer():
return iterative_closest_divisors(hw, aspect_ratio)
else:
h = int(h_candidate)
else:
w = int(w_candidate)
return h, w
def self_attn_forward(params: HypertileParams, scale_depth=True) -> Callable:
@wraps(params.forward)
def wrapper(*args, **kwargs):
if not params.enabled:
return params.forward(*args, **kwargs)
latent_tile_size = max(128, params.tile_size) // 8
x = args[0]
# VAE
if x.ndim == 4:
b, c, h, w = x.shape
nh = random_divisor(h, latent_tile_size, params.swap_size)
nw = random_divisor(w, latent_tile_size, params.swap_size)
if nh * nw > 1:
x = rearrange(x, "b c (nh h) (nw w) -> (b nh nw) c h w", nh=nh, nw=nw) # split into nh * nw tiles
out = params.forward(x, *args[1:], **kwargs)
if nh * nw > 1:
out = rearrange(out, "(b nh nw) c h w -> b c (nh h) (nw w)", nh=nh, nw=nw)
# U-Net
else:
hw: int = x.size(1)
h, w = find_hw_candidates(hw, params.aspect_ratio)
assert h * w == hw, f"Invalid aspect ratio {params.aspect_ratio} for input of shape {x.shape}, hw={hw}, h={h}, w={w}"
factor = 2 ** params.depth if scale_depth else 1
nh = random_divisor(h, latent_tile_size * factor, params.swap_size)
nw = random_divisor(w, latent_tile_size * factor, params.swap_size)
if nh * nw > 1:
x = rearrange(x, "b (nh h nw w) c -> (b nh nw) (h w) c", h=h // nh, w=w // nw, nh=nh, nw=nw)
out = params.forward(x, *args[1:], **kwargs)
if nh * nw > 1:
out = rearrange(out, "(b nh nw) hw c -> b nh nw hw c", nh=nh, nw=nw)
out = rearrange(out, "b nh nw (h w) c -> b (nh h nw w) c", h=h // nh, w=w // nw)
return out
return wrapper
def hypertile_hook_model(model: nn.Module, width, height, *, enable=False, tile_size_max=128, swap_size=1, max_depth=3, is_sdxl=False):
hypertile_layers = getattr(model, "__webui_hypertile_layers", None)
if hypertile_layers is None:
if not enable:
return
hypertile_layers = {}
layers = DEPTH_LAYERS_XL if is_sdxl else DEPTH_LAYERS
for depth in range(4):
for layer_name, module in model.named_modules():
if any(layer_name.endswith(try_name) for try_name in layers[depth]):
params = HypertileParams()
module.__webui_hypertile_params = params
params.forward = module.forward
params.depth = depth
params.layer_name = layer_name
module.forward = self_attn_forward(params)
hypertile_layers[layer_name] = 1
model.__webui_hypertile_layers = hypertile_layers
aspect_ratio = width / height
tile_size = min(largest_tile_size_available(width, height), tile_size_max)
for layer_name, module in model.named_modules():
if layer_name in hypertile_layers:
params = module.__webui_hypertile_params
params.tile_size = tile_size
params.swap_size = swap_size
params.aspect_ratio = aspect_ratio
params.enabled = enable and params.depth <= max_depth
@@ -0,0 +1,109 @@
import hypertile
from modules import scripts, script_callbacks, shared
from scripts.hypertile_xyz import add_axis_options
class ScriptHypertile(scripts.Script):
name = "Hypertile"
def title(self):
return self.name
def show(self, is_img2img):
return scripts.AlwaysVisible
def process(self, p, *args):
hypertile.set_hypertile_seed(p.all_seeds[0])
configure_hypertile(p.width, p.height, enable_unet=shared.opts.hypertile_enable_unet)
self.add_infotext(p)
def before_hr(self, p, *args):
enable = shared.opts.hypertile_enable_unet_secondpass or shared.opts.hypertile_enable_unet
# exclusive hypertile seed for the second pass
if enable:
hypertile.set_hypertile_seed(p.all_seeds[0])
configure_hypertile(p.hr_upscale_to_x, p.hr_upscale_to_y, enable_unet=enable)
if enable and not shared.opts.hypertile_enable_unet:
p.extra_generation_params["Hypertile U-Net second pass"] = True
self.add_infotext(p, add_unet_params=True)
def add_infotext(self, p, add_unet_params=False):
def option(name):
value = getattr(shared.opts, name)
default_value = shared.opts.get_default(name)
return None if value == default_value else value
if shared.opts.hypertile_enable_unet:
p.extra_generation_params["Hypertile U-Net"] = True
if shared.opts.hypertile_enable_unet or add_unet_params:
p.extra_generation_params["Hypertile U-Net max depth"] = option('hypertile_max_depth_unet')
p.extra_generation_params["Hypertile U-Net max tile size"] = option('hypertile_max_tile_unet')
p.extra_generation_params["Hypertile U-Net swap size"] = option('hypertile_swap_size_unet')
if shared.opts.hypertile_enable_vae:
p.extra_generation_params["Hypertile VAE"] = True
p.extra_generation_params["Hypertile VAE max depth"] = option('hypertile_max_depth_vae')
p.extra_generation_params["Hypertile VAE max tile size"] = option('hypertile_max_tile_vae')
p.extra_generation_params["Hypertile VAE swap size"] = option('hypertile_swap_size_vae')
def configure_hypertile(width, height, enable_unet=True):
hypertile.hypertile_hook_model(
shared.sd_model.first_stage_model,
width,
height,
swap_size=shared.opts.hypertile_swap_size_vae,
max_depth=shared.opts.hypertile_max_depth_vae,
tile_size_max=shared.opts.hypertile_max_tile_vae,
enable=shared.opts.hypertile_enable_vae,
)
hypertile.hypertile_hook_model(
shared.sd_model.model,
width,
height,
swap_size=shared.opts.hypertile_swap_size_unet,
max_depth=shared.opts.hypertile_max_depth_unet,
tile_size_max=shared.opts.hypertile_max_tile_unet,
enable=enable_unet,
is_sdxl=shared.sd_model.is_sdxl
)
def on_ui_settings():
import gradio as gr
options = {
"hypertile_explanation": shared.OptionHTML("""
<a href='https://github.com/tfernd/HyperTile'>Hypertile</a> optimizes the self-attention layer within U-Net and VAE models,
resulting in a reduction in computation time ranging from 1 to 4 times. The larger the generated image is, the greater the
benefit.
"""),
"hypertile_enable_unet": shared.OptionInfo(False, "Enable Hypertile U-Net", infotext="Hypertile U-Net").info("enables hypertile for all modes, including hires fix second pass; noticeable change in details of the generated picture"),
"hypertile_enable_unet_secondpass": shared.OptionInfo(False, "Enable Hypertile U-Net for hires fix second pass", infotext="Hypertile U-Net second pass").info("enables hypertile just for hires fix second pass - regardless of whether the above setting is enabled"),
"hypertile_max_depth_unet": shared.OptionInfo(3, "Hypertile U-Net max depth", gr.Slider, {"minimum": 0, "maximum": 3, "step": 1}, infotext="Hypertile U-Net max depth").info("larger = more neural network layers affected; minor effect on performance"),
"hypertile_max_tile_unet": shared.OptionInfo(256, "Hypertile U-Net max tile size", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}, infotext="Hypertile U-Net max tile size").info("larger = worse performance"),
"hypertile_swap_size_unet": shared.OptionInfo(3, "Hypertile U-Net swap size", gr.Slider, {"minimum": 0, "maximum": 64, "step": 1}, infotext="Hypertile U-Net swap size"),
"hypertile_enable_vae": shared.OptionInfo(False, "Enable Hypertile VAE", infotext="Hypertile VAE").info("minimal change in the generated picture"),
"hypertile_max_depth_vae": shared.OptionInfo(3, "Hypertile VAE max depth", gr.Slider, {"minimum": 0, "maximum": 3, "step": 1}, infotext="Hypertile VAE max depth"),
"hypertile_max_tile_vae": shared.OptionInfo(128, "Hypertile VAE max tile size", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}, infotext="Hypertile VAE max tile size"),
"hypertile_swap_size_vae": shared.OptionInfo(3, "Hypertile VAE swap size ", gr.Slider, {"minimum": 0, "maximum": 64, "step": 1}, infotext="Hypertile VAE swap size"),
}
for name, opt in options.items():
opt.section = ('hypertile', "Hypertile")
shared.opts.add_option(name, opt)
script_callbacks.on_ui_settings(on_ui_settings)
script_callbacks.on_before_ui(add_axis_options)
@@ -0,0 +1,51 @@
from modules import scripts
from modules.shared import opts
xyz_grid = [x for x in scripts.scripts_data if x.script_class.__module__ == "xyz_grid.py"][0].module
def int_applier(value_name:str, min_range:int = -1, max_range:int = -1):
"""
Returns a function that applies the given value to the given value_name in opts.data.
"""
def validate(value_name:str, value:str):
value = int(value)
# validate value
if not min_range == -1:
assert value >= min_range, f"Value {value} for {value_name} must be greater than or equal to {min_range}"
if not max_range == -1:
assert value <= max_range, f"Value {value} for {value_name} must be less than or equal to {max_range}"
def apply_int(p, x, xs):
validate(value_name, x)
opts.data[value_name] = int(x)
return apply_int
def bool_applier(value_name:str):
"""
Returns a function that applies the given value to the given value_name in opts.data.
"""
def validate(value_name:str, value:str):
assert value.lower() in ["true", "false"], f"Value {value} for {value_name} must be either true or false"
def apply_bool(p, x, xs):
validate(value_name, x)
value_boolean = x.lower() == "true"
opts.data[value_name] = value_boolean
return apply_bool
def add_axis_options():
extra_axis_options = [
xyz_grid.AxisOption("[Hypertile] Unet First pass Enabled", str, bool_applier("hypertile_enable_unet"), choices=xyz_grid.boolean_choice(reverse=True)),
xyz_grid.AxisOption("[Hypertile] Unet Second pass Enabled", str, bool_applier("hypertile_enable_unet_secondpass"), choices=xyz_grid.boolean_choice(reverse=True)),
xyz_grid.AxisOption("[Hypertile] Unet Max Depth", int, int_applier("hypertile_max_depth_unet", 0, 3), choices=lambda: [str(x) for x in range(4)]),
xyz_grid.AxisOption("[Hypertile] Unet Max Tile Size", int, int_applier("hypertile_max_tile_unet", 0, 512)),
xyz_grid.AxisOption("[Hypertile] Unet Swap Size", int, int_applier("hypertile_swap_size_unet", 0, 64)),
xyz_grid.AxisOption("[Hypertile] VAE Enabled", str, bool_applier("hypertile_enable_vae"), choices=xyz_grid.boolean_choice(reverse=True)),
xyz_grid.AxisOption("[Hypertile] VAE Max Depth", int, int_applier("hypertile_max_depth_vae", 0, 3), choices=lambda: [str(x) for x in range(4)]),
xyz_grid.AxisOption("[Hypertile] VAE Max Tile Size", int, int_applier("hypertile_max_tile_vae", 0, 512)),
xyz_grid.AxisOption("[Hypertile] VAE Swap Size", int, int_applier("hypertile_swap_size_vae", 0, 64)),
]
set_a = {opt.label for opt in xyz_grid.axis_options}
set_b = {opt.label for opt in extra_axis_options}
if set_a.intersection(set_b):
return
xyz_grid.axis_options.extend(extra_axis_options)
@@ -12,6 +12,8 @@ function isMobile() {
} }
function reportWindowSize() { function reportWindowSize() {
if (gradioApp().querySelector('.toprow-compact-tools')) return; // not applicable for compact prompt layout
var currentlyMobile = isMobile(); var currentlyMobile = isMobile();
if (currentlyMobile == isSetupForMobile) return; if (currentlyMobile == isSetupForMobile) return;
isSetupForMobile = currentlyMobile; isSetupForMobile = currentlyMobile;
@@ -0,0 +1,759 @@
import numpy as np
import gradio as gr
import math
from modules.ui_components import InputAccordion
import modules.scripts as scripts
from modules import infotext_utils
infotext_utils.register_info_json('Soft Inpainting')
class SoftInpaintingSettings:
def __init__(self,
mask_blend_power,
mask_blend_scale,
inpaint_detail_preservation,
composite_mask_influence,
composite_difference_threshold,
composite_difference_contrast):
self.mask_blend_power = mask_blend_power
self.mask_blend_scale = mask_blend_scale
self.inpaint_detail_preservation = inpaint_detail_preservation
self.composite_mask_influence = composite_mask_influence
self.composite_difference_threshold = composite_difference_threshold
self.composite_difference_contrast = composite_difference_contrast
def add_generation_params(self, dest):
dest['Soft Inpainting'] = {
'sb': self.mask_blend_power,
'ps': self.mask_blend_scale,
'tcb': self.inpaint_detail_preservation,
'mi': self.composite_mask_influence,
'dt': self.composite_difference_threshold,
'dc': self.composite_difference_contrast,
}
# ------------------- Methods -------------------
def processing_uses_inpainting(p):
# TODO: Figure out a better way to determine if inpainting is being used by p
if getattr(p, "image_mask", None) is not None:
return True
if getattr(p, "mask", None) is not None:
return True
if getattr(p, "nmask", None) is not None:
return True
return False
def latent_blend(settings, a, b, t):
"""
Interpolates two latent image representations according to the parameter t,
where the interpolated vectors' magnitudes are also interpolated separately.
The "detail_preservation" factor biases the magnitude interpolation towards
the larger of the two magnitudes.
"""
import torch
# NOTE: We use inplace operations wherever possible.
# [4][w][h] to [1][4][w][h]
t2 = t.unsqueeze(0)
# [4][w][h] to [1][1][w][h] - the [4] seem redundant.
t3 = t[0].unsqueeze(0).unsqueeze(0)
one_minus_t2 = 1 - t2
one_minus_t3 = 1 - t3
# Linearly interpolate the image vectors.
a_scaled = a * one_minus_t2
b_scaled = b * t2
image_interp = a_scaled
image_interp.add_(b_scaled)
result_type = image_interp.dtype
del a_scaled, b_scaled, t2, one_minus_t2
# Calculate the magnitude of the interpolated vectors. (We will remove this magnitude.)
# 64-bit operations are used here to allow large exponents.
current_magnitude = torch.norm(image_interp, p=2, dim=1, keepdim=True).to(torch.float64).add_(0.00001)
# Interpolate the powered magnitudes, then un-power them (bring them back to a power of 1).
a_magnitude = torch.norm(a, p=2, dim=1, keepdim=True).to(torch.float64).pow_(
settings.inpaint_detail_preservation) * one_minus_t3
b_magnitude = torch.norm(b, p=2, dim=1, keepdim=True).to(torch.float64).pow_(
settings.inpaint_detail_preservation) * t3
desired_magnitude = a_magnitude
desired_magnitude.add_(b_magnitude).pow_(1 / settings.inpaint_detail_preservation)
del a_magnitude, b_magnitude, t3, one_minus_t3
# Change the linearly interpolated image vectors' magnitudes to the value we want.
# This is the last 64-bit operation.
image_interp_scaling_factor = desired_magnitude
image_interp_scaling_factor.div_(current_magnitude)
image_interp_scaling_factor = image_interp_scaling_factor.to(result_type)
image_interp_scaled = image_interp
image_interp_scaled.mul_(image_interp_scaling_factor)
del current_magnitude
del desired_magnitude
del image_interp
del image_interp_scaling_factor
del result_type
return image_interp_scaled
def get_modified_nmask(settings, nmask, sigma):
"""
Converts a negative mask representing the transparency of the original latent vectors being overlayed
to a mask that is scaled according to the denoising strength for this step.
Where:
0 = fully opaque, infinite density, fully masked
1 = fully transparent, zero density, fully unmasked
We bring this transparency to a power, as this allows one to simulate N number of blending operations
where N can be any positive real value. Using this one can control the balance of influence between
the denoiser and the original latents according to the sigma value.
NOTE: "mask" is not used
"""
import torch
return torch.pow(nmask, (sigma ** settings.mask_blend_power) * settings.mask_blend_scale)
def apply_adaptive_masks(
settings: SoftInpaintingSettings,
nmask,
latent_orig,
latent_processed,
overlay_images,
width, height,
paste_to):
import torch
import modules.processing as proc
import modules.images as images
from PIL import Image, ImageOps, ImageFilter
# TODO: Bias the blending according to the latent mask, add adjustable parameter for bias control.
latent_mask = nmask[0].float()
# convert the original mask into a form we use to scale distances for thresholding
mask_scalar = 1 - (torch.clamp(latent_mask, min=0, max=1) ** (settings.mask_blend_scale / 2))
mask_scalar = (0.5 * (1 - settings.composite_mask_influence)
+ mask_scalar * settings.composite_mask_influence)
mask_scalar = mask_scalar / (1.00001 - mask_scalar)
mask_scalar = mask_scalar.cpu().numpy()
latent_distance = torch.norm(latent_processed - latent_orig, p=2, dim=1)
kernel, kernel_center = get_gaussian_kernel(stddev_radius=1.5, max_radius=2)
masks_for_overlay = []
for i, (distance_map, overlay_image) in enumerate(zip(latent_distance, overlay_images)):
converted_mask = distance_map.float().cpu().numpy()
converted_mask = weighted_histogram_filter(converted_mask, kernel, kernel_center,
percentile_min=0.9, percentile_max=1, min_width=1)
converted_mask = weighted_histogram_filter(converted_mask, kernel, kernel_center,
percentile_min=0.25, percentile_max=0.75, min_width=1)
# The distance at which opacity of original decreases to 50%
half_weighted_distance = settings.composite_difference_threshold * mask_scalar
converted_mask = converted_mask / half_weighted_distance
converted_mask = 1 / (1 + converted_mask ** settings.composite_difference_contrast)
converted_mask = smootherstep(converted_mask)
converted_mask = 1 - converted_mask
converted_mask = 255. * converted_mask
converted_mask = converted_mask.astype(np.uint8)
converted_mask = Image.fromarray(converted_mask)
converted_mask = images.resize_image(2, converted_mask, width, height)
converted_mask = proc.create_binary_mask(converted_mask, round=False)
# Remove aliasing artifacts using a gaussian blur.
converted_mask = converted_mask.filter(ImageFilter.GaussianBlur(radius=4))
# Expand the mask to fit the whole image if needed.
if paste_to is not None:
converted_mask = proc.uncrop(converted_mask,
(overlay_image.width, overlay_image.height),
paste_to)
masks_for_overlay.append(converted_mask)
image_masked = Image.new('RGBa', (overlay_image.width, overlay_image.height))
image_masked.paste(overlay_image.convert("RGBA").convert("RGBa"),
mask=ImageOps.invert(converted_mask.convert('L')))
overlay_images[i] = image_masked.convert('RGBA')
return masks_for_overlay
def apply_masks(
settings,
nmask,
overlay_images,
width, height,
paste_to):
import torch
import modules.processing as proc
import modules.images as images
from PIL import Image, ImageOps, ImageFilter
converted_mask = nmask[0].float()
converted_mask = torch.clamp(converted_mask, min=0, max=1).pow_(settings.mask_blend_scale / 2)
converted_mask = 255. * converted_mask
converted_mask = converted_mask.cpu().numpy().astype(np.uint8)
converted_mask = Image.fromarray(converted_mask)
converted_mask = images.resize_image(2, converted_mask, width, height)
converted_mask = proc.create_binary_mask(converted_mask, round=False)
# Remove aliasing artifacts using a gaussian blur.
converted_mask = converted_mask.filter(ImageFilter.GaussianBlur(radius=4))
# Expand the mask to fit the whole image if needed.
if paste_to is not None:
converted_mask = proc.uncrop(converted_mask,
(width, height),
paste_to)
masks_for_overlay = []
for i, overlay_image in enumerate(overlay_images):
masks_for_overlay[i] = converted_mask
image_masked = Image.new('RGBa', (overlay_image.width, overlay_image.height))
image_masked.paste(overlay_image.convert("RGBA").convert("RGBa"),
mask=ImageOps.invert(converted_mask.convert('L')))
overlay_images[i] = image_masked.convert('RGBA')
return masks_for_overlay
def weighted_histogram_filter(img, kernel, kernel_center, percentile_min=0.0, percentile_max=1.0, min_width=1.0):
"""
Generalization convolution filter capable of applying
weighted mean, median, maximum, and minimum filters
parametrically using an arbitrary kernel.
Args:
img (nparray):
The image, a 2-D array of floats, to which the filter is being applied.
kernel (nparray):
The kernel, a 2-D array of floats.
kernel_center (nparray):
The kernel center coordinate, a 1-D array with two elements.
percentile_min (float):
The lower bound of the histogram window used by the filter,
from 0 to 1.
percentile_max (float):
The upper bound of the histogram window used by the filter,
from 0 to 1.
min_width (float):
The minimum size of the histogram window bounds, in weight units.
Must be greater than 0.
Returns:
(nparray): A filtered copy of the input image "img", a 2-D array of floats.
"""
# Converts an index tuple into a vector.
def vec(x):
return np.array(x)
kernel_min = -kernel_center
kernel_max = vec(kernel.shape) - kernel_center
def weighted_histogram_filter_single(idx):
idx = vec(idx)
min_index = np.maximum(0, idx + kernel_min)
max_index = np.minimum(vec(img.shape), idx + kernel_max)
window_shape = max_index - min_index
class WeightedElement:
"""
An element of the histogram, its weight
and bounds.
"""
def __init__(self, value, weight):
self.value: float = value
self.weight: float = weight
self.window_min: float = 0.0
self.window_max: float = 1.0
# Collect the values in the image as WeightedElements,
# weighted by their corresponding kernel values.
values = []
for window_tup in np.ndindex(tuple(window_shape)):
window_index = vec(window_tup)
image_index = window_index + min_index
centered_kernel_index = image_index - idx
kernel_index = centered_kernel_index + kernel_center
element = WeightedElement(img[tuple(image_index)], kernel[tuple(kernel_index)])
values.append(element)
def sort_key(x: WeightedElement):
return x.value
values.sort(key=sort_key)
# Calculate the height of the stack (sum)
# and each sample's range they occupy in the stack
sum = 0
for i in range(len(values)):
values[i].window_min = sum
sum += values[i].weight
values[i].window_max = sum
# Calculate what range of this stack ("window")
# we want to get the weighted average across.
window_min = sum * percentile_min
window_max = sum * percentile_max
window_width = window_max - window_min
# Ensure the window is within the stack and at least a certain size.
if window_width < min_width:
window_center = (window_min + window_max) / 2
window_min = window_center - min_width / 2
window_max = window_center + min_width / 2
if window_max > sum:
window_max = sum
window_min = sum - min_width
if window_min < 0:
window_min = 0
window_max = min_width
value = 0
value_weight = 0
# Get the weighted average of all the samples
# that overlap with the window, weighted
# by the size of their overlap.
for i in range(len(values)):
if window_min >= values[i].window_max:
continue
if window_max <= values[i].window_min:
break
s = max(window_min, values[i].window_min)
e = min(window_max, values[i].window_max)
w = e - s
value += values[i].value * w
value_weight += w
return value / value_weight if value_weight != 0 else 0
img_out = img.copy()
# Apply the kernel operation over each pixel.
for index in np.ndindex(img.shape):
img_out[index] = weighted_histogram_filter_single(index)
return img_out
def smoothstep(x):
"""
The smoothstep function, input should be clamped to 0-1 range.
Turns a diagonal line (f(x) = x) into a sigmoid-like curve.
"""
return x * x * (3 - 2 * x)
def smootherstep(x):
"""
The smootherstep function, input should be clamped to 0-1 range.
Turns a diagonal line (f(x) = x) into a sigmoid-like curve.
"""
return x * x * x * (x * (6 * x - 15) + 10)
def get_gaussian_kernel(stddev_radius=1.0, max_radius=2):
"""
Creates a Gaussian kernel with thresholded edges.
Args:
stddev_radius (float):
Standard deviation of the gaussian kernel, in pixels.
max_radius (int):
The size of the filter kernel. The number of pixels is (max_radius*2+1) ** 2.
The kernel is thresholded so that any values one pixel beyond this radius
is weighted at 0.
Returns:
(nparray, nparray): A kernel array (shape: (N, N)), its center coordinate (shape: (2))
"""
# Evaluates a 0-1 normalized gaussian function for a given square distance from the mean.
def gaussian(sqr_mag):
return math.exp(-sqr_mag / (stddev_radius * stddev_radius))
# Helper function for converting a tuple to an array.
def vec(x):
return np.array(x)
"""
Since a gaussian is unbounded, we need to limit ourselves
to a finite range.
We taper the ends off at the end of that range so they equal zero
while preserving the maximum value of 1 at the mean.
"""
zero_radius = max_radius + 1.0
gauss_zero = gaussian(zero_radius * zero_radius)
gauss_kernel_scale = 1 / (1 - gauss_zero)
def gaussian_kernel_func(coordinate):
x = coordinate[0] ** 2.0 + coordinate[1] ** 2.0
x = gaussian(x)
x -= gauss_zero
x *= gauss_kernel_scale
x = max(0.0, x)
return x
size = max_radius * 2 + 1
kernel_center = max_radius
kernel = np.zeros((size, size))
for index in np.ndindex(kernel.shape):
kernel[index] = gaussian_kernel_func(vec(index) - kernel_center)
return kernel, kernel_center
# ------------------- Constants -------------------
default = SoftInpaintingSettings(1, 0.5, 4, 0, 0.5, 2)
enabled_ui_label = "Soft inpainting"
enabled_gen_param_label = "Soft inpainting enabled"
enabled_el_id = "soft_inpainting_enabled"
ui_labels = SoftInpaintingSettings(
"Schedule bias",
"Preservation strength",
"Transition contrast boost",
"Mask influence",
"Difference threshold",
"Difference contrast")
ui_info = SoftInpaintingSettings(
"Shifts when preservation of original content occurs during denoising.",
"How strongly partially masked content should be preserved.",
"Amplifies the contrast that may be lost in partially masked regions.",
"How strongly the original mask should bias the difference threshold.",
"How much an image region can change before the original pixels are not blended in anymore.",
"How sharp the transition should be between blended and not blended.")
gen_param_labels = SoftInpaintingSettings(
"Soft inpainting schedule bias",
"Soft inpainting preservation strength",
"Soft inpainting transition contrast boost",
"Soft inpainting mask influence",
"Soft inpainting difference threshold",
"Soft inpainting difference contrast")
el_ids = SoftInpaintingSettings(
"mask_blend_power",
"mask_blend_scale",
"inpaint_detail_preservation",
"composite_mask_influence",
"composite_difference_threshold",
"composite_difference_contrast")
# ------------------- Script -------------------
class Script(scripts.Script):
def __init__(self):
self.section = "inpaint"
self.masks_for_overlay = None
self.overlay_images = None
def title(self):
return "Soft Inpainting"
def show(self, is_img2img):
return scripts.AlwaysVisible if is_img2img else False
def ui(self, is_img2img):
if not is_img2img:
return
with InputAccordion(False, label=enabled_ui_label, elem_id=enabled_el_id) as soft_inpainting_enabled:
with gr.Group():
gr.Markdown(
"""
Soft inpainting allows you to **seamlessly blend original content with inpainted content** according to the mask opacity.
**High _Mask blur_** values are recommended!
""")
power = \
gr.Slider(label=ui_labels.mask_blend_power,
info=ui_info.mask_blend_power,
minimum=0,
maximum=8,
step=0.1,
value=default.mask_blend_power,
elem_id=el_ids.mask_blend_power)
scale = \
gr.Slider(label=ui_labels.mask_blend_scale,
info=ui_info.mask_blend_scale,
minimum=0,
maximum=8,
step=0.05,
value=default.mask_blend_scale,
elem_id=el_ids.mask_blend_scale)
detail = \
gr.Slider(label=ui_labels.inpaint_detail_preservation,
info=ui_info.inpaint_detail_preservation,
minimum=1,
maximum=32,
step=0.5,
value=default.inpaint_detail_preservation,
elem_id=el_ids.inpaint_detail_preservation)
gr.Markdown(
"""
### Pixel Composite Settings
""")
mask_inf = \
gr.Slider(label=ui_labels.composite_mask_influence,
info=ui_info.composite_mask_influence,
minimum=0,
maximum=1,
step=0.05,
value=default.composite_mask_influence,
elem_id=el_ids.composite_mask_influence)
dif_thresh = \
gr.Slider(label=ui_labels.composite_difference_threshold,
info=ui_info.composite_difference_threshold,
minimum=0,
maximum=8,
step=0.25,
value=default.composite_difference_threshold,
elem_id=el_ids.composite_difference_threshold)
dif_contr = \
gr.Slider(label=ui_labels.composite_difference_contrast,
info=ui_info.composite_difference_contrast,
minimum=0,
maximum=8,
step=0.25,
value=default.composite_difference_contrast,
elem_id=el_ids.composite_difference_contrast)
with gr.Accordion("Help", open=False):
gr.Markdown(
f"""
### {ui_labels.mask_blend_power}
The blending strength of original content is scaled proportionally with the decreasing noise level values at each step (sigmas).
This ensures that the influence of the denoiser and original content preservation is roughly balanced at each step.
This balance can be shifted using this parameter, controlling whether earlier or later steps have stronger preservation.
- **Below 1**: Stronger preservation near the end (with low sigma)
- **1**: Balanced (proportional to sigma)
- **Above 1**: Stronger preservation in the beginning (with high sigma)
""")
gr.Markdown(
f"""
### {ui_labels.mask_blend_scale}
Skews whether partially masked image regions should be more likely to preserve the original content or favor inpainted content.
This may need to be adjusted depending on the {ui_labels.mask_blend_power}, CFG Scale, prompt and Denoising strength.
- **Low values**: Favors generated content.
- **High values**: Favors original content.
""")
gr.Markdown(
f"""
### {ui_labels.inpaint_detail_preservation}
This parameter controls how the original latent vectors and denoised latent vectors are interpolated.
With higher values, the magnitude of the resulting blended vector will be closer to the maximum of the two interpolated vectors.
This can prevent the loss of contrast that occurs with linear interpolation.
- **Low values**: Softer blending, details may fade.
- **High values**: Stronger contrast, may over-saturate colors.
""")
gr.Markdown(
"""
## Pixel Composite Settings
Masks are generated based on how much a part of the image changed after denoising.
These masks are used to blend the original and final images together.
If the difference is low, the original pixels are used instead of the pixels returned by the inpainting process.
""")
gr.Markdown(
f"""
### {ui_labels.composite_mask_influence}
This parameter controls how much the mask should bias this sensitivity to difference.
- **0**: Ignore the mask, only consider differences in image content.
- **1**: Follow the mask closely despite image content changes.
""")
gr.Markdown(
f"""
### {ui_labels.composite_difference_threshold}
This value represents the difference at which the original pixels will have less than 50% opacity.
- **Low values**: Two images patches must be almost the same in order to retain original pixels.
- **High values**: Two images patches can be very different and still retain original pixels.
""")
gr.Markdown(
f"""
### {ui_labels.composite_difference_contrast}
This value represents the contrast between the opacity of the original and inpainted content.
- **Low values**: The blend will be more gradual and have longer transitions, but may cause ghosting.
- **High values**: Ghosting will be less common, but transitions may be very sudden.
""")
def get_element_value(generation_params: dict, old_key, new_key):
if 'Soft Inpainting' in generation_params:
return generation_params['Soft Inpainting'].get(new_key, True)
else:
return generation_params.get(old_key)
self.infotext_fields = [
(soft_inpainting_enabled, lambda d: get_element_value(d, enabled_gen_param_label, None)),
(power, lambda d: get_element_value(d, gen_param_labels.mask_blend_power, 'sb')),
(scale, lambda d: get_element_value(d, gen_param_labels.mask_blend_scale, 'ps')),
(detail, lambda d: get_element_value(d, gen_param_labels.inpaint_detail_preservation, 'tcb')),
(mask_inf, lambda d: get_element_value(d, gen_param_labels.composite_mask_influence, 'mi')),
(dif_thresh, lambda d: get_element_value(d, gen_param_labels.composite_difference_threshold, 'dt')),
(dif_contr, lambda d: get_element_value(d, gen_param_labels.composite_difference_contrast, 'dc'))
]
self.paste_field_names = []
for _, field_name in self.infotext_fields:
self.paste_field_names.append(field_name)
return [soft_inpainting_enabled,
power,
scale,
detail,
mask_inf,
dif_thresh,
dif_contr]
def process(self, p, enabled, power, scale, detail_preservation, mask_inf, dif_thresh, dif_contr):
if not enabled:
return
if not processing_uses_inpainting(p):
return
# Shut off the rounding it normally does.
p.mask_round = False
settings = SoftInpaintingSettings(power, scale, detail_preservation, mask_inf, dif_thresh, dif_contr)
# p.extra_generation_params["Mask rounding"] = False
settings.add_generation_params(p.extra_generation_params)
def on_mask_blend(self, p, mba: scripts.MaskBlendArgs, enabled, power, scale, detail_preservation, mask_inf,
dif_thresh, dif_contr):
if not enabled:
return
if not processing_uses_inpainting(p):
return
if mba.is_final_blend:
mba.blended_latent = mba.current_latent
return
settings = SoftInpaintingSettings(power, scale, detail_preservation, mask_inf, dif_thresh, dif_contr)
# todo: Why is sigma 2D? Both values are the same.
mba.blended_latent = latent_blend(settings,
mba.init_latent,
mba.current_latent,
get_modified_nmask(settings, mba.nmask, mba.sigma[0]))
def post_sample(self, p, ps: scripts.PostSampleArgs, enabled, power, scale, detail_preservation, mask_inf,
dif_thresh, dif_contr):
if not enabled:
return
if not processing_uses_inpainting(p):
return
nmask = getattr(p, "nmask", None)
if nmask is None:
return
from modules import images
from modules.shared import opts
settings = SoftInpaintingSettings(power, scale, detail_preservation, mask_inf, dif_thresh, dif_contr)
# since the original code puts holes in the existing overlay images,
# we have to rebuild them.
self.overlay_images = []
for img in p.init_images:
image = images.flatten(img, opts.img2img_background_color)
if p.paste_to is None and p.resize_mode != 3:
image = images.resize_image(p.resize_mode, image, p.width, p.height)
self.overlay_images.append(image.convert('RGBA'))
if len(p.init_images) == 1:
self.overlay_images = self.overlay_images * p.batch_size
if getattr(ps.samples, 'already_decoded', False):
self.masks_for_overlay = apply_masks(settings=settings,
nmask=nmask,
overlay_images=self.overlay_images,
width=p.width,
height=p.height,
paste_to=p.paste_to)
else:
self.masks_for_overlay = apply_adaptive_masks(settings=settings,
nmask=nmask,
latent_orig=p.init_latent,
latent_processed=ps.samples,
overlay_images=self.overlay_images,
width=p.width,
height=p.height,
paste_to=p.paste_to)
def postprocess_maskoverlay(self, p, ppmo: scripts.PostProcessMaskOverlayArgs, enabled, power, scale,
detail_preservation, mask_inf, dif_thresh, dif_contr):
if not enabled:
return
if not processing_uses_inpainting(p):
return
if self.masks_for_overlay is None:
return
if self.overlay_images is None:
return
ppmo.mask_for_overlay = self.masks_for_overlay[ppmo.index]
ppmo.overlay_image = self.overlay_images[ppmo.index]
-308
View File
@@ -4,107 +4,6 @@
#licenses pre { margin: 1em 0 2em 0;} #licenses pre { margin: 1em 0 2em 0;}
</style> </style>
<h2><a href="https://github.com/sczhou/CodeFormer/blob/master/LICENSE">CodeFormer</a></h2>
<small>Parts of CodeFormer code had to be copied to be compatible with GFPGAN.</small>
<pre>
S-Lab License 1.0
Copyright 2022 S-Lab
Redistribution and use for non-commercial purpose in source and
binary forms, with or without modification, are permitted provided
that the following conditions are met:
1. Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in
the documentation and/or other materials provided with the
distribution.
3. Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived
from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
In the event that redistribution and/or use for commercial purpose in
source or binary forms, with or without modification is required,
please contact the contributor(s) of the work.
</pre>
<h2><a href="https://github.com/victorca25/iNNfer/blob/main/LICENSE">ESRGAN</a></h2>
<small>Code for architecture and reading models copied.</small>
<pre>
MIT License
Copyright (c) 2021 victorca25
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
</pre>
<h2><a href="https://github.com/xinntao/Real-ESRGAN/blob/master/LICENSE">Real-ESRGAN</a></h2>
<small>Some code is copied to support ESRGAN models.</small>
<pre>
BSD 3-Clause License
Copyright (c) 2021, Xintao Wang
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
3. Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
</pre>
<h2><a href="https://github.com/invoke-ai/InvokeAI/blob/main/LICENSE">InvokeAI</a></h2> <h2><a href="https://github.com/invoke-ai/InvokeAI/blob/main/LICENSE">InvokeAI</a></h2>
<small>Some code for compatibility with OSX is taken from lstein's repository.</small> <small>Some code for compatibility with OSX is taken from lstein's repository.</small>
<pre> <pre>
@@ -183,213 +82,6 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE. SOFTWARE.
</pre> </pre>
<h2><a href="https://github.com/JingyunLiang/SwinIR/blob/main/LICENSE">SwinIR</a></h2>
<small>Code added by contributors, most likely copied from this repository.</small>
<pre>
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [2021] [SwinIR Authors]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
</pre>
<h2><a href="https://github.com/AminRezaei0x443/memory-efficient-attention/blob/main/LICENSE">Memory Efficient Attention</a></h2> <h2><a href="https://github.com/AminRezaei0x443/memory-efficient-attention/blob/main/LICENSE">Memory Efficient Attention</a></h2>
<small>The sub-quadratic cross attention optimization uses modified code from the Memory Efficient Attention package that Alex Birch optimized for 3D tensors. This license is updated to reflect that.</small> <small>The sub-quadratic cross attention optimization uses modified code from the Memory Efficient Attention package that Alex Birch optimized for 3D tensors. This license is updated to reflect that.</small>
<pre> <pre>
+46 -13
View File
@@ -19,16 +19,28 @@ function keyupEditAttention(event) {
let beforeParen = before.lastIndexOf(OPEN); let beforeParen = before.lastIndexOf(OPEN);
if (beforeParen == -1) return false; if (beforeParen == -1) return false;
let beforeClosingParen = before.lastIndexOf(CLOSE);
if (beforeClosingParen != -1 && beforeClosingParen > beforeParen) return false;
// Find closing parenthesis around current cursor // Find closing parenthesis around current cursor
const after = text.substring(selectionStart); const after = text.substring(selectionStart);
let afterParen = after.indexOf(CLOSE); let afterParen = after.indexOf(CLOSE);
if (afterParen == -1) return false; if (afterParen == -1) return false;
let afterOpeningParen = after.indexOf(OPEN);
if (afterOpeningParen != -1 && afterOpeningParen < afterParen) return false;
// Set the selection to the text between the parenthesis // Set the selection to the text between the parenthesis
const parenContent = text.substring(beforeParen + 1, selectionStart + afterParen); const parenContent = text.substring(beforeParen + 1, selectionStart + afterParen);
const lastColon = parenContent.lastIndexOf(":"); if (/.*:-?[\d.]+/s.test(parenContent)) {
selectionStart = beforeParen + 1; const lastColon = parenContent.lastIndexOf(":");
selectionEnd = selectionStart + lastColon; selectionStart = beforeParen + 1;
selectionEnd = selectionStart + lastColon;
} else {
selectionStart = beforeParen + 1;
selectionEnd = selectionStart + parenContent.length;
}
target.setSelectionRange(selectionStart, selectionEnd); target.setSelectionRange(selectionStart, selectionEnd);
return true; return true;
} }
@@ -57,7 +69,7 @@ function keyupEditAttention(event) {
} }
// If the user hasn't selected anything, let's select their current parenthesis block or word // If the user hasn't selected anything, let's select their current parenthesis block or word
if (!selectCurrentParenthesisBlock('<', '>') && !selectCurrentParenthesisBlock('(', ')')) { if (!selectCurrentParenthesisBlock('<', '>') && !selectCurrentParenthesisBlock('(', ')') && !selectCurrentParenthesisBlock('[', ']')) {
selectCurrentWord(); selectCurrentWord();
} }
@@ -65,33 +77,54 @@ function keyupEditAttention(event) {
var closeCharacter = ')'; var closeCharacter = ')';
var delta = opts.keyedit_precision_attention; var delta = opts.keyedit_precision_attention;
var start = selectionStart > 0 ? text[selectionStart - 1] : "";
var end = text[selectionEnd];
if (selectionStart > 0 && text[selectionStart - 1] == '<') { if (start == '<') {
closeCharacter = '>'; closeCharacter = '>';
delta = opts.keyedit_precision_extra; delta = opts.keyedit_precision_extra;
} else if (selectionStart == 0 || text[selectionStart - 1] != "(") { } else if (start == '(' && end == ')' || start == '[' && end == ']') { // convert old-style (((emphasis)))
let numParen = 0;
while (text[selectionStart - numParen - 1] == start && text[selectionEnd + numParen] == end) {
numParen++;
}
if (start == "[") {
weight = (1 / 1.1) ** numParen;
} else {
weight = 1.1 ** numParen;
}
weight = Math.round(weight / opts.keyedit_precision_attention) * opts.keyedit_precision_attention;
text = text.slice(0, selectionStart - numParen) + "(" + text.slice(selectionStart, selectionEnd) + ":" + weight + ")" + text.slice(selectionEnd + numParen);
selectionStart -= numParen - 1;
selectionEnd -= numParen - 1;
} else if (start != '(') {
// do not include spaces at the end // do not include spaces at the end
while (selectionEnd > selectionStart && text[selectionEnd - 1] == ' ') { while (selectionEnd > selectionStart && text[selectionEnd - 1] == ' ') {
selectionEnd -= 1; selectionEnd--;
} }
if (selectionStart == selectionEnd) { if (selectionStart == selectionEnd) {
return; return;
} }
text = text.slice(0, selectionStart) + "(" + text.slice(selectionStart, selectionEnd) + ":1.0)" + text.slice(selectionEnd); text = text.slice(0, selectionStart) + "(" + text.slice(selectionStart, selectionEnd) + ":1.0)" + text.slice(selectionEnd);
selectionStart += 1; selectionStart++;
selectionEnd += 1; selectionEnd++;
} }
var end = text.slice(selectionEnd + 1).indexOf(closeCharacter) + 1; if (text[selectionEnd] != ':') return;
var weight = parseFloat(text.slice(selectionEnd + 1, selectionEnd + end)); var weightLength = text.slice(selectionEnd + 1).indexOf(closeCharacter) + 1;
var weight = parseFloat(text.slice(selectionEnd + 1, selectionEnd + weightLength));
if (isNaN(weight)) return; if (isNaN(weight)) return;
weight += isPlus ? delta : -delta; weight += isPlus ? delta : -delta;
weight = parseFloat(weight.toPrecision(12)); weight = parseFloat(weight.toPrecision(12));
if (String(weight).length == 1) weight += ".0"; if (Number.isInteger(weight)) weight += ".0";
if (closeCharacter == ')' && weight == 1) { if (closeCharacter == ')' && weight == 1) {
var endParenPos = text.substring(selectionEnd).indexOf(')'); var endParenPos = text.substring(selectionEnd).indexOf(')');
@@ -99,7 +132,7 @@ function keyupEditAttention(event) {
selectionStart--; selectionStart--;
selectionEnd--; selectionEnd--;
} else { } else {
text = text.slice(0, selectionEnd + 1) + weight + text.slice(selectionEnd + end); text = text.slice(0, selectionEnd + 1) + weight + text.slice(selectionEnd + weightLength);
} }
target.focus(); target.focus();
+80 -21
View File
@@ -26,8 +26,9 @@ function setupExtraNetworksForTab(tabname) {
var refresh = gradioApp().getElementById(tabname + '_extra_refresh'); var refresh = gradioApp().getElementById(tabname + '_extra_refresh');
var showDirsDiv = gradioApp().getElementById(tabname + '_extra_show_dirs'); var showDirsDiv = gradioApp().getElementById(tabname + '_extra_show_dirs');
var showDirs = gradioApp().querySelector('#' + tabname + '_extra_show_dirs input'); var showDirs = gradioApp().querySelector('#' + tabname + '_extra_show_dirs input');
var promptContainer = gradioApp().querySelector('.prompt-container-compact#' + tabname + '_prompt_container');
var negativePrompt = gradioApp().querySelector('#' + tabname + '_neg_prompt');
sort.dataset.sortkey = 'sortDefault';
tabs.appendChild(searchDiv); tabs.appendChild(searchDiv);
tabs.appendChild(sort); tabs.appendChild(sort);
tabs.appendChild(sortOrder); tabs.appendChild(sortOrder);
@@ -49,20 +50,23 @@ function setupExtraNetworksForTab(tabname) {
elem.style.display = visible ? "" : "none"; elem.style.display = visible ? "" : "none";
}); });
applySort();
}; };
var applySort = function() { var applySort = function() {
var cards = gradioApp().querySelectorAll('#' + tabname + '_extra_tabs div.card');
var reverse = sortOrder.classList.contains("sortReverse"); var reverse = sortOrder.classList.contains("sortReverse");
var sortKey = sort.querySelector("input").value.toLowerCase().replace("sort", "").replaceAll(" ", "_").replace(/_+$/, "").trim(); var sortKey = sort.querySelector("input").value.toLowerCase().replace("sort", "").replaceAll(" ", "_").replace(/_+$/, "").trim() || "name";
sortKey = sortKey ? "sort" + sortKey.charAt(0).toUpperCase() + sortKey.slice(1) : ""; sortKey = "sort" + sortKey.charAt(0).toUpperCase() + sortKey.slice(1);
var sortKeyStore = sortKey ? sortKey + (reverse ? "Reverse" : "") : ""; var sortKeyStore = sortKey + "-" + (reverse ? "Descending" : "Ascending") + "-" + cards.length;
if (!sortKey || sortKeyStore == sort.dataset.sortkey) {
if (sortKeyStore == sort.dataset.sortkey) {
return; return;
} }
sort.dataset.sortkey = sortKeyStore; sort.dataset.sortkey = sortKeyStore;
var cards = gradioApp().querySelectorAll('#' + tabname + '_extra_tabs div.card');
cards.forEach(function(card) { cards.forEach(function(card) {
card.originalParentElement = card.parentElement; card.originalParentElement = card.parentElement;
}); });
@@ -88,15 +92,13 @@ function setupExtraNetworksForTab(tabname) {
}; };
search.addEventListener("input", applyFilter); search.addEventListener("input", applyFilter);
applyFilter();
["change", "blur", "click"].forEach(function(evt) {
sort.querySelector("input").addEventListener(evt, applySort);
});
sortOrder.addEventListener("click", function() { sortOrder.addEventListener("click", function() {
sortOrder.classList.toggle("sortReverse"); sortOrder.classList.toggle("sortReverse");
applySort(); applySort();
}); });
applyFilter();
extraNetworksApplySort[tabname] = applySort;
extraNetworksApplyFilter[tabname] = applyFilter; extraNetworksApplyFilter[tabname] = applyFilter;
var showDirsUpdate = function() { var showDirsUpdate = function() {
@@ -109,11 +111,51 @@ function setupExtraNetworksForTab(tabname) {
showDirsUpdate(); showDirsUpdate();
} }
function extraNetworksMovePromptToTab(tabname, id, showPrompt, showNegativePrompt) {
if (!gradioApp().querySelector('.toprow-compact-tools')) return; // only applicable for compact prompt layout
var promptContainer = gradioApp().getElementById(tabname + '_prompt_container');
var prompt = gradioApp().getElementById(tabname + '_prompt_row');
var negPrompt = gradioApp().getElementById(tabname + '_neg_prompt_row');
var elem = id ? gradioApp().getElementById(id) : null;
if (showNegativePrompt && elem) {
elem.insertBefore(negPrompt, elem.firstChild);
} else {
promptContainer.insertBefore(negPrompt, promptContainer.firstChild);
}
if (showPrompt && elem) {
elem.insertBefore(prompt, elem.firstChild);
} else {
promptContainer.insertBefore(prompt, promptContainer.firstChild);
}
if (elem) {
elem.classList.toggle('extra-page-prompts-active', showNegativePrompt || showPrompt);
}
}
function extraNetworksUrelatedTabSelected(tabname) { // called from python when user selects an unrelated tab (generate)
extraNetworksMovePromptToTab(tabname, '', false, false);
}
function extraNetworksTabSelected(tabname, id, showPrompt, showNegativePrompt) { // called from python when user selects an extra networks tab
extraNetworksMovePromptToTab(tabname, id, showPrompt, showNegativePrompt);
}
function applyExtraNetworkFilter(tabname) { function applyExtraNetworkFilter(tabname) {
setTimeout(extraNetworksApplyFilter[tabname], 1); setTimeout(extraNetworksApplyFilter[tabname], 1);
} }
function applyExtraNetworkSort(tabname) {
setTimeout(extraNetworksApplySort[tabname], 1);
}
var extraNetworksApplyFilter = {}; var extraNetworksApplyFilter = {};
var extraNetworksApplySort = {};
var activePromptTextarea = {}; var activePromptTextarea = {};
function setupExtraNetworks() { function setupExtraNetworks() {
@@ -143,8 +185,10 @@ onUiLoaded(setupExtraNetworks);
var re_extranet = /<([^:^>]+:[^:]+):[\d.]+>(.*)/; var re_extranet = /<([^:^>]+:[^:]+):[\d.]+>(.*)/;
var re_extranet_g = /<([^:^>]+:[^:]+):[\d.]+>/g; var re_extranet_g = /<([^:^>]+:[^:]+):[\d.]+>/g;
function tryToRemoveExtraNetworkFromPrompt(textarea, text) { var re_extranet_neg = /\(([^:^>]+:[\d.]+)\)/;
var m = text.match(re_extranet); var re_extranet_g_neg = /\(([^:^>]+:[\d.]+)\)/g;
function tryToRemoveExtraNetworkFromPrompt(textarea, text, isNeg) {
var m = text.match(isNeg ? re_extranet_neg : re_extranet);
var replaced = false; var replaced = false;
var newTextareaText; var newTextareaText;
if (m) { if (m) {
@@ -152,8 +196,8 @@ function tryToRemoveExtraNetworkFromPrompt(textarea, text) {
var extraTextAfterNet = m[2]; var extraTextAfterNet = m[2];
var partToSearch = m[1]; var partToSearch = m[1];
var foundAtPosition = -1; var foundAtPosition = -1;
newTextareaText = textarea.value.replaceAll(re_extranet_g, function(found, net, pos) { newTextareaText = textarea.value.replaceAll(isNeg ? re_extranet_g_neg : re_extranet_g, function(found, net, pos) {
m = found.match(re_extranet); m = found.match(isNeg ? re_extranet_neg : re_extranet);
if (m[1] == partToSearch) { if (m[1] == partToSearch) {
replaced = true; replaced = true;
foundAtPosition = pos; foundAtPosition = pos;
@@ -163,7 +207,7 @@ function tryToRemoveExtraNetworkFromPrompt(textarea, text) {
}); });
if (foundAtPosition >= 0) { if (foundAtPosition >= 0) {
if (newTextareaText.substr(foundAtPosition, extraTextAfterNet.length) == extraTextAfterNet) { if (extraTextAfterNet && newTextareaText.substr(foundAtPosition, extraTextAfterNet.length) == extraTextAfterNet) {
newTextareaText = newTextareaText.substr(0, foundAtPosition) + newTextareaText.substr(foundAtPosition + extraTextAfterNet.length); newTextareaText = newTextareaText.substr(0, foundAtPosition) + newTextareaText.substr(foundAtPosition + extraTextAfterNet.length);
} }
if (newTextareaText.substr(foundAtPosition - extraTextBeforeNet.length, extraTextBeforeNet.length) == extraTextBeforeNet) { if (newTextareaText.substr(foundAtPosition - extraTextBeforeNet.length, extraTextBeforeNet.length) == extraTextBeforeNet) {
@@ -188,14 +232,23 @@ function tryToRemoveExtraNetworkFromPrompt(textarea, text) {
return false; return false;
} }
function cardClicked(tabname, textToAdd, allowNegativePrompt) { function updatePromptArea(text, textArea, isNeg) {
var textarea = allowNegativePrompt ? activePromptTextarea[tabname] : gradioApp().querySelector("#" + tabname + "_prompt > label > textarea");
if (!tryToRemoveExtraNetworkFromPrompt(textarea, textToAdd)) { if (!tryToRemoveExtraNetworkFromPrompt(textArea, text, isNeg)) {
textarea.value = textarea.value + opts.extra_networks_add_text_separator + textToAdd; textArea.value = textArea.value + opts.extra_networks_add_text_separator + text;
} }
updateInput(textarea); updateInput(textArea);
}
function cardClicked(tabname, textToAdd, textToAddNegative, allowNegativePrompt) {
if (textToAddNegative.length > 0) {
updatePromptArea(textToAdd, gradioApp().querySelector("#" + tabname + "_prompt > label > textarea"));
updatePromptArea(textToAddNegative, gradioApp().querySelector("#" + tabname + "_neg_prompt > label > textarea"), true);
} else {
var textarea = allowNegativePrompt ? activePromptTextarea[tabname] : gradioApp().querySelector("#" + tabname + "_prompt > label > textarea");
updatePromptArea(textToAdd, textarea);
}
} }
function saveCardPreview(event, tabname, filename) { function saveCardPreview(event, tabname, filename) {
@@ -350,3 +403,9 @@ function extraNetworksRefreshSingleCard(page, tabname, name) {
} }
}); });
} }
window.addEventListener("keydown", function(event) {
if (event.key == "Escape") {
closePopup();
}
});
+5 -2
View File
@@ -33,8 +33,11 @@ function updateOnBackgroundChange() {
const modalImage = gradioApp().getElementById("modalImage"); const modalImage = gradioApp().getElementById("modalImage");
if (modalImage && modalImage.offsetParent) { if (modalImage && modalImage.offsetParent) {
let currentButton = selected_gallery_button(); let currentButton = selected_gallery_button();
let preview = gradioApp().querySelectorAll('.livePreview > img');
if (currentButton?.children?.length > 0 && modalImage.src != currentButton.children[0].src) { if (opts.js_live_preview_in_modal_lightbox && preview.length > 0) {
// show preview image if available
modalImage.src = preview[preview.length - 1].src;
} else if (currentButton?.children?.length > 0 && modalImage.src != currentButton.children[0].src) {
modalImage.src = currentButton.children[0].src; modalImage.src = currentButton.children[0].src;
if (modalImage.style.display === 'none') { if (modalImage.style.display === 'none') {
const modal = gradioApp().getElementById("lightboxModal"); const modal = gradioApp().getElementById("lightboxModal");
+60 -29
View File
@@ -1,37 +1,68 @@
var observerAccordionOpen = new MutationObserver(function(mutations) {
mutations.forEach(function(mutationRecord) {
var elem = mutationRecord.target;
var open = elem.classList.contains('open');
var accordion = elem.parentNode;
accordion.classList.toggle('input-accordion-open', open);
var checkbox = gradioApp().querySelector('#' + accordion.id + "-checkbox input");
checkbox.checked = open;
updateInput(checkbox);
var extra = gradioApp().querySelector('#' + accordion.id + "-extra");
if (extra) {
extra.style.display = open ? "" : "none";
}
});
});
function inputAccordionChecked(id, checked) { function inputAccordionChecked(id, checked) {
var label = gradioApp().querySelector('#' + id + " .label-wrap"); var accordion = gradioApp().getElementById(id);
if (label.classList.contains('open') != checked) { accordion.visibleCheckbox.checked = checked;
label.click(); accordion.onVisibleCheckboxChange();
}
function setupAccordion(accordion) {
var labelWrap = accordion.querySelector('.label-wrap');
var gradioCheckbox = gradioApp().querySelector('#' + accordion.id + "-checkbox input");
var extra = gradioApp().querySelector('#' + accordion.id + "-extra");
var span = labelWrap.querySelector('span');
var linked = true;
var isOpen = function() {
return labelWrap.classList.contains('open');
};
var observerAccordionOpen = new MutationObserver(function(mutations) {
mutations.forEach(function(mutationRecord) {
accordion.classList.toggle('input-accordion-open', isOpen());
if (linked) {
accordion.visibleCheckbox.checked = isOpen();
accordion.onVisibleCheckboxChange();
}
});
});
observerAccordionOpen.observe(labelWrap, {attributes: true, attributeFilter: ['class']});
if (extra) {
labelWrap.insertBefore(extra, labelWrap.lastElementChild);
} }
accordion.onChecked = function(checked) {
if (isOpen() != checked) {
labelWrap.click();
}
};
var visibleCheckbox = document.createElement('INPUT');
visibleCheckbox.type = 'checkbox';
visibleCheckbox.checked = isOpen();
visibleCheckbox.id = accordion.id + "-visible-checkbox";
visibleCheckbox.className = gradioCheckbox.className + " input-accordion-checkbox";
span.insertBefore(visibleCheckbox, span.firstChild);
accordion.visibleCheckbox = visibleCheckbox;
accordion.onVisibleCheckboxChange = function() {
if (linked && isOpen() != visibleCheckbox.checked) {
labelWrap.click();
}
gradioCheckbox.checked = visibleCheckbox.checked;
updateInput(gradioCheckbox);
};
visibleCheckbox.addEventListener('click', function(event) {
linked = false;
event.stopPropagation();
});
visibleCheckbox.addEventListener('input', accordion.onVisibleCheckboxChange);
} }
onUiLoaded(function() { onUiLoaded(function() {
for (var accordion of gradioApp().querySelectorAll('.input-accordion')) { for (var accordion of gradioApp().querySelectorAll('.input-accordion')) {
var labelWrap = accordion.querySelector('.label-wrap'); setupAccordion(accordion);
observerAccordionOpen.observe(labelWrap, {attributes: true, attributeFilter: ['class']});
var extra = gradioApp().querySelector('#' + accordion.id + "-extra");
if (extra) {
labelWrap.insertBefore(extra, labelWrap.lastElementChild);
}
} }
}); });
+5 -1
View File
@@ -26,7 +26,11 @@ onAfterUiUpdate(function() {
lastHeadImg = headImg; lastHeadImg = headImg;
// play notification sound if available // play notification sound if available
gradioApp().querySelector('#audio_notification audio')?.play(); const notificationAudio = gradioApp().querySelector('#audio_notification audio');
if (notificationAudio) {
notificationAudio.volume = opts.notification_volume / 100.0 || 1.0;
notificationAudio.play();
}
if (document.hasFocus()) return; if (document.hasFocus()) return;
+25
View File
@@ -44,3 +44,28 @@ onUiLoaded(function() {
buttonShowAllPages.addEventListener("click", settingsShowAllTabs); buttonShowAllPages.addEventListener("click", settingsShowAllTabs);
}); });
onOptionsChanged(function() {
if (gradioApp().querySelector('#settings .settings-category')) return;
var sectionMap = {};
gradioApp().querySelectorAll('#settings > div > button').forEach(function(x) {
sectionMap[x.textContent.trim()] = x;
});
opts._categories.forEach(function(x) {
var section = x[0];
var category = x[1];
var span = document.createElement('SPAN');
span.textContent = category;
span.className = 'settings-category';
var sectionElem = sectionMap[section];
if (!sectionElem) return;
sectionElem.parentElement.insertBefore(span, sectionElem);
});
});
+49
View File
@@ -150,6 +150,14 @@ function submit() {
return res; return res;
} }
function submit_txt2img_upscale() {
var res = submit(...arguments);
res[2] = selected_gallery_index();
return res;
}
function submit_img2img() { function submit_img2img() {
showSubmitButtons('img2img', false); showSubmitButtons('img2img', false);
@@ -170,6 +178,23 @@ function submit_img2img() {
return res; return res;
} }
function submit_extras() {
showSubmitButtons('extras', false);
var id = randomId();
requestProgress(id, gradioApp().getElementById('extras_gallery_container'), gradioApp().getElementById('extras_gallery'), function() {
showSubmitButtons('extras', true);
});
var res = create_submit_args(arguments);
res[0] = id;
console.log(res);
return res;
}
function restoreProgressTxt2img() { function restoreProgressTxt2img() {
showRestoreProgressButton("txt2img", false); showRestoreProgressButton("txt2img", false);
var id = localGet("txt2img_task_id"); var id = localGet("txt2img_task_id");
@@ -198,9 +223,33 @@ function restoreProgressImg2img() {
} }
/**
* Configure the width and height elements on `tabname` to accept
* pasting of resolutions in the form of "width x height".
*/
function setupResolutionPasting(tabname) {
var width = gradioApp().querySelector(`#${tabname}_width input[type=number]`);
var height = gradioApp().querySelector(`#${tabname}_height input[type=number]`);
for (const el of [width, height]) {
el.addEventListener('paste', function(event) {
var pasteData = event.clipboardData.getData('text/plain');
var parsed = pasteData.match(/^\s*(\d+)\D+(\d+)\s*$/);
if (parsed) {
width.value = parsed[1];
height.value = parsed[2];
updateInput(width);
updateInput(height);
event.preventDefault();
}
});
}
}
onUiLoaded(function() { onUiLoaded(function() {
showRestoreProgressButton('txt2img', localGet("txt2img_task_id")); showRestoreProgressButton('txt2img', localGet("txt2img_task_id"));
showRestoreProgressButton('img2img', localGet("img2img_task_id")); showRestoreProgressButton('img2img', localGet("img2img_task_id"));
setupResolutionPasting('txt2img');
setupResolutionPasting('img2img');
}); });
+136 -39
View File
@@ -17,15 +17,13 @@ from fastapi.encoders import jsonable_encoder
from secrets import compare_digest from secrets import compare_digest
import modules.shared as shared import modules.shared as shared
from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart, shared_items, script_callbacks, generation_parameters_copypaste from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart, shared_items, script_callbacks, infotext_utils, sd_models
from modules.api import models from modules.api import models
from modules.shared import opts from modules.shared import opts
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
from modules.textual_inversion.textual_inversion import create_embedding, train_embedding from modules.textual_inversion.textual_inversion import create_embedding, train_embedding
from modules.textual_inversion.preprocess import preprocess
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
from PIL import PngImagePlugin,Image from PIL import PngImagePlugin, Image
from modules.sd_models import unload_model_weights, reload_model_weights, checkpoint_aliases
from modules.sd_models_config import find_checkpoint_config_near_filename from modules.sd_models_config import find_checkpoint_config_near_filename
from modules.realesrgan_model import get_realesrgan_models from modules.realesrgan_model import get_realesrgan_models
from modules import devices from modules import devices
@@ -33,7 +31,7 @@ from typing import Any
import piexif import piexif
import piexif.helper import piexif.helper
from contextlib import closing from contextlib import closing
from modules.progress import create_task_id, add_task_to_queue, start_task, finish_task, current_task
def script_name_to_index(name, scripts): def script_name_to_index(name, scripts):
try: try:
@@ -103,7 +101,8 @@ def decode_base64_to_image(encoding):
def encode_pil_to_base64(image): def encode_pil_to_base64(image):
with io.BytesIO() as output_bytes: with io.BytesIO() as output_bytes:
if isinstance(image, str):
return image
if opts.samples_format.lower() == 'png': if opts.samples_format.lower() == 'png':
use_metadata = False use_metadata = False
metadata = PngImagePlugin.PngInfo() metadata = PngImagePlugin.PngInfo()
@@ -235,7 +234,6 @@ class Api:
self.add_api_route("/sdapi/v1/refresh-vae", self.refresh_vae, methods=["POST"]) self.add_api_route("/sdapi/v1/refresh-vae", self.refresh_vae, methods=["POST"])
self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=models.CreateResponse) self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=models.CreateResponse)
self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=models.CreateResponse) self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=models.CreateResponse)
self.add_api_route("/sdapi/v1/preprocess", self.preprocess, methods=["POST"], response_model=models.PreprocessResponse)
self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=models.TrainResponse) self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=models.TrainResponse)
self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=models.TrainResponse) self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=models.TrainResponse)
self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=models.MemoryResponse) self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=models.MemoryResponse)
@@ -253,6 +251,24 @@ class Api:
self.default_script_arg_txt2img = [] self.default_script_arg_txt2img = []
self.default_script_arg_img2img = [] self.default_script_arg_img2img = []
txt2img_script_runner = scripts.scripts_txt2img
img2img_script_runner = scripts.scripts_img2img
if not txt2img_script_runner.scripts or not img2img_script_runner.scripts:
ui.create_ui()
if not txt2img_script_runner.scripts:
txt2img_script_runner.initialize_scripts(False)
if not self.default_script_arg_txt2img:
self.default_script_arg_txt2img = self.init_default_script_args(txt2img_script_runner)
if not img2img_script_runner.scripts:
img2img_script_runner.initialize_scripts(True)
if not self.default_script_arg_img2img:
self.default_script_arg_img2img = self.init_default_script_args(img2img_script_runner)
def add_api_route(self, path: str, endpoint, **kwargs): def add_api_route(self, path: str, endpoint, **kwargs):
if shared.cmd_opts.api_auth: if shared.cmd_opts.api_auth:
return self.app.add_api_route(path, endpoint, dependencies=[Depends(self.auth)], **kwargs) return self.app.add_api_route(path, endpoint, dependencies=[Depends(self.auth)], **kwargs)
@@ -314,8 +330,13 @@ class Api:
script_args[script.args_from:script.args_to] = ui_default_values script_args[script.args_from:script.args_to] = ui_default_values
return script_args return script_args
def init_script_args(self, request, default_script_args, selectable_scripts, selectable_idx, script_runner): def init_script_args(self, request, default_script_args, selectable_scripts, selectable_idx, script_runner, *, input_script_args=None):
script_args = default_script_args.copy() script_args = default_script_args.copy()
if input_script_args is not None:
for index, value in input_script_args.items():
script_args[index] = value
# position 0 in script_arg is the idx+1 of the selectable script that is going to be run when using scripts.scripts_*2img.run() # position 0 in script_arg is the idx+1 of the selectable script that is going to be run when using scripts.scripts_*2img.run()
if selectable_scripts: if selectable_scripts:
script_args[selectable_scripts.args_from:selectable_scripts.args_to] = request.script_args script_args[selectable_scripts.args_from:selectable_scripts.args_to] = request.script_args
@@ -337,13 +358,83 @@ class Api:
script_args[alwayson_script.args_from + idx] = request.alwayson_scripts[alwayson_script_name]["args"][idx] script_args[alwayson_script.args_from + idx] = request.alwayson_scripts[alwayson_script_name]["args"][idx]
return script_args return script_args
def apply_infotext(self, request, tabname, *, script_runner=None, mentioned_script_args=None):
"""Processes `infotext` field from the `request`, and sets other fields of the `request` accoring to what's in infotext.
If request already has a field set, and that field is encountered in infotext too, the value from infotext is ignored.
Additionally, fills `mentioned_script_args` dict with index: value pairs for script arguments read from infotext.
"""
if not request.infotext:
return {}
possible_fields = infotext_utils.paste_fields[tabname]["fields"]
set_fields = request.model_dump(exclude_unset=True) if hasattr(request, "request") else request.dict(exclude_unset=True) # pydantic v1/v2 have differenrt names for this
params = infotext_utils.parse_generation_parameters(request.infotext)
def get_field_value(field, params):
value = field.function(params) if field.function else params.get(field.label)
if value is None:
return None
if field.api in request.__fields__:
target_type = request.__fields__[field.api].type_
else:
target_type = type(field.component.value)
if target_type == type(None):
return None
if isinstance(value, dict) and value.get('__type__') == 'generic_update': # this is a gradio.update rather than a value
value = value.get('value')
if value is not None and not isinstance(value, target_type):
value = target_type(value)
return value
for field in possible_fields:
if not field.api:
continue
if field.api in set_fields:
continue
value = get_field_value(field, params)
if value is not None:
setattr(request, field.api, value)
if request.override_settings is None:
request.override_settings = {}
overriden_settings = infotext_utils.get_override_settings(params)
for _, setting_name, value in overriden_settings:
if setting_name not in request.override_settings:
request.override_settings[setting_name] = value
if script_runner is not None and mentioned_script_args is not None:
indexes = {v: i for i, v in enumerate(script_runner.inputs)}
script_fields = ((field, indexes[field.component]) for field in possible_fields if field.component in indexes)
for field, index in script_fields:
value = get_field_value(field, params)
if value is None:
continue
mentioned_script_args[index] = value
return params
def text2imgapi(self, txt2imgreq: models.StableDiffusionTxt2ImgProcessingAPI): def text2imgapi(self, txt2imgreq: models.StableDiffusionTxt2ImgProcessingAPI):
task_id = txt2imgreq.force_task_id or create_task_id("txt2img")
script_runner = scripts.scripts_txt2img script_runner = scripts.scripts_txt2img
if not script_runner.scripts:
script_runner.initialize_scripts(False) infotext_script_args = {}
ui.create_ui() self.apply_infotext(txt2imgreq, "txt2img", script_runner=script_runner, mentioned_script_args=infotext_script_args)
if not self.default_script_arg_txt2img:
self.default_script_arg_txt2img = self.init_default_script_args(script_runner)
selectable_scripts, selectable_script_idx = self.get_selectable_script(txt2imgreq.script_name, script_runner) selectable_scripts, selectable_script_idx = self.get_selectable_script(txt2imgreq.script_name, script_runner)
populate = txt2imgreq.copy(update={ # Override __init__ params populate = txt2imgreq.copy(update={ # Override __init__ params
@@ -358,12 +449,15 @@ class Api:
args.pop('script_name', None) args.pop('script_name', None)
args.pop('script_args', None) # will refeed them to the pipeline directly after initializing them args.pop('script_args', None) # will refeed them to the pipeline directly after initializing them
args.pop('alwayson_scripts', None) args.pop('alwayson_scripts', None)
args.pop('infotext', None)
script_args = self.init_script_args(txt2imgreq, self.default_script_arg_txt2img, selectable_scripts, selectable_script_idx, script_runner) script_args = self.init_script_args(txt2imgreq, self.default_script_arg_txt2img, selectable_scripts, selectable_script_idx, script_runner, input_script_args=infotext_script_args)
send_images = args.pop('send_images', True) send_images = args.pop('send_images', True)
args.pop('save_images', None) args.pop('save_images', None)
add_task_to_queue(task_id)
with self.queue_lock: with self.queue_lock:
with closing(StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args)) as p: with closing(StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args)) as p:
p.is_api = True p.is_api = True
@@ -373,12 +467,14 @@ class Api:
try: try:
shared.state.begin(job="scripts_txt2img") shared.state.begin(job="scripts_txt2img")
start_task(task_id)
if selectable_scripts is not None: if selectable_scripts is not None:
p.script_args = script_args p.script_args = script_args
processed = scripts.scripts_txt2img.run(p, *p.script_args) # Need to pass args as list here processed = scripts.scripts_txt2img.run(p, *p.script_args) # Need to pass args as list here
else: else:
p.script_args = tuple(script_args) # Need to pass args as tuple here p.script_args = tuple(script_args) # Need to pass args as tuple here
processed = process_images(p) processed = process_images(p)
finish_task(task_id)
finally: finally:
shared.state.end() shared.state.end()
shared.total_tqdm.clear() shared.total_tqdm.clear()
@@ -388,6 +484,8 @@ class Api:
return models.TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js()) return models.TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js())
def img2imgapi(self, img2imgreq: models.StableDiffusionImg2ImgProcessingAPI): def img2imgapi(self, img2imgreq: models.StableDiffusionImg2ImgProcessingAPI):
task_id = img2imgreq.force_task_id or create_task_id("img2img")
init_images = img2imgreq.init_images init_images = img2imgreq.init_images
if init_images is None: if init_images is None:
raise HTTPException(status_code=404, detail="Init image not found") raise HTTPException(status_code=404, detail="Init image not found")
@@ -397,11 +495,10 @@ class Api:
mask = decode_base64_to_image(mask) mask = decode_base64_to_image(mask)
script_runner = scripts.scripts_img2img script_runner = scripts.scripts_img2img
if not script_runner.scripts:
script_runner.initialize_scripts(True) infotext_script_args = {}
ui.create_ui() self.apply_infotext(img2imgreq, "img2img", script_runner=script_runner, mentioned_script_args=infotext_script_args)
if not self.default_script_arg_img2img:
self.default_script_arg_img2img = self.init_default_script_args(script_runner)
selectable_scripts, selectable_script_idx = self.get_selectable_script(img2imgreq.script_name, script_runner) selectable_scripts, selectable_script_idx = self.get_selectable_script(img2imgreq.script_name, script_runner)
populate = img2imgreq.copy(update={ # Override __init__ params populate = img2imgreq.copy(update={ # Override __init__ params
@@ -418,12 +515,15 @@ class Api:
args.pop('script_name', None) args.pop('script_name', None)
args.pop('script_args', None) # will refeed them to the pipeline directly after initializing them args.pop('script_args', None) # will refeed them to the pipeline directly after initializing them
args.pop('alwayson_scripts', None) args.pop('alwayson_scripts', None)
args.pop('infotext', None)
script_args = self.init_script_args(img2imgreq, self.default_script_arg_img2img, selectable_scripts, selectable_script_idx, script_runner) script_args = self.init_script_args(img2imgreq, self.default_script_arg_img2img, selectable_scripts, selectable_script_idx, script_runner, input_script_args=infotext_script_args)
send_images = args.pop('send_images', True) send_images = args.pop('send_images', True)
args.pop('save_images', None) args.pop('save_images', None)
add_task_to_queue(task_id)
with self.queue_lock: with self.queue_lock:
with closing(StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)) as p: with closing(StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)) as p:
p.init_images = [decode_base64_to_image(x) for x in init_images] p.init_images = [decode_base64_to_image(x) for x in init_images]
@@ -434,12 +534,14 @@ class Api:
try: try:
shared.state.begin(job="scripts_img2img") shared.state.begin(job="scripts_img2img")
start_task(task_id)
if selectable_scripts is not None: if selectable_scripts is not None:
p.script_args = script_args p.script_args = script_args
processed = scripts.scripts_img2img.run(p, *p.script_args) # Need to pass args as list here processed = scripts.scripts_img2img.run(p, *p.script_args) # Need to pass args as list here
else: else:
p.script_args = tuple(script_args) # Need to pass args as tuple here p.script_args = tuple(script_args) # Need to pass args as tuple here
processed = process_images(p) processed = process_images(p)
finish_task(task_id)
finally: finally:
shared.state.end() shared.state.end()
shared.total_tqdm.clear() shared.total_tqdm.clear()
@@ -482,7 +584,7 @@ class Api:
if geninfo is None: if geninfo is None:
geninfo = "" geninfo = ""
params = generation_parameters_copypaste.parse_generation_parameters(geninfo) params = infotext_utils.parse_generation_parameters(geninfo)
script_callbacks.infotext_pasted_callback(geninfo, params) script_callbacks.infotext_pasted_callback(geninfo, params)
return models.PNGInfoResponse(info=geninfo, items=items, parameters=params) return models.PNGInfoResponse(info=geninfo, items=items, parameters=params)
@@ -513,7 +615,7 @@ class Api:
if shared.state.current_image and not req.skip_current_image: if shared.state.current_image and not req.skip_current_image:
current_image = encode_pil_to_base64(shared.state.current_image) current_image = encode_pil_to_base64(shared.state.current_image)
return models.ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image, textinfo=shared.state.textinfo) return models.ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image, textinfo=shared.state.textinfo, current_task=current_task)
def interrogateapi(self, interrogatereq: models.InterrogateRequest): def interrogateapi(self, interrogatereq: models.InterrogateRequest):
image_b64 = interrogatereq.image image_b64 = interrogatereq.image
@@ -540,12 +642,12 @@ class Api:
return {} return {}
def unloadapi(self): def unloadapi(self):
unload_model_weights() sd_models.unload_model_weights()
return {} return {}
def reloadapi(self): def reloadapi(self):
reload_model_weights() sd_models.send_model_to_device(shared.sd_model)
return {} return {}
@@ -565,7 +667,7 @@ class Api:
def set_config(self, req: dict[str, Any]): def set_config(self, req: dict[str, Any]):
checkpoint_name = req.get("sd_model_checkpoint", None) checkpoint_name = req.get("sd_model_checkpoint", None)
if checkpoint_name is not None and checkpoint_name not in checkpoint_aliases: if checkpoint_name is not None and checkpoint_name not in sd_models.checkpoint_aliases:
raise RuntimeError(f"model {checkpoint_name!r} not found") raise RuntimeError(f"model {checkpoint_name!r} not found")
for k, v in req.items(): for k, v in req.items():
@@ -675,19 +777,6 @@ class Api:
finally: finally:
shared.state.end() shared.state.end()
def preprocess(self, args: dict):
try:
shared.state.begin(job="preprocess")
preprocess(**args) # quick operation unless blip/booru interrogation is enabled
shared.state.end()
return models.PreprocessResponse(info='preprocess complete')
except KeyError as e:
return models.PreprocessResponse(info=f"preprocess error: invalid token: {e}")
except Exception as e:
return models.PreprocessResponse(info=f"preprocess error: {e}")
finally:
shared.state.end()
def train_embedding(self, args: dict): def train_embedding(self, args: dict):
try: try:
shared.state.begin(job="train_embedding") shared.state.begin(job="train_embedding")
@@ -790,7 +879,15 @@ class Api:
def launch(self, server_name, port, root_path): def launch(self, server_name, port, root_path):
self.app.include_router(self.router) self.app.include_router(self.router)
uvicorn.run(self.app, host=server_name, port=port, timeout_keep_alive=shared.cmd_opts.timeout_keep_alive, root_path=root_path) uvicorn.run(
self.app,
host=server_name,
port=port,
timeout_keep_alive=shared.cmd_opts.timeout_keep_alive,
root_path=root_path,
ssl_keyfile=shared.cmd_opts.tls_keyfile,
ssl_certfile=shared.cmd_opts.tls_certfile
)
def kill_webui(self): def kill_webui(self):
restart.stop_program() restart.stop_program()
+4 -3
View File
@@ -107,6 +107,8 @@ StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator(
{"key": "send_images", "type": bool, "default": True}, {"key": "send_images", "type": bool, "default": True},
{"key": "save_images", "type": bool, "default": False}, {"key": "save_images", "type": bool, "default": False},
{"key": "alwayson_scripts", "type": dict, "default": {}}, {"key": "alwayson_scripts", "type": dict, "default": {}},
{"key": "force_task_id", "type": str, "default": None},
{"key": "infotext", "type": str, "default": None},
] ]
).generate_model() ).generate_model()
@@ -124,6 +126,8 @@ StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator(
{"key": "send_images", "type": bool, "default": True}, {"key": "send_images", "type": bool, "default": True},
{"key": "save_images", "type": bool, "default": False}, {"key": "save_images", "type": bool, "default": False},
{"key": "alwayson_scripts", "type": dict, "default": {}}, {"key": "alwayson_scripts", "type": dict, "default": {}},
{"key": "force_task_id", "type": str, "default": None},
{"key": "infotext", "type": str, "default": None},
] ]
).generate_model() ).generate_model()
@@ -202,9 +206,6 @@ class TrainResponse(BaseModel):
class CreateResponse(BaseModel): class CreateResponse(BaseModel):
info: str = Field(title="Create info", description="Response string from create embedding or hypernetwork task.") info: str = Field(title="Create info", description="Response string from create embedding or hypernetwork task.")
class PreprocessResponse(BaseModel):
info: str = Field(title="Preprocess info", description="Response string from preprocessing task.")
fields = {} fields = {}
for key, metadata in opts.data_labels.items(): for key, metadata in opts.data_labels.items():
value = opts.data.get(key) value = opts.data.get(key)
+9 -10
View File
@@ -32,7 +32,7 @@ def dump_cache():
with cache_lock: with cache_lock:
cache_filename_tmp = cache_filename + "-" cache_filename_tmp = cache_filename + "-"
with open(cache_filename_tmp, "w", encoding="utf8") as file: with open(cache_filename_tmp, "w", encoding="utf8") as file:
json.dump(cache_data, file, indent=4) json.dump(cache_data, file, indent=4, ensure_ascii=False)
os.replace(cache_filename_tmp, cache_filename) os.replace(cache_filename_tmp, cache_filename)
@@ -62,16 +62,15 @@ def cache(subsection):
if cache_data is None: if cache_data is None:
with cache_lock: with cache_lock:
if cache_data is None: if cache_data is None:
if not os.path.isfile(cache_filename): try:
with open(cache_filename, "r", encoding="utf8") as file:
cache_data = json.load(file)
except FileNotFoundError:
cache_data = {}
except Exception:
os.replace(cache_filename, os.path.join(script_path, "tmp", "cache.json"))
print('[ERROR] issue occurred while trying to read cache.json, move current cache to tmp/cache.json and create new cache')
cache_data = {} cache_data = {}
else:
try:
with open(cache_filename, "r", encoding="utf8") as file:
cache_data = json.load(file)
except Exception:
os.replace(cache_filename, os.path.join(script_path, "tmp", "cache.json"))
print('[ERROR] issue occurred while trying to read cache.json, move current cache to tmp/cache.json and create new cache')
cache_data = {}
s = cache_data.get(subsection, {}) s = cache_data.get(subsection, {})
cache_data[subsection] = s cache_data[subsection] = s
+1
View File
@@ -78,6 +78,7 @@ def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
shared.state.skipped = False shared.state.skipped = False
shared.state.interrupted = False shared.state.interrupted = False
shared.state.stopping_generation = False
shared.state.job_count = 0 shared.state.job_count = 0
if not add_stats: if not add_stats:
+5 -2
View File
@@ -70,13 +70,16 @@ parser.add_argument("--opt-sdp-no-mem-attention", action='store_true', help="pre
parser.add_argument("--disable-opt-split-attention", action='store_true', help="prefer no cross-attention layer optimization for automatic choice of optimization") parser.add_argument("--disable-opt-split-attention", action='store_true', help="prefer no cross-attention layer optimization for automatic choice of optimization")
parser.add_argument("--disable-nan-check", action='store_true', help="do not check if produced images/latent spaces have nans; useful for running without a checkpoint in CI") parser.add_argument("--disable-nan-check", action='store_true', help="do not check if produced images/latent spaces have nans; useful for running without a checkpoint in CI")
parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for specified modules", default=[], type=str.lower) parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for specified modules", default=[], type=str.lower)
parser.add_argument("--use-ipex", action="store_true", help="use Intel XPU as torch device")
parser.add_argument("--disable-model-loading-ram-optimization", action='store_true', help="disable an optimization that reduces RAM use when loading a model") parser.add_argument("--disable-model-loading-ram-optimization", action='store_true', help="disable an optimization that reduces RAM use when loading a model")
parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests") parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")
parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None) parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None)
parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False) parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False)
parser.add_argument("--ui-config-file", type=str, help="filename to use for ui configuration", default=os.path.join(data_path, 'ui-config.json')) parser.add_argument("--ui-config-file", type=str, help="filename to use for ui configuration", default=os.path.join(data_path, 'ui-config.json'))
parser.add_argument("--hide-ui-dir-config", action='store_true', help="hide directory configuration from webui", default=False) parser.add_argument("--hide-ui-dir-config", action='store_true', help="hide directory configuration from webui", default=False)
parser.add_argument("--freeze-settings", action='store_true', help="disable editing settings", default=False) parser.add_argument("--freeze-settings", action='store_true', help="disable editing of all settings globally", default=False)
parser.add_argument("--freeze-settings-in-sections", type=str, help='disable editing settings in specific sections of the settings page by specifying a comma-delimited list such like "saving-images,upscaling". The list of setting names can be found in the modules/shared_options.py file', default=None)
parser.add_argument("--freeze-specific-settings", type=str, help='disable editing of individual settings by specifying a comma-delimited list like "samples_save,samples_format". The list of setting names can be found in the config.json file', default=None)
parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui settings", default=os.path.join(data_path, 'config.json')) parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui settings", default=os.path.join(data_path, 'config.json'))
parser.add_argument("--gradio-debug", action='store_true', help="launch gradio with --debug option") parser.add_argument("--gradio-debug", action='store_true', help="launch gradio with --debug option")
parser.add_argument("--gradio-auth", type=str, help='set gradio authentication like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None) parser.add_argument("--gradio-auth", type=str, help='set gradio authentication like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None)
@@ -107,7 +110,7 @@ parser.add_argument("--tls-certfile", type=str, help="Partially enables TLS, req
parser.add_argument("--disable-tls-verify", action="store_false", help="When passed, enables the use of self-signed certificates.", default=None) parser.add_argument("--disable-tls-verify", action="store_false", help="When passed, enables the use of self-signed certificates.", default=None)
parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None) parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None)
parser.add_argument("--gradio-queue", action='store_true', help="does not do anything", default=True) parser.add_argument("--gradio-queue", action='store_true', help="does not do anything", default=True)
parser.add_argument("--no-gradio-queue", action='store_true', help="Disables gradio queue; causes the webpage to use http requests instead of websockets; was the defaul in earlier versions") parser.add_argument("--no-gradio-queue", action='store_true', help="Disables gradio queue; causes the webpage to use http requests instead of websockets; was the default in earlier versions")
parser.add_argument("--skip-version-check", action='store_true', help="Do not check versions of torch and xformers") parser.add_argument("--skip-version-check", action='store_true', help="Do not check versions of torch and xformers")
parser.add_argument("--no-hashing", action='store_true', help="disable sha256 hashing of checkpoints to help loading performance", default=False) parser.add_argument("--no-hashing", action='store_true', help="disable sha256 hashing of checkpoints to help loading performance", default=False)
parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False) parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False)
-276
View File
@@ -1,276 +0,0 @@
# this file is copied from CodeFormer repository. Please see comment in modules/codeformer_model.py
import math
import torch
from torch import nn, Tensor
import torch.nn.functional as F
from typing import Optional
from modules.codeformer.vqgan_arch import VQAutoEncoder, ResBlock
from basicsr.utils.registry import ARCH_REGISTRY
def calc_mean_std(feat, eps=1e-5):
"""Calculate mean and std for adaptive_instance_normalization.
Args:
feat (Tensor): 4D tensor.
eps (float): A small value added to the variance to avoid
divide-by-zero. Default: 1e-5.
"""
size = feat.size()
assert len(size) == 4, 'The input feature should be 4D tensor.'
b, c = size[:2]
feat_var = feat.view(b, c, -1).var(dim=2) + eps
feat_std = feat_var.sqrt().view(b, c, 1, 1)
feat_mean = feat.view(b, c, -1).mean(dim=2).view(b, c, 1, 1)
return feat_mean, feat_std
def adaptive_instance_normalization(content_feat, style_feat):
"""Adaptive instance normalization.
Adjust the reference features to have the similar color and illuminations
as those in the degradate features.
Args:
content_feat (Tensor): The reference feature.
style_feat (Tensor): The degradate features.
"""
size = content_feat.size()
style_mean, style_std = calc_mean_std(style_feat)
content_mean, content_std = calc_mean_std(content_feat)
normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
return normalized_feat * style_std.expand(size) + style_mean.expand(size)
class PositionEmbeddingSine(nn.Module):
"""
This is a more standard version of the position embedding, very similar to the one
used by the Attention is all you need paper, generalized to work on images.
"""
def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
super().__init__()
self.num_pos_feats = num_pos_feats
self.temperature = temperature
self.normalize = normalize
if scale is not None and normalize is False:
raise ValueError("normalize should be True if scale is passed")
if scale is None:
scale = 2 * math.pi
self.scale = scale
def forward(self, x, mask=None):
if mask is None:
mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool)
not_mask = ~mask
y_embed = not_mask.cumsum(1, dtype=torch.float32)
x_embed = not_mask.cumsum(2, dtype=torch.float32)
if self.normalize:
eps = 1e-6
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
pos_x = x_embed[:, :, :, None] / dim_t
pos_y = y_embed[:, :, :, None] / dim_t
pos_x = torch.stack(
(pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
).flatten(3)
pos_y = torch.stack(
(pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
).flatten(3)
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
return pos
def _get_activation_fn(activation):
"""Return an activation function given a string"""
if activation == "relu":
return F.relu
if activation == "gelu":
return F.gelu
if activation == "glu":
return F.glu
raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
class TransformerSALayer(nn.Module):
def __init__(self, embed_dim, nhead=8, dim_mlp=2048, dropout=0.0, activation="gelu"):
super().__init__()
self.self_attn = nn.MultiheadAttention(embed_dim, nhead, dropout=dropout)
# Implementation of Feedforward model - MLP
self.linear1 = nn.Linear(embed_dim, dim_mlp)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_mlp, embed_dim)
self.norm1 = nn.LayerNorm(embed_dim)
self.norm2 = nn.LayerNorm(embed_dim)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.activation = _get_activation_fn(activation)
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
return tensor if pos is None else tensor + pos
def forward(self, tgt,
tgt_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None):
# self attention
tgt2 = self.norm1(tgt)
q = k = self.with_pos_embed(tgt2, query_pos)
tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
key_padding_mask=tgt_key_padding_mask)[0]
tgt = tgt + self.dropout1(tgt2)
# ffn
tgt2 = self.norm2(tgt)
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
tgt = tgt + self.dropout2(tgt2)
return tgt
class Fuse_sft_block(nn.Module):
def __init__(self, in_ch, out_ch):
super().__init__()
self.encode_enc = ResBlock(2*in_ch, out_ch)
self.scale = nn.Sequential(
nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
nn.LeakyReLU(0.2, True),
nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1))
self.shift = nn.Sequential(
nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
nn.LeakyReLU(0.2, True),
nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1))
def forward(self, enc_feat, dec_feat, w=1):
enc_feat = self.encode_enc(torch.cat([enc_feat, dec_feat], dim=1))
scale = self.scale(enc_feat)
shift = self.shift(enc_feat)
residual = w * (dec_feat * scale + shift)
out = dec_feat + residual
return out
@ARCH_REGISTRY.register()
class CodeFormer(VQAutoEncoder):
def __init__(self, dim_embd=512, n_head=8, n_layers=9,
codebook_size=1024, latent_size=256,
connect_list=('32', '64', '128', '256'),
fix_modules=('quantize', 'generator')):
super(CodeFormer, self).__init__(512, 64, [1, 2, 2, 4, 4, 8], 'nearest',2, [16], codebook_size)
if fix_modules is not None:
for module in fix_modules:
for param in getattr(self, module).parameters():
param.requires_grad = False
self.connect_list = connect_list
self.n_layers = n_layers
self.dim_embd = dim_embd
self.dim_mlp = dim_embd*2
self.position_emb = nn.Parameter(torch.zeros(latent_size, self.dim_embd))
self.feat_emb = nn.Linear(256, self.dim_embd)
# transformer
self.ft_layers = nn.Sequential(*[TransformerSALayer(embed_dim=dim_embd, nhead=n_head, dim_mlp=self.dim_mlp, dropout=0.0)
for _ in range(self.n_layers)])
# logits_predict head
self.idx_pred_layer = nn.Sequential(
nn.LayerNorm(dim_embd),
nn.Linear(dim_embd, codebook_size, bias=False))
self.channels = {
'16': 512,
'32': 256,
'64': 256,
'128': 128,
'256': 128,
'512': 64,
}
# after second residual block for > 16, before attn layer for ==16
self.fuse_encoder_block = {'512':2, '256':5, '128':8, '64':11, '32':14, '16':18}
# after first residual block for > 16, before attn layer for ==16
self.fuse_generator_block = {'16':6, '32': 9, '64':12, '128':15, '256':18, '512':21}
# fuse_convs_dict
self.fuse_convs_dict = nn.ModuleDict()
for f_size in self.connect_list:
in_ch = self.channels[f_size]
self.fuse_convs_dict[f_size] = Fuse_sft_block(in_ch, in_ch)
def _init_weights(self, module):
if isinstance(module, (nn.Linear, nn.Embedding)):
module.weight.data.normal_(mean=0.0, std=0.02)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def forward(self, x, w=0, detach_16=True, code_only=False, adain=False):
# ################### Encoder #####################
enc_feat_dict = {}
out_list = [self.fuse_encoder_block[f_size] for f_size in self.connect_list]
for i, block in enumerate(self.encoder.blocks):
x = block(x)
if i in out_list:
enc_feat_dict[str(x.shape[-1])] = x.clone()
lq_feat = x
# ################# Transformer ###################
# quant_feat, codebook_loss, quant_stats = self.quantize(lq_feat)
pos_emb = self.position_emb.unsqueeze(1).repeat(1,x.shape[0],1)
# BCHW -> BC(HW) -> (HW)BC
feat_emb = self.feat_emb(lq_feat.flatten(2).permute(2,0,1))
query_emb = feat_emb
# Transformer encoder
for layer in self.ft_layers:
query_emb = layer(query_emb, query_pos=pos_emb)
# output logits
logits = self.idx_pred_layer(query_emb) # (hw)bn
logits = logits.permute(1,0,2) # (hw)bn -> b(hw)n
if code_only: # for training stage II
# logits doesn't need softmax before cross_entropy loss
return logits, lq_feat
# ################# Quantization ###################
# if self.training:
# quant_feat = torch.einsum('btn,nc->btc', [soft_one_hot, self.quantize.embedding.weight])
# # b(hw)c -> bc(hw) -> bchw
# quant_feat = quant_feat.permute(0,2,1).view(lq_feat.shape)
# ------------
soft_one_hot = F.softmax(logits, dim=2)
_, top_idx = torch.topk(soft_one_hot, 1, dim=2)
quant_feat = self.quantize.get_codebook_feat(top_idx, shape=[x.shape[0],16,16,256])
# preserve gradients
# quant_feat = lq_feat + (quant_feat - lq_feat).detach()
if detach_16:
quant_feat = quant_feat.detach() # for training stage III
if adain:
quant_feat = adaptive_instance_normalization(quant_feat, lq_feat)
# ################## Generator ####################
x = quant_feat
fuse_list = [self.fuse_generator_block[f_size] for f_size in self.connect_list]
for i, block in enumerate(self.generator.blocks):
x = block(x)
if i in fuse_list: # fuse after i-th block
f_size = str(x.shape[-1])
if w>0:
x = self.fuse_convs_dict[f_size](enc_feat_dict[f_size].detach(), x, w)
out = x
# logits doesn't need softmax before cross_entropy loss
return out, logits, lq_feat
-435
View File
@@ -1,435 +0,0 @@
# this file is copied from CodeFormer repository. Please see comment in modules/codeformer_model.py
'''
VQGAN code, adapted from the original created by the Unleashing Transformers authors:
https://github.com/samb-t/unleashing-transformers/blob/master/models/vqgan.py
'''
import torch
import torch.nn as nn
import torch.nn.functional as F
from basicsr.utils import get_root_logger
from basicsr.utils.registry import ARCH_REGISTRY
def normalize(in_channels):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
@torch.jit.script
def swish(x):
return x*torch.sigmoid(x)
# Define VQVAE classes
class VectorQuantizer(nn.Module):
def __init__(self, codebook_size, emb_dim, beta):
super(VectorQuantizer, self).__init__()
self.codebook_size = codebook_size # number of embeddings
self.emb_dim = emb_dim # dimension of embedding
self.beta = beta # commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
self.embedding = nn.Embedding(self.codebook_size, self.emb_dim)
self.embedding.weight.data.uniform_(-1.0 / self.codebook_size, 1.0 / self.codebook_size)
def forward(self, z):
# reshape z -> (batch, height, width, channel) and flatten
z = z.permute(0, 2, 3, 1).contiguous()
z_flattened = z.view(-1, self.emb_dim)
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
d = (z_flattened ** 2).sum(dim=1, keepdim=True) + (self.embedding.weight**2).sum(1) - \
2 * torch.matmul(z_flattened, self.embedding.weight.t())
mean_distance = torch.mean(d)
# find closest encodings
# min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
min_encoding_scores, min_encoding_indices = torch.topk(d, 1, dim=1, largest=False)
# [0-1], higher score, higher confidence
min_encoding_scores = torch.exp(-min_encoding_scores/10)
min_encodings = torch.zeros(min_encoding_indices.shape[0], self.codebook_size).to(z)
min_encodings.scatter_(1, min_encoding_indices, 1)
# get quantized latent vectors
z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
# compute loss for embedding
loss = torch.mean((z_q.detach()-z)**2) + self.beta * torch.mean((z_q - z.detach()) ** 2)
# preserve gradients
z_q = z + (z_q - z).detach()
# perplexity
e_mean = torch.mean(min_encodings, dim=0)
perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))
# reshape back to match original input shape
z_q = z_q.permute(0, 3, 1, 2).contiguous()
return z_q, loss, {
"perplexity": perplexity,
"min_encodings": min_encodings,
"min_encoding_indices": min_encoding_indices,
"min_encoding_scores": min_encoding_scores,
"mean_distance": mean_distance
}
def get_codebook_feat(self, indices, shape):
# input indices: batch*token_num -> (batch*token_num)*1
# shape: batch, height, width, channel
indices = indices.view(-1,1)
min_encodings = torch.zeros(indices.shape[0], self.codebook_size).to(indices)
min_encodings.scatter_(1, indices, 1)
# get quantized latent vectors
z_q = torch.matmul(min_encodings.float(), self.embedding.weight)
if shape is not None: # reshape back to match original input shape
z_q = z_q.view(shape).permute(0, 3, 1, 2).contiguous()
return z_q
class GumbelQuantizer(nn.Module):
def __init__(self, codebook_size, emb_dim, num_hiddens, straight_through=False, kl_weight=5e-4, temp_init=1.0):
super().__init__()
self.codebook_size = codebook_size # number of embeddings
self.emb_dim = emb_dim # dimension of embedding
self.straight_through = straight_through
self.temperature = temp_init
self.kl_weight = kl_weight
self.proj = nn.Conv2d(num_hiddens, codebook_size, 1) # projects last encoder layer to quantized logits
self.embed = nn.Embedding(codebook_size, emb_dim)
def forward(self, z):
hard = self.straight_through if self.training else True
logits = self.proj(z)
soft_one_hot = F.gumbel_softmax(logits, tau=self.temperature, dim=1, hard=hard)
z_q = torch.einsum("b n h w, n d -> b d h w", soft_one_hot, self.embed.weight)
# + kl divergence to the prior loss
qy = F.softmax(logits, dim=1)
diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.codebook_size + 1e-10), dim=1).mean()
min_encoding_indices = soft_one_hot.argmax(dim=1)
return z_q, diff, {
"min_encoding_indices": min_encoding_indices
}
class Downsample(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
def forward(self, x):
pad = (0, 1, 0, 1)
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
x = self.conv(x)
return x
class Upsample(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
def forward(self, x):
x = F.interpolate(x, scale_factor=2.0, mode="nearest")
x = self.conv(x)
return x
class ResBlock(nn.Module):
def __init__(self, in_channels, out_channels=None):
super(ResBlock, self).__init__()
self.in_channels = in_channels
self.out_channels = in_channels if out_channels is None else out_channels
self.norm1 = normalize(in_channels)
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.norm2 = normalize(out_channels)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
if self.in_channels != self.out_channels:
self.conv_out = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x_in):
x = x_in
x = self.norm1(x)
x = swish(x)
x = self.conv1(x)
x = self.norm2(x)
x = swish(x)
x = self.conv2(x)
if self.in_channels != self.out_channels:
x_in = self.conv_out(x_in)
return x + x_in
class AttnBlock(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.in_channels = in_channels
self.norm = normalize(in_channels)
self.q = torch.nn.Conv2d(
in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0
)
self.k = torch.nn.Conv2d(
in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0
)
self.v = torch.nn.Conv2d(
in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0
)
self.proj_out = torch.nn.Conv2d(
in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0
)
def forward(self, x):
h_ = x
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
# compute attention
b, c, h, w = q.shape
q = q.reshape(b, c, h*w)
q = q.permute(0, 2, 1)
k = k.reshape(b, c, h*w)
w_ = torch.bmm(q, k)
w_ = w_ * (int(c)**(-0.5))
w_ = F.softmax(w_, dim=2)
# attend to values
v = v.reshape(b, c, h*w)
w_ = w_.permute(0, 2, 1)
h_ = torch.bmm(v, w_)
h_ = h_.reshape(b, c, h, w)
h_ = self.proj_out(h_)
return x+h_
class Encoder(nn.Module):
def __init__(self, in_channels, nf, emb_dim, ch_mult, num_res_blocks, resolution, attn_resolutions):
super().__init__()
self.nf = nf
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.attn_resolutions = attn_resolutions
curr_res = self.resolution
in_ch_mult = (1,)+tuple(ch_mult)
blocks = []
# initial convultion
blocks.append(nn.Conv2d(in_channels, nf, kernel_size=3, stride=1, padding=1))
# residual and downsampling blocks, with attention on smaller res (16x16)
for i in range(self.num_resolutions):
block_in_ch = nf * in_ch_mult[i]
block_out_ch = nf * ch_mult[i]
for _ in range(self.num_res_blocks):
blocks.append(ResBlock(block_in_ch, block_out_ch))
block_in_ch = block_out_ch
if curr_res in attn_resolutions:
blocks.append(AttnBlock(block_in_ch))
if i != self.num_resolutions - 1:
blocks.append(Downsample(block_in_ch))
curr_res = curr_res // 2
# non-local attention block
blocks.append(ResBlock(block_in_ch, block_in_ch))
blocks.append(AttnBlock(block_in_ch))
blocks.append(ResBlock(block_in_ch, block_in_ch))
# normalise and convert to latent size
blocks.append(normalize(block_in_ch))
blocks.append(nn.Conv2d(block_in_ch, emb_dim, kernel_size=3, stride=1, padding=1))
self.blocks = nn.ModuleList(blocks)
def forward(self, x):
for block in self.blocks:
x = block(x)
return x
class Generator(nn.Module):
def __init__(self, nf, emb_dim, ch_mult, res_blocks, img_size, attn_resolutions):
super().__init__()
self.nf = nf
self.ch_mult = ch_mult
self.num_resolutions = len(self.ch_mult)
self.num_res_blocks = res_blocks
self.resolution = img_size
self.attn_resolutions = attn_resolutions
self.in_channels = emb_dim
self.out_channels = 3
block_in_ch = self.nf * self.ch_mult[-1]
curr_res = self.resolution // 2 ** (self.num_resolutions-1)
blocks = []
# initial conv
blocks.append(nn.Conv2d(self.in_channels, block_in_ch, kernel_size=3, stride=1, padding=1))
# non-local attention block
blocks.append(ResBlock(block_in_ch, block_in_ch))
blocks.append(AttnBlock(block_in_ch))
blocks.append(ResBlock(block_in_ch, block_in_ch))
for i in reversed(range(self.num_resolutions)):
block_out_ch = self.nf * self.ch_mult[i]
for _ in range(self.num_res_blocks):
blocks.append(ResBlock(block_in_ch, block_out_ch))
block_in_ch = block_out_ch
if curr_res in self.attn_resolutions:
blocks.append(AttnBlock(block_in_ch))
if i != 0:
blocks.append(Upsample(block_in_ch))
curr_res = curr_res * 2
blocks.append(normalize(block_in_ch))
blocks.append(nn.Conv2d(block_in_ch, self.out_channels, kernel_size=3, stride=1, padding=1))
self.blocks = nn.ModuleList(blocks)
def forward(self, x):
for block in self.blocks:
x = block(x)
return x
@ARCH_REGISTRY.register()
class VQAutoEncoder(nn.Module):
def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2, attn_resolutions=None, codebook_size=1024, emb_dim=256,
beta=0.25, gumbel_straight_through=False, gumbel_kl_weight=1e-8, model_path=None):
super().__init__()
logger = get_root_logger()
self.in_channels = 3
self.nf = nf
self.n_blocks = res_blocks
self.codebook_size = codebook_size
self.embed_dim = emb_dim
self.ch_mult = ch_mult
self.resolution = img_size
self.attn_resolutions = attn_resolutions or [16]
self.quantizer_type = quantizer
self.encoder = Encoder(
self.in_channels,
self.nf,
self.embed_dim,
self.ch_mult,
self.n_blocks,
self.resolution,
self.attn_resolutions
)
if self.quantizer_type == "nearest":
self.beta = beta #0.25
self.quantize = VectorQuantizer(self.codebook_size, self.embed_dim, self.beta)
elif self.quantizer_type == "gumbel":
self.gumbel_num_hiddens = emb_dim
self.straight_through = gumbel_straight_through
self.kl_weight = gumbel_kl_weight
self.quantize = GumbelQuantizer(
self.codebook_size,
self.embed_dim,
self.gumbel_num_hiddens,
self.straight_through,
self.kl_weight
)
self.generator = Generator(
self.nf,
self.embed_dim,
self.ch_mult,
self.n_blocks,
self.resolution,
self.attn_resolutions
)
if model_path is not None:
chkpt = torch.load(model_path, map_location='cpu')
if 'params_ema' in chkpt:
self.load_state_dict(torch.load(model_path, map_location='cpu')['params_ema'])
logger.info(f'vqgan is loaded from: {model_path} [params_ema]')
elif 'params' in chkpt:
self.load_state_dict(torch.load(model_path, map_location='cpu')['params'])
logger.info(f'vqgan is loaded from: {model_path} [params]')
else:
raise ValueError('Wrong params!')
def forward(self, x):
x = self.encoder(x)
quant, codebook_loss, quant_stats = self.quantize(x)
x = self.generator(quant)
return x, codebook_loss, quant_stats
# patch based discriminator
@ARCH_REGISTRY.register()
class VQGANDiscriminator(nn.Module):
def __init__(self, nc=3, ndf=64, n_layers=4, model_path=None):
super().__init__()
layers = [nn.Conv2d(nc, ndf, kernel_size=4, stride=2, padding=1), nn.LeakyReLU(0.2, True)]
ndf_mult = 1
ndf_mult_prev = 1
for n in range(1, n_layers): # gradually increase the number of filters
ndf_mult_prev = ndf_mult
ndf_mult = min(2 ** n, 8)
layers += [
nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(ndf * ndf_mult),
nn.LeakyReLU(0.2, True)
]
ndf_mult_prev = ndf_mult
ndf_mult = min(2 ** n_layers, 8)
layers += [
nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=1, padding=1, bias=False),
nn.BatchNorm2d(ndf * ndf_mult),
nn.LeakyReLU(0.2, True)
]
layers += [
nn.Conv2d(ndf * ndf_mult, 1, kernel_size=4, stride=1, padding=1)] # output 1 channel prediction map
self.main = nn.Sequential(*layers)
if model_path is not None:
chkpt = torch.load(model_path, map_location='cpu')
if 'params_d' in chkpt:
self.load_state_dict(torch.load(model_path, map_location='cpu')['params_d'])
elif 'params' in chkpt:
self.load_state_dict(torch.load(model_path, map_location='cpu')['params'])
else:
raise ValueError('Wrong params!')
def forward(self, x):
return self.main(x)
+49 -117
View File
@@ -1,132 +1,64 @@
import os from __future__ import annotations
import logging
import cv2
import torch import torch
import modules.face_restoration from modules import (
import modules.shared devices,
from modules import shared, devices, modelloader, errors errors,
from modules.paths import models_path face_restoration,
face_restoration_utils,
modelloader,
shared,
)
logger = logging.getLogger(__name__)
# codeformer people made a choice to include modified basicsr library to their project which makes
# it utterly impossible to use it alongside with other libraries that also use basicsr, like GFPGAN.
# I am making a choice to include some files from codeformer to work around this issue.
model_dir = "Codeformer"
model_path = os.path.join(models_path, model_dir)
model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth' model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
model_download_name = 'codeformer-v0.1.0.pth'
codeformer = None # used by e.g. postprocessing_codeformer.py
codeformer: face_restoration.FaceRestoration | None = None
def setup_model(dirname): class FaceRestorerCodeFormer(face_restoration_utils.CommonFaceRestoration):
os.makedirs(model_path, exist_ok=True) def name(self):
return "CodeFormer"
path = modules.paths.paths.get("CodeFormer", None) def load_net(self) -> torch.Module:
if path is None: for model_path in modelloader.load_models(
return model_path=self.model_path,
model_url=model_url,
command_path=self.model_path,
download_name=model_download_name,
ext_filter=['.pth'],
):
return modelloader.load_spandrel_model(
model_path,
device=devices.device_codeformer,
expected_architecture='CodeFormer',
).model
raise ValueError("No codeformer model found")
def get_device(self):
return devices.device_codeformer
def restore(self, np_image, w: float | None = None):
if w is None:
w = getattr(shared.opts, "code_former_weight", 0.5)
def restore_face(cropped_face_t):
assert self.net is not None
return self.net(cropped_face_t, w=w, adain=True)[0]
return self.restore_with_helper(np_image, restore_face)
def setup_model(dirname: str) -> None:
global codeformer
try: try:
from torchvision.transforms.functional import normalize
from modules.codeformer.codeformer_arch import CodeFormer
from basicsr.utils import img2tensor, tensor2img
from facelib.utils.face_restoration_helper import FaceRestoreHelper
from facelib.detection.retinaface import retinaface
net_class = CodeFormer
class FaceRestorerCodeFormer(modules.face_restoration.FaceRestoration):
def name(self):
return "CodeFormer"
def __init__(self, dirname):
self.net = None
self.face_helper = None
self.cmd_dir = dirname
def create_models(self):
if self.net is not None and self.face_helper is not None:
self.net.to(devices.device_codeformer)
return self.net, self.face_helper
model_paths = modelloader.load_models(model_path, model_url, self.cmd_dir, download_name='codeformer-v0.1.0.pth', ext_filter=['.pth'])
if len(model_paths) != 0:
ckpt_path = model_paths[0]
else:
print("Unable to load codeformer model.")
return None, None
net = net_class(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9, connect_list=['32', '64', '128', '256']).to(devices.device_codeformer)
checkpoint = torch.load(ckpt_path)['params_ema']
net.load_state_dict(checkpoint)
net.eval()
if hasattr(retinaface, 'device'):
retinaface.device = devices.device_codeformer
face_helper = FaceRestoreHelper(1, face_size=512, crop_ratio=(1, 1), det_model='retinaface_resnet50', save_ext='png', use_parse=True, device=devices.device_codeformer)
self.net = net
self.face_helper = face_helper
return net, face_helper
def send_model_to(self, device):
self.net.to(device)
self.face_helper.face_det.to(device)
self.face_helper.face_parse.to(device)
def restore(self, np_image, w=None):
np_image = np_image[:, :, ::-1]
original_resolution = np_image.shape[0:2]
self.create_models()
if self.net is None or self.face_helper is None:
return np_image
self.send_model_to(devices.device_codeformer)
self.face_helper.clean_all()
self.face_helper.read_image(np_image)
self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)
self.face_helper.align_warp_face()
for cropped_face in self.face_helper.cropped_faces:
cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
cropped_face_t = cropped_face_t.unsqueeze(0).to(devices.device_codeformer)
try:
with torch.no_grad():
output = self.net(cropped_face_t, w=w if w is not None else shared.opts.code_former_weight, adain=True)[0]
restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
del output
devices.torch_gc()
except Exception:
errors.report('Failed inference for CodeFormer', exc_info=True)
restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
restored_face = restored_face.astype('uint8')
self.face_helper.add_restored_face(restored_face)
self.face_helper.get_inverse_affine(None)
restored_img = self.face_helper.paste_faces_to_input_image()
restored_img = restored_img[:, :, ::-1]
if original_resolution != restored_img.shape[0:2]:
restored_img = cv2.resize(restored_img, (0, 0), fx=original_resolution[1]/restored_img.shape[1], fy=original_resolution[0]/restored_img.shape[0], interpolation=cv2.INTER_LINEAR)
self.face_helper.clean_all()
if shared.opts.face_restoration_unload:
self.send_model_to(devices.cpu)
return restored_img
global codeformer
codeformer = FaceRestorerCodeFormer(dirname) codeformer = FaceRestorerCodeFormer(dirname)
shared.face_restorers.append(codeformer) shared.face_restorers.append(codeformer)
except Exception: except Exception:
errors.report("Error setting up CodeFormer", exc_info=True) errors.report("Error setting up CodeFormer", exc_info=True)
# sys.path = stored_sys_path
+79
View File
@@ -0,0 +1,79 @@
import os
from modules import modelloader, errors
from modules.shared import cmd_opts, opts
from modules.upscaler import Upscaler, UpscalerData
from modules.upscaler_utils import upscale_with_model
class UpscalerDAT(Upscaler):
def __init__(self, user_path):
self.name = "DAT"
self.user_path = user_path
self.scalers = []
super().__init__()
for file in self.find_models(ext_filter=[".pt", ".pth"]):
name = modelloader.friendly_name(file)
scaler_data = UpscalerData(name, file, upscaler=self, scale=None)
self.scalers.append(scaler_data)
for model in get_dat_models(self):
if model.name in opts.dat_enabled_models:
self.scalers.append(model)
def do_upscale(self, img, path):
try:
info = self.load_model(path)
except Exception:
errors.report(f"Unable to load DAT model {path}", exc_info=True)
return img
model_descriptor = modelloader.load_spandrel_model(
info.local_data_path,
device=self.device,
prefer_half=(not cmd_opts.no_half and not cmd_opts.upcast_sampling),
expected_architecture="DAT",
)
return upscale_with_model(
model_descriptor,
img,
tile_size=opts.DAT_tile,
tile_overlap=opts.DAT_tile_overlap,
)
def load_model(self, path):
for scaler in self.scalers:
if scaler.data_path == path:
if scaler.local_data_path.startswith("http"):
scaler.local_data_path = modelloader.load_file_from_url(
scaler.data_path,
model_dir=self.model_download_path,
)
if not os.path.exists(scaler.local_data_path):
raise FileNotFoundError(f"DAT data missing: {scaler.local_data_path}")
return scaler
raise ValueError(f"Unable to find model info: {path}")
def get_dat_models(scaler):
return [
UpscalerData(
name="DAT x2",
path="https://github.com/n0kovo/dat_upscaler_models/raw/main/DAT/DAT_x2.pth",
scale=2,
upscaler=scaler,
),
UpscalerData(
name="DAT x3",
path="https://github.com/n0kovo/dat_upscaler_models/raw/main/DAT/DAT_x3.pth",
scale=3,
upscaler=scaler,
),
UpscalerData(
name="DAT x4",
path="https://github.com/n0kovo/dat_upscaler_models/raw/main/DAT/DAT_x4.pth",
scale=4,
upscaler=scaler,
),
]
+105 -4
View File
@@ -4,10 +4,18 @@ from functools import lru_cache
import torch import torch
from modules import errors, shared from modules import errors, shared
from modules import torch_utils
if sys.platform == "darwin": if sys.platform == "darwin":
from modules import mac_specific from modules import mac_specific
if shared.cmd_opts.use_ipex:
from modules import xpu_specific
def has_xpu() -> bool:
return shared.cmd_opts.use_ipex and xpu_specific.has_xpu
def has_mps() -> bool: def has_mps() -> bool:
if sys.platform != "darwin": if sys.platform != "darwin":
@@ -16,6 +24,23 @@ def has_mps() -> bool:
return mac_specific.has_mps return mac_specific.has_mps
def cuda_no_autocast(device_id=None) -> bool:
if device_id is None:
device_id = get_cuda_device_id()
return (
torch.cuda.get_device_capability(device_id) == (7, 5)
and torch.cuda.get_device_name(device_id).startswith("NVIDIA GeForce GTX 16")
)
def get_cuda_device_id():
return (
int(shared.cmd_opts.device_id)
if shared.cmd_opts.device_id is not None and shared.cmd_opts.device_id.isdigit()
else 0
) or torch.cuda.current_device()
def get_cuda_device_string(): def get_cuda_device_string():
if shared.cmd_opts.device_id is not None: if shared.cmd_opts.device_id is not None:
return f"cuda:{shared.cmd_opts.device_id}" return f"cuda:{shared.cmd_opts.device_id}"
@@ -30,6 +55,9 @@ def get_optimal_device_name():
if has_mps(): if has_mps():
return "mps" return "mps"
if has_xpu():
return xpu_specific.get_xpu_device_string()
return "cpu" return "cpu"
@@ -38,7 +66,7 @@ def get_optimal_device():
def get_device_for(task): def get_device_for(task):
if task in shared.cmd_opts.use_cpu: if task in shared.cmd_opts.use_cpu or "all" in shared.cmd_opts.use_cpu:
return cpu return cpu
return get_optimal_device() return get_optimal_device()
@@ -54,14 +82,16 @@ def torch_gc():
if has_mps(): if has_mps():
mac_specific.torch_mps_gc() mac_specific.torch_mps_gc()
if has_xpu():
xpu_specific.torch_xpu_gc()
def enable_tf32(): def enable_tf32():
if torch.cuda.is_available(): if torch.cuda.is_available():
# enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't # enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't
# see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407 # see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407
device_id = (int(shared.cmd_opts.device_id) if shared.cmd_opts.device_id is not None and shared.cmd_opts.device_id.isdigit() else 0) or torch.cuda.current_device() if cuda_no_autocast():
if torch.cuda.get_device_capability(device_id) == (7, 5) and torch.cuda.get_device_name(device_id).startswith("NVIDIA GeForce GTX 16"):
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True
@@ -71,6 +101,7 @@ def enable_tf32():
errors.run(enable_tf32, "Enabling TF32") errors.run(enable_tf32, "Enabling TF32")
cpu: torch.device = torch.device("cpu") cpu: torch.device = torch.device("cpu")
fp8: bool = False
device: torch.device = None device: torch.device = None
device_interrogate: torch.device = None device_interrogate: torch.device = None
device_gfpgan: torch.device = None device_gfpgan: torch.device = None
@@ -79,6 +110,7 @@ device_codeformer: torch.device = None
dtype: torch.dtype = torch.float16 dtype: torch.dtype = torch.float16
dtype_vae: torch.dtype = torch.float16 dtype_vae: torch.dtype = torch.float16
dtype_unet: torch.dtype = torch.float16 dtype_unet: torch.dtype = torch.float16
dtype_inference: torch.dtype = torch.float16
unet_needs_upcast = False unet_needs_upcast = False
@@ -91,15 +123,84 @@ def cond_cast_float(input):
nv_rng = None nv_rng = None
patch_module_list = [
torch.nn.Linear,
torch.nn.Conv2d,
torch.nn.MultiheadAttention,
torch.nn.GroupNorm,
torch.nn.LayerNorm,
]
def manual_cast_forward(target_dtype):
def forward_wrapper(self, *args, **kwargs):
if any(
isinstance(arg, torch.Tensor) and arg.dtype != target_dtype
for arg in args
):
args = [arg.to(target_dtype) if isinstance(arg, torch.Tensor) else arg for arg in args]
kwargs = {k: v.to(target_dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()}
org_dtype = torch_utils.get_param(self).dtype
if org_dtype != target_dtype:
self.to(target_dtype)
result = self.org_forward(*args, **kwargs)
if org_dtype != target_dtype:
self.to(org_dtype)
if target_dtype != dtype_inference:
if isinstance(result, tuple):
result = tuple(
i.to(dtype_inference)
if isinstance(i, torch.Tensor)
else i
for i in result
)
elif isinstance(result, torch.Tensor):
result = result.to(dtype_inference)
return result
return forward_wrapper
@contextlib.contextmanager
def manual_cast(target_dtype):
applied = False
for module_type in patch_module_list:
if hasattr(module_type, "org_forward"):
continue
applied = True
org_forward = module_type.forward
if module_type == torch.nn.MultiheadAttention and has_xpu():
module_type.forward = manual_cast_forward(torch.float32)
else:
module_type.forward = manual_cast_forward(target_dtype)
module_type.org_forward = org_forward
try:
yield None
finally:
if applied:
for module_type in patch_module_list:
if hasattr(module_type, "org_forward"):
module_type.forward = module_type.org_forward
delattr(module_type, "org_forward")
def autocast(disable=False): def autocast(disable=False):
if disable: if disable:
return contextlib.nullcontext() return contextlib.nullcontext()
if dtype == torch.float32 or shared.cmd_opts.precision == "full": if fp8 and device==cpu:
return torch.autocast("cpu", dtype=torch.bfloat16, enabled=True)
if fp8 and dtype_inference == torch.float32:
return manual_cast(dtype)
if dtype == torch.float32 or dtype_inference == torch.float32:
return contextlib.nullcontext() return contextlib.nullcontext()
if has_xpu() or has_mps() or cuda_no_autocast():
return manual_cast(dtype)
return torch.autocast("cuda") return torch.autocast("cuda")
+18 -4
View File
@@ -6,6 +6,21 @@ import traceback
exception_records = [] exception_records = []
def format_traceback(tb):
return [[f"{x.filename}, line {x.lineno}, {x.name}", x.line] for x in traceback.extract_tb(tb)]
def format_exception(e, tb):
return {"exception": str(e), "traceback": format_traceback(tb)}
def get_exceptions():
try:
return list(reversed(exception_records))
except Exception as e:
return str(e)
def record_exception(): def record_exception():
_, e, tb = sys.exc_info() _, e, tb = sys.exc_info()
if e is None: if e is None:
@@ -14,8 +29,7 @@ def record_exception():
if exception_records and exception_records[-1] == e: if exception_records and exception_records[-1] == e:
return return
from modules import sysinfo exception_records.append(format_exception(e, tb))
exception_records.append(sysinfo.format_exception(e, tb))
if len(exception_records) > 5: if len(exception_records) > 5:
exception_records.pop(0) exception_records.pop(0)
@@ -93,8 +107,8 @@ def check_versions():
import torch import torch
import gradio import gradio
expected_torch_version = "2.0.0" expected_torch_version = "2.1.2"
expected_xformers_version = "0.0.20" expected_xformers_version = "0.0.23.post1"
expected_gradio_version = "3.41.2" expected_gradio_version = "3.41.2"
if version.parse(torch.__version__) < version.parse(expected_torch_version): if version.parse(torch.__version__) < version.parse(expected_torch_version):
+16 -183
View File
@@ -1,121 +1,7 @@
import sys from modules import modelloader, devices, errors
import numpy as np
import torch
from PIL import Image
import modules.esrgan_model_arch as arch
from modules import modelloader, images, devices
from modules.shared import opts from modules.shared import opts
from modules.upscaler import Upscaler, UpscalerData from modules.upscaler import Upscaler, UpscalerData
from modules.upscaler_utils import upscale_with_model
def mod2normal(state_dict):
# this code is copied from https://github.com/victorca25/iNNfer
if 'conv_first.weight' in state_dict:
crt_net = {}
items = list(state_dict)
crt_net['model.0.weight'] = state_dict['conv_first.weight']
crt_net['model.0.bias'] = state_dict['conv_first.bias']
for k in items.copy():
if 'RDB' in k:
ori_k = k.replace('RRDB_trunk.', 'model.1.sub.')
if '.weight' in k:
ori_k = ori_k.replace('.weight', '.0.weight')
elif '.bias' in k:
ori_k = ori_k.replace('.bias', '.0.bias')
crt_net[ori_k] = state_dict[k]
items.remove(k)
crt_net['model.1.sub.23.weight'] = state_dict['trunk_conv.weight']
crt_net['model.1.sub.23.bias'] = state_dict['trunk_conv.bias']
crt_net['model.3.weight'] = state_dict['upconv1.weight']
crt_net['model.3.bias'] = state_dict['upconv1.bias']
crt_net['model.6.weight'] = state_dict['upconv2.weight']
crt_net['model.6.bias'] = state_dict['upconv2.bias']
crt_net['model.8.weight'] = state_dict['HRconv.weight']
crt_net['model.8.bias'] = state_dict['HRconv.bias']
crt_net['model.10.weight'] = state_dict['conv_last.weight']
crt_net['model.10.bias'] = state_dict['conv_last.bias']
state_dict = crt_net
return state_dict
def resrgan2normal(state_dict, nb=23):
# this code is copied from https://github.com/victorca25/iNNfer
if "conv_first.weight" in state_dict and "body.0.rdb1.conv1.weight" in state_dict:
re8x = 0
crt_net = {}
items = list(state_dict)
crt_net['model.0.weight'] = state_dict['conv_first.weight']
crt_net['model.0.bias'] = state_dict['conv_first.bias']
for k in items.copy():
if "rdb" in k:
ori_k = k.replace('body.', 'model.1.sub.')
ori_k = ori_k.replace('.rdb', '.RDB')
if '.weight' in k:
ori_k = ori_k.replace('.weight', '.0.weight')
elif '.bias' in k:
ori_k = ori_k.replace('.bias', '.0.bias')
crt_net[ori_k] = state_dict[k]
items.remove(k)
crt_net[f'model.1.sub.{nb}.weight'] = state_dict['conv_body.weight']
crt_net[f'model.1.sub.{nb}.bias'] = state_dict['conv_body.bias']
crt_net['model.3.weight'] = state_dict['conv_up1.weight']
crt_net['model.3.bias'] = state_dict['conv_up1.bias']
crt_net['model.6.weight'] = state_dict['conv_up2.weight']
crt_net['model.6.bias'] = state_dict['conv_up2.bias']
if 'conv_up3.weight' in state_dict:
# modification supporting: https://github.com/ai-forever/Real-ESRGAN/blob/main/RealESRGAN/rrdbnet_arch.py
re8x = 3
crt_net['model.9.weight'] = state_dict['conv_up3.weight']
crt_net['model.9.bias'] = state_dict['conv_up3.bias']
crt_net[f'model.{8+re8x}.weight'] = state_dict['conv_hr.weight']
crt_net[f'model.{8+re8x}.bias'] = state_dict['conv_hr.bias']
crt_net[f'model.{10+re8x}.weight'] = state_dict['conv_last.weight']
crt_net[f'model.{10+re8x}.bias'] = state_dict['conv_last.bias']
state_dict = crt_net
return state_dict
def infer_params(state_dict):
# this code is copied from https://github.com/victorca25/iNNfer
scale2x = 0
scalemin = 6
n_uplayer = 0
plus = False
for block in list(state_dict):
parts = block.split(".")
n_parts = len(parts)
if n_parts == 5 and parts[2] == "sub":
nb = int(parts[3])
elif n_parts == 3:
part_num = int(parts[1])
if (part_num > scalemin
and parts[0] == "model"
and parts[2] == "weight"):
scale2x += 1
if part_num > n_uplayer:
n_uplayer = part_num
out_nc = state_dict[block].shape[0]
if not plus and "conv1x1" in block:
plus = True
nf = state_dict["model.0.weight"].shape[0]
in_nc = state_dict["model.0.weight"].shape[1]
out_nc = out_nc
scale = 2 ** scale2x
return in_nc, out_nc, nf, nb, plus, scale
class UpscalerESRGAN(Upscaler): class UpscalerESRGAN(Upscaler):
@@ -143,12 +29,11 @@ class UpscalerESRGAN(Upscaler):
def do_upscale(self, img, selected_model): def do_upscale(self, img, selected_model):
try: try:
model = self.load_model(selected_model) model = self.load_model(selected_model)
except Exception as e: except Exception:
print(f"Unable to load ESRGAN model {selected_model}: {e}", file=sys.stderr) errors.report(f"Unable to load ESRGAN model {selected_model}", exc_info=True)
return img return img
model.to(devices.device_esrgan) model.to(devices.device_esrgan)
img = esrgan_upscale(model, img) return esrgan_upscale(model, img)
return img
def load_model(self, path: str): def load_model(self, path: str):
if path.startswith("http"): if path.startswith("http"):
@@ -161,69 +46,17 @@ class UpscalerESRGAN(Upscaler):
else: else:
filename = path filename = path
state_dict = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None) return modelloader.load_spandrel_model(
filename,
if "params_ema" in state_dict: device=('cpu' if devices.device_esrgan.type == 'mps' else None),
state_dict = state_dict["params_ema"] expected_architecture='ESRGAN',
elif "params" in state_dict: )
state_dict = state_dict["params"]
num_conv = 16 if "realesr-animevideov3" in filename else 32
model = arch.SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=num_conv, upscale=4, act_type='prelu')
model.load_state_dict(state_dict)
model.eval()
return model
if "body.0.rdb1.conv1.weight" in state_dict and "conv_first.weight" in state_dict:
nb = 6 if "RealESRGAN_x4plus_anime_6B" in filename else 23
state_dict = resrgan2normal(state_dict, nb)
elif "conv_first.weight" in state_dict:
state_dict = mod2normal(state_dict)
elif "model.0.weight" not in state_dict:
raise Exception("The file is not a recognized ESRGAN model.")
in_nc, out_nc, nf, nb, plus, mscale = infer_params(state_dict)
model = arch.RRDBNet(in_nc=in_nc, out_nc=out_nc, nf=nf, nb=nb, upscale=mscale, plus=plus)
model.load_state_dict(state_dict)
model.eval()
return model
def upscale_without_tiling(model, img):
img = np.array(img)
img = img[:, :, ::-1]
img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255
img = torch.from_numpy(img).float()
img = img.unsqueeze(0).to(devices.device_esrgan)
with torch.no_grad():
output = model(img)
output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
output = 255. * np.moveaxis(output, 0, 2)
output = output.astype(np.uint8)
output = output[:, :, ::-1]
return Image.fromarray(output, 'RGB')
def esrgan_upscale(model, img): def esrgan_upscale(model, img):
if opts.ESRGAN_tile == 0: return upscale_with_model(
return upscale_without_tiling(model, img) model,
img,
grid = images.split_grid(img, opts.ESRGAN_tile, opts.ESRGAN_tile, opts.ESRGAN_tile_overlap) tile_size=opts.ESRGAN_tile,
newtiles = [] tile_overlap=opts.ESRGAN_tile_overlap,
scale_factor = 1 )
for y, h, row in grid.tiles:
newrow = []
for tiledata in row:
x, w, tile = tiledata
output = upscale_without_tiling(model, tile)
scale_factor = output.width // tile.width
newrow.append([x * scale_factor, w * scale_factor, output])
newtiles.append([y * scale_factor, h * scale_factor, newrow])
newgrid = images.Grid(newtiles, grid.tile_w * scale_factor, grid.tile_h * scale_factor, grid.image_w * scale_factor, grid.image_h * scale_factor, grid.overlap * scale_factor)
output = images.combine_grid(newgrid)
return output
-465
View File
@@ -1,465 +0,0 @@
# this file is adapted from https://github.com/victorca25/iNNfer
from collections import OrderedDict
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
####################
# RRDBNet Generator
####################
class RRDBNet(nn.Module):
def __init__(self, in_nc, out_nc, nf, nb, nr=3, gc=32, upscale=4, norm_type=None,
act_type='leakyrelu', mode='CNA', upsample_mode='upconv', convtype='Conv2D',
finalact=None, gaussian_noise=False, plus=False):
super(RRDBNet, self).__init__()
n_upscale = int(math.log(upscale, 2))
if upscale == 3:
n_upscale = 1
self.resrgan_scale = 0
if in_nc % 16 == 0:
self.resrgan_scale = 1
elif in_nc != 4 and in_nc % 4 == 0:
self.resrgan_scale = 2
fea_conv = conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None, convtype=convtype)
rb_blocks = [RRDB(nf, nr, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero',
norm_type=norm_type, act_type=act_type, mode='CNA', convtype=convtype,
gaussian_noise=gaussian_noise, plus=plus) for _ in range(nb)]
LR_conv = conv_block(nf, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode, convtype=convtype)
if upsample_mode == 'upconv':
upsample_block = upconv_block
elif upsample_mode == 'pixelshuffle':
upsample_block = pixelshuffle_block
else:
raise NotImplementedError(f'upsample mode [{upsample_mode}] is not found')
if upscale == 3:
upsampler = upsample_block(nf, nf, 3, act_type=act_type, convtype=convtype)
else:
upsampler = [upsample_block(nf, nf, act_type=act_type, convtype=convtype) for _ in range(n_upscale)]
HR_conv0 = conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type, convtype=convtype)
HR_conv1 = conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None, convtype=convtype)
outact = act(finalact) if finalact else None
self.model = sequential(fea_conv, ShortcutBlock(sequential(*rb_blocks, LR_conv)),
*upsampler, HR_conv0, HR_conv1, outact)
def forward(self, x, outm=None):
if self.resrgan_scale == 1:
feat = pixel_unshuffle(x, scale=4)
elif self.resrgan_scale == 2:
feat = pixel_unshuffle(x, scale=2)
else:
feat = x
return self.model(feat)
class RRDB(nn.Module):
"""
Residual in Residual Dense Block
(ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks)
"""
def __init__(self, nf, nr=3, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero',
norm_type=None, act_type='leakyrelu', mode='CNA', convtype='Conv2D',
spectral_norm=False, gaussian_noise=False, plus=False):
super(RRDB, self).__init__()
# This is for backwards compatibility with existing models
if nr == 3:
self.RDB1 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
gaussian_noise=gaussian_noise, plus=plus)
self.RDB2 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
gaussian_noise=gaussian_noise, plus=plus)
self.RDB3 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
gaussian_noise=gaussian_noise, plus=plus)
else:
RDB_list = [ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
gaussian_noise=gaussian_noise, plus=plus) for _ in range(nr)]
self.RDBs = nn.Sequential(*RDB_list)
def forward(self, x):
if hasattr(self, 'RDB1'):
out = self.RDB1(x)
out = self.RDB2(out)
out = self.RDB3(out)
else:
out = self.RDBs(x)
return out * 0.2 + x
class ResidualDenseBlock_5C(nn.Module):
"""
Residual Dense Block
The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18)
Modified options that can be used:
- "Partial Convolution based Padding" arXiv:1811.11718
- "Spectral normalization" arXiv:1802.05957
- "ICASSP 2020 - ESRGAN+ : Further Improving ESRGAN" N. C.
{Rakotonirina} and A. {Rasoanaivo}
"""
def __init__(self, nf=64, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero',
norm_type=None, act_type='leakyrelu', mode='CNA', convtype='Conv2D',
spectral_norm=False, gaussian_noise=False, plus=False):
super(ResidualDenseBlock_5C, self).__init__()
self.noise = GaussianNoise() if gaussian_noise else None
self.conv1x1 = conv1x1(nf, gc) if plus else None
self.conv1 = conv_block(nf, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
spectral_norm=spectral_norm)
self.conv2 = conv_block(nf+gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
spectral_norm=spectral_norm)
self.conv3 = conv_block(nf+2*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
spectral_norm=spectral_norm)
self.conv4 = conv_block(nf+3*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
spectral_norm=spectral_norm)
if mode == 'CNA':
last_act = None
else:
last_act = act_type
self.conv5 = conv_block(nf+4*gc, nf, 3, stride, bias=bias, pad_type=pad_type,
norm_type=norm_type, act_type=last_act, mode=mode, convtype=convtype,
spectral_norm=spectral_norm)
def forward(self, x):
x1 = self.conv1(x)
x2 = self.conv2(torch.cat((x, x1), 1))
if self.conv1x1:
x2 = x2 + self.conv1x1(x)
x3 = self.conv3(torch.cat((x, x1, x2), 1))
x4 = self.conv4(torch.cat((x, x1, x2, x3), 1))
if self.conv1x1:
x4 = x4 + x2
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
if self.noise:
return self.noise(x5.mul(0.2) + x)
else:
return x5 * 0.2 + x
####################
# ESRGANplus
####################
class GaussianNoise(nn.Module):
def __init__(self, sigma=0.1, is_relative_detach=False):
super().__init__()
self.sigma = sigma
self.is_relative_detach = is_relative_detach
self.noise = torch.tensor(0, dtype=torch.float)
def forward(self, x):
if self.training and self.sigma != 0:
self.noise = self.noise.to(x.device)
scale = self.sigma * x.detach() if self.is_relative_detach else self.sigma * x
sampled_noise = self.noise.repeat(*x.size()).normal_() * scale
x = x + sampled_noise
return x
def conv1x1(in_planes, out_planes, stride=1):
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
####################
# SRVGGNetCompact
####################
class SRVGGNetCompact(nn.Module):
"""A compact VGG-style network structure for super-resolution.
This class is copied from https://github.com/xinntao/Real-ESRGAN
"""
def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu'):
super(SRVGGNetCompact, self).__init__()
self.num_in_ch = num_in_ch
self.num_out_ch = num_out_ch
self.num_feat = num_feat
self.num_conv = num_conv
self.upscale = upscale
self.act_type = act_type
self.body = nn.ModuleList()
# the first conv
self.body.append(nn.Conv2d(num_in_ch, num_feat, 3, 1, 1))
# the first activation
if act_type == 'relu':
activation = nn.ReLU(inplace=True)
elif act_type == 'prelu':
activation = nn.PReLU(num_parameters=num_feat)
elif act_type == 'leakyrelu':
activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
self.body.append(activation)
# the body structure
for _ in range(num_conv):
self.body.append(nn.Conv2d(num_feat, num_feat, 3, 1, 1))
# activation
if act_type == 'relu':
activation = nn.ReLU(inplace=True)
elif act_type == 'prelu':
activation = nn.PReLU(num_parameters=num_feat)
elif act_type == 'leakyrelu':
activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
self.body.append(activation)
# the last conv
self.body.append(nn.Conv2d(num_feat, num_out_ch * upscale * upscale, 3, 1, 1))
# upsample
self.upsampler = nn.PixelShuffle(upscale)
def forward(self, x):
out = x
for i in range(0, len(self.body)):
out = self.body[i](out)
out = self.upsampler(out)
# add the nearest upsampled image, so that the network learns the residual
base = F.interpolate(x, scale_factor=self.upscale, mode='nearest')
out += base
return out
####################
# Upsampler
####################
class Upsample(nn.Module):
r"""Upsamples a given multi-channel 1D (temporal), 2D (spatial) or 3D (volumetric) data.
The input data is assumed to be of the form
`minibatch x channels x [optional depth] x [optional height] x width`.
"""
def __init__(self, size=None, scale_factor=None, mode="nearest", align_corners=None):
super(Upsample, self).__init__()
if isinstance(scale_factor, tuple):
self.scale_factor = tuple(float(factor) for factor in scale_factor)
else:
self.scale_factor = float(scale_factor) if scale_factor else None
self.mode = mode
self.size = size
self.align_corners = align_corners
def forward(self, x):
return nn.functional.interpolate(x, size=self.size, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners)
def extra_repr(self):
if self.scale_factor is not None:
info = f'scale_factor={self.scale_factor}'
else:
info = f'size={self.size}'
info += f', mode={self.mode}'
return info
def pixel_unshuffle(x, scale):
""" Pixel unshuffle.
Args:
x (Tensor): Input feature with shape (b, c, hh, hw).
scale (int): Downsample ratio.
Returns:
Tensor: the pixel unshuffled feature.
"""
b, c, hh, hw = x.size()
out_channel = c * (scale**2)
assert hh % scale == 0 and hw % scale == 0
h = hh // scale
w = hw // scale
x_view = x.view(b, c, h, scale, w, scale)
return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)
def pixelshuffle_block(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True,
pad_type='zero', norm_type=None, act_type='relu', convtype='Conv2D'):
"""
Pixel shuffle layer
(Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional
Neural Network, CVPR17)
"""
conv = conv_block(in_nc, out_nc * (upscale_factor ** 2), kernel_size, stride, bias=bias,
pad_type=pad_type, norm_type=None, act_type=None, convtype=convtype)
pixel_shuffle = nn.PixelShuffle(upscale_factor)
n = norm(norm_type, out_nc) if norm_type else None
a = act(act_type) if act_type else None
return sequential(conv, pixel_shuffle, n, a)
def upconv_block(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True,
pad_type='zero', norm_type=None, act_type='relu', mode='nearest', convtype='Conv2D'):
""" Upconv layer """
upscale_factor = (1, upscale_factor, upscale_factor) if convtype == 'Conv3D' else upscale_factor
upsample = Upsample(scale_factor=upscale_factor, mode=mode)
conv = conv_block(in_nc, out_nc, kernel_size, stride, bias=bias,
pad_type=pad_type, norm_type=norm_type, act_type=act_type, convtype=convtype)
return sequential(upsample, conv)
####################
# Basic blocks
####################
def make_layer(basic_block, num_basic_block, **kwarg):
"""Make layers by stacking the same blocks.
Args:
basic_block (nn.module): nn.module class for basic block. (block)
num_basic_block (int): number of blocks. (n_layers)
Returns:
nn.Sequential: Stacked blocks in nn.Sequential.
"""
layers = []
for _ in range(num_basic_block):
layers.append(basic_block(**kwarg))
return nn.Sequential(*layers)
def act(act_type, inplace=True, neg_slope=0.2, n_prelu=1, beta=1.0):
""" activation helper """
act_type = act_type.lower()
if act_type == 'relu':
layer = nn.ReLU(inplace)
elif act_type in ('leakyrelu', 'lrelu'):
layer = nn.LeakyReLU(neg_slope, inplace)
elif act_type == 'prelu':
layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope)
elif act_type == 'tanh': # [-1, 1] range output
layer = nn.Tanh()
elif act_type == 'sigmoid': # [0, 1] range output
layer = nn.Sigmoid()
else:
raise NotImplementedError(f'activation layer [{act_type}] is not found')
return layer
class Identity(nn.Module):
def __init__(self, *kwargs):
super(Identity, self).__init__()
def forward(self, x, *kwargs):
return x
def norm(norm_type, nc):
""" Return a normalization layer """
norm_type = norm_type.lower()
if norm_type == 'batch':
layer = nn.BatchNorm2d(nc, affine=True)
elif norm_type == 'instance':
layer = nn.InstanceNorm2d(nc, affine=False)
elif norm_type == 'none':
def norm_layer(x): return Identity()
else:
raise NotImplementedError(f'normalization layer [{norm_type}] is not found')
return layer
def pad(pad_type, padding):
""" padding layer helper """
pad_type = pad_type.lower()
if padding == 0:
return None
if pad_type == 'reflect':
layer = nn.ReflectionPad2d(padding)
elif pad_type == 'replicate':
layer = nn.ReplicationPad2d(padding)
elif pad_type == 'zero':
layer = nn.ZeroPad2d(padding)
else:
raise NotImplementedError(f'padding layer [{pad_type}] is not implemented')
return layer
def get_valid_padding(kernel_size, dilation):
kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1)
padding = (kernel_size - 1) // 2
return padding
class ShortcutBlock(nn.Module):
""" Elementwise sum the output of a submodule to its input """
def __init__(self, submodule):
super(ShortcutBlock, self).__init__()
self.sub = submodule
def forward(self, x):
output = x + self.sub(x)
return output
def __repr__(self):
return 'Identity + \n|' + self.sub.__repr__().replace('\n', '\n|')
def sequential(*args):
""" Flatten Sequential. It unwraps nn.Sequential. """
if len(args) == 1:
if isinstance(args[0], OrderedDict):
raise NotImplementedError('sequential does not support OrderedDict input.')
return args[0] # No sequential is needed.
modules = []
for module in args:
if isinstance(module, nn.Sequential):
for submodule in module.children():
modules.append(submodule)
elif isinstance(module, nn.Module):
modules.append(module)
return nn.Sequential(*modules)
def conv_block(in_nc, out_nc, kernel_size, stride=1, dilation=1, groups=1, bias=True,
pad_type='zero', norm_type=None, act_type='relu', mode='CNA', convtype='Conv2D',
spectral_norm=False):
""" Conv layer with padding, normalization, activation """
assert mode in ['CNA', 'NAC', 'CNAC'], f'Wrong conv mode [{mode}]'
padding = get_valid_padding(kernel_size, dilation)
p = pad(pad_type, padding) if pad_type and pad_type != 'zero' else None
padding = padding if pad_type == 'zero' else 0
if convtype=='PartialConv2D':
from torchvision.ops import PartialConv2d # this is definitely not going to work, but PartialConv2d doesn't work anyway and this shuts up static analyzer
c = PartialConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
dilation=dilation, bias=bias, groups=groups)
elif convtype=='DeformConv2D':
from torchvision.ops import DeformConv2d # not tested
c = DeformConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
dilation=dilation, bias=bias, groups=groups)
elif convtype=='Conv3D':
c = nn.Conv3d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
dilation=dilation, bias=bias, groups=groups)
else:
c = nn.Conv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
dilation=dilation, bias=bias, groups=groups)
if spectral_norm:
c = nn.utils.spectral_norm(c)
a = act(act_type) if act_type else None
if 'CNA' in mode:
n = norm(norm_type, out_nc) if norm_type else None
return sequential(p, c, n, a)
elif mode == 'NAC':
if norm_type is None and act_type is not None:
a = act(act_type, inplace=False)
n = norm(norm_type, in_nc) if norm_type else None
return sequential(n, a, p, c)
+88 -12
View File
@@ -1,11 +1,14 @@
from __future__ import annotations
import configparser
import os import os
import threading import threading
import re
from modules import shared, errors, cache, scripts from modules import shared, errors, cache, scripts
from modules.gitpython_hack import Repo from modules.gitpython_hack import Repo
from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path # noqa: F401 from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path # noqa: F401
extensions = []
os.makedirs(extensions_dir, exist_ok=True) os.makedirs(extensions_dir, exist_ok=True)
@@ -19,11 +22,56 @@ def active():
return [x for x in extensions if x.enabled] return [x for x in extensions if x.enabled]
class ExtensionMetadata:
filename = "metadata.ini"
config: configparser.ConfigParser
canonical_name: str
requires: list
def __init__(self, path, canonical_name):
self.config = configparser.ConfigParser()
filepath = os.path.join(path, self.filename)
# `self.config.read()` will quietly swallow OSErrors (which FileNotFoundError is),
# so no need to check whether the file exists beforehand.
try:
self.config.read(filepath)
except Exception:
errors.report(f"Error reading {self.filename} for extension {canonical_name}.", exc_info=True)
self.canonical_name = self.config.get("Extension", "Name", fallback=canonical_name)
self.canonical_name = canonical_name.lower().strip()
self.requires = self.get_script_requirements("Requires", "Extension")
def get_script_requirements(self, field, section, extra_section=None):
"""reads a list of requirements from the config; field is the name of the field in the ini file,
like Requires or Before, and section is the name of the [section] in the ini file; additionally,
reads more requirements from [extra_section] if specified."""
x = self.config.get(section, field, fallback='')
if extra_section:
x = x + ', ' + self.config.get(extra_section, field, fallback='')
return self.parse_list(x.lower())
def parse_list(self, text):
"""converts a line from config ("ext1 ext2, ext3 ") into a python list (["ext1", "ext2", "ext3"])"""
if not text:
return []
# both "," and " " are accepted as separator
return [x for x in re.split(r"[,\s]+", text.strip()) if x]
class Extension: class Extension:
lock = threading.Lock() lock = threading.Lock()
cached_fields = ['remote', 'commit_date', 'branch', 'commit_hash', 'version'] cached_fields = ['remote', 'commit_date', 'branch', 'commit_hash', 'version']
metadata: ExtensionMetadata
def __init__(self, name, path, enabled=True, is_builtin=False): def __init__(self, name, path, enabled=True, is_builtin=False, metadata=None):
self.name = name self.name = name
self.path = path self.path = path
self.enabled = enabled self.enabled = enabled
@@ -36,6 +84,8 @@ class Extension:
self.branch = None self.branch = None
self.remote = None self.remote = None
self.have_info_from_repo = False self.have_info_from_repo = False
self.metadata = metadata if metadata else ExtensionMetadata(self.path, name.lower())
self.canonical_name = metadata.canonical_name
def to_dict(self): def to_dict(self):
return {x: getattr(self, x) for x in self.cached_fields} return {x: getattr(self, x) for x in self.cached_fields}
@@ -56,6 +106,7 @@ class Extension:
self.do_read_info_from_repo() self.do_read_info_from_repo()
return self.to_dict() return self.to_dict()
try: try:
d = cache.cached_data_for_file('extensions-git', self.name, os.path.join(self.path, ".git"), read_from_repo) d = cache.cached_data_for_file('extensions-git', self.name, os.path.join(self.path, ".git"), read_from_repo)
self.from_dict(d) self.from_dict(d)
@@ -136,9 +187,6 @@ class Extension:
def list_extensions(): def list_extensions():
extensions.clear() extensions.clear()
if not os.path.isdir(extensions_dir):
return
if shared.cmd_opts.disable_all_extensions: if shared.cmd_opts.disable_all_extensions:
print("*** \"--disable-all-extensions\" arg was used, will not load any extensions ***") print("*** \"--disable-all-extensions\" arg was used, will not load any extensions ***")
elif shared.opts.disable_all_extensions == "all": elif shared.opts.disable_all_extensions == "all":
@@ -148,18 +196,46 @@ def list_extensions():
elif shared.opts.disable_all_extensions == "extra": elif shared.opts.disable_all_extensions == "extra":
print("*** \"Disable all extensions\" option was set, will only load built-in extensions ***") print("*** \"Disable all extensions\" option was set, will only load built-in extensions ***")
extension_paths = [] loaded_extensions = {}
for dirname in [extensions_dir, extensions_builtin_dir]:
# scan through extensions directory and load metadata
for dirname in [extensions_builtin_dir, extensions_dir]:
if not os.path.isdir(dirname): if not os.path.isdir(dirname):
return continue
for extension_dirname in sorted(os.listdir(dirname)): for extension_dirname in sorted(os.listdir(dirname)):
path = os.path.join(dirname, extension_dirname) path = os.path.join(dirname, extension_dirname)
if not os.path.isdir(path): if not os.path.isdir(path):
continue continue
extension_paths.append((extension_dirname, path, dirname == extensions_builtin_dir)) canonical_name = extension_dirname
metadata = ExtensionMetadata(path, canonical_name)
for dirname, path, is_builtin in extension_paths: # check for duplicated canonical names
extension = Extension(name=dirname, path=path, enabled=dirname not in shared.opts.disabled_extensions, is_builtin=is_builtin) already_loaded_extension = loaded_extensions.get(metadata.canonical_name)
extensions.append(extension) if already_loaded_extension is not None:
errors.report(f'Duplicate canonical name "{canonical_name}" found in extensions "{extension_dirname}" and "{already_loaded_extension.name}". Former will be discarded.', exc_info=False)
continue
is_builtin = dirname == extensions_builtin_dir
extension = Extension(name=extension_dirname, path=path, enabled=extension_dirname not in shared.opts.disabled_extensions, is_builtin=is_builtin, metadata=metadata)
extensions.append(extension)
loaded_extensions[canonical_name] = extension
# check for requirements
for extension in extensions:
if not extension.enabled:
continue
for req in extension.metadata.requires:
required_extension = loaded_extensions.get(req)
if required_extension is None:
errors.report(f'Extension "{extension.name}" requires "{req}" which is not installed.', exc_info=False)
continue
if not required_extension.enabled:
errors.report(f'Extension "{extension.name}" requires "{required_extension.name}" which is disabled.', exc_info=False)
continue
extensions: list[Extension] = []
+3 -2
View File
@@ -206,7 +206,7 @@ def parse_prompts(prompts):
return res, extra_data return res, extra_data
def get_user_metadata(filename): def get_user_metadata(filename, lister=None):
if filename is None: if filename is None:
return {} return {}
@@ -215,7 +215,8 @@ def get_user_metadata(filename):
metadata = {} metadata = {}
try: try:
if os.path.isfile(metadata_filename): exists = lister.exists(metadata_filename) if lister else os.path.exists(metadata_filename)
if exists:
with open(metadata_filename, "r", encoding="utf8") as file: with open(metadata_filename, "r", encoding="utf8") as file:
metadata = json.load(file) metadata = json.load(file)
except Exception as e: except Exception as e:
+180
View File
@@ -0,0 +1,180 @@
from __future__ import annotations
import logging
import os
from functools import cached_property
from typing import TYPE_CHECKING, Callable
import cv2
import numpy as np
import torch
from modules import devices, errors, face_restoration, shared
if TYPE_CHECKING:
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
logger = logging.getLogger(__name__)
def bgr_image_to_rgb_tensor(img: np.ndarray) -> torch.Tensor:
"""Convert a BGR NumPy image in [0..1] range to a PyTorch RGB float32 tensor."""
assert img.shape[2] == 3, "image must be RGB"
if img.dtype == "float64":
img = img.astype("float32")
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
return torch.from_numpy(img.transpose(2, 0, 1)).float()
def rgb_tensor_to_bgr_image(tensor: torch.Tensor, *, min_max=(0.0, 1.0)) -> np.ndarray:
"""
Convert a PyTorch RGB tensor in range `min_max` to a BGR NumPy image in [0..1] range.
"""
tensor = tensor.squeeze(0).float().detach().cpu().clamp_(*min_max)
tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0])
assert tensor.dim() == 3, "tensor must be RGB"
img_np = tensor.numpy().transpose(1, 2, 0)
if img_np.shape[2] == 1: # gray image, no RGB/BGR required
return np.squeeze(img_np, axis=2)
return cv2.cvtColor(img_np, cv2.COLOR_BGR2RGB)
def create_face_helper(device) -> FaceRestoreHelper:
from facexlib.detection import retinaface
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
if hasattr(retinaface, 'device'):
retinaface.device = device
return FaceRestoreHelper(
upscale_factor=1,
face_size=512,
crop_ratio=(1, 1),
det_model='retinaface_resnet50',
save_ext='png',
use_parse=True,
device=device,
)
def restore_with_face_helper(
np_image: np.ndarray,
face_helper: FaceRestoreHelper,
restore_face: Callable[[torch.Tensor], torch.Tensor],
) -> np.ndarray:
"""
Find faces in the image using face_helper, restore them using restore_face, and paste them back into the image.
`restore_face` should take a cropped face image and return a restored face image.
"""
from torchvision.transforms.functional import normalize
np_image = np_image[:, :, ::-1]
original_resolution = np_image.shape[0:2]
try:
logger.debug("Detecting faces...")
face_helper.clean_all()
face_helper.read_image(np_image)
face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)
face_helper.align_warp_face()
logger.debug("Found %d faces, restoring", len(face_helper.cropped_faces))
for cropped_face in face_helper.cropped_faces:
cropped_face_t = bgr_image_to_rgb_tensor(cropped_face / 255.0)
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
cropped_face_t = cropped_face_t.unsqueeze(0).to(devices.device_codeformer)
try:
with torch.no_grad():
cropped_face_t = restore_face(cropped_face_t)
devices.torch_gc()
except Exception:
errors.report('Failed face-restoration inference', exc_info=True)
restored_face = rgb_tensor_to_bgr_image(cropped_face_t, min_max=(-1, 1))
restored_face = (restored_face * 255.0).astype('uint8')
face_helper.add_restored_face(restored_face)
logger.debug("Merging restored faces into image")
face_helper.get_inverse_affine(None)
img = face_helper.paste_faces_to_input_image()
img = img[:, :, ::-1]
if original_resolution != img.shape[0:2]:
img = cv2.resize(
img,
(0, 0),
fx=original_resolution[1] / img.shape[1],
fy=original_resolution[0] / img.shape[0],
interpolation=cv2.INTER_LINEAR,
)
logger.debug("Face restoration complete")
finally:
face_helper.clean_all()
return img
class CommonFaceRestoration(face_restoration.FaceRestoration):
net: torch.Module | None
model_url: str
model_download_name: str
def __init__(self, model_path: str):
super().__init__()
self.net = None
self.model_path = model_path
os.makedirs(model_path, exist_ok=True)
@cached_property
def face_helper(self) -> FaceRestoreHelper:
return create_face_helper(self.get_device())
def send_model_to(self, device):
if self.net:
logger.debug("Sending %s to %s", self.net, device)
self.net.to(device)
if self.face_helper:
logger.debug("Sending face helper to %s", device)
self.face_helper.face_det.to(device)
self.face_helper.face_parse.to(device)
def get_device(self):
raise NotImplementedError("get_device must be implemented by subclasses")
def load_net(self) -> torch.Module:
raise NotImplementedError("load_net must be implemented by subclasses")
def restore_with_helper(
self,
np_image: np.ndarray,
restore_face: Callable[[torch.Tensor], torch.Tensor],
) -> np.ndarray:
try:
if self.net is None:
self.net = self.load_net()
except Exception:
logger.warning("Unable to load face-restoration model", exc_info=True)
return np_image
try:
self.send_model_to(self.get_device())
return restore_with_face_helper(np_image, self.face_helper, restore_face)
finally:
if shared.opts.face_restoration_unload:
self.send_model_to(devices.cpu)
def patch_facexlib(dirname: str) -> None:
import facexlib.detection
import facexlib.parsing
det_facex_load_file_from_url = facexlib.detection.load_file_from_url
par_facex_load_file_from_url = facexlib.parsing.load_file_from_url
def update_kwargs(kwargs):
return dict(kwargs, save_dir=dirname, model_dir=None)
def facex_load_file_from_url(**kwargs):
return det_facex_load_file_from_url(**update_kwargs(kwargs))
def facex_load_file_from_url2(**kwargs):
return par_facex_load_file_from_url(**update_kwargs(kwargs))
facexlib.detection.load_file_from_url = facex_load_file_from_url
facexlib.parsing.load_file_from_url = facex_load_file_from_url2
+50 -89
View File
@@ -1,110 +1,71 @@
from __future__ import annotations
import logging
import os import os
import facexlib import torch
import gfpgan
import modules.face_restoration from modules import (
from modules import paths, shared, devices, modelloader, errors devices,
errors,
face_restoration,
face_restoration_utils,
modelloader,
shared,
)
model_dir = "GFPGAN" logger = logging.getLogger(__name__)
user_path = None
model_path = os.path.join(paths.models_path, model_dir)
model_url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth" model_url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth"
have_gfpgan = False model_download_name = "GFPGANv1.4.pth"
loaded_gfpgan_model = None gfpgan_face_restorer: face_restoration.FaceRestoration | None = None
def gfpgann(): class FaceRestorerGFPGAN(face_restoration_utils.CommonFaceRestoration):
global loaded_gfpgan_model def name(self):
global model_path return "GFPGAN"
if loaded_gfpgan_model is not None:
loaded_gfpgan_model.gfpgan.to(devices.device_gfpgan)
return loaded_gfpgan_model
if gfpgan_constructor is None: def get_device(self):
return None return devices.device_gfpgan
models = modelloader.load_models(model_path, model_url, user_path, ext_filter="GFPGAN") def load_net(self) -> torch.Module:
if len(models) == 1 and models[0].startswith("http"): for model_path in modelloader.load_models(
model_file = models[0] model_path=self.model_path,
elif len(models) != 0: model_url=model_url,
latest_file = max(models, key=os.path.getctime) command_path=self.model_path,
model_file = latest_file download_name=model_download_name,
else: ext_filter=['.pth'],
print("Unable to load gfpgan model!") ):
return None if 'GFPGAN' in os.path.basename(model_path):
if hasattr(facexlib.detection.retinaface, 'device'): model = modelloader.load_spandrel_model(
facexlib.detection.retinaface.device = devices.device_gfpgan model_path,
model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=devices.device_gfpgan) device=self.get_device(),
loaded_gfpgan_model = model expected_architecture='GFPGAN',
).model
model.different_w = True # see https://github.com/chaiNNer-org/spandrel/pull/81
return model
raise ValueError("No GFPGAN model found")
return model def restore(self, np_image):
def restore_face(cropped_face_t):
assert self.net is not None
return self.net(cropped_face_t, return_rgb=False)[0]
return self.restore_with_helper(np_image, restore_face)
def send_model_to(model, device):
model.gfpgan.to(device)
model.face_helper.face_det.to(device)
model.face_helper.face_parse.to(device)
def gfpgan_fix_faces(np_image): def gfpgan_fix_faces(np_image):
model = gfpgann() if gfpgan_face_restorer:
if model is None: return gfpgan_face_restorer.restore(np_image)
return np_image logger.warning("GFPGAN face restorer not set up")
send_model_to(model, devices.device_gfpgan)
np_image_bgr = np_image[:, :, ::-1]
cropped_faces, restored_faces, gfpgan_output_bgr = model.enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True)
np_image = gfpgan_output_bgr[:, :, ::-1]
model.face_helper.clean_all()
if shared.opts.face_restoration_unload:
send_model_to(model, devices.cpu)
return np_image return np_image
gfpgan_constructor = None def setup_model(dirname: str) -> None:
global gfpgan_face_restorer
def setup_model(dirname):
try: try:
os.makedirs(model_path, exist_ok=True) face_restoration_utils.patch_facexlib(dirname)
from gfpgan import GFPGANer gfpgan_face_restorer = FaceRestorerGFPGAN(model_path=dirname)
from facexlib import detection, parsing # noqa: F401 shared.face_restorers.append(gfpgan_face_restorer)
global user_path
global have_gfpgan
global gfpgan_constructor
load_file_from_url_orig = gfpgan.utils.load_file_from_url
facex_load_file_from_url_orig = facexlib.detection.load_file_from_url
facex_load_file_from_url_orig2 = facexlib.parsing.load_file_from_url
def my_load_file_from_url(**kwargs):
return load_file_from_url_orig(**dict(kwargs, model_dir=model_path))
def facex_load_file_from_url(**kwargs):
return facex_load_file_from_url_orig(**dict(kwargs, save_dir=model_path, model_dir=None))
def facex_load_file_from_url2(**kwargs):
return facex_load_file_from_url_orig2(**dict(kwargs, save_dir=model_path, model_dir=None))
gfpgan.utils.load_file_from_url = my_load_file_from_url
facexlib.detection.load_file_from_url = facex_load_file_from_url
facexlib.parsing.load_file_from_url = facex_load_file_from_url2
user_path = dirname
have_gfpgan = True
gfpgan_constructor = GFPGANer
class FaceRestorerGFPGAN(modules.face_restoration.FaceRestoration):
def name(self):
return "GFPGAN"
def restore(self, np_image):
return gfpgan_fix_faces(np_image)
shared.face_restorers.append(FaceRestorerGFPGAN())
except Exception: except Exception:
errors.report("Error setting up GFPGAN", exc_info=True) errors.report("Error setting up GFPGAN", exc_info=True)
+10
View File
@@ -47,10 +47,20 @@ def Block_get_config(self):
def BlockContext_init(self, *args, **kwargs): def BlockContext_init(self, *args, **kwargs):
if scripts.scripts_current is not None:
scripts.scripts_current.before_component(self, **kwargs)
scripts.script_callbacks.before_component_callback(self, **kwargs)
res = original_BlockContext_init(self, *args, **kwargs) res = original_BlockContext_init(self, *args, **kwargs)
add_classes_to_gradio_component(self) add_classes_to_gradio_component(self)
scripts.script_callbacks.after_component_callback(self, **kwargs)
if scripts.scripts_current is not None:
scripts.scripts_current.after_component(self, **kwargs)
return res return res
+43
View File
@@ -0,0 +1,43 @@
import os
import sys
from modules import modelloader, devices
from modules.shared import opts
from modules.upscaler import Upscaler, UpscalerData
from modules.upscaler_utils import upscale_with_model
class UpscalerHAT(Upscaler):
def __init__(self, dirname):
self.name = "HAT"
self.scalers = []
self.user_path = dirname
super().__init__()
for file in self.find_models(ext_filter=[".pt", ".pth"]):
name = modelloader.friendly_name(file)
scale = 4 # TODO: scale might not be 4, but we can't know without loading the model
scaler_data = UpscalerData(name, file, upscaler=self, scale=scale)
self.scalers.append(scaler_data)
def do_upscale(self, img, selected_model):
try:
model = self.load_model(selected_model)
except Exception as e:
print(f"Unable to load HAT model {selected_model}: {e}", file=sys.stderr)
return img
model.to(devices.device_esrgan) # TODO: should probably be device_hat
return upscale_with_model(
model,
img,
tile_size=opts.ESRGAN_tile, # TODO: should probably be HAT_tile
tile_overlap=opts.ESRGAN_tile_overlap, # TODO: should probably be HAT_tile_overlap
)
def load_model(self, path: str):
if not os.path.isfile(path):
raise FileNotFoundError(f"Model file {path} not found")
return modelloader.load_spandrel_model(
path,
device=devices.device_esrgan, # TODO: should probably be device_hat
expected_architecture='HAT',
)
+11 -5
View File
@@ -61,12 +61,17 @@ def image_grid(imgs, batch_size=1, rows=None):
return grid return grid
Grid = namedtuple("Grid", ["tiles", "tile_w", "tile_h", "image_w", "image_h", "overlap"]) class Grid(namedtuple("_Grid", ["tiles", "tile_w", "tile_h", "image_w", "image_h", "overlap"])):
@property
def tile_count(self) -> int:
"""
The total number of tiles in the grid.
"""
return sum(len(row[2]) for row in self.tiles)
def split_grid(image, tile_w=512, tile_h=512, overlap=64): def split_grid(image: Image.Image, tile_w: int = 512, tile_h: int = 512, overlap: int = 64) -> Grid:
w = image.width w, h = image.size
h = image.height
non_overlap_width = tile_w - overlap non_overlap_width = tile_w - overlap
non_overlap_height = tile_h - overlap non_overlap_height = tile_h - overlap
@@ -316,7 +321,7 @@ def resize_image(resize_mode, im, width, height, upscaler_name=None):
return res return res
invalid_filename_chars = '<>:"/\\|?*\n\r\t' invalid_filename_chars = '#<>:"/\\|?*\n\r\t'
invalid_filename_prefix = ' ' invalid_filename_prefix = ' '
invalid_filename_postfix = ' .' invalid_filename_postfix = ' .'
re_nonletters = re.compile(r'[\s' + string.punctuation + ']+') re_nonletters = re.compile(r'[\s' + string.punctuation + ']+')
@@ -791,3 +796,4 @@ def flatten(img, bgcolor):
img = background img = background
return img.convert('RGB') return img.convert('RGB')
+22 -9
View File
@@ -7,7 +7,7 @@ from PIL import Image, ImageOps, ImageFilter, ImageEnhance, UnidentifiedImageErr
import gradio as gr import gradio as gr
from modules import images as imgutil from modules import images as imgutil
from modules.generation_parameters_copypaste import create_override_settings_dict, parse_generation_parameters from modules.infotext_utils import create_override_settings_dict, parse_generation_parameters
from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
from modules.shared import opts, state from modules.shared import opts, state
from modules.sd_models import get_closet_checkpoint_match from modules.sd_models import get_closet_checkpoint_match
@@ -44,12 +44,14 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal
steps = p.steps steps = p.steps
override_settings = p.override_settings override_settings = p.override_settings
sd_model_checkpoint_override = get_closet_checkpoint_match(override_settings.get("sd_model_checkpoint", None)) sd_model_checkpoint_override = get_closet_checkpoint_match(override_settings.get("sd_model_checkpoint", None))
batch_results = None
discard_further_results = False
for i, image in enumerate(images): for i, image in enumerate(images):
state.job = f"{i+1} out of {len(images)}" state.job = f"{i+1} out of {len(images)}"
if state.skipped: if state.skipped:
state.skipped = False state.skipped = False
if state.interrupted: if state.interrupted or state.stopping_generation:
break break
try: try:
@@ -127,7 +129,21 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal
if proc is None: if proc is None:
p.override_settings.pop('save_images_replace_action', None) p.override_settings.pop('save_images_replace_action', None)
process_images(p) proc = process_images(p)
if not discard_further_results and proc:
if batch_results:
batch_results.images.extend(proc.images)
batch_results.infotexts.extend(proc.infotexts)
else:
batch_results = proc
if 0 <= shared.opts.img2img_batch_show_results_limit < len(batch_results.images):
discard_further_results = True
batch_results.images = batch_results.images[:int(shared.opts.img2img_batch_show_results_limit)]
batch_results.infotexts = batch_results.infotexts[:int(shared.opts.img2img_batch_show_results_limit)]
return batch_results
def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_name: str, mask_blur: int, mask_alpha: float, inpainting_fill: int, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, request: gr.Request, *args): def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_name: str, mask_blur: int, mask_alpha: float, inpainting_fill: int, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, request: gr.Request, *args):
@@ -206,16 +222,13 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
if shared.opts.enable_console_prompts: if shared.opts.enable_console_prompts:
print(f"\nimg2img: {prompt}", file=shared.progress_print_out) print(f"\nimg2img: {prompt}", file=shared.progress_print_out)
if mask:
p.extra_generation_params["Mask blur"] = mask_blur
with closing(p): with closing(p):
if is_batch: if is_batch:
assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled" assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled"
processed = process_batch(p, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir, args, to_scale=selected_scale_tab == 1, scale_by=scale_by, use_png_info=img2img_batch_use_png_info, png_info_props=img2img_batch_png_info_props, png_info_dir=img2img_batch_png_info_dir)
process_batch(p, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir, args, to_scale=selected_scale_tab == 1, scale_by=scale_by, use_png_info=img2img_batch_use_png_info, png_info_props=img2img_batch_png_info_props, png_info_dir=img2img_batch_png_info_dir) if processed is None:
processed = Processed(p, [], p.seed, "")
processed = Processed(p, [], p.seed, "")
else: else:
processed = modules.scripts.scripts_img2img.run(p, *args) processed = modules.scripts.scripts_img2img.run(p, *args)
if processed is None: if processed is None:
+11
View File
@@ -3,3 +3,14 @@ import sys
# this will break any attempt to import xformers which will prevent stability diffusion repo from trying to use it # this will break any attempt to import xformers which will prevent stability diffusion repo from trying to use it
if "--xformers" not in "".join(sys.argv): if "--xformers" not in "".join(sys.argv):
sys.modules["xformers"] = None sys.modules["xformers"] = None
# Hack to fix a changed import in torchvision 0.17+, which otherwise breaks
# basicsr; see https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/13985
try:
import torchvision.transforms.functional_tensor # noqa: F401
except ImportError:
try:
import torchvision.transforms.functional as functional
sys.modules["torchvision.transforms.functional_tensor"] = functional
except ImportError:
pass # shrug...
@@ -1,22 +1,51 @@
from __future__ import annotations
import base64 import base64
import io import io
import json import json
import os import os
import re import re
import sys
import gradio as gr import gradio as gr
from modules.paths import data_path from modules.paths import data_path
from modules import shared, ui_tempdir, script_callbacks, processing from modules import shared, ui_tempdir, script_callbacks, processing, infotext_versions
from PIL import Image from PIL import Image
sys.modules['modules.generation_parameters_copypaste'] = sys.modules[__name__] # alias for old name
re_param_code = r'\s*(\w[\w \-/]+):\s*("(?:\\.|[^\\"])+"|[^,]*)(?:,|$)' re_param_code = r'\s*(\w[\w \-/]+):\s*("(?:\\.|[^\\"])+"|[^,]*)(?:,|$)'
re_param = re.compile(re_param_code) re_param = re.compile(re_param_code)
re_imagesize = re.compile(r"^(\d+)x(\d+)$") re_imagesize = re.compile(r"^(\d+)x(\d+)$")
re_hypernet_hash = re.compile("\(([0-9a-f]+)\)$") re_hypernet_hash = re.compile("\(([0-9a-f]+)\)$")
type_of_gr_update = type(gr.update()) type_of_gr_update = type(gr.update())
quote_swap = str.maketrans('\'"', '"\'')
info_json_keys = set()
paste_fields = {}
registered_param_bindings = [] def info_json_dumps(data):
"""encode data into json string, but swap single and double quotes to reduce escaping issues"""
return json.dumps(data, ensure_ascii=False, separators=(',', ':')).translate(quote_swap)
def info_json_loads(info_json):
"""decode json string into info data, but swap single and double quotes to reduce escaping issues"""
return json.loads(info_json.translate(quote_swap))
def build_infotext(info: dict):
for info_json_key in info_json_keys:
if info_json_key in info:
info[info_json_key] = info_json_dumps(info[info_json_key])
return ", ".join([k if k == v else f'{k}: {quote(v)}' for k, v in info.items() if v is not None])
def register_info_json(key):
"""register an infotext key as infojson
after a key is registered, a json compatible data structure like dict or list can be used as a value in
generation_parameters and extra_generation_parameters
"""
global info_json_keys
info_json_keys.add(key)
class ParamBinding: class ParamBinding:
@@ -30,6 +59,23 @@ class ParamBinding:
self.paste_field_names = paste_field_names or [] self.paste_field_names = paste_field_names or []
class PasteField(tuple):
def __new__(cls, component, target, *, api=None):
return super().__new__(cls, (component, target))
def __init__(self, component, target, *, api=None):
super().__init__()
self.api = api
self.component = component
self.label = target if isinstance(target, str) else None
self.function = target if callable(target) else None
paste_fields: dict[str, dict] = {}
registered_param_bindings: list[ParamBinding] = []
def reset(): def reset():
paste_fields.clear() paste_fields.clear()
registered_param_bindings.clear() registered_param_bindings.clear()
@@ -82,6 +128,12 @@ def image_from_url_text(filedata):
def add_paste_fields(tabname, init_img, fields, override_settings_component=None): def add_paste_fields(tabname, init_img, fields, override_settings_component=None):
if fields:
for i in range(len(fields)):
if not isinstance(fields[i], PasteField):
fields[i] = PasteField(*fields[i])
paste_fields[tabname] = {"init_img": init_img, "fields": fields, "override_settings_component": override_settings_component} paste_fields[tabname] = {"init_img": init_img, "fields": fields, "override_settings_component": override_settings_component}
# backwards compatibility for existing extensions # backwards compatibility for existing extensions
@@ -113,7 +165,6 @@ def register_paste_params_button(binding: ParamBinding):
def connect_paste_params_buttons(): def connect_paste_params_buttons():
binding: ParamBinding
for binding in registered_param_bindings: for binding in registered_param_bindings:
destination_image_component = paste_fields[binding.tabname]["init_img"] destination_image_component = paste_fields[binding.tabname]["init_img"]
fields = paste_fields[binding.tabname]["fields"] fields = paste_fields[binding.tabname]["fields"]
@@ -207,7 +258,7 @@ def restore_old_hires_fix_params(res):
res['Hires resize-2'] = height res['Hires resize-2'] = height
def parse_generation_parameters(x: str): def parse_generation_parameters(x: str, skip_fields: list[str] | None = None):
"""parses generation parameters string, the one you see in text field under the picture in UI: """parses generation parameters string, the one you see in text field under the picture in UI:
``` ```
girl with an artist's beret, determined, blue eyes, desert scene, computer monitors, heavy makeup, by Alphonse Mucha and Charlie Bowater, ((eyeshadow)), (coquettish), detailed, intricate girl with an artist's beret, determined, blue eyes, desert scene, computer monitors, heavy makeup, by Alphonse Mucha and Charlie Bowater, ((eyeshadow)), (coquettish), detailed, intricate
@@ -217,6 +268,8 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
returns a dict with field values returns a dict with field values
""" """
if skip_fields is None:
skip_fields = shared.opts.infotext_skip_pasting
res = {} res = {}
@@ -289,6 +342,18 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
if "Hires negative prompt" not in res: if "Hires negative prompt" not in res:
res["Hires negative prompt"] = "" res["Hires negative prompt"] = ""
if "Mask mode" not in res:
res["Mask mode"] = "Inpaint masked"
if "Masked content" not in res:
res["Masked content"] = 'original'
if "Inpaint area" not in res:
res["Inpaint area"] = "Whole picture"
if "Masked area padding" not in res:
res["Masked area padding"] = 32
restore_old_hires_fix_params(res) restore_old_hires_fix_params(res)
# Missing RNG means the default was set, which is GPU RNG # Missing RNG means the default was set, which is GPU RNG
@@ -313,6 +378,24 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
if "VAE Decoder" not in res: if "VAE Decoder" not in res:
res["VAE Decoder"] = "Full" res["VAE Decoder"] = "Full"
if "FP8 weight" not in res:
res["FP8 weight"] = "Disable"
if "Cache FP16 weight for LoRA" not in res and res["FP8 weight"] != "Disable":
res["Cache FP16 weight for LoRA"] = False
for key in info_json_keys:
if key in res:
try:
res[key] = info_json_loads(res[key])
except Exception:
print(f'Error parsing "{key}: {res[key]}"')
infotext_versions.backcompat(res)
for key in skip_fields:
res.pop(key, None)
return res return res
@@ -361,13 +444,57 @@ def create_override_settings_dict(text_pairs):
return res return res
def get_override_settings(params, *, skip_fields=None):
"""Returns a list of settings overrides from the infotext parameters dictionary.
This function checks the `params` dictionary for any keys that correspond to settings in `shared.opts` and returns
a list of tuples containing the parameter name, setting name, and new value cast to correct type.
It checks for conditions before adding an override:
- ignores settings that match the current value
- ignores parameter keys present in skip_fields argument.
Example input:
{"Clip skip": "2"}
Example output:
[("Clip skip", "CLIP_stop_at_last_layers", 2)]
"""
res = []
mapping = [(info.infotext, k) for k, info in shared.opts.data_labels.items() if info.infotext]
for param_name, setting_name in mapping + infotext_to_setting_name_mapping:
if param_name in (skip_fields or {}):
continue
v = params.get(param_name, None)
if v is None:
continue
if setting_name == "sd_model_checkpoint" and shared.opts.disable_weights_auto_swap:
continue
v = shared.opts.cast_value(setting_name, v)
current_value = getattr(shared.opts, setting_name, None)
if v == current_value:
continue
res.append((param_name, setting_name, v))
return res
def connect_paste(button, paste_fields, input_comp, override_settings_component, tabname): def connect_paste(button, paste_fields, input_comp, override_settings_component, tabname):
def paste_func(prompt): def paste_func(prompt):
if not prompt and not shared.cmd_opts.hide_ui_dir_config: if not prompt and not shared.cmd_opts.hide_ui_dir_config:
filename = os.path.join(data_path, "params.txt") filename = os.path.join(data_path, "params.txt")
if os.path.exists(filename): try:
with open(filename, "r", encoding="utf8") as file: with open(filename, "r", encoding="utf8") as file:
prompt = file.read() prompt = file.read()
except OSError:
pass
params = parse_generation_parameters(prompt) params = parse_generation_parameters(prompt)
script_callbacks.infotext_pasted_callback(prompt, params) script_callbacks.infotext_pasted_callback(prompt, params)
@@ -389,6 +516,8 @@ def connect_paste(button, paste_fields, input_comp, override_settings_component,
if valtype == bool and v == "False": if valtype == bool and v == "False":
val = False val = False
elif valtype == int:
val = float(v)
else: else:
val = valtype(v) val = valtype(v)
@@ -402,29 +531,9 @@ def connect_paste(button, paste_fields, input_comp, override_settings_component,
already_handled_fields = {key: 1 for _, key in paste_fields} already_handled_fields = {key: 1 for _, key in paste_fields}
def paste_settings(params): def paste_settings(params):
vals = {} vals = get_override_settings(params, skip_fields=already_handled_fields)
mapping = [(info.infotext, k) for k, info in shared.opts.data_labels.items() if info.infotext] vals_pairs = [f"{infotext_text}: {value}" for infotext_text, setting_name, value in vals]
for param_name, setting_name in mapping + infotext_to_setting_name_mapping:
if param_name in already_handled_fields:
continue
v = params.get(param_name, None)
if v is None:
continue
if setting_name == "sd_model_checkpoint" and shared.opts.disable_weights_auto_swap:
continue
v = shared.opts.cast_value(setting_name, v)
current_value = getattr(shared.opts, setting_name, None)
if v == current_value:
continue
vals[param_name] = v
vals_pairs = [f"{k}: {v}" for k, v in vals.items()]
return gr.Dropdown.update(value=vals_pairs, choices=vals_pairs, visible=bool(vals_pairs)) return gr.Dropdown.update(value=vals_pairs, choices=vals_pairs, visible=bool(vals_pairs))
@@ -443,3 +552,4 @@ def connect_paste(button, paste_fields, input_comp, override_settings_component,
outputs=[], outputs=[],
show_progress=False, show_progress=False,
) )
+39
View File
@@ -0,0 +1,39 @@
from modules import shared
from packaging import version
import re
v160 = version.parse("1.6.0")
v170_tsnr = version.parse("v1.7.0-225")
def parse_version(text):
if text is None:
return None
m = re.match(r'([^-]+-[^-]+)-.*', text)
if m:
text = m.group(1)
try:
return version.parse(text)
except Exception:
return None
def backcompat(d):
"""Checks infotext Version field, and enables backwards compatibility options according to it."""
if not shared.opts.auto_backcompat:
return
ver = parse_version(d.get("Version"))
if ver is None:
return
if ver < v160:
d["Old prompt editing timelines"] = True
if ver < v170_tsnr:
d["Downcast alphas_cumprod"] = True
+2 -3
View File
@@ -1,5 +1,6 @@
import importlib import importlib
import logging import logging
import os
import sys import sys
import warnings import warnings
from threading import Thread from threading import Thread
@@ -18,6 +19,7 @@ def imports():
warnings.filterwarnings(action="ignore", category=DeprecationWarning, module="pytorch_lightning") warnings.filterwarnings(action="ignore", category=DeprecationWarning, module="pytorch_lightning")
warnings.filterwarnings(action="ignore", category=UserWarning, module="torchvision") warnings.filterwarnings(action="ignore", category=UserWarning, module="torchvision")
os.environ.setdefault('GRADIO_ANALYTICS_ENABLED', 'False')
import gradio # noqa: F401 import gradio # noqa: F401
startup_timer.record("import gradio") startup_timer.record("import gradio")
@@ -54,9 +56,6 @@ def initialize():
initialize_util.configure_sigint_handler() initialize_util.configure_sigint_handler()
initialize_util.configure_opts_onchange() initialize_util.configure_opts_onchange()
from modules import modelloader
modelloader.cleanup_models()
from modules import sd_models from modules import sd_models
sd_models.setup_model() sd_models.setup_model()
startup_timer.record("setup SD model") startup_timer.record("setup SD model")
+7 -1
View File
@@ -150,10 +150,14 @@ def dumpstacks():
def configure_sigint_handler(): def configure_sigint_handler():
# make the program just exit at ctrl+c without waiting for anything # make the program just exit at ctrl+c without waiting for anything
from modules import shared
def sigint_handler(sig, frame): def sigint_handler(sig, frame):
print(f'Interrupted with signal {sig} in {frame}') print(f'Interrupted with signal {sig} in {frame}')
dumpstacks() if shared.opts.dump_stacks_on_signal:
dumpstacks()
os._exit(0) os._exit(0)
@@ -173,6 +177,8 @@ def configure_opts_onchange():
shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed) shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed)
shared.opts.onchange("gradio_theme", shared.reload_gradio_theme) shared.opts.onchange("gradio_theme", shared.reload_gradio_theme)
shared.opts.onchange("cross_attention_optimization", wrap_queued_call(lambda: sd_hijack.model_hijack.redo_hijack(shared.sd_model)), call=False) shared.opts.onchange("cross_attention_optimization", wrap_queued_call(lambda: sd_hijack.model_hijack.redo_hijack(shared.sd_model)), call=False)
shared.opts.onchange("fp8_storage", wrap_queued_call(lambda: sd_models.reload_model_weights()), call=False)
shared.opts.onchange("cache_fp16_weight", wrap_queued_call(lambda: sd_models.reload_model_weights(forced_reload=True)), call=False)
startup_timer.record("opts onchange") startup_timer.record("opts onchange")
+2 -2
View File
@@ -10,7 +10,7 @@ import torch.hub
from torchvision import transforms from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
from modules import devices, paths, shared, lowvram, modelloader, errors from modules import devices, paths, shared, lowvram, modelloader, errors, torch_utils
blip_image_eval_size = 384 blip_image_eval_size = 384
clip_model_name = 'ViT-L/14' clip_model_name = 'ViT-L/14'
@@ -131,7 +131,7 @@ class InterrogateModels:
self.clip_model = self.clip_model.to(devices.device_interrogate) self.clip_model = self.clip_model.to(devices.device_interrogate)
self.dtype = next(self.clip_model.parameters()).dtype self.dtype = torch_utils.get_param(self.clip_model).dtype
def send_clip_to_ram(self): def send_clip_to_ram(self):
if not shared.opts.interrogate_keep_models_in_memory: if not shared.opts.interrogate_keep_models_in_memory:
+47 -21
View File
@@ -6,6 +6,7 @@ import os
import shutil import shutil
import sys import sys
import importlib.util import importlib.util
import importlib.metadata
import platform import platform
import json import json
from functools import lru_cache from functools import lru_cache
@@ -26,8 +27,7 @@ dir_repos = "repositories"
# Whether to default to printing command output # Whether to default to printing command output
default_command_live = (os.environ.get('WEBUI_LAUNCH_LIVE_OUTPUT') == "1") default_command_live = (os.environ.get('WEBUI_LAUNCH_LIVE_OUTPUT') == "1")
if 'GRADIO_ANALYTICS_ENABLED' not in os.environ: os.environ.setdefault('GRADIO_ANALYTICS_ENABLED', 'False')
os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False'
def check_python_version(): def check_python_version():
@@ -119,11 +119,16 @@ def run(command, desc=None, errdesc=None, custom_env=None, live: bool = default_
def is_installed(package): def is_installed(package):
try: try:
spec = importlib.util.find_spec(package) dist = importlib.metadata.distribution(package)
except ModuleNotFoundError: except importlib.metadata.PackageNotFoundError:
return False try:
spec = importlib.util.find_spec(package)
except ModuleNotFoundError:
return False
return spec is not None return spec is not None
return dist is not None
def repo_dir(name): def repo_dir(name):
@@ -239,11 +244,14 @@ def list_extensions(settings_file):
settings = {} settings = {}
try: try:
if os.path.isfile(settings_file): with open(settings_file, "r", encoding="utf8") as file:
with open(settings_file, "r", encoding="utf8") as file: settings = json.load(file)
settings = json.load(file) except FileNotFoundError:
pass
except Exception: except Exception:
errors.report("Could not load settings", exc_info=True) errors.report(f'\nCould not load settings\nThe config file "{settings_file}" is likely corrupted\nIt has been moved to the "tmp/config.json"\nReverting config to default\n\n''', exc_info=True)
os.replace(settings_file, os.path.join(script_path, "tmp", "config.json"))
settings = {}
disabled_extensions = set(settings.get('disabled_extensions', [])) disabled_extensions = set(settings.get('disabled_extensions', []))
disable_all_extensions = settings.get('disable_all_extensions', 'none') disable_all_extensions = settings.get('disable_all_extensions', 'none')
@@ -308,24 +316,44 @@ def requirements_met(requirements_file):
def prepare_environment(): def prepare_environment():
torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://download.pytorch.org/whl/cu118") torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://download.pytorch.org/whl/cu121")
torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.0.1 torchvision==0.15.2 --extra-index-url {torch_index_url}") torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.1.2 torchvision==0.16.2 --extra-index-url {torch_index_url}")
if args.use_ipex:
if platform.system() == "Windows":
# The "Nuullll/intel-extension-for-pytorch" wheels were built from IPEX source for Intel Arc GPU: https://github.com/intel/intel-extension-for-pytorch/tree/xpu-main
# This is NOT an Intel official release so please use it at your own risk!!
# See https://github.com/Nuullll/intel-extension-for-pytorch/releases/tag/v2.0.110%2Bxpu-master%2Bdll-bundle for details.
#
# Strengths (over official IPEX 2.0.110 windows release):
# - AOT build (for Arc GPU only) to eliminate JIT compilation overhead: https://github.com/intel/intel-extension-for-pytorch/issues/399
# - Bundles minimal oneAPI 2023.2 dependencies into the python wheels, so users don't need to install oneAPI for the whole system.
# - Provides a compatible torchvision wheel: https://github.com/intel/intel-extension-for-pytorch/issues/465
# Limitation:
# - Only works for python 3.10
url_prefix = "https://github.com/Nuullll/intel-extension-for-pytorch/releases/download/v2.0.110%2Bxpu-master%2Bdll-bundle"
torch_command = os.environ.get('TORCH_COMMAND', f"pip install {url_prefix}/torch-2.0.0a0+gite9ebda2-cp310-cp310-win_amd64.whl {url_prefix}/torchvision-0.15.2a0+fa99a53-cp310-cp310-win_amd64.whl {url_prefix}/intel_extension_for_pytorch-2.0.110+gitc6ea20b-cp310-cp310-win_amd64.whl")
else:
# Using official IPEX release for linux since it's already an AOT build.
# However, users still have to install oneAPI toolkit and activate oneAPI environment manually.
# See https://intel.github.io/intel-extension-for-pytorch/index.html#installation for details.
torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://pytorch-extension.intel.com/release-whl/stable/xpu/us/")
torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.0.0a0 intel-extension-for-pytorch==2.0.110+gitba7f6c1 --extra-index-url {torch_index_url}")
requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt") requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.20') xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.23.post1')
clip_package = os.environ.get('CLIP_PACKAGE', "https://github.com/openai/CLIP/archive/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1.zip") clip_package = os.environ.get('CLIP_PACKAGE', "https://github.com/openai/CLIP/archive/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1.zip")
openclip_package = os.environ.get('OPENCLIP_PACKAGE', "https://github.com/mlfoundations/open_clip/archive/bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b.zip") openclip_package = os.environ.get('OPENCLIP_PACKAGE', "https://github.com/mlfoundations/open_clip/archive/bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b.zip")
assets_repo = os.environ.get('ASSETS_REPO', "https://github.com/AUTOMATIC1111/stable-diffusion-webui-assets.git")
stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git") stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git")
stable_diffusion_xl_repo = os.environ.get('STABLE_DIFFUSION_XL_REPO', "https://github.com/Stability-AI/generative-models.git") stable_diffusion_xl_repo = os.environ.get('STABLE_DIFFUSION_XL_REPO', "https://github.com/Stability-AI/generative-models.git")
k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git') k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git')
codeformer_repo = os.environ.get('CODEFORMER_REPO', 'https://github.com/sczhou/CodeFormer.git')
blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git') blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git')
assets_commit_hash = os.environ.get('ASSETS_COMMIT_HASH', "6f7db241d2f8ba7457bac5ca9753331f0c266917")
stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf") stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf")
stable_diffusion_xl_commit_hash = os.environ.get('STABLE_DIFFUSION_XL_COMMIT_HASH', "45c443b316737a4ab6e40413d7794a7f5657c19f") stable_diffusion_xl_commit_hash = os.environ.get('STABLE_DIFFUSION_XL_COMMIT_HASH', "45c443b316737a4ab6e40413d7794a7f5657c19f")
k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "ab527a9a6d347f364e3d185ba6d714e22d80cb3c") k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "ab527a9a6d347f364e3d185ba6d714e22d80cb3c")
codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af")
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9") blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
try: try:
@@ -352,6 +380,8 @@ def prepare_environment():
run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch", live=True) run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch", live=True)
startup_timer.record("install torch") startup_timer.record("install torch")
if args.use_ipex:
args.skip_torch_cuda_test = True
if not args.skip_torch_cuda_test and not check_run_python("import torch; assert torch.cuda.is_available()"): if not args.skip_torch_cuda_test and not check_run_python("import torch; assert torch.cuda.is_available()"):
raise RuntimeError( raise RuntimeError(
'Torch is not able to use GPU; ' 'Torch is not able to use GPU; '
@@ -377,18 +407,14 @@ def prepare_environment():
os.makedirs(os.path.join(script_path, dir_repos), exist_ok=True) os.makedirs(os.path.join(script_path, dir_repos), exist_ok=True)
git_clone(assets_repo, repo_dir('stable-diffusion-webui-assets'), "assets", assets_commit_hash)
git_clone(stable_diffusion_repo, repo_dir('stable-diffusion-stability-ai'), "Stable Diffusion", stable_diffusion_commit_hash) git_clone(stable_diffusion_repo, repo_dir('stable-diffusion-stability-ai'), "Stable Diffusion", stable_diffusion_commit_hash)
git_clone(stable_diffusion_xl_repo, repo_dir('generative-models'), "Stable Diffusion XL", stable_diffusion_xl_commit_hash) git_clone(stable_diffusion_xl_repo, repo_dir('generative-models'), "Stable Diffusion XL", stable_diffusion_xl_commit_hash)
git_clone(k_diffusion_repo, repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash) git_clone(k_diffusion_repo, repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash)
git_clone(codeformer_repo, repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash)
git_clone(blip_repo, repo_dir('BLIP'), "BLIP", blip_commit_hash) git_clone(blip_repo, repo_dir('BLIP'), "BLIP", blip_commit_hash)
startup_timer.record("clone repositores") startup_timer.record("clone repositores")
if not is_installed("lpips"):
run_pip(f"install -r \"{os.path.join(repo_dir('CodeFormer'), 'requirements.txt')}\"", "requirements for CodeFormer")
startup_timer.record("install CodeFormer requirements")
if not os.path.isfile(requirements_file): if not os.path.isfile(requirements_file):
requirements_file = os.path.join(script_path, requirements_file) requirements_file = os.path.join(script_path, requirements_file)
@@ -441,7 +467,7 @@ def dump_sysinfo():
import datetime import datetime
text = sysinfo.get() text = sysinfo.get()
filename = f"sysinfo-{datetime.datetime.utcnow().strftime('%Y-%m-%d-%H-%M')}.txt" filename = f"sysinfo-{datetime.datetime.utcnow().strftime('%Y-%m-%d-%H-%M')}.json"
with open(filename, "w", encoding="utf8") as file: with open(filename, "w", encoding="utf8") as file:
file.write(text) file.write(text)
+50 -8
View File
@@ -1,16 +1,58 @@
import os
import logging import logging
import os
try:
from tqdm import tqdm
class TqdmLoggingHandler(logging.Handler):
def __init__(self, fallback_handler: logging.Handler):
super().__init__()
self.fallback_handler = fallback_handler
def emit(self, record):
try:
# If there are active tqdm progress bars,
# attempt to not interfere with them.
if tqdm._instances:
tqdm.write(self.format(record))
else:
self.fallback_handler.emit(record)
except Exception:
self.fallback_handler.emit(record)
except ImportError:
TqdmLoggingHandler = None
def setup_logging(loglevel): def setup_logging(loglevel):
if loglevel is None: if loglevel is None:
loglevel = os.environ.get("SD_WEBUI_LOG_LEVEL") loglevel = os.environ.get("SD_WEBUI_LOG_LEVEL")
if loglevel: if not loglevel:
log_level = getattr(logging, loglevel.upper(), None) or logging.INFO return
logging.basicConfig(
level=log_level,
format='%(asctime)s %(levelname)s [%(name)s] %(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
)
if logging.root.handlers:
# Already configured, do not interfere
return
formatter = logging.Formatter(
'%(asctime)s %(levelname)s [%(name)s] %(message)s',
'%Y-%m-%d %H:%M:%S',
)
if os.environ.get("SD_WEBUI_RICH_LOG"):
from rich.logging import RichHandler
handler = RichHandler()
else:
handler = logging.StreamHandler()
handler.setFormatter(formatter)
if TqdmLoggingHandler:
handler = TqdmLoggingHandler(handler)
handler.setFormatter(formatter)
log_level = getattr(logging, loglevel.upper(), None) or logging.INFO
logging.root.setLevel(log_level)
logging.root.addHandler(handler)
+15
View File
@@ -1,6 +1,7 @@
import logging import logging
import torch import torch
from torch import Tensor
import platform import platform
from modules.sd_hijack_utils import CondFunc from modules.sd_hijack_utils import CondFunc
from packaging import version from packaging import version
@@ -51,6 +52,17 @@ def cumsum_fix(input, cumsum_func, *args, **kwargs):
return cumsum_func(input, *args, **kwargs) return cumsum_func(input, *args, **kwargs)
# MPS workaround for https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14046
def interpolate_with_fp32_fallback(orig_func, *args, **kwargs) -> Tensor:
try:
return orig_func(*args, **kwargs)
except RuntimeError as e:
if "not implemented for" in str(e) and "Half" in str(e):
input_tensor = args[0]
return orig_func(input_tensor.to(torch.float32), *args[1:], **kwargs).to(input_tensor.dtype)
else:
print(f"An unexpected RuntimeError occurred: {str(e)}")
if has_mps: if has_mps:
if platform.mac_ver()[0].startswith("13.2."): if platform.mac_ver()[0].startswith("13.2."):
# MPS workaround for https://github.com/pytorch/pytorch/issues/95188, thanks to danieldk (https://github.com/explosion/curated-transformers/pull/124) # MPS workaround for https://github.com/pytorch/pytorch/issues/95188, thanks to danieldk (https://github.com/explosion/curated-transformers/pull/124)
@@ -77,6 +89,9 @@ if has_mps:
# MPS workaround for https://github.com/pytorch/pytorch/issues/96113 # MPS workaround for https://github.com/pytorch/pytorch/issues/96113
CondFunc('torch.nn.functional.layer_norm', lambda orig_func, x, normalized_shape, weight, bias, eps, **kwargs: orig_func(x.float(), normalized_shape, weight.float() if weight is not None else None, bias.float() if bias is not None else bias, eps).to(x.dtype), lambda _, input, *args, **kwargs: len(args) == 4 and input.device.type == 'mps') CondFunc('torch.nn.functional.layer_norm', lambda orig_func, x, normalized_shape, weight, bias, eps, **kwargs: orig_func(x.float(), normalized_shape, weight.float() if weight is not None else None, bias.float() if bias is not None else bias, eps).to(x.dtype), lambda _, input, *args, **kwargs: len(args) == 4 and input.device.type == 'mps')
# MPS workaround for https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14046
CondFunc('torch.nn.functional.interpolate', interpolate_with_fp32_fallback, None)
# MPS workaround for https://github.com/pytorch/pytorch/issues/92311 # MPS workaround for https://github.com/pytorch/pytorch/issues/92311
if platform.processor() == 'i386': if platform.processor() == 'i386':
for funcName in ['torch.argmax', 'torch.Tensor.argmax']: for funcName in ['torch.argmax', 'torch.Tensor.argmax']:
+9 -34
View File
@@ -3,40 +3,15 @@ from PIL import Image, ImageFilter, ImageOps
def get_crop_region(mask, pad=0): def get_crop_region(mask, pad=0):
"""finds a rectangular region that contains all masked ares in an image. Returns (x1, y1, x2, y2) coordinates of the rectangle. """finds a rectangular region that contains all masked ares in an image. Returns (x1, y1, x2, y2) coordinates of the rectangle.
For example, if a user has painted the top-right part of a 512x512 image", the result may be (256, 0, 512, 256)""" For example, if a user has painted the top-right part of a 512x512 image, the result may be (256, 0, 512, 256)"""
mask_img = mask if isinstance(mask, Image.Image) else Image.fromarray(mask)
h, w = mask.shape box = mask_img.getbbox()
if box:
crop_left = 0 x1, y1, x2, y2 = box
for i in range(w): else: # when no box is found
if not (mask[:, i] == 0).all(): x1, y1 = mask_img.size
break x2 = y2 = 0
crop_left += 1 return max(x1 - pad, 0), max(y1 - pad, 0), min(x2 + pad, mask_img.size[0]), min(y2 + pad, mask_img.size[1])
crop_right = 0
for i in reversed(range(w)):
if not (mask[:, i] == 0).all():
break
crop_right += 1
crop_top = 0
for i in range(h):
if not (mask[i] == 0).all():
break
crop_top += 1
crop_bottom = 0
for i in reversed(range(h)):
if not (mask[i] == 0).all():
break
crop_bottom += 1
return (
int(max(crop_left-pad, 0)),
int(max(crop_top-pad, 0)),
int(min(w - crop_right + pad, w)),
int(min(h - crop_bottom + pad, h))
)
def expand_crop_region(crop_region, processing_width, processing_height, image_width, image_height): def expand_crop_region(crop_region, processing_width, processing_height, image_width, image_height):
+41 -51
View File
@@ -1,13 +1,20 @@
from __future__ import annotations from __future__ import annotations
import os
import shutil
import importlib import importlib
import logging
import os
from typing import TYPE_CHECKING
from urllib.parse import urlparse from urllib.parse import urlparse
import torch
from modules import shared from modules import shared
from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, UpscalerNone from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, UpscalerNone
from modules.paths import script_path, models_path
if TYPE_CHECKING:
import spandrel
logger = logging.getLogger(__name__)
def load_file_from_url( def load_file_from_url(
@@ -90,54 +97,6 @@ def friendly_name(file: str):
return model_name return model_name
def cleanup_models():
# This code could probably be more efficient if we used a tuple list or something to store the src/destinations
# and then enumerate that, but this works for now. In the future, it'd be nice to just have every "model" scaler
# somehow auto-register and just do these things...
root_path = script_path
src_path = models_path
dest_path = os.path.join(models_path, "Stable-diffusion")
move_files(src_path, dest_path, ".ckpt")
move_files(src_path, dest_path, ".safetensors")
src_path = os.path.join(root_path, "ESRGAN")
dest_path = os.path.join(models_path, "ESRGAN")
move_files(src_path, dest_path)
src_path = os.path.join(models_path, "BSRGAN")
dest_path = os.path.join(models_path, "ESRGAN")
move_files(src_path, dest_path, ".pth")
src_path = os.path.join(root_path, "gfpgan")
dest_path = os.path.join(models_path, "GFPGAN")
move_files(src_path, dest_path)
src_path = os.path.join(root_path, "SwinIR")
dest_path = os.path.join(models_path, "SwinIR")
move_files(src_path, dest_path)
src_path = os.path.join(root_path, "repositories/latent-diffusion/experiments/pretrained_models/")
dest_path = os.path.join(models_path, "LDSR")
move_files(src_path, dest_path)
def move_files(src_path: str, dest_path: str, ext_filter: str = None):
try:
os.makedirs(dest_path, exist_ok=True)
if os.path.exists(src_path):
for file in os.listdir(src_path):
fullpath = os.path.join(src_path, file)
if os.path.isfile(fullpath):
if ext_filter is not None:
if ext_filter not in file:
continue
print(f"Moving {file} from {src_path} to {dest_path}.")
try:
shutil.move(fullpath, dest_path)
except Exception:
pass
if len(os.listdir(src_path)) == 0:
print(f"Removing empty folder: {src_path}")
shutil.rmtree(src_path, True)
except Exception:
pass
def load_upscalers(): def load_upscalers():
# We can only do this 'magic' method to dynamically load upscalers if they are referenced, # We can only do this 'magic' method to dynamically load upscalers if they are referenced,
# so we'll try to import any _model.py files before looking in __subclasses__ # so we'll try to import any _model.py files before looking in __subclasses__
@@ -177,3 +136,34 @@ def load_upscalers():
# Special case for UpscalerNone keeps it at the beginning of the list. # Special case for UpscalerNone keeps it at the beginning of the list.
key=lambda x: x.name.lower() if not isinstance(x.scaler, (UpscalerNone, UpscalerLanczos, UpscalerNearest)) else "" key=lambda x: x.name.lower() if not isinstance(x.scaler, (UpscalerNone, UpscalerLanczos, UpscalerNearest)) else ""
) )
def load_spandrel_model(
path: str | os.PathLike,
*,
device: str | torch.device | None,
prefer_half: bool = False,
dtype: str | torch.dtype | None = None,
expected_architecture: str | None = None,
) -> spandrel.ModelDescriptor:
import spandrel
model_descriptor = spandrel.ModelLoader(device=device).load_from_file(str(path))
if expected_architecture and model_descriptor.architecture != expected_architecture:
logger.warning(
f"Model {path!r} is not a {expected_architecture!r} model (got {model_descriptor.architecture!r})",
)
half = False
if prefer_half:
if model_descriptor.supports_half:
model_descriptor.model.half()
half = True
else:
logger.info("Model %s does not support half precision, ignoring --half", path)
if dtype:
model_descriptor.model.to(dtype=dtype)
model_descriptor.model.eval()
logger.debug(
"Loaded %s from %s (device=%s, half=%s, dtype=%s)",
model_descriptor, path, device, half, dtype,
)
return model_descriptor
+6 -1
View File
@@ -24,10 +24,15 @@ from pytorch_lightning.utilities.distributed import rank_zero_only
from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config
from ldm.modules.ema import LitEma from ldm.modules.ema import LitEma
from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution
from ldm.models.autoencoder import VQModelInterface, IdentityFirstStage, AutoencoderKL from ldm.models.autoencoder import IdentityFirstStage, AutoencoderKL
from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like
from ldm.models.diffusion.ddim import DDIMSampler from ldm.models.diffusion.ddim import DDIMSampler
try:
from ldm.models.autoencoder import VQModelInterface
except Exception:
class VQModelInterface:
pass
__conditioning_keys__ = {'concat': 'c_concat', __conditioning_keys__ = {'concat': 'c_concat',
'crossattn': 'c_crossattn', 'crossattn': 'c_crossattn',
+99 -15
View File
@@ -1,20 +1,24 @@
import os
import json import json
import sys import sys
from dataclasses import dataclass
import gradio as gr import gradio as gr
from modules import errors from modules import errors
from modules.shared_cmd_options import cmd_opts from modules.shared_cmd_options import cmd_opts
from modules.paths_internal import script_path
class OptionInfo: class OptionInfo:
def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, section=None, refresh=None, comment_before='', comment_after='', infotext=None, restrict_api=False): def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, section=None, refresh=None, comment_before='', comment_after='', infotext=None, restrict_api=False, category_id=None):
self.default = default self.default = default
self.label = label self.label = label
self.component = component self.component = component
self.component_args = component_args self.component_args = component_args
self.onchange = onchange self.onchange = onchange
self.section = section self.section = section
self.category_id = category_id
self.refresh = refresh self.refresh = refresh
self.do_not_save = False self.do_not_save = False
@@ -63,7 +67,11 @@ class OptionHTML(OptionInfo):
def options_section(section_identifier, options_dict): def options_section(section_identifier, options_dict):
for v in options_dict.values(): for v in options_dict.values():
v.section = section_identifier if len(section_identifier) == 2:
v.section = section_identifier
elif len(section_identifier) == 3:
v.section = section_identifier[0:2]
v.category_id = section_identifier[2]
return options_dict return options_dict
@@ -76,7 +84,7 @@ class Options:
def __init__(self, data_labels: dict[str, OptionInfo], restricted_opts): def __init__(self, data_labels: dict[str, OptionInfo], restricted_opts):
self.data_labels = data_labels self.data_labels = data_labels
self.data = {k: v.default for k, v in self.data_labels.items()} self.data = {k: v.default for k, v in self.data_labels.items() if not v.do_not_save}
self.restricted_opts = restricted_opts self.restricted_opts = restricted_opts
def __setattr__(self, key, value): def __setattr__(self, key, value):
@@ -85,18 +93,35 @@ class Options:
if self.data is not None: if self.data is not None:
if key in self.data or key in self.data_labels: if key in self.data or key in self.data_labels:
# Check that settings aren't globally frozen
assert not cmd_opts.freeze_settings, "changing settings is disabled" assert not cmd_opts.freeze_settings, "changing settings is disabled"
# Get the info related to the setting being changed
info = self.data_labels.get(key, None) info = self.data_labels.get(key, None)
if info.do_not_save: if info.do_not_save:
return return
# Restrict component arguments
comp_args = info.component_args if info else None comp_args = info.component_args if info else None
if isinstance(comp_args, dict) and comp_args.get('visible', True) is False: if isinstance(comp_args, dict) and comp_args.get('visible', True) is False:
raise RuntimeError(f"not possible to set {key} because it is restricted") raise RuntimeError(f"not possible to set '{key}' because it is restricted")
# Check that this section isn't frozen
if cmd_opts.freeze_settings_in_sections is not None:
frozen_sections = list(map(str.strip, cmd_opts.freeze_settings_in_sections.split(','))) # Trim whitespace from section names
section_key = info.section[0]
section_name = info.section[1]
assert section_key not in frozen_sections, f"not possible to set '{key}' because settings in section '{section_name}' ({section_key}) are frozen with --freeze-settings-in-sections"
# Check that this section of the settings isn't frozen
if cmd_opts.freeze_specific_settings is not None:
frozen_keys = list(map(str.strip, cmd_opts.freeze_specific_settings.split(','))) # Trim whitespace from setting keys
assert key not in frozen_keys, f"not possible to set '{key}' because this setting is frozen with --freeze-specific-settings"
# Check shorthand option which disables editing options in "saving-paths"
if cmd_opts.hide_ui_dir_config and key in self.restricted_opts: if cmd_opts.hide_ui_dir_config and key in self.restricted_opts:
raise RuntimeError(f"not possible to set {key} because it is restricted") raise RuntimeError(f"not possible to set '{key}' because it is restricted with --hide_ui_dir_config")
self.data[key] = value self.data[key] = value
return return
@@ -158,7 +183,7 @@ class Options:
assert not cmd_opts.freeze_settings, "saving settings is disabled" assert not cmd_opts.freeze_settings, "saving settings is disabled"
with open(filename, "w", encoding="utf8") as file: with open(filename, "w", encoding="utf8") as file:
json.dump(self.data, file, indent=4) json.dump(self.data, file, indent=4, ensure_ascii=False)
def same_type(self, x, y): def same_type(self, x, y):
if x is None or y is None: if x is None or y is None:
@@ -170,9 +195,13 @@ class Options:
return type_x == type_y return type_x == type_y
def load(self, filename): def load(self, filename):
with open(filename, "r", encoding="utf8") as file: try:
self.data = json.load(file) with open(filename, "r", encoding="utf8") as file:
self.data = json.load(file)
except Exception:
errors.report(f'\nCould not load settings\nThe config file "{filename}" is likely corrupted\nIt has been moved to the "tmp/config.json"\nReverting config to default\n\n''', exc_info=True)
os.replace(filename, os.path.join(script_path, "tmp", "config.json"))
self.data = {}
# 1.6.0 VAE defaults # 1.6.0 VAE defaults
if self.data.get('sd_vae_as_default') is not None and self.data.get('sd_vae_overrides_per_model_preferences') is None: if self.data.get('sd_vae_as_default') is not None and self.data.get('sd_vae_overrides_per_model_preferences') is None:
self.data['sd_vae_overrides_per_model_preferences'] = not self.data.get('sd_vae_as_default') self.data['sd_vae_overrides_per_model_preferences'] = not self.data.get('sd_vae_as_default')
@@ -206,23 +235,59 @@ class Options:
d = {k: self.data.get(k, v.default) for k, v in self.data_labels.items()} d = {k: self.data.get(k, v.default) for k, v in self.data_labels.items()}
d["_comments_before"] = {k: v.comment_before for k, v in self.data_labels.items() if v.comment_before is not None} d["_comments_before"] = {k: v.comment_before for k, v in self.data_labels.items() if v.comment_before is not None}
d["_comments_after"] = {k: v.comment_after for k, v in self.data_labels.items() if v.comment_after is not None} d["_comments_after"] = {k: v.comment_after for k, v in self.data_labels.items() if v.comment_after is not None}
item_categories = {}
for item in self.data_labels.values():
category = categories.mapping.get(item.category_id)
category = "Uncategorized" if category is None else category.label
if category not in item_categories:
item_categories[category] = item.section[1]
# _categories is a list of pairs: [section, category]. Each section (a setting page) will get a special heading above it with the category as text.
d["_categories"] = [[v, k] for k, v in item_categories.items()] + [["Defaults", "Other"]]
return json.dumps(d) return json.dumps(d)
def add_option(self, key, info): def add_option(self, key, info):
self.data_labels[key] = info self.data_labels[key] = info
if key not in self.data: if key not in self.data and not info.do_not_save:
self.data[key] = info.default self.data[key] = info.default
def reorder(self): def reorder(self):
"""reorder settings so that all items related to section always go together""" """Reorder settings so that:
- all items related to section always go together
- all sections belonging to a category go together
- sections inside a category are ordered alphabetically
- categories are ordered by creation order
Category is a superset of sections: for category "postprocessing" there could be multiple sections: "face restoration", "upscaling".
This function also changes items' category_id so that all items belonging to a section have the same category_id.
"""
category_ids = {}
section_categories = {}
section_ids = {}
settings_items = self.data_labels.items() settings_items = self.data_labels.items()
for _, item in settings_items: for _, item in settings_items:
if item.section not in section_ids: if item.section not in section_categories:
section_ids[item.section] = len(section_ids) section_categories[item.section] = item.category_id
self.data_labels = dict(sorted(settings_items, key=lambda x: section_ids[x[1].section])) for _, item in settings_items:
item.category_id = section_categories.get(item.section)
for category_id in categories.mapping:
if category_id not in category_ids:
category_ids[category_id] = len(category_ids)
def sort_key(x):
item: OptionInfo = x[1]
category_order = category_ids.get(item.category_id, len(category_ids))
section_order = item.section[1]
return category_order, section_order
self.data_labels = dict(sorted(settings_items, key=sort_key))
def cast_value(self, key, value): def cast_value(self, key, value):
"""casts an arbitrary to the same type as this setting's value with key """casts an arbitrary to the same type as this setting's value with key
@@ -245,3 +310,22 @@ class Options:
value = expected_type(value) value = expected_type(value)
return value return value
@dataclass
class OptionsCategory:
id: str
label: str
class OptionsCategories:
def __init__(self):
self.mapping = {}
def register_category(self, category_id, label):
if category_id in self.mapping:
return category_id
self.mapping[category_id] = OptionsCategory(category_id, label)
categories = OptionsCategories()
-1
View File
@@ -38,7 +38,6 @@ mute_sdxl_imports()
path_dirs = [ path_dirs = [
(sd_path, 'ldm', 'Stable Diffusion', []), (sd_path, 'ldm', 'Stable Diffusion', []),
(os.path.join(sd_path, '../generative-models'), 'sgm', 'Stable Diffusion XL', ["sgm"]), (os.path.join(sd_path, '../generative-models'), 'sgm', 'Stable Diffusion XL', ["sgm"]),
(os.path.join(sd_path, '../CodeFormer'), 'inference_codeformer.py', 'CodeFormer', []),
(os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP', []), (os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP', []),
(os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]), (os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]),
] ]
+1
View File
@@ -28,5 +28,6 @@ models_path = os.path.join(data_path, "models")
extensions_dir = os.path.join(data_path, "extensions") extensions_dir = os.path.join(data_path, "extensions")
extensions_builtin_dir = os.path.join(script_path, "extensions-builtin") extensions_builtin_dir = os.path.join(script_path, "extensions-builtin")
config_states_dir = os.path.join(script_path, "config_states") config_states_dir = os.path.join(script_path, "config_states")
default_output_dir = os.path.join(data_path, "output")
roboto_ttf_file = os.path.join(modules_path, 'Roboto-Regular.ttf') roboto_ttf_file = os.path.join(modules_path, 'Roboto-Regular.ttf')
+75 -22
View File
@@ -2,7 +2,7 @@ import os
from PIL import Image from PIL import Image
from modules import shared, images, devices, scripts, scripts_postprocessing, ui_common, generation_parameters_copypaste from modules import shared, images, devices, scripts, scripts_postprocessing, ui_common, infotext_utils
from modules.shared import opts from modules.shared import opts
@@ -29,11 +29,7 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir,
image_list = shared.listfiles(input_dir) image_list = shared.listfiles(input_dir)
for filename in image_list: for filename in image_list:
try: yield filename, filename
image = Image.open(filename)
except Exception:
continue
yield image, filename
else: else:
assert image, 'image not selected' assert image, 'image not selected'
yield image, None yield image, None
@@ -45,43 +41,98 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir,
infotext = '' infotext = ''
for image_data, name in get_images(extras_mode, image, image_folder, input_dir): data_to_process = list(get_images(extras_mode, image, image_folder, input_dir))
shared.state.job_count = len(data_to_process)
for image_placeholder, name in data_to_process:
image_data: Image.Image image_data: Image.Image
shared.state.nextjob()
shared.state.textinfo = name shared.state.textinfo = name
shared.state.skipped = False
if shared.state.interrupted:
break
if isinstance(image_placeholder, str):
try:
image_data = Image.open(image_placeholder)
except Exception:
continue
else:
image_data = image_placeholder
parameters, existing_pnginfo = images.read_info_from_image(image_data) parameters, existing_pnginfo = images.read_info_from_image(image_data)
if parameters: if parameters:
existing_pnginfo["parameters"] = parameters existing_pnginfo["parameters"] = parameters
pp = scripts_postprocessing.PostprocessedImage(image_data.convert("RGB")) initial_pp = scripts_postprocessing.PostprocessedImage(image_data.convert("RGB"))
scripts.scripts_postproc.run(pp, args) scripts.scripts_postproc.run(initial_pp, args)
if opts.use_original_name_batch and name is not None: if shared.state.skipped:
basename = os.path.splitext(os.path.basename(name))[0] continue
else:
basename = ''
infotext = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in pp.info.items() if v is not None]) used_suffixes = {}
for pp in [initial_pp, *initial_pp.extra_images]:
suffix = pp.get_suffix(used_suffixes)
if opts.enable_pnginfo: if opts.use_original_name_batch and name is not None:
pp.image.info = existing_pnginfo basename = os.path.splitext(os.path.basename(name))[0]
pp.image.info["postprocessing"] = infotext forced_filename = basename + suffix
else:
basename = ''
forced_filename = None
if save_output: infotext = infotext_utils.build_infotext(pp.info)
images.save_image(pp.image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=infotext, short_filename=True, no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None)
if extras_mode != 2 or show_extras_results: if opts.enable_pnginfo:
outputs.append(pp.image) pp.image.info = existing_pnginfo
pp.image.info["postprocessing"] = infotext
shared.state.assign_current_image(pp.image)
if save_output:
fullfn, _ = images.save_image(pp.image, path=outpath, basename=basename, extension=opts.samples_format, info=infotext, short_filename=True, no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=forced_filename, suffix=suffix)
if pp.caption:
caption_filename = os.path.splitext(fullfn)[0] + ".txt"
existing_caption = ""
try:
with open(caption_filename, encoding="utf8") as file:
existing_caption = file.read().strip()
except FileNotFoundError:
pass
action = shared.opts.postprocessing_existing_caption_action
if action == 'Prepend' and existing_caption:
caption = f"{existing_caption} {pp.caption}"
elif action == 'Append' and existing_caption:
caption = f"{pp.caption} {existing_caption}"
elif action == 'Keep' and existing_caption:
caption = existing_caption
else:
caption = pp.caption
caption = caption.strip()
if caption:
with open(caption_filename, "w", encoding="utf8") as file:
file.write(caption)
if extras_mode != 2 or show_extras_results:
outputs.append(pp.image)
image_data.close() image_data.close()
devices.torch_gc() devices.torch_gc()
shared.state.end()
return outputs, ui_common.plaintext_to_html(infotext), '' return outputs, ui_common.plaintext_to_html(infotext), ''
def run_postprocessing_webui(id_task, *args, **kwargs):
return run_postprocessing(*args, **kwargs)
def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool, save_output: bool = True): def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool, save_output: bool = True):
"""old handler for API""" """old handler for API"""
@@ -97,9 +148,11 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
"upscaler_2_visibility": extras_upscaler_2_visibility, "upscaler_2_visibility": extras_upscaler_2_visibility,
}, },
"GFPGAN": { "GFPGAN": {
"enable": True,
"gfpgan_visibility": gfpgan_visibility, "gfpgan_visibility": gfpgan_visibility,
}, },
"CodeFormer": { "CodeFormer": {
"enable": True,
"codeformer_visibility": codeformer_visibility, "codeformer_visibility": codeformer_visibility,
"codeformer_weight": codeformer_weight, "codeformer_weight": codeformer_weight,
}, },
+228 -68
View File
@@ -16,7 +16,7 @@ from skimage import exposure
from typing import Any from typing import Any
import modules.sd_hijack import modules.sd_hijack
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors, rng from modules import devices, prompt_parser, masking, sd_samplers, lowvram, infotext_utils, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors, rng
from modules.rng import slerp # noqa: F401 from modules.rng import slerp # noqa: F401
from modules.sd_hijack import model_hijack from modules.sd_hijack import model_hijack
from modules.sd_samplers_common import images_tensor_to_samples, decode_first_stage, approximation_indexes from modules.sd_samplers_common import images_tensor_to_samples, decode_first_stage, approximation_indexes
@@ -62,18 +62,22 @@ def apply_color_correction(correction, original_image):
return image.convert('RGB') return image.convert('RGB')
def apply_overlay(image, paste_loc, index, overlays): def uncrop(image, dest_size, paste_loc):
if overlays is None or index >= len(overlays): x, y, w, h = paste_loc
base_image = Image.new('RGBA', dest_size)
image = images.resize_image(1, image, w, h)
base_image.paste(image, (x, y))
image = base_image
return image
def apply_overlay(image, paste_loc, overlay):
if overlay is None:
return image return image
overlay = overlays[index]
if paste_loc is not None: if paste_loc is not None:
x, y, w, h = paste_loc image = uncrop(image, (overlay.width, overlay.height), paste_loc)
base_image = Image.new('RGBA', (overlay.width, overlay.height))
image = images.resize_image(1, image, w, h)
base_image.paste(image, (x, y))
image = base_image
image = image.convert('RGBA') image = image.convert('RGBA')
image.alpha_composite(overlay) image.alpha_composite(overlay)
@@ -81,9 +85,12 @@ def apply_overlay(image, paste_loc, index, overlays):
return image return image
def create_binary_mask(image): def create_binary_mask(image, round=True):
if image.mode == 'RGBA' and image.getextrema()[-1] != (255, 255): if image.mode == 'RGBA' and image.getextrema()[-1] != (255, 255):
image = image.split()[-1].convert("L").point(lambda x: 255 if x > 128 else 0) if round:
image = image.split()[-1].convert("L").point(lambda x: 255 if x > 128 else 0)
else:
image = image.split()[-1].convert("L")
else: else:
image = image.convert('L') image = image.convert('L')
return image return image
@@ -106,6 +113,21 @@ def txt2img_image_conditioning(sd_model, x, width, height):
return x.new_zeros(x.shape[0], 2*sd_model.noise_augmentor.time_embed.dim, dtype=x.dtype, device=x.device) return x.new_zeros(x.shape[0], 2*sd_model.noise_augmentor.time_embed.dim, dtype=x.dtype, device=x.device)
else: else:
sd = sd_model.model.state_dict()
diffusion_model_input = sd.get('diffusion_model.input_blocks.0.0.weight', None)
if diffusion_model_input is not None:
if diffusion_model_input.shape[1] == 9:
# The "masked-image" in this case will just be all 0.5 since the entire image is masked.
image_conditioning = torch.ones(x.shape[0], 3, height, width, device=x.device) * 0.5
image_conditioning = images_tensor_to_samples(image_conditioning,
approximation_indexes.get(opts.sd_vae_encode_method))
# Add the fake full 1s mask to the first dimension.
image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
image_conditioning = image_conditioning.to(x.dtype)
return image_conditioning
# Dummy zero conditioning if we're not using inpainting or unclip models. # Dummy zero conditioning if we're not using inpainting or unclip models.
# Still takes up a bit of memory, but no encoder call. # Still takes up a bit of memory, but no encoder call.
# Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size. # Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
@@ -157,6 +179,7 @@ class StableDiffusionProcessing:
token_merging_ratio = 0 token_merging_ratio = 0
token_merging_ratio_hr = 0 token_merging_ratio_hr = 0
disable_extra_networks: bool = False disable_extra_networks: bool = False
firstpass_image: Image = None
scripts_value: scripts.ScriptRunner = field(default=None, init=False) scripts_value: scripts.ScriptRunner = field(default=None, init=False)
script_args_value: list = field(default=None, init=False) script_args_value: list = field(default=None, init=False)
@@ -296,7 +319,7 @@ class StableDiffusionProcessing:
return conditioning return conditioning
def edit_image_conditioning(self, source_image): def edit_image_conditioning(self, source_image):
conditioning_image = images_tensor_to_samples(source_image*0.5+0.5, approximation_indexes.get(opts.sd_vae_encode_method)) conditioning_image = shared.sd_model.encode_first_stage(source_image).mode()
return conditioning_image return conditioning_image
@@ -308,7 +331,7 @@ class StableDiffusionProcessing:
c_adm = torch.cat((c_adm, noise_level_emb), 1) c_adm = torch.cat((c_adm, noise_level_emb), 1)
return c_adm return c_adm
def inpainting_image_conditioning(self, source_image, latent_image, image_mask=None): def inpainting_image_conditioning(self, source_image, latent_image, image_mask=None, round_image_mask=True):
self.is_using_inpainting_conditioning = True self.is_using_inpainting_conditioning = True
# Handle the different mask inputs # Handle the different mask inputs
@@ -320,8 +343,10 @@ class StableDiffusionProcessing:
conditioning_mask = conditioning_mask.astype(np.float32) / 255.0 conditioning_mask = conditioning_mask.astype(np.float32) / 255.0
conditioning_mask = torch.from_numpy(conditioning_mask[None, None]) conditioning_mask = torch.from_numpy(conditioning_mask[None, None])
# Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0 if round_image_mask:
conditioning_mask = torch.round(conditioning_mask) # Caller is requesting a discretized mask as input, so we round to either 1.0 or 0.0
conditioning_mask = torch.round(conditioning_mask)
else: else:
conditioning_mask = source_image.new_ones(1, 1, *source_image.shape[-2:]) conditioning_mask = source_image.new_ones(1, 1, *source_image.shape[-2:])
@@ -345,7 +370,7 @@ class StableDiffusionProcessing:
return image_conditioning return image_conditioning
def img2img_image_conditioning(self, source_image, latent_image, image_mask=None): def img2img_image_conditioning(self, source_image, latent_image, image_mask=None, round_image_mask=True):
source_image = devices.cond_cast_float(source_image) source_image = devices.cond_cast_float(source_image)
# HACK: Using introspection as the Depth2Image model doesn't appear to uniquely # HACK: Using introspection as the Depth2Image model doesn't appear to uniquely
@@ -357,11 +382,17 @@ class StableDiffusionProcessing:
return self.edit_image_conditioning(source_image) return self.edit_image_conditioning(source_image)
if self.sampler.conditioning_key in {'hybrid', 'concat'}: if self.sampler.conditioning_key in {'hybrid', 'concat'}:
return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask) return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask, round_image_mask=round_image_mask)
if self.sampler.conditioning_key == "crossattn-adm": if self.sampler.conditioning_key == "crossattn-adm":
return self.unclip_image_conditioning(source_image) return self.unclip_image_conditioning(source_image)
sd = self.sampler.model_wrap.inner_model.model.state_dict()
diffusion_model_input = sd.get('diffusion_model.input_blocks.0.0.weight', None)
if diffusion_model_input is not None:
if diffusion_model_input.shape[1] == 9:
return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask)
# Dummy zero conditioning if we're not using inpainting or depth model. # Dummy zero conditioning if we're not using inpainting or depth model.
return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1) return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1)
@@ -422,6 +453,8 @@ class StableDiffusionProcessing:
opts.sdxl_crop_top, opts.sdxl_crop_top,
self.width, self.width,
self.height, self.height,
opts.fp8_storage,
opts.cache_fp16_weight,
) )
def get_conds_with_caching(self, function, required_prompts, steps, caches, extra_network_data, hires_steps=None): def get_conds_with_caching(self, function, required_prompts, steps, caches, extra_network_data, hires_steps=None):
@@ -596,20 +629,33 @@ def decode_latent_batch(model, batch, target_device=None, check_for_nans=False):
sample = decode_first_stage(model, batch[i:i + 1])[0] sample = decode_first_stage(model, batch[i:i + 1])[0]
if check_for_nans: if check_for_nans:
try: try:
devices.test_for_nans(sample, "vae") devices.test_for_nans(sample, "vae")
except devices.NansException as e: except devices.NansException as e:
if devices.dtype_vae == torch.float32 or not shared.opts.auto_vae_precision: if shared.opts.auto_vae_precision_bfloat16:
autofix_dtype = torch.bfloat16
autofix_dtype_text = "bfloat16"
autofix_dtype_setting = "Automatically convert VAE to bfloat16"
autofix_dtype_comment = ""
elif shared.opts.auto_vae_precision:
autofix_dtype = torch.float32
autofix_dtype_text = "32-bit float"
autofix_dtype_setting = "Automatically revert VAE to 32-bit floats"
autofix_dtype_comment = "\nTo always start with 32-bit VAE, use --no-half-vae commandline flag."
else:
raise e
if devices.dtype_vae == autofix_dtype:
raise e raise e
errors.print_error_explanation( errors.print_error_explanation(
"A tensor with all NaNs was produced in VAE.\n" "A tensor with all NaNs was produced in VAE.\n"
"Web UI will now convert VAE into 32-bit float and retry.\n" f"Web UI will now convert VAE into {autofix_dtype_text} and retry.\n"
"To disable this behavior, disable the 'Automatically revert VAE to 32-bit floats' setting.\n" f"To disable this behavior, disable the '{autofix_dtype_setting}' setting.{autofix_dtype_comment}"
"To always start with 32-bit VAE, use --no-half-vae commandline flag."
) )
devices.dtype_vae = torch.float32 devices.dtype_vae = autofix_dtype
model.first_stage_model.to(devices.dtype_vae) model.first_stage_model.to(devices.dtype_vae)
batch = batch.to(devices.dtype_vae) batch = batch.to(devices.dtype_vae)
@@ -679,12 +725,14 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
"Size": f"{p.width}x{p.height}", "Size": f"{p.width}x{p.height}",
"Model hash": p.sd_model_hash if opts.add_model_hash_to_info else None, "Model hash": p.sd_model_hash if opts.add_model_hash_to_info else None,
"Model": p.sd_model_name if opts.add_model_name_to_info else None, "Model": p.sd_model_name if opts.add_model_name_to_info else None,
"VAE hash": p.sd_vae_hash if opts.add_model_hash_to_info else None, "FP8 weight": opts.fp8_storage if devices.fp8 else None,
"VAE": p.sd_vae_name if opts.add_model_name_to_info else None, "Cache FP16 weight for LoRA": opts.cache_fp16_weight if devices.fp8 else None,
"VAE hash": p.sd_vae_hash if opts.add_vae_hash_to_info else None,
"VAE": p.sd_vae_name if opts.add_vae_name_to_info else None,
"Variation seed": (None if p.subseed_strength == 0 else (p.all_subseeds[0] if use_main_prompt else all_subseeds[index])), "Variation seed": (None if p.subseed_strength == 0 else (p.all_subseeds[0] if use_main_prompt else all_subseeds[index])),
"Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength), "Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
"Seed resize from": (None if p.seed_resize_from_w <= 0 or p.seed_resize_from_h <= 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"), "Seed resize from": (None if p.seed_resize_from_w <= 0 or p.seed_resize_from_h <= 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
"Denoising strength": getattr(p, 'denoising_strength', None), "Denoising strength": p.extra_generation_params.get("Denoising strength"),
"Conditional mask weight": getattr(p, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) if p.is_using_inpainting_conditioning else None, "Conditional mask weight": getattr(p, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) if p.is_using_inpainting_conditioning else None,
"Clip skip": None if clip_skip <= 1 else clip_skip, "Clip skip": None if clip_skip <= 1 else clip_skip,
"ENSD": opts.eta_noise_seed_delta if uses_ensd else None, "ENSD": opts.eta_noise_seed_delta if uses_ensd else None,
@@ -699,7 +747,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
"User": p.user if opts.add_user_name_to_info else None, "User": p.user if opts.add_user_name_to_info else None,
} }
generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None]) generation_params_text = infotext_utils.build_infotext(generation_params)
prompt_text = p.main_prompt if use_main_prompt else all_prompts[index] prompt_text = p.main_prompt if use_main_prompt else all_prompts[index]
negative_prompt_text = f"\nNegative prompt: {p.main_negative_prompt if use_main_prompt else all_negative_prompts[index]}" if all_negative_prompts[index] else "" negative_prompt_text = f"\nNegative prompt: {p.main_negative_prompt if use_main_prompt else all_negative_prompts[index]}" if all_negative_prompts[index] else ""
@@ -711,7 +759,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
if p.scripts is not None: if p.scripts is not None:
p.scripts.before_process(p) p.scripts.before_process(p)
stored_opts = {k: opts.data[k] for k in p.override_settings.keys() if k in opts.data} stored_opts = {k: opts.data[k] if k in opts.data else opts.get_default(k) for k in p.override_settings.keys() if k in opts.data}
try: try:
# if no checkpoint override or the override checkpoint can't be found, remove override entry and load opts checkpoint # if no checkpoint override or the override checkpoint can't be found, remove override entry and load opts checkpoint
@@ -799,7 +847,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
infotexts = [] infotexts = []
output_images = [] output_images = []
with torch.no_grad(), p.sd_model.ema_scope(): with torch.no_grad(), p.sd_model.ema_scope():
with devices.autocast(): with devices.autocast():
p.init(p.all_prompts, p.all_seeds, p.all_subseeds) p.init(p.all_prompts, p.all_seeds, p.all_subseeds)
@@ -819,7 +866,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if state.skipped: if state.skipped:
state.skipped = False state.skipped = False
if state.interrupted: if state.interrupted or state.stopping_generation:
break break
sd_models.reload_model_weights() # model can be changed for example by refiner sd_models.reload_model_weights() # model can be changed for example by refiner
@@ -865,15 +912,47 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if p.n_iter > 1: if p.n_iter > 1:
shared.state.job = f"Batch {n+1} out of {p.n_iter}" shared.state.job = f"Batch {n+1} out of {p.n_iter}"
def rescale_zero_terminal_snr_abar(alphas_cumprod):
alphas_bar_sqrt = alphas_cumprod.sqrt()
# Store old values.
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
# Shift so the last timestep is zero.
alphas_bar_sqrt -= (alphas_bar_sqrt_T)
# Scale so the first timestep is back to the old value.
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
# Convert alphas_bar_sqrt to betas
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
alphas_bar[-1] = 4.8973451890853435e-08
return alphas_bar
if hasattr(p.sd_model, 'alphas_cumprod') and hasattr(p.sd_model, 'alphas_cumprod_original'):
p.sd_model.alphas_cumprod = p.sd_model.alphas_cumprod_original.to(shared.device)
if opts.use_downcasted_alpha_bar:
p.extra_generation_params['Downcast alphas_cumprod'] = opts.use_downcasted_alpha_bar
p.sd_model.alphas_cumprod = p.sd_model.alphas_cumprod.half().to(shared.device)
if opts.sd_noise_schedule == "Zero Terminal SNR":
p.extra_generation_params['Noise Schedule'] = opts.sd_noise_schedule
p.sd_model.alphas_cumprod = rescale_zero_terminal_snr_abar(p.sd_model.alphas_cumprod).to(shared.device)
with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast(): with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():
samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts) samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts)
if p.scripts is not None:
ps = scripts.PostSampleArgs(samples_ddim)
p.scripts.post_sample(p, ps)
samples_ddim = ps.samples
if getattr(samples_ddim, 'already_decoded', False): if getattr(samples_ddim, 'already_decoded', False):
x_samples_ddim = samples_ddim x_samples_ddim = samples_ddim
else: else:
if opts.sd_vae_decode_method != 'Full': if opts.sd_vae_decode_method != 'Full':
p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method
x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True) x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True)
x_samples_ddim = torch.stack(x_samples_ddim).float() x_samples_ddim = torch.stack(x_samples_ddim).float()
@@ -886,6 +965,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
devices.torch_gc() devices.torch_gc()
state.nextjob()
if p.scripts is not None: if p.scripts is not None:
p.scripts.postprocess_batch(p, x_samples_ddim, batch_number=n) p.scripts.postprocess_batch(p, x_samples_ddim, batch_number=n)
@@ -922,13 +1003,36 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
pp = scripts.PostprocessImageArgs(image) pp = scripts.PostprocessImageArgs(image)
p.scripts.postprocess_image(p, pp) p.scripts.postprocess_image(p, pp)
image = pp.image image = pp.image
mask_for_overlay = getattr(p, "mask_for_overlay", None)
overlay_image = p.overlay_images[i] if getattr(p, "overlay_images", None) is not None and i < len(p.overlay_images) else None
if p.scripts is not None:
ppmo = scripts.PostProcessMaskOverlayArgs(i, mask_for_overlay, overlay_image)
p.scripts.postprocess_maskoverlay(p, ppmo)
mask_for_overlay, overlay_image = ppmo.mask_for_overlay, ppmo.overlay_image
if p.color_corrections is not None and i < len(p.color_corrections): if p.color_corrections is not None and i < len(p.color_corrections):
if save_samples and opts.save_images_before_color_correction: if save_samples and opts.save_images_before_color_correction:
image_without_cc = apply_overlay(image, p.paste_to, i, p.overlay_images) image_without_cc = apply_overlay(image, p.paste_to, overlay_image)
images.save_image(image_without_cc, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-before-color-correction") images.save_image(image_without_cc, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-before-color-correction")
image = apply_color_correction(p.color_corrections[i], image) image = apply_color_correction(p.color_corrections[i], image)
image = apply_overlay(image, p.paste_to, i, p.overlay_images) # If the intention is to show the output from the model
# that is being composited over the original image,
# we need to keep the original image around
# and use it in the composite step.
original_denoised_image = image.copy()
if p.paste_to is not None:
original_denoised_image = uncrop(original_denoised_image, (overlay_image.width, overlay_image.height), p.paste_to)
image = apply_overlay(image, p.paste_to, overlay_image)
if p.scripts is not None:
pp = scripts.PostprocessImageArgs(image)
p.scripts.postprocess_image_after_composite(p, pp)
image = pp.image
if save_samples: if save_samples:
images.save_image(image, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p) images.save_image(image, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p)
@@ -938,27 +1042,28 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if opts.enable_pnginfo: if opts.enable_pnginfo:
image.info["parameters"] = text image.info["parameters"] = text
output_images.append(image) output_images.append(image)
if save_samples and hasattr(p, 'mask_for_overlay') and p.mask_for_overlay and any([opts.save_mask, opts.save_mask_composite, opts.return_mask, opts.return_mask_composite]):
image_mask = p.mask_for_overlay.convert('RGB')
image_mask_composite = Image.composite(image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), images.resize_image(2, p.mask_for_overlay, image.width, image.height).convert('L')).convert('RGBA')
if opts.save_mask: if mask_for_overlay is not None:
images.save_image(image_mask, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask") if opts.return_mask or opts.save_mask:
image_mask = mask_for_overlay.convert('RGB')
if save_samples and opts.save_mask:
images.save_image(image_mask, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask")
if opts.return_mask:
output_images.append(image_mask)
if opts.save_mask_composite: if opts.return_mask_composite or opts.save_mask_composite:
images.save_image(image_mask_composite, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask-composite") image_mask_composite = Image.composite(original_denoised_image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), images.resize_image(2, mask_for_overlay, image.width, image.height).convert('L')).convert('RGBA')
if save_samples and opts.save_mask_composite:
if opts.return_mask: images.save_image(image_mask_composite, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask-composite")
output_images.append(image_mask) if opts.return_mask_composite:
output_images.append(image_mask_composite)
if opts.return_mask_composite:
output_images.append(image_mask_composite)
del x_samples_ddim del x_samples_ddim
devices.torch_gc() devices.torch_gc()
state.nextjob() if not infotexts:
infotexts.append(Processed(p, []).infotext(p, 0))
p.color_corrections = None p.color_corrections = None
@@ -1025,6 +1130,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
hr_sampler_name: str = None hr_sampler_name: str = None
hr_prompt: str = '' hr_prompt: str = ''
hr_negative_prompt: str = '' hr_negative_prompt: str = ''
force_task_id: str = None
cached_hr_uc = [None, None] cached_hr_uc = [None, None]
cached_hr_c = [None, None] cached_hr_c = [None, None]
@@ -1097,7 +1203,9 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
def init(self, all_prompts, all_seeds, all_subseeds): def init(self, all_prompts, all_seeds, all_subseeds):
if self.enable_hr: if self.enable_hr:
if self.hr_checkpoint_name: self.extra_generation_params["Denoising strength"] = self.denoising_strength
if self.hr_checkpoint_name and self.hr_checkpoint_name != 'Use same checkpoint':
self.hr_checkpoint_info = sd_models.get_closet_checkpoint_match(self.hr_checkpoint_name) self.hr_checkpoint_info = sd_models.get_closet_checkpoint_match(self.hr_checkpoint_name)
if self.hr_checkpoint_info is None: if self.hr_checkpoint_info is None:
@@ -1124,8 +1232,11 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
if not state.processing_has_refined_job_count: if not state.processing_has_refined_job_count:
if state.job_count == -1: if state.job_count == -1:
state.job_count = self.n_iter state.job_count = self.n_iter
if getattr(self, 'txt2img_upscale', False):
shared.total_tqdm.updateTotal((self.steps + (self.hr_second_pass_steps or self.steps)) * state.job_count) total_steps = (self.hr_second_pass_steps or self.steps) * state.job_count
else:
total_steps = (self.steps + (self.hr_second_pass_steps or self.steps)) * state.job_count
shared.total_tqdm.updateTotal(total_steps)
state.job_count = state.job_count * 2 state.job_count = state.job_count * 2
state.processing_has_refined_job_count = True state.processing_has_refined_job_count = True
@@ -1138,23 +1249,49 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts): def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model) self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
x = self.rng.next() if self.firstpass_image is not None and self.enable_hr:
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x)) # here we don't need to generate image, we just take self.firstpass_image and prepare it for hires fix
del x
if not self.enable_hr: if self.latent_scale_mode is None:
return samples image = np.array(self.firstpass_image).astype(np.float32) / 255.0 * 2.0 - 1.0
image = np.moveaxis(image, 2, 0)
samples = None
decoded_samples = torch.asarray(np.expand_dims(image, 0))
else:
image = np.array(self.firstpass_image).astype(np.float32) / 255.0
image = np.moveaxis(image, 2, 0)
image = torch.from_numpy(np.expand_dims(image, axis=0))
image = image.to(shared.device, dtype=devices.dtype_vae)
if opts.sd_vae_encode_method != 'Full':
self.extra_generation_params['VAE Encoder'] = opts.sd_vae_encode_method
samples = images_tensor_to_samples(image, approximation_indexes.get(opts.sd_vae_encode_method), self.sd_model)
decoded_samples = None
devices.torch_gc()
if self.latent_scale_mode is None:
decoded_samples = torch.stack(decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)).to(dtype=torch.float32)
else: else:
decoded_samples = None # here we generate an image normally
x = self.rng.next()
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
del x
if not self.enable_hr:
return samples
devices.torch_gc()
if self.latent_scale_mode is None:
decoded_samples = torch.stack(decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)).to(dtype=torch.float32)
else:
decoded_samples = None
with sd_models.SkipWritingToConfig(): with sd_models.SkipWritingToConfig():
sd_models.reload_model_weights(info=self.hr_checkpoint_info) sd_models.reload_model_weights(info=self.hr_checkpoint_info)
devices.torch_gc()
return self.sample_hr_pass(samples, decoded_samples, seeds, subseeds, subseed_strength, prompts) return self.sample_hr_pass(samples, decoded_samples, seeds, subseeds, subseed_strength, prompts)
def sample_hr_pass(self, samples, decoded_samples, seeds, subseeds, subseed_strength, prompts): def sample_hr_pass(self, samples, decoded_samples, seeds, subseeds, subseed_strength, prompts):
@@ -1162,7 +1299,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
return samples return samples
self.is_hr_pass = True self.is_hr_pass = True
target_width = self.hr_upscale_to_x target_width = self.hr_upscale_to_x
target_height = self.hr_upscale_to_y target_height = self.hr_upscale_to_y
@@ -1251,7 +1387,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
decoded_samples = decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True) decoded_samples = decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)
self.is_hr_pass = False self.is_hr_pass = False
return decoded_samples return decoded_samples
def close(self): def close(self):
@@ -1354,12 +1489,14 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
mask_blur_x: int = 4 mask_blur_x: int = 4
mask_blur_y: int = 4 mask_blur_y: int = 4
mask_blur: int = None mask_blur: int = None
mask_round: bool = True
inpainting_fill: int = 0 inpainting_fill: int = 0
inpaint_full_res: bool = True inpaint_full_res: bool = True
inpaint_full_res_padding: int = 0 inpaint_full_res_padding: int = 0
inpainting_mask_invert: int = 0 inpainting_mask_invert: int = 0
initial_noise_multiplier: float = None initial_noise_multiplier: float = None
latent_mask: Image = None latent_mask: Image = None
force_task_id: str = None
image_mask: Any = field(default=None, init=False) image_mask: Any = field(default=None, init=False)
@@ -1389,6 +1526,8 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
self.mask_blur_y = value self.mask_blur_y = value
def init(self, all_prompts, all_seeds, all_subseeds): def init(self, all_prompts, all_seeds, all_subseeds):
self.extra_generation_params["Denoising strength"] = self.denoising_strength
self.image_cfg_scale: float = self.image_cfg_scale if shared.sd_model.cond_stage_key == "edit" else None self.image_cfg_scale: float = self.image_cfg_scale if shared.sd_model.cond_stage_key == "edit" else None
self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model) self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
@@ -1399,10 +1538,11 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
if image_mask is not None: if image_mask is not None:
# image_mask is passed in as RGBA by Gradio to support alpha masks, # image_mask is passed in as RGBA by Gradio to support alpha masks,
# but we still want to support binary masks. # but we still want to support binary masks.
image_mask = create_binary_mask(image_mask) image_mask = create_binary_mask(image_mask, round=self.mask_round)
if self.inpainting_mask_invert: if self.inpainting_mask_invert:
image_mask = ImageOps.invert(image_mask) image_mask = ImageOps.invert(image_mask)
self.extra_generation_params["Mask mode"] = "Inpaint not masked"
if self.mask_blur_x > 0: if self.mask_blur_x > 0:
np_mask = np.array(image_mask) np_mask = np.array(image_mask)
@@ -1416,16 +1556,22 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
np_mask = cv2.GaussianBlur(np_mask, (1, kernel_size), self.mask_blur_y) np_mask = cv2.GaussianBlur(np_mask, (1, kernel_size), self.mask_blur_y)
image_mask = Image.fromarray(np_mask) image_mask = Image.fromarray(np_mask)
if self.mask_blur_x > 0 or self.mask_blur_y > 0:
self.extra_generation_params["Mask blur"] = self.mask_blur
if self.inpaint_full_res: if self.inpaint_full_res:
self.mask_for_overlay = image_mask self.mask_for_overlay = image_mask
mask = image_mask.convert('L') mask = image_mask.convert('L')
crop_region = masking.get_crop_region(np.array(mask), self.inpaint_full_res_padding) crop_region = masking.get_crop_region(mask, self.inpaint_full_res_padding)
crop_region = masking.expand_crop_region(crop_region, self.width, self.height, mask.width, mask.height) crop_region = masking.expand_crop_region(crop_region, self.width, self.height, mask.width, mask.height)
x1, y1, x2, y2 = crop_region x1, y1, x2, y2 = crop_region
mask = mask.crop(crop_region) mask = mask.crop(crop_region)
image_mask = images.resize_image(2, mask, self.width, self.height) image_mask = images.resize_image(2, mask, self.width, self.height)
self.paste_to = (x1, y1, x2-x1, y2-y1) self.paste_to = (x1, y1, x2-x1, y2-y1)
self.extra_generation_params["Inpaint area"] = "Only masked"
self.extra_generation_params["Masked area padding"] = self.inpaint_full_res_padding
else: else:
image_mask = images.resize_image(self.resize_mode, image_mask, self.width, self.height) image_mask = images.resize_image(self.resize_mode, image_mask, self.width, self.height)
np_mask = np.array(image_mask) np_mask = np.array(image_mask)
@@ -1445,7 +1591,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
# Save init image # Save init image
if opts.save_init_img: if opts.save_init_img:
self.init_img_hash = hashlib.md5(img.tobytes()).hexdigest() self.init_img_hash = hashlib.md5(img.tobytes()).hexdigest()
images.save_image(img, path=opts.outdir_init_images, basename=None, forced_filename=self.init_img_hash, save_to_dirs=False) images.save_image(img, path=opts.outdir_init_images, basename=None, forced_filename=self.init_img_hash, save_to_dirs=False, existing_info=img.info)
image = images.flatten(img, opts.img2img_background_color) image = images.flatten(img, opts.img2img_background_color)
@@ -1467,6 +1613,9 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
if self.inpainting_fill != 1: if self.inpainting_fill != 1:
image = masking.fill(image, latent_mask) image = masking.fill(image, latent_mask)
if self.inpainting_fill == 0:
self.extra_generation_params["Masked content"] = 'fill'
if add_color_corrections: if add_color_corrections:
self.color_corrections.append(setup_color_correction(image)) self.color_corrections.append(setup_color_correction(image))
@@ -1506,7 +1655,8 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
latmask = init_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2])) latmask = init_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2]))
latmask = np.moveaxis(np.array(latmask, dtype=np.float32), 2, 0) / 255 latmask = np.moveaxis(np.array(latmask, dtype=np.float32), 2, 0) / 255
latmask = latmask[0] latmask = latmask[0]
latmask = np.around(latmask) if self.mask_round:
latmask = np.around(latmask)
latmask = np.tile(latmask[None], (4, 1, 1)) latmask = np.tile(latmask[None], (4, 1, 1))
self.mask = torch.asarray(1.0 - latmask).to(shared.device).type(self.sd_model.dtype) self.mask = torch.asarray(1.0 - latmask).to(shared.device).type(self.sd_model.dtype)
@@ -1515,10 +1665,13 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
# this needs to be fixed to be done in sample() using actual seeds for batches # this needs to be fixed to be done in sample() using actual seeds for batches
if self.inpainting_fill == 2: if self.inpainting_fill == 2:
self.init_latent = self.init_latent * self.mask + create_random_tensors(self.init_latent.shape[1:], all_seeds[0:self.init_latent.shape[0]]) * self.nmask self.init_latent = self.init_latent * self.mask + create_random_tensors(self.init_latent.shape[1:], all_seeds[0:self.init_latent.shape[0]]) * self.nmask
self.extra_generation_params["Masked content"] = 'latent noise'
elif self.inpainting_fill == 3: elif self.inpainting_fill == 3:
self.init_latent = self.init_latent * self.mask self.init_latent = self.init_latent * self.mask
self.extra_generation_params["Masked content"] = 'latent nothing'
self.image_conditioning = self.img2img_image_conditioning(image * 2 - 1, self.init_latent, image_mask) self.image_conditioning = self.img2img_image_conditioning(image * 2 - 1, self.init_latent, image_mask, self.mask_round)
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts): def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
x = self.rng.next() x = self.rng.next()
@@ -1530,7 +1683,14 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning) samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning)
if self.mask is not None: if self.mask is not None:
samples = samples * self.nmask + self.init_latent * self.mask blended_samples = samples * self.nmask + self.init_latent * self.mask
if self.scripts is not None:
mba = scripts.MaskBlendArgs(samples, self.nmask, self.init_latent, self.mask, blended_samples)
self.scripts.on_mask_blend(self, mba)
blended_samples = mba.blended_latent
samples = blended_samples
del x del x
devices.torch_gc() devices.torch_gc()
+4 -3
View File
@@ -1,6 +1,7 @@
import gradio as gr import gradio as gr
from modules import scripts, sd_models from modules import scripts, sd_models
from modules.infotext_utils import PasteField
from modules.ui_common import create_refresh_button from modules.ui_common import create_refresh_button
from modules.ui_components import InputAccordion from modules.ui_components import InputAccordion
@@ -31,9 +32,9 @@ class ScriptRefiner(scripts.ScriptBuiltinUI):
return None if info is None else info.title return None if info is None else info.title
self.infotext_fields = [ self.infotext_fields = [
(enable_refiner, lambda d: 'Refiner' in d), PasteField(enable_refiner, lambda d: 'Refiner' in d),
(refiner_checkpoint, lambda d: lookup_checkpoint(d.get('Refiner'))), PasteField(refiner_checkpoint, lambda d: lookup_checkpoint(d.get('Refiner')), api="refiner_checkpoint"),
(refiner_switch_at, 'Refiner switch at'), PasteField(refiner_switch_at, 'Refiner switch at', api="refiner_switch_at"),
] ]
return enable_refiner, refiner_checkpoint, refiner_switch_at return enable_refiner, refiner_checkpoint, refiner_switch_at
+13 -19
View File
@@ -3,8 +3,10 @@ import json
import gradio as gr import gradio as gr
from modules import scripts, ui, errors from modules import scripts, ui, errors
from modules.infotext_utils import PasteField
from modules.shared import cmd_opts from modules.shared import cmd_opts
from modules.ui_components import ToolButton from modules.ui_components import ToolButton
from modules import infotext_utils
class ScriptSeed(scripts.ScriptBuiltinUI): class ScriptSeed(scripts.ScriptBuiltinUI):
@@ -51,12 +53,12 @@ class ScriptSeed(scripts.ScriptBuiltinUI):
seed_checkbox.change(lambda x: gr.update(visible=x), show_progress=False, inputs=[seed_checkbox], outputs=[seed_extras]) seed_checkbox.change(lambda x: gr.update(visible=x), show_progress=False, inputs=[seed_checkbox], outputs=[seed_extras])
self.infotext_fields = [ self.infotext_fields = [
(self.seed, "Seed"), PasteField(self.seed, "Seed", api="seed"),
(seed_checkbox, lambda d: "Variation seed" in d or "Seed resize from-1" in d), PasteField(seed_checkbox, lambda d: "Variation seed" in d or "Seed resize from-1" in d),
(subseed, "Variation seed"), PasteField(subseed, "Variation seed", api="subseed"),
(subseed_strength, "Variation seed strength"), PasteField(subseed_strength, "Variation seed strength", api="subseed_strength"),
(seed_resize_from_w, "Seed resize from-1"), PasteField(seed_resize_from_w, "Seed resize from-1", api="seed_resize_from_h"),
(seed_resize_from_h, "Seed resize from-2"), PasteField(seed_resize_from_h, "Seed resize from-2", api="seed_resize_from_w"),
] ]
self.on_after_component(lambda x: connect_reuse_seed(self.seed, reuse_seed, x.component, False), elem_id=f'generation_info_{self.tabname}') self.on_after_component(lambda x: connect_reuse_seed(self.seed, reuse_seed, x.component, False), elem_id=f'generation_info_{self.tabname}')
@@ -76,7 +78,6 @@ class ScriptSeed(scripts.ScriptBuiltinUI):
p.seed_resize_from_h = seed_resize_from_h p.seed_resize_from_h = seed_resize_from_h
def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info: gr.Textbox, is_subseed): def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info: gr.Textbox, is_subseed):
""" Connects a 'reuse (sub)seed' button's click event so that it copies last used """ Connects a 'reuse (sub)seed' button's click event so that it copies last used
(sub)seed value from generation info the to the seed field. If copying subseed and subseed strength (sub)seed value from generation info the to the seed field. If copying subseed and subseed strength
@@ -84,21 +85,14 @@ def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info:
def copy_seed(gen_info_string: str, index): def copy_seed(gen_info_string: str, index):
res = -1 res = -1
try: try:
gen_info = json.loads(gen_info_string) gen_info = json.loads(gen_info_string)
index -= gen_info.get('index_of_first_image', 0) infotext = gen_info.get('infotexts')[index]
gen_parameters = infotext_utils.parse_generation_parameters(infotext, [])
if is_subseed and gen_info.get('subseed_strength', 0) > 0: res = int(gen_parameters.get('Variation seed' if is_subseed else 'Seed', -1))
all_subseeds = gen_info.get('all_subseeds', [-1]) except Exception:
res = all_subseeds[index if 0 <= index < len(all_subseeds) else 0]
else:
all_seeds = gen_info.get('all_seeds', [-1])
res = all_seeds[index if 0 <= index < len(all_seeds) else 0]
except json.decoder.JSONDecodeError:
if gen_info_string: if gen_info_string:
errors.report(f"Error parsing JSON generation info: {gen_info_string}") errors.report(f"Error retrieving seed from generation info: {gen_info_string}", exc_info=True)
return [res, gr.update()] return [res, gr.update()]
+20 -2
View File
@@ -8,10 +8,13 @@ from pydantic import BaseModel, Field
from modules.shared import opts from modules.shared import opts
import modules.shared as shared import modules.shared as shared
from collections import OrderedDict
import string
import random
from typing import List
current_task = None current_task = None
pending_tasks = {} pending_tasks = OrderedDict()
finished_tasks = [] finished_tasks = []
recorded_results = [] recorded_results = []
recorded_results_limit = 2 recorded_results_limit = 2
@@ -34,6 +37,11 @@ def finish_task(id_task):
if len(finished_tasks) > 16: if len(finished_tasks) > 16:
finished_tasks.pop(0) finished_tasks.pop(0)
def create_task_id(task_type):
N = 7
res = ''.join(random.choices(string.ascii_uppercase +
string.digits, k=N))
return f"task({task_type}-{res})"
def record_results(id_task, res): def record_results(id_task, res):
recorded_results.append((id_task, res)) recorded_results.append((id_task, res))
@@ -44,6 +52,9 @@ def record_results(id_task, res):
def add_task_to_queue(id_job): def add_task_to_queue(id_job):
pending_tasks[id_job] = time.time() pending_tasks[id_job] = time.time()
class PendingTasksResponse(BaseModel):
size: int = Field(title="Pending task size")
tasks: List[str] = Field(title="Pending task ids")
class ProgressRequest(BaseModel): class ProgressRequest(BaseModel):
id_task: str = Field(default=None, title="Task ID", description="id of the task to get progress for") id_task: str = Field(default=None, title="Task ID", description="id of the task to get progress for")
@@ -63,9 +74,16 @@ class ProgressResponse(BaseModel):
def setup_progress_api(app): def setup_progress_api(app):
app.add_api_route("/internal/pending-tasks", get_pending_tasks, methods=["GET"])
return app.add_api_route("/internal/progress", progressapi, methods=["POST"], response_model=ProgressResponse) return app.add_api_route("/internal/progress", progressapi, methods=["POST"], response_model=ProgressResponse)
def get_pending_tasks():
pending_tasks_ids = list(pending_tasks)
pending_len = len(pending_tasks_ids)
return PendingTasksResponse(size=pending_len, tasks=pending_tasks_ids)
def progressapi(req: ProgressRequest): def progressapi(req: ProgressRequest):
active = req.id_task == current_task active = req.id_task == current_task
queued = req.id_task in pending_tasks queued = req.id_task in pending_tasks
+1 -1
View File
@@ -4,7 +4,7 @@ import re
from collections import namedtuple from collections import namedtuple
import lark import lark
# a prompt like this: "fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][ in background:0.25] [shoddy:masterful:0.5]" # a prompt like this: "fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][: in background:0.25] [shoddy:masterful:0.5]"
# will be represented with prompt_schedule like this (assuming steps=100): # will be represented with prompt_schedule like this (assuming steps=100):
# [25, 'fantasy landscape with a mountain and an oak in foreground shoddy'] # [25, 'fantasy landscape with a mountain and an oak in foreground shoddy']
# [50, 'fantasy landscape with a lake and an oak in foreground in background shoddy'] # [50, 'fantasy landscape with a lake and an oak in foreground in background shoddy']
+65 -93
View File
@@ -1,12 +1,9 @@
import os import os
import numpy as np
from PIL import Image
from realesrgan import RealESRGANer
from modules.upscaler import Upscaler, UpscalerData
from modules.shared import cmd_opts, opts
from modules import modelloader, errors from modules import modelloader, errors
from modules.shared import cmd_opts, opts
from modules.upscaler import Upscaler, UpscalerData
from modules.upscaler_utils import upscale_with_model
class UpscalerRealESRGAN(Upscaler): class UpscalerRealESRGAN(Upscaler):
@@ -14,29 +11,20 @@ class UpscalerRealESRGAN(Upscaler):
self.name = "RealESRGAN" self.name = "RealESRGAN"
self.user_path = path self.user_path = path
super().__init__() super().__init__()
try: self.enable = True
from basicsr.archs.rrdbnet_arch import RRDBNet # noqa: F401 self.scalers = []
from realesrgan import RealESRGANer # noqa: F401 scalers = get_realesrgan_models(self)
from realesrgan.archs.srvgg_arch import SRVGGNetCompact # noqa: F401
self.enable = True
self.scalers = []
scalers = self.load_models(path)
local_model_paths = self.find_models(ext_filter=[".pth"]) local_model_paths = self.find_models(ext_filter=[".pth"])
for scaler in scalers: for scaler in scalers:
if scaler.local_data_path.startswith("http"): if scaler.local_data_path.startswith("http"):
filename = modelloader.friendly_name(scaler.local_data_path) filename = modelloader.friendly_name(scaler.local_data_path)
local_model_candidates = [local_model for local_model in local_model_paths if local_model.endswith(f"{filename}.pth")] local_model_candidates = [local_model for local_model in local_model_paths if local_model.endswith(f"{filename}.pth")]
if local_model_candidates: if local_model_candidates:
scaler.local_data_path = local_model_candidates[0] scaler.local_data_path = local_model_candidates[0]
if scaler.name in opts.realesrgan_enabled_models: if scaler.name in opts.realesrgan_enabled_models:
self.scalers.append(scaler) self.scalers.append(scaler)
except Exception:
errors.report("Error importing Real-ESRGAN", exc_info=True)
self.enable = False
self.scalers = []
def do_upscale(self, img, path): def do_upscale(self, img, path):
if not self.enable: if not self.enable:
@@ -48,20 +36,19 @@ class UpscalerRealESRGAN(Upscaler):
errors.report(f"Unable to load RealESRGAN model {path}", exc_info=True) errors.report(f"Unable to load RealESRGAN model {path}", exc_info=True)
return img return img
upsampler = RealESRGANer( model_descriptor = modelloader.load_spandrel_model(
scale=info.scale, info.local_data_path,
model_path=info.local_data_path,
model=info.model(),
half=not cmd_opts.no_half and not cmd_opts.upcast_sampling,
tile=opts.ESRGAN_tile,
tile_pad=opts.ESRGAN_tile_overlap,
device=self.device, device=self.device,
prefer_half=(not cmd_opts.no_half and not cmd_opts.upcast_sampling),
expected_architecture="ESRGAN", # "RealESRGAN" isn't a specific thing for Spandrel
)
return upscale_with_model(
model_descriptor,
img,
tile_size=opts.ESRGAN_tile,
tile_overlap=opts.ESRGAN_tile_overlap,
# TODO: `outscale`?
) )
upsampled = upsampler.enhance(np.array(img), outscale=info.scale)[0]
image = Image.fromarray(upsampled)
return image
def load_model(self, path): def load_model(self, path):
for scaler in self.scalers: for scaler in self.scalers:
@@ -76,58 +63,43 @@ class UpscalerRealESRGAN(Upscaler):
return scaler return scaler
raise ValueError(f"Unable to find model info: {path}") raise ValueError(f"Unable to find model info: {path}")
def load_models(self, _):
return get_realesrgan_models(self)
def get_realesrgan_models(scaler: UpscalerRealESRGAN):
def get_realesrgan_models(scaler): return [
try: UpscalerData(
from basicsr.archs.rrdbnet_arch import RRDBNet name="R-ESRGAN General 4xV3",
from realesrgan.archs.srvgg_arch import SRVGGNetCompact path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth",
models = [ scale=4,
UpscalerData( upscaler=scaler,
name="R-ESRGAN General 4xV3", ),
path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth", UpscalerData(
scale=4, name="R-ESRGAN General WDN 4xV3",
upscaler=scaler, path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth",
model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu') scale=4,
), upscaler=scaler,
UpscalerData( ),
name="R-ESRGAN General WDN 4xV3", UpscalerData(
path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth", name="R-ESRGAN AnimeVideo",
scale=4, path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth",
upscaler=scaler, scale=4,
model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu') upscaler=scaler,
), ),
UpscalerData( UpscalerData(
name="R-ESRGAN AnimeVideo", name="R-ESRGAN 4x+",
path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth", path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
scale=4, scale=4,
upscaler=scaler, upscaler=scaler,
model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu') ),
), UpscalerData(
UpscalerData( name="R-ESRGAN 4x+ Anime6B",
name="R-ESRGAN 4x+", path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth", scale=4,
scale=4, upscaler=scaler,
upscaler=scaler, ),
model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4) UpscalerData(
), name="R-ESRGAN 2x+",
UpscalerData( path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
name="R-ESRGAN 4x+ Anime6B", scale=2,
path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth", upscaler=scaler,
scale=4, ),
upscaler=scaler, ]
model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
),
UpscalerData(
name="R-ESRGAN 2x+",
path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
scale=2,
upscaler=scaler,
model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
),
]
return models
except Exception:
errors.report("Error making Real-ESRGAN models list", exc_info=True)
+1 -1
View File
@@ -110,7 +110,7 @@ class ImageRNG:
self.is_first = True self.is_first = True
def first(self): def first(self):
noise_shape = self.shape if self.seed_resize_from_h <= 0 or self.seed_resize_from_w <= 0 else (self.shape[0], self.seed_resize_from_h // 8, self.seed_resize_from_w // 8) noise_shape = self.shape if self.seed_resize_from_h <= 0 or self.seed_resize_from_w <= 0 else (self.shape[0], int(self.seed_resize_from_h) // 8, int(self.seed_resize_from_w // 8))
xs = [] xs = []
+4 -1
View File
@@ -41,7 +41,7 @@ class ExtraNoiseParams:
class CFGDenoiserParams: class CFGDenoiserParams:
def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps, text_cond, text_uncond): def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps, text_cond, text_uncond, denoiser=None):
self.x = x self.x = x
"""Latent image representation in the process of being denoised""" """Latent image representation in the process of being denoised"""
@@ -63,6 +63,9 @@ class CFGDenoiserParams:
self.text_uncond = text_uncond self.text_uncond = text_uncond
""" Encoder hidden states of text conditioning from negative prompt""" """ Encoder hidden states of text conditioning from negative prompt"""
self.denoiser = denoiser
"""Current CFGDenoiser object with processing parameters"""
class CFGDenoisedParams: class CFGDenoisedParams:
def __init__(self, x, sampling_step, total_sampling_steps, inner_model): def __init__(self, x, sampling_step, total_sampling_steps, inner_model):
+229 -19
View File
@@ -11,11 +11,31 @@ from modules import shared, paths, script_callbacks, extensions, script_loading,
AlwaysVisible = object() AlwaysVisible = object()
class MaskBlendArgs:
def __init__(self, current_latent, nmask, init_latent, mask, blended_latent, denoiser=None, sigma=None):
self.current_latent = current_latent
self.nmask = nmask
self.init_latent = init_latent
self.mask = mask
self.blended_latent = blended_latent
self.denoiser = denoiser
self.is_final_blend = denoiser is None
self.sigma = sigma
class PostSampleArgs:
def __init__(self, samples):
self.samples = samples
class PostprocessImageArgs: class PostprocessImageArgs:
def __init__(self, image): def __init__(self, image):
self.image = image self.image = image
class PostProcessMaskOverlayArgs:
def __init__(self, index, mask_for_overlay, overlay_image):
self.index = index
self.mask_for_overlay = mask_for_overlay
self.overlay_image = overlay_image
class PostprocessBatchListArgs: class PostprocessBatchListArgs:
def __init__(self, images): def __init__(self, images):
@@ -71,6 +91,9 @@ class Script:
setup_for_ui_only = False setup_for_ui_only = False
"""If true, the script setup will only be run in Gradio UI, not in API""" """If true, the script setup will only be run in Gradio UI, not in API"""
controls = None
"""A list of controls retured by the ui()."""
def title(self): def title(self):
"""this function should return the title of the script. This is what will be displayed in the dropdown menu.""" """this function should return the title of the script. This is what will be displayed in the dropdown menu."""
@@ -206,6 +229,25 @@ class Script:
pass pass
def on_mask_blend(self, p, mba: MaskBlendArgs, *args):
"""
Called in inpainting mode when the original content is blended with the inpainted content.
This is called at every step in the denoising process and once at the end.
If is_final_blend is true, this is called for the final blending stage.
Otherwise, denoiser and sigma are defined and may be used to inform the procedure.
"""
pass
def post_sample(self, p, ps: PostSampleArgs, *args):
"""
Called after the samples have been generated,
but before they have been decoded by the VAE, if applicable.
Check getattr(samples, 'already_decoded', False) to test if the images are decoded.
"""
pass
def postprocess_image(self, p, pp: PostprocessImageArgs, *args): def postprocess_image(self, p, pp: PostprocessImageArgs, *args):
""" """
Called for every image after it has been generated. Called for every image after it has been generated.
@@ -213,6 +255,22 @@ class Script:
pass pass
def postprocess_maskoverlay(self, p, ppmo: PostProcessMaskOverlayArgs, *args):
"""
Called for every image after it has been generated.
"""
pass
def postprocess_image_after_composite(self, p, pp: PostprocessImageArgs, *args):
"""
Called for every image after it has been generated.
Same as postprocess_image but after inpaint_full_res composite
So that it operates on the full image instead of the inpaint_full_res crop region.
"""
pass
def postprocess(self, p, processed, *args): def postprocess(self, p, processed, *args):
""" """
This function is called after processing ends for AlwaysVisible scripts. This function is called after processing ends for AlwaysVisible scripts.
@@ -311,20 +369,113 @@ scripts_data = []
postprocessing_scripts_data = [] postprocessing_scripts_data = []
ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir", "module"]) ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir", "module"])
def topological_sort(dependencies):
"""Accepts a dictionary mapping name to its dependencies, returns a list of names ordered according to dependencies.
Ignores errors relating to missing dependeencies or circular dependencies
"""
visited = {}
result = []
def inner(name):
visited[name] = True
for dep in dependencies.get(name, []):
if dep in dependencies and dep not in visited:
inner(dep)
result.append(name)
for depname in dependencies:
if depname not in visited:
inner(depname)
return result
@dataclass
class ScriptWithDependencies:
script_canonical_name: str
file: ScriptFile
requires: list
load_before: list
load_after: list
def list_scripts(scriptdirname, extension, *, include_extensions=True): def list_scripts(scriptdirname, extension, *, include_extensions=True):
scripts_list = [] scripts = {}
basedir = os.path.join(paths.script_path, scriptdirname) loaded_extensions = {ext.canonical_name: ext for ext in extensions.active()}
if os.path.exists(basedir): loaded_extensions_scripts = {ext.canonical_name: [] for ext in extensions.active()}
for filename in sorted(os.listdir(basedir)):
scripts_list.append(ScriptFile(paths.script_path, filename, os.path.join(basedir, filename))) # build script dependency map
root_script_basedir = os.path.join(paths.script_path, scriptdirname)
if os.path.exists(root_script_basedir):
for filename in sorted(os.listdir(root_script_basedir)):
if not os.path.isfile(os.path.join(root_script_basedir, filename)):
continue
if os.path.splitext(filename)[1].lower() != extension:
continue
script_file = ScriptFile(paths.script_path, filename, os.path.join(root_script_basedir, filename))
scripts[filename] = ScriptWithDependencies(filename, script_file, [], [], [])
if include_extensions: if include_extensions:
for ext in extensions.active(): for ext in extensions.active():
scripts_list += ext.list_files(scriptdirname, extension) extension_scripts_list = ext.list_files(scriptdirname, extension)
for extension_script in extension_scripts_list:
if not os.path.isfile(extension_script.path):
continue
scripts_list = [x for x in scripts_list if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)] script_canonical_name = ("builtin/" if ext.is_builtin else "") + ext.canonical_name + "/" + extension_script.filename
relative_path = scriptdirname + "/" + extension_script.filename
script = ScriptWithDependencies(
script_canonical_name=script_canonical_name,
file=extension_script,
requires=ext.metadata.get_script_requirements("Requires", relative_path, scriptdirname),
load_before=ext.metadata.get_script_requirements("Before", relative_path, scriptdirname),
load_after=ext.metadata.get_script_requirements("After", relative_path, scriptdirname),
)
scripts[script_canonical_name] = script
loaded_extensions_scripts[ext.canonical_name].append(script)
for script_canonical_name, script in scripts.items():
# load before requires inverse dependency
# in this case, append the script name into the load_after list of the specified script
for load_before in script.load_before:
# if this requires an individual script to be loaded before
other_script = scripts.get(load_before)
if other_script:
other_script.load_after.append(script_canonical_name)
# if this requires an extension
other_extension_scripts = loaded_extensions_scripts.get(load_before)
if other_extension_scripts:
for other_script in other_extension_scripts:
other_script.load_after.append(script_canonical_name)
# if After mentions an extension, remove it and instead add all of its scripts
for load_after in list(script.load_after):
if load_after not in scripts and load_after in loaded_extensions_scripts:
script.load_after.remove(load_after)
for other_script in loaded_extensions_scripts.get(load_after, []):
script.load_after.append(other_script.script_canonical_name)
dependencies = {}
for script_canonical_name, script in scripts.items():
for required_script in script.requires:
if required_script not in scripts and required_script not in loaded_extensions:
errors.report(f'Script "{script_canonical_name}" requires "{required_script}" to be loaded, but it is not.', exc_info=False)
dependencies[script_canonical_name] = script.load_after
ordered_scripts = topological_sort(dependencies)
scripts_list = [scripts[script_canonical_name].file for script_canonical_name in ordered_scripts]
return scripts_list return scripts_list
@@ -365,15 +516,9 @@ def load_scripts():
elif issubclass(script_class, scripts_postprocessing.ScriptPostprocessing): elif issubclass(script_class, scripts_postprocessing.ScriptPostprocessing):
postprocessing_scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module)) postprocessing_scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module))
def orderby(basedir): # here the scripts_list is already ordered
# 1st webui, 2nd extensions-builtin, 3rd extensions # processing_script is not considered though
priority = {os.path.join(paths.script_path, "extensions-builtin"):1, paths.script_path:0} for scriptfile in scripts_list:
for key in priority:
if basedir.startswith(key):
return priority[key]
return 9999
for scriptfile in sorted(scripts_list, key=lambda x: [orderby(x.basedir), x]):
try: try:
if scriptfile.basedir != paths.script_path: if scriptfile.basedir != paths.script_path:
sys.path = [scriptfile.basedir] + sys.path sys.path = [scriptfile.basedir] + sys.path
@@ -433,7 +578,12 @@ class ScriptRunner:
auto_processing_scripts = scripts_auto_postprocessing.create_auto_preprocessing_script_data() auto_processing_scripts = scripts_auto_postprocessing.create_auto_preprocessing_script_data()
for script_data in auto_processing_scripts + scripts_data: for script_data in auto_processing_scripts + scripts_data:
script = script_data.script_class() try:
script = script_data.script_class()
except Exception:
errors.report(f"Error # failed to initialize Script {script_data.module}: ", exc_info=True)
continue
script.filename = script_data.path script.filename = script_data.path
script.is_txt2img = not is_img2img script.is_txt2img = not is_img2img
script.is_img2img = is_img2img script.is_img2img = is_img2img
@@ -473,17 +623,26 @@ class ScriptRunner:
on_after.clear() on_after.clear()
def create_script_ui(self, script): def create_script_ui(self, script):
import modules.api.models as api_models
script.args_from = len(self.inputs) script.args_from = len(self.inputs)
script.args_to = len(self.inputs) script.args_to = len(self.inputs)
try:
self.create_script_ui_inner(script)
except Exception:
errors.report(f"Error creating UI for {script.name}: ", exc_info=True)
def create_script_ui_inner(self, script):
import modules.api.models as api_models
controls = wrap_call(script.ui, script.filename, "ui", script.is_img2img) controls = wrap_call(script.ui, script.filename, "ui", script.is_img2img)
script.controls = controls
if controls is None: if controls is None:
return return
script.name = wrap_call(script.title, script.filename, "title", default=script.filename).lower() script.name = wrap_call(script.title, script.filename, "title", default=script.filename).lower()
api_args = [] api_args = []
for control in controls: for control in controls:
@@ -550,6 +709,8 @@ class ScriptRunner:
self.setup_ui_for_section(None, self.selectable_scripts) self.setup_ui_for_section(None, self.selectable_scripts)
def select_script(script_index): def select_script(script_index):
if script_index is None:
script_index = 0
selected_script = self.selectable_scripts[script_index - 1] if script_index>0 else None selected_script = self.selectable_scripts[script_index - 1] if script_index>0 else None
return [gr.update(visible=selected_script == s) for s in self.selectable_scripts] return [gr.update(visible=selected_script == s) for s in self.selectable_scripts]
@@ -593,7 +754,7 @@ class ScriptRunner:
def run(self, p, *args): def run(self, p, *args):
script_index = args[0] script_index = args[0]
if script_index == 0: if script_index == 0 or script_index is None:
return None return None
script = self.selectable_scripts[script_index-1] script = self.selectable_scripts[script_index-1]
@@ -672,6 +833,22 @@ class ScriptRunner:
except Exception: except Exception:
errors.report(f"Error running postprocess_batch_list: {script.filename}", exc_info=True) errors.report(f"Error running postprocess_batch_list: {script.filename}", exc_info=True)
def post_sample(self, p, ps: PostSampleArgs):
for script in self.alwayson_scripts:
try:
script_args = p.script_args[script.args_from:script.args_to]
script.post_sample(p, ps, *script_args)
except Exception:
errors.report(f"Error running post_sample: {script.filename}", exc_info=True)
def on_mask_blend(self, p, mba: MaskBlendArgs):
for script in self.alwayson_scripts:
try:
script_args = p.script_args[script.args_from:script.args_to]
script.on_mask_blend(p, mba, *script_args)
except Exception:
errors.report(f"Error running post_sample: {script.filename}", exc_info=True)
def postprocess_image(self, p, pp: PostprocessImageArgs): def postprocess_image(self, p, pp: PostprocessImageArgs):
for script in self.alwayson_scripts: for script in self.alwayson_scripts:
try: try:
@@ -680,6 +857,22 @@ class ScriptRunner:
except Exception: except Exception:
errors.report(f"Error running postprocess_image: {script.filename}", exc_info=True) errors.report(f"Error running postprocess_image: {script.filename}", exc_info=True)
def postprocess_maskoverlay(self, p, ppmo: PostProcessMaskOverlayArgs):
for script in self.alwayson_scripts:
try:
script_args = p.script_args[script.args_from:script.args_to]
script.postprocess_maskoverlay(p, ppmo, *script_args)
except Exception:
errors.report(f"Error running postprocess_image: {script.filename}", exc_info=True)
def postprocess_image_after_composite(self, p, pp: PostprocessImageArgs):
for script in self.alwayson_scripts:
try:
script_args = p.script_args[script.args_from:script.args_to]
script.postprocess_image_after_composite(p, pp, *script_args)
except Exception:
errors.report(f"Error running postprocess_image_after_composite: {script.filename}", exc_info=True)
def before_component(self, component, **kwargs): def before_component(self, component, **kwargs):
for callback, script in self.on_before_component_elem_id.get(kwargs.get("elem_id"), []): for callback, script in self.on_before_component_elem_id.get(kwargs.get("elem_id"), []):
try: try:
@@ -746,6 +939,23 @@ class ScriptRunner:
except Exception: except Exception:
errors.report(f"Error running setup: {script.filename}", exc_info=True) errors.report(f"Error running setup: {script.filename}", exc_info=True)
def set_named_arg(self, args, script_type, arg_elem_id, value):
script = next((x for x in self.scripts if type(x).__name__ == script_type), None)
if script is None:
return
for i, control in enumerate(script.controls):
if arg_elem_id in control.elem_id:
index = script.args_from + i
if isinstance(args, list):
args[index] = value
return args
elif isinstance(args, tuple):
return args[:index] + (value,) + args[index+1:]
else:
return None
scripts_txt2img: ScriptRunner = None scripts_txt2img: ScriptRunner = None
scripts_img2img: ScriptRunner = None scripts_img2img: ScriptRunner = None
+81 -5
View File
@@ -1,13 +1,56 @@
import dataclasses
import os import os
import gradio as gr import gradio as gr
from modules import errors, shared from modules import errors, shared
@dataclasses.dataclass
class PostprocessedImageSharedInfo:
target_width: int = None
target_height: int = None
class PostprocessedImage: class PostprocessedImage:
def __init__(self, image): def __init__(self, image):
self.image = image self.image = image
self.info = {} self.info = {}
self.shared = PostprocessedImageSharedInfo()
self.extra_images = []
self.nametags = []
self.disable_processing = False
self.caption = None
def get_suffix(self, used_suffixes=None):
used_suffixes = {} if used_suffixes is None else used_suffixes
suffix = "-".join(self.nametags)
if suffix:
suffix = "-" + suffix
if suffix not in used_suffixes:
used_suffixes[suffix] = 1
return suffix
for i in range(1, 100):
proposed_suffix = suffix + "-" + str(i)
if proposed_suffix not in used_suffixes:
used_suffixes[proposed_suffix] = 1
return proposed_suffix
return suffix
def create_copy(self, new_image, *, nametags=None, disable_processing=False):
pp = PostprocessedImage(new_image)
pp.shared = self.shared
pp.nametags = self.nametags.copy()
pp.info = self.info.copy()
pp.disable_processing = disable_processing
if nametags is not None:
pp.nametags += nametags
return pp
class ScriptPostprocessing: class ScriptPostprocessing:
@@ -42,10 +85,17 @@ class ScriptPostprocessing:
pass pass
def image_changed(self): def process_firstpass(self, pp: PostprocessedImage, **args):
"""
Called for all scripts before calling process(). Scripts can examine the image here and set fields
of the pp object to communicate things to other scripts.
args contains a dictionary with all values returned by components from ui()
"""
pass pass
def image_changed(self):
pass
def wrap_call(func, filename, funcname, *args, default=None, **kwargs): def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
@@ -118,16 +168,42 @@ class ScriptPostprocessingRunner:
return inputs return inputs
def run(self, pp: PostprocessedImage, args): def run(self, pp: PostprocessedImage, args):
for script in self.scripts_in_preferred_order(): scripts = []
shared.state.job = script.name
for script in self.scripts_in_preferred_order():
script_args = args[script.args_from:script.args_to] script_args = args[script.args_from:script.args_to]
process_args = {} process_args = {}
for (name, _component), value in zip(script.controls.items(), script_args): for (name, _component), value in zip(script.controls.items(), script_args):
process_args[name] = value process_args[name] = value
script.process(pp, **process_args) scripts.append((script, process_args))
for script, process_args in scripts:
script.process_firstpass(pp, **process_args)
all_images = [pp]
for script, process_args in scripts:
if shared.state.skipped:
break
shared.state.job = script.name
for single_image in all_images.copy():
if not single_image.disable_processing:
script.process(single_image, **process_args)
for extra_image in single_image.extra_images:
if not isinstance(extra_image, PostprocessedImage):
extra_image = single_image.create_copy(extra_image)
all_images.append(extra_image)
single_image.extra_images.clear()
pp.extra_images = all_images[1:]
def create_args_for_run(self, scripts_args): def create_args_for_run(self, scripts_args):
if not self.ui_created: if not self.ui_created:
+1 -1
View File
@@ -215,7 +215,7 @@ class LoadStateDictOnMeta(ReplaceHelper):
would be on the meta device. would be on the meta device.
""" """
if state_dict == sd: if state_dict is sd:
state_dict = {k: v.to(device="meta", dtype=v.dtype) for k, v in state_dict.items()} state_dict = {k: v.to(device="meta", dtype=v.dtype) for k, v in state_dict.items()}
original(module, state_dict, strict=strict) original(module, state_dict, strict=strict)
+26 -6
View File
@@ -5,7 +5,7 @@ from types import MethodType
from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet, patches from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet, patches
from modules.hypernetworks import hypernetwork from modules.hypernetworks import hypernetwork
from modules.shared import cmd_opts from modules.shared import cmd_opts
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr, xlmr_m18
import ldm.modules.attention import ldm.modules.attention
import ldm.modules.diffusionmodules.model import ldm.modules.diffusionmodules.model
@@ -38,8 +38,12 @@ ldm.models.diffusion.ddpm.print = shared.ldm_print
optimizers = [] optimizers = []
current_optimizer: sd_hijack_optimizations.SdOptimization = None current_optimizer: sd_hijack_optimizations.SdOptimization = None
ldm_original_forward = patches.patch(__file__, ldm.modules.diffusionmodules.openaimodel.UNetModel, "forward", sd_unet.UNetModel_forward) ldm_patched_forward = sd_unet.create_unet_forward(ldm.modules.diffusionmodules.openaimodel.UNetModel.forward)
sgm_original_forward = patches.patch(__file__, sgm.modules.diffusionmodules.openaimodel.UNetModel, "forward", sd_unet.UNetModel_forward) ldm_original_forward = patches.patch(__file__, ldm.modules.diffusionmodules.openaimodel.UNetModel, "forward", ldm_patched_forward)
sgm_patched_forward = sd_unet.create_unet_forward(sgm.modules.diffusionmodules.openaimodel.UNetModel.forward)
sgm_original_forward = patches.patch(__file__, sgm.modules.diffusionmodules.openaimodel.UNetModel, "forward", sgm_patched_forward)
def list_optimizers(): def list_optimizers():
new_optimizers = script_callbacks.list_optimizers_callback() new_optimizers = script_callbacks.list_optimizers_callback()
@@ -184,6 +188,20 @@ class StableDiffusionModelHijack:
errors.display(e, "applying cross attention optimization") errors.display(e, "applying cross attention optimization")
undo_optimizations() undo_optimizations()
def convert_sdxl_to_ssd(self, m):
"""Converts an SDXL model to a Segmind Stable Diffusion model (see https://huggingface.co/segmind/SSD-1B)"""
delattr(m.model.diffusion_model.middle_block, '1')
delattr(m.model.diffusion_model.middle_block, '2')
for i in ['9', '8', '7', '6', '5', '4']:
delattr(m.model.diffusion_model.input_blocks[7][1].transformer_blocks, i)
delattr(m.model.diffusion_model.input_blocks[8][1].transformer_blocks, i)
delattr(m.model.diffusion_model.output_blocks[0][1].transformer_blocks, i)
delattr(m.model.diffusion_model.output_blocks[1][1].transformer_blocks, i)
delattr(m.model.diffusion_model.output_blocks[4][1].transformer_blocks, '1')
delattr(m.model.diffusion_model.output_blocks[5][1].transformer_blocks, '1')
devices.torch_gc()
def hijack(self, m): def hijack(self, m):
conditioner = getattr(m, 'conditioner', None) conditioner = getattr(m, 'conditioner', None)
if conditioner: if conditioner:
@@ -211,7 +229,7 @@ class StableDiffusionModelHijack:
else: else:
m.cond_stage_model = conditioner m.cond_stage_model = conditioner
if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation: if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation or type(m.cond_stage_model) == xlmr_m18.BertSeriesModelWithTransformation:
model_embeddings = m.cond_stage_model.roberta.embeddings model_embeddings = m.cond_stage_model.roberta.embeddings
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self) model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self)
m.cond_stage_model = sd_hijack_xlmr.FrozenXLMREmbedderWithCustomWords(m.cond_stage_model, self) m.cond_stage_model = sd_hijack_xlmr.FrozenXLMREmbedderWithCustomWords(m.cond_stage_model, self)
@@ -242,8 +260,12 @@ class StableDiffusionModelHijack:
self.layers = flatten(m) self.layers = flatten(m)
import modules.models.diffusion.ddpm_edit
if isinstance(m, ldm.models.diffusion.ddpm.LatentDiffusion): if isinstance(m, ldm.models.diffusion.ddpm.LatentDiffusion):
sd_unet.original_forward = ldm_original_forward sd_unet.original_forward = ldm_original_forward
elif isinstance(m, modules.models.diffusion.ddpm_edit.LatentDiffusion):
sd_unet.original_forward = ldm_original_forward
elif isinstance(m, sgm.models.diffusion.DiffusionEngine): elif isinstance(m, sgm.models.diffusion.DiffusionEngine):
sd_unet.original_forward = sgm_original_forward sd_unet.original_forward = sgm_original_forward
else: else:
@@ -285,8 +307,6 @@ class StableDiffusionModelHijack:
self.layers = None self.layers = None
self.clip = None self.clip = None
sd_unet.original_forward = None
def apply_circular(self, enable): def apply_circular(self, enable):
if self.circular_enabled == enable: if self.circular_enabled == enable:
+8 -4
View File
@@ -11,10 +11,14 @@ class CondFunc:
break break
except ImportError: except ImportError:
pass pass
for attr_name in func_path[i:-1]: try:
resolved_obj = getattr(resolved_obj, attr_name) for attr_name in func_path[i:-1]:
orig_func = getattr(resolved_obj, func_path[-1]) resolved_obj = getattr(resolved_obj, attr_name)
setattr(resolved_obj, func_path[-1], lambda *args, **kwargs: self(*args, **kwargs)) orig_func = getattr(resolved_obj, func_path[-1])
setattr(resolved_obj, func_path[-1], lambda *args, **kwargs: self(*args, **kwargs))
except AttributeError:
print(f"Warning: Failed to resolve {orig_func} for CondFunc hijack")
pass
self.__init__(orig_func, sub_func, cond_func) self.__init__(orig_func, sub_func, cond_func)
return lambda *args, **kwargs: self(*args, **kwargs) return lambda *args, **kwargs: self(*args, **kwargs)
def __init__(self, orig_func, sub_func, cond_func): def __init__(self, orig_func, sub_func, cond_func):
+76 -26
View File
@@ -1,7 +1,6 @@
import collections import collections
import os.path import os.path
import sys import sys
import gc
import threading import threading
import torch import torch
@@ -231,15 +230,19 @@ def select_checkpoint():
return checkpoint_info return checkpoint_info
checkpoint_dict_replacements = { checkpoint_dict_replacements_sd1 = {
'cond_stage_model.transformer.embeddings.': 'cond_stage_model.transformer.text_model.embeddings.', 'cond_stage_model.transformer.embeddings.': 'cond_stage_model.transformer.text_model.embeddings.',
'cond_stage_model.transformer.encoder.': 'cond_stage_model.transformer.text_model.encoder.', 'cond_stage_model.transformer.encoder.': 'cond_stage_model.transformer.text_model.encoder.',
'cond_stage_model.transformer.final_layer_norm.': 'cond_stage_model.transformer.text_model.final_layer_norm.', 'cond_stage_model.transformer.final_layer_norm.': 'cond_stage_model.transformer.text_model.final_layer_norm.',
} }
checkpoint_dict_replacements_sd2_turbo = { # Converts SD 2.1 Turbo from SGM to LDM format.
'conditioner.embedders.0.': 'cond_stage_model.',
}
def transform_checkpoint_dict_key(k):
for text, replacement in checkpoint_dict_replacements.items(): def transform_checkpoint_dict_key(k, replacements):
for text, replacement in replacements.items():
if k.startswith(text): if k.startswith(text):
k = replacement + k[len(text):] k = replacement + k[len(text):]
@@ -250,9 +253,14 @@ def get_state_dict_from_checkpoint(pl_sd):
pl_sd = pl_sd.pop("state_dict", pl_sd) pl_sd = pl_sd.pop("state_dict", pl_sd)
pl_sd.pop("state_dict", None) pl_sd.pop("state_dict", None)
is_sd2_turbo = 'conditioner.embedders.0.model.ln_final.weight' in pl_sd and pl_sd['conditioner.embedders.0.model.ln_final.weight'].size()[0] == 1024
sd = {} sd = {}
for k, v in pl_sd.items(): for k, v in pl_sd.items():
new_key = transform_checkpoint_dict_key(k) if is_sd2_turbo:
new_key = transform_checkpoint_dict_key(k, checkpoint_dict_replacements_sd2_turbo)
else:
new_key = transform_checkpoint_dict_key(k, checkpoint_dict_replacements_sd1)
if new_key is not None: if new_key is not None:
sd[new_key] = v sd[new_key] = v
@@ -340,10 +348,28 @@ class SkipWritingToConfig:
SkipWritingToConfig.skip = self.previous SkipWritingToConfig.skip = self.previous
def check_fp8(model):
if model is None:
return None
if devices.get_optimal_device_name() == "mps":
enable_fp8 = False
elif shared.opts.fp8_storage == "Enable":
enable_fp8 = True
elif getattr(model, "is_sdxl", False) and shared.opts.fp8_storage == "Enable for SDXL":
enable_fp8 = True
else:
enable_fp8 = False
return enable_fp8
def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer): def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer):
sd_model_hash = checkpoint_info.calculate_shorthash() sd_model_hash = checkpoint_info.calculate_shorthash()
timer.record("calculate hash") timer.record("calculate hash")
if devices.fp8:
# prevent model to load state dict in fp8
model.half()
if not SkipWritingToConfig.skip: if not SkipWritingToConfig.skip:
shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title
@@ -353,16 +379,19 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
model.is_sdxl = hasattr(model, 'conditioner') model.is_sdxl = hasattr(model, 'conditioner')
model.is_sd2 = not model.is_sdxl and hasattr(model.cond_stage_model, 'model') model.is_sd2 = not model.is_sdxl and hasattr(model.cond_stage_model, 'model')
model.is_sd1 = not model.is_sdxl and not model.is_sd2 model.is_sd1 = not model.is_sdxl and not model.is_sd2
model.is_ssd = model.is_sdxl and 'model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight' not in state_dict.keys()
if model.is_sdxl: if model.is_sdxl:
sd_models_xl.extend_sdxl(model) sd_models_xl.extend_sdxl(model)
model.load_state_dict(state_dict, strict=False) if model.is_ssd:
timer.record("apply weights to model") sd_hijack.model_hijack.convert_sdxl_to_ssd(model)
if shared.opts.sd_checkpoint_cache > 0: if shared.opts.sd_checkpoint_cache > 0:
# cache newly loaded model # cache newly loaded model
checkpoints_loaded[checkpoint_info] = state_dict checkpoints_loaded[checkpoint_info] = state_dict.copy()
model.load_state_dict(state_dict, strict=False)
timer.record("apply weights to model")
del state_dict del state_dict
@@ -372,6 +401,7 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
if shared.cmd_opts.no_half: if shared.cmd_opts.no_half:
model.float() model.float()
model.alphas_cumprod_original = model.alphas_cumprod
devices.dtype_unet = torch.float32 devices.dtype_unet = torch.float32
timer.record("apply float()") timer.record("apply float()")
else: else:
@@ -385,7 +415,11 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
if shared.cmd_opts.upcast_sampling and depth_model: if shared.cmd_opts.upcast_sampling and depth_model:
model.depth_model = None model.depth_model = None
alphas_cumprod = model.alphas_cumprod
model.alphas_cumprod = None
model.half() model.half()
model.alphas_cumprod = alphas_cumprod
model.alphas_cumprod_original = alphas_cumprod
model.first_stage_model = vae model.first_stage_model = vae
if depth_model: if depth_model:
model.depth_model = depth_model model.depth_model = depth_model
@@ -393,6 +427,28 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
devices.dtype_unet = torch.float16 devices.dtype_unet = torch.float16
timer.record("apply half()") timer.record("apply half()")
for module in model.modules():
if hasattr(module, 'fp16_weight'):
del module.fp16_weight
if hasattr(module, 'fp16_bias'):
del module.fp16_bias
if check_fp8(model):
devices.fp8 = True
first_stage = model.first_stage_model
model.first_stage_model = None
for module in model.modules():
if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear)):
if shared.opts.cache_fp16_weight:
module.fp16_weight = module.weight.data.clone().cpu().half()
if module.bias is not None:
module.fp16_bias = module.bias.data.clone().cpu().half()
module.to(torch.float8_e4m3fn)
model.first_stage_model = first_stage
timer.record("apply fp8")
else:
devices.fp8 = False
devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16 devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16
model.first_stage_model.to(devices.dtype_vae) model.first_stage_model.to(devices.dtype_vae)
@@ -640,6 +696,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
else: else:
weight_dtype_conversion = { weight_dtype_conversion = {
'first_stage_model': None, 'first_stage_model': None,
'alphas_cumprod': None,
'': torch.float16, '': torch.float16,
} }
@@ -735,7 +792,7 @@ def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer):
return None return None
def reload_model_weights(sd_model=None, info=None): def reload_model_weights(sd_model=None, info=None, forced_reload=False):
checkpoint_info = info or select_checkpoint() checkpoint_info = info or select_checkpoint()
timer = Timer() timer = Timer()
@@ -747,11 +804,14 @@ def reload_model_weights(sd_model=None, info=None):
current_checkpoint_info = None current_checkpoint_info = None
else: else:
current_checkpoint_info = sd_model.sd_checkpoint_info current_checkpoint_info = sd_model.sd_checkpoint_info
if sd_model.sd_model_checkpoint == checkpoint_info.filename: if check_fp8(sd_model) != devices.fp8:
# load from state dict again to prevent extra numerical errors
forced_reload = True
elif sd_model.sd_model_checkpoint == checkpoint_info.filename and not forced_reload:
return sd_model return sd_model
sd_model = reuse_model_from_already_loaded(sd_model, checkpoint_info, timer) sd_model = reuse_model_from_already_loaded(sd_model, checkpoint_info, timer)
if sd_model is not None and sd_model.sd_checkpoint_info.filename == checkpoint_info.filename: if not forced_reload and sd_model is not None and sd_model.sd_checkpoint_info.filename == checkpoint_info.filename:
return sd_model return sd_model
if sd_model is not None: if sd_model is not None:
@@ -782,13 +842,13 @@ def reload_model_weights(sd_model=None, info=None):
sd_hijack.model_hijack.hijack(sd_model) sd_hijack.model_hijack.hijack(sd_model)
timer.record("hijack") timer.record("hijack")
script_callbacks.model_loaded_callback(sd_model)
timer.record("script callbacks")
if not sd_model.lowvram: if not sd_model.lowvram:
sd_model.to(devices.device) sd_model.to(devices.device)
timer.record("move model to device") timer.record("move model to device")
script_callbacks.model_loaded_callback(sd_model)
timer.record("script callbacks")
print(f"Weights loaded in {timer.summary()}.") print(f"Weights loaded in {timer.summary()}.")
model_data.set_sd_model(sd_model) model_data.set_sd_model(sd_model)
@@ -798,17 +858,7 @@ def reload_model_weights(sd_model=None, info=None):
def unload_model_weights(sd_model=None, info=None): def unload_model_weights(sd_model=None, info=None):
timer = Timer() send_model_to_cpu(sd_model or shared.sd_model)
if model_data.sd_model:
model_data.sd_model.to(devices.cpu)
sd_hijack.model_hijack.undo_hijack(model_data.sd_model)
model_data.sd_model = None
sd_model = None
gc.collect()
devices.torch_gc()
print(f"Unloaded weights {timer.summary()}.")
return sd_model return sd_model
+9 -2
View File
@@ -15,13 +15,14 @@ config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml")
config_sd2_inpainting = os.path.join(sd_repo_configs_path, "v2-inpainting-inference.yaml") config_sd2_inpainting = os.path.join(sd_repo_configs_path, "v2-inpainting-inference.yaml")
config_sdxl = os.path.join(sd_xl_repo_configs_path, "sd_xl_base.yaml") config_sdxl = os.path.join(sd_xl_repo_configs_path, "sd_xl_base.yaml")
config_sdxl_refiner = os.path.join(sd_xl_repo_configs_path, "sd_xl_refiner.yaml") config_sdxl_refiner = os.path.join(sd_xl_repo_configs_path, "sd_xl_refiner.yaml")
config_sdxl_inpainting = os.path.join(sd_configs_path, "sd_xl_inpaint.yaml")
config_depth_model = os.path.join(sd_repo_configs_path, "v2-midas-inference.yaml") config_depth_model = os.path.join(sd_repo_configs_path, "v2-midas-inference.yaml")
config_unclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-l-inference.yaml") config_unclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-l-inference.yaml")
config_unopenclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-h-inference.yaml") config_unopenclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-h-inference.yaml")
config_inpainting = os.path.join(sd_configs_path, "v1-inpainting-inference.yaml") config_inpainting = os.path.join(sd_configs_path, "v1-inpainting-inference.yaml")
config_instruct_pix2pix = os.path.join(sd_configs_path, "instruct-pix2pix.yaml") config_instruct_pix2pix = os.path.join(sd_configs_path, "instruct-pix2pix.yaml")
config_alt_diffusion = os.path.join(sd_configs_path, "alt-diffusion-inference.yaml") config_alt_diffusion = os.path.join(sd_configs_path, "alt-diffusion-inference.yaml")
config_alt_diffusion_m18 = os.path.join(sd_configs_path, "alt-diffusion-m18-inference.yaml")
def is_using_v_parameterization_for_sd2(state_dict): def is_using_v_parameterization_for_sd2(state_dict):
""" """
@@ -71,7 +72,10 @@ def guess_model_config_from_state_dict(sd, filename):
sd2_variations_weight = sd.get('embedder.model.ln_final.weight', None) sd2_variations_weight = sd.get('embedder.model.ln_final.weight', None)
if sd.get('conditioner.embedders.1.model.ln_final.weight', None) is not None: if sd.get('conditioner.embedders.1.model.ln_final.weight', None) is not None:
return config_sdxl if diffusion_model_input.shape[1] == 9:
return config_sdxl_inpainting
else:
return config_sdxl
if sd.get('conditioner.embedders.0.model.ln_final.weight', None) is not None: if sd.get('conditioner.embedders.0.model.ln_final.weight', None) is not None:
return config_sdxl_refiner return config_sdxl_refiner
elif sd.get('depth_model.model.pretrained.act_postprocess3.0.project.0.bias', None) is not None: elif sd.get('depth_model.model.pretrained.act_postprocess3.0.project.0.bias', None) is not None:
@@ -95,7 +99,10 @@ def guess_model_config_from_state_dict(sd, filename):
if diffusion_model_input.shape[1] == 8: if diffusion_model_input.shape[1] == 8:
return config_instruct_pix2pix return config_instruct_pix2pix
if sd.get('cond_stage_model.roberta.embeddings.word_embeddings.weight', None) is not None: if sd.get('cond_stage_model.roberta.embeddings.word_embeddings.weight', None) is not None:
if sd.get('cond_stage_model.transformation.weight').size()[0] == 1024:
return config_alt_diffusion_m18
return config_alt_diffusion return config_alt_diffusion
return config_default return config_default
+4 -1
View File
@@ -22,7 +22,10 @@ class WebuiSdModel(LatentDiffusion):
"""structure with additional information about the file with model's weights""" """structure with additional information about the file with model's weights"""
is_sdxl: bool is_sdxl: bool
"""True if the model's architecture is SDXL""" """True if the model's architecture is SDXL or SSD"""
is_ssd: bool
"""True if the model is SSD"""
is_sd2: bool is_sd2: bool
"""True if the model's architecture is SD 2.x""" """True if the model's architecture is SD 2.x"""
+9 -2
View File
@@ -6,6 +6,7 @@ import sgm.models.diffusion
import sgm.modules.diffusionmodules.denoiser_scaling import sgm.modules.diffusionmodules.denoiser_scaling
import sgm.modules.diffusionmodules.discretizer import sgm.modules.diffusionmodules.discretizer
from modules import devices, shared, prompt_parser from modules import devices, shared, prompt_parser
from modules import torch_utils
def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: prompt_parser.SdConditioning | list[str]): def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: prompt_parser.SdConditioning | list[str]):
@@ -34,6 +35,12 @@ def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch:
def apply_model(self: sgm.models.diffusion.DiffusionEngine, x, t, cond): def apply_model(self: sgm.models.diffusion.DiffusionEngine, x, t, cond):
sd = self.model.state_dict()
diffusion_model_input = sd.get('diffusion_model.input_blocks.0.0.weight', None)
if diffusion_model_input is not None:
if diffusion_model_input.shape[1] == 9:
x = torch.cat([x] + cond['c_concat'], dim=1)
return self.model(x, t, cond) return self.model(x, t, cond)
@@ -84,7 +91,7 @@ sgm.modules.GeneralConditioner.get_target_prompt_token_count = get_target_prompt
def extend_sdxl(model): def extend_sdxl(model):
"""this adds a bunch of parameters to make SDXL model look a bit more like SD1.5 to the rest of the codebase.""" """this adds a bunch of parameters to make SDXL model look a bit more like SD1.5 to the rest of the codebase."""
dtype = next(model.model.diffusion_model.parameters()).dtype dtype = torch_utils.get_param(model.model.diffusion_model).dtype
model.model.diffusion_model.dtype = dtype model.model.diffusion_model.dtype = dtype
model.model.conditioning_key = 'crossattn' model.model.conditioning_key = 'crossattn'
model.cond_stage_key = 'txt' model.cond_stage_key = 'txt'
@@ -93,7 +100,7 @@ def extend_sdxl(model):
model.parameterization = "v" if isinstance(model.denoiser.scaling, sgm.modules.diffusionmodules.denoiser_scaling.VScaling) else "eps" model.parameterization = "v" if isinstance(model.denoiser.scaling, sgm.modules.diffusionmodules.denoiser_scaling.VScaling) else "eps"
discretization = sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization() discretization = sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization()
model.alphas_cumprod = torch.asarray(discretization.alphas_cumprod, device=devices.device, dtype=dtype) model.alphas_cumprod = torch.asarray(discretization.alphas_cumprod, device=devices.device, dtype=torch.float32)
model.conditioner.wrapped = torch.nn.Module() model.conditioner.wrapped = torch.nn.Module()
+2 -1
View File
@@ -1,4 +1,4 @@
from modules import sd_samplers_kdiffusion, sd_samplers_timesteps, shared from modules import sd_samplers_kdiffusion, sd_samplers_timesteps, sd_samplers_lcm, shared
# imports for functions that previously were here and are used by other modules # imports for functions that previously were here and are used by other modules
from modules.sd_samplers_common import samples_to_image_grid, sample_to_image # noqa: F401 from modules.sd_samplers_common import samples_to_image_grid, sample_to_image # noqa: F401
@@ -6,6 +6,7 @@ from modules.sd_samplers_common import samples_to_image_grid, sample_to_image #
all_samplers = [ all_samplers = [
*sd_samplers_kdiffusion.samplers_data_k_diffusion, *sd_samplers_kdiffusion.samplers_data_k_diffusion,
*sd_samplers_timesteps.samplers_data_timesteps, *sd_samplers_timesteps.samplers_data_timesteps,
*sd_samplers_lcm.samplers_data_lcm,
] ]
all_samplers_map = {x.name: x for x in all_samplers} all_samplers_map = {x.name: x for x in all_samplers}
+20 -3
View File
@@ -56,6 +56,9 @@ class CFGDenoiser(torch.nn.Module):
self.sampler = sampler self.sampler = sampler
self.model_wrap = None self.model_wrap = None
self.p = None self.p = None
# NOTE: masking before denoising can cause the original latents to be oversmoothed
# as the original latents do not have noise
self.mask_before_denoising = False self.mask_before_denoising = False
@property @property
@@ -105,8 +108,21 @@ class CFGDenoiser(torch.nn.Module):
assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)" assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)"
# If we use masks, blending between the denoised and original latent images occurs here.
def apply_blend(current_latent):
blended_latent = current_latent * self.nmask + self.init_latent * self.mask
if self.p.scripts is not None:
from modules import scripts
mba = scripts.MaskBlendArgs(current_latent, self.nmask, self.init_latent, self.mask, blended_latent, denoiser=self, sigma=sigma)
self.p.scripts.on_mask_blend(self.p, mba)
blended_latent = mba.blended_latent
return blended_latent
# Blend in the original latents (before)
if self.mask_before_denoising and self.mask is not None: if self.mask_before_denoising and self.mask is not None:
x = self.init_latent * self.mask + self.nmask * x x = apply_blend(x)
batch_size = len(conds_list) batch_size = len(conds_list)
repeats = [len(conds_list[i]) for i in range(batch_size)] repeats = [len(conds_list[i]) for i in range(batch_size)]
@@ -130,7 +146,7 @@ class CFGDenoiser(torch.nn.Module):
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma] + [sigma]) sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma] + [sigma])
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_uncond] + [torch.zeros_like(self.init_latent)]) image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_uncond] + [torch.zeros_like(self.init_latent)])
denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps, tensor, uncond) denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps, tensor, uncond, self)
cfg_denoiser_callback(denoiser_params) cfg_denoiser_callback(denoiser_params)
x_in = denoiser_params.x x_in = denoiser_params.x
image_cond_in = denoiser_params.image_cond image_cond_in = denoiser_params.image_cond
@@ -207,8 +223,9 @@ class CFGDenoiser(torch.nn.Module):
else: else:
denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale) denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
# Blend in the original latents (after)
if not self.mask_before_denoising and self.mask is not None: if not self.mask_before_denoising and self.mask is not None:
denoised = self.init_latent * self.mask + self.nmask * denoised denoised = apply_blend(denoised)
self.sampler.last_latent = self.get_pred_x0(torch.cat([x_in[i:i + 1] for i in denoised_image_indexes]), torch.cat([x_out[i:i + 1] for i in denoised_image_indexes]), sigma) self.sampler.last_latent = self.get_pred_x0(torch.cat([x_in[i:i + 1] for i in denoised_image_indexes]), torch.cat([x_out[i:i + 1] for i in denoised_image_indexes]), sigma)
+1 -1
View File
@@ -60,7 +60,7 @@ def restart_sampler(model, x, sigmas, extra_args=None, callback=None, disable=No
sigma_restart = get_sigmas_karras(restart_steps, sigmas[min_idx].item(), sigmas[max_idx].item(), device=sigmas.device)[:-1] sigma_restart = get_sigmas_karras(restart_steps, sigmas[min_idx].item(), sigmas[max_idx].item(), device=sigmas.device)[:-1]
while restart_times > 0: while restart_times > 0:
restart_times -= 1 restart_times -= 1
step_list.extend([(old_sigma, new_sigma) for (old_sigma, new_sigma) in zip(sigma_restart[:-1], sigma_restart[1:])]) step_list.extend(zip(sigma_restart[:-1], sigma_restart[1:]))
last_sigma = None last_sigma = None
for old_sigma, new_sigma in tqdm.tqdm(step_list, disable=disable): for old_sigma, new_sigma in tqdm.tqdm(step_list, disable=disable):

Some files were not shown because too many files have changed in this diff Show More