From ff5d39170153294b90c5dd32529630190eb37737 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fr=C3=A9d=C3=A9ric=20Guillot?= Date: Sat, 2 Sep 2023 21:35:10 -0700 Subject: [PATCH] Add OAuth2 PKCE support --- internal/http/request/context.go | 5 +++ internal/model/app_session.go | 5 +-- internal/oauth2/authorization.go | 54 ++++++++++++++++++++++++++++++++ internal/oauth2/google.go | 44 ++++++++++++-------------- internal/oauth2/manager.go | 8 ++--- internal/oauth2/oidc.go | 44 ++++++++++++-------------- internal/oauth2/provider.go | 6 ++-- internal/storage/session.go | 2 +- internal/ui/middleware.go | 3 +- internal/ui/oauth2_callback.go | 5 +-- internal/ui/oauth2_redirect.go | 8 ++++- internal/ui/session/session.go | 10 +++--- 12 files changed, 126 insertions(+), 68 deletions(-) create mode 100644 internal/oauth2/authorization.go diff --git a/internal/http/request/context.go b/internal/http/request/context.go index caa6f543..7fb50685 100644 --- a/internal/http/request/context.go +++ b/internal/http/request/context.go @@ -20,6 +20,7 @@ const ( SessionIDContextKey CSRFContextKey OAuth2StateContextKey + OAuth2CodeVerifierContextKey FlashMessageContextKey FlashErrorMessageContextKey PocketRequestTokenContextKey @@ -94,6 +95,10 @@ func OAuth2State(r *http.Request) string { return getContextStringValue(r, OAuth2StateContextKey) } +func OAuth2CodeVerifier(r *http.Request) string { + return getContextStringValue(r, OAuth2CodeVerifierContextKey) +} + // FlashMessage returns the message message if any. func FlashMessage(r *http.Request) string { return getContextStringValue(r, FlashMessageContextKey) diff --git a/internal/model/app_session.go b/internal/model/app_session.go index dee4d09d..44c1e251 100644 --- a/internal/model/app_session.go +++ b/internal/model/app_session.go @@ -14,6 +14,7 @@ import ( type SessionData struct { CSRF string `json:"csrf"` OAuth2State string `json:"oauth2_state"` + OAuth2CodeVerifier string `json:"oauth2_code_verifier"` FlashMessage string `json:"flash_message"` FlashErrorMessage string `json:"flash_error_message"` Language string `json:"language"` @@ -22,8 +23,8 @@ type SessionData struct { } func (s SessionData) String() string { - return fmt.Sprintf(`CSRF=%q, OAuth2State=%q, FlashMsg=%q, FlashErrMsg=%q, Lang=%q, Theme=%q, PocketTkn=%q`, - s.CSRF, s.OAuth2State, s.FlashMessage, s.FlashErrorMessage, s.Language, s.Theme, s.PocketRequestToken) + return fmt.Sprintf(`CSRF=%q, OAuth2State=%q, OAuth2CodeVerifier=%q, FlashMsg=%q, FlashErrMsg=%q, Lang=%q, Theme=%q, PocketTkn=%q`, + s.CSRF, s.OAuth2State, s.OAuth2CodeVerifier, s.FlashMessage, s.FlashErrorMessage, s.Language, s.Theme, s.PocketRequestToken) } // Value converts the session data to JSON. diff --git a/internal/oauth2/authorization.go b/internal/oauth2/authorization.go new file mode 100644 index 00000000..5854cb8c --- /dev/null +++ b/internal/oauth2/authorization.go @@ -0,0 +1,54 @@ +// SPDX-FileCopyrightText: Copyright The Miniflux Authors. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +package oauth2 // import "miniflux.app/v2/internal/oauth2" + +import ( + "crypto/sha256" + "encoding/base64" + "io" + + "golang.org/x/oauth2" + + "miniflux.app/v2/internal/crypto" +) + +type Authorization struct { + url string + state string + codeVerifier string +} + +func (u *Authorization) RedirectURL() string { + return u.url +} + +func (u *Authorization) State() string { + return u.state +} + +func (u *Authorization) CodeVerifier() string { + return u.codeVerifier +} + +func GenerateAuthorization(config *oauth2.Config) *Authorization { + codeVerifier := crypto.GenerateRandomStringHex(32) + + sha2 := sha256.New() + io.WriteString(sha2, codeVerifier) + codeChallenge := base64.RawURLEncoding.EncodeToString(sha2.Sum(nil)) + + state := crypto.GenerateRandomStringHex(24) + + authUrl := config.AuthCodeURL( + state, + oauth2.SetAuthURLParam("code_challenge_method", "S256"), + oauth2.SetAuthURLParam("code_challenge", codeChallenge), + ) + + return &Authorization{ + url: authUrl, + state: state, + codeVerifier: codeVerifier, + } +} diff --git a/internal/oauth2/google.go b/internal/oauth2/google.go index 790d31a1..495a69b0 100644 --- a/internal/oauth2/google.go +++ b/internal/oauth2/google.go @@ -24,17 +24,30 @@ type googleProvider struct { redirectURL string } +func NewGoogleProvider(clientID, clientSecret, redirectURL string) *googleProvider { + return &googleProvider{clientID: clientID, clientSecret: clientSecret, redirectURL: redirectURL} +} + +func (g *googleProvider) GetConfig() *oauth2.Config { + return &oauth2.Config{ + RedirectURL: g.redirectURL, + ClientID: g.clientID, + ClientSecret: g.clientSecret, + Scopes: []string{"email"}, + Endpoint: oauth2.Endpoint{ + AuthURL: "https://accounts.google.com/o/oauth2/auth", + TokenURL: "https://accounts.google.com/o/oauth2/token", + }, + } +} + func (g *googleProvider) GetUserExtraKey() string { return "google_id" } -func (g *googleProvider) GetRedirectURL(state string) string { - return g.config().AuthCodeURL(state) -} - -func (g *googleProvider) GetProfile(ctx context.Context, code string) (*Profile, error) { - conf := g.config() - token, err := conf.Exchange(ctx, code) +func (g *googleProvider) GetProfile(ctx context.Context, code, codeVerifier string) (*Profile, error) { + conf := g.GetConfig() + token, err := conf.Exchange(ctx, code, oauth2.SetAuthURLParam("code_verifier", codeVerifier)) if err != nil { return nil, err } @@ -67,20 +80,3 @@ func (g *googleProvider) PopulateUserWithProfileID(user *model.User, profile *Pr func (g *googleProvider) UnsetUserProfileID(user *model.User) { user.GoogleID = "" } - -func (g *googleProvider) config() *oauth2.Config { - return &oauth2.Config{ - RedirectURL: g.redirectURL, - ClientID: g.clientID, - ClientSecret: g.clientSecret, - Scopes: []string{"email"}, - Endpoint: oauth2.Endpoint{ - AuthURL: "https://accounts.google.com/o/oauth2/auth", - TokenURL: "https://accounts.google.com/o/oauth2/token", - }, - } -} - -func newGoogleProvider(clientID, clientSecret, redirectURL string) *googleProvider { - return &googleProvider{clientID: clientID, clientSecret: clientSecret, redirectURL: redirectURL} -} diff --git a/internal/oauth2/manager.go b/internal/oauth2/manager.go index 6ffbb952..04c6ef2a 100644 --- a/internal/oauth2/manager.go +++ b/internal/oauth2/manager.go @@ -10,12 +10,10 @@ import ( "miniflux.app/v2/internal/logger" ) -// Manager handles OAuth2 providers. type Manager struct { providers map[string]Provider } -// FindProvider returns the given provider. func (m *Manager) FindProvider(name string) (Provider, error) { if provider, found := m.providers[name]; found { return provider, nil @@ -24,18 +22,16 @@ func (m *Manager) FindProvider(name string) (Provider, error) { return nil, errors.New("oauth2 provider not found") } -// AddProvider add a new OAuth2 provider. func (m *Manager) AddProvider(name string, provider Provider) { m.providers[name] = provider } -// NewManager returns a new Manager. func NewManager(ctx context.Context, clientID, clientSecret, redirectURL, oidcDiscoveryEndpoint string) *Manager { m := &Manager{providers: make(map[string]Provider)} - m.AddProvider("google", newGoogleProvider(clientID, clientSecret, redirectURL)) + m.AddProvider("google", NewGoogleProvider(clientID, clientSecret, redirectURL)) if oidcDiscoveryEndpoint != "" { - if genericOidcProvider, err := newOidcProvider(ctx, clientID, clientSecret, redirectURL, oidcDiscoveryEndpoint); err != nil { + if genericOidcProvider, err := NewOidcProvider(ctx, clientID, clientSecret, redirectURL, oidcDiscoveryEndpoint); err != nil { logger.Error("[OAuth2] failed to initialize OIDC provider: %v", err) } else { m.AddProvider("oidc", genericOidcProvider) diff --git a/internal/oauth2/oidc.go b/internal/oauth2/oidc.go index 6ea7063d..c65f11ae 100644 --- a/internal/oauth2/oidc.go +++ b/internal/oauth2/oidc.go @@ -19,17 +19,32 @@ type oidcProvider struct { provider *oidc.Provider } +func NewOidcProvider(ctx context.Context, clientID, clientSecret, redirectURL, discoveryEndpoint string) (*oidcProvider, error) { + provider, err := oidc.NewProvider(ctx, discoveryEndpoint) + if err != nil { + return nil, err + } + + return &oidcProvider{clientID: clientID, clientSecret: clientSecret, redirectURL: redirectURL, provider: provider}, nil +} + func (o *oidcProvider) GetUserExtraKey() string { return "openid_connect_id" } -func (o *oidcProvider) GetRedirectURL(state string) string { - return o.config().AuthCodeURL(state) +func (o *oidcProvider) GetConfig() *oauth2.Config { + return &oauth2.Config{ + RedirectURL: o.redirectURL, + ClientID: o.clientID, + ClientSecret: o.clientSecret, + Scopes: []string{"openid", "email"}, + Endpoint: o.provider.Endpoint(), + } } -func (o *oidcProvider) GetProfile(ctx context.Context, code string) (*Profile, error) { - conf := o.config() - token, err := conf.Exchange(ctx, code) +func (o *oidcProvider) GetProfile(ctx context.Context, code, codeVerifier string) (*Profile, error) { + conf := o.GetConfig() + token, err := conf.Exchange(ctx, code, oauth2.SetAuthURLParam("code_verifier", codeVerifier)) if err != nil { return nil, err } @@ -54,22 +69,3 @@ func (o *oidcProvider) PopulateUserWithProfileID(user *model.User, profile *Prof func (o *oidcProvider) UnsetUserProfileID(user *model.User) { user.OpenIDConnectID = "" } - -func (o *oidcProvider) config() *oauth2.Config { - return &oauth2.Config{ - RedirectURL: o.redirectURL, - ClientID: o.clientID, - ClientSecret: o.clientSecret, - Scopes: []string{"openid", "email"}, - Endpoint: o.provider.Endpoint(), - } -} - -func newOidcProvider(ctx context.Context, clientID, clientSecret, redirectURL, discoveryEndpoint string) (*oidcProvider, error) { - provider, err := oidc.NewProvider(ctx, discoveryEndpoint) - if err != nil { - return nil, err - } - - return &oidcProvider{clientID: clientID, clientSecret: clientSecret, redirectURL: redirectURL, provider: provider}, nil -} diff --git a/internal/oauth2/provider.go b/internal/oauth2/provider.go index 8243dc80..63a4bf72 100644 --- a/internal/oauth2/provider.go +++ b/internal/oauth2/provider.go @@ -6,14 +6,16 @@ package oauth2 // import "miniflux.app/v2/internal/oauth2" import ( "context" + "golang.org/x/oauth2" + "miniflux.app/v2/internal/model" ) // Provider is an interface for OAuth2 providers. type Provider interface { + GetConfig() *oauth2.Config GetUserExtraKey() string - GetRedirectURL(state string) string - GetProfile(ctx context.Context, code string) (*Profile, error) + GetProfile(ctx context.Context, code, codeVerifier string) (*Profile, error) PopulateUserCreationWithProfileID(user *model.UserCreationRequest, profile *Profile) PopulateUserWithProfileID(user *model.User, profile *Profile) UnsetUserProfileID(user *model.User) diff --git a/internal/storage/session.go b/internal/storage/session.go index 77e20c35..fb78ea02 100644 --- a/internal/storage/session.go +++ b/internal/storage/session.go @@ -53,7 +53,7 @@ func (s *Storage) createAppSession(session *model.Session) (*model.Session, erro } // UpdateAppSessionField updates only one session field. -func (s *Storage) UpdateAppSessionField(sessionID, field string, value interface{}) error { +func (s *Storage) UpdateAppSessionField(sessionID, field string, value any) error { query := ` UPDATE sessions diff --git a/internal/ui/middleware.go b/internal/ui/middleware.go index 93cd9491..e06fb98e 100644 --- a/internal/ui/middleware.go +++ b/internal/ui/middleware.go @@ -94,7 +94,7 @@ func (m *middleware) handleAppSession(next http.Handler) http.Handler { return } - html.BadRequest(w, r, errors.New("Invalid or missing CSRF")) + html.BadRequest(w, r, errors.New("invalid or missing CSRF")) return } } @@ -103,6 +103,7 @@ func (m *middleware) handleAppSession(next http.Handler) http.Handler { ctx = context.WithValue(ctx, request.SessionIDContextKey, session.ID) ctx = context.WithValue(ctx, request.CSRFContextKey, session.Data.CSRF) ctx = context.WithValue(ctx, request.OAuth2StateContextKey, session.Data.OAuth2State) + ctx = context.WithValue(ctx, request.OAuth2CodeVerifierContextKey, session.Data.OAuth2CodeVerifier) ctx = context.WithValue(ctx, request.FlashMessageContextKey, session.Data.FlashMessage) ctx = context.WithValue(ctx, request.FlashErrorMessageContextKey, session.Data.FlashErrorMessage) ctx = context.WithValue(ctx, request.UserLanguageContextKey, session.Data.Language) diff --git a/internal/ui/oauth2_callback.go b/internal/ui/oauth2_callback.go index f3d7a193..01e7abdc 100644 --- a/internal/ui/oauth2_callback.go +++ b/internal/ui/oauth2_callback.go @@ -4,6 +4,7 @@ package ui // import "miniflux.app/v2/internal/ui" import ( + "crypto/subtle" "errors" "net/http" @@ -38,7 +39,7 @@ func (h *handler) oauth2Callback(w http.ResponseWriter, r *http.Request) { } state := request.QueryStringParam(r, "state", "") - if state == "" || state != request.OAuth2State(r) { + if subtle.ConstantTimeCompare([]byte(state), []byte(request.OAuth2State(r))) == 0 { logger.Error(`[OAuth2] Invalid state value: got "%s" instead of "%s"`, state, request.OAuth2State(r)) html.Redirect(w, r, route.Path(h.router, "login")) return @@ -51,7 +52,7 @@ func (h *handler) oauth2Callback(w http.ResponseWriter, r *http.Request) { return } - profile, err := authProvider.GetProfile(r.Context(), code) + profile, err := authProvider.GetProfile(r.Context(), code, request.OAuth2CodeVerifier(r)) if err != nil { logger.Error("[OAuth2] %v", err) html.Redirect(w, r, route.Path(h.router, "login")) diff --git a/internal/ui/oauth2_redirect.go b/internal/ui/oauth2_redirect.go index 842b71e5..622b544c 100644 --- a/internal/ui/oauth2_redirect.go +++ b/internal/ui/oauth2_redirect.go @@ -10,6 +10,7 @@ import ( "miniflux.app/v2/internal/http/response/html" "miniflux.app/v2/internal/http/route" "miniflux.app/v2/internal/logger" + "miniflux.app/v2/internal/oauth2" "miniflux.app/v2/internal/ui/session" ) @@ -30,5 +31,10 @@ func (h *handler) oauth2Redirect(w http.ResponseWriter, r *http.Request) { return } - html.Redirect(w, r, authProvider.GetRedirectURL(sess.NewOAuth2State())) + auth := oauth2.GenerateAuthorization(authProvider.GetConfig()) + + sess.SetOAuth2State(auth.State()) + sess.SetOAuth2CodeVerifier(auth.CodeVerifier()) + + html.Redirect(w, r, auth.RedirectURL()) } diff --git a/internal/ui/session/session.go b/internal/ui/session/session.go index 9b656682..619c383a 100644 --- a/internal/ui/session/session.go +++ b/internal/ui/session/session.go @@ -4,7 +4,6 @@ package session // import "miniflux.app/v2/internal/ui/session" import ( - "miniflux.app/v2/internal/crypto" "miniflux.app/v2/internal/storage" ) @@ -14,11 +13,12 @@ type Session struct { sessionID string } -// NewOAuth2State generates a new OAuth2 state and stores the value into the database. -func (s *Session) NewOAuth2State() string { - state := crypto.GenerateRandomString(32) +func (s *Session) SetOAuth2State(state string) { s.store.UpdateAppSessionField(s.sessionID, "oauth2_state", state) - return state +} + +func (s *Session) SetOAuth2CodeVerifier(codeVerfier string) { + s.store.UpdateAppSessionField(s.sessionID, "oauth2_code_verifier", codeVerfier) } // NewFlashMessage creates a new flash message.