Coverage for backend/hello.py: 33%

103 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-04-17 17:55 +0000

1import asyncio 

2import time 

3import typing 

4from collections.abc import AsyncGenerator, Awaitable, Callable 

5from contextlib import asynccontextmanager 

6 

7import structlog 

8from asgi_correlation_id import CorrelationIdMiddleware, correlation_id 

9from fastapi import Depends, FastAPI, HTTPException, Request, Response, WebSocket 

10from fastapi.middleware.cors import CORSMiddleware 

11from pydantic import ValidationError 

12from uvicorn.protocols.utils import get_path_with_query_string 

13 

14from .dependencies import Channel, LobbyManager, lobby_manager, settings 

15from .game_state import LobbyClosedError, LobbyFullError, LobbyNotFoundError, TaggedMessage 

16from .logging_config import setup_logging 

17from .models import Annotated, Initializer, LobbyJoinRequest, Message 

18 

19access_logger = structlog.stdlib.get_logger("api.access") 

20 

21 

22@asynccontextmanager 

23async def lifespan(_: FastAPI) -> AsyncGenerator: 

24 """Setup demo if necessary.""" 

25 s = settings() 

26 if s.env == "demo": 

27 lobby = lobby_manager() 

28 lobby.register_lobby() 

29 

30 if s.env == "prod": 

31 json_logs = True 

32 else: 

33 json_logs = False 

34 

35 setup_logging(json_logs=json_logs, log_level=s.log_level) 

36 

37 yield 

38 

39 

40app = FastAPI(lifespan=lifespan) 

41 

42app.add_middleware( 

43 CORSMiddleware, 

44 allow_origins=[settings().frontend_url], 

45 allow_credentials=True, 

46 allow_methods=["*"], 

47 allow_headers=["*"], 

48) 

49 

50 

51@app.middleware("http") 

52async def logging_middleware(request: Request, call_next: Callable[[Request], Awaitable[Response]]) -> Response: 

53 """Add log info to all requests.""" 

54 structlog.contextvars.clear_contextvars() 

55 # These context vars will be added to all log entries emitted during the request 

56 request_id = correlation_id.get() 

57 structlog.contextvars.bind_contextvars(request_id=request_id) 

58 

59 start_time = time.perf_counter_ns() 

60 # If the call_next raises an error, we still want to return our own 500 

61 # response, so we can add headers to it (process time, request ID...) 

62 response = Response(status_code=500) 

63 try: 

64 response = await call_next(request) 

65 except Exception: 

66 structlog.stdlib.get_logger("api.error").exception("Uncaught exception") 

67 raise 

68 finally: 

69 process_time = time.perf_counter_ns() - start_time 

70 

71 status_code = response.status_code 

72 url = get_path_with_query_string(request.scope) # pyright: ignore[reportArgumentType] 

73 client_host = request.client.host if request.client else "" 

74 client_port = request.client.port if request.client else "" 

75 http_method = request.method 

76 http_version = request.scope["http_version"] 

77 # Recreate the Uvicorn access log format, but add all parameters as 

78 # structured information 

79 access_logger.info( 

80 f"""{client_host}:{client_port} - "{http_method} {url} HTTP/{http_version}" {status_code}""", 

81 http={ 

82 "url": str(request.url), 

83 "status_code": status_code, 

84 "method": http_method, 

85 "request_id": request_id, 

86 "version": http_version, 

87 }, 

88 network={"client": {"ip": client_host, "port": client_port}}, 

89 duration=process_time, 

90 ) 

91 response.headers["X-Process-Time"] = str(process_time / 10**9) 

92 return response # noqa: B012 

93 

94 

95app.add_middleware(CorrelationIdMiddleware) 

96 

97logger = structlog.stdlib.get_logger() 

98 

99 

100@app.post("/lobby") 

101def create_lobby(lm: Annotated[LobbyManager, Depends(lobby_manager)]) -> dict[str, tuple[str, ...]]: 

102 """Creates a new lobby.""" 

103 code = lm.register_lobby() 

104 

105 return {"code": code} 

106 

107 

108@app.post("/lobby/join") 

109def join_lobby(req: LobbyJoinRequest, lm: Annotated[LobbyManager, Depends(lobby_manager)]) -> dict[str, str]: 

110 """Joins a player to a lobby using a list of ingredients as the code.""" 

111 try: 

112 id = lm.register_player(req.code) 

113 except LobbyFullError: 

114 raise HTTPException(status_code=403, detail="Lobby is full") 

115 except LobbyNotFoundError: 

116 raise HTTPException(status_code=404, detail="Lobby not found") 

117 except LobbyClosedError: 

118 raise HTTPException(status_code=403, detail="Lobby has already started") 

119 

120 return {"id": id} 

121 

122 

123@app.get("/") 

124async def read_root() -> dict: 

125 """Returns a simple message at root.""" 

126 return {"Hello": "World"} 

127 

128 

129@app.websocket("/ws") 

130async def websocket_endpoint(websocket: WebSocket, lm: Annotated[LobbyManager, Depends(lobby_manager)]) -> None: 

131 """Handles a WebSocket connection for receiving and responding to messages.""" 

132 await websocket.accept() 

133 

134 try: 

135 init = Message.model_validate(await websocket.receive_json()) 

136 except ValidationError: 

137 logger.debug("Received invalid websocket message.") 

138 await websocket.close() 

139 return 

140 

141 match init: 

142 case Message(data=Initializer(code=code, id=id)): 

143 channel = lm.channel(code, id) 

144 logger.info("WebSocket connection initialized.", player=id) 

145 case _: 

146 logger.info("WebSocket connection failed to initialize.", message=init) 

147 await websocket.close() 

148 return 

149 

150 await asyncio.gather(_recv_handler(id, websocket, channel), _send_handler(id, websocket, channel)) 

151 

152 

153async def _recv_handler(id: str, websocket: WebSocket, channel: Channel[TaggedMessage, Message]) -> typing.Never: 

154 while True: 

155 raw = await websocket.receive_text() 

156 try: 

157 data = Message.model_validate_json(raw) 

158 except ValidationError: 

159 logger.warning("Received invalid websocket message.", data=raw) 

160 else: 

161 logger.debug("Received WebSocket message.", message=data, player=id) 

162 channel.send(TaggedMessage(data=data.data, id=id)) 

163 

164 

165async def _send_handler(id: str, websocket: WebSocket, channel: Channel[TaggedMessage, Message]) -> typing.Never: 

166 while True: 

167 msg = await channel.arecv() 

168 logger.debug("Sending WebSocket message.", message=msg, player=id) 

169 await websocket.send_text(msg.model_dump_json())