hoops_ai.dataset.DatasetLoader

class hoops_ai.dataset.DatasetLoader(merged_store_path=None, parquet_file_path=None, item_loader_func=None, graph_files=None, labels=None, file_code_start=0)

Bases: object

A framework-agnostic dataset class that:
  • Creates internally a DatasetExplorer from a .dataset and infoset files based on any group/key available

  • Builds membership matrix for multi-label stratification

  • Splits data into train/validation/test subsets

  • Provides get_dataset(…) to get a CADDataset object

  • Offers remove_indices(…) returning a map of old -> new indices

Parameters:
  • merged_store_path (str | None)

  • parquet_file_path (str | None)

  • item_loader_func (Callable | None)

  • graph_files (List[str] | None)

  • labels (List[Any] | None)

  • file_code_start (int)

available_arrays(group_name)

Get all available arrays for a specific group.

Parameters:

group_name (str)

Return type:

set

available_groups()

Get all available groups in the dataset.

Return type:

set

close_resources(clear_split_history=True)

Close and cleanup resources, particularly the DatasetExplorer instance.

Parameters:

clear_split_history (bool) – If True, also clears the split history

diagnose_file_codes_mismatch(file_codes=None)

Diagnostic method to help understand file code and ID mismatches. Call this method when experiencing issues with file code mapping.

find_group_for_array(array_name)

Find which group contains a specific array. Returns None if the array is not found in any group.

Parameters:

array_name (str)

Return type:

str | None

get_available_stratification_keys()

Get all available keys that can be used for stratification, grouped by their containing group.

Return type:

dict

get_dataset(subset, key=None)

Return a framework-agnostic CADDataset for ‘train’, ‘validation’, or ‘test’.

Parameters:
  • subset (str) – One of ‘train’, ‘validation’, or ‘test’

  • key (str | None) – Optional key to specify which split to use if multiple splits exist. If None, uses the current active split.

Returns:

A framework-agnostic dataset containing the requested subset

Return type:

CADDataset

get_file_codes()

Get the array of file ID codes for the current dataset.

Returns:

Array of file ID codes, or None if not available.

Each index i corresponds to the file code for that file.

Return type:

np.ndarray

get_subset(subset, key=None)
Parameters:
  • subset (str)

  • key (str | None)

Return type:

CADDataset

remove_indices(indices_to_remove)

Removes the given global indices from data_files/label_datas. Automatically updates subset_indices if a split has been done. Returns a dict mapping old_index -> new_index for items that remain.

reset_split_state()

Reset the split state to allow for a new split with different parameters. This preserves the previous split results in _split_history.

set(item_loader_func)
Parameters:

item_loader_func (Callable | None)

set_item_loader_func(item_loader_func)

Inject/replace the item loader function.

This is designed for the common workflow:

loader.split(…) train_ds = loader.get_subset(“train”) loader.set_item_loader_func(my_loader)

Subsets returned earlier remain valid because they delegate item loading to the parent DatasetLoader at access time.

Parameters:

item_loader_func (Callable | None)

split(key='label', group='machining', categories=None, train=0.8, validation=0.1, test=0.1, random_state=42, force_reset=False)

Split the dataset into train/validation/test.

The meaning of key depends on how this loader was constructed:

Direct mode (constructor receives graph_files=[…]):
  • key=”random”: random split.

  • key=”label”: stratified split using the provided labels=[…] list.

DatasetExplorer mode (legacy, uses parquet_file_path):
  • key=”random”: random split over all available files.

  • Otherwise: multi-label stratified split using DatasetExplorer.build_membership_matrix(…) with the provided group, key, and categories.

Notes

  • Splits are deterministic for the same parameters (including random_state).

  • If you call split again with different parameters, the previous split is archived in _split_history (keyed by the key you passed).

Parameters:
Return type:

Tuple[int, int, int]

validate_configuration()

Validate the current configuration and return a summary of available data. Useful for debugging and ensuring the dataset is properly configured.

Return type:

dict