import type { CredentialsParams, RepoDesignation } from "../types/public"; import { omit } from "../utils/omit"; import { toRepoId } from "../utils/toRepoId"; import { typedEntries } from "../utils/typedEntries"; import { downloadFile } from "./download-file"; import { fileExists } from "./file-exists"; import { promisesQueue } from "../utils/promisesQueue"; import type { SetRequired } from "../vendor/type-fest/set-required"; export const SAFETENSORS_FILE = "model.safetensors"; export const SAFETENSORS_INDEX_FILE = "model.safetensors.index.json"; /// We advise model/library authors to use the filenames above for convention inside model repos, /// but in some situations safetensors weights have different filenames. export const RE_SAFETENSORS_FILE = /\.safetensors$/; export const RE_SAFETENSORS_INDEX_FILE = /\.safetensors\.index\.json$/; export const RE_SAFETENSORS_SHARD_FILE = /^(?(?.*?)[_-])(?\d{5,6})-of-(?\d{5,6})\.safetensors$/; export interface SafetensorsShardFileInfo { prefix: string; basePrefix: string; shard: string; total: string; } export function parseSafetensorsShardFilename(filename: string): SafetensorsShardFileInfo | null { const match = RE_SAFETENSORS_SHARD_FILE.exec(filename); if (match && match.groups) { return { prefix: match.groups["prefix"], basePrefix: match.groups["basePrefix"], shard: match.groups["shard"], total: match.groups["total"], }; } return null; } const PARALLEL_DOWNLOADS = 20; const MAX_HEADER_LENGTH = 25_000_000; // 25MB const MAX_CONFIG_LENGTH = 10_000_000; // 10MB — config.json is typically small; cap to avoid large memory use const MAX_SHARD_COUNT = 10_000; // well above any real sharded model; blocks crafted index with millions of entries const GPTQ_QWEIGHT_SUFFIX = "qweight"; const GPTQ_AWQ_AUXILIARY_SUFFIXES = ["qzeros", "g_idx", "scales"]; class SafetensorParseError extends Error {} type FileName = string; export type TensorName = string; export type Dtype = | "F64" | "F32" | "F16" | "F8_E4M3" | "F8_E5M2" | "E8M0" | "F6_E3M2" | "F6_E2M3" | "F4" | "FP4" | "BF16" | "I64" | "I32" | "I16" | "I8" | "U16" | "U8" | "UE8" | "BOOL"; export interface TensorInfo { dtype: Dtype; shape: number[]; data_offsets: [number, number]; } export type SafetensorsFileHeader = Record & { __metadata__: { total_parameters?: string | number } & Record; }; export interface SafetensorsIndexJson { dtype?: string; /// ^there's sometimes a dtype but it looks inconsistent. metadata?: { total_parameters?: string | number } & Record; /// ^ why the naming inconsistency? weight_map: Record; } export type SafetensorsShardedHeaders = Record; export type SafetensorsParseFromRepo = | { sharded: false; header: SafetensorsFileHeader; parameterCount?: Partial>; parameterTotal?: number; filepaths: string[]; } | { sharded: true; index: SafetensorsIndexJson; headers: SafetensorsShardedHeaders; parameterCount?: Partial>; parameterTotal?: number; filepaths: string[]; }; /** * Fetches and parses model config.json */ async function fetchModelConfig( params: { repo: RepoDesignation; revision?: string; hubUrl?: string; fetch?: typeof fetch; } & Partial, ): Promise { try { const configBlob = await downloadFile({ ...params, path: "config.json", }); if (!configBlob) { return null; } const config = JSON.parse(await configBlob.slice(0, MAX_CONFIG_LENGTH).text()); return config as ModelConfig; } catch (error) { // Config file might not exist or be inaccessible return null; } } async function parseSingleFile( path: string, params: { repo: RepoDesignation; revision?: string; hubUrl?: string; /** * Custom fetch function to use instead of the default one, for example to use a proxy or edit headers. */ fetch?: typeof fetch; } & Partial, ): Promise { const blob = await downloadFile({ ...params, path }); if (!blob) { throw new SafetensorParseError(`Failed to parse file ${path}: failed to fetch safetensors header length.`); } const bufLengthOfHeaderLE = await blob.slice(0, 8).arrayBuffer(); const lengthOfHeader = new DataView(bufLengthOfHeaderLE).getBigUint64(0, true); // ^little-endian if (lengthOfHeader <= 0) { throw new SafetensorParseError(`Failed to parse file ${path}: safetensors header is malformed.`); } if (lengthOfHeader > MAX_HEADER_LENGTH) { throw new SafetensorParseError( `Failed to parse file ${path}: safetensor header is too big. Maximum supported size is ${MAX_HEADER_LENGTH} bytes.`, ); } try { // no validation for now, we assume it's a valid FileHeader. const header: SafetensorsFileHeader = JSON.parse(await blob.slice(8, 8 + Number(lengthOfHeader)).text()); return header; } catch (err) { throw new SafetensorParseError(`Failed to parse file ${path}: safetensors header is not valid JSON.`); } } async function parseShardedIndex( path: string, params: { repo: RepoDesignation; revision?: string; hubUrl?: string; /** * Custom fetch function to use instead of the default one, for example to use a proxy or edit headers. */ fetch?: typeof fetch; } & Partial, ): Promise { const indexBlob = await downloadFile({ ...params, path, }); if (!indexBlob) { throw new SafetensorParseError(`Failed to parse file ${path}: failed to fetch safetensors index.`); } try { // no validation for now, we assume it's a valid IndexJson. const index = JSON.parse(await indexBlob.slice(0, MAX_HEADER_LENGTH).text()); return index; } catch (error) { throw new SafetensorParseError(`Failed to parse file ${path}: not a valid JSON.`); } } async function fetchAllHeaders( path: string, index: SafetensorsIndexJson, params: { repo: RepoDesignation; revision?: string; hubUrl?: string; /** * Custom fetch function to use instead of the default one, for example to use a proxy or edit headers. */ fetch?: typeof fetch; } & Partial, ): Promise { const pathPrefix = path.slice(0, path.lastIndexOf("/") + 1); const filenames = [...new Set(Object.values(index.weight_map))]; if (filenames.length > MAX_SHARD_COUNT) { throw new SafetensorParseError( `Too many shard files (${filenames.length}). Maximum supported is ${MAX_SHARD_COUNT}.`, ); } for (const filename of filenames) { if (filename.includes("..") || filename.startsWith("/") || filename.includes("://")) { throw new SafetensorParseError(`Unsafe shard filename in weight_map: "${filename}"`); } } const shardedMap: SafetensorsShardedHeaders = Object.fromEntries( await promisesQueue( filenames.map( (filename) => async () => [filename, await parseSingleFile(pathPrefix + filename, params)] satisfies [string, SafetensorsFileHeader], ), PARALLEL_DOWNLOADS, ), ); return shardedMap; } function parseTotalParameters(value: string | number | undefined): number | undefined { if (!value) return undefined; if (typeof value === "number") return value; return parseInt(value); } /** * Analyze model.safetensors.index.json or model.safetensors from a model hosted * on Hugging Face using smart range requests to extract its metadata. */ export async function parseSafetensorsMetadata( params: { /** Only models are supported */ repo: RepoDesignation; /** * Relative file path to safetensors file inside `repo`. Defaults to `SAFETENSORS_FILE` or `SAFETENSORS_INDEX_FILE` (whichever one exists). */ path?: string; /** * Will include SafetensorsParseFromRepo["parameterCount"], an object containing the number of parameters for each DType * * @default false */ computeParametersCount: true; hubUrl?: string; revision?: string; /** * Custom fetch function to use instead of the default one, for example to use a proxy or edit headers. */ fetch?: typeof fetch; } & Partial, ): Promise>; export async function parseSafetensorsMetadata( params: { /** Only models are supported */ repo: RepoDesignation; path?: string; /** * Will include SafetensorsParseFromRepo["parameterCount"], an object containing the number of parameters for each DType * * @default false */ computeParametersCount?: boolean; hubUrl?: string; revision?: string; /** * Custom fetch function to use instead of the default one, for example to use a proxy or edit headers. */ fetch?: typeof fetch; } & Partial, ): Promise; export async function parseSafetensorsMetadata( params: { repo: RepoDesignation; path?: string; computeParametersCount?: boolean; hubUrl?: string; revision?: string; /** * Custom fetch function to use instead of the default one, for example to use a proxy or edit headers. */ fetch?: typeof fetch; } & Partial, ): Promise { const repoId = toRepoId(params.repo); if (repoId.type !== "model") { throw new TypeError("Only model repos should contain safetensors files."); } // Fetch model config for quantization information const modelConfig = params.computeParametersCount ? await fetchModelConfig(params) : null; const quantConfig = modelConfig?.quantization_config ?? modelConfig?.text_config?.quantization_config; if ( (params.path && RE_SAFETENSORS_FILE.test(params.path)) || (await fileExists({ ...params, path: SAFETENSORS_FILE })) ) { const header = await parseSingleFile(params.path ?? SAFETENSORS_FILE, params); const paramStats = params.computeParametersCount ? { parameterCount: computeNumOfParamsByDtypeSingleFile(header, quantConfig), /// shortcut: get param count directly from metadata parameterTotal: parseTotalParameters(header.__metadata__.total_parameters), } : undefined; return { sharded: false, header, ...paramStats, filepaths: [params.path ?? SAFETENSORS_FILE], }; } else if ( (params.path && RE_SAFETENSORS_INDEX_FILE.test(params.path)) || (await fileExists({ ...params, path: SAFETENSORS_INDEX_FILE })) ) { const path = params.path ?? SAFETENSORS_INDEX_FILE; const index = await parseShardedIndex(path, params); const shardedMap = await fetchAllHeaders(path, index, params); const pathPrefix = path.slice(0, path.lastIndexOf("/") + 1); const paramStats = params.computeParametersCount ? { parameterCount: computeNumOfParamsByDtypeSharded(shardedMap, quantConfig), /// shortcut: get param count directly from metadata parameterTotal: parseTotalParameters(index.metadata?.total_parameters), } : undefined; return { sharded: true, index, headers: shardedMap, ...paramStats, filepaths: [path, ...Object.keys(shardedMap).map((filename) => pathPrefix + filename)], }; } else { throw new Error("model id does not seem to contain safetensors weights"); } } export interface QuantizationConfig { quant_method?: string; modules_to_not_convert?: string[]; bits?: number; load_in_4bit?: boolean; load_in_8bit?: boolean; // compressed-tensors specific format?: string; config_groups?: Record; } export interface ModelConfig { quantization_config?: QuantizationConfig; text_config?: { quantization_config?: QuantizationConfig }; } /** * @internal * Glob match without RegExp: splits pattern on `*` and checks that each literal * segment appears in order within `str`. Avoids RegExp entirely (no ReDoS risk, * no SyntaxError from attacker-controlled patterns in config.json). */ export function globMatch(pattern: string, str: string): boolean { const parts = pattern.split("*"); if (parts.length === 1) { return pattern === str; } if (!str.startsWith(parts[0])) return false; let pos = parts[0].length; const lastPart = parts[parts.length - 1]; if (!str.endsWith(lastPart)) return false; const end = str.length - lastPart.length; for (let i = 1; i < parts.length - 1; i++) { const idx = str.indexOf(parts[i], pos); if (idx === -1 || idx + parts[i].length > end) return false; pos = idx + parts[i].length; } return pos <= end; } /** * Determines if a tensor is quantized based on quantization config and tensor name. * * Python's transformers uses plain substring matching for `modules_to_not_convert`, * so bare names like `"lm_head"` must match `"model.lm_head.weight"`. When the * pattern contains a `*` we fall back to proper glob matching for flexibility. */ export function isQuantizedTensor(tensorName: string, quantConfig?: QuantizationConfig): boolean { if (!quantConfig) return false; const patterns = quantConfig.modules_to_not_convert; if (!patterns?.length) return true; return !patterns.some((pattern) => pattern.includes("*") ? globMatch(pattern, tensorName) : tensorName.includes(pattern), ); } /** * Gets the parameter multiplier for a quantized tensor based on quantization method */ function getQuantizationMultiplier(tensorName: string, dtype: Dtype, quantConfig?: QuantizationConfig): number { if (!quantConfig || !isQuantizedTensor(tensorName, quantConfig)) { return 1; } const quantMethod = quantConfig.quant_method?.toLowerCase(); switch (quantMethod) { case "mxfp4": if (dtype === "U8" && tensorName.includes("_blocks")) { return 2; } return 1; case "gptq": case "awq": if (getTensorSuffix(tensorName) === GPTQ_QWEIGHT_SUFFIX) { const bits = quantConfig.bits && quantConfig.bits > 0 ? quantConfig.bits : 4; return Math.max(1, Math.floor(32 / bits)); } if (quantConfig.bits === 4 && dtype === "U8") { return 2; } if (quantConfig.bits === 2 && dtype === "U8") { return 4; } return 1; case "compressed-tensors": if (quantConfig.format === "pack-quantized" && dtype === "I32") { const numBits = Object.values(quantConfig.config_groups ?? {}).find((g) => g.weights?.num_bits)?.weights?.num_bits ?? 4; return Math.floor(32 / numBits); } return 1; case "bitsandbytes": if (quantConfig.load_in_4bit && dtype === "U8") { return 2; } return 1; default: if (dtype === "U8" && (quantConfig.load_in_4bit || quantConfig.bits === 4)) { return 2; } return 1; } } function computeNumOfParamsByDtypeSingleFile( header: SafetensorsFileHeader, quantConfig?: QuantizationConfig, ): Partial> { const counter: Partial> = {}; const tensors = omit(header, "__metadata__"); for (const [tensorName, v] of typedEntries(tensors)) { if (shouldSkipTensor(tensorName, quantConfig)) { continue; } if (v.shape.length === 0) { continue; } const elements = v.shape.reduce((a, b) => a * b); if (!Number.isFinite(elements)) { continue; } const multiplier = quantConfig ? getQuantizationMultiplier(tensorName, v.dtype, quantConfig) : 1; if (multiplier === 0) { continue; } counter[v.dtype] = (counter[v.dtype] ?? 0) + elements * multiplier; } return counter; } function computeNumOfParamsByDtypeSharded( shardedMap: SafetensorsShardedHeaders, quantConfig?: QuantizationConfig, ): Partial> { const counter: Partial> = {}; for (const header of Object.values(shardedMap)) { for (const [k, v] of typedEntries(computeNumOfParamsByDtypeSingleFile(header, quantConfig))) { counter[k] = (counter[k] ?? 0) + (v ?? 0); } } return counter; } function getTensorSuffix(tensorName: string): string { const lastDotIndex = tensorName.lastIndexOf("."); return lastDotIndex === -1 ? tensorName : tensorName.slice(lastDotIndex + 1); } function shouldSkipTensor(tensorName: string, quantConfig?: QuantizationConfig): boolean { if (!quantConfig) return false; const quantMethod = quantConfig.quant_method?.toLowerCase(); if (quantMethod !== "gptq" && quantMethod !== "awq") return false; if (!isQuantizedTensor(tensorName, quantConfig)) return false; const suffix = getTensorSuffix(tensorName); return suffix !== GPTQ_QWEIGHT_SUFFIX && GPTQ_AWQ_AUXILIARY_SUFFIXES.includes(suffix); }