Skip to content

Commit

Permalink
[js/webgpu] allow uint8 tensors for webgpu (#19545)
Browse files Browse the repository at this point in the history
### Description
allow uint8 tensors for webgpu
  • Loading branch information
fs-eire committed Feb 17, 2024
1 parent 4874a41 commit 06269a3
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 3 deletions.
2 changes: 1 addition & 1 deletion js/common/lib/tensor-impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ export class Tensor implements TensorInterface {
}
case 'gpu-buffer': {
if ((type !== 'float32' && type !== 'float16' && type !== 'int32' && type !== 'int64' && type !== 'uint32' &&
type !== 'bool')) {
type !== 'uint8' && type !== 'bool')) {
throw new TypeError(`unsupported type "${type}" to create tensor from gpu buffer`);
}
this.gpuBufferData = arg0.gpuBuffer;
Expand Down
2 changes: 1 addition & 1 deletion js/common/lib/tensor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ export declare namespace Tensor {
/**
* supported data types for constructing a tensor from a WebGPU buffer
*/
export type GpuBufferDataTypes = 'float32'|'float16'|'int32'|'int64'|'uint32'|'bool';
export type GpuBufferDataTypes = 'float32'|'float16'|'int32'|'int64'|'uint32'|'uint8'|'bool';

/**
* represent where the tensor data is stored
Expand Down
3 changes: 2 additions & 1 deletion js/web/lib/wasm/wasm-common.ts
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,8 @@ export const logLevelStringToEnum = (logLevel?: 'verbose'|'info'|'warning'|'erro
* Check whether the given tensor type is supported by GPU buffer
*/
export const isGpuBufferSupportedType = (type: Tensor.Type): type is Tensor.GpuBufferDataTypes => type === 'float32' ||
type === 'int32' || type === 'int64' || type === 'bool' || type === 'float16' || type === 'uint32';
type === 'float16' || type === 'int32' || type === 'int64' || type === 'uint32' || type === 'uint8' ||
type === 'bool';

/**
* Map string data location to integer value
Expand Down

0 comments on commit 06269a3

Please sign in to comment.