Skip to content

Commit

Permalink
fix ModelProperties conversion issue
Browse files Browse the repository at this point in the history
  • Loading branch information
RaynorChavez committed Nov 12, 2024
1 parent 3ecaaf7 commit 809c9a9
Show file tree
Hide file tree
Showing 3 changed files with 254 additions and 33 deletions.
69 changes: 56 additions & 13 deletions internal/provider/indices_datasource.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,9 @@ type indexModel struct {

type ModelPropertiesModel struct {
Name types.String `tfsdk:"name"`
Dimensions types.Int64 `tfsdk:"dimensions"`
Dimensions types.String `tfsdk:"dimensions"`
Type types.String `tfsdk:"type"`
Tokens types.Int64 `tfsdk:"tokens"`
Tokens types.String `tfsdk:"tokens"`
ModelLocation *ModelLocationModel `tfsdk:"model_location"`
Url types.String `tfsdk:"url"`
TrustRemoteCode types.Bool `tfsdk:"trust_remote_code"`
Expand Down Expand Up @@ -421,6 +421,58 @@ func convertModelLocation(location go_marqo.ModelLocation) *ModelLocationModel {
return modelLocation
}

func (m *ModelPropertiesModel) IsEmpty() bool {
if m == nil {
return true
}
return m.Name.IsNull() &&
m.Dimensions.IsNull() &&
m.Type.IsNull() &&
m.Tokens.IsNull() &&
m.Url.IsNull() &&
!m.TrustRemoteCode.ValueBool() &&
!m.IsMarqtunedModel.ValueBool() &&
(m.ModelLocation == nil || m.ModelLocation.IsEmpty())
}

// Convert ModelProperties from the API response to our schema model
func convertModelProperties(props *go_marqo.ModelProperties) *ModelPropertiesModel {
if props == nil {
return nil
}

model := &ModelPropertiesModel{}

// Convert only non-empty values
if props.Name != "" {
model.Name = types.StringValue(props.Name)
}
if props.Dimensions != 0 {
model.Dimensions = types.StringValue(fmt.Sprintf("%d", props.Dimensions))
}
if props.Type != "" {
model.Type = types.StringValue(props.Type)
}
if props.Tokens != 0 {
model.Tokens = types.StringValue(fmt.Sprintf("%d", props.Tokens))
}
if props.Url != "" {
model.Url = types.StringValue(props.Url)
}

model.TrustRemoteCode = types.BoolValue(props.TrustRemoteCode)
model.IsMarqtunedModel = types.BoolValue(props.IsMarqtunedModel)

model.ModelLocation = convertModelLocation(props.ModelLocation)

// Only return the model if it's not empty
if model.IsEmpty() {
return nil
}

return model
}

// Read refreshes the Terraform state with the latest data.
func (d *indicesDataSource) Read(ctx context.Context, req datasource.ReadRequest, resp *datasource.ReadResponse) {
tflog.Debug(context.TODO(), "Calling marqo client ListIndices")
Expand Down Expand Up @@ -488,17 +540,8 @@ func (d *indicesDataSource) Read(ctx context.Context, req datasource.ReadRequest
Type: types.StringValue(indexDetail.Type),
VectorNumericType: types.StringValue(indexDetail.VectorNumericType),
Model: types.StringValue(indexDetail.Model),
ModelProperties: &ModelPropertiesModel{
Name: types.StringValue(indexDetail.ModelProperties.Name),
Dimensions: types.Int64Value(indexDetail.ModelProperties.Dimensions),
Type: types.StringValue(indexDetail.ModelProperties.Type),
Tokens: types.Int64Value(indexDetail.ModelProperties.Tokens),
ModelLocation: convertModelLocation(indexDetail.ModelProperties.ModelLocation),
Url: types.StringValue(indexDetail.ModelProperties.Url),
TrustRemoteCode: types.BoolValue(indexDetail.ModelProperties.TrustRemoteCode),
IsMarqtunedModel: types.BoolValue(indexDetail.ModelProperties.IsMarqtunedModel),
},
NormalizeEmbeddings: types.BoolValue(indexDetail.NormalizeEmbeddings),
ModelProperties: convertModelProperties(&indexDetail.ModelProperties),
NormalizeEmbeddings: types.BoolValue(indexDetail.NormalizeEmbeddings),
TextPreprocessing: &TextPreprocessingModel{
SplitLength: types.StringValue(fmt.Sprintf("%d", indexDetail.TextPreprocessing.SplitLength)),
SplitMethod: types.StringValue(indexDetail.TextPreprocessing.SplitMethod),
Expand Down
79 changes: 59 additions & 20 deletions internal/provider/indices_resource.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ type IndexSettingsModel struct {
TreatUrlsAndPointersAsImages types.Bool `tfsdk:"treat_urls_and_pointers_as_images"`
TreatUrlsAndPointersAsMedia types.Bool `tfsdk:"treat_urls_and_pointers_as_media"`
Model types.String `tfsdk:"model"`
ModelProperties *ModelPropertiesModel `tfsdk:"model_properties"`
ModelProperties *ModelPropertiesModelCreate `tfsdk:"model_properties"`
NormalizeEmbeddings types.Bool `tfsdk:"normalize_embeddings"`
TextPreprocessing *TextPreprocessingModelCreate `tfsdk:"text_preprocessing"`
ImagePreprocessing *ImagePreprocessingModel `tfsdk:"image_preprocessing"`
Expand All @@ -58,6 +58,17 @@ type IndexSettingsModel struct {
FilterStringMaxLength types.Int64 `tfsdk:"filter_string_max_length"`
}

type ModelPropertiesModelCreate struct {
Name types.String `tfsdk:"name"`
Dimensions types.Int64 `tfsdk:"dimensions"`
Type types.String `tfsdk:"type"`
Tokens types.Int64 `tfsdk:"tokens"`
ModelLocation *ModelLocationModel `tfsdk:"model_location"`
Url types.String `tfsdk:"url"`
TrustRemoteCode types.Bool `tfsdk:"trust_remote_code"`
IsMarqtunedModel types.Bool `tfsdk:"is_marqtuned_model"`
}

type AllFieldInput struct {
Name types.String `tfsdk:"name"`
Type types.String `tfsdk:"type"`
Expand Down Expand Up @@ -372,7 +383,7 @@ func convertModelLocationToAPI(modelLocation *ModelLocationModel) map[string]int
return result
}

func (m *ModelPropertiesModel) IsEmpty() bool {
func (m *ModelPropertiesModelCreate) IsEmpty() bool {
if m == nil {
return true
}
Expand All @@ -395,6 +406,43 @@ func (m *ModelLocationModel) IsEmpty() bool {
(m.Hf == nil || (m.Hf.RepoId.IsNull() && m.Hf.Filename.IsNull()))
}

func convertModelPropertiesToResource(props *go_marqo.ModelProperties) *ModelPropertiesModelCreate {
if props == nil {
return nil
}

model := &ModelPropertiesModelCreate{}

// Convert only non-empty values
if props.Name != "" {
model.Name = types.StringValue(props.Name)
}
if props.Dimensions != 0 {
model.Dimensions = types.Int64Value(props.Dimensions)
}
if props.Type != "" {
model.Type = types.StringValue(props.Type)
}
if props.Tokens != 0 {
model.Tokens = types.Int64Value(props.Tokens)
}
if props.Url != "" {
model.Url = types.StringValue(props.Url)
}

model.TrustRemoteCode = types.BoolValue(props.TrustRemoteCode)
model.IsMarqtunedModel = types.BoolValue(props.IsMarqtunedModel)

model.ModelLocation = convertModelLocation(props.ModelLocation)

// Only return the model if it's not empty
if model.IsEmpty() {
return nil
}

return model
}

func (r *indicesResource) findAndCreateState(indices []go_marqo.IndexDetail, indexName string) (*IndexResourceModel, bool) {
for _, indexDetail := range indices {
if indexDetail.IndexName == indexName {
Expand All @@ -407,24 +455,15 @@ func (r *indicesResource) findAndCreateState(indices []go_marqo.IndexDetail, ind
TreatUrlsAndPointersAsImages: types.BoolValue(indexDetail.TreatUrlsAndPointersAsImages),
TreatUrlsAndPointersAsMedia: types.BoolValue(indexDetail.TreatUrlsAndPointersAsMedia),
Model: types.StringValue(indexDetail.Model),
ModelProperties: &ModelPropertiesModel{
Name: types.StringValue(indexDetail.ModelProperties.Name),
Dimensions: types.Int64Value(indexDetail.ModelProperties.Dimensions),
Type: types.StringValue(indexDetail.ModelProperties.Type),
Tokens: types.Int64Value(indexDetail.ModelProperties.Tokens),
ModelLocation: convertModelLocation(indexDetail.ModelProperties.ModelLocation),
Url: types.StringValue(indexDetail.ModelProperties.Url),
TrustRemoteCode: types.BoolValue(indexDetail.ModelProperties.TrustRemoteCode),
IsMarqtunedModel: types.BoolValue(indexDetail.ModelProperties.IsMarqtunedModel),
},
AllFields: ConvertMarqoAllFieldInputs(indexDetail.AllFields),
TensorFields: indexDetail.TensorFields,
NormalizeEmbeddings: types.BoolValue(indexDetail.NormalizeEmbeddings),
InferenceType: types.StringValue(indexDetail.InferenceType),
NumberOfInferences: types.Int64Value(indexDetail.NumberOfInferences),
StorageClass: types.StringValue(indexDetail.StorageClass),
NumberOfShards: types.Int64Value(indexDetail.NumberOfShards),
NumberOfReplicas: types.Int64Value(indexDetail.NumberOfReplicas),
ModelProperties: convertModelPropertiesToResource(&indexDetail.ModelProperties),
AllFields: ConvertMarqoAllFieldInputs(indexDetail.AllFields),
TensorFields: indexDetail.TensorFields,
NormalizeEmbeddings: types.BoolValue(indexDetail.NormalizeEmbeddings),
InferenceType: types.StringValue(indexDetail.InferenceType),
NumberOfInferences: types.Int64Value(indexDetail.NumberOfInferences),
StorageClass: types.StringValue(indexDetail.StorageClass),
NumberOfShards: types.Int64Value(indexDetail.NumberOfShards),
NumberOfReplicas: types.Int64Value(indexDetail.NumberOfReplicas),
TextPreprocessing: &TextPreprocessingModelCreate{
SplitLength: types.Int64Value(indexDetail.TextPreprocessing.SplitLength),
SplitMethod: types.StringValue(indexDetail.TextPreprocessing.SplitMethod),
Expand Down
139 changes: 139 additions & 0 deletions internal/provider/models_unit_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
package provider

import (
"testing"

"github.com/hashicorp/terraform-plugin-framework/types"
)

func TestModelPropertiesModel_IsEmpty(t *testing.T) {
tests := []struct {
name string
model *ModelPropertiesModel
expected bool
}{
{
name: "nil model",
model: nil,
expected: true,
},
{
name: "empty model",
model: &ModelPropertiesModel{
Name: types.StringNull(),
Dimensions: types.StringValue(""),
Type: types.StringNull(),
Tokens: types.StringValue(""),
Url: types.StringNull(),
TrustRemoteCode: types.BoolValue(false),
IsMarqtunedModel: types.BoolValue(false),
ModelLocation: nil,
},
expected: true,
},
{
name: "model with values",
model: &ModelPropertiesModel{
Name: types.StringValue("test"),
Dimensions: types.StringValue("384"),
Type: types.StringValue("hf"),
Tokens: types.StringValue("0"),
Url: types.StringValue("https://example.com"),
TrustRemoteCode: types.BoolValue(false),
IsMarqtunedModel: types.BoolValue(false),
ModelLocation: nil,
},
expected: false,
},
{
name: "model with only model location",
model: &ModelPropertiesModel{
Name: types.StringNull(),
Dimensions: types.StringValue(""),
Type: types.StringNull(),
Tokens: types.StringValue(""),
Url: types.StringNull(),
TrustRemoteCode: types.BoolValue(false),
IsMarqtunedModel: types.BoolValue(false),
ModelLocation: &ModelLocationModel{
AuthRequired: types.BoolValue(true),
},
},
expected: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := tt.model.IsEmpty()
if got != tt.expected {
t.Errorf("IsEmpty() = %v, want %v", got, tt.expected)
}
})
}
}

func TestModelLocationModel_IsEmpty(t *testing.T) {
tests := []struct {
name string
location *ModelLocationModel
expected bool
}{
{
name: "nil location",
location: nil,
expected: true,
},
{
name: "empty location",
location: &ModelLocationModel{
AuthRequired: types.BoolValue(false),
S3: nil,
Hf: nil,
},
expected: true,
},
{
name: "location with S3",
location: &ModelLocationModel{
AuthRequired: types.BoolValue(false),
S3: &S3LocationModel{
Bucket: types.StringValue("test-bucket"),
Key: types.StringValue("test-key"),
},
Hf: nil,
},
expected: false,
},
{
name: "location with HF",
location: &ModelLocationModel{
AuthRequired: types.BoolValue(false),
S3: nil,
Hf: &HfLocationModel{
RepoId: types.StringValue("test-repo"),
Filename: types.StringValue("test-file"),
},
},
expected: false,
},
{
name: "location with auth required only",
location: &ModelLocationModel{
AuthRequired: types.BoolValue(true),
S3: nil,
Hf: nil,
},
expected: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := tt.location.IsEmpty()
if got != tt.expected {
t.Errorf("IsEmpty() = %v, want %v", got, tt.expected)
}
})
}
}

0 comments on commit 809c9a9

Please sign in to comment.