本文整理汇总了C++中af::matmul方法的典型用法代码示例。如果您正苦于以下问题:C++ af::matmul方法的具体用法?C++ af::matmul怎么用?C++ af::matmul使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类af
的用法示例。
在下文中一共展示了af::matmul方法的2个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的C++代码示例。
示例1: randu
TEST(MatrixMultiply, RhsBroadcastBatched)
{
const int M = 512;
const int K = 512;
const int N = 10;
const int D2 = 2;
const int D3 = 3;
for (int d3 = 1; d3 <= D3; d3 *= D3) {
for (int d2 = 1; d2 <= D2; d2 *= D2) {
array a = randu(M, K, d2, d3);
array b = randu(K, N);
array c = matmul(a, b);
for (int j = 0; j < d3; j++) {
for (int i = 0; i < d2; i++) {
array a_ij = a(span, span, i, j);
array c_ij = c(span, span, i, j);
array res = matmul(a_ij, b);
EXPECT_LT(max<float>(abs(c_ij - res)), 1E-3)
<< " for d2 = " << d2 << " for d3 = " << d3;
}
}
}
}
}
示例2: cppMatMulCheck
void cppMatMulCheck(string TestFile)
{
if (noDoubleTests<T>()) return;
vector<dim4> numDims;
vector<vector<T> > hData;
vector<vector<T> > tests;
readTests<T,T,int>(TestFile, numDims, hData, tests);
array a(numDims[0], &hData[0].front());
array b(numDims[1], &hData[1].front());
dim4 atdims = numDims[0];
{
dim_t f = atdims[0];
atdims[0] = atdims[1];
atdims[1] = f;
}
dim4 btdims = numDims[1];
{
dim_t f = btdims[0];
btdims[0] = btdims[1];
btdims[1] = f;
}
array aT = moddims(a, atdims.ndims(), atdims.get());
array bT = moddims(b, btdims.ndims(), btdims.get());
vector<array> out(tests.size());
if(isBVector) {
out[0] = matmul(aT, b, AF_MAT_NONE, AF_MAT_NONE);
out[1] = matmul(bT, a, AF_MAT_NONE, AF_MAT_NONE);
out[2] = matmul(b, a, AF_MAT_TRANS, AF_MAT_NONE);
out[3] = matmul(bT, aT, AF_MAT_NONE, AF_MAT_TRANS);
out[4] = matmul(b, aT, AF_MAT_TRANS, AF_MAT_TRANS);
}
else {
out[0] = matmul(a, b, AF_MAT_NONE, AF_MAT_NONE);
out[1] = matmul(a, bT, AF_MAT_NONE, AF_MAT_TRANS);
out[2] = matmul(a, bT, AF_MAT_TRANS, AF_MAT_NONE);
out[3] = matmul(aT, bT, AF_MAT_TRANS, AF_MAT_TRANS);
}
for(size_t i = 0; i < tests.size(); i++) {
dim_t elems = out[i].elements();
vector<T> h_out(elems);
out[i].host((void*)&h_out.front());
if (false == equal(h_out.begin(), h_out.end(), tests[i].begin())) {
cout << "Failed test " << i << "\nCalculated: " << endl;
copy(h_out.begin(), h_out.end(), ostream_iterator<T>(cout, ", "));
cout << "Expected: " << endl;
copy(tests[i].begin(), tests[i].end(), ostream_iterator<T>(cout, ", "));
FAIL();
}
}
}