Skip to content

Commit

Permalink
optimize the tcp healthcheck to reduce the thread usage
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangbo1882 committed Mar 21, 2022
1 parent 6ee61ce commit 779973b
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 144 deletions.
131 changes: 19 additions & 112 deletions healthcheck/dial.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,143 +20,50 @@ package healthcheck

import (
"errors"
"fmt"
"net"
"os"
"strconv"
"syscall"
"time"
)

type conn struct {
fd int
f *os.File
net.Conn
mark int
}

func (c *conn) Close() error {
if c.Conn != nil {
c.Conn.Close()
}
if c.f != nil {
err := c.f.Close()
c.fd, c.f = -1, nil
return err
}
if c.fd != -1 {
err := syscall.Close(c.fd)
c.fd = -1
return err
return c.Conn.Close()
}
return nil
}

func sockaddrToString(sa syscall.Sockaddr) string {
switch sa := sa.(type) {
case *syscall.SockaddrInet4:
return net.JoinHostPort(net.IP(sa.Addr[:]).String(), strconv.Itoa(sa.Port))
case *syscall.SockaddrInet6:
return net.JoinHostPort(net.IP(sa.Addr[:]).String(), strconv.Itoa(sa.Port))
default:
return fmt.Sprintf("(unknown - %T)", sa)
func (c *conn) control(network, address string, rawc syscall.RawConn) error {
var fdErr error
ctl := func(fd uintptr) {
if c.mark != 0 {
fdErr = setSocketMark(int(fd), c.mark)
}
}
if err := rawc.Control(ctl); err != nil {
return err
}
return fdErr
}

// dialTCP dials a TCP connection to the specified host and sets marking on the
// socket. The host must be given as an IP address. A mark of zero results in a
// normal (non-marked) connection.
func dialTCP(network, addr string, timeout time.Duration, mark int) (nc net.Conn, err error) {
host, port, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}
ip := net.ParseIP(host)
if ip == nil {
return nil, fmt.Errorf("invalid IP address %q", host)
c := &conn{
mark: mark,
}
p, err := strconv.ParseUint(port, 10, 16)
if err != nil {
return nil, fmt.Errorf("invalid port number %q", port)
}

var domain int
var rsa syscall.Sockaddr
switch network {
case "tcp4":
domain = syscall.AF_INET
if ip.To4() == nil {
return nil, fmt.Errorf("invalid IPv4 address %q", host)
}
sa := &syscall.SockaddrInet4{Port: int(p)}
copy(sa.Addr[:], ip.To4())
rsa = sa

case "tcp6":
domain = syscall.AF_INET6
if ip.To4() != nil {
return nil, fmt.Errorf("invalid IPv6 address %q", host)
}
sa := &syscall.SockaddrInet6{Port: int(p)}
copy(sa.Addr[:], ip.To16())
rsa = sa

default:
return nil, fmt.Errorf("unsupported network %q", network)
dial := net.Dialer{
Timeout: timeout,
Control: c.control,
}

c := &conn{}

defer func() {
if err != nil {
c.Close()
}
}()

c.fd, err = syscall.Socket(domain, syscall.SOCK_STREAM|syscall.SOCK_CLOEXEC, 0)
if err != nil {
return nil, os.NewSyscallError("socket", err)
}

if mark != 0 {
if err := setSocketMark(c.fd, mark); err != nil {
return nil, err
}
}

if err := setSocketTimeout(c.fd, timeout); err != nil {
return nil, err
}
for {
err := syscall.Connect(c.fd, rsa)
if err == nil {
break
}
// Blocking socket connect may be interrupted with EINTR
if err != syscall.EINTR {
return nil, os.NewSyscallError("connect", err)
}
}
if err := setSocketTimeout(c.fd, 0); err != nil {
return nil, err
}

lsa, _ := syscall.Getsockname(c.fd)
rsa, _ = syscall.Getpeername(c.fd)
name := fmt.Sprintf("%s %s -> %s", network, sockaddrToString(lsa), sockaddrToString(rsa))
c.f = os.NewFile(uintptr(c.fd), name)

// When we call net.FileConn the socket will be made non-blocking and
// we will get a *net.TCPConn in return. The *os.File needs to be
// closed in addition to the *net.TCPConn when we're done (conn.Close
// takes care of that for us).
if c.Conn, err = net.FileConn(c.f); err != nil {
return nil, err
}
if _, ok := c.Conn.(*net.TCPConn); !ok {
return nil, fmt.Errorf("%T is not a *net.TCPConn", c.Conn)
}

return c, nil
c.Conn, err = dial.Dial(network, addr)
return c, err
}

// dialUDP dials a UDP connection to the specified host and sets marking on the
Expand Down
95 changes: 63 additions & 32 deletions test_tools/healthcheck_test_tool/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
package main

import (
"errors"
"flag"
"fmt"
"log"
"net"
"strconv"
Expand All @@ -42,6 +44,7 @@ var (
response = flag.String("response", "", "expected HTTP(S) response")
responseCode = flag.Int("response_code", 200, "expected HTTP(S) response code")
tlsVerify = flag.Bool("tls_verify", true, "enable TLS verification for HTTPS and TCP TLS")
parallel = flag.Int("parallel", 1, "concurrent goroutines")

dnsAnswer = flag.String("answer", "", "DNS answer expected from query")
dnsQuery = flag.String("query", "", "DNS query to perform")
Expand All @@ -54,13 +57,15 @@ var (
timeout = flag.Duration("timeout", 0, "healthcheck timeout")
)

func check(hc healthcheck.Checker) {
func check(hc healthcheck.Checker) error {
r := hc.Check(*timeout)
s := "success"
if !r.Success {
s = "failure"
return fmt.Errorf("%v - %v (healthcheck %s)", hc, r, s)
}
log.Printf("%v - %v (healthcheck %s)", hc, r, s)
return nil
}

func unquote(s string) string {
Expand All @@ -74,7 +79,7 @@ func unquote(s string) string {
return us
}

func doDNSCheck(target net.IP) {
func doDNSCheck(target net.IP) error {
qt, err := healthcheck.DNSType(*dnsQueryType)
if err != nil {
log.Fatal(err)
Expand All @@ -84,10 +89,10 @@ func doDNSCheck(target net.IP) {
hc.Answer = *dnsAnswer
hc.Question.Name = *dnsQuery
hc.Question.Qtype = qt
check(hc)
return check(hc)
}

func doHTTPCheck(target net.IP, secure bool) {
func doHTTPCheck(target net.IP, secure bool) error {
hc := healthcheck.NewHTTPChecker(target, *port)
hc.Mark = *mark
hc.Secure = secure
Expand All @@ -97,10 +102,10 @@ func doHTTPCheck(target net.IP, secure bool) {
hc.Method = *method
hc.Proxy = *proxy
hc.TLSVerify = *tlsVerify
check(hc)
return check(hc)
}

func doPingCheck(target net.IP) {
func doPingCheck(target net.IP) error {
pc := healthcheck.NewPingChecker(target)
pc.Mark = *mark
received := 0
Expand All @@ -113,35 +118,40 @@ func doPingCheck(target net.IP) {
received++
log.Printf("Received reply from %v in %v", target, r.Duration)
}
log.Printf("Sent %d packets, received %d replies", *count, received)
msg := fmt.Sprintf("Sent %d packets, received %d replies", *count, received)
if *count != received {
return errors.New(msg)
}
log.Print(msg)
return nil
}

func doRADIUSCheck(target net.IP) {
func doRADIUSCheck(target net.IP) error {
hc := healthcheck.NewRADIUSChecker(target, *port)
hc.Mark = *mark
hc.Username = *radiusUser
hc.Password = *radiusPasswd
hc.Response = *radiusResponse
hc.Secret = *radiusSecret
check(hc)
return check(hc)
}

func doTCPCheck(target net.IP, secure bool) {
func doTCPCheck(target net.IP, secure bool) error {
hc := healthcheck.NewTCPChecker(target, *port)
hc.Mark = *mark
hc.Receive = unquote(*receive)
hc.Send = unquote(*send)
hc.Secure = secure
hc.TLSVerify = *tlsVerify
check(hc)
return check(hc)
}

func doUDPCheck(target net.IP) {
func doUDPCheck(target net.IP) error {
hc := healthcheck.NewUDPChecker(target, *port)
hc.Mark = *mark
hc.Receive = unquote(*receive)
hc.Send = unquote(*send)
check(hc)
return check(hc)
}

func main() {
Expand All @@ -150,25 +160,46 @@ func main() {
if target == nil {
log.Fatalf("Invalid IP address: %v", *ip)
}
if *parallel < 1 {
log.Fatalf("Invalid value for parallel: %v", *parallel)
}

switch *hcType {
case "dns":
doDNSCheck(target)
case "http":
doHTTPCheck(target, false)
case "https":
doHTTPCheck(target, true)
case "ping":
doPingCheck(target)
case "radius":
doRADIUSCheck(target)
case "tcp":
doTCPCheck(target, false)
case "tcp_tls":
doTCPCheck(target, true)
case "udp":
doUDPCheck(target)
default:
log.Fatalf("Unsupported healthcheck type: %q", *hcType)
errs := make(chan error, *parallel)

for i := 0; i < *parallel; i++ {
go func(err chan error) {
switch *hcType {
case "dns":
err <- doDNSCheck(target)
case "http":
err <- doHTTPCheck(target, false)
case "https":
err <- doHTTPCheck(target, true)
case "ping":
err <- doPingCheck(target)
case "radius":
err <- doRADIUSCheck(target)
case "tcp":
err <- doTCPCheck(target, false)
case "tcp_tls":
err <- doTCPCheck(target, true)
case "udp":
err <- doUDPCheck(target)
default:
log.Fatalf("Unsupported healthcheck type: %q", *hcType)
}
}(errs)
}
fail := 0
for i := 0; i < *parallel; i++ {
select {
case err := <-errs:
if err != nil {
log.Printf("Error message: %v", err)
fail++
}
}
}
log.Printf("Test done. %d goroutines. success: %d, fail: %d", *parallel, *parallel-fail, fail)
return
}

0 comments on commit 779973b

Please sign in to comment.