Python low-level AOP using AST rewriting – part II

In the previous post we started modifying functions body for AOP purposes. So far we’ve transformed a function to an AST node with extended code. In this post we will transform it to bytecode and replace the original method code with the new one.

As to this moment, we have these two:

  1. original_function – the decorated function, without any modification.
  2. modified_node – the modified node, with extended code and fixed locations.

Compiling the node into a python module

Given the node, we can use Python built-in compile. Compile takes either string or AST node. The node is transformed into a module which we can later execute.

def compile_node(modified_node):
    compiled_method = compile(modified_node, '<string>', 'exec')
    return compiled_method

The two arguments we send to compile are the source file name to which the method will be linked. In our case, we’re not going to keep the method compiled method, and only use its bytecode. Therefore, we can leave a default value ‘<string>’ which is commonly used. The second argument is ‘exec’, which signals the compiler that multiple statements may appear in the input code and that we don’t care about the returned value. This is a simplified explanation, but sufficient for this context.

Getting replacement bytecode

Python code object does not allow setting the bytecode. Therefore we need to create a new code object, which is identical to the original code, other than the bytecode and stacksize:

def create_replacement_code(original_function, compiled_function):
    # these are the code object field names, ordered by the expected order of code object constructor
    code_arg_names = ['co_argcount', 'co_nlocals', 'co_stacksize',
                      'co_flags', 'co_code', 'co_consts', 'co_names',
                      'co_varnames', 'co_filename', 'co_name',
                      'co_firstlineno', 'co_lnotab']
    # fill all the args based on the original code
    code_args = [getattr(original_function.func_code, key, None)
                 for key in code_arg_names]
    # replace the bytecode and stacksize with the compiled bytecode and stacksize from the modified function
    code = extract_code(compiled_function)
    # replace bytecode and stacksize with the compiled method values
    code_args[code_arg_names.index('co_code')] = code.co_code
    code_args[code_arg_names.index('co_stacksize')] = code.co_stacksize
    return CodeType(*code_args)

The instantiation of the code object is undocumented. Yet, some examples are available online. Please note that this signature fits python 2.7 and was changed in 3.0. Now we finished transforming a python code (which is the original function) into a new code object, containing the modified bytecode.

The last two methods calls can be extracted to:

def create_func_code(original_function, modified_node):
    compiled_method = compile_node(modified_node)
    return create_replacement_code(original_function, compiled_method)

Replacing the code

The last step required is to switch the function’s code into the modified one. Luckily, this step is very simple:

def replace_code(original_function, code):
    original_function.func_code = code

Now, when the decorated method is executed, it’ll contain the modified bytecode, which checks if variables used in comparisons are not None.

The full code

from _ast import BoolOp, And, Name, Compare, IsNot, Load, Num
from ast import NodeTransformer, fix_missing_locations
import ast
import inspect
from types import FunctionType, CodeType


class ComparisonTransformer(NodeTransformer):
    def visit_Compare(self, node):
        parts = [node.left] + node.comparators

        # check if any constant number is involved in the comparison
        if not any(isinstance(part, Num) for part in parts):
            return node

        # get all the "variables" involved in the comparison
        names = [element for element in parts if isinstance(element, Name)]
        if len(names) == 0:
            return node

        # create a reference to None
        none = Name(id='None', ctx=Load())
        # create for each variable a node that represents 'var is not None'
        node_verifiers = [Compare(left=name, ops=[IsNot()], comparators=[none]) for name in names]
        # combine the None checks with the original comparison
        # e.g. 'a &lt; b &lt; 1' --&gt; 'a is not None and b is not None and a &lt; b &lt; 1
        return BoolOp(op=And(), values=node_verifiers + [node])


def rewrite_comparisons(original_function):
    assert isinstance(original_function, FunctionType)

    node = parse_method(original_function)
    rewrite_method(node)
    code = create_func_code(original_function, node)
    replace_code(original_function, code)
    return original_function


def replace_code(original_function, code):
    original_function.func_code = code


def parse_method(original_function):
    return ast.parse(inspect.getsource(original_function))


def rewrite_method(node):
    # assuming the method has single decorator (which is the rewriter) - remove it
    node.body[0].decorator_list.pop()
    # we rename the method to ensure separation from the original one.
    # this step has no real meaning and not really required.
    node.body[0].name = 'internal_method'
    # transform Compare nodes to fit the 'is not None' requirement
    ComparisonTransformer().visit(node)
    # let python try and fill code locations for the new elements
    fix_missing_locations(node)


def create_func_code(original_function, modified_node):
    compiled_method = compile_node(modified_node)
    return create_replacement_code(original_function, compiled_method)


def compile_node(modified_node):
    compiled_method = compile(modified_node, '&lt;string&gt;', 'exec')
    return compiled_method


def extract_code(compiled_method):
    exec compiled_method
    generated_func = locals()['internal_method']
    return generated_func.func_code


def create_replacement_code(original_function, compiled_function):
    # these are the code object field names, ordered by the expected order of code object constructor
    code_arg_names = ['co_argcount', 'co_nlocals', 'co_stacksize',
                      'co_flags', 'co_code', 'co_consts', 'co_names',
                      'co_varnames', 'co_filename', 'co_name',
                      'co_firstlineno', 'co_lnotab']
    # fill all the args based on the original code
    code_args = [getattr(original_function.func_code, key, None)
                 for key in code_arg_names]
    # replace the bytecode and stacksize with the compiled bytecode and stacksize from the modified function
    code = extract_code(compiled_function)
    # replace bytecode and stacksize with the compiled method values
    code_args
[code_arg_names.index('co_code')] = code.co_code
    code_args[code_arg_names.index('co_stacksize')] = code.co_stacksize
    return CodeType(*code_args)


@rewrite_comparisons
def foo(x):
    return x < 1

print foo(None)

This code, will output the following, as expected:

False

To prove that this code is works in other context, the following code:

@rewrite_comparisons
def bar(x, y, z, w):
    if x < 1 or y < 2 or z < 3:
        return 'Default behavior'
    if w < 1:
        return 'Expected behavior'
    return 'Failure'


print bar(None, None, None, 0)

Prints the following result:

Expected behavior

Summary

Python is amazingly flexible and using ~50 lines of code we can create a micro-framework to manipulate methods behavior. In addition, this flexibility allows code changes in many levels, including instructions level changes.

The vast majority of use cases are on the method level and not the instructions. Yet, this facilities can be useful in some cases an we should be able to take advantage of them.

Leave a comment