Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 21 additions & 8 deletions csv/export.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
package csv

import (
"fmt"
"io"
"os"
"reflect"
)

// Export writes slice as csv format with fieldnames to writer w.
Expand Down Expand Up @@ -37,17 +37,30 @@ func ExportUTF8File[S ~[]E, E any](fieldnames []string, slice S, file string) er
}

func export[S ~[]E, E any](fieldnames []string, slice S, w io.Writer, utf8bom bool) (err error) {
csvWriter := NewWriter(w, utf8bom)
var fields any
if len(fieldnames) == 0 {
if len(slice) == 0 {
return fmt.Errorf("can't get struct fieldnames from zero length slice")
t := reflect.TypeFor[E]()
for t.Kind() == reflect.Pointer {
t = t.Elem()
}
if kind := t.Kind(); kind == reflect.Struct || kind == reflect.Map {
fields = reflect.Zero(t).Interface()
}
err = csvWriter.WriteFields(slice[0])
} else {
err = csvWriter.WriteFields(fieldnames)
fields = fieldnames
}
if err != nil {
return
csvWriter := NewWriter(w, utf8bom)
if fields != nil {
if err = csvWriter.WriteFields(fields); err != nil {
return
}
} else {
csvWriter.fieldsWritten = true
csvWriter.zero = make([]string, 1)
csvWriter.pool.New = func() *[]string {
s := make([]string, 1)
return &s
}
}
return csvWriter.WriteAll(slice)
}
22 changes: 22 additions & 0 deletions csv/export_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ func testExport[E any](t *testing.T, tc testcase[E], result string) {
var b bytes.Buffer
if err := Export(tc.fieldnames, tc.slice, &b); err != nil {
t.Error(tc.name, err)
return
}
if r := b.String(); r != result {
t.Errorf("%s expected %q; got %q", tc.name, result, r)
Expand Down Expand Up @@ -111,3 +112,24 @@ a,b
t.Errorf("expected %q; got %q", result, r)
}
}

func TestExportSlice(t *testing.T) {
result := `1
2
3
`
var b bytes.Buffer
if err := Export(nil, []string{"1", "2", "3"}, &b); err != nil {
t.Fatal(err)
}
if r := b.String(); r != result {
t.Errorf("expected %q; got %q", result, r)
}
b.Reset()
if err := Export(nil, []int{1, 2, 3}, &b); err != nil {
t.Fatal(err)
}
if r := b.String(); r != result {
t.Errorf("expected %q; got %q", result, r)
}
}
50 changes: 33 additions & 17 deletions csv/reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"io"
"os"
"reflect"
"strings"
"sync"
)
Expand All @@ -14,8 +15,9 @@ type Reader struct {
*csv.Reader
closer io.Closer

once sync.Once
fields []string
once sync.Once
fields []string
hasFields bool

next []string
nextErr error
Expand All @@ -33,6 +35,7 @@ func NewReader(r io.Reader, hasFields bool) (*Reader, error) {
if err != nil {
return nil, err
}
reader.hasFields = true
}
return reader, nil
}
Expand Down Expand Up @@ -75,6 +78,7 @@ func (r *Reader) Read() (record []string, err error) {
// SetFields sets csv fields.
func (r *Reader) SetFields(fields []string) {
r.fields = fields
r.hasFields = true
}

// Next prepares the next record for reading with the Scan or Decode method.
Expand Down Expand Up @@ -106,22 +110,25 @@ func (r *Reader) Scan(dest ...any) error {
// Decode will unmarshal the current record into dest.
// If column's value is like "[...]", it will be treated as slice.
func (r *Reader) Decode(dest any) error {
if len(r.fields) == 0 {
return fmt.Errorf("csv fields is not parsed")
}
if r.next == nil && r.nextErr == nil {
return fmt.Errorf("Decode called without calling Next")
}
if r.nextErr != nil {
return r.nextErr
}
m := make(map[string]string)
for i, field := range r.fields {
if len(r.next) > i {
m[field] = r.next[i]
if r.hasFields {
if len(r.fields) == 0 {
return fmt.Errorf("csv fields is not parsed")
}
if r.next == nil && r.nextErr == nil {
return fmt.Errorf("Decode called without calling Next")
}
if r.nextErr != nil {
return r.nextErr
}
m := make(map[string]string)
for i, field := range r.fields {
if len(r.next) > i {
m[field] = r.next[i]
}
}
return setRow(dest, m)
}
return setRow(dest, m)
return setCell(dest, r.next[0])
}

// Close closes the underlying reader if it implements the io.Closer interface.
Expand All @@ -134,7 +141,16 @@ func (r *Reader) Close() error {

// DecodeAll decodes each record from r into dest.
func DecodeAll[S ~[]E, E any](r io.Reader, dest *S) (err error) {
reader, err := NewReader(r, true)
t := reflect.TypeFor[E]()
for t.Kind() == reflect.Pointer {
t = t.Elem()
}
var reader *Reader
if kind := t.Kind(); kind == reflect.Struct || kind == reflect.Map {
reader, err = NewReader(r, true)
} else {
reader, err = NewReader(r, false)
}
if err != nil {
return
}
Expand Down
21 changes: 21 additions & 0 deletions csv/reader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,24 @@ b,2,"[3,4]"
t.Errorf("expected %q; got %q", expect, res)
}
}

func TestDecodeSlice(t *testing.T) {
csv := `1
2
3
`
var s1 []string
if err := DecodeAll(strings.NewReader(csv), &s1); err != nil {
t.Fatal(err)
}
if expect := []string{"1", "2", "3"}; !reflect.DeepEqual(expect, s1) {
t.Errorf("expected %v; got %v", expect, s1)
}
var s2 []int
if err := DecodeAll(strings.NewReader(csv), &s2); err != nil {
t.Fatal(err)
}
if expect := []int{1, 2, 3}; !reflect.DeepEqual(expect, s2) {
t.Errorf("expected %v; got %v", expect, s2)
}
}
24 changes: 18 additions & 6 deletions csv/writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ func (w *Writer) WriteFields(fields any) error {
}
default:
v := reflect.ValueOf(fields)
if v.Kind() == reflect.Pointer {
v = reflect.Indirect(v)
for v.Kind() == reflect.Pointer {
v = v.Elem()
if !v.IsValid() {
return fmt.Errorf("can not get fieldnames from nil pointer struct")
}
Expand Down Expand Up @@ -107,6 +107,8 @@ func (w *Writer) Write(record any) error {
return fmt.Errorf("fieldnames has not be written yet")
}
switch d := record.(type) {
case string:
return w.Writer.Write([]string{d})
case []string:
if len(d) == 0 {
return nil
Expand All @@ -117,15 +119,19 @@ func (w *Writer) Write(record any) error {
if v.Kind() == reflect.Interface {
v = v.Elem()
}
if v.Kind() == reflect.Pointer {
v = reflect.Indirect(v)
for v.Kind() == reflect.Pointer {
v = v.Elem()
if !v.IsValid() {
return nil
}
}
r := w.pool.Get()
defer w.pool.Put(r)
switch v.Kind() {
case reflect.Slice:
for i := range v.Len() {
(*r)[i], _ = marshalText(v.Index(i).Interface())
}
case reflect.Map:
if keyType := reflect.TypeOf(v.Interface()).Key(); keyType.Kind() == reflect.String {
for i, field := range w.fields {
Expand Down Expand Up @@ -163,7 +169,7 @@ func (w *Writer) Write(record any) error {
}
}
default:
return fmt.Errorf("not support record format: %s", v.Kind())
(*r)[0], _ = marshalText(v.Interface())
}
if slices.Equal(*r, w.zero) {
return nil
Expand All @@ -175,9 +181,15 @@ func (w *Writer) Write(record any) error {
// WriteAll writes multiple CSV records to w using Write and then calls Flush, returning any error from the Flush.
func (w *Writer) WriteAll(records any) error {
switch s := records.(type) {
case []string:
for _, i := range s {
if err := w.Writer.Write([]string{i}); err != nil {
return err
}
}
case [][]string:
for _, i := range s {
if err := w.Write(i); err != nil {
if err := w.Writer.Write(i); err != nil {
return err
}
}
Expand Down
Loading