diff --git a/modules/initialize.py b/modules/initialize.py index 0365bbb30..a3f5f6b3f 100644 --- a/modules/initialize.py +++ b/modules/initialize.py @@ -50,6 +50,7 @@ def check_versions(): def initialize(): from modules import initialize_util + initialize_util.allow_add_middleware_after_start() initialize_util.fix_torch_version() initialize_util.fix_pytorch_lightning() initialize_util.fix_asyncio_event_loop_policy() diff --git a/modules/initialize_util.py b/modules/initialize_util.py index 79a72cb3a..0800a8690 100644 --- a/modules/initialize_util.py +++ b/modules/initialize_util.py @@ -5,6 +5,8 @@ import sys import re from modules.timer import startup_timer +from modules import patches +from functools import wraps def gradio_server_name(): @@ -191,11 +193,8 @@ def configure_opts_onchange(): def setup_middleware(app): from starlette.middleware.gzip import GZipMiddleware - - app.middleware_stack = None # reset current middleware to allow modifying user provided list app.add_middleware(GZipMiddleware, minimum_size=1000) configure_cors_middleware(app) - app.build_middleware_stack() # rebuild middleware stack on-the-fly def configure_cors_middleware(app): @@ -213,3 +212,38 @@ def configure_cors_middleware(app): cors_options["allow_origin_regex"] = cmd_opts.cors_allow_origins_regex app.add_middleware(CORSMiddleware, **cors_options) + +def allow_add_middleware_after_start(): + from starlette.applications import Starlette + + def add_middleware_wrapper(func): + """Patch Starlette.add_middleware to allow for middleware to be added after the app has started + + Starlette.add_middleware raises RuntimeError("Cannot add middleware after an application has started") if middleware_stack is not None. + We can force add new middleware by first setting middleware_stack to None, then adding the middleware. + When middleware_stack is None, it will rebuild the middleware_stack on the next request (Lazily build middleware stack). + + If packages are updated in the future, things may break, so we have two ways to add middleware after the app has started: + the first way is to just set middleware_stack to None and then retry + the second manually insert the middleware into the user_middleware list without calling add_middleware + """ + + @wraps(func) + def wrapper(self, *args, **kwargs): + res = None + try: + res = func(self, *args, **kwargs) + except RuntimeError as _: + try: + self.middleware_stack = None + res = func(self, *args, **kwargs) + except RuntimeError as e: + print(f'Warning: "{e}", Retrying...') + from starlette.middleware import Middleware + self.user_middleware.insert(0, Middleware(*args, **kwargs)) + self.middleware_stack = None # ensure middleware_stack in the event of concurrent requests + return res + + return wrapper + + patches.patch(__name__, obj=Starlette, field="add_middleware", replacement=add_middleware_wrapper(Starlette.add_middleware))