当前位置: 首页>>代码示例>>Golang>>正文


Golang blas64.Gemm函数代码示例

本文整理汇总了Golang中github.com/gonum/blas/blas64.Gemm函数的典型用法代码示例。如果您正苦于以下问题:Golang Gemm函数的具体用法?Golang Gemm怎么用?Golang Gemm使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。


在下文中一共展示了Gemm函数的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Golang代码示例。

示例1: constructH

func constructH(tau []float64, v blas64.General, store lapack.StoreV, direct lapack.Direct) blas64.General {
	m := v.Rows
	k := v.Cols
	if store == lapack.RowWise {
		m, k = k, m
	}
	h := blas64.General{
		Rows:   m,
		Cols:   m,
		Stride: m,
		Data:   make([]float64, m*m),
	}
	for i := 0; i < m; i++ {
		h.Data[i*m+i] = 1
	}
	for i := 0; i < k; i++ {
		vecData := make([]float64, m)
		if store == lapack.ColumnWise {
			for j := 0; j < m; j++ {
				vecData[j] = v.Data[j*v.Cols+i]
			}
		} else {
			for j := 0; j < m; j++ {
				vecData[j] = v.Data[i*v.Cols+j]
			}
		}
		vec := blas64.Vector{
			Inc:  1,
			Data: vecData,
		}

		hi := blas64.General{
			Rows:   m,
			Cols:   m,
			Stride: m,
			Data:   make([]float64, m*m),
		}
		for i := 0; i < m; i++ {
			hi.Data[i*m+i] = 1
		}
		// hi = I - tau * v * v^T
		blas64.Ger(-tau[i], vec, vec, hi)

		hcopy := blas64.General{
			Rows:   m,
			Cols:   m,
			Stride: m,
			Data:   make([]float64, m*m),
		}
		copy(hcopy.Data, h.Data)
		if direct == lapack.Forward {
			// H = H * H_I in forward mode
			blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, hcopy, hi, 0, h)
		} else {
			// H = H_I * H in backward mode
			blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, hi, hcopy, 0, h)
		}
	}
	return h
}
开发者ID:jacobxk,项目名称:lapack,代码行数:60,代码来源:general.go

示例2: convolveR

func (c *ConvLayer) convolveR(v autofunc.RVector, in, inR linalg.Vector, out *Tensor3) {
	inMat := c.inputToMatrix(in)
	inMatR := c.inputToMatrix(inR)
	filterMat := blas64.General{
		Rows:   c.FilterCount,
		Cols:   inMat.Cols,
		Stride: inMat.Stride,
		Data:   c.FilterVar.Vector,
	}
	outMat := blas64.General{
		Rows:   out.Width * out.Height,
		Cols:   out.Depth,
		Stride: out.Depth,
		Data:   out.Data,
	}
	blas64.Gemm(blas.NoTrans, blas.Trans, 1, inMatR, filterMat, 0, outMat)
	if filterRV, ok := v[c.FilterVar]; ok {
		filterMatR := blas64.General{
			Rows:   c.FilterCount,
			Cols:   inMat.Cols,
			Stride: inMat.Stride,
			Data:   filterRV,
		}
		blas64.Gemm(blas.NoTrans, blas.Trans, 1, inMat, filterMatR, 1, outMat)
	}

	if biasRV, ok := v[c.Biases]; ok {
		biasVec := blas64.Vector{Inc: 1, Data: biasRV}
		for i := 0; i < len(out.Data); i += outMat.Cols {
			outRow := out.Data[i : i+outMat.Cols]
			outVec := blas64.Vector{Inc: 1, Data: outRow}
			blas64.Axpy(len(outRow), 1, biasVec, outVec)
		}
	}
}
开发者ID:unixpickle,项目名称:weakai,代码行数:35,代码来源:conv_layer.go

示例3: propagateSingle

func (c *convLayerResult) propagateSingle(input, upstream, downstream linalg.Vector,
	grad autofunc.Gradient) {
	upstreamMat := blas64.General{
		Rows:   c.Layer.OutputWidth() * c.Layer.OutputHeight(),
		Cols:   c.Layer.OutputDepth(),
		Stride: c.Layer.OutputDepth(),
		Data:   upstream,
	}

	if downstream != nil {
		inDeriv := c.Layer.inputToMatrix(input)
		filterMat := blas64.General{
			Rows:   len(c.Layer.Filters),
			Cols:   c.Layer.FilterWidth * c.Layer.FilterHeight * c.Layer.InputDepth,
			Stride: c.Layer.FilterWidth * c.Layer.FilterHeight * c.Layer.InputDepth,
			Data:   c.Layer.FilterVar.Vector,
		}
		blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, upstreamMat, filterMat, 0, inDeriv)
		flattened := NewTensor3Col(c.Layer.InputWidth, c.Layer.InputHeight,
			c.Layer.InputDepth, inDeriv.Data, c.Layer.FilterWidth,
			c.Layer.FilterHeight, c.Layer.Stride)
		copy(downstream, flattened.Data)
	}

	if filterGrad, ok := grad[c.Layer.FilterVar]; ok {
		inMatrix := c.Layer.inputToMatrix(input)
		destMat := blas64.General{
			Rows:   len(c.Layer.Filters),
			Cols:   c.Layer.FilterWidth * c.Layer.FilterHeight * c.Layer.InputDepth,
			Stride: c.Layer.FilterWidth * c.Layer.FilterHeight * c.Layer.InputDepth,
			Data:   filterGrad,
		}
		blas64.Gemm(blas.Trans, blas.NoTrans, 1, upstreamMat, inMatrix, 1, destMat)
	}
}
开发者ID:unixpickle,项目名称:weakai,代码行数:35,代码来源:conv_layer.go

示例4: testDlarfx

func testDlarfx(t *testing.T, impl Dlarfxer, side blas.Side, m, n, extra int, rnd *rand.Rand) {
	const tol = 1e-13

	c := randomGeneral(m, n, n+extra, rnd)
	cWant := randomGeneral(m, n, n+extra, rnd)
	tau := rnd.NormFloat64()

	var (
		v []float64
		h blas64.General
	)
	if side == blas.Left {
		v = randomSlice(m, rnd)
		h = eye(m, m+extra)
	} else {
		v = randomSlice(n, rnd)
		h = eye(n, n+extra)
	}
	blas64.Ger(-tau, blas64.Vector{Inc: 1, Data: v}, blas64.Vector{Inc: 1, Data: v}, h)
	if side == blas.Left {
		blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, h, c, 0, cWant)
	} else {
		blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, c, h, 0, cWant)
	}

	var work []float64
	if h.Rows > 10 {
		// Allocate work only if H has order > 10.
		if side == blas.Left {
			work = make([]float64, n)
		} else {
			work = make([]float64, m)
		}
	}

	impl.Dlarfx(side, m, n, v, tau, c.Data, c.Stride, work)

	prefix := fmt.Sprintf("Case side=%v, m=%v, n=%v, extra=%v", side, m, n, extra)

	// Check any invalid modifications of c.
	if !generalOutsideAllNaN(c) {
		t.Errorf("%v: out-of-range write to C\n%v", prefix, c.Data)
	}

	if !equalApproxGeneral(c, cWant, tol) {
		t.Errorf("%v: unexpected C\n%v", prefix, c.Data)
	}
}
开发者ID:rawlingsj,项目名称:gofabric8,代码行数:48,代码来源:dlarfx.go

示例5: Mul

// Mul takes the matrix product of a and b, placing the result in the receiver.
//
// See the Muler interface for more information.
func (m *Dense) Mul(a, b Matrix) {
	ar, ac := a.Dims()
	br, bc := b.Dims()

	if ac != br {
		panic(ErrShape)
	}

	m.reuseAs(ar, bc)
	var w *Dense
	if m != a && m != b {
		w = m
	} else {
		w = getWorkspace(ar, bc, false)
		defer func() {
			m.Copy(w)
			putWorkspace(w)
		}()
	}

	if a, ok := a.(RawMatrixer); ok {
		if b, ok := b.(RawMatrixer); ok {
			amat, bmat := a.RawMatrix(), b.RawMatrix()
			blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, amat, bmat, 0, w.Mat)
			return
		}
	}

	if a, ok := a.(Vectorer); ok {
		if b, ok := b.(Vectorer); ok {
			row := make([]float64, ac)
			col := make([]float64, br)
			for r := 0; r < ar; r++ {
				dataTmp := w.Mat.Data[r*w.Mat.Stride : r*w.Mat.Stride+bc]
				for c := 0; c < bc; c++ {
					dataTmp[c] = blas64.Dot(ac,
						blas64.Vector{Inc: 1, Data: a.Row(row, r)},
						blas64.Vector{Inc: 1, Data: b.Col(col, c)},
					)
				}
			}
			return
		}
	}

	row := make([]float64, ac)
	for r := 0; r < ar; r++ {
		for i := range row {
			row[i] = a.At(r, i)
		}
		for c := 0; c < bc; c++ {
			var v float64
			for i, e := range row {
				v += e * b.At(i, c)
			}
			w.Mat.Data[r*w.Mat.Stride+c] = v
		}
	}
}
开发者ID:drewlanenga,项目名称:matrix,代码行数:62,代码来源:dense_arithmetic.go

示例6: testDorghr

func testDorghr(t *testing.T, impl Dorghrer, n, ilo, ihi, extra int, optwork bool, rnd *rand.Rand) {
	const tol = 1e-14

	// Construct the matrix A with elementary reflectors and scalar factors tau.
	a := randomGeneral(n, n, n+extra, rnd)
	var tau []float64
	if n > 1 {
		tau = nanSlice(n - 1)
	}
	work := nanSlice(max(1, n)) // Minimum work for Dgehrd.
	impl.Dgehrd(n, ilo, ihi, a.Data, a.Stride, tau, work, len(work))

	// Extract Q for later comparison.
	q := eye(n, n)
	qCopy := cloneGeneral(q)
	for j := ilo; j < ihi; j++ {
		h := eye(n, n)
		v := blas64.Vector{
			Inc:  1,
			Data: make([]float64, n),
		}
		v.Data[j+1] = 1
		for i := j + 2; i < ihi+1; i++ {
			v.Data[i] = a.Data[i*a.Stride+j]
		}
		blas64.Ger(-tau[j], v, v, h)
		copy(qCopy.Data, q.Data)
		blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, qCopy, h, 0, q)
	}

	if optwork {
		work = nanSlice(1)
		impl.Dorghr(n, ilo, ihi, a.Data, a.Stride, tau, work, -1)
		work = nanSlice(int(work[0]))
	} else {
		work = nanSlice(max(1, ihi-ilo))
	}
	impl.Dorghr(n, ilo, ihi, a.Data, a.Stride, tau, work, len(work))

	prefix := fmt.Sprintf("Case n=%v, ilo=%v, ihi=%v, extra=%v, optwork=%v", n, ilo, ihi, extra, optwork)
	if !generalOutsideAllNaN(a) {
		t.Errorf("%v: out-of-range write to A\n%v", prefix, a.Data)
	}
	if !isOrthonormal(a) {
		t.Errorf("%v: A is not orthogonal\n%v", prefix, a.Data)
	}
	for i := 0; i < n; i++ {
		for j := 0; j < n; j++ {
			aij := a.Data[i*a.Stride+j]
			qij := q.Data[i*q.Stride+j]
			if math.Abs(aij-qij) > tol {
				t.Errorf("%v: unexpected value of A[%v,%v]. want %v, got %v", prefix, i, j, qij, aij)
			}
		}
	}
}
开发者ID:rawlingsj,项目名称:gofabric8,代码行数:56,代码来源:dorghr.go

示例7: QFromQR

// QFromQR extracts the m×m orthonormal matrix Q from a QR decomposition.
func (m *Dense) QFromQR(qr *QR) {
	r, c := qr.qr.Dims()
	m.reuseAs(r, r)

	// Set Q = I.
	for i := 0; i < r; i++ {
		for j := 0; j < i; j++ {
			m.mat.Data[i*m.mat.Stride+j] = 0
		}
		m.mat.Data[i*m.mat.Stride+i] = 1
		for j := i + 1; j < r; j++ {
			m.mat.Data[i*m.mat.Stride+j] = 0
		}
	}

	// Construct Q from the elementary reflectors.
	h := blas64.General{
		Rows:   r,
		Cols:   r,
		Stride: r,
		Data:   make([]float64, r*r),
	}
	qCopy := getWorkspace(r, r, false)
	v := blas64.Vector{
		Inc:  1,
		Data: make([]float64, r),
	}
	for i := 0; i < c; i++ {
		// Set h = I.
		for i := range h.Data {
			h.Data[i] = 0
		}
		for j := 0; j < r; j++ {
			h.Data[j*r+j] = 1
		}

		// Set the vector data as the elementary reflector.
		for j := 0; j < i; j++ {
			v.Data[j] = 0
		}
		v.Data[i] = 1
		for j := i + 1; j < r; j++ {
			v.Data[j] = qr.qr.mat.Data[j*qr.qr.mat.Stride+i]
		}

		// Compute the multiplication matrix.
		blas64.Ger(-qr.tau[i], v, v, h)
		qCopy.Copy(m)
		blas64.Gemm(blas.NoTrans, blas.NoTrans,
			1, qCopy.mat, h,
			0, m.mat)
	}
}
开发者ID:yonglehou,项目名称:matrix,代码行数:54,代码来源:qr.go

示例8: QFromLQ

// QFromLQ extracts the n×n orthonormal matrix Q from an LQ decomposition.
func (m *Dense) QFromLQ(lq *LQ) {
	r, c := lq.lq.Dims()
	m.reuseAs(c, c)

	// Set Q = I.
	for i := 0; i < c; i++ {
		for j := 0; j < i; j++ {
			m.mat.Data[i*m.mat.Stride+j] = 0
		}
		m.mat.Data[i*m.mat.Stride+i] = 1
		for j := i + 1; j < c; j++ {
			m.mat.Data[i*m.mat.Stride+j] = 0
		}
	}

	// Construct Q from the elementary reflectors.
	h := blas64.General{
		Rows:   c,
		Cols:   c,
		Stride: c,
		Data:   make([]float64, c*c),
	}
	qCopy := getWorkspace(c, c, false)
	v := blas64.Vector{
		Inc:  1,
		Data: make([]float64, c),
	}
	for i := 0; i < r; i++ {
		// Set h = I.
		for i := range h.Data {
			h.Data[i] = 0
		}
		for j := 0; j < c; j++ {
			h.Data[j*c+j] = 1
		}

		// Set the vector data as the elementary reflector.
		for j := 0; j < i; j++ {
			v.Data[j] = 0
		}
		v.Data[i] = 1
		for j := i + 1; j < c; j++ {
			v.Data[j] = lq.lq.mat.Data[i*lq.lq.mat.Stride+j]
		}

		// Compute the multiplication matrix.
		blas64.Ger(-lq.tau[i], v, v, h)
		qCopy.Copy(m)
		blas64.Gemm(blas.NoTrans, blas.NoTrans,
			1, h, qCopy.mat,
			0, m.mat)
	}
}
开发者ID:yonglehou,项目名称:matrix,代码行数:54,代码来源:lq.go

示例9: Eval

// Eval returns a matrix literal.
func (m1 *Mul) Eval() MatrixLiteral {

	// This should be replaced with a call to Eval on each side, and then a type
	// switch to handle the various matrix literals.

	lm := m1.Left.Eval()
	rm := m1.Right.Eval()

	left := lm.AsGeneral()
	right := rm.AsGeneral()
	r, c := m1.Dims()
	m := blas64.General{
		Rows:   r,
		Cols:   c,
		Stride: c,
		Data:   make([]float64, r*c),
	}
	blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, left, right, 0, m)
	return &General{m}
}
开发者ID:jonlawlor,项目名称:matrixexp,代码行数:21,代码来源:mul.go

示例10: convolve

func (c *ConvLayer) convolve(in linalg.Vector, out *Tensor3) {
	inMat := c.inputToMatrix(in)
	filterMat := blas64.General{
		Rows:   c.FilterCount,
		Cols:   inMat.Cols,
		Stride: inMat.Stride,
		Data:   c.FilterVar.Vector,
	}
	outMat := blas64.General{
		Rows:   out.Width * out.Height,
		Cols:   out.Depth,
		Stride: out.Depth,
		Data:   out.Data,
	}
	blas64.Gemm(blas.NoTrans, blas.Trans, 1, inMat, filterMat, 0, outMat)

	biasVec := blas64.Vector{Inc: 1, Data: c.Biases.Vector}
	for i := 0; i < len(out.Data); i += outMat.Cols {
		outRow := out.Data[i : i+outMat.Cols]
		outVec := blas64.Vector{Inc: 1, Data: outRow}
		blas64.Axpy(len(outRow), 1, biasVec, outVec)
	}
}
开发者ID:unixpickle,项目名称:weakai,代码行数:23,代码来源:conv_layer.go

示例11: dlatrdCheckDecomposition

// dlatrdCheckDecomposition checks that the first nb rows have been successfully
// reduced.
func dlatrdCheckDecomposition(t *testing.T, uplo blas.Uplo, n, nb int, e, tau, a []float64, lda int, aGen, q blas64.General) bool {
	// Compute Q^T * A * Q.
	tmp := blas64.General{
		Rows:   n,
		Cols:   n,
		Stride: n,
		Data:   make([]float64, n*n),
	}

	ans := blas64.General{
		Rows:   n,
		Cols:   n,
		Stride: n,
		Data:   make([]float64, n*n),
	}

	blas64.Gemm(blas.Trans, blas.NoTrans, 1, q, aGen, 0, tmp)
	blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, tmp, q, 0, ans)

	// Compare with T.
	if uplo == blas.Upper {
		for i := n - 1; i >= n-nb; i-- {
			for j := 0; j < n; j++ {
				v := ans.Data[i*ans.Stride+j]
				switch {
				case i == j:
					if math.Abs(v-a[i*lda+j]) > 1e-10 {
						return false
					}
				case i == j-1:
					if math.Abs(a[i*lda+j]-1) > 1e-10 {
						return false
					}
					if math.Abs(v-e[i]) > 1e-10 {
						return false
					}
				case i == j+1:
				default:
					if math.Abs(v) > 1e-10 {
						return false
					}
				}
			}
		}
	} else {
		for i := 0; i < nb; i++ {
			for j := 0; j < n; j++ {
				v := ans.Data[i*ans.Stride+j]
				switch {
				case i == j:
					if math.Abs(v-a[i*lda+j]) > 1e-10 {
						return false
					}
				case i == j-1:
				case i == j+1:
					if math.Abs(a[i*lda+j]-1) > 1e-10 {
						return false
					}
					if math.Abs(v-e[i-1]) > 1e-10 {
						return false
					}
				default:
					if math.Abs(v) > 1e-10 {
						return false
					}
				}
			}
		}
	}
	return true
}
开发者ID:rawlingsj,项目名称:gofabric8,代码行数:73,代码来源:dlatrd.go

示例12: DlarfbTest


//.........这里部分代码省略.........
						a := make([]float64, ma*lda)
						for i := 0; i < ma; i++ {
							for j := 0; j < lda; j++ {
								a[i*lda+j] = rnd.Float64()
							}
						}
						k := min(ma, na)

						// H is always ma x ma
						var m, n, rowsWork int
						switch {
						default:
							panic("not implemented")
						case side == blas.Left:
							m = test.ma
							n = test.cdim
							rowsWork = n
						case side == blas.Right:
							m = test.cdim
							n = test.ma
							rowsWork = m
						}

						// Use dgeqr2 to find the v vectors
						tau := make([]float64, na)
						work := make([]float64, na)
						impl.Dgeqr2(ma, k, a, lda, tau, work)

						// Correct the v vectors based on the direct and store
						vMatTmp := extractVMat(ma, na, a, lda, lapack.Forward, lapack.ColumnWise)
						vMat := constructVMat(vMatTmp, store, direct)
						v := vMat.Data
						ldv := vMat.Stride

						// Use dlarft to find the t vector
						ldt := test.ldt
						if ldt == 0 {
							ldt = k
						}
						tm := make([]float64, k*ldt)

						impl.Dlarft(direct, store, ma, k, v, ldv, tau, tm, ldt)

						// Generate c matrix
						ldc := test.ldc
						if ldc == 0 {
							ldc = n
						}
						c := make([]float64, m*ldc)
						for i := 0; i < m; i++ {
							for j := 0; j < ldc; j++ {
								c[i*ldc+j] = rnd.Float64()
							}
						}
						cCopy := make([]float64, len(c))
						copy(cCopy, c)

						ldwork := k
						work = make([]float64, rowsWork*k)

						// Call Dlarfb with this information
						impl.Dlarfb(side, trans, direct, store, m, n, k, v, ldv, tm, ldt, c, ldc, work, ldwork)

						h := constructH(tau, vMat, store, direct)

						cMat := blas64.General{
							Rows:   m,
							Cols:   n,
							Stride: ldc,
							Data:   make([]float64, m*ldc),
						}
						copy(cMat.Data, cCopy)
						ans := blas64.General{
							Rows:   m,
							Cols:   n,
							Stride: ldc,
							Data:   make([]float64, m*ldc),
						}
						copy(ans.Data, cMat.Data)
						switch {
						default:
							panic("not implemented")
						case side == blas.Left && trans == blas.NoTrans:
							blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, h, cMat, 0, ans)
						case side == blas.Left && trans == blas.Trans:
							blas64.Gemm(blas.Trans, blas.NoTrans, 1, h, cMat, 0, ans)
						case side == blas.Right && trans == blas.NoTrans:
							blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, cMat, h, 0, ans)
						case side == blas.Right && trans == blas.Trans:
							blas64.Gemm(blas.NoTrans, blas.Trans, 1, cMat, h, 0, ans)
						}
						if !floats.EqualApprox(ans.Data, c, 1e-14) {
							t.Errorf("Cas %v mismatch. Want %v, got %v.", cas, ans.Data, c)
						}
					}
				}
			}
		}
	}
}
开发者ID:rawlingsj,项目名称:gofabric8,代码行数:101,代码来源:dlarfb.go

示例13: DlarfTest


//.........这里部分代码省略.........

			lastr: 0,
			lastc: 1,

			tau: 2,
		},
		{
			m:   10,
			n:   10,
			ldc: 10,

			incv:  4,
			lastv: 6,

			lastr: 9,
			lastc: 8,

			tau: 2,
		},
	} {
		// Construct a random matrix.
		c := make([]float64, test.ldc*test.m)
		for i := 0; i <= test.lastr; i++ {
			for j := 0; j <= test.lastc; j++ {
				c[i*test.ldc+j] = rand.Float64()
			}
		}
		cCopy := make([]float64, len(c))
		copy(cCopy, c)
		cCopy2 := make([]float64, len(c))
		copy(cCopy2, c)

		// Test with side right.
		sz := max(test.m, test.n) // so v works for both right and left side.
		v := make([]float64, test.incv*sz+1)
		// Fill with nonzero entries up until lastv.
		for i := 0; i <= test.lastv; i++ {
			v[i*test.incv] = rand.Float64()
		}
		// Construct h explicitly to compare.
		h := make([]float64, test.n*test.n)
		for i := 0; i < test.n; i++ {
			h[i*test.n+i] = 1
		}
		hMat := blas64.General{
			Rows:   test.n,
			Cols:   test.n,
			Stride: test.n,
			Data:   h,
		}
		vVec := blas64.Vector{
			Inc:  test.incv,
			Data: v,
		}
		blas64.Ger(-test.tau, vVec, vVec, hMat)

		// Apply multiplication (2nd copy is to avoid aliasing).
		cMat := blas64.General{
			Rows:   test.m,
			Cols:   test.n,
			Stride: test.ldc,
			Data:   cCopy,
		}
		cMat2 := blas64.General{
			Rows:   test.m,
			Cols:   test.n,
			Stride: test.ldc,
			Data:   cCopy2,
		}
		blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, cMat2, hMat, 0, cMat)

		// cMat now stores the true answer. Compare with the function call.
		work := make([]float64, sz)
		impl.Dlarf(blas.Right, test.m, test.n, v, test.incv, test.tau, c, test.ldc, work)
		if !floats.EqualApprox(c, cMat.Data, 1e-14) {
			t.Errorf("Dlarf mismatch right, case %v. Want %v, got %v", i, cMat.Data, c)
		}

		// Test on the left side.
		copy(c, cCopy2)
		copy(cCopy, c)
		// Construct h.
		h = make([]float64, test.m*test.m)
		for i := 0; i < test.m; i++ {
			h[i*test.m+i] = 1
		}
		hMat = blas64.General{
			Rows:   test.m,
			Cols:   test.m,
			Stride: test.m,
			Data:   h,
		}
		blas64.Ger(-test.tau, vVec, vVec, hMat)
		blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, hMat, cMat2, 0, cMat)
		impl.Dlarf(blas.Left, test.m, test.n, v, test.incv, test.tau, c, test.ldc, work)
		if !floats.EqualApprox(c, cMat.Data, 1e-14) {
			t.Errorf("Dlarf mismatch left, case %v. Want %v, got %v", i, cMat.Data, c)
		}
	}
}
开发者ID:RomainVabre,项目名称:origin,代码行数:101,代码来源:dlarf.go

示例14: checkPLU

// checkPLU checks that the PLU factorization contained in factorize matches
// the original matrix contained in original.
func checkPLU(t *testing.T, ok bool, m, n, lda int, ipiv []int, factorized, original []float64, tol float64, print bool) {
	var hasZeroDiagonal bool
	for i := 0; i < min(m, n); i++ {
		if factorized[i*lda+i] == 0 {
			hasZeroDiagonal = true
			break
		}
	}
	if hasZeroDiagonal && ok {
		t.Error("Has a zero diagonal but returned ok")
	}
	if !hasZeroDiagonal && !ok {
		t.Error("Non-zero diagonal but returned !ok")
	}

	// Check that the LU decomposition is correct.
	mn := min(m, n)
	l := make([]float64, m*mn)
	ldl := mn
	u := make([]float64, mn*n)
	ldu := n
	for i := 0; i < m; i++ {
		for j := 0; j < n; j++ {
			v := factorized[i*lda+j]
			switch {
			case i == j:
				l[i*ldl+i] = 1
				u[i*ldu+i] = v
			case i > j:
				l[i*ldl+j] = v
			case i < j:
				u[i*ldu+j] = v
			}
		}
	}

	LU := blas64.General{
		Rows:   m,
		Cols:   n,
		Stride: n,
		Data:   make([]float64, m*n),
	}
	U := blas64.General{
		Rows:   mn,
		Cols:   n,
		Stride: ldu,
		Data:   u,
	}
	L := blas64.General{
		Rows:   m,
		Cols:   mn,
		Stride: ldl,
		Data:   l,
	}
	blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, L, U, 0, LU)

	p := make([]float64, m*m)
	ldp := m
	for i := 0; i < m; i++ {
		p[i*ldp+i] = 1
	}
	for i := len(ipiv) - 1; i >= 0; i-- {
		v := ipiv[i]
		blas64.Swap(m, blas64.Vector{1, p[i*ldp:]}, blas64.Vector{1, p[v*ldp:]})
	}
	P := blas64.General{
		Rows:   m,
		Cols:   m,
		Stride: m,
		Data:   p,
	}
	aComp := blas64.General{
		Rows:   m,
		Cols:   n,
		Stride: lda,
		Data:   make([]float64, m*lda),
	}
	copy(aComp.Data, factorized)
	blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, P, LU, 0, aComp)
	if !floats.EqualApprox(aComp.Data, original, tol) {
		if print {
			t.Errorf("PLU multiplication does not match original matrix.\nWant: %v\nGot: %v", original, aComp.Data)
			return
		}
		t.Error("PLU multiplication does not match original matrix.")
	}
}
开发者ID:jacobxk,项目名称:lapack,代码行数:89,代码来源:dgetf2.go

示例15: Dgelq2Test

func Dgelq2Test(t *testing.T, impl Dgelq2er) {
	for c, test := range []struct {
		m, n, lda int
	}{
		{1, 1, 0},
		{2, 2, 0},
		{3, 2, 0},
		{2, 3, 0},
		{1, 12, 0},
		{2, 6, 0},
		{3, 4, 0},
		{4, 3, 0},
		{6, 2, 0},
		{1, 12, 0},
		{1, 1, 20},
		{2, 2, 20},
		{3, 2, 20},
		{2, 3, 20},
		{1, 12, 20},
		{2, 6, 20},
		{3, 4, 20},
		{4, 3, 20},
		{6, 2, 20},
		{1, 12, 20},
	} {
		n := test.n
		m := test.m
		lda := test.lda
		if lda == 0 {
			lda = test.n
		}
		k := min(m, n)
		tau := make([]float64, k)
		for i := range tau {
			tau[i] = rand.Float64()
		}
		work := make([]float64, m)
		for i := range work {
			work[i] = rand.Float64()
		}
		a := make([]float64, m*lda)
		for i := 0; i < m*lda; i++ {
			a[i] = rand.Float64()
		}
		aCopy := make([]float64, len(a))
		copy(aCopy, a)
		impl.Dgelq2(m, n, a, lda, tau, work)

		Q := constructQ("LQ", m, n, a, lda, tau)

		// Check that Q is orthonormal
		for i := 0; i < Q.Rows; i++ {
			nrm := blas64.Nrm2(Q.Cols, blas64.Vector{Inc: 1, Data: Q.Data[i*Q.Stride:]})
			if math.Abs(nrm-1) > 1e-14 {
				t.Errorf("Q not normal. Norm is %v", nrm)
			}
			for j := 0; j < i; j++ {
				dot := blas64.Dot(Q.Rows,
					blas64.Vector{Inc: 1, Data: Q.Data[i*Q.Stride:]},
					blas64.Vector{Inc: 1, Data: Q.Data[j*Q.Stride:]},
				)
				if math.Abs(dot) > 1e-14 {
					t.Errorf("Q not orthogonal. Dot is %v", dot)
				}
			}
		}

		L := blas64.General{
			Rows:   m,
			Cols:   n,
			Stride: n,
			Data:   make([]float64, m*n),
		}
		for i := 0; i < m; i++ {
			for j := 0; j <= min(i, n-1); j++ {
				L.Data[i*L.Stride+j] = a[i*lda+j]
			}
		}

		ans := blas64.General{
			Rows:   m,
			Cols:   n,
			Stride: lda,
			Data:   make([]float64, m*lda),
		}
		copy(ans.Data, aCopy)
		blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, L, Q, 0, ans)
		if !floats.EqualApprox(aCopy, ans.Data, 1e-14) {
			t.Errorf("Case %v, LQ mismatch. Want %v, got %v.", c, aCopy, ans.Data)
		}
	}
}
开发者ID:RomainVabre,项目名称:origin,代码行数:92,代码来源:dgelq2.go


注:本文中的github.com/gonum/blas/blas64.Gemm函数示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。