mirror of
https://github.com/w-okada/voice-changer.git
synced 2025-01-07 03:16:48 +08:00
Refactor and add origin check to SIO
This commit is contained in:
parent
ce9b599501
commit
8dd8d7127d
@ -140,8 +140,8 @@ if __name__ == "MMVCServerSIO":
|
||||
mp.freeze_support()
|
||||
|
||||
voiceChangerManager = VoiceChangerManager.get_instance(voiceChangerParams)
|
||||
app_fastapi = MMVC_Rest.get_instance(voiceChangerManager, voiceChangerParams, PORT, args.allowed_origins)
|
||||
app_socketio = MMVC_SocketIOApp.get_instance(app_fastapi, voiceChangerManager)
|
||||
app_fastapi = MMVC_Rest.get_instance(voiceChangerManager, voiceChangerParams, args.allowed_origins, PORT)
|
||||
app_socketio = MMVC_SocketIOApp.get_instance(app_fastapi, voiceChangerManager, args.allowed_origins, PORT)
|
||||
|
||||
|
||||
if __name__ == "__mp_main__":
|
||||
|
24
server/mods/origins.py
Normal file
24
server/mods/origins.py
Normal file
@ -0,0 +1,24 @@
|
||||
from typing import Optional, Sequence
|
||||
from urllib.parse import urlparse
|
||||
|
||||
ENFORCE_URL_ORIGIN_FORMAT = "Input origins must be well-formed URLs, i.e. https://google.com or https://www.google.com."
|
||||
SCHEMAS = ('http', 'https')
|
||||
LOCAL_ORIGINS = ('127.0.0.1', 'localhost')
|
||||
|
||||
def compute_local_origins(port: Optional[int] = None) -> list[str]:
|
||||
local_origins = [f'{schema}://{origin}' for schema in SCHEMAS for origin in LOCAL_ORIGINS]
|
||||
if port is not None:
|
||||
local_origins = [f'{origin}:{port}' for origin in local_origins]
|
||||
return local_origins
|
||||
|
||||
|
||||
def normalize_origins(origins: Sequence[str]) -> set[str]:
|
||||
allowed_origins = set()
|
||||
for origin in origins:
|
||||
url = urlparse(origin)
|
||||
assert url.scheme, ENFORCE_URL_ORIGIN_FORMAT
|
||||
valid_origin = f'{url.scheme}://{url.hostname}'
|
||||
if url.port:
|
||||
valid_origin += f':{url.port}'
|
||||
allowed_origins.add(valid_origin)
|
||||
return allowed_origins
|
@ -6,7 +6,7 @@ from fastapi import FastAPI, Request, Response, HTTPException
|
||||
from fastapi.routing import APIRoute
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from typing import Callable
|
||||
from typing import Callable, Optional, Sequence, Literal
|
||||
from mods.log_control import VoiceChangaerLogger
|
||||
from voice_changer.VoiceChangerManager import VoiceChangerManager
|
||||
|
||||
@ -43,8 +43,8 @@ class MMVC_Rest:
|
||||
cls,
|
||||
voiceChangerManager: VoiceChangerManager,
|
||||
voiceChangerParams: VoiceChangerParams,
|
||||
port: int,
|
||||
allowedOrigins: list[str],
|
||||
allowedOrigins: Optional[Sequence[str]] = None,
|
||||
port: Optional[int] = None,
|
||||
):
|
||||
if cls._instance is None:
|
||||
logger.info("[Voice Changer] MMVC_Rest initializing...")
|
||||
|
@ -1,35 +1,27 @@
|
||||
import typing
|
||||
from typing import Optional, Sequence, Literal
|
||||
|
||||
from urllib.parse import urlparse
|
||||
from mods.origins import compute_local_origins, normalize_origins
|
||||
from starlette.datastructures import Headers
|
||||
from starlette.responses import PlainTextResponse
|
||||
from starlette.types import ASGIApp, Receive, Scope, Send
|
||||
|
||||
ENFORCE_URL_ORIGIN_FORMAT = "Input origins must be well-formed URLs, i.e. https://google.com or https://www.google.com."
|
||||
|
||||
|
||||
class TrustedOriginMiddleware:
|
||||
def __init__(
|
||||
self,
|
||||
app: ASGIApp,
|
||||
allowed_origins: typing.Optional[typing.Sequence[str]] = None,
|
||||
port: typing.Optional[int] = None,
|
||||
allowed_origins: Optional[Sequence[str]] = None,
|
||||
port: Optional[int] = None,
|
||||
) -> None:
|
||||
schemas = ['http', 'https']
|
||||
local_origins = [f'{schema}://{origin}' for schema in schemas for origin in ['127.0.0.1', 'localhost']]
|
||||
if port is not None:
|
||||
local_origins = [f'{origin}:{port}' for origin in local_origins]
|
||||
|
||||
self.allowed_origins: set[str] = set()
|
||||
if allowed_origins is not None:
|
||||
for origin in allowed_origins:
|
||||
url = urlparse(origin)
|
||||
assert url.scheme, ENFORCE_URL_ORIGIN_FORMAT
|
||||
valid_origin = f'{url.scheme}://{url.hostname}'
|
||||
if url.port:
|
||||
valid_origin += f':{url.port}'
|
||||
self.allowed_origins.add(valid_origin)
|
||||
|
||||
local_origins = compute_local_origins(port)
|
||||
self.allowed_origins.update(local_origins)
|
||||
|
||||
if allowed_origins is not None:
|
||||
normalized_origins = normalize_origins(allowed_origins)
|
||||
self.allowed_origins.update(normalized_origins)
|
||||
|
||||
self.app = app
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
|
@ -1,6 +1,8 @@
|
||||
import socketio
|
||||
from mods.log_control import VoiceChangaerLogger
|
||||
from mods.origins import compute_local_origins, normalize_origins
|
||||
|
||||
from typing import Sequence, Optional
|
||||
from sio.MMVC_SocketIOServer import MMVC_SocketIOServer
|
||||
from voice_changer.VoiceChangerManager import VoiceChangerManager
|
||||
from const import getFrontendPath
|
||||
@ -12,10 +14,24 @@ class MMVC_SocketIOApp:
|
||||
_instance: socketio.ASGIApp | None = None
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls, app_fastapi, voiceChangerManager: VoiceChangerManager):
|
||||
def get_instance(
|
||||
cls,
|
||||
app_fastapi,
|
||||
voiceChangerManager: VoiceChangerManager,
|
||||
allowedOrigins: Optional[Sequence[str]] = None,
|
||||
port: Optional[int] = None,
|
||||
):
|
||||
if cls._instance is None:
|
||||
logger.info("[Voice Changer] MMVC_SocketIOApp initializing...")
|
||||
sio = MMVC_SocketIOServer.get_instance(voiceChangerManager)
|
||||
|
||||
allowed_origins: set[str] = set()
|
||||
local_origins = compute_local_origins(port)
|
||||
allowed_origins.update(local_origins)
|
||||
if allowedOrigins is not None:
|
||||
normalized_origins = normalize_origins(allowedOrigins)
|
||||
allowed_origins.update(normalized_origins)
|
||||
sio = MMVC_SocketIOServer.get_instance(voiceChangerManager, list(allowed_origins))
|
||||
|
||||
app_socketio = socketio.ASGIApp(
|
||||
sio,
|
||||
other_asgi_app=app_fastapi,
|
||||
|
@ -8,9 +8,13 @@ class MMVC_SocketIOServer:
|
||||
_instance: socketio.AsyncServer | None = None
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls, voiceChangerManager: VoiceChangerManager):
|
||||
def get_instance(
|
||||
cls,
|
||||
voiceChangerManager: VoiceChangerManager,
|
||||
allowedOrigins: list[str],
|
||||
):
|
||||
if cls._instance is None:
|
||||
sio = socketio.AsyncServer(async_mode="asgi", cors_allowed_origins="*")
|
||||
sio = socketio.AsyncServer(async_mode="asgi", cors_allowed_origins=allowedOrigins)
|
||||
namespace = MMVC_Namespace.get_instance(voiceChangerManager)
|
||||
sio.register_namespace(namespace)
|
||||
cls._instance = sio
|
||||
|
Loading…
Reference in New Issue
Block a user