import os, subprocess, sys, textwrap, glob
def sh(cmd: str):
print("\n$", cmd)
subprocess.check_call(cmd, shell=True)
# ========= 1) Clone + install =========
os.chdir("/content")
if os.path.exists("LMIM"):
sh("rm -rf LMIM")
sh("git clone https://github.com/zhangyifei01/LMIM.git")
os.chdir("/content/LMIM")
sh("pip -q install -r requirements.txt")
sh("pip -q install lmdb timm gdown")
# ========= 2) Set your LMDB paths =========
# ✅ 여기만 너 데이터 경로로 바꿔!
TRAIN_LMDB = "/content/your_train_lmdb" # e.g. "/content/drive/MyDrive/str_lmdb/train"
EVAL_LMDB = "/content/your_eval_lmdb" # e.g. "/content/drive/MyDrive/str_lmdb/IIIT5k_3000"
OUTDIR = "/content/lmim_out"
os.makedirs(OUTDIR, exist_ok=True)
assert os.path.exists(TRAIN_LMDB), f"TRAIN_LMDB not found: {TRAIN_LMDB}"
assert os.path.exists(EVAL_LMDB), f"EVAL_LMDB not found: {EVAL_LMDB}"
print("OK dataset paths")
# ========= 3) Train (1 epoch sanity) =========
os.chdir("/content/LMIM/lmim_finetune")
train_cmd = f"""
OMP_NUM_THREADS=1 torchrun --standalone --nproc_per_node=1 run_class_finetuning.py \
--model simmim_vit_small_patch4_32x128 \
--data_path "{TRAIN_LMDB}" \
--eval_data_path "{EVAL_LMDB}" \
--output_dir "{OUTDIR}" \
--batch_size 64 \
--opt adamw \
--opt_betas 0.9 0.999 \
--weight_decay 0.05 \
--data_set image_lmdb \
--nb_classes 97 \
--smoothing 0. \
--max_len 25 \
--epochs 1 \
--warmup_epochs 0 \
--drop 0.1 \
--attn_drop_rate 0.1 \
--num_samples 20000 \
--fixed_encoder_layers 0 \
--decoder_name tf_decoder \
--beam_width 0
"""
sh(textwrap.dedent(train_cmd).strip())
# ========= 4) Find checkpoint =========
ckpts = sorted(glob.glob(os.path.join(OUTDIR, "checkpoint-*.pth")), reverse=True)
if not ckpts:
raise RuntimeError(f"No checkpoint found in {OUTDIR}. Training may have failed.")
CKPT = ckpts[0]
print("Using checkpoint:", CKPT)
# ========= 5) Eval (prints accuracy in logs) =========
eval_cmd = f"""
OMP_NUM_THREADS=1 torchrun --standalone --nproc_per_node=1 run_class_finetuning.py \
--model simmim_vit_small_patch4_32x128 \
--data_path "{EVAL_LMDB}" \
--eval_data_path "{EVAL_LMDB}" \
--output_dir "{OUTDIR}/eval" \
--batch_size 128 \
--opt adamw \
--opt_betas 0.9 0.999 \
--weight_decay 0.05 \
--data_set image_lmdb \
--nb_classes 97 \
--smoothing 0. \
--max_len 25 \
--resume "{CKPT}" \
--eval \
--dist_eval \
--num_samples 1000000 \
--fixed_encoder_layers 0 \
--decoder_name tf_decoder \
--beam_width 0
"""
sh(textwrap.dedent(eval_cmd).strip())
print("\nDONE. 위 eval 로그에 accuracy/CER가 출력돼.")import csv
from difflib import SequenceMatcher
# 문자열 유사도 계산
def similarity(a, b):
return SequenceMatcher(None, a, b).ratio()
# CER 계산 (간단 버전)
def cer(gt, pred):
import numpy as np
dp = np.zeros((len(gt)+1, len(pred)+1))
for i in range(len(gt)+1):
dp[i][0] = i
for j in range(len(pred)+1):
dp[0][j] = j
for i in range(1, len(gt)+1):
for j in range(1, len(pred)+1):
cost = 0 if gt[i-1] == pred[j-1] else 1
dp[i][j] = min(
dp[i-1][j] + 1,
dp[i][j-1] + 1,
dp[i-1][j-1] + cost
)
return dp[len(gt)][len(pred)] / len(gt)
# 결과 파일 (LMIM inference 결과 저장된 csv)
input_file = "result.csv"
output_file = "fail_cases.csv"
fail_cases = []
with open(input_file, newline='', encoding='utf-8') as f:
reader = csv.DictReader(f)
for row in reader:
gt = row["gt"]
pred = row["pred"]
error = cer(gt, pred)
# 실패 기준 (CER > 0.3)
if error > 0.3:
fail_cases.append({
"img": row["img"],
"gt": gt,
"pred": pred,
"cer": round(error, 3)
})
# CER 기준으로 정렬
fail_cases = sorted(fail_cases, key=lambda x: -x["cer"])
with open(output_file, "w", newline='', encoding='utf-8') as f:
writer = csv.DictWriter(f, fieldnames=["img", "gt", "pred", "cer"])
writer.writeheader()
writer.writerows(fail_cases)
print(f"총 실패 케이스: {len(fail_cases)}개 저장 완료")https://github.com/zhangyifei01/LMIM
👉 목표: LMIM 결과로 오답(실패 케이스) 자동 추출
⸻
🔥 1. LMIM inference 실행 (기본 커맨드)
👉 LMIM repo 안에서 실행
python lmim_finetune/run_class_finetuning.py \
--eval \
--resume ./checkpoint.pth \
--data_path ./your_data_lmdb \
--eval_data_path ./your_data_lmdb \
--batch_size 32 \
--data_set image_lmdb \
--model vit_base_patch16 \
--nb_classes 36
👉 결과:
• 콘솔에 accuracy 나옴
• (근데 우리는 오답을 뽑아야 함)
⸻
🔥 2. 🔑 핵심: 오답 자동 추출 코드
👉 이게 너 논문의 시작임
📄 extract_fail_cases.py
import csv
from difflib import SequenceMatcher
# 문자열 유사도 계산
def similarity(a, b):
return SequenceMatcher(None, a, b).ratio()
# CER 계산 (간단 버전)
def cer(gt, pred):
import numpy as np
dp = np.zeros((len(gt)+1, len(pred)+1))
for i in range(len(gt)+1):
dp[i][0] = i
for j in range(len(pred)+1):
dp[0][j] = j
for i in range(1, len(gt)+1):
for j in range(1, len(pred)+1):
cost = 0 if gt[i-1] == pred[j-1] else 1
dp[i][j] = min(
dp[i-1][j] + 1,
dp[i][j-1] + 1,
dp[i-1][j-1] + cost
)
return dp[len(gt)][len(pred)] / len(gt)
# 결과 파일 (LMIM inference 결과 저장된 csv)
input_file = "result.csv"
output_file = "fail_cases.csv"
fail_cases = []
with open(input_file, newline='', encoding='utf-8') as f:
reader = csv.DictReader(f)
for row in reader:
gt = row["gt"]
pred = row["pred"]
error = cer(gt, pred)
# 실패 기준 (CER > 0.3)
if error > 0.3:
fail_cases.append({
"img": row["img"],
"gt": gt,
"pred": pred,
"cer": round(error, 3)
})
# CER 기준으로 정렬
fail_cases = sorted(fail_cases, key=lambda x: -x["cer"])
with open(output_file, "w", newline='', encoding='utf-8') as f:
writer = csv.DictWriter(f, fieldnames=["img", "gt", "pred", "cer"])
writer.writeheader()
writer.writerows(fail_cases)
print(f"총 실패 케이스: {len(fail_cases)}개 저장 완료")
⸻
🔥 3. LMIM 결과를 CSV로 저장하는 방법 (필수)
👉 LMIM 기본 코드에는 “pred 저장”이 없음 → 직접 추가해야 함
📄 수정 위치
run_class_finetuning.py
👉 prediction 부분에 추가:
# 예측 결과 저장용
results = []
# inference loop 안에서
for i, (images, target) in enumerate(data_loader):
output = model(images)
pred = output.argmax(dim=1)
for j in range(len(pred)):
results.append({
"img": "unknown", # 필요하면 path 추가
"gt": str(target[j].item()),
"pred": str(pred[j].item())
})
# 끝나고 csv 저장
import csv
with open("result.csv", "w", newline='', encoding='utf-8') as f:
writer = csv.DictWriter(f, fieldnames=["img", "gt", "pred"])
writer.writeheader()
writer.writerows(results)
⸻
🚀 4. 전체 흐름 (이대로 하면 됨)
LMIM eval 실행
→ result.csv 생성
→ extract_fail_cases.py 실행
→ fail_cases.csv 생성
→ 논문 쓸 재료 확보
⸻
🎯 5. 이걸로 뭘 할 수 있냐
👉 fail_cases.csv 보면 바로 보임:
• 손글씨에서 틀림 많음
• 숫자/금액 틀림
• 특정 글자 (0/O, 1/I) 헷갈림
👉 이게 그대로 논문:
"LMIM은 hand-written financial text에서 CER가 높다"
→ "우리는 이를 개선한다"
⸻
🔥 한 줄 정리
👉
LMIM 돌린다 → 결과 저장 → CER로 오답 뽑는다
→ 그게 논문이다
⸻
원하면 다음 단계로
👉 LMDB 없이 바로 돌리는 코드로 바꿔주기
👉 손글씨 데이터셋 추천 + 다운로드
까지 바로 이어서 도와줄게 🔥
'개인 프로젝트 > 논문리뷰' 카테고리의 다른 글
| LMIM +IIIT5K 논문 (0) | 2026.03.02 |
|---|---|
| Linguistics-aware_Masked_Image_Modeling_for_Self-supervised_Scene_Text_Recognition_CVPR_2025_paper (0) | 2026.03.02 |
| LMIM 논문방향 (0) | 2026.02.25 |
| LMIM 논문 개요 정리 (0) | 2026.02.24 |
| MIM 코드 (0) | 2026.02.06 |
댓글