From 85fabbd7bbe2ecad2ea57e0c5de28d334a14b6bd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Wed, 29 Nov 2023 11:56:04 +0800 Subject: [PATCH] Draft rule set support --- .gitignore | 1 + adapter/experimental.go | 28 +- adapter/router.go | 15 +- cmd/sing-box/cmd_format.go | 39 -- cmd/sing-box/cmd_geosite.go | 41 ++ cmd/sing-box/cmd_geosite_contains.go | 97 ++++ cmd/sing-box/cmd_geosite_export.go | 80 +++ cmd/sing-box/cmd_geosite_list.go | 50 ++ cmd/sing-box/cmd_geosite_matcher.go | 56 ++ cmd/sing-box/cmd_merge.go | 2 +- cmd/sing-box/cmd_rule_set.go | 14 + cmd/sing-box/cmd_rule_set_compile.go | 80 +++ cmd/sing-box/cmd_rule_set_format.go | 87 ++++ cmd/sing-box/cmd_tools.go | 6 +- cmd/sing-box/cmd_tools_connect.go | 2 +- common/dialer/router.go | 12 +- common/srs/binary.go | 483 ++++++++++++++++++ common/srs/ip_set.go | 116 +++++ constant/rule.go | 8 + experimental/cachefile/cache.go | 35 ++ experimental/cachefile/fakeip.go | 2 +- experimental/clashapi/proxies.go | 6 +- experimental/clashapi/server_resources.go | 6 +- .../clashapi/trafficontrol/tracker.go | 8 +- go.mod | 2 +- go.sum | 4 +- option/route.go | 1 + option/rule.go | 1 + option/rule_dns.go | 1 + option/rule_set.go | 230 +++++++++ option/types.go | 8 + route/router.go | 34 +- route/rule_abstract.go | 22 +- route/rule_default.go | 7 +- route/rule_dns.go | 7 +- route/rule_headless.go | 173 +++++++ route/rule_item_cidr.go | 17 +- route/rule_item_domain.go | 7 + route/rule_item_rule_set.go | 52 ++ route/rule_set.go | 22 + route/rule_set_local.go | 69 +++ route/rule_set_remote.go | 191 +++++++ 42 files changed, 2050 insertions(+), 72 deletions(-) create mode 100644 cmd/sing-box/cmd_geosite.go create mode 100644 cmd/sing-box/cmd_geosite_contains.go create mode 100644 cmd/sing-box/cmd_geosite_export.go create mode 100644 cmd/sing-box/cmd_geosite_list.go create mode 100644 cmd/sing-box/cmd_geosite_matcher.go create mode 100644 cmd/sing-box/cmd_rule_set.go create mode 100644 cmd/sing-box/cmd_rule_set_compile.go create mode 100644 cmd/sing-box/cmd_rule_set_format.go create mode 100644 common/srs/binary.go create mode 100644 common/srs/ip_set.go create mode 100644 option/rule_set.go create mode 100644 route/rule_headless.go create mode 100644 route/rule_item_rule_set.go create mode 100644 route/rule_set.go create mode 100644 route/rule_set_local.go create mode 100644 route/rule_set_remote.go diff --git a/.gitignore b/.gitignore index 6630f428ce..55bdab3a0c 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,7 @@ /.idea/ /vendor/ /*.json +/*.srs /*.db /site/ /bin/ diff --git a/adapter/experimental.go b/adapter/experimental.go index ac523e4e74..1c62f23e47 100644 --- a/adapter/experimental.go +++ b/adapter/experimental.go @@ -1,8 +1,11 @@ package adapter import ( + "bytes" "context" + "encoding/binary" "net" + "time" "github.com/sagernet/sing-box/common/urltest" N "github.com/sagernet/sing/common/network" @@ -23,6 +26,7 @@ type CacheFile interface { PreStarter StoreFakeIP() bool + FakeIPStorage LoadMode() string StoreMode(mode string) error @@ -30,7 +34,29 @@ type CacheFile interface { StoreSelected(group string, selected string) error LoadGroupExpand(group string) (isExpand bool, loaded bool) StoreGroupExpand(group string, expand bool) error - FakeIPStorage + LoadRuleSet(tag string) *SavedRuleSet + SaveRuleSet(tag string, set *SavedRuleSet) error +} + +type SavedRuleSet struct { + Content []byte + LastUpdated time.Time +} + +func (s *SavedRuleSet) MarshalBinary() ([]byte, error) { + var buffer bytes.Buffer + err := binary.Write(&buffer, binary.BigEndian, s.LastUpdated.Unix()) + if err != nil { + return nil, err + } + buffer.Write(s.Content) + return buffer.Bytes(), nil +} + +func (s *SavedRuleSet) UnmarshalBinary(data []byte) error { + s.LastUpdated = time.Unix(int64(binary.BigEndian.Uint64(data)), 0) + s.Content = data[8:] + return nil } type Tracker interface { diff --git a/adapter/router.go b/adapter/router.go index e4c3904d51..ca4d6547c4 100644 --- a/adapter/router.go +++ b/adapter/router.go @@ -18,7 +18,7 @@ type Router interface { Outbounds() []Outbound Outbound(tag string) (Outbound, bool) - DefaultOutbound(network string) Outbound + DefaultOutbound(network string) (Outbound, error) FakeIPStore() FakeIPStore @@ -27,6 +27,8 @@ type Router interface { GeoIPReader() *geoip.Reader LoadGeosite(code string) (Rule, error) + RuleSet(tag string) (RuleSet, bool) + Exchange(ctx context.Context, message *mdns.Msg) (*mdns.Msg, error) Lookup(ctx context.Context, domain string, strategy dns.DomainStrategy) ([]netip.Addr, error) LookupDefault(ctx context.Context, domain string) ([]netip.Addr, error) @@ -61,11 +63,15 @@ func RouterFromContext(ctx context.Context) Router { return service.FromContext[Router](ctx) } +type HeadlessRule interface { + Match(metadata *InboundContext) bool +} + type Rule interface { + HeadlessRule Service Type() string UpdateGeosite() error - Match(metadata *InboundContext) bool Outbound() string String() string } @@ -76,6 +82,11 @@ type DNSRule interface { RewriteTTL() *uint32 } +type RuleSet interface { + HeadlessRule + Service +} + type InterfaceUpdateListener interface { InterfaceUpdated() } diff --git a/cmd/sing-box/cmd_format.go b/cmd/sing-box/cmd_format.go index 10a5497cd4..c5e939e4a9 100644 --- a/cmd/sing-box/cmd_format.go +++ b/cmd/sing-box/cmd_format.go @@ -7,7 +7,6 @@ import ( "github.com/sagernet/sing-box/common/json" "github.com/sagernet/sing-box/log" - "github.com/sagernet/sing-box/option" E "github.com/sagernet/sing/common/exceptions" "github.com/spf13/cobra" @@ -69,41 +68,3 @@ func format() error { } return nil } - -func formatOne(configPath string) error { - configContent, err := os.ReadFile(configPath) - if err != nil { - return E.Cause(err, "read config") - } - var options option.Options - err = options.UnmarshalJSON(configContent) - if err != nil { - return E.Cause(err, "decode config") - } - buffer := new(bytes.Buffer) - encoder := json.NewEncoder(buffer) - encoder.SetIndent("", " ") - err = encoder.Encode(options) - if err != nil { - return E.Cause(err, "encode config") - } - if !commandFormatFlagWrite { - os.Stdout.WriteString(buffer.String() + "\n") - return nil - } - if bytes.Equal(configContent, buffer.Bytes()) { - return nil - } - output, err := os.Create(configPath) - if err != nil { - return E.Cause(err, "open output") - } - _, err = output.Write(buffer.Bytes()) - output.Close() - if err != nil { - return E.Cause(err, "write output") - } - outputPath, _ := filepath.Abs(configPath) - os.Stderr.WriteString(outputPath + "\n") - return nil -} diff --git a/cmd/sing-box/cmd_geosite.go b/cmd/sing-box/cmd_geosite.go new file mode 100644 index 0000000000..95db935797 --- /dev/null +++ b/cmd/sing-box/cmd_geosite.go @@ -0,0 +1,41 @@ +package main + +import ( + "github.com/sagernet/sing-box/common/geosite" + "github.com/sagernet/sing-box/log" + E "github.com/sagernet/sing/common/exceptions" + + "github.com/spf13/cobra" +) + +var ( + commandGeoSiteFlagFile string + geositeReader *geosite.Reader + geositeCodeList []string +) + +var commandGeoSite = &cobra.Command{ + Use: "geosite", + Short: "Geosite tools", + PersistentPreRun: func(cmd *cobra.Command, args []string) { + err := geositePreRun() + if err != nil { + log.Fatal(err) + } + }, +} + +func init() { + commandGeoSite.PersistentFlags().StringVarP(&commandGeoSiteFlagFile, "file", "f", "geosite.db", "geosite file") + mainCommand.AddCommand(commandGeoSite) +} + +func geositePreRun() error { + reader, codeList, err := geosite.Open(commandGeoSiteFlagFile) + if err != nil { + return E.Cause(err, "open geosite file") + } + geositeReader = reader + geositeCodeList = codeList + return nil +} diff --git a/cmd/sing-box/cmd_geosite_contains.go b/cmd/sing-box/cmd_geosite_contains.go new file mode 100644 index 0000000000..ca5bf7b9ac --- /dev/null +++ b/cmd/sing-box/cmd_geosite_contains.go @@ -0,0 +1,97 @@ +package main + +import ( + "os" + "sort" + + "github.com/sagernet/sing-box/log" + E "github.com/sagernet/sing/common/exceptions" + + "github.com/spf13/cobra" +) + +var commandGeositeContains = &cobra.Command{ + Use: "contains [category] ", + Short: "Check if a domain is in the geosite", + Args: cobra.RangeArgs(1, 2), + Run: func(cmd *cobra.Command, args []string) { + var ( + source string + target string + ) + switch len(args) { + case 1: + target = args[0] + case 2: + source = args[0] + target = args[1] + } + err := geositeContains(source, target) + if err != nil { + log.Fatal(err) + } + }, +} + +func init() { + commandGeoSite.AddCommand(commandGeositeContains) +} + +func geositeContains(source string, target string) error { + var sourceMatcherList []struct { + code string + matcher *searchGeositeMatcher + } + if source != "" { + sourceSet, err := geositeReader.Read(source) + if err != nil { + return err + } + sourceMatcher, err := newSearchGeositeMatcher(sourceSet) + if err != nil { + return E.Cause(err, "compile code: "+source) + } + sourceMatcherList = []struct { + code string + matcher *searchGeositeMatcher + }{ + { + code: source, + matcher: sourceMatcher, + }, + } + + } else { + for _, code := range geositeCodeList { + sourceSet, err := geositeReader.Read(code) + if err != nil { + return err + } + sourceMatcher, err := newSearchGeositeMatcher(sourceSet) + if err != nil { + return E.Cause(err, "compile code: "+code) + } + sourceMatcherList = append(sourceMatcherList, struct { + code string + matcher *searchGeositeMatcher + }{ + code: code, + matcher: sourceMatcher, + }) + } + } + sort.SliceStable(sourceMatcherList, func(i, j int) bool { + return sourceMatcherList[i].code < sourceMatcherList[j].code + }) + + for _, matcherItem := range sourceMatcherList { + if matchRule := matcherItem.matcher.Match(target); matchRule != "" { + os.Stdout.WriteString("Match code (") + os.Stdout.WriteString(matcherItem.code) + os.Stdout.WriteString(") ") + os.Stdout.WriteString(matchRule) + os.Stdout.WriteString("\n") + } + } + return nil +} diff --git a/cmd/sing-box/cmd_geosite_export.go b/cmd/sing-box/cmd_geosite_export.go new file mode 100644 index 0000000000..82915906de --- /dev/null +++ b/cmd/sing-box/cmd_geosite_export.go @@ -0,0 +1,80 @@ +package main + +import ( + "encoding/json" + "io" + "os" + + "github.com/sagernet/sing-box/common/geosite" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/option" + + "github.com/spf13/cobra" +) + +var commandGeositeExportOutput string + +const commandGeositeExportDefaultOutput = "geosite-.json" + +var commandGeositeExport = &cobra.Command{ + Use: "export ", + Short: "Export geosite category as rule-set", + Run: func(cmd *cobra.Command, args []string) { + err := geositeExport(args[0]) + if err != nil { + log.Fatal(err) + } + }, +} + +func init() { + commandGeositeExport.Flags().StringVarP(&commandGeositeExportOutput, "output", "o", commandGeositeExportDefaultOutput, "Output path") + commandGeoSite.AddCommand(commandGeositeExport) +} + +func geositeExport(category string) error { + sourceSet, err := geositeReader.Read(category) + if err != nil { + return err + } + var ( + outputFile *os.File + outputWriter io.Writer + ) + if commandGeositeExportOutput == "stdout" { + outputWriter = os.Stdout + } else if commandGeositeExportOutput == commandGeositeExportDefaultOutput { + outputFile, err = os.Create("geosite-" + category + ".json") + if err != nil { + return err + } + defer outputFile.Close() + outputWriter = outputFile + } else { + outputFile, err = os.Create(commandGeositeExportOutput) + if err != nil { + return err + } + defer outputFile.Close() + outputWriter = outputFile + } + + encoder := json.NewEncoder(outputWriter) + encoder.SetIndent("", " ") + var headlessRule option.DefaultHeadlessRule + defaultRule := geosite.Compile(sourceSet) + headlessRule.Domain = defaultRule.Domain + headlessRule.DomainSuffix = defaultRule.DomainSuffix + headlessRule.DomainKeyword = defaultRule.DomainKeyword + headlessRule.DomainRegex = defaultRule.DomainRegex + var plainRuleSet option.PlainRuleSetCompat + plainRuleSet.Version = C.RuleSetVersion1 + plainRuleSet.Options.Rules = []option.HeadlessRule{ + { + Type: C.RuleTypeDefault, + DefaultOptions: headlessRule, + }, + } + return encoder.Encode(plainRuleSet) +} diff --git a/cmd/sing-box/cmd_geosite_list.go b/cmd/sing-box/cmd_geosite_list.go new file mode 100644 index 0000000000..cedb7adfd2 --- /dev/null +++ b/cmd/sing-box/cmd_geosite_list.go @@ -0,0 +1,50 @@ +package main + +import ( + "os" + "sort" + + "github.com/sagernet/sing-box/log" + F "github.com/sagernet/sing/common/format" + + "github.com/spf13/cobra" +) + +var commandGeositeList = &cobra.Command{ + Use: "list ", + Short: "List geosite categories", + Run: func(cmd *cobra.Command, args []string) { + err := geositeList() + if err != nil { + log.Fatal(err) + } + }, +} + +func init() { + commandGeoSite.AddCommand(commandGeositeList) +} + +func geositeList() error { + var geositeEntry []struct { + category string + items int + } + for _, category := range geositeCodeList { + sourceSet, err := geositeReader.Read(category) + if err != nil { + return err + } + geositeEntry = append(geositeEntry, struct { + category string + items int + }{category, len(sourceSet)}) + } + sort.SliceStable(geositeEntry, func(i, j int) bool { + return geositeEntry[i].items < geositeEntry[j].items + }) + for _, entry := range geositeEntry { + os.Stdout.WriteString(F.ToString(entry.category, " (", entry.items, ")\n")) + } + return nil +} diff --git a/cmd/sing-box/cmd_geosite_matcher.go b/cmd/sing-box/cmd_geosite_matcher.go new file mode 100644 index 0000000000..791dba2499 --- /dev/null +++ b/cmd/sing-box/cmd_geosite_matcher.go @@ -0,0 +1,56 @@ +package main + +import ( + "regexp" + "strings" + + "github.com/sagernet/sing-box/common/geosite" +) + +type searchGeositeMatcher struct { + domainMap map[string]bool + suffixList []string + keywordList []string + regexList []string +} + +func newSearchGeositeMatcher(items []geosite.Item) (*searchGeositeMatcher, error) { + options := geosite.Compile(items) + domainMap := make(map[string]bool) + for _, domain := range options.Domain { + domainMap[domain] = true + } + rule := &searchGeositeMatcher{ + domainMap: domainMap, + suffixList: options.DomainSuffix, + keywordList: options.DomainKeyword, + regexList: options.DomainRegex, + } + return rule, nil +} + +func (r *searchGeositeMatcher) Match(domain string) string { + if r.domainMap[domain] { + return "domain=" + domain + } + for _, suffix := range r.suffixList { + if strings.HasSuffix(domain, suffix) { + return "domain_suffix=" + suffix + } + } + for _, keyword := range r.keywordList { + if strings.Contains(domain, keyword) { + return "domain_keyword=" + keyword + } + } + for _, regexStr := range r.regexList { + regex, err := regexp.Compile(regexStr) + if err != nil { + continue + } + if regex.MatchString(domain) { + return "domain_regex=" + regexStr + } + } + return "" +} diff --git a/cmd/sing-box/cmd_merge.go b/cmd/sing-box/cmd_merge.go index 0aff750182..4fb07b8688 100644 --- a/cmd/sing-box/cmd_merge.go +++ b/cmd/sing-box/cmd_merge.go @@ -18,7 +18,7 @@ import ( ) var commandMerge = &cobra.Command{ - Use: "merge [output]", + Use: "merge ", Short: "Merge configurations", Run: func(cmd *cobra.Command, args []string) { err := merge(args[0]) diff --git a/cmd/sing-box/cmd_rule_set.go b/cmd/sing-box/cmd_rule_set.go new file mode 100644 index 0000000000..f4112a087b --- /dev/null +++ b/cmd/sing-box/cmd_rule_set.go @@ -0,0 +1,14 @@ +package main + +import ( + "github.com/spf13/cobra" +) + +var commandRuleSet = &cobra.Command{ + Use: "rule-set", + Short: "Manage rule sets", +} + +func init() { + mainCommand.AddCommand(commandRuleSet) +} diff --git a/cmd/sing-box/cmd_rule_set_compile.go b/cmd/sing-box/cmd_rule_set_compile.go new file mode 100644 index 0000000000..de318095ac --- /dev/null +++ b/cmd/sing-box/cmd_rule_set_compile.go @@ -0,0 +1,80 @@ +package main + +import ( + "io" + "os" + "strings" + + "github.com/sagernet/sing-box/common/json" + "github.com/sagernet/sing-box/common/srs" + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/option" + + "github.com/spf13/cobra" +) + +var flagRuleSetCompileOutput string + +const flagRuleSetCompileDefaultOutput = ".srs" + +var commandRuleSetCompile = &cobra.Command{ + Use: "compile [source-path]", + Short: "Compile rule-set json to binary", + Args: cobra.ExactArgs(1), + Run: func(cmd *cobra.Command, args []string) { + err := compileRuleSet(args[0]) + if err != nil { + log.Fatal(err) + } + }, +} + +func init() { + commandRuleSet.AddCommand(commandRuleSetCompile) + commandRuleSetCompile.Flags().StringVarP(&flagRuleSetCompileOutput, "output", "o", flagRuleSetCompileDefaultOutput, "Output file") +} + +func compileRuleSet(sourcePath string) error { + var ( + reader io.Reader + err error + ) + if sourcePath == "stdin" { + reader = os.Stdin + } else { + reader, err = os.Open(sourcePath) + if err != nil { + return err + } + } + decoder := json.NewDecoder(json.NewCommentFilter(reader)) + decoder.DisallowUnknownFields() + var plainRuleSet option.PlainRuleSetCompat + err = decoder.Decode(&plainRuleSet) + if err != nil { + return err + } + ruleSet := plainRuleSet.Upgrade() + var outputPath string + if flagRuleSetCompileOutput == flagRuleSetCompileDefaultOutput { + if strings.HasSuffix(sourcePath, ".json") { + outputPath = sourcePath[:len(sourcePath)-5] + ".srs" + } else { + outputPath = sourcePath + ".srs" + } + } else { + outputPath = flagRuleSetCompileOutput + } + outputFile, err := os.Create(outputPath) + if err != nil { + return err + } + err = srs.Write(outputFile, ruleSet) + if err != nil { + outputFile.Close() + os.Remove(outputPath) + return err + } + outputFile.Close() + return nil +} diff --git a/cmd/sing-box/cmd_rule_set_format.go b/cmd/sing-box/cmd_rule_set_format.go new file mode 100644 index 0000000000..dc3ee6aabd --- /dev/null +++ b/cmd/sing-box/cmd_rule_set_format.go @@ -0,0 +1,87 @@ +package main + +import ( + "bytes" + "io" + "os" + "path/filepath" + + "github.com/sagernet/sing-box/common/json" + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/option" + E "github.com/sagernet/sing/common/exceptions" + + "github.com/spf13/cobra" +) + +var commandRuleSetFormatFlagWrite bool + +var commandRuleSetFormat = &cobra.Command{ + Use: "format ", + Short: "Format rule-set json", + Args: cobra.ExactArgs(1), + Run: func(cmd *cobra.Command, args []string) { + err := formatRuleSet(args[0]) + if err != nil { + log.Fatal(err) + } + }, +} + +func init() { + commandRuleSetFormat.Flags().BoolVarP(&commandRuleSetFormatFlagWrite, "write", "w", false, "write result to (source) file instead of stdout") + commandRuleSet.AddCommand(commandRuleSetFormat) +} + +func formatRuleSet(sourcePath string) error { + var ( + reader io.Reader + err error + ) + if sourcePath == "stdin" { + reader = os.Stdin + } else { + reader, err = os.Open(sourcePath) + if err != nil { + return err + } + } + content, err := io.ReadAll(reader) + if err != nil { + return err + } + decoder := json.NewDecoder(json.NewCommentFilter(bytes.NewReader(content))) + decoder.DisallowUnknownFields() + var plainRuleSet option.PlainRuleSetCompat + err = decoder.Decode(&plainRuleSet) + if err != nil { + return err + } + ruleSet := plainRuleSet.Upgrade() + buffer := new(bytes.Buffer) + encoder := json.NewEncoder(buffer) + encoder.SetIndent("", " ") + err = encoder.Encode(ruleSet) + if err != nil { + return E.Cause(err, "encode config") + } + outputPath, _ := filepath.Abs(sourcePath) + if !commandRuleSetFormatFlagWrite || sourcePath == "stdin" { + os.Stdout.WriteString(buffer.String() + "\n") + return nil + } + if bytes.Equal(content, buffer.Bytes()) { + return nil + } + output, err := os.Create(sourcePath) + if err != nil { + return E.Cause(err, "open output") + } + _, err = output.Write(buffer.Bytes()) + output.Close() + if err != nil { + return E.Cause(err, "write output") + } + os.Stderr.WriteString(outputPath + "\n") + return nil +} diff --git a/cmd/sing-box/cmd_tools.go b/cmd/sing-box/cmd_tools.go index 460a50cd1c..c45f585576 100644 --- a/cmd/sing-box/cmd_tools.go +++ b/cmd/sing-box/cmd_tools.go @@ -38,11 +38,7 @@ func createPreStartedClient() (*box.Box, error) { func createDialer(instance *box.Box, network string, outboundTag string) (N.Dialer, error) { if outboundTag == "" { - outbound := instance.Router().DefaultOutbound(N.NetworkName(network)) - if outbound == nil { - return nil, E.New("missing default outbound") - } - return outbound, nil + return instance.Router().DefaultOutbound(N.NetworkName(network)) } else { outbound, loaded := instance.Router().Outbound(outboundTag) if !loaded { diff --git a/cmd/sing-box/cmd_tools_connect.go b/cmd/sing-box/cmd_tools_connect.go index b904ebc9f5..3ea04bcd40 100644 --- a/cmd/sing-box/cmd_tools_connect.go +++ b/cmd/sing-box/cmd_tools_connect.go @@ -18,7 +18,7 @@ import ( var commandConnectFlagNetwork string var commandConnect = &cobra.Command{ - Use: "connect [address]", + Use: "connect
", Short: "Connect to an address", Args: cobra.ExactArgs(1), Run: func(cmd *cobra.Command, args []string) { diff --git a/common/dialer/router.go b/common/dialer/router.go index 1d5586546e..2531607753 100644 --- a/common/dialer/router.go +++ b/common/dialer/router.go @@ -18,11 +18,19 @@ func NewRouter(router adapter.Router) N.Dialer { } func (d *RouterDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { - return d.router.DefaultOutbound(network).DialContext(ctx, network, destination) + dialer, err := d.router.DefaultOutbound(network) + if err != nil { + return nil, err + } + return dialer.DialContext(ctx, network, destination) } func (d *RouterDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { - return d.router.DefaultOutbound(N.NetworkUDP).ListenPacket(ctx, destination) + dialer, err := d.router.DefaultOutbound(N.NetworkUDP) + if err != nil { + return nil, err + } + return dialer.ListenPacket(ctx, destination) } func (d *RouterDialer) Upstream() any { diff --git a/common/srs/binary.go b/common/srs/binary.go new file mode 100644 index 0000000000..6b51bb4d3b --- /dev/null +++ b/common/srs/binary.go @@ -0,0 +1,483 @@ +package srs + +import ( + "compress/zlib" + "encoding/binary" + "io" + "net/netip" + + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/option" + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/domain" + E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/rw" + + "go4.org/netipx" +) + +var MagicBytes = [3]byte{0x53, 0x52, 0x53} // SRS + +const ( + ruleItemQueryType uint8 = iota + ruleItemNetwork + ruleItemDomain + ruleItemDomainKeyword + ruleItemDomainRegex + ruleItemSourceIPCIDR + ruleItemIPCIDR + ruleItemSourcePort + ruleItemSourcePortRange + ruleItemPort + ruleItemPortRange + ruleItemProcessName + ruleItemProcessPath + ruleItemPackageName + ruleItemWIFISSID + ruleItemWIFIBSSID + ruleItemFinal uint8 = 0xFF +) + +func Read(reader io.Reader, recovery bool) (ruleSet option.PlainRuleSet, err error) { + var magicBytes [3]byte + _, err = io.ReadFull(reader, magicBytes[:]) + if err != nil { + return + } + if magicBytes != MagicBytes { + err = E.New("invalid sing-box rule set file") + return + } + var version uint8 + err = binary.Read(reader, binary.BigEndian, &version) + if err != nil { + return ruleSet, err + } + if version != 1 { + return ruleSet, E.New("unsupported version: ", version) + } + zReader, err := zlib.NewReader(reader) + if err != nil { + return + } + length, err := rw.ReadUVariant(zReader) + if err != nil { + return + } + ruleSet.Rules = make([]option.HeadlessRule, length) + for i := uint64(0); i < length; i++ { + ruleSet.Rules[i], err = readRule(zReader, recovery) + if err != nil { + err = E.Cause(err, "read rule [", i, "]") + return + } + } + return +} + +func Write(writer io.Writer, ruleSet option.PlainRuleSet) error { + _, err := writer.Write(MagicBytes[:]) + if err != nil { + return err + } + err = binary.Write(writer, binary.BigEndian, uint8(1)) + if err != nil { + return err + } + zWriter, err := zlib.NewWriterLevel(writer, zlib.BestCompression) + if err != nil { + return err + } + err = rw.WriteUVariant(zWriter, uint64(len(ruleSet.Rules))) + if err != nil { + return err + } + for _, rule := range ruleSet.Rules { + err = writeRule(zWriter, rule) + if err != nil { + return err + } + } + return zWriter.Close() +} + +func readRule(reader io.Reader, recovery bool) (rule option.HeadlessRule, err error) { + var ruleType uint8 + err = binary.Read(reader, binary.BigEndian, &ruleType) + if err != nil { + return + } + switch ruleType { + case 0: + rule.DefaultOptions, err = readDefaultRule(reader, recovery) + case 1: + rule.LogicalOptions, err = readLogicalRule(reader, recovery) + default: + err = E.New("unknown rule type: ", ruleType) + } + return +} + +func writeRule(writer io.Writer, rule option.HeadlessRule) error { + switch rule.Type { + case C.RuleTypeDefault: + return writeDefaultRule(writer, rule.DefaultOptions) + case C.RuleTypeLogical: + return writeLogicalRule(writer, rule.LogicalOptions) + default: + panic("unknown rule type: " + rule.Type) + } +} + +func readDefaultRule(reader io.Reader, recovery bool) (rule option.DefaultHeadlessRule, err error) { + for { + var itemType uint8 + err = binary.Read(reader, binary.BigEndian, &itemType) + if err != nil { + return + } + switch itemType { + case ruleItemQueryType: + var rawQueryType []uint16 + rawQueryType, err = readRuleItemUint16(reader) + if err != nil { + return + } + rule.QueryType = common.Map(rawQueryType, func(it uint16) option.DNSQueryType { + return option.DNSQueryType(it) + }) + case ruleItemNetwork: + rule.Network, err = readRuleItemString(reader) + case ruleItemDomain: + var matcher *domain.Matcher + matcher, err = domain.ReadMatcher(reader) + if err != nil { + return + } + rule.DomainMatcher = matcher + case ruleItemDomainKeyword: + rule.DomainKeyword, err = readRuleItemString(reader) + case ruleItemDomainRegex: + rule.DomainRegex, err = readRuleItemString(reader) + case ruleItemSourceIPCIDR: + rule.SourceIPSet, err = readIPSet(reader) + if err != nil { + return + } + if recovery { + rule.SourceIPCIDR = common.Map(rule.SourceIPSet.Prefixes(), netip.Prefix.String) + } + case ruleItemIPCIDR: + rule.IPSet, err = readIPSet(reader) + if err != nil { + return + } + if recovery { + rule.IPCIDR = common.Map(rule.IPSet.Prefixes(), netip.Prefix.String) + } + case ruleItemSourcePort: + rule.SourcePort, err = readRuleItemUint16(reader) + case ruleItemSourcePortRange: + rule.SourcePortRange, err = readRuleItemString(reader) + case ruleItemPort: + rule.Port, err = readRuleItemUint16(reader) + case ruleItemPortRange: + rule.PortRange, err = readRuleItemString(reader) + case ruleItemProcessName: + rule.ProcessName, err = readRuleItemString(reader) + case ruleItemProcessPath: + rule.ProcessPath, err = readRuleItemString(reader) + case ruleItemPackageName: + rule.PackageName, err = readRuleItemString(reader) + case ruleItemWIFISSID: + rule.WIFISSID, err = readRuleItemString(reader) + case ruleItemWIFIBSSID: + rule.WIFIBSSID, err = readRuleItemString(reader) + case ruleItemFinal: + err = binary.Read(reader, binary.BigEndian, &rule.Invert) + return + default: + err = E.New("unknown rule item type: ", itemType) + } + if err != nil { + return + } + } +} + +func writeDefaultRule(writer io.Writer, rule option.DefaultHeadlessRule) error { + err := binary.Write(writer, binary.BigEndian, uint8(0)) + if err != nil { + return err + } + if len(rule.QueryType) > 0 { + err = writeRuleItemUint16(writer, ruleItemQueryType, common.Map(rule.QueryType, func(it option.DNSQueryType) uint16 { + return uint16(it) + })) + if err != nil { + return err + } + } + if len(rule.Network) > 0 { + err = writeRuleItemString(writer, ruleItemNetwork, rule.Network) + if err != nil { + return err + } + } + if len(rule.Domain) > 0 || len(rule.DomainSuffix) > 0 { + err = binary.Write(writer, binary.BigEndian, ruleItemDomain) + if err != nil { + return err + } + err = domain.NewMatcher(rule.Domain, rule.DomainSuffix).Write(writer) + if err != nil { + return err + } + } + if len(rule.DomainKeyword) > 0 { + err = writeRuleItemString(writer, ruleItemDomainKeyword, rule.DomainKeyword) + if err != nil { + return err + } + } + if len(rule.DomainRegex) > 0 { + err = writeRuleItemString(writer, ruleItemDomainRegex, rule.DomainRegex) + if err != nil { + return err + } + } + if len(rule.SourceIPCIDR) > 0 { + err = writeRuleItemCIDR(writer, ruleItemSourceIPCIDR, rule.SourceIPCIDR) + if err != nil { + return E.Cause(err, "source_ipcidr") + } + } + if len(rule.IPCIDR) > 0 { + err = writeRuleItemCIDR(writer, ruleItemIPCIDR, rule.IPCIDR) + if err != nil { + return E.Cause(err, "ipcidr") + } + } + if len(rule.SourcePort) > 0 { + err = writeRuleItemUint16(writer, ruleItemSourcePort, rule.SourcePort) + if err != nil { + return err + } + } + if len(rule.SourcePortRange) > 0 { + err = writeRuleItemString(writer, ruleItemSourcePortRange, rule.SourcePortRange) + if err != nil { + return err + } + } + if len(rule.Port) > 0 { + err = writeRuleItemUint16(writer, ruleItemPort, rule.Port) + if err != nil { + return err + } + } + if len(rule.PortRange) > 0 { + err = writeRuleItemString(writer, ruleItemPortRange, rule.PortRange) + if err != nil { + return err + } + } + if len(rule.ProcessName) > 0 { + err = writeRuleItemString(writer, ruleItemProcessName, rule.ProcessName) + if err != nil { + return err + } + } + if len(rule.ProcessPath) > 0 { + err = writeRuleItemString(writer, ruleItemProcessPath, rule.ProcessPath) + if err != nil { + return err + } + } + if len(rule.PackageName) > 0 { + err = writeRuleItemString(writer, ruleItemPackageName, rule.PackageName) + if err != nil { + return err + } + } + if len(rule.WIFISSID) > 0 { + err = writeRuleItemString(writer, ruleItemWIFISSID, rule.WIFISSID) + if err != nil { + return err + } + } + if len(rule.WIFIBSSID) > 0 { + err = writeRuleItemString(writer, ruleItemWIFIBSSID, rule.WIFIBSSID) + if err != nil { + return err + } + } + err = binary.Write(writer, binary.BigEndian, ruleItemFinal) + if err != nil { + return err + } + err = binary.Write(writer, binary.BigEndian, rule.Invert) + if err != nil { + return err + } + return nil +} + +func readRuleItemString(reader io.Reader) ([]string, error) { + length, err := rw.ReadUVariant(reader) + if err != nil { + return nil, err + } + value := make([]string, length) + for i := uint64(0); i < length; i++ { + value[i], err = rw.ReadVString(reader) + if err != nil { + return nil, err + } + } + return value, nil +} + +func writeRuleItemString(writer io.Writer, itemType uint8, value []string) error { + err := binary.Write(writer, binary.BigEndian, itemType) + if err != nil { + return err + } + err = rw.WriteUVariant(writer, uint64(len(value))) + if err != nil { + return err + } + for _, item := range value { + err = rw.WriteVString(writer, item) + if err != nil { + return err + } + } + return nil +} + +func readRuleItemUint16(reader io.Reader) ([]uint16, error) { + length, err := rw.ReadUVariant(reader) + if err != nil { + return nil, err + } + value := make([]uint16, length) + for i := uint64(0); i < length; i++ { + err = binary.Read(reader, binary.BigEndian, &value[i]) + if err != nil { + return nil, err + } + } + return value, nil +} + +func writeRuleItemUint16(writer io.Writer, itemType uint8, value []uint16) error { + err := binary.Write(writer, binary.BigEndian, itemType) + if err != nil { + return err + } + err = rw.WriteUVariant(writer, uint64(len(value))) + if err != nil { + return err + } + for _, item := range value { + err = binary.Write(writer, binary.BigEndian, item) + if err != nil { + return err + } + } + return nil +} + +func writeRuleItemCIDR(writer io.Writer, itemType uint8, value []string) error { + var builder netipx.IPSetBuilder + for i, prefixString := range value { + prefix, err := netip.ParsePrefix(prefixString) + if err == nil { + builder.AddPrefix(prefix) + continue + } + addr, addrErr := netip.ParseAddr(prefixString) + if addrErr == nil { + builder.Add(addr) + continue + } + return E.Cause(err, "parse [", i, "]") + } + ipSet, err := builder.IPSet() + if err != nil { + return err + } + err = binary.Write(writer, binary.BigEndian, itemType) + if err != nil { + return err + } + return writeIPSet(writer, ipSet) +} + +func readLogicalRule(reader io.Reader, recovery bool) (logicalRule option.LogicalHeadlessRule, err error) { + var mode uint8 + err = binary.Read(reader, binary.BigEndian, &mode) + if err != nil { + return + } + switch mode { + case 0: + logicalRule.Mode = C.LogicalTypeAnd + case 1: + logicalRule.Mode = C.LogicalTypeOr + default: + err = E.New("unknown logical mode: ", mode) + return + } + length, err := rw.ReadUVariant(reader) + if err != nil { + return + } + logicalRule.Rules = make([]option.HeadlessRule, length) + for i := uint64(0); i < length; i++ { + logicalRule.Rules[i], err = readRule(reader, recovery) + if err != nil { + err = E.Cause(err, "read logical rule [", i, "]") + return + } + } + err = binary.Read(reader, binary.BigEndian, &logicalRule.Invert) + if err != nil { + return + } + return +} + +func writeLogicalRule(writer io.Writer, logicalRule option.LogicalHeadlessRule) error { + err := binary.Write(writer, binary.BigEndian, uint8(1)) + if err != nil { + return err + } + switch logicalRule.Mode { + case C.LogicalTypeAnd: + err = binary.Write(writer, binary.BigEndian, uint8(0)) + case C.LogicalTypeOr: + err = binary.Write(writer, binary.BigEndian, uint8(1)) + default: + panic("unknown logical mode: " + logicalRule.Mode) + } + if err != nil { + return err + } + err = rw.WriteUVariant(writer, uint64(len(logicalRule.Rules))) + if err != nil { + return err + } + for _, rule := range logicalRule.Rules { + err = writeRule(writer, rule) + if err != nil { + return err + } + } + err = binary.Write(writer, binary.BigEndian, logicalRule.Invert) + if err != nil { + return err + } + return nil +} diff --git a/common/srs/ip_set.go b/common/srs/ip_set.go new file mode 100644 index 0000000000..b346da26f6 --- /dev/null +++ b/common/srs/ip_set.go @@ -0,0 +1,116 @@ +package srs + +import ( + "encoding/binary" + "io" + "net/netip" + "unsafe" + + "github.com/sagernet/sing/common/rw" + + "go4.org/netipx" +) + +type myIPSet struct { + rr []myIPRange +} + +type myIPRange struct { + from netip.Addr + to netip.Addr +} + +func readIPSet(reader io.Reader) (*netipx.IPSet, error) { + var version uint8 + err := binary.Read(reader, binary.BigEndian, &version) + if err != nil { + return nil, err + } + var length uint64 + err = binary.Read(reader, binary.BigEndian, &length) + if err != nil { + return nil, err + } + mySet := &myIPSet{ + rr: make([]myIPRange, length), + } + for i := uint64(0); i < length; i++ { + var ( + fromLen uint64 + toLen uint64 + fromAddr netip.Addr + toAddr netip.Addr + ) + fromLen, err = rw.ReadUVariant(reader) + if err != nil { + return nil, err + } + fromBytes := make([]byte, fromLen) + _, err = io.ReadFull(reader, fromBytes) + if err != nil { + return nil, err + } + err = fromAddr.UnmarshalBinary(fromBytes) + if err != nil { + return nil, err + } + toLen, err = rw.ReadUVariant(reader) + if err != nil { + return nil, err + } + toBytes := make([]byte, toLen) + _, err = io.ReadFull(reader, toBytes) + if err != nil { + return nil, err + } + err = toAddr.UnmarshalBinary(toBytes) + if err != nil { + return nil, err + } + mySet.rr[i] = myIPRange{fromAddr, toAddr} + } + return (*netipx.IPSet)(unsafe.Pointer(mySet)), nil +} + +func writeIPSet(writer io.Writer, set *netipx.IPSet) error { + err := binary.Write(writer, binary.BigEndian, uint8(1)) + if err != nil { + return err + } + mySet := (*myIPSet)(unsafe.Pointer(set)) + err = binary.Write(writer, binary.BigEndian, uint64(len(mySet.rr))) + if err != nil { + return err + } + for _, rr := range mySet.rr { + var ( + fromBinary []byte + toBinary []byte + ) + fromBinary, err = rr.from.MarshalBinary() + if err != nil { + return err + } + err = rw.WriteUVariant(writer, uint64(len(fromBinary))) + if err != nil { + return err + } + _, err = writer.Write(fromBinary) + if err != nil { + return err + } + toBinary, err = rr.to.MarshalBinary() + if err != nil { + return err + } + err = rw.WriteUVariant(writer, uint64(len(toBinary))) + if err != nil { + return err + } + _, err = writer.Write(toBinary) + if err != nil { + return err + } + } + return nil +} diff --git a/constant/rule.go b/constant/rule.go index 3c741995f8..5a8eaf127f 100644 --- a/constant/rule.go +++ b/constant/rule.go @@ -9,3 +9,11 @@ const ( LogicalTypeAnd = "and" LogicalTypeOr = "or" ) + +const ( + RuleSetTypeLocal = "local" + RuleSetTypeRemote = "remote" + RuleSetVersion1 = 1 + RuleSetFormatSource = "source" + RuleSetFormatBinary = "binary" +) diff --git a/experimental/cachefile/cache.go b/experimental/cachefile/cache.go index 962e884f4a..614d3b2b04 100644 --- a/experimental/cachefile/cache.go +++ b/experimental/cachefile/cache.go @@ -22,11 +22,13 @@ var ( bucketSelected = []byte("selected") bucketExpand = []byte("group_expand") bucketMode = []byte("clash_mode") + bucketRuleSet = []byte("rule_set") bucketNameList = []string{ string(bucketSelected), string(bucketExpand), string(bucketMode), + string(bucketRuleSet), } cacheIDDefault = []byte("default") @@ -248,3 +250,36 @@ func (c *CacheFile) StoreGroupExpand(group string, isExpand bool) error { } }) } + +func (c *CacheFile) LoadRuleSet(tag string) *adapter.SavedRuleSet { + var savedSet adapter.SavedRuleSet + err := c.DB.View(func(t *bbolt.Tx) error { + bucket := c.bucket(t, bucketRuleSet) + if bucket == nil { + return os.ErrNotExist + } + setBinary := bucket.Get([]byte(tag)) + if len(setBinary) == 0 { + return os.ErrInvalid + } + return savedSet.UnmarshalBinary(setBinary) + }) + if err != nil { + return nil + } + return &savedSet +} + +func (c *CacheFile) SaveRuleSet(tag string, set *adapter.SavedRuleSet) error { + return c.DB.Batch(func(t *bbolt.Tx) error { + bucket, err := c.createBucket(t, bucketRuleSet) + if err != nil { + return err + } + setBinary, err := set.MarshalBinary() + if err != nil { + return err + } + return bucket.Put([]byte(tag), setBinary) + }) +} diff --git a/experimental/cachefile/fakeip.go b/experimental/cachefile/fakeip.go index 2242342a36..e998ebb859 100644 --- a/experimental/cachefile/fakeip.go +++ b/experimental/cachefile/fakeip.go @@ -25,7 +25,7 @@ func (c *CacheFile) FakeIPMetadata() *adapter.FakeIPMetadata { err := c.DB.Batch(func(tx *bbolt.Tx) error { bucket := tx.Bucket(bucketFakeIP) if bucket == nil { - return nil + return os.ErrNotExist } metadataBinary := bucket.Get(keyMetadata) if len(metadataBinary) == 0 { diff --git a/experimental/clashapi/proxies.go b/experimental/clashapi/proxies.go index 050efd8d17..cf96931a85 100644 --- a/experimental/clashapi/proxies.go +++ b/experimental/clashapi/proxies.go @@ -100,8 +100,10 @@ func getProxies(server *Server, router adapter.Router) func(w http.ResponseWrite allProxies = append(allProxies, detour.Tag()) } - defaultTag := router.DefaultOutbound(N.NetworkTCP).Tag() - if defaultTag == "" { + var defaultTag string + if defaultOutbound, err := router.DefaultOutbound(N.NetworkTCP); err == nil { + defaultTag = defaultOutbound.Tag() + } else { defaultTag = allProxies[0] } diff --git a/experimental/clashapi/server_resources.go b/experimental/clashapi/server_resources.go index ad36641e05..d6d22b5390 100644 --- a/experimental/clashapi/server_resources.go +++ b/experimental/clashapi/server_resources.go @@ -51,7 +51,11 @@ func (s *Server) downloadExternalUI() error { } detour = outbound } else { - detour = s.router.DefaultOutbound(N.NetworkTCP) + outbound, err := s.router.DefaultOutbound(N.NetworkTCP) + if err != nil { + return err + } + detour = outbound } httpClient := &http.Client{ Transport: &http.Transport{ diff --git a/experimental/clashapi/trafficontrol/tracker.go b/experimental/clashapi/trafficontrol/tracker.go index 3dc5a367e7..b7c20eb075 100644 --- a/experimental/clashapi/trafficontrol/tracker.go +++ b/experimental/clashapi/trafficontrol/tracker.go @@ -94,7 +94,9 @@ func NewTCPTracker(conn net.Conn, manager *Manager, metadata Metadata, router ad var chain []string var next string if rule == nil { - next = router.DefaultOutbound(N.NetworkTCP).Tag() + if defaultOutbound, err := router.DefaultOutbound(N.NetworkTCP); err == nil { + next = defaultOutbound.Tag() + } } else { next = rule.Outbound() } @@ -181,7 +183,9 @@ func NewUDPTracker(conn N.PacketConn, manager *Manager, metadata Metadata, route var chain []string var next string if rule == nil { - next = router.DefaultOutbound(N.NetworkUDP).Tag() + if defaultOutbound, err := router.DefaultOutbound(N.NetworkUDP); err == nil { + next = defaultOutbound.Tag() + } } else { next = rule.Outbound() } diff --git a/go.mod b/go.mod index f11ae050b9..3acb98db74 100644 --- a/go.mod +++ b/go.mod @@ -26,7 +26,7 @@ require ( github.com/sagernet/gvisor v0.0.0-20231119034329-07cfb6aaf930 github.com/sagernet/quic-go v0.40.0 github.com/sagernet/reality v0.0.0-20230406110435-ee17307e7691 - github.com/sagernet/sing v0.2.18-0.20231124125253-2dcabf4bfcbc + github.com/sagernet/sing v0.2.18-0.20231129035309-b2983d65bc58 github.com/sagernet/sing-dns v0.1.11 github.com/sagernet/sing-mux v0.1.5-0.20231109075101-6b086ed6bb07 github.com/sagernet/sing-quic v0.1.5-0.20231123150216-00957d136203 diff --git a/go.sum b/go.sum index 2b9a369ea0..f8969b0378 100644 --- a/go.sum +++ b/go.sum @@ -110,8 +110,8 @@ github.com/sagernet/reality v0.0.0-20230406110435-ee17307e7691 h1:5Th31OC6yj8byL github.com/sagernet/reality v0.0.0-20230406110435-ee17307e7691/go.mod h1:B8lp4WkQ1PwNnrVMM6KyuFR20pU8jYBD+A4EhJovEXU= github.com/sagernet/sing v0.0.0-20220817130738-ce854cda8522/go.mod h1:QVsS5L/ZA2Q5UhQwLrn0Trw+msNd/NPGEhBKR/ioWiY= github.com/sagernet/sing v0.1.8/go.mod h1:jt1w2u7lJQFFSGLiRrRIs5YWmx4kAPfWuOejuDW9qMk= -github.com/sagernet/sing v0.2.18-0.20231124125253-2dcabf4bfcbc h1:vESVuxHgbd2EzHxd+TYTpNACIEGBOhp5n3KG7bgbcws= -github.com/sagernet/sing v0.2.18-0.20231124125253-2dcabf4bfcbc/go.mod h1:OL6k2F0vHmEzXz2KW19qQzu172FDgSbUSODylighuVo= +github.com/sagernet/sing v0.2.18-0.20231129035309-b2983d65bc58 h1:nA1OoozU/j6T+DB1ooDNjdtuXpc64SmcEoUCf6p95NU= +github.com/sagernet/sing v0.2.18-0.20231129035309-b2983d65bc58/go.mod h1:OL6k2F0vHmEzXz2KW19qQzu172FDgSbUSODylighuVo= github.com/sagernet/sing-dns v0.1.11 h1:PPrMCVVrAeR3f5X23I+cmvacXJ+kzuyAsBiWyUKhGSE= github.com/sagernet/sing-dns v0.1.11/go.mod h1:zJ/YjnYB61SYE+ubMcMqVdpaSvsyQ2iShQGO3vuLvvE= github.com/sagernet/sing-mux v0.1.5-0.20231109075101-6b086ed6bb07 h1:ncKb5tVOsCQgCsv6UpsA0jinbNb5OQ5GMPJlyQP3EHM= diff --git a/option/route.go b/option/route.go index 43150576e2..e313fcf242 100644 --- a/option/route.go +++ b/option/route.go @@ -4,6 +4,7 @@ type RouteOptions struct { GeoIP *GeoIPOptions `json:"geoip,omitempty"` Geosite *GeositeOptions `json:"geosite,omitempty"` Rules []Rule `json:"rules,omitempty"` + RuleSet []RuleSet `json:"rule_set,omitempty"` Final string `json:"final,omitempty"` FindProcess bool `json:"find_process,omitempty"` AutoDetectInterface bool `json:"auto_detect_interface,omitempty"` diff --git a/option/rule.go b/option/rule.go index 4f4042025f..b9639a2cbb 100644 --- a/option/rule.go +++ b/option/rule.go @@ -91,6 +91,7 @@ type DefaultRule struct { ClashMode string `json:"clash_mode,omitempty"` WIFISSID Listable[string] `json:"wifi_ssid,omitempty"` WIFIBSSID Listable[string] `json:"wifi_bssid,omitempty"` + RuleSet Listable[string] `json:"rule_set,omitempty"` Invert bool `json:"invert,omitempty"` Outbound string `json:"outbound,omitempty"` } diff --git a/option/rule_dns.go b/option/rule_dns.go index fca3432293..c02d09f761 100644 --- a/option/rule_dns.go +++ b/option/rule_dns.go @@ -91,6 +91,7 @@ type DefaultDNSRule struct { ClashMode string `json:"clash_mode,omitempty"` WIFISSID Listable[string] `json:"wifi_ssid,omitempty"` WIFIBSSID Listable[string] `json:"wifi_bssid,omitempty"` + RuleSet Listable[string] `json:"rule_set,omitempty"` Invert bool `json:"invert,omitempty"` Server string `json:"server,omitempty"` DisableCache bool `json:"disable_cache,omitempty"` diff --git a/option/rule_set.go b/option/rule_set.go new file mode 100644 index 0000000000..0f5f175da1 --- /dev/null +++ b/option/rule_set.go @@ -0,0 +1,230 @@ +package option + +import ( + "reflect" + + "github.com/sagernet/sing-box/common/json" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/domain" + E "github.com/sagernet/sing/common/exceptions" + F "github.com/sagernet/sing/common/format" + + "go4.org/netipx" +) + +type _RuleSet struct { + Tag string `json:"tag"` + Type string `json:"type"` + Format string `json:"format"` + LocalOptions LocalRuleSet `json:"-"` + RemoteOptions RemoteRuleSet `json:"-"` +} + +type RuleSet _RuleSet + +func (r RuleSet) MarshalJSON() ([]byte, error) { + var v any + switch r.Type { + case C.RuleSetTypeLocal: + v = r.LocalOptions + case C.RuleSetTypeRemote: + v = r.RemoteOptions + default: + return nil, E.New("unknown rule set type: " + r.Type) + } + return MarshallObjects((_RuleSet)(r), v) +} + +func (r *RuleSet) UnmarshalJSON(bytes []byte) error { + err := json.Unmarshal(bytes, (*_RuleSet)(r)) + if err != nil { + return err + } + if r.Tag == "" { + return E.New("missing rule_set.[].tag") + } + switch r.Format { + case "": + return E.New("missing rule_set.[].format") + case C.RuleSetFormatSource, C.RuleSetFormatBinary: + default: + return E.New("unknown rule set format: " + r.Format) + } + var v any + switch r.Type { + case C.RuleSetTypeLocal: + v = &r.LocalOptions + case C.RuleSetTypeRemote: + v = &r.RemoteOptions + case "": + return E.New("missing rule_set.[].type") + default: + return E.New("unknown rule set type: " + r.Type) + } + err = UnmarshallExcluded(bytes, (*_RuleSet)(r), v) + if err != nil { + return E.Cause(err, "rule set") + } + return nil +} + +type LocalRuleSet struct { + Path string `json:"path,omitempty"` +} + +type RemoteRuleSet struct { + URL string `json:"url,omitempty"` + DownloadDetour string `json:"download_detour,omitempty"` + UpdateInterval Duration `json:"update_interval,omitempty"` +} + +type _HeadlessRule struct { + Type string `json:"type,omitempty"` + DefaultOptions DefaultHeadlessRule `json:"-"` + LogicalOptions LogicalHeadlessRule `json:"-"` +} + +type HeadlessRule _HeadlessRule + +func (r HeadlessRule) MarshalJSON() ([]byte, error) { + var v any + switch r.Type { + case C.RuleTypeDefault: + r.Type = "" + v = r.DefaultOptions + case C.RuleTypeLogical: + v = r.LogicalOptions + default: + return nil, E.New("unknown rule type: " + r.Type) + } + return MarshallObjects((_HeadlessRule)(r), v) +} + +func (r *HeadlessRule) UnmarshalJSON(bytes []byte) error { + err := json.Unmarshal(bytes, (*_HeadlessRule)(r)) + if err != nil { + return err + } + var v any + switch r.Type { + case "", C.RuleTypeDefault: + r.Type = C.RuleTypeDefault + v = &r.DefaultOptions + case C.RuleTypeLogical: + v = &r.LogicalOptions + default: + return E.New("unknown rule type: " + r.Type) + } + err = UnmarshallExcluded(bytes, (*_HeadlessRule)(r), v) + if err != nil { + return E.Cause(err, "route rule-set rule") + } + return nil +} + +func (r HeadlessRule) IsValid() bool { + switch r.Type { + case C.RuleTypeDefault, "": + return r.DefaultOptions.IsValid() + case C.RuleTypeLogical: + return r.LogicalOptions.IsValid() + default: + panic("unknown rule type: " + r.Type) + } +} + +type DefaultHeadlessRule struct { + QueryType Listable[DNSQueryType] `json:"query_type,omitempty"` + Network Listable[string] `json:"network,omitempty"` + Domain Listable[string] `json:"domain,omitempty"` + DomainSuffix Listable[string] `json:"domain_suffix,omitempty"` + DomainKeyword Listable[string] `json:"domain_keyword,omitempty"` + DomainRegex Listable[string] `json:"domain_regex,omitempty"` + SourceIPCIDR Listable[string] `json:"source_ip_cidr,omitempty"` + IPCIDR Listable[string] `json:"ip_cidr,omitempty"` + SourcePort Listable[uint16] `json:"source_port,omitempty"` + SourcePortRange Listable[string] `json:"source_port_range,omitempty"` + Port Listable[uint16] `json:"port,omitempty"` + PortRange Listable[string] `json:"port_range,omitempty"` + ProcessName Listable[string] `json:"process_name,omitempty"` + ProcessPath Listable[string] `json:"process_path,omitempty"` + PackageName Listable[string] `json:"package_name,omitempty"` + WIFISSID Listable[string] `json:"wifi_ssid,omitempty"` + WIFIBSSID Listable[string] `json:"wifi_bssid,omitempty"` + Invert bool `json:"invert,omitempty"` + + DomainMatcher *domain.Matcher `json:"-"` + SourceIPSet *netipx.IPSet `json:"-"` + IPSet *netipx.IPSet `json:"-"` +} + +func (r DefaultHeadlessRule) IsValid() bool { + var defaultValue DefaultHeadlessRule + defaultValue.Invert = r.Invert + return !reflect.DeepEqual(r, defaultValue) +} + +type LogicalHeadlessRule struct { + Mode string `json:"mode"` + Rules []HeadlessRule `json:"rules,omitempty"` + Invert bool `json:"invert,omitempty"` +} + +func (r LogicalHeadlessRule) IsValid() bool { + return len(r.Rules) > 0 && common.All(r.Rules, HeadlessRule.IsValid) +} + +type _PlainRuleSetCompat struct { + Version int `json:"version"` + Options PlainRuleSet `json:"-"` +} + +type PlainRuleSetCompat _PlainRuleSetCompat + +func (r PlainRuleSetCompat) MarshalJSON() ([]byte, error) { + var v any + switch r.Version { + case C.RuleSetVersion1: + v = r.Options + default: + return nil, E.New("unknown rule set version: ", r.Version) + } + return MarshallObjects((_PlainRuleSetCompat)(r), v) +} + +func (r *PlainRuleSetCompat) UnmarshalJSON(bytes []byte) error { + err := json.Unmarshal(bytes, (*_PlainRuleSetCompat)(r)) + if err != nil { + return err + } + var v any + switch r.Version { + case C.RuleSetVersion1: + v = &r.Options + case 0: + return E.New("missing rule set version") + default: + return E.New("unknown rule set version: ", r.Version) + } + err = UnmarshallExcluded(bytes, (*_PlainRuleSetCompat)(r), v) + if err != nil { + return E.Cause(err, "rule set") + } + return nil +} + +func (r PlainRuleSetCompat) Upgrade() PlainRuleSet { + var result PlainRuleSet + switch r.Version { + case C.RuleSetVersion1: + result = r.Options + default: + panic("unknown rule set version: " + F.ToString(r.Version)) + } + return result +} + +type PlainRuleSet struct { + Rules []HeadlessRule `json:"rules,omitempty"` +} diff --git a/option/types.go b/option/types.go index f2fed66309..520c3503aa 100644 --- a/option/types.go +++ b/option/types.go @@ -174,6 +174,14 @@ func (d *Duration) UnmarshalJSON(bytes []byte) error { type DNSQueryType uint16 +func (t DNSQueryType) String() string { + typeName, loaded := mDNS.TypeToString[uint16(t)] + if loaded { + return typeName + } + return F.ToString(uint16(t)) +} + func (t DNSQueryType) MarshalJSON() ([]byte, error) { typeName, loaded := mDNS.TypeToString[uint16(t)] if loaded { diff --git a/route/router.go b/route/router.go index e50a51a955..3f44975d4a 100644 --- a/route/router.go +++ b/route/router.go @@ -67,6 +67,8 @@ type Router struct { dnsClient *dns.Client defaultDomainStrategy dns.DomainStrategy dnsRules []adapter.DNSRule + ruleSets []adapter.RuleSet + ruleSetMap map[string]adapter.RuleSet defaultTransport dns.Transport transports []dns.Transport transportMap map[string]dns.Transport @@ -106,6 +108,7 @@ func NewRouter( outboundByTag: make(map[string]adapter.Outbound), rules: make([]adapter.Rule, 0, len(options.Rules)), dnsRules: make([]adapter.DNSRule, 0, len(dnsOptions.Rules)), + ruleSetMap: make(map[string]adapter.RuleSet), needGeoIPDatabase: hasRule(options.Rules, isGeoIPRule) || hasDNSRule(dnsOptions.Rules, isGeoIPDNSRule), needGeositeDatabase: hasRule(options.Rules, isGeositeRule) || hasDNSRule(dnsOptions.Rules, isGeositeDNSRule), geoIPOptions: common.PtrValueOrDefault(options.GeoIP), @@ -140,6 +143,14 @@ func NewRouter( } router.dnsRules = append(router.dnsRules, dnsRule) } + for i, ruleSetOptions := range options.RuleSet { + ruleSet, err := NewRuleSet(ctx, router, router.logger, ruleSetOptions) + if err != nil { + return nil, E.Cause(err, "parse rule-set[", i, "]") + } + router.ruleSets = append(router.ruleSets, ruleSet) + router.ruleSetMap[ruleSetOptions.Tag] = ruleSet + } transports := make([]dns.Transport, len(dnsOptions.Servers)) dummyTransportMap := make(map[string]dns.Transport) @@ -479,6 +490,12 @@ func (r *Router) Start() error { if r.needWIFIState { r.updateWIFIState() } + for i, ruleSet := range r.ruleSets { + err := ruleSet.Start() + if err != nil { + return E.Cause(err, "initialize rule-set[", i, "]") + } + } for i, rule := range r.rules { err := rule.Start() if err != nil { @@ -576,11 +593,17 @@ func (r *Router) Outbound(tag string) (adapter.Outbound, bool) { return outbound, loaded } -func (r *Router) DefaultOutbound(network string) adapter.Outbound { +func (r *Router) DefaultOutbound(network string) (adapter.Outbound, error) { if network == N.NetworkTCP { - return r.defaultOutboundForConnection + if r.defaultOutboundForConnection == nil { + return nil, E.New("missing default outbound for TCP connections") + } + return r.defaultOutboundForConnection, nil } else { - return r.defaultOutboundForPacketConnection + if r.defaultOutboundForPacketConnection == nil { + return nil, E.New("missing default outbound for UDP connections") + } + return r.defaultOutboundForPacketConnection, nil } } @@ -588,6 +611,11 @@ func (r *Router) FakeIPStore() adapter.FakeIPStore { return r.fakeIPStore } +func (r *Router) RuleSet(tag string) (adapter.RuleSet, bool) { + ruleSet, loaded := r.ruleSetMap[tag] + return ruleSet, loaded +} + func (r *Router) RouteConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error { if metadata.InboundDetour != "" { if metadata.LastInbound == metadata.InboundDetour { diff --git a/route/rule_abstract.go b/route/rule_abstract.go index 38d4d57d41..312caaee73 100644 --- a/route/rule_abstract.go +++ b/route/rule_abstract.go @@ -1,6 +1,7 @@ package route import ( + "io" "strings" "github.com/sagernet/sing-box/adapter" @@ -135,7 +136,7 @@ func (r *abstractDefaultRule) String() string { } type abstractLogicalRule struct { - rules []adapter.Rule + rules []adapter.HeadlessRule mode string invert bool outbound string @@ -146,7 +147,10 @@ func (r *abstractLogicalRule) Type() string { } func (r *abstractLogicalRule) UpdateGeosite() error { - for _, rule := range r.rules { + for _, rule := range common.FilterIsInstance(r.rules, func(it adapter.HeadlessRule) (adapter.Rule, bool) { + rule, loaded := it.(adapter.Rule) + return rule, loaded + }) { err := rule.UpdateGeosite() if err != nil { return err @@ -156,7 +160,10 @@ func (r *abstractLogicalRule) UpdateGeosite() error { } func (r *abstractLogicalRule) Start() error { - for _, rule := range r.rules { + for _, rule := range common.FilterIsInstance(r.rules, func(it adapter.HeadlessRule) (common.Starter, bool) { + rule, loaded := it.(common.Starter) + return rule, loaded + }) { err := rule.Start() if err != nil { return err @@ -166,7 +173,10 @@ func (r *abstractLogicalRule) Start() error { } func (r *abstractLogicalRule) Close() error { - for _, rule := range r.rules { + for _, rule := range common.FilterIsInstance(r.rules, func(it adapter.HeadlessRule) (io.Closer, bool) { + rule, loaded := it.(io.Closer) + return rule, loaded + }) { err := rule.Close() if err != nil { return err @@ -177,11 +187,11 @@ func (r *abstractLogicalRule) Close() error { func (r *abstractLogicalRule) Match(metadata *adapter.InboundContext) bool { if r.mode == C.LogicalTypeAnd { - return common.All(r.rules, func(it adapter.Rule) bool { + return common.All(r.rules, func(it adapter.HeadlessRule) bool { return it.Match(metadata) }) != r.invert } else { - return common.Any(r.rules, func(it adapter.Rule) bool { + return common.Any(r.rules, func(it adapter.HeadlessRule) bool { return it.Match(metadata) }) != r.invert } diff --git a/route/rule_default.go b/route/rule_default.go index 758aa34704..19174fcf8d 100644 --- a/route/rule_default.go +++ b/route/rule_default.go @@ -194,6 +194,11 @@ func NewDefaultRule(router adapter.Router, logger log.ContextLogger, options opt rule.items = append(rule.items, item) rule.allItems = append(rule.allItems, item) } + if len(options.RuleSet) > 0 { + item := NewRuleSetItem(router, options.RuleSet) + rule.items = append(rule.items, item) + rule.allItems = append(rule.allItems, item) + } return rule, nil } @@ -206,7 +211,7 @@ type LogicalRule struct { func NewLogicalRule(router adapter.Router, logger log.ContextLogger, options option.LogicalRule) (*LogicalRule, error) { r := &LogicalRule{ abstractLogicalRule{ - rules: make([]adapter.Rule, len(options.Rules)), + rules: make([]adapter.HeadlessRule, len(options.Rules)), invert: options.Invert, outbound: options.Outbound, }, diff --git a/route/rule_dns.go b/route/rule_dns.go index 934c45c622..d12dbdff5c 100644 --- a/route/rule_dns.go +++ b/route/rule_dns.go @@ -190,6 +190,11 @@ func NewDefaultDNSRule(router adapter.Router, logger log.ContextLogger, options rule.items = append(rule.items, item) rule.allItems = append(rule.allItems, item) } + if len(options.RuleSet) > 0 { + item := NewRuleSetItem(router, options.RuleSet) + rule.items = append(rule.items, item) + rule.allItems = append(rule.allItems, item) + } return rule, nil } @@ -212,7 +217,7 @@ type LogicalDNSRule struct { func NewLogicalDNSRule(router adapter.Router, logger log.ContextLogger, options option.LogicalDNSRule) (*LogicalDNSRule, error) { r := &LogicalDNSRule{ abstractLogicalRule: abstractLogicalRule{ - rules: make([]adapter.Rule, len(options.Rules)), + rules: make([]adapter.HeadlessRule, len(options.Rules)), invert: options.Invert, outbound: options.Server, }, diff --git a/route/rule_headless.go b/route/rule_headless.go new file mode 100644 index 0000000000..9df2ee3036 --- /dev/null +++ b/route/rule_headless.go @@ -0,0 +1,173 @@ +package route + +import ( + "github.com/sagernet/sing-box/adapter" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/option" + E "github.com/sagernet/sing/common/exceptions" +) + +func NewHeadlessRule(router adapter.Router, options option.HeadlessRule) (adapter.HeadlessRule, error) { + switch options.Type { + case "", C.RuleTypeDefault: + if !options.DefaultOptions.IsValid() { + return nil, E.New("missing conditions") + } + return NewDefaultHeadlessRule(router, options.DefaultOptions) + case C.RuleTypeLogical: + if !options.LogicalOptions.IsValid() { + return nil, E.New("missing conditions") + } + return NewLogicalHeadlessRule(router, options.LogicalOptions) + default: + return nil, E.New("unknown rule type: ", options.Type) + } +} + +var _ adapter.HeadlessRule = (*DefaultHeadlessRule)(nil) + +type DefaultHeadlessRule struct { + abstractDefaultRule +} + +func NewDefaultHeadlessRule(router adapter.Router, options option.DefaultHeadlessRule) (*DefaultHeadlessRule, error) { + rule := &DefaultHeadlessRule{ + abstractDefaultRule{ + invert: options.Invert, + }, + } + if len(options.Network) > 0 { + item := NewNetworkItem(options.Network) + rule.items = append(rule.items, item) + rule.allItems = append(rule.allItems, item) + } + if len(options.Domain) > 0 || len(options.DomainSuffix) > 0 { + item := NewDomainItem(options.Domain, options.DomainSuffix) + rule.destinationAddressItems = append(rule.destinationAddressItems, item) + rule.allItems = append(rule.allItems, item) + } else if options.DomainMatcher != nil { + item := NewRawDomainItem(options.DomainMatcher) + rule.destinationAddressItems = append(rule.destinationAddressItems, item) + rule.allItems = append(rule.allItems, item) + } + if len(options.DomainKeyword) > 0 { + item := NewDomainKeywordItem(options.DomainKeyword) + rule.destinationAddressItems = append(rule.destinationAddressItems, item) + rule.allItems = append(rule.allItems, item) + } + if len(options.DomainRegex) > 0 { + item, err := NewDomainRegexItem(options.DomainRegex) + if err != nil { + return nil, E.Cause(err, "domain_regex") + } + rule.destinationAddressItems = append(rule.destinationAddressItems, item) + rule.allItems = append(rule.allItems, item) + } + if len(options.SourceIPCIDR) > 0 { + item, err := NewIPCIDRItem(true, options.SourceIPCIDR) + if err != nil { + return nil, E.Cause(err, "source_ipcidr") + } + rule.sourceAddressItems = append(rule.sourceAddressItems, item) + rule.allItems = append(rule.allItems, item) + } else if options.SourceIPSet != nil { + item := NewRawIPCIDRItem(true, options.SourceIPSet) + rule.sourceAddressItems = append(rule.sourceAddressItems, item) + rule.allItems = append(rule.allItems, item) + } + if len(options.IPCIDR) > 0 { + item, err := NewIPCIDRItem(false, options.IPCIDR) + if err != nil { + return nil, E.Cause(err, "ipcidr") + } + rule.destinationAddressItems = append(rule.destinationAddressItems, item) + rule.allItems = append(rule.allItems, item) + } else if options.IPSet != nil { + item := NewRawIPCIDRItem(false, options.IPSet) + rule.destinationAddressItems = append(rule.destinationAddressItems, item) + rule.allItems = append(rule.allItems, item) + } + if len(options.SourcePort) > 0 { + item := NewPortItem(true, options.SourcePort) + rule.sourcePortItems = append(rule.sourcePortItems, item) + rule.allItems = append(rule.allItems, item) + } + if len(options.SourcePortRange) > 0 { + item, err := NewPortRangeItem(true, options.SourcePortRange) + if err != nil { + return nil, E.Cause(err, "source_port_range") + } + rule.sourcePortItems = append(rule.sourcePortItems, item) + rule.allItems = append(rule.allItems, item) + } + if len(options.Port) > 0 { + item := NewPortItem(false, options.Port) + rule.destinationPortItems = append(rule.destinationPortItems, item) + rule.allItems = append(rule.allItems, item) + } + if len(options.PortRange) > 0 { + item, err := NewPortRangeItem(false, options.PortRange) + if err != nil { + return nil, E.Cause(err, "port_range") + } + rule.destinationPortItems = append(rule.destinationPortItems, item) + rule.allItems = append(rule.allItems, item) + } + if len(options.ProcessName) > 0 { + item := NewProcessItem(options.ProcessName) + rule.items = append(rule.items, item) + rule.allItems = append(rule.allItems, item) + } + if len(options.ProcessPath) > 0 { + item := NewProcessPathItem(options.ProcessPath) + rule.items = append(rule.items, item) + rule.allItems = append(rule.allItems, item) + } + if len(options.PackageName) > 0 { + item := NewPackageNameItem(options.PackageName) + rule.items = append(rule.items, item) + rule.allItems = append(rule.allItems, item) + } + if len(options.WIFISSID) > 0 { + item := NewWIFISSIDItem(router, options.WIFISSID) + rule.items = append(rule.items, item) + rule.allItems = append(rule.allItems, item) + } + if len(options.WIFIBSSID) > 0 { + item := NewWIFIBSSIDItem(router, options.WIFIBSSID) + rule.items = append(rule.items, item) + rule.allItems = append(rule.allItems, item) + } + return rule, nil +} + +var _ adapter.HeadlessRule = (*LogicalHeadlessRule)(nil) + +type LogicalHeadlessRule struct { + abstractLogicalRule +} + +func NewLogicalHeadlessRule(router adapter.Router, options option.LogicalHeadlessRule) (*LogicalHeadlessRule, error) { + r := &LogicalHeadlessRule{ + abstractLogicalRule{ + rules: make([]adapter.HeadlessRule, len(options.Rules)), + invert: options.Invert, + }, + } + switch options.Mode { + case C.LogicalTypeAnd: + r.mode = C.LogicalTypeAnd + case C.LogicalTypeOr: + r.mode = C.LogicalTypeOr + default: + return nil, E.New("unknown logical mode: ", options.Mode) + } + for i, subRule := range options.Rules { + rule, err := NewHeadlessRule(router, subRule) + if err != nil { + return nil, E.Cause(err, "sub rule[", i, "]") + } + r.rules[i] = rule + } + return r, nil +} diff --git a/route/rule_item_cidr.go b/route/rule_item_cidr.go index b72d1e10b1..97eb9cefab 100644 --- a/route/rule_item_cidr.go +++ b/route/rule_item_cidr.go @@ -31,7 +31,7 @@ func NewIPCIDRItem(isSource bool, prefixStrings []string) (*IPCIDRItem, error) { builder.Add(addr) continue } - return nil, E.Cause(err, "parse ip_cidr [", i, "]") + return nil, E.Cause(err, "parse [", i, "]") } var description string if isSource { @@ -57,6 +57,21 @@ func NewIPCIDRItem(isSource bool, prefixStrings []string) (*IPCIDRItem, error) { }, nil } +func NewRawIPCIDRItem(isSource bool, ipSet *netipx.IPSet) *IPCIDRItem { + var description string + if isSource { + description = "source_ipcidr=" + } else { + description = "ipcidr=" + } + description += "" + return &IPCIDRItem{ + ipSet: ipSet, + isSource: isSource, + description: description, + } +} + func (r *IPCIDRItem) Match(metadata *adapter.InboundContext) bool { if r.isSource { return r.ipSet.Contains(metadata.Source.Addr) diff --git a/route/rule_item_domain.go b/route/rule_item_domain.go index 6602441deb..d2a11181b0 100644 --- a/route/rule_item_domain.go +++ b/route/rule_item_domain.go @@ -43,6 +43,13 @@ func NewDomainItem(domains []string, domainSuffixes []string) *DomainItem { } } +func NewRawDomainItem(matcher *domain.Matcher) *DomainItem { + return &DomainItem{ + matcher, + "domain/domain_suffix=", + } +} + func (r *DomainItem) Match(metadata *adapter.InboundContext) bool { var domainHost string if metadata.Domain != "" { diff --git a/route/rule_item_rule_set.go b/route/rule_item_rule_set.go new file mode 100644 index 0000000000..43c03c4436 --- /dev/null +++ b/route/rule_item_rule_set.go @@ -0,0 +1,52 @@ +package route + +import ( + "strings" + + "github.com/sagernet/sing-box/adapter" + E "github.com/sagernet/sing/common/exceptions" + F "github.com/sagernet/sing/common/format" +) + +var _ RuleItem = (*RuleSetItem)(nil) + +type RuleSetItem struct { + router adapter.Router + tagList []string + setList []adapter.HeadlessRule +} + +func NewRuleSetItem(router adapter.Router, tagList []string) *RuleSetItem { + return &RuleSetItem{ + router: router, + tagList: tagList, + } +} + +func (r *RuleSetItem) Start() error { + for _, tag := range r.tagList { + ruleSet, loaded := r.router.RuleSet(tag) + if !loaded { + return E.New("rule-set not found: ", tag) + } + r.setList = append(r.setList, ruleSet) + } + return nil +} + +func (r *RuleSetItem) Match(metadata *adapter.InboundContext) bool { + for _, ruleSet := range r.setList { + if ruleSet.Match(metadata) { + return true + } + } + return false +} + +func (r *RuleSetItem) String() string { + if len(r.tagList) == 1 { + return F.ToString("rule_set=", r.tagList[0]) + } else { + return F.ToString("rule_set=[", strings.Join(r.tagList, " "), "]") + } +} diff --git a/route/rule_set.go b/route/rule_set.go new file mode 100644 index 0000000000..76c78c6214 --- /dev/null +++ b/route/rule_set.go @@ -0,0 +1,22 @@ +package route + +import ( + "context" + + "github.com/sagernet/sing-box/adapter" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/option" + E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/logger" +) + +func NewRuleSet(ctx context.Context, router adapter.Router, logger logger.ContextLogger, options option.RuleSet) (adapter.RuleSet, error) { + switch options.Type { + case C.RuleSetTypeLocal: + return NewLocalRuleSet(router, options) + case C.RuleSetTypeRemote: + return NewRemoteRuleSet(ctx, router, logger, options), nil + default: + return nil, E.New("unknown rule set type: ", options.Type) + } +} diff --git a/route/rule_set_local.go b/route/rule_set_local.go new file mode 100644 index 0000000000..ccdb1704a8 --- /dev/null +++ b/route/rule_set_local.go @@ -0,0 +1,69 @@ +package route + +import ( + "os" + + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/common/json" + "github.com/sagernet/sing-box/common/srs" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/option" + E "github.com/sagernet/sing/common/exceptions" +) + +var _ adapter.RuleSet = (*LocalRuleSet)(nil) + +type LocalRuleSet struct { + rules []adapter.HeadlessRule +} + +func NewLocalRuleSet(router adapter.Router, options option.RuleSet) (*LocalRuleSet, error) { + setFile, err := os.Open(options.LocalOptions.Path) + if err != nil { + return nil, err + } + var plainRuleSet option.PlainRuleSet + switch options.Format { + case C.RuleSetFormatSource, "": + var compat option.PlainRuleSetCompat + decoder := json.NewDecoder(json.NewCommentFilter(setFile)) + decoder.DisallowUnknownFields() + err = decoder.Decode(&compat) + if err != nil { + return nil, err + } + plainRuleSet = compat.Upgrade() + case C.RuleSetFormatBinary: + plainRuleSet, err = srs.Read(setFile, false) + if err != nil { + return nil, err + } + default: + return nil, E.New("unknown rule set format: ", options.Format) + } + rules := make([]adapter.HeadlessRule, len(plainRuleSet.Rules)) + for i, ruleOptions := range plainRuleSet.Rules { + rules[i], err = NewHeadlessRule(router, ruleOptions) + if err != nil { + return nil, E.Cause(err, "parse rule_set.rules.[", i, "]") + } + } + return &LocalRuleSet{rules}, nil +} + +func (s *LocalRuleSet) Match(metadata *adapter.InboundContext) bool { + for _, rule := range s.rules { + if rule.Match(metadata) { + return true + } + } + return false +} + +func (s *LocalRuleSet) Start() error { + return nil +} + +func (s *LocalRuleSet) Close() error { + return nil +} diff --git a/route/rule_set_remote.go b/route/rule_set_remote.go new file mode 100644 index 0000000000..3a38f715bb --- /dev/null +++ b/route/rule_set_remote.go @@ -0,0 +1,191 @@ +package route + +import ( + "bytes" + "context" + "io" + "net" + "net/http" + "time" + + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/common/json" + "github.com/sagernet/sing-box/common/srs" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/option" + E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/logger" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" + "github.com/sagernet/sing/service" +) + +var _ adapter.RuleSet = (*RemoteRuleSet)(nil) + +type RemoteRuleSet struct { + ctx context.Context + cancel context.CancelFunc + router adapter.Router + logger logger.ContextLogger + options option.RuleSet + dialer N.Dialer + rules []adapter.HeadlessRule + lastUpdated time.Time + updateTicker *time.Ticker +} + +func NewRemoteRuleSet(ctx context.Context, router adapter.Router, logger logger.ContextLogger, options option.RuleSet) *RemoteRuleSet { + ctx, cancel := context.WithCancel(ctx) + return &RemoteRuleSet{ + ctx: ctx, + cancel: cancel, + router: router, + logger: logger, + options: options, + } +} + +func (s *RemoteRuleSet) Match(metadata *adapter.InboundContext) bool { + for _, rule := range s.rules { + if rule.Match(metadata) { + return true + } + } + return false +} + +func (s *RemoteRuleSet) Start() error { + var dialer N.Dialer + if s.options.RemoteOptions.DownloadDetour != "" { + outbound, loaded := s.router.Outbound(s.options.RemoteOptions.DownloadDetour) + if !loaded { + return E.New("download_detour not found: ", s.options.RemoteOptions.DownloadDetour) + } + dialer = outbound + } else { + outbound, err := s.router.DefaultOutbound(N.NetworkTCP) + if err != nil { + return err + } + dialer = outbound + } + s.dialer = dialer + cacheFile := service.FromContext[adapter.CacheFile](s.ctx) + if cacheFile != nil { + if savedSet := cacheFile.LoadRuleSet(s.options.Tag); savedSet != nil { + err := s.loadBytes(savedSet.Content) + if err != nil { + return E.Cause(err, "restore cached rule-set") + } + s.lastUpdated = savedSet.LastUpdated + } + } + if s.lastUpdated.IsZero() || time.Since(s.lastUpdated) > time.Duration(s.options.RemoteOptions.UpdateInterval) { + err := s.fetchOnce() + if err != nil { + return E.Cause(err, "fetch rule-set ", s.options.Tag) + } + } + s.updateTicker = time.NewTicker(time.Duration(s.options.RemoteOptions.UpdateInterval)) + go s.loopUpdate() + return nil +} + +func (s *RemoteRuleSet) loadBytes(content []byte) error { + var ( + plainRuleSet option.PlainRuleSet + err error + ) + switch s.options.Format { + case C.RuleSetFormatSource, "": + var compat option.PlainRuleSetCompat + decoder := json.NewDecoder(json.NewCommentFilter(bytes.NewReader(content))) + decoder.DisallowUnknownFields() + err = decoder.Decode(&compat) + if err != nil { + return err + } + plainRuleSet = compat.Upgrade() + case C.RuleSetFormatBinary: + plainRuleSet, err = srs.Read(bytes.NewReader(content), false) + if err != nil { + return err + } + default: + return E.New("unknown rule set format: ", s.options.Format) + } + rules := make([]adapter.HeadlessRule, len(plainRuleSet.Rules)) + for i, ruleOptions := range plainRuleSet.Rules { + rules[i], err = NewHeadlessRule(s.router, ruleOptions) + if err != nil { + return E.Cause(err, "parse rule_set.rules.[", i, "]") + } + } + s.rules = rules + return nil +} + +func (s *RemoteRuleSet) loopUpdate() { + for { + select { + case <-s.ctx.Done(): + return + case <-s.updateTicker.C: + err := s.fetchOnce() + if err != nil { + s.logger.Error("fetch rule-set ", s.options.Tag, ": ", err) + } + } + } +} + +func (s *RemoteRuleSet) fetchOnce() error { + s.logger.Info("fetching rule-set ", s.options.Tag, " from URL: ", s.options.RemoteOptions.URL) + httpClient := &http.Client{ + Transport: &http.Transport{ + ForceAttemptHTTP2: true, + TLSHandshakeTimeout: C.TCPTimeout, + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + return s.dialer.DialContext(ctx, network, M.ParseSocksaddr(addr)) + }, + }, + } + defer httpClient.CloseIdleConnections() + request, err := http.NewRequest("GET", s.options.RemoteOptions.URL, nil) + if err != nil { + return err + } + response, err := httpClient.Do(request.WithContext(s.ctx)) + if err != nil { + return err + } + content, err := io.ReadAll(response.Body) + if err != nil { + response.Body.Close() + return err + } + err = s.loadBytes(content) + if err != nil { + response.Body.Close() + return err + } + response.Body.Close() + s.lastUpdated = time.Now() + cacheFile := service.FromContext[adapter.CacheFile](s.ctx) + if cacheFile != nil { + err = cacheFile.SaveRuleSet(s.options.Tag, &adapter.SavedRuleSet{ + LastUpdated: s.lastUpdated, + Content: content, + }) + if err != nil { + s.logger.Error("save rule-set cache: ", err) + } + } + return nil +} + +func (s *RemoteRuleSet) Close() error { + s.updateTicker.Stop() + s.cancel() + return nil +}