Emulating the call stack in Java

The setup

I was recently thinking about the challenges in converting any given recursive algorithm into an iterative one. Famous problems like finding the n'th number in the Fibonacci series have recursive as well as non-recursive/iterative solutions. In such cases, the two approaches attack the problem from different angles. A commonly-held opinion (which I would agree with) is that the recursive algorithm is more elegant and expressive, and the solution matches the problem definition more closely, so it's easier to grasp. However, recursive algorithms can take up more memory because of the large number of function calls that need to be maintained on the call stack. This memory usage can grow exponentially with input size. There are powerful techniques like tail recursion that can be used to reduce the memory overhead to practically zero, but they are not universally applicable. In comparison, the iterative solution for the fibonacci problem has almost no memory overhead.

So is every recursive solution convertible to a non-recursive one in such a way as to do away with the memory overhead? Unfortunately, no. Although many recursive algorithms do have such alternatives, many do not. The only sure-fire way that you can remove recursion from a recursive algorithm is to use a Stack data structure. And that usually ends up taking approximately the same amount of memory as recursion.

The reason for this is that by using a stack we are effectively emulating the native recursion of the language. To track function calls during execution, every running program uses a structure called a Call Stack. This is, as the name suggests, a stack (LIFO) data structure that maintains the context, local variable info and memory address of each function call. As an example, when Function 1 calls Function 2, once Function 2 completes, the program needs to resume execution within Function 1. For that it needs to retrieve from somewhere the memory address of the next line of code within Function 1. So when we modify our program to use a stack explicitly, instead of recursive function calls, we are in effect replicating the internal call stack mechanism of the language.

Diving in

Musing on this, I decided to take a recursive algorithm and implement a non-recursive solution by implementing the call stack myself, to understand better how it works. The problem I've used is to find out all possible permutations of a string that can be obtained from rearranging its characters.

The recursive approach is shown below. The recursive idea is Perm(abc) = (a concat Perm(bc)) + (b concat Perm(ac)) + (c concat Perm(ab)).

public static Set<String> getPermutations(String input) {
        List<Character> allCharacters = new ArrayList<>();
        for (Character c: input.toCharArray()) {
            allCharacters.add(c);
        }

        return getPermutationsInternal(allCharacters, "");
}

private static Set<String> getPermutationsInternal(List<Character> allCharacters, String prefix) {
        Set<String> result = new HashSet<>();
        if (allCharacters.size() == 1) {
            result.add(prefix + allCharacters.get(0).toString());
            return result;
        }

        for (int i = 0; i < allCharacters.size(); i++) {
            char prefixChar = allCharacters.get(i);
            List<Character> characters = new ArrayList<>();
            for (int j = 0; j < allCharacters.size(); j++) {
                if (j != i) {
                    characters.add(allCharacters.get(j));
                }
            }
            result.addAll(getPermutationsInternal(characters, prefix + prefixChar));
        }

        return result;
}

Simple enough. Now to make this non-recursive. The first thing is to identify what all information needs to be stored in a stack entry. In the getPermutationsInternal() method above, there are two input parameters, so I need to store them in order to make them available in each method call. So a stack frame should look something like this

private static class StackFrame {
        final List<Character> allCharacters;
        final String prefix;

        public StackFrame(List<Character> allCharacters, String prefix) {
            this.allCharacters = allCharacters;
            this.prefix = prefix;
        }
}

Next, we modify getPermutationsInternal() to read the necessary values from a stack frame, and we also need to pass it the call stack, so that it can push additional frames on to it.

private static Set<String> getPermutationsInternal(StackFrame sf, Stack<StackFrame> callStack) {
        List<Character> allCharacters = sf.allCharacters;
        String prefix = sf.prefix;
        Set<String> result = new HashSet<>();

        if (allCharacters.size() == 1) {
            result.add(prefix + allCharacters.get(0).toString());
            return result;
        }

        for (int i = 0; i < allCharacters.size(); i++) {
            char prefixChar = allCharacters.get(i);
            List<Character> characters = new ArrayList<>();
            for (int j = 0; j < allCharacters.size(); j++) {
                if (j != i) {
                    characters.add(allCharacters.get(j));
                }
            }
            callStack.push(new StackFrame(characters, prefix + prefixChar));
        }

        return result;
}

Critically, we have changed the statements where the method called itself, into stack push operations, thus emulating what the method calls would have done internally.

Finally, we need to write the code to prepare the initial state of the call stack and then to iteratively pop the stack and call getPermutationsInternal() with the most recently popped stack frame.

public static Set<String> getPermutationsNonRecursive(String input) {
        List<Character> allCharacters = new ArrayList<>();
        for (Character c: input.toCharArray()) {
            allCharacters.add(c);
        }

        Set<String> result = new HashSet<>();
        Stack<StackFrame> callStack = new Stack<>();

        callStack.push(new StackFrame(allCharacters, ""));

        while (!callStack.empty()) {
            StackFrame sf = callStack.pop();
            result.addAll(getPermutationsInternal(sf, callStack));
        }

        return result;
}

Readers will observe that this is quite similar to the Depth First Search traversal of a graph, if you imagine all the function calls as graph nodes.

Coda

This was a fun little exercise and I found it helped me gain a more intuitive understanding of program execution via the call stack. It's such a critical part of program execution and yet programmers are rarely directly exposed to it and as a result may not understand it fully. To be clear, I have just scratched the surface here and may in future posts explore program execution internals further.