mirror of
https://github.com/TheFunny/ArisuAutoSweeper
synced 2026-06-09 20:04:52 +00:00
Upload code
This commit is contained in:
@@ -0,0 +1,181 @@
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from functools import cached_property
|
||||
from typing import ClassVar
|
||||
|
||||
import module.config.server as server
|
||||
from module.exception import ScriptError
|
||||
|
||||
REGEX_PUNCTUATION = re.compile(r'[ ,.\'"“”,。::!!??·•\-—/\\\n\t()\[\]()「」『』【】《》[]]')
|
||||
|
||||
|
||||
def parse_name(n):
|
||||
n = REGEX_PUNCTUATION.sub('', str(n)).lower()
|
||||
return n
|
||||
|
||||
|
||||
@dataclass
|
||||
class Keyword:
|
||||
id: int
|
||||
name: str
|
||||
cn: str
|
||||
en: str
|
||||
jp: str
|
||||
cht: str
|
||||
es: str
|
||||
|
||||
"""
|
||||
Instance attributes and methods
|
||||
"""
|
||||
|
||||
@cached_property
|
||||
def ch(self) -> str:
|
||||
return self.cn
|
||||
|
||||
@cached_property
|
||||
def cn_parsed(self) -> str:
|
||||
return parse_name(self.cn)
|
||||
|
||||
@cached_property
|
||||
def en_parsed(self) -> str:
|
||||
return parse_name(self.en)
|
||||
|
||||
@cached_property
|
||||
def jp_parsed(self) -> str:
|
||||
return parse_name(self.jp)
|
||||
|
||||
@cached_property
|
||||
def cht_parsed(self) -> str:
|
||||
return parse_name(self.cht)
|
||||
|
||||
@cached_property
|
||||
def es_parsed(self) -> str:
|
||||
return parse_name(self.cht)
|
||||
|
||||
def __str__(self):
|
||||
return f'{self.__class__.__name__}({self.name})'
|
||||
|
||||
__repr__ = __str__
|
||||
|
||||
def __eq__(self, other):
|
||||
return str(self) == str(other)
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.name)
|
||||
|
||||
def __bool__(self):
|
||||
return True
|
||||
|
||||
def _keywords_to_find(self, lang: str = None, ignore_punctuation=True):
|
||||
if lang is None:
|
||||
lang = server.lang
|
||||
|
||||
if lang in server.VALID_LANG:
|
||||
match lang:
|
||||
case 'cn':
|
||||
if ignore_punctuation:
|
||||
return [self.cn_parsed]
|
||||
else:
|
||||
return [self.cn]
|
||||
case 'en':
|
||||
if ignore_punctuation:
|
||||
return [self.en_parsed]
|
||||
else:
|
||||
return [self.en]
|
||||
case 'jp':
|
||||
if ignore_punctuation:
|
||||
return [self.jp_parsed]
|
||||
else:
|
||||
return [self.jp]
|
||||
case 'cht':
|
||||
if ignore_punctuation:
|
||||
return [self.cht_parsed]
|
||||
else:
|
||||
return [self.cht]
|
||||
case 'es':
|
||||
if ignore_punctuation:
|
||||
return [self.es_parsed]
|
||||
else:
|
||||
return [self.es]
|
||||
else:
|
||||
if ignore_punctuation:
|
||||
return [
|
||||
self.cn_parsed,
|
||||
self.en_parsed,
|
||||
self.jp_parsed,
|
||||
self.cht_parsed,
|
||||
self.es_parsed,
|
||||
]
|
||||
else:
|
||||
return [
|
||||
self.cn,
|
||||
self.en,
|
||||
self.jp,
|
||||
self.cht,
|
||||
self.es,
|
||||
]
|
||||
|
||||
"""
|
||||
Class attributes and methods
|
||||
|
||||
Note that dataclasses inherited `Keyword` must override `instances` attribute,
|
||||
or `instances` will still be a class attribute of base class.
|
||||
```
|
||||
@dataclass
|
||||
class DungeonNav(Keyword):
|
||||
instances: ClassVar = {}
|
||||
```
|
||||
"""
|
||||
# Key: instance ID. Value: instance object.
|
||||
instances: ClassVar = {}
|
||||
|
||||
def __post_init__(self):
|
||||
self.__class__.instances[self.id] = self
|
||||
|
||||
@classmethod
|
||||
def _compare(cls, name, keyword):
|
||||
return name == keyword
|
||||
|
||||
@classmethod
|
||||
def find(cls, name, lang: str = None, ignore_punctuation=True):
|
||||
"""
|
||||
Args:
|
||||
name: Name in any server or instance id.
|
||||
lang: Lang to find from
|
||||
None to search the names from current server only.
|
||||
ignore_punctuation: True to remove punctuations and turn into lowercase before searching.
|
||||
|
||||
Returns:
|
||||
Keyword instance.
|
||||
|
||||
Raises:
|
||||
ScriptError: If nothing found.
|
||||
"""
|
||||
# Already a keyword
|
||||
if isinstance(name, Keyword):
|
||||
return name
|
||||
# Probably an ID
|
||||
if isinstance(name, int) or (isinstance(name, str) and name.isdigit()):
|
||||
try:
|
||||
return cls.instances[int(name)]
|
||||
except KeyError:
|
||||
pass
|
||||
# Probably a variable name
|
||||
if isinstance(name, str) and '_' in name:
|
||||
for instance in cls.instances.values():
|
||||
if name == instance.name:
|
||||
return instance
|
||||
# Probably an in-game name
|
||||
if ignore_punctuation:
|
||||
name = parse_name(name)
|
||||
else:
|
||||
name = str(name)
|
||||
instance: Keyword
|
||||
for instance in cls.instances.values():
|
||||
for keyword in instance._keywords_to_find(
|
||||
lang=lang, ignore_punctuation=ignore_punctuation):
|
||||
if cls._compare(name, keyword):
|
||||
return instance
|
||||
|
||||
# Not found
|
||||
raise ScriptError(f'Cannot find a {cls.__name__} instance that matches "{name}"')
|
||||
@@ -0,0 +1,76 @@
|
||||
from pponnxcr import TextSystem as TextSystem_
|
||||
|
||||
from module.base.decorator import cached_property
|
||||
from module.exception import ScriptError
|
||||
|
||||
DIC_LANG_TO_MODEL = {
|
||||
'cn': 'zhs',
|
||||
'en': 'en',
|
||||
'jp': 'ja',
|
||||
'tw': 'zht',
|
||||
}
|
||||
|
||||
|
||||
def lang2model(lang: str) -> str:
|
||||
"""
|
||||
Args:
|
||||
lang: In-game language name, defined in VALID_LANG
|
||||
|
||||
Returns:
|
||||
str: Model name, defined in pponnxcr.utility
|
||||
"""
|
||||
return DIC_LANG_TO_MODEL.get(lang, lang)
|
||||
|
||||
|
||||
def model2lang(model: str) -> str:
|
||||
"""
|
||||
Args:
|
||||
model: Model name, defined in pponnxcr.utility
|
||||
|
||||
Returns:
|
||||
str: In-game language name, defined in VALID_LANG
|
||||
"""
|
||||
for k, v in DIC_LANG_TO_MODEL.items():
|
||||
if model == v:
|
||||
return k
|
||||
return model
|
||||
|
||||
|
||||
class TextSystem(TextSystem_):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.text_recognizer.rec_batch_num = 1
|
||||
|
||||
|
||||
class OcrModel:
|
||||
def get_by_model(self, model: str) -> TextSystem:
|
||||
try:
|
||||
return self.__getattribute__(model)
|
||||
except AttributeError:
|
||||
raise ScriptError(f'OCR model "{model}" does not exists')
|
||||
|
||||
def get_by_lang(self, lang: str) -> TextSystem:
|
||||
try:
|
||||
model = lang2model(lang)
|
||||
return self.__getattribute__(model)
|
||||
except AttributeError:
|
||||
raise ScriptError(f'OCR model under lang "{lang}" does not exists')
|
||||
|
||||
@cached_property
|
||||
def zhs(self):
|
||||
return TextSystem('zhs')
|
||||
|
||||
@cached_property
|
||||
def en(self):
|
||||
return TextSystem('en')
|
||||
|
||||
@cached_property
|
||||
def ja(self):
|
||||
return TextSystem('ja')
|
||||
|
||||
@cached_property
|
||||
def zht(self):
|
||||
return TextSystem('zht')
|
||||
|
||||
|
||||
OCR_MODEL = OcrModel()
|
||||
@@ -0,0 +1,418 @@
|
||||
import re
|
||||
import time
|
||||
from datetime import timedelta
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
from pponnxcr.predict_system import BoxedResult
|
||||
|
||||
import module.config.server as server
|
||||
from module.base.button import ButtonWrapper
|
||||
from module.base.decorator import cached_property
|
||||
from module.base.utils import area_pad, corner2area, crop, float2str
|
||||
from module.exception import ScriptError
|
||||
from module.logger import logger
|
||||
from module.ocr.keyword import Keyword
|
||||
from module.ocr.models import OCR_MODEL, TextSystem
|
||||
from module.ocr.utils import merge_buttons
|
||||
|
||||
|
||||
class OcrResultButton:
|
||||
def __init__(self, boxed_result: BoxedResult, matched_keyword: Optional[Keyword]):
|
||||
"""
|
||||
Args:
|
||||
boxed_result: BoxedResult from ppocr-onnx
|
||||
matched_keyword: Keyword object or None
|
||||
"""
|
||||
self.area = boxed_result.box
|
||||
self.search = area_pad(self.area, pad=-20)
|
||||
# self.color =
|
||||
self.button = boxed_result.box
|
||||
|
||||
if matched_keyword is not None:
|
||||
self.matched_keyword = matched_keyword
|
||||
self.name = str(matched_keyword)
|
||||
else:
|
||||
self.matched_keyword = None
|
||||
self.name = boxed_result.ocr_text
|
||||
|
||||
self.text = boxed_result.ocr_text
|
||||
self.score = boxed_result.score
|
||||
|
||||
def __str__(self):
|
||||
return self.name
|
||||
|
||||
__repr__ = __str__
|
||||
|
||||
def __eq__(self, other):
|
||||
return str(self) == str(other)
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.name)
|
||||
|
||||
def __bool__(self):
|
||||
return True
|
||||
|
||||
@property
|
||||
def is_keyword_matched(self) -> bool:
|
||||
return self.matched_keyword is not None
|
||||
|
||||
|
||||
class Ocr:
|
||||
# Merge results with box distance <= thres
|
||||
merge_thres_x = 0
|
||||
merge_thres_y = 0
|
||||
|
||||
def __init__(self, button: ButtonWrapper, lang=None, name=None):
|
||||
"""
|
||||
Args:
|
||||
button:
|
||||
lang: If None, use in-game language
|
||||
name: If None, use button.name
|
||||
"""
|
||||
if lang is None:
|
||||
lang = server.lang
|
||||
if name is None:
|
||||
name = button.name
|
||||
|
||||
self.button: ButtonWrapper = button
|
||||
self.lang: str = lang
|
||||
self.name: str = name
|
||||
|
||||
@cached_property
|
||||
def model(self) -> TextSystem:
|
||||
return OCR_MODEL.get_by_lang(self.lang)
|
||||
|
||||
def pre_process(self, image):
|
||||
"""
|
||||
Args:
|
||||
image (np.ndarray): Shape (height, width, channel)
|
||||
|
||||
Returns:
|
||||
np.ndarray: Shape (width, height)
|
||||
"""
|
||||
return image
|
||||
|
||||
def after_process(self, result):
|
||||
"""
|
||||
Args:
|
||||
result (str): '第二行'
|
||||
|
||||
Returns:
|
||||
str:
|
||||
"""
|
||||
if result.startswith('UID'):
|
||||
result = 'UID'
|
||||
return result
|
||||
|
||||
def format_result(self, result):
|
||||
"""
|
||||
Will be overriden.
|
||||
"""
|
||||
return result
|
||||
|
||||
def _log_change(self, attr, func, before):
|
||||
after = func(before)
|
||||
if after != before:
|
||||
logger.attr(f'{self.name} {attr}', f'{before} -> {after}')
|
||||
return after
|
||||
|
||||
def ocr_single_line(self, image, direct_ocr=False):
|
||||
# pre process
|
||||
start_time = time.time()
|
||||
if not direct_ocr:
|
||||
image = crop(image, self.button.area)
|
||||
image = self.pre_process(image)
|
||||
# ocr
|
||||
result, _ = self.model.ocr_single_line(image)
|
||||
# after proces
|
||||
result = self._log_change('after', self.after_process, result)
|
||||
result = self._log_change('format', self.format_result, result)
|
||||
logger.attr(name='%s %ss' % (self.name, float2str(time.time() - start_time)),
|
||||
text=str(result))
|
||||
return result
|
||||
|
||||
def ocr_multi_lines(self, image_list):
|
||||
# pre process
|
||||
start_time = time.time()
|
||||
image_list = [self.pre_process(image) for image in image_list]
|
||||
# ocr
|
||||
result_list = self.model.ocr_lines(image_list)
|
||||
result_list = [(result, score) for result, score in result_list]
|
||||
# after process
|
||||
result_list = [(self.after_process(result), score) for result, score in result_list]
|
||||
result_list = [(self.format_result(result), score) for result, score in result_list]
|
||||
logger.attr(name="%s %ss" % (self.name, float2str(time.time() - start_time)),
|
||||
text=str([result for result, _ in result_list]))
|
||||
return result_list
|
||||
|
||||
def filter_detected(self, result: BoxedResult) -> bool:
|
||||
"""
|
||||
Return False to drop result.
|
||||
"""
|
||||
return True
|
||||
|
||||
def detect_and_ocr(self, image, direct_ocr=False) -> list[BoxedResult]:
|
||||
"""
|
||||
Args:
|
||||
image:
|
||||
direct_ocr: True to ignore `button` attribute and feed the image to OCR model without cropping.
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
# pre process
|
||||
start_time = time.time()
|
||||
if not direct_ocr:
|
||||
image = crop(image, self.button.area)
|
||||
image = self.pre_process(image)
|
||||
# ocr
|
||||
results: list[BoxedResult] = self.model.detect_and_ocr(image)
|
||||
# after proces
|
||||
for result in results:
|
||||
if not direct_ocr:
|
||||
result.box += self.button.area[:2]
|
||||
result.box = tuple(corner2area(result.box))
|
||||
|
||||
results = [result for result in results if self.filter_detected(result)]
|
||||
results = merge_buttons(results, thres_x=self.merge_thres_x, thres_y=self.merge_thres_y)
|
||||
for result in results:
|
||||
result.ocr_text = self.after_process(result.ocr_text)
|
||||
|
||||
logger.attr(name='%s %ss' % (self.name, float2str(time.time() - start_time)),
|
||||
text=str([result.ocr_text for result in results]))
|
||||
return results
|
||||
|
||||
def _match_result(
|
||||
self,
|
||||
result: str,
|
||||
keyword_classes,
|
||||
lang: str = None,
|
||||
ignore_punctuation=True,
|
||||
ignore_digit=True):
|
||||
"""
|
||||
Args:
|
||||
result (str):
|
||||
keyword_classes: A list of `Keyword` class or classes inherited `Keyword`
|
||||
|
||||
Returns:
|
||||
If matched, return `Keyword` object or objects inherited `Keyword`
|
||||
If not match, return None
|
||||
"""
|
||||
if not isinstance(keyword_classes, list):
|
||||
keyword_classes = [keyword_classes]
|
||||
|
||||
# Digits will be considered as the index of keyword
|
||||
if ignore_digit:
|
||||
if result.isdigit():
|
||||
return None
|
||||
|
||||
# Try in current lang
|
||||
for keyword_class in keyword_classes:
|
||||
try:
|
||||
matched = keyword_class.find(
|
||||
result,
|
||||
lang=lang,
|
||||
ignore_punctuation=ignore_punctuation
|
||||
)
|
||||
return matched
|
||||
except ScriptError:
|
||||
continue
|
||||
|
||||
return None
|
||||
|
||||
def matched_single_line(
|
||||
self,
|
||||
image,
|
||||
keyword_classes,
|
||||
lang: str = None,
|
||||
ignore_punctuation=True
|
||||
) -> OcrResultButton:
|
||||
"""
|
||||
Args:
|
||||
image: Image to detect
|
||||
keyword_classes: `Keyword` class or classes inherited `Keyword`, or a list of them.
|
||||
lang:
|
||||
ignore_punctuation:
|
||||
|
||||
Returns:
|
||||
OcrResultButton: Or None if it didn't matched known keywords.
|
||||
"""
|
||||
result = self.ocr_single_line(image)
|
||||
|
||||
result = self._match_result(
|
||||
result,
|
||||
keyword_classes=keyword_classes,
|
||||
lang=lang,
|
||||
ignore_punctuation=ignore_punctuation,
|
||||
)
|
||||
|
||||
logger.attr(name=f'{self.name} matched',
|
||||
text=result)
|
||||
return result
|
||||
|
||||
def matched_multi_lines(
|
||||
self,
|
||||
image_list,
|
||||
keyword_classes,
|
||||
lang: str = None,
|
||||
ignore_punctuation=True
|
||||
) -> list[OcrResultButton]:
|
||||
"""
|
||||
Args:
|
||||
image_list:
|
||||
keyword_classes: `Keyword` class or classes inherited `Keyword`, or a list of them.
|
||||
lang:
|
||||
ignore_punctuation:
|
||||
|
||||
Returns:
|
||||
List of matched OcrResultButton.
|
||||
OCR result which didn't matched known keywords will be dropped.
|
||||
"""
|
||||
results = self.ocr_multi_lines(image_list)
|
||||
|
||||
results = [self._match_result(
|
||||
result,
|
||||
keyword_classes=keyword_classes,
|
||||
lang=lang,
|
||||
ignore_punctuation=ignore_punctuation,
|
||||
) for result in results]
|
||||
results = [result for result in results if result.is_keyword_matched]
|
||||
|
||||
logger.attr(name=f'{self.name} matched',
|
||||
text=results)
|
||||
return results
|
||||
|
||||
def _product_button(
|
||||
self,
|
||||
boxed_result: BoxedResult,
|
||||
keyword_classes,
|
||||
lang: str = None,
|
||||
ignore_punctuation=True,
|
||||
ignore_digit=True
|
||||
) -> OcrResultButton:
|
||||
if not isinstance(keyword_classes, list):
|
||||
keyword_classes = [keyword_classes]
|
||||
|
||||
matched_keyword = self._match_result(
|
||||
boxed_result.ocr_text,
|
||||
keyword_classes=keyword_classes,
|
||||
lang=lang,
|
||||
ignore_punctuation=ignore_punctuation,
|
||||
ignore_digit=ignore_digit,
|
||||
)
|
||||
button = OcrResultButton(boxed_result, matched_keyword)
|
||||
return button
|
||||
|
||||
def matched_ocr(self, image, keyword_classes, direct_ocr=False) -> list[OcrResultButton]:
|
||||
"""
|
||||
Args:
|
||||
image: Screenshot
|
||||
keyword_classes: `Keyword` class or classes inherited `Keyword`, or a list of them.
|
||||
direct_ocr: True to ignore `button` attribute and feed the image to OCR model without cropping.
|
||||
|
||||
Returns:
|
||||
List of matched OcrResultButton.
|
||||
OCR result which didn't matched known keywords will be dropped.
|
||||
"""
|
||||
results = self.detect_and_ocr(image, direct_ocr=direct_ocr)
|
||||
|
||||
results = [self._product_button(result, keyword_classes) for result in results]
|
||||
results = [result for result in results if result.is_keyword_matched]
|
||||
|
||||
logger.attr(name=f'{self.name} matched',
|
||||
text=results)
|
||||
return results
|
||||
|
||||
|
||||
class Digit(Ocr):
|
||||
def __init__(self, button: ButtonWrapper, lang='en', name=None):
|
||||
super().__init__(button, lang=lang, name=name)
|
||||
|
||||
def format_result(self, result) -> int:
|
||||
"""
|
||||
Returns:
|
||||
int:
|
||||
"""
|
||||
result = super().after_process(result)
|
||||
logger.attr(name=self.name, text=str(result))
|
||||
|
||||
res = re.search(r'(\d+)', result)
|
||||
if res:
|
||||
return int(res.group(1))
|
||||
else:
|
||||
logger.warning(f'No digit found in {result}')
|
||||
return 0
|
||||
|
||||
|
||||
class DigitCounter(Ocr):
|
||||
def __init__(self, button: ButtonWrapper, lang='en', name=None):
|
||||
super().__init__(button, lang=lang, name=name)
|
||||
|
||||
def format_result(self, result) -> tuple[int, int, int]:
|
||||
"""
|
||||
Do OCR on a counter, such as `14/15`, and returns 14, 1, 15
|
||||
|
||||
Returns:
|
||||
int:
|
||||
"""
|
||||
result = super().after_process(result)
|
||||
logger.attr(name=self.name, text=str(result))
|
||||
|
||||
res = re.search(r'(\d+)/(\d+)', result)
|
||||
if res:
|
||||
groups = [int(s) for s in res.groups()]
|
||||
current, total = int(groups[0]), int(groups[1])
|
||||
# current = min(current, total)
|
||||
return current, total - current, total
|
||||
else:
|
||||
logger.warning(f'No digit counter found in {result}')
|
||||
return 0, 0, 0
|
||||
|
||||
|
||||
class Duration(Ocr):
|
||||
@classmethod
|
||||
def timedelta_regex(cls, lang):
|
||||
regex_str = {
|
||||
'cn': r'^(?P<prefix>.*?)'
|
||||
r'((?P<days>\d{1,2})\s*天\s*)?'
|
||||
r'((?P<hours>\d{1,2})\s*小时\s*)?'
|
||||
r'((?P<minutes>\d{1,2})\s*分钟\s*)?'
|
||||
r'((?P<seconds>\d{1,2})\s*秒)?'
|
||||
r'(?P<suffix>[^天时钟秒]*?)$',
|
||||
'en': r'^(?P<prefix>.*?)'
|
||||
r'((?P<days>\d{1,2})\s*d\s*)?'
|
||||
r'((?P<hours>\d{1,2})\s*h\s*)?'
|
||||
r'((?P<minutes>\d{1,2})\s*m\s*)?'
|
||||
r'((?P<seconds>\d{1,2})\s*s)?'
|
||||
r'(?P<suffix>[^dhms]*?)$'
|
||||
}[lang]
|
||||
return re.compile(regex_str)
|
||||
|
||||
def after_process(self, result):
|
||||
result = super().after_process(result)
|
||||
result = result.strip('.,。,')
|
||||
result = result.replace('Oh', '0h').replace('oh', '0h')
|
||||
return result
|
||||
|
||||
def format_result(self, result: str) -> timedelta:
|
||||
"""
|
||||
Do OCR on a duration, such as `18d 2h 13m 30s`, `2h`, `13m 30s`, `9s`
|
||||
|
||||
Returns:
|
||||
timedelta:
|
||||
"""
|
||||
matched = self.timedelta_regex(self.lang).search(result)
|
||||
if not matched:
|
||||
return timedelta()
|
||||
days = self._sanitize_number(matched.group('days'))
|
||||
hours = self._sanitize_number(matched.group('hours'))
|
||||
minutes = self._sanitize_number(matched.group('minutes'))
|
||||
seconds = self._sanitize_number(matched.group('seconds'))
|
||||
return timedelta(days=days, hours=hours, minutes=minutes, seconds=seconds)
|
||||
|
||||
@staticmethod
|
||||
def _sanitize_number(number) -> int:
|
||||
if number is None:
|
||||
return 0
|
||||
return int(number)
|
||||
@@ -0,0 +1,128 @@
|
||||
import itertools
|
||||
|
||||
from pponnxcr.predict_system import BoxedResult
|
||||
|
||||
from module.base.utils import area_in_area, area_offset
|
||||
|
||||
|
||||
def area_cross_area(area1, area2, thres_x=20, thres_y=20):
|
||||
"""
|
||||
Args:
|
||||
area1: (upper_left_x, upper_left_y, bottom_right_x, bottom_right_y).
|
||||
area2: (upper_left_x, upper_left_y, bottom_right_x, bottom_right_y).
|
||||
thres_x:
|
||||
thres_y:
|
||||
|
||||
Returns:
|
||||
bool:
|
||||
"""
|
||||
# https://www.yiiven.cn/rect-is-intersection.html
|
||||
xa1, ya1, xa2, ya2 = area1
|
||||
xb1, yb1, xb2, yb2 = area2
|
||||
return abs(xb2 + xb1 - xa2 - xa1) <= xa2 - xa1 + xb2 - xb1 + thres_x * 2 \
|
||||
and abs(yb2 + yb1 - ya2 - ya1) <= ya2 - ya1 + yb2 - yb1 + thres_y * 2
|
||||
|
||||
|
||||
def _merge_area(area1, area2):
|
||||
xa1, ya1, xa2, ya2 = area1
|
||||
xb1, yb1, xb2, yb2 = area2
|
||||
return min(xa1, xb1), min(ya1, yb1), max(xa2, xb2), max(ya2, yb2)
|
||||
|
||||
|
||||
def _merge_boxed_result(left: BoxedResult, right: BoxedResult) -> BoxedResult:
|
||||
left.box = _merge_area(left.box, right.box)
|
||||
left.ocr_text = left.ocr_text + right.ocr_text
|
||||
return left
|
||||
|
||||
|
||||
def merge_buttons(buttons: list[BoxedResult], thres_x=20, thres_y=20) -> list[BoxedResult]:
|
||||
"""
|
||||
Args:
|
||||
buttons:
|
||||
thres_x: Merge results with horizontal box distance <= `thres_x`
|
||||
thres_y: Merge results with vertical box distance <= `thres_y`
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
if thres_x <= 0 and thres_y <= 0:
|
||||
return buttons
|
||||
|
||||
dic_button = {button.box: button for button in buttons}
|
||||
set_merged = set()
|
||||
for left, right in itertools.combinations(dic_button.items(), 2):
|
||||
left_box, left = left
|
||||
right_box, right = right
|
||||
if area_cross_area(left.box, right.box, thres_x=thres_x, thres_y=thres_y):
|
||||
left = _merge_boxed_result(left, right)
|
||||
dic_button[left_box] = left
|
||||
dic_button[right_box] = left
|
||||
set_merged.add(right_box)
|
||||
|
||||
return [button for box, button in dic_button.items() if box not in set_merged]
|
||||
|
||||
|
||||
# def pair_buttons(
|
||||
# group1: list["OcrResultButton"],
|
||||
# group2: list["OcrResultButton"],
|
||||
# relative_area: tuple[int, int, int, int]
|
||||
# ) -> t.Generator["OcrResultButton", "OcrResultButton"]:
|
||||
# pass
|
||||
|
||||
def pair_buttons(group1, group2, relative_area):
|
||||
"""
|
||||
Pair buttons in group1 with those in group2 in the relative_area.
|
||||
|
||||
Args:
|
||||
group1 (list[OcrResultButton]):
|
||||
group2 (list[OcrResultButton]):
|
||||
relative_area (tuple[int, int, int, int]):
|
||||
|
||||
Yields:
|
||||
OcrResultButton, OcrResultButton:
|
||||
"""
|
||||
for button1 in group1:
|
||||
area = area_offset(relative_area, offset=button1.area[:2])
|
||||
for button2 in group2:
|
||||
if area_in_area(button2.area, area, threshold=0):
|
||||
yield button1, button2
|
||||
|
||||
|
||||
def split_and_pair_buttons(buttons, split_func, relative_area):
|
||||
"""
|
||||
Pair buttons in group1 with those in group2 in the relative_area.
|
||||
|
||||
Args:
|
||||
buttons (list[OcrResultButton]):
|
||||
split_func (callable):
|
||||
A function that accepts an OcrResultButton object returns a bool,
|
||||
button that has a True return join group1, False join group2.
|
||||
relative_area (tuple[int, int, int, int]):
|
||||
|
||||
Yields:
|
||||
OcrResultButton, OcrResultButton:
|
||||
"""
|
||||
group1 = [button for button in buttons if split_func(button)]
|
||||
group2 = [button for button in buttons if not split_func(button)]
|
||||
for ret in pair_buttons(group1, group2, relative_area):
|
||||
yield ret
|
||||
|
||||
|
||||
def split_and_pair_button_attr(buttons, split_func, relative_area):
|
||||
"""
|
||||
Pair buttons in group1 with those in group2 in the relative_area,
|
||||
and treat group2 as the BUTTON attribute of group1.
|
||||
|
||||
Args:
|
||||
buttons (list[OcrResultButton]):
|
||||
split_func (callable):
|
||||
A function that accepts an OcrResultButton object returns a bool,
|
||||
button that has a True return join group1, False join group2.
|
||||
relative_area (tuple[int, int, int, int]):
|
||||
|
||||
Yields:
|
||||
OcrResultButton:
|
||||
"""
|
||||
for button1, button2 in split_and_pair_buttons(buttons, split_func, relative_area):
|
||||
button1.button = button2.button
|
||||
yield button1
|
||||
Reference in New Issue
Block a user