In the previous post we started modifying functions body for AOP purposes. So far we’ve transformed a function to an AST node with extended code. In this post we will transform it to bytecode and replace the original method code with the new one.
As to this moment, we have these two:
- original_function – the decorated function, without any modification.
- modified_node – the modified node, with extended code and fixed locations.
Compiling the node into a python module
Given the node, we can use Python built-in compile. Compile takes either string or AST node. The node is transformed into a module which we can later execute.
def compile_node(modified_node): compiled_method = compile(modified_node, '<string>', 'exec') return compiled_method
The two arguments we send to compile are the source file name to which the method will be linked. In our case, we’re not going to keep the method compiled method, and only use its bytecode. Therefore, we can leave a default value ‘<string>’ which is commonly used. The second argument is ‘exec’, which signals the compiler that multiple statements may appear in the input code and that we don’t care about the returned value. This is a simplified explanation, but sufficient for this context.
Getting replacement bytecode
Python code object does not allow setting the bytecode. Therefore we need to create a new code object, which is identical to the original code, other than the bytecode and stacksize:
def create_replacement_code(original_function, compiled_function): # these are the code object field names, ordered by the expected order of code object constructor code_arg_names = ['co_argcount', 'co_nlocals', 'co_stacksize', 'co_flags', 'co_code', 'co_consts', 'co_names', 'co_varnames', 'co_filename', 'co_name', 'co_firstlineno', 'co_lnotab'] # fill all the args based on the original code code_args = [getattr(original_function.func_code, key, None) for key in code_arg_names] # replace the bytecode and stacksize with the compiled bytecode and stacksize from the modified function code = extract_code(compiled_function) # replace bytecode and stacksize with the compiled method values code_args[code_arg_names.index('co_code')] = code.co_code code_args[code_arg_names.index('co_stacksize')] = code.co_stacksize return CodeType(*code_args)
The instantiation of the code object is undocumented. Yet, some examples are available online. Please note that this signature fits python 2.7 and was changed in 3.0. Now we finished transforming a python code (which is the original function) into a new code object, containing the modified bytecode.
The last two methods calls can be extracted to:
def create_func_code(original_function, modified_node): compiled_method = compile_node(modified_node) return create_replacement_code(original_function, compiled_method)
Replacing the code
The last step required is to switch the function’s code into the modified one. Luckily, this step is very simple:
def replace_code(original_function, code): original_function.func_code = code
Now, when the decorated method is executed, it’ll contain the modified bytecode, which checks if variables used in comparisons are not None.
The full code
from _ast import BoolOp, And, Name, Compare, IsNot, Load, Num from ast import NodeTransformer, fix_missing_locations import ast import inspect from types import FunctionType, CodeType class ComparisonTransformer(NodeTransformer): def visit_Compare(self, node): parts = [node.left] + node.comparators # check if any constant number is involved in the comparison if not any(isinstance(part, Num) for part in parts): return node # get all the "variables" involved in the comparison names = [element for element in parts if isinstance(element, Name)] if len(names) == 0: return node # create a reference to None none = Name(id='None', ctx=Load()) # create for each variable a node that represents 'var is not None' node_verifiers = [Compare(left=name, ops=[IsNot()], comparators=[none]) for name in names] # combine the None checks with the original comparison # e.g. 'a < b < 1' --> 'a is not None and b is not None and a < b < 1 return BoolOp(op=And(), values=node_verifiers + [node]) def rewrite_comparisons(original_function): assert isinstance(original_function, FunctionType) node = parse_method(original_function) rewrite_method(node) code = create_func_code(original_function, node) replace_code(original_function, code) return original_function def replace_code(original_function, code): original_function.func_code = code def parse_method(original_function): return ast.parse(inspect.getsource(original_function)) def rewrite_method(node): # assuming the method has single decorator (which is the rewriter) - remove it node.body[0].decorator_list.pop() # we rename the method to ensure separation from the original one. # this step has no real meaning and not really required. node.body[0].name = 'internal_method' # transform Compare nodes to fit the 'is not None' requirement ComparisonTransformer().visit(node) # let python try and fill code locations for the new elements fix_missing_locations(node) def create_func_code(original_function, modified_node): compiled_method = compile_node(modified_node) return create_replacement_code(original_function, compiled_method) def compile_node(modified_node): compiled_method = compile(modified_node, '<string>', 'exec') return compiled_method def extract_code(compiled_method): exec compiled_method generated_func = locals()['internal_method'] return generated_func.func_code def create_replacement_code(original_function, compiled_function): # these are the code object field names, ordered by the expected order of code object constructor code_arg_names = ['co_argcount', 'co_nlocals', 'co_stacksize', 'co_flags', 'co_code', 'co_consts', 'co_names', 'co_varnames', 'co_filename', 'co_name', 'co_firstlineno', 'co_lnotab'] # fill all the args based on the original code code_args = [getattr(original_function.func_code, key, None) for key in code_arg_names] # replace the bytecode and stacksize with the compiled bytecode and stacksize from the modified function code = extract_code(compiled_function) # replace bytecode and stacksize with the compiled method values code_args [code_arg_names.index('co_code')] = code.co_code code_args[code_arg_names.index('co_stacksize')] = code.co_stacksize return CodeType(*code_args) @rewrite_comparisons def foo(x): return x < 1 print foo(None)
This code, will output the following, as expected:
False
To prove that this code is works in other context, the following code:
@rewrite_comparisons def bar(x, y, z, w): if x < 1 or y < 2 or z < 3: return 'Default behavior' if w < 1: return 'Expected behavior' return 'Failure' print bar(None, None, None, 0)
Prints the following result:
Expected behavior
Summary
Python is amazingly flexible and using ~50 lines of code we can create a micro-framework to manipulate methods behavior. In addition, this flexibility allows code changes in many levels, including instructions level changes.
The vast majority of use cases are on the method level and not the instructions. Yet, this facilities can be useful in some cases an we should be able to take advantage of them.