🔥码云GVP开源项目 12k star Uniapp+ElementUI 功能强大 支持多语言、二开方便! 广告
# 5.8。示例 > 原文: [http://numba.pydata.org/numba-doc/latest/roc/examples.html](http://numba.pydata.org/numba-doc/latest/roc/examples.html) ## 5.8.1。矩阵乘法 以下是使用 HSA 内核的矩阵乘法的简单实现: ```py @roc.jit def matmul(A, B, C): i = roc.get_global_id(0) j = roc.get_global_id(1) if i >= C.shape[0] or j >= C.shape[1]: return tmp = 0 for k in range(A.shape[1]): tmp += A[i, k] * B[k, j] C[i, j] = tmp ``` 这种实现很简单直观但性能很差,因为相同的矩阵元素将从设备内存中多次加载,这很慢(某些设备可能有透明的数据缓存,但它们可能不够大,不能一次保存整个输入)。 如果我们使用阻塞算法来减少对设备内存的访问,则会更快。 HSA 为组中的工作项提供快速[共享内存](memory.html#roc-shared-memory),以协同计算任务。以下实现了使用共享内存的方形矩阵乘法的更快版本: ```py import numpy as np from numba import roc from numba import float32 from time import time as timer blocksize = 16 gridsize = 16 @roc.jit('(float32[:,:], float32[:,:], float32[:,:])') def matmulfast(A, B, C): x = roc.get_global_id(0) y = roc.get_global_id(1) tx = roc.get_local_id(0) ty = roc.get_local_id(1) sA = roc.shared.array(shape=(blocksize, blocksize), dtype=float32) sB = roc.shared.array(shape=(blocksize, blocksize), dtype=float32) if x >= C.shape[0] or y >= C.shape[1]: return tmp = 0 for i in range(gridsize): # preload sA[tx, ty] = A[x, ty + i * blocksize] sB[tx, ty] = B[tx + i * blocksize, y] # wait for preload to end roc.barrier(1) # compute loop for j in range(blocksize): tmp += sA[tx, j] * sB[j, ty] # wait for compute to end roc.barrier(1) C[x, y] = tmp N = gridsize * blocksize A = np.random.random((N, N)).astype(np.float32) B = np.random.random((N, N)).astype(np.float32) C = np.zeros_like(A) griddim = gridsize, gridsize blockdim = blocksize, blocksize with roc.register(A, B, C): ts = timer() matmulfast[griddim, blockdim](A, B, C) te = timer() print("1st GPU time:", te - ts) with roc.register(A, B, C): ts = timer() matmulfast[griddim, blockdim](A, B, C) te = timer() print("2nd GPU time:", te - ts) ts = timer() ans = np.dot(A, B) te = timer() print("CPU time:", te - ts) np.testing.assert_allclose(ans, C, rtol=1e-5) ``` 由于共享内存是有限的资源,因此代码会从输入数组一次预加载一个小块。然后,它调用 [`barrier()`](memory.html#numba.roc.barrier "numba.roc.barrier") 等待所有线程完成预加载,然后再对共享内存进行计算。它在计算后再次同步,以确保所有线程在共享内存中完成数据,然后在下一次循环迭代中覆盖它。