Skip to content

Commit

Permalink
feat: [GoSDK] fp32 <-> fp16/bf16 vector conversion
Browse files Browse the repository at this point in the history
Add the following methods for convenient fp32 vector <-> fp16/bf16
vector conversion

fp32 <-> fp16/bf16 vector conversion:

- `func (fv FloatVector) ToFloat16Vector() Float16Vector`
- `func (fv FloatVector) ToBFloat16Vector() BFloat16Vector`
- `func (fv Float16Vector) ToFloat32Vector() FloatVector`
- `func (fv BFloat16Vector) ToFloat32Vector() FloatVector`

`columnBasedDataOption`:

- `func (opt *columnBasedDataOption) WithFloat16VectorColumn(colName string, dim int, data [][]float32) *columnBasedDataOption`
- `func (opt *columnBasedDataOption) WithBFloat16VectorColumn(colName string, dim int, data [][]float32) *columnBasedDataOption`

`ColumnFloat16Vector`/`ColumnBFloat16Vector`:

- `func NewColumnFloat16VectorFromFp32Vector(fieldName string, dim int, data [][]float32) *ColumnFloat16Vector`
- `func NewColumnBFloat16VectorFromFp32Vector(fieldName string, dim int, data [][]float32) *ColumnBFloat16Vector`
- support []float32 or `entity.FloatVector` in
    - `func (c *ColumnFloat16Vector) AppendValue(i interface{}) error`
    - `func (c *ColumnFloat16Vector) AppendValue(i interface{}) error`

issue: #37448

Signed-off-by: Yinzuo Jiang <[email protected]>
Signed-off-by: Yinzuo Jiang <[email protected]>
  • Loading branch information
jiangyinzuo committed Nov 27, 2024
1 parent 302650a commit 9a3e2f5
Show file tree
Hide file tree
Showing 17 changed files with 401 additions and 73 deletions.
30 changes: 28 additions & 2 deletions client/column/vector.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,14 +137,26 @@ func NewColumnFloat16Vector(fieldName string, dim int, data [][]byte) *ColumnFlo
}
}

func NewColumnFloat16VectorFromFp32Vector(fieldName string, dim int, data [][]float32) *ColumnFloat16Vector {
vectors := lo.Map(data, func(row []float32, _ int) entity.Float16Vector { return entity.FloatVector(row).ToFloat16Vector() })
return &ColumnFloat16Vector{
vectorBase: newVectorBase(fieldName, dim, vectors, entity.FieldTypeFloat16Vector),
}
}

// AppendValue appends vector value into values.
// override default type constrains, add `[]byte` conversion
// Override default type constrains, add `[]byte`, `entity.FloatVector` and
// `[]float32` conversion.
func (c *ColumnFloat16Vector) AppendValue(i interface{}) error {
switch vector := i.(type) {
case entity.Float16Vector:
c.values = append(c.values, vector)
case []byte:
c.values = append(c.values, vector)
case entity.FloatVector:
c.values = append(c.values, vector.ToFloat16Vector())
case []float32:
c.values = append(c.values, entity.FloatVector(vector).ToFloat16Vector())
default:
return errors.Newf("unexpected append value type %T, field type %v", vector, c.fieldType)
}
Expand All @@ -157,6 +169,8 @@ func (c *ColumnFloat16Vector) Slice(start, end int) Column {
}
}

/* bf16 vector */

type ColumnBFloat16Vector struct {
*vectorBase[entity.BFloat16Vector]
}
Expand All @@ -168,14 +182,26 @@ func NewColumnBFloat16Vector(fieldName string, dim int, data [][]byte) *ColumnBF
}
}

func NewColumnBFloat16VectorFromFp32Vector(fieldName string, dim int, data [][]float32) *ColumnBFloat16Vector {
vectors := lo.Map(data, func(row []float32, _ int) entity.BFloat16Vector { return entity.FloatVector(row).ToBFloat16Vector() })
return &ColumnBFloat16Vector{
vectorBase: newVectorBase(fieldName, dim, vectors, entity.FieldTypeBFloat16Vector),
}
}

// AppendValue appends vector value into values.
// override default type constrains, add `[]byte` conversion
// Override default type constrains, add `[]byte`, `entity.FloatVector` and
// `[]float32` conversion.
func (c *ColumnBFloat16Vector) AppendValue(i interface{}) error {
switch vector := i.(type) {
case entity.BFloat16Vector:
c.values = append(c.values, vector)
case []byte:
c.values = append(c.values, vector)
case entity.FloatVector:
c.values = append(c.values, vector.ToBFloat16Vector())
case []float32:
c.values = append(c.values, entity.FloatVector(vector).ToBFloat16Vector())
default:
return errors.Newf("unexpected append value type %T, field type %v", vector, c.fieldType)
}
Expand Down
27 changes: 27 additions & 0 deletions client/column/vector_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,33 @@ func (s *VectorSuite) TestBasic() {
s.Equal(dim, parsed.Dim())
}
})

s.Run("fp32 <-> fp16/bf16 vector conversion", func() {
dim := 3
data := [][]float32{{0.1, 0.2, 0.3}, {0.4, 0.5, 0.6}}
dataToAppend1 := []float32{0.7, 0.8, 0.9}
dataToAppend2 := entity.FloatVector([]float32{1.0, 1.1, 1.2})

fp16Vector := NewColumnFloat16VectorFromFp32Vector("fp16_vector", dim, data)
fp16Vector.AppendValue(dataToAppend1)
fp16Vector.AppendValue(dataToAppend2)
for _, vec := range fp16Vector.Data() {
fp32Vector := vec.ToFloat32Vector()
for i := 0; i < dim; i++ {
s.InDelta(data[i], fp32Vector[i], 1e-3)
}
}

bf16Vector := NewColumnBFloat16VectorFromFp32Vector("bf16_vector", dim, data)
bf16Vector.AppendValue(dataToAppend1)
bf16Vector.AppendValue(dataToAppend2)
for _, vec := range bf16Vector.Data() {
fp32Vector := vec.ToFloat32Vector()
for i := 0; i < dim; i++ {
s.InDelta(data[i], fp32Vector[i], 1e-3)
}
}
})
}

func (s *VectorSuite) TestSlice() {
Expand Down
29 changes: 20 additions & 9 deletions client/entity/vectors.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@
package entity

import (
"encoding/binary"
"math"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)

// Vector interface vector used int search
Expand All @@ -44,13 +43,17 @@ func (fv FloatVector) FieldType() FieldType {
// Serialize serializes vector into byte slice, used in search placeholder
// LittleEndian is used for convention
func (fv FloatVector) Serialize() []byte {
data := make([]byte, 0, 4*len(fv)) // float32 occupies 4 bytes
buf := make([]byte, 4)
for _, f := range fv {
binary.LittleEndian.PutUint32(buf, math.Float32bits(f))
data = append(data, buf...)
}
return data
return typeutil.Float32ArrayToBytes(fv)
}

func (fv FloatVector) ToFloat16Vector() Float16Vector {
return typeutil.Float32ArrayToFloat16Bytes(fv)
}

// SerializeToBFloat16Bytes serializes vector into bfloat16 byte slice,
// used in search placeholder
func (fv FloatVector) ToBFloat16Vector() BFloat16Vector {
return typeutil.Float32ArrayToBFloat16Bytes(fv)
}

// FloatVector float32 vector wrapper.
Expand All @@ -70,6 +73,10 @@ func (fv Float16Vector) Serialize() []byte {
return fv
}

func (fv Float16Vector) ToFloat32Vector() FloatVector {
return typeutil.Float16BytesToFloat32Vector(fv)
}

// FloatVector float32 vector wrapper.
type BFloat16Vector []byte

Expand All @@ -87,6 +94,10 @@ func (fv BFloat16Vector) Serialize() []byte {
return fv
}

func (fv BFloat16Vector) ToFloat32Vector() FloatVector {
return typeutil.BFloat16BytesToFloat32Vector(fv)
}

// BinaryVector []byte vector wrapper
type BinaryVector []byte

Expand Down
43 changes: 42 additions & 1 deletion client/entity/vectors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,50 @@ func TestVectors(t *testing.T) {
}

fv := FloatVector(raw)

assert.Equal(t, dim, fv.Dim())
assert.Equal(t, dim*4, len(fv.Serialize()))

var fvConverted FloatVector

fp16v := fv.ToFloat16Vector()
assert.Equal(t, dim, fp16v.Dim())
assert.Equal(t, dim*2, len(fp16v.Serialize()))
fvConverted = fp16v.ToFloat32Vector()
assert.Equal(t, dim, fvConverted.Dim())
assert.Equal(t, dim*4, len(fvConverted.Serialize()))

bf16v := fv.ToBFloat16Vector()
assert.Equal(t, dim, bf16v.Dim())
assert.Equal(t, dim*2, len(bf16v.Serialize()))
fvConverted = bf16v.ToFloat32Vector()
assert.Equal(t, dim, fvConverted.Dim())
assert.Equal(t, dim*4, len(fvConverted.Serialize()))
})

t.Run("test fp32 <-> fp16/bf16 vector conversion", func(t *testing.T) {
raw := make([]float32, dim)
for i := 0; i < dim; i++ {
raw[i] = float32(i) * 0.1
}

fv := FloatVector(raw)
fp16v := fv.ToFloat16Vector()
bf16v := fv.ToBFloat16Vector()

assert.Equal(t, dim, fp16v.Dim())
assert.Equal(t, dim*2, len(fp16v.Serialize()))
assert.Equal(t, dim, bf16v.Dim())
assert.Equal(t, dim*2, len(bf16v.Serialize()))

fp32vFromfp16v := fp16v.ToFloat32Vector()
for i := 0; i < dim; i++ {
assert.InDelta(t, fv[i], fp32vFromfp16v[i], 0.04)
}

fp32vFrombf16v := bf16v.ToFloat32Vector()
for i := 0; i < dim; i++ {
assert.InDelta(t, fp32vFromfp16v[i], fp32vFrombf16v[i], 0.04)
}
})

t.Run("test binary vector", func(t *testing.T) {
Expand Down
5 changes: 3 additions & 2 deletions client/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,12 @@ require (
github.com/cockroachdb/errors v1.9.1
github.com/grpc-ecosystem/go-grpc-middleware v1.3.0
github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20241120015424-93892e628c69
github.com/milvus-io/milvus/pkg v0.0.2-0.20241111021426-5e90f348fcbb
github.com/milvus-io/milvus/pkg v0.0.2-0.20241126032235-cb6542339e84
github.com/quasilyte/go-ruleguard/dsl v0.3.22
github.com/samber/lo v1.27.0
github.com/stretchr/testify v1.9.0
github.com/tidwall/gjson v1.17.1
go.uber.org/atomic v1.10.0
golang.org/x/exp v0.0.0-20230224173230-c95f2b4c22f2
google.golang.org/grpc v1.65.0
google.golang.org/protobuf v1.34.2
)
Expand Down Expand Up @@ -67,6 +66,7 @@ require (
github.com/sirupsen/logrus v1.9.0 // indirect
github.com/soheilhy/cmux v0.1.5 // indirect
github.com/spaolacci/murmur3 v1.1.0 // indirect
github.com/spf13/cast v1.3.1 // indirect
github.com/spf13/pflag v1.0.5 // indirect
github.com/stretchr/objx v0.5.2 // indirect
github.com/tidwall/match v1.1.1 // indirect
Expand Down Expand Up @@ -98,6 +98,7 @@ require (
go.uber.org/multierr v1.11.0 // indirect
go.uber.org/zap v1.27.0 // indirect
golang.org/x/crypto v0.25.0 // indirect
golang.org/x/exp v0.0.0-20230224173230-c95f2b4c22f2 // indirect
golang.org/x/net v0.27.0 // indirect
golang.org/x/sync v0.7.0 // indirect
golang.org/x/sys v0.22.0 // indirect
Expand Down
6 changes: 4 additions & 2 deletions client/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -320,8 +320,8 @@ github.com/microcosm-cc/bluemonday v1.0.2/go.mod h1:iVP4YcDBq+n/5fb23BhYFvIMq/le
github.com/miekg/dns v1.0.14/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg=
github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20241120015424-93892e628c69 h1:Qt0Bv2Fum3EX3OlkuQYHJINBzeU4oEuHy2lXSfB/gZw=
github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20241120015424-93892e628c69/go.mod h1:/6UT4zZl6awVeXLeE7UGDWZvXj3IWkRsh3mqsn0DiAs=
github.com/milvus-io/milvus/pkg v0.0.2-0.20241111021426-5e90f348fcbb h1:lMyIrG03agASB88AAwnk+NOU9V33lcBdtub/ZEv6IQU=
github.com/milvus-io/milvus/pkg v0.0.2-0.20241111021426-5e90f348fcbb/go.mod h1:w5nu1Z318AvgWQrGUYXaqLeVLu4JvCS/oYhxqctOZvU=
github.com/milvus-io/milvus/pkg v0.0.2-0.20241126032235-cb6542339e84 h1:EAFxmxUVp5yYFDCrX1MQoSxkTO+ycy8NXEqEDEB3cRM=
github.com/milvus-io/milvus/pkg v0.0.2-0.20241126032235-cb6542339e84/go.mod h1:RATa0GS4jhkPpsYOvQ/QvcNz8rd+TlRPDiSyXQnMMxs=
github.com/mitchellh/cli v1.0.0/go.mod h1:hNIlj7HEI86fIcpObd7a0FcrxTWetlwJDGcceTlRvqc=
github.com/mitchellh/go-homedir v1.0.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0=
github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0=
Expand Down Expand Up @@ -436,6 +436,8 @@ github.com/spaolacci/murmur3 v1.1.0 h1:7c1g84S4BPRrfL5Xrdp6fOJ206sU9y293DDHaoy0b
github.com/spaolacci/murmur3 v1.1.0/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA=
github.com/spf13/afero v1.1.2/go.mod h1:j4pytiNVoe2o6bmDsKpLACNPDBIoEAkihy7loJ1B0CQ=
github.com/spf13/cast v1.3.0/go.mod h1:Qx5cxh0v+4UWYiBimWS+eyWzqEqokIECu5etghLkUJE=
github.com/spf13/cast v1.3.1 h1:nFm6S0SMdyzrzcmThSipiEubIDy8WEXKNZ0UOgiRpng=
github.com/spf13/cast v1.3.1/go.mod h1:Qx5cxh0v+4UWYiBimWS+eyWzqEqokIECu5etghLkUJE=
github.com/spf13/cobra v0.0.5/go.mod h1:3K3wKZymM7VvHMDS9+Akkh4K60UwM26emMESw8tLCHU=
github.com/spf13/cobra v1.1.3/go.mod h1:pGADOWyqRD/YMrPZigI/zbliZ2wVD/23d+is3pSWzOo=
github.com/spf13/jwalterweatherman v1.0.0/go.mod h1:cQK4TGJAtQXfYWX+Ddv3mKDzgVb68N+wFjFa4jdeBTo=
Expand Down
36 changes: 27 additions & 9 deletions client/milvusclient/results_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,16 @@ type ResultSetSuite struct {

func (s *ResultSetSuite) TestResultsetUnmarshal() {
type MyData struct {
A int64 `milvus:"name:id"`
V []float32 `milvus:"name:vector"`
A int64 `milvus:"name:id"`
V []float32 `milvus:"name:vector"`
Fp16V []byte `milvus:"name:fp16_vector"`
Bf16V []byte `milvus:"name:bf16_vector"`
}
type OtherData struct {
A string `milvus:"name:id"`
V []float32 `milvus:"name:vector"`
A string `milvus:"name:id"`
V []float32 `milvus:"name:vector"`
Fp16V []byte `milvus:"name:fp16_vector"`
Bf16V []byte `milvus:"name:bf16_vector"`
}

var (
Expand All @@ -51,6 +55,8 @@ func (s *ResultSetSuite) TestResultsetUnmarshal() {
rs := DataSet([]column.Column{
column.NewColumnInt64("id", idData),
column.NewColumnFloatVector("vector", 2, vectorData),
column.NewColumnFloat16VectorFromFp32Vector("fp16_vector", 2, vectorData),
column.NewColumnBFloat16VectorFromFp32Vector("bf16_vector", 2, vectorData),
})
err := rs.Unmarshal([]MyData{})
s.Error(err)
Expand All @@ -66,6 +72,8 @@ func (s *ResultSetSuite) TestResultsetUnmarshal() {
for idx, row := range ptrReceiver {
s.Equal(row.A, idData[idx])
s.Equal(row.V, vectorData[idx])
s.Equal(entity.Float16Vector(row.Fp16V), entity.FloatVector(vectorData[idx]).ToFloat16Vector())
s.Equal(entity.BFloat16Vector(row.Bf16V), entity.FloatVector(vectorData[idx]).ToBFloat16Vector())
}

var otherReceiver []*OtherData
Expand All @@ -75,12 +83,16 @@ func (s *ResultSetSuite) TestResultsetUnmarshal() {

func (s *ResultSetSuite) TestSearchResultUnmarshal() {
type MyData struct {
A int64 `milvus:"name:id"`
V []float32 `milvus:"name:vector"`
A int64 `milvus:"name:id"`
V []float32 `milvus:"name:vector"`
Fp16V []byte `milvus:"name:fp16_vector"`
Bf16V []byte `milvus:"name:bf16_vector"`
}
type OtherData struct {
A string `milvus:"name:id"`
V []float32 `milvus:"name:vector"`
A string `milvus:"name:id"`
V []float32 `milvus:"name:vector"`
Fp16V []byte `milvus:"name:fp16_vector"`
Bf16V []byte `milvus:"name:bf16_vector"`
}

var (
Expand All @@ -95,10 +107,14 @@ func (s *ResultSetSuite) TestSearchResultUnmarshal() {
sr := ResultSet{
sch: entity.NewSchema().
WithField(entity.NewField().WithName("id").WithIsPrimaryKey(true).WithDataType(entity.FieldTypeInt64)).
WithField(entity.NewField().WithName("vector").WithDim(2).WithDataType(entity.FieldTypeFloatVector)),
WithField(entity.NewField().WithName("vector").WithDim(2).WithDataType(entity.FieldTypeFloatVector)).
WithField(entity.NewField().WithName("fp16_vector").WithDim(2).WithDataType(entity.FieldTypeFloat16Vector)).
WithField(entity.NewField().WithName("bf16_vector").WithDim(2).WithDataType(entity.FieldTypeBFloat16Vector)),
IDs: column.NewColumnInt64("id", idData),
Fields: DataSet([]column.Column{
column.NewColumnFloatVector("vector", 2, vectorData),
column.NewColumnFloat16VectorFromFp32Vector("fp16_vector", 2, vectorData),
column.NewColumnBFloat16VectorFromFp32Vector("bf16_vector", 2, vectorData),
}),
}
err := sr.Unmarshal([]MyData{})
Expand All @@ -115,6 +131,8 @@ func (s *ResultSetSuite) TestSearchResultUnmarshal() {
for idx, row := range ptrReceiver {
s.Equal(row.A, idData[idx])
s.Equal(row.V, vectorData[idx])
s.Equal(entity.Float16Vector(row.Fp16V), entity.FloatVector(vectorData[idx]).ToFloat16Vector())
s.Equal(entity.BFloat16Vector(row.Bf16V), entity.FloatVector(vectorData[idx]).ToBFloat16Vector())
}

var otherReceiver []*OtherData
Expand Down
29 changes: 28 additions & 1 deletion client/milvusclient/write_options.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import (
"github.com/milvus-io/milvus/client/v2/column"
"github.com/milvus-io/milvus/client/v2/entity"
"github.com/milvus-io/milvus/client/v2/row"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)

type InsertOption interface {
Expand Down Expand Up @@ -95,13 +96,21 @@ func (opt *columnBasedDataOption) processInsertColumns(colSchema *entity.Schema,
if col.Type() != field.DataType {
return nil, 0, fmt.Errorf("param column %s has type %v but collection field definition is %v", col.Name(), col.Type(), field.DataType)
}
if field.DataType == entity.FieldTypeFloatVector || field.DataType == entity.FieldTypeBinaryVector {
if field.DataType == entity.FieldTypeFloatVector || field.DataType == entity.FieldTypeBinaryVector ||
field.DataType == entity.FieldTypeFloat16Vector || field.DataType == entity.FieldTypeBFloat16Vector ||
field.DataType == entity.FieldTypeSparseVector {
dim := 0
switch column := col.(type) {
case *column.ColumnFloatVector:
dim = column.Dim()
case *column.ColumnBinaryVector:
dim = column.Dim()
case *column.ColumnFloat16Vector:
dim = column.Dim()
case *column.ColumnBFloat16Vector:
dim = column.Dim()
case *column.ColumnSparseFloatVector:
dim = column.Dim()
}
if fmt.Sprintf("%d", dim) != field.TypeParams[entity.TypeParamDim] {
return nil, 0, fmt.Errorf("params column %s vector dim %d not match collection definition, which has dim of %s", field.Name, dim, field.TypeParams[entity.TypeParamDim])
Expand Down Expand Up @@ -205,6 +214,24 @@ func (opt *columnBasedDataOption) WithFloatVectorColumn(colName string, dim int,
return opt.WithColumns(column)
}

func (opt *columnBasedDataOption) WithFloat16VectorColumn(colName string, dim int, data [][]float32) *columnBasedDataOption {
f16v := make([][]byte, 0, len(data))
for i := 0; i < len(data); i++ {
f16v = append(f16v, typeutil.Float32ArrayToFloat16Bytes(data[i]))
}
column := column.NewColumnFloat16Vector(colName, dim, f16v)
return opt.WithColumns(column)
}

func (opt *columnBasedDataOption) WithBFloat16VectorColumn(colName string, dim int, data [][]float32) *columnBasedDataOption {
bf16v := make([][]byte, 0, len(data))
for i := 0; i < len(data); i++ {
bf16v = append(bf16v, typeutil.Float32ArrayToBFloat16Bytes(data[i]))
}
column := column.NewColumnBFloat16Vector(colName, dim, bf16v)
return opt.WithColumns(column)
}

func (opt *columnBasedDataOption) WithBinaryVectorColumn(colName string, dim int, data [][]byte) *columnBasedDataOption {
column := column.NewColumnBinaryVector(colName, dim, data)
return opt.WithColumns(column)
Expand Down
Loading

0 comments on commit 9a3e2f5

Please sign in to comment.