diff --git a/apps/next/middleware.ts b/apps/next/middleware.ts new file mode 100644 index 000000000..cc469942e --- /dev/null +++ b/apps/next/middleware.ts @@ -0,0 +1,30 @@ +import { NextResponse } from 'next/server' +import type { NextRequest } from 'next/server' + +// This function can be marked `async` if using `await` inside +export async function middleware(request: NextRequest) { + if (request.method === 'POST' && request.body) { + try { + const cloned = request.clone() + const requestHeaders = new Headers(request.headers) + const formData = await cloned.formData() + const userJson = formData.get('user') + if (typeof userJson === 'string') { + requestHeaders.set('x-apple-user', userJson) + } + return NextResponse.next({ + request: { + // New request headers + headers: requestHeaders, + }, + }) + } catch (e: unknown) { + console.error('error parsing oauth post', e) + } + } +} + +// See "Matching Paths" below to learn more +export const config = { + matcher: ['/oauth/apple'], +} diff --git a/apps/next/pages/oauth/[provider].tsx b/apps/next/pages/oauth/[provider].tsx index c1e184bb4..92fc295bb 100644 --- a/apps/next/pages/oauth/[provider].tsx +++ b/apps/next/pages/oauth/[provider].tsx @@ -1,13 +1,15 @@ -import { OAuthSignInScreen } from 'app/features/oauth/screen' +import { OAuthSignInScreen, OAuthSignInScreenProps } from 'app/features/oauth/screen' import Head from 'next/head' -export default function Page() { +export { getServerSideProps } from 'app/features/oauth/screen' + +export default function Page(props: OAuthSignInScreenProps) { return ( <> OAuth Sign In - + ) } diff --git a/packages/api/migrations/meta/0004_snapshot.json b/packages/api/migrations/meta/0004_snapshot.json index 2093d15a2..1da15e641 100644 --- a/packages/api/migrations/meta/0004_snapshot.json +++ b/packages/api/migrations/meta/0004_snapshot.json @@ -42,6 +42,34 @@ "notNull": false, "autoincrement": false, "default": "CURRENT_TIMESTAMP" + }, + "totp_secret": { + "name": "totp_secret", + "type": "text", + "primaryKey": false, + "notNull": false, + "autoincrement": false + }, + "totp_expires": { + "name": "totp_expires", + "type": "integer", + "primaryKey": false, + "notNull": false, + "autoincrement": false + }, + "timeout_until": { + "name": "timeout_until", + "type": "integer", + "primaryKey": false, + "notNull": false, + "autoincrement": false + }, + "timeout_seconds": { + "name": "timeout_seconds", + "type": "integer", + "primaryKey": false, + "notNull": false, + "autoincrement": false } }, "indexes": { @@ -205,79 +233,6 @@ "foreignKeys": {}, "compositePrimaryKeys": {}, "uniqueConstraints": {} - }, - "VerificationCode": { - "name": "VerificationCode", - "columns": { - "id": { - "name": "id", - "type": "text", - "primaryKey": true, - "notNull": true, - "autoincrement": false - }, - "user_id": { - "name": "user_id", - "type": "text", - "primaryKey": false, - "notNull": true, - "autoincrement": false - }, - "code": { - "name": "code", - "type": "text", - "primaryKey": false, - "notNull": true, - "autoincrement": false - }, - "expires": { - "name": "expires", - "type": "integer", - "primaryKey": false, - "notNull": true, - "autoincrement": false - }, - "timeout_until": { - "name": "timeout_until", - "type": "integer", - "primaryKey": false, - "notNull": false, - "autoincrement": false - }, - "timeout_seconds": { - "name": "timeout_seconds", - "type": "integer", - "primaryKey": false, - "notNull": true, - "autoincrement": false, - "default": 0 - } - }, - "indexes": { - "VerificationCode_user_id_unique": { - "name": "VerificationCode_user_id_unique", - "columns": ["user_id"], - "isUnique": true - }, - "idx_verificationCode_userId": { - "name": "idx_verificationCode_userId", - "columns": ["user_id"], - "isUnique": false - } - }, - "foreignKeys": { - "VerificationCode_user_id_User_id_fk": { - "name": "VerificationCode_user_id_User_id_fk", - "tableFrom": "VerificationCode", - "tableTo": "User", - "columnsFrom": ["user_id"], - "columnsTo": ["id"], - "onDelete": "no action", - "onUpdate": "no action" - } - }, - "compositePrimaryKeys": {}, - "uniqueConstraints": {} } }, "enums": {}, diff --git a/packages/api/package.json b/packages/api/package.json index 184200beb..1b7e7f493 100644 --- a/packages/api/package.json +++ b/packages/api/package.json @@ -20,15 +20,15 @@ "@cloudflare/workers-wasi": "^0.0.5", "@hono/trpc-server": "^0.1.0", "@libsql/client": "^0.3.5", - "@lucia-auth/adapter-sqlite": "3.0.0-beta.1", + "@lucia-auth/adapter-drizzle": "1.0.0-beta.2", "@trpc/server": "^10.43.2", - "arctic": "0.3.1", + "arctic": "0.10.0", "drizzle-orm": "^0.29.0", "drizzle-valibot": "beta", "hono": "^3.9.2", - "lucia": "3.0.0-beta.6", + "lucia": "3.0.0-beta.12", "miniflare": "3.20231025.1", - "oslo": "0.22.0", + "oslo": "0.24.0", "superjson": "1.13.3", "ts-pattern": "^5.0.5", "valibot": "^0.20.1" diff --git a/packages/api/src/auth/hono.ts b/packages/api/src/auth/hono.ts index 38bc0905b..03ef9b4f0 100644 --- a/packages/api/src/auth/hono.ts +++ b/packages/api/src/auth/hono.ts @@ -1,26 +1,22 @@ -import { getAuthOptions } from './shared' -import { D1Adapter } from '@lucia-auth/adapter-sqlite' -import type { Context as HonoContext, HonoRequest } from 'hono' -import { Lucia } from 'lucia' -import type { Middleware } from 'lucia' +import { getAllowedOriginHost } from '.' +import type { Context as HonoContext, Next } from 'hono' +import { Bindings } from '../worker' +import { verifyRequestOrigin } from 'oslo/request' -export const hono = (): Middleware<[HonoContext]> => { - return ({ args }) => { - const [context] = args - return { - request: context.req, - setCookie: (cookie) => { - context.res.headers.append('set-cookie', cookie.serialize()) - }, - } +export const csrfMiddleware = async (c: HonoContext<{ Bindings: Bindings }>, next: Next) => { + // CSRF middleware + if (c.req.method === 'GET') { + return next() } + const originHeader = c.req.header('origin') + const hostHeader = c.req.header('host') + const allowedOrigin = getAllowedOriginHost(c.env.APP_URL, c.req.raw) + if ( + !originHeader || + !hostHeader || + !verifyRequestOrigin(originHeader, [hostHeader, ...(allowedOrigin ? [allowedOrigin] : [])]) + ) { + return c.body(null, 403) + } + return next() } - -export const createHonoAuth = (db: D1Database, appUrl: string, request?: HonoRequest) => { - return new Lucia(new D1Adapter(db, { session: 'session', user: 'user' }), { - ...getAuthOptions(db, appUrl, request), - middleware: hono(), - }) -} - -export type HonoLucia = ReturnType diff --git a/packages/api/src/auth/index.ts b/packages/api/src/auth/index.ts new file mode 100644 index 000000000..47cbed5b3 --- /dev/null +++ b/packages/api/src/auth/index.ts @@ -0,0 +1,68 @@ +import { Adapter, DatabaseSessionAttributes, DatabaseUserAttributes, Lucia, TimeSpan } from 'lucia' +import { DrizzleSQLiteAdapter } from '@lucia-auth/adapter-drizzle' +import { SessionTable, UserTable } from '../db/schema' +import { DB } from '../db/client' + +/** + * Lucia's isValidRequestOrigin method will compare the + * origin of the request to the configured host. + * We want to allow cross-domain requests from our APP_URL so return that + * if the request origin host matches the APP_URL host. + * @link https://github.com/lucia-auth/lucia/blob/main/packages/lucia/src/utils/url.ts + */ +export const getAllowedOriginHost = (app_url: string, request?: Request) => { + if (!app_url || !request) return undefined + const requestOrigin = request.headers.get('Origin') + const requestHost = requestOrigin ? new URL(requestOrigin).host : undefined + const appHost = new URL(app_url).host + return requestHost === appHost ? appHost : undefined +} + +export const createAuth = (db: DB, appUrl: string) => { + // @ts-ignore Expect type errors because this is D1 and not SQLite... but it works + const adapter = new DrizzleSQLiteAdapter(db, SessionTable, UserTable) + // cast probably only needed until adapter-drizzle is updated + return new Lucia(adapter as Adapter, { + ...getAuthOptions(appUrl), + }) +} + +export const getAuthOptions = (appUrl: string) => { + const env = !appUrl || appUrl.startsWith('http:') ? 'DEV' : 'PROD' + return { + getUserAttributes: (data: DatabaseUserAttributes) => { + return { + email: data.email || '', + } + }, + // Optional additional session attributes to expose + // If updated, also update createSession() in packages/api/src/auth/user.ts + getSessionAttributes: (databaseSession: DatabaseSessionAttributes) => { + return {} + }, + sessionExpiresIn: new TimeSpan(365, 'd'), + sessionCookie: { + name: 'auth_session', + expires: false, + attributes: { + secure: env === 'PROD', + sameSite: 'lax' as const, + }, + }, + + // If you want more debugging, uncomment this + // experimental: { + // debugMode: true, + // }, + } +} + +declare module 'lucia' { + interface Register { + Lucia: ReturnType + } + interface DatabaseSessionAttributes {} + interface DatabaseUserAttributes { + email: string | null + } +} diff --git a/packages/api/src/auth/nextjs.ts b/packages/api/src/auth/nextjs.ts deleted file mode 100644 index 1f7efa9a9..000000000 --- a/packages/api/src/auth/nextjs.ts +++ /dev/null @@ -1,22 +0,0 @@ -// This is currently not used because the T4 API server can be on a different -// domain and the NextJS server won't be able to read the session cookie. -// If you move the hono routes to an /api prefix and proxy the API server -// through the NextJS server using the -// next.config.js rewrites feature, you could try using this to load the session -// from the database and set up a context for tRPC SSR. -// However, there are caching benefits to not utilizing SSR and auth for -// html requests and only utilizing the API server to fetch data. - -import { getAuthOptions } from './shared' -import { D1Adapter } from '@lucia-auth/adapter-sqlite' -import { Lucia } from 'lucia' -import { nextjs } from 'lucia/middleware' - -export const createNextJSAuth = (db: D1Database, appUrl: string, apiUrl: string) => { - return new Lucia(new D1Adapter(db, { session: 'session', user: 'user' }), { - ...getAuthOptions(db, appUrl), - middleware: nextjs(), - }) -} - -export type NextJSLucia = ReturnType diff --git a/packages/api/src/auth/oauth.ts b/packages/api/src/auth/oauth.ts new file mode 100644 index 000000000..20063830f --- /dev/null +++ b/packages/api/src/auth/oauth.ts @@ -0,0 +1,336 @@ +import { ApiContextProps } from '../context' +import { User } from '../db/schema' +import { + Apple, + AppleRefreshedTokens, + AppleTokens, + Discord, + DiscordTokens, + GitHub, + Google, + GoogleTokens, + generateCodeVerifier, + generateState, +} from 'arctic' +import { + AuthProvider, + AuthProviderName, + AuthTokens, + isOAuth2ProviderWithPKCE, + providers, +} from './providers' +import { isWithinExpirationDate } from 'oslo' +import { parseJWT } from 'oslo/jwt' +import { createAuthMethodId, createUser, getAuthMethod, getUserById } from './user' + +import { P, match } from 'ts-pattern' +import { getCookie } from 'hono/cookie' +import { TRPCError } from '@trpc/server' + +export interface AppleIdTokenClaims { + iss: 'https://appleid.apple.com' + sub: string + aud: string + iat: number + exp: number + email?: string + email_verified?: boolean + is_private_email?: boolean + nonce?: string + nonce_supported?: boolean + real_user_status: 0 | 1 | 2 + transfer_sub?: string +} + +export const getAuthProvider = (ctx: ApiContextProps, name: AuthProviderName): AuthProvider => { + const origin = ctx.env.APP_URL ? new URL(ctx.env.APP_URL).origin : '' + if (!providers[name]) { + if (name === 'apple') { + providers[name] = new Apple( + { + clientId: ctx.env.APPLE_CLIENT_ID, + certificate: ctx.env.APPLE_CERTIFICATE, + keyId: ctx.env.APPLE_KEY_ID, + teamId: ctx.env.APPLE_TEAM_ID, + }, + `${origin}/oauth/${name}` + ) + } + if (name === 'discord') { + providers[name] = new Discord( + ctx.env.DISCORD_CLIENT_ID, + ctx.env.DISCORD_CLIENT_SECRET, + `${origin}/oauth/${name}` + ) + } + if (name === 'github') { + providers[name] = new GitHub(ctx.env.GITHUB_CLIENT_ID, ctx.env.GITHUB_CLIENT_SECRET, { + redirectURI: `${origin}/oauth/${name}`, + }) + } + if (name === 'google') { + providers[name] = new Google( + ctx.env.GOOGLE_CLIENT_ID, + ctx.env.GOOGLE_CLIENT_SECRET, + `${origin}/oauth/${name}` + ) + } + } + const service = providers[name] + if (service === null) { + throw new Error(`Unable to configure oauth for ${name}`) + } + return service +} + +export function getAppleClaims(idToken?: string): AppleIdTokenClaims | undefined { + if (!idToken) return undefined + const payload = parseJWT(idToken)?.payload + return payload && + 'iss' in payload && + payload.iss === 'https://appleid.apple.com' && + 'sub' in payload && + 'aud' in payload && + 'iat' in payload && + 'exp' in payload + ? (payload as AppleIdTokenClaims) + : undefined +} + +export const getAuthorizationUrl = async (ctx: ApiContextProps, service: AuthProviderName) => { + const provider = getAuthProvider(ctx, service) + const secure = ctx.req?.url.startsWith('https:') ? 'Secure; ' : '' + const state = generateState() + ctx.setCookie( + `${service}_oauth_state=${state}; Path=/; ${secure}HttpOnly; SameSite=Lax; Max-Age=600` + ) + return await match({ provider, service }) + .with({ service: 'google', provider: P.instanceOf(Google) }, async ({ provider }) => { + // Google requires PKCE + const codeVerifier = generateCodeVerifier() + ctx.setCookie( + `${service}_oauth_verifier=${codeVerifier}; Path=/; ${secure}HttpOnly; SameSite=Lax; Max-Age=600` + ) + const url = await provider.createAuthorizationURL(state, codeVerifier, { + scopes: ['https://www.googleapis.com/auth/userinfo.email'], + }) + // Uncomment if you need to get and store a refresh token + // Currently, OAuth is only used for the initial sign in + // so we don't need to persist the access or refresh tokens + // url.searchParams.set('access_type', 'offline') + return url + }) + .with({ service: 'apple', provider: P.instanceOf(Apple) }, async ({ provider }) => { + const url = await provider.createAuthorizationURL(state, { scopes: ['email'] }) + url.searchParams.set('response_mode', 'form_post') + return url + }) + .with({ service: 'discord', provider: P.instanceOf(Discord) }, async ({ provider }) => { + return await provider.createAuthorizationURL(state) + }) + .with({ service: 'github', provider: P.instanceOf(GitHub) }, async ({ provider }) => { + return await provider.createAuthorizationURL(state, { scopes: ['email'] }) + }) + .otherwise(() => { + throw new Error('Unknown auth provider') + }) +} + +const checkAuthTokens = async (tokens: Partial, authProvider: AuthProvider) => { + let accessToken: string | undefined = tokens.accessToken + let accessTokenExpiresAt: Date | undefined + let refreshToken: string | null | undefined + let idTokenClaims: AppleIdTokenClaims | undefined + + if ('idToken' in tokens && authProvider instanceof Apple) { + idTokenClaims = getAppleClaims(tokens.idToken) + } + if ('refreshToken' in tokens) { + refreshToken = (tokens as Partial).refreshToken + } + if ('accessTokenExpiresAt' in tokens) { + accessTokenExpiresAt = (tokens as Partial) + .accessTokenExpiresAt + if (!accessTokenExpiresAt || !isWithinExpirationDate(accessTokenExpiresAt)) { + if (refreshToken && 'refreshAccessToken' in authProvider) { + const refreshedTokens = await authProvider.refreshAccessToken(refreshToken) + if (refreshedTokens?.accessToken) { + accessToken = refreshedTokens.accessToken + } + if (refreshedTokens?.accessTokenExpiresAt) { + accessTokenExpiresAt = refreshedTokens.accessTokenExpiresAt + } + if (refreshedTokens && 'idToken' in refreshedTokens) { + idTokenClaims = getAppleClaims((refreshedTokens as AppleRefreshedTokens).idToken) + } + } + } + if (!accessTokenExpiresAt || !isWithinExpirationDate(accessTokenExpiresAt)) { + throw new Error('Access token is expired') + } + } + return { + accessToken, + accessTokenExpiresAt, + refreshToken, + idTokenClaims, + } +} + +type GetOAuthUserResult = { + attributes: Partial + providerUserId: string + authMethodId: string +} + +// https://arctic.pages.dev/providers/apple +export async function getAppleUser({ + idTokenClaims, +}: { idTokenClaims: AppleIdTokenClaims }): Promise { + return { + attributes: { + email: idTokenClaims?.email || undefined, + }, + providerUserId: idTokenClaims.sub, + authMethodId: createAuthMethodId('apple', idTokenClaims.sub), + } +} + +// https://arctic.pages.dev/providers/discord +// https://discord.com/developers/docs/resources/user#user-object +export async function getDiscordUser({ + accessToken, +}: { accessToken: string }): Promise { + const res = await ( + await fetch('https://discord.com/api/users/@me', { + headers: { + Authorization: `Bearer ${accessToken}`, + }, + }) + ).json<{ email: string; id: string }>() + return { + attributes: { + email: res.email, + }, + authMethodId: res.id, + providerUserId: createAuthMethodId('discord', res.id), + } +} + +export async function getGitHubUser({ + accessToken, +}: { accessToken: string }): Promise { + const user = await ( + await fetch('https://api.github.com/user', { + headers: { + Authorization: `Bearer ${accessToken}`, + }, + }) + ).json<{ id: string; email: string }>() + const emails = await ( + await fetch('https://api.github.com/user/emails', { + headers: { + Authorization: `Bearer ${accessToken}`, + }, + }) + ).json<{ primary: boolean; email: string }[]>() + const primaryEmail = emails.find((e: { primary: boolean }) => e.primary)?.email + return { + attributes: { + email: user.email || primaryEmail, + }, + providerUserId: user.id, + authMethodId: createAuthMethodId('github', user.id), + } +} + +export async function getGoogleUser({ + accessToken, +}: { accessToken: string }): Promise { + const res = await ( + await fetch('https://openidconnect.googleapis.com/v1/userinfo', { + headers: { + Authorization: `Bearer ${accessToken}`, + }, + }) + ).json<{ email: string; sub: string }>() + return { + attributes: { + email: res.email || undefined, + }, + providerUserId: res.sub, + authMethodId: createAuthMethodId('google', res.sub), + } +} + +export const getUserFromAuthProvider = async <_AuthTokens extends AuthTokens>( + ctx: ApiContextProps, + service: AuthProviderName, + authProvider: AuthProvider, + tokens: Partial<_AuthTokens>, + userData: Partial = {} +): Promise => { + const { accessToken, idTokenClaims } = await checkAuthTokens(tokens, authProvider) + const { authMethodId, providerUserId, attributes } = await match({ + authProvider, + accessToken, + idTokenClaims, + }) + .with({ authProvider: P.instanceOf(Apple), idTokenClaims: P.nullish }, async () => { + throw new Error('Apple idToken is required') + }) + .with({ authProvider: P.instanceOf(Apple), idTokenClaims: { sub: P.nullish } }, async () => { + throw new Error('Missing subject in Apple idToken') + }) + .with({ authProvider: P.instanceOf(Apple), idTokenClaims: P.not(P.nullish) }, getAppleUser) + .with({ accessToken: P.nullish }, async () => { + throw new Error('Access token is required') + }) + .with({ authProvider: P.instanceOf(Discord), accessToken: P.not(P.nullish) }, getDiscordUser) + .with({ authProvider: P.instanceOf(GitHub), accessToken: P.not(P.nullish) }, getGitHubUser) + .with({ authProvider: P.instanceOf(Google), accessToken: P.not(P.nullish) }, getGoogleUser) + .otherwise(() => { + throw new Error('Unknown auth provider') + }) + + if (!authMethodId || !providerUserId) { + throw new Error('Unknown auth provider') + } + const existingAuthMethod = await getAuthMethod(ctx, authMethodId) + if (existingAuthMethod) { + const existingUser = await getUserById(ctx, existingAuthMethod.userId) + if (existingUser) { + return existingUser + } + } + + // TODO if the email is used by another user without oauth, you will wind up with two + // users with the same email. If you'd rather merge the accounts, you could create + // a unique index on the email column in the user table and then check for that first + // and either return an error or automatically link the accounts. + // See https://lucia-auth.com/guidebook/oauth-account-linking/ for an example on + // how to automatically link the accounts. + const user = await createUser(ctx, service, providerUserId, null, { + ...attributes, + ...userData, + }) + return user +} + +export const getOAuthUser = async ( + service: AuthProviderName, + ctx: ApiContextProps, + { code, userData }: { code: string; userData?: Partial } +): Promise => { + const authService = getAuthProvider(ctx, service) + if (isOAuth2ProviderWithPKCE(authService)) { + if (!ctx.c) throw new TRPCError({ code: 'INTERNAL_SERVER_ERROR', message: 'Missing context' }) + const codeVerifier = getCookie(ctx.c, `${service}_oauth_verifier`) + if (!codeVerifier) + throw new TRPCError({ code: 'INTERNAL_SERVER_ERROR', message: 'Missing code verifier' }) + const validateResult = await authService.validateAuthorizationCode(code, codeVerifier) + return getUserFromAuthProvider(ctx, service, authService, validateResult, userData) + } + const validateResult = await authService.validateAuthorizationCode(code) + return getUserFromAuthProvider(ctx, service, authService, validateResult) +} diff --git a/packages/api/src/auth/providers.ts b/packages/api/src/auth/providers.ts index 856d944df..131e699a6 100644 --- a/packages/api/src/auth/providers.ts +++ b/packages/api/src/auth/providers.ts @@ -6,9 +6,9 @@ import type { DiscordTokens, GitHub, GitHubTokens, - Google, GoogleTokens, } from 'arctic' +import { Google } from 'arctic' export const providers: { apple: Apple | null @@ -24,7 +24,15 @@ export const providers: { export type AuthProviderName = keyof typeof providers export type AuthProvider = Apple | Discord | GitHub | Google export type AuthTokens = AppleTokens | DiscordTokens | GitHubTokens | GoogleTokens +export type AuthProviderWithPKCE = Google +export type AuthProviderWithoutPKCE = Exclude export const isAuthProviderName = (name: string): name is AuthProviderName => { return name in providers } + +export const isOAuth2ProviderWithPKCE = ( + provider: AuthProvider +): provider is AuthProviderWithPKCE => { + return provider instanceof Google +} diff --git a/packages/api/src/auth/shared.ts b/packages/api/src/auth/shared.ts deleted file mode 100644 index 3b0a4f233..000000000 --- a/packages/api/src/auth/shared.ts +++ /dev/null @@ -1,256 +0,0 @@ -import { ApiContextProps } from '../context' -import { User } from '../db/schema' -import { - Apple, - AppleIdTokenClaims, - AppleTokens, - Discord, - DiscordTokens, - GitHub, - GitHubTokens, - Google, - GoogleTokens, -} from 'arctic' -import type { HonoRequest } from 'hono' -import { DatabaseSessionAttributes, DatabaseUserAttributes, TimeSpan } from 'lucia' -import { AuthProvider, AuthProviderName, AuthTokens, providers } from './providers' -import { isWithinExpirationDate } from 'oslo' -import { createAuthMethodId, createUser, getAuthMethod, getUserById } from './user' -import type { HonoLucia } from './hono' - -export const getAuthProvider = (ctx: ApiContextProps, name: AuthProviderName): AuthProvider => { - const origin = ctx.env.APP_URL ? new URL(ctx.env.APP_URL).origin : '' - if (!providers[name]) { - if (name === 'apple') { - providers[name] = new Apple( - { - clientId: ctx.env.APPLE_CLIENT_ID, - certificate: ctx.env.APPLE_CERTIFICATE, - keyId: ctx.env.APPLE_KEY_ID, - teamId: ctx.env.APPLE_TEAM_ID, - }, - `${origin}/oauth/${name}`, - { - responseMode: 'form_post', - scope: ['email'], - } - ) - } - if (name === 'discord') { - providers[name] = new Discord( - ctx.env.DISCORD_CLIENT_ID, - ctx.env.DISCORD_CLIENT_SECRET, - `${origin}/oauth/${name}`, - { - scope: ['email'], - } - ) - } - if (name === 'github') { - providers[name] = new GitHub(ctx.env.GITHUB_CLIENT_ID, ctx.env.GITHUB_CLIENT_SECRET, { - redirectURI: `${origin}/oauth/${name}`, - scope: ['email'], - }) - } - if (name === 'google') { - providers[name] = new Google( - ctx.env.GOOGLE_CLIENT_ID, - ctx.env.GOOGLE_CLIENT_SECRET, - `${origin}/oauth/${name}`, - { - scope: ['https://www.googleapis.com/auth/userinfo.email'], - } - ) - } - } - const service = providers[name] - if (service === null) { - throw new Error(`Unable to configure oauth for ${name}`) - } - return service -} - -/** - * Lucia's isValidRequestOrigin method will compare the - * origin of the request to the configured host. - * We want to allow cross-domain requests from our APP_URL so return that - * if the request origin host matches the APP_URL host. - * @link https://github.com/lucia-auth/lucia/blob/main/packages/lucia/src/utils/url.ts - */ -export const getAllowedOriginHost = (app_url: string, request?: HonoRequest) => { - if (!app_url || !request) return undefined - const requestOrigin = request.header('Origin') - const requestHost = requestOrigin ? new URL(requestOrigin).host : undefined - const appHost = new URL(app_url).host - return requestHost === appHost ? appHost : undefined -} - -export const getAuthOptions = (db: D1Database, appUrl: string, request?: HonoRequest) => { - const env = !appUrl || appUrl.startsWith('http:') ? 'DEV' : 'PROD' - const allowedHost = getAllowedOriginHost(appUrl, request) - return { - // Lucia's d1 adapter makes queries for sessions and users directly from the database - // Does drizzle provide a constructor we could use here to automatically perform the transforms? - getUserAttributes: (data: DatabaseUserAttributes) => { - if ('attributes' in data) { - // biome-ignore lint/style/noParameterAssign: this will be fixed in the next lucia v3 beta - data = data.attributes as DatabaseUserAttributes - } - return { - email: data.email || '', - } - }, - // Optional additional session attributes to expose - // If updated, also update createSession() in packages/api/src/auth/user.ts - getSessionAttributes: (databaseSession: DatabaseSessionAttributes) => { - return {} - }, - sessionExpiresIn: new TimeSpan(365, 'd'), - sessionCookie: { - name: 'auth_session', - expires: false, - attributes: { - secure: env === 'PROD', - sameSite: 'lax' as const, - }, - }, - - // https://lucia-auth.com/basics/configuration/#csrfprotection - csrfProtection: { - allowedSubDomains: '*', - allowedDomains: allowedHost ? [allowedHost] : undefined, - }, - - // If you want more debugging, uncomment this - // experimental: { - // debugMode: true, - // }, - } -} - -export const getUserFromAuthProvider = async <_AuthTokens extends AuthTokens>( - ctx: ApiContextProps, - service: AuthProviderName, - authProvider: AuthProvider, - tokens: Partial<_AuthTokens> -): Promise => { - // ts-pattern would make this a little cleaner - let accessToken: string | undefined = tokens.accessToken - let accessTokenExpiresAt: Date | undefined - let refreshToken: string | null | undefined - let idTokenClaims: AppleIdTokenClaims | undefined - const isApple = authProvider instanceof Apple - - if ('refreshToken' in tokens) { - refreshToken = (tokens as Partial).refreshToken - } - if ('accessTokenExpiresAt' in tokens) { - accessTokenExpiresAt = (tokens as Partial) - .accessTokenExpiresAt - if (!accessTokenExpiresAt || !isWithinExpirationDate(accessTokenExpiresAt)) { - if (refreshToken && 'refreshAccessToken' in authProvider) { - const refreshedTokens = await authProvider.refreshAccessToken(refreshToken) - if (refreshedTokens?.accessToken) { - accessToken = refreshedTokens.accessToken - } - if (refreshedTokens?.accessTokenExpiresAt) { - accessTokenExpiresAt = refreshedTokens.accessTokenExpiresAt - } - if (refreshedTokens as Partial) { - idTokenClaims = (refreshedTokens as Partial).idTokenClaims - } - } - } - if (!accessTokenExpiresAt || !isWithinExpirationDate(accessTokenExpiresAt)) { - throw new Error('Access token is expired') - } - } - if (isApple) { - idTokenClaims = (tokens as Partial).idTokenClaims - if (!idTokenClaims) { - throw new Error('Apple idToken is required') - } - } - - let attributes: Partial = {} - let providerUserId: string | undefined - let authMethodId: string | undefined - if (isApple) { - attributes = { - email: idTokenClaims?.email || undefined, - } - providerUserId = idTokenClaims?.sub - if (!providerUserId) { - throw new Error('Missing subject in Apple idToken') - } - authMethodId = createAuthMethodId('apple', providerUserId) - } else { - if (!accessToken) { - throw new Error('Access token is required') - } - if (authProvider instanceof Discord) { - const discordUser = await authProvider.getUser(accessToken) - providerUserId = discordUser.id - authMethodId = createAuthMethodId('discord', providerUserId) - attributes = { - email: discordUser.email || undefined, - } - } - if (authProvider instanceof GitHub) { - const githubUser = await (authProvider as GitHub).getUser(accessToken) - providerUserId = githubUser.id.toString() - authMethodId = createAuthMethodId('github', providerUserId) - attributes = { - email: githubUser.email || undefined, - } - } - if (authProvider instanceof Google) { - const googleUser = await authProvider.getUser(accessToken) - providerUserId = googleUser.sub - authMethodId = createAuthMethodId('google', googleUser.sub) - attributes = { - email: googleUser.email, - } - } - } - if (!authMethodId || !providerUserId) { - throw new Error('Unknown auth provider') - } - const existingAuthMethod = await getAuthMethod(ctx, authMethodId) - if (existingAuthMethod) { - const existingUser = await getUserById(ctx, existingAuthMethod.userId) - if (existingUser) { - return existingUser - } - } - - // TODO if the email is used by another user without oauth, you will wind up with two - // users with the same email. If you'd rather merge the accounts, you could create - // a unique index on the email column in the user table and then check for that first - // and either return an error or automatically link the accounts. - // See https://lucia-auth.com/guidebook/oauth-account-linking/ for an example on - // how to automatically link the accounts. - const user = await createUser(ctx, service, providerUserId, null, attributes) - return user -} - -export const getOAuthUser = async ( - service: AuthProviderName, - ctx: ApiContextProps, - { code }: { code: string } -): Promise => { - const authService = getAuthProvider(ctx, service) - const validateResult = await authService.validateAuthorizationCode(code) - return getUserFromAuthProvider(ctx, service, authService, validateResult) -} - -declare module 'lucia' { - interface Register { - Lucia: HonoLucia - DatabaseUserAttributes: { - email: string | null - } - // biome-ignore lint/complexity/noBannedTypes: Need to define this even if empty - DatabaseSessionAttributes: {} - } -} diff --git a/packages/api/src/auth/user.ts b/packages/api/src/auth/user.ts index 6e98a7485..fb2cdc785 100644 --- a/packages/api/src/auth/user.ts +++ b/packages/api/src/auth/user.ts @@ -10,7 +10,7 @@ import { isWithinExpirationDate } from 'oslo' import { createCode, createTotpSecret, verifyCode } from '../utils/crypto' import { AuthProviderName } from './providers' import { OAuth2RequestError } from 'arctic' -import { getOAuthUser } from './shared' +import { getOAuthUser } from './oauth' export const createAuthMethodId = (providerId: string, providerUserId: string) => { if (providerId.includes(':')) { @@ -315,17 +315,21 @@ export const signInWithOAuthCode = async ( code: string, state?: string, storedState?: string, - redirectTo?: string + redirectTo?: string, + userData?: Partial ) => { if (!storedState || !state || storedState !== state || typeof code !== 'string') { throw new TRPCError({ message: 'Invalid state', code: 'BAD_REQUEST' }) } try { - const user = await getOAuthUser(service, ctx, { code }) + const user = await getOAuthUser(service, ctx, { code, userData }) const session = await createSession(ctx.auth, user.id) - ctx.authRequest?.setSessionCookie(session.id) + if (ctx.setCookie) { + ctx.setCookie(ctx.auth.createSessionCookie(session.id).serialize()) + } return { redirectTo: `${redirectTo ? redirectTo : ctx.env.APP_URL}#token=${session.id}` } } catch (e) { + console.error(e) if (e instanceof OAuth2RequestError) { throw new TRPCError({ message: 'Invalid code', code: 'BAD_REQUEST' }) } @@ -382,7 +386,7 @@ export const cleanup = async (context: ApiContextProps, userId?: string) => { .where( and( userId ? eq(SessionTable.userId, userId) : undefined, - lt(SessionTable.expiresAt, new Date()) + lt(SessionTable.expiresAt, new Date().getTime()) ) ) } diff --git a/packages/api/src/context.ts b/packages/api/src/context.ts index 806bd77ac..b58a54215 100644 --- a/packages/api/src/context.ts +++ b/packages/api/src/context.ts @@ -1,4 +1,3 @@ -import { createHonoAuth, HonoLucia } from './auth/hono' import { type Session } from './auth/user' import { createDb } from './db/client' import type { DB } from './db/client' @@ -6,14 +5,15 @@ import type { User } from './db/schema' import { Bindings } from './worker' import type { inferAsyncReturnType } from '@trpc/server' import type { Context as HonoContext, HonoRequest } from 'hono' -import type { AuthRequest, Lucia } from 'lucia' +import type { Lucia } from 'lucia' import { verifyToken } from './utils/crypto' +import { createAuth } from './auth' +import { getCookie } from 'hono/cookie' export interface ApiContextProps { session?: Session user?: User - auth: HonoLucia - authRequest?: AuthRequest + auth: Lucia req?: HonoRequest c?: HonoContext setCookie: (value: string) => void @@ -62,37 +62,48 @@ export const createContext = async ( // const user = await getUser() - const auth = createHonoAuth(env.DB, env.APP_URL, context.req) + const auth = createAuth(db, env.APP_URL) async function getSession() { let user: User | undefined let session: Session | undefined - let authRequest: AuthRequest | undefined + const res = { + user, + session, + } + + if (!context.req) return res + + const cookieSessionId = getCookie(context, auth.sessionCookieName) + const bearerSessionId = + !cookieSessionId && + context.req.header('x-enable-tokens') && + context.req.header('authorization')?.split(' ')[1] - if (context.req) { - authRequest = auth.handleRequest(context) - const authResult = await authRequest.validate() - if (authResult.session) { - session = authResult.session - user = authResult.user || undefined + if (!cookieSessionId && !bearerSessionId) return res + + const authResult = await auth.validateSession(cookieSessionId || bearerSessionId || '') + if (cookieSessionId) { + if (authResult.session?.fresh) { + context.header('Set-Cookie', auth.createSessionCookie(authResult.session.id).serialize(), { + append: true, + }) } - // console.log('cookie session and auth request', session, authRequest) - if (!session && context.req.header('x-enable-tokens')) { - const tokenAuthResult = await authRequest.validateBearerToken() - if (tokenAuthResult.session) { - session = tokenAuthResult.session - user = tokenAuthResult.user || undefined - } + if (!session) { + context.header('Set-Cookie', auth.createBlankSessionCookie().serialize(), { + append: true, + }) } } - return { session, user, authRequest } + res.session = authResult.session || undefined + res.user = authResult.user || undefined + return res } - const { session, user, authRequest } = await getSession() + const { session, user } = await getSession() return { db, auth, - authRequest, req: context.req, c: context, session, diff --git a/packages/api/src/db/schema.ts b/packages/api/src/db/schema.ts index 3d570300a..cf0745d6c 100644 --- a/packages/api/src/db/schema.ts +++ b/packages/api/src/db/schema.ts @@ -5,7 +5,7 @@ import { HASH_METHODS } from '../utils/password/hash-methods' // User export const UserTable = sqliteTable('User', { - id: text('id').primaryKey(), + id: text('id').notNull().primaryKey(), email: text('email').notNull(), }) export const userRelations = relations(UserTable, ({ many }) => ({ @@ -58,11 +58,12 @@ export const AuthMethodSchema = createInsertSchema(AuthMethodTable) export const SessionTable = sqliteTable( 'Session', { - id: text('id').primaryKey(), + id: text('id').notNull().primaryKey(), userId: text('user_id') .notNull() .references(() => UserTable.id), - expiresAt: integer('expires_at', { mode: 'timestamp' }).notNull(), + // DrizzleSQLiteAdapter currently expects this to be an integer and not use { mode: 'timestamp' } + expiresAt: integer('expires_at').notNull(), }, (t) => ({ userIdIdx: index('idx_session_userId').on(t.userId), diff --git a/packages/api/src/routes/user.ts b/packages/api/src/routes/user.ts index 465fd76ff..6e6744937 100644 --- a/packages/api/src/routes/user.ts +++ b/packages/api/src/routes/user.ts @@ -1,5 +1,5 @@ import { desc, eq } from 'drizzle-orm' -import { UserTable, type User, SessionTable, AuthMethodTable } from '../db/schema' +import { UserTable, SessionTable, AuthMethodTable } from '../db/schema' import { router, protectedProcedure, publicProcedure, valibotParser } from '../trpc' import { Input } from 'valibot' import { ApiContextProps } from '../context' @@ -21,11 +21,15 @@ import { import { TRPCError } from '@trpc/server' import { GetByIdSchema, GetSessionsSchema } from '../schema/shared' import { CreateUserSchema, SignInSchema } from '../schema/user' -import { AppleIdTokenClaims, generateCodeVerifier, generateState } from 'arctic' -import { getAuthProvider, getUserFromAuthProvider } from '../auth/shared' -import { verifyToken, isJWTExpired, sha256 } from '../utils/crypto' +import { + AppleIdTokenClaims, + getAuthProvider, + getAuthorizationUrl, + getUserFromAuthProvider, +} from '../auth/oauth' +import { isJWTExpired, sha256 } from '../utils/crypto' import { getCookie } from 'hono/cookie' -import { JWT, parseJWT } from '../utils/jwt' +import { parseJWT } from 'oslo/jwt' import { P, match } from 'ts-pattern' import { AuthProviderName } from '../auth/providers' @@ -144,7 +148,7 @@ const signInWithAppleIdTokenHandler = // throw new Error('Unable to fetch Apple public key') // } // return key - // })) as AppleIdTokenClaims & { nonce?: string; nonce_supported?: boolean } + // })) as unknown as AppleIdTokenClaims // Since we can't verify the JWT, check that it hasn't expired const parsedJWT = parseJWT(input.idToken) if (parsedJWT && isJWTExpired(parsedJWT)) { @@ -153,10 +157,7 @@ const signInWithAppleIdTokenHandler = message: 'The Apple ID token has expired.', }) } - const payload = parsedJWT?.payload as AppleIdTokenClaims & { - nonce?: string - nonce_supported?: boolean - } + const payload = parsedJWT?.payload as AppleIdTokenClaims if (!payload) { console.error('Apple ID token could not be verified.', { payload, @@ -185,7 +186,9 @@ const signInWithAppleIdTokenHandler = idTokenClaims: payload, }) const session = await createSession(ctx.auth, user.id) - ctx.authRequest?.setSessionCookie(session.id) + if (ctx.setCookie) { + ctx.setCookie(ctx.auth.createSessionCookie(session.id).serialize()) + } return { session } } @@ -201,6 +204,7 @@ const signInWithOAuthCodeHandler = } const storedState = getCookie(ctx.c, `${input.provider}_oauth_state`) + const storedVerifier = getCookie(ctx.c, `${input.provider}_oauth_verifier`) const storedRedirect = getCookie(ctx.c, `${input.provider}_oauth_redirect`) return await signInWithOAuthCode( ctx, @@ -208,28 +212,26 @@ const signInWithOAuthCodeHandler = input.code, input.state, storedState, - storedRedirect + storedRedirect, + input.appleUser ) } const authorizationUrlHandler = (ctx: ApiContextProps) => async (input: Input & { provider: AuthProviderName }) => { - // Get the authorization URL and store the state in a cookie - const provider = getAuthProvider(ctx, input.provider) - const state = generateState() - // TODO Mentioned in docs but seem to be used yet... circle back with future arctic release - // const codeVerifier = generateCodeVerifier() - const url = await provider.createAuthorizationURL(state) - ctx.setCookie(`${input.provider}_oauth_state=${state}; Path=/; HttpOnly; SameSite=Lax`) + const url = await getAuthorizationUrl(ctx, input.provider) if (!validateRedirectDomain(ctx, input.redirectTo)) { throw new TRPCError({ code: 'FORBIDDEN', message: `The redirect URL is invalid: ${input.redirectTo}`, }) } + const secure = ctx.req?.url.startsWith('https:') ? 'Secure; ' : '' ctx.setCookie( - `${input.provider}_oauth_redirect=${input.redirectTo || ''}; Path=/; HttpOnly; SameSite=Lax` + `${input.provider}_oauth_redirect=${ + input.redirectTo || '' + }; Path=/; ${secure}HttpOnly; SameSite=Lax` ) return { redirectTo: url.toString() } } @@ -245,7 +247,9 @@ const signInWithEmailCodeHandler = console.log('calling update passing and invalidate sessions') await ctx.auth.invalidateUserSessions(res.session?.userId) const session = await createSession(ctx.auth, res.session?.userId) - ctx.authRequest?.setSessionCookie(session.id) + if (ctx.setCookie) { + ctx.setCookie(ctx.auth.createSessionCookie(session.id).serialize()) + } res.session = session } return res @@ -338,7 +342,9 @@ export const userRouter = router({ email: input.email, }) const session = await createSession(ctx.auth, user.id) - ctx.authRequest?.setSessionCookie(session.id) + if (ctx.setCookie) { + ctx.setCookie(ctx.auth.createSessionCookie(session.id).serialize()) + } return { session } }), }) diff --git a/packages/api/src/schema/user.ts b/packages/api/src/schema/user.ts index f2e2749ad..cda67eaf3 100644 --- a/packages/api/src/schema/user.ts +++ b/packages/api/src/schema/user.ts @@ -44,6 +44,7 @@ export const SignInSchema = object({ idToken: optional(string()), refreshToken: optional(string()), nonce: optional(string()), + appleUser: optional(object({ email: optionalEmail })), }) export type SignInInput = Input diff --git a/packages/api/src/utils/crypto.ts b/packages/api/src/utils/crypto.ts index 38c130259..fcd1e588c 100644 --- a/packages/api/src/utils/crypto.ts +++ b/packages/api/src/utils/crypto.ts @@ -1,7 +1,7 @@ import { TimeSpan, isWithinExpirationDate } from 'oslo' import { TOTPController } from 'oslo/otp' import { decodeBase64, encodeBase64, encodeHex } from 'oslo/encoding' -import { type JWT, parseJWT, validateJWT } from './jwt' +import { type JWT, parseJWT, validateJWT } from 'oslo/jwt' import { match, P } from 'ts-pattern' import { HMAC, sha256 as sha256AB } from 'oslo/crypto' diff --git a/packages/api/src/utils/jwt.ts b/packages/api/src/utils/jwt.ts deleted file mode 100644 index b1166f4ea..000000000 --- a/packages/api/src/utils/jwt.ts +++ /dev/null @@ -1,309 +0,0 @@ -// Temp copied from https://raw.githubusercontent.com/rmarscher/oslo/optional-typ-header/src/jwt/index.ts -// https://github.com/pilcrowOnPaper/oslo/pull/9 -import { ECDSA, HMAC, RSASSAPKCS1v1_5, RSASSAPSS } from 'oslo/crypto' -import { decodeBase64url, encodeBase64url } from 'oslo/encoding' -import { isWithinExpirationDate } from 'oslo' -import type { TimeSpan } from 'oslo' - -export type JWTAlgorithm = - | 'HS256' - | 'HS384' - | 'HS512' - | 'RS256' - | 'RS384' - | 'RS512' - | 'ES256' - | 'ES384' - | 'ES512' - | 'PS256' - | 'PS384' - | 'PS512' - -export async function createJWT( - algorithm: JWTAlgorithm, - key: ArrayBuffer, - payloadClaims: Record, - options?: { - headers?: Record - expiresIn?: TimeSpan - issuer?: string - subject?: string - audience?: string - notBefore?: Date - includeIssuedTimestamp?: boolean - jwtId?: string - } -): Promise { - const header: JWTHeader = { - alg: algorithm, - typ: 'JWT', - ...options?.headers, - } - const payload: JWTPayload = { - ...payloadClaims, - } - if (options?.audience !== undefined) { - payload.aud = options.audience - } - if (options?.subject !== undefined) { - payload.sub = options.subject - } - if (options?.issuer !== undefined) { - payload.iss = options.issuer - } - if (options?.jwtId !== undefined) { - payload.jti = options.jwtId - } - if (options?.expiresIn !== undefined) { - payload.exp = Math.floor(Date.now() / 1000) + options.expiresIn.seconds() - } - if (options?.notBefore !== undefined) { - payload.nbf = Math.floor(options.notBefore.getTime() / 1000) - } - if (options?.includeIssuedTimestamp === true) { - payload.iat = Math.floor(Date.now() / 1000) - } - const textEncoder = new TextEncoder() - const headerPart = encodeBase64url( - textEncoder.encode(JSON.stringify(header)).buffer as ArrayBuffer - ) - const payloadPart = encodeBase64url( - textEncoder.encode(JSON.stringify(payload)).buffer as ArrayBuffer - ) - const data = textEncoder.encode([headerPart, payloadPart].join('.')).buffer as ArrayBuffer - const signature = await getAlgorithm(algorithm).sign(key, data) - const signaturePart = encodeBase64url(signature) - const value = [headerPart, payloadPart, signaturePart].join('.') - return value -} - -export async function validateJWT( - algorithm: JWTAlgorithm, - key: ArrayBuffer, - jwt: string -): Promise { - const parsedJWT = parseJWT(jwt) - if (!parsedJWT) { - throw new Error('Invalid JWT') - } - if (parsedJWT.algorithm !== algorithm) { - throw new Error('Invalid algorithm') - } - if (parsedJWT.expiresAt && !isWithinExpirationDate(parsedJWT.expiresAt)) { - throw new Error('Expired JWT') - } - if (parsedJWT.notBefore && Date.now() < parsedJWT.notBefore.getTime()) { - throw new Error('Inactive JWT') - } - const signature = decodeBase64url(parsedJWT.parts[2]) - const data = new TextEncoder().encode(`${parsedJWT.parts[0]}.${parsedJWT.parts[1]}`) - .buffer as ArrayBuffer - const validSignature = await getAlgorithm(parsedJWT.algorithm).verify( - key, - signature.buffer as ArrayBuffer, - data - ) - if (!validSignature) { - throw new Error('Invalid signature') - } - return parsedJWT -} - -function getJWTParts(jwt: string): [header: string, payload: string, signature: string] | null { - const jwtParts = jwt.split('.') - if (jwtParts.length !== 3) { - return null - } - return jwtParts as [string, string, string] -} - -export function parseJWT(jwt: string): JWT | null { - const jwtParts = getJWTParts(jwt) - if (!jwtParts) { - return null - } - const textDecoder = new TextDecoder() - const rawHeader = decodeBase64url(jwtParts[0]) - const rawPayload = decodeBase64url(jwtParts[1]) - const header: unknown = JSON.parse(textDecoder.decode(rawHeader)) - if (typeof header !== 'object' || header === null) { - return null - } - if (!('alg' in header) || !isValidAlgorithm(header.alg)) { - return null - } - if ('typ' in header && header.typ !== 'JWT') { - return null - } - const payload: unknown = JSON.parse(textDecoder.decode(rawPayload)) - if (typeof payload !== 'object' || payload === null) { - return null - } - const properties: JWTProperties = { - algorithm: header.alg, - expiresAt: null, - subject: null, - issuedAt: null, - issuer: null, - jwtId: null, - audience: null, - notBefore: null, - } - if ('exp' in payload) { - if (typeof payload.exp !== 'number') { - return null - } - properties.expiresAt = new Date(payload.exp * 1000) - } - if ('iss' in payload) { - if (typeof payload.iss !== 'string') { - return null - } - properties.issuer = payload.iss - } - if ('sub' in payload) { - if (typeof payload.sub !== 'string') { - return null - } - properties.subject = payload.sub - } - if ('aud' in payload) { - if (typeof payload.aud !== 'string') { - return null - } - properties.audience = payload.aud - } - if ('nbf' in payload) { - if (typeof payload.nbf !== 'number') { - return null - } - properties.notBefore = new Date(payload.nbf * 1000) - } - if ('iat' in payload) { - if (typeof payload.iat !== 'number') { - return null - } - properties.issuedAt = new Date(payload.iat * 1000) - } - if ('jti' in payload) { - if (typeof payload.jti !== 'string') { - return null - } - properties.jwtId = payload.jti - } - return { - value: jwt, - header: { - ...header, - typ: 'JWT', - alg: header.alg, - }, - payload: { - ...payload, - }, - parts: jwtParts, - ...properties, - } -} - -interface JWTProperties { - algorithm: JWTAlgorithm - expiresAt: Date | null - issuer: string | null - subject: string | null - audience: string | null - notBefore: Date | null - issuedAt: Date | null - jwtId: string | null -} - -export interface JWT extends JWTProperties { - value: string - header: object - payload: object - parts: [header: string, payload: string, signature: string] -} - -function getAlgorithm(algorithm: JWTAlgorithm): ECDSA | HMAC | RSASSAPKCS1v1_5 | RSASSAPSS { - if (algorithm === 'ES256' || algorithm === 'ES384' || algorithm === 'ES512') { - return new ECDSA(ecdsaDictionary[algorithm].hash, ecdsaDictionary[algorithm].curve) - } - if (algorithm === 'HS256' || algorithm === 'HS384' || algorithm === 'HS512') { - return new HMAC(hmacDictionary[algorithm]) - } - if (algorithm === 'RS256' || algorithm === 'RS384' || algorithm === 'RS512') { - return new RSASSAPKCS1v1_5(rsassapkcs1v1_5Dictionary[algorithm]) - } - if (algorithm === 'PS256' || algorithm === 'PS384' || algorithm === 'PS512') { - return new RSASSAPSS(rsassapssDictionary[algorithm]) - } - throw new TypeError('Invalid algorithm') -} - -function isValidAlgorithm(maybeValidAlgorithm: unknown): maybeValidAlgorithm is JWTAlgorithm { - if (typeof maybeValidAlgorithm !== 'string') return false - return [ - 'HS256', - 'HS384', - 'HS512', - 'RS256', - 'RS384', - 'RS512', - 'ES256', - 'ES384', - 'ES512', - 'PS256', - 'PS384', - 'PS512', - ].includes(maybeValidAlgorithm) -} - -interface JWTHeader { - typ: 'JWT' - alg: JWTAlgorithm - [header: string]: any -} - -interface JWTPayload { - exp?: number - iss?: string - aud?: string - jti?: string - nbf?: number - sub?: string - iat?: number - [claim: string]: any -} - -const ecdsaDictionary = { - ES256: { - hash: 'SHA-256', - curve: 'P-256', - }, - ES384: { - hash: 'SHA-384', - curve: 'P-384', - }, - ES512: { - hash: 'SHA-512', - curve: 'P-521', - }, -} as const - -const hmacDictionary = { - HS256: 'SHA-256', - HS384: 'SHA-384', - HS512: 'SHA-512', -} as const - -const rsassapkcs1v1_5Dictionary = { - RS256: 'SHA-256', - RS384: 'SHA-384', - RS512: 'SHA-512', -} as const - -const rsassapssDictionary = { - PS256: 'SHA-256', - PS384: 'SHA-384', - PS512: 'SHA-512', -} as const diff --git a/packages/api/src/utils/password.ts b/packages/api/src/utils/password.ts index e8c3bf79f..008cd5af7 100644 --- a/packages/api/src/utils/password.ts +++ b/packages/api/src/utils/password.ts @@ -1,6 +1,6 @@ import { argon2Hash, argon2Verify } from './password/argon2' import { HashMethod } from './password/hash-methods' -import { verifyLegacyLuciaPasswordHash } from 'lucia' +import { LegacyScrypt } from 'lucia' export async function hashPassword(password: string) { const hashedPassword = await argon2Hash(password) @@ -19,7 +19,7 @@ export async function verifyPassword( case 'scrypt': case null: case undefined: - return await verifyLegacyLuciaPasswordHash(password, hashedPassword) + return await new LegacyScrypt().verify(hashedPassword, password) case 'argon2': return await argon2Verify(password, hashedPassword) default: diff --git a/packages/api/src/worker.ts b/packages/api/src/worker.ts index 8575d44cc..2f0d71031 100644 --- a/packages/api/src/worker.ts +++ b/packages/api/src/worker.ts @@ -3,6 +3,7 @@ import { appRouter } from '@t4/api/src/router' import { cors } from 'hono/cors' import { createContext } from '@t4/api/src/context' import { trpcServer } from '@hono/trpc-server' +import { csrfMiddleware } from './auth/hono' export type Bindings = Env & { JWT_VERIFICATION_KEY: string @@ -41,6 +42,8 @@ const corsHandler = async (c: Context<{ Bindings: Bindings }>, next: Next) => { })(c, next) } +app.use('*', csrfMiddleware) + // Setup CORS for the frontend app.use('/trpc/*', corsHandler) diff --git a/packages/app/features/oauth/screen.tsx b/packages/app/features/oauth/screen.tsx index a7e58e110..98c972ff2 100644 --- a/packages/app/features/oauth/screen.tsx +++ b/packages/app/features/oauth/screen.tsx @@ -1,15 +1,45 @@ +import type { GetServerSideProps } from 'next' import type { AuthProviderName } from '@t4/api/src/auth/providers' import { Paragraph, isServer } from '@t4/ui' -import { useSignIn } from 'app/utils/auth' +import { type SignInWithOAuth, useSignIn } from 'app/utils/auth' import { useCallback, useEffect, useRef, useState } from 'react' import { createParam } from 'solito' import { P, match } from 'ts-pattern' -type Params = { provider: AuthProviderName; redirectTo: string; code?: string; state?: string } +type Params = { + provider: AuthProviderName + redirectTo: string + code?: string + state?: string +} const { useParam } = createParam() -export const OAuthSignInScreen = (): React.ReactNode => { +// Apple will POST form data to the redirect URI +export const getServerSideProps = (async (context) => { + // Fetch data from external API + let appleUser = null + if (context.req.method !== 'POST') { + return { props: { appleUser } } + } + try { + const userJSON = context.req.headers['x-apple-user'] as string | undefined + if (typeof userJSON === 'string') { + appleUser = JSON.parse(userJSON) + } + } catch (e: unknown) { + console.error(e) + } + // Pass data to the page via props + return { props: { appleUser } } +}) satisfies GetServerSideProps + +export interface OAuthSignInScreenProps { + appleUser?: { email?: string | null } | null +} + + +export const OAuthSignInScreen = ({ appleUser }: OAuthSignInScreenProps): React.ReactNode => { const sent = useRef(false) const { signIn } = useSignIn() const [provider] = useParam('provider') @@ -19,7 +49,7 @@ export const OAuthSignInScreen = (): React.ReactNode => { const [error, setError] = useState(undefined) const sendApiRequestOnLoad = useCallback( - async (params: Params) => { + async (params: SignInWithOAuth) => { if (sent.current) return sent.current = true try { @@ -46,8 +76,18 @@ export const OAuthSignInScreen = (): React.ReactNode => { useEffect(() => { if (sent.current) return if (!provider) return - sendApiRequestOnLoad({ provider, redirectTo: redirectTo || '', state, code }) - }, [provider, redirectTo, state, code, sendApiRequestOnLoad]) + sendApiRequestOnLoad({ + provider, + redirectTo: redirectTo || '', + state, + code, + // undefined vs null is a result of passing via JSON with getServerSideProps + // Maybe there's a superjson plugin or another way to handle it. + appleUser: appleUser ? { + email: appleUser.email || undefined, + } : undefined, + }) + }, [provider, redirectTo, state, code, sendApiRequestOnLoad, appleUser]) const message = match([error, code]) .with([undefined, P._], () => 'Signing in...') diff --git a/packages/app/utils/auth/index.ts b/packages/app/utils/auth/index.ts index ba7163d83..0ec14e608 100644 --- a/packages/app/utils/auth/index.ts +++ b/packages/app/utils/auth/index.ts @@ -129,6 +129,7 @@ export type SignInWithOAuth = { redirectTo?: string code?: string state?: string + appleUser?: { email?: string } } export type SignInProps =