Skip to content

Commit

Permalink
fix bugs and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
RaynorChavez committed Nov 12, 2024
1 parent 99be041 commit c2f1383
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 26 deletions.
33 changes: 21 additions & 12 deletions internal/provider/indices_datasource.go
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ func (d *indicesDataSource) Schema(_ context.Context, _ datasource.SchemaRequest
},
},
"url": schema.StringAttribute{Computed: true},
"trust_remote_code": schema.StringAttribute{Computed: true},
"trust_remote_code": schema.BoolAttribute{Computed: true},
"is_marqtuned_model": schema.BoolAttribute{Computed: true},
},
},
Expand Down Expand Up @@ -400,25 +400,32 @@ func ConvertMarqoAllFieldInputs(marqoFields []go_marqo.AllFieldInput) []AllField
}

func convertModelLocation(location go_marqo.ModelLocation) *ModelLocationModel {
modelLocation := &ModelLocationModel{
AuthRequired: types.BoolValue(location.AuthRequired),
model := &ModelLocationModel{}

if location.AuthRequired {
model.AuthRequired = types.BoolValue(true)
}

if location.S3 != nil {
modelLocation.S3 = &S3LocationModel{
if location.S3 != nil && (location.S3.Bucket != "" || location.S3.Key != "") {
model.S3 = &S3LocationModel{
Bucket: types.StringValue(location.S3.Bucket),
Key: types.StringValue(location.S3.Key),
}
}

if location.Hf != nil {
modelLocation.Hf = &HfLocationModel{
if location.Hf != nil && (location.Hf.RepoId != "" || location.Hf.Filename != "") {
model.Hf = &HfLocationModel{
RepoId: types.StringValue(location.Hf.RepoId),
Filename: types.StringValue(location.Hf.Filename),
}
}

return modelLocation
// Return nil if no fields were set
if model.AuthRequired.IsNull() && model.S3 == nil && model.Hf == nil {
return nil
}

return model
}

func (m *ModelPropertiesModel) IsEmpty() bool {
Expand All @@ -430,8 +437,6 @@ func (m *ModelPropertiesModel) IsEmpty() bool {
m.Type.IsNull() &&
m.Tokens.IsNull() &&
m.Url.IsNull() &&
!m.TrustRemoteCode.ValueBool() &&
!m.IsMarqtunedModel.ValueBool() &&
(m.ModelLocation == nil || m.ModelLocation.IsEmpty())
}

Expand Down Expand Up @@ -460,8 +465,12 @@ func convertModelProperties(props *go_marqo.ModelProperties) *ModelPropertiesMod
model.Url = types.StringValue(props.Url)
}

model.TrustRemoteCode = types.BoolValue(props.TrustRemoteCode)
model.IsMarqtunedModel = types.BoolValue(props.IsMarqtunedModel)
if props.TrustRemoteCode {
model.TrustRemoteCode = types.BoolValue(true)
}
if props.IsMarqtunedModel {
model.IsMarqtunedModel = types.BoolValue(true)
}

model.ModelLocation = convertModelLocation(props.ModelLocation)

Expand Down
21 changes: 12 additions & 9 deletions internal/provider/indices_resource.go
Original file line number Diff line number Diff line change
Expand Up @@ -392,16 +392,14 @@ func (m *ModelPropertiesModelCreate) IsEmpty() bool {
m.Type.IsNull() &&
m.Tokens.IsNull() &&
m.Url.IsNull() &&
!m.TrustRemoteCode.ValueBool() &&
!m.IsMarqtunedModel.ValueBool() &&
(m.ModelLocation == nil || m.ModelLocation.IsEmpty())
}

func (m *ModelLocationModel) IsEmpty() bool {
if m == nil {
return true
}
return !m.AuthRequired.ValueBool() &&
return m.AuthRequired.IsNull() &&
(m.S3 == nil || (m.S3.Bucket.IsNull() && m.S3.Key.IsNull())) &&
(m.Hf == nil || (m.Hf.RepoId.IsNull() && m.Hf.Filename.IsNull()))
}
Expand Down Expand Up @@ -429,13 +427,18 @@ func convertModelPropertiesToResource(props *go_marqo.ModelProperties) *ModelPro
if props.Url != "" {
model.Url = types.StringValue(props.Url)
}
if props.TrustRemoteCode {
model.TrustRemoteCode = types.BoolValue(true)
}
if props.IsMarqtunedModel {
model.IsMarqtunedModel = types.BoolValue(true)
}
// Only convert ModelLocation if it has non-null values
if loc := convertModelLocation(props.ModelLocation); loc != nil {
model.ModelLocation = loc
}

model.TrustRemoteCode = types.BoolValue(props.TrustRemoteCode)
model.IsMarqtunedModel = types.BoolValue(props.IsMarqtunedModel)

model.ModelLocation = convertModelLocation(props.ModelLocation)

// Only return the model if it's not empty
// Only return the model if it's not empty.
if model.IsEmpty() {
return nil
}
Expand Down
25 changes: 20 additions & 5 deletions internal/provider/indices_resource_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@ import (
"github.com/hashicorp/terraform-plugin-testing/terraform"
)

func TestAccResourceIndex(t *testing.T) {
unstructured_langbind_index_name := fmt.Sprintf("donotdelete_unstr_resrc_%s", randomString(6))
func TestAccResourceCustomModelIndex(t *testing.T) {
unstructured_custom_model_index_name := fmt.Sprintf("donotdelete_unstr_resrc_%s", randomString(6))
resource.Test(t, resource.TestCase{
PreCheck: func() { testAccPreCheck(t) },
Expand All @@ -19,7 +18,6 @@ func TestAccResourceIndex(t *testing.T) {
{
Config: testAccEmptyConfig(),
Check: resource.ComposeTestCheckFunc(
testAccCheckIndexExistsAndDelete(unstructured_langbind_index_name),
testAccCheckIndexExistsAndDelete(unstructured_custom_model_index_name),
),
},
Expand All @@ -37,14 +35,31 @@ func TestAccResourceIndex(t *testing.T) {
resource.TestCheckResourceAttr("marqo_index.test", "settings.model_properties.url", "https://marqo-ecs-50-audio-test-dataset.s3.us-east-1.amazonaws.com/test-hf.zip"),
resource.TestCheckResourceAttr("marqo_index.test", "settings.model_properties.dimensions", "384"),
resource.TestCheckResourceAttr("marqo_index.test", "settings.model_properties.type", "hf"),
resource.TestCheckResourceAttr("marqo_index.test", "settings.model_properties.trust_remote_code", "false"),
testAccCheckIndexIsReady(unstructured_custom_model_index_name),
func(s *terraform.State) error {
fmt.Println("Custom Model testing completed")
return nil
},
),
},
// Delete testing automatically occurs in TestCase
},
})
}

func TestAccResourceLangBindIndex(t *testing.T) {
unstructured_langbind_index_name := fmt.Sprintf("donotdelete_unstr_resrc_%s", randomString(6))
resource.Test(t, resource.TestCase{
PreCheck: func() { testAccPreCheck(t) },
ProtoV6ProviderFactories: testAccProtoV6ProviderFactories,
Steps: []resource.TestStep{
// Check if index exists and delete if it does
{
Config: testAccEmptyConfig(),
Check: resource.ComposeTestCheckFunc(
testAccCheckIndexExistsAndDelete(unstructured_langbind_index_name),
),
},
// Create and Read testing
{
Config: testAccResourceIndexConfig(unstructured_langbind_index_name),
Expand Down Expand Up @@ -307,12 +322,12 @@ func testAccResourceIndexConfigCustomModel(name string) string {
type = "unstructured"
vector_numeric_type = "float"
treat_urls_and_pointers_as_images = true
treat_urls_and_pointers_as_media = true
model = "custom-model"
model_properties = {
url = "https://marqo-ecs-50-audio-test-dataset.s3.us-east-1.amazonaws.com/test-hf.zip"
dimensions = 384
type = "hf"
trust_remote_code = false
}
normalize_embeddings = true
inference_type = "marqo.CPU.small"
Expand Down

0 comments on commit c2f1383

Please sign in to comment.