From 747da03e4cf1478937a434792fdb52dd10ae8cda Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fr=C3=A9d=C3=A9ric=20Guillot?= Date: Fri, 24 Nov 2017 16:09:10 -0800 Subject: [PATCH] Improve OAuth2 integration --- helper/crypto.go | 2 +- locale/translations.go | 8 +++--- locale/translations/fr_FR.json | 4 ++- server/oauth2/google.go | 6 ++++- server/oauth2/provider.go | 1 + server/routes.go | 1 + server/static/bin.go | 2 +- server/static/css.go | 2 +- server/static/js.go | 2 +- server/template/common.go | 2 +- server/template/html/settings.html | 10 +++++++ server/template/template.go | 7 +++++ server/template/views.go | 14 ++++++++-- server/ui/controller/oauth2.go | 42 ++++++++++++++++++++++++++++++ sql/sql.go | 2 +- storage/user.go | 30 +++++++++++++++++++-- 16 files changed, 120 insertions(+), 15 deletions(-) diff --git a/helper/crypto.go b/helper/crypto.go index ed18bb68..6a204165 100644 --- a/helper/crypto.go +++ b/helper/crypto.go @@ -26,7 +26,7 @@ func Hash(value string) string { func GenerateRandomBytes(size int) []byte { b := make([]byte, size) if _, err := rand.Read(b); err != nil { - panic(fmt.Errorf("Unable to generate random string: %v", err)) + panic(err) } return b diff --git a/locale/translations.go b/locale/translations.go index 5da99b47..dffced48 100644 --- a/locale/translations.go +++ b/locale/translations.go @@ -1,5 +1,5 @@ // Code generated by go generate; DO NOT EDIT. -// 2017-11-22 22:11:44.610818223 -0800 PST m=+0.024503556 +// 2017-11-24 16:04:49.318661623 -0800 PST m=+0.006828741 package locale @@ -146,12 +146,14 @@ var Translations = map[string]string{ "This special link allows you to subscribe to a website directly by using a bookmark in your web browser.": "Ce lien spécial vous permet de vous abonner à un site web directement en utilisant un marque page dans votre navigateur web.", "Add to Miniflux": "Ajouter à Miniflux", "Refresh all feeds in the background": "Actualiser tous les abonnements en arrière-plan", - "Sign in with Google": "Se connecter avec Google" + "Sign in with Google": "Se connecter avec Google", + "Unlink my Google account": "Dissocier mon compte Google", + "Link my Google account": "Associer mon compte Google" } `, } var TranslationsChecksums = map[string]string{ "en_US": "6fe95384260941e8a5a3c695a655a932e0a8a6a572c1e45cb2b1ae8baa01b897", - "fr_FR": "f413b0bc103b2ab689d52da2e17c5e718a91f5dc4138dc601beaae4ec0cfc1af", + "fr_FR": "f438ed9116ecc7b71412581255dd9b1332cacd9e2876615b03ec65e4b500bf02", } diff --git a/locale/translations/fr_FR.json b/locale/translations/fr_FR.json index c99bf3e5..e5b7eb15 100644 --- a/locale/translations/fr_FR.json +++ b/locale/translations/fr_FR.json @@ -130,5 +130,7 @@ "This special link allows you to subscribe to a website directly by using a bookmark in your web browser.": "Ce lien spécial vous permet de vous abonner à un site web directement en utilisant un marque page dans votre navigateur web.", "Add to Miniflux": "Ajouter à Miniflux", "Refresh all feeds in the background": "Actualiser tous les abonnements en arrière-plan", - "Sign in with Google": "Se connecter avec Google" + "Sign in with Google": "Se connecter avec Google", + "Unlink my Google account": "Dissocier mon compte Google", + "Link my Google account": "Associer mon compte Google" } diff --git a/server/oauth2/google.go b/server/oauth2/google.go index 5c63c75f..e57e027f 100644 --- a/server/oauth2/google.go +++ b/server/oauth2/google.go @@ -23,6 +23,10 @@ type googleProvider struct { redirectURL string } +func (g googleProvider) GetUserExtraKey() string { + return "google_id" +} + func (g googleProvider) GetRedirectURL(state string) string { return g.config().AuthCodeURL(state) } @@ -48,7 +52,7 @@ func (g googleProvider) GetProfile(code string) (*Profile, error) { return nil, fmt.Errorf("unable to unserialize google profile: %v", err) } - profile := &Profile{Key: "google_id", ID: user.Sub, Username: user.Email} + profile := &Profile{Key: g.GetUserExtraKey(), ID: user.Sub, Username: user.Email} return profile, nil } diff --git a/server/oauth2/provider.go b/server/oauth2/provider.go index 27ab22ac..c43931c0 100644 --- a/server/oauth2/provider.go +++ b/server/oauth2/provider.go @@ -6,6 +6,7 @@ package oauth2 // Provider is an interface for OAuth2 providers. type Provider interface { + GetUserExtraKey() string GetRedirectURL(state string) string GetProfile(code string) (*Profile, error) } diff --git a/server/routes.go b/server/routes.go index 7c227ebd..b6bcbc25 100644 --- a/server/routes.go +++ b/server/routes.go @@ -124,6 +124,7 @@ func getRoutes(cfg *config.Config, store *storage.Storage, feedHandler *feed.Han router.Handle("/import", uiHandler.Use(uiController.Import)).Name("import").Methods("GET") router.Handle("/upload", uiHandler.Use(uiController.UploadOPML)).Name("uploadOPML").Methods("POST") + router.Handle("/oauth2/{provider}/unlink", uiHandler.Use(uiController.OAuth2Unlink)).Name("oauth2Unlink").Methods("GET") router.Handle("/oauth2/{provider}/redirect", uiHandler.Use(uiController.OAuth2Redirect)).Name("oauth2Redirect").Methods("GET") router.Handle("/oauth2/{provider}/callback", uiHandler.Use(uiController.OAuth2Callback)).Name("oauth2Callback").Methods("GET") diff --git a/server/static/bin.go b/server/static/bin.go index f06a8d90..8769862d 100644 --- a/server/static/bin.go +++ b/server/static/bin.go @@ -1,5 +1,5 @@ // Code generated by go generate; DO NOT EDIT. -// 2017-11-22 22:11:44.595540312 -0800 PST m=+0.009225645 +// 2017-11-24 16:04:49.314940117 -0800 PST m=+0.003107235 package static diff --git a/server/static/css.go b/server/static/css.go index 17ddc99a..096fc5c9 100644 --- a/server/static/css.go +++ b/server/static/css.go @@ -1,5 +1,5 @@ // Code generated by go generate; DO NOT EDIT. -// 2017-11-22 22:11:44.596955262 -0800 PST m=+0.010640595 +// 2017-11-24 16:04:49.315340301 -0800 PST m=+0.003507419 package static diff --git a/server/static/js.go b/server/static/js.go index f302bb92..28eed6a1 100644 --- a/server/static/js.go +++ b/server/static/js.go @@ -1,5 +1,5 @@ // Code generated by go generate; DO NOT EDIT. -// 2017-11-22 22:11:44.598697812 -0800 PST m=+0.012383145 +// 2017-11-24 16:04:49.316027642 -0800 PST m=+0.004194760 package static diff --git a/server/template/common.go b/server/template/common.go index e65d6111..40fb76ae 100644 --- a/server/template/common.go +++ b/server/template/common.go @@ -1,5 +1,5 @@ // Code generated by go generate; DO NOT EDIT. -// 2017-11-22 22:11:44.609659332 -0800 PST m=+0.023344665 +// 2017-11-24 16:04:49.318279667 -0800 PST m=+0.006446785 package template diff --git a/server/template/html/settings.html b/server/template/html/settings.html index dcad092b..23a5f42e 100644 --- a/server/template/html/settings.html +++ b/server/template/html/settings.html @@ -63,4 +63,14 @@ +{{ if hasOAuth2Provider "google" }} +
+ {{ if hasKey .user.Extra "google_id" }} + {{ t "Unlink my Google account" }} + {{ else }} + {{ t "Link my Google account" }} + {{ end }} +
+{{ end }} + {{ end }} diff --git a/server/template/template.go b/server/template/template.go index 5ef68d64..627fd9fc 100644 --- a/server/template/template.go +++ b/server/template/template.go @@ -40,6 +40,13 @@ func (e *Engine) parseAll() { "hasOAuth2Provider": func(provider string) bool { return e.cfg.Get("OAUTH2_PROVIDER", "") == provider }, + "hasKey": func(dict map[string]string, key string) bool { + log.Println(dict) + if value, found := dict[key]; found { + return value != "" + } + return false + }, "route": func(name string, args ...interface{}) string { return route.GetRoute(e.router, name, args...) }, diff --git a/server/template/views.go b/server/template/views.go index 44121495..d82ab0a8 100644 --- a/server/template/views.go +++ b/server/template/views.go @@ -1,5 +1,5 @@ // Code generated by go generate; DO NOT EDIT. -// 2017-11-22 22:11:44.601583424 -0800 PST m=+0.015268757 +// 2017-11-24 16:04:49.316644027 -0800 PST m=+0.004811145 package template @@ -921,6 +921,16 @@ var templateViewsMap = map[string]string{ +{{ if hasOAuth2Provider "google" }} +
+ {{ if hasKey .user.Extra "google_id" }} + {{ t "Unlink my Google account" }} + {{ else }} + {{ t "Link my Google account" }} + {{ end }} +
+{{ end }} + {{ end }} `, "unread": `{{ define "title"}}{{ t "Unread Items" }} {{ if gt .countUnread 0 }}({{ .countUnread }}){{ end }} {{ end }} @@ -1052,7 +1062,7 @@ var templateViewsMapChecksums = map[string]string{ "integrations": "c485d6d9ed996635e55e73320610e6bcb01a41c1153e8e739ae2294b0b14b243", "login": "04f3ce79bfa5753f69e0d956c2a8999c0da549c7925634a3e8134975da0b0e0f", "sessions": "878dbe8f8ea783b44130c495814179519fa5c3aa2666ac87508f94d58dd008bf", - "settings": "a972fb5767fd32522648149880e40607ed8bbed7a389038bbab6b08539ac2893", + "settings": "1e2df11f5436eb2d05ae1fae30dd6f1362613011edbfcc79ae8b23854fa348b4", "unread": "b6f9be1a72188947c75a6fdcac6ff7878db7745f9efa46318e0433102892a722", "users": "44677e28bb5347799ed0020c90ec785aadec4b1454446d92411cfdaf6e32110b", } diff --git a/server/ui/controller/oauth2.go b/server/ui/controller/oauth2.go index c43d7070..c80ec71e 100644 --- a/server/ui/controller/oauth2.go +++ b/server/ui/controller/oauth2.go @@ -71,6 +71,17 @@ func (c *Controller) OAuth2Callback(ctx *core.Context, request *core.Request, re return } + if ctx.IsAuthenticated() { + user := ctx.LoggedUser() + if err := c.store.UpdateExtraField(user.ID, profile.Key, profile.ID); err != nil { + response.HTML().ServerError(err) + return + } + + response.Redirect(ctx.Route("settings")) + return + } + user, err := c.store.GetUserByExtraField(profile.Key, profile.ID) if err != nil { response.HTML().ServerError(err) @@ -78,6 +89,11 @@ func (c *Controller) OAuth2Callback(ctx *core.Context, request *core.Request, re } if user == nil { + if c.cfg.GetInt("OAUTH2_USER_CREATION", 0) == 0 { + response.HTML().Forbidden() + return + } + user = model.NewUser() user.Username = profile.Username user.IsAdmin = false @@ -114,6 +130,32 @@ func (c *Controller) OAuth2Callback(ctx *core.Context, request *core.Request, re response.Redirect(ctx.Route("unread")) } +// OAuth2Unlink unlink an account from the external provider. +func (c *Controller) OAuth2Unlink(ctx *core.Context, request *core.Request, response *core.Response) { + provider := request.StringParam("provider", "") + if provider == "" { + log.Println("[OAuth2] Invalid or missing provider") + response.Redirect(ctx.Route("login")) + return + } + + authProvider, err := getOAuth2Manager(c.cfg).Provider(provider) + if err != nil { + log.Println("[OAuth2]", err) + response.Redirect(ctx.Route("settings")) + return + } + + user := ctx.LoggedUser() + if err := c.store.RemoveExtraField(user.ID, authProvider.GetUserExtraKey()); err != nil { + response.HTML().ServerError(err) + return + } + + response.Redirect(ctx.Route("settings")) + return +} + func getOAuth2Manager(cfg *config.Config) *oauth2.Manager { return oauth2.NewManager( cfg.Get("OAUTH2_CLIENT_ID", ""), diff --git a/sql/sql.go b/sql/sql.go index 27deabfc..b8c0419d 100644 --- a/sql/sql.go +++ b/sql/sql.go @@ -1,5 +1,5 @@ // Code generated by go generate; DO NOT EDIT. -// 2017-11-22 22:11:44.590706207 -0800 PST m=+0.004391540 +// 2017-11-24 16:04:49.314265268 -0800 PST m=+0.002432386 package sql diff --git a/storage/user.go b/storage/user.go index caffc3b2..3b905a55 100644 --- a/storage/user.go +++ b/storage/user.go @@ -74,6 +74,24 @@ func (s *Storage) CreateUser(user *model.User) (err error) { return nil } +func (s *Storage) UpdateExtraField(userID int64, field, value string) error { + query := fmt.Sprintf(`UPDATE users SET extra = hstore('%s', $1) WHERE id=$2`, field) + _, err := s.db.Exec(query, value, userID) + if err != nil { + return fmt.Errorf("unable to update user extra field: %v", err) + } + return nil +} + +func (s *Storage) RemoveExtraField(userID int64, field string) error { + query := `UPDATE users SET extra = delete(extra, $1) WHERE id=$2` + _, err := s.db.Exec(query, field, userID) + if err != nil { + return fmt.Errorf("unable to remove user extra field: %v", err) + } + return nil +} + func (s *Storage) UpdateUser(user *model.User) error { defer helper.ExecutionTime(time.Now(), fmt.Sprintf("[Storage:UpdateUser] username=%s", user.Username)) user.Username = strings.ToLower(user.Username) @@ -104,14 +122,22 @@ func (s *Storage) GetUserById(userID int64) (*model.User, error) { defer helper.ExecutionTime(time.Now(), fmt.Sprintf("[Storage:GetUserById] userID=%d", userID)) var user model.User - row := s.db.QueryRow("SELECT id, username, is_admin, theme, language, timezone FROM users WHERE id = $1", userID) - err := row.Scan(&user.ID, &user.Username, &user.IsAdmin, &user.Theme, &user.Language, &user.Timezone) + var extra hstore.Hstore + row := s.db.QueryRow("SELECT id, username, is_admin, theme, language, timezone, extra FROM users WHERE id = $1", userID) + err := row.Scan(&user.ID, &user.Username, &user.IsAdmin, &user.Theme, &user.Language, &user.Timezone, &extra) if err == sql.ErrNoRows { return nil, nil } else if err != nil { return nil, fmt.Errorf("unable to fetch user: %v", err) } + user.Extra = make(map[string]string) + for key, value := range extra.Map { + if value.Valid { + user.Extra[key] = value.String + } + } + return &user, nil }