mirror of https://github.com/inclusionAI/AReaL
840 lines
25 KiB
Python
840 lines
25 KiB
Python
# Copyright 2025 Ant Group Inc.
|
|
# Copyright 2024 Wei Fu & Zhiyu Mei
|
|
# Licensed under the Apache License, Version 2.0 (the "License").
|
|
|
|
import asyncio
|
|
import contextlib
|
|
import dataclasses
|
|
import enum
|
|
import json
|
|
import os
|
|
import pickle
|
|
import re
|
|
import time
|
|
from collections import defaultdict
|
|
from statistics import mean
|
|
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Union
|
|
|
|
import numpy as np
|
|
import psutil
|
|
import pynvml
|
|
import torch
|
|
import tqdm
|
|
import tqdm.asyncio
|
|
|
|
import realhf.base.constants as constants
|
|
import realhf.base.logging as logging
|
|
|
|
if TYPE_CHECKING:
|
|
from realhf.api.core.config import ModelName
|
|
|
|
logger = logging.getLogger("benchmark")
|
|
|
|
IF_MARK = False
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class RolloutStat:
|
|
submitted: int = 0
|
|
accepted: int = 0
|
|
running: int = 0
|
|
|
|
|
|
def mock_time_mark(name, identifier, t, step):
|
|
if IF_MARK:
|
|
logger.debug(f"*{name}* #{identifier}# ${t}$ ns step &{step}&")
|
|
|
|
|
|
def time_mark(name, identifier, step=0):
|
|
if IF_MARK:
|
|
logger.debug(
|
|
f"*{name}* #{identifier}# ${int(time.time_ns())}$ ns step &{step}&"
|
|
)
|
|
|
|
|
|
def parse_time_mark_in_line(line, name, step_range=None):
|
|
if f"*{name}*" in line:
|
|
identifer, t, step = (
|
|
line.split("#")[1],
|
|
int(line.split("$")[1]),
|
|
int(line.split("&")[1]),
|
|
)
|
|
if step_range:
|
|
if step >= step_range[1] or step < step_range[0]:
|
|
return None
|
|
return identifer, t
|
|
|
|
else:
|
|
return None
|
|
|
|
|
|
def parse_time_mark_in_file(file, name, step_range=None):
|
|
time_points = defaultdict(list)
|
|
with open(file, "r") as f:
|
|
count = 0
|
|
res_count = 0
|
|
for line in f.readlines():
|
|
count += 1
|
|
res = parse_time_mark_in_line(line, name, step_range=step_range)
|
|
if res is not None:
|
|
res_count += 1
|
|
identifier, time_point = res
|
|
time_points[identifier].append(time_point)
|
|
return time_points
|
|
|
|
|
|
def parse_time_mark_in_dir(dir, name, step_range=None):
|
|
time_points = {}
|
|
for file in os.listdir(dir):
|
|
file_path = os.path.join(dir, file)
|
|
tpsf = parse_time_mark_in_file(file_path, name, step_range=step_range)
|
|
for k, v in tpsf.items():
|
|
if k not in time_points:
|
|
time_points[k] = v
|
|
else:
|
|
time_points[k].extend(v)
|
|
return time_points
|
|
|
|
|
|
MATPLOTLIB_COLORS = [
|
|
"red",
|
|
"blue",
|
|
"green",
|
|
"yellow",
|
|
"orange",
|
|
"purple",
|
|
"pink",
|
|
"black",
|
|
"brown",
|
|
"gray",
|
|
"cyan",
|
|
"magenta",
|
|
"lime",
|
|
"olive",
|
|
"navy",
|
|
]
|
|
|
|
|
|
def summary_time_points(
|
|
start_keys,
|
|
end_keys,
|
|
identifiers,
|
|
dir_name=None,
|
|
file_name=None,
|
|
start_time=None,
|
|
figsize=(12, 4),
|
|
end_time=None,
|
|
step_range=None,
|
|
save_fig_path="time_points.png",
|
|
draw_boundary=False,
|
|
):
|
|
"""Plot and summary time marks in logs."""
|
|
import matplotlib.pyplot as plt
|
|
|
|
assert file_name or dir_name, "dir or file name must be specified"
|
|
all_time_points = {}
|
|
if file_name is None:
|
|
for k in start_keys:
|
|
all_time_points[k] = parse_time_mark_in_dir(
|
|
dir_name, k, step_range=step_range
|
|
)
|
|
for k in end_keys:
|
|
all_time_points[k] = parse_time_mark_in_dir(
|
|
dir_name, k, step_range=step_range
|
|
)
|
|
else:
|
|
for k in start_keys:
|
|
all_time_points[k] = parse_time_mark_in_file(
|
|
file_name, k, step_range=step_range
|
|
)
|
|
for k in end_keys:
|
|
all_time_points[k] = parse_time_mark_in_file(
|
|
file_name, k, step_range=step_range
|
|
)
|
|
|
|
fig, ax = plt.subplots(1, 1, figsize=figsize)
|
|
ax.set_ylim(-1, len(identifiers))
|
|
ax.set_yticks(list(range(len(identifiers))))
|
|
ax.set_yticklabels(identifiers)
|
|
|
|
label_set = {sk: False for sk in start_keys}
|
|
infos = {}
|
|
min_time = None
|
|
max_time = None
|
|
for id_index, identifier in enumerate(identifiers):
|
|
time_sum = {}
|
|
time_list = {}
|
|
for start_key_idx, (start_key, end_key) in enumerate(zip(start_keys, end_keys)):
|
|
time_sum[start_key] = 0
|
|
time_list[start_key] = []
|
|
try:
|
|
start_time_points = np.array(all_time_points[start_key][identifier])
|
|
end_time_points = np.array(all_time_points[end_key][identifier])
|
|
except KeyError:
|
|
continue
|
|
assert len(start_time_points) == len(end_time_points)
|
|
|
|
if start_time is not None:
|
|
valid_indices_st = np.where(start_time_points > start_time)
|
|
valid_indices_et = np.where(start_time_points < end_time)
|
|
valid_indices = np.intersect1d(valid_indices_st, valid_indices_et)
|
|
start_time_points = start_time_points[valid_indices]
|
|
end_time_points = end_time_points[valid_indices]
|
|
|
|
# plot time point pairs
|
|
for stp, etp in zip(list(start_time_points), list(end_time_points)):
|
|
min_time = stp if min_time is None else min(min_time, stp)
|
|
max_time = etp if max_time is None else max(max_time, etp)
|
|
time_sum[start_key] += etp - stp
|
|
time_list[start_key].append(etp - stp)
|
|
|
|
if label_set[start_key] is False:
|
|
label = start_key
|
|
label_set[start_key] = True
|
|
else:
|
|
label = None
|
|
|
|
ax.barh(
|
|
y=id_index,
|
|
width=etp - stp,
|
|
left=stp,
|
|
height=0.8,
|
|
color=MATPLOTLIB_COLORS[start_key_idx],
|
|
label=label,
|
|
)
|
|
|
|
if draw_boundary:
|
|
ax.plot(
|
|
[stp, stp],
|
|
[id_index - 0.4, id_index + 0.4],
|
|
color="black",
|
|
linestyle="-",
|
|
linewidth=0.5,
|
|
)
|
|
ax.plot(
|
|
[etp, etp],
|
|
[id_index - 0.4, id_index + 0.4],
|
|
color="black",
|
|
linestyle="-",
|
|
linewidth=0.5,
|
|
)
|
|
|
|
infos[identifier] = (time_sum, time_list)
|
|
|
|
ax.set_xlim(min_time, max_time)
|
|
total_width = max_time - min_time
|
|
xticks = np.arange(
|
|
min_time - total_width // 12, max_time - total_width // 12, 10 * 1e9
|
|
)
|
|
xtick_labels = [f"{int((i//1e9)%1000)}" for i in xticks]
|
|
ax.set_xticks(xticks)
|
|
ax.set_xticklabels(xtick_labels)
|
|
|
|
# summary time cost percent
|
|
for id_index, identifier in enumerate(identifiers):
|
|
print("=" * 30)
|
|
print(f"Identifier {identifier} time cost percent:")
|
|
bubble_time = 100
|
|
time_sum, time_list = infos[identifier]
|
|
for k in time_sum:
|
|
time_perc = round(time_sum[k] / (max_time - min_time) * 100, 2)
|
|
# print time cost percent
|
|
avg_val = (
|
|
round(mean(time_list[k]) / 10e6, 2) if len(time_list[k]) > 0 else "-"
|
|
)
|
|
max_val = (
|
|
round(max(time_list[k]) / 10e6, 2) if len(time_list[k]) > 0 else "-"
|
|
)
|
|
min_val = (
|
|
round(min(time_list[k]) / 10e6, 2) if len(time_list[k]) > 0 else "-"
|
|
)
|
|
|
|
bubble_time -= time_perc
|
|
print(
|
|
f"{k} -- {time_perc} %, "
|
|
f"avg, min, max = {avg_val}, {min_val}, {max_val} ms, "
|
|
f"sum, n = {round(time_sum[k]/10e6, 2)} ms, {len(time_list[k])}"
|
|
)
|
|
print(f"bubble time -- {round(bubble_time, 2)}%")
|
|
|
|
plt.legend(loc=(1.01, 0.0))
|
|
plt.tight_layout()
|
|
|
|
plt.savefig(save_fig_path)
|
|
|
|
|
|
def gpu_utilization_monitor(worker_idx: int, interval: float, ttl: float):
|
|
pynvml.nvmlInit()
|
|
gpu_idx = worker_idx % 8
|
|
tik = time.time()
|
|
while time.time() - tik < ttl:
|
|
handle = pynvml.nvmlDeviceGetHandleByIndex(gpu_idx)
|
|
utilization = pynvml.nvmlDeviceGetUtilizationRates(handle)
|
|
memory_info = pynvml.nvmlDeviceGetMemoryInfo(handle)
|
|
total_memory = memory_info.total / (1024**2) # Convert bytes to megabytes
|
|
used_memory = memory_info.used / (1024**2)
|
|
memory_usage_percentage = (used_memory / total_memory) * 100
|
|
logger.debug(
|
|
f"Worker Index {worker_idx}, GPU {gpu_idx}: "
|
|
f"Compute Utilization - {utilization.gpu}%, "
|
|
f"Total Memory - {total_memory:.2f}MB, Used Memory - {used_memory:.2f}MB, "
|
|
f"Memory Usage - {memory_usage_percentage:.2f}%"
|
|
)
|
|
time.sleep(interval)
|
|
pynvml.nvmlShutdown()
|
|
|
|
|
|
# Helper function to calculate FLOPs using the Megatron-LM paper's formula
|
|
def calculate_llama_train_flops(
|
|
checkpoint_activations_factor: int,
|
|
batch_size: int,
|
|
seqlens: List[int],
|
|
num_layers: int,
|
|
hidden_size: int,
|
|
intermediate_size: int,
|
|
vocab_size: int,
|
|
):
|
|
return checkpoint_activations_factor * caculuate_llama_forward_flops(
|
|
batch_size,
|
|
seqlens,
|
|
num_layers,
|
|
hidden_size,
|
|
intermediate_size,
|
|
vocab_size,
|
|
)
|
|
|
|
|
|
def caculuate_llama_forward_flops(
|
|
batch_size: int,
|
|
seqlens: List[int],
|
|
num_layers: int,
|
|
hidden_size: int,
|
|
intermediate_size: int,
|
|
vocab_size: int,
|
|
):
|
|
assert len(seqlens) == batch_size
|
|
attn_flops = sum(x**2 for x in seqlens) * hidden_size
|
|
return (
|
|
2
|
|
* num_layers
|
|
* (
|
|
4 * sum(seqlens) * hidden_size**2
|
|
+ 2 * attn_flops
|
|
+ 3 * sum(seqlens) * hidden_size * intermediate_size
|
|
)
|
|
+ 4 * sum(seqlens) * vocab_size * hidden_size
|
|
)
|
|
|
|
|
|
def calculate_llama_gen_flops(
|
|
batch_size,
|
|
prompt_lens,
|
|
gen_len,
|
|
num_layers,
|
|
hidden_size,
|
|
intermediate_size,
|
|
vocab_size,
|
|
):
|
|
flops = caculuate_llama_forward_flops(
|
|
batch_size,
|
|
prompt_lens,
|
|
num_layers,
|
|
hidden_size,
|
|
intermediate_size,
|
|
vocab_size,
|
|
)
|
|
for i in range(gen_len):
|
|
prefix_lens = [x + i for x in prompt_lens]
|
|
flops += (
|
|
2
|
|
* num_layers
|
|
* (
|
|
4 * batch_size * hidden_size**2
|
|
+ 2 * (sum(prefix_lens) + batch_size) * hidden_size
|
|
+ 3 * batch_size * hidden_size * intermediate_size
|
|
)
|
|
+ 4 * batch_size * vocab_size * hidden_size
|
|
)
|
|
return flops
|
|
|
|
|
|
#################### CUDA Kernel Time Marking Start ####################
|
|
# Used to create timeline plots.
|
|
|
|
|
|
class CUDATimeMarkType(enum.Enum):
|
|
forward = "forward"
|
|
backward = "backward"
|
|
optim_step = "optim_step"
|
|
comm = "comm"
|
|
misc = "misc"
|
|
mem_layout = "memory_layout"
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class TimeMarkEntry:
|
|
name: str
|
|
model_name: "ModelName"
|
|
type_: CUDATimeMarkType
|
|
start_time: int
|
|
end_time: int
|
|
|
|
|
|
TIME_MARK_DB = []
|
|
|
|
|
|
def cuda_tmark(name: str, type_: CUDATimeMarkType):
|
|
if os.getenv("REAL_CUDA_TMARK", None) == "1":
|
|
|
|
def wrapper(f: Callable):
|
|
|
|
def _wrapped_f(*args, **kwargs):
|
|
import torch
|
|
|
|
if constants.use_cuda():
|
|
from realhf.base.constants import _model_name
|
|
|
|
torch.cuda.synchronize()
|
|
tik = time.time_ns()
|
|
res = f(*args, **kwargs)
|
|
torch.cuda.synchronize()
|
|
tok = time.time_ns()
|
|
global TIME_MARK_DB
|
|
TIME_MARK_DB.append(
|
|
TimeMarkEntry(name, _model_name, type_, tik, tok)
|
|
)
|
|
else:
|
|
res = f(*args, **kwargs)
|
|
return res
|
|
|
|
return _wrapped_f
|
|
|
|
else:
|
|
|
|
def wrapper(f):
|
|
return f
|
|
|
|
return wrapper
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def cuda_tmarked(name: str, type_: CUDATimeMarkType):
|
|
if os.getenv("REAL_CUDA_TMARK", None) == "1":
|
|
import torch
|
|
|
|
if constants.use_cuda():
|
|
from realhf.base.constants import _model_name
|
|
|
|
torch.cuda.synchronize()
|
|
tik = time.time_ns()
|
|
yield
|
|
if os.getenv("REAL_CUDA_TMARK", None) == "1":
|
|
if constants.use_cuda():
|
|
torch.cuda.synchronize()
|
|
tok = time.time_ns()
|
|
global TIME_MARK_DB
|
|
TIME_MARK_DB.append(TimeMarkEntry(name, _model_name, type_, tik, tok))
|
|
|
|
|
|
def fetch_latest_tmark():
|
|
global TIME_MARK_DB
|
|
return TIME_MARK_DB[-1]
|
|
|
|
|
|
def dump_tmark_db(worker_idx):
|
|
if os.getenv("REAL_CUDA_TMARK", None) != "1":
|
|
return
|
|
fn = os.path.join(
|
|
constants.LOG_ROOT,
|
|
constants.experiment_name(),
|
|
constants.trial_name(),
|
|
f"time_marks{worker_idx}.pkl",
|
|
)
|
|
global TIME_MARK_DB
|
|
with open(fn, "wb") as f:
|
|
pickle.dump(TIME_MARK_DB, f)
|
|
TIME_MARK_DB.clear()
|
|
|
|
|
|
#################### CUDA Kernel Time Marking End ####################
|
|
|
|
#################### CUDA Kernel Time Statistics Start ####################
|
|
# Categorizing CUDA kernels into computation, communication, memory IO, and MISC/IDLE,
|
|
# used to plot the percentage of time spent on each category and show how much we can
|
|
# improve over vanilla parallel strategies.
|
|
|
|
COMPUTE_KERNEL_KEYS = [
|
|
"elementwise_kernel",
|
|
"gemm",
|
|
"aten::",
|
|
"at::native::",
|
|
"flash",
|
|
"backward_kernel",
|
|
"reduce_kernel",
|
|
"multi_tensor_apply",
|
|
"gae_kernel",
|
|
"gemvx::kernel",
|
|
"cublas",
|
|
"cudnn",
|
|
"cutlass",
|
|
]
|
|
|
|
P2P_COMM_KERNEL_KEYS = [
|
|
"ncclDevKernel_SendRecv",
|
|
]
|
|
|
|
COLL_COMM_KERNEL_KEYS = [
|
|
"ncclDevKernel_AllReduce",
|
|
"ncclDevKernel_ReduceScatter",
|
|
"ncclDevKernel_AllGather",
|
|
]
|
|
|
|
MEM_KERNEL_KEYS = [
|
|
"Memcpy",
|
|
"cleanup",
|
|
"Memset",
|
|
]
|
|
|
|
MISC_KERNEL_KEYS = [
|
|
"at_cuda_detail",
|
|
"CudaCodeGen",
|
|
]
|
|
|
|
|
|
class CUDAKernelTimeCategory(enum.Enum):
|
|
COMPUTE = "compute"
|
|
P2P_COMM = "p2p_comm"
|
|
COLL_COMM = "coll_comm"
|
|
MEM = "memoryIO"
|
|
IDLE = "idle"
|
|
MISC = "misc"
|
|
|
|
@classmethod
|
|
def from_name(cls, name):
|
|
# Order may matter. MEM & COMM keys are easier to find out.
|
|
if any(k in name for k in MEM_KERNEL_KEYS):
|
|
return cls.MEM
|
|
if any(k in name for k in P2P_COMM_KERNEL_KEYS):
|
|
return cls.P2P_COMM
|
|
if any(k in name for k in COLL_COMM_KERNEL_KEYS):
|
|
return cls.COLL_COMM
|
|
if any(k in name for k in MISC_KERNEL_KEYS):
|
|
return cls.MISC
|
|
if any(k in name for k in COMPUTE_KERNEL_KEYS):
|
|
return cls.COMPUTE
|
|
raise NotImplementedError(f"Unknown kernel type. Name is `{name}`")
|
|
|
|
|
|
class CUDAKernelTimeStat: # in us
|
|
|
|
def __init__(self, world_size, **kwargs):
|
|
self.world_size = world_size
|
|
for k in CUDAKernelTimeCategory:
|
|
setattr(self, k.value, kwargs.get(k.value, 0))
|
|
|
|
@property
|
|
def total(self):
|
|
return sum([getattr(self, k.value) for k in CUDAKernelTimeCategory])
|
|
|
|
def percentage(self) -> Dict:
|
|
return {
|
|
k.value: getattr(self, k.value) / self.total for k in CUDAKernelTimeCategory
|
|
}
|
|
|
|
def __add__(self, other):
|
|
return CUDAKernelTimeStat(
|
|
world_size=self.world_size + other.world_size,
|
|
**{
|
|
k.value: getattr(self, k.value) + getattr(other, k.value)
|
|
for k in CUDAKernelTimeCategory
|
|
},
|
|
)
|
|
|
|
def __truediv__(self, x):
|
|
assert self.world_size % x == 0
|
|
return CUDAKernelTimeStat(
|
|
world_size=self.world_size // x,
|
|
**{k.value: getattr(self, k.value) / x for k in CUDAKernelTimeCategory},
|
|
)
|
|
|
|
def gpu_average(self):
|
|
return self / self.world_size
|
|
|
|
def __repr__(self):
|
|
import tabulate
|
|
|
|
headers = [
|
|
"",
|
|
"Total",
|
|
"Computation",
|
|
"P2P Comm",
|
|
"Collective Comm",
|
|
"Memory IO",
|
|
"Idle",
|
|
"Misc",
|
|
]
|
|
line1 = [
|
|
"Time (s)",
|
|
self.total / 1e6,
|
|
*[getattr(self, k.value) / 1e6 for k in CUDAKernelTimeCategory],
|
|
]
|
|
line1 = [f"{x:.2f}" if isinstance(x, float) else x for x in line1]
|
|
line2 = [
|
|
"Percentage (%)",
|
|
"-",
|
|
*[f"{self.percentage()[k.value]:.2%}" for k in CUDAKernelTimeCategory],
|
|
]
|
|
tab_str = tabulate.tabulate(
|
|
[headers, line1, line2],
|
|
headers="firstrow",
|
|
tablefmt="fancy_grid",
|
|
stralign="center",
|
|
)
|
|
return (
|
|
f" Number of GPUs: {self.world_size} ".center(
|
|
len(tab_str.split("\n")[0]), "="
|
|
)
|
|
+ "\n"
|
|
+ tab_str
|
|
)
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class KernelEventEntry:
|
|
ts: int
|
|
tid: int
|
|
dur: int
|
|
category: CUDAKernelTimeCategory
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class KernelEventBoundary:
|
|
ts: int
|
|
is_start: bool
|
|
category: CUDAKernelTimeCategory
|
|
|
|
|
|
def kernelStatFromEvents(
|
|
entries: List[KernelEventEntry],
|
|
global_start_ts,
|
|
global_end_ts,
|
|
):
|
|
events: List[KernelEventBoundary] = []
|
|
for entry in entries:
|
|
events.append(KernelEventBoundary(entry.ts, True, entry.category))
|
|
events.append(KernelEventBoundary(entry.ts + entry.dur, False, entry.category))
|
|
# A trick to count for idle time waiting other processes
|
|
events.append(
|
|
KernelEventBoundary(global_start_ts, True, CUDAKernelTimeCategory.IDLE)
|
|
)
|
|
events.append(
|
|
KernelEventBoundary(global_end_ts, False, CUDAKernelTimeCategory.IDLE)
|
|
)
|
|
|
|
events.sort(key=lambda x: x.ts)
|
|
|
|
times = {k: 0 for k in CUDAKernelTimeCategory}
|
|
active = {k: 0 for k in CUDAKernelTimeCategory}
|
|
|
|
current_time = events[0].ts
|
|
|
|
for i in range(len(events)):
|
|
next_time = events[i].ts
|
|
|
|
# Priority: compute > communication > memory > misc > idle
|
|
if i > 0 and next_time != current_time:
|
|
duration = next_time - current_time
|
|
if active[CUDAKernelTimeCategory.COMPUTE] > 0:
|
|
times[CUDAKernelTimeCategory.COMPUTE] += duration
|
|
elif active[CUDAKernelTimeCategory.COLL_COMM] > 0:
|
|
times[CUDAKernelTimeCategory.COLL_COMM] += duration
|
|
elif active[CUDAKernelTimeCategory.P2P_COMM] > 0:
|
|
times[CUDAKernelTimeCategory.P2P_COMM] += duration
|
|
elif active[CUDAKernelTimeCategory.MEM] > 0:
|
|
times[CUDAKernelTimeCategory.MEM] += duration
|
|
elif active[CUDAKernelTimeCategory.MISC] > 0:
|
|
times[CUDAKernelTimeCategory.MISC] += duration
|
|
else:
|
|
times[CUDAKernelTimeCategory.IDLE] += duration
|
|
active[events[i].category] += 1 if events[i].is_start else -1
|
|
current_time = next_time
|
|
|
|
assert all(v == 0 for v in active.values()), active
|
|
return CUDAKernelTimeStat(world_size=1, **{k.value: v for k, v in times.items()})
|
|
|
|
|
|
async def _load_events_async(file_path, semaphore) -> List[Dict]:
|
|
import aiofiles
|
|
|
|
async with semaphore:
|
|
pid = int(file_path.rstrip(".json").split("_r")[-1])
|
|
async with aiofiles.open(file_path, "r") as f:
|
|
content = await f.read()
|
|
events = json.loads(content)["traceEvents"]
|
|
events = list(
|
|
filter(
|
|
lambda x: "cat" in x and x["cat"] in ["gpu_user_annotation", "kernel"],
|
|
events,
|
|
)
|
|
)
|
|
# Replace with the actual process id, starting from 0 to #gpus-1
|
|
for ev in events:
|
|
ev["pid"] = pid
|
|
return events, pid
|
|
|
|
|
|
async def _load_all_events(root_dir, mfc_name) -> List[Dict]:
|
|
trace_file_paths = []
|
|
for fn in os.listdir(root_dir):
|
|
if not fn.startswith(mfc_name):
|
|
continue
|
|
trace_file_paths.append(os.path.join(root_dir, fn))
|
|
|
|
# The JSON file can be large, up to 2GB. Load them concurrently.
|
|
semaphore = asyncio.Semaphore(8)
|
|
tasks = [_load_events_async(file_path, semaphore) for file_path in trace_file_paths]
|
|
|
|
all_events = {}
|
|
for coro in tqdm.asyncio.tqdm(
|
|
asyncio.as_completed(tasks), total=len(tasks), desc="Loading JSON files"
|
|
):
|
|
try:
|
|
events, pid = await coro
|
|
all_events[pid] = events
|
|
except Exception as e:
|
|
print(f"Error loading JSON file: {e}")
|
|
|
|
return all_events
|
|
|
|
|
|
def kernelStatFromTrace(root_dir: str, mfc_name: str):
|
|
cache_file = os.path.join(root_dir, f"_cached_{mfc_name}.json")
|
|
if os.path.exists(cache_file):
|
|
logger.info(f'Loading trace JSON files of MFC "{mfc_name}" from cache...')
|
|
with open(cache_file, "r") as f:
|
|
all_events = json.load(f)
|
|
all_events = {int(pid): v for pid, v in all_events.items()}
|
|
else:
|
|
if not any(fn.startswith(mfc_name) for fn in os.listdir(root_dir)):
|
|
raise RuntimeError(
|
|
f"No trace file found for the given MFC name: {mfc_name}."
|
|
)
|
|
|
|
load_json_tik = time.perf_counter()
|
|
logger.info(
|
|
f'Loading trace JSON files of MFC "{mfc_name}" concurrently from {root_dir}...'
|
|
)
|
|
all_events: Dict[int, List[Dict]] = asyncio.run(
|
|
_load_all_events(root_dir, mfc_name)
|
|
)
|
|
logger.info(
|
|
f"{len(all_events)} JSON file loaded. "
|
|
f"Time consumption: {time.perf_counter() - load_json_tik:.2f} secs. "
|
|
f"Processing..."
|
|
)
|
|
|
|
with open(cache_file, "w") as f:
|
|
json.dump(all_events, f)
|
|
|
|
# To remove the wait time from nccl send/recv, collect send/recv kernels annotations.
|
|
# These annotations look like "nccl:send 0->1". For each annotation, find the execution time
|
|
# of its pair. The actual execution time should the minimum of the two.
|
|
send_recv_annotations = {
|
|
pid: [
|
|
ev
|
|
for ev in events
|
|
if ev["cat"] == "gpu_user_annotation"
|
|
and ev["name"].startswith("nccl:recv")
|
|
or ev["name"].startswith("nccl:send")
|
|
]
|
|
for pid, events in all_events.items()
|
|
}
|
|
for events in send_recv_annotations.values():
|
|
events.sort(key=lambda x: x["ts"])
|
|
|
|
send_recv_time = {pid: [] for pid, events in send_recv_annotations.items()}
|
|
|
|
def _matches_next_sr(type_, src, dst):
|
|
if type_ == "send":
|
|
annot = send_recv_annotations[dst][0]
|
|
m = re.match(r"nccl:recv (\d+)<-(\d+)", annot["name"])
|
|
if not m:
|
|
return False
|
|
peer_dst, peer_src = map(int, m.groups())
|
|
if peer_src != src or peer_dst != dst:
|
|
return False
|
|
return True
|
|
else:
|
|
assert type_ == "recv"
|
|
annot = send_recv_annotations[src][0]
|
|
m = re.match(r"nccl:send (\d+)->(\d+)", annot["name"])
|
|
if not m:
|
|
return False
|
|
peer_src, peer_dst = map(int, m.groups())
|
|
if peer_src != src or peer_dst != dst:
|
|
return False
|
|
return True
|
|
|
|
def resolve_next_sr_time(pid):
|
|
# Resolve send/recv time recursively, just like a Tetris game
|
|
annot = send_recv_annotations[pid][0]
|
|
if annot["name"].startswith("nccl:send"):
|
|
src, dst = map(
|
|
int, re.match(r"nccl:send (\d+)->(\d+)", annot["name"]).groups()
|
|
)
|
|
assert src == pid, (src, pid)
|
|
while not _matches_next_sr("send", src, dst):
|
|
resolve_next_sr_time(dst)
|
|
else:
|
|
assert annot["name"].startswith("nccl:recv")
|
|
dst, src = map(
|
|
int, re.match(r"nccl:recv (\d+)<-(\d+)", annot["name"]).groups()
|
|
)
|
|
assert dst == pid, (dst, pid)
|
|
while not _matches_next_sr("recv", src, dst):
|
|
resolve_next_sr_time(src)
|
|
ev1, ev2 = send_recv_annotations[src].pop(0), send_recv_annotations[dst].pop(0)
|
|
dur = min(ev1["dur"], ev2["dur"])
|
|
send_recv_time[src].append(dur)
|
|
send_recv_time[dst].append(dur)
|
|
|
|
for pid in tqdm.tqdm(all_events, desc="Resolving send/recv times"):
|
|
while len(send_recv_annotations[pid]) > 0:
|
|
resolve_next_sr_time(pid)
|
|
|
|
kernel_events: Dict[int, List[KernelEventEntry]] = defaultdict(list)
|
|
global_start = min(ev["ts"] for events in all_events.values() for ev in events)
|
|
global_end = max(ev["ts"] for events in all_events.values() for ev in events)
|
|
for pid in tqdm.tqdm(all_events, desc="Processing events"):
|
|
for ev in all_events[pid]:
|
|
if ev["cat"] != "kernel":
|
|
continue
|
|
assert ev["dur"] > 0, ev
|
|
cat = CUDAKernelTimeCategory.from_name(ev["name"])
|
|
if cat == CUDAKernelTimeCategory.P2P_COMM:
|
|
assert len(send_recv_time[pid]) > 0
|
|
ev["dur"] = send_recv_time[pid].pop(0)
|
|
kernel_events[pid].append(
|
|
KernelEventEntry(
|
|
ts=ev["ts"], tid=ev["tid"], dur=ev["dur"], category=cat
|
|
)
|
|
)
|
|
assert all(len(times) == 0 for times in send_recv_time.values()), [
|
|
len(times) == 0 for times in send_recv_time.values()
|
|
]
|
|
for events in kernel_events.values():
|
|
events.sort(key=lambda x: x.ts)
|
|
|
|
x = None
|
|
for events in tqdm.tqdm(
|
|
kernel_events.values(),
|
|
total=len(kernel_events),
|
|
desc="Gathering kernel time stats for all processes...",
|
|
):
|
|
stats = kernelStatFromEvents(events, global_start, global_end)
|
|
if x is None:
|
|
x = stats
|
|
else:
|
|
x = x + stats
|
|
return x
|