【C++】手搓variant

标准库std::visit用法

#include <iostream>
#include <variant>

void print(const std::variant<int, std::string>& v) {
    
    
    if (v.index() == 0) {
    
     // index() 获取当前存储的类型索引
        std::cout << std::get<int>(v) << '\n';
    } else if (v.index() == 1) {
    
    
        std::cout << std::get<std::string>(v) << '\n';
    }
}

int main() {
    
    
    std::variant<int, std::string> v; 

    v = 42; // 存储 int
    print(v);

    v = "3333"; // 存储 std::string
    print(v);

    return 0;
}

实现一个简单的variant

封装个简单variant,为了防止类型是int,但是赋得的却是double类型的

#include <iostream>

enum VariantType {
    
     Int, Double };

struct Variant {
    
    
    VariantType type;
    union {
    
    
        int i;
        double d;
    };
};

void print(const Variant& v) {
    
    
    if (v.type == Int) {
    
    
        std::cout << v.i << "\n";
    } else if (v.type == Double) {
    
    
        std::cout << v.d << "\n";
    }
}

int main() {
    
    
    Variant v;

    v.type = Int;
    v.i = 42; // 赋值为整数
    print(v);

    v.type = Double;
    v.d = 3.14; // 赋值为浮点数
    print(v);

    return 0;
}

进一步使用构造函数区分不同的类型

#include <iostream>
#include <stdexcept>

enum VariantType {
    
     Int, Double };

class BadVariantAccess : public std::exception {
    
    
public:
    BadVariantAccess() = default;
    virtual ~BadVariantAccess() = default;
    const char *what() const noexcept override {
    
    
        return "BadVariantAccess";
    }
};

class Variant {
    
    
private:
    VariantType m_type;
    union {
    
    
        int m_i;
        double m_d;
    };

public:
    // 构造函数
    Variant(int i) : m_type(Int), m_i(i) {
    
    }
    Variant(double d) : m_type(Double), m_d(d) {
    
    }

    // 类型检查
    bool is_int() const {
    
     return m_type == Int; }
    bool is_double() const {
    
     return m_type == Double; }

    // 获取值
    int get_int() const {
    
    
        if (m_type != Int) throw BadVariantAccess();
        return m_i;
    }

    double get_double() const {
    
    
        if (m_type != Double) throw BadVariantAccess();
        return m_d;
    }

    // 打印方法
    void print() const {
    
    
        if (is_int()) {
    
    
            std::cout << get_int() << "\n";
        } else if (is_double()) {
    
    
            std::cout << get_double() << "\n";
        }
    }
};

int main() {
    
    
    try {
    
    
        Variant v(42);  // 存储整数
        v.print();      // 输出: 42

        v = Variant(3.14); // 存储浮点数
        v.print();         // 输出: 3.14
    } catch (const BadVariantAccess& e) {
    
    
        std::cerr << "Error: " << e.what() << std::endl;
    }

    return 0;
}

支持两个任意类型的variant

#include <iostream>
#include <stdexcept>
#include <string>
#include <type_traits>
#include <utility>

// 自定义异常类
struct BadVariantAccess : public std::exception {
    
    
    const char* what() const noexcept override {
    
     return "BadVariantAccess"; }
};

// 模板 Variant,支持两种类型 T1 和 T2
template <typename T1, typename T2>
class Variant {
    
    
   private:
    int m_index;
    char m_union[std::max(sizeof(T1), sizeof(T2))];

   public:
    // 构造 T1
    // placement new和placement delete。因为已经给你分配了内存了char
    // Union。在p上调用T1的构造函数
    Variant(T1 value) : m_index(1) {
    
    
        T1* p = reinterpret_cast<T1*>(m_union);
        new (p) T1(value);
    }

    // 构造 T2
    Variant(T2 value) : m_index(2) {
    
    
        T2* p = reinterpret_cast<T2*>(m_union);
        new (p) T2(value);
    }

    // 析构函数,确保正确销毁对象
    ~Variant() {
    
    
        if (m_index == 1) {
    
    
            reinterpret_cast<T1*>(m_union)->~T1();
        } else if (m_index == 2) {
    
    
            reinterpret_cast<T2*>(m_union)->~T2();
        }
    }

    // 移动构造函数
    Variant(Variant&& other) noexcept : m_index(other.m_index) {
    
    
        if (m_index == 1) {
    
    
            new (m_union) T1(std::move(*reinterpret_cast<T1*>(other.m_union)));
        } else if (m_index == 2) {
    
    
            new (m_union) T2(std::move(*reinterpret_cast<T2*>(other.m_union)));
        }
    }

    // 移动赋值运算符
    Variant& operator=(Variant&& other) noexcept {
    
    
        if (this != &other) {
    
    
            this->~Variant();
            new (this) Variant(std::move(other));
        }
        return *this;
    }

    // 类型检查
    template <int I>
    bool is() const {
    
    
        return m_index == I;
    }

    // 模板类型也可以构成重载,模板参数和普通参数。
    template <typename T>
    bool is() const {
    
    
        if constexpr (std::is_same_v<T, T1>) {
    
    
            return is<1>();
        } else if constexpr (std::is_same_v<T, T2>) {
    
    
            return is<2>();
        } else {
    
    
            static_assert(std::is_same_v<T, T1> || std::is_same_v<T, T2>,
                          "T must be either T1 or T2!");
        }
    }

    // 获取值
    template <int I>
    auto get() const {
    
    
        if (m_index != I) {
    
    
            throw BadVariantAccess();
        }
        if constexpr (I == 1) {
    
    
            return *reinterpret_cast<const T1*>(m_union);
        } else if constexpr (I == 2) {
    
    
            return *reinterpret_cast<const T2*>(m_union);
        } else {
    
    
            static_assert(I != I, "I out of range!");
        }
    }

    template <typename T>
    auto get() const {
    
    
        if constexpr (std::is_same_v<T, T1>) {
    
    
            return get<1>();
        } else if constexpr (std::is_same_v<T, T2>) {
    
    
            return get<2>();
        } else {
    
    
            static_assert(std::is_same_v<T, T1> || std::is_same_v<T, T2>,
                          "T must be either T1 or T2!");
        }
    }
};

// 打印函数
void print(const Variant<std::string, int>& v) {
    
    
    if (v.is<1>()) {
    
    
        std::cout << v.get<1>() << '\n';
    } else if (v.is<2>()) {
    
    
        std::cout << v.get<2>() << '\n';
    }
}

// 主函数
int main() {
    
    
    try {
    
    
        Variant<std::string, int> v("Hello");
        print(v);

        v = Variant<std::string, int>(42);
        print(v);
    } catch (const BadVariantAccess& e) {
    
    
        std::cerr << "Error: " << e.what() << std::endl;
    }

    return 0;
}

std::variant_alternative

标准库variant_alternative可以给定index查询具体的类型

#include <iostream>
#include <variant>
#include <type_traits>

int main() {
    
    
    using MyVariant = std::variant<int, double, std::string>;

    // 获取第0个类型(int)
    using T0 = std::variant_alternative<0, MyVariant>::type;
    static_assert(std::is_same_v<T0, int>);
    std::cout << "Type at index 0: " << typeid(T0).name() << '\n';

    // 获取第1个类型(double)
    using T1 = std::variant_alternative_t<1, MyVariant>;
    static_assert(std::is_same_v<T1, double>);
    std::cout << "Type at index 1: " << typeid(T1).name() << '\n';

    // 获取第2个类型(std::string)
    using T2 = std::variant_alternative_t<2, MyVariant>;
    static_assert(std::is_same_v<T2, std::string>);
    std::cout << "Type at index 2: " << typeid(T2).name() << '\n';

    return 0;
}

自己实现variant_alternative

// 主模板(未匹配时导致编译错误)
//首先定义一个通用的模板类型
template <typename Variant, typename T>
struct VariantIndex;

// 偏特化版本:匹配 Variant<T1, T2> 中的 T1,索引为 0
template <typename T1, typename T2>
struct VariantIndex<std::variant<T1, T2>, T1> {
    
    
    static constexpr int value = 0;
};

// 偏特化版本:匹配 Variant<T1, T2> 中的 T2,索引为 1
template <typename T1, typename T2>
struct VariantIndex<std::variant<T1, T2>, T2> {
    
    
    static constexpr int value = 1;
};

支持任意数量类型的variant

#pragma once

#include <algorithm>
#include <type_traits>
#include <functional>

// 今天来实现标准库中的 variant 和 visit

template <size_t I>
struct InPlaceIndex {
    
    
    explicit InPlaceIndex() = default;
};

//这类似于 std::in_place_index<I> 的作用,目的是方便 Variant 构造时传递索引参数
/*
定义了一个 constexpr 变量模板,即 inPlaceIndex<I>,可以用于 自动推导 InPlaceIndex<I> 类型,而 不需要手动构造 InPlaceIndex<I>{}。
这类似于 std::in_place_index<I> 的作用,目的是方便 Variant 构造时传递索引参数。

变量模板已经是个变量了,所以价格模板参数就行?
// Variant<std::string, int, double> v1(inPlaceIndex<0>, "asas");
//等价于 Variant<std::string, int, double> v1(inPlaceIndex<0>{}, "asas");
*/
template <size_t I>
constexpr InPlaceIndex<I> inPlaceIndex;

struct BadVariantAccess : std::exception {
    
    
    BadVariantAccess() = default;
    virtual ~BadVariantAccess() = default;

    const char *what() const noexcept override {
    
    
        return "BadVariantAccess";
    }
};
//类型得到下标
template <typename, typename> // typename -> size_t
struct VariantIndex;

//下标得到类型
template <typename, size_t>   // size_t -> typename
struct VariantAlternative;

// VariantAlternative<Variant<int, double>, 1> = int;
// VariantAlternative<Variant<int, double>, 2> = double;
// VariantIndex<Variant<int, double>, int> = 1;
// VariantIndex<Variant<int, double>, double> = 2;

template <typename ...Ts>
struct Variant {
    
    
private:
    size_t m_index;
	
	//max的参数支持initialize list,所以加上个花括号
	//优化对齐,有些c语言要求要按照最大类型对齐,这里对齐到8字节
    alignas(std::max({
    
    alignof(Ts)...})) char m_union[std::max({
    
    sizeof(Ts)...})];

    using DestructorFunction = void(*)(char *) noexcept;

	/*
	等价于
	static DestructorFunction function_ptrs[2] = {
            [] (char *union_p) noexcept {
                reinterpret_cast<T0 *>(union_p)->~T0();
            },
			[] (char *union_p) noexcept {
                reinterpret_cast<T1 *>(union_p)->~T1();
            },
        };
	
	*/
	//lambda可以隐式转化为函数指针
    static DestructorFunction *destructors_table() noexcept {
    
    
        static DestructorFunction function_ptrs[sizeof...(Ts)] = {
    
    
            [] (char *union_p) noexcept {
    
    
                reinterpret_cast<Ts *>(union_p)->~Ts();
            }...
        };
        return function_ptrs;
    }

    using CopyConstructorFunction = void(*)(char *, char const *) noexcept;

    static CopyConstructorFunction *copy_constructors_table() noexcept {
    
    
        static CopyConstructorFunction function_ptrs[sizeof...(Ts)] = {
    
    
            [] (char *union_dst, char const *union_src) noexcept {
    
    
                new (union_dst) Ts(*reinterpret_cast<Ts const *>(union_src));
            }...
        };
        return function_ptrs;
    }

    using CopyAssignmentFunction = void(*)(char *, char const *) noexcept;

    static CopyAssignmentFunction *copy_assigment_functions_table() noexcept {
    
    
        static CopyAssignmentFunction function_ptrs[sizeof...(Ts)] = {
    
    
            [] (char *union_dst, char const *union_src) noexcept {
    
    
                *reinterpret_cast<Ts *>(union_dst) = *reinterpret_cast<Ts const *>(union_src);
            }...
        };
        return function_ptrs;
    }

    using MoveConstructorFunction = void(*)(char *, char *) noexcept;

    static MoveConstructorFunction *move_constructors_table() noexcept {
    
    
        static MoveConstructorFunction function_ptrs[sizeof...(Ts)] = {
    
    
            [] (char *union_dst, char *union_src) noexcept {
    
    
                new (union_dst) Ts(std::move(*reinterpret_cast<Ts const *>(union_src)));
            }...
        };
        return function_ptrs;
    }

    using MoveAssignmentFunction = void(*)(char *, char *) noexcept;

    static MoveAssignmentFunction *move_assigment_functions_table() noexcept {
    
    
        static MoveAssignmentFunction function_ptrs[sizeof...(Ts)] = {
    
    
            [] (char *union_dst, char *union_src) noexcept {
    
    
                *reinterpret_cast<Ts *>(union_dst) = std::move(*reinterpret_cast<Ts *>(union_src));
            }...
        };
        return function_ptrs;
    }

    template <class Lambda>
    using ConstVisitorFunction = std::common_type<typename std::invoke_result<Lambda, Ts const &>::type...>::type(*)(char const *, Lambda &&);

    template <class Lambda>
    static ConstVisitorFunction<Lambda> *const_visitors_table() noexcept {
    
    
        static ConstVisitorFunction<Lambda> function_ptrs[sizeof...(Ts)] = {
    
    
            [] (char const *union_p, Lambda &&lambda) -> typename std::invoke_result<Lambda, Ts const &>::type {
    
    
                std::invoke(std::forward<Lambda>(lambda),
                            *reinterpret_cast<Ts const *>(union_p));
            }...
        };
        return function_ptrs;
    }

    template <class Lambda>
    using VisitorFunction = std::common_type<typename std::invoke_result<Lambda, Ts &>::type...>::type(*)(char *, Lambda &&);

    template <class Lambda>
    static VisitorFunction<Lambda> *visitors_table() noexcept {
    
    
        static VisitorFunction<Lambda> function_ptrs[sizeof...(Ts)] = {
    
    
			//提取返回类型,cpp14可以用auto,cpp11只能这么写
            [] (char *union_p, Lambda &&lambda) -> std::common_type<typename std::invoke_result<Lambda, Ts &>::type...>::type {
    
    
                return std::invoke(std::forward<Lambda>(lambda),
                                   *reinterpret_cast<Ts *>(union_p));
            }...
        };
        return function_ptrs;
    }

public:
	/*
		T能转成Ts的任意类型就行
		template <typename T>
			retuires (std::is_convertable<T,Ts>||...)
		Variant(T value) : m_index(VariantIndex<Variant, T>::value) {
        T *p = reinterpret_cast<T *>(m_union);
        new (p) T(value);
    }
	C++14语法使用std::disjunction
	
	在模板类中定义某些函数,可以不加模板参数,因为cpp有个类型名称注入,就在类体内,不用写。不用在构造函数里面写下面的东西
	Variant<Ts...>(T value)
	*/
    template <typename T, typename std::enable_if<
        std::disjunction<std::is_same<T, Ts>...>::value,
        int>::type = 0>
    Variant(T value) : m_index(VariantIndex<Variant, T>::value) {
    
    
        T *p = reinterpret_cast<T *>(m_union);
        new (p) T(value);
    }

    Variant(Variant const &that) : m_index(that.m_index) {
    
    
        copy_constructors_table()[index()](m_union, that.m_union);
    }

    Variant &operator=(Variant const &that) {
    
    
        m_index = that.m_index;
        copy_assigment_functions_table()[index()](m_union, that.m_union);
    }

    Variant(Variant &&that) : m_index(that.m_index) {
    
    
        move_constructors_table()[index()](m_union, that.m_union);
    }

    Variant &operator=(Variant &&that) {
    
    
        m_index = that.m_index;
        move_assigment_functions_table()[index()](m_union, that.m_union);
    }
	
	//构造函数重载,只是为了重载而重载
    template <size_t I, typename ...Args>
    explicit Variant(InPlaceIndex<I>, Args &&...value_args) : m_index(I) {
    
    
        new (m_union) typename VariantAlternative<Variant, I>::type
            (std::forward<Args>(value_args)...);
    }
	
	//定义了析构函数,要把拷贝删除
	Variant(Variant const &= delete;
	Variant &operator=(Variant const&) = delete;
	
	//定义析构函数
	//析构函数永远不会抛出异常,加上noexcept
    ~Variant() noexcept {
    
    
        destructors_table()[index()](m_union);
    }

	//注意visit有带const和不带const的
	//const函数,返回的是const的指针或者引用
    template <class Lambda>
	//std::common_type<typename std::invoke_result<Lambda, Ts &>::type...>::type它的作用是 推导出 Lambda 作用于所有 Ts 类型的返回值的公共类型
    std::common_type<typename std::invoke_result<Lambda, Ts &>::type...>::type visit(Lambda &&lambda) {
    
    
        // 由于时间原因,暂时没有实现支持多参数的std::visit,决定留作回家作业,供学有余力的同学自己尝试实现
        return visitors_table<Lambda>()[index()](m_union, std::forward<Lambda>(lambda));
    }

    template <class Lambda>
    std::common_type<typename std::invoke_result<Lambda, Ts const &>::type...>::type visit(Lambda &&lambda) const {
    
    
        return const_visitors_table<Lambda>()[index()](m_union, std::forward<Lambda>(lambda));
    }

    constexpr size_t index() const noexcept {
    
    
        return m_index;
    }
	
	//将支持下标和类型的改造成类似标准库的holds_alternative
	//T在variant的下标是不是我保存的下标
    template <typename T>
    constexpr bool holds_alternative() const noexcept {
    
    
        return VariantIndex<Variant, T>::value == index();
    }

    template <size_t I>
    typename VariantAlternative<Variant, I>::type &get() {
    
    
        static_assert(I < sizeof...(Ts), "I out of range!");
        if (m_index != I)
            throw BadVariantAccess();
        return *reinterpret_cast<typename VariantAlternative<Variant, I>::type *>(m_union);
    }

    template <typename T>
    T &get() {
    
    
        return get<VariantIndex<Variant, T>::value>();
    }

    template <size_t I>
    typename VariantAlternative<Variant, I>::type const &get() const {
    
    
        static_assert(I < sizeof...(Ts), "I out of range!");
        if (m_index != I)
            throw BadVariantAccess();
        return *reinterpret_cast<typename VariantAlternative<Variant, I>::type const *>(m_union);
    }

    template <typename T>
    T const &get() const {
    
    
        return get<VariantIndex<Variant, T>::value>();
    }

	//新增get_if
    template <size_t I>
    typename VariantAlternative<Variant, I>::type *get_if() {
    
    
        static_assert(I < sizeof...(Ts), "I out of range!");
        if (m_index != I)
            return nullptr;
        return reinterpret_cast<typename VariantAlternative<Variant, I>::type *>(m_union);
    }

    template <typename T>
    T *get_if() {
    
    
        return get_if<VariantIndex<Variant, T>::value>();
    }

    template <size_t I>
    typename VariantAlternative<Variant, I>::type const *get_if() const {
    
    
        static_assert(I < sizeof...(Ts), "I out of range!");
        if (m_index != I)
            return nullptr;
        return reinterpret_cast<typename VariantAlternative<Variant, I>::type const *>(m_union);
    }

    template <typename T>
    T const *get_if() const {
    
    
        return get_if<VariantIndex<Variant, T>::value>();
    }
};

//从下标获取类型
//找类型只能是using或者typename
template <typename T, typename ...Ts>
struct VariantAlternative<Variant<T, Ts...>, 0> {
    
    
    using type = T;
};

template <typename T, typename ...Ts, size_t I>
struct VariantAlternative<Variant<T, Ts...>, I> {
    
    
    using type = typename VariantAlternative<Variant<Ts...>, I - 1>::type;
};

/*
支持任意长度参数的variant  index。如果T
是第零个则走第一个,如果不是,则走第二个,但是要把t0去掉
*/
template <typename T, typename ...Ts>
struct VariantIndex<Variant<T, Ts...>, T> {
    
    
    static constexpr size_t value = 0;
};

template <typename T0, typename T, typename ...Ts>
struct VariantIndex<Variant<T0, Ts...>, T> {
    
    
    static constexpr size_t value = VariantIndex<Variant<Ts...>, T>::value + 1;
};

在这里插入图片描述

std::common_type<typename std::invoke_result<Lambda, Ts &>::type…>::type

它的作用是 推导出 Lambda 作用于所有 Ts 类型的返回值的公共类型。

std::invoke_result<Lambda, Ts &>::type
std::invoke_result<F, Args…>::type 用于获取 调用 F(Args…) 之后的返回值类型。
Lambda 是一个可调用对象(比如 lambda 函数)。
Ts… 是 Variant 的所有可能类型。
Ts & 代表 Lambda 会以 Ts 类型的左值引用作为参数进行调用。

typename std::invoke_result<Lambda, Ts &>::type...
会展开为:
typename std::invoke_result<Lambda, T1 &>::type,
typename std::invoke_result<Lambda, T2 &>::type,
typename std::invoke_result<Lambda, T3 &>::type,
...

即,Lambda 作用于 T1 &, T2 &, T3 &, … 后得到的所有返回值类型。

std::common_type<T1, T2, T3, …>::type

std::common_type<T1, T2, …>::type 计算出 所有 T1, T2, T3, … 的公共类型。

std::common_type<int, double>::type  // → double
std::common_type<float, double>::type // → double
std::common_type<int, float, double>::type // → double

在 visit 函数中,std::common_type 确保 Lambda 适用于所有 Ts,并且它的返回类型是 所有返回值类型的公共类型。

我们用 Variant<int, double, std::string> 举例:

#include <iostream>
#include <type_traits>
#include <string>

template <typename... Ts>
struct Variant {
    
    };  // 这里只是一个占位的 Variant

int main() {
    
    
    auto lambda = [](auto &val) -> decltype(val + 1) {
    
    
        return val + 1;
    };

    using ResultType = std::common_type<
        typename std::invoke_result<decltype(lambda), int &>::type,
        typename std::invoke_result<decltype(lambda), double &>::type
    >::type;

    std::cout << std::is_same<ResultType, double>::value << std::endl;  // 输出 1,表示推导为 double

    return 0;
}

参考