Coverage for backend/hello.py: 74%
103 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-05-02 01:42 +0000
« prev ^ index » next coverage.py v7.6.12, created at 2025-05-02 01:42 +0000
1import asyncio
2import time
3import typing
4from collections.abc import AsyncGenerator, Awaitable, Callable
5from contextlib import asynccontextmanager
7import structlog
8from asgi_correlation_id import CorrelationIdMiddleware, correlation_id
9from fastapi import Depends, FastAPI, HTTPException, Request, Response, WebSocket
10from fastapi.middleware.cors import CORSMiddleware
11from prometheus_fastapi_instrumentator import Instrumentator
12from pydantic import ValidationError
13from uvicorn.protocols.utils import get_path_with_query_string
15from .dependencies import Channel, LobbyManager, connection_manager, lobby_manager, settings
16from .game_state import LobbyClosedError, LobbyFullError, LobbyNotFoundError, TaggedMessage
17from .logging_config import setup_logging
18from .models import Annotated, Initializer, LobbyJoinRequest, Message
20access_logger = structlog.stdlib.get_logger("api.access")
23@asynccontextmanager
24async def lifespan(_: FastAPI) -> AsyncGenerator:
25 """Setup demo if necessary."""
26 s = settings()
27 if s.env == "demo":
28 lobby = lobby_manager()
29 lobby.register_lobby()
31 if s.env == "prod":
32 json_logs = True
33 else:
34 json_logs = False
36 setup_logging(json_logs=json_logs, log_level=s.log_level)
38 yield
41app = FastAPI(lifespan=lifespan)
43app.add_middleware(
44 CORSMiddleware,
45 allow_origins=[settings().frontend_url],
46 allow_credentials=True,
47 allow_methods=["*"],
48 allow_headers=["*"],
49)
51Instrumentator().instrument(app, metric_namespace="fastapi").expose(app, include_in_schema=False)
54@app.middleware("http")
55async def logging_middleware(request: Request, call_next: Callable[[Request], Awaitable[Response]]) -> Response:
56 """Add log info to all requests."""
57 structlog.contextvars.clear_contextvars()
58 # These context vars will be added to all log entries emitted during the request
59 request_id = correlation_id.get()
60 structlog.contextvars.bind_contextvars(request_id=request_id)
62 start_time = time.perf_counter_ns()
63 # If the call_next raises an error, we still want to return our own 500
64 # response, so we can add headers to it (process time, request ID...)
65 response = Response(status_code=500)
66 try:
67 response = await call_next(request)
68 except Exception:
69 structlog.stdlib.get_logger("api.error").exception("Uncaught exception")
70 raise
71 finally:
72 process_time = time.perf_counter_ns() - start_time
74 status_code = response.status_code
75 url = get_path_with_query_string(request.scope) # pyright: ignore[reportArgumentType]
76 client_host = request.client.host if request.client else ""
77 client_port = request.client.port if request.client else ""
78 http_method = request.method
79 http_version = request.scope["http_version"]
80 # Recreate the Uvicorn access log format, but add all parameters as
81 # structured information
82 access_logger.info(
83 f"""{client_host}:{client_port} - "{http_method} {url} HTTP/{http_version}" {status_code}""",
84 http={
85 "url": str(request.url),
86 "status_code": status_code,
87 "method": http_method,
88 "request_id": request_id,
89 "version": http_version,
90 },
91 network={"client": {"ip": client_host, "port": client_port}},
92 duration=process_time,
93 )
94 response.headers["X-Process-Time"] = str(process_time / 10**9)
95 return response # noqa: B012
98app.add_middleware(CorrelationIdMiddleware)
100logger = structlog.stdlib.get_logger()
103@app.post("/lobby")
104def create_lobby(lm: Annotated[LobbyManager, Depends(lobby_manager)]) -> dict[str, tuple[str, ...]]:
105 """Creates a new lobby."""
106 code = lm.register_lobby()
108 return {"code": code}
111@app.post("/lobby/join")
112def join_lobby(req: LobbyJoinRequest, lm: Annotated[LobbyManager, Depends(lobby_manager)]) -> dict[str, str]:
113 """Joins a player to a lobby using a list of ingredients as the code."""
114 try:
115 id = lm.register_player(req.code)
116 except LobbyFullError:
117 raise HTTPException(status_code=403, detail="Lobby is full")
118 except LobbyNotFoundError:
119 raise HTTPException(status_code=404, detail="Lobby not found")
120 except LobbyClosedError:
121 raise HTTPException(status_code=403, detail="Lobby has already started")
123 return {"id": id}
126@app.get("/")
127async def read_root() -> dict:
128 """Returns a simple message at root."""
129 return {"Hello": "World"}
132@app.websocket("/ws")
133async def websocket_endpoint(websocket: WebSocket, lm: Annotated[LobbyManager, Depends(lobby_manager)]) -> None:
134 """Handles a WebSocket connection for receiving and responding to messages."""
135 async with connection_manager(websocket):
136 try:
137 init = Message.model_validate(await websocket.receive_json())
138 except ValidationError:
139 logger.debug("Received invalid websocket message.")
140 return
142 match init:
143 case Message(data=Initializer(code=code, id=id)):
144 channel = lm.channel(code, id)
145 logger.info("WebSocket connection initialized.", player=id)
146 case _:
147 logger.info("WebSocket connection failed to initialize.", message=init)
148 return
150 await asyncio.gather(_recv_handler(id, websocket, channel), _send_handler(id, websocket, channel))
153async def _recv_handler(id: str, websocket: WebSocket, channel: Channel[TaggedMessage, Message]) -> typing.Never:
154 while True:
155 raw = await websocket.receive_text()
156 try:
157 data = Message.model_validate_json(raw)
158 except ValidationError:
159 logger.warning("Received invalid websocket message.", data=raw)
160 else:
161 logger.debug("Received WebSocket message.", message=data, player=id)
162 channel.send(TaggedMessage(data=data.data, id=id))
165async def _send_handler(id: str, websocket: WebSocket, channel: Channel[TaggedMessage, Message]) -> typing.Never:
166 while True:
167 msg = await channel.arecv()
168 logger.debug("Sending WebSocket message.", message=msg, player=id)
169 await websocket.send_text(msg.model_dump_json())