diff --git a/README.md b/README.md index 30b33c2..c30abc0 100644 --- a/README.md +++ b/README.md @@ -134,6 +134,28 @@ const router = new Router(openrpcDocument, methodHandlerMapping); const router = new Router(openrpcDocument, { mockMode: true }); ``` +###### router plugins (`x-implementedBy` + client context) + +```typescript +import { Router, plugins } from "@open-rpc/server-js"; + +const router = new Router(openrpcDocument, methodHandlerMapping, { + plugins: [plugins.implementedByPlugin()], +}); +``` + +You can also pass `routerOptions` through `Server`: + +```typescript +const server = new Server({ + openrpcDocument, + methodMapping, + routerOptions: { + plugins: [plugins.implementedByPlugin()], + }, +}); +``` + ##### Creating Transports ###### IPC @@ -190,6 +212,47 @@ const wsFromHttpsTransport = new WebSocketServerTransport(webSocketFromHttpsOpti const wsTransport = new WebSocketServerTransport(webSocketOptions); // Accepts http transport as well. ``` +###### Bidirectional `x-implementedBy` example + +This repository includes a minimal server/client pair that demonstrates: +- server-only methods (`"x-implementedBy": ["server"]`) +- client-only methods (`"x-implementedBy": ["client"]`) +- methods implemented by both (`"x-implementedBy": ["server", "client"]`) + +Run in separate terminals: + +```bash +npm run example:bidirectional:server +``` + +```bash +npm run example:bidirectional:client +``` + +Example sources: +- `src/examples/bidirectional/openrpc.ts` +- `src/examples/bidirectional/server.ts` +- `src/examples/bidirectional/client.ts` + +###### `outboundHandler` example + +This repository also includes a minimal `outboundHandler` example where the server +proactively calls connected client methods on an interval. + +Run in separate terminals: + +```bash +npm run example:outbound:server +``` + +```bash +npm run example:outbound:client +``` + +Example sources: +- `src/examples/bidirectional/server-outbound.ts` +- `src/examples/bidirectional/client-outbound.ts` + ###### Add components as you go ``` const server = new Server(); diff --git a/package.json b/package.json index cab87c1..2b30142 100644 --- a/package.json +++ b/package.json @@ -18,6 +18,10 @@ "test": "npm run build && npm run test:unit", "test:unit": "jest --coverage", "build": "tsc", + "example:bidirectional:server": "npm run build && node build/examples/bidirectional/server.js", + "example:bidirectional:client": "npm run build && node build/examples/bidirectional/client.js", + "example:outbound:server": "npm run build && node build/examples/bidirectional/server-outbound.js", + "example:outbound:client": "npm run build && node build/examples/bidirectional/client-outbound.js", "watch:build": "tsc --watch", "watch:test": "jest --watch", "lint": "eslint . --ext .ts", diff --git a/src/examples/bidirectional/client-outbound.ts b/src/examples/bidirectional/client-outbound.ts new file mode 100644 index 0000000..25fbff9 --- /dev/null +++ b/src/examples/bidirectional/client-outbound.ts @@ -0,0 +1,108 @@ +import WebSocket from "ws"; + +const URL = "ws://localhost:9851"; + +interface PendingRequest { + resolve: (value: unknown) => void; + reject: (err: Error) => void; +} + +const ws = new WebSocket(URL); +const pending = new Map(); +let nextRequestId = 0; + +function sendRequest(method: string, params: unknown[]): Promise { + return new Promise((resolve, reject) => { + const id = `client-${nextRequestId++}`; + pending.set(id, { resolve, reject }); + ws.send(JSON.stringify({ + id, + jsonrpc: "2.0", + method, + params, + })); + }); +} + +function sendResult(id: string, result: unknown) { + ws.send(JSON.stringify({ + id, + jsonrpc: "2.0", + result, + })); +} + +function handleIncomingRequest(payload: any) { + if (!payload.id) { + return; + } + + if (payload.method === "clientHello") { + const name = payload.params?.[0]; + sendResult(payload.id, `Hello ${name} (from outbound client).`); + return; + } + + if (payload.method === "bounce") { + const text = payload.params?.[0]; + const result = `[outbound client bounce] ${text}`; + console.log("received outbound call:", result); + sendResult(payload.id, result); + return; + } + + ws.send(JSON.stringify({ + id: payload.id, + jsonrpc: "2.0", + error: { + code: -32601, + message: `Unknown method "${payload.method}"`, + }, + })); +} + +function handleIncomingResponse(payload: any) { + if (!payload.id) { + return; + } + const pendingRequest = pending.get(payload.id); + if (!pendingRequest) { + return; + } + + pending.delete(payload.id); + if (payload.error) { + pendingRequest.reject(new Error(payload.error.message || "Unknown JSON-RPC error")); + return; + } + pendingRequest.resolve(payload.result); +} + +ws.on("message", (raw) => { + const payload = JSON.parse(raw.toString()); + if (payload.method) { + handleIncomingRequest(payload); + return; + } + handleIncomingResponse(payload); +}); + +ws.on("open", async () => { + try { + const response = await sendRequest("serverCallsClient", ["Bob"]); + console.log("serverCallsClient:", response); + console.log("waiting 6 seconds for outboundHandler calls..."); + setTimeout(() => ws.close(), 6000); + } catch (err) { + console.error("Client request failed:", err); + ws.close(); + } +}); + +ws.on("error", (err) => { + console.error("WebSocket error:", err); +}); + +ws.on("close", () => { + process.exit(0); +}); diff --git a/src/examples/bidirectional/client.ts b/src/examples/bidirectional/client.ts new file mode 100644 index 0000000..6b512b7 --- /dev/null +++ b/src/examples/bidirectional/client.ts @@ -0,0 +1,117 @@ +import WebSocket from "ws"; + +const URL = "ws://localhost:9850"; + +interface PendingRequest { + resolve: (value: unknown) => void; + reject: (err: Error) => void; +} + +const ws = new WebSocket(URL); +const pending = new Map(); +let nextRequestId = 0; + +function sendRequest(method: string, params: unknown[]): Promise { + return new Promise((resolve, reject) => { + const id = `client-${nextRequestId++}`; + pending.set(id, { resolve, reject }); + ws.send(JSON.stringify({ + id, + jsonrpc: "2.0", + method, + params, + })); + }); +} + +function sendResult(id: string, result: unknown) { + ws.send(JSON.stringify({ + id, + jsonrpc: "2.0", + result, + })); +} + +function sendError(id: string, message: string) { + ws.send(JSON.stringify({ + id, + jsonrpc: "2.0", + error: { + code: -32601, + message, + }, + })); +} + +function handleIncomingRequest(payload: any) { + if (!payload.id) { + return; + } + + if (payload.method === "clientHello") { + const name = payload.params?.[0]; + sendResult(payload.id, `Hello ${name} (from client).`); + return; + } + + if (payload.method === "bounce") { + const text = payload.params?.[0]; + sendResult(payload.id, `[client bounce] ${text}`); + return; + } + + sendError(payload.id, `Unknown method "${payload.method}"`); +} + +function handleIncomingResponse(payload: any) { + if (!payload.id) { + return; + } + + const pendingRequest = pending.get(payload.id); + if (!pendingRequest) { + return; + } + + pending.delete(payload.id); + if (payload.error) { + pendingRequest.reject(new Error(payload.error.message || "Unknown JSON-RPC error")); + return; + } + + pendingRequest.resolve(payload.result); +} + +ws.on("message", (raw) => { + const payload = JSON.parse(raw.toString()); + if (payload.method) { + handleIncomingRequest(payload); + return; + } + handleIncomingResponse(payload); +}); + +ws.on("open", async () => { + try { + const hello = await sendRequest("serverHello", ["Alice"]); + console.log("serverHello:", hello); + + const serverUsedClient = await sendRequest("serverCallsClient", ["Alice"]); + console.log("serverCallsClient:", serverUsedClient); + + const bounce = await sendRequest("bounce", ["Hello from client"]); + console.log("bounce:", bounce); + } catch (err) { + console.error("Client call failed:", err); + } finally { + ws.close(); + } +}); + +ws.on("error", (err) => { + console.error("WebSocket error:", err); +}); + +ws.on("close", () => { + process.exit(0); +}); diff --git a/src/examples/bidirectional/openrpc.ts b/src/examples/bidirectional/openrpc.ts new file mode 100644 index 0000000..2f93d66 --- /dev/null +++ b/src/examples/bidirectional/openrpc.ts @@ -0,0 +1,41 @@ +import { OpenrpcDocument as OpenRPC } from "@open-rpc/meta-schema"; + +const bidirectionalOpenRPCDocument: OpenRPC = { + openrpc: "1.2.6", + info: { + title: "Bidirectional Example", + version: "1.0.0", + }, + methods: [ + { + name: "serverHello", + summary: "Implemented by the server only.", + params: [{ name: "name", schema: { type: "string" } }], + result: { name: "message", schema: { type: "string" } }, + "x-implemented-by": ["server"], + }, + { + name: "clientHello", + summary: "Implemented by the client only.", + params: [{ name: "name", schema: { type: "string" } }], + result: { name: "message", schema: { type: "string" } }, + "x-implemented-by": ["client"], + }, + { + name: "bounce", + summary: "Implemented by both client and server.", + params: [{ name: "text", schema: { type: "string" } }], + result: { name: "message", schema: { type: "string" } }, + "x-implemented-by": ["server", "client"], + }, + { + name: "serverCallsClient", + summary: "Server method that calls client methods via injected client proxy.", + params: [{ name: "name", schema: { type: "string" } }], + result: { name: "message", schema: { type: "string" } }, + "x-implemented-by": ["server"], + }, + ], +}; + +export default bidirectionalOpenRPCDocument; diff --git a/src/examples/bidirectional/server-outbound.ts b/src/examples/bidirectional/server-outbound.ts new file mode 100644 index 0000000..de2ed41 --- /dev/null +++ b/src/examples/bidirectional/server-outbound.ts @@ -0,0 +1,68 @@ +import { parseOpenRPCDocument } from "@open-rpc/schema-utils-js"; +import { OpenrpcDocument as OpenRPC } from "@open-rpc/meta-schema"; +import { Router } from "../../router"; +import WebSocketTransport, { ConnectedClient } from "../../transports/websocket"; +import bidirectionalOpenRPCDocument from "./openrpc"; + +const PORT = 9851; + +async function startOutboundServer() { + const openrpcDocument = await parseOpenRPCDocument( + JSON.stringify(bidirectionalOpenRPCDocument), + ) as OpenRPC; + + const router = new Router(openrpcDocument, { + serverHello: async (name: string) => `Hello ${name} (from outbound server).`, + bounce: async (text: string) => `[outbound server bounce] ${text}`, + serverCallsClient: async ( + name: string, + client: { + clientHello: (value: string) => Promise; + bounce: (value: string) => Promise; + }, + ) => { + const helloFromClient = await client.clientHello(name); + const bounceFromClient = await client.bounce(`request/response ping for ${name}`); + return `Request/response path -> ${helloFromClient} | ${bounceFromClient}`; + }, + }); + + const outboundHandler = async (clients: ConnectedClient[]) => { + await Promise.all(clients.map(async (client) => { + try { + const message = await client.methods.bounce("scheduled ping from outboundHandler"); + console.log(`[outboundHandler] ${client.id}: ${String(message)}`); + } catch (err) { + console.error(`[outboundHandler] failed for ${client.id}:`, err); + } + })); + }; + + const transport = new WebSocketTransport({ + middleware: [], + port: PORT, + outboundHandler, + outboundIntervalMs: 2000, + }); + + transport.addRouter(router); + await transport.start(); + console.log(`OutboundHandler example server listening on ws://localhost:${PORT}`); + + const shutdown = async () => { + await transport.stop(); + process.exit(0); + }; + + process.once("SIGINT", () => { + void shutdown(); + }); + process.once("SIGTERM", () => { + void shutdown(); + }); +} + +void startOutboundServer().catch((err) => { + console.error("Failed to start outbound example server:", err); + process.exit(1); +}); diff --git a/src/examples/bidirectional/server.ts b/src/examples/bidirectional/server.ts new file mode 100644 index 0000000..53883f0 --- /dev/null +++ b/src/examples/bidirectional/server.ts @@ -0,0 +1,55 @@ +import { parseOpenRPCDocument } from "@open-rpc/schema-utils-js"; +import { OpenrpcDocument as OpenRPC } from "@open-rpc/meta-schema"; +import { Router } from "../../router"; +import WebSocketTransport from "../../transports/websocket"; +import bidirectionalOpenRPCDocument from "./openrpc"; + +const PORT = 9850; + +async function startServer() { + const openrpcDocument = await parseOpenRPCDocument( + JSON.stringify(bidirectionalOpenRPCDocument), + ) as OpenRPC; + + const router = new Router(openrpcDocument, { + serverHello: async (name: string) => `Hello ${name} (from server).`, + bounce: async (text: string) => `[server bounce] ${text}`, + serverCallsClient: async ( + name: string, + client: { + clientHello: (value: string) => Promise; + bounce: (value: string) => Promise; + }, + ) => { + const helloFromClient = await client.clientHello(name); + const bounceFromClient = await client.bounce(`ping from server for ${name}`); + return `Server called client -> ${helloFromClient} | ${bounceFromClient}`; + }, + }); + + const transport = new WebSocketTransport({ + middleware: [], + port: PORT, + }); + + transport.addRouter(router); + await transport.start(); + console.log(`Bidirectional example server listening on ws://localhost:${PORT}`); + + const shutdown = async () => { + await transport.stop(); + process.exit(0); + }; + + process.once("SIGINT", () => { + void shutdown(); + }); + process.once("SIGTERM", () => { + void shutdown(); + }); +} + +void startServer().catch((err) => { + console.error("Failed to start bidirectional example server:", err); + process.exit(1); +}); diff --git a/src/index.ts b/src/index.ts index 762830a..0234e3f 100644 --- a/src/index.ts +++ b/src/index.ts @@ -2,6 +2,7 @@ import Server, { ServerOptions } from "./server"; import { Router } from "./router"; import { JSONRPCError } from "./error"; export * as transports from "./transports" +export * as plugins from "./plugins" export { Server, diff --git a/src/plugins.ts b/src/plugins.ts new file mode 100644 index 0000000..4fd00a2 --- /dev/null +++ b/src/plugins.ts @@ -0,0 +1,44 @@ +import { MethodObject } from "@open-rpc/meta-schema"; +import { RouterPlugin } from "./router"; + +const getImplementedBy = (methodObject?: MethodObject): string[] => { + if (!methodObject) { + return []; + } + + const implementedBy = (methodObject as MethodObject & { [key: string]: unknown })["x-implemented-by"]; + if (implementedBy === undefined) { + return ["server"]; + } + + if (implementedBy instanceof Array) { + return implementedBy.filter((role): role is string => typeof role === "string"); + } + + return []; +}; + +export const implementedByPlugin = (): RouterPlugin => ({ + name: "implemented-by", + isMethodImplemented: ({ methodName, methodObject, context }) => { + if (methodName === "rpc.discover") { + return true; + } + + const participant = (context?.participant as string | undefined) || "server"; + return getImplementedBy(methodObject).includes(participant); + }, + listMethods: ({ methods, context }) => { + const participant = (context?.participant as string | undefined) || "server"; + return methods + .filter((method) => getImplementedBy(method).includes(participant)) + .map((method) => method.name); + }, + mapHandlerParams: ({ paramsAsArray, context }) => { + if (!context || context.client === undefined) { + return paramsAsArray; + } + return [...paramsAsArray, context.client]; + }, +}); + diff --git a/src/router.test.ts b/src/router.test.ts index 6e259dc..beb2307 100644 --- a/src/router.test.ts +++ b/src/router.test.ts @@ -11,6 +11,7 @@ import { } from "@open-rpc/meta-schema"; import { JSONRPCError } from "./error"; import { JSONRPCErrorObject } from "./transports/server-transport"; +import { implementedByPlugin } from "./plugins"; const jsf = require("json-schema-faker"); // eslint-disable-line @@ -144,6 +145,82 @@ describe("router", () => { const { result } = await router.call("addition", [6, 2]); expect(typeof result).toBe("number"); }); + + it("does not expose x-implemented-by client methods as inbound server methods", async () => { + const exampleWithClientMethod = _.cloneDeep(parsedExample); + const additionMethod = (exampleWithClientMethod.methods as MethodObject[]) + .find((method) => method.name === "addition") as MethodObject; + (additionMethod as MethodObject & { [key: string]: unknown })["x-implemented-by"] = ["client"]; + const router = new Router(exampleWithClientMethod, makeMethodMapping(exampleWithClientMethod.methods as MethodObject[])); + const { error } = await router.call("addition", [2, 2]); + expect((error as JSONRPCErrorObject).code).toBe(-32601); + }); + + it("reports methods implemented by a given participant", async () => { + const exampleWithClientMethod = _.cloneDeep(parsedExample); + const methods = exampleWithClientMethod.methods as MethodObject[]; + const additionMethod = methods.find((method) => method.name === "addition") as MethodObject; + const subtractionMethod = methods.find((method) => method.name === "subtraction") as MethodObject; + (additionMethod as MethodObject & { [key: string]: unknown })["x-implemented-by"] = ["client"]; + (subtractionMethod as MethodObject & { [key: string]: unknown })["x-implemented-by"] = ["server", "client"]; + + const router = new Router(exampleWithClientMethod, makeMethodMapping(methods)); + expect(router.getMethodsImplementedBy("client")).toEqual(expect.arrayContaining(["addition", "subtraction"])); + expect(router.getMethodsImplementedBy("server")).toContain("subtraction"); + }); + + it("defaults x-implemented-by to server when extension is omitted", async () => { + const exampleWithoutExtension = _.cloneDeep(parsedExample); + const additionMethod = (exampleWithoutExtension.methods as MethodObject[]) + .find((method) => method.name === "addition") as MethodObject; + delete (additionMethod as MethodObject & { [key: string]: unknown })["x-implemented-by"]; + + const router = new Router(exampleWithoutExtension, makeMethodMapping(exampleWithoutExtension.methods as MethodObject[])); + expect(router.isMethodImplemented("addition")).toBe(true); + const { result } = await router.call("addition", [2, 2]); + expect(result).toBe(4); + }); + + it("does not include rpc.discover in participant method listings", () => { + const router = new Router(parsedExample, makeMethodMapping(parsedExample.methods as MethodObject[])); + expect(router.getMethodsImplementedBy("server")).not.toContain("rpc.discover"); + }); + + it("plugin reports available methods using plugin context", () => { + const exampleWithRoles = _.cloneDeep(parsedExample); + const methods = exampleWithRoles.methods as MethodObject[]; + const additionMethod = methods.find((method) => method.name === "addition") as MethodObject; + const subtractionMethod = methods.find((method) => method.name === "subtraction") as MethodObject; + (additionMethod as MethodObject & { [key: string]: unknown })["x-implemented-by"] = ["client"]; + (subtractionMethod as MethodObject & { [key: string]: unknown })["x-implemented-by"] = ["server", "client"]; + + const router = new Router( + exampleWithRoles, + makeMethodMapping(methods), + { plugins: [implementedByPlugin()] }, + ); + expect(router.getAvailableMethods({ participant: "client" })).toEqual(expect.arrayContaining(["addition", "subtraction"])); + expect(router.getAvailableMethods({ participant: "server" })).toContain("subtraction"); + expect(router.getAvailableMethods({ participant: "server" })).not.toContain("addition"); + }); + + it("plugin appends context client as final method arg", async () => { + const expectedClient = { id: "client-1" }; + const exampleForContext = _.cloneDeep(parsedExample); + const methodMapping = makeMethodMapping(exampleForContext.methods as MethodObject[]); + methodMapping.addition = async (a: number, b: number, client: unknown) => { + expect(client).toBe(expectedClient); + return a + b; + }; + + const router = new Router( + exampleForContext, + methodMapping, + { plugins: [implementedByPlugin()] }, + ); + const { result } = await router.call("addition", [2, 2], { client: expectedClient }); + expect(result).toBe(4); + }); } }); diff --git a/src/router.ts b/src/router.ts index afc5cb3..a16f92a 100644 --- a/src/router.ts +++ b/src/router.ts @@ -20,6 +20,40 @@ export interface MockModeSettings { } export type TMethodHandler = (...args: any) => Promise; +export interface RouterCallContext { + client?: unknown; + [key: string]: unknown; +} + +export interface RouterPluginMethodImplementationContext { + methodName: string; + methodObject?: MethodObject; + hasLocalHandler: boolean; + openrpcDocument: OpenrpcDocument; + context?: RouterCallContext; +} + +export interface RouterPluginInvocationContext extends RouterPluginMethodImplementationContext { + params: any; + paramsAsArray: any[]; +} + +export interface RouterPluginListMethodsContext { + methods: MethodObject[]; + openrpcDocument: OpenrpcDocument; + context?: RouterCallContext; +} + +export interface RouterPlugin { + name: string; + isMethodImplemented?: (context: RouterPluginMethodImplementationContext) => boolean | undefined; + mapHandlerParams?: (context: RouterPluginInvocationContext) => any[] | undefined; + listMethods?: (context: RouterPluginListMethodsContext) => string[] | undefined; +} + +export interface RouterOptions { + plugins?: RouterPlugin[]; +} const toArray = (method?: MethodObject, params?: Record) => { if (!method) { @@ -53,10 +87,34 @@ export class Router { } private methods: MethodMapping; private methodCallValidator: MethodCallValidator; + private plugins: RouterPlugin[]; + + private getMethodObject(methodName: string): MethodObject | undefined { + return (this.openrpcDocument.methods as MethodObject[]).find((m) => m.name === methodName); + } + + private getImplementedBy(methodName: string): string[] { + const methodObject = this.getMethodObject(methodName); + if (!methodObject) { + return []; + } + + const implementedBy = (methodObject as MethodObject & { [key: string]: unknown })["x-implemented-by"]; + if (implementedBy === undefined) { + return ["server"]; + } + + if (implementedBy instanceof Array) { + return implementedBy.filter((role): role is string => typeof role === "string"); + } + + return []; + } constructor( private openrpcDocument: OpenrpcDocument, methodMapping: MethodMapping | MockModeSettings, + options: RouterOptions = {}, ) { if (methodMapping.mockMode) { this.methods = this.buildMockMethodMapping(openrpcDocument.methods as MethodObject[]); @@ -66,9 +124,14 @@ export class Router { this.methods["rpc.discover"] = this.serviceDiscoveryHandler.bind(this); this.methodCallValidator = new MethodCallValidator(openrpcDocument); + this.plugins = options.plugins || []; } - public async call(methodName: string, params: any) { + public async call(methodName: string, params: any, context?: RouterCallContext) { + if (!this.isMethodImplemented(methodName, context)) { + return Router.methodNotFoundHandler(methodName); + } + const validationErrors = this.methodCallValidator.validate(methodName, params); if (validationErrors instanceof MethodNotFoundError) { @@ -79,11 +142,37 @@ export class Router { return this.invalidParamsHandler(validationErrors); } - const methodObject = (this.openrpcDocument.methods as MethodObject[]).find((m) => m.name === methodName) as MethodObject; + const methodObject = this.getMethodObject(methodName) as MethodObject; - const paramsAsArray = params instanceof Array ? params : toArray(methodObject, params); + let paramsAsArray = params instanceof Array ? params : toArray(methodObject, params); try { + let paramsMappedByPlugin = false; + for (const plugin of this.plugins) { + if (!plugin.mapHandlerParams) { + continue; + } + + const mappedParams = plugin.mapHandlerParams({ + context, + hasLocalHandler: this.methods[methodName] !== undefined, + methodName, + methodObject, + openrpcDocument: this.openrpcDocument, + params, + paramsAsArray, + }); + + if (mappedParams !== undefined) { + paramsAsArray = mappedParams; + paramsMappedByPlugin = true; + } + } + + if (!paramsMappedByPlugin && context && context.client !== undefined) { + paramsAsArray = [...paramsAsArray, context.client]; + } + return { result: await this.methods[methodName](...paramsAsArray) }; } catch (e) { if (e instanceof JSONRPCError) { @@ -93,8 +182,66 @@ export class Router { } } - public isMethodImplemented(methodName: string): boolean { - return this.methods[methodName] !== undefined; + public isMethodImplemented(methodName: string, context?: RouterCallContext): boolean { + const methodObject = (this.openrpcDocument.methods as MethodObject[]).find((m) => m.name === methodName); + const hasLocalHandler = this.methods[methodName] !== undefined; + + if (!hasLocalHandler) { + return false; + } + + if (methodName === "rpc.discover") { + return true; + } + + for (const plugin of this.plugins) { + if (!plugin.isMethodImplemented) { + continue; + } + const pluginResult = plugin.isMethodImplemented({ + context, + hasLocalHandler, + methodName, + methodObject, + openrpcDocument: this.openrpcDocument, + }); + if (pluginResult === false) { + return false; + } + } + + return this.getImplementedBy(methodName).includes("server"); + } + + public getMethodsImplementedBy(participant: string): string[] { + return (this.openrpcDocument.methods as MethodObject[]) + .filter((method) => this.getImplementedBy(method.name).includes(participant)) + .map((method) => method.name) + .filter((methodName) => methodName !== "rpc.discover"); + } + + public getAvailableMethods(context?: RouterCallContext): string[] { + const methods = this.openrpcDocument.methods as MethodObject[]; + + for (const plugin of this.plugins) { + if (!plugin.listMethods) { + continue; + } + const methodList = plugin.listMethods({ + context, + methods, + openrpcDocument: this.openrpcDocument, + }); + + if (methodList !== undefined) { + return methodList.filter((methodName) => methodName !== "rpc.discover"); + } + } + + return methods + .map((method) => method.name) + .filter((methodName) => methodName !== "rpc.discover") + .filter((methodName) => this.isMethodImplemented(methodName, context)); } private serviceDiscoveryHandler(): Promise { diff --git a/src/server.test.ts b/src/server.test.ts index 868763e..d508723 100644 --- a/src/server.test.ts +++ b/src/server.test.ts @@ -54,7 +54,7 @@ describe('Server', () => { const server = new Server({ openrpcDocument: {} as any, methodMapping: mapping }); // Verify Router was created with expected args - expect(require('./router').Router).toHaveBeenCalledWith({} as any, mapping); + expect(require('./router').Router).toHaveBeenCalledWith({} as any, mapping, undefined); expect((server as any).routers).toHaveLength(1); // Restore original Router @@ -136,7 +136,7 @@ describe('Server', () => { const router = server.addRouter({} as any, {} as any); - expect(require('./router').Router).toHaveBeenCalledWith({}, {} as any); + expect(require('./router').Router).toHaveBeenCalledWith({}, {} as any, undefined); expect(transport.addRouter).toHaveBeenCalledWith(router); expect((server as any).routers).toContain(router); @@ -144,6 +144,21 @@ describe('Server', () => { require('./router').Router = originalRouter; }); + it('passes router options into addRouter and constructor path', () => { + const originalRouter = require('./router').Router; + const mockRouter = createTestRouter(); + require('./router').Router = jest.fn().mockReturnValue(mockRouter); + + const routerOptions = { plugins: [{ name: "test-plugin" }] } as any; + const server = new Server({ openrpcDocument: {} as any, methodMapping: {} as any, routerOptions }); + server.addRouter({} as any, {} as any, routerOptions); + + expect(require('./router').Router).toHaveBeenNthCalledWith(1, {} as any, {} as any, routerOptions); + expect(require('./router').Router).toHaveBeenNthCalledWith(2, {} as any, {} as any, routerOptions); + + require('./router').Router = originalRouter; + }); + it('deregisters router and detaches from transports in removeRouter', () => { const server = new Server({ openrpcDocument: {} as any }); diff --git a/src/server.ts b/src/server.ts index 0a6cb4e..151d707 100644 --- a/src/server.ts +++ b/src/server.ts @@ -1,4 +1,4 @@ -import { Router, MethodMapping } from "./router"; +import { Router, MethodMapping, RouterOptions } from "./router"; import { OpenrpcDocument as OpenRPC } from "@open-rpc/meta-schema"; import Transports, {ServerTransport, TransportOptions, TransportClasses, TransportNames } from "./transports"; @@ -15,6 +15,7 @@ export interface ServerOptions { openrpcDocument: OpenRPC; transportConfigs?: TransportConfig[]; methodMapping?: MethodMapping | MockModeOptions; + routerOptions?: RouterOptions; } export default class Server { @@ -26,6 +27,7 @@ export default class Server { this.addRouter( options.openrpcDocument, options.methodMapping, + options.routerOptions, ); } @@ -57,8 +59,8 @@ export default class Server { this.addTransport(transport); } - public addRouter(openrpcDocument: OpenRPC, methodMapping: MethodMapping | MockModeOptions) { - const router = new Router(openrpcDocument, methodMapping); + public addRouter(openrpcDocument: OpenRPC, methodMapping: MethodMapping | MockModeOptions, routerOptions?: RouterOptions) { + const router = new Router(openrpcDocument, methodMapping, routerOptions); this.routers.push(router); this.transports.forEach((transport) => transport.addRouter(router)); diff --git a/src/transports/server-transport.test.ts b/src/transports/server-transport.test.ts index 9d06b29..74c1643 100644 --- a/src/transports/server-transport.test.ts +++ b/src/transports/server-transport.test.ts @@ -37,6 +37,22 @@ describe("Server transport test", () => { expect(result.result).toBe(42); }); + it("passes context into router selection and call", async () => { + const t = new DummyTransport(); + const ctx = { client: { id: "ctx-client" } }; + const isMethodImplemented = jest.fn().mockReturnValue(true); + const call = jest.fn().mockResolvedValue({ result: 42 }); + const fakeRouter = { + isMethodImplemented, + call, + } as unknown as import("../router").Router; + + t.addRouter(fakeRouter); + await t['routerHandler']({ jsonrpc: "2.0", id: "2", method: "bar", params: [] }, ctx); + expect(isMethodImplemented).toHaveBeenCalledWith("bar", ctx); + expect(call).toHaveBeenCalledWith("bar", [], ctx); + }); + it("covers the no router configured branch in routerHandler", async () => { class DummyTransport extends ServerTransport {} const t = new DummyTransport(); diff --git a/src/transports/server-transport.ts b/src/transports/server-transport.ts index 2ed0ebd..2e908dc 100644 --- a/src/transports/server-transport.ts +++ b/src/transports/server-transport.ts @@ -1,4 +1,4 @@ -import { Router } from "../router"; +import { Router, RouterCallContext } from "../router"; export interface JSONRPCRequest { jsonrpc: string; @@ -41,13 +41,13 @@ export abstract class ServerTransport { throw new Error("Transport missing stop implementation"); } - protected async routerHandler({ id, method, params }: JSONRPCRequest): Promise { + protected async routerHandler({ id, method, params }: JSONRPCRequest, context?: RouterCallContext): Promise { if (this.routers.length === 0) { console.warn("transport method called without a router configured."); // tslint:disable-line throw new Error("No router configured"); } - const routerForMethod = this.routers.find((r) => r.isMethodImplemented(method)); + const routerForMethod = this.routers.find((r) => r.isMethodImplemented(method, context)); let res = { id, @@ -63,11 +63,11 @@ export abstract class ServerTransport { } else { res = { ...res, - ...await routerForMethod.call(method, params) + ...await routerForMethod.call(method, params, context) }; } return res; } } -export default ServerTransport; \ No newline at end of file +export default ServerTransport; diff --git a/src/transports/websocket.integration.test.ts b/src/transports/websocket.integration.test.ts new file mode 100644 index 0000000..7d2a4da --- /dev/null +++ b/src/transports/websocket.integration.test.ts @@ -0,0 +1,117 @@ +import WebSocket from "ws"; +import { parseOpenRPCDocument } from "@open-rpc/schema-utils-js"; +import { OpenrpcDocument as OpenRPC } from "@open-rpc/meta-schema"; +import WebSocketTransport from "./websocket"; +import { Router } from "../router"; + +describe("websocket integration", () => { + jest.setTimeout(15000); + + it("supports bidirectional calls over websocket", async () => { + const openrpcDocument = await parseOpenRPCDocument(JSON.stringify({ + openrpc: "1.2.6", + info: { + title: "WebSocket integration test", + version: "1.0.0", + }, + methods: [ + { + name: "add", + params: [ + { name: "a", schema: { type: "number" } }, + { name: "b", schema: { type: "number" } }, + ], + result: { name: "sum", schema: { type: "number" } }, + "x-implemented-by": ["server"], + }, + { + name: "clientDouble", + params: [{ name: "value", schema: { type: "number" } }], + result: { name: "doubled", schema: { type: "number" } }, + "x-implemented-by": ["client"], + }, + { + name: "callClientDouble", + params: [{ name: "value", schema: { type: "number" } }], + result: { name: "result", schema: { type: "number" } }, + "x-implemented-by": ["server"], + }, + ], + })) as OpenRPC; + + const transport = new WebSocketTransport({ + middleware: [], + port: 9720, + }); + + const router = new Router(openrpcDocument, { + add: async (a: number, b: number) => a + b, + callClientDouble: async ( + value: number, + client: { clientDouble: (input: number) => Promise }, + ) => client.clientDouble(value), + clientDouble: async () => 0, + }); + + transport.addRouter(router); + await transport.start(); + + const ws = new WebSocket("ws://localhost:9720"); + + const pending = new Map void; reject: (error: Error) => void }>(); + let nextRequestId = 0; + + const sendRequest = (method: string, params: any[]) => new Promise((resolve, reject) => { + const id = `client-${nextRequestId++}`; + pending.set(id, { resolve, reject }); + ws.send(JSON.stringify({ id, jsonrpc: "2.0", method, params })); + }); + + const messageHandler = (raw: WebSocket.Data) => { + const payload = JSON.parse(raw.toString()); + + if (payload.method === "clientDouble") { + ws.send(JSON.stringify({ + id: payload.id, + jsonrpc: "2.0", + result: payload.params[0] * 2, + })); + return; + } + + if (payload.id && (payload.result !== undefined || payload.error)) { + const pendingRequest = pending.get(payload.id); + if (!pendingRequest) { + return; + } + pending.delete(payload.id); + + if (payload.error) { + pendingRequest.reject(new Error(payload.error.message)); + return; + } + + pendingRequest.resolve(payload.result); + } + }; + + ws.on("message", messageHandler); + + await new Promise((resolve, reject) => { + ws.on("open", resolve); + ws.on("error", reject); + }); + + try { + const sum = await sendRequest("add", [2, 3]); + expect(sum).toBe(5); + + const doubled = await sendRequest("callClientDouble", [7]); + expect(doubled).toBe(14); + } finally { + ws.removeListener("message", messageHandler); + ws.close(); + await transport.stop(); + } + }); +}); diff --git a/src/transports/websocket.test.ts b/src/transports/websocket.test.ts index f3e74b4..2708c4a 100644 --- a/src/transports/websocket.test.ts +++ b/src/transports/websocket.test.ts @@ -371,4 +371,231 @@ describe("WebSocket transport", () => { }); expect((transport as any).options.timeout).toBe(5000); }); + + it("passes a client proxy into method handlers", async () => { + const simpleMathExample = await parseOpenRPCDocument(examples.simpleMath); + (simpleMathExample.methods as any[]).push({ + name: "notify", + params: [{ name: "value", schema: { type: "integer" } }], + result: { name: "notified", schema: { type: "integer" } }, + "x-implemented-by": ["client"], + }); + + const transport = new WebSocketTransport({ + middleware: [], + port: 9712, + }); + + const router = new Router(simpleMathExample, { + addition: async (a: number, b: number, client: { notify: (value: number) => Promise }) => { + return client.notify(a + b); + }, + subtraction: async (a: number, b: number) => a - b, + notify: async () => 0, + }); + + transport.addRouter(router); + await transport.start(); + + const ws = new WebSocket("ws://localhost:9712"); + + await new Promise((resolve, reject) => { + ws.on("message", (raw: WebSocket.Data) => { + const payload = JSON.parse(raw.toString()); + if (payload.method === "notify") { + ws.send(JSON.stringify({ + id: payload.id, + jsonrpc: "2.0", + result: payload.params[0] * 2, + })); + return; + } + + expect(payload.result).toBe(8); + resolve(); + }); + ws.on("error", reject); + ws.on("open", () => { + ws.send(JSON.stringify({ + id: "invoke-addition", + jsonrpc: "2.0", + method: "addition", + params: [2, 2], + })); + }); + }); + + ws.close(); + await transport.stop(); + }); + + it("runs outboundHandler with connected clients", async () => { + const simpleMathExample = await parseOpenRPCDocument(examples.simpleMath); + (simpleMathExample.methods as any[]).push({ + name: "notify", + params: [{ name: "value", schema: { type: "integer" } }], + result: { name: "notified", schema: { type: "integer" } }, + "x-implemented-by": ["client"], + }); + + let hasSentNotify = false; + const outboundHandler = jest.fn(async (clients) => { + if (clients.length === 0 || hasSentNotify) { + return; + } + hasSentNotify = true; + await clients[0].methods.notify(10); + }); + + const transport = new WebSocketTransport({ + middleware: [], + port: 9713, + outboundHandler, + outboundIntervalMs: 50, + }); + + const router = new Router(simpleMathExample, { + addition: async (a: number, b: number) => a + b, + subtraction: async (a: number, b: number) => a - b, + notify: async () => 0, + }); + + transport.addRouter(router); + await transport.start(); + + const ws = new WebSocket("ws://localhost:9713"); + await new Promise((resolve, reject) => { + ws.on("message", (raw: WebSocket.Data) => { + const payload = JSON.parse(raw.toString()); + if (payload.method === "notify") { + ws.send(JSON.stringify({ + id: payload.id, + jsonrpc: "2.0", + result: payload.params[0], + })); + resolve(); + } + }); + ws.on("error", reject); + }); + + expect(outboundHandler).toHaveBeenCalled(); + ws.close(); + await transport.stop(); + }); + + + it("rejects handler client proxy calls when client returns JSON-RPC error", async () => { + const simpleMathExample = await parseOpenRPCDocument(examples.simpleMath); + (simpleMathExample.methods as any[]).push({ + name: "notify", + params: [{ name: "value", schema: { type: "integer" } }], + result: { name: "notified", schema: { type: "integer" } }, + "x-implemented-by": ["client"], + }); + + const transport = new WebSocketTransport({ + middleware: [], + port: 9714, + }); + + const router = new Router(simpleMathExample, { + addition: async (a: number, b: number, client: { notify: (value: number) => Promise }) => { + return client.notify(a + b); + }, + subtraction: async (a: number, b: number) => a - b, + notify: async () => 0, + }); + + transport.addRouter(router); + await transport.start(); + + const ws = new WebSocket("ws://localhost:9714"); + + await new Promise((resolve, reject) => { + ws.on("message", (raw: WebSocket.Data) => { + const payload = JSON.parse(raw.toString()); + if (payload.method === "notify") { + ws.send(JSON.stringify({ + id: payload.id, + jsonrpc: "2.0", + error: { + code: 1234, + message: "client side failure", + data: { reason: "boom" }, + }, + })); + return; + } + + expect(payload.error).toBeDefined(); + expect(payload.error.code).toBe(6969); + resolve(); + }); + ws.on("error", reject); + ws.on("open", () => { + ws.send(JSON.stringify({ + id: "invoke-addition-error", + jsonrpc: "2.0", + method: "addition", + params: [2, 2], + })); + }); + }); + + ws.close(); + await transport.stop(); + }); + + it("cleans up pending client requests when socket closes", async () => { + const transport = new WebSocketTransport({ + middleware: [], + port: 9715, + }); + + const reject = jest.fn(); + const resolve = jest.fn(); + const mockSocket = { + removeAllListeners: jest.fn(), + }; + + (transport as any).pendingClientRequests.set("request-1", { + socket: mockSocket, + reject, + resolve, + }); + (transport as any).clientDetails.set(mockSocket, { id: "client-1", methods: {} }); + + (transport as any).handleClientClose(mockSocket); + + expect(mockSocket.removeAllListeners).toHaveBeenCalled(); + expect(reject).toHaveBeenCalledWith(new Error("WebSocket connection closed")); + expect((transport as any).pendingClientRequests.size).toBe(0); + expect((transport as any).clientDetails.size).toBe(0); + }); + + it("does not resolve pending client requests when response id is missing", () => { + const transport = new WebSocketTransport({ + middleware: [], + port: 9716, + }); + + const reject = jest.fn(); + const resolve = jest.fn(); + (transport as any).pendingClientRequests.set("request-2", { + socket: {}, + reject, + resolve, + }); + + (transport as any).resolvePendingClientRequest({ + jsonrpc: "2.0", + result: "ok", + }); + + expect(resolve).not.toHaveBeenCalled(); + expect(reject).not.toHaveBeenCalled(); + expect((transport as any).pendingClientRequests.size).toBe(1); + }); + }); diff --git a/src/transports/websocket.ts b/src/transports/websocket.ts index b0db41b..640b08c 100644 --- a/src/transports/websocket.ts +++ b/src/transports/websocket.ts @@ -3,9 +3,24 @@ import { json as jsonParser } from "body-parser"; import connect, { HandleFunction, Server as ConnectApp } from "connect"; import http2, { Http2SecureServer, SecureServerOptions } from "http2"; import http from "http"; -import ServerTransport, { JSONRPCRequest } from "./server-transport"; +import ServerTransport, { JSONRPCRequest, JSONRPCResponse } from "./server-transport"; import WebSocket from "ws"; +export interface ClientMethods { + [methodName: string]: (...params: any[]) => Promise; +} + +export interface ConnectedClient { + id: string; + methods: ClientMethods; +} + +interface PendingClientRequest { + socket: WebSocket; + resolve: (result: any) => void; + reject: (error: Error) => void; +} + export interface WebSocketServerTransportOptions extends SecureServerOptions { middleware: HandleFunction[]; port: number; @@ -13,12 +28,19 @@ export interface WebSocketServerTransportOptions extends SecureServerOptions { allowHTTP1?: boolean; app?: ConnectApp; timeout?: number; + outboundHandler?: (clients: ConnectedClient[]) => Promise | void; + outboundIntervalMs?: number; } export default class WebSocketServerTransport extends ServerTransport { private static defaultCorsOptions = { origin: "*" }; private server: Http2SecureServer | http.Server; private wss: WebSocket.Server; + private pendingClientRequests = new Map(); + private clientDetails = new Map(); + private outboundInterval?: NodeJS.Timeout; + private nextClientId = 0; + private nextRequestId = 0; constructor(private options: WebSocketServerTransportOptions) { super(); @@ -51,24 +73,47 @@ export default class WebSocketServerTransport extends ServerTransport { this.wss = new WebSocket.Server({ server: this.server as any }); this.wss.on("connection", (ws: WebSocket) => { - ws.on( - "message", - (message: string) => this.webSocketRouterHandler(JSON.parse(message), ws.send.bind(ws)), - ); - ws.on("close", () => ws.removeAllListeners()); + const client = { + id: `client-${this.nextClientId++}`, + methods: this.buildClientMethodsProxy(ws), + }; + this.clientDetails.set(ws, client); + + ws.on("message", (message: WebSocket.Data) => { + void this.handleWebSocketMessage(message, ws); + }); + ws.on("close", () => this.handleClientClose(ws)); }); } public async start(): Promise { - return new Promise((resolve, reject) => { + await new Promise((resolve, reject) => { this.server.listen(this.options.port, (err?: Error) => { if (err) return reject(err); resolve(); }); }); + + if (this.options.outboundHandler) { + const intervalMs = this.options.outboundIntervalMs ?? 1000; + this.outboundInterval = setInterval(() => { + void Promise.resolve(this.options.outboundHandler!(Array.from(this.clientDetails.values()))) + .catch(() => undefined); + }, intervalMs); + } } public async stop(): Promise { + if (this.outboundInterval) { + clearInterval(this.outboundInterval); + this.outboundInterval = undefined; + } + + this.pendingClientRequests.forEach(({ reject }) => { + reject(new Error("WebSocket connection closed")); + }); + this.pendingClientRequests.clear(); + // First sweep, soft close this.wss.clients.forEach((socket) => { socket.close(); @@ -90,13 +135,88 @@ export default class WebSocketServerTransport extends ServerTransport { }); } - private async webSocketRouterHandler(req: any, respondWith: any) { - let result = null; - if (req instanceof Array) { - result = await Promise.all(req.map((r: JSONRPCRequest) => super.routerHandler(r))); - } else { - result = await super.routerHandler(req); + private handleClientClose(ws: WebSocket) { + ws.removeAllListeners(); + this.clientDetails.delete(ws); + this.pendingClientRequests.forEach((pendingRequest, requestId) => { + if (pendingRequest.socket === ws) { + pendingRequest.reject(new Error("WebSocket connection closed")); + this.pendingClientRequests.delete(requestId); + } + }); + } + + private buildClientMethodsProxy(ws: WebSocket): ClientMethods { + const methodNames = Array.from(new Set(this.routers.flatMap((router) => router.getMethodsImplementedBy("client")))); + + return methodNames.reduce((methods: ClientMethods, methodName: string) => { + methods[methodName] = (...params: any[]) => this.callClientMethod(ws, methodName, params); + return methods; + }, {}); + } + + private callClientMethod(ws: WebSocket, method: string, params: any[]): Promise { + return new Promise((resolve, reject) => { + const id = `server-${this.nextRequestId++}`; + this.pendingClientRequests.set(id, { socket: ws, resolve, reject }); + ws.send(JSON.stringify({ id, jsonrpc: "2.0", method, params }), (err) => { + if (err) { + this.pendingClientRequests.delete(id); + reject(err); + } + }); + }); + } + + private async handleWebSocketMessage(message: WebSocket.Data, ws: WebSocket) { + const messageAsString = typeof message === "string" ? message : message.toString(); + const payload = JSON.parse(messageAsString); + + if (payload instanceof Array) { + const batchResponses = await Promise.all(payload.map((reqOrRes: any) => this.routePayload(reqOrRes, ws))); + const filteredResponses = batchResponses.filter((response) => response !== null); + if (filteredResponses.length > 0) { + ws.send(JSON.stringify(filteredResponses)); + } + return; + } + + const response = await this.routePayload(payload, ws); + if (response) { + ws.send(JSON.stringify(response)); + } + } + + private async routePayload(payload: any, ws: WebSocket): Promise { + if (this.isResponsePayload(payload)) { + this.resolvePendingClientRequest(payload); + return null; } - respondWith(JSON.stringify(result)); + + return super.routerHandler(payload as JSONRPCRequest, { client: this.clientDetails.get(ws)?.methods }); + } + + private resolvePendingClientRequest(payload: JSONRPCResponse) { + if (!payload.id) { + return; + } + + const pendingRequest = this.pendingClientRequests.get(payload.id); + if (!pendingRequest) { + return; + } + this.pendingClientRequests.delete(payload.id); + + if (payload.error) { + pendingRequest.reject(new Error(payload.error.message)); + return; + } + + pendingRequest.resolve(payload.result); + } + + private isResponsePayload(payload: JSONRPCRequest | JSONRPCResponse): payload is JSONRPCResponse { + return (payload as JSONRPCResponse).result !== undefined + || (payload as JSONRPCResponse).error !== undefined; } }