From 0641baeca2a91778ebc7145d8ca71554f4eb9abe Mon Sep 17 00:00:00 2001 From: sunshineplan Date: Wed, 3 Dec 2025 15:06:10 +0800 Subject: [PATCH] csv: Add support for exporting and decoding CSV slices Enhanced CSV export and decode functions to handle slices of primitive types (e.g., []string, []int) without requiring field names. Updated writer and reader logic to properly marshal and unmarshal single-column CSV data. Added corresponding tests to verify slice handling. --- csv/export.go | 29 +++++++++++++++++++-------- csv/export_test.go | 22 ++++++++++++++++++++ csv/reader.go | 50 ++++++++++++++++++++++++++++++---------------- csv/reader_test.go | 21 +++++++++++++++++++ csv/writer.go | 24 ++++++++++++++++------ 5 files changed, 115 insertions(+), 31 deletions(-) diff --git a/csv/export.go b/csv/export.go index 3bdd854..7861755 100644 --- a/csv/export.go +++ b/csv/export.go @@ -1,9 +1,9 @@ package csv import ( - "fmt" "io" "os" + "reflect" ) // Export writes slice as csv format with fieldnames to writer w. @@ -37,17 +37,30 @@ func ExportUTF8File[S ~[]E, E any](fieldnames []string, slice S, file string) er } func export[S ~[]E, E any](fieldnames []string, slice S, w io.Writer, utf8bom bool) (err error) { - csvWriter := NewWriter(w, utf8bom) + var fields any if len(fieldnames) == 0 { - if len(slice) == 0 { - return fmt.Errorf("can't get struct fieldnames from zero length slice") + t := reflect.TypeFor[E]() + for t.Kind() == reflect.Pointer { + t = t.Elem() + } + if kind := t.Kind(); kind == reflect.Struct || kind == reflect.Map { + fields = reflect.Zero(t).Interface() } - err = csvWriter.WriteFields(slice[0]) } else { - err = csvWriter.WriteFields(fieldnames) + fields = fieldnames } - if err != nil { - return + csvWriter := NewWriter(w, utf8bom) + if fields != nil { + if err = csvWriter.WriteFields(fields); err != nil { + return + } + } else { + csvWriter.fieldsWritten = true + csvWriter.zero = make([]string, 1) + csvWriter.pool.New = func() *[]string { + s := make([]string, 1) + return &s + } } return csvWriter.WriteAll(slice) } diff --git a/csv/export_test.go b/csv/export_test.go index da27fbd..1cc5bb4 100644 --- a/csv/export_test.go +++ b/csv/export_test.go @@ -15,6 +15,7 @@ func testExport[E any](t *testing.T, tc testcase[E], result string) { var b bytes.Buffer if err := Export(tc.fieldnames, tc.slice, &b); err != nil { t.Error(tc.name, err) + return } if r := b.String(); r != result { t.Errorf("%s expected %q; got %q", tc.name, result, r) @@ -111,3 +112,24 @@ a,b t.Errorf("expected %q; got %q", result, r) } } + +func TestExportSlice(t *testing.T) { + result := `1 +2 +3 +` + var b bytes.Buffer + if err := Export(nil, []string{"1", "2", "3"}, &b); err != nil { + t.Fatal(err) + } + if r := b.String(); r != result { + t.Errorf("expected %q; got %q", result, r) + } + b.Reset() + if err := Export(nil, []int{1, 2, 3}, &b); err != nil { + t.Fatal(err) + } + if r := b.String(); r != result { + t.Errorf("expected %q; got %q", result, r) + } +} diff --git a/csv/reader.go b/csv/reader.go index ac9bcae..06445c0 100644 --- a/csv/reader.go +++ b/csv/reader.go @@ -5,6 +5,7 @@ import ( "fmt" "io" "os" + "reflect" "strings" "sync" ) @@ -14,8 +15,9 @@ type Reader struct { *csv.Reader closer io.Closer - once sync.Once - fields []string + once sync.Once + fields []string + hasFields bool next []string nextErr error @@ -33,6 +35,7 @@ func NewReader(r io.Reader, hasFields bool) (*Reader, error) { if err != nil { return nil, err } + reader.hasFields = true } return reader, nil } @@ -75,6 +78,7 @@ func (r *Reader) Read() (record []string, err error) { // SetFields sets csv fields. func (r *Reader) SetFields(fields []string) { r.fields = fields + r.hasFields = true } // Next prepares the next record for reading with the Scan or Decode method. @@ -106,22 +110,25 @@ func (r *Reader) Scan(dest ...any) error { // Decode will unmarshal the current record into dest. // If column's value is like "[...]", it will be treated as slice. func (r *Reader) Decode(dest any) error { - if len(r.fields) == 0 { - return fmt.Errorf("csv fields is not parsed") - } - if r.next == nil && r.nextErr == nil { - return fmt.Errorf("Decode called without calling Next") - } - if r.nextErr != nil { - return r.nextErr - } - m := make(map[string]string) - for i, field := range r.fields { - if len(r.next) > i { - m[field] = r.next[i] + if r.hasFields { + if len(r.fields) == 0 { + return fmt.Errorf("csv fields is not parsed") } + if r.next == nil && r.nextErr == nil { + return fmt.Errorf("Decode called without calling Next") + } + if r.nextErr != nil { + return r.nextErr + } + m := make(map[string]string) + for i, field := range r.fields { + if len(r.next) > i { + m[field] = r.next[i] + } + } + return setRow(dest, m) } - return setRow(dest, m) + return setCell(dest, r.next[0]) } // Close closes the underlying reader if it implements the io.Closer interface. @@ -134,7 +141,16 @@ func (r *Reader) Close() error { // DecodeAll decodes each record from r into dest. func DecodeAll[S ~[]E, E any](r io.Reader, dest *S) (err error) { - reader, err := NewReader(r, true) + t := reflect.TypeFor[E]() + for t.Kind() == reflect.Pointer { + t = t.Elem() + } + var reader *Reader + if kind := t.Kind(); kind == reflect.Struct || kind == reflect.Map { + reader, err = NewReader(r, true) + } else { + reader, err = NewReader(r, false) + } if err != nil { return } diff --git a/csv/reader_test.go b/csv/reader_test.go index 874cbad..0a2e16a 100644 --- a/csv/reader_test.go +++ b/csv/reader_test.go @@ -73,3 +73,24 @@ b,2,"[3,4]" t.Errorf("expected %q; got %q", expect, res) } } + +func TestDecodeSlice(t *testing.T) { + csv := `1 +2 +3 +` + var s1 []string + if err := DecodeAll(strings.NewReader(csv), &s1); err != nil { + t.Fatal(err) + } + if expect := []string{"1", "2", "3"}; !reflect.DeepEqual(expect, s1) { + t.Errorf("expected %v; got %v", expect, s1) + } + var s2 []int + if err := DecodeAll(strings.NewReader(csv), &s2); err != nil { + t.Fatal(err) + } + if expect := []int{1, 2, 3}; !reflect.DeepEqual(expect, s2) { + t.Errorf("expected %v; got %v", expect, s2) + } +} diff --git a/csv/writer.go b/csv/writer.go index ecbb1da..0f3565b 100644 --- a/csv/writer.go +++ b/csv/writer.go @@ -51,8 +51,8 @@ func (w *Writer) WriteFields(fields any) error { } default: v := reflect.ValueOf(fields) - if v.Kind() == reflect.Pointer { - v = reflect.Indirect(v) + for v.Kind() == reflect.Pointer { + v = v.Elem() if !v.IsValid() { return fmt.Errorf("can not get fieldnames from nil pointer struct") } @@ -107,6 +107,8 @@ func (w *Writer) Write(record any) error { return fmt.Errorf("fieldnames has not be written yet") } switch d := record.(type) { + case string: + return w.Writer.Write([]string{d}) case []string: if len(d) == 0 { return nil @@ -117,8 +119,8 @@ func (w *Writer) Write(record any) error { if v.Kind() == reflect.Interface { v = v.Elem() } - if v.Kind() == reflect.Pointer { - v = reflect.Indirect(v) + for v.Kind() == reflect.Pointer { + v = v.Elem() if !v.IsValid() { return nil } @@ -126,6 +128,10 @@ func (w *Writer) Write(record any) error { r := w.pool.Get() defer w.pool.Put(r) switch v.Kind() { + case reflect.Slice: + for i := range v.Len() { + (*r)[i], _ = marshalText(v.Index(i).Interface()) + } case reflect.Map: if keyType := reflect.TypeOf(v.Interface()).Key(); keyType.Kind() == reflect.String { for i, field := range w.fields { @@ -163,7 +169,7 @@ func (w *Writer) Write(record any) error { } } default: - return fmt.Errorf("not support record format: %s", v.Kind()) + (*r)[0], _ = marshalText(v.Interface()) } if slices.Equal(*r, w.zero) { return nil @@ -175,9 +181,15 @@ func (w *Writer) Write(record any) error { // WriteAll writes multiple CSV records to w using Write and then calls Flush, returning any error from the Flush. func (w *Writer) WriteAll(records any) error { switch s := records.(type) { + case []string: + for _, i := range s { + if err := w.Writer.Write([]string{i}); err != nil { + return err + } + } case [][]string: for _, i := range s { - if err := w.Write(i); err != nil { + if err := w.Writer.Write(i); err != nil { return err } }