"""WebSocket service for real-time communication with user rooms.""" import logging import socketio from app.utils.auth import JWTUtils from app.utils.cookies import extract_access_token_from_cookies logger = logging.getLogger(__name__) class SocketManager: """Manages WebSocket connections and user rooms.""" def __init__(self): self.sio = socketio.AsyncServer( cors_allowed_origins=["http://localhost:8001"], logger=True, engineio_logger=True, async_mode="asgi", ) # Track user rooms: user_id -> room_id self.user_rooms: dict[str, str] = {} # Track socket sessions: socket_id -> user_id self.socket_users: dict[str, str] = {} self._setup_handlers() def _setup_handlers(self): """Set up socket event handlers.""" @self.sio.event async def connect(sid, environ, auth=None): """Handle client connection.""" logger.info(f"Client {sid} attempting to connect") # Extract access token from cookies cookie_header = environ.get("HTTP_COOKIE", "") access_token = extract_access_token_from_cookies(cookie_header) if not access_token: logger.warning(f"Client {sid} connecting without access token") await self.sio.disconnect(sid) return try: # Validate JWT token and extract user info payload = JWTUtils.decode_access_token(access_token) user_id = payload.get("sub") if not user_id: logger.warning(f"Client {sid} token missing user ID") await self.sio.disconnect(sid) return logger.info(f"User {user_id} connected with socket {sid}") except Exception as e: logger.warning(f"Client {sid} invalid token: {e}") await self.sio.disconnect(sid) return # Store socket-user mapping self.socket_users[sid] = user_id # Create or join user's personal room room_id = f"user_{user_id}" await self.sio.enter_room(sid, room_id) # Update room tracking self.user_rooms[user_id] = room_id logger.info(f"User {user_id} joined room {room_id}") # Send welcome message to user await self.sio.emit( "user_connected", { "user_id": user_id, "room_id": room_id, "message": "Successfully connected to your personal room", }, room=sid, ) @self.sio.event async def disconnect(sid): """Handle client disconnection.""" user_id = self.socket_users.get(sid) if user_id: logger.info(f"User {user_id} disconnected (socket {sid})") # Clean up mappings del self.socket_users[sid] if user_id in self.user_rooms: del self.user_rooms[user_id] else: logger.info(f"Unknown client {sid} disconnected") async def send_to_user(self, user_id: str, event: str, data: dict): """Send a message to a specific user's room.""" room_id = self.user_rooms.get(user_id) if room_id: await self.sio.emit(event, data, room=room_id) logger.debug(f"Sent {event} to user {user_id} in room {room_id}") return True logger.warning(f"User {user_id} not found in any room") return False async def broadcast_to_all(self, event: str, data: dict): """Broadcast a message to all connected users.""" await self.sio.emit(event, data) logger.info(f"Broadcasted {event} to all users") def get_connected_users(self) -> list: """Get list of currently connected user IDs.""" return list(self.user_rooms.keys()) def get_room_info(self) -> dict: """Get information about connected users.""" return { "total_users": len(self.user_rooms), "connected_users": list(self.user_rooms.keys()), } # Global socket manager instance socket_manager = SocketManager()