Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 32 additions & 5 deletions frontend/src/components/socket/Socket.jsx
Original file line number Diff line number Diff line change
@@ -1,11 +1,38 @@
import socketIO from "socket.io-client";
import { SOCKET_URL } from "../../config";

function getSocketToken() {
try {
return localStorage.getItem("token") || "";
} catch {
return "";
}
}

function getSocketOptions() {
const token = getSocketToken();
return {
transports: ["websocket", "polling"],
withCredentials: true,
autoConnect: Boolean(token),
auth: { token },
};
}

const socket = SOCKET_URL
? socketIO(SOCKET_URL, {
transports: ["websocket", "polling"],
withCredentials: true,
})
: socketIO();
? socketIO(SOCKET_URL, getSocketOptions())
: socketIO(getSocketOptions());

window.addEventListener("piperchat:auth-token", () => {
const token = getSocketToken();
socket.auth = { token };

if (token) {
socket.connect();
return;
}

socket.disconnect();
});

export default socket;
1 change: 1 addition & 0 deletions server/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
"scripts": {
"start": "node src/index.js",
"dev": "nodemon src/index.js",
"test:socket:auth": "node scripts/run-socket-auth-unit.mjs",
"test:auth:unit": "node scripts/run-auth-jwt-unit.mjs",
"test:auth": "node scripts/run-auth-integration.mjs",
"gmail:oauth-setup": "node scripts/gmail-oauth-setup.mjs"
Expand Down
59 changes: 59 additions & 0 deletions server/scripts/run-socket-auth-unit.mjs
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
process.env.ACCESS_TOKEN = "socket-auth-test-secret";

const assert = await import("node:assert/strict");
const jwt = await import("jsonwebtoken");
const {
authenticateSocketHandshake,
isAuthenticatedUserClaim,
} = await import("../src/socket/index.js");

function runMiddleware(socket) {
return new Promise((resolve) => {
authenticateSocketHandshake(socket, (error) => resolve(error || null));
});
}

const token = jwt.default.sign(
{ id: "user-123", username: "tester" },
process.env.ACCESS_TOKEN,
);

const socket = {
handshake: {
auth: { token },
headers: {},
},
data: {},
};

assert.equal(await runMiddleware(socket), null);
assert.equal(socket.data.authenticated_user_id, "user-123");
assert.equal(isAuthenticatedUserClaim(socket, "user-123"), true);
assert.equal(isAuthenticatedUserClaim(socket, "user-456"), false);

const missingTokenSocket = {
handshake: { auth: {}, headers: {} },
data: {},
};
assert.match((await runMiddleware(missingTokenSocket)).message, /required/);

const invalidTokenSocket = {
handshake: {
auth: { token: "not-a-valid-token" },
headers: {},
},
data: {},
};
assert.match((await runMiddleware(invalidTokenSocket)).message, /required/);

const bearerSocket = {
handshake: {
auth: {},
headers: { authorization: `Bearer ${token}` },
},
data: {},
};
assert.equal(await runMiddleware(bearerSocket), null);
assert.equal(bearerSocket.data.authenticated_user_id, "user-123");

console.log("socket auth unit checks passed");
74 changes: 71 additions & 3 deletions server/src/socket/index.js
Original file line number Diff line number Diff line change
@@ -1,8 +1,62 @@
import User from "../models/User.js";
import { buildServerTypingEvent } from "../lib/typingEvents.js";
import config from "../config/index.js";
import jwt from "jsonwebtoken";

const onlineUsers = new Map();

function getHandshakeToken(socket) {
const authToken = socket.handshake?.auth?.token;
if (typeof authToken === "string" && authToken.trim()) {
return authToken.trim();
}

const headerToken = socket.handshake?.headers?.["x-auth-token"];
if (typeof headerToken === "string" && headerToken.trim()) {
return headerToken.trim();
}

const authorization = socket.handshake?.headers?.authorization;
if (typeof authorization === "string") {
const match = authorization.match(/^Bearer\s+(.+)$/i);
if (match?.[1]?.trim()) {
return match[1].trim();
}
}

return "";
}

function getPayloadUserId(payload) {
const value = payload?.id ?? payload?._id ?? payload?.user_id ?? payload?.userId;
return value == null ? "" : String(value);
}

function authenticateSocketHandshake(socket, next) {
const token = getHandshakeToken(socket);
if (!token || !config.ACCESS_TOKEN) {
return next(new Error("Socket authentication required"));
}

try {
const payload = jwt.verify(token, config.ACCESS_TOKEN);
const userId = getPayloadUserId(payload);
if (!userId) {
return next(new Error("Socket authentication required"));
}

socket.data.authenticated_user_id = userId;
socket.data.authenticated_user = payload;
return next();
} catch {
return next(new Error("Socket authentication required"));
}
}

function isAuthenticatedUserClaim(socket, userId) {
return String(socket.data.authenticated_user_id || "") === String(userId || "");
}

async function shouldSendNotification(userId, preferenceKey) {
try {
const user = await User.findById(userId).lean();
Expand Down Expand Up @@ -63,13 +117,13 @@ function setUserOffline(io, userId, socketId) {
}

function attachSocketHandlers(io) {
io.use(authenticateSocketHandshake);

io.on("connection", (socket) => {
socket.on("channelCreated", (data) => {
io.emit("newChannel", data);
});
});

io.on("connection", (socket) => {
socket.on("get_userid", (user_id) => {
const normalizedUserId = String(user_id);

Expand All @@ -79,6 +133,13 @@ function attachSocketHandlers(io) {
return;
}

if (!isAuthenticatedUserClaim(socket, normalizedUserId)) {
socket.emit("socket_auth_error", {
message: "Authenticated user mismatch",
});
return;
}

if (socket.data.user_id) {
setUserOffline(io, socket.data.user_id, socket.id);
}
Expand All @@ -92,6 +153,13 @@ function attachSocketHandlers(io) {
socket.on(
"send_req",
async (receiver_id, sender_id, sender_profile_pic, sender_name) => {
if (!isAuthenticatedUserClaim(socket, sender_id)) {
socket.emit("socket_auth_error", {
message: "Authenticated user mismatch",
});
return;
}

const shouldNotify = await shouldSendNotification(receiver_id, "friend_requests");
if (shouldNotify) {
socket.to(receiver_id).emit("recieve_req", {
Expand Down Expand Up @@ -237,4 +305,4 @@ function attachSocketHandlers(io) {
});
}

export { attachSocketHandlers };
export { attachSocketHandlers, authenticateSocketHandshake, isAuthenticatedUserClaim };
Loading