diff --git a/machine-learning/immich_ml/models/constants.py b/machine-learning/immich_ml/models/constants.py index 10a4ae48a9..db9e7cfa4d 100644 --- a/machine-learning/immich_ml/models/constants.py +++ b/machine-learning/immich_ml/models/constants.py @@ -78,6 +78,14 @@ _INSIGHTFACE_MODELS = { _PADDLE_MODELS = { "PP-OCRv5_server", "PP-OCRv5_mobile", + "CH__PP-OCRv5_server", + "CH__PP-OCRv5_mobile", + "EL__PP-OCRv5_mobile", + "EN__PP-OCRv5_mobile", + "ESLAV__PP-OCRv5_mobile", + "KOREAN__PP-OCRv5_mobile", + "LATIN__PP-OCRv5_mobile", + "TH__PP-OCRv5_mobile", } SUPPORTED_PROVIDERS = [ diff --git a/machine-learning/immich_ml/models/ocr/detection.py b/machine-learning/immich_ml/models/ocr/detection.py index 0a9d09b599..07a2f3cce2 100644 --- a/machine-learning/immich_ml/models/ocr/detection.py +++ b/machine-learning/immich_ml/models/ocr/detection.py @@ -23,7 +23,7 @@ class TextDetector(InferenceModel): identity = (ModelType.DETECTION, ModelTask.OCR) def __init__(self, model_name: str, **model_kwargs: Any) -> None: - super().__init__(model_name, **model_kwargs, model_format=ModelFormat.ONNX) + super().__init__(model_name.split("__")[-1], **model_kwargs, model_format=ModelFormat.ONNX) self.max_resolution = 736 self.mean = np.array([0.5, 0.5, 0.5], dtype=np.float32) self.std_inv = np.float32(1.0) / (np.array([0.5, 0.5, 0.5], dtype=np.float32) * 255.0) diff --git a/machine-learning/immich_ml/models/ocr/recognition.py b/machine-learning/immich_ml/models/ocr/recognition.py index 0f91fc4105..af3f99dbdb 100644 --- a/machine-learning/immich_ml/models/ocr/recognition.py +++ b/machine-learning/immich_ml/models/ocr/recognition.py @@ -25,6 +25,7 @@ class TextRecognizer(InferenceModel): identity = (ModelType.RECOGNITION, ModelTask.OCR) def __init__(self, model_name: str, **model_kwargs: Any) -> None: + self.language = LangRec[model_name.split("__")[0]] if "__" in model_name else LangRec.CH self.min_score = model_kwargs.get("minScore", 0.9) self._empty: TextRecognitionOutput = { "box": np.empty(0, dtype=np.float32), @@ -41,7 +42,7 @@ class TextRecognizer(InferenceModel): engine_type=EngineType.ONNXRUNTIME, ocr_version=OCRVersion.PPOCRV5, task_type=TaskType.REC, - lang_type=LangRec.CH, + lang_type=self.language, model_type=RapidModelType.MOBILE if "mobile" in self.model_name else RapidModelType.SERVER, ) ) @@ -61,6 +62,7 @@ class TextRecognizer(InferenceModel): session=session.session, rec_batch_num=settings.max_batch_size.text_recognition if settings.max_batch_size is not None else 6, rec_img_shape=(3, 48, 320), + lang_type=self.language, ) ) return session diff --git a/machine-learning/immich_ml/models/ocr/schemas.py b/machine-learning/immich_ml/models/ocr/schemas.py index a63c8dd8e5..78e8619a0b 100644 --- a/machine-learning/immich_ml/models/ocr/schemas.py +++ b/machine-learning/immich_ml/models/ocr/schemas.py @@ -20,8 +20,8 @@ class TextRecognitionOutput(TypedDict): # RapidOCR expects `engine_type`, `lang_type`, and `font_path` to be attributes class OcrOptions(dict[str, Any]): - def __init__(self, **options: Any) -> None: + def __init__(self, lang_type: LangRec | None = None, **options: Any) -> None: super().__init__(**options) self.engine_type = EngineType.ONNXRUNTIME - self.lang_type = LangRec.CH + self.lang_type = lang_type self.font_path = None diff --git a/web/src/lib/components/admin-settings/MachineLearningSettings.svelte b/web/src/lib/components/admin-settings/MachineLearningSettings.svelte index 7649ee8d17..e05b5088a4 100644 --- a/web/src/lib/components/admin-settings/MachineLearningSettings.svelte +++ b/web/src/lib/components/admin-settings/MachineLearningSettings.svelte @@ -275,8 +275,14 @@ name="ocr-model" bind:value={config.machineLearning.ocr.modelName} options={[ - { value: 'PP-OCRv5_server', text: 'PP-OCRv5_server' }, - { value: 'PP-OCRv5_mobile', text: 'PP-OCRv5_mobile' }, + { text: 'PP-OCRv5_server (Chinese, Japanese and English)', value: 'PP-OCRv5_server' }, + { text: 'PP-OCRv5_mobile (Chinese, Japanese and English)', value: 'PP-OCRv5_mobile' }, + { text: 'PP-OCRv5_mobile (English-only)', value: 'EN__PP-OCRv5_mobile' }, + { text: 'PP-OCRv5_mobile (Greek and English)', value: 'EL__PP-OCRv5_mobile' }, + { text: 'PP-OCRv5_mobile (Korean and English)', value: 'KOREAN__PP-OCRv5_mobile' }, + { text: 'PP-OCRv5_mobile (Latin script languages)', value: 'LATIN__PP-OCRv5_mobile' }, + { text: 'PP-OCRv5_mobile (Russian, Belarusian, Ukrainian and English)', value: 'ESLAV__PP-OCRv5_mobile' }, + { text: 'PP-OCRv5_mobile (Thai and English)', value: 'TH__PP-OCRv5_mobile' }, ]} disabled={disabled || !config.machineLearning.enabled || !config.machineLearning.ocr.enabled} isEdited={config.machineLearning.ocr.modelName !== savedConfig.machineLearning.ocr.modelName}