diff --git a/src/executor/query.rs b/src/executor/query.rs index 01c3a1d0eb..712e33f580 100644 --- a/src/executor/query.rs +++ b/src/executor/query.rs @@ -115,6 +115,41 @@ impl QueryResult { { Ok(T::try_get_many_by_index(self)?) } + + /// Retrieves the names of the columns in the result set + pub fn column_names(&self) -> Vec { + #[cfg(feature = "sqlx-dep")] + use sqlx::Column; + + match &self.row { + #[cfg(feature = "sqlx-mysql")] + QueryResultRow::SqlxMySql(row) => { + row.columns().iter().map(|c| c.name().to_string()).collect() + } + #[cfg(feature = "sqlx-postgres")] + QueryResultRow::SqlxPostgres(row) => { + row.columns().iter().map(|c| c.name().to_string()).collect() + } + #[cfg(feature = "sqlx-sqlite")] + QueryResultRow::SqlxSqlite(row) => { + row.columns().iter().map(|c| c.name().to_string()).collect() + } + #[cfg(feature = "mock")] + QueryResultRow::Mock(row) => row + .clone() + .into_column_value_tuples() + .map(|(c, _)| c.to_string()) + .collect(), + #[cfg(feature = "proxy")] + QueryResultRow::Proxy(row) => row + .clone() + .into_column_value_tuples() + .map(|(c, _)| c.to_string()) + .collect(), + #[allow(unreachable_patterns)] + _ => unreachable!(), + } + } } #[allow(unused_variables)] @@ -1258,7 +1293,11 @@ try_from_u64_err!(uuid::Uuid); #[cfg(test)] mod tests { - use super::TryGetError; + use std::collections::BTreeMap; + + use sea_query::Value; + + use super::*; use crate::error::*; #[test] @@ -1347,4 +1386,21 @@ mod tests { ) ); } + + #[test] + fn column_names_from_query_result() { + let mut values = BTreeMap::new(); + values.insert("id".to_string(), Value::Int(Some(1))); + values.insert( + "name".to_string(), + Value::String(Some(Box::new("Abc".to_owned()))), + ); + let query_result = QueryResult { + row: QueryResultRow::Mock(crate::MockRow { values }), + }; + assert_eq!( + query_result.column_names(), + vec!["id".to_owned(), "name".to_owned()] + ); + } }