痛点:数据库测试,全靠手敲 SQL?
在测试工作中,数据库相关的操作非常常见:
- 测试前准备数据
- 接口测试依赖
- Bug 复现
- 回归验证:修改了一个存储过程,要验证它对各种输入的处理是否正确
- 数据清理
大多数人的做法是:打开 Navicat / DBeaver,复制粘贴 SQL,一条一条执行,运气不好输错一个字符,又要重来。
我就想:能不能把这些都自动化?于是有了这个 SQL 批量执行与结果对比工具。
工具能做什么
| |
|---|
| |
| |
| |
| |
| |
| MySQL / PostgreSQL / SQLite,覆盖主流测试场景 |
| |
| |
完整代码
① 数据库配置 db_config.yaml
# 数据库连接配置databases: test_db: type: mysql host:"192.168.1.100" port:3306 user:"test_user" password:"test_password" database:"test_db" charset:"utf8mb4" prod_like: type: mysql host:"192.168.1.200" port:3306 user:"prod_reader" password:"read_only_password" database:"production_db" charset:"utf8mb4"# 默认数据库default_db:"test_db"# SQL 执行配置execution:# 是否自动开启事务(可回滚) auto_transaction:true# 遇到错误是否继续执行后续 SQL continue_on_error:true# 每条 SQL 执行前等待时间(秒),方便观察 delay_between:0# 执行前是否备份相关表数据 backup_before:true# 备份文件保留天数 backup_keep_days:7
② 主脚本 sql_executor.py
# -*- coding: utf-8 -*-"""SQL 批量执行与结果对比工具支持:MySQL / PostgreSQL / SQLite用于:测试数据准备、数据库回归验证、SQL 执行审计依赖:pip install pymysql psycopg2 pyyaml SQLite 无需额外依赖(Python 内置)"""import osimport reimport sysimport jsonimport timeimport shutilimport datetimeimport argparseimport sqlite3import tempfilefrom pathlib importPathfrom contextlib import contextmanagerimport yaml# ─────────────────────────────────────────────# 第一部分:数据库连接管理# ─────────────────────────────────────────────classDatabaseConnection:"""统一数据库连接管理器"""def __init__(self, config): self.config = config self.conn =None self.db_type = config.get('type','mysql').lower()def connect(self):"""建立数据库连接"""try:if self.db_type =='mysql':import pymysql self.conn = pymysql.connect( host=self.config['host'], port=self.config.get('port',3306), user=self.config['user'], password=self.config['password'], database=self.config.get('database',''), charset=self.config.get('charset','utf8mb4'), cursorclass=pymysql.cursors.DictCursor)elif self.db_type =='postgresql':import psycopg2 self.conn = psycopg2.connect( host=self.config['host'], port=self.config.get('port',5432), user=self.config['user'], password=self.config['password'], database=self.config.get('database',''), cursor_factory=psycopg2.extras.RealDictCursor)elif self.db_type =='sqlite': db_path = self.config.get('database',':memory:') self.conn = sqlite3.connect(db_path) self.conn.row_factory = sqlite3.Rowelse:raiseValueError(f"不支持的数据库类型: {self.db_type}")print(f" ✅ 已连接 {self.config.get('host')}/{self.config.get('database', 'SQLite')}")return selfexceptImportErroras e:print(f" 缺少数据库驱动: {e}")print(f" 请运行以下命令安装:")print(f" MySQL: pip install pymysql")print(f" PostgreSQL: pip install psycopg2-binary") sys.exit(1)exceptExceptionas e:print(f" 连接失败: {e}") sys.exit(1)def close(self):if self.conn: self.conn.close()def execute(self, sql, params=None, commit=False):"""执行单条 SQL""" cursor = self.conn.cursor()try:if params: cursor.execute(sql, params)else: cursor.execute(sql)if commit or self.db_type =='sqlite': self.conn.commit()# 获取结果if cursor.description: rows = cursor.fetchall()if isinstance(rows[0], dict)if rows elseFalse:return{'status':'ok','rows': rows,'affected': len(rows)}else:# SQLite 没有 DictCursor,转换一下 cols =[d[0]for d in cursor.description]return{'status':'ok','rows':[dict(zip(cols, r))for r in rows],'affected': len(rows)}else:return{'status':'ok','affected': cursor.rowcount,'last_insert_id': cursor.lastrowid if self.db_type =='mysql'elseNone}exceptExceptionas e: self.conn.rollback()return{'status':'error','message': str(e),'sql': sql[:100]}finally: cursor.close()def backup_table(self, table_name, backup_dir):"""备份指定表""" backup_file = os.path.join(backup_dir, f"{table_name}_backup_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.json") result = self.execute(f"SELECT * FROM {table_name} LIMIT 10000")if result['status']=='ok':with open(backup_file,'w', encoding='utf-8')as f: json.dump(result.get('rows',[]), f, ensure_ascii=False, indent=2, default=str)print(f" 表 {table_name} 已备份至: {backup_file}")return backup_filereturnNonedef __enter__(self):return self.connect()def __exit__(self, exc_type, exc_val, exc_tb): self.close()# ─────────────────────────────────────────────# 第二部分:SQL 解析与预处理# ─────────────────────────────────────────────def parse_sql_file(sql_content):""" 解析 SQL 文件,支持多条 SQL(用分号分隔) 自动跳过注释和空语句 """# 移除单行注释(--) content = re.sub(r'--[^\n]*','', sql_content)# 移除多行注释(/* ... */) content = re.sub(r'/\*.*?\*/','', content, flags=re.DOTALL) statements =[]# 按分号分割(但分号可能在字符串里,简单处理) parts = content.split(';')for part in parts: stmt = part.strip()# 跳过空语句和纯注释if stmt andnot stmt.startswith('--'): statements.append(stmt)return statementsdef detect_sql_type(sql):"""识别 SQL 类型""" sql_upper = sql.strip().upper()if re.match(r'^\s*SELECT', sql_upper):return'SELECT'elif re.match(r'^\s*INSERT', sql_upper):return'INSERT'elif re.match(r'^\s*UPDATE', sql_upper):return'UPDATE'elif re.match(r'^\s*DELETE', sql_upper):return'DELETE'elif re.match(r'^\s*CREATE', sql_upper):return'CREATE'elif re.match(r'^\s*DROP', sql_upper):return'DROP'elif re.match(r'^\s*ALTER', sql_upper):return'ALTER'else:return'OTHER'def get_sql_summary(sql, max_len=80):"""获取 SQL 语句的摘要描述""" sql_type = detect_sql_type(sql)# 提取表名 table_match = re.search(r'FROM\s+(\w+)', sql, re.IGNORECASE)or \ re.search(r'INTO\s+(\w+)', sql, re.IGNORECASE)or \ re.search(r'UPDATE\s+(\w+)', sql, re.IGNORECASE)or \ re.search(r'TABLE\s+(\w+)', sql, re.IGNORECASE) table = table_match.group(1)if table_match else'unknown' short_sql = sql[:max_len].replace('\n',' ')return f"[{sql_type}] {table}: {short_sql}..."# ─────────────────────────────────────────────# 第三部分:SQL 执行核心逻辑# ─────────────────────────────────────────────def execute_sql_file(db, sql_file, config, backup_dir, dry_run=False):"""执行单个 SQL 文件""" filename = os.path.basename(sql_file)print(f"\n 📄 文件: {filename}")with open(sql_file,'r', encoding='utf-8')as f: sql_content = f.read() statements = parse_sql_file(sql_content)print(f" 解析到 {len(statements)} 条 SQL 语句") results =[]for i, stmt in enumerate(statements,1): sql_type = detect_sql_type(stmt) summary = get_sql_summary(stmt)print(f"\n [{i}/{len(statements)}] {summary}")if dry_run:print(f" 模拟执行(dry-run 模式,不实际执行)") results.append({'sql': stmt[:100],'type': sql_type,'status':'dry_run','message':'模拟执行,未实际运行'})continue# SELECT 类直接执行if sql_type =='SELECT': result = db.execute(stmt)if result['status']=='ok': rows = result.get('rows',[])print(f" SELECT 返回 {len(rows)} 行") results.append({'sql': stmt[:100],'type': sql_type,'status':'ok','rows': rows[:20],# 最多保留20行'total_rows': len(rows)})else:print(f" 执行失败: {result['message']}") results.append({**result,'type': sql_type})else:# INSERT/UPDATE/DELETE 类:先备份可能影响的表,再执行 table_match = re.search(r'(?:INTO|UPDATE|DELETE FROM)\s+(\w+)', stmt, re.IGNORECASE)if table_match: table_name = table_match.group(1)if config.get('backup_before',True): db.backup_table(table_name, backup_dir) result = db.execute(stmt, commit=True)if result['status']=='ok': affected = result.get('affected',0) msg = f"✅ 影响 {affected} 行"if affected >=0else"✅ 执行成功"print(f" {msg}") results.append({'sql': stmt[:100],'type': sql_type,'status':'ok','affected': affected})else:print(f" 执行失败: {result['message']}")ifnot config.get('continue_on_error',True):print(f" 终止执行(continue_on_error=False)")break results.append({**result,'type': sql_type})# 执行间隔 delay = config.get('delay_between',0)if delay >0: time.sleep(delay)return resultsdef compare_results(before_file, after_file, table_name, key_column='id'):""" 对比两个 JSON 快照,找出差异 用于验证 SQL 执行前后的数据变化 """with open(before_file,'r', encoding='utf-8')as f: before_data = json.load(f)with open(after_file,'r', encoding='utf-8')as f: after_data = json.load(f)# 建立 key-indexed map before_map ={str(row.get(key_column,'')): row for row in before_data} after_map ={str(row.get(key_column,'')): row for row in after_data} added =[after_map[k]for k in after_map if k notin before_map] deleted =[before_map[k]for k in before_map if k notin after_map] changed =[]for k in before_map:if k in after_map:if before_map[k]!= after_map[k]: changed.append({'key': k,'before': before_map[k],'after': after_map[k]})return{'table': table_name,'before_count': len(before_data),'after_count': len(after_data),'added': added,'deleted': deleted,'changed': changed}# ─────────────────────────────────────────────# 第四部分:数据脱敏导出# ─────────────────────────────────────────────def mask_sensitive_data(data, rules=None):""" 对查询结果进行敏感数据脱敏 rules: 字段名 → 脱敏规则 """if rules isNone:# 默认规则 rules ={'phone':'phone_mob','mobile':'phone_mob','tel':'phone_mob','id_card':'id_card','id_number':'id_card','bank_card':'bank_card','password':'password','email':'email',}def mask_value(field_name, value):if value isNone:returnNone field_lower = field_name.lower()for keyword, rule in rules.items():if keyword in field_lower:if rule =='phone_mob'and value: s = str(value)return s[:3]+'****'+ s[-4:]if len(s)>=11else'***'elif rule =='id_card'and value: s = str(value)return s[:6]+'********'+ s[-4:]if len(s)>=14else'****'elif rule =='bank_card'and value: s = str(value)return s[:4]+'****'+ s[-4:]if len(s)>=8else'****'elif rule =='email'and value:if'@'in str(value): parts = str(value).split('@')return parts[0][:2]+'***@'+ parts[1]elif rule =='password':return'******'return valueif isinstance(data, list): result =[]for row in data:if isinstance(row, dict): result.append({k: mask_value(k, v)for k, v in row.items()})else: result.append(row)return resultelif isinstance(data, dict):return{k: mask_value(k, v)for k, v in data.items()}return data# ─────────────────────────────────────────────# 第五部分:执行日志与报告生成# ─────────────────────────────────────────────def generate_execution_log(results, output_file):"""生成执行日志""" total = len(results) ok = sum(1for r in results if r['status']=='ok') errors =[r for r in results if r['status']=='error'] lines =[] lines.append(f"# SQL 执行日志") lines.append(f"**时间:** {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") lines.append(f"**总数:** {total} 条 | 成功 {ok} 条 | 失败 {len(errors)} 条") lines.append('') lines.append('---') lines.append('')for i, r in enumerate(results,1): status_icon ={'ok':'','error':'','dry_run':''}.get(r['status'],'') lines.append(f"### {i}. {status_icon} [{r.get('type', '?')}]")if r['status']=='ok':if r.get('affected')isnotNone: lines.append(f"- 影响行数: {r['affected']}")if r.get('rows')isnotNone: lines.append(f"- 返回行数: {r.get('total_rows', len(r['rows']))}")if r.get('last_insert_id')isnotNone: lines.append(f"- 插入ID: {r['last_insert_id']}")else: lines.append(f"- 错误: {r.get('message', '未知错误')}")if r.get('sql'): lines.append(f"- SQL: `{r['sql'][:80]}...`") lines.append('')with open(output_file,'w', encoding='utf-8')as f: f.write('\n'.join(lines))print(f"\n 执行日志: {output_file}")return ok, len(errors)# ─────────────────────────────────────────────# 主入口# ─────────────────────────────────────────────def main(): parser = argparse.ArgumentParser(description='SQL 批量执行与结果对比工具') parser.add_argument('--config','-c', default='db_config.yaml', help='数据库配置文件') parser.add_argument('--sql','-s', required=True, help='SQL 文件或包含 SQL 文件的目录') parser.add_argument('--db','-d', help='数据库名称(覆盖配置文件中的 default_db)') parser.add_argument('--dry-run', action='store_true', help='模拟执行,不实际运行') parser.add_argument('--output','-o', default='output', help='输出目录') parser.add_argument('--backup','-b', default='backups', help='备份目录') parser.add_argument('--mask', action='store_true', default=True, help='导出时自动脱敏') args = parser.parse_args()# 加载配置with open(args.config,'r', encoding='utf-8')as f: full_config = yaml.safe_load(f) exec_config = full_config.get('execution',{}) db_name = args.db or full_config.get('default_db','test_db') db_configs = full_config.get('databases',{})if db_name notin db_configs:print(f" 数据库 '{db_name}' 不在配置中!可用: {list(db_configs.keys())}") sys.exit(1)# 建立输出目录 os.makedirs(args.output, exist_ok=True) os.makedirs(args.backup, exist_ok=True)# 收集 SQL 文件 sql_path = os.path.abspath(args.sql)if os.path.isfile(sql_path): sql_files =[sql_path]elif os.path.isdir(sql_path): sql_files = list(Path(sql_path).glob('*.sql')) sql_files =[str(f)for f in sql_files]else:print(f" 路径不存在: {sql_path}") sys.exit(1)ifnot sql_files:print(" 未找到 .sql 文件") sys.exit(0)print(f"\n{'='*60}")print(f"🔧 SQL 批量执行工具")print(f" 数据库: {db_name} ({db_configs[db_name]['type']})")print(f" SQL 文件: {len(sql_files)} 个")if args.dry_run:print(f" 模式: DRY-RUN(不实际执行)")print(f"{'='*60}")# 建立数据库连接withDatabaseConnection(db_configs[db_name])as db: all_results =[]for sql_file in sql_files: results = execute_sql_file( db, sql_file, exec_config, args.backup, dry_run=args.dry_run) all_results.extend(results)# 生成执行日志 timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S') log_file = os.path.join(args.output, f'execution_log_{timestamp}.md') ok_count, error_count = generate_execution_log(all_results, log_file)# 保存 JSON 结果 json_file = os.path.join(args.output, f'execution_results_{timestamp}.json')# 脱敏处理 processed_results = all_resultsif args.mask:for r in processed_results:if r.get('rows'): r['rows']= mask_sensitive_data(r['rows'])with open(json_file,'w', encoding='utf-8')as f: json.dump(processed_results, f, ensure_ascii=False, indent=2, default=str)print(f"\n{'='*60}")print(f" 执行完成!")print(f" 日志: {log_file}")print(f" JSON: {json_file}")print(f" 备份: {args.backup}/")print(f" 成功: {ok_count} | 失败: {error_count}")print(f"{'='*60}") sys.exit(0if error_count ==0else1)if __name__ =='__main__': main()
使用示例
示例 1:批量准备测试数据
# 准备好测试数据 SQL 文件目录# data_prepare/user_data.sql# data_prepare/product_data.sql# data_prepare/order_data.sqlpython sql_executor.py --config db_config.yaml \--sql ./data_prepare/ \--db test_db \--output ./results
示例 2:干跑验证 SQL 正确性
# 不实际执行,只验证 SQL 是否可解析python sql_executor.py --config db_config.yaml \--sql ./scripts/test.sql \--dry-run
示例 3:对比执行前后数据变化
# 在 SQL 执行前后分别导出数据快照from sql_executor importDatabaseConnection, mask_sensitive_dataimport json, yamlwith open('db_config.yaml')as f: config = yaml.safe_load(f)withDatabaseConnection(config['databases']['test_db'])as db:# 执行前快照 result = db.execute("SELECT * FROM users WHERE id IN (1,2,3)") before = mask_sensitive_data(result['rows'])with open('before_snapshot.json','w')as f: json.dump(before, f, ensure_ascii=False, default=str)# 执行你的 SQL db.execute("UPDATE users SET name='测试用户' WHERE id IN (1,2,3)", commit=True)# 执行后快照 result = db.execute("SELECT * FROM users WHERE id IN (1,2,3)") after = mask_sensitive_data(result['rows'])with open('after_snapshot.json','w')as f: json.dump(after, f, ensure_ascii=False, default=str)
进阶用法:数据验证场景
这个工具特别适合做存储过程/触发器的回归测试:
# 1. 创建测试 SQL(test_sp_logic.sql)SELECT '=== 执行前 ===' AS stage;SELECT COUNT(*) AS user_count FROM users WHERE status='active';SELECT COUNT(*) AS order_count FROM orders WHERE DATE(created_at)= CURDATE();--执行存储过程CALL update_daily_statistics();SELECT '=== 执行后 ===' AS stage;SELECT COUNT(*) AS user_count FROM users WHERE status='active';SELECT COUNT(*) AS order_count FROM orders WHERE DATE(created_at)= CURDATE();# 2. 运行python sql_executor.py -s test_sp_logic.sql -c db_config.yaml -d test_db# 3. 对比执行日志中两次 SELECT 的结果,快速验证逻辑是否正确
完整使用流程
Step 1:安装依赖
pip install pymysql psycopg2-binary pyyaml
Step 2:配置数据库连接
编辑 db_config.yaml,填入测试数据库的连接信息。
Step 3:准备 SQL 文件
把要执行的 SQL 保存为 .sql 文件,放到一个目录下。
Step 4:执行
python sql_executor.py --config db_config.yaml --sql ./data_prepare/--db test_db
Step 5:查看结果
打开 output/execution_log_xxx.md,里面有每条 SQL 的执行状态和结果。
工具设计亮点
| |
|---|
| 事务保护 | |
| 数据脱敏 | |
| 容错执行 | continue_on_error=True |
| 日志完整 | |
| 多数据库支持 | MySQL / PostgreSQL / SQLite 一套代码全部覆盖 |
总结
这个工具解决的核心问题是:把数据库操作从”手工作坊”变成”工业化流水线”。