diff --git a/frontend/src/components/socket/Socket.jsx b/frontend/src/components/socket/Socket.jsx index 07e7a32..506a585 100644 --- a/frontend/src/components/socket/Socket.jsx +++ b/frontend/src/components/socket/Socket.jsx @@ -1,11 +1,38 @@ import socketIO from "socket.io-client"; import { SOCKET_URL } from "../../config"; +function getSocketToken() { + try { + return localStorage.getItem("token") || ""; + } catch { + return ""; + } +} + +function getSocketOptions() { + const token = getSocketToken(); + return { + transports: ["websocket", "polling"], + withCredentials: true, + autoConnect: Boolean(token), + auth: { token }, + }; +} + const socket = SOCKET_URL - ? socketIO(SOCKET_URL, { - transports: ["websocket", "polling"], - withCredentials: true, - }) - : socketIO(); + ? socketIO(SOCKET_URL, getSocketOptions()) + : socketIO(getSocketOptions()); + +window.addEventListener("piperchat:auth-token", () => { + const token = getSocketToken(); + socket.auth = { token }; + + if (token) { + socket.connect(); + return; + } + + socket.disconnect(); +}); export default socket; diff --git a/server/package.json b/server/package.json index 2f856a1..f0cf475 100644 --- a/server/package.json +++ b/server/package.json @@ -7,6 +7,7 @@ "scripts": { "start": "node src/index.js", "dev": "nodemon src/index.js", + "test:socket:auth": "node scripts/run-socket-auth-unit.mjs", "test:auth:unit": "node scripts/run-auth-jwt-unit.mjs", "test:auth": "node scripts/run-auth-integration.mjs", "gmail:oauth-setup": "node scripts/gmail-oauth-setup.mjs" diff --git a/server/scripts/run-socket-auth-unit.mjs b/server/scripts/run-socket-auth-unit.mjs new file mode 100644 index 0000000..05b5f34 --- /dev/null +++ b/server/scripts/run-socket-auth-unit.mjs @@ -0,0 +1,59 @@ +process.env.ACCESS_TOKEN = "socket-auth-test-secret"; + +const assert = await import("node:assert/strict"); +const jwt = await import("jsonwebtoken"); +const { + authenticateSocketHandshake, + isAuthenticatedUserClaim, +} = await import("../src/socket/index.js"); + +function runMiddleware(socket) { + return new Promise((resolve) => { + authenticateSocketHandshake(socket, (error) => resolve(error || null)); + }); +} + +const token = jwt.default.sign( + { id: "user-123", username: "tester" }, + process.env.ACCESS_TOKEN, +); + +const socket = { + handshake: { + auth: { token }, + headers: {}, + }, + data: {}, +}; + +assert.equal(await runMiddleware(socket), null); +assert.equal(socket.data.authenticated_user_id, "user-123"); +assert.equal(isAuthenticatedUserClaim(socket, "user-123"), true); +assert.equal(isAuthenticatedUserClaim(socket, "user-456"), false); + +const missingTokenSocket = { + handshake: { auth: {}, headers: {} }, + data: {}, +}; +assert.match((await runMiddleware(missingTokenSocket)).message, /required/); + +const invalidTokenSocket = { + handshake: { + auth: { token: "not-a-valid-token" }, + headers: {}, + }, + data: {}, +}; +assert.match((await runMiddleware(invalidTokenSocket)).message, /required/); + +const bearerSocket = { + handshake: { + auth: {}, + headers: { authorization: `Bearer ${token}` }, + }, + data: {}, +}; +assert.equal(await runMiddleware(bearerSocket), null); +assert.equal(bearerSocket.data.authenticated_user_id, "user-123"); + +console.log("socket auth unit checks passed"); diff --git a/server/src/socket/index.js b/server/src/socket/index.js index 8561f10..eab10bf 100644 --- a/server/src/socket/index.js +++ b/server/src/socket/index.js @@ -1,8 +1,62 @@ import User from "../models/User.js"; import { buildServerTypingEvent } from "../lib/typingEvents.js"; +import config from "../config/index.js"; +import jwt from "jsonwebtoken"; const onlineUsers = new Map(); +function getHandshakeToken(socket) { + const authToken = socket.handshake?.auth?.token; + if (typeof authToken === "string" && authToken.trim()) { + return authToken.trim(); + } + + const headerToken = socket.handshake?.headers?.["x-auth-token"]; + if (typeof headerToken === "string" && headerToken.trim()) { + return headerToken.trim(); + } + + const authorization = socket.handshake?.headers?.authorization; + if (typeof authorization === "string") { + const match = authorization.match(/^Bearer\s+(.+)$/i); + if (match?.[1]?.trim()) { + return match[1].trim(); + } + } + + return ""; +} + +function getPayloadUserId(payload) { + const value = payload?.id ?? payload?._id ?? payload?.user_id ?? payload?.userId; + return value == null ? "" : String(value); +} + +function authenticateSocketHandshake(socket, next) { + const token = getHandshakeToken(socket); + if (!token || !config.ACCESS_TOKEN) { + return next(new Error("Socket authentication required")); + } + + try { + const payload = jwt.verify(token, config.ACCESS_TOKEN); + const userId = getPayloadUserId(payload); + if (!userId) { + return next(new Error("Socket authentication required")); + } + + socket.data.authenticated_user_id = userId; + socket.data.authenticated_user = payload; + return next(); + } catch { + return next(new Error("Socket authentication required")); + } +} + +function isAuthenticatedUserClaim(socket, userId) { + return String(socket.data.authenticated_user_id || "") === String(userId || ""); +} + async function shouldSendNotification(userId, preferenceKey) { try { const user = await User.findById(userId).lean(); @@ -63,13 +117,13 @@ function setUserOffline(io, userId, socketId) { } function attachSocketHandlers(io) { + io.use(authenticateSocketHandshake); + io.on("connection", (socket) => { socket.on("channelCreated", (data) => { io.emit("newChannel", data); }); - }); - io.on("connection", (socket) => { socket.on("get_userid", (user_id) => { const normalizedUserId = String(user_id); @@ -79,6 +133,13 @@ function attachSocketHandlers(io) { return; } + if (!isAuthenticatedUserClaim(socket, normalizedUserId)) { + socket.emit("socket_auth_error", { + message: "Authenticated user mismatch", + }); + return; + } + if (socket.data.user_id) { setUserOffline(io, socket.data.user_id, socket.id); } @@ -92,6 +153,13 @@ function attachSocketHandlers(io) { socket.on( "send_req", async (receiver_id, sender_id, sender_profile_pic, sender_name) => { + if (!isAuthenticatedUserClaim(socket, sender_id)) { + socket.emit("socket_auth_error", { + message: "Authenticated user mismatch", + }); + return; + } + const shouldNotify = await shouldSendNotification(receiver_id, "friend_requests"); if (shouldNotify) { socket.to(receiver_id).emit("recieve_req", { @@ -237,4 +305,4 @@ function attachSocketHandlers(io) { }); } -export { attachSocketHandlers }; +export { attachSocketHandlers, authenticateSocketHandshake, isAuthenticatedUserClaim };