Skip to content

Commit

Permalink
Add DTLS Handshake hooks to SettingEngine
Browse files Browse the repository at this point in the history
  • Loading branch information
theodorsm authored and Sean-Der committed Aug 19, 2024
1 parent 4a97b7d commit 18e934e
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 11 deletions.
3 changes: 3 additions & 0 deletions dtlstransport.go
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,9 @@ func (t *DTLSTransport) Start(remoteParameters DTLSParameters) error { //nolint:
dtlsConfig.ClientCAs = t.api.settingEngine.dtls.clientCAs
dtlsConfig.RootCAs = t.api.settingEngine.dtls.rootCAs
dtlsConfig.KeyLogWriter = t.api.settingEngine.dtls.keyLogWriter
dtlsConfig.ClientHelloMessageHook = t.api.settingEngine.dtls.clientHelloMessageHook
dtlsConfig.ServerHelloMessageHook = t.api.settingEngine.dtls.serverHelloMessageHook
dtlsConfig.CertificateRequestMessageHook = t.api.settingEngine.dtls.certificateRequestMessageHook

// Connect as DTLS Client/Server, function is blocking and we
// must not hold the DTLSTransport lock
Expand Down
44 changes: 33 additions & 11 deletions settingengine.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (

"github.com/pion/dtls/v3"
dtlsElliptic "github.com/pion/dtls/v3/pkg/crypto/elliptic"
"github.com/pion/dtls/v3/pkg/protocol/handshake"
"github.com/pion/ice/v4"
"github.com/pion/logging"
"github.com/pion/stun/v3"
Expand Down Expand Up @@ -63,17 +64,20 @@ type SettingEngine struct {
SRTCP *uint
}
dtls struct {
insecureSkipHelloVerify bool
disableInsecureSkipVerify bool
retransmissionInterval time.Duration
ellipticCurves []dtlsElliptic.Curve
connectContextMaker func() (context.Context, func())
extendedMasterSecret dtls.ExtendedMasterSecretType
clientAuth *dtls.ClientAuthType
clientCAs *x509.CertPool
rootCAs *x509.CertPool
keyLogWriter io.Writer
customCipherSuites func() []dtls.CipherSuite
insecureSkipHelloVerify bool
disableInsecureSkipVerify bool
retransmissionInterval time.Duration
ellipticCurves []dtlsElliptic.Curve
connectContextMaker func() (context.Context, func())
extendedMasterSecret dtls.ExtendedMasterSecretType
clientAuth *dtls.ClientAuthType
clientCAs *x509.CertPool
rootCAs *x509.CertPool
keyLogWriter io.Writer
customCipherSuites func() []dtls.CipherSuite
clientHelloMessageHook func(handshake.MessageClientHello) handshake.Message
serverHelloMessageHook func(handshake.MessageServerHello) handshake.Message
certificateRequestMessageHook func(handshake.MessageCertificateRequest) handshake.Message
}
sctp struct {
maxReceiveBufferSize uint32
Expand Down Expand Up @@ -455,6 +459,24 @@ func (e *SettingEngine) SetDTLSCustomerCipherSuites(customCipherSuites func() []
e.dtls.customCipherSuites = customCipherSuites
}

// SetDTLSClientHelloMessageHook if not nil, is called when a DTLS Client Hello message is sent
// from a client. The returned handshake message replaces the original message.
func (e *SettingEngine) SetDTLSClientHelloMessageHook(hook func(handshake.MessageClientHello) handshake.Message) {
e.dtls.clientHelloMessageHook = hook
}

// SetDTLSServerHelloMessageHook if not nil, is called when a DTLS Server Hello message is sent
// from a client. The returned handshake message replaces the original message.
func (e *SettingEngine) SetDTLSServerHelloMessageHook(hook func(handshake.MessageServerHello) handshake.Message) {
e.dtls.serverHelloMessageHook = hook
}

// SetDTLSCertificateRequestMessageHook if not nil, is called when a DTLS Certificate Request message is sent
// from a client. The returned handshake message replaces the original message.
func (e *SettingEngine) SetDTLSCertificateRequestMessageHook(hook func(handshake.MessageCertificateRequest) handshake.Message) {
e.dtls.certificateRequestMessageHook = hook
}

// SetSCTPRTOMax sets the maximum retransmission timeout.
// Leave this 0 for the default timeout.
func (e *SettingEngine) SetSCTPRTOMax(rtoMax time.Duration) {
Expand Down
31 changes: 31 additions & 0 deletions settingengine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"time"

"github.com/pion/dtls/v3/pkg/crypto/elliptic"
"github.com/pion/dtls/v3/pkg/protocol/handshake"
"github.com/pion/ice/v4"
"github.com/pion/stun/v3"
"github.com/pion/transport/v3/test"
Expand Down Expand Up @@ -309,3 +310,33 @@ func TestSetICEBindingRequestHandler(t *testing.T) {
<-seenICEControlling.Done()
closePairNow(t, pcOffer, pcAnswer)
}

func TestSetHooks(t *testing.T) {
s := SettingEngine{}

if s.dtls.clientHelloMessageHook != nil ||
s.dtls.serverHelloMessageHook != nil ||
s.dtls.certificateRequestMessageHook != nil {
t.Fatalf("SettingEngine defaults aren't as expected.")
}

s.SetDTLSClientHelloMessageHook(func(msg handshake.MessageClientHello) handshake.Message {
return &msg
})
s.SetDTLSServerHelloMessageHook(func(msg handshake.MessageServerHello) handshake.Message {
return &msg
})
s.SetDTLSCertificateRequestMessageHook(func(msg handshake.MessageCertificateRequest) handshake.Message {
return &msg
})

if s.dtls.clientHelloMessageHook == nil {
t.Errorf("Failed to set DTLS Client Hello Hook")
}
if s.dtls.serverHelloMessageHook == nil {
t.Errorf("Failed to set DTLS Server Hello Hook")
}
if s.dtls.certificateRequestMessageHook == nil {
t.Errorf("Failed to set DTLS Certificate Request Hook")
}
}

0 comments on commit 18e934e

Please sign in to comment.