From b0aa7a8c9363a779b80774da4844d812ec782952 Mon Sep 17 00:00:00 2001 From: Jake Hillion Date: Thu, 20 Apr 2023 10:30:59 -0700 Subject: [PATCH] TypedDataSegment: implementation --- oi/CodeGen.cpp | 237 +++++++++++++++++++++++++++++++++------ oi/ContainerInfo.cpp | 4 + oi/ContainerInfo.h | 1 + oi/Features.h | 1 + oi/FuncGen.cpp | 258 +++++++++++++++++++++++++++++++++++++++++++ oi/FuncGen.h | 6 + oi/OITraceCode.cpp | 2 +- oi/OIUtils.cpp | 8 ++ 8 files changed, 483 insertions(+), 34 deletions(-) diff --git a/oi/CodeGen.cpp b/oi/CodeGen.cpp index 6105bd3..f251182 100644 --- a/oi/CodeGen.cpp +++ b/oi/CodeGen.cpp @@ -19,6 +19,8 @@ #include #include +#include +#include #include "oi/FuncGen.h" #include "oi/Headers.h" @@ -51,7 +53,26 @@ void defineMacros(std::string& code) { code += R"( #define SAVE_SIZE(val) #define SAVE_DATA(val) StoreData(val, returnArg) +)"; + } else { + code += R"( +#define SAVE_SIZE(val) AddData(val, returnArg) +#define SAVE_DATA(val) +)"; + } +} +void defineArray(std::string& code) { + code += R"( +template +struct OIArray { + T vals[N]; +}; +)"; +} + +void defineJitLog(std::string& code) { + code += R"( #define JLOG(str) \ do { \ if (__builtin_expect(logFile, 0)) { \ @@ -65,32 +86,26 @@ void defineMacros(std::string& code) { __jlogptr((uintptr_t)ptr); \ } \ } while (false) - -template -struct OIArray { - T vals[N]; -}; )"; - } else { - code += R"( -#define SAVE_SIZE(val) AddData(val, returnArg) -#define SAVE_DATA(val) -#define JLOG(str) -#define JLOGPTR(ptr) -)"; - } } -void addIncludes(const TypeGraph& typeGraph, std::string& code) { - // Required for the offsetof() macro - code += "#include \n"; - - // TODO deduplicate containers +void addIncludes(const TypeGraph& typeGraph, + FeatureSet features, + std::string& code) { + std::set includes{"cstddef"}; + if (features[Feature::TypedDataSegment]) { + includes.emplace("functional"); + } for (const Type& t : typeGraph.finalTypes) { if (const auto* c = dynamic_cast(&t)) { - code += "#include <" + c->containerInfo_.header + ">\n"; + includes.emplace(c->containerInfo_.header); } } + for (const auto& include : includes) { + code += "#include <"; + code += include; + code += ">\n"; + } } void genDeclsClass(const Class& c, std::string& code) { @@ -367,10 +382,9 @@ void getContainerSizeFuncDef(const Container& c, std::string& code) { // - implement hash for ContainerInfo // - use ref static std::unordered_set usedContainers{}; - if (usedContainers.find(&c.containerInfo_) != usedContainers.end()) { + if (!usedContainers.insert(&c.containerInfo_).second) { return; } - usedContainers.insert(&c.containerInfo_); auto fmt = boost::format(c.containerInfo_.codegen.func) % c.containerInfo_.typeName; @@ -399,6 +413,136 @@ void addGetSizeFuncDefs(const TypeGraph& typeGraph, } } } + +void addStandardTypeHandlers(std::string& code) { + code += R"( + template + StaticTypes::Unit + getSizeType(const T &t, typename TypeHandler::type returnArg) { + JLOG("obj @"); + JLOGPTR(&t); + return TypeHandler::getSizeType(t, returnArg); + } + )"; + + code += R"( + template + struct TypeHandler> { + using type = StaticTypes::List::type>; + static StaticTypes::Unit getSizeType( + const OIArray &container, + typename TypeHandler>::type returnArg) { + auto tail = returnArg.write(N); + for (size_t i=0; i::getSizeType(container.vals[i], ret); + }); + } + return tail.finish(); + } + }; + )"; +} + +void getClassTypeHandler(const Class& c, std::string& code) { + std::string funcName = "getSizeType"; + + std::string typeStaticType; + { + size_t pairs = 0; + + for (size_t i = 0; i < c.members.size(); i++) { + const auto& member = c.members[i]; + + if (i != c.members.size() - 1) { + typeStaticType += "StaticTypes::Pair::type") % + c.name() % member.name) + .str(); + + if (i != c.members.size() - 1) { + typeStaticType += ", "; + } + } + typeStaticType += std::string(pairs, '>'); + + if (typeStaticType.empty()) { + typeStaticType = "StaticTypes::Unit"; + } + } + + std::string traverser; + { + if (!c.members.empty()) { + traverser = "auto ret = returnArg"; + } + for (size_t i = 0; i < c.members.size(); i++) { + const auto& member = c.members[i]; + + if (i != c.members.size() - 1) { + traverser += "\n .delegate([&t](auto ret) {"; + traverser += "\n return OIInternal::getSizeType(t." + + member.name + ", ret);"; + traverser += "\n})"; + } else { + traverser += ";"; + traverser += + "\nreturn OIInternal::getSizeType(t." + member.name + ", ret);"; + } + } + + if (traverser.empty()) { + traverser = "return returnArg;"; + } + } + + code += (boost::format(R"( +template +class TypeHandler { + public: + using type = %2%; + static StaticTypes::Unit %3%( + const %1%& t, + typename TypeHandler::type returnArg) { + %4% + } +}; +)") % c.name() % + typeStaticType % funcName % traverser) + .str(); +} + +void getContainerTypeHandler(const Container& c, std::string& code) { + static std::unordered_set usedContainers{}; + if (!usedContainers.insert(&c.containerInfo_).second) { + return; + } + + const auto& handler = c.containerInfo_.codegen.handler; + if (handler.empty()) { + LOG(ERROR) << "`codegen.handler` must be specified for all containers " + "under \"-ftyped-data-segment\", not specified for \"" + + c.containerInfo_.typeName + "\""; + throw std::runtime_error("missing `codegen.handler`"); + } + auto fmt = boost::format(c.containerInfo_.codegen.handler) % + c.containerInfo_.typeName; + code += fmt.str(); +} + +void addTypeHandlers(const TypeGraph& typeGraph, std::string& code) { + for (const Type& t : typeGraph.finalTypes) { + if (const auto* c = dynamic_cast(&t)) { + getClassTypeHandler(*c, code); + } else if (const auto* con = dynamic_cast(&t)) { + getContainerTypeHandler(*con, code); + } + } +} } // namespace bool CodeGen::generate(drgn_type* drgnType, std::string& code) { @@ -435,8 +579,22 @@ bool CodeGen::generate(drgn_type* drgnType, std::string& code) { }; code = headers::OITraceCode_cpp; - defineMacros(code); - addIncludes(typeGraph_, code); + if (!config_.features[Feature::TypedDataSegment]) { + defineMacros(code); + } + addIncludes(typeGraph_, config_.features, code); + defineArray(code); + defineJitLog(code); // TODO: feature gate this + + if (config_.features[Feature::TypedDataSegment]) { + FuncGen::DefineDataSegmentDataBuffer(code); + FuncGen::DefineStaticTypes(code); + code += "using namespace ObjectIntrospection;\n"; + + code += "namespace OIInternal {\nnamespace {\n"; + FuncGen::DefineBasicTypeHandlers(code); + code += "} // namespace\n} // namespace OIInternal\n"; + } /* * The purpose of the anonymous namespace within `OIInternal` is that @@ -448,29 +606,42 @@ bool CodeGen::generate(drgn_type* drgnType, std::string& code) { * process faster. */ code += "namespace OIInternal {\nnamespace {\n"; - FuncGen::DefineEncodeData(code); - FuncGen::DefineEncodeDataSize(code); - FuncGen::DefineStoreData(code); - FuncGen::DefineAddData(code); + if (!config_.features[Feature::TypedDataSegment]) { + FuncGen::DefineEncodeData(code); + FuncGen::DefineEncodeDataSize(code); + FuncGen::DefineStoreData(code); + FuncGen::DefineAddData(code); + } FuncGen::DeclareGetContainer(code); genDecls(typeGraph_, code); genDefs(typeGraph_, code); genStaticAsserts(typeGraph_, code); - addStandardGetSizeFuncDecls(code); - addGetSizeFuncDecls(typeGraph_, code); + if (config_.features[Feature::TypedDataSegment]) { + addStandardTypeHandlers(code); + addTypeHandlers(typeGraph_, code); + } else { + addStandardGetSizeFuncDecls(code); + addGetSizeFuncDecls(typeGraph_, code); - addStandardGetSizeFuncDefs(code); - addGetSizeFuncDefs(typeGraph_, symbols_, - config_.features[Feature::PolymorphicInheritance], code); + addStandardGetSizeFuncDefs(code); + addGetSizeFuncDefs(typeGraph_, symbols_, + config_.features[Feature::PolymorphicInheritance], code); + } assert(typeGraph_.rootTypes().size() == 1); Type& rootType = typeGraph_.rootTypes()[0]; code += "\nusing __ROOT_TYPE__ = " + rootType.name() + ";\n"; code += "} // namespace\n} // namespace OIInternal\n"; - FuncGen::DefineTopLevelGetSizeRef(code, SymbolService::getTypeName(drgnType)); + if (config_.features[Feature::TypedDataSegment]) { + FuncGen::DefineTopLevelGetSizeRefTyped( + code, SymbolService::getTypeName(drgnType)); + } else { + FuncGen::DefineTopLevelGetSizeRef(code, + SymbolService::getTypeName(drgnType)); + } if (VLOG_IS_ON(3)) { VLOG(3) << "Generated trace code:\n"; diff --git a/oi/ContainerInfo.cpp b/oi/ContainerInfo.cpp index ff9a6cf..f8a154f 100644 --- a/oi/ContainerInfo.cpp +++ b/oi/ContainerInfo.cpp @@ -259,6 +259,10 @@ ContainerInfo::ContainerInfo(const fs::path& path) { } else { throw std::runtime_error("`codegen.decl` is a required field"); } + if (std::optional str = + codegenToml["handler"].value()) { + codegen.handler = std::move(*str); + } } ContainerInfo::ContainerInfo(std::string typeName_, diff --git a/oi/ContainerInfo.h b/oi/ContainerInfo.h index b869f28..72ca541 100644 --- a/oi/ContainerInfo.h +++ b/oi/ContainerInfo.h @@ -30,6 +30,7 @@ struct ContainerInfo { struct Codegen { std::string decl; std::string func; + std::string handler = ""; }; explicit ContainerInfo(const std::filesystem::path& path); // Throws diff --git a/oi/Features.h b/oi/Features.h index 5b89c51..6a9b591 100644 --- a/oi/Features.h +++ b/oi/Features.h @@ -26,6 +26,7 @@ X(GenPaddingStats, "gen-padding-stats") \ X(CaptureThriftIsset, "capture-thrift-isset") \ X(TypeGraph, "type-graph") \ + X(TypedDataSegment, "typed-data-segment") \ X(PolymorphicInheritance, "polymorphic-inheritance") namespace ObjectIntrospection { diff --git a/oi/FuncGen.cpp b/oi/FuncGen.cpp index 41fcc88..af68eab 100644 --- a/oi/FuncGen.cpp +++ b/oi/FuncGen.cpp @@ -266,6 +266,45 @@ void FuncGen::DefineTopLevelGetSizeRef(std::string& testCode, testCode.append(fmt.str()); } +void FuncGen::DefineTopLevelGetSizeRefTyped(std::string& testCode, + const std::string& rawType) { + std::string func = R"( + #pragma GCC diagnostic push + #pragma GCC diagnostic ignored "-Wunknown-attributes" + /* RawType: %1% */ + void __attribute__((used, retain)) getSize_%2$016x(const OIInternal::__ROOT_TYPE__& t) + #pragma GCC diagnostic pop + { + pointers.initialize(); + pointers.add((uintptr_t)&t); + auto data = reinterpret_cast(dataBase); + data[0] = oidMagicId; + data[1] = cookieValue; + data[2] = 0; + size_t dataSegOffset = 3 * sizeof(uintptr_t); + JLOG("%1% @"); + JLOGPTR(&t); + using DataBufferType = OIInternal::TypeHandler::type; + DataBufferType db = DataBuffer::DataSegment(dataSegOffset); + StaticTypes::Unit out = OIInternal::getSizeType(t, db); + StaticTypes::Unit final = out.template cast, + StaticTypes::VarInt + >>() + .write(123456789) + .write(123456789); + dataSegOffset = final.offset(); + data[2] = dataSegOffset; + dataBase += dataSegOffset; + } + )"; + + boost::format fmt = + boost::format(func) % rawType % std::hash{}(rawType); + testCode.append(fmt.str()); +} + void FuncGen::DefineTopLevelGetSizeRefRet(std::string& testCode, const std::string& rawType) { std::string func = R"( @@ -391,3 +430,222 @@ void FuncGen::DeclareGetContainer(std::string& testCode) { )"; testCode.append(func); } + +void FuncGen::DefineDataSegmentDataBuffer(std::string& testCode) { + constexpr std::string_view func = R"( + namespace ObjectIntrospection { + namespace DataBuffer { + class DataBuffer { + protected: + void write_byte(uint8_t); + }; + class DataSegment: public DataBuffer { + public: + DataSegment(size_t offset) : buf(dataBase + offset) {} + void write_byte(uint8_t byte) { + // TODO: Change the inputs to dataBase / dataEnd to improve this check + if (buf < (dataBase + dataSize)) { + *buf = byte; + } + buf++; + } + size_t offset() { + return buf - dataBase; + } + private: + uint8_t* buf; + }; + } // namespace DataBuffer + } // namespace ObjectIntrospection + )"; + + testCode.append(func); +} + +void FuncGen::DefineBasicTypeHandlers(std::string& testCode) { + constexpr std::string_view handlers = R"( + template + struct TypeHandler { + private: + static auto choose_type() { + if constexpr(std::is_pointer_v) { + return std::type_identity, + StaticTypes::Sum, + typename TypeHandler>::type + >>>(); + } else { + return std::type_identity>(); + } + } + public: + using type = typename decltype(choose_type())::type; + static StaticTypes::Unit getSizeType( + const T& t, + typename TypeHandler::type returnArg) { + if constexpr(std::is_pointer_v) { + JLOG("ptr val @"); + JLOGPTR(t); + auto r0 = returnArg.write((uintptr_t)t); + if (t && pointers.add((uintptr_t)t)) { + return r0.template delegate<1>([&t](auto ret) { + if constexpr (!std::is_void>::value) { + return TypeHandler>::getSizeType(*t, ret); + } else { + return ret; + } + }); + } else { + return r0.template delegate<0>(std::identity()); + } + } else { + return returnArg; + } + } + }; + template + class TypeHandler { + public: + using type = StaticTypes::Unit; + }; + )"; + + testCode.append(handlers); +} + +void FuncGen::DefineStaticTypes(std::string& testCode) { + constexpr std::string_view unitType = R"( + template + class Unit { + public: + Unit(DataBuffer db) : _buf(db) {} + + size_t offset() { + return _buf.offset(); + } + + template + T cast() { + return T(_buf); + } + + template + Unit + delegate(F const& cb) { + return cb(*this); + } + + private: + DataBuffer _buf; + }; + )"; + + constexpr std::string_view varintType = R"( + template + class VarInt { + public: + VarInt(DataBuffer db) : _buf(db) {} + + Unit write(uint64_t val) { + while (val >= 128) { + _buf.write_byte(0x80 | (val & 0x7f)); + val >>= 7; + } + _buf.write_byte(uint8_t(val)); + return Unit(_buf); + } + + private: + DataBuffer _buf; + }; + )"; + + constexpr std::string_view pairType = R"( + template + class Pair { + public: + Pair(DataBuffer db) : _buf(db) {} + template + T2 write(U val) { + Unit second = T1(_buf).write(val); + return second.template cast(); + } + template + T2 delegate(F const& cb) { + T1 first = T1(_buf); + Unit second = cb(first); + return second.template cast(); + } + private: + DataBuffer _buf; + }; + )"; + + constexpr std::string_view sumType = R"( + template + class Sum { + private: + template + struct Selector; + template + struct Selector { + using type = typename std::conditional::type>::type; + }; + template + struct Selector { + using type = int; + }; + public: + Sum(DataBuffer db) : _buf(db) {} + template + typename Selector::type write() { + Pair, typename Selector::type> buf(_buf); + return buf.write(I); + } + template + Unit delegate(F const& cb) { + auto tail = write(); + return cb(tail); + } + private: + DataBuffer _buf; + }; + )"; + + constexpr std::string_view listType = R"( + template + class ListContents { + public: + ListContents(DataBuffer db) : _buf(db) {} + + template + ListContents delegate(F const& cb) { + T head = T(_buf); + Unit tail = cb(head); + return tail.template cast>(); + } + + Unit finish() { + return { _buf }; + } + private: + DataBuffer _buf; + }; + + template + using List = Pair, ListContents>; + )"; + + testCode.append("namespace ObjectIntrospection {\n"); + testCode.append("namespace StaticTypes {\n"); + + testCode.append(unitType); + testCode.append(varintType); + testCode.append(pairType); + testCode.append(sumType); + testCode.append(listType); + + testCode.append("} // namespace StaticTypes {\n"); + testCode.append("} // namespace ObjectIntrospection {\n"); +} diff --git a/oi/FuncGen.h b/oi/FuncGen.h index 48f0095..59ec165 100644 --- a/oi/FuncGen.h +++ b/oi/FuncGen.h @@ -56,6 +56,8 @@ class FuncGen { static void DefineTopLevelGetSizeRef(std::string& testCode, const std::string& rawType); + static void DefineTopLevelGetSizeRefTyped(std::string& testCode, + const std::string& rawType); static void DefineTopLevelGetSizeRefRet(std::string& testCode, const std::string& type); @@ -65,4 +67,8 @@ class FuncGen { static void DefineGetSizeTypedValueFunc(std::string& testCode, const std::string& ctype); + + static void DefineDataSegmentDataBuffer(std::string& testCode); + static void DefineStaticTypes(std::string& testCode); + static void DefineBasicTypeHandlers(std::string& testCode); }; diff --git a/oi/OITraceCode.cpp b/oi/OITraceCode.cpp index 5ea879c..c20f15a 100644 --- a/oi/OITraceCode.cpp +++ b/oi/OITraceCode.cpp @@ -35,7 +35,7 @@ #define C10_USING_CUSTOM_GENERATED_MACROS // These globals are set by oid, see end of OIDebugger::compileCode() -extern uintptr_t dataBase; +extern uint8_t* dataBase; extern size_t dataSize; extern uintptr_t cookieValue; extern int logFile; diff --git a/oi/OIUtils.cpp b/oi/OIUtils.cpp index 66305ab..da5eec1 100644 --- a/oi/OIUtils.cpp +++ b/oi/OIUtils.cpp @@ -160,6 +160,14 @@ std::optional processConfigFile( featuresSet[k] = true; } } + + if (featuresSet[Feature::TypedDataSegment] && + !featuresSet[Feature::TypeGraph]) { + featuresSet[Feature::TypeGraph] = true; + LOG(WARNING) << "TypedDataSegment feature requires TypeGraph feature to be " + "enabled, enabling now."; + } + return featuresSet; }