본문 바로가기
개인 프로젝트/논문리뷰

LMIM 코드

by 응_비 2026. 2. 25.
import os, subprocess, sys, textwrap, glob

def sh(cmd: str):
    print("\n$", cmd)
    subprocess.check_call(cmd, shell=True)

# ========= 1) Clone + install =========
os.chdir("/content")
if os.path.exists("LMIM"):
    sh("rm -rf LMIM")
sh("git clone https://github.com/zhangyifei01/LMIM.git")
os.chdir("/content/LMIM")

sh("pip -q install -r requirements.txt")
sh("pip -q install lmdb timm gdown")

# ========= 2) Set your LMDB paths =========
# ✅ 여기만 너 데이터 경로로 바꿔!
TRAIN_LMDB = "/content/your_train_lmdb"   # e.g. "/content/drive/MyDrive/str_lmdb/train"
EVAL_LMDB  = "/content/your_eval_lmdb"    # e.g. "/content/drive/MyDrive/str_lmdb/IIIT5k_3000"

OUTDIR = "/content/lmim_out"
os.makedirs(OUTDIR, exist_ok=True)

assert os.path.exists(TRAIN_LMDB), f"TRAIN_LMDB not found: {TRAIN_LMDB}"
assert os.path.exists(EVAL_LMDB),  f"EVAL_LMDB not found: {EVAL_LMDB}"
print("OK dataset paths")

# ========= 3) Train (1 epoch sanity) =========
os.chdir("/content/LMIM/lmim_finetune")

train_cmd = f"""
OMP_NUM_THREADS=1 torchrun --standalone --nproc_per_node=1 run_class_finetuning.py \
  --model simmim_vit_small_patch4_32x128 \
  --data_path "{TRAIN_LMDB}" \
  --eval_data_path "{EVAL_LMDB}" \
  --output_dir "{OUTDIR}" \
  --batch_size 64 \
  --opt adamw \
  --opt_betas 0.9 0.999 \
  --weight_decay 0.05 \
  --data_set image_lmdb \
  --nb_classes 97 \
  --smoothing 0. \
  --max_len 25 \
  --epochs 1 \
  --warmup_epochs 0 \
  --drop 0.1 \
  --attn_drop_rate 0.1 \
  --num_samples 20000 \
  --fixed_encoder_layers 0 \
  --decoder_name tf_decoder \
  --beam_width 0
"""
sh(textwrap.dedent(train_cmd).strip())

# ========= 4) Find checkpoint =========
ckpts = sorted(glob.glob(os.path.join(OUTDIR, "checkpoint-*.pth")), reverse=True)
if not ckpts:
    raise RuntimeError(f"No checkpoint found in {OUTDIR}. Training may have failed.")
CKPT = ckpts[0]
print("Using checkpoint:", CKPT)

# ========= 5) Eval (prints accuracy in logs) =========
eval_cmd = f"""
OMP_NUM_THREADS=1 torchrun --standalone --nproc_per_node=1 run_class_finetuning.py \
  --model simmim_vit_small_patch4_32x128 \
  --data_path "{EVAL_LMDB}" \
  --eval_data_path "{EVAL_LMDB}" \
  --output_dir "{OUTDIR}/eval" \
  --batch_size 128 \
  --opt adamw \
  --opt_betas 0.9 0.999 \
  --weight_decay 0.05 \
  --data_set image_lmdb \
  --nb_classes 97 \
  --smoothing 0. \
  --max_len 25 \
  --resume "{CKPT}" \
  --eval \
  --dist_eval \
  --num_samples 1000000 \
  --fixed_encoder_layers 0 \
  --decoder_name tf_decoder \
  --beam_width 0
"""
sh(textwrap.dedent(eval_cmd).strip())

print("\nDONE. 위 eval 로그에 accuracy/CER가 출력돼.")

import csv
from difflib import SequenceMatcher

# 문자열 유사도 계산
def similarity(a, b):
    return SequenceMatcher(None, a, b).ratio()

# CER 계산 (간단 버전)
def cer(gt, pred):
    import numpy as np
    dp = np.zeros((len(gt)+1, len(pred)+1))
    for i in range(len(gt)+1):
        dp[i][0] = i
    for j in range(len(pred)+1):
        dp[0][j] = j

    for i in range(1, len(gt)+1):
        for j in range(1, len(pred)+1):
            cost = 0 if gt[i-1] == pred[j-1] else 1
            dp[i][j] = min(
                dp[i-1][j] + 1,
                dp[i][j-1] + 1,
                dp[i-1][j-1] + cost
            )
    return dp[len(gt)][len(pred)] / len(gt)


# 결과 파일 (LMIM inference 결과 저장된 csv)
input_file = "result.csv"
output_file = "fail_cases.csv"

fail_cases = []

with open(input_file, newline='', encoding='utf-8') as f:
    reader = csv.DictReader(f)
    for row in reader:
        gt = row["gt"]
        pred = row["pred"]

        error = cer(gt, pred)

        # 실패 기준 (CER > 0.3)
        if error > 0.3:
            fail_cases.append({
                "img": row["img"],
                "gt": gt,
                "pred": pred,
                "cer": round(error, 3)
            })

# CER 기준으로 정렬
fail_cases = sorted(fail_cases, key=lambda x: -x["cer"])

with open(output_file, "w", newline='', encoding='utf-8') as f:
    writer = csv.DictWriter(f, fieldnames=["img", "gt", "pred", "cer"])
    writer.writeheader()
    writer.writerows(fail_cases)

print(f"총 실패 케이스: {len(fail_cases)}개 저장 완료")

https://github.com/zhangyifei01/LMIM

👉 목표: LMIM 결과로 오답(실패 케이스) 자동 추출



🔥 1. LMIM inference 실행 (기본 커맨드)

👉 LMIM repo 안에서 실행

python lmim_finetune/run_class_finetuning.py \
  --eval \
  --resume ./checkpoint.pth \
  --data_path ./your_data_lmdb \
  --eval_data_path ./your_data_lmdb \
  --batch_size 32 \
  --data_set image_lmdb \
  --model vit_base_patch16 \
  --nb_classes 36

👉 결과:
• 콘솔에 accuracy 나옴
• (근데 우리는 오답을 뽑아야 함)



🔥 2. 🔑 핵심: 오답 자동 추출 코드

👉 이게 너 논문의 시작임

📄 extract_fail_cases.py

import csv
from difflib import SequenceMatcher

# 문자열 유사도 계산
def similarity(a, b):
    return SequenceMatcher(None, a, b).ratio()

# CER 계산 (간단 버전)
def cer(gt, pred):
    import numpy as np
    dp = np.zeros((len(gt)+1, len(pred)+1))
    for i in range(len(gt)+1):
        dp[i][0] = i
    for j in range(len(pred)+1):
        dp[0][j] = j

    for i in range(1, len(gt)+1):
        for j in range(1, len(pred)+1):
            cost = 0 if gt[i-1] == pred[j-1] else 1
            dp[i][j] = min(
                dp[i-1][j] + 1,
                dp[i][j-1] + 1,
                dp[i-1][j-1] + cost
            )
    return dp[len(gt)][len(pred)] / len(gt)


# 결과 파일 (LMIM inference 결과 저장된 csv)
input_file = "result.csv"
output_file = "fail_cases.csv"

fail_cases = []

with open(input_file, newline='', encoding='utf-8') as f:
    reader = csv.DictReader(f)
    for row in reader:
        gt = row["gt"]
        pred = row["pred"]

        error = cer(gt, pred)

        # 실패 기준 (CER > 0.3)
        if error > 0.3:
            fail_cases.append({
                "img": row["img"],
                "gt": gt,
                "pred": pred,
                "cer": round(error, 3)
            })

# CER 기준으로 정렬
fail_cases = sorted(fail_cases, key=lambda x: -x["cer"])

with open(output_file, "w", newline='', encoding='utf-8') as f:
    writer = csv.DictWriter(f, fieldnames=["img", "gt", "pred", "cer"])
    writer.writeheader()
    writer.writerows(fail_cases)

print(f"총 실패 케이스: {len(fail_cases)}개 저장 완료")




🔥 3. LMIM 결과를 CSV로 저장하는 방법 (필수)

👉 LMIM 기본 코드에는 “pred 저장”이 없음 → 직접 추가해야 함

📄 수정 위치

run_class_finetuning.py

👉 prediction 부분에 추가:

# 예측 결과 저장용
results = []

# inference loop 안에서
for i, (images, target) in enumerate(data_loader):
    output = model(images)
    pred = output.argmax(dim=1)

    for j in range(len(pred)):
        results.append({
            "img": "unknown",  # 필요하면 path 추가
            "gt": str(target[j].item()),
            "pred": str(pred[j].item())
        })

# 끝나고 csv 저장
import csv
with open("result.csv", "w", newline='', encoding='utf-8') as f:
    writer = csv.DictWriter(f, fieldnames=["img", "gt", "pred"])
    writer.writeheader()
    writer.writerows(results)




🚀 4. 전체 흐름 (이대로 하면 됨)

LMIM eval 실행
→ result.csv 생성
→ extract_fail_cases.py 실행
→ fail_cases.csv 생성
→ 논문 쓸 재료 확보




🎯 5. 이걸로 뭘 할 수 있냐

👉 fail_cases.csv 보면 바로 보임:
• 손글씨에서 틀림 많음
• 숫자/금액 틀림
• 특정 글자 (0/O, 1/I) 헷갈림

👉 이게 그대로 논문:

"LMIM은 hand-written financial text에서 CER가 높다"
→ "우리는 이를 개선한다"




🔥 한 줄 정리

👉
LMIM 돌린다 → 결과 저장 → CER로 오답 뽑는다
→ 그게 논문이다



원하면 다음 단계로
👉 LMDB 없이 바로 돌리는 코드로 바꿔주기
👉 손글씨 데이터셋 추천 + 다운로드

까지 바로 이어서 도와줄게 🔥

댓글