当前位置: 首页>>代码示例>>Java>>正文


Java DataBuffer类代码示例

本文整理汇总了Java中org.nd4j.linalg.api.buffer.DataBuffer的典型用法代码示例。如果您正苦于以下问题:Java DataBuffer类的具体用法?Java DataBuffer怎么用?Java DataBuffer使用的例子?那么恭喜您, 这里精选的类代码示例或许可以为您提供帮助。


DataBuffer类属于org.nd4j.linalg.api.buffer包,在下文中一共展示了DataBuffer类的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Java代码示例。

示例1: copy

import org.nd4j.linalg.api.buffer.DataBuffer; //导入依赖的package包/类
/**
 * Compute y <- x (copy a matrix)
 */
@Override
public INDArray copy(INDArray x, INDArray y) {
    DataTypeValidation.assertSameDataType(x,y);
    if(x.data().dataType().equals(DataBuffer.DOUBLE))
        JavaBlas.rcopy(
                x.length(),
                x.data().asDouble(),
                x.offset(),
                x.secondaryStride(),
                y.data().asDouble(),
                y.offset(),
                y.secondaryStride());
    else
        JavaBlas.rcopy(
                x.length(),
                x.data().asFloat(),
                x.offset(),
                x.secondaryStride(),
                y.data().asFloat(),
                y.offset(),
                y.secondaryStride());

    return y;
}
 
开发者ID:wlin12,项目名称:JNN,代码行数:28,代码来源:BlasWrapper.java

示例2: applyTransformToOrigin

import org.nd4j.linalg.api.buffer.DataBuffer; //导入依赖的package包/类
/**
 * Apply the transformation at from[i]
 *
 * @param i the index of the element to apply the transform to
 */
@Override
public void applyTransformToOrigin(INDArray origin,int i) {
    if(origin instanceof IComplexNumber) {
        IComplexNDArray c2 = (IComplexNDArray) origin;
        IComplexNumber transformed = apply(origin,getFromOrigin(origin,i),i);
        c2.putScalar(i,transformed);
    }
    else {
        Number f =  apply(origin,getFromOrigin(origin,i),i);
        double val = f.doubleValue();
        if(Double.isNaN(val) || Double.isInfinite(val))
            val = Nd4j.EPS_THRESHOLD;
        if(origin.data().dataType().equals(DataBuffer.FLOAT))
            origin.putScalar(i, val);
        else
            origin.putScalar(i, val);
    }

}
 
开发者ID:wlin12,项目名称:JNN,代码行数:25,代码来源:BaseElementWiseOp.java

示例3: cumsum

import org.nd4j.linalg.api.buffer.DataBuffer; //导入依赖的package包/类
public static Function<INDArray,INDArray> cumsum() {
    return new Function<INDArray, INDArray>() {
        @Override
        public INDArray apply(INDArray input) {
            double s = 0.0;
            for (int i = 0; i < input.length(); i++) {
                if(input.data().dataType().equals(DataBuffer.FLOAT))
                    s += input.getDouble(i);
                else
                    s+= input.getDouble(i);
                input.putScalar(i, s);
            }

            return input;
        }
    };
}
 
开发者ID:wlin12,项目名称:JNN,代码行数:18,代码来源:DimensionFunctions.java

示例4: nrm2

import org.nd4j.linalg.api.buffer.DataBuffer; //导入依赖的package包/类
@Override
public double nrm2(IComplexNDArray x) {
    if(x.data().dataType().equals(DataBuffer.FLOAT))
        return NativeBlas.scnrm2(
                x.length(),
                x.data().asFloat(),
                x.offset(),
                x.secondaryStride());
    else if(x.data().dataType().equals(DataBuffer.DOUBLE))
        return NativeBlas.dznrm2(
                x.length(),
                x.data().asDouble(),
                x.offset(),
                x.secondaryStride());

    throw new IllegalStateException("Illegal data type");


}
 
开发者ID:wlin12,项目名称:JNN,代码行数:20,代码来源:BlasWrapper.java

示例5: dotc

import org.nd4j.linalg.api.buffer.DataBuffer; //导入依赖的package包/类
/**
 * Compute x^T * y (dot product)
 */
@Override
public IComplexNumber dotc(IComplexNDArray x, IComplexNDArray y) {
    DataTypeValidation.assertSameDataType(x,y);
    if(x.data().dataType().equals(DataBuffer.FLOAT))
        return new ComplexFloat(NativeBlas.cdotc(
                x.length(),
                x.data().asFloat(),
                x.blasOffset(),
                x.secondaryStride(),
                y.data().asFloat(),
                y.blasOffset(),
                y.secondaryStride()));
    else if(x.data().dataType().equals(DataBuffer.DOUBLE))
        return new ComplexDouble(
                NativeBlas.zdotc(
                x.length(),
                x.data().asDouble(),
                x.blasOffset(),
                x.secondaryStride(),
                y.data().asDouble(),
                y.blasOffset(),
                y.secondaryStride()));
    throw new IllegalStateException("Illegal data type");
}
 
开发者ID:wlin12,项目名称:JNN,代码行数:28,代码来源:BlasWrapper.java

示例6: swap

import org.nd4j.linalg.api.buffer.DataBuffer; //导入依赖的package包/类
/**
 * Compute x <-> y (swap two matrices)
 */
@Override
public INDArray swap(INDArray x, INDArray y) {
    //NativeBlas.dswap(x.length(), x.data(), 0, 1, y.data(), 0, 1);
    DataTypeValidation.assertSameDataType(x,y);
    if(x.data().dataType().equals(DataBuffer.FLOAT))
        JavaBlas.rswap(
                x.length(),
                x.data().asFloat(),
                x.offset(),
                x.secondaryStride(),
                y.data().asFloat(),
                y.offset(),
                y.secondaryStride());
    else
        JavaBlas.rswap(
                x.length(),
                x.data().asDouble(),
                x.offset(),
                x.secondaryStride(),
                y.data().asDouble(),
                y.offset(),
                y.secondaryStride());
    return y;
}
 
开发者ID:wlin12,项目名称:JNN,代码行数:28,代码来源:BlasWrapper.java

示例7: iamax

import org.nd4j.linalg.api.buffer.DataBuffer; //导入依赖的package包/类
/**
 * Compute index of element with largest absolute value (index of absolute
 * value maximum)
 */
@Override
public int iamax(INDArray x) {
    if(x.data().dataType().equals(DataBuffer.FLOAT))
        return NativeBlas.isamax(
                x.length(),
                x.data().asFloat(),
                x.offset(),
                x.secondaryStride()) - 1;
    else if(x.data().dataType().equals(DataBuffer.DOUBLE)) {
        return NativeBlas.idamax(
                x.length(),
                x.data().asDouble(),
                x.offset(),
                x.secondaryStride()) - 1;

    }
    throw new IllegalStateException("Illegal data type");

}
 
开发者ID:wlin12,项目名称:JNN,代码行数:24,代码来源:BlasWrapper.java

示例8: asum

import org.nd4j.linalg.api.buffer.DataBuffer; //导入依赖的package包/类
@Override
public double asum(IComplexNDArray x) {
    if(x.data().dataType().equals(DataBuffer.FLOAT)) {
        return NativeBlas.scasum(
                x.length(),
                x.data().asFloat(),
                x.offset() / 2,
                x.secondaryStride());

    }
    else if(x.data().dataType().equals(DataBuffer.DOUBLE)) {
        return NativeBlas.dzasum(
                x.length(),
                x.data().asDouble(),
                x.offset() / 2,
                x.secondaryStride());

    }

    throw new IllegalStateException("Illegal data type");

}
 
开发者ID:wlin12,项目名称:JNN,代码行数:23,代码来源:BlasWrapper.java

示例9: posv

import org.nd4j.linalg.api.buffer.DataBuffer; //导入依赖的package包/类
@Override
public void posv(char uplo, INDArray A, INDArray B) {
    int n = A.rows();
    int nrhs = B.columns();
    int info = -1;
    DataTypeValidation.assertSameDataType(A,B);
    if(A.data().dataType().equals(DataBuffer.FLOAT))
        info = NativeBlas.sposv(
                uplo,
                n,
                nrhs,
                A.data().asFloat(),
                A.offset(),
                A.rows(),
                B.data().asFloat(),
                B.offset(),
                B.rows());
    else if(A.data().dataType().equals(DataBuffer.DOUBLE)) {
        info = NativeBlas.dposv(
                uplo,
                n,
                nrhs,
                A.data().asDouble(),
                A.offset(),
                A.rows(),
                B.data().asDouble(),
                B.offset(),
                B.rows());
    }
    checkInfo("DPOSV", info);
    if (info > 0)
        throw new LapackArgumentException("DPOSV",
                "Leading minor of order i of A is not positive definite.");
}
 
开发者ID:wlin12,项目名称:JNN,代码行数:35,代码来源:BlasWrapper.java

示例10: syev

import org.nd4j.linalg.api.buffer.DataBuffer; //导入依赖的package包/类
@Override
public int syev(char jobz, char uplo, INDArray a, INDArray w) {
    int info = -1;
    DataTypeValidation.assertSameDataType(a,w);
    if(a.data().dataType().equals(DataBuffer.FLOAT)) {
        info = NativeBlas.ssyev(
                jobz,
                uplo,
                a.rows(),
                a.data().asFloat(),
                a.offset(),
                a.rows(),
                w.data().asFloat(),
                w.offset());

    }

    else {
        info = NativeBlas.dsyev(
                jobz,
                uplo,
                a.rows(),
                a.data().asDouble(),
                a.offset(),
                a.rows(),
                w.data().asDouble(),
                w.offset());

    }


    if (info > 0)
        throw new LapackConvergenceException("SYEV",
                "Eigenvalues could not be computed " + info
                        + " off-diagonal elements did not converge");

    return info;
}
 
开发者ID:wlin12,项目名称:JNN,代码行数:39,代码来源:BlasWrapper.java

示例11: alloc

import org.nd4j.linalg.api.buffer.DataBuffer; //导入依赖的package包/类
/**
 * Allocate and return a pointer
 * based on the length of the ndarray
 * @param ndarray the ndarray to allocate
 * @return the allocated pointer
 */
public static Pointer alloc(JCublasNDArray ndarray) {
    Pointer ret = new Pointer();
    //allocate memory for the pointer

    Pointer toData =null;
    if(ndarray.data().dataType().equals(DataBuffer.FLOAT))
        toData = Pointer.to(ndarray.data().asFloat()).withByteOffset(ndarray.offset() * size(ndarray));
    else
        toData =  Pointer.to(ndarray.data().asDouble()).withByteOffset(ndarray.offset() * size(ndarray));


    JCublas.cublasAlloc(
            ndarray.length(),
            size(ndarray)
            , ret);

    /* Copy from data to pointer at majorStride() (you want to stride through the data properly) incrementing by 1 for the pointer on the GPU.
    * This allows us to copy only what we need. */

    if(ndarray.length() == ndarray.data().length())
        JCublas.cublasSetVector(
                ndarray.length(),
                size(ndarray),
                toData,
                1,
                ret,
                1);
    else
        JCublas.cublasSetVector(
                ndarray.length(),
                size(ndarray),
                toData,
                ndarray.majorStride(),
                ret,
                1);

    return ret;

}
 
开发者ID:wlin12,项目名称:JNN,代码行数:46,代码来源:SimpleJCublas.java

示例12: copy

import org.nd4j.linalg.api.buffer.DataBuffer; //导入依赖的package包/类
/**
 * Copy x to y
 * @param x the origin
 * @param y the destination
 */
public static void copy(IComplexNDArray x, IComplexNDArray y) {
    DataTypeValidation.assertSameDataType(x,y);
    JCublas.cublasInit();

    JCublasComplexNDArray xC = (JCublasComplexNDArray) x;
    JCublasComplexNDArray yC = (JCublasComplexNDArray) y;

    Pointer xCPointer = alloc(xC);
    Pointer yCPointer = alloc(yC);

    if(xC.data().dataType().equals(DataBuffer.FLOAT)) {
        JCublas.cublasScopy(
                x.length(),
                xCPointer,
                1,
                yCPointer,
                1);


        getData(yC,yCPointer,Pointer.to(yC.data().asFloat()));

    }

    else {
        JCublas.cublasDcopy(
                x.length(),
                xCPointer,
                1,
                yCPointer,
                1);


        getData(yC,yCPointer,Pointer.to(yC.data().asDouble()));

    }


    free(xCPointer,yCPointer);
}
 
开发者ID:wlin12,项目名称:JNN,代码行数:45,代码来源:SimpleJCublas.java

示例13: cumsumi

import org.nd4j.linalg.api.buffer.DataBuffer; //导入依赖的package包/类
/**
 * Cumulative sum along a dimension
 *
 * @param dimension the dimension to perform cumulative sum along
 * @return the cumulative sum along the specified dimension
 */
@Override
public INDArray cumsumi(int dimension) {
    if(isVector()) {
        double s = 0.0;
        for (int i = 0; i < length; i++) {
            if(data.dataType().equals(DataBuffer.FLOAT))
                s += getDouble(i);
            else
                s+= getDouble(i);
            putScalar(i, s);
        }
    }

    else if(dimension == Integer.MAX_VALUE || dimension == shape.length - 1) {
        INDArray flattened = ravel().dup();
        double prevVal =  flattened.getDouble(0);
        for(int i = 1; i < flattened.length(); i++) {
            double d = prevVal + flattened.getDouble(i);
            flattened.putScalar(i,d);
            prevVal = d;
        }

        return flattened;
    }



    else {
        for(int i = 0; i < vectorsAlongDimension(dimension); i++) {
            INDArray vec = vectorAlongDimension(i,dimension);
            vec.cumsumi(0);

        }
    }


    return this;
}
 
开发者ID:wlin12,项目名称:JNN,代码行数:45,代码来源:BaseNDArray.java

示例14: dot

import org.nd4j.linalg.api.buffer.DataBuffer; //导入依赖的package包/类
/**
 * Compute x^T * y (dot product)
 */
@Override
public double dot(INDArray x, INDArray y) {
    //return NativeBlas.ddot(x.length(), x.data(), 0, 1, y.data(), 0, 1);
    DataTypeValidation.assertSameDataType(x,y);
    if(x.data().dataType().equals(DataBuffer.FLOAT))
        return JavaBlas.rdot(
                x.length(),
                x.data().asFloat(),
                x.offset(),
                x.secondaryStride(),
                y.data().asFloat(),
                y.offset(),
                y.secondaryStride());
    else if(x.data().dataType().equals(DataBuffer.DOUBLE)) {
        return JavaBlas.rdot(
                x.length(),
                x.data().asDouble(),
                x.offset(),
                x.secondaryStride(),
                y.data().asDouble(),
                y.offset(),
                y.secondaryStride());
    }

    throw new IllegalStateException("Illegal data type");

}
 
开发者ID:wlin12,项目名称:JNN,代码行数:31,代码来源:BlasWrapper.java

示例15: addi

import org.nd4j.linalg.api.buffer.DataBuffer; //导入依赖的package包/类
/**
 * in place addition of two matrices
 *
 * @param other  the second ndarray to add
 * @param result the result ndarray
 * @return the result of the addition
 */
@Override
public INDArray addi(INDArray other, INDArray result) {
    if (other.isScalar()) {
        return result.addi(other.getDouble(0),result);
    }
    if (isScalar()) {
        return other.addi(getDouble(0), result);
    }


    if (result == this) {
        if(data.dataType().equals(DataBuffer.DOUBLE))
            Nd4j.getBlasWrapper().axpy(1.0, other, result);


        else
            Nd4j.getBlasWrapper().axpy(1.0f,other,result);

    }

    else if (result == other) {
        if(data.dataType().equals(DataBuffer.DOUBLE))
            Nd4j.getBlasWrapper().axpy(1.0, this, result);
        else
            Nd4j.getBlasWrapper().axpy(1.0f,this,result);
    }
    else {

        INDArray resultLinear = result.linearView();
        INDArray otherLinear = other.linearView();
        INDArray linear = linearView();
        for(int i = 0; i < resultLinear.length(); i++) {
            resultLinear.putScalar(i,otherLinear.getDouble(i) + linear.getDouble(i));
        }

    }

    return result;
}
 
开发者ID:wlin12,项目名称:JNN,代码行数:47,代码来源:BaseNDArray.java


注:本文中的org.nd4j.linalg.api.buffer.DataBuffer类示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。