diff --git a/api/api.go b/api/api.go index d512bace..46921d71 100644 --- a/api/api.go +++ b/api/api.go @@ -29,12 +29,14 @@ func Serve(router *mux.Router, store *storage.Storage, pool *worker.Pool, feedHa sr.HandleFunc("/users/{userID:[0-9]+}", handler.userByID).Methods(http.MethodGet) sr.HandleFunc("/users/{userID:[0-9]+}", handler.updateUser).Methods(http.MethodPut) sr.HandleFunc("/users/{userID:[0-9]+}", handler.removeUser).Methods(http.MethodDelete) + sr.HandleFunc("/users/{userID:[0-9]+}/mark-all-as-read", handler.markUserAsRead).Methods(http.MethodPut) sr.HandleFunc("/users/{username}", handler.userByUsername).Methods(http.MethodGet) sr.HandleFunc("/me", handler.currentUser).Methods(http.MethodGet) sr.HandleFunc("/categories", handler.createCategory).Methods(http.MethodPost) sr.HandleFunc("/categories", handler.getCategories).Methods(http.MethodGet) sr.HandleFunc("/categories/{categoryID}", handler.updateCategory).Methods(http.MethodPut) sr.HandleFunc("/categories/{categoryID}", handler.removeCategory).Methods(http.MethodDelete) + sr.HandleFunc("/categories/{categoryID}/mark-all-as-read", handler.markCategoryAsRead).Methods(http.MethodPut) sr.HandleFunc("/discover", handler.getSubscriptions).Methods(http.MethodPost) sr.HandleFunc("/feeds", handler.createFeed).Methods(http.MethodPost) sr.HandleFunc("/feeds", handler.getFeeds).Methods(http.MethodGet) @@ -44,6 +46,7 @@ func Serve(router *mux.Router, store *storage.Storage, pool *worker.Pool, feedHa sr.HandleFunc("/feeds/{feedID}", handler.updateFeed).Methods(http.MethodPut) sr.HandleFunc("/feeds/{feedID}", handler.removeFeed).Methods(http.MethodDelete) sr.HandleFunc("/feeds/{feedID}/icon", handler.feedIcon).Methods(http.MethodGet) + sr.HandleFunc("/feeds/{feedID}/mark-all-as-read", handler.markFeedAsRead).Methods(http.MethodPut) sr.HandleFunc("/export", handler.exportFeeds).Methods(http.MethodGet) sr.HandleFunc("/import", handler.importFeeds).Methods(http.MethodPost) sr.HandleFunc("/feeds/{feedID}/entries", handler.getFeedEntries).Methods(http.MethodGet) diff --git a/api/category.go b/api/category.go index 2222369c..9a298e62 100644 --- a/api/category.go +++ b/api/category.go @@ -7,6 +7,7 @@ package api // import "miniflux.app/api" import ( "errors" "net/http" + "time" "miniflux.app/http/request" "miniflux.app/http/response/json" @@ -64,6 +65,29 @@ func (h *handler) updateCategory(w http.ResponseWriter, r *http.Request) { json.Created(w, r, category) } +func (h *handler) markCategoryAsRead(w http.ResponseWriter, r *http.Request) { + userID := request.UserID(r) + categoryID := request.RouteInt64Param(r, "categoryID") + + category, err := h.store.Category(userID, categoryID) + if err != nil { + json.ServerError(w, r, err) + return + } + + if category == nil { + json.NotFound(w, r) + return + } + + if err = h.store.MarkCategoryAsRead(userID, categoryID, time.Now()); err != nil { + json.ServerError(w, r, err) + return + } + + json.NoContent(w, r) +} + func (h *handler) getCategories(w http.ResponseWriter, r *http.Request) { categories, err := h.store.Categories(request.UserID(r)) if err != nil { diff --git a/api/feed.go b/api/feed.go index 31a33a0e..5bca6fc4 100644 --- a/api/feed.go +++ b/api/feed.go @@ -7,6 +7,7 @@ package api // import "miniflux.app/api" import ( "errors" "net/http" + "time" "miniflux.app/http/request" "miniflux.app/http/response/json" @@ -142,6 +143,29 @@ func (h *handler) updateFeed(w http.ResponseWriter, r *http.Request) { json.Created(w, r, originalFeed) } +func (h *handler) markFeedAsRead(w http.ResponseWriter, r *http.Request) { + feedID := request.RouteInt64Param(r, "feedID") + userID := request.UserID(r) + + feed, err := h.store.FeedByID(userID, feedID) + if err != nil { + json.NotFound(w, r) + return + } + + if feed == nil { + json.NotFound(w, r) + return + } + + if err := h.store.MarkFeedAsRead(userID, feedID, time.Now()); err != nil { + json.ServerError(w, r, err) + return + } + + json.NoContent(w, r) +} + func (h *handler) getFeeds(w http.ResponseWriter, r *http.Request) { feeds, err := h.store.Feeds(request.UserID(r)) if err != nil { diff --git a/api/user.go b/api/user.go index 01b32a6e..f0593f22 100644 --- a/api/user.go +++ b/api/user.go @@ -92,6 +92,26 @@ func (h *handler) updateUser(w http.ResponseWriter, r *http.Request) { json.Created(w, r, originalUser) } +func (h *handler) markUserAsRead(w http.ResponseWriter, r *http.Request) { + userID := request.RouteInt64Param(r, "userID") + if userID != request.UserID(r) { + json.Forbidden(w, r) + return + } + + if _, err := h.store.UserByID(userID); err != nil { + json.NotFound(w, r) + return + } + + if err := h.store.MarkAllAsRead(userID); err != nil { + json.ServerError(w, r, err) + return + } + + json.NoContent(w, r) +} + func (h *handler) users(w http.ResponseWriter, r *http.Request) { if !request.IsAdminUser(r) { json.Forbidden(w, r) diff --git a/client/client.go b/client/client.go index e8b59f7a..60797f5c 100644 --- a/client/client.go +++ b/client/client.go @@ -133,6 +133,12 @@ func (c *Client) DeleteUser(userID int64) error { return c.request.Delete(fmt.Sprintf("/v1/users/%d", userID)) } +// MarkAllAsRead marks all unread entries as read for a given user. +func (c *Client) MarkAllAsRead(userID int64) error { + _, err := c.request.Put(fmt.Sprintf("/v1/users/%d/mark-all-as-read", userID), nil) + return err +} + // Discover try to find subscriptions from a website. func (c *Client) Discover(url string) (Subscriptions, error) { body, err := c.request.Post("/v1/discover", map[string]string{"url": url}) @@ -207,6 +213,12 @@ func (c *Client) UpdateCategory(categoryID int64, title string) (*Category, erro return category, nil } +// MarkCategoryAsRead marks all unread entries in a category as read. +func (c *Client) MarkCategoryAsRead(categoryID int64) error { + _, err := c.request.Put(fmt.Sprintf("/v1/categories/%d/mark-all-as-read", categoryID), nil) + return err +} + // DeleteCategory removes a category. func (c *Client) DeleteCategory(categoryID int64) error { return c.request.Delete(fmt.Sprintf("/v1/categories/%d", categoryID)) @@ -309,6 +321,12 @@ func (c *Client) UpdateFeed(feedID int64, feedChanges *FeedModification) (*Feed, return f, nil } +// MarkFeedAsRead marks all unread entries of the feed as read. +func (c *Client) MarkFeedAsRead(feedID int64) error { + _, err := c.request.Put(fmt.Sprintf("/v1/feeds/%d/mark-all-as-read", feedID), nil) + return err +} + // RefreshAllFeeds refreshes all feeds. func (c *Client) RefreshAllFeeds() error { _, err := c.request.Put(fmt.Sprintf("/v1/feeds/refresh"), nil) diff --git a/tests/category_test.go b/tests/category_test.go index 8aa3d6b1..9345cacc 100644 --- a/tests/category_test.go +++ b/tests/category_test.go @@ -8,6 +8,8 @@ package tests import ( "testing" + + miniflux "miniflux.app/client" ) func TestCreateCategory(t *testing.T) { @@ -81,6 +83,38 @@ func TestUpdateCategory(t *testing.T) { } } +func TestMarkCategoryAsRead(t *testing.T) { + client := createClient(t) + + feed, category := createFeed(t, client) + + results, err := client.FeedEntries(feed.ID, nil) + if err != nil { + t.Fatalf(`Failed to get entries: %v`, err) + } + if results.Total == 0 { + t.Fatalf(`Invalid number of entries: %d`, results.Total) + } + if results.Entries[0].Status != miniflux.EntryStatusUnread { + t.Fatalf(`Invalid entry status, got %q instead of %q`, results.Entries[0].Status, miniflux.EntryStatusUnread) + } + + if err := client.MarkCategoryAsRead(category.ID); err != nil { + t.Fatalf(`Failed to mark category as read: %v`, err) + } + + results, err = client.FeedEntries(feed.ID, nil) + if err != nil { + t.Fatalf(`Failed to get updated entries: %v`, err) + } + + for _, entry := range results.Entries { + if entry.Status != miniflux.EntryStatusRead { + t.Errorf(`Status for entry %d was %q instead of %q`, entry.ID, entry.Status, miniflux.EntryStatusRead) + } + } +} + func TestListCategories(t *testing.T) { categoryName := "My category" client := createClient(t) diff --git a/tests/feed_test.go b/tests/feed_test.go index e7841a4a..e4905df2 100644 --- a/tests/feed_test.go +++ b/tests/feed_test.go @@ -324,6 +324,38 @@ func TestUpdateFeedCategory(t *testing.T) { } } +func TestMarkFeedAsRead(t *testing.T) { + client := createClient(t) + + feed, _ := createFeed(t, client) + + results, err := client.FeedEntries(feed.ID, nil) + if err != nil { + t.Fatalf(`Failed to get entries: %v`, err) + } + if results.Total == 0 { + t.Fatalf(`Invalid number of entries: %d`, results.Total) + } + if results.Entries[0].Status != miniflux.EntryStatusUnread { + t.Fatalf(`Invalid entry status, got %q instead of %q`, results.Entries[0].Status, miniflux.EntryStatusUnread) + } + + if err := client.MarkFeedAsRead(feed.ID); err != nil { + t.Fatalf(`Failed to mark feed as read: %v`, err) + } + + results, err = client.FeedEntries(feed.ID, nil) + if err != nil { + t.Fatalf(`Failed to get updated entries: %v`, err) + } + + for _, entry := range results.Entries { + if entry.Status != miniflux.EntryStatusRead { + t.Errorf(`Status for entry %d was %q instead of %q`, entry.ID, entry.Status, miniflux.EntryStatusRead) + } + } +} + func TestDeleteFeed(t *testing.T) { client := createClient(t) feed, _ := createFeed(t, client) diff --git a/tests/user_test.go b/tests/user_test.go index d2ff3477..26f4c855 100644 --- a/tests/user_test.go +++ b/tests/user_test.go @@ -394,3 +394,65 @@ func TestCannotDeleteUserAsNonAdmin(t *testing.T) { t.Fatal(`A "Forbidden" error should be raised`) } } + +func TestMarkUserAsReadAsUser(t *testing.T) { + username := getRandomUsername() + adminClient := miniflux.New(testBaseURL, testAdminUsername, testAdminPassword) + user, err := adminClient.CreateUser(username, testStandardPassword, false) + if err != nil { + t.Fatal(err) + } + + client := miniflux.New(testBaseURL, username, testStandardPassword) + feed, _ := createFeed(t, client) + + results, err := client.FeedEntries(feed.ID, nil) + if err != nil { + t.Fatalf(`Failed to get entries: %v`, err) + } + if results.Total == 0 { + t.Fatalf(`Invalid number of entries: %d`, results.Total) + } + if results.Entries[0].Status != miniflux.EntryStatusUnread { + t.Fatalf(`Invalid entry status, got %q instead of %q`, results.Entries[0].Status, miniflux.EntryStatusUnread) + } + + if err := client.MarkAllAsRead(user.ID); err != nil { + t.Fatalf(`Failed to mark user's unread entries as read: %v`, err) + } + + results, err = client.FeedEntries(feed.ID, nil) + if err != nil { + t.Fatalf(`Failed to get updated entries: %v`, err) + } + + for _, entry := range results.Entries { + if entry.Status != miniflux.EntryStatusRead { + t.Errorf(`Status for entry %d was %q instead of %q`, entry.ID, entry.Status, miniflux.EntryStatusRead) + } + } +} + +func TestCannotMarkUserAsReadAsOtherUser(t *testing.T) { + username := getRandomUsername() + adminClient := miniflux.New(testBaseURL, testAdminUsername, testAdminPassword) + user1, err := adminClient.CreateUser(username, testStandardPassword, false) + if err != nil { + t.Fatal(err) + } + createFeed(t, miniflux.New(testBaseURL, username, testStandardPassword)) + + username2 := getRandomUsername() + if _, err = adminClient.CreateUser(username2, testStandardPassword, false); err != nil { + t.Fatal(err) + } + + client := miniflux.New(testBaseURL, username2, testStandardPassword) + err = client.MarkAllAsRead(user1.ID) + if err == nil { + t.Fatalf(`Non-admin users should not be able to mark another user as read`) + } + if err != miniflux.ErrForbidden { + t.Errorf(`A "Forbidden" error should be raised, got %q`, err) + } +}