hoops_ai.dataset.GraphDataset

class hoops_ai.dataset.GraphDataset(*args, **kwargs)

Bases: Dataset

PyTorch Dataset class that uses a DatasetLoader object to dynamically create DGL graphs instead of loading pre-computed graphs from disk. Integrates the collate function from GraphClassification flow model.

This dataset can either: 1. Create graphs on-the-fly from raw data stored in the DatasetExplorer 2. Fall back to loading pre-computed DGL graphs from disk if available

Parameters:
clear_cache()

Clear the in-memory cache.

collate_function(batch)

Collate function for batching, similar to GraphClassification.collate_function.

Parameters:

batch (List[Dict[str, Any]]) – List of samples from __getitem__

Returns:

Batched data dictionary

Return type:

Dict[str, Any]

enable_cache(enable=True)

Enable or disable in-memory caching of loaded graphs.

Parameters:

enable (bool) – Whether to enable caching

get_dataloader(batch_size=32, shuffle=True, num_workers=0, drop_last=True, use_prefetch=True)

Create a DataLoader for this dataset.

Parameters:
  • batch_size (int) – Batch size

  • shuffle (bool) – Whether to shuffle the data

  • num_workers (int) – Number of worker processes

  • drop_last (bool) – Whether to drop the last incomplete batch

  • use_prefetch (bool) – Whether to use prefetch loader (if available)

Returns:

DataLoader instance

Return type:

torch.utils.data.DataLoader

remove_indices(local_indices_to_remove)

Remove items by local subset index.

Parameters:

local_indices_to_remove (List[int]) – List of indices to remove (local to this subset)