compile() - 代码编译函数
将字符串形式的 Python 代码编译为代码对象
compile() - 代码编译函数
📝 概述
compile() 是 Python 中的内置函数,用于将字符串形式的 Python 代码编译为代码对象。编译后的代码对象可以被 exec() 或 eval() 函数执行,也可以直接执行。这个函数在动态代码生成、性能优化、代码分析等场景中非常有用。
🎯 学习目标
- 掌握 compile()函数的基本用法和语法
- 理解代码对象的概念和特性
- 学会不同编译模式的使用场景
- 掌握编译优化和缓存技术
- 了解字节码分析和调试技术
📋 前置知识
- Python 基本语法
- 表达式和语句的区别
- 异常处理的基本知识
- 文件操作基础
- 面向对象编程基础
🔍 详细内容
基本概念
compile() 函数将源代码字符串编译成代码对象,这个过程包括:
- 词法分析(Lexical Analysis)
- 语法分析(Syntax Analysis)
- 抽象语法树生成(AST Generation)
- 字节码生成(Bytecode Generation)
编译后的代码对象可以被重复执行,避免了重复编译的开销。
语法格式
compile(source, filename, mode, flags=0, dont_inherit=False, optimize=-1)
参数说明
| 参数名 | 类型 | 必需 | 默认值 | 说明 |
|---|---|---|---|---|
| source | 字符串、字节串或 AST 对象 | 是 | 无 | 要编译的源代码 |
| filename | 字符串 | 是 | 无 | 文件名(用于错误报告) |
| mode | 字符串 | 是 | 无 | 编译模式:’exec’、’eval’、’single’ |
| flags | 整数 | 否 | 0 | 编译标志 |
| dont_inherit | 布尔值 | 否 | False | 是否继承编译标志 |
| optimize | 整数 | 否 | -1 | 优化级别:-1、0、1、2 |
编译模式
- ‘exec’: 编译语句序列,返回 None
- ‘eval’: 编译单个表达式,返回表达式的值
- ‘single’: 编译单个交互式语句
返回值
- 类型: 代码对象(code object)
- 内容: 可执行的字节码对象
💡 代码示例
基本用法
## 基本编译示例
print("基本编译示例:")
## 编译表达式(eval 模式)
expr_source = "2 + 3 * 4"
expr_code = compile(expr_source, '<string>', 'eval')
result = eval(expr_code)
print(f"表达式 '{expr_source}' = {result}") # 14
## 编译语句(exec 模式)
stmt_source = """
x = 10
y = 20
z = x + y
print(f'x + y = {z}')
"""
stmt_code = compile(stmt_source, '<string>', 'exec')
exec(stmt_code)
## 编译单个交互式语句(single 模式)
single_source = "print('Hello from single mode!')"
single_code = compile(single_source, '<string>', 'single')
exec(single_code)
## 查看代码对象属性
print(f"\n 代码对象属性:")
print(f" 文件名: {expr_code.co_filename}")
print(f" 名称: {expr_code.co_name}")
print(f" 参数数量: {expr_code.co_argcount}")
print(f" 局部变量数量: {expr_code.co_nlocals}")
print(f" 字节码: {expr_code.co_code[:20]}...") # 显示前 20 个字节
## 比较编译性能
import time
print(f"\n 性能比较:")
## 重复执行字符串代码(每次都编译)
code_string = "sum(range(1000))"
start_time = time.time()
for _ in range(1000):
result = eval(code_string)
string_time = time.time() - start_time
## 重复执行编译后的代码
compiled_code = compile(code_string, '<string>', 'eval')
start_time = time.time()
for _ in range(1000):
result = eval(compiled_code)
compiled_time = time.time() - start_time
print(f"字符串执行时间: {string_time:.4f}秒")
print(f"编译后执行时间: {compiled_time:.4f}秒")
print(f"性能提升: {string_time/compiled_time:.1f}倍")
编译标志和优化
## 编译标志和优化示例
print("\n 编译标志和优化示例:")
## 导入编译标志
from __future__ import annotations
import ast
## 不同优化级别的编译
code_source = """
def calculate(x):
"""计算函数"""
assert x > 0, "x must be positive" # 断言在优化时可能被移除
result = x * 2 + 1
return result
value = calculate(10)
"""
print("不同优化级别编译:")
## 无优化(optimize=0)
code_no_opt = compile(code_source, '<string>', 'exec', optimize=0)
print(f" 无优化字节码长度: {len(code_no_opt.co_code)}")
## 基本优化(optimize=1)
code_opt1 = compile(code_source, '<string>', 'exec', optimize=1)
print(f" 基本优化字节码长度: {len(code_opt1.co_code)}")
## 高级优化(optimize=2)
code_opt2 = compile(code_source, '<string>', 'exec', optimize=2)
print(f" 高级优化字节码长度: {len(code_opt2.co_code)}")
## 使用编译标志
print(f"\n 编译标志示例:")
## 启用除法特性
from __future__ import division
division_code = "result = 3 / 2" # 在 Python 3 中默认是真除法
## 编译时指定标志
flags = 0
try:
compiled_with_flags = compile(division_code, '<string>', 'exec', flags=flags)
namespace = {}
exec(compiled_with_flags, namespace)
print(f" 除法结果: {namespace['result']}")
except Exception as e:
print(f" 编译失败: {e}")
## 语法检查编译
print(f"\n 语法检查:")
valid_codes = [
"x = 1 + 2",
"def func(): return 42",
"[x for x in range(10) if x % 2 == 0]"
]
invalid_codes = [
"x = 1 +", # 语法错误
"def func( return 42", # 括号不匹配
"if x == 1 print('hello')", # 缺少冒号
]
print("有效代码编译:")
for code in valid_codes:
try:
compile(code, '<string>', 'exec')
print(f" ✓ '{code}' - 编译成功")
except SyntaxError as e:
print(f" ✗ '{code}' - 语法错误: {e}")
print("\n 无效代码编译:")
for code in invalid_codes:
try:
compile(code, '<string>', 'exec')
print(f" ✓ '{code}' - 编译成功(不应该成功)")
except SyntaxError as e:
print(f" ✗ '{code}' - 语法错误: {e.msg}")
代码缓存系统
import hashlib
import pickle
import os
from typing import Dict, Any, Optional
## 代码缓存系统
class CodeCache:
"""代码编译缓存系统"""
def __init__(self, cache_dir: str = "code_cache"):
self.cache_dir = cache_dir
self.memory_cache: Dict[str, Any] = {}
# # 创建缓存目录
if not os.path.exists(cache_dir):
os.makedirs(cache_dir)
def _get_cache_key(self, source: str, mode: str, optimize: int) -> str:
"""生成缓存键"""
content = f"{source}|{mode}|{optimize}"
return hashlib.md5(content.encode()).hexdigest()
def _get_cache_path(self, cache_key: str) -> str:
"""获取缓存文件路径"""
return os.path.join(self.cache_dir, f"{cache_key}.pyc")
def get_from_memory(self, cache_key: str) -> Optional[Any]:
"""从内存缓存获取"""
return self.memory_cache.get(cache_key)
def save_to_memory(self, cache_key: str, code_obj: Any):
"""保存到内存缓存"""
self.memory_cache[cache_key] = code_obj
def get_from_disk(self, cache_key: str) -> Optional[Any]:
"""从磁盘缓存获取"""
cache_path = self._get_cache_path(cache_key)
if os.path.exists(cache_path):
try:
with open(cache_path, 'rb') as f:
return pickle.load(f)
except Exception:
# # 缓存文件损坏,删除它
os.remove(cache_path)
return None
def save_to_disk(self, cache_key: str, code_obj: Any):
"""保存到磁盘缓存"""
cache_path = self._get_cache_path(cache_key)
try:
with open(cache_path, 'wb') as f:
pickle.dump(code_obj, f)
except Exception as e:
print(f"保存缓存失败: {e}")
def compile_with_cache(self, source: str, filename: str = '<string>',
mode: str = 'exec', optimize: int = -1) -> Any:
"""带缓存的编译"""
# # 生成缓存键
cache_key = self._get_cache_key(source, mode, optimize)
# # 尝试从内存缓存获取
code_obj = self.get_from_memory(cache_key)
if code_obj is not None:
return code_obj
# # 尝试从磁盘缓存获取
code_obj = self.get_from_disk(cache_key)
if code_obj is not None:
# # 保存到内存缓存
self.save_to_memory(cache_key, code_obj)
return code_obj
# # 编译代码
try:
code_obj = compile(source, filename, mode, optimize=optimize)
# # 保存到缓存
self.save_to_memory(cache_key, code_obj)
self.save_to_disk(cache_key, code_obj)
return code_obj
except Exception as e:
raise RuntimeError(f"编译失败: {e}")
def clear_cache(self):
"""清空缓存"""
# # 清空内存缓存
self.memory_cache.clear()
# # 清空磁盘缓存
for filename in os.listdir(self.cache_dir):
if filename.endswith('.pyc'):
os.remove(os.path.join(self.cache_dir, filename))
def get_cache_stats(self) -> Dict[str, Any]:
"""获取缓存统计"""
disk_files = [f for f in os.listdir(self.cache_dir) if f.endswith('.pyc')]
return {
'memory_cache_size': len(self.memory_cache),
'disk_cache_size': len(disk_files),
'cache_dir': self.cache_dir
}
## 测试代码缓存
print("\n 代码缓存系统示例:")
cache = CodeCache()
## 测试代码
test_codes = [
"result = sum(range(100))",
"def factorial(n): return 1 if n <= 1 else n * factorial(n-1)",
"squares = [x**2 for x in range(10)]",
"import math; result = math.sqrt(16)"
]
print("首次编译(无缓存):")
start_time = time.time()
for i, code in enumerate(test_codes):
compiled_code = cache.compile_with_cache(code)
print(f" 代码 {i+1}: 编译完成")
first_time = time.time() - start_time
print(f"\n 第二次编译(使用缓存):")
start_time = time.time()
for i, code in enumerate(test_codes):
compiled_code = cache.compile_with_cache(code)
print(f" 代码 {i+1}: 从缓存获取")
second_time = time.time() - start_time
print(f"\n 性能对比:")
print(f"首次编译时间: {first_time:.4f}秒")
print(f"缓存编译时间: {second_time:.4f}秒")
print(f"性能提升: {first_time/second_time:.1f}倍")
## 缓存统计
stats = cache.get_cache_stats()
print(f"\n 缓存统计:")
print(f" 内存缓存大小: {stats['memory_cache_size']}")
print(f" 磁盘缓存大小: {stats['disk_cache_size']}")
print(f" 缓存目录: {stats['cache_dir']}")
## 清理缓存
cache.clear_cache()
print(f"\n 缓存已清理")
动态代码生成器
from typing import List, Dict, Any, Callable
import inspect
## 动态代码生成器
class DynamicCodeGenerator:
"""动态代码生成器"""
def __init__(self):
self.code_cache = CodeCache()
self.generated_functions: Dict[str, Callable] = {}
def generate_getter_setter(self, class_name: str, fields: List[str]) -> str:
"""生成 getter/setter 类代码"""
lines = [f"class {class_name}:"]
lines.append(" def __init__(self):")
# # 初始化字段
for field in fields:
lines.append(f" self._{field} = None")
lines.append("")
# # 生成 getter 和 setter
for field in fields:
# # Getter
lines.append(f" def get_{field}(self):")
lines.append(f" return self._{field}")
lines.append("")
# # Setter
lines.append(f" def set_{field}(self, value):")
lines.append(f" self._{field} = value")
lines.append(f" return self")
lines.append("")
# # __str__ 方法
lines.append(" def __str__(self):")
field_strs = [f"'{field}: {{self._{field}}}'".replace('{field}', field) for field in fields]
lines.append(f" return f'{class_name}(' + ', '.join([{', '.join(field_strs)}]) + ')'")
return "\n".join(lines)
def generate_validator_function(self, func_name: str, rules: Dict[str, Any]) -> str:
"""生成验证函数代码"""
lines = [f"def {func_name}(data):"]
lines.append(" errors = []")
lines.append("")
for field, rule in rules.items():
lines.append(f" # 验证 {field}")
lines.append(f" if '{field}' not in data:")
lines.append(f" errors.append('{field} is required')")
lines.append(f" else:")
lines.append(f" value = data['{field}']")
if 'type' in rule:
type_name = rule['type'].__name__
lines.append(f" if not isinstance(value, {rule['type'].__name__}):")
lines.append(f" errors.append('{field} must be {type_name}')")
if 'min_length' in rule:
lines.append(f" elif len(value) < {rule['min_length']}:")
lines.append(f" errors.append('{field} must be at least {rule['min_length']} characters')")
if 'max_length' in rule:
lines.append(f" elif len(value) > {rule['max_length']}:")
lines.append(f" errors.append('{field} must be at most {rule['max_length']} characters')")
if 'min_value' in rule:
lines.append(f" elif value < {rule['min_value']}:")
lines.append(f" errors.append('{field} must be at least {rule['min_value']}')")
if 'max_value' in rule:
lines.append(f" elif value > {rule['max_value']}:")
lines.append(f" errors.append('{field} must be at most {rule['max_value']}')")
lines.append("")
lines.append(" return {'valid': len(errors) == 0, 'errors': errors}")
return "\n".join(lines)
def generate_math_function(self, func_name: str, expression: str, variables: List[str]) -> str:
"""生成数学函数代码"""
params = ', '.join(variables)
lines = [
f"def {func_name}({params}):",
f" return {expression}"
]
return "\n".join(lines)
def compile_and_execute(self, code: str, func_name: str = None) -> Any:
"""编译并执行代码"""
try:
# # 编译代码
compiled_code = self.code_cache.compile_with_cache(code)
# # 执行代码
namespace = {}
exec(compiled_code, namespace)
# # 如果指定了函数名,返回函数
if func_name and func_name in namespace:
func = namespace[func_name]
self.generated_functions[func_name] = func
return func
return namespace
except Exception as e:
raise RuntimeError(f"代码编译或执行失败: {e}")
def get_function(self, func_name: str) -> Optional[Callable]:
"""获取生成的函数"""
return self.generated_functions.get(func_name)
def list_functions(self) -> List[str]:
"""列出所有生成的函数"""
return list(self.generated_functions.keys())
## 测试动态代码生成器
print("\n 动态代码生成器示例:")
generator = DynamicCodeGenerator()
## 生成 getter/setter 类
print("1. 生成 getter/setter 类:")
user_class_code = generator.generate_getter_setter('User', ['name', 'age', 'email'])
print("生成的类代码:")
print(user_class_code[:200] + "...")
## 编译并使用类
namespace = generator.compile_and_execute(user_class_code)
User = namespace['User']
## 测试生成的类
user = User()
user.set_name('Alice').set_age(30).set_email('alice@example.com')
print(f"\n 用户对象: {user}")
print(f"姓名: {user.get_name()}")
print(f"年龄: {user.get_age()}")
print(f"邮箱: {user.get_email()}")
## 生成验证函数
print(f"\n2. 生成验证函数:")
validation_rules = {
'name': {'type': str, 'min_length': 2, 'max_length': 50},
'age': {'type': int, 'min_value': 0, 'max_value': 150},
'email': {'type': str, 'min_length': 5, 'max_length': 100}
}
validator_code = generator.generate_validator_function('validate_user', validation_rules)
validate_user = generator.compile_and_execute(validator_code, 'validate_user')
## 测试验证函数
test_data = [
{'name': 'Alice', 'age': 30, 'email': 'alice@example.com'},
{'name': 'B', 'age': 30, 'email': 'bob@example.com'}, # 名字太短
{'name': 'Charlie', 'age': -5, 'email': 'charlie@example.com'}, # 年龄无效
{'name': 'David', 'email': 'david@example.com'}, # 缺少年龄
]
for i, data in enumerate(test_data, 1):
result = validate_user(data)
print(f" 测试 {i}: {'有效' if result['valid'] else '无效'}")
if not result['valid']:
print(f" 错误: {result['errors']}")
## 生成数学函数
print(f"\n3. 生成数学函数:")
math_functions = [
('quadratic', 'a * x**2 + b * x + c', ['a', 'b', 'c', 'x']),
('distance', '((x2 - x1)**2 + (y2 - y1)**2)**0.5', ['x1', 'y1', 'x2', 'y2']),
('compound_interest', 'principal * (1 + rate)**time', ['principal', 'rate', 'time'])
]
for func_name, expression, variables in math_functions:
math_code = generator.generate_math_function(func_name, expression, variables)
func = generator.compile_and_execute(math_code, func_name)
print(f" 生成函数: {func_name}")
## 测试数学函数
print(f"\n 数学函数测试:")
quadratic = generator.get_function('quadratic')
distance = generator.get_function('distance')
compound_interest = generator.get_function('compound_interest')
print(f" 二次函数 f(x) = 2x² + 3x + 1, x=5: {quadratic(2, 3, 1, 5)}")
print(f" 距离 (0,0) 到 (3,4): {distance(0, 0, 3, 4)}")
print(f" 复利 本金 1000, 利率 5%, 时间 10 年: {compound_interest(1000, 0.05, 10):.2f}")
## 列出所有生成的函数
print(f"\n 生成的函数列表: {generator.list_functions()}")
🚀 高级应用
字节码分析工具
import dis
import types
from typing import List, Dict, Any
## 字节码分析工具
class BytecodeAnalyzer:
"""字节码分析工具"""
def __init__(self):
self.analysis_cache = {}
def analyze_code(self, code_obj: types.CodeType) -> Dict[str, Any]:
"""分析代码对象"""
analysis = {
'basic_info': self._get_basic_info(code_obj),
'bytecode': self._get_bytecode_info(code_obj),
'constants': self._get_constants_info(code_obj),
'variables': self._get_variables_info(code_obj),
'complexity': self._calculate_complexity(code_obj)
}
return analysis
def _get_basic_info(self, code_obj: types.CodeType) -> Dict[str, Any]:
"""获取基本信息"""
return {
'name': code_obj.co_name,
'filename': code_obj.co_filename,
'first_line': code_obj.co_firstlineno,
'arg_count': code_obj.co_argcount,
'local_count': code_obj.co_nlocals,
'stack_size': code_obj.co_stacksize,
'flags': code_obj.co_flags
}
def _get_bytecode_info(self, code_obj: types.CodeType) -> Dict[str, Any]:
"""获取字节码信息"""
instructions = list(dis.get_instructions(code_obj))
# # 统计指令类型
instruction_counts = {}
for instr in instructions:
instruction_counts[instr.opname] = instruction_counts.get(instr.opname, 0) + 1
return {
'instruction_count': len(instructions),
'instruction_types': len(instruction_counts),
'instruction_distribution': instruction_counts,
'bytecode_size': len(code_obj.co_code)
}
def _get_constants_info(self, code_obj: types.CodeType) -> Dict[str, Any]:
"""获取常量信息"""
constants = code_obj.co_consts or ()
# # 按类型分类常量
constant_types = {}
for const in constants:
const_type = type(const).__name__
if const_type not in constant_types:
constant_types[const_type] = []
constant_types[const_type].append(const)
return {
'constant_count': len(constants),
'constant_types': constant_types,
'has_none': None in constants,
'has_strings': any(isinstance(c, str) for c in constants),
'has_numbers': any(isinstance(c, (int, float)) for c in constants)
}
def _get_variables_info(self, code_obj: types.CodeType) -> Dict[str, Any]:
"""获取变量信息"""
return {
'var_names': list(code_obj.co_varnames or ()),
'free_vars': list(code_obj.co_freevars or ()),
'cell_vars': list(code_obj.co_cellvars or ()),
'global_names': list(code_obj.co_names or ())
}
def _calculate_complexity(self, code_obj: types.CodeType) -> Dict[str, Any]:
"""计算复杂度"""
instructions = list(dis.get_instructions(code_obj))
# # 计算循环复杂度(简化版)
branch_instructions = {
'POP_JUMP_IF_FALSE', 'POP_JUMP_IF_TRUE', 'JUMP_IF_FALSE_OR_POP',
'JUMP_IF_TRUE_OR_POP', 'JUMP_FORWARD', 'JUMP_ABSOLUTE'
}
branch_count = sum(1 for instr in instructions if instr.opname in branch_instructions)
cyclomatic_complexity = branch_count + 1
return {
'cyclomatic_complexity': cyclomatic_complexity,
'branch_count': branch_count,
'instruction_density': len(instructions) / max(1, len(code_obj.co_code))
}
def compare_codes(self, code1: types.CodeType, code2: types.CodeType) -> Dict[str, Any]:
"""比较两个代码对象"""
analysis1 = self.analyze_code(code1)
analysis2 = self.analyze_code(code2)
comparison = {
'bytecode_size_diff': analysis2['bytecode']['bytecode_size'] - analysis1['bytecode']['bytecode_size'],
'instruction_count_diff': analysis2['bytecode']['instruction_count'] - analysis1['bytecode']['instruction_count'],
'complexity_diff': analysis2['complexity']['cyclomatic_complexity'] - analysis1['complexity']['cyclomatic_complexity'],
'constant_count_diff': analysis2['constants']['constant_count'] - analysis1['constants']['constant_count']
}
return {
'code1_analysis': analysis1,
'code2_analysis': analysis2,
'differences': comparison
}
def print_analysis(self, analysis: Dict[str, Any]):
"""打印分析结果"""
print(f"代码分析报告:")
print(f"=" * 40)
# # 基本信息
basic = analysis['basic_info']
print(f"名称: {basic['name']}")
print(f"文件: {basic['filename']}")
print(f"起始行: {basic['first_line']}")
print(f"参数数量: {basic['arg_count']}")
print(f"局部变量数量: {basic['local_count']}")
print(f"栈大小: {basic['stack_size']}")
# # 字节码信息
bytecode = analysis['bytecode']
print(f"\n 字节码信息:")
print(f" 指令数量: {bytecode['instruction_count']}")
print(f" 指令类型: {bytecode['instruction_types']}")
print(f" 字节码大小: {bytecode['bytecode_size']} 字节")
# # 常量信息
constants = analysis['constants']
print(f"\n 常量信息:")
print(f" 常量数量: {constants['constant_count']}")
print(f" 包含 None: {constants['has_none']}")
print(f" 包含字符串: {constants['has_strings']}")
print(f" 包含数字: {constants['has_numbers']}")
# # 复杂度信息
complexity = analysis['complexity']
print(f"\n 复杂度信息:")
print(f" 圈复杂度: {complexity['cyclomatic_complexity']}")
print(f" 分支数量: {complexity['branch_count']}")
print(f" 指令密度: {complexity['instruction_density']:.2f}")
## 测试字节码分析
print("\n 字节码分析示例:")
analyzer = BytecodeAnalyzer()
## 分析不同复杂度的代码
simple_code = compile("result = 2 + 3", '<string>', 'exec')
complex_code = compile("""
def fibonacci(n):
if n <= 1:
return n
else:
return fibonacci(n-1) + fibonacci(n-2)
result = fibonacci(10)
""", '<string>', 'exec')
print("简单代码分析:")
simple_analysis = analyzer.analyze_code(simple_code)
analyzer.print_analysis(simple_analysis)
print(f"\n 复杂代码分析:")
complex_analysis = analyzer.analyze_code(complex_code)
analyzer.print_analysis(complex_analysis)
## 比较代码
print(f"\n 代码比较:")
comparison = analyzer.compare_codes(simple_code, complex_code)
differences = comparison['differences']
print(f"字节码大小差异: {differences['bytecode_size_diff']} 字节")
print(f"指令数量差异: {differences['instruction_count_diff']}")
print(f"复杂度差异: {differences['complexity_diff']}")
print(f"常量数量差异: {differences['constant_count_diff']}")
## 显示字节码
print(f"\n 简单代码字节码:")
dis.dis(simple_code)
print(f"\n 复杂代码字节码(前 20 行):")
instructions = list(dis.get_instructions(complex_code))
for i, instr in enumerate(instructions[:20]):
print(f" {i:2d} {instr.opname:20} {instr.arg or '':10} {instr.argval or ''}")
if len(instructions) > 20:
print(f" ... 还有 {len(instructions) - 20} 条指令")
代码优化分析器
import ast
import time
from typing import List, Tuple, Dict, Any
## 代码优化分析器
class CodeOptimizer:
"""代码优化分析器"""
def __init__(self):
self.optimization_rules = {
'constant_folding': self._suggest_constant_folding,
'loop_optimization': self._suggest_loop_optimization,
'function_calls': self._suggest_function_optimization,
'data_structures': self._suggest_data_structure_optimization
}
def analyze_and_optimize(self, source_code: str) -> Dict[str, Any]:
"""分析并提供优化建议"""
try:
# # 解析 AST
tree = ast.parse(source_code)
# # 编译原始代码
original_code = compile(source_code, '<string>', 'exec')
# # 分析代码
analysis = {
'original_code': source_code,
'ast_analysis': self._analyze_ast(tree),
'optimization_suggestions': self._get_optimization_suggestions(tree),
'performance_analysis': self._analyze_performance(original_code)
}
return analysis
except Exception as e:
return {'error': f"分析失败: {e}"}
def _analyze_ast(self, tree: ast.AST) -> Dict[str, Any]:
"""分析 AST"""
node_counts = {}
for node in ast.walk(tree):
node_type = type(node).__name__
node_counts[node_type] = node_counts.get(node_type, 0) + 1
return {
'total_nodes': sum(node_counts.values()),
'node_types': len(node_counts),
'node_distribution': node_counts
}
def _get_optimization_suggestions(self, tree: ast.AST) -> List[Dict[str, Any]]:
"""获取优化建议"""
suggestions = []
for rule_name, rule_func in self.optimization_rules.items():
rule_suggestions = rule_func(tree)
suggestions.extend(rule_suggestions)
return suggestions
def _suggest_constant_folding(self, tree: ast.AST) -> List[Dict[str, Any]]:
"""建议常量折叠优化"""
suggestions = []
for node in ast.walk(tree):
if isinstance(node, ast.BinOp):
# # 检查是否为常量运算
if (isinstance(node.left, ast.Constant) and
isinstance(node.right, ast.Constant)):
suggestions.append({
'type': 'constant_folding',
'description': '可以预计算常量表达式',
'line': getattr(node, 'lineno', 'unknown'),
'suggestion': '将常量运算预先计算'
})
return suggestions
def _suggest_loop_optimization(self, tree: ast.AST) -> List[Dict[str, Any]]:
"""建议循环优化"""
suggestions = []
for node in ast.walk(tree):
if isinstance(node, (ast.For, ast.While)):
# # 检查嵌套循环
nested_loops = [n for n in ast.walk(node)
if isinstance(n, (ast.For, ast.While)) and n != node]
if nested_loops:
suggestions.append({
'type': 'loop_optimization',
'description': '发现嵌套循环,考虑优化',
'line': getattr(node, 'lineno', 'unknown'),
'suggestion': '考虑循环展开或算法优化'
})
# # 检查循环中的函数调用
function_calls = [n for n in ast.walk(node) if isinstance(n, ast.Call)]
if len(function_calls) > 3:
suggestions.append({
'type': 'loop_optimization',
'description': '循环中有多个函数调用',
'line': getattr(node, 'lineno', 'unknown'),
'suggestion': '考虑将不变的函数调用移出循环'
})
return suggestions
def _suggest_function_optimization(self, tree: ast.AST) -> List[Dict[str, Any]]:
"""建议函数优化"""
suggestions = []
for node in ast.walk(tree):
if isinstance(node, ast.FunctionDef):
# # 检查函数长度
function_lines = getattr(node, 'end_lineno', 0) - getattr(node, 'lineno', 0)
if function_lines > 50:
suggestions.append({
'type': 'function_optimization',
'description': f'函数 {node.name} 过长 ({function_lines} 行)',
'line': getattr(node, 'lineno', 'unknown'),
'suggestion': '考虑拆分为更小的函数'
})
# # 检查参数数量
if len(node.args.args) > 7:
suggestions.append({
'type': 'function_optimization',
'description': f'函数 {node.name} 参数过多 ({len(node.args.args)} 个)',
'line': getattr(node, 'lineno', 'unknown'),
'suggestion': '考虑使用对象或字典传递参数'
})
return suggestions
def _suggest_data_structure_optimization(self, tree: ast.AST) -> List[Dict[str, Any]]:
"""建议数据结构优化"""
suggestions = []
for node in ast.walk(tree):
# # 检查列表推导式 vs 循环
if isinstance(node, ast.For):
# # 简单检查是否可以用列表推导式
if (len(node.body) == 1 and
isinstance(node.body[0], ast.Expr) and
isinstance(node.body[0].value, ast.Call)):
call_node = node.body[0].value
if (isinstance(call_node.func, ast.Attribute) and
call_node.func.attr == 'append'):
suggestions.append({
'type': 'data_structure_optimization',
'description': '可以使用列表推导式替代循环',
'line': getattr(node, 'lineno', 'unknown'),
'suggestion': '使用列表推导式提高性能'
})
return suggestions
def _analyze_performance(self, code_obj: types.CodeType) -> Dict[str, Any]:
"""分析性能"""
# # 使用字节码分析器
analyzer = BytecodeAnalyzer()
bytecode_analysis = analyzer.analyze_code(code_obj)
# # 性能评分(简化版)
complexity = bytecode_analysis['complexity']['cyclomatic_complexity']
instruction_count = bytecode_analysis['bytecode']['instruction_count']
performance_score = max(0, 100 - complexity * 5 - instruction_count * 0.1)
return {
'performance_score': performance_score,
'complexity_rating': 'Low' if complexity <= 5 else 'Medium' if complexity <= 10 else 'High',
'instruction_efficiency': instruction_count / max(1, len(code_obj.co_code)),
'bytecode_analysis': bytecode_analysis
}
def generate_optimized_code(self, original_code: str, suggestions: List[Dict[str, Any]]) -> str:
"""生成优化后的代码(简化版)"""
optimized_code = original_code
# # 这里只是示例,实际的代码优化需要更复杂的 AST 转换
for suggestion in suggestions:
if suggestion['type'] == 'constant_folding':
# # 简单的常量折叠示例
optimized_code = optimized_code.replace('2 + 3', '5')
optimized_code = optimized_code.replace('10 * 2', '20')
return optimized_code
## 测试代码优化分析器
print("\n 代码优化分析器示例:")
optimizer = CodeOptimizer()
## 测试代码
test_code = """
def process_data(data_list, threshold, multiplier, offset, debug_mode, log_file, config):
"""处理数据的函数(故意写得不够优化)"""
result = []
# # 常量运算(可以优化)
magic_number = 2 + 3 * 4
# # 嵌套循环(可能需要优化)
for i in range(len(data_list)):
for j in range(len(data_list)):
if i != j:
# # 循环中的重复计算(可以优化)
value = data_list[i] * multiplier + offset
if value > threshold:
# # 可以用列表推导式优化的模式
result.append(value)
# # 更多处理逻辑...
for item in data_list:
processed = item * 10 + 5 # 另一个常量运算
if processed > 100:
result.append(processed)
return result
## 调用函数
data = [1, 2, 3, 4, 5]
result = process_data(data, 50, 2, 10, True, 'log.txt', {'debug': True})
"""
## 分析代码
print("分析测试代码:")
analysis = optimizer.analyze_and_optimize(test_code)
if 'error' in analysis:
print(f"分析失败: {analysis['error']}")
else:
# # 显示 AST 分析
ast_analysis = analysis['ast_analysis']
print(f"\nAST 分析:")
print(f" 总节点数: {ast_analysis['total_nodes']}")
print(f" 节点类型数: {ast_analysis['node_types']}")
print(f" 主要节点类型: {list(ast_analysis['node_distribution'].keys())[:5]}")
# # 显示优化建议
suggestions = analysis['optimization_suggestions']
print(f"\n 优化建议 ({len(suggestions)} 条):")
for i, suggestion in enumerate(suggestions, 1):
print(f" {i}. {suggestion['description']} (行 {suggestion['line']})")
print(f" 建议: {suggestion['suggestion']}")
# # 显示性能分析
performance = analysis['performance_analysis']
print(f"\n 性能分析:")
print(f" 性能评分: {performance['performance_score']:.1f}/100")
print(f" 复杂度等级: {performance['complexity_rating']}")
print(f" 指令效率: {performance['instruction_efficiency']:.2f}")
# # 生成优化后的代码
optimized_code = optimizer.generate_optimized_code(test_code, suggestions)
if optimized_code != test_code:
print(f"\n 已生成优化后的代码(部分优化)")
else:
print(f"\n 未生成优化代码(需要手动优化)")
⚠️ 常见陷阱与最佳实践
编译错误处理
## 编译错误处理示例
print("\n 编译错误处理示例:")
## 错误处理包装器
class CompileErrorHandler:
"""编译错误处理器"""
def __init__(self):
self.error_patterns = {
'SyntaxError': self._handle_syntax_error,
'IndentationError': self._handle_indentation_error,
'TabError': self._handle_tab_error
}
def safe_compile(self, source: str, filename: str = '<string>',
mode: str = 'exec') -> tuple[Any, List[str]]:
"""安全编译,返回代码对象和错误信息"""
errors = []
try:
code_obj = compile(source, filename, mode)
return code_obj, errors
except Exception as e:
error_type = type(e).__name__
# # 使用专门的错误处理器
if error_type in self.error_patterns:
error_info = self.error_patterns[error_type](e, source)
else:
error_info = self._handle_generic_error(e, source)
errors.append(error_info)
return None, errors
def _handle_syntax_error(self, error: SyntaxError, source: str) -> Dict[str, Any]:
"""处理语法错误"""
return {
'type': 'SyntaxError',
'message': str(error),
'line': error.lineno,
'column': error.offset,
'text': error.text,
'suggestion': self._suggest_syntax_fix(error)
}
def _handle_indentation_error(self, error: IndentationError, source: str) -> Dict[str, Any]:
"""处理缩进错误"""
return {
'type': 'IndentationError',
'message': str(error),
'line': error.lineno,
'suggestion': '检查缩进是否一致,建议使用 4 个空格'
}
def _handle_tab_error(self, error: TabError, source: str) -> Dict[str, Any]:
"""处理制表符错误"""
return {
'type': 'TabError',
'message': str(error),
'line': error.lineno,
'suggestion': '不要混用制表符和空格,建议只使用空格'
}
def _handle_generic_error(self, error: Exception, source: str) -> Dict[str, Any]:
"""处理通用错误"""
return {
'type': type(error).__name__,
'message': str(error),
'suggestion': '检查代码语法和结构'
}
def _suggest_syntax_fix(self, error: SyntaxError) -> str:
"""建议语法修复"""
message = str(error).lower()
if 'invalid syntax' in message:
if error.text and ':' not in error.text:
return '可能缺少冒号(:)'
elif error.text and '(' in error.text and ')' not in error.text:
return '可能缺少右括号(])'
elif error.text and '[' in error.text and ']' not in error.text:
return '可能缺少右方括号(])'
elif 'unexpected eof' in message:
return '代码不完整,可能缺少结束符'
elif 'unmatched' in message:
return '括号或引号不匹配'
return '检查语法错误'
## 测试错误处理
print("错误处理测试:")
error_handler = CompileErrorHandler()
## 测试各种错误代码
error_codes = [
("x = 1 +", "语法错误:表达式不完整"),
("if x == 1\n print('hello')", "缩进错误"),
("def func(\n return 42", "括号不匹配"),
("x = [1, 2, 3\nprint(x)", "方括号不匹配"),
("\tif True:\n print('mixed')", "制表符和空格混用"),
("x = 'unclosed string", "字符串未闭合")
]
for i, (code, description) in enumerate(error_codes, 1):
print(f"\n 测试 {i}: {description}")
print(f"代码: {repr(code)}")
code_obj, errors = error_handler.safe_compile(code)
if code_obj:
print(f" ✓ 编译成功")
else:
for error in errors:
print(f" ✗ {error['type']}: {error['message']}")
if 'line' in error:
print(f" 位置: 第 {error['line']} 行")
print(f" 建议: {error['suggestion']}")
## 编译最佳实践
print(f"\n 编译最佳实践示例:")
class BestPracticeCompiler:
"""最佳实践编译器"""
def __init__(self):
self.cache = CodeCache()
self.error_handler = CompileErrorHandler()
def compile_with_validation(self, source: str, filename: str = '<string>',
mode: str = 'exec') -> tuple[Any, List[str]]:
"""带验证的编译"""
warnings = []
# # 预处理检查
warnings.extend(self._check_code_quality(source))
# # 编译代码
code_obj, errors = self.error_handler.safe_compile(source, filename, mode)
if code_obj:
# # 后处理检查
warnings.extend(self._check_compiled_code(code_obj))
return code_obj, errors + warnings
def _check_code_quality(self, source: str) -> List[Dict[str, Any]]:
"""检查代码质量"""
warnings = []
lines = source.split('\n')
for i, line in enumerate(lines, 1):
# # 检查行长度
if len(line) > 100:
warnings.append({
'type': 'Warning',
'message': f'第 {i} 行过长 ({len(line)} 字符)',
'suggestion': '建议将长行拆分'
})
# # 检查制表符
if '\t' in line:
warnings.append({
'type': 'Warning',
'message': f'第 {i} 行包含制表符',
'suggestion': '建议使用空格代替制表符'
})
# # 检查尾随空格
if line.endswith(' '):
warnings.append({
'type': 'Warning',
'message': f'第 {i} 行有尾随空格',
'suggestion': '删除尾随空格'
})
return warnings
def _check_compiled_code(self, code_obj: types.CodeType) -> List[Dict[str, Any]]:
"""检查编译后的代码"""
warnings = []
# # 检查复杂度
analyzer = BytecodeAnalyzer()
analysis = analyzer.analyze_code(code_obj)
complexity = analysis['complexity']['cyclomatic_complexity']
if complexity > 10:
warnings.append({
'type': 'Warning',
'message': f'代码复杂度过高 ({complexity})',
'suggestion': '考虑重构以降低复杂度'
})
# # 检查字节码大小
bytecode_size = analysis['bytecode']['bytecode_size']
if bytecode_size > 1000:
warnings.append({
'type': 'Warning',
'message': f'字节码过大 ({bytecode_size} 字节)',
'suggestion': '考虑拆分代码'
})
return warnings
## 测试最佳实践编译器
print("最佳实践编译器测试:")
best_compiler = BestPracticeCompiler()
## 测试代码(故意包含一些问题)
test_code_with_issues = """
def very_long_function_name_that_exceeds_normal_length_recommendations(param1, param2, param3, param4, param5):
# # 这行有尾随空格
x = 1
if x == 1:
if x == 1:
if x == 1:
if x == 1:
if x == 1:
print('deeply nested') # 高复杂度
return x
"""
code_obj, issues = best_compiler.compile_with_validation(test_code_with_issues)
if code_obj:
print(" ✓ 编译成功,但有以下问题:")
else:
print(" ✗ 编译失败:")
for issue in issues:
print(f" {issue['type']}: {issue['message']}")
print(f" 建议: {issue['suggestion']}")
🔗 相关函数和模块
内置函数
eval()- 执行表达式exec()- 执行语句globals()- 获取全局命名空间locals()- 获取局部命名空间vars()- 获取对象的属性字典
标准库模块
ast- 抽象语法树操作dis- 字节码反汇编code- 代码对象工具types- 动态类型创建inspect- 对象检查py_compile- Python 文件编译compileall- 批量编译
第三方库
astunparse- AST 反解析astor- AST 操作工具black- 代码格式化pylint- 代码质量检查
📚 扩展阅读
🏷️ 标签
代码编译 字节码 代码对象 动态编译 AST 性能优化 代码分析 编译器
讨论与反馈
欢迎在下方留言讨论,分享你的学习心得或提出问题。评论基于GitHub Issues,需要GitHub账号。