hoops_ai.ml.FlowTrainer
- class hoops_ai.ml.FlowTrainer(flowmodel=None, datasetLoader=None, batch_size=64, num_workers=0, experiment_name='UNKNOWN_Experiment', accelerator='cpu', devices='auto', gradient_clip_val=1.0, max_epochs=100, learning_rate=None, result_dir=None, resume_checkpoint_path=None, early_stopping=True, **trainer_kwargs)
Bases:
objectOrchestrates FlowModel training, checkpointing, evaluation, and dataset purification.
This class wires a split DatasetLoader into PyTorch Lightning using the FlowModel interface methods for sample loading, collation, model retrieval, and metric persistence.
- Parameters:
flowmodel (FlowModel) – Initialized FlowModel instance. Must provide collate_function, load_model_input_from_files, retrieve_model, model_name, and metrics.
datasetLoader (DatasetLoader) – DatasetLoader instance that has already been split into train, validation, and test subsets.
batch_size (int) – Batch size used for train/validation/test DataLoaders.
num_workers (int) – Number of DataLoader worker processes.
experiment_name (str) – Experiment identifier used in output folder naming.
accelerator (str) – Lightning accelerator passed to Trainer (for example: cpu, gpu).
devices – Lightning devices setting (for example: auto, 1, [0, 1]).
gradient_clip_val (float) – Gradient clipping threshold forwarded to Lightning Trainer.
max_epochs (int) – Maximum number of training epochs.
learning_rate (float | None) – Trainer-level learning rate. When explicitly set, overrides the LR from any resumed checkpoint (applied via optimizer param groups at train start, after Lightning restores checkpoint optimizer state). When
None(the default), the checkpoint’s saved LR is preserved on resume; for a fresh run the model’s own default LR is used.result_dir (str) – Root output directory where training artifacts are written.
resume_checkpoint_path (str | None) – Optional checkpoint path used by train() when no explicit resume_checkpoint_path is passed to train().
early_stopping (dict[str, Any] | bool | None) –
Early stopping behavior. - True (default): enable with default settings. - False or None: disable early stopping. - dict: enable and override defaults accepted by EarlyStopping,
plus reset_patience_on_resume.
**trainer_kwargs – Additional keyword arguments passed directly into Lightning Trainer. Provided callbacks are appended to built-in callbacks.
- purify(num_processes=1, chunks_per_process=1)
Purify the datasets in parallel (if num_processes > 1) or in a single process. This method executes a forward/backward pass on a batch of size 1 to check for numerical errors (e.g., NaNs or crashes) and logs a JSON file for each processed chunk.
- test(trained_model_path)
Tests the model using the specified checkpoint.
- Parameters:
trained_model_path (str)
- train(resume_checkpoint_path=None, train_shuffle=True, train_seed=None)
Trains the model and returns the path to the best checkpoint.
- Parameters:
resume_checkpoint_path (str | None) – Optional checkpoint path to resume training from.
train_shuffle (bool) – Whether to shuffle the training DataLoader. For reproducibility experiments, use train_shuffle=False.
train_seed (int | None) – Optional seed passed into the Lightning model. This is used by Embedding._deterministic_train_seed(…)
- Return type: