@@ -7,6 +7,7 @@ import "C"
77import (
88 "errors"
99 "runtime"
10+ "slices"
1011 "unsafe"
1112)
1213
@@ -20,49 +21,44 @@ func (k *PublicKeyECDH) finalize() {
2021}
2122
2223type PrivateKeyECDH struct {
23- _pkey C.GO_EVP_PKEY_PTR
24- curve string
25- hasPublicKey bool
24+ _pkey C.GO_EVP_PKEY_PTR
25+ curve string
2626}
2727
2828func (k * PrivateKeyECDH ) finalize () {
2929 C .go_openssl_EVP_PKEY_free (k ._pkey )
3030}
3131
3232func NewPublicKeyECDH (curve string , bytes []byte ) (* PublicKeyECDH , error ) {
33- if len (bytes ) < 1 {
34- return nil , errors .New ("NewPublicKeyECDH: missing key" )
33+ if len (bytes ) != 1 + 2 * curveSize ( curve ) {
34+ return nil , errors .New ("NewPublicKeyECDH: wrong key length " )
3535 }
3636 pkey , err := newECDHPkey (curve , bytes , false )
3737 if err != nil {
3838 return nil , err
3939 }
40- k := & PublicKeyECDH {pkey , append ([] byte ( nil ), bytes ... )}
40+ k := & PublicKeyECDH {pkey , slices . Clone ( bytes )}
4141 runtime .SetFinalizer (k , (* PublicKeyECDH ).finalize )
4242 return k , nil
4343}
4444
4545func (k * PublicKeyECDH ) Bytes () []byte { return k .bytes }
4646
4747func NewPrivateKeyECDH (curve string , bytes []byte ) (* PrivateKeyECDH , error ) {
48+ if len (bytes ) != curveSize (curve ) {
49+ return nil , errors .New ("NewPrivateKeyECDH: wrong key length" )
50+ }
4851 pkey , err := newECDHPkey (curve , bytes , true )
4952 if err != nil {
5053 return nil , err
5154 }
52- k := & PrivateKeyECDH {pkey , curve , false }
55+ k := & PrivateKeyECDH {pkey , curve }
5356 runtime .SetFinalizer (k , (* PrivateKeyECDH ).finalize )
5457 return k , nil
5558}
5659
5760func (k * PrivateKeyECDH ) PublicKey () (* PublicKeyECDH , error ) {
5861 defer runtime .KeepAlive (k )
59- if ! k .hasPublicKey {
60- err := deriveEcdhPublicKey (k ._pkey , k .curve )
61- if err != nil {
62- return nil , err
63- }
64- k .hasPublicKey = true
65- }
6662 var pkey C.GO_EVP_PKEY_PTR
6763 defer func () {
6864 C .go_openssl_EVP_PKEY_free (pkey )
@@ -112,10 +108,7 @@ func (k *PrivateKeyECDH) PublicKey() (*PublicKeyECDH, error) {
112108}
113109
114110func newECDHPkey (curve string , bytes []byte , isPrivate bool ) (C.GO_EVP_PKEY_PTR , error ) {
115- nid , err := curveNID (curve )
116- if err != nil {
117- return nil , err
118- }
111+ nid := curveNID (curve )
119112 switch vMajor {
120113 case 1 :
121114 return newECDHPkey1 (nid , bytes , isPrivate )
@@ -138,6 +131,7 @@ func newECDHPkey1(nid C.int, bytes []byte, isPrivate bool) (pkey C.GO_EVP_PKEY_P
138131 C .go_openssl_EC_KEY_free (key )
139132 }
140133 }()
134+ group := C .go_openssl_EC_KEY_get0_group (key )
141135 if isPrivate {
142136 priv := C .go_openssl_BN_bin2bn (base (bytes ), C .int (len (bytes )), nil )
143137 if priv == nil {
@@ -147,8 +141,15 @@ func newECDHPkey1(nid C.int, bytes []byte, isPrivate bool) (pkey C.GO_EVP_PKEY_P
147141 if C .go_openssl_EC_KEY_set_private_key (key , priv ) != 1 {
148142 return nil , newOpenSSLError ("EC_KEY_set_private_key" )
149143 }
144+ pub , err := pointMult (group , priv )
145+ if err != nil {
146+ return nil , err
147+ }
148+ defer C .go_openssl_EC_POINT_free (pub )
149+ if C .go_openssl_EC_KEY_set_public_key (key , pub ) != 1 {
150+ return nil , newOpenSSLError ("EC_KEY_set_public_key" )
151+ }
150152 } else {
151- group := C .go_openssl_EC_KEY_get0_group (key )
152153 pub := C .go_openssl_EC_POINT_new (group )
153154 if pub == nil {
154155 return nil , newOpenSSLError ("EC_POINT_new" )
@@ -161,6 +162,14 @@ func newECDHPkey1(nid C.int, bytes []byte, isPrivate bool) (pkey C.GO_EVP_PKEY_P
161162 return nil , newOpenSSLError ("EC_KEY_set_public_key" )
162163 }
163164 }
165+ if C .go_openssl_EC_KEY_check_key (key ) != 1 {
166+ // Match upstream error message.
167+ if isPrivate {
168+ return nil , errors .New ("crypto/ecdh: invalid private key" )
169+ } else {
170+ return nil , errors .New ("crypto/ecdh: invalid public key" )
171+ }
172+ }
164173 return newEVPPKEY (key )
165174}
166175
@@ -175,7 +184,19 @@ func newECDHPkey3(nid C.int, bytes []byte, isPrivate bool) (C.GO_EVP_PKEY_PTR, e
175184 bld .addUTF8String (_OSSL_PKEY_PARAM_GROUP_NAME , C .go_openssl_OBJ_nid2sn (nid ), 0 )
176185 var selection C.int
177186 if isPrivate {
178- bld .addBin (_OSSL_PKEY_PARAM_PRIV_KEY , bytes , true )
187+ priv := C .go_openssl_BN_bin2bn (base (bytes ), C .int (len (bytes )), nil )
188+ if priv == nil {
189+ return nil , newOpenSSLError ("BN_bin2bn" )
190+ }
191+ defer C .go_openssl_BN_clear_free (priv )
192+ pubBytes , err := generateAndEncodeEcPublicKey (nid , func (group C.GO_EC_GROUP_PTR ) (C.GO_EC_POINT_PTR , error ) {
193+ return pointMult (group , priv )
194+ })
195+ if err != nil {
196+ return nil , err
197+ }
198+ bld .addOctetString (_OSSL_PKEY_PARAM_PUB_KEY , pubBytes )
199+ bld .addBN (_OSSL_PKEY_PARAM_PRIV_KEY , priv )
179200 selection = C .GO_EVP_PKEY_KEYPAIR
180201 } else {
181202 bld .addOctetString (_OSSL_PKEY_PARAM_PUB_KEY , bytes )
@@ -187,62 +208,31 @@ func newECDHPkey3(nid C.int, bytes []byte, isPrivate bool) (C.GO_EVP_PKEY_PTR, e
187208 return nil , err
188209 }
189210 defer C .go_openssl_OSSL_PARAM_free (params )
190- return newEvpFromParams (C .GO_EVP_PKEY_EC , selection , params )
211+ pkey , err := newEvpFromParams (C .GO_EVP_PKEY_EC , selection , params )
212+ if err != nil {
213+ return nil , err
214+ }
215+
216+ if err := checkPkey (pkey , isPrivate ); err != nil {
217+ C .go_openssl_EVP_PKEY_free (pkey )
218+ return nil , errors .New ("crypto/ecdh: " + err .Error ())
219+ }
220+ return pkey , nil
191221}
192222
193- // deriveEcdhPublicKey sets the raw public key of pkey by deriving it from
194- // the raw private key.
195- func deriveEcdhPublicKey (pkey C.GO_EVP_PKEY_PTR , curve string ) error {
196- derive := func (group C.GO_EC_GROUP_PTR , priv C.GO_BIGNUM_PTR ) (C.GO_EC_POINT_PTR , error ) {
197- // OpenSSL does not expose any method to generate the public
198- // key from the private key [1], so we have to calculate it here.
199- // [1] https://github.com/openssl/openssl/issues/18437#issuecomment-1144717206
200- pt := C .go_openssl_EC_POINT_new (group )
201- if pt == nil {
202- return nil , newOpenSSLError ("EC_POINT_new" )
203- }
204- if C .go_openssl_EC_POINT_mul (group , pt , priv , nil , nil , nil ) == 0 {
205- C .go_openssl_EC_POINT_free (pt )
206- return nil , newOpenSSLError ("EC_POINT_mul" )
207- }
208- return pt , nil
223+ func pointMult (group C.GO_EC_GROUP_PTR , priv C.GO_BIGNUM_PTR ) (C.GO_EC_POINT_PTR , error ) {
224+ // OpenSSL does not expose any method to generate the public
225+ // key from the private key [1], so we have to calculate it here.
226+ // [1] https://github.com/openssl/openssl/issues/18437#issuecomment-1144717206
227+ pt := C .go_openssl_EC_POINT_new (group )
228+ if pt == nil {
229+ return nil , newOpenSSLError ("EC_POINT_new" )
209230 }
210- switch vMajor {
211- case 1 :
212- key := getECKey (pkey )
213- priv := C .go_openssl_EC_KEY_get0_private_key (key )
214- if priv == nil {
215- return newOpenSSLError ("EC_KEY_get0_private_key" )
216- }
217- group := C .go_openssl_EC_KEY_get0_group (key )
218- pub , err := derive (group , priv )
219- if err != nil {
220- return err
221- }
222- defer C .go_openssl_EC_POINT_free (pub )
223- if C .go_openssl_EC_KEY_set_public_key (key , pub ) != 1 {
224- return newOpenSSLError ("EC_KEY_set_public_key" )
225- }
226- case 3 :
227- var priv C.GO_BIGNUM_PTR
228- if C .go_openssl_EVP_PKEY_get_bn_param (pkey , _OSSL_PKEY_PARAM_PRIV_KEY , & priv ) != 1 {
229- return newOpenSSLError ("EVP_PKEY_get_bn_param" )
230- }
231- defer C .go_openssl_BN_clear_free (priv )
232- nid , _ := curveNID (curve )
233- pubBytes , err := generateAndEncodeEcPublicKey (nid , func (group C.GO_EC_GROUP_PTR ) (C.GO_EC_POINT_PTR , error ) {
234- return derive (group , priv )
235- })
236- if err != nil {
237- return err
238- }
239- if C .go_openssl_EVP_PKEY_set1_encoded_public_key (pkey , base (pubBytes ), C .size_t (len (pubBytes ))) != 1 {
240- return newOpenSSLError ("EVP_PKEY_set1_encoded_public_key" )
241- }
242- default :
243- panic (errUnsupportedVersion ())
231+ if C .go_openssl_EC_POINT_mul (group , pt , priv , nil , nil , nil ) == 0 {
232+ C .go_openssl_EC_POINT_free (pt )
233+ return nil , newOpenSSLError ("EC_POINT_mul" )
244234 }
245- return nil
235+ return pt , nil
246236}
247237
248238func ECDH (priv * PrivateKeyECDH , pub * PublicKeyECDH ) ([]byte , error ) {
@@ -307,7 +297,7 @@ func GenerateKeyECDH(curve string) (*PrivateKeyECDH, []byte, error) {
307297 if err := bnToBinPad (priv , bytes ); err != nil {
308298 return nil , nil , err
309299 }
310- k = & PrivateKeyECDH {pkey , curve , true }
300+ k = & PrivateKeyECDH {pkey , curve }
311301 runtime .SetFinalizer (k , (* PrivateKeyECDH ).finalize )
312302 return k , bytes , nil
313303}
0 commit comments