-
Notifications
You must be signed in to change notification settings - Fork 12
/
join.go
352 lines (303 loc) · 8.3 KB
/
join.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
package datatable
import (
"regexp"
"strings"
"github.com/pkg/errors"
)
// InnerJoin selects records that have matching values in both tables.
// left datatable is used as reference datatable.
// <!> InnerJoin transforms an expr column to a raw column
func (left *DataTable) InnerJoin(right *DataTable, on []JoinOn) (*DataTable, error) {
return newJoinImpl(innerJoin, []*DataTable{left, right}, on).Compute()
}
// InnerJoin selects records that have matching values in both tables.
// tables[0] is used as reference datatable.
func InnerJoin(tables []*DataTable, on []JoinOn) (*DataTable, error) {
return newJoinImpl(innerJoin, tables, on).Compute()
}
// LeftJoin returns all records from the left table (table1), and the matched records from the right table (table2).
// The result is NULL from the right side, if there is no match.
// <!> LeftJoin transforms an expr column to a raw column
func (left *DataTable) LeftJoin(right *DataTable, on []JoinOn) (*DataTable, error) {
return newJoinImpl(leftJoin, []*DataTable{left, right}, on).Compute()
}
// LeftJoin the tables.
// tables[0] is used as reference datatable.
func LeftJoin(tables []*DataTable, on []JoinOn) (*DataTable, error) {
return newJoinImpl(leftJoin, tables, on).Compute()
}
// RightJoin returns all records from the right table (table2), and the matched records from the left table (table1).
// The result is NULL from the left side, when there is no match.
// <!> RightJoin transforms an expr column to a raw column
func (left *DataTable) RightJoin(right *DataTable, on []JoinOn) (*DataTable, error) {
return newJoinImpl(rightJoin, []*DataTable{left, right}, on).Compute()
}
// RightJoin the tables.
// tables[0] is used as reference datatable.
func RightJoin(tables []*DataTable, on []JoinOn) (*DataTable, error) {
return newJoinImpl(rightJoin, tables, on).Compute()
}
// OuterJoin returns all records when there is a match in either left or right table
// <!> OuterJoin transforms an expr column to a raw column
func (left *DataTable) OuterJoin(right *DataTable, on []JoinOn) (*DataTable, error) {
return newJoinImpl(outerJoin, []*DataTable{left, right}, on).Compute()
}
// OuterJoin the tables.
// tables[0] is used as reference datatable.
func OuterJoin(tables []*DataTable, on []JoinOn) (*DataTable, error) {
return newJoinImpl(outerJoin, tables, on).Compute()
}
type JoinOn struct {
Table string
Field string
}
var rgOn = regexp.MustCompile(`^(?:\[([^]]+)\]\.)?(?:\[([^]]+)\])$`)
// On creates a "join on" expression
// ie, as SQL, SELECT * FROM A INNER JOIN B ON B.id = A.user_id
// Syntax: "[table].[field]", "field"
func On(fields ...string) []JoinOn {
var jon []JoinOn
for _, f := range fields {
matches := rgOn.FindStringSubmatch(f)
switch len(matches) {
case 0:
jon = append(jon, JoinOn{Table: "*", Field: f})
case 3:
t := matches[1]
if len(t) == 0 {
t = "*"
}
jon = append(jon, JoinOn{Table: t, Field: matches[2]})
default:
return nil
}
}
return jon
}
// Using creates a "join using" expression
// ie, as SQL, SELECT * FROM A INNER JOIN B USING 'field'
func Using(fields ...string) []JoinOn {
var jon []JoinOn
for _, f := range fields {
jon = append(jon, JoinOn{Table: "*", Field: f})
}
return jon
}
type joinType uint8
const (
innerJoin joinType = iota
leftJoin
rightJoin
outerJoin
)
func colname(dt *DataTable, col string) string {
var sb strings.Builder
sb.WriteString(dt.Name())
sb.WriteString(".")
sb.WriteString(col)
return sb.String()
}
type joinClause struct {
table *DataTable
mcols map[string][]string
on []string
includeOnCols bool
cmapper [][2]string // [initial, output]
hashtable map[uint64][]int
consumed map[int]bool
}
func (jc *joinClause) copyColumnsTo(out *DataTable) error {
if out == nil {
return ErrNilOutputDatatable
}
mon := make(map[string]bool, len(jc.on))
for _, o := range jc.on {
mon[o] = true
}
for _, col := range jc.table.cols {
name := col.name
cname := name
if _, found := mon[name]; found {
if !jc.includeOnCols {
continue
}
} else if v, ok := jc.mcols[name]; ok && len(v) > 1 {
// commons col between table
for _, tn := range v {
if tn == jc.table.name {
cname = colname(jc.table, name)
break
}
}
}
ccpy := col.emptyCopy()
ccpy.name = cname
if err := out.addColumn(ccpy); err != nil {
return err
}
jc.cmapper = append(jc.cmapper, [2]string{name, cname})
}
return nil
}
func (jc *joinClause) initHashTable() {
jc.hashtable = hasher.Table(jc.table, jc.on)
jc.consumed = make(map[int]bool, jc.table.NumRows())
}
type joinImpl struct {
mode joinType
tables []*DataTable
on []JoinOn
clauses []*joinClause
mcols map[string][]string
}
func newJoinImpl(mode joinType, tables []*DataTable, on []JoinOn) *joinImpl {
return &joinImpl{
mode: mode,
tables: tables,
on: on,
}
}
func (j *joinImpl) Compute() (*DataTable, error) {
if err := j.checkInput(); err != nil {
return nil, err
}
j.initColMapper()
out := j.tables[0]
for i := 1; i < len(j.tables); i++ {
jdt, err := j.join(out, j.tables[i])
if err != nil {
return nil, err
}
out = jdt
}
if out == nil {
return nil, ErrNoOutput
}
return out, nil
}
func (j *joinImpl) checkInput() error {
if len(j.tables) < 2 {
return ErrNotEnoughDatatables
}
for i, t := range j.tables {
if t == nil || len(t.Name()) == 0 || t.NumCols() == 0 {
err := errors.Errorf("table #%d is nil", i)
return errors.Wrap(err, ErrNilTable.Error())
}
}
if len(j.on) == 0 {
return ErrNoOnClauses
}
for i, o := range j.on {
if len(o.Field) == 0 {
err := errors.Errorf("on #%d is nil", i)
return errors.Wrap(err, ErrOnClauseIsNil.Error())
}
}
return nil
}
func (j *joinImpl) initColMapper() {
mcols := make(map[string][]string)
for _, t := range j.tables {
for _, name := range t.cols {
mcols[name.name] = append(mcols[name.name], t.Name())
}
}
j.mcols = mcols
}
func (j *joinImpl) join(left, right *DataTable) (*DataTable, error) {
if left == nil {
err := errors.New("left is nil datatable")
return nil, errors.Wrap(err, ErrNilDatatable.Error())
}
if right == nil {
err := errors.New("right is nil datatable")
return nil, errors.Wrap(err, ErrNilDatatable.Error())
}
clauses := [2]*joinClause{
&joinClause{
table: left,
mcols: j.mcols,
includeOnCols: true,
},
&joinClause{
table: right,
mcols: j.mcols,
},
}
// find on clauses
for _, o := range j.on {
if o.Table == left.Name() {
clauses[0].on = append(clauses[0].on, o.Field)
continue
}
if o.Table == right.Name() {
clauses[1].on = append(clauses[1].on, o.Field)
continue
}
if o.Table == "*" || len(o.Table) == 0 {
clauses[0].on = append(clauses[0].on, o.Field)
clauses[1].on = append(clauses[1].on, o.Field)
}
}
// create output
out := New(left.Name())
for _, clause := range clauses {
if err := clause.copyColumnsTo(out); err != nil {
return nil, err
}
}
// mode
var ref, join *joinClause
switch j.mode {
case innerJoin, leftJoin, outerJoin:
ref, join = clauses[0], clauses[1]
case rightJoin:
ref, join = clauses[1], clauses[0]
default:
err := errors.Errorf("unknown mode '%v'", j.mode)
return nil, errors.Wrap(err, ErrUnknownMode.Error())
}
join.initHashTable()
// Copy rows
for _, refrow := range ref.table.Rows(ExportHidden(true)) {
// Create hash
hash := hasher.Row(refrow, ref.on)
// Have we same hash in jointable ?
if indexes, ok := join.hashtable[hash]; ok {
for _, idx := range indexes {
joinrow := join.table.Row(idx, ExportHidden(true))
row := out.NewRow()
for _, cm := range ref.cmapper {
row[cm[1]] = refrow.Get(cm[0])
}
for _, cm := range join.cmapper {
row[cm[1]] = joinrow.Get(cm[0])
}
join.consumed[idx] = true
out.Append(row)
}
} else if j.mode != innerJoin {
row := make(Row, len(refrow))
for _, cm := range ref.cmapper {
row[cm[1]] = refrow.Get(cm[0])
}
out.Append(row)
}
}
// out.Print(os.Stdout, PrintColumnType(false))
// Outer: we must copy rows not consummed in right (join) table
if j.mode == outerJoin {
for i, joinrow := range join.table.Rows() {
if b, ok := join.consumed[i]; ok && b {
continue
}
row := make(Row, len(joinrow))
for _, cm := range join.cmapper {
row[cm[1]] = joinrow.Get(cm[0])
}
out.Append(row)
}
}
return out, nil
}