codegen: split genClassTypeHandler

This commit is contained in:
Jake Hillion 2023-08-10 13:34:06 -07:00 committed by Jake Hillion
parent ff31971bd3
commit 063646a1d9
2 changed files with 143 additions and 117 deletions

View File

@ -579,7 +579,7 @@ void addStandardTypeHandlers(std::string& code) {
JLOGPTR(&t);
return TypeHandler<DB, T>::getSizeType(t, returnArg);
}
)";
)";
code += R"(
template<typename DB, typename T0, long unsigned int N>
@ -597,21 +597,140 @@ void addStandardTypeHandlers(std::string& code) {
return tail.finish();
}
};
)";
)";
}
// Find the last member that isn't padding's index. Return -1 if no such member.
size_t getLastNonPaddingMemberIndex(const std::vector<Member>& members) {
for (size_t i = members.size() - 1; i != (size_t)-1; --i) {
const auto& el = members[i];
if (!el.name.starts_with(AddPadding::MemberPrefix))
return i;
}
return -1;
}
} // namespace
void CodeGen::getClassTypeHandler(const Class& c, std::string& code) {
// Generate the function body that walks the type. Uses the monadic
// `delegate()` form to handle each field except for the last. The last field
// is handled explicitly by passing it to `getSizeType`, as we must consume
// the entire type instead of delegating the next part.
void CodeGen::genClassTraversalFunction(const Class& c, std::string& code) {
std::string funcName = "getSizeType";
std::string extras;
code += " static types::st::Unit<DB> ";
code += funcName;
code += "(\n const ";
code += c.name();
code += "& t,\n typename TypeHandler<DB, ";
code += c.name();
code += ">::type returnArg) {\n";
const Member* thriftIssetMember = nullptr;
if (const auto it = thriftIssetMembers_.find(&c);
it != thriftIssetMembers_.end()) {
thriftIssetMember = it->second;
}
extras += (boost::format(R"(
size_t emptySize = code.size();
size_t lastNonPaddingElement = getLastNonPaddingMemberIndex(c.members);
for (size_t i = 0; i < lastNonPaddingElement + 1; i++) {
const auto& member = c.members[i];
if (member.name.starts_with(AddPadding::MemberPrefix)) {
continue;
}
if (code.size() == emptySize) {
code += " auto ret = returnArg";
}
if (thriftIssetMember != nullptr && thriftIssetMember != &member) {
code += "\n .write(getThriftIsset(t, " + std::to_string(i) + "))";
}
if (i != lastNonPaddingElement) {
code +=
"\n .delegate([&t](auto ret) { return "
"OIInternal::getSizeType<DB>(t.";
code += member.name;
code += ", ret); })";
} else {
code += ";";
code += "\n return OIInternal::getSizeType<DB>(t." + member.name +
", ret);\n";
}
}
if (code.size() == emptySize) {
code += " return returnArg;";
}
code += " }\n";
}
// Generate the static type for the class's representation in the data buffer.
// For `class { int a,b,c; }` we generate (DB omitted for clarity):
// Pair<TypeHandler<int>::type,
// Pair<TypeHandler<int>::type,
// TypeHandler<int>::type
// >>
void CodeGen::genClassStaticType(const Class& c, std::string& code) {
const Member* thriftIssetMember = nullptr;
if (const auto it = thriftIssetMembers_.find(&c);
it != thriftIssetMembers_.end()) {
thriftIssetMember = it->second;
}
size_t lastNonPaddingElement = getLastNonPaddingMemberIndex(c.members);
size_t pairs = 0;
size_t emptySize = code.size();
for (size_t i = 0; i < lastNonPaddingElement + 1; i++) {
const auto& member = c.members[i];
if (member.name.starts_with(AddPadding::MemberPrefix)) {
continue;
}
if (i != lastNonPaddingElement) {
code += "types::st::Pair<DB, ";
pairs++;
}
if (thriftIssetMember != nullptr && thriftIssetMember != &member) {
// Return an additional VarInt before every field except for __isset
// itself.
pairs++;
if (i == lastNonPaddingElement) {
code += "types::st::Pair<DB, types::st::VarInt<DB>, ";
} else {
code += "types::st::VarInt<DB>, types::st::Pair<DB, ";
}
}
code +=
(boost::format("typename TypeHandler<DB, decltype(%1%::%2%)>::type") %
c.name() % member.name)
.str();
if (i != lastNonPaddingElement) {
code += ", ";
}
}
code += std::string(pairs, '>');
if (code.size() == emptySize) {
code += "types::st::Unit<DB>";
}
}
void CodeGen::genClassTypeHandler(const Class& c, std::string& code) {
std::string helpers;
if (const auto it = thriftIssetMembers_.find(&c);
it != thriftIssetMembers_.end()) {
const Member& thriftIssetMember = *it->second;
helpers += (boost::format(R"(
static int getThriftIsset(const %1%& t, size_t i) {
using thrift_data = apache::thrift::TStructDataStorage<%2%>;
@ -622,118 +741,22 @@ void CodeGen::getClassTypeHandler(const Class& c, std::string& code) {
return t.%3%.get(idx);
}
)") % c.name() %
c.fqName() % thriftIssetMember->name)
)") % c.name() % c.fqName() %
thriftIssetMember.name)
.str();
}
size_t lastNonPaddingElement = -1;
for (size_t i = 0; i < c.members.size(); i++) {
const auto& el = c.members[i];
if (!el.name.starts_with(AddPadding::MemberPrefix)) {
lastNonPaddingElement = i;
}
}
// Generate the static type for the class's representation in the data buffer.
// For `class { int a,b,c; }` we generate (DB omitted for clarity):
// Pair<TypeHandler<int>::type,
// Pair<TypeHandler<int>::type,
// TypeHandler<int>::type
// >>
std::string typeStaticType;
{
size_t pairs = 0;
for (size_t i = 0; i < lastNonPaddingElement + 1; i++) {
const auto& member = c.members[i];
if (member.name.starts_with(AddPadding::MemberPrefix)) {
continue;
}
if (i != lastNonPaddingElement) {
typeStaticType += "types::st::Pair<DB, ";
pairs++;
}
if (thriftIssetMember != nullptr && thriftIssetMember != &member) {
// Return an additional VarInt before every field except for __isset
// itself.
pairs++;
if (i == lastNonPaddingElement) {
typeStaticType += "types::st::Pair<DB, types::st::VarInt<DB>, ";
} else {
typeStaticType += "types::st::VarInt<DB>, types::st::Pair<DB, ";
}
}
typeStaticType +=
(boost::format("typename TypeHandler<DB, decltype(%1%::%2%)>::type") %
c.name() % member.name)
.str();
if (i != lastNonPaddingElement) {
typeStaticType += ", ";
}
}
typeStaticType += std::string(pairs, '>');
if (typeStaticType.empty()) {
typeStaticType = "types::st::Unit<DB>";
}
}
// Generate the function body that walks the type. Uses the monadic
// `delegate()` form to handle each field except for the last. The last field
// is handled explicitly by passing it to `getSizeType`, as we must consume
// the entire type instead of delegating the next part.
std::string traverser;
{
for (size_t i = 0; i < lastNonPaddingElement + 1; i++) {
const auto& member = c.members[i];
if (member.name.starts_with(AddPadding::MemberPrefix)) {
continue;
}
if (traverser.empty()) {
traverser = "auto ret = returnArg";
}
if (thriftIssetMember != nullptr && thriftIssetMember != &member) {
traverser += "\n .write(getThriftIsset(t, " + std::to_string(i) + "))";
}
if (i != lastNonPaddingElement) {
traverser += "\n .delegate([&t](auto ret) {";
traverser += "\n return OIInternal::getSizeType<DB>(t." +
member.name + ", ret);";
traverser += "\n})";
} else {
traverser += ";";
traverser +=
"\nreturn OIInternal::getSizeType<DB>(t." + member.name + ", ret);";
}
}
if (traverser.empty()) {
traverser = "return returnArg;";
}
}
code += (boost::format(R"(
template <typename DB>
class TypeHandler<DB, %1%> {%2%
public:
using type = %3%;
static types::st::Unit<DB> %4%(
const %1%& t,
typename TypeHandler<DB, %1%>::type returnArg) {
%5%
}
};
)") % c.name() %
extras % typeStaticType % funcName % traverser)
.str();
code += "template <typename DB>\n";
code += "class TypeHandler<DB, ";
code += c.name();
code += "> {\n";
code += helpers;
code += " public:\n";
code += " using type = ";
genClassStaticType(c, code);
code += ";\n";
genClassTraversalFunction(c, code);
code += "};\n";
}
namespace {
@ -763,7 +786,7 @@ void getContainerTypeHandler(std::unordered_set<const ContainerInfo*>& used,
void CodeGen::addTypeHandlers(const TypeGraph& typeGraph, std::string& code) {
for (const Type& t : typeGraph.finalTypes) {
if (const auto* c = dynamic_cast<const Class*>(&t)) {
getClassTypeHandler(*c, code);
genClassTypeHandler(*c, code);
} else if (const auto* con = dynamic_cast<const Container*>(&t)) {
getContainerTypeHandler(definedContainers_, *con, code);
}

View File

@ -76,7 +76,10 @@ class CodeGen {
std::string& code) const;
void addTypeHandlers(const type_graph::TypeGraph& typeGraph,
std::string& code);
void getClassTypeHandler(const type_graph::Class& c, std::string& code);
void genClassTypeHandler(const type_graph::Class& c, std::string& code);
void genClassStaticType(const type_graph::Class& c, std::string& code);
void genClassTraversalFunction(const type_graph::Class& c, std::string& code);
};
} // namespace oi::detail