In the previous post we started modifying functions body for AOP purposes. So far we’ve transformed a function to an AST node with extended code. In this post we will transform it to bytecode and replace the original method code with the new one.
As to this moment, we have these two:
- original_function – the decorated function, without any modification.
- modified_node – the modified node, with extended code and fixed locations.
Compiling the node into a python module
Given the node, we can use Python built-in compile. Compile takes either string or AST node. The node is transformed into a module which we can later execute.
def compile_node(modified_node):
compiled_method = compile(modified_node, '<string>', 'exec')
return compiled_method
The two arguments we send to compile are the source file name to which the method will be linked. In our case, we’re not going to keep the method compiled method, and only use its bytecode. Therefore, we can leave a default value ‘<string>’ which is commonly used. The second argument is ‘exec’, which signals the compiler that multiple statements may appear in the input code and that we don’t care about the returned value. This is a simplified explanation, but sufficient for this context.
Getting replacement bytecode
Python code object does not allow setting the bytecode. Therefore we need to create a new code object, which is identical to the original code, other than the bytecode and stacksize:
def create_replacement_code(original_function, compiled_function):
# these are the code object field names, ordered by the expected order of code object constructor
code_arg_names = ['co_argcount', 'co_nlocals', 'co_stacksize',
'co_flags', 'co_code', 'co_consts', 'co_names',
'co_varnames', 'co_filename', 'co_name',
'co_firstlineno', 'co_lnotab']
# fill all the args based on the original code
code_args = [getattr(original_function.func_code, key, None)
for key in code_arg_names]
# replace the bytecode and stacksize with the compiled bytecode and stacksize from the modified function
code = extract_code(compiled_function)
# replace bytecode and stacksize with the compiled method values
code_args[code_arg_names.index('co_code')] = code.co_code
code_args[code_arg_names.index('co_stacksize')] = code.co_stacksize
return CodeType(*code_args)
The instantiation of the code object is undocumented. Yet, some examples are available online. Please note that this signature fits python 2.7 and was changed in 3.0. Now we finished transforming a python code (which is the original function) into a new code object, containing the modified bytecode.
The last two methods calls can be extracted to:
def create_func_code(original_function, modified_node):
compiled_method = compile_node(modified_node)
return create_replacement_code(original_function, compiled_method)
Replacing the code
The last step required is to switch the function’s code into the modified one. Luckily, this step is very simple:
def replace_code(original_function, code):
original_function.func_code = code
Now, when the decorated method is executed, it’ll contain the modified bytecode, which checks if variables used in comparisons are not None.
The full code
from _ast import BoolOp, And, Name, Compare, IsNot, Load, Num
from ast import NodeTransformer, fix_missing_locations
import ast
import inspect
from types import FunctionType, CodeType
class ComparisonTransformer(NodeTransformer):
def visit_Compare(self, node):
parts = [node.left] + node.comparators
# check if any constant number is involved in the comparison
if not any(isinstance(part, Num) for part in parts):
return node
# get all the "variables" involved in the comparison
names = [element for element in parts if isinstance(element, Name)]
if len(names) == 0:
return node
# create a reference to None
none = Name(id='None', ctx=Load())
# create for each variable a node that represents 'var is not None'
node_verifiers = [Compare(left=name, ops=[IsNot()], comparators=[none]) for name in names]
# combine the None checks with the original comparison
# e.g. 'a < b < 1' --> 'a is not None and b is not None and a < b < 1
return BoolOp(op=And(), values=node_verifiers + [node])
def rewrite_comparisons(original_function):
assert isinstance(original_function, FunctionType)
node = parse_method(original_function)
rewrite_method(node)
code = create_func_code(original_function, node)
replace_code(original_function, code)
return original_function
def replace_code(original_function, code):
original_function.func_code = code
def parse_method(original_function):
return ast.parse(inspect.getsource(original_function))
def rewrite_method(node):
# assuming the method has single decorator (which is the rewriter) - remove it
node.body[0].decorator_list.pop()
# we rename the method to ensure separation from the original one.
# this step has no real meaning and not really required.
node.body[0].name = 'internal_method'
# transform Compare nodes to fit the 'is not None' requirement
ComparisonTransformer().visit(node)
# let python try and fill code locations for the new elements
fix_missing_locations(node)
def create_func_code(original_function, modified_node):
compiled_method = compile_node(modified_node)
return create_replacement_code(original_function, compiled_method)
def compile_node(modified_node):
compiled_method = compile(modified_node, '<string>', 'exec')
return compiled_method
def extract_code(compiled_method):
exec compiled_method
generated_func = locals()['internal_method']
return generated_func.func_code
def create_replacement_code(original_function, compiled_function):
# these are the code object field names, ordered by the expected order of code object constructor
code_arg_names = ['co_argcount', 'co_nlocals', 'co_stacksize',
'co_flags', 'co_code', 'co_consts', 'co_names',
'co_varnames', 'co_filename', 'co_name',
'co_firstlineno', 'co_lnotab']
# fill all the args based on the original code
code_args = [getattr(original_function.func_code, key, None)
for key in code_arg_names]
# replace the bytecode and stacksize with the compiled bytecode and stacksize from the modified function
code = extract_code(compiled_function)
# replace bytecode and stacksize with the compiled method values
code_args
[code_arg_names.index('co_code')] = code.co_code
code_args[code_arg_names.index('co_stacksize')] = code.co_stacksize
return CodeType(*code_args)
@rewrite_comparisons
def foo(x):
return x < 1
print foo(None)
This code, will output the following, as expected:
False
To prove that this code is works in other context, the following code:
@rewrite_comparisons
def bar(x, y, z, w):
if x < 1 or y < 2 or z < 3:
return 'Default behavior'
if w < 1:
return 'Expected behavior'
return 'Failure'
print bar(None, None, None, 0)
Prints the following result:
Expected behavior
Summary
Python is amazingly flexible and using ~50 lines of code we can create a micro-framework to manipulate methods behavior. In addition, this flexibility allows code changes in many levels, including instructions level changes.
The vast majority of use cases are on the method level and not the instructions. Yet, this facilities can be useful in some cases an we should be able to take advantage of them.