Random Forest ML on GPU

Pascal’s scribbles blog April 17, 2026
Source

In my recent post on Rastair, we looked at some performance best-practices and optimizations for [Rastair], a bioinformatics tool that I'm currently working on. One of the slowest parts of the tool is running machine-learning inference on a Random Forest model. In this post, I want to describe what we use ML for, and how we moved inference to a GPU compute shader to make it fast.

Context

Rastair needs to make decisions on data that is prone to errors: We analyze short genome sequencing reads which have been stochastically aligned to a reference and that give us many points of evidence for what base a position in a sample actually is. They include some quality metrics (per-read and per-position), information on possible insertions and deletions, and many other flags. We also know a general error rate for the instruments used. But some evidence we look at can also be interpreted in two different ways.

The typical case in Rastair is this: We have a position that is C in the reference and we have 30 reads at this position. Some show C (agree with the reference), but some show T. Since Rastair deals with TAPS sequence data, we know that a change from C to T can also be evidence for methylation[^m]. So we now need to decide: Is this a variant, or is this a methylation position, or can it even be both?

[^m]: Simply put: A flag on a position.

For some cases we can pretty directly decide: We know, e.g., that the C to T conversion only happens in CG contexts, and on the original top strand (a flag on the read). This helps exclude cases, but it doesn't yet help us be certain about when a position is truly one or the other.

Using ML

We can filter out some obvious cases with hard thresholds, but for the ambiguous positions we want something that can weigh many pieces of evidence at once. This is where machine learning comes in.

Benjamin designed a set of features that capture the relevant information about a position: things like base quality, mapping quality, depth ratios, the surrounding sequence context, and so on. We feed these into a Random Forest model that outputs a score for each position, which we then convert into a probability using Platt scaling (also Benjamin's work).

We actually run three separate models: one for methylation in CG context, one for de-novo methylation[^denovo], and one for everything else (variants). Each model is trained on data where we know the ground truth.

[^denovo]: A position where a variant creates a new CG site. For example: A G is changed to C and the following base is also G. Then in this sample, there is now a CG where in the reference it was GG. This CG can be methylated.

What's a Random Forest

A Random Forest is an ensemble of decision trees. A single decision tree is simple: at each node, you look at one feature, compare it to a threshold, and go left or right. When you reach a leaf, you get a prediction. Easy to understand, fast to evaluate. But a single tree tends to "overfit"[^of].

[^of]: It learns the training data too well and then makes poor predictions on new data because it memorized noise rather than the underlying pattern.

The "forest" part fixes this: you train many trees (e.g., 400), each on a slightly different random subset of the training data and a random subset of features. At inference time, you run the input through all the trees and average their predictions. This is surprisingly effective[^rf] and gives you a model that generalizes well without needing a lot of tuning.

[^rf]: It's one of those techniques that is easy to underestimate because it's conceptually simple. I had never used this before this project and I'm amazed how quickly it yielded great results.

For our use case, this is a good fit: The model is fast to evaluate (just comparisons and a mean), it handles mixed feature types well, and it doesn't require a GPU or a deep learning framework to train, just some data in arrays.

There are many Rust crates that can train and run a RF. We ended up using the biosphere crate, because it seemed simple and purposeful enough while also being quite fast.

How to verify

How do we know the model actually makes good calls? We compare against Genome in a Bottle (GIAB), a well-characterized reference dataset that serves as a "ground truth" for benchmarking variant callers. Benjamin wrote R scripts to evaluate our calls against this reference, and I later ported that comparison to Rust. This basically works by reading in the VCF file from GIAB and the in the one that Rastair produces, and comparing overlap, false-positive, and false-negative count.

Looking at 45× coverage data at high-confidence regions of the GIAB reference call set, Rastair achieves an F1 score of 98.9%. Thich is on par with other state-of-the-art tools. See our paper for more details.

Using a Compute Shader

So we have a Random Forest model that makes good predictions. The problem is that we need to run it a lot: Rastair processes millions of positions, and for each position we might have multiple alternative alleles that each need to be scored. When profiling with samply, we saw that most time was spent in biosphere doing float comparisons and pointer chasing. Since Rastair already parallelizes across CPU cores (see my previous post for details), CPU usage is at 100% and our only options to make it faster are: Do less work or do it somewhere else. We already tried to do less work by adding some very broad filters[^prefilter]. So the question was: Can we throw a GPU at this?

[^prefilter]: E.g., don't run ML on positions with too little coverage because it would just say "no".

Random Forest inference is a good candidate for GPU acceleration: each (sample, tree) pair is completely independent, the operation is simple (comparisons and memory lookups), and we have large batches to amortize the overhead. We went with wgpu, which compiles compute shaders written in WGSL to Metal, Vulkan, and DX12. This means the same code runs on my MacBook (Metal), a Linux workstation with an NVIDIA or AMD card (Vulkan), or even a Windows machine (DX12, untested).

Flattening the forest

The original RandomForest in biosphere stores trees as heap-allocated recursive structure. This makes building them easy when training, but it's not great for shipping to a GPU. The first step was to convert each tree into a flat array of nodes in BFS (breadth-first) order with explicit child indices (code). Each node is 16 bytes:

The #[repr(C)] is doing real work here: it guarantees a fixed memory layout so we can use bytemuck to cast the entire node slice to raw bytes and upload it directly to the GPU. The WGSL shader defines the same struct layout, so the same bytes are interpreted identically on both sides with no serialization or conversion step needed.

The value field does double duty: it's the split threshold for internal nodes and the leaf prediction for leaf nodes. left < 0 tells you which case you're in. This keeps the struct at exactly 16 bytes, which means 4 nodes fit in a single 64-byte cache line.

Crucially, we use explicit child indices rather than the implicit 2i+1 / 2i+2 layout you might remember from textbook binary heaps. Our real decision trees are rarely balanced, and the implicit layout would require exponential padding for deep, sparse trees. With explicit indices, any tree shape works without wasting memory.

I used Claude Code to implement this step. It's the kind of well-defined data structure transformation that works well with AI assistance, and it works basically first try.

One more tweak: All trees are padded to the same max_tree_size so the GPU can index into them uniformly: tree t, node n lives at nodes[t * max_tree_size + n]. The padding slots are dummy leaves with value = 0.0, so even if traversal somehow lands on one, it contributes nothing.

The shaders

The GPU work happens in two compute shaders, both written in WGSL.

Traverse : Dispatch as (ceil(n_samples / wg_size), n_trees, 1). Each GPU thread handles one (sample, tree) pair: it walks the flat node array from root to leaf, comparing features to thresholds, and writes the leaf value to a per-tree prediction buffer.

Reduce : Dispatch as (ceil(n_samples / wg_size), 1, 1). Each thread averages all per-tree predictions for one sample into the final output.

The shaders are short[^shaders] and I was pleasantly surprised at how straightforward WGSL is for this kind of work. The traverse kernel is essentially the same loop as the CPU version, just with GPU thread indexing instead of a for loop over samples.

[^shaders]: ~100 line

Discussion in the ATmosphere

Loading comments...