Coverage for tests/services/test_suggestion_service.py: 100%
317 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-05-02 02:49 +0000
« prev ^ index » next coverage.py v7.6.12, created at 2025-05-02 02:49 +0000
1import pytest
2import requests
3import json
4from unittest.mock import MagicMock, patch, ANY
6from app.services import suggestion_service
7from app.controllers.ai import (
8 vendors,
9 default_openai_parameters,
10 OLLAMA_URL,
11 good_command,
12 bad_command,
13)
14from app.models.errors import ModelError, DatabaseError
15from app.models.user import Suggestion
16from flask import Flask
17import logging
20@pytest.fixture
21def mock_openai_client(mocker):
22 """Mocks the OpenAI client."""
23 mock_client = MagicMock()
24 mock_chat = MagicMock()
25 mock_completions = MagicMock()
26 mock_response = MagicMock()
27 mock_choice = MagicMock()
28 mock_message = MagicMock()
30 mock_message.content = "openai_suggestion"
31 mock_choice.message = mock_message
32 mock_response.choices = [mock_choice]
33 mock_completions.create.return_value = mock_response
34 mock_chat.completions = mock_completions
35 mock_client.chat = mock_chat
37 mock_models = MagicMock()
38 mock_model_list_response = MagicMock()
39 mock_model_list_response.data = [
40 MagicMock(id="gpt-4o-mini"),
41 MagicMock(id="gpt-3.5-turbo"),
42 ]
43 mock_models.list.return_value = mock_model_list_response
44 mock_client.models = mock_models
46 mocker.patch("app.services.suggestion_service.openai_client", mock_client)
47 return mock_client
50@pytest.fixture
51def mock_gemini_client(mocker):
52 """Mocks the Gemini client."""
53 mock_client = MagicMock()
54 mock_session = MagicMock()
55 mock_response = MagicMock()
57 mock_response.text = json.dumps(["+ b;", " - b;"])
58 mock_usage = MagicMock()
59 mock_usage.prompt_token_count = 10
60 mock_usage.candidates_token_count = 5
61 mock_usage.total_token_count = 15
62 mock_response.usage_metadata = mock_usage
64 mock_session.send_message.return_value = mock_response
65 mock_client.chat_session = mock_session
67 mock_models = MagicMock()
68 model1 = MagicMock(
69 name="models/gemini-1.5-flash", supported_actions=["generateContent"]
70 )
71 model2 = MagicMock(name="models/gemini-pro", supported_actions=["embedContent"])
72 model3 = MagicMock(
73 name="models/gemini-2.0-flash", supported_actions=["generateContent", "other"]
74 )
75 mock_models.list.return_value = [model1, model2, model3]
76 mock_client.models = mock_models
78 mocker.patch("app.services.suggestion_service.gemini_client", mock_client)
79 return mock_client
82@pytest.fixture
83def mock_db_client(mocker):
84 """Mocks the Supabase database client."""
85 mock_client = MagicMock()
86 mock_table = MagicMock()
87 mock_select = MagicMock()
88 mock_eq = MagicMock()
89 mock_single = MagicMock()
90 mock_insert = MagicMock()
91 mock_execute = MagicMock()
92 mock_execute_select = MagicMock()
94 mock_client.table.return_value = mock_table
95 mock_table.select.return_value = mock_select
96 mock_select.eq.return_value = mock_eq
97 mock_eq.single.return_value = mock_single
98 mock_single.execute.return_value = mock_execute_select
100 mock_table.insert.return_value = mock_insert
101 mock_insert.execute.return_value = mock_execute
103 mocker.patch("app.services.suggestion_service.client", mock_client)
104 return mock_client, mock_execute_select
107@pytest.fixture
108def mock_requests(mocker):
109 """Mocks the requests library."""
110 mock_post = mocker.patch("requests.post")
111 mock_get = mocker.patch("requests.get")
113 mock_post_response = MagicMock()
114 mock_post_response.json.return_value = {"response": "ollama_suggestion"}
115 mock_post_response.raise_for_status.return_value = None
116 mock_post.return_value = mock_post_response
118 mock_get_response = MagicMock()
119 mock_get_response.json.return_value = {
120 "models": [{"name": "codellama:latest"}, {"name": "mistral:latest"}]
121 }
122 mock_get_response.raise_for_status.return_value = None
123 mock_get.return_value = mock_get_response
125 return mock_post, mock_get, mock_post_response, mock_get_response
128def test_getSuggestion_routes_to_openai(mocker):
129 mock_openai = mocker.patch.object(
130 suggestion_service, "getSuggestionFromOpenAI", return_value="test"
131 )
132 suggestion_service.getSuggestion(
133 "prompt", vendor=vendors.OpenAI, model_name="gpt-test"
134 )
135 mock_openai.assert_called_once_with(prompt="prompt", model="gpt-test")
138def test_getSuggestion_routes_to_ollama(mocker):
139 mock_ollama = mocker.patch.object(
140 suggestion_service, "getSuggestionFromOllama", return_value="test"
141 )
142 suggestion_service.getSuggestion(
143 "prompt", vendor=vendors.Ollama, model_name="codellama"
144 )
145 mock_ollama.assert_called_once_with(prompt="prompt", model_name="codellama")
148def test_getSuggestion_routes_to_google(mocker):
149 mock_google = mocker.patch.object(
150 suggestion_service, "getSuggestionFromGoogle", return_value=["test"]
151 )
152 suggestion_service.getSuggestion("prompt", vendor=vendors.Google)
153 mock_google.assert_called_once_with(prompt="prompt")
156def test_getSuggestion_defaults_to_google_on_invalid_vendor(mocker):
157 mock_google = mocker.patch.object(
158 suggestion_service, "getSuggestionFromGoogle", return_value=["test"]
159 )
160 suggestion_service.getSuggestion("prompt", vendor="InvalidVendor")
161 mock_google.assert_called_once_with(prompt="prompt")
164def test_getSuggestionFromOpenAI_success_correct(mock_openai_client, app):
165 prompt = "def func("
166 suggestion = suggestion_service.getSuggestionFromOpenAI(
167 prompt, model="gpt-4o-mini", is_correct=True
168 )
170 assert suggestion == "openai_suggestion"
171 call_args = mock_openai_client.chat.completions.create.call_args
172 assert call_args.kwargs["model"] == "gpt-4o-mini"
173 assert call_args.kwargs["messages"][0]["role"] == "system"
174 assert "slightly incorrect" not in call_args.kwargs["messages"][0]["content"]
175 assert call_args.kwargs["messages"][1]["role"] == "user"
176 assert call_args.kwargs["messages"][1]["content"] == prompt
179def test_getSuggestionFromOpenAI_success_incorrect(mock_openai_client, app):
180 prompt = "def func("
181 mock_choice = mock_openai_client.chat.completions.create.return_value.choices[0]
182 mock_choice.message.content = "openai_incorrect_suggestion"
184 suggestion = suggestion_service.getSuggestionFromOpenAI(
185 prompt, model="gpt-4o-mini", is_correct=False
186 )
188 assert suggestion == "openai_incorrect_suggestion"
189 call_args = mock_openai_client.chat.completions.create.call_args
190 assert "slightly incorrect" in call_args.kwargs["messages"][0]["content"]
193def test_getSuggestionFromOpenAI_api_error(mock_openai_client, app):
194 mock_openai_client.chat.completions.create.side_effect = Exception("API Down")
196 with pytest.raises(
197 ModelError, match="Error generating suggestion using OpenAI's API"
198 ):
199 suggestion_service.getSuggestionFromOpenAI("prompt")
202def test_getSuggestionFromOllama_success_correct(mock_requests):
203 mock_post, _, _, _ = mock_requests
204 prompt = "import os"
205 model_name = "codellama"
206 suggestion = suggestion_service.getSuggestionFromOllama(
207 prompt, model_name=model_name, is_correct=True
208 )
210 assert suggestion == "ollama_suggestion"
211 expected_payload = {
212 "model": model_name,
213 "prompt": good_command + prompt,
214 "keep_alive": "1h",
215 "stream": False,
216 }
217 mock_post.assert_called_once_with(OLLAMA_URL, json=expected_payload)
220def test_getSuggestionFromOllama_success_incorrect(mock_requests):
221 mock_post, _, _, _ = mock_requests
222 prompt = "import os"
223 model_name = "codellama"
224 suggestion = suggestion_service.getSuggestionFromOllama(
225 prompt, model_name=model_name, is_correct=False
226 )
228 assert suggestion == "ollama_suggestion"
229 expected_payload = {
230 "model": model_name,
231 "prompt": bad_command + prompt,
232 "keep_alive": "1h",
233 "stream": False,
234 }
235 mock_post.assert_called_once_with(OLLAMA_URL, json=expected_payload)
238def test_getSuggestionFromOllama_request_error(mock_requests):
239 mock_post, _, mock_post_response, _ = mock_requests
240 mock_post_response.raise_for_status.side_effect = (
241 requests.exceptions.RequestException("Connection Failed")
242 )
244 with pytest.raises(
245 ModelError, match="Error fetching Ollama suggestion: Connection Failed"
246 ):
247 suggestion_service.getSuggestionFromOllama("prompt", "model")
250def test_getSuggestionFromOllama_other_error(mock_requests):
251 mock_post, _, _, _ = mock_requests
252 mock_post.side_effect = Exception("Unexpected error")
254 with pytest.raises(
255 ModelError, match="Error fetching Ollama suggestion: Unexpected error"
256 ):
257 suggestion_service.getSuggestionFromOllama("prompt", "model")
260def test_getSuggestionFromGoogle_success(mock_gemini_client, mock_db_client, mocker):
261 mocker.patch("time.time", side_effect=[100.0, 100.5])
263 mock_client, _ = mock_db_client
264 prompt = "function add(a, b) {\n return a"
265 result = suggestion_service.getSuggestionFromGoogle(prompt)
267 assert result == ["+ b;", " - b;"]
268 mock_gemini_client.chat_session.send_message.assert_called_once()
269 assert prompt in mock_gemini_client.chat_session.send_message.call_args[0][0]
271 mock_client.table.assert_called_once_with("ai_usage")
272 mock_client.table().insert.assert_called_once_with(
273 {
274 "provider": "google",
275 "model": "gemini-2.0-flash",
276 "input_tokens": 10,
277 "output_tokens": 5,
278 "total_tokens": 15,
279 "latency_seconds": 0.5,
280 }
281 )
282 mock_client.table().insert().execute.assert_called_once()
285def test_getSuggestionFromGoogle_gemini_error(
286 mock_gemini_client, mock_db_client, caplog
287):
288 mock_gemini_client.chat_session.send_message.side_effect = Exception(
289 "Gemini API error"
290 )
291 mock_client, _ = mock_db_client
293 result = suggestion_service.getSuggestionFromGoogle("prompt")
295 assert result == []
296 assert (
297 "Error communicating with Gemini (Type: Exception): Gemini API error"
298 in caplog.text
299 )
300 mock_client.table.assert_not_called()
303def test_getSuggestionFromGoogle_json_decode_error(
304 mock_gemini_client, mock_db_client, caplog
305):
306 mock_gemini_client.chat_session.send_message.return_value.text = "This is not JSON"
307 mock_client, _ = mock_db_client
308 caplog.set_level(logging.ERROR)
310 result = suggestion_service.getSuggestionFromGoogle("prompt")
312 assert result == []
314 assert "Final JSON parse failed" in caplog.text
316 mock_client.table.assert_called_once_with("ai_usage")
317 mock_client.table().insert().execute.assert_called_once()
320def test_getSuggestionFromGoogle_empty_response(
321 mock_gemini_client, mock_db_client, caplog
322):
323 mock_gemini_client.chat_session.send_message.return_value.text = ""
324 mock_client, _ = mock_db_client
326 result = suggestion_service.getSuggestionFromGoogle("prompt")
328 assert result == []
329 mock_client.table.assert_not_called()
332def test_getSuggestionFromGoogle_no_usage_metadata(mock_gemini_client, mock_db_client):
333 mock_gemini_client.chat_session.send_message.return_value.usage_metadata = None
334 mock_client, _ = mock_db_client
336 result = suggestion_service.getSuggestionFromGoogle("prompt")
338 assert result == ["+ b;", " - b;"]
339 mock_client.table.assert_called_once_with("ai_usage")
340 mock_client.table().insert.assert_called_once_with(
341 {
342 "provider": "google",
343 "model": "gemini-2.0-flash",
344 "input_tokens": -1,
345 "output_tokens": -1,
346 "total_tokens": -1,
347 "latency_seconds": ANY,
348 }
349 )
350 mock_client.table().insert().execute.assert_called_once()
353def test_getAvailableModels_routes_openai(mocker):
354 mock_get = mocker.patch(
355 "app.services.suggestion_service.getModelsFromOpenAI", return_value=[]
356 )
357 suggestion_service.getAvailableModels(vendors.OpenAI)
358 mock_get.assert_called_once()
361def test_getAvailableModels_routes_ollama(mocker):
362 mock_get = mocker.patch(
363 "app.services.suggestion_service.getModelsFromOllama", return_value=[]
364 )
365 suggestion_service.getAvailableModels(vendors.Ollama)
366 mock_get.assert_called_once()
369def test_getAvailableModels_routes_google(mocker):
370 mock_get = mocker.patch(
371 "app.services.suggestion_service.getModelsFromGoogle", return_value=[]
372 )
373 suggestion_service.getAvailableModels(vendors.Google)
374 mock_get.assert_called_once()
377def test_getAvailableModels_invalid_vendor(mocker):
378 with pytest.raises(ValueError):
379 suggestion_service.getAvailableModels("UnsupportedVendor")
382def test_getModelsFromOpenAI_success(mock_openai_client, app):
383 models = suggestion_service.getModelsFromOpenAI()
384 assert models == ["gpt-4o-mini", "gpt-3.5-turbo"]
385 mock_openai_client.models.list.assert_called_once()
388def test_getModelsFromOpenAI_api_error(mock_openai_client, app):
389 mock_openai_client.models.list.side_effect = Exception("API List Error")
390 with pytest.raises(Exception, match="API List Error"):
391 suggestion_service.getModelsFromOpenAI()
394def test_getModelsFromOllama_success(mock_requests):
395 _, mock_get, _, _ = mock_requests
396 models = suggestion_service.getModelsFromOllama()
397 expected_models = ["codellama:latest", "mistral:latest"]
398 assert len(models) == len(expected_models)
399 assert all(m["name"] in expected_models for m in models)
400 mock_get.assert_called_once_with("http://localhost:11434/api/tags")
403def test_getModelsFromOllama_request_error(mock_requests):
404 _, mock_get, _, mock_get_response = mock_requests
405 mock_get_response.raise_for_status.side_effect = (
406 requests.exceptions.RequestException("Ollama Down")
407 )
408 with pytest.raises(
409 Exception, match="Error fetching models from Ollama: Ollama Down"
410 ):
411 suggestion_service.getModelsFromOllama()
414def test_getModelsFromOllama_missing_models_key(mock_requests):
415 _, mock_get, _, mock_get_response = mock_requests
416 mock_get_response.json.return_value = {"some_other_key": []}
417 models = suggestion_service.getModelsFromOllama()
418 assert models == []
421def test_getModelsFromGoogle_success(mock_gemini_client):
422 models = suggestion_service.getModelsFromGoogle()
423 expected_models = ["models/gemini-1.5-flash", "models/gemini-2.0-flash"]
424 mock_gemini_client.models.list.assert_called_once()
427def test_getModelsFromGoogle_api_error(mock_gemini_client):
428 mock_gemini_client.models.list.side_effect = Exception("Gemini List Error")
429 with pytest.raises(Exception, match="Gemini List Error"):
430 suggestion_service.getModelsFromGoogle()
433def test_generate_refined_prompt_success(mock_gemini_client):
434 mock_gemini_client.chat_session.send_message.return_value.text = (
435 " Language: python\n Context: Complete the function. "
436 )
437 raw_prompt = "def my_func(x):\n return"
438 refined = suggestion_service.generate_refined_prompt(raw_prompt)
440 assert refined == "Language: python\n Context: Complete the function."
441 mock_gemini_client.chat_session.send_message.assert_called_once()
443 assert raw_prompt in mock_gemini_client.chat_session.send_message.call_args[0][0]
446def test_generate_refined_prompt_api_error(mock_gemini_client):
447 mock_gemini_client.chat_session.send_message.side_effect = Exception("Refine Error")
448 with pytest.raises(ValueError, match="AI prompt refinement failed: Refine Error"):
449 suggestion_service.generate_refined_prompt("raw")
452def test_generate_hint_from_gemini_success(mock_gemini_client):
453 mock_gemini_client.chat_session.send_message.return_value.text = (
454 " Check the operator used. "
455 )
456 hint = suggestion_service.generate_hint_from_gemini("prompt", "wrong", "right")
458 assert hint == "Check the operator used."
460 mock_gemini_client.chat_session.send_message.assert_called_once()
461 args, _ = mock_gemini_client.chat_session.send_message.call_args
462 sent_prompt = args[0]
463 assert "Context: prompt" in sent_prompt
464 assert "Incorrect Version:" in sent_prompt
465 assert "wrong" in sent_prompt
466 assert "Correct Version:" in sent_prompt
467 assert "right" in sent_prompt
470def test_generate_hint_from_gemini_api_error(mock_gemini_client):
471 mock_gemini_client.chat_session.send_message.side_effect = Exception("Hint Error")
472 hint = suggestion_service.generate_hint_from_gemini("prompt", "wrong", "right")
473 assert hint == "Could not generate hint: Hint Error"
476def test_generate_explanation_from_gemini_success(mock_gemini_client):
477 mock_gemini_client.chat_session.send_message.return_value.text = (
478 " The first version used addition incorrectly. "
479 )
480 explanation = suggestion_service.generate_explanation_from_gemini(
481 "prompt", "wrong", "right"
482 )
484 assert explanation == "The first version used addition incorrectly."
486 mock_gemini_client.chat_session.send_message.assert_called_once()
487 args, _ = mock_gemini_client.chat_session.send_message.call_args
488 sent_prompt = args[0]
489 assert "Context: prompt" in sent_prompt
490 assert "Incorrect Version:" in sent_prompt
491 assert "wrong" in sent_prompt
492 assert "Correct Version:" in sent_prompt
493 assert "right" in sent_prompt
494 assert "Generate an explanation" in sent_prompt
497# Example for one of the check_code_correctness tests
498@pytest.mark.parametrize(
499 "api_response, expected_result",
500 [
501 ("true", True),
502 ("false", False),
503 (" True ", True),
504 ("FALSE", False),
505 ("maybe", False),
506 ("", False),
507 ],
508)
509def test_check_code_correctness(mock_gemini_client, api_response, expected_result):
510 mock_gemini_client.chat_session.send_message.return_value.text = api_response
511 is_correct = suggestion_service.check_code_correctness("prompt", "wrong", "fixed")
513 assert is_correct == expected_result
514 mock_gemini_client.chat_session.send_message.assert_called_once()
515 args, _ = mock_gemini_client.chat_session.send_message.call_args
516 sent_prompt = args[0]
518 assert "Context: prompt" in sent_prompt
519 assert "Incorrect Version:" in sent_prompt
520 assert "wrong" in sent_prompt
521 assert "Fixed Version:" in sent_prompt
522 assert "fixed" in sent_prompt
523 assert "Respond with only 'true' or 'false'." in sent_prompt
526def test_get_suggestion_by_id_found(mock_db_client):
527 mock_client, mock_execute_select = mock_db_client
528 suggestion_data = {
529 "id": "123",
530 "user_id": "abc",
531 "prompt": "test prompt",
532 "correct_suggestion": "correct",
533 "incorrect_suggestion": "incorrect",
534 "status": "pending",
535 }
536 mock_execute_select.data = suggestion_data
538 suggestion = suggestion_service.get_suggestion_by_id("123")
540 assert isinstance(suggestion, Suggestion)
541 assert suggestion.id == "123"
542 assert suggestion.prompt == "test prompt"
543 mock_client.table.assert_called_once_with("suggestions")
544 mock_client.table().select.assert_called_once_with("*")
545 mock_client.table().select().eq.assert_called_once_with("id", "123")
546 mock_client.table().select().eq().single.assert_called_once()
547 mock_client.table().select().eq().single().execute.assert_called_once()
550def test_get_suggestion_by_id_not_found(mock_db_client):
551 mock_client, mock_execute_select = mock_db_client
552 mock_execute_select.data = None
554 suggestion = suggestion_service.get_suggestion_by_id("404")
556 assert suggestion is None
558 mock_client.table.assert_called_once_with("suggestions")
559 mock_client.table().select().eq().single().execute.assert_called_once()
562def test_get_suggestion_by_id_db_error(mock_db_client):
563 mock_client, mock_execute_select = mock_db_client
565 mock_single = mock_client.table().select().eq().single()
566 mock_single.execute.side_effect = Exception("DB Connection Error")
568 with pytest.raises(
569 DatabaseError, match="Failed to retrieve suggestion: DB Connection Error"
570 ):
571 suggestion_service.get_suggestion_by_id("123")