/** * @license * Copyright 2022, JsData. All rights reserved. * * This source code is licensed under the MIT license found in the * LICENSE file in the root directory of this source tree. * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ========================================================================== */ import { CrossValidator } from './CrossValidator'; import { Scikit1D, Scikit2D, Tensor1D } from '../types'; export interface KFoldParams { /** * Number of ways in which the dataset is to be split. * Defaults to `5`. Larger numbers of splits will result * in increased computational cost, since the model is * trained `nSplits` time during cross validation. In * return, more training data is available in each split, * which allows utilize smaller datasets more efficiently. */ nSplits?: number; /** * If set to `true`, indices are shuffled before train/test * splitting. Defaults to `false`. */ shuffle?: boolean; /** * Random seed to be used for shuffling. Ignored if `nSplits = false`. */ randomState?: number; } /** * K-Fold cross-validator * * Generates train and test indices to split data in train/test subsets. * To generate these subsets, the dataset is split into k (about) evenly * sized chunks of consecutive elements. Each split takes another chunk * as test data and the remaining chunks are combined to be the training * data. * * Optionally, the indices can be shuffled before splitting it into chunks * (disabled by default). * * @example * ```js * import { KFold } from 'scikitjs' * * const kf = new KFold({ nSplits: 3 }) * * const X = tf.range(0, 7).reshape([7, 1]) as Tensor2D * * console.log( 'nSplits:', kf.getNumSplits(X) ) * * for (const { trainIndex, testIndex } of kf.split(X) ) * { * try { * console.log( 'train:', trainIndex.toString() ) * console.log( 'test:', testIndex.toString() ) * } * finally { * trainIndex.dispose() * testIndex.dispose() * } * } * ``` */ export declare class KFold implements CrossValidator { nSplits: number; shuffle: boolean; randomState?: number; name: string; tf: any; constructor({ nSplits, shuffle, randomState }?: KFoldParams); getNumSplits(): number; split(X: Scikit2D, y?: Scikit1D, groups?: Scikit1D): IterableIterator<{ trainIndex: Tensor1D; testIndex: Tensor1D; }>; }