Skip to content

Commit

Permalink
feat: 将v2版本的bot filter 转换成v3版本的middleware
Browse files Browse the repository at this point in the history
  • Loading branch information
indes committed Feb 26, 2022
1 parent 88950ac commit 4b5993b
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 96 deletions.
28 changes: 6 additions & 22 deletions internal/bot/bot.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package bot
import (
"time"

"github.com/indes/flowerss-bot/internal/bot/fsm"
"github.com/indes/flowerss-bot/internal/bot/handler"
"github.com/indes/flowerss-bot/internal/bot/middleware"
"github.com/indes/flowerss-bot/internal/config"
Expand All @@ -14,9 +13,6 @@ import (
)

var (
// UserState 用户状态,用于标示当前用户操作所在状态
UserState map[int64]fsm.UserStatus = make(map[int64]fsm.UserStatus)

// B bot
B *tb.Bot
)
Expand All @@ -25,38 +21,27 @@ func init() {
if config.RunMode == config.TestMode {
return
}
poller := &tb.LongPoller{Timeout: 10 * time.Second}
spamProtected := tb.NewMiddlewarePoller(
poller, func(upd *tb.Update) bool {
if !isUserAllowed(upd) {
// 检查用户是否可以使用bot
return false
}

if !CheckAdmin(upd) {
return false
}
return true
},
)
zap.S().Infow(
"init telegram bot",
"token", config.BotToken,
"endpoint", config.TelegramEndpoint,
)

// create bot
var err error
B, err = tb.NewBot(
tb.Settings{
URL: config.TelegramEndpoint,
Token: config.BotToken,
Poller: spamProtected,
Poller: &tb.LongPoller{Timeout: 10 * time.Second},
Client: util.HttpClient,
Verbose: true,
},
)
B.Use(middleware.PreLoadMentionChat(), middleware.IsChatAdmin())
B.Use(
middleware.UserFilter(),
middleware.PreLoadMentionChat(),
middleware.IsChatAdmin(),
)
if err != nil {
zap.S().Fatal(err)
return
Expand Down Expand Up @@ -124,5 +109,4 @@ func setCommands() {
if err := B.SetCommands(commands); err != nil {
zap.S().Errorw("set bot commands failed", "error", err.Error())
}

}
26 changes: 26 additions & 0 deletions internal/bot/middleware/user_filter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package middleware

import (
"fmt"

tb "gopkg.in/telebot.v3"

"github.com/indes/flowerss-bot/internal/config"
)

func UserFilter() tb.MiddlewareFunc {
return func(next tb.HandlerFunc) tb.HandlerFunc {
return func(c tb.Context) error {
if len(config.AllowUsers) == 0 {
return next(c)
}
userID := c.Sender().ID
for _, allowUserID := range config.AllowUsers {
if allowUserID == userID {
return next(c)
}
}
return fmt.Errorf("deny user %d", userID)
}
}
}
74 changes: 0 additions & 74 deletions internal/bot/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,80 +102,6 @@ func BroadcastSourceError(source *model.Source) {
}
}

// CheckAdmin check user is admin of group/channel
func CheckAdmin(upd *tb.Update) bool {

if upd.Message != nil {
if HasAdminType(upd.Message.Chat.Type) {
adminList, _ := B.AdminsOf(upd.Message.Chat)
for _, admin := range adminList {
if admin.User.ID == upd.Message.Sender.ID {
return true
}
}

return false
}

return true
} else if upd.Callback != nil {
if HasAdminType(upd.Callback.Message.Chat.Type) {
adminList, _ := B.AdminsOf(upd.Callback.Message.Chat)
for _, admin := range adminList {
if admin.User.ID == upd.Callback.Sender.ID {
return true
}
}

return false
}

return true
}
return false
}

// IsUserAllowed check user is allowed to use bot
func isUserAllowed(upd *tb.Update) bool {
if upd == nil {
return false
}

var userID int64

if upd.Message != nil {
userID = int64(upd.Message.Sender.ID)
} else if upd.Callback != nil {
userID = int64(upd.Callback.Sender.ID)
} else {
return false
}

if len(config.AllowUsers) == 0 {
return true
}

for _, allowUserID := range config.AllowUsers {
if allowUserID == userID {
return true
}
}

zap.S().Infow("user not allowed", "userID", userID)
return false
}

// HasAdminType check if the message is sent in the group/channel environment
func HasAdminType(t tb.ChatType) bool {
hasAdmin := []tb.ChatType{tb.ChatGroup, tb.ChatSuperGroup, tb.ChatChannel, tb.ChatChannelPrivate}
for _, n := range hasAdmin {
if t == n {
return true
}
}
return false
}

// GetMentionFromMessage get message mention
func GetMentionFromMessage(m *tb.Message) (mention string) {
if m.Text != "" {
Expand Down

0 comments on commit 4b5993b

Please sign in to comment.