diff --git a/ssh/knownhosts/knownhosts.go b/ssh/knownhosts/knownhosts.go index 7376a8dff2..9bf7595e71 100644 --- a/ssh/knownhosts/knownhosts.go +++ b/ssh/knownhosts/knownhosts.go @@ -481,17 +481,17 @@ func decodeHash(encoded string) (hashType string, salt, hash []byte, err error) err = errors.New("knownhosts: hashed host must start with '|'") return } - components := strings.Split(encoded, "|") - if len(components) != 4 { + components := strings.Split(encoded[1:], "|") + if len(components) != 3 { err = fmt.Errorf("knownhosts: got %d components, want 3", len(components)) return } - hashType = components[1] - if salt, err = base64.StdEncoding.DecodeString(components[2]); err != nil { + hashType = components[0] + if salt, err = base64.StdEncoding.DecodeString(components[1]); err != nil { return } - if hash, err = base64.StdEncoding.DecodeString(components[3]); err != nil { + if hash, err = base64.StdEncoding.DecodeString(components[2]); err != nil { return } return diff --git a/ssh/knownhosts/knownhosts_test.go b/ssh/knownhosts/knownhosts_test.go index 464dd59249..4a398b18e5 100644 --- a/ssh/knownhosts/knownhosts_test.go +++ b/ssh/knownhosts/knownhosts_test.go @@ -9,6 +9,7 @@ import ( "fmt" "net" "reflect" + "strings" "testing" "golang.org/x/crypto/ssh" @@ -292,6 +293,7 @@ const encodedTestHostnameHash = "|1|IHXZvQMvTcZTUU29+2vXFgx8Frs=|UGccIWfRVDwilMB func TestHostHash(t *testing.T) { testHostHash(t, testHostname, encodedTestHostnameHash) + testHostHashDecode(t) } func TestHashList(t *testing.T) { @@ -299,6 +301,19 @@ func TestHashList(t *testing.T) { testHostHash(t, testHostname, encoded) } +func testHostHashDecode(t *testing.T) { + for in, want := range map[string]string{ + "1": "must start with '|'", + "|typ|salt": "got 2 components", + "|typ|salt|hash|extra": "got 4 components", + } { + _, _, _, err := decodeHash(in) + if err == nil || !strings.Contains(err.Error(), want) { + t.Fatalf("decodeHash: expected error to match %q, got %v", want, err) + } + } +} + func testHostHash(t *testing.T, hostname, encoded string) { typ, salt, hash, err := decodeHash(encoded) if err != nil {