Source code for netspresso.trainer.trainer

import warnings
from pathlib import Path
from typing import Any, Dict, List, Optional, Union

from loguru import logger
from omegaconf import OmegaConf

from netspresso.base import NetsPressoBase
from netspresso.clients.auth import TokenHandler
from netspresso.clients.launcher import launcher_client_v2
from netspresso.enums import Framework, Optimizer, Scheduler, ServiceTask, Status, Task
from netspresso.exceptions.trainer import (
    BaseDirectoryNotFoundException,
    DirectoryNotFoundException,
    FailedTrainingException,
    FileNotFoundErrorException,
    NotSetDatasetException,
    NotSetModelException,
    NotSupportedModelException,
    NotSupportedTaskException,
    RetrainingFunctionException,
    TaskOrYamlPathException,
)
from netspresso.metadata.common import InputShape
from netspresso.metadata.trainer import TrainerMetadata
from netspresso.trainer.augmentations import AUGMENTATION_CONFIG_TYPE, AugmentationConfig, Transform
from netspresso.trainer.data import DATA_CONFIG_TYPE, ImageLabelPathConfig, PathConfig
from netspresso.trainer.models import (
    CLASSIFICATION_MODELS,
    DETECTION_MODELS,
    SEGMENTATION_MODELS,
    CheckpointConfig,
    ModelConfig,
)
from netspresso.trainer.trainer_configs import TrainerConfigs
from netspresso.trainer.training import TRAINING_CONFIG_TYPE, EnvironmentConfig, LoggingConfig, ScheduleConfig
from netspresso.utils import FileHandler
from netspresso.utils.metadata import MetadataHandler


[docs]class Trainer(NetsPressoBase): def __init__( self, token_handler: TokenHandler, task: Optional[Union[str, Task]] = None, yaml_path: Optional[str] = None ) -> None: """Initialize the Trainer. Args: task (Union[str, Task]], optional): The type of task (classification, detection, segmentation). Either 'task' or 'yaml_path' must be provided, but not both. yaml_path (str, optional): Path to the YAML configuration file. Either 'task' or 'yaml_path' must be provided, but not both. """ self.token_handler = token_handler self.deprecated_names = { "EfficientFormer": "EfficientFormer-L1", "MobileNetV3_Small": "MobileNetV3-S", "MobileNetV3_Large": "MobileNetV3-L", "ViT-Tiny": "ViT-T", "MixNet-Small": "MixNet-S", "MixNet-Medium": "MixNet-M", "MixNet-Large": "MixNet-L", "PIDNet": "PIDNet-S", } if (task is not None) == (yaml_path is not None): raise TaskOrYamlPathException() if task is not None: self._initialize_from_task(task) elif yaml_path is not None: self._initialize_from_yaml(yaml_path) def _initialize_from_task(self, task: Union[str, Task]) -> None: """Initialize the Trainer object based on the provided task. Args: task (Union[str, Task]): The task for which the Trainer is initialized. """ self.task = self._validate_task(task) self.available_models = list(self._get_available_models().keys()) self.data = None self.model = None self.training = TRAINING_CONFIG_TYPE[self.task]() self.augmentation = AUGMENTATION_CONFIG_TYPE[self.task]() self.logging = LoggingConfig() self.environment = EnvironmentConfig() def _initialize_from_yaml(self, yaml_path: str) -> None: """Initialize the Trainer object based on the configuration provided in a YAML file. Args: yaml_path (str): The path to the YAML file containing the configuration. """ hparams = OmegaConf.load(yaml_path) hparams["model"].pop("single_task_model") metadata_path = Path(yaml_path).parent / "metadata.json" metadata = FileHandler.load_json(metadata_path) self.model_name = metadata["model_info"]["model"] self.img_size = hparams["augmentation"]["img_size"] self.task = hparams["data"]["task"] self.available_models = list(self._get_available_models().keys()) self.data = DATA_CONFIG_TYPE[self.task](**hparams["data"]) self.model = ModelConfig(**hparams["model"]) self.training = ScheduleConfig(**hparams["training"]) self.augmentation = AugmentationConfig(**hparams["augmentation"]) self.logging = LoggingConfig(**hparams["logging"]) self.environment = EnvironmentConfig(**hparams["environment"]) def _validate_task(self, task: Union[str, Task]): """Validate the provided task. Args: task (Union[str, Task]): The task to be validated. Raises: ValueError: If the provided task is not supported. Returns: Task: The validated task. """ available_tasks = [task.value for task in Task] if task not in available_tasks: raise NotSupportedTaskException(available_tasks, task) return task def _validate_config(self): """Validate the configuration setup. Raises: ValueError: Raised if the dataset is not set. Use `set_dataset_config` or `set_dataset_config_with_yaml` to set the dataset configuration. ValueError: Raised if the model is not set. Use `set_model_config` or `set_model_config_with_yaml` to set the model configuration. """ if self.data is None: raise NotSetDatasetException() if self.model is None: raise NotSetModelException() def _get_available_models(self) -> Dict[str, Any]: """Get available models based on the current task. Returns: Dict[str, Any]: A dictionary mapping model types to available models. """ available_models = { "classification": CLASSIFICATION_MODELS, "detection": DETECTION_MODELS, "segmentation": SEGMENTATION_MODELS, }[self.task] # Filter out deprecated names filtered_models = { name: config for name, config in available_models.items() if name not in self.deprecated_names } return filtered_models def _get_available_models_w_deprecated_names(self) -> Dict[str, Any]: """Get available models based on the current task. Returns: Dict[str, Any]: A dictionary mapping model types to available models. """ available_models = { "classification": CLASSIFICATION_MODELS, "detection": DETECTION_MODELS, "segmentation": SEGMENTATION_MODELS, }[self.task] return available_models
[docs] def set_dataset_config( self, name: str, root_path: str, train_image: str = "images/train", train_label: str = "labels/train", valid_image: str = "images/valid", valid_label: str = "labels/valid", test_image: str = "images/valid", test_label: str = "labels/valid", id_mapping: Optional[Union[List[str], Dict[str, str], str]] = None, ): """Set the dataset configuration for the Trainer. Args: name (str): The name of dataset. root_path (str): Root directory of dataset. train_image (str, optional): The directory for training images. Should be relative path to root directory. Defaults to "images/train". train_label (str, optional): The directory for training labels. Should be relative path to root directory. Defaults to "labels/train". valid_image (str, optional): The directory for validation images. Should be relative path to root directory. Defaults to "images/val". valid_label (str, optional): The directory for validation labels. Should be relative path to root directory. Defaults to "labels/val". id_mapping (Union[List[str], Dict[str, str]], optional): ID mapping for the dataset. Defaults to None. """ common_config = { "name": name, "path": PathConfig( root=root_path, train=ImageLabelPathConfig(image=train_image, label=train_label), valid=ImageLabelPathConfig(image=valid_image, label=valid_label), test=ImageLabelPathConfig(image=test_image, label=test_label) ), "id_mapping": id_mapping, } self.data = DATA_CONFIG_TYPE[self.task](**common_config)
def check_paths_exist(self, base_path): paths = [ "images/train", "images/valid", "id_mapping.json", ] # Check for the existence of required directories and files for relative_path in paths: path = Path(base_path) / relative_path if not path.exists(): if path.suffix: # It's a file raise FileNotFoundErrorException(relative_path) else: # It's a directory raise DirectoryNotFoundException(relative_path) def find_paths(self, base_path: str, search_dir, split: str) -> List[str]: base_dir = Path(base_path) if not base_dir.exists(): raise BaseDirectoryNotFoundException(base_dir) result_paths = [] dir_path = base_dir / search_dir if dir_path.exists() and dir_path.is_dir(): for item in dir_path.iterdir(): if (item.is_dir() or item.is_file()) and split in item.name: result_paths.append(item.as_posix()) return result_paths[0] def set_dataset(self, dataset_root_path: str): dataset_name = Path(dataset_root_path).name root_path = Path(dataset_root_path).resolve().as_posix() self.check_paths_exist(root_path) images_train = self.find_paths(root_path, "images", "train") images_valid = self.find_paths(root_path, "images", "valid") labels_train = self.find_paths(root_path, "labels", "train") labels_valid = self.find_paths(root_path, "labels", "valid") id_mapping = FileHandler.load_json(f"{root_path}/id_mapping.json") self.set_dataset_config( name=dataset_name, root_path=dataset_root_path, train_image=images_train, train_label=labels_train, valid_image=images_valid, valid_label=labels_valid, id_mapping=id_mapping, )
[docs] def set_model_config( self, model_name: str, img_size: int, use_pretrained: bool = True, load_head: bool = False, path: Optional[str] = None, fx_model_path: Optional[str] = None, optimizer_path: Optional[str] = None, ): """Set the model configuration for the Trainer. Args: model_name (str): Name of the model. img_size (int): Image size for the model. use_pretrained (bool, optional): Whether to use a pre-trained model. Defaults to True. load_head (bool, optional): Whether to load the model head. Defaults to False. path (str, optional): Path to the model. Defaults to None. fx_model_path (str, optional): Path to the FX model. Defaults to None. optimizer_path (str, optional): Path to the optimizer. Defaults to None. Raises: ValueError: If the specified model is not supported for the current task. """ if model_name in self.deprecated_names: warnings.filterwarnings("default", category=DeprecationWarning) warnings.warn( f"The model name '{model_name}' is deprecated and will be removed in a future version. " f"Please use '{self.deprecated_names[model_name]}' instead.", DeprecationWarning, stacklevel=2, ) warnings.filterwarnings("ignore", category=DeprecationWarning) self.model_name = model_name model = self._get_available_models_w_deprecated_names().get(model_name) self.img_size = img_size self.logging.sample_input_size = [img_size, img_size] if model is None: raise NotSupportedModelException() self.model = model( checkpoint=CheckpointConfig( use_pretrained=use_pretrained, load_head=load_head, path=path, fx_model_path=fx_model_path, optimizer_path=optimizer_path, ) )
[docs] def set_fx_model(self, fx_model_path: str): """Set the FX model path for retraining. Args: fx_model_path (str): The path to the FX model. Raises: ValueError: If the model is not set. Please use 'set_model_config' for model setup. """ if not self.model: raise RetrainingFunctionException() self.model.checkpoint.path = None self.model.checkpoint.fx_model_path = fx_model_path
[docs] def set_training_config( self, optimizer, scheduler, epochs: int = 3, batch_size: int = 8, ): """Set the training configuration. Args: optimizer: The configuration of optimizer. scheduler: The configuration of learning rate scheduler. epochs (int, optional): The total number of epoch for training the model. Defaults to 3. batch_size (int, optional): The number of samples in single batch input. Defaults to 8. """ self.training = ScheduleConfig( epochs=epochs, optimizer=optimizer.asdict(), scheduler=scheduler.asdict(), ) self.environment.batch_size = batch_size
[docs] def set_augmentation_config( self, train_transforms: Optional[List] = None, inference_transforms: Optional[List] = None, ): """Set the augmentation configuration for training. Args: train_transforms (List, optional): List of transforms for training. Defaults to None. inference_transforms (List, optional): List of transforms for inference. Defaults to None. """ self.augmentation = AugmentationConfig( train=train_transforms, inference=inference_transforms, )
[docs] def set_logging_config( self, project_id: Optional[str] = None, output_dir: str = "./outputs", tensorboard: bool = True, csv: bool = False, image: bool = True, stdout: bool = True, save_optimizer_state: bool = True, validation_epoch: int = 10, save_checkpoint_epoch: Optional[int] = None, ): """Set the logging configuration. Args: project_id (str, optional): Project name to save the experiment. If None, it is set as {task}_{model} (e.g. segmentation_segformer). output_dir (str, optional): Root directory for saving the experiment. Defaults to "./outputs". tensorboard (bool, optional): Whether to use the tensorboard. Defaults to True. csv (bool, optional): Whether to save the result in csv format. Defaults to False. image (bool, optional): Whether to save the validation results. It is ignored if the task is classification. Defaults to True. stdout (bool, optional): Whether to log the standard output. Defaults to True. save_optimizer_state (bool, optional): Whether to save optimizer state with model checkpoint to resume training. Defaults to True. validation_epoch (int, optional): Validation frequency in total training process. Defaults to 10. save_checkpoint_epoch (int, optional): Checkpoint saving frequency in total training process. Defaults to None. """ self.logging = LoggingConfig( project_id=project_id, output_dir=output_dir, tensorboard=tensorboard, csv=csv, image=image, stdout=stdout, save_optimizer_state=save_optimizer_state, validation_epoch=validation_epoch, save_checkpoint_epoch=save_checkpoint_epoch, )
[docs] def set_environment_config(self, seed: int = 1, num_workers: int = 4): """Set the environment configuration. Args: seed (int, optional): Random seed. Defaults to 1. num_workers (int, optional): The number of multi-processing workers to be used by the data loader. Defaults to 4. """ self.environment = EnvironmentConfig(seed=seed, num_workers=num_workers)
def _change_transforms(self, transforms: Transform): """Update the 'size' attribute in the given list of transforms with the specified image size. Args: transforms (List[Transform]): The list of transforms to be updated. Returns: List[Transform]: The list of transforms with the 'size' attribute updated to the specified image size. """ field_name_to_check = "size" if transforms is None: return transforms for transform in transforms: field_type = transform.__annotations__.get(field_name_to_check) if field_type == List: transform.size = [self.img_size, self.img_size] elif isinstance(field_type, int): transform.size = self.img_size return transforms def _apply_img_size(self): """Apply the specified image size to the augmentation configurations. This method updates the 'img_size' attribute in the augmentation configurations, including 'train.transforms', 'train.mix_transforms', and 'inference.transforms'. """ self.augmentation.img_size = self.img_size self.augmentation.train = self._change_transforms(self.augmentation.train) self.augmentation.inference = self._change_transforms(self.augmentation.inference) def _get_available_options(self): self.token_handler.validate_token() options_response = launcher_client_v2.converter.read_framework_options( access_token=self.token_handler.tokens.access_token, framework=Framework.ONNX, ) available_options = options_response.data # TODO: Will be removed when we support DLC in the future available_options = [ available_option for available_option in available_options if available_option.framework != "dlc" ] return available_options def _get_status_by_training_summary(self, status): status_mapping = { "success": Status.COMPLETED, "stop": Status.STOPPED, "error": Status.ERROR, "": Status.IN_PROGRESS } return status_mapping.get(status, Status.IN_PROGRESS) def initialize_metadata(self, output_dir): def create_metadata_with_status(status, error_message=None): metadata = TrainerMetadata() metadata.status = status if error_message: logger.error(error_message) return metadata try: metadata = TrainerMetadata() except Exception as e: error_message = f"An unexpected error occurred during metadata initialization: {e}" metadata = create_metadata_with_status(Status.ERROR, error_message) except KeyboardInterrupt: warning_message = "Training task was interrupted by the user." metadata = create_metadata_with_status(Status.STOPPED, warning_message) finally: metadata.update_output_dir(output_dir.resolve().as_posix()) metadata.update_model_info( task=self.task, model=self.model_name, dataset=self.data.name, input_shapes=[InputShape(batch=1, channel=3, dimension=[self.img_size, self.img_size])], ) metadata.update_training_info( epochs=self.training.epochs, batch_size=self.environment.batch_size, learning_rate=self.training.optimizer["lr"], optimizer=Optimizer.to_display_name(self.training.optimizer["name"]), scheduler=Scheduler.to_display_name(self.training.scheduler["name"]), ) metadata.update_hparams(hparams=(output_dir / "hparams.yaml").resolve().as_posix()) MetadataHandler.save_metadata(data=metadata, folder_path=output_dir) return metadata def find_best_model_paths(self, destination_folder: Path): best_fx_paths_set = set() for pattern in ["*best_fx.pt", "*best.pt"]: best_fx_paths_set.update(destination_folder.glob(pattern)) best_fx_paths = list(best_fx_paths_set) best_onnx_paths = list(destination_folder.glob("*best.onnx")) return best_fx_paths, best_onnx_paths def create_runtime_config(self, yaml_path): hparams = OmegaConf.load(yaml_path) preprocess = hparams.augmentation.inference for _preprocess in preprocess: if hasattr(_preprocess, 'size') and _preprocess.size: _preprocess.size = _preprocess.size[0] if _preprocess.name == "resize": _preprocess.resize_criteria = "long" if hparams.model.task == Task.IMAGE_CLASSIFICATION: visualize = {"params": {"class_map": hparams.data.id_mapping, "pallete": None}} elif hparams.model.task == Task.OBJECT_DETECTION: visualize = {"params": {"class_map": hparams.data.id_mapping, "normalized": False, "brightness_factor": 1.5}} elif hparams.model.task == Task.SEMANTIC_SEGMENTATION: visualize = {"params": {"class_map": hparams.data.id_mapping, "pallete": None, "normalized": False, "brightness_factor": 1.5}} _config = { "task": hparams.model.task, "preprocess": preprocess, "postprocess": hparams.model.postprocessor, "visualize": visualize, } name = "runtime" config = OmegaConf.create({name: _config}) save_path = Path(yaml_path).parent / f"{name}.yaml" OmegaConf.save(config=config, f=save_path) return save_path
[docs] def train(self, gpus: str, project_name: str, output_dir: Optional[str] = "./outputs") -> TrainerMetadata: """Train the model with the specified configuration. Args: gpus (str): GPU ids to use, separated by commas. project_name (str): Project name to save the experiment. Returns: Dict: A dictionary containing information about the training. """ from netspresso_trainer import train_with_yaml self._validate_config() self._apply_img_size() project_name = project_name if project_name else f"{self.task}_{self.model_name}".lower() destination_folder = Path(output_dir) / project_name destination_folder = FileHandler.create_unique_folder(folder_path=destination_folder) metadata = self.initialize_metadata(output_dir=destination_folder) try: self.logging.output_dir = output_dir self.logging.project_id = destination_folder.name self.logging_dir = Path(self.logging.output_dir) / self.logging.project_id / "version_0" self.environment.gpus = gpus configs = TrainerConfigs( self.data, self.augmentation, self.model, self.training, self.logging, self.environment, ) train_with_yaml( gpus=gpus, data=configs.data, augmentation=configs.augmentation, model=configs.model, training=configs.training, logging=configs.logging, environment=configs.environment, ) available_options = self._get_available_options() metadata.update_available_options(available_options) except Exception as e: e = FailedTrainingException(error_log=e.args[0]) metadata = self.handle_error(metadata, ServiceTask.TRAINING, e.args[0]) except KeyboardInterrupt: metadata = self.handle_stop(metadata, ServiceTask.TRAINING) finally: FileHandler.remove_folder(configs.temp_folder) logger.info(f"Removed {configs.temp_folder} folder.") FileHandler.move_and_cleanup_folders(source_folder=self.logging_dir, destination_folder=destination_folder) logger.info(f"Files in {self.logging_dir} were moved to {destination_folder}.") runtime_config_path = self.create_runtime_config(yaml_path=destination_folder / "hparams.yaml") metadata.runtime = runtime_config_path.as_posix() training_summary = FileHandler.load_json(file_path=destination_folder / "training_summary.json") metadata.update_training_result(training_summary=training_summary) status = self._get_status_by_training_summary(training_summary.get("status")) metadata.update_status(status=status) if status == Status.ERROR: error_stats = training_summary.get("error_stats", "") e = FailedTrainingException(error_log=error_stats) metadata.update_message(exception_detail=e.args[0]) best_fx_paths, best_onnx_paths = self.find_best_model_paths(destination_folder) if best_fx_paths: metadata.update_best_fx_model_path(best_fx_model_path=best_fx_paths[0].resolve().as_posix()) if best_onnx_paths: metadata.update_best_onnx_model_path(best_onnx_model_path=best_onnx_paths[0].resolve().as_posix()) MetadataHandler.save_metadata(data=metadata, folder_path=destination_folder) return metadata