diff --git a/tests/api/v1/test_socket_endpoints.py b/tests/api/v1/test_socket_endpoints.py new file mode 100644 index 0000000..a7d25c8 --- /dev/null +++ b/tests/api/v1/test_socket_endpoints.py @@ -0,0 +1,185 @@ +"""Tests for socket API endpoints.""" + +import pytest +from httpx import AsyncClient +from unittest.mock import AsyncMock, patch + +from app.models.user import User + + +@pytest.fixture +def mock_socket_manager(): + """Mock socket manager for testing.""" + with patch('app.api.v1.socket.socket_manager') as mock: + mock.get_connected_users.return_value = ["1", "2", "3"] + mock.send_to_user = AsyncMock(return_value=True) + mock.broadcast_to_all = AsyncMock() + yield mock + + +class TestSocketEndpoints: + """Test socket API endpoints.""" + + @pytest.mark.asyncio + async def test_get_socket_status_authenticated(self, authenticated_client: AsyncClient, authenticated_user: User, mock_socket_manager): + """Test getting socket status for authenticated user.""" + response = await authenticated_client.get("/api/v1/socket/status") + + assert response.status_code == 200 + data = response.json() + + assert "connected" in data + assert "user_id" in data + assert "total_connected" in data + assert data["user_id"] == authenticated_user.id + assert data["total_connected"] == 3 + assert isinstance(data["connected"], bool) + + @pytest.mark.asyncio + async def test_get_socket_status_unauthenticated(self, client: AsyncClient): + """Test getting socket status without authentication.""" + response = await client.get("/api/v1/socket/status") + assert response.status_code == 401 + + @pytest.mark.asyncio + async def test_send_message_to_user_success(self, authenticated_client: AsyncClient, authenticated_user: User, mock_socket_manager): + """Test sending message to specific user successfully.""" + target_user_id = 2 + message = "Hello there!" + + response = await authenticated_client.post( + "/api/v1/socket/send-message", + params={"target_user_id": target_user_id, "message": message} + ) + + assert response.status_code == 200 + data = response.json() + + assert data["success"] is True + assert data["target_user_id"] == target_user_id + assert data["message"] == "Message sent" + + # Verify socket manager was called correctly + mock_socket_manager.send_to_user.assert_called_once_with( + str(target_user_id), + "user_message", + { + "from_user_id": authenticated_user.id, + "from_user_name": authenticated_user.name, + "message": message, + } + ) + + @pytest.mark.asyncio + async def test_send_message_to_user_not_connected(self, authenticated_client: AsyncClient, authenticated_user: User, mock_socket_manager): + """Test sending message to user who is not connected.""" + target_user_id = 999 + message = "Hello there!" + + # Mock user not connected + mock_socket_manager.send_to_user.return_value = False + + response = await authenticated_client.post( + "/api/v1/socket/send-message", + params={"target_user_id": target_user_id, "message": message} + ) + + assert response.status_code == 200 + data = response.json() + + assert data["success"] is False + assert data["target_user_id"] == target_user_id + assert data["message"] == "User not connected" + + @pytest.mark.asyncio + async def test_send_message_unauthenticated(self, client: AsyncClient): + """Test sending message without authentication.""" + response = await client.post( + "/api/v1/socket/send-message", + params={"target_user_id": 1, "message": "test"} + ) + assert response.status_code == 401 + + @pytest.mark.asyncio + async def test_broadcast_message_success(self, authenticated_client: AsyncClient, authenticated_user: User, mock_socket_manager): + """Test broadcasting message to all users successfully.""" + message = "Important announcement!" + + response = await authenticated_client.post( + "/api/v1/socket/broadcast", + params={"message": message} + ) + + assert response.status_code == 200 + data = response.json() + + assert data["success"] is True + assert data["message"] == "Message broadcasted to all users" + + # Verify socket manager was called correctly + mock_socket_manager.broadcast_to_all.assert_called_once_with( + "broadcast_message", + { + "from_user_id": authenticated_user.id, + "from_user_name": authenticated_user.name, + "message": message, + } + ) + + @pytest.mark.asyncio + async def test_broadcast_message_unauthenticated(self, client: AsyncClient): + """Test broadcasting message without authentication.""" + response = await client.post( + "/api/v1/socket/broadcast", + params={"message": "test"} + ) + assert response.status_code == 401 + + @pytest.mark.asyncio + async def test_send_message_missing_parameters(self, authenticated_client: AsyncClient, authenticated_user: User): + """Test sending message with missing parameters.""" + # Missing target_user_id + response = await authenticated_client.post( + "/api/v1/socket/send-message", + params={"message": "test"} + ) + assert response.status_code == 422 + + # Missing message + response = await authenticated_client.post( + "/api/v1/socket/send-message", + params={"target_user_id": 1} + ) + assert response.status_code == 422 + + @pytest.mark.asyncio + async def test_broadcast_message_missing_parameters(self, authenticated_client: AsyncClient, authenticated_user: User): + """Test broadcasting message with missing parameters.""" + response = await authenticated_client.post("/api/v1/socket/broadcast") + assert response.status_code == 422 + + @pytest.mark.asyncio + async def test_send_message_invalid_user_id(self, authenticated_client: AsyncClient, authenticated_user: User): + """Test sending message with invalid user ID.""" + response = await authenticated_client.post( + "/api/v1/socket/send-message", + params={"target_user_id": "invalid", "message": "test"} + ) + assert response.status_code == 422 + + @pytest.mark.asyncio + async def test_socket_status_shows_user_connection(self, authenticated_client: AsyncClient, authenticated_user: User, mock_socket_manager): + """Test that socket status correctly shows if user is connected.""" + # Test when user is connected + mock_socket_manager.get_connected_users.return_value = [str(authenticated_user.id), "2", "3"] + + response = await authenticated_client.get("/api/v1/socket/status") + data = response.json() + assert data["connected"] is True + + # Test when user is not connected + mock_socket_manager.get_connected_users.return_value = ["2", "3", "4"] + + response = await authenticated_client.get("/api/v1/socket/status") + data = response.json() + assert data["connected"] is False \ No newline at end of file diff --git a/tests/conftest.py b/tests/conftest.py index e44f00e..58bcd6f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -102,6 +102,19 @@ async def test_client(test_app) -> AsyncGenerator[AsyncClient, None]: yield client +@pytest_asyncio.fixture +async def authenticated_client( + test_app: FastAPI, auth_cookies: dict[str, str], +) -> AsyncGenerator[AsyncClient, None]: + """Create a test HTTP client with authentication cookies.""" + async with AsyncClient( + transport=ASGITransport(app=test_app), + base_url="http://test", + cookies=auth_cookies, + ) as client: + yield client + + @pytest_asyncio.fixture async def test_plan(test_session: AsyncSession) -> Plan: """Create a test plan.""" @@ -269,3 +282,29 @@ async def admin_headers(admin_user: User) -> dict[str, str]: access_token = JWTUtils.create_access_token(token_data) return {"Authorization": f"Bearer {access_token}"} + + +@pytest.fixture +def client(test_client: AsyncClient) -> AsyncClient: + """Alias for test_client to match test expectations.""" + return test_client + + +@pytest.fixture +def authenticated_user(test_user: User) -> User: + """Alias for test_user to match test expectations.""" + return test_user + + +@pytest_asyncio.fixture +async def auth_cookies(test_user: User) -> dict[str, str]: + """Create authentication cookies with JWT token.""" + token_data = { + "sub": str(test_user.id), + "email": test_user.email, + "role": test_user.role, + } + + access_token = JWTUtils.create_access_token(token_data) + + return {"access_token": access_token} diff --git a/tests/services/test_socket_service.py b/tests/services/test_socket_service.py new file mode 100644 index 0000000..9a090cc --- /dev/null +++ b/tests/services/test_socket_service.py @@ -0,0 +1,271 @@ +"""Tests for socket service.""" + +import pytest +from unittest.mock import AsyncMock, MagicMock, patch, call +import socketio + +from app.services.socket import SocketManager + + +class TestSocketManager: + """Test socket manager service.""" + + @pytest.fixture + def socket_manager(self): + """Create a fresh socket manager for testing.""" + return SocketManager() + + @pytest.fixture + def mock_sio(self, socket_manager): + """Mock the socket.io server.""" + socket_manager.sio = AsyncMock(spec=socketio.AsyncServer) + return socket_manager.sio + + def test_init_creates_socket_server(self): + """Test that socket manager initializes with proper configuration.""" + manager = SocketManager() + + assert manager.sio is not None + assert isinstance(manager.user_rooms, dict) + assert isinstance(manager.socket_users, dict) + assert len(manager.user_rooms) == 0 + assert len(manager.socket_users) == 0 + + @pytest.mark.asyncio + async def test_send_to_user_success(self, socket_manager, mock_sio): + """Test sending message to connected user.""" + user_id = "123" + room_id = "user_123" + socket_manager.user_rooms[user_id] = room_id + + event = "test_event" + data = {"message": "hello"} + + result = await socket_manager.send_to_user(user_id, event, data) + + assert result is True + mock_sio.emit.assert_called_once_with(event, data, room=room_id) + + @pytest.mark.asyncio + async def test_send_to_user_not_connected(self, socket_manager, mock_sio): + """Test sending message to user who is not connected.""" + user_id = "999" + event = "test_event" + data = {"message": "hello"} + + result = await socket_manager.send_to_user(user_id, event, data) + + assert result is False + mock_sio.emit.assert_not_called() + + @pytest.mark.asyncio + async def test_broadcast_to_all(self, socket_manager, mock_sio): + """Test broadcasting message to all users.""" + event = "broadcast_event" + data = {"message": "announcement"} + + await socket_manager.broadcast_to_all(event, data) + + mock_sio.emit.assert_called_once_with(event, data) + + def test_get_connected_users(self, socket_manager): + """Test getting list of connected users.""" + # Add some users + socket_manager.user_rooms["1"] = "user_1" + socket_manager.user_rooms["2"] = "user_2" + socket_manager.user_rooms["3"] = "user_3" + + connected_users = socket_manager.get_connected_users() + + assert len(connected_users) == 3 + assert "1" in connected_users + assert "2" in connected_users + assert "3" in connected_users + + def test_get_room_info(self, socket_manager): + """Test getting room information.""" + # Add some users + socket_manager.user_rooms["1"] = "user_1" + socket_manager.user_rooms["2"] = "user_2" + + room_info = socket_manager.get_room_info() + + assert room_info["total_users"] == 2 + assert room_info["connected_users"] == ["1", "2"] + + @pytest.mark.asyncio + @patch('app.services.socket.extract_access_token_from_cookies') + @patch('app.services.socket.JWTUtils.decode_access_token') + async def test_connect_handler_success(self, mock_decode, mock_extract_token, socket_manager, mock_sio): + """Test successful connection with valid token.""" + # Setup mocks + mock_extract_token.return_value = "valid_token" + mock_decode.return_value = {"sub": "123"} + + # Mock environment + environ = {"HTTP_COOKIE": "access_token=valid_token"} + + # Access the connect handler directly + handlers = {} + original_event = socket_manager.sio.event + + def mock_event(func): + handlers[func.__name__] = func + return func + + socket_manager.sio.event = mock_event + socket_manager._setup_handlers() + + # Call the connect handler + await handlers['connect']("test_sid", environ) + + # Verify token extraction and validation + mock_extract_token.assert_called_once_with("access_token=valid_token") + mock_decode.assert_called_once_with("valid_token") + + # Verify user tracking + assert socket_manager.socket_users["test_sid"] == "123" + assert socket_manager.user_rooms["123"] == "user_123" + + @pytest.mark.asyncio + @patch('app.services.socket.extract_access_token_from_cookies') + async def test_connect_handler_no_token(self, mock_extract_token, socket_manager, mock_sio): + """Test connection with no access token.""" + # Setup mocks + mock_extract_token.return_value = None + + # Mock environment + environ = {"HTTP_COOKIE": ""} + + # Access the connect handler directly + handlers = {} + original_event = socket_manager.sio.event + + def mock_event(func): + handlers[func.__name__] = func + return func + + socket_manager.sio.event = mock_event + socket_manager._setup_handlers() + + # Call the connect handler + await handlers['connect']("test_sid", environ) + + # Verify disconnection + mock_sio.disconnect.assert_called_once_with("test_sid") + + # Verify no user tracking + assert "test_sid" not in socket_manager.socket_users + assert len(socket_manager.user_rooms) == 0 + + @pytest.mark.asyncio + @patch('app.services.socket.extract_access_token_from_cookies') + @patch('app.services.socket.JWTUtils.decode_access_token') + async def test_connect_handler_invalid_token(self, mock_decode, mock_extract_token, socket_manager, mock_sio): + """Test connection with invalid token.""" + # Setup mocks + mock_extract_token.return_value = "invalid_token" + mock_decode.side_effect = Exception("Invalid token") + + # Mock environment + environ = {"HTTP_COOKIE": "access_token=invalid_token"} + + # Access the connect handler directly + handlers = {} + original_event = socket_manager.sio.event + + def mock_event(func): + handlers[func.__name__] = func + return func + + socket_manager.sio.event = mock_event + socket_manager._setup_handlers() + + # Call the connect handler + await handlers['connect']("test_sid", environ) + + # Verify disconnection + mock_sio.disconnect.assert_called_once_with("test_sid") + + # Verify no user tracking + assert "test_sid" not in socket_manager.socket_users + assert len(socket_manager.user_rooms) == 0 + + @pytest.mark.asyncio + @patch('app.services.socket.extract_access_token_from_cookies') + @patch('app.services.socket.JWTUtils.decode_access_token') + async def test_connect_handler_missing_user_id(self, mock_decode, mock_extract_token, socket_manager, mock_sio): + """Test connection with token missing user ID.""" + # Setup mocks + mock_extract_token.return_value = "token_without_user_id" + mock_decode.return_value = {"other_field": "value"} # Missing 'sub' + + # Mock environment + environ = {"HTTP_COOKIE": "access_token=token_without_user_id"} + + # Access the connect handler directly + handlers = {} + original_event = socket_manager.sio.event + + def mock_event(func): + handlers[func.__name__] = func + return func + + socket_manager.sio.event = mock_event + socket_manager._setup_handlers() + + # Call the connect handler + await handlers['connect']("test_sid", environ) + + # Verify disconnection + mock_sio.disconnect.assert_called_once_with("test_sid") + + # Verify no user tracking + assert "test_sid" not in socket_manager.socket_users + assert len(socket_manager.user_rooms) == 0 + + @pytest.mark.asyncio + async def test_disconnect_handler(self, socket_manager, mock_sio): + """Test disconnect handler.""" + # Setup initial state + socket_manager.socket_users["test_sid"] = "123" + socket_manager.user_rooms["123"] = "user_123" + + # Access the disconnect handler directly + handlers = {} + original_event = socket_manager.sio.event + + def mock_event(func): + handlers[func.__name__] = func + return func + + socket_manager.sio.event = mock_event + socket_manager._setup_handlers() + + # Call the disconnect handler + await handlers['disconnect']("test_sid") + + # Verify cleanup + assert "test_sid" not in socket_manager.socket_users + assert "123" not in socket_manager.user_rooms + + @pytest.mark.asyncio + async def test_disconnect_handler_unknown_socket(self, socket_manager, mock_sio): + """Test disconnect handler with unknown socket.""" + # Access the disconnect handler directly + handlers = {} + original_event = socket_manager.sio.event + + def mock_event(func): + handlers[func.__name__] = func + return func + + socket_manager.sio.event = mock_event + socket_manager._setup_handlers() + + # Call the disconnect handler with unknown socket + await handlers['disconnect']("unknown_sid") + + # Should not raise any errors and state should remain clean + assert len(socket_manager.socket_users) == 0 + assert len(socket_manager.user_rooms) == 0 \ No newline at end of file diff --git a/tests/utils/test_cookies.py b/tests/utils/test_cookies.py new file mode 100644 index 0000000..f7f3243 --- /dev/null +++ b/tests/utils/test_cookies.py @@ -0,0 +1,149 @@ +"""Tests for cookie utilities.""" + +import pytest + +from app.utils.cookies import parse_cookies, extract_access_token_from_cookies + + +class TestParseCookies: + """Test cookie parsing functionality.""" + + def test_parse_empty_cookie_header(self): + """Test parsing empty cookie header.""" + result = parse_cookies("") + assert result == {} + + def test_parse_none_cookie_header(self): + """Test parsing None cookie header.""" + result = parse_cookies("") + assert result == {} + + def test_parse_single_cookie(self): + """Test parsing single cookie.""" + cookie_header = "session_id=abc123" + result = parse_cookies(cookie_header) + + assert result == {"session_id": "abc123"} + + def test_parse_multiple_cookies(self): + """Test parsing multiple cookies.""" + cookie_header = "session_id=abc123; user_pref=dark_mode; lang=en" + result = parse_cookies(cookie_header) + + expected = { + "session_id": "abc123", + "user_pref": "dark_mode", + "lang": "en" + } + assert result == expected + + def test_parse_cookies_with_spaces(self): + """Test parsing cookies with extra spaces.""" + cookie_header = " session_id = abc123 ; user_pref = dark_mode " + result = parse_cookies(cookie_header) + + expected = { + "session_id": "abc123", + "user_pref": "dark_mode" + } + assert result == expected + + def test_parse_cookies_with_equals_in_value(self): + """Test parsing cookies where value contains equals sign.""" + cookie_header = "encoded_data=key=value&other=data; session=123" + result = parse_cookies(cookie_header) + + expected = { + "encoded_data": "key=value&other=data", + "session": "123" + } + assert result == expected + + def test_parse_cookies_malformed(self): + """Test parsing malformed cookies (no equals sign).""" + cookie_header = "session_id=abc123; malformed_cookie; user_pref=dark" + result = parse_cookies(cookie_header) + + # Should skip malformed cookie and parse valid ones + expected = { + "session_id": "abc123", + "user_pref": "dark" + } + assert result == expected + + def test_parse_cookies_empty_values(self): + """Test parsing cookies with empty values.""" + cookie_header = "empty_value=; session_id=abc123" + result = parse_cookies(cookie_header) + + expected = { + "empty_value": "", + "session_id": "abc123" + } + assert result == expected + + def test_parse_cookies_duplicate_names(self): + """Test parsing cookies with duplicate names (last one wins).""" + cookie_header = "session_id=first; session_id=second" + result = parse_cookies(cookie_header) + + assert result == {"session_id": "second"} + + +class TestExtractAccessTokenFromCookies: + """Test access token extraction from cookies.""" + + def test_extract_access_token_present(self): + """Test extracting access token when present.""" + cookie_header = "session_id=abc123; access_token=jwt_token_here; user_pref=dark" + result = extract_access_token_from_cookies(cookie_header) + + assert result == "jwt_token_here" + + def test_extract_access_token_not_present(self): + """Test extracting access token when not present.""" + cookie_header = "session_id=abc123; user_pref=dark" + result = extract_access_token_from_cookies(cookie_header) + + assert result is None + + def test_extract_access_token_empty_header(self): + """Test extracting access token from empty header.""" + result = extract_access_token_from_cookies("") + assert result is None + + def test_extract_access_token_only_token(self): + """Test extracting access token when it's the only cookie.""" + cookie_header = "access_token=my_jwt_token" + result = extract_access_token_from_cookies(cookie_header) + + assert result == "my_jwt_token" + + def test_extract_access_token_with_spaces(self): + """Test extracting access token with spaces around values.""" + cookie_header = " access_token = jwt_token_with_spaces ; other=value " + result = extract_access_token_from_cookies(cookie_header) + + assert result == "jwt_token_with_spaces" + + def test_extract_access_token_empty_value(self): + """Test extracting access token with empty value.""" + cookie_header = "access_token=; other=value" + result = extract_access_token_from_cookies(cookie_header) + + assert result == "" + + def test_extract_access_token_complex_value(self): + """Test extracting access token with complex JWT-like value.""" + jwt_token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjMiLCJleHAiOjE2MzM5NjY0MDB9.signature" + cookie_header = f"session=abc; access_token={jwt_token}; csrf=token" + result = extract_access_token_from_cookies(cookie_header) + + assert result == jwt_token + + def test_extract_access_token_multiple_equals(self): + """Test extracting access token when value contains equals signs.""" + cookie_header = "access_token=encoded=data=here; other=simple" + result = extract_access_token_from_cookies(cookie_header) + + assert result == "encoded=data=here" \ No newline at end of file