From ebec95a52280980caa88b9c8cd92d69c1a7dc164 Mon Sep 17 00:00:00 2001 From: tobi <31960611+tsmethurst@users.noreply.github.com> Date: Thu, 2 May 2024 14:43:00 +0200 Subject: [PATCH] [bugfix] Lock when checking/creating notifs to avoid race (#2890) * [bugfix] Lock when checking/creating notifs to avoid race * test notif spam --- internal/processing/account/move.go | 2 +- internal/processing/admin/accountapprove.go | 2 +- internal/processing/admin/accountreject.go | 2 +- internal/processing/status/pin.go | 4 +- internal/processing/workers/fromclientapi.go | 2 +- internal/processing/workers/fromfediapi.go | 2 +- internal/processing/workers/surface.go | 14 +-- internal/processing/workers/surfaceemail.go | 52 ++++---- internal/processing/workers/surfacenotify.go | 110 +++++++++++------ .../processing/workers/surfacenotify_test.go | 115 ++++++++++++++++++ .../processing/workers/surfacetimeline.go | 68 +++++------ internal/processing/workers/util.go | 18 +-- internal/processing/workers/workers.go | 12 +- internal/processing/workers/workers_test.go | 3 + internal/state/state.go | 17 +-- 15 files changed, 290 insertions(+), 133 deletions(-) create mode 100644 internal/processing/workers/surfacenotify_test.go diff --git a/internal/processing/account/move.go b/internal/processing/account/move.go index 63187dfd1..21e4f887b 100644 --- a/internal/processing/account/move.go +++ b/internal/processing/account/move.go @@ -113,7 +113,7 @@ func (p *Processor) MoveSelf( // in quick succession, so get a lock on // this account. lockKey := originAcct.URI - unlock := p.state.AccountLocks.Lock(lockKey) + unlock := p.state.ProcessingLocks.Lock(lockKey) defer unlock() // Ensure we have a valid, up-to-date representation of the target account. diff --git a/internal/processing/admin/accountapprove.go b/internal/processing/admin/accountapprove.go index ebc91ee0c..c3f6409c3 100644 --- a/internal/processing/admin/accountapprove.go +++ b/internal/processing/admin/accountapprove.go @@ -49,7 +49,7 @@ func (p *Processor) AccountApprove( // Get a lock on the account URI, // to ensure it's not also being // rejected at the same time! - unlock := p.state.AccountLocks.Lock(user.Account.URI) + unlock := p.state.ProcessingLocks.Lock(user.Account.URI) defer unlock() if !*user.Approved { diff --git a/internal/processing/admin/accountreject.go b/internal/processing/admin/accountreject.go index e7d54be41..8cb54cad6 100644 --- a/internal/processing/admin/accountreject.go +++ b/internal/processing/admin/accountreject.go @@ -52,7 +52,7 @@ func (p *Processor) AccountReject( // Get a lock on the account URI, // since we're going to be deleting // it and its associated user. - unlock := p.state.AccountLocks.Lock(user.Account.URI) + unlock := p.state.ProcessingLocks.Lock(user.Account.URI) defer unlock() // Can't reject an account with a diff --git a/internal/processing/status/pin.go b/internal/processing/status/pin.go index d0688331b..c4a6fc3b8 100644 --- a/internal/processing/status/pin.go +++ b/internal/processing/status/pin.go @@ -83,7 +83,7 @@ func (p *Processor) PinCreate(ctx context.Context, requestingAccount *gtsmodel.A } // Get a lock on this account. - unlock := p.state.AccountLocks.Lock(requestingAccount.URI) + unlock := p.state.ProcessingLocks.Lock(requestingAccount.URI) defer unlock() if !targetStatus.PinnedAt.IsZero() { @@ -148,7 +148,7 @@ func (p *Processor) PinRemove(ctx context.Context, requestingAccount *gtsmodel.A } // Get a lock on this account. - unlock := p.state.AccountLocks.Lock(requestingAccount.URI) + unlock := p.state.ProcessingLocks.Lock(requestingAccount.URI) defer unlock() if targetStatus.PinnedAt.IsZero() { diff --git a/internal/processing/workers/fromclientapi.go b/internal/processing/workers/fromclientapi.go index 4564afbb9..a9e33892f 100644 --- a/internal/processing/workers/fromclientapi.go +++ b/internal/processing/workers/fromclientapi.go @@ -41,7 +41,7 @@ import ( type clientAPI struct { state *state.State converter *typeutils.Converter - surface *surface + surface *Surface federate *federate account *account.Processor utils *utils diff --git a/internal/processing/workers/fromfediapi.go b/internal/processing/workers/fromfediapi.go index fcd5b38f2..49756a47a 100644 --- a/internal/processing/workers/fromfediapi.go +++ b/internal/processing/workers/fromfediapi.go @@ -41,7 +41,7 @@ import ( // from the federation/ActivityPub API. type fediAPI struct { state *state.State - surface *surface + surface *Surface federate *federate account *account.Processor utils *utils diff --git a/internal/processing/workers/surface.go b/internal/processing/workers/surface.go index 09162b131..5ec905ae8 100644 --- a/internal/processing/workers/surface.go +++ b/internal/processing/workers/surface.go @@ -25,16 +25,16 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/typeutils" ) -// surface wraps functions for 'surfacing' the result +// Surface wraps functions for 'surfacing' the result // of processing a message, eg: // - timelining a status // - removing a status from timelines // - sending a notification to a user // - sending an email -type surface struct { - state *state.State - converter *typeutils.Converter - stream *stream.Processor - filter *visibility.Filter - emailSender email.Sender +type Surface struct { + State *state.State + Converter *typeutils.Converter + Stream *stream.Processor + Filter *visibility.Filter + EmailSender email.Sender } diff --git a/internal/processing/workers/surfaceemail.go b/internal/processing/workers/surfaceemail.go index 3a5b5e7f4..9d46ff45e 100644 --- a/internal/processing/workers/surfaceemail.go +++ b/internal/processing/workers/surfaceemail.go @@ -33,8 +33,8 @@ import ( // emailUserReportClosed emails the user who created the // given report, to inform them the report has been closed. -func (s *surface) emailUserReportClosed(ctx context.Context, report *gtsmodel.Report) error { - user, err := s.state.DB.GetUserByAccountID(ctx, report.Account.ID) +func (s *Surface) emailUserReportClosed(ctx context.Context, report *gtsmodel.Report) error { + user, err := s.State.DB.GetUserByAccountID(ctx, report.Account.ID) if err != nil { return gtserror.Newf("db error getting user: %w", err) } @@ -51,12 +51,12 @@ func (s *surface) emailUserReportClosed(ctx context.Context, report *gtsmodel.Re return nil } - instance, err := s.state.DB.GetInstance(ctx, config.GetHost()) + instance, err := s.State.DB.GetInstance(ctx, config.GetHost()) if err != nil { return gtserror.Newf("db error getting instance: %w", err) } - if err := s.state.DB.PopulateReport(ctx, report); err != nil { + if err := s.State.DB.PopulateReport(ctx, report); err != nil { return gtserror.Newf("error populating report: %w", err) } @@ -69,12 +69,12 @@ func (s *surface) emailUserReportClosed(ctx context.Context, report *gtsmodel.Re ActionTakenComment: report.ActionTaken, } - return s.emailSender.SendReportClosedEmail(user.Email, reportClosedData) + return s.EmailSender.SendReportClosedEmail(user.Email, reportClosedData) } // emailUserPleaseConfirm emails the given user // to ask them to confirm their email address. -func (s *surface) emailUserPleaseConfirm(ctx context.Context, user *gtsmodel.User) error { +func (s *Surface) emailUserPleaseConfirm(ctx context.Context, user *gtsmodel.User) error { if user.UnconfirmedEmail == "" || user.UnconfirmedEmail == user.Email { // User has already confirmed this @@ -82,7 +82,7 @@ func (s *surface) emailUserPleaseConfirm(ctx context.Context, user *gtsmodel.Use return nil } - instance, err := s.state.DB.GetInstance(ctx, config.GetHost()) + instance, err := s.State.DB.GetInstance(ctx, config.GetHost()) if err != nil { return gtserror.Newf("db error getting instance: %w", err) } @@ -97,7 +97,7 @@ func (s *surface) emailUserPleaseConfirm(ctx context.Context, user *gtsmodel.Use ) // Assemble email contents and send the email. - if err := s.emailSender.SendConfirmEmail( + if err := s.EmailSender.SendConfirmEmail( user.UnconfirmedEmail, email.ConfirmData{ Username: user.Account.Username, @@ -116,7 +116,7 @@ func (s *surface) emailUserPleaseConfirm(ctx context.Context, user *gtsmodel.Use user.ConfirmationSentAt = now user.LastEmailedAt = now - if err := s.state.DB.UpdateUser( + if err := s.State.DB.UpdateUser( ctx, user, "confirmation_token", @@ -131,7 +131,7 @@ func (s *surface) emailUserPleaseConfirm(ctx context.Context, user *gtsmodel.Use // emailUserSignupApproved emails the given user // to inform them their sign-up has been approved. -func (s *surface) emailUserSignupApproved(ctx context.Context, user *gtsmodel.User) error { +func (s *Surface) emailUserSignupApproved(ctx context.Context, user *gtsmodel.User) error { // User may have been approved without // their email address being confirmed // yet. Just send to whatever we have. @@ -140,13 +140,13 @@ func (s *surface) emailUserSignupApproved(ctx context.Context, user *gtsmodel.Us emailAddr = user.UnconfirmedEmail } - instance, err := s.state.DB.GetInstance(ctx, config.GetHost()) + instance, err := s.State.DB.GetInstance(ctx, config.GetHost()) if err != nil { return gtserror.Newf("db error getting instance: %w", err) } // Assemble email contents and send the email. - if err := s.emailSender.SendSignupApprovedEmail( + if err := s.EmailSender.SendSignupApprovedEmail( emailAddr, email.SignupApprovedData{ Username: user.Account.Username, @@ -162,7 +162,7 @@ func (s *surface) emailUserSignupApproved(ctx context.Context, user *gtsmodel.Us now := time.Now() user.LastEmailedAt = now - if err := s.state.DB.UpdateUser( + if err := s.State.DB.UpdateUser( ctx, user, "last_emailed_at", @@ -175,14 +175,14 @@ func (s *surface) emailUserSignupApproved(ctx context.Context, user *gtsmodel.Us // emailUserSignupApproved emails the given user // to inform them their sign-up has been approved. -func (s *surface) emailUserSignupRejected(ctx context.Context, deniedUser *gtsmodel.DeniedUser) error { - instance, err := s.state.DB.GetInstance(ctx, config.GetHost()) +func (s *Surface) emailUserSignupRejected(ctx context.Context, deniedUser *gtsmodel.DeniedUser) error { + instance, err := s.State.DB.GetInstance(ctx, config.GetHost()) if err != nil { return gtserror.Newf("db error getting instance: %w", err) } // Assemble email contents and send the email. - return s.emailSender.SendSignupRejectedEmail( + return s.EmailSender.SendSignupRejectedEmail( deniedUser.Email, email.SignupRejectedData{ Message: deniedUser.Message, @@ -194,13 +194,13 @@ func (s *surface) emailUserSignupRejected(ctx context.Context, deniedUser *gtsmo // emailAdminReportOpened emails all active moderators/admins // of this instance that a new report has been created. -func (s *surface) emailAdminReportOpened(ctx context.Context, report *gtsmodel.Report) error { - instance, err := s.state.DB.GetInstance(ctx, config.GetHost()) +func (s *Surface) emailAdminReportOpened(ctx context.Context, report *gtsmodel.Report) error { + instance, err := s.State.DB.GetInstance(ctx, config.GetHost()) if err != nil { return gtserror.Newf("error getting instance: %w", err) } - toAddresses, err := s.state.DB.GetInstanceModeratorAddresses(ctx) + toAddresses, err := s.State.DB.GetInstanceModeratorAddresses(ctx) if err != nil { if errors.Is(err, db.ErrNoEntries) { // No registered moderator addresses. @@ -209,7 +209,7 @@ func (s *surface) emailAdminReportOpened(ctx context.Context, report *gtsmodel.R return gtserror.Newf("error getting instance moderator addresses: %w", err) } - if err := s.state.DB.PopulateReport(ctx, report); err != nil { + if err := s.State.DB.PopulateReport(ctx, report); err != nil { return gtserror.Newf("error populating report: %w", err) } @@ -221,7 +221,7 @@ func (s *surface) emailAdminReportOpened(ctx context.Context, report *gtsmodel.R ReportTargetDomain: report.TargetAccount.Domain, } - if err := s.emailSender.SendNewReportEmail(toAddresses, reportData); err != nil { + if err := s.EmailSender.SendNewReportEmail(toAddresses, reportData); err != nil { return gtserror.Newf("error emailing instance moderators: %w", err) } @@ -230,13 +230,13 @@ func (s *surface) emailAdminReportOpened(ctx context.Context, report *gtsmodel.R // emailAdminNewSignup emails all active moderators/admins of this // instance that a new account sign-up has been submitted to the instance. -func (s *surface) emailAdminNewSignup(ctx context.Context, newUser *gtsmodel.User) error { - instance, err := s.state.DB.GetInstance(ctx, config.GetHost()) +func (s *Surface) emailAdminNewSignup(ctx context.Context, newUser *gtsmodel.User) error { + instance, err := s.State.DB.GetInstance(ctx, config.GetHost()) if err != nil { return gtserror.Newf("error getting instance: %w", err) } - toAddresses, err := s.state.DB.GetInstanceModeratorAddresses(ctx) + toAddresses, err := s.State.DB.GetInstanceModeratorAddresses(ctx) if err != nil { if errors.Is(err, db.ErrNoEntries) { // No registered moderator addresses. @@ -246,7 +246,7 @@ func (s *surface) emailAdminNewSignup(ctx context.Context, newUser *gtsmodel.Use } // Ensure user populated. - if err := s.state.DB.PopulateUser(ctx, newUser); err != nil { + if err := s.State.DB.PopulateUser(ctx, newUser); err != nil { return gtserror.Newf("error populating user: %w", err) } @@ -259,7 +259,7 @@ func (s *surface) emailAdminNewSignup(ctx context.Context, newUser *gtsmodel.Use SignupURL: instance.URI + "/settings/admin/accounts/" + newUser.AccountID, } - if err := s.emailSender.SendNewSignupEmail(toAddresses, newSignupData); err != nil { + if err := s.EmailSender.SendNewSignupEmail(toAddresses, newSignupData); err != nil { return gtserror.Newf("error emailing instance moderators: %w", err) } diff --git a/internal/processing/workers/surfacenotify.go b/internal/processing/workers/surfacenotify.go index 9c82712f2..be729fa7e 100644 --- a/internal/processing/workers/surfacenotify.go +++ b/internal/processing/workers/surfacenotify.go @@ -20,18 +20,20 @@ package workers import ( "context" "errors" + "strings" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtscontext" "github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/id" + "github.com/superseriousbusiness/gotosocial/internal/util" ) // notifyMentions iterates through mentions on the // given status, and notifies each mentioned account // that they have a new mention. -func (s *surface) notifyMentions( +func (s *Surface) notifyMentions( ctx context.Context, status *gtsmodel.Status, ) error { @@ -43,7 +45,7 @@ func (s *surface) notifyMentions( mention.Status = status // Beforehand, ensure the passed mention is fully populated. - if err := s.state.DB.PopulateMention(ctx, mention); err != nil { + if err := s.State.DB.PopulateMention(ctx, mention); err != nil { errs.Appendf("error populating mention %s: %w", mention.ID, err) continue } @@ -56,7 +58,7 @@ func (s *surface) notifyMentions( // Ensure thread not muted // by mentioned account. - muted, err := s.state.DB.IsThreadMutedByAccount( + muted, err := s.State.DB.IsThreadMutedByAccount( ctx, status.ThreadID, mention.TargetAccountID, @@ -75,7 +77,7 @@ func (s *surface) notifyMentions( // notify mentioned // by status author. - if err := s.notify(ctx, + if err := s.Notify(ctx, gtsmodel.NotificationMention, mention.TargetAccount, mention.OriginAccount, @@ -91,12 +93,12 @@ func (s *surface) notifyMentions( // notifyFollowRequest notifies the target of the given // follow request that they have a new follow request. -func (s *surface) notifyFollowRequest( +func (s *Surface) notifyFollowRequest( ctx context.Context, followReq *gtsmodel.FollowRequest, ) error { // Beforehand, ensure the passed follow request is fully populated. - if err := s.state.DB.PopulateFollowRequest(ctx, followReq); err != nil { + if err := s.State.DB.PopulateFollowRequest(ctx, followReq); err != nil { return gtserror.Newf("error populating follow request %s: %w", followReq.ID, err) } @@ -107,7 +109,7 @@ func (s *surface) notifyFollowRequest( } // Now notify the follow request itself. - if err := s.notify(ctx, + if err := s.Notify(ctx, gtsmodel.NotificationFollowRequest, followReq.TargetAccount, followReq.Account, @@ -123,12 +125,12 @@ func (s *surface) notifyFollowRequest( // they have a new follow. It will also remove any previous // notification of a follow request, essentially replacing // that notification. -func (s *surface) notifyFollow( +func (s *Surface) notifyFollow( ctx context.Context, follow *gtsmodel.Follow, ) error { // Beforehand, ensure the passed follow is fully populated. - if err := s.state.DB.PopulateFollow(ctx, follow); err != nil { + if err := s.State.DB.PopulateFollow(ctx, follow); err != nil { return gtserror.Newf("error populating follow %s: %w", follow.ID, err) } @@ -139,7 +141,7 @@ func (s *surface) notifyFollow( } // Check if previous follow req notif exists. - prevNotif, err := s.state.DB.GetNotification( + prevNotif, err := s.State.DB.GetNotification( gtscontext.SetBarebones(ctx), gtsmodel.NotificationFollowRequest, follow.TargetAccountID, @@ -152,14 +154,14 @@ func (s *surface) notifyFollow( if prevNotif != nil { // Previous follow request notif existed, delete it before creating new. - if err := s.state.DB.DeleteNotificationByID(ctx, prevNotif.ID); // nocollapse + if err := s.State.DB.DeleteNotificationByID(ctx, prevNotif.ID); // nocollapse err != nil && !errors.Is(err, db.ErrNoEntries) { return gtserror.Newf("error deleting notification %s: %w", prevNotif.ID, err) } } // Now notify the follow itself. - if err := s.notify(ctx, + if err := s.Notify(ctx, gtsmodel.NotificationFollow, follow.TargetAccount, follow.Account, @@ -173,7 +175,7 @@ func (s *surface) notifyFollow( // notifyFave notifies the target of the given // fave that their status has been liked/faved. -func (s *surface) notifyFave( +func (s *Surface) notifyFave( ctx context.Context, fave *gtsmodel.StatusFave, ) error { @@ -183,7 +185,7 @@ func (s *surface) notifyFave( } // Beforehand, ensure the passed status fave is fully populated. - if err := s.state.DB.PopulateStatusFave(ctx, fave); err != nil { + if err := s.State.DB.PopulateStatusFave(ctx, fave); err != nil { return gtserror.Newf("error populating fave %s: %w", fave.ID, err) } @@ -195,7 +197,7 @@ func (s *surface) notifyFave( // Ensure favee hasn't // muted the thread. - muted, err := s.state.DB.IsThreadMutedByAccount( + muted, err := s.State.DB.IsThreadMutedByAccount( ctx, fave.Status.ThreadID, fave.TargetAccountID, @@ -212,7 +214,7 @@ func (s *surface) notifyFave( // notify status author // of fave by account. - if err := s.notify(ctx, + if err := s.Notify(ctx, gtsmodel.NotificationFave, fave.TargetAccount, fave.Account, @@ -226,7 +228,7 @@ func (s *surface) notifyFave( // notifyAnnounce notifies the status boost target // account that their status has been boosted. -func (s *surface) notifyAnnounce( +func (s *Surface) notifyAnnounce( ctx context.Context, status *gtsmodel.Status, ) error { @@ -241,7 +243,7 @@ func (s *surface) notifyAnnounce( } // Beforehand, ensure the passed status is fully populated. - if err := s.state.DB.PopulateStatus(ctx, status); err != nil { + if err := s.State.DB.PopulateStatus(ctx, status); err != nil { return gtserror.Newf("error populating status %s: %w", status.ID, err) } @@ -253,7 +255,7 @@ func (s *surface) notifyAnnounce( // Ensure boostee hasn't // muted the thread. - muted, err := s.state.DB.IsThreadMutedByAccount( + muted, err := s.State.DB.IsThreadMutedByAccount( ctx, status.BoostOf.ThreadID, status.BoostOfAccountID, @@ -271,7 +273,7 @@ func (s *surface) notifyAnnounce( // notify status author // of boost by account. - if err := s.notify(ctx, + if err := s.Notify(ctx, gtsmodel.NotificationReblog, status.BoostOfAccount, status.Account, @@ -283,14 +285,14 @@ func (s *surface) notifyAnnounce( return nil } -func (s *surface) notifyPollClose(ctx context.Context, status *gtsmodel.Status) error { +func (s *Surface) notifyPollClose(ctx context.Context, status *gtsmodel.Status) error { // Beforehand, ensure the passed status is fully populated. - if err := s.state.DB.PopulateStatus(ctx, status); err != nil { + if err := s.State.DB.PopulateStatus(ctx, status); err != nil { return gtserror.Newf("error populating status %s: %w", status.ID, err) } // Fetch all votes in the attached status poll. - votes, err := s.state.DB.GetPollVotes(ctx, status.PollID) + votes, err := s.State.DB.GetPollVotes(ctx, status.PollID) if err != nil { return gtserror.Newf("error getting poll %s votes: %w", status.PollID, err) } @@ -300,7 +302,7 @@ func (s *surface) notifyPollClose(ctx context.Context, status *gtsmodel.Status) if status.Account.IsLocal() { // Send a notification to the status // author that their poll has closed! - if err := s.notify(ctx, + if err := s.Notify(ctx, gtsmodel.NotificationPoll, status.Account, status.Account, @@ -319,7 +321,7 @@ func (s *surface) notifyPollClose(ctx context.Context, status *gtsmodel.Status) // notify voter that // poll has been closed. - if err := s.notify(ctx, + if err := s.Notify(ctx, gtsmodel.NotificationPoll, vote.Account, status.Account, @@ -333,8 +335,8 @@ func (s *surface) notifyPollClose(ctx context.Context, status *gtsmodel.Status) return errs.Combine() } -func (s *surface) notifySignup(ctx context.Context, newUser *gtsmodel.User) error { - modAccounts, err := s.state.DB.GetInstanceModerators(ctx) +func (s *Surface) notifySignup(ctx context.Context, newUser *gtsmodel.User) error { + modAccounts, err := s.State.DB.GetInstanceModerators(ctx) if err != nil { if errors.Is(err, db.ErrNoEntries) { // No registered @@ -347,18 +349,18 @@ func (s *surface) notifySignup(ctx context.Context, newUser *gtsmodel.User) erro } // Ensure user + account populated. - if err := s.state.DB.PopulateUser(ctx, newUser); err != nil { + if err := s.State.DB.PopulateUser(ctx, newUser); err != nil { return gtserror.Newf("db error populating new user: %w", err) } - if err := s.state.DB.PopulateAccount(ctx, newUser.Account); err != nil { + if err := s.State.DB.PopulateAccount(ctx, newUser.Account); err != nil { return gtserror.Newf("db error populating new user's account: %w", err) } // Notify each moderator. var errs gtserror.MultiError for _, mod := range modAccounts { - if err := s.notify(ctx, + if err := s.Notify(ctx, gtsmodel.NotificationSignup, mod, newUser.Account, @@ -372,7 +374,24 @@ func (s *surface) notifySignup(ctx context.Context, newUser *gtsmodel.User) erro return errs.Combine() } -// notify creates, inserts, and streams a new +func getNotifyLockURI( + notificationType gtsmodel.NotificationType, + targetAccount *gtsmodel.Account, + originAccount *gtsmodel.Account, + statusID string, +) string { + builder := strings.Builder{} + builder.WriteString("notification:?") + builder.WriteString("type=" + string(notificationType)) + builder.WriteString("&target=" + targetAccount.URI) + builder.WriteString("&origin=" + originAccount.URI) + if statusID != "" { + builder.WriteString("&statusID=" + statusID) + } + return builder.String() +} + +// Notify creates, inserts, and streams a new // notification to the target account if it // doesn't yet exist with the given parameters. // @@ -383,7 +402,7 @@ func (s *surface) notifySignup(ctx context.Context, newUser *gtsmodel.User) erro // // targetAccount and originAccount must be // set, but statusID can be an empty string. -func (s *surface) notify( +func (s *Surface) Notify( ctx context.Context, notificationType gtsmodel.NotificationType, targetAccount *gtsmodel.Account, @@ -395,9 +414,24 @@ func (s *surface) notify( return nil } + // We're doing state-y stuff so get a + // lock on this combo of notif params. + lockURI := getNotifyLockURI( + notificationType, + targetAccount, + originAccount, + statusID, + ) + unlock := s.State.ProcessingLocks.Lock(lockURI) + + // Wrap the unlock so we + // can do granular unlocking. + unlock = util.DoOnce(unlock) + defer unlock() + // Make sure a notification doesn't // already exist with these params. - if _, err := s.state.DB.GetNotification( + if _, err := s.State.DB.GetNotification( gtscontext.SetBarebones(ctx), notificationType, targetAccount.ID, @@ -424,16 +458,20 @@ func (s *surface) notify( StatusID: statusID, } - if err := s.state.DB.PutNotification(ctx, notif); err != nil { + if err := s.State.DB.PutNotification(ctx, notif); err != nil { return gtserror.Newf("error putting notification in database: %w", err) } + // Unlock already, we're done + // with the state-y stuff. + unlock() + // Stream notification to the user. - apiNotif, err := s.converter.NotificationToAPINotification(ctx, notif) + apiNotif, err := s.Converter.NotificationToAPINotification(ctx, notif) if err != nil { return gtserror.Newf("error converting notification to api representation: %w", err) } - s.stream.Notify(ctx, targetAccount, apiNotif) + s.Stream.Notify(ctx, targetAccount, apiNotif) return nil } diff --git a/internal/processing/workers/surfacenotify_test.go b/internal/processing/workers/surfacenotify_test.go new file mode 100644 index 000000000..7b448781d --- /dev/null +++ b/internal/processing/workers/surfacenotify_test.go @@ -0,0 +1,115 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package workers_test + +import ( + "context" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/suite" + "github.com/superseriousbusiness/gotosocial/internal/filter/visibility" + "github.com/superseriousbusiness/gotosocial/internal/gtscontext" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/processing/workers" +) + +type SurfaceNotifyTestSuite struct { + WorkersTestSuite +} + +func (suite *SurfaceNotifyTestSuite) TestSpamNotifs() { + testStructs := suite.SetupTestStructs() + defer suite.TearDownTestStructs(testStructs) + + surface := &workers.Surface{ + State: testStructs.State, + Converter: testStructs.TypeConverter, + Stream: testStructs.Processor.Stream(), + Filter: visibility.NewFilter(testStructs.State), + EmailSender: testStructs.EmailSender, + } + + var ( + ctx = context.Background() + notificationType = gtsmodel.NotificationFollow + targetAccount = suite.testAccounts["local_account_1"] + originAccount = suite.testAccounts["local_account_2"] + ) + + // Set up a bunch of goroutines to surface + // a notification at exactly the same time. + wg := sync.WaitGroup{} + wg.Add(20) + startAt := time.Now().Add(2 * time.Second) + + for i := 0; i < 20; i++ { + go func() { + defer wg.Done() + + // Wait for it.... + untilTick := time.Until(startAt) + <-time.Tick(untilTick) + + // ...Go! + if err := surface.Notify(ctx, + notificationType, + targetAccount, + originAccount, + "", + ); err != nil { + suite.FailNow(err.Error()) + } + }() + } + + // Wait for all notif creation + // attempts to complete. + wg.Wait() + + // Get all notifs for this account. + notifs, err := testStructs.State.DB.GetAccountNotifications( + gtscontext.SetBarebones(ctx), + targetAccount.ID, + "", "", "", 0, nil, + ) + if err != nil { + suite.FailNow(err.Error()) + } + + var gotOne bool + for _, notif := range notifs { + if notif.NotificationType == notificationType && + notif.TargetAccountID == targetAccount.ID && + notif.OriginAccountID == originAccount.ID { + // This is the notif... + if gotOne { + // We already had + // the notif, d'oh! + suite.FailNow("already had notif") + } else { + gotOne = true + } + } + } +} + +func TestSurfaceNotifyTestSuite(t *testing.T) { + suite.Run(t, new(SurfaceNotifyTestSuite)) +} diff --git a/internal/processing/workers/surfacetimeline.go b/internal/processing/workers/surfacetimeline.go index 14634f846..65b039939 100644 --- a/internal/processing/workers/surfacetimeline.go +++ b/internal/processing/workers/surfacetimeline.go @@ -36,14 +36,14 @@ import ( // It will also handle notifications for any mentions attached to // the account, and notifications for any local accounts that want // to know when this account posts. -func (s *surface) timelineAndNotifyStatus(ctx context.Context, status *gtsmodel.Status) error { +func (s *Surface) timelineAndNotifyStatus(ctx context.Context, status *gtsmodel.Status) error { // Ensure status fully populated; including account, mentions, etc. - if err := s.state.DB.PopulateStatus(ctx, status); err != nil { + if err := s.State.DB.PopulateStatus(ctx, status); err != nil { return gtserror.Newf("error populating status with id %s: %w", status.ID, err) } // Get all local followers of the account that posted the status. - follows, err := s.state.DB.GetAccountLocalFollowers(ctx, status.AccountID) + follows, err := s.State.DB.GetAccountLocalFollowers(ctx, status.AccountID) if err != nil { return gtserror.Newf("error getting local followers of account %s: %w", status.AccountID, err) } @@ -79,7 +79,7 @@ func (s *surface) timelineAndNotifyStatus(ctx context.Context, status *gtsmodel. // adding the status to list timelines + home timelines of each // follower, as appropriate, and notifying each follower of the // new status, if the status is eligible for notification. -func (s *surface) timelineAndNotifyStatusForFollowers( +func (s *Surface) timelineAndNotifyStatusForFollowers( ctx context.Context, status *gtsmodel.Status, follows []*gtsmodel.Follow, @@ -98,7 +98,7 @@ func (s *surface) timelineAndNotifyStatusForFollowers( // If it's not timelineable, we can just stop early, since lists // are prettymuch subsets of the home timeline, so if it shouldn't // appear there, it shouldn't appear in lists either. - timelineable, err := s.filter.StatusHomeTimelineable( + timelineable, err := s.Filter.StatusHomeTimelineable( ctx, follow.Account, status, ) if err != nil { @@ -124,7 +124,7 @@ func (s *surface) timelineAndNotifyStatusForFollowers( // of this follow, if applicable. homeTimelined, err := s.timelineStatus( ctx, - s.state.Timelines.Home.IngestOne, + s.State.Timelines.Home.IngestOne, follow.AccountID, // home timelines are keyed by account ID follow.Account, status, @@ -160,7 +160,7 @@ func (s *surface) timelineAndNotifyStatusForFollowers( // - This is a top-level post (not a reply or boost). // // That means we can officially notify this one. - if err := s.notify(ctx, + if err := s.Notify(ctx, gtsmodel.NotificationStatus, follow.Account, status.Account, @@ -175,7 +175,7 @@ func (s *surface) timelineAndNotifyStatusForFollowers( // listTimelineStatusForFollow puts the given status // in any eligible lists owned by the given follower. -func (s *surface) listTimelineStatusForFollow( +func (s *Surface) listTimelineStatusForFollow( ctx context.Context, status *gtsmodel.Status, follow *gtsmodel.Follow, @@ -189,7 +189,7 @@ func (s *surface) listTimelineStatusForFollow( // inclusion in the list. // Get every list entry that targets this follow's ID. - listEntries, err := s.state.DB.GetListEntriesForFollowID( + listEntries, err := s.State.DB.GetListEntriesForFollowID( // We only need the list IDs. gtscontext.SetBarebones(ctx), follow.ID, @@ -217,7 +217,7 @@ func (s *surface) listTimelineStatusForFollow( // list that this list entry belongs to. if _, err := s.timelineStatus( ctx, - s.state.Timelines.List.IngestOne, + s.State.Timelines.List.IngestOne, listEntry.ListID, // list timelines are keyed by list ID follow.Account, status, @@ -232,7 +232,7 @@ func (s *surface) listTimelineStatusForFollow( // listEligible checks if the given status is eligible // for inclusion in the list that that the given listEntry // belongs to, based on the replies policy of the list. -func (s *surface) listEligible( +func (s *Surface) listEligible( ctx context.Context, listEntry *gtsmodel.ListEntry, status *gtsmodel.Status, @@ -253,7 +253,7 @@ func (s *surface) listEligible( // We need to fetch the list that this // entry belongs to, in order to check // the list's replies policy. - list, err := s.state.DB.GetListByID( + list, err := s.State.DB.GetListByID( ctx, listEntry.ListID, ) if err != nil { @@ -273,7 +273,7 @@ func (s *surface) listEligible( // // Check if replied-to account is // also included in this list. - includes, err := s.state.DB.ListIncludesAccount( + includes, err := s.State.DB.ListIncludesAccount( ctx, list.ID, status.InReplyToAccountID, @@ -295,7 +295,7 @@ func (s *surface) listEligible( // // Check if replied-to account is // followed by list owner account. - follows, err := s.state.DB.IsFollowing( + follows, err := s.State.DB.IsFollowing( ctx, list.AccountID, status.InReplyToAccountID, @@ -325,7 +325,7 @@ func (s *surface) listEligible( // // If the status was inserted into the timeline, true will be returned // + it will also be streamed to the user using the given streamType. -func (s *surface) timelineStatus( +func (s *Surface) timelineStatus( ctx context.Context, ingest func(context.Context, string, timeline.Timelineable) (bool, error), timelineID string, @@ -343,26 +343,26 @@ func (s *surface) timelineStatus( } // The status was inserted so stream it to the user. - apiStatus, err := s.converter.StatusToAPIStatus(ctx, status, account) + apiStatus, err := s.Converter.StatusToAPIStatus(ctx, status, account) if err != nil { err = gtserror.Newf("error converting status %s to frontend representation: %w", status.ID, err) return true, err } - s.stream.Update(ctx, account, apiStatus, streamType) + s.Stream.Update(ctx, account, apiStatus, streamType) return true, nil } // deleteStatusFromTimelines completely removes the given status from all timelines. // It will also stream deletion of the status to all open streams. -func (s *surface) deleteStatusFromTimelines(ctx context.Context, statusID string) error { - if err := s.state.Timelines.Home.WipeItemFromAllTimelines(ctx, statusID); err != nil { +func (s *Surface) deleteStatusFromTimelines(ctx context.Context, statusID string) error { + if err := s.State.Timelines.Home.WipeItemFromAllTimelines(ctx, statusID); err != nil { return err } - if err := s.state.Timelines.List.WipeItemFromAllTimelines(ctx, statusID); err != nil { + if err := s.State.Timelines.List.WipeItemFromAllTimelines(ctx, statusID); err != nil { return err } - s.stream.Delete(ctx, statusID) + s.Stream.Delete(ctx, statusID) return nil } @@ -370,15 +370,15 @@ func (s *surface) deleteStatusFromTimelines(ctx context.Context, statusID string // unpreparing it from all timelines, forcing it to be prepared again (with updated // stats, boost counts, etc) next time it's fetched by the timeline owner. This goes // both for the status itself, and for any boosts of the status. -func (s *surface) invalidateStatusFromTimelines(ctx context.Context, statusID string) { - if err := s.state.Timelines.Home.UnprepareItemFromAllTimelines(ctx, statusID); err != nil { +func (s *Surface) invalidateStatusFromTimelines(ctx context.Context, statusID string) { + if err := s.State.Timelines.Home.UnprepareItemFromAllTimelines(ctx, statusID); err != nil { log. WithContext(ctx). WithField("statusID", statusID). Errorf("error unpreparing status from home timelines: %v", err) } - if err := s.state.Timelines.List.UnprepareItemFromAllTimelines(ctx, statusID); err != nil { + if err := s.State.Timelines.List.UnprepareItemFromAllTimelines(ctx, statusID); err != nil { log. WithContext(ctx). WithField("statusID", statusID). @@ -392,14 +392,14 @@ func (s *surface) invalidateStatusFromTimelines(ctx context.Context, statusID st // Note that calling invalidateStatusFromTimelines takes care of the // state in general, we just need to do this for any streams that are // open right now. -func (s *surface) timelineStatusUpdate(ctx context.Context, status *gtsmodel.Status) error { +func (s *Surface) timelineStatusUpdate(ctx context.Context, status *gtsmodel.Status) error { // Ensure status fully populated; including account, mentions, etc. - if err := s.state.DB.PopulateStatus(ctx, status); err != nil { + if err := s.State.DB.PopulateStatus(ctx, status); err != nil { return gtserror.Newf("error populating status with id %s: %w", status.ID, err) } // Get all local followers of the account that posted the status. - follows, err := s.state.DB.GetAccountLocalFollowers(ctx, status.AccountID) + follows, err := s.State.DB.GetAccountLocalFollowers(ctx, status.AccountID) if err != nil { return gtserror.Newf("error getting local followers of account %s: %w", status.AccountID, err) } @@ -427,7 +427,7 @@ func (s *surface) timelineStatusUpdate(ctx context.Context, status *gtsmodel.Sta // slice of followers of the account that posted the given status, // pushing update messages into open list/home streams of each // follower. -func (s *surface) timelineStatusUpdateForFollowers( +func (s *Surface) timelineStatusUpdateForFollowers( ctx context.Context, status *gtsmodel.Status, follows []*gtsmodel.Follow, @@ -444,7 +444,7 @@ func (s *surface) timelineStatusUpdateForFollowers( // If it's not timelineable, we can just stop early, since lists // are prettymuch subsets of the home timeline, so if it shouldn't // appear there, it shouldn't appear in lists either. - timelineable, err := s.filter.StatusHomeTimelineable( + timelineable, err := s.Filter.StatusHomeTimelineable( ctx, follow.Account, status, ) if err != nil { @@ -485,7 +485,7 @@ func (s *surface) timelineStatusUpdateForFollowers( // listTimelineStatusUpdateForFollow pushes edits of the given status // into any eligible lists streams opened by the given follower. -func (s *surface) listTimelineStatusUpdateForFollow( +func (s *Surface) listTimelineStatusUpdateForFollow( ctx context.Context, status *gtsmodel.Status, follow *gtsmodel.Follow, @@ -499,7 +499,7 @@ func (s *surface) listTimelineStatusUpdateForFollow( // inclusion in the list. // Get every list entry that targets this follow's ID. - listEntries, err := s.state.DB.GetListEntriesForFollowID( + listEntries, err := s.State.DB.GetListEntriesForFollowID( // We only need the list IDs. gtscontext.SetBarebones(ctx), follow.ID, @@ -539,17 +539,17 @@ func (s *surface) listTimelineStatusUpdateForFollow( // timelineStatusUpdate streams the edited status to the user using the // given streamType. -func (s *surface) timelineStreamStatusUpdate( +func (s *Surface) timelineStreamStatusUpdate( ctx context.Context, account *gtsmodel.Account, status *gtsmodel.Status, streamType string, ) error { - apiStatus, err := s.converter.StatusToAPIStatus(ctx, status, account) + apiStatus, err := s.Converter.StatusToAPIStatus(ctx, status, account) if err != nil { err = gtserror.Newf("error converting status %s to frontend representation: %w", status.ID, err) return err } - s.stream.StatusUpdate(ctx, account, apiStatus, streamType) + s.Stream.StatusUpdate(ctx, account, apiStatus, streamType) return nil } diff --git a/internal/processing/workers/util.go b/internal/processing/workers/util.go index a01982e1a..780e5ca14 100644 --- a/internal/processing/workers/util.go +++ b/internal/processing/workers/util.go @@ -38,7 +38,7 @@ type utils struct { state *state.State media *media.Processor account *account.Processor - surface *surface + surface *Surface } // wipeStatus encapsulates common logic @@ -245,7 +245,7 @@ func (u *utils) incrementStatusesCount( status *gtsmodel.Status, ) error { // Lock on this account since we're changing stats. - unlock := u.state.AccountLocks.Lock(account.URI) + unlock := u.state.ProcessingLocks.Lock(account.URI) defer unlock() // Populate stats. @@ -276,7 +276,7 @@ func (u *utils) decrementStatusesCount( account *gtsmodel.Account, ) error { // Lock on this account since we're changing stats. - unlock := u.state.AccountLocks.Lock(account.URI) + unlock := u.state.ProcessingLocks.Lock(account.URI) defer unlock() // Populate stats. @@ -310,7 +310,7 @@ func (u *utils) incrementFollowersCount( account *gtsmodel.Account, ) error { // Lock on this account since we're changing stats. - unlock := u.state.AccountLocks.Lock(account.URI) + unlock := u.state.ProcessingLocks.Lock(account.URI) defer unlock() // Populate stats. @@ -339,7 +339,7 @@ func (u *utils) decrementFollowersCount( account *gtsmodel.Account, ) error { // Lock on this account since we're changing stats. - unlock := u.state.AccountLocks.Lock(account.URI) + unlock := u.state.ProcessingLocks.Lock(account.URI) defer unlock() // Populate stats. @@ -373,7 +373,7 @@ func (u *utils) incrementFollowingCount( account *gtsmodel.Account, ) error { // Lock on this account since we're changing stats. - unlock := u.state.AccountLocks.Lock(account.URI) + unlock := u.state.ProcessingLocks.Lock(account.URI) defer unlock() // Populate stats. @@ -402,7 +402,7 @@ func (u *utils) decrementFollowingCount( account *gtsmodel.Account, ) error { // Lock on this account since we're changing stats. - unlock := u.state.AccountLocks.Lock(account.URI) + unlock := u.state.ProcessingLocks.Lock(account.URI) defer unlock() // Populate stats. @@ -436,7 +436,7 @@ func (u *utils) incrementFollowRequestsCount( account *gtsmodel.Account, ) error { // Lock on this account since we're changing stats. - unlock := u.state.AccountLocks.Lock(account.URI) + unlock := u.state.ProcessingLocks.Lock(account.URI) defer unlock() // Populate stats. @@ -465,7 +465,7 @@ func (u *utils) decrementFollowRequestsCount( account *gtsmodel.Account, ) error { // Lock on this account since we're changing stats. - unlock := u.state.AccountLocks.Lock(account.URI) + unlock := u.state.ProcessingLocks.Lock(account.URI) defer unlock() // Populate stats. diff --git a/internal/processing/workers/workers.go b/internal/processing/workers/workers.go index 1159b61a5..6b4cc07a6 100644 --- a/internal/processing/workers/workers.go +++ b/internal/processing/workers/workers.go @@ -55,12 +55,12 @@ func New( // Init surface logic // wrapper struct. - surface := &surface{ - state: state, - converter: converter, - stream: stream, - filter: filter, - emailSender: emailSender, + surface := &Surface{ + State: state, + Converter: converter, + Stream: stream, + Filter: filter, + EmailSender: emailSender, } // Init shared util funcs. diff --git a/internal/processing/workers/workers_test.go b/internal/processing/workers/workers_test.go index 5e2a78bb6..f66190d75 100644 --- a/internal/processing/workers/workers_test.go +++ b/internal/processing/workers/workers_test.go @@ -22,6 +22,7 @@ import ( "github.com/stretchr/testify/suite" "github.com/superseriousbusiness/gotosocial/internal/cleaner" + "github.com/superseriousbusiness/gotosocial/internal/email" "github.com/superseriousbusiness/gotosocial/internal/filter/visibility" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/oauth" @@ -68,6 +69,7 @@ type TestStructs struct { Processor *processing.Processor HTTPClient *testrig.MockHTTPClient TypeConverter *typeutils.Converter + EmailSender email.Sender } func (suite *WorkersTestSuite) SetupSuite() { @@ -168,6 +170,7 @@ func (suite *WorkersTestSuite) SetupTestStructs() *TestStructs { Processor: processor, HTTPClient: httpClient, TypeConverter: typeconverter, + EmailSender: emailSender, } } diff --git a/internal/state/state.go b/internal/state/state.go index f1eb5a9da..90683acd4 100644 --- a/internal/state/state.go +++ b/internal/state/state.go @@ -42,20 +42,21 @@ type State struct { // DB provides access to the database. DB db.DB - // FedLocks provides access to this state's - // mutex map of per URI federation locks. + // FedLocks provides access to this state's mutex + // map of per URI federation locks, intended for + // use in internal/federation functions. // // Used during account and status dereferencing, - // message processing in the FromFediAPI worker - // functions, and by the go-fed/activity library. + // and by the go-fed/activity library. FedLocks mutexes.MutexMap - // AccountLocks provides access to this state's + // ProcessingLocks provides access to this state's // mutex map of per URI locks, intended for use + // in internal/processing functions, for example // when updating accounts, migrating, approving - // or rejecting an account, changing stats, - // pinned statuses, etc. - AccountLocks mutexes.MutexMap + // or rejecting an account, changing stats or + // pinned statuses, creating notifs, etc. + ProcessingLocks mutexes.MutexMap // Storage provides access to the storage driver. Storage *storage.Driver