@@ -168,17 +168,33 @@ std::vector<std::unique_ptr<ComputeCapability>> GetCapability::Execute() {
168
168
auto connected_clusters = GetConnectedClusters (graph_viewer_, ng_clusters);
169
169
170
170
int no_of_clusters = 0 ;
171
+ std::vector<NodeIndex> prev_cluster;
172
+ bool try_next_cluster = false ;
171
173
172
174
for (auto this_cluster : connected_clusters) {
175
+ bool omit_subgraph = false ;
176
+ if (try_next_cluster) {
177
+ // no need to check previous cluster
178
+ for (auto idx : prev_cluster) {
179
+ if ((std::find (this_cluster.begin (), this_cluster.end (), idx)) == this_cluster.end ()) {
180
+ this_cluster.emplace_back (idx);
181
+ }
182
+ }
183
+ try_next_cluster = false ;
184
+ }
185
+
173
186
// If subgraph has less then three, graph is considered trivial unless its an epctx cluster
174
- if (this_cluster.size () < 3 ) {
187
+ if (!try_next_cluster && this_cluster.size () < 3 ) {
175
188
bool is_epctx_node = false ;
176
189
for (auto node_idx : this_cluster) {
177
190
if (graph_viewer_.GetNode (node_idx)->OpType () == " EPContext" )
178
191
is_epctx_node = true ;
179
192
}
180
- if (!is_epctx_node)
181
- continue ;
193
+ if (!is_epctx_node) {
194
+ omit_subgraph = true ;
195
+ prev_cluster = this_cluster;
196
+ try_next_cluster = true ;
197
+ }
182
198
}
183
199
184
200
std::vector<std::string> cluster_graph_inputs, cluster_inputs, cluster_outputs;
@@ -190,7 +206,7 @@ std::vector<std::unique_ptr<ComputeCapability>> GetCapability::Execute() {
190
206
cluster_inputs,
191
207
cluster_outputs);
192
208
193
- bool omit_subgraph = false ;
209
+
194
210
// Omitting zero dim subgraphs
195
211
for (auto index : this_cluster) {
196
212
const Node* node = graph_viewer_.GetNode (index);
0 commit comments