diff --git a/progressbar/progressbar.go b/progressbar/progressbar.go index ab642e2..7d611af 100644 --- a/progressbar/progressbar.go +++ b/progressbar/progressbar.go @@ -1,7 +1,6 @@ package progressbar import ( - "bytes" "context" "fmt" "io" @@ -9,7 +8,6 @@ import ( "strconv" "strings" "sync" - "sync/atomic" "text/template" "time" @@ -17,25 +15,38 @@ import ( "github.com/sunshineplan/utils/unit" ) +const ( + defaultRefresh = 5 * time.Second + maxRefreshMultiple = 3 +) + const defaultTemplate = `[{{.Done}}{{.Undone}}] {{.Speed}} {{.Current -}} ({{.Percent}}) of {{.Total}}{{if .Additional}} [{{.Additional}}]{{end}} {{.Elapsed}} {{.Left}} ` -// ProgressBar is a simple progress bar. -type ProgressBar struct { - mu sync.Mutex +var dots = []string{". ", ".. ", "..."} - ctx context.Context - cancel context.CancelFunc - done chan struct{} +// ProgressBar represents a customizable progress bar for tracking task progress. +// It supports configurable templates, units, and refresh intervals. +type ProgressBar struct { + mu sync.Mutex + buf strings.Builder + + ctx context.Context + cancel context.CancelFunc + msgChan chan string + done chan struct{} + start time.Time + last string + lastWidth int + + blockWidth int + refreshInterval time.Duration + renderInterval time.Duration + template *template.Template - start time.Time - blockWidth int - refresh time.Duration - template *template.Template current counter.Counter total int64 - additional atomic.Value - lastWidth int + additional string speed float64 unit string } @@ -48,7 +59,8 @@ type format struct { Elapsed, Left string } -// New returns a new ProgressBar with default options. +// New creates a new ProgressBar with the specified total count and default options. +// It panics if total is less than or equal to zero. func New(total int) *ProgressBar { return New64(int64(total)) } @@ -58,54 +70,91 @@ func New64(total int64) *ProgressBar { if total <= 0 { panic(fmt.Sprintf("invalid total number: %d", total)) } - ctx, cancel := context.WithCancel(context.Background()) return &ProgressBar{ - ctx: ctx, - cancel: cancel, - done: make(chan struct{}, 1), - - blockWidth: 40, - refresh: 5 * time.Second, - template: template.Must(template.New("ProgressBar").Parse(defaultTemplate)), - total: int64(total), + ctx: ctx, + cancel: cancel, + msgChan: make(chan string, 1), + done: make(chan struct{}), + blockWidth: 40, + refreshInterval: defaultRefresh, + template: template.Must(template.New("ProgressBar").Parse(defaultTemplate)), + total: int64(total), } } -// SetWidth sets progress bar block width. +// SetWidth sets the progress bar block width. +// It panics if called after the progress bar has started or if blockWidth is less than or equal to zero. func (pb *ProgressBar) SetWidth(blockWidth int) *ProgressBar { + pb.mu.Lock() + defer pb.mu.Unlock() + if !pb.start.IsZero() { + panic("progress bar is already started") + } + if blockWidth <= 0 { + panic(fmt.Sprintf("invalid block width: %d", blockWidth)) + } pb.blockWidth = blockWidth - return pb } -// SetRefresh sets progress bar refresh time for check speed. -func (pb *ProgressBar) SetRefresh(refresh time.Duration) *ProgressBar { - pb.refresh = refresh +// SetRefreshInterval sets progress bar refresh interval time for check speed. +// It panics if called after the progress bar has started or if interval is less than or equal to zero. +func (pb *ProgressBar) SetRefreshInterval(interval time.Duration) *ProgressBar { + pb.mu.Lock() + defer pb.mu.Unlock() + if !pb.start.IsZero() { + panic("progress bar is already started") + } + if interval <= 0 { + panic(fmt.Sprintf("invalid refresh interval: %v", interval)) + } + pb.refreshInterval = interval + return pb +} +// SetRenderInterval sets the interval for updating the progress bar display. +// It panics if called after the progress bar has started or if interval is less than or equal to zero. +func (pb *ProgressBar) SetRenderInterval(interval time.Duration) *ProgressBar { + pb.mu.Lock() + defer pb.mu.Unlock() + if !pb.start.IsZero() { + panic("progress bar is already started") + } + if interval <= 0 { + panic(fmt.Sprintf("invalid render interval: %v", interval)) + } + pb.renderInterval = interval return pb } // SetTemplate sets progress bar template. -func (pb *ProgressBar) SetTemplate(tmplt string) (err error) { +func (pb *ProgressBar) SetTemplate(tmplt string) error { + pb.mu.Lock() + defer pb.mu.Unlock() + if !pb.start.IsZero() { + return fmt.Errorf("progress bar is already started") + } t := template.New("ProgressBar") - if _, err = t.Parse(tmplt); err != nil { - return + if _, err := t.Parse(tmplt); err != nil { + return fmt.Errorf("failed to parse template: %w", err) } - - if err = t.Execute(io.Discard, format{}); err != nil { - return + if err := t.Execute(io.Discard, format{}); err != nil { + return fmt.Errorf("failed to execute template: %w", err) } - pb.template = t - - return + return nil } // SetUnit sets progress bar unit. +// It panics if called after the progress bar has started. func (pb *ProgressBar) SetUnit(unit string) *ProgressBar { + pb.mu.Lock() + defer pb.mu.Unlock() + if !pb.start.IsZero() { + panic("progress bar is already started") + } pb.unit = unit - return pb } @@ -116,42 +165,50 @@ func (pb *ProgressBar) Add(n int64) { // Additional adds the specified string to the progress bar. func (pb *ProgressBar) Additional(s string) { - pb.additional.Store(s) + pb.mu.Lock() + defer pb.mu.Unlock() + pb.additional = s } func (pb *ProgressBar) now() int64 { return pb.current.Load() } -func (pb *ProgressBar) print(f format) { - var buf bytes.Buffer - pb.template.Execute(&buf, f) - - width := buf.Len() - if width < pb.lastWidth { - io.WriteString(os.Stdout, - fmt.Sprintf("\r%s\r%s", strings.Repeat(" ", pb.lastWidth), buf.Bytes())) +func (pb *ProgressBar) print(s string, msg bool) { + pb.mu.Lock() + defer pb.mu.Unlock() + pb.buf.Reset() + if len(s) < pb.lastWidth { + pb.buf.WriteRune('\r') + pb.buf.WriteString(strings.Repeat(" ", pb.lastWidth)) + pb.buf.WriteRune('\r') + pb.buf.WriteString(s) } else { - io.WriteString(os.Stdout, "\r\r"+buf.String()) + pb.buf.WriteRune('\r') + pb.buf.WriteString(s) } - - pb.lastWidth = width + if msg { + pb.buf.WriteRune('\n') + pb.buf.WriteString(pb.last) + } else { + pb.last = s + pb.lastWidth = len(s) + } + io.WriteString(os.Stdout, pb.buf.String()) } func (pb *ProgressBar) startRefresh() { - start := time.Now() - maxRefresh := pb.refresh * 3 - - ticker := time.NewTicker(pb.refresh) + start := pb.start + maxRefresh := pb.refreshInterval * maxRefreshMultiple + ticker := time.NewTicker(pb.refreshInterval) defer ticker.Stop() - for { last := pb.now() select { case <-ticker.C: now := pb.now() totalSpeed := float64(now) / (float64(time.Since(start)) / float64(time.Second)) - intervalSpeed := float64(now-last) / (float64(pb.refresh) / float64(time.Second)) + intervalSpeed := float64(now-last) / (float64(pb.refreshInterval) / float64(time.Second)) pb.mu.Lock() if intervalSpeed == 0 { pb.speed = totalSpeed @@ -159,9 +216,9 @@ func (pb *ProgressBar) startRefresh() { pb.speed = intervalSpeed } pb.mu.Unlock() - if intervalSpeed == 0 && pb.refresh < maxRefresh { - pb.refresh += time.Second - ticker.Reset(pb.refresh) + if intervalSpeed == 0 && pb.refreshInterval < maxRefresh { + pb.refreshInterval += time.Second + ticker.Reset(pb.refreshInterval) } case <-pb.ctx.Done(): return @@ -172,67 +229,65 @@ func (pb *ProgressBar) startRefresh() { } func (pb *ProgressBar) startCount() { - ticker := time.NewTicker(time.Second) - defer ticker.Stop() - + interval := pb.renderInterval + if interval == 0 { + interval = time.Second + } + ticker := time.NewTicker(interval) + defer func() { + ticker.Stop() + close(pb.done) + close(pb.msgChan) + }() + var lastNow int64 + var f format + if pb.unit == "bytes" { + f.Total = unit.ByteSize(pb.total).String() + } else { + f.Total = strconv.FormatInt(pb.total, 10) + } + var buf strings.Builder + var dot int for { select { case <-ticker.C: now := min(pb.now(), pb.total) - done := int(int64(pb.blockWidth) * now / pb.total) - percent := float64(now) * 100 / float64(pb.total) - - var progressed string - if now < pb.total && done != 0 { - progressed = strings.Repeat("=", done-1) + ">" - } else { - progressed = strings.Repeat("=", done) - } - pb.mu.Lock() - var left time.Duration - if pb.speed != 0 { - left = time.Duration(float64(pb.total-now)/pb.speed) * time.Second - } - - var f format - if pb.unit == "bytes" { - f = format{ - Done: progressed, - Undone: strings.Repeat(" ", pb.blockWidth-done), - Speed: unit.ByteSize(pb.speed).String() + "/s", - Current: unit.ByteSize(now).String(), - Percent: fmt.Sprintf("%.2f%%", percent), - Total: unit.ByteSize(pb.total).String(), - Additional: pb.additional.Load().(string), - Elapsed: fmt.Sprintf("Elapsed: %s", time.Since(pb.start).Truncate(time.Second)), - Left: fmt.Sprintf("Left: %s", left.Truncate(time.Second)), + if now != lastNow || f.Done == "" { + lastNow = now + done := int(int64(pb.blockWidth) * now / pb.total) + percent := float64(now) * 100 / float64(pb.total) + if now < pb.total && done != 0 { + f.Done = strings.Repeat("=", done-1) + ">" + } else { + f.Done = strings.Repeat("=", done) } - } else { - f = format{ - Done: progressed, - Undone: strings.Repeat(" ", pb.blockWidth-done), - Speed: fmt.Sprintf("%.2f/s", pb.speed), - Current: strconv.FormatInt(now, 10), - Percent: fmt.Sprintf("%.2f%%", percent), - Total: strconv.FormatInt(pb.total, 10), - Additional: pb.additional.Load().(string), - Elapsed: fmt.Sprintf("Elapsed: %s", time.Since(pb.start).Truncate(time.Second)), - Left: fmt.Sprintf("Left: %s", left.Truncate(time.Second)), + f.Undone = strings.Repeat(" ", pb.blockWidth-done) + f.Percent = fmt.Sprintf("%.2f%%", percent) + if pb.unit == "bytes" { + f.Current = unit.ByteSize(now).String() + } else { + f.Current = strconv.FormatInt(now, 10) } } - + f.Additional = pb.additional + f.Elapsed = fmt.Sprintf("Elapsed: %s", time.Since(pb.start).Truncate(time.Second)) if pb.speed == 0 { f.Speed = "--/s" - f.Left = fmt.Sprintf("Left: calculating%s%s", - strings.Repeat(".", time.Now().Second()%3+1), - strings.Repeat(" ", 2-time.Now().Second()%3), - ) + f.Left = "Left: calculating" + dots[dot%3] + dot++ + } else { + if pb.unit == "bytes" { + f.Speed = unit.ByteSize(pb.speed).String() + "/s" + } else { + f.Speed = fmt.Sprintf("%.2f/s", pb.speed) + } + f.Left = fmt.Sprintf("Left: %s", (time.Duration(float64(pb.total-now)/pb.speed) * time.Second).Truncate(time.Second)) } pb.mu.Unlock() - - pb.print(f) - + buf.Reset() + pb.template.Execute(&buf, f) + pb.print(buf.String(), false) if now == pb.total { totalSpeed := float64(pb.total) / (float64(time.Since(pb.start)) / float64(time.Second)) if pb.unit == "bytes" { @@ -241,14 +296,17 @@ func (pb *ProgressBar) startCount() { f.Speed = fmt.Sprintf("%.2f/s", totalSpeed) } f.Left = "Complete" - - pb.print(f) + buf.Reset() + pb.template.Execute(&buf, f) + pb.print(buf.String(), false) io.WriteString(os.Stdout, "\n") - - close(pb.done) return } + case msg := <-pb.msgChan: + pb.print(msg, true) case <-pb.ctx.Done(): + pb.mu.Lock() + defer pb.mu.Unlock() io.WriteString(os.Stdout, "\nCancelled\n") return } @@ -257,32 +315,56 @@ func (pb *ProgressBar) startCount() { // Start starts the progress bar. func (pb *ProgressBar) Start() error { + pb.mu.Lock() + defer pb.mu.Unlock() if !pb.start.IsZero() { return fmt.Errorf("progress bar is already started") } - pb.start = time.Now() - pb.additional.Store("") - go pb.startRefresh() go pb.startCount() + return nil +} +// Message sets a message to be displayed on the progress bar. +func (pb *ProgressBar) Message(msg string) error { + defer func() { recover() }() + select { + case <-pb.done: + return fmt.Errorf("progress bar is already finished") + default: + } + select { + case pb.msgChan <- msg: + default: + } return nil } -// Done waits the progress bar finished. +// Done waits the progress bar finished. Same as [ProgressBar.Wait](). func (pb *ProgressBar) Done() { + pb.Wait() +} + +// Wait blocks until the progress bar is finished. +func (pb *ProgressBar) Wait() { <-pb.done } // Cancel cancels the progress bar. func (pb *ProgressBar) Cancel() { pb.cancel() - close(pb.done) } // FromReader starts the progress bar from a reader. func (pb *ProgressBar) FromReader(r io.Reader, w io.Writer) (int64, error) { - pb.Start() - return io.Copy(pb.current.AddWriter(w), r) + if err := pb.Start(); err != nil { + return 0, err + } + n, err := io.Copy(pb.current.AddWriter(w), r) + if err != nil { + pb.Cancel() + return n, err + } + return n, nil } diff --git a/progressbar/progressbar_test.go b/progressbar/progressbar_test.go index 0226394..9795c46 100644 --- a/progressbar/progressbar_test.go +++ b/progressbar/progressbar_test.go @@ -1,6 +1,7 @@ package progressbar import ( + "fmt" "io" "net/http" "strconv" @@ -15,7 +16,7 @@ func TestProgessBar(t *testing.T) { } }() - pb := New(15).SetRefresh(4 * time.Second) + pb := New(15).SetRefreshInterval(4 * time.Second) pb.Start() pb.Additional("refreshes in 4s") for range pb.total { @@ -24,7 +25,7 @@ func TestProgessBar(t *testing.T) { } pb.Done() - pb = New(10).SetRefresh(500 * time.Millisecond) + pb = New(10).SetRefreshInterval(500 * time.Millisecond) pb.Start() pb.Additional("refreshes in 500ms") for range pb.total { @@ -37,8 +38,45 @@ func TestProgessBar(t *testing.T) { pb.Start() } +func TestMessage(t *testing.T) { + pb := New(15) + pb.Start() + errCh := make(chan error, 1) + stopCh := make(chan struct{}) + go func() { + i := 0 + for { + select { + case <-stopCh: + return + default: + time.Sleep(500 * time.Millisecond) + if err := pb.Message(fmt.Sprintf("test messages (%d)", i)); err != nil { + errCh <- err + return + } + } + i++ + } + }() + for range pb.total { + pb.Add(1) + time.Sleep(time.Second) + } + pb.Done() + close(stopCh) + select { + case err := <-errCh: + t.Fatal(err) + default: + } + if err := pb.Message("test messages"); err == nil { + t.Fatal("expected non-nil error; got nil error") + } +} + func TestCancel(t *testing.T) { - pb := New(15).SetRefresh(4 * time.Second) + pb := New(15).SetRefreshInterval(4 * time.Second) pb.Start() go func() { time.Sleep(3 * time.Second)