@@ -43,6 +43,12 @@ var jsonnetCache, _ = ristretto.NewCache(&ristretto.Config[[]byte, []byte]{
4343
4444type MetadataType string
4545
46+ type OIDCProviderData struct {
47+ Provider string `json:"provider"`
48+ Tokens * identity.CredentialsOIDCEncryptedTokens `json:"tokens"`
49+ Claims Claims `json:"claims"`
50+ }
51+
4652type VerifiedAddress struct {
4753 Value string `json:"value"`
4854 Via identity.VerifiableAddressType `json:"via"`
@@ -53,6 +59,8 @@ const (
5359
5460 PublicMetadata MetadataType = "identity.metadata_public"
5561 AdminMetadata MetadataType = "identity.metadata_admin"
62+
63+ InternalContextKeyProviderData = "provider_data"
5664)
5765
5866func (s * Strategy ) RegisterRegistrationRoutes (r * x.RouterPublic ) {
@@ -216,6 +224,26 @@ func (s *Strategy) Register(w http.ResponseWriter, r *http.Request, f *registrat
216224 return errors .WithStack (flow .ErrCompletedByStrategy )
217225 }
218226
227+ providerDataKey := flow .PrefixInternalContextKey (s .ID (), InternalContextKeyProviderData )
228+ if oidcProviderData := gjson .GetBytes (f .InternalContext , providerDataKey ); oidcProviderData .IsObject () {
229+ var providerData OIDCProviderData
230+ if err = json .Unmarshal ([]byte (oidcProviderData .Raw ), & providerData ); err != nil {
231+ return s .handleError (ctx , w , r , f , pid , nil , errors .WithStack (herodot .ErrInternalServerError .WithReasonf ("Expected OIDC provider data in internal context to be an object but got: %w" , err )))
232+ }
233+ if pid != providerData .Provider {
234+ return s .handleError (ctx , w , r , f , pid , nil , errors .WithStack (herodot .ErrInternalServerError .WithReasonf ("Expected OIDC provider data in internal context to have matching provider but got: %s" , providerData .Provider )))
235+ }
236+ _ , err = s .processRegistration (ctx , w , r , f , providerData .Tokens , & providerData .Claims , provider , & AuthCodeContainer {
237+ FlowID : f .ID .String (),
238+ Traits : p .Traits ,
239+ TransientPayload : f .TransientPayload ,
240+ })
241+ if err != nil {
242+ return s .handleError (ctx , w , r , f , pid , nil , err )
243+ }
244+ return errors .WithStack (flow .ErrCompletedByStrategy )
245+ }
246+
219247 state , pkce , err := s .GenerateState (ctx , provider , f .ID )
220248 if err != nil {
221249 return s .handleError (ctx , w , r , f , pid , nil , err )
@@ -313,6 +341,13 @@ func (s *Strategy) processRegistration(ctx context.Context, w http.ResponseWrite
313341 return nil , nil
314342 }
315343
344+ providerDataKey := flow .PrefixInternalContextKey (s .ID (), InternalContextKeyProviderData )
345+ if hasOIDCProviderData := gjson .GetBytes (rf .InternalContext , providerDataKey ).IsObject (); ! hasOIDCProviderData {
346+ if internalContext , err := sjson .SetBytes (rf .InternalContext , providerDataKey , & OIDCProviderData {Provider : provider .Config ().ID , Tokens : token , Claims : * claims }); err == nil {
347+ rf .InternalContext = internalContext
348+ }
349+ }
350+
316351 fetch := fetcher .NewFetcher (fetcher .WithClient (s .d .HTTPClient (ctx )), fetcher .WithCache (jsonnetCache , 60 * time .Minute ))
317352 jsonnetMapperSnippet , err := fetch .FetchContext (ctx , provider .Config ().Mapper )
318353 if err != nil {
@@ -351,6 +386,10 @@ func (s *Strategy) processRegistration(ctx context.Context, w http.ResponseWrite
351386 return nil , s .handleError (ctx , w , r , rf , provider .Config ().ID , i .Traits , err )
352387 }
353388
389+ if internalContext , err := sjson .DeleteBytes (rf .InternalContext , providerDataKey ); err == nil {
390+ rf .InternalContext = internalContext
391+ }
392+
354393 return nil , nil
355394}
356395
0 commit comments