CodeGen: Add support for key capture

This commit is contained in:
Alastair Robertson 2023-09-22 06:35:33 -07:00 committed by Alastair Robertson
parent 3446339358
commit 4afa2ff190
2 changed files with 118 additions and 4 deletions

View File

@ -32,6 +32,7 @@
#include "type_graph/DrgnParser.h"
#include "type_graph/EnforceCompatibility.h"
#include "type_graph/Flattener.h"
#include "type_graph/KeyCapture.h"
#include "type_graph/NameGen.h"
#include "type_graph/Prune.h"
#include "type_graph/RemoveMembers.h"
@ -46,6 +47,7 @@ namespace oi::detail {
using type_graph::AddChildren;
using type_graph::AddPadding;
using type_graph::AlignmentCalc;
using type_graph::CaptureKeys;
using type_graph::Class;
using type_graph::Container;
using type_graph::DrgnParser;
@ -53,6 +55,7 @@ using type_graph::DrgnParserOptions;
using type_graph::EnforceCompatibility;
using type_graph::Enum;
using type_graph::Flattener;
using type_graph::KeyCapture;
using type_graph::Member;
using type_graph::NameGen;
using type_graph::Primitive;
@ -96,12 +99,17 @@ void defineMacros(std::string& code) {
}
}
void defineArray(std::string& code) {
void defineInternalTypes(std::string& code) {
code += R"(
template<typename T, int N>
struct OIArray {
T vals[N];
};
// Just here to give a different type name to containers whose keys we'll capture
template <typename T>
struct OICaptureKeys : public T {
};
)";
}
@ -831,6 +839,8 @@ void genContainerTypeHandler(FeatureSet features,
return;
}
code += c.codegen.extra;
// TODO: Move this check into the ContainerInfo parsing once always enabled.
const auto& func = c.codegen.traversalFunc;
const auto& processors = c.codegen.processors;
@ -859,6 +869,10 @@ void genContainerTypeHandler(FeatureSet features,
if (!templateParams.empty())
containerWithTypes += '>';
if (c.captureKeys) {
containerWithTypes = "OICaptureKeys<" + containerWithTypes + ">";
}
code += "template <typename DB";
types = 0, values = 0;
for (const auto& p : templateParams) {
@ -875,6 +889,11 @@ void genContainerTypeHandler(FeatureSet features,
code += containerWithTypes;
code += "> {\n";
if (c.captureKeys) {
code += " static constexpr bool captureKeys = true;\n";
} else {
code += " static constexpr bool captureKeys = false;\n";
}
code += " using type = ";
if (processors.empty()) {
code += "types::st::Unit<DB>";
@ -931,9 +950,80 @@ void genContainerTypeHandler(FeatureSet features,
code += "};\n\n";
}
void addCaptureKeySupport(std::string& code) {
code += R"(
template <typename DB, typename T>
class CaptureKeyHandler {
public:
using type = types::st::Sum<DB, types::st::VarInt<DB>, types::st::VarInt<DB>>;
static auto captureKey(const T& key, auto returnArg) {
// Save scalars keys directly, otherwise save pointers for complex types
if constexpr (std::is_scalar_v<T>) {
return returnArg.template write<0>().write(static_cast<uint64_t>(key));
}
return returnArg.template write<1>().write(reinterpret_cast<uintptr_t>(&key));
}
};
template <bool CaptureKeys, typename DB, typename T>
auto maybeCaptureKey(const T& key, auto returnArg) {
if constexpr (CaptureKeys) {
return returnArg.delegate([&key](auto ret) {
return CaptureKeyHandler<DB, T>::captureKey(key, ret);
});
} else {
return returnArg;
}
}
template <typename DB, typename T>
static constexpr inst::ProcessorInst CaptureKeysProcessor{
CaptureKeyHandler<DB, T>::type::describe,
[](result::Element& el, std::function<void(inst::Inst)> stack_ins, ParsedData d) {
if constexpr (std::is_same_v<
typename CaptureKeyHandler<DB, T>::type,
types::st::List<DB, types::st::VarInt<DB>>>) {
// String
auto& str = el.data.emplace<std::string>();
auto list = std::get<ParsedData::List>(d.val);
size_t strlen = list.length;
for (size_t i = 0; i < strlen; i++) {
auto value = list.values().val;
auto c = std::get<ParsedData::VarInt>(value).value;
str.push_back(c);
}
} else {
auto sum = std::get<ParsedData::Sum>(d.val);
if (sum.index == 0) {
el.data = oi::result::Element::Scalar{std::get<ParsedData::VarInt>(sum.value().val).value};
} else {
el.data = oi::result::Element::Pointer{std::get<ParsedData::VarInt>(sum.value().val).value};
}
}
}
};
template <bool CaptureKeys, typename DB, typename T>
static constexpr auto maybeCaptureKeysProcessor() {
if constexpr (CaptureKeys) {
return std::array<inst::ProcessorInst, 1>{
CaptureKeysProcessor<DB, T>,
};
}
else {
return std::array<inst::ProcessorInst, 0>{};
}
}
)";
}
void addStandardTypeHandlers(TypeGraph& typeGraph,
FeatureSet features,
std::string& code) {
if (features[Feature::TreeBuilderV2])
addCaptureKeySupport(code);
// Provide a wrapper function, getSizeType, to infer T instead of having to
// explicitly specify it with TypeHandler<DB, T>::getSizeType every time.
code += R"(
@ -983,6 +1073,10 @@ void CodeGen::addTypeHandlers(const TypeGraph& typeGraph, std::string& code) {
} else if (const auto* con = dynamic_cast<const Container*>(&t)) {
genContainerTypeHandler(config_.features, definedContainers_,
con->containerInfo_, con->templateParams, code);
} else if (const auto* cap = dynamic_cast<const CaptureKeys*>(&t)) {
genContainerTypeHandler(config_.features, definedContainers_,
cap->containerInfo(),
cap->container().templateParams, code);
}
}
}
@ -1061,10 +1155,13 @@ void CodeGen::transform(TypeGraph& typeGraph) {
// Calculate alignment before removing members, as those members may have an
// influence on the class' overall alignment.
pm.addPass(AlignmentCalc::createPass());
pm.addPass(RemoveMembers::createPass(config_.membersToStub));
if (!config_.features[Feature::TreeBuilderV2]) {
if (!config_.features[Feature::TreeBuilderV2])
pm.addPass(EnforceCompatibility::createPass());
}
if (config_.features[Feature::TreeBuilderV2] &&
!config_.keysToCapture.empty())
pm.addPass(KeyCapture::createPass(config_.keysToCapture, containerInfos_));
// Add padding to fill in the gaps of removed members and ensure their
// alignments
@ -1094,7 +1191,7 @@ void CodeGen::generate(
defineMacros(code);
}
addIncludes(typeGraph, config_.features, code);
defineArray(code);
defineInternalTypes(code);
FuncGen::DefineJitLog(code, config_.features);
if (config_.features[Feature::TypedDataSegment]) {

View File

@ -54,6 +54,23 @@ struct TypeHandler<DB, %1% <T0>> {
};
"""
extra = """
template <typename DB, typename CharT, typename Traits, typename Allocator>
class CaptureKeyHandler<DB, std::__cxx11::basic_string<CharT, Traits, Allocator>> {
public:
// List of characters
using type = types::st::List<DB, types::st::VarInt<DB>>;
static auto captureKey(const std::__cxx11::basic_string<CharT, Traits, Allocator>& key, auto returnArg) {
auto tail = returnArg.write(key.size());
for (auto c : key) {
tail = returnArg.write((uintptr_t)c);
}
return tail.finish();
}
};
"""
traversal_func = """
bool sso = ((uintptr_t)container.data() <
(uintptr_t)(&container + sizeof(std::__cxx11::basic_string<T0>))) &&