import { HfInference } from "@huggingface/inference"; import { Embeddings, EmbeddingsParams } from "./base.js"; import { getEnvironmentVariable } from "../util/env.js"; export interface HuggingFaceInferenceEmbeddingsParams extends EmbeddingsParams { apiKey?: string; model?: string; } export class HuggingFaceInferenceEmbeddings extends Embeddings implements HuggingFaceInferenceEmbeddingsParams { apiKey?: string; model: string; client: HfInference; constructor(fields?: HuggingFaceInferenceEmbeddingsParams) { super(fields ?? {}); this.model = fields?.model ?? "sentence-transformers/distilbert-base-nli-mean-tokens"; this.apiKey = fields?.apiKey ?? getEnvironmentVariable("HUGGINGFACEHUB_API_KEY"); this.client = new HfInference(this.apiKey); } async _embed(texts: string[]): Promise { // replace newlines, which can negatively affect performance. const clean = texts.map((text) => text.replace(/\n/g, " ")); return this.caller.call(() => this.client.featureExtraction({ model: this.model, inputs: clean, }) ) as Promise; } embedQuery(document: string): Promise { return this._embed([document]).then((embeddings) => embeddings[0]); } embedDocuments(documents: string[]): Promise { return this._embed(documents); } }