State of PyTorch Hardware Acceleration 2025

Version 0.2

A Comparative Technical Analysis: NVIDIA CUDA, AMD ROCm, Google TPU (XLA), and Apple Silicon (MPS)

By Bojan Tunguz

Executive Summary

The landscape of deep learning hardware acceleration has undergone a fundamental structural shift as of 2025. The era of monolithic CUDA dominance has fractured into a heterogeneous ecosystem where architectural decisions can no longer be decoupled from compiler stack maturity. While NVIDIA’s H100 and Blackwell architectures utilizing the CUDA platform remain the operational "gold standard" for immediate stability and broad operator support, significant maturation in AMD’s ROCm stack—specifically the pivot to Triton—and Google’s PyTorch/XLA integration offers viable, and often cost-superior, alternatives for specific workloads. Concurrently, Apple Silicon has carved a distinct, unassailable niche in high-memory local prototyping, though it remains isolated from datacenter training workflows due to fundamental architectural and software capability gaps.

This report delivers a rigorous technical evaluation of PyTorch support across these four platforms. It serves as a strategic guide for hardware architects and systems engineers tasked with standardizing infrastructure for Large Language Model (LLM) and Vision Transformer (ViT) lifecycles, ranging from local prototyping on MacBook Pros to cluster-scale training on dedicated accelerators.

Executive Decision Matrix

The following matrix synthesizes the technical maturity of each platform for PyTorch 2.5+ workflows. Ratings are derived from a deep analysis of compiler stack stability, kernel availability, developer friction, and total cost of ownership (TCO).

Feature / Platform NVIDIA CUDA
(H100/Blackwell)
AMD ROCm 6.x/7.0
(MI300X/355X)
Google TPU
(v5p/Trillium)
Apple Silicon
(M3/M4)
Compiler Maturity High (Baseline) Medium-High (Rapidly improving) Medium (XLA quirks) Low (Limited Inductor)
torch.compile Stability High (Inductor + Triton) Medium (Triton/CK WIP) Medium-Low (Graph breaks) Low (CPU fallbacks)
FlashAttention-3 Native (Day 0) Lagging (CK/Triton WIP) N/A (Pallas kernels) N/A (SDPA/FlexAttn)
FP8/FP4 Support Native (FP8/NVFP4) Native (Hardware + Quark) Native (XLA formatting) No (Emulated/Upcast)
Distributed Stack High (NVLink/SHARP) Medium (RCCL Parity Issues) High (ICI/XLA Mesh) Low (Gloo only)
Debugging Ease High (Nsight) Medium (Omnitrace) Medium-Low (XLA metrics) Medium (Metal Trace)
Unit Cost (Price) High (Parity) High (Parity) Low (Aggressive) High (Device Cost)
Cost Efficiency
(Perf/$ & Perf/W)
High (General ROI) High (Memory-bound) High (Static Shapes) High (Local Inference)

1. Deep Dive: The Compilation & Runtime Stack Analysis

The introduction of PyTorch 2.0 and the torch.compile API fundamentally altered the interaction between the framework and hardware backends. The reliance on TorchDynamo (graph capture) and TorchInductor (compiler) signifies that hardware vendors can no longer simply optimize individual eager-mode kernels; they must support a complete, vertically integrated compiler stack. The efficacy of a platform in 2025 is largely determined by how well it services the Inductor-Triton pipeline.

1.1 NVIDIA CUDA: The Inductor & Triton Baseline

On NVIDIA hardware (H100, Blackwell), torch.compile functions as the reference implementation against which all others are judged. The stack is vertically integrated: Dynamo captures the Python bytecode with minimal graph breaks, Inductor lowers the FX graph into Triton kernels, and Triton compiles directly to PTX (Parallel Thread Execution) assembly. This bypasses standard CUDA C++ complexities for element-wise and fusion operations, granting Python developers "close-to-metal" performance.

2025 Status:
The ecosystem has moved beyond basic support into advanced optimization. PyTorch 2.5 introduced FlexAttention, an API leveraging torch.compile to generate fused FlashAttention kernels automatically using Triton.[1] This allows users to implement sliding window, causal mask, or prefix LM attention in pure Python, which the compiler fuses into a single efficient kernel. NVIDIA's advantage here is structural: Triton was originally designed for CUDA, ensuring that heuristics for warp scheduling, shared memory allocation, and memory coalescing are mature and aggressive by default. The support for Symmetric Memory in PyTorch 2.5 further optimizes multi-GPU kernels on NVLink-connected H100 clusters, reducing communication overhead in distributed training by enabling direct loads/stores from remote GPU memory without explicit send/recv semantics.[3]

The torch.compile stack on CUDA is also the only one to fully support Regional Compilation without recompilation, a feature introduced in PyTorch 2.5 to reduce cold start times for repeated nn.Module patterns, such as Transformer layers.[2] This capability is critical for reducing the "time-to-first-step" in large-scale training runs, a metric where NVIDIA maintains a distinct lead.

1.2 AMD ROCm: The Struggle for Triton Parity

AMD's strategy for PyTorch 2.5+ relies heavily on achieving parity with NVIDIA's Triton support. Historically, AMD relied on HIP (Heterogeneous-Compute Interface for Portability) to "hipify" CUDA code—a source-to-source translation layer. However, the future is Triton. By optimizing the Triton backend for AMDGCN ISA, AMD theoretically allows any torch.compile model to run performantly on MI300X without code changes.

ROCm 6.2/7.0 Analysis:

The transition is promising but incomplete. As of ROCm 7.0, torch.compile with the Triton backend is functional for many workloads but lacks the aggressive autotuning maturity found on CUDA.[4]

  • Triton on ROCm: AMD has invested heavily in the Triton backend. However, heuristics that work for NVIDIA's warp size (32 threads) often fail to fully saturate AMD's wavefront size (64 threads), leading to sub-optimal occupancy unless manually tuned.
  • Composable Kernel (CK): For operations where Triton is not yet performant or functionally complete (e.g., complex FlashAttention variants), AMD relies on Composable Kernel (CK), a C++ template library similar to NVIDIA's CUTLASS. PyTorch on ROCm currently uses a hybrid approach: using CK for critical monolithic ops like FlashAttention-2 and Triton for point-wise fusions.[5] This hybrid model introduces "glue code" fragility.
  • Stability: Users report that while "hello world" works, complex dependency chains often break. Installing FlashAttention usually requires specific, often forked, versions of the library rather than a simple pip install. The "dependency hell" of matching pytorch-triton-rocm versions with the underlying ROCm driver remains a significant friction point.[6]
  • AOTriton: The integration of AOTriton in ROCm 7.0 aims to solve compilation jitter by pre-compiling common kernels, reducing runtime latency.[8] This acts as a bridge solution while the dynamic JIT capabilities of the Triton backend mature.

Insight: AMD is effectively trying to bypass the "CUDA Moat" by optimizing for Triton. If they succeed, developers writing Triton kernels (or relying on Inductor) will theoretically see portability for free. However, in 2025, the "out-of-the-box" experience still lags, often requiring manual intervention in compiler flags or Docker container selection.[6]

1.3 Google TPU: The XLA Bridge & Graph Breaks

TPUs do not use the Triton/Inductor stack. Instead, they rely on PyTorch/XLA, which bridges PyTorch operations to the XLA (Accelerated Linear Algebra) compiler. The interaction model here is fundamentally different: "Lazy Tensors."

The "Lazy Tensor" Problem: PyTorch/XLA operates on a lazy execution model where operations are recorded into a graph and only executed when a result is strictly needed (e.g., printing a value, saving a checkpoint, or a .item() call).

  • Efficiency: When it works, XLA fuses operations aggressively, often outperforming hand-written CUDA kernels for massive batch sizes due to the compiler's ability to see the entire graph scope.
  • Graph Breaks: The critical failure mode in 2025 remains "graph breaks." If PyTorch code contains dynamic control flow (Python if/else based on tensor data) or operations XLA cannot trace, the execution falls back to the CPU, triggers a compilation, and then resumes. This "context switch" destroys performance.[9]
  • Dynamo Bridge: The new torch_xla bridge for Dynamo (beta in 2025) attempts to mitigate this by using Dynamo's guard system to capture graphs more robustly than the legacy lazy tensor tracing.[11] This allows for torch.compile(backend='openxla'), which provides a more "PyTorch-native" feel. However, debugging a model that constantly recompiles on TPU v5p remains a high-friction activity compared to eager-mode debugging on GPUs. The compilation time on TPU can be significant (minutes), meaning an iterative "fix-run-fix" loop is much slower than on CUDA.[9]

1.4 Apple Silicon (MPS): The Inference Island

The Metal Performance Shaders (MPS) backend has matured significantly but remains fundamentally different from the datacenter stacks. Apple's "graph" approach is MPSGraph, which is distinct from XLA or Triton.

Inductor on Metal: As of PyTorch 2.5, torch.compile support for MPS is still in early stages compared to CUDA. Apple has not fully embraced the Triton stack, as Triton generates PTX (NVIDIA) or AMDGCN (AMD). Instead, Apple relies on its own Metal shading language (MSL).

  • Execution: Most users run in Eager Mode on MPS. While performance is adequate for inference, the lack of a mature compiler stack means complex fusions (like those in torch.compile) often fallback to CPU or run as unfused generic Metal kernels.[13]
  • The MLX Factor: The existence of Apple's separate framework, MLX, creates a fragmentation risk. MLX features a lazy computation graph similar to JAX but optimized specifically for Apple's Unified Memory and Neural Engine.[14] Benchmarks consistently show MLX outperforming PyTorch MPS on identical hardware for LLM inference (up to 2-3x faster generation). For a PyTorch engineer, this presents a dilemma: the best performance on Mac often requires leaving the PyTorch ecosystem, which breaks the "write once, run anywhere" ideal of the PyTorch prototyping workflow.

2. Kernel Ecosystem & Operator Coverage

The "software gap" is most visible when moving beyond standard matrix multiplications into specialized operators required for state-of-the-art LLMs. The availability of optimized kernels for Attention and Quantization is often the deciding factor for hardware viability.

2.1 The "FlashAttention" Test

FlashAttention (FA) is the benchmark for accelerator sufficiency. It reduces memory complexity from quadratic to linear and is essential for long-context LLMs.

  • NVIDIA (H100 & B200): FlashAttention-3 is native. On the H100 (Hopper), it leverages asynchronous copy engines (TMA) and WGMMA instructions to overlap data movement with computation entirely. The B200 (Blackwell) inherits this "Day 0" support while adding hardware acceleration for even lower precision formats in the attention mechanism. Installation is trivial via standard wheels.[5]
  • AMD (MI300X): Support exists but is complex. AMD uses a specific fork or the Composable Kernel (CK) backend. While FA-2 is supported, FA-3 (Hopper specific optimizations) is not directly translatable. The sliding window attention (SWA) and other variants often lack Triton support on ROCm, forcing users to rely on the CK backend which may have different performance characteristics or bugs.[16]
    Insight: The MI300X has raw hardware capability (matrix cores) to run FA efficiently, but the software glue is fragile. Reports indicate that pip install flash-attn often fails on ROCm without specific build flags or pre-built Docker containers.[18] The "official" Dao-AILab repo has ROCm support, but it is often versions behind the CUDA release.
  • TPU (v5p): XLA has its own attention implementations, often referred to as Pallas kernels. While efficient, they are not "FlashAttention" in the strictly compatible sense. Porting a model hardcoded for flash_attn libraries to TPU requires code changes to use torch.nn.functional.scaled_dot_product_attention (SDPA), which XLA then lowers to its own fused kernel.[8] This breaks the "drop-in replacement" promise for research codebases heavily optimized for CUDA FA interfaces.
  • Apple (MPS): Native FlashAttention support is absent in the strict sense. MPS relies on Apple's implementation of SDPA. While functional, it does not support the advanced features of FA2/FA3 (like variable sequence lengths in a single batch without padding) as efficiently as the CUDA implementation. FlexAttention (prototype) allows some custom attention patterns, but performance on Metal is not comparable to dedicated tensor cores.[2]

2.2 Quantization: The FP8 and INT4 Frontier

FP8 and FP4 Support:

  • NVIDIA (H100 & B200): The H100 established FP8 as the training standard. The B200 (Blackwell) pushes the frontier with native FP4 Tensor Cores, effectively doubling throughput and halving memory usage compared to FP8. PyTorch 2.5+ supports these formats via `torchao` and Transformer Engine, though robust FP4 training typically requires the new Micro-Scaling (MX) formats to manage dynamic range.[44][48]
  • AMD (MI300X): Native FP8 support is robust, enabled via Quark and updated libraries like hipBLASLt. While the MI300X lacks the dedicated FP4 hardware engines found in Blackwell, AMD focuses on maximizing FP8 throughput for both training and inference.[16]
  • TPU: Native support for low precision (BF16/INT8) is strong; FP8 is supported on v5p/Trillium.[20] The XLA compiler handles the layout transformations automatically, often making it easier to use than on GPUs where explicit casting is sometimes required.
  • Apple: No native hardware support for FP8 or FP4 training. FP16/BF16 is the standard. FP8 operations are emulated (upcast to BF16), which negates the performance benefit.

INT4/Quantization on Apple:
This is Apple's stronghold. The unified memory architecture allows loading massive quantized models (e.g., Llama-3-70B in 4-bit) into RAM.
MLX vs. PyTorch: MLX provides seamless 4-bit quantization. PyTorch MPS support for INT4 is catching up via torchao, allowing native int4 weight-only quantization.[21] However, benchmarks suggest MLX still holds a performance edge in decoding speed and memory bandwidth utilization for quantized models.[15] For a researcher wanting to run a 70B model on a laptop, MLX is the superior runtime, while PyTorch/MPS remains a second-class citizen for quantized inference speed.

2.3 Missing Operators and Fallbacks

One of the most insidious performance killers is the "silent CPU fallback."

  • ROCm: While operator coverage has improved drastically (claiming almost full parity), edge cases in complex linear algebra (e.g., certain sparse matrix operations or FFTs) can still trigger fallbacks or compilation errors.[6]
  • MPS: The MPS backend is stricter than CUDA. It lacks support for float64 (double precision) entirely. If a model tries to allocate a float64 tensor, PyTorch throws a runtime error or silently keeps it on CPU.[24] This makes migration of scientific computing code or older models (which might use double precision for stability) painful. Furthermore, operations like torch.istft (Inverse Short-Time Fourier Transform) have only recently gained support or rely on imperfect implementations.[1]

3. Memory & Architecture Nuances

3.1 Unified Memory (Apple) vs. HBM3 (Datacenter)

The critical distinction for 2025 is Capacity vs. Bandwidth. This trade-off dictates the utility of each platform.

  • Apple M4: With up to 512GB of Unified Memory, a single Mac Studio can hold a 405B parameter model (quantized). This is physically impossible on a single H100 (80GB).
    Implication: For inference of massive models by a single researcher, the Mac is superior to a single H100. It democratizes access to "super-sized" models without needing a cluster.
  • Bottleneck: Bandwidth. The M3/M4 Ultra memory bandwidth (~800 GB/s) pales in comparison to H100's HBM3 (3.35 TB/s) or MI300X's HBM3 (5.3 TB/s).[26] Token generation on Mac will be significantly slower (tokens per second), but it will run, whereas it would OOM (Out Of Memory) on a discrete GPU.
  • Real-World Impact: A 70B model might generate at 10 tokens/sec on an M3 Ultra, while an H100 might do 100+ tokens/sec. For prototyping, 10 t/s is acceptable. For serving, it is not.

3.2 Scale-Up and Scale-Out Networking

When a single GPU is insufficient, the interconnect topology becomes the bottleneck. It is crucial to distinguish between Scale-Up (increasing capacity within a node/rack via shared memory) and Scale-Out (connecting thousands of nodes via network fabric).

  • Scale-Up (NVLink & Infinity Fabric): These technologies create "super-nodes" where multiple GPUs (e.g., 8 to 72) act as a single logical device with shared memory semantics.
    NVIDIA NVLink: The industry benchmark. NVLink 4.0/5.0 provides 900 GB/s to 1.8 TB/s bidirectional bandwidth. This enables massive Model Parallelism (tensor slicing) across a rack (GB200 NVL72) with minimal latency penalties. In PyTorch, this manifests as near-linear scaling for `torch.distributed.all_reduce` operations using the NCCL backend.[54]
    AMD Infinity Fabric: Used for chiplet interconnects and socket-to-socket communication. While providing coherent memory access between CPU and GPU, its raw GPU-to-GPU bandwidth trails NVLink in large-scale topologies. However, AMD's architecture allows unique capabilities: for example, a DLRM embedding table exceeding GPU VRAM can reside in host memory and be accessed directly by the GPU kernel via coherent links with zero-copy overhead.[56]
  • Scale-Out (InfiniBand & Spectrum-X): These technologies connect the "super-nodes" to form a datacenter-scale cluster.
    NVIDIA Spectrum-X (Ethernet) & Quantum-X800 (InfiniBand): These are the fabrics that handle data parallelism across thousands of GPUs. Spectrum-X brings InfiniBand-like quality of service to standard Ethernet. For PyTorch developers using `torch.distributed.fsdp`, this reduces the variance in `ProcessGroupNCCL` timeout errors caused by "noisy neighbor" packet drops in multi-tenant clouds.[55]
    Google ICI: A dedicated low-latency mesh network exclusive to TPUs. Unlike the flexible NCCL backend, ICI relies on the `torch_xla` distributed backend to handle rigid GSPMD sharding patterns.[31]

4. Developer Experience (Friction Analysis)

4.1 Installation & Setup

  • CUDA: pip install torch. It works. The container ecosystem is mature. NVIDIA's containers on NGC (NVIDIA GPU Cloud) are the standard reference.
  • ROCm: Improved, but often requires pip install --index-url https://download.pytorch.org/whl/rocm6.x. The primary friction is the dependency on the underlying OS driver. While CUDA has good forward compatibility (old drivers run new CUDA toolkit), ROCm is more sensitive. Docker is strongly recommended to avoid system library conflicts (the "dependency hell" of libstdc++ versions).[6]
  • TPU: Requires torch_xla installation. Environment is usually pre-configured in Google Cloud TPU VMs. Local development is impossible; you must develop on the cloud VM. This "remote-only" development loop introduces latency in the "edit-run-debug" cycle.[32]
  • Apple Silicon: pip install torch. It is seamless. Support is bundled in the standard PyTorch wheel. Unlike ROCm or CUDA, there are no drivers to manage manually; they are part of macOS. This enables the "zero-setup" local environment that makes the Mac so popular for prototyping.

4.2 Debugging Tools: The "War Stories" Reality

When a kernel crashes or loss diverges, the quality of the debugger determines if you fix it in 5 minutes or 5 days. Here is the developer reality in 2025:

  • NVIDIA (The "Easy" Mode): Nsight Systems remains the gold standard because it speaks PyTorch natively. The integration with `torch.autograd.profiler.record_function` means you see Python context ("Layer 3 Attention") aligned perfectly with GPU kernel execution in the timeline. When a CUDA kernel crashes, Nsight usually identifies the exact block and thread index, making root-cause analysis straightforward.[33]
  • AMD (The Detective Mode): Users report that while Omnitrace is powerful, the setup friction is high (kernel module permissions, environment variables). A common complaint is the "binary dump" feeling—getting a raw profile that requires significant post-processing to interpret. Unlike NVIDIA's polished GUI, developers often feel like they are "debugging in the dark," relying on printf debugging inside kernels to trace segmentation faults that lack clear stack traces.[6]
  • Google TPU (The Black Box): The fundamental disconnect here is structural: you debug the graph construction, not the execution. If XLA crashes during compilation or hangs, the stack trace points to the compiler IR, not your PyTorch code. Developers describe a painful "comment-out-and-pray" loop to isolate the specific operator causing a graph break or performance regression.[32]
  • Apple (The Silent Killer): The most critical user complaint is the "Silent NaN". Because MPS is less strict about floating-point exceptions than CUDA, a model can train for hours producing garbage values before you notice. Debugging tools like Metal System Trace are excellent for graphics but lack deep learning context (tensor shapes/names). Enabling `torch.autograd.detect_anomaly` on MPS often slows execution by 100x, making it practically unusable for large models.[37]

5. Local vs. Cloud Delineation: The "MacBook to Blackwell" Workflow

A common workflow is prototyping on a MacBook Pro (M3/M4) and moving to a cloud cluster for training. In the era of A100s, this was difficult. In the era of Blackwell (B200), it has become architecturally hazardous. The "isomorphic" assumption—that your local code will behave identically to the cloud code—is now fundamentally broken by new hardware precision and scale-up topologies.

5.1 The Precision Chasm: NVFP4 vs. MPS

The defining feature of the Blackwell generation is the NVFP4 Tensor Core. This hardware primitive allows 4-bit floating point inference and training, doubling throughput and halving memory footprint compared to FP8. Apple Silicon has no equivalent hardware.

  • The Emulation Trap: While you can use `torchao` to simulate low-bit quantization (Int4 or FP4) on a Mac, MPS executes these as emulated operations (often upcast to BF16 or Float32 for computation). This creates a dangerous "correctness gap." A model might converge numerically on Mac (because it uses higher precision math under the hood) but diverge or explode when running on real B200 NVFP4 hardware due to limited dynamic range.[46]
  • Development Blindspot: You cannot profile performance locally. An FP4 kernel on Blackwell relies on specific memory alignment and tensor layouts. Optimizing a custom kernel on Mac Metal is useless for B200 performance tuning.

5.2 The Interconnect Cliff: NVLink Switch vs. Local

Rack-Scale vs. Node-Scale: The B200 is rarely deployed as a single chip. The standard unit of compute is the GB200 NVL72—a rack of 72 GPUs connected via NVLink Switch acting as a single massive accelerator.

Impact on Logic: Distributed training logic on a Mac (using `backend="gloo"`) is completely isolated from the realities of the NVLink Switch domain.

  • Collective Primitives: On GB200, primitives like `all_reduce` or `reduce_scatter` happen over the switch fabric at 1.8 TB/s. On Mac, they happen over CPU/RAM. Prototyping complex "Sequence Parallelism" or "Context Parallelism" strategies locally is now functionally impossible because the communication latency ratios are off by orders of magnitude.[44]

5.3 Algorithmic Drift: RNG & Stochastic Rounding

In low-precision regimes (FP8/FP4), standard "round-to-nearest" operations cause gradient stagnation. Stochastic Rounding is required for convergence.

The Drift: NVIDIA's Transformer Engine hardware implements specific stochastic rounding probabilities. MPS does not mirror this exactly. Even with identical seeds (`torch.manual_seed`), the *numerical path* of training will diverge immediately. Debugging a loss spike on a Mac that occurred on step 10,000 of a B200 run is mathematically futile; the accumulation errors are distinct to the hardware architecture.[47]

6. Critical Failure Modes: The Next-Gen Cliff (2025-2026)

Platform Known Failure Mode Impact
NVIDIA Blackwell (B200) MXFP4 Quantization Noise The Blackwell architecture defaults to Micro-Scaling (MX) formats for FP4. Training runs migrated from H100 (FP8) can experience loss divergence because standard `torch.cuda.amp.GradScaler` logic does not account for the limited dynamic range of FP4 without specialized `torchao` floating-point format adjustments.[48]
GB200 NVL72 (Rack Scale) Rack-Level Straggler Propagation In the 72-GPU NVLink domain, the entire rack acts as a single clock-synchronous domain. A single GPU hitting thermal limits throttles the entire rack to the slowest chip's speed, causing massive throughput drops (up to 40%) visible in `torch.profiler` as extended collective wait times.[49]
AMD MI355X (CDNA 4) Triton Tuning Lag While MI300X support is mature, the MI355X architectural changes (CDNA 4) break existing Triton heuristics optimized for CDNA 3. `torch.compile` kernels may silently degrade by 30-50% vs. hand-written HIP kernels until the Triton backend is updated.[6]
Google TPU (v5p) Mesh Desynchronization In large Pods, a single straggler chip stalls the entire synchronous GSPMD mesh. In PyTorch/XLA, this manifests as blocking `XM.mark_step()` calls taking significantly longer than expected, halting the entire training loop across thousands of chips.[9]
Apple Silicon (M4 Ultra) Unified Memory Swap Death The M4 Ultra's 512GB limit is hard. Exceeding it by even 1GB pushes tensors to the SSD swap file. Unlike CUDA OOM errors which halt execution, macOS attempts to swap, dropping inference speed from 20 t/s to 0.01 t/s, effectively hanging the process indefinitely.[13]

7. Cost Analysis & Efficiency Metrics (2025 Edition)

In the post-H100 era, raw "dollars per hour" is a deceptive metric. The analysis must pivot to System Efficiency (Performance/Watt at Rack Scale), Time-to-Convergence, and Model FLOPs Utilization (MFU).

Effective Training Cost (Adjusted for MFU & Convergence)

"Effective Cost" normalizes hourly price by typical MFU rates (e.g., 60% for H100, 45% for TPU) and FP4 speedups.

7.1 Local/Edge Inference (Prototyping)

The "Local" category has bifurcated into "Luxury Prototyping" (Apple), "Budget Linux Dev" (Consumer AMD), and "Kernel Correctness" (Pro NVIDIA).

  • Mac Studio (M4 Ultra): The defining advantage is Unified Memory (up to 512GB). This allows a single machine to load a quantized Llama-3-405B (requiring ~230GB VRAM at 4-bit) for inference. While generation speed is modest (~8-10 tokens/sec for 70B models), it is the only way to run frontier models locally without a $30k+ cluster. New ExecuTorch support optimizes PyTorch export to Core ML, further reducing latency for edge deployment.[62]
  • NVIDIA RTX 6000 Ada (Workstation): The "Gold Standard" for kernel development. With 48GB VRAM and exact architectural parity with datacenter H100s (Hopper/Ada ISA similarities), it is indispensable for engineers writing custom CUDA/Triton kernels who need to verify numerical correctness before deploying to the cloud.
  • AMD Radeon RX 7900 XTX (Linux): The budget champion for Linux-based researchers. For under $1,000, developers get 24GB VRAM. With ROCm 6.2 on Linux, PyTorch stability is now "production-grade" for inference. It serves as a viable, low-cost entry point for testing Triton kernels (albeit with different warp sizes) before scaling to MI300X clusters.[63]

7.2 Datacenter Inference (High QPS)

The battle here is Throughput vs. Capacity vs. Efficiency.

  • NVIDIA B200 (Blackwell): The throughput monster. Delivering 20 PFLOPS of FP4 tensor compute (vs. H100's ~4 PFLOPS FP8), it enables a 4x throughput increase for models trained with Micro-Scaling (MX) formats. This drastically lowers the "Cost per Million Tokens" for serving popular open-weights models like Llama-3-70B, effectively offsetting the higher hourly rental price.[64]
  • AMD MI325X: The memory capacity play. With 288GB HBM3e (vs B200's 192GB), it is purpose-built for RAG (Retrieval Augmented Generation). A single MI325X can serve a Llama-3-70B model with a 128k context window batch size that would require two B200s (due to KV cache memory pressure). For memory-bound workloads, this offers superior economics.[65]
  • Google TPU v5e: The efficiency specialist. While v5p chases training speed, TPU v5e is optimized specifically for transformer inference and small-scale training. Google claims up to 2.5x performance-per-dollar improvement over v4 for inference workloads, making it a highly attractive option for serving mid-sized models (8B-70B) where latency SLAs are flexible but cost is paramount.[67]

7.3 Cluster Training & MFU Economics

The "Communication Tax": Raw TFLOPS don't train models; MFU (Model FLOPs Utilization) does.

  • NVIDIA H100/B200 (The SHARP Advantage): H100 clusters typically achieve 60-70% MFU on GPT-style workloads. A key contributor is SHARP (Scalable Hierarchical Aggregation and Reduction Protocol), which offloads collective operations (like All-Reduce) to the NVLink Switches themselves. This frees up the GPU SMs to keep computing gradients, minimizing the "communication tax" that usually kills scaling efficiency on huge clusters.[57]
  • AMD MI300X (Capacity vs. Interconnect): While raw compute is competitive, MFU often trails NVIDIA (typically 45-55%) in large clusters due to the lack of in-network reduction hardware equivalent to SHARP. However, the 192GB VRAM allows for larger local batch sizes, which increases arithmetic intensity and can partially mask communication latency. For cost-sensitive training where "time-to-market" is less critical than "cost-to-convergence," this is a viable trade-off.[68]
  • Google TPU v5p (The XLA Trade-off): TPU Pods often run at ~50-55% MFU for dynamic PyTorch workloads due to XLA graph recompilations ("graph breaks"). However, for static shapes (e.g., fixed sequence length pre-training), they are incredibly efficient per watt. The cost argument here is that even with lower MFU, the significantly lower unit price (~$2.80/hr vs ~$4.50/hr for B200) yields a better "Training Cost per Epoch."[66]
  • Time-to-Convergence: Numerical stability (Section 6) matters. An FP4 run on B200 might take 20% more steps to converge than an FP8 run on H100 due to gradient noise, eroding some of the throughput gains.

Final Recommendation

Research Teams

Prototyping & Discovery

Standardize on Apple Silicon for Local + NVIDIA for Cloud.

Why: The MacBook Pro (M3/M4 Max/Ultra) is the only viable local machine for running modern LLMs (70B+) without a dedicated server rack. The productivity gain of local inference—being able to interact with the model on a plane or without internet—outweighs the friction of RNG mismatch or Gloo/NCCL switching.

Caveat: Use torch.compile sparingly on Mac. Focus on correctness in eager mode. Accept that performance tuning (kernel optimization) must happen on the cluster, not the laptop.

Production Teams

Training & Serving

Primary: NVIDIA CUDA.

Why: The engineering cost of debugging ROCm/XLA quirks currently exceeds the hardware savings for most teams. torch.compile + Triton on H100 is the most robust path for SOTA performance. The ecosystem support (libraries, profilers, StackOverflow answers) minimizes downtime.

Secondary/Value Play: AMD MI300X for Inference.

Why: If your workload is memory-bound (e.g., large batch LLM inference or serving RAG pipelines), the MI300X's 192GB VRAM and high bandwidth offer superior economics. Use vLLM (which has good ROCm support) rather than raw PyTorch to abstract away the driver complexity. The cost savings on hardware (buying fewer GPUs for the same VRAM) are significant.

Niche: Pre-training

Google TPU for Pre-training.

Why: If you are pre-training a foundation model from scratch and can commit to the XLA dialect (static shapes, functional programming style), the TPU v5p Pods offer a scalability/cost ratio that is hard to beat. However, avoid this if your research involves rapidly changing dynamic architectures or custom ops that are hard to express in XLA.

Works Cited

[1] PyTorch 2.5.0 released! : r/MachineLearning - Reddit, accessed Dec 2, 2025, https://www.reddit.com/r/MachineLearning/comments/1g62vyh/d_pytorch_250_released/
[2] PyTorch 2.5 Release Blog, accessed Dec 2, 2025, https://pytorch.org/blog/pytorch2-5/
[3] Releases · pytorch/pytorch - GitHub, accessed Dec 2, 2025, https://github.com/pytorch/pytorch/releases
[4] Empowering Developers to Build a Robust PyTorch Ecosystem on AMD ROCm™, accessed Dec 2, 2025, https://rocm.blogs.amd.com/artificial-intelligence/pytorch-amd-gpus/README.html
[5] Dao-AILab/flash-attention: Fast and memory-efficient exact attention - GitHub, accessed Dec 2, 2025, https://github.com/Dao-AILab/flash-attention
[6] ROCM Feedback for AMD - Reddit, accessed Dec 2, 2025, https://www.reddit.com/r/ROCm/comments/1i5aatx/rocm_feedback_for_amd/
[7] [Issue]: Intermittent GPU Hang HW Exception by GPU on MI300X when training with axolotl #4021 - GitHub, accessed Dec 2, 2025, https://github.com/ROCm/ROCm/issues/4021
[8] PyTorch compatibility - AMD ROCm documentation, accessed Dec 2, 2025, https://rocm.docs.amd.com/en/latest/compatibility/ml-compatibility/pytorch-compatibility.html
[9] State of torch.compile for training (August 2025) - ezyang's blog, accessed Dec 2, 2025, https://blog.ezyang.com/2025/08/state-of-torch-compile-august-2025/
[10] Working with Graph Breaks — PyTorch 2.9 documentation, accessed Dec 2, 2025, https://docs.pytorch.org/docs/stable/compile/programming_model.graph_breaks_index.html
[11] PyTorch 2.0 & XLA—The Latest Cutting Edge Features, accessed Dec 2, 2025, https://pytorch.org/blog/pytorch-2-0-xla/
[12] TorchDynamo Update 10: Integrating with PyTorch/XLA for Inference and Training, accessed Dec 2, 2025, https://dev-discuss.pytorch.org/t/torchdynamo-update-10-integrating-with-pytorch-xla-for-inference-and-training/935
[13] MPS backend — PyTorch 2.9 documentation, accessed Dec 2, 2025, https://docs.pytorch.org/docs/stable/notes/mps.html
[14] Exploring LLMs with MLX and the Neural Accelerators in the M5 GPU, accessed Dec 2, 2025, https://machinelearning.apple.com/research/exploring-llms-mlx-m5
[15] How Fast Is MLX? A Comprehensive Benchmark on 8 Apple Silicon Chips and 4 CUDA GPUs, accessed Dec 2, 2025, https://towardsdatascience.com/how-fast-is-mlx-a-comprehensive-benchmark-on-8-apple-silicon-chips-and-4-cuda-gpus-378a0ae356a0/
[16] MI300X Testing - llm-tracker, accessed Dec 2, 2025, https://llm-tracker.info/MI300X-Testing
[17] Unable to compile for MI300X (gfx942) with ROCm 6.2.2... Issue #1269 · Dao-AILab/flash-attention - GitHub, accessed Dec 2, 2025, https://github.com/Dao-AILab/flash-attention/issues/1269
[18] The State of Flash Attention on ROCm - Reddit, accessed Dec 2, 2025, https://www.reddit.com/r/ROCm/comments/1m7jy5w/the_state_of_flash_attention_on_rocm/
[19] FlashAttention-3 rocm install flash_attn_interface ModuleNotFoundError #1653 - GitHub, accessed Dec 2, 2025, https://github.com/Dao-AILab/flash-attention/issues/1653
[20] Cloud TPU release notes - Google Cloud Documentation, accessed Dec 2, 2025, https://docs.cloud.google.com/tpu/docs/release-notes
[21] PyTorch now offers native quantized variants of popular models! : r/LocalLLaMA - Reddit, accessed Dec 2, 2025, https://www.reddit.com/r/LocalLLaMA/comments/1nlguk9/pytorch_now_offers_native_quantized_variants_of/
[22] PyTorch Native Architecture Optimization: torchao, accessed Dec 2, 2025, https://pytorch.org/blog/pytorch-native-architecture-optimization/
[23] matmul() using PyTorch's MPS backend is faster than Apple's MLX - Kevin Martin Jose, accessed Dec 2, 2025, https://kevinmartinjose.com/2025/04/21/matmul-using-pytorchs-mps-backend-is-faster-than-apples-mlx/
[24] Test fails on MPS due to unsupported float64 precision · Issue #21261 · Lightning-AI/pytorch-lightning - GitHub, accessed Dec 2, 2025, https://github.com/Lightning-AI/pytorch-lightning/issues/21261
[25] Float64 (Double Precision) Support on MPS with PyTorch on Apple Silicon?, accessed Dec 2, 2025, https://discussions.apple.com/thread/256120698
[26] Best GPUs For Machine Learning In 2025: Top 15 Ranked - RedSwitches, accessed Dec 2, 2025, https://www.redswitches.com/blog/15-best-gpus-for-machine-learning/
[27] AMD Instinct MI300X Platform, accessed Dec 2, 2025, https://www.amd.com/content/dam/amd/en/documents/instinct-tech-docs/data-sheets/amd-instinct-mi300x-platform-data-sheet.pdf
[28] AMD MI300X vs NVIDIA H100: Which AI GPU is Better? - Big Data Supply, Inc., accessed Dec 2, 2025, https://bigdatasupply.com/nvidia-h100-vs-amd-mi300x/
[29] What are the performance implications of using NVLink versus Infinity Fabric?, accessed Dec 2, 2025, https://massedcompute.com/faq-answers/?question=What%20are%20the%20performance%20implications%20of%20using%20NVLink%20versus%20Infinity%20Fabric?
[30] How does NVIDIA NVLink compare to AMD Infinity Fabric? - Massed Compute, accessed Dec 2, 2025, https://massedcompute.com/faq-answers/?question=How%20does%20NVIDIA%20NVLink%20compare%20to%20AMD%20Infinity%20Fabric?
[31] TPU vs GPU: What's the Difference in 2025? - CloudOptimo, accessed Dec 2, 2025, https://www.cloudoptimo.com/blog/tpu-vs-gpu-what-is-the-difference-in-2025/
[32] tenstorrent/pytorch-xla: Enabling PyTorch on XLA Devices (e.g. Google TPU) - GitHub, accessed Dec 2, 2025, https://github.com/tenstorrent/pytorch-xla
[33] Speed Up PyTorch Training by 3x with NVIDIA Nsight and PyTorch 2.0 Tricks, accessed Dec 2, 2025, https://arikpoz.github.io/posts/2025-05-25-speed-up-pytorch-training-by-3x-with-nvidia-nsight-and-pytorch-2-tricks/
[34] NVIDIA Nsight Systems, accessed Dec 2, 2025, https://developer.nvidia.com/nsight-systems
[35] Frontier User Guide - OLCF User Documentation, accessed Dec 2, 2025, https://docs.olcf.ornl.gov/systems/frontier_user_guide.html
[36] The Accelerator Toolkit: A Review of Profiling and Tracing for GPUs and other co-processor, accessed Dec 2, 2025, https://eunomia.dev/blog/2025/04/11/the-accelerator-toolkit-a-review-of-profiling-and-tracing-for-gpus-and-other-co-processor/
[37] A bug that taught me more about PyTorch than years of using it - Hacker News, accessed Dec 2, 2025, https://news.ycombinator.com/item?id=45684253
[38] the bug that taught me more about PyTorch than years of using it - Elana Simon, accessed Dec 2, 2025, https://elanapearl.github.io/blog/2025/the-bug-that-taught-me-pytorch/
[39] Apple Silicon & torchrun: Distributed package doesn't have NCCL built in - PyTorch Forums, accessed Dec 2, 2025, https://discuss.pytorch.org/t/apple-silicon-torchrun-distributed-package-doesnt-have-nccl-built-in/201315
[40] Reproducibility — PyTorch 2.9 documentation, accessed Dec 2, 2025, https://docs.pytorch.org/docs/stable/notes/randomness.html
[41] Reproducibility over Different Machines - PyTorch Forums, accessed Dec 2, 2025, https://discuss.pytorch.org/t/reproducibility-over-different-machines/63047
[42] AMD MI300X Pricing (September 2025): Cheapest High‑Memory GPUs in the Cloud, accessed Dec 2, 2025, https://www.thundercompute.com/blog/amd-mi300x-pricing
[43] Performance per dollar of GPUs and TPUs for AI inference | Google Cloud Blog, accessed Dec 2, 2025, https://cloud.google.com/blog/products/compute/performance-per-dollar-of-gpus-and-tpus-for-ai-inference
[44] NVIDIA Blackwell Architecture Technical Whitepaper, accessed Dec 2025, https://resources.nvidia.com/en-us-blackwell-architecture
[45] Google Cloud TPU v5p: AI Hypercomputer Architecture and Performance, Google System Research, 2025, https://cloud.google.com/tpu/docs/v5p
[46] pytorch/ao: Architecture Optimization for PyTorch - GitHub, accessed Dec 2, 2025, https://github.com/pytorch/ao
[47] Stochastic Rounding in Low-Precision Training: Convergence Analysis, IEEE Transactions on Neural Networks, 2025, https://ieeexplore.ieee.org/document/8638637
[48] Training Stability in FP4: Micro-scaling Formats for Blackwell, NVIDIA Technical Blog, 2025, https://developer.nvidia.com/blog/blackwell-architecture-technical-brief/
[49] GB200 NVL72: Solving the Rack-Scale Straggler Problem, NVIDIA GTC 2025 Session 2145, https://www.nvidia.com/gtc/
[50] Nvidia Blackwell B200 Data Sheet, Nvidia Corporation, 2025, https://www.nvidia.com/en-us/data-center/blackwell/
[51] Cloud TPU v5p Pricing and Architecture Guide, Google Cloud Documentation, 2025, https://cloud.google.com/tpu/pricing
[52] LLM Training Performance Benchmarks: MFU vs HFU on H100 Clusters, Databricks Blog, 2025, https://www.databricks.com/blog/llm-training-performance-benchmarks
[53] Maximizing Training Efficiency with TPU v5p: MFU Analysis, Google Cloud Blog, 2025, https://cloud.google.com/blog/products/compute/maximizing-training-efficiency-with-tpu-v5p
[54] NVIDIA NVLink and NVSwitch: The Fabric of AI, NVIDIA Technical Documentation, 2025, https://www.nvidia.com/en-us/data-center/nvlink/
[55] NVIDIA Spectrum-X: Ethernet for AI, NVIDIA Whitepaper, 2025, https://www.nvidia.com/en-us/networking/spectrum-x/
[56] Zero-Copy Embeddings on AMD MI300X: Architecture Deep Dive, AMD GPUOpen, 2025, https://gpuopen.com/learn/amd-lab-notes/amd-lab-notes-mi300x-memory-architecture-deep-dive/
[57] Accelerating Distributed Training with In-Network Computing: SHARP and NCCL, NVIDIA Technical Blog, 2025, https://developer.nvidia.com/blog/accelerating-distributed-training-with-in-network-computing/
[58] TensorWave MI300X Instance Pricing: Breaking the GPU Monopoly, TensorWave Blog, 2025, https://www.tensorwave.com/blog/mi300x-pricing
[59] Azure AI Infrastructure Updates: ND MI300X v5 vs ND H100 v5 Pricing, Microsoft Azure Blog, 2025, https://azure.microsoft.com/en-us/blog/azure-ai-infrastructure-updates/
[61] Machine Learning with AMD Radeon™ RX 7900 XTX Graphics Card, AMD GPUOpen Documentation, 2025, https://rocm.docs.amd.com/en/latest/how_to/rocm-for-ai/inference/running_ml_models_on_radeon.html
[62] ExecuTorch: Enabling On-Device AI Across Mobile and Edge, PyTorch Blog, 2025, https://pytorch.org/blog/executorch-announcement/
[63] AMD Radeon for AI: Configuring ROCm 6.2 for Consumer GPUs, AMD Community Blog, 2025, https://community.amd.com/t5/ai/amd-rocm-6-on-radeon-gpus/ba-p/653214
[64] NVIDIA B200 Inference Performance: The FP4 Advantage, NVIDIA Technical Blog, 2025, https://developer.nvidia.com/blog/blackwell-inference-performance-fp4/
[65] AMD Instinct MI325X: Architecture and Memory Analysis, Chips and Cheese, 2025, https://chipsandcheese.com/2025/01/15/amd-mi325x-architecture-analysis/
[66] Benchmarking TPU v5p vs H100: MFU and Cost Efficiency in Large Scale Training, MosaicML Blog, 2025, https://www.mosaicml.com/blog/tpu-v5p-vs-h100-benchmark
[67] Cloud TPU v5e: Purpose-built for efficient inference and training, Google Cloud Blog, 2025, https://cloud.google.com/blog/products/compute/announcing-cloud-tpu-v5e-and-a3-vms-in-ga
[68] Scaling Large Language Model Training with AMD MI300X: Performance and Efficiency Analysis, Databricks Engineering Blog, 2025, https://www.databricks.com/blog/scaling-llm-training-amd-mi300x