Applying a function to all elements of a tree in Python

60 views Asked by At

There is a tree-like data structure, the nodes of which contain scalar values (integers), lists and dictionaries, np.array. I need to write a function (in Python) that applies an aggregation function to the lower levels of the tree and replaces the lower level with the result of the aggregation. The result of the function is a new data structure.

Example.

Initial data structure: {“a”: [1, [2, 3, 4], [5, 6, 7]], “b”: [{“c”:8, “d”:9}, {“ e”:3, “f”:4}, 8]}
Aggregation function:   sum

First use:  {“a”: [1, 9, 18]], “b”: [17, 7, 8]}
Second use: {“a”: 28, “b”: 32}
Third use:  60
Fourth use: 60
1

There are 1 answers

0
Stef On

sum is a particularly simple example because you can sum in any order and you'll always get the same result.

So for instance you could first collapse the tree into a linear collection, using depth-first-search or breadth-first-search, and then sum everything, and you'd get the same result.

But with a different aggregation function, the way things are grouped might matter.

I suggest two different functions, collapse_then_aggregate and aggregate_recursively, which return the same result for a simple aggregating function like sum but can return different results for more complex aggregating functions.

Note how depth_first_search and aggregate_recursively use the same logic to walk the tree recursively, by making a recursive call on each element of the iterable. Dictionaries are treated separately with if isinstance(tree, dict) because you care about the values and not the keys. Iterables are treated using try/except. Strings are iterables too, but presumably you don't want to iterate on their individual characters, so I wrote a special case if isinstance(tree, (str, bytes)) so that strings are not treated like iterables.

def depth_first_search(tree):
    if isinstance(tree, dict):
        for subtree in tree.values():
            yield from depth_first_search(subtree)
    elif isinstance(tree, (str, bytes)):
        yield tree
    else:
        try:
            tree_iter = iter(tree)
        except TypeError:
            yield tree
        else:
            for subtree in tree_iter:
                yield from depth_first_search(subtree)

def collapse_then_aggregate(f, tree):
    return f(depth_first_search(tree))

def aggregate_recursively(f, tree):
    if isinstance(tree, dict):
        return f(aggregate_recursively(f, subtree) for subtree in tree.values())
    elif isinstance(tree, (str, bytes)):
        return tree
    else:
        try:
            tree_iter = iter(tree)
        except TypeError:
            return tree
        return f(aggregate_recursively(f, subtree) for subtree in tree_iter)

Examples of applications:

from math import prod
from statistics import mean, geometric_mean

tree = {'a': [1, [2, 3, 4], [5, 6, 7, 8]], 'b': [{'c':8, 'd':9}, {'e':3, 'f':4}, 8]}

for f in (sum, prod, mean, geometric_mean, list):
    for fold in (collapse_then_aggregate, aggregate_recursively):
        result = fold(f, tree)
        print(f'{f.__name__:4.4}  {fold.__name__:23}  {result}')

Results:

sum   collapse_then_aggregate  68
sum   aggregate_recursively    68
prod  collapse_then_aggregate  278691840
prod  aggregate_recursively    278691840
mean  collapse_then_aggregate  5.230769230769231
mean  aggregate_recursively    5.083333333333334
geom  collapse_then_aggregate  4.462980019474007
geom  aggregate_recursively    4.03915728944794
list  collapse_then_aggregate  [1, 2, 3, 4, 5, 6, 7, 8, 8, 9, 3, 4, 8]
list  aggregate_recursively    [[1, [2, 3, 4], [5, 6, 7, 8]], [[8, 9], [3, 4], 8]]

Note how sum and prod give the same result for our two aggregation methods, but mean, geometric_mean and list give different results.