cpu: rv64: brgemm: add bias fusion for rv64 brgemm kernel#5150
cpu: rv64: brgemm: add bias fusion for rv64 brgemm kernel#5150zhangjian29 wants to merge 2 commits into
Conversation
| const float beta_val = first_kpos ? 0.0f : 1.0f; | ||
| brgemm_kernel_execute( | ||
| brg_kernel, A, B, C, valid_ow, beta_val); | ||
| const float *bias_ptr = first_kpos && jcp.with_bias |
There was a problem hiding this comment.
This misses bias for padded convolution because first_kpos can be true for a BRGEMM call that does not cover the full OW range. For example, ./benchdnn --conv --mode=C --dir=FWD_B --dt=f32 --stag=acdb --wtag=cdba --dtag=acdb mb1ic16ih20iw20oc16oh20ow20kh3kw3sh1sw1ph1pw1 will fail.
There was a problem hiding this comment.
You're right, thanks for catching this. In padded convolutions the first BRGEMM call may only cover a partial OW range (e.g., with pw=1, kw=3: kw=0 covers OW[1..OW-1]), so positions like OW[0] never receive bias.
- Fix: Instead of fusing bias into the first BRGEMM call, the output is now initialized to bias values when !with_sum && with_bias, then all BRGEMM calls accumulate with beta=1. This covers every OW position regardless of padding. For with_sum && with_bias, a scalar bias add runs after the loop (same as the original code on main).
The first_kpos variable and bias pointer passing are removed from the conv path entirely. The JIT kernel's bias fusion is still used for the inner product path, where K-split doesn't have this partial-coverage issue.
b40044b to
da74b01
Compare
Description
This PR introduces fused bias addition in the RV64 BRGEMM JIT kernel, eliminating the separate scalar bias loop that previously ran after the BRGEMM computation. When bias is present, the bias vector is now added to the accumulators inside the kernel using RVV vector operations before storing results to C, following the same pattern as
rvv_gemm_f32.This initial version provides:
brgemm_kernel_executewith correct per-M-tile offsetrvv_brgemm_convolution_fwdandrvv_brgemm_inner_product_fwdcallersKey Features
vfadd_vvbefore the C-store phase, avoiding a separate scalar pass over the outputbeq reg_bias, x0); when no bias is present, the overhead is a single branch instructionf32brgemm_kernel_executesignature addsptr_bias = nullptras a default parameter, so existing callers (e.g., Winograd) are unaffectedImplementation Details
The bias vector has length M (one element per output channel/row). In the JIT kernel:
brgemm_kernel_params_t::ptr_bias(offset 56) into callee-saved registers4v_tmp(LMUL=m4) and added to all accumulator vectors (v_c0-v_c3) viavfadd_vvv_c0The bias is only applied on the first K-block (
kb == 0); subsequent K-blocks passnullptr. In convolution, bias is passed only on the first kernel position call; in inner product, bias is passed on the firstbrgemm_kernel_executecall.For the inner product split-M path (when
MB < nthr), each thread handles a subset of M tiles with the correctly offset bias pointer (bia + m_offset), eliminating the per-thread scalar bias loop entirely.Checklist
General
Performance improvements
All experiments are performed on a Spacemit X60 platform with VLEN=128. We draw comparisons among:
mainbranch (bias added in separate scalar loop after BRGEMM)Correctness Evaluation
Test command:
All
brgemm:rvv,gemm:rvvandjit_1x1:rvvtests pass.Single-Core Performance
Inner Product (brgemm:rvv)
IP Transformer_lt total: 464.02 ms → 443.12 ms (+4.5%)
Convolution (brgemm:rvv)
Convolution performance is largely unchanged because the bias loop overhead was already small relative to the GEMM compute time.
8-Core Performance
Inner Product (brgemm:rvv, split-M path)
The 8-core split-M path shows +38% improvement for
resnet:ip1(MB=1, OC=1000). When MB < nthr, each thread handles a subset of M tiles. The previous scalar bias loop ran per-thread over all MB rows; fusing it into the kernel eliminates this overhead and replaces scalar ops with vector ops (LMUL=m4).Convolution (brgemm:rvv)
Dispatch Coverage
brgemm:rvvbrgemm:rvvbrgemm:rvv(remaining:jit_1x1:rvv,gemm:rvv,wino:rvv)Key Observations
Future Plans
f16data type when Zvfh support is added to the BRGEMM kernel