Trainer¶
Description¶
netspresso.trainer.trainer.Trainer
¶
Bases: NetsPressoBase
__init__(token_handler, task=None, yaml_path=None)
¶
Initialize the Trainer.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
task
|
Union[str, Task]]
|
The type of task (classification, detection, segmentation). Either 'task' or 'yaml_path' must be provided, but not both. |
None
|
yaml_path
|
str
|
Path to the YAML configuration file. Either 'task' or 'yaml_path' must be provided, but not both. |
None
|
set_augmentation_config(train_transforms=None, inference_transforms=None)
¶
Set the augmentation configuration for training.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
train_transforms
|
List
|
List of transforms for training. Defaults to None. |
None
|
inference_transforms
|
List
|
List of transforms for inference. Defaults to None. |
None
|
set_dataset_config(name, root_path, train_image='images/train', train_label='labels/train', valid_image='images/valid', valid_label='labels/valid', test_image='images/valid', test_label='labels/valid', id_mapping=None)
¶
Set the dataset configuration for the Trainer.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
name
|
str
|
The name of dataset. |
required |
root_path
|
str
|
Root directory of dataset. |
required |
train_image
|
str
|
The directory for training images. Should be relative path to root directory. Defaults to "images/train". |
'images/train'
|
train_label
|
str
|
The directory for training labels. Should be relative path to root directory. Defaults to "labels/train". |
'labels/train'
|
valid_image
|
str
|
The directory for validation images. Should be relative path to root directory. Defaults to "images/val". |
'images/valid'
|
valid_label
|
str
|
The directory for validation labels. Should be relative path to root directory. Defaults to "labels/val". |
'labels/valid'
|
id_mapping
|
Union[List[str], Dict[str, str]]
|
ID mapping for the dataset. Defaults to None. |
None
|
set_environment_config(seed=1, num_workers=4)
¶
Set the environment configuration.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
seed
|
int
|
Random seed. Defaults to 1. |
1
|
num_workers
|
int
|
The number of multi-processing workers to be used by the data loader. Defaults to 4. |
4
|
set_fx_model(fx_model_path)
¶
Set the FX model path for retraining.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
fx_model_path
|
str
|
The path to the FX model. |
required |
Raises:
| Type | Description |
|---|---|
ValueError
|
If the model is not set. Please use 'set_model_config' for model setup. |
set_logging_config(project_id=None, output_dir='./outputs', tensorboard=True, csv=False, image=True, stdout=True, save_optimizer_state=True, validation_epoch=10, save_checkpoint_epoch=None)
¶
Set the logging configuration.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
project_id
|
str
|
Project name to save the experiment. If None, it is set as {task}_{model} (e.g. segmentation_segformer). |
None
|
output_dir
|
str
|
Root directory for saving the experiment. Defaults to "./outputs". |
'./outputs'
|
tensorboard
|
bool
|
Whether to use the tensorboard. Defaults to True. |
True
|
csv
|
bool
|
Whether to save the result in csv format. Defaults to False. |
False
|
image
|
bool
|
Whether to save the validation results. It is ignored if the task is classification. Defaults to True. |
True
|
stdout
|
bool
|
Whether to log the standard output. Defaults to True. |
True
|
save_optimizer_state
|
bool
|
Whether to save optimizer state with model checkpoint to resume training. Defaults to True. |
True
|
validation_epoch
|
int
|
Validation frequency in total training process. Defaults to 10. |
10
|
save_checkpoint_epoch
|
int
|
Checkpoint saving frequency in total training process. Defaults to None. |
None
|
set_model_config(model_name, img_size, use_pretrained=True, load_head=False, path=None, fx_model_path=None, optimizer_path=None)
¶
Set the model configuration for the Trainer.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model_name
|
str
|
Name of the model. |
required |
img_size
|
int
|
Image size for the model. |
required |
use_pretrained
|
bool
|
Whether to use a pre-trained model. Defaults to True. |
True
|
load_head
|
bool
|
Whether to load the model head. Defaults to False. |
False
|
path
|
str
|
Path to the model. Defaults to None. |
None
|
fx_model_path
|
str
|
Path to the FX model. Defaults to None. |
None
|
optimizer_path
|
str
|
Path to the optimizer. Defaults to None. |
None
|
Raises:
| Type | Description |
|---|---|
ValueError
|
If the specified model is not supported for the current task. |
set_training_config(optimizer, scheduler, epochs=3, batch_size=8)
¶
Set the training configuration.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
optimizer
|
The configuration of optimizer. |
required | |
scheduler
|
The configuration of learning rate scheduler. |
required | |
epochs
|
int
|
The total number of epoch for training the model. Defaults to 3. |
3
|
batch_size
|
int
|
The number of samples in single batch input. Defaults to 8. |
8
|
train(gpus, project_name, output_dir='./outputs')
¶
Train the model with the specified configuration.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
gpus
|
str
|
GPU ids to use, separated by commas. |
required |
project_name
|
str
|
Project name to save the experiment. |
required |
Returns:
| Name | Type | Description |
|---|---|---|
Dict |
TrainerMetadata
|
A dictionary containing information about the training. |
Examples¶
Training¶
from netspresso import NetsPresso
from netspresso.enums import Task
from netspresso.trainer.augmentations import Resize
from netspresso.trainer.optimizers import AdamW
from netspresso.trainer.schedulers import CosineAnnealingWarmRestartsWithCustomWarmUp
# Login with API key (recommended)
# Get your API token from: https://account.netspresso.ai/api-token
netspresso = NetsPresso(api_key="YOUR_API_KEY")
# Note: Email/password login will be deprecated soon
# netspresso = NetsPresso(email="YOUR_EMAIL", password="YOUR_PASSWORD")
# 1. Declare trainer
trainer = netspresso.trainer(task=Task.OBJECT_DETECTION)
# 2. Set config for training
# 2-1. Data
trainer.set_dataset_config(
name="traffic_sign_config_example",
root_path="/root/traffic-sign",
train_image="images/train",
train_label="labels/train",
valid_image="images/valid",
valid_label="labels/valid",
id_mapping=["prohibitory", "danger", "mandatory", "other"],
)
# 2-2. Model
print(trainer.available_models) # ['EfficientFormer', 'YOLOX-S', ...]
trainer.set_model_config(model_name="YOLOX-S", img_size=512)
# 2-3. Augmentation
trainer.set_augmentation_config(
train_transforms=[Resize()],
inference_transforms=[Resize()],
)
# 2-4. Training
optimizer = AdamW(lr=6e-3)
scheduler = CosineAnnealingWarmRestartsWithCustomWarmUp(warmup_epochs=10)
trainer.set_training_config(
epochs=40,
batch_size=16,
optimizer=optimizer,
scheduler=scheduler,
)
# 3. Train
training_result = trainer.train(gpus="0, 1", project_name="project_sample")
Retraining¶
from netspresso import NetsPresso
from netspresso.trainer.optimizers import AdamW
# Login with API key (recommended)
# Get your API token from: https://account.netspresso.ai/api-token
netspresso = NetsPresso(api_key="YOUR_API_KEY")
# Note: Email/password login will be deprecated soon
# netspresso = NetsPresso(email="YOUR_EMAIL", password="YOUR_PASSWORD")
# 1. Declare trainer
trainer = netspresso.trainer(yaml_path="./temp/hparams.yaml")
# 2. Set config for retraining
# 2-1. FX Model
trainer.set_fx_model(fx_model_path="./temp/FX_MODEL_PATH.pt")
# 2-2. Training
trainer.set_training_config(
epochs=30,
batch_size=16,
optimizer=AdamW(lr=6e-3),
)
# 3. Train
trainer.train(gpus="0, 1", project_name="project_retrain_sample")