前面博客添加链接描述的实现逻辑是:以blockIdx.x作为矩阵Q的第i行输入,blockIdx.y × blockDim.y + threadIdx.y作为矩阵V的第j行输入,中间遍历矩阵K,V时使用for循环串行遍历。在for循环里面相当于每次只能获得QK.T的一个元素,我们的优化策略是以threadIdx.y来加速QK.T的矩阵乘法。这种做法的缺点是:如果N很大,d很小,那么会造成大量的线程浪费闲置,尤其是当d=1的时候,此时算法接近串行。
Q[i,:]=blockIdx.y,V[:,j]=blockIdx.x × blockDim.x + threadIdx.x
为此我们重新修改flash attention的实现逻辑:使用blockIdx.y来作为矩阵Q的第i行对应的索引,采取二维线程块,blockIdx.x × blockDim.x + threadIdx.x作为矩阵V的第j列索引,在遍历K,V的时候,我们引入threadIdx.y,每轮循环可以遍历BLOCK_DIM_y行K,V的元素。这样的话,如果d太小,我们就可以增大BLOCK_DIM_y的取值,加速K,V行的遍历。
#include
文章评论