Skip to content

Commit 32bdc3e

Browse files
authored
Merge pull request #1031 from cloudflare/webgpu_do
webgpu: ensure api restrictions
2 parents 9e4a9ca + 57b6e4a commit 32bdc3e

12 files changed

+133
-20
lines changed

src/workerd/api/global-scope.c++

+17-1
Original file line numberDiff line numberDiff line change
@@ -786,4 +786,20 @@ double Performance::now() {
786786
return dateNow();
787787
}
788788

789-
} // namespace workerd::api
789+
#ifdef WORKERD_EXPERIMENTAL_ENABLE_WEBGPU
790+
jsg::Ref<api::gpu::GPU> Navigator::getGPU(CompatibilityFlags::Reader flags) {
791+
// is this a durable object?
792+
KJ_IF_MAYBE (actor, IoContext::current().getActor()) {
793+
JSG_REQUIRE(actor->getPersistent() != nullptr, TypeError,
794+
"webgpu api is only available in Durable Objects (no storage)");
795+
} else {
796+
JSG_FAIL_REQUIRE(TypeError, "webgpu api is only available in Durable Objects");
797+
};
798+
799+
JSG_REQUIRE(flags.getWebgpu(), TypeError, "webgpu needs the webgpu compatibility flag set");
800+
801+
return jsg::alloc<api::gpu::GPU>();
802+
}
803+
#endif
804+
805+
} // namespace workerd::api

src/workerd/api/global-scope.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ class Navigator: public jsg::Object {
4444
public:
4545
kj::StringPtr getUserAgent() { return "Cloudflare-Workers"_kj; }
4646
#ifdef WORKERD_EXPERIMENTAL_ENABLE_WEBGPU
47-
jsg::Ref<api::gpu::GPU> getGPU() { return jsg::alloc<api::gpu::GPU>(); }
47+
jsg::Ref<api::gpu::GPU> getGPU(CompatibilityFlags::Reader flags);
4848
#endif
4949

5050
JSG_RESOURCE_TYPE(Navigator) {

src/workerd/api/gpu/gpu.c++

+1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
// https://opensource.org/licenses/Apache-2.0
44

55
#include "gpu.h"
6+
#include "workerd/jsg/exception.h"
67
#include <dawn/dawn_proc.h>
78

89
namespace workerd::api::gpu {

src/workerd/api/gpu/webgpu-buffer-test.gpu-wd-test

+8-1
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,15 @@ const unitTests :Workerd.Config = (
77
modules = [
88
(name = "worker", esModule = embed "webgpu-buffer-test.js")
99
],
10+
durableObjectNamespaces = [
11+
(className = "DurableObjectExample", uniqueKey = "210bd0cbd803ef7883a1ee9d86cce06e"),
12+
],
13+
durableObjectStorage = (inMemory = void),
14+
bindings = [
15+
(name = "ns", durableObjectNamespace = "DurableObjectExample"),
16+
],
1017
compatibilityDate = "2023-01-15",
11-
compatibilityFlags = ["experimental", "nodejs_compat"],
18+
compatibilityFlags = ["experimental", "nodejs_compat", "webgpu"],
1219
)
1320
),
1421
],

src/workerd/api/gpu/webgpu-buffer-test.js

+20-4
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
1-
import { deepEqual, ok } from "node:assert";
1+
import { deepEqual, ok, equal } from "node:assert";
22

33
// run manually for now
44
// bazel run --//src/workerd/io:enable_experimental_webgpu //src/workerd/server:workerd -- test `realpath ./src/workerd/api/gpu/webgpu-buffer-test.gpu-wd-test` --verbose --experimental
55

6-
export const read_sync_stack = {
7-
async test(ctrl, env, ctx) {
6+
export class DurableObjectExample {
7+
constructor(state) {
8+
this.state = state;
9+
}
10+
11+
async fetch() {
812
ok(navigator.gpu);
913
const adapter = await navigator.gpu.requestAdapter();
1014
ok(adapter);
@@ -56,6 +60,18 @@ export const read_sync_stack = {
5660
const copyArrayBuffer = gpuReadBuffer.getMappedRange();
5761
ok(copyArrayBuffer);
5862

59-
deepEqual(new Uint8Array(copyArrayBuffer), new Uint8Array([ 0, 1, 2, 3 ]));
63+
deepEqual(new Uint8Array(copyArrayBuffer), new Uint8Array([0, 1, 2, 3]));
64+
65+
return new Response("OK");
66+
}
67+
}
68+
69+
export const buffer_mapping = {
70+
async test(ctrl, env, ctx) {
71+
let id = env.ns.idFromName("A");
72+
let obj = env.ns.get(id);
73+
let res = await obj.fetch("http://foo/test");
74+
let text = await res.text();
75+
equal(text, "OK");
6076
},
6177
};

src/workerd/api/gpu/webgpu-compute-test.gpu-wd-test

+8-1
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,15 @@ const unitTests :Workerd.Config = (
77
modules = [
88
(name = "worker", esModule = embed "webgpu-compute-test.js")
99
],
10+
durableObjectNamespaces = [
11+
(className = "DurableObjectExample", uniqueKey = "210bd0cbd803ef7883a1ee9d86cce06e"),
12+
],
13+
durableObjectStorage = (inMemory = void),
14+
bindings = [
15+
(name = "ns", durableObjectNamespace = "DurableObjectExample"),
16+
],
1017
compatibilityDate = "2023-01-15",
11-
compatibilityFlags = ["experimental", "nodejs_compat"],
18+
compatibilityFlags = ["experimental", "nodejs_compat", "webgpu"],
1219
)
1320
),
1421
],

src/workerd/api/gpu/webgpu-compute-test.js

+19-3
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
1-
import { deepEqual, ok } from "node:assert";
1+
import { deepEqual, ok, equal } from "node:assert";
22

33
// run manually for now
44
// bazel run --//src/workerd/io:enable_experimental_webgpu //src/workerd/server:workerd -- test `realpath ./src/workerd/api/gpu/webgpu-compute-test.gpu-wd-test` --verbose --experimental
55

6-
export const read_sync_stack = {
7-
async test(ctrl, env, ctx) {
6+
export class DurableObjectExample {
7+
constructor(state) {
8+
this.state = state;
9+
}
10+
11+
async fetch() {
812
ok(navigator.gpu);
913
if (!("gpu" in navigator)) {
1014
console.log(
@@ -271,5 +275,17 @@ export const read_sync_stack = {
271275
new Float32Array(arrayBuffer),
272276
new Float32Array([2, 2, 50, 60, 114, 140])
273277
);
278+
279+
return new Response("OK");
280+
}
281+
}
282+
283+
export const compute_shader = {
284+
async test(ctrl, env, ctx) {
285+
let id = env.ns.idFromName("A");
286+
let obj = env.ns.get(id);
287+
let res = await obj.fetch("http://foo/test");
288+
let text = await res.text();
289+
equal(text, "OK");
274290
},
275291
};

src/workerd/api/gpu/webgpu-errors-test.gpu-wd-test

+8-1
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,15 @@ const unitTests :Workerd.Config = (
77
modules = [
88
(name = "worker", esModule = embed "webgpu-errors-test.js")
99
],
10+
durableObjectNamespaces = [
11+
(className = "DurableObjectExample", uniqueKey = "210bd0cbd803ef7883a1ee9d86cce06e"),
12+
],
13+
durableObjectStorage = (inMemory = void),
14+
bindings = [
15+
(name = "ns", durableObjectNamespace = "DurableObjectExample"),
16+
],
1017
compatibilityDate = "2023-01-15",
11-
compatibilityFlags = ["experimental", "nodejs_compat"],
18+
compatibilityFlags = ["experimental", "nodejs_compat", "webgpu"],
1219
)
1320
),
1421
],

src/workerd/api/gpu/webgpu-errors-test.js

+19-3
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
1-
import { ok } from "node:assert";
1+
import { ok, equal } from "node:assert";
22

33
// run manually for now
44
// bazel run --//src/workerd/io:enable_experimental_webgpu //src/workerd/server:workerd -- test `realpath ./src/workerd/api/gpu/webgpu-errors-test.gpu-wd-test` --verbose --experimental
55

6-
export const read_sync_stack = {
7-
async test(ctrl, env, ctx) {
6+
export class DurableObjectExample {
7+
constructor(state) {
8+
this.state = state;
9+
}
10+
11+
async fetch() {
812
ok(navigator.gpu);
913

1014
const adapter = await navigator.gpu.requestAdapter();
@@ -80,5 +84,17 @@ export const read_sync_stack = {
8084

8185
// ensure callback with error was indeed called
8286
ok(callbackCalled);
87+
88+
return new Response("OK");
89+
}
90+
}
91+
92+
export const error_handling = {
93+
async test(ctrl, env, ctx) {
94+
let id = env.ns.idFromName("A");
95+
let obj = env.ns.get(id);
96+
let res = await obj.fetch("http://foo/test");
97+
let text = await res.text();
98+
equal(text, "OK");
8399
},
84100
};

src/workerd/api/gpu/webgpu-write-test.gpu-wd-test

+8-1
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,15 @@ const unitTests :Workerd.Config = (
77
modules = [
88
(name = "worker", esModule = embed "webgpu-write-test.js")
99
],
10+
durableObjectNamespaces = [
11+
(className = "DurableObjectExample", uniqueKey = "210bd0cbd803ef7883a1ee9d86cce06e"),
12+
],
13+
durableObjectStorage = (inMemory = void),
14+
bindings = [
15+
(name = "ns", durableObjectNamespace = "DurableObjectExample"),
16+
],
1017
compatibilityDate = "2023-01-15",
11-
compatibilityFlags = ["experimental", "nodejs_compat"],
18+
compatibilityFlags = ["experimental", "nodejs_compat", "webgpu"],
1219
)
1320
),
1421
],

src/workerd/api/gpu/webgpu-write-test.js

+20-4
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
1-
import { ok, deepEqual } from "node:assert";
1+
import { ok, deepEqual, equal } from "node:assert";
22

33
// run manually for now
44
// bazel run --//src/workerd/io:enable_experimental_webgpu //src/workerd/server:workerd -- test `realpath ./src/workerd/api/gpu/webgpu-write-test.gpu-wd-test` --verbose --experimental
55

6-
export const read_sync_stack = {
7-
async test(ctrl, env, ctx) {
6+
export class DurableObjectExample {
7+
constructor(state) {
8+
this.state = state;
9+
}
10+
11+
async fetch() {
812
ok(navigator.gpu);
913
const adapter = await navigator.gpu.requestAdapter();
1014
ok(adapter);
@@ -23,6 +27,18 @@ export const read_sync_stack = {
2327

2428
// Write bytes to buffer.
2529
new Uint8Array(arrayBuffer).set([0, 1, 2, 3]);
26-
deepEqual(new Uint8Array(arrayBuffer), new Uint8Array([ 0, 1, 2, 3 ]));
30+
deepEqual(new Uint8Array(arrayBuffer), new Uint8Array([0, 1, 2, 3]));
31+
32+
return new Response("OK");
33+
}
34+
}
35+
36+
export const buffer_write = {
37+
async test(ctrl, env, ctx) {
38+
let id = env.ns.idFromName("A");
39+
let obj = env.ns.get(id);
40+
let res = await obj.fetch("http://foo/test");
41+
let text = await res.text();
42+
equal(text, "OK");
2743
},
2844
};

src/workerd/io/compatibility-date.capnp

+4
Original file line numberDiff line numberDiff line change
@@ -340,4 +340,8 @@ struct CompatibilityFlags @0x8f8c1b68151b6cef {
340340
$compatEnableFlag("rtti_api")
341341
$experimental;
342342
# Enables the `workerd:rtti` module for querying runtime-type-information from JavaScript.
343+
344+
webgpu @35 :Bool
345+
$compatEnableFlag("webgpu")
346+
$experimental;
343347
}

0 commit comments

Comments
 (0)