diff --git a/cli.ts b/cli.ts index 8c7716dc..9e942c29 100644 --- a/cli.ts +++ b/cli.ts @@ -57,14 +57,14 @@ const main = async (): Promise => { return process.exit(0); } - const webviewServer = new WebviewServer(); - const server = new MainServer(webviewServer, args); - // The main server inserts webview server address to the root HTML, so we'll - // need to wait for it to listen otherwise the address will be null. - await webviewServer.listen(typeof args["webview-port"] !== "undefined" && parseInt(args["webview-port"], 10) || 8444); - await server.listen(typeof args.port !== "undefined" && parseInt(args.port, 10) || 8443); - console.log(`Main server serving ${server.address}`); - console.log(`Webview server serving ${webviewServer.address}`); + const webviewServer = new WebviewServer(typeof args["webview-port"] !== "undefined" && parseInt(args["webview-port"], 10) || 8444); + const server = new MainServer(typeof args.port !== "undefined" && parseInt(args.port, 10) || 8443, webviewServer, args); + const [webviewAddress, serverAddress] = await Promise.all([ + webviewServer.listen(), + server.listen() + ]); + console.log(`Main server serving ${serverAddress}`); + console.log(`Webview server serving ${webviewAddress}`); }; main().catch((error) => { diff --git a/server.ts b/server.ts index 2d4f9699..cc5d2a09 100644 --- a/server.ts +++ b/server.ts @@ -56,7 +56,9 @@ export abstract class Server { // The underlying web server. protected readonly server: http.Server; - public constructor() { + private listenPromise: Promise | undefined; + + public constructor(private readonly port: number) { this.server = http.createServer(async (request, response): Promise => { try { if (request.method !== "GET") { @@ -89,14 +91,19 @@ export abstract class Server { requestPath: string, ): Promise<[string | Buffer, http.OutgoingHttpHeaders]>; - public listen(port: number): Promise { - return new Promise((resolve, reject) => { - this.server.on("error", reject); - this.server.listen(port, resolve); - }); + public listen(): Promise { + if (!this.listenPromise) { + this.listenPromise = new Promise((resolve, reject) => { + this.server.on("error", reject); + this.server.listen(this.port, () => { + resolve(this.address()); + }); + }); + } + return this.listenPromise; } - public get address(): string { + public address(): string { const address = this.server.address(); const endpoint = typeof address !== "string" ? ((address.address === "::" ? "localhost" : address.address) + ":" + address.port) @@ -121,8 +128,8 @@ export class MainServer extends Server { private readonly services = new ServiceCollection(); - public constructor(private readonly webviewServer: WebviewServer, args: ParsedArgs) { - super(); + public constructor(port: number, private readonly webviewServer: WebviewServer, args: ParsedArgs) { + super(port); this.server.on("upgrade", async (request, socket) => { const protocol = this.createProtocol(request, socket); @@ -175,7 +182,7 @@ export class MainServer extends Server { const remoteAuthority = request.headers.host as string; const transformer = getUriTransformer(remoteAuthority); - const webviewEndpoint = this.webviewServer.address; + const webviewEndpoint = await this.webviewServer.listen(); const cwd = process.env.VSCODE_CWD || process.cwd(); const workspacePath = parsedUrl.query.workspace as string | undefined;