Skip to content

Commit

Permalink
use noise prologue mechanism
Browse files Browse the repository at this point in the history
  • Loading branch information
ckousik committed Aug 15, 2022
1 parent f8025fb commit 71f607f
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 77 deletions.
3 changes: 2 additions & 1 deletion p2p/transport/webrtc/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ func newConnection(

accept: accept,
}

pc.OnDataChannel(func(dc *webrtc.DataChannel) {
log.Debugf("[%s] incoming datachannel: %s", localPeer, dc.Label())
dc.OnOpen(func() {
dcrwc, err := dc.Detach()
if err != nil {
Expand Down Expand Up @@ -131,6 +131,7 @@ func (c *connection) OpenStream(ctx context.Context) (network.MuxedStream, error
error
}, 1)
dc.OnOpen(func() {
log.Debugf("[%s] opened new datachannel: %s", c.localPeer, dc.Label())
rwc, err := dc.Detach()
if err != nil {
result <- struct {
Expand Down
44 changes: 2 additions & 42 deletions p2p/transport/webrtc/listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package libp2pwebrtc

import (
"context"
"crypto/x509"
"encoding/hex"
"errors"
"fmt"
Expand All @@ -20,7 +19,6 @@ import (
"github.com/multiformats/go-multibase"
"github.com/multiformats/go-multihash"

"github.com/pion/dtls/v2/pkg/crypto/fingerprint"
"github.com/pion/ice/v2"
"github.com/pion/webrtc/v3"
)
Expand Down Expand Up @@ -55,7 +53,7 @@ func init() {

}

/// implement net.Listener
// / implement net.Listener
type listener struct {
transport *WebRTCTransport
config webrtc.Configuration
Expand Down Expand Up @@ -83,7 +81,6 @@ func newListener(transport *WebRTCTransport, laddr ma.Multiaddr, socket net.Pack
localMhBuf, _ := multihash.EncodeName(localMh, sdpHashToMh(localFingerprints[0].Algorithm))
localFpMultibase, _ := multibase.Encode(multibase.Base64, localMhBuf)

var wg sync.WaitGroup
ctx, cancel := context.WithCancel(context.Background())

l := &listener{
Expand All @@ -96,10 +93,9 @@ func newListener(transport *WebRTCTransport, laddr ma.Multiaddr, socket net.Pack
ctx: ctx,
cancel: cancel,
connChan: make(chan tpt.CapableConn, 20),
wg: wg,
}

wg.Add(1)
l.wg.Add(1)
go l.startAcceptLoop()
return l, err
}
Expand Down Expand Up @@ -302,39 +298,3 @@ func (l *listener) accept(ctx context.Context, addr candidateAddr) (tpt.CapableC

return conn, nil
}

func verifyRemoteFingerprint(raw []byte, remoteMultibaseMultihash string) bool {
cert, err := x509.ParseCertificate(raw)
if err != nil {
log.Debugf("could not parse certificate: %v", err)
return false
}

_, remoteData, err := multibase.Decode(remoteMultibaseMultihash)
if err != nil {
log.Debugf("could not decode multibase: %v", err)
return false
}
decoded, err := multihash.Decode(remoteData)
if err != nil {
log.Debugf("could not decode multihash: %v", err)
return false
}
remoteFingerprint := hex.EncodeToString(decoded.Digest)
remoteFingerprint = maFingerprintToSdp(remoteFingerprint)

// create fingerprint for remote certificate
hashAlgoName := mhToSdpHash(decoded.Name)
if hashAlgoName == "" {
hashAlgoName = decoded.Name
}

hashAlgo, err := fingerprint.HashFromString(hashAlgoName)
if err != nil {
log.Debugf("could not find hash algo: %s %v", hashAlgoName, err)
return false
}
fp, err := fingerprint.Fingerprint(cert, hashAlgo)

return strings.EqualFold(fp, remoteFingerprint)
}
92 changes: 58 additions & 34 deletions p2p/transport/webrtc/transport.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
package libp2pwebrtc

import (
"bytes"
"context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/x509"
"encoding/hex"
"fmt"
"net"
"sort"
"strings"

"github.com/google/uuid"
ic "github.com/libp2p/go-libp2p-core/crypto"
Expand All @@ -21,7 +25,9 @@ import (
ma "github.com/multiformats/go-multiaddr"
mafmt "github.com/multiformats/go-multiaddr-fmt"
manet "github.com/multiformats/go-multiaddr/net"
"github.com/multiformats/go-multihash"

"github.com/pion/dtls/v2/pkg/crypto/fingerprint"
"github.com/pion/webrtc/v3"
)

Expand Down Expand Up @@ -293,54 +299,72 @@ func (t *WebRTCTransport) getCertificateFingerprint() (webrtc.DTLSFingerprint, e
return fps[0], nil
}

func (t *WebRTCTransport) noiseHandshake(ctx context.Context, pc *webrtc.PeerConnection, datachannel *dataChannel, peer peer.ID, inbound bool) (secureConn sec.SecureConn, err error) {
if inbound {
secureConn, err = t.noiseTpt.SecureInbound(ctx, datachannel, peer)
if err != nil {
return
}
} else {
secureConn, err = t.noiseTpt.SecureOutbound(ctx, datachannel, peer)
if err != nil {
return
}
func (t *WebRTCTransport) generateNoisePrologue(pc *webrtc.PeerConnection) ([]byte, error) {
raw := pc.SCTP().Transport().GetRemoteCertificate()
cert, err := x509.ParseCertificate(raw)
if err != nil {
return nil, err
}
localFingerprint, err := t.getCertificateFingerprint()
// guess hash algorithm
localFp, err := t.getCertificateFingerprint()
if err != nil {
return
return nil, err
}
encodedMultibase, err := encodeDTLSFingerprint(localFingerprint)

hashAlgo, err := fingerprint.HashFromString(localFp.Algorithm)
if err != nil {
return
log.Debugf("could not find hash algo: %s %v", localFp.Algorithm, err)
return nil, err
}
remoteFp, err := fingerprint.Fingerprint(cert, hashAlgo)
remoteFp = strings.ToLower(remoteFp)

_, err = secureConn.Write([]byte(encodedMultibase))
mhAlgoName := sdpHashToMh(localFp.Algorithm)
if mhAlgoName == "" {
mhAlgoName = localFp.Algorithm
}

local := strings.ReplaceAll(localFp.Value, ":", "")
remote := strings.ReplaceAll(remoteFp, ":", "")
localEncoded, err := multihash.EncodeName([]byte(local), mhAlgoName)
if err != nil {
return
log.Debugf("could not encode multihash for local fingerprint")
return nil, err
}
remoteEncoded, err := multihash.EncodeName([]byte(remote), mhAlgoName)
if err != nil {
log.Debugf("could not encode multihash for remote fingerprint")
return nil, err
}

b := [][]byte{localEncoded, remoteEncoded}
sort.Slice(b, func(i, j int) bool {
return bytes.Compare(b[i], b[j]) < 0
})
result := append([]byte("libp2p-webrtc-noise:"), b[0]...)
result = append(result, b[1]...)
return result, nil
}

done := make(chan error, 1)
go func() {
buf := make([]byte, 2048)
n, err := secureConn.Read(buf)
func (t *WebRTCTransport) noiseHandshake(ctx context.Context, pc *webrtc.PeerConnection, datachannel *dataChannel, peer peer.ID, inbound bool) (secureConn sec.SecureConn, err error) {
prologue, err := t.generateNoisePrologue(pc)
if err != nil {
return nil, fmt.Errorf("could not generate noise prologue: %v", err)
}
sessionTransport, err := t.noiseTpt.WithSessionOptions(noise.Prologue(prologue))
if err != nil {
return nil, fmt.Errorf("could not instantiate noise session transport: %v", err)
}
if inbound {
secureConn, err = sessionTransport.SecureInbound(ctx, datachannel, peer)
if err != nil {
done <- err
}
remoteFpMultibase := string(buf[:n])
if !verifyRemoteFingerprint(pc.SCTP().Transport().GetRemoteCertificate(), remoteFpMultibase) {
done <- fmt.Errorf("could not verify remote fingerprint")
return
}
close(done)
}()

select {
case err = <-done:
} else {
secureConn, err = sessionTransport.SecureOutbound(ctx, datachannel, peer)
if err != nil {
err = fmt.Errorf("dialed: read failed: %v", err)
return
}
case <-ctx.Done():
return nil, ErrNoiseHandshakeTimeout
}
return secureConn, nil
}
5 changes: 5 additions & 0 deletions p2p/transport/webrtc/transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package libp2pwebrtc
import (
"context"
"testing"
"time"

"github.com/libp2p/go-libp2p-core/crypto"
"github.com/libp2p/go-libp2p-core/network"
Expand Down Expand Up @@ -136,7 +137,10 @@ func TestTransportDialerCanCreateStreams(t *testing.T) {
go func() {
conn, err := tr1.Dial(context.Background(), listener.Multiaddr(), listeningPeer)
require.NoError(t, err)
t.Log("dialer creating stream")
time.Sleep(100 * time.Millisecond)
stream, err := conn.OpenStream(context.Background())
t.Log("dialer created stream")
require.NoError(t, err)
_, err = stream.Write([]byte("test"))
require.NoError(t, err)
Expand All @@ -145,6 +149,7 @@ func TestTransportDialerCanCreateStreams(t *testing.T) {
lconn, err := listener.Accept()
require.NoError(t, err)
require.Equal(t, connectingPeer, lconn.RemotePeer())
t.Log("accepted connection")

stream, err := lconn.AcceptStream()
require.NoError(t, err)
Expand Down

0 comments on commit 71f607f

Please sign in to comment.