import json import os import pathlib import sys import toml def is_thrift_test(config): return "thrift_definitions" in config def get_case_name(test_suite, test_case): return f"{test_suite}_{test_case}" def get_target_oid_func_name(test_suite, test_case): case_name = get_case_name(test_suite, test_case) return f"oid_test_case_{case_name}" def get_target_oil_func_name(test_suite, test_case): case_name = get_case_name(test_suite, test_case) return f"oil_test_case_{case_name}" def get_namespace(test_suite): return f"ns_{test_suite}" def add_headers(f, custom_headers, thrift_headers): f.write( """ #include #include #include #include #include #include #include """ ) for header in custom_headers: f.write(f"#include <{header}>\n") for header in thrift_headers: f.write(f'#include "{header}"\n') def add_test_setup(f, config): ns = get_namespace(config["suite"]) # fmt: off f.write( f"\n" f'{config.get("raw_definitions", "")}\n' f"namespace {ns} {{\n" f'{config.get("definitions", "")}\n' ) # fmt: on def define_traceable_func(name, params, body): return ( f"\n" f' extern "C" {{\n' f" void __attribute__((noinline)) {name}({params}) {{\n" f"{body}" f" }}\n" f" }}\n" ) cases = config["cases"] for case_name, case in cases.items(): # generate getter for an object of this type param_types = ", ".join( f"std::remove_cvref_t<{param}>" for param in case["param_types"] ) if "arg_types" in case: arg_types = ", ".join(case["arg_types"]) else: arg_types = param_types f.write( f"\n" f" std::tuple<{arg_types}> get_{case_name}() {{\n" f'{case["setup"]}\n' f" }}\n" ) # generate oid and oil targets params_str = ", ".join( f"{param} a{i}" for i, param in enumerate(case["param_types"]) ) oid_func_body = "".join( f" std::cout << (uintptr_t)(&a{i}) << std::endl;\n" for i in range(len(case["param_types"])) ) oid_func_body += " std::cout << BOOST_CURRENT_FUNCTION << std::endl;\n" f.write( define_traceable_func( get_target_oid_func_name(config["suite"], case_name), params_str, oid_func_body, ) ) oil_func_body = ( f"\n" f"ObjectIntrospection::options opts{{\n" f' .configFilePath = std::getenv("CONFIG_FILE_PATH"),\n' f" .debugLevel = 3,\n" f' .sourceFileDumpPath = "oil_jit_code.cpp",\n' f" .forceJIT = true,\n" f"}};" ) oil_func_body += ' std::cout << "{\\"results\\": [" << std::endl;\n' oil_func_body += ' std::cout << "," << std::endl;\n'.join( f" size_t size{i} = 0;\n" f" auto ret{i} = ObjectIntrospection::getObjectSize(&a{i}, &size{i}, opts);\n" f' std::cout << "{{\\"ret\\": " << ret{i} << ", \\"size\\": " << size{i} << "}}" << std::endl;\n' for i in range(len(case["param_types"])) ) oil_func_body += ' std::cout << "]}" << std::endl;\n' f.write( define_traceable_func( get_target_oil_func_name(config["suite"], case_name), params_str, oil_func_body, ) ) f.write(f"}} // namespace {ns}\n") def add_common_code(f): f.write( """ int main(int argc, char *argv[]) { if (argc < 3 || argc > 4) { std::cerr << "usage: " << argv[0] << " oid/oil CASE [ITER]" << std::endl; return -1; } std::string mode = argv[1]; std::string test_case = argv[2]; int iterations = 1000; if (argc == 4) { std::istringstream iss(argv[3]); iss >> iterations; if (iss.fail()) iterations = 1000; } """ ) def add_dispatch_code(f, config): ns = get_namespace(config["suite"]) for case_name in config["cases"]: case_str = get_case_name(config["suite"], case_name) oil_func_name = get_target_oil_func_name(config["suite"], case_name) oid_func_name = get_target_oid_func_name(config["suite"], case_name) f.write( f' if (test_case == "{case_str}") {{\n' f" auto val = {ns}::get_{case_name}();\n" f" for (int i=0; i sizes;\n" f' for (const auto& each : result_json.get_child("results")) {{\n' f" const auto& result = each.second;\n" f' int oilResult = result.get("ret");\n' f' size_t oilSize = result.get("size");\n' f" ASSERT_EQ(oilResult, 0);\n" f" sizes.push_back(oilSize);\n" f" }}" ) if "expect_json" in case: try: json.loads(case["expect_json"]) except json.decoder.JSONDecodeError as error: print( f"\x1b[31m`expect_json` value for test case {config['suite']}.{case_name} was invalid JSON: {error}\x1b[0m", file=sys.stderr, ) sys.exit(1) f.write( f"\n" f" std::stringstream expected_json_ss;\n" f' expected_json_ss << R"--({case["expect_json"]})--";\n' f" bpt::ptree expected_json;\n" f" bpt::read_json(expected_json_ss, expected_json);\n" f" auto sizes_it = sizes.begin();\n" f" for (auto it = expected_json.begin(); it != expected_json.end(); ++it, ++sizes_it) {{\n" f" auto node = it->second;\n" f' size_t expected_size = node.get("staticSize");\n' f' expected_size += node.get("dynamicSize");\n' f" EXPECT_EQ(*sizes_it, expected_size);\n" f" }}\n" ) f.write(f"}}\n") def generate_skip(case, specific): possibly_skip = "" skip_reason = case.get("skip", False) specific_skip_reason = case.get(f"{specific}_skip", False) if specific_skip_reason or skip_reason: possibly_skip += " if (!run_skipped_tests) {\n" possibly_skip += " GTEST_SKIP()" if type(specific_skip_reason) == str: possibly_skip += f' << "{specific_skip_reason}"' elif type(skip_reason) == str: possibly_skip += f' << "{skip_reason}"' possibly_skip += ";\n" possibly_skip += " }\n" return possibly_skip def gen_runner(output_runner_name, test_configs): with open(output_runner_name, "w") as f: f.write( "#include \n" "#include \n" "#include \n" "#include \n" "#include \n" "#include \n" "#include \n" '#include "runner_common.h"\n' "\n" "namespace ba = boost::asio;\n" "namespace bpt = boost::property_tree;\n" "\n" "using ::testing::MatchesRegex;\n" "\n" "extern bool run_skipped_tests;\n" ) for config in test_configs: add_tests(f, config) def gen_thrift(test_configs): for config in test_configs: if not is_thrift_test(config): continue output_thrift_name = f"{config['suite']}.thrift" with open(output_thrift_name, "w") as f: f.write(config["thrift_definitions"]) print(f"Thrift out: {output_thrift_name}") def main(): if len(sys.argv) < 4: print("Usage: gen_tests.py OUTPUT_TARGET OUTPUT_RUNNER INPUT1 [INPUT2 ...]") exit(1) output_target = sys.argv[1] output_runner = sys.argv[2] inputs = sys.argv[3:] print(f"Output target: {output_target}") print(f"Output runner: {output_runner}") print(f"Input files: {inputs}") test_configs = [] test_suites = set() while len(inputs) > 0: test_path = inputs.pop() if test_path.endswith(".toml"): test_suite = pathlib.Path(test_path).stem if test_suite in test_suites: raise Exception(f"Test suite {test_suite} is defined multiple times") test_suites.add(test_suite) config = toml.load(test_path) config["suite"] = test_suite test_configs += [config] elif os.path.isdir(test_path): for root, dirs, files in os.walk(test_path): for name in files: if name.endswith(".toml"): path = os.path.join(root, name) print("Found definition file at {path}") inputs.append(path) else: raise Exception( "Test definition inputs must have the '.toml' extension or be a directory" ) gen_target(output_target, test_configs) gen_runner(output_runner, test_configs) gen_thrift(test_configs) if __name__ == "__main__": main()