본문 바로가기
개인 프로젝트/논문리뷰

LMIM 적용 코드

by 응_비 2026. 3. 5.

최종

 

Dataset 준비
      ↓
Model 로드
      ↓
Training pipeline 실행
      ↓
Pretraining 정상 수행
      ↓
Loss 감소 확인  ← 지금 여기

 

  • LMIM 코드 실행 성공
  • shape 오류 해결 성공
  • pretraining 정상 진행 확인
  • loss 감소 확인
  • repo 실험 가능 상태 확보

Epoch: [0] [624/625] eta: 0:00:00 lr: 0.000016 loss: 0.0554 (0.3921) time: 1.5275 data: 1.2081 max mem: 3910 [15:10:31.712165] Epoch: [0] Total time: 0:09:03 (0.8699 s / it) [15:10:31.712299] Averaged stats: lr: 0.000016 loss: 0.0554 (0.3921) [15:10:31.714950] Training time 0:09:03

%cd /content
!git clone https://github.com/zhangyifei01/LMIM.git
!ls /content
!ls /content/LMIM
!ls /content/LMIM/lmim_pretrain

from pathlib import Path
import shutil, re

ENGINE_PATH = Path("/content/LMIM/lmim_pretrain/engine_pretrain.py")
BACKUP_PATH = Path("/content/LMIM/lmim_pretrain/engine_pretrain.py.bak_onetime")

assert ENGINE_PATH.exists(), f"engine_pretrain.py not found: {ENGINE_PATH}"

# backup
if not BACKUP_PATH.exists():
    shutil.copy(ENGINE_PATH, BACKUP_PATH)
    print(f"[OK] backup created: {BACKUP_PATH}")
else:
    print(f"[OK] backup already exists: {BACKUP_PATH}")

text = ENGINE_PATH.read_text()

original_exact = """        samples = udata[0].to(device, non_blocking=True)
        samples_aug = udata[1].to(device, non_blocking=True)
"""

patched_block = """        samples = udata[0].to(device, non_blocking=True)

        # HOTFIX:
        # ImageFolder returns (image, label), but LMIM expects (image, augmented_image).
        samples_aug = samples.clone()

        print("[DEBUG] samples.shape =", samples.shape)
        print("[DEBUG] samples_aug.shape =", samples_aug.shape)
"""

patched = False

if original_exact in text:
    text = text.replace(original_exact, patched_block, 1)
    patched = True
    print("[OK] exact patch applied.")
else:
    pattern = r"""
(\s*)samples\s*=\s*udata\[0\]\.to\(device,\s*non_blocking=True\)\s*
\1samples_aug\s*=\s*udata\[1\]\.to\(device,\s*non_blocking=True\)
"""
    replacement = r"""\1samples = udata[0].to(device, non_blocking=True)

\1# HOTFIX:
\1# ImageFolder returns (image, label), but LMIM expects (image, augmented_image).
\1samples_aug = samples.clone()

\1print("[DEBUG] samples.shape =", samples.shape)
\1print("[DEBUG] samples_aug.shape =", samples_aug.shape)"""
    new_text, count = re.subn(pattern, replacement, text, count=1, flags=re.VERBOSE)
    if count == 1:
        text = new_text
        patched = True
        print("[OK] regex patch applied.")

if not patched:
    print("[INFO] patch target not found. It may already be patched.")
else:
    ENGINE_PATH.write_text(text)
    print(f"[OK] saved patched file: {ENGINE_PATH}")

updated = ENGINE_PATH.read_text()
idx = updated.find("samples = udata[0]")
print("\n[PATCHED SNIPPET]")
print(updated[idx:idx+400])
%cd /content/LMIM/lmim_pretrain

!python -u main_pretrain.py \
  --data_path /content/LMIM/data/IIIT5K \
  --output_dir /content/LMIM/outputs/pretrain_test \
  --log_dir /content/LMIM/logs/pretrain_test \
  --model mae_vit_base_patch4 \
  --batch_size 8 \
  --epochs 1 \
  --num_workers 0 \
  --device cuda
from pathlib import Path
import shutil

MAIN_PATH = Path("/content/LMIM/lmim_pretrain/main_pretrain.py")
BACKUP_PATH = Path("/content/LMIM/lmim_pretrain/main_pretrain.py.bak_savefix")

if not BACKUP_PATH.exists():
    shutil.copy(MAIN_PATH, BACKUP_PATH)
    print(f"[OK] backup created: {BACKUP_PATH}")
else:
    print(f"[OK] backup already exists: {BACKUP_PATH}")

text = MAIN_PATH.read_text()

old = """        if args.output_dir and (epoch + 1) % 5 == 0:
            misc.save_model(
                args=args,
                model=model,
                model_without_ddp=model_without_ddp,
                optimizer=optimizer,
                loss_scaler=loss_scaler,
                epoch=epoch)
"""

new = """        if args.output_dir and (((epoch + 1) % 5 == 0) or ((epoch + 1) == args.epochs)):
            misc.save_model(
                args=args,
                model=model,
                model_without_ddp=model_without_ddp,
                optimizer=optimizer,
                loss_scaler=loss_scaler,
                epoch=epoch)
"""

if old in text:
    text = text.replace(old, new, 1)
    MAIN_PATH.write_text(text)
    print("[OK] save condition patched.")
else:
    print("[INFO] exact block not found. Check manually.")

print("\n[CHECK]")
updated = MAIN_PATH.read_text()
idx = updated.find("if args.output_dir")
print(updated[idx:idx+350])
%cd /content/LMIM/lmim_pretrain

!python -u main_pretrain.py \
  --data_path /content/LMIM/data/IIIT5K \
  --output_dir /content/LMIM/outputs/pretrain_test_1ep_save \
  --log_dir /content/LMIM/logs/pretrain_test_1ep_save \
  --model mae_vit_base_patch4 \
  --batch_size 8 \
  --epochs 1 \
  --num_workers 0 \
  --device cuda
!ls -R /content/LMIM/outputs/pretrain_test_1ep_save
!ls -R /content/LMIM/logs/pretrain_test_1ep_save

# =========================
# 0. clean start
# =========================
%cd /content
!rm -rf /content/LMIM

# =========================
# 1. clone repo
# =========================
!git clone https://github.com/zhangyifei01/LMIM.git /content/LMIM
%cd /content/LMIM

# =========================
# 2. mount drive
# =========================
from google.colab import drive
drive.mount('/content/drive')

# =========================
# 3. connect dataset
# =========================
import os, shutil

DATA_DIR = "/content/drive/MyDrive/datasets/IIIT5K"
LINK_PATH = "/content/LMIM/data/IIIT5K"

os.makedirs("/content/LMIM/data", exist_ok=True)

if os.path.islink(LINK_PATH):
    os.unlink(LINK_PATH)
elif os.path.exists(LINK_PATH):
    shutil.rmtree(LINK_PATH)

os.symlink(DATA_DIR, LINK_PATH)

print("linked:", LINK_PATH)
print("DATA_DIR exists:", os.path.exists(DATA_DIR))
print("LINK_PATH exists:", os.path.exists(LINK_PATH))
print("files:", os.listdir(LINK_PATH)[:20])
# =========================
# 4. install dependencies
# =========================
!pip uninstall -y numpy opencv-python opencv-python-headless opencv-contrib-python timm
!pip install -q --no-cache-dir \
  "numpy==1.26.4" \
  "timm==0.4.12" \
  "opencv-python==4.8.1.78" \
  "imgaug==0.4.0" \
  "lmdb" \
  "easydict" \
  "tensorboardX" \
  "pyyaml" \
  "scipy" \
  "tqdm" \
  "nltk"
%cd /content/LMIM

import os, re

# =========================
# 5. patch compatibility
# =========================
root = "/content/LMIM"
patched = []

for current_root, _, files in os.walk(root):
    for fn in files:
        if not fn.endswith(".py"):
            continue

        path = os.path.join(current_root, fn)
        with open(path, "r", encoding="utf-8", errors="ignore") as f:
            text = f.read()

        new_text = text
        new_text = new_text.replace("from torch._six import inf", "from torch import inf")
        new_text = new_text.replace("dtype=np.float)", "dtype=float)")
        new_text = new_text.replace("dtype=np.float,", "dtype=float,")
        new_text = new_text.replace("dtype=np.float\n", "dtype=float\n")
        new_text = re.sub(r",\s*qk_scale\s*=\s*[^,\)\n]+", "", new_text)
        new_text = re.sub(r"qk_scale\s*=\s*[^,\)\n]+,\s*", "", new_text)

        if new_text != text:
            with open(path, "w", encoding="utf-8") as f:
                f.write(new_text)
            patched.append(path)

print("patched files:")
for p in patched:
    print("-", p)

# =========================
# 6. patch main_pretrain.py
# =========================
file_path = "/content/LMIM/lmim_pretrain/main_pretrain.py"

with open(file_path, "r", encoding="utf-8") as f:
    text = f.read()

text = text.replace(
    "def main(args):\n    if isinstance(args.data_path, list) and len(args.data_path) == 1:\n        args.data_path = args.data_path[0]\n",
    "def main(args):\n"
)

marker = "def main(args):"
inject = """def main(args):
    if isinstance(args.data_path, list) and len(args.data_path) == 1:
        args.data_path = args.data_path[0]
"""

if marker in text and "if isinstance(args.data_path, list) and len(args.data_path) == 1:" not in text:
    text = text.replace(marker, inject, 1)

with open(file_path, "w", encoding="utf-8") as f:
    f.write(text)

print("patched:", file_path)

%cd /content
!rm -rf LMIM
!git clone https://github.com/ayumiymk/LMIM.git
!ls /content/LMIM
!ls /content/LMIM/data
%cd /content
!git clone https://github.com/ayumiymk/LMIM.git
%cd /content
!rm -rf /content/LMIM
!git clone https://github.com/zhangyifei01/LMIM.git /content/LMIM
!ls /content/LMIM
from google.colab import drive
drive.mount('/content/drive')
DATA_DIR = "/content/drive/MyDrive/datasets/IIIT5K"
import os, shutil

DATA_DIR = "/content/drive/MyDrive/datasets/IIIT5K"
LINK_PATH = "/content/LMIM/data/IIIT5K"

os.makedirs("/content/LMIM/data", exist_ok=True)

if os.path.islink(LINK_PATH):
    os.unlink(LINK_PATH)
elif os.path.exists(LINK_PATH):
    shutil.rmtree(LINK_PATH)

os.symlink(DATA_DIR, LINK_PATH)

print("linked:", LINK_PATH)
print("check:", os.listdir(LINK_PATH))
%cd /content/LMIM

!pip install -q timm
!pip install -q opencv-python
!pip install -q lmdb
!pip install -q easydict
!pip install -q scipy
!pip install -q pyyaml
!pip install -q tqdm
!pip install -q nltk
!python lmim_pretrain/main_pretrain.py \
--batch_size 8 \
--epochs 1 \
--model mae_vit_base_patch4 \
--data_path /content/LMIM/data/IIIT5K \
--output_dir /content/LMIM/outputs/pretrain_test \
--log_dir /content/LMIM/logs/pretrain_test
import os

root = "/content/LMIM"
patched = []

for current_root, _, files in os.walk(root):
    for fn in files:
        if fn.endswith(".py"):
            path = os.path.join(current_root, fn)
            with open(path, "r", encoding="utf-8", errors="ignore") as f:
                text = f.read()

            new_text = text.replace("from torch._six import inf", "from torch import inf")

            if new_text != text:
                with open(path, "w", encoding="utf-8") as f:
                    f.write(new_text)
                patched.append(path)

print("patched files:")
for p in patched:
    print("-", p)
!grep -R "torch._six" -n /content/LMIM || true
!python /content/LMIM/lmim_pretrain/main_pretrain.py \
  --batch_size 8 \
  --epochs 1 \
  --model mae_vit_base_patch4 \
  --data_path /content/LMIM/data/IIIT5K \
  --output_dir /content/LMIM/outputs/pretrain_test \
  --log_dir /content/LMIM/logs/pretrain_test
%cd /content/LMIM

# ----------------------------
# 1. install required packages
# ----------------------------
!pip install -q --no-cache-dir \
  "numpy==1.26.4" \
  "timm==0.4.12" \
  "opencv-python==4.8.1.78" \
  "imgaug==0.4.0" \
  "lmdb" \
  "easydict" \
  "tensorboardX" \
  "pyyaml" \
  "scipy" \
  "tqdm" \
  "nltk"

# ----------------------------
# 2. patch known compatibility issues
# ----------------------------
import os
import re

root = "/content/LMIM"

patched = []

for current_root, _, files in os.walk(root):
    for fn in files:
        if not fn.endswith(".py"):
            continue

        path = os.path.join(current_root, fn)
        with open(path, "r", encoding="utf-8", errors="ignore") as f:
            text = f.read()

        new_text = text

        # (a) torch._six -> torch
        new_text = new_text.replace("from torch._six import inf", "from torch import inf")

        # (b) np.float -> float (but keep np.float32 intact)
        new_text = new_text.replace("dtype=np.float)", "dtype=float)")
        new_text = new_text.replace("dtype=np.float,", "dtype=float,")
        new_text = new_text.replace("dtype=np.float\n", "dtype=float\n")

        # (c) remove qk_scale argument for old timm compatibility
        new_text = re.sub(r",\s*qk_scale\s*=\s*[^,\)\n]+", "", new_text)
        new_text = re.sub(r"qk_scale\s*=\s*[^,\)\n]+,\s*", "", new_text)

        if new_text != text:
            with open(path, "w", encoding="utf-8") as f:
                f.write(new_text)
            patched.append(path)

print("patched files:")
for p in patched:
    print("-", p)

# ----------------------------
# 3. patch data_path list -> string
# ----------------------------
main_file = "/content/LMIM/lmim_pretrain/main_pretrain.py"

with open(main_file, "r", encoding="utf-8") as f:
    text = f.read()

target = "def main(args):"
inject = """def main(args):
    if isinstance(args.data_path, list) and len(args.data_path) == 1:
        args.data_path = args.data_path[0]
"""

if target in text and "if isinstance(args.data_path, list) and len(args.data_path) == 1:" not in text:
    text = text.replace(target, inject, 1)

with open(main_file, "w", encoding="utf-8") as f:
    f.write(text)

print("patched main_pretrain.py data_path handling")

# ----------------------------
# 4. verify critical imports
# ----------------------------
import numpy as np
import timm
import cv2
import imgaug
import lmdb
import tensorboardX

print("numpy:", np.__version__)
print("timm:", timm.__version__)
print("opencv:", cv2.__version__)
print("imgaug:", imgaug.__version__)
print("lmdb: OK")
print("tensorboardX: OK")
file_path = "/content/LMIM/lmim_pretrain/main_pretrain.py"

with open(file_path, "r", encoding="utf-8") as f:
    text = f.read()

marker = "def main(args):"
inject = """def main(args):
    if isinstance(args.data_path, list) and len(args.data_path) == 1:
        args.data_path = args.data_path[0]
"""

if marker in text:
    # 기존 잘못/중복 삽입 방지
    text = text.replace(
        "def main(args):\n    if isinstance(args.data_path, list) and len(args.data_path) == 1:\n        args.data_path = args.data_path[0]\n",
        "def main(args):\n"
    )
    text = text.replace(marker, inject, 1)

with open(file_path, "w", encoding="utf-8") as f:
    f.write(text)

print("patched:", file_path)
!python /content/LMIM/lmim_pretrain/main_pretrain.py \
  --batch_size 8 \
  --epochs 1 \
  --model mae_vit_base_patch4 \
  --data_path /content/LMIM/data/IIIT5K \
  --output_dir /content/LMIM/outputs/pretrain_test \
  --log_dir /content/LMIM/logs/pretrain_test

 

지금 상태에서는 “그냥 IIIT5K 전체를 pretrain에 넣는 것”보다, train / test 역할을 분리해서 생각하는 게 맞아.

핵심은 이거야.

LMIM의 원래 흐름

  1. Pretrain: 큰 텍스트 이미지 데이터로 자기지도학습
  2. Finetune / Train: 라벨 있는 학습 데이터로 STR 학습
  3. Test / Eval: test split으로 성능 확인

LMIM 공식 저장소 설명에서도 영문 pre-train 데이터는 Union14M, 그리고 IIIT5K 같은 데이터셋은 benchmark(평가용) 로 소개돼 있어. 즉, IIIT5K는 원래 “본격 pretrain용 대형 데이터”가 아니라 검증/학습·평가용 벤치마크 성격이 강해. (GitHub)

그래서 네가 지금 한 건 이렇게 보면 돼.

  • 지금 실행한 main_pretrain.py
    • “LMIM pretraining 코드가 도는지” 확인하는 스모크 테스트
    • 연구적으로는 가능하지만,
    • IIIT5K 전체를 pretrain에 쓰면 데이터가 너무 작고, test 성격 데이터와 섞일 위험이 있어서 논문식 실험으로는 부적절해 (GitHub)

다만 무조건 pretrain 단계에서 train/test를 나누는 것보다는,
정확히는 pretrain / finetune / test의 역할을 분리해야 해.


지금 네 상황에서 가장 현실적인 정리

1) 빠른 실험용

IIIT5K만 쓸 거면:

  • train split → finetune/train
  • test split → eval/test

이렇게 써야 해.

그리고 pretrain은 생략하거나,
정말 해보고 싶으면 IIIT5K train만 가지고 toy pretrain 정도로만 돌려봐.
하지만 이건 “논문 재현”이라기보다 코드 확인용 실험에 가까워.

2) 논문 흐름에 더 가깝게

  • Pretrain: 큰 unlabeled text image 데이터
  • Finetune: IIIT5K train
  • Eval: IIIT5K test

이게 더 맞는 구조야. LMIM은 원래 그런 식의 self-supervised STR 흐름이야. (GitHub)


네가 지금 바로 해야 할 판단

경우 A. “일단 돌아가게 하고 싶다”

그러면:

  • pretrain은 일단 toy 실험
  • 이후 IIIT5K train/test 나눠서 학습/평가 코드로 넘어가기

경우 B. “논문 발표/PPT용으로 맞게 설명하고 싶다”

그러면 발표에서는 이렇게 말하면 돼:

“현재는 IIIT5K로 LMIM pretraining 코드 동작을 확인하는 수준의 예비 실험을 수행했고,
실제 실험에서는 pretraining용 대규모 데이터와 benchmark용 IIIT5K train/test를 분리하는 것이 타당하다.”

이렇게 말하면 완전 자연스러워.


아주 중요

IIIT5K는 원래 train / test split이 존재하는 데이터셋이야. 변환본 설명에도 cropped word images가 train/test로 나뉘어 있음이 명시돼 있어. (GitHub)

그래서 네 폴더가 이미 이런 식이면:

  • /content/drive/MyDrive/datasets/IIIT5K/train
  • /content/drive/MyDrive/datasets/IIIT5K/test

이 구조를 기준으로 써야 해.


한 줄 결론

지금처럼 IIIT5K 전체를 pretrain에 넣는 건 “코드 점검용”으로는 가능하지만, 제대로 된 실험 흐름이라면 pretrain용 데이터와 IIIT5K train/test를 분리해서 써야 해.

원하면 내가 다음으로 바로
“네 현재 폴더 구조 기준으로 train/test 나눠서 finetune + eval 돌리는 Colab 코드” 형태로 이어서 정리해줄게.

이렇게 만들 거야.

목표

  1. IIIT5K의 train만 pretrain에 연결
  2. 나중에 test는 evaluation용으로 따로 보관
  3. 우선은 데이터 경로를 명확히 분리해서 헷갈리지 않게 만들기

즉, 지금처럼

  • /content/LMIM/data/IIIT5K -> IIIT5K 전체

로 두지 말고,

  • /content/LMIM/data/IIIT5K_train -> train만 연결
  • /content/LMIM/data/IIIT5K_test -> test만 연결

이렇게 갈 거야.


1단계: 현재 데이터 구조 확인

먼저 진짜로 train, test가 어떻게 들어있는지 확인해.

import os

DATA_ROOT = "/content/drive/MyDrive/datasets/IIIT5K"

print("DATA_ROOT exists:", os.path.exists(DATA_ROOT))
print("root contents:", os.listdir(DATA_ROOT))

train_dir = os.path.join(DATA_ROOT, "train")
test_dir = os.path.join(DATA_ROOT, "test")

print("train exists:", os.path.exists(train_dir))
print("test exists:", os.path.exists(test_dir))

if os.path.exists(train_dir):
    print("train sample:", os.listdir(train_dir)[:10])

if os.path.exists(test_dir):
    print("test sample:", os.listdir(test_dir)[:10])

2단계: LMIM/data 아래에 train/test 심볼릭 링크 따로 만들기

이제 이걸 분리해서 연결.

import os
import shutil

DATA_ROOT = "/content/drive/MyDrive/datasets/IIIT5K"
LMIM_DATA_ROOT = "/content/LMIM/data"

TRAIN_SRC = os.path.join(DATA_ROOT, "train")
TEST_SRC = os.path.join(DATA_ROOT, "test")

TRAIN_LINK = os.path.join(LMIM_DATA_ROOT, "IIIT5K_train")
TEST_LINK = os.path.join(LMIM_DATA_ROOT, "IIIT5K_test")

os.makedirs(LMIM_DATA_ROOT, exist_ok=True)

def safe_link(src, dst):
    if os.path.islink(dst):
        os.unlink(dst)
    elif os.path.exists(dst):
        shutil.rmtree(dst)
    os.symlink(src, dst)
    print(f"linked: {dst} -> {src}")

safe_link(TRAIN_SRC, TRAIN_LINK)
safe_link(TEST_SRC, TEST_LINK)

print("LMIM data contents:", os.listdir(LMIM_DATA_ROOT))
print("train link check:", os.listdir(TRAIN_LINK)[:10])
print("test link check:", os.listdir(TEST_LINK)[:10])

3단계: pretrain은 train만 사용

지금 네 상황에서는 toy pretrain이니까 train만 넣자.

!python /content/LMIM/lmim_pretrain/main_pretrain.py \
  --batch_size 8 \
  --epochs 1 \
  --model mae_vit_base_patch4 \
  --data_path /content/LMIM/data/IIIT5K_train \
  --output_dir /content/LMIM/outputs/pretrain_train_only \
  --log_dir /content/LMIM/logs/pretrain_train_only

4단계: 데이터 개수도 한번 확인

이건 실험 기록용으로 좋다.

import os

train_path = "/content/LMIM/data/IIIT5K_train"
test_path = "/content/LMIM/data/IIIT5K_test"

train_files = [f for f in os.listdir(train_path) if os.path.isfile(os.path.join(train_path, f))]
test_files = [f for f in os.listdir(test_path) if os.path.isfile(os.path.join(test_path, f))]

print("num train images:", len(train_files))
print("num test images:", len(test_files))
print("train sample:", train_files[:5])
print("test sample:", test_files[:5])

왜 이렇게 짜는 거냐면

지금 네가 헷갈리는 포인트가 이거야.

기존 방식

  • IIIT5K 전체를 /content/LMIM/data/IIIT5K에 연결
  • 그러면 train/test 구분 없이 다 들어감

이건 코드 확인용으로는 되지만,
실험 관점에서는 평가 데이터가 섞일 수 있어서 깔끔하지 않음

내가 제안한 방식

  • IIIT5K_train
  • IIIT5K_test

를 아예 분리

그러면 네가 나중에 PPT에서 설명할 때도 명확해져.

“IIIT5K 데이터셋의 train split은 학습/예비 pretraining에 사용하고, test split은 평가용으로 분리하였다.”

이 문장이 바로 가능해.


네 다음 단계는 사실 더 중요해

지금은 pretrain 코드만 돌리는 상태잖아.
근데 진짜 보고 싶은 건 보통 이거야.

  1. 이미지 넣음
  2. 모델이 텍스트를 어떻게 예측했는지 확인
  3. gt와 pred 비교
  4. 맞았는지 틀렸는지 CSV로 저장

즉, 너한테 필요한 건 사실 단순 pretrain보다도
test 이미지에 대해 inference 돌리고 결과 보는 코드에 더 가까워.


그래서 내 추천 순서

지금 바로 할 것

위 코드로

  • train/test 분리
  • train만 pretrain

그 다음

  • test 셋에서 예측 결과 뽑는 코드
  • image_name / gt / pred / correct 형태 CSV 생성

이걸 해야 네가
“LMIM이 실제로 어떻게 인식했는지”
볼 수 있어.


진짜 바로 쓸 수 있게 한 셀로 정리하면

import os
import shutil

# -----------------------------
# 1. paths
# -----------------------------
DATA_ROOT = "/content/drive/MyDrive/datasets/IIIT5K"
LMIM_DATA_ROOT = "/content/LMIM/data"

TRAIN_SRC = os.path.join(DATA_ROOT, "train")
TEST_SRC = os.path.join(DATA_ROOT, "test")

TRAIN_LINK = os.path.join(LMIM_DATA_ROOT, "IIIT5K_train")
TEST_LINK = os.path.join(LMIM_DATA_ROOT, "IIIT5K_test")

# -----------------------------
# 2. checks
# -----------------------------
assert os.path.exists(DATA_ROOT), f"DATA_ROOT not found: {DATA_ROOT}"
assert os.path.exists(TRAIN_SRC), f"TRAIN_SRC not found: {TRAIN_SRC}"
assert os.path.exists(TEST_SRC), f"TEST_SRC not found: {TEST_SRC}"

os.makedirs(LMIM_DATA_ROOT, exist_ok=True)

# -----------------------------
# 3. safe symlink
# -----------------------------
def safe_link(src, dst):
    if os.path.islink(dst):
        os.unlink(dst)
    elif os.path.exists(dst):
        shutil.rmtree(dst)
    os.symlink(src, dst)
    print(f"linked: {dst} -> {src}")

safe_link(TRAIN_SRC, TRAIN_LINK)
safe_link(TEST_SRC, TEST_LINK)

# -----------------------------
# 4. count files
# -----------------------------
train_files = [f for f in os.listdir(TRAIN_LINK) if os.path.isfile(os.path.join(TRAIN_LINK, f))]
test_files = [f for f in os.listdir(TEST_LINK) if os.path.isfile(os.path.join(TEST_LINK, f))]

print("num train images:", len(train_files))
print("num test images:", len(test_files))
print("train sample:", train_files[:5])
print("test sample:", test_files[:5])

그리고 이어서:

!python /content/LMIM/lmim_pretrain/main_pretrain.py \
  --batch_size 8 \
  --epochs 1 \
  --model mae_vit_base_patch4 \
  --data_path /content/LMIM/data/IIIT5K_train \
  --output_dir /content/LMIM/outputs/pretrain_train_only \
  --log_dir /content/LMIM/logs/pretrain_train_only

지금 너한테 더 필요한 건 다음 단계인 “test 이미지 넣고 gt/pred 비교하는 inference 코드” 쪽이야. 그걸 바로 이어서 짜줄게.

댓글