{
  "path": "/gpu-random-forest-ml.html",
  "site": "at://did:plc:x67qh7v3fd7znbdhauc45ng3/site.standard.publication/3mjcd2t6afe25",
  "$type": "site.standard.document",
  "title": "Random Forest ML on GPU",
  "updatedAt": "2026-04-17T00:00:00.000Z",
  "bskyPostRef": {
    "cid": "bafyreiehp5yc57dqpdsskvuxfafhk7s7f3p732vggxewyaftawfst6otki",
    "uri": "at://did:plc:x67qh7v3fd7znbdhauc45ng3/app.bsky.feed.post/3mjosq32uws27"
  },
  "publishedAt": "2026-04-17T00:00:00.000Z",
  "textContent": "In [my recent post on Rastair][rastair-post],\nwe looked at some performance best-practices and optimizations\nfor [Rastair], a bioinformatics tool that I'm currently working on.\nOne of the slowest parts of the tool\nis running machine-learning inference on a Random Forest model.\nIn this post,\nI want to describe what we use ML for,\nand how we moved inference to a GPU compute shader\nto make it fast.\n\n[rastair-post]: https://deterministic.space/rastair.html \"Notes on Rastair, a variant and methylation caller\"\n\nContext\n\nRastair needs to make decisions on data that is prone to errors:\nWe analyze short genome sequencing reads\nwhich have been stochastically aligned to a reference\nand that give us many points of evidence for what base a position in a sample actually is.\nThey include some quality metrics (per-read and per-position),\ninformation on possible insertions and deletions,\nand many other flags.\nWe also know a general error rate for the instruments used.\nBut some evidence we look at can also be interpreted in two different ways.\n\nThe typical case in Rastair is this:\nWe have a position that is C in the reference\nand we have 30 reads at this position.\nSome show C (agree with the reference),\nbut some show T.\nSince Rastair deals with [TAPS] sequence data,\nwe know that a change from C to T can also be evidence for methylation[^m].\nSo we now need to decide:\nIs this a variant, or is this a methylation position, or can it even be both?\n\n[TAPS]: https://www.nature.com/articles/s41587-019-0041-2 \"TAPS paper in Nature Biotechnology\"\n[^m]: Simply put: A flag on a position.\n\nFor some cases we can pretty directly decide:\nWe know, e.g., that the C to T conversion only happens in CG contexts,\nand on the original top strand (a flag on the read).\nThis helps exclude cases,\nbut it doesn't yet help us be certain about when a position is truly one or the other.\n\nUsing ML\n\nWe can filter out some obvious cases with hard thresholds,\nbut for the ambiguous positions\nwe want something that can weigh many pieces of evidence at once.\nThis is where machine learning comes in.\n\n[Benjamin] designed a set of features\nthat capture the relevant information about a position:\nthings like base quality, mapping quality, depth ratios,\nthe surrounding sequence context, and so on.\nWe feed these into a Random Forest model\nthat outputs a score for each position,\nwhich we then convert into a probability\nusing [Platt scaling][platt] (also Benjamin's work).\n\n[Benjamin]: https://www.ludwig.ox.ac.uk/team/benjamin-schuster-bockler \"Benjamin Schuster-Böckler\"\n[platt]: https://en.wikipedia.org/wiki/Platt_scaling \"Platt scaling on Wikipedia\"\n\nWe actually run three separate models:\none for methylation in CG context,\none for de-novo methylation[^denovo],\nand one for everything else (variants).\nEach model is trained on data where we know the ground truth.\n\n[^denovo]: A position where a variant creates a new CG site.\n  For example: A G is changed to C and the following base is also G.\n  Then in this sample, there is now a CG where in the reference it was GG.\n  This CG can be methylated.\n\nWhat's a Random Forest\n\nA Random Forest is an ensemble of decision trees.\nA single decision tree is simple:\nat each node, you look at one feature,\ncompare it to a threshold,\nand go left or right.\nWhen you reach a leaf, you get a prediction.\nEasy to understand, fast to evaluate.\nBut a single tree tends to \"overfit\"[^of].\n\n[^of]: It learns the training data too well and then makes poor predictions on new data\n  because it memorized noise rather than the underlying pattern.\n\nThe \"forest\" part fixes this:\nyou train many trees (e.g., 400),\neach on a slightly different random subset of the training data\nand a random subset of features.\nAt inference time,\nyou run the input through all the trees\nand average their predictions.\nThis is surprisingly effective[^rf]\nand gives you a model that generalizes well\nwithout needing a lot of tuning.\n\n[^rf]: It's one of those techniques that is easy to underestimate because it's conceptually simple.\n  I had never used this before this project and I'm amazed how quickly it yielded great results.\n\nFor our use case, this is a good fit:\nThe model is fast to evaluate (just comparisons and a mean),\nit handles mixed feature types well,\nand it doesn't require a GPU or a deep learning framework to train,\njust some data in arrays.\n\nThere are many Rust crates that can train and run a RF.\nWe ended up using the [biosphere] crate,\nbecause it seemed simple and purposeful enough\nwhile also being quite fast.\n\n[biosphere]: https://github.com/mlondschien/biosphere/ \"biosphere: Simple, fast random forests.\"\n\nHow to verify\n\nHow do we know the model actually makes good calls?\nWe compare against [Genome in a Bottle][giab] (GIAB),\na well-characterized reference dataset\nthat serves as a \"ground truth\" for benchmarking variant callers.\nBenjamin wrote R scripts to evaluate our calls against this reference,\nand I later ported that comparison to Rust.\nThis basically works by reading in the VCF file from GIAB\nand the in the one that Rastair produces,\nand comparing overlap, false-positive, and false-negative count.\n\nLooking at 45× coverage data\nat high-confidence regions of the GIAB reference call set,\nRastair achieves an F1 score of 98.9%.\nThich is on par with other state-of-the-art tools.\nSee our [paper] for more details.\n\n[giab]: https://www.nist.gov/programs-projects/genome-bottle \"Genome in a Bottle Consortium\"\n[paper]: https://www.biorxiv.org/content/10.64898/2026.03.19.712983v1 \"Rastair: an integrated variant and methylation caller\"\n\nUsing a Compute Shader\n\nSo we have a Random Forest model that makes good predictions.\nThe problem is that we need to run it a lot:\nRastair processes millions of positions,\nand for each position we might have multiple alternative alleles\nthat each need to be scored.\nWhen profiling with [samply],\nwe saw that most time was spent in biosphere\ndoing float comparisons and pointer chasing.\nSince Rastair already parallelizes across CPU cores\n(see [my previous post][rastair-post] for details),\nCPU usage is at 100% and our only options to make it faster are:\nDo less work or do it somewhere else.\nWe already tried to do less work by adding some very broad filters[^prefilter].\nSo the question was:\nCan we throw a GPU at this?\n\n[samply]: https://github.com/mstange/samply/ \"samply is a command line CPU profiler which uses the Firefox profiler as its UI\"\n[^prefilter]: E.g., don't run ML on positions with too little coverage because it would just say \"no\".\n\nRandom Forest inference is a good candidate for GPU acceleration:\neach (sample, tree) pair is completely independent,\nthe operation is simple (comparisons and memory lookups),\nand we have large batches to amortize the overhead.\nWe went with [wgpu],\nwhich compiles compute shaders written in [WGSL] to Metal, Vulkan, and DX12.\nThis means the same code runs on my MacBook (Metal),\na Linux workstation with an NVIDIA or AMD card (Vulkan),\nor even a Windows machine (DX12, untested).\n\n[wgpu]: https://wgpu.rs/ \"wgpu is a safe and portable graphics library for Rust based on the WebGPU API. It is suitable for general purpose graphics and compute on the GPU.\"\n[WGSL]: https://www.w3.org/TR/WGSL/ \"WebGPU Shading Language\"\n\nFlattening the forest\n\nThe original RandomForest in biosphere\nstores trees as heap-allocated recursive structure.\nThis makes building them easy when training,\nbut it's not great for shipping to a GPU.\nThe first step was to convert each tree\ninto a flat array of nodes in BFS (breadth-first) order\nwith explicit child indices\n([code][flat_forest.rs]).\nEach node is 16 bytes:\n\n[flat_forest.rs]: https://github.com/Softleif/biosphere/blob/1d7c621fa54860a9b1d1807f0d6137b0c4aaafea/src/flat_forest.rs#L50-L57\n\nThe #[repr(C)] is doing real work here:\nit guarantees a fixed memory layout\nso we can use [bytemuck] to cast the entire node slice to raw bytes\nand upload it directly to the GPU.\nThe WGSL shader defines the same struct layout,\nso the same bytes are interpreted identically on both sides\nwith no serialization or conversion step needed.\n\n[bytemuck]: https://docs.rs/bytemuck/1.25.0/bytemuck/ \"bytemuck, a crate for mucking around with piles of bytes\"\n\nThe value field does double duty:\nit's the split threshold for internal nodes\nand the leaf prediction for leaf nodes.\nleft < 0 tells you which case you're in.\nThis keeps the struct at exactly 16 bytes,\nwhich means 4 nodes fit in a single 64-byte cache line.\n\nCrucially, we use explicit child indices\nrather than the implicit 2i+1 / 2i+2 layout\nyou might remember from textbook binary heaps.\nOur real decision trees are rarely balanced,\nand the implicit layout would require exponential padding\nfor deep, sparse trees.\nWith explicit indices, any tree shape works\nwithout wasting memory.\n\nI used Claude Code to implement this step.\nIt's the kind of well-defined data structure transformation\nthat works well with AI assistance,\nand it works basically first try.\n\nOne more tweak:\nAll trees are padded to the same max_tree_size\nso the GPU can index into them uniformly:\ntree t, node n lives at nodes[t * max_tree_size + n].\nThe padding slots are dummy leaves with value = 0.0,\nso even if traversal somehow lands on one, it contributes nothing.\n\nThe shaders\n\nThe GPU work happens in two compute shaders, both written in [WGSL].\n\nTraverse\n: Dispatch as (ceil(n_samples / wg_size), n_trees, 1).\n  Each GPU thread handles one (sample, tree) pair:\n  it walks the flat node array from root to leaf,\n  comparing features to thresholds,\n  and writes the leaf value to a per-tree prediction buffer.\n\nReduce\n: Dispatch as (ceil(n_samples / wg_size), 1, 1).\n  Each thread averages all per-tree predictions for one sample\n  into the final output.\n\nThe shaders are short[^shaders] and I was pleasantly surprised\nat how straightforward WGSL is for this kind of work.\nThe traverse kernel is essentially the same loop\nas the CPU version,\njust with GPU thread indexing instead of a for loop over samples.\n\n[^shaders]: ~100 line",
  "canonicalUrl": "https://deterministic.space/gpu-random-forest-ml.html"
}