本文整理汇总了Java中org.nd4j.linalg.api.buffer.DataBuffer类的典型用法代码示例。如果您正苦于以下问题:Java DataBuffer类的具体用法?Java DataBuffer怎么用?Java DataBuffer使用的例子?那么恭喜您, 这里精选的类代码示例或许可以为您提供帮助。
DataBuffer类属于org.nd4j.linalg.api.buffer包,在下文中一共展示了DataBuffer类的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Java代码示例。
示例1: copy
import org.nd4j.linalg.api.buffer.DataBuffer; //导入依赖的package包/类
/**
* Compute y <- x (copy a matrix)
*/
@Override
public INDArray copy(INDArray x, INDArray y) {
DataTypeValidation.assertSameDataType(x,y);
if(x.data().dataType().equals(DataBuffer.DOUBLE))
JavaBlas.rcopy(
x.length(),
x.data().asDouble(),
x.offset(),
x.secondaryStride(),
y.data().asDouble(),
y.offset(),
y.secondaryStride());
else
JavaBlas.rcopy(
x.length(),
x.data().asFloat(),
x.offset(),
x.secondaryStride(),
y.data().asFloat(),
y.offset(),
y.secondaryStride());
return y;
}
示例2: applyTransformToOrigin
import org.nd4j.linalg.api.buffer.DataBuffer; //导入依赖的package包/类
/**
* Apply the transformation at from[i]
*
* @param i the index of the element to apply the transform to
*/
@Override
public void applyTransformToOrigin(INDArray origin,int i) {
if(origin instanceof IComplexNumber) {
IComplexNDArray c2 = (IComplexNDArray) origin;
IComplexNumber transformed = apply(origin,getFromOrigin(origin,i),i);
c2.putScalar(i,transformed);
}
else {
Number f = apply(origin,getFromOrigin(origin,i),i);
double val = f.doubleValue();
if(Double.isNaN(val) || Double.isInfinite(val))
val = Nd4j.EPS_THRESHOLD;
if(origin.data().dataType().equals(DataBuffer.FLOAT))
origin.putScalar(i, val);
else
origin.putScalar(i, val);
}
}
示例3: cumsum
import org.nd4j.linalg.api.buffer.DataBuffer; //导入依赖的package包/类
public static Function<INDArray,INDArray> cumsum() {
return new Function<INDArray, INDArray>() {
@Override
public INDArray apply(INDArray input) {
double s = 0.0;
for (int i = 0; i < input.length(); i++) {
if(input.data().dataType().equals(DataBuffer.FLOAT))
s += input.getDouble(i);
else
s+= input.getDouble(i);
input.putScalar(i, s);
}
return input;
}
};
}
示例4: nrm2
import org.nd4j.linalg.api.buffer.DataBuffer; //导入依赖的package包/类
@Override
public double nrm2(IComplexNDArray x) {
if(x.data().dataType().equals(DataBuffer.FLOAT))
return NativeBlas.scnrm2(
x.length(),
x.data().asFloat(),
x.offset(),
x.secondaryStride());
else if(x.data().dataType().equals(DataBuffer.DOUBLE))
return NativeBlas.dznrm2(
x.length(),
x.data().asDouble(),
x.offset(),
x.secondaryStride());
throw new IllegalStateException("Illegal data type");
}
示例5: dotc
import org.nd4j.linalg.api.buffer.DataBuffer; //导入依赖的package包/类
/**
* Compute x^T * y (dot product)
*/
@Override
public IComplexNumber dotc(IComplexNDArray x, IComplexNDArray y) {
DataTypeValidation.assertSameDataType(x,y);
if(x.data().dataType().equals(DataBuffer.FLOAT))
return new ComplexFloat(NativeBlas.cdotc(
x.length(),
x.data().asFloat(),
x.blasOffset(),
x.secondaryStride(),
y.data().asFloat(),
y.blasOffset(),
y.secondaryStride()));
else if(x.data().dataType().equals(DataBuffer.DOUBLE))
return new ComplexDouble(
NativeBlas.zdotc(
x.length(),
x.data().asDouble(),
x.blasOffset(),
x.secondaryStride(),
y.data().asDouble(),
y.blasOffset(),
y.secondaryStride()));
throw new IllegalStateException("Illegal data type");
}
示例6: swap
import org.nd4j.linalg.api.buffer.DataBuffer; //导入依赖的package包/类
/**
* Compute x <-> y (swap two matrices)
*/
@Override
public INDArray swap(INDArray x, INDArray y) {
//NativeBlas.dswap(x.length(), x.data(), 0, 1, y.data(), 0, 1);
DataTypeValidation.assertSameDataType(x,y);
if(x.data().dataType().equals(DataBuffer.FLOAT))
JavaBlas.rswap(
x.length(),
x.data().asFloat(),
x.offset(),
x.secondaryStride(),
y.data().asFloat(),
y.offset(),
y.secondaryStride());
else
JavaBlas.rswap(
x.length(),
x.data().asDouble(),
x.offset(),
x.secondaryStride(),
y.data().asDouble(),
y.offset(),
y.secondaryStride());
return y;
}
示例7: iamax
import org.nd4j.linalg.api.buffer.DataBuffer; //导入依赖的package包/类
/**
* Compute index of element with largest absolute value (index of absolute
* value maximum)
*/
@Override
public int iamax(INDArray x) {
if(x.data().dataType().equals(DataBuffer.FLOAT))
return NativeBlas.isamax(
x.length(),
x.data().asFloat(),
x.offset(),
x.secondaryStride()) - 1;
else if(x.data().dataType().equals(DataBuffer.DOUBLE)) {
return NativeBlas.idamax(
x.length(),
x.data().asDouble(),
x.offset(),
x.secondaryStride()) - 1;
}
throw new IllegalStateException("Illegal data type");
}
示例8: asum
import org.nd4j.linalg.api.buffer.DataBuffer; //导入依赖的package包/类
@Override
public double asum(IComplexNDArray x) {
if(x.data().dataType().equals(DataBuffer.FLOAT)) {
return NativeBlas.scasum(
x.length(),
x.data().asFloat(),
x.offset() / 2,
x.secondaryStride());
}
else if(x.data().dataType().equals(DataBuffer.DOUBLE)) {
return NativeBlas.dzasum(
x.length(),
x.data().asDouble(),
x.offset() / 2,
x.secondaryStride());
}
throw new IllegalStateException("Illegal data type");
}
示例9: posv
import org.nd4j.linalg.api.buffer.DataBuffer; //导入依赖的package包/类
@Override
public void posv(char uplo, INDArray A, INDArray B) {
int n = A.rows();
int nrhs = B.columns();
int info = -1;
DataTypeValidation.assertSameDataType(A,B);
if(A.data().dataType().equals(DataBuffer.FLOAT))
info = NativeBlas.sposv(
uplo,
n,
nrhs,
A.data().asFloat(),
A.offset(),
A.rows(),
B.data().asFloat(),
B.offset(),
B.rows());
else if(A.data().dataType().equals(DataBuffer.DOUBLE)) {
info = NativeBlas.dposv(
uplo,
n,
nrhs,
A.data().asDouble(),
A.offset(),
A.rows(),
B.data().asDouble(),
B.offset(),
B.rows());
}
checkInfo("DPOSV", info);
if (info > 0)
throw new LapackArgumentException("DPOSV",
"Leading minor of order i of A is not positive definite.");
}
示例10: syev
import org.nd4j.linalg.api.buffer.DataBuffer; //导入依赖的package包/类
@Override
public int syev(char jobz, char uplo, INDArray a, INDArray w) {
int info = -1;
DataTypeValidation.assertSameDataType(a,w);
if(a.data().dataType().equals(DataBuffer.FLOAT)) {
info = NativeBlas.ssyev(
jobz,
uplo,
a.rows(),
a.data().asFloat(),
a.offset(),
a.rows(),
w.data().asFloat(),
w.offset());
}
else {
info = NativeBlas.dsyev(
jobz,
uplo,
a.rows(),
a.data().asDouble(),
a.offset(),
a.rows(),
w.data().asDouble(),
w.offset());
}
if (info > 0)
throw new LapackConvergenceException("SYEV",
"Eigenvalues could not be computed " + info
+ " off-diagonal elements did not converge");
return info;
}
示例11: alloc
import org.nd4j.linalg.api.buffer.DataBuffer; //导入依赖的package包/类
/**
* Allocate and return a pointer
* based on the length of the ndarray
* @param ndarray the ndarray to allocate
* @return the allocated pointer
*/
public static Pointer alloc(JCublasNDArray ndarray) {
Pointer ret = new Pointer();
//allocate memory for the pointer
Pointer toData =null;
if(ndarray.data().dataType().equals(DataBuffer.FLOAT))
toData = Pointer.to(ndarray.data().asFloat()).withByteOffset(ndarray.offset() * size(ndarray));
else
toData = Pointer.to(ndarray.data().asDouble()).withByteOffset(ndarray.offset() * size(ndarray));
JCublas.cublasAlloc(
ndarray.length(),
size(ndarray)
, ret);
/* Copy from data to pointer at majorStride() (you want to stride through the data properly) incrementing by 1 for the pointer on the GPU.
* This allows us to copy only what we need. */
if(ndarray.length() == ndarray.data().length())
JCublas.cublasSetVector(
ndarray.length(),
size(ndarray),
toData,
1,
ret,
1);
else
JCublas.cublasSetVector(
ndarray.length(),
size(ndarray),
toData,
ndarray.majorStride(),
ret,
1);
return ret;
}
示例12: copy
import org.nd4j.linalg.api.buffer.DataBuffer; //导入依赖的package包/类
/**
* Copy x to y
* @param x the origin
* @param y the destination
*/
public static void copy(IComplexNDArray x, IComplexNDArray y) {
DataTypeValidation.assertSameDataType(x,y);
JCublas.cublasInit();
JCublasComplexNDArray xC = (JCublasComplexNDArray) x;
JCublasComplexNDArray yC = (JCublasComplexNDArray) y;
Pointer xCPointer = alloc(xC);
Pointer yCPointer = alloc(yC);
if(xC.data().dataType().equals(DataBuffer.FLOAT)) {
JCublas.cublasScopy(
x.length(),
xCPointer,
1,
yCPointer,
1);
getData(yC,yCPointer,Pointer.to(yC.data().asFloat()));
}
else {
JCublas.cublasDcopy(
x.length(),
xCPointer,
1,
yCPointer,
1);
getData(yC,yCPointer,Pointer.to(yC.data().asDouble()));
}
free(xCPointer,yCPointer);
}
示例13: cumsumi
import org.nd4j.linalg.api.buffer.DataBuffer; //导入依赖的package包/类
/**
* Cumulative sum along a dimension
*
* @param dimension the dimension to perform cumulative sum along
* @return the cumulative sum along the specified dimension
*/
@Override
public INDArray cumsumi(int dimension) {
if(isVector()) {
double s = 0.0;
for (int i = 0; i < length; i++) {
if(data.dataType().equals(DataBuffer.FLOAT))
s += getDouble(i);
else
s+= getDouble(i);
putScalar(i, s);
}
}
else if(dimension == Integer.MAX_VALUE || dimension == shape.length - 1) {
INDArray flattened = ravel().dup();
double prevVal = flattened.getDouble(0);
for(int i = 1; i < flattened.length(); i++) {
double d = prevVal + flattened.getDouble(i);
flattened.putScalar(i,d);
prevVal = d;
}
return flattened;
}
else {
for(int i = 0; i < vectorsAlongDimension(dimension); i++) {
INDArray vec = vectorAlongDimension(i,dimension);
vec.cumsumi(0);
}
}
return this;
}
示例14: dot
import org.nd4j.linalg.api.buffer.DataBuffer; //导入依赖的package包/类
/**
* Compute x^T * y (dot product)
*/
@Override
public double dot(INDArray x, INDArray y) {
//return NativeBlas.ddot(x.length(), x.data(), 0, 1, y.data(), 0, 1);
DataTypeValidation.assertSameDataType(x,y);
if(x.data().dataType().equals(DataBuffer.FLOAT))
return JavaBlas.rdot(
x.length(),
x.data().asFloat(),
x.offset(),
x.secondaryStride(),
y.data().asFloat(),
y.offset(),
y.secondaryStride());
else if(x.data().dataType().equals(DataBuffer.DOUBLE)) {
return JavaBlas.rdot(
x.length(),
x.data().asDouble(),
x.offset(),
x.secondaryStride(),
y.data().asDouble(),
y.offset(),
y.secondaryStride());
}
throw new IllegalStateException("Illegal data type");
}
示例15: addi
import org.nd4j.linalg.api.buffer.DataBuffer; //导入依赖的package包/类
/**
* in place addition of two matrices
*
* @param other the second ndarray to add
* @param result the result ndarray
* @return the result of the addition
*/
@Override
public INDArray addi(INDArray other, INDArray result) {
if (other.isScalar()) {
return result.addi(other.getDouble(0),result);
}
if (isScalar()) {
return other.addi(getDouble(0), result);
}
if (result == this) {
if(data.dataType().equals(DataBuffer.DOUBLE))
Nd4j.getBlasWrapper().axpy(1.0, other, result);
else
Nd4j.getBlasWrapper().axpy(1.0f,other,result);
}
else if (result == other) {
if(data.dataType().equals(DataBuffer.DOUBLE))
Nd4j.getBlasWrapper().axpy(1.0, this, result);
else
Nd4j.getBlasWrapper().axpy(1.0f,this,result);
}
else {
INDArray resultLinear = result.linearView();
INDArray otherLinear = other.linearView();
INDArray linear = linearView();
for(int i = 0; i < resultLinear.length(); i++) {
resultLinear.putScalar(i,otherLinear.getDouble(i) + linear.getDouble(i));
}
}
return result;
}