Coverage for tests/services/test_suggestion_service.py: 100%

317 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-05-02 02:49 +0000

1import pytest 

2import requests 

3import json 

4from unittest.mock import MagicMock, patch, ANY 

5 

6from app.services import suggestion_service 

7from app.controllers.ai import ( 

8 vendors, 

9 default_openai_parameters, 

10 OLLAMA_URL, 

11 good_command, 

12 bad_command, 

13) 

14from app.models.errors import ModelError, DatabaseError 

15from app.models.user import Suggestion 

16from flask import Flask 

17import logging 

18 

19 

20@pytest.fixture 

21def mock_openai_client(mocker): 

22 """Mocks the OpenAI client.""" 

23 mock_client = MagicMock() 

24 mock_chat = MagicMock() 

25 mock_completions = MagicMock() 

26 mock_response = MagicMock() 

27 mock_choice = MagicMock() 

28 mock_message = MagicMock() 

29 

30 mock_message.content = "openai_suggestion" 

31 mock_choice.message = mock_message 

32 mock_response.choices = [mock_choice] 

33 mock_completions.create.return_value = mock_response 

34 mock_chat.completions = mock_completions 

35 mock_client.chat = mock_chat 

36 

37 mock_models = MagicMock() 

38 mock_model_list_response = MagicMock() 

39 mock_model_list_response.data = [ 

40 MagicMock(id="gpt-4o-mini"), 

41 MagicMock(id="gpt-3.5-turbo"), 

42 ] 

43 mock_models.list.return_value = mock_model_list_response 

44 mock_client.models = mock_models 

45 

46 mocker.patch("app.services.suggestion_service.openai_client", mock_client) 

47 return mock_client 

48 

49 

50@pytest.fixture 

51def mock_gemini_client(mocker): 

52 """Mocks the Gemini client.""" 

53 mock_client = MagicMock() 

54 mock_session = MagicMock() 

55 mock_response = MagicMock() 

56 

57 mock_response.text = json.dumps(["+ b;", " - b;"]) 

58 mock_usage = MagicMock() 

59 mock_usage.prompt_token_count = 10 

60 mock_usage.candidates_token_count = 5 

61 mock_usage.total_token_count = 15 

62 mock_response.usage_metadata = mock_usage 

63 

64 mock_session.send_message.return_value = mock_response 

65 mock_client.chat_session = mock_session 

66 

67 mock_models = MagicMock() 

68 model1 = MagicMock( 

69 name="models/gemini-1.5-flash", supported_actions=["generateContent"] 

70 ) 

71 model2 = MagicMock(name="models/gemini-pro", supported_actions=["embedContent"]) 

72 model3 = MagicMock( 

73 name="models/gemini-2.0-flash", supported_actions=["generateContent", "other"] 

74 ) 

75 mock_models.list.return_value = [model1, model2, model3] 

76 mock_client.models = mock_models 

77 

78 mocker.patch("app.services.suggestion_service.gemini_client", mock_client) 

79 return mock_client 

80 

81 

82@pytest.fixture 

83def mock_db_client(mocker): 

84 """Mocks the Supabase database client.""" 

85 mock_client = MagicMock() 

86 mock_table = MagicMock() 

87 mock_select = MagicMock() 

88 mock_eq = MagicMock() 

89 mock_single = MagicMock() 

90 mock_insert = MagicMock() 

91 mock_execute = MagicMock() 

92 mock_execute_select = MagicMock() 

93 

94 mock_client.table.return_value = mock_table 

95 mock_table.select.return_value = mock_select 

96 mock_select.eq.return_value = mock_eq 

97 mock_eq.single.return_value = mock_single 

98 mock_single.execute.return_value = mock_execute_select 

99 

100 mock_table.insert.return_value = mock_insert 

101 mock_insert.execute.return_value = mock_execute 

102 

103 mocker.patch("app.services.suggestion_service.client", mock_client) 

104 return mock_client, mock_execute_select 

105 

106 

107@pytest.fixture 

108def mock_requests(mocker): 

109 """Mocks the requests library.""" 

110 mock_post = mocker.patch("requests.post") 

111 mock_get = mocker.patch("requests.get") 

112 

113 mock_post_response = MagicMock() 

114 mock_post_response.json.return_value = {"response": "ollama_suggestion"} 

115 mock_post_response.raise_for_status.return_value = None 

116 mock_post.return_value = mock_post_response 

117 

118 mock_get_response = MagicMock() 

119 mock_get_response.json.return_value = { 

120 "models": [{"name": "codellama:latest"}, {"name": "mistral:latest"}] 

121 } 

122 mock_get_response.raise_for_status.return_value = None 

123 mock_get.return_value = mock_get_response 

124 

125 return mock_post, mock_get, mock_post_response, mock_get_response 

126 

127 

128def test_getSuggestion_routes_to_openai(mocker): 

129 mock_openai = mocker.patch.object( 

130 suggestion_service, "getSuggestionFromOpenAI", return_value="test" 

131 ) 

132 suggestion_service.getSuggestion( 

133 "prompt", vendor=vendors.OpenAI, model_name="gpt-test" 

134 ) 

135 mock_openai.assert_called_once_with(prompt="prompt", model="gpt-test") 

136 

137 

138def test_getSuggestion_routes_to_ollama(mocker): 

139 mock_ollama = mocker.patch.object( 

140 suggestion_service, "getSuggestionFromOllama", return_value="test" 

141 ) 

142 suggestion_service.getSuggestion( 

143 "prompt", vendor=vendors.Ollama, model_name="codellama" 

144 ) 

145 mock_ollama.assert_called_once_with(prompt="prompt", model_name="codellama") 

146 

147 

148def test_getSuggestion_routes_to_google(mocker): 

149 mock_google = mocker.patch.object( 

150 suggestion_service, "getSuggestionFromGoogle", return_value=["test"] 

151 ) 

152 suggestion_service.getSuggestion("prompt", vendor=vendors.Google) 

153 mock_google.assert_called_once_with(prompt="prompt") 

154 

155 

156def test_getSuggestion_defaults_to_google_on_invalid_vendor(mocker): 

157 mock_google = mocker.patch.object( 

158 suggestion_service, "getSuggestionFromGoogle", return_value=["test"] 

159 ) 

160 suggestion_service.getSuggestion("prompt", vendor="InvalidVendor") 

161 mock_google.assert_called_once_with(prompt="prompt") 

162 

163 

164def test_getSuggestionFromOpenAI_success_correct(mock_openai_client, app): 

165 prompt = "def func(" 

166 suggestion = suggestion_service.getSuggestionFromOpenAI( 

167 prompt, model="gpt-4o-mini", is_correct=True 

168 ) 

169 

170 assert suggestion == "openai_suggestion" 

171 call_args = mock_openai_client.chat.completions.create.call_args 

172 assert call_args.kwargs["model"] == "gpt-4o-mini" 

173 assert call_args.kwargs["messages"][0]["role"] == "system" 

174 assert "slightly incorrect" not in call_args.kwargs["messages"][0]["content"] 

175 assert call_args.kwargs["messages"][1]["role"] == "user" 

176 assert call_args.kwargs["messages"][1]["content"] == prompt 

177 

178 

179def test_getSuggestionFromOpenAI_success_incorrect(mock_openai_client, app): 

180 prompt = "def func(" 

181 mock_choice = mock_openai_client.chat.completions.create.return_value.choices[0] 

182 mock_choice.message.content = "openai_incorrect_suggestion" 

183 

184 suggestion = suggestion_service.getSuggestionFromOpenAI( 

185 prompt, model="gpt-4o-mini", is_correct=False 

186 ) 

187 

188 assert suggestion == "openai_incorrect_suggestion" 

189 call_args = mock_openai_client.chat.completions.create.call_args 

190 assert "slightly incorrect" in call_args.kwargs["messages"][0]["content"] 

191 

192 

193def test_getSuggestionFromOpenAI_api_error(mock_openai_client, app): 

194 mock_openai_client.chat.completions.create.side_effect = Exception("API Down") 

195 

196 with pytest.raises( 

197 ModelError, match="Error generating suggestion using OpenAI's API" 

198 ): 

199 suggestion_service.getSuggestionFromOpenAI("prompt") 

200 

201 

202def test_getSuggestionFromOllama_success_correct(mock_requests): 

203 mock_post, _, _, _ = mock_requests 

204 prompt = "import os" 

205 model_name = "codellama" 

206 suggestion = suggestion_service.getSuggestionFromOllama( 

207 prompt, model_name=model_name, is_correct=True 

208 ) 

209 

210 assert suggestion == "ollama_suggestion" 

211 expected_payload = { 

212 "model": model_name, 

213 "prompt": good_command + prompt, 

214 "keep_alive": "1h", 

215 "stream": False, 

216 } 

217 mock_post.assert_called_once_with(OLLAMA_URL, json=expected_payload) 

218 

219 

220def test_getSuggestionFromOllama_success_incorrect(mock_requests): 

221 mock_post, _, _, _ = mock_requests 

222 prompt = "import os" 

223 model_name = "codellama" 

224 suggestion = suggestion_service.getSuggestionFromOllama( 

225 prompt, model_name=model_name, is_correct=False 

226 ) 

227 

228 assert suggestion == "ollama_suggestion" 

229 expected_payload = { 

230 "model": model_name, 

231 "prompt": bad_command + prompt, 

232 "keep_alive": "1h", 

233 "stream": False, 

234 } 

235 mock_post.assert_called_once_with(OLLAMA_URL, json=expected_payload) 

236 

237 

238def test_getSuggestionFromOllama_request_error(mock_requests): 

239 mock_post, _, mock_post_response, _ = mock_requests 

240 mock_post_response.raise_for_status.side_effect = ( 

241 requests.exceptions.RequestException("Connection Failed") 

242 ) 

243 

244 with pytest.raises( 

245 ModelError, match="Error fetching Ollama suggestion: Connection Failed" 

246 ): 

247 suggestion_service.getSuggestionFromOllama("prompt", "model") 

248 

249 

250def test_getSuggestionFromOllama_other_error(mock_requests): 

251 mock_post, _, _, _ = mock_requests 

252 mock_post.side_effect = Exception("Unexpected error") 

253 

254 with pytest.raises( 

255 ModelError, match="Error fetching Ollama suggestion: Unexpected error" 

256 ): 

257 suggestion_service.getSuggestionFromOllama("prompt", "model") 

258 

259 

260def test_getSuggestionFromGoogle_success(mock_gemini_client, mock_db_client, mocker): 

261 mocker.patch("time.time", side_effect=[100.0, 100.5]) 

262 

263 mock_client, _ = mock_db_client 

264 prompt = "function add(a, b) {\n return a" 

265 result = suggestion_service.getSuggestionFromGoogle(prompt) 

266 

267 assert result == ["+ b;", " - b;"] 

268 mock_gemini_client.chat_session.send_message.assert_called_once() 

269 assert prompt in mock_gemini_client.chat_session.send_message.call_args[0][0] 

270 

271 mock_client.table.assert_called_once_with("ai_usage") 

272 mock_client.table().insert.assert_called_once_with( 

273 { 

274 "provider": "google", 

275 "model": "gemini-2.0-flash", 

276 "input_tokens": 10, 

277 "output_tokens": 5, 

278 "total_tokens": 15, 

279 "latency_seconds": 0.5, 

280 } 

281 ) 

282 mock_client.table().insert().execute.assert_called_once() 

283 

284 

285def test_getSuggestionFromGoogle_gemini_error( 

286 mock_gemini_client, mock_db_client, caplog 

287): 

288 mock_gemini_client.chat_session.send_message.side_effect = Exception( 

289 "Gemini API error" 

290 ) 

291 mock_client, _ = mock_db_client 

292 

293 result = suggestion_service.getSuggestionFromGoogle("prompt") 

294 

295 assert result == [] 

296 assert ( 

297 "Error communicating with Gemini (Type: Exception): Gemini API error" 

298 in caplog.text 

299 ) 

300 mock_client.table.assert_not_called() 

301 

302 

303def test_getSuggestionFromGoogle_json_decode_error( 

304 mock_gemini_client, mock_db_client, caplog 

305): 

306 mock_gemini_client.chat_session.send_message.return_value.text = "This is not JSON" 

307 mock_client, _ = mock_db_client 

308 caplog.set_level(logging.ERROR) 

309 

310 result = suggestion_service.getSuggestionFromGoogle("prompt") 

311 

312 assert result == [] 

313 

314 assert "Final JSON parse failed" in caplog.text 

315 

316 mock_client.table.assert_called_once_with("ai_usage") 

317 mock_client.table().insert().execute.assert_called_once() 

318 

319 

320def test_getSuggestionFromGoogle_empty_response( 

321 mock_gemini_client, mock_db_client, caplog 

322): 

323 mock_gemini_client.chat_session.send_message.return_value.text = "" 

324 mock_client, _ = mock_db_client 

325 

326 result = suggestion_service.getSuggestionFromGoogle("prompt") 

327 

328 assert result == [] 

329 mock_client.table.assert_not_called() 

330 

331 

332def test_getSuggestionFromGoogle_no_usage_metadata(mock_gemini_client, mock_db_client): 

333 mock_gemini_client.chat_session.send_message.return_value.usage_metadata = None 

334 mock_client, _ = mock_db_client 

335 

336 result = suggestion_service.getSuggestionFromGoogle("prompt") 

337 

338 assert result == ["+ b;", " - b;"] 

339 mock_client.table.assert_called_once_with("ai_usage") 

340 mock_client.table().insert.assert_called_once_with( 

341 { 

342 "provider": "google", 

343 "model": "gemini-2.0-flash", 

344 "input_tokens": -1, 

345 "output_tokens": -1, 

346 "total_tokens": -1, 

347 "latency_seconds": ANY, 

348 } 

349 ) 

350 mock_client.table().insert().execute.assert_called_once() 

351 

352 

353def test_getAvailableModels_routes_openai(mocker): 

354 mock_get = mocker.patch( 

355 "app.services.suggestion_service.getModelsFromOpenAI", return_value=[] 

356 ) 

357 suggestion_service.getAvailableModels(vendors.OpenAI) 

358 mock_get.assert_called_once() 

359 

360 

361def test_getAvailableModels_routes_ollama(mocker): 

362 mock_get = mocker.patch( 

363 "app.services.suggestion_service.getModelsFromOllama", return_value=[] 

364 ) 

365 suggestion_service.getAvailableModels(vendors.Ollama) 

366 mock_get.assert_called_once() 

367 

368 

369def test_getAvailableModels_routes_google(mocker): 

370 mock_get = mocker.patch( 

371 "app.services.suggestion_service.getModelsFromGoogle", return_value=[] 

372 ) 

373 suggestion_service.getAvailableModels(vendors.Google) 

374 mock_get.assert_called_once() 

375 

376 

377def test_getAvailableModels_invalid_vendor(mocker): 

378 with pytest.raises(ValueError): 

379 suggestion_service.getAvailableModels("UnsupportedVendor") 

380 

381 

382def test_getModelsFromOpenAI_success(mock_openai_client, app): 

383 models = suggestion_service.getModelsFromOpenAI() 

384 assert models == ["gpt-4o-mini", "gpt-3.5-turbo"] 

385 mock_openai_client.models.list.assert_called_once() 

386 

387 

388def test_getModelsFromOpenAI_api_error(mock_openai_client, app): 

389 mock_openai_client.models.list.side_effect = Exception("API List Error") 

390 with pytest.raises(Exception, match="API List Error"): 

391 suggestion_service.getModelsFromOpenAI() 

392 

393 

394def test_getModelsFromOllama_success(mock_requests): 

395 _, mock_get, _, _ = mock_requests 

396 models = suggestion_service.getModelsFromOllama() 

397 expected_models = ["codellama:latest", "mistral:latest"] 

398 assert len(models) == len(expected_models) 

399 assert all(m["name"] in expected_models for m in models) 

400 mock_get.assert_called_once_with("http://localhost:11434/api/tags") 

401 

402 

403def test_getModelsFromOllama_request_error(mock_requests): 

404 _, mock_get, _, mock_get_response = mock_requests 

405 mock_get_response.raise_for_status.side_effect = ( 

406 requests.exceptions.RequestException("Ollama Down") 

407 ) 

408 with pytest.raises( 

409 Exception, match="Error fetching models from Ollama: Ollama Down" 

410 ): 

411 suggestion_service.getModelsFromOllama() 

412 

413 

414def test_getModelsFromOllama_missing_models_key(mock_requests): 

415 _, mock_get, _, mock_get_response = mock_requests 

416 mock_get_response.json.return_value = {"some_other_key": []} 

417 models = suggestion_service.getModelsFromOllama() 

418 assert models == [] 

419 

420 

421def test_getModelsFromGoogle_success(mock_gemini_client): 

422 models = suggestion_service.getModelsFromGoogle() 

423 expected_models = ["models/gemini-1.5-flash", "models/gemini-2.0-flash"] 

424 mock_gemini_client.models.list.assert_called_once() 

425 

426 

427def test_getModelsFromGoogle_api_error(mock_gemini_client): 

428 mock_gemini_client.models.list.side_effect = Exception("Gemini List Error") 

429 with pytest.raises(Exception, match="Gemini List Error"): 

430 suggestion_service.getModelsFromGoogle() 

431 

432 

433def test_generate_refined_prompt_success(mock_gemini_client): 

434 mock_gemini_client.chat_session.send_message.return_value.text = ( 

435 " Language: python\n Context: Complete the function. " 

436 ) 

437 raw_prompt = "def my_func(x):\n return" 

438 refined = suggestion_service.generate_refined_prompt(raw_prompt) 

439 

440 assert refined == "Language: python\n Context: Complete the function." 

441 mock_gemini_client.chat_session.send_message.assert_called_once() 

442 

443 assert raw_prompt in mock_gemini_client.chat_session.send_message.call_args[0][0] 

444 

445 

446def test_generate_refined_prompt_api_error(mock_gemini_client): 

447 mock_gemini_client.chat_session.send_message.side_effect = Exception("Refine Error") 

448 with pytest.raises(ValueError, match="AI prompt refinement failed: Refine Error"): 

449 suggestion_service.generate_refined_prompt("raw") 

450 

451 

452def test_generate_hint_from_gemini_success(mock_gemini_client): 

453 mock_gemini_client.chat_session.send_message.return_value.text = ( 

454 " Check the operator used. " 

455 ) 

456 hint = suggestion_service.generate_hint_from_gemini("prompt", "wrong", "right") 

457 

458 assert hint == "Check the operator used." 

459 

460 mock_gemini_client.chat_session.send_message.assert_called_once() 

461 args, _ = mock_gemini_client.chat_session.send_message.call_args 

462 sent_prompt = args[0] 

463 assert "Context: prompt" in sent_prompt 

464 assert "Incorrect Version:" in sent_prompt 

465 assert "wrong" in sent_prompt 

466 assert "Correct Version:" in sent_prompt 

467 assert "right" in sent_prompt 

468 

469 

470def test_generate_hint_from_gemini_api_error(mock_gemini_client): 

471 mock_gemini_client.chat_session.send_message.side_effect = Exception("Hint Error") 

472 hint = suggestion_service.generate_hint_from_gemini("prompt", "wrong", "right") 

473 assert hint == "Could not generate hint: Hint Error" 

474 

475 

476def test_generate_explanation_from_gemini_success(mock_gemini_client): 

477 mock_gemini_client.chat_session.send_message.return_value.text = ( 

478 " The first version used addition incorrectly. " 

479 ) 

480 explanation = suggestion_service.generate_explanation_from_gemini( 

481 "prompt", "wrong", "right" 

482 ) 

483 

484 assert explanation == "The first version used addition incorrectly." 

485 

486 mock_gemini_client.chat_session.send_message.assert_called_once() 

487 args, _ = mock_gemini_client.chat_session.send_message.call_args 

488 sent_prompt = args[0] 

489 assert "Context: prompt" in sent_prompt 

490 assert "Incorrect Version:" in sent_prompt 

491 assert "wrong" in sent_prompt 

492 assert "Correct Version:" in sent_prompt 

493 assert "right" in sent_prompt 

494 assert "Generate an explanation" in sent_prompt 

495 

496 

497# Example for one of the check_code_correctness tests 

498@pytest.mark.parametrize( 

499 "api_response, expected_result", 

500 [ 

501 ("true", True), 

502 ("false", False), 

503 (" True ", True), 

504 ("FALSE", False), 

505 ("maybe", False), 

506 ("", False), 

507 ], 

508) 

509def test_check_code_correctness(mock_gemini_client, api_response, expected_result): 

510 mock_gemini_client.chat_session.send_message.return_value.text = api_response 

511 is_correct = suggestion_service.check_code_correctness("prompt", "wrong", "fixed") 

512 

513 assert is_correct == expected_result 

514 mock_gemini_client.chat_session.send_message.assert_called_once() 

515 args, _ = mock_gemini_client.chat_session.send_message.call_args 

516 sent_prompt = args[0] 

517 

518 assert "Context: prompt" in sent_prompt 

519 assert "Incorrect Version:" in sent_prompt 

520 assert "wrong" in sent_prompt 

521 assert "Fixed Version:" in sent_prompt 

522 assert "fixed" in sent_prompt 

523 assert "Respond with only 'true' or 'false'." in sent_prompt 

524 

525 

526def test_get_suggestion_by_id_found(mock_db_client): 

527 mock_client, mock_execute_select = mock_db_client 

528 suggestion_data = { 

529 "id": "123", 

530 "user_id": "abc", 

531 "prompt": "test prompt", 

532 "correct_suggestion": "correct", 

533 "incorrect_suggestion": "incorrect", 

534 "status": "pending", 

535 } 

536 mock_execute_select.data = suggestion_data 

537 

538 suggestion = suggestion_service.get_suggestion_by_id("123") 

539 

540 assert isinstance(suggestion, Suggestion) 

541 assert suggestion.id == "123" 

542 assert suggestion.prompt == "test prompt" 

543 mock_client.table.assert_called_once_with("suggestions") 

544 mock_client.table().select.assert_called_once_with("*") 

545 mock_client.table().select().eq.assert_called_once_with("id", "123") 

546 mock_client.table().select().eq().single.assert_called_once() 

547 mock_client.table().select().eq().single().execute.assert_called_once() 

548 

549 

550def test_get_suggestion_by_id_not_found(mock_db_client): 

551 mock_client, mock_execute_select = mock_db_client 

552 mock_execute_select.data = None 

553 

554 suggestion = suggestion_service.get_suggestion_by_id("404") 

555 

556 assert suggestion is None 

557 

558 mock_client.table.assert_called_once_with("suggestions") 

559 mock_client.table().select().eq().single().execute.assert_called_once() 

560 

561 

562def test_get_suggestion_by_id_db_error(mock_db_client): 

563 mock_client, mock_execute_select = mock_db_client 

564 

565 mock_single = mock_client.table().select().eq().single() 

566 mock_single.execute.side_effect = Exception("DB Connection Error") 

567 

568 with pytest.raises( 

569 DatabaseError, match="Failed to retrieve suggestion: DB Connection Error" 

570 ): 

571 suggestion_service.get_suggestion_by_id("123")