mirror of
https://github.com/w-okada/voice-changer.git
synced 2025-01-08 11:57:48 +08:00
Refactor and add origin check to SIO
This commit is contained in:
parent
ce9b599501
commit
8dd8d7127d
@ -140,8 +140,8 @@ if __name__ == "MMVCServerSIO":
|
|||||||
mp.freeze_support()
|
mp.freeze_support()
|
||||||
|
|
||||||
voiceChangerManager = VoiceChangerManager.get_instance(voiceChangerParams)
|
voiceChangerManager = VoiceChangerManager.get_instance(voiceChangerParams)
|
||||||
app_fastapi = MMVC_Rest.get_instance(voiceChangerManager, voiceChangerParams, PORT, args.allowed_origins)
|
app_fastapi = MMVC_Rest.get_instance(voiceChangerManager, voiceChangerParams, args.allowed_origins, PORT)
|
||||||
app_socketio = MMVC_SocketIOApp.get_instance(app_fastapi, voiceChangerManager)
|
app_socketio = MMVC_SocketIOApp.get_instance(app_fastapi, voiceChangerManager, args.allowed_origins, PORT)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__mp_main__":
|
if __name__ == "__mp_main__":
|
||||||
|
24
server/mods/origins.py
Normal file
24
server/mods/origins.py
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
from typing import Optional, Sequence
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
ENFORCE_URL_ORIGIN_FORMAT = "Input origins must be well-formed URLs, i.e. https://google.com or https://www.google.com."
|
||||||
|
SCHEMAS = ('http', 'https')
|
||||||
|
LOCAL_ORIGINS = ('127.0.0.1', 'localhost')
|
||||||
|
|
||||||
|
def compute_local_origins(port: Optional[int] = None) -> list[str]:
|
||||||
|
local_origins = [f'{schema}://{origin}' for schema in SCHEMAS for origin in LOCAL_ORIGINS]
|
||||||
|
if port is not None:
|
||||||
|
local_origins = [f'{origin}:{port}' for origin in local_origins]
|
||||||
|
return local_origins
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_origins(origins: Sequence[str]) -> set[str]:
|
||||||
|
allowed_origins = set()
|
||||||
|
for origin in origins:
|
||||||
|
url = urlparse(origin)
|
||||||
|
assert url.scheme, ENFORCE_URL_ORIGIN_FORMAT
|
||||||
|
valid_origin = f'{url.scheme}://{url.hostname}'
|
||||||
|
if url.port:
|
||||||
|
valid_origin += f':{url.port}'
|
||||||
|
allowed_origins.add(valid_origin)
|
||||||
|
return allowed_origins
|
@ -6,7 +6,7 @@ from fastapi import FastAPI, Request, Response, HTTPException
|
|||||||
from fastapi.routing import APIRoute
|
from fastapi.routing import APIRoute
|
||||||
from fastapi.staticfiles import StaticFiles
|
from fastapi.staticfiles import StaticFiles
|
||||||
from fastapi.exceptions import RequestValidationError
|
from fastapi.exceptions import RequestValidationError
|
||||||
from typing import Callable
|
from typing import Callable, Optional, Sequence, Literal
|
||||||
from mods.log_control import VoiceChangaerLogger
|
from mods.log_control import VoiceChangaerLogger
|
||||||
from voice_changer.VoiceChangerManager import VoiceChangerManager
|
from voice_changer.VoiceChangerManager import VoiceChangerManager
|
||||||
|
|
||||||
@ -43,8 +43,8 @@ class MMVC_Rest:
|
|||||||
cls,
|
cls,
|
||||||
voiceChangerManager: VoiceChangerManager,
|
voiceChangerManager: VoiceChangerManager,
|
||||||
voiceChangerParams: VoiceChangerParams,
|
voiceChangerParams: VoiceChangerParams,
|
||||||
port: int,
|
allowedOrigins: Optional[Sequence[str]] = None,
|
||||||
allowedOrigins: list[str],
|
port: Optional[int] = None,
|
||||||
):
|
):
|
||||||
if cls._instance is None:
|
if cls._instance is None:
|
||||||
logger.info("[Voice Changer] MMVC_Rest initializing...")
|
logger.info("[Voice Changer] MMVC_Rest initializing...")
|
||||||
|
@ -1,35 +1,27 @@
|
|||||||
import typing
|
from typing import Optional, Sequence, Literal
|
||||||
|
|
||||||
from urllib.parse import urlparse
|
from mods.origins import compute_local_origins, normalize_origins
|
||||||
from starlette.datastructures import Headers
|
from starlette.datastructures import Headers
|
||||||
from starlette.responses import PlainTextResponse
|
from starlette.responses import PlainTextResponse
|
||||||
from starlette.types import ASGIApp, Receive, Scope, Send
|
from starlette.types import ASGIApp, Receive, Scope, Send
|
||||||
|
|
||||||
ENFORCE_URL_ORIGIN_FORMAT = "Input origins must be well-formed URLs, i.e. https://google.com or https://www.google.com."
|
|
||||||
|
|
||||||
|
|
||||||
class TrustedOriginMiddleware:
|
class TrustedOriginMiddleware:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
app: ASGIApp,
|
app: ASGIApp,
|
||||||
allowed_origins: typing.Optional[typing.Sequence[str]] = None,
|
allowed_origins: Optional[Sequence[str]] = None,
|
||||||
port: typing.Optional[int] = None,
|
port: Optional[int] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
schemas = ['http', 'https']
|
|
||||||
local_origins = [f'{schema}://{origin}' for schema in schemas for origin in ['127.0.0.1', 'localhost']]
|
|
||||||
if port is not None:
|
|
||||||
local_origins = [f'{origin}:{port}' for origin in local_origins]
|
|
||||||
|
|
||||||
self.allowed_origins: set[str] = set()
|
self.allowed_origins: set[str] = set()
|
||||||
if allowed_origins is not None:
|
|
||||||
for origin in allowed_origins:
|
local_origins = compute_local_origins(port)
|
||||||
url = urlparse(origin)
|
|
||||||
assert url.scheme, ENFORCE_URL_ORIGIN_FORMAT
|
|
||||||
valid_origin = f'{url.scheme}://{url.hostname}'
|
|
||||||
if url.port:
|
|
||||||
valid_origin += f':{url.port}'
|
|
||||||
self.allowed_origins.add(valid_origin)
|
|
||||||
self.allowed_origins.update(local_origins)
|
self.allowed_origins.update(local_origins)
|
||||||
|
|
||||||
|
if allowed_origins is not None:
|
||||||
|
normalized_origins = normalize_origins(allowed_origins)
|
||||||
|
self.allowed_origins.update(normalized_origins)
|
||||||
|
|
||||||
self.app = app
|
self.app = app
|
||||||
|
|
||||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
import socketio
|
import socketio
|
||||||
from mods.log_control import VoiceChangaerLogger
|
from mods.log_control import VoiceChangaerLogger
|
||||||
|
from mods.origins import compute_local_origins, normalize_origins
|
||||||
|
|
||||||
|
from typing import Sequence, Optional
|
||||||
from sio.MMVC_SocketIOServer import MMVC_SocketIOServer
|
from sio.MMVC_SocketIOServer import MMVC_SocketIOServer
|
||||||
from voice_changer.VoiceChangerManager import VoiceChangerManager
|
from voice_changer.VoiceChangerManager import VoiceChangerManager
|
||||||
from const import getFrontendPath
|
from const import getFrontendPath
|
||||||
@ -12,10 +14,24 @@ class MMVC_SocketIOApp:
|
|||||||
_instance: socketio.ASGIApp | None = None
|
_instance: socketio.ASGIApp | None = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_instance(cls, app_fastapi, voiceChangerManager: VoiceChangerManager):
|
def get_instance(
|
||||||
|
cls,
|
||||||
|
app_fastapi,
|
||||||
|
voiceChangerManager: VoiceChangerManager,
|
||||||
|
allowedOrigins: Optional[Sequence[str]] = None,
|
||||||
|
port: Optional[int] = None,
|
||||||
|
):
|
||||||
if cls._instance is None:
|
if cls._instance is None:
|
||||||
logger.info("[Voice Changer] MMVC_SocketIOApp initializing...")
|
logger.info("[Voice Changer] MMVC_SocketIOApp initializing...")
|
||||||
sio = MMVC_SocketIOServer.get_instance(voiceChangerManager)
|
|
||||||
|
allowed_origins: set[str] = set()
|
||||||
|
local_origins = compute_local_origins(port)
|
||||||
|
allowed_origins.update(local_origins)
|
||||||
|
if allowedOrigins is not None:
|
||||||
|
normalized_origins = normalize_origins(allowedOrigins)
|
||||||
|
allowed_origins.update(normalized_origins)
|
||||||
|
sio = MMVC_SocketIOServer.get_instance(voiceChangerManager, list(allowed_origins))
|
||||||
|
|
||||||
app_socketio = socketio.ASGIApp(
|
app_socketio = socketio.ASGIApp(
|
||||||
sio,
|
sio,
|
||||||
other_asgi_app=app_fastapi,
|
other_asgi_app=app_fastapi,
|
||||||
|
@ -8,9 +8,13 @@ class MMVC_SocketIOServer:
|
|||||||
_instance: socketio.AsyncServer | None = None
|
_instance: socketio.AsyncServer | None = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_instance(cls, voiceChangerManager: VoiceChangerManager):
|
def get_instance(
|
||||||
|
cls,
|
||||||
|
voiceChangerManager: VoiceChangerManager,
|
||||||
|
allowedOrigins: list[str],
|
||||||
|
):
|
||||||
if cls._instance is None:
|
if cls._instance is None:
|
||||||
sio = socketio.AsyncServer(async_mode="asgi", cors_allowed_origins="*")
|
sio = socketio.AsyncServer(async_mode="asgi", cors_allowed_origins=allowedOrigins)
|
||||||
namespace = MMVC_Namespace.get_instance(voiceChangerManager)
|
namespace = MMVC_Namespace.get_instance(voiceChangerManager)
|
||||||
sio.register_namespace(namespace)
|
sio.register_namespace(namespace)
|
||||||
cls._instance = sio
|
cls._instance = sio
|
||||||
|
Loading…
Reference in New Issue
Block a user