diff --git a/storage/integration.go b/storage/integration.go index dab8b1ee..48a6fdeb 100644 --- a/storage/integration.go +++ b/storage/integration.go @@ -175,17 +175,6 @@ func (s *Storage) UpdateIntegration(integration *model.Integration) error { return nil } -// CreateIntegration creates initial user integration settings. -func (s *Storage) CreateIntegration(userID int64) error { - query := `INSERT INTO integrations (user_id) VALUES ($1)` - _, err := s.db.Exec(query, userID) - if err != nil { - return fmt.Errorf(`store: unable to create integration row: %v`, err) - } - - return nil -} - // HasSaveEntry returns true if the given user can save articles to third-parties. func (s *Storage) HasSaveEntry(userID int64) (result bool) { query := ` diff --git a/storage/user.go b/storage/user.go index 757a1b1e..d967de4e 100644 --- a/storage/user.go +++ b/storage/user.go @@ -84,7 +84,12 @@ func (s *Storage) CreateUser(user *model.User) (err error) { openid_connect_id ` - err = s.db.QueryRow(query, user.Username, hashedPassword, user.IsAdmin, user.GoogleID, user.OpenIDConnectID).Scan( + tx, err := s.db.Begin() + if err != nil { + return fmt.Errorf(`store: unable to start transaction: %v`, err) + } + + err = tx.QueryRow(query, user.Username, hashedPassword, user.IsAdmin, user.GoogleID, user.OpenIDConnectID).Scan( &user.ID, &user.Username, &user.IsAdmin, @@ -101,11 +106,26 @@ func (s *Storage) CreateUser(user *model.User) (err error) { &user.OpenIDConnectID, ) if err != nil { + tx.Rollback() return fmt.Errorf(`store: unable to create user: %v`, err) } - s.CreateCategory(&model.Category{Title: "All", UserID: user.ID}) - s.CreateIntegration(user.ID) + _, err = tx.Exec(`INSERT INTO categories (user_id, title) VALUES ($1, $2)`, user.ID, "All") + if err != nil { + tx.Rollback() + return fmt.Errorf(`store: unable to create user default category: %v`, err) + } + + _, err = tx.Exec(`INSERT INTO integrations (user_id) VALUES ($1)`, user.ID) + if err != nil { + tx.Rollback() + return fmt.Errorf(`store: unable to create integration row: %v`, err) + } + + if err := tx.Commit(); err != nil { + return fmt.Errorf(`store: unable to commit transaction: %v`, err) + } + return nil } @@ -362,22 +382,22 @@ func (s *Storage) fetchUser(query string, args ...interface{}) (*model.User, err // RemoveUser deletes a user. func (s *Storage) RemoveUser(userID int64) error { - ts, err := s.db.Begin() + tx, err := s.db.Begin() if err != nil { return fmt.Errorf(`store: unable to start transaction: %v`, err) } - if _, err := ts.Exec(`DELETE FROM users WHERE id=$1`, userID); err != nil { - ts.Rollback() + if _, err := tx.Exec(`DELETE FROM users WHERE id=$1`, userID); err != nil { + tx.Rollback() return fmt.Errorf(`store: unable to remove user #%d: %v`, userID, err) } - if _, err := ts.Exec(`DELETE FROM integrations WHERE user_id=$1`, userID); err != nil { - ts.Rollback() + if _, err := tx.Exec(`DELETE FROM integrations WHERE user_id=$1`, userID); err != nil { + tx.Rollback() return fmt.Errorf(`store: unable to remove integration settings for user #%d: %v`, userID, err) } - if err := ts.Commit(); err != nil { + if err := tx.Commit(); err != nil { return fmt.Errorf(`store: unable to commit transaction: %v`, err) }