Docs: script to auto-generate ggml operations docs (#14598)
* Docs: script to auto-generate ggml operations docs * Review: formatting changes + change github action * Use built-in types instead of typing * docs : add BLAS and Metal ops --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
parent
a457551332
commit
11ee0fea2a
|
@ -0,0 +1,40 @@
|
|||
name: Update Operations Documentation
|
||||
|
||||
on:
|
||||
push:
|
||||
paths:
|
||||
- 'docs/ops/**'
|
||||
- 'scripts/create_ops_docs.py'
|
||||
pull_request:
|
||||
paths:
|
||||
- 'docs/ops/**'
|
||||
- 'scripts/create_ops_docs.py'
|
||||
|
||||
jobs:
|
||||
update-ops-docs:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.x'
|
||||
|
||||
- name: Generate operations documentation to temporary file
|
||||
run: |
|
||||
mkdir -p /tmp/ops_check
|
||||
./scripts/create_ops_docs.py /tmp/ops_check/ops.md
|
||||
|
||||
- name: Check if docs/ops.md matches generated version
|
||||
run: |
|
||||
if ! diff -q docs/ops.md /tmp/ops_check/ops.md; then
|
||||
echo "Operations documentation (docs/ops.md) is not up to date with the backend CSV files."
|
||||
echo "To fix: run ./scripts/create_ops_docs.py and commit the updated docs/ops.md along with your changes"
|
||||
echo "Differences found:"
|
||||
diff docs/ops.md /tmp/ops_check/ops.md || true
|
||||
exit 1
|
||||
fi
|
||||
echo "Operations documentation is up to date."
|
|
@ -0,0 +1,95 @@
|
|||
# GGML Operations
|
||||
|
||||
List of GGML operations and backend support status.
|
||||
|
||||
Legend:
|
||||
- ✅ Fully supported by this backend
|
||||
- 🟡 Partially supported by this backend
|
||||
- ❌ Not supported by this backend
|
||||
|
||||
| Operation | BLAS | CPU | CUDA | Metal |
|
||||
|-----------|------|------|------|------|
|
||||
| ABS | ❌ | ✅ | 🟡 | ❌ |
|
||||
| ACC | ❌ | ✅ | ✅ | ✅ |
|
||||
| ADD | ❌ | ✅ | ✅ | 🟡 |
|
||||
| ADD1 | ❌ | ✅ | ✅ | ❌ |
|
||||
| ARANGE | ❌ | ✅ | ✅ | ✅ |
|
||||
| ARGMAX | ❌ | ✅ | ✅ | ✅ |
|
||||
| ARGSORT | ❌ | ✅ | ✅ | ✅ |
|
||||
| CLAMP | ❌ | ✅ | ✅ | 🟡 |
|
||||
| CONCAT | ❌ | ✅ | 🟡 | ✅ |
|
||||
| CONT | ❌ | ✅ | 🟡 | ✅ |
|
||||
| CONV_2D_DW | ❌ | ✅ | ✅ | ❌ |
|
||||
| CONV_TRANSPOSE_1D | ❌ | ✅ | ✅ | ✅ |
|
||||
| CONV_TRANSPOSE_2D | ❌ | ✅ | ✅ | ❌ |
|
||||
| COS | ❌ | ✅ | ✅ | 🟡 |
|
||||
| COUNT_EQUAL | ❌ | ✅ | ✅ | ❌ |
|
||||
| CPY | ❌ | 🟡 | 🟡 | 🟡 |
|
||||
| CROSS_ENTROPY_LOSS | ❌ | ✅ | ✅ | ❌ |
|
||||
| CROSS_ENTROPY_LOSS_BACK | ❌ | ✅ | ✅ | ❌ |
|
||||
| DIAG_MASK_INF | ❌ | ✅ | ✅ | 🟡 |
|
||||
| DIV | ❌ | ✅ | ✅ | 🟡 |
|
||||
| DUP | ❌ | ✅ | 🟡 | 🟡 |
|
||||
| ELU | ❌ | ✅ | ❌ | 🟡 |
|
||||
| EXP | ❌ | ✅ | 🟡 | ❌ |
|
||||
| FLASH_ATTN_EXT | ❌ | ✅ | 🟡 | 🟡 |
|
||||
| GATED_LINEAR_ATTN | ❌ | ✅ | ✅ | ❌ |
|
||||
| GEGLU | ❌ | ✅ | ✅ | 🟡 |
|
||||
| GEGLU_ERF | ❌ | ✅ | ✅ | 🟡 |
|
||||
| GEGLU_QUICK | ❌ | ✅ | ✅ | 🟡 |
|
||||
| GELU | ❌ | ✅ | 🟡 | 🟡 |
|
||||
| GELU_ERF | ❌ | ✅ | 🟡 | 🟡 |
|
||||
| GELU_QUICK | ❌ | ✅ | 🟡 | 🟡 |
|
||||
| GET_ROWS | ❌ | ✅ | 🟡 | ✅ |
|
||||
| GET_ROWS_BACK | ❌ | 🟡 | 🟡 | ❌ |
|
||||
| GROUP_NORM | ❌ | ✅ | ✅ | ✅ |
|
||||
| HARDSIGMOID | ❌ | ✅ | 🟡 | ❌ |
|
||||
| HARDSWISH | ❌ | ✅ | 🟡 | ❌ |
|
||||
| IM2COL | ❌ | ✅ | ✅ | 🟡 |
|
||||
| L2_NORM | ❌ | ✅ | ✅ | ✅ |
|
||||
| LEAKY_RELU | ❌ | ✅ | ✅ | ✅ |
|
||||
| LOG | ❌ | ✅ | ✅ | ❌ |
|
||||
| MEAN | ❌ | ✅ | ✅ | ✅ |
|
||||
| MUL | ❌ | ✅ | ✅ | 🟡 |
|
||||
| MUL_MAT | 🟡 | 🟡 | 🟡 | 🟡 |
|
||||
| MUL_MAT_ID | ❌ | ✅ | ✅ | ✅ |
|
||||
| NEG | ❌ | ✅ | 🟡 | 🟡 |
|
||||
| NORM | ❌ | ✅ | ✅ | 🟡 |
|
||||
| OPT_STEP_ADAMW | ❌ | ✅ | ✅ | ❌ |
|
||||
| OUT_PROD | 🟡 | 🟡 | 🟡 | ❌ |
|
||||
| PAD | ❌ | ✅ | ✅ | ✅ |
|
||||
| PAD_REFLECT_1D | ❌ | ✅ | ❌ | ✅ |
|
||||
| POOL_2D | ❌ | ✅ | ✅ | ✅ |
|
||||
| REGLU | ❌ | ✅ | ✅ | 🟡 |
|
||||
| RELU | ❌ | ✅ | 🟡 | 🟡 |
|
||||
| REPEAT | ❌ | ✅ | 🟡 | ✅ |
|
||||
| REPEAT_BACK | ❌ | ✅ | ✅ | ❌ |
|
||||
| RMS_NORM | ❌ | ✅ | ✅ | 🟡 |
|
||||
| RMS_NORM_BACK | ❌ | ✅ | ✅ | ❌ |
|
||||
| RMS_NORM_MUL | ❌ | ✅ | ✅ | ✅ |
|
||||
| ROPE | ❌ | ✅ | ✅ | ✅ |
|
||||
| ROPE_BACK | ❌ | ✅ | ✅ | ❌ |
|
||||
| RWKV_WKV6 | ❌ | ✅ | ✅ | ✅ |
|
||||
| RWKV_WKV7 | ❌ | ✅ | ✅ | ✅ |
|
||||
| SCALE | ❌ | ✅ | ✅ | ✅ |
|
||||
| SET | ❌ | ✅ | ❌ | ✅ |
|
||||
| SET_ROWS | ❌ | 🟡 | ❌ | 🟡 |
|
||||
| SGN | ❌ | ✅ | 🟡 | ❌ |
|
||||
| SIGMOID | ❌ | ✅ | 🟡 | 🟡 |
|
||||
| SILU | ❌ | ✅ | 🟡 | 🟡 |
|
||||
| SILU_BACK | ❌ | ✅ | ✅ | ❌ |
|
||||
| SIN | ❌ | ✅ | ✅ | 🟡 |
|
||||
| SOFT_MAX | ❌ | ✅ | ✅ | ✅ |
|
||||
| SOFT_MAX_BACK | ❌ | 🟡 | 🟡 | ❌ |
|
||||
| SQR | ❌ | ✅ | ✅ | 🟡 |
|
||||
| SQRT | ❌ | ✅ | ✅ | 🟡 |
|
||||
| SSM_CONV | ❌ | ✅ | ✅ | ✅ |
|
||||
| SSM_SCAN | ❌ | ✅ | ✅ | ✅ |
|
||||
| STEP | ❌ | ✅ | 🟡 | ❌ |
|
||||
| SUB | ❌ | ✅ | ✅ | 🟡 |
|
||||
| SUM | ❌ | ✅ | ✅ | ❌ |
|
||||
| SUM_ROWS | ❌ | ✅ | ✅ | ✅ |
|
||||
| SWIGLU | ❌ | ✅ | ✅ | 🟡 |
|
||||
| TANH | ❌ | ✅ | 🟡 | 🟡 |
|
||||
| TIMESTEP_EMBEDDING | ❌ | ✅ | ✅ | ✅ |
|
||||
| UPSCALE | ❌ | ✅ | ✅ | 🟡 |
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,196 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
"""
|
||||
This script parses docs/ops/*.csv and creates the ops.md, which is a table documenting supported operations on various ggml backends.
|
||||
"""
|
||||
import csv
|
||||
import logging
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from collections import defaultdict
|
||||
|
||||
|
||||
class DocsGenerator:
|
||||
def __init__(self, ggml_root: str, output_filename: str = "ops.md"):
|
||||
self.ggml_root = Path(ggml_root)
|
||||
self.ops_dir = self.ggml_root / "docs" / "ops"
|
||||
self.output_filename = output_filename
|
||||
self.backend_support: dict[str, dict[str, list[bool]]] = defaultdict(
|
||||
lambda: defaultdict(list)
|
||||
)
|
||||
self.all_operations: set[str] = set()
|
||||
self.all_backends: set[str] = set()
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
def parse_support_files(self) -> None:
|
||||
if not self.ops_dir.exists():
|
||||
self.logger.warning(f"ops directory not found: {self.ops_dir}")
|
||||
return
|
||||
|
||||
self.logger.info(f"Parsing support files from {self.ops_dir}...")
|
||||
|
||||
for support_file in self.ops_dir.glob("*.csv"):
|
||||
self.logger.info(f" Reading: {support_file.name}")
|
||||
self._parse_support_file(support_file)
|
||||
|
||||
def _parse_support_file(self, file_path: Path) -> None:
|
||||
try:
|
||||
with open(file_path, "r", newline='') as f:
|
||||
reader = csv.DictReader(f)
|
||||
|
||||
for row in reader:
|
||||
# Skip rows that don't have support mode
|
||||
if row.get('test_mode') != 'support':
|
||||
continue
|
||||
|
||||
backend_name = row.get('backend_name', '').strip()
|
||||
operation = row.get('op_name', '').strip()
|
||||
supported_str = row.get('error_message', '').strip() # "yes" or "no"
|
||||
backend_reg_name = row.get('backend_reg_name', '').strip()
|
||||
|
||||
# Skip invalid or error operations
|
||||
if not operation or not backend_name or operation in [
|
||||
"CONTEXT_ERROR",
|
||||
"BUILD_ERROR",
|
||||
]:
|
||||
continue
|
||||
|
||||
is_supported = supported_str.lower() == "yes"
|
||||
|
||||
# Use backend_reg_name for grouping, fallback to backend_name
|
||||
backend_key = backend_reg_name if backend_reg_name else backend_name
|
||||
|
||||
self.all_backends.add(backend_key)
|
||||
self.backend_support[backend_key][operation].append(is_supported)
|
||||
self.all_operations.add(operation)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f" Error parsing {file_path}: {e}")
|
||||
|
||||
def get_backend_support_status(self, backend: str, operation: str) -> str:
|
||||
support_list = self.backend_support[backend].get(operation, [])
|
||||
|
||||
if not support_list:
|
||||
return "unsupported"
|
||||
|
||||
all_supported = all(support_list)
|
||||
any_supported = any(support_list)
|
||||
|
||||
if all_supported:
|
||||
return "supported"
|
||||
elif any_supported:
|
||||
return "partially supported"
|
||||
else:
|
||||
return "unsupported"
|
||||
|
||||
def get_support_status(self, operation: str) -> str:
|
||||
if operation not in self.all_operations:
|
||||
return "unsupported"
|
||||
|
||||
support_count = 0
|
||||
total_backends = len(self.all_backends)
|
||||
|
||||
for backend in self.all_backends:
|
||||
if self.backend_support[backend].get(operation, False):
|
||||
support_count += 1
|
||||
|
||||
if support_count == 0:
|
||||
return "unsupported"
|
||||
elif support_count == total_backends:
|
||||
return "supported"
|
||||
else:
|
||||
return "partially supported"
|
||||
|
||||
def get_support_symbol(self, status: str) -> str:
|
||||
symbols = {"supported": "✅", "partially supported": "🟡", "unsupported": "❌"}
|
||||
return symbols.get(status, "❓")
|
||||
|
||||
def generate_markdown(self) -> str:
|
||||
lines = []
|
||||
|
||||
lines.append("# GGML Operations")
|
||||
lines.append("")
|
||||
lines.append("List of GGML operations and backend support status.")
|
||||
lines.append("")
|
||||
lines.append("Legend:")
|
||||
lines.append("- ✅ Fully supported by this backend")
|
||||
lines.append("- 🟡 Partially supported by this backend")
|
||||
lines.append("- ❌ Not supported by this backend")
|
||||
lines.append("")
|
||||
|
||||
backends = sorted(self.all_backends)
|
||||
header = "| Operation |"
|
||||
for backend in backends:
|
||||
header += f" {backend} |"
|
||||
|
||||
separator = "|-----------|"
|
||||
for _ in backends:
|
||||
separator += "------|"
|
||||
|
||||
lines.append(header)
|
||||
lines.append(separator)
|
||||
|
||||
sorted_operations = sorted(self.all_operations)
|
||||
|
||||
for operation in sorted_operations:
|
||||
row = f"| {operation:>32} |"
|
||||
|
||||
for backend in backends:
|
||||
status = self.get_backend_support_status(backend, operation)
|
||||
if status == "supported":
|
||||
symbol = "✅"
|
||||
elif status == "partially supported":
|
||||
symbol = "🟡"
|
||||
else:
|
||||
symbol = "❌"
|
||||
row += f" {symbol} |"
|
||||
|
||||
lines.append(row)
|
||||
|
||||
lines.append("")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
def run(self) -> None:
|
||||
self.logger.info("Parsing GGML operation support files...")
|
||||
self.parse_support_files()
|
||||
|
||||
if not self.all_operations:
|
||||
self.logger.error(
|
||||
"No operations found. Make sure to run test-backend-ops support --output csv > docs/ops/file.csv first."
|
||||
)
|
||||
return
|
||||
|
||||
self.logger.info(
|
||||
f"Found {len(self.all_operations)} operations across {len(self.all_backends)} backends"
|
||||
)
|
||||
|
||||
self.logger.info("Generating markdown...")
|
||||
markdown_content = self.generate_markdown()
|
||||
|
||||
docs_dir = self.ggml_root / "docs"
|
||||
docs_dir.mkdir(exist_ok=True)
|
||||
|
||||
ops_file = docs_dir / self.output_filename
|
||||
with open(ops_file, "w") as f:
|
||||
f.write(markdown_content)
|
||||
|
||||
self.logger.info(f"Generated: {ops_file}")
|
||||
self.logger.info(f"Operations: {len(self.all_operations)}")
|
||||
self.logger.info(f"Backends: {len(self.all_backends)}")
|
||||
|
||||
|
||||
def main():
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
if len(sys.argv) > 1:
|
||||
output_filename = sys.argv[1]
|
||||
else:
|
||||
output_filename = "ops.md"
|
||||
|
||||
generator = DocsGenerator(".", output_filename)
|
||||
generator.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -317,10 +317,11 @@ enum test_mode {
|
|||
MODE_TEST,
|
||||
MODE_PERF,
|
||||
MODE_GRAD,
|
||||
MODE_SUPPORT,
|
||||
};
|
||||
|
||||
// Output format support similar to llama-bench
|
||||
enum output_formats { CONSOLE, SQL };
|
||||
enum output_formats { CONSOLE, SQL, CSV };
|
||||
|
||||
static const char * output_format_str(output_formats format) {
|
||||
switch (format) {
|
||||
|
@ -328,6 +329,8 @@ static const char * output_format_str(output_formats format) {
|
|||
return "console";
|
||||
case SQL:
|
||||
return "sql";
|
||||
case CSV:
|
||||
return "csv";
|
||||
default:
|
||||
GGML_ABORT("invalid output format");
|
||||
}
|
||||
|
@ -338,6 +341,8 @@ static bool output_format_from_str(const std::string & s, output_formats & forma
|
|||
format = CONSOLE;
|
||||
} else if (s == "sql") {
|
||||
format = SQL;
|
||||
} else if (s == "csv") {
|
||||
format = CSV;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
|
@ -360,6 +365,8 @@ struct test_result {
|
|||
double bandwidth_gb_s;
|
||||
size_t memory_kb;
|
||||
int n_runs;
|
||||
std::string device_description;
|
||||
std::string backend_reg_name;
|
||||
|
||||
test_result() {
|
||||
// Initialize with default values
|
||||
|
@ -384,7 +391,7 @@ struct test_result {
|
|||
test_result(const std::string & backend_name, const std::string & op_name, const std::string & op_params,
|
||||
const std::string & test_mode, bool supported, bool passed, const std::string & error_message = "",
|
||||
double time_us = 0.0, double flops = 0.0, double bandwidth_gb_s = 0.0, size_t memory_kb = 0,
|
||||
int n_runs = 0) :
|
||||
int n_runs = 0, const std::string & device_description = "", const std::string & backend_reg_name = "") :
|
||||
backend_name(backend_name),
|
||||
op_name(op_name),
|
||||
op_params(op_params),
|
||||
|
@ -396,7 +403,9 @@ struct test_result {
|
|||
flops(flops),
|
||||
bandwidth_gb_s(bandwidth_gb_s),
|
||||
memory_kb(memory_kb),
|
||||
n_runs(n_runs) {
|
||||
n_runs(n_runs),
|
||||
device_description(device_description),
|
||||
backend_reg_name(backend_reg_name) {
|
||||
// Set test time
|
||||
time_t t = time(NULL);
|
||||
char buf[32];
|
||||
|
@ -410,7 +419,8 @@ struct test_result {
|
|||
static const std::vector<std::string> & get_fields() {
|
||||
static const std::vector<std::string> fields = {
|
||||
"test_time", "build_commit", "backend_name", "op_name", "op_params", "test_mode", "supported",
|
||||
"passed", "error_message", "time_us", "flops", "bandwidth_gb_s", "memory_kb", "n_runs"
|
||||
"passed", "error_message", "time_us", "flops", "bandwidth_gb_s", "memory_kb", "n_runs",
|
||||
"device_description", "backend_reg_name"
|
||||
};
|
||||
return fields;
|
||||
}
|
||||
|
@ -444,7 +454,9 @@ struct test_result {
|
|||
std::to_string(flops),
|
||||
std::to_string(bandwidth_gb_s),
|
||||
std::to_string(memory_kb),
|
||||
std::to_string(n_runs) };
|
||||
std::to_string(n_runs),
|
||||
device_description,
|
||||
backend_reg_name };
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -633,6 +645,8 @@ struct console_printer : public printer {
|
|||
print_test_console(result);
|
||||
} else if (result.test_mode == "perf") {
|
||||
print_perf_console(result);
|
||||
} else if (result.test_mode == "support") {
|
||||
print_support_console(result);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -799,6 +813,17 @@ struct console_printer : public printer {
|
|||
}
|
||||
printf("\n");
|
||||
}
|
||||
|
||||
void print_support_console(const test_result & result) {
|
||||
printf(" %s(%s): ", result.op_name.c_str(), result.op_params.c_str());
|
||||
fflush(stdout);
|
||||
|
||||
if (result.supported) {
|
||||
printf("\033[1;32mSUPPORTED\033[0m\n");
|
||||
} else {
|
||||
printf("\033[1;31mNOT SUPPORTED\033[0m\n");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct sql_printer : public printer {
|
||||
|
@ -841,12 +866,39 @@ struct sql_printer : public printer {
|
|||
}
|
||||
};
|
||||
|
||||
struct csv_printer : public printer {
|
||||
void print_header() override {
|
||||
std::vector<std::string> fields = test_result::get_fields();
|
||||
for (size_t i = 0; i < fields.size(); i++) {
|
||||
printf("\"%s\"%s", fields[i].c_str(), i < fields.size() - 1 ? "," : "");
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
|
||||
void print_test_result(const test_result & result) override {
|
||||
std::vector<std::string> values = result.get_values();
|
||||
for (size_t i = 0; i < values.size(); i++) {
|
||||
// Escape quotes and wrap in quotes for CSV
|
||||
std::string escaped_value = values[i];
|
||||
size_t pos = 0;
|
||||
while ((pos = escaped_value.find("\"", pos)) != std::string::npos) {
|
||||
escaped_value.replace(pos, 1, "\"\"");
|
||||
pos += 2;
|
||||
}
|
||||
printf("\"%s\"%s", escaped_value.c_str(), i < values.size() - 1 ? "," : "");
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
};
|
||||
|
||||
static std::unique_ptr<printer> create_printer(output_formats format) {
|
||||
switch (format) {
|
||||
case CONSOLE:
|
||||
return std::make_unique<console_printer>();
|
||||
case SQL:
|
||||
return std::make_unique<sql_printer>();
|
||||
case CSV:
|
||||
return std::make_unique<csv_printer>();
|
||||
}
|
||||
GGML_ABORT("invalid output format");
|
||||
}
|
||||
|
@ -928,7 +980,7 @@ struct test_case {
|
|||
std::vector<ggml_tensor *> sentinels;
|
||||
|
||||
void add_sentinel(ggml_context * ctx) {
|
||||
if (mode == MODE_PERF || mode == MODE_GRAD) {
|
||||
if (mode == MODE_PERF || mode == MODE_GRAD || mode == MODE_SUPPORT) {
|
||||
return;
|
||||
}
|
||||
ggml_tensor * sentinel = ::ggml_new_tensor_1d(ctx, GGML_TYPE_F32, sentinel_size);
|
||||
|
@ -1153,15 +1205,12 @@ struct test_case {
|
|||
return true;
|
||||
}
|
||||
|
||||
// check if backends support op
|
||||
if (!ggml_backend_supports_op(backend, out)) {
|
||||
// Create test result for unsupported performance test
|
||||
test_result result(ggml_backend_name(backend), current_op_name, vars(), "perf", false, false,
|
||||
"not supported");
|
||||
|
||||
if (output_printer) {
|
||||
output_printer->print_test_result(result);
|
||||
}
|
||||
output_printer->print_test_result(result);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
@ -1266,6 +1315,38 @@ struct test_case {
|
|||
return true;
|
||||
}
|
||||
|
||||
bool eval_support(ggml_backend_t backend, const char * op_name, printer * output_printer) {
|
||||
mode = MODE_SUPPORT;
|
||||
|
||||
static const size_t graph_nodes = 8192;
|
||||
|
||||
ggml_init_params params = {
|
||||
/* .mem_size = */ ggml_tensor_overhead()*128 + ggml_graph_overhead_custom(graph_nodes, false),
|
||||
/* .mem_base = */ NULL,
|
||||
/* .no_alloc = */ true,
|
||||
};
|
||||
ggml_context_ptr ctx(ggml_init(params)); // smart ptr
|
||||
GGML_ASSERT(ctx);
|
||||
|
||||
ggml_tensor * out = build_graph(ctx.get());
|
||||
std::string current_op_name = op_desc(out);
|
||||
if (op_name != nullptr && current_op_name != op_name) {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool supported = ggml_backend_supports_op(backend, out);
|
||||
|
||||
std::string device_desc = ggml_backend_dev_description(ggml_backend_get_device(backend));
|
||||
std::string backend_reg_name = ggml_backend_reg_name(ggml_backend_dev_backend_reg(ggml_backend_get_device(backend)));
|
||||
|
||||
test_result result(ggml_backend_name(backend), current_op_name, vars(), "support", supported, supported,
|
||||
supported ? "yes" : "no", 0.0, 0.0, 0.0, 0, 0, device_desc, backend_reg_name);
|
||||
|
||||
output_printer->print_test_result(result);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool eval_grad(ggml_backend_t backend, const char * op_name, printer * output_printer) {
|
||||
mode = MODE_GRAD;
|
||||
const std::vector<float> expect = grad_expect();
|
||||
|
@ -5599,17 +5680,27 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
|
|||
return true;
|
||||
}
|
||||
|
||||
if (mode == MODE_SUPPORT) {
|
||||
auto test_cases = make_test_cases_eval();
|
||||
filter_test_cases(test_cases, params_filter);
|
||||
for (auto & test : test_cases) {
|
||||
test->eval_support(backend, op_name, output_printer);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
|
||||
static void usage(char ** argv) {
|
||||
printf("Usage: %s [mode] [-o <op>] [-b <backend>] [-p <params regex>] [--output <console|sql>]\n", argv[0]);
|
||||
printf("Usage: %s [mode] [-o <op>] [-b <backend>] [-p <params regex>] [--output <console|sql|csv>]\n", argv[0]);
|
||||
printf(" valid modes:\n");
|
||||
printf(" - test (default, compare with CPU backend for correctness)\n");
|
||||
printf(" - grad (compare gradients from backpropagation with method of finite differences)\n");
|
||||
printf(" - perf (performance evaluation)\n");
|
||||
printf(" - support (probe backend operation support)\n");
|
||||
printf(" op names for -o are as given by ggml_op_desc() (e.g. ADD, MUL_MAT, etc)\n");
|
||||
printf(" --output specifies output format (default: console)\n");
|
||||
printf(" --output specifies output format (default: console, options: console, sql, csv)\n");
|
||||
}
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
|
@ -5626,6 +5717,8 @@ int main(int argc, char ** argv) {
|
|||
mode = MODE_PERF;
|
||||
} else if (strcmp(argv[i], "grad") == 0) {
|
||||
mode = MODE_GRAD;
|
||||
} else if (strcmp(argv[i], "support") == 0) {
|
||||
mode = MODE_SUPPORT;
|
||||
} else if (strcmp(argv[i], "-o") == 0) {
|
||||
if (i + 1 < argc) {
|
||||
op_name_filter = argv[++i];
|
||||
|
|
Loading…
Reference in New Issue