Monte Carlo Tree Search for Code Generation using LLMs

Date: 2024-02-25
[View on GitHub]

There have been a slew of LLMs released to the public over the past ~12 months that have shown surprising levels of knowledge & ability in multiple domains. Code generation in particular has received quite a lot of attention, with a continuous stream of new models being published almost monthly–notable examples including OpenAI's GPT-4 & GPT-3.5, Microsoft's Phi-2, Meta's CodeLlama, and DeepSeek's DeepSeek Coder, alongside various fine-tune variants such as Magicoder and Phind-CodeLlama.

Many of these models are impressive at generating short snippets of code, and often fare well even for complex problem descriptions, albeit with some hallucinations due to the probabilistic nature of LLMs. In this blog post we explore the following: is it possible to achieve LLM code generation that is accurate, consistent, and relatively cheap to run?

We approach the problem by combining a pre-trained transformer model with an adapted version of the Monte-Carlo tree search (MCTS) algorithm to guide the generation, taking inspiration from the paper Planning with Large Language Models for Code Generation. We compare the effectiveness of this approach against direct sampling of the transformer’s unaided generations on the HumanEval and VerilogEval benchmarks and detail the process & observations gathered along the way. Jump ahead to the final discussion if you’re curious about the results!

Background

Language models function by predicting the most likely next tokens (sequence of characters) that follow a given prompt or sequence. They generate text based on probabilities and patterns derived from their training data, essentially acting as statistical predictors. It turns out that such an approach works surprisingly well in generating all kinds of text including source code for computer programs.

But as impressive as they are, LLMs also come with their quirks and limitations. One well-known issue is the tendency for models to hallucinate and generate responses that are either subtly incorrect or outright made up. This is not what we want when generating code, where we expect the output to be both functionally correct & consistent over time.

llm-incorrect-example llm-code-execution-example

Example of incorrect Python code output from DeepSeek Coder 6.7B (running on llama.cpp).

There are several ways we could mitigate the effect of hallucinations. We could, for instance, just naively sample the LLM multiple times for the same prompt and filter out the incorrect ones that fail compilation or a test bench. This is known as sampling + filtering (S+F), and it provides a reasonable baseline against which to compare alternative solutions.

Another approach is to search for the correct sequence of tokens using a tree-based search algorithm. Each strategy comes with its own pros and cons in improving LLM accuracy, and we'll explore them both in greater depth to learn exactly how and when they might be most useful.

Setting up with an LLM

One of our primary constraints for the experiment was to choose a smaller code LLM that is relatively cheap to run. While somewhat arbitrary, we decided to define this as any LLM that could run locally on an Apple M2 Pro with 32GB unified memory and 16-core GPU as a reasonably specced consumer-grade machine.

In selecting the specific language model to use, we looked at several popular open-source LLMs ranging from 2.7B~34B parameters in size: Phi-2, CodeBooga-34B-v0.1, Magicoder-S-DS-6.7B, and deepseek-coder-6.7b-instruct. Our method of appraisal was largely qualitative, running each LLM through a series of prompts either via the HuggingFace transformers library (example notebook), LM Studio chat, or llama.cpp’s CLI interface.

The CodeBooga-34B model proved too large to fit in memory unless heavily quantized to 3 bits or below (due to Apple defining a hard limit for GPU memory usage at roughly ~70% of total memory), which degraded the model quality beyond our acceptable limit. On the other hand, we struggled to get the smaller Phi-2 model to follow the output format instructions specified in our prompts, making it difficult to automatically evaluate the model’s code output on benchmark test cases.

We ultimately settled on the mixed 5/6-bit Q5_K_M quantized version of deepseek-coder-6.7B-instruct, due to its superior quality & inference speed. Running under llama.cpp, the model could generate an average of ~22.5 tokens/s and we integrated it via the Python bindings provided by llama-cpp-python.

from llama_cpp import Llama

model = Llama(
    model_path="deepseek-coder-6.7b-instruct.Q5_K_M.gguf",
    n_gpu_layers=-1, # run the full model on GPU
    n_ctx=2048,
    n_batch=256,
    n_threads=10,
    logits_all=True, # include token probabilities in output
)

We would like to point out that llama-cpp-python does not yet support batched inference, so we had to run the model in a single batch mode sequentially. However, thanks to the default prompt kv-caching in llama.cpp, we were still able to process subsequent runs of the same prompt relatively quickly.

Prompting

To better process and extract code from the generations given by deepseek-coder-6.7B-instruct, we employed few-shot prompting to ensure the model followed a structured format and proper indentation as appropriate. We relied on templating from the LLM prompting library outlines to create the final prompts.

import outlines

@outlines.prompt
def few_shot_prompt(examples, question):
    """
    Please answer the following question following the examples.
    Generate valid Python code by indenting 4 spaces always.

    {% for example in examples %}
    Question:
    ```
    {{ example.prompt }}
    ```
    Answer:
    ```
    {{ example.canonical_solution }}
    ```
    {% endfor %}

    Question:
    ```
    {{ question }}
    ```
    Answer:
    ```
    """

The HumanEval dataset

We chose to specifically test the model's Python proficiency via the HumanEval dataset which, though far from being a perfect measure, is an established benchmark in evaluating LLM code performance. HumanEval consists of 164 programming tasks in Python of varying difficulty, with each task containing at least an ID, prompt, reference solution, and test to appraise model output. Here is an example of what a single task in the dataset looks like:

human-eval example

(source ResearchGate)

The dataset also provides the evaluate_functional_correctness script to calculate the pass@k metric for the LLM-generated code samples. The pass@k metric is defined as the probability that at least one of the k-generated samples for each task in the benchmark passes its unit tests. It can be seen as the overall "problem pass rate" of the LLM across the entire HumanEval problem set.

Generating S+F baselines

With the S+F strategy, we expected a higher number of samples generated per prompt to increase the LLM's accuracy. In an ideal world, we would've generated upwards of ~200 samples, but due to both hardware & time constraints, we limited the generations to 20 samples per task. This has the downside of higher variance in the ensuing pass@k calculations, but we did not expect it to be so significant as to distort the experiment results.

Below is the code for generating the S+F baselines with hyperparameters:

from humaneval import get_prompts_with_ids
from human_eval.data import write_jsonl
      
N_SAMPLES = 20
prompts_ids = get_prompts_with_ids()
# prompts are already templated to include few-shot examples
for prompt, task_id in prompts_ids:
    samples = []
    for _ in range(N_SAMPLES):
        # hyperparams: top-3 sampling, 1.0 temp, 256 max tokens
        output = model(
            prompt=prompt, max_tokens=256, temperature=1, top_k=3, stop=["```"]
        )
        res = output["choices"][0]["text"]
        item = dict(task_id=task_id, completion=res)
        samples.append(item)

    # save generated samples in jsonl out file
    write_jsonl("few_shot_baselines_256_top_3.jsonl", samples, append=True)

and the resulting pass@k metrics:

pass@1 pass@5 pass@20 time for 20 samples
S+F 62.38% 84.60% 90.74% ~4h

Discussion

Running the HumanEval-provided pass@k evaluation script on LLM-generated samples.

Code generation and tree search

Having established the baseline with S+F, we turned to an alternative approach treating LLM code generation as a search problem instead. Here the goal is to find the correct sequence of tokens that passes the tests defined for each HumanEval task.

We start by modeling the problem using a decision tree. Given the initial prompt description (the root), we recursively map out all the possible next tokens that could follow (branches or children) until a specified maximum depth is reached.

A visualisation of a small decision tree. Red represents the root node (i.e. the initial prompt), blue represents child nodes (subsequent tokens), and green represents the final set of all possible outcomes. Play around with the time-step slider above to see how the tree grows as it deepens.

Finding the solution with this method simply requires selecting the best option from the final list of outcomes, but this is only possible when the outcome space is small. Because the computation needed to map a full tree increases exponentially with its depth and branching factor (the number of children at each node), even slight changes in these values can meaningfully impact the feasibility of the search.


Try changing the branching factor of the tree above to see how it affects the number of total outcomes to compute.

Though difficult to estimate, we'd expect code text generation to have a depth & branching factor on the order of hundreds, if not thousands. A typical game of Go, for reference, takes an average of 211 turns to complete with an average branching factor of ~250. This equates to ~10500 possible outcomes needed to be computed, which is already far more than can be processed in any reasonable length of time.

It is thus impractical to apply decision trees to code generation without also adopting some kind of strategy to significantly reduce the branching factor & resulting search space. This is precisely the purpose of MCTS, which applies the Upper Confidence Bound heuristic to identify the "winning" branches in a tree without computing the entire tree.

Example of MCTS–note the differing depths of the various subtrees.

The result is an asymmetric expansion of the tree, where only a small subset of all possibilities are computed, balancing exploitation of known good options with exploration of unknown, potentially better options. This makes it possible to effectively search even very large trees within a reasonable length of time.

Adopting MCTS with LLMs

The way MCTS works is by repeatedly cycling through four logical stages: selection, expansion, evaluation (or simulation), and back-propagation. But unlike in traditional MCTS, we rely on the probabilities/predictions of an LLM to complete each of these steps.

mcts steps

High-level diagram of steps in conventional Monte-Carlo tree search (source Wikipedia).

In the selection phase, the MCTS algorithm makes a decision on which node or branch it should explore/exploit next by assigning a kind of "priority" to each node. Greater weight is typically given to nodes that are either higher in value (currently observed performance) or have fewer visits (not yet explored i.e. greater uncertainty). For our case of LLM code generation, the exact calculation of this priority is given by:

p-ucb function

and the node with the highest priority is selected:

p-ucb selection

This formulation adapts UCB with transformer probabilities to weight the exploration, as described in Section D.1 of Planning with Large Language Models for Code Generation. The choice of which nodes to explore is a function of three main parameters: the node value Q (to be elaborated further below), the node token probability P (provided by the LLM), and an exploration term that captures how often we've previously visited the node.

Once a node has been selected, it is then expanded with k new child nodes in the expansion phase. Because only a single branch is selected at any given point in time, MCTS develops the search tree asymmetrically by disregarding all other nodes. In this experiment, the child nodes are obtained from the top 3 most probable next token options given by an LLM.

Then following expansion comes evaluation. It is here that the value of the node is determined via the reward or value function–a quantitative measure to rank nodes, such that choosing higher value nodes would yield better end outcomes. However each node in the tree represents an incomplete intermediate state that is often difficult to evaluate directly. Thus we instead evaluate a chosen node's "expected" end outcome by simulating the remaining branch expansions.

We achieve this simulation by having an LLM complete the remaining sequence (program) from the node's partial state. This completed program is then evaluated against a bench of tests from the HumanEval dataset, to obtain a pass rate between 0 - 1 as the node's value. Code that passes more tests ends up with higher values and consequently is explored more in future iterations through the tree.

But due to the nested nature of search trees, any high-value nodes discovered deep in the hierarchy might not be reached again in successive searches (e.g. if the previous parent branches are never chosen). To ensure that this doesn't happen, a node's reward value is propagated to all its parents up to the tree's root. This is the final back-propagation phase that completes a single iteration of MCTS, and once finished we return to the top (root) of the tree to restart the cycle.

Coding the MCTS algorithm

We use a simple Node class to represent the nodes of the Monte-Carlo tree. Remember that a node is equivalent to a branch, and it represents a possible token (character) that could come next in the sequence of generated text. Each node carries its values for generation state, number of visits, node value, LLM token probability, and children.

class Node:
    def __init__(self, prob, state, parent):
        self.value = 0  # max reward obtainable from node
        self.prob = prob  # input for the P-UCB calculation
        self.state = state  # full generated text sequence up til node
        self._children = []
        self._parent = parent
        self.visits = 0

    def backprop(self, value):
        # only propagate if new reward is greater than current max
        if value > self.value:
            self.value = value
            if self._parent is not None:
                self._parent.backprop(value)

For each HumanEval task, the tree is expanded and evaluated for a fixed number of iterations (called rollouts). In each rollout we perform selection, expansion, evaluation, and back-propagation, and at the end of all iterations return the best found solution.

max_rollouts = 128 # max number of iterations through tree
top_k = 3 # number of child tokens (branches) to generate per node
for prompt, task_id in prompts:
    # cache of generated programs => rewards
    program_dict = {}
    # initialise tree with HumanEval prompt
    root = Node(prob=1, state=prompt, parent=None)

    for i in range(max_rollouts):
        curr_node = root # always start new iteration from root
        curr_node.visits += 1

We start by continuously selecting the best nodes in the tree according to the P-UCB metric explained earlier, until we arrive at a leaf node that has not yet been expanded & evaluated.

        # selection
        while len(curr_node._children) > 0:
            curr_node = p_ucb_select(curr_node._children)
            curr_node.visits += 1

Once we've landed on the candidate leaf node, we expand it by generating top-k tokens from the LLM and creating child nodes. Note that we apply the exp() function to exponentiate the probability values as the LLM returns log-probabilities.

        # expansion
        tokens = get_top_k_tokens(curr_node, top_k) # top_k = 3
        child_nodes = [
            Node(exp(logprob), state=(curr_node.state + token), parent=curr_node)
            for (token, logprob) in tokens
        ]
        curr_node._children = child_nodes

We then evaluate the selected node by generating a program from its current state. This is done using a greedy search, but can easily be replaced with a beam search or similar algorithm (we believe that code generation should be simulated using deterministic search as opposed to non-deterministic sampling, but this should be tested).

The generated program is evaluated against the respective tests from HumanEval to obtain the node's value.

        # evaluation
        reward = match_cached_programs(curr_node.state, program_dict)
        # only run generation if node state not found in cached programs
        if reward == -1:
            generated_program = llm_generate(curr_node.state)
            # run generated program against HumanEval test cases
            reward = calculate_reward(prompt, generated_program)
            # cache generated program and its corresponding pass rate
            program_dict[generated_program] = reward
          

Finally, we back-propagate the reward up the tree to all its parents (including the root). We stop when a reward of 1.0 is found, as this represents a fully correct solution and thus there is no more exploration to do.

    # backprop reward up the tree
    curr_node.backprop(reward)

    # early termination if correct program is found
    if reward == 1:
        break

P-UCB selection:

def p_ucb_select(parent_node, child_nodes):
    s_visits = parent_node.visits
    # scalar constant term
    beta = log((s_visits + c_base + 1) / c_base) + c

    # find the child node with the highest P-UCB value
    max_p_ucb = -inf
    max_node = None
    for i in range(len(child_nodes)):
        node = child_nodes[i]
        p_ucb = node.value + beta * node.prob * sqrt(log(s_visits)) / (
            1 + node.visits
        ) # calculate the P-UCB value for each child
        if p_ucb > max_p_ucb:
            max_node = node
            max_p_ucb = p_ucb
    return max_node # return max(P-UCB) node

Node expansion:

# fetch the top 3 highest probability token candidates from the LLM
def get_top_k_tokens(curr_node, k=3):
    output = model(prompt=curr_node.state, max_tokens=1, temperature=1, logprobs=k)
    output_probs = output["choices"][0]["logprobs"]["top_logprobs"][0]
    return output_probs.items()

Reward function:

def calculate_reward(task_id, completion, timeout=10):
    problem = human_eval_problems[task_id]
    split_tests = problem["test"]
    results = []
    for test in split_tests:
        res = check_correctness(test, completion, timeout)
        results.append(res["passed"])

    return sum(results) / len(results) # set test pass rate as reward

Visualizing the MCTS Tree

Below is the tree graph for HumanEval-113 using MCTS. Each time step represents a single iteration (rollout) of the MCTS algorithm, with the nodes selected at each iteration colored in red. For this particular problem, the final solution was found on the 13th rollout.

Beating the baseline on HumanEval

If we recall from above, the following metrics were achieved for S+F with 2048 context size, top-3 sampling, 1.0 temperature, and 256 max tokens:

pass@1 pass@5 pass@20 total time (20 samples/task)
S+F 62.38% 84.60% 90.74% ~4h

Below are the results for MCTS with 128 max rollouts, 2048 context, top-3 probabilities, 1.0 temperature, and 256 max tokens:

pass rate avg. unique generations avg. rollouts time
MCTS 92.59% 3.62 15.46 ~1h 30m

Discussions

  1. Technically, the avg. rollouts rather than no. of unique generations should be used to compare against the number of samples generated per problem. But in practice, many rollouts explore the same solution and thus much of the compute can be saved via caching of previously generated programs, reward calculations and token probabilities.
  2. One may suggest that we could also stop generation in S+F as soon as a correct solution is found. However due to the probabilistic nature of sampling, which cannot guarantee replication across successive generations, we need to collect the full set of samples for every problem to properly estimate the pass@k metrics for various values of k. We could alternatively sample until success and calculate the empirical expected k to compare with the average rollout of MCTS, but we suffice with estimating the pass@k metric.

Extending the experiment to VerilogEval

Using MCTS on the HumanEval dataset proved quite successful, outperforming the ~equivalent S+F 5-sample baseline by a decent margin and even the 20-sample baseline at much less compute cost. But due to the raw strength of deepseek-coder-6.7B-instruct in HumanEval and Python, applying MCTS didn't show a significant improvement in absolute quality vs. the top S+F results.

We thus wanted to know what would happen when dealing with problems that the LLM wasn't well trained on, and how effective MCTS could be at improving the model's code quality compared to standalone S+F. To this end, we adapted our experiment to test on the VerilogEval dataset, which follows a very similar format as HumanEval but for testing proficiency in the lesser-known Verilog language as opposed to Python.

verlilog-eval example

A typical task within VerilogEval (source VerilogEval). Tests are also present within the dataset but aren't shown in the diagram above.

Because we expected the deepseek LLM to be less capable in Verilog, we needed to adjust a few generation parameters before running both S+F and MCTS. We increased the context length & max number of generated tokens to account for the greater verbosity of Verilog code. We also increased the k in top-k sampling from 3 to 50 to enable more diverse model outputs.

S+F with 4096 context size, top-50 sampling, 1.0 temperature, and 1024 max tokens:

pass@1 pass@5 pass@20 time for 20 samples
S + F 29.71% 42.58% 51.30% ~7h

MCTS with 128 max rollouts, 4096 context, top-5 probabilities, 1.0 temperature, 1024 max tokens:

pass rate avg. unique generations avg. rollouts time
MCTS 46.75% 18.18 73.46 ~8h 30m

Discussion

What have we learned?

Our results from comparing S+F performance against MCTS in both the HumanEval and VerilogEval datasets show that when dealing with domains where an LLM is relatively proficient (e.g. Python), combining the LLM's predictive capabilities with a specialised algorithm such as MCTS does indeed produce higher quality results with significantly less compute.

The story changes however as we transition to areas where the LLM is much weaker (e.g. Verilog). In this scenario, MCTS doesn't do well due to its reliance on the LLM’s (usually incorrect) token probabilities to determine the next branches to explore. On the other hand, S+F enables a more consistent boost in LLM performance despite being the simpler approach.

It's worth noting that for Verilog code generation, the overall pass rate remained low irrespective of the approach. This highlights the importance of fine-tuning language models for specific tasks as one of the most effective means to improve accuracy. Even the reasonably small deepseek-coder-6.7b-instruct model showed impressive quality and speed, despite running on a single consumer machine at 5-bit quantisation.

Future considerations

Although MCTS' good performance on HumanEval didn't carry through to the VerilogEval dataset, we believe there are opportunities to make it more robust to fluctuations in the underlying LLM ability by tweaking its reward function & selection/expansion policies:

  1. Reward: Picking a better heuristic to judge the validity of a partial program would guide the search more effectively. For example, rather than relying solely on tests to assign reward values, we could apply negative values for nodes that fail compilation, or that represent tokens outside a pre-specified set of acceptable grammar.
  2. Selection: As the model's token probabilities become more miscalibrated the harder the domain, we can lower the weight given to the transformer probability during node selection.
  3. Expansion: Similar to selection, unreliable token probabilities can lead MCTS to expand potentially unproductive branches. One simple improvement would be to expand a larger number of nodes at each iteration (e.g. top-50 instead of top-3), but this doesn't always lead to better exploration if only a small number of nodes have outsized likelihoods relative to the rest. It might instead be worth trying to sample from rather than directly choose the top-k probability nodes, even if it breaks the deterministic guarantee of MCTS.

On a final note, going beyond pure MCTS we are cautiously optimistic about the general approach of interweaving LLM capabilities with traditional/alternative algorithms (or even other LLMs) to achieve greater robustness & accuracy. Recent research in using LLMs with evolutionary algorithms, self-rewarding LLMs, and other systemic strategies show promise in this direction, and we expect to continue developing both MCTS & these ideas for future enquiry.