Calling Numba-generated PyCFunctionWithKeywords from Python

69 views Asked by At

I serialized a jitted Numba function to a byte array and now want to deserialize and call it. This works fine for primitive data types with llvm_cfunc_wrapper_name:

import numba, ctypes
import llvmlite.binding as llvm

@numba.njit("f8(f8)")
def foo(x):
    return x + 0.5

# serialize function to byte array
sig = foo.signatures[0]
lib = foo.overloads[sig].library
cfunc_name = foo.overloads[sig].fndesc.llvm_cfunc_wrapper_name
function_bytes = lib._get_compiled_object()

# deserialize function_bytes to func
llvm.initialize()
llvm.initialize_native_target()
llvm.initialize_native_asmprinter()
target = llvm.Target.from_default_triple()
target_machine = target.create_target_machine()
backing_mod = llvm.parse_assembly("")
engine = llvm.create_mcjit_compiler(backing_mod, target_machine)
engine.add_object_file(llvm.ObjectFileRef.from_data(function_bytes))
func_ptr = engine.get_function_address(cfunc_name)

func = ctypes.CFUNCTYPE(ctypes.c_double, ctypes.c_double)(func_ptr)

print(func(0.25))

But I want to call functions with NumPy arguments. There is a llvm_cpython_wrapper_name for that which uses PyCFunctionWithKeywords, but unfortunately my best guess segfaults:

import numba, ctypes
import llvmlite.binding as llvm
import numpy as np

@numba.njit("f8[:](f8[:])")
def foo(x):
    return x + 0.5

# serialize function to byte array
sig = foo.signatures[0]
lib = foo.overloads[sig].library
cpython_name = foo.overloads[sig].fndesc.llvm_cpython_wrapper_name
function_bytes = lib._get_compiled_object()

# deserialize function_bytes to func
llvm.initialize()
llvm.initialize_native_target()
llvm.initialize_native_asmprinter()
target = llvm.Target.from_default_triple()
target_machine = target.create_target_machine()
backing_mod = llvm.parse_assembly("")
engine = llvm.create_mcjit_compiler(backing_mod, target_machine)
engine.add_object_file(llvm.ObjectFileRef.from_data(function_bytes))
func_ptr = engine.get_function_address(cpython_name)

def func(*args, **kwargs):
    py_obj_ptr = ctypes.POINTER(ctypes.py_object)
    return ctypes.CFUNCTYPE(py_obj_ptr, py_obj_ptr, py_obj_ptr, py_obj_ptr)(func_ptr)(
        ctypes.cast(id(None), py_obj_ptr),
        ctypes.cast(id(args), py_obj_ptr),
        ctypes.cast(id(kwargs), py_obj_ptr))

# segfaults here
print(func(np.ones(3)))

Here are some links to Numba source code (unfortunately very hard to follow), which might be helpful to figure this out.

0

There are 0 answers