feat: add FastAPI app with config delivery endpoint
This commit is contained in:
+132
@@ -0,0 +1,132 @@
|
||||
import logging
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import datetime
|
||||
from typing import AsyncGenerator
|
||||
|
||||
import yaml
|
||||
from fastapi import FastAPI, Depends, HTTPException
|
||||
from fastapi.responses import Response
|
||||
from fastapi.templating import Jinja2Templates
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
|
||||
from models import Base, Config, Subscription, ExportLog, make_engine, make_session_factory
|
||||
from mihomo import MihomoClient
|
||||
from expander import expand_config, build_mihomo_config
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s %(levelname)s %(message)s",
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DATABASE_URL = os.environ.get("DATABASE_URL", "sqlite+aiosqlite:////data/db/app.db")
|
||||
MIHOMO_API = os.environ.get("MIHOMO_API", "http://mihomo:9090")
|
||||
MIHOMO_SECRET = os.environ.get("MIHOMO_SECRET", "")
|
||||
MIHOMO_CONFIG_DIR = os.environ.get("MIHOMO_CONFIG_DIR", "/data/mihomo")
|
||||
|
||||
engine = make_engine(DATABASE_URL)
|
||||
SessionLocal = make_session_factory(engine)
|
||||
mihomo_client = MihomoClient(MIHOMO_API, MIHOMO_SECRET)
|
||||
templates = Jinja2Templates(
|
||||
directory=os.path.join(os.path.dirname(__file__), "templates")
|
||||
)
|
||||
|
||||
|
||||
async def get_db() -> AsyncGenerator[AsyncSession, None]:
|
||||
async with SessionLocal() as session:
|
||||
yield session
|
||||
|
||||
|
||||
async def write_and_reload_mihomo(db: AsyncSession) -> None:
|
||||
result = await db.execute(select(Config))
|
||||
configs = result.scalars().all()
|
||||
config_yaml = build_mihomo_config([c.base_yaml for c in configs], MIHOMO_SECRET)
|
||||
config_path = os.path.join(MIHOMO_CONFIG_DIR, "config.yaml")
|
||||
tmp_path = config_path + ".tmp"
|
||||
os.makedirs(MIHOMO_CONFIG_DIR, exist_ok=True)
|
||||
with open(tmp_path, "w") as f:
|
||||
f.write(config_yaml)
|
||||
os.replace(tmp_path, config_path)
|
||||
logger.info("Wrote Mihomo config to %s", config_path)
|
||||
await mihomo_client.reload_config()
|
||||
logger.info("Mihomo config reloaded")
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI): # type: ignore[type-arg]
|
||||
os.makedirs(MIHOMO_CONFIG_DIR, exist_ok=True)
|
||||
db_path = DATABASE_URL.split("///")[-1]
|
||||
if db_path:
|
||||
os.makedirs(os.path.dirname(db_path), exist_ok=True)
|
||||
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
await mihomo_client.wait_ready()
|
||||
|
||||
async with SessionLocal() as db:
|
||||
await write_and_reload_mihomo(db)
|
||||
|
||||
yield
|
||||
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
|
||||
@app.get("/config/{token}.yaml")
|
||||
async def get_config(
|
||||
token: str,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Response:
|
||||
result = await db.execute(select(Config).where(Config.token == token))
|
||||
config = result.scalar_one_or_none()
|
||||
if not config:
|
||||
raise HTTPException(status_code=404, detail="Config not found")
|
||||
|
||||
result = await db.execute(
|
||||
select(Subscription).where(Subscription.config_id == config.id)
|
||||
)
|
||||
subscriptions = result.scalars().all()
|
||||
|
||||
provider_proxies: dict[str, list[dict]] = {}
|
||||
errors: list[str] = []
|
||||
|
||||
for sub in subscriptions:
|
||||
try:
|
||||
proxies = await mihomo_client.refresh_and_collect(sub.name, timeout=30)
|
||||
provider_proxies[sub.name] = proxies
|
||||
sub.last_fetched_at = datetime.utcnow()
|
||||
except Exception as exc:
|
||||
logger.error("Failed to refresh provider %s: %s", sub.name, exc)
|
||||
errors.append(f"{sub.name}: {exc}")
|
||||
|
||||
try:
|
||||
expanded = expand_config(config.base_yaml, provider_proxies)
|
||||
except Exception as exc:
|
||||
logger.error("Config expansion failed for token %s: %s", token, exc)
|
||||
db.add(
|
||||
ExportLog(
|
||||
config_id=config.id,
|
||||
node_count=0,
|
||||
success=False,
|
||||
error_message=str(exc),
|
||||
)
|
||||
)
|
||||
await db.commit()
|
||||
raise HTTPException(status_code=500, detail=f"Config expansion failed: {exc}")
|
||||
|
||||
node_count = sum(len(p) for p in provider_proxies.values())
|
||||
error_msg = "; ".join(errors) if errors else None
|
||||
db.add(
|
||||
ExportLog(
|
||||
config_id=config.id,
|
||||
node_count=node_count,
|
||||
success=not bool(errors),
|
||||
error_message=error_msg,
|
||||
)
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
return Response(content=expanded, media_type="application/x-yaml")
|
||||
@@ -0,0 +1,130 @@
|
||||
import uuid
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from httpx import AsyncClient, ASGITransport
|
||||
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
|
||||
|
||||
from models import Base, Config, Subscription, ExportLog
|
||||
from main import app, get_db
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def db_engine():
|
||||
engine = create_async_engine("sqlite+aiosqlite:///:memory:")
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
yield engine
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def db_session(db_engine):
|
||||
Session = async_sessionmaker(db_engine, expire_on_commit=False)
|
||||
async with Session() as session:
|
||||
yield session
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def http_client(db_session):
|
||||
async def override_get_db():
|
||||
yield db_session
|
||||
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
|
||||
async with AsyncClient(
|
||||
transport=ASGITransport(app=app), base_url="http://test"
|
||||
) as client:
|
||||
yield client
|
||||
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
async def test_get_config_not_found(http_client):
|
||||
resp = await http_client.get("/config/nonexistent-token.yaml")
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
async def test_get_config_returns_yaml(http_client, db_session):
|
||||
token = str(uuid.uuid4())
|
||||
config = Config(
|
||||
name="test",
|
||||
token=token,
|
||||
base_yaml="proxies: []\nproxy-groups: []\nrules:\n - MATCH,DIRECT\n",
|
||||
)
|
||||
db_session.add(config)
|
||||
await db_session.commit()
|
||||
|
||||
with patch("main.mihomo_client") as mock_mc:
|
||||
mock_mc.refresh_and_collect = AsyncMock(return_value=[])
|
||||
resp = await http_client.get(f"/config/{token}.yaml")
|
||||
|
||||
assert resp.status_code == 200
|
||||
assert "proxies" in resp.text
|
||||
|
||||
|
||||
async def test_get_config_writes_export_log(http_client, db_session):
|
||||
from sqlalchemy import select
|
||||
|
||||
token = str(uuid.uuid4())
|
||||
config = Config(
|
||||
name="test",
|
||||
token=token,
|
||||
base_yaml="proxies: []\nproxy-groups: []\nrules: []\n",
|
||||
)
|
||||
db_session.add(config)
|
||||
await db_session.commit()
|
||||
|
||||
with patch("main.mihomo_client") as mock_mc:
|
||||
mock_mc.refresh_and_collect = AsyncMock(return_value=[])
|
||||
await http_client.get(f"/config/{token}.yaml")
|
||||
|
||||
result = await db_session.execute(
|
||||
select(ExportLog).where(ExportLog.config_id == config.id)
|
||||
)
|
||||
logs = result.scalars().all()
|
||||
assert len(logs) == 1
|
||||
assert logs[0].success is True
|
||||
|
||||
|
||||
async def test_get_config_with_subscription_expands_nodes(http_client, db_session):
|
||||
token = str(uuid.uuid4())
|
||||
config = Config(
|
||||
name="test",
|
||||
token=token,
|
||||
base_yaml=(
|
||||
"proxies: []\n"
|
||||
"proxy-providers:\n"
|
||||
" myprovider:\n"
|
||||
" type: http\n"
|
||||
" url: https://example.com/sub\n"
|
||||
" interval: 3600\n"
|
||||
"proxy-groups:\n"
|
||||
" - name: Proxy\n"
|
||||
" type: select\n"
|
||||
" use:\n"
|
||||
" - myprovider\n"
|
||||
"rules:\n"
|
||||
" - MATCH,DIRECT\n"
|
||||
),
|
||||
)
|
||||
db_session.add(config)
|
||||
await db_session.flush()
|
||||
|
||||
sub = Subscription(config_id=config.id, name="myprovider", url="https://example.com/sub")
|
||||
db_session.add(sub)
|
||||
await db_session.commit()
|
||||
|
||||
fake_proxies = [
|
||||
{"name": "node1", "type": "ss", "server": "1.2.3.4", "port": 443,
|
||||
"password": "pwd", "cipher": "aes-256-gcm", "alive": True},
|
||||
]
|
||||
|
||||
with patch("main.mihomo_client") as mock_mc:
|
||||
mock_mc.refresh_and_collect = AsyncMock(return_value=fake_proxies)
|
||||
resp = await http_client.get(f"/config/{token}.yaml")
|
||||
|
||||
assert resp.status_code == 200
|
||||
assert "node1" in resp.text
|
||||
assert "proxy-providers" not in resp.text
|
||||
assert "alive" not in resp.text
|
||||
Reference in New Issue
Block a user