From 3a507c6f7bd27b692329870773ffe9503f51a90b Mon Sep 17 00:00:00 2001 From: indes Date: Wed, 14 Sep 2022 23:06:34 +0800 Subject: [PATCH] feat: set buttom (#200) --- internal/bot/bot.go | 6 +-- internal/bot/handler/set.go | 16 ++++---- internal/bot/handler/set_feed_tag.go | 33 +++++---------- .../handler/set_subscription_tag_button.go | 12 ++---- internal/bot/handler/set_update_interval.go | 33 +++++++-------- internal/core/core.go | 35 ++++++++++++++++ internal/core/core_test.go | 41 +++++++++++++++++++ internal/log/log.go | 4 ++ internal/model/subscribe.go | 25 ----------- internal/storage/mock/storage_mock.go | 29 +++++++++++++ internal/storage/storage.go | 4 ++ internal/storage/subscription.go | 30 ++++++++++++++ internal/storage/subscription_test.go | 33 ++++++++++++++- 13 files changed, 216 insertions(+), 85 deletions(-) diff --git a/internal/bot/bot.go b/internal/bot/bot.go index d3da3b7c..d0a20286 100644 --- a/internal/bot/bot.go +++ b/internal/bot/bot.go @@ -83,8 +83,8 @@ func setCommands() { handler.NewRemoveAllSubscription(), handler.NewOnDocument(B, Core), handler.NewSet(B, Core), - handler.NewSetFeedTag(), - handler.NewSetUpdateInterval(), + handler.NewSetFeedTag(Core), + handler.NewSetUpdateInterval(Core), handler.NewExport(Core), handler.NewImport(), handler.NewPauseAll(), @@ -100,7 +100,7 @@ func setCommands() { ButtonHandlers := []handler.ButtonHandler{ handler.NewRemoveAllSubscriptionButton(Core), handler.NewCancelRemoveAllSubscriptionButton(), - handler.NewSetFeedItemButton(B), + handler.NewSetFeedItemButton(B, Core), handler.NewRemoveSubscriptionItemButton(Core), handler.NewNotificationSwitchButton(B), handler.NewSetSubscriptionTagButton(B), diff --git a/internal/bot/handler/set.go b/internal/bot/handler/set.go index ce3ab2e5..feecc54a 100644 --- a/internal/bot/handler/set.go +++ b/internal/bot/handler/set.go @@ -8,6 +8,7 @@ import ( "strings" "text/template" + "github.com/spf13/cast" tb "gopkg.in/telebot.v3" "github.com/indes/flowerss-bot/internal/bot/chat" @@ -90,7 +91,7 @@ const ( SetFeedItemButtonUnique = "set_feed_item_btn" feedSettingTmpl = ` 订阅设置 -[id] {{ .sub.ID }} +[id] {{ .source.ID }} [标题] {{ .source.Title }} [Link] {{.source.Link }} [抓取更新] {{if ge .source.ErrorCount .Count }}暂停{{else if lt .source.ErrorCount .Count }}抓取中{{end}} @@ -102,11 +103,12 @@ const ( ) type SetFeedItemButton struct { - bot *tb.Bot + bot *tb.Bot + core *core.Core } -func NewSetFeedItemButton(bot *tb.Bot) *SetFeedItemButton { - return &SetFeedItemButton{bot: bot} +func NewSetFeedItemButton(bot *tb.Bot, core *core.Core) *SetFeedItemButton { + return &SetFeedItemButton{bot: bot, core: core} } func (r *SetFeedItemButton) CallbackUnique() string { @@ -135,13 +137,13 @@ func (r *SetFeedItemButton) Handle(ctx tb.Context) error { } } - sourceID, _ := strconv.Atoi(data[1]) - source, err := model.GetSourceById(uint(sourceID)) + sourceID := cast.ToUint(data[1]) + source, err := r.core.GetSource(context.Background(), sourceID) if err != nil { return ctx.Edit("找不到该订阅源") } - sub, err := model.GetSubscribeByUserIDAndSourceID(subscriberID, source.ID) + sub, err := r.core.GetSubscription(context.Background(), subscriberID, source.ID) if err != nil { return ctx.Edit("用户未订阅该rss") } diff --git a/internal/bot/handler/set_feed_tag.go b/internal/bot/handler/set_feed_tag.go index 71f3e5ac..cc2eea0f 100644 --- a/internal/bot/handler/set_feed_tag.go +++ b/internal/bot/handler/set_feed_tag.go @@ -1,21 +1,23 @@ package handler import ( - "strconv" + "context" "strings" + "github.com/spf13/cast" + tb "gopkg.in/telebot.v3" + "github.com/indes/flowerss-bot/internal/bot/message" "github.com/indes/flowerss-bot/internal/bot/session" - "github.com/indes/flowerss-bot/internal/model" - - tb "gopkg.in/telebot.v3" + "github.com/indes/flowerss-bot/internal/core" ) type SetFeedTag struct { + core *core.Core } -func NewSetFeedTag() *SetFeedTag { - return &SetFeedTag{} +func NewSetFeedTag(core *core.Core) *SetFeedTag { + return &SetFeedTag{core: core} } func (s *SetFeedTag) Command() string { @@ -38,36 +40,23 @@ func (s *SetFeedTag) Handle(ctx tb.Context) error { msg := s.getMessageWithoutMention(ctx) args := strings.Split(strings.TrimSpace(msg), " ") if len(args) < 1 { - return ctx.Reply("/setfeedtag [sub id] [tag1] [tag2] 设置订阅标签(最多设置三个Tag,以空格分割)") + return ctx.Reply("/setfeedtag [sourceID] [tag1] [tag2] 设置订阅标签(最多设置三个Tag,以空格分割)") } // 截短参数 if len(args) > 4 { args = args[:4] } - subID, err := strconv.Atoi(args[0]) - if err != nil { - return ctx.Reply("请输入正确的订阅id!") - } - - sub, err := model.GetSubscribeByID(subID) - if err != nil || sub == nil { - return ctx.Reply("请输入正确的订阅id!") - } + sourceID := cast.ToUint(args[0]) mentionChat, _ := session.GetMentionChatFromCtxStore(ctx) subscribeUserID := ctx.Chat().ID if mentionChat != nil { subscribeUserID = mentionChat.ID } - if subscribeUserID != sub.UserID { - return ctx.Reply("订阅记录与操作者id不一致") - } - - if err := sub.SetTag(args[1:]); err != nil { + if err := s.core.SetSubscriptionTag(context.Background(), subscribeUserID, sourceID, args[1:]); err != nil { return ctx.Reply("订阅标签设置失败!") - } return ctx.Reply("订阅标签设置成功!") } diff --git a/internal/bot/handler/set_subscription_tag_button.go b/internal/bot/handler/set_subscription_tag_button.go index c0ff89c9..bb84c7fd 100644 --- a/internal/bot/handler/set_subscription_tag_button.go +++ b/internal/bot/handler/set_subscription_tag_button.go @@ -5,10 +5,10 @@ import ( "strconv" "strings" + "github.com/spf13/cast" tb "gopkg.in/telebot.v3" "github.com/indes/flowerss-bot/internal/bot/chat" - "github.com/indes/flowerss-bot/internal/model" ) const ( @@ -55,17 +55,11 @@ func (b *SetSubscriptionTagButton) Handle(ctx tb.Context) error { return ctx.Send("无权限") } data := strings.Split(c.Data, ":") - ownID, _ := strconv.Atoi(data[0]) - sourceID, _ := strconv.Atoi(data[1]) - - sub, err := model.GetSubscribeByUserIDAndSourceID(int64(ownID), uint(sourceID)) - if err != nil { - return ctx.Send("系统错误,代码04") - } + sourceID := cast.ToUint(data[1]) msg := fmt.Sprintf( "请使用`/setfeedtag %d tags`命令为该订阅设置标签,tags为需要设置的标签,以空格分隔。(最多设置三个标签) \n"+ "例如:`/setfeedtag %d 科技 苹果`", - sub.ID, sub.ID, + sourceID, sourceID, ) return ctx.Edit(msg, &tb.SendOptions{ParseMode: tb.ModeMarkdown}) } diff --git a/internal/bot/handler/set_update_interval.go b/internal/bot/handler/set_update_interval.go index 1264d08d..c509cbe3 100644 --- a/internal/bot/handler/set_update_interval.go +++ b/internal/bot/handler/set_update_interval.go @@ -1,21 +1,25 @@ package handler import ( + "context" "strconv" "strings" + "github.com/spf13/cast" tb "gopkg.in/telebot.v3" "github.com/indes/flowerss-bot/internal/bot/message" "github.com/indes/flowerss-bot/internal/bot/session" - "github.com/indes/flowerss-bot/internal/model" + "github.com/indes/flowerss-bot/internal/core" + "github.com/indes/flowerss-bot/internal/log" ) type SetUpdateInterval struct { + core *core.Core } -func NewSetUpdateInterval() *SetUpdateInterval { - return &SetUpdateInterval{} +func NewSetUpdateInterval(core *core.Core) *SetUpdateInterval { + return &SetUpdateInterval{core: core} } func (s *SetUpdateInterval) Command() string { @@ -38,7 +42,7 @@ func (s *SetUpdateInterval) Handle(ctx tb.Context) error { msg := s.getMessageWithoutMention(ctx) args := strings.Split(strings.TrimSpace(msg), " ") if len(args) < 2 { - return ctx.Reply("/setinterval [interval] [sub id] 设置订阅刷新频率(可设置多个sub id,以空格分割)") + return ctx.Reply("/setinterval [interval] [sourceID] 设置订阅刷新频率(可设置多个sub id,以空格分割)") } interval, err := strconv.Atoi(args[0]) @@ -51,22 +55,15 @@ func (s *SetUpdateInterval) Handle(ctx tb.Context) error { if mentionChat != nil { subscribeUserID = mentionChat.ID } - for _, id := range args[1:] { - subID, err := strconv.Atoi(id) - if err != nil { - return ctx.Reply("请输入正确的订阅id!") - } - - sub, err := model.GetSubscribeByID(subID) - if err != nil || sub == nil { - return ctx.Reply("请输入正确的订阅id!") - } - if sub.UserID != subscribeUserID { - return ctx.Reply("订阅id与订阅者id不匹配!") + for _, id := range args[1:] { + sourceID := cast.ToUint(id) + if err := s.core.SetSubscriptionInterval( + context.Background(), subscribeUserID, sourceID, interval, + ); err != nil { + log.Errorf("SetSubscriptionInterval failed, %v", err) + return ctx.Reply("抓取频率设置失败!") } - - _ = sub.SetInterval(interval) } return ctx.Reply("抓取频率设置成功!") } diff --git a/internal/core/core.go b/internal/core/core.go index 5567b9c4..8980c837 100644 --- a/internal/core/core.go +++ b/internal/core/core.go @@ -3,6 +3,7 @@ package core import ( "context" "errors" + "strings" "sync" "gorm.io/driver/mysql" @@ -213,3 +214,37 @@ func (c *Core) UnsubscribeAllSource(ctx context.Context, userID int64) error { wg.Wait() return nil } + +// GetSubscription 获取订阅 +func (c *Core) GetSubscription(ctx context.Context, userID int64, sourceID uint) (*model.Subscribe, error) { + subscription, err := c.subscriptionStorage.GetSubscription(ctx, userID, sourceID) + if err != nil { + if err == storage.ErrRecordNotFound { + return nil, ErrSubscriptionNotExist + } + return nil, err + } + return subscription, nil +} + +// SetSubscriptionTag 设置订阅标签 +func (c *Core) SetSubscriptionTag(ctx context.Context, userID int64, sourceID uint, tags []string) error { + subscription, err := c.GetSubscription(ctx, userID, sourceID) + if err != nil { + return err + } + + subscription.Tag = "#" + strings.Join(tags, " #") + return c.subscriptionStorage.UpdateSubscription(ctx, userID, sourceID, subscription) +} + +// SetSubscriptionInterval +func (c *Core) SetSubscriptionInterval(ctx context.Context, userID int64, sourceID uint, interval int) error { + subscription, err := c.GetSubscription(ctx, userID, sourceID) + if err != nil { + return err + } + + subscription.Interval = interval + return c.subscriptionStorage.UpdateSubscription(ctx, userID, sourceID, subscription) +} diff --git a/internal/core/core_test.go b/internal/core/core_test.go index a619e41e..27083a31 100644 --- a/internal/core/core_test.go +++ b/internal/core/core_test.go @@ -308,3 +308,44 @@ func TestCore_GetSource(t *testing.T) { }, ) } + +func TestCore_GetSubscription(t *testing.T) { + c, s := getTestCore(t) + defer s.Ctrl.Finish() + ctx := context.Background() + userID := int64(101) + sourceID := uint(1) + + t.Run( + "subscription err", func(t *testing.T) { + s.Subscription.EXPECT().GetSubscription(ctx, userID, sourceID).Return( + nil, errors.New("err"), + ).Times(1) + got, err := c.GetSubscription(ctx, userID, sourceID) + assert.Error(t, err) + assert.Nil(t, got) + }, + ) + + t.Run( + "subscription not exist", func(t *testing.T) { + s.Subscription.EXPECT().GetSubscription(ctx, userID, sourceID).Return( + nil, storage.ErrRecordNotFound, + ).Times(1) + got, err := c.GetSubscription(ctx, userID, sourceID) + assert.Equal(t, ErrSubscriptionNotExist, err) + assert.Nil(t, got) + }, + ) + + t.Run( + "ok", func(t *testing.T) { + s.Subscription.EXPECT().GetSubscription(ctx, userID, sourceID).Return( + &model.Subscribe{}, nil, + ).Times(1) + got, err := c.GetSubscription(ctx, userID, sourceID) + assert.Nil(t, err) + assert.NotNil(t, got) + }, + ) +} diff --git a/internal/log/log.go b/internal/log/log.go index 472f5d8c..e666f024 100644 --- a/internal/log/log.go +++ b/internal/log/log.go @@ -71,3 +71,7 @@ func Fatal(args ...interface{}) { func Fatalf(template string, args ...interface{}) { globalLogger.Sugar().Fatalf(template, args...) } + +func Debugf(template string, args ...interface{}) { + globalLogger.Sugar().Debugf(template, args...) +} diff --git a/internal/model/subscribe.go b/internal/model/subscribe.go index 9809c65c..07cbe67a 100644 --- a/internal/model/subscribe.go +++ b/internal/model/subscribe.go @@ -2,7 +2,6 @@ package model import ( "errors" - "strings" "github.com/indes/flowerss-bot/internal/config" ) @@ -23,15 +22,6 @@ type Subscribe struct { EditTime } -func GetSubscribeByUserIDAndSourceID(userID int64, sourceID uint) (*Subscribe, error) { - var sub Subscribe - db.Where("user_id=? and source_id=?", userID, sourceID).First(&sub) - if sub.UserID != int64(userID) { - return nil, errors.New("未订阅该RSS源") - } - return &sub, nil -} - func GetSubscriberBySource(s *Source) []*Subscribe { if s == nil { return []*Subscribe{} @@ -86,21 +76,6 @@ func (s *Source) ToggleEnabled() error { return nil } -func (s *Subscribe) SetTag(tags []string) error { - defer s.Save() - - tagStr := strings.Join(tags, " #") - - s.Tag = "#" + tagStr - return nil -} - -func (s *Subscribe) SetInterval(interval int) error { - defer s.Save() - s.Interval = interval - return nil -} - func (s *Subscribe) Unsub() error { if s.ID == 0 { return errors.New("can't delete 0 subscribe") diff --git a/internal/storage/mock/storage_mock.go b/internal/storage/mock/storage_mock.go index c2eb8f74..21debd30 100644 --- a/internal/storage/mock/storage_mock.go +++ b/internal/storage/mock/storage_mock.go @@ -293,6 +293,21 @@ func (mr *MockSubscriptionMockRecorder) DeleteSubscription(ctx, userID, sourceID return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteSubscription", reflect.TypeOf((*MockSubscription)(nil).DeleteSubscription), ctx, userID, sourceID) } +// GetSubscription mocks base method. +func (m *MockSubscription) GetSubscription(ctx context.Context, userID int64, sourceID uint) (*model.Subscribe, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetSubscription", ctx, userID, sourceID) + ret0, _ := ret[0].(*model.Subscribe) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetSubscription indicates an expected call of GetSubscription. +func (mr *MockSubscriptionMockRecorder) GetSubscription(ctx, userID, sourceID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSubscription", reflect.TypeOf((*MockSubscription)(nil).GetSubscription), ctx, userID, sourceID) +} + // GetSubscriptionsBySourceID mocks base method. func (m *MockSubscription) GetSubscriptionsBySourceID(ctx context.Context, sourceID uint, opts *storage.GetSubscriptionsOptions) (*storage.GetSubscriptionsResult, error) { m.ctrl.T.Helper() @@ -352,6 +367,20 @@ func (mr *MockSubscriptionMockRecorder) SubscriptionExist(ctx, userID, sourceID return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SubscriptionExist", reflect.TypeOf((*MockSubscription)(nil).SubscriptionExist), ctx, userID, sourceID) } +// UpdateSubscription mocks base method. +func (m *MockSubscription) UpdateSubscription(ctx context.Context, userID int64, sourceID uint, newSubscription *model.Subscribe) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpdateSubscription", ctx, userID, sourceID, newSubscription) + ret0, _ := ret[0].(error) + return ret0 +} + +// UpdateSubscription indicates an expected call of UpdateSubscription. +func (mr *MockSubscriptionMockRecorder) UpdateSubscription(ctx, userID, sourceID, newSubscription interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateSubscription", reflect.TypeOf((*MockSubscription)(nil).UpdateSubscription), ctx, userID, sourceID, newSubscription) +} + // MockContent is a mock of Content interface. type MockContent struct { ctrl *gomock.Controller diff --git a/internal/storage/storage.go b/internal/storage/storage.go index cb30c41e..e698ed48 100644 --- a/internal/storage/storage.go +++ b/internal/storage/storage.go @@ -55,6 +55,7 @@ type Subscription interface { Storage AddSubscription(ctx context.Context, subscription *model.Subscribe) error SubscriptionExist(ctx context.Context, userID int64, sourceID uint) (bool, error) + GetSubscription(ctx context.Context, userID int64, sourceID uint) (*model.Subscribe, error) GetSubscriptionsByUserID( ctx context.Context, userID int64, opts *GetSubscriptionsOptions, ) (*GetSubscriptionsResult, error) @@ -64,6 +65,9 @@ type Subscription interface { CountSubscriptions(ctx context.Context) (int64, error) DeleteSubscription(ctx context.Context, userID int64, sourceID uint) (int64, error) CountSourceSubscriptions(ctx context.Context, sourceID uint) (int64, error) + UpdateSubscription( + ctx context.Context, userID int64, sourceID uint, newSubscription *model.Subscribe, + ) error } type Content interface { diff --git a/internal/storage/subscription.go b/internal/storage/subscription.go index 5016931a..78722db2 100644 --- a/internal/storage/subscription.go +++ b/internal/storage/subscription.go @@ -6,6 +6,7 @@ import ( "gorm.io/gorm" + "github.com/indes/flowerss-bot/internal/log" "github.com/indes/flowerss-bot/internal/model" ) @@ -38,6 +39,20 @@ func (s *SubscriptionStorageImpl) SubscriptionExist(ctx context.Context, userID return (count > 0), nil } +func (s *SubscriptionStorageImpl) GetSubscription(ctx context.Context, userID int64, sourceID uint) ( + *model.Subscribe, error, +) { + subscription := &model.Subscribe{} + result := s.db.WithContext(ctx).Where("user_id = ? and source_id = ?", userID, sourceID).First(subscription) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, ErrRecordNotFound + } + return nil, result.Error + } + return subscription, nil +} + func (s *SubscriptionStorageImpl) GetSubscriptionsByUserID( ctx context.Context, userID int64, opts *GetSubscriptionsOptions, ) (*GetSubscriptionsResult, error) { @@ -135,3 +150,18 @@ func (s *SubscriptionStorageImpl) CountSourceSubscriptions(ctx context.Context, } return count, nil } + +func (s *SubscriptionStorageImpl) UpdateSubscription( + ctx context.Context, userID int64, sourceID uint, newSubscription *model.Subscribe, +) error { + result := s.db.WithContext(ctx).Where( + "user_id = ? and source_id = ?", userID, sourceID, + ).Updates(newSubscription) + if result.Error != nil { + return result.Error + } + log.Debugf( + "update %d row, userID %d sourceID %d new %#v", result.RowsAffected, userID, sourceID, newSubscription, + ) + return nil +} diff --git a/internal/storage/subscription_test.go b/internal/storage/subscription_test.go index 0255475b..1c72e791 100644 --- a/internal/storage/subscription_test.go +++ b/internal/storage/subscription_test.go @@ -26,7 +26,6 @@ func TestSubscriptionStorageImpl(t *testing.T) { UserID: 101, EnableNotification: 1, }, - &model.Subscribe{ SourceID: 2, UserID: 100, @@ -58,6 +57,10 @@ func TestSubscriptionStorageImpl(t *testing.T) { assert.Nil(t, err) assert.True(t, exist) + subscription, err := s.GetSubscription(ctx, 101, 1) + assert.Nil(t, err) + assert.NotNil(t, subscription) + opt := &GetSubscriptionsOptions{ Count: 2, } @@ -100,6 +103,10 @@ func TestSubscriptionStorageImpl(t *testing.T) { assert.Nil(t, err) assert.False(t, exist) + subscription, err = s.GetSubscription(ctx, 101, 1) + assert.Error(t, err) + assert.Nil(t, subscription) + got, err = s.CountSubscriptions(ctx) assert.Nil(t, err) assert.Equal(t, int64(4), got) @@ -109,4 +116,28 @@ func TestSubscriptionStorageImpl(t *testing.T) { assert.Equal(t, int64(2), got) }, ) + + t.Run( + "update subscription", func(t *testing.T) { + sub := &model.Subscribe{ + ID: 10001, + SourceID: 1000, + UserID: 1002, + EnableNotification: 1, + } + err := s.UpdateSubscription(ctx, sub.UserID, sub.SourceID, sub) + assert.Nil(t, err) + + err = s.AddSubscription(ctx, sub) + assert.Nil(t, err) + + sub.Tag = "tag" + err = s.UpdateSubscription(ctx, sub.UserID, sub.SourceID, sub) + assert.Nil(t, err) + + subscription, err := s.GetSubscription(ctx, sub.UserID, sub.SourceID) + assert.Nil(t, err) + assert.Equal(t, sub.Tag, subscription.Tag) + }, + ) }