From 94c01fe532128703d3a122694fc20ab3986f7cb8 Mon Sep 17 00:00:00 2001 From: urbnywrt Date: Fri, 15 May 2026 00:08:00 +0300 Subject: [PATCH] feat: add FastAPI app with config delivery endpoint --- app/main.py | 132 +++++++++++++++++++++++++++++++++++++++ app/tests/test_routes.py | 130 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 262 insertions(+) create mode 100644 app/main.py create mode 100644 app/tests/test_routes.py diff --git a/app/main.py b/app/main.py new file mode 100644 index 0000000..c3d7b00 --- /dev/null +++ b/app/main.py @@ -0,0 +1,132 @@ +import logging +import os +from contextlib import asynccontextmanager +from datetime import datetime +from typing import AsyncGenerator + +import yaml +from fastapi import FastAPI, Depends, HTTPException +from fastapi.responses import Response +from fastapi.templating import Jinja2Templates +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy import select + +from models import Base, Config, Subscription, ExportLog, make_engine, make_session_factory +from mihomo import MihomoClient +from expander import expand_config, build_mihomo_config + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s %(levelname)s %(message)s", +) +logger = logging.getLogger(__name__) + +DATABASE_URL = os.environ.get("DATABASE_URL", "sqlite+aiosqlite:////data/db/app.db") +MIHOMO_API = os.environ.get("MIHOMO_API", "http://mihomo:9090") +MIHOMO_SECRET = os.environ.get("MIHOMO_SECRET", "") +MIHOMO_CONFIG_DIR = os.environ.get("MIHOMO_CONFIG_DIR", "/data/mihomo") + +engine = make_engine(DATABASE_URL) +SessionLocal = make_session_factory(engine) +mihomo_client = MihomoClient(MIHOMO_API, MIHOMO_SECRET) +templates = Jinja2Templates( + directory=os.path.join(os.path.dirname(__file__), "templates") +) + + +async def get_db() -> AsyncGenerator[AsyncSession, None]: + async with SessionLocal() as session: + yield session + + +async def write_and_reload_mihomo(db: AsyncSession) -> None: + result = await db.execute(select(Config)) + configs = result.scalars().all() + config_yaml = build_mihomo_config([c.base_yaml for c in configs], MIHOMO_SECRET) + config_path = os.path.join(MIHOMO_CONFIG_DIR, "config.yaml") + tmp_path = config_path + ".tmp" + os.makedirs(MIHOMO_CONFIG_DIR, exist_ok=True) + with open(tmp_path, "w") as f: + f.write(config_yaml) + os.replace(tmp_path, config_path) + logger.info("Wrote Mihomo config to %s", config_path) + await mihomo_client.reload_config() + logger.info("Mihomo config reloaded") + + +@asynccontextmanager +async def lifespan(app: FastAPI): # type: ignore[type-arg] + os.makedirs(MIHOMO_CONFIG_DIR, exist_ok=True) + db_path = DATABASE_URL.split("///")[-1] + if db_path: + os.makedirs(os.path.dirname(db_path), exist_ok=True) + + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + + await mihomo_client.wait_ready() + + async with SessionLocal() as db: + await write_and_reload_mihomo(db) + + yield + + +app = FastAPI(lifespan=lifespan) + + +@app.get("/config/{token}.yaml") +async def get_config( + token: str, + db: AsyncSession = Depends(get_db), +) -> Response: + result = await db.execute(select(Config).where(Config.token == token)) + config = result.scalar_one_or_none() + if not config: + raise HTTPException(status_code=404, detail="Config not found") + + result = await db.execute( + select(Subscription).where(Subscription.config_id == config.id) + ) + subscriptions = result.scalars().all() + + provider_proxies: dict[str, list[dict]] = {} + errors: list[str] = [] + + for sub in subscriptions: + try: + proxies = await mihomo_client.refresh_and_collect(sub.name, timeout=30) + provider_proxies[sub.name] = proxies + sub.last_fetched_at = datetime.utcnow() + except Exception as exc: + logger.error("Failed to refresh provider %s: %s", sub.name, exc) + errors.append(f"{sub.name}: {exc}") + + try: + expanded = expand_config(config.base_yaml, provider_proxies) + except Exception as exc: + logger.error("Config expansion failed for token %s: %s", token, exc) + db.add( + ExportLog( + config_id=config.id, + node_count=0, + success=False, + error_message=str(exc), + ) + ) + await db.commit() + raise HTTPException(status_code=500, detail=f"Config expansion failed: {exc}") + + node_count = sum(len(p) for p in provider_proxies.values()) + error_msg = "; ".join(errors) if errors else None + db.add( + ExportLog( + config_id=config.id, + node_count=node_count, + success=not bool(errors), + error_message=error_msg, + ) + ) + await db.commit() + + return Response(content=expanded, media_type="application/x-yaml") diff --git a/app/tests/test_routes.py b/app/tests/test_routes.py new file mode 100644 index 0000000..f83d7e8 --- /dev/null +++ b/app/tests/test_routes.py @@ -0,0 +1,130 @@ +import uuid +import pytest +import pytest_asyncio +from unittest.mock import AsyncMock, patch +from httpx import AsyncClient, ASGITransport +from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker + +from models import Base, Config, Subscription, ExportLog +from main import app, get_db + + +@pytest_asyncio.fixture +async def db_engine(): + engine = create_async_engine("sqlite+aiosqlite:///:memory:") + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + yield engine + await engine.dispose() + + +@pytest_asyncio.fixture +async def db_session(db_engine): + Session = async_sessionmaker(db_engine, expire_on_commit=False) + async with Session() as session: + yield session + + +@pytest_asyncio.fixture +async def http_client(db_session): + async def override_get_db(): + yield db_session + + app.dependency_overrides[get_db] = override_get_db + + async with AsyncClient( + transport=ASGITransport(app=app), base_url="http://test" + ) as client: + yield client + + app.dependency_overrides.clear() + + +async def test_get_config_not_found(http_client): + resp = await http_client.get("/config/nonexistent-token.yaml") + assert resp.status_code == 404 + + +async def test_get_config_returns_yaml(http_client, db_session): + token = str(uuid.uuid4()) + config = Config( + name="test", + token=token, + base_yaml="proxies: []\nproxy-groups: []\nrules:\n - MATCH,DIRECT\n", + ) + db_session.add(config) + await db_session.commit() + + with patch("main.mihomo_client") as mock_mc: + mock_mc.refresh_and_collect = AsyncMock(return_value=[]) + resp = await http_client.get(f"/config/{token}.yaml") + + assert resp.status_code == 200 + assert "proxies" in resp.text + + +async def test_get_config_writes_export_log(http_client, db_session): + from sqlalchemy import select + + token = str(uuid.uuid4()) + config = Config( + name="test", + token=token, + base_yaml="proxies: []\nproxy-groups: []\nrules: []\n", + ) + db_session.add(config) + await db_session.commit() + + with patch("main.mihomo_client") as mock_mc: + mock_mc.refresh_and_collect = AsyncMock(return_value=[]) + await http_client.get(f"/config/{token}.yaml") + + result = await db_session.execute( + select(ExportLog).where(ExportLog.config_id == config.id) + ) + logs = result.scalars().all() + assert len(logs) == 1 + assert logs[0].success is True + + +async def test_get_config_with_subscription_expands_nodes(http_client, db_session): + token = str(uuid.uuid4()) + config = Config( + name="test", + token=token, + base_yaml=( + "proxies: []\n" + "proxy-providers:\n" + " myprovider:\n" + " type: http\n" + " url: https://example.com/sub\n" + " interval: 3600\n" + "proxy-groups:\n" + " - name: Proxy\n" + " type: select\n" + " use:\n" + " - myprovider\n" + "rules:\n" + " - MATCH,DIRECT\n" + ), + ) + db_session.add(config) + await db_session.flush() + + sub = Subscription(config_id=config.id, name="myprovider", url="https://example.com/sub") + db_session.add(sub) + await db_session.commit() + + fake_proxies = [ + {"name": "node1", "type": "ss", "server": "1.2.3.4", "port": 443, + "password": "pwd", "cipher": "aes-256-gcm", "alive": True}, + ] + + with patch("main.mihomo_client") as mock_mc: + mock_mc.refresh_and_collect = AsyncMock(return_value=fake_proxies) + resp = await http_client.get(f"/config/{token}.yaml") + + assert resp.status_code == 200 + assert "node1" in resp.text + assert "proxy-providers" not in resp.text + assert "alive" not in resp.text