import {getRequestTypeName, REQUEST_METHOD} from "@gongt/ts-stl-library/request/request-method"; import {escapeRegExp} from "@gongt/ts-stl-library/strings/escape-regexp"; import {NextFunction, Request, RequestHandler, Response, Router} from "express-serve-static-core"; const validHostname = /^(([a-z\d]|[a-z\d][a-z\d\-]*[a-z\d])\.)*([a-z\d]|[a-z\d][a-z\d\-_]*[a-z\d])$/; export class CrossDomainMiddleware { private methods: string[] = []; private hosts: string[] = []; private credentials: string = "true"; private headers: string[] = []; public constructor() { } public getMiddleware(): RequestHandler { const methodString = this.methods.join(', ') || '*'; const exposeHeaders = this.headers.join(', '); const standardHeader: any = { 'Access-Control-Max-Age': '86400', 'Access-Control-Allow-Credentials': this.credentials, 'Access-Control-Allow-Methods': methodString, }; if (exposeHeaders) { standardHeader['Access-Control-Expose-Headers'] = exposeHeaders; } const hostRegex = new RegExp(this.hosts.map(escapeRegExp).join('|'), ''); const hasHostLimit = this.hosts.length > 0 || this.credentials; if (!hasHostLimit) { standardHeader['Access-Control-Allow-Origin'] = '*'; } return function corsHandler(req: Request, res: Response, next: NextFunction) { let origin; if (hasHostLimit) { const origin = req.header('origin'); if (hostRegex.test(origin)) { res.header('Access-Control-Allow-Origin', origin); } } res.header(standardHeader); const reqHeaders = req.header('Access-Control-Request-Headers'); if (reqHeaders) { res.header('Access-Control-Allow-Headers', reqHeaders); } res.vary('origin') .vary('Access-Control-Request-Headers'); next(); }; } public allowCredentials(allow: boolean) { this.credentials = allow? 'true' : 'false'; return this; } public exposeHeaders(...headers: string[]) { this.headers = this.headers.concat(headers); return this; } public allowMethods(...methods: REQUEST_METHOD[]) { if (this.methods.indexOf(getRequestTypeName(REQUEST_METHOD.OPTIONS)) === -1) { methods.unshift(REQUEST_METHOD.OPTIONS); } methods.forEach(e => { const name = getRequestTypeName(e); if (!name) { throw new TypeError("unknown method: " + e); } if (this.methods.indexOf(name) !== -1) { return; } return this.methods.push(name); }); return this; } public allowHosts(...hosts: string[]) { hosts.forEach(e => { if (!validHostname.test(e)) { throw new TypeError("invalid hostname: " + e); } return this.hosts.push(e); }); return this; } public mount(router: Router, path?: string) { if (path) { router.use(path, this.getMiddleware()); } else { router.use(this.getMiddleware()); } return this; } }