文章目录
标准库std::visit用法
#include <iostream>
#include <variant>
void print(const std::variant<int, std::string>& v) {
if (v.index() == 0) {
// index() 获取当前存储的类型索引
std::cout << std::get<int>(v) << '\n';
} else if (v.index() == 1) {
std::cout << std::get<std::string>(v) << '\n';
}
}
int main() {
std::variant<int, std::string> v;
v = 42; // 存储 int
print(v);
v = "3333"; // 存储 std::string
print(v);
return 0;
}
实现一个简单的variant
封装个简单variant,为了防止类型是int,但是赋得的却是double类型的
#include <iostream>
enum VariantType {
Int, Double };
struct Variant {
VariantType type;
union {
int i;
double d;
};
};
void print(const Variant& v) {
if (v.type == Int) {
std::cout << v.i << "\n";
} else if (v.type == Double) {
std::cout << v.d << "\n";
}
}
int main() {
Variant v;
v.type = Int;
v.i = 42; // 赋值为整数
print(v);
v.type = Double;
v.d = 3.14; // 赋值为浮点数
print(v);
return 0;
}
进一步使用构造函数区分不同的类型
#include <iostream>
#include <stdexcept>
enum VariantType {
Int, Double };
class BadVariantAccess : public std::exception {
public:
BadVariantAccess() = default;
virtual ~BadVariantAccess() = default;
const char *what() const noexcept override {
return "BadVariantAccess";
}
};
class Variant {
private:
VariantType m_type;
union {
int m_i;
double m_d;
};
public:
// 构造函数
Variant(int i) : m_type(Int), m_i(i) {
}
Variant(double d) : m_type(Double), m_d(d) {
}
// 类型检查
bool is_int() const {
return m_type == Int; }
bool is_double() const {
return m_type == Double; }
// 获取值
int get_int() const {
if (m_type != Int) throw BadVariantAccess();
return m_i;
}
double get_double() const {
if (m_type != Double) throw BadVariantAccess();
return m_d;
}
// 打印方法
void print() const {
if (is_int()) {
std::cout << get_int() << "\n";
} else if (is_double()) {
std::cout << get_double() << "\n";
}
}
};
int main() {
try {
Variant v(42); // 存储整数
v.print(); // 输出: 42
v = Variant(3.14); // 存储浮点数
v.print(); // 输出: 3.14
} catch (const BadVariantAccess& e) {
std::cerr << "Error: " << e.what() << std::endl;
}
return 0;
}
支持两个任意类型的variant
#include <iostream>
#include <stdexcept>
#include <string>
#include <type_traits>
#include <utility>
// 自定义异常类
struct BadVariantAccess : public std::exception {
const char* what() const noexcept override {
return "BadVariantAccess"; }
};
// 模板 Variant,支持两种类型 T1 和 T2
template <typename T1, typename T2>
class Variant {
private:
int m_index;
char m_union[std::max(sizeof(T1), sizeof(T2))];
public:
// 构造 T1
// placement new和placement delete。因为已经给你分配了内存了char
// Union。在p上调用T1的构造函数
Variant(T1 value) : m_index(1) {
T1* p = reinterpret_cast<T1*>(m_union);
new (p) T1(value);
}
// 构造 T2
Variant(T2 value) : m_index(2) {
T2* p = reinterpret_cast<T2*>(m_union);
new (p) T2(value);
}
// 析构函数,确保正确销毁对象
~Variant() {
if (m_index == 1) {
reinterpret_cast<T1*>(m_union)->~T1();
} else if (m_index == 2) {
reinterpret_cast<T2*>(m_union)->~T2();
}
}
// 移动构造函数
Variant(Variant&& other) noexcept : m_index(other.m_index) {
if (m_index == 1) {
new (m_union) T1(std::move(*reinterpret_cast<T1*>(other.m_union)));
} else if (m_index == 2) {
new (m_union) T2(std::move(*reinterpret_cast<T2*>(other.m_union)));
}
}
// 移动赋值运算符
Variant& operator=(Variant&& other) noexcept {
if (this != &other) {
this->~Variant();
new (this) Variant(std::move(other));
}
return *this;
}
// 类型检查
template <int I>
bool is() const {
return m_index == I;
}
// 模板类型也可以构成重载,模板参数和普通参数。
template <typename T>
bool is() const {
if constexpr (std::is_same_v<T, T1>) {
return is<1>();
} else if constexpr (std::is_same_v<T, T2>) {
return is<2>();
} else {
static_assert(std::is_same_v<T, T1> || std::is_same_v<T, T2>,
"T must be either T1 or T2!");
}
}
// 获取值
template <int I>
auto get() const {
if (m_index != I) {
throw BadVariantAccess();
}
if constexpr (I == 1) {
return *reinterpret_cast<const T1*>(m_union);
} else if constexpr (I == 2) {
return *reinterpret_cast<const T2*>(m_union);
} else {
static_assert(I != I, "I out of range!");
}
}
template <typename T>
auto get() const {
if constexpr (std::is_same_v<T, T1>) {
return get<1>();
} else if constexpr (std::is_same_v<T, T2>) {
return get<2>();
} else {
static_assert(std::is_same_v<T, T1> || std::is_same_v<T, T2>,
"T must be either T1 or T2!");
}
}
};
// 打印函数
void print(const Variant<std::string, int>& v) {
if (v.is<1>()) {
std::cout << v.get<1>() << '\n';
} else if (v.is<2>()) {
std::cout << v.get<2>() << '\n';
}
}
// 主函数
int main() {
try {
Variant<std::string, int> v("Hello");
print(v);
v = Variant<std::string, int>(42);
print(v);
} catch (const BadVariantAccess& e) {
std::cerr << "Error: " << e.what() << std::endl;
}
return 0;
}
std::variant_alternative
标准库variant_alternative可以给定index查询具体的类型
#include <iostream>
#include <variant>
#include <type_traits>
int main() {
using MyVariant = std::variant<int, double, std::string>;
// 获取第0个类型(int)
using T0 = std::variant_alternative<0, MyVariant>::type;
static_assert(std::is_same_v<T0, int>);
std::cout << "Type at index 0: " << typeid(T0).name() << '\n';
// 获取第1个类型(double)
using T1 = std::variant_alternative_t<1, MyVariant>;
static_assert(std::is_same_v<T1, double>);
std::cout << "Type at index 1: " << typeid(T1).name() << '\n';
// 获取第2个类型(std::string)
using T2 = std::variant_alternative_t<2, MyVariant>;
static_assert(std::is_same_v<T2, std::string>);
std::cout << "Type at index 2: " << typeid(T2).name() << '\n';
return 0;
}
自己实现variant_alternative
// 主模板(未匹配时导致编译错误)
//首先定义一个通用的模板类型
template <typename Variant, typename T>
struct VariantIndex;
// 偏特化版本:匹配 Variant<T1, T2> 中的 T1,索引为 0
template <typename T1, typename T2>
struct VariantIndex<std::variant<T1, T2>, T1> {
static constexpr int value = 0;
};
// 偏特化版本:匹配 Variant<T1, T2> 中的 T2,索引为 1
template <typename T1, typename T2>
struct VariantIndex<std::variant<T1, T2>, T2> {
static constexpr int value = 1;
};
支持任意数量类型的variant
#pragma once
#include <algorithm>
#include <type_traits>
#include <functional>
// 今天来实现标准库中的 variant 和 visit
template <size_t I>
struct InPlaceIndex {
explicit InPlaceIndex() = default;
};
//这类似于 std::in_place_index<I> 的作用,目的是方便 Variant 构造时传递索引参数
/*
定义了一个 constexpr 变量模板,即 inPlaceIndex<I>,可以用于 自动推导 InPlaceIndex<I> 类型,而 不需要手动构造 InPlaceIndex<I>{}。
这类似于 std::in_place_index<I> 的作用,目的是方便 Variant 构造时传递索引参数。
变量模板已经是个变量了,所以价格模板参数就行?
// Variant<std::string, int, double> v1(inPlaceIndex<0>, "asas");
//等价于 Variant<std::string, int, double> v1(inPlaceIndex<0>{}, "asas");
*/
template <size_t I>
constexpr InPlaceIndex<I> inPlaceIndex;
struct BadVariantAccess : std::exception {
BadVariantAccess() = default;
virtual ~BadVariantAccess() = default;
const char *what() const noexcept override {
return "BadVariantAccess";
}
};
//类型得到下标
template <typename, typename> // typename -> size_t
struct VariantIndex;
//下标得到类型
template <typename, size_t> // size_t -> typename
struct VariantAlternative;
// VariantAlternative<Variant<int, double>, 1> = int;
// VariantAlternative<Variant<int, double>, 2> = double;
// VariantIndex<Variant<int, double>, int> = 1;
// VariantIndex<Variant<int, double>, double> = 2;
template <typename ...Ts>
struct Variant {
private:
size_t m_index;
//max的参数支持initialize list,所以加上个花括号
//优化对齐,有些c语言要求要按照最大类型对齐,这里对齐到8字节
alignas(std::max({
alignof(Ts)...})) char m_union[std::max({
sizeof(Ts)...})];
using DestructorFunction = void(*)(char *) noexcept;
/*
等价于
static DestructorFunction function_ptrs[2] = {
[] (char *union_p) noexcept {
reinterpret_cast<T0 *>(union_p)->~T0();
},
[] (char *union_p) noexcept {
reinterpret_cast<T1 *>(union_p)->~T1();
},
};
*/
//lambda可以隐式转化为函数指针
static DestructorFunction *destructors_table() noexcept {
static DestructorFunction function_ptrs[sizeof...(Ts)] = {
[] (char *union_p) noexcept {
reinterpret_cast<Ts *>(union_p)->~Ts();
}...
};
return function_ptrs;
}
using CopyConstructorFunction = void(*)(char *, char const *) noexcept;
static CopyConstructorFunction *copy_constructors_table() noexcept {
static CopyConstructorFunction function_ptrs[sizeof...(Ts)] = {
[] (char *union_dst, char const *union_src) noexcept {
new (union_dst) Ts(*reinterpret_cast<Ts const *>(union_src));
}...
};
return function_ptrs;
}
using CopyAssignmentFunction = void(*)(char *, char const *) noexcept;
static CopyAssignmentFunction *copy_assigment_functions_table() noexcept {
static CopyAssignmentFunction function_ptrs[sizeof...(Ts)] = {
[] (char *union_dst, char const *union_src) noexcept {
*reinterpret_cast<Ts *>(union_dst) = *reinterpret_cast<Ts const *>(union_src);
}...
};
return function_ptrs;
}
using MoveConstructorFunction = void(*)(char *, char *) noexcept;
static MoveConstructorFunction *move_constructors_table() noexcept {
static MoveConstructorFunction function_ptrs[sizeof...(Ts)] = {
[] (char *union_dst, char *union_src) noexcept {
new (union_dst) Ts(std::move(*reinterpret_cast<Ts const *>(union_src)));
}...
};
return function_ptrs;
}
using MoveAssignmentFunction = void(*)(char *, char *) noexcept;
static MoveAssignmentFunction *move_assigment_functions_table() noexcept {
static MoveAssignmentFunction function_ptrs[sizeof...(Ts)] = {
[] (char *union_dst, char *union_src) noexcept {
*reinterpret_cast<Ts *>(union_dst) = std::move(*reinterpret_cast<Ts *>(union_src));
}...
};
return function_ptrs;
}
template <class Lambda>
using ConstVisitorFunction = std::common_type<typename std::invoke_result<Lambda, Ts const &>::type...>::type(*)(char const *, Lambda &&);
template <class Lambda>
static ConstVisitorFunction<Lambda> *const_visitors_table() noexcept {
static ConstVisitorFunction<Lambda> function_ptrs[sizeof...(Ts)] = {
[] (char const *union_p, Lambda &&lambda) -> typename std::invoke_result<Lambda, Ts const &>::type {
std::invoke(std::forward<Lambda>(lambda),
*reinterpret_cast<Ts const *>(union_p));
}...
};
return function_ptrs;
}
template <class Lambda>
using VisitorFunction = std::common_type<typename std::invoke_result<Lambda, Ts &>::type...>::type(*)(char *, Lambda &&);
template <class Lambda>
static VisitorFunction<Lambda> *visitors_table() noexcept {
static VisitorFunction<Lambda> function_ptrs[sizeof...(Ts)] = {
//提取返回类型,cpp14可以用auto,cpp11只能这么写
[] (char *union_p, Lambda &&lambda) -> std::common_type<typename std::invoke_result<Lambda, Ts &>::type...>::type {
return std::invoke(std::forward<Lambda>(lambda),
*reinterpret_cast<Ts *>(union_p));
}...
};
return function_ptrs;
}
public:
/*
T能转成Ts的任意类型就行
template <typename T>
retuires (std::is_convertable<T,Ts>||...)
Variant(T value) : m_index(VariantIndex<Variant, T>::value) {
T *p = reinterpret_cast<T *>(m_union);
new (p) T(value);
}
C++14语法使用std::disjunction
在模板类中定义某些函数,可以不加模板参数,因为cpp有个类型名称注入,就在类体内,不用写。不用在构造函数里面写下面的东西
Variant<Ts...>(T value)
*/
template <typename T, typename std::enable_if<
std::disjunction<std::is_same<T, Ts>...>::value,
int>::type = 0>
Variant(T value) : m_index(VariantIndex<Variant, T>::value) {
T *p = reinterpret_cast<T *>(m_union);
new (p) T(value);
}
Variant(Variant const &that) : m_index(that.m_index) {
copy_constructors_table()[index()](m_union, that.m_union);
}
Variant &operator=(Variant const &that) {
m_index = that.m_index;
copy_assigment_functions_table()[index()](m_union, that.m_union);
}
Variant(Variant &&that) : m_index(that.m_index) {
move_constructors_table()[index()](m_union, that.m_union);
}
Variant &operator=(Variant &&that) {
m_index = that.m_index;
move_assigment_functions_table()[index()](m_union, that.m_union);
}
//构造函数重载,只是为了重载而重载
template <size_t I, typename ...Args>
explicit Variant(InPlaceIndex<I>, Args &&...value_args) : m_index(I) {
new (m_union) typename VariantAlternative<Variant, I>::type
(std::forward<Args>(value_args)...);
}
//定义了析构函数,要把拷贝删除
Variant(Variant const &) = delete;
Variant &operator=(Variant const&) = delete;
//定义析构函数
//析构函数永远不会抛出异常,加上noexcept
~Variant() noexcept {
destructors_table()[index()](m_union);
}
//注意visit有带const和不带const的
//const函数,返回的是const的指针或者引用
template <class Lambda>
//std::common_type<typename std::invoke_result<Lambda, Ts &>::type...>::type它的作用是 推导出 Lambda 作用于所有 Ts 类型的返回值的公共类型
std::common_type<typename std::invoke_result<Lambda, Ts &>::type...>::type visit(Lambda &&lambda) {
// 由于时间原因,暂时没有实现支持多参数的std::visit,决定留作回家作业,供学有余力的同学自己尝试实现
return visitors_table<Lambda>()[index()](m_union, std::forward<Lambda>(lambda));
}
template <class Lambda>
std::common_type<typename std::invoke_result<Lambda, Ts const &>::type...>::type visit(Lambda &&lambda) const {
return const_visitors_table<Lambda>()[index()](m_union, std::forward<Lambda>(lambda));
}
constexpr size_t index() const noexcept {
return m_index;
}
//将支持下标和类型的改造成类似标准库的holds_alternative
//T在variant的下标是不是我保存的下标
template <typename T>
constexpr bool holds_alternative() const noexcept {
return VariantIndex<Variant, T>::value == index();
}
template <size_t I>
typename VariantAlternative<Variant, I>::type &get() {
static_assert(I < sizeof...(Ts), "I out of range!");
if (m_index != I)
throw BadVariantAccess();
return *reinterpret_cast<typename VariantAlternative<Variant, I>::type *>(m_union);
}
template <typename T>
T &get() {
return get<VariantIndex<Variant, T>::value>();
}
template <size_t I>
typename VariantAlternative<Variant, I>::type const &get() const {
static_assert(I < sizeof...(Ts), "I out of range!");
if (m_index != I)
throw BadVariantAccess();
return *reinterpret_cast<typename VariantAlternative<Variant, I>::type const *>(m_union);
}
template <typename T>
T const &get() const {
return get<VariantIndex<Variant, T>::value>();
}
//新增get_if
template <size_t I>
typename VariantAlternative<Variant, I>::type *get_if() {
static_assert(I < sizeof...(Ts), "I out of range!");
if (m_index != I)
return nullptr;
return reinterpret_cast<typename VariantAlternative<Variant, I>::type *>(m_union);
}
template <typename T>
T *get_if() {
return get_if<VariantIndex<Variant, T>::value>();
}
template <size_t I>
typename VariantAlternative<Variant, I>::type const *get_if() const {
static_assert(I < sizeof...(Ts), "I out of range!");
if (m_index != I)
return nullptr;
return reinterpret_cast<typename VariantAlternative<Variant, I>::type const *>(m_union);
}
template <typename T>
T const *get_if() const {
return get_if<VariantIndex<Variant, T>::value>();
}
};
//从下标获取类型
//找类型只能是using或者typename
template <typename T, typename ...Ts>
struct VariantAlternative<Variant<T, Ts...>, 0> {
using type = T;
};
template <typename T, typename ...Ts, size_t I>
struct VariantAlternative<Variant<T, Ts...>, I> {
using type = typename VariantAlternative<Variant<Ts...>, I - 1>::type;
};
/*
支持任意长度参数的variant index。如果T
是第零个则走第一个,如果不是,则走第二个,但是要把t0去掉
*/
template <typename T, typename ...Ts>
struct VariantIndex<Variant<T, Ts...>, T> {
static constexpr size_t value = 0;
};
template <typename T0, typename T, typename ...Ts>
struct VariantIndex<Variant<T0, Ts...>, T> {
static constexpr size_t value = VariantIndex<Variant<Ts...>, T>::value + 1;
};
std::common_type<typename std::invoke_result<Lambda, Ts &>::type…>::type
它的作用是 推导出 Lambda 作用于所有 Ts 类型的返回值的公共类型。
std::invoke_result<Lambda, Ts &>::type
std::invoke_result<F, Args…>::type 用于获取 调用 F(Args…) 之后的返回值类型。
Lambda 是一个可调用对象(比如 lambda 函数)。
Ts… 是 Variant 的所有可能类型。
Ts & 代表 Lambda 会以 Ts 类型的左值引用作为参数进行调用。
typename std::invoke_result<Lambda, Ts &>::type...
会展开为:
typename std::invoke_result<Lambda, T1 &>::type,
typename std::invoke_result<Lambda, T2 &>::type,
typename std::invoke_result<Lambda, T3 &>::type,
...
即,Lambda 作用于 T1 &, T2 &, T3 &, … 后得到的所有返回值类型。
std::common_type<T1, T2, T3, …>::type
std::common_type<T1, T2, …>::type 计算出 所有 T1, T2, T3, … 的公共类型。
std::common_type<int, double>::type // → double
std::common_type<float, double>::type // → double
std::common_type<int, float, double>::type // → double
在 visit 函数中,std::common_type 确保 Lambda 适用于所有 Ts,并且它的返回类型是 所有返回值类型的公共类型。
我们用 Variant<int, double, std::string> 举例:
#include <iostream>
#include <type_traits>
#include <string>
template <typename... Ts>
struct Variant {
}; // 这里只是一个占位的 Variant
int main() {
auto lambda = [](auto &val) -> decltype(val + 1) {
return val + 1;
};
using ResultType = std::common_type<
typename std::invoke_result<decltype(lambda), int &>::type,
typename std::invoke_result<decltype(lambda), double &>::type
>::type;
std::cout << std::is_same<ResultType, double>::value << std::endl; // 输出 1,表示推导为 double
return 0;
}