Train

class Trainer(token_handler: TokenHandler, task: str | Task | None = None, yaml_path: str | None = None)[source]

Bases: NetsPressoBase

set_dataset_config(name: str, root_path: str, train_image: str = 'images/train', train_label: str = 'labels/train', valid_image: str = 'images/valid', valid_label: str = 'labels/valid', test_image: str = 'images/valid', test_label: str = 'labels/valid', id_mapping: List[str] | Dict[str, str] | str | None = None)[source]

Set the dataset configuration for the Trainer.

Parameters:
  • name (str) – The name of dataset.

  • root_path (str) – Root directory of dataset.

  • train_image (str, optional) – The directory for training images. Should be relative path to root directory. Defaults to “images/train”.

  • train_label (str, optional) – The directory for training labels. Should be relative path to root directory. Defaults to “labels/train”.

  • valid_image (str, optional) – The directory for validation images. Should be relative path to root directory. Defaults to “images/val”.

  • valid_label (str, optional) – The directory for validation labels. Should be relative path to root directory. Defaults to “labels/val”.

  • id_mapping (Union[List[str], Dict[str, str]], optional) – ID mapping for the dataset. Defaults to None.

set_model_config(model_name: str, img_size: int, use_pretrained: bool = True, load_head: bool = False, path: str | None = None, fx_model_path: str | None = None, optimizer_path: str | None = None)[source]

Set the model configuration for the Trainer.

Parameters:
  • model_name (str) – Name of the model.

  • img_size (int) – Image size for the model.

  • use_pretrained (bool, optional) – Whether to use a pre-trained model. Defaults to True.

  • load_head (bool, optional) – Whether to load the model head. Defaults to False.

  • path (str, optional) – Path to the model. Defaults to None.

  • fx_model_path (str, optional) – Path to the FX model. Defaults to None.

  • optimizer_path (str, optional) – Path to the optimizer. Defaults to None.

Raises:

ValueError – If the specified model is not supported for the current task.

set_fx_model(fx_model_path: str)[source]

Set the FX model path for retraining.

Parameters:

fx_model_path (str) – The path to the FX model.

Raises:

ValueError – If the model is not set. Please use ‘set_model_config’ for model setup.

set_training_config(optimizer, scheduler, epochs: int = 3, batch_size: int = 8)[source]

Set the training configuration.

Parameters:
  • optimizer – The configuration of optimizer.

  • scheduler – The configuration of learning rate scheduler.

  • epochs (int, optional) – The total number of epoch for training the model. Defaults to 3.

  • batch_size (int, optional) – The number of samples in single batch input. Defaults to 8.

set_augmentation_config(train_transforms: List | None = None, inference_transforms: List | None = None)[source]

Set the augmentation configuration for training.

Parameters:
  • train_transforms (List, optional) – List of transforms for training. Defaults to None.

  • inference_transforms (List, optional) – List of transforms for inference. Defaults to None.

set_logging_config(project_id: str | None = None, output_dir: str = './outputs', tensorboard: bool = True, csv: bool = False, image: bool = True, stdout: bool = True, save_optimizer_state: bool = True, validation_epoch: int = 10, save_checkpoint_epoch: int | None = None)[source]

Set the logging configuration.

Parameters:
  • project_id (str, optional) – Project name to save the experiment. If None, it is set as {task}_{model} (e.g. segmentation_segformer).

  • output_dir (str, optional) – Root directory for saving the experiment. Defaults to “./outputs”.

  • tensorboard (bool, optional) – Whether to use the tensorboard. Defaults to True.

  • csv (bool, optional) – Whether to save the result in csv format. Defaults to False.

  • image (bool, optional) – Whether to save the validation results. It is ignored if the task is classification. Defaults to True.

  • stdout (bool, optional) – Whether to log the standard output. Defaults to True.

  • save_optimizer_state (bool, optional) – Whether to save optimizer state with model checkpoint to resume training. Defaults to True.

  • validation_epoch (int, optional) – Validation frequency in total training process. Defaults to 10.

  • save_checkpoint_epoch (int, optional) – Checkpoint saving frequency in total training process. Defaults to None.

set_environment_config(seed: int = 1, num_workers: int = 4)[source]

Set the environment configuration.

Parameters:
  • seed (int, optional) – Random seed. Defaults to 1.

  • num_workers (int, optional) – The number of multi-processing workers to be used by the data loader. Defaults to 4.

train(gpus: str, project_name: str, output_dir: str | None = './outputs') TrainerMetadata[source]

Train the model with the specified configuration.

Parameters:
  • gpus (str) – GPU ids to use, separated by commas.

  • project_name (str) – Project name to save the experiment.

Returns:

A dictionary containing information about the training.

Return type:

Dict

Example

Train

from netspresso import NetsPresso
from netspresso.enums import Task
from netspresso.trainer.augmentations import Resize
from netspresso.trainer.optimizers import AdamW
from netspresso.trainer.schedulers import CosineAnnealingWarmRestartsWithCustomWarmUp


netspresso = NetsPresso(email="YOUR_EMAIL", password="YOUR_PASSWORD")

# 1. Declare trainer
trainer = netspresso.trainer(task=Task.OBJECT_DETECTION)

# 2. Set config for training
# 2-1. Data
trainer.set_dataset_config(
   name="traffic_sign_config_example",
   root_path="/root/traffic-sign",
   train_image="images/train",
   train_label="labels/train",
   valid_image="images/valid",
   valid_label="labels/valid",
   id_mapping=["prohibitory", "danger", "mandatory", "other"],
)

# 2-2. Model
print(trainer.available_models)  # ['EfficientFormer', 'YOLOX-S', ...]
trainer.set_model_config(model_name="YOLOX-S", img_size=512)

# 2-3. Augmentation
trainer.set_augmentation_config(
   train_transforms=[Resize()],
   inference_transforms=[Resize()],
)

# 2-4. Training
optimizer = AdamW(lr=6e-3)
scheduler = CosineAnnealingWarmRestartsWithCustomWarmUp(warmup_epochs=10)
trainer.set_training_config(
   epochs=40,
   batch_size=16,
   optimizer=optimizer,
   scheduler=scheduler,
)

# 3. Train
training_result = trainer.train(gpus="0, 1", project_name="project_sample")

Retrain

from netspresso import NetsPresso
from netspresso.trainer.optimizers import AdamW


netspresso = NetsPresso(email="YOUR_EMAIL", password="YOUR_PASSWORD")

# 1. Declare trainer
trainer = netspresso.trainer(yaml_path="./temp/hparams.yaml")

# 2. Set config for retraining
# 2-1. FX Model
trainer.set_fx_model(fx_model_path="./temp/FX_MODEL_PATH.pt")

# 2-2. Training
trainer.set_training_config(
   epochs=30,
   batch_size=16,
   optimizer=AdamW(lr=6e-3),
)

# 3. Train
trainer.train(gpus="0, 1", project_name="project_retrain_sample")