Skip to content

Commit

Permalink
rafact: enable all source update (indes#204)
Browse files Browse the repository at this point in the history
  • Loading branch information
indes committed Sep 21, 2022
1 parent 4b4327d commit d5a6e61
Show file tree
Hide file tree
Showing 11 changed files with 210 additions and 67 deletions.
4 changes: 2 additions & 2 deletions internal/bot/bot.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,8 @@ func setCommands() {
handler.NewSetUpdateInterval(Core),
handler.NewExport(Core),
handler.NewImport(),
handler.NewPauseAll(),
handler.NewActiveAll(),
handler.NewPauseAll(Core),
handler.NewActiveAll(Core),
handler.NewHelp(),
handler.NewVersion(),
}
Expand Down
33 changes: 23 additions & 10 deletions internal/bot/handler/active_all.go
Original file line number Diff line number Diff line change
@@ -1,18 +1,21 @@
package handler

import (
"context"
"fmt"

"github.com/indes/flowerss-bot/internal/bot/session"
"github.com/indes/flowerss-bot/internal/model"
tb "gopkg.in/telebot.v3"

"github.com/indes/flowerss-bot/internal/bot/session"
"github.com/indes/flowerss-bot/internal/core"
)

type ActiveAll struct {
core *core.Core
}

func NewActiveAll() *ActiveAll {
return &ActiveAll{}
func NewActiveAll(core *core.Core) *ActiveAll {
return &ActiveAll{core: core}
}

func (a *ActiveAll) Command() string {
Expand All @@ -30,19 +33,29 @@ func (a *ActiveAll) Handle(ctx tb.Context) error {
subscribeUserID = mentionChat.ID
}

if err := model.ActiveSourcesByUserID(subscribeUserID); err != nil {
return ctx.Reply("激活失败")
source, err := a.core.GetUserSubscribedSources(context.Background(), subscribeUserID)
if err != nil {
return ctx.Reply("系统错误")
}

for _, s := range source {
err := a.core.EnableSourceUpdate(context.Background(), s.ID)
if err != nil {
return ctx.Reply("激活失败")
}
}

reply := "订阅已全部开启"
if mentionChat != nil {
reply = fmt.Sprintf("频道 [%s](https://t.me/%s) 订阅已全部开启", mentionChat.Title, mentionChat.Username)
}

return ctx.Reply(reply, &tb.SendOptions{
DisableWebPagePreview: true,
ParseMode: tb.ModeMarkdown,
})
return ctx.Reply(
reply, &tb.SendOptions{
DisableWebPagePreview: true,
ParseMode: tb.ModeMarkdown,
},
)
}

func (a *ActiveAll) Middlewares() []tb.MiddlewareFunc {
Expand Down
34 changes: 23 additions & 11 deletions internal/bot/handler/pause_all.go
Original file line number Diff line number Diff line change
@@ -1,19 +1,21 @@
package handler

import (
"context"
"fmt"

"github.com/indes/flowerss-bot/internal/bot/session"
"github.com/indes/flowerss-bot/internal/model"

tb "gopkg.in/telebot.v3"

"github.com/indes/flowerss-bot/internal/bot/session"
"github.com/indes/flowerss-bot/internal/core"
)

type PauseAll struct {
core *core.Core
}

func NewPauseAll() *PauseAll {
return &PauseAll{}
func NewPauseAll(core *core.Core) *PauseAll {
return &PauseAll{core: core}
}

func (p *PauseAll) Command() string {
Expand All @@ -36,18 +38,28 @@ func (p *PauseAll) Handle(ctx tb.Context) error {
}
}

if err := model.PauseSourcesByUserID(subscribeUserID); err != nil {
return ctx.Reply("暂停失败")
source, err := p.core.GetUserSubscribedSources(context.Background(), subscribeUserID)
if err != nil {
return ctx.Reply("系统错误")
}

for _, s := range source {
err := p.core.DisableSourceUpdate(context.Background(), s.ID)
if err != nil {
return ctx.Reply("暂停失败")
}
}

reply := "订阅已全部暂停"
if channelChat != nil {
reply = fmt.Sprintf("频道 [%s](https://t.me/%s) 订阅已全部暂停", channelChat.Title, channelChat.Username)
}
return ctx.Send(reply, &tb.SendOptions{
DisableWebPagePreview: true,
ParseMode: tb.ModeMarkdown,
})
return ctx.Send(
reply, &tb.SendOptions{
DisableWebPagePreview: true,
ParseMode: tb.ModeMarkdown,
},
)
}

func (p *PauseAll) Middlewares() []tb.MiddlewareFunc {
Expand Down
27 changes: 27 additions & 0 deletions internal/core/core.go
Original file line number Diff line number Diff line change
Expand Up @@ -248,3 +248,30 @@ func (c *Core) SetSubscriptionInterval(ctx context.Context, userID int64, source
subscription.Interval = interval
return c.subscriptionStorage.UpdateSubscription(ctx, userID, sourceID, subscription)
}

// EnableSourceUpdate 开启订阅源更新
func (c *Core) EnableSourceUpdate(ctx context.Context, sourceID uint) error {
return c.ClearSourceErrorCount(ctx, sourceID)
}

// DisableSourceUpdate 关闭订阅源更新
func (c *Core) DisableSourceUpdate(ctx context.Context, sourceID uint) error {
source, err := c.GetSource(ctx, sourceID)
if err != nil {
return err
}

source.ErrorCount = config.ErrorThreshold + 1
return c.sourceStorage.UpdateSource(ctx, sourceID, source)
}

// ClearSourceErrorCount 清空订阅源错误计数
func (c *Core) ClearSourceErrorCount(ctx context.Context, sourceID uint) error {
source, err := c.GetSource(ctx, sourceID)
if err != nil {
return err
}

source.ErrorCount = 0
return c.sourceStorage.UpdateSource(ctx, sourceID, source)
}
90 changes: 90 additions & 0 deletions internal/core/core_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -349,3 +349,93 @@ func TestCore_GetSubscription(t *testing.T) {
},
)
}

func TestCore_DisableSourceUpdate(t *testing.T) {
c, s := getTestCore(t)
defer s.Ctrl.Finish()
ctx := context.Background()
sourceID := uint(1)

t.Run(
"get source err", func(t *testing.T) {
s.Source.EXPECT().GetSource(ctx, sourceID).Return(
nil, errors.New("err"),
).Times(1)
err := c.DisableSourceUpdate(ctx, sourceID)
assert.Error(t, err)
},
)

t.Run(
"update source err", func(t *testing.T) {
s.Source.EXPECT().GetSource(ctx, sourceID).Return(
&model.Source{}, nil,
).Times(1)

s.Source.EXPECT().UpdateSource(ctx, sourceID, gomock.Any()).Return(
errors.New("err"),
).Times(1)
err := c.DisableSourceUpdate(ctx, sourceID)
assert.Error(t, err)
},
)

t.Run(
"update source err", func(t *testing.T) {
s.Source.EXPECT().GetSource(ctx, sourceID).Return(
&model.Source{}, nil,
).Times(1)

s.Source.EXPECT().UpdateSource(ctx, sourceID, gomock.Any()).Return(
nil,
).Times(1)
err := c.DisableSourceUpdate(ctx, sourceID)
assert.Nil(t, err)
},
)
}

func TestCore_ClearSourceErrorCount(t *testing.T) {
c, s := getTestCore(t)
defer s.Ctrl.Finish()
ctx := context.Background()
sourceID := uint(1)

t.Run(
"get source err", func(t *testing.T) {
s.Source.EXPECT().GetSource(ctx, sourceID).Return(
nil, errors.New("err"),
).Times(1)
err := c.ClearSourceErrorCount(ctx, sourceID)
assert.Error(t, err)
},
)

t.Run(
"update source err", func(t *testing.T) {
s.Source.EXPECT().GetSource(ctx, sourceID).Return(
&model.Source{}, nil,
).Times(1)

s.Source.EXPECT().UpdateSource(ctx, sourceID, gomock.Any()).Return(
errors.New("err"),
).Times(1)
err := c.ClearSourceErrorCount(ctx, sourceID)
assert.Error(t, err)
},
)

t.Run(
"update source err", func(t *testing.T) {
s.Source.EXPECT().GetSource(ctx, sourceID).Return(
&model.Source{}, nil,
).Times(1)

s.Source.EXPECT().UpdateSource(ctx, sourceID, gomock.Any()).Return(
nil,
).Times(1)
err := c.ClearSourceErrorCount(ctx, sourceID)
assert.Nil(t, err)
},
)
}
38 changes: 0 additions & 38 deletions internal/model/source.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,44 +102,6 @@ func (s *Source) NeedUpdate() bool {
}
}

func ActiveSourcesByUserID(userID int64) error {
subs, err := GetSubsByUserID(userID)

if err != nil {
return err
}

for _, sub := range subs {
var source Source
db.Where("id=?", sub.SourceID).First(&source)
if source.ID == sub.SourceID {
source.ErrorCount = 0
db.Save(&source)
}
}

return nil
}

func PauseSourcesByUserID(userID int64) error {
subs, err := GetSubsByUserID(userID)

if err != nil {
return err
}

for _, sub := range subs {
var source Source
db.Where("id=?", sub.SourceID).First(&source)
if source.ID == sub.SourceID {
source.ErrorCount = config.ErrorThreshold + 1
db.Save(&source)
}
}

return nil
}

func (s *Source) AddErrorCount() {
s.ErrorCount++
s.Save()
Expand Down
6 changes: 0 additions & 6 deletions internal/model/subscribe.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,6 @@ func GetSubscriberBySource(s *Source) []*Subscribe {
return subs
}

func GetSubsByUserID(userID int64) ([]Subscribe, error) {
var subs []Subscribe
db.Where("user_id=?", userID).Order("id").Find(&subs)
return subs, nil
}

func (s *Subscribe) ToggleNotification() error {
if s.EnableNotification != 1 {
s.EnableNotification = 1
Expand Down
14 changes: 14 additions & 0 deletions internal/storage/mock/storage_mock.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 11 additions & 0 deletions internal/storage/source.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (

"gorm.io/gorm"

"github.com/indes/flowerss-bot/internal/log"
"github.com/indes/flowerss-bot/internal/model"
)

Expand Down Expand Up @@ -61,3 +62,13 @@ func (s *SourceStorageImpl) Delete(ctx context.Context, id uint) error {
}
return nil
}

func (s *SourceStorageImpl) UpdateSource(ctx context.Context, sourceID uint, newSource *model.Source) error {
newSource.ID = sourceID
result := s.db.Save(newSource)
if result.Error != nil {
return result.Error
}
log.Debugf("update %d row, sourceID %d new %#v", result.RowsAffected, sourceID, newSource)
return nil
}
19 changes: 19 additions & 0 deletions internal/storage/source_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,23 @@ func TestSourceStorageImpl(t *testing.T) {
},
)

t.Run(
"update source", func(t *testing.T) {
source := &model.Source{
ID: 1,
Link: "http://google.com",
Title: "title",
}
err := s.UpdateSource(ctx, source.ID, source)
assert.Nil(t, err)

source.Title = "title2"
err = s.UpdateSource(ctx, source.ID, source)
assert.Nil(t, err)

got, err := s.GetSource(ctx, source.ID)
assert.Nil(t, err)
assert.Equal(t, source.Title, got.Title)
},
)
}
1 change: 1 addition & 0 deletions internal/storage/storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ type Source interface {
GetSource(ctx context.Context, id uint) (*model.Source, error)
GetSourceByURL(ctx context.Context, url string) (*model.Source, error)
Delete(ctx context.Context, id uint) error
UpdateSource(ctx context.Context, sourceID uint, newSource *model.Source) error
}

type SubscriptionSortType = int
Expand Down

0 comments on commit d5a6e61

Please sign in to comment.