{ "cells": [ { "cell_type": "markdown", "id": "f5662507-f32c-4537-b0d8-b9234768b2c2", "metadata": {}, "source": [ "# HOOPS AI: Use the Dataset Explorer to navigate the dataset\n", "\n", "\n", "The `dataset` module provides a comprehensive framework for exploring, navigating, and loading CAD model datasets for machine learning applications. It consists of two primary components that work together to simplify data handling:\n", "\n", "1. **DatasetExplorer** - For exploring and querying dataset contents\n", "2. **DatasetLoader** - For loading and preparing datasets for machine learning training\n", "\n", "These components are designed to work with the processed data from the `cadaccess` and `cadencoder` modules, as well as the outputs from the flow pipeline system. They provide high-level abstractions that allow users to focus on machine learning tasks rather than data handling complexities.\n", "\n", "## DatasetExplorer\n", "\n", "The `DatasetExplorer` class (`dataset_explorer.py`) provides methods for exploring and querying datasets stored in Zarr format (.dataset) with accompanying metadata (.infoset) in Parquet files. This class focuses on data discovery, filtering, and statistical analysis.\n", "\n", "### Key Methods\n", "\n", "#### Data Discovery and Metadata\n", "\n", "- `available_groups() -> set`: Returns the set of available dataset groups (faces, edges, file, etc.)\n", "- `get_descriptions(table_name: str, key_id: Optional[int] = None, use_wildchar: Optional[bool] = False) -> pd.DataFrame`: Retrieves metadata descriptions (labels, face types, edge types, etc.)\n", "- `get_parquet_info_by_code(file_id_code: int)`: Returns rows from the Parquet file for a specific file ID code\n", "- `get_file_info_all() -> pd.DataFrame`: Returns all file info from the Parquet metadata\n", "\n", "#### Data Distribution Analysis\n", "\n", "- `create_distribution(key: str, bins: int = 10, group: str = \"faces\") -> Dict[str, Any]`: Computes histograms of data distributions using Dask for parallel processing\n", "\n", "#### Data Filtering and Selection\n", "\n", "- `get_file_list(group: str, where: Callable[[xr.Dataset], xr.DataArray]) -> List[str]`: Returns file IDs matching a boolean filter condition\n", "- `file_dataset(file_id_code: int, group: str) -> xr.Dataset`: Returns a subset of the dataset for a specific file\n", "- `build_membership_matrix(group: str, key: str, bins_or_categories: Union[int, List, np.ndarray], as_counts: bool = False) -> tuple[np.ndarray, np.ndarray, np.ndarray]`: Builds a file-by-bin membership matrix for stratified splitting\n", "- `decode_file_id_code(code: int) -> str`: Converts an integer file ID code to the original string identifier" ] }, { "cell_type": "code", "execution_count": 1, "id": "021c4256-d764-4983-aae8-4a87498b483e", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "ℹ️ Using TEST LICENSE (expires December 8, 2025 - 37 days remaining)\n", " For production use, obtain your own license from Tech Soft 3D\n", "======================================================================\n", "✓ HOOPS AI License: Valid (TEST LICENSE - expires Dec 6, 2025)\n", "======================================================================\n" ] } ], "source": [ "import hoops_ai\n", "import os\n", "\n", "hoops_ai.set_license(hoops_ai.use_test_license())" ] }, { "cell_type": "code", "execution_count": 2, "id": "381fd5db-d080-4004-a103-fd5d358ee504", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "C:\\Users\\LuisSalazar\\Documents\\MAIN\\MLProject\\repo\\HOOPS-AI-tutorials\\packages\\flows\\cadsynth_aag\n", "[DatasetExplorer] Default local cluster started: \n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "ff6d70a731c24cf2a703a8dc28a18450", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Processing file info: 0%| | 0/162412 [00:00 id name description table_name\n", "0 0 Plane not set face_types\n", "1 1 Cylinder not set face_types\n", "2 2 Cone not set face_types\n", "3 14 Extrusion not set face_types\n", "4 3 Sphere not set face_types\n", "5 4 Torus not set face_types\n", "6 5 Nurbs not set face_types\n", "7 13 Revolution not set face_types\n" ] } ], "source": [ "face_type_description = explorer.get_descriptions(\"face_types\")\n", "print(type(face_type_description), face_type_description)" ] }, { "cell_type": "code", "execution_count": 5, "id": "1fd33405-d646-487a-9b81-f2bd9d0510a3", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " name description \\\n", "0 0007c9b910d876090f12c2cad80794df C:\\Temp\\Cadsynth_aag\\step\\00035533.stp \n", "\n", " subset id table_name \n", "0 train 25 file_info \n" ] } ], "source": [ "# Get and print meta data information\n", "file_id = 25\n", "df_info = explorer.get_parquet_info_by_code(file_id)\n", "print(type(df_info), df_info)" ] }, { "cell_type": "code", "execution_count": 6, "id": "af3a683e-c6f1-408e-9c0c-3e7ad9446feb", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Datasets (Table of Content) for file ID '25':\n", " [faceface] DATA: a3_distance, Shape: (2025, 64), Dims: ('faceface_flat', 'bins_d'), Size: 129600\n", " [faceface] DATA: d2_distance, Shape: (2025, 64), Dims: ('faceface_flat', 'bins_a'), Size: 129600\n", " [faceface] DATA: extended_adjacency, Shape: (2025,), Dims: ('faceface_flat',), Size: 2025\n", " [faceface] DATA: face_pair_edges_path, Shape: (2025, 32), Dims: ('faceface_flat', 'dim_path'), Size: 64800\n", " [faceface] DATA: face_x, Shape: (2025,), Dims: ('faceface_flat',), Size: 2025\n", " [faceface] DATA: face_y, Shape: (2025,), Dims: ('faceface_flat',), Size: 2025\n", " [faceface] DATA: file_id_code_faceface, Shape: (2025,), Dims: ('faceface_flat',), Size: 2025\n", " [faces] DATA: face_areas, Shape: (45,), Dims: ('face',), Size: 45\n", " [faces] DATA: face_indices, Shape: (45,), Dims: ('face',), Size: 45\n", " [faces] DATA: face_labels, Shape: (45,), Dims: ('face',), Size: 45\n", " [faces] DATA: face_loops, Shape: (45,), Dims: ('face',), Size: 45\n", " [faces] DATA: face_neighborscount, Shape: (45,), Dims: ('face',), Size: 45\n", " [faces] DATA: face_types, Shape: (45,), Dims: ('face',), Size: 45\n", " [faces] DATA: face_uv_grids, Shape: (45, 5, 5, 7), Dims: ('face', 'uv_x', 'uv_y', 'component'), Size: 7875\n", " [faces] DATA: file_id_code_faces, Shape: (45,), Dims: ('face',), Size: 45\n", " [file] DATA: duration_dglconvert, Shape: (1,), Dims: ('file',), Size: 1\n", " [file] DATA: file_id_code_file, Shape: (1,), Dims: ('file',), Size: 1\n", " [file] DATA: size_cadfile, Shape: (1,), Dims: ('file',), Size: 1\n", " [file] DATA: size_dglfile, Shape: (1,), Dims: ('file',), Size: 1\n", " [edges] DATA: edge_convexities, Shape: (128,), Dims: ('edge',), Size: 128\n", " [edges] DATA: edge_dihedral_angles, Shape: (128,), Dims: ('edge',), Size: 128\n", " [edges] DATA: edge_indices, Shape: (128,), Dims: ('edge',), Size: 128\n", " [edges] DATA: edge_lengths, Shape: (128,), Dims: ('edge',), Size: 128\n", " [edges] DATA: edge_types, Shape: (128,), Dims: ('edge',), Size: 128\n", " [edges] DATA: edge_u_grids, Shape: (128, 5, 6), Dims: ('edge', 'dim_x', 'component'), Size: 3840\n", " [edges] DATA: file_id_code_edges, Shape: (128,), Dims: ('edge',), Size: 128\n", " [graph] DATA: destination, Shape: (128,), Dims: ('edge',), Size: 128\n", " [graph] DATA: file_id_code_graph, Shape: (128,), Dims: ('edge',), Size: 128\n", " [graph] DATA: num_nodes, Shape: (128,), Dims: ('edge',), Size: 128\n", " [graph] DATA: source, Shape: (128,), Dims: ('edge',), Size: 128\n", "\n", "type of file_data_arrays \n", "type of array_areas \n", "brep surfaces (45,)\n" ] } ], "source": [ "# Access various dataset groups\n", "file_datasetGroup = {grp: explorer.file_dataset(file_id_code=file_id, group=grp) for grp in groups}\n", "\n", "print(f\"Datasets (Table of Content) for file ID '{file_id}':\")\n", "for grp, ds in file_datasetGroup.items():\n", " for name, da in ds.data_vars.items():\n", " print(f\" [{grp}] DATA: {name}, Shape: {da.shape}, Dims: {da.dims}, Size: {da.size}\")\n", "print()\n", "\n", "file_dataset = file_datasetGroup[\"faces\"]\n", "print(\"type of file_data_arrays\", type(file_dataset))\n", "\n", "#print the areas of each face\n", "array_areas = file_dataset[\"face_areas\"].data.compute()\n", "print(\"type of array_areas\", type(array_areas))\n", "print(\"brep surfaces\", array_areas.shape)" ] }, { "cell_type": "code", "execution_count": 7, "id": "b6ed3a30-beef-4f26-a17a-46538c5f4fc2", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "numpy array shape (45, 5, 5, 7)\n" ] } ], "source": [ "# this requires to be a bit familiar with pandas and dask.\n", "uv_grid_data = file_dataset[\"face_uv_grids\"].data.compute()\n", "\n", "\n", "print(\"numpy array shape\", uv_grid_data.shape)" ] }, { "cell_type": "code", "execution_count": 8, "id": "66844184-060d-433e-855a-4e94fcc27e2d", "metadata": { "scrolled": true, "tags": [] }, "outputs": [], "source": [ "#print(uv_grid_data)" ] }, { "cell_type": "code", "execution_count": 9, "id": "d13f4ae0-82e8-4369-93b7-99c163d64b7c", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
namedescriptionsubsetidtable_name
000001d44b8eb37fe4bb6ec4e89746ea3C:\\Temp\\Cadsynth_aag\\step\\00021388.stptrain0file_info
1000093c3b74c8076fbba20bf8613e2a2C:\\Temp\\Cadsynth_aag\\step\\20221121_154647_2127...validation1file_info
200012dc0e22bd178d7c6a12436734130C:\\Temp\\Cadsynth_aag\\step\\00067737.stpvalidation2file_info
300013001df401fe06d5213ef6fc9e581C:\\Temp\\Cadsynth_aag\\step\\00058353.stpvalidation3file_info
400015d0e0e285c5089a485360b34db1aC:\\Temp\\Cadsynth_aag\\step\\00048571.stptrain4file_info
..................
162407fffe278b406958997fff4c1f76b72ee0C:\\Temp\\Cadsynth_aag\\step\\20221123_142528_2390...test162407file_info
162408fffe3fba779b5fa5543c4a82c3cfbabbC:\\Temp\\Cadsynth_aag\\step\\00091954.stptrain162408file_info
162409fffee3edd46107ca808a958811dcb9a8C:\\Temp\\Cadsynth_aag\\step\\20221124_154714_1362...validation162409file_info
162410ffff1ac0abc28d3ede6b7bfb6434cdf4C:\\Temp\\Cadsynth_aag\\step\\00065860.stptrain162410file_info
162411ffffdc1d0cb629065026fa55a97ca314C:\\Temp\\Cadsynth_aag\\step\\00016373.stptrain162411file_info
\n", "

162412 rows × 5 columns

\n", "
" ], "text/plain": [ " name \\\n", "0 00001d44b8eb37fe4bb6ec4e89746ea3 \n", "1 000093c3b74c8076fbba20bf8613e2a2 \n", "2 00012dc0e22bd178d7c6a12436734130 \n", "3 00013001df401fe06d5213ef6fc9e581 \n", "4 00015d0e0e285c5089a485360b34db1a \n", "... ... \n", "162407 fffe278b406958997fff4c1f76b72ee0 \n", "162408 fffe3fba779b5fa5543c4a82c3cfbabb \n", "162409 fffee3edd46107ca808a958811dcb9a8 \n", "162410 ffff1ac0abc28d3ede6b7bfb6434cdf4 \n", "162411 ffffdc1d0cb629065026fa55a97ca314 \n", "\n", " description subset id \\\n", "0 C:\\Temp\\Cadsynth_aag\\step\\00021388.stp train 0 \n", "1 C:\\Temp\\Cadsynth_aag\\step\\20221121_154647_2127... validation 1 \n", "2 C:\\Temp\\Cadsynth_aag\\step\\00067737.stp validation 2 \n", "3 C:\\Temp\\Cadsynth_aag\\step\\00058353.stp validation 3 \n", "4 C:\\Temp\\Cadsynth_aag\\step\\00048571.stp train 4 \n", "... ... ... ... \n", "162407 C:\\Temp\\Cadsynth_aag\\step\\20221123_142528_2390... test 162407 \n", "162408 C:\\Temp\\Cadsynth_aag\\step\\00091954.stp train 162408 \n", "162409 C:\\Temp\\Cadsynth_aag\\step\\20221124_154714_1362... validation 162409 \n", "162410 C:\\Temp\\Cadsynth_aag\\step\\00065860.stp train 162410 \n", "162411 C:\\Temp\\Cadsynth_aag\\step\\00016373.stp train 162411 \n", "\n", " table_name \n", "0 file_info \n", "1 file_info \n", "2 file_info \n", "3 file_info \n", "4 file_info \n", "... ... \n", "162407 file_info \n", "162408 file_info \n", "162409 file_info \n", "162410 file_info \n", "162411 file_info \n", "\n", "[162412 rows x 5 columns]" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "explorer.get_file_info_all()" ] }, { "cell_type": "code", "execution_count": 10, "id": "9b977880-8df0-4eda-a43d-57fe16a625e1", "metadata": { "tags": [] }, "outputs": [], "source": [ "# Visualization libraries\n", "import matplotlib.pyplot as plt\n", "\n", "def print_distribution_info(dist, title=\"Distribution\"):\n", " \"\"\"Helper function to print and visualize distribution data.\"\"\"\n", " list_filecount = list()\n", " for i, bin_files in enumerate(dist['file_id_codes_in_bins']):\n", " list_filecount.append(bin_files.size)\n", "\n", " dist['file_count'] =list_filecount\n", " # Visualization with matplotlib\n", " fig, ax = plt.subplots(figsize=(12, 4))\n", " \n", " bin_centers = 0.5 * (dist['bin_edges'][1:] + dist['bin_edges'][:-1])\n", " ax.bar(bin_centers, dist['file_count'], width=(dist['bin_edges'][1] - dist['bin_edges'][0]), \n", " alpha=0.7, color='steelblue', edgecolor='black', linewidth=1)\n", " \n", " # Add file count annotations\n", " for i, count in enumerate(dist['file_count']):\n", " if count > 0: # Only annotate non-empty bins\n", " ax.text(bin_centers[i], count + 0.5, f\"{count}\", \n", " ha='center', va='bottom', fontsize=8)\n", " \n", " ax.set_xlabel('Value')\n", " ax.set_ylabel('Count')\n", " ax.set_title(f'{title} Histogram')\n", " ax.grid(True, linestyle='--', alpha=0.7)\n", " \n", " plt.tight_layout()\n", " plt.show()" ] }, { "cell_type": "code", "execution_count": 11, "id": "477d5cad-0bcb-4575-9079-55f16327ac67", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Face labels distribution created in 2.31 seconds\n", "\n" ] }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import time\n", "start_time = time.time()\n", "face_dist = explorer.create_distribution(key=\"face_labels\", bins=None, group=\"faces\")\n", "print(f\"Face labels distribution created in {(time.time() - start_time):.2f} seconds\\n\")\n", "print_distribution_info(face_dist, title=\"Labels\")\n" ] }, { "cell_type": "code", "execution_count": 12, "id": "42e37fde-a9db-40c7-baf0-ae183bda4abe", "metadata": { "scrolled": true, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " id name description table_name\n", "14 0 no-label not set face_labels\n", "15 7 2sides_through_step not set face_labels\n", "16 11 rectangular_blind_slot not set face_labels\n", "17 18 circular blind step not set face_labels\n", "18 5 6sides_passage not set face_labels\n", "19 9 rectangular_blind_step not set face_labels\n", "20 12 rectangular_pocket not set face_labels\n", "21 13 triangular_pocket not set face_labels\n", "22 14 6sides_pocket not set face_labels\n", "23 21 circular end pocket not set face_labels\n", "24 23 blind hole not set face_labels\n", "25 2 triangular_through_slot not set face_labels\n", "26 17 through hole not set face_labels\n", "27 3 rectangular_passage not set face_labels\n", "28 16 circular through slot not set face_labels\n", "29 4 triangular_passage not set face_labels\n", "30 6 rectangular_through_step not set face_labels\n", "31 10 triangular_blind_step not set face_labels\n", "32 24 fillet not set face_labels\n", "33 1 rectangular_through_slot not set face_labels\n", "34 15 chamfer not set face_labels\n", "35 19 horizontal circular end blind slot not set face_labels\n", "36 22 o-ring not set face_labels\n", "37 8 slanted_through_step not set face_labels\n", "38 20 vertical circular end blind slot not set face_labels\n" ] } ], "source": [ "print(explorer.get_descriptions(\"face_labels\"))" ] }, { "cell_type": "code", "execution_count": 13, "id": "cf864c67-0f5f-4b7d-8037-b2803e180fa8", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Face labels distribution created in 4.20 seconds\n", "\n" ] }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "start_time = time.time()\n", "dist = explorer.create_distribution(key=\"num_nodes\", bins=12, group=\"graph\")\n", "print(f\"Face labels distribution created in {(time.time() - start_time):.2f} seconds\\n\")\n", "print_distribution_info(dist, title=\"Brep facecount Distribution\")" ] }, { "cell_type": "markdown", "id": "83793d5b-93c9-423a-be2e-0d543e33d15f", "metadata": { "tags": [] }, "source": [ "# Gather files that fulfilled a given condition. Filter" ] }, { "cell_type": "code", "execution_count": 14, "id": "39715be3-a1ed-43db-a950-6def2d0a64bf", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Filtering completed in 0.21 seconds\n", "Found 32234 files with face_labels == 15 (chamfer)\n", "\n", "[ 10 12 15 ... 162392 162398 162401]\n" ] } ], "source": [ "start_time = time.time()\n", "\n", "# condition\n", "label_is_pipefittings = lambda ds: ds['face_labels'] == 15\n", "\n", "filelist = explorer.get_file_list(group=\"faces\", where=label_is_pipefittings)\n", "print(f\"Filtering completed in {(time.time() - start_time):.2f} seconds\")\n", "print(f\"Found {len(filelist)} files with face_labels == 15 (chamfer)\\n\")\n", "print(filelist)" ] }, { "cell_type": "markdown", "id": "30cfce7a-515a-4ebe-87bd-2286235a823b", "metadata": {}, "source": [ "# Query data for single file" ] }, { "cell_type": "code", "execution_count": 15, "id": "83350a38-9bde-47a3-84a5-2bd47724a5e0", "metadata": { "tags": [] }, "outputs": [], "source": [ "def demo_query_single_file(explorer, file_id):\n", " \"\"\"Show how to access and query dataset details for a single file.\"\"\"\n", " print(\"=== Single File Dataset Access ===\")\n", " import time\n", " # Get and print parquet info\n", " df_info = explorer.get_parquet_info_by_code(file_id)\n", " print(\"Files info:\")\n", " for column in df_info.columns:\n", " print(f\"Column: {column}\")\n", " for value in df_info[column]:\n", " print(f\" {value}\")\n", " print()\n", "\n", " # Access various dataset groups\n", " groups = [\"faces\", \"file\", \"edges\", \"graph\"]\n", " datasets = {grp: explorer.file_dataset(file_id_code=file_id, group=grp) for grp in groups}\n", "\n", " print(f\"Datasets for file ID '{file_id}':\")\n", " for grp, ds in datasets.items():\n", " for name, da in ds.data_vars.items():\n", " print(f\" [{grp}] VARIABLE: {name}, Shape: {da.shape}, Dims: {da.dims}, Size: {da.size}\")\n", " print()\n", "\n", " # Query uv grids data for a specific face\n", " start_time = time.time()\n", " uv_grid_data = datasets[\"faces\"][\"face_uv_grids\"].isel(face=2)\n", " print(\"uv_grids data for face index 2:\")\n", " np_uvgrid = uv_grid_data.data.compute()\n", " print(f\"Query took {(time.time() - start_time):.2f} seconds\\n\")" ] }, { "cell_type": "code", "execution_count": 16, "id": "4dd615c0-6659-4407-b3a8-951770560272", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "=== Single File Dataset Access ===\n", "Files info:\n", "Column: name\n", " 07207dfd094fe0ebe9368ded4c271b23\n", "Column: description\n", " C:\\Temp\\Cadsynth_aag\\step\\20221124_154714_17096.step\n", "Column: subset\n", " test\n", "Column: id\n", " 4500\n", "Column: table_name\n", " file_info\n", "\n", "Datasets for file ID '4500':\n", " [faces] VARIABLE: face_areas, Shape: (38,), Dims: ('face',), Size: 38\n", " [faces] VARIABLE: face_indices, Shape: (38,), Dims: ('face',), Size: 38\n", " [faces] VARIABLE: face_labels, Shape: (38,), Dims: ('face',), Size: 38\n", " [faces] VARIABLE: face_loops, Shape: (38,), Dims: ('face',), Size: 38\n", " [faces] VARIABLE: face_neighborscount, Shape: (38,), Dims: ('face',), Size: 38\n", " [faces] VARIABLE: face_types, Shape: (38,), Dims: ('face',), Size: 38\n", " [faces] VARIABLE: face_uv_grids, Shape: (38, 5, 5, 7), Dims: ('face', 'uv_x', 'uv_y', 'component'), Size: 6650\n", " [faces] VARIABLE: file_id_code_faces, Shape: (38,), Dims: ('face',), Size: 38\n", " [file] VARIABLE: duration_dglconvert, Shape: (1,), Dims: ('file',), Size: 1\n", " [file] VARIABLE: file_id_code_file, Shape: (1,), Dims: ('file',), Size: 1\n", " [file] VARIABLE: size_cadfile, Shape: (1,), Dims: ('file',), Size: 1\n", " [file] VARIABLE: size_dglfile, Shape: (1,), Dims: ('file',), Size: 1\n", " [edges] VARIABLE: edge_convexities, Shape: (96,), Dims: ('edge',), Size: 96\n", " [edges] VARIABLE: edge_dihedral_angles, Shape: (96,), Dims: ('edge',), Size: 96\n", " [edges] VARIABLE: edge_indices, Shape: (96,), Dims: ('edge',), Size: 96\n", " [edges] VARIABLE: edge_lengths, Shape: (96,), Dims: ('edge',), Size: 96\n", " [edges] VARIABLE: edge_types, Shape: (96,), Dims: ('edge',), Size: 96\n", " [edges] VARIABLE: edge_u_grids, Shape: (96, 5, 6), Dims: ('edge', 'dim_x', 'component'), Size: 2880\n", " [edges] VARIABLE: file_id_code_edges, Shape: (96,), Dims: ('edge',), Size: 96\n", " [graph] VARIABLE: destination, Shape: (96,), Dims: ('edge',), Size: 96\n", " [graph] VARIABLE: file_id_code_graph, Shape: (96,), Dims: ('edge',), Size: 96\n", " [graph] VARIABLE: num_nodes, Shape: (96,), Dims: ('edge',), Size: 96\n", " [graph] VARIABLE: source, Shape: (96,), Dims: ('edge',), Size: 96\n", "\n", "uv_grids data for face index 2:\n", "Query took 1.13 seconds\n", "\n" ] } ], "source": [ "demo_query_single_file(explorer,file_id=4500)" ] }, { "cell_type": "markdown", "id": "ca936a97-a46e-4b43-bebc-7c3623d1dff7", "metadata": {}, "source": [ "# Create subsets (train, validation, test) based on the label distribution" ] }, { "cell_type": "code", "execution_count": 17, "id": "9b25d773-2a63-45b9-8b3f-8a21c751c2e7", "metadata": { "tags": [] }, "outputs": [], "source": [ "\n", "def demo_stratified_splits(explorer):\n", " \"\"\"Show building a membership matrix and performing stratified splits.\"\"\"\n", " print(\"=== Membership Matrix and Data Splitting ===\")\n", " \n", " import time\n", " from iterstrat.ml_stratifiers import MultilabelStratifiedShuffleSplit\n", " import numpy as np\n", " start_time = time.time()\n", "\n", " df_label = explorer.get_descriptions(\"label\", None, True)\n", " label_key = df_label[\"table_name\"].iloc[0]\n", "\n", " if label_key == \"file_label\":\n", " group = \"file\"\n", " else:\n", " group = \"faces\"\n", "\n", " matrix, file_codes, _ = explorer.build_membership_matrix(group=group, key=label_key, bins_or_categories=None, as_counts=False)\n", "\n", " # First split: 70% train, 30% temporary\n", " msss = MultilabelStratifiedShuffleSplit(n_splits=1, test_size=0.30, random_state=42)\n", " for train_idx, temp_idx in msss.split(np.arange(len(matrix))[:, None], matrix):\n", " pass\n", "\n", " # Second split on the temporary set into 50% validation, 50% test => 15% each overall\n", " msss2 = MultilabelStratifiedShuffleSplit(n_splits=1, test_size=0.5, random_state=43)\n", " for val_sub, test_sub in msss2.split(np.arange(len(temp_idx))[:, None], matrix[temp_idx]):\n", " val_idx = temp_idx[val_sub]\n", " test_idx = temp_idx[test_sub]\n", "\n", " print(\"Train file IDs:\", file_codes[train_idx].shape)\n", " print(\"Validation file IDs:\", file_codes[val_idx].shape)\n", " print(\"Test file IDs:\", file_codes[test_idx].shape)\n", " print(f\"Stratified Splitting completed in {(time.time() - start_time):.2f} seconds\")\n", " print()" ] }, { "cell_type": "code", "execution_count": 18, "id": "a4fc2bb1-25ab-47c0-b513-da792ab4b3ae", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "=== Membership Matrix and Data Splitting ===\n", "Train file IDs: (113479,)\n", "Validation file IDs: (24453,)\n", "Test file IDs: (24480,)\n", "Stratified Splitting completed in 15.14 seconds\n", "\n" ] } ], "source": [ "demo_stratified_splits(explorer)" ] } ], "metadata": { "kernelspec": { "display_name": "HOOPS AI (CPU)", "language": "python", "name": "hoops_ai_cpu" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.21" } }, "nbformat": 4, "nbformat_minor": 5 }