Skip to content

Commit

Permalink
Make cog train call trainings endpoint (#2013)
Browse files Browse the repository at this point in the history
* Support different response types in _predict

* Currently _predict assumes all response types are
a PredictionResponse type.
* In the case where _predict is called by a train
endpoint this assumption is not true and we have
to use a TrainingResponse type.
* Feed in the response type to the _predict
function so it can properly deserialise training
responses.

* Make cog train CLI call /trainings

* Instead of relying on the /predictions endpoint
to be overridden directly call the /trainings
endpoint.
* Report the correct command to the user in the
event of a failure.

* Explicitly check return code in test train

* Allows better debugging of stderr et al

* Correctly set TrainingResponse type in training idempotent

* This call also requires the appropriate response
type.

* Use explicit argument index syntax in cog predict
  • Loading branch information
8W9aG authored Oct 23, 2024
1 parent a86adcd commit fe9b1f5
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 24 deletions.
14 changes: 9 additions & 5 deletions pkg/cli/predict.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ func cmdPredict(cmd *cobra.Command, args []string) error {
Image: imageName,
Volumes: volumes,
Env: envFlags,
})
}, false)

go func() {
captureSignal := make(chan os.Signal, 1)
Expand All @@ -152,7 +152,7 @@ func cmdPredict(cmd *cobra.Command, args []string) error {
Image: imageName,
Volumes: volumes,
Env: envFlags,
})
}, false)

if err := predictor.Start(os.Stderr, timeout); err != nil {
return err
Expand All @@ -170,14 +170,14 @@ func cmdPredict(cmd *cobra.Command, args []string) error {
}
}()

return predictIndividualInputs(predictor, inputFlags, outPath)
return predictIndividualInputs(predictor, inputFlags, outPath, false)
}

func isURI(ref *openapi3.Schema) bool {
return ref != nil && ref.Type.Is("string") && ref.Format == "uri"
}

func predictIndividualInputs(predictor predict.Predictor, inputFlags []string, outputPath string) error {
func predictIndividualInputs(predictor predict.Predictor, inputFlags []string, outputPath string, isTrain bool) error {
console.Info("Running prediction...")
schema, err := predictor.GetSchema()
if err != nil {
Expand All @@ -200,7 +200,11 @@ func predictIndividualInputs(predictor predict.Predictor, inputFlags []string, o
}

// Generate output depending on type in schema
responseSchema := schema.Paths.Value("/predictions").Post.Responses.Value("200").Value.Content["application/json"].Schema.Value
url := "/predictions"
if isTrain {
url = "/trainings"
}
responseSchema := schema.Paths.Value(url).Post.Responses.Value("200").Value.Content["application/json"].Schema.Value
outputSchema := responseSchema.Properties["output"].Value

prediction, err := predictor.Predict(inputs)
Expand Down
4 changes: 2 additions & 2 deletions pkg/cli/train.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ func cmdTrain(cmd *cobra.Command, args []string) error {
Volumes: volumes,
Env: trainEnvFlags,
Args: []string{"python", "-m", "cog.server.http", "--x-mode", "train"},
})
}, true)

go func() {
captureSignal := make(chan os.Signal, 1)
Expand All @@ -134,5 +134,5 @@ func cmdTrain(cmd *cobra.Command, args []string) error {
}
}()

return predictIndividualInputs(predictor, trainInputFlags, trainOutPath)
return predictIndividualInputs(predictor, trainInputFlags, trainOutPath, true)
}
42 changes: 30 additions & 12 deletions pkg/predict/predictor.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,19 +43,20 @@ type ValidationErrorResponse struct {

type Predictor struct {
runOptions docker.RunOptions
isTrain bool

// Running state
containerID string
port int
}

func NewPredictor(runOptions docker.RunOptions) Predictor {
func NewPredictor(runOptions docker.RunOptions, isTrain bool) Predictor {
if global.Debug {
runOptions.Env = append(runOptions.Env, "COG_LOG_LEVEL=debug")
} else {
runOptions.Env = append(runOptions.Env, "COG_LOG_LEVEL=warning")
}
return Predictor{runOptions: runOptions}
return Predictor{runOptions: runOptions, isTrain: isTrain}
}

func (p *Predictor) Start(logsWriter io.Writer, timeout time.Duration) error {
Expand Down Expand Up @@ -146,7 +147,7 @@ func (p *Predictor) Predict(inputs Inputs) (*Response, error) {
return nil, err
}

url := fmt.Sprintf("http://localhost:%d/predictions", p.port)
url := p.url()
req, err := http.NewRequest(http.MethodPost, url, bytes.NewBuffer(requestBody))
if err != nil {
return nil, fmt.Errorf("Failed to create HTTP request to %s: %w", url, err)
Expand All @@ -164,14 +165,14 @@ func (p *Predictor) Predict(inputs Inputs) (*Response, error) {
if resp.StatusCode == http.StatusUnprocessableEntity {
errorResponse := &ValidationErrorResponse{}
if err := json.NewDecoder(resp.Body).Decode(errorResponse); err != nil {
return nil, fmt.Errorf("/predictions call returned status 422, and the response body failed to decode: %w", err)
return nil, fmt.Errorf("/%s call returned status 422, and the response body failed to decode: %w", p.endpoint(), err)
}

return nil, buildInputValidationErrorMessage(errorResponse)
return nil, p.buildInputValidationErrorMessage(errorResponse)
}

if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("/predictions call returned status %d", resp.StatusCode)
return nil, fmt.Errorf("/%s call returned status %d", p.endpoint(), resp.StatusCode)
}

prediction := &Response{}
Expand All @@ -197,30 +198,47 @@ func (p *Predictor) GetSchema() (*openapi3.T, error) {
return openapi3.NewLoader().LoadFromData(body)
}

func buildInputValidationErrorMessage(errorResponse *ValidationErrorResponse) error {
func (p *Predictor) endpoint() string {
if p.isTrain {
return "trainings"
}
return "predictions"
}

func (p *Predictor) url() string {
return fmt.Sprintf("http://localhost:%d/%s", p.port, p.endpoint())
}

func (p *Predictor) buildInputValidationErrorMessage(errorResponse *ValidationErrorResponse) error {
errorMessages := []string{}

for _, validationError := range errorResponse.Detail {
if len(validationError.Location) != 3 || validationError.Location[0] != "body" || validationError.Location[1] != "input" {
responseBody, _ := json.MarshalIndent(errorResponse, "", "\t")
return fmt.Errorf("/predictions call returned status 422, and there was an unexpected message in response:\n\n%s", responseBody)
return fmt.Errorf("/%s call returned status 422, and there was an unexpected message in response:\n\n%s", p.endpoint(), responseBody)
}

errorMessages = append(errorMessages, fmt.Sprintf("- %s: %s", validationError.Location[2], validationError.Message))
}

command := "predict"
if p.isTrain {
command = "train"
}

return fmt.Errorf(
`The inputs you passed to cog predict could not be validated:
`The inputs you passed to cog %[1]s could not be validated:
%s
%[2]s
You can provide an input with -i. For example:
cog predict -i blur=3.5
cog %[1]s -i blur=3.5
If your input is a local file, you need to prefix the path with @ to tell Cog to read the file contents. For example:
cog predict -i [email protected]`,
cog %[1]s -i [email protected]`,
command,
strings.Join(errorMessages, "\n"),
)
}
46 changes: 42 additions & 4 deletions python/cog/server/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import traceback
from datetime import datetime, timezone
from enum import Enum, auto, unique
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Optional
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Optional, Type

import structlog
import uvicorn
Expand Down Expand Up @@ -218,8 +218,14 @@ def train(
default=None, include_in_schema=False
),
) -> Any: # type: ignore
respond_async = prefer == "respond-async"

with trace_context(make_trace_context(traceparent, tracestate)):
return predict(request, prefer)
return _predict(
request=request,
response_type=TrainingResponse,
respond_async=respond_async,
)

@app.put(
"/trainings/{training_id}",
Expand All @@ -237,8 +243,37 @@ def train_idempotent(
default=None, include_in_schema=False
),
) -> Any:
if request.id is not None and request.id != training_id:
body = {
"loc": ("body", "id"),
"msg": "training ID must match the ID supplied in the URL",
"type": "value_error",
}
raise HTTPException(422, [body])

# We've already checked that the IDs match, now ensure that an ID is
# set on the prediction object
request.id = training_id

# If the prediction service is already running a prediction with a
# matching ID, return its current state.
if runner.is_busy():
task = runner.get_predict_task(request.id)
if task:
return JSONResponse(
jsonable_encoder(task.result),
status_code=202,
)

# TODO: spec-compliant parsing of Prefer header.
respond_async = prefer == "respond-async"

with trace_context(make_trace_context(traceparent, tracestate)):
return predict_idempotent(training_id, request, prefer)
return _predict(
request=request,
response_type=TrainingResponse,
respond_async=respond_async,
)

@app.post("/trainings/{training_id}/cancel")
def cancel_training(
Expand Down Expand Up @@ -311,6 +346,7 @@ async def predict(
with trace_context(make_trace_context(traceparent, tracestate)):
return _predict(
request=request,
response_type=PredictionResponse,
respond_async=respond_async,
)

Expand Down Expand Up @@ -358,12 +394,14 @@ async def predict_idempotent(
with trace_context(make_trace_context(traceparent, tracestate)):
return _predict(
request=request,
response_type=PredictionResponse,
respond_async=respond_async,
)

def _predict(
*,
request: Optional[PredictionRequest],
response_type: Type[schema.PredictionResponse],
respond_async: bool = False,
) -> Response:
# [compat] If no body is supplied, assume that this model can be run
Expand Down Expand Up @@ -412,7 +450,7 @@ def _predict(
else:
response_object = predict_task.result.dict()
try:
_ = PredictionResponse(**response_object)
_ = response_type(**response_object)
except ValidationError as e:
_log_invalid_output(e)
raise HTTPException(status_code=500, detail=str(e)) from e
Expand Down
3 changes: 2 additions & 1 deletion test-integration/test_integration/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@ def test_train_takes_input_and_produces_weights(tmpdir_factory):
result = subprocess.run(
["cog", "train", "--debug", "-i", "n=42"],
cwd=out_dir,
check=True,
check=False,
capture_output=True,
)
assert result.returncode == 0
assert result.stdout == b""
with open(out_dir / "weights.bin", "rb") as f:
assert len(f.read()) == 42
Expand Down

0 comments on commit fe9b1f5

Please sign in to comment.