diff --git a/include/oi/types/st.h b/include/oi/types/st.h index 5465aae..3bc7a04 100644 --- a/include/oi/types/st.h +++ b/include/oi/types/st.h @@ -80,7 +80,7 @@ class Unit { } template - Unit delegate(F const& cb) { + Unit consume(F const& cb) { return cb(*this); } @@ -131,6 +131,11 @@ class VarInt { return Unit(_buf); } + template + Unit consume(F const& cb) { + return cb(*this); + } + #ifdef DEFINE_DESCRIBE static constexpr types::dy::VarInt describe{}; #endif @@ -165,6 +170,11 @@ class Pair { return second.template cast(); } + template + Unit consume(F const& cb) { + return cb(*this); + } + #ifdef DEFINE_DESCRIBE static constexpr types::dy::Pair describe{T1::describe, T2::describe}; #endif @@ -217,6 +227,11 @@ class Sum { return cb(tail); } + template + Unit consume(F const& cb) { + return cb(*this); + } + #ifdef DEFINE_DESCRIBE private: static constexpr std::array members{ @@ -274,6 +289,11 @@ class List : Pair, ListContents>(db) { } + template + Unit consume(F const& cb) { + return cb(*this); + } + #ifdef DEFINE_DESCRIBE public: static constexpr types::dy::List describe{T::describe}; diff --git a/oi/CodeGen.cpp b/oi/CodeGen.cpp index 14c3f52..4c9d687 100644 --- a/oi/CodeGen.cpp +++ b/oi/CodeGen.cpp @@ -639,8 +639,8 @@ size_t getLastNonPaddingMemberIndex(const std::vector& members) { // Generate the function body that walks the type. Uses the monadic // `delegate()` form to handle each field except for the last. The last field -// is handled explicitly by passing it to `getSizeType`, as we must consume -// the entire type instead of delegating the next part. +// instead uses `consume()` as we must not accidentally handle the first half +// of a pair as the last field. void CodeGen::genClassTraversalFunction(const Class& c, std::string& code) { std::string funcName = "getSizeType"; @@ -667,30 +667,28 @@ void CodeGen::genClassTraversalFunction(const Class& c, std::string& code) { } if (code.size() == emptySize) { - code += " auto ret = returnArg"; + code += " return returnArg"; } if (thriftIssetMember != nullptr && thriftIssetMember != &member) { code += "\n .write(getThriftIsset(t, " + std::to_string(i) + "))"; } - if (i != lastNonPaddingElement) { - code += - "\n .delegate([&t](auto ret) { return " - "OIInternal::getSizeType(t."; - code += member.name; - code += ", ret); })"; + code += "\n ."; + if (i == lastNonPaddingElement) { + code += "consume"; } else { - code += ";"; - code += "\n return OIInternal::getSizeType(t." + member.name + - ", ret);\n"; + code += "delegate"; } + code += "([&t](auto ret) { return OIInternal::getSizeType(t."; + code += member.name; + code += ", ret); })"; } if (code.size() == emptySize) { code += " return returnArg;"; } - code += " }\n"; + code += ";\n }\n"; } // Generate the static type for the class's representation in the data buffer.