Coverage for backend/hello.py: 74%

103 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-05-02 01:42 +0000

1import asyncio 

2import time 

3import typing 

4from collections.abc import AsyncGenerator, Awaitable, Callable 

5from contextlib import asynccontextmanager 

6 

7import structlog 

8from asgi_correlation_id import CorrelationIdMiddleware, correlation_id 

9from fastapi import Depends, FastAPI, HTTPException, Request, Response, WebSocket 

10from fastapi.middleware.cors import CORSMiddleware 

11from prometheus_fastapi_instrumentator import Instrumentator 

12from pydantic import ValidationError 

13from uvicorn.protocols.utils import get_path_with_query_string 

14 

15from .dependencies import Channel, LobbyManager, connection_manager, lobby_manager, settings 

16from .game_state import LobbyClosedError, LobbyFullError, LobbyNotFoundError, TaggedMessage 

17from .logging_config import setup_logging 

18from .models import Annotated, Initializer, LobbyJoinRequest, Message 

19 

20access_logger = structlog.stdlib.get_logger("api.access") 

21 

22 

23@asynccontextmanager 

24async def lifespan(_: FastAPI) -> AsyncGenerator: 

25 """Setup demo if necessary.""" 

26 s = settings() 

27 if s.env == "demo": 

28 lobby = lobby_manager() 

29 lobby.register_lobby() 

30 

31 if s.env == "prod": 

32 json_logs = True 

33 else: 

34 json_logs = False 

35 

36 setup_logging(json_logs=json_logs, log_level=s.log_level) 

37 

38 yield 

39 

40 

41app = FastAPI(lifespan=lifespan) 

42 

43app.add_middleware( 

44 CORSMiddleware, 

45 allow_origins=[settings().frontend_url], 

46 allow_credentials=True, 

47 allow_methods=["*"], 

48 allow_headers=["*"], 

49) 

50 

51Instrumentator().instrument(app, metric_namespace="fastapi").expose(app, include_in_schema=False) 

52 

53 

54@app.middleware("http") 

55async def logging_middleware(request: Request, call_next: Callable[[Request], Awaitable[Response]]) -> Response: 

56 """Add log info to all requests.""" 

57 structlog.contextvars.clear_contextvars() 

58 # These context vars will be added to all log entries emitted during the request 

59 request_id = correlation_id.get() 

60 structlog.contextvars.bind_contextvars(request_id=request_id) 

61 

62 start_time = time.perf_counter_ns() 

63 # If the call_next raises an error, we still want to return our own 500 

64 # response, so we can add headers to it (process time, request ID...) 

65 response = Response(status_code=500) 

66 try: 

67 response = await call_next(request) 

68 except Exception: 

69 structlog.stdlib.get_logger("api.error").exception("Uncaught exception") 

70 raise 

71 finally: 

72 process_time = time.perf_counter_ns() - start_time 

73 

74 status_code = response.status_code 

75 url = get_path_with_query_string(request.scope) # pyright: ignore[reportArgumentType] 

76 client_host = request.client.host if request.client else "" 

77 client_port = request.client.port if request.client else "" 

78 http_method = request.method 

79 http_version = request.scope["http_version"] 

80 # Recreate the Uvicorn access log format, but add all parameters as 

81 # structured information 

82 access_logger.info( 

83 f"""{client_host}:{client_port} - "{http_method} {url} HTTP/{http_version}" {status_code}""", 

84 http={ 

85 "url": str(request.url), 

86 "status_code": status_code, 

87 "method": http_method, 

88 "request_id": request_id, 

89 "version": http_version, 

90 }, 

91 network={"client": {"ip": client_host, "port": client_port}}, 

92 duration=process_time, 

93 ) 

94 response.headers["X-Process-Time"] = str(process_time / 10**9) 

95 return response # noqa: B012 

96 

97 

98app.add_middleware(CorrelationIdMiddleware) 

99 

100logger = structlog.stdlib.get_logger() 

101 

102 

103@app.post("/lobby") 

104def create_lobby(lm: Annotated[LobbyManager, Depends(lobby_manager)]) -> dict[str, tuple[str, ...]]: 

105 """Creates a new lobby.""" 

106 code = lm.register_lobby() 

107 

108 return {"code": code} 

109 

110 

111@app.post("/lobby/join") 

112def join_lobby(req: LobbyJoinRequest, lm: Annotated[LobbyManager, Depends(lobby_manager)]) -> dict[str, str]: 

113 """Joins a player to a lobby using a list of ingredients as the code.""" 

114 try: 

115 id = lm.register_player(req.code) 

116 except LobbyFullError: 

117 raise HTTPException(status_code=403, detail="Lobby is full") 

118 except LobbyNotFoundError: 

119 raise HTTPException(status_code=404, detail="Lobby not found") 

120 except LobbyClosedError: 

121 raise HTTPException(status_code=403, detail="Lobby has already started") 

122 

123 return {"id": id} 

124 

125 

126@app.get("/") 

127async def read_root() -> dict: 

128 """Returns a simple message at root.""" 

129 return {"Hello": "World"} 

130 

131 

132@app.websocket("/ws") 

133async def websocket_endpoint(websocket: WebSocket, lm: Annotated[LobbyManager, Depends(lobby_manager)]) -> None: 

134 """Handles a WebSocket connection for receiving and responding to messages.""" 

135 async with connection_manager(websocket): 

136 try: 

137 init = Message.model_validate(await websocket.receive_json()) 

138 except ValidationError: 

139 logger.debug("Received invalid websocket message.") 

140 return 

141 

142 match init: 

143 case Message(data=Initializer(code=code, id=id)): 

144 channel = lm.channel(code, id) 

145 logger.info("WebSocket connection initialized.", player=id) 

146 case _: 

147 logger.info("WebSocket connection failed to initialize.", message=init) 

148 return 

149 

150 await asyncio.gather(_recv_handler(id, websocket, channel), _send_handler(id, websocket, channel)) 

151 

152 

153async def _recv_handler(id: str, websocket: WebSocket, channel: Channel[TaggedMessage, Message]) -> typing.Never: 

154 while True: 

155 raw = await websocket.receive_text() 

156 try: 

157 data = Message.model_validate_json(raw) 

158 except ValidationError: 

159 logger.warning("Received invalid websocket message.", data=raw) 

160 else: 

161 logger.debug("Received WebSocket message.", message=data, player=id) 

162 channel.send(TaggedMessage(data=data.data, id=id)) 

163 

164 

165async def _send_handler(id: str, websocket: WebSocket, channel: Channel[TaggedMessage, Message]) -> typing.Never: 

166 while True: 

167 msg = await channel.arecv() 

168 logger.debug("Sending WebSocket message.", message=msg, player=id) 

169 await websocket.send_text(msg.model_dump_json())