Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add DTLS handshake hooks to SettingEngine #2882

Merged
merged 1 commit into from
Aug 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions dtlstransport.go
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,9 @@ func (t *DTLSTransport) Start(remoteParameters DTLSParameters) error { //nolint:
dtlsConfig.ClientCAs = t.api.settingEngine.dtls.clientCAs
dtlsConfig.RootCAs = t.api.settingEngine.dtls.rootCAs
dtlsConfig.KeyLogWriter = t.api.settingEngine.dtls.keyLogWriter
dtlsConfig.ClientHelloMessageHook = t.api.settingEngine.dtls.clientHelloMessageHook
dtlsConfig.ServerHelloMessageHook = t.api.settingEngine.dtls.serverHelloMessageHook
dtlsConfig.CertificateRequestMessageHook = t.api.settingEngine.dtls.certificateRequestMessageHook

// Connect as DTLS Client/Server, function is blocking and we
// must not hold the DTLSTransport lock
Expand Down
44 changes: 33 additions & 11 deletions settingengine.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (

"github.com/pion/dtls/v3"
dtlsElliptic "github.com/pion/dtls/v3/pkg/crypto/elliptic"
"github.com/pion/dtls/v3/pkg/protocol/handshake"
"github.com/pion/ice/v4"
"github.com/pion/logging"
"github.com/pion/stun/v3"
Expand Down Expand Up @@ -63,17 +64,20 @@ type SettingEngine struct {
SRTCP *uint
}
dtls struct {
insecureSkipHelloVerify bool
disableInsecureSkipVerify bool
retransmissionInterval time.Duration
ellipticCurves []dtlsElliptic.Curve
connectContextMaker func() (context.Context, func())
extendedMasterSecret dtls.ExtendedMasterSecretType
clientAuth *dtls.ClientAuthType
clientCAs *x509.CertPool
rootCAs *x509.CertPool
keyLogWriter io.Writer
customCipherSuites func() []dtls.CipherSuite
insecureSkipHelloVerify bool
disableInsecureSkipVerify bool
retransmissionInterval time.Duration
ellipticCurves []dtlsElliptic.Curve
connectContextMaker func() (context.Context, func())
extendedMasterSecret dtls.ExtendedMasterSecretType
clientAuth *dtls.ClientAuthType
clientCAs *x509.CertPool
rootCAs *x509.CertPool
keyLogWriter io.Writer
customCipherSuites func() []dtls.CipherSuite
clientHelloMessageHook func(handshake.MessageClientHello) handshake.Message
serverHelloMessageHook func(handshake.MessageServerHello) handshake.Message
certificateRequestMessageHook func(handshake.MessageCertificateRequest) handshake.Message
}
sctp struct {
maxReceiveBufferSize uint32
Expand Down Expand Up @@ -455,6 +459,24 @@ func (e *SettingEngine) SetDTLSCustomerCipherSuites(customCipherSuites func() []
e.dtls.customCipherSuites = customCipherSuites
}

// SetDTLSClientHelloMessageHook if not nil, is called when a DTLS Client Hello message is sent
// from a client. The returned handshake message replaces the original message.
func (e *SettingEngine) SetDTLSClientHelloMessageHook(hook func(handshake.MessageClientHello) handshake.Message) {
e.dtls.clientHelloMessageHook = hook
}

// SetDTLSServerHelloMessageHook if not nil, is called when a DTLS Server Hello message is sent
// from a client. The returned handshake message replaces the original message.
func (e *SettingEngine) SetDTLSServerHelloMessageHook(hook func(handshake.MessageServerHello) handshake.Message) {
e.dtls.serverHelloMessageHook = hook
}

// SetDTLSCertificateRequestMessageHook if not nil, is called when a DTLS Certificate Request message is sent
// from a client. The returned handshake message replaces the original message.
func (e *SettingEngine) SetDTLSCertificateRequestMessageHook(hook func(handshake.MessageCertificateRequest) handshake.Message) {
e.dtls.certificateRequestMessageHook = hook
}

// SetSCTPRTOMax sets the maximum retransmission timeout.
// Leave this 0 for the default timeout.
func (e *SettingEngine) SetSCTPRTOMax(rtoMax time.Duration) {
Expand Down
31 changes: 31 additions & 0 deletions settingengine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"time"

"github.com/pion/dtls/v3/pkg/crypto/elliptic"
"github.com/pion/dtls/v3/pkg/protocol/handshake"
"github.com/pion/ice/v4"
"github.com/pion/stun/v3"
"github.com/pion/transport/v3/test"
Expand Down Expand Up @@ -309,3 +310,33 @@ func TestSetICEBindingRequestHandler(t *testing.T) {
<-seenICEControlling.Done()
closePairNow(t, pcOffer, pcAnswer)
}

func TestSetHooks(t *testing.T) {
s := SettingEngine{}

if s.dtls.clientHelloMessageHook != nil ||
s.dtls.serverHelloMessageHook != nil ||
s.dtls.certificateRequestMessageHook != nil {
t.Fatalf("SettingEngine defaults aren't as expected.")
}

s.SetDTLSClientHelloMessageHook(func(msg handshake.MessageClientHello) handshake.Message {
return &msg
})
s.SetDTLSServerHelloMessageHook(func(msg handshake.MessageServerHello) handshake.Message {
return &msg
})
s.SetDTLSCertificateRequestMessageHook(func(msg handshake.MessageCertificateRequest) handshake.Message {
return &msg
})

if s.dtls.clientHelloMessageHook == nil {
t.Errorf("Failed to set DTLS Client Hello Hook")
}
if s.dtls.serverHelloMessageHook == nil {
t.Errorf("Failed to set DTLS Server Hello Hook")
}
if s.dtls.certificateRequestMessageHook == nil {
t.Errorf("Failed to set DTLS Certificate Request Hook")
}
}
Loading