JittorMirror/src/utils/cache_compile.cc

349 lines
11 KiB
C++

// ***************************************************************
// Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#include <fstream>
#include <streambuf>
#include "misc/hash.h"
#include "utils/cache_compile.h"
namespace jittor {
namespace jit_compiler {
#ifndef TEST
string read_all(const string& fname) {
std::ifstream ifs(fname);
if (ifs)
return string((std::istreambuf_iterator<char>(ifs)),
(std::istreambuf_iterator<char>()));
return "";
}
void write(const string& fname, const string& src) {
std::ofstream(fname) << src;
}
bool file_exist(const string& fname) {
std::ifstream f(fname);
return f.good();
}
#endif
string join(string a, string b) {
const char sep = '/';
if (!b.empty() && b.front() == sep) return b;
a.reserve(a.size() + b.size() + 1);
if (!a.empty() && a.back() != sep) a += sep;
a += b;
return a;
}
void find_names(string cmd, vector<string>& input_names, string& output_name, map<string,vector<string>>& extra) {
size_t i=0;
while (i<cmd.size() && cmd[i] != ' ') i++;
CHECK(i<cmd.size());
// find space not in str
auto pass = [&](size_t& j) {
while (j<cmd.size()) {
if (cmd[j]=='\'') {
j++;
while (j<cmd.size() && cmd[j]!='\'') j++;
ASSERT(j<cmd.size());
j++;
continue;
}
while (j<cmd.size() && cmd[j]!=' ' && cmd[j]!='\'') j++;
if (j<cmd.size()) {
if (cmd[j]==' ') break;
if (cmd[j]=='\'') continue;
}
}
};
// remove "'"
auto substr = [&](size_t i, size_t j) -> string {
string s;
for (size_t k=i; k<j; k++)
if (cmd[k]!='\'') s += cmd[k];
return s;
};
while (i<cmd.size()) {
if (cmd[i] == ' ') {
i++;
continue;
}
if (cmd[i] == '-') {
if (i+2<cmd.size() && cmd[i+1]=='o' && cmd[i+2]==' ') {
auto j=i+3;
while (j<cmd.size() && cmd[j] == ' ') j++;
CHECK(j<cmd.size());
auto k=j;
pass(k);
CHECK(j<k && output_name.size()==0);
// -o xxx
// i j k
output_name = substr(j, k);
i = k;
continue;
} else if (i+2<cmd.size() && cmd[i+1]=='I') {
// -Ixxx -I'xxx' -I xxx
size_t j=i+2;
while (j<cmd.size() && cmd[j]==' ') j++;
size_t k=j;
pass(k);
CHECK(j<k);
auto inc = substr(j, k);
// find include
i = k;
extra["I"].push_back(inc);
continue;
} else if (i+2<cmd.size() && cmd[i+1]=='x') {
// option with space and arg
size_t j=i+2;
while (j<cmd.size() && cmd[j]==' ') j++;
size_t k=j;
pass(k);
i = k;
continue;
} else {
pass(i);
continue;
}
}
auto j=i;
pass(j);
input_names.push_back(substr(i, j));
i = j;
}
CHECK(output_name.size() && input_names.size())
<< "output_name: " << output_name
<< " input_names: " << input_names << "\n" << cmd;
}
size_t skip_comments(const string& src, size_t i) {
if (src[i] == '/' && (i+1<src.size() && src[i+1] == '/')) {
size_t j=i+1;
while (j<src.size() && src[j] != '\n') j++;
if (j<src.size()) j++;
return j;
} else
if (src[i] == '/' && (i+1<src.size() && src[i+1] == '*')) {
size_t j=i+1;
while (j<src.size() && !(src[j] == '/' && src[j-1] == '*')) j++;
if (j<src.size()) j++;
return j;
}
return i;
}
void process(string src, vector<string>& input_names) {
for (size_t i=0; i<src.size(); i++) {
i = skip_comments(src, i);
if (i>=src.size()) break;
if (src[i] == '#') {
// #include "a.h"
// i jk l
auto j=i+1;
while (j<src.size() && src[j] != ' ') j++;
if (j>=src.size()) return;
auto k=j+1;
while (k<src.size() && src[k] == ' ') k++;
if (k>=src.size()) return;
auto l=k+1;
while (l<src.size() && (src[l] != ' ' && src[l] != '\n')) l++;
if (src[k] == '"' && src[l-1] == '"' && j-i==8 && src.substr(i,j-i) == "#include") {
auto inc = src.substr(k+1, l-k-2);
if (inc != "test.h" && inc != "helper_cuda.h") {
LOGvvvv << "Found include" << inc;
input_names.push_back(inc);
}
}
i=l;
}
}
}
bool cache_compile(const string& cmd, const string& cache_path, const string& jittor_path) {
vector<string> input_names;
map<string,vector<string>> extra;
string output_name;
find_names(cmd, input_names, output_name, extra);
string output_cache_key;
bool ran = false;
output_cache_key = read_all(output_name+".key");
string cd_cmd = cache_path.size() ? "cd " + cache_path + " && " + cmd : cmd;
if (output_cache_key.size() == 0) {
LOGvv << "Cache key of" << output_name << "not found.";
LOGvvv << "Run cmd:" << cmd;
system_with_check(cd_cmd.c_str());
ran = true;
}
string cache_key = cmd;
cache_key += "\n";
unordered_set<string> processed;
auto src_path = join(jittor_path, "src");
const auto& extra_include = extra["I"];
for (size_t i=0; i<input_names.size(); i++) {
if (processed.count(input_names[i]) != 0)
continue;
processed.insert(input_names[i]);
auto src = read_all(input_names[i]);
ASSERT(src.size()) << "Source read failed:" << input_names[i];
auto hash = S(hash64(src));
vector<string> new_names;
process(src, new_names);
for (auto& name : new_names) {
string full_name;
if (name.substr(0, 4) == "jit/" || name.substr(0, 4) == "gen/")
full_name = join(cache_path, name);
else if (name.size() && name[0]=='/')
full_name = name;
else
full_name = join(src_path, name);
if (!file_exist(full_name)) {
bool found = 0;
for (const auto& inc : extra_include) {
full_name = join(inc, name);
if (file_exist(full_name)) {
found = 1;
break;
}
}
ASSERT(found) << "Include file" << name << "not found in" << extra_include
>> "\nCommands:" << cmd;
LOGvvvv << "Include file found:" << full_name;
}
input_names.push_back(full_name);
}
cache_key += "# ";
cache_key += input_names[i];
cache_key += ": ";
cache_key += hash;
cache_key += "\n";
}
if (output_cache_key.size() != 0 && output_cache_key != cache_key) {
LOGvv << "Cache key of" << output_name << "changed.";
LOGvvv << "Run cmd:" << cmd;
system_with_check(cd_cmd.c_str());
ran = true;
}
if (output_cache_key != cache_key) {
LOGvvvv << "Prev cache key" << output_cache_key;
LOGvvvv << "Write cache key" << output_name+".key:\n" >> cache_key;
write(output_name+".key", cache_key);
}
if (!ran)
LOGvv << "Command cached:" << cmd;
return ran;
}
} // jit_compiler
} // jittor
#ifdef TEST
#include "test.h"
static unordered_map<string,string> files;
namespace jittor {
namespace jit_compiler {
string read_all(const string& fname) {
if (files.count(fname)) return files[fname];
return "";
}
void write(const string& fname, const string& src) {
files[fname] = src;
}
bool file_exist(const string& fname) {
return files.count(fname);
}
}
}
void test_find_names(string cmd, vector<string> input_names, string output_name, map<string,vector<string>> extra={}) {
LOGvv << cmd;
vector<string> inames;
string oname;
map<string,vector<string>> ename;
jittor::jit_compiler::find_names(cmd, inames, oname, ename);
CHECKop(oname,==,output_name);
CHECKop(inames.size(),==,input_names.size());
for (size_t i=0; i<inames.size(); i++)
CHECKop(inames[i],==,input_names[i]);
CHECKop(ename.size(),==,extra.size());
for (auto& kv : extra) {
auto& x = ename[kv.first];
CHECKop(x.size(),==,kv.second.size());
for (size_t i=0; i<x.size(); i++)
CHECKop(x[i],==,kv.second[i]);
}
}
void test_find_nams_error(string cmd) {
expect_error([&]() {
vector<string> inames;
string oname;
map<string, vector<string>> ename;
jittor::jit_compiler::find_names(cmd, inames, oname, ename);
});
}
void test_process(string src, vector<string> files) {
vector<string> ifiles;
jittor::jit_compiler::process(src, ifiles);
CHECK(files.size() == ifiles.size());
for (size_t i=0; i<files.size(); i++)
CHECKop(files[i],==,ifiles[i]);
}
void test_main() {
using jittor::jit_compiler::cache_compile;
test_find_names("g++ a.cc b.cc -afdsf -xvs c.o -o asd",
{"a.cc", "b.cc", "c.o"}, "asd");
test_find_names("g++ -o asd a.cc b.cc -afdsf -xvs c.o",
{"a.cc", "b.cc", "c.o"}, "asd");
test_find_names("g++ -o asd 'a ().cc' b.cc -afdsf -xvs c.o",
{"a ().cc", "b.cc", "c.o"}, "asd");
test_find_nams_error("g++ -o");
test_find_nams_error("g++ -o ");
test_find_nams_error("g++ -o asd");
// test include -I option
test_find_names("g++ a.cc b.cc -I/a/b -I'/a a/b' -I 'a/ a/' -afdsf -xvs c.o -o asd",
{"a.cc", "b.cc", "c.o"}, "asd", {{"I",{"/a/b","/a a/b","a/ a/"}}});
test_process("", {});
test_process("#inc <asd>", {});
test_process("#include <asd>", {});
test_process("#include \"asd\"", {"asd"});
test_process("//#include \"asd\"", {});
test_process("/*#include \"asd\"*/", {});
test_process("#include \"asd\"\n#include \"zxc\"", {"asd", "zxc"});
files = {{"src/a.h", "xxx"}, {"src/a.cc", "#include \"a.h\"\nxxx"}};
CHECK(cache_compile("echo src/a.cc -o a.o"));
CHECK(files.count("a.o.key"));
CHECK(!cache_compile("echo src/a.cc -o a.o"));
files["src/a.h"] ="xxxx";
CHECK(cache_compile("echo src/a.cc -o a.o"));
files["src/a.cc"] ="xxxx";
CHECK(cache_compile("echo src/a.cc -o a.o"));
CHECK(cache_compile("echo src/a.cc -ff -o a.o"));
// test include
files = {{"ex/a.h", "xxx"}, {"src/a.cc", "#include \"a.h\"\nxxx"}};
CHECK(cache_compile("echo src/a.cc -Iex -o a.o"));
CHECK(files.count("a.o.key"));
CHECK(files["a.o.key"].find("ex/a.h") >= 0);
expect_error([&]() {
cache_compile("echo src/a.cc -o a.o");
});
}
#endif