diff --git a/dtlstransport.go b/dtlstransport.go index cc8c6b159bc..3e442923b55 100644 --- a/dtlstransport.go +++ b/dtlstransport.go @@ -355,6 +355,9 @@ func (t *DTLSTransport) Start(remoteParameters DTLSParameters) error { //nolint: dtlsConfig.ClientCAs = t.api.settingEngine.dtls.clientCAs dtlsConfig.RootCAs = t.api.settingEngine.dtls.rootCAs dtlsConfig.KeyLogWriter = t.api.settingEngine.dtls.keyLogWriter + dtlsConfig.ClientHelloMessageHook = t.api.settingEngine.dtls.clientHelloMessageHook + dtlsConfig.ServerHelloMessageHook = t.api.settingEngine.dtls.serverHelloMessageHook + dtlsConfig.CertificateRequestMessageHook = t.api.settingEngine.dtls.certificateRequestMessageHook // Connect as DTLS Client/Server, function is blocking and we // must not hold the DTLSTransport lock diff --git a/settingengine.go b/settingengine.go index 1c2475b238a..f3c70707c9a 100644 --- a/settingengine.go +++ b/settingengine.go @@ -15,6 +15,7 @@ import ( "github.com/pion/dtls/v3" dtlsElliptic "github.com/pion/dtls/v3/pkg/crypto/elliptic" + "github.com/pion/dtls/v3/pkg/protocol/handshake" "github.com/pion/ice/v4" "github.com/pion/logging" "github.com/pion/stun/v3" @@ -63,17 +64,20 @@ type SettingEngine struct { SRTCP *uint } dtls struct { - insecureSkipHelloVerify bool - disableInsecureSkipVerify bool - retransmissionInterval time.Duration - ellipticCurves []dtlsElliptic.Curve - connectContextMaker func() (context.Context, func()) - extendedMasterSecret dtls.ExtendedMasterSecretType - clientAuth *dtls.ClientAuthType - clientCAs *x509.CertPool - rootCAs *x509.CertPool - keyLogWriter io.Writer - customCipherSuites func() []dtls.CipherSuite + insecureSkipHelloVerify bool + disableInsecureSkipVerify bool + retransmissionInterval time.Duration + ellipticCurves []dtlsElliptic.Curve + connectContextMaker func() (context.Context, func()) + extendedMasterSecret dtls.ExtendedMasterSecretType + clientAuth *dtls.ClientAuthType + clientCAs *x509.CertPool + rootCAs *x509.CertPool + keyLogWriter io.Writer + customCipherSuites func() []dtls.CipherSuite + clientHelloMessageHook func(handshake.MessageClientHello) handshake.Message + serverHelloMessageHook func(handshake.MessageServerHello) handshake.Message + certificateRequestMessageHook func(handshake.MessageCertificateRequest) handshake.Message } sctp struct { maxReceiveBufferSize uint32 @@ -455,6 +459,24 @@ func (e *SettingEngine) SetDTLSCustomerCipherSuites(customCipherSuites func() [] e.dtls.customCipherSuites = customCipherSuites } +// SetDTLSClientHelloMessageHook if not nil, is called when a DTLS Client Hello message is sent +// from a client. The returned handshake message replaces the original message. +func (e *SettingEngine) SetDTLSClientHelloMessageHook(hook func(handshake.MessageClientHello) handshake.Message) { + e.dtls.clientHelloMessageHook = hook +} + +// SetDTLSServerHelloMessageHook if not nil, is called when a DTLS Server Hello message is sent +// from a client. The returned handshake message replaces the original message. +func (e *SettingEngine) SetDTLSServerHelloMessageHook(hook func(handshake.MessageServerHello) handshake.Message) { + e.dtls.serverHelloMessageHook = hook +} + +// SetDTLSCertificateRequestMessageHook if not nil, is called when a DTLS Certificate Request message is sent +// from a client. The returned handshake message replaces the original message. +func (e *SettingEngine) SetDTLSCertificateRequestMessageHook(hook func(handshake.MessageCertificateRequest) handshake.Message) { + e.dtls.certificateRequestMessageHook = hook +} + // SetSCTPRTOMax sets the maximum retransmission timeout. // Leave this 0 for the default timeout. func (e *SettingEngine) SetSCTPRTOMax(rtoMax time.Duration) { diff --git a/settingengine_test.go b/settingengine_test.go index bb4dfa537b2..4a5714e1f69 100644 --- a/settingengine_test.go +++ b/settingengine_test.go @@ -13,6 +13,7 @@ import ( "time" "github.com/pion/dtls/v3/pkg/crypto/elliptic" + "github.com/pion/dtls/v3/pkg/protocol/handshake" "github.com/pion/ice/v4" "github.com/pion/stun/v3" "github.com/pion/transport/v3/test" @@ -309,3 +310,33 @@ func TestSetICEBindingRequestHandler(t *testing.T) { <-seenICEControlling.Done() closePairNow(t, pcOffer, pcAnswer) } + +func TestSetHooks(t *testing.T) { + s := SettingEngine{} + + if s.dtls.clientHelloMessageHook != nil || + s.dtls.serverHelloMessageHook != nil || + s.dtls.certificateRequestMessageHook != nil { + t.Fatalf("SettingEngine defaults aren't as expected.") + } + + s.SetDTLSClientHelloMessageHook(func(msg handshake.MessageClientHello) handshake.Message { + return &msg + }) + s.SetDTLSServerHelloMessageHook(func(msg handshake.MessageServerHello) handshake.Message { + return &msg + }) + s.SetDTLSCertificateRequestMessageHook(func(msg handshake.MessageCertificateRequest) handshake.Message { + return &msg + }) + + if s.dtls.clientHelloMessageHook == nil { + t.Errorf("Failed to set DTLS Client Hello Hook") + } + if s.dtls.serverHelloMessageHook == nil { + t.Errorf("Failed to set DTLS Server Hello Hook") + } + if s.dtls.certificateRequestMessageHook == nil { + t.Errorf("Failed to set DTLS Certificate Request Hook") + } +}