Skip to content

Commit eae06b5

Browse files
Add Cos to kernel registry (cpu backend) (#3907)
1 parent e39ba95 commit eae06b5

File tree

4 files changed

+60
-12
lines changed

4 files changed

+60
-12
lines changed

tfjs-backend-cpu/.vscode/launch.json

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
{
2+
// Use IntelliSense to learn about possible attributes.
3+
// Hover to view descriptions of existing attributes.
4+
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
5+
"version": "0.2.0",
6+
"configurations": [
7+
{
8+
"name": "Run tests",
9+
"type": "node",
10+
"request": "launch",
11+
"runtimeExecutable": "yarn",
12+
"runtimeArgs": ["test"],
13+
"sourceMaps": true,
14+
"cwd": "${workspaceRoot}",
15+
"protocol": "inspector",
16+
}
17+
]
18+
}

tfjs-backend-cpu/src/backend_cpu.ts

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1279,17 +1279,6 @@ export class MathBackendCPU extends KernelBackend {
12791279
return this.makeOutput(resultValues, x.shape, 'float32');
12801280
}
12811281

1282-
cos<T extends Tensor>(x: T): T {
1283-
assertNotComplex(x, 'cos');
1284-
1285-
const resultValues = new Float32Array(x.size);
1286-
const values = this.readSync(x.dataId) as TypedArray;
1287-
for (let i = 0; i < values.length; ++i) {
1288-
resultValues[i] = Math.cos(values[i]);
1289-
}
1290-
return this.makeOutput(resultValues, x.shape, 'float32');
1291-
}
1292-
12931282
tan<T extends Tensor>(x: T): T {
12941283
assertNotComplex(x, 'tan');
12951284

tfjs-backend-cpu/src/kernels/Cos.ts

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
/**
2+
* @license
3+
* Copyright 2020 Google LLC. All Rights Reserved.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
import {Cos, CosInputs, KernelConfig, TypedArray, util} from '@tensorflow/tfjs-core';
19+
20+
import {MathBackendCPU} from '../backend_cpu';
21+
import {assertNotComplex} from '../cpu_util';
22+
23+
export const cosConfig: KernelConfig = {
24+
kernelName: Cos,
25+
backendName: 'cpu',
26+
kernelFunc: ({inputs, backend}) => {
27+
const {x} = inputs as CosInputs;
28+
const cpuBackend = backend as MathBackendCPU;
29+
assertNotComplex(x, 'cos');
30+
31+
const values = cpuBackend.data.get(x.dataId).values as TypedArray;
32+
const xSize = util.sizeFromShape(x.shape);
33+
const newValues = new Float32Array(xSize);
34+
for (let i = 0; i < xSize; ++i) {
35+
newValues[i] = Math.cos(values[i]);
36+
}
37+
const dataId = cpuBackend.write(newValues, x.shape, x.dtype);
38+
return {dataId, shape: x.shape, dtype: x.dtype};
39+
}
40+
};

tfjs-backend-cpu/src/register_all_kernels.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
// the contents of this file and import only the kernels that are needed.
2020
import {KernelConfig, registerKernel} from '@tensorflow/tfjs-core';
2121

22+
import {cosConfig} from './kernels/Cos';
2223
import {dilation2dConfig} from './kernels/Dilation2D';
2324
import {dilation2dBackpropFilterConfig} from './kernels/Dilation2DBackpropFilter';
2425
import {dilation2dBackpropInputConfig} from './kernels/Dilation2DBackpropInput';
@@ -39,7 +40,7 @@ import {transposeConfig} from './kernels/Transpose';
3940

4041
// List all kernel configs here
4142
const kernelConfigs: KernelConfig[] = [
42-
dilation2dConfig, dilation2dBackpropInputConfig,
43+
cosConfig, dilation2dConfig, dilation2dBackpropInputConfig,
4344
dilation2dBackpropFilterConfig, divConfig, flipLeftRightConfig,
4445
identityConfig, maxPoolWithArgmaxConfig, maxConfig, nonMaxSuppressionV4Config,
4546
nonMaxSuppressionV5Config, padV2Config, reshapeConfig, rotateWithOffsetConfig,

0 commit comments

Comments
 (0)