Skip to content

Commit 446841a

Browse files
committed
Applied Linting
1 parent 16a9d66 commit 446841a

File tree

7 files changed

+34
-111
lines changed

7 files changed

+34
-111
lines changed

src/nvidia_resiliency_ext/attribution/mcp_integration/README.md

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -96,12 +96,7 @@ order = global_registry.get_execution_order(["log_analyzer", "fr_analyzer", "com
9696

9797
**Tool Types**:
9898
1. **Module tools**: One per registered module (`log_analyzer`, `fr_analyzer`, etc.)
99-
<<<<<<< HEAD
10099
2. **Utility tools**: `status`, `get_result`
101-
=======
102-
2. **Pipeline tool**: `run_pipeline` for multi-module execution
103-
3. **Utility tools**: `status`, `get_result`
104-
>>>>>>> bfd729b (Add MCP integration and changes in `attribution` modules to run with MCP)
105100

106101
**Resource Pattern**:
107102
```
@@ -118,11 +113,7 @@ Example: attribution://log_analyzer/f47ac10b-58cc-4372-a567-0e02b2c3d479
118113
#### NVRxMCPClient
119114
- Connects to a single MCP server
120115
- Async context manager pattern
121-
<<<<<<< HEAD
122116
- Methods: `run_module()`, `get_result()`
123-
=======
124-
- Methods: `run_module()`, `run_pipeline()`, `get_result()`
125-
>>>>>>> bfd729b (Add MCP integration and changes in `attribution` modules to run with MCP)
126117

127118
#### MultiServerClient
128119
- Manages multiple MCP servers

src/nvidia_resiliency_ext/attribution/mcp_integration/mcp_client.py

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,7 @@
1616
from mcp.client.session import ClientSession
1717
from mcp.client.stdio import StdioServerParameters, stdio_client
1818

19-
<<<<<<< HEAD
2019
from nvidia_resiliency_ext.attribution.mcp_integration.registry import deserialize_result
21-
=======
22-
from nvidia_resiliency_ext.attribution.mcp_integration.registry import (
23-
deserialize_result,
24-
)
25-
>>>>>>> bfd729b (Add MCP integration and changes in `attribution` modules to run with MCP)
2620

2721
logger = logging.getLogger(__name__)
2822

@@ -145,11 +139,7 @@ async def run_module(self, module_name: str, **kwargs) -> Dict[str, Any]:
145139
result_str = await self.call_tool(module_name, arguments)
146140
return deserialize_result(result_str)
147141

148-
<<<<<<< HEAD
149142
async def get_result(self, result_id: str) -> Dict[str, Any]:
150-
=======
151-
async def get_result(self, result_id: str) -> Dict[str, Any]:
152-
>>>>>>> bfd729b (Add MCP integration and changes in `attribution` modules to run with MCP)
153143
"""
154144
Retrieve a cached result by ID.
155145
@@ -213,11 +203,7 @@ class MultiServerClient:
213203

214204
def __init__(self):
215205
"""Initialize the multi-server client."""
216-
<<<<<<< HEAD
217206
self.servers: Dict[str, NVRxMCPClient] = {}
218-
=======
219-
self.servers: Dict[str, ClientSession] = {}
220-
>>>>>>> bfd729b (Add MCP integration and changes in `attribution` modules to run with MCP)
221207
self.module_to_server: Dict[str, str] = {}
222208

223209
def add_server(self, server_name: str, server_command: List[str]):
@@ -228,11 +214,7 @@ def add_server(self, server_name: str, server_command: List[str]):
228214
server_name: Name for the server
229215
server_command: Command to start the server
230216
"""
231-
<<<<<<< HEAD
232217
self.servers[server_name] = NVRxMCPClient(server_command)
233-
=======
234-
self.servers[server_name] = ClientSession(server_command)
235-
>>>>>>> bfd729b (Add MCP integration and changes in `attribution` modules to run with MCP)
236218

237219
async def connect_all(self):
238220
"""Connect to all registered servers."""

src/nvidia_resiliency_ext/attribution/mcp_integration/mcp_server.py

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,6 @@
99
import asyncio
1010
import json
1111
import logging
12-
<<<<<<< HEAD
13-
=======
14-
import uuid
15-
>>>>>>> bfd729b (Add MCP integration and changes in `attribution` modules to run with MCP)
1612
from typing import Any, Dict, List, Optional
1713

1814
from mcp.server import Server
@@ -202,11 +198,6 @@ async def _handle_module_execution(
202198
"""Execute a single attribution module."""
203199
# Apply default values from input schema
204200
arguments_with_defaults = self.registry.apply_defaults(module_name, arguments)
205-
<<<<<<< HEAD
206-
207-
=======
208-
209-
>>>>>>> bfd729b (Add MCP integration and changes in `attribution` modules to run with MCP)
210201
# Get or create module instance
211202
if module_name not in self.module_instances:
212203
# Convert arguments to argparse.Namespace
@@ -242,19 +233,11 @@ async def _handle_module_execution(
242233

243234
return [TextContent(type="text", text=serialize_result(response))]
244235

245-
<<<<<<< HEAD
246-
=======
247-
248-
>>>>>>> bfd729b (Add MCP integration and changes in `attribution` modules to run with MCP)
249236
async def run(self):
250237
"""Run the MCP server."""
251238
import os
252239

253-
<<<<<<< HEAD
254240
logger.info("Starting NVRX Attribution MCP Server")
255-
=======
256-
logger.info(f"Starting NVRX Attribution MCP Server")
257-
>>>>>>> bfd729b (Add MCP integration and changes in `attribution` modules to run with MCP)
258241
logger.info(f"Registered modules: {self.registry.list_modules()}, pid: {os.getpid()}")
259242

260243
async with stdio_server() as (read_stream, write_stream):
@@ -264,8 +247,4 @@ async def run(self):
264247

265248
def run_sync(self):
266249
"""Run the server synchronously."""
267-
<<<<<<< HEAD
268-
asyncio.run(self.run())
269-
=======
270250
asyncio.run(self.run())
271-
>>>>>>> bfd729b (Add MCP integration and changes in `attribution` modules to run with MCP)

src/nvidia_resiliency_ext/attribution/mcp_integration/module_definitions.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,6 @@
99
from nvidia_resiliency_ext.attribution.mcp_integration.registry import global_registry
1010
from nvidia_resiliency_ext.attribution.trace_analyzer.fr_attribution import CollectiveAnalyzer
1111

12-
<<<<<<< HEAD
13-
14-
=======
15-
>>>>>>> bfd729b (Add MCP integration and changes in `attribution` modules to run with MCP)
1612
def register_all_modules():
1713
"""Register all NVRX attribution modules with the global registry."""
1814

@@ -75,10 +71,6 @@ def register_all_modules():
7571
dependencies=[],
7672
)
7773

78-
<<<<<<< HEAD
79-
80-
=======
81-
>>>>>>> bfd729b (Add MCP integration and changes in `attribution` modules to run with MCP)
8274
def create_args_from_dict(module_name: str, config: dict) -> argparse.Namespace:
8375
"""
8476
Create an argparse.Namespace from a configuration dictionary.

src/nvidia_resiliency_ext/attribution/mcp_integration/registry.py

Lines changed: 0 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,10 @@
44
"""
55

66
import hashlib
7-
<<<<<<< HEAD
87
import json
98
import logging
109
from dataclasses import asdict, dataclass, is_dataclass
1110
from typing import Any, Dict, List, Optional, Type
12-
=======
13-
import inspect
14-
import json
15-
import logging
16-
from dataclasses import asdict, dataclass, is_dataclass
17-
from typing import Any, Callable, Dict, List, Optional, Type
18-
>>>>>>> bfd729b (Add MCP integration and changes in `attribution` modules to run with MCP)
1911

2012
from nvidia_resiliency_ext.attribution.base import NVRxAttribution
2113

@@ -85,16 +77,13 @@ def register(
8577
)
8678
self._modules[name] = metadata
8779

88-
<<<<<<< HEAD
8980
def unregister(self, name: str):
9081
"""Unregister a module."""
9182
if name in self._modules:
9283
del self._modules[name]
9384
else:
9485
raise ValueError(f"Module '{name}' not registered")
9586

96-
=======
97-
>>>>>>> bfd729b (Add MCP integration and changes in `attribution` modules to run with MCP)
9887
def get_module_metadata(self, name: str) -> Optional[ModuleMetadata]:
9988
"""Get metadata for a registered module."""
10089
return self._modules.get(name)
@@ -127,7 +116,6 @@ def apply_defaults(self, module_name: str, arguments: Dict[str, Any]) -> Dict[st
127116
metadata = self._modules.get(module_name)
128117
if not metadata:
129118
return arguments
130-
<<<<<<< HEAD
131119

132120
# Create a copy to avoid modifying the original
133121
result = dict(arguments)
@@ -136,26 +124,11 @@ def apply_defaults(self, module_name: str, arguments: Dict[str, Any]) -> Dict[st
136124
input_schema = metadata.input_schema
137125
properties = input_schema.get("properties", {})
138126

139-
=======
140-
141-
# Create a copy to avoid modifying the original
142-
result = dict(arguments)
143-
144-
# Get the properties from the input schema
145-
input_schema = metadata.input_schema
146-
properties = input_schema.get("properties", {})
147-
148-
>>>>>>> bfd729b (Add MCP integration and changes in `attribution` modules to run with MCP)
149127
# Apply defaults for missing arguments
150128
for param_name, param_schema in properties.items():
151129
if param_name not in result and "default" in param_schema:
152130
result[param_name] = param_schema["default"]
153131
logger.debug(f"Applied default for {param_name}: {param_schema['default']}")
154-
<<<<<<< HEAD
155-
156-
=======
157-
158-
>>>>>>> bfd729b (Add MCP integration and changes in `attribution` modules to run with MCP)
159132
return result
160133

161134
def cache_result(self, module_name: str, arguments: Dict[str, Any], result: Any):

src/nvidia_resiliency_ext/attribution/mcp_integration/server_launcher.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -57,13 +57,8 @@ def main():
5757
all_modules = global_registry.list_modules()
5858
for module in list(all_modules):
5959
if module not in args.modules:
60-
<<<<<<< HEAD
6160
global_registry.unregister(module)
6261
logger.info(f"Unregistered module: {module}")
63-
=======
64-
# Remove from registry (simplified - in production, use proper filtering)
65-
logger.info(f"Skipping module: {module}")
66-
>>>>>>> bfd729b (Add MCP integration and changes in `attribution` modules to run with MCP)
6762
logger.info(f"Enabled modules: {args.modules}")
6863
else:
6964
logger.info(f"Enabled modules: {global_registry.list_modules()}")

src/nvidia_resiliency_ext/attribution/trace_analyzer/fr_attribution.py

Lines changed: 34 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from collections import Counter, defaultdict
1515
from dataclasses import dataclass
1616
from pathlib import Path
17-
from typing import Dict, List, Tuple
17+
from typing import Dict, List, Tuple
1818

1919
from nvidia_resiliency_ext.attribution.base import AttributionState, NVRxAttribution
2020
from nvidia_resiliency_ext.attribution.utils import capture_logs
@@ -390,20 +390,22 @@ def group_collectives_by_windows(self):
390390
)
391391
already_participated = pg_window_participants[pg_window_key] & ranks_with_current_pg
392392
previous_participants = pg_window_participants[pg_window_key]
393-
393+
394394
has_previous_participants = len(previous_participants) > 0
395395
has_significant_new_ranks = len(ranks_with_current_pg - previous_participants) >= 2
396-
396+
397397
# Create new window if:
398398
# 1. Some ranks have already participated (same ranks coming back), OR
399399
# 2. We have previous participants and mostly/completely new ranks (different batch)
400400
should_create_new_window = False
401-
401+
402402
if current_pg not in pgs_with_active_ranks_last_iter:
403403
# PG was inactive - check if we need a new window
404-
if already_participated or (has_previous_participants and has_significant_new_ranks):
404+
if already_participated or (
405+
has_previous_participants and has_significant_new_ranks
406+
):
405407
should_create_new_window = True
406-
408+
407409
if should_create_new_window:
408410
# We're starting a new window/phase
409411
pg_window_counter[current_pg] += 1
@@ -554,8 +556,8 @@ def matching_collectives_per_process_group(collective_group):
554556
if c.state != 'scheduled':
555557
continue
556558
rank_counts['appeared'].append(c.file_id)
557-
# if get_correct_seq_id(c) <= max_completed_collective_seq_id:
558-
# rank_counts['mismatched'].append(c.file_id)
559+
# if get_correct_seq_id(c) <= max_completed_collective_seq_id:
560+
# rank_counts['mismatched'].append(c.file_id)
559561
appeared_rank_counts = Counter(rank_counts['appeared'])
560562
# Ranks with less number of enqueued collectives than max_enqueued_collective_seq_id -> host not making expected progress
561563
for rank_id in self.pg_configs[process_group]['ranks']:
@@ -717,24 +719,34 @@ def get_correct_seq_id(collective):
717719
for key, collective_group in self.collective_groups.items():
718720
logger.debug(f"key: {key}, collective_group: {collective_group}")
719721
matching_collectives_per_process_group((key, collective_group))
720-
722+
721723
# Cross-window matching: if the same PG has missing ranks in different windows,
722724
# try to match them across windows
723-
pg_all_windows = defaultdict(list) # pg_id -> list of (window_idx, identified_ranks, missing_ranks)
724-
725+
pg_all_windows = defaultdict(
726+
list
727+
) # pg_id -> list of (window_idx, identified_ranks, missing_ranks)
728+
725729
for pg_id, entries in missing_pg.items():
726730
for entry in entries:
727731
# entry format: (pg_id, pg_desc, op_type, size, dtype, total_nranks, identified_ranks, missing_ranks)
728732
pg_desc = entry[1] # e.g., "default_pg,0" or "default_pg,1"
729733
identified_ranks_str = entry[6]
730734
missing_ranks_str = entry[7]
731-
732-
identified_ranks = set(map(int, identified_ranks_str.split(','))) if identified_ranks_str else set()
733-
missing_ranks = set(map(int, missing_ranks_str.split(','))) if missing_ranks_str else set()
734-
735+
736+
identified_ranks = (
737+
set(map(int, identified_ranks_str.split(',')))
738+
if identified_ranks_str
739+
else set()
740+
)
741+
missing_ranks = (
742+
set(map(int, missing_ranks_str.split(','))) if missing_ranks_str else set()
743+
)
744+
735745
window_idx = int(pg_desc.split(',')[-1]) if ',' in pg_desc else 0
736-
pg_all_windows[pg_id].append((window_idx, identified_ranks, missing_ranks, entry))
737-
746+
pg_all_windows[pg_id].append(
747+
(window_idx, identified_ranks, missing_ranks, entry)
748+
)
749+
738750
# For each PG with multiple windows, try to match missing ranks across windows
739751
merged_missing_pg = defaultdict(list)
740752
for pg_id, windows_data in pg_all_windows.items():
@@ -743,19 +755,19 @@ def get_correct_seq_id(collective):
743755
for _, _, _, entry in windows_data:
744756
merged_missing_pg[pg_id].append(entry)
745757
continue
746-
758+
747759
# Multiple windows for this PG - try to match across windows
748760
all_identified = set()
749761
all_missing = set()
750762
representative_entry = windows_data[0][3] # Use first window's entry as template
751-
763+
752764
for window_idx, identified, missing, entry in windows_data:
753765
all_identified.update(identified)
754766
all_missing.update(missing)
755-
767+
756768
# Ranks that are identified in at least one window should not be considered missing
757769
truly_missing = all_missing - all_identified
758-
770+
759771
if truly_missing:
760772
# Create merged entry with truly missing ranks
761773
merged_entry = list(representative_entry)
@@ -767,7 +779,7 @@ def get_correct_seq_id(collective):
767779
# No truly missing ranks after cross-window matching
768780
# Don't add to merged_missing_pg (it's complete now)
769781
pass
770-
782+
771783
return completed_pg, merged_missing_pg
772784

773785
completed_pg, missing_pg = match_collectives()
@@ -944,7 +956,6 @@ def find_valid_paths(graph, start_node, visited):
944956
logger.debug(f"unique_paths: {unique_paths}")
945957
return grouped_pgs
946958

947-
948959
def process_file(self, filepath: str):
949960
"""
950961
Process a single file to extract collective operations and other metadata

0 commit comments

Comments
 (0)