Train
- class Trainer(token_handler: TokenHandler, task: str | Task | None = None, yaml_path: str | None = None)[source]
Bases:
NetsPressoBase
- set_dataset_config(name: str, root_path: str, train_image: str = 'images/train', train_label: str = 'labels/train', valid_image: str = 'images/valid', valid_label: str = 'labels/valid', test_image: str = 'images/valid', test_label: str = 'labels/valid', id_mapping: List[str] | Dict[str, str] | str | None = None)[source]
Set the dataset configuration for the Trainer.
- Parameters:
name (str) – The name of dataset.
root_path (str) – Root directory of dataset.
train_image (str, optional) – The directory for training images. Should be relative path to root directory. Defaults to “images/train”.
train_label (str, optional) – The directory for training labels. Should be relative path to root directory. Defaults to “labels/train”.
valid_image (str, optional) – The directory for validation images. Should be relative path to root directory. Defaults to “images/val”.
valid_label (str, optional) – The directory for validation labels. Should be relative path to root directory. Defaults to “labels/val”.
id_mapping (Union[List[str], Dict[str, str]], optional) – ID mapping for the dataset. Defaults to None.
- set_model_config(model_name: str, img_size: int, use_pretrained: bool = True, load_head: bool = False, path: str | None = None, fx_model_path: str | None = None, optimizer_path: str | None = None)[source]
Set the model configuration for the Trainer.
- Parameters:
model_name (str) – Name of the model.
img_size (int) – Image size for the model.
use_pretrained (bool, optional) – Whether to use a pre-trained model. Defaults to True.
load_head (bool, optional) – Whether to load the model head. Defaults to False.
path (str, optional) – Path to the model. Defaults to None.
fx_model_path (str, optional) – Path to the FX model. Defaults to None.
optimizer_path (str, optional) – Path to the optimizer. Defaults to None.
- Raises:
ValueError – If the specified model is not supported for the current task.
- set_fx_model(fx_model_path: str)[source]
Set the FX model path for retraining.
- Parameters:
fx_model_path (str) – The path to the FX model.
- Raises:
ValueError – If the model is not set. Please use ‘set_model_config’ for model setup.
- set_training_config(optimizer, scheduler, epochs: int = 3, batch_size: int = 8)[source]
Set the training configuration.
- Parameters:
optimizer – The configuration of optimizer.
scheduler – The configuration of learning rate scheduler.
epochs (int, optional) – The total number of epoch for training the model. Defaults to 3.
batch_size (int, optional) – The number of samples in single batch input. Defaults to 8.
- set_augmentation_config(train_transforms: List | None = None, inference_transforms: List | None = None)[source]
Set the augmentation configuration for training.
- Parameters:
train_transforms (List, optional) – List of transforms for training. Defaults to None.
inference_transforms (List, optional) – List of transforms for inference. Defaults to None.
- set_logging_config(project_id: str | None = None, output_dir: str = './outputs', tensorboard: bool = True, csv: bool = False, image: bool = True, stdout: bool = True, save_optimizer_state: bool = True, validation_epoch: int = 10, save_checkpoint_epoch: int | None = None)[source]
Set the logging configuration.
- Parameters:
project_id (str, optional) – Project name to save the experiment. If None, it is set as {task}_{model} (e.g. segmentation_segformer).
output_dir (str, optional) – Root directory for saving the experiment. Defaults to “./outputs”.
tensorboard (bool, optional) – Whether to use the tensorboard. Defaults to True.
csv (bool, optional) – Whether to save the result in csv format. Defaults to False.
image (bool, optional) – Whether to save the validation results. It is ignored if the task is classification. Defaults to True.
stdout (bool, optional) – Whether to log the standard output. Defaults to True.
save_optimizer_state (bool, optional) – Whether to save optimizer state with model checkpoint to resume training. Defaults to True.
validation_epoch (int, optional) – Validation frequency in total training process. Defaults to 10.
save_checkpoint_epoch (int, optional) – Checkpoint saving frequency in total training process. Defaults to None.
- set_environment_config(seed: int = 1, num_workers: int = 4)[source]
Set the environment configuration.
- Parameters:
seed (int, optional) – Random seed. Defaults to 1.
num_workers (int, optional) – The number of multi-processing workers to be used by the data loader. Defaults to 4.
- train(gpus: str, project_name: str, output_dir: str | None = './outputs') TrainerMetadata [source]
Train the model with the specified configuration.
- Parameters:
gpus (str) – GPU ids to use, separated by commas.
project_name (str) – Project name to save the experiment.
- Returns:
A dictionary containing information about the training.
- Return type:
Dict
Example
Train
from netspresso import NetsPresso
from netspresso.enums import Task
from netspresso.trainer.augmentations import Resize
from netspresso.trainer.optimizers import AdamW
from netspresso.trainer.schedulers import CosineAnnealingWarmRestartsWithCustomWarmUp
netspresso = NetsPresso(email="YOUR_EMAIL", password="YOUR_PASSWORD")
# 1. Declare trainer
trainer = netspresso.trainer(task=Task.OBJECT_DETECTION)
# 2. Set config for training
# 2-1. Data
trainer.set_dataset_config(
name="traffic_sign_config_example",
root_path="/root/traffic-sign",
train_image="images/train",
train_label="labels/train",
valid_image="images/valid",
valid_label="labels/valid",
id_mapping=["prohibitory", "danger", "mandatory", "other"],
)
# 2-2. Model
print(trainer.available_models) # ['EfficientFormer', 'YOLOX-S', ...]
trainer.set_model_config(model_name="YOLOX-S", img_size=512)
# 2-3. Augmentation
trainer.set_augmentation_config(
train_transforms=[Resize()],
inference_transforms=[Resize()],
)
# 2-4. Training
optimizer = AdamW(lr=6e-3)
scheduler = CosineAnnealingWarmRestartsWithCustomWarmUp(warmup_epochs=10)
trainer.set_training_config(
epochs=40,
batch_size=16,
optimizer=optimizer,
scheduler=scheduler,
)
# 3. Train
training_result = trainer.train(gpus="0, 1", project_name="project_sample")
Retrain
from netspresso import NetsPresso
from netspresso.trainer.optimizers import AdamW
netspresso = NetsPresso(email="YOUR_EMAIL", password="YOUR_PASSWORD")
# 1. Declare trainer
trainer = netspresso.trainer(yaml_path="./temp/hparams.yaml")
# 2. Set config for retraining
# 2-1. FX Model
trainer.set_fx_model(fx_model_path="./temp/FX_MODEL_PATH.pt")
# 2-2. Training
trainer.set_training_config(
epochs=30,
batch_size=16,
optimizer=AdamW(lr=6e-3),
)
# 3. Train
trainer.train(gpus="0, 1", project_name="project_retrain_sample")