Merge branch 'dev' into Branch_AddNewFilenameGen
This commit is contained in:
+88
-18
@@ -3,11 +3,14 @@ import io
|
||||
import time
|
||||
import datetime
|
||||
import uvicorn
|
||||
import gradio as gr
|
||||
from threading import Lock
|
||||
from io import BytesIO
|
||||
from gradio.processing_utils import decode_base64_to_file
|
||||
from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request, Response
|
||||
from fastapi import APIRouter, Depends, FastAPI, Request, Response
|
||||
from fastapi.security import HTTPBasic, HTTPBasicCredentials
|
||||
from fastapi.exceptions import HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from secrets import compare_digest
|
||||
|
||||
import modules.shared as shared
|
||||
@@ -18,7 +21,7 @@ from modules.textual_inversion.textual_inversion import create_embedding, train_
|
||||
from modules.textual_inversion.preprocess import preprocess
|
||||
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
|
||||
from PIL import PngImagePlugin,Image
|
||||
from modules.sd_models import checkpoints_list
|
||||
from modules.sd_models import checkpoints_list, unload_model_weights, reload_model_weights
|
||||
from modules.sd_models_config import find_checkpoint_config_near_filename
|
||||
from modules.realesrgan_model import get_realesrgan_models
|
||||
from modules import devices
|
||||
@@ -90,6 +93,16 @@ def encode_pil_to_base64(image):
|
||||
return base64.b64encode(bytes_data)
|
||||
|
||||
def api_middleware(app: FastAPI):
|
||||
rich_available = True
|
||||
try:
|
||||
import anyio # importing just so it can be placed on silent list
|
||||
import starlette # importing just so it can be placed on silent list
|
||||
from rich.console import Console
|
||||
console = Console()
|
||||
except:
|
||||
import traceback
|
||||
rich_available = False
|
||||
|
||||
@app.middleware("http")
|
||||
async def log_and_time(req: Request, call_next):
|
||||
ts = time.time()
|
||||
@@ -110,6 +123,36 @@ def api_middleware(app: FastAPI):
|
||||
))
|
||||
return res
|
||||
|
||||
def handle_exception(request: Request, e: Exception):
|
||||
err = {
|
||||
"error": type(e).__name__,
|
||||
"detail": vars(e).get('detail', ''),
|
||||
"body": vars(e).get('body', ''),
|
||||
"errors": str(e),
|
||||
}
|
||||
print(f"API error: {request.method}: {request.url} {err}")
|
||||
if not isinstance(e, HTTPException): # do not print backtrace on known httpexceptions
|
||||
if rich_available:
|
||||
console.print_exception(show_locals=True, max_frames=2, extra_lines=1, suppress=[anyio, starlette], word_wrap=False, width=min([console.width, 200]))
|
||||
else:
|
||||
traceback.print_exc()
|
||||
return JSONResponse(status_code=vars(e).get('status_code', 500), content=jsonable_encoder(err))
|
||||
|
||||
@app.middleware("http")
|
||||
async def exception_handling(request: Request, call_next):
|
||||
try:
|
||||
return await call_next(request)
|
||||
except Exception as e:
|
||||
return handle_exception(request, e)
|
||||
|
||||
@app.exception_handler(Exception)
|
||||
async def fastapi_exception_handler(request: Request, e: Exception):
|
||||
return handle_exception(request, e)
|
||||
|
||||
@app.exception_handler(HTTPException)
|
||||
async def http_exception_handler(request: Request, e: HTTPException):
|
||||
return handle_exception(request, e)
|
||||
|
||||
|
||||
class Api:
|
||||
def __init__(self, app: FastAPI, queue_lock: Lock):
|
||||
@@ -150,8 +193,13 @@ class Api:
|
||||
self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=TrainResponse)
|
||||
self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=TrainResponse)
|
||||
self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=MemoryResponse)
|
||||
self.add_api_route("/sdapi/v1/unload-checkpoint", self.unloadapi, methods=["POST"])
|
||||
self.add_api_route("/sdapi/v1/reload-checkpoint", self.reloadapi, methods=["POST"])
|
||||
self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=ScriptsList)
|
||||
|
||||
self.default_script_arg_txt2img = []
|
||||
self.default_script_arg_img2img = []
|
||||
|
||||
def add_api_route(self, path: str, endpoint, **kwargs):
|
||||
if shared.cmd_opts.api_auth:
|
||||
return self.app.add_api_route(path, endpoint, dependencies=[Depends(self.auth)], **kwargs)
|
||||
@@ -185,7 +233,7 @@ class Api:
|
||||
script_idx = script_name_to_index(script_name, script_runner.scripts)
|
||||
return script_runner.scripts[script_idx]
|
||||
|
||||
def init_script_args(self, request, selectable_scripts, selectable_idx, script_runner):
|
||||
def init_default_script_args(self, script_runner):
|
||||
#find max idx from the scripts in runner and generate a none array to init script_args
|
||||
last_arg_index = 1
|
||||
for script in script_runner.scripts:
|
||||
@@ -193,13 +241,24 @@ class Api:
|
||||
last_arg_index = script.args_to
|
||||
# None everywhere except position 0 to initialize script args
|
||||
script_args = [None]*last_arg_index
|
||||
script_args[0] = 0
|
||||
|
||||
# get default values
|
||||
with gr.Blocks(): # will throw errors calling ui function without this
|
||||
for script in script_runner.scripts:
|
||||
if script.ui(script.is_img2img):
|
||||
ui_default_values = []
|
||||
for elem in script.ui(script.is_img2img):
|
||||
ui_default_values.append(elem.value)
|
||||
script_args[script.args_from:script.args_to] = ui_default_values
|
||||
return script_args
|
||||
|
||||
def init_script_args(self, request, default_script_args, selectable_scripts, selectable_idx, script_runner):
|
||||
script_args = default_script_args.copy()
|
||||
# position 0 in script_arg is the idx+1 of the selectable script that is going to be run when using scripts.scripts_*2img.run()
|
||||
if selectable_scripts:
|
||||
script_args[selectable_scripts.args_from:selectable_scripts.args_to] = request.script_args
|
||||
script_args[0] = selectable_idx + 1
|
||||
else:
|
||||
# when [0] = 0 no selectable script to run
|
||||
script_args[0] = 0
|
||||
|
||||
# Now check for always on scripts
|
||||
if request.alwayson_scripts and (len(request.alwayson_scripts) > 0):
|
||||
@@ -212,7 +271,9 @@ class Api:
|
||||
raise HTTPException(status_code=422, detail=f"Cannot have a selectable script in the always on scripts params")
|
||||
# always on script with no arg should always run so you don't really need to add them to the requests
|
||||
if "args" in request.alwayson_scripts[alwayson_script_name]:
|
||||
script_args[alwayson_script.args_from:alwayson_script.args_to] = request.alwayson_scripts[alwayson_script_name]["args"]
|
||||
# min between arg length in scriptrunner and arg length in the request
|
||||
for idx in range(0, min((alwayson_script.args_to - alwayson_script.args_from), len(request.alwayson_scripts[alwayson_script_name]["args"]))):
|
||||
script_args[alwayson_script.args_from + idx] = request.alwayson_scripts[alwayson_script_name]["args"][idx]
|
||||
return script_args
|
||||
|
||||
def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
|
||||
@@ -220,6 +281,8 @@ class Api:
|
||||
if not script_runner.scripts:
|
||||
script_runner.initialize_scripts(False)
|
||||
ui.create_ui()
|
||||
if not self.default_script_arg_txt2img:
|
||||
self.default_script_arg_txt2img = self.init_default_script_args(script_runner)
|
||||
selectable_scripts, selectable_script_idx = self.get_selectable_script(txt2imgreq.script_name, script_runner)
|
||||
|
||||
populate = txt2imgreq.copy(update={ # Override __init__ params
|
||||
@@ -235,7 +298,7 @@ class Api:
|
||||
args.pop('script_args', None) # will refeed them to the pipeline directly after initializing them
|
||||
args.pop('alwayson_scripts', None)
|
||||
|
||||
script_args = self.init_script_args(txt2imgreq, selectable_scripts, selectable_script_idx, script_runner)
|
||||
script_args = self.init_script_args(txt2imgreq, self.default_script_arg_txt2img, selectable_scripts, selectable_script_idx, script_runner)
|
||||
|
||||
send_images = args.pop('send_images', True)
|
||||
args.pop('save_images', None)
|
||||
@@ -272,6 +335,8 @@ class Api:
|
||||
if not script_runner.scripts:
|
||||
script_runner.initialize_scripts(True)
|
||||
ui.create_ui()
|
||||
if not self.default_script_arg_img2img:
|
||||
self.default_script_arg_img2img = self.init_default_script_args(script_runner)
|
||||
selectable_scripts, selectable_script_idx = self.get_selectable_script(img2imgreq.script_name, script_runner)
|
||||
|
||||
populate = img2imgreq.copy(update={ # Override __init__ params
|
||||
@@ -289,7 +354,7 @@ class Api:
|
||||
args.pop('script_args', None) # will refeed them to the pipeline directly after initializing them
|
||||
args.pop('alwayson_scripts', None)
|
||||
|
||||
script_args = self.init_script_args(img2imgreq, selectable_scripts, selectable_script_idx, script_runner)
|
||||
script_args = self.init_script_args(img2imgreq, self.default_script_arg_img2img, selectable_scripts, selectable_script_idx, script_runner)
|
||||
|
||||
send_images = args.pop('send_images', True)
|
||||
args.pop('save_images', None)
|
||||
@@ -331,16 +396,11 @@ class Api:
|
||||
def extras_batch_images_api(self, req: ExtrasBatchImagesRequest):
|
||||
reqDict = setUpscalers(req)
|
||||
|
||||
def prepareFiles(file):
|
||||
file = decode_base64_to_file(file.data, file_path=file.name)
|
||||
file.orig_name = file.name
|
||||
return file
|
||||
|
||||
reqDict['image_folder'] = list(map(prepareFiles, reqDict['imageList']))
|
||||
reqDict.pop('imageList')
|
||||
image_list = reqDict.pop('imageList', [])
|
||||
image_folder = [decode_base64_to_image(x.data) for x in image_list]
|
||||
|
||||
with self.queue_lock:
|
||||
result = postprocessing.run_extras(extras_mode=1, image="", input_dir="", output_dir="", save_output=False, **reqDict)
|
||||
result = postprocessing.run_extras(extras_mode=1, image_folder=image_folder, image="", input_dir="", output_dir="", save_output=False, **reqDict)
|
||||
|
||||
return ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1])
|
||||
|
||||
@@ -412,6 +472,16 @@ class Api:
|
||||
|
||||
return {}
|
||||
|
||||
def unloadapi(self):
|
||||
unload_model_weights()
|
||||
|
||||
return {}
|
||||
|
||||
def reloadapi(self):
|
||||
reload_model_weights()
|
||||
|
||||
return {}
|
||||
|
||||
def skip(self):
|
||||
shared.state.skip()
|
||||
|
||||
|
||||
@@ -0,0 +1,103 @@
|
||||
import argparse
|
||||
import os
|
||||
from modules.paths_internal import models_path, script_path, data_path, extensions_dir, extensions_builtin_dir, sd_default_config, sd_model_file
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument("-f", action='store_true', help=argparse.SUPPRESS) # allows running as root; implemented outside of webui
|
||||
parser.add_argument("--update-all-extensions", action='store_true', help="launch.py argument: download updates for all extensions when starting the program")
|
||||
parser.add_argument("--skip-python-version-check", action='store_true', help="launch.py argument: do not check python version")
|
||||
parser.add_argument("--skip-torch-cuda-test", action='store_true', help="launch.py argument: do not check if CUDA is able to work properly")
|
||||
parser.add_argument("--reinstall-xformers", action='store_true', help="launch.py argument: install the appropriate version of xformers even if you have some version already installed")
|
||||
parser.add_argument("--reinstall-torch", action='store_true', help="launch.py argument: install the appropriate version of torch even if you have some version already installed")
|
||||
parser.add_argument("--update-check", action='store_true', help="launch.py argument: chck for updates at startup")
|
||||
parser.add_argument("--tests", type=str, default=None, help="launch.py argument: run tests in the specified directory")
|
||||
parser.add_argument("--no-tests", action='store_true', help="launch.py argument: do not run tests even if --tests option is specified")
|
||||
parser.add_argument("--skip-install", action='store_true', help="launch.py argument: skip installation of packages")
|
||||
parser.add_argument("--data-dir", type=str, default=os.path.dirname(os.path.dirname(os.path.realpath(__file__))), help="base path where all user data is stored")
|
||||
parser.add_argument("--config", type=str, default=sd_default_config, help="path to config which constructs model",)
|
||||
parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",)
|
||||
parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to directory with stable diffusion checkpoints")
|
||||
parser.add_argument("--vae-dir", type=str, default=None, help="Path to directory with VAE files")
|
||||
parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN'))
|
||||
parser.add_argument("--gfpgan-model", type=str, help="GFPGAN model file name", default=None)
|
||||
parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats")
|
||||
parser.add_argument("--no-half-vae", action='store_true', help="do not switch the VAE model to 16-bit floats")
|
||||
parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware acceleration in browser)")
|
||||
parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI")
|
||||
parser.add_argument("--embeddings-dir", type=str, default=os.path.join(data_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)")
|
||||
parser.add_argument("--textual-inversion-templates-dir", type=str, default=os.path.join(script_path, 'textual_inversion_templates'), help="directory with textual inversion templates")
|
||||
parser.add_argument("--hypernetwork-dir", type=str, default=os.path.join(models_path, 'hypernetworks'), help="hypernetwork directory")
|
||||
parser.add_argument("--localizations-dir", type=str, default=os.path.join(script_path, 'localizations'), help="localizations directory")
|
||||
parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui")
|
||||
parser.add_argument("--medvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a little speed for low VRM usage")
|
||||
parser.add_argument("--lowvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a lot of speed for very low VRM usage")
|
||||
parser.add_argument("--lowram", action='store_true', help="load stable diffusion checkpoint weights to VRAM instead of RAM")
|
||||
parser.add_argument("--always-batch-cond-uncond", action='store_true', help="disables cond/uncond batching that is enabled to save memory with --medvram or --lowvram")
|
||||
parser.add_argument("--unload-gfpgan", action='store_true', help="does not do anything.")
|
||||
parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast")
|
||||
parser.add_argument("--upcast-sampling", action='store_true', help="upcast sampling. No effect with --no-half. Usually produces similar results to --no-half with better performance while using less memory.")
|
||||
parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site")
|
||||
parser.add_argument("--ngrok", type=str, help="ngrok authtoken, alternative to gradio --share", default=None)
|
||||
parser.add_argument("--ngrok-region", type=str, help="The region in which ngrok should start.", default="us")
|
||||
parser.add_argument("--enable-insecure-extension-access", action='store_true', help="enable extensions tab regardless of other options")
|
||||
parser.add_argument("--codeformer-models-path", type=str, help="Path to directory with codeformer model file(s).", default=os.path.join(models_path, 'Codeformer'))
|
||||
parser.add_argument("--gfpgan-models-path", type=str, help="Path to directory with GFPGAN model file(s).", default=os.path.join(models_path, 'GFPGAN'))
|
||||
parser.add_argument("--esrgan-models-path", type=str, help="Path to directory with ESRGAN model file(s).", default=os.path.join(models_path, 'ESRGAN'))
|
||||
parser.add_argument("--bsrgan-models-path", type=str, help="Path to directory with BSRGAN model file(s).", default=os.path.join(models_path, 'BSRGAN'))
|
||||
parser.add_argument("--realesrgan-models-path", type=str, help="Path to directory with RealESRGAN model file(s).", default=os.path.join(models_path, 'RealESRGAN'))
|
||||
parser.add_argument("--clip-models-path", type=str, help="Path to directory with CLIP model file(s).", default=None)
|
||||
parser.add_argument("--xformers", action='store_true', help="enable xformers for cross attention layers")
|
||||
parser.add_argument("--force-enable-xformers", action='store_true', help="enable xformers for cross attention layers regardless of whether the checking code thinks you can run it; do not make bug reports if this fails to work")
|
||||
parser.add_argument("--xformers-flash-attention", action='store_true', help="enable xformers with Flash Attention to improve reproducibility (supported for SD2.x or variant only)")
|
||||
parser.add_argument("--deepdanbooru", action='store_true', help="does not do anything")
|
||||
parser.add_argument("--opt-split-attention", action='store_true', help="force-enables Doggettx's cross-attention layer optimization. By default, it's on for torch cuda.")
|
||||
parser.add_argument("--opt-sub-quad-attention", action='store_true', help="enable memory efficient sub-quadratic cross-attention layer optimization")
|
||||
parser.add_argument("--sub-quad-q-chunk-size", type=int, help="query chunk size for the sub-quadratic cross-attention layer optimization to use", default=1024)
|
||||
parser.add_argument("--sub-quad-kv-chunk-size", type=int, help="kv chunk size for the sub-quadratic cross-attention layer optimization to use", default=None)
|
||||
parser.add_argument("--sub-quad-chunk-threshold", type=int, help="the percentage of VRAM threshold for the sub-quadratic cross-attention layer optimization to use chunking", default=None)
|
||||
parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="force-enables InvokeAI's cross-attention layer optimization. By default, it's on when cuda is unavailable.")
|
||||
parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find")
|
||||
parser.add_argument("--opt-sdp-attention", action='store_true', help="enable scaled dot product cross-attention layer optimization; requires PyTorch 2.*")
|
||||
parser.add_argument("--opt-sdp-no-mem-attention", action='store_true', help="enable scaled dot product cross-attention layer optimization without memory efficient attention, makes image generation deterministic; requires PyTorch 2.*")
|
||||
parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization")
|
||||
parser.add_argument("--disable-nan-check", action='store_true', help="do not check if produced images/latent spaces have nans; useful for running without a checkpoint in CI")
|
||||
parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for specified modules", default=[], type=str.lower)
|
||||
parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")
|
||||
parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None)
|
||||
parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False)
|
||||
parser.add_argument("--ui-config-file", type=str, help="filename to use for ui configuration", default=os.path.join(data_path, 'ui-config.json'))
|
||||
parser.add_argument("--hide-ui-dir-config", action='store_true', help="hide directory configuration from webui", default=False)
|
||||
parser.add_argument("--freeze-settings", action='store_true', help="disable editing settings", default=False)
|
||||
parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui settings", default=os.path.join(data_path, 'config.json'))
|
||||
parser.add_argument("--gradio-debug", action='store_true', help="launch gradio with --debug option")
|
||||
parser.add_argument("--gradio-auth", type=str, help='set gradio authentication like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None)
|
||||
parser.add_argument("--gradio-auth-path", type=str, help='set gradio authentication file path ex. "/path/to/auth/file" same auth format as --gradio-auth', default=None)
|
||||
parser.add_argument("--gradio-img2img-tool", type=str, help='does not do anything')
|
||||
parser.add_argument("--gradio-inpaint-tool", type=str, help="does not do anything")
|
||||
parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last")
|
||||
parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(data_path, 'styles.csv'))
|
||||
parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False)
|
||||
parser.add_argument("--theme", type=str, help="launches the UI with light or dark theme", default=None)
|
||||
parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False)
|
||||
parser.add_argument("--disable-console-progressbars", action='store_true', help="do not output progressbars to console", default=False)
|
||||
parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False)
|
||||
parser.add_argument('--vae-path', type=str, help='Checkpoint to use as VAE; setting this argument disables all settings related to VAE', default=None)
|
||||
parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False)
|
||||
parser.add_argument("--api", action='store_true', help="use api=True to launch the API together with the webui (use --nowebui instead for only the API)")
|
||||
parser.add_argument("--api-auth", type=str, help='Set authentication for API like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None)
|
||||
parser.add_argument("--api-log", action='store_true', help="use api-log=True to enable logging of all API requests")
|
||||
parser.add_argument("--nowebui", action='store_true', help="use api=True to launch the API instead of the webui")
|
||||
parser.add_argument("--ui-debug-mode", action='store_true', help="Don't load model to quickly launch UI")
|
||||
parser.add_argument("--device-id", type=str, help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)", default=None)
|
||||
parser.add_argument("--administrator", action='store_true', help="Administrator rights", default=False)
|
||||
parser.add_argument("--cors-allow-origins", type=str, help="Allowed CORS origin(s) in the form of a comma-separated list (no spaces)", default=None)
|
||||
parser.add_argument("--cors-allow-origins-regex", type=str, help="Allowed CORS origin(s) in the form of a single regular expression", default=None)
|
||||
parser.add_argument("--tls-keyfile", type=str, help="Partially enables TLS, requires --tls-certfile to fully function", default=None)
|
||||
parser.add_argument("--tls-certfile", type=str, help="Partially enables TLS, requires --tls-keyfile to fully function", default=None)
|
||||
parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None)
|
||||
parser.add_argument("--gradio-queue", action='store_true', help="does not do anything", default=True)
|
||||
parser.add_argument("--no-gradio-queue", action='store_true', help="Disables gradio queue; causes the webpage to use http requests instead of websockets; was the defaul in earlier versions")
|
||||
parser.add_argument("--skip-version-check", action='store_true', help="Do not check versions of torch and xformers")
|
||||
parser.add_argument("--no-hashing", action='store_true', help="disable sha256 hashing of checkpoints to help loading performance", default=False)
|
||||
parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False)
|
||||
+6
-2
@@ -92,14 +92,18 @@ def cond_cast_float(input):
|
||||
|
||||
|
||||
def randn(seed, shape):
|
||||
from modules.shared import opts
|
||||
|
||||
torch.manual_seed(seed)
|
||||
if device.type == 'mps':
|
||||
if opts.randn_source == "CPU" or device.type == 'mps':
|
||||
return torch.randn(shape, device=cpu).to(device)
|
||||
return torch.randn(shape, device=device)
|
||||
|
||||
|
||||
def randn_without_seed(shape):
|
||||
if device.type == 'mps':
|
||||
from modules.shared import opts
|
||||
|
||||
if opts.randn_source == "CPU" or device.type == 'mps':
|
||||
return torch.randn(shape, device=cpu).to(device)
|
||||
return torch.randn(shape, device=device)
|
||||
|
||||
|
||||
+29
-12
@@ -5,17 +5,22 @@ import traceback
|
||||
import time
|
||||
import git
|
||||
|
||||
from modules import paths, shared
|
||||
from modules import shared
|
||||
from modules.paths_internal import extensions_dir, extensions_builtin_dir
|
||||
|
||||
extensions = []
|
||||
extensions_dir = os.path.join(paths.data_path, "extensions")
|
||||
extensions_builtin_dir = os.path.join(paths.script_path, "extensions-builtin")
|
||||
|
||||
if not os.path.exists(extensions_dir):
|
||||
os.makedirs(extensions_dir)
|
||||
|
||||
|
||||
def active():
|
||||
return [x for x in extensions if x.enabled]
|
||||
if shared.opts.disable_all_extensions == "all":
|
||||
return []
|
||||
elif shared.opts.disable_all_extensions == "extra":
|
||||
return [x for x in extensions if x.enabled and x.is_builtin]
|
||||
else:
|
||||
return [x for x in extensions if x.enabled]
|
||||
|
||||
|
||||
class Extension:
|
||||
@@ -27,21 +32,29 @@ class Extension:
|
||||
self.can_update = False
|
||||
self.is_builtin = is_builtin
|
||||
self.version = ''
|
||||
self.remote = None
|
||||
self.have_info_from_repo = False
|
||||
|
||||
def read_info_from_repo(self):
|
||||
if self.have_info_from_repo:
|
||||
return
|
||||
|
||||
self.have_info_from_repo = True
|
||||
|
||||
repo = None
|
||||
try:
|
||||
if os.path.exists(os.path.join(path, ".git")):
|
||||
repo = git.Repo(path)
|
||||
if os.path.exists(os.path.join(self.path, ".git")):
|
||||
repo = git.Repo(self.path)
|
||||
except Exception:
|
||||
print(f"Error reading github repository info from {path}:", file=sys.stderr)
|
||||
print(f"Error reading github repository info from {self.path}:", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
if repo is None or repo.bare:
|
||||
self.remote = None
|
||||
else:
|
||||
try:
|
||||
self.remote = next(repo.remote().urls, None)
|
||||
self.status = 'unknown'
|
||||
self.remote = next(repo.remote().urls, None)
|
||||
head = repo.head.commit
|
||||
ts = time.asctime(time.gmtime(repo.head.commit.committed_date))
|
||||
self.version = f'{head.hexsha[:8]} ({ts})'
|
||||
@@ -89,7 +102,12 @@ def list_extensions():
|
||||
if not os.path.isdir(extensions_dir):
|
||||
return
|
||||
|
||||
paths = []
|
||||
if shared.opts.disable_all_extensions == "all":
|
||||
print("*** \"Disable all extensions\" option was set, will not load any extensions ***")
|
||||
elif shared.opts.disable_all_extensions == "extra":
|
||||
print("*** \"Disable all extensions\" option was set, will only load built-in extensions ***")
|
||||
|
||||
extension_paths = []
|
||||
for dirname in [extensions_dir, extensions_builtin_dir]:
|
||||
if not os.path.isdir(dirname):
|
||||
return
|
||||
@@ -99,9 +117,8 @@ def list_extensions():
|
||||
if not os.path.isdir(path):
|
||||
continue
|
||||
|
||||
paths.append((extension_dirname, path, dirname == extensions_builtin_dir))
|
||||
extension_paths.append((extension_dirname, path, dirname == extensions_builtin_dir))
|
||||
|
||||
for dirname, path, is_builtin in paths:
|
||||
for dirname, path, is_builtin in extension_paths:
|
||||
extension = Extension(name=dirname, path=path, enabled=dirname not in shared.opts.disabled_extensions, is_builtin=is_builtin)
|
||||
extensions.append(extension)
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ class ExtraNetworkHypernet(extra_networks.ExtraNetwork):
|
||||
def activate(self, p, params_list):
|
||||
additional = shared.opts.sd_hypernetwork
|
||||
|
||||
if additional != "" and additional in shared.hypernetworks and len([x for x in params_list if x.items[0] == additional]) == 0:
|
||||
if additional != "None" and additional in shared.hypernetworks and len([x for x in params_list if x.items[0] == additional]) == 0:
|
||||
p.all_prompts = [x + f"<hypernet:{additional}:{shared.opts.extra_networks_default_multiplier}>" for x in p.all_prompts]
|
||||
params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))
|
||||
|
||||
|
||||
@@ -284,6 +284,10 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
|
||||
|
||||
restore_old_hires_fix_params(res)
|
||||
|
||||
# Missing RNG means the default was set, which is GPU RNG
|
||||
if "RNG" not in res:
|
||||
res["RNG"] = "GPU"
|
||||
|
||||
return res
|
||||
|
||||
|
||||
@@ -304,6 +308,7 @@ infotext_to_setting_name_mapping = [
|
||||
('UniPC skip type', 'uni_pc_skip_type'),
|
||||
('UniPC order', 'uni_pc_order'),
|
||||
('UniPC lower order final', 'uni_pc_lower_order_final'),
|
||||
('RNG', 'randn_source'),
|
||||
]
|
||||
|
||||
|
||||
@@ -401,9 +406,14 @@ def connect_paste(button, paste_fields, input_comp, override_settings_component,
|
||||
|
||||
button.click(
|
||||
fn=paste_func,
|
||||
_js=f"recalculate_prompts_{tabname}",
|
||||
inputs=[input_comp],
|
||||
outputs=[x[0] for x in paste_fields],
|
||||
)
|
||||
button.click(
|
||||
fn=None,
|
||||
_js=f"recalculate_prompts_{tabname}",
|
||||
inputs=[],
|
||||
outputs=[],
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -312,7 +312,7 @@ class Hypernetwork:
|
||||
|
||||
def list_hypernetworks(path):
|
||||
res = {}
|
||||
for filename in sorted(glob.iglob(os.path.join(path, '**/*.pt'), recursive=True)):
|
||||
for filename in sorted(glob.iglob(os.path.join(path, '**/*.pt'), recursive=True), key=str.lower):
|
||||
name = os.path.splitext(os.path.basename(filename))[0]
|
||||
# Prevent a hypothetical "None.pt" from being listed.
|
||||
if name != "None":
|
||||
|
||||
+9
-3
@@ -261,9 +261,12 @@ def resize_image(resize_mode, im, width, height, upscaler_name=None):
|
||||
|
||||
if scale > 1.0:
|
||||
upscalers = [x for x in shared.sd_upscalers if x.name == upscaler_name]
|
||||
assert len(upscalers) > 0, f"could not find upscaler named {upscaler_name}"
|
||||
if len(upscalers) == 0:
|
||||
upscaler = shared.sd_upscalers[0]
|
||||
print(f"could not find upscaler named {upscaler_name or '<empty string>'}, using {upscaler.name} as a fallback")
|
||||
else:
|
||||
upscaler = upscalers[0]
|
||||
|
||||
upscaler = upscalers[0]
|
||||
im = upscaler.scaler.upscale(im, scale, upscaler.data_path)
|
||||
|
||||
if im.width != w or im.height != h:
|
||||
@@ -350,6 +353,7 @@ class FilenameGenerator:
|
||||
'prompt_spaces': lambda self: sanitize_filename_part(self.prompt, replace_spaces=False),
|
||||
'prompt_words': lambda self: self.prompt_words(),
|
||||
'hasprompt': lambda self, *args: self.hasprompt(*args), #accept formats:[hasprompt<prompt1|default><prompt2>..]
|
||||
'clip_skip': lambda self: opts.data["CLIP_stop_at_last_layers"],
|
||||
}
|
||||
default_time_format = '%Y%m%d%H%M%S'
|
||||
|
||||
@@ -662,6 +666,8 @@ Steps: {json_info["steps"]}, Sampler: {sampler}, CFG scale: {json_info["scale"]}
|
||||
|
||||
|
||||
def image_data(data):
|
||||
import gradio as gr
|
||||
|
||||
try:
|
||||
image = Image.open(io.BytesIO(data))
|
||||
textinfo, _ = read_info_from_image(image)
|
||||
@@ -677,7 +683,7 @@ def image_data(data):
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return '', None
|
||||
return gr.update(), None
|
||||
|
||||
|
||||
def flatten(img, bgcolor):
|
||||
|
||||
+3
-2
@@ -151,13 +151,14 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
|
||||
override_settings=override_settings,
|
||||
)
|
||||
|
||||
p.scripts = modules.scripts.scripts_txt2img
|
||||
p.scripts = modules.scripts.scripts_img2img
|
||||
p.script_args = args
|
||||
|
||||
if shared.cmd_opts.enable_console_prompts:
|
||||
print(f"\nimg2img: {prompt}", file=shared.progress_print_out)
|
||||
|
||||
p.extra_generation_params["Mask blur"] = mask_blur
|
||||
if mask:
|
||||
p.extra_generation_params["Mask blur"] = mask_blur
|
||||
|
||||
if is_batch:
|
||||
assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled"
|
||||
|
||||
@@ -32,7 +32,7 @@ def download_default_clip_interrogate_categories(content_dir):
|
||||
category_types = ["artists", "flavors", "mediums", "movements"]
|
||||
|
||||
try:
|
||||
os.makedirs(tmpdir)
|
||||
os.makedirs(tmpdir, exist_ok=True)
|
||||
for category_type in category_types:
|
||||
torch.hub.download_url_to_file(f"https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/{category_type}.txt", os.path.join(tmpdir, f"{category_type}.txt"))
|
||||
os.rename(tmpdir, content_dir)
|
||||
@@ -41,7 +41,7 @@ def download_default_clip_interrogate_categories(content_dir):
|
||||
errors.display(e, "downloading default CLIP interrogate categories")
|
||||
finally:
|
||||
if os.path.exists(tmpdir):
|
||||
os.remove(tmpdir)
|
||||
os.removedirs(tmpdir)
|
||||
|
||||
|
||||
class InterrogateModels:
|
||||
|
||||
+6
-4
@@ -55,12 +55,12 @@ def setup_for_low_vram(sd_model, use_medvram):
|
||||
if hasattr(sd_model.cond_stage_model, 'model'):
|
||||
sd_model.cond_stage_model.transformer = sd_model.cond_stage_model.model
|
||||
|
||||
# remove four big modules, cond, first_stage, depth (if applicable), and unet from the model and then
|
||||
# remove several big modules: cond, first_stage, depth/embedder (if applicable), and unet from the model and then
|
||||
# send the model to GPU. Then put modules back. the modules will be in CPU.
|
||||
stored = sd_model.cond_stage_model.transformer, sd_model.first_stage_model, getattr(sd_model, 'depth_model', None), sd_model.model
|
||||
sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.depth_model, sd_model.model = None, None, None, None
|
||||
stored = sd_model.cond_stage_model.transformer, sd_model.first_stage_model, getattr(sd_model, 'depth_model', None), getattr(sd_model, 'embedder', None), sd_model.model
|
||||
sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.depth_model, sd_model.embedder, sd_model.model = None, None, None, None, None
|
||||
sd_model.to(devices.device)
|
||||
sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.depth_model, sd_model.model = stored
|
||||
sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.depth_model, sd_model.embedder, sd_model.model = stored
|
||||
|
||||
# register hooks for those the first three models
|
||||
sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu)
|
||||
@@ -69,6 +69,8 @@ def setup_for_low_vram(sd_model, use_medvram):
|
||||
sd_model.first_stage_model.decode = first_stage_model_decode_wrap
|
||||
if sd_model.depth_model:
|
||||
sd_model.depth_model.register_forward_pre_hook(send_me_to_gpu)
|
||||
if sd_model.embedder:
|
||||
sd_model.embedder.register_forward_pre_hook(send_me_to_gpu)
|
||||
parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model
|
||||
|
||||
if hasattr(sd_model.cond_stage_model, 'model'):
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import torch
|
||||
import platform
|
||||
from modules import paths
|
||||
from modules.sd_hijack_utils import CondFunc
|
||||
from packaging import version
|
||||
@@ -32,6 +33,10 @@ if has_mps:
|
||||
# MPS fix for randn in torchsde
|
||||
CondFunc('torchsde._brownian.brownian_interval._randn', lambda _, size, dtype, device, seed: torch.randn(size, dtype=dtype, device=torch.device("cpu"), generator=torch.Generator(torch.device("cpu")).manual_seed(int(seed))).to(device), lambda _, size, dtype, device, seed: device.type == 'mps')
|
||||
|
||||
if platform.mac_ver()[0].startswith("13.2."):
|
||||
# MPS workaround for https://github.com/pytorch/pytorch/issues/95188, thanks to danieldk (https://github.com/explosion/curated-transformers/pull/124)
|
||||
CondFunc('torch.nn.functional.linear', lambda _, input, weight, bias: (torch.matmul(input, weight.t()) + bias) if bias is not None else torch.matmul(input, weight.t()), lambda _, input, weight, bias: input.numel() > 10485760)
|
||||
|
||||
if version.parse(torch.__version__) < version.parse("1.13"):
|
||||
# PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working
|
||||
|
||||
@@ -49,4 +54,6 @@ if has_mps:
|
||||
CondFunc('torch.cumsum', cumsum_fix_func, None)
|
||||
CondFunc('torch.Tensor.cumsum', cumsum_fix_func, None)
|
||||
CondFunc('torch.narrow', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).clone(), None)
|
||||
|
||||
if version.parse(torch.__version__) == version.parse("2.0"):
|
||||
# MPS workaround for https://github.com/pytorch/pytorch/issues/96113
|
||||
CondFunc('torch.nn.functional.layer_norm', lambda orig_func, x, normalized_shape, weight, bias, eps, **kwargs: orig_func(x.float(), normalized_shape, weight.float() if weight is not None else None, bias.float() if bias is not None else bias, eps).to(x.dtype), lambda *args, **kwargs: len(args) == 6)
|
||||
|
||||
@@ -4,7 +4,6 @@ import shutil
|
||||
import importlib
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from basicsr.utils.download_util import load_file_from_url
|
||||
from modules import shared
|
||||
from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, UpscalerNone
|
||||
from modules.paths import script_path, models_path
|
||||
@@ -59,6 +58,7 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None
|
||||
|
||||
if model_url is not None and len(output) == 0:
|
||||
if download_name is not None:
|
||||
from basicsr.utils.download_util import load_file_from_url
|
||||
dl = load_file_from_url(model_url, model_path, True, download_name)
|
||||
output.append(dl)
|
||||
else:
|
||||
|
||||
+2
-9
@@ -1,16 +1,9 @@
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
from modules.paths_internal import models_path, script_path, data_path, extensions_dir, extensions_builtin_dir
|
||||
|
||||
import modules.safe
|
||||
|
||||
script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
|
||||
# Parse the --data-dir flag first so we can use it as a base for our other argument default values
|
||||
parser = argparse.ArgumentParser(add_help=False)
|
||||
parser.add_argument("--data-dir", type=str, default=os.path.dirname(os.path.dirname(os.path.realpath(__file__))), help="base path where all user data is stored",)
|
||||
cmd_opts_pre = parser.parse_known_args()[0]
|
||||
data_path = cmd_opts_pre.data_dir
|
||||
models_path = os.path.join(data_path, "models")
|
||||
|
||||
# data_path = cmd_opts_pre.data
|
||||
sys.path.insert(0, script_path)
|
||||
|
||||
@@ -0,0 +1,22 @@
|
||||
"""this module defines internal paths used by program and is safe to import before dependencies are installed in launch.py"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
|
||||
script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
|
||||
sd_configs_path = os.path.join(script_path, "configs")
|
||||
sd_default_config = os.path.join(sd_configs_path, "v1-inference.yaml")
|
||||
sd_model_file = os.path.join(script_path, 'model.ckpt')
|
||||
default_sd_model_file = sd_model_file
|
||||
|
||||
# Parse the --data-dir flag first so we can use it as a base for our other argument default values
|
||||
parser_pre = argparse.ArgumentParser(add_help=False)
|
||||
parser_pre.add_argument("--data-dir", type=str, default=os.path.dirname(os.path.dirname(os.path.realpath(__file__))), help="base path where all user data is stored",)
|
||||
cmd_opts_pre = parser_pre.parse_known_args()[0]
|
||||
|
||||
data_path = cmd_opts_pre.data_dir
|
||||
|
||||
models_path = os.path.join(data_path, "models")
|
||||
extensions_dir = os.path.join(data_path, "extensions")
|
||||
extensions_builtin_dir = os.path.join(script_path, "extensions-builtin")
|
||||
@@ -18,9 +18,15 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir,
|
||||
|
||||
if extras_mode == 1:
|
||||
for img in image_folder:
|
||||
image = Image.open(img)
|
||||
if isinstance(img, Image.Image):
|
||||
image = img
|
||||
fn = ''
|
||||
else:
|
||||
image = Image.open(os.path.abspath(img.name))
|
||||
fn = os.path.splitext(img.orig_name)[0]
|
||||
|
||||
image_data.append(image)
|
||||
image_names.append(os.path.splitext(img.orig_name)[0])
|
||||
image_names.append(fn)
|
||||
elif extras_mode == 2:
|
||||
assert not shared.cmd_opts.hide_ui_dir_config, '--hide-ui-dir-config option must be disabled'
|
||||
assert input_dir, 'input directory not selected'
|
||||
|
||||
+62
-14
@@ -3,6 +3,7 @@ import math
|
||||
import os
|
||||
import sys
|
||||
import warnings
|
||||
import hashlib
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
@@ -78,22 +79,28 @@ def apply_overlay(image, paste_loc, index, overlays):
|
||||
|
||||
|
||||
def txt2img_image_conditioning(sd_model, x, width, height):
|
||||
if sd_model.model.conditioning_key not in {'hybrid', 'concat'}:
|
||||
# Dummy zero conditioning if we're not using inpainting model.
|
||||
if sd_model.model.conditioning_key in {'hybrid', 'concat'}: # Inpainting models
|
||||
|
||||
# The "masked-image" in this case will just be all zeros since the entire image is masked.
|
||||
image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device)
|
||||
image_conditioning = sd_model.get_first_stage_encoding(sd_model.encode_first_stage(image_conditioning))
|
||||
|
||||
# Add the fake full 1s mask to the first dimension.
|
||||
image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
|
||||
image_conditioning = image_conditioning.to(x.dtype)
|
||||
|
||||
return image_conditioning
|
||||
|
||||
elif sd_model.model.conditioning_key == "crossattn-adm": # UnCLIP models
|
||||
|
||||
return x.new_zeros(x.shape[0], 2*sd_model.noise_augmentor.time_embed.dim, dtype=x.dtype, device=x.device)
|
||||
|
||||
else:
|
||||
# Dummy zero conditioning if we're not using inpainting or unclip models.
|
||||
# Still takes up a bit of memory, but no encoder call.
|
||||
# Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
|
||||
return x.new_zeros(x.shape[0], 5, 1, 1, dtype=x.dtype, device=x.device)
|
||||
|
||||
# The "masked-image" in this case will just be all zeros since the entire image is masked.
|
||||
image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device)
|
||||
image_conditioning = sd_model.get_first_stage_encoding(sd_model.encode_first_stage(image_conditioning))
|
||||
|
||||
# Add the fake full 1s mask to the first dimension.
|
||||
image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
|
||||
image_conditioning = image_conditioning.to(x.dtype)
|
||||
|
||||
return image_conditioning
|
||||
|
||||
|
||||
class StableDiffusionProcessing:
|
||||
"""
|
||||
@@ -190,6 +197,14 @@ class StableDiffusionProcessing:
|
||||
|
||||
return conditioning_image
|
||||
|
||||
def unclip_image_conditioning(self, source_image):
|
||||
c_adm = self.sd_model.embedder(source_image)
|
||||
if self.sd_model.noise_augmentor is not None:
|
||||
noise_level = 0 # TODO: Allow other noise levels?
|
||||
c_adm, noise_level_emb = self.sd_model.noise_augmentor(c_adm, noise_level=repeat(torch.tensor([noise_level]).to(c_adm.device), '1 -> b', b=c_adm.shape[0]))
|
||||
c_adm = torch.cat((c_adm, noise_level_emb), 1)
|
||||
return c_adm
|
||||
|
||||
def inpainting_image_conditioning(self, source_image, latent_image, image_mask=None):
|
||||
self.is_using_inpainting_conditioning = True
|
||||
|
||||
@@ -241,6 +256,9 @@ class StableDiffusionProcessing:
|
||||
if self.sampler.conditioning_key in {'hybrid', 'concat'}:
|
||||
return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask)
|
||||
|
||||
if self.sampler.conditioning_key == "crossattn-adm":
|
||||
return self.unclip_image_conditioning(source_image)
|
||||
|
||||
# Dummy zero conditioning if we're not using inpainting or depth model.
|
||||
return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1)
|
||||
|
||||
@@ -459,6 +477,8 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
|
||||
"Conditional mask weight": getattr(p, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) if p.is_using_inpainting_conditioning else None,
|
||||
"Clip skip": None if clip_skip <= 1 else clip_skip,
|
||||
"ENSD": None if opts.eta_noise_seed_delta == 0 else opts.eta_noise_seed_delta,
|
||||
"Init image hash": getattr(p, 'init_img_hash', None),
|
||||
"RNG": (opts.randn_source if opts.randn_source != "GPU" else None)
|
||||
}
|
||||
|
||||
generation_params.update(p.extra_generation_params)
|
||||
@@ -622,8 +642,14 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||
processed = Processed(p, [], p.seed, "")
|
||||
file.write(processed.infotext(p, 0))
|
||||
|
||||
uc = get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, p.steps, cached_uc)
|
||||
c = get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, p.steps, cached_c)
|
||||
step_multiplier = 1
|
||||
if not shared.opts.dont_fix_second_order_samplers_schedule:
|
||||
try:
|
||||
step_multiplier = 2 if sd_samplers.all_samplers_map.get(p.sampler_name).aliases[0] in ['k_dpmpp_2s_a', 'k_dpmpp_2s_a_ka', 'k_dpmpp_sde', 'k_dpmpp_sde_ka', 'k_dpm_2', 'k_dpm_2_a', 'k_heun'] else 1
|
||||
except:
|
||||
pass
|
||||
uc = get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, p.steps * step_multiplier, cached_uc)
|
||||
c = get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, p.steps * step_multiplier, cached_c)
|
||||
|
||||
if len(model_hijack.comments) > 0:
|
||||
for comment in model_hijack.comments:
|
||||
@@ -689,6 +715,22 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||
image.info["parameters"] = text
|
||||
output_images.append(image)
|
||||
|
||||
if hasattr(p, 'mask_for_overlay') and p.mask_for_overlay:
|
||||
image_mask = p.mask_for_overlay.convert('RGB')
|
||||
image_mask_composite = Image.composite(image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), p.mask_for_overlay.convert('L')).convert('RGBA')
|
||||
|
||||
if opts.save_mask:
|
||||
images.save_image(image_mask, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-mask")
|
||||
|
||||
if opts.save_mask_composite:
|
||||
images.save_image(image_mask_composite, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-mask-composite")
|
||||
|
||||
if opts.return_mask:
|
||||
output_images.append(image_mask)
|
||||
|
||||
if opts.return_mask_composite:
|
||||
output_images.append(image_mask_composite)
|
||||
|
||||
del x_samples_ddim
|
||||
|
||||
devices.torch_gc()
|
||||
@@ -974,6 +1016,12 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
||||
self.color_corrections = []
|
||||
imgs = []
|
||||
for img in self.init_images:
|
||||
|
||||
# Save init image
|
||||
if opts.save_init_img:
|
||||
self.init_img_hash = hashlib.md5(img.tobytes()).hexdigest()
|
||||
images.save_image(img, path=opts.outdir_init_images, basename=None, forced_filename=self.init_img_hash, save_to_dirs=False)
|
||||
|
||||
image = images.flatten(img, opts.img2img_background_color)
|
||||
|
||||
if crop_region is None and self.resize_mode != 3:
|
||||
|
||||
+1
-4
@@ -1,6 +1,5 @@
|
||||
# this code is adapted from the script contributed by anon from /h/
|
||||
|
||||
import io
|
||||
import pickle
|
||||
import collections
|
||||
import sys
|
||||
@@ -12,11 +11,9 @@ import _codecs
|
||||
import zipfile
|
||||
import re
|
||||
|
||||
|
||||
# PyTorch 1.13 and later have _TypedStorage renamed to TypedStorage
|
||||
TypedStorage = torch.storage.TypedStorage if hasattr(torch.storage, 'TypedStorage') else torch.storage._TypedStorage
|
||||
|
||||
|
||||
def encode(*args):
|
||||
out = _codecs.encode(*args)
|
||||
return out
|
||||
@@ -27,7 +24,7 @@ class RestrictedUnpickler(pickle.Unpickler):
|
||||
|
||||
def persistent_load(self, saved_id):
|
||||
assert saved_id[0] == 'storage'
|
||||
return TypedStorage()
|
||||
return TypedStorage(_internal=True)
|
||||
|
||||
def find_class(self, module, name):
|
||||
if self.extra_handler is not None:
|
||||
|
||||
+35
-1
@@ -239,7 +239,15 @@ def load_scripts():
|
||||
elif issubclass(script_class, scripts_postprocessing.ScriptPostprocessing):
|
||||
postprocessing_scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module))
|
||||
|
||||
for scriptfile in sorted(scripts_list):
|
||||
def orderby(basedir):
|
||||
# 1st webui, 2nd extensions-builtin, 3rd extensions
|
||||
priority = {os.path.join(paths.script_path, "extensions-builtin"):1, paths.script_path:0}
|
||||
for key in priority:
|
||||
if basedir.startswith(key):
|
||||
return priority[key]
|
||||
return 9999
|
||||
|
||||
for scriptfile in sorted(scripts_list, key=lambda x: [orderby(x.basedir), x]):
|
||||
try:
|
||||
if scriptfile.basedir != paths.script_path:
|
||||
sys.path = [scriptfile.basedir] + sys.path
|
||||
@@ -513,6 +521,18 @@ def reload_scripts():
|
||||
scripts_postproc = scripts_postprocessing.ScriptPostprocessingRunner()
|
||||
|
||||
|
||||
def add_classes_to_gradio_component(comp):
|
||||
"""
|
||||
this adds gradio-* to the component for css styling (ie gradio-button to gr.Button), as well as some others
|
||||
"""
|
||||
|
||||
comp.elem_classes = ["gradio-" + comp.get_block_name(), *(comp.elem_classes or [])]
|
||||
|
||||
if getattr(comp, 'multiselect', False):
|
||||
comp.elem_classes.append('multiselect')
|
||||
|
||||
|
||||
|
||||
def IOComponent_init(self, *args, **kwargs):
|
||||
if scripts_current is not None:
|
||||
scripts_current.before_component(self, **kwargs)
|
||||
@@ -521,6 +541,8 @@ def IOComponent_init(self, *args, **kwargs):
|
||||
|
||||
res = original_IOComponent_init(self, *args, **kwargs)
|
||||
|
||||
add_classes_to_gradio_component(self)
|
||||
|
||||
script_callbacks.after_component_callback(self, **kwargs)
|
||||
|
||||
if scripts_current is not None:
|
||||
@@ -531,3 +553,15 @@ def IOComponent_init(self, *args, **kwargs):
|
||||
|
||||
original_IOComponent_init = gr.components.IOComponent.__init__
|
||||
gr.components.IOComponent.__init__ = IOComponent_init
|
||||
|
||||
|
||||
def BlockContext_init(self, *args, **kwargs):
|
||||
res = original_BlockContext_init(self, *args, **kwargs)
|
||||
|
||||
add_classes_to_gradio_component(self)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
original_BlockContext_init = gr.blocks.BlockContext.__init__
|
||||
gr.blocks.BlockContext.__init__ = BlockContext_init
|
||||
|
||||
@@ -109,7 +109,7 @@ class ScriptPostprocessingRunner:
|
||||
inputs = []
|
||||
|
||||
for script in self.scripts_in_preferred_order():
|
||||
with gr.Box() as group:
|
||||
with gr.Row() as group:
|
||||
self.create_script_ui(script, inputs)
|
||||
|
||||
script.group = group
|
||||
|
||||
@@ -337,7 +337,7 @@ def xformers_attention_forward(self, x, context=None, mask=None):
|
||||
|
||||
dtype = q.dtype
|
||||
if shared.opts.upcast_attn:
|
||||
q, k = q.float(), k.float()
|
||||
q, k, v = q.float(), k.float(), v.float()
|
||||
|
||||
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=get_xformers_flash_attention_op(q, k, v))
|
||||
|
||||
@@ -372,7 +372,7 @@ def scaled_dot_product_attention_forward(self, x, context=None, mask=None):
|
||||
|
||||
dtype = q.dtype
|
||||
if shared.opts.upcast_attn:
|
||||
q, k = q.float(), k.float()
|
||||
q, k, v = q.float(), k.float(), v.float()
|
||||
|
||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||
hidden_states = torch.nn.functional.scaled_dot_product_attention(
|
||||
|
||||
@@ -67,7 +67,7 @@ def hijack_ddpm_edit():
|
||||
unet_needs_upcast = lambda *args, **kwargs: devices.unet_needs_upcast
|
||||
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model, unet_needs_upcast)
|
||||
CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, timesteps, *args, **kwargs: orig_func(timesteps, *args, **kwargs).to(torch.float32 if timesteps.dtype == torch.int64 else devices.dtype_unet), unet_needs_upcast)
|
||||
if version.parse(torch.__version__) <= version.parse("1.13.1"):
|
||||
if version.parse(torch.__version__) <= version.parse("1.13.2") or torch.cuda.is_available():
|
||||
CondFunc('ldm.modules.diffusionmodules.util.GroupNorm32.forward', lambda orig_func, self, *args, **kwargs: orig_func(self.float(), *args, **kwargs), unet_needs_upcast)
|
||||
CondFunc('ldm.modules.attention.GEGLU.forward', lambda orig_func, self, x: orig_func(self.float(), x.float()).to(devices.dtype_unet), unet_needs_upcast)
|
||||
CondFunc('open_clip.transformer.ResidualAttentionBlock.__init__', lambda orig_func, *args, **kwargs: kwargs.update({'act_layer': GELUHijack}) and False or orig_func(*args, **kwargs), lambda _, *args, **kwargs: kwargs.get('act_layer') is None or kwargs['act_layer'] == torch.nn.GELU)
|
||||
|
||||
+32
-4
@@ -122,7 +122,7 @@ def list_models():
|
||||
elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file:
|
||||
print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr)
|
||||
|
||||
for filename in model_list:
|
||||
for filename in sorted(model_list, key=str.lower):
|
||||
checkpoint_info = CheckpointInfo(filename)
|
||||
checkpoint_info.register()
|
||||
|
||||
@@ -178,7 +178,7 @@ def select_checkpoint():
|
||||
return checkpoint_info
|
||||
|
||||
|
||||
chckpoint_dict_replacements = {
|
||||
checkpoint_dict_replacements = {
|
||||
'cond_stage_model.transformer.embeddings.': 'cond_stage_model.transformer.text_model.embeddings.',
|
||||
'cond_stage_model.transformer.encoder.': 'cond_stage_model.transformer.text_model.encoder.',
|
||||
'cond_stage_model.transformer.final_layer_norm.': 'cond_stage_model.transformer.text_model.final_layer_norm.',
|
||||
@@ -186,7 +186,7 @@ chckpoint_dict_replacements = {
|
||||
|
||||
|
||||
def transform_checkpoint_dict_key(k):
|
||||
for text, replacement in chckpoint_dict_replacements.items():
|
||||
for text, replacement in checkpoint_dict_replacements.items():
|
||||
if k.startswith(text):
|
||||
k = replacement + k[len(text):]
|
||||
|
||||
@@ -383,6 +383,14 @@ def repair_config(sd_config):
|
||||
elif shared.cmd_opts.upcast_sampling:
|
||||
sd_config.model.params.unet_config.params.use_fp16 = True
|
||||
|
||||
if getattr(sd_config.model.params.first_stage_config.params.ddconfig, "attn_type", None) == "vanilla-xformers" and not shared.xformers_available:
|
||||
sd_config.model.params.first_stage_config.params.ddconfig.attn_type = "vanilla"
|
||||
|
||||
# For UnCLIP-L, override the hardcoded karlo directory
|
||||
if hasattr(sd_config.model.params, "noise_aug_config") and hasattr(sd_config.model.params.noise_aug_config.params, "clip_stats_path"):
|
||||
karlo_path = os.path.join(paths.models_path, 'karlo')
|
||||
sd_config.model.params.noise_aug_config.params.clip_stats_path = sd_config.model.params.noise_aug_config.params.clip_stats_path.replace("checkpoints/karlo_models", karlo_path)
|
||||
|
||||
|
||||
sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight'
|
||||
sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight'
|
||||
@@ -494,7 +502,7 @@ def reload_model_weights(sd_model=None, info=None):
|
||||
if sd_model is None or checkpoint_config != sd_model.used_config:
|
||||
del sd_model
|
||||
checkpoints_loaded.clear()
|
||||
load_model(checkpoint_info, already_loaded_state_dict=state_dict, time_taken_to_load_state_dict=timer.records["load weights from disk"])
|
||||
load_model(checkpoint_info, already_loaded_state_dict=state_dict)
|
||||
return shared.sd_model
|
||||
|
||||
try:
|
||||
@@ -517,3 +525,23 @@ def reload_model_weights(sd_model=None, info=None):
|
||||
print(f"Weights loaded in {timer.summary()}.")
|
||||
|
||||
return sd_model
|
||||
|
||||
def unload_model_weights(sd_model=None, info=None):
|
||||
from modules import lowvram, devices, sd_hijack
|
||||
timer = Timer()
|
||||
|
||||
if shared.sd_model:
|
||||
|
||||
# shared.sd_model.cond_stage_model.to(devices.cpu)
|
||||
# shared.sd_model.first_stage_model.to(devices.cpu)
|
||||
shared.sd_model.to(devices.cpu)
|
||||
sd_hijack.model_hijack.undo_hijack(shared.sd_model)
|
||||
shared.sd_model = None
|
||||
sd_model = None
|
||||
gc.collect()
|
||||
devices.torch_gc()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
print(f"Unloaded weights {timer.summary()}.")
|
||||
|
||||
return sd_model
|
||||
@@ -14,6 +14,8 @@ config_sd2 = os.path.join(sd_repo_configs_path, "v2-inference.yaml")
|
||||
config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml")
|
||||
config_sd2_inpainting = os.path.join(sd_repo_configs_path, "v2-inpainting-inference.yaml")
|
||||
config_depth_model = os.path.join(sd_repo_configs_path, "v2-midas-inference.yaml")
|
||||
config_unclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-l-inference.yaml")
|
||||
config_unopenclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-h-inference.yaml")
|
||||
config_inpainting = os.path.join(sd_configs_path, "v1-inpainting-inference.yaml")
|
||||
config_instruct_pix2pix = os.path.join(sd_configs_path, "instruct-pix2pix.yaml")
|
||||
config_alt_diffusion = os.path.join(sd_configs_path, "alt-diffusion-inference.yaml")
|
||||
@@ -65,9 +67,14 @@ def is_using_v_parameterization_for_sd2(state_dict):
|
||||
def guess_model_config_from_state_dict(sd, filename):
|
||||
sd2_cond_proj_weight = sd.get('cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight', None)
|
||||
diffusion_model_input = sd.get('model.diffusion_model.input_blocks.0.0.weight', None)
|
||||
sd2_variations_weight = sd.get('embedder.model.ln_final.weight', None)
|
||||
|
||||
if sd.get('depth_model.model.pretrained.act_postprocess3.0.project.0.bias', None) is not None:
|
||||
return config_depth_model
|
||||
elif sd2_variations_weight is not None and sd2_variations_weight.shape[0] == 768:
|
||||
return config_unclip
|
||||
elif sd2_variations_weight is not None and sd2_variations_weight.shape[0] == 1024:
|
||||
return config_unopenclip
|
||||
|
||||
if sd2_cond_proj_weight is not None and sd2_cond_proj_weight.shape[1] == 1024:
|
||||
if diffusion_model_input.shape[1] == 9:
|
||||
|
||||
@@ -60,3 +60,13 @@ def store_latent(decoded):
|
||||
|
||||
class InterruptedException(BaseException):
|
||||
pass
|
||||
|
||||
|
||||
if opts.randn_source == "CPU":
|
||||
import torchsde._brownian.brownian_interval
|
||||
|
||||
def torchsde_randn(size, dtype, device, seed):
|
||||
generator = torch.Generator(devices.cpu).manual_seed(int(seed))
|
||||
return torch.randn(size, dtype=dtype, device=devices.cpu, generator=generator).to(device)
|
||||
|
||||
torchsde._brownian.brownian_interval._randn = torchsde_randn
|
||||
|
||||
@@ -70,8 +70,13 @@ class VanillaStableDiffusionSampler:
|
||||
|
||||
# Have to unwrap the inpainting conditioning here to perform pre-processing
|
||||
image_conditioning = None
|
||||
uc_image_conditioning = None
|
||||
if isinstance(cond, dict):
|
||||
image_conditioning = cond["c_concat"][0]
|
||||
if self.conditioning_key == "crossattn-adm":
|
||||
image_conditioning = cond["c_adm"]
|
||||
uc_image_conditioning = unconditional_conditioning["c_adm"]
|
||||
else:
|
||||
image_conditioning = cond["c_concat"][0]
|
||||
cond = cond["c_crossattn"][0]
|
||||
unconditional_conditioning = unconditional_conditioning["c_crossattn"][0]
|
||||
|
||||
@@ -98,8 +103,12 @@ class VanillaStableDiffusionSampler:
|
||||
# Wrap the image conditioning back up since the DDIM code can accept the dict directly.
|
||||
# Note that they need to be lists because it just concatenates them later.
|
||||
if image_conditioning is not None:
|
||||
cond = {"c_concat": [image_conditioning], "c_crossattn": [cond]}
|
||||
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
|
||||
if self.conditioning_key == "crossattn-adm":
|
||||
cond = {"c_adm": image_conditioning, "c_crossattn": [cond]}
|
||||
unconditional_conditioning = {"c_adm": uc_image_conditioning, "c_crossattn": [unconditional_conditioning]}
|
||||
else:
|
||||
cond = {"c_concat": [image_conditioning], "c_crossattn": [cond]}
|
||||
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
|
||||
|
||||
return x, ts, cond, unconditional_conditioning
|
||||
|
||||
@@ -176,8 +185,12 @@ class VanillaStableDiffusionSampler:
|
||||
|
||||
# Wrap the conditioning models with additional image conditioning for inpainting model
|
||||
if image_conditioning is not None:
|
||||
conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]}
|
||||
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
|
||||
if self.conditioning_key == "crossattn-adm":
|
||||
conditioning = {"c_adm": image_conditioning, "c_crossattn": [conditioning]}
|
||||
unconditional_conditioning = {"c_adm": torch.zeros_like(image_conditioning), "c_crossattn": [unconditional_conditioning]}
|
||||
else:
|
||||
conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]}
|
||||
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
|
||||
|
||||
samples = self.launch_sampling(t_enc + 1, lambda: self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning))
|
||||
|
||||
@@ -195,8 +208,12 @@ class VanillaStableDiffusionSampler:
|
||||
# Wrap the conditioning models with additional image conditioning for inpainting model
|
||||
# dummy_for_plms is needed because PLMS code checks the first item in the dict to have the right shape
|
||||
if image_conditioning is not None:
|
||||
conditioning = {"dummy_for_plms": np.zeros((conditioning.shape[0],)), "c_crossattn": [conditioning], "c_concat": [image_conditioning]}
|
||||
unconditional_conditioning = {"c_crossattn": [unconditional_conditioning], "c_concat": [image_conditioning]}
|
||||
if self.conditioning_key == "crossattn-adm":
|
||||
conditioning = {"dummy_for_plms": np.zeros((conditioning.shape[0],)), "c_crossattn": [conditioning], "c_adm": image_conditioning}
|
||||
unconditional_conditioning = {"c_crossattn": [unconditional_conditioning], "c_adm": torch.zeros_like(image_conditioning)}
|
||||
else:
|
||||
conditioning = {"dummy_for_plms": np.zeros((conditioning.shape[0],)), "c_crossattn": [conditioning], "c_concat": [image_conditioning]}
|
||||
unconditional_conditioning = {"c_crossattn": [unconditional_conditioning], "c_concat": [image_conditioning]}
|
||||
|
||||
samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0])
|
||||
|
||||
|
||||
@@ -92,14 +92,21 @@ class CFGDenoiser(torch.nn.Module):
|
||||
batch_size = len(conds_list)
|
||||
repeats = [len(conds_list[i]) for i in range(batch_size)]
|
||||
|
||||
if shared.sd_model.model.conditioning_key == "crossattn-adm":
|
||||
image_uncond = torch.zeros_like(image_cond)
|
||||
make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": c_crossattn, "c_adm": c_adm}
|
||||
else:
|
||||
image_uncond = image_cond
|
||||
make_condition_dict = lambda c_crossattn, c_concat: {"c_crossattn": c_crossattn, "c_concat": [c_concat]}
|
||||
|
||||
if not is_edit_model:
|
||||
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
|
||||
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma])
|
||||
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond])
|
||||
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_uncond])
|
||||
else:
|
||||
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x] + [x])
|
||||
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma] + [sigma])
|
||||
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond] + [torch.zeros_like(self.init_latent)])
|
||||
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_uncond] + [torch.zeros_like(self.init_latent)])
|
||||
|
||||
denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps, tensor, uncond)
|
||||
cfg_denoiser_callback(denoiser_params)
|
||||
@@ -116,13 +123,13 @@ class CFGDenoiser(torch.nn.Module):
|
||||
cond_in = torch.cat([tensor, uncond, uncond])
|
||||
|
||||
if shared.batch_cond_uncond:
|
||||
x_out = self.inner_model(x_in, sigma_in, cond={"c_crossattn": [cond_in], "c_concat": [image_cond_in]})
|
||||
x_out = self.inner_model(x_in, sigma_in, cond=make_condition_dict([cond_in], image_cond_in))
|
||||
else:
|
||||
x_out = torch.zeros_like(x_in)
|
||||
for batch_offset in range(0, x_out.shape[0], batch_size):
|
||||
a = batch_offset
|
||||
b = a + batch_size
|
||||
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": [cond_in[a:b]], "c_concat": [image_cond_in[a:b]]})
|
||||
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict([cond_in[a:b]], image_cond_in[a:b]))
|
||||
else:
|
||||
x_out = torch.zeros_like(x_in)
|
||||
batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size
|
||||
@@ -135,9 +142,9 @@ class CFGDenoiser(torch.nn.Module):
|
||||
else:
|
||||
c_crossattn = torch.cat([tensor[a:b]], uncond)
|
||||
|
||||
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": c_crossattn, "c_concat": [image_cond_in[a:b]]})
|
||||
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(c_crossattn, image_cond_in[a:b]))
|
||||
|
||||
x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond={"c_crossattn": [uncond], "c_concat": [image_cond_in[-uncond.shape[0]:]]})
|
||||
x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=make_condition_dict([uncond], image_cond_in[-uncond.shape[0]:]))
|
||||
|
||||
denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps)
|
||||
cfg_denoised_callback(denoised_params)
|
||||
@@ -183,7 +190,7 @@ class TorchHijack:
|
||||
if noise.shape == x.shape:
|
||||
return noise
|
||||
|
||||
if x.device.type == 'mps':
|
||||
if opts.randn_source == "CPU" or x.device.type == 'mps':
|
||||
return torch.randn_like(x, device=devices.cpu).to(x.device)
|
||||
else:
|
||||
return torch.randn_like(x)
|
||||
|
||||
+59
-101
@@ -4,6 +4,7 @@ import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import requests
|
||||
|
||||
from PIL import Image
|
||||
import gradio as gr
|
||||
@@ -13,114 +14,22 @@ import modules.interrogate
|
||||
import modules.memmon
|
||||
import modules.styles
|
||||
import modules.devices as devices
|
||||
from modules import localization, extensions, script_loading, errors, ui_components, shared_items
|
||||
from modules.paths import models_path, script_path, data_path
|
||||
|
||||
from modules import localization, script_loading, errors, ui_components, shared_items, cmd_args
|
||||
from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir
|
||||
|
||||
demo = None
|
||||
|
||||
sd_configs_path = os.path.join(script_path, "configs")
|
||||
sd_default_config = os.path.join(sd_configs_path, "v1-inference.yaml")
|
||||
sd_model_file = os.path.join(script_path, 'model.ckpt')
|
||||
default_sd_model_file = sd_model_file
|
||||
parser = cmd_args.parser
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--data-dir", type=str, default=os.path.dirname(os.path.dirname(os.path.realpath(__file__))), help="base path where all user data is stored",)
|
||||
parser.add_argument("--config", type=str, default=sd_default_config, help="path to config which constructs model",)
|
||||
parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",)
|
||||
parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to directory with stable diffusion checkpoints")
|
||||
parser.add_argument("--vae-dir", type=str, default=None, help="Path to directory with VAE files")
|
||||
parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN'))
|
||||
parser.add_argument("--gfpgan-model", type=str, help="GFPGAN model file name", default=None)
|
||||
parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats")
|
||||
parser.add_argument("--no-half-vae", action='store_true', help="do not switch the VAE model to 16-bit floats")
|
||||
parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware acceleration in browser)")
|
||||
parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI")
|
||||
parser.add_argument("--embeddings-dir", type=str, default=os.path.join(data_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)")
|
||||
parser.add_argument("--textual-inversion-templates-dir", type=str, default=os.path.join(script_path, 'textual_inversion_templates'), help="directory with textual inversion templates")
|
||||
parser.add_argument("--hypernetwork-dir", type=str, default=os.path.join(models_path, 'hypernetworks'), help="hypernetwork directory")
|
||||
parser.add_argument("--localizations-dir", type=str, default=os.path.join(script_path, 'localizations'), help="localizations directory")
|
||||
parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui")
|
||||
parser.add_argument("--medvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a little speed for low VRM usage")
|
||||
parser.add_argument("--lowvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a lot of speed for very low VRM usage")
|
||||
parser.add_argument("--lowram", action='store_true', help="load stable diffusion checkpoint weights to VRAM instead of RAM")
|
||||
parser.add_argument("--always-batch-cond-uncond", action='store_true', help="disables cond/uncond batching that is enabled to save memory with --medvram or --lowvram")
|
||||
parser.add_argument("--unload-gfpgan", action='store_true', help="does not do anything.")
|
||||
parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast")
|
||||
parser.add_argument("--upcast-sampling", action='store_true', help="upcast sampling. No effect with --no-half. Usually produces similar results to --no-half with better performance while using less memory.")
|
||||
parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site")
|
||||
parser.add_argument("--ngrok", type=str, help="ngrok authtoken, alternative to gradio --share", default=None)
|
||||
parser.add_argument("--ngrok-region", type=str, help="The region in which ngrok should start.", default="us")
|
||||
parser.add_argument("--enable-insecure-extension-access", action='store_true', help="enable extensions tab regardless of other options")
|
||||
parser.add_argument("--codeformer-models-path", type=str, help="Path to directory with codeformer model file(s).", default=os.path.join(models_path, 'Codeformer'))
|
||||
parser.add_argument("--gfpgan-models-path", type=str, help="Path to directory with GFPGAN model file(s).", default=os.path.join(models_path, 'GFPGAN'))
|
||||
parser.add_argument("--esrgan-models-path", type=str, help="Path to directory with ESRGAN model file(s).", default=os.path.join(models_path, 'ESRGAN'))
|
||||
parser.add_argument("--bsrgan-models-path", type=str, help="Path to directory with BSRGAN model file(s).", default=os.path.join(models_path, 'BSRGAN'))
|
||||
parser.add_argument("--realesrgan-models-path", type=str, help="Path to directory with RealESRGAN model file(s).", default=os.path.join(models_path, 'RealESRGAN'))
|
||||
parser.add_argument("--clip-models-path", type=str, help="Path to directory with CLIP model file(s).", default=None)
|
||||
parser.add_argument("--xformers", action='store_true', help="enable xformers for cross attention layers")
|
||||
parser.add_argument("--force-enable-xformers", action='store_true', help="enable xformers for cross attention layers regardless of whether the checking code thinks you can run it; do not make bug reports if this fails to work")
|
||||
parser.add_argument("--xformers-flash-attention", action='store_true', help="enable xformers with Flash Attention to improve reproducibility (supported for SD2.x or variant only)")
|
||||
parser.add_argument("--deepdanbooru", action='store_true', help="does not do anything")
|
||||
parser.add_argument("--opt-split-attention", action='store_true', help="force-enables Doggettx's cross-attention layer optimization. By default, it's on for torch cuda.")
|
||||
parser.add_argument("--opt-sub-quad-attention", action='store_true', help="enable memory efficient sub-quadratic cross-attention layer optimization")
|
||||
parser.add_argument("--sub-quad-q-chunk-size", type=int, help="query chunk size for the sub-quadratic cross-attention layer optimization to use", default=1024)
|
||||
parser.add_argument("--sub-quad-kv-chunk-size", type=int, help="kv chunk size for the sub-quadratic cross-attention layer optimization to use", default=None)
|
||||
parser.add_argument("--sub-quad-chunk-threshold", type=int, help="the percentage of VRAM threshold for the sub-quadratic cross-attention layer optimization to use chunking", default=None)
|
||||
parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="force-enables InvokeAI's cross-attention layer optimization. By default, it's on when cuda is unavailable.")
|
||||
parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find")
|
||||
parser.add_argument("--opt-sdp-attention", action='store_true', help="enable scaled dot product cross-attention layer optimization; requires PyTorch 2.*")
|
||||
parser.add_argument("--opt-sdp-no-mem-attention", action='store_true', help="enable scaled dot product cross-attention layer optimization without memory efficient attention, makes image generation deterministic; requires PyTorch 2.*")
|
||||
parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization")
|
||||
parser.add_argument("--disable-nan-check", action='store_true', help="do not check if produced images/latent spaces have nans; useful for running without a checkpoint in CI")
|
||||
parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for specified modules", default=[], type=str.lower)
|
||||
parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")
|
||||
parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None)
|
||||
parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False)
|
||||
parser.add_argument("--ui-config-file", type=str, help="filename to use for ui configuration", default=os.path.join(data_path, 'ui-config.json'))
|
||||
parser.add_argument("--hide-ui-dir-config", action='store_true', help="hide directory configuration from webui", default=False)
|
||||
parser.add_argument("--freeze-settings", action='store_true', help="disable editing settings", default=False)
|
||||
parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui settings", default=os.path.join(data_path, 'config.json'))
|
||||
parser.add_argument("--gradio-debug", action='store_true', help="launch gradio with --debug option")
|
||||
parser.add_argument("--gradio-auth", type=str, help='set gradio authentication like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None)
|
||||
parser.add_argument("--gradio-auth-path", type=str, help='set gradio authentication file path ex. "/path/to/auth/file" same auth format as --gradio-auth', default=None)
|
||||
parser.add_argument("--gradio-img2img-tool", type=str, help='does not do anything')
|
||||
parser.add_argument("--gradio-inpaint-tool", type=str, help="does not do anything")
|
||||
parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last")
|
||||
parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(data_path, 'styles.csv'))
|
||||
parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False)
|
||||
parser.add_argument("--theme", type=str, help="launches the UI with light or dark theme", default=None)
|
||||
parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False)
|
||||
parser.add_argument("--disable-console-progressbars", action='store_true', help="do not output progressbars to console", default=False)
|
||||
parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False)
|
||||
parser.add_argument('--vae-path', type=str, help='Checkpoint to use as VAE; setting this argument disables all settings related to VAE', default=None)
|
||||
parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False)
|
||||
parser.add_argument("--api", action='store_true', help="use api=True to launch the API together with the webui (use --nowebui instead for only the API)")
|
||||
parser.add_argument("--api-auth", type=str, help='Set authentication for API like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None)
|
||||
parser.add_argument("--api-log", action='store_true', help="use api-log=True to enable logging of all API requests")
|
||||
parser.add_argument("--nowebui", action='store_true', help="use api=True to launch the API instead of the webui")
|
||||
parser.add_argument("--ui-debug-mode", action='store_true', help="Don't load model to quickly launch UI")
|
||||
parser.add_argument("--device-id", type=str, help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)", default=None)
|
||||
parser.add_argument("--administrator", action='store_true', help="Administrator rights", default=False)
|
||||
parser.add_argument("--cors-allow-origins", type=str, help="Allowed CORS origin(s) in the form of a comma-separated list (no spaces)", default=None)
|
||||
parser.add_argument("--cors-allow-origins-regex", type=str, help="Allowed CORS origin(s) in the form of a single regular expression", default=None)
|
||||
parser.add_argument("--tls-keyfile", type=str, help="Partially enables TLS, requires --tls-certfile to fully function", default=None)
|
||||
parser.add_argument("--tls-certfile", type=str, help="Partially enables TLS, requires --tls-keyfile to fully function", default=None)
|
||||
parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None)
|
||||
parser.add_argument("--gradio-queue", action='store_true', help="Uses gradio queue; experimental option; breaks restart UI button")
|
||||
parser.add_argument("--skip-version-check", action='store_true', help="Do not check versions of torch and xformers")
|
||||
parser.add_argument("--no-hashing", action='store_true', help="disable sha256 hashing of checkpoints to help loading performance", default=False)
|
||||
parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False)
|
||||
|
||||
|
||||
script_loading.preload_extensions(extensions.extensions_dir, parser)
|
||||
script_loading.preload_extensions(extensions.extensions_builtin_dir, parser)
|
||||
script_loading.preload_extensions(extensions_dir, parser)
|
||||
script_loading.preload_extensions(extensions_builtin_dir, parser)
|
||||
|
||||
if os.environ.get('IGNORE_CMD_ARGS_ERRORS', None) is None:
|
||||
cmd_opts = parser.parse_args()
|
||||
else:
|
||||
cmd_opts, _ = parser.parse_known_args()
|
||||
|
||||
|
||||
restricted_opts = {
|
||||
"samples_filename_pattern",
|
||||
"directories_filename_pattern",
|
||||
@@ -131,6 +40,7 @@ restricted_opts = {
|
||||
"outdir_grids",
|
||||
"outdir_txt2img_grids",
|
||||
"outdir_save",
|
||||
"outdir_init_images"
|
||||
}
|
||||
|
||||
ui_reorder_categories = [
|
||||
@@ -146,6 +56,21 @@ ui_reorder_categories = [
|
||||
"scripts",
|
||||
]
|
||||
|
||||
# https://huggingface.co/datasets/freddyaboulton/gradio-theme-subdomains/resolve/main/subdomains.json
|
||||
gradio_hf_hub_themes = [
|
||||
"gradio/glass",
|
||||
"gradio/monochrome",
|
||||
"gradio/seafoam",
|
||||
"gradio/soft",
|
||||
"freddyaboulton/dracula_revamped",
|
||||
"gradio/dracula_test",
|
||||
"abidlabs/dracula_test",
|
||||
"abidlabs/pakistan",
|
||||
"dawood/microsoft_windows",
|
||||
"ysharma/steampunk"
|
||||
]
|
||||
|
||||
|
||||
cmd_opts.disable_extension_access = (cmd_opts.share or cmd_opts.listen or cmd_opts.server_name) and not cmd_opts.enable_insecure_extension_access
|
||||
|
||||
devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_esrgan, devices.device_codeformer = \
|
||||
@@ -332,6 +257,8 @@ options_templates.update(options_section(('saving-images', "Saving images/grids"
|
||||
"save_images_before_face_restoration": OptionInfo(False, "Save a copy of image before doing face restoration."),
|
||||
"save_images_before_highres_fix": OptionInfo(False, "Save a copy of image before applying highres fix."),
|
||||
"save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"),
|
||||
"save_mask": OptionInfo(False, "For inpainting, save a copy of the greyscale mask"),
|
||||
"save_mask_composite": OptionInfo(False, "For inpainting, save a masked composite"),
|
||||
"jpeg_quality": OptionInfo(80, "Quality for saved jpeg images", gr.Slider, {"minimum": 1, "maximum": 100, "step": 1}),
|
||||
"webp_lossless": OptionInfo(False, "Use lossless compression for webp images"),
|
||||
"export_for_4chan": OptionInfo(True, "If the saved image file size is above the limit, or its either width or height are above the limit, save a downscaled copy as JPG"),
|
||||
@@ -343,6 +270,7 @@ options_templates.update(options_section(('saving-images', "Saving images/grids"
|
||||
"use_upscaler_name_as_suffix": OptionInfo(False, "Use upscaler name as filename suffix in the extras tab"),
|
||||
"save_selected_only": OptionInfo(True, "When using 'Save' button, only save a single selected image"),
|
||||
"do_not_add_watermark": OptionInfo(False, "Do not add watermark to images"),
|
||||
"save_init_img": OptionInfo(False, "Save init images when using img2img"),
|
||||
|
||||
"temp_dir": OptionInfo("", "Directory for temporary images; leave empty for default"),
|
||||
"clean_temp_dir_at_start": OptionInfo(False, "Cleanup non-default temporary directory when starting webui"),
|
||||
@@ -358,6 +286,7 @@ options_templates.update(options_section(('saving-paths', "Paths for saving"), {
|
||||
"outdir_txt2img_grids": OptionInfo("outputs/txt2img-grids", 'Output directory for txt2img grids', component_args=hide_dirs),
|
||||
"outdir_img2img_grids": OptionInfo("outputs/img2img-grids", 'Output directory for img2img grids', component_args=hide_dirs),
|
||||
"outdir_save": OptionInfo("log/images", "Directory for saving images using the Save button", component_args=hide_dirs),
|
||||
"outdir_init_images": OptionInfo("outputs/init-images", "Directory for saving init images when using img2img", component_args=hide_dirs),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('saving-to-dirs', "Saving to a directory"), {
|
||||
@@ -373,6 +302,8 @@ options_templates.update(options_section(('upscaling', "Upscaling"), {
|
||||
"ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for ESRGAN upscalers. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}),
|
||||
"realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI. (Requires restart)", gr.CheckboxGroup, lambda: {"choices": shared_items.realesrgan_models_names()}),
|
||||
"upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in sd_upscalers]}),
|
||||
"SCUNET_tile": OptionInfo(256, "Tile size for SCUNET upscalers. 0 = no tiling.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}),
|
||||
"SCUNET_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for SCUNET upscalers. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 64, "step": 1}),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('face-restoration', "Face restoration"), {
|
||||
@@ -421,6 +352,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
|
||||
"comma_padding_backtrack": OptionInfo(20, "Increase coherency by padding from the last comma within n tokens when using more than 75 tokens", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1 }),
|
||||
"CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}),
|
||||
"upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"),
|
||||
"randn_source": OptionInfo("GPU", "Random number generator source. Changes seeds drastically. Use CPU to produce the same picture across different vidocard vendors.", gr.Radio, {"choices": ["GPU", "CPU"]}),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('compatibility', "Compatibility"), {
|
||||
@@ -428,6 +360,7 @@ options_templates.update(options_section(('compatibility', "Compatibility"), {
|
||||
"use_old_karras_scheduler_sigmas": OptionInfo(False, "Use old karras scheduler sigmas (0.1 to 10)."),
|
||||
"no_dpmpp_sde_batch_determinism": OptionInfo(False, "Do not make DPM++ SDE deterministic across different batch sizes."),
|
||||
"use_old_hires_fix_width_height": OptionInfo(False, "For hires fix, use width/height sliders to set final resolution rather than first pass (disables Upscale by, Resize width/height to)."),
|
||||
"dont_fix_second_order_samplers_schedule": OptionInfo(False, "Do not fix prompt schedule for second order samplers."),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('interrogate', "Interrogate Options"), {
|
||||
@@ -448,12 +381,16 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"),
|
||||
options_templates.update(options_section(('extra_networks', "Extra Networks"), {
|
||||
"extra_networks_default_view": OptionInfo("cards", "Default view for Extra Networks", gr.Dropdown, {"choices": ["cards", "thumbs"]}),
|
||||
"extra_networks_default_multiplier": OptionInfo(1.0, "Multiplier for extra networks", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
||||
"extra_networks_card_width": OptionInfo(0, "Card width for Extra Networks (px)"),
|
||||
"extra_networks_card_height": OptionInfo(0, "Card height for Extra Networks (px)"),
|
||||
"extra_networks_add_text_separator": OptionInfo(" ", "Extra text to add before <...> when adding extra network to prompt"),
|
||||
"sd_hypernetwork": OptionInfo("None", "Add hypernetwork to prompt", gr.Dropdown, lambda: {"choices": [""] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks),
|
||||
"sd_hypernetwork": OptionInfo("None", "Add hypernetwork to prompt", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('ui', "User interface"), {
|
||||
"return_grid": OptionInfo(True, "Show grid in results for web"),
|
||||
"return_mask": OptionInfo(False, "For inpainting, include the greyscale mask in results for web"),
|
||||
"return_mask_composite": OptionInfo(False, "For inpainting, include masked composite in results for web"),
|
||||
"do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
|
||||
"add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
|
||||
"add_model_name_to_info": OptionInfo(True, "Add model name to generation information"),
|
||||
@@ -468,11 +405,13 @@ options_templates.update(options_section(('ui', "User interface"), {
|
||||
"dimensions_and_batch_together": OptionInfo(True, "Show Width/Height and Batch sliders in same row"),
|
||||
"keyedit_precision_attention": OptionInfo(0.1, "Ctrl+up/down precision when editing (attention:1.1)", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
|
||||
"keyedit_precision_extra": OptionInfo(0.05, "Ctrl+up/down precision when editing <extra networks:0.9>", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
|
||||
"keyedit_delimiters": OptionInfo(".,\/!?%^*;:{}=`~()", "Ctrl+up/down word delimiters"),
|
||||
"quicksettings": OptionInfo("sd_model_checkpoint", "Quicksettings list"),
|
||||
"hidden_tabs": OptionInfo([], "Hidden UI tabs (requires restart)", ui_components.DropdownMulti, lambda: {"choices": [x for x in tab_names]}),
|
||||
"ui_reorder": OptionInfo(", ".join(ui_reorder_categories), "txt2img/img2img UI item order"),
|
||||
"ui_extra_networks_tab_reorder": OptionInfo("", "Extra networks tab order"),
|
||||
"localization": OptionInfo("None", "Localization (requires restart)", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)),
|
||||
"gradio_theme": OptionInfo("Default", "Gradio theme (requires restart)", ui_components.DropdownEditable, lambda: {"choices": ["Default"] + gradio_hf_hub_themes})
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('ui', "Live previews"), {
|
||||
@@ -508,7 +447,8 @@ options_templates.update(options_section(('postprocessing', "Postprocessing"), {
|
||||
}))
|
||||
|
||||
options_templates.update(options_section((None, "Hidden options"), {
|
||||
"disabled_extensions": OptionInfo([], "Disable those extensions"),
|
||||
"disabled_extensions": OptionInfo([], "Disable these extensions"),
|
||||
"disable_all_extensions": OptionInfo("none", "Disable all extensions (preserves the list of disabled extensions)", gr.Radio, {"choices": ["none", "extra", "all"]}),
|
||||
"sd_checkpoint_hash": OptionInfo("", "SHA256 hash of the current checkpoint"),
|
||||
}))
|
||||
|
||||
@@ -685,6 +625,24 @@ clip_model = None
|
||||
|
||||
progress_print_out = sys.stdout
|
||||
|
||||
gradio_theme = gr.themes.Base()
|
||||
|
||||
|
||||
def reload_gradio_theme(theme_name=None):
|
||||
global gradio_theme
|
||||
if not theme_name:
|
||||
theme_name = opts.gradio_theme
|
||||
|
||||
if theme_name == "Default":
|
||||
gradio_theme = gr.themes.Default()
|
||||
else:
|
||||
try:
|
||||
gradio_theme = gr.themes.ThemeClass.from_hub(theme_name)
|
||||
except requests.exceptions.ConnectionError:
|
||||
print("Can't access HuggingFace Hub, falling back to default Gradio theme")
|
||||
gradio_theme = gr.themes.Default()
|
||||
|
||||
|
||||
|
||||
class TotalTQDM:
|
||||
def __init__(self):
|
||||
@@ -726,7 +684,7 @@ mem_mon.start()
|
||||
|
||||
|
||||
def listfiles(dirname):
|
||||
filenames = [os.path.join(dirname, x) for x in sorted(os.listdir(dirname)) if not x.startswith(".")]
|
||||
filenames = [os.path.join(dirname, x) for x in sorted(os.listdir(dirname), key=str.lower) if not x.startswith(".")]
|
||||
return [file for file in filenames if os.path.isfile(file)]
|
||||
|
||||
|
||||
|
||||
@@ -152,7 +152,11 @@ class EmbeddingDatabase:
|
||||
name = data.get('name', name)
|
||||
else:
|
||||
data = extract_image_data_embed(embed_image)
|
||||
name = data.get('name', name)
|
||||
if data:
|
||||
name = data.get('name', name)
|
||||
else:
|
||||
# if data is None, means this is not an embeding, just a preview image
|
||||
return
|
||||
elif ext in ['.BIN', '.PT']:
|
||||
data = torch.load(path, map_location="cpu")
|
||||
elif ext in ['.SAFETENSORS']:
|
||||
|
||||
+93
-57
@@ -20,7 +20,7 @@ from PIL import Image, PngImagePlugin
|
||||
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
|
||||
|
||||
from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks, postprocessing, ui_components, ui_common, ui_postprocessing
|
||||
from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML
|
||||
from modules.ui_components import FormRow, FormColumn, FormGroup, ToolButton, FormHTML
|
||||
from modules.paths import script_path, data_path
|
||||
|
||||
from modules.shared import opts, cmd_opts, restricted_opts
|
||||
@@ -70,17 +70,6 @@ def gr_show(visible=True):
|
||||
sample_img2img = "assets/stable-samples/img2img/sketch-mountains-input.jpg"
|
||||
sample_img2img = sample_img2img if os.path.exists(sample_img2img) else None
|
||||
|
||||
css_hide_progressbar = """
|
||||
.wrap .m-12 svg { display:none!important; }
|
||||
.wrap .m-12::before { content:"Loading..." }
|
||||
.wrap .z-20 svg { display:none!important; }
|
||||
.wrap .z-20::before { content:"Loading..." }
|
||||
.wrap.cover-bg .z-20::before { content:"" }
|
||||
.progress-bar { display:none!important; }
|
||||
.meta-text { display:none!important; }
|
||||
.meta-text-center { display:none!important; }
|
||||
"""
|
||||
|
||||
# Using constants for these since the variation selector isn't visible.
|
||||
# Important that they exactly match script.js for tooltip to work.
|
||||
random_symbol = '\U0001f3b2\ufe0f' # 🎲️
|
||||
@@ -89,7 +78,7 @@ paste_symbol = '\u2199\ufe0f' # ↙
|
||||
refresh_symbol = '\U0001f504' # 🔄
|
||||
save_style_symbol = '\U0001f4be' # 💾
|
||||
apply_style_symbol = '\U0001f4cb' # 📋
|
||||
clear_prompt_symbol = '\U0001F5D1' # 🗑️
|
||||
clear_prompt_symbol = '\U0001f5d1\ufe0f' # 🗑️
|
||||
extra_networks_symbol = '\U0001F3B4' # 🎴
|
||||
switch_values_symbol = '\U000021C5' # ⇅
|
||||
|
||||
@@ -179,14 +168,13 @@ def interrogate_deepbooru(image):
|
||||
|
||||
|
||||
def create_seed_inputs(target_interface):
|
||||
with FormRow(elem_id=target_interface + '_seed_row'):
|
||||
with FormRow(elem_id=target_interface + '_seed_row', variant="compact"):
|
||||
seed = (gr.Textbox if cmd_opts.use_textbox_seed else gr.Number)(label='Seed', value=-1, elem_id=target_interface + '_seed')
|
||||
seed.style(container=False)
|
||||
random_seed = gr.Button(random_symbol, elem_id=target_interface + '_random_seed')
|
||||
reuse_seed = gr.Button(reuse_symbol, elem_id=target_interface + '_reuse_seed')
|
||||
random_seed = ToolButton(random_symbol, elem_id=target_interface + '_random_seed', label='Random seed')
|
||||
reuse_seed = ToolButton(reuse_symbol, elem_id=target_interface + '_reuse_seed', label='Reuse seed')
|
||||
|
||||
with gr.Group(elem_id=target_interface + '_subseed_show_box'):
|
||||
seed_checkbox = gr.Checkbox(label='Extra', elem_id=target_interface + '_subseed_show', value=False)
|
||||
seed_checkbox = gr.Checkbox(label='Extra', elem_id=target_interface + '_subseed_show', value=False)
|
||||
|
||||
# Components to show/hide based on the 'Extra' checkbox
|
||||
seed_extras = []
|
||||
@@ -195,8 +183,8 @@ def create_seed_inputs(target_interface):
|
||||
seed_extras.append(seed_extra_row_1)
|
||||
subseed = gr.Number(label='Variation seed', value=-1, elem_id=target_interface + '_subseed')
|
||||
subseed.style(container=False)
|
||||
random_subseed = gr.Button(random_symbol, elem_id=target_interface + '_random_subseed')
|
||||
reuse_subseed = gr.Button(reuse_symbol, elem_id=target_interface + '_reuse_subseed')
|
||||
random_subseed = ToolButton(random_symbol, elem_id=target_interface + '_random_subseed')
|
||||
reuse_subseed = ToolButton(reuse_symbol, elem_id=target_interface + '_reuse_subseed')
|
||||
subseed_strength = gr.Slider(label='Variation strength', value=0.0, minimum=0, maximum=1, step=0.01, elem_id=target_interface + '_subseed_strength')
|
||||
|
||||
with FormRow(visible=False) as seed_extra_row_2:
|
||||
@@ -291,19 +279,19 @@ def create_toprow(is_img2img):
|
||||
with gr.Row():
|
||||
with gr.Column(scale=80):
|
||||
with gr.Row():
|
||||
negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=2, placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)")
|
||||
negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)")
|
||||
|
||||
button_interrogate = None
|
||||
button_deepbooru = None
|
||||
if is_img2img:
|
||||
with gr.Column(scale=1, elem_id="interrogate_col"):
|
||||
with gr.Column(scale=1, elem_classes="interrogate-col"):
|
||||
button_interrogate = gr.Button('Interrogate\nCLIP', elem_id="interrogate")
|
||||
button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru")
|
||||
|
||||
with gr.Column(scale=1, elem_id=f"{id_part}_actions_column"):
|
||||
with gr.Row(elem_id=f"{id_part}_generate_box"):
|
||||
interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt")
|
||||
skip = gr.Button('Skip', elem_id=f"{id_part}_skip")
|
||||
with gr.Row(elem_id=f"{id_part}_generate_box", elem_classes="generate-box"):
|
||||
interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt", elem_classes="generate-box-interrupt")
|
||||
skip = gr.Button('Skip', elem_id=f"{id_part}_skip", elem_classes="generate-box-skip")
|
||||
submit = gr.Button('Generate', elem_id=f"{id_part}_generate", variant='primary')
|
||||
|
||||
skip.click(
|
||||
@@ -325,9 +313,9 @@ def create_toprow(is_img2img):
|
||||
prompt_style_apply = ToolButton(value=apply_style_symbol, elem_id=f"{id_part}_style_apply")
|
||||
save_style = ToolButton(value=save_style_symbol, elem_id=f"{id_part}_style_create")
|
||||
|
||||
token_counter = gr.HTML(value="<span></span>", elem_id=f"{id_part}_token_counter")
|
||||
token_counter = gr.HTML(value="<span>0/75</span>", elem_id=f"{id_part}_token_counter", elem_classes=["token-counter"])
|
||||
token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button")
|
||||
negative_token_counter = gr.HTML(value="<span></span>", elem_id=f"{id_part}_negative_token_counter")
|
||||
negative_token_counter = gr.HTML(value="<span>0/75</span>", elem_id=f"{id_part}_negative_token_counter", elem_classes=["token-counter"])
|
||||
negative_token_button = gr.Button(visible=False, elem_id=f"{id_part}_negative_token_button")
|
||||
|
||||
clear_prompt_button.click(
|
||||
@@ -479,7 +467,9 @@ def create_ui():
|
||||
width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="txt2img_width")
|
||||
height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="txt2img_height")
|
||||
|
||||
res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="txt2img_res_switch_btn")
|
||||
with gr.Column(elem_id="txt2img_dimensions_row", scale=1, elem_classes="dimensions-tools"):
|
||||
res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="txt2img_res_switch_btn", label="Switch dims")
|
||||
|
||||
if opts.dimensions_and_batch_together:
|
||||
with gr.Column(elem_id="txt2img_column_batch"):
|
||||
batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count")
|
||||
@@ -492,7 +482,7 @@ def create_ui():
|
||||
seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('txt2img')
|
||||
|
||||
elif category == "checkboxes":
|
||||
with FormRow(elem_id="txt2img_checkboxes", variant="compact"):
|
||||
with FormRow(elem_classes="checkboxes-row", variant="compact"):
|
||||
restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="txt2img_restore_faces")
|
||||
tiling = gr.Checkbox(label='Tiling', value=False, elem_id="txt2img_tiling")
|
||||
enable_hr = gr.Checkbox(label='Hires. fix', value=False, elem_id="txt2img_enable_hr")
|
||||
@@ -586,7 +576,7 @@ def create_ui():
|
||||
txt2img_prompt.submit(**txt2img_args)
|
||||
submit.click(**txt2img_args)
|
||||
|
||||
res_switch_btn.click(lambda w, h: (h, w), inputs=[width, height], outputs=[width, height])
|
||||
res_switch_btn.click(lambda w, h: (h, w), inputs=[width, height], outputs=[width, height], show_progress=False)
|
||||
|
||||
txt_prompt_img.change(
|
||||
fn=modules.images.image_data,
|
||||
@@ -757,7 +747,9 @@ def create_ui():
|
||||
width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="img2img_width")
|
||||
height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="img2img_height")
|
||||
|
||||
res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="img2img_res_switch_btn")
|
||||
with gr.Column(elem_id="img2img_dimensions_row", scale=1, elem_classes="dimensions-tools"):
|
||||
res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="img2img_res_switch_btn")
|
||||
|
||||
if opts.dimensions_and_batch_together:
|
||||
with gr.Column(elem_id="img2img_column_batch"):
|
||||
batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count")
|
||||
@@ -774,7 +766,7 @@ def create_ui():
|
||||
seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('img2img')
|
||||
|
||||
elif category == "checkboxes":
|
||||
with FormRow(elem_id="img2img_checkboxes", variant="compact"):
|
||||
with FormRow(elem_classes="checkboxes-row", variant="compact"):
|
||||
restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="img2img_restore_faces")
|
||||
tiling = gr.Checkbox(label='Tiling', value=False, elem_id="img2img_tiling")
|
||||
|
||||
@@ -904,7 +896,7 @@ def create_ui():
|
||||
|
||||
img2img_prompt.submit(**img2img_args)
|
||||
submit.click(**img2img_args)
|
||||
res_switch_btn.click(lambda w, h: (h, w), inputs=[width, height], outputs=[width, height])
|
||||
res_switch_btn.click(lambda w, h: (h, w), inputs=[width, height], outputs=[width, height], show_progress=False)
|
||||
|
||||
img2img_interrogate.click(
|
||||
fn=lambda *args: process_interrogate(interrogate, *args),
|
||||
@@ -1212,7 +1204,7 @@ def create_ui():
|
||||
|
||||
with gr.Column(elem_id='ti_gallery_container'):
|
||||
ti_output = gr.Text(elem_id="ti_output", value="", show_label=False)
|
||||
ti_gallery = gr.Gallery(label='Output', show_label=False, elem_id='ti_gallery').style(grid=4)
|
||||
ti_gallery = gr.Gallery(label='Output', show_label=False, elem_id='ti_gallery').style(columns=4)
|
||||
ti_progress = gr.HTML(elem_id="ti_progress", value="")
|
||||
ti_outcome = gr.HTML(elem_id="ti_error", value="")
|
||||
|
||||
@@ -1491,11 +1483,33 @@ def create_ui():
|
||||
request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications")
|
||||
download_localization = gr.Button(value='Download localization template', elem_id="download_localization")
|
||||
reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary', elem_id="settings_reload_script_bodies")
|
||||
with gr.Row():
|
||||
unload_sd_model = gr.Button(value='Unload SD checkpoint to free VRAM', elem_id="sett_unload_sd_model")
|
||||
reload_sd_model = gr.Button(value='Reload the last SD checkpoint back into VRAM', elem_id="sett_reload_sd_model")
|
||||
|
||||
with gr.TabItem("Licenses"):
|
||||
gr.HTML(shared.html("licenses.html"), elem_id="licenses")
|
||||
|
||||
gr.Button(value="Show all pages", elem_id="settings_show_all_pages")
|
||||
|
||||
|
||||
def unload_sd_weights():
|
||||
modules.sd_models.unload_model_weights()
|
||||
|
||||
def reload_sd_weights():
|
||||
modules.sd_models.reload_model_weights()
|
||||
|
||||
unload_sd_model.click(
|
||||
fn=unload_sd_weights,
|
||||
inputs=[],
|
||||
outputs=[]
|
||||
)
|
||||
|
||||
reload_sd_model.click(
|
||||
fn=reload_sd_weights,
|
||||
inputs=[],
|
||||
outputs=[]
|
||||
)
|
||||
|
||||
request_notifications.click(
|
||||
fn=lambda: None,
|
||||
@@ -1541,22 +1555,6 @@ def create_ui():
|
||||
(train_interface, "Train", "ti"),
|
||||
]
|
||||
|
||||
css = ""
|
||||
|
||||
for cssfile in modules.scripts.list_files_with_name("style.css"):
|
||||
if not os.path.isfile(cssfile):
|
||||
continue
|
||||
|
||||
with open(cssfile, "r", encoding="utf8") as file:
|
||||
css += file.read() + "\n"
|
||||
|
||||
if os.path.exists(os.path.join(data_path, "user.css")):
|
||||
with open(os.path.join(data_path, "user.css"), "r", encoding="utf8") as file:
|
||||
css += file.read() + "\n"
|
||||
|
||||
if not cmd_opts.no_progressbar_hiding:
|
||||
css += css_hide_progressbar
|
||||
|
||||
interfaces += script_callbacks.ui_tabs_callback()
|
||||
interfaces += [(settings_interface, "Settings", "settings")]
|
||||
|
||||
@@ -1567,7 +1565,7 @@ def create_ui():
|
||||
for _interface, label, _ifid in interfaces:
|
||||
shared.tab_names.append(label)
|
||||
|
||||
with gr.Blocks(css=css, analytics_enabled=False, title="Stable Diffusion") as demo:
|
||||
with gr.Blocks(theme=shared.gradio_theme, analytics_enabled=False, title="Stable Diffusion") as demo:
|
||||
with gr.Row(elem_id="quicksettings", variant="compact"):
|
||||
for i, k, item in sorted(quicksettings_list, key=lambda x: quicksettings_names.get(x[1], x[0])):
|
||||
component = create_setting_component(k, is_quicksettings=True)
|
||||
@@ -1598,11 +1596,13 @@ def create_ui():
|
||||
|
||||
for i, k, item in quicksettings_list:
|
||||
component = component_dict[k]
|
||||
info = opts.data_labels[k]
|
||||
|
||||
component.change(
|
||||
fn=lambda value, k=k: run_settings_single(value, key=k),
|
||||
inputs=[component],
|
||||
outputs=[component, text_settings],
|
||||
show_progress=info.refresh is not None,
|
||||
)
|
||||
|
||||
text_settings.change(
|
||||
@@ -1628,6 +1628,7 @@ def create_ui():
|
||||
fn=get_settings_values,
|
||||
inputs=[],
|
||||
outputs=[component_dict[k] for k in component_keys],
|
||||
queue=False,
|
||||
)
|
||||
|
||||
def modelmerger(*args):
|
||||
@@ -1704,7 +1705,7 @@ def create_ui():
|
||||
if init_field is not None:
|
||||
init_field(saved_value)
|
||||
|
||||
if type(x) in [gr.Slider, gr.Radio, gr.Checkbox, gr.Textbox, gr.Number, gr.Dropdown] and x.visible:
|
||||
if type(x) in [gr.Slider, gr.Radio, gr.Checkbox, gr.Textbox, gr.Number, gr.Dropdown, ToolButton] and x.visible:
|
||||
apply_field(x, 'visible')
|
||||
|
||||
if type(x) == gr.Slider:
|
||||
@@ -1750,25 +1751,60 @@ def create_ui():
|
||||
return demo
|
||||
|
||||
|
||||
def reload_javascript():
|
||||
def webpath(fn):
|
||||
if fn.startswith(script_path):
|
||||
web_path = os.path.relpath(fn, script_path).replace('\\', '/')
|
||||
else:
|
||||
web_path = os.path.abspath(fn)
|
||||
|
||||
return f'file={web_path}?{os.path.getmtime(fn)}'
|
||||
|
||||
|
||||
def javascript_html():
|
||||
script_js = os.path.join(script_path, "script.js")
|
||||
head = f'<script type="text/javascript" src="file={os.path.abspath(script_js)}?{os.path.getmtime(script_js)}"></script>\n'
|
||||
head = f'<script type="text/javascript" src="{webpath(script_js)}"></script>\n'
|
||||
|
||||
inline = f"{localization.localization_js(shared.opts.localization)};"
|
||||
if cmd_opts.theme is not None:
|
||||
inline += f"set_theme('{cmd_opts.theme}');"
|
||||
|
||||
for script in modules.scripts.list_scripts("javascript", ".js"):
|
||||
head += f'<script type="text/javascript" src="file={script.path}?{os.path.getmtime(script.path)}"></script>\n'
|
||||
head += f'<script type="text/javascript" src="{webpath(script.path)}"></script>\n'
|
||||
|
||||
for script in modules.scripts.list_scripts("javascript", ".mjs"):
|
||||
head += f'<script type="module" src="file={script.path}?{os.path.getmtime(script.path)}"></script>\n'
|
||||
head += f'<script type="module" src="{webpath(script.path)}"></script>\n'
|
||||
|
||||
head += f'<script type="text/javascript">{inline}</script>\n'
|
||||
|
||||
return head
|
||||
|
||||
|
||||
def css_html():
|
||||
head = ""
|
||||
|
||||
def stylesheet(fn):
|
||||
return f'<link rel="stylesheet" property="stylesheet" href="{webpath(fn)}">'
|
||||
|
||||
for cssfile in modules.scripts.list_files_with_name("style.css"):
|
||||
if not os.path.isfile(cssfile):
|
||||
continue
|
||||
|
||||
head += stylesheet(cssfile)
|
||||
|
||||
if os.path.exists(os.path.join(data_path, "user.css")):
|
||||
head += stylesheet(os.path.join(data_path, "user.css"))
|
||||
|
||||
return head
|
||||
|
||||
|
||||
def reload_javascript():
|
||||
js = javascript_html()
|
||||
css = css_html()
|
||||
|
||||
def template_response(*args, **kwargs):
|
||||
res = shared.GradioTemplateResponseOriginal(*args, **kwargs)
|
||||
res.body = res.body.replace(b'</head>', f'{head}</head>'.encode("utf8"))
|
||||
res.body = res.body.replace(b'</head>', f'{js}</head>'.encode("utf8"))
|
||||
res.body = res.body.replace(b'</body>', f'{css}</body>'.encode("utf8"))
|
||||
res.init_headers()
|
||||
return res
|
||||
|
||||
|
||||
@@ -125,12 +125,12 @@ Requested path was: {f}
|
||||
|
||||
with gr.Column(variant='panel', elem_id=f"{tabname}_results"):
|
||||
with gr.Group(elem_id=f"{tabname}_gallery_container"):
|
||||
result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery").style(grid=4)
|
||||
result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery").style(columns=4)
|
||||
|
||||
generation_info = None
|
||||
with gr.Column():
|
||||
with gr.Row(elem_id=f"image_buttons_{tabname}"):
|
||||
open_folder_button = gr.Button(folder_symbol, elem_id="hidden_element" if shared.cmd_opts.hide_ui_dir_config else f'open_folder_{tabname}')
|
||||
with gr.Row(elem_id=f"image_buttons_{tabname}", elem_classes="image-buttons"):
|
||||
open_folder_button = gr.Button(folder_symbol, visible=not shared.cmd_opts.hide_ui_dir_config)
|
||||
|
||||
if tabname != "extras":
|
||||
save = gr.Button('Save', elem_id=f'save_{tabname}')
|
||||
@@ -145,11 +145,10 @@ Requested path was: {f}
|
||||
)
|
||||
|
||||
if tabname != "extras":
|
||||
with gr.Row():
|
||||
download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False, elem_id=f'download_files_{tabname}')
|
||||
download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False, elem_id=f'download_files_{tabname}')
|
||||
|
||||
with gr.Group():
|
||||
html_info = gr.HTML(elem_id=f'html_info_{tabname}')
|
||||
html_info = gr.HTML(elem_id=f'html_info_{tabname}', elem_classes="infotext")
|
||||
html_log = gr.HTML(elem_id=f'html_log_{tabname}')
|
||||
|
||||
generation_info = gr.Textbox(visible=False, elem_id=f'generation_info_{tabname}')
|
||||
@@ -160,6 +159,7 @@ Requested path was: {f}
|
||||
_js="function(x, y, z){ return [x, y, selected_gallery_index()] }",
|
||||
inputs=[generation_info, html_info, html_info],
|
||||
outputs=[html_info, html_info],
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
save.click(
|
||||
@@ -195,7 +195,7 @@ Requested path was: {f}
|
||||
|
||||
else:
|
||||
html_info_x = gr.HTML(elem_id=f'html_info_x_{tabname}')
|
||||
html_info = gr.HTML(elem_id=f'html_info_{tabname}')
|
||||
html_info = gr.HTML(elem_id=f'html_info_{tabname}', elem_classes="infotext")
|
||||
html_log = gr.HTML(elem_id=f'html_log_{tabname}')
|
||||
|
||||
paste_field_names = []
|
||||
|
||||
+34
-18
@@ -1,58 +1,74 @@
|
||||
import gradio as gr
|
||||
|
||||
|
||||
class ToolButton(gr.Button, gr.components.FormComponent):
|
||||
class FormComponent:
|
||||
def get_expected_parent(self):
|
||||
return gr.components.Form
|
||||
|
||||
|
||||
gr.Dropdown.get_expected_parent = FormComponent.get_expected_parent
|
||||
|
||||
|
||||
class ToolButton(FormComponent, gr.Button):
|
||||
"""Small button with single emoji as text, fits inside gradio forms"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(variant="tool", **kwargs)
|
||||
def __init__(self, *args, **kwargs):
|
||||
classes = kwargs.pop("elem_classes", [])
|
||||
super().__init__(*args, elem_classes=["tool", *classes], **kwargs)
|
||||
|
||||
def get_block_name(self):
|
||||
return "button"
|
||||
|
||||
|
||||
class ToolButtonTop(gr.Button, gr.components.FormComponent):
|
||||
"""Small button with single emoji as text, with extra margin at top, fits inside gradio forms"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(variant="tool-top", **kwargs)
|
||||
|
||||
def get_block_name(self):
|
||||
return "button"
|
||||
|
||||
|
||||
class FormRow(gr.Row, gr.components.FormComponent):
|
||||
class FormRow(FormComponent, gr.Row):
|
||||
"""Same as gr.Row but fits inside gradio forms"""
|
||||
|
||||
def get_block_name(self):
|
||||
return "row"
|
||||
|
||||
|
||||
class FormGroup(gr.Group, gr.components.FormComponent):
|
||||
class FormColumn(FormComponent, gr.Column):
|
||||
"""Same as gr.Column but fits inside gradio forms"""
|
||||
|
||||
def get_block_name(self):
|
||||
return "column"
|
||||
|
||||
|
||||
class FormGroup(FormComponent, gr.Group):
|
||||
"""Same as gr.Row but fits inside gradio forms"""
|
||||
|
||||
def get_block_name(self):
|
||||
return "group"
|
||||
|
||||
|
||||
class FormHTML(gr.HTML, gr.components.FormComponent):
|
||||
class FormHTML(FormComponent, gr.HTML):
|
||||
"""Same as gr.HTML but fits inside gradio forms"""
|
||||
|
||||
def get_block_name(self):
|
||||
return "html"
|
||||
|
||||
|
||||
class FormColorPicker(gr.ColorPicker, gr.components.FormComponent):
|
||||
class FormColorPicker(FormComponent, gr.ColorPicker):
|
||||
"""Same as gr.ColorPicker but fits inside gradio forms"""
|
||||
|
||||
def get_block_name(self):
|
||||
return "colorpicker"
|
||||
|
||||
|
||||
class DropdownMulti(gr.Dropdown):
|
||||
class DropdownMulti(FormComponent, gr.Dropdown):
|
||||
"""Same as gr.Dropdown but always multiselect"""
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(multiselect=True, **kwargs)
|
||||
|
||||
def get_block_name(self):
|
||||
return "dropdown"
|
||||
|
||||
|
||||
class DropdownEditable(FormComponent, gr.Dropdown):
|
||||
"""Same as gr.Dropdown but allows editing value"""
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(allow_custom_value=True, **kwargs)
|
||||
|
||||
def get_block_name(self):
|
||||
return "dropdown"
|
||||
|
||||
|
||||
+69
-26
@@ -1,6 +1,5 @@
|
||||
import json
|
||||
import os.path
|
||||
import shutil
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
@@ -22,7 +21,7 @@ def check_access():
|
||||
assert not shared.cmd_opts.disable_extension_access, "extension access disabled because of command line flags"
|
||||
|
||||
|
||||
def apply_and_restart(disable_list, update_list):
|
||||
def apply_and_restart(disable_list, update_list, disable_all):
|
||||
check_access()
|
||||
|
||||
disabled = json.loads(disable_list)
|
||||
@@ -44,6 +43,7 @@ def apply_and_restart(disable_list, update_list):
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
shared.opts.disabled_extensions = disabled
|
||||
shared.opts.disable_all_extensions = disable_all
|
||||
shared.opts.save(shared.config_filename)
|
||||
|
||||
shared.state.interrupt()
|
||||
@@ -64,6 +64,9 @@ def check_updates(id_task, disable_list):
|
||||
|
||||
try:
|
||||
ext.check_updates()
|
||||
except FileNotFoundError as e:
|
||||
if 'FETCH_HEAD' not in str(e):
|
||||
raise
|
||||
except Exception:
|
||||
print(f"Error checking updates for {ext.name}:", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
@@ -88,6 +91,8 @@ def extension_table():
|
||||
"""
|
||||
|
||||
for ext in extensions.extensions:
|
||||
ext.read_info_from_repo()
|
||||
|
||||
remote = f"""<a href="{html.escape(ext.remote or '')}" target="_blank">{html.escape("built-in" if ext.is_builtin else ext.remote or '')}</a>"""
|
||||
|
||||
if ext.can_update:
|
||||
@@ -95,9 +100,13 @@ def extension_table():
|
||||
else:
|
||||
ext_status = ext.status
|
||||
|
||||
style = ""
|
||||
if shared.opts.disable_all_extensions == "extra" and not ext.is_builtin or shared.opts.disable_all_extensions == "all":
|
||||
style = ' style="color: var(--primary-400)"'
|
||||
|
||||
code += f"""
|
||||
<tr>
|
||||
<td><label><input class="gr-check-radio gr-checkbox" name="enable_{html.escape(ext.name)}" type="checkbox" {'checked="checked"' if ext.enabled else ''}>{html.escape(ext.name)}</label></td>
|
||||
<td><label{style}><input class="gr-check-radio gr-checkbox" name="enable_{html.escape(ext.name)}" type="checkbox" {'checked="checked"' if ext.enabled else ''}>{html.escape(ext.name)}</label></td>
|
||||
<td>{remote}</td>
|
||||
<td>{ext.version}</td>
|
||||
<td{' class="extension_status"' if ext.remote is not None else ''}>{ext_status}</td>
|
||||
@@ -120,7 +129,7 @@ def normalize_git_url(url):
|
||||
return url
|
||||
|
||||
|
||||
def install_extension_from_url(dirname, url):
|
||||
def install_extension_from_url(dirname, branch_name, url):
|
||||
check_access()
|
||||
|
||||
assert url, 'No URL specified'
|
||||
@@ -141,22 +150,27 @@ def install_extension_from_url(dirname, url):
|
||||
|
||||
try:
|
||||
shutil.rmtree(tmpdir, True)
|
||||
|
||||
repo = git.Repo.clone_from(url, tmpdir)
|
||||
repo.remote().fetch()
|
||||
|
||||
if branch_name == '':
|
||||
# if no branch is specified, use the default branch
|
||||
with git.Repo.clone_from(url, tmpdir) as repo:
|
||||
repo.remote().fetch()
|
||||
for submodule in repo.submodules:
|
||||
submodule.update()
|
||||
else:
|
||||
with git.Repo.clone_from(url, tmpdir, branch=branch_name) as repo:
|
||||
repo.remote().fetch()
|
||||
for submodule in repo.submodules:
|
||||
submodule.update()
|
||||
try:
|
||||
os.rename(tmpdir, target_dir)
|
||||
except OSError as err:
|
||||
# TODO what does this do on windows? I think it'll be a different error code but I don't have a system to check it
|
||||
# Shouldn't cause any new issues at least but we probably want to handle it there too.
|
||||
if err.errno == errno.EXDEV:
|
||||
# Cross device link, typical in docker or when tmp/ and extensions/ are on different file systems
|
||||
# Since we can't use a rename, do the slower but more versitile shutil.move()
|
||||
shutil.move(tmpdir, target_dir)
|
||||
else:
|
||||
# Something else, not enough free space, permissions, etc. rethrow it so that it gets handled.
|
||||
raise(err)
|
||||
raise err
|
||||
|
||||
import launch
|
||||
launch.run_extension_installer(target_dir)
|
||||
@@ -167,12 +181,12 @@ def install_extension_from_url(dirname, url):
|
||||
shutil.rmtree(tmpdir, True)
|
||||
|
||||
|
||||
def install_extension_from_index(url, hide_tags, sort_column):
|
||||
def install_extension_from_index(url, hide_tags, sort_column, filter_text):
|
||||
ext_table, message = install_extension_from_url(None, url)
|
||||
|
||||
code, _ = refresh_available_extensions_from_data(hide_tags, sort_column)
|
||||
code, _ = refresh_available_extensions_from_data(hide_tags, sort_column, filter_text)
|
||||
|
||||
return code, ext_table, message
|
||||
return code, ext_table, message, ''
|
||||
|
||||
|
||||
def refresh_available_extensions(url, hide_tags, sort_column):
|
||||
@@ -186,11 +200,17 @@ def refresh_available_extensions(url, hide_tags, sort_column):
|
||||
|
||||
code, tags = refresh_available_extensions_from_data(hide_tags, sort_column)
|
||||
|
||||
return url, code, gr.CheckboxGroup.update(choices=tags), ''
|
||||
return url, code, gr.CheckboxGroup.update(choices=tags), '', ''
|
||||
|
||||
|
||||
def refresh_available_extensions_for_tags(hide_tags, sort_column):
|
||||
code, _ = refresh_available_extensions_from_data(hide_tags, sort_column)
|
||||
def refresh_available_extensions_for_tags(hide_tags, sort_column, filter_text):
|
||||
code, _ = refresh_available_extensions_from_data(hide_tags, sort_column, filter_text)
|
||||
|
||||
return code, ''
|
||||
|
||||
|
||||
def search_extensions(filter_text, hide_tags, sort_column):
|
||||
code, _ = refresh_available_extensions_from_data(hide_tags, sort_column, filter_text)
|
||||
|
||||
return code, ''
|
||||
|
||||
@@ -205,7 +225,7 @@ sort_ordering = [
|
||||
]
|
||||
|
||||
|
||||
def refresh_available_extensions_from_data(hide_tags, sort_column):
|
||||
def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text=""):
|
||||
extlist = available_extensions["extensions"]
|
||||
installed_extension_urls = {normalize_git_url(extension.remote): extension.name for extension in extensions.extensions}
|
||||
|
||||
@@ -244,7 +264,12 @@ def refresh_available_extensions_from_data(hide_tags, sort_column):
|
||||
hidden += 1
|
||||
continue
|
||||
|
||||
install_code = f"""<input onclick="install_extension_from_index(this, '{html.escape(url)}')" type="button" value="{"Install" if not existing else "Installed"}" {"disabled=disabled" if existing else ""} class="gr-button gr-button-lg gr-button-secondary">"""
|
||||
if filter_text and filter_text.strip():
|
||||
if filter_text.lower() not in html.escape(name).lower() and filter_text.lower() not in html.escape(description).lower():
|
||||
hidden += 1
|
||||
continue
|
||||
|
||||
install_code = f"""<button onclick="install_extension_from_index(this, '{html.escape(url)}')" {"disabled=disabled" if existing else ""} class="lg secondary gradio-button custom-button">{"Install" if not existing else "Installed"}</button>"""
|
||||
|
||||
tags_text = ", ".join([f"<span class='extension-tag' title='{tags.get(x, '')}'>{x}</span>" for x in extension_tags])
|
||||
|
||||
@@ -281,16 +306,24 @@ def create_ui():
|
||||
with gr.Row(elem_id="extensions_installed_top"):
|
||||
apply = gr.Button(value="Apply and restart UI", variant="primary")
|
||||
check = gr.Button(value="Check for updates")
|
||||
extensions_disable_all = gr.Radio(label="Disable all extensions", choices=["none", "extra", "all"], value=shared.opts.disable_all_extensions, elem_id="extensions_disable_all")
|
||||
extensions_disabled_list = gr.Text(elem_id="extensions_disabled_list", visible=False).style(container=False)
|
||||
extensions_update_list = gr.Text(elem_id="extensions_update_list", visible=False).style(container=False)
|
||||
|
||||
info = gr.HTML()
|
||||
html = ""
|
||||
if shared.opts.disable_all_extensions != "none":
|
||||
html = """
|
||||
<span style="color: var(--primary-400);">
|
||||
"Disable all extensions" was set, change it to "none" to load all extensions again
|
||||
</span>
|
||||
"""
|
||||
info = gr.HTML(html)
|
||||
extensions_table = gr.HTML(lambda: extension_table())
|
||||
|
||||
apply.click(
|
||||
fn=apply_and_restart,
|
||||
_js="extensions_apply",
|
||||
inputs=[extensions_disabled_list, extensions_update_list],
|
||||
inputs=[extensions_disabled_list, extensions_update_list, extensions_disable_all],
|
||||
outputs=[],
|
||||
)
|
||||
|
||||
@@ -312,42 +345,52 @@ def create_ui():
|
||||
hide_tags = gr.CheckboxGroup(value=["ads", "localization", "installed"], label="Hide extensions with tags", choices=["script", "ads", "localization", "installed"])
|
||||
sort_column = gr.Radio(value="newest first", label="Order", choices=["newest first", "oldest first", "a-z", "z-a", "internal order", ], type="index")
|
||||
|
||||
with gr.Row():
|
||||
search_extensions_text = gr.Text(label="Search").style(container=False)
|
||||
|
||||
install_result = gr.HTML()
|
||||
available_extensions_table = gr.HTML()
|
||||
|
||||
refresh_available_extensions_button.click(
|
||||
fn=modules.ui.wrap_gradio_call(refresh_available_extensions, extra_outputs=[gr.update(), gr.update(), gr.update()]),
|
||||
inputs=[available_extensions_index, hide_tags, sort_column],
|
||||
outputs=[available_extensions_index, available_extensions_table, hide_tags, install_result],
|
||||
outputs=[available_extensions_index, available_extensions_table, hide_tags, install_result, search_extensions_text],
|
||||
)
|
||||
|
||||
install_extension_button.click(
|
||||
fn=modules.ui.wrap_gradio_call(install_extension_from_index, extra_outputs=[gr.update(), gr.update()]),
|
||||
inputs=[extension_to_install, hide_tags, sort_column],
|
||||
inputs=[extension_to_install, hide_tags, sort_column, search_extensions_text],
|
||||
outputs=[available_extensions_table, extensions_table, install_result],
|
||||
)
|
||||
|
||||
search_extensions_text.change(
|
||||
fn=modules.ui.wrap_gradio_call(search_extensions, extra_outputs=[gr.update()]),
|
||||
inputs=[search_extensions_text, hide_tags, sort_column],
|
||||
outputs=[available_extensions_table, install_result],
|
||||
)
|
||||
|
||||
hide_tags.change(
|
||||
fn=modules.ui.wrap_gradio_call(refresh_available_extensions_for_tags, extra_outputs=[gr.update()]),
|
||||
inputs=[hide_tags, sort_column],
|
||||
inputs=[hide_tags, sort_column, search_extensions_text],
|
||||
outputs=[available_extensions_table, install_result]
|
||||
)
|
||||
|
||||
sort_column.change(
|
||||
fn=modules.ui.wrap_gradio_call(refresh_available_extensions_for_tags, extra_outputs=[gr.update()]),
|
||||
inputs=[hide_tags, sort_column],
|
||||
inputs=[hide_tags, sort_column, search_extensions_text],
|
||||
outputs=[available_extensions_table, install_result]
|
||||
)
|
||||
|
||||
with gr.TabItem("Install from URL"):
|
||||
install_url = gr.Text(label="URL for extension's git repository")
|
||||
install_branch = gr.Text(label="Specific branch name", placeholder="Leave empty for default main branch")
|
||||
install_dirname = gr.Text(label="Local directory name", placeholder="Leave empty for auto")
|
||||
install_button = gr.Button(value="Install", variant="primary")
|
||||
install_result = gr.HTML(elem_id="extension_install_result")
|
||||
|
||||
install_button.click(
|
||||
fn=modules.ui.wrap_gradio_call(install_extension_from_url, extra_outputs=[gr.update()]),
|
||||
inputs=[install_dirname, install_url],
|
||||
inputs=[install_dirname, install_branch, install_url],
|
||||
outputs=[extensions_table, install_result],
|
||||
)
|
||||
|
||||
|
||||
@@ -2,8 +2,10 @@ import glob
|
||||
import os.path
|
||||
import urllib.parse
|
||||
from pathlib import Path
|
||||
from PIL import PngImagePlugin
|
||||
|
||||
from modules import shared
|
||||
from modules.images import read_info_from_image
|
||||
import gradio as gr
|
||||
import json
|
||||
import html
|
||||
@@ -22,21 +24,37 @@ def register_page(page):
|
||||
allowed_dirs.update(set(sum([x.allowed_directories_for_previews() for x in extra_pages], [])))
|
||||
|
||||
|
||||
def fetch_file(filename: str = ""):
|
||||
from starlette.responses import FileResponse
|
||||
|
||||
if not any([Path(x).absolute() in Path(filename).absolute().parents for x in allowed_dirs]):
|
||||
raise ValueError(f"File cannot be fetched: {filename}. Must be in one of directories registered by extra pages.")
|
||||
|
||||
ext = os.path.splitext(filename)[1].lower()
|
||||
if ext not in (".png", ".jpg", ".webp"):
|
||||
raise ValueError(f"File cannot be fetched: {filename}. Only png and jpg and webp.")
|
||||
|
||||
# would profit from returning 304
|
||||
return FileResponse(filename, headers={"Accept-Ranges": "bytes"})
|
||||
|
||||
|
||||
def get_metadata(page: str = "", item: str = ""):
|
||||
from starlette.responses import JSONResponse
|
||||
|
||||
page = next(iter([x for x in extra_pages if x.name == page]), None)
|
||||
if page is None:
|
||||
return JSONResponse({})
|
||||
|
||||
metadata = page.metadata.get(item)
|
||||
if metadata is None:
|
||||
return JSONResponse({})
|
||||
|
||||
return JSONResponse({"metadata": metadata})
|
||||
|
||||
|
||||
def add_pages_to_demo(app):
|
||||
def fetch_file(filename: str = ""):
|
||||
from starlette.responses import FileResponse
|
||||
|
||||
if not any([Path(x).absolute() in Path(filename).absolute().parents for x in allowed_dirs]):
|
||||
raise ValueError(f"File cannot be fetched: {filename}. Must be in one of directories registered by extra pages.")
|
||||
|
||||
ext = os.path.splitext(filename)[1].lower()
|
||||
if ext not in (".png", ".jpg", ".webp"):
|
||||
raise ValueError(f"File cannot be fetched: {filename}. Only png and jpg and webp.")
|
||||
|
||||
# would profit from returning 304
|
||||
return FileResponse(filename, headers={"Accept-Ranges": "bytes"})
|
||||
|
||||
app.add_api_route("/sd_extra_networks/thumb", fetch_file, methods=["GET"])
|
||||
app.add_api_route("/sd_extra_networks/metadata", get_metadata, methods=["GET"])
|
||||
|
||||
|
||||
class ExtraNetworksPage:
|
||||
@@ -45,6 +63,7 @@ class ExtraNetworksPage:
|
||||
self.name = title.lower()
|
||||
self.card_page = shared.html("extra-networks-card.html")
|
||||
self.allow_negative_prompt = False
|
||||
self.metadata = {}
|
||||
|
||||
def refresh(self):
|
||||
pass
|
||||
@@ -66,6 +85,8 @@ class ExtraNetworksPage:
|
||||
view = shared.opts.extra_networks_default_view
|
||||
items_html = ''
|
||||
|
||||
self.metadata = {}
|
||||
|
||||
subdirs = {}
|
||||
for parentdir in [os.path.abspath(x) for x in self.allowed_directories_for_previews()]:
|
||||
for x in glob.glob(os.path.join(parentdir, '**/*'), recursive=True):
|
||||
@@ -86,12 +107,16 @@ class ExtraNetworksPage:
|
||||
subdirs = {"": 1, **subdirs}
|
||||
|
||||
subdirs_html = "".join([f"""
|
||||
<button class='gr-button gr-button-lg gr-button-secondary{" search-all" if subdir=="" else ""}' onclick='extraNetworksSearchButton("{tabname}_extra_tabs", event)'>
|
||||
<button class='lg secondary gradio-button custom-button{" search-all" if subdir=="" else ""}' onclick='extraNetworksSearchButton("{tabname}_extra_tabs", event)'>
|
||||
{html.escape(subdir if subdir!="" else "all")}
|
||||
</button>
|
||||
""" for subdir in subdirs])
|
||||
|
||||
for item in self.list_items():
|
||||
metadata = item.get("metadata")
|
||||
if metadata:
|
||||
self.metadata[item["name"]] = metadata
|
||||
|
||||
items_html += self.create_html_for_item(item, tabname)
|
||||
|
||||
if items_html == '':
|
||||
@@ -124,14 +149,16 @@ class ExtraNetworksPage:
|
||||
if onclick is None:
|
||||
onclick = '"' + html.escape(f"""return cardClicked({json.dumps(tabname)}, {item["prompt"]}, {"true" if self.allow_negative_prompt else "false"})""") + '"'
|
||||
|
||||
height = f"height: {shared.opts.extra_networks_card_height}px;" if shared.opts.extra_networks_card_height else ''
|
||||
width = f"width: {shared.opts.extra_networks_card_width}px;" if shared.opts.extra_networks_card_width else ''
|
||||
background_image = f"background-image: url(\"{html.escape(preview)}\");" if preview else ''
|
||||
metadata_button = ""
|
||||
metadata = item.get("metadata")
|
||||
if metadata:
|
||||
metadata_onclick = '"' + html.escape(f"""extraNetworksShowMetadata({json.dumps(metadata)}); return false;""") + '"'
|
||||
metadata_button = f"<div class='metadata-button' title='Show metadata' onclick={metadata_onclick}></div>"
|
||||
metadata_button = f"<div class='metadata-button' title='Show metadata' onclick='extraNetworksRequestMetadata(event, {json.dumps(self.name)}, {json.dumps(item['name'])})'></div>"
|
||||
|
||||
args = {
|
||||
"preview_html": "style='background-image: url(\"" + html.escape(preview) + "\")'" if preview else '',
|
||||
"style": f"'{height}{width}{background_image}'",
|
||||
"prompt": item.get("prompt", None),
|
||||
"tabname": json.dumps(tabname),
|
||||
"local_preview": json.dumps(item["local_preview"]),
|
||||
@@ -215,6 +242,7 @@ def create_ui(container, button, tabname):
|
||||
with gr.Tabs(elem_id=tabname+"_extra_tabs") as tabs:
|
||||
for page in ui.stored_extra_pages:
|
||||
with gr.Tab(page.title):
|
||||
|
||||
page_elem = gr.HTML(page.create_html(ui.tabname))
|
||||
ui.pages.append(page_elem)
|
||||
|
||||
@@ -226,10 +254,10 @@ def create_ui(container, button, tabname):
|
||||
|
||||
def toggle_visibility(is_visible):
|
||||
is_visible = not is_visible
|
||||
return is_visible, gr.update(visible=is_visible)
|
||||
return is_visible, gr.update(visible=is_visible), gr.update(variant=("secondary-down" if is_visible else "secondary"))
|
||||
|
||||
state_visible = gr.State(value=False)
|
||||
button.click(fn=toggle_visibility, inputs=[state_visible], outputs=[state_visible, container])
|
||||
button.click(fn=toggle_visibility, inputs=[state_visible], outputs=[state_visible, container, button])
|
||||
|
||||
def refresh():
|
||||
res = []
|
||||
@@ -264,6 +292,7 @@ def setup_ui(ui, gallery):
|
||||
|
||||
img_info = images[index if index >= 0 else 0]
|
||||
image = image_from_url_text(img_info)
|
||||
geninfo, items = read_info_from_image(image)
|
||||
|
||||
is_allowed = False
|
||||
for extra_page in ui.stored_extra_pages:
|
||||
@@ -273,7 +302,12 @@ def setup_ui(ui, gallery):
|
||||
|
||||
assert is_allowed, f'writing to {filename} is not allowed'
|
||||
|
||||
image.save(filename)
|
||||
if geninfo:
|
||||
pnginfo_data = PngImagePlugin.PngInfo()
|
||||
pnginfo_data.add_text('parameters', geninfo)
|
||||
image.save(filename, pnginfo=pnginfo_data)
|
||||
else:
|
||||
image.save(filename)
|
||||
|
||||
return [page.create_html(ui.tabname) for page in ui.stored_extra_pages]
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ def create_ui():
|
||||
extras_image = gr.Image(label="Source", source="upload", interactive=True, type="pil", elem_id="extras_image")
|
||||
|
||||
with gr.TabItem('Batch Process', elem_id="extras_batch_process_tab") as tab_batch:
|
||||
image_batch = gr.File(label="Batch Process", file_count="multiple", interactive=True, type="file", elem_id="extras_image_batch")
|
||||
image_batch = gr.Files(label="Batch Process", interactive=True, elem_id="extras_image_batch")
|
||||
|
||||
with gr.TabItem('Batch from Directory', elem_id="extras_batch_directory_tab") as tab_batch_dir:
|
||||
extras_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, placeholder="A directory on the same machine where the server is running.", elem_id="extras_batch_input_dir")
|
||||
|
||||
Reference in New Issue
Block a user