当前位置: 首页>>代码示例>>Java>>正文


Java JCublas.cublasInit方法代码示例

本文整理汇总了Java中jcuda.jcublas.JCublas.cublasInit方法的典型用法代码示例。如果您正苦于以下问题:Java JCublas.cublasInit方法的具体用法?Java JCublas.cublasInit怎么用?Java JCublas.cublasInit使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在jcuda.jcublas.JCublas的用法示例。


在下文中一共展示了JCublas.cublasInit方法的10个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Java代码示例。

示例1: scal

import jcuda.jcublas.JCublas; //导入方法依赖的package包/类
/**
 * Complex multiplication of an ndarray
 * @param alpha
 * @param x
 * @return
 */
public static IComplexNDArray scal(IComplexFloat alpha, IComplexNDArray x) {
    JCublasComplexNDArray xC = (JCublasComplexNDArray) x;
    DataTypeValidation.assertFloat(x);
    JCublas.cublasInit();

    Pointer xCPointer = alloc(xC);

    JCublas.cublasCscal(
            x.length(),
            jcuda.cuComplex.cuCmplx(alpha.realComponent(), alpha.imaginaryComponent()),
            xCPointer,
            1
    );


    getData(xC,xCPointer,Pointer.to(xC.data().asFloat()));

    free(xCPointer);

    return x;
}
 
开发者ID:wlin12,项目名称:JNN,代码行数:28,代码来源:SimpleJCublas.java

示例2: axpy

import jcuda.jcublas.JCublas; //导入方法依赖的package包/类
/**
 * Simpler version of saxpy
 * taking in to account the parameters of the ndarray
 * @param alpha the alpha to scale by
 * @param x the x
 * @param y the y
 */
public static void axpy(double alpha, INDArray x, INDArray y) {
    DataTypeValidation.assertDouble(x,y);
    JCublas.cublasInit();
    JCublasNDArray xC = (JCublasNDArray) x;
    JCublasNDArray yC = (JCublasNDArray) y;

    Pointer xCPointer = alloc(xC);
    Pointer yCPointer = alloc(yC);

    JCublas.cublasDaxpy(x.length(), alpha, xCPointer, 1, yCPointer, 1);


    getData(yC,yCPointer,Pointer.to(yC.data().asDouble()));
    free(xCPointer,yCPointer);

}
 
开发者ID:wlin12,项目名称:JNN,代码行数:24,代码来源:SimpleJCublas.java

示例3: iamax

import jcuda.jcublas.JCublas; //导入方法依赖的package包/类
/**
 * Returns the index of the max element
 * in the given ndarray
 * @param x
 * @return
 */
public static int iamax(INDArray x) {
    JCublas.cublasInit();

    JCublasNDArray xC = (JCublasNDArray) x;
    Pointer xCPointer = alloc(xC);


    int max = JCublas.cublasIsamax(
            x.length(),
            xCPointer,
            1);
    free(xCPointer);
    return max - 1;

}
 
开发者ID:wlin12,项目名称:JNN,代码行数:22,代码来源:SimpleJCublas.java

示例4: saxpy

import jcuda.jcublas.JCublas; //导入方法依赖的package包/类
/**
 * Simpler version of saxpy
 * taking in to account the parameters of the ndarray
 * @param alpha the alpha to scale by
 * @param x the x
 * @param y the y
 */
public static void saxpy(float alpha, INDArray x, INDArray y) {
    DataTypeValidation.assertFloat(x,y);
    JCublas.cublasInit();
    JCublasNDArray xC = (JCublasNDArray) x;
    JCublasNDArray yC = (JCublasNDArray) y;

    Pointer xCPointer = alloc(xC);
    Pointer yCPointer = alloc(yC);

    JCublas.cublasSaxpy(x.length(),alpha,xCPointer,1,yCPointer,1);


    getData(yC,yCPointer,Pointer.to(yC.data().asFloat()));
    free(xCPointer,yCPointer);

}
 
开发者ID:wlin12,项目名称:JNN,代码行数:24,代码来源:SimpleJCublas.java

示例5: gemv

import jcuda.jcublas.JCublas; //导入方法依赖的package包/类
/**
 * General matrix vector multiplication
 * @param A
 * @param B
 * @param C
 * @param alpha
 * @param beta
 * @return
 */
public static INDArray gemv(INDArray A, INDArray B, INDArray C, double alpha, double beta) {

    DataTypeValidation.assertDouble(A,B,C);
    JCublas.cublasInit();

    JCublasNDArray cA = (JCublasNDArray) A;
    JCublasNDArray cB = (JCublasNDArray) B;
    JCublasNDArray cC = (JCublasNDArray) C;

    Pointer cAPointer = alloc(cA);
    Pointer cBPointer = alloc(cB);
    Pointer cCPointer = alloc(cC);


    JCublas.cublasDgemv(
            'N',
            A.rows(),
            A.columns(),
            alpha,
            cAPointer,
            A.rows(),
            cBPointer,
            1,
            beta,
            cCPointer,
            1);

    getData(cC,cCPointer,Pointer.to(cC.data().asDouble()));
    free(cAPointer,cBPointer,cCPointer);


    return C;
}
 
开发者ID:wlin12,项目名称:JNN,代码行数:43,代码来源:SimpleJCublas.java

示例6: gemm

import jcuda.jcublas.JCublas; //导入方法依赖的package包/类
/**
 * General matrix multiply
 * @param A
 * @param B
 * @param a
 * @param C
 * @param b
 * @return
 */
public static IComplexNDArray gemm(IComplexNDArray A, IComplexNDArray B, IComplexFloat a,IComplexNDArray C
        , IComplexFloat b) {
    DataTypeValidation.assertFloat(A,B,C);
    JCublas.cublasInit();

    JCublasComplexNDArray cA = (JCublasComplexNDArray) A;
    JCublasComplexNDArray cB = (JCublasComplexNDArray) B;
    JCublasComplexNDArray cC = (JCublasComplexNDArray) C;

    Pointer cAPointer = alloc(cA);
    Pointer cBPointer = alloc(cB);
    Pointer cCPointer = alloc(cC);


    cuComplex alpha = cuComplex.cuCmplx(a.realComponent().floatValue(),b.imaginaryComponent().floatValue());
    cuComplex beta = cuComplex.cuCmplx(b.realComponent().floatValue(),b.imaginaryComponent().floatValue());

    JCublas.cublasCgemm(
            'n', //trans
            'n',
            cC.rows(),  // m
            cC.columns(), // n
            cA.columns(), //k,
            alpha,
            cAPointer, // A
            A.rows(),  // lda
            cBPointer, // x
            B.rows(), // ldb
            beta,  // beta
            cCPointer, // y
            C.rows()); // ldc


    getData(cC,cCPointer,Pointer.to(cC.data().asFloat()));
    free(cAPointer,cBPointer,cCPointer);

    return C;

}
 
开发者ID:wlin12,项目名称:JNN,代码行数:49,代码来源:SimpleJCublas.java

示例7: copy

import jcuda.jcublas.JCublas; //导入方法依赖的package包/类
/**
 * Copy x to y
 * @param x the origin
 * @param y the destination
 */
public static void copy(IComplexNDArray x, IComplexNDArray y) {
    DataTypeValidation.assertSameDataType(x,y);
    JCublas.cublasInit();

    JCublasComplexNDArray xC = (JCublasComplexNDArray) x;
    JCublasComplexNDArray yC = (JCublasComplexNDArray) y;

    Pointer xCPointer = alloc(xC);
    Pointer yCPointer = alloc(yC);

    if(xC.data().dataType().equals(DataBuffer.FLOAT)) {
        JCublas.cublasScopy(
                x.length(),
                xCPointer,
                1,
                yCPointer,
                1);


        getData(yC,yCPointer,Pointer.to(yC.data().asFloat()));

    }

    else {
        JCublas.cublasDcopy(
                x.length(),
                xCPointer,
                1,
                yCPointer,
                1);


        getData(yC,yCPointer,Pointer.to(yC.data().asDouble()));

    }


    free(xCPointer,yCPointer);
}
 
开发者ID:wlin12,项目名称:JNN,代码行数:45,代码来源:SimpleJCublas.java

示例8: swap

import jcuda.jcublas.JCublas; //导入方法依赖的package包/类
/**
 * Swap the elements in each ndarray
 * @param x
 * @param y
 */
public static void swap(INDArray x, INDArray y) {

    DataTypeValidation.assertSameDataType(x,y);
    JCublas.cublasInit();

    JCublasNDArray xC = (JCublasNDArray) x;
    JCublasNDArray yC = (JCublasNDArray) y;
    Pointer xCPointer = alloc(xC);
    Pointer yCPointer = alloc(yC);


    if(xC.data().dataType().equals(DataBuffer.FLOAT)) {
        JCublas.cublasSswap(
                xC.length(),
                xCPointer,
                1,
                yCPointer,
                1);

        getData(yC,yCPointer,Pointer.to(yC.data().asFloat()));
    }

    else {
        JCublas.cublasDswap(
                xC.length(),
                xCPointer,
                1,
                yCPointer,
                1);

        getData(yC,yCPointer,Pointer.to(yC.data().asDouble()));
    }

    free(xCPointer,yCPointer);

}
 
开发者ID:wlin12,项目名称:JNN,代码行数:42,代码来源:SimpleJCublas.java

示例9: nrm2

import jcuda.jcublas.JCublas; //导入方法依赖的package包/类
/**
 * Returns the norm2 of the given ndarray
 * @param x
 * @return
 */
public static float nrm2(INDArray x) {
    JCublas.cublasInit();
    JCublasNDArray xC = (JCublasNDArray) x;
    Pointer xCPointer = alloc(xC);


    float normal2 = JCublas.cublasSnrm2(x.length(), xCPointer, 1);
    JCublas.cublasFree(xCPointer);
    return normal2;
}
 
开发者ID:wlin12,项目名称:JNN,代码行数:16,代码来源:SimpleJCublas.java

示例10: ger

import jcuda.jcublas.JCublas; //导入方法依赖的package包/类
public static INDArray ger(INDArray A, INDArray B, INDArray C, double alpha) {
    DataTypeValidation.assertDouble(A,B,C);
    JCublas.cublasInit();
    // = alpha * A * transpose(B) + C
    JCublasNDArray aC = (JCublasNDArray) A;
    JCublasNDArray bC = (JCublasNDArray) B;
    JCublasNDArray cC = (JCublasNDArray) C;

    Pointer aCPointer = alloc(aC);
    Pointer bCPointer = alloc(bC);
    Pointer cCPointer = alloc(cC);


    JCublas.cublasDger(
            A.rows(),   // m
            A.columns(),// n
            alpha,      // alpha
            aCPointer,        // d_A or x
            A.rows(),   // incx
            bCPointer,        // dB or y
            B.rows(),   // incy
            cCPointer,        // dC or A
            C.rows()    // lda
    );

    getData(cC,cCPointer,Pointer.to(cC.data().asDouble()));
    free(aCPointer, bCPointer, cCPointer);

    return C;
}
 
开发者ID:wlin12,项目名称:JNN,代码行数:31,代码来源:SimpleJCublas.java


注:本文中的jcuda.jcublas.JCublas.cublasInit方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。