Python low-level AOP using AST rewriting – part II

In the previous post we started modifying functions body for AOP purposes. So far we’ve transformed a function to an AST node with extended code. In this post we will transform it to bytecode and replace the original method code with the new one.

As to this moment, we have these two:

  1. original_function – the decorated function, without any modification.
  2. modified_node – the modified node, with extended code and fixed locations.

Compiling the node into a python module

Given the node, we can use Python built-in compile. Compile takes either string or AST node. The node is transformed into a module which we can later execute.

def compile_node(modified_node):
    compiled_method = compile(modified_node, '<string>', 'exec')
    return compiled_method

The two arguments we send to compile are the source file name to which the method will be linked. In our case, we’re not going to keep the method compiled method, and only use its bytecode. Therefore, we can leave a default value ‘<string>’ which is commonly used. The second argument is ‘exec’, which signals the compiler that multiple statements may appear in the input code and that we don’t care about the returned value. This is a simplified explanation, but sufficient for this context.

Getting replacement bytecode

Python code object does not allow setting the bytecode. Therefore we need to create a new code object, which is identical to the original code, other than the bytecode and stacksize:

def create_replacement_code(original_function, compiled_function):
    # these are the code object field names, ordered by the expected order of code object constructor
    code_arg_names = ['co_argcount', 'co_nlocals', 'co_stacksize',
                      'co_flags', 'co_code', 'co_consts', 'co_names',
                      'co_varnames', 'co_filename', 'co_name',
                      'co_firstlineno', 'co_lnotab']
    # fill all the args based on the original code
    code_args = [getattr(original_function.func_code, key, None)
                 for key in code_arg_names]
    # replace the bytecode and stacksize with the compiled bytecode and stacksize from the modified function
    code = extract_code(compiled_function)
    # replace bytecode and stacksize with the compiled method values
    code_args[code_arg_names.index('co_code')] = code.co_code
    code_args[code_arg_names.index('co_stacksize')] = code.co_stacksize
    return CodeType(*code_args)

The instantiation of the code object is undocumented. Yet, some examples are available online. Please note that this signature fits python 2.7 and was changed in 3.0. Now we finished transforming a python code (which is the original function) into a new code object, containing the modified bytecode.

The last two methods calls can be extracted to:

def create_func_code(original_function, modified_node):
    compiled_method = compile_node(modified_node)
    return create_replacement_code(original_function, compiled_method)

Replacing the code

The last step required is to switch the function’s code into the modified one. Luckily, this step is very simple:

def replace_code(original_function, code):
    original_function.func_code = code

Now, when the decorated method is executed, it’ll contain the modified bytecode, which checks if variables used in comparisons are not None.

The full code

from _ast import BoolOp, And, Name, Compare, IsNot, Load, Num
from ast import NodeTransformer, fix_missing_locations
import ast
import inspect
from types import FunctionType, CodeType


class ComparisonTransformer(NodeTransformer):
    def visit_Compare(self, node):
        parts = [node.left] + node.comparators

        # check if any constant number is involved in the comparison
        if not any(isinstance(part, Num) for part in parts):
            return node

        # get all the "variables" involved in the comparison
        names = [element for element in parts if isinstance(element, Name)]
        if len(names) == 0:
            return node

        # create a reference to None
        none = Name(id='None', ctx=Load())
        # create for each variable a node that represents 'var is not None'
        node_verifiers = [Compare(left=name, ops=[IsNot()], comparators=[none]) for name in names]
        # combine the None checks with the original comparison
        # e.g. 'a &lt; b &lt; 1' --&gt; 'a is not None and b is not None and a &lt; b &lt; 1
        return BoolOp(op=And(), values=node_verifiers + [node])


def rewrite_comparisons(original_function):
    assert isinstance(original_function, FunctionType)

    node = parse_method(original_function)
    rewrite_method(node)
    code = create_func_code(original_function, node)
    replace_code(original_function, code)
    return original_function


def replace_code(original_function, code):
    original_function.func_code = code


def parse_method(original_function):
    return ast.parse(inspect.getsource(original_function))


def rewrite_method(node):
    # assuming the method has single decorator (which is the rewriter) - remove it
    node.body[0].decorator_list.pop()
    # we rename the method to ensure separation from the original one.
    # this step has no real meaning and not really required.
    node.body[0].name = 'internal_method'
    # transform Compare nodes to fit the 'is not None' requirement
    ComparisonTransformer().visit(node)
    # let python try and fill code locations for the new elements
    fix_missing_locations(node)


def create_func_code(original_function, modified_node):
    compiled_method = compile_node(modified_node)
    return create_replacement_code(original_function, compiled_method)


def compile_node(modified_node):
    compiled_method = compile(modified_node, '&lt;string&gt;', 'exec')
    return compiled_method


def extract_code(compiled_method):
    exec compiled_method
    generated_func = locals()['internal_method']
    return generated_func.func_code


def create_replacement_code(original_function, compiled_function):
    # these are the code object field names, ordered by the expected order of code object constructor
    code_arg_names = ['co_argcount', 'co_nlocals', 'co_stacksize',
                      'co_flags', 'co_code', 'co_consts', 'co_names',
                      'co_varnames', 'co_filename', 'co_name',
                      'co_firstlineno', 'co_lnotab']
    # fill all the args based on the original code
    code_args = [getattr(original_function.func_code, key, None)
                 for key in code_arg_names]
    # replace the bytecode and stacksize with the compiled bytecode and stacksize from the modified function
    code = extract_code(compiled_function)
    # replace bytecode and stacksize with the compiled method values
    code_args
[code_arg_names.index('co_code')] = code.co_code
    code_args[code_arg_names.index('co_stacksize')] = code.co_stacksize
    return CodeType(*code_args)


@rewrite_comparisons
def foo(x):
    return x < 1

print foo(None)

This code, will output the following, as expected:

False

To prove that this code is works in other context, the following code:

@rewrite_comparisons
def bar(x, y, z, w):
    if x < 1 or y < 2 or z < 3:
        return 'Default behavior'
    if w < 1:
        return 'Expected behavior'
    return 'Failure'


print bar(None, None, None, 0)

Prints the following result:

Expected behavior

Summary

Python is amazingly flexible and using ~50 lines of code we can create a micro-framework to manipulate methods behavior. In addition, this flexibility allows code changes in many levels, including instructions level changes.

The vast majority of use cases are on the method level and not the instructions. Yet, this facilities can be useful in some cases an we should be able to take advantage of them.

Python low-level AOP using AST rewriting – part I

This post and the next post will address AOP in Python. In general AOP in Python is very simple thanks to Python’s decorators. The aspects which we would like to apply in this post are low-level, meaning they’ll be applied on in-body instructions and not just on method level. The way in which we’re going to implement it will be using code weaving and rewriting.
I previously blogged about similar concept in .Net using Mono Cecil, where we tracked IL instructions.

The topic will be covered by two posts, where the first one will address rewriting code and the second one will deal with replacing the original code.

Background

Motivation

The general motivation for AOP is to separate the business logic from other functional logic, like logging, security or error handling. Most of the common examples fit the pattern of wrapping the function with new one. Then, perform logic before/after the method is executed. This is very useful, yet, limits our ability to change behavior of specific instructions inside the method which are relevant to the aspect.

Example

During the post we will use a concrete simple example. Let us observe the following example (Python 2.7):

def foo(x):
    return x < 1

print foo(None)

As you probably know, this will print:

True

This is a common Python (2.7) behavior but might not be intuitive. In general, assuming we had many variables and many comparisons, we’d like to change all to the pattern: VAR is not None and VAR < CONST

The goal of our process will be to transform the method to:

def foo(x):
    return x is not None and x &lt; 1

Where the aspect we’re applying is Update Comparison of None and Constants.

The required steps

The steps required by this solution are the following:

  1. Decorate the method – create an entry point for the mechanism which’ll apply the aspect.
  2. Create an AST from the method –  prepare a modifiable syntax tree from the original method.
  3. Rewrite the AST – find the instructions influenced by the aspect and modify them.
  4. Create bytecode – create identical code to the original one other than newly generated bytecode.
  5. Update the method – replace the original method code with the new one.

Decorating the method

Like the common approach, we will use a decorator to modify the function. We will start from this simple decorator and build over it:

def rewrite_comparisons(original_function):
    assert isinstance(original_function, FunctionType)
    return original_function

@rewrite_comparisons
def foo(x):
    return x &lt; 1

This decorator does nothing, so far.

Getting code from function

The first challenge is getting the method code from a function and make it modifiable. Since Python provides bytecode by default for a method, we will use built-in inspect to extract the original source code:

function_source_code = inspect.getsource(original_function)

Inspect uses the code locations linked to the function and read them from the source file. The return value is a string with function. This is different from disassembling code from the method bytecode.

We can assume that for our functions the source code is available. Otherwise, this first step will fail, and the processing will need to be in bytecode level (which might be covered in other post). In addition, this constrains us to ensure decorator is called before any other decorator. Otherwise, previous decorators might be ignore since their effect is not reflected in the original source code.

Building an AST (abstract syntax tree)

After the previous line of code extracted the source, we can parse it to an AST. The motivation for building an abstract syntax tree, is that it’s modifiable and we can compile is back to bytecode.

function_source_code = inspect.getsource(original_function)
node = ast.parse(function_source_code)

The node we get is the root one of the parsed code. It links to all the elements in the hierarchy and represents a simplified module code.

Taking for example the foo function, the tree is:

Module
  # the method declaration (foo)
  FunctionDef
    # the arguments list (x)
    arguments
      Name
        Param
    # return instruction
    Return
      # comparison of two elements
      Compare
        # load variable (x)
        Name
          Load
        # comparison operator (<)
        Lt
        # load constant (1)
        Num

The AST represents the function, while the decorator is omitted for simplicity. As can easily be seen, the tree represents all the content of the method, including declaration, other methods in context if there are and more. Given the AST, we’d like to modify it a fit the need that our aspect requires.

Transforming the AST

AST visitors

We will use the AST visitors as an introduction to syntax tree traversal. The node visitors follow a convention where callback names are of pattern visit_NODETYPE(self, node), where node type can be any these. For example, if we want a callback on method calls, we can define one for the Call node and name it visit_Call(self, node).

In our example, we can visit the compare nodes, and print all the operands:

from ast import NodeVisitor


class ComparisonVisitor(NodeVisitor):
    def visit_Compare(self, node):
        operands = [node.left] + node.comparators
        print '; '.join(type(operand).__name__ for operand in operands)

For every callback, we are assured the type of the node fits the Compare node type. Given the type, we can investigate it’s members. Comparison in Python is composed of operators (one or more) and operands (two or more). In the case of Compare node, the first operand is called left, and the rest are called comparators. One of the reason for the complicated structure is to support expressions like:

0 &lt; x &lt; 100

Using the visitor we can query the nodes, but not modify them. If we visit the the original foo function:

&lt;/pre&gt;
node = ast.parse(inspect.getsource(foo))
ComparisonVisitor().visit(node)

The result we expect is:

Name; Num

Since comparison is x < 1, where x is Name load in the context and 1 is a Constant Number in the context.

AST transformers

Python provides transformers, which are a special type of AST visitors. The transformers, in contrast to nodes visitors,  modify the nodes they visit. In our example, we’ll look for nodes that represent comparison between variables and numbers, and then extend them to comply with the aspect.

from ast import NodeTransformer
from _ast import BoolOp, And, Name, Compare, IsNot, Load, Num


class ComparisonTransformer(NodeTransformer):
    def visit_Compare(self, node):
        parts = [node.left] + node.comparators

        # check if any constant number is involved in the comparison
        if not any(isinstance(part, Num) for part in parts):
            return node

        # get all the "variables" involved in the comparison
        names = [element for element in parts if isinstance(element, Name)]
        if len(names) == 0:
            return node

        # create a reference to None
        none = Name(id='None', ctx=Load())
        # create for each variable a node that represents 'var is not None'
        node_verifiers = [Compare(left=name, ops=[IsNot()], comparators=[none]) for name in names]
        # combine the None checks with the original comparison
        # e.g. 'a &lt; b &lt; 1' --&gt; 'a is not None and b is not None and a &lt; b &lt; 1
        return BoolOp(op=And(), values=node_verifiers + [node])

This chunk of code is a simplified (relaxed type input checks no attempts to code location fixes) version of a transformer that visits all nodes of type Compare. The transformer methods names use the same convention as the visitors.

According to the original behavior a new node is being built. This node is a new Boolean expression, which requires all the variables[2] in use to be not None and to satisfy the original comparison.

If we’d look at the output, the the AST will be modified and verify variables are not None before they’re compared to None. The out tree for the modified foo is:

Module
  # the method declaration (foo)
  FunctionDef
    # the arguments list (x)
    arguments
      Name
        Param
    # return instruction
    Return
      # the bool expression that combines with And:
      # 1. the original comparison
      # 2. the new check 'VAR is not None'
      BoolOp
        And
        # the 'x is not None' comparison
        Compare
          Name
            Load
          IsNot
          Name
            Load
        # the original comparison 'x < 1'
        Compare
          Name
            Load
          Lt
          Num

Prepare the node forrecompilation

In the next phase, we’re going to import the new code as temporary module, which will case the declaration of the new method to be executed again. In order to do so, we’d like to remove the rewriter decorator, since we don’t want it to process the modified function. In addition, we rename the function for safety to avoid collisions between the declared function and other locals.  Lastly, we ask python to fix code locations for the new nodes so they can be compiled later on. This is done using fix_missing_locations.

from ast fix_missing_locations


def rewrite_method(node):
    # assuming the method has single decorator (which is the rewriter) - remove it
    node.body[0].decorator_list.pop()
    # we rename the method to ensure separation from the original one.
    # this step has no real meaning and not really required.
    node.body[0].name = 'internal_method'
    # transform Compare nodes to fit the 'is not None' requirement
    ComparisonTransformer().visit(node)
    # let python try and fill code locations for the new elements
    fix_missing_locations(node)

Summary

During the first phase we got as an input a function (through a decorator), then modified it’s body by visiting it’s body using a syntactic level. Lastly, we modified it’s declaration and source locations so it can be safely imported as a new function.

As you probably notice, the only part in this code which is concerned by the aspect is the transformer. Meaning, if we’d like to apply a different aspect the only part which’ll change is the transformer. In our example the ComparisonTransformer is hard-coded for simplicity, but in real solution we’d provide it as an argument to the decorator.

Next phase

In the next phase we’ll use the modified function to generate replacement bytecode.

AOP without weaving

In this post I’ll present a usage of runtime method replacer in AOP context. The idea behind it is to change the behavior of an application without changing the IL of its methods. In this post I’ll show how to log an exception from a method.

This post is based on the work of Ziad Elmalki who posted the original method replacer. It is also based on the updated code for the method replacer by Chung Sung which is compatible with the new .NET framework versions. Lastly thanks to Roy Osherove who mentioned those recently.

Replacing methods

The method replacer uses the following concept – after a method is jitted it receives a pointer of the jitted code. You can see how to extract that address in the original post. After extracting the addresses, we can simply replace one method with another:

public static void ReplaceMethod(IntPtr srcAdr, IntPtr destAdr)
{
unsafe
{
if (IntPtr.Size == 8)
{
ulong* d = (ulong*)destAdr.ToPointer();
*d = (
ulong)srcAdr.ToInt64();
}
else
{
uint* d = (uint*)destAdr.ToPointer();
*d = (
uint)srcAdr.ToInt32();
}
}
}

As a simple example, if we have these two methods:

public class MyClass
{
public static void Foo()
{
Console.WriteLine("In Foo");
throw new Exception("I am done here!");
}

public static void Bar()
{
Console.WriteLine("In Bar");
}
}

Then executing Foo in the following context:

MethodInfo barMethod = typeof (MyClass).GetMethod("Bar");
MethodInfo fooMethod = typeof (MyClass).GetMethod("Foo");
MethodUtil.ReplaceMethod(barMethod, fooMethod);
MyClass.Foo();

Will actually lead to the next result:

image

Which is… Cool!

Catching exceptions in Foo

What I’d like to present is a simplified example of how to catch an exception in business code without modifying it. A similar functionality to PostSharp exception handling. What we’re about to do is to hijack the original calls to Foo and redirect those to our new wrapper method. Our new wrapper method will call the original one inside a try/catch block.

Storing the original Foo

Since we’re about to intercept calls to Foo based on its address, we’d like to store a “way” to call the original method later. The “way” to do it is simple, we’ll extract the method address before starting the interception and create a delegate to it using marshaling. The delegate will be stored on a field:

MethodInfo fooMethod = typeof (MyClass).GetMethod("Foo");
IntPtr fooAdress = MethodUtil.GetMethodAddress(fooMethod);
OriginalFoo =
Marshal.GetDelegateForFunctionPointer(fooAdress, typeof (Action));

Creating the wrapper

For the purpose of this example we could prepare a stub in the project istelf. But, in order to prove that it is likely possible to create a more general solution, we will generate the wrapper at runtime.

Since the wrapper is going to receive the calls instead of Foo it must have the same signature. Besides, our wrapper will retrieve the original Foo delegate from a static field named OriginalFoo. The delegate will be called from the method inside a try/catch block.

We will generate a dynamic method that replaces the original method:

// The field holding the delegate to the original Foo
FieldInfo originalFooDelegateField = typeof (FooProtector).GetField("OriginalFoo");

MethodInfo invokeDelegateMethod = OriginalFoo.GetType().GetMethod("DynamicInvoke");
MethodInfo innerExceptionGetter = typeof(Exception).GetProperty("InnerException").GetGetMethod();
MethodInfo exceptionMessageGetter = typeof(Exception).GetProperty("Message").GetGetMethod();

var dynamicMethod = new DynamicMethod("FooProtector", typeof(void), new Type[0]);
ILGenerator ilGenerator = dynamicMethod.GetILGenerator();

Label beginExceptionBlock = ilGenerator.BeginExceptionBlock();

// Preparing the call to the original Foo -
// Load the original Foo
ilGenerator.Emit(OpCodes.Ldsfld, originalFooDelegateField);
// Load "no arguments" to invoke the delegate
ilGenerator.Emit(OpCodes.Ldnull);
// Invoke the delegate and call original Foo
ilGenerator.Emit(OpCodes.Callvirt, invokeDelegateMethod);
ilGenerator.Emit(
OpCodes.Pop);

ilGenerator.Emit(
OpCodes.Leave, beginExceptionBlock);
ilGenerator.BeginCatchBlock(
typeof (Exception));

// Extract the exception message
ilGenerator.Emit(OpCodes.Callvirt, innerExceptionGetter);
ilGenerator.Emit(
OpCodes.Callvirt, exceptionMessageGetter);

// Print the exception message
MethodInfo info = typeof (Console).GetMethod("WriteLine", new[] {typeof (string)});
ilGenerator.Emit(
OpCodes.Call, info);

ilGenerator.Emit(
OpCodes.Leave, beginExceptionBlock);
ilGenerator.EndExceptionBlock();
ilGenerator.Emit(
OpCodes.Ret);

// Trigger method compilation
dynamicMethod.CreateDelegate(typeof (Action));

This wrapper calls the original method through a delegate. In case an exception is thrown, it extracts the original exception and prints to the console the message.

Is it working?

Let’s revisit the original code and update it to the protecting code:

FooProtector.ProtectFoo();
MyClass.Foo();

The expected result is two messages printed, where the second one is the exception message “I am done here!”. As we can happily see, this is the exact result:

image

Conclusion

The concept of replacing methods using their jitted versions can be useful. It can be used to for AOP where it can be used for logging, exception handling and basically applying any custom aspect. It can also be used to modify some 3rd party code behavior for which we have no source code. Additionally, as Roy says is can be used as an engine for mocking frameworks.

But there are some disadvantages too. Firstly, it is very dependent on the compilation outcome which makes it quite fragile. Secondly, it is sensitive to optimizations, for example inlined methods cannot be handled. Thirdly, when it is used extensively it requires generation and JIT of many dynamic methods which might lead to a performance hit.