@@ -9,9 +9,13 @@ import (
9
9
"html"
10
10
"io"
11
11
"net/http"
12
+ "os"
13
+ "runtime"
12
14
"strconv"
13
15
"strings"
16
+ "syscall"
14
17
"time"
18
+ "unsafe"
15
19
16
20
"github.com/docker/go-units"
17
21
"github.com/docker/model-distribution/distribution"
@@ -106,6 +110,147 @@ func (c *Client) Status() Status {
106
110
}
107
111
}
108
112
113
+ func humanReadableSize (size float64 ) string {
114
+ return units .CustomSize ("%.2f%s" , float64 (size ), 1000.0 , []string {"B" , "kB" , "MB" , "GB" , "TB" , "PB" , "EB" , "ZB" , "YB" })
115
+ }
116
+
117
+ func humanReadableSizePad (size float64 , width int ) string {
118
+ return fmt .Sprintf ("%*s" , width , humanReadableSize (size ))
119
+ }
120
+
121
+ func humanReadableTimePad (seconds int64 , width int ) string {
122
+ var s string
123
+ if seconds < 60 {
124
+ s = fmt .Sprintf ("%ds" , seconds )
125
+ } else if seconds < 3600 {
126
+ s = fmt .Sprintf ("%dm %02ds" , seconds / 60 , seconds % 60 )
127
+ } else {
128
+ s = fmt .Sprintf ("%dh %02dm %02ds" , seconds / 3600 , (seconds % 3600 )/ 60 , seconds % 60 )
129
+ }
130
+ return fmt .Sprintf ("%*s" , width , s )
131
+ }
132
+
133
+ // ProgressBarState tracks the running totals and timing for speed/ETA
134
+ type ProgressBarState struct {
135
+ LastDownloaded uint64
136
+ LastTime time.Time
137
+ StartTime time.Time
138
+ UpdateInterval time.Duration // New: interval between updates
139
+ lastPrint time.Time // New: last time the progress bar was printed
140
+ }
141
+
142
+ // formatBar calculates the bar width and filled bar string.
143
+ func (pbs * ProgressBarState ) formatBar (percent float64 , termWidth int , prefix , suffix string ) string {
144
+ barWidth := termWidth - len (prefix ) - len (suffix ) - 4
145
+ if barWidth < 10 {
146
+ barWidth = 10
147
+ }
148
+ filled := int (percent / 100 * float64 (barWidth ))
149
+ if filled > barWidth {
150
+ filled = barWidth
151
+ }
152
+ bar := strings .Repeat ("█" , filled ) + strings .Repeat (" " , barWidth - filled )
153
+ return bar
154
+ }
155
+
156
+ // calcSpeed calculates the current download speed.
157
+ func (pbs * ProgressBarState ) calcSpeed (current uint64 , now time.Time ) float64 {
158
+ elapsed := now .Sub (pbs .LastTime ).Seconds ()
159
+ if elapsed <= 0 {
160
+ return 0
161
+ }
162
+
163
+ speed := float64 (current - pbs .LastDownloaded ) / elapsed
164
+ pbs .LastTime = now
165
+ pbs .LastDownloaded = current
166
+
167
+ return speed
168
+ }
169
+
170
+ // formatSuffix returns the suffix string showing human readable sizes, speed, and ETA.
171
+ func (pbs * ProgressBarState ) fmtSuffix (current , total uint64 , speed float64 , eta int64 ) string {
172
+ return fmt .Sprintf ("%s/%s %s/s %s" ,
173
+ humanReadableSizePad (float64 (current ), 10 ),
174
+ humanReadableSize (float64 (total )),
175
+ humanReadableSizePad (speed , 10 ),
176
+ humanReadableTimePad (eta , 16 ),
177
+ )
178
+ }
179
+
180
+ // calcETA calculates the estimated time remaining.
181
+ func (pbs * ProgressBarState ) calcETA (current , total uint64 , speed float64 ) int64 {
182
+ if speed <= 0 {
183
+ return 0
184
+ }
185
+ return int64 (float64 (total - current ) / speed )
186
+ }
187
+
188
+ // printProgressBar prints/updates a progress bar in the terminal
189
+ // Only prints if UpdateInterval has passed since last print, or always if interval=0
190
+ func (pbs * ProgressBarState ) printProgressBar (current , total uint64 ) {
191
+ if pbs .StartTime .IsZero () {
192
+ pbs .StartTime = time .Now ()
193
+ pbs .LastTime = pbs .StartTime
194
+ pbs .LastDownloaded = current
195
+ pbs .lastPrint = pbs .StartTime
196
+ }
197
+
198
+ now := time .Now ()
199
+ // Only update display if enough time passed,
200
+ // unless interval is 0 (always print)
201
+ if pbs .UpdateInterval > 0 && now .Sub (pbs .lastPrint ) < pbs .UpdateInterval && current != total {
202
+ return
203
+ }
204
+
205
+ pbs .lastPrint = now
206
+ termWidth := getTerminalWidth ()
207
+ percent := float64 (current ) / float64 (total ) * 100
208
+ prefix := fmt .Sprintf ("%3.0f%% |" , percent )
209
+ speed := pbs .calcSpeed (current , now )
210
+ eta := pbs .calcETA (current , total , speed )
211
+ suffix := pbs .fmtSuffix (current , total , speed , eta )
212
+ bar := pbs .formatBar (percent , termWidth , prefix , suffix )
213
+ fmt .Fprintf (os .Stderr , "\r %s%s| %s" , prefix , bar , suffix )
214
+ }
215
+
216
+ func getTerminalWidthUnix () (int , error ) {
217
+ type winsize struct {
218
+ Row uint16
219
+ Col uint16
220
+ Xpixel uint16
221
+ Ypixel uint16
222
+ }
223
+ ws := & winsize {}
224
+ retCode , _ , errno := syscall .Syscall6 (
225
+ syscall .SYS_IOCTL ,
226
+ uintptr (os .Stdout .Fd ()),
227
+ uintptr (syscall .TIOCGWINSZ ),
228
+ uintptr (unsafe .Pointer (ws )),
229
+ 0 , 0 , 0 ,
230
+ )
231
+ if int (retCode ) == - 1 {
232
+ return 0 , errno
233
+ }
234
+ return int (ws .Col ), nil
235
+ }
236
+
237
+ // getTerminalSize tries to get the terminal width (default 80 if fails)
238
+ func getTerminalWidth () int {
239
+ var width int
240
+ var err error
241
+ default_width := 80
242
+ if runtime .GOOS == "windows" { // to be implemented
243
+ return default_width
244
+ }
245
+
246
+ width , err = getTerminalWidthUnix ()
247
+ if width == 0 || err != nil {
248
+ return default_width
249
+ }
250
+
251
+ return width
252
+ }
253
+
109
254
func (c * Client ) Pull (model string , ignoreRuntimeMemoryCheck bool , progress func (string )) (string , bool , error ) {
110
255
model = normalizeHuggingFaceModelName (model )
111
256
jsonData , err := json .Marshal (dmrm.ModelCreateRequest {From : model , IgnoreRuntimeMemoryCheck : ignoreRuntimeMemoryCheck })
@@ -130,10 +275,14 @@ func (c *Client) Pull(model string, ignoreRuntimeMemoryCheck bool, progress func
130
275
}
131
276
132
277
progressShown := false
133
- current := uint64 (0 ) // Track cumulative progress across all layers
278
+ // Track cumulative progress across all layers
279
+ current := uint64 (0 )
134
280
layerProgress := make (map [string ]uint64 ) // Track progress per layer ID
135
281
136
282
scanner := bufio .NewScanner (resp .Body )
283
+ pbs := & ProgressBarState {
284
+ UpdateInterval : time .Millisecond * 100 ,
285
+ }
137
286
for scanner .Scan () {
138
287
progressLine := scanner .Text ()
139
288
if progressLine == "" {
@@ -159,7 +308,7 @@ func (c *Client) Pull(model string, ignoreRuntimeMemoryCheck bool, progress func
159
308
current += layerCurrent
160
309
}
161
310
162
- progress ( fmt . Sprintf ( "Downloaded %s of %s" , units . CustomSize ( "%.2f%s" , float64 ( current ), 1000.0 , [] string { "B" , "kB" , "MB" , "GB" , "TB" , "PB" , "EB" , "ZB" , "YB" }), units . CustomSize ( "%.2f%s" , float64 ( progressMsg .Total ), 1000.0 , [] string { "B" , "kB" , "MB" , "GB" , "TB" , "PB" , "EB" , "ZB" , "YB" })) )
311
+ pbs . printProgressBar ( current , progressMsg .Total )
163
312
progressShown = true
164
313
case "error" :
165
314
return "" , progressShown , fmt .Errorf ("error pulling model: %s" , progressMsg .Message )
0 commit comments