sharpening the saw

A friend of mine called the other day and asked how to basically loop over a tree. The actual description of the problem was that there’s an SQL table with N fields that has a different enum defined for each field.
His question was how to print out all the possible records for such a table.

While it’s a simple task there are many ways to do it with different upsides and downsides. So I decided to code up several implementations to show him the possibilities.

So the data structure is something along the lines of:

table = (
  ('a', ('a1','a2')),
  ('b', ('b1', 'b2', 'b3')),
  ('c', ('c1', 'c2', 'c3')),
  ('d', ('d1', 'd2')),
)

The first solution that comes to mind is obviously the simple recursive algorithm:

def construct_tree(t, depth=0):
    if depth == len(t)-1:
        return [[x] for x in t[depth][1]]
 
    retl = []
    futures = construct_tree(t, depth+1)
    for attr_val in t[depth][1]:
        for f in futures:
            retl.append([attr_val] + f)
 
    return retl
 
print ','.join(f[0] for f in table)
for l in construct_tree(table):
    print ','.join(l)

This works and is easy to understand but consumes lots of memory since it will build full tree as list of lists before returning.

Another solution is to avoid recursion completely and use a separate list to keep track of where you are in the tree.

def construct_path(idx_map, t):
    """return path through the tree as a iterable"""
    return (t[fieldno][1][field_val_idx] for fieldno, field_val_idx in enumerate(idx_map))
 
def reset_subtree(from_idx, idx_map, t):
    for idx in range(from_idx+1, len(t)):
        idx_map[idx] = 0
 
def construct_tree(t):
    idx_map = [0]*len(t)
    cur_field = len(t)-1
 
    while 1:
        if cur_field == len(t)-1:
            # we have reached the leaf node, print the whole path
            yield construct_path(idx_map, t)
 
        if idx_map[cur_field] < len(t[cur_field][1])-1:
            # we still have some work at this depth
            idx_map[cur_field] += 1
            # always jump to the end after changing index
            cur_field = len(t)-1
        else:
            # can't increment this field anymore, try previous if any
            cur_field -= 1
            if cur_field >= 0:
                reset_subtree(cur_field, idx_map, t)
            else:
                # there is no previous field
                break
 
print ','.join(f[0] for f in table)
for l in construct_tree(table):
    print ','.join(l)

Performance wise this is a lot better but certainly harder to write and understand.

So how to get the performance of the iterative solution while keeping the simplicity of the recursive one? The solution is to use Python’s generators instead of lists:

import collections
 
def construct_tree(t, buf=None):
    if buf is None:
        buf = collections.deque()
 
    for x in t[len(buf)][1]:
        buf.append(x)
        if len(buf) == len(t):
            # leaf node, stop recursion
            yield buf
        else:
            for x in construct_tree(t, buf):
                yield x
        buf.pop()
 
print ','.join(f[0] for f in table)
for e in construct_tree(table):
    print ','.join(e)

Since I have lately been re-learning Prolog I also wanted to write a solution in that language since this seems to be an ideal task for it:

attribute(a, [a1, a2]).
attribute(b, [b1, b2, b3]).
attribute(c, [c1, c2, c3]).
attribute(d, [d1, d2]).
 
fgen(Field, X):-
    attribute(Field, L1),
    member(X, L1).
 
table:-
    fgen(a, X), fgen(b, Y), fgen(c, Z), fgen(d, I),
    write(X),write(','),write(Y),write(','),write(Z),write(','),write(I),nl,fail.

And here’s how to run it:

hadara@hadara-laptop:~/code$ prolog -s table.pl -t table -q | head -3
a1,b1,c1,d1
a1,b1,c1,d2
a1,b1,c2,d1

Update
While learning Erlang I noticed that their list comprehension syntax allowed you to use multiple generator expressions in it. You can think of it as basically using nested loops so you can write my specific example in Erlang like this:

[{X1, X2, X3, X4} || X1 <- [a1,a2], X2 <- [b1,b2,b3], X3 <- [c1,c2,c3], X4 <- [d1,d2]].

Sure enough Pythons list comprehension syntax allows the same:

[(x1,x2,x3,x4) for x1 in ('a1','a2') for x2 in ('b1','b2','b3') for x3 in ('c1','c2','c3') for x4 in ('d1','d2')]

or in a much more effective way using the generator expressions:

((x1,x2,x3,x4) for x1 in ('a1','a2') for x2 in ('b1','b2','b3') for x3 in ('c1','c2','c3') for x4 in ('d1','d2'))

This of course would be useless in real life since you obviously do not want to hard code expression for specific size of list. Luckily you can just use recursion inside the comprehension expression to generalize it.
In Erlang it would look like this:

build_tree([Head]) -> [[X] || X <- Head];
build_tree([Head|Tail]) ->
    [[X1] ++ X2 || X1 <- Head, X2 <- build_tree(Tail)].

and in Python (using generators instead of list comprehensions):

def construct_tree(x):
    if len(x) == 1:
        return ([i] for i in x[0])
 
    return ([i] + j for i in x[0] for j in construct_tree(x[1:]))

Leave a Reply

Your email address will not be published.


*