From 601c1fbd814d7416420e06f10931cfcb9288afb5 Mon Sep 17 00:00:00 2001 From: Raynor Chavez Date: Mon, 11 Nov 2024 15:24:17 +1100 Subject: [PATCH] address comments --- go_marqo/client.go | 31 ++++-- internal/provider/indices_datasource.go | 104 +++++++++++++++---- internal/provider/indices_resource.go | 132 ++++++++++++++++++------ 3 files changed, 209 insertions(+), 58 deletions(-) diff --git a/go_marqo/client.go b/go_marqo/client.go index f67a5af..9a93732 100644 --- a/go_marqo/client.go +++ b/go_marqo/client.go @@ -61,13 +61,30 @@ type AllFieldInput struct { } type ModelProperties struct { - Name string `json:"name"` - Dimensions int64 `json:"dimensions"` - Type string `json:"type"` - Tokens int64 `json:"tokens"` - ModelLocation string `json:"model_location"` - Url string `json:"url"` - TrustRemoteCode bool `json:"trustRemoteCode"` + Name string `json:"name"` + Dimensions int64 `json:"dimensions"` + Type string `json:"type"` + Tokens int64 `json:"tokens"` + ModelLocation ModelLocation `json:"modelLocation"` + Url string `json:"url"` + TrustRemoteCode bool `json:"trustRemoteCode"` + IsMarqtunedModel bool `json:"isMarqtunedModel"` +} + +type ModelLocation struct { + S3 *S3Location `json:"s3,omitempty"` + Hf *HfLocation `json:"hf,omitempty"` + AuthRequired bool `json:"authRequired"` +} + +type S3Location struct { + Bucket string `json:"bucket"` + Key string `json:"key"` +} + +type HfLocation struct { + RepoId string `json:"repoId"` + Filename string `json:"filename"` } type ImagePreprocessingModel struct { diff --git a/internal/provider/indices_datasource.go b/internal/provider/indices_datasource.go index 9dc4357..46784dd 100644 --- a/internal/provider/indices_datasource.go +++ b/internal/provider/indices_datasource.go @@ -62,13 +62,30 @@ type indexModel struct { } type ModelPropertiesModel struct { - Name types.String `tfsdk:"name"` - Dimensions types.String `tfsdk:"dimensions"` - Type types.String `tfsdk:"type"` - Tokens types.String `tfsdk:"tokens"` - ModelLocation types.String `tfsdk:"model_location"` - Url types.String `tfsdk:"url"` - TrustRemoteCode types.String `tfsdk:"trust_remote_code"` + Name types.String `tfsdk:"name"` + Dimensions types.String `tfsdk:"dimensions"` + Type types.String `tfsdk:"type"` + Tokens types.String `tfsdk:"tokens"` + ModelLocation ModelLocationModel `tfsdk:"model_location"` + Url types.String `tfsdk:"url"` + TrustRemoteCode types.String `tfsdk:"trust_remote_code"` + IsMarqtunedModel types.Bool `tfsdk:"is_marqtuned_model"` +} + +type ModelLocationModel struct { + S3 *S3LocationModel `tfsdk:"s3"` + Hf *HfLocationModel `tfsdk:"hf"` + AuthRequired types.Bool `tfsdk:"auth_required"` +} + +type S3LocationModel struct { + Bucket types.String `tfsdk:"bucket"` + Key types.String `tfsdk:"key"` +} + +type HfLocationModel struct { + RepoId types.String `tfsdk:"repo_id"` + Filename types.String `tfsdk:"filename"` } type TextPreprocessingModel struct { @@ -236,13 +253,35 @@ func (d *indicesDataSource) Schema(_ context.Context, _ datasource.SchemaRequest "model_properties": schema.SingleNestedAttribute{ Computed: true, Attributes: map[string]schema.Attribute{ - "name": schema.StringAttribute{Computed: true}, - "dimensions": schema.StringAttribute{Computed: true}, - "type": schema.StringAttribute{Computed: true}, - "tokens": schema.StringAttribute{Computed: true}, - "model_location": schema.StringAttribute{Computed: true}, - "url": schema.StringAttribute{Computed: true}, - "trust_remote_code": schema.StringAttribute{Computed: true}, + "name": schema.StringAttribute{Computed: true}, + "dimensions": schema.StringAttribute{Computed: true}, + "type": schema.StringAttribute{Computed: true}, + "tokens": schema.StringAttribute{Computed: true}, + "model_location": schema.SingleNestedAttribute{ + Computed: true, + Attributes: map[string]schema.Attribute{ + "s3": schema.SingleNestedAttribute{ + Computed: true, + Optional: true, + Attributes: map[string]schema.Attribute{ + "bucket": schema.StringAttribute{Computed: true}, + "key": schema.StringAttribute{Computed: true}, + }, + }, + "hf": schema.SingleNestedAttribute{ + Computed: true, + Optional: true, + Attributes: map[string]schema.Attribute{ + "repo_id": schema.StringAttribute{Computed: true}, + "filename": schema.StringAttribute{Computed: true}, + }, + }, + "auth_required": schema.BoolAttribute{Computed: true}, + }, + }, + "url": schema.StringAttribute{Computed: true}, + "trust_remote_code": schema.StringAttribute{Computed: true}, + "is_marqtuned_model": schema.BoolAttribute{Computed: true}, }, }, "normalize_embeddings": schema.BoolAttribute{ @@ -326,6 +365,28 @@ func ConvertMarqoAllFieldInputs(marqoFields []go_marqo.AllFieldInput) []AllField return allFieldsConverted } +func convertModelLocation(location go_marqo.ModelLocation) ModelLocationModel { + modelLocation := ModelLocationModel{ + AuthRequired: types.BoolValue(location.AuthRequired), + } + + if location.S3 != nil { + modelLocation.S3 = &S3LocationModel{ + Bucket: types.StringValue(location.S3.Bucket), + Key: types.StringValue(location.S3.Key), + } + } + + if location.Hf != nil { + modelLocation.Hf = &HfLocationModel{ + RepoId: types.StringValue(location.Hf.RepoId), + Filename: types.StringValue(location.Hf.Filename), + } + } + + return modelLocation +} + // Read refreshes the Terraform state with the latest data. func (d *indicesDataSource) Read(ctx context.Context, req datasource.ReadRequest, resp *datasource.ReadResponse) { tflog.Debug(context.TODO(), "Calling marqo client ListIndices") @@ -394,13 +455,14 @@ func (d *indicesDataSource) Read(ctx context.Context, req datasource.ReadRequest VectorNumericType: types.StringValue(indexDetail.VectorNumericType), Model: types.StringValue(indexDetail.Model), ModelProperties: ModelPropertiesModel{ - Name: types.StringValue(indexDetail.ModelProperties.Name), - Dimensions: types.StringValue(fmt.Sprintf("%d", indexDetail.ModelProperties.Dimensions)), - Type: types.StringValue(indexDetail.ModelProperties.Type), - Tokens: types.StringValue(fmt.Sprintf("%d", indexDetail.ModelProperties.Tokens)), - ModelLocation: types.StringValue(indexDetail.ModelProperties.ModelLocation), - Url: types.StringValue(indexDetail.ModelProperties.Url), - TrustRemoteCode: types.StringValue(fmt.Sprintf("%t", indexDetail.ModelProperties.TrustRemoteCode)), + Name: types.StringValue(indexDetail.ModelProperties.Name), + Dimensions: types.StringValue(fmt.Sprintf("%d", indexDetail.ModelProperties.Dimensions)), + Type: types.StringValue(indexDetail.ModelProperties.Type), + Tokens: types.StringValue(fmt.Sprintf("%d", indexDetail.ModelProperties.Tokens)), + ModelLocation: convertModelLocation(indexDetail.ModelProperties.ModelLocation), + Url: types.StringValue(indexDetail.ModelProperties.Url), + TrustRemoteCode: types.StringValue(fmt.Sprintf("%t", indexDetail.ModelProperties.TrustRemoteCode)), + IsMarqtunedModel: types.BoolValue(indexDetail.ModelProperties.IsMarqtunedModel), }, NormalizeEmbeddings: types.BoolValue(indexDetail.NormalizeEmbeddings), TextPreprocessing: TextPreprocessingModel{ diff --git a/internal/provider/indices_resource.go b/internal/provider/indices_resource.go index 8acaac5..41ce681 100644 --- a/internal/provider/indices_resource.go +++ b/internal/provider/indices_resource.go @@ -66,13 +66,14 @@ type AllFieldInput struct { } type ModelPropertiesModelCreate struct { - Name types.String `tfsdk:"name"` - Dimensions types.Int64 `tfsdk:"dimensions"` - Type types.String `tfsdk:"type"` - Tokens types.Int64 `tfsdk:"tokens"` - ModelLocation types.String `tfsdk:"model_location"` - Url types.String `tfsdk:"url"` - TrustRemoteCode types.Bool `tfsdk:"trust_remote_code"` + Name types.String `tfsdk:"name"` + Dimensions types.Int64 `tfsdk:"dimensions"` + Type types.String `tfsdk:"type"` + Tokens types.Int64 `tfsdk:"tokens"` + ModelLocation *ModelLocationModel `tfsdk:"model_location"` + Url types.String `tfsdk:"url"` + TrustRemoteCode types.Bool `tfsdk:"trust_remote_code"` + IsMarqtunedModel types.Bool `tfsdk:"is_marqtuned_model"` } type TextPreprocessingModelCreate struct { @@ -197,13 +198,33 @@ func (r *indicesResource) Schema(_ context.Context, _ resource.SchemaRequest, re "model_properties": schema.SingleNestedAttribute{ Optional: true, Attributes: map[string]schema.Attribute{ - "name": schema.StringAttribute{Optional: true}, - "dimensions": schema.Int64Attribute{Optional: true}, - "type": schema.StringAttribute{Optional: true}, - "tokens": schema.Int64Attribute{Optional: true}, - "model_location": schema.StringAttribute{Optional: true}, - "url": schema.StringAttribute{Optional: true}, - "trust_remote_code": schema.BoolAttribute{Optional: true}, + "name": schema.StringAttribute{Optional: true}, + "dimensions": schema.Int64Attribute{Optional: true}, + "type": schema.StringAttribute{Optional: true}, + "tokens": schema.Int64Attribute{Optional: true}, + "model_location": schema.SingleNestedAttribute{ + Optional: true, + Attributes: map[string]schema.Attribute{ + "s3": schema.SingleNestedAttribute{ + Optional: true, + Attributes: map[string]schema.Attribute{ + "bucket": schema.StringAttribute{Optional: true}, + "key": schema.StringAttribute{Optional: true}, + }, + }, + "hf": schema.SingleNestedAttribute{ + Optional: true, + Attributes: map[string]schema.Attribute{ + "repo_id": schema.StringAttribute{Optional: true}, + "filename": schema.StringAttribute{Optional: true}, + }, + }, + "auth_required": schema.BoolAttribute{Optional: true}, + }, + }, + "url": schema.StringAttribute{Optional: true}, + "trust_remote_code": schema.BoolAttribute{Optional: true}, + "is_marqtuned_model": schema.BoolAttribute{Optional: true}, }, }, "normalize_embeddings": schema.BoolAttribute{ @@ -336,9 +357,53 @@ func convertAllFieldsToMap(allFieldsInput []AllFieldInput) []map[string]interfac return allFields } +func convertModelLocationToAPI(modelLocation *ModelLocationModel) map[string]interface{} { + if modelLocation == nil { + return nil + } + + result := map[string]interface{}{ + "authRequired": modelLocation.AuthRequired.ValueBool(), + } + + if modelLocation.S3 != nil { + result["s3"] = map[string]interface{}{ + "bucket": modelLocation.S3.Bucket.ValueString(), + "key": modelLocation.S3.Key.ValueString(), + } + } + + if modelLocation.Hf != nil { + result["hf"] = map[string]interface{}{ + "repoId": modelLocation.Hf.RepoId.ValueString(), + "filename": modelLocation.Hf.Filename.ValueString(), + } + } + + return result +} + func (r *indicesResource) findAndCreateState(indices []go_marqo.IndexDetail, indexName string) (*IndexResourceModel, bool) { for _, indexDetail := range indices { if indexDetail.IndexName == indexName { + modelLocation := &ModelLocationModel{ + AuthRequired: types.BoolValue(indexDetail.ModelProperties.ModelLocation.AuthRequired), + } + + if indexDetail.ModelProperties.ModelLocation.S3 != nil { + modelLocation.S3 = &S3LocationModel{ + Bucket: types.StringValue(indexDetail.ModelProperties.ModelLocation.S3.Bucket), + Key: types.StringValue(indexDetail.ModelProperties.ModelLocation.S3.Key), + } + } + + if indexDetail.ModelProperties.ModelLocation.Hf != nil { + modelLocation.Hf = &HfLocationModel{ + RepoId: types.StringValue(indexDetail.ModelProperties.ModelLocation.Hf.RepoId), + Filename: types.StringValue(indexDetail.ModelProperties.ModelLocation.Hf.Filename), + } + } + return &IndexResourceModel{ //ID: types.StringValue(indexDetail.IndexName), IndexName: types.StringValue(indexDetail.IndexName), @@ -349,13 +414,14 @@ func (r *indicesResource) findAndCreateState(indices []go_marqo.IndexDetail, ind TreatUrlsAndPointersAsMedia: types.BoolValue(indexDetail.TreatUrlsAndPointersAsMedia), Model: types.StringValue(indexDetail.Model), ModelProperties: &ModelPropertiesModelCreate{ - Name: types.StringValue(indexDetail.ModelProperties.Name), - Dimensions: types.Int64Value(indexDetail.ModelProperties.Dimensions), - Type: types.StringValue(indexDetail.ModelProperties.Type), - Tokens: types.Int64Value(indexDetail.ModelProperties.Tokens), - ModelLocation: types.StringValue(indexDetail.ModelProperties.ModelLocation), - Url: types.StringValue(indexDetail.ModelProperties.Url), - TrustRemoteCode: types.BoolValue(indexDetail.ModelProperties.TrustRemoteCode), + Name: types.StringValue(indexDetail.ModelProperties.Name), + Dimensions: types.Int64Value(indexDetail.ModelProperties.Dimensions), + Type: types.StringValue(indexDetail.ModelProperties.Type), + Tokens: types.Int64Value(indexDetail.ModelProperties.Tokens), + ModelLocation: modelLocation, + Url: types.StringValue(indexDetail.ModelProperties.Url), + TrustRemoteCode: types.BoolValue(indexDetail.ModelProperties.TrustRemoteCode), + IsMarqtunedModel: types.BoolValue(indexDetail.ModelProperties.IsMarqtunedModel), }, AllFields: ConvertMarqoAllFieldInputs(indexDetail.AllFields), TensorFields: indexDetail.TensorFields, @@ -522,15 +588,21 @@ func (r *indicesResource) Create(ctx context.Context, req resource.CreateRequest } // Optional dictionary fields if model.Settings.ModelProperties != nil { - settings["modelProperties"] = map[string]interface{}{ - "url": model.Settings.ModelProperties.Url.ValueString(), - "dimensions": model.Settings.ModelProperties.Dimensions.ValueInt64(), - "type": model.Settings.ModelProperties.Type.ValueString(), - "tokens": model.Settings.ModelProperties.Tokens.ValueInt64(), - "modelLocation": model.Settings.ModelProperties.ModelLocation.ValueString(), - "name": model.Settings.ModelProperties.Name.ValueString(), - "trustRemoteCode": model.Settings.ModelProperties.TrustRemoteCode.ValueBool(), + modelPropertiesMap := map[string]interface{}{ + "name": model.Settings.ModelProperties.Name.ValueString(), + "dimensions": model.Settings.ModelProperties.Dimensions.ValueInt64(), + "type": model.Settings.ModelProperties.Type.ValueString(), + "tokens": model.Settings.ModelProperties.Tokens.ValueInt64(), + "url": model.Settings.ModelProperties.Url.ValueString(), + "trustRemoteCode": model.Settings.ModelProperties.TrustRemoteCode.ValueBool(), + "isMarqtunedModel": model.Settings.ModelProperties.IsMarqtunedModel.ValueBool(), } + + if model.Settings.ModelProperties.ModelLocation != nil { + modelPropertiesMap["modelLocation"] = convertModelLocationToAPI(model.Settings.ModelProperties.ModelLocation) + } + + settings["modelProperties"] = modelPropertiesMap } if model.Settings.TextPreprocessing != nil { settings["textPreprocessing"] = map[string]interface{}{ @@ -587,7 +659,7 @@ func (r *indicesResource) Create(ctx context.Context, req resource.CreateRequest model.Settings.ModelProperties.Dimensions.IsNull() && model.Settings.ModelProperties.Type.IsNull() && model.Settings.ModelProperties.Tokens.IsNull() && - model.Settings.ModelProperties.ModelLocation.IsNull() && + model.Settings.ModelProperties.ModelLocation == nil && model.Settings.ModelProperties.Url.IsNull() && model.Settings.ModelProperties.TrustRemoteCode.IsNull()) { delete(settings, "modelProperties")