diff --git a/oi/type_graph/AddChildren.cpp b/oi/type_graph/AddChildren.cpp index 91dd17b..0ff8fc9 100644 --- a/oi/type_graph/AddChildren.cpp +++ b/oi/type_graph/AddChildren.cpp @@ -32,7 +32,7 @@ using ref = std::reference_wrapper; namespace oi::detail::type_graph { Pass AddChildren::createPass(DrgnParser& drgnParser, SymbolService& symbols) { - auto fn = [&drgnParser, &symbols](TypeGraph& typeGraph) { + auto fn = [&drgnParser, &symbols](TypeGraph& typeGraph, NodeTracker&) { AddChildren pass(typeGraph, drgnParser); pass.enumerateChildClasses(symbols); for (auto& type : typeGraph.rootTypes()) { diff --git a/oi/type_graph/AddPadding.cpp b/oi/type_graph/AddPadding.cpp index 885337a..79fbd53 100644 --- a/oi/type_graph/AddPadding.cpp +++ b/oi/type_graph/AddPadding.cpp @@ -25,7 +25,7 @@ using ref = std::reference_wrapper; namespace oi::detail::type_graph { Pass AddPadding::createPass() { - auto fn = [](TypeGraph& typeGraph) { + auto fn = [](TypeGraph& typeGraph, NodeTracker&) { AddPadding pass(typeGraph); for (auto& type : typeGraph.rootTypes()) { pass.accept(type); diff --git a/oi/type_graph/AlignmentCalc.cpp b/oi/type_graph/AlignmentCalc.cpp index da26cd6..31a7a30 100644 --- a/oi/type_graph/AlignmentCalc.cpp +++ b/oi/type_graph/AlignmentCalc.cpp @@ -25,7 +25,7 @@ using ref = std::reference_wrapper; namespace oi::detail::type_graph { Pass AlignmentCalc::createPass() { - auto fn = [](TypeGraph& typeGraph) { + auto fn = [](TypeGraph& typeGraph, NodeTracker&) { AlignmentCalc alignmentCalc; alignmentCalc.calculateAlignments(typeGraph.rootTypes()); }; diff --git a/oi/type_graph/EnforceCompatibility.cpp b/oi/type_graph/EnforceCompatibility.cpp index d3dc249..00919b3 100644 --- a/oi/type_graph/EnforceCompatibility.cpp +++ b/oi/type_graph/EnforceCompatibility.cpp @@ -27,8 +27,8 @@ namespace oi::detail::type_graph { Pass EnforceCompatibility::createPass() { - auto fn = [](TypeGraph& typeGraph) { - EnforceCompatibility pass{typeGraph.resetTracker()}; + auto fn = [](TypeGraph& typeGraph, NodeTracker& tracker) { + EnforceCompatibility pass{tracker}; for (auto& type : typeGraph.rootTypes()) { pass.accept(type); } diff --git a/oi/type_graph/Flattener.cpp b/oi/type_graph/Flattener.cpp index 62a32e3..23c73f4 100644 --- a/oi/type_graph/Flattener.cpp +++ b/oi/type_graph/Flattener.cpp @@ -21,8 +21,8 @@ namespace oi::detail::type_graph { Pass Flattener::createPass() { - auto fn = [](TypeGraph& typeGraph) { - Flattener flattener{typeGraph.resetTracker()}; + auto fn = [](TypeGraph& typeGraph, NodeTracker& tracker) { + Flattener flattener{tracker}; for (auto& type : typeGraph.rootTypes()) { flattener.accept(type); } diff --git a/oi/type_graph/NameGen.cpp b/oi/type_graph/NameGen.cpp index 7b9ea67..8ae8e36 100644 --- a/oi/type_graph/NameGen.cpp +++ b/oi/type_graph/NameGen.cpp @@ -23,7 +23,7 @@ using ref = std::reference_wrapper; namespace oi::detail::type_graph { Pass NameGen::createPass() { - auto fn = [](TypeGraph& typeGraph) { + auto fn = [](TypeGraph& typeGraph, NodeTracker&) { NameGen nameGen; nameGen.generateNames(typeGraph.rootTypes()); }; diff --git a/oi/type_graph/NodeTracker.h b/oi/type_graph/NodeTracker.h index 9c3a030..baf7bdf 100644 --- a/oi/type_graph/NodeTracker.h +++ b/oi/type_graph/NodeTracker.h @@ -59,6 +59,12 @@ class NodeTracker { std::fill(visited_.begin(), visited_.end(), false); } + /* + * resize + * + * Resizes the underlying vector to the requested size, with the same + * semantics as std::vector::resize(). + */ void resize(size_t size) { visited_.resize(size); } @@ -67,4 +73,39 @@ class NodeTracker { std::vector visited_; }; +/* + * NodeTrackerHolder + * + * Wrapper which ensures that the contained NodeTracker has been reset before + * allowing access to it. + */ +class NodeTrackerHolder { + public: + /* + * Implicit ctor from NodeTracker + */ + NodeTrackerHolder(NodeTracker& tracker) : tracker_(tracker) { + } + + /* + * get + * + * Returns a reference to a NodeTracker which has been reset, i.e. one in + * which all nodes are marked "not visited". + */ + NodeTracker& get() { + tracker_.reset(); + return tracker_; + } + + NodeTracker& get(size_t size) { + tracker_.reset(); + tracker_.resize(size); + return tracker_; + } + + private: + NodeTracker& tracker_; +}; + } // namespace oi::detail::type_graph diff --git a/oi/type_graph/PassManager.cpp b/oi/type_graph/PassManager.cpp index dff4a28..a9c3bc4 100644 --- a/oi/type_graph/PassManager.cpp +++ b/oi/type_graph/PassManager.cpp @@ -20,6 +20,7 @@ #include #include +#include "NodeTracker.h" #include "Printer.h" #include "TypeGraph.h" @@ -28,8 +29,8 @@ using ref = std::reference_wrapper; namespace oi::detail::type_graph { -void Pass::run(TypeGraph& typeGraph) { - fn_(typeGraph); +void Pass::run(TypeGraph& typeGraph, NodeTrackerHolder tracker) { + fn_(typeGraph, tracker.get(typeGraph.size())); } void PassManager::addPass(Pass p) { @@ -37,11 +38,11 @@ void PassManager::addPass(Pass p) { } namespace { -void print(const TypeGraph& typeGraph) { +void print(const TypeGraph& typeGraph, NodeTrackerHolder tracker) { if (!VLOG_IS_ON(1)) return; std::stringstream out; - Printer printer{out, typeGraph.resetTracker(), typeGraph.size()}; + Printer printer{out, tracker.get(typeGraph.size()), typeGraph.size()}; for (const auto& type : typeGraph.rootTypes()) { printer.print(type); } @@ -54,19 +55,21 @@ void print(const TypeGraph& typeGraph) { const std::string separator = "----------------"; void PassManager::run(TypeGraph& typeGraph) { + NodeTracker tracker; + VLOG(1) << separator; VLOG(1) << "Parsed Type Graph:"; VLOG(1) << separator; - print(typeGraph); + print(typeGraph, tracker); VLOG(1) << separator; for (size_t i = 0; i < passes_.size(); i++) { auto& pass = passes_[i]; LOG(INFO) << "Running pass (" << i + 1 << "/" << passes_.size() << "): " << pass.name(); - pass.run(typeGraph); + pass.run(typeGraph, tracker); VLOG(1) << separator; - print(typeGraph); + print(typeGraph, tracker); VLOG(1) << separator; } } diff --git a/oi/type_graph/PassManager.h b/oi/type_graph/PassManager.h index 5183d94..7dbc850 100644 --- a/oi/type_graph/PassManager.h +++ b/oi/type_graph/PassManager.h @@ -21,6 +21,8 @@ namespace oi::detail::type_graph { +class NodeTrackerHolder; +class NodeTracker; class TypeGraph; class Type; @@ -30,12 +32,13 @@ class Type; * TODO */ class Pass { - using PassFn = std::function; + using PassFn = + std::function; public: Pass(std::string name, PassFn fn) : name_(std::move(name)), fn_(fn) { } - void run(TypeGraph& typeGraph); + void run(TypeGraph& typeGraph, NodeTrackerHolder tracker); std::string& name() { return name_; }; diff --git a/oi/type_graph/Prune.cpp b/oi/type_graph/Prune.cpp index 5d5db80..903a8c2 100644 --- a/oi/type_graph/Prune.cpp +++ b/oi/type_graph/Prune.cpp @@ -21,8 +21,8 @@ namespace oi::detail::type_graph { Pass Prune::createPass() { - auto fn = [](TypeGraph& typeGraph) { - Prune pass{typeGraph.resetTracker()}; + auto fn = [](TypeGraph& typeGraph, NodeTracker& tracker) { + Prune pass{tracker}; for (auto& type : typeGraph.rootTypes()) { pass.accept(type); } diff --git a/oi/type_graph/RemoveMembers.cpp b/oi/type_graph/RemoveMembers.cpp index e5e4eb3..9602b87 100644 --- a/oi/type_graph/RemoveMembers.cpp +++ b/oi/type_graph/RemoveMembers.cpp @@ -22,7 +22,7 @@ namespace oi::detail::type_graph { Pass RemoveMembers::createPass( const std::vector>& membersToIgnore) { - auto fn = [&membersToIgnore](TypeGraph& typeGraph) { + auto fn = [&membersToIgnore](TypeGraph& typeGraph, NodeTracker&) { RemoveMembers removeMembers{membersToIgnore}; for (auto& type : typeGraph.rootTypes()) { removeMembers.accept(type); diff --git a/oi/type_graph/RemoveTopLevelPointer.cpp b/oi/type_graph/RemoveTopLevelPointer.cpp index 07a2ecd..9ecaaaa 100644 --- a/oi/type_graph/RemoveTopLevelPointer.cpp +++ b/oi/type_graph/RemoveTopLevelPointer.cpp @@ -20,7 +20,7 @@ namespace oi::detail::type_graph { Pass RemoveTopLevelPointer::createPass() { - auto fn = [](TypeGraph& typeGraph) { + auto fn = [](TypeGraph& typeGraph, NodeTracker&) { RemoveTopLevelPointer pass; pass.removeTopLevelPointers(typeGraph.rootTypes()); }; diff --git a/oi/type_graph/TopoSorter.cpp b/oi/type_graph/TopoSorter.cpp index 4ad9605..a2c882f 100644 --- a/oi/type_graph/TopoSorter.cpp +++ b/oi/type_graph/TopoSorter.cpp @@ -23,7 +23,7 @@ using ref = std::reference_wrapper; namespace oi::detail::type_graph { Pass TopoSorter::createPass() { - auto fn = [](TypeGraph& typeGraph) { + auto fn = [](TypeGraph& typeGraph, NodeTracker&) { TopoSorter sorter; sorter.sort(typeGraph.rootTypes()); typeGraph.finalTypes = std::move(sorter.sortedTypes()); diff --git a/oi/type_graph/TypeGraph.cpp b/oi/type_graph/TypeGraph.cpp index 13a6b6b..3d513c4 100644 --- a/oi/type_graph/TypeGraph.cpp +++ b/oi/type_graph/TypeGraph.cpp @@ -17,12 +17,6 @@ namespace oi::detail::type_graph { -NodeTracker& TypeGraph::resetTracker() const noexcept { - tracker_.reset(); - tracker_.resize(size()); - return tracker_; -} - template <> Primitive& TypeGraph::makeType(Primitive::Kind kind) { switch (kind) { diff --git a/oi/type_graph/TypeGraph.h b/oi/type_graph/TypeGraph.h index 36b430a..542d770 100644 --- a/oi/type_graph/TypeGraph.h +++ b/oi/type_graph/TypeGraph.h @@ -19,7 +19,6 @@ #include #include -#include "NodeTracker.h" #include "Types.h" namespace oi::detail::type_graph { @@ -48,8 +47,6 @@ class TypeGraph { rootTypes_.push_back(type); } - NodeTracker& resetTracker() const noexcept; - // Override of the generic makeType function that returns singleton Primitive // objects template @@ -88,7 +85,6 @@ class TypeGraph { std::vector> rootTypes_; // Store all type objects in vectors for ownership. Order is not significant. std::vector> types_; - mutable NodeTracker tracker_; NodeId next_id_ = 0; }; diff --git a/oi/type_graph/TypeIdentifier.cpp b/oi/type_graph/TypeIdentifier.cpp index df5e196..a576c8d 100644 --- a/oi/type_graph/TypeIdentifier.cpp +++ b/oi/type_graph/TypeIdentifier.cpp @@ -22,9 +22,8 @@ namespace oi::detail::type_graph { Pass TypeIdentifier::createPass( const std::vector& passThroughTypes) { - auto fn = [&passThroughTypes](TypeGraph& typeGraph) { - TypeIdentifier typeId{typeGraph.resetTracker(), typeGraph, - passThroughTypes}; + auto fn = [&passThroughTypes](TypeGraph& typeGraph, NodeTracker& tracker) { + TypeIdentifier typeId{tracker, typeGraph, passThroughTypes}; for (auto& type : typeGraph.rootTypes()) { typeId.accept(type); } diff --git a/test/test_add_children.cpp b/test/test_add_children.cpp index 1ccb19f..2844759 100644 --- a/test/test_add_children.cpp +++ b/test/test_add_children.cpp @@ -2,6 +2,7 @@ #include "oi/SymbolService.h" #include "oi/type_graph/AddChildren.h" +#include "oi/type_graph/NodeTracker.h" #include "oi/type_graph/Printer.h" #include "oi/type_graph/TypeGraph.h" #include "test_drgn_parser.h" @@ -22,12 +23,13 @@ std::string AddChildrenTest::run(std::string_view function, Type& type = drgnParser.parse(drgnRoot); typeGraph.addRoot(type); + NodeTracker tracker; auto pass = AddChildren::createPass(drgnParser, *symbols_); - pass.run(typeGraph); + pass.run(typeGraph, tracker); std::stringstream out; - Printer printer{out, typeGraph.resetTracker(), typeGraph.size()}; + Printer printer{out, tracker, typeGraph.size()}; printer.print(type); return out.str(); diff --git a/test/test_drgn_parser.cpp b/test/test_drgn_parser.cpp index c3c70b9..0c3a69e 100644 --- a/test/test_drgn_parser.cpp +++ b/test/test_drgn_parser.cpp @@ -6,6 +6,7 @@ // TODO needed?: #include "oi/ContainerInfo.h" #include "oi/OIParser.h" +#include "oi/type_graph/NodeTracker.h" #include "oi/type_graph/Printer.h" #include "oi/type_graph/TypeGraph.h" #include "oi/type_graph/Types.h" @@ -55,7 +56,8 @@ std::string DrgnParserTest::run(std::string_view function, Type& type = drgnParser.parse(drgnRoot); std::stringstream out; - Printer printer{out, typeGraph.resetTracker(), typeGraph.size()}; + NodeTracker tracker; + Printer printer{out, tracker, typeGraph.size()}; printer.print(type); return out.str(); diff --git a/test/test_node_tracker.cpp b/test/test_node_tracker.cpp index 9c59103..b34e83f 100644 --- a/test/test_node_tracker.cpp +++ b/test/test_node_tracker.cpp @@ -82,3 +82,33 @@ TEST(NodeTrackerTest, LargeIds) { EXPECT_TRUE(tracker.visit(myclass1)); EXPECT_TRUE(tracker.visit(myclass2)); } + +TEST(NodeTrackerTest, NodeTrackerHolder) { + Class myclass{0, Class::Kind::Class, "myclass", 0}; + Array myarray{1, myclass, 3}; + + NodeTracker baseTracker_doNotUse; + NodeTrackerHolder holder{baseTracker_doNotUse}; + + { + auto& tracker = holder.get(); + // First visit + EXPECT_FALSE(tracker.visit(myarray)); + EXPECT_FALSE(tracker.visit(myclass)); + + // Second visit + EXPECT_TRUE(tracker.visit(myarray)); + EXPECT_TRUE(tracker.visit(myclass)); + } + + { + auto& tracker = holder.get(); + // First visit, fresh tracker + EXPECT_FALSE(tracker.visit(myarray)); + EXPECT_FALSE(tracker.visit(myclass)); + + // Second visit, fresh tracker + EXPECT_TRUE(tracker.visit(myarray)); + EXPECT_TRUE(tracker.visit(myclass)); + } +} diff --git a/test/type_graph_utils.cpp b/test/type_graph_utils.cpp index 3a96552..7f7071a 100644 --- a/test/type_graph_utils.cpp +++ b/test/type_graph_utils.cpp @@ -3,6 +3,7 @@ #include #include "oi/ContainerInfo.h" +#include "oi/type_graph/NodeTracker.h" #include "oi/type_graph/PassManager.h" #include "oi/type_graph/Printer.h" #include "oi/type_graph/TypeGraph.h" @@ -21,7 +22,8 @@ void check(const TypeGraph& typeGraph, std::string_view expected, std::string_view comment) { std::stringstream out; - type_graph::Printer printer(out, typeGraph.resetTracker(), typeGraph.size()); + NodeTracker tracker; + type_graph::Printer printer(out, tracker, typeGraph.size()); for (const auto& type : typeGraph.rootTypes()) { printer.print(type); @@ -47,7 +49,8 @@ void test(type_graph::Pass pass, // Validate input formatting check(typeGraph, input, "parsing input graph"); - pass.run(typeGraph); + NodeTracker tracker; + pass.run(typeGraph, tracker); check(typeGraph, expectedAfter, "after running pass"); }