當前位置: 首頁>>代碼示例>>TypeScript>>正文


TypeScript tfjs-core.scalar函數代碼示例

本文整理匯總了TypeScript中@tensorflow/tfjs-core.scalar函數的典型用法代碼示例。如果您正苦於以下問題:TypeScript scalar函數的具體用法?TypeScript scalar怎麽用?TypeScript scalar使用的例子?那麽, 這裏精選的函數代碼示例或許可以為您提供幫助。


在下文中一共展示了scalar函數的8個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的TypeScript代碼示例。

示例1: describe

describe('arithmetic', () => {
  let node: Node;
  const input1 = [tfc.scalar(1)];
  const input2 = [tfc.scalar(1)];
  const context = new ExecutionContext({}, {});

  beforeEach(() => {
    node = {
      name: 'test',
      op: '',
      category: 'arithmetic',
      inputNames: ['input1', 'input2'],
      inputs: [],
      params: {a: createTensorAttr(0), b: createTensorAttr(1)},
      children: []
    };
  });

  describe('executeOp', () => {
    ['add', 'mul', 'div', 'sub', 'maximum', 'minimum', 'pow',
     'squaredDifference', 'mod', 'floorDiv']
        .forEach((op => {
          it('should call tfc.' + op, () => {
            const spy = spyOn(tfc, op as 'add');
            node.op = op;
            executeOp(node, {input1, input2}, context);

            expect(spy).toHaveBeenCalledWith(input1[0], input2[0]);
          });
        }));
  });
});
開發者ID:oveddan,項目名稱:tfjs-converter,代碼行數:32,代碼來源:arithmetic_executor_test.ts

示例2: gLoss

 gLoss(generatedPred: tf.Tensor1D) {
   if (this.lossType === 'LeastSq loss') {
     return generatedPred.sub(tf.scalar(1)).square().mean() as tf.Scalar;
   } else {
     return generatedPred.log().mean().mul(tf.scalar(-1)) as tf.Scalar;
   }
 }
開發者ID:deepkapha,項目名稱:dklabs.github.io,代碼行數:7,代碼來源:ganlab_models.ts

示例3: it

        it('should execute control flow graph', async (done) => {
          inputNode = {
            inputNames: [],
            inputs: [],
            children: [],
            name: 'input',
            op: 'placeholder',
            category: 'graph',
            params: {}
          };
          constNode = {
            inputNames: [],
            inputs: [],
            children: [],
            name: 'const',
            op: 'const',
            category: 'graph',
            params: {}
          };
          outputNode = {
            inputNames: ['input', 'const'],
            inputs: [inputNode, constNode],
            children: [],
            name: 'output',
            op: 'switch',
            category: 'control',
            params: {}
          };
          inputNode.children.push(outputNode);
          constNode.children.push(outputNode);
          graphWithControlFlow = {
            inputs: [constNode, inputNode],
            nodes:
                {'input': inputNode, 'const': constNode, 'output': outputNode},
            outputs: [outputNode],
            withControlFlow: true,
            withDynamicShape: false,
            placeholders: [inputNode]
          };

          executor = new GraphExecutor(graphWithControlFlow);
          const inputTensor = tfc.scalar(1);
          const constTensor = tfc.scalar(2);
          executor.weightMap = {const : [constTensor]};
          const spy =
              spyOn(operations, 'executeOp').and.callFake((node: Node) => {
                return node.op === 'const' ? [constTensor] : [inputTensor];
              });

          await executor.executeAsync({input: [inputTensor]}).then(result => {
            expect(spy.calls.allArgs()).toEqual([
              [inputNode, jasmine.any(Object), jasmine.any(ExecutionContext)],
              [outputNode, jasmine.any(Object), jasmine.any(ExecutionContext)],
              [constNode, jasmine.any(Object), jasmine.any(ExecutionContext)],
            ]);
            done();
          });
        });
開發者ID:oveddan,項目名稱:tfjs-converter,代碼行數:58,代碼來源:graph_executor_test.ts

示例4: switch

export let executeOp: OpExecutor = (node: Node, tensorMap: NamedTensorsMap,
                                    context: ExecutionContext):
                                       tfc.Tensor[] => {
  switch (node.op) {
    case 'const': {
      return tensorMap[node.name];
    }
    case 'placeholder':
      const def =
          getParamValue('default', node, tensorMap, context) as tfc.Tensor;
      return [getTensor(node.name, tensorMap, context) || def];
    case 'identity':
    case 'stopGradient':
    case 'fakeQuantWithMinMaxVars':  // This op is currently ignored.
      return [getParamValue('x', node, tensorMap, context) as tfc.Tensor];
    case 'snapshot':
      const snapshot =
          (getParamValue('x', node, tensorMap, context) as tfc.Tensor);
      return [snapshot.clone()];
    case 'shape':
      return [tfc.tensor1d(
          (getParamValue('x', node, tensorMap, context) as tfc.Tensor).shape,
          'int32')];
    case 'size':
      return [tfc.scalar(
          (getParamValue('x', node, tensorMap, context) as tfc.Tensor).size,
          'int32')];
    case 'rank':
      return [tfc.scalar(
          (getParamValue('x', node, tensorMap, context) as tfc.Tensor).rank,
          'int32')];
    case 'noop':
      return [];
    case 'print':
      const input = getParamValue('x', node, tensorMap, context) as tfc.Tensor;
      const data =
          getParamValue('data', node, tensorMap, context) as tfc.Tensor[];
      const message =
          getParamValue('message', node, tensorMap, context) as string;
      const summarize =
          getParamValue('summarize', node, tensorMap, context) as number;
      console.warn(
          'The graph has a tf.print() operation,' +
          'usually used for debugging, which slows down performance.');
      console.log(message);
      for (let i = 0; i < data.length; i++) {
        console.log(
            Array.prototype.slice.call(data[0].dataSync()).slice(0, summarize));
      }
      return [input];

    default:
      throw TypeError(`Node type ${node.op} is not implemented`);
  }
};
開發者ID:oveddan,項目名稱:tfjs-converter,代碼行數:55,代碼來源:graph_executor.ts

示例5: describe

describe('logical', () => {
  let node: Node;
  const input1 = [tfc.scalar(1)];
  const input2 = [tfc.scalar(2)];
  const context = new ExecutionContext({}, {});

  beforeEach(() => {
    node = {
      name: 'test',
      op: '',
      category: 'logical',
      inputNames: ['input1', 'input2'],
      inputs: [],
      params: {a: createTensorAttr(0), b: createTensorAttr(1)},
      children: []
    };
  });

  describe('executeOp', () => {
    ['equal', 'notEqual', 'greater', 'greaterEqual', 'less', 'lessEqual',
     'logicalAnd', 'logicalOr']
        .forEach(op => {
          it('should call tfc.' + op, () => {
            const spy = spyOn(tfc, op as 'equal');
            node.op = op;
            executeOp(node, {input1, input2}, context);

            expect(spy).toHaveBeenCalledWith(input1[0], input2[0]);
          });
        });
    describe('logicalNot', () => {
      it('should call tfc.logicalNot', () => {
        spyOn(tfc, 'logicalNot');
        node.op = 'logicalNot';
        executeOp(node, {input1}, context);

        expect(tfc.logicalNot).toHaveBeenCalledWith(input1[0]);
      });
    });

    describe('where', () => {
      it('should call tfc.where', () => {
        spyOn(tfc, 'where');
        node.op = 'where';
        node.inputNames = ['input1', 'input2', 'input3'];
        node.params.condition = createTensorAttr(2);
        const input3 = [tfc.scalar(1)];
        executeOp(node, {input1, input2, input3}, context);

        expect(tfc.where).toHaveBeenCalledWith(input3[0], input1[0], input2[0]);
      });
    });
  });
});
開發者ID:oveddan,項目名稱:tfjs-converter,代碼行數:54,代碼來源:logical_executor_test.ts

示例6: it

      it('should call tfc.linspace', () => {
        spyOn(tfc, 'linspace');
        node.op = 'linspace';
        node.params['start'] = createNumberAttrFromIndex(0);
        node.params['stop'] = createNumberAttrFromIndex(1);
        node.params['num'] = createNumberAttrFromIndex(2);
        node.inputNames = ['input', 'input2', 'input3'];
        const input = [tfc.scalar(0)];
        const input3 = [tfc.scalar(2)];
        executeOp(node, {input, input2, input3}, context);

        expect(tfc.linspace).toHaveBeenCalledWith(0, 1, 2);
      });
開發者ID:oveddan,項目名稱:tfjs-converter,代碼行數:13,代碼來源:creation_executor_test.ts

示例7: dLoss

 // Define losses.
 dLoss(truePred: tf.Tensor1D, generatedPred: tf.Tensor1D) {
   if (this.lossType === 'LeastSq loss') {
     return tf.add(
       truePred.sub(tf.scalar(1)).square().mean(),
       generatedPred.square().mean()
     ) as tf.Scalar;
   } else {
     return tf.add(
       truePred.log().mul(tf.scalar(0.95)).mean(),
       tf.sub(tf.scalar(1), generatedPred).log().mean()
     ).mul(tf.scalar(-1)) as tf.Scalar;
   }
 }
開發者ID:deepkapha,項目名稱:dklabs.github.io,代碼行數:14,代碼來源:ganlab_models.ts

示例8: it

      it('should write the tensor to tensorArray', async () => {
        const tensorArray =
            new TensorArray('', 'int32', 5, [], true, false, true);
        context.addTensorArray(tensorArray);
        node.op = 'tensorArrayWrite';
        node.params['tensorArrayId'] = createNumberAttrFromIndex(0);
        node.params['index'] = createNumberAttrFromIndex(1);
        node.params['tensor'] = createTensorAttr(2);
        node.inputNames = ['input2', 'input3', 'input1'];
        const input2 = [scalar(tensorArray.id)];
        const input3 = [scalar(0)];
        await executeOp(node, {input1, input2, input3}, context);

        expect(tensorArray.size()).toEqual(1);
      });
開發者ID:oveddan,項目名稱:tfjs-converter,代碼行數:15,代碼來源:control_executor_test.ts


注:本文中的@tensorflow/tfjs-core.scalar函數示例由純淨天空整理自Github/MSDocs等開源代碼及文檔管理平台,相關代碼片段篩選自各路編程大神貢獻的開源項目,源碼版權歸原作者所有,傳播和使用請參考對應項目的License;未經允許,請勿轉載。