from dataclasses import dataclass
from enum import Enum
from typing import Optional
from netspresso.np_qai.options.common import CommonOptions
[docs]class Framework(str, Enum):
PYTORCH = "pytorch"
ONNX = "onnx"
ONNXRUNTIME = "onnxruntime"
AIMET = "aimet"
TENSORFLOW = "tensorflow"
TFLITE = "tensorflow_lite"
COREML = "coreml"
TENSORRT = "tensorrt"
QNN = "qnn"
[docs]class Extension(str, Enum):
ONNX = ".onnx"
PT = ".pt"
AIMET = ".aimet"
H5 = ".h5"
[docs]class Runtime(str, Enum):
TFLITE = "tflite"
QNN_LIB_AARCH64_ANDROID = "qnn_lib_aarch64_android"
QNN_CONTEXT_BINARY = "qnn_context_binary"
ONNX = "onnx"
PRECOMPILED_QNN_ONNX = "precompiled_qnn_onnx"
[docs]class QuantizeFullType(str, Enum):
INT8 = "int8"
INT16 = "int16"
W8A16 = "w8a16"
W4A8 = "w4a8"
W4A16 = "w4a16"
[docs]class QuantizeWeightType(str, Enum):
FP16 = "float16"
[docs]@dataclass
class CompileOptions(CommonOptions):
"""
Compile options for the model.
Note:
For details, see `CompileOptions in QAI Hub API <https://app.aihub.qualcomm.com/docs/hub/api.html#compile-options>`_.
"""
target_runtime: Optional[Runtime] = Runtime.TFLITE
output_names: Optional[str] = None
truncate_64bit_tensors: Optional[bool] = False
truncate_64bit_io: Optional[bool] = False
force_channel_last_input: Optional[str] = None
force_channel_last_output: Optional[str] = None
quantize_full_type: Optional[QuantizeFullType] = None
quantize_weight_type: Optional[QuantizeWeightType] = None
quantize_io: Optional[bool] = False
quantize_io_type: Optional[str] = None
qnn_graph_name: Optional[str] = None
qnn_context_binary_vtcm: Optional[str] = None
qnn_context_binary_optimization_level: Optional[int] = None
[docs] def to_cli_string(self) -> str:
args = []
if self.compute_unit is not None:
compute_units = ",".join([unit.name.lower() for unit in self.compute_unit])
args.append(f"--compute_unit {compute_units}")
if self.target_runtime is not None:
args.append(f"--target_runtime {self.target_runtime}")
if self.output_names is not None:
output_names_str = ",".join(self.output_names.split()) # Split and join to handle spaces
args.append(f'--output_names "{output_names_str}"')
if self.truncate_64bit_tensors:
args.append("--truncate_64bit_tensors")
if self.truncate_64bit_io:
args.append("--truncate_64bit_io")
if self.force_channel_last_input is not None:
args.append(f'--force_channel_last_input "{self.force_channel_last_input}"')
if self.force_channel_last_output is not None:
args.append(f'--force_channel_last_output "{self.force_channel_last_output}"')
if self.quantize_full_type is not None:
args.append(f"--quantize_full_type {self.quantize_full_type}")
if self.quantize_weight_type is not None:
args.append(f"--quantize_weight_type {self.quantize_weight_type}")
if self.quantize_io:
args.append("--quantize_io")
if self.quantize_io_type is not None:
args.append(f'--quantize_io_type {self.quantize_io_type}')
if self.qnn_graph_name is not None:
args.append(f'--qnn_graph_name {self.qnn_graph_name}')
if self.qnn_context_binary_vtcm is not None:
args.append(f'--qnn_context_binary_vtcm {self.qnn_context_binary_vtcm}')
if self.qnn_context_binary_optimization_level is not None:
args.append(f"--qnn_context_binary_optimization_level {self.qnn_context_binary_optimization_level}")
return " ".join(args)