hoops_ai.dataset.GraphDataset
- class hoops_ai.dataset.GraphDataset(*args, **kwargs)
Bases:
DatasetPyTorch Dataset class that uses a DatasetLoader object to dynamically create DGL graphs instead of loading pre-computed graphs from disk. Integrates the collate function from GraphClassification flow model.
This dataset can either: 1. Create graphs on-the-fly from raw data stored in the DatasetExplorer 2. Fall back to loading pre-computed DGL graphs from disk if available
- Parameters:
- clear_cache()
Clear the in-memory cache.
- collate_function(batch)
Collate function for batching, similar to GraphClassification.collate_function.
- enable_cache(enable=True)
Enable or disable in-memory caching of loaded graphs.
- Parameters:
enable (bool) – Whether to enable caching
- get_dataloader(batch_size=32, shuffle=True, num_workers=0, drop_last=True, use_prefetch=True)
Create a DataLoader for this dataset.
- Parameters:
- Returns:
DataLoader instance
- Return type: