How can I call method as a callback?

108 views Asked by At

I'm trying to rewrite my driver to the OOP architecture and I'm faced with the problem of calling a method like callback when I don't have any std::XXXX, as well as when arguments to the function itself are passed by a template.

So I have the following class:

class Log {
public:
    /*static*/ VOID LogNotifyUsermodeCallback(PKDPC Dpc, PVOID DeferredContext, PVOID SystemArgument1, PVOID SystemArgument2);

    NTSTATUS LogRegisterIrpBasedNotification(PDEVICE_OBJECT DeviceObject, PIRP Irp);
    NTSTATUS LogRegisterEventBasedNotification(PDEVICE_OBJECT DeviceObject, PIRP Irp);
};

Also I have the following code:

NTSTATUS Log::LogRegisterIrpBasedNotification(PDEVICE_OBJECT DeviceObject, PIRP Irp)
{
//Some line of code....
        KeInitializeDpc(&NotifyRecord->Dpc, // Dpc
            LogNotifyUsermodeCallback, // DeferredRoutine
            NotifyRecord // DeferredContext
        );
//Some line of code....
}

DeferredRoutine has the following pattern of calling arguments:

VOID Log::LogNotifyUsermodeCallback(PKDPC Dpc, PVOID DeferredContext, PVOID SystemArgument1, PVOID SystemArgument2)

I can't call the LogNotifyUsermodeCallback method as static, because arguments are passed there by template, and other methods of the Log class are also called in this method itself, and I don't have other ideas. So, the main question is: how, under all these conditions, can I call LogNotifyUsermodeCallback as a callback method?

3

There are 3 answers

0
Barbosso On BEST ANSWER

As I said earlier, arguments are passed to my function by windows itself => I can't let my arguments go there, since they will violate the order of the arguments being passed.

My solution:

I'm creating a C callback + global variable.

namespace Cwrapper {
    Log* GLogPtr = nullptr;
    void Callback(PKDPC Dpc, PVOID DeferredContext, PVOID SystemArgument1, PVOID SystemArgument2)
    {
        if (GLogPtr) { return GLogPtr->LogNotifyUsermodeCallback(Dpc, DeferredContext, SystemArgument1, SystemArgument2); }
    }
};

Next, I pass the this pointer to my specific class to GLogPtr

Cwrapper::GLogPtr = this;
KeInitializeDpc(&NotifyRecord->Dpc, // Dpc
    Cwrapper::Callback, // DeferredRoutine
    NotifyRecord // DeferredContext
);

From the information I have read, this method is also called a "trampoline"

Some information: https://web.archive.org/web/20230822094216/https://codeyarns.com/tech/2015-09-01-how-to-register-class-method-as-c-callback.html

0
Jarod42 On

Whereas the C++way to register function would be std::function, the C way to handle registration is function pointer + userData pointer.

// From library
using callBackType = Ret (Arg1, Arg2, void* /* userData */);
void registerCallback(callBackType* f, void* userData);

// Your code
struct Obj
{
    static Ret MyStaticCallBack(Arg1 arg1, Arg2 arg2, void* userData) {
        Obj* self = reinterpret_cast<Obj*>(userData);

        return self->MyMethod(arg1, arg2);
    }
    Ret MyMethod(Arg1 arg1, Arg2 arg2);
};

Then you code to register callback would be:

Obj my_object;

registerCallback(&Obj::MyStaticCallBack, my_object);
0
paulsm4 On

SUGGESTION:

  1. Define an abstract base class with your "callback method".
  2. Define a subclasses for each different implemention of that method.
  3. Pass the desired subclass to whatever function or method needs to use it.

EXAMPLE:

include <iostream>

using namespace std;

// Define an abstract base class
class LogNotifier {
public:    
    virtual int Register() = 0;
};

// One implementation...
class IrpBasedNotificationLog : public LogNotifier {
public:
    virtual int Register() {
       cout << "Greetings from IrpBasedNotificationLog.Register...\n";
       return 0;
    }
};

// A second, different implementation...
class EventBasedBasedNotificationLog : public LogNotifier {
public:
    virtual int Register() {
       cout << "Greetings from EventBasedBasedNotificationLog .Register...\n";
       return 0;
    }
};

// Invoke the "callback"
void Client(LogNotifier& theLog)
{
  theLog.Register();
}

int main()
{
  EventBasedBasedNotificationLog myLog;
  Client(myLog);
}