diff --git a/apps/client/src/screens/connect/index.tsx b/apps/client/src/screens/connect/index.tsx index 4dd1b0e1b..fff094638 100644 --- a/apps/client/src/screens/connect/index.tsx +++ b/apps/client/src/screens/connect/index.tsx @@ -6,6 +6,7 @@ import { getFileUrl, getUrlFromServer } from '@/helpers/get-file-url'; import { getLocalStorageItem, getLocalStorageItemBool, + getSessionStorageItem, LocalStorageKey, removeLocalStorageItem, SessionStorageKey, @@ -14,6 +15,7 @@ import { setSessionStorageItem } from '@/helpers/storage'; import { useForm } from '@/hooks/use-form'; +import { useStrictEffect } from '@/hooks/use-strict-effect'; import { PluginSlot, TestId } from '@sharkord/shared'; import { Alert, @@ -58,6 +60,77 @@ const Connect = memo(() => { return invite || undefined; }, []); + const onRememberCredentialsChange = useCallback( + (checked: boolean) => { + onChange('rememberCredentials', checked); + + if (checked) { + setLocalStorageItem(LocalStorageKey.REMEMBER_CREDENTIALS, 'true'); + } else { + removeLocalStorageItem(LocalStorageKey.REMEMBER_CREDENTIALS); + } + }, + [onChange] + ); + + const onOidcLoginClick = useCallback(() => { + const url = getUrlFromServer(); + window.location.href = `${url}/auth/login`; + }, []); + + const handleOidcSuccess = useCallback(async () => { + setLoading(true); + try { + const cookies = document.cookie.split('; ').reduce( + (acc, current) => { + const [name, value] = current.split('='); + acc[name] = value; + return acc; + }, + {} as Record + ); + + const token = cookies['sharkord_token']; + + if (token) { + setSessionStorageItem(SessionStorageKey.TOKEN, token); + + document.cookie = + 'sharkord_token=; Max-Age=0; path=/; SameSite=Lax; Secure'; + } else { + throw new Error('No authentication token found in cookies.'); + } + + await connect(); + toast.success('Logged in with OIDC'); + } catch (error) { + const errorMessage = + error instanceof Error ? error.message : String(error); + toast.error(`Could not connect with OIDC: ${errorMessage}`); + } finally { + setLoading(false); + } + }, []); + + useStrictEffect(() => { + const urlParams = new URLSearchParams(window.location.search); + const oidcStatus = urlParams.get('oidc_status'); + + if (oidcStatus === 'success') { + const newUrl = window.location.pathname; + window.history.replaceState({}, document.title, newUrl); + + handleOidcSuccess(); + } + }, [handleOidcSuccess]); + + useStrictEffect(() => { + const token = getSessionStorageItem(SessionStorageKey.TOKEN); + if (info && info.oidcEnabled && !info.allowNewUsers && !token) { + onOidcLoginClick(); + } + }, [info, onOidcLoginClick]); + const onConnectClick = useCallback(async () => { setLoading(true); @@ -121,10 +194,7 @@ const Connect = memo(() => { }, [info]); return ( -
-
- -
+
@@ -148,30 +218,32 @@ const Connect = memo(() => { )} -
{ - e.preventDefault(); - onConnectClick(); - }} - > - - - - - - -
+ {!(info?.oidcEnabled && !info?.allowNewUsers) && ( +
{ + e.preventDefault(); + onConnectClick(); + }} + > + + + + + + +
+ )}
{ )} - - - {!info?.allowNewUsers && ( - <> - {!inviteCode && ( - - {t('registrationDisabled')} - - )} - + {!(info?.oidcEnabled && !info?.allowNewUsers) && ( + + )} + + {info?.oidcEnabled && ( + + )} + + {!info?.allowNewUsers && !inviteCode && ( + + {t('registrationDisabled')} + )} {inviteCode && ( @@ -228,8 +309,8 @@ const Connect = memo(() => { -
- v{VITE_APP_VERSION} +
+ v{import.meta.env.VITE_APP_VERSION} { > Sharkord + +
); diff --git a/apps/server/package.json b/apps/server/package.json index 0b5a0b7c1..4f5a46ebb 100644 --- a/apps/server/package.json +++ b/apps/server/package.json @@ -52,6 +52,7 @@ "jsonwebtoken": "^9.0.2", "link-preview-js": "^3.1.0", "mediasoup": "^3.19.17", + "openid-client": "^6.8.2", "queue": "^7.0.0", "sanitize-html": "^2.17.0", "semver": "^7.7.3", diff --git a/apps/server/src/config.ts b/apps/server/src/config.ts index a5d078a3c..bd84a2a63 100644 --- a/apps/server/src/config.ts +++ b/apps/server/src/config.ts @@ -14,11 +14,49 @@ const [SERVER_PUBLIC_IP, SERVER_PRIVATE_IP] = await Promise.all([ getPrivateIp() ]); +const jsonTransform = (fallback: T) => + z + .preprocess((val) => { + if (typeof val !== 'string') return val; + try { + return JSON.parse(val); + } catch { + return fallback; + } + }, z.any()) + .transform((val) => val as T); + +const commaSeparatedTransform = (fallback: string[]) => + z.preprocess((val) => { + if (typeof val !== 'string') return val; + return val + .split(',') + .map((s) => s.trim()) + .filter(Boolean); + }, z.string().array()); + const zConfig = z.object({ server: z.object({ port: z.coerce.number().int().positive(), debug: z.coerce.boolean(), - autoupdate: z.coerce.boolean() + autoupdate: z.coerce.boolean(), + disableLocalSignup: z.coerce.boolean() + }), + oidc: z.object({ + oidcEnabled: z.coerce.boolean(), + enforceOidcRoles: z.coerce.boolean(), + issuer: z.string(), + clientId: z.string(), + clientSecret: z.string(), + rolesMapping: jsonTransform>({}), + requiredGroups: commaSeparatedTransform([]), + allowedOrigins: commaSeparatedTransform([]), + caCertPath: z.string().optional(), + groupsClaim: z.string(), + usernameClaim: z.string(), + displayNameClaim: z.string(), + enforceOidcDisplayName: z.coerce.boolean(), + additionalScopes: commaSeparatedTransform([]) }), webRtc: z.object({ port: z.coerce.number().int().positive(), @@ -45,13 +83,30 @@ const zConfig = z.object({ }) }); -type TConfig = z.infer; +type TConfig = z.output; const defaultConfig: TConfig = { server: { port: 4991, debug: IS_DEVELOPMENT, - autoupdate: false + autoupdate: false, + disableLocalSignup: false + }, + oidc: { + oidcEnabled: false, + enforceOidcRoles: true, + issuer: 'https://auth.example.com/.well-known/openid-configuration', + clientId: '', + clientSecret: '', + rolesMapping: {}, + requiredGroups: [], + allowedOrigins: [], + caCertPath: '', + groupsClaim: 'groups', + usernameClaim: 'preferred_username', + displayNameClaim: '', + enforceOidcDisplayName: true, + additionalScopes: [] }, webRtc: { port: 40000, @@ -78,48 +133,74 @@ const defaultConfig: TConfig = { } }; -let config: TConfig = structuredClone(defaultConfig); +const prepareForSave = (data: TConfig) => { + const { oidc, ...rest } = data; + const { allowedOrigins, rolesMapping, ...oidcRest } = oidc; + + return { + ...rest, + oidc: { + ...oidcRest, + rolesMapping: JSON.stringify(rolesMapping), + allowedOrigins: allowedOrigins.join(',') + } + }; +}; + +let config: TConfig = zConfig.parse(defaultConfig); await ensureServerDirs(); const configExists = await fs.exists(CONFIG_INI_PATH); if (!configExists) { - // config does not exist, create it with the default config - await fs.writeFile(CONFIG_INI_PATH, stringify(config)); + await fs.writeFile(CONFIG_INI_PATH, stringify(prepareForSave(config))); } else { try { - // config exists, we need to make sure it is up to date with the schema - // to make this easy, we will read the existing config, merge it with the default config, and write it back to the file - // this way we don't have to worry about migrating old config files when we add/remove config options const existingConfigText = await fs.readFile(CONFIG_INI_PATH, { encoding: 'utf-8' }); + const existingConfig = parse(existingConfigText); - const existingConfig = parse(existingConfigText) as Partial; - const mergedConfig = deepMerge(config, existingConfig); - + const mergedConfig = deepMerge(defaultConfig, existingConfig); config = zConfig.parse(mergedConfig); - await fs.writeFile(CONFIG_INI_PATH, stringify(config)); + await fs.writeFile(CONFIG_INI_PATH, stringify(prepareForSave(config))); } catch (error) { - // something went wrong, just log the error and overwrite the config file with the default config console.error( - `Error reading or parsing config.ini. Overwriting with default config. Error: ${getErrorMessage(error)}` + `Error parsing config.ini. Resetting to defaults. Error: ${getErrorMessage(error)}` ); - - await fs.writeFile(CONFIG_INI_PATH, stringify(config)); + await fs.writeFile(CONFIG_INI_PATH, stringify(prepareForSave(config))); } } -config = applyEnvOverrides(config, { - 'server.port': 'SHARKORD_PORT', - 'server.debug': 'SHARKORD_DEBUG', - 'server.autoupdate': 'SHARKORD_AUTOUPDATE', - 'webRtc.port': 'SHARKORD_WEBRTC_PORT', - 'webRtc.announcedAddress': 'SHARKORD_WEBRTC_ANNOUNCED_ADDRESS', - 'webRtc.maxBitrate': 'SHARKORD_WEBRTC_MAX_BITRATE' -}); +config = zConfig.parse( + applyEnvOverrides(config, { + 'server.port': 'SHARKORD_PORT', + 'server.debug': 'SHARKORD_DEBUG', + 'server.autoupdate': 'SHARKORD_AUTOUPDATE', + 'server.disableLocalSignup': 'SHARKORD_DISABLE_LOCAL_SIGNUP', + + 'oidc.oidcEnabled': 'OIDC_ENABLED', + 'oidc.enforceOidcRoles': 'OIDC_ENFORCE_ROLES', + 'oidc.issuer': 'OIDC_ISSUER', + 'oidc.clientId': 'OIDC_CLIENT_ID', + 'oidc.clientSecret': 'OIDC_CLIENT_SECRET', + 'oidc.rolesMapping': 'OIDC_ROLES_MAPPING', + 'oidc.requiredGroups': 'OIDC_REQUIRED_GROUPS', + 'oidc.allowedOrigins': 'OIDC_ALLOWED_ORIGINS', + 'oidc.caCertPath': 'OIDC_CA_CERT_PATH', + 'oidc.groupsClaim': 'OIDC_GROUPS_CLAIM', + 'oidc.usernameClaim': 'OIDC_USERNAME_CLAIM', + 'oidc.displayNameClaim': 'OIDC_DISPLAY_NAME_CLAIM', + 'oidc.enforceOidcDisplayName': 'OIDC_ENFORCE_DISPLAY_NAME', + 'oidc.additionalScopes': 'OIDC_ADDITIONAL_SCOPES', + + 'webRtc.port': 'SHARKORD_WEBRTC_PORT', + 'webRtc.announcedAddress': 'SHARKORD_WEBRTC_ANNOUNCED_ADDRESS', + 'webRtc.maxBitrate': 'SHARKORD_WEBRTC_MAX_BITRATE' + }) +); config = Object.freeze(config); diff --git a/apps/server/src/db/index.ts b/apps/server/src/db/index.ts index a0e2e8d9e..38d6a4854 100644 --- a/apps/server/src/db/index.ts +++ b/apps/server/src/db/index.ts @@ -2,16 +2,17 @@ import { Database } from 'bun:sqlite'; import { migrate } from 'drizzle-orm/better-sqlite3/migrator'; import { BunSQLiteDatabase, drizzle } from 'drizzle-orm/bun-sqlite'; import { DB_PATH, DRIZZLE_PATH } from '../helpers/paths'; +import * as schema from './schema'; import { seedDatabase } from './seed'; -let db: BunSQLiteDatabase; +let db: BunSQLiteDatabase; const loadDb = async () => { const sqlite = new Database(DB_PATH, { create: true, strict: true }); sqlite.run('PRAGMA foreign_keys = ON;'); - db = drizzle({ client: sqlite }); + db = drizzle(sqlite, { schema }); await migrate(db, { migrationsFolder: DRIZZLE_PATH }); await seedDatabase(); diff --git a/apps/server/src/db/migrations/0008_oidc_sub.sql b/apps/server/src/db/migrations/0008_oidc_sub.sql new file mode 100644 index 000000000..cb429133c --- /dev/null +++ b/apps/server/src/db/migrations/0008_oidc_sub.sql @@ -0,0 +1,3 @@ +ALTER TABLE `users` ADD `oidc_sub` text; +--> statement-breakpoint +CREATE UNIQUE INDEX `users_oidc_sub_idx` ON `users` (`oidc_sub`); diff --git a/apps/server/src/db/migrations/0009_roles_added_by.sql b/apps/server/src/db/migrations/0009_roles_added_by.sql new file mode 100644 index 000000000..501addca6 --- /dev/null +++ b/apps/server/src/db/migrations/0009_roles_added_by.sql @@ -0,0 +1 @@ +ALTER TABLE `user_roles` ADD `added_by` text NOT NULL DEFAULT 'manual'; diff --git a/apps/server/src/db/migrations/meta/_journal.json b/apps/server/src/db/migrations/meta/_journal.json index 38e90616c..92b6f6a7f 100644 --- a/apps/server/src/db/migrations/meta/_journal.json +++ b/apps/server/src/db/migrations/meta/_journal.json @@ -57,6 +57,20 @@ "when": 1772897866100, "tag": "0007_overrated_carnage", "breakpoints": true + }, + { + "idx": 8, + "version": "6", + "when": 1772897870000, + "tag": "0008_oidc_sub", + "breakpoints": true + }, + { + "idx": 9, + "version": "6", + "when": 1772897874000, + "tag": "0009_roles_added_by", + "breakpoints": true } ] } \ No newline at end of file diff --git a/apps/server/src/db/queries/users.ts b/apps/server/src/db/queries/users.ts index bc57af9e7..e57a07830 100644 --- a/apps/server/src/db/queries/users.ts +++ b/apps/server/src/db/queries/users.ts @@ -210,6 +210,7 @@ const getUserById = async ( banned: users.banned, banReason: users.banReason, bannedAt: users.bannedAt, + oidcSub: users.oidcSub, avatar: avatarFiles, banner: bannerFiles }) @@ -235,6 +236,54 @@ const getUserById = async ( }; }; +const getUserByOidcSub = async ( + oidcSub: string +): Promise => { + const avatarFiles = alias(files, 'avatarFiles'); + const bannerFiles = alias(files, 'bannerFiles'); + + const user = await db + .select({ + id: users.id, + identity: users.identity, + name: users.name, + avatarId: users.avatarId, + bannerId: users.bannerId, + bio: users.bio, + bannerColor: users.bannerColor, + createdAt: users.createdAt, + updatedAt: users.updatedAt, + password: users.password, + lastLoginAt: users.lastLoginAt, + banned: users.banned, + banReason: users.banReason, + bannedAt: users.bannedAt, + oidcSub: users.oidcSub, + avatar: avatarFiles, + banner: bannerFiles + }) + .from(users) + .leftJoin(avatarFiles, eq(users.avatarId, avatarFiles.id)) + .leftJoin(bannerFiles, eq(users.bannerId, bannerFiles.id)) + .where(eq(users.oidcSub, oidcSub)) + .get(); + + if (!user) return undefined; + + const roles = await db + .select({ roleId: userRoles.roleId }) + .from(userRoles) + .where(eq(userRoles.userId, user.id)) + .all(); + + return { + ...user, + avatar: user.avatar, + banner: user.banner, + roleIds: roles.map((r) => r.roleId) + }; +}; + const getUserByIdentity = async ( identity: string ): Promise => { @@ -257,6 +306,7 @@ const getUserByIdentity = async ( banned: users.banned, banReason: users.banReason, bannedAt: users.bannedAt, + oidcSub: users.oidcSub, avatar: avatarFiles, banner: bannerFiles }) @@ -316,6 +366,7 @@ const getUsers = async (): Promise => { banned: users.banned, banReason: users.banReason, bannedAt: users.bannedAt, + oidcSub: users.oidcSub, avatar: avatarFiles, banner: bannerFiles }) @@ -359,6 +410,7 @@ const getUsers = async (): Promise => { banned: result.banned, banReason: result.banReason, bannedAt: result.bannedAt, + oidcSub: result.oidcSub, roleIds: rolesMap[result.id] || [] })); }; @@ -375,6 +427,7 @@ export { getStorageUsageByUserId, getUserById, getUserByIdentity, + getUserByOidcSub, getUserByToken, getUsers }; diff --git a/apps/server/src/db/schema.ts b/apps/server/src/db/schema.ts index 281631620..9ac0cd2e1 100644 --- a/apps/server/src/db/schema.ts +++ b/apps/server/src/db/schema.ts @@ -150,6 +150,7 @@ const users = sqliteTable( banReason: text('ban_reason'), bannedAt: integer('banned_at'), bannerColor: text('banner_color'), + oidcSub: text('oidc_sub').unique(), lastLoginAt: integer('last_login_at') .notNull() .$defaultFn(() => Date.now()), @@ -158,6 +159,7 @@ const users = sqliteTable( }, (t) => [ uniqueIndex('users_identity_idx').on(t.identity), + uniqueIndex('users_oidc_sub_idx').on(t.oidcSub), index('users_name_idx').on(t.name), index('users_banned_idx').on(t.banned), index('users_last_login_idx').on(t.lastLoginAt) @@ -173,7 +175,10 @@ const userRoles = sqliteTable( roleId: integer('role_id') .notNull() .references(() => roles.id, { onDelete: 'cascade' }), - createdAt: integer('created_at').notNull() + createdAt: integer('created_at').notNull(), + addedBy: text('added_by', { enum: ['manual', 'oidc', 'bot'] }) + .notNull() + .default('manual') }, (t) => [ primaryKey({ columns: [t.userId, t.roleId] }), diff --git a/apps/server/src/http/index.ts b/apps/server/src/http/index.ts index d6b7a9308..c7ace6897 100644 --- a/apps/server/src/http/index.ts +++ b/apps/server/src/http/index.ts @@ -14,6 +14,7 @@ import { import { infoRouteHandler } from './info'; import { interfaceRouteHandler } from './interface'; import { loginRouteHandler } from './login'; +import { oidcCallback, oidcLogin } from './oidc'; import { pluginBundleRouteHandler } from './plugin-bundle'; import { pluginsComponentsRouteHandler } from './plugins-components'; import { publicRouteHandler } from './public'; @@ -38,7 +39,9 @@ const routeHandlers: Partial< GET: { exact: { '/healthz': (req, res) => healthRouteHandler(req, res), - '/info': (req, res) => infoRouteHandler(req, res) + '/info': (req, res) => infoRouteHandler(req, res), + '/auth/login': (req, res) => oidcLogin(req, res), + '/auth/callback': (req, res) => oidcCallback(req, res) }, prefix: { '/public': (req, res) => publicRouteHandler(req, res), diff --git a/apps/server/src/http/info.ts b/apps/server/src/http/info.ts index e78328c8b..0b87d8ec7 100644 --- a/apps/server/src/http/info.ts +++ b/apps/server/src/http/info.ts @@ -1,5 +1,6 @@ import type { TServerInfo } from '@sharkord/shared'; import http from 'http'; +import { config } from '../config'; import { getSettings } from '../db/queries/server'; import { SERVER_VERSION } from '../utils/env'; @@ -15,7 +16,10 @@ const infoRouteHandler = async ( name: settings.name, description: settings.description, logo: settings.logo, - allowNewUsers: settings.allowNewUsers + allowNewUsers: config.server.disableLocalSignup + ? false + : settings.allowNewUsers, + oidcEnabled: config.oidc.oidcEnabled }; res.writeHead(200, { 'Content-Type': 'application/json' }); diff --git a/apps/server/src/http/login.ts b/apps/server/src/http/login.ts index 064cc72cd..2175bef45 100644 --- a/apps/server/src/http/login.ts +++ b/apps/server/src/http/login.ts @@ -159,6 +159,13 @@ const loginRouteHandler = async ( } if (!existingUser) { + if (config.server.disableLocalSignup && config.oidc.oidcEnabled) { + throw new HttpValidationError( + 'identity', + 'Registration is only allowed via OIDC' + ); + } + let inviteRoleId: number | null = null; const result = await isInviteValid(data.invite); diff --git a/apps/server/src/http/oidc.ts b/apps/server/src/http/oidc.ts new file mode 100644 index 000000000..c71353de7 --- /dev/null +++ b/apps/server/src/http/oidc.ts @@ -0,0 +1,455 @@ +import { createHash, randomBytes, timingSafeEqual } from 'crypto'; +import { and, eq, inArray } from 'drizzle-orm'; +import fs from 'fs/promises'; +import http from 'http'; +import jwt from 'jsonwebtoken'; +import * as client from 'openid-client'; +import { config } from '../config'; +import { db } from '../db'; +import { publishUser } from '../db/publishers'; +import { getDefaultRole, getRoles } from '../db/queries/roles'; +import { getServerToken } from '../db/queries/server'; +import { getUserByIdentity, getUserByOidcSub } from '../db/queries/users'; +import { userRoles, users } from '../db/schema'; + +const getBaseUrl = (req: http.IncomingMessage) => { + const protocol = (req.headers['x-forwarded-proto'] as string) || 'http'; + const host = req.headers.host; + return `${protocol}://${host}`; +}; + +const safeCompare = (a: string, b: string) => { + const bufA = Buffer.from(a); + const bufB = Buffer.from(b); + return bufA.length === bufB.length && timingSafeEqual(bufA, bufB); +}; + +// Cache the OIDC discovery document for 5 minutes to avoid hitting the +// IdP well-known endpoint on every login and every callback. +let discoveryCache: { + value: Awaited>; + issuer: string; + expiresAt: number; +} | null = null; +const DISCOVERY_CACHE_TTL_MS = 5 * 60 * 1000; + +export const getOidcConfig = async () => { + const issuerUrl = new URL(config.oidc.issuer); + + const isLocal = + issuerUrl.hostname === 'localhost' || issuerUrl.hostname === '127.0.0.1'; + if (!isLocal && issuerUrl.protocol !== 'https:') { + throw new Error( + `Security Error: OIDC Issuer must use HTTPS for non-local host: ${issuerUrl.hostname}` + ); + } + + if ( + discoveryCache && + discoveryCache.issuer === config.oidc.issuer && + Date.now() < discoveryCache.expiresAt + ) { + return discoveryCache.value; + } + + const discoveryOptions: any = {}; + + if (config.oidc.caCertPath) { + try { + const ca = await fs.readFile(config.oidc.caCertPath); + + discoveryOptions[client.customFetch] = (url: string, options: any) => { + return fetch(url, { + ...options, + ca: ca + }); + }; + } catch (err) { + console.error( + `OIDC Config Error: Failed to read CA file at ${config.oidc.caCertPath}.` + ); + } + } + + const result = await client.discovery( + issuerUrl, + config.oidc.clientId, + config.oidc.clientSecret, + undefined, + discoveryOptions + ); + + discoveryCache = { + value: result, + issuer: config.oidc.issuer, + expiresAt: Date.now() + DISCOVERY_CACHE_TTL_MS + }; + return result; +}; + +export const oidcLogin = async ( + req: http.IncomingMessage, + res: http.ServerResponse +) => { + if (config.oidc.oidcEnabled === false) { + return res.writeHead(404); + } + + try { + const as = await getOidcConfig(); + + const referer = req.headers.referer; + if (!referer) { + return res.writeHead(400, 'Referer header is missing').end(); + } + + const refererOrigin = new URL(referer).origin; + if (!config.oidc.allowedOrigins.includes(refererOrigin)) { + return res.writeHead(400, 'Invalid origin').end(); + } + + const code_verifier = client.randomPKCECodeVerifier(); + const code_challenge = + await client.calculatePKCECodeChallenge(code_verifier); + const state = client.randomState(); + const nonce = client.randomNonce(); + + const sessionData = JSON.stringify({ + code_verifier, + state, + nonce, + redirectOrigin: refererOrigin + }); + + // Set OIDC session cookie + res.setHeader( + 'Set-Cookie', + `__Host-oidc_session=${encodeURIComponent(sessionData)}; HttpOnly; Secure; SameSite=Lax; Max-Age=300; Path=/` + ); + + const baseUrl = getBaseUrl(req); + const redirectUri = `${baseUrl}/auth/callback`; + + const parameters: Record = { + redirect_uri: redirectUri, + scope: [ + 'openid', + 'profile', + 'email', + config.oidc.groupsClaim, + ...config.oidc.additionalScopes + ] + .filter(Boolean) + .join(' '), + code_challenge, + code_challenge_method: 'S256', + state, + nonce + }; + + const redirectTo = client.buildAuthorizationUrl(as, parameters); + res.writeHead(302, { Location: redirectTo.href }).end(); + } catch (error) { + console.error('OIDC Login Error:', error); + res.writeHead(500).end('Internal Server Error'); + } +}; + +export const oidcCallback = async ( + req: http.IncomingMessage, + res: http.ServerResponse +) => { + if (config.oidc.oidcEnabled === false) { + return res.writeHead(404); + } + + try { + const as = await getOidcConfig(); + + const rawCookies = req.headers.cookie || ''; + const cookieMap = Object.fromEntries( + rawCookies.split('; ').map((v) => { + const idx = v.indexOf('='); + return [v.slice(0, idx), v.slice(idx + 1)]; + }) + ); + const sessionCookie = cookieMap['__Host-oidc_session']; + + if (!sessionCookie) throw new Error('Missing OIDC session cookie'); + + let sessionData; + try { + sessionData = JSON.parse(decodeURIComponent(sessionCookie)); + } catch (e) { + throw new Error('Invalid session cookie format'); + } + const { + code_verifier, + state: expectedState, + nonce: expectedNonce, + redirectOrigin + } = sessionData; + + if ( + !redirectOrigin || + !config.oidc.allowedOrigins.includes(redirectOrigin) + ) { + throw new Error('Invalid redirect origin in session'); + } + + const baseUrl = getBaseUrl(req); + const safeUrl = (req.url || '').startsWith('/') ? req.url : '/'; + const url = new URL(safeUrl || '', baseUrl); + + const params = Object.fromEntries(url.searchParams); + + if (!params.state || !safeCompare(params.state, expectedState)) { + throw new Error('CSRF token mismatch'); + } + + const tokenResponse = await client.authorizationCodeGrant(as, url, { + pkceCodeVerifier: code_verifier, + expectedState, + expectedNonce + }); + + const idTokenClaims = tokenResponse.claims(); + if (!idTokenClaims?.sub) { + throw new Error('Invalid claims: missing sub'); + } + + let mergedClaims: Record = { ...idTokenClaims }; + const needsUserInfo = + !idTokenClaims.email || + !idTokenClaims[config.oidc.groupsClaim] || + !idTokenClaims[config.oidc.usernameClaim] || + (!!config.oidc.displayNameClaim && + !idTokenClaims[config.oidc.displayNameClaim]); + + if (needsUserInfo) { + try { + const userInfo = await client.fetchUserInfo( + as, + tokenResponse.access_token, + idTokenClaims.sub + ); + // ID token claims take precedence over UserInfo (ID token is signed) + mergedClaims = { ...userInfo, ...idTokenClaims }; + } catch (err) { + console.warn('OIDC: Could not fetch UserInfo endpoint:', err); + } + } + + if (config.oidc.requiredGroups.length > 0) { + const userGroups = ( + (mergedClaims[config.oidc.groupsClaim] as string[]) || [] + ).map((g) => g.toLowerCase()); + const hasRequired = config.oidc.requiredGroups.some((r) => + userGroups.includes(r.toLowerCase()) + ); + if (!hasRequired) { + return res.writeHead(403).end('Forbidden'); + } + } + + const identity = ((mergedClaims.email as string) || + idTokenClaims.sub) as string; + const user = await syncUserWithDatabase(identity, mergedClaims); + + const appToken = jwt.sign({ userId: user.id }, await getServerToken(), { + expiresIn: '1d' + }); + + const target = new URL(redirectOrigin); + + // Set success flag so frontend knows to initiate connection + target.searchParams.set('oidc_status', 'success'); + + // Set App Token as HttpOnly Cookie AND Clear OIDC Session in one header + const authCookie = `sharkord_token=${appToken}; Path=/; SameSite=Lax; Secure; Max-Age=86400`; + const clearSession = + '__Host-oidc_session=; Max-Age=0; Path=/; HttpOnly; Secure; SameSite=Lax'; + + res.setHeader('Set-Cookie', [authCookie, clearSession]); + res.writeHead(302, { Location: target.toString() }).end(); + } catch (error) { + console.error('OIDC Callback Error:', error); + res.writeHead(401).end('Authentication Failed'); + } +}; + +function resolveDisplayName(claims: Record): string { + if (config.oidc.displayNameClaim) { + const val = claims[config.oidc.displayNameClaim] as string | undefined; + if (val) return val; + } + return ( + (claims[config.oidc.usernameClaim] as string) ?? (claims.sub as string) + ); +} + +async function syncUserWithDatabase( + identity: string, + claims: Record +) { + const sub = claims.sub as string; + + // Look up by stable IdP subject first, fall back to identity for users + // created before oidcSub was introduced. + let user = + (await getUserByOidcSub(sub)) ?? (await getUserByIdentity(identity)); + + if (!user) { + const defaultRole = await getDefaultRole(); + if (!defaultRole) throw new Error('Default role missing'); + + const randomPassword = createHash('sha256') + .update(randomBytes(32).toString('hex')) + .digest('hex'); + + const [insertedUser] = await db + .insert(users) + .values({ + identity, + password: randomPassword, + name: resolveDisplayName(claims), + oidcSub: sub, + createdAt: Date.now(), + lastLoginAt: Date.now(), + banned: false + }) + .returning(); + + if (!insertedUser) { + throw new Error( + 'Failed to create user: Database insert returned no data.' + ); + } + + await db.insert(userRoles).values({ + roleId: defaultRole.id, + userId: insertedUser.id, + createdAt: Date.now() + }); + + publishUser(insertedUser.id, 'create'); + + user = await getUserByOidcSub(sub); + } else { + // Sync mutable fields that may have changed on the IdP. + const updates: Partial = {}; + + // Always update lastLoginAt. + updates.lastLoginAt = Date.now(); + + // Always backfill oidcSub for users created before this feature. + if (!user.oidcSub) updates.oidcSub = sub; + + // Sync identity (e.g. email) if it changed on the IdP. + if (user.identity !== identity) updates.identity = identity; + + // Sync display name: always when enforceOidcDisplayName, otherwise only + // on first login (oidcSub was null, covered by the backfill path above). + if (config.oidc.enforceOidcDisplayName) { + const idpName = resolveDisplayName(claims); + if (user.name !== idpName) updates.name = idpName; + } + + if (Object.keys(updates).length > 0) { + await db + .update(users) + .set({ ...updates, updatedAt: Date.now() }) + .where(eq(users.id, user.id)); + publishUser(user.id, 'update'); + user = await getUserByOidcSub(sub); + } + } + + if (!user) throw new Error('User synchronization failed'); + + await applyRoleMappings(user.id, claims); + return user; +} + +async function applyRoleMappings(userId: number, claims: any) { + const rolesMapping = config.oidc.rolesMapping; + if (Object.keys(rolesMapping).length === 0) return; + + const oidcGroups = ((claims[config.oidc.groupsClaim] as string[]) || []).map( + (g: string) => g.toLowerCase() + ); + const allDbRoles = await getRoles(); + const targetRoleIds: number[] = []; + + for (const [oidcRole, localRole] of Object.entries(rolesMapping)) { + if (oidcGroups.includes(oidcRole.toLowerCase())) { + const dbRole = allDbRoles.find( + (r: { id: number; name: string }) => + r.name.toLowerCase() === localRole.toLowerCase() + ); + if (dbRole) targetRoleIds.push(dbRole.id); + } + } + const uniqueTargetRoleIds = [...new Set(targetRoleIds)]; + const userCurrentRoles = await db.query.userRoles.findMany({ + where: eq(userRoles.userId, userId) + }); + + if (config.oidc.enforceOidcRoles) { + const mappedRoleNames = Object.values(rolesMapping).map((name) => + name.toLowerCase() + ); + const mappedDbRoles = allDbRoles.filter((r: { id: number; name: string }) => + mappedRoleNames.includes(r.name.toLowerCase()) + ); + const mappedDbRoleIds = mappedDbRoles.map((r) => r.id); + const userCurrentRoleIds = userCurrentRoles.map((r) => r.roleId); + + const rolesToRemove = userCurrentRoleIds.filter( + (id) => mappedDbRoleIds.includes(id) && !uniqueTargetRoleIds.includes(id) + ); + if (rolesToRemove.length > 0) { + await db + .delete(userRoles) + .where( + and( + eq(userRoles.userId, userId), + inArray(userRoles.roleId, rolesToRemove) + ) + ); + } + } else { + const oidcManagedRoleIds = userCurrentRoles + .filter((r) => r.addedBy === 'oidc') + .map((r) => r.roleId); + + const rolesToRemove = oidcManagedRoleIds.filter( + (id) => !uniqueTargetRoleIds.includes(id) + ); + if (rolesToRemove.length > 0) { + await db + .delete(userRoles) + .where( + and( + eq(userRoles.userId, userId), + inArray(userRoles.roleId, rolesToRemove), + eq(userRoles.addedBy, 'oidc') + ) + ); + } + } + + const rolesToAdd = uniqueTargetRoleIds.filter( + (id) => !userCurrentRoles.some((r) => r.roleId === id) + ); + if (rolesToAdd.length > 0) { + await db.insert(userRoles).values( + rolesToAdd.map((roleId) => ({ + userId, + roleId, + createdAt: Date.now(), + addedBy: 'oidc' as const + })) + ); + } +} diff --git a/bun.lock b/bun.lock index d87edaa7e..8a211ecfc 100644 --- a/bun.lock +++ b/bun.lock @@ -4,6 +4,9 @@ "workspaces": { "": { "name": "sharkord", + "dependencies": { + "openid-client": "^6.8.2", + }, "devDependencies": { "knip": "^5.80.0", }, @@ -88,6 +91,7 @@ "jsonwebtoken": "^9.0.2", "link-preview-js": "^3.1.0", "mediasoup": "^3.19.17", + "openid-client": "^6.8.2", "queue": "^7.0.0", "sanitize-html": "^2.17.0", "semver": "^7.7.3", @@ -1044,6 +1048,8 @@ "jiti": ["jiti@2.6.1", "", { "bin": { "jiti": "lib/jiti-cli.mjs" } }, "sha512-ekilCSN1jwRvIbgeg/57YFh8qQDNbwDb9xT/qu2DAHbFFZUicIl4ygVaAvzveMhMVr3LnpSKTNnwt8PoOfmKhQ=="], + "jose": ["jose@6.1.3", "", {}, "sha512-0TpaTfihd4QMNwrz/ob2Bp7X04yuxJkjRGi4aKmOqwhov54i6u79oCv7T+C7lo70MKH6BesI3vscD1yb/yzKXQ=="], + "js-sha256": ["js-sha256@0.11.1", "", {}, "sha512-o6WSo/LUvY2uC4j7mO50a2ms7E/EAdbP0swigLV+nzHKTTaYnaLIWJ02VdXrsJX0vGedDESQnLsOekr94ryfjg=="], "js-yaml": ["js-yaml@4.1.1", "", { "dependencies": { "argparse": "2.0.1" }, "bin": { "js-yaml": "bin/js-yaml.js" } }, "sha512-qQKT4zQxXl8lLwBtHMWwaTcGfFOZviOJet3Oy/xmGk2gZH677CJM9EvtfdSkgWcATZhj/55JZ0rmy3myCT5lsA=="], @@ -1160,10 +1166,14 @@ "nth-check": ["nth-check@2.1.1", "", { "dependencies": { "boolbase": "^1.0.0" } }, "sha512-lqjrjmaOoAnWfMmBPL+XNnynZh2+swxiX3WUE0s4yEHI6m+AwrK2UZOimIRl3X/4QctVqS8AiZjFqyOGrMXb/w=="], + "oauth4webapi": ["oauth4webapi@3.8.4", "", {}, "sha512-EKlVEgav8zH31IXxvhCqjEgQws6S9QmnmJyLXmeV5REf59g7VmqRVa5l/rhGWtUqGm2rLVTNwukn9hla5kJ2WQ=="], + "once": ["once@1.4.0", "", { "dependencies": { "wrappy": "1" } }, "sha512-lNaJgI+2Q5URQBkccEKHTQOPaXdUxnZZElQTZY0MFUAuaEqe1E+Nyvgdz/aIyNi6Z9MzO5dv1H8n58/GELp3+w=="], "one-time": ["one-time@1.0.0", "", { "dependencies": { "fn.name": "1.x.x" } }, "sha512-5DXOiRKwuSEcQ/l0kGCF6Q3jcADFv5tSmRaJck/OqkVFcOzutB134KRSfF0xDrL39MNnqxbHBbUUcjZIhTgb2g=="], + "openid-client": ["openid-client@6.8.2", "", { "dependencies": { "jose": "^6.1.3", "oauth4webapi": "^3.8.4" } }, "sha512-uOvTCndr4udZsKihJ68H9bUICrriHdUVJ6Az+4Ns6cW55rwM5h0bjVIzDz2SxgOI84LKjFyjOFvERLzdTUROGA=="], + "optionator": ["optionator@0.9.4", "", { "dependencies": { "deep-is": "^0.1.3", "fast-levenshtein": "^2.0.6", "levn": "^0.4.1", "prelude-ls": "^1.2.1", "type-check": "^0.4.0", "word-wrap": "^1.2.5" } }, "sha512-6IpQ7mKUxRcZNLIObR0hz7lxsapSSIYNZJwXPGeF0mTVqGKFIXj1DQcMoT22S3ROcLyY/rz0PWaWZ9ayWmad9g=="], "orderedmap": ["orderedmap@2.1.1", "", {}, "sha512-TvAWxi0nDe1j/rtMcWcIj94+Ffe6n7zhow33h40SKxmsmozs6dz/e+EajymfoFcHd7sxNn8yHM8839uixMOV6g=="], diff --git a/package.json b/package.json index e356a5550..7b053aee6 100644 --- a/package.json +++ b/package.json @@ -21,5 +21,8 @@ }, "devDependencies": { "knip": "^5.80.0" + }, + "dependencies": { + "openid-client": "^6.8.2" } } \ No newline at end of file diff --git a/packages/shared/src/types.ts b/packages/shared/src/types.ts index 06afecac1..e03394eb9 100644 --- a/packages/shared/src/types.ts +++ b/packages/shared/src/types.ts @@ -100,6 +100,7 @@ export type TServerInfo = Pick< > & { logo: TFile | null; version: string; + oidcEnabled: boolean; }; export type TArtifact = {