Skip to content
Merged

merge #112

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 42 additions & 24 deletions crypto/hash/algorithms/argon2/argon2.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,37 +108,49 @@ type Argon2 struct {
codec interfaces.Codec
}

func (c *Argon2) Type() string {
return types.TypeArgon2.String()
}

// ConfigValidator implements the config validator for Argon2
type ConfigValidator struct{}
type ConfigValidator struct {
params *Params
}

// Validate validates the Argon2 configuration
func (v *ConfigValidator) Validate(config *types.Config) error {
if config.TimeCost < 1 {
return fmt.Errorf("invalid time cost: %d", config.TimeCost)
if config.SaltLength < 8 {
return core.ErrSaltLengthTooShort
}
if config.MemoryCost < 1 {
return fmt.Errorf("invalid memory cost: %d", config.MemoryCost)
if v.params.TimeCost < 1 {
return fmt.Errorf("invalid time cost: %d", v.params.TimeCost)
}
if config.Threads < 1 {
return fmt.Errorf("invalid threads: %d", config.Threads)
if v.params.MemoryCost < 1 {
return fmt.Errorf("invalid memory cost: %d", v.params.MemoryCost)
}
if config.SaltLength < 8 {
return core.ErrSaltLengthTooShort
if v.params.Threads < 1 {
return fmt.Errorf("invalid threads: %d", v.params.Threads)
}
if config.KeyLength < 4 || config.KeyLength > 1024 {
return fmt.Errorf("invalid key length: %d, must be between 4 and 1024", config.KeyLength)
if v.params.KeyLength < 4 || v.params.KeyLength > 1024 {
return fmt.Errorf("invalid key length: %d, must be between 4 and 1024", v.params.KeyLength)
}
return nil
}

// DefaultConfig returns the default configuration for Argon2
func DefaultConfig() *types.Config {
return &types.Config{
TimeCost: 3, // Default time cost
MemoryCost: 64 * 1024, // Default memory cost (64MB)
Threads: 4, // Default threads
SaltLength: 16, // Default salt length
KeyLength: 32, // Default key length
SaltLength: 16, // Default salt length
ParamConfig: DefaultParams().String(),
}
}

func DefaultParams() *Params {
return &Params{
TimeCost: 3, // Default time cost
MemoryCost: 65536, // Default memory cost (64MB)
Threads: 4, // Default threads
KeyLength: 32, // Default key length
}
}

Expand All @@ -148,16 +160,22 @@ func NewArgon2Crypto(config *types.Config) (interfaces.Cryptographic, error) {
if config == nil {
config = DefaultConfig()
}
validator := &ConfigValidator{}

if config.ParamConfig == "" {
config.ParamConfig = DefaultParams().String()
}
params, err := parseParams(config.ParamConfig)
if err != nil {
return nil, fmt.Errorf("invalid argon2 param config: %v", err)
}

validator := &ConfigValidator{
params: params,
}
if err := validator.Validate(config); err != nil {
return nil, fmt.Errorf("invalid argon2 config: %v", err)
}
params := &Params{
TimeCost: config.TimeCost,
MemoryCost: config.MemoryCost,
Threads: config.Threads,
KeyLength: config.KeyLength,
}

return &Argon2{
params: params,
config: config,
Expand Down Expand Up @@ -195,7 +213,7 @@ func (c *Argon2) Verify(hashed, password string) error {
return err
}
if parts.Algorithm != types.TypeArgon2 {
return fmt.Errorf("algorithm mismatch")
return core.ErrAlgorithmMismatch
}
// Parse parameters
params, err := parseParams(parts.Params)
Expand Down
22 changes: 13 additions & 9 deletions crypto/hash/algorithms/argon2/argon2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ func TestParams_String(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := tt.params.String(); got != tt.want {
t.Errorf("Params.String() = %v, want %v", got, tt.want)
t.Errorf("params.String() = %v, want %v", got, tt.want)
}
})
}
Expand All @@ -217,22 +217,26 @@ func TestNewArgon2Crypto(t *testing.T) {
{
name: "Custom config",
config: &types.Config{
TimeCost: 3,
MemoryCost: 64 * 1024,
Threads: 4,
SaltLength: 16,
KeyLength: 32,
ParamConfig: (&Params{
TimeCost: 3,
MemoryCost: 64 * 1024,
Threads: 4,
KeyLength: 32,
}).String(),
},
wantErr: false,
},
{
name: "Invalid config - zero time cost",
config: &types.Config{
TimeCost: 0,
MemoryCost: 64 * 1024,
Threads: 4,
SaltLength: 16,
KeyLength: 32,
ParamConfig: (&Params{
TimeCost: 0,
MemoryCost: 64 * 1024,
Threads: 4,
KeyLength: 32,
}).String(),
},
wantErr: true,
},
Expand Down
70 changes: 63 additions & 7 deletions crypto/hash/algorithms/bcrypt/bcrypt.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ package bcrypt

import (
"fmt"
"strconv"
"strings"

"golang.org/x/crypto/bcrypt"

Expand All @@ -17,18 +19,55 @@ import (

// Bcrypt implements the Bcrypt hashing algorithm
type Bcrypt struct {
params *Params
config *types.Config
codec interfaces.Codec
}

func (c *Bcrypt) Type() string {
return types.TypeBcrypt.String()
}

type Params struct {
Cost int
}

func (p *Params) String() string {
return fmt.Sprintf("c:%d", p.Cost)
}

func parseParams(params string) (*Params, error) {
result := &Params{}

if params == "" {
return result, nil
}
for _, param := range strings.Split(params, ",") {
parts := strings.Split(param, ":")
if len(parts) != 2 {
return nil, fmt.Errorf("invalid bcrypt param format: %s", param)
}
switch parts[0] {
case "c":
cost, err := strconv.ParseInt(parts[1], 10, 32)
if err != nil {
return nil, fmt.Errorf("invalid cost: %v", err)
}
result.Cost = int(cost)
}
}
return result, nil
}

type ConfigValidator struct {
params *Params
}

func (v ConfigValidator) Validate(config *types.Config) interface{} {
if config.SaltLength < 8 {
return core.ErrSaltLengthTooShort
}
if config.Cost < 4 || config.Cost > 31 {
if v.params.Cost < 4 || v.params.Cost > 31 {
return core.ErrCostOutOfRange
}
return nil
Expand All @@ -39,20 +78,37 @@ func NewBcryptCrypto(config *types.Config) (interfaces.Cryptographic, error) {
if config == nil {
config = DefaultConfig()
}
validator := &ConfigValidator{}
if config.ParamConfig == "" {
config.ParamConfig = DefaultParams().String()
}
params, err := parseParams(config.ParamConfig)
if err != nil {
return nil, fmt.Errorf("invalid bcrypt param config: %v", err)
}

validator := &ConfigValidator{
params: params,
}
if err := validator.Validate(config); err != nil {
return nil, fmt.Errorf("invalid bcrypt config: %v", err)
}
return &Bcrypt{
config: config,
params: params,
codec: core.NewCodec(types.TypeBcrypt),
}, nil
}

func DefaultConfig() *types.Config {
return &types.Config{
SaltLength: 16,
Cost: 10,
SaltLength: 16,
ParamConfig: DefaultParams().String(),
}
}

func DefaultParams() *Params {
return &Params{
Cost: bcrypt.DefaultCost,
}
}

Expand All @@ -67,7 +123,7 @@ func (c *Bcrypt) Hash(password string) (string, error) {

// HashWithSalt implements the hash with salt method
func (c *Bcrypt) HashWithSalt(password, salt string) (string, error) {
hash, err := bcrypt.GenerateFromPassword([]byte(password+salt), c.config.Cost)
hash, err := bcrypt.GenerateFromPassword([]byte(password+salt), c.params.Cost)
if err != nil {
return "", err
}
Expand All @@ -81,11 +137,11 @@ func (c *Bcrypt) Verify(hashed, password string) error {
return err
}
if parts.Algorithm != types.TypeBcrypt {
return fmt.Errorf("algorithm mismatch")
return core.ErrAlgorithmMismatch
}
err = bcrypt.CompareHashAndPassword(parts.Hash, []byte(password+string(parts.Salt)))
if err != nil {
return fmt.Errorf("password not match")
return core.ErrPasswordNotMatch
}
return nil
}
8 changes: 6 additions & 2 deletions crypto/hash/algorithms/bcrypt/bcrypt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,19 @@ func TestNewBcryptCrypto(t *testing.T) {
{
name: "Custom config",
config: &types.Config{
Cost: 10,
SaltLength: 16,
ParamConfig: (&Params{
Cost: 10,
}).String(),
},
wantErr: false,
},
{
name: "Invalid config - zero cost",
config: &types.Config{
Cost: 0,
ParamConfig: (&Params{
Cost: 0,
}).String(),
SaltLength: 16,
},
wantErr: true,
Expand Down
14 changes: 8 additions & 6 deletions crypto/hash/algorithms/dummy/dummy.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package dummy

import (
"errors"
"fmt"

"github.com/origadmin/toolkits/crypto/hash/interfaces"
Expand All @@ -13,19 +14,20 @@ import (

// Crypto implements a dummy hashing algorithm
type Crypto struct {
config *types.Config
}

func (c *Crypto) Type() string {
return "dummy"
}

// NewDummyCrypto creates a new dummy crypto instance
func NewDummyCrypto(config *types.Config) (interfaces.Cryptographic, error) {
return &Crypto{
config: config,
}, nil
return nil, errors.New("algorithm not implemented")
}

// Hash implements the hash method
func (c *Crypto) Hash(password string) (string, error) {
return "", fmt.Errorf("dummy algorithm not implemented")
return "", fmt.Errorf("algorithm not implemented")
}

// HashWithSalt implements the hash with salt method
Expand All @@ -35,7 +37,7 @@ func (c *Crypto) HashWithSalt(password, salt string) (string, error) {

// Verify implements the verify method
func (c *Crypto) Verify(hashed, password string) error {
return fmt.Errorf("dummy algorithm not implemented")
return fmt.Errorf("algorithm not implemented")
}

func DefaultConfig() *types.Config {
Expand Down
11 changes: 8 additions & 3 deletions crypto/hash/algorithms/hmac256/hmac256.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ package hmac256
import (
"crypto/hmac"
"crypto/sha256"
"crypto/subtle"
"fmt"

"github.com/origadmin/toolkits/crypto/hash/core"
Expand All @@ -21,6 +22,10 @@ type HMAC256 struct {
codec interfaces.Codec
}

func (c *HMAC256) Type() string {
return types.TypeHMAC256.String()
}

type ConfigValidator struct {
SaltLength int
}
Expand Down Expand Up @@ -78,14 +83,14 @@ func (c *HMAC256) Verify(hashed, password string) error {
}

if parts.Algorithm != types.TypeHMAC256 {
return fmt.Errorf("algorithm mismatch")
return core.ErrAlgorithmMismatch
}

h := hmac.New(sha256.New, parts.Salt)
h.Write([]byte(password))
newHash := h.Sum(nil)
if string(newHash) != string(parts.Hash) {
return fmt.Errorf("password not match")
if subtle.ConstantTimeCompare(newHash, parts.Hash) != 1 {
return core.ErrPasswordNotMatch
}

return nil
Expand Down
Loading