Coder Social home page Coder Social logo

Comments (1)

pfultz2 avatar pfultz2 commented on July 26, 2024

BTW, here is the script I used to split the tests for onnx, with some tweaking this can be used for TF tests as well:

split_tests.py
import sys, os, string

# onnx/
#   parse/
#   verify/
#   include/
#   models/

def collect_test_cases(files):
    test_cases = []
    for file in files:
        test_case = None
        for line in open(file).readlines():
            # print(line)
            if line.startswith('TEST_CASE'):
                test_case = line
            elif test_case:
                test_case = test_case + line
            if line.startswith('}'):
                if test_case:
                    test_cases.append(test_case)
                test_case = None
    return test_cases

def get_function_parameter(case, names):
    for name in names:
        if not name in case:
            continue
        n = len(name) + 2
        i = case.index(name) + n
        end = case.find('"', i)
        return case[i:end]


def group_by(l, select):
    result = {}
    for item in l:
        key = select(item)
        if not key in result:
            result[key] = []
        result[key].append(item)
    return result

def removesuffix(s):
    if '.' in s:
        return s.rsplit('.', 1)[0]
    return s

def write_to(filename, content):
    with open(filename, 'w') as f:
        f.write(content)

header_guard_template = string.Template('''
#ifndef MIGRAPHX_GUARD_TEST_ONNX_${name}_HPP
#define MIGRAPHX_GUARD_TEST_ONNX_${name}_HPP

${content}

#endif
''')

verify_test_util = '''
#include <migraphx/program.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/common.hpp>

template <typename T = float>
std::vector<T> norm_test(const std::vector<size_t>& x_dims,
                         std::vector<T>& scale,
                         std::vector<T>& bias,
                         migraphx::program p)
{
    p.compile(migraphx::make_target("ref"));

    migraphx::shape s_x{migraphx::shape::get_type<T>{}, x_dims};
    migraphx::shape s_s{migraphx::shape::get_type<T>{}, {scale.size()}};
    migraphx::shape s_b{migraphx::shape::get_type<T>{}, {scale.size()}};

    std::vector<T> x(s_x.elements());
    std::iota(std::begin(x), std::end(x), 1);

    migraphx::parameter_map pp;
    pp["x"]     = migraphx::argument(s_x, x.data());
    pp["scale"] = migraphx::argument(s_s, scale.data());
    pp["bias"]  = migraphx::argument(s_b, bias.data());

    auto result = p.eval(pp).back();

    std::vector<T> result_vector;
    result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });

    return result_vector;
}

template <typename T = float>
std::vector<T> mvn_test(std::vector<size_t> data_lens, migraphx::program p)
{
    p.compile(migraphx::make_target("ref"));

    migraphx::shape data_shape(migraphx::shape::get_type<T>{}, std::move(data_lens));
    std::vector<T> data(data_shape.elements());
    std::iota(begin(data), end(data), 0);

    migraphx::parameter_map pm;
    pm["data"] = migraphx::argument(data_shape, data.data());

    auto result = p.eval(pm).back();
    std::vector<T> result_vector;
    result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });

    return result_vector;
}

inline std::vector<float> gen_trilu_test(const migraphx::shape& s, const migraphx::program& p)
{
    // input data filled with values 1 to nelements
    std::vector<float> x_data(s.elements());
    std::iota(x_data.begin(), x_data.end(), 1);

    migraphx::parameter_map pp;
    pp["x"] = migraphx::argument(s, x_data.data());

    auto result = p.eval(pp).back();
    std::vector<float> result_vector;
    result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
    return result_vector;
}
'''

def write_header(filename, content):
    basename = os.path.basename(filename)
    name = removesuffix(basename).upper()
    write_to(filename, header_guard_template.substitute(content=content, name=name))


include_guide = {
    'norm_test': 'onnx_verify_utils.hpp',
    'mvn_test': 'onnx_verify_utils.hpp',
    'gen_trilu_test': 'onnx_verify_utils.hpp',
}

def create_includes(case):
    includes = ""
    for key, include in include_guide.items():
        if not key in case:
            continue
        if include in includes:
            continue
        includes = "#include <{}>\n".format(include)
    return includes

parse_template = string.Template('''
#include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp>
#include <onnx_test.hpp>
${includes}

${test_case}

''')

main_test = '''
#include <test.hpp>

int main(int argc, const char* argv[]) { test::run(argc, argv); }
'''


def write_case(p, name, cases):
    content = '\n'.join(cases)
    includes = create_includes(content)
    content = parse_template.substitute(test_case=content, includes=includes)
    write_to(os.path.join(p, name+'.cpp'), content)


def get_onnx_file(case):
    param = get_function_parameter(case, ['optimize_onnx', 'parse_onnx'])
    if not param:
        print("No onnx file found for test case:\n", case)
    basename = removesuffix(param)
    return basename.replace('/', '_')

def main():
    args = sys.argv
    onnx_dir = args[1]
    onnx_test = os.path.join(onnx_dir, 'onnx_test.cpp')
    verify_onnx = os.path.join(onnx_dir, 'verify_onnx.cpp')
    onnx_rnn_test = os.path.join(onnx_dir, 'onnx_rnn_test.cpp')
    group_cases = group_by(collect_test_cases([verify_onnx]), get_onnx_file)
    parse_dir = os.path.join(onnx_dir, 'verify')
    for key, cases in group_cases.items():
        write_case(parse_dir, key, cases)
    write_to(os.path.join(parse_dir, 'main.cpp'), main_test)
    include_dir = os.path.join(onnx_dir, 'include')
    write_header(os.path.join(include_dir, 'onnx_verify_utils.hpp'), verify_test_util)
    # write_header(os.path.join(include_dir, 'onnx_test.hpp'), onnx_header)
    # write_header(os.path.join(include_dir, 'onnx_test_utils.hpp'), onnx_utils_header)



if __name__ == "__main__":
    main()

from amdmigraphx.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.