Skip to content

Commit

Permalink
Render symbolic shape specialization (#39)
Browse files Browse the repository at this point in the history
  • Loading branch information
ezyang authored May 20, 2024
1 parent 41830f5 commit d2505df
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 16 deletions.
24 changes: 20 additions & 4 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,9 @@ pub fn parse_path(path: &PathBuf, config: ParseConfig) -> anyhow::Result<ParseOu
let mut metrics_index: CompilationMetricsIndex = FxIndexMap::default();
let stack_index: RefCell<StackIndex> = RefCell::new(FxHashMap::default());

let symbolic_shape_specialization_index: RefCell<SymbolicShapeSpecializationIndex> =
RefCell::new(FxHashMap::default());

// Store results in an output Vec<PathBuf, String>
let mut output: Vec<(PathBuf, String)> = Vec::new();

Expand Down Expand Up @@ -145,7 +148,12 @@ pub fn parse_path(path: &PathBuf, config: ParseConfig) -> anyhow::Result<ParseOu
})
.peekable();

let mut all_parsers = default_parsers(&tt, &stack_index);
let mut all_parsers = default_parsers(&tt);
all_parsers.push(Box::new(crate::parsers::CompilationMetricsParser {
tt: &tt,
stack_index: &stack_index,
symbolic_shape_specialization_index: &symbolic_shape_specialization_index,
})); // TODO: use own tt instances
all_parsers.extend(config.custom_parsers);

while let Some((lineno, line)) = iter.next() {
Expand All @@ -156,7 +164,7 @@ pub fn parse_path(path: &PathBuf, config: ParseConfig) -> anyhow::Result<ParseOu
let start = Instant::now();

let Some(caps) = re_glog.captures(&line) else {
eprintln!("Failed to parse glog prefix on line {}", lineno);
multi.suspend(|| eprintln!("Failed to parse glog prefix on line {}", lineno));
stats.fail_glog += 1;
continue;
};
Expand Down Expand Up @@ -274,11 +282,11 @@ pub fn parse_path(path: &PathBuf, config: ParseConfig) -> anyhow::Result<ParseOu
}
Err(err) => match parser.name() {
"dynamo_guards" => {
eprintln!("Failed to parse guards json: {}", err);
multi.suspend(|| eprintln!("Failed to parse guards json: {}", err));
stats.fail_dynamo_guards_json += 1;
}
name => {
eprintln!("Parser {name} failed: {err}");
multi.suspend(|| eprintln!("Parser {name} failed: {err}"));
stats.fail_parser += 1;
}
},
Expand All @@ -290,6 +298,14 @@ pub fn parse_path(path: &PathBuf, config: ParseConfig) -> anyhow::Result<ParseOu
unknown_stack_trie.insert(stack, None);
}

if let Some(specialization) = e.symbolic_shape_specialization {
symbolic_shape_specialization_index
.borrow_mut()
.entry(e.compile_id.clone())
.or_default()
.push(specialization);
}

if let Some(m) = e.compilation_metrics {
let compile_id_dir: PathBuf = e
.compile_id
Expand Down
38 changes: 26 additions & 12 deletions src/parsers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -307,9 +307,16 @@ impl StructuredLogParser for LinkParser {
}
}

fn format_stack(stack: &StackSummary) -> String {
let mut trie = StackTrieNode::default();
trie.insert_no_terminal(stack.to_vec());
trie.fmt(None).unwrap()
}

pub struct CompilationMetricsParser<'t> {
tt: &'t TinyTemplate<'t>,
stack_index: &'t RefCell<StackIndex>,
pub tt: &'t TinyTemplate<'t>,
pub stack_index: &'t RefCell<StackIndex>,
pub symbolic_shape_specialization_index: &'t RefCell<SymbolicShapeSpecializationIndex>,
}
impl StructuredLogParser for CompilationMetricsParser<'_> {
fn name(&self) -> &'static str {
Expand Down Expand Up @@ -341,16 +348,27 @@ impl StructuredLogParser for CompilationMetricsParser<'_> {
.stack_index
.borrow()
.get(&cid)
.map_or("".to_string(), |stack| {
let mut trie = StackTrieNode::default();
trie.insert_no_terminal(stack.to_vec());
trie.fmt(None).unwrap()
});
.map_or("".to_string(), format_stack);
let specializations = self
.symbolic_shape_specialization_index
.borrow_mut()
.remove(&cid)
.unwrap_or(Vec::new())
.drain(..)
.map(|spec| SymbolicShapeSpecializationContext {
symbol: spec.symbol.unwrap_or("".to_string()),
sources: spec.sources.unwrap_or(Vec::new()),
value: spec.value.unwrap_or("".to_string()),
user_stack_html: format_stack(&spec.user_stack.unwrap_or(Vec::new())),
stack_html: format_stack(&spec.stack.unwrap_or(Vec::new())),
})
.collect();
let context = CompilationMetricsContext {
css: crate::CSS,
m: &m,
compile_id: id,
stack_html: stack_html,
symbolic_shape_specializations: specializations,
};
let output = self.tt.render(&filename, &context)?;
simple_file_output(&filename, lineno, compile_id, &output)
Expand Down Expand Up @@ -401,10 +419,7 @@ impl StructuredLogParser for AOTAutogradBackwardCompilationMetricsParser<'_> {
}

// Register your parser here
pub fn default_parsers<'t>(
tt: &'t TinyTemplate<'t>,
stack_index: &'t RefCell<StackIndex>,
) -> Vec<Box<dyn StructuredLogParser + 't>> {
pub fn default_parsers<'t>(tt: &'t TinyTemplate<'t>) -> Vec<Box<dyn StructuredLogParser + 't>> {
// We need to use Box wrappers here because vecs in Rust need to have known size
let result: Vec<Box<dyn StructuredLogParser>> = vec![
Box::new(SentinelFileParser::new("optimize_ddp_split_graph", |e| {
Expand All @@ -430,7 +445,6 @@ pub fn default_parsers<'t>(
Box::new(DynamoGuardParser { tt }),
Box::new(InductorOutputCodeParser),
Box::new(OptimizeDdpSplitChildParser),
Box::new(CompilationMetricsParser { tt, stack_index }), // TODO: use own tt instances
Box::new(AOTAutogradBackwardCompilationMetricsParser { tt }), // TODO: use own tt instances
Box::new(LinkParser),
];
Expand Down
15 changes: 15 additions & 0 deletions src/templates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,21 @@ pub static TEMPLATE_COMPILATION_METRICS: &str = r#"
{{ for op in m.non_compliant_ops }}
<li> <code> {op} </code> </li>
{{ endfor }}
<h2>Symbolic shape specializations</h2>
<table>
<tr>
<th>Sym</th> <th>Source(s)</th> <th>Value</th> <th>User stack</th> <th>Framework stack</th>
</tr>
{{ for spec in symbolic_shape_specializations }}
<tr>
<td>{spec.symbol}</td>
<td>{{ for source in spec.sources }}{source}<br>{{ endfor }}</td>
<td>{spec.value}</td>
<td>{spec.user_stack_html | format_unescaped}</td>
<td>{spec.stack_html | format_unescaped}</td>
</tr>
{{ endfor }}
</table>
</body>
</html>
"#;
Expand Down
23 changes: 23 additions & 0 deletions src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ use std::sync::Mutex;
pub type ParseOutput = Vec<(PathBuf, String)>;
pub type CompilationMetricsIndex = FxIndexMap<Option<CompileId>, Vec<CompilationMetricsMetadata>>;
pub type StackIndex = FxHashMap<Option<CompileId>, StackSummary>; // NB: attempt is always 0 here
pub type SymbolicShapeSpecializationIndex =
FxHashMap<Option<CompileId>, Vec<SymbolicShapeSpecializationMetadata>>;

pub type FxIndexMap<K, V> = IndexMap<K, V, BuildHasherDefault<FxHasher>>;

Expand Down Expand Up @@ -265,6 +267,16 @@ pub struct AOTAutogradBackwardCompilationMetricsMetadata {
pub fail_reason: Option<String>,
}

#[derive(Debug, Deserialize, Serialize)]
pub struct SymbolicShapeSpecializationMetadata {
pub symbol: Option<String>,
pub sources: Option<Vec<String>>,
pub value: Option<String>,
pub reason: Option<String>,
pub stack: Option<StackSummary>,
pub user_stack: Option<StackSummary>,
}

#[derive(Debug, Serialize)]
pub struct AOTAutogradBackwardCompilationMetricsContext<'e> {
pub m: &'e AOTAutogradBackwardCompilationMetricsMetadata,
Expand All @@ -278,6 +290,7 @@ pub struct CompilationMetricsContext<'e> {
pub css: &'static str,
pub compile_id: String,
pub stack_html: String,
pub symbolic_shape_specializations: Vec<SymbolicShapeSpecializationContext>,
}

#[derive(Debug, Serialize)]
Expand Down Expand Up @@ -360,6 +373,7 @@ pub struct Envelope {
Option<AOTAutogradBackwardCompilationMetricsMetadata>,
pub graph_dump: Option<GraphDumpMetadata>,
pub link: Option<LinkMetadata>,
pub symbolic_shape_specialization: Option<SymbolicShapeSpecializationMetadata>,
}

#[derive(Debug, Deserialize, Serialize)]
Expand All @@ -385,3 +399,12 @@ pub struct IndexContext {
pub num_breaks: usize,
pub custom_header_html: String,
}

#[derive(Debug, Serialize)]
pub struct SymbolicShapeSpecializationContext {
pub symbol: String,
pub sources: Vec<String>,
pub value: String,
pub user_stack_html: String,
pub stack_html: String,
}

0 comments on commit d2505df

Please sign in to comment.