Compare commits
253 Commits
v1.5.2-RC
...
refiner_alt
| Author | SHA1 | Date | |
|---|---|---|---|
| f1d7c07a5a | |||
| 686598387f | |||
| 3f82820612 | |||
| 6c7b6ecb81 | |||
| 57e8a11d17 | |||
| f9950da3e3 | |||
| aa42c0ff8e | |||
| 06da34d47a | |||
| 5cae08f2c3 | |||
| 8f31b139b8 | |||
| ce4be668fe | |||
| 2e8b40004e | |||
| 1e8482356c | |||
| e9c591b101 | |||
| ee96a6a588 | |||
| 92b99f3273 | |||
| ee75416e3e | |||
| d86d12e911 | |||
| 2844d9597b | |||
| dd1e2726f3 | |||
| f18a032190 | |||
| 9cbde6c9fd | |||
| f4e4992a4a | |||
| 31506f0771 | |||
| 85c2c138d2 | |||
| c11104fed5 | |||
| dfc01c68cd | |||
| 496cef956b | |||
| b315c20756 | |||
| c6278c15a8 | |||
| 0a0a6b2a4d | |||
| 1f7fc4d7a3 | |||
| 8ece321df3 | |||
| 1d7dcdb6c3 | |||
| 60183eebc3 | |||
| 36ca80d004 | |||
| 3f451f3042 | |||
| c980dca234 | |||
| f879cac1e7 | |||
| ad510b2cd3 | |||
| c74c708ed8 | |||
| e053e21af6 | |||
| 7a64601428 | |||
| b85ec2b9b6 | |||
| d56a9cfe6a | |||
| a32f270a47 | |||
| 8197f24dbc | |||
| ef1698fd6d | |||
| c613416af3 | |||
| 22ecb78b51 | |||
| a6b245e46f | |||
| 0ae2767ae6 | |||
| e64263653a | |||
| d2b842ce07 | |||
| d8371d0b3c | |||
| e7140a36c0 | |||
| aa744cadc8 | |||
| 63cac3c3cc | |||
| bcff763b6e | |||
| 9ac2989edd | |||
| 1d60a609a9 | |||
| 4560176640 | |||
| 31a9966b9d | |||
| c57cb6e89c | |||
| b6596cdb19 | |||
| 9213d5cb3b | |||
| 682ff8936d | |||
| f08a69e629 | |||
| fadbab3781 | |||
| daee41e0d6 | |||
| 21000f13a1 | |||
| a0e74c4db4 | |||
| 073342c887 | |||
| 6346d8eeaa | |||
| 094c416a80 | |||
| 99f5f8e76b | |||
| cd4e053e5e | |||
| 2dc2bc4ab5 | |||
| e219211ff6 | |||
| df9fd1d3ae | |||
| 2e613a6ffc | |||
| f5994e84a2 | |||
| c93857922a | |||
| 6391128b41 | |||
| 7c5480eb96 | |||
| 67312653d7 | |||
| e81b431701 | |||
| 695300929a | |||
| 82b415c9c1 | |||
| d89a915b74 | |||
| ac8dfd9386 | |||
| 1f6bfdea80 | |||
| 70e66e81e5 | |||
| f0c1063a70 | |||
| 09165916fa | |||
| c134a48016 | |||
| 75336dfc84 | |||
| 3f9e09a615 | |||
| 01486f6896 | |||
| 56c3f94ba3 | |||
| 073c0ebba3 | |||
| 362789a379 | |||
| 7f1d087cba | |||
| 3bd2c68eb4 | |||
| 71efc5bda8 | |||
| f4d9297127 | |||
| 220e298417 | |||
| f7813fad1c | |||
| 8b37734244 | |||
| bbfff771d7 | |||
| 24f21583cd | |||
| 09c1be9674 | |||
| af528552d6 | |||
| 20549a50cb | |||
| 8e840e1519 | |||
| f56a309432 | |||
| 0904df84e2 | |||
| fca42949a3 | |||
| 84b6fcd02c | |||
| ccb9233934 | |||
| 10ff071e33 | |||
| 390bffa81b | |||
| 0c9b1e7969 | |||
| 6a0d498c8e | |||
| 401ba1b879 | |||
| 07be13caa3 | |||
| 6d3a0c9506 | |||
| 0042954490 | |||
| 8a4149accc | |||
| b98fa1c397 | |||
| c6b826d796 | |||
| 2860c3be3e | |||
| 4b43480fe8 | |||
| 151b8ed3a6 | |||
| b235022c61 | |||
| c10633f93a | |||
| 0d577aba26 | |||
| c09bc2c608 | |||
| fb87a05fe8 | |||
| 4d9b096663 | |||
| 29d7e31d89 | |||
| dca121e903 | |||
| 0af4127fd1 | |||
| a1eb49627a | |||
| 02038036ff | |||
| f60d9fbe29 | |||
| cc53db6652 | |||
| a64fbe8928 | |||
| eec540b227 | |||
| 77761e7bad | |||
| 40cd59207b | |||
| 3bca90b249 | |||
| 085c903229 | |||
| 8a40e30d08 | |||
| 63a8861c19 | |||
| fb44838176 | |||
| 53ccdefc01 | |||
| 9857537053 | |||
| b95a41ad72 | |||
| 6f0abbb71a | |||
| 4ca9f70b59 | |||
| e18fc29bbf | |||
| 79d6e9cd32 | |||
| aefe1325df | |||
| 11dc92dc0a | |||
| bdeb44aeb2 | |||
| e1323fc1b7 | |||
| 3ac950248d | |||
| bef40851af | |||
| 9a52a30d2f | |||
| fc163218c4 | |||
| 19ac0adf03 | |||
| ac81c1dd1f | |||
| 6cc5a886ae | |||
| 9cbf3461f7 | |||
| 25004d4eee | |||
| 91a131aa6c | |||
| 0cb9711a15 | |||
| 89e6dfff71 | |||
| 8284ebd94c | |||
| 187323a606 | |||
| deed8439d5 | |||
| 6305632493 | |||
| 246d1f1f70 | |||
| ca6f90dc6d | |||
| 835a7dbf0e | |||
| 225eb1b1a0 | |||
| b8a903efbe | |||
| 7c22bbd3ad | |||
| 13e371af73 | |||
| ae36e0899f | |||
| b73c405013 | |||
| 8de6d3ff77 | |||
| fd43558586 | |||
| d0bf509fa1 | |||
| d6ec08ba89 | |||
| 65bf3ba260 | |||
| bed598ce7f | |||
| b1a16a298c | |||
| fee593a07f | |||
| fc8e23dec5 | |||
| a68f469030 | |||
| f7c0a963f1 | |||
| 5b06607476 | |||
| 6b68b59032 | |||
| 0a89cd1a58 | |||
| ca45ff1ae6 | |||
| 1cbfafafd2 | |||
| f451994053 | |||
| ec83db8978 | |||
| a8d4213317 | |||
| 0615b3c532 | |||
| 2d635c0192 | |||
| 88a3e1d306 | |||
| 0674fabd0d | |||
| c76a30af41 | |||
| 3c26734d60 | |||
| 2a7e34fe79 | |||
| 90eb731ff1 | |||
| 491d42bb1c | |||
| 45c0f58dc6 | |||
| 1fe2dcaa2a | |||
| 075934a944 | |||
| ed4d7912c7 | |||
| 16eddc622e | |||
| bc91f15ed3 | |||
| 118529a6dc | |||
| 33694baea1 | |||
| f873890298 | |||
| 128d59c9cc | |||
| 2f57a559ac | |||
| 2f98f7c924 | |||
| 6233268964 | |||
| ddbf4a73f5 | |||
| 4bf64976c1 | |||
| 5677296d1b | |||
| cb75734896 | |||
| fc3bdf8c11 | |||
| 0fae47e974 | |||
| c278e60131 | |||
| 3c570421d3 | |||
| 7bb0fbed13 | |||
| 37e048a7e2 | |||
| 15a94d6cf7 | |||
| 40a18d38a8 | |||
| 952effa8b1 | |||
| 0dcf6436a8 | |||
| 95c5c4d64e | |||
| 543ea5730b | |||
| 643836007f | |||
| 24bad5dc7b | |||
| 57d61de25c | |||
| 5ef7590324 |
@@ -87,5 +87,9 @@ module.exports = {
|
||||
modalNextImage: "readonly",
|
||||
// token-counters.js
|
||||
setupTokenCounters: "readonly",
|
||||
// localStorage.js
|
||||
localSet: "readonly",
|
||||
localGet: "readonly",
|
||||
localRemove: "readonly"
|
||||
}
|
||||
};
|
||||
|
||||
@@ -88,7 +88,7 @@ A browser interface based on Gradio library for Stable Diffusion.
|
||||
- [Alt-Diffusion](https://arxiv.org/abs/2211.06679) support - see [wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#alt-diffusion) for instructions
|
||||
- Now without any bad letters!
|
||||
- Load checkpoints in safetensors format
|
||||
- Eased resolution restriction: generated image's domension must be a multiple of 8 rather than 64
|
||||
- Eased resolution restriction: generated image's dimension must be a multiple of 8 rather than 64
|
||||
- Now with a license!
|
||||
- Reorder elements in the UI from settings screen
|
||||
|
||||
@@ -115,7 +115,7 @@ Alternatively, use online services (like Google Colab):
|
||||
1. Install the dependencies:
|
||||
```bash
|
||||
# Debian-based:
|
||||
sudo apt install wget git python3 python3-venv
|
||||
sudo apt install wget git python3 python3-venv libgl1 libglib2.0-0
|
||||
# Red Hat-based:
|
||||
sudo dnf install wget git python3
|
||||
# Arch-based:
|
||||
@@ -123,7 +123,7 @@ sudo pacman -S wget git python3
|
||||
```
|
||||
2. Navigate to the directory you would like the webui to be installed and execute the following command:
|
||||
```bash
|
||||
bash <(wget -qO- https://raw.githubusercontent.com/AUTOMATIC1111/stable-diffusion-webui/master/webui.sh)
|
||||
wget -q https://raw.githubusercontent.com/AUTOMATIC1111/stable-diffusion-webui/master/webui.sh
|
||||
```
|
||||
3. Run `webui.sh`.
|
||||
4. Check `webui-user.sh` for options.
|
||||
@@ -169,5 +169,6 @@ Licenses for borrowed code can be found in `Settings -> Licenses` screen, and al
|
||||
- UniPC sampler - Wenliang Zhao - https://github.com/wl-zhao/UniPC
|
||||
- TAESD - Ollin Boer Bohan - https://github.com/madebyollin/taesd
|
||||
- LyCORIS - KohakuBlueleaf
|
||||
- Restart sampling - lambertae - https://github.com/Newbeeer/diffusion_restart_sampling
|
||||
- Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user.
|
||||
- (You)
|
||||
|
||||
@@ -167,7 +167,7 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor)
|
||||
random_prompt = gr.Textbox(label='Random prompt', lines=4, max_lines=4, interactive=False)
|
||||
|
||||
with gr.Column(scale=1, min_width=120):
|
||||
generate_random_prompt = gr.Button('Generate').style(full_width=True, size="lg")
|
||||
generate_random_prompt = gr.Button('Generate', size="lg", scale=1)
|
||||
|
||||
self.edit_notes = gr.TextArea(label='Notes', lines=4)
|
||||
|
||||
|
||||
@@ -43,6 +43,6 @@ class ExtraOptionsSection(scripts.Script):
|
||||
|
||||
|
||||
shared.options_templates.update(shared.options_section(('ui', "User interface"), {
|
||||
"extra_options": shared.OptionInfo([], "Options in main UI", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in txt2img/img2img interfaces").needs_restart(),
|
||||
"extra_options_accordion": shared.OptionInfo(False, "Place options in main UI into an accordion")
|
||||
"extra_options": shared.OptionInfo([], "Options in main UI", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in txt2img/img2img interfaces").needs_reload_ui(),
|
||||
"extra_options_accordion": shared.OptionInfo(False, "Place options in main UI into an accordion").needs_restart()
|
||||
}))
|
||||
|
||||
@@ -1,20 +1,38 @@
|
||||
function toggleCss(key, css, enable) {
|
||||
var style = document.getElementById(key);
|
||||
if (enable && !style) {
|
||||
style = document.createElement('style');
|
||||
style.id = key;
|
||||
style.type = 'text/css';
|
||||
document.head.appendChild(style);
|
||||
}
|
||||
if (style && !enable) {
|
||||
document.head.removeChild(style);
|
||||
}
|
||||
if (style) {
|
||||
style.innerHTML == '';
|
||||
style.appendChild(document.createTextNode(css));
|
||||
}
|
||||
}
|
||||
|
||||
function setupExtraNetworksForTab(tabname) {
|
||||
gradioApp().querySelector('#' + tabname + '_extra_tabs').classList.add('extra-networks');
|
||||
|
||||
var tabs = gradioApp().querySelector('#' + tabname + '_extra_tabs > div');
|
||||
var search = gradioApp().querySelector('#' + tabname + '_extra_search textarea');
|
||||
var searchDiv = gradioApp().getElementById(tabname + '_extra_search');
|
||||
var search = searchDiv.querySelector('textarea');
|
||||
var sort = gradioApp().getElementById(tabname + '_extra_sort');
|
||||
var sortOrder = gradioApp().getElementById(tabname + '_extra_sortorder');
|
||||
var refresh = gradioApp().getElementById(tabname + '_extra_refresh');
|
||||
var showDirsDiv = gradioApp().getElementById(tabname + '_extra_show_dirs');
|
||||
var showDirs = gradioApp().querySelector('#' + tabname + '_extra_show_dirs input');
|
||||
|
||||
search.classList.add('search');
|
||||
sort.classList.add('sort');
|
||||
sortOrder.classList.add('sortorder');
|
||||
sort.dataset.sortkey = 'sortDefault';
|
||||
tabs.appendChild(search);
|
||||
tabs.appendChild(searchDiv);
|
||||
tabs.appendChild(sort);
|
||||
tabs.appendChild(sortOrder);
|
||||
tabs.appendChild(refresh);
|
||||
tabs.appendChild(showDirsDiv);
|
||||
|
||||
var applyFilter = function() {
|
||||
var searchTerm = search.value.toLowerCase();
|
||||
@@ -80,6 +98,15 @@ function setupExtraNetworksForTab(tabname) {
|
||||
});
|
||||
|
||||
extraNetworksApplyFilter[tabname] = applyFilter;
|
||||
|
||||
var showDirsUpdate = function() {
|
||||
var css = '#' + tabname + '_extra_tabs .extra-network-subdirs { display: none; }';
|
||||
toggleCss(tabname + '_extra_show_dirs_style', css, !showDirs.checked);
|
||||
localSet('extra-networks-show-dirs', showDirs.checked ? 1 : 0);
|
||||
};
|
||||
showDirs.checked = localGet('extra-networks-show-dirs', 1) == 1;
|
||||
showDirs.addEventListener("change", showDirsUpdate);
|
||||
showDirsUpdate();
|
||||
}
|
||||
|
||||
function applyExtraNetworkFilter(tabname) {
|
||||
@@ -179,7 +206,7 @@ function saveCardPreview(event, tabname, filename) {
|
||||
}
|
||||
|
||||
function extraNetworksSearchButton(tabs_id, event) {
|
||||
var searchTextarea = gradioApp().querySelector("#" + tabs_id + ' > div > textarea');
|
||||
var searchTextarea = gradioApp().querySelector("#" + tabs_id + ' > label > textarea');
|
||||
var button = event.target;
|
||||
var text = button.classList.contains("search-all") ? "" : button.textContent.trim();
|
||||
|
||||
|
||||
@@ -190,3 +190,14 @@ onUiUpdate(function(mutationRecords) {
|
||||
tooltipCheckTimer = setTimeout(processTooltipCheckNodes, 1000);
|
||||
}
|
||||
});
|
||||
|
||||
onUiLoaded(function() {
|
||||
for (var comp of window.gradio_config.components) {
|
||||
if (comp.props.webui_tooltip && comp.props.elem_id) {
|
||||
var elem = gradioApp().getElementById(comp.props.elem_id);
|
||||
if (elem) {
|
||||
elem.title = comp.props.webui_tooltip;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
@@ -0,0 +1,26 @@
|
||||
|
||||
function localSet(k, v) {
|
||||
try {
|
||||
localStorage.setItem(k, v);
|
||||
} catch (e) {
|
||||
console.warn(`Failed to save ${k} to localStorage: ${e}`);
|
||||
}
|
||||
}
|
||||
|
||||
function localGet(k, def) {
|
||||
try {
|
||||
return localStorage.getItem(k);
|
||||
} catch (e) {
|
||||
console.warn(`Failed to load ${k} from localStorage: ${e}`);
|
||||
}
|
||||
|
||||
return def;
|
||||
}
|
||||
|
||||
function localRemove(k) {
|
||||
try {
|
||||
return localStorage.removeItem(k);
|
||||
} catch (e) {
|
||||
console.warn(`Failed to remove ${k} from localStorage: ${e}`);
|
||||
}
|
||||
}
|
||||
@@ -11,11 +11,11 @@ var ignore_ids_for_localization = {
|
||||
train_hypernetwork: 'OPTION',
|
||||
txt2img_styles: 'OPTION',
|
||||
img2img_styles: 'OPTION',
|
||||
setting_random_artist_categories: 'SPAN',
|
||||
setting_face_restoration_model: 'SPAN',
|
||||
setting_realesrgan_enabled_models: 'SPAN',
|
||||
extras_upscaler_1: 'SPAN',
|
||||
extras_upscaler_2: 'SPAN',
|
||||
setting_random_artist_categories: 'OPTION',
|
||||
setting_face_restoration_model: 'OPTION',
|
||||
setting_realesrgan_enabled_models: 'OPTION',
|
||||
extras_upscaler_1: 'OPTION',
|
||||
extras_upscaler_2: 'OPTION',
|
||||
};
|
||||
|
||||
var re_num = /^[.\d]+$/;
|
||||
|
||||
+8
-10
@@ -152,11 +152,11 @@ function submit() {
|
||||
showSubmitButtons('txt2img', false);
|
||||
|
||||
var id = randomId();
|
||||
localStorage.setItem("txt2img_task_id", id);
|
||||
localSet("txt2img_task_id", id);
|
||||
|
||||
requestProgress(id, gradioApp().getElementById('txt2img_gallery_container'), gradioApp().getElementById('txt2img_gallery'), function() {
|
||||
showSubmitButtons('txt2img', true);
|
||||
localStorage.removeItem("txt2img_task_id");
|
||||
localRemove("txt2img_task_id");
|
||||
showRestoreProgressButton('txt2img', false);
|
||||
});
|
||||
|
||||
@@ -171,11 +171,11 @@ function submit_img2img() {
|
||||
showSubmitButtons('img2img', false);
|
||||
|
||||
var id = randomId();
|
||||
localStorage.setItem("img2img_task_id", id);
|
||||
localSet("img2img_task_id", id);
|
||||
|
||||
requestProgress(id, gradioApp().getElementById('img2img_gallery_container'), gradioApp().getElementById('img2img_gallery'), function() {
|
||||
showSubmitButtons('img2img', true);
|
||||
localStorage.removeItem("img2img_task_id");
|
||||
localRemove("img2img_task_id");
|
||||
showRestoreProgressButton('img2img', false);
|
||||
});
|
||||
|
||||
@@ -189,9 +189,7 @@ function submit_img2img() {
|
||||
|
||||
function restoreProgressTxt2img() {
|
||||
showRestoreProgressButton("txt2img", false);
|
||||
var id = localStorage.getItem("txt2img_task_id");
|
||||
|
||||
id = localStorage.getItem("txt2img_task_id");
|
||||
var id = localGet("txt2img_task_id");
|
||||
|
||||
if (id) {
|
||||
requestProgress(id, gradioApp().getElementById('txt2img_gallery_container'), gradioApp().getElementById('txt2img_gallery'), function() {
|
||||
@@ -205,7 +203,7 @@ function restoreProgressTxt2img() {
|
||||
function restoreProgressImg2img() {
|
||||
showRestoreProgressButton("img2img", false);
|
||||
|
||||
var id = localStorage.getItem("img2img_task_id");
|
||||
var id = localGet("img2img_task_id");
|
||||
|
||||
if (id) {
|
||||
requestProgress(id, gradioApp().getElementById('img2img_gallery_container'), gradioApp().getElementById('img2img_gallery'), function() {
|
||||
@@ -218,8 +216,8 @@ function restoreProgressImg2img() {
|
||||
|
||||
|
||||
onUiLoaded(function() {
|
||||
showRestoreProgressButton('txt2img', localStorage.getItem("txt2img_task_id"));
|
||||
showRestoreProgressButton('img2img', localStorage.getItem("img2img_task_id"));
|
||||
showRestoreProgressButton('txt2img', localGet("txt2img_task_id"));
|
||||
showRestoreProgressButton('img2img', localGet("img2img_task_id"));
|
||||
});
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from modules import launch_utils
|
||||
|
||||
|
||||
args = launch_utils.args
|
||||
python = launch_utils.python
|
||||
git = launch_utils.git
|
||||
@@ -26,8 +25,11 @@ start = launch_utils.start
|
||||
|
||||
|
||||
def main():
|
||||
if not args.skip_prepare_environment:
|
||||
prepare_environment()
|
||||
launch_utils.startup_timer.record("initial startup")
|
||||
|
||||
with launch_utils.startup_timer.subcategory("prepare environment"):
|
||||
if not args.skip_prepare_environment:
|
||||
prepare_environment()
|
||||
|
||||
if args.test_server:
|
||||
configure_for_tests()
|
||||
|
||||
+8
-1
@@ -15,7 +15,7 @@ from fastapi.encoders import jsonable_encoder
|
||||
from secrets import compare_digest
|
||||
|
||||
import modules.shared as shared
|
||||
from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart
|
||||
from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart, shared_items
|
||||
from modules.api import models
|
||||
from modules.shared import opts
|
||||
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
|
||||
@@ -197,6 +197,7 @@ class Api:
|
||||
self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[models.PromptStyleItem])
|
||||
self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=models.EmbeddingsResponse)
|
||||
self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"])
|
||||
self.add_api_route("/sdapi/v1/refresh-vae", self.refresh_vae, methods=["POST"])
|
||||
self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=models.CreateResponse)
|
||||
self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=models.CreateResponse)
|
||||
self.add_api_route("/sdapi/v1/preprocess", self.preprocess, methods=["POST"], response_model=models.PreprocessResponse)
|
||||
@@ -343,6 +344,7 @@ class Api:
|
||||
processed = process_images(p)
|
||||
finally:
|
||||
shared.state.end()
|
||||
shared.total_tqdm.clear()
|
||||
|
||||
b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else []
|
||||
|
||||
@@ -402,6 +404,7 @@ class Api:
|
||||
processed = process_images(p)
|
||||
finally:
|
||||
shared.state.end()
|
||||
shared.total_tqdm.clear()
|
||||
|
||||
b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else []
|
||||
|
||||
@@ -608,6 +611,10 @@ class Api:
|
||||
with self.queue_lock:
|
||||
shared.refresh_checkpoints()
|
||||
|
||||
def refresh_vae(self):
|
||||
with self.queue_lock:
|
||||
shared_items.refresh_vae_list()
|
||||
|
||||
def create_embedding(self, args: dict):
|
||||
try:
|
||||
shared.state.begin(job="create_embedding")
|
||||
|
||||
@@ -3,7 +3,7 @@ import html
|
||||
import threading
|
||||
import time
|
||||
|
||||
from modules import shared, progress, errors
|
||||
from modules import shared, progress, errors, devices
|
||||
|
||||
queue_lock = threading.Lock()
|
||||
|
||||
@@ -75,6 +75,8 @@ def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
|
||||
error_message = f'{type(e).__name__}: {e}'
|
||||
res = extra_outputs_array + [f"<div class='error'>{html.escape(error_message)}</div>"]
|
||||
|
||||
devices.torch_gc()
|
||||
|
||||
shared.state.skipped = False
|
||||
shared.state.interrupted = False
|
||||
shared.state.job_count = 0
|
||||
|
||||
@@ -13,6 +13,7 @@ parser.add_argument("--reinstall-xformers", action='store_true', help="launch.py
|
||||
parser.add_argument("--reinstall-torch", action='store_true', help="launch.py argument: install the appropriate version of torch even if you have some version already installed")
|
||||
parser.add_argument("--update-check", action='store_true', help="launch.py argument: check for updates at startup")
|
||||
parser.add_argument("--test-server", action='store_true', help="launch.py argument: configure server for testing")
|
||||
parser.add_argument("--log-startup", action='store_true', help="launch.py argument: print a detailed log of what's happening at startup")
|
||||
parser.add_argument("--skip-prepare-environment", action='store_true', help="launch.py argument: skip all environment preparation")
|
||||
parser.add_argument("--skip-install", action='store_true', help="launch.py argument: skip installation of packages")
|
||||
parser.add_argument("--do-not-download-clip", action='store_true', help="do not download CLIP model even if it's not included in the checkpoint")
|
||||
@@ -66,6 +67,7 @@ parser.add_argument("--opt-sdp-no-mem-attention", action='store_true', help="pre
|
||||
parser.add_argument("--disable-opt-split-attention", action='store_true', help="prefer no cross-attention layer optimization for automatic choice of optimization")
|
||||
parser.add_argument("--disable-nan-check", action='store_true', help="do not check if produced images/latent spaces have nans; useful for running without a checkpoint in CI")
|
||||
parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for specified modules", default=[], type=str.lower)
|
||||
parser.add_argument("--disable-model-loading-ram-optimization", action='store_true', help="disable an optimization that reduces RAM use when loading a model")
|
||||
parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")
|
||||
parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None)
|
||||
parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False)
|
||||
@@ -110,3 +112,5 @@ parser.add_argument('--subpath', type=str, help='customize the subpath for gradi
|
||||
parser.add_argument('--add-stop-route', action='store_true', help='add /_stop route to stop server')
|
||||
parser.add_argument('--api-server-stop', action='store_true', help='enable server stop/restart/kill via api')
|
||||
parser.add_argument('--timeout-keep-alive', type=int, default=30, help='set timeout_keep_alive for uvicorn')
|
||||
parser.add_argument("--disable-all-extensions", action='store_true', help="prevent all extensions from running regardless of any other settings", default=False)
|
||||
parser.add_argument("--disable-extra-extensions", action='store_true', help=" prevent all extensions except built-in from running regardless of any other settings", default=False)
|
||||
|
||||
+75
-8
@@ -3,7 +3,7 @@ import contextlib
|
||||
from functools import lru_cache
|
||||
|
||||
import torch
|
||||
from modules import errors
|
||||
from modules import errors, rng_philox
|
||||
|
||||
if sys.platform == "darwin":
|
||||
from modules import mac_specific
|
||||
@@ -71,14 +71,17 @@ def enable_tf32():
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
|
||||
|
||||
|
||||
errors.run(enable_tf32, "Enabling TF32")
|
||||
|
||||
cpu = torch.device("cpu")
|
||||
device = device_interrogate = device_gfpgan = device_esrgan = device_codeformer = None
|
||||
dtype = torch.float16
|
||||
dtype_vae = torch.float16
|
||||
dtype_unet = torch.float16
|
||||
cpu: torch.device = torch.device("cpu")
|
||||
device: torch.device = None
|
||||
device_interrogate: torch.device = None
|
||||
device_gfpgan: torch.device = None
|
||||
device_esrgan: torch.device = None
|
||||
device_codeformer: torch.device = None
|
||||
dtype: torch.dtype = torch.float16
|
||||
dtype_vae: torch.dtype = torch.float16
|
||||
dtype_unet: torch.dtype = torch.float16
|
||||
unet_needs_upcast = False
|
||||
|
||||
|
||||
@@ -90,23 +93,87 @@ def cond_cast_float(input):
|
||||
return input.float() if unet_needs_upcast else input
|
||||
|
||||
|
||||
nv_rng = None
|
||||
|
||||
|
||||
def randn(seed, shape):
|
||||
"""Generate a tensor with random numbers from a normal distribution using seed.
|
||||
|
||||
Uses the seed parameter to set the global torch seed; to generate more with that seed, use randn_like/randn_without_seed."""
|
||||
|
||||
from modules.shared import opts
|
||||
|
||||
torch.manual_seed(seed)
|
||||
manual_seed(seed)
|
||||
|
||||
if opts.randn_source == "NV":
|
||||
return torch.asarray(nv_rng.randn(shape), device=device)
|
||||
|
||||
if opts.randn_source == "CPU" or device.type == 'mps':
|
||||
return torch.randn(shape, device=cpu).to(device)
|
||||
|
||||
return torch.randn(shape, device=device)
|
||||
|
||||
|
||||
def randn_local(seed, shape):
|
||||
"""Generate a tensor with random numbers from a normal distribution using seed.
|
||||
|
||||
Does not change the global random number generator. You can only generate the seed's first tensor using this function."""
|
||||
|
||||
from modules.shared import opts
|
||||
|
||||
if opts.randn_source == "NV":
|
||||
rng = rng_philox.Generator(seed)
|
||||
return torch.asarray(rng.randn(shape), device=device)
|
||||
|
||||
local_device = cpu if opts.randn_source == "CPU" or device.type == 'mps' else device
|
||||
local_generator = torch.Generator(local_device).manual_seed(int(seed))
|
||||
return torch.randn(shape, device=local_device, generator=local_generator).to(device)
|
||||
|
||||
|
||||
def randn_like(x):
|
||||
"""Generate a tensor with random numbers from a normal distribution using the previously initialized genrator.
|
||||
|
||||
Use either randn() or manual_seed() to initialize the generator."""
|
||||
|
||||
from modules.shared import opts
|
||||
|
||||
if opts.randn_source == "NV":
|
||||
return torch.asarray(nv_rng.randn(x.shape), device=x.device, dtype=x.dtype)
|
||||
|
||||
if opts.randn_source == "CPU" or x.device.type == 'mps':
|
||||
return torch.randn_like(x, device=cpu).to(x.device)
|
||||
|
||||
return torch.randn_like(x)
|
||||
|
||||
|
||||
def randn_without_seed(shape):
|
||||
"""Generate a tensor with random numbers from a normal distribution using the previously initialized genrator.
|
||||
|
||||
Use either randn() or manual_seed() to initialize the generator."""
|
||||
|
||||
from modules.shared import opts
|
||||
|
||||
if opts.randn_source == "NV":
|
||||
return torch.asarray(nv_rng.randn(shape), device=device)
|
||||
|
||||
if opts.randn_source == "CPU" or device.type == 'mps':
|
||||
return torch.randn(shape, device=cpu).to(device)
|
||||
|
||||
return torch.randn(shape, device=device)
|
||||
|
||||
|
||||
def manual_seed(seed):
|
||||
"""Set up a global random number generator using the specified seed."""
|
||||
from modules.shared import opts
|
||||
|
||||
if opts.randn_source == "NV":
|
||||
global nv_rng
|
||||
nv_rng = rng_philox.Generator(seed)
|
||||
return
|
||||
|
||||
torch.manual_seed(seed)
|
||||
|
||||
|
||||
def autocast(disable=False):
|
||||
from modules import shared
|
||||
|
||||
|
||||
+52
-1
@@ -14,7 +14,8 @@ def record_exception():
|
||||
if exception_records and exception_records[-1] == e:
|
||||
return
|
||||
|
||||
exception_records.append((e, tb))
|
||||
from modules import sysinfo
|
||||
exception_records.append(sysinfo.format_exception(e, tb))
|
||||
|
||||
if len(exception_records) > 5:
|
||||
exception_records.pop(0)
|
||||
@@ -83,3 +84,53 @@ def run(code, task):
|
||||
code()
|
||||
except Exception as e:
|
||||
display(task, e)
|
||||
|
||||
|
||||
def check_versions():
|
||||
from packaging import version
|
||||
from modules import shared
|
||||
|
||||
import torch
|
||||
import gradio
|
||||
|
||||
expected_torch_version = "2.0.0"
|
||||
expected_xformers_version = "0.0.20"
|
||||
expected_gradio_version = "3.39.0"
|
||||
|
||||
if version.parse(torch.__version__) < version.parse(expected_torch_version):
|
||||
print_error_explanation(f"""
|
||||
You are running torch {torch.__version__}.
|
||||
The program is tested to work with torch {expected_torch_version}.
|
||||
To reinstall the desired version, run with commandline flag --reinstall-torch.
|
||||
Beware that this will cause a lot of large files to be downloaded, as well as
|
||||
there are reports of issues with training tab on the latest version.
|
||||
|
||||
Use --skip-version-check commandline argument to disable this check.
|
||||
""".strip())
|
||||
|
||||
if shared.xformers_available:
|
||||
import xformers
|
||||
|
||||
if version.parse(xformers.__version__) < version.parse(expected_xformers_version):
|
||||
print_error_explanation(f"""
|
||||
You are running xformers {xformers.__version__}.
|
||||
The program is tested to work with xformers {expected_xformers_version}.
|
||||
To reinstall the desired version, run with commandline flag --reinstall-xformers.
|
||||
|
||||
Use --skip-version-check commandline argument to disable this check.
|
||||
""".strip())
|
||||
|
||||
if gradio.__version__ != expected_gradio_version:
|
||||
print_error_explanation(f"""
|
||||
You are running gradio {gradio.__version__}.
|
||||
The program is designed to work with gradio {expected_gradio_version}.
|
||||
Using a different version of gradio is extremely likely to break the program.
|
||||
|
||||
Reasons why you have the mismatched gradio version can be:
|
||||
- you use --skip-install flag.
|
||||
- you use webui.py to start the program instead of launch.py.
|
||||
- an extension installs the incompatible gradio version.
|
||||
|
||||
Use --skip-version-check commandline argument to disable this check.
|
||||
""".strip())
|
||||
|
||||
|
||||
@@ -11,9 +11,9 @@ os.makedirs(extensions_dir, exist_ok=True)
|
||||
|
||||
|
||||
def active():
|
||||
if shared.opts.disable_all_extensions == "all":
|
||||
if shared.cmd_opts.disable_all_extensions or shared.opts.disable_all_extensions == "all":
|
||||
return []
|
||||
elif shared.opts.disable_all_extensions == "extra":
|
||||
elif shared.cmd_opts.disable_extra_extensions or shared.opts.disable_all_extensions == "extra":
|
||||
return [x for x in extensions if x.enabled and x.is_builtin]
|
||||
else:
|
||||
return [x for x in extensions if x.enabled]
|
||||
@@ -141,8 +141,12 @@ def list_extensions():
|
||||
if not os.path.isdir(extensions_dir):
|
||||
return
|
||||
|
||||
if shared.opts.disable_all_extensions == "all":
|
||||
if shared.cmd_opts.disable_all_extensions:
|
||||
print("*** \"--disable-all-extensions\" arg was used, will not load any extensions ***")
|
||||
elif shared.opts.disable_all_extensions == "all":
|
||||
print("*** \"Disable all extensions\" option was set, will not load any extensions ***")
|
||||
elif shared.cmd_opts.disable_extra_extensions:
|
||||
print("*** \"--disable-extra-extensions\" arg was used, will only load built-in extensions ***")
|
||||
elif shared.opts.disable_all_extensions == "extra":
|
||||
print("*** \"Disable all extensions\" option was set, will only load built-in extensions ***")
|
||||
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from collections import defaultdict
|
||||
|
||||
@@ -177,3 +179,20 @@ def parse_prompts(prompts):
|
||||
|
||||
return res, extra_data
|
||||
|
||||
|
||||
def get_user_metadata(filename):
|
||||
if filename is None:
|
||||
return {}
|
||||
|
||||
basename, ext = os.path.splitext(filename)
|
||||
metadata_filename = basename + '.json'
|
||||
|
||||
metadata = {}
|
||||
try:
|
||||
if os.path.isfile(metadata_filename):
|
||||
with open(metadata_filename, "r", encoding="utf8") as file:
|
||||
metadata = json.load(file)
|
||||
except Exception as e:
|
||||
errors.display(e, f"reading extra network user metadata from {metadata_filename}")
|
||||
|
||||
return metadata
|
||||
|
||||
+33
-6
@@ -7,7 +7,7 @@ import json
|
||||
import torch
|
||||
import tqdm
|
||||
|
||||
from modules import shared, images, sd_models, sd_vae, sd_models_config
|
||||
from modules import shared, images, sd_models, sd_vae, sd_models_config, errors
|
||||
from modules.ui_common import plaintext_to_html
|
||||
import gradio as gr
|
||||
import safetensors.torch
|
||||
@@ -72,7 +72,20 @@ def to_half(tensor, enable):
|
||||
return tensor
|
||||
|
||||
|
||||
def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights, save_metadata):
|
||||
def read_metadata(primary_model_name, secondary_model_name, tertiary_model_name):
|
||||
metadata = {}
|
||||
|
||||
for checkpoint_name in [primary_model_name, secondary_model_name, tertiary_model_name]:
|
||||
checkpoint_info = sd_models.checkpoints_list.get(checkpoint_name, None)
|
||||
if checkpoint_info is None:
|
||||
continue
|
||||
|
||||
metadata.update(checkpoint_info.metadata)
|
||||
|
||||
return json.dumps(metadata, indent=4, ensure_ascii=False)
|
||||
|
||||
|
||||
def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights, save_metadata, add_merge_recipe, copy_metadata_fields, metadata_json):
|
||||
shared.state.begin(job="model-merge")
|
||||
|
||||
def fail(message):
|
||||
@@ -241,11 +254,25 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
|
||||
shared.state.textinfo = "Saving"
|
||||
print(f"Saving to {output_modelname}...")
|
||||
|
||||
metadata = None
|
||||
metadata = {}
|
||||
|
||||
if save_metadata and copy_metadata_fields:
|
||||
if primary_model_info:
|
||||
metadata.update(primary_model_info.metadata)
|
||||
if secondary_model_info:
|
||||
metadata.update(secondary_model_info.metadata)
|
||||
if tertiary_model_info:
|
||||
metadata.update(tertiary_model_info.metadata)
|
||||
|
||||
if save_metadata:
|
||||
metadata = {"format": "pt"}
|
||||
try:
|
||||
metadata.update(json.loads(metadata_json))
|
||||
except Exception as e:
|
||||
errors.display(e, "readin metadata from json")
|
||||
|
||||
metadata["format"] = "pt"
|
||||
|
||||
if save_metadata and add_merge_recipe:
|
||||
merge_recipe = {
|
||||
"type": "webui", # indicate this model was merged with webui's built-in merger
|
||||
"primary_model_hash": primary_model_info.sha256,
|
||||
@@ -261,7 +288,6 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
|
||||
"is_inpainting": result_is_inpainting_model,
|
||||
"is_instruct_pix2pix": result_is_instruct_pix2pix_model
|
||||
}
|
||||
metadata["sd_merge_recipe"] = json.dumps(merge_recipe)
|
||||
|
||||
sd_merge_models = {}
|
||||
|
||||
@@ -281,11 +307,12 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
|
||||
if tertiary_model_info:
|
||||
add_model_metadata(tertiary_model_info)
|
||||
|
||||
metadata["sd_merge_recipe"] = json.dumps(merge_recipe)
|
||||
metadata["sd_merge_models"] = json.dumps(sd_merge_models)
|
||||
|
||||
_, extension = os.path.splitext(output_modelname)
|
||||
if extension.lower() == ".safetensors":
|
||||
safetensors.torch.save_file(theta_0, output_modelname, metadata=metadata)
|
||||
safetensors.torch.save_file(theta_0, output_modelname, metadata=metadata if len(metadata)>0 else None)
|
||||
else:
|
||||
torch.save(theta_0, output_modelname)
|
||||
|
||||
|
||||
@@ -280,6 +280,9 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
|
||||
if "Hires sampler" not in res:
|
||||
res["Hires sampler"] = "Use same sampler"
|
||||
|
||||
if "Hires checkpoint" not in res:
|
||||
res["Hires checkpoint"] = "Use same checkpoint"
|
||||
|
||||
if "Hires prompt" not in res:
|
||||
res["Hires prompt"] = ""
|
||||
|
||||
@@ -304,6 +307,12 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
|
||||
if "Schedule rho" not in res:
|
||||
res["Schedule rho"] = 0
|
||||
|
||||
if "VAE Encoder" not in res:
|
||||
res["VAE Encoder"] = "Full"
|
||||
|
||||
if "VAE Decoder" not in res:
|
||||
res["VAE Decoder"] = "Full"
|
||||
|
||||
return res
|
||||
|
||||
|
||||
@@ -319,6 +328,10 @@ infotext_to_setting_name_mapping = [
|
||||
('Noise multiplier', 'initial_noise_multiplier'),
|
||||
('Eta', 'eta_ancestral'),
|
||||
('Eta DDIM', 'eta_ddim'),
|
||||
('Sigma churn', 's_churn'),
|
||||
('Sigma tmin', 's_tmin'),
|
||||
('Sigma tmax', 's_tmax'),
|
||||
('Sigma noise', 's_noise'),
|
||||
('Discard penultimate sigma', 'always_discard_next_to_last_sigma'),
|
||||
('UniPC variant', 'uni_pc_variant'),
|
||||
('UniPC skip type', 'uni_pc_skip_type'),
|
||||
@@ -329,6 +342,10 @@ infotext_to_setting_name_mapping = [
|
||||
('RNG', 'randn_source'),
|
||||
('NGMS', 's_min_uncond'),
|
||||
('Pad conds', 'pad_cond_uncond'),
|
||||
('VAE Encoder', 'sd_vae_encode_method'),
|
||||
('VAE Decoder', 'sd_vae_decode_method'),
|
||||
('Refiner', 'sd_refiner_checkpoint'),
|
||||
('Refiner switch at', 'sd_refiner_switch_at'),
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,60 @@
|
||||
import gradio as gr
|
||||
|
||||
from modules import scripts
|
||||
|
||||
def add_classes_to_gradio_component(comp):
|
||||
"""
|
||||
this adds gradio-* to the component for css styling (ie gradio-button to gr.Button), as well as some others
|
||||
"""
|
||||
|
||||
comp.elem_classes = [f"gradio-{comp.get_block_name()}", *(comp.elem_classes or [])]
|
||||
|
||||
if getattr(comp, 'multiselect', False):
|
||||
comp.elem_classes.append('multiselect')
|
||||
|
||||
|
||||
def IOComponent_init(self, *args, **kwargs):
|
||||
self.webui_tooltip = kwargs.pop('tooltip', None)
|
||||
|
||||
if scripts.scripts_current is not None:
|
||||
scripts.scripts_current.before_component(self, **kwargs)
|
||||
|
||||
scripts.script_callbacks.before_component_callback(self, **kwargs)
|
||||
|
||||
res = original_IOComponent_init(self, *args, **kwargs)
|
||||
|
||||
add_classes_to_gradio_component(self)
|
||||
|
||||
scripts.script_callbacks.after_component_callback(self, **kwargs)
|
||||
|
||||
if scripts.scripts_current is not None:
|
||||
scripts.scripts_current.after_component(self, **kwargs)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def Block_get_config(self):
|
||||
config = original_Block_get_config(self)
|
||||
|
||||
webui_tooltip = getattr(self, 'webui_tooltip', None)
|
||||
if webui_tooltip:
|
||||
config["webui_tooltip"] = webui_tooltip
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def BlockContext_init(self, *args, **kwargs):
|
||||
res = original_BlockContext_init(self, *args, **kwargs)
|
||||
|
||||
add_classes_to_gradio_component(self)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
original_IOComponent_init = gr.components.IOComponent.__init__
|
||||
original_Block_get_config = gr.blocks.Block.get_config
|
||||
original_BlockContext_init = gr.blocks.BlockContext.__init__
|
||||
|
||||
gr.components.IOComponent.__init__ = IOComponent_init
|
||||
gr.blocks.Block.get_config = Block_get_config
|
||||
gr.blocks.BlockContext.__init__ = BlockContext_init
|
||||
@@ -10,7 +10,7 @@ import torch
|
||||
import tqdm
|
||||
from einops import rearrange, repeat
|
||||
from ldm.util import default
|
||||
from modules import devices, processing, sd_models, shared, sd_samplers, hashes, sd_hijack_checkpoint, errors
|
||||
from modules import devices, sd_models, shared, sd_samplers, hashes, sd_hijack_checkpoint, errors
|
||||
from modules.textual_inversion import textual_inversion, logging
|
||||
from modules.textual_inversion.learn_schedule import LearnRateScheduler
|
||||
from torch import einsum
|
||||
@@ -469,8 +469,7 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None,
|
||||
|
||||
|
||||
def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, use_weight, create_image_every, save_hypernetwork_every, template_filename, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
|
||||
# images allows training previews to have infotext. Importing it at the top causes a circular import problem.
|
||||
from modules import images
|
||||
from modules import images, processing
|
||||
|
||||
save_hypernetwork_every = save_hypernetwork_every or 0
|
||||
create_image_every = create_image_every or 0
|
||||
|
||||
+1
-1
@@ -318,7 +318,7 @@ def resize_image(resize_mode, im, width, height, upscaler_name=None):
|
||||
return res
|
||||
|
||||
|
||||
invalid_filename_chars = '<>:"/\\|?*\n'
|
||||
invalid_filename_chars = '<>:"/\\|?*\n\r\t'
|
||||
invalid_filename_prefix = ' '
|
||||
invalid_filename_postfix = ' .'
|
||||
re_nonletters = re.compile(r'[\s' + string.punctuation + ']+')
|
||||
|
||||
+12
-26
@@ -3,14 +3,13 @@ from contextlib import closing
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image, ImageOps, ImageFilter, ImageEnhance, ImageChops, UnidentifiedImageError
|
||||
from PIL import Image, ImageOps, ImageFilter, ImageEnhance, UnidentifiedImageError
|
||||
import gradio as gr
|
||||
|
||||
from modules import sd_samplers, images as imgutil
|
||||
from modules.generation_parameters_copypaste import create_override_settings_dict, parse_generation_parameters
|
||||
from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
|
||||
from modules.shared import opts, state
|
||||
from modules.images import save_image
|
||||
import modules.shared as shared
|
||||
import modules.processing as processing
|
||||
from modules.ui import plaintext_to_html
|
||||
@@ -18,9 +17,10 @@ import modules.scripts
|
||||
|
||||
|
||||
def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=False, scale_by=1.0, use_png_info=False, png_info_props=None, png_info_dir=None):
|
||||
output_dir = output_dir.strip()
|
||||
processing.fix_seed(p)
|
||||
|
||||
images = list(shared.walk_files(input_dir, allowed_extensions=(".png", ".jpg", ".jpeg", ".webp")))
|
||||
images = list(shared.walk_files(input_dir, allowed_extensions=(".png", ".jpg", ".jpeg", ".webp", ".tif", ".tiff")))
|
||||
|
||||
is_inpaint_batch = False
|
||||
if inpaint_mask_dir:
|
||||
@@ -32,11 +32,6 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal
|
||||
|
||||
print(f"Will process {len(images)} images, creating {p.n_iter * p.batch_size} new images for each.")
|
||||
|
||||
save_normally = output_dir == ''
|
||||
|
||||
p.do_not_save_grid = True
|
||||
p.do_not_save_samples = not save_normally
|
||||
|
||||
state.job_count = len(images) * p.n_iter
|
||||
|
||||
# extract "default" params to use in case getting png info fails
|
||||
@@ -111,21 +106,14 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal
|
||||
|
||||
proc = modules.scripts.scripts_img2img.run(p, *args)
|
||||
if proc is None:
|
||||
proc = process_images(p)
|
||||
|
||||
for n, processed_image in enumerate(proc.images):
|
||||
filename = image_path.stem
|
||||
infotext = proc.infotext(p, n)
|
||||
relpath = os.path.dirname(os.path.relpath(image, input_dir))
|
||||
|
||||
if n > 0:
|
||||
filename += f"-{n}"
|
||||
|
||||
if not save_normally:
|
||||
os.makedirs(os.path.join(output_dir, relpath), exist_ok=True)
|
||||
if processed_image.mode == 'RGBA':
|
||||
processed_image = processed_image.convert("RGB")
|
||||
save_image(processed_image, os.path.join(output_dir, relpath), None, extension=opts.samples_format, info=infotext, forced_filename=filename, save_to_dirs=False)
|
||||
if output_dir:
|
||||
p.outpath_samples = output_dir
|
||||
p.override_settings['save_to_dirs'] = False
|
||||
if p.n_iter > 1 or p.batch_size > 1:
|
||||
p.override_settings['samples_filename_pattern'] = f'{image_path.stem}-[generation_number]'
|
||||
else:
|
||||
p.override_settings['samples_filename_pattern'] = f'{image_path.stem}'
|
||||
process_images(p)
|
||||
|
||||
|
||||
def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, request: gr.Request, *args):
|
||||
@@ -141,9 +129,7 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
|
||||
mask = None
|
||||
elif mode == 2: # inpaint
|
||||
image, mask = init_img_with_mask["image"], init_img_with_mask["mask"]
|
||||
alpha_mask = ImageOps.invert(image.split()[-1]).convert('L').point(lambda x: 255 if x > 0 else 0, mode='1')
|
||||
mask = mask.convert('L').point(lambda x: 255 if x > 128 else 0, mode='1')
|
||||
mask = ImageChops.lighter(alpha_mask, mask).convert('L')
|
||||
mask = mask.split()[-1].convert("L").point(lambda x: 255 if x > 128 else 0)
|
||||
image = image.convert("RGB")
|
||||
elif mode == 3: # inpaint sketch
|
||||
image = inpaint_color_sketch
|
||||
|
||||
+26
-5
@@ -10,9 +10,7 @@ from functools import lru_cache
|
||||
|
||||
from modules import cmd_args, errors
|
||||
from modules.paths_internal import script_path, extensions_dir
|
||||
from modules import timer
|
||||
|
||||
timer.startup_timer.record("start")
|
||||
from modules.timer import startup_timer
|
||||
|
||||
args, _ = cmd_args.parser.parse_known_args()
|
||||
|
||||
@@ -226,8 +224,13 @@ def run_extensions_installers(settings_file):
|
||||
if not os.path.isdir(extensions_dir):
|
||||
return
|
||||
|
||||
for dirname_extension in list_extensions(settings_file):
|
||||
run_extension_installer(os.path.join(extensions_dir, dirname_extension))
|
||||
with startup_timer.subcategory("run extensions installers"):
|
||||
for dirname_extension in list_extensions(settings_file):
|
||||
path = os.path.join(extensions_dir, dirname_extension)
|
||||
|
||||
if os.path.isdir(path):
|
||||
run_extension_installer(path)
|
||||
startup_timer.record(dirname_extension)
|
||||
|
||||
|
||||
re_requirement = re.compile(r"\s*([-_a-zA-Z0-9]+)\s*(?:==\s*([-+_.a-zA-Z0-9]+))?\s*")
|
||||
@@ -300,8 +303,11 @@ def prepare_environment():
|
||||
if not args.skip_python_version_check:
|
||||
check_python_version()
|
||||
|
||||
startup_timer.record("checks")
|
||||
|
||||
commit = commit_hash()
|
||||
tag = git_tag()
|
||||
startup_timer.record("git version info")
|
||||
|
||||
print(f"Python {sys.version}")
|
||||
print(f"Version: {tag}")
|
||||
@@ -309,21 +315,27 @@ def prepare_environment():
|
||||
|
||||
if args.reinstall_torch or not is_installed("torch") or not is_installed("torchvision"):
|
||||
run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch", live=True)
|
||||
startup_timer.record("install torch")
|
||||
|
||||
if not args.skip_torch_cuda_test and not check_run_python("import torch; assert torch.cuda.is_available()"):
|
||||
raise RuntimeError(
|
||||
'Torch is not able to use GPU; '
|
||||
'add --skip-torch-cuda-test to COMMANDLINE_ARGS variable to disable this check'
|
||||
)
|
||||
startup_timer.record("torch GPU test")
|
||||
|
||||
|
||||
if not is_installed("gfpgan"):
|
||||
run_pip(f"install {gfpgan_package}", "gfpgan")
|
||||
startup_timer.record("install gfpgan")
|
||||
|
||||
if not is_installed("clip"):
|
||||
run_pip(f"install {clip_package}", "clip")
|
||||
startup_timer.record("install clip")
|
||||
|
||||
if not is_installed("open_clip"):
|
||||
run_pip(f"install {openclip_package}", "open_clip")
|
||||
startup_timer.record("install open_clip")
|
||||
|
||||
if (not is_installed("xformers") or args.reinstall_xformers) and args.xformers:
|
||||
if platform.system() == "Windows":
|
||||
@@ -337,8 +349,11 @@ def prepare_environment():
|
||||
elif platform.system() == "Linux":
|
||||
run_pip(f"install -U -I --no-deps {xformers_package}", "xformers")
|
||||
|
||||
startup_timer.record("install xformers")
|
||||
|
||||
if not is_installed("ngrok") and args.ngrok:
|
||||
run_pip("install ngrok", "ngrok")
|
||||
startup_timer.record("install ngrok")
|
||||
|
||||
os.makedirs(os.path.join(script_path, dir_repos), exist_ok=True)
|
||||
|
||||
@@ -348,22 +363,28 @@ def prepare_environment():
|
||||
git_clone(codeformer_repo, repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash)
|
||||
git_clone(blip_repo, repo_dir('BLIP'), "BLIP", blip_commit_hash)
|
||||
|
||||
startup_timer.record("clone repositores")
|
||||
|
||||
if not is_installed("lpips"):
|
||||
run_pip(f"install -r \"{os.path.join(repo_dir('CodeFormer'), 'requirements.txt')}\"", "requirements for CodeFormer")
|
||||
startup_timer.record("install CodeFormer requirements")
|
||||
|
||||
if not os.path.isfile(requirements_file):
|
||||
requirements_file = os.path.join(script_path, requirements_file)
|
||||
|
||||
if not requirements_met(requirements_file):
|
||||
run_pip(f"install -r \"{requirements_file}\"", "requirements")
|
||||
startup_timer.record("install requirements")
|
||||
|
||||
run_extensions_installers(settings_file=args.ui_settings_file)
|
||||
|
||||
if args.update_check:
|
||||
version_check(commit)
|
||||
startup_timer.record("check version")
|
||||
|
||||
if args.update_all_extensions:
|
||||
git_pull_recursive(extensions_dir)
|
||||
startup_timer.record("update extensions")
|
||||
|
||||
if "--exit" in sys.argv:
|
||||
print("Exiting because of --exit argument")
|
||||
|
||||
@@ -15,6 +15,9 @@ def send_everything_to_cpu():
|
||||
|
||||
|
||||
def setup_for_low_vram(sd_model, use_medvram):
|
||||
if getattr(sd_model, 'lowvram', False):
|
||||
return
|
||||
|
||||
sd_model.lowvram = True
|
||||
|
||||
parents = {}
|
||||
|
||||
+190
-78
@@ -16,6 +16,7 @@ from typing import Any, Dict, List
|
||||
import modules.sd_hijack
|
||||
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors
|
||||
from modules.sd_hijack import model_hijack
|
||||
from modules.sd_samplers_common import images_tensor_to_samples, decode_first_stage, approximation_indexes
|
||||
from modules.shared import opts, cmd_opts, state
|
||||
import modules.shared as shared
|
||||
import modules.paths as paths
|
||||
@@ -83,7 +84,7 @@ def txt2img_image_conditioning(sd_model, x, width, height):
|
||||
|
||||
# The "masked-image" in this case will just be all zeros since the entire image is masked.
|
||||
image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device)
|
||||
image_conditioning = sd_model.get_first_stage_encoding(sd_model.encode_first_stage(image_conditioning))
|
||||
image_conditioning = images_tensor_to_samples(image_conditioning, approximation_indexes.get(opts.sd_vae_encode_method))
|
||||
|
||||
# Add the fake full 1s mask to the first dimension.
|
||||
image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
|
||||
@@ -109,7 +110,7 @@ class StableDiffusionProcessing:
|
||||
cached_uc = [None, None]
|
||||
cached_c = [None, None]
|
||||
|
||||
def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_min_uncond: float = 0.0, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None, override_settings_restore_afterwards: bool = True, sampler_index: int = None, script_args: list = None):
|
||||
def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_min_uncond: float = 0.0, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = None, override_settings: Dict[str, Any] = None, override_settings_restore_afterwards: bool = True, sampler_index: int = None, script_args: list = None):
|
||||
if sampler_index is not None:
|
||||
print("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name", file=sys.stderr)
|
||||
|
||||
@@ -147,8 +148,8 @@ class StableDiffusionProcessing:
|
||||
self.s_min_uncond = s_min_uncond or opts.s_min_uncond
|
||||
self.s_churn = s_churn or opts.s_churn
|
||||
self.s_tmin = s_tmin or opts.s_tmin
|
||||
self.s_tmax = s_tmax or float('inf') # not representable as a standard ui option
|
||||
self.s_noise = s_noise or opts.s_noise
|
||||
self.s_tmax = (s_tmax if s_tmax is not None else opts.s_tmax) or float('inf')
|
||||
self.s_noise = s_noise if s_noise is not None else opts.s_noise
|
||||
self.override_settings = {k: v for k, v in (override_settings or {}).items() if k not in shared.restricted_opts}
|
||||
self.override_settings_restore_afterwards = override_settings_restore_afterwards
|
||||
self.is_using_inpainting_conditioning = False
|
||||
@@ -177,6 +178,8 @@ class StableDiffusionProcessing:
|
||||
self.extra_network_data = None
|
||||
self.seeds = None
|
||||
self.subseeds = None
|
||||
self.recorded_checkpoint = None
|
||||
self.recorded_checkpoint_hash = None
|
||||
|
||||
self.step_multiplier = 1
|
||||
self.cached_uc = StableDiffusionProcessing.cached_uc
|
||||
@@ -185,6 +188,7 @@ class StableDiffusionProcessing:
|
||||
self.c = None
|
||||
|
||||
self.user = None
|
||||
self.image_conditioning = None
|
||||
|
||||
@property
|
||||
def sd_model(self):
|
||||
@@ -202,7 +206,7 @@ class StableDiffusionProcessing:
|
||||
midas_in = torch.from_numpy(transformed["midas_in"][None, ...]).to(device=shared.device)
|
||||
midas_in = repeat(midas_in, "1 ... -> n ...", n=self.batch_size)
|
||||
|
||||
conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image))
|
||||
conditioning_image = images_tensor_to_samples(source_image*0.5+0.5, approximation_indexes.get(opts.sd_vae_encode_method))
|
||||
conditioning = torch.nn.functional.interpolate(
|
||||
self.sd_model.depth_model(midas_in),
|
||||
size=conditioning_image.shape[2:],
|
||||
@@ -215,7 +219,7 @@ class StableDiffusionProcessing:
|
||||
return conditioning
|
||||
|
||||
def edit_image_conditioning(self, source_image):
|
||||
conditioning_image = self.sd_model.encode_first_stage(source_image).mode()
|
||||
conditioning_image = images_tensor_to_samples(source_image*0.5+0.5, approximation_indexes.get(opts.sd_vae_encode_method))
|
||||
|
||||
return conditioning_image
|
||||
|
||||
@@ -275,10 +279,10 @@ class StableDiffusionProcessing:
|
||||
if self.sd_model.cond_stage_key == "edit":
|
||||
return self.edit_image_conditioning(source_image)
|
||||
|
||||
if self.sampler.conditioning_key in {'hybrid', 'concat'}:
|
||||
if self.sd_model.model.conditioning_key in {'hybrid', 'concat'}:
|
||||
return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask)
|
||||
|
||||
if self.sampler.conditioning_key == "crossattn-adm":
|
||||
if self.sd_model.model.conditioning_key == "crossattn-adm":
|
||||
return self.unclip_image_conditioning(source_image)
|
||||
|
||||
# Dummy zero conditioning if we're not using inpainting or depth model.
|
||||
@@ -294,7 +298,7 @@ class StableDiffusionProcessing:
|
||||
self.sampler = None
|
||||
self.c = None
|
||||
self.uc = None
|
||||
if not opts.experimental_persistent_cond_cache:
|
||||
if not opts.persistent_cond_cache:
|
||||
StableDiffusionProcessing.cached_c = [None, None]
|
||||
StableDiffusionProcessing.cached_uc = [None, None]
|
||||
|
||||
@@ -318,6 +322,21 @@ class StableDiffusionProcessing:
|
||||
self.all_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, self.styles) for x in self.all_prompts]
|
||||
self.all_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, self.styles) for x in self.all_negative_prompts]
|
||||
|
||||
def cached_params(self, required_prompts, steps, extra_network_data):
|
||||
"""Returns parameters that invalidate the cond cache if changed"""
|
||||
|
||||
return (
|
||||
required_prompts,
|
||||
steps,
|
||||
opts.CLIP_stop_at_last_layers,
|
||||
shared.sd_model.sd_checkpoint_info,
|
||||
extra_network_data,
|
||||
opts.sdxl_crop_left,
|
||||
opts.sdxl_crop_top,
|
||||
self.width,
|
||||
self.height,
|
||||
)
|
||||
|
||||
def get_conds_with_caching(self, function, required_prompts, steps, caches, extra_network_data):
|
||||
"""
|
||||
Returns the result of calling function(shared.sd_model, required_prompts, steps)
|
||||
@@ -331,17 +350,7 @@ class StableDiffusionProcessing:
|
||||
caches is a list with items described above.
|
||||
"""
|
||||
|
||||
cached_params = (
|
||||
required_prompts,
|
||||
steps,
|
||||
opts.CLIP_stop_at_last_layers,
|
||||
shared.sd_model.sd_checkpoint_info,
|
||||
extra_network_data,
|
||||
opts.sdxl_crop_left,
|
||||
opts.sdxl_crop_top,
|
||||
self.width,
|
||||
self.height,
|
||||
)
|
||||
cached_params = self.cached_params(required_prompts, steps, extra_network_data)
|
||||
|
||||
for cache in caches:
|
||||
if cache[0] is not None and cached_params == cache[0]:
|
||||
@@ -367,6 +376,58 @@ class StableDiffusionProcessing:
|
||||
def parse_extra_network_prompts(self):
|
||||
self.prompts, self.extra_network_data = extra_networks.parse_prompts(self.prompts)
|
||||
|
||||
def save_samples(self) -> bool:
|
||||
"""Returns whether generated images need to be written to disk"""
|
||||
return opts.samples_save and not self.do_not_save_samples and (opts.save_incomplete_images or not state.interrupted and not state.skipped)
|
||||
|
||||
def run_refiner(self, samples):
|
||||
shared.state.nextjob()
|
||||
|
||||
stopped_at = self.sampler.stop_at
|
||||
noisy_output = self.sampler.noisy_output
|
||||
self.sampler = None
|
||||
|
||||
a_is_sdxl = shared.sd_model.is_sdxl
|
||||
decoded_noisy = decode_latent_batch(shared.sd_model, noisy_output, target_device=devices.cpu, check_for_nans=True)
|
||||
|
||||
refiner_checkpoint_info = sd_models.get_closet_checkpoint_match(shared.opts.sd_refiner_checkpoint)
|
||||
if refiner_checkpoint_info is None:
|
||||
raise Exception(f'Could not find checkpoint with name {shared.opts.sd_refiner_checkpoint}')
|
||||
|
||||
self.recorded_checkpoint = shared.sd_model.sd_checkpoint_info.name_for_extra
|
||||
self.recorded_checkpoint_hash = shared.sd_model.sd_model_hash
|
||||
self.extra_generation_params['Refiner'] = refiner_checkpoint_info.short_title
|
||||
self.extra_generation_params['Refiner switch at'] = shared.opts.sd_refiner_switch_at
|
||||
|
||||
with sd_models.SkipWritingToConfig():
|
||||
sd_models.reload_model_weights(info=refiner_checkpoint_info)
|
||||
|
||||
devices.torch_gc()
|
||||
self.setup_conds()
|
||||
|
||||
b_is_sdxl = shared.sd_model.is_sdxl
|
||||
|
||||
if a_is_sdxl != b_is_sdxl:
|
||||
decoded_noisy = torch.stack(decoded_noisy).float()
|
||||
decoded_noisy = torch.clamp((decoded_noisy + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
noisy_latent = images_tensor_to_samples(decoded_noisy, approximation_indexes.get(opts.sd_vae_encode_method), shared.sd_model)
|
||||
else:
|
||||
noisy_latent = noisy_output
|
||||
|
||||
x = torch.zeros_like(noisy_latent)
|
||||
|
||||
with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():
|
||||
denoising_strength = self.denoising_strength
|
||||
|
||||
self.denoising_strength = 1.0 - (stopped_at + 1) / self.steps
|
||||
self.image_conditioning = txt2img_image_conditioning(shared.sd_model, noisy_latent, self.width, self.height)
|
||||
self.sampler = sd_samplers.create_sampler(self.sampler_name, shared.sd_model)
|
||||
samples = self.sampler.sample_img2img(self, noisy_latent, x, self.c, self.uc, image_conditioning=self.image_conditioning, steps=max(1, self.steps - stopped_at - 1))
|
||||
|
||||
self.denoising_strength = denoising_strength
|
||||
|
||||
return samples
|
||||
|
||||
|
||||
class Processed:
|
||||
def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_negative_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None, comments=""):
|
||||
@@ -492,7 +553,7 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see
|
||||
noise_shape = shape if seed_resize_from_h <= 0 or seed_resize_from_w <= 0 else (shape[0], seed_resize_from_h//8, seed_resize_from_w//8)
|
||||
|
||||
subnoise = None
|
||||
if subseeds is not None:
|
||||
if subseeds is not None and subseed_strength != 0:
|
||||
subseed = 0 if i >= len(subseeds) else subseeds[i]
|
||||
|
||||
subnoise = devices.randn(subseed, noise_shape)
|
||||
@@ -524,7 +585,7 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see
|
||||
cnt = p.sampler.number_of_needed_noises(p)
|
||||
|
||||
if eta_noise_seed_delta > 0:
|
||||
torch.manual_seed(seed + eta_noise_seed_delta)
|
||||
devices.manual_seed(seed + eta_noise_seed_delta)
|
||||
|
||||
for j in range(cnt):
|
||||
sampler_noises[j].append(devices.randn_without_seed(tuple(noise_shape)))
|
||||
@@ -538,8 +599,15 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see
|
||||
return x
|
||||
|
||||
|
||||
class DecodedSamples(list):
|
||||
already_decoded = True
|
||||
|
||||
|
||||
def decode_latent_batch(model, batch, target_device=None, check_for_nans=False):
|
||||
samples = []
|
||||
if getattr(batch, 'already_decoded', False):
|
||||
return batch
|
||||
|
||||
samples = DecodedSamples()
|
||||
|
||||
for i in range(batch.shape[0]):
|
||||
sample = decode_first_stage(model, batch[i:i + 1])[0]
|
||||
@@ -572,12 +640,6 @@ def decode_latent_batch(model, batch, target_device=None, check_for_nans=False):
|
||||
return samples
|
||||
|
||||
|
||||
def decode_first_stage(model, x):
|
||||
x = model.decode_first_stage(x.to(devices.dtype_vae))
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def get_fixed_seed(seed):
|
||||
if seed is None or seed == '' or seed == -1:
|
||||
return int(random.randrange(4294967294))
|
||||
@@ -624,8 +686,8 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
|
||||
"Seed": p.all_seeds[0] if use_main_prompt else all_seeds[index],
|
||||
"Face restoration": (opts.face_restoration_model if p.restore_faces else None),
|
||||
"Size": f"{p.width}x{p.height}",
|
||||
"Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
|
||||
"Model": (None if not opts.add_model_name_to_info else shared.sd_model.sd_checkpoint_info.name_for_extra),
|
||||
"Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else p.recorded_checkpoint_hash or shared.sd_model.sd_model_hash),
|
||||
"Model": (None if not opts.add_model_name_to_info else p.recorded_checkpoint or shared.sd_model.sd_checkpoint_info.name_for_extra),
|
||||
"Variation seed": (None if p.subseed_strength == 0 else (p.all_subseeds[0] if use_main_prompt else all_subseeds[index])),
|
||||
"Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
|
||||
"Seed resize from": (None if p.seed_resize_from_w <= 0 or p.seed_resize_from_h <= 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
|
||||
@@ -636,7 +698,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
|
||||
"Token merging ratio": None if token_merging_ratio == 0 else token_merging_ratio,
|
||||
"Token merging ratio hr": None if not enable_hr or token_merging_ratio_hr == 0 else token_merging_ratio_hr,
|
||||
"Init image hash": getattr(p, 'init_img_hash', None),
|
||||
"RNG": opts.randn_source if opts.randn_source != "GPU" else None,
|
||||
"RNG": opts.randn_source if opts.randn_source != "GPU" and opts.randn_source != "NV" else None,
|
||||
"NGMS": None if p.s_min_uncond == 0 else p.s_min_uncond,
|
||||
**p.extra_generation_params,
|
||||
"Version": program_version() if opts.add_version_to_infotext else None,
|
||||
@@ -658,6 +720,10 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
||||
stored_opts = {k: opts.data[k] for k in p.override_settings.keys()}
|
||||
|
||||
try:
|
||||
# after running refiner, the refiner model is not unloaded - webui swaps back to main model here
|
||||
if shared.sd_model.sd_checkpoint_info.title != opts.sd_model_checkpoint:
|
||||
sd_models.reload_model_weights()
|
||||
|
||||
# if no checkpoint override or the override checkpoint can't be found, remove override entry and load opts checkpoint
|
||||
if sd_models.checkpoint_aliases.get(p.override_settings.get('sd_model_checkpoint')) is None:
|
||||
p.override_settings.pop('sd_model_checkpoint', None)
|
||||
@@ -729,6 +795,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||
infotexts = []
|
||||
output_images = []
|
||||
|
||||
have_refiner = shared.opts.sd_refiner_switch_at < 1.0 and shared.sd_model.sd_checkpoint_info.title != shared.opts.sd_refiner_checkpoint
|
||||
|
||||
with torch.no_grad(), p.sd_model.ema_scope():
|
||||
with devices.autocast():
|
||||
p.init(p.all_prompts, p.all_seeds, p.all_subseeds)
|
||||
@@ -742,6 +810,10 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||
if state.job_count == -1:
|
||||
state.job_count = p.n_iter
|
||||
|
||||
if have_refiner:
|
||||
state.job_count *= 2
|
||||
shared.total_tqdm.updateTotal(p.steps * state.job_count // 2)
|
||||
|
||||
for n in range(p.n_iter):
|
||||
p.iteration = n
|
||||
|
||||
@@ -751,6 +823,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||
if state.interrupted:
|
||||
break
|
||||
|
||||
sd_models.reload_model_weights() # model can be changed for example by refiner
|
||||
|
||||
p.prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
|
||||
p.negative_prompts = p.all_negative_prompts[n * p.batch_size:(n + 1) * p.batch_size]
|
||||
p.seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
|
||||
@@ -791,9 +865,21 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||
shared.state.job = f"Batch {n+1} out of {p.n_iter}"
|
||||
|
||||
with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():
|
||||
p.sampler = sd_samplers.create_sampler(p.sampler_name, p.sd_model)
|
||||
|
||||
if have_refiner:
|
||||
p.sampler.stop_at = max(1, int(shared.opts.sd_refiner_switch_at * p.steps - 1))
|
||||
|
||||
samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts)
|
||||
|
||||
if opts.sd_vae_decode_method != 'Full':
|
||||
p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method
|
||||
|
||||
if have_refiner:
|
||||
samples_ddim = p.run_refiner(samples_ddim)
|
||||
|
||||
x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True)
|
||||
|
||||
x_samples_ddim = torch.stack(x_samples_ddim).float()
|
||||
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
|
||||
@@ -817,6 +903,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||
def infotext(index=0, use_main_prompt=False):
|
||||
return create_infotext(p, p.prompts, p.seeds, p.subseeds, use_main_prompt=use_main_prompt, index=index, all_negative_prompts=p.negative_prompts)
|
||||
|
||||
save_samples = p.save_samples()
|
||||
|
||||
for i, x_sample in enumerate(x_samples_ddim):
|
||||
p.batch_index = i
|
||||
|
||||
@@ -824,7 +912,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||
x_sample = x_sample.astype(np.uint8)
|
||||
|
||||
if p.restore_faces:
|
||||
if opts.save and not p.do_not_save_samples and opts.save_images_before_face_restoration:
|
||||
if save_samples and opts.save_images_before_face_restoration:
|
||||
images.save_image(Image.fromarray(x_sample), p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-before-face-restoration")
|
||||
|
||||
devices.torch_gc()
|
||||
@@ -838,16 +926,15 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||
pp = scripts.PostprocessImageArgs(image)
|
||||
p.scripts.postprocess_image(p, pp)
|
||||
image = pp.image
|
||||
|
||||
if p.color_corrections is not None and i < len(p.color_corrections):
|
||||
if opts.save and not p.do_not_save_samples and opts.save_images_before_color_correction:
|
||||
if save_samples and opts.save_images_before_color_correction:
|
||||
image_without_cc = apply_overlay(image, p.paste_to, i, p.overlay_images)
|
||||
images.save_image(image_without_cc, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-before-color-correction")
|
||||
image = apply_color_correction(p.color_corrections[i], image)
|
||||
|
||||
image = apply_overlay(image, p.paste_to, i, p.overlay_images)
|
||||
|
||||
if opts.samples_save and not p.do_not_save_samples:
|
||||
if save_samples:
|
||||
images.save_image(image, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p)
|
||||
|
||||
text = infotext(i)
|
||||
@@ -855,8 +942,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||
if opts.enable_pnginfo:
|
||||
image.info["parameters"] = text
|
||||
output_images.append(image)
|
||||
|
||||
if hasattr(p, 'mask_for_overlay') and p.mask_for_overlay and any([opts.save_mask, opts.save_mask_composite, opts.return_mask, opts.return_mask_composite]):
|
||||
if save_samples and hasattr(p, 'mask_for_overlay') and p.mask_for_overlay and any([opts.save_mask, opts.save_mask_composite, opts.return_mask, opts.return_mask_composite]):
|
||||
image_mask = p.mask_for_overlay.convert('RGB')
|
||||
image_mask_composite = Image.composite(image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), images.resize_image(2, p.mask_for_overlay, image.width, image.height).convert('L')).convert('RGBA')
|
||||
|
||||
@@ -892,7 +978,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||
grid.info["parameters"] = text
|
||||
output_images.insert(0, grid)
|
||||
index_of_first_image = 1
|
||||
|
||||
if opts.grid_save:
|
||||
images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(use_main_prompt=True), short_filename=not opts.grid_extended_filename, p=p, grid=True)
|
||||
|
||||
@@ -935,7 +1020,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||
cached_hr_uc = [None, None]
|
||||
cached_hr_c = [None, None]
|
||||
|
||||
def __init__(self, enable_hr: bool = False, denoising_strength: float = 0.75, firstphase_width: int = 0, firstphase_height: int = 0, hr_scale: float = 2.0, hr_upscaler: str = None, hr_second_pass_steps: int = 0, hr_resize_x: int = 0, hr_resize_y: int = 0, hr_sampler_name: str = None, hr_prompt: str = '', hr_negative_prompt: str = '', **kwargs):
|
||||
def __init__(self, enable_hr: bool = False, denoising_strength: float = 0.75, firstphase_width: int = 0, firstphase_height: int = 0, hr_scale: float = 2.0, hr_upscaler: str = None, hr_second_pass_steps: int = 0, hr_resize_x: int = 0, hr_resize_y: int = 0, hr_checkpoint_name: str = None, hr_sampler_name: str = None, hr_prompt: str = '', hr_negative_prompt: str = '', **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.enable_hr = enable_hr
|
||||
self.denoising_strength = denoising_strength
|
||||
@@ -946,11 +1031,14 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||
self.hr_resize_y = hr_resize_y
|
||||
self.hr_upscale_to_x = hr_resize_x
|
||||
self.hr_upscale_to_y = hr_resize_y
|
||||
self.hr_checkpoint_name = hr_checkpoint_name
|
||||
self.hr_checkpoint_info = None
|
||||
self.hr_sampler_name = hr_sampler_name
|
||||
self.hr_prompt = hr_prompt
|
||||
self.hr_negative_prompt = hr_negative_prompt
|
||||
self.all_hr_prompts = None
|
||||
self.all_hr_negative_prompts = None
|
||||
self.latent_scale_mode = None
|
||||
|
||||
if firstphase_width != 0 or firstphase_height != 0:
|
||||
self.hr_upscale_to_x = self.width
|
||||
@@ -973,6 +1061,14 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||
|
||||
def init(self, all_prompts, all_seeds, all_subseeds):
|
||||
if self.enable_hr:
|
||||
if self.hr_checkpoint_name:
|
||||
self.hr_checkpoint_info = sd_models.get_closet_checkpoint_match(self.hr_checkpoint_name)
|
||||
|
||||
if self.hr_checkpoint_info is None:
|
||||
raise Exception(f'Could not find checkpoint with name {self.hr_checkpoint_name}')
|
||||
|
||||
self.extra_generation_params["Hires checkpoint"] = self.hr_checkpoint_info.short_title
|
||||
|
||||
if self.hr_sampler_name is not None and self.hr_sampler_name != self.sampler_name:
|
||||
self.extra_generation_params["Hires sampler"] = self.hr_sampler_name
|
||||
|
||||
@@ -982,6 +1078,11 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||
if tuple(self.hr_negative_prompt) != tuple(self.negative_prompt):
|
||||
self.extra_generation_params["Hires negative prompt"] = self.hr_negative_prompt
|
||||
|
||||
self.latent_scale_mode = shared.latent_upscale_modes.get(self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_modes.get(shared.latent_upscale_default_mode, "nearest")
|
||||
if self.enable_hr and self.latent_scale_mode is None:
|
||||
if not any(x.name == self.hr_upscaler for x in shared.sd_upscalers):
|
||||
raise Exception(f"could not find upscaler named {self.hr_upscaler}")
|
||||
|
||||
if opts.use_old_hires_fix_width_height and self.applied_old_hires_behavior_to != (self.width, self.height):
|
||||
self.hr_resize_x = self.width
|
||||
self.hr_resize_y = self.height
|
||||
@@ -1020,14 +1121,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||
self.truncate_x = (self.hr_upscale_to_x - target_w) // opt_f
|
||||
self.truncate_y = (self.hr_upscale_to_y - target_h) // opt_f
|
||||
|
||||
# special case: the user has chosen to do nothing
|
||||
if self.hr_upscale_to_x == self.width and self.hr_upscale_to_y == self.height:
|
||||
self.enable_hr = False
|
||||
self.denoising_strength = None
|
||||
self.extra_generation_params.pop("Hires upscale", None)
|
||||
self.extra_generation_params.pop("Hires resize", None)
|
||||
return
|
||||
|
||||
if not state.processing_has_refined_job_count:
|
||||
if state.job_count == -1:
|
||||
state.job_count = self.n_iter
|
||||
@@ -1043,19 +1136,32 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||
self.extra_generation_params["Hires upscaler"] = self.hr_upscaler
|
||||
|
||||
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
|
||||
self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
|
||||
|
||||
latent_scale_mode = shared.latent_upscale_modes.get(self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_modes.get(shared.latent_upscale_default_mode, "nearest")
|
||||
if self.enable_hr and latent_scale_mode is None:
|
||||
if not any(x.name == self.hr_upscaler for x in shared.sd_upscalers):
|
||||
raise Exception(f"could not find upscaler named {self.hr_upscaler}")
|
||||
|
||||
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
|
||||
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
|
||||
del x
|
||||
|
||||
if not self.enable_hr:
|
||||
return samples
|
||||
|
||||
if self.latent_scale_mode is None:
|
||||
decoded_samples = torch.stack(decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)).to(dtype=torch.float32)
|
||||
else:
|
||||
decoded_samples = None
|
||||
|
||||
current = shared.sd_model.sd_checkpoint_info
|
||||
try:
|
||||
if self.hr_checkpoint_info is not None:
|
||||
self.sampler = None
|
||||
sd_models.reload_model_weights(info=self.hr_checkpoint_info)
|
||||
devices.torch_gc()
|
||||
|
||||
return self.sample_hr_pass(samples, decoded_samples, seeds, subseeds, subseed_strength, prompts)
|
||||
finally:
|
||||
self.sampler = None
|
||||
sd_models.reload_model_weights(info=current)
|
||||
devices.torch_gc()
|
||||
|
||||
def sample_hr_pass(self, samples, decoded_samples, seeds, subseeds, subseed_strength, prompts):
|
||||
self.is_hr_pass = True
|
||||
|
||||
target_width = self.hr_upscale_to_x
|
||||
@@ -1064,7 +1170,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||
def save_intermediate(image, index):
|
||||
"""saves image before applying hires fix, if enabled in options; takes as an argument either an image or batch with latent space images"""
|
||||
|
||||
if not opts.save or self.do_not_save_samples or not opts.save_images_before_highres_fix:
|
||||
if not self.save_samples() or not opts.save_images_before_highres_fix:
|
||||
return
|
||||
|
||||
if not isinstance(image, Image.Image):
|
||||
@@ -1073,11 +1179,18 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||
info = create_infotext(self, self.all_prompts, self.all_seeds, self.all_subseeds, [], iteration=self.iteration, position_in_batch=index)
|
||||
images.save_image(image, self.outpath_samples, "", seeds[index], prompts[index], opts.samples_format, info=info, p=self, suffix="-before-highres-fix")
|
||||
|
||||
if latent_scale_mode is not None:
|
||||
img2img_sampler_name = self.hr_sampler_name or self.sampler_name
|
||||
|
||||
if self.sampler_name in ['PLMS', 'UniPC']: # PLMS/UniPC do not support img2img so we just silently switch to DDIM
|
||||
img2img_sampler_name = 'DDIM'
|
||||
|
||||
self.sampler = sd_samplers.create_sampler(img2img_sampler_name, self.sd_model)
|
||||
|
||||
if self.latent_scale_mode is not None:
|
||||
for i in range(samples.shape[0]):
|
||||
save_intermediate(samples, i)
|
||||
|
||||
samples = torch.nn.functional.interpolate(samples, size=(target_height // opt_f, target_width // opt_f), mode=latent_scale_mode["mode"], antialias=latent_scale_mode["antialias"])
|
||||
samples = torch.nn.functional.interpolate(samples, size=(target_height // opt_f, target_width // opt_f), mode=self.latent_scale_mode["mode"], antialias=self.latent_scale_mode["antialias"])
|
||||
|
||||
# Avoid making the inpainting conditioning unless necessary as
|
||||
# this does need some extra compute to decode / encode the image again.
|
||||
@@ -1086,7 +1199,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||
else:
|
||||
image_conditioning = self.txt2img_image_conditioning(samples)
|
||||
else:
|
||||
decoded_samples = decode_first_stage(self.sd_model, samples)
|
||||
lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
|
||||
batch_images = []
|
||||
@@ -1103,28 +1215,21 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||
batch_images.append(image)
|
||||
|
||||
decoded_samples = torch.from_numpy(np.array(batch_images))
|
||||
decoded_samples = decoded_samples.to(shared.device)
|
||||
decoded_samples = 2. * decoded_samples - 1.
|
||||
decoded_samples = decoded_samples.to(shared.device, dtype=devices.dtype_vae)
|
||||
|
||||
samples = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(decoded_samples))
|
||||
if opts.sd_vae_encode_method != 'Full':
|
||||
self.extra_generation_params['VAE Encoder'] = opts.sd_vae_encode_method
|
||||
samples = images_tensor_to_samples(decoded_samples, approximation_indexes.get(opts.sd_vae_encode_method))
|
||||
|
||||
image_conditioning = self.img2img_image_conditioning(decoded_samples, samples)
|
||||
|
||||
shared.state.nextjob()
|
||||
|
||||
img2img_sampler_name = self.hr_sampler_name or self.sampler_name
|
||||
|
||||
if self.sampler_name in ['PLMS', 'UniPC']: # PLMS/UniPC do not support img2img so we just silently switch to DDIM
|
||||
img2img_sampler_name = 'DDIM'
|
||||
|
||||
self.sampler = sd_samplers.create_sampler(img2img_sampler_name, self.sd_model)
|
||||
|
||||
samples = samples[:, :, self.truncate_y//2:samples.shape[2]-(self.truncate_y+1)//2, self.truncate_x//2:samples.shape[3]-(self.truncate_x+1)//2]
|
||||
|
||||
noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, p=self)
|
||||
|
||||
# GC now before running the next img2img to prevent running out of memory
|
||||
x = None
|
||||
devices.torch_gc()
|
||||
|
||||
if not self.disable_extra_networks:
|
||||
@@ -1143,15 +1248,17 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||
|
||||
sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio())
|
||||
|
||||
decoded_samples = decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)
|
||||
|
||||
self.is_hr_pass = False
|
||||
|
||||
return samples
|
||||
return decoded_samples
|
||||
|
||||
def close(self):
|
||||
super().close()
|
||||
self.hr_c = None
|
||||
self.hr_uc = None
|
||||
if not opts.experimental_persistent_cond_cache:
|
||||
if not opts.persistent_cond_cache:
|
||||
StableDiffusionProcessingTxt2Img.cached_hr_uc = [None, None]
|
||||
StableDiffusionProcessingTxt2Img.cached_hr_c = [None, None]
|
||||
|
||||
@@ -1184,8 +1291,11 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||
if self.hr_c is not None:
|
||||
return
|
||||
|
||||
self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, self.hr_negative_prompts, self.steps * self.step_multiplier, [self.cached_hr_uc, self.cached_uc], self.hr_extra_network_data)
|
||||
self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, self.hr_prompts, self.steps * self.step_multiplier, [self.cached_hr_c, self.cached_c], self.hr_extra_network_data)
|
||||
hr_prompts = prompt_parser.SdConditioning(self.hr_prompts, width=self.hr_upscale_to_x, height=self.hr_upscale_to_y)
|
||||
hr_negative_prompts = prompt_parser.SdConditioning(self.hr_negative_prompts, width=self.hr_upscale_to_x, height=self.hr_upscale_to_y, is_negative_prompt=True)
|
||||
|
||||
self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, hr_negative_prompts, self.steps * self.step_multiplier, [self.cached_hr_uc, self.cached_uc], self.hr_extra_network_data)
|
||||
self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, hr_prompts, self.steps * self.step_multiplier, [self.cached_hr_c, self.cached_c], self.hr_extra_network_data)
|
||||
|
||||
def setup_conds(self):
|
||||
super().setup_conds()
|
||||
@@ -1193,7 +1303,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||
self.hr_uc = None
|
||||
self.hr_c = None
|
||||
|
||||
if self.enable_hr:
|
||||
if self.enable_hr and self.hr_checkpoint_info is None:
|
||||
if shared.opts.hires_fix_use_firstpass_conds:
|
||||
self.calculate_hr_conds()
|
||||
|
||||
@@ -1247,7 +1357,6 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
||||
self.image_conditioning = None
|
||||
|
||||
def init(self, all_prompts, all_seeds, all_subseeds):
|
||||
self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
|
||||
crop_region = None
|
||||
|
||||
image_mask = self.image_mask
|
||||
@@ -1344,10 +1453,13 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
||||
raise RuntimeError(f"bad number of images passed: {len(imgs)}; expecting {self.batch_size} or less")
|
||||
|
||||
image = torch.from_numpy(batch_images)
|
||||
image = 2. * image - 1.
|
||||
image = image.to(shared.device, dtype=devices.dtype_vae)
|
||||
|
||||
self.init_latent = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image))
|
||||
if opts.sd_vae_encode_method != 'Full':
|
||||
self.extra_generation_params['VAE Encoder'] = opts.sd_vae_encode_method
|
||||
|
||||
self.init_latent = images_tensor_to_samples(image, approximation_indexes.get(opts.sd_vae_encode_method), self.sd_model)
|
||||
devices.torch_gc()
|
||||
|
||||
if self.resize_mode == 3:
|
||||
self.init_latent = torch.nn.functional.interpolate(self.init_latent, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
|
||||
|
||||
+15
-10
@@ -19,8 +19,8 @@ prompt: (emphasized | scheduled | alternate | plain | WHITESPACE)*
|
||||
!emphasized: "(" prompt ")"
|
||||
| "(" prompt ":" prompt ")"
|
||||
| "[" prompt "]"
|
||||
scheduled: "[" [prompt ":"] prompt ":" [WHITESPACE] NUMBER "]"
|
||||
alternate: "[" prompt ("|" prompt)+ "]"
|
||||
scheduled: "[" [prompt ":"] prompt ":" [WHITESPACE] NUMBER [WHITESPACE] "]"
|
||||
alternate: "[" prompt ("|" [prompt])+ "]"
|
||||
WHITESPACE: /\s+/
|
||||
plain: /([^\\\[\]():|]|\\.)+/
|
||||
%import common.SIGNED_NUMBER -> NUMBER
|
||||
@@ -53,6 +53,10 @@ def get_learned_conditioning_prompt_schedules(prompts, steps):
|
||||
[[3, '((a][:b:c '], [10, '((a][:b:c d']]
|
||||
>>> g("[a|(b:1.1)]")
|
||||
[[1, 'a'], [2, '(b:1.1)'], [3, 'a'], [4, '(b:1.1)'], [5, 'a'], [6, '(b:1.1)'], [7, 'a'], [8, '(b:1.1)'], [9, 'a'], [10, '(b:1.1)']]
|
||||
>>> g("[fe|]male")
|
||||
[[1, 'female'], [2, 'male'], [3, 'female'], [4, 'male'], [5, 'female'], [6, 'male'], [7, 'female'], [8, 'male'], [9, 'female'], [10, 'male']]
|
||||
>>> g("[fe|||]male")
|
||||
[[1, 'female'], [2, 'male'], [3, 'male'], [4, 'male'], [5, 'female'], [6, 'male'], [7, 'male'], [8, 'male'], [9, 'female'], [10, 'male']]
|
||||
"""
|
||||
|
||||
def collect_steps(steps, tree):
|
||||
@@ -60,11 +64,11 @@ def get_learned_conditioning_prompt_schedules(prompts, steps):
|
||||
|
||||
class CollectSteps(lark.Visitor):
|
||||
def scheduled(self, tree):
|
||||
tree.children[-1] = float(tree.children[-1])
|
||||
if tree.children[-1] < 1:
|
||||
tree.children[-1] *= steps
|
||||
tree.children[-1] = min(steps, int(tree.children[-1]))
|
||||
res.append(tree.children[-1])
|
||||
tree.children[-2] = float(tree.children[-2])
|
||||
if tree.children[-2] < 1:
|
||||
tree.children[-2] *= steps
|
||||
tree.children[-2] = min(steps, int(tree.children[-2]))
|
||||
res.append(tree.children[-2])
|
||||
|
||||
def alternate(self, tree):
|
||||
res.extend(range(1, steps+1))
|
||||
@@ -75,10 +79,11 @@ def get_learned_conditioning_prompt_schedules(prompts, steps):
|
||||
def at_step(step, tree):
|
||||
class AtStep(lark.Transformer):
|
||||
def scheduled(self, args):
|
||||
before, after, _, when = args
|
||||
before, after, _, when, _ = args
|
||||
yield before or () if step <= when else after
|
||||
def alternate(self, args):
|
||||
yield next(args[(step - 1)%len(args)])
|
||||
args = ["" if not arg else arg for arg in args]
|
||||
yield args[(step - 1) % len(args)]
|
||||
def start(self, args):
|
||||
def flatten(x):
|
||||
if type(x) == str:
|
||||
@@ -333,7 +338,7 @@ re_attention = re.compile(r"""
|
||||
\\|
|
||||
\(|
|
||||
\[|
|
||||
:([+-]?[.\d]+)\)|
|
||||
:\s*([+-]?[.\d]+)\s*\)|
|
||||
\)|
|
||||
]|
|
||||
[^\\()\[\]:]+|
|
||||
|
||||
@@ -0,0 +1,102 @@
|
||||
"""RNG imitiating torch cuda randn on CPU. You are welcome.
|
||||
|
||||
Usage:
|
||||
|
||||
```
|
||||
g = Generator(seed=0)
|
||||
print(g.randn(shape=(3, 4)))
|
||||
```
|
||||
|
||||
Expected output:
|
||||
```
|
||||
[[-0.92466259 -0.42534415 -2.6438457 0.14518388]
|
||||
[-0.12086647 -0.57972564 -0.62285122 -0.32838709]
|
||||
[-1.07454231 -0.36314407 -1.67105067 2.26550497]]
|
||||
```
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
|
||||
philox_m = [0xD2511F53, 0xCD9E8D57]
|
||||
philox_w = [0x9E3779B9, 0xBB67AE85]
|
||||
|
||||
two_pow32_inv = np.array([2.3283064e-10], dtype=np.float32)
|
||||
two_pow32_inv_2pi = np.array([2.3283064e-10 * 6.2831855], dtype=np.float32)
|
||||
|
||||
|
||||
def uint32(x):
|
||||
"""Converts (N,) np.uint64 array into (2, N) np.unit32 array."""
|
||||
return x.view(np.uint32).reshape(-1, 2).transpose(1, 0)
|
||||
|
||||
|
||||
def philox4_round(counter, key):
|
||||
"""A single round of the Philox 4x32 random number generator."""
|
||||
|
||||
v1 = uint32(counter[0].astype(np.uint64) * philox_m[0])
|
||||
v2 = uint32(counter[2].astype(np.uint64) * philox_m[1])
|
||||
|
||||
counter[0] = v2[1] ^ counter[1] ^ key[0]
|
||||
counter[1] = v2[0]
|
||||
counter[2] = v1[1] ^ counter[3] ^ key[1]
|
||||
counter[3] = v1[0]
|
||||
|
||||
|
||||
def philox4_32(counter, key, rounds=10):
|
||||
"""Generates 32-bit random numbers using the Philox 4x32 random number generator.
|
||||
|
||||
Parameters:
|
||||
counter (numpy.ndarray): A 4xN array of 32-bit integers representing the counter values (offset into generation).
|
||||
key (numpy.ndarray): A 2xN array of 32-bit integers representing the key values (seed).
|
||||
rounds (int): The number of rounds to perform.
|
||||
|
||||
Returns:
|
||||
numpy.ndarray: A 4xN array of 32-bit integers containing the generated random numbers.
|
||||
"""
|
||||
|
||||
for _ in range(rounds - 1):
|
||||
philox4_round(counter, key)
|
||||
|
||||
key[0] = key[0] + philox_w[0]
|
||||
key[1] = key[1] + philox_w[1]
|
||||
|
||||
philox4_round(counter, key)
|
||||
return counter
|
||||
|
||||
|
||||
def box_muller(x, y):
|
||||
"""Returns just the first out of two numbers generated by Box–Muller transform algorithm."""
|
||||
u = x * two_pow32_inv + two_pow32_inv / 2
|
||||
v = y * two_pow32_inv_2pi + two_pow32_inv_2pi / 2
|
||||
|
||||
s = np.sqrt(-2.0 * np.log(u))
|
||||
|
||||
r1 = s * np.sin(v)
|
||||
return r1.astype(np.float32)
|
||||
|
||||
|
||||
class Generator:
|
||||
"""RNG that produces same outputs as torch.randn(..., device='cuda') on CPU"""
|
||||
|
||||
def __init__(self, seed):
|
||||
self.seed = seed
|
||||
self.offset = 0
|
||||
|
||||
def randn(self, shape):
|
||||
"""Generate a sequence of n standard normal random variables using the Philox 4x32 random number generator and the Box-Muller transform."""
|
||||
|
||||
n = 1
|
||||
for x in shape:
|
||||
n *= x
|
||||
|
||||
counter = np.zeros((4, n), dtype=np.uint32)
|
||||
counter[0] = self.offset
|
||||
counter[2] = np.arange(n, dtype=np.uint32) # up to 2^32 numbers can be generated - if you want more you'd need to spill into counter[3]
|
||||
self.offset += 1
|
||||
|
||||
key = np.empty(n, dtype=np.uint64)
|
||||
key.fill(self.seed)
|
||||
key = uint32(key)
|
||||
|
||||
g = philox4_32(counter, key)
|
||||
|
||||
return box_muller(g[0], g[1]).reshape(shape) # discard g[2] and g[3]
|
||||
@@ -631,49 +631,3 @@ def reload_script_body_only():
|
||||
|
||||
|
||||
reload_scripts = load_scripts # compatibility alias
|
||||
|
||||
|
||||
def add_classes_to_gradio_component(comp):
|
||||
"""
|
||||
this adds gradio-* to the component for css styling (ie gradio-button to gr.Button), as well as some others
|
||||
"""
|
||||
|
||||
comp.elem_classes = [f"gradio-{comp.get_block_name()}", *(comp.elem_classes or [])]
|
||||
|
||||
if getattr(comp, 'multiselect', False):
|
||||
comp.elem_classes.append('multiselect')
|
||||
|
||||
|
||||
|
||||
def IOComponent_init(self, *args, **kwargs):
|
||||
if scripts_current is not None:
|
||||
scripts_current.before_component(self, **kwargs)
|
||||
|
||||
script_callbacks.before_component_callback(self, **kwargs)
|
||||
|
||||
res = original_IOComponent_init(self, *args, **kwargs)
|
||||
|
||||
add_classes_to_gradio_component(self)
|
||||
|
||||
script_callbacks.after_component_callback(self, **kwargs)
|
||||
|
||||
if scripts_current is not None:
|
||||
scripts_current.after_component(self, **kwargs)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
original_IOComponent_init = gr.components.IOComponent.__init__
|
||||
gr.components.IOComponent.__init__ = IOComponent_init
|
||||
|
||||
|
||||
def BlockContext_init(self, *args, **kwargs):
|
||||
res = original_BlockContext_init(self, *args, **kwargs)
|
||||
|
||||
add_classes_to_gradio_component(self)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
original_BlockContext_init = gr.blocks.BlockContext.__init__
|
||||
gr.blocks.BlockContext.__init__ = BlockContext_init
|
||||
|
||||
@@ -3,8 +3,31 @@ import open_clip
|
||||
import torch
|
||||
import transformers.utils.hub
|
||||
|
||||
from modules import shared
|
||||
|
||||
class DisableInitialization:
|
||||
|
||||
class ReplaceHelper:
|
||||
def __init__(self):
|
||||
self.replaced = []
|
||||
|
||||
def replace(self, obj, field, func):
|
||||
original = getattr(obj, field, None)
|
||||
if original is None:
|
||||
return None
|
||||
|
||||
self.replaced.append((obj, field, original))
|
||||
setattr(obj, field, func)
|
||||
|
||||
return original
|
||||
|
||||
def restore(self):
|
||||
for obj, field, original in self.replaced:
|
||||
setattr(obj, field, original)
|
||||
|
||||
self.replaced.clear()
|
||||
|
||||
|
||||
class DisableInitialization(ReplaceHelper):
|
||||
"""
|
||||
When an object of this class enters a `with` block, it starts:
|
||||
- preventing torch's layer initialization functions from working
|
||||
@@ -21,7 +44,7 @@ class DisableInitialization:
|
||||
"""
|
||||
|
||||
def __init__(self, disable_clip=True):
|
||||
self.replaced = []
|
||||
super().__init__()
|
||||
self.disable_clip = disable_clip
|
||||
|
||||
def replace(self, obj, field, func):
|
||||
@@ -86,8 +109,81 @@ class DisableInitialization:
|
||||
self.transformers_utils_hub_get_from_cache = self.replace(transformers.utils.hub, 'get_from_cache', transformers_utils_hub_get_from_cache)
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
for obj, field, original in self.replaced:
|
||||
setattr(obj, field, original)
|
||||
self.restore()
|
||||
|
||||
self.replaced.clear()
|
||||
|
||||
class InitializeOnMeta(ReplaceHelper):
|
||||
"""
|
||||
Context manager that causes all parameters for linear/conv2d/mha layers to be allocated on meta device,
|
||||
which results in those parameters having no values and taking no memory. model.to() will be broken and
|
||||
will need to be repaired by using LoadStateDictOnMeta below when loading params from state dict.
|
||||
|
||||
Usage:
|
||||
```
|
||||
with sd_disable_initialization.InitializeOnMeta():
|
||||
sd_model = instantiate_from_config(sd_config.model)
|
||||
```
|
||||
"""
|
||||
|
||||
def __enter__(self):
|
||||
if shared.cmd_opts.disable_model_loading_ram_optimization:
|
||||
return
|
||||
|
||||
def set_device(x):
|
||||
x["device"] = "meta"
|
||||
return x
|
||||
|
||||
linear_init = self.replace(torch.nn.Linear, '__init__', lambda *args, **kwargs: linear_init(*args, **set_device(kwargs)))
|
||||
conv2d_init = self.replace(torch.nn.Conv2d, '__init__', lambda *args, **kwargs: conv2d_init(*args, **set_device(kwargs)))
|
||||
mha_init = self.replace(torch.nn.MultiheadAttention, '__init__', lambda *args, **kwargs: mha_init(*args, **set_device(kwargs)))
|
||||
self.replace(torch.nn.Module, 'to', lambda *args, **kwargs: None)
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.restore()
|
||||
|
||||
|
||||
class LoadStateDictOnMeta(ReplaceHelper):
|
||||
"""
|
||||
Context manager that allows to read parameters from state_dict into a model that has some of its parameters in the meta device.
|
||||
As those parameters are read from state_dict, they will be deleted from it, so by the end state_dict will be mostly empty, to save memory.
|
||||
Meant to be used together with InitializeOnMeta above.
|
||||
|
||||
Usage:
|
||||
```
|
||||
with sd_disable_initialization.LoadStateDictOnMeta(state_dict):
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, state_dict, device):
|
||||
super().__init__()
|
||||
self.state_dict = state_dict
|
||||
self.device = device
|
||||
|
||||
def __enter__(self):
|
||||
if shared.cmd_opts.disable_model_loading_ram_optimization:
|
||||
return
|
||||
|
||||
sd = self.state_dict
|
||||
device = self.device
|
||||
|
||||
def load_from_state_dict(original, self, state_dict, prefix, *args, **kwargs):
|
||||
params = [(name, param) for name, param in self._parameters.items() if param is not None and param.is_meta]
|
||||
|
||||
for name, param in params:
|
||||
if param.is_meta:
|
||||
self._parameters[name] = torch.nn.parameter.Parameter(torch.zeros_like(param, device=device), requires_grad=param.requires_grad)
|
||||
|
||||
original(self, state_dict, prefix, *args, **kwargs)
|
||||
|
||||
for name, _ in params:
|
||||
key = prefix + name
|
||||
if key in sd:
|
||||
del sd[key]
|
||||
|
||||
linear_load_from_state_dict = self.replace(torch.nn.Linear, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(linear_load_from_state_dict, *args, **kwargs))
|
||||
conv2d_load_from_state_dict = self.replace(torch.nn.Conv2d, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(conv2d_load_from_state_dict, *args, **kwargs))
|
||||
mha_load_from_state_dict = self.replace(torch.nn.MultiheadAttention, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(mha_load_from_state_dict, *args, **kwargs))
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.restore()
|
||||
|
||||
+15
-9
@@ -2,11 +2,10 @@ import torch
|
||||
from torch.nn.functional import silu
|
||||
from types import MethodType
|
||||
|
||||
import modules.textual_inversion.textual_inversion
|
||||
from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet
|
||||
from modules.hypernetworks import hypernetwork
|
||||
from modules.shared import cmd_opts
|
||||
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr
|
||||
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr, sd_hijack_inpainting
|
||||
|
||||
import ldm.modules.attention
|
||||
import ldm.modules.diffusionmodules.model
|
||||
@@ -30,8 +29,12 @@ ldm.modules.attention.MemoryEfficientCrossAttention = ldm.modules.attention.Cros
|
||||
ldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES["softmax-xformers"] = ldm.modules.attention.CrossAttention
|
||||
|
||||
# silence new console spam from SD2
|
||||
ldm.modules.attention.print = lambda *args: None
|
||||
ldm.modules.diffusionmodules.model.print = lambda *args: None
|
||||
ldm.modules.attention.print = shared.ldm_print
|
||||
ldm.modules.diffusionmodules.model.print = shared.ldm_print
|
||||
ldm.util.print = shared.ldm_print
|
||||
ldm.models.diffusion.ddpm.print = shared.ldm_print
|
||||
|
||||
sd_hijack_inpainting.do_inpainting_hijack()
|
||||
|
||||
optimizers = []
|
||||
current_optimizer: sd_hijack_optimizations.SdOptimization = None
|
||||
@@ -164,12 +167,13 @@ class StableDiffusionModelHijack:
|
||||
clip = None
|
||||
optimization_method = None
|
||||
|
||||
embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase()
|
||||
|
||||
def __init__(self):
|
||||
import modules.textual_inversion.textual_inversion
|
||||
|
||||
self.extra_generation_params = {}
|
||||
self.comments = []
|
||||
|
||||
self.embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase()
|
||||
self.embedding_db.add_embedding_dir(cmd_opts.embeddings_dir)
|
||||
|
||||
def apply_optimizations(self, option=None):
|
||||
@@ -197,7 +201,7 @@ class StableDiffusionModelHijack:
|
||||
conditioner.embedders[i] = sd_hijack_clip.FrozenCLIPEmbedderForSDXLWithCustomWords(embedder, self)
|
||||
text_cond_models.append(conditioner.embedders[i])
|
||||
if typename == 'FrozenOpenCLIPEmbedder2':
|
||||
embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self)
|
||||
embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self, textual_inversion_key='clip_g')
|
||||
conditioner.embedders[i] = sd_hijack_open_clip.FrozenOpenCLIPEmbedder2WithCustomWords(embedder, self)
|
||||
text_cond_models.append(conditioner.embedders[i])
|
||||
|
||||
@@ -292,10 +296,11 @@ class StableDiffusionModelHijack:
|
||||
|
||||
|
||||
class EmbeddingsWithFixes(torch.nn.Module):
|
||||
def __init__(self, wrapped, embeddings):
|
||||
def __init__(self, wrapped, embeddings, textual_inversion_key='clip_l'):
|
||||
super().__init__()
|
||||
self.wrapped = wrapped
|
||||
self.embeddings = embeddings
|
||||
self.textual_inversion_key = textual_inversion_key
|
||||
|
||||
def forward(self, input_ids):
|
||||
batch_fixes = self.embeddings.fixes
|
||||
@@ -309,7 +314,8 @@ class EmbeddingsWithFixes(torch.nn.Module):
|
||||
vecs = []
|
||||
for fixes, tensor in zip(batch_fixes, inputs_embeds):
|
||||
for offset, embedding in fixes:
|
||||
emb = devices.cond_cast_unet(embedding.vec)
|
||||
vec = embedding.vec[self.textual_inversion_key] if isinstance(embedding.vec, dict) else embedding.vec
|
||||
emb = devices.cond_cast_unet(vec)
|
||||
emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0])
|
||||
tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]])
|
||||
|
||||
|
||||
@@ -161,7 +161,7 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
|
||||
position += 1
|
||||
continue
|
||||
|
||||
emb_len = int(embedding.vec.shape[0])
|
||||
emb_len = int(embedding.vectors)
|
||||
if len(chunk.tokens) + emb_len > self.chunk_length:
|
||||
next_chunk()
|
||||
|
||||
@@ -245,6 +245,8 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
|
||||
hashes.append(f"{name}: {shorthash}")
|
||||
|
||||
if hashes:
|
||||
if self.hijack.extra_generation_params.get("TI hashes"):
|
||||
hashes.append(self.hijack.extra_generation_params.get("TI hashes"))
|
||||
self.hijack.extra_generation_params["TI hashes"] = ", ".join(hashes)
|
||||
|
||||
if getattr(self.wrapped, 'return_pooled', False):
|
||||
|
||||
@@ -92,6 +92,4 @@ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=F
|
||||
|
||||
|
||||
def do_inpainting_hijack():
|
||||
# p_sample_plms is needed because PLMS can't work with dicts as conditionings
|
||||
|
||||
ldm.models.diffusion.plms.PLMSSampler.p_sample_plms = p_sample_plms
|
||||
|
||||
@@ -256,9 +256,9 @@ def split_cross_attention_forward(self, x, context=None, mask=None, **kwargs):
|
||||
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
|
||||
f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free')
|
||||
|
||||
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
|
||||
slice_size = q.shape[1] // steps
|
||||
for i in range(0, q.shape[1], slice_size):
|
||||
end = i + slice_size
|
||||
end = min(i + slice_size, q.shape[1])
|
||||
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)
|
||||
|
||||
s2 = s1.softmax(dim=-1, dtype=q.dtype)
|
||||
|
||||
+173
-53
@@ -14,8 +14,7 @@ import ldm.modules.midas as midas
|
||||
|
||||
from ldm.util import instantiate_from_config
|
||||
|
||||
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl
|
||||
from modules.sd_hijack_inpainting import do_inpainting_hijack
|
||||
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache
|
||||
from modules.timer import Timer
|
||||
import tomesd
|
||||
|
||||
@@ -33,6 +32,8 @@ class CheckpointInfo:
|
||||
self.filename = filename
|
||||
abspath = os.path.abspath(filename)
|
||||
|
||||
self.is_safetensors = os.path.splitext(filename)[1].lower() == ".safetensors"
|
||||
|
||||
if shared.cmd_opts.ckpt_dir is not None and abspath.startswith(shared.cmd_opts.ckpt_dir):
|
||||
name = abspath.replace(shared.cmd_opts.ckpt_dir, '')
|
||||
elif abspath.startswith(model_path):
|
||||
@@ -43,6 +44,19 @@ class CheckpointInfo:
|
||||
if name.startswith("\\") or name.startswith("/"):
|
||||
name = name[1:]
|
||||
|
||||
def read_metadata():
|
||||
metadata = read_metadata_from_safetensors(filename)
|
||||
self.modelspec_thumbnail = metadata.pop('modelspec.thumbnail', None)
|
||||
|
||||
return metadata
|
||||
|
||||
self.metadata = {}
|
||||
if self.is_safetensors:
|
||||
try:
|
||||
self.metadata = cache.cached_data_for_file('safetensors-metadata', "checkpoint/" + name, filename, read_metadata)
|
||||
except Exception as e:
|
||||
errors.display(e, f"reading metadata for {filename}")
|
||||
|
||||
self.name = name
|
||||
self.name_for_extra = os.path.splitext(os.path.basename(filename))[0]
|
||||
self.model_name = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0]
|
||||
@@ -52,17 +66,9 @@ class CheckpointInfo:
|
||||
self.shorthash = self.sha256[0:10] if self.sha256 else None
|
||||
|
||||
self.title = name if self.shorthash is None else f'{name} [{self.shorthash}]'
|
||||
self.short_title = self.name_for_extra if self.shorthash is None else f'{self.name_for_extra} [{self.shorthash}]'
|
||||
|
||||
self.ids = [self.hash, self.model_name, self.title, name, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] if self.shorthash else [])
|
||||
|
||||
self.metadata = {}
|
||||
|
||||
_, ext = os.path.splitext(self.filename)
|
||||
if ext.lower() == ".safetensors":
|
||||
try:
|
||||
self.metadata = read_metadata_from_safetensors(filename)
|
||||
except Exception as e:
|
||||
errors.display(e, f"reading checkpoint metadata: {filename}")
|
||||
self.ids = [self.hash, self.model_name, self.title, name, self.name_for_extra, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] if self.shorthash else [])
|
||||
|
||||
def register(self):
|
||||
checkpoints_list[self.title] = self
|
||||
@@ -79,8 +85,9 @@ class CheckpointInfo:
|
||||
if self.shorthash not in self.ids:
|
||||
self.ids += [self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]']
|
||||
|
||||
checkpoints_list.pop(self.title)
|
||||
checkpoints_list.pop(self.title, None)
|
||||
self.title = f'{self.name} [{self.shorthash}]'
|
||||
self.short_title = f'{self.name_for_extra} [{self.shorthash}]'
|
||||
self.register()
|
||||
|
||||
return self.shorthash
|
||||
@@ -101,14 +108,8 @@ def setup_model():
|
||||
enable_midas_autodownload()
|
||||
|
||||
|
||||
def checkpoint_tiles():
|
||||
def convert(name):
|
||||
return int(name) if name.isdigit() else name.lower()
|
||||
|
||||
def alphanumeric_key(key):
|
||||
return [convert(c) for c in re.split('([0-9]+)', key)]
|
||||
|
||||
return sorted([x.title for x in checkpoints_list.values()], key=alphanumeric_key)
|
||||
def checkpoint_tiles(use_short=False):
|
||||
return [x.short_title if use_short else x.title for x in checkpoints_list.values()]
|
||||
|
||||
|
||||
def list_models():
|
||||
@@ -131,11 +132,14 @@ def list_models():
|
||||
elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file:
|
||||
print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr)
|
||||
|
||||
for filename in sorted(model_list, key=str.lower):
|
||||
for filename in model_list:
|
||||
checkpoint_info = CheckpointInfo(filename)
|
||||
checkpoint_info.register()
|
||||
|
||||
|
||||
re_strip_checksum = re.compile(r"\s*\[[^]]+]\s*$")
|
||||
|
||||
|
||||
def get_closet_checkpoint_match(search_string):
|
||||
checkpoint_info = checkpoint_aliases.get(search_string, None)
|
||||
if checkpoint_info is not None:
|
||||
@@ -145,6 +149,11 @@ def get_closet_checkpoint_match(search_string):
|
||||
if found:
|
||||
return found[0]
|
||||
|
||||
search_string_without_checksum = re.sub(re_strip_checksum, '', search_string)
|
||||
found = sorted([info for info in checkpoints_list.values() if search_string_without_checksum in info.title], key=lambda x: len(x.title))
|
||||
if found:
|
||||
return found[0]
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@@ -280,11 +289,27 @@ def get_checkpoint_state_dict(checkpoint_info: CheckpointInfo, timer):
|
||||
return res
|
||||
|
||||
|
||||
class SkipWritingToConfig:
|
||||
"""This context manager prevents load_model_weights from writing checkpoint name to the config when it loads weight."""
|
||||
|
||||
skip = False
|
||||
previous = None
|
||||
|
||||
def __enter__(self):
|
||||
self.previous = SkipWritingToConfig.skip
|
||||
SkipWritingToConfig.skip = True
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, exc_traceback):
|
||||
SkipWritingToConfig.skip = self.previous
|
||||
|
||||
|
||||
def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer):
|
||||
sd_model_hash = checkpoint_info.calculate_shorthash()
|
||||
timer.record("calculate hash")
|
||||
|
||||
shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title
|
||||
if not SkipWritingToConfig.skip:
|
||||
shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title
|
||||
|
||||
if state_dict is None:
|
||||
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
|
||||
@@ -297,12 +322,13 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
|
||||
sd_models_xl.extend_sdxl(model)
|
||||
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
del state_dict
|
||||
timer.record("apply weights to model")
|
||||
|
||||
if shared.opts.sd_checkpoint_cache > 0:
|
||||
# cache newly loaded model
|
||||
checkpoints_loaded[checkpoint_info] = model.state_dict().copy()
|
||||
checkpoints_loaded[checkpoint_info] = state_dict
|
||||
|
||||
del state_dict
|
||||
|
||||
if shared.cmd_opts.opt_channelslast:
|
||||
model.to(memory_format=torch.channels_last)
|
||||
@@ -423,6 +449,7 @@ sdxl_refiner_clip_weight = 'conditioner.embedders.0.model.ln_final.weight'
|
||||
class SdModelData:
|
||||
def __init__(self):
|
||||
self.sd_model = None
|
||||
self.loaded_sd_models = []
|
||||
self.was_loaded_at_least_once = False
|
||||
self.lock = threading.Lock()
|
||||
|
||||
@@ -437,6 +464,7 @@ class SdModelData:
|
||||
|
||||
try:
|
||||
load_model()
|
||||
|
||||
except Exception as e:
|
||||
errors.display(e, "loading stable diffusion model", full_traceback=True)
|
||||
print("", file=sys.stderr)
|
||||
@@ -448,11 +476,24 @@ class SdModelData:
|
||||
def set_sd_model(self, v):
|
||||
self.sd_model = v
|
||||
|
||||
try:
|
||||
self.loaded_sd_models.remove(v)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
if v is not None:
|
||||
self.loaded_sd_models.insert(0, v)
|
||||
|
||||
|
||||
model_data = SdModelData()
|
||||
|
||||
|
||||
def get_empty_cond(sd_model):
|
||||
from modules import extra_networks, processing
|
||||
|
||||
p = processing.StableDiffusionProcessingTxt2Img()
|
||||
extra_networks.activate(p, {})
|
||||
|
||||
if hasattr(sd_model, 'conditioner'):
|
||||
d = sd_model.get_learned_conditioning([""])
|
||||
return d['crossattn']
|
||||
@@ -460,20 +501,43 @@ def get_empty_cond(sd_model):
|
||||
return sd_model.cond_stage_model([""])
|
||||
|
||||
|
||||
def send_model_to_cpu(m):
|
||||
from modules import lowvram
|
||||
|
||||
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
||||
lowvram.send_everything_to_cpu()
|
||||
else:
|
||||
m.to(devices.cpu)
|
||||
|
||||
devices.torch_gc()
|
||||
|
||||
|
||||
def send_model_to_device(m):
|
||||
from modules import lowvram
|
||||
|
||||
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
||||
lowvram.setup_for_low_vram(m, shared.cmd_opts.medvram)
|
||||
else:
|
||||
m.to(shared.device)
|
||||
|
||||
|
||||
def send_model_to_trash(m):
|
||||
m.to(device="meta")
|
||||
devices.torch_gc()
|
||||
|
||||
|
||||
def load_model(checkpoint_info=None, already_loaded_state_dict=None):
|
||||
from modules import lowvram, sd_hijack
|
||||
from modules import sd_hijack
|
||||
checkpoint_info = checkpoint_info or select_checkpoint()
|
||||
|
||||
timer = Timer()
|
||||
|
||||
if model_data.sd_model:
|
||||
sd_hijack.model_hijack.undo_hijack(model_data.sd_model)
|
||||
send_model_to_trash(model_data.sd_model)
|
||||
model_data.sd_model = None
|
||||
gc.collect()
|
||||
devices.torch_gc()
|
||||
|
||||
do_inpainting_hijack()
|
||||
|
||||
timer = Timer()
|
||||
timer.record("unload existing model")
|
||||
|
||||
if already_loaded_state_dict is not None:
|
||||
state_dict = already_loaded_state_dict
|
||||
@@ -495,25 +559,27 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
|
||||
sd_model = None
|
||||
try:
|
||||
with sd_disable_initialization.DisableInitialization(disable_clip=clip_is_included_into_sd or shared.cmd_opts.do_not_download_clip):
|
||||
sd_model = instantiate_from_config(sd_config.model)
|
||||
except Exception:
|
||||
pass
|
||||
with sd_disable_initialization.InitializeOnMeta():
|
||||
sd_model = instantiate_from_config(sd_config.model)
|
||||
|
||||
except Exception as e:
|
||||
errors.display(e, "creating model quickly", full_traceback=True)
|
||||
|
||||
if sd_model is None:
|
||||
print('Failed to create model quickly; will retry using slow method.', file=sys.stderr)
|
||||
sd_model = instantiate_from_config(sd_config.model)
|
||||
|
||||
with sd_disable_initialization.InitializeOnMeta():
|
||||
sd_model = instantiate_from_config(sd_config.model)
|
||||
|
||||
sd_model.used_config = checkpoint_config
|
||||
|
||||
timer.record("create model")
|
||||
|
||||
load_model_weights(sd_model, checkpoint_info, state_dict, timer)
|
||||
|
||||
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
||||
lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram)
|
||||
else:
|
||||
sd_model.to(shared.device)
|
||||
with sd_disable_initialization.LoadStateDictOnMeta(state_dict, devices.cpu):
|
||||
load_model_weights(sd_model, checkpoint_info, state_dict, timer)
|
||||
timer.record("load weights from state dict")
|
||||
|
||||
send_model_to_device(sd_model)
|
||||
timer.record("move model to device")
|
||||
|
||||
sd_hijack.model_hijack.hijack(sd_model)
|
||||
@@ -521,7 +587,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
|
||||
timer.record("hijack")
|
||||
|
||||
sd_model.eval()
|
||||
model_data.sd_model = sd_model
|
||||
model_data.set_sd_model(sd_model)
|
||||
model_data.was_loaded_at_least_once = True
|
||||
|
||||
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) # Reload embeddings after model load as they may or may not fit the model
|
||||
@@ -542,10 +608,61 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
|
||||
return sd_model
|
||||
|
||||
|
||||
def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer):
|
||||
"""
|
||||
Checks if the desired checkpoint from checkpoint_info is not already loaded in model_data.loaded_sd_models.
|
||||
If it is loaded, returns that (moving it to GPU if necessary, and moving the currently loadded model to CPU if necessary).
|
||||
If not, returns the model that can be used to load weights from checkpoint_info's file.
|
||||
If no such model exists, returns None.
|
||||
Additionaly deletes loaded models that are over the limit set in settings (sd_checkpoints_limit).
|
||||
"""
|
||||
|
||||
already_loaded = None
|
||||
for i in reversed(range(len(model_data.loaded_sd_models))):
|
||||
loaded_model = model_data.loaded_sd_models[i]
|
||||
if loaded_model.sd_checkpoint_info.filename == checkpoint_info.filename:
|
||||
already_loaded = loaded_model
|
||||
continue
|
||||
|
||||
if len(model_data.loaded_sd_models) > shared.opts.sd_checkpoints_limit > 0:
|
||||
print(f"Unloading model {len(model_data.loaded_sd_models)} over the limit of {shared.opts.sd_checkpoints_limit}: {loaded_model.sd_checkpoint_info.title}")
|
||||
model_data.loaded_sd_models.pop()
|
||||
send_model_to_trash(loaded_model)
|
||||
timer.record("send model to trash")
|
||||
|
||||
if shared.opts.sd_checkpoints_keep_in_cpu:
|
||||
send_model_to_cpu(sd_model)
|
||||
timer.record("send model to cpu")
|
||||
|
||||
if already_loaded is not None:
|
||||
send_model_to_device(already_loaded)
|
||||
timer.record("send model to device")
|
||||
|
||||
model_data.set_sd_model(already_loaded)
|
||||
print(f"Using already loaded model {already_loaded.sd_checkpoint_info.title}: done in {timer.summary()}")
|
||||
return model_data.sd_model
|
||||
elif shared.opts.sd_checkpoints_limit > 1 and len(model_data.loaded_sd_models) < shared.opts.sd_checkpoints_limit:
|
||||
print(f"Loading model {checkpoint_info.title} ({len(model_data.loaded_sd_models) + 1} out of {shared.opts.sd_checkpoints_limit})")
|
||||
|
||||
model_data.sd_model = None
|
||||
load_model(checkpoint_info)
|
||||
return model_data.sd_model
|
||||
elif len(model_data.loaded_sd_models) > 0:
|
||||
sd_model = model_data.loaded_sd_models.pop()
|
||||
model_data.sd_model = sd_model
|
||||
|
||||
print(f"Reusing loaded model {sd_model.sd_checkpoint_info.title} to load {checkpoint_info.title}")
|
||||
return sd_model
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def reload_model_weights(sd_model=None, info=None):
|
||||
from modules import lowvram, devices, sd_hijack
|
||||
from modules import devices, sd_hijack
|
||||
checkpoint_info = info or select_checkpoint()
|
||||
|
||||
timer = Timer()
|
||||
|
||||
if not sd_model:
|
||||
sd_model = model_data.sd_model
|
||||
|
||||
@@ -554,19 +671,17 @@ def reload_model_weights(sd_model=None, info=None):
|
||||
else:
|
||||
current_checkpoint_info = sd_model.sd_checkpoint_info
|
||||
if sd_model.sd_model_checkpoint == checkpoint_info.filename:
|
||||
return
|
||||
return sd_model
|
||||
|
||||
sd_model = reuse_model_from_already_loaded(sd_model, checkpoint_info, timer)
|
||||
if sd_model is not None and sd_model.sd_checkpoint_info.filename == checkpoint_info.filename:
|
||||
return sd_model
|
||||
|
||||
if sd_model is not None:
|
||||
sd_unet.apply_unet("None")
|
||||
|
||||
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
||||
lowvram.send_everything_to_cpu()
|
||||
else:
|
||||
sd_model.to(devices.cpu)
|
||||
|
||||
send_model_to_cpu(sd_model)
|
||||
sd_hijack.model_hijack.undo_hijack(sd_model)
|
||||
|
||||
timer = Timer()
|
||||
|
||||
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
|
||||
|
||||
checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
|
||||
@@ -574,7 +689,9 @@ def reload_model_weights(sd_model=None, info=None):
|
||||
timer.record("find config")
|
||||
|
||||
if sd_model is None or checkpoint_config != sd_model.used_config:
|
||||
del sd_model
|
||||
if sd_model is not None:
|
||||
send_model_to_trash(sd_model)
|
||||
|
||||
load_model(checkpoint_info, already_loaded_state_dict=state_dict)
|
||||
return model_data.sd_model
|
||||
|
||||
@@ -597,6 +714,9 @@ def reload_model_weights(sd_model=None, info=None):
|
||||
|
||||
print(f"Weights loaded in {timer.summary()}.")
|
||||
|
||||
model_data.set_sd_model(sd_model)
|
||||
sd_unet.apply_unet()
|
||||
|
||||
return sd_model
|
||||
|
||||
|
||||
|
||||
+13
-4
@@ -56,6 +56,14 @@ def encode_embedding_init_text(self: sgm.modules.GeneralConditioner, init_text,
|
||||
return torch.cat(res, dim=1)
|
||||
|
||||
|
||||
def tokenize(self: sgm.modules.GeneralConditioner, texts):
|
||||
for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'tokenize')]:
|
||||
return embedder.tokenize(texts)
|
||||
|
||||
raise AssertionError('no tokenizer available')
|
||||
|
||||
|
||||
|
||||
def process_texts(self, texts):
|
||||
for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'process_texts')]:
|
||||
return embedder.process_texts(texts)
|
||||
@@ -68,6 +76,7 @@ def get_target_prompt_token_count(self, token_count):
|
||||
|
||||
# those additions to GeneralConditioner make it possible to use it as model.cond_stage_model from SD1.5 in exist
|
||||
sgm.modules.GeneralConditioner.encode_embedding_init_text = encode_embedding_init_text
|
||||
sgm.modules.GeneralConditioner.tokenize = tokenize
|
||||
sgm.modules.GeneralConditioner.process_texts = process_texts
|
||||
sgm.modules.GeneralConditioner.get_target_prompt_token_count = get_target_prompt_token_count
|
||||
|
||||
@@ -89,10 +98,10 @@ def extend_sdxl(model):
|
||||
model.conditioner.wrapped = torch.nn.Module()
|
||||
|
||||
|
||||
sgm.modules.attention.print = lambda *args: None
|
||||
sgm.modules.diffusionmodules.model.print = lambda *args: None
|
||||
sgm.modules.diffusionmodules.openaimodel.print = lambda *args: None
|
||||
sgm.modules.encoders.modules.print = lambda *args: None
|
||||
sgm.modules.attention.print = shared.ldm_print
|
||||
sgm.modules.diffusionmodules.model.print = shared.ldm_print
|
||||
sgm.modules.diffusionmodules.openaimodel.print = shared.ldm_print
|
||||
sgm.modules.encoders.modules.print = shared.ldm_print
|
||||
|
||||
# this gets the code to load the vanilla attention that we override
|
||||
sgm.modules.attention.SDP_IS_AVAILABLE = True
|
||||
|
||||
@@ -2,10 +2,8 @@ from collections import namedtuple
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from modules import devices, processing, images, sd_vae_approx, sd_samplers, sd_vae_taesd
|
||||
|
||||
from modules import devices, images, sd_vae_approx, sd_samplers, sd_vae_taesd, shared
|
||||
from modules.shared import opts, state
|
||||
import modules.shared as shared
|
||||
|
||||
SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])
|
||||
|
||||
@@ -25,19 +23,29 @@ def setup_img2img_steps(p, steps=None):
|
||||
approximation_indexes = {"Full": 0, "Approx NN": 1, "Approx cheap": 2, "TAESD": 3}
|
||||
|
||||
|
||||
def single_sample_to_image(sample, approximation=None):
|
||||
def samples_to_images_tensor(sample, approximation=None, model=None):
|
||||
'''latents -> images [-1, 1]'''
|
||||
if approximation is None:
|
||||
approximation = approximation_indexes.get(opts.show_progress_type, 0)
|
||||
|
||||
if approximation == 2:
|
||||
x_sample = sd_vae_approx.cheap_approximation(sample) * 0.5 + 0.5
|
||||
x_sample = sd_vae_approx.cheap_approximation(sample)
|
||||
elif approximation == 1:
|
||||
x_sample = sd_vae_approx.model()(sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach() * 0.5 + 0.5
|
||||
x_sample = sd_vae_approx.model()(sample.to(devices.device, devices.dtype)).detach()
|
||||
elif approximation == 3:
|
||||
x_sample = sample * 1.5
|
||||
x_sample = sd_vae_taesd.model()(x_sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach()
|
||||
x_sample = sd_vae_taesd.decoder_model()(x_sample.to(devices.device, devices.dtype)).detach()
|
||||
x_sample = x_sample * 2 - 1
|
||||
else:
|
||||
x_sample = processing.decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0] * 0.5 + 0.5
|
||||
if model is None:
|
||||
model = shared.sd_model
|
||||
x_sample = model.decode_first_stage(sample.to(model.first_stage_model.dtype))
|
||||
|
||||
return x_sample
|
||||
|
||||
|
||||
def single_sample_to_image(sample, approximation=None):
|
||||
x_sample = samples_to_images_tensor(sample.unsqueeze(0), approximation)[0] * 0.5 + 0.5
|
||||
|
||||
x_sample = torch.clamp(x_sample, min=0.0, max=1.0)
|
||||
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
|
||||
@@ -46,6 +54,12 @@ def single_sample_to_image(sample, approximation=None):
|
||||
return Image.fromarray(x_sample)
|
||||
|
||||
|
||||
def decode_first_stage(model, x):
|
||||
x = x.to(devices.dtype_vae)
|
||||
approx_index = approximation_indexes.get(opts.sd_vae_decode_method, 0)
|
||||
return samples_to_images_tensor(x, approx_index, model)
|
||||
|
||||
|
||||
def sample_to_image(samples, index=0, approximation=None):
|
||||
return single_sample_to_image(samples[index], approximation)
|
||||
|
||||
@@ -54,6 +68,24 @@ def samples_to_image_grid(samples, approximation=None):
|
||||
return images.image_grid([single_sample_to_image(sample, approximation) for sample in samples])
|
||||
|
||||
|
||||
def images_tensor_to_samples(image, approximation=None, model=None):
|
||||
'''image[0, 1] -> latent'''
|
||||
if approximation is None:
|
||||
approximation = approximation_indexes.get(opts.sd_vae_encode_method, 0)
|
||||
|
||||
if approximation == 3:
|
||||
image = image.to(devices.device, devices.dtype)
|
||||
x_latent = sd_vae_taesd.encoder_model()(image)
|
||||
else:
|
||||
if model is None:
|
||||
model = shared.sd_model
|
||||
image = image.to(shared.device, dtype=devices.dtype_vae)
|
||||
image = image * 2 - 1
|
||||
x_latent = model.get_first_stage_encoding(model.encode_first_stage(image))
|
||||
|
||||
return x_latent
|
||||
|
||||
|
||||
def store_latent(decoded):
|
||||
state.current_latent = decoded
|
||||
|
||||
@@ -85,11 +117,13 @@ class InterruptedException(BaseException):
|
||||
pass
|
||||
|
||||
|
||||
if opts.randn_source == "CPU":
|
||||
def replace_torchsde_browinan():
|
||||
import torchsde._brownian.brownian_interval
|
||||
|
||||
def torchsde_randn(size, dtype, device, seed):
|
||||
generator = torch.Generator(devices.cpu).manual_seed(int(seed))
|
||||
return torch.randn(size, dtype=dtype, device=devices.cpu, generator=generator).to(device)
|
||||
return devices.randn_local(seed, size).to(device=device, dtype=dtype)
|
||||
|
||||
torchsde._brownian.brownian_interval._randn = torchsde_randn
|
||||
|
||||
|
||||
replace_torchsde_browinan()
|
||||
|
||||
@@ -44,7 +44,7 @@ class VanillaStableDiffusionSampler:
|
||||
return 0
|
||||
|
||||
def launch_sampling(self, steps, func):
|
||||
state.sampling_steps = steps
|
||||
state.sampling_steps = self.stop_at if self.stop_at is not None else steps
|
||||
state.sampling_step = 0
|
||||
|
||||
try:
|
||||
|
||||
@@ -0,0 +1,74 @@
|
||||
import torch
|
||||
import tqdm
|
||||
import k_diffusion.sampling
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def restart_sampler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1., restart_list=None):
|
||||
"""Implements restart sampling in Restart Sampling for Improving Generative Processes (2023)
|
||||
Restart_list format: {min_sigma: [ restart_steps, restart_times, max_sigma]}
|
||||
If restart_list is None: will choose restart_list automatically, otherwise will use the given restart_list
|
||||
"""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
step_id = 0
|
||||
from k_diffusion.sampling import to_d, get_sigmas_karras
|
||||
|
||||
def heun_step(x, old_sigma, new_sigma, second_order=True):
|
||||
nonlocal step_id
|
||||
denoised = model(x, old_sigma * s_in, **extra_args)
|
||||
d = to_d(x, old_sigma, denoised)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': step_id, 'sigma': new_sigma, 'sigma_hat': old_sigma, 'denoised': denoised})
|
||||
dt = new_sigma - old_sigma
|
||||
if new_sigma == 0 or not second_order:
|
||||
# Euler method
|
||||
x = x + d * dt
|
||||
else:
|
||||
# Heun's method
|
||||
x_2 = x + d * dt
|
||||
denoised_2 = model(x_2, new_sigma * s_in, **extra_args)
|
||||
d_2 = to_d(x_2, new_sigma, denoised_2)
|
||||
d_prime = (d + d_2) / 2
|
||||
x = x + d_prime * dt
|
||||
step_id += 1
|
||||
return x
|
||||
|
||||
steps = sigmas.shape[0] - 1
|
||||
if restart_list is None:
|
||||
if steps >= 20:
|
||||
restart_steps = 9
|
||||
restart_times = 1
|
||||
if steps >= 36:
|
||||
restart_steps = steps // 4
|
||||
restart_times = 2
|
||||
sigmas = get_sigmas_karras(steps - restart_steps * restart_times, sigmas[-2].item(), sigmas[0].item(), device=sigmas.device)
|
||||
restart_list = {0.1: [restart_steps + 1, restart_times, 2]}
|
||||
else:
|
||||
restart_list = {}
|
||||
|
||||
restart_list = {int(torch.argmin(abs(sigmas - key), dim=0)): value for key, value in restart_list.items()}
|
||||
|
||||
step_list = []
|
||||
for i in range(len(sigmas) - 1):
|
||||
step_list.append((sigmas[i], sigmas[i + 1]))
|
||||
if i + 1 in restart_list:
|
||||
restart_steps, restart_times, restart_max = restart_list[i + 1]
|
||||
min_idx = i + 1
|
||||
max_idx = int(torch.argmin(abs(sigmas - restart_max), dim=0))
|
||||
if max_idx < min_idx:
|
||||
sigma_restart = get_sigmas_karras(restart_steps, sigmas[min_idx].item(), sigmas[max_idx].item(), device=sigmas.device)[:-1]
|
||||
while restart_times > 0:
|
||||
restart_times -= 1
|
||||
step_list.extend([(old_sigma, new_sigma) for (old_sigma, new_sigma) in zip(sigma_restart[:-1], sigma_restart[1:])])
|
||||
|
||||
last_sigma = None
|
||||
for old_sigma, new_sigma in tqdm.tqdm(step_list, disable=disable):
|
||||
if last_sigma is None:
|
||||
last_sigma = old_sigma
|
||||
elif last_sigma < old_sigma:
|
||||
x = x + k_diffusion.sampling.torch.randn_like(x) * s_noise * (old_sigma ** 2 - last_sigma ** 2) ** 0.5
|
||||
x = heun_step(x, old_sigma, new_sigma)
|
||||
last_sigma = new_sigma
|
||||
|
||||
return x
|
||||
@@ -2,8 +2,9 @@ from collections import deque
|
||||
import torch
|
||||
import inspect
|
||||
import k_diffusion.sampling
|
||||
from modules import prompt_parser, devices, sd_samplers_common
|
||||
from modules import prompt_parser, devices, sd_samplers_common, sd_samplers_extra
|
||||
|
||||
from modules.processing import StableDiffusionProcessing
|
||||
from modules.shared import opts, state
|
||||
import modules.shared as shared
|
||||
from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback
|
||||
@@ -30,12 +31,15 @@ samplers_k_diffusion = [
|
||||
('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}),
|
||||
('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras', "second_order": True, "brownian_noise": True}),
|
||||
('DPM++ 2M SDE Karras', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_ka'], {'scheduler': 'karras', "brownian_noise": True}),
|
||||
('DPM++ 2M SDE Exponential', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_exp'], {'scheduler': 'exponential', "brownian_noise": True}),
|
||||
('Restart', sd_samplers_extra.restart_sampler, ['restart'], {'scheduler': 'karras'}),
|
||||
]
|
||||
|
||||
|
||||
samplers_data_k_diffusion = [
|
||||
sd_samplers_common.SamplerData(label, lambda model, funcname=funcname: KDiffusionSampler(funcname, model), aliases, options)
|
||||
for label, funcname, aliases, options in samplers_k_diffusion
|
||||
if hasattr(k_diffusion.sampling, funcname)
|
||||
if callable(funcname) or hasattr(k_diffusion.sampling, funcname)
|
||||
]
|
||||
|
||||
sampler_extra_params = {
|
||||
@@ -258,10 +262,7 @@ class TorchHijack:
|
||||
if noise.shape == x.shape:
|
||||
return noise
|
||||
|
||||
if opts.randn_source == "CPU" or x.device.type == 'mps':
|
||||
return torch.randn_like(x, device=devices.cpu).to(x.device)
|
||||
else:
|
||||
return torch.randn_like(x)
|
||||
return devices.randn_like(x)
|
||||
|
||||
|
||||
class KDiffusionSampler:
|
||||
@@ -270,16 +271,25 @@ class KDiffusionSampler:
|
||||
|
||||
self.model_wrap = denoiser(sd_model, quantize=shared.opts.enable_quantization)
|
||||
self.funcname = funcname
|
||||
self.func = getattr(k_diffusion.sampling, self.funcname)
|
||||
self.func = funcname if callable(funcname) else getattr(k_diffusion.sampling, self.funcname)
|
||||
self.extra_params = sampler_extra_params.get(funcname, [])
|
||||
self.model_wrap_cfg = CFGDenoiser(self.model_wrap)
|
||||
self.sampler_noises = None
|
||||
self.stop_at = None
|
||||
self.noisy_output = None
|
||||
self.eta = None
|
||||
self.config = None # set by the function calling the constructor
|
||||
self.last_latent = None
|
||||
self.s_min_uncond = None
|
||||
|
||||
# NOTE: These are also defined in the StableDiffusionProcessing class.
|
||||
# They should have been here to begin with but we're going to
|
||||
# leave that class __init__ signature alone.
|
||||
self.s_churn = 0.0
|
||||
self.s_tmin = 0.0
|
||||
self.s_tmax = float('inf')
|
||||
self.s_noise = 1.0
|
||||
|
||||
self.conditioning_key = sd_model.model.conditioning_key
|
||||
|
||||
def callback_state(self, d):
|
||||
@@ -288,6 +298,7 @@ class KDiffusionSampler:
|
||||
if opts.live_preview_content == "Combined":
|
||||
sd_samplers_common.store_latent(latent)
|
||||
self.last_latent = latent
|
||||
self.noisy_output = d['x']
|
||||
|
||||
if self.stop_at is not None and step > self.stop_at:
|
||||
raise sd_samplers_common.InterruptedException
|
||||
@@ -296,7 +307,7 @@ class KDiffusionSampler:
|
||||
shared.total_tqdm.update()
|
||||
|
||||
def launch_sampling(self, steps, func):
|
||||
state.sampling_steps = steps
|
||||
state.sampling_steps = self.stop_at if self.stop_at is not None else steps
|
||||
state.sampling_step = 0
|
||||
|
||||
try:
|
||||
@@ -314,7 +325,7 @@ class KDiffusionSampler:
|
||||
def number_of_needed_noises(self, p):
|
||||
return p.steps
|
||||
|
||||
def initialize(self, p):
|
||||
def initialize(self, p: StableDiffusionProcessing):
|
||||
self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None
|
||||
self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None
|
||||
self.model_wrap_cfg.step = 0
|
||||
@@ -335,6 +346,29 @@ class KDiffusionSampler:
|
||||
|
||||
extra_params_kwargs['eta'] = self.eta
|
||||
|
||||
if len(self.extra_params) > 0:
|
||||
s_churn = getattr(opts, 's_churn', p.s_churn)
|
||||
s_tmin = getattr(opts, 's_tmin', p.s_tmin)
|
||||
s_tmax = getattr(opts, 's_tmax', p.s_tmax) or self.s_tmax # 0 = inf
|
||||
s_noise = getattr(opts, 's_noise', p.s_noise)
|
||||
|
||||
if s_churn != self.s_churn:
|
||||
extra_params_kwargs['s_churn'] = s_churn
|
||||
p.s_churn = s_churn
|
||||
p.extra_generation_params['Sigma churn'] = s_churn
|
||||
if s_tmin != self.s_tmin:
|
||||
extra_params_kwargs['s_tmin'] = s_tmin
|
||||
p.s_tmin = s_tmin
|
||||
p.extra_generation_params['Sigma tmin'] = s_tmin
|
||||
if s_tmax != self.s_tmax:
|
||||
extra_params_kwargs['s_tmax'] = s_tmax
|
||||
p.s_tmax = s_tmax
|
||||
p.extra_generation_params['Sigma tmax'] = s_tmax
|
||||
if s_noise != self.s_noise:
|
||||
extra_params_kwargs['s_noise'] = s_noise
|
||||
p.s_noise = s_noise
|
||||
p.extra_generation_params['Sigma noise'] = s_noise
|
||||
|
||||
return extra_params_kwargs
|
||||
|
||||
def get_sigmas(self, p, steps):
|
||||
@@ -376,6 +410,9 @@ class KDiffusionSampler:
|
||||
sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item())
|
||||
|
||||
sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, device=shared.device)
|
||||
elif self.config is not None and self.config.options.get('scheduler', None) == 'exponential':
|
||||
m_sigma_min, m_sigma_max = (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item())
|
||||
sigmas = k_diffusion.sampling.get_sigmas_exponential(n=steps, sigma_min=m_sigma_min, sigma_max=m_sigma_max, device=shared.device)
|
||||
else:
|
||||
sigmas = self.model_wrap.get_sigmas(steps)
|
||||
|
||||
|
||||
+15
-1
@@ -1,6 +1,6 @@
|
||||
import os
|
||||
import collections
|
||||
from modules import paths, shared, devices, script_callbacks, sd_models
|
||||
from modules import paths, shared, devices, script_callbacks, sd_models, extra_networks
|
||||
import glob
|
||||
from copy import deepcopy
|
||||
|
||||
@@ -16,6 +16,7 @@ checkpoint_info = None
|
||||
|
||||
checkpoints_loaded = collections.OrderedDict()
|
||||
|
||||
|
||||
def get_base_vae(model):
|
||||
if base_vae is not None and checkpoint_info == model.sd_checkpoint_info and model:
|
||||
return base_vae
|
||||
@@ -50,6 +51,7 @@ def get_filename(filepath):
|
||||
|
||||
|
||||
def refresh_vae_list():
|
||||
global vae_dict
|
||||
vae_dict.clear()
|
||||
|
||||
paths = [
|
||||
@@ -83,6 +85,8 @@ def refresh_vae_list():
|
||||
name = get_filename(filepath)
|
||||
vae_dict[name] = filepath
|
||||
|
||||
vae_dict = dict(sorted(vae_dict.items(), key=lambda item: shared.natural_sort_key(item[0])))
|
||||
|
||||
|
||||
def find_vae_near_checkpoint(checkpoint_file):
|
||||
checkpoint_path = os.path.basename(checkpoint_file).rsplit('.', 1)[0]
|
||||
@@ -97,6 +101,16 @@ def resolve_vae(checkpoint_file):
|
||||
if shared.cmd_opts.vae_path is not None:
|
||||
return shared.cmd_opts.vae_path, 'from commandline argument'
|
||||
|
||||
metadata = extra_networks.get_user_metadata(checkpoint_file)
|
||||
vae_metadata = metadata.get("vae", None)
|
||||
if vae_metadata is not None and vae_metadata != "Automatic":
|
||||
if vae_metadata == "None":
|
||||
return None, None
|
||||
|
||||
vae_from_metadata = vae_dict.get(vae_metadata, None)
|
||||
if vae_from_metadata is not None:
|
||||
return vae_from_metadata, "from user metadata"
|
||||
|
||||
is_automatic = shared.opts.sd_vae in {"Automatic", "auto"} # "auto" for people with old config
|
||||
|
||||
vae_near_checkpoint = find_vae_near_checkpoint(checkpoint_file)
|
||||
|
||||
@@ -81,6 +81,6 @@ def cheap_approximation(sample):
|
||||
|
||||
coefs = torch.tensor(coeffs).to(sample.device)
|
||||
|
||||
x_sample = torch.einsum("lxy,lr -> rxy", sample, coefs)
|
||||
x_sample = torch.einsum("...lxy,lr -> ...rxy", sample, coefs)
|
||||
|
||||
return x_sample
|
||||
|
||||
+44
-8
@@ -44,7 +44,17 @@ def decoder():
|
||||
)
|
||||
|
||||
|
||||
class TAESD(nn.Module):
|
||||
def encoder():
|
||||
return nn.Sequential(
|
||||
conv(3, 64), Block(64, 64),
|
||||
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
|
||||
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
|
||||
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
|
||||
conv(64, 4),
|
||||
)
|
||||
|
||||
|
||||
class TAESDDecoder(nn.Module):
|
||||
latent_magnitude = 3
|
||||
latent_shift = 0.5
|
||||
|
||||
@@ -55,21 +65,28 @@ class TAESD(nn.Module):
|
||||
self.decoder.load_state_dict(
|
||||
torch.load(decoder_path, map_location='cpu' if devices.device.type != 'cuda' else None))
|
||||
|
||||
@staticmethod
|
||||
def unscale_latents(x):
|
||||
"""[0, 1] -> raw latents"""
|
||||
return x.sub(TAESD.latent_shift).mul(2 * TAESD.latent_magnitude)
|
||||
|
||||
class TAESDEncoder(nn.Module):
|
||||
latent_magnitude = 3
|
||||
latent_shift = 0.5
|
||||
|
||||
def __init__(self, encoder_path="taesd_encoder.pth"):
|
||||
"""Initialize pretrained TAESD on the given device from the given checkpoints."""
|
||||
super().__init__()
|
||||
self.encoder = encoder()
|
||||
self.encoder.load_state_dict(
|
||||
torch.load(encoder_path, map_location='cpu' if devices.device.type != 'cuda' else None))
|
||||
|
||||
|
||||
def download_model(model_path, model_url):
|
||||
if not os.path.exists(model_path):
|
||||
os.makedirs(os.path.dirname(model_path), exist_ok=True)
|
||||
|
||||
print(f'Downloading TAESD decoder to: {model_path}')
|
||||
print(f'Downloading TAESD model to: {model_path}')
|
||||
torch.hub.download_url_to_file(model_url, model_path)
|
||||
|
||||
|
||||
def model():
|
||||
def decoder_model():
|
||||
model_name = "taesdxl_decoder.pth" if getattr(shared.sd_model, 'is_sdxl', False) else "taesd_decoder.pth"
|
||||
loaded_model = sd_vae_taesd_models.get(model_name)
|
||||
|
||||
@@ -78,7 +95,7 @@ def model():
|
||||
download_model(model_path, 'https://github.com/madebyollin/taesd/raw/main/' + model_name)
|
||||
|
||||
if os.path.exists(model_path):
|
||||
loaded_model = TAESD(model_path)
|
||||
loaded_model = TAESDDecoder(model_path)
|
||||
loaded_model.eval()
|
||||
loaded_model.to(devices.device, devices.dtype)
|
||||
sd_vae_taesd_models[model_name] = loaded_model
|
||||
@@ -86,3 +103,22 @@ def model():
|
||||
raise FileNotFoundError('TAESD model not found')
|
||||
|
||||
return loaded_model.decoder
|
||||
|
||||
|
||||
def encoder_model():
|
||||
model_name = "taesdxl_encoder.pth" if getattr(shared.sd_model, 'is_sdxl', False) else "taesd_encoder.pth"
|
||||
loaded_model = sd_vae_taesd_models.get(model_name)
|
||||
|
||||
if loaded_model is None:
|
||||
model_path = os.path.join(paths_internal.models_path, "VAE-taesd", model_name)
|
||||
download_model(model_path, 'https://github.com/madebyollin/taesd/raw/main/' + model_name)
|
||||
|
||||
if os.path.exists(model_path):
|
||||
loaded_model = TAESDEncoder(model_path)
|
||||
loaded_model.eval()
|
||||
loaded_model.to(devices.device, devices.dtype)
|
||||
sd_vae_taesd_models[model_name] = loaded_model
|
||||
else:
|
||||
raise FileNotFoundError('TAESD model not found')
|
||||
|
||||
return loaded_model.encoder
|
||||
|
||||
+122
-40
@@ -51,15 +51,34 @@ restricted_opts = {
|
||||
|
||||
# https://huggingface.co/datasets/freddyaboulton/gradio-theme-subdomains/resolve/main/subdomains.json
|
||||
gradio_hf_hub_themes = [
|
||||
"gradio/base",
|
||||
"gradio/glass",
|
||||
"gradio/monochrome",
|
||||
"gradio/seafoam",
|
||||
"gradio/soft",
|
||||
"freddyaboulton/dracula_revamped",
|
||||
"gradio/dracula_test",
|
||||
"abidlabs/dracula_test",
|
||||
"abidlabs/Lime",
|
||||
"abidlabs/pakistan",
|
||||
"Ama434/neutral-barlow",
|
||||
"dawood/microsoft_windows",
|
||||
"finlaymacklon/smooth_slate",
|
||||
"Franklisi/darkmode",
|
||||
"freddyaboulton/dracula_revamped",
|
||||
"freddyaboulton/test-blue",
|
||||
"gstaff/xkcd",
|
||||
"Insuz/Mocha",
|
||||
"Insuz/SimpleIndigo",
|
||||
"JohnSmith9982/small_and_pretty",
|
||||
"nota-ai/theme",
|
||||
"nuttea/Softblue",
|
||||
"ParityError/Anime",
|
||||
"reilnuud/polite",
|
||||
"remilia/Ghostly",
|
||||
"rottenlittlecreature/Moon_Goblin",
|
||||
"step-3-profit/Midnight-Deep",
|
||||
"Taithrah/Minimal",
|
||||
"ysharma/huggingface",
|
||||
"ysharma/steampunk"
|
||||
]
|
||||
|
||||
@@ -220,12 +239,19 @@ class State:
|
||||
return
|
||||
|
||||
import modules.sd_samplers
|
||||
if opts.show_progress_grid:
|
||||
self.assign_current_image(modules.sd_samplers.samples_to_image_grid(self.current_latent))
|
||||
else:
|
||||
self.assign_current_image(modules.sd_samplers.sample_to_image(self.current_latent))
|
||||
|
||||
self.current_image_sampling_step = self.sampling_step
|
||||
try:
|
||||
if opts.show_progress_grid:
|
||||
self.assign_current_image(modules.sd_samplers.samples_to_image_grid(self.current_latent))
|
||||
else:
|
||||
self.assign_current_image(modules.sd_samplers.sample_to_image(self.current_latent))
|
||||
|
||||
self.current_image_sampling_step = self.sampling_step
|
||||
|
||||
except Exception:
|
||||
# when switching models during genration, VAE would be on CPU, so creating an image will fail.
|
||||
# we silently ignore this error
|
||||
errors.record_exception()
|
||||
|
||||
def assign_current_image(self, image):
|
||||
self.current_image = image
|
||||
@@ -252,6 +278,7 @@ class OptionInfo:
|
||||
self.onchange = onchange
|
||||
self.section = section
|
||||
self.refresh = refresh
|
||||
self.do_not_save = False
|
||||
|
||||
self.comment_before = comment_before
|
||||
"""HTML text that will be added after label in UI"""
|
||||
@@ -279,8 +306,17 @@ class OptionInfo:
|
||||
self.comment_after += " <span class='info'>(requires restart)</span>"
|
||||
return self
|
||||
|
||||
def needs_reload_ui(self):
|
||||
self.comment_after += " <span class='info'>(requires Reload UI)</span>"
|
||||
return self
|
||||
|
||||
|
||||
class OptionHTML(OptionInfo):
|
||||
def __init__(self, text):
|
||||
super().__init__(str(text).strip(), label='', component=lambda **kwargs: gr.HTML(elem_classes="settings-info", **kwargs))
|
||||
|
||||
self.do_not_save = True
|
||||
|
||||
|
||||
def options_section(section_identifier, options_dict):
|
||||
for v in options_dict.values():
|
||||
@@ -349,6 +385,7 @@ options_templates.update(options_section(('saving-images', "Saving images/grids"
|
||||
"temp_dir": OptionInfo("", "Directory for temporary images; leave empty for default"),
|
||||
"clean_temp_dir_at_start": OptionInfo(False, "Cleanup non-default temporary directory when starting webui"),
|
||||
|
||||
"save_incomplete_images": OptionInfo(False, "Save incomplete images").info("save images that has been interrupted in mid-generation; even if not saved, they will still show up in webui output."),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('saving-paths', "Paths for saving"), {
|
||||
@@ -385,13 +422,15 @@ options_templates.update(options_section(('face-restoration', "Face restoration"
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('system', "System"), {
|
||||
"show_warnings": OptionInfo(False, "Show warnings in console."),
|
||||
"show_warnings": OptionInfo(False, "Show warnings in console.").needs_reload_ui(),
|
||||
"show_gradio_deprecation_warnings": OptionInfo(True, "Show gradio deprecation warnings in console.").needs_reload_ui(),
|
||||
"memmon_poll_rate": OptionInfo(8, "VRAM usage polls per second during generation.", gr.Slider, {"minimum": 0, "maximum": 40, "step": 1}).info("0 = disable"),
|
||||
"samples_log_stdout": OptionInfo(False, "Always print all generation info to standard output"),
|
||||
"multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."),
|
||||
"print_hypernet_extra": OptionInfo(False, "Print extra hypernetwork information to console."),
|
||||
"list_hidden_files": OptionInfo(True, "Load models/files in hidden directories").info("directory is hidden if its name starts with \".\""),
|
||||
"disable_mmap_load_safetensors": OptionInfo(False, "Disable memmapping for loading .safetensors files.").info("fixes very slow loading speed in some cases"),
|
||||
"hide_ldm_prints": OptionInfo(True, "Prevent Stability-AI's ldm/sgm modules from printing noise to console."),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('training', "Training"), {
|
||||
@@ -411,24 +450,19 @@ options_templates.update(options_section(('training', "Training"), {
|
||||
|
||||
options_templates.update(options_section(('sd', "Stable Diffusion"), {
|
||||
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints),
|
||||
"sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
|
||||
"sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
|
||||
"sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": shared_items.sd_vae_items()}, refresh=shared_items.refresh_vae_list).info("choose VAE model: Automatic = use one with same filename as checkpoint; None = use VAE from checkpoint"),
|
||||
"sd_vae_as_default": OptionInfo(True, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"),
|
||||
"sd_checkpoints_limit": OptionInfo(1, "Maximum number of checkpoints loaded at the same time", gr.Slider, {"minimum": 1, "maximum": 10, "step": 1}),
|
||||
"sd_checkpoints_keep_in_cpu": OptionInfo(True, "Only keep one model on device").info("will keep models other than the currently used one in RAM rather than VRAM"),
|
||||
"sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}).info("obsolete; set to 0 and use the two settings above instead"),
|
||||
"sd_unet": OptionInfo("Automatic", "SD Unet", gr.Dropdown, lambda: {"choices": shared_items.sd_unet_items()}, refresh=shared_items.refresh_unet_list).info("choose Unet model: Automatic = use one with same filename as checkpoint; None = use Unet from checkpoint"),
|
||||
"inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
||||
"initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.5, "maximum": 1.5, "step": 0.01}),
|
||||
"img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
|
||||
"img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies.").info("normally you'd do less with less denoising"),
|
||||
"img2img_background_color": OptionInfo("#ffffff", "With img2img, fill image's transparent parts with this color.", ui_components.FormColorPicker, {}),
|
||||
"enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds. Requires restart to apply."),
|
||||
"enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds").needs_reload_ui(),
|
||||
"enable_emphasis": OptionInfo(True, "Enable emphasis").info("use (text) to make model pay more attention to text and [text] to make it pay less attention"),
|
||||
"enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"),
|
||||
"comma_padding_backtrack": OptionInfo(20, "Prompt word wrap length limit", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1}).info("in tokens - for texts shorter than specified, if they don't fit into 75 token limit, move them to the next 75 token chunk"),
|
||||
"CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#clip-skip").info("ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer"),
|
||||
"upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"),
|
||||
"auto_vae_precision": OptionInfo(True, "Automaticlly revert VAE to 32-bit floats").info("triggers when a tensor with NaNs is produced in VAE; disabling the option in this case will result in a black square image"),
|
||||
"randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU"]}).info("changes seeds drastically; use CPU to produce the same picture across different videocard vendors"),
|
||||
"randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU", "NV"]}).info("changes seeds drastically; use CPU to produce the same picture across different videocard vendors; use NV to produce same picture as on NVidia videocards"),
|
||||
"sd_refiner_checkpoint": OptionInfo(None, "Refiner checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints).info("switch to another model in the middle of generation"),
|
||||
"sd_refiner_switch_at": OptionInfo(1.0, "Refiner switch at", gr.Slider, {"minimum": 0.01, "maximum": 1.0, "step": 0.01}).info("fraction of sampling steps when the swtch to refiner model should happen; 1=never, 0.5=switch in the middle of generation"),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('sdxl', "Stable Diffusion XL"), {
|
||||
@@ -438,6 +472,35 @@ options_templates.update(options_section(('sdxl', "Stable Diffusion XL"), {
|
||||
"sdxl_refiner_high_aesthetic_score": OptionInfo(6.0, "SDXL high aesthetic score", gr.Number).info("used for refiner model prompt"),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('vae', "VAE"), {
|
||||
"sd_vae_explanation": OptionHTML("""
|
||||
<abbr title='Variational autoencoder'>VAE</abbr> is a neural network that transforms a standard <abbr title='red/green/blue'>RGB</abbr>
|
||||
image into latent space representation and back. Latent space representation is what stable diffusion is working on during sampling
|
||||
(i.e. when the progress bar is between empty and full). For txt2img, VAE is used to create a resulting image after the sampling is finished.
|
||||
For img2img, VAE is used to process user's input image before the sampling, and to create an image after sampling.
|
||||
"""),
|
||||
"sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
|
||||
"sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": shared_items.sd_vae_items()}, refresh=shared_items.refresh_vae_list).info("choose VAE model: Automatic = use one with same filename as checkpoint; None = use VAE from checkpoint"),
|
||||
"sd_vae_as_default": OptionInfo(True, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"),
|
||||
"auto_vae_precision": OptionInfo(True, "Automaticlly revert VAE to 32-bit floats").info("triggers when a tensor with NaNs is produced in VAE; disabling the option in this case will result in a black square image"),
|
||||
"sd_vae_encode_method": OptionInfo("Full", "VAE type for encode", gr.Radio, {"choices": ["Full", "TAESD"]}).info("method to encode image to latent (use in img2img, hires-fix or inpaint mask)"),
|
||||
"sd_vae_decode_method": OptionInfo("Full", "VAE type for decode", gr.Radio, {"choices": ["Full", "TAESD"]}).info("method to decode latent to image"),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('img2img', "img2img"), {
|
||||
"inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
||||
"initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.5, "maximum": 1.5, "step": 0.01}),
|
||||
"img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
|
||||
"img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies.").info("normally you'd do less with less denoising"),
|
||||
"img2img_background_color": OptionInfo("#ffffff", "With img2img, fill transparent parts of the input image with this color.", ui_components.FormColorPicker, {}),
|
||||
"img2img_editor_height": OptionInfo(720, "Height of the image editor", gr.Slider, {"minimum": 80, "maximum": 1600, "step": 1}).info("in pixels").needs_reload_ui(),
|
||||
"img2img_sketch_default_brush_color": OptionInfo("#ffffff", "Sketch initial brush color", ui_components.FormColorPicker, {}).info("default brush color of img2img sketch").needs_reload_ui(),
|
||||
"img2img_inpaint_mask_brush_color": OptionInfo("#ffffff", "Inpaint mask brush color", ui_components.FormColorPicker, {}).info("brush color of inpaint mask").needs_reload_ui(),
|
||||
"img2img_inpaint_sketch_default_brush_color": OptionInfo("#ffffff", "Inpaint sketch initial brush color", ui_components.FormColorPicker, {}).info("default brush color of img2img inpaint sketch").needs_reload_ui(),
|
||||
"return_mask": OptionInfo(False, "For inpainting, include the greyscale mask in results for web"),
|
||||
"return_mask_composite": OptionInfo(False, "For inpainting, include masked composite in results for web"),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('optimizations', "Optimizations"), {
|
||||
"cross_attention_optimization": OptionInfo("Automatic", "Cross attention optimization", gr.Dropdown, lambda: {"choices": shared_items.cross_attention_optimizations()}),
|
||||
"s_min_uncond": OptionInfo(0.0, "Negative Guidance minimum sigma", gr.Slider, {"minimum": 0.0, "maximum": 15.0, "step": 0.01}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9177").info("skip negative prompt for some steps when the image is almost ready; 0=disable, higher=faster"),
|
||||
@@ -445,7 +508,7 @@ options_templates.update(options_section(('optimizations', "Optimizations"), {
|
||||
"token_merging_ratio_img2img": OptionInfo(0.0, "Token merging ratio for img2img", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).info("only applies if non-zero and overrides above"),
|
||||
"token_merging_ratio_hr": OptionInfo(0.0, "Token merging ratio for high-res pass", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).info("only applies if non-zero and overrides above"),
|
||||
"pad_cond_uncond": OptionInfo(False, "Pad prompt/negative prompt to be same length").info("improves performance when prompt and negative prompt have different lengths; changes seeds"),
|
||||
"experimental_persistent_cond_cache": OptionInfo(False, "persistent cond cache").info("Experimental, keep cond caches across jobs, reduce overhead."),
|
||||
"persistent_cond_cache": OptionInfo(True, "Persistent cond cache").info("Do not recalculate conds from prompts if prompts have not changed since previous calculation"),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('compatibility', "Compatibility"), {
|
||||
@@ -457,7 +520,7 @@ options_templates.update(options_section(('compatibility', "Compatibility"), {
|
||||
"hires_fix_use_firstpass_conds": OptionInfo(False, "For hires fix, calculate conds of second pass using extra networks of first pass."),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('interrogate', "Interrogate Options"), {
|
||||
options_templates.update(options_section(('interrogate', "Interrogate"), {
|
||||
"interrogate_keep_models_in_memory": OptionInfo(False, "Keep models in VRAM"),
|
||||
"interrogate_return_ranks": OptionInfo(False, "Include ranks of model tags matches in results.").info("booru only"),
|
||||
"interrogate_clip_num_beams": OptionInfo(1, "BLIP: num_beams", gr.Slider, {"minimum": 1, "maximum": 16, "step": 1}),
|
||||
@@ -481,19 +544,17 @@ options_templates.update(options_section(('extra_networks', "Extra Networks"), {
|
||||
"extra_networks_card_text_scale": OptionInfo(1.0, "Card text scale", gr.Slider, {"minimum": 0.0, "maximum": 2.0, "step": 0.01}).info("1 = original size"),
|
||||
"extra_networks_card_show_desc": OptionInfo(True, "Show description on card"),
|
||||
"extra_networks_add_text_separator": OptionInfo(" ", "Extra networks separator").info("extra text to add before <...> when adding extra network to prompt"),
|
||||
"ui_extra_networks_tab_reorder": OptionInfo("", "Extra networks tab order").needs_restart(),
|
||||
"ui_extra_networks_tab_reorder": OptionInfo("", "Extra networks tab order").needs_reload_ui(),
|
||||
"textual_inversion_print_at_load": OptionInfo(False, "Print a list of Textual Inversion embeddings when loading model"),
|
||||
"textual_inversion_add_hashes_to_infotext": OptionInfo(True, "Add Textual Inversion hashes to infotext"),
|
||||
"sd_hypernetwork": OptionInfo("None", "Add hypernetwork to prompt", gr.Dropdown, lambda: {"choices": ["None", *hypernetworks]}, refresh=reload_hypernetworks),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('ui', "User interface"), {
|
||||
"localization": OptionInfo("None", "Localization", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)).needs_restart(),
|
||||
"gradio_theme": OptionInfo("Default", "Gradio theme", ui_components.DropdownEditable, lambda: {"choices": ["Default"] + gradio_hf_hub_themes}).needs_restart(),
|
||||
"img2img_editor_height": OptionInfo(720, "img2img: height of image editor", gr.Slider, {"minimum": 80, "maximum": 1600, "step": 1}).info("in pixels").needs_restart(),
|
||||
"localization": OptionInfo("None", "Localization", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)).needs_reload_ui(),
|
||||
"gradio_theme": OptionInfo("Default", "Gradio theme", ui_components.DropdownEditable, lambda: {"choices": ["Default"] + gradio_hf_hub_themes}).info("you can also manually enter any of themes from the <a href='https://huggingface.co/spaces/gradio/theme-gallery'>gallery</a>.").needs_reload_ui(),
|
||||
"gradio_themes_cache": OptionInfo(True, "Cache gradio themes locally").info("disable to update the selected Gradio theme"),
|
||||
"return_grid": OptionInfo(True, "Show grid in results for web"),
|
||||
"return_mask": OptionInfo(False, "For inpainting, include the greyscale mask in results for web"),
|
||||
"return_mask_composite": OptionInfo(False, "For inpainting, include masked composite in results for web"),
|
||||
"do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
|
||||
"send_seed": OptionInfo(True, "Send seed when sending prompt or image to other interface"),
|
||||
"send_size": OptionInfo(True, "Send size when sending prompt or image to another interface"),
|
||||
@@ -502,21 +563,22 @@ options_templates.update(options_section(('ui', "User interface"), {
|
||||
"js_modal_lightbox_gamepad": OptionInfo(False, "Navigate image viewer with gamepad"),
|
||||
"js_modal_lightbox_gamepad_repeat": OptionInfo(250, "Gamepad repeat period, in milliseconds"),
|
||||
"show_progress_in_title": OptionInfo(True, "Show generation progress in window title."),
|
||||
"samplers_in_dropdown": OptionInfo(True, "Use dropdown for sampler selection instead of radio group").needs_restart(),
|
||||
"dimensions_and_batch_together": OptionInfo(True, "Show Width/Height and Batch sliders in same row").needs_restart(),
|
||||
"samplers_in_dropdown": OptionInfo(True, "Use dropdown for sampler selection instead of radio group").needs_reload_ui(),
|
||||
"dimensions_and_batch_together": OptionInfo(True, "Show Width/Height and Batch sliders in same row").needs_reload_ui(),
|
||||
"keyedit_precision_attention": OptionInfo(0.1, "Ctrl+up/down precision when editing (attention:1.1)", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
|
||||
"keyedit_precision_extra": OptionInfo(0.05, "Ctrl+up/down precision when editing <extra networks:0.9>", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
|
||||
"keyedit_delimiters": OptionInfo(".,\\/!?%^*;:{}=`~()", "Ctrl+up/down word delimiters"),
|
||||
"keyedit_move": OptionInfo(True, "Alt+left/right moves prompt elements"),
|
||||
"quicksettings_list": OptionInfo(["sd_model_checkpoint"], "Quicksettings list", ui_components.DropdownMulti, lambda: {"choices": list(opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that appear at the top of page rather than in settings tab").needs_restart(),
|
||||
"ui_tab_order": OptionInfo([], "UI tab order", ui_components.DropdownMulti, lambda: {"choices": list(tab_names)}).needs_restart(),
|
||||
"hidden_tabs": OptionInfo([], "Hidden UI tabs", ui_components.DropdownMulti, lambda: {"choices": list(tab_names)}).needs_restart(),
|
||||
"ui_reorder_list": OptionInfo([], "txt2img/img2img UI item order", ui_components.DropdownMulti, lambda: {"choices": list(shared_items.ui_reorder_categories())}).info("selected items appear first").needs_restart(),
|
||||
"hires_fix_show_sampler": OptionInfo(False, "Hires fix: show hires sampler selection").needs_restart(),
|
||||
"hires_fix_show_prompts": OptionInfo(False, "Hires fix: show hires prompt and negative prompt").needs_restart(),
|
||||
"disable_token_counters": OptionInfo(False, "Disable prompt token counters").needs_restart(),
|
||||
"quicksettings_list": OptionInfo(["sd_model_checkpoint"], "Quicksettings list", ui_components.DropdownMulti, lambda: {"choices": list(opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that appear at the top of page rather than in settings tab").needs_reload_ui(),
|
||||
"ui_tab_order": OptionInfo([], "UI tab order", ui_components.DropdownMulti, lambda: {"choices": list(tab_names)}).needs_reload_ui(),
|
||||
"hidden_tabs": OptionInfo([], "Hidden UI tabs", ui_components.DropdownMulti, lambda: {"choices": list(tab_names)}).needs_reload_ui(),
|
||||
"ui_reorder_list": OptionInfo([], "txt2img/img2img UI item order", ui_components.DropdownMulti, lambda: {"choices": list(shared_items.ui_reorder_categories())}).info("selected items appear first").needs_reload_ui(),
|
||||
"hires_fix_show_sampler": OptionInfo(False, "Hires fix: show hires checkpoint and sampler selection").needs_reload_ui(),
|
||||
"hires_fix_show_prompts": OptionInfo(False, "Hires fix: show hires prompt and negative prompt").needs_reload_ui(),
|
||||
"disable_token_counters": OptionInfo(False, "Disable prompt token counters").needs_reload_ui(),
|
||||
}))
|
||||
|
||||
|
||||
options_templates.update(options_section(('infotext', "Infotext"), {
|
||||
"add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
|
||||
"add_model_name_to_info": OptionInfo(True, "Add model name to generation information"),
|
||||
@@ -544,12 +606,13 @@ options_templates.update(options_section(('ui', "Live previews"), {
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('sampler-params', "Sampler parameters"), {
|
||||
"hide_samplers": OptionInfo([], "Hide samplers in user interface", gr.CheckboxGroup, lambda: {"choices": [x.name for x in list_samplers()]}).needs_restart(),
|
||||
"hide_samplers": OptionInfo([], "Hide samplers in user interface", gr.CheckboxGroup, lambda: {"choices": [x.name for x in list_samplers()]}).needs_reload_ui(),
|
||||
"eta_ddim": OptionInfo(0.0, "Eta for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}).info("noise multiplier; higher = more unperdictable results"),
|
||||
"eta_ancestral": OptionInfo(1.0, "Eta for ancestral samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}).info("noise multiplier; applies to Euler a and other samplers that have a in them"),
|
||||
"ddim_discretize": OptionInfo('uniform', "img2img DDIM discretize", gr.Radio, {"choices": ['uniform', 'quad']}),
|
||||
's_churn': OptionInfo(0.0, "sigma churn", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
||||
's_churn': OptionInfo(0.0, "sigma churn", gr.Slider, {"minimum": 0.0, "maximum": 100.0, "step": 0.01}),
|
||||
's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
||||
's_tmax': OptionInfo(0.0, "sigma tmax", gr.Slider, {"minimum": 0.0, "maximum": 999.0, "step": 0.01}).info("0 = inf"),
|
||||
's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
||||
'k_sched_type': OptionInfo("Automatic", "scheduler type", gr.Dropdown, {"choices": ["Automatic", "karras", "exponential", "polyexponential"]}).info("lets you override the noise schedule for k-diffusion samplers; choosing Automatic disables the three parameters below"),
|
||||
'sigma_min': OptionInfo(0.0, "sigma min", gr.Number).info("0 = default (~0.03); minimum noise strength for k-diffusion noise scheduler"),
|
||||
@@ -594,6 +657,9 @@ class Options:
|
||||
assert not cmd_opts.freeze_settings, "changing settings is disabled"
|
||||
|
||||
info = opts.data_labels.get(key, None)
|
||||
if info.do_not_save:
|
||||
return
|
||||
|
||||
comp_args = info.component_args if info else None
|
||||
if isinstance(comp_args, dict) and comp_args.get('visible', True) is False:
|
||||
raise RuntimeError(f"not possible to set {key} because it is restricted")
|
||||
@@ -623,6 +689,9 @@ class Options:
|
||||
if oldval == value:
|
||||
return False
|
||||
|
||||
if self.data_labels[key].do_not_save:
|
||||
return False
|
||||
|
||||
try:
|
||||
setattr(self, key, value)
|
||||
except RuntimeError:
|
||||
@@ -799,13 +868,19 @@ def reload_gradio_theme(theme_name=None):
|
||||
gradio_theme = gr.themes.Default(**default_theme_args)
|
||||
else:
|
||||
try:
|
||||
gradio_theme = gr.themes.ThemeClass.from_hub(theme_name)
|
||||
theme_cache_dir = os.path.join(script_path, 'tmp', 'gradio_themes')
|
||||
theme_cache_path = os.path.join(theme_cache_dir, f'{theme_name.replace("/", "_")}.json')
|
||||
if opts.gradio_themes_cache and os.path.exists(theme_cache_path):
|
||||
gradio_theme = gr.themes.ThemeClass.load(theme_cache_path)
|
||||
else:
|
||||
os.makedirs(theme_cache_dir, exist_ok=True)
|
||||
gradio_theme = gr.themes.ThemeClass.from_hub(theme_name)
|
||||
gradio_theme.dump(theme_cache_path)
|
||||
except Exception as e:
|
||||
errors.display(e, "changing gradio theme")
|
||||
gradio_theme = gr.themes.Default(**default_theme_args)
|
||||
|
||||
|
||||
|
||||
class TotalTQDM:
|
||||
def __init__(self):
|
||||
self._tqdm = None
|
||||
@@ -889,3 +964,10 @@ def walk_files(path, allowed_extensions=None):
|
||||
continue
|
||||
|
||||
yield os.path.join(root, filename)
|
||||
|
||||
|
||||
def ldm_print(*args, **kwargs):
|
||||
if opts.hide_ldm_prints:
|
||||
return
|
||||
|
||||
print(*args, **kwargs)
|
||||
|
||||
+1
-4
@@ -106,10 +106,7 @@ class StyleDatabase:
|
||||
if os.path.exists(path):
|
||||
shutil.copy(path, f"{path}.bak")
|
||||
|
||||
fd = os.open(path, os.O_RDWR | os.O_CREAT)
|
||||
with os.fdopen(fd, "w", encoding="utf-8-sig", newline='') as file:
|
||||
# _fields is actually part of the public API: typing.NamedTuple is a replacement for collections.NamedTuple,
|
||||
# and collections.NamedTuple has explicit documentation for accessing _fields. Same goes for _asdict()
|
||||
with open(path, "w", encoding="utf-8-sig", newline='') as file:
|
||||
writer = csv.DictWriter(file, fieldnames=PromptStyle._fields)
|
||||
writer.writeheader()
|
||||
writer.writerows(style._asdict() for k, style in self.styles.items())
|
||||
|
||||
+5
-1
@@ -109,11 +109,15 @@ def format_traceback(tb):
|
||||
return [[f"{x.filename}, line {x.lineno}, {x.name}", x.line] for x in traceback.extract_tb(tb)]
|
||||
|
||||
|
||||
def format_exception(e, tb):
|
||||
return {"exception": str(e), "traceback": format_traceback(tb)}
|
||||
|
||||
|
||||
def get_exceptions():
|
||||
try:
|
||||
from modules import errors
|
||||
|
||||
return [{"exception": str(e), "traceback": format_traceback(tb)} for e, tb in reversed(errors.exception_records)]
|
||||
return list(reversed(errors.exception_records))
|
||||
except Exception as e:
|
||||
return str(e)
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ import numpy as np
|
||||
from PIL import Image, PngImagePlugin
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers, sd_hijack_checkpoint, errors, hashes
|
||||
from modules import shared, devices, sd_hijack, sd_models, images, sd_samplers, sd_hijack_checkpoint, errors, hashes
|
||||
import modules.textual_inversion.dataset
|
||||
from modules.textual_inversion.learn_schedule import LearnRateScheduler
|
||||
|
||||
@@ -181,29 +181,38 @@ class EmbeddingDatabase:
|
||||
else:
|
||||
return
|
||||
|
||||
|
||||
# textual inversion embeddings
|
||||
if 'string_to_param' in data:
|
||||
param_dict = data['string_to_param']
|
||||
param_dict = getattr(param_dict, '_parameters', param_dict) # fix for torch 1.12.1 loading saved file from torch 1.11
|
||||
assert len(param_dict) == 1, 'embedding file has multiple terms in it'
|
||||
emb = next(iter(param_dict.items()))[1]
|
||||
# diffuser concepts
|
||||
elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor:
|
||||
vec = emb.detach().to(devices.device, dtype=torch.float32)
|
||||
shape = vec.shape[-1]
|
||||
vectors = vec.shape[0]
|
||||
elif type(data) == dict and 'clip_g' in data and 'clip_l' in data: # SDXL embedding
|
||||
vec = {k: v.detach().to(devices.device, dtype=torch.float32) for k, v in data.items()}
|
||||
shape = data['clip_g'].shape[-1] + data['clip_l'].shape[-1]
|
||||
vectors = data['clip_g'].shape[0]
|
||||
elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: # diffuser concepts
|
||||
assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
|
||||
|
||||
emb = next(iter(data.values()))
|
||||
if len(emb.shape) == 1:
|
||||
emb = emb.unsqueeze(0)
|
||||
vec = emb.detach().to(devices.device, dtype=torch.float32)
|
||||
shape = vec.shape[-1]
|
||||
vectors = vec.shape[0]
|
||||
else:
|
||||
raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.")
|
||||
|
||||
vec = emb.detach().to(devices.device, dtype=torch.float32)
|
||||
embedding = Embedding(vec, name)
|
||||
embedding.step = data.get('step', None)
|
||||
embedding.sd_checkpoint = data.get('sd_checkpoint', None)
|
||||
embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
|
||||
embedding.vectors = vec.shape[0]
|
||||
embedding.shape = vec.shape[-1]
|
||||
embedding.vectors = vectors
|
||||
embedding.shape = shape
|
||||
embedding.filename = path
|
||||
embedding.set_hash(hashes.sha256(embedding.filename, "textual_inversion/" + name) or '')
|
||||
|
||||
@@ -378,6 +387,8 @@ def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, dat
|
||||
|
||||
|
||||
def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, use_weight, create_image_every, save_embedding_every, template_filename, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
|
||||
from modules import processing
|
||||
|
||||
save_embedding_every = save_embedding_every or 0
|
||||
create_image_every = create_image_every or 0
|
||||
template_file = textual_inversion_templates.get(template_filename, None)
|
||||
|
||||
+19
-4
@@ -1,4 +1,5 @@
|
||||
import time
|
||||
import argparse
|
||||
|
||||
|
||||
class TimerSubcategory:
|
||||
@@ -11,20 +12,27 @@ class TimerSubcategory:
|
||||
def __enter__(self):
|
||||
self.start = time.time()
|
||||
self.timer.base_category = self.original_base_category + self.category + "/"
|
||||
self.timer.subcategory_level += 1
|
||||
|
||||
if self.timer.print_log:
|
||||
print(f"{' ' * self.timer.subcategory_level}{self.category}:")
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
elapsed_for_subcategroy = time.time() - self.start
|
||||
self.timer.base_category = self.original_base_category
|
||||
self.timer.add_time_to_record(self.original_base_category + self.category, elapsed_for_subcategroy)
|
||||
self.timer.record(self.category)
|
||||
self.timer.subcategory_level -= 1
|
||||
self.timer.record(self.category, disable_log=True)
|
||||
|
||||
|
||||
class Timer:
|
||||
def __init__(self):
|
||||
def __init__(self, print_log=False):
|
||||
self.start = time.time()
|
||||
self.records = {}
|
||||
self.total = 0
|
||||
self.base_category = ''
|
||||
self.print_log = print_log
|
||||
self.subcategory_level = 0
|
||||
|
||||
def elapsed(self):
|
||||
end = time.time()
|
||||
@@ -38,13 +46,16 @@ class Timer:
|
||||
|
||||
self.records[category] += amount
|
||||
|
||||
def record(self, category, extra_time=0):
|
||||
def record(self, category, extra_time=0, disable_log=False):
|
||||
e = self.elapsed()
|
||||
|
||||
self.add_time_to_record(self.base_category + category, e + extra_time)
|
||||
|
||||
self.total += e + extra_time
|
||||
|
||||
if self.print_log and not disable_log:
|
||||
print(f"{' ' * self.subcategory_level}{category}: done in {e + extra_time:.3f}s")
|
||||
|
||||
def subcategory(self, name):
|
||||
self.elapsed()
|
||||
|
||||
@@ -71,6 +82,10 @@ class Timer:
|
||||
self.__init__()
|
||||
|
||||
|
||||
startup_timer = Timer()
|
||||
parser = argparse.ArgumentParser(add_help=False)
|
||||
parser.add_argument("--log-startup", action='store_true', help="print a detailed log of what's happening at startup")
|
||||
args = parser.parse_known_args()[0]
|
||||
|
||||
startup_timer = Timer(print_log=args.log_startup)
|
||||
|
||||
startup_record = None
|
||||
|
||||
+2
-1
@@ -9,7 +9,7 @@ from modules.ui import plaintext_to_html
|
||||
import gradio as gr
|
||||
|
||||
|
||||
def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_sampler_index: int, hr_prompt: str, hr_negative_prompt, override_settings_texts, request: gr.Request, *args):
|
||||
def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_checkpoint_name: str, hr_sampler_index: int, hr_prompt: str, hr_negative_prompt, override_settings_texts, request: gr.Request, *args):
|
||||
override_settings = create_override_settings_dict(override_settings_texts)
|
||||
|
||||
p = processing.StableDiffusionProcessingTxt2Img(
|
||||
@@ -41,6 +41,7 @@ def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, step
|
||||
hr_second_pass_steps=hr_second_pass_steps,
|
||||
hr_resize_x=hr_resize_x,
|
||||
hr_resize_y=hr_resize_y,
|
||||
hr_checkpoint_name=None if hr_checkpoint_name == 'Use same checkpoint' else hr_checkpoint_name,
|
||||
hr_sampler_name=sd_samplers.samplers_for_img2img[hr_sampler_index - 1].name if hr_sampler_index != 0 else None,
|
||||
hr_prompt=hr_prompt,
|
||||
hr_negative_prompt=hr_negative_prompt,
|
||||
|
||||
+155
-300
@@ -12,34 +12,30 @@ import numpy as np
|
||||
from PIL import Image, PngImagePlugin # noqa: F401
|
||||
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
|
||||
|
||||
from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, errors, shared_items, ui_settings, timer, sysinfo
|
||||
from modules import gradio_extensons # noqa: F401
|
||||
from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, errors, shared_items, ui_settings, timer, sysinfo, ui_checkpoint_merger, ui_prompt_styles, scripts
|
||||
from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML
|
||||
from modules.paths import script_path
|
||||
from modules.ui_common import create_refresh_button
|
||||
from modules.ui_gradio_extensions import reload_javascript
|
||||
|
||||
|
||||
from modules.shared import opts, cmd_opts
|
||||
|
||||
import modules.codeformer_model
|
||||
import modules.generation_parameters_copypaste as parameters_copypaste
|
||||
import modules.gfpgan_model
|
||||
import modules.hypernetworks.ui
|
||||
import modules.scripts
|
||||
import modules.hypernetworks.ui as hypernetworks_ui
|
||||
import modules.textual_inversion.ui as textual_inversion_ui
|
||||
import modules.textual_inversion.textual_inversion as textual_inversion
|
||||
import modules.shared as shared
|
||||
import modules.styles
|
||||
import modules.textual_inversion.ui
|
||||
import modules.images
|
||||
from modules import prompt_parser
|
||||
from modules.sd_hijack import model_hijack
|
||||
from modules.sd_samplers import samplers, samplers_for_img2img
|
||||
from modules.textual_inversion import textual_inversion
|
||||
import modules.hypernetworks.ui
|
||||
from modules.generation_parameters_copypaste import image_from_url_text
|
||||
import modules.extras
|
||||
|
||||
create_setting_component = ui_settings.create_setting_component
|
||||
|
||||
warnings.filterwarnings("default" if opts.show_warnings else "ignore", category=UserWarning)
|
||||
warnings.filterwarnings("default" if opts.show_gradio_deprecation_warnings else "ignore", category=gr.deprecation.GradioDeprecationWarning)
|
||||
|
||||
# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI
|
||||
mimetypes.init()
|
||||
@@ -92,19 +88,6 @@ def send_gradio_gallery_to_image(x):
|
||||
return image_from_url_text(x[0])
|
||||
|
||||
|
||||
def add_style(name: str, prompt: str, negative_prompt: str):
|
||||
if name is None:
|
||||
return [gr_show() for x in range(4)]
|
||||
|
||||
style = modules.styles.PromptStyle(name, prompt, negative_prompt)
|
||||
shared.prompt_styles.styles[style.name] = style
|
||||
# Save all loaded prompt styles: this allows us to update the storage format in the future more easily, because we
|
||||
# reserialize all styles every time we save them
|
||||
shared.prompt_styles.save_styles(shared.styles_filename)
|
||||
|
||||
return [gr.Dropdown.update(visible=True, choices=list(shared.prompt_styles.styles)) for _ in range(2)]
|
||||
|
||||
|
||||
def calc_resolution_hires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y):
|
||||
from modules import processing, devices
|
||||
|
||||
@@ -129,13 +112,6 @@ def resize_from_to_html(width, height, scale_by):
|
||||
return f"resize: from <span class='resolution'>{width}x{height}</span> to <span class='resolution'>{target_width}x{target_height}</span>"
|
||||
|
||||
|
||||
def apply_styles(prompt, prompt_neg, styles):
|
||||
prompt = shared.prompt_styles.apply_styles_to_prompt(prompt, styles)
|
||||
prompt_neg = shared.prompt_styles.apply_negative_styles_to_prompt(prompt_neg, styles)
|
||||
|
||||
return [gr.Textbox.update(value=prompt), gr.Textbox.update(value=prompt_neg), gr.Dropdown.update(value=[])]
|
||||
|
||||
|
||||
def process_interrogate(interrogation_function, mode, ii_input_dir, ii_output_dir, *ii_singles):
|
||||
if mode in {0, 1, 3, 4}:
|
||||
return [interrogation_function(ii_singles[mode]), None]
|
||||
@@ -172,7 +148,6 @@ def interrogate_deepbooru(image):
|
||||
def create_seed_inputs(target_interface):
|
||||
with FormRow(elem_id=f"{target_interface}_seed_row", variant="compact"):
|
||||
seed = (gr.Textbox if cmd_opts.use_textbox_seed else gr.Number)(label='Seed', value=-1, elem_id=f"{target_interface}_seed")
|
||||
seed.style(container=False)
|
||||
random_seed = ToolButton(random_symbol, elem_id=f"{target_interface}_random_seed", label='Random seed')
|
||||
reuse_seed = ToolButton(reuse_symbol, elem_id=f"{target_interface}_reuse_seed", label='Reuse seed')
|
||||
|
||||
@@ -184,7 +159,6 @@ def create_seed_inputs(target_interface):
|
||||
with FormRow(visible=False, elem_id=f"{target_interface}_subseed_row") as seed_extra_row_1:
|
||||
seed_extras.append(seed_extra_row_1)
|
||||
subseed = gr.Number(label='Variation seed', value=-1, elem_id=f"{target_interface}_subseed")
|
||||
subseed.style(container=False)
|
||||
random_subseed = ToolButton(random_symbol, elem_id=f"{target_interface}_random_subseed")
|
||||
reuse_subseed = ToolButton(reuse_symbol, elem_id=f"{target_interface}_reuse_subseed")
|
||||
subseed_strength = gr.Slider(label='Variation strength', value=0.0, minimum=0, maximum=1, step=0.01, elem_id=f"{target_interface}_subseed_strength")
|
||||
@@ -267,71 +241,76 @@ def update_token_counter(text, steps):
|
||||
return f"<span class='gr-box gr-text-input'>{token_count}/{max_length}</span>"
|
||||
|
||||
|
||||
def create_toprow(is_img2img):
|
||||
id_part = "img2img" if is_img2img else "txt2img"
|
||||
class Toprow:
|
||||
"""Creates a top row UI with prompts, generate button, styles, extra little buttons for things, and enables some functionality related to their operation"""
|
||||
|
||||
with gr.Row(elem_id=f"{id_part}_toprow", variant="compact"):
|
||||
with gr.Column(elem_id=f"{id_part}_prompt_container", scale=6):
|
||||
with gr.Row():
|
||||
with gr.Column(scale=80):
|
||||
with gr.Row():
|
||||
prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, lines=3, placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)", elem_classes=["prompt"])
|
||||
def __init__(self, is_img2img):
|
||||
id_part = "img2img" if is_img2img else "txt2img"
|
||||
self.id_part = id_part
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column(scale=80):
|
||||
with gr.Row():
|
||||
negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)", elem_classes=["prompt"])
|
||||
with gr.Row(elem_id=f"{id_part}_toprow", variant="compact"):
|
||||
with gr.Column(elem_id=f"{id_part}_prompt_container", scale=6):
|
||||
with gr.Row():
|
||||
with gr.Column(scale=80):
|
||||
with gr.Row():
|
||||
self.prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, lines=3, placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)", elem_classes=["prompt"])
|
||||
self.prompt_img = gr.File(label="", elem_id=f"{id_part}_prompt_image", file_count="single", type="binary", visible=False)
|
||||
|
||||
button_interrogate = None
|
||||
button_deepbooru = None
|
||||
if is_img2img:
|
||||
with gr.Column(scale=1, elem_classes="interrogate-col"):
|
||||
button_interrogate = gr.Button('Interrogate\nCLIP', elem_id="interrogate")
|
||||
button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru")
|
||||
with gr.Row():
|
||||
with gr.Column(scale=80):
|
||||
with gr.Row():
|
||||
self.negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)", elem_classes=["prompt"])
|
||||
|
||||
with gr.Column(scale=1, elem_id=f"{id_part}_actions_column"):
|
||||
with gr.Row(elem_id=f"{id_part}_generate_box", elem_classes="generate-box"):
|
||||
interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt", elem_classes="generate-box-interrupt")
|
||||
skip = gr.Button('Skip', elem_id=f"{id_part}_skip", elem_classes="generate-box-skip")
|
||||
submit = gr.Button('Generate', elem_id=f"{id_part}_generate", variant='primary')
|
||||
self.button_interrogate = None
|
||||
self.button_deepbooru = None
|
||||
if is_img2img:
|
||||
with gr.Column(scale=1, elem_classes="interrogate-col"):
|
||||
self.button_interrogate = gr.Button('Interrogate\nCLIP', elem_id="interrogate")
|
||||
self.button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru")
|
||||
|
||||
skip.click(
|
||||
fn=lambda: shared.state.skip(),
|
||||
inputs=[],
|
||||
outputs=[],
|
||||
)
|
||||
with gr.Column(scale=1, elem_id=f"{id_part}_actions_column"):
|
||||
with gr.Row(elem_id=f"{id_part}_generate_box", elem_classes="generate-box"):
|
||||
self.interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt", elem_classes="generate-box-interrupt")
|
||||
self.skip = gr.Button('Skip', elem_id=f"{id_part}_skip", elem_classes="generate-box-skip")
|
||||
self.submit = gr.Button('Generate', elem_id=f"{id_part}_generate", variant='primary')
|
||||
|
||||
interrupt.click(
|
||||
fn=lambda: shared.state.interrupt(),
|
||||
inputs=[],
|
||||
outputs=[],
|
||||
)
|
||||
self.skip.click(
|
||||
fn=lambda: shared.state.skip(),
|
||||
inputs=[],
|
||||
outputs=[],
|
||||
)
|
||||
|
||||
with gr.Row(elem_id=f"{id_part}_tools"):
|
||||
paste = ToolButton(value=paste_symbol, elem_id="paste")
|
||||
clear_prompt_button = ToolButton(value=clear_prompt_symbol, elem_id=f"{id_part}_clear_prompt")
|
||||
extra_networks_button = ToolButton(value=extra_networks_symbol, elem_id=f"{id_part}_extra_networks")
|
||||
prompt_style_apply = ToolButton(value=apply_style_symbol, elem_id=f"{id_part}_style_apply")
|
||||
save_style = ToolButton(value=save_style_symbol, elem_id=f"{id_part}_style_create")
|
||||
restore_progress_button = ToolButton(value=restore_progress_symbol, elem_id=f"{id_part}_restore_progress", visible=False)
|
||||
self.interrupt.click(
|
||||
fn=lambda: shared.state.interrupt(),
|
||||
inputs=[],
|
||||
outputs=[],
|
||||
)
|
||||
|
||||
token_counter = gr.HTML(value="<span>0/75</span>", elem_id=f"{id_part}_token_counter", elem_classes=["token-counter"])
|
||||
token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button")
|
||||
negative_token_counter = gr.HTML(value="<span>0/75</span>", elem_id=f"{id_part}_negative_token_counter", elem_classes=["token-counter"])
|
||||
negative_token_button = gr.Button(visible=False, elem_id=f"{id_part}_negative_token_button")
|
||||
with gr.Row(elem_id=f"{id_part}_tools"):
|
||||
self.paste = ToolButton(value=paste_symbol, elem_id="paste")
|
||||
self.clear_prompt_button = ToolButton(value=clear_prompt_symbol, elem_id=f"{id_part}_clear_prompt")
|
||||
self.restore_progress_button = ToolButton(value=restore_progress_symbol, elem_id=f"{id_part}_restore_progress", visible=False)
|
||||
|
||||
clear_prompt_button.click(
|
||||
fn=lambda *x: x,
|
||||
_js="confirm_clear_prompt",
|
||||
inputs=[prompt, negative_prompt],
|
||||
outputs=[prompt, negative_prompt],
|
||||
)
|
||||
self.token_counter = gr.HTML(value="<span>0/75</span>", elem_id=f"{id_part}_token_counter", elem_classes=["token-counter"])
|
||||
self.token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button")
|
||||
self.negative_token_counter = gr.HTML(value="<span>0/75</span>", elem_id=f"{id_part}_negative_token_counter", elem_classes=["token-counter"])
|
||||
self.negative_token_button = gr.Button(visible=False, elem_id=f"{id_part}_negative_token_button")
|
||||
|
||||
with gr.Row(elem_id=f"{id_part}_styles_row"):
|
||||
prompt_styles = gr.Dropdown(label="Styles", elem_id=f"{id_part}_styles", choices=[k for k, v in shared.prompt_styles.styles.items()], value=[], multiselect=True)
|
||||
create_refresh_button(prompt_styles, shared.prompt_styles.reload, lambda: {"choices": [k for k, v in shared.prompt_styles.styles.items()]}, f"refresh_{id_part}_styles")
|
||||
self.clear_prompt_button.click(
|
||||
fn=lambda *x: x,
|
||||
_js="confirm_clear_prompt",
|
||||
inputs=[self.prompt, self.negative_prompt],
|
||||
outputs=[self.prompt, self.negative_prompt],
|
||||
)
|
||||
|
||||
return prompt, prompt_styles, negative_prompt, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button, restore_progress_button
|
||||
self.ui_styles = ui_prompt_styles.UiPromptStyles(id_part, self.prompt, self.negative_prompt)
|
||||
|
||||
self.prompt_img.change(
|
||||
fn=modules.images.image_data,
|
||||
inputs=[self.prompt_img],
|
||||
outputs=[self.prompt, self.prompt_img],
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
|
||||
def setup_progressbar(*args, **kwargs):
|
||||
@@ -415,22 +394,20 @@ def create_ui():
|
||||
|
||||
parameters_copypaste.reset()
|
||||
|
||||
modules.scripts.scripts_current = modules.scripts.scripts_txt2img
|
||||
modules.scripts.scripts_txt2img.initialize_scripts(is_img2img=False)
|
||||
scripts.scripts_current = scripts.scripts_txt2img
|
||||
scripts.scripts_txt2img.initialize_scripts(is_img2img=False)
|
||||
|
||||
with gr.Blocks(analytics_enabled=False) as txt2img_interface:
|
||||
txt2img_prompt, txt2img_prompt_styles, txt2img_negative_prompt, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button, restore_progress_button = create_toprow(is_img2img=False)
|
||||
toprow = Toprow(is_img2img=False)
|
||||
|
||||
dummy_component = gr.Label(visible=False)
|
||||
txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="binary", visible=False)
|
||||
|
||||
with FormRow(variant='compact', elem_id="txt2img_extra_networks", visible=False) as extra_networks:
|
||||
from modules import ui_extra_networks
|
||||
extra_networks_ui = ui_extra_networks.create_ui(extra_networks, extra_networks_button, 'txt2img')
|
||||
extra_tabs = gr.Tabs(elem_id="txt2img_extra_tabs")
|
||||
extra_tabs.__enter__()
|
||||
|
||||
with gr.Row().style(equal_height=False):
|
||||
with gr.Tab("Generation", id="txt2img_generation") as txt2img_generation_tab, gr.Row().style(equal_height=False):
|
||||
with gr.Column(variant='compact', elem_id="txt2img_settings"):
|
||||
modules.scripts.scripts_txt2img.prepare_ui()
|
||||
scripts.scripts_txt2img.prepare_ui()
|
||||
|
||||
for category in ordered_ui_categories():
|
||||
if category == "sampler":
|
||||
@@ -476,6 +453,10 @@ def create_ui():
|
||||
hr_resize_y = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize height to", value=0, elem_id="txt2img_hr_resize_y")
|
||||
|
||||
with FormRow(elem_id="txt2img_hires_fix_row3", variant="compact", visible=opts.hires_fix_show_sampler) as hr_sampler_container:
|
||||
|
||||
hr_checkpoint_name = gr.Dropdown(label='Hires checkpoint', elem_id="hr_checkpoint", choices=["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True), value="Use same checkpoint")
|
||||
create_refresh_button(hr_checkpoint_name, modules.sd_models.list_models, lambda: {"choices": ["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True)}, "hr_checkpoint_refresh")
|
||||
|
||||
hr_sampler_index = gr.Dropdown(label='Hires sampling method', elem_id="hr_sampler", choices=["Use same sampler"] + [x.name for x in samplers_for_img2img], value="Use same sampler", type="index")
|
||||
|
||||
with FormRow(elem_id="txt2img_hires_fix_row4", variant="compact", visible=opts.hires_fix_show_prompts) as hr_prompts_container:
|
||||
@@ -498,10 +479,10 @@ def create_ui():
|
||||
|
||||
elif category == "scripts":
|
||||
with FormGroup(elem_id="txt2img_script_container"):
|
||||
custom_inputs = modules.scripts.scripts_txt2img.setup_ui()
|
||||
custom_inputs = scripts.scripts_txt2img.setup_ui()
|
||||
|
||||
else:
|
||||
modules.scripts.scripts_txt2img.setup_ui_for_section(category)
|
||||
scripts.scripts_txt2img.setup_ui_for_section(category)
|
||||
|
||||
hr_resolution_preview_inputs = [enable_hr, width, height, hr_scale, hr_resize_x, hr_resize_y]
|
||||
|
||||
@@ -532,9 +513,9 @@ def create_ui():
|
||||
_js="submit",
|
||||
inputs=[
|
||||
dummy_component,
|
||||
txt2img_prompt,
|
||||
txt2img_negative_prompt,
|
||||
txt2img_prompt_styles,
|
||||
toprow.prompt,
|
||||
toprow.negative_prompt,
|
||||
toprow.ui_styles.dropdown,
|
||||
steps,
|
||||
sampler_index,
|
||||
restore_faces,
|
||||
@@ -553,6 +534,7 @@ def create_ui():
|
||||
hr_second_pass_steps,
|
||||
hr_resize_x,
|
||||
hr_resize_y,
|
||||
hr_checkpoint_name,
|
||||
hr_sampler_index,
|
||||
hr_prompt,
|
||||
hr_negative_prompt,
|
||||
@@ -569,12 +551,12 @@ def create_ui():
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
txt2img_prompt.submit(**txt2img_args)
|
||||
submit.click(**txt2img_args)
|
||||
toprow.prompt.submit(**txt2img_args)
|
||||
toprow.submit.click(**txt2img_args)
|
||||
|
||||
res_switch_btn.click(fn=None, _js="function(){switchWidthHeight('txt2img')}", inputs=None, outputs=None, show_progress=False)
|
||||
|
||||
restore_progress_button.click(
|
||||
toprow.restore_progress_button.click(
|
||||
fn=progress.restore_progress,
|
||||
_js="restoreProgressTxt2img",
|
||||
inputs=[dummy_component],
|
||||
@@ -587,18 +569,6 @@ def create_ui():
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
txt_prompt_img.change(
|
||||
fn=modules.images.image_data,
|
||||
inputs=[
|
||||
txt_prompt_img
|
||||
],
|
||||
outputs=[
|
||||
txt2img_prompt,
|
||||
txt_prompt_img
|
||||
],
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
enable_hr.change(
|
||||
fn=lambda x: gr_show(x),
|
||||
inputs=[enable_hr],
|
||||
@@ -607,8 +577,8 @@ def create_ui():
|
||||
)
|
||||
|
||||
txt2img_paste_fields = [
|
||||
(txt2img_prompt, "Prompt"),
|
||||
(txt2img_negative_prompt, "Negative prompt"),
|
||||
(toprow.prompt, "Prompt"),
|
||||
(toprow.negative_prompt, "Negative prompt"),
|
||||
(steps, "Steps"),
|
||||
(sampler_index, "Sampler"),
|
||||
(restore_faces, "Face restoration"),
|
||||
@@ -617,34 +587,36 @@ def create_ui():
|
||||
(width, "Size-1"),
|
||||
(height, "Size-2"),
|
||||
(batch_size, "Batch size"),
|
||||
(seed_checkbox, lambda d: "Variation seed" in d or "Seed resize from-1" in d),
|
||||
(subseed, "Variation seed"),
|
||||
(subseed_strength, "Variation seed strength"),
|
||||
(seed_resize_from_w, "Seed resize from-1"),
|
||||
(seed_resize_from_h, "Seed resize from-2"),
|
||||
(txt2img_prompt_styles, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()),
|
||||
(toprow.ui_styles.dropdown, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()),
|
||||
(denoising_strength, "Denoising strength"),
|
||||
(enable_hr, lambda d: "Denoising strength" in d),
|
||||
(hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)),
|
||||
(enable_hr, lambda d: "Denoising strength" in d and ("Hires upscale" in d or "Hires upscaler" in d or "Hires resize-1" in d)),
|
||||
(hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d and ("Hires upscale" in d or "Hires upscaler" in d or "Hires resize-1" in d))),
|
||||
(hr_scale, "Hires upscale"),
|
||||
(hr_upscaler, "Hires upscaler"),
|
||||
(hr_second_pass_steps, "Hires steps"),
|
||||
(hr_resize_x, "Hires resize-1"),
|
||||
(hr_resize_y, "Hires resize-2"),
|
||||
(hr_checkpoint_name, "Hires checkpoint"),
|
||||
(hr_sampler_index, "Hires sampler"),
|
||||
(hr_sampler_container, lambda d: gr.update(visible=True) if d.get("Hires sampler", "Use same sampler") != "Use same sampler" else gr.update()),
|
||||
(hr_sampler_container, lambda d: gr.update(visible=True) if d.get("Hires sampler", "Use same sampler") != "Use same sampler" or d.get("Hires checkpoint", "Use same checkpoint") != "Use same checkpoint" else gr.update()),
|
||||
(hr_prompt, "Hires prompt"),
|
||||
(hr_negative_prompt, "Hires negative prompt"),
|
||||
(hr_prompts_container, lambda d: gr.update(visible=True) if d.get("Hires prompt", "") != "" or d.get("Hires negative prompt", "") != "" else gr.update()),
|
||||
*modules.scripts.scripts_txt2img.infotext_fields
|
||||
*scripts.scripts_txt2img.infotext_fields
|
||||
]
|
||||
parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields, override_settings)
|
||||
parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding(
|
||||
paste_button=txt2img_paste, tabname="txt2img", source_text_component=txt2img_prompt, source_image_component=None,
|
||||
paste_button=toprow.paste, tabname="txt2img", source_text_component=toprow.prompt, source_image_component=None,
|
||||
))
|
||||
|
||||
txt2img_preview_params = [
|
||||
txt2img_prompt,
|
||||
txt2img_negative_prompt,
|
||||
toprow.prompt,
|
||||
toprow.negative_prompt,
|
||||
steps,
|
||||
sampler_index,
|
||||
cfg_scale,
|
||||
@@ -653,24 +625,25 @@ def create_ui():
|
||||
height,
|
||||
]
|
||||
|
||||
token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_prompt, steps], outputs=[token_counter])
|
||||
negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_negative_prompt, steps], outputs=[negative_token_counter])
|
||||
toprow.token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[toprow.prompt, steps], outputs=[toprow.token_counter])
|
||||
toprow.negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[toprow.negative_prompt, steps], outputs=[toprow.negative_token_counter])
|
||||
|
||||
ui_extra_networks.setup_ui(extra_networks_ui, txt2img_gallery)
|
||||
from modules import ui_extra_networks
|
||||
extra_networks_ui = ui_extra_networks.create_ui(txt2img_interface, [txt2img_generation_tab], 'txt2img')
|
||||
ui_extra_networks.setup_ui(extra_networks_ui, txt2img_gallery)
|
||||
|
||||
modules.scripts.scripts_current = modules.scripts.scripts_img2img
|
||||
modules.scripts.scripts_img2img.initialize_scripts(is_img2img=True)
|
||||
extra_tabs.__exit__()
|
||||
|
||||
scripts.scripts_current = scripts.scripts_img2img
|
||||
scripts.scripts_img2img.initialize_scripts(is_img2img=True)
|
||||
|
||||
with gr.Blocks(analytics_enabled=False) as img2img_interface:
|
||||
img2img_prompt, img2img_prompt_styles, img2img_negative_prompt, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button, restore_progress_button = create_toprow(is_img2img=True)
|
||||
toprow = Toprow(is_img2img=True)
|
||||
|
||||
img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="binary", visible=False)
|
||||
extra_tabs = gr.Tabs(elem_id="img2img_extra_tabs")
|
||||
extra_tabs.__enter__()
|
||||
|
||||
with FormRow(variant='compact', elem_id="img2img_extra_networks", visible=False) as extra_networks:
|
||||
from modules import ui_extra_networks
|
||||
extra_networks_ui_img2img = ui_extra_networks.create_ui(extra_networks, extra_networks_button, 'img2img')
|
||||
|
||||
with FormRow().style(equal_height=False):
|
||||
with gr.Tab("Generation", id="img2img_generation") as img2img_generation_tab, FormRow().style(equal_height=False):
|
||||
with gr.Column(variant='compact', elem_id="img2img_settings"):
|
||||
copy_image_buttons = []
|
||||
copy_image_destinations = {}
|
||||
@@ -692,19 +665,19 @@ def create_ui():
|
||||
img2img_selected_tab = gr.State(0)
|
||||
|
||||
with gr.TabItem('img2img', id='img2img', elem_id="img2img_img2img_tab") as tab_img2img:
|
||||
init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool="editor", image_mode="RGBA").style(height=opts.img2img_editor_height)
|
||||
init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool="editor", image_mode="RGBA", height=opts.img2img_editor_height)
|
||||
add_copy_image_controls('img2img', init_img)
|
||||
|
||||
with gr.TabItem('Sketch', id='img2img_sketch', elem_id="img2img_img2img_sketch_tab") as tab_sketch:
|
||||
sketch = gr.Image(label="Image for img2img", elem_id="img2img_sketch", show_label=False, source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGBA").style(height=opts.img2img_editor_height)
|
||||
sketch = gr.Image(label="Image for img2img", elem_id="img2img_sketch", show_label=False, source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGBA", height=opts.img2img_editor_height, brush_color=opts.img2img_sketch_default_brush_color)
|
||||
add_copy_image_controls('sketch', sketch)
|
||||
|
||||
with gr.TabItem('Inpaint', id='inpaint', elem_id="img2img_inpaint_tab") as tab_inpaint:
|
||||
init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA").style(height=opts.img2img_editor_height)
|
||||
init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA", height=opts.img2img_editor_height, brush_color=opts.img2img_inpaint_mask_brush_color)
|
||||
add_copy_image_controls('inpaint', init_img_with_mask)
|
||||
|
||||
with gr.TabItem('Inpaint sketch', id='inpaint_sketch', elem_id="img2img_inpaint_sketch_tab") as tab_inpaint_color:
|
||||
inpaint_color_sketch = gr.Image(label="Color sketch inpainting", show_label=False, elem_id="inpaint_sketch", source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGBA").style(height=opts.img2img_editor_height)
|
||||
inpaint_color_sketch = gr.Image(label="Color sketch inpainting", show_label=False, elem_id="inpaint_sketch", source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGBA", height=opts.img2img_editor_height, brush_color=opts.img2img_inpaint_sketch_default_brush_color)
|
||||
inpaint_color_sketch_orig = gr.State(None)
|
||||
add_copy_image_controls('inpaint_sketch', inpaint_color_sketch)
|
||||
|
||||
@@ -764,7 +737,7 @@ def create_ui():
|
||||
with FormRow():
|
||||
resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", choices=["Just resize", "Crop and resize", "Resize and fill", "Just resize (latent upscale)"], type="index", value="Just resize")
|
||||
|
||||
modules.scripts.scripts_img2img.prepare_ui()
|
||||
scripts.scripts_img2img.prepare_ui()
|
||||
|
||||
for category in ordered_ui_categories():
|
||||
if category == "sampler":
|
||||
@@ -845,7 +818,7 @@ def create_ui():
|
||||
|
||||
elif category == "scripts":
|
||||
with FormGroup(elem_id="img2img_script_container"):
|
||||
custom_inputs = modules.scripts.scripts_img2img.setup_ui()
|
||||
custom_inputs = scripts.scripts_img2img.setup_ui()
|
||||
|
||||
elif category == "inpaint":
|
||||
with FormGroup(elem_id="inpaint_controls", visible=False) as inpaint_controls:
|
||||
@@ -876,34 +849,22 @@ def create_ui():
|
||||
outputs=[inpaint_controls, mask_alpha],
|
||||
)
|
||||
else:
|
||||
modules.scripts.scripts_img2img.setup_ui_for_section(category)
|
||||
scripts.scripts_img2img.setup_ui_for_section(category)
|
||||
|
||||
img2img_gallery, generation_info, html_info, html_log = create_output_panel("img2img", opts.outdir_img2img_samples)
|
||||
|
||||
connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False)
|
||||
connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True)
|
||||
|
||||
img2img_prompt_img.change(
|
||||
fn=modules.images.image_data,
|
||||
inputs=[
|
||||
img2img_prompt_img
|
||||
],
|
||||
outputs=[
|
||||
img2img_prompt,
|
||||
img2img_prompt_img
|
||||
],
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
img2img_args = dict(
|
||||
fn=wrap_gradio_gpu_call(modules.img2img.img2img, extra_outputs=[None, '', '']),
|
||||
_js="submit_img2img",
|
||||
inputs=[
|
||||
dummy_component,
|
||||
dummy_component,
|
||||
img2img_prompt,
|
||||
img2img_negative_prompt,
|
||||
img2img_prompt_styles,
|
||||
toprow.prompt,
|
||||
toprow.negative_prompt,
|
||||
toprow.ui_styles.dropdown,
|
||||
init_img,
|
||||
sketch,
|
||||
init_img_with_mask,
|
||||
@@ -962,11 +923,11 @@ def create_ui():
|
||||
inpaint_color_sketch,
|
||||
init_img_inpaint,
|
||||
],
|
||||
outputs=[img2img_prompt, dummy_component],
|
||||
outputs=[toprow.prompt, dummy_component],
|
||||
)
|
||||
|
||||
img2img_prompt.submit(**img2img_args)
|
||||
submit.click(**img2img_args)
|
||||
toprow.prompt.submit(**img2img_args)
|
||||
toprow.submit.click(**img2img_args)
|
||||
|
||||
res_switch_btn.click(fn=None, _js="function(){switchWidthHeight('img2img')}", inputs=None, outputs=None, show_progress=False)
|
||||
|
||||
@@ -978,7 +939,7 @@ def create_ui():
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
restore_progress_button.click(
|
||||
toprow.restore_progress_button.click(
|
||||
fn=progress.restore_progress,
|
||||
_js="restoreProgressImg2img",
|
||||
inputs=[dummy_component],
|
||||
@@ -991,46 +952,22 @@ def create_ui():
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
img2img_interrogate.click(
|
||||
toprow.button_interrogate.click(
|
||||
fn=lambda *args: process_interrogate(interrogate, *args),
|
||||
**interrogate_args,
|
||||
)
|
||||
|
||||
img2img_deepbooru.click(
|
||||
toprow.button_deepbooru.click(
|
||||
fn=lambda *args: process_interrogate(interrogate_deepbooru, *args),
|
||||
**interrogate_args,
|
||||
)
|
||||
|
||||
prompts = [(txt2img_prompt, txt2img_negative_prompt), (img2img_prompt, img2img_negative_prompt)]
|
||||
style_dropdowns = [txt2img_prompt_styles, img2img_prompt_styles]
|
||||
style_js_funcs = ["update_txt2img_tokens", "update_img2img_tokens"]
|
||||
|
||||
for button, (prompt, negative_prompt) in zip([txt2img_save_style, img2img_save_style], prompts):
|
||||
button.click(
|
||||
fn=add_style,
|
||||
_js="ask_for_style_name",
|
||||
# Have to pass empty dummy component here, because the JavaScript and Python function have to accept
|
||||
# the same number of parameters, but we only know the style-name after the JavaScript prompt
|
||||
inputs=[dummy_component, prompt, negative_prompt],
|
||||
outputs=[txt2img_prompt_styles, img2img_prompt_styles],
|
||||
)
|
||||
|
||||
for button, (prompt, negative_prompt), styles, js_func in zip([txt2img_prompt_style_apply, img2img_prompt_style_apply], prompts, style_dropdowns, style_js_funcs):
|
||||
button.click(
|
||||
fn=apply_styles,
|
||||
_js=js_func,
|
||||
inputs=[prompt, negative_prompt, styles],
|
||||
outputs=[prompt, negative_prompt, styles],
|
||||
)
|
||||
|
||||
token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter])
|
||||
negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[img2img_negative_prompt, steps], outputs=[negative_token_counter])
|
||||
|
||||
ui_extra_networks.setup_ui(extra_networks_ui_img2img, img2img_gallery)
|
||||
toprow.token_button.click(fn=update_token_counter, inputs=[toprow.prompt, steps], outputs=[toprow.token_counter])
|
||||
toprow.negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[toprow.negative_prompt, steps], outputs=[toprow.negative_token_counter])
|
||||
|
||||
img2img_paste_fields = [
|
||||
(img2img_prompt, "Prompt"),
|
||||
(img2img_negative_prompt, "Negative prompt"),
|
||||
(toprow.prompt, "Prompt"),
|
||||
(toprow.negative_prompt, "Negative prompt"),
|
||||
(steps, "Steps"),
|
||||
(sampler_index, "Sampler"),
|
||||
(restore_faces, "Face restoration"),
|
||||
@@ -1040,28 +977,35 @@ def create_ui():
|
||||
(width, "Size-1"),
|
||||
(height, "Size-2"),
|
||||
(batch_size, "Batch size"),
|
||||
(seed_checkbox, lambda d: "Variation seed" in d or "Seed resize from-1" in d),
|
||||
(subseed, "Variation seed"),
|
||||
(subseed_strength, "Variation seed strength"),
|
||||
(seed_resize_from_w, "Seed resize from-1"),
|
||||
(seed_resize_from_h, "Seed resize from-2"),
|
||||
(img2img_prompt_styles, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()),
|
||||
(toprow.ui_styles.dropdown, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()),
|
||||
(denoising_strength, "Denoising strength"),
|
||||
(mask_blur, "Mask blur"),
|
||||
*modules.scripts.scripts_img2img.infotext_fields
|
||||
*scripts.scripts_img2img.infotext_fields
|
||||
]
|
||||
parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields, override_settings)
|
||||
parameters_copypaste.add_paste_fields("inpaint", init_img_with_mask, img2img_paste_fields, override_settings)
|
||||
parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding(
|
||||
paste_button=img2img_paste, tabname="img2img", source_text_component=img2img_prompt, source_image_component=None,
|
||||
paste_button=toprow.paste, tabname="img2img", source_text_component=toprow.prompt, source_image_component=None,
|
||||
))
|
||||
|
||||
modules.scripts.scripts_current = None
|
||||
from modules import ui_extra_networks
|
||||
extra_networks_ui_img2img = ui_extra_networks.create_ui(img2img_interface, [img2img_generation_tab], 'img2img')
|
||||
ui_extra_networks.setup_ui(extra_networks_ui_img2img, img2img_gallery)
|
||||
|
||||
extra_tabs.__exit__()
|
||||
|
||||
scripts.scripts_current = None
|
||||
|
||||
with gr.Blocks(analytics_enabled=False) as extras_interface:
|
||||
ui_postprocessing.create_ui()
|
||||
|
||||
with gr.Blocks(analytics_enabled=False) as pnginfo_interface:
|
||||
with gr.Row().style(equal_height=False):
|
||||
with gr.Row(equal_height=False):
|
||||
with gr.Column(variant='panel'):
|
||||
image = gr.Image(elem_id="pnginfo_image", label="Source", source="upload", interactive=True, type="pil")
|
||||
|
||||
@@ -1083,64 +1027,13 @@ def create_ui():
|
||||
outputs=[html, generation_info, html2],
|
||||
)
|
||||
|
||||
def update_interp_description(value):
|
||||
interp_description_css = "<p style='margin-bottom: 2.5em'>{}</p>"
|
||||
interp_descriptions = {
|
||||
"No interpolation": interp_description_css.format("No interpolation will be used. Requires one model; A. Allows for format conversion and VAE baking."),
|
||||
"Weighted sum": interp_description_css.format("A weighted sum will be used for interpolation. Requires two models; A and B. The result is calculated as A * (1 - M) + B * M"),
|
||||
"Add difference": interp_description_css.format("The difference between the last two models will be added to the first. Requires three models; A, B and C. The result is calculated as A + (B - C) * M")
|
||||
}
|
||||
return interp_descriptions[value]
|
||||
|
||||
with gr.Blocks(analytics_enabled=False) as modelmerger_interface:
|
||||
with gr.Row().style(equal_height=False):
|
||||
with gr.Column(variant='compact'):
|
||||
interp_description = gr.HTML(value=update_interp_description("Weighted sum"), elem_id="modelmerger_interp_description")
|
||||
|
||||
with FormRow(elem_id="modelmerger_models"):
|
||||
primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary model (A)")
|
||||
create_refresh_button(primary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_A")
|
||||
|
||||
secondary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_secondary_model_name", label="Secondary model (B)")
|
||||
create_refresh_button(secondary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_B")
|
||||
|
||||
tertiary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_tertiary_model_name", label="Tertiary model (C)")
|
||||
create_refresh_button(tertiary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_C")
|
||||
|
||||
custom_name = gr.Textbox(label="Custom Name (Optional)", elem_id="modelmerger_custom_name")
|
||||
interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3, elem_id="modelmerger_interp_amount")
|
||||
interp_method = gr.Radio(choices=["No interpolation", "Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method", elem_id="modelmerger_interp_method")
|
||||
interp_method.change(fn=update_interp_description, inputs=[interp_method], outputs=[interp_description])
|
||||
|
||||
with FormRow():
|
||||
checkpoint_format = gr.Radio(choices=["ckpt", "safetensors"], value="safetensors", label="Checkpoint format", elem_id="modelmerger_checkpoint_format")
|
||||
save_as_half = gr.Checkbox(value=False, label="Save as float16", elem_id="modelmerger_save_as_half")
|
||||
save_metadata = gr.Checkbox(value=True, label="Save metadata (.safetensors only)", elem_id="modelmerger_save_metadata")
|
||||
|
||||
with FormRow():
|
||||
with gr.Column():
|
||||
config_source = gr.Radio(choices=["A, B or C", "B", "C", "Don't"], value="A, B or C", label="Copy config from", type="index", elem_id="modelmerger_config_method")
|
||||
|
||||
with gr.Column():
|
||||
with FormRow():
|
||||
bake_in_vae = gr.Dropdown(choices=["None"] + list(sd_vae.vae_dict), value="None", label="Bake in VAE", elem_id="modelmerger_bake_in_vae")
|
||||
create_refresh_button(bake_in_vae, sd_vae.refresh_vae_list, lambda: {"choices": ["None"] + list(sd_vae.vae_dict)}, "modelmerger_refresh_bake_in_vae")
|
||||
|
||||
with FormRow():
|
||||
discard_weights = gr.Textbox(value="", label="Discard weights with matching name", elem_id="modelmerger_discard_weights")
|
||||
|
||||
with gr.Row():
|
||||
modelmerger_merge = gr.Button(elem_id="modelmerger_merge", value="Merge", variant='primary')
|
||||
|
||||
with gr.Column(variant='compact', elem_id="modelmerger_results_container"):
|
||||
with gr.Group(elem_id="modelmerger_results_panel"):
|
||||
modelmerger_result = gr.HTML(elem_id="modelmerger_result", show_label=False)
|
||||
modelmerger_ui = ui_checkpoint_merger.UiCheckpointMerger()
|
||||
|
||||
with gr.Blocks(analytics_enabled=False) as train_interface:
|
||||
with gr.Row().style(equal_height=False):
|
||||
with gr.Row(equal_height=False):
|
||||
gr.HTML(value="<p style='margin-bottom: 0.7em'>See <b><a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\">wiki</a></b> for detailed explanation.</p>")
|
||||
|
||||
with gr.Row(variant="compact").style(equal_height=False):
|
||||
with gr.Row(variant="compact", equal_height=False):
|
||||
with gr.Tabs(elem_id="train_tabs"):
|
||||
|
||||
with gr.Tab(label="Create embedding", id="create_embedding"):
|
||||
@@ -1160,7 +1053,7 @@ def create_ui():
|
||||
new_hypernetwork_name = gr.Textbox(label="Name", elem_id="train_new_hypernetwork_name")
|
||||
new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "1024", "320", "640", "1280"], elem_id="train_new_hypernetwork_sizes")
|
||||
new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'", elem_id="train_new_hypernetwork_layer_structure")
|
||||
new_hypernetwork_activation_func = gr.Dropdown(value="linear", label="Select activation function of hypernetwork. Recommended : Swish / Linear(none)", choices=modules.hypernetworks.ui.keys, elem_id="train_new_hypernetwork_activation_func")
|
||||
new_hypernetwork_activation_func = gr.Dropdown(value="linear", label="Select activation function of hypernetwork. Recommended : Swish / Linear(none)", choices=hypernetworks_ui.keys, elem_id="train_new_hypernetwork_activation_func")
|
||||
new_hypernetwork_initialization_option = gr.Dropdown(value = "Normal", label="Select Layer weights initialization. Recommended: Kaiming for relu-like, Xavier for sigmoid-like, Normal otherwise", choices=["Normal", "KaimingUniform", "KaimingNormal", "XavierUniform", "XavierNormal"], elem_id="train_new_hypernetwork_initialization_option")
|
||||
new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization", elem_id="train_new_hypernetwork_add_layer_norm")
|
||||
new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout", elem_id="train_new_hypernetwork_use_dropout")
|
||||
@@ -1300,12 +1193,12 @@ def create_ui():
|
||||
|
||||
with gr.Column(elem_id='ti_gallery_container'):
|
||||
ti_output = gr.Text(elem_id="ti_output", value="", show_label=False)
|
||||
gr.Gallery(label='Output', show_label=False, elem_id='ti_gallery').style(columns=4)
|
||||
gr.Gallery(label='Output', show_label=False, elem_id='ti_gallery', columns=4)
|
||||
gr.HTML(elem_id="ti_progress", value="")
|
||||
ti_outcome = gr.HTML(elem_id="ti_error", value="")
|
||||
|
||||
create_embedding.click(
|
||||
fn=modules.textual_inversion.ui.create_embedding,
|
||||
fn=textual_inversion_ui.create_embedding,
|
||||
inputs=[
|
||||
new_embedding_name,
|
||||
initialization_text,
|
||||
@@ -1320,7 +1213,7 @@ def create_ui():
|
||||
)
|
||||
|
||||
create_hypernetwork.click(
|
||||
fn=modules.hypernetworks.ui.create_hypernetwork,
|
||||
fn=hypernetworks_ui.create_hypernetwork,
|
||||
inputs=[
|
||||
new_hypernetwork_name,
|
||||
new_hypernetwork_sizes,
|
||||
@@ -1340,7 +1233,7 @@ def create_ui():
|
||||
)
|
||||
|
||||
run_preprocess.click(
|
||||
fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.preprocess, extra_outputs=[gr.update()]),
|
||||
fn=wrap_gradio_gpu_call(textual_inversion_ui.preprocess, extra_outputs=[gr.update()]),
|
||||
_js="start_training_textual_inversion",
|
||||
inputs=[
|
||||
dummy_component,
|
||||
@@ -1376,7 +1269,7 @@ def create_ui():
|
||||
)
|
||||
|
||||
train_embedding.click(
|
||||
fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.train_embedding, extra_outputs=[gr.update()]),
|
||||
fn=wrap_gradio_gpu_call(textual_inversion_ui.train_embedding, extra_outputs=[gr.update()]),
|
||||
_js="start_training_textual_inversion",
|
||||
inputs=[
|
||||
dummy_component,
|
||||
@@ -1410,7 +1303,7 @@ def create_ui():
|
||||
)
|
||||
|
||||
train_hypernetwork.click(
|
||||
fn=wrap_gradio_gpu_call(modules.hypernetworks.ui.train_hypernetwork, extra_outputs=[gr.update()]),
|
||||
fn=wrap_gradio_gpu_call(hypernetworks_ui.train_hypernetwork, extra_outputs=[gr.update()]),
|
||||
_js="start_training_textual_inversion",
|
||||
inputs=[
|
||||
dummy_component,
|
||||
@@ -1464,7 +1357,7 @@ def create_ui():
|
||||
(img2img_interface, "img2img", "img2img"),
|
||||
(extras_interface, "Extras", "extras"),
|
||||
(pnginfo_interface, "PNG Info", "pnginfo"),
|
||||
(modelmerger_interface, "Checkpoint Merger", "modelmerger"),
|
||||
(modelmerger_ui.blocks, "Checkpoint Merger", "modelmerger"),
|
||||
(train_interface, "Train", "train"),
|
||||
]
|
||||
|
||||
@@ -1516,49 +1409,11 @@ def create_ui():
|
||||
settings.text_settings.change(fn=update_image_cfg_scale_visibility, inputs=[], outputs=[image_cfg_scale])
|
||||
demo.load(fn=update_image_cfg_scale_visibility, inputs=[], outputs=[image_cfg_scale])
|
||||
|
||||
def modelmerger(*args):
|
||||
try:
|
||||
results = modules.extras.run_modelmerger(*args)
|
||||
except Exception as e:
|
||||
errors.report("Error loading/saving model file", exc_info=True)
|
||||
modules.sd_models.list_models() # to remove the potentially missing models from the list
|
||||
return [*[gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(4)], f"Error merging checkpoints: {e}"]
|
||||
return results
|
||||
|
||||
modelmerger_merge.click(fn=lambda: '', inputs=[], outputs=[modelmerger_result])
|
||||
modelmerger_merge.click(
|
||||
fn=wrap_gradio_gpu_call(modelmerger, extra_outputs=lambda: [gr.update() for _ in range(4)]),
|
||||
_js='modelmerger',
|
||||
inputs=[
|
||||
dummy_component,
|
||||
primary_model_name,
|
||||
secondary_model_name,
|
||||
tertiary_model_name,
|
||||
interp_method,
|
||||
interp_amount,
|
||||
save_as_half,
|
||||
custom_name,
|
||||
checkpoint_format,
|
||||
config_source,
|
||||
bake_in_vae,
|
||||
discard_weights,
|
||||
save_metadata,
|
||||
],
|
||||
outputs=[
|
||||
primary_model_name,
|
||||
secondary_model_name,
|
||||
tertiary_model_name,
|
||||
settings.component_dict['sd_model_checkpoint'],
|
||||
modelmerger_result,
|
||||
]
|
||||
)
|
||||
modelmerger_ui.setup_ui(dummy_component=dummy_component, sd_model_checkpoint_component=settings.component_dict['sd_model_checkpoint'])
|
||||
|
||||
loadsave.dump_defaults()
|
||||
demo.ui_loadsave = loadsave
|
||||
|
||||
# Required as a workaround for change() event not triggering when loading values from ui-config.json
|
||||
interp_description.value = update_interp_description(interp_method.value)
|
||||
|
||||
return demo
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,124 @@
|
||||
|
||||
import gradio as gr
|
||||
|
||||
from modules import sd_models, sd_vae, errors, extras, call_queue
|
||||
from modules.ui_components import FormRow
|
||||
from modules.ui_common import create_refresh_button
|
||||
|
||||
|
||||
def update_interp_description(value):
|
||||
interp_description_css = "<p style='margin-bottom: 2.5em'>{}</p>"
|
||||
interp_descriptions = {
|
||||
"No interpolation": interp_description_css.format("No interpolation will be used. Requires one model; A. Allows for format conversion and VAE baking."),
|
||||
"Weighted sum": interp_description_css.format("A weighted sum will be used for interpolation. Requires two models; A and B. The result is calculated as A * (1 - M) + B * M"),
|
||||
"Add difference": interp_description_css.format("The difference between the last two models will be added to the first. Requires three models; A, B and C. The result is calculated as A + (B - C) * M")
|
||||
}
|
||||
return interp_descriptions[value]
|
||||
|
||||
|
||||
def modelmerger(*args):
|
||||
try:
|
||||
results = extras.run_modelmerger(*args)
|
||||
except Exception as e:
|
||||
errors.report("Error loading/saving model file", exc_info=True)
|
||||
sd_models.list_models() # to remove the potentially missing models from the list
|
||||
return [*[gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)], f"Error merging checkpoints: {e}"]
|
||||
return results
|
||||
|
||||
|
||||
class UiCheckpointMerger:
|
||||
def __init__(self):
|
||||
with gr.Blocks(analytics_enabled=False) as modelmerger_interface:
|
||||
with gr.Row(equal_height=False):
|
||||
with gr.Column(variant='compact'):
|
||||
self.interp_description = gr.HTML(value=update_interp_description("Weighted sum"), elem_id="modelmerger_interp_description")
|
||||
|
||||
with FormRow(elem_id="modelmerger_models"):
|
||||
self.primary_model_name = gr.Dropdown(sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary model (A)")
|
||||
create_refresh_button(self.primary_model_name, sd_models.list_models, lambda: {"choices": sd_models.checkpoint_tiles()}, "refresh_checkpoint_A")
|
||||
|
||||
self.secondary_model_name = gr.Dropdown(sd_models.checkpoint_tiles(), elem_id="modelmerger_secondary_model_name", label="Secondary model (B)")
|
||||
create_refresh_button(self.secondary_model_name, sd_models.list_models, lambda: {"choices": sd_models.checkpoint_tiles()}, "refresh_checkpoint_B")
|
||||
|
||||
self.tertiary_model_name = gr.Dropdown(sd_models.checkpoint_tiles(), elem_id="modelmerger_tertiary_model_name", label="Tertiary model (C)")
|
||||
create_refresh_button(self.tertiary_model_name, sd_models.list_models, lambda: {"choices": sd_models.checkpoint_tiles()}, "refresh_checkpoint_C")
|
||||
|
||||
self.custom_name = gr.Textbox(label="Custom Name (Optional)", elem_id="modelmerger_custom_name")
|
||||
self.interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3, elem_id="modelmerger_interp_amount")
|
||||
self.interp_method = gr.Radio(choices=["No interpolation", "Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method", elem_id="modelmerger_interp_method")
|
||||
self.interp_method.change(fn=update_interp_description, inputs=[self.interp_method], outputs=[self.interp_description])
|
||||
|
||||
with FormRow():
|
||||
self.checkpoint_format = gr.Radio(choices=["ckpt", "safetensors"], value="safetensors", label="Checkpoint format", elem_id="modelmerger_checkpoint_format")
|
||||
self.save_as_half = gr.Checkbox(value=False, label="Save as float16", elem_id="modelmerger_save_as_half")
|
||||
|
||||
with FormRow():
|
||||
with gr.Column():
|
||||
self.config_source = gr.Radio(choices=["A, B or C", "B", "C", "Don't"], value="A, B or C", label="Copy config from", type="index", elem_id="modelmerger_config_method")
|
||||
|
||||
with gr.Column():
|
||||
with FormRow():
|
||||
self.bake_in_vae = gr.Dropdown(choices=["None"] + list(sd_vae.vae_dict), value="None", label="Bake in VAE", elem_id="modelmerger_bake_in_vae")
|
||||
create_refresh_button(self.bake_in_vae, sd_vae.refresh_vae_list, lambda: {"choices": ["None"] + list(sd_vae.vae_dict)}, "modelmerger_refresh_bake_in_vae")
|
||||
|
||||
with FormRow():
|
||||
self.discard_weights = gr.Textbox(value="", label="Discard weights with matching name", elem_id="modelmerger_discard_weights")
|
||||
|
||||
with gr.Accordion("Metadata", open=False) as metadata_editor:
|
||||
with FormRow():
|
||||
self.save_metadata = gr.Checkbox(value=True, label="Save metadata", elem_id="modelmerger_save_metadata")
|
||||
self.add_merge_recipe = gr.Checkbox(value=True, label="Add merge recipe metadata", elem_id="modelmerger_add_recipe")
|
||||
self.copy_metadata_fields = gr.Checkbox(value=True, label="Copy metadata from merged models", elem_id="modelmerger_copy_metadata")
|
||||
|
||||
self.metadata_json = gr.TextArea('{}', label="Metadata in JSON format")
|
||||
self.read_metadata = gr.Button("Read metadata from selected checkpoints")
|
||||
|
||||
with FormRow():
|
||||
self.modelmerger_merge = gr.Button(elem_id="modelmerger_merge", value="Merge", variant='primary')
|
||||
|
||||
with gr.Column(variant='compact', elem_id="modelmerger_results_container"):
|
||||
with gr.Group(elem_id="modelmerger_results_panel"):
|
||||
self.modelmerger_result = gr.HTML(elem_id="modelmerger_result", show_label=False)
|
||||
|
||||
self.metadata_editor = metadata_editor
|
||||
self.blocks = modelmerger_interface
|
||||
|
||||
def setup_ui(self, dummy_component, sd_model_checkpoint_component):
|
||||
self.checkpoint_format.change(lambda fmt: gr.update(visible=fmt == 'safetensors'), inputs=[self.checkpoint_format], outputs=[self.metadata_editor], show_progress=False)
|
||||
|
||||
self.read_metadata.click(extras.read_metadata, inputs=[self.primary_model_name, self.secondary_model_name, self.tertiary_model_name], outputs=[self.metadata_json])
|
||||
|
||||
self.modelmerger_merge.click(fn=lambda: '', inputs=[], outputs=[self.modelmerger_result])
|
||||
self.modelmerger_merge.click(
|
||||
fn=call_queue.wrap_gradio_gpu_call(modelmerger, extra_outputs=lambda: [gr.update() for _ in range(4)]),
|
||||
_js='modelmerger',
|
||||
inputs=[
|
||||
dummy_component,
|
||||
self.primary_model_name,
|
||||
self.secondary_model_name,
|
||||
self.tertiary_model_name,
|
||||
self.interp_method,
|
||||
self.interp_amount,
|
||||
self.save_as_half,
|
||||
self.custom_name,
|
||||
self.checkpoint_format,
|
||||
self.config_source,
|
||||
self.bake_in_vae,
|
||||
self.discard_weights,
|
||||
self.save_metadata,
|
||||
self.add_merge_recipe,
|
||||
self.copy_metadata_fields,
|
||||
self.metadata_json,
|
||||
],
|
||||
outputs=[
|
||||
self.primary_model_name,
|
||||
self.secondary_model_name,
|
||||
self.tertiary_model_name,
|
||||
sd_model_checkpoint_component,
|
||||
self.modelmerger_result,
|
||||
]
|
||||
)
|
||||
|
||||
# Required as a workaround for change() event not triggering when loading values from ui-config.json
|
||||
self.interp_description.value = update_interp_description(self.interp_method.value)
|
||||
|
||||
+29
-5
@@ -134,7 +134,7 @@ Requested path was: {f}
|
||||
|
||||
with gr.Column(variant='panel', elem_id=f"{tabname}_results"):
|
||||
with gr.Group(elem_id=f"{tabname}_gallery_container"):
|
||||
result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery").style(columns=4)
|
||||
result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery", columns=4)
|
||||
|
||||
generation_info = None
|
||||
with gr.Column():
|
||||
@@ -223,20 +223,44 @@ Requested path was: {f}
|
||||
|
||||
|
||||
def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
|
||||
refresh_components = refresh_component if isinstance(refresh_component, list) else [refresh_component]
|
||||
|
||||
label = None
|
||||
for comp in refresh_components:
|
||||
label = getattr(comp, 'label', None)
|
||||
if label is not None:
|
||||
break
|
||||
|
||||
def refresh():
|
||||
refresh_method()
|
||||
args = refreshed_args() if callable(refreshed_args) else refreshed_args
|
||||
|
||||
for k, v in args.items():
|
||||
setattr(refresh_component, k, v)
|
||||
for comp in refresh_components:
|
||||
setattr(comp, k, v)
|
||||
|
||||
return gr.update(**(args or {}))
|
||||
return [gr.update(**(args or {})) for _ in refresh_components] if len(refresh_components) > 1 else gr.update(**(args or {}))
|
||||
|
||||
refresh_button = ToolButton(value=refresh_symbol, elem_id=elem_id)
|
||||
refresh_button = ToolButton(value=refresh_symbol, elem_id=elem_id, tooltip=f"{label}: refresh" if label else "Refresh")
|
||||
refresh_button.click(
|
||||
fn=refresh,
|
||||
inputs=[],
|
||||
outputs=[refresh_component]
|
||||
outputs=refresh_components
|
||||
)
|
||||
return refresh_button
|
||||
|
||||
|
||||
def setup_dialog(button_show, dialog, *, button_close=None):
|
||||
"""Sets up the UI so that the dialog (gr.Box) is invisible, and is only shown when buttons_show is clicked, in a fullscreen modal window."""
|
||||
|
||||
dialog.visible = False
|
||||
|
||||
button_show.click(
|
||||
fn=lambda: gr.update(visible=True),
|
||||
inputs=[],
|
||||
outputs=[dialog],
|
||||
).then(fn=None, _js="function(){ popup(gradioApp().getElementById('" + dialog.elem_id + "')); }")
|
||||
|
||||
if button_close:
|
||||
button_close.click(fn=None, _js="closePopup")
|
||||
|
||||
|
||||
@@ -35,7 +35,7 @@ class FormColumn(FormComponent, gr.Column):
|
||||
|
||||
|
||||
class FormGroup(FormComponent, gr.Group):
|
||||
"""Same as gr.Row but fits inside gradio forms"""
|
||||
"""Same as gr.Group but fits inside gradio forms"""
|
||||
|
||||
def get_block_name(self):
|
||||
return "group"
|
||||
|
||||
+15
-11
@@ -164,7 +164,7 @@ def extension_table():
|
||||
ext_status = ext.status
|
||||
|
||||
style = ""
|
||||
if shared.opts.disable_all_extensions == "extra" and not ext.is_builtin or shared.opts.disable_all_extensions == "all":
|
||||
if shared.cmd_opts.disable_extra_extensions and not ext.is_builtin or shared.opts.disable_all_extensions == "extra" and not ext.is_builtin or shared.cmd_opts.disable_all_extensions or shared.opts.disable_all_extensions == "all":
|
||||
style = STYLE_PRIMARY
|
||||
|
||||
version_link = ext.version
|
||||
@@ -533,16 +533,20 @@ def create_ui():
|
||||
apply = gr.Button(value=apply_label, variant="primary")
|
||||
check = gr.Button(value="Check for updates")
|
||||
extensions_disable_all = gr.Radio(label="Disable all extensions", choices=["none", "extra", "all"], value=shared.opts.disable_all_extensions, elem_id="extensions_disable_all")
|
||||
extensions_disabled_list = gr.Text(elem_id="extensions_disabled_list", visible=False).style(container=False)
|
||||
extensions_update_list = gr.Text(elem_id="extensions_update_list", visible=False).style(container=False)
|
||||
extensions_disabled_list = gr.Text(elem_id="extensions_disabled_list", visible=False, container=False)
|
||||
extensions_update_list = gr.Text(elem_id="extensions_update_list", visible=False, container=False)
|
||||
|
||||
html = ""
|
||||
if shared.opts.disable_all_extensions != "none":
|
||||
html = """
|
||||
<span style="color: var(--primary-400);">
|
||||
"Disable all extensions" was set, change it to "none" to load all extensions again
|
||||
</span>
|
||||
"""
|
||||
|
||||
if shared.cmd_opts.disable_all_extensions or shared.cmd_opts.disable_extra_extensions or shared.opts.disable_all_extensions != "none":
|
||||
if shared.cmd_opts.disable_all_extensions:
|
||||
msg = '"--disable-all-extensions" was used, remove it to load all extensions again'
|
||||
elif shared.opts.disable_all_extensions != "none":
|
||||
msg = '"Disable all extensions" was set, change it to "none" to load all extensions again'
|
||||
elif shared.cmd_opts.disable_extra_extensions:
|
||||
msg = '"--disable-extra-extensions" was used, remove it to load all extensions again'
|
||||
html = f'<span style="color: var(--primary-400);">{msg}</span>'
|
||||
|
||||
info = gr.HTML(html)
|
||||
extensions_table = gr.HTML('Loading...')
|
||||
ui.load(fn=extension_table, inputs=[], outputs=[extensions_table])
|
||||
@@ -565,7 +569,7 @@ def create_ui():
|
||||
with gr.Row():
|
||||
refresh_available_extensions_button = gr.Button(value="Load from:", variant="primary")
|
||||
extensions_index_url = os.environ.get('WEBUI_EXTENSIONS_INDEX', "https://raw.githubusercontent.com/AUTOMATIC1111/stable-diffusion-webui-extensions/master/index.json")
|
||||
available_extensions_index = gr.Text(value=extensions_index_url, label="Extension index URL").style(container=False)
|
||||
available_extensions_index = gr.Text(value=extensions_index_url, label="Extension index URL", container=False)
|
||||
extension_to_install = gr.Text(elem_id="extension_to_install", visible=False)
|
||||
install_extension_button = gr.Button(elem_id="install_extension_button", visible=False)
|
||||
|
||||
@@ -574,7 +578,7 @@ def create_ui():
|
||||
sort_column = gr.Radio(value="newest first", label="Order", choices=["newest first", "oldest first", "a-z", "z-a", "internal order",'update time', 'create time', "stars"], type="index")
|
||||
|
||||
with gr.Row():
|
||||
search_extensions_text = gr.Text(label="Search").style(container=False)
|
||||
search_extensions_text = gr.Text(label="Search", container=False)
|
||||
|
||||
install_result = gr.HTML()
|
||||
available_extensions_table = gr.HTML()
|
||||
|
||||
@@ -2,7 +2,7 @@ import os.path
|
||||
import urllib.parse
|
||||
from pathlib import Path
|
||||
|
||||
from modules import shared, ui_extra_networks_user_metadata, errors
|
||||
from modules import shared, ui_extra_networks_user_metadata, errors, extra_networks
|
||||
from modules.images import read_info_from_image, save_image_with_geninfo
|
||||
from modules.ui import up_down_symbol
|
||||
import gradio as gr
|
||||
@@ -101,16 +101,7 @@ class ExtraNetworksPage:
|
||||
|
||||
def read_user_metadata(self, item):
|
||||
filename = item.get("filename", None)
|
||||
basename, ext = os.path.splitext(filename)
|
||||
metadata_filename = basename + '.json'
|
||||
|
||||
metadata = {}
|
||||
try:
|
||||
if os.path.isfile(metadata_filename):
|
||||
with open(metadata_filename, "r", encoding="utf8") as file:
|
||||
metadata = json.load(file)
|
||||
except Exception as e:
|
||||
errors.display(e, f"reading extra network user metadata from {metadata_filename}")
|
||||
metadata = extra_networks.get_user_metadata(filename)
|
||||
|
||||
desc = metadata.get("description", None)
|
||||
if desc is not None:
|
||||
@@ -164,7 +155,7 @@ class ExtraNetworksPage:
|
||||
subdirs = {"": 1, **subdirs}
|
||||
|
||||
subdirs_html = "".join([f"""
|
||||
<button class='lg secondary gradio-button custom-button{" search-all" if subdir=="" else ""}' onclick='extraNetworksSearchButton("{tabname}_extra_tabs", event)'>
|
||||
<button class='lg secondary gradio-button custom-button{" search-all" if subdir=="" else ""}' onclick='extraNetworksSearchButton("{tabname}_extra_search", event)'>
|
||||
{html.escape(subdir if subdir!="" else "all")}
|
||||
</button>
|
||||
""" for subdir in subdirs])
|
||||
@@ -356,7 +347,7 @@ def pages_in_preferred_order(pages):
|
||||
return sorted(pages, key=lambda x: tab_scores[x.name])
|
||||
|
||||
|
||||
def create_ui(container, button, tabname):
|
||||
def create_ui(interface: gr.Blocks, unrelated_tabs, tabname):
|
||||
ui = ExtraNetworksUi()
|
||||
ui.pages = []
|
||||
ui.pages_contents = []
|
||||
@@ -364,48 +355,42 @@ def create_ui(container, button, tabname):
|
||||
ui.stored_extra_pages = pages_in_preferred_order(extra_pages.copy())
|
||||
ui.tabname = tabname
|
||||
|
||||
with gr.Tabs(elem_id=tabname+"_extra_tabs"):
|
||||
for page in ui.stored_extra_pages:
|
||||
with gr.Tab(page.title, id=page.id_page):
|
||||
elem_id = f"{tabname}_{page.id_page}_cards_html"
|
||||
page_elem = gr.HTML('Loading...', elem_id=elem_id)
|
||||
ui.pages.append(page_elem)
|
||||
related_tabs = []
|
||||
|
||||
page_elem.change(fn=lambda: None, _js='function(){applyExtraNetworkFilter(' + quote_js(tabname) + '); return []}', inputs=[], outputs=[])
|
||||
for page in ui.stored_extra_pages:
|
||||
with gr.Tab(page.title, id=page.id_page) as tab:
|
||||
elem_id = f"{tabname}_{page.id_page}_cards_html"
|
||||
page_elem = gr.HTML('Loading...', elem_id=elem_id)
|
||||
ui.pages.append(page_elem)
|
||||
|
||||
editor = page.create_user_metadata_editor(ui, tabname)
|
||||
editor.create_ui()
|
||||
ui.user_metadata_editors.append(editor)
|
||||
page_elem.change(fn=lambda: None, _js='function(){applyExtraNetworkFilter(' + quote_js(tabname) + '); return []}', inputs=[], outputs=[])
|
||||
|
||||
gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", placeholder="Search...", visible=False)
|
||||
gr.Dropdown(choices=['Default Sort', 'Date Created', 'Date Modified', 'Name'], value='Default Sort', elem_id=tabname+"_extra_sort", multiselect=False, visible=False, show_label=False, interactive=True)
|
||||
ToolButton(up_down_symbol, elem_id=tabname+"_extra_sortorder")
|
||||
button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh")
|
||||
editor = page.create_user_metadata_editor(ui, tabname)
|
||||
editor.create_ui()
|
||||
ui.user_metadata_editors.append(editor)
|
||||
|
||||
related_tabs.append(tab)
|
||||
|
||||
edit_search = gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", elem_classes="search", placeholder="Search...", visible=False, interactive=True)
|
||||
dropdown_sort = gr.Dropdown(choices=['Default Sort', 'Date Created', 'Date Modified', 'Name'], value='Default Sort', elem_id=tabname+"_extra_sort", elem_classes="sort", multiselect=False, visible=False, show_label=False, interactive=True, label=tabname+"_extra_sort_order")
|
||||
button_sortorder = ToolButton(up_down_symbol, elem_id=tabname+"_extra_sortorder", elem_classes="sortorder", visible=False)
|
||||
button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh", visible=False)
|
||||
checkbox_show_dirs = gr.Checkbox(True, label='Show dirs', elem_id=tabname+"_extra_show_dirs", elem_classes="show-dirs", visible=False)
|
||||
|
||||
ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False)
|
||||
ui.preview_target_filename = gr.Textbox('Preview save filename', elem_id=tabname+"_preview_filename", visible=False)
|
||||
|
||||
def toggle_visibility(is_visible):
|
||||
is_visible = not is_visible
|
||||
for tab in unrelated_tabs:
|
||||
tab.select(fn=lambda: [gr.update(visible=False) for _ in range(5)], inputs=[], outputs=[edit_search, dropdown_sort, button_sortorder, button_refresh, checkbox_show_dirs], show_progress=False)
|
||||
|
||||
return is_visible, gr.update(visible=is_visible), gr.update(variant=("secondary-down" if is_visible else "secondary"))
|
||||
|
||||
def fill_tabs(is_empty):
|
||||
"""Creates HTML for extra networks' tabs when the extra networks button is clicked for the first time."""
|
||||
for tab in related_tabs:
|
||||
tab.select(fn=lambda: [gr.update(visible=True) for _ in range(5)], inputs=[], outputs=[edit_search, dropdown_sort, button_sortorder, button_refresh, checkbox_show_dirs], show_progress=False)
|
||||
|
||||
def pages_html():
|
||||
if not ui.pages_contents:
|
||||
refresh()
|
||||
return refresh()
|
||||
|
||||
if is_empty:
|
||||
return True, *ui.pages_contents
|
||||
|
||||
return True, *[gr.update() for _ in ui.pages_contents]
|
||||
|
||||
state_visible = gr.State(value=False)
|
||||
button.click(fn=toggle_visibility, inputs=[state_visible], outputs=[state_visible, container, button], show_progress=False)
|
||||
|
||||
state_empty = gr.State(value=True)
|
||||
button.click(fn=fill_tabs, inputs=[state_empty], outputs=[state_empty, *ui.pages], show_progress=False)
|
||||
return ui.pages_contents
|
||||
|
||||
def refresh():
|
||||
for pg in ui.stored_extra_pages:
|
||||
@@ -415,6 +400,7 @@ def create_ui(container, button, tabname):
|
||||
|
||||
return ui.pages_contents
|
||||
|
||||
interface.load(fn=pages_html, inputs=[], outputs=[*ui.pages])
|
||||
button_refresh.click(fn=refresh, inputs=[], outputs=ui.pages)
|
||||
|
||||
return ui
|
||||
|
||||
@@ -3,6 +3,7 @@ import os
|
||||
|
||||
from modules import shared, ui_extra_networks, sd_models
|
||||
from modules.ui_extra_networks import quote_js
|
||||
from modules.ui_extra_networks_checkpoints_user_metadata import CheckpointUserMetadataEditor
|
||||
|
||||
|
||||
class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
|
||||
@@ -12,7 +13,7 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
|
||||
def refresh(self):
|
||||
shared.refresh_checkpoints()
|
||||
|
||||
def create_item(self, name, index=None):
|
||||
def create_item(self, name, index=None, enable_filter=True):
|
||||
checkpoint: sd_models.CheckpointInfo = sd_models.checkpoint_aliases.get(name)
|
||||
path, ext = os.path.splitext(checkpoint.filename)
|
||||
return {
|
||||
@@ -23,6 +24,7 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
|
||||
"search_term": self.search_terms_from_path(checkpoint.filename) + " " + (checkpoint.sha256 or ""),
|
||||
"onclick": '"' + html.escape(f"""return selectCheckpoint({quote_js(name)})""") + '"',
|
||||
"local_preview": f"{path}.{shared.opts.samples_format}",
|
||||
"metadata": checkpoint.metadata,
|
||||
"sort_keys": {'default': index, **self.get_sort_keys(checkpoint.filename)},
|
||||
}
|
||||
|
||||
@@ -33,3 +35,5 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
|
||||
def allowed_directories_for_previews(self):
|
||||
return [v for v in [shared.cmd_opts.ckpt_dir, sd_models.model_path] if v is not None]
|
||||
|
||||
def create_user_metadata_editor(self, ui, tabname):
|
||||
return CheckpointUserMetadataEditor(ui, tabname, self)
|
||||
|
||||
@@ -0,0 +1,60 @@
|
||||
import gradio as gr
|
||||
|
||||
from modules import ui_extra_networks_user_metadata, sd_vae
|
||||
from modules.ui_common import create_refresh_button
|
||||
|
||||
|
||||
class CheckpointUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor):
|
||||
def __init__(self, ui, tabname, page):
|
||||
super().__init__(ui, tabname, page)
|
||||
|
||||
self.select_vae = None
|
||||
|
||||
def save_user_metadata(self, name, desc, notes, vae):
|
||||
user_metadata = self.get_user_metadata(name)
|
||||
user_metadata["description"] = desc
|
||||
user_metadata["notes"] = notes
|
||||
user_metadata["vae"] = vae
|
||||
|
||||
self.write_user_metadata(name, user_metadata)
|
||||
|
||||
def put_values_into_components(self, name):
|
||||
user_metadata = self.get_user_metadata(name)
|
||||
values = super().put_values_into_components(name)
|
||||
|
||||
return [
|
||||
*values[0:5],
|
||||
user_metadata.get('vae', ''),
|
||||
]
|
||||
|
||||
def create_editor(self):
|
||||
self.create_default_editor_elems()
|
||||
|
||||
with gr.Row():
|
||||
self.select_vae = gr.Dropdown(choices=["Automatic", "None"] + list(sd_vae.vae_dict), value="None", label="Preferred VAE", elem_id="checpoint_edit_user_metadata_preferred_vae")
|
||||
create_refresh_button(self.select_vae, sd_vae.refresh_vae_list, lambda: {"choices": ["Automatic", "None"] + list(sd_vae.vae_dict)}, "checpoint_edit_user_metadata_refresh_preferred_vae")
|
||||
|
||||
self.edit_notes = gr.TextArea(label='Notes', lines=4)
|
||||
|
||||
self.create_default_buttons()
|
||||
|
||||
viewed_components = [
|
||||
self.edit_name,
|
||||
self.edit_description,
|
||||
self.html_filedata,
|
||||
self.html_preview,
|
||||
self.edit_notes,
|
||||
self.select_vae,
|
||||
]
|
||||
|
||||
self.button_edit\
|
||||
.click(fn=self.put_values_into_components, inputs=[self.edit_name_input], outputs=viewed_components)\
|
||||
.then(fn=lambda: gr.update(visible=True), inputs=[], outputs=[self.box])
|
||||
|
||||
edited_components = [
|
||||
self.edit_description,
|
||||
self.edit_notes,
|
||||
self.select_vae,
|
||||
]
|
||||
|
||||
self.setup_save_handler(self.button_save, self.save_user_metadata, edited_components)
|
||||
@@ -11,7 +11,7 @@ class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage):
|
||||
def refresh(self):
|
||||
shared.reload_hypernetworks()
|
||||
|
||||
def create_item(self, name, index=None):
|
||||
def create_item(self, name, index=None, enable_filter=True):
|
||||
full_path = shared.hypernetworks[name]
|
||||
path, ext = os.path.splitext(full_path)
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage):
|
||||
def refresh(self):
|
||||
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True)
|
||||
|
||||
def create_item(self, name, index=None):
|
||||
def create_item(self, name, index=None, enable_filter=True):
|
||||
embedding = sd_hijack.model_hijack.embedding_db.word_embeddings.get(name)
|
||||
|
||||
path, ext = os.path.splitext(embedding.filename)
|
||||
|
||||
@@ -96,6 +96,7 @@ class UserMetadataEditor:
|
||||
|
||||
stats = os.stat(filename)
|
||||
params = [
|
||||
('Filename: ', os.path.basename(filename)),
|
||||
('File size: ', sysinfo.pretty_bytes(stats.st_size)),
|
||||
('Modified: ', datetime.datetime.fromtimestamp(stats.st_mtime).strftime('%Y-%m-%d %H:%M')),
|
||||
]
|
||||
|
||||
@@ -6,7 +6,7 @@ import modules.generation_parameters_copypaste as parameters_copypaste
|
||||
def create_ui():
|
||||
tab_index = gr.State(value=0)
|
||||
|
||||
with gr.Row().style(equal_height=False, variant='compact'):
|
||||
with gr.Row(equal_height=False, variant='compact'):
|
||||
with gr.Column(variant='compact'):
|
||||
with gr.Tabs(elem_id="mode_extras"):
|
||||
with gr.TabItem('Single Image', id="single_image", elem_id="extras_single_tab") as tab_single:
|
||||
|
||||
@@ -0,0 +1,110 @@
|
||||
import gradio as gr
|
||||
|
||||
from modules import shared, ui_common, ui_components, styles
|
||||
|
||||
styles_edit_symbol = '\U0001f58c\uFE0F' # 🖌️
|
||||
styles_materialize_symbol = '\U0001f4cb' # 📋
|
||||
|
||||
|
||||
def select_style(name):
|
||||
style = shared.prompt_styles.styles.get(name)
|
||||
existing = style is not None
|
||||
empty = not name
|
||||
|
||||
prompt = style.prompt if style else gr.update()
|
||||
negative_prompt = style.negative_prompt if style else gr.update()
|
||||
|
||||
return prompt, negative_prompt, gr.update(visible=existing), gr.update(visible=not empty)
|
||||
|
||||
|
||||
def save_style(name, prompt, negative_prompt):
|
||||
if not name:
|
||||
return gr.update(visible=False)
|
||||
|
||||
style = styles.PromptStyle(name, prompt, negative_prompt)
|
||||
shared.prompt_styles.styles[style.name] = style
|
||||
shared.prompt_styles.save_styles(shared.styles_filename)
|
||||
|
||||
return gr.update(visible=True)
|
||||
|
||||
|
||||
def delete_style(name):
|
||||
if name == "":
|
||||
return
|
||||
|
||||
shared.prompt_styles.styles.pop(name, None)
|
||||
shared.prompt_styles.save_styles(shared.styles_filename)
|
||||
|
||||
return '', '', ''
|
||||
|
||||
|
||||
def materialize_styles(prompt, negative_prompt, styles):
|
||||
prompt = shared.prompt_styles.apply_styles_to_prompt(prompt, styles)
|
||||
negative_prompt = shared.prompt_styles.apply_negative_styles_to_prompt(negative_prompt, styles)
|
||||
|
||||
return [gr.Textbox.update(value=prompt), gr.Textbox.update(value=negative_prompt), gr.Dropdown.update(value=[])]
|
||||
|
||||
|
||||
def refresh_styles():
|
||||
return gr.update(choices=list(shared.prompt_styles.styles)), gr.update(choices=list(shared.prompt_styles.styles))
|
||||
|
||||
|
||||
class UiPromptStyles:
|
||||
def __init__(self, tabname, main_ui_prompt, main_ui_negative_prompt):
|
||||
self.tabname = tabname
|
||||
|
||||
with gr.Row(elem_id=f"{tabname}_styles_row"):
|
||||
self.dropdown = gr.Dropdown(label="Styles", show_label=False, elem_id=f"{tabname}_styles", choices=list(shared.prompt_styles.styles), value=[], multiselect=True, tooltip="Styles")
|
||||
edit_button = ui_components.ToolButton(value=styles_edit_symbol, elem_id=f"{tabname}_styles_edit_button", tooltip="Edit styles")
|
||||
|
||||
with gr.Box(elem_id=f"{tabname}_styles_dialog", elem_classes="popup-dialog") as styles_dialog:
|
||||
with gr.Row():
|
||||
self.selection = gr.Dropdown(label="Styles", elem_id=f"{tabname}_styles_edit_select", choices=list(shared.prompt_styles.styles), value=[], allow_custom_value=True, info="Styles allow you to add custom text to prompt. Use the {prompt} token in style text, and it will be replaced with user's prompt when applying style. Otherwise, style's text will be added to the end of the prompt.")
|
||||
ui_common.create_refresh_button([self.dropdown, self.selection], shared.prompt_styles.reload, lambda: {"choices": list(shared.prompt_styles.styles)}, f"refresh_{tabname}_styles")
|
||||
self.materialize = ui_components.ToolButton(value=styles_materialize_symbol, elem_id=f"{tabname}_style_apply", tooltip="Apply all selected styles from the style selction dropdown in main UI to the prompt.")
|
||||
|
||||
with gr.Row():
|
||||
self.prompt = gr.Textbox(label="Prompt", show_label=True, elem_id=f"{tabname}_edit_style_prompt", lines=3)
|
||||
|
||||
with gr.Row():
|
||||
self.neg_prompt = gr.Textbox(label="Negative prompt", show_label=True, elem_id=f"{tabname}_edit_style_neg_prompt", lines=3)
|
||||
|
||||
with gr.Row():
|
||||
self.save = gr.Button('Save', variant='primary', elem_id=f'{tabname}_edit_style_save', visible=False)
|
||||
self.delete = gr.Button('Delete', variant='primary', elem_id=f'{tabname}_edit_style_delete', visible=False)
|
||||
self.close = gr.Button('Close', variant='secondary', elem_id=f'{tabname}_edit_style_close')
|
||||
|
||||
self.selection.change(
|
||||
fn=select_style,
|
||||
inputs=[self.selection],
|
||||
outputs=[self.prompt, self.neg_prompt, self.delete, self.save],
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
self.save.click(
|
||||
fn=save_style,
|
||||
inputs=[self.selection, self.prompt, self.neg_prompt],
|
||||
outputs=[self.delete],
|
||||
show_progress=False,
|
||||
).then(refresh_styles, outputs=[self.dropdown, self.selection], show_progress=False)
|
||||
|
||||
self.delete.click(
|
||||
fn=delete_style,
|
||||
_js='function(name){ if(name == "") return ""; return confirm("Delete style " + name + "?") ? name : ""; }',
|
||||
inputs=[self.selection],
|
||||
outputs=[self.selection, self.prompt, self.neg_prompt],
|
||||
show_progress=False,
|
||||
).then(refresh_styles, outputs=[self.dropdown, self.selection], show_progress=False)
|
||||
|
||||
self.materialize.click(
|
||||
fn=materialize_styles,
|
||||
inputs=[main_ui_prompt, main_ui_negative_prompt, self.dropdown],
|
||||
outputs=[main_ui_prompt, main_ui_negative_prompt, self.dropdown],
|
||||
show_progress=False,
|
||||
).then(fn=None, _js="function(){update_"+tabname+"_tokens(); closePopup();}", show_progress=False)
|
||||
|
||||
ui_common.setup_dialog(button_show=edit_button, dialog=styles_dialog, button_close=self.close)
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -158,7 +158,7 @@ class UiSettings:
|
||||
loadsave.create_ui()
|
||||
|
||||
with gr.TabItem("Sysinfo", id="sysinfo", elem_id="settings_tab_sysinfo"):
|
||||
gr.HTML('<a href="./internal/sysinfo-download" class="sysinfo_big_link" download>Download system info</a><br /><a href="./internal/sysinfo">(or open as text in a new page)</a>', elem_id="sysinfo_download")
|
||||
gr.HTML('<a href="./internal/sysinfo-download" class="sysinfo_big_link" download>Download system info</a><br /><a href="./internal/sysinfo" target="_blank">(or open as text in a new page)</a>', elem_id="sysinfo_download")
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1):
|
||||
|
||||
+2
-2
@@ -7,7 +7,7 @@ blendmodes
|
||||
clean-fid
|
||||
einops
|
||||
gfpgan
|
||||
gradio==3.32.0
|
||||
gradio==3.39.0
|
||||
inflection
|
||||
jsonmerge
|
||||
kornia
|
||||
@@ -30,4 +30,4 @@ tomesd
|
||||
torch
|
||||
torchdiffeq
|
||||
torchsde
|
||||
transformers==4.25.1
|
||||
transformers==4.30.2
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
GitPython==3.1.30
|
||||
GitPython==3.1.32
|
||||
Pillow==9.5.0
|
||||
accelerate==0.18.0
|
||||
accelerate==0.21.0
|
||||
basicsr==1.4.2
|
||||
blendmodes==2022
|
||||
clean-fid==0.1.35
|
||||
einops==0.4.1
|
||||
fastapi==0.94.0
|
||||
gfpgan==1.3.8
|
||||
gradio==3.32.0
|
||||
gradio==3.39.0
|
||||
httpcore==0.15
|
||||
inflection==0.5.1
|
||||
jsonmerge==1.8.0
|
||||
@@ -22,10 +22,10 @@ pytorch_lightning==1.9.4
|
||||
realesrgan==0.3.0
|
||||
resize-right==0.0.2
|
||||
safetensors==0.3.1
|
||||
scikit-image==0.20.0
|
||||
timm==0.6.7
|
||||
tomesd==0.1.2
|
||||
scikit-image==0.21.0
|
||||
timm==0.9.2
|
||||
tomesd==0.1.3
|
||||
torch
|
||||
torchdiffeq==0.2.3
|
||||
torchsde==0.2.5
|
||||
transformers==4.25.1
|
||||
transformers==4.30.2
|
||||
|
||||
+14
-13
@@ -3,6 +3,7 @@ from copy import copy
|
||||
from itertools import permutations, chain
|
||||
import random
|
||||
import csv
|
||||
import os.path
|
||||
from io import StringIO
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
@@ -10,7 +11,7 @@ import numpy as np
|
||||
import modules.scripts as scripts
|
||||
import gradio as gr
|
||||
|
||||
from modules import images, sd_samplers, processing, sd_models, sd_vae, sd_samplers_kdiffusion
|
||||
from modules import images, sd_samplers, processing, sd_models, sd_vae, sd_samplers_kdiffusion, errors
|
||||
from modules.processing import process_images, Processed, StableDiffusionProcessingTxt2Img
|
||||
from modules.shared import opts, state
|
||||
import modules.shared as shared
|
||||
@@ -66,14 +67,6 @@ def apply_order(p, x, xs):
|
||||
p.prompt = prompt_tmp + p.prompt
|
||||
|
||||
|
||||
def apply_sampler(p, x, xs):
|
||||
sampler_name = sd_samplers.samplers_map.get(x.lower(), None)
|
||||
if sampler_name is None:
|
||||
raise RuntimeError(f"Unknown sampler: {x}")
|
||||
|
||||
p.sampler_name = sampler_name
|
||||
|
||||
|
||||
def confirm_samplers(p, xs):
|
||||
for x in xs:
|
||||
if x.lower() not in sd_samplers.samplers_map:
|
||||
@@ -182,6 +175,8 @@ def do_nothing(p, x, xs):
|
||||
def format_nothing(p, opt, x):
|
||||
return ""
|
||||
|
||||
def format_remove_path(p, opt, x):
|
||||
return os.path.basename(x)
|
||||
|
||||
def str_permutations(x):
|
||||
"""dummy function for specifying it in AxisOption's type when you want to get a list of permutations"""
|
||||
@@ -221,9 +216,10 @@ axis_options = [
|
||||
AxisOptionImg2Img("Image CFG Scale", float, apply_field("image_cfg_scale")),
|
||||
AxisOption("Prompt S/R", str, apply_prompt, format_value=format_value),
|
||||
AxisOption("Prompt order", str_permutations, apply_order, format_value=format_value_join_list),
|
||||
AxisOptionTxt2Img("Sampler", str, apply_sampler, format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers]),
|
||||
AxisOptionImg2Img("Sampler", str, apply_sampler, format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers_for_img2img]),
|
||||
AxisOption("Checkpoint name", str, apply_checkpoint, format_value=format_value, confirm=confirm_checkpoints, cost=1.0, choices=lambda: sorted(sd_models.checkpoints_list, key=str.casefold)),
|
||||
AxisOptionTxt2Img("Sampler", str, apply_field("sampler_name"), format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers]),
|
||||
AxisOptionTxt2Img("Hires sampler", str, apply_field("hr_sampler_name"), confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers_for_img2img]),
|
||||
AxisOptionImg2Img("Sampler", str, apply_field("sampler_name"), format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers_for_img2img]),
|
||||
AxisOption("Checkpoint name", str, apply_checkpoint, format_value=format_remove_path, confirm=confirm_checkpoints, cost=1.0, choices=lambda: sorted(sd_models.checkpoints_list, key=str.casefold)),
|
||||
AxisOption("Negative Guidance minimum sigma", float, apply_field("s_min_uncond")),
|
||||
AxisOption("Sigma Churn", float, apply_field("s_churn")),
|
||||
AxisOption("Sigma min", float, apply_field("s_tmin")),
|
||||
@@ -648,7 +644,12 @@ class Script(scripts.Script):
|
||||
y_opt.apply(pc, y, ys)
|
||||
z_opt.apply(pc, z, zs)
|
||||
|
||||
res = process_images(pc)
|
||||
try:
|
||||
res = process_images(pc)
|
||||
except Exception as e:
|
||||
errors.display(e, "generating image for xyz plot")
|
||||
|
||||
res = Processed(p, [], p.seed, "")
|
||||
|
||||
# Sets subgrid infotexts
|
||||
subgrid_index = 1 + iz
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
--checkbox-label-gap: 0.25em 0.1em;
|
||||
--section-header-text-size: 12pt;
|
||||
--block-background-fill: transparent;
|
||||
|
||||
}
|
||||
|
||||
.block.padded:not(.gradio-accordion) {
|
||||
@@ -42,7 +43,8 @@ div.form{
|
||||
.block.gradio-radio,
|
||||
.block.gradio-checkboxgroup,
|
||||
.block.gradio-number,
|
||||
.block.gradio-colorpicker
|
||||
.block.gradio-colorpicker,
|
||||
div.gradio-group
|
||||
{
|
||||
border-width: 0 !important;
|
||||
box-shadow: none !important;
|
||||
@@ -133,6 +135,15 @@ a{
|
||||
cursor: pointer;
|
||||
}
|
||||
|
||||
div.styler{
|
||||
border: none;
|
||||
background: var(--background-fill-primary);
|
||||
}
|
||||
|
||||
.block.gradio-textbox{
|
||||
overflow: visible !important;
|
||||
}
|
||||
|
||||
|
||||
/* general styled components */
|
||||
|
||||
@@ -164,7 +175,7 @@ a{
|
||||
.checkboxes-row > div{
|
||||
flex: 0;
|
||||
white-space: nowrap;
|
||||
min-width: auto;
|
||||
min-width: auto !important;
|
||||
}
|
||||
|
||||
button.custom-button{
|
||||
@@ -388,6 +399,7 @@ div#extras_scale_to_tab div.form{
|
||||
#quicksettings > div, #quicksettings > fieldset{
|
||||
max-width: 24em;
|
||||
min-width: 24em;
|
||||
width: 24em;
|
||||
padding: 0;
|
||||
border: none;
|
||||
box-shadow: none;
|
||||
@@ -482,6 +494,13 @@ table.popup-table .link{
|
||||
font-size: 18pt;
|
||||
}
|
||||
|
||||
#settings .settings-info{
|
||||
max-width: 48em;
|
||||
border: 1px dotted #777;
|
||||
margin: 0;
|
||||
padding: 1em;
|
||||
}
|
||||
|
||||
|
||||
/* live preview */
|
||||
.progressDiv{
|
||||
@@ -767,9 +786,14 @@ footer {
|
||||
/* extra networks UI */
|
||||
|
||||
.extra-network-cards{
|
||||
height: 725px;
|
||||
overflow: scroll;
|
||||
height: calc(100vh - 24rem);
|
||||
overflow: clip scroll;
|
||||
resize: vertical;
|
||||
min-height: 52rem;
|
||||
}
|
||||
|
||||
.extra-networks > div.tab-nav{
|
||||
min-height: 3.4rem;
|
||||
}
|
||||
|
||||
.extra-networks > div > [id *= '_extra_']{
|
||||
@@ -784,10 +808,12 @@ footer {
|
||||
margin: 0 0.15em;
|
||||
}
|
||||
.extra-networks .tab-nav .search,
|
||||
.extra-networks .tab-nav .sort{
|
||||
display: inline-block;
|
||||
.extra-networks .tab-nav .sort,
|
||||
.extra-networks .tab-nav .show-dirs
|
||||
{
|
||||
margin: 0.3em;
|
||||
align-self: center;
|
||||
width: auto;
|
||||
}
|
||||
|
||||
.extra-networks .tab-nav .search {
|
||||
@@ -972,3 +998,16 @@ div.block.gradio-box.edit-user-metadata {
|
||||
.edit-user-metadata-buttons{
|
||||
margin-top: 1.5em;
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
div.block.gradio-box.popup-dialog, .popup-dialog {
|
||||
width: 56em;
|
||||
background: var(--body-background-fill);
|
||||
padding: 2em !important;
|
||||
}
|
||||
|
||||
div.block.gradio-box.popup-dialog > div:last-child, .popup-dialog > div:last-child{
|
||||
margin-top: 1em;
|
||||
}
|
||||
|
||||
@@ -14,7 +14,6 @@ from typing import Iterable
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.middleware.gzip import GZipMiddleware
|
||||
from packaging import version
|
||||
|
||||
import logging
|
||||
|
||||
@@ -50,6 +49,7 @@ startup_timer.record("setup paths")
|
||||
import ldm.modules.encoders.modules # noqa: F401
|
||||
startup_timer.record("import ldm")
|
||||
|
||||
|
||||
from modules import extra_networks
|
||||
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, queue_lock # noqa: F401
|
||||
|
||||
@@ -58,10 +58,15 @@ if ".dev" in torch.__version__ or "+git" in torch.__version__:
|
||||
torch.__long_version__ = torch.__version__
|
||||
torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0)
|
||||
|
||||
from modules import shared, sd_samplers, upscaler, extensions, localization, ui_tempdir, ui_extra_networks, config_states
|
||||
from modules import shared
|
||||
|
||||
if not shared.cmd_opts.skip_version_check:
|
||||
errors.check_versions()
|
||||
|
||||
import modules.codeformer_model as codeformer
|
||||
import modules.face_restoration
|
||||
import modules.gfpgan_model as gfpgan
|
||||
from modules import sd_samplers, upscaler, extensions, localization, ui_tempdir, ui_extra_networks, config_states
|
||||
import modules.face_restoration
|
||||
import modules.img2img
|
||||
|
||||
import modules.lowvram
|
||||
@@ -130,37 +135,6 @@ def fix_asyncio_event_loop_policy():
|
||||
asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy())
|
||||
|
||||
|
||||
def check_versions():
|
||||
if shared.cmd_opts.skip_version_check:
|
||||
return
|
||||
|
||||
expected_torch_version = "2.0.0"
|
||||
|
||||
if version.parse(torch.__version__) < version.parse(expected_torch_version):
|
||||
errors.print_error_explanation(f"""
|
||||
You are running torch {torch.__version__}.
|
||||
The program is tested to work with torch {expected_torch_version}.
|
||||
To reinstall the desired version, run with commandline flag --reinstall-torch.
|
||||
Beware that this will cause a lot of large files to be downloaded, as well as
|
||||
there are reports of issues with training tab on the latest version.
|
||||
|
||||
Use --skip-version-check commandline argument to disable this check.
|
||||
""".strip())
|
||||
|
||||
expected_xformers_version = "0.0.20"
|
||||
if shared.xformers_available:
|
||||
import xformers
|
||||
|
||||
if version.parse(xformers.__version__) < version.parse(expected_xformers_version):
|
||||
errors.print_error_explanation(f"""
|
||||
You are running xformers {xformers.__version__}.
|
||||
The program is tested to work with xformers {expected_xformers_version}.
|
||||
To reinstall the desired version, run with commandline flag --reinstall-xformers.
|
||||
|
||||
Use --skip-version-check commandline argument to disable this check.
|
||||
""".strip())
|
||||
|
||||
|
||||
def restore_config_state_file():
|
||||
config_state_file = shared.opts.restore_config_state_file
|
||||
if config_state_file == "":
|
||||
@@ -248,7 +222,6 @@ def initialize():
|
||||
fix_asyncio_event_loop_policy()
|
||||
validate_tls_options()
|
||||
configure_sigint_handler()
|
||||
check_versions()
|
||||
modelloader.cleanup_models()
|
||||
configure_opts_onchange()
|
||||
|
||||
@@ -320,9 +293,9 @@ def initialize_rest(*, reload_script_modules=False):
|
||||
if modules.sd_hijack.current_optimizer is None:
|
||||
modules.sd_hijack.apply_optimizations()
|
||||
|
||||
Thread(target=load_model).start()
|
||||
devices.first_time_calculation()
|
||||
|
||||
Thread(target=devices.first_time_calculation).start()
|
||||
Thread(target=load_model).start()
|
||||
|
||||
shared.reload_hypernetworks()
|
||||
startup_timer.record("reload hypernetworks")
|
||||
|
||||
Reference in New Issue
Block a user