Skip to content

Feat: add WithRequest as a type and the callback #26

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 59 additions & 14 deletions src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import passport from 'passport'
import OAuth2Strategy, { VerifyCallback } from 'passport-oauth2'
import debug from 'debug'
import payload from 'payload'
import { Request } from 'express'
import { Config } from 'payload/config'
import {
Field,
Expand All @@ -14,13 +15,13 @@ import {
} from 'payload/dist/fields/config/types'
import { PaginatedDocs } from 'payload/dist/database/types'
import getCookieExpiration from 'payload/dist/utilities/getCookieExpiration'
import { TextField } from 'payload/types'
import { PayloadRequest, TextField } from 'payload/types'

import OAuthButton from './OAuthButton'
import type { oAuthPluginOptions } from './types'
import type { oAuthPluginOptions, oAuthPluginOptionsWithRequest } from './types'
import { createElement } from 'react'

export { OAuthButton, oAuthPluginOptions }
export { OAuthButton, oAuthPluginOptions, oAuthPluginOptionsWithRequest }

interface User {}

Expand Down Expand Up @@ -60,7 +61,7 @@ const CLIENTSIDE = typeof session !== 'function'
* ```
*/
export const oAuthPlugin =
(options: oAuthPluginOptions) =>
(options: oAuthPluginOptions | oAuthPluginOptionsWithRequest) =>
(incoming: Config): Config => {
// Shorthands
const collectionSlug = options.userCollection?.slug || 'users'
Expand Down Expand Up @@ -93,7 +94,7 @@ export const oAuthPlugin =

function oAuthPluginClient(
incoming: Config,
options: oAuthPluginOptions
options: oAuthPluginOptions | oAuthPluginOptionsWithRequest
): Config {
const button = options.components?.Button ?? OAuthButton
return button
Expand All @@ -118,7 +119,7 @@ function oAuthPluginClient(

function oAuthPluginServer(
incoming: Config,
options: oAuthPluginOptions
options: oAuthPluginOptions | oAuthPluginOptionsWithRequest
): Config {
// Shorthands
const callbackPath =
Expand All @@ -143,14 +144,14 @@ function oAuthPluginServer(
throw new Error(
`Choose a unique callbackPath for oAuth strategy ${oAuthStrategyCount} (not ${options.callbackPath})`
)

// Passport strategy
const strategy = new OAuth2Strategy(options, async function (
async function verifyCallback(
accessToken: string,
refreshToken: string,
profile: {},
cb: VerifyCallback
) {
cb: VerifyCallback,
options: any, // Adjust type of options as needed
req?: Request // Optional request parameter
): Promise<any> {
let info: {
sub: string
email?: string
Expand All @@ -159,9 +160,10 @@ function oAuthPluginServer(
}
let user: User & { collection?: any; _strategy?: any }
let users: PaginatedDocs<User>

try {
// Get the userinfo
info = await options.userinfo?.(accessToken, refreshToken)
info = await options.userinfo?.(accessToken, refreshToken, req)
if (!info) throw new Error('Failed to get userinfo')

// Match existing user
Expand All @@ -171,7 +173,20 @@ function oAuthPluginServer(
showHiddenFields: true,
})

if (users.docs && users.docs.length) {
// Connect user to current login profile if already authenticated
if (req?.user && (req as PayloadRequest).user.id) {
await payload.update({
collection: collectionSlug,
id: (req as PayloadRequest).user.id,
data: {
...info,
showHiddenFields: true,
},
})
log('connect.user', req.user)
return cb(null, req.user)
} else if (users.docs && users.docs.length) {
// User exists
user = users.docs[0]
user.collection = collectionSlug
user._strategy = strategyName
Expand All @@ -196,8 +211,38 @@ function oAuthPluginServer(
log('signin.fail', error.message, error.trace)
cb(error)
}
})
}

// Passport strategy
let strategy: OAuth2Strategy

if ('passReqToCallback' in options && options.passReqToCallback) {
strategy = new OAuth2Strategy(options, async function (
req: Request,
accessToken: string,
refreshToken: string,
profile: {},
cb: VerifyCallback
) {
await verifyCallback(
accessToken,
refreshToken,
profile,
cb,
options,
req
)
})
} else {
strategy = new OAuth2Strategy(options, async function (
accessToken: string,
refreshToken: string,
profile: {},
cb: VerifyCallback
) {
await verifyCallback(accessToken, refreshToken, profile, cb, options)
})
}
// Alternative?
// strategy.userProfile = async (accessToken, cb) => {
// const user = await options.userinfo?.(accessToken)
Expand Down
24 changes: 21 additions & 3 deletions src/types.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
import { type SessionOptions } from 'express-session'
import type { StrategyOptions } from 'passport-oauth2'
import type {
StrategyOptions,
StrategyOptionsWithRequest,
} from 'passport-oauth2'
import { Request } from 'express'
import type { ComponentType } from 'react'

export interface oAuthPluginOptions extends StrategyOptions {
interface BaseOAuthPluginOptions {
/** Database connection URI in case the lib needs access to database */
databaseUri: string

Expand Down Expand Up @@ -34,7 +38,8 @@ export interface oAuthPluginOptions extends StrategyOptions {
/** Map an authentication result to a user */
userinfo: (
accessToken: string,
refreshToken?: string
refreshToken?: string,
req?: Request
) => Promise<{
/** Unique identifier for the linked account */
sub: string
Expand Down Expand Up @@ -73,6 +78,19 @@ export interface oAuthPluginOptions extends StrategyOptions {
successRedirect?: string
}

export interface oAuthPluginOptions
extends BaseOAuthPluginOptions,
StrategyOptions {}

export interface oAuthPluginOptionsWithRequest
extends BaseOAuthPluginOptions,
StrategyOptionsWithRequest {
/**
* With this option enabled, req will be passed as the first argument to the verify callback.
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about always passing the request as last argument to the verify callback?

* @default true
*/
passReqToCallback: true
}
export type ButtonProps = {
/** Path that initiates the oAuth flow */
authorizePath: string
Expand Down