From 8acc7aab4c254c4819f45e512b86cf5a4255091f Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Fri, 29 Mar 2024 11:38:16 +0800 Subject: [PATCH] Refactor topic Find functions and add more tests for pagination (#30127) This also fixed #22238 --- models/repo/topic.go | 31 ++++++++++-------------- models/repo/topic_test.go | 14 +++++------ routers/api/v1/repo/topic.go | 7 +++--- routers/web/explore/topic.go | 2 +- routers/web/repo/view.go | 2 +- tests/integration/api_repo_topic_test.go | 22 ++++++++++++++++- tests/integration/repo_topic_test.go | 24 +++++++++++++++++- 7 files changed, 70 insertions(+), 32 deletions(-) diff --git a/models/repo/topic.go b/models/repo/topic.go index 79b13e320d..430a60f603 100644 --- a/models/repo/topic.go +++ b/models/repo/topic.go @@ -178,7 +178,7 @@ type FindTopicOptions struct { Keyword string } -func (opts *FindTopicOptions) toConds() builder.Cond { +func (opts *FindTopicOptions) ToConds() builder.Cond { cond := builder.NewCond() if opts.RepoID > 0 { cond = cond.And(builder.Eq{"repo_topic.repo_id": opts.RepoID}) @@ -191,29 +191,24 @@ func (opts *FindTopicOptions) toConds() builder.Cond { return cond } -// FindTopics retrieves the topics via FindTopicOptions -func FindTopics(ctx context.Context, opts *FindTopicOptions) ([]*Topic, int64, error) { - sess := db.GetEngine(ctx).Select("topic.*").Where(opts.toConds()) +func (opts *FindTopicOptions) ToOrders() string { orderBy := "topic.repo_count DESC" if opts.RepoID > 0 { - sess.Join("INNER", "repo_topic", "repo_topic.topic_id = topic.id") orderBy = "topic.name" // when render topics for a repo, it's better to sort them by name, to get consistent result } - if opts.PageSize != 0 && opts.Page != 0 { - sess = db.SetSessionPagination(sess, opts) - } - topics := make([]*Topic, 0, 10) - total, err := sess.OrderBy(orderBy).FindAndCount(&topics) - return topics, total, err + return orderBy } -// CountTopics counts the number of topics matching the FindTopicOptions -func CountTopics(ctx context.Context, opts *FindTopicOptions) (int64, error) { - sess := db.GetEngine(ctx).Where(opts.toConds()) - if opts.RepoID > 0 { - sess.Join("INNER", "repo_topic", "repo_topic.topic_id = topic.id") +func (opts *FindTopicOptions) ToJoins() []db.JoinFunc { + if opts.RepoID <= 0 { + return nil + } + return []db.JoinFunc{ + func(e db.Engine) error { + e.Join("INNER", "repo_topic", "repo_topic.topic_id = topic.id") + return nil + }, } - return sess.Count(new(Topic)) } // GetRepoTopicByName retrieves topic from name for a repo if it exist @@ -283,7 +278,7 @@ func DeleteTopic(ctx context.Context, repoID int64, topicName string) (*Topic, e // SaveTopics save topics to a repository func SaveTopics(ctx context.Context, repoID int64, topicNames ...string) error { - topics, _, err := FindTopics(ctx, &FindTopicOptions{ + topics, err := db.Find[Topic](ctx, &FindTopicOptions{ RepoID: repoID, }) if err != nil { diff --git a/models/repo/topic_test.go b/models/repo/topic_test.go index 2b609e6d66..1600896b6e 100644 --- a/models/repo/topic_test.go +++ b/models/repo/topic_test.go @@ -19,18 +19,18 @@ func TestAddTopic(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) - topics, _, err := repo_model.FindTopics(db.DefaultContext, &repo_model.FindTopicOptions{}) + topics, err := db.Find[repo_model.Topic](db.DefaultContext, &repo_model.FindTopicOptions{}) assert.NoError(t, err) assert.Len(t, topics, totalNrOfTopics) - topics, total, err := repo_model.FindTopics(db.DefaultContext, &repo_model.FindTopicOptions{ + topics, total, err := db.FindAndCount[repo_model.Topic](db.DefaultContext, &repo_model.FindTopicOptions{ ListOptions: db.ListOptions{Page: 1, PageSize: 2}, }) assert.NoError(t, err) assert.Len(t, topics, 2) assert.EqualValues(t, 6, total) - topics, _, err = repo_model.FindTopics(db.DefaultContext, &repo_model.FindTopicOptions{ + topics, err = db.Find[repo_model.Topic](db.DefaultContext, &repo_model.FindTopicOptions{ RepoID: 1, }) assert.NoError(t, err) @@ -38,11 +38,11 @@ func TestAddTopic(t *testing.T) { assert.NoError(t, repo_model.SaveTopics(db.DefaultContext, 2, "golang")) repo2NrOfTopics := 1 - topics, _, err = repo_model.FindTopics(db.DefaultContext, &repo_model.FindTopicOptions{}) + topics, err = db.Find[repo_model.Topic](db.DefaultContext, &repo_model.FindTopicOptions{}) assert.NoError(t, err) assert.Len(t, topics, totalNrOfTopics) - topics, _, err = repo_model.FindTopics(db.DefaultContext, &repo_model.FindTopicOptions{ + topics, err = db.Find[repo_model.Topic](db.DefaultContext, &repo_model.FindTopicOptions{ RepoID: 2, }) assert.NoError(t, err) @@ -55,11 +55,11 @@ func TestAddTopic(t *testing.T) { assert.NoError(t, err) assert.EqualValues(t, 1, topic.RepoCount) - topics, _, err = repo_model.FindTopics(db.DefaultContext, &repo_model.FindTopicOptions{}) + topics, err = db.Find[repo_model.Topic](db.DefaultContext, &repo_model.FindTopicOptions{}) assert.NoError(t, err) assert.Len(t, topics, totalNrOfTopics) - topics, _, err = repo_model.FindTopics(db.DefaultContext, &repo_model.FindTopicOptions{ + topics, err = db.Find[repo_model.Topic](db.DefaultContext, &repo_model.FindTopicOptions{ RepoID: 2, }) assert.NoError(t, err) diff --git a/routers/api/v1/repo/topic.go b/routers/api/v1/repo/topic.go index 1d8e675bde..9852caa989 100644 --- a/routers/api/v1/repo/topic.go +++ b/routers/api/v1/repo/topic.go @@ -7,6 +7,7 @@ import ( "net/http" "strings" + "code.gitea.io/gitea/models/db" repo_model "code.gitea.io/gitea/models/repo" "code.gitea.io/gitea/modules/log" api "code.gitea.io/gitea/modules/structs" @@ -53,7 +54,7 @@ func ListTopics(ctx *context.APIContext) { RepoID: ctx.Repo.Repository.ID, } - topics, total, err := repo_model.FindTopics(ctx, opts) + topics, total, err := db.FindAndCount[repo_model.Topic](ctx, opts) if err != nil { ctx.InternalServerError(err) return @@ -172,7 +173,7 @@ func AddTopic(ctx *context.APIContext) { } // Prevent adding more topics than allowed to repo - count, err := repo_model.CountTopics(ctx, &repo_model.FindTopicOptions{ + count, err := db.Count[repo_model.Topic](ctx, &repo_model.FindTopicOptions{ RepoID: ctx.Repo.Repository.ID, }) if err != nil { @@ -287,7 +288,7 @@ func TopicSearch(ctx *context.APIContext) { ListOptions: utils.GetListOptions(ctx), } - topics, total, err := repo_model.FindTopics(ctx, opts) + topics, total, err := db.FindAndCount[repo_model.Topic](ctx, opts) if err != nil { ctx.InternalServerError(err) return diff --git a/routers/web/explore/topic.go b/routers/web/explore/topic.go index 95fecfe2b8..b4507ba28d 100644 --- a/routers/web/explore/topic.go +++ b/routers/web/explore/topic.go @@ -23,7 +23,7 @@ func TopicSearch(ctx *context.Context) { }, } - topics, total, err := repo_model.FindTopics(ctx, opts) + topics, total, err := db.FindAndCount[repo_model.Topic](ctx, opts) if err != nil { ctx.Error(http.StatusInternalServerError) return diff --git a/routers/web/repo/view.go b/routers/web/repo/view.go index 73a7be4e89..93e0f5bcbd 100644 --- a/routers/web/repo/view.go +++ b/routers/web/repo/view.go @@ -899,7 +899,7 @@ func renderLanguageStats(ctx *context.Context) { } func renderRepoTopics(ctx *context.Context) { - topics, _, err := repo_model.FindTopics(ctx, &repo_model.FindTopicOptions{ + topics, err := db.Find[repo_model.Topic](ctx, &repo_model.FindTopicOptions{ RepoID: ctx.Repo.Repository.ID, }) if err != nil { diff --git a/tests/integration/api_repo_topic_test.go b/tests/integration/api_repo_topic_test.go index c41bc4abb6..a10e159b78 100644 --- a/tests/integration/api_repo_topic_test.go +++ b/tests/integration/api_repo_topic_test.go @@ -26,14 +26,34 @@ func TestAPITopicSearch(t *testing.T) { TopicNames []*api.TopicResponse `json:"topics"` } + // search all topics + res := MakeRequest(t, NewRequest(t, "GET", searchURL.String()), http.StatusOK) + DecodeJSON(t, res, &topics) + assert.Len(t, topics.TopicNames, 6) + assert.EqualValues(t, "6", res.Header().Get("x-total-count")) + + // pagination search topics first page + topics.TopicNames = nil query := url.Values{"page": []string{"1"}, "limit": []string{"4"}} searchURL.RawQuery = query.Encode() - res := MakeRequest(t, NewRequest(t, "GET", searchURL.String()), http.StatusOK) + res = MakeRequest(t, NewRequest(t, "GET", searchURL.String()), http.StatusOK) DecodeJSON(t, res, &topics) assert.Len(t, topics.TopicNames, 4) assert.EqualValues(t, "6", res.Header().Get("x-total-count")) + // pagination search topics second page + topics.TopicNames = nil + query = url.Values{"page": []string{"2"}, "limit": []string{"4"}} + + searchURL.RawQuery = query.Encode() + res = MakeRequest(t, NewRequest(t, "GET", searchURL.String()), http.StatusOK) + DecodeJSON(t, res, &topics) + assert.Len(t, topics.TopicNames, 2) + assert.EqualValues(t, "6", res.Header().Get("x-total-count")) + + // add keyword search + query = url.Values{"page": []string{"1"}, "limit": []string{"4"}} query.Add("q", "topic") searchURL.RawQuery = query.Encode() res = MakeRequest(t, NewRequest(t, "GET", searchURL.String()), http.StatusOK) diff --git a/tests/integration/repo_topic_test.go b/tests/integration/repo_topic_test.go index 58fee8418f..f198397007 100644 --- a/tests/integration/repo_topic_test.go +++ b/tests/integration/repo_topic_test.go @@ -21,20 +21,42 @@ func TestTopicSearch(t *testing.T) { TopicNames []*api.TopicResponse `json:"topics"` } + // search all topics + res := MakeRequest(t, NewRequest(t, "GET", searchURL.String()), http.StatusOK) + DecodeJSON(t, res, &topics) + assert.Len(t, topics.TopicNames, 6) + assert.EqualValues(t, "6", res.Header().Get("x-total-count")) + + // pagination search topics + topics.TopicNames = nil query := url.Values{"page": []string{"1"}, "limit": []string{"4"}} searchURL.RawQuery = query.Encode() - res := MakeRequest(t, NewRequest(t, "GET", searchURL.String()), http.StatusOK) + res = MakeRequest(t, NewRequest(t, "GET", searchURL.String()), http.StatusOK) DecodeJSON(t, res, &topics) assert.Len(t, topics.TopicNames, 4) assert.EqualValues(t, "6", res.Header().Get("x-total-count")) + // second page + topics.TopicNames = nil + query = url.Values{"page": []string{"2"}, "limit": []string{"4"}} + + searchURL.RawQuery = query.Encode() + res = MakeRequest(t, NewRequest(t, "GET", searchURL.String()), http.StatusOK) + DecodeJSON(t, res, &topics) + assert.Len(t, topics.TopicNames, 2) + assert.EqualValues(t, "6", res.Header().Get("x-total-count")) + + // add keyword search + topics.TopicNames = nil + query = url.Values{"page": []string{"1"}, "limit": []string{"4"}} query.Add("q", "topic") searchURL.RawQuery = query.Encode() res = MakeRequest(t, NewRequest(t, "GET", searchURL.String()), http.StatusOK) DecodeJSON(t, res, &topics) assert.Len(t, topics.TopicNames, 2) + topics.TopicNames = nil query.Set("q", "database") searchURL.RawQuery = query.Encode() res = MakeRequest(t, NewRequest(t, "GET", searchURL.String()), http.StatusOK)