ML/머신러닝

샴 네트워크(Siamese Network)와 삼중항 손실 (Triplet loss)

KAU 2020. 3. 6. 17:47

샴 네트워크(Siamese Network)

샴 네트워크 구조

샴 네트워크는 무엇일까?

두 개의 입력에 대해 독립적으로 두 개의 합성곱 신경망을 실행한 뒤 비교하는 아이디어이다.

 

샴 네트워크 구조를 간단히 설명하자면 기존의 컨볼루션 네트워크를 통해서 피처맵을 뽑아낸다. 

두개의 이미지에서 피처맵을 뽑아낸 이후에 거리를 계산해 본다. 

거리가 작다면 두사람이 비슷하다는것이고 

거리가 크다면 두 사람이 다른 사람이라는것이다.

 거리는 두 벡터사이의 노름으로 정의함

 

  • 두 네트워크에 두 사진을 입력으로 넣고 합성곱 신경망으로 인코딩을 시킨다.
  • 만약에 두 사람이 비슷한 사람이라면 인코딩 사이의 거리 값은 작아야한다.
  • 만약에 두 사람이 다른 사람이라면 인코딩 사이의 거리 값은 커야한다.
  • 위 조건을 만족시키도록 학습을 시켜야 한다.

삼중항 손실 (Triplet loss)

학습을 통해서 긍정이미지에 대한 거리는 줄이고 부정이미지에 대한 거리는 늘리는것이다

간단하게 말해 앵커 이미지와 긍정이미지 그리고 부정이미지를 본다는 의미다. 

 

긍정이미지와 부정이미지? 무슨 뜻일까 

쉽게 말해 동일 인물이면 positive 다른인물이면 Negative

기준이 되는 이미지를  A , 긍정 이미지를  P, 부정 이미지를  N으로 두면 

긍정이미지의 거리가 부정이미지의 거리보다 작아야 한다.

하지만 위 식은 조금 더 수정해줄 필요가 있다.

 

positive와의 거리차 : 0.3 / negative와의 거리차 : 0.31

 

위와 같은 경우에 두 이미지의 차이가 있다고 할 수 있을까?

분명 negative와의 거리가 positive와의 거리보다 멀지만 이 수치만을 봤을 때 anchor와 positive는 같은 클래스, negative는 다른 클래스라고는 단언할 수 없을 것 이다.

마치 같은 클래스의 데이터인 것처럼 거리 차가차가 얼마 나지 않는다.

여기서 α는 positive와의 거리보다 negative의 거리가 다른 클래스라는 것을 보증할만큼 충분히 멀다는 걸

보증해 줄 파라미터입니다.

수정된 식

 

max 함수를 이용한 단일 손실함수
전체 손실함수

훈련세트를 만들 때 학습하기 어렵게 만들어야한다. 

그 이유는 무작위로 추출을 통하면 d(A,P)+a<d(A,N)의 조건이 너무 간단하게 만족이되어

학습이 제대로 이루어 지지 않기 때문이다.

 

그렇다면 어떻게 해야하는가

학습이 어렵게 거리가 비슷한 이미지를 선택하면 된다.
다른 인물일 때 거리가 길다
반대로 동일인물일 때는 거리가 짧게 나오는 모습