diff --git a/sqli.go b/sqli.go index 8cbf645..6ddee53 100644 --- a/sqli.go +++ b/sqli.go @@ -1,7 +1,6 @@ package libinjection import ( - "bytes" "strings" ) @@ -24,7 +23,7 @@ type sqliState struct { current *sqliToken // fingerprint pattern c-string, +1 form ending null - fingerprint []byte + fingerprint string // |----------------------------------------| // | |/**/ |--[start] |# | @@ -84,7 +83,7 @@ func sqliInit(s *sqliState, input string, flags int) { // single quote. // ByteDouble ("), process pretending input started with a // double quote. -func (s *sqliState) sqliFingerprint(flags int) []byte { +func (s *sqliState) sqliFingerprint(flags int) string { s.reset(flags) length := s.fold() @@ -103,25 +102,28 @@ func (s *sqliState) sqliFingerprint(flags int) []byte { s.tokenVec[length-1].category = sqliTokenTypeComment } + fp := strings.Builder{} + for i := 0; i < length; i++ { - s.fingerprint = append(s.fingerprint, s.tokenVec[i].category) - } + c := s.tokenVec[i].category + // check for 'X' in pattern, and then + // clear out all tokens + // + // this means parsing could not be done + // accurately due to pgsql's double comments + // or other syntax that isn't consistent. + // Should be very rare false positive + if c == sqliTokenTypeEvil { + s.fingerprint = string(sqliTokenTypeEvil) + s.tokenVec[0].category = sqliTokenTypeEvil + s.tokenVec[0].val = string(sqliTokenTypeEvil) + return s.fingerprint + } - // check for 'X' in pattern, and then - // clear out all tokens - // - // this means parsing could not be done - // accurately due to pgsql's double comments - // or other syntax that isn't consistent. - // Should be very rare false positive - if bytes.ContainsAny(s.fingerprint, string(sqliTokenTypeEvil)) { - s.fingerprint = s.fingerprint[:0] - s.fingerprint = append(s.fingerprint, sqliTokenTypeEvil) - - s.tokenVec[0].category = sqliTokenTypeEvil - s.tokenVec[0].val = [32]byte{sqliTokenTypeEvil} + fp.WriteByte(c) } + s.fingerprint = fp.String() return s.fingerprint } @@ -163,16 +165,10 @@ func (s *sqliState) merge(tokenA, tokenB *sqliToken) bool { return false } - // oddly annoying last.val + ' ' + current.val - var tmp [tokenSize]byte - copy(tmp[:], tokenA.val[:tokenA.len]) - tmp[tokenA.len] = ' ' - copy(tmp[tokenA.len+1:], tokenB.val[:tokenB.len]) - - length := tokenA.len + tokenB.len + 1 - ch := s.lookupWord(sqliLookupWord, tmp[:length]) + tmp := tokenA.val[:tokenA.len] + " " + tokenB.val[:tokenB.len] + ch := s.lookupWord(sqliLookupWord, tmp) if ch != byteNull { - tokenA.assign(ch, tokenA.pos, length, string(tmp[:length])) + tokenA.assign(ch, tokenA.pos, len(tmp), tmp) return true } return false @@ -316,32 +312,32 @@ func (s *sqliState) fold() int { case (s.tokenVec[left].category == sqliTokenTypeBareWord || s.tokenVec[left].category == sqliTokenTypeVariable) && s.tokenVec[left+1].category == sqliTokenTypeLeftParenthesis && ( // TSQL functions but common enough to be column names - toUpperCmp("USER_ID", string(s.tokenVec[left].val[:s.tokenVec[left].len])) || - toUpperCmp("USER_NAME", string(s.tokenVec[left].val[:s.tokenVec[left].len])) || + toUpperCmp("USER_ID", s.tokenVec[left].val[:s.tokenVec[left].len]) || + toUpperCmp("USER_NAME", s.tokenVec[left].val[:s.tokenVec[left].len]) || // Function in MySQL - toUpperCmp("DATABASE", string(s.tokenVec[left].val[:s.tokenVec[left].len])) || - toUpperCmp("PASSWORD", string(s.tokenVec[left].val[:s.tokenVec[left].len])) || - toUpperCmp("USER", string(s.tokenVec[left].val[:s.tokenVec[left].len])) || + toUpperCmp("DATABASE", s.tokenVec[left].val[:s.tokenVec[left].len]) || + toUpperCmp("PASSWORD", s.tokenVec[left].val[:s.tokenVec[left].len]) || + toUpperCmp("USER", s.tokenVec[left].val[:s.tokenVec[left].len]) || // MySQL words that act as a variable and are a function // TSQL current_users is fake_variable // http://msdn.microsoft.com/en-us/library/ms176050.aspx - toUpperCmp("CURRENT_USER", string(s.tokenVec[left].val[:s.tokenVec[left].len])) || - toUpperCmp("CURRENT_DATE", string(s.tokenVec[left].val[:s.tokenVec[left].len])) || - toUpperCmp("CURRENT_TIME", string(s.tokenVec[left].val[:s.tokenVec[left].len])) || - toUpperCmp("CURRENT_TIMESTAMP", string(s.tokenVec[left].val[:s.tokenVec[left].len])) || - toUpperCmp("LOCALTIME", string(s.tokenVec[left].val[:s.tokenVec[left].len])) || - toUpperCmp("LOCALTIMESTAMP", string(s.tokenVec[left].val[:s.tokenVec[left].len]))): + toUpperCmp("CURRENT_USER", s.tokenVec[left].val[:s.tokenVec[left].len]) || + toUpperCmp("CURRENT_DATE", s.tokenVec[left].val[:s.tokenVec[left].len]) || + toUpperCmp("CURRENT_TIME", s.tokenVec[left].val[:s.tokenVec[left].len]) || + toUpperCmp("CURRENT_TIMESTAMP", s.tokenVec[left].val[:s.tokenVec[left].len]) || + toUpperCmp("LOCALTIME", s.tokenVec[left].val[:s.tokenVec[left].len]) || + toUpperCmp("LOCALTIMESTAMP", s.tokenVec[left].val[:s.tokenVec[left].len])): // pos is the same // other conversions need to go here... for instance // password CAN be a function, coalesce CAN be a funtion s.tokenVec[left].category = sqliTokenTypeFunction continue case s.tokenVec[left].category == sqliTokenTypeKeyword && - (toUpperCmp("IN", string(s.tokenVec[left].val[:s.tokenVec[left].len])) || - toUpperCmp("NOT IN", string(s.tokenVec[left].val[:s.tokenVec[left].len]))): + (toUpperCmp("IN", s.tokenVec[left].val[:s.tokenVec[left].len]) || + toUpperCmp("NOT IN", s.tokenVec[left].val[:s.tokenVec[left].len])): if s.tokenVec[left+1].category == sqliTokenTypeLeftParenthesis { // got ... IN ( ... (or 'NOT IN') // it's an operator @@ -362,8 +358,8 @@ func (s *sqliState) fold() int { // "foo" = LIKE(1,2) continue case s.tokenVec[left].category == sqliTokenTypeOperator && - (toUpperCmp("LIKE", string(s.tokenVec[left].val[:s.tokenVec[left].len])) || - toUpperCmp("NOT LIKE", string(s.tokenVec[left].val[:s.tokenVec[left].len]))): + (toUpperCmp("LIKE", s.tokenVec[left].val[:s.tokenVec[left].len]) || + toUpperCmp("NOT LIKE", s.tokenVec[left].val[:s.tokenVec[left].len])): if s.tokenVec[left+1].category == sqliTokenTypeLeftParenthesis { // SELECT LIKE(... // it's a function @@ -385,7 +381,7 @@ func (s *sqliState) fold() int { case s.tokenVec[left].category == sqliTokenTypeCollate && s.tokenVec[left+1].category == sqliTokenTypeBareWord: // there are too many collation types.. so if the bareword has a "_" // then it's TYPE_SQLTYPE - if bytes.ContainsRune(s.tokenVec[left+1].val[:], '_') { + if strings.ContainsRune(s.tokenVec[left+1].val[:], '_') { s.tokenVec[left+1].category = sqliTokenTypeSQLType left = 0 } @@ -514,7 +510,7 @@ func (s *sqliState) fold() int { s.tokenVec[left].category == sqliTokenTypeVariable || s.tokenVec[left].category == sqliTokenTypeString) && s.tokenVec[left+1].category == sqliTokenTypeOperator && - string(s.tokenVec[left+1].val[:s.tokenVec[left+1].len]) == "::" && + s.tokenVec[left+1].val[:s.tokenVec[left+1].len] == "::" && s.tokenVec[left+2].category == sqliTokenTypeSQLType: pos -= 2 left = 0 @@ -607,7 +603,7 @@ func (s *sqliState) fold() int { // if we get User(foo), then User is not a function // This should be expanded since it eliminated a lot of false // positives. - if toUpperCmp("USER", string(s.tokenVec[left].val[:s.tokenVec[left].len])) { + if toUpperCmp("USER", s.tokenVec[left].val[:s.tokenVec[left].len]) { s.tokenVec[left].category = sqliTokenTypeBareWord } } @@ -668,23 +664,25 @@ func (s *sqliState) tokenize() bool { // // return TRUE if SQLi, false otherwise func (s *sqliState) blacklist() bool { - var fp []byte length := len(s.fingerprint) if length < 1 { return false } - fp = append(fp, '0') + fp := strings.Builder{} + fp.Grow(length + 1) + + fp.WriteByte('0') for i := 0; i < length; i++ { ch := s.fingerprint[i] if ch >= 'a' && ch <= 'z' { ch -= 0x20 } - fp = append(fp, ch) + fp.WriteByte(ch) } - return isKeyword(fp) == sqliTokenTypeFingerprint + return isKeyword(fp.String()) == sqliTokenTypeFingerprint } // Given a positive match for a pattern (i.e. pattern is SQLi), this function @@ -789,8 +787,8 @@ func (s *sqliState) notWhitelist() bool { // no opening quote, no closing quote // and each string has data // sos || s&s are string and operator || logic operator and string - switch { - case string(s.fingerprint) == "sos" || string(s.fingerprint) == "s&s": + switch s.fingerprint { + case "sos", "s&s": if s.tokenVec[0].strOpen == byteNull && s.tokenVec[2].strClose == byteNull && s.tokenVec[0].strClose == s.tokenVec[2].strOpen { @@ -803,22 +801,17 @@ func (s *sqliState) notWhitelist() bool { } return false - case string(s.fingerprint) == "s&n" || - string(s.fingerprint) == "n&1" || - string(s.fingerprint) == "1&1" || - string(s.fingerprint) == "1&v" || - string(s.fingerprint) == "1&s": + case "s&n", "n&1", "1&1", "1&v", "1&s": // 'sexy and 17' not SQLi // 'sexy and 17<18' SQLi if s.statsTokens == 3 { return false } - case s.tokenVec[1].category == sqliTokenTypeKeyword: - if s.tokenVec[1].len < 5 || !toUpperCmp("INTO", string(s.tokenVec[1].val[:4])) { - // if it's not "INTO OUTFILE", or "INTO DUMPFILE" (MySQL) - // then treat as safe - return false - } + } + if s.tokenVec[1].category == sqliTokenTypeKeyword && (s.tokenVec[1].len < 5 || !toUpperCmp("INTO", s.tokenVec[1].val[:4])) { + // if it's not "INTO OUTFILE", or "INTO DUMPFILE" (MySQL) + // then treat as safe + return false } } @@ -829,7 +822,7 @@ func (s *sqliState) checkFingerprint() bool { return s.blacklist() && s.notWhitelist() } -func (s *sqliState) lookupWord(lookupType int, word []byte) byte { +func (s *sqliState) lookupWord(lookupType int, word string) byte { if lookupType == sqliLookupFingerprint { if s.checkFingerprint() { return 'X' @@ -897,12 +890,12 @@ func (s *sqliState) check() bool { // IsSQLi returns true if the input is SQLi // It also returns the fingerprint of the SQL Injection as []byte -func IsSQLi(input string) (bool, []byte) { +func IsSQLi(input string) (bool, string) { state := new(sqliState) sqliInit(state, input, 0) result := state.check() if result { return result, state.fingerprint } - return result, []byte{} + return result, "" } diff --git a/sqli_helpers.go b/sqli_helpers.go index aa3d197..1657b64 100644 --- a/sqli_helpers.go +++ b/sqli_helpers.go @@ -111,12 +111,12 @@ func toUpperCmp(a, b string) bool { return a == strings.ToUpper(b) } -func isKeyword(key []byte) byte { +func isKeyword(key string) byte { return searchKeyword(key, sqlKeywords) } -func searchKeyword(key []byte, keywords map[string]byte) byte { - upperKey := strings.ToUpper(string(key)) +func searchKeyword(key string, keywords map[string]byte) byte { + upperKey := strings.ToUpper(key) if val, ok := keywords[upperKey]; ok { return val diff --git a/sqli_parse.go b/sqli_parse.go index 633e4f1..77b8ec3 100644 --- a/sqli_parse.go +++ b/sqli_parse.go @@ -21,7 +21,7 @@ func parseEolComment(s *sqliState) int { func parseMoney(s *sqliState) int { if s.pos+1 == s.length { - s.current.assignByte(sqliTokenTypeBareWord, s.pos, 1, '$') + s.current.assign(sqliTokenTypeBareWord, s.pos, 1, "$") return s.length } @@ -48,14 +48,14 @@ func parseMoney(s *sqliState) int { xlen := strLenSpn(s.input[s.pos+1:], s.length-s.pos-1, "abcdefghjiklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ") if xlen == 0 { // hmm, it's "$" _something_ .. just add $ and keep going - s.current.assignByte(sqliTokenTypeBareWord, s.pos, 1, '$') + s.current.assign(sqliTokenTypeBareWord, s.pos, 1, "$") return s.pos + 1 } // we have $foobar????? if s.pos+xlen+1 == s.length || s.input[s.pos+xlen+1] != '$' { // not $foobar$, or fell off edge - s.current.assignByte(sqliTokenTypeBareWord, s.pos, 1, '$') + s.current.assign(sqliTokenTypeBareWord, s.pos, 1, "$") return s.pos + 1 } @@ -81,7 +81,7 @@ func parseMoney(s *sqliState) int { } func parseOther(s *sqliState) int { - s.current.assignByte(sqliTokenTypeUnknown, s.pos, 1, s.input[s.pos]) + s.current.assign(sqliTokenTypeUnknown, s.pos, 1, s.input[s.pos:]) return s.pos + 1 } @@ -90,12 +90,12 @@ func parseWhite(s *sqliState) int { } func parseOperator1(s *sqliState) int { - s.current.assignByte(sqliTokenTypeOperator, s.pos, 1, s.input[s.pos]) + s.current.assign(sqliTokenTypeOperator, s.pos, 1, s.input[s.pos:]) return s.pos + 1 } func parseByte(s *sqliState) int { - s.current.assignByte(s.input[s.pos], s.pos, 1, s.input[s.pos]) + s.current.assign(s.input[s.pos], s.pos, 1, s.input[s.pos:]) return s.pos + 1 } @@ -107,7 +107,7 @@ func parseHash(s *sqliState) int { s.statsCommentHash++ return parseEolComment(s) } - s.current.assignByte(sqliTokenTypeOperator, s.pos, 1, '#') + s.current.assign(sqliTokenTypeOperator, s.pos, 1, "#") return s.pos + 1 } @@ -128,7 +128,7 @@ func parseDash(s *sqliState) int { s.statsCommentDDX++ return parseEolComment(s) default: - s.current.assignByte(sqliTokenTypeOperator, s.pos, 1, '-') + s.current.assign(sqliTokenTypeOperator, s.pos, 1, "-") return s.pos + 1 } } @@ -174,7 +174,7 @@ func parseBackSlash(s *sqliState) int { s.current.assign(sqliTokenTypeNumber, s.pos, 2, s.input[s.pos:]) return s.pos + 2 } - s.current.assignByte(sqliTokenTypeBackslash, s.pos, 1, s.input[s.pos]) + s.current.assign(sqliTokenTypeBackslash, s.pos, 1, s.input[s.pos:]) return s.pos + 1 } @@ -189,7 +189,7 @@ func parseOperator2(s *sqliState) int { return s.pos + 3 } - ch := s.lookupWord(sqliLookupOperator, []byte(s.input[s.pos:s.pos+2])) + ch := s.lookupWord(sqliLookupOperator, s.input[s.pos:s.pos+2]) if ch != byteNull { s.current.assign(ch, s.pos, 2, s.input[s.pos:]) return s.pos + 2 @@ -321,7 +321,7 @@ func parseNumber(s *sqliState) int { if pos-start == 1 { // only one character read so far - s.current.assignByte(sqliTokenTypeDot, start, 1, '.') + s.current.assign(sqliTokenTypeDot, start, 1, ".") return pos } } diff --git a/sqli_test.go b/sqli_test.go index 19a8ef2..89ae860 100644 --- a/sqli_test.go +++ b/sqli_test.go @@ -195,6 +195,7 @@ func BenchmarkSQLiDriver(b *testing.B) { } b.Run("sqli", func(b *testing.B) { + b.ReportAllocs() for i := 0; i < b.N; i++ { for _, tc := range cases.sqli { tt := tc @@ -204,6 +205,7 @@ func BenchmarkSQLiDriver(b *testing.B) { }) b.Run("folding", func(b *testing.B) { + b.ReportAllocs() for i := 0; i < b.N; i++ { for _, tc := range cases.folding { tt := tc @@ -213,6 +215,7 @@ func BenchmarkSQLiDriver(b *testing.B) { }) b.Run("tokens", func(b *testing.B) { + b.ReportAllocs() for i := 0; i < b.N; i++ { for _, tc := range cases.tokens { tt := tc diff --git a/sqli_token.go b/sqli_token.go index 2b5da28..20504af 100644 --- a/sqli_token.go +++ b/sqli_token.go @@ -13,7 +13,7 @@ type sqliToken struct { category byte strOpen byte strClose byte - val [32]byte + val string } const ( @@ -82,16 +82,7 @@ func (t *sqliToken) assign(tokenType byte, pos, length int, value string) { t.category = tokenType t.pos = pos t.len = last - for i := 0; i < last; i++ { - t.val[i] = value[i] - } -} - -func (t *sqliToken) assignByte(tokenType byte, pos, len int, value byte) { - t.category = tokenType - t.pos = pos - t.len = 1 - t.val[0] = value + t.val = value[:last] } func (t *sqliToken) isUnaryOp() bool { @@ -105,7 +96,7 @@ func (t *sqliToken) isUnaryOp() bool { case 2: return t.val[0] == '!' && t.val[1] == '!' case 3: - return toUpperCmp("NOT", string(t.val[:3])) + return toUpperCmp("NOT", t.val[:3]) default: return false }