Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: [GoSDK] fp32 <-> fp16/bf16 vector conversion #37978

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 28 additions & 2 deletions client/column/vector.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,14 +137,26 @@
}
}

func NewColumnFloat16VectorFromFp32Vector(fieldName string, dim int, data [][]float32) *ColumnFloat16Vector {
vectors := lo.Map(data, func(row []float32, _ int) entity.Float16Vector { return entity.FloatVector(row).ToFloat16Vector() })
return &ColumnFloat16Vector{
vectorBase: newVectorBase(fieldName, dim, vectors, entity.FieldTypeFloat16Vector),
}
}

// AppendValue appends vector value into values.
// override default type constrains, add `[]byte` conversion
// Override default type constrains, add `[]byte`, `entity.FloatVector` and
// `[]float32` conversion.
func (c *ColumnFloat16Vector) AppendValue(i interface{}) error {
switch vector := i.(type) {
case entity.Float16Vector:
c.values = append(c.values, vector)
case []byte:
c.values = append(c.values, vector)
case entity.FloatVector:
c.values = append(c.values, vector.ToFloat16Vector())

Check warning on line 157 in client/column/vector.go

View check run for this annotation

Codecov / codecov/patch

client/column/vector.go#L156-L157

Added lines #L156 - L157 were not covered by tests
case []float32:
c.values = append(c.values, entity.FloatVector(vector).ToFloat16Vector())
default:
return errors.Newf("unexpected append value type %T, field type %v", vector, c.fieldType)
}
Expand All @@ -157,6 +169,8 @@
}
}

/* bf16 vector */

type ColumnBFloat16Vector struct {
*vectorBase[entity.BFloat16Vector]
}
Expand All @@ -168,14 +182,26 @@
}
}

func NewColumnBFloat16VectorFromFp32Vector(fieldName string, dim int, data [][]float32) *ColumnBFloat16Vector {
vectors := lo.Map(data, func(row []float32, _ int) entity.BFloat16Vector { return entity.FloatVector(row).ToBFloat16Vector() })
return &ColumnBFloat16Vector{
vectorBase: newVectorBase(fieldName, dim, vectors, entity.FieldTypeBFloat16Vector),
}
}

// AppendValue appends vector value into values.
// override default type constrains, add `[]byte` conversion
// Override default type constrains, add `[]byte`, `entity.FloatVector` and
// `[]float32` conversion.
func (c *ColumnBFloat16Vector) AppendValue(i interface{}) error {
switch vector := i.(type) {
case entity.BFloat16Vector:
c.values = append(c.values, vector)
case []byte:
c.values = append(c.values, vector)
case entity.FloatVector:
c.values = append(c.values, vector.ToBFloat16Vector())

Check warning on line 202 in client/column/vector.go

View check run for this annotation

Codecov / codecov/patch

client/column/vector.go#L201-L202

Added lines #L201 - L202 were not covered by tests
case []float32:
c.values = append(c.values, entity.FloatVector(vector).ToBFloat16Vector())
default:
return errors.Newf("unexpected append value type %T, field type %v", vector, c.fieldType)
}
Expand Down
27 changes: 27 additions & 0 deletions client/column/vector_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,33 @@ func (s *VectorSuite) TestBasic() {
s.Equal(dim, parsed.Dim())
}
})

s.Run("fp32 <-> fp16/bf16 vector conversion", func() {
dim := 3
data := [][]float32{{0.1, 0.2, 0.3}, {0.4, 0.5, 0.6}, {0.7, 0.8, 0.9}, {1.0, 1.1, 1.2}}

fp16Vector := NewColumnFloat16VectorFromFp32Vector("fp16_vector", dim, data[:2])
fp16Vector.AppendValue(data[2])
fp16Vector.AppendValue(data[3])
for i, vec := range fp16Vector.Data() {
fp32Vector := vec.ToFloat32Vector()
s.Equal(dim, len(fp32Vector))
for j := 0; j < dim; j++ {
s.InDelta(data[i][j], fp32Vector[j], 7e-3)
}
}

bf16Vector := NewColumnBFloat16VectorFromFp32Vector("bf16_vector", dim, data[:2])
bf16Vector.AppendValue(data[2])
bf16Vector.AppendValue(data[3])
for i, vec := range bf16Vector.Data() {
fp32Vector := vec.ToFloat32Vector()
s.Equal(dim, len(fp32Vector))
for j := 0; j < dim; j++ {
s.InDelta(data[i][j], fp32Vector[j], 7e-3)
}
}
})
}

func (s *VectorSuite) TestSlice() {
Expand Down
29 changes: 20 additions & 9 deletions client/entity/vectors.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@
package entity

import (
"encoding/binary"
"math"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)

// Vector interface vector used int search
Expand All @@ -44,13 +43,17 @@ func (fv FloatVector) FieldType() FieldType {
// Serialize serializes vector into byte slice, used in search placeholder
// LittleEndian is used for convention
func (fv FloatVector) Serialize() []byte {
data := make([]byte, 0, 4*len(fv)) // float32 occupies 4 bytes
buf := make([]byte, 4)
for _, f := range fv {
binary.LittleEndian.PutUint32(buf, math.Float32bits(f))
data = append(data, buf...)
}
return data
return typeutil.Float32ArrayToBytes(fv)
}

func (fv FloatVector) ToFloat16Vector() Float16Vector {
return typeutil.Float32ArrayToFloat16Bytes(fv)
}

// SerializeToBFloat16Bytes serializes vector into bfloat16 byte slice,
// used in search placeholder
func (fv FloatVector) ToBFloat16Vector() BFloat16Vector {
return typeutil.Float32ArrayToBFloat16Bytes(fv)
}

// FloatVector float32 vector wrapper.
Expand All @@ -70,6 +73,10 @@ func (fv Float16Vector) Serialize() []byte {
return fv
}

func (fv Float16Vector) ToFloat32Vector() FloatVector {
return typeutil.Float16BytesToFloat32Vector(fv)
}

// FloatVector float32 vector wrapper.
type BFloat16Vector []byte

Expand All @@ -87,6 +94,10 @@ func (fv BFloat16Vector) Serialize() []byte {
return fv
}

func (fv BFloat16Vector) ToFloat32Vector() FloatVector {
return typeutil.BFloat16BytesToFloat32Vector(fv)
}

// BinaryVector []byte vector wrapper
type BinaryVector []byte

Expand Down
43 changes: 42 additions & 1 deletion client/entity/vectors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,50 @@ func TestVectors(t *testing.T) {
}

fv := FloatVector(raw)

assert.Equal(t, dim, fv.Dim())
assert.Equal(t, dim*4, len(fv.Serialize()))

var fvConverted FloatVector

fp16v := fv.ToFloat16Vector()
assert.Equal(t, dim, fp16v.Dim())
assert.Equal(t, dim*2, len(fp16v.Serialize()))
fvConverted = fp16v.ToFloat32Vector()
assert.Equal(t, dim, fvConverted.Dim())
assert.Equal(t, dim*4, len(fvConverted.Serialize()))

bf16v := fv.ToBFloat16Vector()
assert.Equal(t, dim, bf16v.Dim())
assert.Equal(t, dim*2, len(bf16v.Serialize()))
fvConverted = bf16v.ToFloat32Vector()
assert.Equal(t, dim, fvConverted.Dim())
assert.Equal(t, dim*4, len(fvConverted.Serialize()))
})

t.Run("test fp32 <-> fp16/bf16 vector conversion", func(t *testing.T) {
raw := make([]float32, dim)
for i := 0; i < dim; i++ {
raw[i] = float32(i) * 0.1
}

fv := FloatVector(raw)
fp16v := fv.ToFloat16Vector()
bf16v := fv.ToBFloat16Vector()

assert.Equal(t, dim, fp16v.Dim())
assert.Equal(t, dim*2, len(fp16v.Serialize()))
assert.Equal(t, dim, bf16v.Dim())
assert.Equal(t, dim*2, len(bf16v.Serialize()))

fp32vFromfp16v := fp16v.ToFloat32Vector()
for i := 0; i < dim; i++ {
assert.InDelta(t, fv[i], fp32vFromfp16v[i], 0.04)
}

fp32vFrombf16v := bf16v.ToFloat32Vector()
for i := 0; i < dim; i++ {
assert.InDelta(t, fp32vFromfp16v[i], fp32vFrombf16v[i], 0.04)
}
})

t.Run("test binary vector", func(t *testing.T) {
Expand Down
5 changes: 3 additions & 2 deletions client/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,12 @@ require (
github.com/cockroachdb/errors v1.9.1
github.com/grpc-ecosystem/go-grpc-middleware v1.3.0
github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20241120015424-93892e628c69
github.com/milvus-io/milvus/pkg v0.0.2-0.20241111021426-5e90f348fcbb
github.com/milvus-io/milvus/pkg v0.0.2-0.20241126032235-cb6542339e84
github.com/quasilyte/go-ruleguard/dsl v0.3.22
github.com/samber/lo v1.27.0
github.com/stretchr/testify v1.9.0
github.com/tidwall/gjson v1.17.1
go.uber.org/atomic v1.10.0
golang.org/x/exp v0.0.0-20230224173230-c95f2b4c22f2
google.golang.org/grpc v1.65.0
google.golang.org/protobuf v1.34.2
)
Expand Down Expand Up @@ -67,6 +66,7 @@ require (
github.com/sirupsen/logrus v1.9.0 // indirect
github.com/soheilhy/cmux v0.1.5 // indirect
github.com/spaolacci/murmur3 v1.1.0 // indirect
github.com/spf13/cast v1.3.1 // indirect
github.com/spf13/pflag v1.0.5 // indirect
github.com/stretchr/objx v0.5.2 // indirect
github.com/tidwall/match v1.1.1 // indirect
Expand Down Expand Up @@ -98,6 +98,7 @@ require (
go.uber.org/multierr v1.11.0 // indirect
go.uber.org/zap v1.27.0 // indirect
golang.org/x/crypto v0.25.0 // indirect
golang.org/x/exp v0.0.0-20230224173230-c95f2b4c22f2 // indirect
golang.org/x/net v0.27.0 // indirect
golang.org/x/sync v0.7.0 // indirect
golang.org/x/sys v0.22.0 // indirect
Expand Down
6 changes: 4 additions & 2 deletions client/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -320,8 +320,8 @@ github.com/microcosm-cc/bluemonday v1.0.2/go.mod h1:iVP4YcDBq+n/5fb23BhYFvIMq/le
github.com/miekg/dns v1.0.14/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg=
github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20241120015424-93892e628c69 h1:Qt0Bv2Fum3EX3OlkuQYHJINBzeU4oEuHy2lXSfB/gZw=
github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20241120015424-93892e628c69/go.mod h1:/6UT4zZl6awVeXLeE7UGDWZvXj3IWkRsh3mqsn0DiAs=
github.com/milvus-io/milvus/pkg v0.0.2-0.20241111021426-5e90f348fcbb h1:lMyIrG03agASB88AAwnk+NOU9V33lcBdtub/ZEv6IQU=
github.com/milvus-io/milvus/pkg v0.0.2-0.20241111021426-5e90f348fcbb/go.mod h1:w5nu1Z318AvgWQrGUYXaqLeVLu4JvCS/oYhxqctOZvU=
github.com/milvus-io/milvus/pkg v0.0.2-0.20241126032235-cb6542339e84 h1:EAFxmxUVp5yYFDCrX1MQoSxkTO+ycy8NXEqEDEB3cRM=
github.com/milvus-io/milvus/pkg v0.0.2-0.20241126032235-cb6542339e84/go.mod h1:RATa0GS4jhkPpsYOvQ/QvcNz8rd+TlRPDiSyXQnMMxs=
github.com/mitchellh/cli v1.0.0/go.mod h1:hNIlj7HEI86fIcpObd7a0FcrxTWetlwJDGcceTlRvqc=
github.com/mitchellh/go-homedir v1.0.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0=
github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0=
Expand Down Expand Up @@ -436,6 +436,8 @@ github.com/spaolacci/murmur3 v1.1.0 h1:7c1g84S4BPRrfL5Xrdp6fOJ206sU9y293DDHaoy0b
github.com/spaolacci/murmur3 v1.1.0/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA=
github.com/spf13/afero v1.1.2/go.mod h1:j4pytiNVoe2o6bmDsKpLACNPDBIoEAkihy7loJ1B0CQ=
github.com/spf13/cast v1.3.0/go.mod h1:Qx5cxh0v+4UWYiBimWS+eyWzqEqokIECu5etghLkUJE=
github.com/spf13/cast v1.3.1 h1:nFm6S0SMdyzrzcmThSipiEubIDy8WEXKNZ0UOgiRpng=
github.com/spf13/cast v1.3.1/go.mod h1:Qx5cxh0v+4UWYiBimWS+eyWzqEqokIECu5etghLkUJE=
github.com/spf13/cobra v0.0.5/go.mod h1:3K3wKZymM7VvHMDS9+Akkh4K60UwM26emMESw8tLCHU=
github.com/spf13/cobra v1.1.3/go.mod h1:pGADOWyqRD/YMrPZigI/zbliZ2wVD/23d+is3pSWzOo=
github.com/spf13/jwalterweatherman v1.0.0/go.mod h1:cQK4TGJAtQXfYWX+Ddv3mKDzgVb68N+wFjFa4jdeBTo=
Expand Down
36 changes: 27 additions & 9 deletions client/milvusclient/results_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,16 @@ type ResultSetSuite struct {

func (s *ResultSetSuite) TestResultsetUnmarshal() {
type MyData struct {
A int64 `milvus:"name:id"`
V []float32 `milvus:"name:vector"`
A int64 `milvus:"name:id"`
V []float32 `milvus:"name:vector"`
Fp16V []byte `milvus:"name:fp16_vector"`
Bf16V []byte `milvus:"name:bf16_vector"`
}
type OtherData struct {
A string `milvus:"name:id"`
V []float32 `milvus:"name:vector"`
A string `milvus:"name:id"`
V []float32 `milvus:"name:vector"`
Fp16V []byte `milvus:"name:fp16_vector"`
Bf16V []byte `milvus:"name:bf16_vector"`
}

var (
Expand All @@ -51,6 +55,8 @@ func (s *ResultSetSuite) TestResultsetUnmarshal() {
rs := DataSet([]column.Column{
column.NewColumnInt64("id", idData),
column.NewColumnFloatVector("vector", 2, vectorData),
column.NewColumnFloat16VectorFromFp32Vector("fp16_vector", 2, vectorData),
column.NewColumnBFloat16VectorFromFp32Vector("bf16_vector", 2, vectorData),
})
err := rs.Unmarshal([]MyData{})
s.Error(err)
Expand All @@ -66,6 +72,8 @@ func (s *ResultSetSuite) TestResultsetUnmarshal() {
for idx, row := range ptrReceiver {
s.Equal(row.A, idData[idx])
s.Equal(row.V, vectorData[idx])
s.Equal(entity.Float16Vector(row.Fp16V), entity.FloatVector(vectorData[idx]).ToFloat16Vector())
s.Equal(entity.BFloat16Vector(row.Bf16V), entity.FloatVector(vectorData[idx]).ToBFloat16Vector())
}

var otherReceiver []*OtherData
Expand All @@ -75,12 +83,16 @@ func (s *ResultSetSuite) TestResultsetUnmarshal() {

func (s *ResultSetSuite) TestSearchResultUnmarshal() {
type MyData struct {
A int64 `milvus:"name:id"`
V []float32 `milvus:"name:vector"`
A int64 `milvus:"name:id"`
V []float32 `milvus:"name:vector"`
Fp16V []byte `milvus:"name:fp16_vector"`
Bf16V []byte `milvus:"name:bf16_vector"`
}
type OtherData struct {
A string `milvus:"name:id"`
V []float32 `milvus:"name:vector"`
A string `milvus:"name:id"`
V []float32 `milvus:"name:vector"`
Fp16V []byte `milvus:"name:fp16_vector"`
Bf16V []byte `milvus:"name:bf16_vector"`
}

var (
Expand All @@ -95,10 +107,14 @@ func (s *ResultSetSuite) TestSearchResultUnmarshal() {
sr := ResultSet{
sch: entity.NewSchema().
WithField(entity.NewField().WithName("id").WithIsPrimaryKey(true).WithDataType(entity.FieldTypeInt64)).
WithField(entity.NewField().WithName("vector").WithDim(2).WithDataType(entity.FieldTypeFloatVector)),
WithField(entity.NewField().WithName("vector").WithDim(2).WithDataType(entity.FieldTypeFloatVector)).
WithField(entity.NewField().WithName("fp16_vector").WithDim(2).WithDataType(entity.FieldTypeFloat16Vector)).
WithField(entity.NewField().WithName("bf16_vector").WithDim(2).WithDataType(entity.FieldTypeBFloat16Vector)),
IDs: column.NewColumnInt64("id", idData),
Fields: DataSet([]column.Column{
column.NewColumnFloatVector("vector", 2, vectorData),
column.NewColumnFloat16VectorFromFp32Vector("fp16_vector", 2, vectorData),
column.NewColumnBFloat16VectorFromFp32Vector("bf16_vector", 2, vectorData),
}),
}
err := sr.Unmarshal([]MyData{})
Expand All @@ -115,6 +131,8 @@ func (s *ResultSetSuite) TestSearchResultUnmarshal() {
for idx, row := range ptrReceiver {
s.Equal(row.A, idData[idx])
s.Equal(row.V, vectorData[idx])
s.Equal(entity.Float16Vector(row.Fp16V), entity.FloatVector(vectorData[idx]).ToFloat16Vector())
s.Equal(entity.BFloat16Vector(row.Bf16V), entity.FloatVector(vectorData[idx]).ToBFloat16Vector())
}

var otherReceiver []*OtherData
Expand Down
26 changes: 25 additions & 1 deletion client/milvusclient/write_options.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import (
"github.com/milvus-io/milvus/client/v2/column"
"github.com/milvus-io/milvus/client/v2/entity"
"github.com/milvus-io/milvus/client/v2/row"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)

type InsertOption interface {
Expand Down Expand Up @@ -95,13 +96,18 @@ func (opt *columnBasedDataOption) processInsertColumns(colSchema *entity.Schema,
if col.Type() != field.DataType {
return nil, 0, fmt.Errorf("param column %s has type %v but collection field definition is %v", col.Name(), col.Type(), field.DataType)
}
if field.DataType == entity.FieldTypeFloatVector || field.DataType == entity.FieldTypeBinaryVector {
if field.DataType == entity.FieldTypeFloatVector || field.DataType == entity.FieldTypeBinaryVector ||
field.DataType == entity.FieldTypeFloat16Vector || field.DataType == entity.FieldTypeBFloat16Vector {
dim := 0
switch column := col.(type) {
case *column.ColumnFloatVector:
dim = column.Dim()
case *column.ColumnBinaryVector:
dim = column.Dim()
case *column.ColumnFloat16Vector:
dim = column.Dim()
case *column.ColumnBFloat16Vector:
dim = column.Dim()
}
if fmt.Sprintf("%d", dim) != field.TypeParams[entity.TypeParamDim] {
return nil, 0, fmt.Errorf("params column %s vector dim %d not match collection definition, which has dim of %s", field.Name, dim, field.TypeParams[entity.TypeParamDim])
Expand Down Expand Up @@ -205,6 +211,24 @@ func (opt *columnBasedDataOption) WithFloatVectorColumn(colName string, dim int,
return opt.WithColumns(column)
}

func (opt *columnBasedDataOption) WithFloat16VectorColumn(colName string, dim int, data [][]float32) *columnBasedDataOption {
f16v := make([][]byte, 0, len(data))
for i := 0; i < len(data); i++ {
f16v = append(f16v, typeutil.Float32ArrayToFloat16Bytes(data[i]))
}
column := column.NewColumnFloat16Vector(colName, dim, f16v)
return opt.WithColumns(column)
}

func (opt *columnBasedDataOption) WithBFloat16VectorColumn(colName string, dim int, data [][]float32) *columnBasedDataOption {
bf16v := make([][]byte, 0, len(data))
for i := 0; i < len(data); i++ {
bf16v = append(bf16v, typeutil.Float32ArrayToBFloat16Bytes(data[i]))
}
column := column.NewColumnBFloat16Vector(colName, dim, bf16v)
return opt.WithColumns(column)
}

func (opt *columnBasedDataOption) WithBinaryVectorColumn(colName string, dim int, data [][]byte) *columnBasedDataOption {
column := column.NewColumnBinaryVector(colName, dim, data)
return opt.WithColumns(column)
Expand Down
Loading
Loading