Accelerating ZK Proving with WebGPU: Techniques and Challenges
Written by Jaehun Kim and Jason Park on

webgpu

Client-side proving is crucial for enabling privacy-preserving ZK applications, yet there are still two main bottlenecks: time and space (memory constraints). There are many technical breakthroughs that address these constraints such as:

  • Incrementally verifiable computation (IVC) techniques
  • Degelating proof generation
    • Full proof delegation to trusted execution environments (TEEs)
    • Full proof delegation via coSNARKs (e.g. zkSaaS)
    • Partial proof delegation by separating a circuit into private and public computations (Aztec’s zk-rollup)
  • Streaming the witness (Gemini, Ligetron)

Yet another approach is leveraging GPUs, which significantly boost performance for parallelizable tasks. Ingonyama’s ICICLE, a CUDA library for ZKPs, is a prime example of this approach and has been integrated into various ZK frameworks. However, CUDA-based solutions don’t work on most mobile GPUs.

One of the more prominent frameworks for leveraging GPU on mobile devices is WebGPU, a powerful abstraction that works with GPUs on Android and iOS, as well as most other platforms (Linux, Windows, MacOS, etc). As its name suggests, it works on major browsers by default (Chrome, Edge, Brave, Safari enabled it in iOS 18.2, and Firefox will soon enable it), and it can also be integrated with native iOS and Android applications.

Recently, we have been integrating WebGPU with Starkware’s Stwo prover and have found 5x improvements for evaluating constraint polynomials and 2x for the overall proving pipeline. In this post, we’d like to share some of our learnings along the way and hopefully convince the readers to integrate it into their prover frameworks.

WebGPU Fundamentals: Computation Workflow and Memory Hierarchy

Before diving into specific implementations and optimizations, let us gently introduce some fundamental concepts of GPU. A very simplified workflow between of a program using GPU looks as follows:

workflow.png

In a nutshell, the CPU provides the GPU with a program (called a compute shader) along with some inputs. The GPU copies the program to each of its many cores and each core executes the program in parallel. After all the cores finish executing the program, the GPU returns the results to the CPU. As the font in red suggests, there are a lot of operations other than just executing on the GPU, and we will later show that they are one of the major bottlenecks of offloading computation to the GPU.

memory-hierarchy.png

Another concept to understand is the memory hierarchy of GPUs. Specific terminology differs among GPU vendors, but focusing on terms used in WebGPU, a GPU is basically made up of multiple workgroups that are each made up of multiple work items. Each work item has a thread that has access to its own local memory, which is very fast to access (think 1 cycle) since it’s closest to the thread. But a work item also has access to a shared memory, which can be accessed by any work items in the same workgroup. It is designed to act as a cache among the same workgroup threads and provides low-latency memory access due to its physical proximity with the threads. Finally, all work items of any workgroup can access the global memory, which is the slowest to access as it is located separately in a VRAM and thus requires the most cycles. In general, access to the global memory is also not cached, which means that every time the global memory is invoked, the GPU needs to retrieve data from the VRAM. A notable exception is when one instantiates a read-only buffer on the global memory, which will be cached by the GPU, and thus will be faster to access.

In short, a GPU has access to 3 different levels in the memory hierarchy and also supports caching. Below is a table that shows the variable types for each level that can be used when writing WGSL, a shader language for WebGPU.

Variable Type Hierarchy Size Access Speed (Cycles are estimates)
var<private> Work Item Per-thread register/memory Read-Write Fastest (~1 cycle)
var<workgroup> Workgroup ~16KB Read-Write Fast (~5-20 cycles)
var<uniform> Global ~64KB Read-only Medium (Cached, ~20-100 cycles)
var<storage> Global 128MB~ (depends on VRAM size) Read-Write Slowest (Uncached, ~200-800 cycles)

For a more detailed explainer on WebGPU, we recommend going through this post: WebGPU-All of the cores, none of the canvas.

Now, given that we now have some understanding of how GPU works, let’s dive into how to use the tools in WebGPU by optimizing an FFT operation step-by-step.

Notes from Implementing NTT (Number Theoretic Transform) in WebGPU

The Number Theoretic Transform (NTT) is essentially a Fast Fourier Transform (FFT) performed over a finite field instead of complex numbers. In zk-proof protocols, various operations on polynomials are required, and NTT can be used to interpolate or evaluate these polynomials.

In this section, we’ll build an NTT in WGSL from scratch. We’ll start with a basic single-threaded version and then step-by-step add multi-threading and workgroup memory optimizations. WebGPU compute shaders let you tap into the GPU’s parallel power for tasks like NTT. Each step includes clear WGSL code snippets and explanations.

Rust implementation

fft-0.png

This Rust implementation uses the Cooley–Tukey algorithm. It breaks down the NTT into successive stages where each stage applies butterfly operations to combine elements, efficiently transforming the data in-place. For more details, check out the Wikipedia article on the Cooley–Tukey FFT algorithm. Since NTT is an FFT defined over a finite field, the principles behind NTT and FFT are the same.

fn ntt(data: &mut [u32], twiddles: &[u32]) {
    let n = data.len();
    let mut stride = 1;
    let mut twiddle_offset = 0;
    while stride < n {
        let half_jump = stride;
        let jump = stride * 2;
        for base in (0..n).step_by(jump) {
            for j in 0..half_jump {
                let even_index = base + j;
                let odd_index = even_index + half_jump;
                let a = data[even_index];
                let b = data[odd_index];
                let tw = twiddles[twiddle_offset + j];
                let res = butterfly(a, b, tw);
                data[even_index] = res.even;
                data[odd_index] = res.odd;
            }
        }
        twiddle_offset += half_jump;
        stride *= 2;
    }
}

Single-Threaded WGSL implementation

This is the simplest version of our NTT implementation. It’s designed to run on a single GPU thread, meaning all the butterfly operations are executed sequentially. The coefficients of the polynomial are loaded into the data storage variable before the program begins. They are computed in-place, so the evaluation result of the polynomial is stored into the data storage variable also. After execution ends, this value must be retrieved to the CPU memory.

Although the underlying logic is identical, this single-threaded WGSL implementation performs slower than a single-threaded Rust version. Even when disregarding the CPU-to-GPU buffer overhead and the latency incurred when invoking the GPU.

This is because GPUs are optimized for executing many simple operations concurrently, boasting thousands of execution units compared to the few dozen typically found in modern CPUs. However, the performance of an individual GPU execution unit is poorer, which ultimately results in slower overall performance for this implementation. Additionally, while CPUs operate with a unified memory system, GPUs feature a multi-tiered memory hierarchy. In this case, the in-place computations on the data storage variable are carried out in global memory, which is the slowest memory type in the GPU, further impacting performance.

@group(0) @binding(0)
var<storage, read_write> data: array<u32>;

@group(0) @binding(1)
var<storage, read> twiddles: array<u32>;

const N: u32 = 8u; // NTT size

@compute @workgroup_size(1, 1, 1)
fn single_thread_ntt(@builtin(local_invocation_id) local_id : vec3<u32>) {
    var stride: u32 = 1u;
    var twiddleOffset: u32 = 0u;

    while (stride < N) {
        let halfJump: u32 = stride;
        let jump: u32 = stride * 2u;

        for (var base: u32 = 0u; base < N; base = base + jump) {
            for (var j: u32 = 0u; j < halfJump; j = j + 1u) {
                butterfly(base, j, halfJump, twiddleOffset);
            }
        }
        // Update the twiddle offset and stride
        twiddleOffset = twiddleOffset + halfJump;
        stride = stride * 2u;

        // No need to synchronize since it's running on a single thread
    }
}

Multi-threaded version

A workgroup is a small collection of threads. In WebGPU, parallelization is determined by the size of the workgroup and the number of workgroups executed. For instance, if you run a workgroup of size 64 using the dispatch_workgroup as described below, it will result in 64 * 12 = 768 threads.

compute_pass.dispatch_workgroups(3, 2, 2); // e.g., spawn 3 * 2 * 2 = 12 workgroups

The following example demonstrates parallelizing the task using a single workgroup. There are 4 threads, and in each stride iteration of the NTT, the threads divide the indices among themselves for processing, synchronize afterwards, and then repeat this process. In this code, we invoke storageBarrier() to synchronize access to buffers in the storage address space among invocations within a single workgroup.

fft-1.png

This example has the advantage of using multiple threads; however, it needs to access the data storage variable every time(again, storage variable - global memory - is the slowest memory type in the GPU) and requires all threads to synchronize on every stride, leaving room for further optimization.

const NUM_THREADS: u32 = 4u;

// 1 Workgroup, 4 threads per workgroup, 4 threads total
@compute @workgroup_size(NUM_THREADS, 1, 1)
fn single_workgroup_ntt(@builtin(local_invocation_id) local_id: vec3<u32>) {
    var stride: u32 = 1u;
    var twiddleOffset: u32 = 0u;

    while (stride < N) {
        let halfJump: u32 = stride;
        let jump: u32 = stride * 2u;

        // Total number of base iterations (each iteration increments by jump)
        let numBases: u32 = N / jump;
        let threadId: u32 = local_id.x; // 0 ~ (NUM_THREADS-1)

        // Each thread starts at its index and processes every 4th iteration
        for (var i: u32 = threadId; i < numBases; i = i + NUM_THREADS) {
            let base: u32 = i * jump;
            for (var j: u32 = 0u; j < halfJump; j = j + 1u) {
                butterfly(base, j, halfJump, twiddleOffset);
            }
        }
        // Update the twiddle offset and stride
        twiddleOffset = twiddleOffset + halfJump;
        stride = stride * 2u;

				// Use storageBarrier() to synchronize threads for
				// correct access to the 'data' storage variable.
        storageBarrier();
    }
}

Using multiple workgroups

fft-2.png

In the CT algorithm, because computations can be performed independently up to a certain point, they can be isolated and executed per workgroup. This approach allows the use of workgroup memory, which is faster to access than storage variables. Also, this implementation does not require all threads to synchronize on every stride, which is a good improvement. In this code, we invoke workgroupBarrier() to synchronize access to buffers in the workgroup address space among invocations within a single workgroup.

So this code goes like below:

  1. Each workgroup performs the NTT butterfly operations only on its own assigned range (baseIndex to baseIndex + subBlockSize - 1).
  2. After repeating for partialStages times it finishes.
  3. Once the first log2(subBlockSize) stages are done, we need to start merging pairs of sub-blocks (forming larger sub-blocks).

At this point, different workgroups may access the same memory range, so cross-workgroup synchronization is required.

Standard WGSL does not provide a global barrier (i.e., no built-in way to synchronize across different workgroups). Thus, we typically end the dispatch → let the CPU (host) synchronize everything → then launch another dispatch to handle the remaining stages.

// launch multi_workgroup_ntt() using 2 workgroups
compute_pass.set_pipeline(&self.multi_workgroup_ntt);
compute_pass.dispatch_workgroups(2, 1, 1);

// let cpu synchronize everything, then launch another dispatch
compute_pass.set_pipeline(&self.single_workgroup_ntt);
compute_pass.dispatch_workgroups(1, 1, 1);
const N: u32 = 8u;
const NUM_THREADS: u32 = 4u;
var<workgroup> workgroup_data: array<u32, N>;

// External assumption: for instance, if we dispatch 2 workgroups,
//   => totalWorkgroups = 2
//   => subBlockSize = N / totalWorkgroups = 4
//   => partialStages = log2(4) = 2 (meaning we can proceed up to 2 stages without cross-workgroup conflicts)
// Because the input is in bit-reversed order, each "4-sized section" will not conflict with others
// for the first 2 stages. After that, sub-blocks must merge, requiring global sync.

@compute @workgroup_size(NUM_THREADS, 1, 1)
fn multi_workgroup_ntt(
    @builtin(local_invocation_id) local_id: vec3<u32>,
    @builtin(workgroup_id) workgroup_id: vec3<u32>,
    // How many workgroups to use (e.g., 2)
    totalWorkgroups: u32,
    // The "maximum number of stages" to be processed in this dispatch (e.g., 2)
    partialStages: u32,
) {
		// =====================================================================
    // Here, it is assumed that there is code to copy data from storage memory (data)
    // to workgroup memory (workgroup_data) at the beginning.
    // =====================================================================

    // Current workgroup ID
    let gId = workgroup_id.x; // 0 ~ (totalWorkgroups-1)
    // Thread ID within a single workgroup
    let tId = local_id.x;     // 0 ~ (NUM_THREADS-1)

    // subBlockSize: The size of the physical section this workgroup will handle
    let subBlockSize = N / totalWorkgroups;

    // baseIndex: The start index in memory for this workgroup
    let baseIndex = gId * subBlockSize;

    var stage: u32 = 0u;
    var stride: u32 = 1u;
    var twiddleOffset: u32 = 0u;

    // Repeat partialStages times (e.g., 2 times)
    while (stage < partialStages) {
        let halfJump = stride;
        let jump = stride * 2u;

        // Traverse base in increments of 'jump' only within this sub-block
        let numBasesInSubBlock = subBlockSize / jump;

        // Parallel processing: the NUM_THREADS in one workgroup split up the base loop
        var i = tId;
        while (i < numBasesInSubBlock) {
            let base = baseIndex + i * jump;
            for (var j: u32 = 0u; j < halfJump; j = j + 1u) {
                butterfly_workgroup(base, j, halfJump, twiddleOffset);
            }
            i = i + NUM_THREADS;
        }

        twiddleOffset = twiddleOffset + halfJump;
        stride = stride * 2u;
        stage = stage + 1u;

        // Use workgroupBarrier() to synchronize threads for
        // correct access to the 'workgroup_data' workgroup variable.
        workgroupBarrier();
    }

    // =====================================================================
    // Here, it is assumed that there is code to copy the computed results
    // from workgroup memory (workgroup_data) back to storage memory (data).
    // =====================================================================

    // By this point:
    //   - Each workgroup has finished partialStages (=2) stages of NTT on its own subBlockSize section
    //   - There was no conflict between different workgroups (thanks to the bit-reversed ordering).
    //
    // The remaining stages (log2(N) - partialStages) involve merging sub-blocks,
    // which is not safe to do in a single dispatch without cross-workgroup sync.
    // Typically, the CPU will issue another dispatch to finish the remaining stages.
}

More things to consider

Trade-offs between workgroup size and overall performance

The size of a workgroup is critical for fully leveraging the GPU’s parallel processing capabilities. If a workgroup is too small, the hardware’s parallel execution units remain underutilized, leading to suboptimal performance. Conversely, overly large workgroups may overconsume resources like registers and shared memory, which reduces the number of workgroups that can execute concurrently. As a general guideline, starting with 64 threads—or a multiple that matches the underlying hardware’s warp/wavefront size—is recommended. This strategy is designed to align with the GPU’s warp (or wavefront) size (e.g., NVIDIA uses 32 and AMD uses 64) to maximize parallelism.

Minimize data transfers between CPU and GPU

Profiling your WebGPU code often reveals that copying buffers between the CPU and GPU (in either direction) can be a major performance bottleneck. The overhead of these data transfers can quickly eclipse other optimizations, negating potential gains. As a practical approach, it is generally more efficient to reduce the volume of data exchanged and increase the computational workload on the GPU. By restructuring your algorithms to perform as much processing as possible on the GPU, you can minimize these costly transfers and achieve significant performance improvements.

Challenges in Using WebGPU

Shader Language Constraints

Now that we are familiar with how to code with WGSL for maximum performance, let’s explore some of its limitations by comparing it to other popular shader languages:

Feature WGSL (WebGPU) CUDA (NVIDIA) MSL (Apple Metal) HLSL (DirectX) GLSL (OpenGL/Vulkan)
Scalar Types f32, i32, u32, bool (❌ No f64/i64) f16, f32, f64, i8, i16, i32, i64, u8, u16, u32, u64, bool f16, f32, i8, i16, i32, i64, u8, u16, u32, u64, bool f16, f32, f64, i8, i16, i32, i64, u8, u16, u32, u64, bool f16, f32, f64, i8, i16, i32, i64, u8, u16, u32, u64, bool
Arbitrary Length Arrays ❌ Allowed only for storage memory space ✅ (via malloc, new, vector<T>) ✅ (via device array<T>) ❌ (Allowed only for Vulkan GLSL with buffer storage)
Implicit Scalar Conversion ❌ (Explicit casting required) ✅ (e.g., int → float)
Recursion ❌ (Allowed only for Vulkan GLSL using certain extensions)
Function Overloading
Cyclic Dependencies

Some highlights:

  • WGSL doesn’t natively support 64-bit integers, so any integer exceeding 32-bits need to be decomposed.
  • Very limited support for arbitrary length arrays: a runtime-sized array can only be created as the last element of a storage memory space buffer (i.e. global memory space).

Limited Library Ecosystem

Another challenge is the lack of libraries compared to incumbent languages like CUDA, where there are optimized implementations of common GPU computations (cuDNN, cuFFT, cuBLAS, etc). Also, there is no native support for file imports, which means that to use shader libraries, developers need to manually concatenate shader code, raising the risk of errors like function overloading and cyclic dependencies.

On the bright side, there have been efforts to create an extension of the WGSL language (WESL), which will first and foremost provide an import functionality that will solve the composing issue above and also add additional features such as conditional compilation and runtime variable insertion.

Integrating WebGPU into ZK Frameworks

Optimizing Computation vs. Data Transfer

In general, computation on the GPU is fast and cost-efficient, whereas transferring data between the CPU and GPU can be a major bottleneck (benchmarks). In zk circuits or zkVMs, this could mean offloading the entire proving process to the GPU instead of offloading multiple components—-especially when the input and output data are small relative to the intermediate data generated.

Another consideration is that the amount of memory that mobile device GPUs have access to varies, and in order to support a wide range of mobile devices, one should provide a hybrid approach of deciding when to use the GPU or not. WebGPU has an API that exposes the maximum size of memory that the underlying device GPU has access to, which can be used to determine if the GPU is capable of running the computation.

Considerations for Native Builds

When targeting native builds instead of the browser, however, the cost of data transfer drops to near zero. This is because many modern mobile devices use unified memory for CPU and GPU and thus there is no need copy the data to a separate buffer when transferring it. (Ingonyama has a nice POC for demonstrating this on Apple Silicon)

Although WebGPU offers a configuration for this case, developers must still implement safeguards, as not all devices share unified memory.

Dynamic Invocation

WebGPU’s flexibility makes it suitable for both high-end and lower-end devices. To harness its full potential, implementations should dynamically adjust based on the device’s limits. A basic approach is to set a minimum capability threshold for GPU use, while more sophisticated methods could selectively run different components on the GPU based on memory limits.

Challenges in Integrating WebGPU for ZK

High Maintenance Costs

Porting existing proving code to WGSL, continuously updating it as the original implementations evolve, and going through additional audits can be resource intensive. And as new zkVMs, zkDSLs, and new theoretical breakthroughs continue to emerge, it would become impossible to keep up with the maintenance costs. One potential solution is to build a production-ready library of common shader routines (similar to Ingonyama’s ICICLE), which would streamline integration and reduce duplicated effort.

Lack of Benchmarks and Test Frameworks

We also lack benchmarks for measuring the performance of ZK frameworks, which would be valuable for measuring the peak memory usage when running a prover and determining what the size of the circuit (in the case of zk circuits) or the size of the program (in the case of zkVMs) should be to make it viable to run using WebGPU. Once WebGPU is integrated in a ZK framework, having a test framework that spans multiple mobile devices would also make it easier to accurately assess performance gains.

Conclusion: The Road Ahead for WebGPU in ZK

As discussed in the beginning of this post, client-side proving is key to unlocking privacy-preserving zero-knowledge applications, and WebGPU is emerging as a powerful tool to accelerate this process. By harnessing GPU parallelism, we can offload and expedite the most compute-intensive aspects of proving.

Despite challenges like high maintenance overhead and the current lack of comprehensive benchmarks, the path forward is promising. Collaborative efforts—such as the development of shared shader libraries and unified testing frameworks—can overcome these roadblocks. As the community works together, we can look forward to more efficient, scalable, and accessible privacy solutions for everyone.