From 98008f9054b954fdd1eca10e2d81fd6a75c3ed89 Mon Sep 17 00:00:00 2001 From: Alastair Robertson Date: Thu, 2 Nov 2023 06:05:35 -0700 Subject: [PATCH] TypeGraph: Add IdentifyContainers mutator pass --- oi/type_graph/CMakeLists.txt | 1 + oi/type_graph/IdentifyContainers.cpp | 88 ++++++++++++ oi/type_graph/IdentifyContainers.h | 59 ++++++++ test/CMakeLists.txt | 1 + test/test_identify_containers.cpp | 193 +++++++++++++++++++++++++++ test/type_graph_utils.cpp | 24 ++++ test/type_graph_utils.h | 1 + 7 files changed, 367 insertions(+) create mode 100644 oi/type_graph/IdentifyContainers.cpp create mode 100644 oi/type_graph/IdentifyContainers.h create mode 100644 test/test_identify_containers.cpp diff --git a/oi/type_graph/CMakeLists.txt b/oi/type_graph/CMakeLists.txt index 4e8f622..940c031 100644 --- a/oi/type_graph/CMakeLists.txt +++ b/oi/type_graph/CMakeLists.txt @@ -5,6 +5,7 @@ add_library(type_graph DrgnParser.cpp EnforceCompatibility.cpp Flattener.cpp + IdentifyContainers.cpp KeyCapture.cpp NameGen.cpp PassManager.cpp diff --git a/oi/type_graph/IdentifyContainers.cpp b/oi/type_graph/IdentifyContainers.cpp new file mode 100644 index 0000000..a0622f2 --- /dev/null +++ b/oi/type_graph/IdentifyContainers.cpp @@ -0,0 +1,88 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "IdentifyContainers.h" + +#include + +#include "TypeGraph.h" +#include "oi/ContainerInfo.h" + +namespace oi::detail::type_graph { + +Pass IdentifyContainers::createPass( + const std::vector>& containers) { + auto fn = [&containers](TypeGraph& typeGraph, NodeTracker&) { + IdentifyContainers typeId{typeGraph, containers}; + for (auto& type : typeGraph.rootTypes()) { + type = typeId.mutate(type); + } + }; + + return Pass("IdentifyContainers", fn); +} + +bool IdentifyContainers::isAllocator(Type& t) { + auto* c = dynamic_cast(&t); + if (!c) + return false; + + // Maybe add more checks for an allocator. + // For now, just test for the presence of an "allocate" function + for (const auto& func : c->functions) { + if (func.name == "allocate") { + return true; + } + } + return false; +} + +IdentifyContainers::IdentifyContainers( + TypeGraph& typeGraph, + const std::vector>& containers) + : tracker_(typeGraph.size()), + typeGraph_(typeGraph), + containers_(containers) { +} + +Type& IdentifyContainers::mutate(Type& type) { + if (Type* mutated = tracker_.get(type)) + return *mutated; + + Type& mutated = type.accept(*this); + tracker_.set(type, mutated); + return mutated; +} + +Type& IdentifyContainers::visit(Class& c) { + for (const auto& containerInfo : containers_) { + if (!std::regex_search(c.fqName(), containerInfo->matcher)) { + continue; + } + + auto& container = typeGraph_.makeType(*containerInfo, c.size()); + container.templateParams = c.templateParams; + + tracker_.set(c, container); + RecursiveMutator::visit(container); + return container; + } + + tracker_.set(c, c); + RecursiveMutator::visit(c); + return c; +} + +} // namespace oi::detail::type_graph diff --git a/oi/type_graph/IdentifyContainers.h b/oi/type_graph/IdentifyContainers.h new file mode 100644 index 0000000..6436701 --- /dev/null +++ b/oi/type_graph/IdentifyContainers.h @@ -0,0 +1,59 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include + +#include "NodeTracker.h" +#include "PassManager.h" +#include "Types.h" +#include "Visitor.h" +#include "oi/ContainerInfo.h" + +namespace oi::detail::type_graph { + +class TypeGraph; + +/* + * IdentifyContainers + * + * Walks a flattened type graph and replaces type nodes based on container + * definition TOML files. + */ +class IdentifyContainers : public RecursiveMutator { + public: + static Pass createPass( + const std::vector>& containers); + static bool isAllocator(Type& t); + + IdentifyContainers( + TypeGraph& typeGraph, + const std::vector>& containers); + + using RecursiveMutator::mutate; + + Type& mutate(Type& type) override; + Type& visit(Class& c) override; + + private: + MutationTracker tracker_; + TypeGraph& typeGraph_; + const std::vector>& containers_; +}; + +} // namespace oi::detail::type_graph diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 24b8f68..287ebcd 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -43,6 +43,7 @@ add_executable(test_type_graph test_drgn_parser.cpp test_enforce_compatibility.cpp test_flattener.cpp + test_identify_containers.cpp test_key_capture.cpp test_name_gen.cpp test_node_tracker.cpp diff --git a/test/test_identify_containers.cpp b/test/test_identify_containers.cpp new file mode 100644 index 0000000..7a2eab1 --- /dev/null +++ b/test/test_identify_containers.cpp @@ -0,0 +1,193 @@ +#include + +#include "oi/ContainerInfo.h" +#include "oi/type_graph/IdentifyContainers.h" +#include "oi/type_graph/Types.h" +#include "test/type_graph_utils.h" + +using namespace type_graph; + +namespace { +void test(std::string_view input, std::string_view expectedAfter) { + ::test(IdentifyContainers::createPass(getContainerInfos()), input, + expectedAfter); +} +}; // namespace + +TEST(IdentifyContainers, Container) { + test(R"( +[0] Class: std::vector (size: 24) + Param + Primitive: int32_t + Member: a (offset: 0) + Primitive: int32_t +)", + R"( +[1] Container: std::vector (size: 24) + Param + Primitive: int32_t +)"); +} + +TEST(IdentifyContainers, ContainerInClass) { + test(R"( +[0] Class: MyClass (size: 0) + Param +[1] Class: std::vector (size: 24) + Param + Primitive: int32_t + Parent (offset: 0) +[2] Class: std::vector (size: 24) + Param + Primitive: int32_t + Member: a (offset: 0) +[3] Class: std::vector (size: 24) + Param + Primitive: int32_t +)", + R"( +[0] Class: MyClass (size: 0) + Param +[4] Container: std::vector (size: 24) + Param + Primitive: int32_t + Parent (offset: 0) +[5] Container: std::vector (size: 24) + Param + Primitive: int32_t + Member: a (offset: 0) +[6] Container: std::vector (size: 24) + Param + Primitive: int32_t +)"); +} + +TEST(IdentifyContainers, ContainerInContainer) { + test(R"( +[0] Class: std::vector (size: 24) + Param +[1] Class: std::vector (size: 24) + Param + Primitive: int32_t +)", + R"( +[2] Container: std::vector (size: 24) + Param +[3] Container: std::vector (size: 24) + Param + Primitive: int32_t +)"); +} + +TEST(IdentifyContainers, ContainerInContainer2) { + test(R"( +[0] Container: std::vector (size: 24) + Param +[1] Class: std::vector (size: 24) + Param + Primitive: int32_t +)", + R"( +[0] Container: std::vector (size: 24) + Param +[2] Container: std::vector (size: 24) + Param + Primitive: int32_t +)"); +} + +TEST(IdentifyContainers, ContainerInArray) { + test(R"( +[0] Array: (length: 2) +[1] Class: std::vector (size: 24) + Param + Primitive: int32_t +)", + R"( +[0] Array: (length: 2) +[2] Container: std::vector (size: 24) + Param + Primitive: int32_t +)"); +} + +TEST(IdentifyContainers, ContainerInTypedef) { + test(R"( +[0] Typedef: MyAlias +[1] Class: std::vector (size: 24) + Param + Primitive: int32_t +)", + R"( +[0] Typedef: MyAlias +[2] Container: std::vector (size: 24) + Param + Primitive: int32_t +)"); +} + +TEST(IdentifyContainers, ContainerInPointer) { + test(R"( +[0] Pointer +[1] Class: std::vector (size: 24) + Param + Primitive: int32_t +)", + R"( +[0] Pointer +[2] Container: std::vector (size: 24) + Param + Primitive: int32_t +)"); +} + +TEST(IdentifyContainers, ContainerDuplicate) { + test(R"( +[0] Class: std::vector (size: 24) + Param + Primitive: int32_t + Member: a (offset: 0) + Primitive: int32_t + [0] +)", + R"( +[1] Container: std::vector (size: 24) + Param + Primitive: int32_t + [1] +)"); +} + +TEST(IdentifyContainers, CycleClass) { + test(R"( +[0] Class: ClassA (size: 0) + Member: x (offset: 0) +[1] Class: ClassB (size: 0) + Param + [0] +)", + R"( +[0] Class: ClassA (size: 0) + Member: x (offset: 0) +[1] Class: ClassB (size: 0) + Param + [0] +)"); +} + +TEST(IdentifyContainers, CycleContainer) { + test(R"( +[0] Class: ClassA (size: 0) + Member: x (offset: 0) +[1] Class: std::vector (size: 0) + Param + [0] +)", + R"( +[0] Class: ClassA (size: 0) + Member: x (offset: 0) +[2] Container: std::vector (size: 0) + Param + [0] +)"); +} diff --git a/test/type_graph_utils.cpp b/test/type_graph_utils.cpp index 7f7071a..815a46a 100644 --- a/test/type_graph_utils.cpp +++ b/test/type_graph_utils.cpp @@ -59,6 +59,30 @@ void testNoChange(type_graph::Pass pass, std::string_view input) { test(pass, input, input); } +std::vector> getContainerInfos() { + auto std_vector = + std::make_unique("std::vector", SEQ_TYPE, "vector"); + std_vector->stubTemplateParams = {1}; + + auto std_map = std::make_unique("std::map", SEQ_TYPE, "map"); + std_map->stubTemplateParams = {2, 3}; + + auto std_list = + std::make_unique("std::list", SEQ_TYPE, "list"); + std_list->stubTemplateParams = {1}; + + auto std_pair = + std::make_unique("std::pair", SEQ_TYPE, "list"); + + std::vector> containers; + containers.emplace_back(std::move(std_vector)); + containers.emplace_back(std::move(std_map)); + containers.emplace_back(std::move(std_list)); + containers.emplace_back(std::move(std_pair)); + + return containers; +} + Container getVector(NodeId id) { static ContainerInfo info{"std::vector", SEQ_TYPE, "vector"}; info.stubTemplateParams = {1}; diff --git a/test/type_graph_utils.h b/test/type_graph_utils.h index 57654da..216b6b3 100644 --- a/test/type_graph_utils.h +++ b/test/type_graph_utils.h @@ -23,6 +23,7 @@ void test(type_graph::Pass pass, void testNoChange(type_graph::Pass pass, std::string_view input); +std::vector> getContainerInfos(); type_graph::Container getVector(type_graph::NodeId id = 0); type_graph::Container getMap(type_graph::NodeId id = 0); type_graph::Container getList(type_graph::NodeId id = 0);