Skip to content

Commit 524737d

Browse files
committed
Progress bar with more info
More info than previous progress bar Signed-off-by: Eric Curtin <[email protected]>
1 parent 08a8afe commit 524737d

File tree

1 file changed

+151
-2
lines changed

1 file changed

+151
-2
lines changed

desktop/desktop.go

Lines changed: 151 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,13 @@ import (
99
"html"
1010
"io"
1111
"net/http"
12+
"os"
13+
"runtime"
1214
"strconv"
1315
"strings"
16+
"syscall"
1417
"time"
18+
"unsafe"
1519

1620
"github.com/docker/go-units"
1721
"github.com/docker/model-distribution/distribution"
@@ -106,6 +110,147 @@ func (c *Client) Status() Status {
106110
}
107111
}
108112

113+
func humanReadableSize(size float64) string {
114+
return units.CustomSize("%.2f%s", float64(size), 1000.0, []string{"B", "kB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB"})
115+
}
116+
117+
func humanReadableSizePad(size float64, width int) string {
118+
return fmt.Sprintf("%*s", width, humanReadableSize(size))
119+
}
120+
121+
func humanReadableTimePad(seconds int64, width int) string {
122+
var s string
123+
if seconds < 60 {
124+
s = fmt.Sprintf("%ds", seconds)
125+
} else if seconds < 3600 {
126+
s = fmt.Sprintf("%dm %02ds", seconds/60, seconds%60)
127+
} else {
128+
s = fmt.Sprintf("%dh %02dm %02ds", seconds/3600, (seconds%3600)/60, seconds%60)
129+
}
130+
return fmt.Sprintf("%*s", width, s)
131+
}
132+
133+
// ProgressBarState tracks the running totals and timing for speed/ETA
134+
type ProgressBarState struct {
135+
LastDownloaded uint64
136+
LastTime time.Time
137+
StartTime time.Time
138+
UpdateInterval time.Duration // New: interval between updates
139+
lastPrint time.Time // New: last time the progress bar was printed
140+
}
141+
142+
// formatBar calculates the bar width and filled bar string.
143+
func (pbs *ProgressBarState) formatBar(percent float64, termWidth int, prefix, suffix string) string {
144+
barWidth := termWidth - len(prefix) - len(suffix) - 4
145+
if barWidth < 10 {
146+
barWidth = 10
147+
}
148+
filled := int(percent / 100 * float64(barWidth))
149+
if filled > barWidth {
150+
filled = barWidth
151+
}
152+
bar := strings.Repeat("█", filled) + strings.Repeat(" ", barWidth-filled)
153+
return bar
154+
}
155+
156+
// calcSpeed calculates the current download speed.
157+
func (pbs *ProgressBarState) calcSpeed(current uint64, now time.Time) float64 {
158+
elapsed := now.Sub(pbs.LastTime).Seconds()
159+
if elapsed <= 0 {
160+
return 0
161+
}
162+
163+
speed := float64(current-pbs.LastDownloaded) / elapsed
164+
pbs.LastTime = now
165+
pbs.LastDownloaded = current
166+
167+
return speed
168+
}
169+
170+
// formatSuffix returns the suffix string showing human readable sizes, speed, and ETA.
171+
func (pbs *ProgressBarState) fmtSuffix(current, total uint64, speed float64, eta int64) string {
172+
return fmt.Sprintf("%s/%s %s/s %s",
173+
humanReadableSizePad(float64(current), 10),
174+
humanReadableSize(float64(total)),
175+
humanReadableSizePad(speed, 10),
176+
humanReadableTimePad(eta, 16),
177+
)
178+
}
179+
180+
// calcETA calculates the estimated time remaining.
181+
func (pbs *ProgressBarState) calcETA(current, total uint64, speed float64) int64 {
182+
if speed <= 0 {
183+
return 0
184+
}
185+
return int64(float64(total-current) / speed)
186+
}
187+
188+
// printProgressBar prints/updates a progress bar in the terminal
189+
// Only prints if UpdateInterval has passed since last print, or always if interval=0
190+
func (pbs *ProgressBarState) printProgressBar(current, total uint64) {
191+
if pbs.StartTime.IsZero() {
192+
pbs.StartTime = time.Now()
193+
pbs.LastTime = pbs.StartTime
194+
pbs.LastDownloaded = current
195+
pbs.lastPrint = pbs.StartTime
196+
}
197+
198+
now := time.Now()
199+
// Only update display if enough time passed,
200+
// unless interval is 0 (always print)
201+
if pbs.UpdateInterval > 0 && now.Sub(pbs.lastPrint) < pbs.UpdateInterval && current != total {
202+
return
203+
}
204+
205+
pbs.lastPrint = now
206+
termWidth := getTerminalWidth()
207+
percent := float64(current) / float64(total) * 100
208+
prefix := fmt.Sprintf("%3.0f%% |", percent)
209+
speed := pbs.calcSpeed(current, now)
210+
eta := pbs.calcETA(current, total, speed)
211+
suffix := pbs.fmtSuffix(current, total, speed, eta)
212+
bar := pbs.formatBar(percent, termWidth, prefix, suffix)
213+
fmt.Fprintf(os.Stderr, "\r%s%s| %s", prefix, bar, suffix)
214+
}
215+
216+
func getTerminalWidthUnix() (int, error) {
217+
type winsize struct {
218+
Row uint16
219+
Col uint16
220+
Xpixel uint16
221+
Ypixel uint16
222+
}
223+
ws := &winsize{}
224+
retCode, _, errno := syscall.Syscall6(
225+
syscall.SYS_IOCTL,
226+
uintptr(os.Stdout.Fd()),
227+
uintptr(syscall.TIOCGWINSZ),
228+
uintptr(unsafe.Pointer(ws)),
229+
0, 0, 0,
230+
)
231+
if int(retCode) == -1 {
232+
return 0, errno
233+
}
234+
return int(ws.Col), nil
235+
}
236+
237+
// getTerminalSize tries to get the terminal width (default 80 if fails)
238+
func getTerminalWidth() int {
239+
var width int
240+
var err error
241+
default_width := 80
242+
if runtime.GOOS == "windows" { // to be implemented
243+
return default_width
244+
}
245+
246+
width, err = getTerminalWidthUnix()
247+
if width == 0 || err != nil {
248+
return default_width
249+
}
250+
251+
return width
252+
}
253+
109254
func (c *Client) Pull(model string, ignoreRuntimeMemoryCheck bool, progress func(string)) (string, bool, error) {
110255
model = normalizeHuggingFaceModelName(model)
111256
jsonData, err := json.Marshal(dmrm.ModelCreateRequest{From: model, IgnoreRuntimeMemoryCheck: ignoreRuntimeMemoryCheck})
@@ -130,10 +275,14 @@ func (c *Client) Pull(model string, ignoreRuntimeMemoryCheck bool, progress func
130275
}
131276

132277
progressShown := false
133-
current := uint64(0) // Track cumulative progress across all layers
278+
// Track cumulative progress across all layers
279+
current := uint64(0)
134280
layerProgress := make(map[string]uint64) // Track progress per layer ID
135281

136282
scanner := bufio.NewScanner(resp.Body)
283+
pbs := &ProgressBarState{
284+
UpdateInterval: time.Millisecond * 100,
285+
}
137286
for scanner.Scan() {
138287
progressLine := scanner.Text()
139288
if progressLine == "" {
@@ -159,7 +308,7 @@ func (c *Client) Pull(model string, ignoreRuntimeMemoryCheck bool, progress func
159308
current += layerCurrent
160309
}
161310

162-
progress(fmt.Sprintf("Downloaded %s of %s", units.CustomSize("%.2f%s", float64(current), 1000.0, []string{"B", "kB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB"}), units.CustomSize("%.2f%s", float64(progressMsg.Total), 1000.0, []string{"B", "kB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB"})))
311+
pbs.printProgressBar(current, progressMsg.Total)
163312
progressShown = true
164313
case "error":
165314
return "", progressShown, fmt.Errorf("error pulling model: %s", progressMsg.Message)

0 commit comments

Comments
 (0)