import { PoolClient, Pool, escapeIdentifier } from 'pg'; import { Repositories } from './Repositories'; import { InvalidUpdateValueError } from '@squiz/dx-common-lib'; export interface Reader { find(item: Partial): Promise; findOne(id: string | Partial): Promise; } export interface Writer { create(value: Partial): Promise; update(where: Partial, newValue: Partial): Promise; delete(where: Partial): Promise; } export type Repository = Reader & Writer; export type PageResult = { items: T[]; totalCount: number; pageSize: number; }; export type SortDirection = 'desc' | 'asc'; export const DEFAULT_PAGE_SIZE = 20; export abstract class AbstractRepository implements Reader, Writer { protected tableName: string; /** object where the key is the model property name amd the value is sql column name */ protected modelPropertyToSqlColumn: { [key in keyof SHAPE]: string }; /** object where the key is the sql column name and the value is the model property name */ protected sqlColumnToModelProperty: { [key: string]: string }; constructor( protected repositories: Repositories, protected pool: Pool, tableName: string, mapping: { [key in keyof SHAPE]: string }, protected classRef: { new (data?: Record): DATA_CLASS }, ) { this.tableName = `"${tableName}"`; this.modelPropertyToSqlColumn = mapping; this.sqlColumnToModelProperty = Object.entries(mapping).reduce((prev, curr) => { const [modelProp, columnName] = curr as [string, string]; prev[columnName] = modelProp; return prev; }, {} as { [key: string]: string }); } protected async getConnection(): Promise { return await this.pool.connect(); } private sanitiseValue(value: Partial) { const sanitisedValue = Object.entries(value) .map(([key, value]) => [this.modelPropertyToSqlColumn[key as keyof SHAPE], value] as const) .filter(([column, _v]) => !!column); const columns = sanitisedValue.map(([column]) => escapeIdentifier(column)); const bindingParams = sanitisedValue.map((_, index) => `$${index + 1}`); const values = sanitisedValue.map(([_column, value]) => value); return { columns, bindingParams, values, }; } async create(value: SHAPE, transactionClient: PoolClient | null = null): Promise { const valueAsClass = new this.classRef(value as Record); const { columns, bindingParams, values } = this.sanitiseValue(valueAsClass); const result = await this.executeQuery( `INSERT INTO ${this.tableName} (${columns.join(', ')}) VALUES (${bindingParams.join(', ')}) RETURNING *`, values, transactionClient, ); return result[0]; } async update( where: Partial, newValue: Exclude, Record>, transactionClient: PoolClient | null = null, ): Promise { if (Object.keys(newValue).length === 0) { throw new InvalidUpdateValueError('Failed updating the repository, update values cannot be empty'); } const whereValues = Object.values(where); const newValues = this.sanitiseValue(newValue); const setValues = newValues.columns.map((c, i) => `${c} = ${newValues.bindingParams[i]}`); const whereString = this.createWhereStringFromPartialModel(where, setValues.length); const result = await this.executeQuery( `UPDATE ${this.tableName} SET ${setValues.join(', ')} WHERE ${whereString} RETURNING *`, [...newValues.values, ...whereValues], transactionClient, ); return result; } async delete(where: Partial, transactionClient: PoolClient | null = null): Promise { const client = transactionClient ?? (await this.getConnection()); try { const whereString = this.createWhereStringFromPartialModel(where); const result = await client.query(`DELETE FROM ${this.tableName} WHERE ${whereString}`, Object.values(where)); return result.rowCount ?? 0; } finally { if (client && !transactionClient) { client.release(); } } } protected createWhereStringFromPartialModel(values: Partial, initialIndex: number = 0) { const keys = Object.keys(values); if (keys.length == 0) { throw new Error(`Values cannot be an empty object. It must have at least one property`); } const sql = keys.reduce((acc, key, index) => { const condition = `"${this.modelPropertyToSqlColumn[key as keyof SHAPE]}" = $${1 + index + initialIndex}`; return acc === '' ? `${acc} ${condition}` : `${acc} AND ${condition}`; }, ''); return sql; } protected async executeQueryRaw( query: string, values: any[], transactionClient: PoolClient | null = null, ): Promise { const client = transactionClient ?? (await this.getConnection()); try { const result = await client.query(query, values); return result.rows; } finally { if (client && !transactionClient) { client.release(); } } } async executeQuery(query: string, values: any[], transactionClient: PoolClient | null = null): Promise { const rows = await this.executeQueryRaw(query, values, transactionClient); return rows.map((a) => this.createAndHydrateModel(a)); } protected createAndHydrateModel(row: any): SHAPE { const inputData: Record = {}; for (const key of Object.keys(row)) { const translatedKey = this.sqlColumnToModelProperty[key]; inputData[translatedKey] = row[key]; } return new this.classRef(inputData); } async findOne(item: Partial, transactionClient?: PoolClient): Promise { const result = await this.executeQuery( `SELECT * FROM ${this.tableName} WHERE ${this.createWhereStringFromPartialModel(item)} LIMIT 1`, Object.values(item), transactionClient, ); return result[0]; } async find(item: Partial): Promise { const result = await this.executeQuery( `SELECT * FROM ${this.tableName} WHERE ${this.createWhereStringFromPartialModel(item)}`, Object.values(item), ); return result; } async findAll(): Promise { const result = await this.executeQuery( `SELECT * FROM ${this.tableName}`, [], ); return result; } async getCount(item: Partial | null = null): Promise { let whereClause = ''; if (item) { whereClause = `WHERE ${this.createWhereStringFromPartialModel(item)}`; } return this.getCountRaw(whereClause, item ? Object.values(item) : []); } async getPage( pageNumber: number, sortBy: (keyof SHAPE)[] = [], direction: SortDirection = 'asc', pageSize: number | null = null, item: Partial | null = null, ): Promise> { let whereClause = ''; if (item) { whereClause = `WHERE ${this.createWhereStringFromPartialModel(item)}`; } return this.getPageRaw( pageNumber, sortBy, direction, whereClause, this.tableName, Object.values(item ?? {}), pageSize, ); } async getCountRaw(whereClause: string = '', values: any[] = [], tableRef: string = ''): Promise { const result = await this.executeQueryRaw( `SELECT COUNT(*) FROM ${this.tableName} ${tableRef} ${whereClause}`, values, ); return parseInt(result[0].count); } async getPageRaw( pageNumber: number, sortBy: (keyof SHAPE)[] = [], direction: SortDirection = 'asc', whereClause: string = '', tableRef: string = '', values: any[] = [], pageSize: number | null = null, searchFields: Partial | null = null, ): Promise> { if (pageSize === null) { pageSize = DEFAULT_PAGE_SIZE; } if (pageNumber <= 0) { throw new Error(`Page number value cannot be less than 1`); } if (pageSize <= 0) { throw new Error(`Page size value cannot be less than 1`); } let orderByClause = ''; if (sortBy.length) { orderByClause = `ORDER BY ${sortBy .map((a) => this.modelPropertyToSqlColumn[a as keyof SHAPE]) .join(',')} ${direction}`; } const offset = (pageNumber - 1) * pageSize; if (searchFields !== null) { const searchFieldsWhere: string[] = []; for (const [key, value] of Object.entries(searchFields)) { if (typeof value !== 'string') { throw new Error(`Search field ${key} needs to be of type string`); } searchFieldsWhere.push(`${this.modelPropertyToSqlColumn[key as keyof SHAPE]} LIKE $${values.length + 1}`); values.push(`%${value}%`); } if (whereClause.length) { whereClause = `${whereClause} AND`; } whereClause = `${whereClause} (${searchFieldsWhere.join(' OR ')})`; } const query = ` SELECT * FROM ${this.tableName} ${tableRef} ${whereClause} ${orderByClause} OFFSET ${offset} LIMIT ${pageSize} `; const items = await this.executeQuery(query, values); return { items, totalCount: await this.getCountRaw(whereClause, values, tableRef), pageSize, }; } }