diff --git a/apps/next/middleware.ts b/apps/next/middleware.ts
new file mode 100644
index 000000000..cc469942e
--- /dev/null
+++ b/apps/next/middleware.ts
@@ -0,0 +1,30 @@
+import { NextResponse } from 'next/server'
+import type { NextRequest } from 'next/server'
+
+// This function can be marked `async` if using `await` inside
+export async function middleware(request: NextRequest) {
+ if (request.method === 'POST' && request.body) {
+ try {
+ const cloned = request.clone()
+ const requestHeaders = new Headers(request.headers)
+ const formData = await cloned.formData()
+ const userJson = formData.get('user')
+ if (typeof userJson === 'string') {
+ requestHeaders.set('x-apple-user', userJson)
+ }
+ return NextResponse.next({
+ request: {
+ // New request headers
+ headers: requestHeaders,
+ },
+ })
+ } catch (e: unknown) {
+ console.error('error parsing oauth post', e)
+ }
+ }
+}
+
+// See "Matching Paths" below to learn more
+export const config = {
+ matcher: ['/oauth/apple'],
+}
diff --git a/apps/next/pages/oauth/[provider].tsx b/apps/next/pages/oauth/[provider].tsx
index c1e184bb4..92fc295bb 100644
--- a/apps/next/pages/oauth/[provider].tsx
+++ b/apps/next/pages/oauth/[provider].tsx
@@ -1,13 +1,15 @@
-import { OAuthSignInScreen } from 'app/features/oauth/screen'
+import { OAuthSignInScreen, OAuthSignInScreenProps } from 'app/features/oauth/screen'
import Head from 'next/head'
-export default function Page() {
+export { getServerSideProps } from 'app/features/oauth/screen'
+
+export default function Page(props: OAuthSignInScreenProps) {
return (
<>
OAuth Sign In
-
+
>
)
}
diff --git a/packages/api/migrations/meta/0004_snapshot.json b/packages/api/migrations/meta/0004_snapshot.json
index 2093d15a2..1da15e641 100644
--- a/packages/api/migrations/meta/0004_snapshot.json
+++ b/packages/api/migrations/meta/0004_snapshot.json
@@ -42,6 +42,34 @@
"notNull": false,
"autoincrement": false,
"default": "CURRENT_TIMESTAMP"
+ },
+ "totp_secret": {
+ "name": "totp_secret",
+ "type": "text",
+ "primaryKey": false,
+ "notNull": false,
+ "autoincrement": false
+ },
+ "totp_expires": {
+ "name": "totp_expires",
+ "type": "integer",
+ "primaryKey": false,
+ "notNull": false,
+ "autoincrement": false
+ },
+ "timeout_until": {
+ "name": "timeout_until",
+ "type": "integer",
+ "primaryKey": false,
+ "notNull": false,
+ "autoincrement": false
+ },
+ "timeout_seconds": {
+ "name": "timeout_seconds",
+ "type": "integer",
+ "primaryKey": false,
+ "notNull": false,
+ "autoincrement": false
}
},
"indexes": {
@@ -205,79 +233,6 @@
"foreignKeys": {},
"compositePrimaryKeys": {},
"uniqueConstraints": {}
- },
- "VerificationCode": {
- "name": "VerificationCode",
- "columns": {
- "id": {
- "name": "id",
- "type": "text",
- "primaryKey": true,
- "notNull": true,
- "autoincrement": false
- },
- "user_id": {
- "name": "user_id",
- "type": "text",
- "primaryKey": false,
- "notNull": true,
- "autoincrement": false
- },
- "code": {
- "name": "code",
- "type": "text",
- "primaryKey": false,
- "notNull": true,
- "autoincrement": false
- },
- "expires": {
- "name": "expires",
- "type": "integer",
- "primaryKey": false,
- "notNull": true,
- "autoincrement": false
- },
- "timeout_until": {
- "name": "timeout_until",
- "type": "integer",
- "primaryKey": false,
- "notNull": false,
- "autoincrement": false
- },
- "timeout_seconds": {
- "name": "timeout_seconds",
- "type": "integer",
- "primaryKey": false,
- "notNull": true,
- "autoincrement": false,
- "default": 0
- }
- },
- "indexes": {
- "VerificationCode_user_id_unique": {
- "name": "VerificationCode_user_id_unique",
- "columns": ["user_id"],
- "isUnique": true
- },
- "idx_verificationCode_userId": {
- "name": "idx_verificationCode_userId",
- "columns": ["user_id"],
- "isUnique": false
- }
- },
- "foreignKeys": {
- "VerificationCode_user_id_User_id_fk": {
- "name": "VerificationCode_user_id_User_id_fk",
- "tableFrom": "VerificationCode",
- "tableTo": "User",
- "columnsFrom": ["user_id"],
- "columnsTo": ["id"],
- "onDelete": "no action",
- "onUpdate": "no action"
- }
- },
- "compositePrimaryKeys": {},
- "uniqueConstraints": {}
}
},
"enums": {},
diff --git a/packages/api/package.json b/packages/api/package.json
index 184200beb..1b7e7f493 100644
--- a/packages/api/package.json
+++ b/packages/api/package.json
@@ -20,15 +20,15 @@
"@cloudflare/workers-wasi": "^0.0.5",
"@hono/trpc-server": "^0.1.0",
"@libsql/client": "^0.3.5",
- "@lucia-auth/adapter-sqlite": "3.0.0-beta.1",
+ "@lucia-auth/adapter-drizzle": "1.0.0-beta.2",
"@trpc/server": "^10.43.2",
- "arctic": "0.3.1",
+ "arctic": "0.10.0",
"drizzle-orm": "^0.29.0",
"drizzle-valibot": "beta",
"hono": "^3.9.2",
- "lucia": "3.0.0-beta.6",
+ "lucia": "3.0.0-beta.12",
"miniflare": "3.20231025.1",
- "oslo": "0.22.0",
+ "oslo": "0.24.0",
"superjson": "1.13.3",
"ts-pattern": "^5.0.5",
"valibot": "^0.20.1"
diff --git a/packages/api/src/auth/hono.ts b/packages/api/src/auth/hono.ts
index 38bc0905b..03ef9b4f0 100644
--- a/packages/api/src/auth/hono.ts
+++ b/packages/api/src/auth/hono.ts
@@ -1,26 +1,22 @@
-import { getAuthOptions } from './shared'
-import { D1Adapter } from '@lucia-auth/adapter-sqlite'
-import type { Context as HonoContext, HonoRequest } from 'hono'
-import { Lucia } from 'lucia'
-import type { Middleware } from 'lucia'
+import { getAllowedOriginHost } from '.'
+import type { Context as HonoContext, Next } from 'hono'
+import { Bindings } from '../worker'
+import { verifyRequestOrigin } from 'oslo/request'
-export const hono = (): Middleware<[HonoContext]> => {
- return ({ args }) => {
- const [context] = args
- return {
- request: context.req,
- setCookie: (cookie) => {
- context.res.headers.append('set-cookie', cookie.serialize())
- },
- }
+export const csrfMiddleware = async (c: HonoContext<{ Bindings: Bindings }>, next: Next) => {
+ // CSRF middleware
+ if (c.req.method === 'GET') {
+ return next()
}
+ const originHeader = c.req.header('origin')
+ const hostHeader = c.req.header('host')
+ const allowedOrigin = getAllowedOriginHost(c.env.APP_URL, c.req.raw)
+ if (
+ !originHeader ||
+ !hostHeader ||
+ !verifyRequestOrigin(originHeader, [hostHeader, ...(allowedOrigin ? [allowedOrigin] : [])])
+ ) {
+ return c.body(null, 403)
+ }
+ return next()
}
-
-export const createHonoAuth = (db: D1Database, appUrl: string, request?: HonoRequest) => {
- return new Lucia(new D1Adapter(db, { session: 'session', user: 'user' }), {
- ...getAuthOptions(db, appUrl, request),
- middleware: hono(),
- })
-}
-
-export type HonoLucia = ReturnType
diff --git a/packages/api/src/auth/index.ts b/packages/api/src/auth/index.ts
new file mode 100644
index 000000000..47cbed5b3
--- /dev/null
+++ b/packages/api/src/auth/index.ts
@@ -0,0 +1,68 @@
+import { Adapter, DatabaseSessionAttributes, DatabaseUserAttributes, Lucia, TimeSpan } from 'lucia'
+import { DrizzleSQLiteAdapter } from '@lucia-auth/adapter-drizzle'
+import { SessionTable, UserTable } from '../db/schema'
+import { DB } from '../db/client'
+
+/**
+ * Lucia's isValidRequestOrigin method will compare the
+ * origin of the request to the configured host.
+ * We want to allow cross-domain requests from our APP_URL so return that
+ * if the request origin host matches the APP_URL host.
+ * @link https://github.com/lucia-auth/lucia/blob/main/packages/lucia/src/utils/url.ts
+ */
+export const getAllowedOriginHost = (app_url: string, request?: Request) => {
+ if (!app_url || !request) return undefined
+ const requestOrigin = request.headers.get('Origin')
+ const requestHost = requestOrigin ? new URL(requestOrigin).host : undefined
+ const appHost = new URL(app_url).host
+ return requestHost === appHost ? appHost : undefined
+}
+
+export const createAuth = (db: DB, appUrl: string) => {
+ // @ts-ignore Expect type errors because this is D1 and not SQLite... but it works
+ const adapter = new DrizzleSQLiteAdapter(db, SessionTable, UserTable)
+ // cast probably only needed until adapter-drizzle is updated
+ return new Lucia(adapter as Adapter, {
+ ...getAuthOptions(appUrl),
+ })
+}
+
+export const getAuthOptions = (appUrl: string) => {
+ const env = !appUrl || appUrl.startsWith('http:') ? 'DEV' : 'PROD'
+ return {
+ getUserAttributes: (data: DatabaseUserAttributes) => {
+ return {
+ email: data.email || '',
+ }
+ },
+ // Optional additional session attributes to expose
+ // If updated, also update createSession() in packages/api/src/auth/user.ts
+ getSessionAttributes: (databaseSession: DatabaseSessionAttributes) => {
+ return {}
+ },
+ sessionExpiresIn: new TimeSpan(365, 'd'),
+ sessionCookie: {
+ name: 'auth_session',
+ expires: false,
+ attributes: {
+ secure: env === 'PROD',
+ sameSite: 'lax' as const,
+ },
+ },
+
+ // If you want more debugging, uncomment this
+ // experimental: {
+ // debugMode: true,
+ // },
+ }
+}
+
+declare module 'lucia' {
+ interface Register {
+ Lucia: ReturnType
+ }
+ interface DatabaseSessionAttributes {}
+ interface DatabaseUserAttributes {
+ email: string | null
+ }
+}
diff --git a/packages/api/src/auth/nextjs.ts b/packages/api/src/auth/nextjs.ts
deleted file mode 100644
index 1f7efa9a9..000000000
--- a/packages/api/src/auth/nextjs.ts
+++ /dev/null
@@ -1,22 +0,0 @@
-// This is currently not used because the T4 API server can be on a different
-// domain and the NextJS server won't be able to read the session cookie.
-// If you move the hono routes to an /api prefix and proxy the API server
-// through the NextJS server using the
-// next.config.js rewrites feature, you could try using this to load the session
-// from the database and set up a context for tRPC SSR.
-// However, there are caching benefits to not utilizing SSR and auth for
-// html requests and only utilizing the API server to fetch data.
-
-import { getAuthOptions } from './shared'
-import { D1Adapter } from '@lucia-auth/adapter-sqlite'
-import { Lucia } from 'lucia'
-import { nextjs } from 'lucia/middleware'
-
-export const createNextJSAuth = (db: D1Database, appUrl: string, apiUrl: string) => {
- return new Lucia(new D1Adapter(db, { session: 'session', user: 'user' }), {
- ...getAuthOptions(db, appUrl),
- middleware: nextjs(),
- })
-}
-
-export type NextJSLucia = ReturnType
diff --git a/packages/api/src/auth/oauth.ts b/packages/api/src/auth/oauth.ts
new file mode 100644
index 000000000..20063830f
--- /dev/null
+++ b/packages/api/src/auth/oauth.ts
@@ -0,0 +1,336 @@
+import { ApiContextProps } from '../context'
+import { User } from '../db/schema'
+import {
+ Apple,
+ AppleRefreshedTokens,
+ AppleTokens,
+ Discord,
+ DiscordTokens,
+ GitHub,
+ Google,
+ GoogleTokens,
+ generateCodeVerifier,
+ generateState,
+} from 'arctic'
+import {
+ AuthProvider,
+ AuthProviderName,
+ AuthTokens,
+ isOAuth2ProviderWithPKCE,
+ providers,
+} from './providers'
+import { isWithinExpirationDate } from 'oslo'
+import { parseJWT } from 'oslo/jwt'
+import { createAuthMethodId, createUser, getAuthMethod, getUserById } from './user'
+
+import { P, match } from 'ts-pattern'
+import { getCookie } from 'hono/cookie'
+import { TRPCError } from '@trpc/server'
+
+export interface AppleIdTokenClaims {
+ iss: 'https://appleid.apple.com'
+ sub: string
+ aud: string
+ iat: number
+ exp: number
+ email?: string
+ email_verified?: boolean
+ is_private_email?: boolean
+ nonce?: string
+ nonce_supported?: boolean
+ real_user_status: 0 | 1 | 2
+ transfer_sub?: string
+}
+
+export const getAuthProvider = (ctx: ApiContextProps, name: AuthProviderName): AuthProvider => {
+ const origin = ctx.env.APP_URL ? new URL(ctx.env.APP_URL).origin : ''
+ if (!providers[name]) {
+ if (name === 'apple') {
+ providers[name] = new Apple(
+ {
+ clientId: ctx.env.APPLE_CLIENT_ID,
+ certificate: ctx.env.APPLE_CERTIFICATE,
+ keyId: ctx.env.APPLE_KEY_ID,
+ teamId: ctx.env.APPLE_TEAM_ID,
+ },
+ `${origin}/oauth/${name}`
+ )
+ }
+ if (name === 'discord') {
+ providers[name] = new Discord(
+ ctx.env.DISCORD_CLIENT_ID,
+ ctx.env.DISCORD_CLIENT_SECRET,
+ `${origin}/oauth/${name}`
+ )
+ }
+ if (name === 'github') {
+ providers[name] = new GitHub(ctx.env.GITHUB_CLIENT_ID, ctx.env.GITHUB_CLIENT_SECRET, {
+ redirectURI: `${origin}/oauth/${name}`,
+ })
+ }
+ if (name === 'google') {
+ providers[name] = new Google(
+ ctx.env.GOOGLE_CLIENT_ID,
+ ctx.env.GOOGLE_CLIENT_SECRET,
+ `${origin}/oauth/${name}`
+ )
+ }
+ }
+ const service = providers[name]
+ if (service === null) {
+ throw new Error(`Unable to configure oauth for ${name}`)
+ }
+ return service
+}
+
+export function getAppleClaims(idToken?: string): AppleIdTokenClaims | undefined {
+ if (!idToken) return undefined
+ const payload = parseJWT(idToken)?.payload
+ return payload &&
+ 'iss' in payload &&
+ payload.iss === 'https://appleid.apple.com' &&
+ 'sub' in payload &&
+ 'aud' in payload &&
+ 'iat' in payload &&
+ 'exp' in payload
+ ? (payload as AppleIdTokenClaims)
+ : undefined
+}
+
+export const getAuthorizationUrl = async (ctx: ApiContextProps, service: AuthProviderName) => {
+ const provider = getAuthProvider(ctx, service)
+ const secure = ctx.req?.url.startsWith('https:') ? 'Secure; ' : ''
+ const state = generateState()
+ ctx.setCookie(
+ `${service}_oauth_state=${state}; Path=/; ${secure}HttpOnly; SameSite=Lax; Max-Age=600`
+ )
+ return await match({ provider, service })
+ .with({ service: 'google', provider: P.instanceOf(Google) }, async ({ provider }) => {
+ // Google requires PKCE
+ const codeVerifier = generateCodeVerifier()
+ ctx.setCookie(
+ `${service}_oauth_verifier=${codeVerifier}; Path=/; ${secure}HttpOnly; SameSite=Lax; Max-Age=600`
+ )
+ const url = await provider.createAuthorizationURL(state, codeVerifier, {
+ scopes: ['https://www.googleapis.com/auth/userinfo.email'],
+ })
+ // Uncomment if you need to get and store a refresh token
+ // Currently, OAuth is only used for the initial sign in
+ // so we don't need to persist the access or refresh tokens
+ // url.searchParams.set('access_type', 'offline')
+ return url
+ })
+ .with({ service: 'apple', provider: P.instanceOf(Apple) }, async ({ provider }) => {
+ const url = await provider.createAuthorizationURL(state, { scopes: ['email'] })
+ url.searchParams.set('response_mode', 'form_post')
+ return url
+ })
+ .with({ service: 'discord', provider: P.instanceOf(Discord) }, async ({ provider }) => {
+ return await provider.createAuthorizationURL(state)
+ })
+ .with({ service: 'github', provider: P.instanceOf(GitHub) }, async ({ provider }) => {
+ return await provider.createAuthorizationURL(state, { scopes: ['email'] })
+ })
+ .otherwise(() => {
+ throw new Error('Unknown auth provider')
+ })
+}
+
+const checkAuthTokens = async (tokens: Partial, authProvider: AuthProvider) => {
+ let accessToken: string | undefined = tokens.accessToken
+ let accessTokenExpiresAt: Date | undefined
+ let refreshToken: string | null | undefined
+ let idTokenClaims: AppleIdTokenClaims | undefined
+
+ if ('idToken' in tokens && authProvider instanceof Apple) {
+ idTokenClaims = getAppleClaims(tokens.idToken)
+ }
+ if ('refreshToken' in tokens) {
+ refreshToken = (tokens as Partial).refreshToken
+ }
+ if ('accessTokenExpiresAt' in tokens) {
+ accessTokenExpiresAt = (tokens as Partial)
+ .accessTokenExpiresAt
+ if (!accessTokenExpiresAt || !isWithinExpirationDate(accessTokenExpiresAt)) {
+ if (refreshToken && 'refreshAccessToken' in authProvider) {
+ const refreshedTokens = await authProvider.refreshAccessToken(refreshToken)
+ if (refreshedTokens?.accessToken) {
+ accessToken = refreshedTokens.accessToken
+ }
+ if (refreshedTokens?.accessTokenExpiresAt) {
+ accessTokenExpiresAt = refreshedTokens.accessTokenExpiresAt
+ }
+ if (refreshedTokens && 'idToken' in refreshedTokens) {
+ idTokenClaims = getAppleClaims((refreshedTokens as AppleRefreshedTokens).idToken)
+ }
+ }
+ }
+ if (!accessTokenExpiresAt || !isWithinExpirationDate(accessTokenExpiresAt)) {
+ throw new Error('Access token is expired')
+ }
+ }
+ return {
+ accessToken,
+ accessTokenExpiresAt,
+ refreshToken,
+ idTokenClaims,
+ }
+}
+
+type GetOAuthUserResult = {
+ attributes: Partial
+ providerUserId: string
+ authMethodId: string
+}
+
+// https://arctic.pages.dev/providers/apple
+export async function getAppleUser({
+ idTokenClaims,
+}: { idTokenClaims: AppleIdTokenClaims }): Promise {
+ return {
+ attributes: {
+ email: idTokenClaims?.email || undefined,
+ },
+ providerUserId: idTokenClaims.sub,
+ authMethodId: createAuthMethodId('apple', idTokenClaims.sub),
+ }
+}
+
+// https://arctic.pages.dev/providers/discord
+// https://discord.com/developers/docs/resources/user#user-object
+export async function getDiscordUser({
+ accessToken,
+}: { accessToken: string }): Promise {
+ const res = await (
+ await fetch('https://discord.com/api/users/@me', {
+ headers: {
+ Authorization: `Bearer ${accessToken}`,
+ },
+ })
+ ).json<{ email: string; id: string }>()
+ return {
+ attributes: {
+ email: res.email,
+ },
+ authMethodId: res.id,
+ providerUserId: createAuthMethodId('discord', res.id),
+ }
+}
+
+export async function getGitHubUser({
+ accessToken,
+}: { accessToken: string }): Promise {
+ const user = await (
+ await fetch('https://api.github.com/user', {
+ headers: {
+ Authorization: `Bearer ${accessToken}`,
+ },
+ })
+ ).json<{ id: string; email: string }>()
+ const emails = await (
+ await fetch('https://api.github.com/user/emails', {
+ headers: {
+ Authorization: `Bearer ${accessToken}`,
+ },
+ })
+ ).json<{ primary: boolean; email: string }[]>()
+ const primaryEmail = emails.find((e: { primary: boolean }) => e.primary)?.email
+ return {
+ attributes: {
+ email: user.email || primaryEmail,
+ },
+ providerUserId: user.id,
+ authMethodId: createAuthMethodId('github', user.id),
+ }
+}
+
+export async function getGoogleUser({
+ accessToken,
+}: { accessToken: string }): Promise {
+ const res = await (
+ await fetch('https://openidconnect.googleapis.com/v1/userinfo', {
+ headers: {
+ Authorization: `Bearer ${accessToken}`,
+ },
+ })
+ ).json<{ email: string; sub: string }>()
+ return {
+ attributes: {
+ email: res.email || undefined,
+ },
+ providerUserId: res.sub,
+ authMethodId: createAuthMethodId('google', res.sub),
+ }
+}
+
+export const getUserFromAuthProvider = async <_AuthTokens extends AuthTokens>(
+ ctx: ApiContextProps,
+ service: AuthProviderName,
+ authProvider: AuthProvider,
+ tokens: Partial<_AuthTokens>,
+ userData: Partial = {}
+): Promise => {
+ const { accessToken, idTokenClaims } = await checkAuthTokens(tokens, authProvider)
+ const { authMethodId, providerUserId, attributes } = await match({
+ authProvider,
+ accessToken,
+ idTokenClaims,
+ })
+ .with({ authProvider: P.instanceOf(Apple), idTokenClaims: P.nullish }, async () => {
+ throw new Error('Apple idToken is required')
+ })
+ .with({ authProvider: P.instanceOf(Apple), idTokenClaims: { sub: P.nullish } }, async () => {
+ throw new Error('Missing subject in Apple idToken')
+ })
+ .with({ authProvider: P.instanceOf(Apple), idTokenClaims: P.not(P.nullish) }, getAppleUser)
+ .with({ accessToken: P.nullish }, async () => {
+ throw new Error('Access token is required')
+ })
+ .with({ authProvider: P.instanceOf(Discord), accessToken: P.not(P.nullish) }, getDiscordUser)
+ .with({ authProvider: P.instanceOf(GitHub), accessToken: P.not(P.nullish) }, getGitHubUser)
+ .with({ authProvider: P.instanceOf(Google), accessToken: P.not(P.nullish) }, getGoogleUser)
+ .otherwise(() => {
+ throw new Error('Unknown auth provider')
+ })
+
+ if (!authMethodId || !providerUserId) {
+ throw new Error('Unknown auth provider')
+ }
+ const existingAuthMethod = await getAuthMethod(ctx, authMethodId)
+ if (existingAuthMethod) {
+ const existingUser = await getUserById(ctx, existingAuthMethod.userId)
+ if (existingUser) {
+ return existingUser
+ }
+ }
+
+ // TODO if the email is used by another user without oauth, you will wind up with two
+ // users with the same email. If you'd rather merge the accounts, you could create
+ // a unique index on the email column in the user table and then check for that first
+ // and either return an error or automatically link the accounts.
+ // See https://lucia-auth.com/guidebook/oauth-account-linking/ for an example on
+ // how to automatically link the accounts.
+ const user = await createUser(ctx, service, providerUserId, null, {
+ ...attributes,
+ ...userData,
+ })
+ return user
+}
+
+export const getOAuthUser = async (
+ service: AuthProviderName,
+ ctx: ApiContextProps,
+ { code, userData }: { code: string; userData?: Partial }
+): Promise => {
+ const authService = getAuthProvider(ctx, service)
+ if (isOAuth2ProviderWithPKCE(authService)) {
+ if (!ctx.c) throw new TRPCError({ code: 'INTERNAL_SERVER_ERROR', message: 'Missing context' })
+ const codeVerifier = getCookie(ctx.c, `${service}_oauth_verifier`)
+ if (!codeVerifier)
+ throw new TRPCError({ code: 'INTERNAL_SERVER_ERROR', message: 'Missing code verifier' })
+ const validateResult = await authService.validateAuthorizationCode(code, codeVerifier)
+ return getUserFromAuthProvider(ctx, service, authService, validateResult, userData)
+ }
+ const validateResult = await authService.validateAuthorizationCode(code)
+ return getUserFromAuthProvider(ctx, service, authService, validateResult)
+}
diff --git a/packages/api/src/auth/providers.ts b/packages/api/src/auth/providers.ts
index 856d944df..131e699a6 100644
--- a/packages/api/src/auth/providers.ts
+++ b/packages/api/src/auth/providers.ts
@@ -6,9 +6,9 @@ import type {
DiscordTokens,
GitHub,
GitHubTokens,
- Google,
GoogleTokens,
} from 'arctic'
+import { Google } from 'arctic'
export const providers: {
apple: Apple | null
@@ -24,7 +24,15 @@ export const providers: {
export type AuthProviderName = keyof typeof providers
export type AuthProvider = Apple | Discord | GitHub | Google
export type AuthTokens = AppleTokens | DiscordTokens | GitHubTokens | GoogleTokens
+export type AuthProviderWithPKCE = Google
+export type AuthProviderWithoutPKCE = Exclude
export const isAuthProviderName = (name: string): name is AuthProviderName => {
return name in providers
}
+
+export const isOAuth2ProviderWithPKCE = (
+ provider: AuthProvider
+): provider is AuthProviderWithPKCE => {
+ return provider instanceof Google
+}
diff --git a/packages/api/src/auth/shared.ts b/packages/api/src/auth/shared.ts
deleted file mode 100644
index 3b0a4f233..000000000
--- a/packages/api/src/auth/shared.ts
+++ /dev/null
@@ -1,256 +0,0 @@
-import { ApiContextProps } from '../context'
-import { User } from '../db/schema'
-import {
- Apple,
- AppleIdTokenClaims,
- AppleTokens,
- Discord,
- DiscordTokens,
- GitHub,
- GitHubTokens,
- Google,
- GoogleTokens,
-} from 'arctic'
-import type { HonoRequest } from 'hono'
-import { DatabaseSessionAttributes, DatabaseUserAttributes, TimeSpan } from 'lucia'
-import { AuthProvider, AuthProviderName, AuthTokens, providers } from './providers'
-import { isWithinExpirationDate } from 'oslo'
-import { createAuthMethodId, createUser, getAuthMethod, getUserById } from './user'
-import type { HonoLucia } from './hono'
-
-export const getAuthProvider = (ctx: ApiContextProps, name: AuthProviderName): AuthProvider => {
- const origin = ctx.env.APP_URL ? new URL(ctx.env.APP_URL).origin : ''
- if (!providers[name]) {
- if (name === 'apple') {
- providers[name] = new Apple(
- {
- clientId: ctx.env.APPLE_CLIENT_ID,
- certificate: ctx.env.APPLE_CERTIFICATE,
- keyId: ctx.env.APPLE_KEY_ID,
- teamId: ctx.env.APPLE_TEAM_ID,
- },
- `${origin}/oauth/${name}`,
- {
- responseMode: 'form_post',
- scope: ['email'],
- }
- )
- }
- if (name === 'discord') {
- providers[name] = new Discord(
- ctx.env.DISCORD_CLIENT_ID,
- ctx.env.DISCORD_CLIENT_SECRET,
- `${origin}/oauth/${name}`,
- {
- scope: ['email'],
- }
- )
- }
- if (name === 'github') {
- providers[name] = new GitHub(ctx.env.GITHUB_CLIENT_ID, ctx.env.GITHUB_CLIENT_SECRET, {
- redirectURI: `${origin}/oauth/${name}`,
- scope: ['email'],
- })
- }
- if (name === 'google') {
- providers[name] = new Google(
- ctx.env.GOOGLE_CLIENT_ID,
- ctx.env.GOOGLE_CLIENT_SECRET,
- `${origin}/oauth/${name}`,
- {
- scope: ['https://www.googleapis.com/auth/userinfo.email'],
- }
- )
- }
- }
- const service = providers[name]
- if (service === null) {
- throw new Error(`Unable to configure oauth for ${name}`)
- }
- return service
-}
-
-/**
- * Lucia's isValidRequestOrigin method will compare the
- * origin of the request to the configured host.
- * We want to allow cross-domain requests from our APP_URL so return that
- * if the request origin host matches the APP_URL host.
- * @link https://github.com/lucia-auth/lucia/blob/main/packages/lucia/src/utils/url.ts
- */
-export const getAllowedOriginHost = (app_url: string, request?: HonoRequest) => {
- if (!app_url || !request) return undefined
- const requestOrigin = request.header('Origin')
- const requestHost = requestOrigin ? new URL(requestOrigin).host : undefined
- const appHost = new URL(app_url).host
- return requestHost === appHost ? appHost : undefined
-}
-
-export const getAuthOptions = (db: D1Database, appUrl: string, request?: HonoRequest) => {
- const env = !appUrl || appUrl.startsWith('http:') ? 'DEV' : 'PROD'
- const allowedHost = getAllowedOriginHost(appUrl, request)
- return {
- // Lucia's d1 adapter makes queries for sessions and users directly from the database
- // Does drizzle provide a constructor we could use here to automatically perform the transforms?
- getUserAttributes: (data: DatabaseUserAttributes) => {
- if ('attributes' in data) {
- // biome-ignore lint/style/noParameterAssign: this will be fixed in the next lucia v3 beta
- data = data.attributes as DatabaseUserAttributes
- }
- return {
- email: data.email || '',
- }
- },
- // Optional additional session attributes to expose
- // If updated, also update createSession() in packages/api/src/auth/user.ts
- getSessionAttributes: (databaseSession: DatabaseSessionAttributes) => {
- return {}
- },
- sessionExpiresIn: new TimeSpan(365, 'd'),
- sessionCookie: {
- name: 'auth_session',
- expires: false,
- attributes: {
- secure: env === 'PROD',
- sameSite: 'lax' as const,
- },
- },
-
- // https://lucia-auth.com/basics/configuration/#csrfprotection
- csrfProtection: {
- allowedSubDomains: '*',
- allowedDomains: allowedHost ? [allowedHost] : undefined,
- },
-
- // If you want more debugging, uncomment this
- // experimental: {
- // debugMode: true,
- // },
- }
-}
-
-export const getUserFromAuthProvider = async <_AuthTokens extends AuthTokens>(
- ctx: ApiContextProps,
- service: AuthProviderName,
- authProvider: AuthProvider,
- tokens: Partial<_AuthTokens>
-): Promise => {
- // ts-pattern would make this a little cleaner
- let accessToken: string | undefined = tokens.accessToken
- let accessTokenExpiresAt: Date | undefined
- let refreshToken: string | null | undefined
- let idTokenClaims: AppleIdTokenClaims | undefined
- const isApple = authProvider instanceof Apple
-
- if ('refreshToken' in tokens) {
- refreshToken = (tokens as Partial).refreshToken
- }
- if ('accessTokenExpiresAt' in tokens) {
- accessTokenExpiresAt = (tokens as Partial)
- .accessTokenExpiresAt
- if (!accessTokenExpiresAt || !isWithinExpirationDate(accessTokenExpiresAt)) {
- if (refreshToken && 'refreshAccessToken' in authProvider) {
- const refreshedTokens = await authProvider.refreshAccessToken(refreshToken)
- if (refreshedTokens?.accessToken) {
- accessToken = refreshedTokens.accessToken
- }
- if (refreshedTokens?.accessTokenExpiresAt) {
- accessTokenExpiresAt = refreshedTokens.accessTokenExpiresAt
- }
- if (refreshedTokens as Partial) {
- idTokenClaims = (refreshedTokens as Partial).idTokenClaims
- }
- }
- }
- if (!accessTokenExpiresAt || !isWithinExpirationDate(accessTokenExpiresAt)) {
- throw new Error('Access token is expired')
- }
- }
- if (isApple) {
- idTokenClaims = (tokens as Partial).idTokenClaims
- if (!idTokenClaims) {
- throw new Error('Apple idToken is required')
- }
- }
-
- let attributes: Partial = {}
- let providerUserId: string | undefined
- let authMethodId: string | undefined
- if (isApple) {
- attributes = {
- email: idTokenClaims?.email || undefined,
- }
- providerUserId = idTokenClaims?.sub
- if (!providerUserId) {
- throw new Error('Missing subject in Apple idToken')
- }
- authMethodId = createAuthMethodId('apple', providerUserId)
- } else {
- if (!accessToken) {
- throw new Error('Access token is required')
- }
- if (authProvider instanceof Discord) {
- const discordUser = await authProvider.getUser(accessToken)
- providerUserId = discordUser.id
- authMethodId = createAuthMethodId('discord', providerUserId)
- attributes = {
- email: discordUser.email || undefined,
- }
- }
- if (authProvider instanceof GitHub) {
- const githubUser = await (authProvider as GitHub).getUser(accessToken)
- providerUserId = githubUser.id.toString()
- authMethodId = createAuthMethodId('github', providerUserId)
- attributes = {
- email: githubUser.email || undefined,
- }
- }
- if (authProvider instanceof Google) {
- const googleUser = await authProvider.getUser(accessToken)
- providerUserId = googleUser.sub
- authMethodId = createAuthMethodId('google', googleUser.sub)
- attributes = {
- email: googleUser.email,
- }
- }
- }
- if (!authMethodId || !providerUserId) {
- throw new Error('Unknown auth provider')
- }
- const existingAuthMethod = await getAuthMethod(ctx, authMethodId)
- if (existingAuthMethod) {
- const existingUser = await getUserById(ctx, existingAuthMethod.userId)
- if (existingUser) {
- return existingUser
- }
- }
-
- // TODO if the email is used by another user without oauth, you will wind up with two
- // users with the same email. If you'd rather merge the accounts, you could create
- // a unique index on the email column in the user table and then check for that first
- // and either return an error or automatically link the accounts.
- // See https://lucia-auth.com/guidebook/oauth-account-linking/ for an example on
- // how to automatically link the accounts.
- const user = await createUser(ctx, service, providerUserId, null, attributes)
- return user
-}
-
-export const getOAuthUser = async (
- service: AuthProviderName,
- ctx: ApiContextProps,
- { code }: { code: string }
-): Promise => {
- const authService = getAuthProvider(ctx, service)
- const validateResult = await authService.validateAuthorizationCode(code)
- return getUserFromAuthProvider(ctx, service, authService, validateResult)
-}
-
-declare module 'lucia' {
- interface Register {
- Lucia: HonoLucia
- DatabaseUserAttributes: {
- email: string | null
- }
- // biome-ignore lint/complexity/noBannedTypes: Need to define this even if empty
- DatabaseSessionAttributes: {}
- }
-}
diff --git a/packages/api/src/auth/user.ts b/packages/api/src/auth/user.ts
index 6e98a7485..fb2cdc785 100644
--- a/packages/api/src/auth/user.ts
+++ b/packages/api/src/auth/user.ts
@@ -10,7 +10,7 @@ import { isWithinExpirationDate } from 'oslo'
import { createCode, createTotpSecret, verifyCode } from '../utils/crypto'
import { AuthProviderName } from './providers'
import { OAuth2RequestError } from 'arctic'
-import { getOAuthUser } from './shared'
+import { getOAuthUser } from './oauth'
export const createAuthMethodId = (providerId: string, providerUserId: string) => {
if (providerId.includes(':')) {
@@ -315,17 +315,21 @@ export const signInWithOAuthCode = async (
code: string,
state?: string,
storedState?: string,
- redirectTo?: string
+ redirectTo?: string,
+ userData?: Partial
) => {
if (!storedState || !state || storedState !== state || typeof code !== 'string') {
throw new TRPCError({ message: 'Invalid state', code: 'BAD_REQUEST' })
}
try {
- const user = await getOAuthUser(service, ctx, { code })
+ const user = await getOAuthUser(service, ctx, { code, userData })
const session = await createSession(ctx.auth, user.id)
- ctx.authRequest?.setSessionCookie(session.id)
+ if (ctx.setCookie) {
+ ctx.setCookie(ctx.auth.createSessionCookie(session.id).serialize())
+ }
return { redirectTo: `${redirectTo ? redirectTo : ctx.env.APP_URL}#token=${session.id}` }
} catch (e) {
+ console.error(e)
if (e instanceof OAuth2RequestError) {
throw new TRPCError({ message: 'Invalid code', code: 'BAD_REQUEST' })
}
@@ -382,7 +386,7 @@ export const cleanup = async (context: ApiContextProps, userId?: string) => {
.where(
and(
userId ? eq(SessionTable.userId, userId) : undefined,
- lt(SessionTable.expiresAt, new Date())
+ lt(SessionTable.expiresAt, new Date().getTime())
)
)
}
diff --git a/packages/api/src/context.ts b/packages/api/src/context.ts
index 806bd77ac..b58a54215 100644
--- a/packages/api/src/context.ts
+++ b/packages/api/src/context.ts
@@ -1,4 +1,3 @@
-import { createHonoAuth, HonoLucia } from './auth/hono'
import { type Session } from './auth/user'
import { createDb } from './db/client'
import type { DB } from './db/client'
@@ -6,14 +5,15 @@ import type { User } from './db/schema'
import { Bindings } from './worker'
import type { inferAsyncReturnType } from '@trpc/server'
import type { Context as HonoContext, HonoRequest } from 'hono'
-import type { AuthRequest, Lucia } from 'lucia'
+import type { Lucia } from 'lucia'
import { verifyToken } from './utils/crypto'
+import { createAuth } from './auth'
+import { getCookie } from 'hono/cookie'
export interface ApiContextProps {
session?: Session
user?: User
- auth: HonoLucia
- authRequest?: AuthRequest
+ auth: Lucia
req?: HonoRequest
c?: HonoContext
setCookie: (value: string) => void
@@ -62,37 +62,48 @@ export const createContext = async (
// const user = await getUser()
- const auth = createHonoAuth(env.DB, env.APP_URL, context.req)
+ const auth = createAuth(db, env.APP_URL)
async function getSession() {
let user: User | undefined
let session: Session | undefined
- let authRequest: AuthRequest | undefined
+ const res = {
+ user,
+ session,
+ }
+
+ if (!context.req) return res
+
+ const cookieSessionId = getCookie(context, auth.sessionCookieName)
+ const bearerSessionId =
+ !cookieSessionId &&
+ context.req.header('x-enable-tokens') &&
+ context.req.header('authorization')?.split(' ')[1]
- if (context.req) {
- authRequest = auth.handleRequest(context)
- const authResult = await authRequest.validate()
- if (authResult.session) {
- session = authResult.session
- user = authResult.user || undefined
+ if (!cookieSessionId && !bearerSessionId) return res
+
+ const authResult = await auth.validateSession(cookieSessionId || bearerSessionId || '')
+ if (cookieSessionId) {
+ if (authResult.session?.fresh) {
+ context.header('Set-Cookie', auth.createSessionCookie(authResult.session.id).serialize(), {
+ append: true,
+ })
}
- // console.log('cookie session and auth request', session, authRequest)
- if (!session && context.req.header('x-enable-tokens')) {
- const tokenAuthResult = await authRequest.validateBearerToken()
- if (tokenAuthResult.session) {
- session = tokenAuthResult.session
- user = tokenAuthResult.user || undefined
- }
+ if (!session) {
+ context.header('Set-Cookie', auth.createBlankSessionCookie().serialize(), {
+ append: true,
+ })
}
}
- return { session, user, authRequest }
+ res.session = authResult.session || undefined
+ res.user = authResult.user || undefined
+ return res
}
- const { session, user, authRequest } = await getSession()
+ const { session, user } = await getSession()
return {
db,
auth,
- authRequest,
req: context.req,
c: context,
session,
diff --git a/packages/api/src/db/schema.ts b/packages/api/src/db/schema.ts
index 3d570300a..cf0745d6c 100644
--- a/packages/api/src/db/schema.ts
+++ b/packages/api/src/db/schema.ts
@@ -5,7 +5,7 @@ import { HASH_METHODS } from '../utils/password/hash-methods'
// User
export const UserTable = sqliteTable('User', {
- id: text('id').primaryKey(),
+ id: text('id').notNull().primaryKey(),
email: text('email').notNull(),
})
export const userRelations = relations(UserTable, ({ many }) => ({
@@ -58,11 +58,12 @@ export const AuthMethodSchema = createInsertSchema(AuthMethodTable)
export const SessionTable = sqliteTable(
'Session',
{
- id: text('id').primaryKey(),
+ id: text('id').notNull().primaryKey(),
userId: text('user_id')
.notNull()
.references(() => UserTable.id),
- expiresAt: integer('expires_at', { mode: 'timestamp' }).notNull(),
+ // DrizzleSQLiteAdapter currently expects this to be an integer and not use { mode: 'timestamp' }
+ expiresAt: integer('expires_at').notNull(),
},
(t) => ({
userIdIdx: index('idx_session_userId').on(t.userId),
diff --git a/packages/api/src/routes/user.ts b/packages/api/src/routes/user.ts
index 465fd76ff..6e6744937 100644
--- a/packages/api/src/routes/user.ts
+++ b/packages/api/src/routes/user.ts
@@ -1,5 +1,5 @@
import { desc, eq } from 'drizzle-orm'
-import { UserTable, type User, SessionTable, AuthMethodTable } from '../db/schema'
+import { UserTable, SessionTable, AuthMethodTable } from '../db/schema'
import { router, protectedProcedure, publicProcedure, valibotParser } from '../trpc'
import { Input } from 'valibot'
import { ApiContextProps } from '../context'
@@ -21,11 +21,15 @@ import {
import { TRPCError } from '@trpc/server'
import { GetByIdSchema, GetSessionsSchema } from '../schema/shared'
import { CreateUserSchema, SignInSchema } from '../schema/user'
-import { AppleIdTokenClaims, generateCodeVerifier, generateState } from 'arctic'
-import { getAuthProvider, getUserFromAuthProvider } from '../auth/shared'
-import { verifyToken, isJWTExpired, sha256 } from '../utils/crypto'
+import {
+ AppleIdTokenClaims,
+ getAuthProvider,
+ getAuthorizationUrl,
+ getUserFromAuthProvider,
+} from '../auth/oauth'
+import { isJWTExpired, sha256 } from '../utils/crypto'
import { getCookie } from 'hono/cookie'
-import { JWT, parseJWT } from '../utils/jwt'
+import { parseJWT } from 'oslo/jwt'
import { P, match } from 'ts-pattern'
import { AuthProviderName } from '../auth/providers'
@@ -144,7 +148,7 @@ const signInWithAppleIdTokenHandler =
// throw new Error('Unable to fetch Apple public key')
// }
// return key
- // })) as AppleIdTokenClaims & { nonce?: string; nonce_supported?: boolean }
+ // })) as unknown as AppleIdTokenClaims
// Since we can't verify the JWT, check that it hasn't expired
const parsedJWT = parseJWT(input.idToken)
if (parsedJWT && isJWTExpired(parsedJWT)) {
@@ -153,10 +157,7 @@ const signInWithAppleIdTokenHandler =
message: 'The Apple ID token has expired.',
})
}
- const payload = parsedJWT?.payload as AppleIdTokenClaims & {
- nonce?: string
- nonce_supported?: boolean
- }
+ const payload = parsedJWT?.payload as AppleIdTokenClaims
if (!payload) {
console.error('Apple ID token could not be verified.', {
payload,
@@ -185,7 +186,9 @@ const signInWithAppleIdTokenHandler =
idTokenClaims: payload,
})
const session = await createSession(ctx.auth, user.id)
- ctx.authRequest?.setSessionCookie(session.id)
+ if (ctx.setCookie) {
+ ctx.setCookie(ctx.auth.createSessionCookie(session.id).serialize())
+ }
return { session }
}
@@ -201,6 +204,7 @@ const signInWithOAuthCodeHandler =
}
const storedState = getCookie(ctx.c, `${input.provider}_oauth_state`)
+ const storedVerifier = getCookie(ctx.c, `${input.provider}_oauth_verifier`)
const storedRedirect = getCookie(ctx.c, `${input.provider}_oauth_redirect`)
return await signInWithOAuthCode(
ctx,
@@ -208,28 +212,26 @@ const signInWithOAuthCodeHandler =
input.code,
input.state,
storedState,
- storedRedirect
+ storedRedirect,
+ input.appleUser
)
}
const authorizationUrlHandler =
(ctx: ApiContextProps) =>
async (input: Input & { provider: AuthProviderName }) => {
- // Get the authorization URL and store the state in a cookie
- const provider = getAuthProvider(ctx, input.provider)
- const state = generateState()
- // TODO Mentioned in docs but seem to be used yet... circle back with future arctic release
- // const codeVerifier = generateCodeVerifier()
- const url = await provider.createAuthorizationURL(state)
- ctx.setCookie(`${input.provider}_oauth_state=${state}; Path=/; HttpOnly; SameSite=Lax`)
+ const url = await getAuthorizationUrl(ctx, input.provider)
if (!validateRedirectDomain(ctx, input.redirectTo)) {
throw new TRPCError({
code: 'FORBIDDEN',
message: `The redirect URL is invalid: ${input.redirectTo}`,
})
}
+ const secure = ctx.req?.url.startsWith('https:') ? 'Secure; ' : ''
ctx.setCookie(
- `${input.provider}_oauth_redirect=${input.redirectTo || ''}; Path=/; HttpOnly; SameSite=Lax`
+ `${input.provider}_oauth_redirect=${
+ input.redirectTo || ''
+ }; Path=/; ${secure}HttpOnly; SameSite=Lax`
)
return { redirectTo: url.toString() }
}
@@ -245,7 +247,9 @@ const signInWithEmailCodeHandler =
console.log('calling update passing and invalidate sessions')
await ctx.auth.invalidateUserSessions(res.session?.userId)
const session = await createSession(ctx.auth, res.session?.userId)
- ctx.authRequest?.setSessionCookie(session.id)
+ if (ctx.setCookie) {
+ ctx.setCookie(ctx.auth.createSessionCookie(session.id).serialize())
+ }
res.session = session
}
return res
@@ -338,7 +342,9 @@ export const userRouter = router({
email: input.email,
})
const session = await createSession(ctx.auth, user.id)
- ctx.authRequest?.setSessionCookie(session.id)
+ if (ctx.setCookie) {
+ ctx.setCookie(ctx.auth.createSessionCookie(session.id).serialize())
+ }
return { session }
}),
})
diff --git a/packages/api/src/schema/user.ts b/packages/api/src/schema/user.ts
index f2e2749ad..cda67eaf3 100644
--- a/packages/api/src/schema/user.ts
+++ b/packages/api/src/schema/user.ts
@@ -44,6 +44,7 @@ export const SignInSchema = object({
idToken: optional(string()),
refreshToken: optional(string()),
nonce: optional(string()),
+ appleUser: optional(object({ email: optionalEmail })),
})
export type SignInInput = Input
diff --git a/packages/api/src/utils/crypto.ts b/packages/api/src/utils/crypto.ts
index 38c130259..fcd1e588c 100644
--- a/packages/api/src/utils/crypto.ts
+++ b/packages/api/src/utils/crypto.ts
@@ -1,7 +1,7 @@
import { TimeSpan, isWithinExpirationDate } from 'oslo'
import { TOTPController } from 'oslo/otp'
import { decodeBase64, encodeBase64, encodeHex } from 'oslo/encoding'
-import { type JWT, parseJWT, validateJWT } from './jwt'
+import { type JWT, parseJWT, validateJWT } from 'oslo/jwt'
import { match, P } from 'ts-pattern'
import { HMAC, sha256 as sha256AB } from 'oslo/crypto'
diff --git a/packages/api/src/utils/jwt.ts b/packages/api/src/utils/jwt.ts
deleted file mode 100644
index b1166f4ea..000000000
--- a/packages/api/src/utils/jwt.ts
+++ /dev/null
@@ -1,309 +0,0 @@
-// Temp copied from https://raw.githubusercontent.com/rmarscher/oslo/optional-typ-header/src/jwt/index.ts
-// https://github.com/pilcrowOnPaper/oslo/pull/9
-import { ECDSA, HMAC, RSASSAPKCS1v1_5, RSASSAPSS } from 'oslo/crypto'
-import { decodeBase64url, encodeBase64url } from 'oslo/encoding'
-import { isWithinExpirationDate } from 'oslo'
-import type { TimeSpan } from 'oslo'
-
-export type JWTAlgorithm =
- | 'HS256'
- | 'HS384'
- | 'HS512'
- | 'RS256'
- | 'RS384'
- | 'RS512'
- | 'ES256'
- | 'ES384'
- | 'ES512'
- | 'PS256'
- | 'PS384'
- | 'PS512'
-
-export async function createJWT(
- algorithm: JWTAlgorithm,
- key: ArrayBuffer,
- payloadClaims: Record,
- options?: {
- headers?: Record
- expiresIn?: TimeSpan
- issuer?: string
- subject?: string
- audience?: string
- notBefore?: Date
- includeIssuedTimestamp?: boolean
- jwtId?: string
- }
-): Promise {
- const header: JWTHeader = {
- alg: algorithm,
- typ: 'JWT',
- ...options?.headers,
- }
- const payload: JWTPayload = {
- ...payloadClaims,
- }
- if (options?.audience !== undefined) {
- payload.aud = options.audience
- }
- if (options?.subject !== undefined) {
- payload.sub = options.subject
- }
- if (options?.issuer !== undefined) {
- payload.iss = options.issuer
- }
- if (options?.jwtId !== undefined) {
- payload.jti = options.jwtId
- }
- if (options?.expiresIn !== undefined) {
- payload.exp = Math.floor(Date.now() / 1000) + options.expiresIn.seconds()
- }
- if (options?.notBefore !== undefined) {
- payload.nbf = Math.floor(options.notBefore.getTime() / 1000)
- }
- if (options?.includeIssuedTimestamp === true) {
- payload.iat = Math.floor(Date.now() / 1000)
- }
- const textEncoder = new TextEncoder()
- const headerPart = encodeBase64url(
- textEncoder.encode(JSON.stringify(header)).buffer as ArrayBuffer
- )
- const payloadPart = encodeBase64url(
- textEncoder.encode(JSON.stringify(payload)).buffer as ArrayBuffer
- )
- const data = textEncoder.encode([headerPart, payloadPart].join('.')).buffer as ArrayBuffer
- const signature = await getAlgorithm(algorithm).sign(key, data)
- const signaturePart = encodeBase64url(signature)
- const value = [headerPart, payloadPart, signaturePart].join('.')
- return value
-}
-
-export async function validateJWT(
- algorithm: JWTAlgorithm,
- key: ArrayBuffer,
- jwt: string
-): Promise {
- const parsedJWT = parseJWT(jwt)
- if (!parsedJWT) {
- throw new Error('Invalid JWT')
- }
- if (parsedJWT.algorithm !== algorithm) {
- throw new Error('Invalid algorithm')
- }
- if (parsedJWT.expiresAt && !isWithinExpirationDate(parsedJWT.expiresAt)) {
- throw new Error('Expired JWT')
- }
- if (parsedJWT.notBefore && Date.now() < parsedJWT.notBefore.getTime()) {
- throw new Error('Inactive JWT')
- }
- const signature = decodeBase64url(parsedJWT.parts[2])
- const data = new TextEncoder().encode(`${parsedJWT.parts[0]}.${parsedJWT.parts[1]}`)
- .buffer as ArrayBuffer
- const validSignature = await getAlgorithm(parsedJWT.algorithm).verify(
- key,
- signature.buffer as ArrayBuffer,
- data
- )
- if (!validSignature) {
- throw new Error('Invalid signature')
- }
- return parsedJWT
-}
-
-function getJWTParts(jwt: string): [header: string, payload: string, signature: string] | null {
- const jwtParts = jwt.split('.')
- if (jwtParts.length !== 3) {
- return null
- }
- return jwtParts as [string, string, string]
-}
-
-export function parseJWT(jwt: string): JWT | null {
- const jwtParts = getJWTParts(jwt)
- if (!jwtParts) {
- return null
- }
- const textDecoder = new TextDecoder()
- const rawHeader = decodeBase64url(jwtParts[0])
- const rawPayload = decodeBase64url(jwtParts[1])
- const header: unknown = JSON.parse(textDecoder.decode(rawHeader))
- if (typeof header !== 'object' || header === null) {
- return null
- }
- if (!('alg' in header) || !isValidAlgorithm(header.alg)) {
- return null
- }
- if ('typ' in header && header.typ !== 'JWT') {
- return null
- }
- const payload: unknown = JSON.parse(textDecoder.decode(rawPayload))
- if (typeof payload !== 'object' || payload === null) {
- return null
- }
- const properties: JWTProperties = {
- algorithm: header.alg,
- expiresAt: null,
- subject: null,
- issuedAt: null,
- issuer: null,
- jwtId: null,
- audience: null,
- notBefore: null,
- }
- if ('exp' in payload) {
- if (typeof payload.exp !== 'number') {
- return null
- }
- properties.expiresAt = new Date(payload.exp * 1000)
- }
- if ('iss' in payload) {
- if (typeof payload.iss !== 'string') {
- return null
- }
- properties.issuer = payload.iss
- }
- if ('sub' in payload) {
- if (typeof payload.sub !== 'string') {
- return null
- }
- properties.subject = payload.sub
- }
- if ('aud' in payload) {
- if (typeof payload.aud !== 'string') {
- return null
- }
- properties.audience = payload.aud
- }
- if ('nbf' in payload) {
- if (typeof payload.nbf !== 'number') {
- return null
- }
- properties.notBefore = new Date(payload.nbf * 1000)
- }
- if ('iat' in payload) {
- if (typeof payload.iat !== 'number') {
- return null
- }
- properties.issuedAt = new Date(payload.iat * 1000)
- }
- if ('jti' in payload) {
- if (typeof payload.jti !== 'string') {
- return null
- }
- properties.jwtId = payload.jti
- }
- return {
- value: jwt,
- header: {
- ...header,
- typ: 'JWT',
- alg: header.alg,
- },
- payload: {
- ...payload,
- },
- parts: jwtParts,
- ...properties,
- }
-}
-
-interface JWTProperties {
- algorithm: JWTAlgorithm
- expiresAt: Date | null
- issuer: string | null
- subject: string | null
- audience: string | null
- notBefore: Date | null
- issuedAt: Date | null
- jwtId: string | null
-}
-
-export interface JWT extends JWTProperties {
- value: string
- header: object
- payload: object
- parts: [header: string, payload: string, signature: string]
-}
-
-function getAlgorithm(algorithm: JWTAlgorithm): ECDSA | HMAC | RSASSAPKCS1v1_5 | RSASSAPSS {
- if (algorithm === 'ES256' || algorithm === 'ES384' || algorithm === 'ES512') {
- return new ECDSA(ecdsaDictionary[algorithm].hash, ecdsaDictionary[algorithm].curve)
- }
- if (algorithm === 'HS256' || algorithm === 'HS384' || algorithm === 'HS512') {
- return new HMAC(hmacDictionary[algorithm])
- }
- if (algorithm === 'RS256' || algorithm === 'RS384' || algorithm === 'RS512') {
- return new RSASSAPKCS1v1_5(rsassapkcs1v1_5Dictionary[algorithm])
- }
- if (algorithm === 'PS256' || algorithm === 'PS384' || algorithm === 'PS512') {
- return new RSASSAPSS(rsassapssDictionary[algorithm])
- }
- throw new TypeError('Invalid algorithm')
-}
-
-function isValidAlgorithm(maybeValidAlgorithm: unknown): maybeValidAlgorithm is JWTAlgorithm {
- if (typeof maybeValidAlgorithm !== 'string') return false
- return [
- 'HS256',
- 'HS384',
- 'HS512',
- 'RS256',
- 'RS384',
- 'RS512',
- 'ES256',
- 'ES384',
- 'ES512',
- 'PS256',
- 'PS384',
- 'PS512',
- ].includes(maybeValidAlgorithm)
-}
-
-interface JWTHeader {
- typ: 'JWT'
- alg: JWTAlgorithm
- [header: string]: any
-}
-
-interface JWTPayload {
- exp?: number
- iss?: string
- aud?: string
- jti?: string
- nbf?: number
- sub?: string
- iat?: number
- [claim: string]: any
-}
-
-const ecdsaDictionary = {
- ES256: {
- hash: 'SHA-256',
- curve: 'P-256',
- },
- ES384: {
- hash: 'SHA-384',
- curve: 'P-384',
- },
- ES512: {
- hash: 'SHA-512',
- curve: 'P-521',
- },
-} as const
-
-const hmacDictionary = {
- HS256: 'SHA-256',
- HS384: 'SHA-384',
- HS512: 'SHA-512',
-} as const
-
-const rsassapkcs1v1_5Dictionary = {
- RS256: 'SHA-256',
- RS384: 'SHA-384',
- RS512: 'SHA-512',
-} as const
-
-const rsassapssDictionary = {
- PS256: 'SHA-256',
- PS384: 'SHA-384',
- PS512: 'SHA-512',
-} as const
diff --git a/packages/api/src/utils/password.ts b/packages/api/src/utils/password.ts
index e8c3bf79f..008cd5af7 100644
--- a/packages/api/src/utils/password.ts
+++ b/packages/api/src/utils/password.ts
@@ -1,6 +1,6 @@
import { argon2Hash, argon2Verify } from './password/argon2'
import { HashMethod } from './password/hash-methods'
-import { verifyLegacyLuciaPasswordHash } from 'lucia'
+import { LegacyScrypt } from 'lucia'
export async function hashPassword(password: string) {
const hashedPassword = await argon2Hash(password)
@@ -19,7 +19,7 @@ export async function verifyPassword(
case 'scrypt':
case null:
case undefined:
- return await verifyLegacyLuciaPasswordHash(password, hashedPassword)
+ return await new LegacyScrypt().verify(hashedPassword, password)
case 'argon2':
return await argon2Verify(password, hashedPassword)
default:
diff --git a/packages/api/src/worker.ts b/packages/api/src/worker.ts
index 8575d44cc..2f0d71031 100644
--- a/packages/api/src/worker.ts
+++ b/packages/api/src/worker.ts
@@ -3,6 +3,7 @@ import { appRouter } from '@t4/api/src/router'
import { cors } from 'hono/cors'
import { createContext } from '@t4/api/src/context'
import { trpcServer } from '@hono/trpc-server'
+import { csrfMiddleware } from './auth/hono'
export type Bindings = Env & {
JWT_VERIFICATION_KEY: string
@@ -41,6 +42,8 @@ const corsHandler = async (c: Context<{ Bindings: Bindings }>, next: Next) => {
})(c, next)
}
+app.use('*', csrfMiddleware)
+
// Setup CORS for the frontend
app.use('/trpc/*', corsHandler)
diff --git a/packages/app/features/oauth/screen.tsx b/packages/app/features/oauth/screen.tsx
index a7e58e110..98c972ff2 100644
--- a/packages/app/features/oauth/screen.tsx
+++ b/packages/app/features/oauth/screen.tsx
@@ -1,15 +1,45 @@
+import type { GetServerSideProps } from 'next'
import type { AuthProviderName } from '@t4/api/src/auth/providers'
import { Paragraph, isServer } from '@t4/ui'
-import { useSignIn } from 'app/utils/auth'
+import { type SignInWithOAuth, useSignIn } from 'app/utils/auth'
import { useCallback, useEffect, useRef, useState } from 'react'
import { createParam } from 'solito'
import { P, match } from 'ts-pattern'
-type Params = { provider: AuthProviderName; redirectTo: string; code?: string; state?: string }
+type Params = {
+ provider: AuthProviderName
+ redirectTo: string
+ code?: string
+ state?: string
+}
const { useParam } = createParam()
-export const OAuthSignInScreen = (): React.ReactNode => {
+// Apple will POST form data to the redirect URI
+export const getServerSideProps = (async (context) => {
+ // Fetch data from external API
+ let appleUser = null
+ if (context.req.method !== 'POST') {
+ return { props: { appleUser } }
+ }
+ try {
+ const userJSON = context.req.headers['x-apple-user'] as string | undefined
+ if (typeof userJSON === 'string') {
+ appleUser = JSON.parse(userJSON)
+ }
+ } catch (e: unknown) {
+ console.error(e)
+ }
+ // Pass data to the page via props
+ return { props: { appleUser } }
+}) satisfies GetServerSideProps
+
+export interface OAuthSignInScreenProps {
+ appleUser?: { email?: string | null } | null
+}
+
+
+export const OAuthSignInScreen = ({ appleUser }: OAuthSignInScreenProps): React.ReactNode => {
const sent = useRef(false)
const { signIn } = useSignIn()
const [provider] = useParam('provider')
@@ -19,7 +49,7 @@ export const OAuthSignInScreen = (): React.ReactNode => {
const [error, setError] = useState(undefined)
const sendApiRequestOnLoad = useCallback(
- async (params: Params) => {
+ async (params: SignInWithOAuth) => {
if (sent.current) return
sent.current = true
try {
@@ -46,8 +76,18 @@ export const OAuthSignInScreen = (): React.ReactNode => {
useEffect(() => {
if (sent.current) return
if (!provider) return
- sendApiRequestOnLoad({ provider, redirectTo: redirectTo || '', state, code })
- }, [provider, redirectTo, state, code, sendApiRequestOnLoad])
+ sendApiRequestOnLoad({
+ provider,
+ redirectTo: redirectTo || '',
+ state,
+ code,
+ // undefined vs null is a result of passing via JSON with getServerSideProps
+ // Maybe there's a superjson plugin or another way to handle it.
+ appleUser: appleUser ? {
+ email: appleUser.email || undefined,
+ } : undefined,
+ })
+ }, [provider, redirectTo, state, code, sendApiRequestOnLoad, appleUser])
const message = match([error, code])
.with([undefined, P._], () => 'Signing in...')
diff --git a/packages/app/utils/auth/index.ts b/packages/app/utils/auth/index.ts
index ba7163d83..0ec14e608 100644
--- a/packages/app/utils/auth/index.ts
+++ b/packages/app/utils/auth/index.ts
@@ -129,6 +129,7 @@ export type SignInWithOAuth = {
redirectTo?: string
code?: string
state?: string
+ appleUser?: { email?: string }
}
export type SignInProps =