-
Notifications
You must be signed in to change notification settings - Fork 7
/
sqlfile.go
181 lines (151 loc) · 3.04 KB
/
sqlfile.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
// Package sqlfile provides a way to execute sql file easily
//
// For more usage see https://github.com/tanimutomo/sqlfile
package sqlfile
import (
"database/sql"
"fmt"
"io/ioutil"
"strings"
)
// SqlFile represents a queries holder
type SqlFile struct {
files []string
queries []string
}
// New create new SqlFile object
func New() *SqlFile {
return &SqlFile{}
}
// File add and load queries from input file
func (s *SqlFile) File(file string) error {
queries, err := load(file)
if err != nil {
return err
}
s.files = append(s.files, file)
s.queries = append(s.queries, queries...)
return nil
}
// Files add and load queries from multiple input files
func (s *SqlFile) Files(files ...string) error {
for _, file := range files {
if err := s.File(file); err != nil {
return err
}
}
return nil
}
// Directory add and load queries from *.sql files in specified directory
func (s *SqlFile) Directory(dir string) error {
files, err := ioutil.ReadDir(dir)
if err != nil {
return err
}
for _, file := range files {
if file.IsDir() {
continue
}
name := file.Name()
if name[len(name)-3:] != "sql" {
continue
}
if err := s.File(dir + "/" + name); err != nil {
return err
}
}
return nil
}
// Exec execute SQL statements written int the specified sql file
func (s *SqlFile) Exec(db *sql.DB) (res []sql.Result, err error) {
tx, err := db.Begin()
if err != nil {
return res, err
}
defer saveTx(tx, &err)
var rs []sql.Result
for _, q := range s.queries {
r, err := tx.Exec(q)
if err != nil {
return res, fmt.Errorf(err.Error() + " : when executing > " + q)
}
rs = append(rs, r)
}
return rs, err
}
// Load load sql file from path, and return SqlFile pointer
func load(path string) (qs []string, err error) {
ls, err := readFileByLine(path)
if err != nil {
return qs, err
}
var ncls []string
for _, l := range ls {
ncl := excludeComment(l)
ncls = append(ncls, ncl)
}
l := strings.Join(ncls, "")
qs = strings.Split(l, ";")
qs = qs[:len(qs)-1]
return qs, nil
}
func readFileByLine(path string) (ls []string, err error) {
f, err := ioutil.ReadFile(path)
if err != nil {
return ls, err
}
ls = strings.Split(string(f), "\n")
return ls, nil
}
func excludeComment(line string) string {
d := "\""
s := "'"
c := "--"
var nc string
ck := line
mx := len(line) + 1
for {
if len(ck) == 0 {
return nc
}
di := strings.Index(ck, d)
si := strings.Index(ck, s)
ci := strings.Index(ck, c)
if di < 0 {
di = mx
}
if si < 0 {
si = mx
}
if ci < 0 {
ci = mx
}
var ei int
if di < si && di < ci {
nc += ck[:di+1]
ck = ck[di+1:]
ei = strings.Index(ck, d)
} else if si < di && si < ci {
nc += ck[:si+1]
ck = ck[si+1:]
ei = strings.Index(ck, s)
} else if ci < di && ci < si {
return nc + ck[:ci]
} else {
return nc + ck
}
nc += ck[:ei+1]
ck = ck[ei+1:]
}
}
func saveTx(tx *sql.Tx, err *error) {
if p := recover(); p != nil {
tx.Rollback()
panic(p)
} else if *err != nil {
tx.Rollback()
} else {
e := tx.Commit()
err = &e
}
}