Comments (1)
BTW, here is the script I used to split the tests for onnx, with some tweaking this can be used for TF tests as well:
split_tests.py
import sys, os, string
# onnx/
# parse/
# verify/
# include/
# models/
def collect_test_cases(files):
test_cases = []
for file in files:
test_case = None
for line in open(file).readlines():
# print(line)
if line.startswith('TEST_CASE'):
test_case = line
elif test_case:
test_case = test_case + line
if line.startswith('}'):
if test_case:
test_cases.append(test_case)
test_case = None
return test_cases
def get_function_parameter(case, names):
for name in names:
if not name in case:
continue
n = len(name) + 2
i = case.index(name) + n
end = case.find('"', i)
return case[i:end]
def group_by(l, select):
result = {}
for item in l:
key = select(item)
if not key in result:
result[key] = []
result[key].append(item)
return result
def removesuffix(s):
if '.' in s:
return s.rsplit('.', 1)[0]
return s
def write_to(filename, content):
with open(filename, 'w') as f:
f.write(content)
header_guard_template = string.Template('''
#ifndef MIGRAPHX_GUARD_TEST_ONNX_${name}_HPP
#define MIGRAPHX_GUARD_TEST_ONNX_${name}_HPP
${content}
#endif
''')
verify_test_util = '''
#include <migraphx/program.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/common.hpp>
template <typename T = float>
std::vector<T> norm_test(const std::vector<size_t>& x_dims,
std::vector<T>& scale,
std::vector<T>& bias,
migraphx::program p)
{
p.compile(migraphx::make_target("ref"));
migraphx::shape s_x{migraphx::shape::get_type<T>{}, x_dims};
migraphx::shape s_s{migraphx::shape::get_type<T>{}, {scale.size()}};
migraphx::shape s_b{migraphx::shape::get_type<T>{}, {scale.size()}};
std::vector<T> x(s_x.elements());
std::iota(std::begin(x), std::end(x), 1);
migraphx::parameter_map pp;
pp["x"] = migraphx::argument(s_x, x.data());
pp["scale"] = migraphx::argument(s_s, scale.data());
pp["bias"] = migraphx::argument(s_b, bias.data());
auto result = p.eval(pp).back();
std::vector<T> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
return result_vector;
}
template <typename T = float>
std::vector<T> mvn_test(std::vector<size_t> data_lens, migraphx::program p)
{
p.compile(migraphx::make_target("ref"));
migraphx::shape data_shape(migraphx::shape::get_type<T>{}, std::move(data_lens));
std::vector<T> data(data_shape.elements());
std::iota(begin(data), end(data), 0);
migraphx::parameter_map pm;
pm["data"] = migraphx::argument(data_shape, data.data());
auto result = p.eval(pm).back();
std::vector<T> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
return result_vector;
}
inline std::vector<float> gen_trilu_test(const migraphx::shape& s, const migraphx::program& p)
{
// input data filled with values 1 to nelements
std::vector<float> x_data(s.elements());
std::iota(x_data.begin(), x_data.end(), 1);
migraphx::parameter_map pp;
pp["x"] = migraphx::argument(s, x_data.data());
auto result = p.eval(pp).back();
std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
return result_vector;
}
'''
def write_header(filename, content):
basename = os.path.basename(filename)
name = removesuffix(basename).upper()
write_to(filename, header_guard_template.substitute(content=content, name=name))
include_guide = {
'norm_test': 'onnx_verify_utils.hpp',
'mvn_test': 'onnx_verify_utils.hpp',
'gen_trilu_test': 'onnx_verify_utils.hpp',
}
def create_includes(case):
includes = ""
for key, include in include_guide.items():
if not key in case:
continue
if include in includes:
continue
includes = "#include <{}>\n".format(include)
return includes
parse_template = string.Template('''
#include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp>
#include <onnx_test.hpp>
${includes}
${test_case}
''')
main_test = '''
#include <test.hpp>
int main(int argc, const char* argv[]) { test::run(argc, argv); }
'''
def write_case(p, name, cases):
content = '\n'.join(cases)
includes = create_includes(content)
content = parse_template.substitute(test_case=content, includes=includes)
write_to(os.path.join(p, name+'.cpp'), content)
def get_onnx_file(case):
param = get_function_parameter(case, ['optimize_onnx', 'parse_onnx'])
if not param:
print("No onnx file found for test case:\n", case)
basename = removesuffix(param)
return basename.replace('/', '_')
def main():
args = sys.argv
onnx_dir = args[1]
onnx_test = os.path.join(onnx_dir, 'onnx_test.cpp')
verify_onnx = os.path.join(onnx_dir, 'verify_onnx.cpp')
onnx_rnn_test = os.path.join(onnx_dir, 'onnx_rnn_test.cpp')
group_cases = group_by(collect_test_cases([verify_onnx]), get_onnx_file)
parse_dir = os.path.join(onnx_dir, 'verify')
for key, cases in group_cases.items():
write_case(parse_dir, key, cases)
write_to(os.path.join(parse_dir, 'main.cpp'), main_test)
include_dir = os.path.join(onnx_dir, 'include')
write_header(os.path.join(include_dir, 'onnx_verify_utils.hpp'), verify_test_util)
# write_header(os.path.join(include_dir, 'onnx_test.hpp'), onnx_header)
# write_header(os.path.join(include_dir, 'onnx_test_utils.hpp'), onnx_utils_header)
if __name__ == "__main__":
main()
from amdmigraphx.
Related Issues (20)
- Add Splitk GEMM test that would split pointwise module
- [Issue]: Fatal error: 'half/half.hpp' file not found HOT 1
- Can't build a Docker image with RX7900 XTX GPU
- Add TF tests to test package
- [WIP] Create hipBLASLt workplan
- Add onnx parser for SkipSimplifiedLayerNormalization
- Add onnx parser for GroupQueryAttention
- Add JIT kernel to implement GroupQueryAttention
- Concat broadcast rewrite to handle more cases
- instruction::replace(const shape&) fails when multiple inputs originate from the same instruction
- MI300: run_high_level_pipeline: Invalid MLIR created HOT 5
- Simplify `log(softmax(x))`
- Add onnx parser for SimplifiedLayerNormalization
- Limit GPU memory allocation of literals to be a predefined value HOT 2
- Investigate HIP memory allocation scheme
- [Issue]: failed at hipMemGetInfo(&free, &total)
- [Issue]: no member named 'exception_ptr' in namespace 'std'
- Longformer reshape simplification
- Scatter elements simplification
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from amdmigraphx.