Study/Python

[Tensorflow] Object Detection API로 학습&추론 -1

Omoknooni 2023. 4. 2. 23:22

이제 API로 인식시킬 데이터셋을 학습시켜보자

 

학습까지의 과정

데이터 수집 > 데이터 라벨링 > 단일 csv 파일 생성 > TFRecord 생성 > Label map 생성 > pre-trained model 다운로드 & config 설정 > 이미지 학습

 

 

데이터 수집

학습에 사용할 이미지를 수집하는 단계로 적당한 크기의 객체가 들어있는 이미지를 수집한다.

 

데이터 라벨링

학습에 사용하기 위해 수집한 이미지 데이터를 라벨링하는 작업이 필요하다.

LableImg라는 오픈소스 툴을 이용해 작업을 진행한다.

 

# labelImg 설치
pip install labelImg

# labelImg 실행
labelImg

 

labelImg 설치 후 실행, 'Open Dir' 로 이미지가 저장된 폴더 선택하면 폴더 내의 모든 이미지를 불러온다.

 

이미지를 불러온 후, 'Create RectBox' 클릭 혹은 단축키 'w'로 이미지 캡쳐모드로 들어간다.

학습할 부분을 선택 후, 다음 팝업창에서 해당 부분을 어떤 객체로 인식할 것인지 입력한다.

학습할 부분을 선택하는 것을 이 작업을 라벨링이라고 한다.

 

 

이후 해당 이미지 내에서 모든 학습할 부분을 캡쳐하고 'Save'로 저장한다.

저장하면 해당 이미지에서 라벨링한 내용들이 같은 폴더내의 같은 이름의 xml 파일로 저장된다

 

이제 이 이미지 라벨링 데이터셋을 학습(train)에 이용할 것평가확인(test)에 사용할 것으로 분리해줘야 한다.

이 둘의 비율은 보통 train 7: test 3 내지 train 8: test 2 정도로 해준다고 한다.

단일 csv 파일 생성

데이터 라벨링의 결과로 각 이미지 별로 xml파일이 생성되었다. 

이 xml파일 내의 값들을 사용처 별로 각각 하나의 csv 파일로 통합시켜줘야 한다.

즉, train용으로 분리한 데이터셋에서 csv 파일 한 개, test용으로 분리한 데이터셋에서 csv 파일 한 개를 생성해줘야한다.

 

생성된 csv파일의 내부값은 다음과 같다. 라벨링된 각 객체의 x,y좌표 값이 들어가있는 것을 볼 수 있다.

 

 

Label map 생성

tensorflow가 동작하기 위해서 라벨맵(Label map)을 설정해줘야한다.

라벨맵은 인식할 객체의 정보를 담은 파일로 위의 데이터셋에서 라벨링했던 내용을 포함한다.

 

 id값은 1부터 시작하며, 파일 확장자는 pbtxt로 저장한다.

item {
    id: 1
    name: 'cat'
}

item {
    id: 2
    name: 'dog'
}

 

 

 

TFRecord 생성

TFRecord 파일은 tensorflow 학습을 하는데 필요한 데이터들을 보관하기 위한 바이너리 데이터 포맷으로, 학습을 시키기 위해서 tfrecord파일을 생성해줘야한다.

 

앞서 생성한 csv 파일을 기반으로 TFRecord파일을 생성한다.

 

code : https://github.com/datitran/raccoon_dataset/blob/master/generate_tfrecord.py

위 코드에서 앞서 생성한 라벨맵을 바탕으로 class_text_to_int()의 내용을 수정한다.

def class_text_to_int(row_label):
    if row_label == 'cat':	# 라벨맵 객체의 name
        return 1			# 반환값은 해당 객체의 id
    elif row_label == 'dog':
        return 2
    else:
    	None

 

# test용 데이터셋 TFRecord 생성
python generate_tfrecord.py --csv_input=[test용 데이터셋의 csv파일 경로] --image_dir=[test용 데이터셋의 경로] --output_path=test.tfrecord

# train용 데이터셋 TFRecord 생성
python generate_tfrecord.py --csv_input=[train용 데이터셋의 csv파일 경로] --image_dir=[train용 데이터셋의 경로] --output_path=train.tfrecord

 

 

pre-trained model 다운로드

학습에는 많은 데이터가 필요한데, 일반 개인단위에서는 구하기도 힘들고 오래걸리기 마련이다.

따라서, 대규모 학습 데이터를 기반으로 생성된 모델을 배포해서 개인 등 일반 사용자가 쉽게 사용할 수 있도록 했다.

공식적으로 배포하는 사전 훈련된 모델(pre-trained model)은 Tensorflow 2 Detection Model Zoo 에서 찾을 수 있다.  

우리는 사전에 학습된 모델을 가져와서 우리가 학습시킬 데이터를 추가로 학습을 시키는 구조를 사용할 것이다.

평균 탐색시간(Speed)과 평균 객체 인식률(COCO mAP)을 비교해 적절한 모델을 다운로드

모델의 구조는 아래와 같다

├─ssd_mobilenet_v2_fpnlite_640x640_coco17_tpu-8
│  │  pipeline.config
│  │
│  ├─checkpoint
│  │      checkpoint
│  │      ckpt-0.data-00000-of-00001
│  │      ckpt-0.index
│  │
│  └─saved_model
│      │  saved_model.pb
│      │
│      └─variables
│              variables.data-00000-of-00001
│              variables.index

checkpoint가 우리가 학습을 시작할 pre-trained model의 경로라고 보면 된다.

 

pipeline.config 설정

학습에 사용할 설정파일인 pipeline.config에서 모델 파라미터를 설정한다. pre-trained 모델에 있는 pipeline.config는 대규모 데이터를 학습했을 당시의 모델 파라미터값으로 우리가 학습을 할 때에는 이 값을 학습 환경에 적절하게 조절해 주어야 한다.

 

이 파일에서 주요하게 변경할 부분은 다음과 같다.

# 주요 수정 구간
num_classes: [label map에서 작성한 인식할 객체의 종류 갯수]
...
batch_size: [연산 한 번에 들어가는 데이터의 크기, 적당히 1정도, 
메모리 용량이나 GPU 메모리에 여유가 있는 경우 4까지도 괜찮음, 너무 크면 OOM발생]
...
fine_tune_checkpoint: [학습을 시작할 checkpoint의 경로, pre-trained 모델의 checkpoint/ckpt-0]
num_steps: [학습 step수]
fine_tune_checkpoint_type: "detection" [객체를 detection할 것이므로 detection으로 수정]
...
train_input_reader {
  label_map_path: [label map.pbtxt의 경로]
  tf_record_input_reader {
    input_path: [train용 tfrecord의 경로]
  }
}
...
eval_input_reader {
  label_map_path: [label map.pbtxt의 경로]
  shuffle: false
  num_epochs: 1
  tf_record_input_reader {
    input_path: [test용 tfrecord의 경로]
  }
}

 

학습

위의 pipeline.config 세팅까지 완료되면 학습을 시작

python model_main_tf2.py —-pipeline_config_path=[pipeline.config의 경로] -—model_dir=[학습 결과물들이 저장될 경로] —-logtostderr

 

학습이 시작되면 중간중간 진행된 step 수와 loss, learning rate 값들이 출력된다

 

이러한 학습 과정을 모니터링 할 수 있는 툴로 tensorboard가 있다.

API 설치시 같이 설치되므로 학습 실행 시켜두고 다음 커맨드 실행

tensorboard --logdir=[학습 결과물들이 저장될 경로]

 

실행되면 localhost에 6006(기본 포트값)번 포트로 텐서보드가 실행된다