Skip to content

Commit

Permalink
[js/web] use ApiTensor insteadof onnxjs Tensor in TensorResultValidat…
Browse files Browse the repository at this point in the history
…or (#19358)

### Description
use ApiTensor insteadof onnxjs Tensor in TensorResultValidator. Make
test runner less depend on onnxjs classes.
  • Loading branch information
fs-eire committed Feb 21, 2024
1 parent 3fe2c13 commit 70567a4
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 17 deletions.
26 changes: 10 additions & 16 deletions js/web/test/test-runner.ts
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,6 @@ const ONNXRUNTIME_THRESHOLD_RELATIVE_ERROR = 1.00001;
*/
const now = (typeof performance !== 'undefined' && performance.now) ? () => performance.now() : Date.now;

function toInternalTensor(tensor: ort.Tensor): Tensor {
return new Tensor(
tensor.dims, tensor.type as Tensor.DataType, undefined, undefined, tensor.data as Tensor.NumberType);
}
function fromInternalTensor(tensor: Tensor): ort.Tensor {
return new ort.Tensor(tensor.type, tensor.data as ort.Tensor.DataType, tensor.dims);
}
Expand Down Expand Up @@ -330,6 +326,10 @@ export class TensorResultValidator {
}

checkTensorResult(actual: Tensor[], expected: Tensor[]): void {
this.checkApiTensorResult(actual.map(fromInternalTensor), expected.map(fromInternalTensor));
}

checkApiTensorResult(actual: ort.Tensor[], expected: ort.Tensor[]): void {
// check output size
expect(actual.length, 'size of output tensors').to.equal(expected.length);

Expand All @@ -347,10 +347,6 @@ export class TensorResultValidator {
}
}

checkApiTensorResult(actual: ort.Tensor[], expected: ort.Tensor[]): void {
this.checkTensorResult(actual.map(toInternalTensor), expected.map(toInternalTensor));
}

checkNamedTensorResult(actual: Record<string, ort.Tensor>, expected: Test.NamedTensor[]): void {
// check output size
expect(Object.getOwnPropertyNames(actual).length, 'size of output tensors').to.equal(expected.length);
Expand All @@ -364,7 +360,7 @@ export class TensorResultValidator {
}

// This function check whether 2 tensors should be considered as 'match' or not
areEqual(actual: Tensor, expected: Tensor): boolean {
areEqual(actual: ort.Tensor, expected: ort.Tensor): boolean {
if (!actual || !expected) {
return false;
}
Expand Down Expand Up @@ -392,13 +388,13 @@ export class TensorResultValidator {

switch (actualType) {
case 'string':
return this.strictEqual(actual.stringData, expected.stringData);
return this.strictEqual(actual.data, expected.data);

case 'float32':
case 'float64':
return this.floatEqual(
actual.numberData as number[] | Float32Array | Float64Array,
expected.numberData as number[] | Float32Array | Float64Array);
actual.data as number[] | Float32Array | Float64Array,
expected.data as number[] | Float32Array | Float64Array);

case 'uint8':
case 'int8':
Expand All @@ -409,10 +405,8 @@ export class TensorResultValidator {
case 'int64':
case 'bool':
return TensorResultValidator.integerEqual(
actual.numberData as number[] | Uint8Array | Int8Array | Uint16Array | Int16Array | Uint32Array |
Int32Array,
expected.numberData as number[] | Uint8Array | Int8Array | Uint16Array | Int16Array | Uint32Array |
Int32Array);
actual.data as number[] | Uint8Array | Int8Array | Uint16Array | Int16Array | Uint32Array | Int32Array,
expected.data as number[] | Uint8Array | Int8Array | Uint16Array | Int16Array | Uint32Array | Int32Array);

default:
throw new Error('type not implemented or not supported');
Expand Down
4 changes: 3 additions & 1 deletion js/web/test/unittests/backends/webgl/test-conv-new.ts
Original file line number Diff line number Diff line change
Expand Up @@ -893,7 +893,9 @@ describe('New Conv tests', () => {
const expected = cpuConv(
inputTensor, kernelTensor, biasTensor, testData.autoPad, testData.dilations, testData.pads,
testData.strides);
if (!validator.areEqual(actual, expected)) {
try {
validator.checkTensorResult([actual], [expected]);
} catch {
console.log(actual.dims, `[${actual.numberData.slice(0, 20).join(',')},...]`);
console.log(expected.dims, `[${expected.numberData.slice(0, 20).join(',')},...]`);
throw new Error('Expected and Actual did not match');
Expand Down

0 comments on commit 70567a4

Please sign in to comment.