"""WebSocket service for real-time communication with user rooms.""" import logging import socketio from app.core.config import settings from app.utils.auth import JWTUtils from app.utils.cookies import extract_access_token_from_cookies logger = logging.getLogger(__name__) class SocketManager: """Manages WebSocket connections and user rooms.""" def __init__(self) -> None: """Initialize the SocketManager with a Socket.IO server.""" self.sio = socketio.AsyncServer( cors_allowed_origins=settings.CORS_ORIGINS, logger=True, engineio_logger=True, async_mode="asgi", ) # Track user rooms: user_id -> room_id self.user_rooms: dict[str, str] = {} # Track socket sessions: socket_id -> user_id self.socket_users: dict[str, str] = {} self._setup_handlers() def _setup_handlers(self) -> None: """Set up socket event handlers.""" @self.sio.event async def connect(sid: str, environ: dict) -> None: """Handle client connection.""" logger.info("Client %s attempting to connect", sid) # Extract access token from cookies cookie_header = environ.get("HTTP_COOKIE", "") access_token = extract_access_token_from_cookies(cookie_header) if not access_token: logger.warning("Client %s connecting without access token", sid) await self.sio.disconnect(sid) return try: # Validate JWT token and extract user info payload = JWTUtils.decode_access_token(access_token) user_id = payload.get("sub") if not user_id: logger.warning("Client %s token missing user ID", sid) await self.sio.disconnect(sid) return logger.info("User %s connected with socket %s", user_id, sid) except Exception: logger.exception("Client %s invalid token", sid) await self.sio.disconnect(sid) return # Store socket-user mapping self.socket_users[sid] = user_id # Create or join user's personal room room_id = f"user_{user_id}" await self.sio.enter_room(sid, room_id) # Update room tracking self.user_rooms[user_id] = room_id logger.info("User %s joined room %s", user_id, room_id) # Send welcome message to user await self.sio.emit( "user_connected", { "user_id": user_id, "room_id": room_id, "message": "Successfully connected to your personal room", }, room=sid, ) @self.sio.event async def disconnect(sid: str) -> None: """Handle client disconnection.""" user_id = self.socket_users.get(sid) if user_id: logger.info("User %s disconnected (socket %s)", user_id, sid) # Clean up mappings del self.socket_users[sid] if user_id in self.user_rooms: del self.user_rooms[user_id] else: logger.info("Unknown client %s disconnected", sid) @self.sio.event async def play_sound(sid: str, data: dict) -> None: """Handle play sound event from client.""" user_id = self.socket_users.get(sid) if not user_id: logger.warning("Play sound request from unknown client %s", sid) return sound_id = data.get("sound_id") if not sound_id: logger.warning( "Play sound request missing sound_id from user %s", user_id, ) return try: # Import here to avoid circular imports from app.api.v1.sounds import play_sound_internal # Call the internal play sound function await play_sound_internal(int(sound_id), user_id) logger.info("User %s played sound %s via WebSocket", user_id, sound_id) except Exception as e: logger.exception( "Error playing sound %s for user %s: %s", sound_id, user_id, e, ) # Emit error back to user await self.sio.emit( "sound_play_error", {"sound_id": sound_id, "error": str(e)}, room=sid, ) async def send_to_user(self, user_id: str, event: str, data: dict) -> bool: """Send a message to a specific user's room.""" room_id = self.user_rooms.get(user_id) if room_id: await self.sio.emit(event, data, room=room_id) logger.debug("Sent %s to user %s in room %s", event, user_id, room_id) return True logger.warning("User %s not found in any room", user_id) return False async def broadcast_to_all(self, event: str, data: dict) -> None: """Broadcast a message to all connected users.""" await self.sio.emit(event, data) logger.info("Broadcasted %s to all users", event) def get_connected_users(self) -> list: """Get list of currently connected user IDs.""" return list(self.user_rooms.keys()) def get_room_info(self) -> dict: """Get information about connected users.""" return { "total_users": len(self.user_rooms), "connected_users": list(self.user_rooms.keys()), } # Global socket manager instance socket_manager = SocketManager()