diff --git a/doc/userguide/output/eve/eve-json-format.rst b/doc/userguide/output/eve/eve-json-format.rst index 952945dffc98..a7667e2bbbcf 100644 --- a/doc/userguide/output/eve/eve-json-format.rst +++ b/doc/userguide/output/eve/eve-json-format.rst @@ -3057,3 +3057,32 @@ Example of ARP logging: request and response "dest_mac": "00:1d:09:f0:92:ab", "dest_ip": "10.10.10.1" } + +Event type: MySQL +----------------- + +Fields +~~~~~~ + +* "version": the MySQL protocol version offered by the server. +* "tls": protocol need to be upgrade to tls. +* "command": sql query statement or utility command like ping. +* "rows": zero or multi results from executing sql query statement, one row is splited by comma. + +Examples +~~~~~~~~ + +Example of MySQL logging: + +:: + +{ + "mysql": { + "version": "8.0.32", + "tls": false, + "command": "SELECT VERSION()", + "rows": [ + "8.0.32" + ] + } +} diff --git a/doc/userguide/rules/index.rst b/doc/userguide/rules/index.rst index c8b586fecaa7..a3b8108bdcf4 100644 --- a/doc/userguide/rules/index.rst +++ b/doc/userguide/rules/index.rst @@ -37,6 +37,7 @@ Suricata Rules nfs-keywords smtp-keywords websocket-keywords + mysql-keywords app-layer xbits noalert diff --git a/doc/userguide/rules/mysql-keywords.rst b/doc/userguide/rules/mysql-keywords.rst new file mode 100644 index 000000000000..39ed949c3db9 --- /dev/null +++ b/doc/userguide/rules/mysql-keywords.rst @@ -0,0 +1,50 @@ +MySQL Keywords +============ + +The MySQL keywords are implemented and can be used to match on fields in MySQL messages. + +============================== ================== +Keyword Direction +============================== ================== +mysql.command Request +mysql.rows Response +============================== ================== + +mysql.command +---------- + +This keyword matches on the query statement like `select * from xxx where yyy = zzz` found in a MySQL request. + +Syntax +~~~~~~ + +:: + + mysql.command; content:; + +Examples +~~~~~~~~ + +:: + + mysql.commands; content:"select"; + +mysql.rows +------- + +This keyword matches on the rows which come from query statement result found in a Mysql response. +row format: 1,foo,bar + +Syntax +~~~~~~ + +:: + + mysql.rows; content:; + +Examples +~~~~~~~~ + +:: + + mysql.rows; content:"foo,bar"; diff --git a/doc/userguide/upgrade.rst b/doc/userguide/upgrade.rst index 2b16dd31a35c..f9306b32a2f0 100644 --- a/doc/userguide/upgrade.rst +++ b/doc/userguide/upgrade.rst @@ -82,6 +82,8 @@ Major changes - Unknown requirements in the ``requires`` keyword will now be treated as unmet requirements, causing the rule to not be loaded. See :ref:`keyword_requires`. +- MySQL parser and logger have been introduced. +- The MySQL keywords ``mysql.command`` and ``mysql.command`` have been introduced. Removals ~~~~~~~~ diff --git a/etc/schema.json b/etc/schema.json index b335dc5c2104..cccda2dfc998 100644 --- a/etc/schema.json +++ b/etc/schema.json @@ -2394,6 +2394,35 @@ }, "additionalProperties": false }, + "mysql": { + "type": "object", + "optional": true, + "properties": { + "version": { + "type": "string", + "description": "Mysql server version" + }, + "tls": { + "type": "boolean" + }, + "command": { + "type": "string", + "description": "sql query statement or some utility commands like ping." + }, + "affected_rows": { + "type": "integer" + }, + "rows": { + "type": "array", + "optional": true, + "minItems": 1, + "items": { + "type": "string" + }, + "description": "Comma separated result from sql statement" + } + } + }, "ldap": { "type": "object", "optional": true, @@ -4674,6 +4703,10 @@ "description": "Errors encountered parsing MQTT protocol", "$ref": "#/$defs/stats_applayer_error" }, + "mysql": { + "description": "Errors encountered parsing MySQL protocol", + "$ref": "#/$defs/stats_applayer_error" + }, "nfs_tcp": { "description": "Errors encountered parsing NFS/TCP protocol", "$ref": "#/$defs/stats_applayer_error" @@ -4849,6 +4882,10 @@ "description": "Number of flows for MQTT protocol", "type": "integer" }, + "mysql": { + "description": "Number of flows for MySQL protocol", + "type": "integer" + }, "nfs_tcp": { "description": "Number of flows for NFS/TCP protocol", "type": "integer" @@ -5019,6 +5056,10 @@ "description": "Number of transactions for MQTT protocol", "type": "integer" }, + "mysql": { + "description": "Number of flows for MySQL protocol", + "type": "integer" + }, "nfs_tcp": { "description": "Number of transactions for NFS/TCP protocol", "type": "integer" diff --git a/rust/src/lib.rs b/rust/src/lib.rs index bea7854f107e..a57f261adaa9 100644 --- a/rust/src/lib.rs +++ b/rust/src/lib.rs @@ -139,6 +139,7 @@ pub mod ffi; pub mod feature; pub mod sdp; pub mod ldap; +pub mod mysql; #[allow(unused_imports)] pub use suricata_lua_sys; diff --git a/rust/src/mysql/detect.rs b/rust/src/mysql/detect.rs new file mode 100644 index 000000000000..1f42b0f84bc6 --- /dev/null +++ b/rust/src/mysql/detect.rs @@ -0,0 +1,166 @@ +/* Copyright (C) 2024 Open Information Security Foundation + * + * You can copy, redistribute or modify this Program under the terms of + * the GNU General Public License version 2 as published by the Free + * Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * version 2 along with this program; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA + * 02110-1301, USA. + */ + +// Author: QianKaiLin + +/// Detect +/// Get the mysql query +use super::mysql::{MysqlTransaction, ALPROTO_MYSQL}; +use crate::detect::{ + DetectBufferSetActiveList, DetectHelperBufferMpmRegister, DetectHelperGetData, + DetectHelperGetMultiData, DetectHelperKeywordRegister, DetectHelperMultiBufferMpmRegister, + DetectSignatureSetAppProto, SCSigTableElmt, SIGMATCH_NOOPT, +}; +use std::os::raw::{c_int, c_void}; + +static mut G_MYSQL_COMMAND_BUFFER_ID: c_int = 0; +static mut G_MYSQL_ROWS_BUFFER_ID: c_int = 0; + +#[no_mangle] +unsafe extern "C" fn SCMysqlCommandSetup( + de: *mut c_void, s: *mut c_void, _raw: *const std::os::raw::c_char, +) -> c_int { + if DetectSignatureSetAppProto(s, ALPROTO_MYSQL) != 0 { + return -1; + } + if DetectBufferSetActiveList(de, s, G_MYSQL_COMMAND_BUFFER_ID) < 0 { + return -1; + } + return 0; +} + +#[no_mangle] +unsafe extern "C" fn SCMysqlGetCommand( + de: *mut c_void, transforms: *const c_void, flow: *const c_void, flow_flags: u8, + tx: *const c_void, list_id: c_int, +) -> *mut c_void { + return DetectHelperGetData( + de, + transforms, + flow, + flow_flags, + tx, + list_id, + SCMysqlGetCommandData, + ); +} + +#[no_mangle] +unsafe extern "C" fn SCMysqlGetCommandData( + tx: *const c_void, _flags: u8, buf: *mut *const u8, len: *mut u32, +) -> bool { + let tx = cast_pointer!(tx, MysqlTransaction); + if let Some(command) = &tx.command { + if !command.is_empty() { + *buf = command.as_ptr(); + *len = command.len() as u32; + return true; + } + } + + false +} + +#[no_mangle] +unsafe extern "C" fn SCMysqlRowsSetup( + de: *mut c_void, s: *mut c_void, _raw: *const std::os::raw::c_char, +) -> c_int { + if DetectSignatureSetAppProto(s, ALPROTO_MYSQL) != 0 { + return -1; + } + if DetectBufferSetActiveList(de, s, G_MYSQL_ROWS_BUFFER_ID) < 0 { + return -1; + } + return 0; +} + +#[no_mangle] +unsafe extern "C" fn SCMysqlGetRows( + de: *mut c_void, transforms: *const c_void, flow: *const c_void, flow_flags: u8, + tx: *const c_void, list_id: c_int, local_id: u32, +) -> *mut c_void { + return DetectHelperGetMultiData( + de, + transforms, + flow, + flow_flags, + tx, + list_id, + local_id, + SCMysqlGetRowsData, + ); +} + +/// Get the mysql rows at index i +#[no_mangle] +pub unsafe extern "C" fn SCMysqlGetRowsData( + tx: *const c_void, _flow_flags: u8, local_id: u32, buf: *mut *const u8, len: *mut u32, +) -> bool { + let tx = cast_pointer!(tx, MysqlTransaction); + if let Some(rows) = &tx.rows { + if !rows.is_empty() { + let index = local_id as usize; + if let Some(row) = rows.get(index) { + *buf = row.as_ptr(); + *len = row.len() as u32; + return true; + } + } + } + + false +} + +#[no_mangle] +pub unsafe extern "C" fn ScDetectMysqlRegister() { + let kw = SCSigTableElmt { + name: b"mysql.command\0".as_ptr() as *const libc::c_char, + desc: b"sticky buffer to match on the MySQL command\0".as_ptr() as *const libc::c_char, + url: b"/rules/mysql-keywords.html#mysql-command\0".as_ptr() as *const libc::c_char, + Setup: SCMysqlCommandSetup, + flags: SIGMATCH_NOOPT, + AppLayerTxMatch: None, + Free: None, + }; + let _g_mysql_command_kw_id = DetectHelperKeywordRegister(&kw); + G_MYSQL_COMMAND_BUFFER_ID = DetectHelperBufferMpmRegister( + b"mysql.command\0".as_ptr() as *const libc::c_char, + b"mysql.command\0".as_ptr() as *const libc::c_char, + ALPROTO_MYSQL, + false, + true, + SCMysqlGetCommand, + ); + let kw = SCSigTableElmt { + name: b"mysql.rows\0".as_ptr() as *const libc::c_char, + desc: b"sticky buffer to match on the MySQL Rows\0".as_ptr() as *const libc::c_char, + url: b"/rules/mysql-keywords.html#mysql-rows\0".as_ptr() as *const libc::c_char, + Setup: SCMysqlRowsSetup, + flags: SIGMATCH_NOOPT, + AppLayerTxMatch: None, + Free: None, + }; + let _g_mysql_rows_kw_id = DetectHelperKeywordRegister(&kw); + G_MYSQL_ROWS_BUFFER_ID = DetectHelperMultiBufferMpmRegister( + b"mysql.rows\0".as_ptr() as *const libc::c_char, + b"mysql select statement resultset\0".as_ptr() as *const libc::c_char, + ALPROTO_MYSQL, + true, + false, + SCMysqlGetRows, + ); +} diff --git a/rust/src/mysql/logger.rs b/rust/src/mysql/logger.rs new file mode 100644 index 000000000000..38cd85ebb461 --- /dev/null +++ b/rust/src/mysql/logger.rs @@ -0,0 +1,59 @@ +/* Copyright (C) 2024 Open Information Security Foundation + * + * You can copy, redistribute or modify this Program under the terms of + * the GNU General Public License version 2 as published by the Free + * Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * version 2 along with this program; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA + * 02110-1301, USA. + */ + +// Author: QianKaiLin + +use crate::jsonbuilder::{JsonBuilder, JsonError}; +use crate::mysql::mysql::*; + +fn log_mysql(tx: &MysqlTransaction, js: &mut JsonBuilder) -> Result<(), JsonError> { + js.open_object("mysql")?; + js.set_string("version", tx.version.as_str())?; + js.set_bool("tls", tx.tls)?; + + if let Some(command) = &tx.command { + js.set_string("command", command)?; + } + + if let Some(affected_rows) = tx.affected_rows { + js.set_uint("affected_rows", affected_rows)?; + } + + if let Some(rows) = &tx.rows { + js.open_array("rows")?; + for row in rows { + js.append_string(row)?; + } + js.close()?; + } + + js.close()?; + + Ok(()) +} + +#[no_mangle] +pub unsafe extern "C" fn SCMysqlLogger( + tx: *mut std::os::raw::c_void, js: &mut JsonBuilder, +) -> bool { + let tx_mysql = cast_pointer!(tx, MysqlTransaction); + let result = log_mysql(tx_mysql, js); + if let Err(ref _err) = result { + return false; + } + return result.is_ok(); +} diff --git a/rust/src/mysql/mod.rs b/rust/src/mysql/mod.rs new file mode 100644 index 000000000000..4bc3005880c4 --- /dev/null +++ b/rust/src/mysql/mod.rs @@ -0,0 +1,23 @@ +/* Copyright (C) 2024 Open Information Security Foundation + * + * You can copy, redistribute or modify this Program under the terms of + * the GNU General Public License version 2 as published by the Free + * Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * version 2 along with this program; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA + * 02110-1301, USA. + */ + +// Author: QianKaiLin + +pub mod detect; +pub mod logger; +pub mod mysql; +pub mod parser; diff --git a/rust/src/mysql/mysql.rs b/rust/src/mysql/mysql.rs new file mode 100644 index 000000000000..c6b5054cb900 --- /dev/null +++ b/rust/src/mysql/mysql.rs @@ -0,0 +1,1255 @@ +/* Copyright (C) 2024 Open Information Security Foundation + * + * You can copy, redistribute or modify this Program under the terms of + * the GNU General Public License version 2 as published by the Free + * Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * version 2 along with this program; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA + * 02110-1301, USA. + */ + +// Author: QianKaiLin +// +use std::collections::VecDeque; +use std::ffi::CString; + +use nom7::IResult; + +use crate::applayer::*; +use crate::conf::{conf_get, get_memval}; +use crate::core::*; + +use super::parser::*; + +pub const MYSQL_CONFIG_DEFAULT_STREAM_DEPTH: u32 = 0; + +static mut MYSQL_MAX_TX: usize = 1024; + +pub static mut ALPROTO_MYSQL: AppProto = ALPROTO_UNKNOWN; + +#[derive(FromPrimitive, Debug, AppLayerEvent)] +pub enum MysqlEvent { + TooManyTransactions, +} + +#[derive(Debug)] +pub struct MysqlTransaction { + pub tx_id: u64, + + /// Required + pub version: String, + /// Optional when tls is true + pub command: Option, + /// Optional when tls is true + pub affected_rows: Option, + /// Optional when tls is true + pub rows: Option>, + pub tls: bool, + + pub complete: bool, + pub tx_data: AppLayerTxData, +} + +impl Transaction for MysqlTransaction { + fn id(&self) -> u64 { + self.tx_id + } +} + +impl MysqlTransaction { + pub fn new(version: String) -> Self { + Self { + tx_id: 0, + version, + command: None, + affected_rows: None, + tls: false, + tx_data: AppLayerTxData::new(), + complete: false, + rows: None, + } + } +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum MysqlStateProgress { + // Default State + Init, + + // Connection Phase + // Server send HandshakeRequest to Client + // https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase.html + Handshake, + // Client send HandshakeResponse to Server + // Server send AuthSwitchRequest to Client + Auth, + // Server send OkResponse to Client + AuthFinished, + + // Command Phase + // Client send QueryRequest to Server + CommandReceived, + // Server send EOF with 0x0A to Client + TextResulsetContinue, + // Server send QueryResponse to Client or Ok to Client + CommandResponseReceived, + // Server send LocalFileRequest with zero length to Client + LocalFileRequestReceived, + // Client send empty packet to Server + LocalFileContentFinished, + // Client send StatisticsRequest to Server + StatisticsReceived, + // Server send StatisticsResponse to client + StatisticsResponseReceived, + // Client send FieldList to Server + FieldListReceived, + // Server send FieldListResponse to client + FieldListResponseReceived, + // Client send ChangeUserRequest to Server + ChangeUserReceived, + // Server send OkResponse to client + ChangeUserResponseReceived, + // Client send Unknown to Server + UnknownCommandReceived, + // Client send StmtPrepareRequest to Server + StmtPrepareReceived, + // Server send StmtPrepareResponse to Client + StmtPrepareResponseReceived, + // Client send StmtExecRequest to Server + StmtExecReceived, + // Server send StmtExecResponse with EOF status equal 0x0a to Client + StmtExecResponseContinue, + // Server send ResultSetResponse to Client or Ok Response to Client + StmtExecResponseReceived, + // Client send StmtFetchRequest to Server + StmtFetchReceived, + // Server send StmtFetchResponse with EOF status equal 0x0a to Client + StmtFetchResponseContinue, + // Server send StmtFetchResponse to Client + StmtFetchResponseReceived, + // Client send StmtFetchRequest to Server + StmtResetReceived, + // Server send Ok or Err Response to Client + StmtResetResponseReceived, + // Client send StmtCloseRequest to Server + StmtCloseReceived, + + // Client send QueryRequest and command equal Quit to Server + // Client send ChangeUserRequest to server and server send ErrResponse to Client + // Transport Layer EOF + // Transport Layer Upgrade to TLS + Finished, +} + +#[derive(Debug)] +struct MysqlStatement { + statement_id: Option, + prepare_stmt: String, + param_cnt: Option, + param_types: Option>, + stmt_long_datas: Option>, + rows: Option>, +} + +impl MysqlStatement { + fn new(prepare_stmt: String) -> Self { + MysqlStatement { + statement_id: None, + prepare_stmt, + param_cnt: None, + param_types: None, + stmt_long_datas: None, + rows: None, + } + } + + fn set_statement_id(&mut self, statement_id: u32) { + self.statement_id = Some(statement_id); + } + + fn set_param_cnt(&mut self, param_cnt: u16) { + self.param_cnt = Some(param_cnt); + } + + fn set_param_types(&mut self, cols: Vec) { + self.param_types = Some(cols); + } + + fn add_stmt_long_datas(&mut self, long_data: StmtLongData) { + if let Some(stmt_long_datas) = &mut self.stmt_long_datas { + stmt_long_datas.push(long_data); + } else { + self.stmt_long_datas = Some(vec![long_data]); + } + } + + fn reset_stmt_long_datas(&mut self) { + self.stmt_long_datas.take(); + } + + fn add_rows(&mut self, rows: Vec) { + if let Some(old_rows) = &mut self.rows { + old_rows.extend(rows); + } else { + self.rows = Some(rows); + } + } + + fn execute(&self, params: Vec) -> Option { + let prepare_stmt = self.prepare_stmt.clone(); + let mut query = String::new(); + if !params.is_empty() + && self.param_cnt.is_some() + && self.param_cnt.unwrap() as usize == params.len() + { + let mut params = params.iter(); + for part in prepare_stmt.split('?') { + query.push_str(part); + if let Some(param) = params.next() { + query.push_str(param); + } + } + Some(query) + } else { + None + } + } +} + +#[derive(Debug)] +pub struct MysqlState { + pub state_data: AppLayerStateData, + pub tx_id: u64, + transactions: VecDeque, + request_gap: bool, + response_gap: bool, + state_progress: MysqlStateProgress, + tx_index_completed: usize, + + client_flags: u32, + version: Option, + tls: bool, + /// stmt prepare + prepare_stmt: Option, +} + +impl State for MysqlState { + fn get_transaction_count(&self) -> usize { + self.transactions.len() + } + + fn get_transaction_by_index(&self, index: usize) -> Option<&MysqlTransaction> { + self.transactions.get(index) + } +} + +impl Default for MysqlState { + fn default() -> Self { + Self::new() + } +} + +impl MysqlState { + pub fn new() -> Self { + let state = Self { + state_data: AppLayerStateData::new(), + tx_id: 0, + transactions: VecDeque::new(), + request_gap: false, + response_gap: false, + state_progress: MysqlStateProgress::Init, + tx_index_completed: 0, + + client_flags: 0, + version: None, + tls: false, + prepare_stmt: None, + }; + state + } + + pub fn free_tx(&mut self, tx_id: u64) { + let len = self.transactions.len(); + let mut found = false; + let mut index = 0; + for i in 0..len { + let tx = &self.transactions[i]; + if tx.tx_id == tx_id + 1 { + found = true; + index = i; + break; + } + } + if found { + self.tx_index_completed = 0; + self.transactions.remove(index); + } + } + + pub fn get_tx(&mut self, tx_id: u64) -> Option<&MysqlTransaction> { + self.transactions.iter().find(|tx| tx.tx_id == tx_id + 1) + } + + fn get_tx_mut(&mut self, tx_id: u64) -> Option<&mut MysqlTransaction> { + self.transactions + .iter_mut() + .find(|tx| tx.tx_id == tx_id + 1) + } + + fn set_event(tx: &mut MysqlTransaction, event: MysqlEvent) { + tx.tx_data.set_event(event as u8); + } + + fn new_tx(&mut self, command: String) -> MysqlTransaction { + let mut tx = MysqlTransaction::new(self.version.clone().unwrap_or_default()); + self.tx_id += 1; + tx.tx_id = self.tx_id; + tx.tls = self.tls; + if tx.tls { + tx.complete = true; + } + tx.command = Some(command); + SCLogDebug!("Creating new transaction.tx_id: {}", tx.tx_id); + if self.transactions.len() > unsafe { MYSQL_MAX_TX } + self.tx_index_completed { + let mut index = self.tx_index_completed; + for tx_old in &mut self.transactions.range_mut(self.tx_index_completed..) { + index += 1; + if !tx_old.complete { + tx_old.complete = true; + MysqlState::set_event(tx_old, MysqlEvent::TooManyTransactions); + break; + } + } + self.tx_index_completed = index; + } + tx + } + + /// Find or create a new transaction + /// + /// If a new transaction is created, push that into state.transactions before returning &mut to last tx + /// If we can't find a transaction and we should not create one, we return None + /// The moment when this is called will may impact the logic of transaction tracking (e.g. when a tx is considered completed) + fn create_tx(&mut self, command: String) -> Option<&mut MysqlTransaction> { + let tx = self.new_tx(command); + SCLogDebug!("create state is {:?}", &self.state_progress); + self.transactions.push_back(tx); + self.transactions.back_mut() + } + + fn request_next_state( + &mut self, request: MysqlFEMessage, f: *const Flow, + ) -> Option { + match request { + MysqlFEMessage::HandshakeResponse(resp) => { + // for now, don't support compress + if resp.zstd_compression_level.is_some() { + return Some(MysqlStateProgress::Finished); + } + if resp.client_flags & CLIENT_DEPRECATE_EOF != 0 + || resp.client_flags & CLIENT_OPTIONAL_RESULTSET_METADATA != 0 + { + return Some(MysqlStateProgress::Finished); + } + self.client_flags = resp.client_flags; + Some(MysqlStateProgress::Auth) + } + MysqlFEMessage::SSLRequest(_) => { + unsafe { + AppLayerRequestProtocolTLSUpgrade(f); + } + self.tls = true; + self.create_tx("".to_string()); + Some(MysqlStateProgress::Finished) + } + MysqlFEMessage::AuthRequest => None, + MysqlFEMessage::LocalFileData(length) => { + if length == 0 { + let tx = self.get_tx_mut(self.tx_id - 1); + if let Some(tx) = tx { + tx.complete = true; + } + return Some(MysqlStateProgress::LocalFileContentFinished); + } + None + } + MysqlFEMessage::Request(req) => match req.command { + MysqlCommand::Query { query: _ } => { + self.create_tx(req.command.to_string()); + return Some(MysqlStateProgress::CommandReceived); + } + MysqlCommand::StmtPrepare { query } => { + self.prepare_stmt = Some(MysqlStatement::new(query)); + return Some(MysqlStateProgress::StmtPrepareReceived); + } + + MysqlCommand::StmtExecute { + statement_id: expected_statement_id, + params, + } => { + if let Some(prepare_stmt) = &self.prepare_stmt { + if let Some(statement_id) = prepare_stmt.statement_id { + if statement_id == expected_statement_id { + let command = prepare_stmt.execute(params.unwrap_or_default()); + self.create_tx(command.unwrap_or_default()); + } else { + SCLogWarning!( + "Receive stmt exec statement_id {} not equal we need {}", + expected_statement_id, + statement_id + ); + return Some(MysqlStateProgress::Finished); + } + } + } else { + return Some(MysqlStateProgress::Finished); + } + return Some(MysqlStateProgress::StmtExecReceived); + } + MysqlCommand::StmtFetch { + statement_id: _, + number_rows: _, + } => { + return Some(MysqlStateProgress::StmtFetchReceived); + } + MysqlCommand::StmtSendLongData(stmt_long_data) => { + if let Some(prepare_stmt) = &mut self.prepare_stmt { + if let Some(statement_id) = prepare_stmt.statement_id { + if statement_id == stmt_long_data.statement_id { + prepare_stmt.add_stmt_long_datas(stmt_long_data); + } + } + } + None + } + MysqlCommand::StmtReset { statement_id } => { + if let Some(prepare_stmt) = &mut self.prepare_stmt { + if let Some(expected_statement_id) = prepare_stmt.statement_id { + if statement_id == expected_statement_id { + prepare_stmt.reset_stmt_long_datas(); + } + } + } + return Some(MysqlStateProgress::StmtResetReceived); + } + MysqlCommand::StmtClose { statement_id } => { + if let Some(prepare_stmt) = &self.prepare_stmt { + if let Some(expected_statement_id) = prepare_stmt.statement_id { + if statement_id == expected_statement_id { + self.prepare_stmt.take(); + } else { + SCLogWarning!( + "Receive stmt close statement_id {} not equal we need {}", + expected_statement_id, + statement_id + ); + } + } else { + SCLogWarning!("Receive stmt close without stmt prepare response"); + } + } else { + SCLogWarning!("Receive stmt close without stmt prepare response"); + } + + return Some(MysqlStateProgress::StmtCloseReceived); + } + MysqlCommand::Quit => { + self.create_tx(req.command.to_string()); + return Some(MysqlStateProgress::Finished); + } + MysqlCommand::Ping + | MysqlCommand::Debug + | MysqlCommand::ResetConnection + | MysqlCommand::SetOption => { + self.create_tx(req.command.to_string()); + Some(MysqlStateProgress::CommandReceived) + } + MysqlCommand::Statistics => Some(MysqlStateProgress::StatisticsReceived), + MysqlCommand::FieldList { table: _ } => { + self.create_tx(req.command.to_string()); + return Some(MysqlStateProgress::FieldListReceived); + } + MysqlCommand::ChangeUser => { + self.create_tx(req.command.to_string()); + return Some(MysqlStateProgress::ChangeUserReceived); + } + _ => { + SCLogWarning!("Unknown command {}", req.command_code); + return Some(MysqlStateProgress::UnknownCommandReceived); + } + }, + } + } + + fn state_based_req_parsing( + state: MysqlStateProgress, i: &[u8], param_cnt: Option, + param_types: Option>, + stmt_long_datas: Option>, client_flags: u32, + ) -> IResult<&[u8], MysqlFEMessage> { + match state { + MysqlStateProgress::Handshake => { + let old = i; + let (_, client_flags) = parse_handshake_capabilities(i)?; + if client_flags & CLIENT_SSL != 0 { + let (i, req) = parse_handshake_ssl_request(old)?; + return Ok((i, MysqlFEMessage::SSLRequest(req))); + } + let (i, req) = parse_handshake_response(old)?; + Ok((i, MysqlFEMessage::HandshakeResponse(req))) + } + MysqlStateProgress::Auth => { + let (i, _) = parse_auth_request(i)?; + Ok((i, MysqlFEMessage::AuthRequest)) + } + MysqlStateProgress::LocalFileRequestReceived => { + let (i, length) = parse_local_file_data_content(i)?; + Ok((i, MysqlFEMessage::LocalFileData(length))) + } + _ => { + let (i, req) = + parse_request(i, param_cnt, param_types, stmt_long_datas, client_flags)?; + Ok((i, MysqlFEMessage::Request(req))) + } + } + } + + pub fn parse_request(&mut self, flow: *const Flow, i: &[u8]) -> AppLayerResult { + // We're not interested in empty requests. + if i.is_empty() { + return AppLayerResult::ok(); + } + + // If there was gap, check we can sync up again. + if self.request_gap { + if probe(i).is_err() { + SCLogDebug!("Suricata interprets there's a gap in the request"); + return AppLayerResult::ok(); + } + + // It looks like we're in sync with the message header + // clear gap state and keep parsing. + self.request_gap = false; + } + if self.state_progress == MysqlStateProgress::Finished { + return AppLayerResult::ok(); + } + + let mut start = i; + while !start.is_empty() { + SCLogDebug!( + "In 'parse_request' State Progress is: {:?}", + &self.state_progress + ); + let mut stmt_long_datas = None; + let mut param_cnt = None; + let mut param_types = None; + if let Some(prepare_stmt) = &self.prepare_stmt { + stmt_long_datas = prepare_stmt.stmt_long_datas.clone(); + param_cnt = prepare_stmt.param_cnt; + param_types = prepare_stmt.param_types.clone(); + } + + match MysqlState::state_based_req_parsing( + self.state_progress, + start, + param_cnt, + param_types.clone(), + stmt_long_datas, + self.client_flags, + ) { + Ok((rem, request)) => { + SCLogDebug!("Request is {:?}", &request); + start = rem; + if let Some(state) = self.request_next_state(request, flow) { + self.state_progress = state; + } + } + Err(nom7::Err::Incomplete(_needed)) => { + let consumed = i.len() - start.len(); + let needed_estimation = start.len() + 1; + SCLogDebug!( + "Needed: {:?}, estimated needed: {:?}", + _needed, + needed_estimation + ); + return AppLayerResult::incomplete(consumed as u32, needed_estimation as u32); + } + Err(err) => { + SCLogError!( + "Error while parsing MySQL request, state: {:?} err: {:?}", + self.state_progress, + err + ); + return AppLayerResult::err(); + } + } + } + + // All Input was fully consumed. + AppLayerResult::ok() + } + + /// When the state changes based on a specific response, there are other actions we may need to perform + /// + /// If there is data from the backend message that Suri should store separately in the State or + /// Transaction, that is also done here + fn response_next_state(&mut self, response: MysqlBEMessage) -> Option { + match response { + MysqlBEMessage::HandshakeRequest(req) => { + self.version = Some(req.version.clone()); + Some(MysqlStateProgress::Handshake) + } + + MysqlBEMessage::Response(resp) => match resp.item { + MysqlResponsePacket::LocalInFileRequest => { + Some(MysqlStateProgress::LocalFileRequestReceived) + } + MysqlResponsePacket::FieldsList { columns: _ } => { + let tx = if self.tx_id > 0 { + self.get_tx_mut(self.tx_id - 1) + } else { + None + }; + if let Some(tx) = tx { + tx.complete = true; + } + Some(MysqlStateProgress::FieldListResponseReceived) + } + MysqlResponsePacket::Statistics => { + let tx = if self.tx_id > 0 { + self.get_tx_mut(self.tx_id - 1) + } else { + None + }; + if let Some(tx) = tx { + tx.complete = true; + } + Some(MysqlStateProgress::StatisticsResponseReceived) + } + MysqlResponsePacket::AuthSwithRequest => Some(MysqlStateProgress::Auth), + MysqlResponsePacket::AuthData => None, + MysqlResponsePacket::Err { .. } => match self.state_progress { + MysqlStateProgress::CommandReceived + | MysqlStateProgress::TextResulsetContinue => { + let tx = if self.tx_id > 0 { + self.get_tx_mut(self.tx_id - 1) + } else { + None + }; + if let Some(tx) = tx { + tx.complete = true; + } + Some(MysqlStateProgress::CommandResponseReceived) + } + MysqlStateProgress::FieldListReceived => { + let tx = if self.tx_id > 0 { + self.get_tx_mut(self.tx_id - 1) + } else { + None + }; + if let Some(tx) = tx { + tx.complete = true; + } + Some(MysqlStateProgress::FieldListResponseReceived) + } + MysqlStateProgress::StmtExecReceived + | MysqlStateProgress::StmtExecResponseContinue => { + let tx = if self.tx_id > 0 { + self.get_tx_mut(self.tx_id - 1) + } else { + None + }; + if let Some(tx) = tx { + tx.complete = true; + } + Some(MysqlStateProgress::StmtExecResponseReceived) + } + MysqlStateProgress::StmtResetReceived => { + Some(MysqlStateProgress::StmtResetResponseReceived) + } + MysqlStateProgress::ChangeUserReceived => { + let tx = if self.tx_id > 0 { + self.get_tx_mut(self.tx_id - 1) + } else { + None + }; + if let Some(tx) = tx { + tx.complete = true; + } + Some(MysqlStateProgress::Finished) + } + MysqlStateProgress::StmtFetchReceived + | MysqlStateProgress::StmtFetchResponseContinue => { + let tx = if self.tx_id > 0 { + self.get_tx_mut(self.tx_id - 1) + } else { + None + }; + if let Some(tx) = tx { + tx.complete = true; + } + Some(MysqlStateProgress::StmtFetchResponseReceived) + } + _ => None, + }, + MysqlResponsePacket::Ok { + rows, + flags: _, + warnings: _, + } => match self.state_progress { + MysqlStateProgress::Auth => Some(MysqlStateProgress::AuthFinished), + MysqlStateProgress::CommandReceived => { + let tx = if self.tx_id > 0 { + self.get_tx_mut(self.tx_id - 1) + } else { + None + }; + if let Some(tx) = tx { + tx.affected_rows = Some(rows); + tx.complete = true; + } + Some(MysqlStateProgress::CommandResponseReceived) + } + MysqlStateProgress::StmtExecReceived => { + let tx = if self.tx_id > 0 { + self.get_tx_mut(self.tx_id - 1) + } else { + None + }; + if let Some(tx) = tx { + tx.affected_rows = Some(rows); + tx.complete = true; + } + Some(MysqlStateProgress::StmtExecResponseReceived) + } + MysqlStateProgress::ChangeUserReceived => { + Some(MysqlStateProgress::ChangeUserResponseReceived) + } + MysqlStateProgress::StmtResetReceived => { + Some(MysqlStateProgress::StmtResetResponseReceived) + } + MysqlStateProgress::TextResulsetContinue => { + let tx = if self.tx_id > 0 { + self.get_tx_mut(self.tx_id - 1) + } else { + None + }; + if let Some(tx) = tx { + tx.complete = true; + } + Some(MysqlStateProgress::CommandResponseReceived) + } + MysqlStateProgress::StmtExecResponseContinue => { + let prepare_stmt = self.prepare_stmt.take(); + if self.tx_id > 0 { + let tx = self.get_tx_mut(self.tx_id - 1); + if let Some(tx) = tx { + if let Some(mut prepare_stmt) = prepare_stmt { + let rows = prepare_stmt.rows.take(); + if let Some(rows) = rows { + tx.rows = Some( + rows.into_iter() + .map(|row| match row { + MysqlResultBinarySetRow::Err => String::new(), + MysqlResultBinarySetRow::Text(text) => text, + }) + .collect::>(), + ); + } + + tx.complete = true; + } + } + } + Some(MysqlStateProgress::StmtExecResponseReceived) + } + MysqlStateProgress::StmtFetchResponseContinue => { + Some(MysqlStateProgress::StmtFetchResponseReceived) + } + _ => None, + }, + MysqlResponsePacket::ResultSet { + n_cols: _, + columns: _, + eof, + rows, + } => { + let tx = if self.tx_id > 0 { + self.get_tx_mut(self.tx_id - 1) + } else { + None + }; + if !rows.is_empty() { + let mut rows = rows.into_iter().map(|row| row.texts.join(",")).collect(); + if let Some(tx) = tx { + if eof.status_flags != 0x0A { + tx.rows = Some(rows); + Some(MysqlStateProgress::CommandResponseReceived) + } else { + // MultiStatement + if let Some(state_rows) = tx.rows.as_mut() { + state_rows.append(&mut rows); + } else { + tx.rows = Some(rows); + } + + Some(MysqlStateProgress::TextResulsetContinue) + } + } else { + Some(MysqlStateProgress::Finished) + } + } else { + Some(MysqlStateProgress::CommandResponseReceived) + } + } + MysqlResponsePacket::StmtPrepare { + statement_id, + num_params, + params, + .. + } => { + if let Some(prepare_stmt) = &mut self.prepare_stmt { + prepare_stmt.set_statement_id(statement_id); + prepare_stmt.set_param_cnt(num_params); + if let Some(params) = params { + prepare_stmt.set_param_types(params); + } + } + + Some(MysqlStateProgress::StmtPrepareResponseReceived) + } + MysqlResponsePacket::StmtFetch => { + Some(MysqlStateProgress::StmtFetchResponseReceived) + } + MysqlResponsePacket::BinaryResultSet { + n_cols: _, + eof, + rows, + } => { + if self.state_progress == MysqlStateProgress::StmtFetchReceived + || self.state_progress == MysqlStateProgress::StmtFetchResponseContinue + { + return Some(MysqlStateProgress::StmtFetchResponseContinue); + } + + if !rows.is_empty() { + if eof.status_flags != 0x0A { + let tx = if self.tx_id > 0 { + self.get_tx_mut(self.tx_id - 1) + } else { + None + }; + if let Some(tx) = tx { + tx.rows = Some( + rows.into_iter() + .map(|row| match row { + MysqlResultBinarySetRow::Err => String::new(), + MysqlResultBinarySetRow::Text(text) => text, + }) + .collect::>(), + ); + tx.complete = true; + } + + Some(MysqlStateProgress::StmtExecResponseReceived) + } else { + // MultiResulset + if let Some(prepare_stmt) = &mut self.prepare_stmt { + prepare_stmt.add_rows(rows); + } + + Some(MysqlStateProgress::StmtExecResponseContinue) + } + } else { + Some(MysqlStateProgress::StmtExecResponseReceived) + } + } + _ => None, + }, + } + } + + fn state_based_resp_parsing( + state: MysqlStateProgress, i: &[u8], client_flags: u32, + ) -> IResult<&[u8], MysqlBEMessage> { + match state { + MysqlStateProgress::Init => { + let (i, resp) = parse_handshake_request(i)?; + Ok((i, MysqlBEMessage::HandshakeRequest(resp))) + } + + MysqlStateProgress::Auth => { + let (i, resp) = parse_auth_responsev2(i)?; + Ok((i, MysqlBEMessage::Response(resp))) + } + + MysqlStateProgress::StmtPrepareReceived => { + let (i, resp) = parse_stmt_prepare_response(i, client_flags)?; + Ok((i, MysqlBEMessage::Response(resp))) + } + + MysqlStateProgress::StmtExecReceived | MysqlStateProgress::StmtExecResponseContinue => { + let (i, resp) = parse_stmt_execute_response(i)?; + Ok((i, MysqlBEMessage::Response(resp))) + } + + MysqlStateProgress::StmtFetchReceived + | MysqlStateProgress::StmtFetchResponseContinue => { + let (i, resp) = parse_stmt_fetch_response(i)?; + Ok((i, MysqlBEMessage::Response(resp))) + } + + MysqlStateProgress::FieldListReceived => { + let (i, resp) = parse_field_list_response(i)?; + Ok((i, MysqlBEMessage::Response(resp))) + } + MysqlStateProgress::StatisticsReceived => { + let (i, resp) = parse_statistics_response(i)?; + Ok((i, MysqlBEMessage::Response(resp))) + } + MysqlStateProgress::ChangeUserReceived => { + let (i, resp) = parse_change_user_response(i)?; + Ok((i, MysqlBEMessage::Response(resp))) + } + + _ => { + let (i, resp) = parse_response(i)?; + Ok((i, MysqlBEMessage::Response(resp))) + } + } + } + + fn invalid_state_resp(&self) -> bool { + use MysqlStateProgress::*; + self.state_progress == CommandResponseReceived + || self.state_progress == StmtCloseReceived + || self.state_progress == StmtPrepareResponseReceived + || self.state_progress == StmtExecResponseReceived + } + + pub fn parse_response(&mut self, i: &[u8]) -> AppLayerResult { + // We're not interested in empty responses. + if i.is_empty() { + return AppLayerResult::ok(); + } + + if self.response_gap { + if probe(i).is_err() { + SCLogDebug!("Suricata interprets there's a gap in the response"); + return AppLayerResult::ok(); + } + + // It seems we're in sync with a message header, clear gap state and keep parsing. + self.response_gap = false; + } + + let mut start = i; + + while !start.is_empty() { + if self.state_progress == MysqlStateProgress::Finished || self.invalid_state_resp() { + return AppLayerResult::ok(); + } + match MysqlState::state_based_resp_parsing( + self.state_progress, + start, + self.client_flags, + ) { + Ok((rem, response)) => { + start = rem; + + SCLogDebug!("Response is {:?}", &response); + if let Some(state) = self.response_next_state(response) { + self.state_progress = state; + } + } + Err(nom7::Err::Incomplete(_needed)) => { + let consumed = i.len() - start.len(); + let needed_estimation = start.len() + 1; + SCLogDebug!( + "Needed: {:?}, estimated needed: {:?}, start is {:?}", + _needed, + needed_estimation, + &start + ); + return AppLayerResult::incomplete(consumed as u32, needed_estimation as u32); + } + Err(_err) => { + SCLogDebug!( + "Error while parsing MySQL response, state: {:?} err: {:?}", + self.state_progress, + _err, + ); + return AppLayerResult::err(); + } + } + } + + // All Input was fully consumed. + AppLayerResult::ok() + } + + pub fn on_request_gap(&mut self, _size: u32) { + self.request_gap = true; + } + + pub fn on_response_gap(&mut self, _size: u32) { + self.response_gap = true; + } +} + +/// Probe for a valid mysql message +fn probe(i: &[u8]) -> IResult<&[u8], ()> { + let (i, _) = parse_packet_header(i)?; + Ok((i, ())) +} + +// C exports + +/// C entry point for a probing parser. +#[no_mangle] +pub unsafe extern "C" fn rs_mysql_probing_ts( + _flow: *const Flow, _direction: u8, input: *const u8, input_len: u32, _rdir: *mut u8, +) -> AppProto { + if input_len >= 1 && !input.is_null() { + let slice: &[u8] = build_slice!(input, input_len as usize); + match parse_handshake_capabilities(slice) { + Ok((_, client_flags)) => { + if client_flags & CLIENT_SSL != 0 { + match parse_handshake_ssl_request(slice) { + Ok(_) => return ALPROTO_MYSQL, + Err(nom7::Err::Incomplete(_)) => return ALPROTO_UNKNOWN, + Err(err) => { + SCLogError!("failed to probe ssl request {:?}", err); + return ALPROTO_FAILED; + } + } + } else { + match parse_handshake_response(slice) { + Ok(_) => return ALPROTO_MYSQL, + Err(nom7::Err::Incomplete(_)) => return ALPROTO_UNKNOWN, + Err(err) => { + SCLogError!("failed to probe handshake response {:?}", err); + return ALPROTO_FAILED; + } + } + } + } + Err(nom7::Err::Incomplete(_)) => return ALPROTO_UNKNOWN, + Err(_err) => { + SCLogDebug!("failed to probe request {:?}", _err); + return ALPROTO_FAILED; + } + } + } + + return ALPROTO_UNKNOWN; +} + +#[no_mangle] +pub unsafe extern "C" fn rs_mysql_probing_tc( + _flow: *const Flow, _direction: u8, input: *const u8, input_len: u32, _rdir: *mut u8, +) -> AppProto { + if input_len >= 1 && !input.is_null() { + let slice: &[u8] = build_slice!(input, input_len as usize); + match parse_handshake_request(slice) { + Ok(_) => return ALPROTO_MYSQL, + Err(nom7::Err::Incomplete(_)) => return ALPROTO_UNKNOWN, + Err(_err) => { + SCLogDebug!("failed to probe response {:?}", _err); + return ALPROTO_FAILED; + } + } + } + + return ALPROTO_UNKNOWN; +} + +#[no_mangle] +pub extern "C" fn rs_mysql_state_new( + _orig_state: *mut std::os::raw::c_void, _orig_proto: AppProto, +) -> *mut std::os::raw::c_void { + let state = MysqlState::new(); + let boxed = Box::new(state); + return Box::into_raw(boxed) as *mut _; +} + +#[no_mangle] +pub extern "C" fn rs_mysql_state_free(state: *mut std::os::raw::c_void) { + std::mem::drop(unsafe { Box::from_raw(state as *mut MysqlState) }); +} + +#[no_mangle] +pub extern "C" fn rs_mysql_state_tx_free(state: *mut std::os::raw::c_void, tx_id: u64) { + let state_safe: &mut MysqlState; + unsafe { + state_safe = cast_pointer!(state, MysqlState); + } + state_safe.free_tx(tx_id); +} + +#[no_mangle] +pub unsafe extern "C" fn rs_mysql_parse_request( + flow: *const Flow, state: *mut std::os::raw::c_void, pstate: *mut std::os::raw::c_void, + stream_slice: StreamSlice, _data: *const std::os::raw::c_void, +) -> AppLayerResult { + if stream_slice.is_empty() { + if AppLayerParserStateIssetFlag(pstate, APP_LAYER_PARSER_EOF_TS) > 0 { + SCLogDebug!(" Caracal reached `eof`"); + return AppLayerResult::ok(); + } else { + return AppLayerResult::err(); + } + } + + let state_safe: &mut MysqlState = cast_pointer!(state, MysqlState); + + if stream_slice.is_gap() { + state_safe.on_request_gap(stream_slice.gap_size()); + } else { + return state_safe.parse_request(flow, stream_slice.as_slice()); + } + AppLayerResult::ok() +} + +#[no_mangle] +pub unsafe extern "C" fn rs_mysql_parse_response( + _flow: *const Flow, state: *mut std::os::raw::c_void, pstate: *mut std::os::raw::c_void, + stream_slice: StreamSlice, _data: *const std::os::raw::c_void, +) -> AppLayerResult { + if stream_slice.is_empty() { + if AppLayerParserStateIssetFlag(pstate, APP_LAYER_PARSER_EOF_TC) > 0 { + return AppLayerResult::ok(); + } else { + return AppLayerResult::err(); + } + } + + let state_safe: &mut MysqlState = cast_pointer!(state, MysqlState); + + if stream_slice.is_gap() { + state_safe.on_response_gap(stream_slice.gap_size()); + } else { + return state_safe.parse_response(stream_slice.as_slice()); + } + AppLayerResult::ok() +} + +#[no_mangle] +pub extern "C" fn rs_mysql_state_get_tx_count(state: *mut std::os::raw::c_void) -> u64 { + let state_safe: &mut MysqlState; + unsafe { + state_safe = cast_pointer!(state, MysqlState); + } + return state_safe.tx_id; +} + +#[no_mangle] +pub unsafe extern "C" fn rs_mysql_state_get_tx( + state: *mut std::os::raw::c_void, tx_id: u64, +) -> *mut std::os::raw::c_void { + let state_safe: &mut MysqlState = cast_pointer!(state, MysqlState); + match state_safe.get_tx(tx_id) { + Some(tx) => { + return tx as *const _ as *mut _; + } + None => { + return std::ptr::null_mut(); + } + } +} + +#[no_mangle] +pub unsafe extern "C" fn rs_mysql_tx_get_alstate_progress( + tx: *mut std::os::raw::c_void, _direction: u8, +) -> std::os::raw::c_int { + let tx = cast_pointer!(tx, MysqlTransaction); + if tx.complete { + return 1; + } + return 0; +} + +export_tx_data_get!(rs_mysql_get_tx_data, MysqlTransaction); +export_state_data_get!(rs_mysql_get_state_data, MysqlState); + +// Parser name as a C style string. +const PARSER_NAME: &[u8] = b"mysql\0"; + +#[no_mangle] +pub unsafe extern "C" fn rs_mysql_register_parser() { + let default_port = CString::new("[3306]").unwrap(); + let mut stream_depth = MYSQL_CONFIG_DEFAULT_STREAM_DEPTH; + let parser = RustParser { + name: PARSER_NAME.as_ptr() as *const std::os::raw::c_char, + default_port: default_port.as_ptr(), + ipproto: IPPROTO_TCP, + probe_ts: Some(rs_mysql_probing_ts), + probe_tc: Some(rs_mysql_probing_tc), + min_depth: 0, + max_depth: 16, + state_new: rs_mysql_state_new, + state_free: rs_mysql_state_free, + tx_free: rs_mysql_state_tx_free, + parse_ts: rs_mysql_parse_request, + parse_tc: rs_mysql_parse_response, + get_tx_count: rs_mysql_state_get_tx_count, + get_tx: rs_mysql_state_get_tx, + tx_comp_st_ts: 1, + tx_comp_st_tc: 1, + tx_get_progress: rs_mysql_tx_get_alstate_progress, + get_eventinfo: Some(MysqlEvent::get_event_info), + get_eventinfo_byid: Some(MysqlEvent::get_event_info_by_id), + localstorage_new: None, + localstorage_free: None, + get_tx_files: None, + get_tx_iterator: Some( + crate::applayer::state_get_tx_iterator::, + ), + get_tx_data: rs_mysql_get_tx_data, + get_state_data: rs_mysql_get_state_data, + apply_tx_config: None, + flags: APP_LAYER_PARSER_OPT_ACCEPT_GAPS, + get_frame_id_by_name: None, + get_frame_name_by_id: None, + }; + + let ip_proto_str = CString::new("tcp").unwrap(); + + if AppLayerProtoDetectConfProtoDetectionEnabled(ip_proto_str.as_ptr(), parser.name) != 0 { + let alproto = AppLayerRegisterProtocolDetection(&parser, 1); + ALPROTO_MYSQL = alproto; + if AppLayerParserConfParserEnabled(ip_proto_str.as_ptr(), parser.name) != 0 { + let _ = AppLayerRegisterParser(&parser, alproto); + } + SCLogDebug!("Rust mysql parser registered."); + let retval = conf_get("app-layer.protocols.mysql.stream-depth"); + if let Some(val) = retval { + match get_memval(val) { + Ok(retval) => { + stream_depth = retval as u32; + } + Err(_) => { + SCLogError!("Invalid depth value"); + } + } + AppLayerParserSetStreamDepth(IPPROTO_TCP, ALPROTO_MYSQL, stream_depth) + } + if let Some(val) = conf_get("app-layer.protocols.mysql.max-tx") { + if let Ok(v) = val.parse::() { + MYSQL_MAX_TX = v; + } else { + SCLogError!("Invalid value for mysql.max-tx"); + } + } + AppLayerParserRegisterLogger(IPPROTO_TCP, ALPROTO_MYSQL); + } else { + SCLogDebug!("Protocol detector and parser disabled for MYSQL."); + } +} diff --git a/rust/src/mysql/parser.rs b/rust/src/mysql/parser.rs new file mode 100644 index 000000000000..d74bbcaed5eb --- /dev/null +++ b/rust/src/mysql/parser.rs @@ -0,0 +1,2354 @@ +/* Copyright (C) 2024 Open Information Security Foundation + * + * You can copy, redistribute or modify this Program under the terms of + * the GNU General Public License version 2 as published by the Free + * Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * version 2 along with this program; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA + * 02110-1301, USA. + */ + +// Author: QianKaiLin + +//! MySQL nom parsers + +use nom7::{ + bytes::streaming::{take, take_till}, + combinator::{cond, map, verify}, + multi::{many_m_n, many_till}, + number::streaming::{ + be_i8, be_u32, be_u8, le_f32, le_f64, le_i16, le_i32, le_i64, le_u16, le_u24, le_u32, + le_u64, + }, + IResult, +}; +use num::{FromPrimitive, ToPrimitive}; +use suricata_derive::EnumStringU8; + +#[allow(dead_code)] +pub const CLIENT_LONG_PASSWORD: u32 = BIT_U32!(0); +#[allow(dead_code)] +pub const CLIENT_FOUND_ROWS: u32 = BIT_U32!(1); +#[allow(dead_code)] +pub const CLIENT_LONG_FLAG: u32 = BIT_U32!(2); +const CLIENT_CONNECT_WITH_DB: u32 = BIT_U32!(3); +#[allow(dead_code)] +const CLIENT_NO_SCHEMA: u32 = BIT_U32!(4); +#[allow(dead_code)] +const CLIENT_COMPRESS: u32 = BIT_U32!(5); +#[allow(dead_code)] +const CLIENT_ODBC: u32 = BIT_U32!(6); +#[allow(dead_code)] +const CLIENT_LOCAL_FILES: u32 = BIT_U32!(7); +#[allow(dead_code)] +const CLIENT_IGNORE_SPACE: u32 = BIT_U32!(8); +const CLIENT_PROTOCOL_41: u32 = BIT_U32!(9); +#[allow(dead_code)] +const CLIENT_INTERACTIVE: u32 = BIT_U32!(10); +pub const CLIENT_SSL: u32 = BIT_U32!(11); +#[allow(dead_code)] +pub const CLIENT_IGNORE_SIGPIPE: u32 = BIT_U32!(12); +#[allow(dead_code)] +pub const CLIENT_TRANSACTIONS: u32 = BIT_U32!(13); +#[allow(dead_code)] +pub const CLIENT_RESERVED: u32 = BIT_U32!(14); +#[allow(dead_code)] +pub const CLIENT_RESERVED2: u32 = BIT_U32!(15); +#[allow(dead_code)] +pub const CLIENT_MULTI_STATEMENTS: u32 = BIT_U32!(16); +#[allow(dead_code)] +pub const CLIENT_MULTI_RESULTS: u32 = BIT_U32!(17); +#[allow(dead_code)] +pub const CLIENT_PS_MULTI_RESULTS: u32 = BIT_U32!(18); +#[allow(dead_code)] +pub const CLIENT_PLUGIN_AUTH: u32 = BIT_U32!(19); +#[allow(dead_code)] +pub const CLIENT_CONNECT_ATTRS: u32 = BIT_U32!(20); +#[allow(dead_code)] +pub const CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA: u32 = BIT_U32!(21); +#[allow(dead_code)] +pub const CLIENT_CAN_HANDLE_EXPIRED_PASSWORDS: u32 = BIT_U32!(22); +#[allow(dead_code)] +pub const CLIENT_SESSION_TRACK: u32 = BIT_U32!(23); +pub const CLIENT_DEPRECATE_EOF: u32 = BIT_U32!(24); +#[allow(dead_code)] +pub const CLIENT_OPTIONAL_RESULTSET_METADATA: u32 = BIT_U32!(25); +#[allow(dead_code)] +pub const CLIENT_ZSTD_COMPRESSION_ALGORITHM: u32 = BIT_U32!(26); +#[allow(dead_code)] +pub const CLIENT_QUERY_ATTRIBUTES: u32 = BIT_U32!(27); +#[allow(dead_code)] +pub const MULTI_FACTOR_AUTHENTICATION: u32 = BIT_U32!(28); +#[allow(dead_code)] +pub const CLIENT_CAPABILITY_EXTENSION: u32 = BIT_U32!(29); +#[allow(dead_code)] +pub const CLIENT_SSL_VERIFY_SERVER_CERT: u32 = BIT_U32!(30); +#[allow(dead_code)] +pub const CLIENT_REMEMBER_OPTIONS: u32 = BIT_U32!(31); + +#[allow(dead_code)] +pub const FIELD_FLAGS_UNSIGNED: u32 = BIT_U32!(5); + +const PAYLOAD_MAX_LEN: u32 = 0xffffff; + +#[repr(u8)] +#[derive(Debug, Clone, Copy, EnumStringU8, FromPrimitive)] +pub enum FieldType { + Decimal = 0, + Tiny = 1, + Short = 2, + Long = 3, + Float = 4, + Double = 5, + NULL = 6, + Timestamp = 7, + LongLong = 8, + Int24 = 9, + Date = 10, + Time = 11, + Datetime = 12, + Year = 13, + NewDate = 14, + Varchar = 15, + Bit = 16, + Timestamp2 = 17, + Datetime2 = 18, + Time2 = 19, + Array = 20, + Unknown = 241, + Vector = 242, + Invalid = 243, + Bool = 244, + Json = 245, + NewDecimal = 246, + Enum = 247, + Set = 248, + TinyBlob = 249, + MediumBlob = 250, + LongBlob = 251, + Blob = 252, + VarString = 253, + String = 254, + Geometry = 255, +} + +#[inline] +fn parse_field_type(field_type: u8) -> FieldType { + if let Some(f) = FromPrimitive::from_u8(field_type) { + f + } else { + FieldType::Invalid + } +} + +#[derive(Debug)] +pub struct MysqlPacket<'a> { + pub pkt_len: usize, + pub pkt_num: u8, + + payload: &'a mut [u8], +} + +impl<'a> Drop for MysqlPacket<'a> { + fn drop(&mut self) { + unsafe { + std::mem::drop(Box::from_raw(self.payload as *mut [u8])); + } + } +} + +#[derive(Debug)] +pub struct MysqlEofPacket { + pub warnings: u16, + pub status_flags: u16, +} + +#[derive(Debug, PartialEq, Eq, Clone)] +pub struct StmtLongData { + pub statement_id: u32, + pub param_id: u16, + pub payload: String, +} + +#[derive(Debug, PartialEq, Eq, Clone)] +pub enum MysqlCommand { + Unknown, + Quit, + Ping, + Statistics, + Debug, + ChangeUser, + ResetConnection, + SetOption, + InitDb { + schema: String, + }, + Query { + query: String, + }, + FieldList { + table: String, + }, + StmtPrepare { + query: String, + }, + StmtSendLongData(StmtLongData), + StmtExecute { + statement_id: u32, + params: Option>, + }, + StmtFetch { + statement_id: u32, + number_rows: u32, + }, + StmtReset { + statement_id: u32, + }, + StmtClose { + statement_id: u32, + }, +} + +impl std::fmt::Display for MysqlCommand { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + MysqlCommand::Quit => write!(f, "quit"), + MysqlCommand::Query { query } => write!(f, "{}", query), + MysqlCommand::Ping => write!(f, "ping"), + _ => write!(f, ""), + } + } +} + +#[derive(Debug, Clone)] +pub struct MysqlColumnDefinition { + pub catalog: String, + pub schema: String, + pub table: String, + pub orig_table: String, + pub name: String, + pub character_set: u16, + pub column_length: u32, + pub field_type: FieldType, + pub flags: u16, + pub decimals: u8, +} + +#[derive(Debug)] +pub struct MysqlResultSetRow { + pub texts: Vec, +} + +#[derive(Debug)] +pub enum MysqlResultBinarySetRow { + Err, + Text(String), +} + +#[derive(Debug)] +pub struct MysqlHandshakeRequest { + // pub header: MysqlPacket, + pub protocol: u8, + pub version: String, + pub conn_id: u32, + pub salt1: String, + pub capability_flags1: u16, + pub character_set: u8, + pub status_flags: u16, + pub capability_flags2: u16, + pub auth_plugin_len: u8, + pub salt2: String, + pub auth_plugin_data: Option, +} + +#[derive(Debug)] +pub struct MysqlHandshakeResponseAttribute { + pub key: String, + pub value: String, +} + +#[derive(Debug)] +pub struct MysqlSSLRequest { + pub filter: Option, +} + +#[derive(Debug)] +pub struct MysqlHandshakeResponse { + pub username: String, + pub auth_response_len: u8, + pub auth_response: String, + pub database: Option, + pub client_flags: u32, + pub client_plugin_name: Option, + pub attributes: Option>, + pub zstd_compression_level: Option, +} + +#[derive(Debug)] +pub struct MysqlAuthSwtichRequest { + pub plugin_name: String, + pub plugin_data: String, +} + +#[derive(Debug)] +pub struct MysqlRequest { + // pub header: MysqlPacket, + pub command_code: u8, + pub command: MysqlCommand, +} + +#[derive(Debug)] +pub enum MysqlResponsePacket { + Unknown, + AuthMoreData { + data: u8, + }, + LocalInFileRequest, + AuthData, + Statistics, + AuthSwithRequest, + EOF, + Ok { + rows: u64, + flags: u16, + warnings: u16, + }, + Err { + error_code: u16, + error_message: String, + }, + FieldsList { + columns: Option>, + }, + ResultSet { + n_cols: u64, + columns: Vec, + eof: MysqlEofPacket, + rows: Vec, + }, + BinaryResultSet { + n_cols: u64, + eof: MysqlEofPacket, + rows: Vec, + }, + + StmtPrepare { + statement_id: u32, + num_params: u16, + params: Option>, + fields: Option>, + }, + StmtFetch, +} + +#[derive(Debug)] +pub struct MysqlResponse { + pub item: MysqlResponsePacket, +} + +#[derive(Debug)] +pub enum MysqlBEMessage { + HandshakeRequest(MysqlHandshakeRequest), + Response(MysqlResponse), +} + +#[derive(Debug)] +pub enum MysqlFEMessage { + SSLRequest(MysqlSSLRequest), + AuthRequest, + Request(MysqlRequest), + LocalFileData(usize), + HandshakeResponse(MysqlHandshakeResponse), +} + +fn parse_varint(i: &[u8]) -> IResult<&[u8], u64> { + let (i, length) = be_u8(i)?; + match length { + // 251: NULL + 0xfb => Ok((i, 0)), + // 252: value of following 2 + 0xfc => { + let (i, v0) = be_u8(i)?; + let (i, v1) = be_u8(i)?; + let v0 = v0 as u64; + let v1 = (v1 as u64) << 8; + Ok((i, v0 | v1)) + } + // 253: value of following 3 + 0xfd => { + let (i, v0) = be_u8(i)?; + let (i, v1) = be_u8(i)?; + let (i, v2) = be_u8(i)?; + let v0 = v0 as u64; + let v1 = (v1 as u64) << 8; + let v2 = (v2 as u64) << 16; + Ok((i, v0 | v1 | v2)) + } + // 254: value of following 8 + 0xfe => { + let (i, v0) = be_u8(i)?; + let (i, v1) = be_u8(i)?; + let (i, v2) = be_u8(i)?; + let (i, v3) = be_u8(i)?; + let (i, v4) = be_u8(i)?; + let (i, v5) = be_u8(i)?; + let (i, v6) = be_u8(i)?; + let (i, v7) = be_u8(i)?; + let v0 = v0 as u64; + let v1 = (v1 as u64) << 8; + let v2 = (v2 as u64) << 16; + let v3 = (v3 as u64) << 24; + let v4 = (v4 as u64) << 32; + let v5 = (v5 as u64) << 40; + let v6 = (v6 as u64) << 48; + let v7 = (v7 as u64) << 56; + Ok((i, v0 | v1 | v2 | v3 | v4 | v5 | v6 | v7)) + } + _ => Ok((i, length as u64)), + } +} + +pub fn parse_packet_header(i: &[u8]) -> IResult<&[u8], MysqlPacket> { + let mut payload = Vec::new(); + let mut payload_len: usize = 0; + let mut rem = i; + let mut pkt_num = None; + // Loop until payload length is less than 0xffffff + // https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_basic_packets.html#sect_protocol_basic_packets_sending_mt_16mb + loop { + let (i, pkt_len) = verify(le_u24, |&pkt_len| -> bool { pkt_len <= PAYLOAD_MAX_LEN })(rem)?; + payload_len += pkt_len as usize; + let (i, num) = be_u8(i)?; + if pkt_num.is_none() { + pkt_num = Some(num); + } + let (i, rem_payload) = take(pkt_len)(i)?; + rem = i; + // payload extend rem_payload + payload.extend_from_slice(rem_payload); + + if pkt_len < PAYLOAD_MAX_LEN { + break; + } + } + + let pkt_len = payload_len; + let pkt_num = pkt_num.unwrap_or_default(); + // payload extend rem for next parse + let payload = Box::leak(payload.into_boxed_slice()); + Ok(( + rem, + MysqlPacket { + pkt_len, + pkt_num, + payload, + }, + )) +} + +fn parse_eof_packet(i: &[u8]) -> IResult<&[u8], MysqlEofPacket> { + let (rem, header) = parse_packet_header(i)?; + let payload = + unsafe { std::slice::from_raw_parts(header.payload.as_ptr(), header.payload.len()) }; + let (i, _tag) = verify(be_u8, |&x| x == 0xfe)(payload)?; + let (i, warnings) = le_u16(i)?; + let (_, status_flags) = le_u16(i)?; + + Ok(( + rem, + MysqlEofPacket { + warnings, + status_flags, + }, + )) +} + +fn parse_init_db_cmd(i: &[u8]) -> IResult<&[u8], MysqlCommand> { + let (i, schema) = map(take(i.len()), |s: &[u8]| { + String::from_utf8_lossy(s).to_string() + })(i)?; + Ok((i, MysqlCommand::InitDb { schema })) +} + +fn parse_query_cmd(i: &[u8], client_flags: u32) -> IResult<&[u8], MysqlCommand> { + let length = i.len(); + let old = i; + let (i, param_cnt) = cond(client_flags & CLIENT_QUERY_ATTRIBUTES != 0, parse_varint)(i)?; + let (i, _param_set_cnt) = cond( + client_flags & CLIENT_QUERY_ATTRIBUTES != 0, + verify(be_u8, |¶m_set_cnt| param_set_cnt == 1), + )(i)?; + let param_cnt = param_cnt.unwrap_or_default(); + let (i, null_mask) = cond(param_cnt > 0, take((param_cnt + 7) / 8))(i)?; + let (i, new_params_bind_flag) = cond( + param_cnt > 0, + verify(be_u8, |&new_params_bind_flag| new_params_bind_flag == 1), + )(i)?; + let new_params_bind_flag = new_params_bind_flag.unwrap_or_default(); + + let (i, param_types) = cond( + param_cnt > 0 && new_params_bind_flag != 0, + many_m_n( + param_cnt as usize, + param_cnt as usize, + |i| -> IResult<&[u8], (FieldType, bool)> { + let (i, field_type) = be_u8(i)?; + let (i, flags) = be_u8(i)?; + let (i, _param_name) = map(take(length), |s: &[u8]| { + String::from_utf8_lossy(s).to_string() + })(i)?; + + Ok((i, (parse_field_type(field_type), flags != 0))) + }, + ), + )(i)?; + + let mut data = i; + if param_cnt > 0 { + let null_mask = null_mask.unwrap_or_default(); + if let Some(param_types) = param_types { + for i in 0..param_cnt as usize { + if !null_mask.is_empty() && ((null_mask[i >> 3] >> (i & 7)) & 1) == 1 { + continue; + } + let (field_type, unsigned) = param_types.get(i).unwrap(); + + let ch = data; + // Normal + let (ch, _res) = match *field_type { + FieldType::NULL => (ch, "NULL".to_string()), + FieldType::Tiny | FieldType::Bool => { + if *unsigned { + let (ch, v) = be_u8(ch)?; + (ch, v.to_string()) + } else { + let (ch, v) = be_i8(ch)?; + (ch, v.to_string()) + } + } + FieldType::Short | FieldType::Year => { + if *unsigned { + let (ch, v) = le_u16(ch)?; + (ch, v.to_string()) + } else { + let (ch, v) = le_i16(ch)?; + (ch, v.to_string()) + } + } + FieldType::Int24 | FieldType::Long => { + if *unsigned { + let (ch, v) = le_u32(ch)?; + (ch, v.to_string()) + } else { + let (ch, v) = le_i32(ch)?; + (ch, v.to_string()) + } + } + FieldType::LongLong => { + if *unsigned { + let (ch, v) = le_u64(ch)?; + (ch, v.to_string()) + } else { + let (ch, v) = le_i64(ch)?; + (ch, v.to_string()) + } + } + FieldType::Float => { + let (ch, v) = le_f32(ch)?; + (ch, v.to_string()) + } + FieldType::Double => { + let (ch, v) = le_f64(ch)?; + (ch, v.to_string()) + } + FieldType::Decimal + | FieldType::NewDecimal + | FieldType::Varchar + | FieldType::Bit + | FieldType::Enum + | FieldType::Set + | FieldType::TinyBlob + | FieldType::MediumBlob + | FieldType::LongBlob + | FieldType::Blob + | FieldType::VarString + | FieldType::String + | FieldType::Geometry + | FieldType::Json + | FieldType::Vector => { + let (ch, len) = parse_varint(ch)?; + let (ch, data) = map(take(len), |ch: &[u8]| { + String::from_utf8_lossy(ch).to_string() + })(ch)?; + (ch, data) + } + FieldType::Date + | FieldType::NewDate + | FieldType::Datetime + | FieldType::Datetime2 + | FieldType::Timestamp + | FieldType::Timestamp2 + | FieldType::Time + | FieldType::Time2 => { + let (ch, len) = parse_varint(ch)?; + match len { + 0 => (ch, "datetime 0000-00-00 00:00:00.000000".to_string()), + 4 => { + let (ch, year) = le_u16(ch)?; + let (ch, month) = be_u8(ch)?; + let (ch, day) = be_u8(ch)?; + (ch, format!("datetime {:04}-{:02}-{:02}", year, month, day)) + } + 7 => { + let (ch, year) = le_u16(ch)?; + let (ch, month) = be_u8(ch)?; + let (ch, day) = be_u8(ch)?; + let (ch, hour) = be_u8(ch)?; + let (ch, minute) = be_u8(ch)?; + let (ch, second) = be_u8(ch)?; + ( + ch, + format!( + "datetime {:04}-{:02}-{:02} {:02}:{:02}:{:02}", + year, month, day, hour, minute, second + ), + ) + } + 11 => { + let (ch, year) = le_u16(ch)?; + let (ch, month) = be_u8(ch)?; + let (ch, day) = be_u8(ch)?; + let (ch, hour) = be_u8(ch)?; + let (ch, minute) = be_u8(ch)?; + let (ch, second) = be_u8(ch)?; + let (ch, microsecond) = le_u32(ch)?; + ( + ch, + format!( + "datetime {:04}-{:02}-{:02} {:02}:{:02}:{:02}.{:06}", + year, month, day, hour, minute, second, microsecond, + ), + ) + } + _ => { + let (ch, _) = take(len)(ch)?; + (ch, "".to_string()) + } + } + } + _ => (ch, "".to_string()), + }; + data = ch; + } + } + } + let i = data; + + let consumed = old.len() - i.len(); + + // Should never happen + if consumed > length { + return Ok(( + &[], + MysqlCommand::Query { + query: "".to_string(), + }, + )); + } + let length = length - consumed; + + let (i, query) = map(take(length), |s: &[u8]| { + String::from_utf8_lossy(s).to_string() + })(i)?; + + Ok((i, MysqlCommand::Query { query })) +} + +fn parse_stmt_prepare_cmd(i: &[u8]) -> IResult<&[u8], MysqlCommand> { + let length = i.len(); + let (i, query) = map(take(length), |s: &[u8]| { + String::from_utf8_lossy(s).to_string() + })(i)?; + Ok((i, MysqlCommand::StmtPrepare { query })) +} + +fn parse_stmt_send_long_data_cmd(i: &[u8]) -> IResult<&[u8], MysqlCommand> { + let (i, statement_id) = le_u32(i)?; + let (i, param_id) = le_u16(i)?; + let (i, length) = parse_varint(i)?; + let (i, payload) = map(take(length), |s: &[u8]| { + String::from_utf8_lossy(s).to_string() + })(i)?; + Ok(( + i, + MysqlCommand::StmtSendLongData(StmtLongData { + statement_id, + param_id, + payload, + }), + )) +} + +fn parse_stmt_execute_cmd( + i: &[u8], param_cnt: Option, param_types: Option>, + stmt_long_datas: Option>, client_flags: u32, +) -> IResult<&[u8], MysqlCommand> { + let length = i.len(); + let old = i; + let (i, statement_id) = le_u32(i)?; + let (i, flags) = be_u8(i)?; + let (i, _iteration_count) = le_u32(i)?; + + if let Some(param_cnt) = param_cnt { + let mut param_cnt = param_cnt; + if param_cnt > 0 || ((client_flags & CLIENT_QUERY_ATTRIBUTES != 0) && (flags & 8 != 0)) { + let (i, override_param_cnts) = + cond(client_flags & CLIENT_QUERY_ATTRIBUTES != 0, parse_varint)(i)?; + if let Some(override_param_cnts) = override_param_cnts { + param_cnt = override_param_cnts as u16; + } + if param_cnt > 0 { + // NULL-bitmap, [(column-count + 7) / 8 bytes] + let null_bitmap_size = (param_cnt + 7) / 8; + let (i, null_mask) = take(null_bitmap_size)(i)?; + let (i, new_params_bind_flags) = be_u8(i)?; + + let (i, new_param_types) = cond( + new_params_bind_flags != 0, + many_m_n( + param_cnt as usize, + param_cnt as usize, + |ch| -> IResult<&[u8], (FieldType, bool)> { + let (ch, field_type) = be_u8(ch)?; + let (ch, flags) = be_u8(ch)?; + let (ch, _param_names) = + cond(client_flags & CLIENT_QUERY_ATTRIBUTES != 0, |ch| { + let (ch, length) = parse_varint(ch)?; + let (ch, name) = map(take(length), |s| { + String::from_utf8_lossy(s).to_string() + })(ch)?; + Ok((ch, name)) + })(ch)?; + + Ok((ch, (parse_field_type(field_type), flags != 0))) + }, + ), + )(i)?; + let param_types = if let Some(new_param_types) = new_param_types { + Some(new_param_types) + } else { + param_types.map(|param_types| { + param_types + .iter() + .map(|param_type| (param_type.field_type, param_type.flags != 0)) + .collect() + }) + }; + + let consumed = old.len() - i.len(); + // Should never happen + if consumed > length { + return Ok(( + &[], + MysqlCommand::StmtExecute { + statement_id, + params: None, + }, + )); + } + let (i, data) = take(length - consumed)(i)?; + if param_types.is_none() { + return Ok(( + i, + MysqlCommand::StmtExecute { + statement_id, + params: None, + }, + )); + } + + let param_types = param_types.unwrap(); + + let mut data = data; + let mut params = Vec::new(); + for i in 0..param_cnt as usize { + // Field is NULL + // (byte >> bit-pos) % 2 == 1 + if !null_mask.is_empty() && ((null_mask[i >> 3] >> (i & 7)) & 1) == 1 { + params.push("NULL".to_string()); + continue; + } + // Field is LongData + if let Some(stmt_long_datas) = &stmt_long_datas { + for stmt_long_data in stmt_long_datas { + if stmt_long_data.param_id as usize == i { + params.push(stmt_long_data.payload.clone()); + continue; + } + } + } + let (field_type, unsigned) = param_types.get(i).unwrap(); + + let ch = data; + // Normal + let (ch, res) = match *field_type { + FieldType::NULL => (ch, "NULL".to_string()), + FieldType::Tiny | FieldType::Bool => { + if *unsigned { + let (ch, v) = be_u8(ch)?; + (ch, v.to_string()) + } else { + let (ch, v) = be_i8(ch)?; + (ch, v.to_string()) + } + } + FieldType::Short | FieldType::Year => { + if *unsigned { + let (ch, v) = le_u16(ch)?; + (ch, v.to_string()) + } else { + let (ch, v) = le_i16(ch)?; + (ch, v.to_string()) + } + } + FieldType::Int24 | FieldType::Long => { + if *unsigned { + let (ch, v) = le_u32(ch)?; + (ch, v.to_string()) + } else { + let (ch, v) = le_i32(ch)?; + (ch, v.to_string()) + } + } + FieldType::LongLong => { + if *unsigned { + let (ch, v) = le_u64(ch)?; + (ch, v.to_string()) + } else { + let (ch, v) = le_i64(ch)?; + (ch, v.to_string()) + } + } + FieldType::Float => { + let (ch, v) = le_f32(ch)?; + (ch, v.to_string()) + } + FieldType::Double => { + let (ch, v) = le_f64(ch)?; + (ch, v.to_string()) + } + FieldType::Decimal + | FieldType::NewDecimal + | FieldType::Varchar + | FieldType::Bit + | FieldType::Enum + | FieldType::Set + | FieldType::TinyBlob + | FieldType::MediumBlob + | FieldType::LongBlob + | FieldType::Blob + | FieldType::VarString + | FieldType::String + | FieldType::Geometry + | FieldType::Json + | FieldType::Vector => { + let (ch, len) = parse_varint(ch)?; + let (ch, data) = map(take(len), |ch: &[u8]| { + String::from_utf8_lossy(ch).to_string() + })(ch)?; + (ch, data) + } + FieldType::Date + | FieldType::NewDate + | FieldType::Datetime + | FieldType::Datetime2 + | FieldType::Timestamp + | FieldType::Timestamp2 + | FieldType::Time + | FieldType::Time2 => { + let (ch, len) = parse_varint(ch)?; + match len { + 0 => (ch, "datetime 0000-00-00 00:00:00.000000".to_string()), + 4 => { + let (ch, year) = le_u16(ch)?; + let (ch, month) = be_u8(ch)?; + let (ch, day) = be_u8(ch)?; + (ch, format!("datetime {:04}-{:02}-{:02}", year, month, day)) + } + 7 => { + let (ch, year) = le_u16(ch)?; + let (ch, month) = be_u8(ch)?; + let (ch, day) = be_u8(ch)?; + let (ch, hour) = be_u8(ch)?; + let (ch, minute) = be_u8(ch)?; + let (ch, second) = be_u8(ch)?; + ( + ch, + format!( + "datetime {:04}-{:02}-{:02} {:02}:{:02}:{:02}", + year, month, day, hour, minute, second + ), + ) + } + 11 => { + let (ch, year) = le_u16(ch)?; + let (ch, month) = be_u8(ch)?; + let (ch, day) = be_u8(ch)?; + let (ch, hour) = be_u8(ch)?; + let (ch, minute) = be_u8(ch)?; + let (ch, second) = be_u8(ch)?; + let (ch, microsecond) = le_u32(ch)?; + ( + ch, + format!( + "datetime {:04}-{:02}-{:02} {:02}:{:02}:{:02}.{:06}", + year, month, day, hour, minute, second, microsecond, + ), + ) + } + _ => { + let (ch, _) = take(len)(ch)?; + (ch, "".to_string()) + } + } + } + _ => (ch, "".to_string()), + }; + params.push(res); + data = ch; + } + Ok(( + i, + MysqlCommand::StmtExecute { + statement_id, + params: Some(params), + }, + )) + } else { + Ok(( + i, + MysqlCommand::StmtExecute { + statement_id, + params: None, + }, + )) + } + } else { + Ok(( + i, + MysqlCommand::StmtExecute { + statement_id, + params: None, + }, + )) + } + } else { + let consumed = old.len() - i.len(); + // Should never happen + if consumed > length { + return Ok(( + &[], + MysqlCommand::StmtExecute { + statement_id, + params: None, + }, + )); + } + let (i, _) = take(length - consumed)(i)?; + Ok(( + i, + MysqlCommand::StmtExecute { + statement_id, + params: None, + }, + )) + } +} + +fn parse_field_list_cmd(i: &[u8]) -> IResult<&[u8], MysqlCommand> { + let length = i.len(); + let old = i; + let (i, table) = map(take_till(|ch| ch == 0x00), |s: &[u8]| { + String::from_utf8_lossy(s).to_string() + })(i)?; + let consumed = old.len() - i.len(); + // Should never happen + if consumed > length { + return Ok(( + &[], + MysqlCommand::FieldList { + table: "".to_string(), + }, + )); + } + let (i, _) = take(length - consumed)(i)?; + Ok((i, MysqlCommand::FieldList { table })) +} + +fn parse_stmt_fetch_cmd(i: &[u8]) -> IResult<&[u8], MysqlCommand> { + let (i, statement_id) = le_u32(i)?; + let (i, number_rows) = le_u32(i)?; + Ok(( + i, + MysqlCommand::StmtFetch { + statement_id, + number_rows, + }, + )) +} + +fn parse_stmt_close_cmd(i: &[u8]) -> IResult<&[u8], MysqlCommand> { + let (i, statement_id) = le_u32(i)?; + Ok((i, MysqlCommand::StmtClose { statement_id })) +} + +fn parse_column_definition(i: &[u8]) -> IResult<&[u8], MysqlColumnDefinition> { + let (rem, header) = parse_packet_header(i)?; + let payload = + unsafe { std::slice::from_raw_parts(header.payload.as_ptr(), header.payload.len()) }; + let (i, _len) = parse_varint(payload)?; + let (i, _catalog) = map(take(_len as u32), |s: &[u8]| { + String::from_utf8_lossy(s).to_string() + })(i)?; + + let (i, _len) = parse_varint(i)?; + let (i, schema) = map(take(_len as u32), |s: &[u8]| { + String::from_utf8_lossy(s).to_string() + })(i)?; + + let (i, _len) = parse_varint(i)?; + let (i, table) = map(take(_len as u32), |s: &[u8]| { + String::from_utf8_lossy(s).to_string() + })(i)?; + + let (i, _len) = parse_varint(i)?; + let (i, orig_table) = map(take(_len as u32), |s: &[u8]| { + String::from_utf8_lossy(s).to_string() + })(i)?; + + let (i, _len) = parse_varint(i)?; + let (i, name) = map(take(_len as u32), |s: &[u8]| { + String::from_utf8_lossy(s).to_string() + })(i)?; + + let (i, _len) = parse_varint(i)?; + let (i, _orig_name) = map(take(_len as u32), |s: &[u8]| { + String::from_utf8_lossy(s).to_string() + })(i)?; + + let (i, _) = parse_varint(i)?; + let (i, character_set) = le_u16(i)?; + let (i, column_length) = le_u32(i)?; + let (i, field_type) = be_u8(i)?; + let (i, flags) = le_u16(i)?; + let (i, decimals) = be_u8(i)?; + let (_, _filter) = take(2_u32)(i)?; + + let field_type = parse_field_type(field_type); + + Ok(( + rem, + MysqlColumnDefinition { + catalog: "def".to_string(), + schema, + table, + orig_table, + name, + character_set, + column_length, + field_type, + flags, + decimals, + }, + )) +} + +fn parse_resultset_row_texts(i: &[u8]) -> IResult<&[u8], Vec> { + let mut rem = i; + let mut length = i.len(); + let mut texts = Vec::new(); + while length > 0 { + let (i, len) = parse_varint(rem)?; + let mut consumed = rem.len() - i.len(); + if len == 0xFB { + texts.push("NULL".to_string()); + rem = i; + } else { + let (i, text) = map(take(len), |s: &[u8]| String::from_utf8_lossy(s).to_string())(i)?; + texts.push(text); + consumed += len as usize; + rem = i; + } + // Should never happen + if consumed > length { + return Ok((&[], texts)); + } + length -= consumed; + } + + Ok((&[], texts)) +} + +fn parse_resultset_row(i: &[u8]) -> IResult<&[u8], MysqlResultSetRow> { + let (rem, header) = parse_packet_header(i)?; + let payload = + unsafe { std::slice::from_raw_parts(header.payload.as_ptr(), header.payload.len()) }; + let (_, texts) = parse_resultset_row_texts(payload)?; + + Ok((rem, MysqlResultSetRow { texts })) +} + +fn parse_binary_resultset_row( + columns: Vec, +) -> impl FnMut(&[u8]) -> IResult<&[u8], MysqlResultBinarySetRow> { + move |i| { + let (rem, header) = parse_packet_header(i)?; + let payload = + unsafe { std::slice::from_raw_parts(header.payload.as_ptr(), header.payload.len()) }; + let (i, response_code) = verify(be_u8, |&x| x == 0x00 || x == 0xFF)(payload)?; + // ERR + if response_code == 0xFF { + let (_, _resp) = parse_response_err(i)?; + return Ok((rem, MysqlResultBinarySetRow::Err)); + } + let (_, data) = take(header.pkt_len - 1)(i)?; + + // NULL-bitmap, [(column-count + 7 + 2) / 8 bytes] + let mut texts = Vec::new(); + let mut pos = (columns.len() + 7 + 2) >> 3; + let null_mask = &data[..pos]; + for i in 0..columns.len() { + // Field is NULL + // byte = ((field-pos + 2) / 8) + // bit-pos = ((field-pos + 2) % 8) + // (byte >> bit-pos) % 2 == 1 + if ((null_mask[(i + 2) >> 3] >> ((i + 2) & 7)) & 1) == 1 { + continue; + } + + match columns[i].field_type { + FieldType::NULL => texts.push("NULL".to_string()), + FieldType::Tiny => { + if columns[i].flags & (FIELD_FLAGS_UNSIGNED as u16) != 0 { + texts.push(format!("{}", data[pos].to_u8().unwrap_or_default())); + } else { + texts.push(format!("{}", data[pos].to_i8().unwrap_or_default())); + } + pos += 1; + } + FieldType::Short | FieldType::Year => { + if columns[i].flags & (FIELD_FLAGS_UNSIGNED as u16) != 0 { + texts.push(format!( + "{}", + u16::from_le_bytes(data[pos..pos + 2].try_into().unwrap_or_default()) + )); + } else { + texts.push(format!( + "{}", + i16::from_le_bytes(data[pos..pos + 2].try_into().unwrap_or_default()) + )); + } + pos += 2; + } + FieldType::Int24 | FieldType::Long => { + if columns[i].flags & (FIELD_FLAGS_UNSIGNED as u16) != 0 { + texts.push(format!( + "{}", + u32::from_le_bytes(data[pos..pos + 4].try_into().unwrap_or_default()) + )); + } else { + texts.push(format!( + "{}", + i32::from_le_bytes(data[pos..pos + 4].try_into().unwrap_or_default()) + )); + } + pos += 4; + } + FieldType::LongLong => { + if columns[i].flags & (FIELD_FLAGS_UNSIGNED as u16) != 0 { + texts.push(format!( + "{}", + u64::from_le_bytes(data[pos..pos + 8].try_into().unwrap_or_default()) + )); + } else { + texts.push(format!( + "{}", + i64::from_le_bytes(data[pos..pos + 8].try_into().unwrap_or_default()) + )); + } + pos += 8; + } + FieldType::Float => { + texts.push(format!( + "{}", + f32::from_le_bytes(data[pos..pos + 4].try_into().unwrap_or_default()) + )); + pos += 4; + } + FieldType::Double => { + texts.push(format!( + "{}", + f64::from_le_bytes(data[pos..pos + 8].try_into().unwrap_or_default()) + )); + pos += 8; + } + FieldType::Decimal + | FieldType::NewDecimal + | FieldType::Varchar + | FieldType::Bit + | FieldType::Enum + | FieldType::Set + | FieldType::TinyBlob + | FieldType::MediumBlob + | FieldType::LongBlob + | FieldType::Blob + | FieldType::VarString + | FieldType::String + | FieldType::Geometry + | FieldType::Json + | FieldType::Vector => { + let length_string = &data[pos..]; + let (not_readed, length) = parse_varint(length_string)?; + if length_string.len() < length as usize { + break; + } + pos += length_string.len() - not_readed.len(); + if length > 0 { + let (_, string) = + map(take(length), |s| String::from_utf8_lossy(s).to_string())( + not_readed, + )?; + texts.push(string); + pos += length as usize; + } + } + FieldType::Date + | FieldType::NewDate + | FieldType::Datetime + | FieldType::Datetime2 + | FieldType::Timestamp + | FieldType::Timestamp2 + | FieldType::Time + | FieldType::Time2 => { + let length_string = &data[pos..]; + let (not_readed, length) = parse_varint(length_string)?; + if length_string.len() < length as usize { + break; + } + pos += length_string.len() - not_readed.len(); + let string = match length { + 0 => "datetime 0000-00-00 00:00:00.000000".to_string(), + 4 => { + let (ch, year) = le_u16(not_readed)?; + let (ch, month) = be_u8(ch)?; + let (_, day) = be_u8(ch)?; + format!("datetime {:04}-{:02}-{:02}", year, month, day) + } + 7 => { + let (ch, year) = le_u16(not_readed)?; + let (ch, month) = be_u8(ch)?; + let (ch, day) = be_u8(ch)?; + let (ch, hour) = be_u8(ch)?; + let (ch, minute) = be_u8(ch)?; + let (_, second) = be_u8(ch)?; + format!( + "datetime {:04}-{:02}-{:02} {:02}:{:02}:{:02}", + year, month, day, hour, minute, second + ) + } + 11 => { + let (ch, year) = le_u16(not_readed)?; + let (ch, month) = be_u8(ch)?; + let (ch, day) = be_u8(ch)?; + let (ch, hour) = be_u8(ch)?; + let (ch, minute) = be_u8(ch)?; + let (ch, second) = be_u8(ch)?; + let (_, microsecond) = le_u32(ch)?; + format!( + "datetime {:04}-{:02}-{:02} {:02}:{:02}:{:02}.{:06}", + year, month, day, hour, minute, second, microsecond, + ) + } + _ => "".to_string(), + }; + pos += length as usize; + texts.push(string); + } + _ => { + break; + } + } + } + let texts = texts.join(","); + + Ok((rem, MysqlResultBinarySetRow::Text(texts))) + } +} + +fn parse_response_resultset(i: &[u8], n_cols: u64) -> IResult<&[u8], MysqlResponse> { + let (i, columns) = many_m_n(n_cols as usize, n_cols as usize, parse_column_definition)(i)?; + let (i, eof) = parse_eof_packet(i)?; + let (i, (rows, _)) = many_till(parse_resultset_row, |i| { + let (rem, header) = parse_packet_header(i)?; + let payload = + unsafe { std::slice::from_raw_parts(header.payload.as_ptr(), header.payload.len()) }; + let (i, response_code) = verify(be_u8, |&x| x == 0xFE || x == 0xFF)(payload)?; + match response_code { + // EOF + 0xFE => Ok(( + rem, + MysqlResponse { + item: MysqlResponsePacket::EOF, + }, + )), + // ERR + 0xFF => { + let (_, response) = parse_response_err(i)?; + Ok((rem, response)) + } + _ => Ok(( + rem, + MysqlResponse { + item: MysqlResponsePacket::Unknown, + }, + )), + } + })(i)?; + Ok(( + i, + MysqlResponse { + item: MysqlResponsePacket::ResultSet { + n_cols, + columns, + eof, + rows, + }, + }, + )) +} + +fn parse_response_binary_resultset(i: &[u8], n_cols: u64) -> IResult<&[u8], MysqlResponse> { + let (i, columns) = many_m_n(n_cols as usize, n_cols as usize, parse_column_definition)(i)?; + let (i, eof) = parse_eof_packet(i)?; + let (i, (rows, _)) = many_till(parse_binary_resultset_row(columns), |i| { + // eof + let (rem, header) = parse_packet_header(i)?; + let payload = + unsafe { std::slice::from_raw_parts(header.payload.as_ptr(), header.payload.len()) }; + let (i, response_code) = verify(be_u8, |&x| x == 0xFE || x == 0xFF)(payload)?; + match response_code { + // EOF + 0xFE => Ok(( + rem, + MysqlResponse { + item: MysqlResponsePacket::EOF, + }, + )), + // ERR + 0xFF => { + let (_, response) = parse_response_err(i)?; + Ok((rem, response)) + } + _ => Ok(( + rem, + MysqlResponse { + item: MysqlResponsePacket::Unknown, + }, + )), + } + })(i)?; + Ok(( + i, + MysqlResponse { + item: MysqlResponsePacket::BinaryResultSet { n_cols, eof, rows }, + }, + )) +} + +fn parse_response_ok(i: &[u8]) -> IResult<&[u8], MysqlResponse> { + let length = i.len(); + let old = i; + let (i, rows) = parse_varint(i)?; + let (i, _last_insert_id) = parse_varint(i)?; + let (i, flags) = le_u16(i)?; + let (i, warnings) = le_u16(i)?; + let consumed = old.len() - i.len(); + // Should never happen + if consumed > length { + return Ok(( + &[], + MysqlResponse { + item: MysqlResponsePacket::Ok { + rows, + flags, + warnings, + }, + }, + )); + } + let (i, _) = take(length - consumed)(i)?; + + Ok(( + i, + MysqlResponse { + item: MysqlResponsePacket::Ok { + rows, + flags, + warnings, + }, + }, + )) +} + +fn parse_response_err(i: &[u8]) -> IResult<&[u8], MysqlResponse> { + let length = i.len(); + let (i, error_code) = le_u16(i)?; + let (i, _) = take(6_u32)(i)?; + // sql state maker & sql state + let (i, _) = take(6_u32)(i)?; + let length = length - 2 - 12; + let (i, error_message) = map(take(length), |s: &[u8]| { + String::from_utf8_lossy(s).to_string() + })(i)?; + Ok(( + i, + MysqlResponse { + item: MysqlResponsePacket::Err { + error_code, + error_message, + }, + }, + )) +} + +pub fn parse_handshake_request(i: &[u8]) -> IResult<&[u8], MysqlHandshakeRequest> { + let (rem, header) = parse_packet_header(i)?; + let payload = + unsafe { std::slice::from_raw_parts(header.payload.as_ptr(), header.payload.len()) }; + let (i, protocol) = verify(be_u8, |&x| x == 0x0a_u8)(payload)?; + let (i, version) = map(take_till(|ch| ch == 0x00), |s: &[u8]| { + String::from_utf8_lossy(s).to_string() + })(i)?; + let (i, _) = take(1_u32)(i)?; + let (i, conn_id) = le_u32(i)?; + let (i, salt1) = map(take(8_u32), |s: &[u8]| { + String::from_utf8_lossy(s).to_string() + })(i)?; + let (i, _) = take(1_u32)(i)?; + let (i, capability_flags1) = le_u16(i)?; + let (i, character_set) = be_u8(i)?; + let (i, status_flags) = le_u16(i)?; + let (i, capability_flags2) = le_u16(i)?; + let (i, auth_plugin_len) = be_u8(i)?; + let (i, _) = take(10_u32)(i)?; + let (i, salt2) = map(take_till(|ch| ch == 0x00), |s: &[u8]| { + String::from_utf8_lossy(s).to_string() + })(i)?; + let (i, _) = take(1_u32)(i)?; + let (i, auth_plugin_data) = cond( + auth_plugin_len > 0, + map(take(auth_plugin_len as usize), |s: &[u8]| { + String::from_utf8_lossy(s).to_string() + }), + )(i)?; + let (_, _) = take(1_u32)(i)?; + Ok(( + rem, + MysqlHandshakeRequest { + protocol, + version, + conn_id, + salt1, + capability_flags1, + character_set, + status_flags, + capability_flags2, + auth_plugin_len, + salt2, + auth_plugin_data, + }, + )) +} + +pub fn parse_handshake_capabilities(i: &[u8]) -> IResult<&[u8], u32> { + let (rem, header) = parse_packet_header(i)?; + let payload = + unsafe { std::slice::from_raw_parts(header.payload.as_ptr(), header.payload.len()) }; + let (i, client_flags) = verify(le_u32, |&client_flags| { + client_flags & CLIENT_PROTOCOL_41 != 0 + })(payload)?; + let (i, _max_packet_size) = be_u32(i)?; + let (_, _character_set) = be_u8(i)?; + + // fk this code + Ok((rem, client_flags)) +} + +pub fn parse_handshake_ssl_request(i: &[u8]) -> IResult<&[u8], MysqlSSLRequest> { + let (rem, header) = parse_packet_header(i)?; + let payload = + unsafe { std::slice::from_raw_parts(header.payload.as_ptr(), header.payload.len()) }; + let (i, _client_flags) = verify(le_u32, |&client_flags| { + client_flags & CLIENT_PROTOCOL_41 != 0 + })(payload)?; + let (i, _max_packet_size) = be_u32(i)?; + let (i, _character_set) = be_u8(i)?; + let (_, filter) = map(take(23_u32), |s: &[u8]| { + String::from_utf8_lossy(s).to_string() + })(i)?; + Ok(( + rem, + MysqlSSLRequest { + filter: Some(filter), + }, + )) +} + +pub fn parse_handshake_response(i: &[u8]) -> IResult<&[u8], MysqlHandshakeResponse> { + let (rem, header) = parse_packet_header(i)?; + let payload = + unsafe { std::slice::from_raw_parts(header.payload.as_ptr(), header.payload.len()) }; + let (i, client_flags) = verify(le_u32, |&client_flags| { + client_flags & CLIENT_PROTOCOL_41 != 0 + })(payload)?; + let (i, _max_packet_size) = be_u32(i)?; + let (i, _character_set) = be_u8(i)?; + + let (i, _filter) = map(take(23_u32), |s: &[u8]| { + String::from_utf8_lossy(s).to_string() + })(i)?; + let (i, username) = map(take_till(|ch| ch == 0x00), |s: &[u8]| { + String::from_utf8_lossy(s).to_string() + })(i)?; + let (i, _) = take(1_u32)(i)?; + let (i, auth_response_len) = be_u8(i)?; + let (i, auth_response) = map(take(auth_response_len as usize), |s: &[u8]| { + String::from_utf8_lossy(s).to_string() + })(i)?; + + let (i, database) = cond( + client_flags & CLIENT_CONNECT_WITH_DB != 0, + map(take_till(|ch| ch == 0x00), |s: &[u8]| { + String::from_utf8_lossy(s).to_string() + }), + )(i)?; + let (i, _) = cond(database.is_some(), take(1_u32))(i)?; + + let (i, client_plugin_name) = cond( + client_flags & CLIENT_PLUGIN_AUTH != 0, + map(take_till(|ch| ch == 0x00), |s: &[u8]| { + String::from_utf8_lossy(s).to_string() + }), + )(i)?; + let (i, _) = cond(client_plugin_name.is_some(), take(1_u32))(i)?; + + let (i, length) = cond(client_flags & CLIENT_CONNECT_ATTRS != 0, be_u8)(i)?; + + let (i, attributes) = cond( + length.is_some(), + parse_handshake_response_attributes(length), + )(i)?; + + let (_, zstd_compression_level) = + cond(client_flags & CLIENT_ZSTD_COMPRESSION_ALGORITHM != 0, be_u8)(i)?; + Ok(( + rem, + MysqlHandshakeResponse { + username, + auth_response_len, + auth_response, + database, + client_plugin_name, + attributes, + zstd_compression_level, + client_flags, + }, + )) +} + +fn parse_handshake_response_attributes( + length: Option, +) -> impl FnMut(&[u8]) -> IResult<&[u8], Vec> { + move |i| { + if length.is_none() { + return Ok((i, Vec::new())); + } + let mut length = length.unwrap(); + let mut res = vec![]; + let mut rem = i; + while length > 0 { + let (i, key_len) = be_u8(rem)?; + // length contains key_len + length -= 1; + let (i, key) = map(take(key_len as usize), |s: &[u8]| { + String::from_utf8_lossy(s).to_string() + })(i)?; + let (i, value_len) = be_u8(i)?; + // length contains value_len + length -= 1; + let (i, value) = map(take(value_len as usize), |s: &[u8]| { + String::from_utf8_lossy(s).to_string() + })(i)?; + res.push(MysqlHandshakeResponseAttribute { key, value }); + length -= key_len + value_len; + rem = i; + } + + Ok((rem, res)) + } +} + +pub fn parse_auth_response(i: &[u8]) -> IResult<&[u8], MysqlResponse> { + let (rem, header) = parse_packet_header(i)?; + let payload = + unsafe { std::slice::from_raw_parts(header.payload.as_ptr(), header.payload.len()) }; + let (i, status) = be_u8(payload)?; + match status { + 0x00 => { + let (_, response) = parse_response_ok(i)?; + Ok((rem, response)) + } + // AuthMoreData + 0x01 => { + let (_i, data) = be_u8(i)?; + Ok(( + rem, + MysqlResponse { + item: MysqlResponsePacket::AuthMoreData { data }, + }, + )) + } + 0xEF => Ok(( + rem, + MysqlResponse { + item: MysqlResponsePacket::EOF, + }, + )), + _ => Ok(( + rem, + MysqlResponse { + item: MysqlResponsePacket::Unknown, + }, + )), + } +} + +pub fn parse_auth_switch_request(i: &[u8]) -> IResult<&[u8], MysqlAuthSwtichRequest> { + let (rem, header) = parse_packet_header(i)?; + let payload = + unsafe { std::slice::from_raw_parts(header.payload.as_ptr(), header.payload.len()) }; + let plugin_length = payload.len(); + let (i, plugin_name) = map(take_till(|ch| ch == 0x00), |s: &[u8]| { + String::from_utf8_lossy(s).to_string() + })(payload)?; + let plugin_length = plugin_length - i.len(); + let (_, plugin_data) = map( + cond( + header.pkt_len - (plugin_length) > 0, + take(header.pkt_len - plugin_length), + ), + |ch: Option<&[u8]>| { + if let Some(ch) = ch { + String::from_utf8_lossy(ch).to_string() + } else { + String::new() + } + }, + )(i)?; + + Ok(( + rem, + MysqlAuthSwtichRequest { + plugin_name, + plugin_data, + }, + )) +} + +pub fn parse_local_file_data_content(i: &[u8]) -> IResult<&[u8], usize> { + let (rem, header) = parse_packet_header(i)?; + Ok((rem, header.pkt_len)) +} + +pub fn parse_request( + i: &[u8], params: Option, param_types: Option>, + stmt_long_datas: Option>, client_flags: u32, +) -> IResult<&[u8], MysqlRequest> { + let (rem, header) = parse_packet_header(i)?; + let payload = + unsafe { std::slice::from_raw_parts(header.payload.as_ptr(), header.payload.len()) }; + let (i, command_code) = be_u8(payload)?; + match command_code { + 0x01 => Ok(( + rem, + MysqlRequest { + command_code, + command: MysqlCommand::Quit, + }, + )), + + 0x02 => { + let (_, command) = parse_init_db_cmd(i)?; + Ok(( + rem, + MysqlRequest { + command_code, + command, + }, + )) + } + + 0x03 => { + let (_, command) = parse_query_cmd(i, client_flags)?; + Ok(( + rem, + MysqlRequest { + command_code, + command, + }, + )) + } + + 0x04 => { + let (_, command) = parse_field_list_cmd(i)?; + Ok(( + rem, + MysqlRequest { + command_code, + command, + }, + )) + } + + 0x08 => Ok(( + rem, + MysqlRequest { + command_code, + command: MysqlCommand::Statistics, + }, + )), + + 0x0D => Ok(( + rem, + MysqlRequest { + command_code, + command: MysqlCommand::Debug, + }, + )), + + 0x0e => Ok(( + rem, + MysqlRequest { + command_code, + command: MysqlCommand::Ping, + }, + )), + + 0x11 => Ok(( + rem, + MysqlRequest { + command_code, + command: MysqlCommand::ChangeUser, + }, + )), + 0x1A => { + let length = header.pkt_len - 1; + if length == 2 { + let (_, _) = le_u16(i)?; + Ok(( + rem, + MysqlRequest { + command_code, + command: MysqlCommand::SetOption, + }, + )) + } else if length == 4 { + let (_, statement_id) = le_u32(i)?; + Ok(( + rem, + MysqlRequest { + command_code, + command: MysqlCommand::StmtReset { statement_id }, + }, + )) + } else { + Ok(( + rem, + MysqlRequest { + command_code, + command: MysqlCommand::Unknown, + }, + )) + } + } + + 0x1F => Ok(( + rem, + MysqlRequest { + command_code, + command: MysqlCommand::ResetConnection, + }, + )), + + 0x16 => { + let (_, command) = parse_stmt_prepare_cmd(i)?; + Ok(( + rem, + MysqlRequest { + command_code, + command, + }, + )) + } + + 0x17 => { + let (_, command) = + parse_stmt_execute_cmd(i, params, param_types, stmt_long_datas, client_flags)?; + Ok(( + rem, + MysqlRequest { + command_code, + command, + }, + )) + } + + 0x18 => { + let (_, command) = parse_stmt_send_long_data_cmd(i)?; + Ok(( + rem, + MysqlRequest { + command_code, + command, + }, + )) + } + + 0x19 => { + // + if header.pkt_len - 1 == 8 { + let (_, command) = parse_stmt_fetch_cmd(i)?; + Ok(( + rem, + MysqlRequest { + command_code, + command, + }, + )) + } else { + let (_, command) = parse_stmt_close_cmd(i)?; + Ok(( + rem, + MysqlRequest { + command_code, + command, + }, + )) + } + } + + _ => { + SCLogDebug!( + "Unknown request, header: {:?}, command_code: {}", + header, + command_code + ); + let (_, _) = cond(header.pkt_len - 1 > 0, take(header.pkt_len - 1))(i)?; + Ok(( + rem, + MysqlRequest { + command_code, + command: MysqlCommand::Unknown, + }, + )) + } + } +} + +pub fn parse_response(i: &[u8]) -> IResult<&[u8], MysqlResponse> { + let (rem, header) = parse_packet_header(i)?; + let payload = + unsafe { std::slice::from_raw_parts(header.payload.as_ptr(), header.payload.len()) }; + let (i, response_code) = be_u8(payload)?; + match response_code { + // OK + 0x00 => { + let (_, response) = parse_response_ok(i)?; + Ok((rem, response)) + } + // LOCAL INFILE Request + 0xFB => Ok(( + rem, + MysqlResponse { + item: MysqlResponsePacket::LocalInFileRequest, + }, + )), + // EOF + 0xFE => Ok(( + rem, + MysqlResponse { + item: MysqlResponsePacket::EOF, + }, + )), + // ERR + 0xFF => { + let (_, response) = parse_response_err(i)?; + Ok((rem, response)) + } + // Text Resultset + _ => parse_response_resultset(rem, response_code as u64), + } +} + +pub fn parse_change_user_response(i: &[u8]) -> IResult<&[u8], MysqlResponse> { + let (rem, header) = parse_packet_header(i)?; + let payload = + unsafe { std::slice::from_raw_parts(header.payload.as_ptr(), header.payload.len()) }; + let (i, response_code) = be_u8(payload)?; + match response_code { + // OK + 0x00 => { + let (_, response) = parse_response_ok(i)?; + Ok((rem, response)) + } + // AuthSwitch + 0xFE => Ok(( + rem, + MysqlResponse { + item: MysqlResponsePacket::AuthSwithRequest, + }, + )), + // ERR + 0xFF => { + let (_, response) = parse_response_err(i)?; + Ok((rem, response)) + } + _ => Ok(( + rem, + MysqlResponse { + item: MysqlResponsePacket::Unknown, + }, + )), + } +} + +pub fn parse_statistics_response(i: &[u8]) -> IResult<&[u8], MysqlResponse> { + let (rem, header) = parse_packet_header(i)?; + let payload = + unsafe { std::slice::from_raw_parts(header.payload.as_ptr(), header.payload.len()) }; + let (_, _) = take(header.pkt_len)(payload)?; + Ok(( + rem, + MysqlResponse { + item: MysqlResponsePacket::Statistics, + }, + )) +} + +pub fn parse_field_list_response(i: &[u8]) -> IResult<&[u8], MysqlResponse> { + let (rem, header) = parse_packet_header(i)?; + let payload = + unsafe { std::slice::from_raw_parts(header.payload.as_ptr(), header.payload.len()) }; + let (i, response_code) = be_u8(payload)?; + match response_code { + // ERR + 0xFF => { + let (_, response) = parse_response_err(i)?; + Ok((rem, response)) + } + 0x00 => Ok(( + rem, + MysqlResponse { + item: MysqlResponsePacket::FieldsList { columns: None }, + }, + )), + _ => { + let n_cols = response_code; + let (i, columns) = + many_m_n(n_cols as usize, n_cols as usize, parse_column_definition)(rem)?; + let (_, _) = parse_eof_packet(i)?; + Ok(( + rem, + MysqlResponse { + item: MysqlResponsePacket::FieldsList { + columns: Some(columns), + }, + }, + )) + } + } +} + +pub fn parse_stmt_prepare_response(i: &[u8], _client_flags: u32) -> IResult<&[u8], MysqlResponse> { + let (rem, header) = parse_packet_header(i)?; + let payload = + unsafe { std::slice::from_raw_parts(header.payload.as_ptr(), header.payload.len()) }; + let (i, response_code) = be_u8(payload)?; + match response_code { + 0x00 => { + let (i, statement_id) = le_u32(i)?; + let (i, num_columns) = le_u16(i)?; + let (i, num_params) = le_u16(i)?; + let (i, _filter) = be_u8(i)?; + //TODO: why? + // let (i, _warning_cnt) = cond(header.pkt_len > 12, take(2_u32))(i)?; + let (_, _warning_cnt) = take(2_u32)(i)?; + // should use remain + let (i, params) = cond( + num_params > 0, + many_till(parse_column_definition, parse_eof_packet), + )(rem) + .map(|(i, params)| { + if let Some(params) = params { + (i, Some(params.0)) + } else { + (i, None) + } + })?; + // should use remain + let (_, fields) = cond( + num_columns > 0, + many_till(parse_column_definition, parse_eof_packet), + )(i) + .map(|(i, fields)| { + if let Some(fields) = fields { + (i, Some(fields.0)) + } else { + (i, None) + } + })?; + + Ok(( + i, + MysqlResponse { + item: MysqlResponsePacket::StmtPrepare { + statement_id, + num_params, + params, + fields, + }, + }, + )) + } + // ERR + 0xFF => { + let (_, response) = parse_response_err(i)?; + Ok((rem, response)) + } + _ => Ok(( + rem, + MysqlResponse { + item: MysqlResponsePacket::Unknown, + }, + )), + } +} + +pub fn parse_stmt_execute_response(i: &[u8]) -> IResult<&[u8], MysqlResponse> { + let (rem, header) = parse_packet_header(i)?; + let payload = + unsafe { std::slice::from_raw_parts(header.payload.as_ptr(), header.payload.len()) }; + let (i, response_code) = be_u8(payload)?; + match response_code { + // OK + 0x00 => { + let (_, response) = parse_response_ok(i)?; + Ok((rem, response)) + } + // ERR + 0xFF => { + let (_, response) = parse_response_err(i)?; + Ok((rem, response)) + } + _ => parse_response_binary_resultset(rem, response_code as u64), + } +} + +pub fn parse_stmt_fetch_response(i: &[u8]) -> IResult<&[u8], MysqlResponse> { + let (rem, header) = parse_packet_header(i)?; + let payload = + unsafe { std::slice::from_raw_parts(header.payload.as_ptr(), header.payload.len()) }; + let (i, response_code) = be_u8(payload)?; + match response_code { + // OK + 0x00 => { + let (_, response) = parse_response_ok(i)?; + Ok((rem, response)) + } + // ERR + 0xFF => { + let (_, response) = parse_response_err(i)?; + Ok((rem, response)) + } + _ => parse_response_binary_resultset(rem, response_code as u64), + } +} + +pub fn parse_auth_request(i: &[u8]) -> IResult<&[u8], ()> { + let (rem, _header) = parse_packet_header(i)?; + Ok((rem, ())) +} + +pub fn parse_auth_responsev2(i: &[u8]) -> IResult<&[u8], MysqlResponse> { + let (rem, header) = parse_packet_header(i)?; + let payload = + unsafe { std::slice::from_raw_parts(header.payload.as_ptr(), header.payload.len()) }; + let (i, response_code) = be_u8(payload)?; + match response_code { + // OK + 0x00 => { + let (_, response) = parse_response_ok(i)?; + Ok((rem, response)) + } + // auth data + _ => Ok(( + rem, + MysqlResponse { + item: MysqlResponsePacket::AuthData, + }, + )), + } +} + +#[cfg(test)] +mod test { + + use super::*; + + #[test] + fn test_parse_handshake_request() { + let pkt: &[u8] = &[ + 0x49, 0x00, 0x00, 0x00, 0x0a, 0x38, 0x2e, 0x34, 0x2e, 0x30, 0x00, 0x51, 0x00, 0x00, + 0x00, 0x3e, 0x7d, 0x6a, 0x6a, 0x1a, 0x2d, 0x2b, 0x6b, 0x00, 0xff, 0xff, 0xff, 0x02, + 0x00, 0xff, 0xdf, 0x15, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x61, 0x74, 0x54, 0x07, 0x62, 0x28, 0x5d, 0x21, 0x06, 0x44, 0x06, 0x62, 0x00, 0x63, + 0x61, 0x63, 0x68, 0x69, 0x6e, 0x67, 0x5f, 0x73, 0x68, 0x61, 0x32, 0x5f, 0x70, 0x61, + 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x00, + ]; + let (rem, handshake_request) = parse_handshake_request(pkt).unwrap(); + + assert!(rem.is_empty()); + assert_eq!(handshake_request.protocol, 10); + assert_eq!(handshake_request.version, "8.4.0"); + assert_eq!(handshake_request.conn_id, 81); + assert_eq!(handshake_request.capability_flags1, 0xffff); + assert_eq!(handshake_request.status_flags, 0x0002); + assert_eq!(handshake_request.capability_flags2, 0xdfff); + assert_eq!(handshake_request.auth_plugin_len, 21); + assert_eq!( + handshake_request.auth_plugin_data, + Some("caching_sha2_password".to_string()), + ); + let pkt: &[u8] = &[ + 0x49, 0x00, 0x00, 0x00, 0x0a, 0x39, 0x2e, 0x30, 0x2e, 0x31, 0x00, 0x08, 0x00, 0x00, + 0x00, 0x5e, 0x09, 0x7c, 0x41, 0x76, 0x5d, 0x66, 0x17, 0x00, 0xff, 0xff, 0xff, 0x02, + 0x00, 0xff, 0xdf, 0x15, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x47, 0x4c, 0x7a, 0x03, 0x13, 0x35, 0x71, 0x0a, 0x4e, 0x2f, 0x45, 0x34, 0x00, 0x63, + 0x61, 0x63, 0x68, 0x69, 0x6e, 0x67, 0x5f, 0x73, 0x68, 0x61, 0x32, 0x5f, 0x70, 0x61, + 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x00, + ]; + let (rem, _) = parse_handshake_request(pkt).unwrap(); + + assert!(rem.is_empty()); + } + + #[test] + fn test_parse_handshake_response() { + let pkt: &[u8] = &[ + 0xc6, 0x00, 0x00, 0x01, 0x8d, 0xa2, 0x1a, 0x00, 0x00, 0x00, 0x00, 0x00, 0x2d, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x72, 0x6f, 0x6f, 0x74, 0x00, 0x20, + 0xbd, 0xb9, 0xfd, 0xe3, 0x22, 0xce, 0x86, 0x7d, 0x6c, 0x1d, 0x0e, 0xad, 0x22, 0x92, + 0xde, 0x56, 0xe5, 0xf2, 0x3d, 0xf8, 0xe0, 0x1f, 0x6f, 0x59, 0x5e, 0x62, 0xa6, 0x6b, + 0x7e, 0x54, 0x61, 0xfc, 0x73, 0x65, 0x6e, 0x74, 0x69, 0x6e, 0x65, 0x6c, 0x2d, 0x66, + 0x6c, 0x6f, 0x77, 0x00, 0x63, 0x61, 0x63, 0x68, 0x69, 0x6e, 0x67, 0x5f, 0x73, 0x68, + 0x61, 0x32, 0x5f, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x00, 0x5b, 0x0c, + 0x5f, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x0f, 0x47, + 0x6f, 0x2d, 0x4d, 0x79, 0x53, 0x51, 0x4c, 0x2d, 0x44, 0x72, 0x69, 0x76, 0x65, 0x72, + 0x03, 0x5f, 0x6f, 0x73, 0x05, 0x6c, 0x69, 0x6e, 0x75, 0x78, 0x09, 0x5f, 0x70, 0x6c, + 0x61, 0x74, 0x66, 0x6f, 0x72, 0x6d, 0x05, 0x61, 0x6d, 0x64, 0x36, 0x34, 0x04, 0x5f, + 0x70, 0x69, 0x64, 0x06, 0x34, 0x35, 0x30, 0x39, 0x37, 0x36, 0x0c, 0x5f, 0x73, 0x65, + 0x72, 0x76, 0x65, 0x72, 0x5f, 0x68, 0x6f, 0x73, 0x74, 0x0a, 0x31, 0x37, 0x32, 0x2e, + 0x31, 0x37, 0x2e, 0x30, 0x2e, 0x32, + ]; + + let (rem, handshake_response) = parse_handshake_response(pkt).unwrap(); + + assert!(rem.is_empty()); + assert_eq!(handshake_response.username, "root"); + assert_eq!( + handshake_response.database, + Some("sentinel-flow".to_string()) + ); + assert_eq!( + handshake_response.client_plugin_name, + Some("caching_sha2_password".to_string()) + ); + let pkt: &[u8] = &[ + 0x5c, 0x00, 0x00, 0x01, 0x85, 0xa2, 0x0a, 0x00, 0x00, 0x00, 0x00, 0x00, 0x2d, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x72, 0x6f, 0x6f, 0x74, 0x00, 0x20, + 0x9f, 0xbd, 0x98, 0xd7, 0x8f, 0x7b, 0x74, 0xfe, 0x9e, 0x4e, 0x99, 0x64, 0xc0, 0xd0, + 0x6a, 0x1d, 0x56, 0xbf, 0x36, 0xb1, 0xcd, 0x10, 0x6d, 0x3a, 0x37, 0xaf, 0x25, 0x22, + 0x06, 0xb6, 0xe5, 0x13, 0x63, 0x61, 0x63, 0x68, 0x69, 0x6e, 0x67, 0x5f, 0x73, 0x68, + 0x61, 0x32, 0x5f, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x00, + ]; + let (rem, _) = parse_handshake_response(pkt).unwrap(); + + assert!(rem.is_empty()); + } + + #[test] + fn test_parse_query_request() { + let pkt: &[u8] = &[ + 0x12, 0x00, 0x00, 0x00, 0x03, 0x53, 0x45, 0x54, 0x20, 0x4e, 0x41, 0x4d, 0x45, 0x53, + 0x20, 0x75, 0x74, 0x66, 0x38, 0x6d, 0x62, 0x34, + ]; + + let (rem, request) = parse_request(pkt, None, None, None, 0).unwrap(); + + assert!(rem.is_empty()); + assert_eq!(request.command_code, 0x03); + + let command = request.command; + if let MysqlCommand::Query { query } = command { + assert_eq!(query, "SET NAMES utf8mb4".to_string()); + } else { + unreachable!(); + } + } + + #[test] + fn test_parse_ok_response() { + let pkt: &[u8] = &[ + 0x07, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + ]; + + let (rem, response) = parse_response(pkt).unwrap(); + assert!(rem.is_empty()); + + let item = response.item; + if let MysqlResponsePacket::Ok { + rows, + flags, + warnings, + } = item + { + assert_eq!(rows, 0); + assert_eq!(flags, 0x0002); + assert_eq!(warnings, 0); + } else { + unreachable!(); + } + } + + #[test] + fn test_parse_text_resultset_response() { + let pkt: &[u8] = &[ + 0x01, 0x00, 0x00, 0x01, 0x01, // Column count + 0x1f, 0x00, 0x00, 0x02, 0x03, 0x64, 0x65, 0x66, 0x00, 0x00, 0x00, 0x09, 0x56, 0x45, + 0x52, 0x53, 0x49, 0x4f, 0x4e, 0x28, 0x29, 0x00, 0x0c, 0xff, 0x00, 0x14, 0x00, 0x00, + 0x00, 0xfd, 0x01, 0x00, 0x1f, 0x00, 0x00, // Field packet + 0x05, 0x00, 0x00, 0x03, 0xfe, 0x00, 0x00, 0x02, 0x00, // EOF + 0x06, 0x00, 0x00, 0x04, 0x05, 0x39, 0x2e, 0x30, 0x2e, 0x31, // Row packet + 0x05, 0x00, 0x00, 0x05, 0xfe, 0x00, 0x00, 0x02, 0x00, // EOF + ]; + + let (rem, response) = parse_response(pkt).unwrap(); + assert!(rem.is_empty()); + + let item = response.item; + if let MysqlResponsePacket::ResultSet { + n_cols, + columns: _, + eof: _, + rows: _, + } = item + { + assert_eq!(n_cols, 1); + } + } + + #[test] + fn test_parse_quit_request() { + let pkt: &[u8] = &[0x01, 0x00, 0x00, 0x00, 0x01]; + + let (rem, request) = parse_request(pkt, None, None, None, 0).unwrap(); + + assert!(rem.is_empty()); + assert_eq!(request.command_code, 0x01); + + let command = request.command; + if let MysqlCommand::Quit = command { + } else { + unreachable!(); + } + } + + #[test] + fn test_parse_prepare_stmt_request() { + let pkt: &[u8] = &[ + 0x2b, 0x00, 0x00, 0x00, 0x16, 0x73, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x20, 0x2a, 0x20, + 0x66, 0x72, 0x6f, 0x6d, 0x20, 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x73, 0x20, + 0x57, 0x48, 0x45, 0x52, 0x45, 0x20, 0x69, 0x64, 0x20, 0x3d, 0x3f, 0x20, 0x6c, 0x69, + 0x6d, 0x69, 0x74, 0x20, 0x31, + ]; + + let (rem, request) = parse_request(pkt, None, None, None, 0).unwrap(); + + assert!(rem.is_empty()); + assert_eq!(request.command_code, 0x16); + + let command = request.command; + + if let MysqlCommand::StmtPrepare { query } = command { + assert_eq!(query, "select * from requests WHERE id =? limit 1"); + } else { + unreachable!(); + } + let pkt: &[u8] = &[ + 64, 0, 0, 0, 22, 83, 69, 76, 69, 67, 84, 32, 96, 114, 101, 115, 111, 117, 114, 99, 101, + 96, 32, 70, 82, 79, 77, 32, 96, 115, 121, 115, 95, 97, 117, 116, 104, 111, 114, 105, + 116, 105, 101, 115, 96, 32, 87, 72, 69, 82, 69, 32, 97, 117, 116, 104, 111, 114, 105, + 116, 121, 95, 105, 100, 32, 61, 32, 63, + ]; + let (rem, request) = parse_request(pkt, None, None, None, 0).unwrap(); + + assert!(rem.is_empty()); + assert_eq!(request.command_code, 0x16); + } + + #[test] + fn test_parse_close_stmt_request() { + let pkt: &[u8] = &[0x05, 0x00, 0x00, 0x00, 0x19, 0x01, 0x00, 0x00, 0x00]; + + let (rem, request) = parse_request(pkt, Some(1), None, None, 0).unwrap(); + + assert!(rem.is_empty()); + assert_eq!(request.command_code, 0x19); + + let command = request.command; + + if let MysqlCommand::StmtClose { statement_id } = command { + assert_eq!(statement_id, 1); + } else { + unreachable!(); + } + } + + #[test] + fn test_parse_ping_request() { + let pkt: &[u8] = &[0x01, 0x00, 0x00, 0x00, 0x0e]; + let (rem, request) = parse_request(pkt, None, None, None, 0).unwrap(); + + assert!(rem.is_empty()); + assert_eq!(request.command, MysqlCommand::Ping); + } +} diff --git a/src/app-layer-parser.c b/src/app-layer-parser.c index 045fcb086e1f..ebc207f2d53e 100644 --- a/src/app-layer-parser.c +++ b/src/app-layer-parser.c @@ -1739,6 +1739,7 @@ void AppLayerParserRegisterProtocolParsers(void) RegisterHTTP2Parsers(); rs_telnet_register_parser(); RegisterIMAPParsers(); + rs_mysql_register_parser(); /** POP3 */ AppLayerProtoDetectRegisterProtocol(ALPROTO_POP3, "pop3"); diff --git a/src/app-layer-protos.c b/src/app-layer-protos.c index 03736554c7b6..10d03de1f53f 100644 --- a/src/app-layer-protos.c +++ b/src/app-layer-protos.c @@ -69,6 +69,7 @@ const AppProtoStringTuple AppProtoStrings[ALPROTO_MAX] = { { ALPROTO_BITTORRENT_DHT, "bittorrent-dht" }, { ALPROTO_POP3, "pop3" }, { ALPROTO_HTTP, "http" }, + { ALPROTO_MYSQL, "mysql" }, { ALPROTO_FAILED, "failed" }, }; diff --git a/src/app-layer-protos.h b/src/app-layer-protos.h index 10b8959772c4..f538264db901 100644 --- a/src/app-layer-protos.h +++ b/src/app-layer-protos.h @@ -68,6 +68,7 @@ enum AppProtoEnum { // signature-only (ie not seen in flow) // HTTP for any version (ALPROTO_HTTP1 (version 1) or ALPROTO_HTTP2) ALPROTO_HTTP, + ALPROTO_MYSQL, /* used by the probing parser when alproto detection fails * permanently for that particular stream */ diff --git a/src/detect-engine-register.c b/src/detect-engine-register.c index 9bddf0fd8437..7078bce3a786 100644 --- a/src/detect-engine-register.c +++ b/src/detect-engine-register.c @@ -707,6 +707,7 @@ void SigTableSetup(void) ScDetectRfbRegister(); ScDetectSipRegister(); ScDetectTemplateRegister(); + ScDetectMysqlRegister(); /* close keyword registration */ DetectBufferTypeCloseRegistration(); diff --git a/src/detect-engine-register.h b/src/detect-engine-register.h index b7a029998555..827d3727842c 100644 --- a/src/detect-engine-register.h +++ b/src/detect-engine-register.h @@ -331,6 +331,9 @@ enum DetectKeywordId { DETECT_AL_JA4_HASH, + DETECT_AL_MYSQL_COMMAND, + DETECT_AL_MYSQL_ROWS, + /* make sure this stays last */ DETECT_TBLSIZE_STATIC, }; diff --git a/src/output.c b/src/output.c index b99897509c0f..4c83ad443817 100644 --- a/src/output.c +++ b/src/output.c @@ -901,6 +901,7 @@ void OutputRegisterRootLoggers(void) // underscore instead of dash for bittorrent_dht RegisterSimpleJsonApplayerLogger( ALPROTO_BITTORRENT_DHT, rs_bittorrent_dht_logger_log, "bittorrent_dht"); + RegisterSimpleJsonApplayerLogger(ALPROTO_MYSQL, SCMysqlLogger, "mysql"); OutputPacketLoggerRegister(); OutputFiledataLoggerRegister(); @@ -1081,6 +1082,10 @@ void OutputRegisterLoggers(void) OutputRegisterTxSubModule(LOGGER_JSON_TX, "eve-log", "JsonLdapLog", "eve-log.ldap", OutputJsonLogInitSub, ALPROTO_LDAP, JsonGenericDirFlowLogger, JsonLogThreadInit, JsonLogThreadDeinit); + /* MySQL JSON logger. */ + OutputRegisterTxSubModule(LOGGER_JSON_TX, "eve-log", "JsonMySQLLog", "eve-log.mysql", + OutputJsonLogInitSub, ALPROTO_MYSQL, JsonGenericDirFlowLogger, JsonLogThreadInit, + JsonLogThreadDeinit); /* DoH2 JSON logger. */ JsonDoh2LogRegister(); /* Template JSON logger. */ diff --git a/suricata.yaml.in b/suricata.yaml.in index 672429e403b8..3668d0cb6abd 100644 --- a/suricata.yaml.in +++ b/suricata.yaml.in @@ -324,6 +324,7 @@ outputs: - sip - quic - ldap + - mysql - arp: enabled: no # Many events can be logged. Disabled by default - dhcp: @@ -928,6 +929,12 @@ app-layer: stream-depth: 0 # Maximum number of live PostgreSQL transactions per flow # max-tx: 1024 + mysql: + enabled: no + # Stream reassembly size for MySQL. By default, track it completely. + stream-depth: 0 + # Maximum number of live MySQL transactions per flow + # max-tx: 1024 dcerpc: enabled: yes # Maximum number of live DCERPC transactions per flow