Refactor and add origin check to SIO

This commit is contained in:
Yury 2024-03-18 22:52:22 +02:00
parent ce9b599501
commit 8dd8d7127d
6 changed files with 64 additions and 28 deletions

View File

@ -140,8 +140,8 @@ if __name__ == "MMVCServerSIO":
mp.freeze_support()
voiceChangerManager = VoiceChangerManager.get_instance(voiceChangerParams)
app_fastapi = MMVC_Rest.get_instance(voiceChangerManager, voiceChangerParams, PORT, args.allowed_origins)
app_socketio = MMVC_SocketIOApp.get_instance(app_fastapi, voiceChangerManager)
app_fastapi = MMVC_Rest.get_instance(voiceChangerManager, voiceChangerParams, args.allowed_origins, PORT)
app_socketio = MMVC_SocketIOApp.get_instance(app_fastapi, voiceChangerManager, args.allowed_origins, PORT)
if __name__ == "__mp_main__":

24
server/mods/origins.py Normal file
View File

@ -0,0 +1,24 @@
from typing import Optional, Sequence
from urllib.parse import urlparse
ENFORCE_URL_ORIGIN_FORMAT = "Input origins must be well-formed URLs, i.e. https://google.com or https://www.google.com."
SCHEMAS = ('http', 'https')
LOCAL_ORIGINS = ('127.0.0.1', 'localhost')
def compute_local_origins(port: Optional[int] = None) -> list[str]:
local_origins = [f'{schema}://{origin}' for schema in SCHEMAS for origin in LOCAL_ORIGINS]
if port is not None:
local_origins = [f'{origin}:{port}' for origin in local_origins]
return local_origins
def normalize_origins(origins: Sequence[str]) -> set[str]:
allowed_origins = set()
for origin in origins:
url = urlparse(origin)
assert url.scheme, ENFORCE_URL_ORIGIN_FORMAT
valid_origin = f'{url.scheme}://{url.hostname}'
if url.port:
valid_origin += f':{url.port}'
allowed_origins.add(valid_origin)
return allowed_origins

View File

@ -6,7 +6,7 @@ from fastapi import FastAPI, Request, Response, HTTPException
from fastapi.routing import APIRoute
from fastapi.staticfiles import StaticFiles
from fastapi.exceptions import RequestValidationError
from typing import Callable
from typing import Callable, Optional, Sequence, Literal
from mods.log_control import VoiceChangaerLogger
from voice_changer.VoiceChangerManager import VoiceChangerManager
@ -43,8 +43,8 @@ class MMVC_Rest:
cls,
voiceChangerManager: VoiceChangerManager,
voiceChangerParams: VoiceChangerParams,
port: int,
allowedOrigins: list[str],
allowedOrigins: Optional[Sequence[str]] = None,
port: Optional[int] = None,
):
if cls._instance is None:
logger.info("[Voice Changer] MMVC_Rest initializing...")

View File

@ -1,35 +1,27 @@
import typing
from typing import Optional, Sequence, Literal
from urllib.parse import urlparse
from mods.origins import compute_local_origins, normalize_origins
from starlette.datastructures import Headers
from starlette.responses import PlainTextResponse
from starlette.types import ASGIApp, Receive, Scope, Send
ENFORCE_URL_ORIGIN_FORMAT = "Input origins must be well-formed URLs, i.e. https://google.com or https://www.google.com."
class TrustedOriginMiddleware:
def __init__(
self,
app: ASGIApp,
allowed_origins: typing.Optional[typing.Sequence[str]] = None,
port: typing.Optional[int] = None,
allowed_origins: Optional[Sequence[str]] = None,
port: Optional[int] = None,
) -> None:
schemas = ['http', 'https']
local_origins = [f'{schema}://{origin}' for schema in schemas for origin in ['127.0.0.1', 'localhost']]
if port is not None:
local_origins = [f'{origin}:{port}' for origin in local_origins]
self.allowed_origins: set[str] = set()
if allowed_origins is not None:
for origin in allowed_origins:
url = urlparse(origin)
assert url.scheme, ENFORCE_URL_ORIGIN_FORMAT
valid_origin = f'{url.scheme}://{url.hostname}'
if url.port:
valid_origin += f':{url.port}'
self.allowed_origins.add(valid_origin)
local_origins = compute_local_origins(port)
self.allowed_origins.update(local_origins)
if allowed_origins is not None:
normalized_origins = normalize_origins(allowed_origins)
self.allowed_origins.update(normalized_origins)
self.app = app
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:

View File

@ -1,6 +1,8 @@
import socketio
from mods.log_control import VoiceChangaerLogger
from mods.origins import compute_local_origins, normalize_origins
from typing import Sequence, Optional
from sio.MMVC_SocketIOServer import MMVC_SocketIOServer
from voice_changer.VoiceChangerManager import VoiceChangerManager
from const import getFrontendPath
@ -12,10 +14,24 @@ class MMVC_SocketIOApp:
_instance: socketio.ASGIApp | None = None
@classmethod
def get_instance(cls, app_fastapi, voiceChangerManager: VoiceChangerManager):
def get_instance(
cls,
app_fastapi,
voiceChangerManager: VoiceChangerManager,
allowedOrigins: Optional[Sequence[str]] = None,
port: Optional[int] = None,
):
if cls._instance is None:
logger.info("[Voice Changer] MMVC_SocketIOApp initializing...")
sio = MMVC_SocketIOServer.get_instance(voiceChangerManager)
allowed_origins: set[str] = set()
local_origins = compute_local_origins(port)
allowed_origins.update(local_origins)
if allowedOrigins is not None:
normalized_origins = normalize_origins(allowedOrigins)
allowed_origins.update(normalized_origins)
sio = MMVC_SocketIOServer.get_instance(voiceChangerManager, list(allowed_origins))
app_socketio = socketio.ASGIApp(
sio,
other_asgi_app=app_fastapi,

View File

@ -8,9 +8,13 @@ class MMVC_SocketIOServer:
_instance: socketio.AsyncServer | None = None
@classmethod
def get_instance(cls, voiceChangerManager: VoiceChangerManager):
def get_instance(
cls,
voiceChangerManager: VoiceChangerManager,
allowedOrigins: list[str],
):
if cls._instance is None:
sio = socketio.AsyncServer(async_mode="asgi", cors_allowed_origins="*")
sio = socketio.AsyncServer(async_mode="asgi", cors_allowed_origins=allowedOrigins)
namespace = MMVC_Namespace.get_instance(voiceChangerManager)
sio.register_namespace(namespace)
cls._instance = sio