From 0bfbc41b7739ec6e80e93d316ccb786bad40a84a Mon Sep 17 00:00:00 2001 From: Aron Carroll Date: Mon, 4 Nov 2024 13:39:35 +0000 Subject: [PATCH] Fix `cog predict` running file outputs This commit backports some of the changes to pkg/cli/predict.go from main as of 3e0dc79 to fix a nil pointer exception that occurs when trying to run `cog predict` on a model that outputs a `File` type. The panic occurs in predictIndividualInputs() trying to decode the data URL returned by the model. Giving the error: panic: runtime error: invalid memory address or nil pointer dereference --- go.mod | 12 ++-- go.sum | 31 ++++----- pkg/cli/predict.go | 166 +++++++++++++++++++++++++++++++-------------- 3 files changed, 134 insertions(+), 75 deletions(-) diff --git a/go.mod b/go.mod index 5167988584..8456453cbf 100644 --- a/go.mod +++ b/go.mod @@ -6,7 +6,7 @@ require ( github.com/anaskhan96/soup v1.2.5 github.com/docker/cli v24.0.6+incompatible github.com/docker/docker v24.0.6+incompatible - github.com/getkin/kin-openapi v0.120.0 + github.com/getkin/kin-openapi v0.127.0 github.com/golangci/golangci-lint v1.55.2 github.com/hashicorp/go-version v1.6.0 github.com/logrusorgru/aurora v2.0.3+incompatible @@ -15,7 +15,7 @@ require ( github.com/moby/term v0.5.0 github.com/spf13/cobra v1.8.0 github.com/spf13/pflag v1.0.5 - github.com/stretchr/testify v1.8.4 + github.com/stretchr/testify v1.9.0 github.com/vincent-petithory/dataurl v1.0.0 github.com/xeipuuv/gojsonschema v1.2.0 github.com/xeonx/timeago v1.0.0-rc5 @@ -77,8 +77,8 @@ require ( github.com/fzipp/gocyclo v0.6.0 // indirect github.com/ghostiam/protogetter v0.2.3 // indirect github.com/go-critic/go-critic v0.9.0 // indirect - github.com/go-openapi/jsonpointer v0.19.6 // indirect - github.com/go-openapi/swag v0.22.4 // indirect + github.com/go-openapi/jsonpointer v0.21.0 // indirect + github.com/go-openapi/swag v0.23.0 // indirect github.com/go-toolsmith/astcast v1.1.0 // indirect github.com/go-toolsmith/astcopy v1.1.0 // indirect github.com/go-toolsmith/astequal v1.1.0 // indirect @@ -112,7 +112,7 @@ require ( github.com/hashicorp/hcl v1.0.0 // indirect github.com/hexops/gotextdiff v1.0.3 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect - github.com/invopop/yaml v0.2.0 // indirect + github.com/invopop/yaml v0.3.1 // indirect github.com/jgautheron/goconst v1.6.0 // indirect github.com/jingyugao/rowserrcheck v1.1.1 // indirect github.com/jirfag/go-printf-func-name v0.0.0-20200119135958-7558a9eaa5af // indirect @@ -183,7 +183,7 @@ require ( github.com/spf13/viper v1.13.0 // indirect github.com/ssgreg/nlreturn/v2 v2.2.1 // indirect github.com/stbenjam/no-sprintf-host-port v0.1.1 // indirect - github.com/stretchr/objx v0.5.0 // indirect + github.com/stretchr/objx v0.5.2 // indirect github.com/subosito/gotenv v1.4.1 // indirect github.com/t-yuki/gocover-cobertura v0.0.0-20180217150009-aaee18c8195c // indirect github.com/tdakkota/asciicheck v0.2.0 // indirect diff --git a/go.sum b/go.sum index af15fe66db..1132bc8b67 100644 --- a/go.sum +++ b/go.sum @@ -126,7 +126,6 @@ github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGX github.com/cncf/udpa/go v0.0.0-20200629203442-efcf912fb354/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= github.com/cpuguy83/go-md2man/v2 v2.0.3/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= -github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/creack/pty v1.1.18 h1:n56/Zwd5o6whRC5PMGretI4IdRLlmBXYNjScPaBgsbY= github.com/curioswitch/go-reassign v0.2.0 h1:G9UZyOcpk/d7Gd6mqYgd8XYWFMw/znxwGDUstnC9DIo= github.com/curioswitch/go-reassign v0.2.0/go.mod h1:x6OpXuWvgfQaMGks2BZybTngWjT84hqJfKoO8Tt/Roc= @@ -172,8 +171,8 @@ github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nos github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM= github.com/fzipp/gocyclo v0.6.0 h1:lsblElZG7d3ALtGMx9fmxeTKZaLLpU8mET09yN4BBLo= github.com/fzipp/gocyclo v0.6.0/go.mod h1:rXPyn8fnlpa0R2csP/31uerbiVBugk5whMdlyaLkLoA= -github.com/getkin/kin-openapi v0.120.0 h1:MqJcNJFrMDFNc07iwE8iFC5eT2k/NPUFDIpNeiZv8Jg= -github.com/getkin/kin-openapi v0.120.0/go.mod h1:PCWw/lfBrJY4HcdqE3jj+QFkaFK8ABoqo7PvqVhXXqw= +github.com/getkin/kin-openapi v0.127.0 h1:Mghqi3Dhryf3F8vR370nN67pAERW+3a95vomb3MAREY= +github.com/getkin/kin-openapi v0.127.0/go.mod h1:OZrfXzUfGrNbsKj+xmFBx6E5c6yH3At/tAKSc2UszXM= github.com/ghostiam/protogetter v0.2.3 h1:qdv2pzo3BpLqezwqfGDLZ+nHEYmc5bUpIdsMbBVwMjw= github.com/ghostiam/protogetter v0.2.3/go.mod h1:KmNLOsy1v04hKbvZs8EfGI1fk39AgTdRDxWNYPfXVc4= github.com/go-critic/go-critic v0.9.0 h1:Pmys9qvU3pSML/3GEQ2Xd9RZ/ip+aXHKILuxczKGV/U= @@ -188,11 +187,10 @@ github.com/go-logfmt/logfmt v0.3.0/go.mod h1:Qt1PoO58o5twSAckw1HlFXLmHsOX5/0LbT9 github.com/go-logfmt/logfmt v0.4.0/go.mod h1:3RMwSq7FuexP4Kalkev3ejPJsZTpXXBr9+V4qmtdjCk= github.com/go-logfmt/logfmt v0.5.0/go.mod h1:wCYkCAKZfumFQihp8CzCvQ3paCTfi41vtzG1KdI/P7A= github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ= -github.com/go-openapi/jsonpointer v0.19.6 h1:eCs3fxoIi3Wh6vtgmLTOjdhSpiqphQ+DaPn38N2ZdrE= -github.com/go-openapi/jsonpointer v0.19.6/go.mod h1:osyAmYz/mB/C3I+WsTTSgw1ONzaLJoLCyoi6/zppojs= -github.com/go-openapi/swag v0.22.3/go.mod h1:UzaqsxGiab7freDnrUUra0MwWfN/q7tE4j+VcZ0yl14= -github.com/go-openapi/swag v0.22.4 h1:QLMzNJnMGPRNDCbySlcj1x01tzU8/9LTTL9hZZZogBU= -github.com/go-openapi/swag v0.22.4/go.mod h1:UzaqsxGiab7freDnrUUra0MwWfN/q7tE4j+VcZ0yl14= +github.com/go-openapi/jsonpointer v0.21.0 h1:YgdVicSA9vH5RiHs9TZW5oyafXZFc6+2Vc1rr/O9oNQ= +github.com/go-openapi/jsonpointer v0.21.0/go.mod h1:IUyH9l/+uyhIYQ/PXVA41Rexl+kOkAPDdXEYns6fzUY= +github.com/go-openapi/swag v0.23.0 h1:vsEVJDUo2hPJ2tu0/Xc+4noaxyEffXNIs3cOULZ+GrE= +github.com/go-openapi/swag v0.23.0/go.mod h1:esZ8ITTYEsH1V2trKHjAN8Ai7xHb8RV+YSZ577vPjgQ= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI= github.com/go-test/deep v1.0.8 h1:TDsG77qcSprGbC6vTN8OuXp5g+J+b5Pcguhf7Zt61VM= @@ -339,8 +337,8 @@ github.com/ianlancetaylor/demangle v0.0.0-20181102032728-5e5cf60278f6/go.mod h1: github.com/ianlancetaylor/demangle v0.0.0-20200824232613-28f6c0f3b639/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= -github.com/invopop/yaml v0.2.0 h1:7zky/qH+O0DwAyoobXUqvVBwgBFRxKoQ/3FjcVpjTMY= -github.com/invopop/yaml v0.2.0/go.mod h1:2XuRLgs/ouIrW3XNzuNj7J3Nvu/Dig5MXvbCEdiBN3Q= +github.com/invopop/yaml v0.3.1 h1:f0+ZpmhfBSS4MhG+4HYseMdJhoeeopbSKbq5Rpeelso= +github.com/invopop/yaml v0.3.1/go.mod h1:PMOp3nn4/12yEZUFfmOuNHJsZToEEOwoWsT+D81KkeA= github.com/jgautheron/goconst v1.6.0 h1:gbMLWKRMkzAc6kYsQL6/TxaoBUg3Jm9LSF/Ih1ADWGA= github.com/jgautheron/goconst v1.6.0/go.mod h1:aAosetZ5zaeC/2EfMeRswtxUFBpe2Hr7HzkgX4fanO4= github.com/jingyugao/rowserrcheck v1.1.1 h1:zibz55j/MJtLsjP1OF4bSdgXxwL1b+Vn7Tjzq7gFzUs= @@ -372,12 +370,10 @@ github.com/konsorten/go-windows-terminal-sequences v1.0.3/go.mod h1:T0+1ngSBFLxv github.com/kr/fs v0.1.0/go.mod h1:FFnZGqtBN9Gxj7eW1uZ42v5BccTP0vu6NEaFoC2HwRg= github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= -github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= -github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/kulti/thelper v0.6.3 h1:ElhKf+AlItIu+xGnI990no4cE2+XaSu1ULymV2Yulxs= github.com/kulti/thelper v0.6.3/go.mod h1:DsqKShOvP40epevkFrvIwkCMNYxMeTNjdWL4dqWHZ6I= github.com/kunwardeep/paralleltest v1.0.8 h1:Ul2KsqtzFxTlSU7IP0JusWlLiNqQaloB9vguyjbE558= @@ -511,8 +507,8 @@ github.com/quasilyte/stdinfo v0.0.0-20220114132959-f7386bf02567/go.mod h1:DWNGW8 github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY= github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= -github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M= github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA= +github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/ryancurrah/gomodguard v1.3.0 h1:q15RT/pd6UggBXVBuLps8BXRvl5GPBcwVA7BJHMLuTw= github.com/ryancurrah/gomodguard v1.3.0/go.mod h1:ggBxb3luypPEzqVtq33ee7YSN35V28XeGnid8dnni50= @@ -564,8 +560,9 @@ github.com/stbenjam/no-sprintf-host-port v0.1.1/go.mod h1:TLhvtIvONRzdmkFiio4O8L github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= -github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= +github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= @@ -574,9 +571,9 @@ github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/ github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= -github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= -github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/subosito/gotenv v1.4.1 h1:jyEFiXpy21Wm81FBN71l9VoMMV8H8jG+qIK3GCpY6Qs= github.com/subosito/gotenv v1.4.1/go.mod h1:ayKnFf/c6rvx/2iiLrJUk1e6plDbT3edrFNGqEflhK0= github.com/t-yuki/gocover-cobertura v0.0.0-20180217150009-aaee18c8195c h1:+aPplBwWcHBo6q9xrfWdMrT9o4kltkmmvpemgIjep/8= @@ -1060,7 +1057,6 @@ gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8 gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= -gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA= gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k= @@ -1072,7 +1068,6 @@ gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -gopkg.in/yaml.v3 v3.0.0/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gotest.tools/gotestsum v1.12.0 h1:CmwtaGDkHxrZm4Ib0Vob89MTfpc3GrEFMJKovliPwGk= diff --git a/pkg/cli/predict.go b/pkg/cli/predict.go index f079083fe3..086ae05412 100644 --- a/pkg/cli/predict.go +++ b/pkg/cli/predict.go @@ -7,6 +7,7 @@ import ( "fmt" "os" "os/signal" + "path/filepath" "strings" "syscall" @@ -14,6 +15,7 @@ import ( "github.com/mitchellh/go-homedir" "github.com/spf13/cobra" "github.com/vincent-petithory/dataurl" + "golang.org/x/sys/unix" "github.com/replicate/cog/pkg/config" "github.com/replicate/cog/pkg/docker" @@ -165,6 +167,10 @@ func cmdPredict(cmd *cobra.Command, args []string) error { return predictIndividualInputs(predictor, inputFlags, outPath) } +func isURI(ref *openapi3.Schema) bool { + return ref != nil && ref.Type.Is("string") && ref.Format == "uri" +} + func predictIndividualInputs(predictor predict.Predictor, inputFlags []string, outputPath string) error { console.Info("Running prediction...") schema, err := predictor.GetSchema() @@ -177,39 +183,87 @@ func predictIndividualInputs(predictor predict.Predictor, inputFlags []string, o return err } - prediction, err := predictor.Predict(inputs) - if err != nil { - return err + // If outputPath != "", then we now know the output path for sure + if outputPath != "" { + // Ignore @, to make it behave the same as -i + outputPath = strings.TrimPrefix(outputPath, "@") + + if err := checkOutputWritable(outputPath); err != nil { + return fmt.Errorf("Output path is not writable: %w", err) + } } // Generate output depending on type in schema - var out []byte - responseSchema := schema.Paths["/predictions"].Post.Responses["200"].Value.Content["application/json"].Schema.Value + url := "/predictions" + responseSchema := schema.Paths.Value(url).Post.Responses.Value("200").Value.Content["application/json"].Schema.Value outputSchema := responseSchema.Properties["output"].Value - // Multiple outputs! - if outputSchema.Type == "array" && outputSchema.Items.Value != nil && outputSchema.Items.Value.Type == "string" && outputSchema.Items.Value.Format == "uri" { - return handleMultipleFileOutput(prediction, outputSchema) + prediction, err := predictor.Predict(inputs) + if err != nil { + return fmt.Errorf("Failed to predict: %w", err) } - if outputSchema.Type == "string" && outputSchema.Format == "uri" { - dataurlObj, err := dataurl.DecodeString((*prediction.Output).(string)) - if err != nil { - return fmt.Errorf("Failed to decode dataurl: %w", err) - } - out = dataurlObj.Data + if prediction.Output == nil { + console.Warn("No output generated") + return nil + } + + switch { + case isURI(outputSchema): + addExtension := false if outputPath == "" { outputPath = "output" - extension := mime.ExtensionByType(dataurlObj.ContentType()) - if extension != "" { - outputPath += extension + addExtension = true + } + + outputStr, ok := (*prediction.Output).(string) + if !ok { + return fmt.Errorf("Failed to convert prediction output to string") + } + + if err := writeDataURLOutput(outputStr, outputPath, addExtension); err != nil { + return fmt.Errorf("Failed to write output: %w", err) + } + + return nil + case outputSchema.Type.Is("array") && isURI(outputSchema.Items.Value): + outputs, ok := (*prediction.Output).([]interface{}) + if !ok { + return fmt.Errorf("Failed to decode output") + } + + for i, output := range outputs { + outputPath := fmt.Sprintf("output.%d", i) + addExtension := true + + outputStr, ok := output.(string) + if !ok { + return fmt.Errorf("Failed to convert prediction output to string") + } + + if err := writeDataURLOutput(outputStr, outputPath, addExtension); err != nil { + return fmt.Errorf("Failed to write output %d: %w", i, err) } } - } else if outputSchema.Type == "string" { - // Handle strings separately because if we encode it to JSON it will be surrounded by quotes. - s := (*prediction.Output).(string) - out = []byte(s) - } else { + + return nil + case outputSchema.Type.Is("string"): + s, ok := (*prediction.Output).(string) + if !ok { + return fmt.Errorf("Failed to convert prediction output to string") + } + + if outputPath == "" { + console.Output(s) + } else { + err := writeOutput(outputPath, []byte(s)) + if err != nil { + return fmt.Errorf("Failed to write output: %w", err) + } + } + + return nil + default: // Treat everything else as JSON -- ints, floats, bools will all convert correctly. rawJSON, err := json.Marshal(prediction.Output) if err != nil { @@ -219,27 +273,39 @@ func predictIndividualInputs(predictor predict.Predictor, inputFlags []string, o if err := json.Indent(&indentedJSON, rawJSON, "", " "); err != nil { return err } - out = indentedJSON.Bytes() - // FIXME: this stopped working - // f := colorjson.NewFormatter() - // f.Indent = 2 - // s, _ := f.Marshal(obj) - - } + if outputPath == "" { + console.Output(indentedJSON.String()) + } else { + err := writeOutput(outputPath, indentedJSON.Bytes()) + if err != nil { + return fmt.Errorf("Failed to write output: %w", err) + } + } - // Write to stdout - if outputPath == "" { - console.Output(string(out)) return nil } +} - // Fall back to writing file +func checkOutputWritable(outputPath string) error { + outputPath, err := homedir.Expand(outputPath) + if err != nil { + return err + } - // Ignore @, to make it behave the same as -i - outputPath = strings.TrimPrefix(outputPath, "@") + // Check if the file exists + _, err = os.Stat(outputPath) + if err == nil { + // File exists, check if it's writable + return unix.Access(outputPath, unix.W_OK) + } else if os.IsNotExist(err) { + // File doesn't exist, check if the directory is writable + dir := filepath.Dir(outputPath) + return unix.Access(dir, unix.W_OK) + } - return writeOutput(outputPath, out) + // Some other error occurred + return err } func writeOutput(outputPath string, output []byte) error { @@ -249,7 +315,7 @@ func writeOutput(outputPath string, output []byte) error { } // Write to file - outFile, err := os.OpenFile(outputPath, os.O_WRONLY|os.O_CREATE, 0o755) + outFile, err := os.OpenFile(outputPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o755) if err != nil { return err } @@ -264,26 +330,24 @@ func writeOutput(outputPath string, output []byte) error { return nil } -func handleMultipleFileOutput(prediction *predict.Response, outputSchema *openapi3.Schema) error { - outputs, ok := (*prediction.Output).([]interface{}) - if !ok { - return fmt.Errorf("Failed to decode output") +func writeDataURLOutput(outputString string, outputPath string, addExtension bool) error { + dataurlObj, err := dataurl.DecodeString(outputString) + if err != nil { + return fmt.Errorf("Failed to decode dataurl: %w", err) } + output := dataurlObj.Data - for i, output := range outputs { - outputString := output.(string) - dataurlObj, err := dataurl.DecodeString(outputString) - if err != nil { - return fmt.Errorf("Failed to decode dataurl: %w", err) - } - out := dataurlObj.Data + if addExtension { extension := mime.ExtensionByType(dataurlObj.ContentType()) - outputPath := fmt.Sprintf("output.%d%s", i, extension) - if err := writeOutput(outputPath, out); err != nil { - return err + if extension != "" { + outputPath += extension } } + if err := writeOutput(outputPath, output); err != nil { + return err + } + return nil }