From e27bc5728aa7c4d8145ec46164e899b3a10669ec Mon Sep 17 00:00:00 2001 From: Dhiraj Bokde Date: Fri, 24 Jan 2025 13:06:03 -0800 Subject: [PATCH] feat: add endpoint to list all models across catalog sources Signed-off-by: Dhiraj Bokde --- api/openapi/model-registry.yaml | 30 ++- internal/server/openapi/api.go | 1 + .../openapi/api_model_catalog_service.go | 24 +++ .../api_model_catalog_service_service.go | 22 ++ pkg/openapi/api_model_catalog_service.go | 203 ++++++++++++++++++ 5 files changed, 278 insertions(+), 2 deletions(-) diff --git a/api/openapi/model-registry.yaml b/api/openapi/model-registry.yaml index 04d599ca..d4eb9c08 100644 --- a/api/openapi/model-registry.yaml +++ b/api/openapi/model-registry.yaml @@ -100,6 +100,34 @@ paths: type: string in: path required: true + /api/model_catalog/v1alpha3/catalog_sources/models: + summary: Path used to get the list of catalog models from all catalog sources. + description: >- + The REST endpoint/path used to list zero or more `CatalogModel` entities from all `CatalogSources`. + get: + tags: + - ModelCatalogService + parameters: + - $ref: "#/components/parameters/name" + - $ref: "#/components/parameters/externalId" + - $ref: "#/components/parameters/pageSize" + - $ref: "#/components/parameters/orderBy" + - $ref: "#/components/parameters/sortOrder" + - $ref: "#/components/parameters/offset" + responses: + "200": + $ref: "#/components/responses/CatalogModelListResponse" + "400": + $ref: "#/components/responses/BadRequest" + "401": + $ref: "#/components/responses/Unauthorized" + "404": + $ref: "#/components/responses/NotFound" + "500": + $ref: "#/components/responses/InternalServerError" + operationId: getAllCatalogModels + summary: List All CatalogModels from All CatalogSources + description: Gets a list of all `CatalogModel` entities. /api/model_catalog/v1alpha3/catalog_sources/{id}/models/{model_id}: summary: Path used to get a single CatalogModel. description: >- @@ -1249,8 +1277,6 @@ components: type within a Model Registry instance and cannot be changed once set. type: string artifact: - description: |- - The catalog model artifact. $ref: "#/components/schemas/ModelArtifact" ArtifactState: description: |2- diff --git a/internal/server/openapi/api.go b/internal/server/openapi/api.go index 6a703431..65fadbb5 100644 --- a/internal/server/openapi/api.go +++ b/internal/server/openapi/api.go @@ -39,6 +39,7 @@ type ModelCatalogServiceAPIServicer interface { GetCatalogModels(context.Context, string, string, string, string, model.OrderByField, model.SortOrder, string) (ImplResponse, error) GetCatalogSource(context.Context, string) (ImplResponse, error) GetCatalogSources(context.Context, string, string, model.OrderByField, model.SortOrder, string) (ImplResponse, error) + GetAllCatalogModels(context.Context, string, string, string, model.OrderByField, model.SortOrder, string) (ImplResponse, error) } // ModelRegistryServiceAPIRouter defines the required methods for binding the api requests to a responses for the ModelRegistryServiceAPI diff --git a/internal/server/openapi/api_model_catalog_service.go b/internal/server/openapi/api_model_catalog_service.go index cfe89db4..713775c3 100644 --- a/internal/server/openapi/api_model_catalog_service.go +++ b/internal/server/openapi/api_model_catalog_service.go @@ -51,6 +51,11 @@ func NewModelCatalogServiceAPIController(s ModelCatalogServiceAPIServicer, opts // Routes returns all the api routes for the ModelCatalogServiceAPIController func (c *ModelCatalogServiceAPIController) Routes() Routes { return Routes{ + "GetAllCatalogModels": Route{ + strings.ToUpper("Get"), + "/api/model_catalog/v1alpha3/catalog_sources/models", + c.GetAllCatalogModels, + }, "GetCatalogModel": Route{ strings.ToUpper("Get"), "/api/model_catalog/v1alpha3/catalog_sources/{id}/models/{model_id}", @@ -84,6 +89,25 @@ func (c *ModelCatalogServiceAPIController) Routes() Routes { } } +// GetAllCatalogModels - List All CatalogModels from CatalogSources +func (c *ModelCatalogServiceAPIController) GetAllCatalogModels(w http.ResponseWriter, r *http.Request) { + query := r.URL.Query() + nameParam := query.Get("name") + externalIdParam := query.Get("externalId") + pageSizeParam := query.Get("pageSize") + orderByParam := query.Get("orderBy") + sortOrderParam := query.Get("sortOrder") + offsetParam := query.Get("offset") + result, err := c.service.GetAllCatalogModels(r.Context(), nameParam, externalIdParam, pageSizeParam, model.OrderByField(orderByParam), model.SortOrder(sortOrderParam), offsetParam) + // If an error occurred, encode the error with the status code + if err != nil { + c.errorHandler(w, r, err, &result) + return + } + // If no error, encode the body and the result code + EncodeJSONResponse(result.Body, &result.Code, w) +} + // GetCatalogModel - Get a CatalogModel func (c *ModelCatalogServiceAPIController) GetCatalogModel(w http.ResponseWriter, r *http.Request) { idParam := chi.URLParam(r, "id") diff --git a/internal/server/openapi/api_model_catalog_service_service.go b/internal/server/openapi/api_model_catalog_service_service.go index 1deab2be..f63fa8dd 100644 --- a/internal/server/openapi/api_model_catalog_service_service.go +++ b/internal/server/openapi/api_model_catalog_service_service.go @@ -24,6 +24,26 @@ type ModelCatalogServiceAPIService struct { modelCatalogs map[string]catalog.ModelCatalogApi } +func (m ModelCatalogServiceAPIService) GetAllCatalogModels(ctx context.Context, name string, externalId string, pageSize string, orderBy openapi.OrderByField, sortOrder openapi.SortOrder, offset string) (ImplResponse, error) { + var lastError error + var allModels openapi.CatalogModelList + for _, modelCatalog := range m.modelCatalogs { + models, err := modelCatalog.GetCatalogModels(ctx, name, externalId, pageSize, orderBy, sortOrder, offset) + if err != nil { + lastError = err + } + allModels.Items = append(allModels.Items, models.Items...) + allModels.PageSize = models.PageSize + allModels.Size += models.Size + } + if lastError != nil && allModels.Size == 0 { + // only return an error if there are no models from any catalogs + // NOTE: catalog access errors are silently ignored if at least one catalog is functioning + return Response(http.StatusInternalServerError, lastError), lastError + } + return Response(http.StatusOK, allModels), lastError +} + func (m ModelCatalogServiceAPIService) GetCatalogModel(ctx context.Context, id string, modelId string) (ImplResponse, error) { catalog, ok := m.modelCatalogs[id] if !ok { @@ -104,6 +124,8 @@ func missingCatalogError(id string) (ImplResponse, error) { return ErrorResponse(http.StatusNotFound, err), err } +var _ ModelCatalogServiceAPIServicer = &ModelCatalogServiceAPIService{} + // NewModelCatalogServiceAPIService creates a default api service func NewModelCatalogServiceAPIService(modelCatalogs map[string]catalog.ModelCatalogApi) ModelCatalogServiceAPIServicer { return &ModelCatalogServiceAPIService{ diff --git a/pkg/openapi/api_model_catalog_service.go b/pkg/openapi/api_model_catalog_service.go index c4731557..499e42a9 100644 --- a/pkg/openapi/api_model_catalog_service.go +++ b/pkg/openapi/api_model_catalog_service.go @@ -22,6 +22,209 @@ import ( // ModelCatalogServiceAPIService ModelCatalogServiceAPI service type ModelCatalogServiceAPIService service +type ApiGetAllCatalogModelsRequest struct { + ctx context.Context + ApiService *ModelCatalogServiceAPIService + name *string + externalId *string + pageSize *string + orderBy *OrderByField + sortOrder *SortOrder + offset *string +} + +// Name of entity to search. +func (r ApiGetAllCatalogModelsRequest) Name(name string) ApiGetAllCatalogModelsRequest { + r.name = &name + return r +} + +// External ID of entity to search. +func (r ApiGetAllCatalogModelsRequest) ExternalId(externalId string) ApiGetAllCatalogModelsRequest { + r.externalId = &externalId + return r +} + +// Number of entities in each page. +func (r ApiGetAllCatalogModelsRequest) PageSize(pageSize string) ApiGetAllCatalogModelsRequest { + r.pageSize = &pageSize + return r +} + +// Specifies the order by criteria for listing entities. +func (r ApiGetAllCatalogModelsRequest) OrderBy(orderBy OrderByField) ApiGetAllCatalogModelsRequest { + r.orderBy = &orderBy + return r +} + +// Specifies the sort order for listing entities, defaults to ASC. +func (r ApiGetAllCatalogModelsRequest) SortOrder(sortOrder SortOrder) ApiGetAllCatalogModelsRequest { + r.sortOrder = &sortOrder + return r +} + +// Number of entities to skip before page. +func (r ApiGetAllCatalogModelsRequest) Offset(offset string) ApiGetAllCatalogModelsRequest { + r.offset = &offset + return r +} + +func (r ApiGetAllCatalogModelsRequest) Execute() (*CatalogModelList, *http.Response, error) { + return r.ApiService.GetAllCatalogModelsExecute(r) +} + +/* +GetAllCatalogModels List All CatalogModels from CatalogSources + +Gets a list of all `CatalogModel` entities. + + @param ctx context.Context - for authentication, logging, cancellation, deadlines, tracing, etc. Passed from http.Request or context.Background(). + @return ApiGetAllCatalogModelsRequest +*/ +func (a *ModelCatalogServiceAPIService) GetAllCatalogModels(ctx context.Context) ApiGetAllCatalogModelsRequest { + return ApiGetAllCatalogModelsRequest{ + ApiService: a, + ctx: ctx, + } +} + +// Execute executes the request +// +// @return CatalogModelList +func (a *ModelCatalogServiceAPIService) GetAllCatalogModelsExecute(r ApiGetAllCatalogModelsRequest) (*CatalogModelList, *http.Response, error) { + var ( + localVarHTTPMethod = http.MethodGet + localVarPostBody interface{} + formFiles []formFile + localVarReturnValue *CatalogModelList + ) + + localBasePath, err := a.client.cfg.ServerURLWithContext(r.ctx, "ModelCatalogServiceAPIService.GetAllCatalogModels") + if err != nil { + return localVarReturnValue, nil, &GenericOpenAPIError{error: err.Error()} + } + + localVarPath := localBasePath + "/api/model_catalog/v1alpha3/catalog_sources/models" + + localVarHeaderParams := make(map[string]string) + localVarQueryParams := url.Values{} + localVarFormParams := url.Values{} + + if r.name != nil { + parameterAddToHeaderOrQuery(localVarQueryParams, "name", r.name, "") + } + if r.externalId != nil { + parameterAddToHeaderOrQuery(localVarQueryParams, "externalId", r.externalId, "") + } + if r.pageSize != nil { + parameterAddToHeaderOrQuery(localVarQueryParams, "pageSize", r.pageSize, "") + } + if r.orderBy != nil { + parameterAddToHeaderOrQuery(localVarQueryParams, "orderBy", r.orderBy, "") + } + if r.sortOrder != nil { + parameterAddToHeaderOrQuery(localVarQueryParams, "sortOrder", r.sortOrder, "") + } + if r.offset != nil { + parameterAddToHeaderOrQuery(localVarQueryParams, "offset", r.offset, "") + } + // to determine the Content-Type header + localVarHTTPContentTypes := []string{} + + // set Content-Type header + localVarHTTPContentType := selectHeaderContentType(localVarHTTPContentTypes) + if localVarHTTPContentType != "" { + localVarHeaderParams["Content-Type"] = localVarHTTPContentType + } + + // to determine the Accept header + localVarHTTPHeaderAccepts := []string{"application/json"} + + // set Accept header + localVarHTTPHeaderAccept := selectHeaderAccept(localVarHTTPHeaderAccepts) + if localVarHTTPHeaderAccept != "" { + localVarHeaderParams["Accept"] = localVarHTTPHeaderAccept + } + req, err := a.client.prepareRequest(r.ctx, localVarPath, localVarHTTPMethod, localVarPostBody, localVarHeaderParams, localVarQueryParams, localVarFormParams, formFiles) + if err != nil { + return localVarReturnValue, nil, err + } + + localVarHTTPResponse, err := a.client.callAPI(req) + if err != nil || localVarHTTPResponse == nil { + return localVarReturnValue, localVarHTTPResponse, err + } + + localVarBody, err := io.ReadAll(localVarHTTPResponse.Body) + localVarHTTPResponse.Body.Close() + localVarHTTPResponse.Body = io.NopCloser(bytes.NewBuffer(localVarBody)) + if err != nil { + return localVarReturnValue, localVarHTTPResponse, err + } + + if localVarHTTPResponse.StatusCode >= 300 { + newErr := &GenericOpenAPIError{ + body: localVarBody, + error: localVarHTTPResponse.Status, + } + if localVarHTTPResponse.StatusCode == 400 { + var v Error + err = a.client.decode(&v, localVarBody, localVarHTTPResponse.Header.Get("Content-Type")) + if err != nil { + newErr.error = err.Error() + return localVarReturnValue, localVarHTTPResponse, newErr + } + newErr.error = formatErrorMessage(localVarHTTPResponse.Status, &v) + newErr.model = v + return localVarReturnValue, localVarHTTPResponse, newErr + } + if localVarHTTPResponse.StatusCode == 401 { + var v Error + err = a.client.decode(&v, localVarBody, localVarHTTPResponse.Header.Get("Content-Type")) + if err != nil { + newErr.error = err.Error() + return localVarReturnValue, localVarHTTPResponse, newErr + } + newErr.error = formatErrorMessage(localVarHTTPResponse.Status, &v) + newErr.model = v + return localVarReturnValue, localVarHTTPResponse, newErr + } + if localVarHTTPResponse.StatusCode == 404 { + var v Error + err = a.client.decode(&v, localVarBody, localVarHTTPResponse.Header.Get("Content-Type")) + if err != nil { + newErr.error = err.Error() + return localVarReturnValue, localVarHTTPResponse, newErr + } + newErr.error = formatErrorMessage(localVarHTTPResponse.Status, &v) + newErr.model = v + return localVarReturnValue, localVarHTTPResponse, newErr + } + if localVarHTTPResponse.StatusCode == 500 { + var v Error + err = a.client.decode(&v, localVarBody, localVarHTTPResponse.Header.Get("Content-Type")) + if err != nil { + newErr.error = err.Error() + return localVarReturnValue, localVarHTTPResponse, newErr + } + newErr.error = formatErrorMessage(localVarHTTPResponse.Status, &v) + newErr.model = v + } + return localVarReturnValue, localVarHTTPResponse, newErr + } + + err = a.client.decode(&localVarReturnValue, localVarBody, localVarHTTPResponse.Header.Get("Content-Type")) + if err != nil { + newErr := &GenericOpenAPIError{ + body: localVarBody, + error: err.Error(), + } + return localVarReturnValue, localVarHTTPResponse, newErr + } + + return localVarReturnValue, localVarHTTPResponse, nil +} + type ApiGetCatalogModelRequest struct { ctx context.Context ApiService *ModelCatalogServiceAPIService