From d7b0757daf0067e6fa09ab5163e59fbf5174e8f3 Mon Sep 17 00:00:00 2001 From: sunshineplan Date: Thu, 25 Sep 2025 15:08:36 +0800 Subject: [PATCH 01/40] progressbar --- progressbar/progressbar.go | 50 +++++++++++++-------------------- progressbar/progressbar_test.go | 14 ++++----- 2 files changed, 27 insertions(+), 37 deletions(-) diff --git a/progressbar/progressbar.go b/progressbar/progressbar.go index 7d611af..99482b5 100644 --- a/progressbar/progressbar.go +++ b/progressbar/progressbar.go @@ -27,7 +27,7 @@ var dots = []string{". ", ".. ", "..."} // ProgressBar represents a customizable progress bar for tracking task progress. // It supports configurable templates, units, and refresh intervals. -type ProgressBar struct { +type ProgressBar[T int | int64] struct { mu sync.Mutex buf strings.Builder @@ -61,17 +61,12 @@ type format struct { // New creates a new ProgressBar with the specified total count and default options. // It panics if total is less than or equal to zero. -func New(total int) *ProgressBar { - return New64(int64(total)) -} - -// New64 returns a new ProgressBar with default options. -func New64(total int64) *ProgressBar { +func New[T int | int64](total T) *ProgressBar[T] { if total <= 0 { panic(fmt.Sprintf("invalid total number: %d", total)) } ctx, cancel := context.WithCancel(context.Background()) - return &ProgressBar{ + return &ProgressBar[T]{ ctx: ctx, cancel: cancel, msgChan: make(chan string, 1), @@ -85,7 +80,7 @@ func New64(total int64) *ProgressBar { // SetWidth sets the progress bar block width. // It panics if called after the progress bar has started or if blockWidth is less than or equal to zero. -func (pb *ProgressBar) SetWidth(blockWidth int) *ProgressBar { +func (pb *ProgressBar[T]) SetWidth(blockWidth int) *ProgressBar[T] { pb.mu.Lock() defer pb.mu.Unlock() if !pb.start.IsZero() { @@ -100,7 +95,7 @@ func (pb *ProgressBar) SetWidth(blockWidth int) *ProgressBar { // SetRefreshInterval sets progress bar refresh interval time for check speed. // It panics if called after the progress bar has started or if interval is less than or equal to zero. -func (pb *ProgressBar) SetRefreshInterval(interval time.Duration) *ProgressBar { +func (pb *ProgressBar[T]) SetRefreshInterval(interval time.Duration) *ProgressBar[T] { pb.mu.Lock() defer pb.mu.Unlock() if !pb.start.IsZero() { @@ -115,7 +110,7 @@ func (pb *ProgressBar) SetRefreshInterval(interval time.Duration) *ProgressBar { // SetRenderInterval sets the interval for updating the progress bar display. // It panics if called after the progress bar has started or if interval is less than or equal to zero. -func (pb *ProgressBar) SetRenderInterval(interval time.Duration) *ProgressBar { +func (pb *ProgressBar[T]) SetRenderInterval(interval time.Duration) *ProgressBar[T] { pb.mu.Lock() defer pb.mu.Unlock() if !pb.start.IsZero() { @@ -129,7 +124,7 @@ func (pb *ProgressBar) SetRenderInterval(interval time.Duration) *ProgressBar { } // SetTemplate sets progress bar template. -func (pb *ProgressBar) SetTemplate(tmplt string) error { +func (pb *ProgressBar[T]) SetTemplate(tmplt string) error { pb.mu.Lock() defer pb.mu.Unlock() if !pb.start.IsZero() { @@ -148,7 +143,7 @@ func (pb *ProgressBar) SetTemplate(tmplt string) error { // SetUnit sets progress bar unit. // It panics if called after the progress bar has started. -func (pb *ProgressBar) SetUnit(unit string) *ProgressBar { +func (pb *ProgressBar[T]) SetUnit(unit string) *ProgressBar[T] { pb.mu.Lock() defer pb.mu.Unlock() if !pb.start.IsZero() { @@ -159,22 +154,22 @@ func (pb *ProgressBar) SetUnit(unit string) *ProgressBar { } // Add adds the specified amount to the progress bar. -func (pb *ProgressBar) Add(n int64) { - pb.current.Add(n) +func (pb *ProgressBar[T]) Add(n T) { + pb.current.Add(int64(n)) } // Additional adds the specified string to the progress bar. -func (pb *ProgressBar) Additional(s string) { +func (pb *ProgressBar[T]) Additional(s string) { pb.mu.Lock() defer pb.mu.Unlock() pb.additional = s } -func (pb *ProgressBar) now() int64 { +func (pb *ProgressBar[T]) now() int64 { return pb.current.Load() } -func (pb *ProgressBar) print(s string, msg bool) { +func (pb *ProgressBar[T]) print(s string, msg bool) { pb.mu.Lock() defer pb.mu.Unlock() pb.buf.Reset() @@ -197,7 +192,7 @@ func (pb *ProgressBar) print(s string, msg bool) { io.WriteString(os.Stdout, pb.buf.String()) } -func (pb *ProgressBar) startRefresh() { +func (pb *ProgressBar[T]) startRefresh() { start := pb.start maxRefresh := pb.refreshInterval * maxRefreshMultiple ticker := time.NewTicker(pb.refreshInterval) @@ -228,7 +223,7 @@ func (pb *ProgressBar) startRefresh() { } } -func (pb *ProgressBar) startCount() { +func (pb *ProgressBar[T]) startCount() { interval := pb.renderInterval if interval == 0 { interval = time.Second @@ -314,7 +309,7 @@ func (pb *ProgressBar) startCount() { } // Start starts the progress bar. -func (pb *ProgressBar) Start() error { +func (pb *ProgressBar[T]) Start() error { pb.mu.Lock() defer pb.mu.Unlock() if !pb.start.IsZero() { @@ -327,7 +322,7 @@ func (pb *ProgressBar) Start() error { } // Message sets a message to be displayed on the progress bar. -func (pb *ProgressBar) Message(msg string) error { +func (pb *ProgressBar[T]) Message(msg string) error { defer func() { recover() }() select { case <-pb.done: @@ -341,23 +336,18 @@ func (pb *ProgressBar) Message(msg string) error { return nil } -// Done waits the progress bar finished. Same as [ProgressBar.Wait](). -func (pb *ProgressBar) Done() { - pb.Wait() -} - // Wait blocks until the progress bar is finished. -func (pb *ProgressBar) Wait() { +func (pb *ProgressBar[T]) Wait() { <-pb.done } // Cancel cancels the progress bar. -func (pb *ProgressBar) Cancel() { +func (pb *ProgressBar[T]) Cancel() { pb.cancel() } // FromReader starts the progress bar from a reader. -func (pb *ProgressBar) FromReader(r io.Reader, w io.Writer) (int64, error) { +func (pb *ProgressBar[T]) FromReader(r io.Reader, w io.Writer) (int64, error) { if err := pb.Start(); err != nil { return 0, err } diff --git a/progressbar/progressbar_test.go b/progressbar/progressbar_test.go index 9795c46..49b46b4 100644 --- a/progressbar/progressbar_test.go +++ b/progressbar/progressbar_test.go @@ -23,7 +23,7 @@ func TestProgessBar(t *testing.T) { pb.Add(1) time.Sleep(time.Second) } - pb.Done() + pb.Wait() pb = New(10).SetRefreshInterval(500 * time.Millisecond) pb.Start() @@ -32,7 +32,7 @@ func TestProgessBar(t *testing.T) { pb.Add(1) time.Sleep(time.Second) } - pb.Done() + pb.Wait() pb = New(0) pb.Start() @@ -63,7 +63,7 @@ func TestMessage(t *testing.T) { pb.Add(1) time.Sleep(time.Second) } - pb.Done() + pb.Wait() close(stopCh) select { case err := <-errCh: @@ -88,7 +88,7 @@ func TestCancel(t *testing.T) { time.Sleep(time.Second) } }() - pb.Done() + pb.Wait() } func TestFromReader(t *testing.T) { @@ -97,7 +97,7 @@ func TestFromReader(t *testing.T) { t.Fatal(err) } defer resp.Body.Close() - total, err := strconv.Atoi(resp.Header.Get("content-length")) + total, err := strconv.ParseInt(resp.Header.Get("content-length"), 10, 64) if err != nil { t.Fatal(err) } @@ -105,11 +105,11 @@ func TestFromReader(t *testing.T) { if _, err := pb.FromReader(resp.Body, io.Discard); err != nil { t.Fatal(err) } - pb.Done() + pb.Wait() } func TestSetTemplate(t *testing.T) { - pb := &ProgressBar{} + pb := &ProgressBar[int]{} if err := pb.SetTemplate(`{{.Done}}`); err != nil { t.Error(err) } From 3f9deedaa6ca80f6310edcfa39f9c9f7a51bab09 Mon Sep 17 00:00:00 2001 From: sunshineplan Date: Thu, 25 Sep 2025 15:51:24 +0800 Subject: [PATCH 02/40] choice --- choice/choice.go | 56 ++++++++++++++++++++----------------------- choice/choice_test.go | 2 +- 2 files changed, 27 insertions(+), 31 deletions(-) diff --git a/choice/choice.go b/choice/choice.go index 6833148..004ddff 100644 --- a/choice/choice.go +++ b/choice/choice.go @@ -29,17 +29,13 @@ func Menu[E any](choices []E, showQuit bool) string { if len(choices) == 0 { return "" } - var digit int - for n := len(choices); n != 0; digit++ { - n /= 10 - } - option := fmt.Sprintf("%%%dd", digit) + digit := len(strconv.Itoa(len(choices))) var b strings.Builder for i, choice := range choices { - fmt.Fprintf(&b, "%s. %s\n", fmt.Sprintf(option, i+1), choiceStr(choice)) + fmt.Fprintf(&b, "%*d. %s\n", digit, i+1, choiceStr(choice)) } if showQuit { - fmt.Fprintf(&b, "%s. Quit\n", fmt.Sprintf(fmt.Sprintf("%%%ds", digit), "q")) + fmt.Fprintf(&b, "%*d. Quit\n", digit, 0) } return b.String() } @@ -51,13 +47,8 @@ var _ error = choiceError("") type choiceError string -func (err choiceError) Error() string { - return "bad choice: " + string(err) -} - -func (choiceError) Unwrap() error { - return ErrBadChoice -} +func (err choiceError) Error() string { return "bad choice: " + string(err) } +func (choiceError) Unwrap() error { return ErrBadChoice } func choose[E any](choice string, choices []E) (res E, err error) { n, err := strconv.Atoi(choice) @@ -92,26 +83,31 @@ func ChooseWithDefault[E any](choices []E, def int) (choice bool, res E, err err } else { prompt = "Please choose: " } - scanner := bufio.NewScanner(os.Stdin) - var b []byte - if def <= 0 { - for len(b) == 0 { - fmt.Print(prompt) - scanner.Scan() - b = bytes.TrimSpace(scanner.Bytes()) - } - } else { - fmt.Print(prompt) - scanner.Scan() - b = bytes.TrimSpace(scanner.Bytes()) - if len(b) == 0 { - return true, choices[def-1], nil - } + b, err := readLine(bufio.NewScanner(os.Stdin), prompt, def <= 0) + if err != nil { + return + } + if len(b) == 0 && def > 0 { + return true, choices[def-1], nil } - if bytes.EqualFold(b, []byte("q")) { + if bytes.EqualFold(b, []byte("0")) || bytes.EqualFold(b, []byte("q")) { return } choice = true res, err = choose(string(b), choices) return } + +func readLine(scanner *bufio.Scanner, prompt string, required bool) ([]byte, error) { + for { + fmt.Print(prompt) + if !scanner.Scan() { + return nil, scanner.Err() + } + b := bytes.TrimSpace(scanner.Bytes()) + if required && len(b) == 0 { + continue + } + return b, nil + } +} diff --git a/choice/choice_test.go b/choice/choice_test.go index b5fd7a2..206b1d8 100644 --- a/choice/choice_test.go +++ b/choice/choice_test.go @@ -30,7 +30,7 @@ func TestMenu(t *testing.T) { 8. hh 9. ii 10. jj - q. Quit + 0. Quit ` if s := Menu(choices, true); s != expect { t.Errorf("expected %q; got %q", expect, s) From e5cce6eb81846df21ab5ae26d152a0ff0b003e8a Mon Sep 17 00:00:00 2001 From: sunshineplan Date: Thu, 25 Sep 2025 16:54:14 +0800 Subject: [PATCH 03/40] clock --- clock/clock.go | 38 +++++++++++++++++--------------------- clock/clock_test.go | 2 +- 2 files changed, 18 insertions(+), 22 deletions(-) diff --git a/clock/clock.go b/clock/clock.go index 72effce..e92ca7c 100644 --- a/clock/clock.go +++ b/clock/clock.go @@ -25,23 +25,22 @@ type Clock struct { wall unique.Handle[uint64] } -func wall[Int int | float64 | uint64](wall Int) unique.Handle[uint64] { - for wall < 0 { - wall += secondsPerDay +func wall(i int64) unique.Handle[uint64] { + if i %= secondsPerDay; i < 0 { + i += secondsPerDay } - return unique.Make(uint64(int64(wall) % secondsPerDay)) + return unique.Make(uint64(i)) } func New(hour, min, sec int) Clock { - return Clock{wall(hour*secondsPerHour + min*secondsPerMinute + sec)} + return Clock{wall(int64(hour)*secondsPerHour + int64(min)*secondsPerMinute + int64(sec))} } var clockLayout = []string{ + "3:04PM", // [time.Kitchen] + "3:04:05PM", "15:04", - time.TimeOnly, - time.Kitchen, - "03:04PM", - "03:04:05PM", + "15:04:05", // [time.TimeOnly] } func Parse(v string) (Clock, error) { @@ -82,8 +81,8 @@ func (c Clock) Clock() (hour, min, sec int) { return } -func (c Clock) Seconds() int64 { - return int64(c.wall.Value()) +func (c Clock) Seconds() uint64 { + return c.wall.Value() } func (c Clock) Hour() int { @@ -100,14 +99,13 @@ func (c Clock) Second() int { func (c Clock) String() string { if c.wall == w0 { - return "" + return "invalid" } return fmt.Sprintf("%d:%02d:%02d", c.Hour(), c.Minute(), c.Second()) } -func (c Clock) MarshalText() (text []byte, err error) { - text = []byte(c.String()) - return +func (c Clock) MarshalText() ([]byte, error) { + return []byte(c.String()), nil } func (c *Clock) UnmarshalText(text []byte) error { @@ -140,11 +138,11 @@ func (c Clock) Compare(u Clock) int { } func (c Clock) Add(d time.Duration) Clock { - return Clock{wall(float64(c.wall.Value()) + d.Seconds())} + return Clock{wall(int64(c.wall.Value()) + int64(d.Seconds()))} } func (c Clock) Sub(u Clock) time.Duration { - return time.Duration(c.Seconds()-u.Seconds()) * time.Second + return time.Duration(int64(c.Seconds())-int64(u.Seconds())) * time.Second } func (c Clock) Since(u Clock) time.Duration { @@ -152,10 +150,8 @@ func (c Clock) Since(u Clock) time.Duration { } func (c Clock) Until(u Clock) time.Duration { - d := u.Seconds() - c.Seconds() - if d == 0 { - return 0 - } else if d < 0 { + d := int64(u.Seconds()) - int64(c.Seconds()) + if d < 0 { d += secondsPerDay } return time.Duration(d) * time.Second diff --git a/clock/clock_test.go b/clock/clock_test.go index d85b99a..cb1e67e 100644 --- a/clock/clock_test.go +++ b/clock/clock_test.go @@ -61,7 +61,7 @@ func TestParse(t *testing.T) { func TestSeconds(t *testing.T) { for i, testcase := range []struct { c Clock - expected int64 + expected uint64 }{ {New(0, 0, 0), 0}, {New(7, 1, 2), 7*secondsPerHour + 1*secondsPerMinute + 2}, From ab95d248c0e9037679d074c4cfa415f17060f2ed Mon Sep 17 00:00:00 2001 From: sunshineplan Date: Thu, 25 Sep 2025 17:01:13 +0800 Subject: [PATCH 04/40] clock comment --- clock/clock.go | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/clock/clock.go b/clock/clock.go index e92ca7c..ab5d88f 100644 --- a/clock/clock.go +++ b/clock/clock.go @@ -21,6 +21,7 @@ var ( var w0 unique.Handle[uint64] +// Clock represents a time within a day (hour, minute, second) type Clock struct { wall unique.Handle[uint64] } @@ -32,6 +33,8 @@ func wall(i int64) unique.Handle[uint64] { return unique.Make(uint64(i)) } +// New constructs a Clock from hour, minute, second. +// Values can exceed normal ranges; they are normalized modulo 24 hours. func New(hour, min, sec int) Clock { return Clock{wall(int64(hour)*secondsPerHour + int64(min)*secondsPerMinute + int64(sec))} } @@ -43,6 +46,7 @@ var clockLayout = []string{ "15:04:05", // [time.TimeOnly] } +// Parse parses a string into a Clock using supported layouts func Parse(v string) (Clock, error) { for _, layout := range clockLayout { if t, err := time.Parse(layout, v); err == nil { @@ -52,6 +56,7 @@ func Parse(v string) (Clock, error) { return Clock{}, fmt.Errorf("cannot parse %q as clock", v) } +// MustParse parses a string and panics if parsing fails func MustParse(v string) Clock { if c, err := Parse(v); err != nil { panic("clock: " + err.Error()) @@ -60,18 +65,22 @@ func MustParse(v string) Clock { } } +// ParseTime converts a time.Time into a Clock (hour, minute, second) func ParseTime(t time.Time) Clock { return New(t.Clock()) } +// Now returns the current local time as a Clock func Now() Clock { return ParseTime(time.Now()) } +// Time converts the Clock into a time.Time on the Unix epoch date func (c Clock) Time() time.Time { return time.Unix(int64(c.wall.Value()), 0).UTC() } +// Clock returns hour, minute, second components of the Clock func (c Clock) Clock() (hour, min, sec int) { sec = int(c.wall.Value()) hour = sec / secondsPerHour @@ -81,22 +90,27 @@ func (c Clock) Clock() (hour, min, sec int) { return } +// Seconds returns the number of seconds func (c Clock) Seconds() uint64 { return c.wall.Value() } +// Hour returns the hour of the Clock func (c Clock) Hour() int { return int(c.Seconds()%secondsPerDay) / secondsPerHour } +// Minute returns the minute of the Clock func (c Clock) Minute() int { return int(c.Seconds()%secondsPerHour) / secondsPerMinute } +// Second returns the second of the Clock func (c Clock) Second() int { return int(c.Seconds() % secondsPerMinute) } +// String returns the Clock as a string "H:MM:SS", or "invalid" if zero func (c Clock) String() string { if c.wall == w0 { return "invalid" @@ -104,10 +118,12 @@ func (c Clock) String() string { return fmt.Sprintf("%d:%02d:%02d", c.Hour(), c.Minute(), c.Second()) } +// MarshalText implements encoding.TextMarshaler func (c Clock) MarshalText() ([]byte, error) { return []byte(c.String()), nil } +// UnmarshalText implements encoding.TextUnmarshaler func (c *Clock) UnmarshalText(text []byte) error { clock, err := Parse(string(text)) if err != nil { @@ -117,38 +133,48 @@ func (c *Clock) UnmarshalText(text []byte) error { return nil } +// IsValid returns true if the Clock is not the zero/invalid value func (c Clock) IsValid() bool { return c.wall != w0 } +// After returns true if c is after u func (c Clock) After(u Clock) bool { return c.wall.Value() > u.wall.Value() } +// Before returns true if c is before u func (c Clock) Before(u Clock) bool { return c.wall.Value() < u.wall.Value() } +// Equal returns true if c equals u func (c Clock) Equal(u Clock) bool { return c.wall == u.wall } +// Compare compares c with u and returns -1,0,1 func (c Clock) Compare(u Clock) int { return cmp.Compare(c.wall.Value(), u.wall.Value()) } +// Add adds a duration to the Clock and returns a new Clock +// Duration is truncated to seconds func (c Clock) Add(d time.Duration) Clock { return Clock{wall(int64(c.wall.Value()) + int64(d.Seconds()))} } +// Sub returns the duration between c and u (c - u) func (c Clock) Sub(u Clock) time.Duration { return time.Duration(int64(c.Seconds())-int64(u.Seconds())) * time.Second } +// Since returns the duration from u until c (c - u, wrapped around 24h) func (c Clock) Since(u Clock) time.Duration { return u.Until(c) } +// Until returns the duration from c until u (always non-negative, wraps around 24h) func (c Clock) Until(u Clock) time.Duration { d := int64(u.Seconds()) - int64(c.Seconds()) if d < 0 { @@ -157,10 +183,12 @@ func (c Clock) Until(u Clock) time.Duration { return time.Duration(d) * time.Second } +// Since returns duration from u until now func Since(c Clock) time.Duration { return c.Until(Now()) } +// Until returns duration from now until c func Until(c Clock) time.Duration { return Now().Until(c) } From 378ab709659a3045c8dba8502999db93edeaf035 Mon Sep 17 00:00:00 2001 From: sunshineplan Date: Fri, 26 Sep 2025 10:52:34 +0800 Subject: [PATCH 05/40] confirm --- confirm/confirm.go | 24 ++++++++++---- confirm/confirm_test.go | 71 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 89 insertions(+), 6 deletions(-) create mode 100644 confirm/confirm_test.go diff --git a/confirm/confirm.go b/confirm/confirm.go index 6e05b20..169fdad 100644 --- a/confirm/confirm.go +++ b/confirm/confirm.go @@ -1,12 +1,19 @@ package confirm import ( + "bufio" "fmt" + "io" + "os" "strings" ) // Do asks the user for confirmation. func Do(prompt string, attempts int) bool { + return do(prompt, attempts, os.Stdout, os.Stdin) +} + +func do(prompt string, attempts int, w io.Writer, r io.Reader) bool { if prompt == "" { prompt = "Are you sure?" } @@ -14,11 +21,16 @@ func Do(prompt string, attempts int) bool { attempts = 3 } - fmt.Print(prompt, " (yes/no): ") - var input string + if _, err := fmt.Fprintf(w, "%s (yes/no): ", prompt); err != nil { + fmt.Println("Error writing to output:", err) + return false + } + br := bufio.NewReader(r) for ; attempts > 0; attempts-- { - if _, err := fmt.Scanln(&input); err != nil { - fmt.Println(err) + input, err := br.ReadString('\n') + if err != nil { + fmt.Fprintln(w, "Error reading input:", err) + continue } switch strings.ToLower(strings.TrimSpace(input)) { case "y", "yes": @@ -27,10 +39,10 @@ func Do(prompt string, attempts int) bool { return false default: if attempts > 1 { - fmt.Print("Please type 'yes' or 'no': ") + fmt.Fprint(w, "Please type 'yes' or 'no': ") } } } - fmt.Println("Max retries exceeded.") + fmt.Fprintln(w, "Max retries exceeded.") return false } diff --git a/confirm/confirm_test.go b/confirm/confirm_test.go new file mode 100644 index 0000000..7365944 --- /dev/null +++ b/confirm/confirm_test.go @@ -0,0 +1,71 @@ +package confirm + +import ( + "bytes" + "strings" + "testing" +) + +func TestDo_YesResponses(t *testing.T) { + tests := []string{"y\n", "Y\n", "yes\n", " YES \n"} + for _, input := range tests { + t.Run(strings.TrimSpace(input), func(t *testing.T) { + in := strings.NewReader(input) + var out bytes.Buffer + got := do("Confirm?", 1, &out, in) + if !got { + t.Errorf("expected true for input %q, got false", input) + } + }) + } +} + +func TestDo_NoResponses(t *testing.T) { + tests := []string{"n\n", "N\n", "no\n", " NO \n"} + for _, input := range tests { + t.Run(strings.TrimSpace(input), func(t *testing.T) { + in := strings.NewReader(input) + var out bytes.Buffer + got := do("Confirm?", 1, &out, in) + if got { + t.Errorf("expected false for input %q, got true", input) + } + }) + } +} + +func TestDo_InvalidThenYes(t *testing.T) { + in := strings.NewReader("maybe\nYES\n") + var out bytes.Buffer + got := do("Confirm?", 2, &out, in) + if !got { + t.Errorf("expected true after invalid input then yes, got false") + } + if !strings.Contains(out.String(), "Please type 'yes' or 'no':") { + t.Errorf("expected retry prompt, got %q", out.String()) + } +} + +func TestDo_MaxRetriesExceeded(t *testing.T) { + in := strings.NewReader("maybe\nidk\nnope\n") + var out bytes.Buffer + got := do("Confirm?", 3, &out, in) + if got { + t.Errorf("expected false after max retries, got true") + } + if !strings.Contains(out.String(), "Max retries exceeded.") { + t.Errorf("expected max retries message, got %q", out.String()) + } +} + +func TestDo_DefaultPromptAndAttempts(t *testing.T) { + in := strings.NewReader("y\n") + var out bytes.Buffer + got := do("", 0, &out, in) + if !got { + t.Errorf("expected true with default prompt, got false") + } + if !strings.Contains(out.String(), "Are you sure? (yes/no):") { + t.Errorf("expected default prompt, got %q", out.String()) + } +} From e482c4b0db1e1176f5cc73416f7a0c997d788ba8 Mon Sep 17 00:00:00 2001 From: sunshineplan Date: Fri, 26 Sep 2025 11:04:59 +0800 Subject: [PATCH 06/40] Update confirm.go --- confirm/confirm.go | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/confirm/confirm.go b/confirm/confirm.go index 169fdad..a870127 100644 --- a/confirm/confirm.go +++ b/confirm/confirm.go @@ -25,14 +25,17 @@ func do(prompt string, attempts int, w io.Writer, r io.Reader) bool { fmt.Println("Error writing to output:", err) return false } - br := bufio.NewReader(r) + scanner := bufio.NewScanner(r) for ; attempts > 0; attempts-- { - input, err := br.ReadString('\n') - if err != nil { - fmt.Fprintln(w, "Error reading input:", err) + if !scanner.Scan() { + if err := scanner.Err(); err != nil { + fmt.Fprintln(w, "Error reading input:", err) + } else { + fmt.Fprintln(w, "No input received (EOF).") + } continue } - switch strings.ToLower(strings.TrimSpace(input)) { + switch strings.ToLower(strings.TrimSpace(scanner.Text())) { case "y", "yes": return true case "n", "no": From 4a10d0e1603e6afedd5ec06af754615b69547281 Mon Sep 17 00:00:00 2001 From: sunshineplan Date: Fri, 26 Sep 2025 11:14:33 +0800 Subject: [PATCH 07/40] Update choice --- choice/choice.go | 7 ++- choice/choice_test.go | 101 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 107 insertions(+), 1 deletion(-) diff --git a/choice/choice.go b/choice/choice.go index 004ddff..4f426af 100644 --- a/choice/choice.go +++ b/choice/choice.go @@ -5,6 +5,7 @@ import ( "bytes" "errors" "fmt" + "io" "os" "strconv" "strings" @@ -70,6 +71,10 @@ func Choose[E any](choices []E) (choice bool, res E, err error) { // ChooseWithDefault function allows the user to make a choice from the given options with an optional default value. func ChooseWithDefault[E any](choices []E, def int) (choice bool, res E, err error) { + return chooseWithDefault(os.Stdin, choices, def) +} + +func chooseWithDefault[E any](r io.Reader, choices []E, def int) (choice bool, res E, err error) { if n := len(choices); n == 0 { err = errors.New("no choices") return @@ -83,7 +88,7 @@ func ChooseWithDefault[E any](choices []E, def int) (choice bool, res E, err err } else { prompt = "Please choose: " } - b, err := readLine(bufio.NewScanner(os.Stdin), prompt, def <= 0) + b, err := readLine(bufio.NewScanner(r), prompt, def <= 0) if err != nil { return } diff --git a/choice/choice_test.go b/choice/choice_test.go index 206b1d8..e06a51a 100644 --- a/choice/choice_test.go +++ b/choice/choice_test.go @@ -54,3 +54,104 @@ func TestError(t *testing.T) { t.Error("expected err is ErrBadChoice; got not") } } + +func TestChooseWithDefault(t *testing.T) { + tests := []struct { + name string + input string + choices []string + def int + wantChoice bool + wantRes string + wantErr bool + }{ + { + name: "valid choice", + input: "2\n", + choices: []string{"a", "b", "c"}, + def: 0, + wantChoice: true, + wantRes: "b", + wantErr: false, + }, + { + name: "default used when empty input", + input: "\n", + choices: []string{"a", "b", "c"}, + def: 2, + wantChoice: true, + wantRes: "b", + wantErr: false, + }, + { + name: "quit with 0", + input: "0\n", + choices: []string{"a", "b"}, + def: 0, + wantChoice: false, + wantRes: "", + wantErr: false, + }, + { + name: "quit with q", + input: "q\n", + choices: []string{"a", "b"}, + def: 0, + wantChoice: false, + wantRes: "", + wantErr: false, + }, + { + name: "invalid input", + input: "x\n", + choices: []string{"a", "b"}, + def: 0, + wantChoice: true, + wantRes: "", + wantErr: true, + }, + { + name: "out of range", + input: "5\n", + choices: []string{"a", "b"}, + def: 0, + wantChoice: true, + wantRes: "", + wantErr: true, + }, + { + name: "no choices", + input: "1\n", + choices: []string{}, + def: 0, + wantChoice: false, + wantRes: "", + wantErr: true, + }, + { + name: "invalid default", + input: "1\n", + choices: []string{"a"}, + def: 5, + wantChoice: false, + wantRes: "", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := strings.NewReader(tt.input) + choice, res, err := chooseWithDefault(r, tt.choices, tt.def) + if choice != tt.wantChoice { + t.Errorf("got choice=%v, want %v", choice, tt.wantChoice) + } + if res != tt.wantRes { + t.Errorf("got res=%q, want %q", res, tt.wantRes) + } + if (err != nil) != tt.wantErr { + t.Errorf("got err=%v, wantErr=%v", err, tt.wantErr) + } + }) + } +} From 8ef243340144c60d9696ccbc263dbc002fc1d17b Mon Sep 17 00:00:00 2001 From: sunshineplan Date: Sun, 28 Sep 2025 17:05:31 +0800 Subject: [PATCH 08/40] ring and loadbalance --- container/ring.go | 268 ++++++++++++++++++++------------- container/ring_test.go | 47 +++--- loadbalance/loadbalance.go | 5 +- loadbalance/random.go | 8 +- loadbalance/roundrobin.go | 86 +++++------ loadbalance/roundrobin_test.go | 21 ++- 6 files changed, 252 insertions(+), 183 deletions(-) diff --git a/container/ring.go b/container/ring.go index 4b29b40..f163fe6 100644 --- a/container/ring.go +++ b/container/ring.go @@ -1,57 +1,6 @@ package container -import ( - "container/ring" - "sync" -) - -var ringMutex sync.RWMutex - -type mutex ring.Ring - -func newMutex() *mutex { - r := ring.New(1) - r.Value = new(sync.RWMutex) - return (*mutex)(r) -} - -func (mu *mutex) Lock() { - ringMutex.RLock() - defer ringMutex.RUnlock() - (*ring.Ring)(mu).Do(func(a any) { - a.(*sync.RWMutex).Lock() - }) -} - -func (mu *mutex) Unlock() { - ringMutex.RLock() - defer ringMutex.RUnlock() - (*ring.Ring)(mu).Do(func(a any) { - a.(*sync.RWMutex).Unlock() - }) -} - -func (mu *mutex) RLock() { - ringMutex.RLock() - defer ringMutex.RUnlock() - (*ring.Ring)(mu).Do(func(a any) { - a.(*sync.RWMutex).RLock() - }) -} - -func (mu *mutex) RUnlock() { - ringMutex.RLock() - defer ringMutex.RUnlock() - (*ring.Ring)(mu).Do(func(a any) { - a.(*sync.RWMutex).RUnlock() - }) -} - -func (mu *mutex) Link(s *mutex) *mutex { - ringMutex.Lock() - defer ringMutex.Unlock() - return (*mutex)((*ring.Ring)(mu).Link((*ring.Ring)(s))) -} +import "sync" // A Ring is an element of a circular list, or ring. // Rings do not have a beginning or end; a pointer to any ring element @@ -59,30 +8,69 @@ func (mu *mutex) Link(s *mutex) *mutex { // as nil Ring pointers. The zero value for a Ring is a one-element // ring with a nil Value. type Ring[T any] struct { - mu *mutex - r *ring.Ring + mu sync.RWMutex + ringMu *sync.RWMutex + + next, prev *Ring[T] + value T // for use by client; untouched by this library +} + +func (r *Ring[T]) init() *Ring[T] { + r.ringMu = new(sync.RWMutex) + r.next = r + r.prev = r + return r } // Next returns the next ring element. r must not be empty. func (r *Ring[T]) Next() *Ring[T] { - r.mu.Lock() - defer r.mu.Unlock() - return &Ring[T]{r.mu, r.r.Next()} + r.mu.RLock() + defer r.mu.RUnlock() + if r.ringMu == nil { + return r.init() + } + r.ringMu.RLock() + defer r.ringMu.RUnlock() + return r.next } // Prev returns the previous ring element. r must not be empty. func (r *Ring[T]) Prev() *Ring[T] { - r.mu.Lock() - defer r.mu.Unlock() - return &Ring[T]{r.mu, r.r.Prev()} + r.mu.RLock() + defer r.mu.RUnlock() + if r.ringMu == nil { + return r.init() + } + r.ringMu.RLock() + defer r.ringMu.RUnlock() + return r.prev +} + +func (r *Ring[T]) move(n int) *Ring[T] { + switch { + case n < 0: + for ; n < 0; n++ { + r = r.prev + } + case n > 0: + for ; n > 0; n-- { + r = r.next + } + } + return r } // Move moves n % r.Len() elements backward (n < 0) or forward (n >= 0) // in the ring and returns that ring element. r must not be empty. func (r *Ring[T]) Move(n int) *Ring[T] { - r.mu.Lock() - defer r.mu.Unlock() - return &Ring[T]{newMutex(), r.r.Move(n)} + r.mu.RLock() + defer r.mu.RUnlock() + if r.ringMu == nil { + return r.init() + } + r.ringMu.Lock() + defer r.ringMu.Unlock() + return r.move(n) } // NewRing creates a ring of n elements. @@ -90,20 +78,61 @@ func NewRing[T any](n int) *Ring[T] { if n <= 0 { return nil } - return &Ring[T]{newMutex(), ring.New(n)} -} - -func (r *Ring[T]) Set(v T) { - r.mu.Lock() - defer r.mu.Unlock() - r.r.Value = &v -} - -func (r *Ring[T]) Value() *T { - r.mu.RLock() - defer r.mu.RUnlock() - v, _ := r.r.Value.(*T) - return v + r := &Ring[T]{ringMu: new(sync.RWMutex)} + p := r + for i := 1; i < n; i++ { + p.next = &Ring[T]{ringMu: p.ringMu, prev: p} + p = p.next + } + p.next = r + r.prev = p + return r +} + +func linkLock[T any](r, s *Ring[T]) (unlock func()) { + r.ringMu.Lock() + if s.ringMu == r.ringMu { + unlock = r.ringMu.Unlock + } else { + s.mu.Lock() + if s.ringMu == nil { + s.init() + unlock = func() { + s.mu.Unlock() + r.ringMu.Unlock() + } + } else { + s.ringMu.Lock() + unlock = func() { + s.ringMu.Unlock() + s.mu.Unlock() + r.ringMu.Unlock() + } + } + } + return +} + +func (r *Ring[T]) link(s *Ring[T]) *Ring[T] { + n := r.next + p := s.prev + // Note: Cannot use multiple assignment because + // evaluation order of LHS is not specified. + r.next = s + s.prev = r + n.prev = p + p.next = n + if s.ringMu == r.ringMu { + n.ringMu = new(sync.RWMutex) + for p := n.next; p != n; p = p.next { + p.ringMu = n.ringMu + } + } else { + for p := s.next; p != s; p = p.next { + p.ringMu = r.ringMu + } + } + return n } // Link connects ring r with ring s such that r.Next() @@ -122,52 +151,87 @@ func (r *Ring[T]) Value() *T { // after r. The result points to the element following the // last element of s after insertion. func (r *Ring[T]) Link(s *Ring[T]) *Ring[T] { + r.mu.RLock() + defer r.mu.RUnlock() + if r.ringMu == nil { + r.init() + } + n := r.next if s == nil { - return r.Next() + return n } - m := r.mu.Link(s.mu) - r.mu.Lock() - defer r.mu.Unlock() - return &Ring[T]{m, r.r.Link(s.r)} + unlock := linkLock(r, s) + defer unlock() + return r.link(s) } // Unlink removes n % r.Len() elements from the ring r, starting // at r.Next(). If n % r.Len() == 0, r remains unchanged. // The result is the removed subring. r must not be empty. func (r *Ring[T]) Unlink(n int) *Ring[T] { - r.mu.Lock() - defer r.mu.Unlock() - u := r.r.Unlink(n) - if u == nil { + if n <= 0 { return nil } - return &Ring[T]{newMutex(), u} + r.mu.RLock() + defer r.mu.RUnlock() + if r.ringMu == nil { + return r.init() + } + r.ringMu.Lock() + defer r.ringMu.Unlock() + return r.link(r.move(n + 1)) } // Len computes the number of elements in ring r. // It executes in time proportional to the number of elements. func (r *Ring[T]) Len() int { - if r == nil { - return 0 + n := 0 + if r != nil { + r.mu.RLock() + defer r.mu.RUnlock() + if r.ringMu == nil { + r.init() + return 1 + } + r.ringMu.RLock() + defer r.ringMu.RUnlock() + n = 1 + for p := r.next; p != r; p = p.next { + n++ + } } - r.mu.RLock() - defer r.mu.RUnlock() - return r.r.Len() + return n } // Do calls function f on each element of the ring, in forward order. // The behavior of Do is undefined if f changes *r. -func (r *Ring[T]) Do(f func(*T)) { - if r == nil { - return +func (r *Ring[T]) Do(f func(T)) { + if r != nil { + r.mu.RLock() + if r.ringMu == nil { + r.mu.RUnlock() + return + } + r.ringMu.RLock() + defer r.ringMu.RUnlock() + f(r.value) + r.mu.RUnlock() + for p := r.next; p != r; p = p.next { + p.mu.RLock() + f(p.value) + p.mu.RUnlock() + } } +} + +func (r *Ring[T]) Set(v T) { + r.mu.Lock() + defer r.mu.Unlock() + r.value = v +} + +func (r *Ring[T]) Value() T { r.mu.RLock() defer r.mu.RUnlock() - r.r.Do(func(a any) { - if v, ok := a.(*T); ok { - f(v) - } else { - f(nil) - } - }) + return r.value } diff --git a/container/ring_test.go b/container/ring_test.go index 5d9217c..cfc5d88 100644 --- a/container/ring_test.go +++ b/container/ring_test.go @@ -5,13 +5,12 @@ package container import ( - "container/ring" "fmt" "testing" ) // For debugging - keep around. -func dump(r *ring.Ring) { +func dump[T any](r *Ring[T]) { if r == nil { fmt.Println("empty") return @@ -24,7 +23,7 @@ func dump(r *ring.Ring) { fmt.Println() } -func verify[T int](t *testing.T, r *Ring[T], N int, sum int) { +func verify(t *testing.T, r *Ring[int], N int, sum int) { // Len n := r.Len() if n != N { @@ -34,11 +33,9 @@ func verify[T int](t *testing.T, r *Ring[T], N int, sum int) { // iteration n = 0 s := 0 - r.Do(func(p *T) { + r.Do(func(p int) { n++ - if p != nil { - s += int(*p) - } + s += p }) if n != N { t.Errorf("number of forward iterations == %d; expected %d", n, N) @@ -47,41 +44,41 @@ func verify[T int](t *testing.T, r *Ring[T], N int, sum int) { t.Errorf("forward ring sum = %d; expected %d", s, sum) } - if r == nil || r.r == nil { + if r == nil { return } // connections - if r.Next().r != nil { - var p *ring.Ring // previous element - for q := r.r; p == nil || q != r.r; q = q.Next() { + if r.Next() != nil { + var p *Ring[int] // previous element + for q := r; p == nil || q != r; q = q.next { if p != nil && p != q.Prev() { t.Errorf("prev = %p, expected q.prev = %p\n", p, q.Prev()) } p = q } - if p != r.Prev().r { + if p != r.Prev() { t.Errorf("prev = %p, expected r.prev = %p\n", p, r.Prev()) } } // Move - if r.Move(0).r != r.r { + if r.Move(0) != r { t.Errorf("r.Move(0) != r") } - if r.Move(N).r != r.r { + if r.Move(N) != r { t.Errorf("r.Move(%d) != r", N) } - if r.Move(-N).r != r.r { + if r.Move(-N) != r { t.Errorf("r.Move(%d) != r", -N) } for i := 0; i < 10; i++ { ni := N + i mi := ni % N - if r.Move(ni).r != r.Move(mi).r { + if r.Move(ni) != r.Move(mi) { t.Errorf("r.Move(%d) != r.Move(%d)", ni, mi) } - if r.Move(-ni).r != r.Move(-mi).r { + if r.Move(-ni) != r.Move(-mi) { t.Errorf("r.Move(%d) != r.Move(%d)", -ni, -mi) } } @@ -89,8 +86,8 @@ func verify[T int](t *testing.T, r *Ring[T], N int, sum int) { func TestCornerCases(t *testing.T) { var ( - r0 = &Ring[int]{newMutex(), nil} - r1 = Ring[int]{newMutex(), new(ring.Ring)} + r0 *Ring[int] + r1 Ring[int] ) // Basics verify(t, r0, 0, 0) @@ -132,16 +129,16 @@ func TestNew(t *testing.T) { func TestLink1(t *testing.T) { r1a := makeN(1) - var r1b = Ring[int]{newMutex(), &ring.Ring{}} + var r1b Ring[int] r2a := r1a.Link(&r1b) verify(t, r2a, 2, 1) - if r2a.r != r1a.r { + if r2a != r1a { t.Errorf("a) 2-element link failed") } r2b := r2a.Link(r2a.Next()) verify(t, r2b, 2, 1) - if r2b.r != r2a.Next().r { + if r2b != r2a.Next() { t.Errorf("b) 2-element link failed") } @@ -151,7 +148,7 @@ func TestLink1(t *testing.T) { } func TestLink2(t *testing.T) { - var r0 = &Ring[int]{newMutex(), nil} + var r0 *Ring[int] r1a := NewRing[int](1) r1a.Set(42) r1b := NewRing[int](1) @@ -172,7 +169,7 @@ func TestLink2(t *testing.T) { } func TestLink3(t *testing.T) { - var r = Ring[int]{newMutex(), new(ring.Ring)} + var r Ring[int] n := 1 for i := 1; i < 10; i++ { n += i @@ -216,7 +213,7 @@ func TestLinkUnlink(t *testing.T) { // Test that calling Move() on an empty Ring initializes it. func TestMoveEmptyRing(t *testing.T) { - var r = Ring[int]{newMutex(), &ring.Ring{}} + var r Ring[int] r.Move(1) verify(t, &r, 1, 0) diff --git a/loadbalance/loadbalance.go b/loadbalance/loadbalance.go index 2e358a2..3423ee1 100644 --- a/loadbalance/loadbalance.go +++ b/loadbalance/loadbalance.go @@ -2,19 +2,16 @@ package loadbalance import ( "errors" - "sync" "github.com/sunshineplan/utils/container" ) var ErrEmptyLoadBalancer = errors.New("empty load balancer") -var mu sync.RWMutex - type LoadBalancer[E any] interface { Len() int Next() E - Ring() *container.Ring[*E] + Ring() *container.Ring[E] Link(LoadBalancer[E]) LoadBalancer[E] Unlink(int) LoadBalancer[E] } diff --git a/loadbalance/random.go b/loadbalance/random.go index ccd03e0..5407956 100644 --- a/loadbalance/random.go +++ b/loadbalance/random.go @@ -17,8 +17,8 @@ func WeightedRandom[E any](items ...Weighted[E]) LoadBalancer[E] { } func (r *random[E]) Next() E { - mu.RLock() - defer mu.RUnlock() - r.roundrobin = (*roundrobin[E])(r.Ring().Move(rand.IntN(r.Ring().Len()))) - return **r.Ring().Value() + r.RLock() + defer r.RUnlock() + r.roundrobin.ring = r.ring.Move(rand.IntN(r.Len())) + return r.Ring().Value() } diff --git a/loadbalance/roundrobin.go b/loadbalance/roundrobin.go index c30eb3e..b4b4b79 100644 --- a/loadbalance/roundrobin.go +++ b/loadbalance/roundrobin.go @@ -1,54 +1,51 @@ package loadbalance -import "github.com/sunshineplan/utils/container" +import ( + "sync" + + "github.com/sunshineplan/utils/container" +) var _ LoadBalancer[any] = &roundrobin[any]{} -type roundrobin[E any] container.Ring[*E] +type roundrobin[E any] struct { + sync.RWMutex + ring *container.Ring[E] +} func newRoundRobin[E any, Items []E | []Weighted[E]](items Items) *roundrobin[E] { if len(items) == 0 { panic(ErrEmptyLoadBalancer) } - var root *roundrobin[E] + var ring *container.Ring[E] switch items := any(items).(type) { case []E: + ring = container.NewRing[E](len(items)) for _, i := range items { - r := container.NewRing[*E](1) - r.Set(&i) - if root == nil { - root = (*roundrobin[E])(r) - } else { - root.Ring().Link(r) - root = (*roundrobin[E])(root.Ring().Next()) - } - } - if root != nil { - root = (*roundrobin[E])(root.Ring().Next()) + ring.Set(i) + ring = ring.Next() } case []Weighted[E]: for _, i := range items { if i.Weight == 0 { continue } - item := &i.Item - r := container.NewRing[*E](i.Weight) - for range r.Len() { - r.Set(item) - r = r.Next() + subring := container.NewRing[E](i.Weight) + for range i.Weight { + subring.Set(i.Item) + subring = subring.Next() } - if root == nil { - root = (*roundrobin[E])(r) + if ring == nil { + ring = subring } else { - root.Ring().Link(r) - root = (*roundrobin[E])(root.Ring().Next()) + ring = ring.Link(subring).Prev() } } - if root != nil { - root = (*roundrobin[E])(root.Ring().Next()) + if ring != nil { + ring = ring.Next() } } - return root + return &roundrobin[E]{ring: ring} } func RoundRobin[E any](items ...E) LoadBalancer[E] { @@ -60,31 +57,36 @@ func WeightedRoundRobin[E any](items ...Weighted[E]) LoadBalancer[E] { } func (r *roundrobin[E]) Len() int { - mu.RLock() - defer mu.RUnlock() - return r.Ring().Len() + r.RLock() + defer r.RUnlock() + return r.ring.Len() } func (r *roundrobin[E]) Next() (next E) { - mu.Lock() - defer mu.Unlock() - v := **r.Ring().Value() - *r = *(*roundrobin[E])(r.Ring().Next()) - return v + r.Lock() + defer r.Unlock() + next = r.ring.Value() + r.ring = r.ring.Next() + return } -func (r *roundrobin[E]) Ring() *container.Ring[*E] { - return (*container.Ring[*E])(r) +func (r *roundrobin[E]) Ring() *container.Ring[E] { + r.RLock() + defer r.RUnlock() + return r.ring } func (r *roundrobin[E]) Link(s LoadBalancer[E]) LoadBalancer[E] { - mu.Lock() - defer mu.Unlock() - return (*roundrobin[E])(r.Ring().Link(s.Ring())) + sr := s.Ring() + r.Lock() + defer r.Unlock() + r.ring = r.ring.Prev().Link(sr) + return r } func (r *roundrobin[E]) Unlink(n int) LoadBalancer[E] { - mu.Lock() - defer mu.Unlock() - return (*roundrobin[E])(r.Ring().Unlink(n)) + r.Lock() + defer r.Unlock() + r.ring = r.ring.Unlink(n) + return r } diff --git a/loadbalance/roundrobin_test.go b/loadbalance/roundrobin_test.go index 5671e4c..605908c 100644 --- a/loadbalance/roundrobin_test.go +++ b/loadbalance/roundrobin_test.go @@ -7,29 +7,38 @@ import ( func TestRoundRobin(t *testing.T) { r1 := RoundRobin([]string{"a", "b", "c"}...) + if r1.Len() != 3 { + t.Fatalf("want 3, got %d", r1.Len()) + } var res []string for range 6 { res = append(res, r1.Next()) } if expect := []string{"a", "b", "c", "a", "b", "c"}; !slices.Equal(res, expect) { - t.Errorf("want %v, got %v", expect, res) + t.Fatalf("want %v, got %v", expect, res) } res = nil r2 := WeightedRoundRobin([]Weighted[string]{{"a", 2}, {"b", 1}, {"c", 1}}...) + if r2.Len() != 4 { + t.Fatalf("want 4, got %d", r2.Len()) + } for range 12 { res = append(res, r2.Next()) } if expect := []string{"a", "a", "b", "c", "a", "a", "b", "c", "a", "a", "b", "c"}; !slices.Equal(res, expect) { - t.Errorf("want %v, got %v", expect, res) + t.Fatalf("want %v, got %v", expect, res) } res = nil - r1.Link(r2) + ring := r1.Link(r2) + if ring.Len() != 7 { + t.Fatalf("want 7, got %d", ring.Len()) + } for range 7 { - res = append(res, r1.Next()) + res = append(res, ring.Next()) } - if expect := []string{"a", "a", "a", "b", "c", "b", "c"}; !slices.Equal(res, expect) { - t.Errorf("want %v, got %v", expect, res) + if expect := []string{"a", "b", "c", "a", "a", "b", "c"}; !slices.Equal(res, expect) { + t.Fatalf("want %v, got %v", expect, res) } } From 17fad2765fbd38499fb3ea22f7a6a0c732e36634 Mon Sep 17 00:00:00 2001 From: sunshineplan Date: Mon, 29 Sep 2025 13:56:54 +0800 Subject: [PATCH 09/40] Update ring --- container/ring.go | 112 +++++++++++++++++++++++++++++------------ container/ring_test.go | 112 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 192 insertions(+), 32 deletions(-) diff --git a/container/ring.go b/container/ring.go index f163fe6..a1b7375 100644 --- a/container/ring.go +++ b/container/ring.go @@ -1,6 +1,9 @@ package container -import "sync" +import ( + "sync" + "unsafe" +) // A Ring is an element of a circular list, or ring. // Rings do not have a beginning or end; a pointer to any ring element @@ -24,8 +27,8 @@ func (r *Ring[T]) init() *Ring[T] { // Next returns the next ring element. r must not be empty. func (r *Ring[T]) Next() *Ring[T] { - r.mu.RLock() - defer r.mu.RUnlock() + r.mu.Lock() + defer r.mu.Unlock() if r.ringMu == nil { return r.init() } @@ -36,8 +39,8 @@ func (r *Ring[T]) Next() *Ring[T] { // Prev returns the previous ring element. r must not be empty. func (r *Ring[T]) Prev() *Ring[T] { - r.mu.RLock() - defer r.mu.RUnlock() + r.mu.Lock() + defer r.mu.Unlock() if r.ringMu == nil { return r.init() } @@ -63,13 +66,13 @@ func (r *Ring[T]) move(n int) *Ring[T] { // Move moves n % r.Len() elements backward (n < 0) or forward (n >= 0) // in the ring and returns that ring element. r must not be empty. func (r *Ring[T]) Move(n int) *Ring[T] { - r.mu.RLock() - defer r.mu.RUnlock() + r.mu.Lock() + defer r.mu.Unlock() if r.ringMu == nil { return r.init() } - r.ringMu.Lock() - defer r.ringMu.Unlock() + r.ringMu.RLock() + defer r.ringMu.RUnlock() return r.move(n) } @@ -90,23 +93,60 @@ func NewRing[T any](n int) *Ring[T] { } func linkLock[T any](r, s *Ring[T]) (unlock func()) { - r.ringMu.Lock() - if s.ringMu == r.ringMu { - unlock = r.ringMu.Unlock + rmu, smu := r.ringMu, s.ringMu + if s == r { + r.mu.Lock() + rmu.Lock() + unlock = func() { + rmu.Unlock() + r.mu.Unlock() + } } else { - s.mu.Lock() - if s.ringMu == nil { - s.init() - unlock = func() { + order := uintptr(unsafe.Pointer(&r.mu)) < uintptr(unsafe.Pointer(&s.mu)) + var finalUnlock func() + if order { + r.mu.Lock() + s.mu.Lock() + finalUnlock = func() { s.mu.Unlock() - r.ringMu.Unlock() + r.mu.Unlock() } } else { - s.ringMu.Lock() - unlock = func() { - s.ringMu.Unlock() + s.mu.Lock() + r.mu.Lock() + finalUnlock = func() { + r.mu.Unlock() s.mu.Unlock() - r.ringMu.Unlock() + } + } + switch smu { + case nil: + s.init() + smu = rmu + fallthrough + case rmu: + smu.Lock() + unlock = func() { + smu.Unlock() + finalUnlock() + } + default: + if order { + rmu.Lock() + smu.Lock() + unlock = func() { + smu.Unlock() + rmu.Unlock() + finalUnlock() + } + } else { + smu.Lock() + rmu.Lock() + unlock = func() { + rmu.Unlock() + smu.Unlock() + finalUnlock() + } } } } @@ -114,6 +154,15 @@ func linkLock[T any](r, s *Ring[T]) (unlock func()) { } func (r *Ring[T]) link(s *Ring[T]) *Ring[T] { + var sameRing bool + if s.ringMu == r.ringMu { + sameRing = true + } else { + s.ringMu = r.ringMu + for p := s.next; p != s; p = p.next { + p.ringMu = r.ringMu + } + } n := r.next p := s.prev // Note: Cannot use multiple assignment because @@ -122,15 +171,11 @@ func (r *Ring[T]) link(s *Ring[T]) *Ring[T] { s.prev = r n.prev = p p.next = n - if s.ringMu == r.ringMu { + if sameRing { n.ringMu = new(sync.RWMutex) for p := n.next; p != n; p = p.next { p.ringMu = n.ringMu } - } else { - for p := s.next; p != s; p = p.next { - p.ringMu = r.ringMu - } } return n } @@ -151,12 +196,12 @@ func (r *Ring[T]) link(s *Ring[T]) *Ring[T] { // after r. The result points to the element following the // last element of s after insertion. func (r *Ring[T]) Link(s *Ring[T]) *Ring[T] { - r.mu.RLock() - defer r.mu.RUnlock() + r.mu.Lock() if r.ringMu == nil { r.init() } n := r.next + r.mu.Unlock() if s == nil { return n } @@ -172,8 +217,8 @@ func (r *Ring[T]) Unlink(n int) *Ring[T] { if n <= 0 { return nil } - r.mu.RLock() - defer r.mu.RUnlock() + r.mu.Lock() + defer r.mu.Unlock() if r.ringMu == nil { return r.init() } @@ -187,8 +232,8 @@ func (r *Ring[T]) Unlink(n int) *Ring[T] { func (r *Ring[T]) Len() int { n := 0 if r != nil { - r.mu.RLock() - defer r.mu.RUnlock() + r.mu.Lock() + defer r.mu.Unlock() if r.ringMu == nil { r.init() return 1 @@ -227,6 +272,9 @@ func (r *Ring[T]) Do(f func(T)) { func (r *Ring[T]) Set(v T) { r.mu.Lock() defer r.mu.Unlock() + if r.ringMu == nil { + r.init() + } r.value = v } diff --git a/container/ring_test.go b/container/ring_test.go index cfc5d88..e92b126 100644 --- a/container/ring_test.go +++ b/container/ring_test.go @@ -6,6 +6,7 @@ package container import ( "fmt" + "sync" "testing" ) @@ -218,3 +219,114 @@ func TestMoveEmptyRing(t *testing.T) { r.Move(1) verify(t, &r, 1, 0) } + +// TestLinkSharedRing tests the Link function to ensure the shared ringMu is correctly set. +func TestLinkSharedRing(t *testing.T) { + // Helper function to check if all elements in a ring share the same ringMu. + checkRingMu := func(t *testing.T, r *Ring[int], expectedMu *sync.RWMutex, name string) { + if r == nil { + t.Errorf("%s: ring is nil", name) + return + } + seen := make(map[*Ring[int]]bool) + current := r + for { + if current.ringMu != expectedMu { + t.Errorf("%s: element %p has ringMu %p, expected %p", name, current, current.ringMu, expectedMu) + } + seen[current] = true + current = current.Next() + if current == r { + break + } + if seen[current] { + t.Errorf("%s: cycle detected before reaching start", name) + break + } + } + } + + // Test 1: Link two distinct rings. + t.Run("LinkDistinctRings", func(t *testing.T) { + r1 := NewRing[int](3) + r2 := NewRing[int](2) + originalR1Mu := r1.ringMu + originalR2Mu := r2.ringMu + + // Link r1 and r2. + r1.Link(r2) + + if originalR1Mu == originalR2Mu { + t.Fatalf("initial rings have same ringMu %p", originalR1Mu) + } + + // Check that all elements in the combined ring share r1's ringMu. + checkRingMu(t, r1, originalR1Mu, "combined ring") + + // Verify ring length. + if r1.Len() != 5 { + t.Errorf("combined ring length is %d, expected 5", r1.Len()) + } + }) + + // Test 2: Link within the same ring. + t.Run("LinkSameRing", func(t *testing.T) { + r := NewRing[int](5) + originalMu := r.ringMu + + // Move to the third element. + r3 := r.Move(2) + + // Link r to r3, splitting the ring. + result := r.Link(r3) + + // Check that the original ring (r) has 2 elements and retains original ringMu. + checkRingMu(t, r, originalMu, "original ring") + + // Check that the result ring has 3 elements and a new ringMu. + if result.ringMu == originalMu { + t.Errorf("result ring has same ringMu %p as original %p", result.ringMu, originalMu) + } + + checkRingMu(t, result, result.ringMu, "result ring") + }) + + // Test 3: Link a ring to itself. + t.Run("LinkSelf", func(t *testing.T) { + r := NewRing[int](5) + originalMu := r.ringMu + + result := r.Link(r) + + // Check that the ring remains unchanged and retains original ringMu. + checkRingMu(t, r, originalMu, "self-linked ring") + if r.Len() != 1 { + t.Errorf("self-linked ring length is %d, expected 1", r.Len()) + } + + // Check that the result ring has 3 elements and a new ringMu. + if result.ringMu == originalMu { + t.Errorf("result ring has same ringMu %p as original %p", result.ringMu, originalMu) + } + + checkRingMu(t, result, result.ringMu, "result ring") + }) + + // Test 4: Link with nil. + t.Run("LinkNil", func(t *testing.T) { + r := NewRing[int](3) + originalMu := r.ringMu + originalNext := r.Next() + + result := r.Link(nil) + + // Check that the ring is unchanged and retains original ringMu. + checkRingMu(t, r, originalMu, "ring after linking nil") + if r.Len() != 3 { + t.Errorf("ring length is %d, expected 3", r.Len()) + } + if result != originalNext { + t.Errorf("Link result is %p, expected %p", result, originalNext) + } + }) +} From 36c495a4f32c7739d639a9ae592a53c020b8407f Mon Sep 17 00:00:00 2001 From: sunshineplan Date: Mon, 29 Sep 2025 14:17:34 +0800 Subject: [PATCH 10/40] ring comment --- container/ring.go | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/container/ring.go b/container/ring.go index a1b7375..b91e6f9 100644 --- a/container/ring.go +++ b/container/ring.go @@ -5,17 +5,17 @@ import ( "unsafe" ) -// A Ring is an element of a circular list, or ring. +// A Ring is an element of a thread-safe, generic circular list, or ring. // Rings do not have a beginning or end; a pointer to any ring element // serves as reference to the entire ring. Empty rings are represented -// as nil Ring pointers. The zero value for a Ring is a one-element -// ring with a nil Value. +// as nil Ring pointers. The zero value for a Ring[T] is a one-element +// ring with a zero Value of type T. type Ring[T any] struct { mu sync.RWMutex ringMu *sync.RWMutex next, prev *Ring[T] - value T // for use by client; untouched by this library + value T } func (r *Ring[T]) init() *Ring[T] { @@ -269,15 +269,18 @@ func (r *Ring[T]) Do(f func(T)) { } } -func (r *Ring[T]) Set(v T) { +// Set assigns value v to the ring element and returns it. +func (r *Ring[T]) Set(v T) *Ring[T] { r.mu.Lock() defer r.mu.Unlock() if r.ringMu == nil { r.init() } r.value = v + return r } +// Value returns the value of the ring element. func (r *Ring[T]) Value() T { r.mu.RLock() defer r.mu.RUnlock() From 2a15a5efc370378384a00c1fcd7396cbc91617b1 Mon Sep 17 00:00:00 2001 From: sunshineplan Date: Mon, 29 Sep 2025 15:25:17 +0800 Subject: [PATCH 11/40] loadbalance --- loadbalance/loadbalance.go | 3 +-- loadbalance/random.go | 8 ++++---- loadbalance/roundrobin.go | 27 ++++++++++++--------------- loadbalance/roundrobin_test.go | 6 +++--- 4 files changed, 20 insertions(+), 24 deletions(-) diff --git a/loadbalance/loadbalance.go b/loadbalance/loadbalance.go index 3423ee1..40ad0a0 100644 --- a/loadbalance/loadbalance.go +++ b/loadbalance/loadbalance.go @@ -11,8 +11,7 @@ var ErrEmptyLoadBalancer = errors.New("empty load balancer") type LoadBalancer[E any] interface { Len() int Next() E - Ring() *container.Ring[E] - Link(LoadBalancer[E]) LoadBalancer[E] + Link(*container.Ring[E]) LoadBalancer[E] Unlink(int) LoadBalancer[E] } diff --git a/loadbalance/random.go b/loadbalance/random.go index 5407956..d72381a 100644 --- a/loadbalance/random.go +++ b/loadbalance/random.go @@ -17,8 +17,8 @@ func WeightedRandom[E any](items ...Weighted[E]) LoadBalancer[E] { } func (r *random[E]) Next() E { - r.RLock() - defer r.RUnlock() - r.roundrobin.ring = r.ring.Move(rand.IntN(r.Len())) - return r.Ring().Value() + r.Lock() + defer r.Unlock() + r.ring = r.ring.Move(rand.IntN(r.len)) + return r.ring.Value() } diff --git a/loadbalance/roundrobin.go b/loadbalance/roundrobin.go index b4b4b79..0812261 100644 --- a/loadbalance/roundrobin.go +++ b/loadbalance/roundrobin.go @@ -11,6 +11,7 @@ var _ LoadBalancer[any] = &roundrobin[any]{} type roundrobin[E any] struct { sync.RWMutex ring *container.Ring[E] + len int } func newRoundRobin[E any, Items []E | []Weighted[E]](items Items) *roundrobin[E] { @@ -22,8 +23,7 @@ func newRoundRobin[E any, Items []E | []Weighted[E]](items Items) *roundrobin[E] case []E: ring = container.NewRing[E](len(items)) for _, i := range items { - ring.Set(i) - ring = ring.Next() + ring = ring.Set(i).Next() } case []Weighted[E]: for _, i := range items { @@ -32,8 +32,7 @@ func newRoundRobin[E any, Items []E | []Weighted[E]](items Items) *roundrobin[E] } subring := container.NewRing[E](i.Weight) for range i.Weight { - subring.Set(i.Item) - subring = subring.Next() + subring = subring.Set(i.Item).Next() } if ring == nil { ring = subring @@ -44,8 +43,11 @@ func newRoundRobin[E any, Items []E | []Weighted[E]](items Items) *roundrobin[E] if ring != nil { ring = ring.Next() } + if ring.Len() == 0 { + panic(ErrEmptyLoadBalancer) + } } - return &roundrobin[E]{ring: ring} + return &roundrobin[E]{ring: ring, len: ring.Len()} } func RoundRobin[E any](items ...E) LoadBalancer[E] { @@ -59,7 +61,7 @@ func WeightedRoundRobin[E any](items ...Weighted[E]) LoadBalancer[E] { func (r *roundrobin[E]) Len() int { r.RLock() defer r.RUnlock() - return r.ring.Len() + return r.len } func (r *roundrobin[E]) Next() (next E) { @@ -70,17 +72,11 @@ func (r *roundrobin[E]) Next() (next E) { return } -func (r *roundrobin[E]) Ring() *container.Ring[E] { - r.RLock() - defer r.RUnlock() - return r.ring -} - -func (r *roundrobin[E]) Link(s LoadBalancer[E]) LoadBalancer[E] { - sr := s.Ring() +func (r *roundrobin[E]) Link(s *container.Ring[E]) LoadBalancer[E] { r.Lock() defer r.Unlock() - r.ring = r.ring.Prev().Link(sr) + r.ring = r.ring.Prev().Link(s) + r.len = r.ring.Len() return r } @@ -88,5 +84,6 @@ func (r *roundrobin[E]) Unlink(n int) LoadBalancer[E] { r.Lock() defer r.Unlock() r.ring = r.ring.Unlink(n) + r.len = r.ring.Len() return r } diff --git a/loadbalance/roundrobin_test.go b/loadbalance/roundrobin_test.go index 605908c..b1ecc23 100644 --- a/loadbalance/roundrobin_test.go +++ b/loadbalance/roundrobin_test.go @@ -6,7 +6,7 @@ import ( ) func TestRoundRobin(t *testing.T) { - r1 := RoundRobin([]string{"a", "b", "c"}...) + r1 := newRoundRobin[string]([]string{"a", "b", "c"}) if r1.Len() != 3 { t.Fatalf("want 3, got %d", r1.Len()) } @@ -19,7 +19,7 @@ func TestRoundRobin(t *testing.T) { } res = nil - r2 := WeightedRoundRobin([]Weighted[string]{{"a", 2}, {"b", 1}, {"c", 1}}...) + r2 := newRoundRobin[string]([]Weighted[string]{{"a", 2}, {"b", 1}, {"c", 1}}) if r2.Len() != 4 { t.Fatalf("want 4, got %d", r2.Len()) } @@ -31,7 +31,7 @@ func TestRoundRobin(t *testing.T) { } res = nil - ring := r1.Link(r2) + ring := r1.Link(r2.ring) if ring.Len() != 7 { t.Fatalf("want 7, got %d", ring.Len()) } From d4438aa6b612c74631ec0c836538b968c5051eba Mon Sep 17 00:00:00 2001 From: sunshineplan Date: Mon, 29 Sep 2025 15:54:09 +0800 Subject: [PATCH 12/40] loadbalance comment --- loadbalance/loadbalance.go | 14 +++++++-- loadbalance/random.go | 46 ++++++++++++++++++++++++++---- loadbalance/roundrobin.go | 52 +++++++++++++++++++++++++++------- loadbalance/roundrobin_test.go | 10 +++++-- 4 files changed, 102 insertions(+), 20 deletions(-) diff --git a/loadbalance/loadbalance.go b/loadbalance/loadbalance.go index 40ad0a0..1ad3521 100644 --- a/loadbalance/loadbalance.go +++ b/loadbalance/loadbalance.go @@ -6,16 +6,26 @@ import ( "github.com/sunshineplan/utils/container" ) +// ErrEmptyLoadBalancer is returned when a load balancer is initialized with no valid items. var ErrEmptyLoadBalancer = errors.New("empty load balancer") +// LoadBalancer defines an interface for load balancing algorithms. +// It provides methods to access and manipulate a circular list of elements of type E, +// ensuring thread-safe operations for concurrent use. type LoadBalancer[E any] interface { + // Len returns the number of elements in the load balancer. Len() int + // Next returns the next element according to the load balancing strategy. Next() E + // Link merges the given ring into the load balancer, inserting its elements. Link(*container.Ring[E]) LoadBalancer[E] + // Unlink removes n elements from the load balancer, starting from the next position. Unlink(int) LoadBalancer[E] } +// Weighted represents an item with an associated weight for weighted load balancing. +// The Weight field determines how many times the Item appears in the load balancer. type Weighted[E any] struct { - Item E - Weight int + Item E // The item to be balanced. + Weight int // The weight determining the item's frequency in the rotation. } diff --git a/loadbalance/random.go b/loadbalance/random.go index d72381a..7ae12e8 100644 --- a/loadbalance/random.go +++ b/loadbalance/random.go @@ -1,21 +1,55 @@ package loadbalance -import "math/rand/v2" +import ( + "math/rand/v2" + + "github.com/sunshineplan/utils/container" +) var _ LoadBalancer[any] = &random[any]{} +// random implements a thread-safe random load balancer by extending the round-robin load balancer. +// It selects elements randomly from the ring, ensuring thread-safe operations using the +// underlying round-robin mutex and ring mutexes. type random[E any] struct { - *roundrobin[E] + *roundrobin[E] // Embeds roundrobin for shared functionality. +} + +func newRandom[E any, Items []E | []Weighted[E]](items Items) (*random[E], error) { + lb, err := newRoundRobin[E](items) + if err != nil { + return nil, err + } + return &random[E]{lb}, nil +} + +// Random creates a new random load balancer with the given items. +// Each item appears once in the pool. It returns error with ErrEmptyLoadBalancer if no items are provided. +func Random[E any](items ...E) (LoadBalancer[E], error) { + return newRandom[E](items) } -func Random[E any](items ...E) LoadBalancer[E] { - return &random[E]{newRoundRobin[E](items)} +// WeightedRandom creates a new weighted random load balancer. +// Each item's weight determines how many times it appears in the pool. +// It returns error with ErrEmptyLoadBalancer if no items have positive weight. +func WeightedRandom[E any](items ...Weighted[E]) (LoadBalancer[E], error) { + return newRandom[E](items) } -func WeightedRandom[E any](items ...Weighted[E]) LoadBalancer[E] { - return &random[E]{newRoundRobin[E](items)} +// RandomFromRing creates a new random load balancer from an existing ring. +// It uses the provided ring directly, ensuring thread-safe random selection. +// It returns error with ErrEmptyLoadBalancer if the ring is nil or empty. +func RandomFromRing[E any](ring *container.Ring[E]) (LoadBalancer[E], error) { + len := ring.Len() + if len == 0 { + return nil, ErrEmptyLoadBalancer + } + return &random[E]{&roundrobin[E]{ring: ring, len: len}}, nil } +// Next returns a randomly selected element from the load balancer. +// It is thread-safe, using a write lock to update the ring position. +// If the balancer is empty, it returns the zero value of E. func (r *random[E]) Next() E { r.Lock() defer r.Unlock() diff --git a/loadbalance/roundrobin.go b/loadbalance/roundrobin.go index 0812261..a80da80 100644 --- a/loadbalance/roundrobin.go +++ b/loadbalance/roundrobin.go @@ -8,20 +8,25 @@ import ( var _ LoadBalancer[any] = &roundrobin[any]{} +// roundrobin implements a thread-safe round-robin load balancer using a circular ring. +// It supports both simple and weighted item distributions, ensuring fair rotation through +// elements. All methods are thread-safe using a struct-level mutex and the underlying ring's mutexes. type roundrobin[E any] struct { sync.RWMutex - ring *container.Ring[E] - len int + ring *container.Ring[E] // The underlying ring storing the elements. + len int // Cached length of the ring. } -func newRoundRobin[E any, Items []E | []Weighted[E]](items Items) *roundrobin[E] { +func newRoundRobin[E any, Items []E | []Weighted[E]](items Items) (*roundrobin[E], error) { if len(items) == 0 { - panic(ErrEmptyLoadBalancer) + return nil, ErrEmptyLoadBalancer } var ring *container.Ring[E] + var n int switch items := any(items).(type) { case []E: - ring = container.NewRing[E](len(items)) + n = len(items) + ring = container.NewRing[E](n) for _, i := range items { ring = ring.Set(i).Next() } @@ -43,27 +48,48 @@ func newRoundRobin[E any, Items []E | []Weighted[E]](items Items) *roundrobin[E] if ring != nil { ring = ring.Next() } - if ring.Len() == 0 { - panic(ErrEmptyLoadBalancer) + if n = ring.Len(); n == 0 { + return nil, ErrEmptyLoadBalancer } } - return &roundrobin[E]{ring: ring, len: ring.Len()} + return &roundrobin[E]{ring: ring, len: n}, nil } -func RoundRobin[E any](items ...E) LoadBalancer[E] { +// RoundRobin creates a new round-robin load balancer with the given items. +// Each item appears once in the rotation. It returns error with ErrEmptyLoadBalancer if no items are provided. +func RoundRobin[E any](items ...E) (LoadBalancer[E], error) { return newRoundRobin[E](items) } -func WeightedRoundRobin[E any](items ...Weighted[E]) LoadBalancer[E] { +// WeightedRoundRobin creates a new weighted round-robin load balancer. +// Each item's weight determines how many times it appears in the rotation. +// It returns error with ErrEmptyLoadBalancer if no items have positive weight. +func WeightedRoundRobin[E any](items ...Weighted[E]) (LoadBalancer[E], error) { return newRoundRobin[E](items) } +// RoundRobinFromRing creates a new round-robin load balancer from an existing ring. +// It uses the provided ring directly, ensuring thread-safe operations. +// It returns error with ErrEmptyLoadBalancer if the ring is nil or empty. +func RoundRobinFromRing[E any](ring *container.Ring[E]) (LoadBalancer[E], error) { + len := ring.Len() + if len == 0 { + return nil, ErrEmptyLoadBalancer + } + return &roundrobin[E]{ring: ring, len: len}, nil +} + +// Len returns the number of elements in the load balancer. +// It is thread-safe and uses the cached length for O(1) access. func (r *roundrobin[E]) Len() int { r.RLock() defer r.RUnlock() return r.len } +// Next returns the next element in the round-robin sequence. +// It is thread-safe, advancing the ring to the next position. +// If the balancer is empty, it returns the zero value of E. func (r *roundrobin[E]) Next() (next E) { r.Lock() defer r.Unlock() @@ -72,6 +98,9 @@ func (r *roundrobin[E]) Next() (next E) { return } +// Link merges the given ring into the load balancer, inserting its elements +// after the current position. It is thread-safe, updates the cached length, +// and returns the load balancer for chaining. func (r *roundrobin[E]) Link(s *container.Ring[E]) LoadBalancer[E] { r.Lock() defer r.Unlock() @@ -80,6 +109,9 @@ func (r *roundrobin[E]) Link(s *container.Ring[E]) LoadBalancer[E] { return r } +// Unlink removes n elements starting from the next position and returns +// the load balancer for chaining. It is thread-safe, updates the cached length, +// and sets the ring to nil if it becomes empty. func (r *roundrobin[E]) Unlink(n int) LoadBalancer[E] { r.Lock() defer r.Unlock() diff --git a/loadbalance/roundrobin_test.go b/loadbalance/roundrobin_test.go index b1ecc23..5b66b71 100644 --- a/loadbalance/roundrobin_test.go +++ b/loadbalance/roundrobin_test.go @@ -6,7 +6,10 @@ import ( ) func TestRoundRobin(t *testing.T) { - r1 := newRoundRobin[string]([]string{"a", "b", "c"}) + r1, err := newRoundRobin[string]([]string{"a", "b", "c"}) + if err != nil { + t.Fatal(err) + } if r1.Len() != 3 { t.Fatalf("want 3, got %d", r1.Len()) } @@ -19,7 +22,10 @@ func TestRoundRobin(t *testing.T) { } res = nil - r2 := newRoundRobin[string]([]Weighted[string]{{"a", 2}, {"b", 1}, {"c", 1}}) + r2, err := newRoundRobin[string]([]Weighted[string]{{"a", 2}, {"b", 1}, {"c", 1}}) + if err != nil { + t.Fatal(err) + } if r2.Len() != 4 { t.Fatalf("want 4, got %d", r2.Len()) } From 4dad95c84b46757c54114a7bd5d7828a3b8bdf95 Mon Sep 17 00:00:00 2001 From: sunshineplan Date: Mon, 29 Sep 2025 17:09:18 +0800 Subject: [PATCH 13/40] list --- container/list.go | 238 ++++++++++++++++++++++++++++++++++------- container/list_test.go | 9 +- 2 files changed, 202 insertions(+), 45 deletions(-) diff --git a/container/list.go b/container/list.go index 7251ca0..a152a05 100644 --- a/container/list.go +++ b/container/list.go @@ -1,88 +1,185 @@ package container import ( - "container/list" "sync" + "unsafe" ) // Element is an element of a linked list. type Element[T any] struct { - e *list.Element + // Next and previous pointers in the doubly-linked list of elements. + // To simplify the implementation, internally a list l is implemented + // as a ring, such that &l.root is both the next element of the last + // list element (l.Back()) and the previous element of the first list + // element (l.Front()). + next, prev *Element[T] // The list to which this element belongs. list *List[T] + + // The value stored with this element. + value T +} + +// Set assigns value v to the element and returns it. +func (e *Element[T]) Set(v T) *Element[T] { + if e.list != nil { + e.list.mu.RLock() + defer e.list.mu.RUnlock() + } + e.value = v + return e } // Value returns the value stored with this element. func (e *Element[T]) Value() T { - e.list.mu.RLock() - defer e.list.mu.RUnlock() - return e.e.Value.(T) + if e.list != nil { + e.list.mu.RLock() + defer e.list.mu.RUnlock() + } + return e.value +} + +func (e *Element[T]) nextElement() *Element[T] { + if p := e.next; e.list != nil && p != &e.list.root { + return p + } + return nil } // Next returns the next list element or nil. func (e *Element[T]) Next() *Element[T] { + if e.list == nil { + return nil + } e.list.mu.RLock() defer e.list.mu.RUnlock() - if next := e.e.Next(); next != nil { - return &Element[T]{next, e.list} + return e.nextElement() +} + +func (e *Element[T]) prevElement() *Element[T] { + if p := e.prev; e.list != nil && p != &e.list.root { + return p } return nil } // Prev returns the previous list element or nil. func (e *Element[T]) Prev() *Element[T] { + if e.list == nil { + return nil + } e.list.mu.RLock() defer e.list.mu.RUnlock() - if prev := e.e.Prev(); prev != nil { - return &Element[T]{prev, e.list} - } - return nil + return e.prevElement() } -// List represents a doubly linked list. +// List represents a doubly linked list like [list.List]. // The zero value for List is an empty list ready to use. type List[T any] struct { - mu sync.RWMutex - l list.List + mu sync.RWMutex + root Element[T] // sentinel list element, only &root, root.prev, and root.next are used + len int // current list length excluding (this) sentinel element +} + +func (l *List[T]) init() { + l.root.next = &l.root + l.root.prev = &l.root + l.len = 0 } // Init initializes or clears list l. func (l *List[T]) Init() *List[T] { l.mu.Lock() defer l.mu.Unlock() - l.l.Init() + l.init() return l } -// New returns an initialized list. +// NewList returns an initialized list. func NewList[T any]() *List[T] { return new(List[T]).Init() } -// Len returns the number of elements of list l. The complexity is O(1). +// Len returns the number of elements of list l. +// The complexity is O(1). func (l *List[T]) Len() int { l.mu.RLock() defer l.mu.RUnlock() - return l.l.Len() + return l.len +} + +func (l *List[T]) front() *Element[T] { + if l.len == 0 { + return nil + } + return l.root.next } // Front returns the first element of list l or nil if the list is empty. func (l *List[T]) Front() *Element[T] { l.mu.RLock() defer l.mu.RUnlock() - if e := l.l.Front(); e != nil { - return &Element[T]{e, l} + return l.front() +} + +func (l *List[T]) back() *Element[T] { + if l.len == 0 { + return nil } - return nil + return l.root.prev } // Back returns the last element of list l or nil if the list is empty. func (l *List[T]) Back() *Element[T] { l.mu.RLock() defer l.mu.RUnlock() - if e := l.l.Back(); e != nil { - return &Element[T]{e, l} + return l.back() +} + +// lazyInit lazily initializes a zero List value. +func (l *List[T]) lazyInit() { + if l.root.next == nil { + l.init() } - return nil +} + +// insert inserts e after at, increments l.len, and returns e. +func (l *List[T]) insert(e, at *Element[T]) *Element[T] { + e.prev = at + e.next = at.next + e.prev.next = e + e.next.prev = e + e.list = l + l.len++ + return e +} + +// insertValue is a convenience wrapper for insert(&Element{Value: v}, at). +func (l *List[T]) insertValue(v T, at *Element[T]) *Element[T] { + return l.insert(&Element[T]{value: v}, at) +} + +// remove removes e from its list, decrements l.len +func (l *List[T]) remove(e *Element[T]) { + e.prev.next = e.next + e.next.prev = e.prev + e.next = nil // avoid memory leaks + e.prev = nil // avoid memory leaks + e.list = nil + l.len-- +} + +// move moves e to next to at. +func (l *List[T]) move(e, at *Element[T]) { + if e == at { + return + } + e.prev.next = e.next + e.next.prev = e.prev + + e.prev = at + e.next = at.next + e.prev.next = e + e.next.prev = e } // Remove removes e from l if e is an element of list l. @@ -91,89 +188,152 @@ func (l *List[T]) Back() *Element[T] { func (l *List[T]) Remove(e *Element[T]) T { l.mu.Lock() defer l.mu.Unlock() - return l.l.Remove(e.e).(T) + if e.list == l { + // if e.list == l, l must have been initialized when e was inserted + // in l or l == nil (e is a zero Element) and l.remove will crash + l.remove(e) + } + return e.value } // PushFront inserts a new element e with value v at the front of list l and returns e. func (l *List[T]) PushFront(v T) *Element[T] { l.mu.Lock() defer l.mu.Unlock() - return &Element[T]{l.l.PushFront(v), l} + l.lazyInit() + return l.insertValue(v, &l.root) } // PushBack inserts a new element e with value v at the back of list l and returns e. func (l *List[T]) PushBack(v T) *Element[T] { l.mu.Lock() defer l.mu.Unlock() - return &Element[T]{l.l.PushBack(v), l} + l.lazyInit() + return l.insertValue(v, l.root.prev) } // InsertBefore inserts a new element e with value v immediately before mark and returns e. // If mark is not an element of l, the list is not modified. // The mark must not be nil. func (l *List[T]) InsertBefore(v T, mark *Element[T]) *Element[T] { + if mark.list != l { + return nil + } l.mu.Lock() defer l.mu.Unlock() - return &Element[T]{l.l.InsertBefore(v, mark.e), l} + // see comment in List.Remove about initialization of l + return l.insertValue(v, mark.prev) } // InsertAfter inserts a new element e with value v immediately after mark and returns e. // If mark is not an element of l, the list is not modified. // The mark must not be nil. func (l *List[T]) InsertAfter(v T, mark *Element[T]) *Element[T] { + if mark.list != l { + return nil + } l.mu.Lock() defer l.mu.Unlock() - return &Element[T]{l.l.InsertAfter(v, mark.e), l} + // see comment in List.Remove about initialization of l + return l.insertValue(v, mark) } // MoveToFront moves element e to the front of list l. // If e is not an element of l, the list is not modified. // The element must not be nil. func (l *List[T]) MoveToFront(e *Element[T]) { + if e.list != l { + return + } l.mu.Lock() defer l.mu.Unlock() - l.l.MoveToFront(e.e) + if l.root.next == e { + return + } + // see comment in List.Remove about initialization of l + l.move(e, &l.root) } // MoveToBack moves element e to the back of list l. // If e is not an element of l, the list is not modified. // The element must not be nil. func (l *List[T]) MoveToBack(e *Element[T]) { + if e.list != l { + return + } l.mu.Lock() defer l.mu.Unlock() - l.l.MoveToBack(e.e) + if l.root.prev == e { + return + } + // see comment in List.Remove about initialization of l + l.move(e, l.root.prev) } // MoveBefore moves element e to its new position before mark. // If e or mark is not an element of l, or e == mark, the list is not modified. // The element and mark must not be nil. func (l *List[T]) MoveBefore(e, mark *Element[T]) { + if e.list != l || e == mark || mark.list != l { + return + } l.mu.Lock() defer l.mu.Unlock() - l.l.MoveBefore(e.e, mark.e) + l.move(e, mark.prev) } // MoveAfter moves element e to its new position after mark. // If e or mark is not an element of l, or e == mark, the list is not modified. // The element and mark must not be nil. func (l *List[T]) MoveAfter(e, mark *Element[T]) { + if e.list != l || e == mark || mark.list != l { + return + } l.mu.Lock() defer l.mu.Unlock() - l.l.MoveAfter(e.e, mark.e) + l.move(e, mark) } // PushBackList inserts a copy of another list at the back of list l. // The lists l and other may be the same. They must not be nil. func (l *List[T]) PushBackList(other *List[T]) { - l.mu.Lock() - defer l.mu.Unlock() - l.l.PushBackList(&other.l) + unlock := pushLock(l, other) + defer unlock() + l.lazyInit() + for i, e := other.len, other.front(); i > 0; i, e = i-1, e.nextElement() { + l.insertValue(e.value, l.root.prev) + } } // PushFrontList inserts a copy of another list at the front of list l. // The lists l and other may be the same. They must not be nil. func (l *List[T]) PushFrontList(other *List[T]) { - l.mu.Lock() - defer l.mu.Unlock() - l.l.PushFrontList(&other.l) + unlock := pushLock(l, other) + defer unlock() + l.lazyInit() + for i, e := other.len, other.back(); i > 0; i, e = i-1, e.prevElement() { + l.insertValue(e.value, &l.root) + } +} + +func pushLock[T any](l, other *List[T]) (unlock func()) { + if l == other { + l.mu.Lock() + unlock = l.mu.Unlock + } else if uintptr(unsafe.Pointer(l)) < uintptr(unsafe.Pointer(other)) { + l.mu.Lock() + other.mu.RLock() + unlock = func() { + other.mu.RUnlock() + l.mu.Unlock() + } + } else { + other.mu.RLock() + l.mu.Lock() + unlock = func() { + l.mu.Unlock() + other.mu.RUnlock() + } + } + return } diff --git a/container/list_test.go b/container/list_test.go index 8b75210..0cc9861 100644 --- a/container/list_test.go +++ b/container/list_test.go @@ -4,10 +4,7 @@ package container -import ( - "container/list" - "testing" -) +import "testing" func checkListLen[T int](t *testing.T, l *List[T], len int) bool { if n := l.Len(); n != len { @@ -140,7 +137,7 @@ func TestInsertBeforeUnknownMark(t *testing.T) { l.PushBack(1) l.PushBack(2) l.PushBack(3) - l.InsertBefore(1, &Element[int]{new(list.Element), nil}) + l.InsertBefore(1, new(Element[int])) checkList(t, &l, []int{1, 2, 3}) } @@ -150,7 +147,7 @@ func TestInsertAfterUnknownMark(t *testing.T) { l.PushBack(1) l.PushBack(2) l.PushBack(3) - l.InsertAfter(1, &Element[int]{new(list.Element), nil}) + l.InsertAfter(1, new(Element[int])) checkList(t, &l, []int{1, 2, 3}) } From 237a6db37e4ad30f07f22f2a76334c9bf8b8dbb8 Mon Sep 17 00:00:00 2001 From: sunshineplan Date: Tue, 30 Sep 2025 15:09:32 +0800 Subject: [PATCH 14/40] Update list.go --- container/list.go | 111 ++++++++++++++++++++++++++++++---------------- 1 file changed, 73 insertions(+), 38 deletions(-) diff --git a/container/list.go b/container/list.go index a152a05..e526b8b 100644 --- a/container/list.go +++ b/container/list.go @@ -1,12 +1,15 @@ package container import ( + "cmp" "sync" "unsafe" ) -// Element is an element of a linked list. +// Element is a thread-safe element of a linked list. type Element[T any] struct { + mu sync.RWMutex + // Next and previous pointers in the doubly-linked list of elements. // To simplify the implementation, internally a list l is implemented // as a ring, such that &l.root is both the next element of the last @@ -23,20 +26,16 @@ type Element[T any] struct { // Set assigns value v to the element and returns it. func (e *Element[T]) Set(v T) *Element[T] { - if e.list != nil { - e.list.mu.RLock() - defer e.list.mu.RUnlock() - } + e.mu.Lock() + defer e.mu.Unlock() e.value = v return e } // Value returns the value stored with this element. func (e *Element[T]) Value() T { - if e.list != nil { - e.list.mu.RLock() - defer e.list.mu.RUnlock() - } + e.mu.RLock() + defer e.mu.RUnlock() return e.value } @@ -49,11 +48,12 @@ func (e *Element[T]) nextElement() *Element[T] { // Next returns the next list element or nil. func (e *Element[T]) Next() *Element[T] { - if e.list == nil { - return nil + e.mu.RLock() + defer e.mu.RUnlock() + if e.list != nil { + e.list.mu.RLock() + defer e.list.mu.RUnlock() } - e.list.mu.RLock() - defer e.list.mu.RUnlock() return e.nextElement() } @@ -66,15 +66,16 @@ func (e *Element[T]) prevElement() *Element[T] { // Prev returns the previous list element or nil. func (e *Element[T]) Prev() *Element[T] { - if e.list == nil { - return nil + e.mu.RLock() + defer e.mu.RUnlock() + if e.list != nil { + e.list.mu.RLock() + defer e.list.mu.RUnlock() } - e.list.mu.RLock() - defer e.list.mu.RUnlock() return e.prevElement() } -// List represents a doubly linked list like [list.List]. +// List represents a thread-safe doubly linked list. // The zero value for List is an empty list ready to use. type List[T any] struct { mu sync.RWMutex @@ -186,9 +187,11 @@ func (l *List[T]) move(e, at *Element[T]) { // It returns the element value e.Value. // The element must not be nil. func (l *List[T]) Remove(e *Element[T]) T { - l.mu.Lock() - defer l.mu.Unlock() + e.mu.Lock() + defer e.mu.Unlock() if e.list == l { + l.mu.Lock() + defer l.mu.Unlock() // if e.list == l, l must have been initialized when e was inserted // in l or l == nil (e is a zero Element) and l.remove will crash l.remove(e) @@ -216,6 +219,8 @@ func (l *List[T]) PushBack(v T) *Element[T] { // If mark is not an element of l, the list is not modified. // The mark must not be nil. func (l *List[T]) InsertBefore(v T, mark *Element[T]) *Element[T] { + mark.mu.RLock() + defer mark.mu.RUnlock() if mark.list != l { return nil } @@ -229,6 +234,8 @@ func (l *List[T]) InsertBefore(v T, mark *Element[T]) *Element[T] { // If mark is not an element of l, the list is not modified. // The mark must not be nil. func (l *List[T]) InsertAfter(v T, mark *Element[T]) *Element[T] { + mark.mu.RLock() + defer mark.mu.RUnlock() if mark.list != l { return nil } @@ -242,6 +249,8 @@ func (l *List[T]) InsertAfter(v T, mark *Element[T]) *Element[T] { // If e is not an element of l, the list is not modified. // The element must not be nil. func (l *List[T]) MoveToFront(e *Element[T]) { + e.mu.RLock() + defer e.mu.RUnlock() if e.list != l { return } @@ -258,6 +267,8 @@ func (l *List[T]) MoveToFront(e *Element[T]) { // If e is not an element of l, the list is not modified. // The element must not be nil. func (l *List[T]) MoveToBack(e *Element[T]) { + e.mu.RLock() + defer e.mu.RUnlock() if e.list != l { return } @@ -274,7 +285,12 @@ func (l *List[T]) MoveToBack(e *Element[T]) { // If e or mark is not an element of l, or e == mark, the list is not modified. // The element and mark must not be nil. func (l *List[T]) MoveBefore(e, mark *Element[T]) { - if e.list != l || e == mark || mark.list != l { + if e == mark { + return + } + unlock := lock(&e.mu, &mark.mu, true, true) + defer unlock() + if e.list != l || mark.list != l { return } l.mu.Lock() @@ -286,7 +302,12 @@ func (l *List[T]) MoveBefore(e, mark *Element[T]) { // If e or mark is not an element of l, or e == mark, the list is not modified. // The element and mark must not be nil. func (l *List[T]) MoveAfter(e, mark *Element[T]) { - if e.list != l || e == mark || mark.list != l { + if e == mark { + return + } + unlock := lock(&e.mu, &mark.mu, true, true) + defer unlock() + if e.list != l || mark.list != l { return } l.mu.Lock() @@ -297,7 +318,7 @@ func (l *List[T]) MoveAfter(e, mark *Element[T]) { // PushBackList inserts a copy of another list at the back of list l. // The lists l and other may be the same. They must not be nil. func (l *List[T]) PushBackList(other *List[T]) { - unlock := pushLock(l, other) + unlock := lock(&l.mu, &other.mu, false, true) defer unlock() l.lazyInit() for i, e := other.len, other.front(); i > 0; i, e = i-1, e.nextElement() { @@ -308,7 +329,7 @@ func (l *List[T]) PushBackList(other *List[T]) { // PushFrontList inserts a copy of another list at the front of list l. // The lists l and other may be the same. They must not be nil. func (l *List[T]) PushFrontList(other *List[T]) { - unlock := pushLock(l, other) + unlock := lock(&l.mu, &other.mu, false, true) defer unlock() l.lazyInit() for i, e := other.len, other.back(); i > 0; i, e = i-1, e.prevElement() { @@ -316,23 +337,37 @@ func (l *List[T]) PushFrontList(other *List[T]) { } } -func pushLock[T any](l, other *List[T]) (unlock func()) { - if l == other { - l.mu.Lock() - unlock = l.mu.Unlock - } else if uintptr(unsafe.Pointer(l)) < uintptr(unsafe.Pointer(other)) { - l.mu.Lock() - other.mu.RLock() +func lock(s, r *sync.RWMutex, sReadOnly, rReadOnly bool) (unlock func()) { + var sl sync.Locker = s + var rl sync.Locker = r + if sReadOnly { + sl = s.RLocker() + } + if rReadOnly { + rl = r.RLocker() + } + switch cmp.Compare(uintptr(unsafe.Pointer(s)), uintptr(unsafe.Pointer(r))) { + case 0: + if sReadOnly && rReadOnly { + s.RLock() + unlock = s.RUnlock + } else { + s.Lock() + unlock = s.Unlock + } + case 1: + sl.Lock() + rl.Lock() unlock = func() { - other.mu.RUnlock() - l.mu.Unlock() + rl.Unlock() + sl.Unlock() } - } else { - other.mu.RLock() - l.mu.Lock() + case -1: + rl.Lock() + sl.Lock() unlock = func() { - l.mu.Unlock() - other.mu.RUnlock() + sl.Unlock() + rl.Unlock() } } return From abf21004bbe42be0e6ead4c173a9625abe49fb85 Mon Sep 17 00:00:00 2001 From: sunshineplan Date: Thu, 9 Oct 2025 15:26:12 +0800 Subject: [PATCH 15/40] map --- container/map.go | 36 +++++------ container/map_test.go | 140 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 157 insertions(+), 19 deletions(-) create mode 100644 container/map_test.go diff --git a/container/map.go b/container/map.go index b0f09f1..22e029c 100644 --- a/container/map.go +++ b/container/map.go @@ -2,11 +2,14 @@ package container import "sync" -type Map[Key, Value any] struct { +// Map is a generic concurrency-safe map that wraps sync.Map +// and provides type-safe access for keys and values. +type Map[Key comparable, Value any] struct { m sync.Map } -func NewMap[Key, Value any]() *Map[Key, Value] { +// NewMap creates and returns a new, empty generic concurrency-safe Map. +func NewMap[Key comparable, Value any]() *Map[Key, Value] { return &Map[Key, Value]{} } @@ -15,8 +18,8 @@ func NewMap[Key, Value any]() *Map[Key, Value] { // The ok result indicates whether value was found in the map. func (m *Map[Key, Value]) Load(key Key) (value Value, ok bool) { var v any - if v, ok = m.m.Load(key); ok && v != nil { - value = v.(Value) + if v, ok = m.m.Load(key); ok { + value, _ = v.(Value) } return } @@ -37,9 +40,7 @@ func (m *Map[Key, Value]) Clear() { func (m *Map[Key, Value]) LoadOrStore(key Key, value Value) (actual Value, loaded bool) { var v any if v, loaded = m.m.LoadOrStore(key, value); loaded { - if v != nil { - actual = v.(Value) - } + actual, _ = v.(Value) } else { actual = value } @@ -50,8 +51,8 @@ func (m *Map[Key, Value]) LoadOrStore(key Key, value Value) (actual Value, loade // The loaded result reports whether the key was present. func (m *Map[Key, Value]) LoadAndDelete(key Key) (value Value, loaded bool) { var v any - if v, loaded = m.m.LoadAndDelete(key); loaded && v != nil { - value = v.(Value) + if v, loaded = m.m.LoadAndDelete(key); loaded { + value, _ = v.(Value) } return } @@ -65,8 +66,8 @@ func (m *Map[Key, Value]) Delete(key Key) { // The loaded result reports whether the key was present. func (m *Map[Key, Value]) Swap(key Key, value Value) (previous Value, loaded bool) { var v any - if v, loaded = m.m.Swap(key, value); loaded && v != nil { - previous = v.(Value) + if v, loaded = m.m.Swap(key, value); loaded { + previous, _ = v.(Value) } return } @@ -100,14 +101,11 @@ func (m *Map[Key, Value]) CompareAndDelete(key Key, old Value) (deleted bool) { // false after a constant number of calls. func (m *Map[Key, Value]) Range(f func(Key, Value) bool) { m.m.Range(func(key, value any) bool { - var k Key - var v Value - if key != nil { - k = key.(Key) - } - if value != nil { - v = value.(Value) + if k, ok := key.(Key); ok { + if v, ok := value.(Value); ok { + return f(k, v) + } } - return f(k, v) + return true }) } diff --git a/container/map_test.go b/container/map_test.go new file mode 100644 index 0000000..aeb4b45 --- /dev/null +++ b/container/map_test.go @@ -0,0 +1,140 @@ +package container + +import ( + "sync" + "testing" +) + +func TestMap_BasicOperations(t *testing.T) { + m := NewMap[string, int]() + + // Store & Load + m.Store("a", 1) + if v, ok := m.Load("a"); !ok || v != 1 { + t.Fatalf("Load failed, got (%v, %v), want (1, true)", v, ok) + } + + // LoadOrStore (existing key) + actual, loaded := m.LoadOrStore("a", 2) + if !loaded || actual != 1 { + t.Fatalf("LoadOrStore existing key failed, got (%v, %v), want (1, true)", actual, loaded) + } + + // LoadOrStore (new key) + actual, loaded = m.LoadOrStore("b", 3) + if loaded || actual != 3 { + t.Fatalf("LoadOrStore new key failed, got (%v, %v), want (3, false)", actual, loaded) + } + + // LoadAndDelete + v, ok := m.LoadAndDelete("b") + if !ok || v != 3 { + t.Fatalf("LoadAndDelete failed, got (%v, %v), want (3, true)", v, ok) + } + if _, ok := m.Load("b"); ok { + t.Fatalf("Key 'b' should be deleted") + } + + // Swap + prev, loaded := m.Swap("a", 5) + if !loaded || prev != 1 { + t.Fatalf("Swap failed, got (%v, %v), want (1, true)", prev, loaded) + } + v, ok = m.Load("a") + if !ok || v != 5 { + t.Fatalf("Swap did not update value, got (%v, %v), want (5, true)", v, ok) + } + + // CompareAndSwap + if !m.CompareAndSwap("a", 5, 10) { + t.Fatalf("CompareAndSwap should succeed") + } + v, _ = m.Load("a") + if v != 10 { + t.Fatalf("CompareAndSwap failed, value = %v, want 10", v) + } + + if m.CompareAndSwap("a", 5, 20) { + t.Fatalf("CompareAndSwap should fail when old != current") + } + + // CompareAndDelete + m.Store("x", 42) + if !m.CompareAndDelete("x", 42) { + t.Fatalf("CompareAndDelete should succeed") + } + if _, ok := m.Load("x"); ok { + t.Fatalf("CompareAndDelete did not delete key") + } + + m.Store("y", 99) + if m.CompareAndDelete("y", 100) { + t.Fatalf("CompareAndDelete should fail when old != current") + } +} + +func TestMap_Range(t *testing.T) { + m := NewMap[string, int]() + m.Store("a", 1) + m.Store("b", 2) + m.Store("c", 3) + + collected := make(map[string]int) + m.Range(func(k string, v int) bool { + collected[k] = v + return true + }) + + if len(collected) != 3 { + t.Fatalf("Range failed, got %d elements, want 3", len(collected)) + } +} + +func TestMap_Clear(t *testing.T) { + m := NewMap[string, int]() + m.Store("a", 1) + m.Store("b", 2) + m.Clear() + + count := 0 + m.Range(func(k string, v int) bool { + count++ + return true + }) + + if count != 0 { + t.Fatalf("Clear failed, map still has %d elements", count) + } +} + +func TestMap_ConcurrentAccess(t *testing.T) { + m := NewMap[int, int]() + wg := sync.WaitGroup{} + + // Concurrent store + for i := 0; i < 100; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + m.Store(i, i*i) + }(i) + } + + wg.Wait() + + // Concurrent load + wg = sync.WaitGroup{} + for i := 0; i < 100; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + v, ok := m.Load(i) + if !ok { + t.Errorf("Missing key %d", i) + } else if v != i*i { + t.Errorf("Unexpected value for %d: got %d, want %d", i, v, i*i) + } + }(i) + } + wg.Wait() +} From 0ef2013966e6ba0626b3017fc7b7e8c988e1b673 Mon Sep 17 00:00:00 2001 From: sunshineplan Date: Thu, 9 Oct 2025 16:04:36 +0800 Subject: [PATCH 16/40] value --- container/map.go | 2 +- container/map_test.go | 4 ++-- container/value.go | 31 +++++++++++++++---------------- 3 files changed, 18 insertions(+), 19 deletions(-) diff --git a/container/map.go b/container/map.go index 22e029c..c1165fc 100644 --- a/container/map.go +++ b/container/map.go @@ -75,7 +75,7 @@ func (m *Map[Key, Value]) Swap(key Key, value Value) (previous Value, loaded boo // CompareAndSwap swaps the old and new values for key // if the value stored in the map is equal to old. // The old value must be of a comparable type. -func (m *Map[Key, Value]) CompareAndSwap(key Key, old Value, new Value) bool { +func (m *Map[Key, Value]) CompareAndSwap(key Key, old Value, new Value) (swapped bool) { return m.m.CompareAndSwap(key, old, new) } diff --git a/container/map_test.go b/container/map_test.go index aeb4b45..f86c6fc 100644 --- a/container/map_test.go +++ b/container/map_test.go @@ -112,7 +112,7 @@ func TestMap_ConcurrentAccess(t *testing.T) { wg := sync.WaitGroup{} // Concurrent store - for i := 0; i < 100; i++ { + for i := range 100 { wg.Add(1) go func(i int) { defer wg.Done() @@ -124,7 +124,7 @@ func TestMap_ConcurrentAccess(t *testing.T) { // Concurrent load wg = sync.WaitGroup{} - for i := 0; i < 100; i++ { + for i := range 100 { wg.Add(1) go func(i int) { defer wg.Done() diff --git a/container/value.go b/container/value.go index dd84c5c..03a0366 100644 --- a/container/value.go +++ b/container/value.go @@ -19,48 +19,47 @@ func NewValue[T any]() *Value[T] { // indicating whether a value was stored. // If there has been no call to Store for this Value, it returns the // zero value of T and false. -func (v *Value[T]) Load() (val T, stored bool) { - if v := v.v.Load(); v == nil { - return - } else { - return v.(T), true +func (v *Value[T]) Load() (val T, ok bool) { + if loaded := v.v.Load(); loaded != nil { + return loaded.(T), true } + return } // MustLoad returns the value set by the most recent Store. // It panics if there has been no call to Store for this Value. func (v *Value[T]) MustLoad() (val T) { - if v, stored := v.Load(); stored { - return v + val, ok := v.Load() + if !ok { + panic("container/value: there has been no call to Store for this Value") } - panic("cache/value: there has been no call to Store for this Value") + return } // Store sets the value of the [Value] v to val. func (v *Value[T]) Store(val T) { if any(val) == nil { - panic("cache/value: store of nil value into Value") + panic("container/value: store of nil value into Value") } v.v.Store(val) } // Swap stores the new value into the Value and returns the previous value. // If no value was previously stored, it returns the zero value of T and false. -func (v *Value[T]) Swap(new T) (old T, stored bool) { +func (v *Value[T]) Swap(new T) (old T, loaded bool) { if any(new) == nil { - panic("cache/value: swap of nil value into Value") + panic("container/value: swap of nil value into Value") } - if v := v.v.Swap(new); v == nil { - return - } else { - return v.(T), true + if prev := v.v.Swap(new); prev != nil { + return prev.(T), true } + return } // CompareAndSwap executes the compare-and-swap operation for the [Value]. func (v *Value[T]) CompareAndSwap(old, new T) (swapped bool) { if any(new) == nil { - panic("cache/value: compare and swap of nil value into Value") + panic("container/value: compare and swap of nil value into Value") } return v.v.CompareAndSwap(old, new) } From e5090c7676094c9c64d963e4bde2df4235206695 Mon Sep 17 00:00:00 2001 From: sunshineplan Date: Thu, 9 Oct 2025 16:37:54 +0800 Subject: [PATCH 17/40] counter --- counter/counter.go | 31 ++++++++++++++++++++++++++++--- counter/listener.go | 23 +++++++++++++---------- counter/rw.go | 24 +++++++++--------------- 3 files changed, 50 insertions(+), 28 deletions(-) diff --git a/counter/counter.go b/counter/counter.go index fd87779..82b7b3d 100644 --- a/counter/counter.go +++ b/counter/counter.go @@ -3,16 +3,19 @@ package counter import ( "io" "sync/atomic" + "time" ) -type Counter atomic.Int64 +type Counter struct { + v atomic.Int64 +} func (c *Counter) Add(delta int64) (new int64) { - return (*atomic.Int64)(c).Add(delta) + return c.v.Add(delta) } func (c *Counter) Load() int64 { - return (*atomic.Int64)(c).Load() + return c.v.Load() } func (c *Counter) AddWriter(w io.Writer) io.Writer { @@ -22,3 +25,25 @@ func (c *Counter) AddWriter(w io.Writer) io.Writer { func (c *Counter) AddReader(r io.Reader) io.Reader { return newReader(c, r) } + +type RateCounter struct { + Counter + start time.Time +} + +func NewRateCounter() *RateCounter { + return &RateCounter{start: time.Now()} +} + +func (rc *RateCounter) Reset() { + rc.Counter.v.Store(0) + rc.start = time.Now() +} + +func (rc *RateCounter) Rate() float64 { + duration := time.Since(rc.start).Seconds() + if duration == 0 { + return 0 + } + return float64(rc.Load()) / duration +} diff --git a/counter/listener.go b/counter/listener.go index eb08185..e3f0358 100644 --- a/counter/listener.go +++ b/counter/listener.go @@ -1,6 +1,9 @@ package counter -import "net" +import ( + "io" + "net" +) var ( _ net.Listener = &Listener{} @@ -22,7 +25,11 @@ func (l *Listener) Accept() (net.Conn, error) { if err != nil { return nil, err } - return &conn{c, l}, nil + return &conn{ + Conn: c, + r: l.read.AddReader(c), + w: l.written.AddWriter(c), + }, nil } func (l *Listener) Close() error { @@ -43,13 +50,9 @@ func (l *Listener) WriteCount() int64 { type conn struct { net.Conn - listener *Listener -} - -func (conn *conn) Write(b []byte) (n int, err error) { - return conn.listener.written.AddWriter(conn.Conn).Write(b) + r io.Reader + w io.Writer } -func (conn *conn) Read(b []byte) (n int, err error) { - return conn.listener.read.AddReader(conn.Conn).Read(b) -} +func (c *conn) Read(b []byte) (int, error) { return c.r.Read(b) } +func (c *conn) Write(b []byte) (int, error) { return c.w.Write(b) } diff --git a/counter/rw.go b/counter/rw.go index 484a2dd..e21bc90 100644 --- a/counter/rw.go +++ b/counter/rw.go @@ -1,8 +1,6 @@ package counter -import ( - "io" -) +import "io" type reader struct { *Counter @@ -19,10 +17,9 @@ func newReader(n *Counter, r io.Reader) io.Reader { func (r *reader) Read(p []byte) (n int, err error) { n, err = r.r.Read(p) - if err != nil { - return + if n > 0 { + r.Add(int64(n)) } - r.Add(int64(n)) return } @@ -32,10 +29,9 @@ type readerWriterTo struct { func (r *readerWriterTo) WriteTo(w io.Writer) (n int64, err error) { n, err = r.r.(io.WriterTo).WriteTo(w) - if err != nil { - return + if n > 0 { + r.Add(n) } - r.Add(int64(n)) return } @@ -54,10 +50,9 @@ func newWriter(n *Counter, w io.Writer) io.Writer { func (w *writer) Write(p []byte) (n int, err error) { n, err = w.w.Write(p) - if err != nil { - return + if n > 0 { + w.Add(int64(n)) } - w.Add(int64(n)) return } @@ -67,9 +62,8 @@ type writerReaderFrom struct { func (w *writerReaderFrom) ReadFrom(r io.Reader) (n int64, err error) { n, err = w.w.(io.ReaderFrom).ReadFrom(r) - if err != nil { - return + if n > 0 { + w.Add(n) } - w.Add(int64(n)) return } From 6a512c2609a1858b51203426c4066e20d24b7e1a Mon Sep 17 00:00:00 2001 From: sunshineplan Date: Thu, 9 Oct 2025 16:56:58 +0800 Subject: [PATCH 18/40] Update counter --- counter/counter.go | 4 ++-- counter/rw.go | 36 ++++++++++++++++++------------------ 2 files changed, 20 insertions(+), 20 deletions(-) diff --git a/counter/counter.go b/counter/counter.go index 82b7b3d..709742a 100644 --- a/counter/counter.go +++ b/counter/counter.go @@ -19,11 +19,11 @@ func (c *Counter) Load() int64 { } func (c *Counter) AddWriter(w io.Writer) io.Writer { - return newWriter(c, w) + return newWriterCounter(c, w) } func (c *Counter) AddReader(r io.Reader) io.Reader { - return newReader(c, r) + return newReaderCounter(c, r) } type RateCounter struct { diff --git a/counter/rw.go b/counter/rw.go index e21bc90..cf8587d 100644 --- a/counter/rw.go +++ b/counter/rw.go @@ -2,20 +2,20 @@ package counter import "io" -type reader struct { +type readerCounter struct { *Counter r io.Reader } -func newReader(n *Counter, r io.Reader) io.Reader { - reader := &reader{n, r} +func newReaderCounter(n *Counter, r io.Reader) io.Reader { + reader := &readerCounter{n, r} if _, ok := r.(io.WriterTo); ok { - return readerWriterTo{reader} + return writerTo{reader} } return reader } -func (r *reader) Read(p []byte) (n int, err error) { +func (r *readerCounter) Read(p []byte) (n int, err error) { n, err = r.r.Read(p) if n > 0 { r.Add(int64(n)) @@ -23,32 +23,32 @@ func (r *reader) Read(p []byte) (n int, err error) { return } -type readerWriterTo struct { - *reader +type writerTo struct { + *readerCounter } -func (r *readerWriterTo) WriteTo(w io.Writer) (n int64, err error) { - n, err = r.r.(io.WriterTo).WriteTo(w) +func (r *writerTo) WriteTo(w io.Writer) (n int64, err error) { + n, err = io.Copy(w, r.r) if n > 0 { r.Add(n) } return } -type writer struct { +type writerCounter struct { *Counter w io.Writer } -func newWriter(n *Counter, w io.Writer) io.Writer { - writer := &writer{n, w} +func newWriterCounter(n *Counter, w io.Writer) io.Writer { + writer := &writerCounter{n, w} if _, ok := w.(io.ReaderFrom); ok { - return writerReaderFrom{writer} + return readerFrom{writer} } return writer } -func (w *writer) Write(p []byte) (n int, err error) { +func (w *writerCounter) Write(p []byte) (n int, err error) { n, err = w.w.Write(p) if n > 0 { w.Add(int64(n)) @@ -56,12 +56,12 @@ func (w *writer) Write(p []byte) (n int, err error) { return } -type writerReaderFrom struct { - *writer +type readerFrom struct { + *writerCounter } -func (w *writerReaderFrom) ReadFrom(r io.Reader) (n int64, err error) { - n, err = w.w.(io.ReaderFrom).ReadFrom(r) +func (w *readerFrom) ReadFrom(r io.Reader) (n int64, err error) { + n, err = io.Copy(w.w, r) if n > 0 { w.Add(n) } From df624c813c1f0ae1f014f7121386e2d20761fac1 Mon Sep 17 00:00:00 2001 From: sunshineplan Date: Fri, 10 Oct 2025 15:29:36 +0800 Subject: [PATCH 19/40] Update counter --- counter/counter.go | 69 ++++++++++++++++++++++++++------------ counter/counter_test.go | 12 +++---- counter/listener.go | 30 ++++++----------- counter/listener_test.go | 4 +-- counter/rw.go | 69 -------------------------------------- httpsvr/httpsvr.go | 8 ++--- progressbar/counter.go | 39 +++++++++++++++++++++ progressbar/progressbar.go | 12 +++---- 8 files changed, 115 insertions(+), 128 deletions(-) delete mode 100644 counter/rw.go create mode 100644 progressbar/counter.go diff --git a/counter/counter.go b/counter/counter.go index 709742a..0dbdff6 100644 --- a/counter/counter.go +++ b/counter/counter.go @@ -3,47 +3,72 @@ package counter import ( "io" "sync/atomic" - "time" ) type Counter struct { - v atomic.Int64 + n atomic.Int64 } func (c *Counter) Add(delta int64) (new int64) { - return c.v.Add(delta) + return c.n.Add(delta) } -func (c *Counter) Load() int64 { - return c.v.Load() +func (c *Counter) Get() int64 { + return c.n.Load() } -func (c *Counter) AddWriter(w io.Writer) io.Writer { - return newWriterCounter(c, w) +type CounterReader struct { + r io.Reader + c *Counter } -func (c *Counter) AddReader(r io.Reader) io.Reader { - return newReaderCounter(c, r) +func CountReader(r io.Reader, c *Counter) io.Reader { + return NewCounterReader(r, c) } -type RateCounter struct { - Counter - start time.Time +func NewCounterReader(r io.Reader, c *Counter) *CounterReader { + if c == nil { + c = new(Counter) + } + return &CounterReader{r, c} +} + +func (r *CounterReader) Read(p []byte) (n int, err error) { + n, err = r.r.Read(p) + if n > 0 { + r.c.Add(int64(n)) + } + return +} + +func (r *CounterReader) Bytes() int64 { + return r.c.Get() } -func NewRateCounter() *RateCounter { - return &RateCounter{start: time.Now()} +type CounterWriter struct { + w io.Writer + c *Counter } -func (rc *RateCounter) Reset() { - rc.Counter.v.Store(0) - rc.start = time.Now() +func CountWriter(w io.Writer, c *Counter) io.Writer { + return NewCounterWriter(w, c) } -func (rc *RateCounter) Rate() float64 { - duration := time.Since(rc.start).Seconds() - if duration == 0 { - return 0 +func NewCounterWriter(w io.Writer, c *Counter) *CounterWriter { + if c == nil { + c = new(Counter) } - return float64(rc.Load()) / duration + return &CounterWriter{w, c} +} + +func (w *CounterWriter) Write(p []byte) (n int, err error) { + n, err = w.w.Write(p) + if n > 0 { + w.c.Add(int64(n)) + } + return +} + +func (w *CounterWriter) Bytes() int64 { + return w.c.Get() } diff --git a/counter/counter_test.go b/counter/counter_test.go index 3e37490..e414230 100644 --- a/counter/counter_test.go +++ b/counter/counter_test.go @@ -14,19 +14,19 @@ func TestReader(t *testing.T) { c, buf := new(Counter), new(bytes.Buffer) buf.Write(data1) buf.Write(data2) - r := c.AddReader(buf) + r := CountReader(buf, c) io.ReadAll(r) - if count := c.Load(); count != dataLen { - t.Fatalf("expected %d; got %d", dataLen, count) + if n := c.Get(); n != dataLen { + t.Fatalf("expected %d; got %d", dataLen, n) } } func TestWriter(t *testing.T) { c, buf := new(Counter), new(bytes.Buffer) - w := c.AddWriter(buf) + w := CountWriter(buf, c) w.Write(data1) w.Write(data2) - if count := c.Load(); count != dataLen { - t.Fatalf("expected %d; got %d", dataLen, count) + if n := c.Get(); n != dataLen { + t.Fatalf("expected %d; got %d", dataLen, n) } } diff --git a/counter/listener.go b/counter/listener.go index e3f0358..673c1ab 100644 --- a/counter/listener.go +++ b/counter/listener.go @@ -11,41 +11,33 @@ var ( ) type Listener struct { - listener net.Listener - read Counter - written Counter + net.Listener + readBytes Counter + writeBytes Counter } func NewListener(listener net.Listener) *Listener { - return &Listener{listener: listener} + return &Listener{Listener: listener} } func (l *Listener) Accept() (net.Conn, error) { - c, err := l.listener.Accept() + c, err := l.Listener.Accept() if err != nil { return nil, err } return &conn{ Conn: c, - r: l.read.AddReader(c), - w: l.written.AddWriter(c), + r: CountReader(c, &l.readBytes), + w: CountWriter(c, &l.writeBytes), }, nil } -func (l *Listener) Close() error { - return l.listener.Close() +func (l *Listener) ReadBytes() int64 { + return l.readBytes.Get() } -func (l *Listener) Addr() net.Addr { - return l.listener.Addr() -} - -func (l *Listener) ReadCount() int64 { - return l.read.Load() -} - -func (l *Listener) WriteCount() int64 { - return l.written.Load() +func (l *Listener) WriteBytes() int64 { + return l.writeBytes.Get() } type conn struct { diff --git a/counter/listener_test.go b/counter/listener_test.go index 4584b1f..1b0ad5a 100644 --- a/counter/listener_test.go +++ b/counter/listener_test.go @@ -50,7 +50,7 @@ func TestListener(t *testing.T) { } conn.Close() - if count := l.ReadCount(); count != dataLen { - t.Fatalf("expected %d; got %d", dataLen, count) + if n := l.ReadBytes(); n != dataLen { + t.Fatalf("expected %d; got %d", dataLen, n) } } diff --git a/counter/rw.go b/counter/rw.go deleted file mode 100644 index cf8587d..0000000 --- a/counter/rw.go +++ /dev/null @@ -1,69 +0,0 @@ -package counter - -import "io" - -type readerCounter struct { - *Counter - r io.Reader -} - -func newReaderCounter(n *Counter, r io.Reader) io.Reader { - reader := &readerCounter{n, r} - if _, ok := r.(io.WriterTo); ok { - return writerTo{reader} - } - return reader -} - -func (r *readerCounter) Read(p []byte) (n int, err error) { - n, err = r.r.Read(p) - if n > 0 { - r.Add(int64(n)) - } - return -} - -type writerTo struct { - *readerCounter -} - -func (r *writerTo) WriteTo(w io.Writer) (n int64, err error) { - n, err = io.Copy(w, r.r) - if n > 0 { - r.Add(n) - } - return -} - -type writerCounter struct { - *Counter - w io.Writer -} - -func newWriterCounter(n *Counter, w io.Writer) io.Writer { - writer := &writerCounter{n, w} - if _, ok := w.(io.ReaderFrom); ok { - return readerFrom{writer} - } - return writer -} - -func (w *writerCounter) Write(p []byte) (n int, err error) { - n, err = w.w.Write(p) - if n > 0 { - w.Add(int64(n)) - } - return -} - -type readerFrom struct { - *writerCounter -} - -func (w *readerFrom) ReadFrom(r io.Reader) (n int64, err error) { - n, err = io.Copy(w.w, r) - if n > 0 { - w.Add(n) - } - return -} diff --git a/httpsvr/httpsvr.go b/httpsvr/httpsvr.go index 6a12cad..2372e19 100644 --- a/httpsvr/httpsvr.go +++ b/httpsvr/httpsvr.go @@ -165,18 +165,18 @@ func (s *Server) RunTLS(certFile, keyFile string) error { return s.run() } -func (s *Server) ReadCount() int64 { +func (s *Server) ReadBytes() int64 { if s.l == nil { return 0 } - return s.l.ReadCount() + return s.l.ReadBytes() } -func (s *Server) WriteCount() int64 { +func (s *Server) WriteBytes() int64 { if s.l == nil { return 0 } - return s.l.WriteCount() + return s.l.WriteBytes() } // TCP runs an HTTP server on TCP network listener. diff --git a/progressbar/counter.go b/progressbar/counter.go new file mode 100644 index 0000000..f17675c --- /dev/null +++ b/progressbar/counter.go @@ -0,0 +1,39 @@ +package progressbar + +import ( + "io" + + "github.com/sunshineplan/utils/counter" +) + +type genericCounter interface { + Add(int64) int64 + Write([]byte) (int, error) + Get() int64 +} + +var ( + _ genericCounter = new(numberCounter) + _ genericCounter = new(writerCounter) +) + +func newNumberCounter() genericCounter { + return new(numberCounter) +} + +func newWriterCounter(w io.Writer) genericCounter { + return &writerCounter{counter.NewCounterWriter(w, new(counter.Counter))} +} + +type numberCounter struct { + counter.Counter +} + +func (c *numberCounter) Write(_ []byte) (n int, err error) { return } + +type writerCounter struct { + *counter.CounterWriter +} + +func (c *writerCounter) Add(_ int64) (n int64) { return } +func (c *writerCounter) Get() int64 { return c.CounterWriter.Bytes() } diff --git a/progressbar/progressbar.go b/progressbar/progressbar.go index 99482b5..6dda4a6 100644 --- a/progressbar/progressbar.go +++ b/progressbar/progressbar.go @@ -11,7 +11,6 @@ import ( "text/template" "time" - "github.com/sunshineplan/utils/counter" "github.com/sunshineplan/utils/unit" ) @@ -44,7 +43,7 @@ type ProgressBar[T int | int64] struct { renderInterval time.Duration template *template.Template - current counter.Counter + current genericCounter total int64 additional string speed float64 @@ -74,6 +73,7 @@ func New[T int | int64](total T) *ProgressBar[T] { blockWidth: 40, refreshInterval: defaultRefresh, template: template.Must(template.New("ProgressBar").Parse(defaultTemplate)), + current: newNumberCounter(), total: int64(total), } } @@ -166,7 +166,7 @@ func (pb *ProgressBar[T]) Additional(s string) { } func (pb *ProgressBar[T]) now() int64 { - return pb.current.Load() + return pb.current.Get() } func (pb *ProgressBar[T]) print(s string, msg bool) { @@ -348,13 +348,13 @@ func (pb *ProgressBar[T]) Cancel() { // FromReader starts the progress bar from a reader. func (pb *ProgressBar[T]) FromReader(r io.Reader, w io.Writer) (int64, error) { + pb.current = newWriterCounter(w) if err := pb.Start(); err != nil { return 0, err } - n, err := io.Copy(pb.current.AddWriter(w), r) + n, err := io.Copy(pb.current, r) if err != nil { pb.Cancel() - return n, err } - return n, nil + return n, err } From 7463394ddcc4366b894d40eba9c8c9f62b26abb7 Mon Sep 17 00:00:00 2001 From: sunshineplan Date: Fri, 10 Oct 2025 15:52:16 +0800 Subject: [PATCH 20/40] Update counter --- counter/counter.go | 27 ++++++++++++++++++++++----- counter/listener.go | 28 ++++++++++++++++++++++------ 2 files changed, 44 insertions(+), 11 deletions(-) diff --git a/counter/counter.go b/counter/counter.go index 0dbdff6..904147d 100644 --- a/counter/counter.go +++ b/counter/counter.go @@ -5,27 +5,34 @@ import ( "sync/atomic" ) +// Counter is a thread-safe utility for counting values, starting from zero. type Counter struct { n atomic.Int64 } -func (c *Counter) Add(delta int64) (new int64) { +// Add adds delta to the [Counter] and returns the new value. +func (c *Counter) Add(delta int64) int64 { return c.n.Add(delta) } +// Get returns the current value of the [Counter]. func (c *Counter) Get() int64 { return c.n.Load() } +// CounterReader wraps an io.Reader to count the number of bytes read. type CounterReader struct { - r io.Reader - c *Counter + r io.Reader // Underlying reader + c *Counter // Counter for bytes read } +// CountReader creates an io.Reader that counts bytes read from r, using the provided [Counter]. func CountReader(r io.Reader, c *Counter) io.Reader { return NewCounterReader(r, c) } +// NewCounterReader creates a [CounterReader] that counts bytes read from r, using the provided [Counter]. +// If c is nil, a new Counter is created. func NewCounterReader(r io.Reader, c *Counter) *CounterReader { if c == nil { c = new(Counter) @@ -33,6 +40,8 @@ func NewCounterReader(r io.Reader, c *Counter) *CounterReader { return &CounterReader{r, c} } +// Read reads from the underlying Reader and increments the counter by the number of bytes read. +// It returns the number of bytes read and any error encountered. func (r *CounterReader) Read(p []byte) (n int, err error) { n, err = r.r.Read(p) if n > 0 { @@ -41,19 +50,24 @@ func (r *CounterReader) Read(p []byte) (n int, err error) { return } +// Bytes returns the total number of bytes read. func (r *CounterReader) Bytes() int64 { return r.c.Get() } +// CounterWriter wraps an io.Writer to count the number of bytes written. type CounterWriter struct { - w io.Writer - c *Counter + w io.Writer // Underlying writer + c *Counter // Counter for bytes written } +// CountWriter creates an io.Writer that counts bytes written to w, using the provided [Counter]. func CountWriter(w io.Writer, c *Counter) io.Writer { return NewCounterWriter(w, c) } +// NewCounterWriter creates a [CounterWriter] that counts bytes written to w, using the provided [Counter]. +// If c is nil, a new Counter is created. func NewCounterWriter(w io.Writer, c *Counter) *CounterWriter { if c == nil { c = new(Counter) @@ -61,6 +75,8 @@ func NewCounterWriter(w io.Writer, c *Counter) *CounterWriter { return &CounterWriter{w, c} } +// Write writes to the underlying Writer and increments the counter by the number of bytes written. +// It returns the number of bytes written and any error encountered. func (w *CounterWriter) Write(p []byte) (n int, err error) { n, err = w.w.Write(p) if n > 0 { @@ -69,6 +85,7 @@ func (w *CounterWriter) Write(p []byte) (n int, err error) { return } +// Bytes returns the total number of bytes written. func (w *CounterWriter) Bytes() int64 { return w.c.Get() } diff --git a/counter/listener.go b/counter/listener.go index 673c1ab..11511a2 100644 --- a/counter/listener.go +++ b/counter/listener.go @@ -10,16 +10,20 @@ var ( _ net.Conn = &conn{} ) +// Listener wraps a net.Listener to count bytes read and written across all connections. type Listener struct { net.Listener - readBytes Counter - writeBytes Counter + readBytes Counter // Counter for bytes read across all connections + writeBytes Counter // Counter for bytes written across all connections } +// NewListener creates a Listener that counts bytes read and written across all connections. func NewListener(listener net.Listener) *Listener { return &Listener{Listener: listener} } +// Accept accepts a connection and wraps it with byte counting for reads and writes. +// It returns the wrapped connection or an error if the accept fails. func (l *Listener) Accept() (net.Conn, error) { c, err := l.Listener.Accept() if err != nil { @@ -32,19 +36,31 @@ func (l *Listener) Accept() (net.Conn, error) { }, nil } +// ReadBytes returns the total number of bytes read across all connections. func (l *Listener) ReadBytes() int64 { return l.readBytes.Get() } +// WriteBytes returns the total number of bytes written across all connections. func (l *Listener) WriteBytes() int64 { return l.writeBytes.Get() } +// conn wraps a net.Conn to count bytes read and written. type conn struct { net.Conn - r io.Reader - w io.Writer + r io.Reader // Reader that counts bytes read + w io.Writer // Writer that counts bytes written } -func (c *conn) Read(b []byte) (int, error) { return c.r.Read(b) } -func (c *conn) Write(b []byte) (int, error) { return c.w.Write(b) } +// Read reads from the underlying Reader and counts the bytes read. +// It returns the number of bytes read and any error encountered. +func (c *conn) Read(b []byte) (int, error) { + return c.r.Read(b) +} + +// Write writes to the underlying Writer and counts the bytes written. +// It returns the number of bytes written and any error encountered. +func (c *conn) Write(b []byte) (int, error) { + return c.w.Write(b) +} From 6140a0c132691226cc07b7e190962fbb44114604 Mon Sep 17 00:00:00 2001 From: sunshineplan Date: Fri, 10 Oct 2025 16:56:32 +0800 Subject: [PATCH 21/40] pool --- pool/pool.go | 36 ++++++++++++++++++++++++++++-------- 1 file changed, 28 insertions(+), 8 deletions(-) diff --git a/pool/pool.go b/pool/pool.go index c202780..057c3ce 100644 --- a/pool/pool.go +++ b/pool/pool.go @@ -2,22 +2,42 @@ package pool import "sync" +// Pool is a type-safe wrapper around [sync.Pool] that provides +// a generic, concurrency-safe object pool for any type T. type Pool[T any] struct { - p *sync.Pool + pool sync.Pool + // New optionally specifies a function to generate + // a value when Get would otherwise return nil. + // It may not be changed concurrently with calls to Get. + New func() *T } +// New creates a new Pool that allocates new zero-valued objects +// of type T when none are available. func New[T any]() *Pool[T] { - return NewFunc(func() *T { return new(T) }) -} - -func NewFunc[T any](fn func() *T) *Pool[T] { - return &Pool[T]{&sync.Pool{New: func() any { return fn() }}} + return &Pool[T]{New: func() *T { return new(T) }} } +// Get selects an arbitrary item from the [Pool], removes it from the +// Pool, and returns it to the caller. +// Get may choose to ignore the pool and treat it as empty. +// Callers should not assume any relation between values passed to [Pool.Put] and +// the values returned by Get. +// +// If Get would otherwise return nil and p.New is non-nil, Get returns +// the result of calling p.New. func (p *Pool[T]) Get() *T { - return p.p.Get().(*T) + x := p.pool.Get() + if x == nil { + if p.New != nil { + return p.New() + } + return nil + } + return x.(*T) } +// Put adds x to the pool. func (p *Pool[T]) Put(x *T) { - p.p.Put(x) + p.pool.Put(x) } From 0be3286d77e0f20beac808b9e7afc5f9cd6d5cbc Mon Sep 17 00:00:00 2001 From: sunshineplan Date: Sat, 11 Oct 2025 11:17:53 +0800 Subject: [PATCH 22/40] pop3 --- pop3/pop3.go | 110 +++++++++++++++++++++++++++++++-------------------- 1 file changed, 67 insertions(+), 43 deletions(-) diff --git a/pop3/pop3.go b/pop3/pop3.go index a799e27..9029584 100644 --- a/pop3/pop3.go +++ b/pop3/pop3.go @@ -5,6 +5,7 @@ import ( "crypto/tls" "errors" "fmt" + "io" "log/slog" "net" "net/textproto" @@ -65,6 +66,14 @@ func NewClient(conn net.Conn) (*Client, error) { return c, nil } +func (c *Client) Auth(user, pass string) error { + if _, err := c.Cmd("USER %s", false, user); err != nil { + return err + } + _, err := c.Cmd("PASS %s", false, pass) + return err +} + // Stat returns the number of messages and their total size in bytes in the inbox. func (c *Client) Stat() (count int, size int, err error) { s, err := c.Cmd("STAT", false) @@ -74,6 +83,9 @@ func (c *Client) Stat() (count int, size int, err error) { // count size f := strings.Fields(s) + if len(f) < 2 { + return 0, 0, fmt.Errorf("invalid STAT response: %q", s) + } // Total number of messages. count, err = strconv.Atoi(f[0]) @@ -100,6 +112,26 @@ type MessageID struct { UID string } +func (c *Client) multiList(cmd string, parse func([]string) (MessageID, error)) ([]MessageID, error) { + s, err := c.Cmd(cmd, true) + if err != nil { + return nil, err + } + var out []MessageID + for _, line := range strings.Split(s, lineBreak) { + f := strings.Fields(line) + if len(f) == 0 { + continue + } + id, err := parse(f) + if err != nil { + return nil, err + } + out = append(out, id) + } + return out, nil +} + // List returns a list of (message ID, message Size) pairs. // If the optional id > 0, then only that particular message is listed. // The message IDs are sequential, 1 to N. @@ -120,31 +152,23 @@ func (c *Client) List(id int) ([]MessageID, error) { return nil, err } - var ( - out []MessageID - lines = strings.Split(s, lineBreak) - ) - - for _, l := range lines { + var out []MessageID + for l := range strings.SplitSeq(s, lineBreak) { // id size f := strings.Fields(l) if len(f) == 0 { - break + continue } - id, err := strconv.Atoi(f[0]) if err != nil { return nil, err } - size, err := strconv.Atoi(f[1]) if err != nil { return nil, err } - out = append(out, MessageID{ID: id, Size: size}) } - return out, nil } @@ -168,26 +192,19 @@ func (c *Client) Uidl(id int) ([]MessageID, error) { return nil, err } - var ( - out []MessageID - lines = strings.Split(s, lineBreak) - ) - - for _, l := range lines { + var out []MessageID + for l := range strings.SplitSeq(s, lineBreak) { // id uid f := strings.Fields(l) if len(f) == 0 { - break + continue } - id, err := strconv.Atoi(f[0]) if err != nil { return nil, err } - out = append(out, MessageID{ID: id, UID: f[1]}) } - return out, nil } @@ -231,6 +248,7 @@ func (c *Client) Noop() error { // quit and close. func (c *Client) Quit() error { if _, err := c.Cmd("QUIT", false); err != nil { + c.Close() return err } return c.Close() @@ -256,39 +274,45 @@ func (c *Client) Cmd(s string, isMulti bool, args ...any) (string, error) { return s, nil } - var res string + var b strings.Builder for { s, err := c.ReadLine() if err != nil { + if err == io.EOF { + err = io.ErrUnexpectedEOF + } return "", err } slog.Debug("<<< " + s) - if s == "." { - break + // Dot by itself marks end; otherwise cut one dot. + if len(s) > 0 && s[0] == '.' { + if len(s) == 1 { + break + } + s = s[1:] } - - res += s + lineBreak + b.WriteString(s) + b.WriteString(lineBreak) } - - return res, nil + return b.String(), nil } func parseResp(s string) (string, error) { - if len(s) == 0 { + switch s { + case "", respOK: return "", nil - } - - if s == respOK { - return "", nil - } else if strings.HasPrefix(s, respOKInfo) { - return strings.TrimPrefix(s, respOKInfo), nil - } else if s == respErr { - return "", errors.New("unknown error (no info specified in response)") - } else if strings.HasPrefix(s, respErrInfo) { - return "", errors.New(strings.TrimPrefix(s, respErrInfo)) - } else if strings.HasPrefix(s, respContinue) { - return strings.TrimPrefix(s, respContinue), nil - } else { - return "", fmt.Errorf("unknown response: %q", s) + case respErr: + return "", errors.New("server returned -ERR without info") + default: + switch { + case strings.HasPrefix(s, respOKInfo): + return strings.TrimPrefix(s, respOKInfo), nil + case strings.HasPrefix(s, respErrInfo): + return "", errors.New(strings.TrimPrefix(s, respErrInfo)) + case strings.HasPrefix(s, respContinue): + return strings.TrimPrefix(s, respContinue), nil + default: + return "", fmt.Errorf("unknown response: %q", s) + } } } From a77926828fc544449ef45f7ae1c47fbb925ed485 Mon Sep 17 00:00:00 2001 From: sunshineplan Date: Sat, 11 Oct 2025 11:25:29 +0800 Subject: [PATCH 23/40] Update pop3.go --- pop3/pop3.go | 100 +++++++++++++++++++++------------------------------ 1 file changed, 41 insertions(+), 59 deletions(-) diff --git a/pop3/pop3.go b/pop3/pop3.go index 9029584..c9f3a13 100644 --- a/pop3/pop3.go +++ b/pop3/pop3.go @@ -13,6 +13,8 @@ import ( "strings" ) +// Client represents a POP3 client connection. +// It embeds a textproto.Conn for low-level protocol communication. type Client struct { *textproto.Conn } @@ -27,6 +29,8 @@ const ( respContinue = "+ " ) +// Dial establishes a plain TCP connection to the POP3 server at the given address. +// The connection respects the provided context for timeout or cancellation. func Dial(ctx context.Context, addr string) (*Client, error) { var d net.Dialer conn, err := d.DialContext(ctx, "tcp", addr) @@ -36,6 +40,8 @@ func Dial(ctx context.Context, addr string) (*Client, error) { return NewClient(conn) } +// DialTLS establishes a secure POP3-over-TLS connection to the given address. +// The server name is automatically derived from the address for certificate verification. func DialTLS(ctx context.Context, addr string) (*Client, error) { host, _, err := net.SplitHostPort(addr) if err != nil { @@ -50,6 +56,8 @@ func DialTLS(ctx context.Context, addr string) (*Client, error) { return NewClient(conn) } +// NewClient initializes a POP3 client from an existing connection. +// It reads the server greeting line and validates that it starts with "+OK". func NewClient(conn net.Conn) (*Client, error) { c := &Client{textproto.NewConn(conn)} s, err := c.ReadLine() @@ -66,6 +74,8 @@ func NewClient(conn net.Conn) (*Client, error) { return c, nil } +// Auth authenticates the user using the USER/PASS commands. +// Returns an error if either step fails. func (c *Client) Auth(user, pass string) error { if _, err := c.Cmd("USER %s", false, user); err != nil { return err @@ -74,20 +84,16 @@ func (c *Client) Auth(user, pass string) error { return err } -// Stat returns the number of messages and their total size in bytes in the inbox. +// Stat returns the number of messages and the total mailbox size (in bytes). func (c *Client) Stat() (count int, size int, err error) { s, err := c.Cmd("STAT", false) if err != nil { return } - - // count size f := strings.Fields(s) if len(f) < 2 { return 0, 0, fmt.Errorf("invalid STAT response: %q", s) } - - // Total number of messages. count, err = strconv.Atoi(f[0]) if err != nil { return @@ -95,46 +101,21 @@ func (c *Client) Stat() (count int, size int, err error) { if count == 0 { return } - - // Total size of all messages in bytes. size, err = strconv.Atoi(f[1]) - return } -// MessageID contains the ID and size of an individual message. +// MessageID represents a single message entry as returned by LIST or UIDL. +// It includes the message index, size, and optional UID. type MessageID struct { - // ID is the numerical index (non-unique) of the message. - ID int - Size int - - // UID is only present if the response is to the UIDL command. - UID string -} - -func (c *Client) multiList(cmd string, parse func([]string) (MessageID, error)) ([]MessageID, error) { - s, err := c.Cmd(cmd, true) - if err != nil { - return nil, err - } - var out []MessageID - for _, line := range strings.Split(s, lineBreak) { - f := strings.Fields(line) - if len(f) == 0 { - continue - } - id, err := parse(f) - if err != nil { - return nil, err - } - out = append(out, id) - } - return out, nil + ID int // Numerical message index (1-based) + Size int // Message size in bytes + UID string // Optional UID (only for UIDL command) } -// List returns a list of (message ID, message Size) pairs. -// If the optional id > 0, then only that particular message is listed. -// The message IDs are sequential, 1 to N. +// List returns message IDs and sizes from the mailbox. +// If id > 0, only that specific message is listed (single-line response). +// If id == 0, all messages are listed (multi-line response). func (c *Client) List(id int) ([]MessageID, error) { var ( s string @@ -142,10 +123,8 @@ func (c *Client) List(id int) ([]MessageID, error) { ) if id > 0 { - // Single line response listing one message. s, err = c.Cmd("LIST %d", false, id) } else { - // Multiline response listing all messages. s, err = c.Cmd("LIST", true) } if err != nil { @@ -154,7 +133,6 @@ func (c *Client) List(id int) ([]MessageID, error) { var out []MessageID for l := range strings.SplitSeq(s, lineBreak) { - // id size f := strings.Fields(l) if len(f) == 0 { continue @@ -172,9 +150,9 @@ func (c *Client) List(id int) ([]MessageID, error) { return out, nil } -// Uidl returns a list of (message ID, message UID) pairs. If the optional msgID -// is > 0, then only that particular message is listed. It works like Top() but only works on -// servers that support the UIDL command. Messages size field is not available in the UIDL response. +// Uidl returns message IDs and their unique identifiers (UIDs). +// If id > 0, only that specific message is listed. +// The UIDL command may not be supported by all servers. func (c *Client) Uidl(id int) ([]MessageID, error) { var ( s string @@ -182,10 +160,8 @@ func (c *Client) Uidl(id int) ([]MessageID, error) { ) if id > 0 { - // Single line response listing one message. s, err = c.Cmd("UIDL %d", false, id) } else { - // Multiline response listing all messages. s, err = c.Cmd("UIDL", true) } if err != nil { @@ -194,7 +170,6 @@ func (c *Client) Uidl(id int) ([]MessageID, error) { var out []MessageID for l := range strings.SplitSeq(s, lineBreak) { - // id uid f := strings.Fields(l) if len(f) == 0 { continue @@ -208,19 +183,18 @@ func (c *Client) Uidl(id int) ([]MessageID, error) { return out, nil } -// Retr downloads a message by the given id and returns the data -// of the entire message. +// Retr retrieves the full message text by ID, including headers and body. func (c *Client) Retr(id int) (string, error) { return c.Cmd("RETR %d", true, id) } -// Top retrieves a message by its ID with full headers and numLines lines of the body. +// Top retrieves message headers and the first numLines of the body. func (c *Client) Top(id int, numLines int) (string, error) { return c.Cmd("TOP %d %d", true, id, numLines) } -// Dele deletes one or more messages. The server only executes the -// deletions after a successful Quit(). +// Dele marks one or more messages for deletion. +// The deletions are finalized only after a successful Quit(). func (c *Client) Dele(ids ...int) error { for _, id := range ids { if _, err := c.Cmd("DELE %d", false, id); err != nil { @@ -230,22 +204,21 @@ func (c *Client) Dele(ids ...int) error { return nil } -// Rset clears the messages marked for deletion in the current session. +// Rset resets the deletion marks on all messages in the current session. func (c *Client) Rset() error { _, err := c.Cmd("RSET", false) return err } -// Noop issues a do-nothing NOOP command to the server. This is useful for -// prolonging open connections. +// Noop sends a NOOP command to keep the connection alive. +// Useful for preventing idle timeouts. func (c *Client) Noop() error { _, err := c.Cmd("NOOP", false) return err } -// Quit sends the QUIT command to server and gracefully closes the connection. -// Message deletions (DELE command) are only excuted by the server on a graceful -// quit and close. +// Quit sends the QUIT command and closes the connection gracefully. +// Deletions are committed only if QUIT succeeds. func (c *Client) Quit() error { if _, err := c.Cmd("QUIT", false); err != nil { c.Close() @@ -254,6 +227,9 @@ func (c *Client) Quit() error { return c.Close() } +// Cmd sends a POP3 command with optional arguments and reads the response. +// If isMulti is true, the response is treated as multi-line and read until +// a line containing only "." is encountered. All lines are concatenated and returned. func (c *Client) Cmd(s string, isMulti bool, args ...any) (string, error) { slog.Debug(">>> " + fmt.Sprintf(s, args...)) if _, err := c.Conn.Cmd(s, args...); err != nil { @@ -284,7 +260,9 @@ func (c *Client) Cmd(s string, isMulti bool, args ...any) (string, error) { return "", err } slog.Debug("<<< " + s) - // Dot by itself marks end; otherwise cut one dot. + + // A single dot line marks the end of multi-line response. + // Lines beginning with a dot have one dot removed as per POP3 spec. if len(s) > 0 && s[0] == '.' { if len(s) == 1 { break @@ -297,6 +275,9 @@ func (c *Client) Cmd(s string, isMulti bool, args ...any) (string, error) { return b.String(), nil } +// parseResp interprets a single-line POP3 response. +// It distinguishes between +OK, -ERR, and continuation ("+ ") responses, +// returning the response message or an error if the response is invalid. func parseResp(s string) (string, error) { switch s { case "", respOK: @@ -310,6 +291,7 @@ func parseResp(s string) (string, error) { case strings.HasPrefix(s, respErrInfo): return "", errors.New(strings.TrimPrefix(s, respErrInfo)) case strings.HasPrefix(s, respContinue): + // Some servers send "+ " for continuation prompts (rare in simple POP3). return strings.TrimPrefix(s, respContinue), nil default: return "", fmt.Errorf("unknown response: %q", s) From a5d69326630193a8921aa43aa3dd28da38f6f0e3 Mon Sep 17 00:00:00 2001 From: sunshineplan Date: Sat, 11 Oct 2025 13:02:54 +0800 Subject: [PATCH 24/40] retry --- retry/retry.go | 43 +++++++++++++++++-------------------------- retry/retry_test.go | 17 +++++++++-------- 2 files changed, 26 insertions(+), 34 deletions(-) diff --git a/retry/retry.go b/retry/retry.go index 536f9dc..94b75b4 100644 --- a/retry/retry.go +++ b/retry/retry.go @@ -2,39 +2,29 @@ package retry import ( "errors" + "fmt" "time" ) -var errNoMoreRetry error = errorNoMoreRetry("no more retry") +// ErrNoMoreRetry is a sentinel error indicating that no more retries should be performed. +var ErrNoMoreRetry = errors.New("no more retry") -type errorNoMoreRetry string - -func (err errorNoMoreRetry) Error() string { - return string(err) -} - -func (errorNoMoreRetry) Unwrap() error { - return errNoMoreRetry +// StopRetry creates a wrapped error indicating that retries should stop. +func StopRetry(msg string) error { + return fmt.Errorf("%w: %s", ErrNoMoreRetry, msg) } -// IsNoMoreRetry reports whether error is NoMoreRetry error. +// IsNoMoreRetry reports whether the given error indicates to stop retrying. func IsNoMoreRetry(err error) bool { - if e, ok := err.(interface{ Unwrap() []error }); ok { - for _, err := range e.Unwrap() { - if IsNoMoreRetry(err) { - return true - } - } - return false - } - return errors.Is(err, errNoMoreRetry) + return errors.Is(err, ErrNoMoreRetry) } -// ErrNoMoreRetry tells function does no more retry. -func ErrNoMoreRetry(err string) error { return errorNoMoreRetry(err) } - -// Do keeps retrying the function until no error is returned. -func Do(fn func() error, attempts, delay int) error { +// Do executes fn repeatedly until it succeeds, the attempts are exhausted, +// or fn returns an error that indicates no more retries. +func Do(fn func() error, attempts int, delay time.Duration) error { + if attempts <= 0 { + return errors.New("invalid attempts count") + } var errs []error for i := range attempts { err := fn() @@ -44,8 +34,9 @@ func Do(fn func() error, attempts, delay int) error { errs = append(errs, err) if IsNoMoreRetry(err) { break - } else if i < attempts-1 { - time.Sleep(time.Second * time.Duration(delay)) + } + if i < attempts-1 { + time.Sleep(delay) } } return errors.Join(errs...) diff --git a/retry/retry_test.go b/retry/retry_test.go index 722b5bd..294cb2a 100644 --- a/retry/retry_test.go +++ b/retry/retry_test.go @@ -4,17 +4,18 @@ import ( "errors" "strconv" "testing" + "time" ) func TestRetry(t *testing.T) { - if err := ErrNoMoreRetry("error"); !errors.Is(err, errNoMoreRetry) { - t.Error("expected err is errNoMoreRetry; got not") + if err := StopRetry("error"); !errors.Is(err, ErrNoMoreRetry) { + t.Error("expected err is ErrNoMoreRetry; got not") } var i int if err := Do(func() error { defer func() { i++ }() return nil - }, 3, 1); err != nil { + }, 3, time.Second); err != nil { t.Errorf("expected nil error; got non-nil error %v", err) } else if i != 1 { t.Errorf("expected 1; got %d", i) @@ -24,7 +25,7 @@ func TestRetry(t *testing.T) { if err := Do(func() error { defer func() { i++ }() return errors.New("error" + strconv.Itoa(i)) - }, 3, 1); err == nil { + }, 3, time.Second); err == nil { t.Error("expected non-nil error; got nil error") } else if expect := "error0\nerror1\nerror2"; err.Error() != expect { t.Errorf("expected %s; got %s", expect, err) @@ -33,10 +34,10 @@ func TestRetry(t *testing.T) { i = 0 if err := Do(func() error { defer func() { i++ }() - return ErrNoMoreRetry("error" + strconv.Itoa(i)) - }, 3, 1); !IsNoMoreRetry(err) { + return StopRetry("error" + strconv.Itoa(i)) + }, 3, time.Second); !IsNoMoreRetry(err) { t.Errorf("expected ErrNoMoreRetry; got %s", err) - } else if err.Error() != "error0" { - t.Errorf("expected error0; got %d", i) + } else if err.Error() != "no more retry: error0" { + t.Errorf("expected error0; got %s", err) } } From 92b27ae499e3cf36e4f72926694d271b78eb196f Mon Sep 17 00:00:00 2001 From: sunshineplan Date: Mon, 13 Oct 2025 14:41:46 +0800 Subject: [PATCH 25/40] csv --- csv/convert.go | 133 ++++++++++++++++------------ csv/export.go | 7 +- csv/export_test.go | 10 --- csv/reader.go | 38 ++++---- csv/reader_test.go | 10 ++- csv/types.go | 10 --- csv/writer.go | 210 +++++++++++++++++++++------------------------ csv/writer_test.go | 24 ------ 8 files changed, 203 insertions(+), 239 deletions(-) delete mode 100644 csv/types.go diff --git a/csv/convert.go b/csv/convert.go index 5b7d622..e5173cc 100644 --- a/csv/convert.go +++ b/csv/convert.go @@ -7,6 +7,8 @@ import ( "fmt" "reflect" "strconv" + + "github.com/sunshineplan/utils/container" ) var ( @@ -40,56 +42,88 @@ func strconvErr(err error) error { return err } +var convertMap container.Map[reflect.Type, func(reflect.Type, string) (reflect.Value, error)] + func convert(t reflect.Type, s string) (reflect.Value, error) { - if t.Kind() == reflect.Ptr { + if t.Kind() == reflect.Pointer || t.Kind() == reflect.Interface { return convert(t.Elem(), s) } - v := reflect.Indirect(reflect.New(t)) + fn, ok := convertMap.Load(t) + if ok { + return fn(t, s) + } + v := reflect.New(t).Elem() if v.CanInt() { - n, err := strconv.ParseInt(s, 10, t.Bits()) - if err != nil { - return v, fmt.Errorf("converting type String %q to a %s: %v", s, t.Kind(), strconvErr(err)) + fn = func(t reflect.Type, s string) (reflect.Value, error) { + v := reflect.New(t).Elem() + n, err := strconv.ParseInt(s, 10, t.Bits()) + if err != nil { + return reflect.Value{}, fmt.Errorf("converting type String %q to a %s: %v", s, t.Kind(), strconvErr(err)) + } + v.SetInt(n) + return v, nil } - v.SetInt(n) - return v, nil } else if v.CanUint() { - u, err := strconv.ParseUint(s, 10, t.Bits()) - if err != nil { - return v, fmt.Errorf("converting type String %q to a %s: %v", s, t.Kind(), strconvErr(err)) + fn = func(t reflect.Type, s string) (reflect.Value, error) { + v := reflect.New(t).Elem() + u, err := strconv.ParseUint(s, 10, t.Bits()) + if err != nil { + return reflect.Value{}, fmt.Errorf("converting type String %q to a %s: %v", s, t.Kind(), strconvErr(err)) + } + v.SetUint(u) + return v, nil } - v.SetUint(u) - return v, nil } else if v.CanFloat() { - f, err := strconv.ParseFloat(s, t.Bits()) - if err != nil { - return v, fmt.Errorf("converting type String %q to a %s: %v", s, t.Kind(), strconvErr(err)) + fn = func(t reflect.Type, s string) (reflect.Value, error) { + v := reflect.New(t).Elem() + f, err := strconv.ParseFloat(s, t.Bits()) + if err != nil { + return reflect.Value{}, fmt.Errorf("converting type String %q to a %s: %v", s, t.Kind(), strconvErr(err)) + } + v.SetFloat(f) + return v, nil } - v.SetFloat(f) - return v, nil } else if v.CanComplex() { - c, err := strconv.ParseComplex(s, t.Bits()) - if err != nil { - return v, fmt.Errorf("converting type String %q to a %s: %v", s, t.Kind(), strconvErr(err)) - } - v.SetComplex(c) - return v, nil - } - switch t.Kind() { - case reflect.String, reflect.Interface: - v.SetString(s) - return v, nil - case reflect.Bool: - b, err := strconv.ParseBool(s) - if err == nil { - return v, fmt.Errorf("converting type String %q to a %s: %v", s, t.Kind(), err) - } - v.SetBool(b) - return v, nil - case reflect.Slice: - if t.Elem().Kind() == reflect.Uint8 { - v.SetBytes([]byte(s)) + fn = func(t reflect.Type, s string) (reflect.Value, error) { + v := reflect.New(t).Elem() + c, err := strconv.ParseComplex(s, t.Bits()) + if err != nil { + return reflect.Value{}, fmt.Errorf("converting type String %q to a %s: %v", s, t.Kind(), strconvErr(err)) + } + v.SetComplex(c) return v, nil } + } else { + switch t.Kind() { + case reflect.String: + fn = func(t reflect.Type, s string) (reflect.Value, error) { + v := reflect.New(t).Elem() + v.SetString(s) + return v, nil + } + case reflect.Bool: + fn = func(t reflect.Type, s string) (reflect.Value, error) { + v := reflect.New(t).Elem() + b, err := strconv.ParseBool(s) + if err == nil { + return reflect.Value{}, fmt.Errorf("converting type String %q to a %s: %v", s, t.Kind(), err) + } + v.SetBool(b) + return v, nil + } + case reflect.Slice: + if t.Elem().Kind() == reflect.Uint8 { + fn = func(t reflect.Type, s string) (reflect.Value, error) { + v := reflect.New(t).Elem() + v.SetBytes([]byte(s)) + return v, nil + } + } + } + } + if fn != nil { + convertMap.Store(t, fn) + return fn(t, s) } return reflect.Value{}, errNotSet } @@ -142,24 +176,14 @@ func setCell(dest any, s string) error { } return d.UnmarshalText([]byte(s)) } - dpv := reflect.ValueOf(dest) - if dpv.Kind() != reflect.Ptr { + if dpv.Kind() != reflect.Pointer { return errors.New("destination not a pointer") } if dpv.IsNil() { return errNilPtr } - dv := reflect.Indirect(dpv) - if v := reflect.ValueOf(s); v.Type().AssignableTo(dv.Type()) { - dv.Set(v) - return nil - } - if dv.Kind() == reflect.Pointer { - dv.Set(reflect.New(dv.Type().Elem())) - return setCell(dv.Interface(), s) - } - + dv := dpv.Elem() if v, err := convert(dv.Type(), s); err != nil { if err == errNotSet { if err := json.Unmarshal([]byte(strconv.Quote(s)), dest); err != nil { @@ -190,26 +214,23 @@ func setRow(dest any, m map[string]string) error { b, _ := json.Marshal(m) return d.UnmarshalJSON(b) } - dpv := reflect.ValueOf(dest) - if dpv.Kind() != reflect.Ptr { + if dpv.Kind() != reflect.Pointer { return errors.New("destination not a pointer") } if dpv.IsNil() { return errNilPtr } - - dv := reflect.Indirect(dpv) + dv := dpv.Elem() if v := reflect.ValueOf(m); v.Type().AssignableTo(dv.Type()) { dv.Set(v) return nil } - if len(m) == 0 { return nil } switch dv.Kind() { - case reflect.Ptr: + case reflect.Pointer: dv.Set(reflect.New(dv.Type().Elem())) return setRow(dv.Interface(), m) case reflect.Map: diff --git a/csv/export.go b/csv/export.go index d700802..3bdd854 100644 --- a/csv/export.go +++ b/csv/export.go @@ -17,7 +17,7 @@ func ExportFile[S ~[]E, E any](fieldnames []string, slice S, file string) error if err != nil { return err } - + defer f.Close() return export(fieldnames, slice, f, false) } @@ -32,13 +32,13 @@ func ExportUTF8File[S ~[]E, E any](fieldnames []string, slice S, file string) er if err != nil { return err } - + defer f.Close() return export(fieldnames, slice, f, true) } func export[S ~[]E, E any](fieldnames []string, slice S, w io.Writer, utf8bom bool) (err error) { csvWriter := NewWriter(w, utf8bom) - if fieldnames == nil { + if len(fieldnames) == 0 { if len(slice) == 0 { return fmt.Errorf("can't get struct fieldnames from zero length slice") } @@ -49,6 +49,5 @@ func export[S ~[]E, E any](fieldnames []string, slice S, w io.Writer, utf8bom bo if err != nil { return } - return csvWriter.WriteAll(slice) } diff --git a/csv/export_test.go b/csv/export_test.go index 97c97a0..da27fbd 100644 --- a/csv/export_test.go +++ b/csv/export_test.go @@ -67,16 +67,6 @@ aa, fieldnames: nil, slice: []*test{{A: "a", B: "b"}, nil, {A: "aa", B: nil}}, }, result) - testExport(t, testcase[D]{ - name: "D slice", - fieldnames: []string{"A", "B"}, - slice: []D{{{"A", "a"}, {"B", "b"}}, {{"A", "aa"}, {"B", nil}}}, - }, result) - testExport(t, testcase[D]{ - name: "D slice without fieldnames", - fieldnames: nil, - slice: []D{{{"A", "a"}, {"B", "b"}}, {{"A", "aa"}, {"B", nil}}}, - }, result) testExport(t, testcase[any]{ name: "interface slice", fieldnames: []string{"A", "B"}, diff --git a/csv/reader.go b/csv/reader.go index 41479b3..94bb7f3 100644 --- a/csv/reader.go +++ b/csv/reader.go @@ -22,20 +22,19 @@ type Reader struct { } // NewReader returns a new Reader that reads from r. -func NewReader(r io.Reader, hasFields bool) *Reader { +func NewReader(r io.Reader, hasFields bool) (*Reader, error) { reader := &Reader{Reader: csv.NewReader(r)} if closer, ok := r.(io.Closer); ok { reader.closer = closer } - if hasFields { var err error reader.fields, err = reader.Read() if err != nil { - panic(err) + return nil, err } } - return reader + return reader, nil } // ReadFile returns Reader reads from file. @@ -44,7 +43,12 @@ func ReadFile(file string, hasFields bool) (*Reader, error) { if err != nil { return nil, err } - return NewReader(f, hasFields), nil + reader, err := NewReader(f, hasFields) + if err != nil { + f.Close() + return nil, err + } + return reader, nil } func (r *Reader) Read() (record []string, err error) { @@ -76,14 +80,12 @@ func (r *Reader) Scan(dest ...any) error { if r.next == nil && r.nextErr == nil { return fmt.Errorf("Scan called without calling Next") } - if r.nextErr != nil { return r.nextErr } if len(dest) != len(r.next) { return fmt.Errorf("expected %d destination arguments in Scan, not %d", len(r.next), len(dest)) } - for i, v := range r.next { if err := setCell(dest[i], v); err != nil { return fmt.Errorf("Scan error on field index %d: %v", i, err) @@ -104,7 +106,6 @@ func (r *Reader) Decode(dest any) error { if r.nextErr != nil { return r.nextErr } - m := make(map[string]string) for i, field := range r.fields { if len(r.next) > i { @@ -123,25 +124,19 @@ func (r *Reader) Close() error { // DecodeAll decodes each record from r into dest. func DecodeAll[S ~[]E, E any](r io.Reader, dest *S) (err error) { - defer func() { - if e := recover(); e != nil { - err = fmt.Errorf("%v", e) - } - }() - - reader := NewReader(r, true) - defer reader.Close() - - var res S + reader, err := NewReader(r, true) + if err != nil { + return + } + *dest = nil for reader.Next() { var t E if err = reader.Decode(&t); err != nil { + *dest = nil return } - res = append(res, t) + *dest = append(*dest, t) } - *dest = res - return } @@ -151,5 +146,6 @@ func DecodeFile[S ~[]E, E any](file string, dest *S) error { if err != nil { return err } + defer f.Close() return DecodeAll(f, dest) } diff --git a/csv/reader_test.go b/csv/reader_test.go index 44895f5..874cbad 100644 --- a/csv/reader_test.go +++ b/csv/reader_test.go @@ -19,7 +19,10 @@ test a,1,"[1,2]" b,2,"[3,4]" ` - r := NewReader(strings.NewReader(csv), true) + r, err := NewReader(strings.NewReader(csv), true) + if err != nil { + t.Fatal(err) + } if expect := []string{"A", "B", "C"}; !slices.Equal(expect, r.fields) { t.Errorf("expected %v; got %v", expect, r.fields) } @@ -35,7 +38,10 @@ b,2,"[3,4]" t.Errorf("expected %v; got %v", expect, res1) } - r = NewReader(strings.NewReader(csv), true) + r, err = NewReader(strings.NewReader(csv), true) + if err != nil { + t.Fatal(err) + } var res2 []map[string]string for r.Next() { var res map[string]string diff --git a/csv/types.go b/csv/types.go deleted file mode 100644 index abff346..0000000 --- a/csv/types.go +++ /dev/null @@ -1,10 +0,0 @@ -package csv - -// D is an ordered representation of a document. This type should be used when the order of the elements matter. -type D []E - -// E represents a element for a D. It is usually used inside a D. -type E struct { - Key string - Value any -} diff --git a/csv/writer.go b/csv/writer.go index a2a3236..ecbb1da 100644 --- a/csv/writer.go +++ b/csv/writer.go @@ -6,6 +6,8 @@ import ( "io" "reflect" "slices" + + "github.com/sunshineplan/utils/pool" ) var utf8bom = []byte{0xEF, 0xBB, 0xBF} @@ -17,6 +19,8 @@ type Writer struct { utf8bom bool fields []field fieldsWritten bool + zero []string + pool *pool.Pool[[]string] } type field struct { @@ -29,13 +33,10 @@ func NewWriter(w io.Writer, utf8bom bool) *Writer { Writer: csv.NewWriter(w), w: w, utf8bom: utf8bom, + pool: pool.New[[]string](), } } -func (w *Writer) SkipWriteFields() { - w.fieldsWritten = true -} - // WriteFields writes fieldnames to w along with necessary utf8bom bytes. The fields must be a // non-zero field struct or a non-zero length string slice, otherwise an error will be return. // It can be run only once. @@ -43,51 +44,37 @@ func (w *Writer) WriteFields(fields any) error { if w.fieldsWritten { return fmt.Errorf("fieldnames already be written") } - - v := reflect.ValueOf(fields) - if v.Kind() == reflect.Ptr { - v = reflect.Indirect(v) - if !v.IsValid() { - return fmt.Errorf("can not get fieldnames from nil pointer struct") - } - } - switch v.Kind() { - case reflect.Struct: - if v.NumField() == 0 { - return fmt.Errorf("can not get fieldnames from zero field struct") + switch f := fields.(type) { + case []string: + for _, i := range f { + w.fields = append(w.fields, field{i, ""}) } - - for i := range v.NumField() { - if f := v.Type().Field(i); f.IsExported() { - tag, _ := v.Type().Field(i).Tag.Lookup("csv") - w.fields = append(w.fields, field{v.Type().Field(i).Name, tag}) + default: + v := reflect.ValueOf(fields) + if v.Kind() == reflect.Pointer { + v = reflect.Indirect(v) + if !v.IsValid() { + return fmt.Errorf("can not get fieldnames from nil pointer struct") } } - case reflect.Slice: - if v.Len() == 0 { - return fmt.Errorf("can not get fieldnames from zero length slice") - } - - if fieldnames, ok := fields.([]string); ok { - for _, i := range fieldnames { - w.fields = append(w.fields, field{i, ""}) + switch v.Kind() { + case reflect.Struct: + if v.NumField() == 0 { + return fmt.Errorf("can not get fieldnames from zero field struct") } - } else if d, ok := fields.(D); ok { - for _, i := range d { - w.fields = append(w.fields, field{i.Key, ""}) + for i := range v.NumField() { + if f := v.Type().Field(i); f.IsExported() { + tag, _ := v.Type().Field(i).Tag.Lookup("csv") + w.fields = append(w.fields, field{v.Type().Field(i).Name, tag}) + } } - } else { - return fmt.Errorf("only can get fieldnames from slice which is string slice or csv.D") + default: + return fmt.Errorf("can not get fieldnames from fields which is not struct or string slice") } - - default: - return fmt.Errorf("can not get fieldnames from fields which is not struct or string slice or csv.D") } - if w.utf8bom { w.w.Write(utf8bom) } - var record []string for _, i := range w.fields { if i.tag != "" { @@ -96,18 +83,19 @@ func (w *Writer) WriteFields(fields any) error { record = append(record, i.name) } } - if err := w.Writer.Write(record); err != nil { return err } w.Flush() - if err := w.Error(); err != nil { return err } - w.fieldsWritten = true - + w.zero = make([]string, len(w.fields)) + w.pool.New = func() *[]string { + s := make([]string, len(w.fields)) + return &s + } return nil } @@ -118,94 +106,92 @@ func (w *Writer) Write(record any) error { if !w.fieldsWritten { return fmt.Errorf("fieldnames has not be written yet") } - - v := reflect.ValueOf(record) - if v.Kind() == reflect.Interface { - v = v.Elem() - } - if v.Kind() == reflect.Ptr { - v = reflect.Indirect(v) - if !v.IsValid() { + switch d := record.(type) { + case []string: + if len(d) == 0 { return nil } - } - - r := make([]string, len(w.fields)) - switch v.Kind() { - case reflect.Map: - if keyType := reflect.TypeOf(v.Interface()).Key(); keyType.Kind() == reflect.String { - for i, field := range w.fields { - key := reflect.Indirect(reflect.New(keyType)) - key.SetString(field.name) - if v := v.MapIndex(key); v.IsValid() && v.Interface() != nil { - r[i], _ = marshalText(v.Interface()) - } - } - } else { - return fmt.Errorf("only can write record from map which is string kind") + return w.Writer.Write(d) + default: + v := reflect.ValueOf(record) + if v.Kind() == reflect.Interface { + v = v.Elem() } - case reflect.Struct: - for i, field := range w.fields { - var val reflect.Value - var found bool - for i := range v.NumField() { - if tag, ok := v.Type().Field(i).Tag.Lookup("csv"); ok && tag == field.tag { - val = v.FieldByName(field.name) - found = true - break - } - } - if !found { - if val = v.FieldByName(field.name); val.IsValid() && val.Interface() != nil { - found = true - } - } - if found { - r[i], _ = marshalText(val.Interface()) + if v.Kind() == reflect.Pointer { + v = reflect.Indirect(v) + if !v.IsValid() { + return nil } } - case reflect.Slice: - if rec, ok := record.([]string); ok { - if len(rec) == 0 { - return nil + r := w.pool.Get() + defer w.pool.Put(r) + switch v.Kind() { + case reflect.Map: + if keyType := reflect.TypeOf(v.Interface()).Key(); keyType.Kind() == reflect.String { + for i, field := range w.fields { + key := reflect.Indirect(reflect.New(keyType)) + key.SetString(field.name) + if v := v.MapIndex(key); v.IsValid() && v.Interface() != nil { + (*r)[i], _ = marshalText(v.Interface()) + } else { + (*r)[i] = "" + } + } + } else { + return fmt.Errorf("only can write record from map which is string kind") } - return w.Writer.Write(rec) - } else if d, ok := record.(D); ok { + case reflect.Struct: for i, field := range w.fields { - for _, e := range d { - if field.name == e.Key { - r[i], _ = marshalText(e.Value) + var val reflect.Value + var found bool + for i := range v.NumField() { + if tag, ok := v.Type().Field(i).Tag.Lookup("csv"); ok && tag == field.tag { + val = v.FieldByName(field.name) + found = true break } } + if !found { + if val = v.FieldByName(field.name); val.IsValid() && val.Interface() != nil { + found = true + } + } + if found { + (*r)[i], _ = marshalText(val.Interface()) + } else { + (*r)[i] = "" + } } - break + default: + return fmt.Errorf("not support record format: %s", v.Kind()) } - fallthrough - default: - return fmt.Errorf("not support record format: %s", v.Kind()) - } - - if slices.Equal(r, make([]string, len(w.fields))) { - return nil + if slices.Equal(*r, w.zero) { + return nil + } + return w.Writer.Write(*r) } - - return w.Writer.Write(r) } // WriteAll writes multiple CSV records to w using Write and then calls Flush, returning any error from the Flush. func (w *Writer) WriteAll(records any) error { - if reflect.TypeOf(records).Kind() != reflect.Slice { - return fmt.Errorf("records is not slice") - } - - v := reflect.ValueOf(records) - for i := range v.Len() { - if err := w.Write(v.Index(i).Interface()); err != nil { - return err + switch s := records.(type) { + case [][]string: + for _, i := range s { + if err := w.Write(i); err != nil { + return err + } + } + default: + if reflect.TypeOf(records).Kind() != reflect.Slice { + return fmt.Errorf("records is not slice") + } + v := reflect.ValueOf(records) + for i := range v.Len() { + if err := w.Write(v.Index(i).Interface()); err != nil { + return err + } } } w.Flush() - return w.Error() } diff --git a/csv/writer_test.go b/csv/writer_test.go index 14305de..9ed617a 100644 --- a/csv/writer_test.go +++ b/csv/writer_test.go @@ -27,30 +27,6 @@ func TestWriteFields(t *testing.T) { } } -func TestSkipWriteFields(t *testing.T) { - var buf bytes.Buffer - w := NewWriter(&buf, false) - if err := w.WriteFields([]string{"A", "B"}); err != nil { - t.Fatal(err) - } - if err := w.Write([]string{"a", "b"}); err != nil { - t.Fatal(err) - } - w.Flush() - w = NewWriter(&buf, false) - if err := w.Write([]string{"c", "d"}); err == nil { - t.Fatal("gave nil error; want error") - } - w.SkipWriteFields() - if err := w.Write([]string{"c", "d"}); err != nil { - t.Fatal(err) - } - w.Flush() - if result, r := "A,B\na,b\nc,d\n", buf.String(); r != result { - t.Errorf("expected %q; got %q", result, r) - } -} - func TestWriter(t *testing.T) { result := `A|B a|b From 7110ae059cc4bafb0c3eda7108df0e803979ebbb Mon Sep 17 00:00:00 2001 From: sunshineplan Date: Tue, 14 Oct 2025 13:25:32 +0800 Subject: [PATCH 26/40] flags --- csv/reader.go | 10 +++++ flags/flags.go | 89 ++++++++++++++++++++++++++++++--------------- flags/flags_test.go | 16 ++++---- 3 files changed, 77 insertions(+), 38 deletions(-) diff --git a/csv/reader.go b/csv/reader.go index 94bb7f3..ac9bcae 100644 --- a/csv/reader.go +++ b/csv/reader.go @@ -51,6 +51,15 @@ func ReadFile(file string, hasFields bool) (*Reader, error) { return reader, nil } +// Read reads one record (a slice of fields) from r. +// If the record has an unexpected number of fields, +// Read returns the record along with the error [ErrFieldCount]. +// If the record contains a field that cannot be parsed, +// Read returns a partial record along with the parse error. +// The partial record contains all fields read before the error. +// If there is no data left to be read, Read returns nil, [io.EOF]. +// If [Reader.ReuseRecord] is true, the returned slice may be shared +// between multiple calls to Read. func (r *Reader) Read() (record []string, err error) { record, err = r.Reader.Read() if err == nil { @@ -115,6 +124,7 @@ func (r *Reader) Decode(dest any) error { return setRow(dest, m) } +// Close closes the underlying reader if it implements the io.Closer interface. func (r *Reader) Close() error { if r.closer != nil { return r.closer.Close() diff --git a/flags/flags.go b/flags/flags.go index eb4f7cd..c25c28b 100644 --- a/flags/flags.go +++ b/flags/flags.go @@ -13,70 +13,101 @@ import ( ) var ( - Strict bool + // SilentMissingConfig controls whether a warning message is printed + // when the specified configuration file is missing. + SilentMissingConfig bool + // config stores the path of the configuration file. config string ) +// errMissingConfig is returned when the specified configuration file does not exist. +var errMissingConfig = errors.New("config file is missing") + +// SetConfigFile sets the path of the configuration file to be used when parsing flags. func SetConfigFile(path string) { config = path } -func getArgs(strict, hint bool) (args []string) { +// getArgs reads the configuration file and converts its key-value pairs +// into command-line style arguments compatible with the flag package. +// Lines beginning with '#' or empty lines are ignored. +// Each valid line must be in the form "key=value". +// Returns a slice of arguments or an error if parsing fails. +func getArgs() (args []string, err error) { lines, err := txt.ReadFile(config) if err != nil { - if errors.Is(err, fs.ErrNotExist) && (strict || hint) { - fmt.Println("config file is missing") - } - if !strict { - return + if errors.Is(err, fs.ErrNotExist) { + return nil, errMissingConfig } - panic(err) + return } - for _, line := range lines { + for i, line := range lines { line = strings.TrimSpace(line) if line == "" || strings.HasPrefix(line, "#") { continue } parts := strings.SplitN(line, "=", 2) if len(parts) != 2 { - panic(fmt.Sprintf("cannot parse %q", line)) + err = fmt.Errorf("line %d: cannot parse %q", i+1, line) + return } if key := strings.TrimSpace(parts[0]); flag.Lookup(key) != nil { args = append(args, fmt.Sprintf("-%s=%s", key, unquote(parts[1]))) } else { - if err := fmt.Sprintf("undefined flag %q", key); strict { - panic(err) - } else { - fmt.Println("[Warning]", err) - } + err = fmt.Errorf("line %d: unknown flag %q", i+1, key) + return } } return } +// unquote removes surrounding quotes from a string if present. +// If the string cannot be unquoted, it is returned unchanged. func unquote(s string) string { s = strings.TrimSpace(s) - if s, err := strconv.Unquote(s); err == nil { - return s + if unq, err := strconv.Unquote(s); err == nil { + return unq } return s } -func ParseFlags(strict, hint bool) { +// Parse parses command-line flags and optionally merges them with values +// read from the configuration file specified via SetConfigFile. +// If a configuration file is provided, its flags are prepended to os.Args. +// Errors during parsing or reading are handled according to the flag package’s +// ErrorHandling setting. +func Parse() error { if config != "" { - flag.CommandLine.Parse(append(getArgs(strict, hint), os.Args[1:]...)) - return + args, err := getArgs() + if err != nil { + handleError(err) + } + return flag.CommandLine.Parse(append(args, os.Args[1:]...)) } - flag.Parse() -} - -func UseStrict(strict bool) { - Strict = strict + return flag.CommandLine.Parse(os.Args[1:]) } -func Parse() { - ParseFlags(Strict, true) +// handleError handles errors according to their type +// and the current flag.CommandLine.ErrorHandling mode. +func handleError(err error) { + switch err { + case errMissingConfig: + if !SilentMissingConfig { + fmt.Println("[flags]", err) + } + default: + switch flag.CommandLine.ErrorHandling() { + case flag.ContinueOnError: + fmt.Println("[flags]", err) + case flag.ExitOnError: + os.Exit(2) + case flag.PanicOnError: + panic(err) + } + } } -func ParseStrict() { - ParseFlags(true, true) +// init reinitializes the default CommandLine flag set to use ContinueOnError mode, +// preventing flag.Parse from exiting the program automatically. +func init() { + flag.CommandLine.Init(flag.CommandLine.Name(), flag.ContinueOnError) } diff --git a/flags/flags_test.go b/flags/flags_test.go index 4d5e14e..1203c5d 100644 --- a/flags/flags_test.go +++ b/flags/flags_test.go @@ -5,15 +5,14 @@ import ( "testing" ) -var ( - var0 = flag.String("var0", "", "") - var1 = flag.String("var1", "", "") - var2 = flag.String("var2", "2", "") -) - func TestParse(t *testing.T) { + var0 := flag.String("var0", "", "") + var1 := flag.String("var1", "", "") + var2 := flag.String("var2", "2", "") config = "test_config.ini" - Parse() + if err := Parse(); err != nil { + t.Error(err) + } if *var0 != "0" { t.Errorf("expected %q; got %q", "0", *var0) } @@ -23,8 +22,7 @@ func TestParse(t *testing.T) { if *var2 != "" { t.Errorf("expected %q; got %q", "", *var2) } - - UseStrict(true) + flag.CommandLine.Init("", flag.PanicOnError) defer func() { if err := recover(); err == nil { t.Error("gave no panic; want panic") From e64cbab44422167cd79e6ab639866bd818feb7a2 Mon Sep 17 00:00:00 2001 From: sunshineplan Date: Tue, 14 Oct 2025 17:22:24 +0800 Subject: [PATCH 27/40] html --- html/element.go | 133 +++++++++++++++++++++++++++++++++--------------- 1 file changed, 93 insertions(+), 40 deletions(-) diff --git a/html/element.go b/html/element.go index e1c6b89..8c83d7e 100644 --- a/html/element.go +++ b/html/element.go @@ -2,8 +2,11 @@ package html import ( "fmt" + "maps" "slices" "strings" + + "github.com/sunshineplan/utils/pool" ) var _ HTMLer = new(Element) @@ -15,6 +18,9 @@ type Element struct { } func (e *Element) Attribute(name, value string) *Element { + if e.attrs == nil { + e.attrs = make(map[string]string) + } e.attrs[name] = value return e } @@ -80,7 +86,7 @@ func (e *Element) AppendContent(v ...any) *Element { func (e *Element) AppendChild(child ...*Element) *Element { for _, i := range child { - e.AppendContent(i) + e.content += i.HTML() } return e } @@ -93,56 +99,93 @@ func (e *Element) AppendHTML(html ...string) *Element { } // https://developer.mozilla.org/en-US/docs/Glossary/Void_element +var voidElements = map[string]struct{}{ + "area": {}, + "base": {}, + "br": {}, + "col": {}, + "embed": {}, + "hr": {}, + "img": {}, + "input": {}, + "link": {}, + "meta": {}, + "param": {}, + "source": {}, + "track": {}, + "wbr": {}, +} + func (e Element) isVoidElement() bool { - return slices.Contains([]string{ - "area", - "base", - "br", - "col", - "embed", - "hr", - "img", - "input", - "link", - "meta", - "param", - "source", - "track", - "wbr", - }, strings.ToLower(e.tag)) + _, ok := voidElements[strings.ToLower(e.tag)] + return ok } +var builderPool = pool.New[strings.Builder]() + // https://developer.mozilla.org/en-US/docs/Web/HTML/Attributes func (e Element) printAttrs() string { - var s []string - for k, v := range e.attrs { - if v == "" || v == "true" { - s = append(s, k) - } else if v == "false" { + if len(e.attrs) == 0 { + return "" + } + keys := make([]string, 0, len(e.attrs)) + for k := range e.attrs { + keys = append(keys, k) + } + slices.Sort(keys) + b := builderPool.Get() + defer func() { + b.Reset() + builderPool.Put(b) + }() + first := true + for _, k := range keys { + v := e.attrs[k] + switch v { + case "", "true": + if !first { + b.WriteByte(' ') + } + b.WriteString(k) + first = false + case "false": continue - } else { - s = append(s, fmt.Sprintf("%s=%q", k, v)) + default: + if !first { + b.WriteByte(' ') + } + b.WriteString(k) + b.WriteByte('=') + b.WriteByte('"') + b.WriteString(EscapeString(v)) + b.WriteByte('"') + first = false } } - slices.Sort(s) - return strings.Join(s, " ") + return b.String() } func (e *Element) String() string { - var b strings.Builder - if e.tag != "" { - fmt.Fprint(&b, "<", e.tag) - if attrs := e.printAttrs(); attrs != "" { - fmt.Fprint(&b, " ", attrs) - } - } if e.tag == "" { - fmt.Fprint(&b, e.content) - } else if e.isVoidElement() { - fmt.Fprint(&b, ">") - } else { - fmt.Fprint(&b, ">", e.content) - fmt.Fprintf(&b, "", e.tag) + return string(e.content) + } + b := builderPool.Get() + defer func() { + b.Reset() + builderPool.Put(b) + }() + b.WriteByte('<') + b.WriteString(e.tag) + if attrs := e.printAttrs(); attrs != "" { + b.WriteByte(' ') + b.WriteString(attrs) + } + b.WriteByte('>') + if !e.isVoidElement() { + b.WriteString(string(e.content)) + b.WriteString("') } return b.String() } @@ -151,6 +194,16 @@ func (e *Element) HTML() HTML { return HTML(e.String()) } +func (e *Element) Clone() *Element { + attrs := make(map[string]string, len(e.attrs)) + maps.Copy(attrs, e.attrs) + return &Element{ + tag: e.tag, + attrs: attrs, + content: e.content, + } +} + func NewElement(tag string) *Element { - return &Element{tag, make(map[string]string), ""} + return &Element{tag: tag, attrs: make(map[string]string)} } From b450898f6842e33b06dcf0a7ded581b30e281782 Mon Sep 17 00:00:00 2001 From: sunshineplan Date: Wed, 15 Oct 2025 09:07:09 +0800 Subject: [PATCH 28/40] html comment --- html/element.go | 29 +++++++++++++++++++++++++++++ html/html.go | 12 +++++++++++- html/table.go | 11 +++++++++++ 3 files changed, 51 insertions(+), 1 deletion(-) diff --git a/html/element.go b/html/element.go index 8c83d7e..1f82cbf 100644 --- a/html/element.go +++ b/html/element.go @@ -11,12 +11,16 @@ import ( var _ HTMLer = new(Element) +// Element represents a single HTML element, including its tag name, +// attributes, and inner HTML content. type Element struct { tag string attrs map[string]string content HTML } +// Attribute sets or updates an attribute on the element. +// If attrs is nil, it initializes the map. func (e *Element) Attribute(name, value string) *Element { if e.attrs == nil { e.attrs = make(map[string]string) @@ -25,30 +29,38 @@ func (e *Element) Attribute(name, value string) *Element { return e } +// Class sets the "class" attribute. Multiple classes can be provided. func (e *Element) Class(class ...string) *Element { return e.Attribute("class", strings.Join(class, " ")) } +// Href sets the "href" attribute. func (e *Element) Href(href string) *Element { return e.Attribute("href", href) } +// Name sets the "name" attribute. func (e *Element) Name(name string) *Element { return e.Attribute("name", name) } +// Src sets the "src" attribute. func (e *Element) Src(src string) *Element { return e.Attribute("src", src) } +// Style sets the "style" attribute. func (e *Element) Style(style string) *Element { return e.Attribute("style", style) } +// Title sets the "title" attribute. func (e *Element) Title(title string) *Element { return e.Attribute("title", title) } +// content converts an arbitrary value into escaped HTML text. +// It handles HTMLer, HTML, string, and other types gracefully. func content(v any) HTML { switch v := v.(type) { case nil: @@ -64,19 +76,23 @@ func content(v any) HTML { } } +// Content replaces the current content of the element with new values. func (e *Element) Content(v ...any) *Element { e.content = "" return e.AppendContent(v...) } +// Contentf formats a string using fmt.Sprintf and sets it as the element content. func (e *Element) Contentf(format string, a ...any) *Element { return e.Content(fmt.Sprintf(format, a...)) } +// HTMLContent inserts raw (unescaped) HTML into the element content. func (e *Element) HTMLContent(html string) *Element { return e.Content(HTML(html)) } +// AppendContent appends additional content to the element. func (e *Element) AppendContent(v ...any) *Element { for _, v := range v { e.content += content(v) @@ -84,6 +100,7 @@ func (e *Element) AppendContent(v ...any) *Element { return e } +// AppendChild appends child elements to the current element. func (e *Element) AppendChild(child ...*Element) *Element { for _, i := range child { e.content += i.HTML() @@ -91,6 +108,7 @@ func (e *Element) AppendChild(child ...*Element) *Element { return e } +// AppendHTML appends one or more raw HTML strings directly to the element. func (e *Element) AppendHTML(html ...string) *Element { for _, i := range html { e.AppendContent(HTML(i)) @@ -99,6 +117,7 @@ func (e *Element) AppendHTML(html ...string) *Element { } // https://developer.mozilla.org/en-US/docs/Glossary/Void_element +// Map of HTML5 void elements that do not require a closing tag. var voidElements = map[string]struct{}{ "area": {}, "base": {}, @@ -116,14 +135,20 @@ var voidElements = map[string]struct{}{ "wbr": {}, } +// isVoidElement reports whether the element is a void (self-closing) element. func (e Element) isVoidElement() bool { _, ok := voidElements[strings.ToLower(e.tag)] return ok } +// builderPool provides a pool of reusable strings.Builder instances +// to reduce memory allocations during HTML rendering. var builderPool = pool.New[strings.Builder]() // https://developer.mozilla.org/en-US/docs/Web/HTML/Attributes +// printAttrs returns a formatted string representation of all attributes +// sorted by key. Boolean attributes (true or empty) are printed without a value. +// Attributes with value "false" are omitted. func (e Element) printAttrs() string { if len(e.attrs) == 0 { return "" @@ -165,6 +190,7 @@ func (e Element) printAttrs() string { return b.String() } +// String returns the serialized HTML representation of the element. func (e *Element) String() string { if e.tag == "" { return string(e.content) @@ -190,10 +216,12 @@ func (e *Element) String() string { return b.String() } +// HTML returns the element as HTML type, implementing [HTMLer]. func (e *Element) HTML() HTML { return HTML(e.String()) } +// Clone creates a deep copy of the element and its attributes. func (e *Element) Clone() *Element { attrs := make(map[string]string, len(e.attrs)) maps.Copy(attrs, e.attrs) @@ -204,6 +232,7 @@ func (e *Element) Clone() *Element { } } +// NewElement creates and returns a new HTML element with the given tag name. func NewElement(tag string) *Element { return &Element{tag: tag, attrs: make(map[string]string)} } diff --git a/html/html.go b/html/html.go index aa61902..39c0f02 100644 --- a/html/html.go +++ b/html/html.go @@ -3,20 +3,30 @@ package html import "html" var ( - EscapeString = html.EscapeString + // EscapeString is alias of [html.EscapeString], + // used for encoding HTML entities. + EscapeString = html.EscapeString + // UnescapeString is alias of [html.UnescapeString], + // used for decoding HTML entities. UnescapeString = html.UnescapeString ) +// HTML represents a string that contains valid HTML markup. type HTML string +// HTMLer defines types that can render themselves as HTML. type HTMLer interface { HTML() HTML } +// Background creates an element with no tag, typically used for raw content. func Background() *Element { return NewElement("") } +// NewHTML creates a new element. func NewHTML() *Element { return NewElement("html") } +// Common HTML element constructors for convenience. + func A() *Element { return NewElement("a") } func B() *Element { return NewElement("b") } func Body() *Element { return NewElement("body") } diff --git a/html/table.go b/html/table.go index 36a579f..4991d3c 100644 --- a/html/table.go +++ b/html/table.go @@ -4,8 +4,10 @@ import "strconv" var _ HTMLer = new(TableCell) +// TableCell wraps an Element and provides methods specific to and cells. type TableCell struct{ *Element } +// Tr creates a new element containing the given table cells. func Tr(element ...*TableCell) *Element { tr := NewElement("tr") for _, i := range element { @@ -14,44 +16,53 @@ func Tr(element ...*TableCell) *Element { return tr } +// Th creates a new (table header cell) element with optional content. func Th(content any) *TableCell { return &TableCell{NewElement("th").Content(content)} } +// Td creates a new (table data cell) element with optional content. func Td(content any) *TableCell { return &TableCell{NewElement("td").Content(content)} } +// Abbr sets the "abbr" attribute on the table cell. func (cell *TableCell) Abbr(abbr string) *TableCell { cell.Element.Attribute("abbr", abbr) return cell } +// Colspan sets the "colspan" attribute, specifying how many columns the cell spans. func (cell *TableCell) Colspan(n uint) *TableCell { cell.Element.Attribute("colspan", strconv.FormatUint(uint64(n), 10)) return cell } +// Headers sets the "headers" attribute, linking the cell to header IDs. func (cell *TableCell) Headers(headers string) *TableCell { cell.Element.Attribute("headers", headers) return cell } +// Rowspan sets the "rowspan" attribute, specifying how many rows the cell spans. func (cell *TableCell) Rowspan(n uint) *TableCell { cell.Element.Attribute("rowspan", strconv.FormatUint(uint64(n), 10)) return cell } +// Scope sets the "scope" attribute, typically used in to define scope of headers. func (cell *TableCell) Scope(scope string) *TableCell { cell.Element.Attribute("scope", scope) return cell } +// Class sets the "class" attribute for the table cell. func (cell *TableCell) Class(class ...string) *TableCell { cell.Element.Class(class...) return cell } +// Style sets the "style" attribute for the table cell. func (cell *TableCell) Style(style string) *TableCell { cell.Element.Style(style) return cell From 1f4d98bdced949bf390c5284a13cd527d2138a80 Mon Sep 17 00:00:00 2001 From: sunshineplan Date: Wed, 15 Oct 2025 16:20:58 +0800 Subject: [PATCH 29/40] httpsvr --- httpsvr/httpsvr.go | 124 +++++++++++++++++++++++++++------------------ 1 file changed, 74 insertions(+), 50 deletions(-) diff --git a/httpsvr/httpsvr.go b/httpsvr/httpsvr.go index 2372e19..568c57c 100644 --- a/httpsvr/httpsvr.go +++ b/httpsvr/httpsvr.go @@ -20,7 +20,7 @@ var certCache = cache.NewWithRenew[string, *tls.Certificate](false) var defaultReload = 24 * time.Hour -// Server defines parameters for running an HTTP server. +// Server defines parameters for running an HTTP or HTTPS server. type Server struct { *http.Server *log.Logger @@ -36,94 +36,103 @@ type Server struct { l *counter.Listener } -// New creates an HTTP server. +// New creates a new Server instance with default logger and error log. func New() *Server { - return &Server{Server: &http.Server{}, Logger: log.Default()} + logger := log.Default() + return &Server{Server: &http.Server{ErrorLog: logger.Logger}, Logger: logger} } +// SetLogger sets a custom logger for both the Server and its internal http.Server. func (s *Server) SetLogger(logger *log.Logger) { s.Logger = logger s.Server.ErrorLog = logger.Logger } +// SetReload defines the certificate reload interval. +// Default is 24 hours if not set explicitly. func (s *Server) SetReload(d time.Duration) { s.reload = d } -// Run runs an HTTP server which can be gracefully shut down. -func (s *Server) run() error { +// Serve starts the HTTP or HTTPS server and handles graceful shutdown signals. +// +// It listens on either a Unix domain socket (if s.Unix is set) or a TCP address. +// When receiving SIGHUP, the server reloads configuration or certificates. +// When receiving SIGINT/SIGTERM, it gracefully shuts down all connections. +func (s *Server) Serve(tls bool) (err error) { + s.tls = tls + if s.reload == 0 { + s.reload = defaultReload + } + // Channel used to wait for graceful shutdown completion. idleConnsClosed := make(chan struct{}) + // Handle system signals for reload and graceful stop. c := make(chan os.Signal, 1) signal.Notify(c, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP) + defer signal.Stop(c) go func() { for { switch <-c { case syscall.SIGHUP: - s.Rotate() - if s.tls { - cert, err := s.loadCertificate() - if err != nil { - s.Println("Failed to reload certificate:", err) - continue - } - if s.reload == 0 { - s.reload = defaultReload - } - certCache.Set(s.certFile+s.keyFile, cert, s.reload, s.loadCertificate) + if err := s.Reload(); err != nil { + s.Printf("reload failed: %v", err) + } else { + s.Print("reload successful") } case syscall.SIGINT, syscall.SIGTERM: ctx, cancel := context.WithTimeout(context.Background(), time.Minute) defer cancel() if err := s.Shutdown(ctx); err != nil { - s.Println("Failed to close server:", err) + s.Printf("failed to close server: %v", err) } close(idleConnsClosed) return } } }() - + var listener net.Listener if s.Unix != "" { - listener, err := net.Listen("unix", s.Unix) + // Listen on Unix domain socket. + listener, err = net.Listen("unix", s.Unix) if err != nil { - return fmt.Errorf("failed to listen socket file: %v", err) + return fmt.Errorf("failed to listen socket file: %w", err) } + defer os.Remove(s.Unix) // Let everyone can access the socket file. if err := os.Chmod(s.Unix, 0666); err != nil { - return fmt.Errorf("failed to chmod socket file: %v", err) + return fmt.Errorf("failed to chmod socket file: %w", err) } - s.l = counter.NewListener(listener) } else { + // Default to "http" or "https" if port not specified. port := s.Port if port == "" { - if s.tls { + if tls { port = "https" } else { port = "http" } } s.Addr = s.Host + ":" + port - listener, err := net.Listen("tcp", s.Addr) + listener, err = net.Listen("tcp", s.Addr) if err != nil { - return fmt.Errorf("failed to listen tcp: %v", err) + return fmt.Errorf("failed to listen tcp: %w", err) } - s.l = counter.NewListener(listener) } + s.l = counter.NewListener(listener) - var err error - if s.tls { - err = s.ServeTLS(s.l, "", "") + if tls { + err = s.Server.ServeTLS(s.l, "", "") } else { - err = s.Serve(s.l) + err = s.Server.Serve(s.l) } if err != http.ErrServerClosed { - return fmt.Errorf("failed to serve: %v", err) + return fmt.Errorf("failed to serve: %w", err) } - <-idleConnsClosed return nil } +// loadCertificate loads the current TLS certificate and key pair from disk. func (s *Server) loadCertificate() (*tls.Certificate, error) { cert, err := tls.LoadX509KeyPair(s.certFile, s.keyFile) if err != nil { @@ -132,39 +141,53 @@ func (s *Server) loadCertificate() (*tls.Certificate, error) { return &cert, nil } +// getCertificate is a callback for tls.Config.GetCertificate. +// It retrieves the cached certificate, or reloads it if expired. func (s *Server) getCertificate(_ *tls.ClientHelloInfo) (*tls.Certificate, error) { - v, ok := certCache.Get(s.certFile + s.keyFile) + key := s.certFile + s.keyFile + v, ok := certCache.Get(key) if ok { return v, nil } - cert, err := s.loadCertificate() if err != nil { return nil, err } - - if s.reload == 0 { - s.reload = defaultReload - } - certCache.Set(s.certFile+s.keyFile, cert, s.reload, s.loadCertificate) - + certCache.Set(key, cert, s.reload, s.loadCertificate) return cert, nil } -// Run runs an HTTP server which can be gracefully shut down. +// Run starts an HTTP server with graceful shutdown support. func (s *Server) Run() error { - s.tls = false - return s.run() + return s.Serve(false) } +// RunTLS starts an HTTPS server using the provided certificate and key files. +// Certificates are automatically reloaded based on the reload interval. func (s *Server) RunTLS(certFile, keyFile string) error { - s.tls = true s.certFile = certFile s.keyFile = keyFile - s.TLSConfig = &tls.Config{GetCertificate: s.getCertificate} - return s.run() + if s.TLSConfig == nil { + s.TLSConfig = &tls.Config{} + } + s.TLSConfig.GetCertificate = s.getCertificate + return s.Serve(true) +} + +// Reload rotates server's log and reloads TLS certificates if applicable. +func (s *Server) Reload() error { + s.Rotate() + if s.tls { + cert, err := s.loadCertificate() + if err != nil { + return fmt.Errorf("failed to reload certificate: %w", err) + } + certCache.Set(s.certFile+s.keyFile, cert, s.reload, s.loadCertificate) + } + return nil } +// ReadBytes returns the total number of bytes read by the listener. func (s *Server) ReadBytes() int64 { if s.l == nil { return 0 @@ -172,6 +195,7 @@ func (s *Server) ReadBytes() int64 { return s.l.ReadBytes() } +// WriteBytes returns the total number of bytes written by the listener. func (s *Server) WriteBytes() int64 { if s.l == nil { return 0 @@ -179,22 +203,22 @@ func (s *Server) WriteBytes() int64 { return s.l.WriteBytes() } -// TCP runs an HTTP server on TCP network listener. +// TCP runs an HTTP server using TCP network listener. func TCP(addr string, handler http.Handler) error { return (&Server{Server: &http.Server{Addr: addr, Handler: handler}}).Run() } -// TLS runs an HTTP server on TCP network listener and handle requests on incoming TLS connections. +// TLS runs an HTTPS server using TCP network listener. func TLS(addr string, handler http.Handler, certFile, keyFile string) error { return (&Server{Server: &http.Server{Addr: addr, Handler: handler}}).RunTLS(certFile, keyFile) } -// Unix runs an HTTP server on Unix domain socket listener. +// Unix runs an HTTP server using Unix domain socket listener. func Unix(unix string, handler http.Handler) error { return (&Server{Unix: unix, Server: &http.Server{Handler: handler}}).Run() } -// UnixTLS runs an HTTP server on Unix domain socket listener and handle requests on incoming TLS connections. +// UnixTLS runs an HTTPS server using Unix domain socket listener. func UnixTLS(unix string, handler http.Handler, certFile, keyFile string) error { return (&Server{Unix: unix, Server: &http.Server{Handler: handler}}).RunTLS(certFile, keyFile) } From 64628d6dc2c2969c51b1dff5ea20508bd906eb21 Mon Sep 17 00:00:00 2001 From: sunshineplan Date: Fri, 17 Oct 2025 16:18:24 +0800 Subject: [PATCH 30/40] log --- log/handler.go | 27 +++++++++++---------------- log/log.go | 10 ++++------ log/logger.go | 34 ++++++++++++++++------------------ log/logger_test.go | 7 ++++--- log/rotate.go | 13 ++++++++++--- 5 files changed, 45 insertions(+), 46 deletions(-) diff --git a/log/handler.go b/log/handler.go index 98bfac7..47286cc 100644 --- a/log/handler.go +++ b/log/handler.go @@ -3,47 +3,42 @@ package log import ( "bytes" "context" - "fmt" "log" "log/slog" - "strings" "sync" - "time" ) var _ slog.Handler = new(defaultHandler) type defaultHandler struct { *sync.Mutex - slog.Handler - *log.Logger *bytes.Buffer + *log.Logger + slog.Handler } -func newDefaultHandler(mu *sync.Mutex, logger *log.Logger, opts *slog.HandlerOptions) *defaultHandler { +func newDefaultHandler(mu *sync.Mutex, logger *log.Logger, level *slog.LevelVar) *defaultHandler { buf := new(bytes.Buffer) - return &defaultHandler{mu, slog.NewTextHandler(buf, opts), logger, buf} + return &defaultHandler{mu, buf, logger, slog.NewTextHandler(buf, &slog.HandlerOptions{Level: level})} } func (h *defaultHandler) Handle(ctx context.Context, r slog.Record) error { h.Lock() defer h.Unlock() - msg := fmt.Sprintf("%s %s", r.Level, r.Message) - r.Time, r.Message, r.Level = time.Time{}, "", 0 - h.Handler.Handle(ctx, r) - if log := strings.TrimSpace(strings.Replace(strings.Replace(h.String(), "level=INFO", "", 1), ` msg=""`, "", 1)); log == "" { - h.Print(msg) - } else { - h.Println(msg, log) + if err := h.Handler.Handle(ctx, r); err != nil { + return err + } + if _, err := h.Writer().Write(h.Bytes()); err != nil { + return err } h.Reset() return nil } func (h *defaultHandler) WithAttrs(attrs []slog.Attr) slog.Handler { - return &defaultHandler{h.Mutex, h.Handler.WithAttrs(attrs), h.Logger, h.Buffer} + return &defaultHandler{h.Mutex, h.Buffer, h.Logger, h.Handler.WithAttrs(attrs)} } func (h *defaultHandler) WithGroup(name string) slog.Handler { - return &defaultHandler{h.Mutex, h.Handler.WithGroup(name), h.Logger, h.Buffer} + return &defaultHandler{h.Mutex, h.Buffer, h.Logger, h.Handler.WithGroup(name)} } diff --git a/log/log.go b/log/log.go index 7b4cd54..727c5a5 100644 --- a/log/log.go +++ b/log/log.go @@ -104,8 +104,8 @@ func Error(msg string, args ...any) { func ErrorContext(ctx context.Context, msg string, args ...any) { std.ErrorContext(ctx, msg, args...) } -func LoggerHandler() slog.Handler { - return std.LoggerHandler() +func Handler() slog.Handler { + return std.Handler() } func Info(msg string, args ...any) { std.Info(msg, args...) @@ -126,10 +126,8 @@ func WarnContext(ctx context.Context, msg string, args ...any) { std.WarnContext(ctx, msg, args...) } func With(args ...any) *Logger { - std = std.With(args...) - return std + return std.With(args...) } func WithGroup(name string) *Logger { - std = std.WithGroup(name) - return std + return std.WithGroup(name) } diff --git a/log/logger.go b/log/logger.go index f830b72..db5a30e 100644 --- a/log/logger.go +++ b/log/logger.go @@ -15,30 +15,30 @@ var ( ) const ( - Ldate = 1 << iota // the date in the local time zone: 2009/01/23 - Ltime // the time in the local time zone: 01:23:23 - Lmicroseconds // microsecond resolution: 01:23:23.123123. assumes Ltime. - Llongfile // full file name and line number: /a/b/c/d.go:23 - Lshortfile // final file name element and line number: d.go:23. overrides Llongfile - LUTC // if Ldate or Ltime is set, use UTC rather than the local time zone - Lmsgprefix // move the "prefix" from the beginning of the line to before the message - LstdFlags = Ldate | Ltime // initial values for the standard logger + Ldate = log.Ldate // the date in the local time zone: 2009/01/23 + Ltime = log.Ltime // the time in the local time zone: 01:23:23 + Lmicroseconds = log.Lmicroseconds // microsecond resolution: 01:23:23.123123. assumes Ltime. + Llongfile = log.Llongfile // full file name and line number: /a/b/c/d.go:23 + Lshortfile = log.Lshortfile // final file name element and line number: d.go:23. overrides Llongfile + LUTC = log.LUTC // if Ldate or Ltime is set, use UTC rather than the local time zone + Lmsgprefix = log.Lmsgprefix // move the "prefix" from the beginning of the line to before the message + LstdFlags = log.LstdFlags // initial values for the standard logger ) type Logger struct { + mu sync.Mutex *log.Logger + file *os.File extra io.Writer - slog *slog.Logger - - mu *sync.Mutex + slog *slog.Logger level *slog.LevelVar } func newLogger(l *log.Logger, file *os.File) *Logger { - logger := &Logger{Logger: l, file: file, mu: new(sync.Mutex), level: new(slog.LevelVar)} - logger.slog = slog.New(newDefaultHandler(logger.mu, l, &slog.HandlerOptions{Level: logger.level})) + logger := &Logger{Logger: l, file: file, level: new(slog.LevelVar)} + logger.slog = slog.New(newDefaultHandler(&logger.mu, l, logger.level)) return logger } @@ -118,7 +118,7 @@ func (l *Logger) Error(msg string, args ...any) { func (l *Logger) ErrorContext(ctx context.Context, msg string, args ...any) { l.slog.ErrorContext(ctx, msg, args...) } -func (l *Logger) LoggerHandler() slog.Handler { +func (l *Logger) Handler() slog.Handler { return l.slog.Handler() } func (l *Logger) Info(msg string, args ...any) { @@ -140,12 +140,10 @@ func (l *Logger) WarnContext(ctx context.Context, msg string, args ...any) { l.slog.WarnContext(ctx, msg, args...) } func (l *Logger) With(args ...any) *Logger { - l.slog = l.slog.With(args...) - return l + return &Logger{Logger: l.Logger, file: l.file, extra: l.extra, slog: l.slog.With(args...), level: l.level} } func (l *Logger) WithGroup(name string) *Logger { - l.slog = l.slog.WithGroup(name) - return l + return &Logger{Logger: l.Logger, file: l.file, extra: l.extra, slog: l.slog.WithGroup(name), level: l.level} } func (l *Logger) Rotate() { diff --git a/log/logger_test.go b/log/logger_test.go index 3422de9..6ec7dce 100644 --- a/log/logger_test.go +++ b/log/logger_test.go @@ -6,6 +6,7 @@ import ( "log/slog" "os" "runtime" + "strings" "testing" ) @@ -51,7 +52,7 @@ func TestSLogger(t *testing.T) { t.Errorf("expected empty string; got %q", file) } l.Info("test") - if s, expected := buf.String(), "INFO test\n"; s != expected { + if s, expected := buf.String(), "INFO msg=test\n"; !strings.HasSuffix(s, expected) { t.Errorf("expected %q; got %q", expected, s) } buf.Reset() @@ -62,13 +63,13 @@ func TestSLogger(t *testing.T) { buf.Reset() l.SetLevel(slog.LevelDebug) l.Debug("test") - if s, expected := buf.String(), "DEBUG test\n"; s != expected { + if s, expected := buf.String(), "DEBUG msg=test\n"; !strings.HasSuffix(s, expected) { t.Errorf("expected %q; got %q", expected, s) } buf.Reset() l = l.WithGroup("g").With("a", 1) l.Info("test") - if s, expected := buf.String(), "INFO test g.a=1\n"; s != expected { + if s, expected := buf.String(), "INFO msg=test g.a=1\n"; !strings.HasSuffix(s, expected) { t.Errorf("expected %q; got %q", expected, s) } } diff --git a/log/rotate.go b/log/rotate.go index dc4fa6f..6f433ec 100644 --- a/log/rotate.go +++ b/log/rotate.go @@ -1,6 +1,7 @@ package log import ( + "context" "os" "os/signal" ) @@ -9,12 +10,18 @@ type Rotatable interface { Rotate() } -func ListenRotateSignal(r Rotatable, sig ...os.Signal) { +func ListenRotateSignal(ctx context.Context, r Rotatable, sig ...os.Signal) { c := make(chan os.Signal, 1) signal.Notify(c, sig...) go func() { - for range c { - r.Rotate() + for { + select { + case <-ctx.Done(): + signal.Stop(c) + return + case <-c: + r.Rotate() + } } }() } From 35ef404b48fdcee82894d3a3fc107eeac27570b4 Mon Sep 17 00:00:00 2001 From: sunshineplan Date: Sat, 18 Oct 2025 12:42:49 +0800 Subject: [PATCH 31/40] container and log --- container/value.go | 36 ++++------------ container/value_test.go | 66 ++++++----------------------- log/handler.go | 4 +- log/log.go | 89 ++++++++++++++++++++------------------ log/logger.go | 94 +++++++++++++++++++++++------------------ log/logger_test.go | 6 +-- 6 files changed, 127 insertions(+), 168 deletions(-) diff --git a/container/value.go b/container/value.go index 03a0366..d6b21aa 100644 --- a/container/value.go +++ b/container/value.go @@ -15,51 +15,31 @@ func NewValue[T any]() *Value[T] { return &Value[T]{} } -// Load returns the value set by the most recent Store and a boolean -// indicating whether a value was stored. -// If there has been no call to Store for this Value, it returns the -// zero value of T and false. -func (v *Value[T]) Load() (val T, ok bool) { +// Load returns the value set by the most recent [Value.Store]. +// If there has been no call to [Value.Store] for this Value, it returns the +// zero value of T. +func (v *Value[T]) Load() (val T) { if loaded := v.v.Load(); loaded != nil { - return loaded.(T), true - } - return -} - -// MustLoad returns the value set by the most recent Store. -// It panics if there has been no call to Store for this Value. -func (v *Value[T]) MustLoad() (val T) { - val, ok := v.Load() - if !ok { - panic("container/value: there has been no call to Store for this Value") + return loaded.(T) } return } // Store sets the value of the [Value] v to val. func (v *Value[T]) Store(val T) { - if any(val) == nil { - panic("container/value: store of nil value into Value") - } v.v.Store(val) } // Swap stores the new value into the Value and returns the previous value. -// If no value was previously stored, it returns the zero value of T and false. -func (v *Value[T]) Swap(new T) (old T, loaded bool) { - if any(new) == nil { - panic("container/value: swap of nil value into Value") - } +// If no value was previously stored, it returns the zero value of T. +func (v *Value[T]) Swap(new T) (old T) { if prev := v.v.Swap(new); prev != nil { - return prev.(T), true + return prev.(T) } return } // CompareAndSwap executes the compare-and-swap operation for the [Value]. func (v *Value[T]) CompareAndSwap(old, new T) (swapped bool) { - if any(new) == nil { - panic("container/value: compare and swap of nil value into Value") - } return v.v.CompareAndSwap(old, new) } diff --git a/container/value_test.go b/container/value_test.go index 119129a..4433670 100644 --- a/container/value_test.go +++ b/container/value_test.go @@ -13,15 +13,12 @@ import ( func TestValue(t *testing.T) { v := NewValue[int]() - if _, ok := v.Load(); ok { - t.Fatal("initial Value is not nil") - } v.Store(42) - if i, ok := v.Load(); !ok || i != 42 { + if i := v.Load(); i != 42 { t.Fatalf("wrong value: got %d, want 42", i) } v.Store(84) - if i, ok := v.Load(); !ok || i != 84 { + if i := v.Load(); i != 84 { t.Fatalf("wrong value: got %d, want 84", i) } } @@ -29,55 +26,22 @@ func TestValue(t *testing.T) { func TestValueLarge(t *testing.T) { v := NewValue[string]() v.Store("foo") - if s, ok := v.Load(); !ok || s != "foo" { + if s := v.Load(); s != "foo" { t.Fatalf("wrong value: got %s, want foo", s) } v.Store("barbaz") - if s, ok := v.Load(); !ok || s != "barbaz" { + if s := v.Load(); s != "barbaz" { t.Fatalf("wrong value: got %s, want barbaz", s) } } -func TestValuePanic(t *testing.T) { - const nilErr = "cache/value: store of nil value into Value" - v := NewValue[any]() - func() { - defer func() { - err := recover() - if err != nilErr { - t.Fatalf("inconsistent store panic: got '%v', want '%v'", err, nilErr) - } - }() - v.Store(nil) - }() - v.Store(1) - func() { - defer func() { - err := recover() - if err != nilErr { - t.Fatalf("inconsistent store panic: got '%v', want '%v'", err, nilErr) - } - }() - v.Store(nil) - }() -} - func TestPointType(t *testing.T) { v := NewValue[*int]() - if v, stored := v.Load(); stored { - t.Fatal("wrong stored status") - } else if v != nil { - t.Fatal("initial Value is not nil") - } v.Store((*int)(nil)) - if v, stored := v.Load(); !stored { - t.Fatal("wrong stored status") - } else if v != nil { + if v := v.Load(); v != nil { t.Fatalf("wrong value: got %v, want nil", v) } - if old, stored := v.Swap(utils.Ptr(1)); !stored { - t.Fatal("wrong stored status") - } else if old != nil { + if old := v.Swap(utils.Ptr(1)); old != nil { t.Fatalf("wrong value: got %v, want nil", v) } } @@ -105,7 +69,7 @@ func TestValueConcurrent(t *testing.T) { for j := 0; j < N; j++ { x := test[rand.IntN(len(test))] v.Store(x) - x = v.MustLoad() + x = v.Load() for _, x1 := range test { if x == x1 { continue loop @@ -131,7 +95,7 @@ func BenchmarkValueRead(b *testing.B) { v.Store(new(int)) b.RunParallel(func(pb *testing.PB) { for pb.Next() { - x := v.MustLoad() + x := v.Load() if *x != 0 { b.Fatalf("wrong value: got %v, want 0", *x) } @@ -145,7 +109,7 @@ var Value_SwapTests = []struct { want any err any }{ - {init: nil, new: nil, err: "cache/value: swap of nil value into Value"}, + {init: nil, new: nil, err: "sync/atomic: swap of nil value into Value"}, {init: nil, new: true, want: nil, err: nil}, {init: true, new: "", err: "sync/atomic: swap of inconsistently typed value into Value"}, {init: true, new: false, want: true, err: nil}, @@ -167,10 +131,10 @@ func TestValue_Swap(t *testing.T) { t.Errorf("should panic %v, got ", tt.err) } }() - if got, _ := v.Swap(tt.new); got != tt.want { + if got := v.Swap(tt.new); got != tt.want { t.Errorf("got %v, want %v", got, tt.want) } - if got, stored := v.Load(); !stored || got != tt.new { + if got := v.Load(); got != tt.new { t.Errorf("got %v, want %v", got, tt.new) } }) @@ -192,16 +156,14 @@ func TestValueSwapConcurrent(t *testing.T) { go func() { var c uint64 for new := i; new < i+n; new++ { - if old, stored := v.Swap(new); stored { - c += old - } + c += v.Swap(new) } atomic.AddUint64(&count, c) g.Done() }() } g.Wait() - if want, got := (m*n-1)*(m*n)/2, count+v.MustLoad(); got != want { + if want, got := (m*n-1)*(m*n)/2, count+v.Load(); got != want { t.Errorf("sum from 0 to %d was %d, want %v", m*n-1, got, want) } } @@ -270,7 +232,7 @@ func TestValueCompareAndSwapConcurrent(t *testing.T) { }() } w.Wait() - if stop := v.MustLoad(); stop != m*n { + if stop := v.Load(); stop != m*n { t.Errorf("did not get to %v, stopped at %v", m*n, stop) } } diff --git a/log/handler.go b/log/handler.go index 47286cc..49ad456 100644 --- a/log/handler.go +++ b/log/handler.go @@ -17,9 +17,9 @@ type defaultHandler struct { slog.Handler } -func newDefaultHandler(mu *sync.Mutex, logger *log.Logger, level *slog.LevelVar) *defaultHandler { +func newDefaultHandler(logger *log.Logger, level *slog.LevelVar) *defaultHandler { buf := new(bytes.Buffer) - return &defaultHandler{mu, buf, logger, slog.NewTextHandler(buf, &slog.HandlerOptions{Level: level})} + return &defaultHandler{new(sync.Mutex), buf, logger, slog.NewTextHandler(buf, &slog.HandlerOptions{Level: level})} } func (h *defaultHandler) Handle(ctx context.Context, r slog.Record) error { diff --git a/log/log.go b/log/log.go index 727c5a5..a11e5a7 100644 --- a/log/log.go +++ b/log/log.go @@ -6,128 +6,135 @@ import ( "log" "log/slog" "os" + "sync/atomic" ) -var std = newLogger(log.Default(), os.Stderr) +var defaultLogger atomic.Pointer[Logger] -func Default() *Logger { return std } +func init() { + defaultLogger.Store(newLogger(log.Default(), os.Stderr)) +} + +func Default() *Logger { return defaultLogger.Load() } + +func SetDefault(l *Logger) { + defaultLogger.Store(l) +} func File() string { - return std.File() + return Default().File() } func SetOutput(file string, extra io.Writer) { - std.SetOutput(file, extra) + Default().SetOutput(file, extra) } func SetFile(file string) { - std.SetFile(file) + Default().SetFile(file) } func SetExtra(extra io.Writer) { - std.SetExtra(extra) + Default().SetExtra(extra) } func Rotate() { - std.Rotate() + Default().Rotate() } func Flags() int { - return std.Flags() + return Default().Flags() } func SetFlags(flag int) { - std.SetFlags(flag) + Default().SetFlags(flag) } func Prefix() string { - return std.Prefix() + return Default().Prefix() } func SetPrefix(prefix string) { - std.SetPrefix(prefix) + Default().SetPrefix(prefix) } func Writer() io.Writer { - return std.Writer() + return Default().Writer() } func Print(v ...any) { - std.Print(v...) + Default().Print(v...) } func Printf(format string, v ...any) { - std.Printf(format, v...) + Default().Printf(format, v...) } func Println(v ...any) { - std.Println(v...) + Default().Println(v...) } func Fatal(v ...any) { - std.Fatal(v...) + Default().Fatal(v...) } func Fatalf(format string, v ...any) { - std.Fatalf(format, v...) + Default().Fatalf(format, v...) } func Fatalln(v ...any) { - std.Fatalln(v...) + Default().Fatalln(v...) } func Panic(v ...any) { - std.Panic(v...) + Default().Panic(v...) } func Panicf(format string, v ...any) { - std.Panicf(format, v...) + Default().Panicf(format, v...) } func Panicln(v ...any) { - std.Panicln(v...) + Default().Panicln(v...) } func Output(calldepth int, s string) error { - return std.Output(calldepth+1, s) // +1 for this frame. + return Default().Output(calldepth+1, s) // +1 for this frame. } func SetHandler(h slog.Handler) { - std.mu.Lock() - defer std.mu.Unlock() - std.slog = slog.New(h) + Default().slog.Store(slog.New(h)) } func Level() *slog.LevelVar { - return std.level + return Default().level } func SetLevel(level slog.Level) { - std.level.Set(level) + Default().level.Set(level) } func Debug(msg string, args ...any) { - std.Debug(msg, args...) + Default().Debug(msg, args...) } func DebugContext(ctx context.Context, msg string, args ...any) { - std.DebugContext(ctx, msg, args...) + Default().DebugContext(ctx, msg, args...) } func Enabled(ctx context.Context, level slog.Level) bool { - return std.Enabled(ctx, level) + return Default().Enabled(ctx, level) } func Error(msg string, args ...any) { - std.Error(msg, args...) + Default().Error(msg, args...) } func ErrorContext(ctx context.Context, msg string, args ...any) { - std.ErrorContext(ctx, msg, args...) + Default().ErrorContext(ctx, msg, args...) } func Handler() slog.Handler { - return std.Handler() + return Default().Handler() } func Info(msg string, args ...any) { - std.Info(msg, args...) + Default().Info(msg, args...) } func InfoContext(ctx context.Context, msg string, args ...any) { - std.InfoContext(ctx, msg, args...) + Default().InfoContext(ctx, msg, args...) } func Log(ctx context.Context, level slog.Level, msg string, args ...any) { - std.Log(ctx, level, msg, args...) + Default().Log(ctx, level, msg, args...) } func LogAttrs(ctx context.Context, level slog.Level, msg string, attrs ...slog.Attr) { - std.LogAttrs(ctx, level, msg, attrs...) + Default().LogAttrs(ctx, level, msg, attrs...) } func Warn(msg string, args ...any) { - std.Warn(msg, args...) + Default().Warn(msg, args...) } func WarnContext(ctx context.Context, msg string, args ...any) { - std.WarnContext(ctx, msg, args...) + Default().WarnContext(ctx, msg, args...) } func With(args ...any) *Logger { - return std.With(args...) + return Default().With(args...) } func WithGroup(name string) *Logger { - return std.WithGroup(name) + return Default().WithGroup(name) } diff --git a/log/logger.go b/log/logger.go index db5a30e..26627d7 100644 --- a/log/logger.go +++ b/log/logger.go @@ -6,7 +6,9 @@ import ( "log" "log/slog" "os" - "sync" + "sync/atomic" + + "github.com/sunshineplan/utils/container" ) var ( @@ -26,19 +28,17 @@ const ( ) type Logger struct { - mu sync.Mutex *log.Logger - - file *os.File - extra io.Writer - - slog *slog.Logger + file atomic.Pointer[os.File] + extra container.Value[io.Writer] + slog atomic.Pointer[slog.Logger] level *slog.LevelVar } func newLogger(l *log.Logger, file *os.File) *Logger { - logger := &Logger{Logger: l, file: file, level: new(slog.LevelVar)} - logger.slog = slog.New(newDefaultHandler(&logger.mu, l, logger.level)) + logger := &Logger{Logger: l, level: new(slog.LevelVar)} + logger.file.Store(file) + logger.slog.Store(slog.New(newDefaultHandler(l, logger.level))) return logger } @@ -51,8 +51,8 @@ func New(file, prefix string, flag int) *Logger { } func (l *Logger) File() string { - if l.file != nil { - return l.file.Name() + if file := l.file.Load(); file != nil { + return file.Name() } return "" } @@ -69,33 +69,29 @@ func (l *Logger) setOutput(file *os.File, extra io.Writer) { writers = append(writers, io.Discard) } l.Logger.SetOutput(io.MultiWriter(writers...)) - if l.file != nil && l.file != file { - l.file.Close() + if oldFile := l.file.Load(); oldFile != nil && oldFile != file { + oldFile.Close() + } + l.file.Store(file) + if extra != nil { + l.extra.Store(extra) } - l.file = file - l.extra = extra } func (l *Logger) SetOutput(file string, extra io.Writer) { - l.mu.Lock() - defer l.mu.Unlock() l.setOutput(openFile(file), extra) } func (l *Logger) SetFile(file string) { - l.SetOutput(file, l.extra) + l.SetOutput(file, l.extra.Load()) } func (l *Logger) SetExtra(extra io.Writer) { - l.mu.Lock() - defer l.mu.Unlock() - l.setOutput(l.file, extra) + l.setOutput(l.file.Load(), extra) } func (l *Logger) SetHandler(h slog.Handler) { - l.mu.Lock() - defer l.mu.Unlock() - l.slog = slog.New(h) + l.slog.Store(slog.New(h)) } func (l *Logger) Level() *slog.LevelVar { return l.level @@ -104,54 +100,68 @@ func (l *Logger) SetLevel(level slog.Level) { l.level.Set(level) } func (l *Logger) Debug(msg string, args ...any) { - l.slog.Debug(msg, args...) + l.slog.Load().Debug(msg, args...) } func (l *Logger) DebugContext(ctx context.Context, msg string, args ...any) { - l.slog.DebugContext(ctx, msg, args...) + l.slog.Load().DebugContext(ctx, msg, args...) } func (l *Logger) Enabled(ctx context.Context, level slog.Level) bool { - return l.slog.Enabled(ctx, level) + return l.slog.Load().Enabled(ctx, level) } func (l *Logger) Error(msg string, args ...any) { - l.slog.Error(msg, args...) + l.slog.Load().Error(msg, args...) } func (l *Logger) ErrorContext(ctx context.Context, msg string, args ...any) { - l.slog.ErrorContext(ctx, msg, args...) + l.slog.Load().ErrorContext(ctx, msg, args...) } func (l *Logger) Handler() slog.Handler { - return l.slog.Handler() + return l.slog.Load().Handler() } func (l *Logger) Info(msg string, args ...any) { - l.slog.Info(msg, args...) + l.slog.Load().Info(msg, args...) } func (l *Logger) InfoContext(ctx context.Context, msg string, args ...any) { - l.slog.InfoContext(ctx, msg, args...) + l.slog.Load().InfoContext(ctx, msg, args...) } func (l *Logger) Log(ctx context.Context, level slog.Level, msg string, args ...any) { - l.slog.Log(ctx, level, msg, args...) + l.slog.Load().Log(ctx, level, msg, args...) } func (l *Logger) LogAttrs(ctx context.Context, level slog.Level, msg string, attrs ...slog.Attr) { - l.slog.LogAttrs(ctx, level, msg, attrs...) + l.slog.Load().LogAttrs(ctx, level, msg, attrs...) } func (l *Logger) Warn(msg string, args ...any) { - l.slog.Warn(msg, args...) + l.slog.Load().Warn(msg, args...) } func (l *Logger) WarnContext(ctx context.Context, msg string, args ...any) { - l.slog.WarnContext(ctx, msg, args...) + l.slog.Load().WarnContext(ctx, msg, args...) } func (l *Logger) With(args ...any) *Logger { - return &Logger{Logger: l.Logger, file: l.file, extra: l.extra, slog: l.slog.With(args...), level: l.level} + logger := &Logger{Logger: l.Logger, extra: l.extra, level: l.level} + logger.file.Store(l.file.Load()) + if extra := l.extra.Load(); extra != nil { + logger.extra.Store(extra) + } + logger.slog.Store(l.slog.Load().With(args...)) + return logger } func (l *Logger) WithGroup(name string) *Logger { - return &Logger{Logger: l.Logger, file: l.file, extra: l.extra, slog: l.slog.WithGroup(name), level: l.level} + logger := &Logger{Logger: l.Logger, extra: l.extra, level: l.level} + logger.file.Store(l.file.Load()) + if extra := l.extra.Load(); extra != nil { + logger.extra.Store(extra) + } + logger.slog.Store(l.slog.Load().WithGroup(name)) + return logger } func (l *Logger) Rotate() { - if i, ok := l.extra.(Rotatable); ok { - i.Rotate() + if extra := l.extra.Load(); extra != nil { + if i, ok := extra.(Rotatable); ok { + i.Rotate() + } } - if l.file != nil { - l.SetFile(l.file.Name()) + if file := l.file.Load(); file != nil { + l.SetFile(file.Name()) } } diff --git a/log/logger_test.go b/log/logger_test.go index 6ec7dce..60e424c 100644 --- a/log/logger_test.go +++ b/log/logger_test.go @@ -52,7 +52,7 @@ func TestSLogger(t *testing.T) { t.Errorf("expected empty string; got %q", file) } l.Info("test") - if s, expected := buf.String(), "INFO msg=test\n"; !strings.HasSuffix(s, expected) { + if s, expected := buf.String(), "level=INFO msg=test\n"; !strings.HasSuffix(s, expected) { t.Errorf("expected %q; got %q", expected, s) } buf.Reset() @@ -63,13 +63,13 @@ func TestSLogger(t *testing.T) { buf.Reset() l.SetLevel(slog.LevelDebug) l.Debug("test") - if s, expected := buf.String(), "DEBUG msg=test\n"; !strings.HasSuffix(s, expected) { + if s, expected := buf.String(), "level=DEBUG msg=test\n"; !strings.HasSuffix(s, expected) { t.Errorf("expected %q; got %q", expected, s) } buf.Reset() l = l.WithGroup("g").With("a", 1) l.Info("test") - if s, expected := buf.String(), "INFO msg=test g.a=1\n"; !strings.HasSuffix(s, expected) { + if s, expected := buf.String(), "level=INFO msg=test g.a=1\n"; !strings.HasSuffix(s, expected) { t.Errorf("expected %q; got %q", expected, s) } } From b5f09e8bd5b976f220dca48099cdf98f3114d4fa Mon Sep 17 00:00:00 2001 From: sunshineplan Date: Mon, 20 Oct 2025 14:47:46 +0800 Subject: [PATCH 32/40] log --- log/handler.go | 7 ++++--- log/logger.go | 26 ++++++++++++++++++-------- log/logger_test.go | 10 +++++++--- 3 files changed, 29 insertions(+), 14 deletions(-) diff --git a/log/handler.go b/log/handler.go index 49ad456..dc3365e 100644 --- a/log/handler.go +++ b/log/handler.go @@ -5,7 +5,9 @@ import ( "context" "log" "log/slog" + "strings" "sync" + "time" ) var _ slog.Handler = new(defaultHandler) @@ -25,12 +27,11 @@ func newDefaultHandler(logger *log.Logger, level *slog.LevelVar) *defaultHandler func (h *defaultHandler) Handle(ctx context.Context, r slog.Record) error { h.Lock() defer h.Unlock() + r.Time = time.Time{} if err := h.Handler.Handle(ctx, r); err != nil { return err } - if _, err := h.Writer().Write(h.Bytes()); err != nil { - return err - } + h.Print(strings.TrimPrefix(h.String(), "level=")) h.Reset() return nil } diff --git a/log/logger.go b/log/logger.go index 26627d7..eab56af 100644 --- a/log/logger.go +++ b/log/logger.go @@ -46,7 +46,10 @@ func New(file, prefix string, flag int) *Logger { if file == "" { return newLogger(log.New(io.Discard, prefix, flag), nil) } - f := openFile(file) + f, err := openFile(file) + if err != nil { + panic(err) + } return newLogger(log.New(f, prefix, flag), f) } @@ -70,7 +73,9 @@ func (l *Logger) setOutput(file *os.File, extra io.Writer) { } l.Logger.SetOutput(io.MultiWriter(writers...)) if oldFile := l.file.Load(); oldFile != nil && oldFile != file { - oldFile.Close() + if err := oldFile.Close(); err != nil { + l.Error("failed to close log file", "error", err) + } } l.file.Store(file) if extra != nil { @@ -78,8 +83,13 @@ func (l *Logger) setOutput(file *os.File, extra io.Writer) { } } -func (l *Logger) SetOutput(file string, extra io.Writer) { - l.setOutput(openFile(file), extra) +func (l *Logger) SetOutput(file string, extra io.Writer) error { + f, err := openFile(file) + if err != nil { + return err + } + l.setOutput(f, extra) + return nil } func (l *Logger) SetFile(file string) { @@ -169,13 +179,13 @@ func (l *Logger) Write(b []byte) (int, error) { return l.Writer().Write(b) } -func openFile(file string) *os.File { +func openFile(file string) (*os.File, error) { if file != "" { f, err := os.OpenFile(file, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0640) if err != nil { - panic(err) + return nil, err } - return f + return f, nil } - return nil + return nil, nil } diff --git a/log/logger_test.go b/log/logger_test.go index 60e424c..275f919 100644 --- a/log/logger_test.go +++ b/log/logger_test.go @@ -2,6 +2,7 @@ package log import ( "bytes" + "fmt" "log" "log/slog" "os" @@ -52,7 +53,8 @@ func TestSLogger(t *testing.T) { t.Errorf("expected empty string; got %q", file) } l.Info("test") - if s, expected := buf.String(), "level=INFO msg=test\n"; !strings.HasSuffix(s, expected) { + fmt.Print(buf.String()) + if s, expected := buf.String(), "INFO msg=test\n"; !strings.HasSuffix(s, expected) { t.Errorf("expected %q; got %q", expected, s) } buf.Reset() @@ -63,13 +65,15 @@ func TestSLogger(t *testing.T) { buf.Reset() l.SetLevel(slog.LevelDebug) l.Debug("test") - if s, expected := buf.String(), "level=DEBUG msg=test\n"; !strings.HasSuffix(s, expected) { + fmt.Print(buf.String()) + if s, expected := buf.String(), "DEBUG msg=test\n"; !strings.HasSuffix(s, expected) { t.Errorf("expected %q; got %q", expected, s) } buf.Reset() l = l.WithGroup("g").With("a", 1) l.Info("test") - if s, expected := buf.String(), "level=INFO msg=test g.a=1\n"; !strings.HasSuffix(s, expected) { + fmt.Print(buf.String()) + if s, expected := buf.String(), "INFO msg=test g.a=1\n"; !strings.HasSuffix(s, expected) { t.Errorf("expected %q; got %q", expected, s) } } From cf19079b8835208cfcd4e26b9b9ca1599a35c314 Mon Sep 17 00:00:00 2001 From: sunshineplan Date: Mon, 20 Oct 2025 15:33:22 +0800 Subject: [PATCH 33/40] log comment --- log/handler.go | 17 ++++++---- log/log.go | 79 ++++++++++++++++++++++++++++++++++++++++++-- log/logger.go | 90 ++++++++++++++++++++++++++++++++++++++++---------- log/rotate.go | 3 ++ 4 files changed, 163 insertions(+), 26 deletions(-) diff --git a/log/handler.go b/log/handler.go index dc3365e..ba7a191 100644 --- a/log/handler.go +++ b/log/handler.go @@ -10,20 +10,23 @@ import ( "time" ) -var _ slog.Handler = new(defaultHandler) - +// defaultHandler combines a standard log.Logger with an slog.Handler for flexible logging. type defaultHandler struct { - *sync.Mutex - *bytes.Buffer - *log.Logger - slog.Handler + *sync.Mutex // Mutex for thread-safe buffer access. + *bytes.Buffer // Buffer for formatting log messages. + *log.Logger // Underlying standard logger for output. + slog.Handler // Structured logging handler. } +var _ slog.Handler = new(defaultHandler) + +// newDefaultHandler creates a new defaultHandler with the specified logger and log level. func newDefaultHandler(logger *log.Logger, level *slog.LevelVar) *defaultHandler { buf := new(bytes.Buffer) return &defaultHandler{new(sync.Mutex), buf, logger, slog.NewTextHandler(buf, &slog.HandlerOptions{Level: level})} } +// Handle formats and outputs a log record using the slog.Handler and log.Logger. func (h *defaultHandler) Handle(ctx context.Context, r slog.Record) error { h.Lock() defer h.Unlock() @@ -36,10 +39,12 @@ func (h *defaultHandler) Handle(ctx context.Context, r slog.Record) error { return nil } +// WithAttrs returns a new handler with the specified attributes. func (h *defaultHandler) WithAttrs(attrs []slog.Attr) slog.Handler { return &defaultHandler{h.Mutex, h.Buffer, h.Logger, h.Handler.WithAttrs(attrs)} } +// WithGroup returns a new handler with the specified group name. func (h *defaultHandler) WithGroup(name string) slog.Handler { return &defaultHandler{h.Mutex, h.Buffer, h.Logger, h.Handler.WithGroup(name)} } diff --git a/log/log.go b/log/log.go index a11e5a7..9cf02f5 100644 --- a/log/log.go +++ b/log/log.go @@ -9,132 +9,207 @@ import ( "sync/atomic" ) +// defaultLogger holds the default Logger instance, managed atomically for thread safety. var defaultLogger atomic.Pointer[Logger] +// init initializes the default logger with stderr output. func init() { defaultLogger.Store(newLogger(log.Default(), os.Stderr)) } +// Default returns the current default Logger instance. func Default() *Logger { return defaultLogger.Load() } +// SetDefault sets the default Logger instance. func SetDefault(l *Logger) { defaultLogger.Store(l) } +// File returns the current log file path of the default Logger. func File() string { return Default().File() } +// SetOutput sets the output destination for the default Logger. +// The file parameter specifies the log file path; if empty, no file output is used. +// The extra parameter allows an additional output destination (e.g., stderr). func SetOutput(file string, extra io.Writer) { Default().SetOutput(file, extra) } +// SetFile sets the log file path for the default Logger, keeping the existing extra writer. func SetFile(file string) { Default().SetFile(file) } +// SetExtra sets an additional output destination for the default Logger, keeping the existing file. func SetExtra(extra io.Writer) { Default().SetExtra(extra) } +// Rotate reopens the log file and rotates the extra writer for the default Logger if applicable. func Rotate() { Default().Rotate() } +// Flags returns the current log flags of the default Logger. func Flags() int { return Default().Flags() } + +// SetFlags sets the log flags for the default Logger. func SetFlags(flag int) { Default().SetFlags(flag) } + +// Prefix returns the current log prefix of the default Logger. func Prefix() string { return Default().Prefix() } + +// SetPrefix sets the log prefix for the default Logger. func SetPrefix(prefix string) { Default().SetPrefix(prefix) } + +// Writer returns the current output writer of the default Logger. func Writer() io.Writer { return Default().Writer() } + +// Print logs a message using the default Logger's Print method. func Print(v ...any) { Default().Print(v...) } + +// Printf logs a formatted message using the default Logger's Printf method. func Printf(format string, v ...any) { Default().Printf(format, v...) } + +// Println logs a message with a newline using the default Logger's Println method. func Println(v ...any) { Default().Println(v...) } + +// Fatal logs a message and exits using the default Logger's Fatal method. func Fatal(v ...any) { Default().Fatal(v...) } + +// Fatalf logs a formatted message and exits using the default Logger's Fatalf method. func Fatalf(format string, v ...any) { Default().Fatalf(format, v...) } + +// Fatalln logs a message with a newline and exits using the default Logger's Fatalln method. func Fatalln(v ...any) { Default().Fatalln(v...) } + +// Panic logs a message and panics using the default Logger's Panic method. func Panic(v ...any) { Default().Panic(v...) } + +// Panicf logs a formatted message and panics using the default Logger's Panicf method. func Panicf(format string, v ...any) { Default().Panicf(format, v...) } + +// Panicln logs a message with a newline and panics using the default Logger's Panicln method. func Panicln(v ...any) { Default().Panicln(v...) } + +// Output logs a message with the specified call depth using the default Logger's Output method. func Output(calldepth int, s string) error { return Default().Output(calldepth+1, s) // +1 for this frame. } +// SetHandler sets the slog handler for the default Logger. +// Note: The new handler may not respect the existing log level (obtained via [Level]), potentially disabling level control. +// Ensure the provided handler is configured with the desired log level if needed. func SetHandler(h slog.Handler) { Default().slog.Store(slog.New(h)) } -func Level() *slog.LevelVar { - return Default().level + +// Level returns the log level of the default Logger. +func Level() slog.Level { + return Default().Level() } + +// SetLevel sets the log level for the default Logger. func SetLevel(level slog.Level) { Default().level.Set(level) } + +// Debug logs a message at Debug level using the default Logger. func Debug(msg string, args ...any) { Default().Debug(msg, args...) } + +// DebugContext logs a message at Debug level with context using the default Logger. func DebugContext(ctx context.Context, msg string, args ...any) { Default().DebugContext(ctx, msg, args...) } + +// Enabled checks if the specified log level is enabled for the default Logger. func Enabled(ctx context.Context, level slog.Level) bool { return Default().Enabled(ctx, level) } + +// Error logs a message at Error level using the default Logger. func Error(msg string, args ...any) { Default().Error(msg, args...) } + +// ErrorContext logs a message at Error level with context using the default Logger. func ErrorContext(ctx context.Context, msg string, args ...any) { Default().ErrorContext(ctx, msg, args...) } + +// Handler returns the slog handler of the default Logger. func Handler() slog.Handler { return Default().Handler() } + +// Info logs a message at Info level using the default Logger. func Info(msg string, args ...any) { Default().Info(msg, args...) } + +// InfoContext logs a message at Info level with context using the default Logger. func InfoContext(ctx context.Context, msg string, args ...any) { Default().InfoContext(ctx, msg, args...) } + +// Log logs a message at the specified level with context using the default Logger. func Log(ctx context.Context, level slog.Level, msg string, args ...any) { Default().Log(ctx, level, msg, args...) } + +// LogAttrs logs a message at the specified level with attributes using the default Logger. func LogAttrs(ctx context.Context, level slog.Level, msg string, attrs ...slog.Attr) { Default().LogAttrs(ctx, level, msg, attrs...) } + +// Warn logs a message at Warn level using the default Logger. func Warn(msg string, args ...any) { Default().Warn(msg, args...) } + +// WarnContext logs a message at Warn level with context using the default Logger. func WarnContext(ctx context.Context, msg string, args ...any) { Default().WarnContext(ctx, msg, args...) } + +// With returns a new Logger with the specified attributes, leaving the default Logger unchanged. func With(args ...any) *Logger { return Default().With(args...) } + +// WithGroup returns a new Logger with the specified group name, leaving the default Logger unchanged. func WithGroup(name string) *Logger { return Default().WithGroup(name) } diff --git a/log/logger.go b/log/logger.go index eab56af..f9f4aef 100644 --- a/log/logger.go +++ b/log/logger.go @@ -11,30 +11,35 @@ import ( "github.com/sunshineplan/utils/container" ) +// Logger implements a custom logger that combines the standard log.Logger with slog.Logger, +// providing flexible output destinations and log level control. +type Logger struct { + *log.Logger // Underlying standard logger. + file atomic.Pointer[os.File] // File handle for log output, managed atomically. + extra container.Value[io.Writer] // Additional output destination (e.g., stderr). + slog atomic.Pointer[slog.Logger] // Structured logger for leveled logging. + level *slog.LevelVar // Log level controller. +} + var ( _ io.Writer = new(Logger) _ Rotatable = new(Logger) ) +// Constants for log flags, mirrored from the standard log package. const ( - Ldate = log.Ldate // the date in the local time zone: 2009/01/23 - Ltime = log.Ltime // the time in the local time zone: 01:23:23 - Lmicroseconds = log.Lmicroseconds // microsecond resolution: 01:23:23.123123. assumes Ltime. - Llongfile = log.Llongfile // full file name and line number: /a/b/c/d.go:23 - Lshortfile = log.Lshortfile // final file name element and line number: d.go:23. overrides Llongfile - LUTC = log.LUTC // if Ldate or Ltime is set, use UTC rather than the local time zone - Lmsgprefix = log.Lmsgprefix // move the "prefix" from the beginning of the line to before the message - LstdFlags = log.LstdFlags // initial values for the standard logger + Ldate = log.Ldate // Include date in log output (e.g., 2009/01/23). + Ltime = log.Ltime // Include time in log output (e.g., 01:23:23). + Lmicroseconds = log.Lmicroseconds // Include microsecond resolution in time (requires Ltime). + Llongfile = log.Llongfile // Include full file name and line number (e.g., /a/b/c/d.go:23). + Lshortfile = log.Lshortfile // Include final file name element and line number (e.g., d.go:23), overrides Llongfile. + LUTC = log.LUTC // Use UTC for date/time if Ldate or Ltime is set. + Lmsgprefix = log.Lmsgprefix // Move prefix to before the message. + LstdFlags = log.LstdFlags // Default flags: Ldate | Ltime. ) -type Logger struct { - *log.Logger - file atomic.Pointer[os.File] - extra container.Value[io.Writer] - slog atomic.Pointer[slog.Logger] - level *slog.LevelVar -} - +// newLogger creates a new Logger instance with the specified standard logger and file handle. +// The logger is initialized with a default slog handler and log level. func newLogger(l *log.Logger, file *os.File) *Logger { logger := &Logger{Logger: l, level: new(slog.LevelVar)} logger.file.Store(file) @@ -42,6 +47,8 @@ func newLogger(l *log.Logger, file *os.File) *Logger { return logger } +// New creates a new Logger instance with the specified file path, prefix, and flags. +// If file is empty, logs are discarded. Panics if the file cannot be opened. func New(file, prefix string, flag int) *Logger { if file == "" { return newLogger(log.New(io.Discard, prefix, flag), nil) @@ -53,6 +60,7 @@ func New(file, prefix string, flag int) *Logger { return newLogger(log.New(f, prefix, flag), f) } +// File returns the current log file path, or an empty string if no file is set. func (l *Logger) File() string { if file := l.file.Load(); file != nil { return file.Name() @@ -60,6 +68,9 @@ func (l *Logger) File() string { return "" } +// setOutput configures the logger's output destinations. +// It sets the file and extra writer, closing the old file if necessary. +// Logs are written to io.Discard if no outputs are provided. func (l *Logger) setOutput(file *os.File, extra io.Writer) { var writers []io.Writer if file != nil { @@ -83,6 +94,8 @@ func (l *Logger) setOutput(file *os.File, extra io.Writer) { } } +// SetOutput sets the log output to the specified file path and extra writer. +// Returns an error if the file cannot be opened. func (l *Logger) SetOutput(file string, extra io.Writer) error { f, err := openFile(file) if err != nil { @@ -92,59 +105,94 @@ func (l *Logger) SetOutput(file string, extra io.Writer) error { return nil } +// SetFile sets the log file to the specified path, keeping the existing extra writer. func (l *Logger) SetFile(file string) { l.SetOutput(file, l.extra.Load()) } +// SetExtra sets an additional output destination (e.g., stderr), keeping the existing file. func (l *Logger) SetExtra(extra io.Writer) { l.setOutput(l.file.Load(), extra) } +// SetHandler sets the slog handler for structured logging. +// Note: The new handler may not respect the existing log level (l.level), potentially disabling level control. +// Ensure the provided handler is configured with the desired log level if needed. func (l *Logger) SetHandler(h slog.Handler) { l.slog.Store(slog.New(h)) } -func (l *Logger) Level() *slog.LevelVar { - return l.level + +// Level returns the current log level. +func (l *Logger) Level() slog.Level { + return l.level.Level() } + +// SetLevel sets the log level for structured logging. func (l *Logger) SetLevel(level slog.Level) { l.level.Set(level) } + +// Debug logs a message at Debug level with the given arguments. func (l *Logger) Debug(msg string, args ...any) { l.slog.Load().Debug(msg, args...) } + +// DebugContext logs a message at Debug level with the given context and arguments. func (l *Logger) DebugContext(ctx context.Context, msg string, args ...any) { l.slog.Load().DebugContext(ctx, msg, args...) } + +// Enabled checks if the specified log level is enabled for the logger. func (l *Logger) Enabled(ctx context.Context, level slog.Level) bool { return l.slog.Load().Enabled(ctx, level) } + +// Error logs a message at Error level with the given arguments. func (l *Logger) Error(msg string, args ...any) { l.slog.Load().Error(msg, args...) } + +// ErrorContext logs a message at Error level with the given context and arguments. func (l *Logger) ErrorContext(ctx context.Context, msg string, args ...any) { l.slog.Load().ErrorContext(ctx, msg, args...) } + +// Handler returns the current slog handler. func (l *Logger) Handler() slog.Handler { return l.slog.Load().Handler() } + +// Info logs a message at Info level with the given arguments. func (l *Logger) Info(msg string, args ...any) { l.slog.Load().Info(msg, args...) } + +// InfoContext logs a message at Info level with the given context and arguments. func (l *Logger) InfoContext(ctx context.Context, msg string, args ...any) { l.slog.Load().InfoContext(ctx, msg, args...) } + +// Log logs a message at the specified level with the given context and arguments. func (l *Logger) Log(ctx context.Context, level slog.Level, msg string, args ...any) { l.slog.Load().Log(ctx, level, msg, args...) } + +// LogAttrs logs a message at the specified level with the given context and attributes. func (l *Logger) LogAttrs(ctx context.Context, level slog.Level, msg string, attrs ...slog.Attr) { l.slog.Load().LogAttrs(ctx, level, msg, attrs...) } + +// Warn logs a message at Warn level with the given arguments. func (l *Logger) Warn(msg string, args ...any) { l.slog.Load().Warn(msg, args...) } + +// WarnContext logs a message at Warn level with the given context and arguments. func (l *Logger) WarnContext(ctx context.Context, msg string, args ...any) { l.slog.Load().WarnContext(ctx, msg, args...) } + +// With returns a new Logger with the specified attributes, leaving the original unchanged. func (l *Logger) With(args ...any) *Logger { logger := &Logger{Logger: l.Logger, extra: l.extra, level: l.level} logger.file.Store(l.file.Load()) @@ -154,6 +202,8 @@ func (l *Logger) With(args ...any) *Logger { logger.slog.Store(l.slog.Load().With(args...)) return logger } + +// WithGroup returns a new Logger with the specified group name, leaving the original unchanged. func (l *Logger) WithGroup(name string) *Logger { logger := &Logger{Logger: l.Logger, extra: l.extra, level: l.level} logger.file.Store(l.file.Load()) @@ -164,6 +214,7 @@ func (l *Logger) WithGroup(name string) *Logger { return logger } +// Rotate reopens the log file and rotates the extra writer if it implements Rotatable. func (l *Logger) Rotate() { if extra := l.extra.Load(); extra != nil { if i, ok := extra.(Rotatable); ok { @@ -175,10 +226,13 @@ func (l *Logger) Rotate() { } } +// Write writes bytes to the logger's output destination, implementing io.Writer. func (l *Logger) Write(b []byte) (int, error) { return l.Writer().Write(b) } +// openFile opens a log file with the specified path in append mode. +// Returns nil if the path is empty, or an error if the file cannot be opened. func openFile(file string) (*os.File, error) { if file != "" { f, err := os.OpenFile(file, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0640) diff --git a/log/rotate.go b/log/rotate.go index 6f433ec..25c0bf9 100644 --- a/log/rotate.go +++ b/log/rotate.go @@ -6,10 +6,13 @@ import ( "os/signal" ) +// Rotatable defines an interface for objects that support log rotation. type Rotatable interface { Rotate() } +// ListenRotateSignal listens for the specified signals and triggers rotation on the Rotatable object. +// It stops listening when the context is canceled. func ListenRotateSignal(ctx context.Context, r Rotatable, sig ...os.Signal) { c := make(chan os.Signal, 1) signal.Notify(c, sig...) From d8be22fa65443e9d4e12e0d57cfedf727bc1ad93 Mon Sep 17 00:00:00 2001 From: sunshineplan Date: Tue, 21 Oct 2025 15:05:57 +0800 Subject: [PATCH 34/40] mail --- mail/mail.go | 10 +++++- mail/message.go | 83 +++++++++++++++++++++++++++++-------------------- mail/receipt.go | 42 +++++++++++++++++-------- 3 files changed, 88 insertions(+), 47 deletions(-) diff --git a/mail/mail.go b/mail/mail.go index dbc6adc..0ceb1cb 100644 --- a/mail/mail.go +++ b/mail/mail.go @@ -20,6 +20,8 @@ type Dialer struct { Timeout time.Duration } +// Dial dials the SMTP server and performs optional STARTTLS / AUTH. +// It returns a connected smtp.Client. Caller should call client.Quit() when done. func (d *Dialer) Dial() (client *smtp.Client, err error) { if d.Timeout == 0 { d.Timeout = 3 * time.Minute @@ -37,6 +39,7 @@ func (d *Dialer) Dial() (client *smtp.Client, err error) { return } + // If connection is not TLS but server supports STARTTLS, upgrade. if !d.TLS { if ok, _ := client.Extension("STARTTLS"); ok { if err = client.StartTLS(nil); err != nil { @@ -45,6 +48,7 @@ func (d *Dialer) Dial() (client *smtp.Client, err error) { } } + // Authenticate if server advertises AUTH and credentials are provided. if ok, _ := client.Extension("AUTH"); ok && d.Account != "" && d.Password != "" { if err = client.Auth2(&smtp.Auth{Identity: "", Username: d.Account, Password: d.Password, Server: d.Server}); err != nil { client.Quit() @@ -55,15 +59,19 @@ func (d *Dialer) Dial() (client *smtp.Client, err error) { return client, nil } -// Send sends the given messages. +// Send sends the given messages using one established connection. +// It honors Dialer.Timeout for each message via context timeouts. +// The client connection will be Quit() when Send returns. func (d *Dialer) Send(msg ...*Message) error { client, err := d.Dial() if err != nil { return err } + // ensure we close the SMTP connection when the function returns defer client.Quit() for _, m := range msg { + // default From to the dialer's account if not set if m.From == nil { m.From = Receipt("", d.Account) } diff --git a/mail/message.go b/mail/message.go index 840677d..ffcf96e 100644 --- a/mail/message.go +++ b/mail/message.go @@ -7,7 +7,6 @@ import ( "crypto/rand" "encoding/base64" "fmt" - "math" "mime" "net/mail" "net/textproto" @@ -51,24 +50,27 @@ type Message struct { Attachments []*Attachment } +// RcptList returns a de-duplicated list of recipient email addresses while preserving order: +// first To, then Cc, then Bcc. func (m *Message) RcptList() (rcpts []string) { - list := make(map[string]struct{}) - for _, to := range m.To { - list[to.Address] = struct{}{} - } - for _, cc := range m.Cc { - list[cc.Address] = struct{}{} - } - for _, bcc := range m.Bcc { - list[bcc.Address] = struct{}{} - } - for k := range list { - rcpts = append(rcpts, k) + seen := make(map[string]bool) + for _, list := range [][]*mail.Address{m.To, m.Cc, m.Bcc} { + for _, addr := range list { + if !seen[addr.Address] { + rcpts = append(rcpts, addr.Address) + seen[addr.Address] = true + } + } } return } +// Bytes renders the RFC822-style message bytes for the message. +// id is used to create the Message-ID domain if provided; if empty, fallback to hostname. +// The produced message uses CRLF line endings as required by SMTP and includes correct +// MIME headers for single-part or multipart/mixed with attachments. func (m *Message) Bytes(id string) []byte { + // determine hostname part for Message-ID if id == "" { if m.From != nil && m.From.Address != "" { id = m.From.Address @@ -85,12 +87,18 @@ func (m *Message) Bytes(id string) []byte { var buf bytes.Buffer w := textproto.NewWriter(bufio.NewWriter(&buf)) + // Basic headers w.PrintfLine("MIME-Version: 1.0") w.PrintfLine("Date: %s", time.Now().Format(time.RFC1123Z)) w.PrintfLine("Message-ID: <%s>", id) - w.PrintfLine("Subject: =?UTF-8?B?%s?=", toBase64(m.Subject)) - w.PrintfLine("From: %s", m.From) - w.PrintfLine("To: %s", m.To) + w.PrintfLine("Subject: %s", encodeHeader(m.Subject)) + if m.From != nil { + w.PrintfLine("From: %s", m.From.String()) + } + // To / Cc headers (these are header fields visible in the message) + if len(m.To) > 0 { + w.PrintfLine("To: %s", m.To) + } if len(m.Cc) > 0 { w.PrintfLine("Cc: %s", m.Cc) } @@ -105,7 +113,7 @@ func (m *Message) Bytes(id string) []byte { w.PrintfLine(`Content-Type: %s; charset="UTF-8"`, m.ContentType) w.PrintfLine("Content-Transfer-Encoding: base64") w.PrintfLine("") - w.PrintfLine("%s", toBase64(m.Body)) + writeBase64BytesLines(w, []byte(m.Body)) if l := len(m.Attachments); l > 0 { for i, attachment := range m.Attachments { @@ -116,25 +124,14 @@ func (m *Message) Bytes(id string) []byte { w.PrintfLine("Content-Type: application/octet-stream") } if attachment.ContentID != "" { - w.PrintfLine(`Content-Disposition: inline; filename="=?UTF-8?B?%s?="`, toBase64(attachment.Filename)) + w.PrintfLine(`Content-Disposition: inline; filename="%s"`, encodeHeader(attachment.Filename)) w.PrintfLine("Content-ID: <%s>", attachment.ContentID) } else { - w.PrintfLine(`Content-Disposition: attachment; filename="=?UTF-8?B?%s?="`, toBase64(attachment.Filename)) + w.PrintfLine(`Content-Disposition: attachment; filename="%s"`, encodeHeader(attachment.Filename)) } w.PrintfLine("Content-Transfer-Encoding: base64") w.PrintfLine("") - - b := make([]byte, base64.StdEncoding.EncodedLen(len(attachment.Bytes))) - base64.StdEncoding.Encode(b, attachment.Bytes) - - // write base64 content in lines of up to 76 chars - for i, l := 0, int(math.Ceil(float64(len(b))/76)); i < l; i++ { - if i == l-1 { - w.PrintfLine("%s", b[i*76:]) - } else { - w.PrintfLine("%s", b[i*76:(i+1)*76]) - } - } + writeBase64BytesLines(w, attachment.Bytes) if i < l-1 { w.PrintfLine("--%s", boundary) @@ -147,10 +144,28 @@ func (m *Message) Bytes(id string) []byte { return buf.Bytes() } -func toBase64(str string) string { - return base64.StdEncoding.EncodeToString([]byte(str)) +// encodeHeader encodes a header value using RFC2047 only when non-ASCII chars are present. +// For pure ASCII strings it returns the original string unmodified. +func encodeHeader(s string) string { + for _, r := range s { + if r > 127 { + return fmt.Sprintf("=?UTF-8?B?%s?=", base64.StdEncoding.EncodeToString([]byte(s))) + } + } + return s +} + +// writeBase64BytesLines encodes bytes to base64 and writes it in lines of up to 76 chars. +func writeBase64BytesLines(w *textproto.Writer, b []byte) { + enc := base64.StdEncoding.EncodeToString(b) + // number of lines + for i := 0; i < len(enc); i += 76 { + end := min(i+76, len(enc)) + w.PrintfLine("%s", enc[i:end]) + } } +// randomString returns a hex string of length 2*n (because each byte => two hex chars). func randomString(n int) string { b := make([]byte, n) if _, err := rand.Read(b); err != nil { @@ -159,6 +174,8 @@ func randomString(n int) string { return fmt.Sprintf("%x", b) } +// generateMsgID generates a message id using the provided reference (domain or email). +// If ref has an @, use the right-hand side as domain, otherwise use ref as domain. func generateMsgID(ref string) string { s := strings.Split(ref, "@") return fmt.Sprintf("%s@%s", randomString(16), s[len(s)-1]) diff --git a/mail/receipt.go b/mail/receipt.go index 132998c..2cc3bde 100644 --- a/mail/receipt.go +++ b/mail/receipt.go @@ -5,23 +5,25 @@ import ( "encoding/json" "fmt" "net/mail" - "strconv" "strings" ) var ( - _ encoding.TextUnmarshaler = (*Receipts)(nil) + _ encoding.TextUnmarshaler = new(Receipts) _ encoding.TextMarshaler = Receipts{} - _ json.Unmarshaler = (*Receipts)(nil) + _ json.Unmarshaler = new(Receipts) _ json.Marshaler = Receipts{} ) +// Receipt creates a mail.Address pointer from name and address. func Receipt(name, address string) *mail.Address { return &mail.Address{Name: name, Address: address} } +// Receipts represents a list of mail addresses. type Receipts []*mail.Address +// ParseReceipts parses a comma/semicolon separated list of addresses into Receipts. func ParseReceipts(rcpts string) (Receipts, error) { addresses, err := mail.ParseAddressList(rcpts) if err != nil { @@ -30,16 +32,24 @@ func ParseReceipts(rcpts string) (Receipts, error) { return Receipts(addresses), nil } +// List returns a slice of address strings (just the email addresses). func (rcpts Receipts) List() []string { - var s []string - for _, rcpt := range rcpts { - s = append(s, rcpt.Address) + s := make([]string, len(rcpts)) + for i, rcpt := range rcpts { + s[i] = rcpt.Address } return s } +// String returns the addresses joined in the standard RFC 5322 way +// using the String method of mail.Address. func (rcpts Receipts) String() string { + if len(rcpts) == 0 { + return "" + } var b strings.Builder + // approximate average length to avoid repeated allocations + b.Grow(len(rcpts) * 32) for i, rcpt := range rcpts { if i != 0 { b.WriteString(", ") @@ -49,6 +59,7 @@ func (rcpts Receipts) String() string { return b.String() } +// UnmarshalText implements encoding.TextUnmarshaler func (rcpts *Receipts) UnmarshalText(text []byte) error { addresses, err := ParseReceipts(string(text)) if err != nil { @@ -58,16 +69,18 @@ func (rcpts *Receipts) UnmarshalText(text []byte) error { return nil } +// UnmarshalJSON supports either a single string (comma separated) or an array of strings. func (rcpts *Receipts) UnmarshalJSON(b []byte) error { - if unquote, err := strconv.Unquote(string(b)); err == nil { - return rcpts.UnmarshalText([]byte(unquote)) + var s string + if err := json.Unmarshal(b, &s); err == nil { + return rcpts.UnmarshalText([]byte(s)) } - var s []string - if err := json.Unmarshal(b, &s); err != nil { - return nil + var list []string + if err := json.Unmarshal(b, &list); err != nil { + return err } var addresses []*mail.Address - for _, i := range s { + for _, i := range list { address, err := mail.ParseAddress(i) if err != nil { return err @@ -78,10 +91,13 @@ func (rcpts *Receipts) UnmarshalJSON(b []byte) error { return nil } +// MarshalText implements encoding.TextMarshaler and returns the RFC5322 representation. func (rcpts Receipts) MarshalText() ([]byte, error) { return []byte(rcpts.String()), nil } +// MarshalJSON implements json.Marshaler. +// It encodes the receipts as a JSON array of email address strings. func (rcpts Receipts) MarshalJSON() ([]byte, error) { - return []byte(rcpts.String()), nil + return json.Marshal(rcpts.String()) } From 11f7b66b3c394f1c114e200b469855359fe8adff Mon Sep 17 00:00:00 2001 From: sunshineplan Date: Tue, 21 Oct 2025 15:29:48 +0800 Subject: [PATCH 35/40] ocr --- ocr/ocr.go | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/ocr/ocr.go b/ocr/ocr.go index 86ab53f..81bc018 100644 --- a/ocr/ocr.go +++ b/ocr/ocr.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/json" "errors" + "fmt" "io" "mime/multipart" "net/http" @@ -16,6 +17,12 @@ const ( var errNoResult = errors.New("no ocr result") +type ocrResponse struct { + ParsedResults []struct { + ParsedText string + } +} + // OCR reads image from reader r and converts it into string. func OCR(r io.Reader) (string, error) { return OCRWithClient(r, http.DefaultClient) @@ -49,16 +56,16 @@ func OCRWithClient(r io.Reader, client *http.Client) (string, error) { } defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("ocr request failed: %s", resp.Status) + } + b, err := io.ReadAll(resp.Body) if err != nil { return "", err } - var res struct { - ParsedResults []struct { - ParsedText string - } - } + var res ocrResponse if err := json.Unmarshal(b, &res); err != nil { return "", err } From a6a5ce99fe6d2973a95a68024788e2b245b0c731 Mon Sep 17 00:00:00 2001 From: sunshineplan Date: Tue, 21 Oct 2025 15:43:21 +0800 Subject: [PATCH 36/40] slice --- slice/slice.go | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/slice/slice.go b/slice/slice.go index 2d272e9..894a27e 100644 --- a/slice/slice.go +++ b/slice/slice.go @@ -1,13 +1,12 @@ package slice -// Deduplicate removes duplicate items in slice. +// Deduplicate removes duplicate elements while preserving order. func Deduplicate[S ~[]E, E comparable](s S) S { - if s == nil { + if len(s) == 0 { return nil } - - res := S{} - m := make(map[E]struct{}) + m := make(map[E]struct{}, len(s)) + res := make(S, 0, len(s)) for _, i := range s { if _, ok := m[i]; !ok { m[i] = struct{}{} From 38e4e9c275ccf3c118e431d1dba46391dbe6ccc9 Mon Sep 17 00:00:00 2001 From: sunshineplan Date: Tue, 21 Oct 2025 16:50:28 +0800 Subject: [PATCH 37/40] unit --- unit/bytesize.go | 99 ++++++++++++++++++++++++++++++------------- unit/bytesize_test.go | 82 ++++++++++++++++++++++++++++++++++- 2 files changed, 151 insertions(+), 30 deletions(-) diff --git a/unit/bytesize.go b/unit/bytesize.go index 30ee81d..a8a9928 100644 --- a/unit/bytesize.go +++ b/unit/bytesize.go @@ -15,6 +15,20 @@ var ( var byteSizeRegexp = regexp.MustCompile(`^(\d+(?:\.\d+)?) ?((?i)[KMGTPE]?B?)$`) +// ByteSize represents a quantity of bytes. +type ByteSize int64 + +// Common byte size units. +const ( + B ByteSize = 1 + KB ByteSize = 1 << (10 * iota) + MB + GB + TB + PB + EB +) + var ( byteSizeStr = map[ByteSize]string{ B: "B", @@ -36,18 +50,8 @@ var ( } ) -type ByteSize int64 - -const ( - B ByteSize = 1 - KB ByteSize = 1 << (10 * iota) - MB - GB - TB - PB - EB -) - +// ParseByteSize parses a human-readable size string (e.g. "1.5GB", "100 KB") +// and returns the corresponding ByteSize value. func ParseByteSize(s string) (ByteSize, error) { s = strings.TrimSpace(s) res := byteSizeRegexp.FindStringSubmatch(s) @@ -65,41 +69,61 @@ func ParseByteSize(s string) (ByteSize, error) { return ByteSize(v * float64(strByteSize[unit])), nil } -func NewByteSize(n float64, unit ByteSize) ByteSize { - return ByteSize(n * float64(unit)) +// MustParseByteSize parses a size string and panics if it is invalid. +func MustParseByteSize(s string) ByteSize { + v, err := ParseByteSize(s) + if err != nil { + panic(err) + } + return v +} + +// NewByteSize creates a ByteSize from a numeric value and a unit string. +// Example: NewByteSize(1.5, "GB") -> 1.5GB. +// Returns an error if the unit is not recognized. +func NewByteSize(n float64, unit string) (ByteSize, error) { + unit = strings.ToUpper(strings.TrimSpace(unit)) + unit = strings.TrimSuffix(unit, "B") + if unit == "" { + unit = "B" + } + bs, ok := strByteSize[unit] + if !ok { + return 0, errors.New("unknown byte size unit: " + unit) + } + return ByteSize(n * float64(bs)), nil } +// DefaultByteSizeFormatFloat formats a float with up to 2 decimals, +// trimming trailing zeros and the decimal point if unnecessary. func DefaultByteSizeFormatFloat(f float64) string { s := strconv.FormatFloat(f, 'f', 2, 64) s = strings.TrimRight(s, "0") return strings.TrimSuffix(s, ".") } +// ByteSizeFormatFloat defines how floating-point values are formatted in String(). +// It can be replaced to customize output precision. var ByteSizeFormatFloat = DefaultByteSizeFormatFloat +// String returns a human-readable representation of the byte size. +// e.g. 1536 -> "1.5KB", 1048576 -> "1MB". func (n ByteSize) String() string { - unit := B - switch { - case n >= EB: - unit = EB - case n >= PB: - unit = PB - case n >= TB: - unit = TB - case n >= GB: - unit = GB - case n >= MB: - unit = MB - case n >= KB: - unit = KB + units := []ByteSize{EB, PB, TB, GB, MB, KB, B} + for _, unit := range units { + if n >= unit { + return ByteSizeFormatFloat(float64(n)/float64(unit)) + byteSizeStr[unit] + } } - return ByteSizeFormatFloat(float64(n)/float64(unit)) + byteSizeStr[unit] + return "0B" } +// MarshalText implements the encoding.TextMarshaler interface. func (b ByteSize) MarshalText() ([]byte, error) { return []byte(b.String()), nil } +// UnmarshalText implements the encoding.TextUnmarshaler interface. func (b *ByteSize) UnmarshalText(text []byte) error { bytes, err := ParseByteSize(string(text)) if err != nil { @@ -108,3 +132,20 @@ func (b *ByteSize) UnmarshalText(text []byte) error { *b = bytes return nil } + +// To converts the ByteSize to the specified unit and returns its string representation. +func (b ByteSize) To(unit string, decimals int) (string, error) { + unit = strings.ToUpper(strings.TrimSpace(unit)) + unit = strings.TrimSuffix(unit, "B") + if unit == "" { + unit = "B" + } + base, ok := strByteSize[unit] + if !ok { + return "", errors.New("invalid unit: " + unit) + } + value := float64(b) / float64(base) + s := strconv.FormatFloat(value, 'f', decimals, 64) + s = strings.TrimRight(s, "0") + return strings.TrimSuffix(s, ".") + byteSizeStr[base], nil +} diff --git a/unit/bytesize_test.go b/unit/bytesize_test.go index 1bec500..ee19b11 100644 --- a/unit/bytesize_test.go +++ b/unit/bytesize_test.go @@ -33,7 +33,6 @@ func TestByteSize(t *testing.T) { {KB, "1KB"}, {10 * MB, "10MB"}, {1536 * MB, "1.5GB"}, - {NewByteSize(1.5, GB), "1.5GB"}, } { if bytesize := ByteSize(testcase.n).String(); bytesize != testcase.str { t.Errorf("expected %q; got %q", testcase.str, bytesize) @@ -45,3 +44,84 @@ func TestByteSize(t *testing.T) { } } } + +func TestNewByteSize(t *testing.T) { + tests := []struct { + n float64 + unit string + want ByteSize + expectErr bool + }{ + {1, "B", B, false}, + {1, "KB", KB, false}, + {1, "mb", MB, false}, + {1.5, "GB", ByteSize(1.5 * float64(GB)), false}, + {2, "tb", ByteSize(2 * float64(TB)), false}, + {0, "MB", 0, false}, + {1, "XYZ", 0, true}, + } + + for _, tt := range tests { + got, err := NewByteSize(tt.n, tt.unit) + if tt.expectErr { + if err == nil { + t.Errorf("NewByteSize(%v, %q) expected error, got nil", tt.n, tt.unit) + } + continue + } + if err != nil { + t.Errorf("NewByteSize(%v, %q) unexpected error: %v", tt.n, tt.unit, err) + continue + } + if got != tt.want { + t.Errorf("NewByteSize(%v, %q) = %v, want %v", tt.n, tt.unit, got, tt.want) + } + } +} + +func TestTo(t *testing.T) { + tests := []struct { + size ByteSize + unit string + decimals int + want string + expectErr bool + }{ + // integer results + {ByteSize(1024), "KB", 0, "1KB", false}, + {ByteSize(1536), "KB", 0, "2KB", false}, + {ByteSize(1048576), "MB", 0, "1MB", false}, + + // fractional results + {ByteSize(1536), "KB", 2, "1.5KB", false}, + {ByteSize(1536), "MB", 3, "0.001MB", false}, + {MustParseByteSize("1.5GB"), "GB", 2, "1.5GB", false}, + + // small numbers < 1 + {ByteSize(1), "KB", 6, "0.000977KB", false}, + {ByteSize(500), "MB", 6, "0.000477MB", false}, + + // decimals = 0 for fractional -> rounds to nearest int + {ByteSize(1536), "KB", 0, "2KB", false}, + + // invalid unit + {ByteSize(1024), "XYZ", 2, "", true}, + } + + for _, tt := range tests { + got, err := tt.size.To(tt.unit, tt.decimals) + if tt.expectErr { + if err == nil { + t.Errorf("To(%q, %d) expected error, got nil", tt.unit, tt.decimals) + } + continue + } + if err != nil { + t.Errorf("To(%q, %d) unexpected error: %v", tt.unit, tt.decimals, err) + continue + } + if got != tt.want { + t.Errorf("To(%q, %d) = %q, want %q", tt.unit, tt.decimals, got, tt.want) + } + } +} From dcf9cba3128132dcde242b613dff41ed156cefd7 Mon Sep 17 00:00:00 2001 From: sunshineplan Date: Wed, 22 Oct 2025 13:50:43 +0800 Subject: [PATCH 38/40] processing/text --- processing/text/processor.go | 149 +++++++++++++++++++++++++----- processing/text/processor_test.go | 12 +-- processing/text/task.go | 59 +++++------- processing/text/task_test.go | 8 +- 4 files changed, 162 insertions(+), 66 deletions(-) diff --git a/processing/text/processor.go b/processing/text/processor.go index 0b43d0a..f56f9dd 100644 --- a/processing/text/processor.go +++ b/processing/text/processor.go @@ -1,64 +1,169 @@ package text import ( + "fmt" "regexp" "strings" ) +// Processor defines a generic interface for text processors. +// Each processor can optionally be executed only once (Once == true) +// and provides a human-readable description for debugging or logging. type Processor interface { + // Describe returns a short description of this processor. + Describe() string + // Once reports whether this processor should run only once. Once() bool + // Process performs the actual text transformation. Process(string) (string, error) } var ( - _ Processor = processor{} - _ Processor = RemoveByRegexp{} - _ Processor = Cut{} - _ Processor = Trim{} + _ Processor = new(processor) + _ Processor = new(multiProcessor) + _ Processor = RegexpRemover{} + _ Processor = Cutter{} + _ Processor = Trimmer{} ) +// processor is a generic implementation of Processor, +// allowing you to wrap any custom text function. type processor struct { + desc string once bool fn func(string) (string, error) } -func NewProcessor(once bool, fn func(string) (string, error)) Processor { - return processor{once, fn} +// NewProcessor creates a new Processor from a function. +// +// desc - short description for debugging +// once - whether this processor should be executed only once +// fn - transformation function taking a string and returning a string/error +func NewProcessor(desc string, once bool, fn func(string) (string, error)) Processor { + return &processor{desc, once, fn} } -func WrapFunc(fn func(string) string) func(string) (string, error) { - return func(s string) (string, error) { return fn(s), nil } -} +// Describe returns the processor's description string. +func (p *processor) Describe() string { return p.desc } + +// Once reports whether this processor should run only once. +func (p *processor) Once() bool { return p.once } -func (p processor) Once() bool { return p.once } -func (p processor) Process(s string) (string, error) { +// Process executes the wrapped function to transform the text. +func (p *processor) Process(s string) (string, error) { return p.fn(s) } -type RemoveByRegexp struct { - *regexp.Regexp +// multiProcessor executes multiple sub-processors as a single Processor. +type multiProcessor struct { + desc string + once bool + processors []Processor +} + +// NewMultiProcessor creates a new MultiProcessor. +// +// desc - human-readable name +// once - whether this processor should execute only once +// procs - list of sub-processors +func NewMultiProcessor(desc string, once bool, procs ...Processor) Processor { + return &multiProcessor{desc, once, procs} +} + +// Describe returns the description for debugging or logging. +func (m *multiProcessor) Describe() string { return m.desc } + +// Once reports whether the MultiProcessor should run only once. +func (m *multiProcessor) Once() bool { return m.once } + +// Process executes all sub-processors. +func (m *multiProcessor) Process(s string) (string, error) { + for _, p := range m.processors { + var err error + s, err = p.Process(s) + if err != nil { + return "", err + } + } + return s, nil } -func (RemoveByRegexp) Once() bool { return false } -func (p RemoveByRegexp) Process(s string) (string, error) { - return p.ReplaceAllString(s, ""), nil +// RegexpRemover removes substrings that match the given regular expression. +type RegexpRemover struct { + Re *regexp.Regexp +} + +// Describe returns a string representation of the RegexpRemover. +func (p RegexpRemover) Describe() string { return fmt.Sprintf("RegexpRemover(%q)", p.Re.String()) } + +// Once always returns false, meaning this processor can be applied repeatedly. +func (RegexpRemover) Once() bool { return false } + +// Process removes all matches of the regular expression from the input string. +func (p RegexpRemover) Process(s string) (string, error) { + return p.Re.ReplaceAllString(s, ""), nil } -type Cut struct { +// Cutter splits the input by the given separator and keeps only the part before it. +type Cutter struct { Sep string } -func (Cut) Once() bool { return true } -func (p Cut) Process(s string) (string, error) { +// Describe returns a string representation of the Cutter. +func (p Cutter) Describe() string { return fmt.Sprintf("Cutter(%q)", p.Sep) } + +// Once always returns true, meaning this processor should be run only once. +func (Cutter) Once() bool { return true } + +// Process cuts the string at the first occurrence of the separator and returns the left part. +func (p Cutter) Process(s string) (string, error) { before, _, _ := strings.Cut(s, p.Sep) return before, nil } -type Trim struct { +// Trimmer removes all leading and trailing characters from the given cutset. +type Trimmer struct { Cutset string } -func (Trim) Once() bool { return false } -func (p Trim) Process(s string) (string, error) { +// Describe returns a string representation of the Trimmer. +func (p Trimmer) Describe() string { return fmt.Sprintf("Trimmer(%q)", p.Cutset) } + +// Once always returns false, meaning this processor can be applied repeatedly. +func (Trimmer) Once() bool { return false } + +// Process trims all leading and trailing characters in Cutset from the input string. +func (p Trimmer) Process(s string) (string, error) { return strings.Trim(s, p.Cutset), nil } + +// TrimSpace returns a processor that removes leading and trailing spaces. +func TrimSpace() Processor { + return NewProcessor("TrimSpace", false, WrapFunc(strings.TrimSpace)) +} + +// CutSpace returns a processor that extracts the first word in the input string. +func CutSpace() Processor { + return NewProcessor("CutSpace", true, func(s string) (string, error) { + if fs := strings.Fields(s); len(fs) > 0 { + return fs[0], nil + } + return "", nil + }) +} + +// RemoveParentheses returns a processor that remove both western and full-width parentheses. +func RemoveParentheses() Processor { + return NewMultiProcessor( + "RemoveParentheses", + true, + RegexpRemover{regexp.MustCompile(`\([^\)]*\)`)}, + RegexpRemover{regexp.MustCompile(`([^)]*)`)}, + ) +} + +// WrapFunc wraps a simple string -> string function +// into a function matching the Processor signature. +func WrapFunc(fn func(string) string) func(string) (string, error) { + return func(s string) (string, error) { return fn(s), nil } +} diff --git a/processing/text/processor_test.go b/processing/text/processor_test.go index 0ffa220..1c87ac4 100644 --- a/processing/text/processor_test.go +++ b/processing/text/processor_test.go @@ -5,7 +5,7 @@ import ( "testing" ) -func TestRemoveByRegexp(t *testing.T) { +func TestRegexpRemover(t *testing.T) { for i, testcase := range []struct { re *regexp.Regexp s string @@ -15,7 +15,7 @@ func TestRemoveByRegexp(t *testing.T) { {regexp.MustCompile(`\d+`), "abc123", "abc"}, {regexp.MustCompile(`\d+$`), "123abc456", "123abc"}, } { - if res, err := NewTasks().Append(RemoveByRegexp{testcase.re}).Process(testcase.s); err != nil { + if res, err := NewTasks().Append(RegexpRemover{testcase.re}).Process(testcase.s); err != nil { t.Error(err) } else if res != testcase.expected { t.Errorf("#%d: got %q; want %q", i, res, testcase.expected) @@ -23,7 +23,7 @@ func TestRemoveByRegexp(t *testing.T) { } } -func TestCut(t *testing.T) { +func TestCutter(t *testing.T) { for i, testcase := range []struct { seq string s string @@ -34,7 +34,7 @@ func TestCut(t *testing.T) { {" ", " abc 123", ""}, {"abc", "123abc456", "123"}, } { - if res, err := NewTasks().Append(Cut{testcase.seq}).Process(testcase.s); err != nil { + if res, err := NewTasks().Append(Cutter{testcase.seq}).Process(testcase.s); err != nil { t.Error(err) } else if res != testcase.expected { t.Errorf("#%d: got %q; want %q", i, res, testcase.expected) @@ -42,7 +42,7 @@ func TestCut(t *testing.T) { } } -func TestTrim(t *testing.T) { +func TestTrimmer(t *testing.T) { for i, testcase := range []struct { cutset string s string @@ -53,7 +53,7 @@ func TestTrim(t *testing.T) { {" ", " abc 123\n", "abc 123\n"}, {" \n", " abc 123\n", "abc 123"}, } { - if res, err := NewTasks().Append(Trim{testcase.cutset}).Process(testcase.s); err != nil { + if res, err := NewTasks().Append(Trimmer{testcase.cutset}).Process(testcase.s); err != nil { t.Error(err) } else if res != testcase.expected { t.Errorf("#%d: got %q; want %q", i, res, testcase.expected) diff --git a/processing/text/task.go b/processing/text/task.go index b6241d6..278e77e 100644 --- a/processing/text/task.go +++ b/processing/text/task.go @@ -1,73 +1,64 @@ package text import ( - "regexp" - "strings" + "fmt" ) +var MaxIter = 100 + +// Tasks represents an ordered list of text processors. +// Each processor in the list will be executed sequentially, +// and repeated until the text no longer changes. type Tasks struct { tasks []Processor } +// NewTasks creates a new Tasks instance with the provided processors. func NewTasks(tasks ...Processor) *Tasks { return &Tasks{tasks} } -func (t *Tasks) Process(s string) (string, error) { - var err error - for first, output := true, s; ; { +// Process executes all configured processors on a single string. +func (t *Tasks) Process(s string) (output string, err error) { + output = s + first := true + for range MaxIter { for _, task := range t.tasks { if first || !task.Once() { s, err = task.Process(s) if err != nil { - return "", err + return "", fmt.Errorf("%s error: %w", task.Describe(), err) } } } if s == output { - return output, nil - } else { - output = s + return } + output = s if first { first = false } } + err = fmt.Errorf("max iteration limit reached") + return } -func (t *Tasks) ProcessAll(s []string) ([]string, error) { - var output []string - for _, i := range s { - s, err := t.Process(i) +// ProcessAll applies all processors to a slice of strings, returning the processed results. +// If any processor returns an error, processing stops and the error is returned. +func (t *Tasks) ProcessAll(inputs []string) ([]string, error) { + output := make([]string, len(inputs)) + for i, in := range inputs { + out, err := t.Process(in) if err != nil { return nil, err } - output = append(output, s) + output[i] = out } return output, nil } +// Append adds one or more processors to the task list and returns the updated instance. func (t *Tasks) Append(tasks ...Processor) *Tasks { t.tasks = append(t.tasks, tasks...) return t } - -func (t *Tasks) TrimSpace() *Tasks { - return t.Append(NewProcessor(false, WrapFunc(strings.TrimSpace))) -} - -func (t *Tasks) CutSpace() *Tasks { - return t.Append(NewProcessor(true, func(s string) (string, error) { - if fs := strings.Fields(s); len(fs) > 0 { - return fs[0], nil - } - return "", nil - })) -} - -func (t *Tasks) RemoveParentheses() *Tasks { - return t.Append( - RemoveByRegexp{regexp.MustCompile(`\([^)]*\)`)}, - RemoveByRegexp{regexp.MustCompile(`([^)]*)`)}, - ) -} diff --git a/processing/text/task_test.go b/processing/text/task_test.go index 79e6585..976c0ff 100644 --- a/processing/text/task_test.go +++ b/processing/text/task_test.go @@ -3,7 +3,7 @@ package text import "testing" func TestTrimSpace(t *testing.T) { - task := NewTasks().TrimSpace() + task := NewTasks(TrimSpace()) for i, testcase := range []struct { s string expected string @@ -22,7 +22,7 @@ func TestTrimSpace(t *testing.T) { } func TestCutSpace(t *testing.T) { - task := NewTasks().CutSpace() + task := NewTasks(CutSpace()) for i, testcase := range []struct { s string expected string @@ -41,7 +41,7 @@ func TestCutSpace(t *testing.T) { } func TestRemoveParentheses(t *testing.T) { - task := NewTasks().RemoveParentheses() + task := NewTasks(RemoveParentheses()) for i, testcase := range []struct { s string expected string @@ -62,7 +62,7 @@ func TestRemoveParentheses(t *testing.T) { } func TestTasks(t *testing.T) { - task := NewTasks().TrimSpace().RemoveParentheses().CutSpace() + task := NewTasks(TrimSpace(), RemoveParentheses(), CutSpace()) for i, testcase := range []struct { s string expected string From 180f9c4c4cca9bdea9ad27642fb6dbb0bd1791b3 Mon Sep 17 00:00:00 2001 From: sunshineplan Date: Wed, 22 Oct 2025 15:27:25 +0800 Subject: [PATCH 39/40] smtp --- smtp/auth.go | 30 +++++++++++++++++------------- smtp/smtp.go | 6 +++--- 2 files changed, 20 insertions(+), 16 deletions(-) diff --git a/smtp/auth.go b/smtp/auth.go index dde5335..8bb308b 100644 --- a/smtp/auth.go +++ b/smtp/auth.go @@ -1,44 +1,49 @@ package smtp import ( + "bytes" "errors" + "fmt" "net/smtp" - "strings" + "slices" ) +// loginAuth implements the LOGIN authentication mechanism for SMTP. var _ smtp.Auth = &loginAuth{} +// loginAuth holds the credentials and server information for LOGIN authentication. type loginAuth struct { - username, password, server string + username, password, host string } +// Start initiates the LOGIN authentication process, verifying TLS and server name. func (a *loginAuth) Start(server *smtp.ServerInfo) (string, []byte, error) { if !server.TLS && !isLocalhost(server.Name) { return "", nil, errors.New("unencrypted connection") } - if server.Name != a.server { + if server.Name != a.host { return "", nil, errors.New("wrong server name") } resp := []byte(a.username) return "LOGIN", resp, nil } +// Next handles the server's challenge, responding with username or password as needed. func (a *loginAuth) Next(fromServer []byte, more bool) ([]byte, error) { if more { - if strings.Contains(string(fromServer), "Username") { - resp := []byte(a.username) - return resp, nil + if bytes.Contains(fromServer, []byte("Username")) { + return []byte(a.username), nil } - if strings.Contains(string(fromServer), "Password") { - resp := []byte(a.password) - return resp, nil + if bytes.Contains(fromServer, []byte("Password")) { + return []byte(a.password), nil } // We've already sent everything. - return nil, errors.New("unexpected server challenge") + return nil, fmt.Errorf("unexpected server challenge: %s", string(fromServer)) } return nil, nil } +// isLocalhost checks if the server name is a localhost address. func isLocalhost(name string) bool { return name == "localhost" || name == "127.0.0.1" || name == "::1" } @@ -57,14 +62,13 @@ type Auth struct { func (c *Client) Auth2(auth *Auth) error { // Auto select auth mode var a smtp.Auth - if auths := strings.Join(c.auth, " "); strings.Contains(auths, "CRAM-MD5") { + if slices.Contains(c.auth, "CRAM-MD5") { a = smtp.CRAMMD5Auth(auth.Username, auth.Password) - } else if strings.Contains(auths, "PLAIN") { + } else if slices.Contains(c.auth, "PLAIN") { a = smtp.PlainAuth(auth.Identity, auth.Username, auth.Password, auth.Server) } else { a = &loginAuth{auth.Username, auth.Password, auth.Server} } - return c.Auth(a) } diff --git a/smtp/smtp.go b/smtp/smtp.go index cd21083..68b3635 100644 --- a/smtp/smtp.go +++ b/smtp/smtp.go @@ -134,7 +134,7 @@ func (c *Client) helo() error { func (c *Client) ehlo() error { _, msg, err := c.Cmd(250, "EHLO %s", c.localName) if err != nil { - return err + return fmt.Errorf("failed to send HELO: %w", err) } ext := make(map[string]string) extList := strings.Split(msg, "\n") @@ -149,7 +149,7 @@ func (c *Client) ehlo() error { c.auth = strings.Split(mechs, " ") } c.ext = ext - return err + return nil } // StartTLS sends the STARTTLS command and encrypts all further communication. @@ -161,7 +161,7 @@ func (c *Client) StartTLS(config *tls.Config) error { } _, _, err := c.Cmd(220, "STARTTLS") if err != nil { - return err + return fmt.Errorf("failed to send STARTTLS: %w", err) } if config == nil { config = &tls.Config{ServerName: c.serverName} From a7b42b32caa2c4bcc6a1d6f8b6046849ab821fd2 Mon Sep 17 00:00:00 2001 From: sunshineplan Date: Wed, 22 Oct 2025 16:14:20 +0800 Subject: [PATCH 40/40] txt --- txt/export.go | 8 +++++--- txt/reader.go | 33 +++++++++++++++++++++++---------- txt/reader_test.go | 5 ++++- txt/writer.go | 2 -- 4 files changed, 32 insertions(+), 16 deletions(-) diff --git a/txt/export.go b/txt/export.go index 9a69be2..d5e8832 100644 --- a/txt/export.go +++ b/txt/export.go @@ -5,18 +5,20 @@ import ( "os" ) -// Export writes contents to writer w. +// Export writes the contents to w using a buffered Writer. +// It returns any error encountered during writing or flushing. func Export(contents []string, w io.Writer) error { return NewWriter(w).WriteAll(contents) } -// ExportFile writes contents to file. +// ExportFile writes the contents to the specified file, overwriting it if it exists. +// The file is created with default permissions (0666, subject to umask). +// It returns any error encountered during file creation or writing. func ExportFile(contents []string, file string) error { f, err := os.Create(file) if err != nil { return err } defer f.Close() - return Export(contents, f) } diff --git a/txt/reader.go b/txt/reader.go index f80cd70..919a715 100644 --- a/txt/reader.go +++ b/txt/reader.go @@ -7,40 +7,53 @@ import ( "os" ) +// Reader provides buffered reading from an io.Reader, splitting input by lines. type Reader struct { scanner *bufio.Scanner } +// NewReader returns a new Reader that reads from r using a buffered scanner. +// It splits input by lines using the default bufio.Scanner line-splitting behavior (\n). func NewReader(r io.Reader) *Reader { return &Reader{bufio.NewScanner(r)} } -func (r *Reader) Iter() iter.Seq[string] { - return func(yield func(string) bool) { +// Iter returns an iterator over the lines read from the underlying io.Reader. +// Each iteration yields a line and a nil error, or an empty string and an error +// if the scanner encounters an error. +func (r *Reader) Iter() iter.Seq2[string, error] { + return func(yield func(string, error) bool) { for r.scanner.Scan() { - if !yield(r.scanner.Text()) { + if !yield(r.scanner.Text(), nil) { return } } + if err := r.scanner.Err(); err != nil { + yield("", err) + } } } -// ReadAll reads all contents from r. -func ReadAll(r io.Reader) []string { +// ReadAll reads all lines from r and returns them as a slice of strings. +// It returns an error if the underlying scanner encounters an error. +func ReadAll(r io.Reader) ([]string, error) { var s []string - for i := range NewReader(r).Iter() { + for i, err := range NewReader(r).Iter() { + if err != nil { + return nil, err + } s = append(s, i) } - return s + return s, nil } -// ReadFile reads all contents from file. +// ReadFile reads all lines from the specified file and returns them as a slice of strings. +// It returns an error if the file cannot be opened or read. func ReadFile(file string) ([]string, error) { f, err := os.Open(file) if err != nil { return nil, err } defer f.Close() - - return ReadAll(f), nil + return ReadAll(f) } diff --git a/txt/reader_test.go b/txt/reader_test.go index b9cddb0..6c9d434 100644 --- a/txt/reader_test.go +++ b/txt/reader_test.go @@ -11,7 +11,10 @@ func TestReader(t *testing.T) { B C ` - res := ReadAll(strings.NewReader(txt)) + res, err := ReadAll(strings.NewReader(txt)) + if err != nil { + t.Fatal(err) + } if expect := []string{"A", "B", "C"}; !slices.Equal(expect, res) { t.Errorf("expected %v; got %v", expect, res) } diff --git a/txt/writer.go b/txt/writer.go index 4deaeb2..44a6c75 100644 --- a/txt/writer.go +++ b/txt/writer.go @@ -31,7 +31,6 @@ func (w *Writer) WriteLine(s string) (int, error) { if w.UseCRLF { return w.WriteString(s + "\r\n") } - return w.WriteString(s + "\n") } @@ -50,6 +49,5 @@ func (w *Writer) WriteAll(contents []string) error { return err } } - return w.Flush() }