Skip to content

Commit

Permalink
Make rule-set initialization parallel
Browse files Browse the repository at this point in the history
  • Loading branch information
nekohasekai committed Nov 30, 2023
1 parent cb28aba commit 889f426
Show file tree
Hide file tree
Showing 7 changed files with 93 additions and 22 deletions.
10 changes: 9 additions & 1 deletion adapter/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@ package adapter

import (
"context"
"net/http"
"net/netip"

"github.com/sagernet/sing-box/common/geoip"
"github.com/sagernet/sing-dns"
"github.com/sagernet/sing-tun"
"github.com/sagernet/sing/common/control"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/service"

mdns "github.com/miekg/dns"
Expand Down Expand Up @@ -83,8 +85,14 @@ type DNSRule interface {
}

type RuleSet interface {
StartContext(ctx context.Context, startContext RuleSetStartContext) error
Close() error
HeadlessRule
Service
}

type RuleSetStartContext interface {
HTTPClient(detour string, dialer N.Dialer) *http.Client
Close()
}

type InterfaceUpdateListener interface {
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ require (
github.com/sagernet/gvisor v0.0.0-20231119034329-07cfb6aaf930
github.com/sagernet/quic-go v0.40.0
github.com/sagernet/reality v0.0.0-20230406110435-ee17307e7691
github.com/sagernet/sing v0.2.18-0.20231129075305-eb56a60214be
github.com/sagernet/sing v0.2.18-0.20231130082037-cac27afa2a18
github.com/sagernet/sing-dns v0.1.11
github.com/sagernet/sing-mux v0.1.5-0.20231109075101-6b086ed6bb07
github.com/sagernet/sing-quic v0.1.5-0.20231123150216-00957d136203
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,8 @@ github.com/sagernet/reality v0.0.0-20230406110435-ee17307e7691 h1:5Th31OC6yj8byL
github.com/sagernet/reality v0.0.0-20230406110435-ee17307e7691/go.mod h1:B8lp4WkQ1PwNnrVMM6KyuFR20pU8jYBD+A4EhJovEXU=
github.com/sagernet/sing v0.0.0-20220817130738-ce854cda8522/go.mod h1:QVsS5L/ZA2Q5UhQwLrn0Trw+msNd/NPGEhBKR/ioWiY=
github.com/sagernet/sing v0.1.8/go.mod h1:jt1w2u7lJQFFSGLiRrRIs5YWmx4kAPfWuOejuDW9qMk=
github.com/sagernet/sing v0.2.18-0.20231129075305-eb56a60214be h1:FigAM9kq7RRXmHvgn8w2a8tqCY5CMV5GIk0id84dI0o=
github.com/sagernet/sing v0.2.18-0.20231129075305-eb56a60214be/go.mod h1:OL6k2F0vHmEzXz2KW19qQzu172FDgSbUSODylighuVo=
github.com/sagernet/sing v0.2.18-0.20231130082037-cac27afa2a18 h1:R9A7AV+YKh/uVQkfjFZ9xJ7vH2hxYHoOc4FnIReONY8=
github.com/sagernet/sing v0.2.18-0.20231130082037-cac27afa2a18/go.mod h1:OL6k2F0vHmEzXz2KW19qQzu172FDgSbUSODylighuVo=
github.com/sagernet/sing-dns v0.1.11 h1:PPrMCVVrAeR3f5X23I+cmvacXJ+kzuyAsBiWyUKhGSE=
github.com/sagernet/sing-dns v0.1.11/go.mod h1:zJ/YjnYB61SYE+ubMcMqVdpaSvsyQ2iShQGO3vuLvvE=
github.com/sagernet/sing-mux v0.1.5-0.20231109075101-6b086ed6bb07 h1:ncKb5tVOsCQgCsv6UpsA0jinbNb5OQ5GMPJlyQP3EHM=
Expand Down
21 changes: 17 additions & 4 deletions route/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ import (
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
serviceNTP "github.com/sagernet/sing/common/ntp"
"github.com/sagernet/sing/common/task"
"github.com/sagernet/sing/common/uot"
"github.com/sagernet/sing/service"
"github.com/sagernet/sing/service/pause"
Expand Down Expand Up @@ -490,12 +491,24 @@ func (r *Router) Start() error {
if r.needWIFIState {
r.updateWIFIState()
}
ruleSetStartContext := NewRuleSetStartContext()
var ruleSetStartGroup task.Group
for i, ruleSet := range r.ruleSets {
err := ruleSet.Start()
if err != nil {
return E.Cause(err, "initialize rule-set[", i, "]")
}
ruleSetStartGroup.Append0(func(ctx context.Context) error {
err := ruleSet.StartContext(ctx, ruleSetStartContext)
if err != nil {
return E.Cause(err, "initialize rule-set[", i, "]")
}
return nil
})
}
ruleSetStartGroup.Concurrency(5)
ruleSetStartGroup.FastFail()
err := ruleSetStartGroup.Run(r.ctx)
if err != nil {
return err
}
ruleSetStartContext.Close()
for i, rule := range r.rules {
err := rule.Start()
if err != nil {
Expand Down
45 changes: 45 additions & 0 deletions route/rule_set.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,17 @@ package route

import (
"context"
"net"
"net/http"
"sync"

"github.com/sagernet/sing-box/adapter"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/option"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)

func NewRuleSet(ctx context.Context, router adapter.Router, logger logger.ContextLogger, options option.RuleSet) (adapter.RuleSet, error) {
Expand All @@ -20,3 +25,43 @@ func NewRuleSet(ctx context.Context, router adapter.Router, logger logger.Contex
return nil, E.New("unknown rule set type: ", options.Type)
}
}

var _ adapter.RuleSetStartContext = (*RuleSetStartContext)(nil)

type RuleSetStartContext struct {
access sync.Mutex
httpClientCache map[string]*http.Client
}

func NewRuleSetStartContext() *RuleSetStartContext {
return &RuleSetStartContext{
httpClientCache: make(map[string]*http.Client),
}
}

func (c *RuleSetStartContext) HTTPClient(detour string, dialer N.Dialer) *http.Client {
c.access.Lock()
defer c.access.Unlock()
if httpClient, loaded := c.httpClientCache[detour]; loaded {
return httpClient
}
httpClient := &http.Client{
Transport: &http.Transport{
ForceAttemptHTTP2: true,
TLSHandshakeTimeout: C.TCPTimeout,
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return dialer.DialContext(ctx, network, M.ParseSocksaddr(addr))
},
},
}
c.httpClientCache[detour] = httpClient
return httpClient
}

func (c *RuleSetStartContext) Close() {
c.access.Lock()
defer c.access.Unlock()
for _, client := range c.httpClientCache {
client.CloseIdleConnections()
}
}
3 changes: 2 additions & 1 deletion route/rule_set_local.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package route

import (
"context"
"os"

"github.com/sagernet/sing-box/adapter"
Expand Down Expand Up @@ -60,7 +61,7 @@ func (s *LocalRuleSet) Match(metadata *adapter.InboundContext) bool {
return false
}

func (s *LocalRuleSet) Start() error {
func (s *LocalRuleSet) StartContext(ctx context.Context, startContext adapter.RuleSetStartContext) error {
return nil
}

Expand Down
30 changes: 17 additions & 13 deletions route/rule_set_remote.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ func (s *RemoteRuleSet) Match(metadata *adapter.InboundContext) bool {
return false
}

func (s *RemoteRuleSet) Start() error {
func (s *RemoteRuleSet) StartContext(ctx context.Context, startContext adapter.RuleSetStartContext) error {
var dialer N.Dialer
if s.options.RemoteOptions.DownloadDetour != "" {
outbound, loaded := s.router.Outbound(s.options.RemoteOptions.DownloadDetour)
Expand Down Expand Up @@ -91,7 +91,7 @@ func (s *RemoteRuleSet) Start() error {
}
}
if s.lastUpdated.IsZero() || time.Since(s.lastUpdated) > s.updateInterval {
err := s.fetchOnce()
err := s.fetchOnce(ctx, startContext)
if err != nil {
return E.Cause(err, "fetch rule-set ", s.options.Tag)
}
Expand Down Expand Up @@ -141,34 +141,38 @@ func (s *RemoteRuleSet) loopUpdate() {
case <-s.ctx.Done():
return
case <-s.updateTicker.C:
err := s.fetchOnce()
err := s.fetchOnce(s.ctx, nil)
if err != nil {
s.logger.Error("fetch rule-set ", s.options.Tag, ": ", err)
}
}
}
}

func (s *RemoteRuleSet) fetchOnce() error {
func (s *RemoteRuleSet) fetchOnce(ctx context.Context, startContext adapter.RuleSetStartContext) error {
s.logger.Debug("updating rule-set ", s.options.Tag, " from URL: ", s.options.RemoteOptions.URL)
httpClient := &http.Client{
Transport: &http.Transport{
ForceAttemptHTTP2: true,
TLSHandshakeTimeout: C.TCPTimeout,
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return s.dialer.DialContext(ctx, network, M.ParseSocksaddr(addr))
var httpClient *http.Client
if startContext != nil {
httpClient = startContext.HTTPClient(s.options.RemoteOptions.DownloadDetour, s.dialer)
} else {
httpClient = &http.Client{
Transport: &http.Transport{
ForceAttemptHTTP2: true,
TLSHandshakeTimeout: C.TCPTimeout,
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return s.dialer.DialContext(ctx, network, M.ParseSocksaddr(addr))
},
},
},
}
}
defer httpClient.CloseIdleConnections()
request, err := http.NewRequest("GET", s.options.RemoteOptions.URL, nil)
if err != nil {
return err
}
if s.lastEtag != "" {
request.Header.Set("If-None-Match", s.lastEtag)
}
response, err := httpClient.Do(request.WithContext(s.ctx))
response, err := httpClient.Do(request.WithContext(ctx))
if err != nil {
return err
}
Expand Down

0 comments on commit 889f426

Please sign in to comment.