Jintao Zhang
Publications
HybridStitch: Pixel and Timestep Level Model Stitching for Diffusion Acceleration
Diffusion models have demonstrated a remarkable ability in Text-to-Image (T2I) generation applications. Despite the advanced generation output, they suffer from heavy computation overhead, especially for large models that contain tens of billions of parameters. Prior work has illustrated that replacing part of the denoising steps with a smaller model still maintains the generation quality. However, these methods only focus on saving computation for some timesteps, ignoring the difference in compute demand within one timestep. In this work, we propose HybridStitch, a new T2I generation paradigm that treats generation like editing. Specifically, we introduce a hybrid stage that jointly incorporates both the large model and the small model. HybridStitch separates the entire image into two regions: one that is relatively easy to render, enabling an early transition to the smaller model, and another that is more complex and therefore requires refinement by the large model. HybridStitch employs the small model to construct a coarse sketch while exploiting the large model to edit and refine the complex regions. According to our evaluation, HybridStitch achieves 1.83$\times$ speedup on Stable Diffusion 3, which is faster than all existing mixture of model methods.
SageBwd: A Trainable Low-bit Attention
Low-bit attention, such as SageAttention, has emerged as an effective approach for accelerating model inference, but its applicability to training remains poorly understood. In prior work, we introduced SageBwd, a trainable INT8 attention that quantizes six of seven attention matrix multiplications while preserving fine-tuning performance. However, SageBwd exhibited a persistent performance gap to full-precision attention (FPA) during pre-training. In this work, we investigate why this gap occurs and demonstrate that SageBwd matches full-precision attention during pretraining. Through experiments and theoretical analysis, we reach a few important insights and conclusions: (i) QK-norm is necessary for stable training at large tokens per step, (ii) quantization errors primarily arise from the backward-pass score gradient dS, (iii) reducing tokens per step enables SageBwd to match FPA performance in pre-training, and (iv) K-smoothing remains essential for training stability, while Q-smoothing provides limited benefit during pre-training.
SLA2: Sparse-Linear Attention with Learnable Routing and QAT
Sparse-Linear Attention (SLA) combines sparse and linear attention to accelerate diffusion models and has shown strong performance in video generation. However, (i) SLA relies on a heuristic split that assigns computations to the sparse or linear branch based on attention-weight magnitude, which can be suboptimal. Additionally, (ii) after formally analyzing the attention error in SLA, we identify a mismatch between SLA and a direct decomposition into sparse and linear attention. We propose SLA2, which introduces (I) a learnable router that dynamically selects whether each attention computation should use sparse or linear attention, (II) a more faithful and direct sparse-linear attention formulation that uses a learnable ratio to combine the sparse and linear attention branches, and (III) a sparse + low-bit attention design, where low-bit attention is introduced via quantization-aware fine-tuning to reduce quantization error. Experiments show that on video diffusion models, SLA2 can achieve 97% attention sparsity and deliver an 18.6x attention speedup while preserving generation quality.
Residual Context Diffusion Language Models
Diffusion Large Language Models (dLLMs) have emerged as a promising alternative to purely autoregressive language models because they can decode multiple tokens in parallel. However, state-of-the-art block-wise dLLMs rely on a "remasking" mechanism that decodes only the most confident tokens and discards the rest, effectively wasting computation. We demonstrate that recycling computation from the discarded tokens is beneficial, as these tokens retain contextual information useful for subsequent decoding iterations. In light of this, we propose Residual Context Diffusion (RCD), a module that converts these discarded token representations into contextual residuals and injects them back for the next denoising step. RCD uses a decoupled two-stage training pipeline to bypass the memory bottlenecks associated with backpropagation. We validate our method on both long CoT reasoning (SDAR) and short CoT instruction following (LLaDA) models. We demonstrate that a standard dLLM can be efficiently converted to the RCD paradigm with merely ~1 billion tokens. RCD consistently improves frontier dLLMs by 5-10 points in accuracy with minimal extra computation overhead across a wide range of benchmarks. Notably, on the most challenging AIME tasks, RCD nearly doubles baseline accuracy and attains up to 4-5x fewer denoising steps at equivalent accuracy levels.