How to propagating context through StructuredTaskScope by ScopedValue, by the way, how about the MDC ThreadContextMap in StructuredTaskScope?

363 views Asked by At

In this case, i need to propagating some state like tracer/span or request context per request. The jdk env is 21.0.1-preview.

I try to propagate any state between thread and virtual thread by a common way.

Also, the full implementation will be compatible with ThreadLocal scenarios like:
[context-propagation]: (https://github.com/micrometer-metrics/context-propagation) [TransmittableThreadLocal]: https://github.com/alibaba/transmittable-thread-local/blob/master/README-EN.md)

In jdk 21, Thread is defined as Platform Thread, and the following three scenarios are my targets this time around

  1. virtual threads -> platform threads
  2. platform threads -> virtual threads
  3. virtual threads -> virtual threads

Below is my specific code and includes some key ideas without platform threads cases.

Like:

@SuppressWarnings({"unchecked", "rawtypes"})
static class PropagatingTaskDecorator implements TaskDecorator {


    public record Tracer(String traceId, String spanId, String parentId) {

    }

    private final ScopedValue[] keys;

    public PropagatingTaskDecorator(ScopedValue<?>... keys) {
        this.keys = keys;
    }

    static class ScopedValueMap {

        public static final ScopedValue<Tracer> TRACER_SCOPED_VALUE = ScopedValue.newInstance();
        private final Map<ScopedValue<?>, Object> bindingsSnapshot;

        public ScopedValueMap() {
            bindingsSnapshot = Maps.newHashMap();
            if (TRACER_SCOPED_VALUE.isBound()) {
                var latestTracer = TRACER_SCOPED_VALUE.get();
                bindingsSnapshot.put(TRACER_SCOPED_VALUE, new Tracer(latestTracer.traceId(), UUID.randomUUID().toString(), latestTracer.spanId()));
            } else {
                bindingsSnapshot.put(TRACER_SCOPED_VALUE, new Tracer(UUID.randomUUID().toString(), UUID.randomUUID().toString(), null));
            }
        }

        /**
         * here is a function that can be used to capture the current bindings of a set of ScopedValues
         * @param scopedValue scopedValue
         * @param <T> T
         */
        <T> void put(ScopedValue<T> scopedValue) {
            if (scopedValue.isBound()) {
                bindingsSnapshot.put(scopedValue, scopedValue.get());
            }
        }

        <T> T get(ScopedValue<T> scopedValue) {
            if (bindingsSnapshot.containsKey(scopedValue)) {
                return (T) bindingsSnapshot.get(scopedValue);
            }
            throw new RuntimeException(STR."ScopedValue not found: \{scopedValue}");
        }
    }

    @Nonnull
    @Override
    public Runnable decorate(@Nonnull Runnable runnable) {
        // capture the current bindings of a set of ScopedValues
        var bindingsSnapshot = new ScopedValueMap();
        for (var key : keys) {
            bindingsSnapshot.put(key);
        }
        /*
         * Here's the key, propagating snapshot state between virtual threads via bindingsSnapshot,
         * which bypasses the StructuredTaskScope limitation.
         *
         * Also, the full implementation will be compatible with ThreadLocal scenarios
         * (like https://github.com/micrometer-metrics/context-propagation or https://github.com/oldratlee/log4j2-ttl-thread-context-map/blob/ master/src/main/java/com/alibaba/ttl/log4j2/TtlThreadContextMap.java)
         *
         * In jdk 21, Thread is defined as Platform Thread, and the following three scenarios are my targets this time around
         * 1. virtual threads -> platform threads
         * 2. platform threads -> virtual threads
         * 3. virtual threads -> virtual threads
         */
        return () -> {
            ScopedValue.Carrier carrier = ScopedValue.where(TRACER_SCOPED_VALUE, bindingsSnapshot.get(TRACER_SCOPED_VALUE));
            for (var key : keys) {
                carrier = carrier.where(key, bindingsSnapshot.get(key));
            }
            carrier.run(runnable);
        };
    }
}

@Slf4j
class ServiceImpl {
    public void xxx() {
        try (var executor = new SimpleAsyncTaskExecutor("pg-")) {
            executor.setVirtualThreads(true);
            var decorator = new PropagatingTaskDecorator();
            executor.setTaskDecorator(decorator);
            executor.setThreadFactory(Thread.ofVirtual().name("v-").factory());

            executor.submit(() -> {
                log.info("{}", xxx); // here log4j2 need print trace info by %X{traceId}
                                     // but how count MDC work for ScopedValue
                                     // up to now, i overwrite the ThreadContextMap and 
                                     // try to split virtual/platform thread for different
                                     // way.
            });
        }
    }
}

Can someone do me a fovor ?

1

There are 1 answers

5
igor.zh On

Welcome to fabulous SO and thank you for the interesting question!

First, from the standpoint of ScopedValue usage you code is, in my opinion, quite correct. I wasn't able to understand what are bindingsSnapshot variable and ScopedValueMap class in your code, but if you declare TRACER_SCOPED_VALUE constant like

private static final ScopedValue<String> TRACER_SCOPED_VALUE = ScopedValue.newInstance();

, create a Carrier with a value in decorate method, not in the Runnable

ScopedValue.Carrier carrier = ScopedValue.where(TRACER_SCOPED_VALUE, "Long live SO!");

, then upon the execution of the task you would be able to retrieve this value:

String value = TRACER_SCOPED_VALUE.get();

So, I believe, it fulfills your goal of propagating a state (TRACER_SCOPED_VALUE) from "a thread" (the one which executes ServiceImpl.xxx, Executor's submit, and Decorators' decorate) to "a virtual thread", started by Executor.

The following snippet demonstrates sharing a value between a thread that submits a task to an Executor and a thread that executes this task:

import org.springframework.core.task.SimpleAsyncTaskExecutor;
import org.springframework.core.task.TaskDecorator;

...

private static final ScopedValue<String> SCOPED_VALUE = ScopedValue.newInstance();

private static class ScopedValueDecorator implements TaskDecorator {

    public Runnable decorate(Runnable runnable) {
        final ScopedValue.Carrier carrier = ScopedValue.where(SCOPED_VALUE, "Long Live SO!"); 
        System.out.println(Thread.currentThread().getName() + " maps " + carrier.get(SCOPED_VALUE));
        return () -> {
            carrier.run(runnable);
        };          
    }
}

public static void main(String[] args) {
    try (SimpleAsyncTaskExecutor executor = new SimpleAsyncTaskExecutor("scope-value-")) {
    executor.setTaskDecorator(new ScopedValueDecorator());
    executor.submit(() -> {
        System.out.println(Thread.currentThread().getName() + " retrieves " + SCOPED_VALUE.get());
    });
}

}

Note that the usage of virtual threads is not required; SimpleAsyncTaskExecutor, however, does not pool the threads and platform threads could be too expensive.

As for MDC, (BasicMDCAdapter, LogbackMDCAdapter) it uses ThreadLocal, not ScopedValue in a currently available latest version. ThreadLocal and ScopedValue, although targeting similar problem, are technically unrelated things. So, an answer to

how count MDC work for ScopedValue up to now

is "MDS doesn't".

EDIT: It does not until you go ahead and implement your own MDCAdapter. An example of such implementation is discussed in Logback: availability of MDCs in forks created inside a StructuredTaskScope