Back-end

OOM 문제 해결

hjr067 2025. 4. 19. 02:01

Websocket + FastAPI로 구현했던 기능에서 계속 OOM 에러가 떠서.. 

아예 구성을 바꿔야 겠다고 생각했다.

 

 

[🧱 전체 구성 요약]

📱 Client
   ↓ WebSocket 요청
🌐 FastAPI
   → Celery task 발송
   → WebSocket 열고 로그 수신/전송
🐇 Celery Worker
   → model.fit() 수행
   → 로그를 Redis에 저장 or 직접 WebSocket으로 push
🐘 Redis
   → Task queue / 로그 전달용

✅ 구성 요소 요약

역할파일설명
Celery task 실행 tasks/train.py 모델 학습
Redis 로그 전송 utils/pubsub.py 학습 로그 publish
WebSocket 수신 ws/log_consumer.py Redis subscribe 후 클라이언트에 전송
FastAPI 라우터 routers/ws_train.py 학습 요청 → task 실행
Celery 설정 celery_worker.py celery worker 실행
서비스 통합 docker-compose.yml FastAPI + Redis + Celery 실행

 

🔧 1단계: Redis + Celery 설정

📦 설치

requirements.txt 또는 pip에 추가:

celery[redis]
redis

 

🗂️ 2단계: 디렉터리 구조 예시

 
app/
├── main.py                  ← FastAPI 실행
├── celery_worker.py         ← Celery 실행 스크립트
├── tasks/
│   └── train.py             ← 학습 로직 정의
├── ws/
│   └── handler.py           ← WebSocket 처리

🧠 3단계: Celery 설정 (app/celery_worker.py)

# app/celery_worker.py
from celery import Celery
import os

REDIS_URL = os.getenv("REDIS_URL", "redis://redis:6379/0")

celery_app = Celery(
    "worker",
    broker=REDIS_URL,
    backend=REDIS_URL,
)

celery_app.conf.task_routes = {
    "tasks.train.run_training": {"queue": "training"},
}

 

🧪 4단계: 학습 task 정의 (tasks/train.py)

# app/tasks/train.py
from app.celery_worker import celery_app
from app.utils.pubsub import publish_log
import tensorflow as tf
import numpy as np
import json

@celery_app.task
def run_training(user_id, model_code, epochs, batch_size, learning_rate):
    exec_globals = {}
    exec(model_code, exec_globals)

    model = exec_globals["model"]
    x_train = exec_globals["x_train"]
    y_train = exec_globals["y_train"]

    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate),
        loss="categorical_crossentropy",
        metrics=["accuracy"]
    )

    for epoch in range(epochs):
        history = model.fit(x_train, y_train, epochs=1, batch_size=batch_size, verbose=0)
        acc = round(float(history.history['accuracy'][0]) * 100, 2)
        loss = round(float(history.history['loss'][0]), 4)

        publish_log(f"user:{user_id}", {
            "type": "epoch_log",
            "epoch": epoch + 1,
            "accuracy": acc,
            "loss": loss,
        })

    publish_log(f"user:{user_id}", {
        "type": "final",
        "status": "학습 완료"
    })

    return {"message": "done"}

✅ app/utils/pubsub.py

# app/utils/pubsub.py
import redis
import os
import json

REDIS_URL = os.getenv("REDIS_URL", "redis://redis:6379/0")
r = redis.Redis.from_url(REDIS_URL)

def publish_log(channel: str, data: dict):
    r.publish(channel, json.dumps(data))

def subscribe_log(channel: str):
    pubsub = r.pubsub()
    pubsub.subscribe(channel)
    return pubsub

🔁 5단계: FastAPI에서 Celery task 호출 (main.py 또는 ws_train.py)

# app/routers/ws_train.py
from fastapi import APIRouter, WebSocket
from tasks.train import run_training
from app.utils.pubsub import subscribe_log
import asyncio
import json

router = APIRouter()

@router.websocket("/ws/train")
async def websocket_train(websocket: WebSocket):
    await websocket.accept()
    try:
        data = await websocket.receive_json()
        user_id = str(data["user_id"])
        model_code = data["code"]
        form = data["form"]

        task = run_training.delay(
            user_id,
            model_code,
            form["epochs"],
            form["batch_size"],
            form["learning_rate"]
        )

        pubsub = subscribe_log(f"user:{user_id}")

        while True:
            message = pubsub.get_message(ignore_subscribe_messages=True, timeout=1)
            if message:
                await websocket.send_text(message["data"].decode("utf-8"))
            await asyncio.sleep(0.5)

    except Exception as e:
        await websocket.send_json({"error": str(e)})
    finally:
        await websocket.close()

 

📡 6단계: 실시간 로그 전달

Redis Pub/Sub (강력 추천)

  • 학습 로그를 Redis channel로 publish
  • FastAPI WebSocket이 Redis를 subscribe해서 실시간으로 전송

 

🚀 Celery worker 실행

celery -A app.celery_worker.celery_app worker --loglevel=info -Q training
 

 

docker-compose.yml

# docker-compose.yml
version: '3.8'

services:
  fastapi_app:
    build: .
    container_name: fastapi_app
    ports:
      - "8000:8000"
    depends_on:
      - redis
    env_file:
      - .env
    volumes:
      - .:/app

  redis:
    image: redis:alpine
    container_name: redis
    ports:
      - "6379:6379"

  celery_worker:
    build: .
    container_name: celery_worker
    command: celery -A app.celery_worker.celery_app worker --loglevel=info -Q training
    depends_on:
      - redis
    env_file:
      - .env
    volumes:
      - .:/app