Skip to content

[Web] Improve large tensor loading in wasm runtime#19771

Open
MakotoUwu wants to merge 1 commit into
apache:mainfrom
MakotoUwu:ole-34/apache-tvm-webgpu-runtime-gemma4
Open

[Web] Improve large tensor loading in wasm runtime#19771
MakotoUwu wants to merge 1 commit into
apache:mainfrom
MakotoUwu:ole-34/apache-tvm-webgpu-runtime-gemma4

Conversation

@MakotoUwu

@MakotoUwu MakotoUwu commented Jun 14, 2026

Copy link
Copy Markdown

This splits out the Web/WebGPU runtime-only portion of #19766 into a smaller PR, following reviewer feedback that the compiler-side changes should be handled separately.

This PR keeps the scope to web/ runtime code:

  • make ArrayDecodeStorage tolerate f32-to-bf16 records whose payload is already native float32-sized, while preserving the existing packed-bf16 expansion path
  • load very large tensor-cache records in outer-dimension chunks; this avoids routing a single multi-hundred-MiB JS-to-wasm decode/copy call through the runtime boundary at once, while keeping the existing full-record path for smaller tensors
  • unpack kTVMFFIShape callback results as JS number arrays so chunked tensor views can pass explicit shape tuples

Concrete motivation: the Gemma 4 E2B MLC artifact has tensor-cache records at 192 MiB, 140 MiB, and 1120 MiB. The 1120 MiB record is model.embed_tokens_per_layer.q_weight with shape [262144, 1120]. Local browser validation on an Apple Silicon Chrome/WebGPU lane reproduced an abort during CPU-side arrayDecodeStorage of that record before GPU copy began. This is a JS/Wasm staging issue for very large tensor-cache records, not a claim that the target WebGPU device cannot allocate the final tensor.

The 128 MiB chunk cap is intentionally conservative and only applies as a per-call staging size for records above that threshold. Smaller records continue to use the existing full-record path.

Local validation:

  • npm run lint from web/
  • npx tsc --noEmit --pretty false from web/
  • git diff --check

I could not run the full local npm run prepwasm && npm run build path on this machine because Emscripten (emcc/emsdk) is not installed. The earlier broad PR had the Apache wasm CI pass before the compiler-side CI failures; this PR is intended to let the wasm job validate the runtime-only split independently.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request restructures TVM-FFI includes in wasm_runtime.cc to prevent static initialization crashes and updates ArrayDecodeStorage to tolerate uncompressed float32 weights under the 'f32-to-bf16' format. In runtime.ts, it introduces chunked record loading and copying (up to 128MB chunks) to handle large tensors efficiently, and adds support for kTVMFFIShape types. The review feedback suggests optimizing these chunking loops by utilizing the cached makeShapeTuple method on the Instance class rather than invoking the FFI this.ctx.makeShapeTuple repeatedly, which reduces redundant FFI round-trips.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment thread web/src/runtime.ts Outdated
Comment on lines +1469 to +1471
this.ctx.makeShapeTuple(
...chunkShape.map((value) => new Scalar(value, "int")),
),

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

We can leverage the cached makeShapeTuple method on the Instance class instead of directly calling the FFI this.ctx.makeShapeTuple on every chunk. This avoids redundant FFI round-trips to create the same shape tuple multiple times across chunks and records, improving performance.

                    this.makeShapeTuple(chunkShape),

Comment thread web/src/runtime.ts Outdated
Comment on lines +1513 to +1515
const chunkShapeTuple = this.ctx.makeShapeTuple(
...chunkShape.map((value) => new Scalar(value, "int")),
);

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

We can leverage the cached makeShapeTuple method on the Instance class instead of directly calling the FFI this.ctx.makeShapeTuple on every chunk. This avoids redundant FFI round-trips to create the same shape tuple multiple times across chunks and records, improving performance.

                  const chunkShapeTuple = this.makeShapeTuple(chunkShape);

@MakotoUwu MakotoUwu force-pushed the ole-34/apache-tvm-webgpu-runtime-gemma4 branch from 012380a to 9c4334d Compare June 14, 2026 20:31
@MakotoUwu MakotoUwu marked this pull request as ready for review June 15, 2026 09:30
Comment thread web/src/runtime.ts
artifactCache: ArtifactCacheTemplate,
signal?: AbortSignal,
) {
const maxChunkBytes = 128 * 1024 * 1024;

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how is the number determined, would be good to have a sense of what WebGPU runtime supports, the motivation here is not as clear

@tqchen tqchen requested a review from akaashrp June 15, 2026 11:47
@tqchen

tqchen commented Jun 15, 2026

Copy link
Copy Markdown
Member

@akaashrp would be great if you can help review

@tqchen

tqchen commented Jun 15, 2026

Copy link
Copy Markdown
Member

Please elaborate on "reorder FFI implementation includes before runtime implementation includes in the wasm single-translation-unit build, avoiding static initialization ordering issues during module startup"

Since static init ordering should not impact the startup, if there is an impact we need to know details

@tqchen

tqchen commented Jun 15, 2026

Copy link
Copy Markdown
Member

would be great to get elaboration on "load large tensor-cache records in chunks to avoid oversized JS-to-wasm decode/copy calls", is there a specific scenario that motivate this, generally, we would favor simplicity as long as such case is covered by webgpu runtime

@MakotoUwu MakotoUwu force-pushed the ole-34/apache-tvm-webgpu-runtime-gemma4 branch from 9c4334d to 1fea1f5 Compare June 15, 2026 12:41
@MakotoUwu

Copy link
Copy Markdown
Author

Thanks for the review. I updated the PR to narrow the scope and make the motivation more explicit:

  • Removed the FFI include reordering change from this PR. I agree the previous static-init wording was too strong without a concrete wasm startup trace in this split PR, so this now leaves the existing include order unchanged.
  • Kept the large tensor-cache record chunking and clarified the motivation in the PR description and code comment. The concrete case is a single tensor-cache record larger than the conservative WebGPU staging/view size used by the web runtime. The 128 MiB cap follows the existing maxStorageBufferBindingSize fallback in web/src/webgpu.ts; it is only a per-decode/per-copy view cap for loading large records, while smaller records still use the existing full-record path.

Local checks still pass after the update: npm run lint, npx tsc --noEmit --pretty false, and git diff --check.

@tqchen

tqchen commented Jun 15, 2026

Copy link
Copy Markdown
Member

Thanks, is there a real usecase for the chunking? .eg. if maxStorageBufferBindingSize limit is set really to 128MB, seems that means we canot allocate tensor larger than that anyway as a result we won't have copy in such shape. And in cases where maxStorageBufferBindingSize is larger, we don't need chunking.

@MakotoUwu MakotoUwu force-pushed the ole-34/apache-tvm-webgpu-runtime-gemma4 branch from 1fea1f5 to a83d3a4 Compare June 15, 2026 13:50
@MakotoUwu

Copy link
Copy Markdown
Author

You are right that maxStorageBufferBindingSize was not the right motivation. I updated the code comment and PR description to avoid that claim.

The concrete use case is a very large tensor-cache record that is valid for the target device but fragile as one JS/Wasm staging call. The Gemma 4 E2B artifact has three records above 128 MiB:

  • model.embed_tokens.q_weight: 192 MiB, shape [262144, 192]
  • model.embed_tokens_per_layer.q_scale: 140 MiB, shape [262144, 280]
  • model.embed_tokens_per_layer.q_weight: 1120 MiB, shape [262144, 1120]

In the local Chrome/WebGPU validation lane, the failure was deterministic on the 1120 MiB record. Instrumentation reached arrayDecodeStorage:start and aborted before arrayDecodeStorage:done, before GPU copy. So the chunking is intended to avoid one multi-hundred-MiB JS-to-wasm byte-array decode/copy call, not to work around a WebGPU allocation/binding limit. Smaller records still use the original full-record path.

@tqchen

tqchen commented Jun 15, 2026

Copy link
Copy Markdown
Member

/gemini review

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces chunked loading for large tensor-cache records to avoid massive single JS-to-wasm byte-array calls, and updates the WASM runtime to handle uncompressed float32 payloads labeled as 'f32-to-bf16'. Feedback on these changes highlights several critical issues: the custom f32-to-bf16 decoding path in C++ ignores cpu_arr->byte_offset which causes data corruption during chunked loading; the storageBytes helper in TypeScript fails for "bool" data types; a potential resource leak exists in the GPU chunked copy path if an exception occurs during view creation; and a defensive null check is needed for shapeObjPtr before loading memory.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment thread web/emcc/wasm_runtime.cc Outdated
// the "f32-to-bf16" tag without performing the bf16 truncation.
if (size == byte_size / 2) {
const uint16_t* bf16 = reinterpret_cast<const uint16_t*>(byte_data);
uint32_t* data = static_cast<uint32_t*>(cpu_arr->data);

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The custom f32-to-bf16 decoding path directly uses cpu_arr->data as the destination pointer, completely ignoring cpu_arr->byte_offset. When loading chunked records, chunkView is a tensor view with a non-zero byte_offset. Ignoring this offset causes all decoded chunks to be written to the very beginning of the underlying tensor, leading to silent data corruption. To fix this, the destination pointer must be offset by cpu_arr->byte_offset.

Suggested change
uint32_t* data = static_cast<uint32_t*>(cpu_arr->data);
uint32_t* data = reinterpret_cast<uint32_t*>(static_cast<char*>(cpu_arr->data) + cpu_arr->byte_offset);

Comment thread web/src/runtime.ts
Comment on lines +2362 to +2364
case TypeIndex.kTVMFFIShape: {
const shapeObjPtr = this.memory.loadPointer(valuePtr);
if (callbackArg) {

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Add a defensive null check for shapeObjPtr before performing pointer arithmetic and loading memory. If shapeObjPtr is null (0), accessing it will cause an out-of-bounds memory access or return garbage.

Suggested change
case TypeIndex.kTVMFFIShape: {
const shapeObjPtr = this.memory.loadPointer(valuePtr);
if (callbackArg) {
case TypeIndex.kTVMFFIShape: {
const shapeObjPtr = this.memory.loadPointer(valuePtr);
if (shapeObjPtr === 0) {
return null;
}
if (callbackArg) {

Comment thread web/src/runtime.ts
Comment on lines +1331 to +1339
const storageBytes = (dtype: string) => {
const match = dtype.match(/(\d+)(?:x(\d+))?$/);
if (match === null) {
throw new Error("Cannot determine storage width of dtype " + dtype);
}
const bits = Number(match[1]);
const lanes = match[2] === undefined ? 1 : Number(match[2]);
return (bits * lanes + 7) >> 3;
};

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The storageBytes helper uses a regular expression to extract the bit width from the dtype string. However, if the dtype is "bool", the regex will fail to match any digits and throw an error. In TVM, boolean tensors are commonly represented with the "bool" dtype (which is 1 byte / 8 bits in DLPack). Adding an explicit fallback for "bool" makes the helper more robust.

Suggested change
const storageBytes = (dtype: string) => {
const match = dtype.match(/(\d+)(?:x(\d+))?$/);
if (match === null) {
throw new Error("Cannot determine storage width of dtype " + dtype);
}
const bits = Number(match[1]);
const lanes = match[2] === undefined ? 1 : Number(match[2]);
return (bits * lanes + 7) >> 3;
};
const storageBytes = (dtype: string) => {
if (dtype === "bool") {
return 1;
}
const match = dtype.match(/(\d+)(?:x(\d+))?$/);
if (match === null) {
throw new Error("Cannot determine storage width of dtype " + dtype);
}
const bits = Number(match[1]);
const lanes = match[2] === undefined ? 1 : Number(match[2]);
return (bits * lanes + 7) >> 3;
};

Comment thread web/src/runtime.ts
Comment on lines +1512 to +1531
const [cpuView, gpuView] = this.withNewScope(() => {
return [
this.detachFromCurrentScope(
this.ctx.tensorCreateView(
cpu_arr,
chunkShapeTuple,
rec.dtype,
new Scalar(targetByteOffset, "int"),
)
),
this.detachFromCurrentScope(
this.ctx.tensorCreateView(
gpu_arr,
chunkShapeTuple,
rec.dtype,
new Scalar(targetByteOffset, "int"),
)
),
];
});

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

In the GPU copy path, cpuView and gpuView are created and immediately detached from the current scope inside the withNewScope block. If the second tensorCreateView (for gpuView) throws an exception, the first view (cpuView) will have already been detached from the scope's auto-disposal list, but the assignment to [cpuView, gpuView] will never complete. This results in a resource leak of cpuView because it is neither auto-disposed by the scope nor disposed in the finally block. To prevent this, detach both views only after both have been successfully created.

                const [cpuView, gpuView] = this.withNewScope(() => {
                  const cView = this.ctx.tensorCreateView(
                    cpu_arr,
                    chunkShapeTuple,
                    rec.dtype,
                    new Scalar(targetByteOffset, "int"),
                  );
                  const gView = this.ctx.tensorCreateView(
                    gpu_arr,
                    chunkShapeTuple,
                    rec.dtype,
                    new Scalar(targetByteOffset, "int"),
                  );
                  return [
                    this.detachFromCurrentScope(cView),
                    this.detachFromCurrentScope(gView),
                  ];
                });

@MakotoUwu MakotoUwu force-pushed the ole-34/apache-tvm-webgpu-runtime-gemma4 branch from a83d3a4 to f1e69f2 Compare June 15, 2026 15:16
@MakotoUwu

Copy link
Copy Markdown
Author

Addressed the Gemini follow-up comments in the latest commit:

  • ArrayDecodeStorage now honors cpu_arr->byte_offset in the packed bf16 expansion path, so chunk views write to the correct offset.
  • storageBytes now handles bool explicitly.
  • The GPU chunk copy path now creates both views before detaching them from the scope, avoiding a leak if the second view creation throws.
  • kTVMFFIShape callback decoding now guards against a null shape object pointer.

Local checks still pass: npm run lint, npx tsc --noEmit --pretty false, and git diff --check.

@MakotoUwu MakotoUwu force-pushed the ole-34/apache-tvm-webgpu-runtime-gemma4 branch from f1e69f2 to 1d2a3f1 Compare June 16, 2026 07:53
@MakotoUwu

Copy link
Copy Markdown
Author

Pushed a small follow-up on the new head 1d2a3f131:

  • Made withNewScope close scopes in a finally block, so failed chunk-view creation cannot leave a live scope/view behind.
  • Tightened the f32-to-bf16 decode branch to require the exact bf16 payload size before expansion, so malformed payload sizes fall through to the normal size validation path.

Local validation passes: git diff --check, npm run lint, and npx tsc --noEmit --pretty false. Remote CI is in progress on the new head.

@MakotoUwu MakotoUwu force-pushed the ole-34/apache-tvm-webgpu-runtime-gemma4 branch from 1d2a3f1 to a2359c1 Compare June 16, 2026 08:46
@MakotoUwu

Copy link
Copy Markdown
Author

Pushed a small follow-up on head a2359c159 after rerunning the downstream browser smoke with a WebLLM bundle rebuilt from this PR branch.

The smoke first exposed one more scope issue in the chunked path: this.makeShapeTuple(chunkShape) can create a cached TVM object on a cache miss, so calling it before withNewScope(...) failed with Must call beginScope to use functions that returns TVM objects. I moved both CPU and GPU chunk-path makeShapeTuple calls inside the existing withNewScope callbacks.

Local TVM web checks pass on the pushed head:

  • git diff --check
  • npm run lint
  • npx tsc --noEmit --pretty false

Downstream Chrome/WebGPU smoke also passes with a WebLLM JS bundle rebuilt against the pushed TVM source:

  • TVM PR head: a2359c159cb3fb94766335beea33f281fb9a7bba
  • Rebuilt WebLLM bundle SHA256: c357f887fde2408078e715abe4b7a6b54c719f44d87e4e9ba2cefe58e101e8a5
  • TVM wasm_runtime.bc SHA256: 8dbe21eeb30abda125e1b71464d16c8cf1f1c877be08c3a1a7275f0a27701a23
  • Model wasm used for the smoke: 70d7295dc91b622b79ceeada2c64b4c20787832631c04e3714de95db04515dfc
  • Browser: Chrome 149.0.7827.104, Apple M3 WebGPU/Metal

Prompt smoke results:

  • T1 Hi: load 11.6s, generated Hi there! How can I help you, finish length - PASS
  • T2 France one-word answer: load 5.3s, generated Paris, finish stop - PASS
  • T3 haiku: load 5.7s, generated a non-empty haiku-like response, generation 2.8s, finish stop - PASS

Boundary note: this validates the current PR-head TypeScript/WebLLM loading path with the latest available Apache model wasm. It is not a fresh model-wasm rebuild from a2359c159.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants