Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
RaynorChavez committed Nov 11, 2024
1 parent 591d4a2 commit 601c1fb
Show file tree
Hide file tree
Showing 3 changed files with 209 additions and 58 deletions.
31 changes: 24 additions & 7 deletions go_marqo/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,30 @@ type AllFieldInput struct {
}

type ModelProperties struct {
Name string `json:"name"`
Dimensions int64 `json:"dimensions"`
Type string `json:"type"`
Tokens int64 `json:"tokens"`
ModelLocation string `json:"model_location"`
Url string `json:"url"`
TrustRemoteCode bool `json:"trustRemoteCode"`
Name string `json:"name"`
Dimensions int64 `json:"dimensions"`
Type string `json:"type"`
Tokens int64 `json:"tokens"`
ModelLocation ModelLocation `json:"modelLocation"`
Url string `json:"url"`
TrustRemoteCode bool `json:"trustRemoteCode"`
IsMarqtunedModel bool `json:"isMarqtunedModel"`
}

type ModelLocation struct {
S3 *S3Location `json:"s3,omitempty"`
Hf *HfLocation `json:"hf,omitempty"`
AuthRequired bool `json:"authRequired"`
}

type S3Location struct {
Bucket string `json:"bucket"`
Key string `json:"key"`
}

type HfLocation struct {
RepoId string `json:"repoId"`
Filename string `json:"filename"`
}

type ImagePreprocessingModel struct {
Expand Down
104 changes: 83 additions & 21 deletions internal/provider/indices_datasource.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,30 @@ type indexModel struct {
}

type ModelPropertiesModel struct {
Name types.String `tfsdk:"name"`
Dimensions types.String `tfsdk:"dimensions"`
Type types.String `tfsdk:"type"`
Tokens types.String `tfsdk:"tokens"`
ModelLocation types.String `tfsdk:"model_location"`
Url types.String `tfsdk:"url"`
TrustRemoteCode types.String `tfsdk:"trust_remote_code"`
Name types.String `tfsdk:"name"`
Dimensions types.String `tfsdk:"dimensions"`
Type types.String `tfsdk:"type"`
Tokens types.String `tfsdk:"tokens"`
ModelLocation ModelLocationModel `tfsdk:"model_location"`
Url types.String `tfsdk:"url"`
TrustRemoteCode types.String `tfsdk:"trust_remote_code"`
IsMarqtunedModel types.Bool `tfsdk:"is_marqtuned_model"`
}

type ModelLocationModel struct {
S3 *S3LocationModel `tfsdk:"s3"`
Hf *HfLocationModel `tfsdk:"hf"`
AuthRequired types.Bool `tfsdk:"auth_required"`
}

type S3LocationModel struct {
Bucket types.String `tfsdk:"bucket"`
Key types.String `tfsdk:"key"`
}

type HfLocationModel struct {
RepoId types.String `tfsdk:"repo_id"`
Filename types.String `tfsdk:"filename"`
}

type TextPreprocessingModel struct {
Expand Down Expand Up @@ -236,13 +253,35 @@ func (d *indicesDataSource) Schema(_ context.Context, _ datasource.SchemaRequest
"model_properties": schema.SingleNestedAttribute{
Computed: true,
Attributes: map[string]schema.Attribute{
"name": schema.StringAttribute{Computed: true},
"dimensions": schema.StringAttribute{Computed: true},
"type": schema.StringAttribute{Computed: true},
"tokens": schema.StringAttribute{Computed: true},
"model_location": schema.StringAttribute{Computed: true},
"url": schema.StringAttribute{Computed: true},
"trust_remote_code": schema.StringAttribute{Computed: true},
"name": schema.StringAttribute{Computed: true},
"dimensions": schema.StringAttribute{Computed: true},
"type": schema.StringAttribute{Computed: true},
"tokens": schema.StringAttribute{Computed: true},
"model_location": schema.SingleNestedAttribute{
Computed: true,
Attributes: map[string]schema.Attribute{
"s3": schema.SingleNestedAttribute{
Computed: true,
Optional: true,
Attributes: map[string]schema.Attribute{
"bucket": schema.StringAttribute{Computed: true},
"key": schema.StringAttribute{Computed: true},
},
},
"hf": schema.SingleNestedAttribute{
Computed: true,
Optional: true,
Attributes: map[string]schema.Attribute{
"repo_id": schema.StringAttribute{Computed: true},
"filename": schema.StringAttribute{Computed: true},
},
},
"auth_required": schema.BoolAttribute{Computed: true},
},
},
"url": schema.StringAttribute{Computed: true},
"trust_remote_code": schema.StringAttribute{Computed: true},
"is_marqtuned_model": schema.BoolAttribute{Computed: true},
},
},
"normalize_embeddings": schema.BoolAttribute{
Expand Down Expand Up @@ -326,6 +365,28 @@ func ConvertMarqoAllFieldInputs(marqoFields []go_marqo.AllFieldInput) []AllField
return allFieldsConverted
}

func convertModelLocation(location go_marqo.ModelLocation) ModelLocationModel {
modelLocation := ModelLocationModel{
AuthRequired: types.BoolValue(location.AuthRequired),
}

if location.S3 != nil {
modelLocation.S3 = &S3LocationModel{
Bucket: types.StringValue(location.S3.Bucket),
Key: types.StringValue(location.S3.Key),
}
}

if location.Hf != nil {
modelLocation.Hf = &HfLocationModel{
RepoId: types.StringValue(location.Hf.RepoId),
Filename: types.StringValue(location.Hf.Filename),
}
}

return modelLocation
}

// Read refreshes the Terraform state with the latest data.
func (d *indicesDataSource) Read(ctx context.Context, req datasource.ReadRequest, resp *datasource.ReadResponse) {
tflog.Debug(context.TODO(), "Calling marqo client ListIndices")
Expand Down Expand Up @@ -394,13 +455,14 @@ func (d *indicesDataSource) Read(ctx context.Context, req datasource.ReadRequest
VectorNumericType: types.StringValue(indexDetail.VectorNumericType),
Model: types.StringValue(indexDetail.Model),
ModelProperties: ModelPropertiesModel{
Name: types.StringValue(indexDetail.ModelProperties.Name),
Dimensions: types.StringValue(fmt.Sprintf("%d", indexDetail.ModelProperties.Dimensions)),
Type: types.StringValue(indexDetail.ModelProperties.Type),
Tokens: types.StringValue(fmt.Sprintf("%d", indexDetail.ModelProperties.Tokens)),
ModelLocation: types.StringValue(indexDetail.ModelProperties.ModelLocation),
Url: types.StringValue(indexDetail.ModelProperties.Url),
TrustRemoteCode: types.StringValue(fmt.Sprintf("%t", indexDetail.ModelProperties.TrustRemoteCode)),
Name: types.StringValue(indexDetail.ModelProperties.Name),
Dimensions: types.StringValue(fmt.Sprintf("%d", indexDetail.ModelProperties.Dimensions)),
Type: types.StringValue(indexDetail.ModelProperties.Type),
Tokens: types.StringValue(fmt.Sprintf("%d", indexDetail.ModelProperties.Tokens)),
ModelLocation: convertModelLocation(indexDetail.ModelProperties.ModelLocation),
Url: types.StringValue(indexDetail.ModelProperties.Url),
TrustRemoteCode: types.StringValue(fmt.Sprintf("%t", indexDetail.ModelProperties.TrustRemoteCode)),
IsMarqtunedModel: types.BoolValue(indexDetail.ModelProperties.IsMarqtunedModel),
},
NormalizeEmbeddings: types.BoolValue(indexDetail.NormalizeEmbeddings),
TextPreprocessing: TextPreprocessingModel{
Expand Down
132 changes: 102 additions & 30 deletions internal/provider/indices_resource.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,14 @@ type AllFieldInput struct {
}

type ModelPropertiesModelCreate struct {
Name types.String `tfsdk:"name"`
Dimensions types.Int64 `tfsdk:"dimensions"`
Type types.String `tfsdk:"type"`
Tokens types.Int64 `tfsdk:"tokens"`
ModelLocation types.String `tfsdk:"model_location"`
Url types.String `tfsdk:"url"`
TrustRemoteCode types.Bool `tfsdk:"trust_remote_code"`
Name types.String `tfsdk:"name"`
Dimensions types.Int64 `tfsdk:"dimensions"`
Type types.String `tfsdk:"type"`
Tokens types.Int64 `tfsdk:"tokens"`
ModelLocation *ModelLocationModel `tfsdk:"model_location"`
Url types.String `tfsdk:"url"`
TrustRemoteCode types.Bool `tfsdk:"trust_remote_code"`
IsMarqtunedModel types.Bool `tfsdk:"is_marqtuned_model"`
}

type TextPreprocessingModelCreate struct {
Expand Down Expand Up @@ -197,13 +198,33 @@ func (r *indicesResource) Schema(_ context.Context, _ resource.SchemaRequest, re
"model_properties": schema.SingleNestedAttribute{
Optional: true,
Attributes: map[string]schema.Attribute{
"name": schema.StringAttribute{Optional: true},
"dimensions": schema.Int64Attribute{Optional: true},
"type": schema.StringAttribute{Optional: true},
"tokens": schema.Int64Attribute{Optional: true},
"model_location": schema.StringAttribute{Optional: true},
"url": schema.StringAttribute{Optional: true},
"trust_remote_code": schema.BoolAttribute{Optional: true},
"name": schema.StringAttribute{Optional: true},
"dimensions": schema.Int64Attribute{Optional: true},
"type": schema.StringAttribute{Optional: true},
"tokens": schema.Int64Attribute{Optional: true},
"model_location": schema.SingleNestedAttribute{
Optional: true,
Attributes: map[string]schema.Attribute{
"s3": schema.SingleNestedAttribute{
Optional: true,
Attributes: map[string]schema.Attribute{
"bucket": schema.StringAttribute{Optional: true},
"key": schema.StringAttribute{Optional: true},
},
},
"hf": schema.SingleNestedAttribute{
Optional: true,
Attributes: map[string]schema.Attribute{
"repo_id": schema.StringAttribute{Optional: true},
"filename": schema.StringAttribute{Optional: true},
},
},
"auth_required": schema.BoolAttribute{Optional: true},
},
},
"url": schema.StringAttribute{Optional: true},
"trust_remote_code": schema.BoolAttribute{Optional: true},
"is_marqtuned_model": schema.BoolAttribute{Optional: true},
},
},
"normalize_embeddings": schema.BoolAttribute{
Expand Down Expand Up @@ -336,9 +357,53 @@ func convertAllFieldsToMap(allFieldsInput []AllFieldInput) []map[string]interfac
return allFields
}

func convertModelLocationToAPI(modelLocation *ModelLocationModel) map[string]interface{} {
if modelLocation == nil {
return nil
}

result := map[string]interface{}{
"authRequired": modelLocation.AuthRequired.ValueBool(),
}

if modelLocation.S3 != nil {
result["s3"] = map[string]interface{}{
"bucket": modelLocation.S3.Bucket.ValueString(),
"key": modelLocation.S3.Key.ValueString(),
}
}

if modelLocation.Hf != nil {
result["hf"] = map[string]interface{}{
"repoId": modelLocation.Hf.RepoId.ValueString(),
"filename": modelLocation.Hf.Filename.ValueString(),
}
}

return result
}

func (r *indicesResource) findAndCreateState(indices []go_marqo.IndexDetail, indexName string) (*IndexResourceModel, bool) {
for _, indexDetail := range indices {
if indexDetail.IndexName == indexName {
modelLocation := &ModelLocationModel{
AuthRequired: types.BoolValue(indexDetail.ModelProperties.ModelLocation.AuthRequired),
}

if indexDetail.ModelProperties.ModelLocation.S3 != nil {
modelLocation.S3 = &S3LocationModel{
Bucket: types.StringValue(indexDetail.ModelProperties.ModelLocation.S3.Bucket),
Key: types.StringValue(indexDetail.ModelProperties.ModelLocation.S3.Key),
}
}

if indexDetail.ModelProperties.ModelLocation.Hf != nil {
modelLocation.Hf = &HfLocationModel{
RepoId: types.StringValue(indexDetail.ModelProperties.ModelLocation.Hf.RepoId),
Filename: types.StringValue(indexDetail.ModelProperties.ModelLocation.Hf.Filename),
}
}

return &IndexResourceModel{
//ID: types.StringValue(indexDetail.IndexName),
IndexName: types.StringValue(indexDetail.IndexName),
Expand All @@ -349,13 +414,14 @@ func (r *indicesResource) findAndCreateState(indices []go_marqo.IndexDetail, ind
TreatUrlsAndPointersAsMedia: types.BoolValue(indexDetail.TreatUrlsAndPointersAsMedia),
Model: types.StringValue(indexDetail.Model),
ModelProperties: &ModelPropertiesModelCreate{
Name: types.StringValue(indexDetail.ModelProperties.Name),
Dimensions: types.Int64Value(indexDetail.ModelProperties.Dimensions),
Type: types.StringValue(indexDetail.ModelProperties.Type),
Tokens: types.Int64Value(indexDetail.ModelProperties.Tokens),
ModelLocation: types.StringValue(indexDetail.ModelProperties.ModelLocation),
Url: types.StringValue(indexDetail.ModelProperties.Url),
TrustRemoteCode: types.BoolValue(indexDetail.ModelProperties.TrustRemoteCode),
Name: types.StringValue(indexDetail.ModelProperties.Name),
Dimensions: types.Int64Value(indexDetail.ModelProperties.Dimensions),
Type: types.StringValue(indexDetail.ModelProperties.Type),
Tokens: types.Int64Value(indexDetail.ModelProperties.Tokens),
ModelLocation: modelLocation,
Url: types.StringValue(indexDetail.ModelProperties.Url),
TrustRemoteCode: types.BoolValue(indexDetail.ModelProperties.TrustRemoteCode),
IsMarqtunedModel: types.BoolValue(indexDetail.ModelProperties.IsMarqtunedModel),
},
AllFields: ConvertMarqoAllFieldInputs(indexDetail.AllFields),
TensorFields: indexDetail.TensorFields,
Expand Down Expand Up @@ -522,15 +588,21 @@ func (r *indicesResource) Create(ctx context.Context, req resource.CreateRequest
}
// Optional dictionary fields
if model.Settings.ModelProperties != nil {
settings["modelProperties"] = map[string]interface{}{
"url": model.Settings.ModelProperties.Url.ValueString(),
"dimensions": model.Settings.ModelProperties.Dimensions.ValueInt64(),
"type": model.Settings.ModelProperties.Type.ValueString(),
"tokens": model.Settings.ModelProperties.Tokens.ValueInt64(),
"modelLocation": model.Settings.ModelProperties.ModelLocation.ValueString(),
"name": model.Settings.ModelProperties.Name.ValueString(),
"trustRemoteCode": model.Settings.ModelProperties.TrustRemoteCode.ValueBool(),
modelPropertiesMap := map[string]interface{}{
"name": model.Settings.ModelProperties.Name.ValueString(),
"dimensions": model.Settings.ModelProperties.Dimensions.ValueInt64(),
"type": model.Settings.ModelProperties.Type.ValueString(),
"tokens": model.Settings.ModelProperties.Tokens.ValueInt64(),
"url": model.Settings.ModelProperties.Url.ValueString(),
"trustRemoteCode": model.Settings.ModelProperties.TrustRemoteCode.ValueBool(),
"isMarqtunedModel": model.Settings.ModelProperties.IsMarqtunedModel.ValueBool(),
}

if model.Settings.ModelProperties.ModelLocation != nil {
modelPropertiesMap["modelLocation"] = convertModelLocationToAPI(model.Settings.ModelProperties.ModelLocation)
}

settings["modelProperties"] = modelPropertiesMap
}
if model.Settings.TextPreprocessing != nil {
settings["textPreprocessing"] = map[string]interface{}{
Expand Down Expand Up @@ -587,7 +659,7 @@ func (r *indicesResource) Create(ctx context.Context, req resource.CreateRequest
model.Settings.ModelProperties.Dimensions.IsNull() &&
model.Settings.ModelProperties.Type.IsNull() &&
model.Settings.ModelProperties.Tokens.IsNull() &&
model.Settings.ModelProperties.ModelLocation.IsNull() &&
model.Settings.ModelProperties.ModelLocation == nil &&
model.Settings.ModelProperties.Url.IsNull() &&
model.Settings.ModelProperties.TrustRemoteCode.IsNull()) {
delete(settings, "modelProperties")
Expand Down

0 comments on commit 601c1fb

Please sign in to comment.