Skip to content

Commit

Permalink
Fix cog predict running file outputs
Browse files Browse the repository at this point in the history
This commit backports some of the changes to pkg/cli/predict.go from
main as of 3e0dc79 to fix a nil pointer exception that occurs when
trying to run `cog predict` on a model that outputs a `File` type.

The panic occurs in predictIndividualInputs() trying to decode the
data URL returned by the model. Giving the error:

    panic: runtime error: invalid memory address or nil pointer dereference
  • Loading branch information
aron committed Nov 4, 2024
1 parent ffc097d commit 0bfbc41
Show file tree
Hide file tree
Showing 3 changed files with 134 additions and 75 deletions.
12 changes: 6 additions & 6 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ require (
github.com/anaskhan96/soup v1.2.5
github.com/docker/cli v24.0.6+incompatible
github.com/docker/docker v24.0.6+incompatible
github.com/getkin/kin-openapi v0.120.0
github.com/getkin/kin-openapi v0.127.0
github.com/golangci/golangci-lint v1.55.2
github.com/hashicorp/go-version v1.6.0
github.com/logrusorgru/aurora v2.0.3+incompatible
Expand All @@ -15,7 +15,7 @@ require (
github.com/moby/term v0.5.0
github.com/spf13/cobra v1.8.0
github.com/spf13/pflag v1.0.5
github.com/stretchr/testify v1.8.4
github.com/stretchr/testify v1.9.0
github.com/vincent-petithory/dataurl v1.0.0
github.com/xeipuuv/gojsonschema v1.2.0
github.com/xeonx/timeago v1.0.0-rc5
Expand Down Expand Up @@ -77,8 +77,8 @@ require (
github.com/fzipp/gocyclo v0.6.0 // indirect
github.com/ghostiam/protogetter v0.2.3 // indirect
github.com/go-critic/go-critic v0.9.0 // indirect
github.com/go-openapi/jsonpointer v0.19.6 // indirect
github.com/go-openapi/swag v0.22.4 // indirect
github.com/go-openapi/jsonpointer v0.21.0 // indirect
github.com/go-openapi/swag v0.23.0 // indirect
github.com/go-toolsmith/astcast v1.1.0 // indirect
github.com/go-toolsmith/astcopy v1.1.0 // indirect
github.com/go-toolsmith/astequal v1.1.0 // indirect
Expand Down Expand Up @@ -112,7 +112,7 @@ require (
github.com/hashicorp/hcl v1.0.0 // indirect
github.com/hexops/gotextdiff v1.0.3 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/invopop/yaml v0.2.0 // indirect
github.com/invopop/yaml v0.3.1 // indirect
github.com/jgautheron/goconst v1.6.0 // indirect
github.com/jingyugao/rowserrcheck v1.1.1 // indirect
github.com/jirfag/go-printf-func-name v0.0.0-20200119135958-7558a9eaa5af // indirect
Expand Down Expand Up @@ -183,7 +183,7 @@ require (
github.com/spf13/viper v1.13.0 // indirect
github.com/ssgreg/nlreturn/v2 v2.2.1 // indirect
github.com/stbenjam/no-sprintf-host-port v0.1.1 // indirect
github.com/stretchr/objx v0.5.0 // indirect
github.com/stretchr/objx v0.5.2 // indirect
github.com/subosito/gotenv v1.4.1 // indirect
github.com/t-yuki/gocover-cobertura v0.0.0-20180217150009-aaee18c8195c // indirect
github.com/tdakkota/asciicheck v0.2.0 // indirect
Expand Down
31 changes: 13 additions & 18 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,6 @@ github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGX
github.com/cncf/udpa/go v0.0.0-20200629203442-efcf912fb354/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk=
github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk=
github.com/cpuguy83/go-md2man/v2 v2.0.3/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/creack/pty v1.1.18 h1:n56/Zwd5o6whRC5PMGretI4IdRLlmBXYNjScPaBgsbY=
github.com/curioswitch/go-reassign v0.2.0 h1:G9UZyOcpk/d7Gd6mqYgd8XYWFMw/znxwGDUstnC9DIo=
github.com/curioswitch/go-reassign v0.2.0/go.mod h1:x6OpXuWvgfQaMGks2BZybTngWjT84hqJfKoO8Tt/Roc=
Expand Down Expand Up @@ -172,8 +171,8 @@ github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nos
github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM=
github.com/fzipp/gocyclo v0.6.0 h1:lsblElZG7d3ALtGMx9fmxeTKZaLLpU8mET09yN4BBLo=
github.com/fzipp/gocyclo v0.6.0/go.mod h1:rXPyn8fnlpa0R2csP/31uerbiVBugk5whMdlyaLkLoA=
github.com/getkin/kin-openapi v0.120.0 h1:MqJcNJFrMDFNc07iwE8iFC5eT2k/NPUFDIpNeiZv8Jg=
github.com/getkin/kin-openapi v0.120.0/go.mod h1:PCWw/lfBrJY4HcdqE3jj+QFkaFK8ABoqo7PvqVhXXqw=
github.com/getkin/kin-openapi v0.127.0 h1:Mghqi3Dhryf3F8vR370nN67pAERW+3a95vomb3MAREY=
github.com/getkin/kin-openapi v0.127.0/go.mod h1:OZrfXzUfGrNbsKj+xmFBx6E5c6yH3At/tAKSc2UszXM=
github.com/ghostiam/protogetter v0.2.3 h1:qdv2pzo3BpLqezwqfGDLZ+nHEYmc5bUpIdsMbBVwMjw=
github.com/ghostiam/protogetter v0.2.3/go.mod h1:KmNLOsy1v04hKbvZs8EfGI1fk39AgTdRDxWNYPfXVc4=
github.com/go-critic/go-critic v0.9.0 h1:Pmys9qvU3pSML/3GEQ2Xd9RZ/ip+aXHKILuxczKGV/U=
Expand All @@ -188,11 +187,10 @@ github.com/go-logfmt/logfmt v0.3.0/go.mod h1:Qt1PoO58o5twSAckw1HlFXLmHsOX5/0LbT9
github.com/go-logfmt/logfmt v0.4.0/go.mod h1:3RMwSq7FuexP4Kalkev3ejPJsZTpXXBr9+V4qmtdjCk=
github.com/go-logfmt/logfmt v0.5.0/go.mod h1:wCYkCAKZfumFQihp8CzCvQ3paCTfi41vtzG1KdI/P7A=
github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ=
github.com/go-openapi/jsonpointer v0.19.6 h1:eCs3fxoIi3Wh6vtgmLTOjdhSpiqphQ+DaPn38N2ZdrE=
github.com/go-openapi/jsonpointer v0.19.6/go.mod h1:osyAmYz/mB/C3I+WsTTSgw1ONzaLJoLCyoi6/zppojs=
github.com/go-openapi/swag v0.22.3/go.mod h1:UzaqsxGiab7freDnrUUra0MwWfN/q7tE4j+VcZ0yl14=
github.com/go-openapi/swag v0.22.4 h1:QLMzNJnMGPRNDCbySlcj1x01tzU8/9LTTL9hZZZogBU=
github.com/go-openapi/swag v0.22.4/go.mod h1:UzaqsxGiab7freDnrUUra0MwWfN/q7tE4j+VcZ0yl14=
github.com/go-openapi/jsonpointer v0.21.0 h1:YgdVicSA9vH5RiHs9TZW5oyafXZFc6+2Vc1rr/O9oNQ=
github.com/go-openapi/jsonpointer v0.21.0/go.mod h1:IUyH9l/+uyhIYQ/PXVA41Rexl+kOkAPDdXEYns6fzUY=
github.com/go-openapi/swag v0.23.0 h1:vsEVJDUo2hPJ2tu0/Xc+4noaxyEffXNIs3cOULZ+GrE=
github.com/go-openapi/swag v0.23.0/go.mod h1:esZ8ITTYEsH1V2trKHjAN8Ai7xHb8RV+YSZ577vPjgQ=
github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY=
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI=
github.com/go-test/deep v1.0.8 h1:TDsG77qcSprGbC6vTN8OuXp5g+J+b5Pcguhf7Zt61VM=
Expand Down Expand Up @@ -339,8 +337,8 @@ github.com/ianlancetaylor/demangle v0.0.0-20181102032728-5e5cf60278f6/go.mod h1:
github.com/ianlancetaylor/demangle v0.0.0-20200824232613-28f6c0f3b639/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
github.com/invopop/yaml v0.2.0 h1:7zky/qH+O0DwAyoobXUqvVBwgBFRxKoQ/3FjcVpjTMY=
github.com/invopop/yaml v0.2.0/go.mod h1:2XuRLgs/ouIrW3XNzuNj7J3Nvu/Dig5MXvbCEdiBN3Q=
github.com/invopop/yaml v0.3.1 h1:f0+ZpmhfBSS4MhG+4HYseMdJhoeeopbSKbq5Rpeelso=
github.com/invopop/yaml v0.3.1/go.mod h1:PMOp3nn4/12yEZUFfmOuNHJsZToEEOwoWsT+D81KkeA=
github.com/jgautheron/goconst v1.6.0 h1:gbMLWKRMkzAc6kYsQL6/TxaoBUg3Jm9LSF/Ih1ADWGA=
github.com/jgautheron/goconst v1.6.0/go.mod h1:aAosetZ5zaeC/2EfMeRswtxUFBpe2Hr7HzkgX4fanO4=
github.com/jingyugao/rowserrcheck v1.1.1 h1:zibz55j/MJtLsjP1OF4bSdgXxwL1b+Vn7Tjzq7gFzUs=
Expand Down Expand Up @@ -372,12 +370,10 @@ github.com/konsorten/go-windows-terminal-sequences v1.0.3/go.mod h1:T0+1ngSBFLxv
github.com/kr/fs v0.1.0/go.mod h1:FFnZGqtBN9Gxj7eW1uZ42v5BccTP0vu6NEaFoC2HwRg=
github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/kulti/thelper v0.6.3 h1:ElhKf+AlItIu+xGnI990no4cE2+XaSu1ULymV2Yulxs=
github.com/kulti/thelper v0.6.3/go.mod h1:DsqKShOvP40epevkFrvIwkCMNYxMeTNjdWL4dqWHZ6I=
github.com/kunwardeep/paralleltest v1.0.8 h1:Ul2KsqtzFxTlSU7IP0JusWlLiNqQaloB9vguyjbE558=
Expand Down Expand Up @@ -511,8 +507,8 @@ github.com/quasilyte/stdinfo v0.0.0-20220114132959-f7386bf02567/go.mod h1:DWNGW8
github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY=
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M=
github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA=
github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8=
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/ryancurrah/gomodguard v1.3.0 h1:q15RT/pd6UggBXVBuLps8BXRvl5GPBcwVA7BJHMLuTw=
github.com/ryancurrah/gomodguard v1.3.0/go.mod h1:ggBxb3luypPEzqVtq33ee7YSN35V28XeGnid8dnni50=
Expand Down Expand Up @@ -564,8 +560,9 @@ github.com/stbenjam/no-sprintf-host-port v0.1.1/go.mod h1:TLhvtIvONRzdmkFiio4O8L
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY=
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
Expand All @@ -574,9 +571,9 @@ github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/subosito/gotenv v1.4.1 h1:jyEFiXpy21Wm81FBN71l9VoMMV8H8jG+qIK3GCpY6Qs=
github.com/subosito/gotenv v1.4.1/go.mod h1:ayKnFf/c6rvx/2iiLrJUk1e6plDbT3edrFNGqEflhK0=
github.com/t-yuki/gocover-cobertura v0.0.0-20180217150009-aaee18c8195c h1:+aPplBwWcHBo6q9xrfWdMrT9o4kltkmmvpemgIjep/8=
Expand Down Expand Up @@ -1060,7 +1057,6 @@ gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI=
gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA=
gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k=
Expand All @@ -1072,7 +1068,6 @@ gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.0/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gotest.tools/gotestsum v1.12.0 h1:CmwtaGDkHxrZm4Ib0Vob89MTfpc3GrEFMJKovliPwGk=
Expand Down
166 changes: 115 additions & 51 deletions pkg/cli/predict.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@ import (
"fmt"
"os"
"os/signal"
"path/filepath"
"strings"
"syscall"

"github.com/getkin/kin-openapi/openapi3"
"github.com/mitchellh/go-homedir"
"github.com/spf13/cobra"
"github.com/vincent-petithory/dataurl"
"golang.org/x/sys/unix"

"github.com/replicate/cog/pkg/config"
"github.com/replicate/cog/pkg/docker"
Expand Down Expand Up @@ -165,6 +167,10 @@ func cmdPredict(cmd *cobra.Command, args []string) error {
return predictIndividualInputs(predictor, inputFlags, outPath)
}

func isURI(ref *openapi3.Schema) bool {
return ref != nil && ref.Type.Is("string") && ref.Format == "uri"
}

func predictIndividualInputs(predictor predict.Predictor, inputFlags []string, outputPath string) error {
console.Info("Running prediction...")
schema, err := predictor.GetSchema()
Expand All @@ -177,39 +183,87 @@ func predictIndividualInputs(predictor predict.Predictor, inputFlags []string, o
return err
}

prediction, err := predictor.Predict(inputs)
if err != nil {
return err
// If outputPath != "", then we now know the output path for sure
if outputPath != "" {
// Ignore @, to make it behave the same as -i
outputPath = strings.TrimPrefix(outputPath, "@")

if err := checkOutputWritable(outputPath); err != nil {
return fmt.Errorf("Output path is not writable: %w", err)
}
}

// Generate output depending on type in schema
var out []byte
responseSchema := schema.Paths["/predictions"].Post.Responses["200"].Value.Content["application/json"].Schema.Value
url := "/predictions"
responseSchema := schema.Paths.Value(url).Post.Responses.Value("200").Value.Content["application/json"].Schema.Value
outputSchema := responseSchema.Properties["output"].Value

// Multiple outputs!
if outputSchema.Type == "array" && outputSchema.Items.Value != nil && outputSchema.Items.Value.Type == "string" && outputSchema.Items.Value.Format == "uri" {
return handleMultipleFileOutput(prediction, outputSchema)
prediction, err := predictor.Predict(inputs)
if err != nil {
return fmt.Errorf("Failed to predict: %w", err)
}

if outputSchema.Type == "string" && outputSchema.Format == "uri" {
dataurlObj, err := dataurl.DecodeString((*prediction.Output).(string))
if err != nil {
return fmt.Errorf("Failed to decode dataurl: %w", err)
}
out = dataurlObj.Data
if prediction.Output == nil {
console.Warn("No output generated")
return nil
}

switch {
case isURI(outputSchema):
addExtension := false
if outputPath == "" {
outputPath = "output"
extension := mime.ExtensionByType(dataurlObj.ContentType())
if extension != "" {
outputPath += extension
addExtension = true
}

outputStr, ok := (*prediction.Output).(string)
if !ok {
return fmt.Errorf("Failed to convert prediction output to string")
}

if err := writeDataURLOutput(outputStr, outputPath, addExtension); err != nil {
return fmt.Errorf("Failed to write output: %w", err)
}

return nil
case outputSchema.Type.Is("array") && isURI(outputSchema.Items.Value):
outputs, ok := (*prediction.Output).([]interface{})
if !ok {
return fmt.Errorf("Failed to decode output")
}

for i, output := range outputs {
outputPath := fmt.Sprintf("output.%d", i)
addExtension := true

outputStr, ok := output.(string)
if !ok {
return fmt.Errorf("Failed to convert prediction output to string")
}

if err := writeDataURLOutput(outputStr, outputPath, addExtension); err != nil {
return fmt.Errorf("Failed to write output %d: %w", i, err)
}
}
} else if outputSchema.Type == "string" {
// Handle strings separately because if we encode it to JSON it will be surrounded by quotes.
s := (*prediction.Output).(string)
out = []byte(s)
} else {

return nil
case outputSchema.Type.Is("string"):
s, ok := (*prediction.Output).(string)
if !ok {
return fmt.Errorf("Failed to convert prediction output to string")
}

if outputPath == "" {
console.Output(s)
} else {
err := writeOutput(outputPath, []byte(s))
if err != nil {
return fmt.Errorf("Failed to write output: %w", err)
}
}

return nil
default:
// Treat everything else as JSON -- ints, floats, bools will all convert correctly.
rawJSON, err := json.Marshal(prediction.Output)
if err != nil {
Expand All @@ -219,27 +273,39 @@ func predictIndividualInputs(predictor predict.Predictor, inputFlags []string, o
if err := json.Indent(&indentedJSON, rawJSON, "", " "); err != nil {
return err
}
out = indentedJSON.Bytes()

// FIXME: this stopped working
// f := colorjson.NewFormatter()
// f.Indent = 2
// s, _ := f.Marshal(obj)

}
if outputPath == "" {
console.Output(indentedJSON.String())
} else {
err := writeOutput(outputPath, indentedJSON.Bytes())
if err != nil {
return fmt.Errorf("Failed to write output: %w", err)
}
}

// Write to stdout
if outputPath == "" {
console.Output(string(out))
return nil
}
}

// Fall back to writing file
func checkOutputWritable(outputPath string) error {
outputPath, err := homedir.Expand(outputPath)
if err != nil {
return err
}

// Ignore @, to make it behave the same as -i
outputPath = strings.TrimPrefix(outputPath, "@")
// Check if the file exists
_, err = os.Stat(outputPath)
if err == nil {
// File exists, check if it's writable
return unix.Access(outputPath, unix.W_OK)
} else if os.IsNotExist(err) {
// File doesn't exist, check if the directory is writable
dir := filepath.Dir(outputPath)
return unix.Access(dir, unix.W_OK)
}

return writeOutput(outputPath, out)
// Some other error occurred
return err
}

func writeOutput(outputPath string, output []byte) error {
Expand All @@ -249,7 +315,7 @@ func writeOutput(outputPath string, output []byte) error {
}

// Write to file
outFile, err := os.OpenFile(outputPath, os.O_WRONLY|os.O_CREATE, 0o755)
outFile, err := os.OpenFile(outputPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o755)
if err != nil {
return err
}
Expand All @@ -264,26 +330,24 @@ func writeOutput(outputPath string, output []byte) error {
return nil
}

func handleMultipleFileOutput(prediction *predict.Response, outputSchema *openapi3.Schema) error {
outputs, ok := (*prediction.Output).([]interface{})
if !ok {
return fmt.Errorf("Failed to decode output")
func writeDataURLOutput(outputString string, outputPath string, addExtension bool) error {
dataurlObj, err := dataurl.DecodeString(outputString)
if err != nil {
return fmt.Errorf("Failed to decode dataurl: %w", err)
}
output := dataurlObj.Data

for i, output := range outputs {
outputString := output.(string)
dataurlObj, err := dataurl.DecodeString(outputString)
if err != nil {
return fmt.Errorf("Failed to decode dataurl: %w", err)
}
out := dataurlObj.Data
if addExtension {
extension := mime.ExtensionByType(dataurlObj.ContentType())
outputPath := fmt.Sprintf("output.%d%s", i, extension)
if err := writeOutput(outputPath, out); err != nil {
return err
if extension != "" {
outputPath += extension
}
}

if err := writeOutput(outputPath, output); err != nil {
return err
}

return nil
}

Expand Down

0 comments on commit 0bfbc41

Please sign in to comment.