diff --git a/crypto/hash/algorithms/argon2/argon2.go b/crypto/hash/algorithms/argon2/argon2.go index 877a0e91..67b567ea 100644 --- a/crypto/hash/algorithms/argon2/argon2.go +++ b/crypto/hash/algorithms/argon2/argon2.go @@ -108,25 +108,31 @@ type Argon2 struct { codec interfaces.Codec } +func (c *Argon2) Type() string { + return types.TypeArgon2.String() +} + // ConfigValidator implements the config validator for Argon2 -type ConfigValidator struct{} +type ConfigValidator struct { + params *Params +} // Validate validates the Argon2 configuration func (v *ConfigValidator) Validate(config *types.Config) error { - if config.TimeCost < 1 { - return fmt.Errorf("invalid time cost: %d", config.TimeCost) + if config.SaltLength < 8 { + return core.ErrSaltLengthTooShort } - if config.MemoryCost < 1 { - return fmt.Errorf("invalid memory cost: %d", config.MemoryCost) + if v.params.TimeCost < 1 { + return fmt.Errorf("invalid time cost: %d", v.params.TimeCost) } - if config.Threads < 1 { - return fmt.Errorf("invalid threads: %d", config.Threads) + if v.params.MemoryCost < 1 { + return fmt.Errorf("invalid memory cost: %d", v.params.MemoryCost) } - if config.SaltLength < 8 { - return core.ErrSaltLengthTooShort + if v.params.Threads < 1 { + return fmt.Errorf("invalid threads: %d", v.params.Threads) } - if config.KeyLength < 4 || config.KeyLength > 1024 { - return fmt.Errorf("invalid key length: %d, must be between 4 and 1024", config.KeyLength) + if v.params.KeyLength < 4 || v.params.KeyLength > 1024 { + return fmt.Errorf("invalid key length: %d, must be between 4 and 1024", v.params.KeyLength) } return nil } @@ -134,11 +140,17 @@ func (v *ConfigValidator) Validate(config *types.Config) error { // DefaultConfig returns the default configuration for Argon2 func DefaultConfig() *types.Config { return &types.Config{ - TimeCost: 3, // Default time cost - MemoryCost: 64 * 1024, // Default memory cost (64MB) - Threads: 4, // Default threads - SaltLength: 16, // Default salt length - KeyLength: 32, // Default key length + SaltLength: 16, // Default salt length + ParamConfig: DefaultParams().String(), + } +} + +func DefaultParams() *Params { + return &Params{ + TimeCost: 3, // Default time cost + MemoryCost: 65536, // Default memory cost (64MB) + Threads: 4, // Default threads + KeyLength: 32, // Default key length } } @@ -148,16 +160,22 @@ func NewArgon2Crypto(config *types.Config) (interfaces.Cryptographic, error) { if config == nil { config = DefaultConfig() } - validator := &ConfigValidator{} + + if config.ParamConfig == "" { + config.ParamConfig = DefaultParams().String() + } + params, err := parseParams(config.ParamConfig) + if err != nil { + return nil, fmt.Errorf("invalid argon2 param config: %v", err) + } + + validator := &ConfigValidator{ + params: params, + } if err := validator.Validate(config); err != nil { return nil, fmt.Errorf("invalid argon2 config: %v", err) } - params := &Params{ - TimeCost: config.TimeCost, - MemoryCost: config.MemoryCost, - Threads: config.Threads, - KeyLength: config.KeyLength, - } + return &Argon2{ params: params, config: config, @@ -195,7 +213,7 @@ func (c *Argon2) Verify(hashed, password string) error { return err } if parts.Algorithm != types.TypeArgon2 { - return fmt.Errorf("algorithm mismatch") + return core.ErrAlgorithmMismatch } // Parse parameters params, err := parseParams(parts.Params) diff --git a/crypto/hash/algorithms/argon2/argon2_test.go b/crypto/hash/algorithms/argon2/argon2_test.go index 9f0df8fd..c7958c47 100644 --- a/crypto/hash/algorithms/argon2/argon2_test.go +++ b/crypto/hash/algorithms/argon2/argon2_test.go @@ -197,7 +197,7 @@ func TestParams_String(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if got := tt.params.String(); got != tt.want { - t.Errorf("Params.String() = %v, want %v", got, tt.want) + t.Errorf("params.String() = %v, want %v", got, tt.want) } }) } @@ -217,22 +217,26 @@ func TestNewArgon2Crypto(t *testing.T) { { name: "Custom config", config: &types.Config{ - TimeCost: 3, - MemoryCost: 64 * 1024, - Threads: 4, SaltLength: 16, - KeyLength: 32, + ParamConfig: (&Params{ + TimeCost: 3, + MemoryCost: 64 * 1024, + Threads: 4, + KeyLength: 32, + }).String(), }, wantErr: false, }, { name: "Invalid config - zero time cost", config: &types.Config{ - TimeCost: 0, - MemoryCost: 64 * 1024, - Threads: 4, SaltLength: 16, - KeyLength: 32, + ParamConfig: (&Params{ + TimeCost: 0, + MemoryCost: 64 * 1024, + Threads: 4, + KeyLength: 32, + }).String(), }, wantErr: true, }, diff --git a/crypto/hash/algorithms/bcrypt/bcrypt.go b/crypto/hash/algorithms/bcrypt/bcrypt.go index 06c7834e..7f80fcb5 100644 --- a/crypto/hash/algorithms/bcrypt/bcrypt.go +++ b/crypto/hash/algorithms/bcrypt/bcrypt.go @@ -6,6 +6,8 @@ package bcrypt import ( "fmt" + "strconv" + "strings" "golang.org/x/crypto/bcrypt" @@ -17,18 +19,55 @@ import ( // Bcrypt implements the Bcrypt hashing algorithm type Bcrypt struct { + params *Params config *types.Config codec interfaces.Codec } +func (c *Bcrypt) Type() string { + return types.TypeBcrypt.String() +} + +type Params struct { + Cost int +} + +func (p *Params) String() string { + return fmt.Sprintf("c:%d", p.Cost) +} + +func parseParams(params string) (*Params, error) { + result := &Params{} + + if params == "" { + return result, nil + } + for _, param := range strings.Split(params, ",") { + parts := strings.Split(param, ":") + if len(parts) != 2 { + return nil, fmt.Errorf("invalid bcrypt param format: %s", param) + } + switch parts[0] { + case "c": + cost, err := strconv.ParseInt(parts[1], 10, 32) + if err != nil { + return nil, fmt.Errorf("invalid cost: %v", err) + } + result.Cost = int(cost) + } + } + return result, nil +} + type ConfigValidator struct { + params *Params } func (v ConfigValidator) Validate(config *types.Config) interface{} { if config.SaltLength < 8 { return core.ErrSaltLengthTooShort } - if config.Cost < 4 || config.Cost > 31 { + if v.params.Cost < 4 || v.params.Cost > 31 { return core.ErrCostOutOfRange } return nil @@ -39,20 +78,37 @@ func NewBcryptCrypto(config *types.Config) (interfaces.Cryptographic, error) { if config == nil { config = DefaultConfig() } - validator := &ConfigValidator{} + if config.ParamConfig == "" { + config.ParamConfig = DefaultParams().String() + } + params, err := parseParams(config.ParamConfig) + if err != nil { + return nil, fmt.Errorf("invalid bcrypt param config: %v", err) + } + + validator := &ConfigValidator{ + params: params, + } if err := validator.Validate(config); err != nil { return nil, fmt.Errorf("invalid bcrypt config: %v", err) } return &Bcrypt{ config: config, + params: params, codec: core.NewCodec(types.TypeBcrypt), }, nil } func DefaultConfig() *types.Config { return &types.Config{ - SaltLength: 16, - Cost: 10, + SaltLength: 16, + ParamConfig: DefaultParams().String(), + } +} + +func DefaultParams() *Params { + return &Params{ + Cost: bcrypt.DefaultCost, } } @@ -67,7 +123,7 @@ func (c *Bcrypt) Hash(password string) (string, error) { // HashWithSalt implements the hash with salt method func (c *Bcrypt) HashWithSalt(password, salt string) (string, error) { - hash, err := bcrypt.GenerateFromPassword([]byte(password+salt), c.config.Cost) + hash, err := bcrypt.GenerateFromPassword([]byte(password+salt), c.params.Cost) if err != nil { return "", err } @@ -81,11 +137,11 @@ func (c *Bcrypt) Verify(hashed, password string) error { return err } if parts.Algorithm != types.TypeBcrypt { - return fmt.Errorf("algorithm mismatch") + return core.ErrAlgorithmMismatch } err = bcrypt.CompareHashAndPassword(parts.Hash, []byte(password+string(parts.Salt))) if err != nil { - return fmt.Errorf("password not match") + return core.ErrPasswordNotMatch } return nil } diff --git a/crypto/hash/algorithms/bcrypt/bcrypt_test.go b/crypto/hash/algorithms/bcrypt/bcrypt_test.go index 8f73266a..c0fc1baf 100644 --- a/crypto/hash/algorithms/bcrypt/bcrypt_test.go +++ b/crypto/hash/algorithms/bcrypt/bcrypt_test.go @@ -20,15 +20,19 @@ func TestNewBcryptCrypto(t *testing.T) { { name: "Custom config", config: &types.Config{ - Cost: 10, SaltLength: 16, + ParamConfig: (&Params{ + Cost: 10, + }).String(), }, wantErr: false, }, { name: "Invalid config - zero cost", config: &types.Config{ - Cost: 0, + ParamConfig: (&Params{ + Cost: 0, + }).String(), SaltLength: 16, }, wantErr: true, diff --git a/crypto/hash/algorithms/dummy/dummy.go b/crypto/hash/algorithms/dummy/dummy.go index 7cf8f481..35ea46bd 100644 --- a/crypto/hash/algorithms/dummy/dummy.go +++ b/crypto/hash/algorithms/dummy/dummy.go @@ -5,6 +5,7 @@ package dummy import ( + "errors" "fmt" "github.com/origadmin/toolkits/crypto/hash/interfaces" @@ -13,19 +14,20 @@ import ( // Crypto implements a dummy hashing algorithm type Crypto struct { - config *types.Config +} + +func (c *Crypto) Type() string { + return "dummy" } // NewDummyCrypto creates a new dummy crypto instance func NewDummyCrypto(config *types.Config) (interfaces.Cryptographic, error) { - return &Crypto{ - config: config, - }, nil + return nil, errors.New("algorithm not implemented") } // Hash implements the hash method func (c *Crypto) Hash(password string) (string, error) { - return "", fmt.Errorf("dummy algorithm not implemented") + return "", fmt.Errorf("algorithm not implemented") } // HashWithSalt implements the hash with salt method @@ -35,7 +37,7 @@ func (c *Crypto) HashWithSalt(password, salt string) (string, error) { // Verify implements the verify method func (c *Crypto) Verify(hashed, password string) error { - return fmt.Errorf("dummy algorithm not implemented") + return fmt.Errorf("algorithm not implemented") } func DefaultConfig() *types.Config { diff --git a/crypto/hash/algorithms/hmac256/hmac256.go b/crypto/hash/algorithms/hmac256/hmac256.go index 0e31f37b..d9da74cc 100644 --- a/crypto/hash/algorithms/hmac256/hmac256.go +++ b/crypto/hash/algorithms/hmac256/hmac256.go @@ -7,6 +7,7 @@ package hmac256 import ( "crypto/hmac" "crypto/sha256" + "crypto/subtle" "fmt" "github.com/origadmin/toolkits/crypto/hash/core" @@ -21,6 +22,10 @@ type HMAC256 struct { codec interfaces.Codec } +func (c *HMAC256) Type() string { + return types.TypeHMAC256.String() +} + type ConfigValidator struct { SaltLength int } @@ -78,14 +83,14 @@ func (c *HMAC256) Verify(hashed, password string) error { } if parts.Algorithm != types.TypeHMAC256 { - return fmt.Errorf("algorithm mismatch") + return core.ErrAlgorithmMismatch } h := hmac.New(sha256.New, parts.Salt) h.Write([]byte(password)) newHash := h.Sum(nil) - if string(newHash) != string(parts.Hash) { - return fmt.Errorf("password not match") + if subtle.ConstantTimeCompare(newHash, parts.Hash) != 1 { + return core.ErrPasswordNotMatch } return nil diff --git a/crypto/hash/algorithms/md5/md5.go b/crypto/hash/algorithms/md5/md5.go index 3ce82bd3..863ff287 100644 --- a/crypto/hash/algorithms/md5/md5.go +++ b/crypto/hash/algorithms/md5/md5.go @@ -21,6 +21,10 @@ type MD5 struct { codec interfaces.Codec } +func (c *MD5) Type() string { + return types.TypeMD5.String() +} + type ConfigValidator struct { SaltLength int } @@ -35,7 +39,7 @@ func (v ConfigValidator) Validate(config *types.Config) interface{} { // NewMD5Crypto creates a new MD5 crypto instance func NewMD5Crypto(config *types.Config) (interfaces.Cryptographic, error) { if config == nil { - config = DefaultConfig() + config = types.DefaultConfig() } validator := &ConfigValidator{} if err := validator.Validate(config); err != nil { @@ -70,18 +74,12 @@ func (c *MD5) Verify(hashed, password string) error { } if parts.Algorithm != types.TypeMD5 { - return fmt.Errorf("algorithm mismatch") + return core.ErrAlgorithmMismatch } newHash := md5.Sum([]byte(password + string(parts.Salt))) if subtle.ConstantTimeCompare(newHash[:], parts.Hash) != 1 { - return fmt.Errorf("password not match") + return core.ErrPasswordNotMatch } return nil } - -func DefaultConfig() *types.Config { - return &types.Config{ - SaltLength: 16, - } -} diff --git a/crypto/hash/algorithms/pbkdf2/pbkdf2.go b/crypto/hash/algorithms/pbkdf2/pbkdf2.go new file mode 100644 index 00000000..040e19e2 --- /dev/null +++ b/crypto/hash/algorithms/pbkdf2/pbkdf2.go @@ -0,0 +1,223 @@ +/* + * Copyright (c) 2024 OrigAdmin. All rights reserved. + */ + +package pbkdf2 + +import ( + "crypto/subtle" + "fmt" + "hash" + "strconv" + "strings" + + "golang.org/x/crypto/pbkdf2" + + "github.com/origadmin/toolkits/crypto/hash/core" + "github.com/origadmin/toolkits/crypto/hash/interfaces" + "github.com/origadmin/toolkits/crypto/hash/types" + "github.com/origadmin/toolkits/crypto/hash/utils" +) + +// PBKDF2 implements the PBKDF2 hashing algorithm +type PBKDF2 struct { + params *Params + config *types.Config + codec interfaces.Codec +} + +func (c *PBKDF2) Type() string { + return types.TypePBKDF2.String() +} + +type ConfigValidator struct { + params *Params +} + +func (v ConfigValidator) Validate(config *types.Config) error { + if config.SaltLength < 8 { + return fmt.Errorf("salt length must be at least 8 bytes") + } + if v.params.Iterations < 1000 { + return fmt.Errorf("iterations must be at least 1000") + } + if v.params.KeyLength < 8 { + return fmt.Errorf("key length must be at least 8 bytes") + } + if v.params.HashType == "" { + return fmt.Errorf("hash type must be specified") + } + _, err := core.ParseHash(v.params.HashType) + if err != nil { + return err + } + return nil +} + +// NewPBKDF2Crypto creates a new PBKDF2 crypto instance +func NewPBKDF2Crypto(config *types.Config) (interfaces.Cryptographic, error) { + if config == nil { + config = DefaultConfig() + } + + if config.ParamConfig == "" { + config.ParamConfig = DefaultParams().String() + } + params, err := parseParams(config.ParamConfig) + if err != nil { + return nil, fmt.Errorf("invalid pbkdf2 param config: %v", err) + } + + validator := &ConfigValidator{ + params: params, + } + if err := validator.Validate(config); err != nil { + return nil, fmt.Errorf("invalid pbkdf2 config: %v", err) + } + return &PBKDF2{ + params: params, + config: config, + codec: core.NewCodec(types.TypePBKDF2), + }, nil +} + +func DefaultParams() *Params { + return &Params{ + Iterations: 10000, + KeyLength: 32, + HashType: core.SHA256.String(), + } +} + +func DefaultConfig() *types.Config { + return &types.Config{ + SaltLength: 16, + ParamConfig: DefaultParams().String(), + } +} + +// Hash implements the hash method +func (c *PBKDF2) Hash(password string) (string, error) { + salt, err := utils.GenerateSalt(c.config.SaltLength) + if err != nil { + return "", err + } + return c.HashWithSalt(password, string(salt)) +} + +func (c *PBKDF2) hashFromName(name string) (func() hash.Hash, error) { + parseHash, err := core.ParseHash(name) + if err != nil { + return nil, err + } + return parseHash.New, nil +} + +// HashWithSalt implements the hash with salt method +func (c *PBKDF2) HashWithSalt(password, salt string) (string, error) { + hashHash, err := c.hashFromName(c.params.HashType) + if err != nil { + return "", err + } + newHash := pbkdf2.Key([]byte(password), []byte(salt), c.params.Iterations, int(c.params.KeyLength), hashHash) + return c.codec.Encode([]byte(salt), newHash, c.params.String()), nil +} + +// Verify implements the verify method +func (c *PBKDF2) Verify(hashed, password string) error { + parts, err := c.codec.Decode(hashed) + if err != nil { + return err + } + if parts.Algorithm != types.TypePBKDF2 { + return core.ErrAlgorithmMismatch + } + + // Parse parameters + params, err := parseParams(parts.Params) + if err != nil { + return err + } + + // The hash function is recreated based on the hash type being parsed + hashHash, err := c.hashFromName(params.HashType) + if err != nil { + return err + } + + newHash := pbkdf2.Key([]byte(password), parts.Salt, params.Iterations, int(params.KeyLength), hashHash) + if subtle.ConstantTimeCompare(newHash, parts.Hash) != 1 { + return core.ErrPasswordNotMatch + } + return nil +} + +// Params represents parameters for PBKDF2 algorithm +type Params struct { + Iterations int + KeyLength uint32 + HashType string +} + +// String returns the string representation of parameters +func (p *Params) String() string { + var parts []string + if p.Iterations > 0 { + parts = append(parts, fmt.Sprintf("i:%d", p.Iterations)) + } + if p.KeyLength > 0 { + parts = append(parts, fmt.Sprintf("k:%d", p.KeyLength)) + } + _, err := core.ParseHash(p.HashType) + if err == nil { + parts = append(parts, fmt.Sprintf("h:%s", p.HashType)) + } + return strings.Join(parts, ",") +} + +// parseParams parses PBKDF2 parameters from string +func parseParams(params string) (*Params, error) { + result := &Params{} + + // Handle empty string case + if params == "" { + return result, nil + } + + kv := make(map[string]string) + for _, param := range strings.Split(params, ",") { + parts := strings.Split(param, ":") + if len(parts) != 2 { + return nil, fmt.Errorf("invalid pbkdf2 param format: %s", param) + } + kv[parts[0]] = parts[1] + } + + // Parse iterations + if v, ok := kv["i"]; ok { + iterations, err := strconv.Atoi(v) + if err != nil { + return nil, fmt.Errorf("invalid iterations: %v", err) + } + result.Iterations = iterations + } + + // Parse key length + if v, ok := kv["k"]; ok { + keyLength, err := strconv.Atoi(v) + if err != nil { + return nil, fmt.Errorf("invalid key length: %v", err) + } + result.KeyLength = uint32(keyLength) + } + + // Parse hash type + if v, ok := kv["h"]; ok { + _, err := core.ParseHash(v) + if err == nil { + result.HashType = v + } + } + + return result, nil +} diff --git a/crypto/hash/algorithms/scrypt/scrypt.go b/crypto/hash/algorithms/scrypt/scrypt.go index a222e243..76cf35b1 100644 --- a/crypto/hash/algorithms/scrypt/scrypt.go +++ b/crypto/hash/algorithms/scrypt/scrypt.go @@ -20,19 +20,25 @@ import ( // Scrypt implements the Scrypt hashing algorithm type Scrypt struct { + params *Params config *types.Config codec interfaces.Codec } +func (c *Scrypt) Type() string { + return types.TypeScrypt.String() +} + type ConfigValidator struct { + params *Params } -func (v ConfigValidator) Validate(config *types.Config) interface{} { +func (v ConfigValidator) Validate(config *types.Config) error { if config.SaltLength < 8 { return fmt.Errorf("salt length must be at least 8 bytes") } // N must be > 1 and a power of 2 - if config.Scrypt.N <= 1 || config.Scrypt.N&(config.Scrypt.N-1) != 0 { + if v.params.N <= 1 || v.params.N&(v.params.N-1) != 0 { return fmt.Errorf("N must be > 1 and a power of 2") } @@ -44,25 +50,41 @@ func NewScryptCrypto(config *types.Config) (interfaces.Cryptographic, error) { if config == nil { config = DefaultConfig() } - validator := &ConfigValidator{} + + if config.ParamConfig == "" { + config.ParamConfig = DefaultParams().String() + } + params, err := parseParams(config.ParamConfig) + if err != nil { + return nil, fmt.Errorf("invalid scrypt param config: %v", err) + } + + validator := &ConfigValidator{ + params: params, + } if err := validator.Validate(config); err != nil { return nil, fmt.Errorf("invalid scrypt config: %v", err) } return &Scrypt{ + params: params, config: config, codec: core.NewCodec(types.TypeScrypt), }, nil } +func DefaultParams() *Params { + return &Params{ + N: 16384, + R: 8, + P: 1, + KeyLen: 32, + } +} + func DefaultConfig() *types.Config { return &types.Config{ - SaltLength: 16, - KeyLength: 32, - Scrypt: types.ScryptConfig{ - N: 16384, - R: 8, - P: 1, - }, + SaltLength: 16, + ParamConfig: DefaultParams().String(), } } @@ -160,17 +182,11 @@ func (c *Scrypt) Hash(password string) (string, error) { // HashWithSalt implements the hash with salt method func (c *Scrypt) HashWithSalt(password, salt string) (string, error) { - params := &Params{ - N: c.config.Scrypt.N, - R: c.config.Scrypt.R, - P: c.config.Scrypt.P, - KeyLen: int(c.config.KeyLength), - } - hash, err := scrypt.Key([]byte(password), []byte(salt), params.N, params.R, params.P, params.KeyLen) + hash, err := scrypt.Key([]byte(password), []byte(salt), c.params.N, c.params.R, c.params.P, c.params.KeyLen) if err != nil { return "", err } - return c.codec.Encode([]byte(salt), hash, params.String()), nil + return c.codec.Encode([]byte(salt), hash, c.params.String()), nil } // Verify implements the verify method diff --git a/crypto/hash/algorithms/sha/sha.go b/crypto/hash/algorithms/sha/sha.go new file mode 100644 index 00000000..f99a0b30 --- /dev/null +++ b/crypto/hash/algorithms/sha/sha.go @@ -0,0 +1,108 @@ +/* + * Copyright (c) 2024 OrigAdmin. All rights reserved. + */ + +package sha + +import ( + "crypto/subtle" + "fmt" + + "github.com/origadmin/toolkits/crypto/hash/core" + "github.com/origadmin/toolkits/crypto/hash/interfaces" + "github.com/origadmin/toolkits/crypto/hash/types" + "github.com/origadmin/toolkits/crypto/hash/utils" +) + +// SHA implements the SHA hashing algorithm +type SHA struct { + config *types.Config + codec interfaces.Codec + hashHash core.Hash +} + +func (c *SHA) Type() string { + return c.hashHash.String() +} + +type ConfigValidator struct { +} + +func (v ConfigValidator) Validate(config *types.Config) interface{} { + if config.SaltLength < 8 { + return core.ErrSaltLengthTooShort + } + return nil +} + +// NewSHACrypto creates a new SHA crypto instance +func NewSHACrypto(hashType types.Type, config *types.Config) (interfaces.Cryptographic, error) { + if config == nil { + config = DefaultConfig() + } + validator := &ConfigValidator{} + if err := validator.Validate(config); err != nil { + return nil, fmt.Errorf("invalid sha config: %v", err) + } + hashHash, err := core.ParseHash(hashType.String()) + if err != nil { + return nil, err + } + + return &SHA{ + config: config, + codec: core.NewCodec(hashType), + hashHash: hashHash, + }, nil +} + +func NewSha1Crypto(config *types.Config) (interfaces.Cryptographic, error) { + return NewSHACrypto(types.TypeSha1, config) +} + +func NewSha256Crypto(config *types.Config) (interfaces.Cryptographic, error) { + return NewSHACrypto(types.TypeSha256, config) +} + +func NewSha512Crypto(config *types.Config) (interfaces.Cryptographic, error) { + return NewSHACrypto(types.TypeSha512, config) +} + +func DefaultConfig() *types.Config { + return &types.Config{ + SaltLength: 16, + } +} + +// Hash implements the hash method +func (c *SHA) Hash(password string) (string, error) { + salt, err := utils.GenerateSalt(c.config.SaltLength) + if err != nil { + return "", err + } + return c.HashWithSalt(password, string(salt)) +} + +// HashWithSalt implements the hash with salt method +func (c *SHA) HashWithSalt(password, salt string) (string, error) { + newHash := c.hashHash.New().Sum([]byte(password + salt)) + return c.codec.Encode([]byte(salt), newHash[:]), nil +} + +// Verify implements the verify method +func (c *SHA) Verify(hashed, password string) error { + parts, err := c.codec.Decode(hashed) + if err != nil { + return err + } + + if parts.Algorithm.String() != c.hashHash.String() { + return core.ErrAlgorithmMismatch + } + newHash := c.hashHash.New().Sum([]byte(password + string(parts.Salt))) + if subtle.ConstantTimeCompare(newHash, parts.Hash) != 1 { + return core.ErrPasswordNotMatch + } + + return nil +} diff --git a/crypto/hash/algorithms/sha1/sha1.go b/crypto/hash/algorithms/sha1/sha1.go deleted file mode 100644 index 33f4faf2..00000000 --- a/crypto/hash/algorithms/sha1/sha1.go +++ /dev/null @@ -1,87 +0,0 @@ -/* - * Copyright (c) 2024 OrigAdmin. All rights reserved. - */ - -package sha1 - -import ( - "crypto/sha1" - "crypto/subtle" - "fmt" - - "github.com/origadmin/toolkits/crypto/hash/core" - "github.com/origadmin/toolkits/crypto/hash/interfaces" - "github.com/origadmin/toolkits/crypto/hash/types" - "github.com/origadmin/toolkits/crypto/hash/utils" -) - -// SHA1Crypto implements the SHA1 hashing algorithm -type SHA1Crypto struct { - config *types.Config - codec interfaces.Codec -} - -type ConfigValidator struct { -} - -func (v ConfigValidator) Validate(config *types.Config) interface{} { - if config.SaltLength < 8 { - return core.ErrSaltLengthTooShort - } - return nil -} - -// NewSHA1Crypto creates a new SHA1 crypto instance -func NewSHA1Crypto(config *types.Config) (interfaces.Cryptographic, error) { - if config == nil { - config = DefaultConfig() - } - validator := &ConfigValidator{} - if err := validator.Validate(config); err != nil { - return nil, fmt.Errorf("invalid sha1 config: %v", err) - } - return &SHA1Crypto{ - config: config, - codec: core.NewCodec(types.TypeSha1), - }, nil -} - -func DefaultConfig() *types.Config { - return &types.Config{ - SaltLength: 16, - } -} - -// Hash implements the hash method -func (c *SHA1Crypto) Hash(password string) (string, error) { - salt, err := utils.GenerateSalt(c.config.SaltLength) - if err != nil { - return "", err - } - return c.HashWithSalt(password, string(salt)) -} - -// HashWithSalt implements the hash with salt method -func (c *SHA1Crypto) HashWithSalt(password, salt string) (string, error) { - hash := sha1.Sum([]byte(password + salt)) - return c.codec.Encode([]byte(salt), hash[:]), nil -} - -// Verify implements the verify method -func (c *SHA1Crypto) Verify(hashed, password string) error { - parts, err := c.codec.Decode(hashed) - if err != nil { - return err - } - - if parts.Algorithm != types.TypeSha1 { - return core.ErrAlgorithmMismatch - } - - newHash := sha1.Sum([]byte(password + string(parts.Salt))) - if subtle.ConstantTimeCompare(newHash[:], parts.Hash) != 1 { - return core.ErrPasswordNotMatch - } - - return nil -} diff --git a/crypto/hash/algorithms/sha256/sha256.go b/crypto/hash/algorithms/sha256/sha256.go deleted file mode 100644 index 1d22e248..00000000 --- a/crypto/hash/algorithms/sha256/sha256.go +++ /dev/null @@ -1,63 +0,0 @@ -/* - * Copyright (c) 2024 OrigAdmin. All rights reserved. - */ - -package sha256 - -import ( - "crypto/sha256" - "fmt" - - "github.com/origadmin/toolkits/crypto/hash/core" - "github.com/origadmin/toolkits/crypto/hash/interfaces" - "github.com/origadmin/toolkits/crypto/hash/types" - "github.com/origadmin/toolkits/crypto/hash/utils" -) - -// Sha256 implements the Sha256 hashing algorithm -type Sha256 struct { - config *types.Config - codec interfaces.Codec -} - -// NewSHA256Crypto creates a new Sha256 crypto instance -func NewSHA256Crypto(config *types.Config) (interfaces.Cryptographic, error) { - return &Sha256{ - config: config, - codec: core.NewCodec(types.TypeSha256), - }, nil -} - -// Hash implements the hash method -func (c *Sha256) Hash(password string) (string, error) { - salt, err := utils.GenerateSalt(c.config.SaltLength) - if err != nil { - return "", err - } - return c.HashWithSalt(password, string(salt)) -} - -// HashWithSalt implements the hash with salt method -func (c *Sha256) HashWithSalt(password, salt string) (string, error) { - hash := sha256.Sum256([]byte(password + salt)) - return c.codec.Encode([]byte(salt), hash[:]), nil -} - -// Verify implements the verify method -func (c *Sha256) Verify(hashed, password string) error { - parts, err := c.codec.Decode(hashed) - if err != nil { - return err - } - - if parts.Algorithm != types.TypeSha256 { - return fmt.Errorf("algorithm mismatch") - } - - newHash := sha256.Sum256([]byte(password + string(parts.Salt))) - if string(newHash[:]) != string(parts.Hash) { - return fmt.Errorf("password not match") - } - - return nil -} diff --git a/crypto/hash/cache.go b/crypto/hash/cache.go index ecd19dfe..a6b9f863 100644 --- a/crypto/hash/cache.go +++ b/crypto/hash/cache.go @@ -18,6 +18,10 @@ type cachedCrypto struct { cache sync.Map } +func (c *cachedCrypto) Type() string { + return c.crypto.Type() +} + type cacheItem struct { hash string expiresAt time.Time diff --git a/crypto/hash/core/hash.go b/crypto/hash/core/hash.go new file mode 100644 index 00000000..d0e5b438 --- /dev/null +++ b/crypto/hash/core/hash.go @@ -0,0 +1,358 @@ +/* + * Copyright (c) 2024 OrigAdmin. All rights reserved. + */ + +// Package core implements the functions, types, and interfaces for the module. +package core + +import ( + "crypto" + "crypto/md5" + "crypto/sha1" + "crypto/sha256" + "crypto/sha512" + "fmt" + "hash" + "hash/adler32" + "hash/crc32" + "hash/crc64" + "hash/fnv" + "hash/maphash" + "strings" + + "golang.org/x/crypto/blake2b" + "golang.org/x/crypto/blake2s" + "golang.org/x/crypto/md4" + "golang.org/x/crypto/ripemd160" + "golang.org/x/crypto/sha3" +) + +type Hash uint32 + +const ( + MD4 Hash = 1 + iota // import golang.org/x/crypto/md4 + MD5 // import crypto/md5 + SHA1 // import crypto/sha1 + SHA224 // import crypto/sha256 + SHA256 // import crypto/sha256 + SHA384 // import crypto/sha512 + SHA512 // import crypto/sha512 + MD5SHA1 // no implementation; MD5+SHA1 used for TLS RSA + RIPEMD160 // import golang.org/x/crypto/ripemd160 + SHA3_224 // import golang.org/x/crypto/sha3 + SHA3_256 // import golang.org/x/crypto/sha3 + SHA3_384 // import golang.org/x/crypto/sha3 + SHA3_512 // import golang.org/x/crypto/sha3 + SHA512_224 // import crypto/sha512 + SHA512_256 // import crypto/sha512 + BLAKE2s_256 // import golang.org/x/crypto/blake2s + BLAKE2b_256 // import golang.org/x/crypto/blake2b + BLAKE2b_384 // import golang.org/x/crypto/blake2b + BLAKE2b_512 // import golang.org/x/crypto/blake2b + + ADLER32 + CRC32 + CRC32_ISO + CRC32_CAST + CRC32_KOOP + CRC64_ISO + CRC64_ECMA + FNV32 + FNV32A + FNV64 + FNV64A + FNV128 + FNV128A + MAPHASH + maxHash +) + +var hashEnd = maxHash +var hashes = make([]func() hash.Hash, maxHash) +var customHashNames = make(map[string]Hash) +var customNameHashes = make(map[Hash]string) + +func init() { + // Register all hash.Hash functions + UpdateHashFunc(MD4, md4.New) + UpdateHashFunc(MD5, md5.New) + UpdateHashFunc(SHA1, sha1.New) + UpdateHashFunc(SHA224, sha256.New224) + UpdateHashFunc(SHA256, sha256.New) + UpdateHashFunc(SHA384, sha512.New384) + UpdateHashFunc(SHA512, sha512.New) + UpdateHashFunc(MD5SHA1, crypto.MD5SHA1.New) + UpdateHashFunc(RIPEMD160, ripemd160.New) + UpdateHashFunc(SHA3_224, sha3.New224) + UpdateHashFunc(SHA3_256, sha3.New256) + UpdateHashFunc(SHA3_384, sha3.New384) + UpdateHashFunc(SHA3_512, sha3.New512) + UpdateHashFunc(SHA512_224, sha512.New512_224) + UpdateHashFunc(SHA512_256, sha512.New512_256) + newBlake2sHash256 := func() hash.Hash { + h, _ := blake2s.New256(nil) + return h + } + UpdateHashFunc(BLAKE2s_256, newBlake2sHash256) + newBlake2bHash256 := func() hash.Hash { + h, _ := blake2b.New256(nil) + return h + } + newBlake2bHash384 := func() hash.Hash { + h, _ := blake2b.New384(nil) + return h + } + + newBlake2bHash512 := func() hash.Hash { + h, _ := blake2b.New512(nil) + return h + } + UpdateHashFunc(BLAKE2b_256, newBlake2bHash256) + UpdateHashFunc(BLAKE2b_384, newBlake2bHash384) + UpdateHashFunc(BLAKE2b_512, newBlake2bHash512) + + UpdateHashFunc(ADLER32, func() hash.Hash { return adler32.New() }) + UpdateHashFunc(CRC32, func() hash.Hash { return crc32.NewIEEE() }) + UpdateHashFunc(CRC32_ISO, func() hash.Hash { return crc32.New(crc32.MakeTable(crc32.IEEE)) }) + UpdateHashFunc(CRC32_CAST, func() hash.Hash { return crc32.New(crc32.MakeTable(crc32.Castagnoli)) }) + UpdateHashFunc(CRC32_KOOP, func() hash.Hash { return crc32.New(crc32.MakeTable(crc32.Koopman)) }) + UpdateHashFunc(CRC64_ISO, func() hash.Hash { return crc64.New(crc64.MakeTable(crc64.ISO)) }) + UpdateHashFunc(CRC64_ECMA, func() hash.Hash { return crc64.New(crc64.MakeTable(crc64.ECMA)) }) + UpdateHashFunc(FNV32, func() hash.Hash { return fnv.New32() }) + UpdateHashFunc(FNV32A, func() hash.Hash { return fnv.New32a() }) + UpdateHashFunc(FNV64, func() hash.Hash { return fnv.New64() }) + UpdateHashFunc(FNV64A, func() hash.Hash { return fnv.New64a() }) + UpdateHashFunc(FNV128, func() hash.Hash { return fnv.New128() }) + UpdateHashFunc(FNV128A, func() hash.Hash { return fnv.New128a() }) + newMapHash := func() hash.Hash { + var mh maphash.Hash + mh.SetSeed(maphash.MakeSeed()) + return &mh + } + UpdateHashFunc(MAPHASH, newMapHash) +} + +func (h Hash) New() hash.Hash { + if h > 0 && h < hashEnd { + f := hashes[h] + if f != nil { + return f() + } + } + panic(fmt.Sprintf("hash function %d not registered", h)) +} + +func (h Hash) String() string { + if h < maxHash { + switch h { + case MD4: + return "md4" + case MD5: + return "md5" + case SHA1: + return "sha1" + case SHA224: + return "sha224" + case SHA256: + return "sha256" + case SHA384: + return "sha384" + case SHA512: + return "sha512" + case MD5SHA1: + return "md5sha1" + case RIPEMD160: + return "ripemd160" + case SHA3_224: + return "sha3-224" + case SHA3_256: + return "sha3-256" + case SHA3_384: + return "sha3-384" + case SHA3_512: + return "sha3-512" + case SHA512_224: + return "sha512-224" + case SHA512_256: + return "sha512-256" + case BLAKE2s_256: + return "blake2s-256" + case BLAKE2b_256: + return "blake2b-256" + case BLAKE2b_384: + return "blake2b-384" + case BLAKE2b_512: + return "blake2b-512" + case ADLER32: + return "adler32" + case CRC32: + return "crc32" + case CRC32_ISO: + return "crc32-iso" + case CRC32_CAST: + return "crc32-cast" + case CRC64_ISO: + return "crc64-iso" + case CRC64_ECMA: + return "crc64-ecma" + case FNV32: + return "fnv32" + case FNV32A: + return "fnv32a" + case FNV64: + return "fnv64" + case FNV64A: + return "fnv64a" + case FNV128: + return "fnv128" + case FNV128A: + return "fnv128a" + case MAPHASH: + return "maphash" + default: + if name, ok := customNameHashes[h]; ok { + return name + } + } + } + return fmt.Sprintf("Hash(%d)", h) +} + +// RegisterHashFunc registers a new hash.Hash function +func RegisterHashFunc(name string, hashFunc func() hash.Hash) { + if _, err := ParseHash(name); err == nil { + panic(fmt.Sprintf("hash function %s already registered", name)) + } + name = strings.ToLower(name) + old := hashEnd + hashEnd++ + customHashNames[name] = old + customNameHashes[old] = name + hashes = append(hashes, hashFunc) +} + +// UpdateHashFunc updates a hash.Hash function +func UpdateHashFunc(hash Hash, hashFunc func() hash.Hash) { + if hash >= hashEnd { + return + } + hashes[hash] = hashFunc +} + +// RegisterOrUpdateHashFunc registers a new hash.Hash function if it does not exist, +// otherwise updates it +func RegisterOrUpdateHashFunc(name string, hashFunc func() hash.Hash) { + if h, err := ParseHash(name); err == nil { + UpdateHashFunc(h, hashFunc) + } else { + RegisterHashFunc(name, hashFunc) + } +} + +func ParseHash(s string) (Hash, error) { + s = strings.ToLower(s) + if h, ok := ParseCryptoHash(s); ok { + return h, nil + } + if h, ok := ParseInternalHash(s); ok { + return h, nil + } + if h, ok := ParseCustomHash(s); ok { + return h, nil + } + return 0, fmt.Errorf("unknown hash function: %s", s) +} + +// ParseCryptoHash only deals with crypto.Hash supported algorithms +func ParseCryptoHash(s string) (Hash, bool) { + switch s { + case "md4": + return MD4, true + case "md5": + return MD5, true + case "sha1": + return SHA1, true + case "sha224": + return SHA224, true + case "sha256": + return SHA256, true + case "sha384": + return SHA384, true + case "sha512": + return SHA512, true + case "md5sha1": + return MD5SHA1, true + case "ripemd160": + return RIPEMD160, true + case "sha3-224": + return SHA3_224, true + case "sha3-256": + return SHA3_256, true + case "sha3-384": + return SHA3_384, true + case "sha3-512": + return SHA3_512, true + case "sha512-224": + return SHA512_224, true + case "sha512-256": + return SHA512_256, true + case "blake2s-256": + return BLAKE2s_256, true + case "blake2b-256": + return BLAKE2b_256, true + case "blake2b-384": + return BLAKE2b_384, true + case "blake2b-512": + return BLAKE2b_512, true + default: + return Hash(0), false + } +} + +// ParseInternalHash algorithms that handle internal extensions +func ParseInternalHash(s string) (Hash, bool) { + switch s { + case "adler32": + return ADLER32, true + case "crc32": + return CRC32, true + case "crc32-iso": + return CRC32_ISO, true + case "crc32-cast": + return CRC32_CAST, true + case "crc32-koop": + return CRC32_KOOP, true + case "crc64-iso": + return CRC64_ISO, true + case "crc64-ecma": + return CRC64_ECMA, true + case "fnv-32": + return FNV32, true + case "fnv-32a": + return FNV32A, true + case "fnv-64": + return FNV64, true + case "fnv-64a": + return FNV64A, true + case "fnv-128": + return FNV128, true + case "fnv-128a": + return FNV128A, true + case "maphash": + return MAPHASH, true + default: + return 0, false + } +} + +// ParseCustomHash algorithms that handle custom registrations +func ParseCustomHash(s string) (Hash, bool) { + if h, ok := customHashNames[s]; ok { + return h, ok + } + return 0, false +} + +func IsCustomHash(h Hash) bool { + return h >= maxHash && h < hashEnd +} diff --git a/crypto/hash/core/hash_test.go b/crypto/hash/core/hash_test.go new file mode 100644 index 00000000..5b7d924d --- /dev/null +++ b/crypto/hash/core/hash_test.go @@ -0,0 +1,183 @@ +package core + +import ( + "fmt" + "hash" + "testing" +) + +type CustomHash struct { +} + +func (c *CustomHash) Write(p []byte) (n int, err error) { + return len(p), nil +} + +func (c *CustomHash) Reset() { +} + +func (c *CustomHash) Size() int { + return 0 +} + +func (c *CustomHash) BlockSize() int { + return 0 +} + +func (c *CustomHash) Sum(b []byte) []byte { + return []byte("mock_custom_hash") +} + +type CustomHash2 struct { +} + +func (c *CustomHash2) Write(p []byte) (n int, err error) { + return len(p), nil +} + +func (c *CustomHash2) Reset() { + +} + +func (c *CustomHash2) Size() int { + return 0 +} + +func (c *CustomHash2) BlockSize() int { + return 0 +} + +func (c *CustomHash2) Sum(b []byte) []byte { + return []byte("custom_hash_2") +} + +type CustomHash3 struct { + counter int + sum []byte +} + +func (c *CustomHash3) Write(p []byte) (n int, err error) { + c.counter += len(p) + return len(p), nil +} + +func (c *CustomHash3) Sum(b []byte) []byte { + return c.sum +} + +func (c *CustomHash3) Reset() { + c.counter = 0 +} + +func (c *CustomHash3) Size() int { + return 16 +} + +func (c *CustomHash3) BlockSize() int { + return 64 +} + +// 注册自定义哈希算法 +func init() { + RegisterHashFunc("CUSTOM", func() hash.Hash { + return &CustomHash{} + }) + + RegisterHashFunc("CUSTOM2", func() hash.Hash { + return &CustomHash2{} + }) + + RegisterOrUpdateHashFunc("CONFLICT", func() hash.Hash { + return &CustomHash3{sum: []byte("version1")} + }) + + RegisterOrUpdateHashFunc("CONFLICT", func() hash.Hash { + return &CustomHash3{sum: []byte("version2")} + }) +} + +func TestHashAlgorithms(t *testing.T) { + tests := []struct { + name string + hash Hash + input string + expect string + shouldError bool + }{ + {"MD5", MD5, "hello", "5d41402abc4b2a76b9719d911017c592", false}, + {"SHA1", SHA1, "hello", "aaf4c61ddcc5e8a2dabede0f3b482cd9aea9434d", false}, + {"SHA256", SHA256, "hello", "2cf24dba5fb0a30e26e83b2ac5b9e29e1b161e5c1fa7425e73043362938b9824", false}, + {"SHA512", SHA512, "hello", "9b71d224bd62f3785d96d46ad3ea3d73319bfbc2890caadae2dff72519673ca72323c3d99ba5c11d7c7acc6e14b8c5da0c4663475c2e5c3adef46f73bcdec043", false}, + {"CRC32", CRC32, "hello", "3610a686", false}, + {"FNV32", FNV32, "hello", "b6fa7167", false}, + {"FNV32A", FNV32A, "hello", "4f9f2cab", false}, + {"FNV64", FNV64, "hello", "7b495389bdbdd4c7", false}, + {"FNV64A", FNV64A, "hello", "a430d84680aabd0b", false}, + {"FNV128", FNV128, "hello", "f14b58486483d94f708038798c29697f", false}, + {"FNV128A", FNV128A, "hello", "e3e1efd54283d94f7081314b599d31b3", false}, + {"CUSTOM", 0, "hello", "mock_custom_hash", false}, + { + name: "CUSTOM", + input: "test123", + expect: "mock_custom_hash", + shouldError: false, + }, + { + name: "CUSTOM2", + input: "hello", + expect: "custom_hash_2", + shouldError: false, + }, + { + name: "CUSTOM3", + input: "data", + expect: "custom3-4", + shouldError: true, + }, + + // 异常case + { + name: "NOT_EXIST", + input: "test", + shouldError: true, + }, + { + name: "CONFLICT", + input: "test", + expect: "version2", + shouldError: false, + }, + { + name: "INVALID!NAME", + input: "test", + shouldError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var err error + if tt.hash == 0 { + tt.hash, err = ParseHash(tt.name) + if tt.shouldError { + if err == nil { + t.Fatalf("ParseHash(%s) error = %v, wantErr %v", tt.name, err, tt.shouldError) + } + return + } + } else { + + } + h := tt.hash.New() + h.Write([]byte(tt.input)) + got := h.Sum(nil) + gotHex := string(got) + if !IsCustomHash(tt.hash) { + gotHex = fmt.Sprintf("%x", got) + } + if gotHex != tt.expect { + t.Errorf("%s() = %v, want %v", tt.name, gotHex, tt.expect) + } + }) + } +} diff --git a/crypto/hash/crypto.go b/crypto/hash/crypto.go index a2494fc5..06730d6a 100644 --- a/crypto/hash/crypto.go +++ b/crypto/hash/crypto.go @@ -12,12 +12,11 @@ import ( "github.com/origadmin/toolkits/crypto/hash/algorithms/argon2" "github.com/origadmin/toolkits/crypto/hash/algorithms/bcrypt" - "github.com/origadmin/toolkits/crypto/hash/algorithms/dummy" "github.com/origadmin/toolkits/crypto/hash/algorithms/hmac256" "github.com/origadmin/toolkits/crypto/hash/algorithms/md5" + "github.com/origadmin/toolkits/crypto/hash/algorithms/pbkdf2" "github.com/origadmin/toolkits/crypto/hash/algorithms/scrypt" - "github.com/origadmin/toolkits/crypto/hash/algorithms/sha1" - "github.com/origadmin/toolkits/crypto/hash/algorithms/sha256" + "github.com/origadmin/toolkits/crypto/hash/algorithms/sha" "github.com/origadmin/toolkits/crypto/hash/core" "github.com/origadmin/toolkits/crypto/hash/interfaces" "github.com/origadmin/toolkits/crypto/hash/types" @@ -50,93 +49,89 @@ var ( }, types.TypeMD5: { creator: md5.NewMD5Crypto, - defaultConfig: md5.DefaultConfig, + defaultConfig: types.DefaultConfig, }, types.TypeScrypt: { creator: scrypt.NewScryptCrypto, defaultConfig: scrypt.DefaultConfig, }, types.TypeSha1: { - creator: sha1.NewSHA1Crypto, - defaultConfig: func() *types.Config { - return &types.Config{ - SaltLength: 16, - } - }, + creator: sha.NewSha1Crypto, + defaultConfig: sha.DefaultConfig, }, types.TypeSha256: { - creator: sha256.NewSHA256Crypto, - defaultConfig: func() *types.Config { - return &types.Config{ - SaltLength: 16, - } - }, - }, - // Unimplemented cryptos use dummy implementation - types.TypeCustom: { - creator: dummy.NewDummyCrypto, - defaultConfig: dummy.DefaultConfig, + creator: sha.NewSha256Crypto, + defaultConfig: sha.DefaultConfig, }, types.TypeSha512: { - creator: dummy.NewDummyCrypto, - defaultConfig: dummy.DefaultConfig, - }, - types.TypeSha384: { - creator: dummy.NewDummyCrypto, - defaultConfig: dummy.DefaultConfig, - }, - types.TypeSha3256: { - creator: dummy.NewDummyCrypto, - defaultConfig: dummy.DefaultConfig, - }, - types.TypeHMAC512: { - creator: dummy.NewDummyCrypto, - defaultConfig: dummy.DefaultConfig, + creator: sha.NewSha512Crypto, + defaultConfig: sha.DefaultConfig, }, types.TypePBKDF2: { - creator: dummy.NewDummyCrypto, - defaultConfig: dummy.DefaultConfig, - }, - types.TypePBKDF2SHA256: { - creator: dummy.NewDummyCrypto, - defaultConfig: dummy.DefaultConfig, - }, - types.TypePBKDF2SHA512: { - creator: dummy.NewDummyCrypto, - defaultConfig: dummy.DefaultConfig, - }, - types.TypePBKDF2SHA384: { - creator: dummy.NewDummyCrypto, - defaultConfig: dummy.DefaultConfig, - }, - types.TypePBKDF2SHA3256: { - creator: dummy.NewDummyCrypto, - defaultConfig: dummy.DefaultConfig, - }, - types.TypePBKDF2SHA3224: { - creator: dummy.NewDummyCrypto, - defaultConfig: dummy.DefaultConfig, - }, - types.TypePBKDF2SHA3384: { - creator: dummy.NewDummyCrypto, - defaultConfig: dummy.DefaultConfig, - }, - types.TypePBKDF2SHA3512224: { - creator: dummy.NewDummyCrypto, - defaultConfig: dummy.DefaultConfig, - }, - types.TypePBKDF2SHA3512256: { - creator: dummy.NewDummyCrypto, - defaultConfig: dummy.DefaultConfig, - }, - types.TypePBKDF2SHA3512384: { - creator: dummy.NewDummyCrypto, - defaultConfig: dummy.DefaultConfig, - }, - types.TypePBKDF2SHA3512512: { - creator: dummy.NewDummyCrypto, - defaultConfig: dummy.DefaultConfig, + creator: pbkdf2.NewPBKDF2Crypto, + defaultConfig: pbkdf2.DefaultConfig, }, + // Unimplemented cryptos use dummy implementation + //types.TypeCustom: { + // creator: dummy.NewDummyCrypto, + // defaultConfig: dummy.DefaultConfig, + //}, + //types.TypeSha512: { + // creator: dummy.NewDummyCrypto, + // defaultConfig: dummy.DefaultConfig, + //}, + //types.TypeSha384: { + // creator: dummy.NewDummyCrypto, + // defaultConfig: dummy.DefaultConfig, + //}, + //types.TypeSha3256: { + // creator: dummy.NewDummyCrypto, + // defaultConfig: dummy.DefaultConfig, + //}, + //types.TypeHMAC512: { + // creator: dummy.NewDummyCrypto, + // defaultConfig: dummy.DefaultConfig, + //}, + //types.TypePBKDF2SHA256: { + // creator: dummy.NewDummyCrypto, + // defaultConfig: dummy.DefaultConfig, + //}, + //types.TypePBKDF2SHA512: { + // creator: dummy.NewDummyCrypto, + // defaultConfig: dummy.DefaultConfig, + //}, + //types.TypePBKDF2SHA384: { + // creator: dummy.NewDummyCrypto, + // defaultConfig: dummy.DefaultConfig, + //}, + //types.TypePBKDF2SHA3256: { + // creator: dummy.NewDummyCrypto, + // defaultConfig: dummy.DefaultConfig, + //}, + //types.TypePBKDF2SHA3224: { + // creator: dummy.NewDummyCrypto, + // defaultConfig: dummy.DefaultConfig, + //}, + //types.TypePBKDF2SHA3384: { + // creator: dummy.NewDummyCrypto, + // defaultConfig: dummy.DefaultConfig, + //}, + //types.TypePBKDF2SHA3512224: { + // creator: dummy.NewDummyCrypto, + // defaultConfig: dummy.DefaultConfig, + //}, + //types.TypePBKDF2SHA3512256: { + // creator: dummy.NewDummyCrypto, + // defaultConfig: dummy.DefaultConfig, + //}, + //types.TypePBKDF2SHA3512384: { + // creator: dummy.NewDummyCrypto, + // defaultConfig: dummy.DefaultConfig, + //}, + //types.TypePBKDF2SHA3512512: { + // creator: dummy.NewDummyCrypto, + // defaultConfig: dummy.DefaultConfig, + //}, } ) @@ -147,6 +142,10 @@ type crypto struct { cryptos map[types.Type]interfaces.Cryptographic } +func (c crypto) Type() string { + return string(c.algorithm) +} + func (c crypto) Hash(password string) (string, error) { return c.crypto.Hash(password) } diff --git a/crypto/hash/examples/hash/main.go b/crypto/hash/examples/hash/main.go index 780696b7..9ea7b90c 100644 --- a/crypto/hash/examples/hash/main.go +++ b/crypto/hash/examples/hash/main.go @@ -9,40 +9,39 @@ import ( "log" "github.com/origadmin/toolkits/crypto/hash" + "github.com/origadmin/toolkits/crypto/hash/algorithms/argon2" "github.com/origadmin/toolkits/crypto/hash/types" ) func main() { - // 创建加密实例 + // reate cryptographic instance crypto, err := hash.NewCrypto(types.TypeArgon2, func(config *types.Config) { - config.TimeCost = 3 - config.MemoryCost = 64 * 1024 - config.Threads = 4 config.SaltLength = 16 + config.ParamConfig = argon2.DefaultParams().String() }) if err != nil { log.Fatal(err) } - // 测试密码 + // Test password password := "test123" - // 生成哈希 + // Generate hash hashed, err := crypto.Hash(password) if err != nil { log.Fatal(err) } fmt.Printf("Generated hash: %s\n", hashed) - // 验证密码 + // Verify password err = crypto.Verify(hashed, password) if err != nil { log.Fatal(err) } fmt.Println("Password verified successfully!") - // 测试错误的密码 + // Test wrong password wrongPassword := "wrong123" err = crypto.Verify(hashed, wrongPassword) if err != nil { diff --git a/crypto/hash/interfaces/cryptographic.go b/crypto/hash/interfaces/cryptographic.go index 39eaad08..d9993299 100644 --- a/crypto/hash/interfaces/cryptographic.go +++ b/crypto/hash/interfaces/cryptographic.go @@ -6,6 +6,7 @@ package interfaces // Cryptographic defines the interface for cryptographic operations type Cryptographic interface { + Type() string // Hash generates a hash for the given password Hash(password string) (string, error) // HashWithSalt generates a hash for the given password with the specified salt diff --git a/crypto/hash/types/config.go b/crypto/hash/types/config.go index 42108c8a..d77c5da8 100644 --- a/crypto/hash/types/config.go +++ b/crypto/hash/types/config.go @@ -4,49 +4,19 @@ package types -// ScryptConfig represents the configuration for Scrypt algorithm -type ScryptConfig struct { - N int `env:"HASH_SCRYPT_N"` - R int `env:"HASH_SCRYPT_R"` - P int `env:"HASH_SCRYPT_P"` +type ParamConfig interface { + String() string } // Config represents the configuration for hash algorithms type Config struct { - TimeCost uint32 `env:"HASH_TIMECOST"` - MemoryCost uint32 `env:"HASH_MEMORYCOST"` - Threads uint8 `env:"HASH_THREADS"` - SaltLength int `env:"HASH_SALTLENGTH"` - KeyLength uint32 `env:"HASH_KEYLENGTH"` - Cost int `env:"HASH_COST"` // Cost parameter (for bcrypt) - Salt string `env:"HASH_SALT"` // Salt for HMAC - Scrypt ScryptConfig + SaltLength int `env:"HASH_SALTLENGTH"` + ParamConfig string `env:"HASH_PARAM_CONFIG"` } // ConfigOption is a function that modifies a Config type ConfigOption func(*Config) -// WithTimeCost sets the time cost -func WithTimeCost(cost uint32) ConfigOption { - return func(cfg *Config) { - cfg.TimeCost = cost - } -} - -// WithMemoryCost sets the memory cost -func WithMemoryCost(cost uint32) ConfigOption { - return func(cfg *Config) { - cfg.MemoryCost = cost - } -} - -// WithThreads sets the number of threads -func WithThreads(threads uint8) ConfigOption { - return func(cfg *Config) { - cfg.Threads = threads - } -} - // WithSaltLength sets the salt length func WithSaltLength(length int) ConfigOption { return func(cfg *Config) { @@ -54,34 +24,18 @@ func WithSaltLength(length int) ConfigOption { } } -// WithCost sets the cost parameter -func WithCost(cost int) ConfigOption { - return func(cfg *Config) { - cfg.Cost = cost - } -} - -// WithSalt sets the salt -func WithSalt(salt string) ConfigOption { - return func(cfg *Config) { - cfg.Salt = salt - } -} - -// WithScryptConfig sets the scrypt configuration -func WithScryptConfig(scrypt ScryptConfig) ConfigOption { +func WithParams(paramConfig ParamConfig) ConfigOption { return func(cfg *Config) { - cfg.Scrypt = scrypt + if paramConfig == nil { + return + } + cfg.ParamConfig = paramConfig.String() } } // DefaultConfig 返回默认配置 func DefaultConfig() *Config { return &Config{ - TimeCost: 3, // 默认时间成本 - MemoryCost: 65536, // 默认内存成本 (64MB) - Threads: 4, // 默认线程数 - SaltLength: 16, // 默认盐长度 - KeyLength: 32, // 默认密钥长度 + SaltLength: 16, // Default salt length } }