Train Model
Script for training and evaluation a residual U-Net model for image segmentation. The training process of RhizoNet includes logging, checkpointing and metrics evaluation.
Dependencies: - PyTorch Lightning - MONAI - Scikit-image - wandb
- Usage:
pip install rhizonet train_rhizonet –config_file ./setup_files/setup-train.json –gpus 2 –strategy “ddp” –accelerator “gpu”
- main()
- train_model(args)
Train and evaluate RhizoNet on a specified dataset.
- Parameters
args (Namespace) – Command-line containing: - config_file (str): Path to the JSON Configuration of the model. - gpus (int): Number of gpu nodes to use training. - strategy (str): Strategy to use for training (e.g., ‘ddp’, ‘dp’) - accelerator (str): cpu or gpu training
Returns: None
Notes
The evaluation results are saved to the file specified by ‘save_path’ in the configuration file
Training and validation metrics are also available in the WandDB project ‘rhizonet’
Predictions associated to the full size images specified in ‘pred_data_dir’ are generated after training and saved in the ‘save_path’ directory.
Metrics (accuracy, precision, recall and IOU) are evaluated on full size images specified in ‘pred_data_dir’ and results are saved in a metrics.json files
- Example::
- Run this script using the following command-line if 2 GPU nodes available:
python train.py –gpus 2 –strategy “ddp” –config_file “./setup_files/setup-train.json”
- Run this script using the following command-line if 1 GPU node available:
python train.py –gpus 1 –strategy “dp” –config_file “./setup_files/setup-train.json”