diff --git a/choice/choice.go b/choice/choice.go index 6833148..4f426af 100644 --- a/choice/choice.go +++ b/choice/choice.go @@ -5,6 +5,7 @@ import ( "bytes" "errors" "fmt" + "io" "os" "strconv" "strings" @@ -29,17 +30,13 @@ func Menu[E any](choices []E, showQuit bool) string { if len(choices) == 0 { return "" } - var digit int - for n := len(choices); n != 0; digit++ { - n /= 10 - } - option := fmt.Sprintf("%%%dd", digit) + digit := len(strconv.Itoa(len(choices))) var b strings.Builder for i, choice := range choices { - fmt.Fprintf(&b, "%s. %s\n", fmt.Sprintf(option, i+1), choiceStr(choice)) + fmt.Fprintf(&b, "%*d. %s\n", digit, i+1, choiceStr(choice)) } if showQuit { - fmt.Fprintf(&b, "%s. Quit\n", fmt.Sprintf(fmt.Sprintf("%%%ds", digit), "q")) + fmt.Fprintf(&b, "%*d. Quit\n", digit, 0) } return b.String() } @@ -51,13 +48,8 @@ var _ error = choiceError("") type choiceError string -func (err choiceError) Error() string { - return "bad choice: " + string(err) -} - -func (choiceError) Unwrap() error { - return ErrBadChoice -} +func (err choiceError) Error() string { return "bad choice: " + string(err) } +func (choiceError) Unwrap() error { return ErrBadChoice } func choose[E any](choice string, choices []E) (res E, err error) { n, err := strconv.Atoi(choice) @@ -79,6 +71,10 @@ func Choose[E any](choices []E) (choice bool, res E, err error) { // ChooseWithDefault function allows the user to make a choice from the given options with an optional default value. func ChooseWithDefault[E any](choices []E, def int) (choice bool, res E, err error) { + return chooseWithDefault(os.Stdin, choices, def) +} + +func chooseWithDefault[E any](r io.Reader, choices []E, def int) (choice bool, res E, err error) { if n := len(choices); n == 0 { err = errors.New("no choices") return @@ -92,26 +88,31 @@ func ChooseWithDefault[E any](choices []E, def int) (choice bool, res E, err err } else { prompt = "Please choose: " } - scanner := bufio.NewScanner(os.Stdin) - var b []byte - if def <= 0 { - for len(b) == 0 { - fmt.Print(prompt) - scanner.Scan() - b = bytes.TrimSpace(scanner.Bytes()) - } - } else { - fmt.Print(prompt) - scanner.Scan() - b = bytes.TrimSpace(scanner.Bytes()) - if len(b) == 0 { - return true, choices[def-1], nil - } + b, err := readLine(bufio.NewScanner(r), prompt, def <= 0) + if err != nil { + return + } + if len(b) == 0 && def > 0 { + return true, choices[def-1], nil } - if bytes.EqualFold(b, []byte("q")) { + if bytes.EqualFold(b, []byte("0")) || bytes.EqualFold(b, []byte("q")) { return } choice = true res, err = choose(string(b), choices) return } + +func readLine(scanner *bufio.Scanner, prompt string, required bool) ([]byte, error) { + for { + fmt.Print(prompt) + if !scanner.Scan() { + return nil, scanner.Err() + } + b := bytes.TrimSpace(scanner.Bytes()) + if required && len(b) == 0 { + continue + } + return b, nil + } +} diff --git a/choice/choice_test.go b/choice/choice_test.go index b5fd7a2..e06a51a 100644 --- a/choice/choice_test.go +++ b/choice/choice_test.go @@ -30,7 +30,7 @@ func TestMenu(t *testing.T) { 8. hh 9. ii 10. jj - q. Quit + 0. Quit ` if s := Menu(choices, true); s != expect { t.Errorf("expected %q; got %q", expect, s) @@ -54,3 +54,104 @@ func TestError(t *testing.T) { t.Error("expected err is ErrBadChoice; got not") } } + +func TestChooseWithDefault(t *testing.T) { + tests := []struct { + name string + input string + choices []string + def int + wantChoice bool + wantRes string + wantErr bool + }{ + { + name: "valid choice", + input: "2\n", + choices: []string{"a", "b", "c"}, + def: 0, + wantChoice: true, + wantRes: "b", + wantErr: false, + }, + { + name: "default used when empty input", + input: "\n", + choices: []string{"a", "b", "c"}, + def: 2, + wantChoice: true, + wantRes: "b", + wantErr: false, + }, + { + name: "quit with 0", + input: "0\n", + choices: []string{"a", "b"}, + def: 0, + wantChoice: false, + wantRes: "", + wantErr: false, + }, + { + name: "quit with q", + input: "q\n", + choices: []string{"a", "b"}, + def: 0, + wantChoice: false, + wantRes: "", + wantErr: false, + }, + { + name: "invalid input", + input: "x\n", + choices: []string{"a", "b"}, + def: 0, + wantChoice: true, + wantRes: "", + wantErr: true, + }, + { + name: "out of range", + input: "5\n", + choices: []string{"a", "b"}, + def: 0, + wantChoice: true, + wantRes: "", + wantErr: true, + }, + { + name: "no choices", + input: "1\n", + choices: []string{}, + def: 0, + wantChoice: false, + wantRes: "", + wantErr: true, + }, + { + name: "invalid default", + input: "1\n", + choices: []string{"a"}, + def: 5, + wantChoice: false, + wantRes: "", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := strings.NewReader(tt.input) + choice, res, err := chooseWithDefault(r, tt.choices, tt.def) + if choice != tt.wantChoice { + t.Errorf("got choice=%v, want %v", choice, tt.wantChoice) + } + if res != tt.wantRes { + t.Errorf("got res=%q, want %q", res, tt.wantRes) + } + if (err != nil) != tt.wantErr { + t.Errorf("got err=%v, wantErr=%v", err, tt.wantErr) + } + }) + } +} diff --git a/clock/clock.go b/clock/clock.go index 72effce..ab5d88f 100644 --- a/clock/clock.go +++ b/clock/clock.go @@ -21,29 +21,32 @@ var ( var w0 unique.Handle[uint64] +// Clock represents a time within a day (hour, minute, second) type Clock struct { wall unique.Handle[uint64] } -func wall[Int int | float64 | uint64](wall Int) unique.Handle[uint64] { - for wall < 0 { - wall += secondsPerDay +func wall(i int64) unique.Handle[uint64] { + if i %= secondsPerDay; i < 0 { + i += secondsPerDay } - return unique.Make(uint64(int64(wall) % secondsPerDay)) + return unique.Make(uint64(i)) } +// New constructs a Clock from hour, minute, second. +// Values can exceed normal ranges; they are normalized modulo 24 hours. func New(hour, min, sec int) Clock { - return Clock{wall(hour*secondsPerHour + min*secondsPerMinute + sec)} + return Clock{wall(int64(hour)*secondsPerHour + int64(min)*secondsPerMinute + int64(sec))} } var clockLayout = []string{ + "3:04PM", // [time.Kitchen] + "3:04:05PM", "15:04", - time.TimeOnly, - time.Kitchen, - "03:04PM", - "03:04:05PM", + "15:04:05", // [time.TimeOnly] } +// Parse parses a string into a Clock using supported layouts func Parse(v string) (Clock, error) { for _, layout := range clockLayout { if t, err := time.Parse(layout, v); err == nil { @@ -53,6 +56,7 @@ func Parse(v string) (Clock, error) { return Clock{}, fmt.Errorf("cannot parse %q as clock", v) } +// MustParse parses a string and panics if parsing fails func MustParse(v string) Clock { if c, err := Parse(v); err != nil { panic("clock: " + err.Error()) @@ -61,18 +65,22 @@ func MustParse(v string) Clock { } } +// ParseTime converts a time.Time into a Clock (hour, minute, second) func ParseTime(t time.Time) Clock { return New(t.Clock()) } +// Now returns the current local time as a Clock func Now() Clock { return ParseTime(time.Now()) } +// Time converts the Clock into a time.Time on the Unix epoch date func (c Clock) Time() time.Time { return time.Unix(int64(c.wall.Value()), 0).UTC() } +// Clock returns hour, minute, second components of the Clock func (c Clock) Clock() (hour, min, sec int) { sec = int(c.wall.Value()) hour = sec / secondsPerHour @@ -82,34 +90,40 @@ func (c Clock) Clock() (hour, min, sec int) { return } -func (c Clock) Seconds() int64 { - return int64(c.wall.Value()) +// Seconds returns the number of seconds +func (c Clock) Seconds() uint64 { + return c.wall.Value() } +// Hour returns the hour of the Clock func (c Clock) Hour() int { return int(c.Seconds()%secondsPerDay) / secondsPerHour } +// Minute returns the minute of the Clock func (c Clock) Minute() int { return int(c.Seconds()%secondsPerHour) / secondsPerMinute } +// Second returns the second of the Clock func (c Clock) Second() int { return int(c.Seconds() % secondsPerMinute) } +// String returns the Clock as a string "H:MM:SS", or "invalid" if zero func (c Clock) String() string { if c.wall == w0 { - return "" + return "invalid" } return fmt.Sprintf("%d:%02d:%02d", c.Hour(), c.Minute(), c.Second()) } -func (c Clock) MarshalText() (text []byte, err error) { - text = []byte(c.String()) - return +// MarshalText implements encoding.TextMarshaler +func (c Clock) MarshalText() ([]byte, error) { + return []byte(c.String()), nil } +// UnmarshalText implements encoding.TextUnmarshaler func (c *Clock) UnmarshalText(text []byte) error { clock, err := Parse(string(text)) if err != nil { @@ -119,52 +133,62 @@ func (c *Clock) UnmarshalText(text []byte) error { return nil } +// IsValid returns true if the Clock is not the zero/invalid value func (c Clock) IsValid() bool { return c.wall != w0 } +// After returns true if c is after u func (c Clock) After(u Clock) bool { return c.wall.Value() > u.wall.Value() } +// Before returns true if c is before u func (c Clock) Before(u Clock) bool { return c.wall.Value() < u.wall.Value() } +// Equal returns true if c equals u func (c Clock) Equal(u Clock) bool { return c.wall == u.wall } +// Compare compares c with u and returns -1,0,1 func (c Clock) Compare(u Clock) int { return cmp.Compare(c.wall.Value(), u.wall.Value()) } +// Add adds a duration to the Clock and returns a new Clock +// Duration is truncated to seconds func (c Clock) Add(d time.Duration) Clock { - return Clock{wall(float64(c.wall.Value()) + d.Seconds())} + return Clock{wall(int64(c.wall.Value()) + int64(d.Seconds()))} } +// Sub returns the duration between c and u (c - u) func (c Clock) Sub(u Clock) time.Duration { - return time.Duration(c.Seconds()-u.Seconds()) * time.Second + return time.Duration(int64(c.Seconds())-int64(u.Seconds())) * time.Second } +// Since returns the duration from u until c (c - u, wrapped around 24h) func (c Clock) Since(u Clock) time.Duration { return u.Until(c) } +// Until returns the duration from c until u (always non-negative, wraps around 24h) func (c Clock) Until(u Clock) time.Duration { - d := u.Seconds() - c.Seconds() - if d == 0 { - return 0 - } else if d < 0 { + d := int64(u.Seconds()) - int64(c.Seconds()) + if d < 0 { d += secondsPerDay } return time.Duration(d) * time.Second } +// Since returns duration from u until now func Since(c Clock) time.Duration { return c.Until(Now()) } +// Until returns duration from now until c func Until(c Clock) time.Duration { return Now().Until(c) } diff --git a/clock/clock_test.go b/clock/clock_test.go index d85b99a..cb1e67e 100644 --- a/clock/clock_test.go +++ b/clock/clock_test.go @@ -61,7 +61,7 @@ func TestParse(t *testing.T) { func TestSeconds(t *testing.T) { for i, testcase := range []struct { c Clock - expected int64 + expected uint64 }{ {New(0, 0, 0), 0}, {New(7, 1, 2), 7*secondsPerHour + 1*secondsPerMinute + 2}, diff --git a/confirm/confirm.go b/confirm/confirm.go index 6e05b20..a870127 100644 --- a/confirm/confirm.go +++ b/confirm/confirm.go @@ -1,12 +1,19 @@ package confirm import ( + "bufio" "fmt" + "io" + "os" "strings" ) // Do asks the user for confirmation. func Do(prompt string, attempts int) bool { + return do(prompt, attempts, os.Stdout, os.Stdin) +} + +func do(prompt string, attempts int, w io.Writer, r io.Reader) bool { if prompt == "" { prompt = "Are you sure?" } @@ -14,23 +21,31 @@ func Do(prompt string, attempts int) bool { attempts = 3 } - fmt.Print(prompt, " (yes/no): ") - var input string + if _, err := fmt.Fprintf(w, "%s (yes/no): ", prompt); err != nil { + fmt.Println("Error writing to output:", err) + return false + } + scanner := bufio.NewScanner(r) for ; attempts > 0; attempts-- { - if _, err := fmt.Scanln(&input); err != nil { - fmt.Println(err) + if !scanner.Scan() { + if err := scanner.Err(); err != nil { + fmt.Fprintln(w, "Error reading input:", err) + } else { + fmt.Fprintln(w, "No input received (EOF).") + } + continue } - switch strings.ToLower(strings.TrimSpace(input)) { + switch strings.ToLower(strings.TrimSpace(scanner.Text())) { case "y", "yes": return true case "n", "no": return false default: if attempts > 1 { - fmt.Print("Please type 'yes' or 'no': ") + fmt.Fprint(w, "Please type 'yes' or 'no': ") } } } - fmt.Println("Max retries exceeded.") + fmt.Fprintln(w, "Max retries exceeded.") return false } diff --git a/confirm/confirm_test.go b/confirm/confirm_test.go new file mode 100644 index 0000000..7365944 --- /dev/null +++ b/confirm/confirm_test.go @@ -0,0 +1,71 @@ +package confirm + +import ( + "bytes" + "strings" + "testing" +) + +func TestDo_YesResponses(t *testing.T) { + tests := []string{"y\n", "Y\n", "yes\n", " YES \n"} + for _, input := range tests { + t.Run(strings.TrimSpace(input), func(t *testing.T) { + in := strings.NewReader(input) + var out bytes.Buffer + got := do("Confirm?", 1, &out, in) + if !got { + t.Errorf("expected true for input %q, got false", input) + } + }) + } +} + +func TestDo_NoResponses(t *testing.T) { + tests := []string{"n\n", "N\n", "no\n", " NO \n"} + for _, input := range tests { + t.Run(strings.TrimSpace(input), func(t *testing.T) { + in := strings.NewReader(input) + var out bytes.Buffer + got := do("Confirm?", 1, &out, in) + if got { + t.Errorf("expected false for input %q, got true", input) + } + }) + } +} + +func TestDo_InvalidThenYes(t *testing.T) { + in := strings.NewReader("maybe\nYES\n") + var out bytes.Buffer + got := do("Confirm?", 2, &out, in) + if !got { + t.Errorf("expected true after invalid input then yes, got false") + } + if !strings.Contains(out.String(), "Please type 'yes' or 'no':") { + t.Errorf("expected retry prompt, got %q", out.String()) + } +} + +func TestDo_MaxRetriesExceeded(t *testing.T) { + in := strings.NewReader("maybe\nidk\nnope\n") + var out bytes.Buffer + got := do("Confirm?", 3, &out, in) + if got { + t.Errorf("expected false after max retries, got true") + } + if !strings.Contains(out.String(), "Max retries exceeded.") { + t.Errorf("expected max retries message, got %q", out.String()) + } +} + +func TestDo_DefaultPromptAndAttempts(t *testing.T) { + in := strings.NewReader("y\n") + var out bytes.Buffer + got := do("", 0, &out, in) + if !got { + t.Errorf("expected true with default prompt, got false") + } + if !strings.Contains(out.String(), "Are you sure? (yes/no):") { + t.Errorf("expected default prompt, got %q", out.String()) + } +} diff --git a/container/list.go b/container/list.go index 7251ca0..e526b8b 100644 --- a/container/list.go +++ b/container/list.go @@ -1,179 +1,374 @@ package container import ( - "container/list" + "cmp" "sync" + "unsafe" ) -// Element is an element of a linked list. +// Element is a thread-safe element of a linked list. type Element[T any] struct { - e *list.Element + mu sync.RWMutex + + // Next and previous pointers in the doubly-linked list of elements. + // To simplify the implementation, internally a list l is implemented + // as a ring, such that &l.root is both the next element of the last + // list element (l.Back()) and the previous element of the first list + // element (l.Front()). + next, prev *Element[T] // The list to which this element belongs. list *List[T] + + // The value stored with this element. + value T +} + +// Set assigns value v to the element and returns it. +func (e *Element[T]) Set(v T) *Element[T] { + e.mu.Lock() + defer e.mu.Unlock() + e.value = v + return e } // Value returns the value stored with this element. func (e *Element[T]) Value() T { - e.list.mu.RLock() - defer e.list.mu.RUnlock() - return e.e.Value.(T) + e.mu.RLock() + defer e.mu.RUnlock() + return e.value +} + +func (e *Element[T]) nextElement() *Element[T] { + if p := e.next; e.list != nil && p != &e.list.root { + return p + } + return nil } // Next returns the next list element or nil. func (e *Element[T]) Next() *Element[T] { - e.list.mu.RLock() - defer e.list.mu.RUnlock() - if next := e.e.Next(); next != nil { - return &Element[T]{next, e.list} + e.mu.RLock() + defer e.mu.RUnlock() + if e.list != nil { + e.list.mu.RLock() + defer e.list.mu.RUnlock() + } + return e.nextElement() +} + +func (e *Element[T]) prevElement() *Element[T] { + if p := e.prev; e.list != nil && p != &e.list.root { + return p } return nil } // Prev returns the previous list element or nil. func (e *Element[T]) Prev() *Element[T] { - e.list.mu.RLock() - defer e.list.mu.RUnlock() - if prev := e.e.Prev(); prev != nil { - return &Element[T]{prev, e.list} + e.mu.RLock() + defer e.mu.RUnlock() + if e.list != nil { + e.list.mu.RLock() + defer e.list.mu.RUnlock() } - return nil + return e.prevElement() } -// List represents a doubly linked list. +// List represents a thread-safe doubly linked list. // The zero value for List is an empty list ready to use. type List[T any] struct { - mu sync.RWMutex - l list.List + mu sync.RWMutex + root Element[T] // sentinel list element, only &root, root.prev, and root.next are used + len int // current list length excluding (this) sentinel element +} + +func (l *List[T]) init() { + l.root.next = &l.root + l.root.prev = &l.root + l.len = 0 } // Init initializes or clears list l. func (l *List[T]) Init() *List[T] { l.mu.Lock() defer l.mu.Unlock() - l.l.Init() + l.init() return l } -// New returns an initialized list. +// NewList returns an initialized list. func NewList[T any]() *List[T] { return new(List[T]).Init() } -// Len returns the number of elements of list l. The complexity is O(1). +// Len returns the number of elements of list l. +// The complexity is O(1). func (l *List[T]) Len() int { l.mu.RLock() defer l.mu.RUnlock() - return l.l.Len() + return l.len +} + +func (l *List[T]) front() *Element[T] { + if l.len == 0 { + return nil + } + return l.root.next } // Front returns the first element of list l or nil if the list is empty. func (l *List[T]) Front() *Element[T] { l.mu.RLock() defer l.mu.RUnlock() - if e := l.l.Front(); e != nil { - return &Element[T]{e, l} + return l.front() +} + +func (l *List[T]) back() *Element[T] { + if l.len == 0 { + return nil } - return nil + return l.root.prev } // Back returns the last element of list l or nil if the list is empty. func (l *List[T]) Back() *Element[T] { l.mu.RLock() defer l.mu.RUnlock() - if e := l.l.Back(); e != nil { - return &Element[T]{e, l} + return l.back() +} + +// lazyInit lazily initializes a zero List value. +func (l *List[T]) lazyInit() { + if l.root.next == nil { + l.init() } - return nil +} + +// insert inserts e after at, increments l.len, and returns e. +func (l *List[T]) insert(e, at *Element[T]) *Element[T] { + e.prev = at + e.next = at.next + e.prev.next = e + e.next.prev = e + e.list = l + l.len++ + return e +} + +// insertValue is a convenience wrapper for insert(&Element{Value: v}, at). +func (l *List[T]) insertValue(v T, at *Element[T]) *Element[T] { + return l.insert(&Element[T]{value: v}, at) +} + +// remove removes e from its list, decrements l.len +func (l *List[T]) remove(e *Element[T]) { + e.prev.next = e.next + e.next.prev = e.prev + e.next = nil // avoid memory leaks + e.prev = nil // avoid memory leaks + e.list = nil + l.len-- +} + +// move moves e to next to at. +func (l *List[T]) move(e, at *Element[T]) { + if e == at { + return + } + e.prev.next = e.next + e.next.prev = e.prev + + e.prev = at + e.next = at.next + e.prev.next = e + e.next.prev = e } // Remove removes e from l if e is an element of list l. // It returns the element value e.Value. // The element must not be nil. func (l *List[T]) Remove(e *Element[T]) T { - l.mu.Lock() - defer l.mu.Unlock() - return l.l.Remove(e.e).(T) + e.mu.Lock() + defer e.mu.Unlock() + if e.list == l { + l.mu.Lock() + defer l.mu.Unlock() + // if e.list == l, l must have been initialized when e was inserted + // in l or l == nil (e is a zero Element) and l.remove will crash + l.remove(e) + } + return e.value } // PushFront inserts a new element e with value v at the front of list l and returns e. func (l *List[T]) PushFront(v T) *Element[T] { l.mu.Lock() defer l.mu.Unlock() - return &Element[T]{l.l.PushFront(v), l} + l.lazyInit() + return l.insertValue(v, &l.root) } // PushBack inserts a new element e with value v at the back of list l and returns e. func (l *List[T]) PushBack(v T) *Element[T] { l.mu.Lock() defer l.mu.Unlock() - return &Element[T]{l.l.PushBack(v), l} + l.lazyInit() + return l.insertValue(v, l.root.prev) } // InsertBefore inserts a new element e with value v immediately before mark and returns e. // If mark is not an element of l, the list is not modified. // The mark must not be nil. func (l *List[T]) InsertBefore(v T, mark *Element[T]) *Element[T] { + mark.mu.RLock() + defer mark.mu.RUnlock() + if mark.list != l { + return nil + } l.mu.Lock() defer l.mu.Unlock() - return &Element[T]{l.l.InsertBefore(v, mark.e), l} + // see comment in List.Remove about initialization of l + return l.insertValue(v, mark.prev) } // InsertAfter inserts a new element e with value v immediately after mark and returns e. // If mark is not an element of l, the list is not modified. // The mark must not be nil. func (l *List[T]) InsertAfter(v T, mark *Element[T]) *Element[T] { + mark.mu.RLock() + defer mark.mu.RUnlock() + if mark.list != l { + return nil + } l.mu.Lock() defer l.mu.Unlock() - return &Element[T]{l.l.InsertAfter(v, mark.e), l} + // see comment in List.Remove about initialization of l + return l.insertValue(v, mark) } // MoveToFront moves element e to the front of list l. // If e is not an element of l, the list is not modified. // The element must not be nil. func (l *List[T]) MoveToFront(e *Element[T]) { + e.mu.RLock() + defer e.mu.RUnlock() + if e.list != l { + return + } l.mu.Lock() defer l.mu.Unlock() - l.l.MoveToFront(e.e) + if l.root.next == e { + return + } + // see comment in List.Remove about initialization of l + l.move(e, &l.root) } // MoveToBack moves element e to the back of list l. // If e is not an element of l, the list is not modified. // The element must not be nil. func (l *List[T]) MoveToBack(e *Element[T]) { + e.mu.RLock() + defer e.mu.RUnlock() + if e.list != l { + return + } l.mu.Lock() defer l.mu.Unlock() - l.l.MoveToBack(e.e) + if l.root.prev == e { + return + } + // see comment in List.Remove about initialization of l + l.move(e, l.root.prev) } // MoveBefore moves element e to its new position before mark. // If e or mark is not an element of l, or e == mark, the list is not modified. // The element and mark must not be nil. func (l *List[T]) MoveBefore(e, mark *Element[T]) { + if e == mark { + return + } + unlock := lock(&e.mu, &mark.mu, true, true) + defer unlock() + if e.list != l || mark.list != l { + return + } l.mu.Lock() defer l.mu.Unlock() - l.l.MoveBefore(e.e, mark.e) + l.move(e, mark.prev) } // MoveAfter moves element e to its new position after mark. // If e or mark is not an element of l, or e == mark, the list is not modified. // The element and mark must not be nil. func (l *List[T]) MoveAfter(e, mark *Element[T]) { + if e == mark { + return + } + unlock := lock(&e.mu, &mark.mu, true, true) + defer unlock() + if e.list != l || mark.list != l { + return + } l.mu.Lock() defer l.mu.Unlock() - l.l.MoveAfter(e.e, mark.e) + l.move(e, mark) } // PushBackList inserts a copy of another list at the back of list l. // The lists l and other may be the same. They must not be nil. func (l *List[T]) PushBackList(other *List[T]) { - l.mu.Lock() - defer l.mu.Unlock() - l.l.PushBackList(&other.l) + unlock := lock(&l.mu, &other.mu, false, true) + defer unlock() + l.lazyInit() + for i, e := other.len, other.front(); i > 0; i, e = i-1, e.nextElement() { + l.insertValue(e.value, l.root.prev) + } } // PushFrontList inserts a copy of another list at the front of list l. // The lists l and other may be the same. They must not be nil. func (l *List[T]) PushFrontList(other *List[T]) { - l.mu.Lock() - defer l.mu.Unlock() - l.l.PushFrontList(&other.l) + unlock := lock(&l.mu, &other.mu, false, true) + defer unlock() + l.lazyInit() + for i, e := other.len, other.back(); i > 0; i, e = i-1, e.prevElement() { + l.insertValue(e.value, &l.root) + } +} + +func lock(s, r *sync.RWMutex, sReadOnly, rReadOnly bool) (unlock func()) { + var sl sync.Locker = s + var rl sync.Locker = r + if sReadOnly { + sl = s.RLocker() + } + if rReadOnly { + rl = r.RLocker() + } + switch cmp.Compare(uintptr(unsafe.Pointer(s)), uintptr(unsafe.Pointer(r))) { + case 0: + if sReadOnly && rReadOnly { + s.RLock() + unlock = s.RUnlock + } else { + s.Lock() + unlock = s.Unlock + } + case 1: + sl.Lock() + rl.Lock() + unlock = func() { + rl.Unlock() + sl.Unlock() + } + case -1: + rl.Lock() + sl.Lock() + unlock = func() { + sl.Unlock() + rl.Unlock() + } + } + return } diff --git a/container/list_test.go b/container/list_test.go index 8b75210..0cc9861 100644 --- a/container/list_test.go +++ b/container/list_test.go @@ -4,10 +4,7 @@ package container -import ( - "container/list" - "testing" -) +import "testing" func checkListLen[T int](t *testing.T, l *List[T], len int) bool { if n := l.Len(); n != len { @@ -140,7 +137,7 @@ func TestInsertBeforeUnknownMark(t *testing.T) { l.PushBack(1) l.PushBack(2) l.PushBack(3) - l.InsertBefore(1, &Element[int]{new(list.Element), nil}) + l.InsertBefore(1, new(Element[int])) checkList(t, &l, []int{1, 2, 3}) } @@ -150,7 +147,7 @@ func TestInsertAfterUnknownMark(t *testing.T) { l.PushBack(1) l.PushBack(2) l.PushBack(3) - l.InsertAfter(1, &Element[int]{new(list.Element), nil}) + l.InsertAfter(1, new(Element[int])) checkList(t, &l, []int{1, 2, 3}) } diff --git a/container/map.go b/container/map.go index b0f09f1..c1165fc 100644 --- a/container/map.go +++ b/container/map.go @@ -2,11 +2,14 @@ package container import "sync" -type Map[Key, Value any] struct { +// Map is a generic concurrency-safe map that wraps sync.Map +// and provides type-safe access for keys and values. +type Map[Key comparable, Value any] struct { m sync.Map } -func NewMap[Key, Value any]() *Map[Key, Value] { +// NewMap creates and returns a new, empty generic concurrency-safe Map. +func NewMap[Key comparable, Value any]() *Map[Key, Value] { return &Map[Key, Value]{} } @@ -15,8 +18,8 @@ func NewMap[Key, Value any]() *Map[Key, Value] { // The ok result indicates whether value was found in the map. func (m *Map[Key, Value]) Load(key Key) (value Value, ok bool) { var v any - if v, ok = m.m.Load(key); ok && v != nil { - value = v.(Value) + if v, ok = m.m.Load(key); ok { + value, _ = v.(Value) } return } @@ -37,9 +40,7 @@ func (m *Map[Key, Value]) Clear() { func (m *Map[Key, Value]) LoadOrStore(key Key, value Value) (actual Value, loaded bool) { var v any if v, loaded = m.m.LoadOrStore(key, value); loaded { - if v != nil { - actual = v.(Value) - } + actual, _ = v.(Value) } else { actual = value } @@ -50,8 +51,8 @@ func (m *Map[Key, Value]) LoadOrStore(key Key, value Value) (actual Value, loade // The loaded result reports whether the key was present. func (m *Map[Key, Value]) LoadAndDelete(key Key) (value Value, loaded bool) { var v any - if v, loaded = m.m.LoadAndDelete(key); loaded && v != nil { - value = v.(Value) + if v, loaded = m.m.LoadAndDelete(key); loaded { + value, _ = v.(Value) } return } @@ -65,8 +66,8 @@ func (m *Map[Key, Value]) Delete(key Key) { // The loaded result reports whether the key was present. func (m *Map[Key, Value]) Swap(key Key, value Value) (previous Value, loaded bool) { var v any - if v, loaded = m.m.Swap(key, value); loaded && v != nil { - previous = v.(Value) + if v, loaded = m.m.Swap(key, value); loaded { + previous, _ = v.(Value) } return } @@ -74,7 +75,7 @@ func (m *Map[Key, Value]) Swap(key Key, value Value) (previous Value, loaded boo // CompareAndSwap swaps the old and new values for key // if the value stored in the map is equal to old. // The old value must be of a comparable type. -func (m *Map[Key, Value]) CompareAndSwap(key Key, old Value, new Value) bool { +func (m *Map[Key, Value]) CompareAndSwap(key Key, old Value, new Value) (swapped bool) { return m.m.CompareAndSwap(key, old, new) } @@ -100,14 +101,11 @@ func (m *Map[Key, Value]) CompareAndDelete(key Key, old Value) (deleted bool) { // false after a constant number of calls. func (m *Map[Key, Value]) Range(f func(Key, Value) bool) { m.m.Range(func(key, value any) bool { - var k Key - var v Value - if key != nil { - k = key.(Key) - } - if value != nil { - v = value.(Value) + if k, ok := key.(Key); ok { + if v, ok := value.(Value); ok { + return f(k, v) + } } - return f(k, v) + return true }) } diff --git a/container/map_test.go b/container/map_test.go new file mode 100644 index 0000000..f86c6fc --- /dev/null +++ b/container/map_test.go @@ -0,0 +1,140 @@ +package container + +import ( + "sync" + "testing" +) + +func TestMap_BasicOperations(t *testing.T) { + m := NewMap[string, int]() + + // Store & Load + m.Store("a", 1) + if v, ok := m.Load("a"); !ok || v != 1 { + t.Fatalf("Load failed, got (%v, %v), want (1, true)", v, ok) + } + + // LoadOrStore (existing key) + actual, loaded := m.LoadOrStore("a", 2) + if !loaded || actual != 1 { + t.Fatalf("LoadOrStore existing key failed, got (%v, %v), want (1, true)", actual, loaded) + } + + // LoadOrStore (new key) + actual, loaded = m.LoadOrStore("b", 3) + if loaded || actual != 3 { + t.Fatalf("LoadOrStore new key failed, got (%v, %v), want (3, false)", actual, loaded) + } + + // LoadAndDelete + v, ok := m.LoadAndDelete("b") + if !ok || v != 3 { + t.Fatalf("LoadAndDelete failed, got (%v, %v), want (3, true)", v, ok) + } + if _, ok := m.Load("b"); ok { + t.Fatalf("Key 'b' should be deleted") + } + + // Swap + prev, loaded := m.Swap("a", 5) + if !loaded || prev != 1 { + t.Fatalf("Swap failed, got (%v, %v), want (1, true)", prev, loaded) + } + v, ok = m.Load("a") + if !ok || v != 5 { + t.Fatalf("Swap did not update value, got (%v, %v), want (5, true)", v, ok) + } + + // CompareAndSwap + if !m.CompareAndSwap("a", 5, 10) { + t.Fatalf("CompareAndSwap should succeed") + } + v, _ = m.Load("a") + if v != 10 { + t.Fatalf("CompareAndSwap failed, value = %v, want 10", v) + } + + if m.CompareAndSwap("a", 5, 20) { + t.Fatalf("CompareAndSwap should fail when old != current") + } + + // CompareAndDelete + m.Store("x", 42) + if !m.CompareAndDelete("x", 42) { + t.Fatalf("CompareAndDelete should succeed") + } + if _, ok := m.Load("x"); ok { + t.Fatalf("CompareAndDelete did not delete key") + } + + m.Store("y", 99) + if m.CompareAndDelete("y", 100) { + t.Fatalf("CompareAndDelete should fail when old != current") + } +} + +func TestMap_Range(t *testing.T) { + m := NewMap[string, int]() + m.Store("a", 1) + m.Store("b", 2) + m.Store("c", 3) + + collected := make(map[string]int) + m.Range(func(k string, v int) bool { + collected[k] = v + return true + }) + + if len(collected) != 3 { + t.Fatalf("Range failed, got %d elements, want 3", len(collected)) + } +} + +func TestMap_Clear(t *testing.T) { + m := NewMap[string, int]() + m.Store("a", 1) + m.Store("b", 2) + m.Clear() + + count := 0 + m.Range(func(k string, v int) bool { + count++ + return true + }) + + if count != 0 { + t.Fatalf("Clear failed, map still has %d elements", count) + } +} + +func TestMap_ConcurrentAccess(t *testing.T) { + m := NewMap[int, int]() + wg := sync.WaitGroup{} + + // Concurrent store + for i := range 100 { + wg.Add(1) + go func(i int) { + defer wg.Done() + m.Store(i, i*i) + }(i) + } + + wg.Wait() + + // Concurrent load + wg = sync.WaitGroup{} + for i := range 100 { + wg.Add(1) + go func(i int) { + defer wg.Done() + v, ok := m.Load(i) + if !ok { + t.Errorf("Missing key %d", i) + } else if v != i*i { + t.Errorf("Unexpected value for %d: got %d, want %d", i, v, i*i) + } + }(i) + } + wg.Wait() +} diff --git a/container/ring.go b/container/ring.go index 4b29b40..b91e6f9 100644 --- a/container/ring.go +++ b/container/ring.go @@ -1,80 +1,66 @@ package container import ( - "container/ring" "sync" + "unsafe" ) -var ringMutex sync.RWMutex - -type mutex ring.Ring - -func newMutex() *mutex { - r := ring.New(1) - r.Value = new(sync.RWMutex) - return (*mutex)(r) -} - -func (mu *mutex) Lock() { - ringMutex.RLock() - defer ringMutex.RUnlock() - (*ring.Ring)(mu).Do(func(a any) { - a.(*sync.RWMutex).Lock() - }) -} - -func (mu *mutex) Unlock() { - ringMutex.RLock() - defer ringMutex.RUnlock() - (*ring.Ring)(mu).Do(func(a any) { - a.(*sync.RWMutex).Unlock() - }) -} - -func (mu *mutex) RLock() { - ringMutex.RLock() - defer ringMutex.RUnlock() - (*ring.Ring)(mu).Do(func(a any) { - a.(*sync.RWMutex).RLock() - }) -} - -func (mu *mutex) RUnlock() { - ringMutex.RLock() - defer ringMutex.RUnlock() - (*ring.Ring)(mu).Do(func(a any) { - a.(*sync.RWMutex).RUnlock() - }) -} - -func (mu *mutex) Link(s *mutex) *mutex { - ringMutex.Lock() - defer ringMutex.Unlock() - return (*mutex)((*ring.Ring)(mu).Link((*ring.Ring)(s))) -} - -// A Ring is an element of a circular list, or ring. +// A Ring is an element of a thread-safe, generic circular list, or ring. // Rings do not have a beginning or end; a pointer to any ring element // serves as reference to the entire ring. Empty rings are represented -// as nil Ring pointers. The zero value for a Ring is a one-element -// ring with a nil Value. +// as nil Ring pointers. The zero value for a Ring[T] is a one-element +// ring with a zero Value of type T. type Ring[T any] struct { - mu *mutex - r *ring.Ring + mu sync.RWMutex + ringMu *sync.RWMutex + + next, prev *Ring[T] + value T +} + +func (r *Ring[T]) init() *Ring[T] { + r.ringMu = new(sync.RWMutex) + r.next = r + r.prev = r + return r } // Next returns the next ring element. r must not be empty. func (r *Ring[T]) Next() *Ring[T] { r.mu.Lock() defer r.mu.Unlock() - return &Ring[T]{r.mu, r.r.Next()} + if r.ringMu == nil { + return r.init() + } + r.ringMu.RLock() + defer r.ringMu.RUnlock() + return r.next } // Prev returns the previous ring element. r must not be empty. func (r *Ring[T]) Prev() *Ring[T] { r.mu.Lock() defer r.mu.Unlock() - return &Ring[T]{r.mu, r.r.Prev()} + if r.ringMu == nil { + return r.init() + } + r.ringMu.RLock() + defer r.ringMu.RUnlock() + return r.prev +} + +func (r *Ring[T]) move(n int) *Ring[T] { + switch { + case n < 0: + for ; n < 0; n++ { + r = r.prev + } + case n > 0: + for ; n > 0; n-- { + r = r.next + } + } + return r } // Move moves n % r.Len() elements backward (n < 0) or forward (n >= 0) @@ -82,7 +68,12 @@ func (r *Ring[T]) Prev() *Ring[T] { func (r *Ring[T]) Move(n int) *Ring[T] { r.mu.Lock() defer r.mu.Unlock() - return &Ring[T]{newMutex(), r.r.Move(n)} + if r.ringMu == nil { + return r.init() + } + r.ringMu.RLock() + defer r.ringMu.RUnlock() + return r.move(n) } // NewRing creates a ring of n elements. @@ -90,20 +81,103 @@ func NewRing[T any](n int) *Ring[T] { if n <= 0 { return nil } - return &Ring[T]{newMutex(), ring.New(n)} + r := &Ring[T]{ringMu: new(sync.RWMutex)} + p := r + for i := 1; i < n; i++ { + p.next = &Ring[T]{ringMu: p.ringMu, prev: p} + p = p.next + } + p.next = r + r.prev = p + return r } -func (r *Ring[T]) Set(v T) { - r.mu.Lock() - defer r.mu.Unlock() - r.r.Value = &v +func linkLock[T any](r, s *Ring[T]) (unlock func()) { + rmu, smu := r.ringMu, s.ringMu + if s == r { + r.mu.Lock() + rmu.Lock() + unlock = func() { + rmu.Unlock() + r.mu.Unlock() + } + } else { + order := uintptr(unsafe.Pointer(&r.mu)) < uintptr(unsafe.Pointer(&s.mu)) + var finalUnlock func() + if order { + r.mu.Lock() + s.mu.Lock() + finalUnlock = func() { + s.mu.Unlock() + r.mu.Unlock() + } + } else { + s.mu.Lock() + r.mu.Lock() + finalUnlock = func() { + r.mu.Unlock() + s.mu.Unlock() + } + } + switch smu { + case nil: + s.init() + smu = rmu + fallthrough + case rmu: + smu.Lock() + unlock = func() { + smu.Unlock() + finalUnlock() + } + default: + if order { + rmu.Lock() + smu.Lock() + unlock = func() { + smu.Unlock() + rmu.Unlock() + finalUnlock() + } + } else { + smu.Lock() + rmu.Lock() + unlock = func() { + rmu.Unlock() + smu.Unlock() + finalUnlock() + } + } + } + } + return } -func (r *Ring[T]) Value() *T { - r.mu.RLock() - defer r.mu.RUnlock() - v, _ := r.r.Value.(*T) - return v +func (r *Ring[T]) link(s *Ring[T]) *Ring[T] { + var sameRing bool + if s.ringMu == r.ringMu { + sameRing = true + } else { + s.ringMu = r.ringMu + for p := s.next; p != s; p = p.next { + p.ringMu = r.ringMu + } + } + n := r.next + p := s.prev + // Note: Cannot use multiple assignment because + // evaluation order of LHS is not specified. + r.next = s + s.prev = r + n.prev = p + p.next = n + if sameRing { + n.ringMu = new(sync.RWMutex) + for p := n.next; p != n; p = p.next { + p.ringMu = n.ringMu + } + } + return n } // Link connects ring r with ring s such that r.Next() @@ -122,52 +196,93 @@ func (r *Ring[T]) Value() *T { // after r. The result points to the element following the // last element of s after insertion. func (r *Ring[T]) Link(s *Ring[T]) *Ring[T] { + r.mu.Lock() + if r.ringMu == nil { + r.init() + } + n := r.next + r.mu.Unlock() if s == nil { - return r.Next() + return n } - m := r.mu.Link(s.mu) - r.mu.Lock() - defer r.mu.Unlock() - return &Ring[T]{m, r.r.Link(s.r)} + unlock := linkLock(r, s) + defer unlock() + return r.link(s) } // Unlink removes n % r.Len() elements from the ring r, starting // at r.Next(). If n % r.Len() == 0, r remains unchanged. // The result is the removed subring. r must not be empty. func (r *Ring[T]) Unlink(n int) *Ring[T] { + if n <= 0 { + return nil + } r.mu.Lock() defer r.mu.Unlock() - u := r.r.Unlink(n) - if u == nil { - return nil + if r.ringMu == nil { + return r.init() } - return &Ring[T]{newMutex(), u} + r.ringMu.Lock() + defer r.ringMu.Unlock() + return r.link(r.move(n + 1)) } // Len computes the number of elements in ring r. // It executes in time proportional to the number of elements. func (r *Ring[T]) Len() int { - if r == nil { - return 0 + n := 0 + if r != nil { + r.mu.Lock() + defer r.mu.Unlock() + if r.ringMu == nil { + r.init() + return 1 + } + r.ringMu.RLock() + defer r.ringMu.RUnlock() + n = 1 + for p := r.next; p != r; p = p.next { + n++ + } } - r.mu.RLock() - defer r.mu.RUnlock() - return r.r.Len() + return n } // Do calls function f on each element of the ring, in forward order. // The behavior of Do is undefined if f changes *r. -func (r *Ring[T]) Do(f func(*T)) { - if r == nil { - return +func (r *Ring[T]) Do(f func(T)) { + if r != nil { + r.mu.RLock() + if r.ringMu == nil { + r.mu.RUnlock() + return + } + r.ringMu.RLock() + defer r.ringMu.RUnlock() + f(r.value) + r.mu.RUnlock() + for p := r.next; p != r; p = p.next { + p.mu.RLock() + f(p.value) + p.mu.RUnlock() + } } +} + +// Set assigns value v to the ring element and returns it. +func (r *Ring[T]) Set(v T) *Ring[T] { + r.mu.Lock() + defer r.mu.Unlock() + if r.ringMu == nil { + r.init() + } + r.value = v + return r +} + +// Value returns the value of the ring element. +func (r *Ring[T]) Value() T { r.mu.RLock() defer r.mu.RUnlock() - r.r.Do(func(a any) { - if v, ok := a.(*T); ok { - f(v) - } else { - f(nil) - } - }) + return r.value } diff --git a/container/ring_test.go b/container/ring_test.go index 5d9217c..e92b126 100644 --- a/container/ring_test.go +++ b/container/ring_test.go @@ -5,13 +5,13 @@ package container import ( - "container/ring" "fmt" + "sync" "testing" ) // For debugging - keep around. -func dump(r *ring.Ring) { +func dump[T any](r *Ring[T]) { if r == nil { fmt.Println("empty") return @@ -24,7 +24,7 @@ func dump(r *ring.Ring) { fmt.Println() } -func verify[T int](t *testing.T, r *Ring[T], N int, sum int) { +func verify(t *testing.T, r *Ring[int], N int, sum int) { // Len n := r.Len() if n != N { @@ -34,11 +34,9 @@ func verify[T int](t *testing.T, r *Ring[T], N int, sum int) { // iteration n = 0 s := 0 - r.Do(func(p *T) { + r.Do(func(p int) { n++ - if p != nil { - s += int(*p) - } + s += p }) if n != N { t.Errorf("number of forward iterations == %d; expected %d", n, N) @@ -47,41 +45,41 @@ func verify[T int](t *testing.T, r *Ring[T], N int, sum int) { t.Errorf("forward ring sum = %d; expected %d", s, sum) } - if r == nil || r.r == nil { + if r == nil { return } // connections - if r.Next().r != nil { - var p *ring.Ring // previous element - for q := r.r; p == nil || q != r.r; q = q.Next() { + if r.Next() != nil { + var p *Ring[int] // previous element + for q := r; p == nil || q != r; q = q.next { if p != nil && p != q.Prev() { t.Errorf("prev = %p, expected q.prev = %p\n", p, q.Prev()) } p = q } - if p != r.Prev().r { + if p != r.Prev() { t.Errorf("prev = %p, expected r.prev = %p\n", p, r.Prev()) } } // Move - if r.Move(0).r != r.r { + if r.Move(0) != r { t.Errorf("r.Move(0) != r") } - if r.Move(N).r != r.r { + if r.Move(N) != r { t.Errorf("r.Move(%d) != r", N) } - if r.Move(-N).r != r.r { + if r.Move(-N) != r { t.Errorf("r.Move(%d) != r", -N) } for i := 0; i < 10; i++ { ni := N + i mi := ni % N - if r.Move(ni).r != r.Move(mi).r { + if r.Move(ni) != r.Move(mi) { t.Errorf("r.Move(%d) != r.Move(%d)", ni, mi) } - if r.Move(-ni).r != r.Move(-mi).r { + if r.Move(-ni) != r.Move(-mi) { t.Errorf("r.Move(%d) != r.Move(%d)", -ni, -mi) } } @@ -89,8 +87,8 @@ func verify[T int](t *testing.T, r *Ring[T], N int, sum int) { func TestCornerCases(t *testing.T) { var ( - r0 = &Ring[int]{newMutex(), nil} - r1 = Ring[int]{newMutex(), new(ring.Ring)} + r0 *Ring[int] + r1 Ring[int] ) // Basics verify(t, r0, 0, 0) @@ -132,16 +130,16 @@ func TestNew(t *testing.T) { func TestLink1(t *testing.T) { r1a := makeN(1) - var r1b = Ring[int]{newMutex(), &ring.Ring{}} + var r1b Ring[int] r2a := r1a.Link(&r1b) verify(t, r2a, 2, 1) - if r2a.r != r1a.r { + if r2a != r1a { t.Errorf("a) 2-element link failed") } r2b := r2a.Link(r2a.Next()) verify(t, r2b, 2, 1) - if r2b.r != r2a.Next().r { + if r2b != r2a.Next() { t.Errorf("b) 2-element link failed") } @@ -151,7 +149,7 @@ func TestLink1(t *testing.T) { } func TestLink2(t *testing.T) { - var r0 = &Ring[int]{newMutex(), nil} + var r0 *Ring[int] r1a := NewRing[int](1) r1a.Set(42) r1b := NewRing[int](1) @@ -172,7 +170,7 @@ func TestLink2(t *testing.T) { } func TestLink3(t *testing.T) { - var r = Ring[int]{newMutex(), new(ring.Ring)} + var r Ring[int] n := 1 for i := 1; i < 10; i++ { n += i @@ -216,8 +214,119 @@ func TestLinkUnlink(t *testing.T) { // Test that calling Move() on an empty Ring initializes it. func TestMoveEmptyRing(t *testing.T) { - var r = Ring[int]{newMutex(), &ring.Ring{}} + var r Ring[int] r.Move(1) verify(t, &r, 1, 0) } + +// TestLinkSharedRing tests the Link function to ensure the shared ringMu is correctly set. +func TestLinkSharedRing(t *testing.T) { + // Helper function to check if all elements in a ring share the same ringMu. + checkRingMu := func(t *testing.T, r *Ring[int], expectedMu *sync.RWMutex, name string) { + if r == nil { + t.Errorf("%s: ring is nil", name) + return + } + seen := make(map[*Ring[int]]bool) + current := r + for { + if current.ringMu != expectedMu { + t.Errorf("%s: element %p has ringMu %p, expected %p", name, current, current.ringMu, expectedMu) + } + seen[current] = true + current = current.Next() + if current == r { + break + } + if seen[current] { + t.Errorf("%s: cycle detected before reaching start", name) + break + } + } + } + + // Test 1: Link two distinct rings. + t.Run("LinkDistinctRings", func(t *testing.T) { + r1 := NewRing[int](3) + r2 := NewRing[int](2) + originalR1Mu := r1.ringMu + originalR2Mu := r2.ringMu + + // Link r1 and r2. + r1.Link(r2) + + if originalR1Mu == originalR2Mu { + t.Fatalf("initial rings have same ringMu %p", originalR1Mu) + } + + // Check that all elements in the combined ring share r1's ringMu. + checkRingMu(t, r1, originalR1Mu, "combined ring") + + // Verify ring length. + if r1.Len() != 5 { + t.Errorf("combined ring length is %d, expected 5", r1.Len()) + } + }) + + // Test 2: Link within the same ring. + t.Run("LinkSameRing", func(t *testing.T) { + r := NewRing[int](5) + originalMu := r.ringMu + + // Move to the third element. + r3 := r.Move(2) + + // Link r to r3, splitting the ring. + result := r.Link(r3) + + // Check that the original ring (r) has 2 elements and retains original ringMu. + checkRingMu(t, r, originalMu, "original ring") + + // Check that the result ring has 3 elements and a new ringMu. + if result.ringMu == originalMu { + t.Errorf("result ring has same ringMu %p as original %p", result.ringMu, originalMu) + } + + checkRingMu(t, result, result.ringMu, "result ring") + }) + + // Test 3: Link a ring to itself. + t.Run("LinkSelf", func(t *testing.T) { + r := NewRing[int](5) + originalMu := r.ringMu + + result := r.Link(r) + + // Check that the ring remains unchanged and retains original ringMu. + checkRingMu(t, r, originalMu, "self-linked ring") + if r.Len() != 1 { + t.Errorf("self-linked ring length is %d, expected 1", r.Len()) + } + + // Check that the result ring has 3 elements and a new ringMu. + if result.ringMu == originalMu { + t.Errorf("result ring has same ringMu %p as original %p", result.ringMu, originalMu) + } + + checkRingMu(t, result, result.ringMu, "result ring") + }) + + // Test 4: Link with nil. + t.Run("LinkNil", func(t *testing.T) { + r := NewRing[int](3) + originalMu := r.ringMu + originalNext := r.Next() + + result := r.Link(nil) + + // Check that the ring is unchanged and retains original ringMu. + checkRingMu(t, r, originalMu, "ring after linking nil") + if r.Len() != 3 { + t.Errorf("ring length is %d, expected 3", r.Len()) + } + if result != originalNext { + t.Errorf("Link result is %p, expected %p", result, originalNext) + } + }) +} diff --git a/container/value.go b/container/value.go index dd84c5c..d6b21aa 100644 --- a/container/value.go +++ b/container/value.go @@ -15,52 +15,31 @@ func NewValue[T any]() *Value[T] { return &Value[T]{} } -// Load returns the value set by the most recent Store and a boolean -// indicating whether a value was stored. -// If there has been no call to Store for this Value, it returns the -// zero value of T and false. -func (v *Value[T]) Load() (val T, stored bool) { - if v := v.v.Load(); v == nil { - return - } else { - return v.(T), true +// Load returns the value set by the most recent [Value.Store]. +// If there has been no call to [Value.Store] for this Value, it returns the +// zero value of T. +func (v *Value[T]) Load() (val T) { + if loaded := v.v.Load(); loaded != nil { + return loaded.(T) } -} - -// MustLoad returns the value set by the most recent Store. -// It panics if there has been no call to Store for this Value. -func (v *Value[T]) MustLoad() (val T) { - if v, stored := v.Load(); stored { - return v - } - panic("cache/value: there has been no call to Store for this Value") + return } // Store sets the value of the [Value] v to val. func (v *Value[T]) Store(val T) { - if any(val) == nil { - panic("cache/value: store of nil value into Value") - } v.v.Store(val) } // Swap stores the new value into the Value and returns the previous value. -// If no value was previously stored, it returns the zero value of T and false. -func (v *Value[T]) Swap(new T) (old T, stored bool) { - if any(new) == nil { - panic("cache/value: swap of nil value into Value") - } - if v := v.v.Swap(new); v == nil { - return - } else { - return v.(T), true +// If no value was previously stored, it returns the zero value of T. +func (v *Value[T]) Swap(new T) (old T) { + if prev := v.v.Swap(new); prev != nil { + return prev.(T) } + return } // CompareAndSwap executes the compare-and-swap operation for the [Value]. func (v *Value[T]) CompareAndSwap(old, new T) (swapped bool) { - if any(new) == nil { - panic("cache/value: compare and swap of nil value into Value") - } return v.v.CompareAndSwap(old, new) } diff --git a/container/value_test.go b/container/value_test.go index 119129a..4433670 100644 --- a/container/value_test.go +++ b/container/value_test.go @@ -13,15 +13,12 @@ import ( func TestValue(t *testing.T) { v := NewValue[int]() - if _, ok := v.Load(); ok { - t.Fatal("initial Value is not nil") - } v.Store(42) - if i, ok := v.Load(); !ok || i != 42 { + if i := v.Load(); i != 42 { t.Fatalf("wrong value: got %d, want 42", i) } v.Store(84) - if i, ok := v.Load(); !ok || i != 84 { + if i := v.Load(); i != 84 { t.Fatalf("wrong value: got %d, want 84", i) } } @@ -29,55 +26,22 @@ func TestValue(t *testing.T) { func TestValueLarge(t *testing.T) { v := NewValue[string]() v.Store("foo") - if s, ok := v.Load(); !ok || s != "foo" { + if s := v.Load(); s != "foo" { t.Fatalf("wrong value: got %s, want foo", s) } v.Store("barbaz") - if s, ok := v.Load(); !ok || s != "barbaz" { + if s := v.Load(); s != "barbaz" { t.Fatalf("wrong value: got %s, want barbaz", s) } } -func TestValuePanic(t *testing.T) { - const nilErr = "cache/value: store of nil value into Value" - v := NewValue[any]() - func() { - defer func() { - err := recover() - if err != nilErr { - t.Fatalf("inconsistent store panic: got '%v', want '%v'", err, nilErr) - } - }() - v.Store(nil) - }() - v.Store(1) - func() { - defer func() { - err := recover() - if err != nilErr { - t.Fatalf("inconsistent store panic: got '%v', want '%v'", err, nilErr) - } - }() - v.Store(nil) - }() -} - func TestPointType(t *testing.T) { v := NewValue[*int]() - if v, stored := v.Load(); stored { - t.Fatal("wrong stored status") - } else if v != nil { - t.Fatal("initial Value is not nil") - } v.Store((*int)(nil)) - if v, stored := v.Load(); !stored { - t.Fatal("wrong stored status") - } else if v != nil { + if v := v.Load(); v != nil { t.Fatalf("wrong value: got %v, want nil", v) } - if old, stored := v.Swap(utils.Ptr(1)); !stored { - t.Fatal("wrong stored status") - } else if old != nil { + if old := v.Swap(utils.Ptr(1)); old != nil { t.Fatalf("wrong value: got %v, want nil", v) } } @@ -105,7 +69,7 @@ func TestValueConcurrent(t *testing.T) { for j := 0; j < N; j++ { x := test[rand.IntN(len(test))] v.Store(x) - x = v.MustLoad() + x = v.Load() for _, x1 := range test { if x == x1 { continue loop @@ -131,7 +95,7 @@ func BenchmarkValueRead(b *testing.B) { v.Store(new(int)) b.RunParallel(func(pb *testing.PB) { for pb.Next() { - x := v.MustLoad() + x := v.Load() if *x != 0 { b.Fatalf("wrong value: got %v, want 0", *x) } @@ -145,7 +109,7 @@ var Value_SwapTests = []struct { want any err any }{ - {init: nil, new: nil, err: "cache/value: swap of nil value into Value"}, + {init: nil, new: nil, err: "sync/atomic: swap of nil value into Value"}, {init: nil, new: true, want: nil, err: nil}, {init: true, new: "", err: "sync/atomic: swap of inconsistently typed value into Value"}, {init: true, new: false, want: true, err: nil}, @@ -167,10 +131,10 @@ func TestValue_Swap(t *testing.T) { t.Errorf("should panic %v, got ", tt.err) } }() - if got, _ := v.Swap(tt.new); got != tt.want { + if got := v.Swap(tt.new); got != tt.want { t.Errorf("got %v, want %v", got, tt.want) } - if got, stored := v.Load(); !stored || got != tt.new { + if got := v.Load(); got != tt.new { t.Errorf("got %v, want %v", got, tt.new) } }) @@ -192,16 +156,14 @@ func TestValueSwapConcurrent(t *testing.T) { go func() { var c uint64 for new := i; new < i+n; new++ { - if old, stored := v.Swap(new); stored { - c += old - } + c += v.Swap(new) } atomic.AddUint64(&count, c) g.Done() }() } g.Wait() - if want, got := (m*n-1)*(m*n)/2, count+v.MustLoad(); got != want { + if want, got := (m*n-1)*(m*n)/2, count+v.Load(); got != want { t.Errorf("sum from 0 to %d was %d, want %v", m*n-1, got, want) } } @@ -270,7 +232,7 @@ func TestValueCompareAndSwapConcurrent(t *testing.T) { }() } w.Wait() - if stop := v.MustLoad(); stop != m*n { + if stop := v.Load(); stop != m*n { t.Errorf("did not get to %v, stopped at %v", m*n, stop) } } diff --git a/counter/counter.go b/counter/counter.go index fd87779..904147d 100644 --- a/counter/counter.go +++ b/counter/counter.go @@ -5,20 +5,87 @@ import ( "sync/atomic" ) -type Counter atomic.Int64 +// Counter is a thread-safe utility for counting values, starting from zero. +type Counter struct { + n atomic.Int64 +} + +// Add adds delta to the [Counter] and returns the new value. +func (c *Counter) Add(delta int64) int64 { + return c.n.Add(delta) +} + +// Get returns the current value of the [Counter]. +func (c *Counter) Get() int64 { + return c.n.Load() +} + +// CounterReader wraps an io.Reader to count the number of bytes read. +type CounterReader struct { + r io.Reader // Underlying reader + c *Counter // Counter for bytes read +} + +// CountReader creates an io.Reader that counts bytes read from r, using the provided [Counter]. +func CountReader(r io.Reader, c *Counter) io.Reader { + return NewCounterReader(r, c) +} + +// NewCounterReader creates a [CounterReader] that counts bytes read from r, using the provided [Counter]. +// If c is nil, a new Counter is created. +func NewCounterReader(r io.Reader, c *Counter) *CounterReader { + if c == nil { + c = new(Counter) + } + return &CounterReader{r, c} +} + +// Read reads from the underlying Reader and increments the counter by the number of bytes read. +// It returns the number of bytes read and any error encountered. +func (r *CounterReader) Read(p []byte) (n int, err error) { + n, err = r.r.Read(p) + if n > 0 { + r.c.Add(int64(n)) + } + return +} + +// Bytes returns the total number of bytes read. +func (r *CounterReader) Bytes() int64 { + return r.c.Get() +} + +// CounterWriter wraps an io.Writer to count the number of bytes written. +type CounterWriter struct { + w io.Writer // Underlying writer + c *Counter // Counter for bytes written +} -func (c *Counter) Add(delta int64) (new int64) { - return (*atomic.Int64)(c).Add(delta) +// CountWriter creates an io.Writer that counts bytes written to w, using the provided [Counter]. +func CountWriter(w io.Writer, c *Counter) io.Writer { + return NewCounterWriter(w, c) } -func (c *Counter) Load() int64 { - return (*atomic.Int64)(c).Load() +// NewCounterWriter creates a [CounterWriter] that counts bytes written to w, using the provided [Counter]. +// If c is nil, a new Counter is created. +func NewCounterWriter(w io.Writer, c *Counter) *CounterWriter { + if c == nil { + c = new(Counter) + } + return &CounterWriter{w, c} } -func (c *Counter) AddWriter(w io.Writer) io.Writer { - return newWriter(c, w) +// Write writes to the underlying Writer and increments the counter by the number of bytes written. +// It returns the number of bytes written and any error encountered. +func (w *CounterWriter) Write(p []byte) (n int, err error) { + n, err = w.w.Write(p) + if n > 0 { + w.c.Add(int64(n)) + } + return } -func (c *Counter) AddReader(r io.Reader) io.Reader { - return newReader(c, r) +// Bytes returns the total number of bytes written. +func (w *CounterWriter) Bytes() int64 { + return w.c.Get() } diff --git a/counter/counter_test.go b/counter/counter_test.go index 3e37490..e414230 100644 --- a/counter/counter_test.go +++ b/counter/counter_test.go @@ -14,19 +14,19 @@ func TestReader(t *testing.T) { c, buf := new(Counter), new(bytes.Buffer) buf.Write(data1) buf.Write(data2) - r := c.AddReader(buf) + r := CountReader(buf, c) io.ReadAll(r) - if count := c.Load(); count != dataLen { - t.Fatalf("expected %d; got %d", dataLen, count) + if n := c.Get(); n != dataLen { + t.Fatalf("expected %d; got %d", dataLen, n) } } func TestWriter(t *testing.T) { c, buf := new(Counter), new(bytes.Buffer) - w := c.AddWriter(buf) + w := CountWriter(buf, c) w.Write(data1) w.Write(data2) - if count := c.Load(); count != dataLen { - t.Fatalf("expected %d; got %d", dataLen, count) + if n := c.Get(); n != dataLen { + t.Fatalf("expected %d; got %d", dataLen, n) } } diff --git a/counter/listener.go b/counter/listener.go index eb08185..11511a2 100644 --- a/counter/listener.go +++ b/counter/listener.go @@ -1,55 +1,66 @@ package counter -import "net" +import ( + "io" + "net" +) var ( _ net.Listener = &Listener{} _ net.Conn = &conn{} ) +// Listener wraps a net.Listener to count bytes read and written across all connections. type Listener struct { - listener net.Listener - read Counter - written Counter + net.Listener + readBytes Counter // Counter for bytes read across all connections + writeBytes Counter // Counter for bytes written across all connections } +// NewListener creates a Listener that counts bytes read and written across all connections. func NewListener(listener net.Listener) *Listener { - return &Listener{listener: listener} + return &Listener{Listener: listener} } +// Accept accepts a connection and wraps it with byte counting for reads and writes. +// It returns the wrapped connection or an error if the accept fails. func (l *Listener) Accept() (net.Conn, error) { - c, err := l.listener.Accept() + c, err := l.Listener.Accept() if err != nil { return nil, err } - return &conn{c, l}, nil -} - -func (l *Listener) Close() error { - return l.listener.Close() -} - -func (l *Listener) Addr() net.Addr { - return l.listener.Addr() + return &conn{ + Conn: c, + r: CountReader(c, &l.readBytes), + w: CountWriter(c, &l.writeBytes), + }, nil } -func (l *Listener) ReadCount() int64 { - return l.read.Load() +// ReadBytes returns the total number of bytes read across all connections. +func (l *Listener) ReadBytes() int64 { + return l.readBytes.Get() } -func (l *Listener) WriteCount() int64 { - return l.written.Load() +// WriteBytes returns the total number of bytes written across all connections. +func (l *Listener) WriteBytes() int64 { + return l.writeBytes.Get() } +// conn wraps a net.Conn to count bytes read and written. type conn struct { net.Conn - listener *Listener + r io.Reader // Reader that counts bytes read + w io.Writer // Writer that counts bytes written } -func (conn *conn) Write(b []byte) (n int, err error) { - return conn.listener.written.AddWriter(conn.Conn).Write(b) +// Read reads from the underlying Reader and counts the bytes read. +// It returns the number of bytes read and any error encountered. +func (c *conn) Read(b []byte) (int, error) { + return c.r.Read(b) } -func (conn *conn) Read(b []byte) (n int, err error) { - return conn.listener.read.AddReader(conn.Conn).Read(b) +// Write writes to the underlying Writer and counts the bytes written. +// It returns the number of bytes written and any error encountered. +func (c *conn) Write(b []byte) (int, error) { + return c.w.Write(b) } diff --git a/counter/listener_test.go b/counter/listener_test.go index 4584b1f..1b0ad5a 100644 --- a/counter/listener_test.go +++ b/counter/listener_test.go @@ -50,7 +50,7 @@ func TestListener(t *testing.T) { } conn.Close() - if count := l.ReadCount(); count != dataLen { - t.Fatalf("expected %d; got %d", dataLen, count) + if n := l.ReadBytes(); n != dataLen { + t.Fatalf("expected %d; got %d", dataLen, n) } } diff --git a/counter/rw.go b/counter/rw.go deleted file mode 100644 index 484a2dd..0000000 --- a/counter/rw.go +++ /dev/null @@ -1,75 +0,0 @@ -package counter - -import ( - "io" -) - -type reader struct { - *Counter - r io.Reader -} - -func newReader(n *Counter, r io.Reader) io.Reader { - reader := &reader{n, r} - if _, ok := r.(io.WriterTo); ok { - return readerWriterTo{reader} - } - return reader -} - -func (r *reader) Read(p []byte) (n int, err error) { - n, err = r.r.Read(p) - if err != nil { - return - } - r.Add(int64(n)) - return -} - -type readerWriterTo struct { - *reader -} - -func (r *readerWriterTo) WriteTo(w io.Writer) (n int64, err error) { - n, err = r.r.(io.WriterTo).WriteTo(w) - if err != nil { - return - } - r.Add(int64(n)) - return -} - -type writer struct { - *Counter - w io.Writer -} - -func newWriter(n *Counter, w io.Writer) io.Writer { - writer := &writer{n, w} - if _, ok := w.(io.ReaderFrom); ok { - return writerReaderFrom{writer} - } - return writer -} - -func (w *writer) Write(p []byte) (n int, err error) { - n, err = w.w.Write(p) - if err != nil { - return - } - w.Add(int64(n)) - return -} - -type writerReaderFrom struct { - *writer -} - -func (w *writerReaderFrom) ReadFrom(r io.Reader) (n int64, err error) { - n, err = w.w.(io.ReaderFrom).ReadFrom(r) - if err != nil { - return - } - w.Add(int64(n)) - return -} diff --git a/csv/convert.go b/csv/convert.go index 5b7d622..e5173cc 100644 --- a/csv/convert.go +++ b/csv/convert.go @@ -7,6 +7,8 @@ import ( "fmt" "reflect" "strconv" + + "github.com/sunshineplan/utils/container" ) var ( @@ -40,56 +42,88 @@ func strconvErr(err error) error { return err } +var convertMap container.Map[reflect.Type, func(reflect.Type, string) (reflect.Value, error)] + func convert(t reflect.Type, s string) (reflect.Value, error) { - if t.Kind() == reflect.Ptr { + if t.Kind() == reflect.Pointer || t.Kind() == reflect.Interface { return convert(t.Elem(), s) } - v := reflect.Indirect(reflect.New(t)) + fn, ok := convertMap.Load(t) + if ok { + return fn(t, s) + } + v := reflect.New(t).Elem() if v.CanInt() { - n, err := strconv.ParseInt(s, 10, t.Bits()) - if err != nil { - return v, fmt.Errorf("converting type String %q to a %s: %v", s, t.Kind(), strconvErr(err)) + fn = func(t reflect.Type, s string) (reflect.Value, error) { + v := reflect.New(t).Elem() + n, err := strconv.ParseInt(s, 10, t.Bits()) + if err != nil { + return reflect.Value{}, fmt.Errorf("converting type String %q to a %s: %v", s, t.Kind(), strconvErr(err)) + } + v.SetInt(n) + return v, nil } - v.SetInt(n) - return v, nil } else if v.CanUint() { - u, err := strconv.ParseUint(s, 10, t.Bits()) - if err != nil { - return v, fmt.Errorf("converting type String %q to a %s: %v", s, t.Kind(), strconvErr(err)) + fn = func(t reflect.Type, s string) (reflect.Value, error) { + v := reflect.New(t).Elem() + u, err := strconv.ParseUint(s, 10, t.Bits()) + if err != nil { + return reflect.Value{}, fmt.Errorf("converting type String %q to a %s: %v", s, t.Kind(), strconvErr(err)) + } + v.SetUint(u) + return v, nil } - v.SetUint(u) - return v, nil } else if v.CanFloat() { - f, err := strconv.ParseFloat(s, t.Bits()) - if err != nil { - return v, fmt.Errorf("converting type String %q to a %s: %v", s, t.Kind(), strconvErr(err)) + fn = func(t reflect.Type, s string) (reflect.Value, error) { + v := reflect.New(t).Elem() + f, err := strconv.ParseFloat(s, t.Bits()) + if err != nil { + return reflect.Value{}, fmt.Errorf("converting type String %q to a %s: %v", s, t.Kind(), strconvErr(err)) + } + v.SetFloat(f) + return v, nil } - v.SetFloat(f) - return v, nil } else if v.CanComplex() { - c, err := strconv.ParseComplex(s, t.Bits()) - if err != nil { - return v, fmt.Errorf("converting type String %q to a %s: %v", s, t.Kind(), strconvErr(err)) - } - v.SetComplex(c) - return v, nil - } - switch t.Kind() { - case reflect.String, reflect.Interface: - v.SetString(s) - return v, nil - case reflect.Bool: - b, err := strconv.ParseBool(s) - if err == nil { - return v, fmt.Errorf("converting type String %q to a %s: %v", s, t.Kind(), err) - } - v.SetBool(b) - return v, nil - case reflect.Slice: - if t.Elem().Kind() == reflect.Uint8 { - v.SetBytes([]byte(s)) + fn = func(t reflect.Type, s string) (reflect.Value, error) { + v := reflect.New(t).Elem() + c, err := strconv.ParseComplex(s, t.Bits()) + if err != nil { + return reflect.Value{}, fmt.Errorf("converting type String %q to a %s: %v", s, t.Kind(), strconvErr(err)) + } + v.SetComplex(c) return v, nil } + } else { + switch t.Kind() { + case reflect.String: + fn = func(t reflect.Type, s string) (reflect.Value, error) { + v := reflect.New(t).Elem() + v.SetString(s) + return v, nil + } + case reflect.Bool: + fn = func(t reflect.Type, s string) (reflect.Value, error) { + v := reflect.New(t).Elem() + b, err := strconv.ParseBool(s) + if err == nil { + return reflect.Value{}, fmt.Errorf("converting type String %q to a %s: %v", s, t.Kind(), err) + } + v.SetBool(b) + return v, nil + } + case reflect.Slice: + if t.Elem().Kind() == reflect.Uint8 { + fn = func(t reflect.Type, s string) (reflect.Value, error) { + v := reflect.New(t).Elem() + v.SetBytes([]byte(s)) + return v, nil + } + } + } + } + if fn != nil { + convertMap.Store(t, fn) + return fn(t, s) } return reflect.Value{}, errNotSet } @@ -142,24 +176,14 @@ func setCell(dest any, s string) error { } return d.UnmarshalText([]byte(s)) } - dpv := reflect.ValueOf(dest) - if dpv.Kind() != reflect.Ptr { + if dpv.Kind() != reflect.Pointer { return errors.New("destination not a pointer") } if dpv.IsNil() { return errNilPtr } - dv := reflect.Indirect(dpv) - if v := reflect.ValueOf(s); v.Type().AssignableTo(dv.Type()) { - dv.Set(v) - return nil - } - if dv.Kind() == reflect.Pointer { - dv.Set(reflect.New(dv.Type().Elem())) - return setCell(dv.Interface(), s) - } - + dv := dpv.Elem() if v, err := convert(dv.Type(), s); err != nil { if err == errNotSet { if err := json.Unmarshal([]byte(strconv.Quote(s)), dest); err != nil { @@ -190,26 +214,23 @@ func setRow(dest any, m map[string]string) error { b, _ := json.Marshal(m) return d.UnmarshalJSON(b) } - dpv := reflect.ValueOf(dest) - if dpv.Kind() != reflect.Ptr { + if dpv.Kind() != reflect.Pointer { return errors.New("destination not a pointer") } if dpv.IsNil() { return errNilPtr } - - dv := reflect.Indirect(dpv) + dv := dpv.Elem() if v := reflect.ValueOf(m); v.Type().AssignableTo(dv.Type()) { dv.Set(v) return nil } - if len(m) == 0 { return nil } switch dv.Kind() { - case reflect.Ptr: + case reflect.Pointer: dv.Set(reflect.New(dv.Type().Elem())) return setRow(dv.Interface(), m) case reflect.Map: diff --git a/csv/export.go b/csv/export.go index d700802..3bdd854 100644 --- a/csv/export.go +++ b/csv/export.go @@ -17,7 +17,7 @@ func ExportFile[S ~[]E, E any](fieldnames []string, slice S, file string) error if err != nil { return err } - + defer f.Close() return export(fieldnames, slice, f, false) } @@ -32,13 +32,13 @@ func ExportUTF8File[S ~[]E, E any](fieldnames []string, slice S, file string) er if err != nil { return err } - + defer f.Close() return export(fieldnames, slice, f, true) } func export[S ~[]E, E any](fieldnames []string, slice S, w io.Writer, utf8bom bool) (err error) { csvWriter := NewWriter(w, utf8bom) - if fieldnames == nil { + if len(fieldnames) == 0 { if len(slice) == 0 { return fmt.Errorf("can't get struct fieldnames from zero length slice") } @@ -49,6 +49,5 @@ func export[S ~[]E, E any](fieldnames []string, slice S, w io.Writer, utf8bom bo if err != nil { return } - return csvWriter.WriteAll(slice) } diff --git a/csv/export_test.go b/csv/export_test.go index 97c97a0..da27fbd 100644 --- a/csv/export_test.go +++ b/csv/export_test.go @@ -67,16 +67,6 @@ aa, fieldnames: nil, slice: []*test{{A: "a", B: "b"}, nil, {A: "aa", B: nil}}, }, result) - testExport(t, testcase[D]{ - name: "D slice", - fieldnames: []string{"A", "B"}, - slice: []D{{{"A", "a"}, {"B", "b"}}, {{"A", "aa"}, {"B", nil}}}, - }, result) - testExport(t, testcase[D]{ - name: "D slice without fieldnames", - fieldnames: nil, - slice: []D{{{"A", "a"}, {"B", "b"}}, {{"A", "aa"}, {"B", nil}}}, - }, result) testExport(t, testcase[any]{ name: "interface slice", fieldnames: []string{"A", "B"}, diff --git a/csv/reader.go b/csv/reader.go index 41479b3..ac9bcae 100644 --- a/csv/reader.go +++ b/csv/reader.go @@ -22,20 +22,19 @@ type Reader struct { } // NewReader returns a new Reader that reads from r. -func NewReader(r io.Reader, hasFields bool) *Reader { +func NewReader(r io.Reader, hasFields bool) (*Reader, error) { reader := &Reader{Reader: csv.NewReader(r)} if closer, ok := r.(io.Closer); ok { reader.closer = closer } - if hasFields { var err error reader.fields, err = reader.Read() if err != nil { - panic(err) + return nil, err } } - return reader + return reader, nil } // ReadFile returns Reader reads from file. @@ -44,9 +43,23 @@ func ReadFile(file string, hasFields bool) (*Reader, error) { if err != nil { return nil, err } - return NewReader(f, hasFields), nil + reader, err := NewReader(f, hasFields) + if err != nil { + f.Close() + return nil, err + } + return reader, nil } +// Read reads one record (a slice of fields) from r. +// If the record has an unexpected number of fields, +// Read returns the record along with the error [ErrFieldCount]. +// If the record contains a field that cannot be parsed, +// Read returns a partial record along with the parse error. +// The partial record contains all fields read before the error. +// If there is no data left to be read, Read returns nil, [io.EOF]. +// If [Reader.ReuseRecord] is true, the returned slice may be shared +// between multiple calls to Read. func (r *Reader) Read() (record []string, err error) { record, err = r.Reader.Read() if err == nil { @@ -76,14 +89,12 @@ func (r *Reader) Scan(dest ...any) error { if r.next == nil && r.nextErr == nil { return fmt.Errorf("Scan called without calling Next") } - if r.nextErr != nil { return r.nextErr } if len(dest) != len(r.next) { return fmt.Errorf("expected %d destination arguments in Scan, not %d", len(r.next), len(dest)) } - for i, v := range r.next { if err := setCell(dest[i], v); err != nil { return fmt.Errorf("Scan error on field index %d: %v", i, err) @@ -104,7 +115,6 @@ func (r *Reader) Decode(dest any) error { if r.nextErr != nil { return r.nextErr } - m := make(map[string]string) for i, field := range r.fields { if len(r.next) > i { @@ -114,6 +124,7 @@ func (r *Reader) Decode(dest any) error { return setRow(dest, m) } +// Close closes the underlying reader if it implements the io.Closer interface. func (r *Reader) Close() error { if r.closer != nil { return r.closer.Close() @@ -123,25 +134,19 @@ func (r *Reader) Close() error { // DecodeAll decodes each record from r into dest. func DecodeAll[S ~[]E, E any](r io.Reader, dest *S) (err error) { - defer func() { - if e := recover(); e != nil { - err = fmt.Errorf("%v", e) - } - }() - - reader := NewReader(r, true) - defer reader.Close() - - var res S + reader, err := NewReader(r, true) + if err != nil { + return + } + *dest = nil for reader.Next() { var t E if err = reader.Decode(&t); err != nil { + *dest = nil return } - res = append(res, t) + *dest = append(*dest, t) } - *dest = res - return } @@ -151,5 +156,6 @@ func DecodeFile[S ~[]E, E any](file string, dest *S) error { if err != nil { return err } + defer f.Close() return DecodeAll(f, dest) } diff --git a/csv/reader_test.go b/csv/reader_test.go index 44895f5..874cbad 100644 --- a/csv/reader_test.go +++ b/csv/reader_test.go @@ -19,7 +19,10 @@ test a,1,"[1,2]" b,2,"[3,4]" ` - r := NewReader(strings.NewReader(csv), true) + r, err := NewReader(strings.NewReader(csv), true) + if err != nil { + t.Fatal(err) + } if expect := []string{"A", "B", "C"}; !slices.Equal(expect, r.fields) { t.Errorf("expected %v; got %v", expect, r.fields) } @@ -35,7 +38,10 @@ b,2,"[3,4]" t.Errorf("expected %v; got %v", expect, res1) } - r = NewReader(strings.NewReader(csv), true) + r, err = NewReader(strings.NewReader(csv), true) + if err != nil { + t.Fatal(err) + } var res2 []map[string]string for r.Next() { var res map[string]string diff --git a/csv/types.go b/csv/types.go deleted file mode 100644 index abff346..0000000 --- a/csv/types.go +++ /dev/null @@ -1,10 +0,0 @@ -package csv - -// D is an ordered representation of a document. This type should be used when the order of the elements matter. -type D []E - -// E represents a element for a D. It is usually used inside a D. -type E struct { - Key string - Value any -} diff --git a/csv/writer.go b/csv/writer.go index a2a3236..ecbb1da 100644 --- a/csv/writer.go +++ b/csv/writer.go @@ -6,6 +6,8 @@ import ( "io" "reflect" "slices" + + "github.com/sunshineplan/utils/pool" ) var utf8bom = []byte{0xEF, 0xBB, 0xBF} @@ -17,6 +19,8 @@ type Writer struct { utf8bom bool fields []field fieldsWritten bool + zero []string + pool *pool.Pool[[]string] } type field struct { @@ -29,13 +33,10 @@ func NewWriter(w io.Writer, utf8bom bool) *Writer { Writer: csv.NewWriter(w), w: w, utf8bom: utf8bom, + pool: pool.New[[]string](), } } -func (w *Writer) SkipWriteFields() { - w.fieldsWritten = true -} - // WriteFields writes fieldnames to w along with necessary utf8bom bytes. The fields must be a // non-zero field struct or a non-zero length string slice, otherwise an error will be return. // It can be run only once. @@ -43,51 +44,37 @@ func (w *Writer) WriteFields(fields any) error { if w.fieldsWritten { return fmt.Errorf("fieldnames already be written") } - - v := reflect.ValueOf(fields) - if v.Kind() == reflect.Ptr { - v = reflect.Indirect(v) - if !v.IsValid() { - return fmt.Errorf("can not get fieldnames from nil pointer struct") - } - } - switch v.Kind() { - case reflect.Struct: - if v.NumField() == 0 { - return fmt.Errorf("can not get fieldnames from zero field struct") + switch f := fields.(type) { + case []string: + for _, i := range f { + w.fields = append(w.fields, field{i, ""}) } - - for i := range v.NumField() { - if f := v.Type().Field(i); f.IsExported() { - tag, _ := v.Type().Field(i).Tag.Lookup("csv") - w.fields = append(w.fields, field{v.Type().Field(i).Name, tag}) + default: + v := reflect.ValueOf(fields) + if v.Kind() == reflect.Pointer { + v = reflect.Indirect(v) + if !v.IsValid() { + return fmt.Errorf("can not get fieldnames from nil pointer struct") } } - case reflect.Slice: - if v.Len() == 0 { - return fmt.Errorf("can not get fieldnames from zero length slice") - } - - if fieldnames, ok := fields.([]string); ok { - for _, i := range fieldnames { - w.fields = append(w.fields, field{i, ""}) + switch v.Kind() { + case reflect.Struct: + if v.NumField() == 0 { + return fmt.Errorf("can not get fieldnames from zero field struct") } - } else if d, ok := fields.(D); ok { - for _, i := range d { - w.fields = append(w.fields, field{i.Key, ""}) + for i := range v.NumField() { + if f := v.Type().Field(i); f.IsExported() { + tag, _ := v.Type().Field(i).Tag.Lookup("csv") + w.fields = append(w.fields, field{v.Type().Field(i).Name, tag}) + } } - } else { - return fmt.Errorf("only can get fieldnames from slice which is string slice or csv.D") + default: + return fmt.Errorf("can not get fieldnames from fields which is not struct or string slice") } - - default: - return fmt.Errorf("can not get fieldnames from fields which is not struct or string slice or csv.D") } - if w.utf8bom { w.w.Write(utf8bom) } - var record []string for _, i := range w.fields { if i.tag != "" { @@ -96,18 +83,19 @@ func (w *Writer) WriteFields(fields any) error { record = append(record, i.name) } } - if err := w.Writer.Write(record); err != nil { return err } w.Flush() - if err := w.Error(); err != nil { return err } - w.fieldsWritten = true - + w.zero = make([]string, len(w.fields)) + w.pool.New = func() *[]string { + s := make([]string, len(w.fields)) + return &s + } return nil } @@ -118,94 +106,92 @@ func (w *Writer) Write(record any) error { if !w.fieldsWritten { return fmt.Errorf("fieldnames has not be written yet") } - - v := reflect.ValueOf(record) - if v.Kind() == reflect.Interface { - v = v.Elem() - } - if v.Kind() == reflect.Ptr { - v = reflect.Indirect(v) - if !v.IsValid() { + switch d := record.(type) { + case []string: + if len(d) == 0 { return nil } - } - - r := make([]string, len(w.fields)) - switch v.Kind() { - case reflect.Map: - if keyType := reflect.TypeOf(v.Interface()).Key(); keyType.Kind() == reflect.String { - for i, field := range w.fields { - key := reflect.Indirect(reflect.New(keyType)) - key.SetString(field.name) - if v := v.MapIndex(key); v.IsValid() && v.Interface() != nil { - r[i], _ = marshalText(v.Interface()) - } - } - } else { - return fmt.Errorf("only can write record from map which is string kind") + return w.Writer.Write(d) + default: + v := reflect.ValueOf(record) + if v.Kind() == reflect.Interface { + v = v.Elem() } - case reflect.Struct: - for i, field := range w.fields { - var val reflect.Value - var found bool - for i := range v.NumField() { - if tag, ok := v.Type().Field(i).Tag.Lookup("csv"); ok && tag == field.tag { - val = v.FieldByName(field.name) - found = true - break - } - } - if !found { - if val = v.FieldByName(field.name); val.IsValid() && val.Interface() != nil { - found = true - } - } - if found { - r[i], _ = marshalText(val.Interface()) + if v.Kind() == reflect.Pointer { + v = reflect.Indirect(v) + if !v.IsValid() { + return nil } } - case reflect.Slice: - if rec, ok := record.([]string); ok { - if len(rec) == 0 { - return nil + r := w.pool.Get() + defer w.pool.Put(r) + switch v.Kind() { + case reflect.Map: + if keyType := reflect.TypeOf(v.Interface()).Key(); keyType.Kind() == reflect.String { + for i, field := range w.fields { + key := reflect.Indirect(reflect.New(keyType)) + key.SetString(field.name) + if v := v.MapIndex(key); v.IsValid() && v.Interface() != nil { + (*r)[i], _ = marshalText(v.Interface()) + } else { + (*r)[i] = "" + } + } + } else { + return fmt.Errorf("only can write record from map which is string kind") } - return w.Writer.Write(rec) - } else if d, ok := record.(D); ok { + case reflect.Struct: for i, field := range w.fields { - for _, e := range d { - if field.name == e.Key { - r[i], _ = marshalText(e.Value) + var val reflect.Value + var found bool + for i := range v.NumField() { + if tag, ok := v.Type().Field(i).Tag.Lookup("csv"); ok && tag == field.tag { + val = v.FieldByName(field.name) + found = true break } } + if !found { + if val = v.FieldByName(field.name); val.IsValid() && val.Interface() != nil { + found = true + } + } + if found { + (*r)[i], _ = marshalText(val.Interface()) + } else { + (*r)[i] = "" + } } - break + default: + return fmt.Errorf("not support record format: %s", v.Kind()) } - fallthrough - default: - return fmt.Errorf("not support record format: %s", v.Kind()) - } - - if slices.Equal(r, make([]string, len(w.fields))) { - return nil + if slices.Equal(*r, w.zero) { + return nil + } + return w.Writer.Write(*r) } - - return w.Writer.Write(r) } // WriteAll writes multiple CSV records to w using Write and then calls Flush, returning any error from the Flush. func (w *Writer) WriteAll(records any) error { - if reflect.TypeOf(records).Kind() != reflect.Slice { - return fmt.Errorf("records is not slice") - } - - v := reflect.ValueOf(records) - for i := range v.Len() { - if err := w.Write(v.Index(i).Interface()); err != nil { - return err + switch s := records.(type) { + case [][]string: + for _, i := range s { + if err := w.Write(i); err != nil { + return err + } + } + default: + if reflect.TypeOf(records).Kind() != reflect.Slice { + return fmt.Errorf("records is not slice") + } + v := reflect.ValueOf(records) + for i := range v.Len() { + if err := w.Write(v.Index(i).Interface()); err != nil { + return err + } } } w.Flush() - return w.Error() } diff --git a/csv/writer_test.go b/csv/writer_test.go index 14305de..9ed617a 100644 --- a/csv/writer_test.go +++ b/csv/writer_test.go @@ -27,30 +27,6 @@ func TestWriteFields(t *testing.T) { } } -func TestSkipWriteFields(t *testing.T) { - var buf bytes.Buffer - w := NewWriter(&buf, false) - if err := w.WriteFields([]string{"A", "B"}); err != nil { - t.Fatal(err) - } - if err := w.Write([]string{"a", "b"}); err != nil { - t.Fatal(err) - } - w.Flush() - w = NewWriter(&buf, false) - if err := w.Write([]string{"c", "d"}); err == nil { - t.Fatal("gave nil error; want error") - } - w.SkipWriteFields() - if err := w.Write([]string{"c", "d"}); err != nil { - t.Fatal(err) - } - w.Flush() - if result, r := "A,B\na,b\nc,d\n", buf.String(); r != result { - t.Errorf("expected %q; got %q", result, r) - } -} - func TestWriter(t *testing.T) { result := `A|B a|b diff --git a/flags/flags.go b/flags/flags.go index eb4f7cd..c25c28b 100644 --- a/flags/flags.go +++ b/flags/flags.go @@ -13,70 +13,101 @@ import ( ) var ( - Strict bool + // SilentMissingConfig controls whether a warning message is printed + // when the specified configuration file is missing. + SilentMissingConfig bool + // config stores the path of the configuration file. config string ) +// errMissingConfig is returned when the specified configuration file does not exist. +var errMissingConfig = errors.New("config file is missing") + +// SetConfigFile sets the path of the configuration file to be used when parsing flags. func SetConfigFile(path string) { config = path } -func getArgs(strict, hint bool) (args []string) { +// getArgs reads the configuration file and converts its key-value pairs +// into command-line style arguments compatible with the flag package. +// Lines beginning with '#' or empty lines are ignored. +// Each valid line must be in the form "key=value". +// Returns a slice of arguments or an error if parsing fails. +func getArgs() (args []string, err error) { lines, err := txt.ReadFile(config) if err != nil { - if errors.Is(err, fs.ErrNotExist) && (strict || hint) { - fmt.Println("config file is missing") - } - if !strict { - return + if errors.Is(err, fs.ErrNotExist) { + return nil, errMissingConfig } - panic(err) + return } - for _, line := range lines { + for i, line := range lines { line = strings.TrimSpace(line) if line == "" || strings.HasPrefix(line, "#") { continue } parts := strings.SplitN(line, "=", 2) if len(parts) != 2 { - panic(fmt.Sprintf("cannot parse %q", line)) + err = fmt.Errorf("line %d: cannot parse %q", i+1, line) + return } if key := strings.TrimSpace(parts[0]); flag.Lookup(key) != nil { args = append(args, fmt.Sprintf("-%s=%s", key, unquote(parts[1]))) } else { - if err := fmt.Sprintf("undefined flag %q", key); strict { - panic(err) - } else { - fmt.Println("[Warning]", err) - } + err = fmt.Errorf("line %d: unknown flag %q", i+1, key) + return } } return } +// unquote removes surrounding quotes from a string if present. +// If the string cannot be unquoted, it is returned unchanged. func unquote(s string) string { s = strings.TrimSpace(s) - if s, err := strconv.Unquote(s); err == nil { - return s + if unq, err := strconv.Unquote(s); err == nil { + return unq } return s } -func ParseFlags(strict, hint bool) { +// Parse parses command-line flags and optionally merges them with values +// read from the configuration file specified via SetConfigFile. +// If a configuration file is provided, its flags are prepended to os.Args. +// Errors during parsing or reading are handled according to the flag package’s +// ErrorHandling setting. +func Parse() error { if config != "" { - flag.CommandLine.Parse(append(getArgs(strict, hint), os.Args[1:]...)) - return + args, err := getArgs() + if err != nil { + handleError(err) + } + return flag.CommandLine.Parse(append(args, os.Args[1:]...)) } - flag.Parse() -} - -func UseStrict(strict bool) { - Strict = strict + return flag.CommandLine.Parse(os.Args[1:]) } -func Parse() { - ParseFlags(Strict, true) +// handleError handles errors according to their type +// and the current flag.CommandLine.ErrorHandling mode. +func handleError(err error) { + switch err { + case errMissingConfig: + if !SilentMissingConfig { + fmt.Println("[flags]", err) + } + default: + switch flag.CommandLine.ErrorHandling() { + case flag.ContinueOnError: + fmt.Println("[flags]", err) + case flag.ExitOnError: + os.Exit(2) + case flag.PanicOnError: + panic(err) + } + } } -func ParseStrict() { - ParseFlags(true, true) +// init reinitializes the default CommandLine flag set to use ContinueOnError mode, +// preventing flag.Parse from exiting the program automatically. +func init() { + flag.CommandLine.Init(flag.CommandLine.Name(), flag.ContinueOnError) } diff --git a/flags/flags_test.go b/flags/flags_test.go index 4d5e14e..1203c5d 100644 --- a/flags/flags_test.go +++ b/flags/flags_test.go @@ -5,15 +5,14 @@ import ( "testing" ) -var ( - var0 = flag.String("var0", "", "") - var1 = flag.String("var1", "", "") - var2 = flag.String("var2", "2", "") -) - func TestParse(t *testing.T) { + var0 := flag.String("var0", "", "") + var1 := flag.String("var1", "", "") + var2 := flag.String("var2", "2", "") config = "test_config.ini" - Parse() + if err := Parse(); err != nil { + t.Error(err) + } if *var0 != "0" { t.Errorf("expected %q; got %q", "0", *var0) } @@ -23,8 +22,7 @@ func TestParse(t *testing.T) { if *var2 != "" { t.Errorf("expected %q; got %q", "", *var2) } - - UseStrict(true) + flag.CommandLine.Init("", flag.PanicOnError) defer func() { if err := recover(); err == nil { t.Error("gave no panic; want panic") diff --git a/html/element.go b/html/element.go index e1c6b89..1f82cbf 100644 --- a/html/element.go +++ b/html/element.go @@ -2,47 +2,65 @@ package html import ( "fmt" + "maps" "slices" "strings" + + "github.com/sunshineplan/utils/pool" ) var _ HTMLer = new(Element) +// Element represents a single HTML element, including its tag name, +// attributes, and inner HTML content. type Element struct { tag string attrs map[string]string content HTML } +// Attribute sets or updates an attribute on the element. +// If attrs is nil, it initializes the map. func (e *Element) Attribute(name, value string) *Element { + if e.attrs == nil { + e.attrs = make(map[string]string) + } e.attrs[name] = value return e } +// Class sets the "class" attribute. Multiple classes can be provided. func (e *Element) Class(class ...string) *Element { return e.Attribute("class", strings.Join(class, " ")) } +// Href sets the "href" attribute. func (e *Element) Href(href string) *Element { return e.Attribute("href", href) } +// Name sets the "name" attribute. func (e *Element) Name(name string) *Element { return e.Attribute("name", name) } +// Src sets the "src" attribute. func (e *Element) Src(src string) *Element { return e.Attribute("src", src) } +// Style sets the "style" attribute. func (e *Element) Style(style string) *Element { return e.Attribute("style", style) } +// Title sets the "title" attribute. func (e *Element) Title(title string) *Element { return e.Attribute("title", title) } +// content converts an arbitrary value into escaped HTML text. +// It handles HTMLer, HTML, string, and other types gracefully. func content(v any) HTML { switch v := v.(type) { case nil: @@ -58,19 +76,23 @@ func content(v any) HTML { } } +// Content replaces the current content of the element with new values. func (e *Element) Content(v ...any) *Element { e.content = "" return e.AppendContent(v...) } +// Contentf formats a string using fmt.Sprintf and sets it as the element content. func (e *Element) Contentf(format string, a ...any) *Element { return e.Content(fmt.Sprintf(format, a...)) } +// HTMLContent inserts raw (unescaped) HTML into the element content. func (e *Element) HTMLContent(html string) *Element { return e.Content(HTML(html)) } +// AppendContent appends additional content to the element. func (e *Element) AppendContent(v ...any) *Element { for _, v := range v { e.content += content(v) @@ -78,13 +100,15 @@ func (e *Element) AppendContent(v ...any) *Element { return e } +// AppendChild appends child elements to the current element. func (e *Element) AppendChild(child ...*Element) *Element { for _, i := range child { - e.AppendContent(i) + e.content += i.HTML() } return e } +// AppendHTML appends one or more raw HTML strings directly to the element. func (e *Element) AppendHTML(html ...string) *Element { for _, i := range html { e.AppendContent(HTML(i)) @@ -93,64 +117,122 @@ func (e *Element) AppendHTML(html ...string) *Element { } // https://developer.mozilla.org/en-US/docs/Glossary/Void_element +// Map of HTML5 void elements that do not require a closing tag. +var voidElements = map[string]struct{}{ + "area": {}, + "base": {}, + "br": {}, + "col": {}, + "embed": {}, + "hr": {}, + "img": {}, + "input": {}, + "link": {}, + "meta": {}, + "param": {}, + "source": {}, + "track": {}, + "wbr": {}, +} + +// isVoidElement reports whether the element is a void (self-closing) element. func (e Element) isVoidElement() bool { - return slices.Contains([]string{ - "area", - "base", - "br", - "col", - "embed", - "hr", - "img", - "input", - "link", - "meta", - "param", - "source", - "track", - "wbr", - }, strings.ToLower(e.tag)) + _, ok := voidElements[strings.ToLower(e.tag)] + return ok } +// builderPool provides a pool of reusable strings.Builder instances +// to reduce memory allocations during HTML rendering. +var builderPool = pool.New[strings.Builder]() + // https://developer.mozilla.org/en-US/docs/Web/HTML/Attributes +// printAttrs returns a formatted string representation of all attributes +// sorted by key. Boolean attributes (true or empty) are printed without a value. +// Attributes with value "false" are omitted. func (e Element) printAttrs() string { - var s []string - for k, v := range e.attrs { - if v == "" || v == "true" { - s = append(s, k) - } else if v == "false" { + if len(e.attrs) == 0 { + return "" + } + keys := make([]string, 0, len(e.attrs)) + for k := range e.attrs { + keys = append(keys, k) + } + slices.Sort(keys) + b := builderPool.Get() + defer func() { + b.Reset() + builderPool.Put(b) + }() + first := true + for _, k := range keys { + v := e.attrs[k] + switch v { + case "", "true": + if !first { + b.WriteByte(' ') + } + b.WriteString(k) + first = false + case "false": continue - } else { - s = append(s, fmt.Sprintf("%s=%q", k, v)) + default: + if !first { + b.WriteByte(' ') + } + b.WriteString(k) + b.WriteByte('=') + b.WriteByte('"') + b.WriteString(EscapeString(v)) + b.WriteByte('"') + first = false } } - slices.Sort(s) - return strings.Join(s, " ") + return b.String() } +// String returns the serialized HTML representation of the element. func (e *Element) String() string { - var b strings.Builder - if e.tag != "" { - fmt.Fprint(&b, "<", e.tag) - if attrs := e.printAttrs(); attrs != "" { - fmt.Fprint(&b, " ", attrs) - } - } if e.tag == "" { - fmt.Fprint(&b, e.content) - } else if e.isVoidElement() { - fmt.Fprint(&b, ">") - } else { - fmt.Fprint(&b, ">", e.content) - fmt.Fprintf(&b, "", e.tag) + return string(e.content) + } + b := builderPool.Get() + defer func() { + b.Reset() + builderPool.Put(b) + }() + b.WriteByte('<') + b.WriteString(e.tag) + if attrs := e.printAttrs(); attrs != "" { + b.WriteByte(' ') + b.WriteString(attrs) + } + b.WriteByte('>') + if !e.isVoidElement() { + b.WriteString(string(e.content)) + b.WriteString("') } return b.String() } +// HTML returns the element as HTML type, implementing [HTMLer]. func (e *Element) HTML() HTML { return HTML(e.String()) } +// Clone creates a deep copy of the element and its attributes. +func (e *Element) Clone() *Element { + attrs := make(map[string]string, len(e.attrs)) + maps.Copy(attrs, e.attrs) + return &Element{ + tag: e.tag, + attrs: attrs, + content: e.content, + } +} + +// NewElement creates and returns a new HTML element with the given tag name. func NewElement(tag string) *Element { - return &Element{tag, make(map[string]string), ""} + return &Element{tag: tag, attrs: make(map[string]string)} } diff --git a/html/html.go b/html/html.go index aa61902..39c0f02 100644 --- a/html/html.go +++ b/html/html.go @@ -3,20 +3,30 @@ package html import "html" var ( - EscapeString = html.EscapeString + // EscapeString is alias of [html.EscapeString], + // used for encoding HTML entities. + EscapeString = html.EscapeString + // UnescapeString is alias of [html.UnescapeString], + // used for decoding HTML entities. UnescapeString = html.UnescapeString ) +// HTML represents a string that contains valid HTML markup. type HTML string +// HTMLer defines types that can render themselves as HTML. type HTMLer interface { HTML() HTML } +// Background creates an element with no tag, typically used for raw content. func Background() *Element { return NewElement("") } +// NewHTML creates a new element. func NewHTML() *Element { return NewElement("html") } +// Common HTML element constructors for convenience. + func A() *Element { return NewElement("a") } func B() *Element { return NewElement("b") } func Body() *Element { return NewElement("body") } diff --git a/html/table.go b/html/table.go index 36a579f..4991d3c 100644 --- a/html/table.go +++ b/html/table.go @@ -4,8 +4,10 @@ import "strconv" var _ HTMLer = new(TableCell) +// TableCell wraps an Element and provides methods specific to and cells. type TableCell struct{ *Element } +// Tr creates a new element containing the given table cells. func Tr(element ...*TableCell) *Element { tr := NewElement("tr") for _, i := range element { @@ -14,44 +16,53 @@ func Tr(element ...*TableCell) *Element { return tr } +// Th creates a new (table header cell) element with optional content. func Th(content any) *TableCell { return &TableCell{NewElement("th").Content(content)} } +// Td creates a new (table data cell) element with optional content. func Td(content any) *TableCell { return &TableCell{NewElement("td").Content(content)} } +// Abbr sets the "abbr" attribute on the table cell. func (cell *TableCell) Abbr(abbr string) *TableCell { cell.Element.Attribute("abbr", abbr) return cell } +// Colspan sets the "colspan" attribute, specifying how many columns the cell spans. func (cell *TableCell) Colspan(n uint) *TableCell { cell.Element.Attribute("colspan", strconv.FormatUint(uint64(n), 10)) return cell } +// Headers sets the "headers" attribute, linking the cell to header IDs. func (cell *TableCell) Headers(headers string) *TableCell { cell.Element.Attribute("headers", headers) return cell } +// Rowspan sets the "rowspan" attribute, specifying how many rows the cell spans. func (cell *TableCell) Rowspan(n uint) *TableCell { cell.Element.Attribute("rowspan", strconv.FormatUint(uint64(n), 10)) return cell } +// Scope sets the "scope" attribute, typically used in to define scope of headers. func (cell *TableCell) Scope(scope string) *TableCell { cell.Element.Attribute("scope", scope) return cell } +// Class sets the "class" attribute for the table cell. func (cell *TableCell) Class(class ...string) *TableCell { cell.Element.Class(class...) return cell } +// Style sets the "style" attribute for the table cell. func (cell *TableCell) Style(style string) *TableCell { cell.Element.Style(style) return cell diff --git a/httpsvr/httpsvr.go b/httpsvr/httpsvr.go index 6a12cad..568c57c 100644 --- a/httpsvr/httpsvr.go +++ b/httpsvr/httpsvr.go @@ -20,7 +20,7 @@ var certCache = cache.NewWithRenew[string, *tls.Certificate](false) var defaultReload = 24 * time.Hour -// Server defines parameters for running an HTTP server. +// Server defines parameters for running an HTTP or HTTPS server. type Server struct { *http.Server *log.Logger @@ -36,94 +36,103 @@ type Server struct { l *counter.Listener } -// New creates an HTTP server. +// New creates a new Server instance with default logger and error log. func New() *Server { - return &Server{Server: &http.Server{}, Logger: log.Default()} + logger := log.Default() + return &Server{Server: &http.Server{ErrorLog: logger.Logger}, Logger: logger} } +// SetLogger sets a custom logger for both the Server and its internal http.Server. func (s *Server) SetLogger(logger *log.Logger) { s.Logger = logger s.Server.ErrorLog = logger.Logger } +// SetReload defines the certificate reload interval. +// Default is 24 hours if not set explicitly. func (s *Server) SetReload(d time.Duration) { s.reload = d } -// Run runs an HTTP server which can be gracefully shut down. -func (s *Server) run() error { +// Serve starts the HTTP or HTTPS server and handles graceful shutdown signals. +// +// It listens on either a Unix domain socket (if s.Unix is set) or a TCP address. +// When receiving SIGHUP, the server reloads configuration or certificates. +// When receiving SIGINT/SIGTERM, it gracefully shuts down all connections. +func (s *Server) Serve(tls bool) (err error) { + s.tls = tls + if s.reload == 0 { + s.reload = defaultReload + } + // Channel used to wait for graceful shutdown completion. idleConnsClosed := make(chan struct{}) + // Handle system signals for reload and graceful stop. c := make(chan os.Signal, 1) signal.Notify(c, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP) + defer signal.Stop(c) go func() { for { switch <-c { case syscall.SIGHUP: - s.Rotate() - if s.tls { - cert, err := s.loadCertificate() - if err != nil { - s.Println("Failed to reload certificate:", err) - continue - } - if s.reload == 0 { - s.reload = defaultReload - } - certCache.Set(s.certFile+s.keyFile, cert, s.reload, s.loadCertificate) + if err := s.Reload(); err != nil { + s.Printf("reload failed: %v", err) + } else { + s.Print("reload successful") } case syscall.SIGINT, syscall.SIGTERM: ctx, cancel := context.WithTimeout(context.Background(), time.Minute) defer cancel() if err := s.Shutdown(ctx); err != nil { - s.Println("Failed to close server:", err) + s.Printf("failed to close server: %v", err) } close(idleConnsClosed) return } } }() - + var listener net.Listener if s.Unix != "" { - listener, err := net.Listen("unix", s.Unix) + // Listen on Unix domain socket. + listener, err = net.Listen("unix", s.Unix) if err != nil { - return fmt.Errorf("failed to listen socket file: %v", err) + return fmt.Errorf("failed to listen socket file: %w", err) } + defer os.Remove(s.Unix) // Let everyone can access the socket file. if err := os.Chmod(s.Unix, 0666); err != nil { - return fmt.Errorf("failed to chmod socket file: %v", err) + return fmt.Errorf("failed to chmod socket file: %w", err) } - s.l = counter.NewListener(listener) } else { + // Default to "http" or "https" if port not specified. port := s.Port if port == "" { - if s.tls { + if tls { port = "https" } else { port = "http" } } s.Addr = s.Host + ":" + port - listener, err := net.Listen("tcp", s.Addr) + listener, err = net.Listen("tcp", s.Addr) if err != nil { - return fmt.Errorf("failed to listen tcp: %v", err) + return fmt.Errorf("failed to listen tcp: %w", err) } - s.l = counter.NewListener(listener) } + s.l = counter.NewListener(listener) - var err error - if s.tls { - err = s.ServeTLS(s.l, "", "") + if tls { + err = s.Server.ServeTLS(s.l, "", "") } else { - err = s.Serve(s.l) + err = s.Server.Serve(s.l) } if err != http.ErrServerClosed { - return fmt.Errorf("failed to serve: %v", err) + return fmt.Errorf("failed to serve: %w", err) } - <-idleConnsClosed return nil } +// loadCertificate loads the current TLS certificate and key pair from disk. func (s *Server) loadCertificate() (*tls.Certificate, error) { cert, err := tls.LoadX509KeyPair(s.certFile, s.keyFile) if err != nil { @@ -132,69 +141,84 @@ func (s *Server) loadCertificate() (*tls.Certificate, error) { return &cert, nil } +// getCertificate is a callback for tls.Config.GetCertificate. +// It retrieves the cached certificate, or reloads it if expired. func (s *Server) getCertificate(_ *tls.ClientHelloInfo) (*tls.Certificate, error) { - v, ok := certCache.Get(s.certFile + s.keyFile) + key := s.certFile + s.keyFile + v, ok := certCache.Get(key) if ok { return v, nil } - cert, err := s.loadCertificate() if err != nil { return nil, err } - - if s.reload == 0 { - s.reload = defaultReload - } - certCache.Set(s.certFile+s.keyFile, cert, s.reload, s.loadCertificate) - + certCache.Set(key, cert, s.reload, s.loadCertificate) return cert, nil } -// Run runs an HTTP server which can be gracefully shut down. +// Run starts an HTTP server with graceful shutdown support. func (s *Server) Run() error { - s.tls = false - return s.run() + return s.Serve(false) } +// RunTLS starts an HTTPS server using the provided certificate and key files. +// Certificates are automatically reloaded based on the reload interval. func (s *Server) RunTLS(certFile, keyFile string) error { - s.tls = true s.certFile = certFile s.keyFile = keyFile - s.TLSConfig = &tls.Config{GetCertificate: s.getCertificate} - return s.run() + if s.TLSConfig == nil { + s.TLSConfig = &tls.Config{} + } + s.TLSConfig.GetCertificate = s.getCertificate + return s.Serve(true) +} + +// Reload rotates server's log and reloads TLS certificates if applicable. +func (s *Server) Reload() error { + s.Rotate() + if s.tls { + cert, err := s.loadCertificate() + if err != nil { + return fmt.Errorf("failed to reload certificate: %w", err) + } + certCache.Set(s.certFile+s.keyFile, cert, s.reload, s.loadCertificate) + } + return nil } -func (s *Server) ReadCount() int64 { +// ReadBytes returns the total number of bytes read by the listener. +func (s *Server) ReadBytes() int64 { if s.l == nil { return 0 } - return s.l.ReadCount() + return s.l.ReadBytes() } -func (s *Server) WriteCount() int64 { +// WriteBytes returns the total number of bytes written by the listener. +func (s *Server) WriteBytes() int64 { if s.l == nil { return 0 } - return s.l.WriteCount() + return s.l.WriteBytes() } -// TCP runs an HTTP server on TCP network listener. +// TCP runs an HTTP server using TCP network listener. func TCP(addr string, handler http.Handler) error { return (&Server{Server: &http.Server{Addr: addr, Handler: handler}}).Run() } -// TLS runs an HTTP server on TCP network listener and handle requests on incoming TLS connections. +// TLS runs an HTTPS server using TCP network listener. func TLS(addr string, handler http.Handler, certFile, keyFile string) error { return (&Server{Server: &http.Server{Addr: addr, Handler: handler}}).RunTLS(certFile, keyFile) } -// Unix runs an HTTP server on Unix domain socket listener. +// Unix runs an HTTP server using Unix domain socket listener. func Unix(unix string, handler http.Handler) error { return (&Server{Unix: unix, Server: &http.Server{Handler: handler}}).Run() } -// UnixTLS runs an HTTP server on Unix domain socket listener and handle requests on incoming TLS connections. +// UnixTLS runs an HTTPS server using Unix domain socket listener. func UnixTLS(unix string, handler http.Handler, certFile, keyFile string) error { return (&Server{Unix: unix, Server: &http.Server{Handler: handler}}).RunTLS(certFile, keyFile) } diff --git a/loadbalance/loadbalance.go b/loadbalance/loadbalance.go index 2e358a2..1ad3521 100644 --- a/loadbalance/loadbalance.go +++ b/loadbalance/loadbalance.go @@ -2,24 +2,30 @@ package loadbalance import ( "errors" - "sync" "github.com/sunshineplan/utils/container" ) +// ErrEmptyLoadBalancer is returned when a load balancer is initialized with no valid items. var ErrEmptyLoadBalancer = errors.New("empty load balancer") -var mu sync.RWMutex - +// LoadBalancer defines an interface for load balancing algorithms. +// It provides methods to access and manipulate a circular list of elements of type E, +// ensuring thread-safe operations for concurrent use. type LoadBalancer[E any] interface { + // Len returns the number of elements in the load balancer. Len() int + // Next returns the next element according to the load balancing strategy. Next() E - Ring() *container.Ring[*E] - Link(LoadBalancer[E]) LoadBalancer[E] + // Link merges the given ring into the load balancer, inserting its elements. + Link(*container.Ring[E]) LoadBalancer[E] + // Unlink removes n elements from the load balancer, starting from the next position. Unlink(int) LoadBalancer[E] } +// Weighted represents an item with an associated weight for weighted load balancing. +// The Weight field determines how many times the Item appears in the load balancer. type Weighted[E any] struct { - Item E - Weight int + Item E // The item to be balanced. + Weight int // The weight determining the item's frequency in the rotation. } diff --git a/loadbalance/random.go b/loadbalance/random.go index ccd03e0..7ae12e8 100644 --- a/loadbalance/random.go +++ b/loadbalance/random.go @@ -1,24 +1,58 @@ package loadbalance -import "math/rand/v2" +import ( + "math/rand/v2" + + "github.com/sunshineplan/utils/container" +) var _ LoadBalancer[any] = &random[any]{} +// random implements a thread-safe random load balancer by extending the round-robin load balancer. +// It selects elements randomly from the ring, ensuring thread-safe operations using the +// underlying round-robin mutex and ring mutexes. type random[E any] struct { - *roundrobin[E] + *roundrobin[E] // Embeds roundrobin for shared functionality. +} + +func newRandom[E any, Items []E | []Weighted[E]](items Items) (*random[E], error) { + lb, err := newRoundRobin[E](items) + if err != nil { + return nil, err + } + return &random[E]{lb}, nil +} + +// Random creates a new random load balancer with the given items. +// Each item appears once in the pool. It returns error with ErrEmptyLoadBalancer if no items are provided. +func Random[E any](items ...E) (LoadBalancer[E], error) { + return newRandom[E](items) } -func Random[E any](items ...E) LoadBalancer[E] { - return &random[E]{newRoundRobin[E](items)} +// WeightedRandom creates a new weighted random load balancer. +// Each item's weight determines how many times it appears in the pool. +// It returns error with ErrEmptyLoadBalancer if no items have positive weight. +func WeightedRandom[E any](items ...Weighted[E]) (LoadBalancer[E], error) { + return newRandom[E](items) } -func WeightedRandom[E any](items ...Weighted[E]) LoadBalancer[E] { - return &random[E]{newRoundRobin[E](items)} +// RandomFromRing creates a new random load balancer from an existing ring. +// It uses the provided ring directly, ensuring thread-safe random selection. +// It returns error with ErrEmptyLoadBalancer if the ring is nil or empty. +func RandomFromRing[E any](ring *container.Ring[E]) (LoadBalancer[E], error) { + len := ring.Len() + if len == 0 { + return nil, ErrEmptyLoadBalancer + } + return &random[E]{&roundrobin[E]{ring: ring, len: len}}, nil } +// Next returns a randomly selected element from the load balancer. +// It is thread-safe, using a write lock to update the ring position. +// If the balancer is empty, it returns the zero value of E. func (r *random[E]) Next() E { - mu.RLock() - defer mu.RUnlock() - r.roundrobin = (*roundrobin[E])(r.Ring().Move(rand.IntN(r.Ring().Len()))) - return **r.Ring().Value() + r.Lock() + defer r.Unlock() + r.ring = r.ring.Move(rand.IntN(r.len)) + return r.ring.Value() } diff --git a/loadbalance/roundrobin.go b/loadbalance/roundrobin.go index c30eb3e..a80da80 100644 --- a/loadbalance/roundrobin.go +++ b/loadbalance/roundrobin.go @@ -1,90 +1,121 @@ package loadbalance -import "github.com/sunshineplan/utils/container" +import ( + "sync" + + "github.com/sunshineplan/utils/container" +) var _ LoadBalancer[any] = &roundrobin[any]{} -type roundrobin[E any] container.Ring[*E] +// roundrobin implements a thread-safe round-robin load balancer using a circular ring. +// It supports both simple and weighted item distributions, ensuring fair rotation through +// elements. All methods are thread-safe using a struct-level mutex and the underlying ring's mutexes. +type roundrobin[E any] struct { + sync.RWMutex + ring *container.Ring[E] // The underlying ring storing the elements. + len int // Cached length of the ring. +} -func newRoundRobin[E any, Items []E | []Weighted[E]](items Items) *roundrobin[E] { +func newRoundRobin[E any, Items []E | []Weighted[E]](items Items) (*roundrobin[E], error) { if len(items) == 0 { - panic(ErrEmptyLoadBalancer) + return nil, ErrEmptyLoadBalancer } - var root *roundrobin[E] + var ring *container.Ring[E] + var n int switch items := any(items).(type) { case []E: + n = len(items) + ring = container.NewRing[E](n) for _, i := range items { - r := container.NewRing[*E](1) - r.Set(&i) - if root == nil { - root = (*roundrobin[E])(r) - } else { - root.Ring().Link(r) - root = (*roundrobin[E])(root.Ring().Next()) - } - } - if root != nil { - root = (*roundrobin[E])(root.Ring().Next()) + ring = ring.Set(i).Next() } case []Weighted[E]: for _, i := range items { if i.Weight == 0 { continue } - item := &i.Item - r := container.NewRing[*E](i.Weight) - for range r.Len() { - r.Set(item) - r = r.Next() + subring := container.NewRing[E](i.Weight) + for range i.Weight { + subring = subring.Set(i.Item).Next() } - if root == nil { - root = (*roundrobin[E])(r) + if ring == nil { + ring = subring } else { - root.Ring().Link(r) - root = (*roundrobin[E])(root.Ring().Next()) + ring = ring.Link(subring).Prev() } } - if root != nil { - root = (*roundrobin[E])(root.Ring().Next()) + if ring != nil { + ring = ring.Next() + } + if n = ring.Len(); n == 0 { + return nil, ErrEmptyLoadBalancer } } - return root + return &roundrobin[E]{ring: ring, len: n}, nil } -func RoundRobin[E any](items ...E) LoadBalancer[E] { +// RoundRobin creates a new round-robin load balancer with the given items. +// Each item appears once in the rotation. It returns error with ErrEmptyLoadBalancer if no items are provided. +func RoundRobin[E any](items ...E) (LoadBalancer[E], error) { return newRoundRobin[E](items) } -func WeightedRoundRobin[E any](items ...Weighted[E]) LoadBalancer[E] { +// WeightedRoundRobin creates a new weighted round-robin load balancer. +// Each item's weight determines how many times it appears in the rotation. +// It returns error with ErrEmptyLoadBalancer if no items have positive weight. +func WeightedRoundRobin[E any](items ...Weighted[E]) (LoadBalancer[E], error) { return newRoundRobin[E](items) } -func (r *roundrobin[E]) Len() int { - mu.RLock() - defer mu.RUnlock() - return r.Ring().Len() +// RoundRobinFromRing creates a new round-robin load balancer from an existing ring. +// It uses the provided ring directly, ensuring thread-safe operations. +// It returns error with ErrEmptyLoadBalancer if the ring is nil or empty. +func RoundRobinFromRing[E any](ring *container.Ring[E]) (LoadBalancer[E], error) { + len := ring.Len() + if len == 0 { + return nil, ErrEmptyLoadBalancer + } + return &roundrobin[E]{ring: ring, len: len}, nil } -func (r *roundrobin[E]) Next() (next E) { - mu.Lock() - defer mu.Unlock() - v := **r.Ring().Value() - *r = *(*roundrobin[E])(r.Ring().Next()) - return v +// Len returns the number of elements in the load balancer. +// It is thread-safe and uses the cached length for O(1) access. +func (r *roundrobin[E]) Len() int { + r.RLock() + defer r.RUnlock() + return r.len } -func (r *roundrobin[E]) Ring() *container.Ring[*E] { - return (*container.Ring[*E])(r) +// Next returns the next element in the round-robin sequence. +// It is thread-safe, advancing the ring to the next position. +// If the balancer is empty, it returns the zero value of E. +func (r *roundrobin[E]) Next() (next E) { + r.Lock() + defer r.Unlock() + next = r.ring.Value() + r.ring = r.ring.Next() + return } -func (r *roundrobin[E]) Link(s LoadBalancer[E]) LoadBalancer[E] { - mu.Lock() - defer mu.Unlock() - return (*roundrobin[E])(r.Ring().Link(s.Ring())) +// Link merges the given ring into the load balancer, inserting its elements +// after the current position. It is thread-safe, updates the cached length, +// and returns the load balancer for chaining. +func (r *roundrobin[E]) Link(s *container.Ring[E]) LoadBalancer[E] { + r.Lock() + defer r.Unlock() + r.ring = r.ring.Prev().Link(s) + r.len = r.ring.Len() + return r } +// Unlink removes n elements starting from the next position and returns +// the load balancer for chaining. It is thread-safe, updates the cached length, +// and sets the ring to nil if it becomes empty. func (r *roundrobin[E]) Unlink(n int) LoadBalancer[E] { - mu.Lock() - defer mu.Unlock() - return (*roundrobin[E])(r.Ring().Unlink(n)) + r.Lock() + defer r.Unlock() + r.ring = r.ring.Unlink(n) + r.len = r.ring.Len() + return r } diff --git a/loadbalance/roundrobin_test.go b/loadbalance/roundrobin_test.go index 5671e4c..5b66b71 100644 --- a/loadbalance/roundrobin_test.go +++ b/loadbalance/roundrobin_test.go @@ -6,30 +6,45 @@ import ( ) func TestRoundRobin(t *testing.T) { - r1 := RoundRobin([]string{"a", "b", "c"}...) + r1, err := newRoundRobin[string]([]string{"a", "b", "c"}) + if err != nil { + t.Fatal(err) + } + if r1.Len() != 3 { + t.Fatalf("want 3, got %d", r1.Len()) + } var res []string for range 6 { res = append(res, r1.Next()) } if expect := []string{"a", "b", "c", "a", "b", "c"}; !slices.Equal(res, expect) { - t.Errorf("want %v, got %v", expect, res) + t.Fatalf("want %v, got %v", expect, res) } res = nil - r2 := WeightedRoundRobin([]Weighted[string]{{"a", 2}, {"b", 1}, {"c", 1}}...) + r2, err := newRoundRobin[string]([]Weighted[string]{{"a", 2}, {"b", 1}, {"c", 1}}) + if err != nil { + t.Fatal(err) + } + if r2.Len() != 4 { + t.Fatalf("want 4, got %d", r2.Len()) + } for range 12 { res = append(res, r2.Next()) } if expect := []string{"a", "a", "b", "c", "a", "a", "b", "c", "a", "a", "b", "c"}; !slices.Equal(res, expect) { - t.Errorf("want %v, got %v", expect, res) + t.Fatalf("want %v, got %v", expect, res) } res = nil - r1.Link(r2) + ring := r1.Link(r2.ring) + if ring.Len() != 7 { + t.Fatalf("want 7, got %d", ring.Len()) + } for range 7 { - res = append(res, r1.Next()) + res = append(res, ring.Next()) } - if expect := []string{"a", "a", "a", "b", "c", "b", "c"}; !slices.Equal(res, expect) { - t.Errorf("want %v, got %v", expect, res) + if expect := []string{"a", "b", "c", "a", "a", "b", "c"}; !slices.Equal(res, expect) { + t.Fatalf("want %v, got %v", expect, res) } } diff --git a/log/handler.go b/log/handler.go index 98bfac7..ba7a191 100644 --- a/log/handler.go +++ b/log/handler.go @@ -3,7 +3,6 @@ package log import ( "bytes" "context" - "fmt" "log" "log/slog" "strings" @@ -11,39 +10,41 @@ import ( "time" ) -var _ slog.Handler = new(defaultHandler) - +// defaultHandler combines a standard log.Logger with an slog.Handler for flexible logging. type defaultHandler struct { - *sync.Mutex - slog.Handler - *log.Logger - *bytes.Buffer + *sync.Mutex // Mutex for thread-safe buffer access. + *bytes.Buffer // Buffer for formatting log messages. + *log.Logger // Underlying standard logger for output. + slog.Handler // Structured logging handler. } -func newDefaultHandler(mu *sync.Mutex, logger *log.Logger, opts *slog.HandlerOptions) *defaultHandler { +var _ slog.Handler = new(defaultHandler) + +// newDefaultHandler creates a new defaultHandler with the specified logger and log level. +func newDefaultHandler(logger *log.Logger, level *slog.LevelVar) *defaultHandler { buf := new(bytes.Buffer) - return &defaultHandler{mu, slog.NewTextHandler(buf, opts), logger, buf} + return &defaultHandler{new(sync.Mutex), buf, logger, slog.NewTextHandler(buf, &slog.HandlerOptions{Level: level})} } +// Handle formats and outputs a log record using the slog.Handler and log.Logger. func (h *defaultHandler) Handle(ctx context.Context, r slog.Record) error { h.Lock() defer h.Unlock() - msg := fmt.Sprintf("%s %s", r.Level, r.Message) - r.Time, r.Message, r.Level = time.Time{}, "", 0 - h.Handler.Handle(ctx, r) - if log := strings.TrimSpace(strings.Replace(strings.Replace(h.String(), "level=INFO", "", 1), ` msg=""`, "", 1)); log == "" { - h.Print(msg) - } else { - h.Println(msg, log) + r.Time = time.Time{} + if err := h.Handler.Handle(ctx, r); err != nil { + return err } + h.Print(strings.TrimPrefix(h.String(), "level=")) h.Reset() return nil } +// WithAttrs returns a new handler with the specified attributes. func (h *defaultHandler) WithAttrs(attrs []slog.Attr) slog.Handler { - return &defaultHandler{h.Mutex, h.Handler.WithAttrs(attrs), h.Logger, h.Buffer} + return &defaultHandler{h.Mutex, h.Buffer, h.Logger, h.Handler.WithAttrs(attrs)} } +// WithGroup returns a new handler with the specified group name. func (h *defaultHandler) WithGroup(name string) slog.Handler { - return &defaultHandler{h.Mutex, h.Handler.WithGroup(name), h.Logger, h.Buffer} + return &defaultHandler{h.Mutex, h.Buffer, h.Logger, h.Handler.WithGroup(name)} } diff --git a/log/log.go b/log/log.go index 7b4cd54..9cf02f5 100644 --- a/log/log.go +++ b/log/log.go @@ -6,130 +6,210 @@ import ( "log" "log/slog" "os" + "sync/atomic" ) -var std = newLogger(log.Default(), os.Stderr) +// defaultLogger holds the default Logger instance, managed atomically for thread safety. +var defaultLogger atomic.Pointer[Logger] -func Default() *Logger { return std } +// init initializes the default logger with stderr output. +func init() { + defaultLogger.Store(newLogger(log.Default(), os.Stderr)) +} + +// Default returns the current default Logger instance. +func Default() *Logger { return defaultLogger.Load() } +// SetDefault sets the default Logger instance. +func SetDefault(l *Logger) { + defaultLogger.Store(l) +} + +// File returns the current log file path of the default Logger. func File() string { - return std.File() + return Default().File() } +// SetOutput sets the output destination for the default Logger. +// The file parameter specifies the log file path; if empty, no file output is used. +// The extra parameter allows an additional output destination (e.g., stderr). func SetOutput(file string, extra io.Writer) { - std.SetOutput(file, extra) + Default().SetOutput(file, extra) } +// SetFile sets the log file path for the default Logger, keeping the existing extra writer. func SetFile(file string) { - std.SetFile(file) + Default().SetFile(file) } +// SetExtra sets an additional output destination for the default Logger, keeping the existing file. func SetExtra(extra io.Writer) { - std.SetExtra(extra) + Default().SetExtra(extra) } +// Rotate reopens the log file and rotates the extra writer for the default Logger if applicable. func Rotate() { - std.Rotate() + Default().Rotate() } +// Flags returns the current log flags of the default Logger. func Flags() int { - return std.Flags() + return Default().Flags() } + +// SetFlags sets the log flags for the default Logger. func SetFlags(flag int) { - std.SetFlags(flag) + Default().SetFlags(flag) } + +// Prefix returns the current log prefix of the default Logger. func Prefix() string { - return std.Prefix() + return Default().Prefix() } + +// SetPrefix sets the log prefix for the default Logger. func SetPrefix(prefix string) { - std.SetPrefix(prefix) + Default().SetPrefix(prefix) } + +// Writer returns the current output writer of the default Logger. func Writer() io.Writer { - return std.Writer() + return Default().Writer() } + +// Print logs a message using the default Logger's Print method. func Print(v ...any) { - std.Print(v...) + Default().Print(v...) } + +// Printf logs a formatted message using the default Logger's Printf method. func Printf(format string, v ...any) { - std.Printf(format, v...) + Default().Printf(format, v...) } + +// Println logs a message with a newline using the default Logger's Println method. func Println(v ...any) { - std.Println(v...) + Default().Println(v...) } + +// Fatal logs a message and exits using the default Logger's Fatal method. func Fatal(v ...any) { - std.Fatal(v...) + Default().Fatal(v...) } + +// Fatalf logs a formatted message and exits using the default Logger's Fatalf method. func Fatalf(format string, v ...any) { - std.Fatalf(format, v...) + Default().Fatalf(format, v...) } + +// Fatalln logs a message with a newline and exits using the default Logger's Fatalln method. func Fatalln(v ...any) { - std.Fatalln(v...) + Default().Fatalln(v...) } + +// Panic logs a message and panics using the default Logger's Panic method. func Panic(v ...any) { - std.Panic(v...) + Default().Panic(v...) } + +// Panicf logs a formatted message and panics using the default Logger's Panicf method. func Panicf(format string, v ...any) { - std.Panicf(format, v...) + Default().Panicf(format, v...) } + +// Panicln logs a message with a newline and panics using the default Logger's Panicln method. func Panicln(v ...any) { - std.Panicln(v...) + Default().Panicln(v...) } + +// Output logs a message with the specified call depth using the default Logger's Output method. func Output(calldepth int, s string) error { - return std.Output(calldepth+1, s) // +1 for this frame. + return Default().Output(calldepth+1, s) // +1 for this frame. } +// SetHandler sets the slog handler for the default Logger. +// Note: The new handler may not respect the existing log level (obtained via [Level]), potentially disabling level control. +// Ensure the provided handler is configured with the desired log level if needed. func SetHandler(h slog.Handler) { - std.mu.Lock() - defer std.mu.Unlock() - std.slog = slog.New(h) + Default().slog.Store(slog.New(h)) } -func Level() *slog.LevelVar { - return std.level + +// Level returns the log level of the default Logger. +func Level() slog.Level { + return Default().Level() } + +// SetLevel sets the log level for the default Logger. func SetLevel(level slog.Level) { - std.level.Set(level) + Default().level.Set(level) } + +// Debug logs a message at Debug level using the default Logger. func Debug(msg string, args ...any) { - std.Debug(msg, args...) + Default().Debug(msg, args...) } + +// DebugContext logs a message at Debug level with context using the default Logger. func DebugContext(ctx context.Context, msg string, args ...any) { - std.DebugContext(ctx, msg, args...) + Default().DebugContext(ctx, msg, args...) } + +// Enabled checks if the specified log level is enabled for the default Logger. func Enabled(ctx context.Context, level slog.Level) bool { - return std.Enabled(ctx, level) + return Default().Enabled(ctx, level) } + +// Error logs a message at Error level using the default Logger. func Error(msg string, args ...any) { - std.Error(msg, args...) + Default().Error(msg, args...) } + +// ErrorContext logs a message at Error level with context using the default Logger. func ErrorContext(ctx context.Context, msg string, args ...any) { - std.ErrorContext(ctx, msg, args...) + Default().ErrorContext(ctx, msg, args...) } -func LoggerHandler() slog.Handler { - return std.LoggerHandler() + +// Handler returns the slog handler of the default Logger. +func Handler() slog.Handler { + return Default().Handler() } + +// Info logs a message at Info level using the default Logger. func Info(msg string, args ...any) { - std.Info(msg, args...) + Default().Info(msg, args...) } + +// InfoContext logs a message at Info level with context using the default Logger. func InfoContext(ctx context.Context, msg string, args ...any) { - std.InfoContext(ctx, msg, args...) + Default().InfoContext(ctx, msg, args...) } + +// Log logs a message at the specified level with context using the default Logger. func Log(ctx context.Context, level slog.Level, msg string, args ...any) { - std.Log(ctx, level, msg, args...) + Default().Log(ctx, level, msg, args...) } + +// LogAttrs logs a message at the specified level with attributes using the default Logger. func LogAttrs(ctx context.Context, level slog.Level, msg string, attrs ...slog.Attr) { - std.LogAttrs(ctx, level, msg, attrs...) + Default().LogAttrs(ctx, level, msg, attrs...) } + +// Warn logs a message at Warn level using the default Logger. func Warn(msg string, args ...any) { - std.Warn(msg, args...) + Default().Warn(msg, args...) } + +// WarnContext logs a message at Warn level with context using the default Logger. func WarnContext(ctx context.Context, msg string, args ...any) { - std.WarnContext(ctx, msg, args...) + Default().WarnContext(ctx, msg, args...) } + +// With returns a new Logger with the specified attributes, leaving the default Logger unchanged. func With(args ...any) *Logger { - std = std.With(args...) - return std + return Default().With(args...) } + +// WithGroup returns a new Logger with the specified group name, leaving the default Logger unchanged. func WithGroup(name string) *Logger { - std = std.WithGroup(name) - return std + return Default().WithGroup(name) } diff --git a/log/logger.go b/log/logger.go index f830b72..f9f4aef 100644 --- a/log/logger.go +++ b/log/logger.go @@ -6,57 +6,71 @@ import ( "log" "log/slog" "os" - "sync" + "sync/atomic" + + "github.com/sunshineplan/utils/container" ) +// Logger implements a custom logger that combines the standard log.Logger with slog.Logger, +// providing flexible output destinations and log level control. +type Logger struct { + *log.Logger // Underlying standard logger. + file atomic.Pointer[os.File] // File handle for log output, managed atomically. + extra container.Value[io.Writer] // Additional output destination (e.g., stderr). + slog atomic.Pointer[slog.Logger] // Structured logger for leveled logging. + level *slog.LevelVar // Log level controller. +} + var ( _ io.Writer = new(Logger) _ Rotatable = new(Logger) ) +// Constants for log flags, mirrored from the standard log package. const ( - Ldate = 1 << iota // the date in the local time zone: 2009/01/23 - Ltime // the time in the local time zone: 01:23:23 - Lmicroseconds // microsecond resolution: 01:23:23.123123. assumes Ltime. - Llongfile // full file name and line number: /a/b/c/d.go:23 - Lshortfile // final file name element and line number: d.go:23. overrides Llongfile - LUTC // if Ldate or Ltime is set, use UTC rather than the local time zone - Lmsgprefix // move the "prefix" from the beginning of the line to before the message - LstdFlags = Ldate | Ltime // initial values for the standard logger + Ldate = log.Ldate // Include date in log output (e.g., 2009/01/23). + Ltime = log.Ltime // Include time in log output (e.g., 01:23:23). + Lmicroseconds = log.Lmicroseconds // Include microsecond resolution in time (requires Ltime). + Llongfile = log.Llongfile // Include full file name and line number (e.g., /a/b/c/d.go:23). + Lshortfile = log.Lshortfile // Include final file name element and line number (e.g., d.go:23), overrides Llongfile. + LUTC = log.LUTC // Use UTC for date/time if Ldate or Ltime is set. + Lmsgprefix = log.Lmsgprefix // Move prefix to before the message. + LstdFlags = log.LstdFlags // Default flags: Ldate | Ltime. ) -type Logger struct { - *log.Logger - file *os.File - extra io.Writer - - slog *slog.Logger - - mu *sync.Mutex - level *slog.LevelVar -} - +// newLogger creates a new Logger instance with the specified standard logger and file handle. +// The logger is initialized with a default slog handler and log level. func newLogger(l *log.Logger, file *os.File) *Logger { - logger := &Logger{Logger: l, file: file, mu: new(sync.Mutex), level: new(slog.LevelVar)} - logger.slog = slog.New(newDefaultHandler(logger.mu, l, &slog.HandlerOptions{Level: logger.level})) + logger := &Logger{Logger: l, level: new(slog.LevelVar)} + logger.file.Store(file) + logger.slog.Store(slog.New(newDefaultHandler(l, logger.level))) return logger } +// New creates a new Logger instance with the specified file path, prefix, and flags. +// If file is empty, logs are discarded. Panics if the file cannot be opened. func New(file, prefix string, flag int) *Logger { if file == "" { return newLogger(log.New(io.Discard, prefix, flag), nil) } - f := openFile(file) + f, err := openFile(file) + if err != nil { + panic(err) + } return newLogger(log.New(f, prefix, flag), f) } +// File returns the current log file path, or an empty string if no file is set. func (l *Logger) File() string { - if l.file != nil { - return l.file.Name() + if file := l.file.Load(); file != nil { + return file.Name() } return "" } +// setOutput configures the logger's output destinations. +// It sets the file and extra writer, closing the old file if necessary. +// Logs are written to io.Discard if no outputs are provided. func (l *Logger) setOutput(file *os.File, extra io.Writer) { var writers []io.Writer if file != nil { @@ -69,105 +83,163 @@ func (l *Logger) setOutput(file *os.File, extra io.Writer) { writers = append(writers, io.Discard) } l.Logger.SetOutput(io.MultiWriter(writers...)) - if l.file != nil && l.file != file { - l.file.Close() + if oldFile := l.file.Load(); oldFile != nil && oldFile != file { + if err := oldFile.Close(); err != nil { + l.Error("failed to close log file", "error", err) + } + } + l.file.Store(file) + if extra != nil { + l.extra.Store(extra) } - l.file = file - l.extra = extra } -func (l *Logger) SetOutput(file string, extra io.Writer) { - l.mu.Lock() - defer l.mu.Unlock() - l.setOutput(openFile(file), extra) +// SetOutput sets the log output to the specified file path and extra writer. +// Returns an error if the file cannot be opened. +func (l *Logger) SetOutput(file string, extra io.Writer) error { + f, err := openFile(file) + if err != nil { + return err + } + l.setOutput(f, extra) + return nil } +// SetFile sets the log file to the specified path, keeping the existing extra writer. func (l *Logger) SetFile(file string) { - l.SetOutput(file, l.extra) + l.SetOutput(file, l.extra.Load()) } +// SetExtra sets an additional output destination (e.g., stderr), keeping the existing file. func (l *Logger) SetExtra(extra io.Writer) { - l.mu.Lock() - defer l.mu.Unlock() - l.setOutput(l.file, extra) + l.setOutput(l.file.Load(), extra) } +// SetHandler sets the slog handler for structured logging. +// Note: The new handler may not respect the existing log level (l.level), potentially disabling level control. +// Ensure the provided handler is configured with the desired log level if needed. func (l *Logger) SetHandler(h slog.Handler) { - l.mu.Lock() - defer l.mu.Unlock() - l.slog = slog.New(h) + l.slog.Store(slog.New(h)) } -func (l *Logger) Level() *slog.LevelVar { - return l.level + +// Level returns the current log level. +func (l *Logger) Level() slog.Level { + return l.level.Level() } + +// SetLevel sets the log level for structured logging. func (l *Logger) SetLevel(level slog.Level) { l.level.Set(level) } + +// Debug logs a message at Debug level with the given arguments. func (l *Logger) Debug(msg string, args ...any) { - l.slog.Debug(msg, args...) + l.slog.Load().Debug(msg, args...) } + +// DebugContext logs a message at Debug level with the given context and arguments. func (l *Logger) DebugContext(ctx context.Context, msg string, args ...any) { - l.slog.DebugContext(ctx, msg, args...) + l.slog.Load().DebugContext(ctx, msg, args...) } + +// Enabled checks if the specified log level is enabled for the logger. func (l *Logger) Enabled(ctx context.Context, level slog.Level) bool { - return l.slog.Enabled(ctx, level) + return l.slog.Load().Enabled(ctx, level) } + +// Error logs a message at Error level with the given arguments. func (l *Logger) Error(msg string, args ...any) { - l.slog.Error(msg, args...) + l.slog.Load().Error(msg, args...) } + +// ErrorContext logs a message at Error level with the given context and arguments. func (l *Logger) ErrorContext(ctx context.Context, msg string, args ...any) { - l.slog.ErrorContext(ctx, msg, args...) + l.slog.Load().ErrorContext(ctx, msg, args...) } -func (l *Logger) LoggerHandler() slog.Handler { - return l.slog.Handler() + +// Handler returns the current slog handler. +func (l *Logger) Handler() slog.Handler { + return l.slog.Load().Handler() } + +// Info logs a message at Info level with the given arguments. func (l *Logger) Info(msg string, args ...any) { - l.slog.Info(msg, args...) + l.slog.Load().Info(msg, args...) } + +// InfoContext logs a message at Info level with the given context and arguments. func (l *Logger) InfoContext(ctx context.Context, msg string, args ...any) { - l.slog.InfoContext(ctx, msg, args...) + l.slog.Load().InfoContext(ctx, msg, args...) } + +// Log logs a message at the specified level with the given context and arguments. func (l *Logger) Log(ctx context.Context, level slog.Level, msg string, args ...any) { - l.slog.Log(ctx, level, msg, args...) + l.slog.Load().Log(ctx, level, msg, args...) } + +// LogAttrs logs a message at the specified level with the given context and attributes. func (l *Logger) LogAttrs(ctx context.Context, level slog.Level, msg string, attrs ...slog.Attr) { - l.slog.LogAttrs(ctx, level, msg, attrs...) + l.slog.Load().LogAttrs(ctx, level, msg, attrs...) } + +// Warn logs a message at Warn level with the given arguments. func (l *Logger) Warn(msg string, args ...any) { - l.slog.Warn(msg, args...) + l.slog.Load().Warn(msg, args...) } + +// WarnContext logs a message at Warn level with the given context and arguments. func (l *Logger) WarnContext(ctx context.Context, msg string, args ...any) { - l.slog.WarnContext(ctx, msg, args...) + l.slog.Load().WarnContext(ctx, msg, args...) } + +// With returns a new Logger with the specified attributes, leaving the original unchanged. func (l *Logger) With(args ...any) *Logger { - l.slog = l.slog.With(args...) - return l + logger := &Logger{Logger: l.Logger, extra: l.extra, level: l.level} + logger.file.Store(l.file.Load()) + if extra := l.extra.Load(); extra != nil { + logger.extra.Store(extra) + } + logger.slog.Store(l.slog.Load().With(args...)) + return logger } + +// WithGroup returns a new Logger with the specified group name, leaving the original unchanged. func (l *Logger) WithGroup(name string) *Logger { - l.slog = l.slog.WithGroup(name) - return l + logger := &Logger{Logger: l.Logger, extra: l.extra, level: l.level} + logger.file.Store(l.file.Load()) + if extra := l.extra.Load(); extra != nil { + logger.extra.Store(extra) + } + logger.slog.Store(l.slog.Load().WithGroup(name)) + return logger } +// Rotate reopens the log file and rotates the extra writer if it implements Rotatable. func (l *Logger) Rotate() { - if i, ok := l.extra.(Rotatable); ok { - i.Rotate() + if extra := l.extra.Load(); extra != nil { + if i, ok := extra.(Rotatable); ok { + i.Rotate() + } } - if l.file != nil { - l.SetFile(l.file.Name()) + if file := l.file.Load(); file != nil { + l.SetFile(file.Name()) } } +// Write writes bytes to the logger's output destination, implementing io.Writer. func (l *Logger) Write(b []byte) (int, error) { return l.Writer().Write(b) } -func openFile(file string) *os.File { +// openFile opens a log file with the specified path in append mode. +// Returns nil if the path is empty, or an error if the file cannot be opened. +func openFile(file string) (*os.File, error) { if file != "" { f, err := os.OpenFile(file, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0640) if err != nil { - panic(err) + return nil, err } - return f + return f, nil } - return nil + return nil, nil } diff --git a/log/logger_test.go b/log/logger_test.go index 3422de9..275f919 100644 --- a/log/logger_test.go +++ b/log/logger_test.go @@ -2,10 +2,12 @@ package log import ( "bytes" + "fmt" "log" "log/slog" "os" "runtime" + "strings" "testing" ) @@ -51,7 +53,8 @@ func TestSLogger(t *testing.T) { t.Errorf("expected empty string; got %q", file) } l.Info("test") - if s, expected := buf.String(), "INFO test\n"; s != expected { + fmt.Print(buf.String()) + if s, expected := buf.String(), "INFO msg=test\n"; !strings.HasSuffix(s, expected) { t.Errorf("expected %q; got %q", expected, s) } buf.Reset() @@ -62,13 +65,15 @@ func TestSLogger(t *testing.T) { buf.Reset() l.SetLevel(slog.LevelDebug) l.Debug("test") - if s, expected := buf.String(), "DEBUG test\n"; s != expected { + fmt.Print(buf.String()) + if s, expected := buf.String(), "DEBUG msg=test\n"; !strings.HasSuffix(s, expected) { t.Errorf("expected %q; got %q", expected, s) } buf.Reset() l = l.WithGroup("g").With("a", 1) l.Info("test") - if s, expected := buf.String(), "INFO test g.a=1\n"; s != expected { + fmt.Print(buf.String()) + if s, expected := buf.String(), "INFO msg=test g.a=1\n"; !strings.HasSuffix(s, expected) { t.Errorf("expected %q; got %q", expected, s) } } diff --git a/log/rotate.go b/log/rotate.go index dc4fa6f..25c0bf9 100644 --- a/log/rotate.go +++ b/log/rotate.go @@ -1,20 +1,30 @@ package log import ( + "context" "os" "os/signal" ) +// Rotatable defines an interface for objects that support log rotation. type Rotatable interface { Rotate() } -func ListenRotateSignal(r Rotatable, sig ...os.Signal) { +// ListenRotateSignal listens for the specified signals and triggers rotation on the Rotatable object. +// It stops listening when the context is canceled. +func ListenRotateSignal(ctx context.Context, r Rotatable, sig ...os.Signal) { c := make(chan os.Signal, 1) signal.Notify(c, sig...) go func() { - for range c { - r.Rotate() + for { + select { + case <-ctx.Done(): + signal.Stop(c) + return + case <-c: + r.Rotate() + } } }() } diff --git a/mail/mail.go b/mail/mail.go index dbc6adc..0ceb1cb 100644 --- a/mail/mail.go +++ b/mail/mail.go @@ -20,6 +20,8 @@ type Dialer struct { Timeout time.Duration } +// Dial dials the SMTP server and performs optional STARTTLS / AUTH. +// It returns a connected smtp.Client. Caller should call client.Quit() when done. func (d *Dialer) Dial() (client *smtp.Client, err error) { if d.Timeout == 0 { d.Timeout = 3 * time.Minute @@ -37,6 +39,7 @@ func (d *Dialer) Dial() (client *smtp.Client, err error) { return } + // If connection is not TLS but server supports STARTTLS, upgrade. if !d.TLS { if ok, _ := client.Extension("STARTTLS"); ok { if err = client.StartTLS(nil); err != nil { @@ -45,6 +48,7 @@ func (d *Dialer) Dial() (client *smtp.Client, err error) { } } + // Authenticate if server advertises AUTH and credentials are provided. if ok, _ := client.Extension("AUTH"); ok && d.Account != "" && d.Password != "" { if err = client.Auth2(&smtp.Auth{Identity: "", Username: d.Account, Password: d.Password, Server: d.Server}); err != nil { client.Quit() @@ -55,15 +59,19 @@ func (d *Dialer) Dial() (client *smtp.Client, err error) { return client, nil } -// Send sends the given messages. +// Send sends the given messages using one established connection. +// It honors Dialer.Timeout for each message via context timeouts. +// The client connection will be Quit() when Send returns. func (d *Dialer) Send(msg ...*Message) error { client, err := d.Dial() if err != nil { return err } + // ensure we close the SMTP connection when the function returns defer client.Quit() for _, m := range msg { + // default From to the dialer's account if not set if m.From == nil { m.From = Receipt("", d.Account) } diff --git a/mail/message.go b/mail/message.go index 840677d..ffcf96e 100644 --- a/mail/message.go +++ b/mail/message.go @@ -7,7 +7,6 @@ import ( "crypto/rand" "encoding/base64" "fmt" - "math" "mime" "net/mail" "net/textproto" @@ -51,24 +50,27 @@ type Message struct { Attachments []*Attachment } +// RcptList returns a de-duplicated list of recipient email addresses while preserving order: +// first To, then Cc, then Bcc. func (m *Message) RcptList() (rcpts []string) { - list := make(map[string]struct{}) - for _, to := range m.To { - list[to.Address] = struct{}{} - } - for _, cc := range m.Cc { - list[cc.Address] = struct{}{} - } - for _, bcc := range m.Bcc { - list[bcc.Address] = struct{}{} - } - for k := range list { - rcpts = append(rcpts, k) + seen := make(map[string]bool) + for _, list := range [][]*mail.Address{m.To, m.Cc, m.Bcc} { + for _, addr := range list { + if !seen[addr.Address] { + rcpts = append(rcpts, addr.Address) + seen[addr.Address] = true + } + } } return } +// Bytes renders the RFC822-style message bytes for the message. +// id is used to create the Message-ID domain if provided; if empty, fallback to hostname. +// The produced message uses CRLF line endings as required by SMTP and includes correct +// MIME headers for single-part or multipart/mixed with attachments. func (m *Message) Bytes(id string) []byte { + // determine hostname part for Message-ID if id == "" { if m.From != nil && m.From.Address != "" { id = m.From.Address @@ -85,12 +87,18 @@ func (m *Message) Bytes(id string) []byte { var buf bytes.Buffer w := textproto.NewWriter(bufio.NewWriter(&buf)) + // Basic headers w.PrintfLine("MIME-Version: 1.0") w.PrintfLine("Date: %s", time.Now().Format(time.RFC1123Z)) w.PrintfLine("Message-ID: <%s>", id) - w.PrintfLine("Subject: =?UTF-8?B?%s?=", toBase64(m.Subject)) - w.PrintfLine("From: %s", m.From) - w.PrintfLine("To: %s", m.To) + w.PrintfLine("Subject: %s", encodeHeader(m.Subject)) + if m.From != nil { + w.PrintfLine("From: %s", m.From.String()) + } + // To / Cc headers (these are header fields visible in the message) + if len(m.To) > 0 { + w.PrintfLine("To: %s", m.To) + } if len(m.Cc) > 0 { w.PrintfLine("Cc: %s", m.Cc) } @@ -105,7 +113,7 @@ func (m *Message) Bytes(id string) []byte { w.PrintfLine(`Content-Type: %s; charset="UTF-8"`, m.ContentType) w.PrintfLine("Content-Transfer-Encoding: base64") w.PrintfLine("") - w.PrintfLine("%s", toBase64(m.Body)) + writeBase64BytesLines(w, []byte(m.Body)) if l := len(m.Attachments); l > 0 { for i, attachment := range m.Attachments { @@ -116,25 +124,14 @@ func (m *Message) Bytes(id string) []byte { w.PrintfLine("Content-Type: application/octet-stream") } if attachment.ContentID != "" { - w.PrintfLine(`Content-Disposition: inline; filename="=?UTF-8?B?%s?="`, toBase64(attachment.Filename)) + w.PrintfLine(`Content-Disposition: inline; filename="%s"`, encodeHeader(attachment.Filename)) w.PrintfLine("Content-ID: <%s>", attachment.ContentID) } else { - w.PrintfLine(`Content-Disposition: attachment; filename="=?UTF-8?B?%s?="`, toBase64(attachment.Filename)) + w.PrintfLine(`Content-Disposition: attachment; filename="%s"`, encodeHeader(attachment.Filename)) } w.PrintfLine("Content-Transfer-Encoding: base64") w.PrintfLine("") - - b := make([]byte, base64.StdEncoding.EncodedLen(len(attachment.Bytes))) - base64.StdEncoding.Encode(b, attachment.Bytes) - - // write base64 content in lines of up to 76 chars - for i, l := 0, int(math.Ceil(float64(len(b))/76)); i < l; i++ { - if i == l-1 { - w.PrintfLine("%s", b[i*76:]) - } else { - w.PrintfLine("%s", b[i*76:(i+1)*76]) - } - } + writeBase64BytesLines(w, attachment.Bytes) if i < l-1 { w.PrintfLine("--%s", boundary) @@ -147,10 +144,28 @@ func (m *Message) Bytes(id string) []byte { return buf.Bytes() } -func toBase64(str string) string { - return base64.StdEncoding.EncodeToString([]byte(str)) +// encodeHeader encodes a header value using RFC2047 only when non-ASCII chars are present. +// For pure ASCII strings it returns the original string unmodified. +func encodeHeader(s string) string { + for _, r := range s { + if r > 127 { + return fmt.Sprintf("=?UTF-8?B?%s?=", base64.StdEncoding.EncodeToString([]byte(s))) + } + } + return s +} + +// writeBase64BytesLines encodes bytes to base64 and writes it in lines of up to 76 chars. +func writeBase64BytesLines(w *textproto.Writer, b []byte) { + enc := base64.StdEncoding.EncodeToString(b) + // number of lines + for i := 0; i < len(enc); i += 76 { + end := min(i+76, len(enc)) + w.PrintfLine("%s", enc[i:end]) + } } +// randomString returns a hex string of length 2*n (because each byte => two hex chars). func randomString(n int) string { b := make([]byte, n) if _, err := rand.Read(b); err != nil { @@ -159,6 +174,8 @@ func randomString(n int) string { return fmt.Sprintf("%x", b) } +// generateMsgID generates a message id using the provided reference (domain or email). +// If ref has an @, use the right-hand side as domain, otherwise use ref as domain. func generateMsgID(ref string) string { s := strings.Split(ref, "@") return fmt.Sprintf("%s@%s", randomString(16), s[len(s)-1]) diff --git a/mail/receipt.go b/mail/receipt.go index 132998c..2cc3bde 100644 --- a/mail/receipt.go +++ b/mail/receipt.go @@ -5,23 +5,25 @@ import ( "encoding/json" "fmt" "net/mail" - "strconv" "strings" ) var ( - _ encoding.TextUnmarshaler = (*Receipts)(nil) + _ encoding.TextUnmarshaler = new(Receipts) _ encoding.TextMarshaler = Receipts{} - _ json.Unmarshaler = (*Receipts)(nil) + _ json.Unmarshaler = new(Receipts) _ json.Marshaler = Receipts{} ) +// Receipt creates a mail.Address pointer from name and address. func Receipt(name, address string) *mail.Address { return &mail.Address{Name: name, Address: address} } +// Receipts represents a list of mail addresses. type Receipts []*mail.Address +// ParseReceipts parses a comma/semicolon separated list of addresses into Receipts. func ParseReceipts(rcpts string) (Receipts, error) { addresses, err := mail.ParseAddressList(rcpts) if err != nil { @@ -30,16 +32,24 @@ func ParseReceipts(rcpts string) (Receipts, error) { return Receipts(addresses), nil } +// List returns a slice of address strings (just the email addresses). func (rcpts Receipts) List() []string { - var s []string - for _, rcpt := range rcpts { - s = append(s, rcpt.Address) + s := make([]string, len(rcpts)) + for i, rcpt := range rcpts { + s[i] = rcpt.Address } return s } +// String returns the addresses joined in the standard RFC 5322 way +// using the String method of mail.Address. func (rcpts Receipts) String() string { + if len(rcpts) == 0 { + return "" + } var b strings.Builder + // approximate average length to avoid repeated allocations + b.Grow(len(rcpts) * 32) for i, rcpt := range rcpts { if i != 0 { b.WriteString(", ") @@ -49,6 +59,7 @@ func (rcpts Receipts) String() string { return b.String() } +// UnmarshalText implements encoding.TextUnmarshaler func (rcpts *Receipts) UnmarshalText(text []byte) error { addresses, err := ParseReceipts(string(text)) if err != nil { @@ -58,16 +69,18 @@ func (rcpts *Receipts) UnmarshalText(text []byte) error { return nil } +// UnmarshalJSON supports either a single string (comma separated) or an array of strings. func (rcpts *Receipts) UnmarshalJSON(b []byte) error { - if unquote, err := strconv.Unquote(string(b)); err == nil { - return rcpts.UnmarshalText([]byte(unquote)) + var s string + if err := json.Unmarshal(b, &s); err == nil { + return rcpts.UnmarshalText([]byte(s)) } - var s []string - if err := json.Unmarshal(b, &s); err != nil { - return nil + var list []string + if err := json.Unmarshal(b, &list); err != nil { + return err } var addresses []*mail.Address - for _, i := range s { + for _, i := range list { address, err := mail.ParseAddress(i) if err != nil { return err @@ -78,10 +91,13 @@ func (rcpts *Receipts) UnmarshalJSON(b []byte) error { return nil } +// MarshalText implements encoding.TextMarshaler and returns the RFC5322 representation. func (rcpts Receipts) MarshalText() ([]byte, error) { return []byte(rcpts.String()), nil } +// MarshalJSON implements json.Marshaler. +// It encodes the receipts as a JSON array of email address strings. func (rcpts Receipts) MarshalJSON() ([]byte, error) { - return []byte(rcpts.String()), nil + return json.Marshal(rcpts.String()) } diff --git a/ocr/ocr.go b/ocr/ocr.go index 86ab53f..81bc018 100644 --- a/ocr/ocr.go +++ b/ocr/ocr.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/json" "errors" + "fmt" "io" "mime/multipart" "net/http" @@ -16,6 +17,12 @@ const ( var errNoResult = errors.New("no ocr result") +type ocrResponse struct { + ParsedResults []struct { + ParsedText string + } +} + // OCR reads image from reader r and converts it into string. func OCR(r io.Reader) (string, error) { return OCRWithClient(r, http.DefaultClient) @@ -49,16 +56,16 @@ func OCRWithClient(r io.Reader, client *http.Client) (string, error) { } defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("ocr request failed: %s", resp.Status) + } + b, err := io.ReadAll(resp.Body) if err != nil { return "", err } - var res struct { - ParsedResults []struct { - ParsedText string - } - } + var res ocrResponse if err := json.Unmarshal(b, &res); err != nil { return "", err } diff --git a/pool/pool.go b/pool/pool.go index c202780..057c3ce 100644 --- a/pool/pool.go +++ b/pool/pool.go @@ -2,22 +2,42 @@ package pool import "sync" +// Pool is a type-safe wrapper around [sync.Pool] that provides +// a generic, concurrency-safe object pool for any type T. type Pool[T any] struct { - p *sync.Pool + pool sync.Pool + // New optionally specifies a function to generate + // a value when Get would otherwise return nil. + // It may not be changed concurrently with calls to Get. + New func() *T } +// New creates a new Pool that allocates new zero-valued objects +// of type T when none are available. func New[T any]() *Pool[T] { - return NewFunc(func() *T { return new(T) }) -} - -func NewFunc[T any](fn func() *T) *Pool[T] { - return &Pool[T]{&sync.Pool{New: func() any { return fn() }}} + return &Pool[T]{New: func() *T { return new(T) }} } +// Get selects an arbitrary item from the [Pool], removes it from the +// Pool, and returns it to the caller. +// Get may choose to ignore the pool and treat it as empty. +// Callers should not assume any relation between values passed to [Pool.Put] and +// the values returned by Get. +// +// If Get would otherwise return nil and p.New is non-nil, Get returns +// the result of calling p.New. func (p *Pool[T]) Get() *T { - return p.p.Get().(*T) + x := p.pool.Get() + if x == nil { + if p.New != nil { + return p.New() + } + return nil + } + return x.(*T) } +// Put adds x to the pool. func (p *Pool[T]) Put(x *T) { - p.p.Put(x) + p.pool.Put(x) } diff --git a/pop3/pop3.go b/pop3/pop3.go index a799e27..c9f3a13 100644 --- a/pop3/pop3.go +++ b/pop3/pop3.go @@ -5,6 +5,7 @@ import ( "crypto/tls" "errors" "fmt" + "io" "log/slog" "net" "net/textproto" @@ -12,6 +13,8 @@ import ( "strings" ) +// Client represents a POP3 client connection. +// It embeds a textproto.Conn for low-level protocol communication. type Client struct { *textproto.Conn } @@ -26,6 +29,8 @@ const ( respContinue = "+ " ) +// Dial establishes a plain TCP connection to the POP3 server at the given address. +// The connection respects the provided context for timeout or cancellation. func Dial(ctx context.Context, addr string) (*Client, error) { var d net.Dialer conn, err := d.DialContext(ctx, "tcp", addr) @@ -35,6 +40,8 @@ func Dial(ctx context.Context, addr string) (*Client, error) { return NewClient(conn) } +// DialTLS establishes a secure POP3-over-TLS connection to the given address. +// The server name is automatically derived from the address for certificate verification. func DialTLS(ctx context.Context, addr string) (*Client, error) { host, _, err := net.SplitHostPort(addr) if err != nil { @@ -49,6 +56,8 @@ func DialTLS(ctx context.Context, addr string) (*Client, error) { return NewClient(conn) } +// NewClient initializes a POP3 client from an existing connection. +// It reads the server greeting line and validates that it starts with "+OK". func NewClient(conn net.Conn) (*Client, error) { c := &Client{textproto.NewConn(conn)} s, err := c.ReadLine() @@ -65,17 +74,26 @@ func NewClient(conn net.Conn) (*Client, error) { return c, nil } -// Stat returns the number of messages and their total size in bytes in the inbox. +// Auth authenticates the user using the USER/PASS commands. +// Returns an error if either step fails. +func (c *Client) Auth(user, pass string) error { + if _, err := c.Cmd("USER %s", false, user); err != nil { + return err + } + _, err := c.Cmd("PASS %s", false, pass) + return err +} + +// Stat returns the number of messages and the total mailbox size (in bytes). func (c *Client) Stat() (count int, size int, err error) { s, err := c.Cmd("STAT", false) if err != nil { return } - - // count size f := strings.Fields(s) - - // Total number of messages. + if len(f) < 2 { + return 0, 0, fmt.Errorf("invalid STAT response: %q", s) + } count, err = strconv.Atoi(f[0]) if err != nil { return @@ -83,26 +101,21 @@ func (c *Client) Stat() (count int, size int, err error) { if count == 0 { return } - - // Total size of all messages in bytes. size, err = strconv.Atoi(f[1]) - return } -// MessageID contains the ID and size of an individual message. +// MessageID represents a single message entry as returned by LIST or UIDL. +// It includes the message index, size, and optional UID. type MessageID struct { - // ID is the numerical index (non-unique) of the message. - ID int - Size int - - // UID is only present if the response is to the UIDL command. - UID string + ID int // Numerical message index (1-based) + Size int // Message size in bytes + UID string // Optional UID (only for UIDL command) } -// List returns a list of (message ID, message Size) pairs. -// If the optional id > 0, then only that particular message is listed. -// The message IDs are sequential, 1 to N. +// List returns message IDs and sizes from the mailbox. +// If id > 0, only that specific message is listed (single-line response). +// If id == 0, all messages are listed (multi-line response). func (c *Client) List(id int) ([]MessageID, error) { var ( s string @@ -110,47 +123,36 @@ func (c *Client) List(id int) ([]MessageID, error) { ) if id > 0 { - // Single line response listing one message. s, err = c.Cmd("LIST %d", false, id) } else { - // Multiline response listing all messages. s, err = c.Cmd("LIST", true) } if err != nil { return nil, err } - var ( - out []MessageID - lines = strings.Split(s, lineBreak) - ) - - for _, l := range lines { - // id size + var out []MessageID + for l := range strings.SplitSeq(s, lineBreak) { f := strings.Fields(l) if len(f) == 0 { - break + continue } - id, err := strconv.Atoi(f[0]) if err != nil { return nil, err } - size, err := strconv.Atoi(f[1]) if err != nil { return nil, err } - out = append(out, MessageID{ID: id, Size: size}) } - return out, nil } -// Uidl returns a list of (message ID, message UID) pairs. If the optional msgID -// is > 0, then only that particular message is listed. It works like Top() but only works on -// servers that support the UIDL command. Messages size field is not available in the UIDL response. +// Uidl returns message IDs and their unique identifiers (UIDs). +// If id > 0, only that specific message is listed. +// The UIDL command may not be supported by all servers. func (c *Client) Uidl(id int) ([]MessageID, error) { var ( s string @@ -158,52 +160,41 @@ func (c *Client) Uidl(id int) ([]MessageID, error) { ) if id > 0 { - // Single line response listing one message. s, err = c.Cmd("UIDL %d", false, id) } else { - // Multiline response listing all messages. s, err = c.Cmd("UIDL", true) } if err != nil { return nil, err } - var ( - out []MessageID - lines = strings.Split(s, lineBreak) - ) - - for _, l := range lines { - // id uid + var out []MessageID + for l := range strings.SplitSeq(s, lineBreak) { f := strings.Fields(l) if len(f) == 0 { - break + continue } - id, err := strconv.Atoi(f[0]) if err != nil { return nil, err } - out = append(out, MessageID{ID: id, UID: f[1]}) } - return out, nil } -// Retr downloads a message by the given id and returns the data -// of the entire message. +// Retr retrieves the full message text by ID, including headers and body. func (c *Client) Retr(id int) (string, error) { return c.Cmd("RETR %d", true, id) } -// Top retrieves a message by its ID with full headers and numLines lines of the body. +// Top retrieves message headers and the first numLines of the body. func (c *Client) Top(id int, numLines int) (string, error) { return c.Cmd("TOP %d %d", true, id, numLines) } -// Dele deletes one or more messages. The server only executes the -// deletions after a successful Quit(). +// Dele marks one or more messages for deletion. +// The deletions are finalized only after a successful Quit(). func (c *Client) Dele(ids ...int) error { for _, id := range ids { if _, err := c.Cmd("DELE %d", false, id); err != nil { @@ -213,29 +204,32 @@ func (c *Client) Dele(ids ...int) error { return nil } -// Rset clears the messages marked for deletion in the current session. +// Rset resets the deletion marks on all messages in the current session. func (c *Client) Rset() error { _, err := c.Cmd("RSET", false) return err } -// Noop issues a do-nothing NOOP command to the server. This is useful for -// prolonging open connections. +// Noop sends a NOOP command to keep the connection alive. +// Useful for preventing idle timeouts. func (c *Client) Noop() error { _, err := c.Cmd("NOOP", false) return err } -// Quit sends the QUIT command to server and gracefully closes the connection. -// Message deletions (DELE command) are only excuted by the server on a graceful -// quit and close. +// Quit sends the QUIT command and closes the connection gracefully. +// Deletions are committed only if QUIT succeeds. func (c *Client) Quit() error { if _, err := c.Cmd("QUIT", false); err != nil { + c.Close() return err } return c.Close() } +// Cmd sends a POP3 command with optional arguments and reads the response. +// If isMulti is true, the response is treated as multi-line and read until +// a line containing only "." is encountered. All lines are concatenated and returned. func (c *Client) Cmd(s string, isMulti bool, args ...any) (string, error) { slog.Debug(">>> " + fmt.Sprintf(s, args...)) if _, err := c.Conn.Cmd(s, args...); err != nil { @@ -256,39 +250,51 @@ func (c *Client) Cmd(s string, isMulti bool, args ...any) (string, error) { return s, nil } - var res string + var b strings.Builder for { s, err := c.ReadLine() if err != nil { + if err == io.EOF { + err = io.ErrUnexpectedEOF + } return "", err } slog.Debug("<<< " + s) - if s == "." { - break - } - res += s + lineBreak + // A single dot line marks the end of multi-line response. + // Lines beginning with a dot have one dot removed as per POP3 spec. + if len(s) > 0 && s[0] == '.' { + if len(s) == 1 { + break + } + s = s[1:] + } + b.WriteString(s) + b.WriteString(lineBreak) } - - return res, nil + return b.String(), nil } +// parseResp interprets a single-line POP3 response. +// It distinguishes between +OK, -ERR, and continuation ("+ ") responses, +// returning the response message or an error if the response is invalid. func parseResp(s string) (string, error) { - if len(s) == 0 { - return "", nil - } - - if s == respOK { + switch s { + case "", respOK: return "", nil - } else if strings.HasPrefix(s, respOKInfo) { - return strings.TrimPrefix(s, respOKInfo), nil - } else if s == respErr { - return "", errors.New("unknown error (no info specified in response)") - } else if strings.HasPrefix(s, respErrInfo) { - return "", errors.New(strings.TrimPrefix(s, respErrInfo)) - } else if strings.HasPrefix(s, respContinue) { - return strings.TrimPrefix(s, respContinue), nil - } else { - return "", fmt.Errorf("unknown response: %q", s) + case respErr: + return "", errors.New("server returned -ERR without info") + default: + switch { + case strings.HasPrefix(s, respOKInfo): + return strings.TrimPrefix(s, respOKInfo), nil + case strings.HasPrefix(s, respErrInfo): + return "", errors.New(strings.TrimPrefix(s, respErrInfo)) + case strings.HasPrefix(s, respContinue): + // Some servers send "+ " for continuation prompts (rare in simple POP3). + return strings.TrimPrefix(s, respContinue), nil + default: + return "", fmt.Errorf("unknown response: %q", s) + } } } diff --git a/processing/text/processor.go b/processing/text/processor.go index 0b43d0a..f56f9dd 100644 --- a/processing/text/processor.go +++ b/processing/text/processor.go @@ -1,64 +1,169 @@ package text import ( + "fmt" "regexp" "strings" ) +// Processor defines a generic interface for text processors. +// Each processor can optionally be executed only once (Once == true) +// and provides a human-readable description for debugging or logging. type Processor interface { + // Describe returns a short description of this processor. + Describe() string + // Once reports whether this processor should run only once. Once() bool + // Process performs the actual text transformation. Process(string) (string, error) } var ( - _ Processor = processor{} - _ Processor = RemoveByRegexp{} - _ Processor = Cut{} - _ Processor = Trim{} + _ Processor = new(processor) + _ Processor = new(multiProcessor) + _ Processor = RegexpRemover{} + _ Processor = Cutter{} + _ Processor = Trimmer{} ) +// processor is a generic implementation of Processor, +// allowing you to wrap any custom text function. type processor struct { + desc string once bool fn func(string) (string, error) } -func NewProcessor(once bool, fn func(string) (string, error)) Processor { - return processor{once, fn} +// NewProcessor creates a new Processor from a function. +// +// desc - short description for debugging +// once - whether this processor should be executed only once +// fn - transformation function taking a string and returning a string/error +func NewProcessor(desc string, once bool, fn func(string) (string, error)) Processor { + return &processor{desc, once, fn} } -func WrapFunc(fn func(string) string) func(string) (string, error) { - return func(s string) (string, error) { return fn(s), nil } -} +// Describe returns the processor's description string. +func (p *processor) Describe() string { return p.desc } + +// Once reports whether this processor should run only once. +func (p *processor) Once() bool { return p.once } -func (p processor) Once() bool { return p.once } -func (p processor) Process(s string) (string, error) { +// Process executes the wrapped function to transform the text. +func (p *processor) Process(s string) (string, error) { return p.fn(s) } -type RemoveByRegexp struct { - *regexp.Regexp +// multiProcessor executes multiple sub-processors as a single Processor. +type multiProcessor struct { + desc string + once bool + processors []Processor +} + +// NewMultiProcessor creates a new MultiProcessor. +// +// desc - human-readable name +// once - whether this processor should execute only once +// procs - list of sub-processors +func NewMultiProcessor(desc string, once bool, procs ...Processor) Processor { + return &multiProcessor{desc, once, procs} +} + +// Describe returns the description for debugging or logging. +func (m *multiProcessor) Describe() string { return m.desc } + +// Once reports whether the MultiProcessor should run only once. +func (m *multiProcessor) Once() bool { return m.once } + +// Process executes all sub-processors. +func (m *multiProcessor) Process(s string) (string, error) { + for _, p := range m.processors { + var err error + s, err = p.Process(s) + if err != nil { + return "", err + } + } + return s, nil } -func (RemoveByRegexp) Once() bool { return false } -func (p RemoveByRegexp) Process(s string) (string, error) { - return p.ReplaceAllString(s, ""), nil +// RegexpRemover removes substrings that match the given regular expression. +type RegexpRemover struct { + Re *regexp.Regexp +} + +// Describe returns a string representation of the RegexpRemover. +func (p RegexpRemover) Describe() string { return fmt.Sprintf("RegexpRemover(%q)", p.Re.String()) } + +// Once always returns false, meaning this processor can be applied repeatedly. +func (RegexpRemover) Once() bool { return false } + +// Process removes all matches of the regular expression from the input string. +func (p RegexpRemover) Process(s string) (string, error) { + return p.Re.ReplaceAllString(s, ""), nil } -type Cut struct { +// Cutter splits the input by the given separator and keeps only the part before it. +type Cutter struct { Sep string } -func (Cut) Once() bool { return true } -func (p Cut) Process(s string) (string, error) { +// Describe returns a string representation of the Cutter. +func (p Cutter) Describe() string { return fmt.Sprintf("Cutter(%q)", p.Sep) } + +// Once always returns true, meaning this processor should be run only once. +func (Cutter) Once() bool { return true } + +// Process cuts the string at the first occurrence of the separator and returns the left part. +func (p Cutter) Process(s string) (string, error) { before, _, _ := strings.Cut(s, p.Sep) return before, nil } -type Trim struct { +// Trimmer removes all leading and trailing characters from the given cutset. +type Trimmer struct { Cutset string } -func (Trim) Once() bool { return false } -func (p Trim) Process(s string) (string, error) { +// Describe returns a string representation of the Trimmer. +func (p Trimmer) Describe() string { return fmt.Sprintf("Trimmer(%q)", p.Cutset) } + +// Once always returns false, meaning this processor can be applied repeatedly. +func (Trimmer) Once() bool { return false } + +// Process trims all leading and trailing characters in Cutset from the input string. +func (p Trimmer) Process(s string) (string, error) { return strings.Trim(s, p.Cutset), nil } + +// TrimSpace returns a processor that removes leading and trailing spaces. +func TrimSpace() Processor { + return NewProcessor("TrimSpace", false, WrapFunc(strings.TrimSpace)) +} + +// CutSpace returns a processor that extracts the first word in the input string. +func CutSpace() Processor { + return NewProcessor("CutSpace", true, func(s string) (string, error) { + if fs := strings.Fields(s); len(fs) > 0 { + return fs[0], nil + } + return "", nil + }) +} + +// RemoveParentheses returns a processor that remove both western and full-width parentheses. +func RemoveParentheses() Processor { + return NewMultiProcessor( + "RemoveParentheses", + true, + RegexpRemover{regexp.MustCompile(`\([^\)]*\)`)}, + RegexpRemover{regexp.MustCompile(`([^)]*)`)}, + ) +} + +// WrapFunc wraps a simple string -> string function +// into a function matching the Processor signature. +func WrapFunc(fn func(string) string) func(string) (string, error) { + return func(s string) (string, error) { return fn(s), nil } +} diff --git a/processing/text/processor_test.go b/processing/text/processor_test.go index 0ffa220..1c87ac4 100644 --- a/processing/text/processor_test.go +++ b/processing/text/processor_test.go @@ -5,7 +5,7 @@ import ( "testing" ) -func TestRemoveByRegexp(t *testing.T) { +func TestRegexpRemover(t *testing.T) { for i, testcase := range []struct { re *regexp.Regexp s string @@ -15,7 +15,7 @@ func TestRemoveByRegexp(t *testing.T) { {regexp.MustCompile(`\d+`), "abc123", "abc"}, {regexp.MustCompile(`\d+$`), "123abc456", "123abc"}, } { - if res, err := NewTasks().Append(RemoveByRegexp{testcase.re}).Process(testcase.s); err != nil { + if res, err := NewTasks().Append(RegexpRemover{testcase.re}).Process(testcase.s); err != nil { t.Error(err) } else if res != testcase.expected { t.Errorf("#%d: got %q; want %q", i, res, testcase.expected) @@ -23,7 +23,7 @@ func TestRemoveByRegexp(t *testing.T) { } } -func TestCut(t *testing.T) { +func TestCutter(t *testing.T) { for i, testcase := range []struct { seq string s string @@ -34,7 +34,7 @@ func TestCut(t *testing.T) { {" ", " abc 123", ""}, {"abc", "123abc456", "123"}, } { - if res, err := NewTasks().Append(Cut{testcase.seq}).Process(testcase.s); err != nil { + if res, err := NewTasks().Append(Cutter{testcase.seq}).Process(testcase.s); err != nil { t.Error(err) } else if res != testcase.expected { t.Errorf("#%d: got %q; want %q", i, res, testcase.expected) @@ -42,7 +42,7 @@ func TestCut(t *testing.T) { } } -func TestTrim(t *testing.T) { +func TestTrimmer(t *testing.T) { for i, testcase := range []struct { cutset string s string @@ -53,7 +53,7 @@ func TestTrim(t *testing.T) { {" ", " abc 123\n", "abc 123\n"}, {" \n", " abc 123\n", "abc 123"}, } { - if res, err := NewTasks().Append(Trim{testcase.cutset}).Process(testcase.s); err != nil { + if res, err := NewTasks().Append(Trimmer{testcase.cutset}).Process(testcase.s); err != nil { t.Error(err) } else if res != testcase.expected { t.Errorf("#%d: got %q; want %q", i, res, testcase.expected) diff --git a/processing/text/task.go b/processing/text/task.go index b6241d6..278e77e 100644 --- a/processing/text/task.go +++ b/processing/text/task.go @@ -1,73 +1,64 @@ package text import ( - "regexp" - "strings" + "fmt" ) +var MaxIter = 100 + +// Tasks represents an ordered list of text processors. +// Each processor in the list will be executed sequentially, +// and repeated until the text no longer changes. type Tasks struct { tasks []Processor } +// NewTasks creates a new Tasks instance with the provided processors. func NewTasks(tasks ...Processor) *Tasks { return &Tasks{tasks} } -func (t *Tasks) Process(s string) (string, error) { - var err error - for first, output := true, s; ; { +// Process executes all configured processors on a single string. +func (t *Tasks) Process(s string) (output string, err error) { + output = s + first := true + for range MaxIter { for _, task := range t.tasks { if first || !task.Once() { s, err = task.Process(s) if err != nil { - return "", err + return "", fmt.Errorf("%s error: %w", task.Describe(), err) } } } if s == output { - return output, nil - } else { - output = s + return } + output = s if first { first = false } } + err = fmt.Errorf("max iteration limit reached") + return } -func (t *Tasks) ProcessAll(s []string) ([]string, error) { - var output []string - for _, i := range s { - s, err := t.Process(i) +// ProcessAll applies all processors to a slice of strings, returning the processed results. +// If any processor returns an error, processing stops and the error is returned. +func (t *Tasks) ProcessAll(inputs []string) ([]string, error) { + output := make([]string, len(inputs)) + for i, in := range inputs { + out, err := t.Process(in) if err != nil { return nil, err } - output = append(output, s) + output[i] = out } return output, nil } +// Append adds one or more processors to the task list and returns the updated instance. func (t *Tasks) Append(tasks ...Processor) *Tasks { t.tasks = append(t.tasks, tasks...) return t } - -func (t *Tasks) TrimSpace() *Tasks { - return t.Append(NewProcessor(false, WrapFunc(strings.TrimSpace))) -} - -func (t *Tasks) CutSpace() *Tasks { - return t.Append(NewProcessor(true, func(s string) (string, error) { - if fs := strings.Fields(s); len(fs) > 0 { - return fs[0], nil - } - return "", nil - })) -} - -func (t *Tasks) RemoveParentheses() *Tasks { - return t.Append( - RemoveByRegexp{regexp.MustCompile(`\([^)]*\)`)}, - RemoveByRegexp{regexp.MustCompile(`([^)]*)`)}, - ) -} diff --git a/processing/text/task_test.go b/processing/text/task_test.go index 79e6585..976c0ff 100644 --- a/processing/text/task_test.go +++ b/processing/text/task_test.go @@ -3,7 +3,7 @@ package text import "testing" func TestTrimSpace(t *testing.T) { - task := NewTasks().TrimSpace() + task := NewTasks(TrimSpace()) for i, testcase := range []struct { s string expected string @@ -22,7 +22,7 @@ func TestTrimSpace(t *testing.T) { } func TestCutSpace(t *testing.T) { - task := NewTasks().CutSpace() + task := NewTasks(CutSpace()) for i, testcase := range []struct { s string expected string @@ -41,7 +41,7 @@ func TestCutSpace(t *testing.T) { } func TestRemoveParentheses(t *testing.T) { - task := NewTasks().RemoveParentheses() + task := NewTasks(RemoveParentheses()) for i, testcase := range []struct { s string expected string @@ -62,7 +62,7 @@ func TestRemoveParentheses(t *testing.T) { } func TestTasks(t *testing.T) { - task := NewTasks().TrimSpace().RemoveParentheses().CutSpace() + task := NewTasks(TrimSpace(), RemoveParentheses(), CutSpace()) for i, testcase := range []struct { s string expected string diff --git a/progressbar/counter.go b/progressbar/counter.go new file mode 100644 index 0000000..f17675c --- /dev/null +++ b/progressbar/counter.go @@ -0,0 +1,39 @@ +package progressbar + +import ( + "io" + + "github.com/sunshineplan/utils/counter" +) + +type genericCounter interface { + Add(int64) int64 + Write([]byte) (int, error) + Get() int64 +} + +var ( + _ genericCounter = new(numberCounter) + _ genericCounter = new(writerCounter) +) + +func newNumberCounter() genericCounter { + return new(numberCounter) +} + +func newWriterCounter(w io.Writer) genericCounter { + return &writerCounter{counter.NewCounterWriter(w, new(counter.Counter))} +} + +type numberCounter struct { + counter.Counter +} + +func (c *numberCounter) Write(_ []byte) (n int, err error) { return } + +type writerCounter struct { + *counter.CounterWriter +} + +func (c *writerCounter) Add(_ int64) (n int64) { return } +func (c *writerCounter) Get() int64 { return c.CounterWriter.Bytes() } diff --git a/progressbar/progressbar.go b/progressbar/progressbar.go index 7d611af..6dda4a6 100644 --- a/progressbar/progressbar.go +++ b/progressbar/progressbar.go @@ -11,7 +11,6 @@ import ( "text/template" "time" - "github.com/sunshineplan/utils/counter" "github.com/sunshineplan/utils/unit" ) @@ -27,7 +26,7 @@ var dots = []string{". ", ".. ", "..."} // ProgressBar represents a customizable progress bar for tracking task progress. // It supports configurable templates, units, and refresh intervals. -type ProgressBar struct { +type ProgressBar[T int | int64] struct { mu sync.Mutex buf strings.Builder @@ -44,7 +43,7 @@ type ProgressBar struct { renderInterval time.Duration template *template.Template - current counter.Counter + current genericCounter total int64 additional string speed float64 @@ -61,17 +60,12 @@ type format struct { // New creates a new ProgressBar with the specified total count and default options. // It panics if total is less than or equal to zero. -func New(total int) *ProgressBar { - return New64(int64(total)) -} - -// New64 returns a new ProgressBar with default options. -func New64(total int64) *ProgressBar { +func New[T int | int64](total T) *ProgressBar[T] { if total <= 0 { panic(fmt.Sprintf("invalid total number: %d", total)) } ctx, cancel := context.WithCancel(context.Background()) - return &ProgressBar{ + return &ProgressBar[T]{ ctx: ctx, cancel: cancel, msgChan: make(chan string, 1), @@ -79,13 +73,14 @@ func New64(total int64) *ProgressBar { blockWidth: 40, refreshInterval: defaultRefresh, template: template.Must(template.New("ProgressBar").Parse(defaultTemplate)), + current: newNumberCounter(), total: int64(total), } } // SetWidth sets the progress bar block width. // It panics if called after the progress bar has started or if blockWidth is less than or equal to zero. -func (pb *ProgressBar) SetWidth(blockWidth int) *ProgressBar { +func (pb *ProgressBar[T]) SetWidth(blockWidth int) *ProgressBar[T] { pb.mu.Lock() defer pb.mu.Unlock() if !pb.start.IsZero() { @@ -100,7 +95,7 @@ func (pb *ProgressBar) SetWidth(blockWidth int) *ProgressBar { // SetRefreshInterval sets progress bar refresh interval time for check speed. // It panics if called after the progress bar has started or if interval is less than or equal to zero. -func (pb *ProgressBar) SetRefreshInterval(interval time.Duration) *ProgressBar { +func (pb *ProgressBar[T]) SetRefreshInterval(interval time.Duration) *ProgressBar[T] { pb.mu.Lock() defer pb.mu.Unlock() if !pb.start.IsZero() { @@ -115,7 +110,7 @@ func (pb *ProgressBar) SetRefreshInterval(interval time.Duration) *ProgressBar { // SetRenderInterval sets the interval for updating the progress bar display. // It panics if called after the progress bar has started or if interval is less than or equal to zero. -func (pb *ProgressBar) SetRenderInterval(interval time.Duration) *ProgressBar { +func (pb *ProgressBar[T]) SetRenderInterval(interval time.Duration) *ProgressBar[T] { pb.mu.Lock() defer pb.mu.Unlock() if !pb.start.IsZero() { @@ -129,7 +124,7 @@ func (pb *ProgressBar) SetRenderInterval(interval time.Duration) *ProgressBar { } // SetTemplate sets progress bar template. -func (pb *ProgressBar) SetTemplate(tmplt string) error { +func (pb *ProgressBar[T]) SetTemplate(tmplt string) error { pb.mu.Lock() defer pb.mu.Unlock() if !pb.start.IsZero() { @@ -148,7 +143,7 @@ func (pb *ProgressBar) SetTemplate(tmplt string) error { // SetUnit sets progress bar unit. // It panics if called after the progress bar has started. -func (pb *ProgressBar) SetUnit(unit string) *ProgressBar { +func (pb *ProgressBar[T]) SetUnit(unit string) *ProgressBar[T] { pb.mu.Lock() defer pb.mu.Unlock() if !pb.start.IsZero() { @@ -159,22 +154,22 @@ func (pb *ProgressBar) SetUnit(unit string) *ProgressBar { } // Add adds the specified amount to the progress bar. -func (pb *ProgressBar) Add(n int64) { - pb.current.Add(n) +func (pb *ProgressBar[T]) Add(n T) { + pb.current.Add(int64(n)) } // Additional adds the specified string to the progress bar. -func (pb *ProgressBar) Additional(s string) { +func (pb *ProgressBar[T]) Additional(s string) { pb.mu.Lock() defer pb.mu.Unlock() pb.additional = s } -func (pb *ProgressBar) now() int64 { - return pb.current.Load() +func (pb *ProgressBar[T]) now() int64 { + return pb.current.Get() } -func (pb *ProgressBar) print(s string, msg bool) { +func (pb *ProgressBar[T]) print(s string, msg bool) { pb.mu.Lock() defer pb.mu.Unlock() pb.buf.Reset() @@ -197,7 +192,7 @@ func (pb *ProgressBar) print(s string, msg bool) { io.WriteString(os.Stdout, pb.buf.String()) } -func (pb *ProgressBar) startRefresh() { +func (pb *ProgressBar[T]) startRefresh() { start := pb.start maxRefresh := pb.refreshInterval * maxRefreshMultiple ticker := time.NewTicker(pb.refreshInterval) @@ -228,7 +223,7 @@ func (pb *ProgressBar) startRefresh() { } } -func (pb *ProgressBar) startCount() { +func (pb *ProgressBar[T]) startCount() { interval := pb.renderInterval if interval == 0 { interval = time.Second @@ -314,7 +309,7 @@ func (pb *ProgressBar) startCount() { } // Start starts the progress bar. -func (pb *ProgressBar) Start() error { +func (pb *ProgressBar[T]) Start() error { pb.mu.Lock() defer pb.mu.Unlock() if !pb.start.IsZero() { @@ -327,7 +322,7 @@ func (pb *ProgressBar) Start() error { } // Message sets a message to be displayed on the progress bar. -func (pb *ProgressBar) Message(msg string) error { +func (pb *ProgressBar[T]) Message(msg string) error { defer func() { recover() }() select { case <-pb.done: @@ -341,30 +336,25 @@ func (pb *ProgressBar) Message(msg string) error { return nil } -// Done waits the progress bar finished. Same as [ProgressBar.Wait](). -func (pb *ProgressBar) Done() { - pb.Wait() -} - // Wait blocks until the progress bar is finished. -func (pb *ProgressBar) Wait() { +func (pb *ProgressBar[T]) Wait() { <-pb.done } // Cancel cancels the progress bar. -func (pb *ProgressBar) Cancel() { +func (pb *ProgressBar[T]) Cancel() { pb.cancel() } // FromReader starts the progress bar from a reader. -func (pb *ProgressBar) FromReader(r io.Reader, w io.Writer) (int64, error) { +func (pb *ProgressBar[T]) FromReader(r io.Reader, w io.Writer) (int64, error) { + pb.current = newWriterCounter(w) if err := pb.Start(); err != nil { return 0, err } - n, err := io.Copy(pb.current.AddWriter(w), r) + n, err := io.Copy(pb.current, r) if err != nil { pb.Cancel() - return n, err } - return n, nil + return n, err } diff --git a/progressbar/progressbar_test.go b/progressbar/progressbar_test.go index 9795c46..49b46b4 100644 --- a/progressbar/progressbar_test.go +++ b/progressbar/progressbar_test.go @@ -23,7 +23,7 @@ func TestProgessBar(t *testing.T) { pb.Add(1) time.Sleep(time.Second) } - pb.Done() + pb.Wait() pb = New(10).SetRefreshInterval(500 * time.Millisecond) pb.Start() @@ -32,7 +32,7 @@ func TestProgessBar(t *testing.T) { pb.Add(1) time.Sleep(time.Second) } - pb.Done() + pb.Wait() pb = New(0) pb.Start() @@ -63,7 +63,7 @@ func TestMessage(t *testing.T) { pb.Add(1) time.Sleep(time.Second) } - pb.Done() + pb.Wait() close(stopCh) select { case err := <-errCh: @@ -88,7 +88,7 @@ func TestCancel(t *testing.T) { time.Sleep(time.Second) } }() - pb.Done() + pb.Wait() } func TestFromReader(t *testing.T) { @@ -97,7 +97,7 @@ func TestFromReader(t *testing.T) { t.Fatal(err) } defer resp.Body.Close() - total, err := strconv.Atoi(resp.Header.Get("content-length")) + total, err := strconv.ParseInt(resp.Header.Get("content-length"), 10, 64) if err != nil { t.Fatal(err) } @@ -105,11 +105,11 @@ func TestFromReader(t *testing.T) { if _, err := pb.FromReader(resp.Body, io.Discard); err != nil { t.Fatal(err) } - pb.Done() + pb.Wait() } func TestSetTemplate(t *testing.T) { - pb := &ProgressBar{} + pb := &ProgressBar[int]{} if err := pb.SetTemplate(`{{.Done}}`); err != nil { t.Error(err) } diff --git a/retry/retry.go b/retry/retry.go index 536f9dc..94b75b4 100644 --- a/retry/retry.go +++ b/retry/retry.go @@ -2,39 +2,29 @@ package retry import ( "errors" + "fmt" "time" ) -var errNoMoreRetry error = errorNoMoreRetry("no more retry") +// ErrNoMoreRetry is a sentinel error indicating that no more retries should be performed. +var ErrNoMoreRetry = errors.New("no more retry") -type errorNoMoreRetry string - -func (err errorNoMoreRetry) Error() string { - return string(err) -} - -func (errorNoMoreRetry) Unwrap() error { - return errNoMoreRetry +// StopRetry creates a wrapped error indicating that retries should stop. +func StopRetry(msg string) error { + return fmt.Errorf("%w: %s", ErrNoMoreRetry, msg) } -// IsNoMoreRetry reports whether error is NoMoreRetry error. +// IsNoMoreRetry reports whether the given error indicates to stop retrying. func IsNoMoreRetry(err error) bool { - if e, ok := err.(interface{ Unwrap() []error }); ok { - for _, err := range e.Unwrap() { - if IsNoMoreRetry(err) { - return true - } - } - return false - } - return errors.Is(err, errNoMoreRetry) + return errors.Is(err, ErrNoMoreRetry) } -// ErrNoMoreRetry tells function does no more retry. -func ErrNoMoreRetry(err string) error { return errorNoMoreRetry(err) } - -// Do keeps retrying the function until no error is returned. -func Do(fn func() error, attempts, delay int) error { +// Do executes fn repeatedly until it succeeds, the attempts are exhausted, +// or fn returns an error that indicates no more retries. +func Do(fn func() error, attempts int, delay time.Duration) error { + if attempts <= 0 { + return errors.New("invalid attempts count") + } var errs []error for i := range attempts { err := fn() @@ -44,8 +34,9 @@ func Do(fn func() error, attempts, delay int) error { errs = append(errs, err) if IsNoMoreRetry(err) { break - } else if i < attempts-1 { - time.Sleep(time.Second * time.Duration(delay)) + } + if i < attempts-1 { + time.Sleep(delay) } } return errors.Join(errs...) diff --git a/retry/retry_test.go b/retry/retry_test.go index 722b5bd..294cb2a 100644 --- a/retry/retry_test.go +++ b/retry/retry_test.go @@ -4,17 +4,18 @@ import ( "errors" "strconv" "testing" + "time" ) func TestRetry(t *testing.T) { - if err := ErrNoMoreRetry("error"); !errors.Is(err, errNoMoreRetry) { - t.Error("expected err is errNoMoreRetry; got not") + if err := StopRetry("error"); !errors.Is(err, ErrNoMoreRetry) { + t.Error("expected err is ErrNoMoreRetry; got not") } var i int if err := Do(func() error { defer func() { i++ }() return nil - }, 3, 1); err != nil { + }, 3, time.Second); err != nil { t.Errorf("expected nil error; got non-nil error %v", err) } else if i != 1 { t.Errorf("expected 1; got %d", i) @@ -24,7 +25,7 @@ func TestRetry(t *testing.T) { if err := Do(func() error { defer func() { i++ }() return errors.New("error" + strconv.Itoa(i)) - }, 3, 1); err == nil { + }, 3, time.Second); err == nil { t.Error("expected non-nil error; got nil error") } else if expect := "error0\nerror1\nerror2"; err.Error() != expect { t.Errorf("expected %s; got %s", expect, err) @@ -33,10 +34,10 @@ func TestRetry(t *testing.T) { i = 0 if err := Do(func() error { defer func() { i++ }() - return ErrNoMoreRetry("error" + strconv.Itoa(i)) - }, 3, 1); !IsNoMoreRetry(err) { + return StopRetry("error" + strconv.Itoa(i)) + }, 3, time.Second); !IsNoMoreRetry(err) { t.Errorf("expected ErrNoMoreRetry; got %s", err) - } else if err.Error() != "error0" { - t.Errorf("expected error0; got %d", i) + } else if err.Error() != "no more retry: error0" { + t.Errorf("expected error0; got %s", err) } } diff --git a/slice/slice.go b/slice/slice.go index 2d272e9..894a27e 100644 --- a/slice/slice.go +++ b/slice/slice.go @@ -1,13 +1,12 @@ package slice -// Deduplicate removes duplicate items in slice. +// Deduplicate removes duplicate elements while preserving order. func Deduplicate[S ~[]E, E comparable](s S) S { - if s == nil { + if len(s) == 0 { return nil } - - res := S{} - m := make(map[E]struct{}) + m := make(map[E]struct{}, len(s)) + res := make(S, 0, len(s)) for _, i := range s { if _, ok := m[i]; !ok { m[i] = struct{}{} diff --git a/smtp/auth.go b/smtp/auth.go index dde5335..8bb308b 100644 --- a/smtp/auth.go +++ b/smtp/auth.go @@ -1,44 +1,49 @@ package smtp import ( + "bytes" "errors" + "fmt" "net/smtp" - "strings" + "slices" ) +// loginAuth implements the LOGIN authentication mechanism for SMTP. var _ smtp.Auth = &loginAuth{} +// loginAuth holds the credentials and server information for LOGIN authentication. type loginAuth struct { - username, password, server string + username, password, host string } +// Start initiates the LOGIN authentication process, verifying TLS and server name. func (a *loginAuth) Start(server *smtp.ServerInfo) (string, []byte, error) { if !server.TLS && !isLocalhost(server.Name) { return "", nil, errors.New("unencrypted connection") } - if server.Name != a.server { + if server.Name != a.host { return "", nil, errors.New("wrong server name") } resp := []byte(a.username) return "LOGIN", resp, nil } +// Next handles the server's challenge, responding with username or password as needed. func (a *loginAuth) Next(fromServer []byte, more bool) ([]byte, error) { if more { - if strings.Contains(string(fromServer), "Username") { - resp := []byte(a.username) - return resp, nil + if bytes.Contains(fromServer, []byte("Username")) { + return []byte(a.username), nil } - if strings.Contains(string(fromServer), "Password") { - resp := []byte(a.password) - return resp, nil + if bytes.Contains(fromServer, []byte("Password")) { + return []byte(a.password), nil } // We've already sent everything. - return nil, errors.New("unexpected server challenge") + return nil, fmt.Errorf("unexpected server challenge: %s", string(fromServer)) } return nil, nil } +// isLocalhost checks if the server name is a localhost address. func isLocalhost(name string) bool { return name == "localhost" || name == "127.0.0.1" || name == "::1" } @@ -57,14 +62,13 @@ type Auth struct { func (c *Client) Auth2(auth *Auth) error { // Auto select auth mode var a smtp.Auth - if auths := strings.Join(c.auth, " "); strings.Contains(auths, "CRAM-MD5") { + if slices.Contains(c.auth, "CRAM-MD5") { a = smtp.CRAMMD5Auth(auth.Username, auth.Password) - } else if strings.Contains(auths, "PLAIN") { + } else if slices.Contains(c.auth, "PLAIN") { a = smtp.PlainAuth(auth.Identity, auth.Username, auth.Password, auth.Server) } else { a = &loginAuth{auth.Username, auth.Password, auth.Server} } - return c.Auth(a) } diff --git a/smtp/smtp.go b/smtp/smtp.go index cd21083..68b3635 100644 --- a/smtp/smtp.go +++ b/smtp/smtp.go @@ -134,7 +134,7 @@ func (c *Client) helo() error { func (c *Client) ehlo() error { _, msg, err := c.Cmd(250, "EHLO %s", c.localName) if err != nil { - return err + return fmt.Errorf("failed to send HELO: %w", err) } ext := make(map[string]string) extList := strings.Split(msg, "\n") @@ -149,7 +149,7 @@ func (c *Client) ehlo() error { c.auth = strings.Split(mechs, " ") } c.ext = ext - return err + return nil } // StartTLS sends the STARTTLS command and encrypts all further communication. @@ -161,7 +161,7 @@ func (c *Client) StartTLS(config *tls.Config) error { } _, _, err := c.Cmd(220, "STARTTLS") if err != nil { - return err + return fmt.Errorf("failed to send STARTTLS: %w", err) } if config == nil { config = &tls.Config{ServerName: c.serverName} diff --git a/txt/export.go b/txt/export.go index 9a69be2..d5e8832 100644 --- a/txt/export.go +++ b/txt/export.go @@ -5,18 +5,20 @@ import ( "os" ) -// Export writes contents to writer w. +// Export writes the contents to w using a buffered Writer. +// It returns any error encountered during writing or flushing. func Export(contents []string, w io.Writer) error { return NewWriter(w).WriteAll(contents) } -// ExportFile writes contents to file. +// ExportFile writes the contents to the specified file, overwriting it if it exists. +// The file is created with default permissions (0666, subject to umask). +// It returns any error encountered during file creation or writing. func ExportFile(contents []string, file string) error { f, err := os.Create(file) if err != nil { return err } defer f.Close() - return Export(contents, f) } diff --git a/txt/reader.go b/txt/reader.go index f80cd70..919a715 100644 --- a/txt/reader.go +++ b/txt/reader.go @@ -7,40 +7,53 @@ import ( "os" ) +// Reader provides buffered reading from an io.Reader, splitting input by lines. type Reader struct { scanner *bufio.Scanner } +// NewReader returns a new Reader that reads from r using a buffered scanner. +// It splits input by lines using the default bufio.Scanner line-splitting behavior (\n). func NewReader(r io.Reader) *Reader { return &Reader{bufio.NewScanner(r)} } -func (r *Reader) Iter() iter.Seq[string] { - return func(yield func(string) bool) { +// Iter returns an iterator over the lines read from the underlying io.Reader. +// Each iteration yields a line and a nil error, or an empty string and an error +// if the scanner encounters an error. +func (r *Reader) Iter() iter.Seq2[string, error] { + return func(yield func(string, error) bool) { for r.scanner.Scan() { - if !yield(r.scanner.Text()) { + if !yield(r.scanner.Text(), nil) { return } } + if err := r.scanner.Err(); err != nil { + yield("", err) + } } } -// ReadAll reads all contents from r. -func ReadAll(r io.Reader) []string { +// ReadAll reads all lines from r and returns them as a slice of strings. +// It returns an error if the underlying scanner encounters an error. +func ReadAll(r io.Reader) ([]string, error) { var s []string - for i := range NewReader(r).Iter() { + for i, err := range NewReader(r).Iter() { + if err != nil { + return nil, err + } s = append(s, i) } - return s + return s, nil } -// ReadFile reads all contents from file. +// ReadFile reads all lines from the specified file and returns them as a slice of strings. +// It returns an error if the file cannot be opened or read. func ReadFile(file string) ([]string, error) { f, err := os.Open(file) if err != nil { return nil, err } defer f.Close() - - return ReadAll(f), nil + return ReadAll(f) } diff --git a/txt/reader_test.go b/txt/reader_test.go index b9cddb0..6c9d434 100644 --- a/txt/reader_test.go +++ b/txt/reader_test.go @@ -11,7 +11,10 @@ func TestReader(t *testing.T) { B C ` - res := ReadAll(strings.NewReader(txt)) + res, err := ReadAll(strings.NewReader(txt)) + if err != nil { + t.Fatal(err) + } if expect := []string{"A", "B", "C"}; !slices.Equal(expect, res) { t.Errorf("expected %v; got %v", expect, res) } diff --git a/txt/writer.go b/txt/writer.go index 4deaeb2..44a6c75 100644 --- a/txt/writer.go +++ b/txt/writer.go @@ -31,7 +31,6 @@ func (w *Writer) WriteLine(s string) (int, error) { if w.UseCRLF { return w.WriteString(s + "\r\n") } - return w.WriteString(s + "\n") } @@ -50,6 +49,5 @@ func (w *Writer) WriteAll(contents []string) error { return err } } - return w.Flush() } diff --git a/unit/bytesize.go b/unit/bytesize.go index 30ee81d..a8a9928 100644 --- a/unit/bytesize.go +++ b/unit/bytesize.go @@ -15,6 +15,20 @@ var ( var byteSizeRegexp = regexp.MustCompile(`^(\d+(?:\.\d+)?) ?((?i)[KMGTPE]?B?)$`) +// ByteSize represents a quantity of bytes. +type ByteSize int64 + +// Common byte size units. +const ( + B ByteSize = 1 + KB ByteSize = 1 << (10 * iota) + MB + GB + TB + PB + EB +) + var ( byteSizeStr = map[ByteSize]string{ B: "B", @@ -36,18 +50,8 @@ var ( } ) -type ByteSize int64 - -const ( - B ByteSize = 1 - KB ByteSize = 1 << (10 * iota) - MB - GB - TB - PB - EB -) - +// ParseByteSize parses a human-readable size string (e.g. "1.5GB", "100 KB") +// and returns the corresponding ByteSize value. func ParseByteSize(s string) (ByteSize, error) { s = strings.TrimSpace(s) res := byteSizeRegexp.FindStringSubmatch(s) @@ -65,41 +69,61 @@ func ParseByteSize(s string) (ByteSize, error) { return ByteSize(v * float64(strByteSize[unit])), nil } -func NewByteSize(n float64, unit ByteSize) ByteSize { - return ByteSize(n * float64(unit)) +// MustParseByteSize parses a size string and panics if it is invalid. +func MustParseByteSize(s string) ByteSize { + v, err := ParseByteSize(s) + if err != nil { + panic(err) + } + return v +} + +// NewByteSize creates a ByteSize from a numeric value and a unit string. +// Example: NewByteSize(1.5, "GB") -> 1.5GB. +// Returns an error if the unit is not recognized. +func NewByteSize(n float64, unit string) (ByteSize, error) { + unit = strings.ToUpper(strings.TrimSpace(unit)) + unit = strings.TrimSuffix(unit, "B") + if unit == "" { + unit = "B" + } + bs, ok := strByteSize[unit] + if !ok { + return 0, errors.New("unknown byte size unit: " + unit) + } + return ByteSize(n * float64(bs)), nil } +// DefaultByteSizeFormatFloat formats a float with up to 2 decimals, +// trimming trailing zeros and the decimal point if unnecessary. func DefaultByteSizeFormatFloat(f float64) string { s := strconv.FormatFloat(f, 'f', 2, 64) s = strings.TrimRight(s, "0") return strings.TrimSuffix(s, ".") } +// ByteSizeFormatFloat defines how floating-point values are formatted in String(). +// It can be replaced to customize output precision. var ByteSizeFormatFloat = DefaultByteSizeFormatFloat +// String returns a human-readable representation of the byte size. +// e.g. 1536 -> "1.5KB", 1048576 -> "1MB". func (n ByteSize) String() string { - unit := B - switch { - case n >= EB: - unit = EB - case n >= PB: - unit = PB - case n >= TB: - unit = TB - case n >= GB: - unit = GB - case n >= MB: - unit = MB - case n >= KB: - unit = KB + units := []ByteSize{EB, PB, TB, GB, MB, KB, B} + for _, unit := range units { + if n >= unit { + return ByteSizeFormatFloat(float64(n)/float64(unit)) + byteSizeStr[unit] + } } - return ByteSizeFormatFloat(float64(n)/float64(unit)) + byteSizeStr[unit] + return "0B" } +// MarshalText implements the encoding.TextMarshaler interface. func (b ByteSize) MarshalText() ([]byte, error) { return []byte(b.String()), nil } +// UnmarshalText implements the encoding.TextUnmarshaler interface. func (b *ByteSize) UnmarshalText(text []byte) error { bytes, err := ParseByteSize(string(text)) if err != nil { @@ -108,3 +132,20 @@ func (b *ByteSize) UnmarshalText(text []byte) error { *b = bytes return nil } + +// To converts the ByteSize to the specified unit and returns its string representation. +func (b ByteSize) To(unit string, decimals int) (string, error) { + unit = strings.ToUpper(strings.TrimSpace(unit)) + unit = strings.TrimSuffix(unit, "B") + if unit == "" { + unit = "B" + } + base, ok := strByteSize[unit] + if !ok { + return "", errors.New("invalid unit: " + unit) + } + value := float64(b) / float64(base) + s := strconv.FormatFloat(value, 'f', decimals, 64) + s = strings.TrimRight(s, "0") + return strings.TrimSuffix(s, ".") + byteSizeStr[base], nil +} diff --git a/unit/bytesize_test.go b/unit/bytesize_test.go index 1bec500..ee19b11 100644 --- a/unit/bytesize_test.go +++ b/unit/bytesize_test.go @@ -33,7 +33,6 @@ func TestByteSize(t *testing.T) { {KB, "1KB"}, {10 * MB, "10MB"}, {1536 * MB, "1.5GB"}, - {NewByteSize(1.5, GB), "1.5GB"}, } { if bytesize := ByteSize(testcase.n).String(); bytesize != testcase.str { t.Errorf("expected %q; got %q", testcase.str, bytesize) @@ -45,3 +44,84 @@ func TestByteSize(t *testing.T) { } } } + +func TestNewByteSize(t *testing.T) { + tests := []struct { + n float64 + unit string + want ByteSize + expectErr bool + }{ + {1, "B", B, false}, + {1, "KB", KB, false}, + {1, "mb", MB, false}, + {1.5, "GB", ByteSize(1.5 * float64(GB)), false}, + {2, "tb", ByteSize(2 * float64(TB)), false}, + {0, "MB", 0, false}, + {1, "XYZ", 0, true}, + } + + for _, tt := range tests { + got, err := NewByteSize(tt.n, tt.unit) + if tt.expectErr { + if err == nil { + t.Errorf("NewByteSize(%v, %q) expected error, got nil", tt.n, tt.unit) + } + continue + } + if err != nil { + t.Errorf("NewByteSize(%v, %q) unexpected error: %v", tt.n, tt.unit, err) + continue + } + if got != tt.want { + t.Errorf("NewByteSize(%v, %q) = %v, want %v", tt.n, tt.unit, got, tt.want) + } + } +} + +func TestTo(t *testing.T) { + tests := []struct { + size ByteSize + unit string + decimals int + want string + expectErr bool + }{ + // integer results + {ByteSize(1024), "KB", 0, "1KB", false}, + {ByteSize(1536), "KB", 0, "2KB", false}, + {ByteSize(1048576), "MB", 0, "1MB", false}, + + // fractional results + {ByteSize(1536), "KB", 2, "1.5KB", false}, + {ByteSize(1536), "MB", 3, "0.001MB", false}, + {MustParseByteSize("1.5GB"), "GB", 2, "1.5GB", false}, + + // small numbers < 1 + {ByteSize(1), "KB", 6, "0.000977KB", false}, + {ByteSize(500), "MB", 6, "0.000477MB", false}, + + // decimals = 0 for fractional -> rounds to nearest int + {ByteSize(1536), "KB", 0, "2KB", false}, + + // invalid unit + {ByteSize(1024), "XYZ", 2, "", true}, + } + + for _, tt := range tests { + got, err := tt.size.To(tt.unit, tt.decimals) + if tt.expectErr { + if err == nil { + t.Errorf("To(%q, %d) expected error, got nil", tt.unit, tt.decimals) + } + continue + } + if err != nil { + t.Errorf("To(%q, %d) unexpected error: %v", tt.unit, tt.decimals, err) + continue + } + if got != tt.want { + t.Errorf("To(%q, %d) = %q, want %q", tt.unit, tt.decimals, got, tt.want) + } + } +}