Skip to content

Trainer

Description

netspresso.trainer.trainer.Trainer

Bases: NetsPressoBase

__init__(token_handler, task=None, yaml_path=None)

Initialize the Trainer.

Parameters:

Name Type Description Default
task Union[str, Task]]

The type of task (classification, detection, segmentation). Either 'task' or 'yaml_path' must be provided, but not both.

None
yaml_path str

Path to the YAML configuration file. Either 'task' or 'yaml_path' must be provided, but not both.

None

set_augmentation_config(train_transforms=None, inference_transforms=None)

Set the augmentation configuration for training.

Parameters:

Name Type Description Default
train_transforms List

List of transforms for training. Defaults to None.

None
inference_transforms List

List of transforms for inference. Defaults to None.

None

set_dataset_config(name, root_path, train_image='images/train', train_label='labels/train', valid_image='images/valid', valid_label='labels/valid', test_image='images/valid', test_label='labels/valid', id_mapping=None)

Set the dataset configuration for the Trainer.

Parameters:

Name Type Description Default
name str

The name of dataset.

required
root_path str

Root directory of dataset.

required
train_image str

The directory for training images. Should be relative path to root directory. Defaults to "images/train".

'images/train'
train_label str

The directory for training labels. Should be relative path to root directory. Defaults to "labels/train".

'labels/train'
valid_image str

The directory for validation images. Should be relative path to root directory. Defaults to "images/val".

'images/valid'
valid_label str

The directory for validation labels. Should be relative path to root directory. Defaults to "labels/val".

'labels/valid'
id_mapping Union[List[str], Dict[str, str]]

ID mapping for the dataset. Defaults to None.

None

set_environment_config(seed=1, num_workers=4)

Set the environment configuration.

Parameters:

Name Type Description Default
seed int

Random seed. Defaults to 1.

1
num_workers int

The number of multi-processing workers to be used by the data loader. Defaults to 4.

4

set_fx_model(fx_model_path)

Set the FX model path for retraining.

Parameters:

Name Type Description Default
fx_model_path str

The path to the FX model.

required

Raises:

Type Description
ValueError

If the model is not set. Please use 'set_model_config' for model setup.

set_logging_config(project_id=None, output_dir='./outputs', tensorboard=True, csv=False, image=True, stdout=True, save_optimizer_state=True, validation_epoch=10, save_checkpoint_epoch=None)

Set the logging configuration.

Parameters:

Name Type Description Default
project_id str

Project name to save the experiment. If None, it is set as {task}_{model} (e.g. segmentation_segformer).

None
output_dir str

Root directory for saving the experiment. Defaults to "./outputs".

'./outputs'
tensorboard bool

Whether to use the tensorboard. Defaults to True.

True
csv bool

Whether to save the result in csv format. Defaults to False.

False
image bool

Whether to save the validation results. It is ignored if the task is classification. Defaults to True.

True
stdout bool

Whether to log the standard output. Defaults to True.

True
save_optimizer_state bool

Whether to save optimizer state with model checkpoint to resume training. Defaults to True.

True
validation_epoch int

Validation frequency in total training process. Defaults to 10.

10
save_checkpoint_epoch int

Checkpoint saving frequency in total training process. Defaults to None.

None

set_model_config(model_name, img_size, use_pretrained=True, load_head=False, path=None, fx_model_path=None, optimizer_path=None)

Set the model configuration for the Trainer.

Parameters:

Name Type Description Default
model_name str

Name of the model.

required
img_size int

Image size for the model.

required
use_pretrained bool

Whether to use a pre-trained model. Defaults to True.

True
load_head bool

Whether to load the model head. Defaults to False.

False
path str

Path to the model. Defaults to None.

None
fx_model_path str

Path to the FX model. Defaults to None.

None
optimizer_path str

Path to the optimizer. Defaults to None.

None

Raises:

Type Description
ValueError

If the specified model is not supported for the current task.

set_training_config(optimizer, scheduler, epochs=3, batch_size=8)

Set the training configuration.

Parameters:

Name Type Description Default
optimizer

The configuration of optimizer.

required
scheduler

The configuration of learning rate scheduler.

required
epochs int

The total number of epoch for training the model. Defaults to 3.

3
batch_size int

The number of samples in single batch input. Defaults to 8.

8

train(gpus, project_name, output_dir='./outputs')

Train the model with the specified configuration.

Parameters:

Name Type Description Default
gpus str

GPU ids to use, separated by commas.

required
project_name str

Project name to save the experiment.

required

Returns:

Name Type Description
Dict TrainerMetadata

A dictionary containing information about the training.

Examples

Training

from netspresso import NetsPresso
from netspresso.enums import Task
from netspresso.trainer.augmentations import Resize
from netspresso.trainer.optimizers import AdamW
from netspresso.trainer.schedulers import CosineAnnealingWarmRestartsWithCustomWarmUp

# Login with API key (recommended)
# Get your API token from: https://account.netspresso.ai/api-token
netspresso = NetsPresso(api_key="YOUR_API_KEY")

# Note: Email/password login will be deprecated soon
# netspresso = NetsPresso(email="YOUR_EMAIL", password="YOUR_PASSWORD")

# 1. Declare trainer
trainer = netspresso.trainer(task=Task.OBJECT_DETECTION)

# 2. Set config for training
# 2-1. Data
trainer.set_dataset_config(
   name="traffic_sign_config_example",
   root_path="/root/traffic-sign",
   train_image="images/train",
   train_label="labels/train",
   valid_image="images/valid",
   valid_label="labels/valid",
   id_mapping=["prohibitory", "danger", "mandatory", "other"],
)

# 2-2. Model
print(trainer.available_models)  # ['EfficientFormer', 'YOLOX-S', ...]
trainer.set_model_config(model_name="YOLOX-S", img_size=512)

# 2-3. Augmentation
trainer.set_augmentation_config(
   train_transforms=[Resize()],
   inference_transforms=[Resize()],
)

# 2-4. Training
optimizer = AdamW(lr=6e-3)
scheduler = CosineAnnealingWarmRestartsWithCustomWarmUp(warmup_epochs=10)
trainer.set_training_config(
   epochs=40,
   batch_size=16,
   optimizer=optimizer,
   scheduler=scheduler,
)

# 3. Train
training_result = trainer.train(gpus="0, 1", project_name="project_sample")

Retraining

from netspresso import NetsPresso
from netspresso.trainer.optimizers import AdamW

# Login with API key (recommended)
# Get your API token from: https://account.netspresso.ai/api-token
netspresso = NetsPresso(api_key="YOUR_API_KEY")

# Note: Email/password login will be deprecated soon
# netspresso = NetsPresso(email="YOUR_EMAIL", password="YOUR_PASSWORD")

# 1. Declare trainer
trainer = netspresso.trainer(yaml_path="./temp/hparams.yaml")

# 2. Set config for retraining
# 2-1. FX Model
trainer.set_fx_model(fx_model_path="./temp/FX_MODEL_PATH.pt")

# 2-2. Training
trainer.set_training_config(
   epochs=30,
   batch_size=16,
   optimizer=AdamW(lr=6e-3),
)

# 3. Train
trainer.train(gpus="0, 1", project_name="project_retrain_sample")