Using the visitor pattern instead of "dynamic_cast" in C++

Visitor 패턴을 사용하여 C++ dynamic_cast 지우기

C++에서 다형성을 만족하는 클래스들에게 각각 독립적인 로직을 갖는 함수를 사용하기 위해서는 가상 함수 테이블을 사용하는 것이 일반적이다. 클래스의 수가 늘어날 때 그리고 함수의 수가 늘어날 때를 고려하거나 다수의 클래스가 같은 로직을 사용할 때의 경우를 고려하여 다양한 디자인 패턴들이 소개가 되었다. 그 중에서도 특히 클래스의 수는 늘어날 가능성이 매우 희박하면서도, 클래스 특수한(class specific) 함수들의 수는 늘어나기 쉬운 경우에 대해서 일관된 로직 플로우 또는 패턴을 마련해야 할 필요성이 생겼다. 이 글에서 naïve한 접근 방식을 소개하고 이를 보안한 Visitor 패턴을 사용해 어떻게 구현하였는지 서술하려고 한다.

다음 코드조각은 전체 실험에서 사용할 기본적인 클래스 정의를 나타낸다. Base라는 한 개의 순수 가상 클래스를 정의하였고, 이를 상속받은 8개의 클래스들을 정의하였다. Visitor를 받아들이기 위한 Invite라는 멤버 함수가 가상함수로 정의되어 있다. 상속받은 자식 클래스에서도 역시 재정의하였다. 이 때, 각 Derived 클래스의 Invite 함수의 스코프에서 *this의 타입형은 자기자신 클래스의 레퍼런스가 될 것이고, 이 를 정보로 삼아 함수 오버로딩을 활용하여 각 Visitor에서 클래스 특화 함수를 호출할 수 있게 하였다.

class Base {
  protected:
    // 아주 많은 데이터
    int data[100];
  public:
    Base() {}
    virtual ~Base() {}
    
    virtual void Func() = 0;
    virtual void Invite(Visitor& visitor) {
        visitor.Visit(*this);
    }

    virtual int Id() = 0;
};

#define MAKE_DERIVED(id) \
class Derived##id : public Base { \
public: \
    int cast_id; \
    Derived##id() : Base() {} \
    Derived##id(int a) : Base(), cast_id(a) {} \
    virtual ~Derived##id() {} \
    virtual void Func() override { } \
    virtual void Invite(Visitor& visitor) override { \
        visitor.Visit(*this); \
    } \
    virtual int Id() override { \
        return cast_id; \
    } \
};

MAKE_DERIVED(A)
MAKE_DERIVED(B)
MAKE_DERIVED(C)
MAKE_DERIVED(D)
MAKE_DERIVED(E)
MAKE_DERIVED(F)
MAKE_DERIVED(G)
MAKE_DERIVED(H)

/// After preprocessing the above macro
class DerivedA : public Base {
  public:
    int cast_id;
    DerivedA() : Base() {}
    DerivedA(int a) : Base(), cast_id(a) {}
    virtual ~DerivedA() {}
    virtual void Func() override { }
    virtual void Invite(Visitor& visitor) override {
        visitor.Visit(*this);
    }
    virtual int Id() override {
        return cast_id;
    }
};


다음은 테스트를 하기 위한 데이터 준비작업이다. 10,000,000개의 클래스 인스턴스를 준비하였다.

const static size_t test_size = 10000000;

auto now = std::chrono::high_resolution_clock::now;
auto begin = now();

std::random_device rd;
std::mt19937 mt(rd());
std::uniform_int_distribution<> dis(0, 7);
std::vector<std::unique_ptr<Base>> v;
v.reserve(test_size);
for(size_t i = 0 ; i < test_size ; i++) {
    int id = dis(mt);
    switch (id) {
    case 0:
        v.emplace_back(new DerivedA(1 * 10 + dis(mt)));
        break;
    case 1:
        v.emplace_back(new DerivedB(2 * 10 + dis(mt)));
        break;        
    case 2:
        v.emplace_back(new DerivedC(3 * 10 + dis(mt)));
        break;
    case 3:
        v.emplace_back(new DerivedD(4 * 10 + dis(mt)));
        break;        
    case 4:
        v.emplace_back(new DerivedE(5 * 10 + dis(mt)));
        break;        
    case 5:
        v.emplace_back(new DerivedF(6 * 10 + dis(mt)));
        break;        
    case 6:
        v.emplace_back(new DerivedG(7 * 10 + dis(mt)));
        break;        
    case 7:
        v.emplace_back(new DerivedH(8 * 10 + dis(mt)));
        break;
    default:
        break;
    }
}
auto end = now();
auto duration =
std::chrono::duration_cast<std::chrono::microseconds>(end - begin).count();
std::cout << "Test preparation\n" << std::setw(20)
          << std::right << duration << "us" << std::endl;

가장 naïve한 접근 방법은, 부모 클래스 타입의 포인터에 dynamic_cast를 사용하는 것이다. 이러한 cascading dyamic_cast 방식은 많은 RTTI에 기반한 비교연산을 동반하기 때문에 성능이 좋을 것이라고 기대하기 어렵다.

auto begin = now();
int cnt[8];
for(auto&& i : v) {
    auto&& p = i.get();
    if(auto ptr = dynamic_cast<DerivedA*>(p)) {
        cnt[0]++;
    } else if(auto ptr = dynamic_cast<DerivedB*>(p)) {
        cnt[1]++;
    } else if(auto ptr = dynamic_cast<DerivedC*>(p)) {
        cnt[2]++;
    } else if(auto ptr = dynamic_cast<DerivedD*>(p)) {
        cnt[3]++;
    } else if(auto ptr = dynamic_cast<DerivedE*>(p)) {
        cnt[4]++;
    } else if(auto ptr = dynamic_cast<DerivedF*>(p)) {
        cnt[5]++;
    } else if(auto ptr = dynamic_cast<DerivedG*>(p)) {
        cnt[6]++;
    } else if(auto ptr = dynamic_cast<DerivedH*>(p)) {
        cnt[7]++;
    } else {
        throw std::bad_cast();
    }
}

auto end = now();
auto duration
  = std::chrono::duration_cast<std::chrono::microseconds>(end - begin).count();
std::cout << "Dynamic cast\n" << std::setw(20)
          << std::right << duration << "us" << std::endl;
std::cout << "---------validation----------\n";
std::cout << cnt[0] << " : " << cnt[1] << " : " << cnt[2] << " |\n"
          << cnt[3] << " : " << cnt[4] << " : " << cnt[5] << " |\n"
          << cnt[6] << " : " << cnt[7] << std::endl;
std::cout << "-----------------------------\n\n";

다음 방법은 Visitor를 활용한 방법이다. 사용할 Visitor의 인터페이스는 다음과 같다. 먼저 각 Derived class마다 Visit 함수를 지니고 있고, 각 Visit 함수 내부에서는 가상 함수인 DoWork함수를 호출하게 된다. Visit함수와 DoWork함수는 derived 클래스 별로 오버로딩이 되어있어서 타입을 알맞게 찾을 수 있다. 즉, 전체적인 플로우는 다음과 같다. Base 클래스와 Derived 클래스들의 vtable을 사용해 Visitor를 인자로 받으면, 자기 자신의 타입에 해당하는 Visit 함수를 호출하게 된다. 이 때 Derived 클래스의 구체적인 타입은 알 수 있는 상태이기 때문에 알맞은 DoWork 함수를 호출할 수 있다. 이 때 모든 가상 함수를 오버라이딩 했다면, Base 타입의 객체를 받은 DoWork함수는 필요없고, 또 호출되지 않을 것이다. 대신에 이를 기본행동을 지정하는 함수를 만들 경우에 사용할 수 있다. 특정 클래스, DerivedA와 DerivedB에만 특정 로직을 동작시키고, 그 외에는 기본 로직을 활용하고 싶다면 Base를 받는 경우를 기본 로직을 위한 함수로 설정할 수 있다. 이럴 경우 Visitor 인터페이스의 DoWork 함수들의 기본 로직을 아래 코드조각의 맨 하단 부분처럼 바꿔야 할 필요가 있을 것이다.

#define DERIVED_FORWARD_DECL(id) \
class Derived##id;

#define MAKE_VISIT_FUNC(id) \
void Visit(const Derived##id& tgt) { \
    DoWork(tgt); \
}

#define MAKE_WORK_FUNC(id, logic) \
virtual void DoWork(const Derived##id& tgt) { \
    logic \
}

class Visitor {
  public:
    virtual ~Visitor() {}

    void Visit(const Base& tgt) {
        DoWork(tgt);
    }

    MAKE_VISIT_FUNC(A)
    MAKE_VISIT_FUNC(B)
    MAKE_VISIT_FUNC(C)
    MAKE_VISIT_FUNC(D)
    MAKE_VISIT_FUNC(E)
    MAKE_VISIT_FUNC(F)
    MAKE_VISIT_FUNC(G)
    MAKE_VISIT_FUNC(H)

    virtual void DoWork(const Base& tgt) {
        throw std::logic_error("Not Reachable");
    }
    MAKE_WORK_FUNC(A, ;)
    MAKE_WORK_FUNC(B, ;)
    MAKE_WORK_FUNC(C, ;)
    MAKE_WORK_FUNC(D, ;)
    MAKE_WORK_FUNC(E, ;)
    MAKE_WORK_FUNC(F, ;)
    MAKE_WORK_FUNC(G, ;)
    MAKE_WORK_FUNC(H, ;)

};

/// DoWork에 default function을 사용할 경우
void Visiter::DoWork(const DerivedB& tgt) {
    // explicit upcasting
    DoWork(static_cast<const Base&>(tgt);
}

다음은 cascading dynamic_cast를 사용했을 때와 같은 작동을 하는 CountVisitor에 대한 정의와 테스트 코드이다.

auto begin = now();
class CountVisitor : public Visitor {
public:
    CountVisitor() : cnt{0} {}
    virtual ~CountVisitor() = default;

    MAKE_WORK_FUNC(A, cnt[0]++;)
    MAKE_WORK_FUNC(B, cnt[1]++;)
    MAKE_WORK_FUNC(C, cnt[2]++;)
    MAKE_WORK_FUNC(D, cnt[3]++;)
    MAKE_WORK_FUNC(E, cnt[4]++;)
    MAKE_WORK_FUNC(F, cnt[5]++;)
    MAKE_WORK_FUNC(G, cnt[6]++;)
    MAKE_WORK_FUNC(H, cnt[7]++;)

    int cnt[8];
};

CountVisitor cv;
for(auto&& i : v) {
    i->Invite(cv);
}       

auto end = now();
auto duration
    = std::chrono::duration_cast<std::chrono::microseconds>(end - begin).count();
std::cout << "Dynamic cast\n" << std::setw(20)
          << std::right << duration << "us" << std::endl;
std::cout << "---------validation----------\n";
std::cout << cnt[0] << " : " << cnt[1] << " : " << cnt[2] << " |\n"
          << cnt[3] << " : " << cnt[4] << " : " << cnt[5] << " |\n"
          << cnt[6] << " : " << cnt[7] << std::endl;
std::cout << "-----------------------------\n\n";

이를 바탕으로 수행 시간을 측정해 보았다. 테스트 환경은 인텔 i5-8500이며 Ubuntu 18.04 Linux 5.0.0-27에서 측정되었다. g++ 7.4.0버전에 -O3 최적화 옵션으로 컴파일 하였다. Cascading dynamic_cast의 경우 약 1,000,000 마이크로초 (= 1초)가 소요되었다. 1,100,000 마이크로초가 넘는 테스트 회수도 있었다. Visitor를 사용할 경우, 150,000 마이크로초에서 160,000 마이크로초 사이의 시간이 소요되었다. 약 6배 정도의 성능 차이가 있는 것을 확인하였다. 다음은 특정 테스트 회수의 결과이다.

Test preparation
             2088129us
Dynamic cast
             1087401us
---------validation----------
1249520 : 1249369 : 1249623 |
1250041 : 1249819 : 1251194 |
1250905 : 1249529           |
-----------------------------
Visitor
              156315us
---------validation----------
1249520 : 1249369 : 1249623 |
1250041 : 1249819 : 1251194 |
1250905 : 1249529           |
-----------------------------

최적화 옵션을 사용하지 않았을 경우는 그 차이가 1.6배정도 (Visitor가 더 빠름) 으로 줄어들기는 하지만 그럼에도 차이가 유의미하다고 판단할 수 있을 것이다. -O1 또는 -O2의 경우에도 -O3와 유사한 비율의 성능 차이를 확인할 수 있었다.

Type checking + dynamic_cast

C++이 dynamic_cast를 하기위해 런타임 타입 정보를 얻는 방법에 대해서는 규정을 하고있지는 않지만, 각 클래스 내부에 타입에 대한 (엄격히 관리된) 정보를 int등의 타입으로 유지한다면 int타입에 대한 비교 연산이 RTTI을 사용한 dynamic_cast보다 빠를 것 같았다. 물론 타입에 대한 정보를 클래스 내부에 유지하는 것은 권장할만한 행동은 아니지만, 어째뜬 우리는 int타입의 값을 보고 클래스의 실제 타입을 추론해 볼 것이다. 테스트 케이스를 생성한 코드 또는 Base & Derived 클래스를 정의한 코드를 보면 정수의 값으로 타입 정보를 저장하였다. 10의 자리의 수를 보고 타입을 추론할 수 있을 것이다. 이를 바탕으로 다음과 같이 코드를 작성할 수 있다.

auto begin = now();
int cnt[8] = {0};
for(auto&& i : v) {
    auto&& p = i.get();  
    int id = p->Id() / 10;          
    switch (id) {
    case 1:
        {
            auto ptr = dynamic_cast<DerivedA*>(p);
            cnt[0]++;
        }
        break;
    case 2:
        {
            auto ptr = dynamic_cast<DerivedB*>(p);
            cnt[1]++;
        }
        break;        
    case 3:
        {
            auto ptr = dynamic_cast<DerivedC*>(p);
            cnt[2]++;
        }
        break;
    case 4:
        {
            auto ptr = dynamic_cast<DerivedD*>(p);
            cnt[3]++;
        }
        break;        
    case 5:
        {
            auto ptr = dynamic_cast<DerivedE*>(p);
            cnt[4]++;
        }
        break;        
    case 6:
        {
            auto ptr = dynamic_cast<DerivedF*>(p);
            cnt[5]++;
        }
        break;        
    case 7:
        {
            auto ptr = dynamic_cast<DerivedG*>(p);
            cnt[6]++;
        }
        break;        
    case 8:
        {
            auto ptr = dynamic_cast<DerivedH*>(p);
            cnt[7]++;
        }
        break;
    default:
        break;
    }
}

auto end = now();
auto duration
    = std::chrono::duration_cast<std::chrono::microseconds>(end - begin).count();
std::cout << "Dynamic cast + with id checking\n" << std::setw(20) << std::right << duration << "us" << std::endl;
std::cout << "---------validation----------\n";
std::cout << cnt[0] << " : " << cnt[1] << " : " << cnt[2] << " |\n"
          << cnt[3] << " : " << cnt[4] << " : " << cnt[5] << " |\n"
          << cnt[6] << " : " << cnt[7] << std::endl;
std::cout << "-----------------------------\n\n";

이를 같은 환경에서 테스트했을 때, 200,000마이크로초 근처였다. Visitor보다는 빠르지 않았어도, dynamic_cast를 여러번 한 경우에 비해 짧은 수행시간을 보여주었다.

정리

dynamic_cast를 사용하지 않는 것이 권장되어도, (그리고 클래스 설계를 dynamic_cast를 쓰지 않는 방법으로 수정하는 것이 권장되더라도) 런타임에 타입을 체크하거나 상속받은 자식 클래스에 있는 특수한 함수 또는 성질에 접근해야할 필요성을 완전히 지우는 것은 노력이 필요한 일이다. 이 글에서 dynamic_cast를 cascade 처럼 사용하지 않고 비슷한 역할을 수행할 수 있는 방법 중 하나를 소개해보았다. 이 방법론을 사용하다면, 클래스의 종류가 늘어나지 않고 함수를 계속 추가해 나가야 할 때 좀 더 편하고 그리고 동작시간이 빠른 코드 및 구현을 얻어낼 수 있을 것으로 기대한다.