本文整理汇总了Python中DB.patternFittedBlocks方法的典型用法代码示例。如果您正苦于以下问题:Python DB.patternFittedBlocks方法的具体用法?Python DB.patternFittedBlocks怎么用?Python DB.patternFittedBlocks使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类DB
的用法示例。
在下文中一共展示了DB.patternFittedBlocks方法的1个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: gemms
# 需要导入模块: import DB [as 别名]
# 或者: from DB import patternFittedBlocks [as 别名]
def gemms(self, matrices):
nameToIndex = dict()
for i,matrix in enumerate(matrices):
nameToIndex[matrix.name] = i
sparsityPatterns = [matrix.spp for matrix in matrices]
# Eliminate irrelevant entries in the matrix multiplication
equivalentSparsityPatterns = Sparse.equivalentMultiplicationPatterns(sparsityPatterns)
# Fit the equivalent sparsity pattern tightly for each memory block
fittedBlocks = [DB.patternFittedBlocks(matrix.blocks, equivalentSparsityPatterns[i]) for i, matrix in enumerate(matrices)]
# Find the actual implementation pattern, i.e. dense matrix -> dense pattern
implementationPatterns = [matrix.getImplementationPattern(fittedBlocks[i], equivalentSparsityPatterns[i]) for i, matrix in enumerate(matrices)]
# Determine the matrix multiplication order based on the implementation pattern
chainOrder, dummy = Sparse.sparseMatrixChainOrder(implementationPatterns)
self.nonZeroFlops += Sparse.calculateOptimalSparseFlops(equivalentSparsityPatterns)
# convert matrix chain order to postfix
stack = list()
output = list()
stack.append((0, len(matrices)-1))
while len(stack) > 0:
current = stack[-1]
i = current[0]
j = current[1]
stack.pop()
if (i == j): # matrix chain is a single matrix
output.append(matrices[i]) # post the matrix A_i
else: # subproblem A_i * ... * A_j
output.append('*') # post a multiplication
k = chainOrder[current[0], current[1]] # split position A_i..k * A_(k+1)j
stack.append((current[0], k)) # post A_i..k
stack.append((k+1, current[1])) # post A_(k+1)j
# parse postfix
operands = list()
tempCounter = len(self.temps)
while len(output) > 0:
top = output.pop()
if top != '*':
operands.append(top)
else:
# In the following we generate instructions for op1 * op2.
op2 = operands.pop()
op1 = operands.pop()
blocks1 = fittedBlocks[nameToIndex[op1.name]] if nameToIndex.has_key(op1.name) else op1.blocks
blocks2 = fittedBlocks[nameToIndex[op2.name]] if nameToIndex.has_key(op2.name) else op2.blocks
# Manage temporary variables
if len(output) > 0:
if len(self.temps) == 0:
tempCounter += 1
self.temps.append(DB.MatrixInfo(self.tempBaseName + str(tempCounter)))
result = self.temps.pop()
resultRequiredReals = result.requiredReals
spp1 = implementationPatterns[nameToIndex[op1.name]] if nameToIndex.has_key(op1.name) else op1.spp
spp2 = implementationPatterns[nameToIndex[op2.name]] if nameToIndex.has_key(op2.name) else op2.spp
result = DB.MatrixInfo(result.name, op1.rows, op2.cols, sparsityPattern = spp1 * spp2)
result.fitBlocksToSparsityPattern()
result.generateMemoryLayout(self.arch, alignStartrow=True)
resultName = result.name
operands.append(result)
beta = 0
result.requiredReals = max(resultRequiredReals, result.requiredReals)
else:
beta = 1
result = self.kernel
resultName = self.resultName
ops = []
writes = []
# op1 and op2 may be partitioned in several blocks.
# Here we split the blocks of op1 and op2 in sums, i.e.
# op1 * op2 = (op11 + op12 + ... + op1m) * (op21 + op22 + ... + op2n)
# = op11*op21 + op11*op22 + ...
# opij is a matrix where the j-th block of opi is nonzero.
# E.g. the first block of op1 (e.g. a 3x3 matrix) is given by (1, 2, 0, 2), then
# 0 0 0
# op11 = x x 0
# 0 0 0
for i1, block1 in enumerate(blocks1):
for i2, block2 in enumerate(blocks2):
# op1k * op2l is only nonzero if the columns of op1k and
# the rows of op2l intersect.
self.__gemm(op1.name, block1, op1.blocks[i1], op2.name, block2, op2.blocks[i2], resultName, result.blocks[0], beta, ops, writes)
# Reorder ops in order to find betas
if len(writes) > 0 and beta == 0:
targetCard = result.blocks[0].ld * result.blocks[0].cols()
mdsIn = MDS.maxDisjointSet(writes, targetCard)
mdsOut = list( set(range(len(writes))).difference(set(mdsIn)) )
order = mdsIn + mdsOut
memsetInterval = set(range(targetCard))
for m in mdsIn:
memsetInterval.difference_update(set( [i + j*result.blocks[0].ld for j in range(writes[m].startcol, writes[m].stopcol) for i in range(writes[m].startrow, writes[m].stoprow)] ))
memsetInterval = list(memsetInterval)
ranges = []
for key, group in itertools.groupby(enumerate(memsetInterval), lambda(index, value): value - index):
group = map(operator.itemgetter(1), group)
#.........这里部分代码省略.........