hoops_ai.ml.EXPERIMENTAL.flow_inference

Classes

FlowInference(cad_loader, flowmodel[, log_file])

class hoops_ai.ml.EXPERIMENTAL.flow_inference.FlowInference(cad_loader, flowmodel, log_file='training_errors.log')

Bases: object

Parameters:
  • cad_loader (CADLoader)

  • flowmodel (FlowModel)

  • log_file (str)

load_from_checkpoint(checkpoint_path)

Loads a model from a checkpoint.

Parameters:

checkpoint_path (str) – Path to the model checkpoint file

predict_and_postprocess(batch)

Makes predictions using the loaded model.

Parameters:

batch (Dict[str, torch.Tensor]) – Preprocessed input batch.

Returns:

Predicted class indices.

Return type:

np.ndarray

preprocess(file_path)

Preprocesses the input file to prepare it for the model.

Parameters:

file_path (str) – Path to the input bin dgl file.

Returns:

Preprocessed batch ready for prediction.

Return type:

Dict[str, torch.Tensor]