diff --git a/sql/engines/odps.py b/sql/engines/odps.py index 4c0e0c6de0..b1c20d7138 100644 --- a/sql/engines/odps.py +++ b/sql/engines/odps.py @@ -2,9 +2,10 @@ import re import logging +import sqlparse from . import EngineBase -from .models import ResultSet, ReviewSet, ReviewResult +from .models import ResultSet from odps import ODPS @@ -37,16 +38,24 @@ def info(self): def get_all_databases(self): """获取数据库列表, 返回一个ResultSet - ODPS只有project概念, 直接返回project名称 + ODPS只有project概念, 直接返回project名称 + TODO: 目前ODPS获取所有项目接口比较慢, 暂时支持返回一个project,后续再优化 """ result = ResultSet() try: - conn = self.get_connection(self.get_connection()) + conn = self.get_connection() + + # 判断project是否存在 + db_exist = conn.exist_project(self.instance.db_name) + + if db_exist is False: + raise ValueError(f"[{self.instance.db_name}]项目不存在") + result.rows = [conn.project] except Exception as e: logger.warning(f"ODPS执行异常, {e}") - result.rows = [self.instance.db_name] + result.error = str(e) return result def get_all_tables(self, db_name, **kwargs): @@ -126,3 +135,27 @@ def query(self, db_name=None, sql='', limit_num=0, close_conn=True, **kwargs): result_set.error = str(e) return result_set + def query_check(self, db_name=None, sql=''): + # 查询语句的检查、注释去除、切分 + result = {'msg': '', 'bad_query': False, 'filtered_sql': sql, 'has_star': False} + keyword_warning = '' + sql_whitelist = ['select'] + # 根据白名单list拼接pattern语句 + whitelist_pattern = re.compile("^" + "|^".join(sql_whitelist), re.IGNORECASE) + # 删除注释语句,进行语法判断,执行第一条有效sql + try: + sql = sqlparse.format(sql, strip_comments=True) + sql = sqlparse.split(sql)[0] + result['filtered_sql'] = sql.strip() + # sql_lower = sql.lower() + except IndexError: + result['bad_query'] = True + result['msg'] = '没有有效的SQL语句' + return result + if whitelist_pattern.match(sql) is None: + result['bad_query'] = True + result['msg'] = '仅支持{}语法!'.format(','.join(sql_whitelist)) + return result + if result.get('bad_query'): + result['msg'] = keyword_warning + return result diff --git a/sql/engines/tests.py b/sql/engines/tests.py index 7678431105..4806c7d4d1 100644 --- a/sql/engines/tests.py +++ b/sql/engines/tests.py @@ -18,6 +18,7 @@ from sql.engines.oracle import OracleEngine from sql.engines.mongo import MongoEngine from sql.engines.clickhouse import ClickHouseEngine +from sql.engines.odps import ODPSEngine from sql.models import Instance, SqlWorkflow, SqlWorkflowContent User = get_user_model() @@ -1882,3 +1883,98 @@ def test_execute_workflow_success(self, _conn, _cursor, _execute): execute_result = new_engine.execute_workflow(workflow=wf) self.assertIsInstance(execute_result, ReviewSet) self.assertEqual(execute_result.rows[0].__dict__.keys(), row.__dict__.keys()) + + +class ODPSTest(TestCase): + def setUp(self) -> None: + self.ins = Instance.objects.create(instance_name='some_ins', type='slave', db_type='odps', + host='some_host', port=9200, user='ins_user', db_name='some_db') + self.engine = ODPSEngine(instance=self.ins) + + def tearDown(self) -> None: + self.ins.delete() + + @patch('sql.engines.odps.ODPSEngine.get_connection') + def test_get_connection(self, mock_odps): + _ = self.engine.get_connection() + mock_odps.assert_called_once() + + @patch('sql.engines.odps.ODPSEngine.get_connection') + def test_query(self, mock_get_connection): + test_sql = """select 123""" + self.assertIsInstance(self.engine.query('some_db', test_sql), ResultSet) + + def test_query_check(self): + test_sql = """select 123; -- this is comment + select 456;""" + + result_sql = "select 123;" + + check_result = self.engine.query_check(sql=test_sql) + + self.assertIsInstance(check_result, dict) + self.assertEqual(False, check_result.get("bad_query")) + self.assertEqual(result_sql, check_result.get("filtered_sql")) + + def test_query_check_error(self): + test_sql = """drop table table_a""" + + check_result = self.engine.query_check(sql=test_sql) + + self.assertIsInstance(check_result, dict) + self.assertEqual(True, check_result.get("bad_query")) + + @patch('sql.engines.odps.ODPSEngine.get_connection') + def test_get_all_databases(self, mock_get_connection): + + mock_conn = Mock() + mock_conn.exist_project.return_value = True + mock_conn.project = 'some_db' + + mock_get_connection.return_value = mock_conn + + result = self.engine.get_all_databases() + + self.assertIsInstance(result, ResultSet) + self.assertEqual(result.rows, ['some_db']) + + @patch('sql.engines.odps.ODPSEngine.get_connection') + def test_get_all_tables(self, mock_get_connection): + + # 下面是查表示例返回结果 + class T: + def __init__(self, name): + self.name = name + + mock_conn = Mock() + mock_conn.list_tables.return_value = [T('u'), T('v'), T('w')] + mock_get_connection.return_value = mock_conn + + table_list = self.engine.get_all_tables('some_db') + + self.assertEqual(table_list.rows, ['u', 'v', 'w']) + + @patch('sql.engines.odps.ODPSEngine.get_all_columns_by_tb') + def test_describe_table(self, mock_get_all_columns_by_tb): + self.engine.describe_table('some_db', 'some_table') + mock_get_all_columns_by_tb.assert_called_once() + + @patch('sql.engines.odps.ODPSEngine.get_connection') + def test_get_all_columns_by_tb(self, mock_get_connection): + + mock_conn = Mock() + + mock_cols = Mock() + + mock_col = Mock() + mock_col.name, mock_col.type, mock_col.comment = 'XiaoMing', 'string', 'name' + + mock_cols.schema.columns = [mock_col] + mock_conn.get_table.return_value = mock_cols + mock_get_connection.return_value = mock_conn + + result = self.engine.get_all_columns_by_tb('some_db', 'some_table') + mock_get_connection.assert_called_once() + mock_conn.get_table.assert_called_once() + self.assertEqual(result.rows, [['XiaoMing', 'string', 'name']]) + self.assertEqual(result.column_list, ['COLUMN_NAME', 'COLUMN_TYPE', 'COLUMN_COMMENT'])