Skip to content

Commit

Permalink
Change global variables of default
Browse files Browse the repository at this point in the history
Change global variables of default read/write opts to members
 of trdsql structure.
  • Loading branch information
noborus committed Jun 6, 2019
1 parent af8c0d4 commit 7760062
Show file tree
Hide file tree
Showing 20 changed files with 194 additions and 152 deletions.
4 changes: 2 additions & 2 deletions cmd/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,8 @@ func Run(args []string) int {
flags := flag.NewFlagSet("trdsql", flag.ExitOnError)

tr := trdsql.NewTRDSQL()
ro := trdsql.DefaultReadOpts
wo := trdsql.DefaultWriteOpts
ro := &tr.ReadOpts
wo := &tr.WriteOpts

flags.Usage = func() {
fmt.Fprintf(os.Stderr, `Usage: %s [OPTIONS] [SQL(SELECT...)]
Expand Down
57 changes: 33 additions & 24 deletions csv_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ func TestCsvInputNew(t *testing.T) {
if err != nil {
t.Error(err)
}
_, err = NewCSVReader(file, DefaultReadOpts)
_, err = NewCSVReader(file, NewReadOpts())
if err != nil {
t.Error(`NewCSVReader error`)
}
Expand All @@ -39,7 +39,7 @@ func TestCsvInputNew(t *testing.T) {
func TestCsvEmptyNew(t *testing.T) {
const csvStream = ``
s := strings.NewReader(csvStream)
r, err := NewCSVReader(s, DefaultReadOpts)
r, err := NewCSVReader(s, NewReadOpts())
if err != nil {
t.Error(err)
}
Expand All @@ -50,39 +50,41 @@ func TestCsvEmptyNew(t *testing.T) {
}

func TestCsvHeaderNew(t *testing.T) {
DefaultReadOpts.InHeader = true
DefaultReadOpts.InDelimiter = ","
ro := NewReadOpts()
ro.InHeader = true
ro.InDelimiter = ","
csvStream := `h1,h2
v1,v2`
s := strings.NewReader(csvStream)
r, _ := NewCSVReader(s, DefaultReadOpts)
r, _ := NewCSVReader(s, ro)
header, _ := r.GetColumn(1)
if header[0] != "h1" || header[1] != "h2" {
t.Error("invalid header")
}
DefaultReadOpts.InHeader = false
}

func TestCsvEmptyColumnHeaderNew(t *testing.T) {
DefaultReadOpts.InHeader = true
DefaultReadOpts.InDelimiter = ","
ro := NewReadOpts()
ro.InHeader = true
ro.InDelimiter = ","
csvStream := `h1,
v1,v2`
s := strings.NewReader(csvStream)
r, _ := NewCSVReader(s, DefaultReadOpts)
r, _ := NewCSVReader(s, ro)
header, _ := r.GetColumn(1)
if header[0] != "h1" || header[1] != "c2" {
t.Error("invalid header")
}
}

func TestCsvEmptyColumnRowNew(t *testing.T) {
DefaultReadOpts.InHeader = true
DefaultReadOpts.InDelimiter = ","
ro := NewReadOpts()
ro.InHeader = true
ro.InDelimiter = ","
csvStream := `h1,h2
,v2`
s := strings.NewReader(csvStream)
r, _ := NewCSVReader(s, DefaultReadOpts)
r, _ := NewCSVReader(s, ro)
_, err := r.GetColumn(0)
if err != nil {
t.Error(err)
Expand All @@ -95,14 +97,15 @@ func TestCsvEmptyColumnRowNew(t *testing.T) {
}

func TestCsvColumnDifferenceNew(t *testing.T) {
DefaultReadOpts.InHeader = true
DefaultReadOpts.InDelimiter = ","
ro := NewReadOpts()
ro.InHeader = true
ro.InDelimiter = ","
csvStream := `h1,h2,h3
v1,v2,v3
x1,x2
z1`
s := strings.NewReader(csvStream)
r, _ := NewCSVReader(s, DefaultReadOpts)
r, _ := NewCSVReader(s, ro)
_, err := r.GetColumn(1)
if err != nil {
t.Error(err)
Expand All @@ -126,20 +129,22 @@ func TestCsvNoInputNew(t *testing.T) {
if err == nil {
t.Error(`Should error`)
}
_, err = NewCSVReader(file, DefaultReadOpts)
_, err = NewCSVReader(file, NewReadOpts())
if err != nil {
t.Error(`NewCSVReader error`)
}
}

func TestCsvIndefiniteInputFile(t *testing.T) {
ro := NewReadOpts()
ro.InHeader = false
ro.InDelimiter = ","

file, err := tableFileOpen("testdata/test_indefinite.csv")
if err != nil {
t.Error(err)
}
DefaultReadOpts.InHeader = false
DefaultReadOpts.InDelimiter = ","
cr, err := NewCSVReader(file, DefaultReadOpts)
cr, err := NewCSVReader(file, ro)
if err != nil {
t.Error(`NewCSVReader error`)
}
Expand All @@ -158,8 +163,10 @@ func TestCsvIndefiniteInputFile2(t *testing.T) {
if err != nil {
t.Error(err)
}
DefaultReadOpts.InDelimiter = ","
cr, err := NewCSVReader(file, DefaultReadOpts)
ro := NewReadOpts()
ro.InHeader = false
ro.InDelimiter = ","
cr, err := NewCSVReader(file, ro)
if err != nil {
t.Error(`NewCSVReader error`)
}
Expand All @@ -177,8 +184,10 @@ func TestCsvIndefiniteInputFile3(t *testing.T) {
if err != nil {
t.Error(err)
}
DefaultReadOpts.InDelimiter = ","
cr, err := NewCSVReader(file, DefaultReadOpts)
ro := NewReadOpts()
ro.InHeader = false
ro.InDelimiter = ","
cr, err := NewCSVReader(file, ro)
if err != nil {
t.Error(`NewCSVReader error`)
}
Expand All @@ -193,7 +202,7 @@ func TestCsvIndefiniteInputFile3(t *testing.T) {
}

func TestCsvOutNew(t *testing.T) {
out := NewCSVWrite(",", false)
out := NewCSVWrite(NewWriteOpts())
if out == nil {
t.Error(`csvOut error`)
}
Expand Down
48 changes: 19 additions & 29 deletions input.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,22 @@ type Reader interface {
// Import is parses the SQL statement and imports one or more tables.
// Return the rewritten SQL and error.
// No error is returned if there is no table to import.
func (trdsql *TRDSQL) Import(db *DDB, sqlstr string) (string, error) {
func (trd *TRDSQL) Import(db *DDB, sqlstr string) (string, error) {
tables := listTable(sqlstr)

if len(tables) == 0 {
// without FROM clause. ex. SELECT 1+1;
debug.Printf("table not found\n")
return sqlstr, nil
}

opts := trd.ReadOpts
created := make(map[string]bool)
for _, fileName := range tables {
if created[fileName] {
debug.Printf("already created \"%s\"\n", fileName)
continue
}
tableName, err := trdsql.ImportFile(db, fileName)
tableName, err := ImportFile(db, fileName, opts)
if err != nil {
return sqlstr, err
}
Expand Down Expand Up @@ -133,38 +133,32 @@ func isSQLKeyWords(str string) bool {
// ImportFile is imports a file.
// Return the escaped table name and error.
// Do not import if file not found (no error)
func (trdsql *TRDSQL) ImportFile(db *DDB, fileName string) (string, error) {
file, err := trdsql.importFileOpen(fileName)
func ImportFile(db *DDB, fileName string, opts ReadOpts) (string, error) {
file, err := importFileOpen(fileName)
if err != nil {
debug.Printf("%s\n", err)
return "", nil
}
defer file.Close()
reader, err := trdsql.NewReader(file, fileName)

if opts.InFormat == GUESS {
opts.InFormat = guessExtension(fileName)
}
reader, err := NewReader(file, opts)
if err != nil {
return "", err
}

if DefaultReadOpts.InSkip > 0 {
skip := make([]interface{}, 1)
for i := 0; i < DefaultReadOpts.InSkip; i++ {
r, e := reader.ReadRow(skip)
if e != nil {
log.Printf("ERROR: skip error %s", e)
break
}
debug.Printf("Skip row:%s\n", r)
}
}
tableName := db.EscapeTable(fileName)
columnNames, err := reader.GetColumn(DefaultReadOpts.InPreRead)
columnNames, err := reader.GetColumn(opts.InPreRead)
if err != nil {
if err != io.EOF {
return tableName, err
}
debug.Printf("EOF reached before argument number of rows")
}
columnTypes, err := reader.GetTypes()

if err != nil {
if err != io.EOF {
return tableName, err
Expand All @@ -178,28 +172,24 @@ func (trdsql *TRDSQL) ImportFile(db *DDB, fileName string) (string, error) {
if err != nil {
return tableName, err
}
err = db.Import(tableName, columnNames, reader, DefaultReadOpts.InPreRead)
err = db.Import(tableName, columnNames, reader, opts.InPreRead)
return tableName, err
}

// NewReader returns an Reader interface
// depending on the file to be imported.
func (trdsql *TRDSQL) NewReader(reader io.Reader, fileName string) (Reader, error) {
if DefaultReadOpts.InFormat == GUESS {
DefaultReadOpts.InFormat = guessExtension(fileName)
}

switch DefaultReadOpts.InFormat {
func NewReader(reader io.Reader, opts ReadOpts) (Reader, error) {
switch opts.InFormat {
case CSV:
return NewCSVReader(reader, DefaultReadOpts)
return NewCSVReader(reader, opts)
case LTSV:
return NewLTSVReader(reader)
return NewLTSVReader(reader, opts)
case JSON:
return NewJSONReader(reader)
case TBLN:
return NewTBLNReader(reader)
default:
return nil, fmt.Errorf("unknown formatt")
return nil, fmt.Errorf("unknown format")
}
}

Expand Down Expand Up @@ -232,7 +222,7 @@ func guessExtension(tableName string) Format {
}
}

func (trdsql *TRDSQL) importFileOpen(tableName string) (io.ReadCloser, error) {
func importFileOpen(tableName string) (io.ReadCloser, error) {
r := regexp.MustCompile(`\*|\?|\[`)
if r.MatchString(tableName) {
return globFileOpen(tableName)
Expand Down
19 changes: 17 additions & 2 deletions input_csv.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"encoding/csv"
"fmt"
"io"
"log"
"strconv"
)

Expand All @@ -16,8 +17,9 @@ type CSVRead struct {
inHeader bool
}

func NewCSVReader(r io.Reader, opts *ReadOpts) (Reader, error) {
func NewCSVReader(r io.Reader, opts ReadOpts) (Reader, error) {
var err error

if opts.InHeader {
opts.InPreRead--
}
Expand All @@ -28,6 +30,19 @@ func NewCSVReader(r io.Reader, opts *ReadOpts) (Reader, error) {
cr.reader.TrimLeadingSpace = true
cr.inHeader = opts.InHeader
cr.reader.Comma, err = delimiter(opts.InDelimiter)

if opts.InSkip > 0 {
skip := make([]interface{}, 1)
for i := 0; i < opts.InSkip; i++ {
r, e := cr.ReadRow(skip)
if e != nil {
log.Printf("ERROR: skip error %s", e)
break
}
debug.Printf("Skip row:%s\n", r)
}
}

return cr, err
}

Expand Down Expand Up @@ -83,7 +98,7 @@ func (cr *CSVRead) GetColumn(rowNum int) ([]string, error) {
func (cr *CSVRead) GetTypes() ([]string, error) {
cr.types = make([]string, len(cr.names))
for i := 0; i < len(cr.names); i++ {
cr.types[i] = "text"
cr.types[i] = DefaultDBType
}
return cr.types, nil
}
Expand Down
14 changes: 7 additions & 7 deletions input_json.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,16 +53,16 @@ func (jr *JSONRead) GetColumn(rowNum int) ([]string, error) {
func (jr *JSONRead) GetTypes() ([]string, error) {
jr.types = make([]string, len(jr.names))
for i := 0; i < len(jr.names); i++ {
jr.types[i] = "text"
jr.types[i] = DefaultDBType
}
return jr.types, nil
}

func (jr *JSONRead) readAhead(top interface{}, rcount int) (map[string]string, []string, error) {
func (jr *JSONRead) readAhead(top interface{}, count int) (map[string]string, []string, error) {
if jr.inArray != nil {
if len(jr.inArray) > rcount {
if len(jr.inArray) > count {
jr.count++
return jr.secondLevel(top, jr.inArray[rcount])
return jr.secondLevel(top, jr.inArray[count])
}
return nil, nil, io.EOF
}
Expand Down Expand Up @@ -173,18 +173,18 @@ func (jr *JSONRead) ReadRow(row []interface{}) ([]interface{}, error) {
var data interface{}
err := jr.reader.Decode(&data)
if err != nil {
return nil, err
return nil, fmt.Errorf("json format error:%s", err)
}
row = jr.rowParse(row, data)
}
return row, nil
}

func (jr *JSONRead) rowParse(row []interface{}, jsonRow interface{}) []interface{} {
switch dmap := jsonRow.(type) {
switch m := jsonRow.(type) {
case map[string]interface{}:
for i := range jr.names {
row[i] = jsonString(dmap[jr.names[i]])
row[i] = jsonString(m[jr.names[i]])
}
default:
for i := range jr.names {
Expand Down
Loading

0 comments on commit 7760062

Please sign in to comment.