# PyTorch with WireGuard

By Toshihito Kikuchi

[PyTorch DDP](https://docs.pytorch.org/tutorials/intermediate/ddp_tutorial.html) (= Distributed Data Parallel) is becoming the industry standard to parallelize your model training process across multiple GPUs and machines. Since one of our missions is to provide scalable GPU computes, it’s essential to support PyTorch DDP on Kinesis Network.

Does Kinesis Network support PyTorch DDP? Yes, but it isn't quite "plug-and-play" yet — it requires a bit of manual tuning to get everything running smoothly. We will fully integrate it soon, but for now, you need some special configuration to run it. In this article, I’d like to explain why it’s a little bit tricky and what we're implementing behind the scene.

### Challenge in PyTorch networking

Okay, I know you have a model to train. You wrote a python script with [DistributedDataParallel](https://docs.pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html), something like this.

```python
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

def setup():
    # Initialize the process group
    # NCCL is the standard backend for NVIDIA GPUs
    dist.init_process_group(backend="nccl")
    torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))

def cleanup():
    dist.destroy_process_group()

def run_training():
    setup()

    local_rank = int(os.environ["LOCAL_RANK"])
    steps = int(os.environ["STEPS"])
    rank = int(os.environ["RANK"])
    device = torch.device(f"cuda:{local_rank}")

    # 1. Define a tiny model
    model = nn.Linear(10, 10).to(device)
    model = DDP(model, device_ids=[local_rank])

    # 2. Setup Loss and Optimizer
    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(model.parameters(), lr=0.001)

    # 3. Simple training loop (Synthetic data)
    print(f"[Rank {rank}] Starting training...")

    for step in range(steps):
        # Create random data on the fly
        inputs = torch.randn(20, 10).to(device)
        labels = torch.randn(20, 10).to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_fn(outputs, labels)
        loss.backward()
        optimizer.step()

        if step % 10 == 0 and rank == 0:
            print(f"Step {step} | Loss: {loss.item():.4f}")

    print(f"[Rank {rank}] Training complete.")
    cleanup()

if __name__ == "__main__":
    run_training()
```

We usually use `torchrun` to kick a distributed training process.  Below is the output to run a single-node, single-gpu training inside a container with detailed logs.

```log
root@2f3d44d8b0ea:/app# NCCL_DEBUG=INFO LOGLEVEL=DEBUG STEPS=30 torchrun \
  --nnodes=1 --nproc_per_node=1 --node_rank=0 \
  --rdzv_id=test_job --rdzv_backend=c10d --rdzv_endpoint=127.0.0.1:29500 \
  train.py
I0406 17:13:28.244000 1265 torch/distributed/run.py:735] Using nproc_per_node=1.
I0406 17:13:28.340000 1265 torch/distributed/launcher/api.py:131] Using default numa options = None
I0406 17:13:28.340000 1265 torch/distributed/launcher/api.py:224] Starting elastic_operator with launch configs:
I0406 17:13:28.340000 1265 torch/distributed/launcher/api.py:224]   entrypoint               : train.py
I0406 17:13:28.340000 1265 torch/distributed/launcher/api.py:224]   min_nodes                : 1
I0406 17:13:28.340000 1265 torch/distributed/launcher/api.py:224]   max_nodes                : 1
I0406 17:13:28.340000 1265 torch/distributed/launcher/api.py:224]   nproc_per_node           : 1
I0406 17:13:28.340000 1265 torch/distributed/launcher/api.py:224]   run_id                   : test_job
I0406 17:13:28.340000 1265 torch/distributed/launcher/api.py:224]   rdzv_backend             : c10d
I0406 17:13:28.340000 1265 torch/distributed/launcher/api.py:224]   rdzv_endpoint            : 127.0.0.1:29500
I0406 17:13:28.340000 1265 torch/distributed/launcher/api.py:224]   rdzv_configs             : {'timeout': 900}
I0406 17:13:28.340000 1265 torch/distributed/launcher/api.py:224]   max_restarts             : 0
I0406 17:13:28.340000 1265 torch/distributed/launcher/api.py:224]   monitor_interval         : 0.1
I0406 17:13:28.340000 1265 torch/distributed/launcher/api.py:224]   log_dir                  : /tmp/torchelastic_z0qxtusm
I0406 17:13:28.340000 1265 torch/distributed/launcher/api.py:224]   metrics_cfg              : {}
I0406 17:13:28.340000 1265 torch/distributed/launcher/api.py:224]   event_log_handler        : null
I0406 17:13:28.340000 1265 torch/distributed/launcher/api.py:224]   numa_options             : None
I0406 17:13:28.340000 1265 torch/distributed/launcher/api.py:224]   signals_to_handle        : SIGTERM,SIGINT,SIGHUP,SIGQUIT
I0406 17:13:28.340000 1265 torch/distributed/launcher/api.py:224]   duplicate_stdout_filters : []
I0406 17:13:28.340000 1265 torch/distributed/launcher/api.py:224]   duplicate_stderr_filters : []
I0406 17:13:28.340000 1265 torch/distributed/launcher/api.py:224]
I0406 17:13:28.344000 1265 torch/distributed/elastic/agent/server/api.py:898] [default] starting workers for entrypoint: python3
I0406 17:13:28.345000 1265 torch/distributed/elastic/agent/server/api.py:693] [default] Rendezvous'ing worker group
I0406 17:13:28.481000 1265 torch/distributed/elastic/agent/server/api.py:539] [default] Rendezvous complete for workers. Result:
I0406 17:13:28.481000 1265 torch/distributed/elastic/agent/server/api.py:539]   restart_count=0
I0406 17:13:28.481000 1265 torch/distributed/elastic/agent/server/api.py:539]   master_addr=2f3d44d8b0ea
I0406 17:13:28.481000 1265 torch/distributed/elastic/agent/server/api.py:539]   master_port=43093
I0406 17:13:28.481000 1265 torch/distributed/elastic/agent/server/api.py:539]   group_rank=0
I0406 17:13:28.481000 1265 torch/distributed/elastic/agent/server/api.py:539]   group_world_size=1
I0406 17:13:28.481000 1265 torch/distributed/elastic/agent/server/api.py:539]   local_ranks=[0]
I0406 17:13:28.481000 1265 torch/distributed/elastic/agent/server/api.py:539]   role_ranks=[0]
I0406 17:13:28.481000 1265 torch/distributed/elastic/agent/server/api.py:539]   global_ranks=[0]
I0406 17:13:28.481000 1265 torch/distributed/elastic/agent/server/api.py:539]   role_world_sizes=[1]
I0406 17:13:28.481000 1265 torch/distributed/elastic/agent/server/api.py:539]   global_world_sizes=[1]
I0406 17:13:28.481000 1265 torch/distributed/elastic/agent/server/api.py:539]   event_log_handler=null
I0406 17:13:28.481000 1265 torch/distributed/elastic/agent/server/api.py:539]
I0406 17:13:28.482000 1265 torch/distributed/elastic/agent/server/api.py:701] [default] Starting worker group
I0406 17:13:28.482000 1265 torch/distributed/elastic/agent/server/local_elastic_agent.py:299] use_agent_store: True
I0406 17:13:28.482000 1265 torch/distributed/elastic/agent/server/local_elastic_agent.py:195] Environment variable 'TORCHELASTIC_ENABLE_FILE_TIMER' not found. Do not start FileTimerServer.
I0406 17:13:28.482000 1265 torch/distributed/elastic/agent/server/local_elastic_agent.py:239] Environment variable 'TORCHELASTIC_HEALTH_CHECK_PORT' not found. Do not start health check.
2f3d44d8b0ea:1297:1297 [0] NCCL INFO ENV/Plugin: Could not find: libnccl-env.so
2f3d44d8b0ea:1297:1297 [0] NCCL INFO Bootstrap: Using eth0:172.17.0.2<0>
2f3d44d8b0ea:1297:1297 [0] NCCL INFO cudaDriverVersion 12080
2f3d44d8b0ea:1297:1297 [0] NCCL INFO NCCL version 2.28.9+cuda12.9
2f3d44d8b0ea:1297:1297 [0] NCCL INFO Comm config Blocking set to 1
2f3d44d8b0ea:1297:1297 [0] NCCL INFO NET/Plugin: Could not find: libnccl-net.so
2f3d44d8b0ea:1297:1297 [0] NCCL INFO Failed to open libibverbs.so[.1]
2f3d44d8b0ea:1297:1297 [0] NCCL INFO transport/net_ib.cc:852 -> 3
2f3d44d8b0ea:1297:1297 [0] NCCL INFO Failed to initialize NET plugin IB
2f3d44d8b0ea:1297:1297 [0] NCCL INFO NET/Socket : Using [0]eth0:172.17.0.2<0> [1]wg0:10.100.10.1<0>
2f3d44d8b0ea:1297:1297 [0] NCCL INFO Initialized NET plugin Socket
2f3d44d8b0ea:1297:1297 [0] NCCL INFO Assigned NET plugin Socket to comm
2f3d44d8b0ea:1297:1297 [0] NCCL INFO Using network Socket
2f3d44d8b0ea:1297:1297 [0] NCCL INFO ncclCommInitRankConfig comm 0x3cd2a790 rank 0 nranks 1 cudaDev 0 nvmlDev 0 busId 70 commId 0xbda3ad566558fee7 - Init START
2f3d44d8b0ea:1297:1297 [0] NCCL INFO RAS client listening socket at ::1<28028>
2f3d44d8b0ea:1297:1297 [0] NCCL INFO Bootstrap timings total 0.001224 (create 0.000070, send 0.000246, recv 0.000273, ring 0.000001, delay 0.000000)
2f3d44d8b0ea:1297:1297 [0] NCCL INFO NCCL_IGNORE_DISABLED_P2P set by environment to 1.
2f3d44d8b0ea:1297:1297 [0] NCCL INFO ncclTopoGetCpuAffinity: Affinity for GPU 0 is empty, ignoring. (GPU affinity =  ; CPU affinity = 0-27).
2f3d44d8b0ea:1297:1297 [0] NCCL INFO comm 0x3cd2a790 rank 0 nRanks 1 nNodes 1 localRanks 1 localRank 0 MNNVL 0
2f3d44d8b0ea:1297:1297 [0] NCCL INFO Channel 00/64 : 0
2f3d44d8b0ea:1297:1297 [0] NCCL INFO Channel 01/64 : 0
2f3d44d8b0ea:1297:1297 [0] NCCL INFO Channel 02/64 : 0
...snip...
2f3d44d8b0ea:1297:1297 [0] NCCL INFO Channel 63/64 : 0
2f3d44d8b0ea:1297:1297 [0] NCCL INFO Trees [0] -1/-1/-1->0->-1 [1] -1/-1/-1->0->-1 [2] -1/-1/-1->0->-1 [3] -1/-1/-1->0->-1 [4] -1/-1/-1->0->-1 [5] -1/-1/-1->0->-1 [6] -1/-1/-1->0->-1 [7] -1/-1/-1->0->-1 [8] -1/-1/-1->0->-1 [9] -1/-1/-1->0->-1 [10] -1/-1/-1->0->-1 [11] -1/-1/-1->0->-1 [12] -1/-1/-1->0->-1 [13] -1/-1/-1->0->-1 [14] -1/-1/-1->0->-1 [15] -1/-1/-1->0->-1 [16] -1/-1/-1->0->-1 [17] -1/-1/-1->0->-1 [18] -1/-1/-1->0->-1 [19] -1/-1/-1->0->-1 [20] -1/-1/-1->0->-1 [21] -1/-1/-1->0->-1 [22] -1/-1/-1->0->-1 [23] -1/-1/-1->0->-1 [24] -1/-1/-1->0->-1 [25] -1/-1/-1->0->-1 [26] -1/-1/-1->0->-1 [27] -1/-1/-1->0->-1 [28] -1/-1/-1->0->-1 [29] -1/-1/-1->0->-1 [30] -1/-1/-1->0->-1 [31] -1/-1/-1->0->-1 [32] -1/-1/-1->0->-1 [33] -1/-1/-1->0->-1 [34] -1/-1/-1->0->-1 [35] -1/-1/-1->0->-1 [36] -1/-1/-1->0->-1 [37] -1/-1/-1->0->-1 [38] -1/-1/-1->0->-1 [39] -1/-1/-1->0->-1 [40] -1/-1/-1->0->-1 [41] -1/-1/-1->0->-1 [42] -1/-1/-1->0->-1 [43] -1/-1/-1->0->-1 [44] -1/-1/-1->0->-1 [45] -1/-1/-1->0->-1 [46] -1/-1/-1->0->-1 [47
2f3d44d8b0ea:1297:1297 [0] NCCL INFO P2P Chunksize set to 524288
2f3d44d8b0ea:1297:1297 [0] NCCL INFO PROFILER/Plugin: Could not find: libnccl-profiler.so
2f3d44d8b0ea:1297:1297 [0] NCCL INFO Check P2P Type isAllDirectP2p 1 directMode 0 isAllCudaP2p 1
2f3d44d8b0ea:1297:1331 [0] NCCL INFO [Proxy Service] Device 0 CPU core 10
2f3d44d8b0ea:1297:1332 [0] NCCL INFO [Proxy Service UDS] Device 0 CPU core 9
2f3d44d8b0ea:1297:1297 [0] NCCL INFO TUNER/Plugin: Could not find: libnccl-tuner.so
2f3d44d8b0ea:1297:1297 [0] NCCL INFO 64 coll channels, 64 collnet channels, 0 nvls channels, 64 p2p channels, 64 p2p channels per peer
2f3d44d8b0ea:1297:1297 [0] NCCL INFO CC Off, workFifoBytes 1048576
2f3d44d8b0ea:1297:1297 [0] NCCL INFO ncclCommInitRankConfig comm 0x3cd2a790 rank 0 nranks 1 cudaDev 0 nvmlDev 0 busId 70 commId 0xbda3ad566558fee7 - Init COMPLETE
2f3d44d8b0ea:1297:1297 [0] NCCL INFO Init timings - ncclCommInitRankConfig: rank 0 nranks 1 total 0.17 (kernels 0.13, alloc 0.00, bootstrap 0.00, allgathers 0.00, topo 0.00, graphs 0.00, connections 0.02, rest 0.00)
[Rank 0] Starting training...
Step 0 | Loss: 1.1732
Step 10 | Loss: 1.1904
Step 20 | Loss: 1.5663
[Rank 0] Training complete.
2f3d44d8b0ea:1297:1297 [0] NCCL INFO comm 0x3cd2a790 rank 0 nranks 1 cudaDev 0 busId 70 - Destroy COMPLETE
2f3d44d8b0ea:1297:1297 [0] NCCL INFO ENV/Plugin: Closing env plugin ncclEnvDefault
I0406 17:13:32.492000 1265 torch/distributed/elastic/agent/server/api.py:917] [default] worker group successfully finished. Waiting 300 seconds for other agents to finish.
I0406 17:13:32.493000 1265 torch/distributed/elastic/agent/server/api.py:970] Local worker group finished (WorkerState.SUCCEEDED). Waiting 300 seconds for other agents to finish
I0406 17:13:32.493000 1265 torch/distributed/elastic/agent/server/api.py:984] Done waiting for other agents. Elapsed: 0.0004515647888183594 seconds
root@2f3d44d8b0ea:/app#
```

It looks pretty straightforward. Can we just specify `--nnodes` and `--node_rank` accordingly to run this on multiple gpu nodes? Well, unfortunately no, that’s why I’m writing this article.

The challenge is networking. In the example above, I specified `--rdzv_endpoint=127.0.0.1:29500`. Does this mean each node communicates with this endpoint during a training process, like the Hub-and-Spoke topology?  The answer is no.  As the name implies, it’s a rendezvous point.  Every node meets there, to retrieve the actual endpoint to communicate with. In the output above, you see `master_addr=2f3d44d8b0ea` and `master_port=43093` , that’s the actual endpoint each node communicates with. The master port, 43093 in this case, is dynamically assigned, which means we cannot predict which port needs to be opened. This is a problem in a Docker environment because we need to expose ports on creation but we don't know it on creation.  Should we use the host network `--network=host` or publish a wide range of ports like `-p 1000:65535`?  Apparently that design is far from ideal. We want our containers to be contained as much as possible.

In [the earlier article](https://docs.kinesis.network/blog/reaching-out-to-home-computers), I wrote we leverage [WireGuard](https://www.wireguard.com/) to bring home computers. We can get help from WireGuard to solve this PyTorch situation too. Once we set up a WireGuard network on every Docker containers, they communicate with one another via a single UDP port. This journey starts here.

### WireGuard Setup

Setting up a WireGuard network is pretty easy.  I skip the detailed setup steps in this article.  Basically you need two things: 1) create a conf file like /etc/wireguard/wg0.conf and run `wg-quick up wg0`.  Besides, when you create a Docker container, you need to specify `--cap-add=NET_ADMIN` and publish a UDP port.  Lastly, you need to install the packages `wireguard-tools iptables iproute2`.  The wireguard driver exists in the host as a part of Linux kernel, but you still need client tools to use it.

Once setup is done, you will see a virtual interface like `wg0`.

```log
root@2f3d44d8b0ea:/app# ip -4 a
1: lo: <LOOPBACK,UP,LOWER_UP> mtu 65536 qdisc noqueue state UNKNOWN group default qlen 1000
    inet 127.0.0.1/8 scope host lo
       valid_lft forever preferred_lft forever
2: eth0@if32: <BROADCAST,MULTICAST,UP,LOWER_UP> mtu 1500 qdisc noqueue state UP group default  link-netnsid 0
    inet 172.17.0.2/16 brd 172.17.255.255 scope global eth0
       valid_lft forever preferred_lft forever
3: wg0: <POINTOPOINT,NOARP,UP,LOWER_UP> mtu 1420 qdisc noqueue state UNKNOWN group default qlen 1000
    inet 10.100.10.1/24 scope global wg0
       valid_lft forever preferred_lft forever
```

### Debugging with GDB

Let's run `torchrun` with the WireGuard network address 10.100.10.1 as the rendezvous point.

```log
root@2f3d44d8b0ea:/app# NCCL_DEBUG=INFO LOGLEVEL=DEBUG STEPS=30 torchrun \
  --nnodes=1 --nproc_per_node=1 --node_rank=0 \
  --rdzv_id=test_job --rdzv_backend=c10d --rdzv_endpoint=10.100.10.1:29500 \
  train.py
I0406 17:54:35.221000 1339 torch/distributed/run.py:735] Using nproc_per_node=1.
I0406 17:54:35.316000 1339 torch/distributed/launcher/api.py:131] Using default numa options = None
I0406 17:54:35.317000 1339 torch/distributed/launcher/api.py:224] Starting elastic_operator with launch configs:
I0406 17:54:35.317000 1339 torch/distributed/launcher/api.py:224]   entrypoint               : train.py
I0406 17:54:35.317000 1339 torch/distributed/launcher/api.py:224]   min_nodes                : 1
I0406 17:54:35.317000 1339 torch/distributed/launcher/api.py:224]   max_nodes                : 1
I0406 17:54:35.317000 1339 torch/distributed/launcher/api.py:224]   nproc_per_node           : 1
I0406 17:54:35.317000 1339 torch/distributed/launcher/api.py:224]   run_id                   : test_job
I0406 17:54:35.317000 1339 torch/distributed/launcher/api.py:224]   rdzv_backend             : c10d
I0406 17:54:35.317000 1339 torch/distributed/launcher/api.py:224]   rdzv_endpoint            : 10.100.10.1:29500
I0406 17:54:35.317000 1339 torch/distributed/launcher/api.py:224]   rdzv_configs             : {'timeout': 900}
I0406 17:54:35.317000 1339 torch/distributed/launcher/api.py:224]   max_restarts             : 0
I0406 17:54:35.317000 1339 torch/distributed/launcher/api.py:224]   monitor_interval         : 0.1
I0406 17:54:35.317000 1339 torch/distributed/launcher/api.py:224]   log_dir                  : /tmp/torchelastic_q6zcchrr
I0406 17:54:35.317000 1339 torch/distributed/launcher/api.py:224]   metrics_cfg              : {}
I0406 17:54:35.317000 1339 torch/distributed/launcher/api.py:224]   event_log_handler        : null
I0406 17:54:35.317000 1339 torch/distributed/launcher/api.py:224]   numa_options             : None
I0406 17:54:35.317000 1339 torch/distributed/launcher/api.py:224]   signals_to_handle        : SIGTERM,SIGINT,SIGHUP,SIGQUIT
I0406 17:54:35.317000 1339 torch/distributed/launcher/api.py:224]   duplicate_stdout_filters : []
I0406 17:54:35.317000 1339 torch/distributed/launcher/api.py:224]   duplicate_stderr_filters : []
I0406 17:54:35.317000 1339 torch/distributed/launcher/api.py:224]
[E406 17:55:24.823650196 socket.cpp:1028] [c10d] The client socket has timed out after 60000ms while trying to connect to (10.100.10.1, 29500).
[W406 17:55:24.824122039 TCPStore.cpp:340] [c10d] TCP client failed to connect/validate to host 10.100.10.1:29500 - retrying (try=0, timeout=60000ms, delay=46179ms): The client socket has timed out after 60000ms while trying to connect to (10.100.10.1, 29500).
Exception raised from throwTimeoutError at /pytorch/torch/csrc/distributed/c10d/socket.cpp:1030 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) + 0x9d (0x71c56a97205d in /usr/local/lib/python3.12/dist-packages/torch/lib/libc10.so)
frame #1: <unknown function> + 0x16daa1c (0x71c4c09eea1c in /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_cpu.so)
frame #2: <unknown function> + 0x6b25497 (0x71c4c5e39497 in /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_cpu.so)
frame #3: <unknown function> + 0x6b256cf (0x71c4c5e396cf in /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_cpu.so)
frame #4: <unknown function> + 0x6b25b57 (0x71c4c5e39b57 in /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_cpu.so)
frame #5: <unknown function> + 0x6a87113 (0x71c4c5d9b113 in /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_cpu.so)
frame #6: c10d::TCPStore::TCPStore(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, c10d::TCPStoreOptions const&) + 0x41d (0x71c4c5da1ead in /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_cpu.so)
frame #7: <unknown function> + 0xecbf51 (0x71c4d5872f51 in /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_python.so)
frame #8: <unknown function> + 0x411060 (0x71c4d4db8060 in /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_python.so)
frame #9: /usr/bin/python3() [0x581e9f]
frame #10: _PyObject_MakeTpCall + 0x13e (0x54903e in /usr/bin/python3)
...snip...
frame #29: <unknown function> + 0x2a1ca (0x71c59dc851ca in /lib/x86_64-linux-gnu/libc.so.6)
frame #30: __libc_start_main + 0x8b (0x71c59dc8528b in /lib/x86_64-linux-gnu/libc.so.6)
frame #31: _start + 0x25 (0x6576c5 in /usr/bin/python3)

[E406 17:56:55.628826474 socket.cpp:1028] [c10d] The client socket has timed out after 60000ms while trying to connect to (10.100.10.1, 29500).
[E406 17:56:55.628974143 TCPStore.cpp:328] [c10d] TCP client failed to connect/validate to host 10.100.10.1:29500 - timed out (try=1, timeout=60000ms): The client socket has timed out after 60000ms while trying to connect to (10.100.10.1, 29500).
Exception raised from throwTimeoutError at /pytorch/torch/csrc/distributed/c10d/socket.cpp:1030 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) + 0x9d (0x71c56a97205d in /usr/local/lib/python3.12/dist-packages/torch/lib/libc10.so)
frame #1: <unknown function> + 0x16daa1c (0x71c4c09eea1c in /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_cpu.so)
frame #2: <unknown function> + 0x6b25497 (0x71c4c5e39497 in /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_cpu.so)
frame #3: <unknown function> + 0x6b256cf (0x71c4c5e396cf in /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_cpu.so)
frame #4: <unknown function> + 0x6b25b57 (0x71c4c5e39b57 in /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_cpu.so)
frame #5: <unknown function> + 0x6a87113 (0x71c4c5d9b113 in /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_cpu.so)
frame #6: c10d::TCPStore::TCPStore(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, c10d::TCPStoreOptions const&) + 0x41d (0x71c4c5da1ead in /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_cpu.so)
frame #7: <unknown function> + 0xecbf51 (0x71c4d5872f51 in /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_python.so)
frame #8: <unknown function> + 0x411060 (0x71c4d4db8060 in /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_python.so)
frame #9: /usr/bin/python3() [0x581e9f]
frame #10: _PyObject_MakeTpCall + 0x13e (0x54903e in /usr/bin/python3)
...snip...
frame #30: __libc_start_main + 0x8b (0x71c59dc8528b in /lib/x86_64-linux-gnu/libc.so.6)
frame #31: _start + 0x25 (0x6576c5 in /usr/bin/python3)

Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/distributed/elastic/rendezvous/c10d_rendezvous_backend.py", line 156, in _create_tcp_store
    store = TCPStore(
            ^^^^^^^^^
torch.distributed.DistNetworkError: The client socket has timed out after 60000ms while trying to connect to (10.100.10.1, 29500).

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/usr/local/bin/torchrun", line 6, in <module>
    sys.exit(main())
             ^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 362, in wrapper
    return f(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/distributed/run.py", line 990, in main
    run(args)
  File "/usr/local/lib/python3.12/dist-packages/torch/distributed/run.py", line 981, in run
    elastic_launch(
  File "/usr/local/lib/python3.12/dist-packages/torch/distributed/launcher/api.py", line 170, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/distributed/launcher/api.py", line 284, in launch_agent
    rdzv_handler=rdzv_registry.get_rendezvous_handler(rdzv_parameters),
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/distributed/elastic/rendezvous/registry.py", line 96, in get_rendezvous_handler
    return handler_registry.create_handler(params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/distributed/elastic/rendezvous/api.py", line 377, in create_handler
    handler = creator(params)
              ^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/distributed/elastic/rendezvous/registry.py", line 46, in _create_c10d_handler
    backend, store = create_backend(params)
                     ^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/distributed/elastic/rendezvous/c10d_rendezvous_backend.py", line 254, in create_backend
    store = _create_tcp_store(params)
            ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/distributed/elastic/rendezvous/c10d_rendezvous_backend.py", line 180, in _create_tcp_store
    raise RendezvousConnectionError(
torch.distributed.elastic.rendezvous.api.RendezvousConnectionError: The connection to the C10d store has failed. See inner exception for details.
```

It failed with timeout! It seems that the process cannot connect to the rendezvous point 10.100.10.1:29500.

The log says `throwTimeoutError` was raised at [this line](https://github.com/pytorch/pytorch/blob/v2.11.0/torch/csrc/distributed/c10d/socket.cpp#L1030).  Is it time to file an issue there? No, it’s not what this blog does. It’s time to attach debugger!

As you may know, `torchrun` is just a python script to kick `torch.distributed.run`.

```log
root@2f3d44d8b0ea:/app# which torchrun
/usr/local/bin/torchrun
root@2f3d44d8b0ea:/app# cat /usr/local/bin/torchrun
#!/usr/bin/python3
import sys
from torch.distributed.run import main
if __name__ == '__main__':
    sys.argv[0] = sys.argv[0].removesuffix('.exe')
    sys.exit(main())
```

To debug it, you launch python with gdb, run the script, and break it when it’s stuck before the timeout exception is thrown.

```log
root@2f3d44d8b0ea:/app# NCCL_DEBUG=INFO LOGLEVEL=DEBUG STEPS=30 gdb -q /usr/bin/python3
Reading symbols from /usr/bin/python3...
(No debugging symbols found in /usr/bin/python3)
(gdb) set pagination off
(gdb) r /usr/local/bin/torchrun --nnodes=1 --nproc_per_node=1 --node_rank=0 --rdzv_id=test_job --rdzv_backend=c10d --rdzv_endpoint=10.100.10.1:29500 train.py
Starting program: /usr/bin/python3 /usr/local/bin/torchrun --nnodes=1 --nproc_per_node=1 --node_rank=0 --rdzv_id=test_job --rdzv_backend=c10d --rdzv_endpoint=10.100.10.1:29500 train.py
warning: Error disabling address space randomization: Operation not permitted
[Thread debugging using libthread_db enabled]
Using host libthread_db library "/lib/x86_64-linux-gnu/libthread_db.so.1".
[New Thread 0x7733f25ff6c0 (LWP 1528)]
...snip...
[New Thread 0x7733e55e56c0 (LWP 1554)]
I0406 18:11:48.680000 1525 torch/distributed/run.py:735] Using nproc_per_node=1.
[New Thread 0x7733dfe2a6c0 (LWP 1555)]
I0406 18:11:48.788000 1525 torch/distributed/launcher/api.py:131] Using default numa options = None
I0406 18:11:48.789000 1525 torch/distributed/launcher/api.py:224] Starting elastic_operator with launch configs:
I0406 18:11:48.789000 1525 torch/distributed/launcher/api.py:224]   entrypoint               : train.py
I0406 18:11:48.789000 1525 torch/distributed/launcher/api.py:224]   min_nodes                : 1
I0406 18:11:48.789000 1525 torch/distributed/launcher/api.py:224]   max_nodes                : 1
I0406 18:11:48.789000 1525 torch/distributed/launcher/api.py:224]   nproc_per_node           : 1
I0406 18:11:48.789000 1525 torch/distributed/launcher/api.py:224]   run_id                   : test_job
I0406 18:11:48.789000 1525 torch/distributed/launcher/api.py:224]   rdzv_backend             : c10d
I0406 18:11:48.789000 1525 torch/distributed/launcher/api.py:224]   rdzv_endpoint            : 10.100.10.1:29500
I0406 18:11:48.789000 1525 torch/distributed/launcher/api.py:224]   rdzv_configs             : {'timeout': 900}
I0406 18:11:48.789000 1525 torch/distributed/launcher/api.py:224]   max_restarts             : 0
I0406 18:11:48.789000 1525 torch/distributed/launcher/api.py:224]   monitor_interval         : 0.1
I0406 18:11:48.789000 1525 torch/distributed/launcher/api.py:224]   log_dir                  : /tmp/torchelastic_iju9cg3h
I0406 18:11:48.789000 1525 torch/distributed/launcher/api.py:224]   metrics_cfg              : {}
I0406 18:11:48.789000 1525 torch/distributed/launcher/api.py:224]   event_log_handler        : null
I0406 18:11:48.789000 1525 torch/distributed/launcher/api.py:224]   numa_options             : None
I0406 18:11:48.789000 1525 torch/distributed/launcher/api.py:224]   signals_to_handle        : SIGTERM,SIGINT,SIGHUP,SIGQUIT
I0406 18:11:48.789000 1525 torch/distributed/launcher/api.py:224]   duplicate_stdout_filters : []
I0406 18:11:48.789000 1525 torch/distributed/launcher/api.py:224]   duplicate_stderr_filters : []
I0406 18:11:48.789000 1525 torch/distributed/launcher/api.py:224]
^C
Thread 1 "pt_elastic" received signal SIGINT, Interrupt.
0x0000773541c2dadf in __GI___clock_nanosleep (clock_id=clock_id@entry=0, flags=flags@entry=0, req=0x7ffc61d9df30, rem=0x0)
    at ../sysdeps/unix/sysv/linux/clock_nanosleep.c:78
warning: 78     ../sysdeps/unix/sysv/linux/clock_nanosleep.c: No such file or directory
(gdb) bt
#0  0x0000773541c2dadf in __GI___clock_nanosleep (clock_id=clock_id@entry=0, flags=flags@entry=0, req=0x7ffc61d9df30, rem=0x0)
    at ../sysdeps/unix/sysv/linux/clock_nanosleep.c:78
#1  0x0000773541c3aa27 in __GI___nanosleep (req=<optimized out>, rem=<optimized out>) at ../sysdeps/unix/sysv/linux/nanosleep.c:25
#2  0x0000773469c3849a in c10d::detail::(anonymous namespace)::SocketConnectOp::tryConnect(int) ()
   from /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_cpu.so
#3  0x0000773469c396cf in c10d::detail::(anonymous namespace)::SocketConnectOp::run() [clone .constprop.0] ()
   from /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_cpu.so
#4  0x0000773469c39b57 in c10d::detail::Socket::connect(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, unsigned short, c10d::detail::SocketOptions const&) () from /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_cpu.so
#5  0x0000773469b9b113 in c10d::detail::TCPClient::connect(c10d::detail::SocketAddress const&, c10d::TCPStoreOptions const&, std::shared_ptr<c10d::Backoff>)
    () from /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_cpu.so
#6  0x0000773469ba1ead in c10d::TCPStore::TCPStore(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, c10d::TCPStoreOptions const&) () from /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_cpu.so
#7  0x0000773479672f51 in pybind11::cpp_function::initialize<pybind11::detail::initimpl::factory<torch::distributed::c10d::(anonymous namespace)::c10d_init(_object*, _object*)::{lambda(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, unsigned short, std::optional<int>, bool, std::chrono::duration<long, std::ratio<1l, 1000l> >, bool, bool, std::optional<int>, bool)#1}, pybind11::detail::void_type (*)(), c10::intrusive_ptr<c10d::TCPStore, c10::detail::intrusive_target_default_null_type<c10d::TCPStore> > (std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, unsigned short, std::optional<int>, bool, std::chrono::duration<long, std::ratio<1l, 1000l> >, bool, bool, std::optional<int>, bool), pybind11::detail::void_type ()>::execute<pybind11::class_<c10d::TCPStore, c10::intrusive_ptr<c10d::TCPStore, c10::detail::intrusive_target_default_null_type<c10d::TCPStore> > >, pybind11::arg, pybind11::arg, pybind11::arg_v, pybind11::arg_v, pybind11::arg_v, pybind11::arg_v, pybind11::arg_v, pybind11::arg_v, pybind11::arg_v, char [24]>(pybind11::class_<c10d::TCPStore, c10::intrusive_ptr<c10d::TCPStore, c10::detail::intrusive_target_default_null_type<c10d::TCPStore> > >&, pybind11::arg const&, pybind11::arg const&, pybind11::arg_v const&, pybind11::arg_v const&, pybind11::arg_v const&, pybind11::arg_v const&, pybind11::arg_v const&, pybind11::arg_v const&, pybind11::arg_v const&, char const (&) [24]) &&::{lambda(pybind11::detail::value_and_holder&, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, unsigned short, std::optional<int>, bool, std::chrono::duration<long, std::ratio<1l, 1000l> >, bool, bool, std::optional<int>, bool)#1}, void, pybind11::detail::value_and_holder&, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, unsigned short, std::optional<int>, bool, std::chrono::duration<long, std::ratio<1l, 1000l> >, bool, bool, std::optional<int>, bool, pybind11::name, pybind11::is_method, pybind11::sibling, pybind11::detail::is_new_style_constructor, pybind11::arg, pybind11::arg, pybind11::arg_v, pybind11::arg_v, pybind11::arg_v, pybind11::arg_v, pybind11::arg_v, pybind11::arg_v, pybind11::arg_v, char [24]>(pybind11::class_<c10d::TCPStore, c10::intrusive_ptr<c10d::TCPStore, c10::detail::intrusive_target_default_null_type<c10d::TCPStore> > >&&, void (*)(pybind11::detail::value_and_holder&, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, unsigned short, std::optional<int>, bool, std::chrono::duration<long, std::ratio<1l, 1000l> >, bool, bool, std::optional<int>, bool), pybind11::name const&, pybind11::is_method const&, pybind11::sibling const&, pybind11::detail::is_new_style_constructor const&, pybind11::arg const&, pybind11::arg const&, pybind11::arg_v const&, pybind11::arg_v const&, pybind11::arg_v const&, pybind11::arg_v const&, pybind11::arg_v const&, pybind11::arg_v const&, pybind11::arg_v const&, char const (&) [24])::{lambda(pybind11::detail::function_call&)#1}::_FUN(pybind11::detail::function_call&) ()
   from /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_python.so
#8  0x0000773478bb8060 in pybind11::cpp_function::dispatcher(_object*, _object*, _object*) ()
   from /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_python.so
...snip...
#28 0x00000000006bc3ed in Py_BytesMain ()
#29 0x0000773541b6b1ca in __libc_start_call_main (main=main@entry=0x518930, argc=argc@entry=9, argv=argv@entry=0x7ffc61d9fdf8)
    at ../sysdeps/nptl/libc_start_call_main.h:58
#30 0x0000773541b6b28b in __libc_start_main_impl (main=0x518930, argc=9, argv=0x7ffc61d9fdf8, init=<optimized out>, fini=<optimized out>,
    rtld_fini=<optimized out>, stack_end=0x7ffc61d9fde8) at ../csu/libc-start.c:360
#31 0x00000000006576c5 in _start ()
(gdb)
```

It’s running [`SocketConnectOp::tryConnect`](https://github.com/pytorch/pytorch/blob/v2.11.0/torch/csrc/distributed/c10d/socket.cpp#L824) , which calls the standard `connect` function via `tryConnectCore` , that is expected to fail.  Let’s double check. You can just set a breakpoint there.

```log
root@2f3d44d8b0ea:/app# NCCL_DEBUG=INFO LOGLEVEL=DEBUG STEPS=30 gdb -q /usr/bin/python3
Reading symbols from /usr/bin/python3...
(No debugging symbols found in /usr/bin/python3)
(gdb) set pagination off
(gdb) b SocketConnectOp::tryConnect
Function "SocketConnectOp::tryConnect" not defined.
Make breakpoint pending on future shared library load? (y or [n]) y
Breakpoint 1 (SocketConnectOp::tryConnect) pending.
(gdb) r /usr/local/bin/torchrun --nnodes=1 --nproc_per_node=1 --node_rank=0 --rdzv_id=test_job --rdzv_backend=c10d --rdzv_endpoint=10.100.10.1:29500 train.py
Starting program: /usr/bin/python3 /usr/local/bin/torchrun --nnodes=1 --nproc_per_node=1 --node_rank=0 --rdzv_id=test_job --rdzv_backend=c10d --rdzv_endpoint=10.100.10.1:29500 train.py
warning: Error disabling address space randomization: Operation not permitted
[Thread debugging using libthread_db enabled]
Using host libthread_db library "/lib/x86_64-linux-gnu/libthread_db.so.1".
[New Thread 0x73fda9dff6c0 (LWP 1786)]
...
[New Thread 0x73fd9cde56c0 (LWP 1812)]
I0406 19:03:29.615000 1783 torch/distributed/run.py:735] Using nproc_per_node=1.
[New Thread 0x73fd977156c0 (LWP 1813)]
I0406 19:03:29.701000 1783 torch/distributed/launcher/api.py:131] Using default numa options = None
I0406 19:03:29.702000 1783 torch/distributed/launcher/api.py:224] Starting elastic_operator with launch configs:
I0406 19:03:29.702000 1783 torch/distributed/launcher/api.py:224]   entrypoint               : train.py
I0406 19:03:29.702000 1783 torch/distributed/launcher/api.py:224]   min_nodes                : 1
I0406 19:03:29.702000 1783 torch/distributed/launcher/api.py:224]   max_nodes                : 1
I0406 19:03:29.702000 1783 torch/distributed/launcher/api.py:224]   nproc_per_node           : 1
I0406 19:03:29.702000 1783 torch/distributed/launcher/api.py:224]   run_id                   : test_job
I0406 19:03:29.702000 1783 torch/distributed/launcher/api.py:224]   rdzv_backend             : c10d
I0406 19:03:29.702000 1783 torch/distributed/launcher/api.py:224]   rdzv_endpoint            : 10.100.10.1:29500
I0406 19:03:29.702000 1783 torch/distributed/launcher/api.py:224]   rdzv_configs             : {'timeout': 900}
I0406 19:03:29.702000 1783 torch/distributed/launcher/api.py:224]   max_restarts             : 0
I0406 19:03:29.702000 1783 torch/distributed/launcher/api.py:224]   monitor_interval         : 0.1
I0406 19:03:29.702000 1783 torch/distributed/launcher/api.py:224]   log_dir                  : /tmp/torchelastic_wofb5igg
I0406 19:03:29.702000 1783 torch/distributed/launcher/api.py:224]   metrics_cfg              : {}
I0406 19:03:29.702000 1783 torch/distributed/launcher/api.py:224]   event_log_handler        : null
I0406 19:03:29.702000 1783 torch/distributed/launcher/api.py:224]   numa_options             : None
I0406 19:03:29.702000 1783 torch/distributed/launcher/api.py:224]   signals_to_handle        : SIGTERM,SIGINT,SIGHUP,SIGQUIT
I0406 19:03:29.702000 1783 torch/distributed/launcher/api.py:224]   duplicate_stdout_filters : []
I0406 19:03:29.702000 1783 torch/distributed/launcher/api.py:224]   duplicate_stderr_filters : []
I0406 19:03:29.702000 1783 torch/distributed/launcher/api.py:224]

Thread 1 "pt_elastic" hit Breakpoint 1, 0x000073fe214380b0 in c10d::detail::(anonymous namespace)::SocketConnectOp::tryConnect(int) ()
   from /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_cpu.so
(gdb) b connect
Breakpoint 2 at 0x73fe2139b070 (17 locations)
(gdb) c
Continuing.

Thread 1 "pt_elastic" hit Breakpoint 2.17, __libc_connect (fd=24, addr=addr@entry=..., len=len@entry=16) at ../sysdeps/unix/sysv/linux/connect.c:24
warning: 24     ../sysdeps/unix/sysv/linux/connect.c: No such file or directory
(gdb) bt
#0  __libc_connect (fd=24, addr=addr@entry=..., len=len@entry=16) at ../sysdeps/unix/sysv/linux/connect.c:24
#1  0x000073fef94f7d04 in reopen (statp=statp@entry=0x73fef95b8680 <_res>, terrno=terrno@entry=0x7ffd0220c058, ns=ns@entry=0) at ./resolv/res_send.c:856
#2  0x000073fef94f88ee in send_dg (ansp2_malloced=<optimized out>, resplen2=<optimized out>, anssizp2=<optimized out>, ansp2=<optimized out>,
    anscp=<optimized out>, gotsomewhere=<synthetic pointer>, v_circuit=<synthetic pointer>, ns=<optimized out>, terrno=0x7ffd0220c058,
    anssizp=0x7ffd0220c190, ansp=0x7ffd0220c048, buflen2=<optimized out>, buf2=<optimized out>, buflen=<optimized out>, buf=<optimized out>,
    statp=<optimized out>) at ./resolv/res_send.c:957
#3  __GI___res_context_send (ctx=ctx@entry=0x187d3590, buf=buf@entry=0x7ffd0220c240 "\267\252\001", buflen=<optimized out>, buf2=buf2@entry=0x0,
    buflen2=buflen2@entry=0, ans=<optimized out>, ans@entry=0x7ffd0220ca80 "", anssiz=<optimized out>, ansp=<optimized out>, ansp2=<optimized out>,
    nansp2=<optimized out>, resplen2=<optimized out>, ansp2_malloced=<optimized out>) at ./resolv/res_send.c:373
#4  0x000073fef94f6217 in __GI___res_context_query (ctx=ctx@entry=0x187d3590, name=name@entry=0x7ffd0220ce80 "1.10.100.10.in-addr.arpa", class=class@entry=1,
    type=type@entry=12, answer=answer@entry=0x7ffd0220ca80 "", anslen=anslen@entry=1024, answerp=0x7ffd0220c718, answerp2=0x0, nanswerp2=0x0, resplen2=0x0,
    answerp2_malloced=0x0) at ./resolv/res_query.c:218
#5  0x000073fef94ef564 in __GI__nss_dns_gethostbyaddr2_r (addr=<optimized out>, len=<optimized out>, af=<optimized out>, result=0x7ffd0220d460,
    buffer=<optimized out>, buflen=<optimized out>, errnop=0x73fef93ac278, h_errnop=0x7ffd0220d444, ttlp=0x0) at nss_dns/dns-host.c:576
#6  0x000073fef94efa59 in __GI__nss_dns_gethostbyaddr_r (addr=<optimized out>, len=<optimized out>, af=<optimized out>, result=<optimized out>,
    buffer=<optimized out>, buflen=<optimized out>, errnop=0x73fef93ac278, h_errnop=0x7ffd0220d444) at nss_dns/dns-host.c:630
#7  0x000073fef9509f2c in __gethostbyaddr_r (addr=addr@entry=0x187b0d58, len=len@entry=16, type=type@entry=10, resbuf=resbuf@entry=0x7ffd0220d460,
    buffer=<optimized out>, buflen=<optimized out>, result=<optimized out>, h_errnop=<optimized out>) at ../nss/getXXbyYY_r.c:273
#8  0x000073fef950ba86 in gni_host_inet_name (addrlen=<optimized out>, flags=<optimized out>, hostlen=1025, host=0x7ffd0220df60 "", sa=0x187b0d50,
    tmpbuf=0x7ffd0220d4a0) at ./nss/getnameinfo.c:243
#9  gni_host_inet (addrlen=<optimized out>, flags=<optimized out>, hostlen=1025, host=0x7ffd0220df60 "", sa=0x187b0d50, tmpbuf=0x7ffd0220d4a0)
    at ./nss/getnameinfo.c:382
#10 gni_host (addrlen=<optimized out>, flags=<optimized out>, hostlen=<optimized out>, host=<optimized out>, sa=<optimized out>, tmpbuf=<optimized out>)
    at ./nss/getnameinfo.c:424
#11 __GI_getnameinfo (sa=0x187b0d50, addrlen=<optimized out>, host=0x7ffd0220df60 "", hostlen=1025, serv=0x7ffd0220dd70 "", servlen=32, flags=<optimized out>)
    at ./nss/getnameinfo.c:538
#12 0x000073fe214368bb in c10d::detail::formatSockAddr[abi:cxx11](sockaddr const*, unsigned int) ()
   from /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_cpu.so
#13 0x000073fe2143c03d in void fmt::v12::detail::value<fmt::v12::context>::format_custom<addrinfo>(void*, fmt::v12::parse_context<char>&, fmt::v12::context&)
    () from /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_cpu.so
#14 0x000073fe1c2fba18 in fmt::v12::vformat[abi:cxx11](fmt::v12::basic_string_view<char>, fmt::v12::basic_format_args<fmt::v12::context>) ()
   from /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_cpu.so
#15 0x000073fe21436748 in c10d::detail::SocketImpl::SocketImpl(int, addrinfo const&) () from /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_cpu.so
#16 0x000073fe21438990 in c10d::detail::(anonymous namespace)::SocketConnectOp::tryConnect(int) ()
...
(gdb) c
Continuing.

Thread 1 "pt_elastic" hit Breakpoint 2.17, __libc_connect (fd=23, addr=..., len=28) at ../sysdeps/unix/sysv/linux/connect.c:24
24      in ../sysdeps/unix/sysv/linux/connect.c
(gdb) bt
#0  __libc_connect (fd=23, addr=..., len=28) at ../sysdeps/unix/sysv/linux/connect.c:24
#1  0x000073fe214389e3 in c10d::detail::(anonymous namespace)::SocketConnectOp::tryConnect(int) ()
   from /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_cpu.so
#2  0x000073fe214396cf in c10d::detail::(anonymous namespace)::SocketConnectOp::run() [clone .constprop.0] ()
   from /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_cpu.so
#3  0x000073fe21439b57 in c10d::detail::Socket::connect(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, unsigned short, c10d::detail::SocketOptions const&) () from /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_cpu.so
#4  0x000073fe2139b113 in c10d::detail::TCPClient::connect(c10d::detail::SocketAddress const&, c10d::TCPStoreOptions const&, std::shared_ptr<c10d::Backoff>)
    () from /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_cpu.so
#5  0x000073fe213a1ead in c10d::TCPStore::TCPStore(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, c10d::TCPStoreOptions const&) () from /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_cpu.so
#6  0x000073fe30e72f51 in pybind11::cpp_function::initialize<pybind11::detail::initimpl::factory<torch::distributed::c10d::(anonymous namespace)::c10d_init(_object*, _object*)::{lambda(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, unsigned short, std::optional<int>, bool, std::chrono::duration<long, std::ratio<1l, 1000l> >, bool, bool, std::optional<int>, bool)#1}, pybind11::detail::void_type (*)(), c10::intrusive_ptr<c10d::TCPStore, c10::detail::intrusive_target_default_null_type<c10d::TCPStore> > (std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, unsigned short, std::optional<int>, bool, std::chrono::duration<long, std::ratio<1l, 1000l> >, bool, bool, std::optional<int>, bool), pybind11::detail::void_type ()>::execute<pybind11::class_<c10d::TCPStore, c10::intrusive_ptr<c10d::TCPStore, c10::detail::intrusive_target_default_null_type<c10d::TCPStore> > >, pybind11::arg, pybind11::arg, pybind11::arg_v, pybind11::arg_v, pybind11::arg_v, pybind11::arg_v, pybind11::arg_v, pybind11::arg_v, pybind11::arg_v, char [24]>(pybind11::class_<c10d::TCPStore, c10::intrusive_ptr<c10d::TCPStore, c10::detail::intrusive_target_default_null_type<c10d::TCPStore> > >&, pybind11::arg const&, pybind11::arg const&, pybind11::arg_v const&, pybind11::arg_v const&, pybind11::arg_v const&, pybind11::arg_v const&, pybind11::arg_v const&, pybind11::arg_v const&, pybind11::arg_v const&, char const (&) [24]) &&::{lambda(pybind11::detail::value_and_holder&, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, unsigned short, std::optional<int>, bool, std::chrono::duration<long, std::ratio<1l, 1000l> >, bool, bool, std::optional<int>, bool)#1}, void, pybind11::detail::value_and_holder&, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, unsigned short, std::optional<int>, bool, std::chrono::duration<long, std::ratio<1l, 1000l> >, bool, bool, std::optional<int>, bool, pybind11::name, pybind11::is_method, pybind11::sibling, pybind11::detail::is_new_style_constructor, pybind11::arg, pybind11::arg, pybind11::arg_v, pybind11::arg_v, pybind11::arg_v, pybind11::arg_v, pybind11::arg_v, pybind11::arg_v, pybind11::arg_v, char [24]>(pybind11::class_<c10d::TCPStore, c10::intrusive_ptr<c10d::TCPStore, c10::detail::intrusive_target_default_null_type<c10d::TCPStore> > >&&, void (*)(pybind11::detail::value_and_holder&, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, unsigned short, std::optional<int>, bool, std::chrono::duration<long, std::ratio<1l, 1000l> >, bool, bool, std::optional<int>, bool), pybind11::name const&, pybind11::is_method const&, pybind11::sibling const&, pybind11::detail::is_new_style_constructor const&, pybind11::arg const&, pybind11::arg const&, pybind11::arg_v const&, pybind11::arg_v const&, pybind11::arg_v const&, pybind11::arg_v const&, pybind11::arg_v const&, pybind11::arg_v const&, pybind11::arg_v const&, char const (&) [24])::{lambda(pybind11::detail::function_call&)#1}::_FUN(pybind11::detail::function_call&) ()
   from /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_python.so
#7  0x000073fe303b8060 in pybind11::cpp_function::dispatcher(_object*, _object*, _object*) ()
   from /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_python.so
...
(gdb) b *0x000073fe214389e3
Breakpoint 3 at 0x73fe214389e3
(gdb) c
Continuing.

Thread 1 "pt_elastic" hit Breakpoint 3, 0x000073fe214389e3 in c10d::detail::(anonymous namespace)::SocketConnectOp::tryConnect(int) ()
   from /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_cpu.so
(gdb) p $eax
$1 = -1
(gdb) p (int)errno
$2 = 115

```

It hit twice. The first one is from `getnameinfo`, which looks unrelated and we skip. The second one was called from `SocketConnectOp::tryConnect` and failed with `EINPROGRESS` (= 115).  This is the one we're interested in. And we’re interested in the parameters of `connect`. Here’s assembly of where we are, immediately after the call to `connect`.

```log
(gdb) disas 0x000073fe214380b0
...
   0x000073fe214389ca <+2330>:  mov    -0x5a8(%rbp),%rax
   0x000073fe214389d1 <+2337>:  mov    0x10(%rax),%edx
   0x000073fe214389d4 <+2340>:  mov    0x18(%rax),%rsi
   0x000073fe214389d8 <+2344>:  mov    0x50(%rbx),%rax
   0x000073fe214389dc <+2348>:  mov    (%rax),%edi
   0x000073fe214389de <+2350>:  call   0x73fe1b9afb00 <connect@plt>
=> 0x000073fe214389e3 <+2355>:  test   %eax,%eax
```

What we want is the 2nd parameter `const struct sockaddr *addr`, which is passed via `$rsi`  in [System V ABI](https://wiki.osdev.org/System_V_ABI).  Since we already lost `$rdi`, we need restore it from the stack.

```log
(gdb) x/1g $rbp-0x5a8
0x7ffd0220e7c8: 0x00000000187b0d20
(gdb) x/1g 0x00000000187b0d20+0x18
0x187b0d38:     0x00000000187b0d50
(gdb) x/28xb 0x00000000187b0d50
0x187b0d50:     0x0a    0x00    0x73    0x3c    0x00    0x00    0x00    0x00
0x187b0d58:     0x00    0x00    0x00    0x00    0x00    0x00    0x00    0x00
0x187b0d60:     0x00    0x00    0xff    0xff    0x0a    0x64    0x0a    0x01
0x187b0d68:     0x00    0x00    0x00    0x00
```

How to read this?  The address family is `0x0a 0x00` , which means `AF_INET6`, and we can see the address is `0xff 0xff 0x0a 0x64 0x0a 0x01`, which is an IPv4-mapped IPv6 address `::ffff:10.100.10.1`. The port is `0x73 0x3c` , which is 0x733c=29500.  This means the script simply tries to connect to the rendezvous point we specified, 10.100.10.1:29500, but it failed.  This means somebody should be listening on the endpoint.  Let's find out.

```
root@2f3d44d8b0ea:/app# ss -paln | grep 29500
root@2f3d44d8b0ea:/app#
```

Okay, nobody is listening on the endpoint, that's why `connect` failed.

The next thing to do is to see the positive behavior. We know this works with 127.0.0.1. Let’s see if the endpoint is listened on in that case.

```log
root@2f3d44d8b0ea:/app# NCCL_DEBUG=INFO LOGLEVEL=DEBUG STEPS=30 gdb -q /usr/bin/python3
Reading symbols from /usr/bin/python3...
(No debugging symbols found in /usr/bin/python3)
(gdb) set pagination off
(gdb) b SocketConnectOp::tryConnect
Function "SocketConnectOp::tryConnect" not defined.
Make breakpoint pending on future shared library load? (y or [n]) y
Breakpoint 1 (SocketConnectOp::tryConnect) pending.
(gdb) r /usr/local/bin/torchrun --nnodes=1 --nproc_per_node=1 --node_rank=0 --rdzv_id=test_job --rdzv_backend=c10d --rdzv_endpoint=127.0.0.1:29500 train.py
Starting program: /usr/bin/python3 /usr/local/bin/torchrun --nnodes=1 --nproc_per_node=1 --node_rank=0 --rdzv_id=test_job --rdzv_backend=c10d --rdzv_endpoint=127.0.0.1:29500 train.py
warning: Error disabling address space randomization: Operation not permitted
[Thread debugging using libthread_db enabled]
Using host libthread_db library "/lib/x86_64-linux-gnu/libthread_db.so.1".
[New Thread 0x7b4c309ff6c0 (LWP 1830)]
...
[New Thread 0x7b4c239e56c0 (LWP 1856)]
I0406 19:22:02.314000 1827 torch/distributed/run.py:735] Using nproc_per_node=1.
[New Thread 0x7b4c1e32a6c0 (LWP 1857)]
I0406 19:22:02.402000 1827 torch/distributed/launcher/api.py:131] Using default numa options = None
I0406 19:22:02.403000 1827 torch/distributed/launcher/api.py:224] Starting elastic_operator with launch configs:
I0406 19:22:02.403000 1827 torch/distributed/launcher/api.py:224]   entrypoint               : train.py
I0406 19:22:02.403000 1827 torch/distributed/launcher/api.py:224]   min_nodes                : 1
I0406 19:22:02.403000 1827 torch/distributed/launcher/api.py:224]   max_nodes                : 1
I0406 19:22:02.403000 1827 torch/distributed/launcher/api.py:224]   nproc_per_node           : 1
I0406 19:22:02.403000 1827 torch/distributed/launcher/api.py:224]   run_id                   : test_job
I0406 19:22:02.403000 1827 torch/distributed/launcher/api.py:224]   rdzv_backend             : c10d
I0406 19:22:02.403000 1827 torch/distributed/launcher/api.py:224]   rdzv_endpoint            : 127.0.0.1:29500
I0406 19:22:02.403000 1827 torch/distributed/launcher/api.py:224]   rdzv_configs             : {'timeout': 900}
I0406 19:22:02.403000 1827 torch/distributed/launcher/api.py:224]   max_restarts             : 0
I0406 19:22:02.403000 1827 torch/distributed/launcher/api.py:224]   monitor_interval         : 0.1
I0406 19:22:02.403000 1827 torch/distributed/launcher/api.py:224]   log_dir                  : /tmp/torchelastic_2oqcd4j3
I0406 19:22:02.403000 1827 torch/distributed/launcher/api.py:224]   metrics_cfg              : {}
I0406 19:22:02.403000 1827 torch/distributed/launcher/api.py:224]   event_log_handler        : null
I0406 19:22:02.403000 1827 torch/distributed/launcher/api.py:224]   numa_options             : None
I0406 19:22:02.403000 1827 torch/distributed/launcher/api.py:224]   signals_to_handle        : SIGTERM,SIGINT,SIGHUP,SIGQUIT
I0406 19:22:02.403000 1827 torch/distributed/launcher/api.py:224]   duplicate_stdout_filters : []
I0406 19:22:02.403000 1827 torch/distributed/launcher/api.py:224]   duplicate_stderr_filters : []
I0406 19:22:02.403000 1827 torch/distributed/launcher/api.py:224]
[New Thread 0x7b4c10dff6c0 (LWP 1858)]

Thread 1 "pt_elastic" hit Breakpoint 1, 0x00007b4ca80380b0 in c10d::detail::(anonymous namespace)::SocketConnectOp::tryConnect(int) ()
   from /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_cpu.so
(gdb)
```

And see this! The endpoint is there.

```
root@2f3d44d8b0ea:/app# ss -paln | grep 29500
tcp   LISTEN 0      4096                              *:29500                *:*    users:(("pt_elastic",pid=1827,fd=30))
```

So the problem is the script doesn’t start listening on the endpoint if the address is 10.100.10.1, while it does on 127.0.0.1.

Do you know which function to set a breakpoint on? Probably `listen` or `bind`?  Well, in this case, I did some homework for you already and it turned out `bind` was the one. So let’s do it.

```log
(gdb) del
Delete all breakpoints, watchpoints, tracepoints, and catchpoints? (y or n) y
(gdb) b bind
Breakpoint 2 at 0x7b4caa493780 (9 locations)
(gdb) r /usr/local/bin/torchrun --nnodes=1 --nproc_per_node=1 --node_rank=0 --rdzv_id=test_job --rdzv_backend=c10d --rdzv_endpoint=127.0.0.1:29500 train.py
The program being debugged has been started already.
Start it from the beginning? (y or n) y
Starting program: /usr/bin/python3 /usr/local/bin/torchrun --nnodes=1 --nproc_per_node=1 --node_rank=0 --rdzv_id=test_job --rdzv_backend=c10d --rdzv_endpoint=127.0.0.1:29500 train.py
warning: Error disabling address space randomization: Operation not permitted
[Thread debugging using libthread_db enabled]
Using host libthread_db library "/lib/x86_64-linux-gnu/libthread_db.so.1".
[New Thread 0x70b1e8fff6c0 (LWP 1863)]
...
[New Thread 0x70b1dbfe56c0 (LWP 1889)]

Thread 1 "python3" hit Breakpoint 2.8, 0x000070b26fdda030 in torch::jit::slot_dict_impl<torch::jit::detail::ParameterPolicy>::bind(pybind11::module_ const&, char const*) () from /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_python.so
(gdb) c
Continuing.

Thread 1 "python3" hit Breakpoint 2.7, 0x000070b26fdd9820 in torch::jit::slot_dict_impl<torch::jit::detail::BufferPolicy>::bind(pybind11::module_ const&, char const*) () from /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_python.so
(gdb) c
Continuing.

Thread 1 "python3" hit Breakpoint 2.6, 0x000070b26fdd9010 in torch::jit::slot_dict_impl<torch::jit::detail::ModulePolicy>::bind(pybind11::module_ const&, char const*) () from /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_python.so
(gdb) c
Continuing.
I0406 19:27:40.820000 1862 torch/distributed/run.py:735] Using nproc_per_node=1.
[New Thread 0x70b1d682a6c0 (LWP 1890)]

Thread 1 "pt_elastic" hit Breakpoint 2.9, __GI_bind () at ../sysdeps/unix/syscall-template.S:120
warning: 120    ../sysdeps/unix/syscall-template.S: No such file or directory
(gdb) bt
#0  __GI_bind () at ../sysdeps/unix/syscall-template.S:120
#1  0x000070b300eca4eb in ?? () from /lib/x86_64-linux-gnu/libcuda.so.1
#2  0x000070b300f18602 in ?? () from /lib/x86_64-linux-gnu/libcuda.so.1
#3  0x000070b337a34a76 in ?? () from /usr/local/lib/python3.12/dist-packages/torch/lib/../../nvidia/cuda_runtime/lib/libcudart.so.12
#4  0x000070b337a39638 in ?? () from /usr/local/lib/python3.12/dist-packages/torch/lib/../../nvidia/cuda_runtime/lib/libcudart.so.12
#5  0x000070b338605ed3 in __pthread_once_slow (once_control=0x70b337cb2220, init_routine=0x70b337a395f0) at ./nptl/pthread_once.c:116
#6  0x000070b337a88ed9 in ?? () from /usr/local/lib/python3.12/dist-packages/torch/lib/../../nvidia/cuda_runtime/lib/libcudart.so.12
#7  0x000070b337a356ff in ?? () from /usr/local/lib/python3.12/dist-packages/torch/lib/../../nvidia/cuda_runtime/lib/libcudart.so.12
#8  0x000070b337a4d99a in cudaGetDeviceCount () from /usr/local/lib/python3.12/dist-packages/torch/lib/../../nvidia/cuda_runtime/lib/libcudart.so.12
#9  0x000070b3379ad842 in c10::cuda::device_count() () from /usr/local/lib/python3.12/dist-packages/torch/lib/libc10_cuda.so
#10 0x000070b26ffed402 in THCPModule_getDeviceCount_wrap(_object*, _object*) () from /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_python.so
#11 0x0000000000581a8a in ?? ()
...
#25 0x00000000006bc3ed in Py_BytesMain ()
#26 0x000070b33858e1ca in __libc_start_call_main (main=main@entry=0x518930, argc=argc@entry=9, argv=argv@entry=0x7ffd9cdf2e68)
    at ../sysdeps/nptl/libc_start_call_main.h:58
#27 0x000070b33858e28b in __libc_start_main_impl (main=0x518930, argc=9, argv=0x7ffd9cdf2e68, init=<optimized out>, fini=<optimized out>,
    rtld_fini=<optimized out>, stack_end=0x7ffd9cdf2e58) at ../csu/libc-start.c:360
#28 0x00000000006576c5 in _start ()
(gdb) c
Continuing.
I0406 19:27:50.442000 1862 torch/distributed/launcher/api.py:131] Using default numa options = None
I0406 19:27:50.443000 1862 torch/distributed/launcher/api.py:224] Starting elastic_operator with launch configs:
I0406 19:27:50.443000 1862 torch/distributed/launcher/api.py:224]   entrypoint               : train.py
I0406 19:27:50.443000 1862 torch/distributed/launcher/api.py:224]   min_nodes                : 1
I0406 19:27:50.443000 1862 torch/distributed/launcher/api.py:224]   max_nodes                : 1
I0406 19:27:50.443000 1862 torch/distributed/launcher/api.py:224]   nproc_per_node           : 1
I0406 19:27:50.443000 1862 torch/distributed/launcher/api.py:224]   run_id                   : test_job
I0406 19:27:50.443000 1862 torch/distributed/launcher/api.py:224]   rdzv_backend             : c10d
I0406 19:27:50.443000 1862 torch/distributed/launcher/api.py:224]   rdzv_endpoint            : 127.0.0.1:29500
I0406 19:27:50.443000 1862 torch/distributed/launcher/api.py:224]   rdzv_configs             : {'timeout': 900}
I0406 19:27:50.443000 1862 torch/distributed/launcher/api.py:224]   max_restarts             : 0
I0406 19:27:50.443000 1862 torch/distributed/launcher/api.py:224]   monitor_interval         : 0.1
I0406 19:27:50.443000 1862 torch/distributed/launcher/api.py:224]   log_dir                  : /tmp/torchelastic_lqlqv2qz
I0406 19:27:50.443000 1862 torch/distributed/launcher/api.py:224]   metrics_cfg              : {}
I0406 19:27:50.443000 1862 torch/distributed/launcher/api.py:224]   event_log_handler        : null
I0406 19:27:50.443000 1862 torch/distributed/launcher/api.py:224]   numa_options             : None
I0406 19:27:50.443000 1862 torch/distributed/launcher/api.py:224]   signals_to_handle        : SIGTERM,SIGINT,SIGHUP,SIGQUIT
I0406 19:27:50.443000 1862 torch/distributed/launcher/api.py:224]   duplicate_stdout_filters : []
I0406 19:27:50.443000 1862 torch/distributed/launcher/api.py:224]   duplicate_stderr_filters : []
I0406 19:27:50.443000 1862 torch/distributed/launcher/api.py:224]

Thread 1 "pt_elastic" hit Breakpoint 2.9, __GI_bind () at ../sysdeps/unix/syscall-template.S:120
120     in ../sysdeps/unix/syscall-template.S

(gdb) bt
#0  __GI_bind () at ../sysdeps/unix/syscall-template.S:120
#1  0x000070b26b907b2d in uv.tcp_bind () from /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_cpu.so
#2  0x000070b2605c0918 in c10d::detail::UvTcpServer::makeWithPort(uv_loop_s*, unsigned short, bool) ()
   from /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_cpu.so
#3  0x000070b2605b78fe in c10d::detail::LibUVStoreDaemon::init(c10d::TCPStoreOptions const&) ()
   from /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_cpu.so
#4  0x000070b2605b793f in c10d::detail::create_libuv_tcpstore_backend(c10d::TCPStoreOptions const&) ()
   from /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_cpu.so
#5  0x000070b26059dc85 in c10d::detail::TCPServer::start(c10d::TCPStoreOptions const&) ()
   from /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_cpu.so
#6  0x000070b2605a1bf2 in c10d::TCPStore::TCPStore(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, c10d::TCPStoreOptions const&) () from /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_cpu.so
#7  0x000070b270072f51 in pybind11::cpp_function::initialize<pybind11::detail::initimpl::factory<torch::distributed::c10d::(anonymous namespace)::c10d_init(_object*, _object*)::{lambda(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, unsigned short, std::optional<int>, bool, std::chrono::duration<long, std::ratio<1l, 1000l> >, bool, bool, std::optional<int>, bool)#1}, pybind11::detail::void_type (*)(), c10::intrusive_ptr<c10d::TCPStore, c10::detail::intrusive_target_default_null_type<c10d::TCPStore> > (std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, unsigned short, std::optional<int>, bool, std::chrono::duration<long, std::ratio<1l, 1000l> >, bool, bool, std::optional<int>, bool), pybind11::detail::void_type ()>::execute<pybind11::class_<c10d::TCPStore, c10::intrusive_ptr<c10d::TCPStore, c10::detail::intrusive_target_default_null_type<c10d::TCPStore> > >, pybind11::arg, pybind11::arg, pybind11::arg_v, pybind11::arg_v, pybind11::arg_v, pybind11::arg_v, pybind11::arg_v, pybind11::arg_v, pybind11::arg_v, char [24]>(pybind11::class_<c10d::TCPStore, c10::intrusive_ptr<c10d::TCPStore, c10::detail::intrusive_target_default_null_type<c10d::TCPStore> > >&, pybind11::arg const&, pybind11::arg const&, pybind11::arg_v const&, pybind11::arg_v const&, pybind11::arg_v const&, pybind11::arg_v const&, pybind11::arg_v const&, pybind11::arg_v const&, pybind11::arg_v const&, char const (&) [24]) &&::{lambda(pybind11::detail::value_and_holder&, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, unsigned short, std::optional<int>, bool, std::chrono::duration<long, std::ratio<1l, 1000l> >, bool, bool, std::optional<int>, bool)#1}, void, pybind11::detail::value_and_holder&, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, unsigned short, std::optional<int>, bool, std::chrono::duration<long, std::ratio<1l, 1000l> >, bool, bool, std::optional<int>, bool, pybind11::name, pybind11::is_method, pybind11::sibling, pybind11::detail::is_new_style_constructor, pybind11::arg, pybind11::arg, pybind11::arg_v, pybind11::arg_v, pybind11::arg_v, pybind11::arg_v, pybind11::arg_v, pybind11::arg_v, pybind11::arg_v, char [24]>(pybind11::class_<c10d::TCPStore, c10::intrusive_ptr<c10d::TCPStore, c10::detail::intrusive_target_default_null_type<c10d::TCPStore> > >&&, void (*)(pybind11::detail::value_and_holder&, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, unsigned short, std::optional<int>, bool, std::chrono::duration<long, std::ratio<1l, 1000l> >, bool, bool, std::optional<int>, bool), pybind11::name const&, pybind11::is_method const&, pybind11::sibling const&, pybind11::detail::is_new_style_constructor const&, pybind11::arg const&, pybind11::arg const&, pybind11::arg_v const&, pybind11::arg_v const&, pybind11::arg_v const&, pybind11::arg_v const&, pybind11::arg_v const&, pybind11::arg_v const&, pybind11::arg_v const&, char const (&) [24])::{lambda(pybind11::detail::function_call&)#1}::_FUN(pybind11::detail::function_call&) ()
   from /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_python.so
#8  0x000070b26f5b8060 in pybind11::cpp_function::dispatcher(_object*, _object*, _object*) ()
   from /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_python.so
...
#28 0x00000000006bc3ed in Py_BytesMain ()
#29 0x000070b33858e1ca in __libc_start_call_main (main=main@entry=0x518930, argc=argc@entry=9, argv=argv@entry=0x7ffd9cdf2e68)
    at ../sysdeps/nptl/libc_start_call_main.h:58
#30 0x000070b33858e28b in __libc_start_main_impl (main=0x518930, argc=9, argv=0x7ffd9cdf2e68, init=<optimized out>, fini=<optimized out>,
    rtld_fini=<optimized out>, stack_end=0x7ffd9cdf2e58) at ../csu/libc-start.c:360
#31 0x00000000006576c5 in _start ()
```

Okay, we got it. In this positive scenario, we start listening on the endpoint through `TCPServer::start`.

Now, what happens if the address is 10.100.10.1? If you look at the debugger output earlier carefully, we called `connect` inside `TCPStore::TCPStore` through `TCPClient::connect`. So we know `TCPStore::TCPStore` is surely called. Let’s see if we call `TCPServer::start` or not.

```log
(gdb) del
Delete all breakpoints, watchpoints, tracepoints, and catchpoints? (y or n) y
(gdb) b TCPStore::TCPStore
Breakpoint 3 at 0x70b2605a1a90
(gdb) b TCPServer::start
Breakpoint 4 at 0x70b26059d990
(gdb) b TCPClient::connect
Breakpoint 5 at 0x70b26059b070
(gdb) r /usr/local/bin/torchrun --nnodes=1 --nproc_per_node=1 --node_rank=0 --rdzv_id=test_job --rdzv_backend=c10d --rdzv_endpoint=10.100.10.1:29500 train.py
The program being debugged has been started already.
Start it from the beginning? (y or n) y
Starting program: /usr/bin/python3 /usr/local/bin/torchrun --nnodes=1 --nproc_per_node=1 --node_rank=0 --rdzv_id=test_job --rdzv_backend=c10d --rdzv_endpoint=10.100.10.1:29500 train.py
warning: Error disabling address space randomization: Operation not permitted
[Thread debugging using libthread_db enabled]
Using host libthread_db library "/lib/x86_64-linux-gnu/libthread_db.so.1".
[New Thread 0x7b62f01ff6c0 (LWP 1900)]
...
[New Thread 0x7b62e31e56c0 (LWP 1926)]
I0406 19:36:25.975000 1899 torch/distributed/run.py:735] Using nproc_per_node=1.
[New Thread 0x7b62dda2a6c0 (LWP 1927)]
I0406 19:36:26.088000 1899 torch/distributed/launcher/api.py:131] Using default numa options = None
I0406 19:36:26.091000 1899 torch/distributed/launcher/api.py:224] Starting elastic_operator with launch configs:
I0406 19:36:26.091000 1899 torch/distributed/launcher/api.py:224]   entrypoint               : train.py
I0406 19:36:26.091000 1899 torch/distributed/launcher/api.py:224]   min_nodes                : 1
I0406 19:36:26.091000 1899 torch/distributed/launcher/api.py:224]   max_nodes                : 1
I0406 19:36:26.091000 1899 torch/distributed/launcher/api.py:224]   nproc_per_node           : 1
I0406 19:36:26.091000 1899 torch/distributed/launcher/api.py:224]   run_id                   : test_job
I0406 19:36:26.091000 1899 torch/distributed/launcher/api.py:224]   rdzv_backend             : c10d
I0406 19:36:26.091000 1899 torch/distributed/launcher/api.py:224]   rdzv_endpoint            : 10.100.10.1:29500
I0406 19:36:26.091000 1899 torch/distributed/launcher/api.py:224]   rdzv_configs             : {'timeout': 900}
I0406 19:36:26.091000 1899 torch/distributed/launcher/api.py:224]   max_restarts             : 0
I0406 19:36:26.091000 1899 torch/distributed/launcher/api.py:224]   monitor_interval         : 0.1
I0406 19:36:26.091000 1899 torch/distributed/launcher/api.py:224]   log_dir                  : /tmp/torchelastic_e8gckjjs
I0406 19:36:26.091000 1899 torch/distributed/launcher/api.py:224]   metrics_cfg              : {}
I0406 19:36:26.091000 1899 torch/distributed/launcher/api.py:224]   event_log_handler        : null
I0406 19:36:26.091000 1899 torch/distributed/launcher/api.py:224]   numa_options             : None
I0406 19:36:26.091000 1899 torch/distributed/launcher/api.py:224]   signals_to_handle        : SIGTERM,SIGINT,SIGHUP,SIGQUIT
I0406 19:36:26.091000 1899 torch/distributed/launcher/api.py:224]   duplicate_stdout_filters : []
I0406 19:36:26.091000 1899 torch/distributed/launcher/api.py:224]   duplicate_stderr_filters : []
I0406 19:36:26.091000 1899 torch/distributed/launcher/api.py:224]

Thread 1 "pt_elastic" hit Breakpoint 3, 0x00007b63677a1a90 in c10d::TCPStore::TCPStore(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, c10d::TCPStoreOptions const&) () from /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_cpu.so
(gdb) c
Continuing.

Thread 1 "pt_elastic" hit Breakpoint 5, 0x00007b636779b070 in c10d::detail::TCPClient::connect(c10d::detail::SocketAddress const&, c10d::TCPStoreOptions const&, std::shared_ptr<c10d::Backoff>) () from /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_cpu.so
(gdb)
```

See? We reached `TCPClient::connect` without hitting `TCPServer::start`. This is the problem. Let’s go check the code and see how we call `TCPServer::start`.  Code is [here](https://github.com/pytorch/pytorch/blob/v2.11.0/torch/csrc/distributed/c10d/TCPStore.cpp#L269).  It’s behind the check `if (opts.isServer) {`.  Does this mean `isServer` was `false` in our case?  Let's confirm it on debugger.

```log
(gdb) del
Delete all breakpoints, watchpoints, tracepoints, and catchpoints? (y or n) y
(gdb) b TCPStore::TCPStore
Breakpoint 6 at 0x7b63677a1a90
(gdb) r /usr/local/bin/torchrun --nnodes=1 --nproc_per_node=1 --node_rank=0 --rdzv_id=test_job --rdzv_backend=c10d --rdzv_endpoint=10.100.10.1:29500 train.py
The program being debugged has been started already.
Start it from the beginning? (y or n) y
Starting program: /usr/bin/python3 /usr/local/bin/torchrun --nnodes=1 --nproc_per_node=1 --node_rank=0 --rdzv_id=test_job --rdzv_backend=c10d --rdzv_endpoint=10.100.10.1:29500 train.py
warning: Error disabling address space randomization: Operation not permitted
[Thread debugging using libthread_db enabled]
Using host libthread_db library "/lib/x86_64-linux-gnu/libthread_db.so.1".
[New Thread 0x71c1955ff6c0 (LWP 1929)]
...
[New Thread 0x71c1885e56c0 (LWP 1955)]
I0406 19:39:42.387000 1928 torch/distributed/run.py:735] Using nproc_per_node=1.
[New Thread 0x71c182f156c0 (LWP 1956)]
I0406 19:39:42.502000 1928 torch/distributed/launcher/api.py:131] Using default numa options = None
I0406 19:39:42.505000 1928 torch/distributed/launcher/api.py:224] Starting elastic_operator with launch configs:
I0406 19:39:42.505000 1928 torch/distributed/launcher/api.py:224]   entrypoint               : train.py
I0406 19:39:42.505000 1928 torch/distributed/launcher/api.py:224]   min_nodes                : 1
I0406 19:39:42.505000 1928 torch/distributed/launcher/api.py:224]   max_nodes                : 1
I0406 19:39:42.505000 1928 torch/distributed/launcher/api.py:224]   nproc_per_node           : 1
I0406 19:39:42.505000 1928 torch/distributed/launcher/api.py:224]   run_id                   : test_job
I0406 19:39:42.505000 1928 torch/distributed/launcher/api.py:224]   rdzv_backend             : c10d
I0406 19:39:42.505000 1928 torch/distributed/launcher/api.py:224]   rdzv_endpoint            : 10.100.10.1:29500
I0406 19:39:42.505000 1928 torch/distributed/launcher/api.py:224]   rdzv_configs             : {'timeout': 900}
I0406 19:39:42.505000 1928 torch/distributed/launcher/api.py:224]   max_restarts             : 0
I0406 19:39:42.505000 1928 torch/distributed/launcher/api.py:224]   monitor_interval         : 0.1
I0406 19:39:42.505000 1928 torch/distributed/launcher/api.py:224]   log_dir                  : /tmp/torchelastic_d435hy84
I0406 19:39:42.505000 1928 torch/distributed/launcher/api.py:224]   metrics_cfg              : {}
I0406 19:39:42.505000 1928 torch/distributed/launcher/api.py:224]   event_log_handler        : null
I0406 19:39:42.505000 1928 torch/distributed/launcher/api.py:224]   numa_options             : None
I0406 19:39:42.505000 1928 torch/distributed/launcher/api.py:224]   signals_to_handle        : SIGTERM,SIGINT,SIGHUP,SIGQUIT
I0406 19:39:42.505000 1928 torch/distributed/launcher/api.py:224]   duplicate_stdout_filters : []
I0406 19:39:42.505000 1928 torch/distributed/launcher/api.py:224]   duplicate_stderr_filters : []
I0406 19:39:42.505000 1928 torch/distributed/launcher/api.py:224]

Thread 1 "pt_elastic" hit Breakpoint 6, 0x000071c20cba1a90 in c10d::TCPStore::TCPStore(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, c10d::TCPStoreOptions const&) () from /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_cpu.so
(gdb) disas $rip
Dump of assembler code for function _ZN4c10d8TCPStoreC2ENSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEEERKNS_15TCPStoreOptionsE:
=> 0x000071c20cba1a90 <+0>:     push   %rbp
   0x000071c20cba1a91 <+1>:     pxor   %xmm0,%xmm0
...
   0x000071c20cba1bbf <+303>:   call   0x71c20cc32ba0 <_ZN4c10d6detail6Socket10initializeEv> <<<< Socket::initialize()
   0x000071c20cba1bc4 <+308>:   mov    -0x608(%rbp),%rsi
   0x000071c20cba1bcb <+315>:   movzwl (%rsi),%eax
   0x000071c20cba1bce <+318>:   cmpb   $0x0,0x2(%rsi) <<<< opts.isServer
   0x000071c20cba1bd2 <+322>:   mov    %ax,0x38(%rbx)
   0x000071c20cba1bd6 <+326>:   je     0x71c20cba1df1 <_ZN4c10d8TCPStoreC2ENSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEEERKNS_15TCPStoreOptionsE+865>
   0x000071c20cba1bdc <+332>:   lea    -0x240(%rbp),%rax
   0x000071c20cba1be3 <+339>:   mov    %rax,%rdi
   0x000071c20cba1be6 <+342>:   mov    %rax,-0x620(%rbp)
   0x000071c20cba1bed <+349>:   call   0x71c20cb9d990 <_ZN4c10d6detail9TCPServer5startERKNS_15TCPStoreOptionsE> <<<< TCPServer::start
   0x000071c20cba1bf2 <+354>:   mov    0x48(%rbx),%rdi
   0x000071c20cba1bf6 <+358>:   movdqa -0x240(%rbp),%xmm3
   0x000071c20cba1bfe <+366>:   movups %xmm3,0x40(%rbx)
   0x000071c20cba1c02 <+370>:   test   %rdi,%rdi
...
```

As I commented inline, `cmpb $0x0,0x2(%rsi)` is checking the flag.

```
(gdb) b *0x000071c20cba1bce
Breakpoint 7 at 0x71c20cba1bce
(gdb) c
Continuing.

Thread 1 "pt_elastic" hit Breakpoint 7, 0x000071c20cba1bce in c10d::TCPStore::TCPStore(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, c10d::TCPStoreOptions const&) () from /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_cpu.so
(gdb) x/1b $rsi+2
0x7ffcc728a192: 0
```

It’s zero! This is why we skip listening on the rendezvous point!

And it’s from the parameter `opts`. Where does it come from?  Who instantiates `TCPStore` ?

```log
(gdb) bt
#0  0x000071c20cba1bce in c10d::TCPStore::TCPStore(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, c10d::TCPStoreOptions const&) () from /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_cpu.so
#1  0x000071c21c672f51 in pybind11::cpp_function::initialize<pybind11::detail::initimpl::factory<torch::distributed::c10d::(anonymous namespace)::c10d_init(_object*, _object*)::{lambda(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, unsigned short, std::optional<int>, bool, std::chrono::duration<long, std::ratio<1l, 1000l> >, bool, bool, std::optional<int>, bool)#1}, pybind11::detail::void_type (*)(), c10::intrusive_ptr<c10d::TCPStore, c10::detail::intrusive_target_default_null_type<c10d::TCPStore> > (std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, unsigned short, std::optional<int>, bool, std::chrono::duration<long, std::ratio<1l, 1000l> >, bool, bool, std::optional<int>, bool), pybind11::detail::void_type ()>::execute<pybind11::class_<c10d::TCPStore, c10::intrusive_ptr<c10d::TCPStore, c10::detail::intrusive_target_default_null_type<c10d::TCPStore> > >, pybind11::arg, pybind11::arg, pybind11::arg_v, pybind11::arg_v, pybind11::arg_v, pybind11::arg_v, pybind11::arg_v, pybind11::arg_v, pybind11::arg_v, char [24]>(pybind11::class_<c10d::TCPStore, c10::intrusive_ptr<c10d::TCPStore, c10::detail::intrusive_target_default_null_type<c10d::TCPStore> > >&, pybind11::arg const&, pybind11::arg const&, pybind11::arg_v const&, pybind11::arg_v const&, pybind11::arg_v const&, pybind11::arg_v const&, pybind11::arg_v const&, pybind11::arg_v const&, pybind11::arg_v const&, char const (&) [24]) &&::{lambda(pybind11::detail::value_and_holder&, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, unsigned short, std::optional<int>, bool, std::chrono::duration<long, std::ratio<1l, 1000l> >, bool, bool, std::optional<int>, bool)#1}, void, pybind11::detail::value_and_holder&, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, unsigned short, std::optional<int>, bool, std::chrono::duration<long, std::ratio<1l, 1000l> >, bool, bool, std::optional<int>, bool, pybind11::name, pybind11::is_method, pybind11::sibling, pybind11::detail::is_new_style_constructor, pybind11::arg, pybind11::arg, pybind11::arg_v, pybind11::arg_v, pybind11::arg_v, pybind11::arg_v, pybind11::arg_v, pybind11::arg_v, pybind11::arg_v, char [24]>(pybind11::class_<c10d::TCPStore, c10::intrusive_ptr<c10d::TCPStore, c10::detail::intrusive_target_default_null_type<c10d::TCPStore> > >&&, void (*)(pybind11::detail::value_and_holder&, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, unsigned short, std::optional<int>, bool, std::chrono::duration<long, std::ratio<1l, 1000l> >, bool, bool, std::optional<int>, bool), pybind11::name const&, pybind11::is_method const&, pybind11::sibling const&, pybind11::detail::is_new_style_constructor const&, pybind11::arg const&, pybind11::arg const&, pybind11::arg_v const&, pybind11::arg_v const&, pybind11::arg_v const&, pybind11::arg_v const&, pybind11::arg_v const&, pybind11::arg_v const&, pybind11::arg_v const&, char const (&) [24])::{lambda(pybind11::detail::function_call&)#1}::_FUN(pybind11::detail::function_call&) ()
   from /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_python.so
#2  0x000071c21bbb8060 in pybind11::cpp_function::dispatcher(_object*, _object*, _object*) ()
   from /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_python.so
#3  0x0000000000581e9f in ?? ()
```

The symbol of frame #1 is insanely long.  See `pybind11` namespace. It’s python binding, meaning it’s python script instantiating `TCPStore` in C++.  `opts.isServer`  is also from python.

### Debugging with PDB

How to debug Python? Do we need VSCode or some fancy IDE to debug it? That’s not what this blog does.  Do we debug python binding?  Well, that's too ambitious.

One of the great functionality Python provides, compared to other script languages, is Python has a built-in console debugger, [pdb](https://docs.python.org/3/library/pdb.html). You don’t need additional components to live debug python code.  Very handy.

To use pdb, you modify your script to break at the beginning. Since we’re running our script with `torchrun`, we make a little modification in `torchrun` itself, just adding one line `breakpoint()` before `main()`.

```python
root@2f3d44d8b0ea:/app# cat /usr/local/bin/torchrun
#!/usr/bin/python3
import sys
from torch.distributed.run import main
if __name__ == '__main__':
    breakpoint() # <---- add this line
    sys.argv[0] = sys.argv[0].removesuffix('.exe')
    sys.exit(main())
```

Where does it call C++? A formal approach would be to start from the log `Starting elastic_operator with launch configs:`, which is from [this line](https://github.com/pytorch/pytorch/blob/v2.11.0/torch/distributed/launcher/api.py#L225). In our case, however, let’s simply search the repo for the keyword `TCPStore` .  This function [`_create_tcp_store`](https://github.com/pytorch/pytorch/blob/v2.11.0/torch/distributed/elastic/rendezvous/c10d_rendezvous_backend.py#L156) looks suspicious.  Since python is a script language, you can set a breakpoint accurately with a source line.

```log
root@2f3d44d8b0ea:/app# NCCL_DEBUG=INFO LOGLEVEL=DEBUG STEPS=30 torchrun   --nnodes=1 --nproc_per_node=1 --node_rank=0   --rdzv_id=test_job --rdzv_backend=c10d --rdzv_endpoint=10.100.10.1:29500   train.py
> /usr/local/bin/torchrun(6)<module>()
-> sys.argv[0] = sys.argv[0].removesuffix('.exe')
(Pdb) b /usr/local/lib/python3.12/dist-packages/torch/distributed/elastic/rendezvous/c10d_rendezvous_backend.py:154
Breakpoint 1 at /usr/local/lib/python3.12/dist-packages/torch/distributed/elastic/rendezvous/c10d_rendezvous_backend.py:154
(Pdb) c
I0406 20:29:58.666000 1992 torch/distributed/run.py:735] Using nproc_per_node=1.
I0406 20:29:58.778000 1992 torch/distributed/launcher/api.py:131] Using default numa options = None
I0406 20:29:58.781000 1992 torch/distributed/launcher/api.py:224] Starting elastic_operator with launch configs:
I0406 20:29:58.781000 1992 torch/distributed/launcher/api.py:224]   entrypoint               : train.py
I0406 20:29:58.781000 1992 torch/distributed/launcher/api.py:224]   min_nodes                : 1
I0406 20:29:58.781000 1992 torch/distributed/launcher/api.py:224]   max_nodes                : 1
I0406 20:29:58.781000 1992 torch/distributed/launcher/api.py:224]   nproc_per_node           : 1
I0406 20:29:58.781000 1992 torch/distributed/launcher/api.py:224]   run_id                   : test_job
I0406 20:29:58.781000 1992 torch/distributed/launcher/api.py:224]   rdzv_backend             : c10d
I0406 20:29:58.781000 1992 torch/distributed/launcher/api.py:224]   rdzv_endpoint            : 10.100.10.1:29500
I0406 20:29:58.781000 1992 torch/distributed/launcher/api.py:224]   rdzv_configs             : {'timeout': 900}
I0406 20:29:58.781000 1992 torch/distributed/launcher/api.py:224]   max_restarts             : 0
I0406 20:29:58.781000 1992 torch/distributed/launcher/api.py:224]   monitor_interval         : 0.1
I0406 20:29:58.781000 1992 torch/distributed/launcher/api.py:224]   log_dir                  : /tmp/torchelastic_fw4c6ft_
I0406 20:29:58.781000 1992 torch/distributed/launcher/api.py:224]   metrics_cfg              : {}
I0406 20:29:58.781000 1992 torch/distributed/launcher/api.py:224]   event_log_handler        : null
I0406 20:29:58.781000 1992 torch/distributed/launcher/api.py:224]   numa_options             : None
I0406 20:29:58.781000 1992 torch/distributed/launcher/api.py:224]   signals_to_handle        : SIGTERM,SIGINT,SIGHUP,SIGQUIT
I0406 20:29:58.781000 1992 torch/distributed/launcher/api.py:224]   duplicate_stdout_filters : []
I0406 20:29:58.781000 1992 torch/distributed/launcher/api.py:224]   duplicate_stderr_filters : []
I0406 20:29:58.781000 1992 torch/distributed/launcher/api.py:224]
> /usr/local/lib/python3.12/dist-packages/torch/distributed/elastic/rendezvous/c10d_rendezvous_backend.py(154)_create_tcp_store()
-> for is_server in [is_host, False]:
(Pdb) bt
  /usr/local/bin/torchrun(7)<module>()
-> sys.exit(main())
  /usr/local/lib/python3.12/dist-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py(362)wrapper()
-> return f(*args, **kwargs)
  /usr/local/lib/python3.12/dist-packages/torch/distributed/run.py(990)main()
-> run(args)
  /usr/local/lib/python3.12/dist-packages/torch/distributed/run.py(981)run()
-> elastic_launch(
  /usr/local/lib/python3.12/dist-packages/torch/distributed/launcher/api.py(170)__call__()
-> return launch_agent(self._config, self._entrypoint, list(args))
  /usr/local/lib/python3.12/dist-packages/torch/distributed/launcher/api.py(284)launch_agent()
-> rdzv_handler=rdzv_registry.get_rendezvous_handler(rdzv_parameters),
  /usr/local/lib/python3.12/dist-packages/torch/distributed/elastic/rendezvous/registry.py(96)get_rendezvous_handler()
-> return handler_registry.create_handler(params)
  /usr/local/lib/python3.12/dist-packages/torch/distributed/elastic/rendezvous/api.py(377)create_handler()
-> handler = creator(params)
  /usr/local/lib/python3.12/dist-packages/torch/distributed/elastic/rendezvous/registry.py(46)_create_c10d_handler()
-> backend, store = create_backend(params)
  /usr/local/lib/python3.12/dist-packages/torch/distributed/elastic/rendezvous/c10d_rendezvous_backend.py(254)create_backend()
-> store = _create_tcp_store(params)
> /usr/local/lib/python3.12/dist-packages/torch/distributed/elastic/rendezvous/c10d_rendezvous_backend.py(154)_create_tcp_store()
-> for is_server in [is_host, False]:
(Pdb) p is_host
False
(Pdb) p cfg_is_host
None
(Pdb) p host
'10.100.10.1'
```

Okay, we got it. `_create_tcp_store` is instantiating `TCPStore` with `is_master=is_server`, which is `False`.

Where does this `is_host` come from? Since `cfg_is_host` is None , it must be from [`_matches_machine_hostname`](https://github.com/pytorch/pytorch/blob/v2.11.0/torch/distributed/elastic/rendezvous/utils.py#L117).  Let’s run this function line by line.

```log
root@2f3d44d8b0ea:/app# NCCL_DEBUG=INFO LOGLEVEL=DEBUG STEPS=30 torchrun   --nnodes=1 --nproc_per_node=1 --node_rank=0   --rdzv_id=test_job --rdzv_backend=c10d --rdzv_endpoint=10.100.10.1:29500   train.py
> /usr/local/bin/torchrun(6)<module>()
-> sys.argv[0] = sys.argv[0].removesuffix('.exe')
(Pdb) b /usr/local/lib/python3.12/dist-packages/torch/distributed/elastic/rendezvous/utils.py:146
Breakpoint 1 at /usr/local/lib/python3.12/dist-packages/torch/distributed/elastic/rendezvous/utils.py:146
(Pdb) c
I0406 20:35:35.452000 2021 torch/distributed/run.py:735] Using nproc_per_node=1.
I0406 20:35:35.549000 2021 torch/distributed/launcher/api.py:131] Using default numa options = None
I0406 20:35:35.551000 2021 torch/distributed/launcher/api.py:224] Starting elastic_operator with launch configs:
I0406 20:35:35.551000 2021 torch/distributed/launcher/api.py:224]   entrypoint               : train.py
I0406 20:35:35.551000 2021 torch/distributed/launcher/api.py:224]   min_nodes                : 1
I0406 20:35:35.551000 2021 torch/distributed/launcher/api.py:224]   max_nodes                : 1
I0406 20:35:35.551000 2021 torch/distributed/launcher/api.py:224]   nproc_per_node           : 1
I0406 20:35:35.551000 2021 torch/distributed/launcher/api.py:224]   run_id                   : test_job
I0406 20:35:35.551000 2021 torch/distributed/launcher/api.py:224]   rdzv_backend             : c10d
I0406 20:35:35.551000 2021 torch/distributed/launcher/api.py:224]   rdzv_endpoint            : 10.100.10.1:29500
I0406 20:35:35.551000 2021 torch/distributed/launcher/api.py:224]   rdzv_configs             : {'timeout': 900}
I0406 20:35:35.551000 2021 torch/distributed/launcher/api.py:224]   max_restarts             : 0
I0406 20:35:35.551000 2021 torch/distributed/launcher/api.py:224]   monitor_interval         : 0.1
I0406 20:35:35.551000 2021 torch/distributed/launcher/api.py:224]   log_dir                  : /tmp/torchelastic_dek76zg_
I0406 20:35:35.551000 2021 torch/distributed/launcher/api.py:224]   metrics_cfg              : {}
I0406 20:35:35.551000 2021 torch/distributed/launcher/api.py:224]   event_log_handler        : null
I0406 20:35:35.551000 2021 torch/distributed/launcher/api.py:224]   numa_options             : None
I0406 20:35:35.551000 2021 torch/distributed/launcher/api.py:224]   signals_to_handle        : SIGTERM,SIGINT,SIGHUP,SIGQUIT
I0406 20:35:35.551000 2021 torch/distributed/launcher/api.py:224]   duplicate_stdout_filters : []
I0406 20:35:35.551000 2021 torch/distributed/launcher/api.py:224]   duplicate_stderr_filters : []
I0406 20:35:35.551000 2021 torch/distributed/launcher/api.py:224]
> /usr/local/lib/python3.12/dist-packages/torch/distributed/elastic/rendezvous/utils.py(146)_matches_machine_hostname()
-> if host == this_host:
(Pdb) p host
'10.100.10.1'
(Pdb) p this_host
'2f3d44d8b0ea'
(Pdb) n
> /usr/local/lib/python3.12/dist-packages/torch/distributed/elastic/rendezvous/utils.py(149)_matches_machine_hostname()
-> addr_list = socket.getaddrinfo(
(Pdb)
> /usr/local/lib/python3.12/dist-packages/torch/distributed/elastic/rendezvous/utils.py(150)_matches_machine_hostname()
-> this_host, None, proto=socket.IPPROTO_TCP, flags=socket.AI_CANONNAME
(Pdb)
> /usr/local/lib/python3.12/dist-packages/torch/distributed/elastic/rendezvous/utils.py(149)_matches_machine_hostname()
-> addr_list = socket.getaddrinfo(
(Pdb)
> /usr/local/lib/python3.12/dist-packages/torch/distributed/elastic/rendezvous/utils.py(152)_matches_machine_hostname()
-> for addr_info in addr_list:
(Pdb) p addr_list
[(<AddressFamily.AF_INET: 2>, <SocketKind.SOCK_STREAM: 1>, 6, '2f3d44d8b0ea', ('172.17.0.2', 0))]
(Pdb)
```

We clearly see the problem. First, we get the hostname with `gethostname()`, which is `2f3d44d8b0ea` (It matches the container’s ID). And we get the IP address associated with the hostname via `getaddrinfo`, which returns only `172.17.0.2`, the one mapped to the default bridge interface (eth0). Since it doesn’t match `10.100.10.1`, it thinks “I’m not the master. Somebody else should start listening on the rendezvous endpoint.”

Ideally `_matches_machine_hostname` should iterate all IP addressed assigned to check any of them matches the host. We may consider sending a PR to them, but for now, is there a good way to work around this behavior?

There is. See the beginning of `_create_tcp_store`. It’s overwritable through `cfg_is_host`, coming from `params`.

```python
    cfg_is_host = params.get_as_bool("is_host")
    # If the user has explicitly specified whether our process should host the
    # the store, respect it.
    if cfg_is_host is not None:
        is_host = cfg_is_host
    # Otherwise try to determine whether we are the host based on our hostname
    # and IP address.
    else:
        is_host = _matches_machine_hostname(host)
```

Where is this `RendezvousParameters` created? It’s way above `_matches_machine_hostname`. It’s in the function [`run`](https://github.com/pytorch/pytorch/blob/v2.11.0/torch/distributed/run.py#L980). `torchrun` takes a rarely-used parameter `--rdzv_conf` where we can specify key-value pairs. What we want to specify is `is_host`. This looks promising. Let’s try it.

```log
root@2f3d44d8b0ea:/app# NCCL_DEBUG=INFO LOGLEVEL=DEBUG STEPS=30 torchrun \
  --nnodes=1 --nproc_per_node=1 --node_rank=0 \
  --rdzv_id=test_job --rdzv_backend=c10d --rdzv_endpoint=10.100.10.1:29500 \
  --rdzv_conf=is_host=1 \
  train.py
I0406 20:52:48.947000 2060 torch/distributed/run.py:735] Using nproc_per_node=1.
I0406 20:52:49.029000 2060 torch/distributed/launcher/api.py:131] Using default numa options = None
I0406 20:52:49.030000 2060 torch/distributed/launcher/api.py:224] Starting elastic_operator with launch configs:
I0406 20:52:49.030000 2060 torch/distributed/launcher/api.py:224]   entrypoint               : train.py
I0406 20:52:49.030000 2060 torch/distributed/launcher/api.py:224]   min_nodes                : 1
I0406 20:52:49.030000 2060 torch/distributed/launcher/api.py:224]   max_nodes                : 1
I0406 20:52:49.030000 2060 torch/distributed/launcher/api.py:224]   nproc_per_node           : 1
I0406 20:52:49.030000 2060 torch/distributed/launcher/api.py:224]   run_id                   : test_job
I0406 20:52:49.030000 2060 torch/distributed/launcher/api.py:224]   rdzv_backend             : c10d
I0406 20:52:49.030000 2060 torch/distributed/launcher/api.py:224]   rdzv_endpoint            : 10.100.10.1:29500
I0406 20:52:49.030000 2060 torch/distributed/launcher/api.py:224]   rdzv_configs             : {'is_host': 'True', 'timeout': 900}
I0406 20:52:49.030000 2060 torch/distributed/launcher/api.py:224]   max_restarts             : 0
I0406 20:52:49.030000 2060 torch/distributed/launcher/api.py:224]   monitor_interval         : 0.1
I0406 20:52:49.030000 2060 torch/distributed/launcher/api.py:224]   log_dir                  : /tmp/torchelastic_ccaq0pji
I0406 20:52:49.030000 2060 torch/distributed/launcher/api.py:224]   metrics_cfg              : {}
I0406 20:52:49.030000 2060 torch/distributed/launcher/api.py:224]   event_log_handler        : null
I0406 20:52:49.030000 2060 torch/distributed/launcher/api.py:224]   numa_options             : None
I0406 20:52:49.030000 2060 torch/distributed/launcher/api.py:224]   signals_to_handle        : SIGTERM,SIGINT,SIGHUP,SIGQUIT
I0406 20:52:49.030000 2060 torch/distributed/launcher/api.py:224]   duplicate_stdout_filters : []
I0406 20:52:49.030000 2060 torch/distributed/launcher/api.py:224]   duplicate_stderr_filters : []
I0406 20:52:49.030000 2060 torch/distributed/launcher/api.py:224]
I0406 20:52:49.055000 2060 torch/distributed/elastic/agent/server/api.py:898] [default] starting workers for entrypoint: python3
I0406 20:52:49.056000 2060 torch/distributed/elastic/agent/server/api.py:693] [default] Rendezvous'ing worker group
I0406 20:52:49.283000 2060 torch/distributed/elastic/agent/server/api.py:539] [default] Rendezvous complete for workers. Result:
I0406 20:52:49.283000 2060 torch/distributed/elastic/agent/server/api.py:539]   restart_count=0
I0406 20:52:49.283000 2060 torch/distributed/elastic/agent/server/api.py:539]   master_addr=2f3d44d8b0ea
I0406 20:52:49.283000 2060 torch/distributed/elastic/agent/server/api.py:539]   master_port=36587
I0406 20:52:49.283000 2060 torch/distributed/elastic/agent/server/api.py:539]   group_rank=0
I0406 20:52:49.283000 2060 torch/distributed/elastic/agent/server/api.py:539]   group_world_size=1
I0406 20:52:49.283000 2060 torch/distributed/elastic/agent/server/api.py:539]   local_ranks=[0]
I0406 20:52:49.283000 2060 torch/distributed/elastic/agent/server/api.py:539]   role_ranks=[0]
I0406 20:52:49.283000 2060 torch/distributed/elastic/agent/server/api.py:539]   global_ranks=[0]
I0406 20:52:49.283000 2060 torch/distributed/elastic/agent/server/api.py:539]   role_world_sizes=[1]
I0406 20:52:49.283000 2060 torch/distributed/elastic/agent/server/api.py:539]   global_world_sizes=[1]
I0406 20:52:49.283000 2060 torch/distributed/elastic/agent/server/api.py:539]   event_log_handler=null
I0406 20:52:49.283000 2060 torch/distributed/elastic/agent/server/api.py:539]
I0406 20:52:49.283000 2060 torch/distributed/elastic/agent/server/api.py:701] [default] Starting worker group
I0406 20:52:49.283000 2060 torch/distributed/elastic/agent/server/local_elastic_agent.py:299] use_agent_store: True
I0406 20:52:49.283000 2060 torch/distributed/elastic/agent/server/local_elastic_agent.py:195] Environment variable 'TORCHELASTIC_ENABLE_FILE_TIMER' not found. Do not start FileTimerServer.
I0406 20:52:49.283000 2060 torch/distributed/elastic/agent/server/local_elastic_agent.py:239] Environment variable 'TORCHELASTIC_HEALTH_CHECK_PORT' not found. Do not start health check.
2f3d44d8b0ea:2092:2092 [0] NCCL INFO ENV/Plugin: Could not find: libnccl-env.so
2f3d44d8b0ea:2092:2092 [0] NCCL INFO Bootstrap: Using eth0:172.17.0.2<0>
2f3d44d8b0ea:2092:2092 [0] NCCL INFO cudaDriverVersion 12080
2f3d44d8b0ea:2092:2092 [0] NCCL INFO NCCL version 2.28.9+cuda12.9
2f3d44d8b0ea:2092:2092 [0] NCCL INFO Comm config Blocking set to 1
2f3d44d8b0ea:2092:2092 [0] NCCL INFO NET/Plugin: Could not find: libnccl-net.so
2f3d44d8b0ea:2092:2092 [0] NCCL INFO Failed to open libibverbs.so[.1]
2f3d44d8b0ea:2092:2092 [0] NCCL INFO transport/net_ib.cc:852 -> 3
2f3d44d8b0ea:2092:2092 [0] NCCL INFO Failed to initialize NET plugin IB
2f3d44d8b0ea:2092:2092 [0] NCCL INFO NET/Socket : Using [0]eth0:172.17.0.2<0> [1]wg0:10.100.10.1<0>
2f3d44d8b0ea:2092:2092 [0] NCCL INFO Initialized NET plugin Socket
2f3d44d8b0ea:2092:2092 [0] NCCL INFO Assigned NET plugin Socket to comm
2f3d44d8b0ea:2092:2092 [0] NCCL INFO Using network Socket
2f3d44d8b0ea:2092:2092 [0] NCCL INFO ncclCommInitRankConfig comm 0x230fcf20 rank 0 nranks 1 cudaDev 0 nvmlDev 0 busId 70 commId 0xda9011adca28b7b9 - Init START
2f3d44d8b0ea:2092:2092 [0] NCCL INFO RAS client listening socket at ::1<28028>
2f3d44d8b0ea:2092:2092 [0] NCCL INFO Bootstrap timings total 0.001490 (create 0.000058, send 0.000233, recv 0.000596, ring 0.000001, delay 0.000000)
2f3d44d8b0ea:2092:2092 [0] NCCL INFO NCCL_IGNORE_DISABLED_P2P set by environment to 1.
2f3d44d8b0ea:2092:2092 [0] NCCL INFO ncclTopoGetCpuAffinity: Affinity for GPU 0 is empty, ignoring. (GPU affinity =  ; CPU affinity = 0-27).
2f3d44d8b0ea:2092:2092 [0] NCCL INFO comm 0x230fcf20 rank 0 nRanks 1 nNodes 1 localRanks 1 localRank 0 MNNVL 0
2f3d44d8b0ea:2092:2092 [0] NCCL INFO Channel 00/64 : 0
...
2f3d44d8b0ea:2092:2092 [0] NCCL INFO Channel 63/64 : 0
2f3d44d8b0ea:2092:2092 [0] NCCL INFO Trees [0] -1/-1/-1->0->-1 [1] -1/-1/-1->0->-1 [2] -1/-1/-1->0->-1 [3] -1/-1/-1->0->-1 [4] -1/-1/-1->0->-1 [5] -1/-1/-1->0->-1 [6] -1/-1/-1->0->-1 [7] -1/-1/-1->0->-1 [8] -1/-1/-1->0->-1 [9] -1/-1/-1->0->-1 [10] -1/-1/-1->0->-1 [11] -1/-1/-1->0->-1 [12] -1/-1/-1->0->-1 [13] -1/-1/-1->0->-1 [14] -1/-1/-1->0->-1 [15] -1/-1/-1->0->-1 [16] -1/-1/-1->0->-1 [17] -1/-1/-1->0->-1 [18] -1/-1/-1->0->-1 [19] -1/-1/-1->0->-1 [20] -1/-1/-1->0->-1 [21] -1/-1/-1->0->-1 [22] -1/-1/-1->0->-1 [23] -1/-1/-1->0->-1 [24] -1/-1/-1->0->-1 [25] -1/-1/-1->0->-1 [26] -1/-1/-1->0->-1 [27] -1/-1/-1->0->-1 [28] -1/-1/-1->0->-1 [29] -1/-1/-1->0->-1 [30] -1/-1/-1->0->-1 [31] -1/-1/-1->0->-1 [32] -1/-1/-1->0->-1 [33] -1/-1/-1->0->-1 [34] -1/-1/-1->0->-1 [35] -1/-1/-1->0->-1 [36] -1/-1/-1->0->-1 [37] -1/-1/-1->0->-1 [38] -1/-1/-1->0->-1 [39] -1/-1/-1->0->-1 [40] -1/-1/-1->0->-1 [41] -1/-1/-1->0->-1 [42] -1/-1/-1->0->-1 [43] -1/-1/-1->0->-1 [44] -1/-1/-1->0->-1 [45] -1/-1/-1->0->-1 [46] -1/-1/-1->0->-1 [47
2f3d44d8b0ea:2092:2092 [0] NCCL INFO P2P Chunksize set to 524288
2f3d44d8b0ea:2092:2092 [0] NCCL INFO PROFILER/Plugin: Could not find: libnccl-profiler.so
2f3d44d8b0ea:2092:2092 [0] NCCL INFO Check P2P Type isAllDirectP2p 1 directMode 0 isAllCudaP2p 1
2f3d44d8b0ea:2092:2126 [0] NCCL INFO [Proxy Service] Device 0 CPU core 2
2f3d44d8b0ea:2092:2127 [0] NCCL INFO [Proxy Service UDS] Device 0 CPU core 12
2f3d44d8b0ea:2092:2092 [0] NCCL INFO TUNER/Plugin: Could not find: libnccl-tuner.so
2f3d44d8b0ea:2092:2092 [0] NCCL INFO 64 coll channels, 64 collnet channels, 0 nvls channels, 64 p2p channels, 64 p2p channels per peer
2f3d44d8b0ea:2092:2092 [0] NCCL INFO CC Off, workFifoBytes 1048576
2f3d44d8b0ea:2092:2092 [0] NCCL INFO ncclCommInitRankConfig comm 0x230fcf20 rank 0 nranks 1 cudaDev 0 nvmlDev 0 busId 70 commId 0xda9011adca28b7b9 - Init COMPLETE
2f3d44d8b0ea:2092:2092 [0] NCCL INFO Init timings - ncclCommInitRankConfig: rank 0 nranks 1 total 0.17 (kernels 0.13, alloc 0.00, bootstrap 0.00, allgathers 0.00, topo 0.00, graphs 0.00, connections 0.02, rest 0.00)
[Rank 0] Starting training...
Step 0 | Loss: 1.3868
Step 10 | Loss: 1.5847
Step 20 | Loss: 1.5615
[Rank 0] Training complete.
2f3d44d8b0ea:2092:2092 [0] NCCL INFO comm 0x230fcf20 rank 0 nranks 1 cudaDev 0 busId 70 - Destroy COMPLETE
2f3d44d8b0ea:2092:2092 [0] NCCL INFO ENV/Plugin: Closing env plugin ncclEnvDefault
I0406 20:52:52.994000 2060 torch/distributed/elastic/https://github.com/pytorch/pytorch/blob/v2.11.0/torch/distributed/run.py#L1006agent/server/api.py:917] [default] worker group successfully finished. Waiting 300 seconds for other agents to finish.
I0406 20:52:52.994000 2060 torch/distributed/elastic/agent/server/api.py:970] Local worker group finished (WorkerState.SUCCEEDED). Waiting 300 seconds for other agents to finish
I0406 20:52:52.995000 2060 torch/distributed/elastic/agent/server/api.py:984] Done waiting for other agents. Elapsed: 0.00042891502380371094 seconds
root@2f3d44d8b0ea:/app#
```

It worked like a charm!

Now, it’s time to run a training on two nodes?

### Training on two nodes

This is the first node; master node, rank 0.

```log
root@2f3d44d8b0ea:/app# NCCL_DEBUG=INFO LOGLEVEL=DEBUG STEPS=30 torchrun \
  --nnodes=2 --nproc_per_node=1 --node_rank=0 \
  --rdzv_id=test_job --rdzv_backend=c10d --rdzv_endpoint=10.100.10.1:29500 \
  --rdzv_conf=is_host=1 \
  train.py
I0406 20:59:43.723000 2129 torch/distributed/run.py:735] Using nproc_per_node=1.
I0406 20:59:43.825000 2129 torch/distributed/launcher/api.py:131] Using default numa options = None
I0406 20:59:43.826000 2129 torch/distributed/launcher/api.py:224] Starting elastic_operator with launch configs:
I0406 20:59:43.826000 2129 torch/distributed/launcher/api.py:224]   entrypoint               : train.py
I0406 20:59:43.826000 2129 torch/distributed/launcher/api.py:224]   min_nodes                : 2
I0406 20:59:43.826000 2129 torch/distributed/launcher/api.py:224]   max_nodes                : 2
I0406 20:59:43.826000 2129 torch/distributed/launcher/api.py:224]   nproc_per_node           : 1
I0406 20:59:43.826000 2129 torch/distributed/launcher/api.py:224]   run_id                   : test_job
I0406 20:59:43.826000 2129 torch/distributed/launcher/api.py:224]   rdzv_backend             : c10d
I0406 20:59:43.826000 2129 torch/distributed/launcher/api.py:224]   rdzv_endpoint            : 10.100.10.1:29500
I0406 20:59:43.826000 2129 torch/distributed/launcher/api.py:224]   rdzv_configs             : {'is_host': 'True', 'timeout': 900}
I0406 20:59:43.826000 2129 torch/distributed/launcher/api.py:224]   max_restarts             : 0
I0406 20:59:43.826000 2129 torch/distributed/launcher/api.py:224]   monitor_interval         : 0.1
I0406 20:59:43.826000 2129 torch/distributed/launcher/api.py:224]   log_dir                  : /tmp/torchelastic_ji8a86n0
I0406 20:59:43.826000 2129 torch/distributed/launcher/api.py:224]   metrics_cfg              : {}
I0406 20:59:43.826000 2129 torch/distributed/launcher/api.py:224]   event_log_handler        : null
I0406 20:59:43.826000 2129 torch/distributed/launcher/api.py:224]   numa_options             : None
I0406 20:59:43.826000 2129 torch/distributed/launcher/api.py:224]   signals_to_handle        : SIGTERM,SIGINT,SIGHUP,SIGQUIT
I0406 20:59:43.826000 2129 torch/distributed/launcher/api.py:224]   duplicate_stdout_filters : []
I0406 20:59:43.826000 2129 torch/distributed/launcher/api.py:224]   duplicate_stderr_filters : []
I0406 20:59:43.826000 2129 torch/distributed/launcher/api.py:224]
I0406 20:59:43.854000 2129 torch/distributed/elastic/agent/server/api.py:898] [default] starting workers for entrypoint: python3
I0406 20:59:43.855000 2129 torch/distributed/elastic/agent/server/api.py:693] [default] Rendezvous'ing worker group
I0406 20:59:51.030000 2129 torch/distributed/elastic/agent/server/api.py:539] [default] Rendezvous complete for workers. Result:
I0406 20:59:51.030000 2129 torch/distributed/elastic/agent/server/api.py:539]   restart_count=0
I0406 20:59:51.030000 2129 torch/distributed/elastic/agent/server/api.py:539]   master_addr=2f3d44d8b0ea
I0406 20:59:51.030000 2129 torch/distributed/elastic/agent/server/api.py:539]   master_port=44403
I0406 20:59:51.030000 2129 torch/distributed/elastic/agent/server/api.py:539]   group_rank=0
I0406 20:59:51.030000 2129 torch/distributed/elastic/agent/server/api.py:539]   group_world_size=2
I0406 20:59:51.030000 2129 torch/distributed/elastic/agent/server/api.py:539]   local_ranks=[0]
I0406 20:59:51.030000 2129 torch/distributed/elastic/agent/server/api.py:539]   role_ranks=[0]
I0406 20:59:51.030000 2129 torch/distributed/elastic/agent/server/api.py:539]   global_ranks=[0]
I0406 20:59:51.030000 2129 torch/distributed/elastic/agent/server/api.py:539]   role_world_sizes=[2]
I0406 20:59:51.030000 2129 torch/distributed/elastic/agent/server/api.py:539]   global_world_sizes=[2]
I0406 20:59:51.030000 2129 torch/distributed/elastic/agent/server/api.py:539]   event_log_handler=null
I0406 20:59:51.030000 2129 torch/distributed/elastic/agent/server/api.py:539]
I0406 20:59:51.030000 2129 torch/distributed/elastic/agent/server/api.py:701] [default] Starting worker group
I0406 20:59:51.030000 2129 torch/distributed/elastic/agent/server/local_elastic_agent.py:299] use_agent_store: True
I0406 20:59:51.031000 2129 torch/distributed/elastic/agent/server/local_elastic_agent.py:195] Environment variable 'TORCHELASTIC_ENABLE_FILE_TIMER' not found. Do not start FileTimerServer.
I0406 20:59:51.031000 2129 torch/distributed/elastic/agent/server/local_elastic_agent.py:239] Environment variable 'TORCHELASTIC_HEALTH_CHECK_PORT' not found. Do not start health check.
2f3d44d8b0ea:2161:2161 [0] NCCL INFO ENV/Plugin: Could not find: libnccl-env.so
2f3d44d8b0ea:2161:2161 [0] NCCL INFO Bootstrap: Using eth0:172.17.0.2<0>
2f3d44d8b0ea:2161:2161 [0] NCCL INFO cudaDriverVersion 12080
2f3d44d8b0ea:2161:2161 [0] NCCL INFO NCCL version 2.28.9+cuda12.9
2f3d44d8b0ea:2161:2161 [0] NCCL INFO Comm config Blocking set to 1
2f3d44d8b0ea:2161:2161 [0] NCCL INFO NET/Plugin: Could not find: libnccl-net.so
2f3d44d8b0ea:2161:2161 [0] NCCL INFO Failed to open libibverbs.so[.1]
2f3d44d8b0ea:2161:2161 [0] NCCL INFO transport/net_ib.cc:852 -> 3
2f3d44d8b0ea:2161:2161 [0] NCCL INFO Failed to initialize NET plugin IB
2f3d44d8b0ea:2161:2161 [0] NCCL INFO NET/Socket : Using [0]eth0:172.17.0.2<0> [1]wg0:10.100.10.1<0>
2f3d44d8b0ea:2161:2161 [0] NCCL INFO Initialized NET plugin Socket
2f3d44d8b0ea:2161:2161 [0] NCCL INFO Assigned NET plugin Socket to comm
2f3d44d8b0ea:2161:2161 [0] NCCL INFO Using network Socket
2f3d44d8b0ea:2161:2161 [0] NCCL INFO ncclCommInitRankConfig comm 0x3afce880 rank 0 nranks 2 cudaDev 0 nvmlDev 0 busId 70 commId 0x622b058bf1d66d35 - Init START
```

This is the second node; rank 1.  Ouch!  It failed!

```log
root@dd3bc9f2214d:/app# NCCL_DEBUG=INFO LOGLEVEL=DEBUG STEPS=30 torchrun \
  --nnodes=2 --nproc_per_node=1 --node_rank=1 \
  --rdzv_id=test_job --rdzv_backend=c10d --rdzv_endpoint=10.100.10.1:29500 \
  train.py
I0406 20:59:49.909000 131 torch/distributed/run.py:735] Using nproc_per_node=1.
I0406 20:59:49.992000 131 torch/distributed/launcher/api.py:131] Using default numa options = None
I0406 20:59:49.992000 131 torch/distributed/launcher/api.py:224] Starting elastic_operator with launch configs:
I0406 20:59:49.992000 131 torch/distributed/launcher/api.py:224]   entrypoint               : train.py
I0406 20:59:49.992000 131 torch/distributed/launcher/api.py:224]   min_nodes                : 2
I0406 20:59:49.992000 131 torch/distributed/launcher/api.py:224]   max_nodes                : 2
I0406 20:59:49.992000 131 torch/distributed/launcher/api.py:224]   nproc_per_node           : 1
I0406 20:59:49.992000 131 torch/distributed/launcher/api.py:224]   run_id                   : test_job
I0406 20:59:49.992000 131 torch/distributed/launcher/api.py:224]   rdzv_backend             : c10d
I0406 20:59:49.992000 131 torch/distributed/launcher/api.py:224]   rdzv_endpoint            : 10.100.10.1:29500
I0406 20:59:49.992000 131 torch/distributed/launcher/api.py:224]   rdzv_configs             : {'timeout': 900}
I0406 20:59:49.992000 131 torch/distributed/launcher/api.py:224]   max_restarts             : 0
I0406 20:59:49.992000 131 torch/distributed/launcher/api.py:224]   monitor_interval         : 0.1
I0406 20:59:49.992000 131 torch/distributed/launcher/api.py:224]   log_dir                  : /tmp/torchelastic_bxkvajbb
I0406 20:59:49.992000 131 torch/distributed/launcher/api.py:224]   metrics_cfg              : {}
I0406 20:59:49.992000 131 torch/distributed/launcher/api.py:224]   event_log_handler        : null
I0406 20:59:49.992000 131 torch/distributed/launcher/api.py:224]   numa_options             : None
I0406 20:59:49.992000 131 torch/distributed/launcher/api.py:224]   signals_to_handle        : SIGTERM,SIGINT,SIGHUP,SIGQUIT
I0406 20:59:49.992000 131 torch/distributed/launcher/api.py:224]   duplicate_stdout_filters : []
I0406 20:59:49.992000 131 torch/distributed/launcher/api.py:224]   duplicate_stderr_filters : []
I0406 20:59:49.992000 131 torch/distributed/launcher/api.py:224]
I0406 20:59:50.023000 131 torch/distributed/elastic/agent/server/api.py:898] [default] starting workers for entrypoint: python3
I0406 20:59:50.023000 131 torch/distributed/elastic/agent/server/api.py:693] [default] Rendezvous'ing worker group
I0406 20:59:51.032000 131 torch/distributed/elastic/agent/server/api.py:539] [default] Rendezvous complete for workers. Result:
I0406 20:59:51.032000 131 torch/distributed/elastic/agent/server/api.py:539]   restart_count=0
I0406 20:59:51.032000 131 torch/distributed/elastic/agent/server/api.py:539]   master_addr=2f3d44d8b0ea
I0406 20:59:51.032000 131 torch/distributed/elastic/agent/server/api.py:539]   master_port=44403
I0406 20:59:51.032000 131 torch/distributed/elastic/agent/server/api.py:539]   group_rank=1
I0406 20:59:51.032000 131 torch/distributed/elastic/agent/server/api.py:539]   group_world_size=2
I0406 20:59:51.032000 131 torch/distributed/elastic/agent/server/api.py:539]   local_ranks=[0]
I0406 20:59:51.032000 131 torch/distributed/elastic/agent/server/api.py:539]   role_ranks=[1]
I0406 20:59:51.032000 131 torch/distributed/elastic/agent/server/api.py:539]   global_ranks=[1]
I0406 20:59:51.032000 131 torch/distributed/elastic/agent/server/api.py:539]   role_world_sizes=[2]
I0406 20:59:51.032000 131 torch/distributed/elastic/agent/server/api.py:539]   global_world_sizes=[2]
I0406 20:59:51.032000 131 torch/distributed/elastic/agent/server/api.py:539]   event_log_handler=null
I0406 20:59:51.032000 131 torch/distributed/elastic/agent/server/api.py:539]
I0406 20:59:51.032000 131 torch/distributed/elastic/agent/server/api.py:701] [default] Starting worker group
I0406 20:59:51.032000 131 torch/distributed/elastic/agent/server/local_elastic_agent.py:299] use_agent_store: True
I0406 20:59:51.033000 131 torch/distributed/elastic/agent/server/local_elastic_agent.py:195] Environment variable 'TORCHELASTIC_ENABLE_FILE_TIMER' not found. Do not start FileTimerServer.
I0406 20:59:51.033000 131 torch/distributed/elastic/agent/server/local_elastic_agent.py:239] Environment variable 'TORCHELASTIC_HEALTH_CHECK_PORT' not found. Do not start health check.
[W406 20:59:52.360496683 socket.cpp:764] [c10d] The IPv6 network addresses of (2f3d44d8b0ea, 44403) cannot be retrieved (gai error: -2 - Name or service not known).
[W406 20:59:52.816683350 socket.cpp:764] [c10d] The IPv6 network addresses of (2f3d44d8b0ea, 44403) cannot be retrieved (gai error: -2 - Name or service not known).
[W406 20:59:53.625858136 socket.cpp:764] [c10d] The IPv6 network addresses of (2f3d44d8b0ea, 44403) cannot be retrieved (gai error: -2 - Name or service not known).
[W406 20:59:54.305711072 socket.cpp:764] [c10d] The IPv6 network addresses of (2f3d44d8b0ea, 44403) cannot be retrieved (gai error: -2 - Name or service not known).
[W406 20:59:56.854604631 socket.cpp:764] [c10d] The IPv6 network addresses of (2f3d44d8b0ea, 44403) cannot be retrieved (gai error: -2 - Name or service not known).
[W406 20:59:58.498181734 socket.cpp:764] [c10d] The IPv6 network addresses of (2f3d44d8b0ea, 44403) cannot be retrieved (gai error: -2 - Name or service not known).
[W406 21:00:02.234917616 socket.cpp:764] [c10d] The IPv6 network addresses of (2f3d44d8b0ea, 44403) cannot be retrieved (gai error: -2 - Name or service not known).
[W406 21:00:10.139460584 socket.cpp:764] [c10d] The IPv6 network addresses of (2f3d44d8b0ea, 44403) cannot be retrieved (gai error: -2 - Name or service not known).
[W406 21:00:17.403585181 socket.cpp:764] [c10d] The IPv6 network addresses of (2f3d44d8b0ea, 44403) cannot be retrieved (gai error: -2 - Name or service not known).
[W406 21:00:25.605543230 socket.cpp:764] [c10d] The IPv6 network addresses of (2f3d44d8b0ea, 44403) cannot be retrieved (gai error: -2 - Name or service not known).
[W406 21:00:36.109063693 socket.cpp:764] [c10d] The IPv6 network addresses of (2f3d44d8b0ea, 44403) cannot be retrieved (gai error: -2 - Name or service not known).
[W406 21:01:06.103270816 socket.cpp:764] [c10d] The IPv6 network addresses of (2f3d44d8b0ea, 44403) cannot be retrieved (gai error: -2 - Name or service not known).
[W406 21:01:32.333383773 socket.cpp:764] [c10d] The IPv6 network addresses of (2f3d44d8b0ea, 44403) cannot be retrieved (gai error: -2 - Name or service not known).
[W406 21:02:48.901471959 socket.cpp:764] [c10d] The IPv6 network addresses of (2f3d44d8b0ea, 44403) cannot be retrieved (gai error: -2 - Name or service not known).
[W406 21:03:22.483614915 socket.cpp:764] [c10d] The IPv6 network addresses of (2f3d44d8b0ea, 44403) cannot be retrieved (gai error: -2 - Name or service not known).
[W406 21:04:02.944138445 socket.cpp:764] [c10d] The IPv6 network addresses of (2f3d44d8b0ea, 44403) cannot be retrieved (gai error: -2 - Name or service not known).
[W406 21:05:13.293170306 socket.cpp:764] [c10d] The IPv6 network addresses of (2f3d44d8b0ea, 44403) cannot be retrieved (gai error: -2 - Name or service not known).
[W406 21:06:36.582248302 socket.cpp:764] [c10d] The IPv6 network addresses of (2f3d44d8b0ea, 44403) cannot be retrieved (gai error: -2 - Name or service not known).
[W406 21:07:41.396091809 socket.cpp:764] [c10d] The IPv6 network addresses of (2f3d44d8b0ea, 44403) cannot be retrieved (gai error: -2 - Name or service not known).
[W406 21:08:12.654681204 socket.cpp:764] [c10d] The IPv6 network addresses of (2f3d44d8b0ea, 44403) cannot be retrieved (gai error: -2 - Name or service not known).
[W406 21:08:54.086316393 socket.cpp:764] [c10d] The IPv6 network addresses of (2f3d44d8b0ea, 44403) cannot be retrieved (gai error: -2 - Name or service not known).
[E406 21:08:54.086357761 socket.cpp:1028] [c10d] The client socket has timed out after 600000ms while trying to connect to (2f3d44d8b0ea, 44403).
[W406 21:08:54.086697964 TCPStore.cpp:340] [c10d] TCP client failed to connect/validate to host 2f3d44d8b0ea:44403 - retrying (try=0, timeout=600000ms, delay=74560ms): The client socket has timed out after 600000ms while trying to connect to (2f3d44d8b0ea, 44403).
Exception raised from throwTimeoutError at /pytorch/torch/csrc/distributed/c10d/socket.cpp:1030 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) + 0x9d (0x7cadc417205d in /usr/local/lib/python3.12/dist-packages/torch/lib/libc10.so)
frame #1: <unknown function> + 0x16daa1c (0x7cace79eea1c in /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_cpu.so)
frame #2: <unknown function> + 0x6b25497 (0x7cacece39497 in /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_cpu.so)
frame #3: <unknown function> + 0x6b256cf (0x7cacece396cf in /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_cpu.so)
frame #4: <unknown function> + 0x6b25b57 (0x7cacece39b57 in /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_cpu.so)
frame #5: <unknown function> + 0x6a87113 (0x7cacecd9b113 in /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_cpu.so)
frame #6: c10d::TCPStore::TCPStore(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, c10d::TCPStoreOptions const&) + 0x41d (0x7cacecda1ead in /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_cpu.so)
frame #7: <unknown function> + 0xecbf51 (0x7cacfc872f51 in /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_python.so)
frame #8: <unknown function> + 0x411060 (0x7cacfbdb8060 in /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_python.so)
frame #9: /usr/bin/python3() [0x581e9f]
...
frame #32: _start + 0x25 (0x6576c5 in /usr/bin/python3)

[W406 21:10:08.720419645 socket.cpp:764] [c10d] The IPv6 network addresses of (2f3d44d8b0ea, 44403) cannot be retrieved (gai error: -2 - Name or service not known).
[W406 21:10:48.462175738 socket.cpp:764] [c10d] The IPv6 network addresses of (2f3d44d8b0ea, 44403) cannot be retrieved (gai error: -2 - Name or service not known).
[W406 21:11:42.498706496 socket.cpp:764] [c10d] The IPv6 network addresses of (2f3d44d8b0ea, 44403) cannot be retrieved (gai error: -2 - Name or service not known).
[W406 21:12:12.849740252 socket.cpp:764] [c10d] The IPv6 network addresses of (2f3d44d8b0ea, 44403) cannot be retrieved (gai error: -2 - Name or service not known).
[W406 21:12:43.260536804 socket.cpp:764] [c10d] The IPv6 network addresses of (2f3d44d8b0ea, 44403) cannot be retrieved (gai error: -2 - Name or service not known).
[W406 21:13:41.678641445 socket.cpp:764] [c10d] The IPv6 network addresses of (2f3d44d8b0ea, 44403) cannot be retrieved (gai error: -2 - Name or service not known).
[W406 21:14:14.137672492 socket.cpp:764] [c10d] The IPv6 network addresses of (2f3d44d8b0ea, 44403) cannot be retrieved (gai error: -2 - Name or service not known).
[W406 21:15:41.308584356 socket.cpp:764] [c10d] The IPv6 network addresses of (2f3d44d8b0ea, 44403) cannot be retrieved (gai error: -2 - Name or service not known).
[W406 21:16:48.530796444 socket.cpp:764] [c10d] The IPv6 network addresses of (2f3d44d8b0ea, 44403) cannot be retrieved (gai error: -2 - Name or service not known).
[W406 21:17:19.359159001 socket.cpp:764] [c10d] The IPv6 network addresses of (2f3d44d8b0ea, 44403) cannot be retrieved (gai error: -2 - Name or service not known).
[W406 21:18:45.552803370 socket.cpp:764] [c10d] The IPv6 network addresses of (2f3d44d8b0ea, 44403) cannot be retrieved (gai error: -2 - Name or service not known).
[W406 21:19:24.849525746 socket.cpp:764] [c10d] The IPv6 network addresses of (2f3d44d8b0ea, 44403) cannot be retrieved (gai error: -2 - Name or service not known).
[E406 21:19:24.849598964 socket.cpp:1028] [c10d] The client socket has timed out after 600000ms while trying to connect to (2f3d44d8b0ea, 44403).
[E406 21:19:24.849748526 TCPStore.cpp:328] [c10d] TCP client failed to connect/validate to host 2f3d44d8b0ea:44403 - timed out (try=1, timeout=600000ms): The client socket has timed out after 600000ms while trying to connect to (2f3d44d8b0ea, 44403).
Exception raised from throwTimeoutError at /pytorch/torch/csrc/distributed/c10d/socket.cpp:1030 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) + 0x9d (0x7cadc417205d in /usr/local/lib/python3.12/dist-packages/torch/lib/libc10.so)
frame #1: <unknown function> + 0x16daa1c (0x7cace79eea1c in /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_cpu.so)
frame #2: <unknown function> + 0x6b25497 (0x7cacece39497 in /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_cpu.so)
frame #3: <unknown function> + 0x6b256cf (0x7cacece396cf in /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_cpu.so)
frame #4: <unknown function> + 0x6b25b57 (0x7cacece39b57 in /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_cpu.so)
frame #5: <unknown function> + 0x6a87113 (0x7cacecd9b113 in /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_cpu.so)
frame #6: c10d::TCPStore::TCPStore(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, c10d::TCPStoreOptions const&) + 0x41d (0x7cacecda1ead in /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_cpu.so)
frame #7: <unknown function> + 0xecbf51 (0x7cacfc872f51 in /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_python.so)
frame #8: <unknown function> + 0x411060 (0x7cacfbdb8060 in /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_python.so)
frame #9: /usr/bin/python3() [0x581e9f]
...
frame #31: __libc_start_main + 0x8b (0x7cadc4e2128b in /lib/x86_64-linux-gnu/libc.so.6)
frame #32: _start + 0x25 (0x6576c5 in /usr/bin/python3)

Traceback (most recent call last):
  File "/app/train.py", line 54, in <module>
    run_training()
  File "/app/train.py", line 18, in run_training
    setup()
  File "/app/train.py", line 11, in setup
    dist.init_process_group(backend="nccl")
  File "/usr/local/lib/python3.12/dist-packages/torch/distributed/c10d_logger.py", line 83, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/distributed/c10d_logger.py", line 97, in wrapper
    func_return = func(*args, **kwargs)
                  ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/distributed/distributed_c10d.py", line 1831, in init_process_group
    store, rank, world_size = next(rendezvous_iterator)
                              ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/distributed/rendezvous.py", line 280, in _env_rendezvous_handler
    store = _create_c10d_store(
            ^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/distributed/rendezvous.py", line 190, in _create_c10d_store
    return TCPStore(
           ^^^^^^^^^
torch.distributed.DistNetworkError: The client socket has timed out after 600000ms while trying to connect to (2f3d44d8b0ea, 44403).
E0406 21:19:25.365000 131 torch/distributed/elastic/multiprocessing/api.py:986] failed (exitcode: 1) local_rank: 0 (pid: 161) of binary: /usr/bin/python3
I0406 21:19:25.376000 131 torch/distributed/elastic/multiprocessing/errors/__init__.py:375] ('local_rank %s FAILED with no error file. Decorate your entrypoint fn with @record for traceback info. See: https://pytorch.org/docs/stable/elastic/errors.html', 1)
Traceback (most recent call last):
  File "/usr/local/bin/torchrun", line 6, in <module>
    sys.exit(main())
             ^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 362, in wrapper
    return f(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/distributed/run.py", line 990, in main
    run(args)
  File "/usr/local/lib/python3.12/dist-packages/torch/distributed/run.py", line 981, in run
    elastic_launch(
  File "/usr/local/lib/python3.12/dist-packages/torch/distributed/launcher/api.py", line 170, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/distributed/launcher/api.py", line 317, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError:
============================================================
train.py FAILED
------------------------------------------------------------
Failures:
  <NO_OTHER_FAILURES>
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
  time      : 2026-04-06_21:19:25
  host      : dd3bc9f2214d
  rank      : 1 (local_rank: 0)
  exitcode  : 1 (pid: 161)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
============================================================
root@dd3bc9f2214d:/app# e
```

The reason is obvious.  The message `TCP client failed to connect/validate to host 2f3d44d8b0ea:44403` implies It tried to communicate via hostname, which is not reachable.

Looking at the first node’s log carefully, you see `master_addr=2f3d44d8b0ea`. This means the first node advertised itself with the hostname. In this case, we want to use the IP address.

Let’s find out where `torchrun` prints this log. Well, it’s in the log: `torch/distributed/elastic/agent/server/api.py:539`, [here](https://github.com/pytorch/pytorch/blob/v2.11.0/torch/distributed/elastic/agent/server/api.py#L539)​.  The line `master_addr = spec.master_addr or rdzv_info.bootstrap_store_info.master_addr` does that. It’s easy to confirm it with pdb.

```log
root@2f3d44d8b0ea:/app# NCCL_DEBUG=INFO LOGLEVEL=DEBUG STEPS=30 torchrun   --nnodes=1 --nproc_per_node=1 --node_rank=0   --rdzv_id=test_job --rdzv_backend=c10d --rdzv_endpoint=10.100.10.1:29500   --rdzv_conf=is_host=True   train.py
> /usr/local/bin/torchrun(6)<module>()
-> sys.argv[0] = sys.argv[0].removesuffix('.exe')
(Pdb) b /usr/local/lib/python3.12/dist-packages/torch/distributed/elastic/agent/server/api.py:524
Breakpoint 1 at /usr/local/lib/python3.12/dist-packages/torch/distributed/elastic/agent/server/api.py:524
(Pdb) c
I0406 21:51:17.139000 2208 torch/distributed/run.py:735] Using nproc_per_node=1.
I0406 21:51:17.236000 2208 torch/distributed/launcher/api.py:131] Using default numa options = None
I0406 21:51:17.238000 2208 torch/distributed/launcher/api.py:224] Starting elastic_operator with launch configs:
I0406 21:51:17.238000 2208 torch/distributed/launcher/api.py:224]   entrypoint               : train.py
I0406 21:51:17.238000 2208 torch/distributed/launcher/api.py:224]   min_nodes                : 1
I0406 21:51:17.238000 2208 torch/distributed/launcher/api.py:224]   max_nodes                : 1
I0406 21:51:17.238000 2208 torch/distributed/launcher/api.py:224]   nproc_per_node           : 1
I0406 21:51:17.238000 2208 torch/distributed/launcher/api.py:224]   run_id                   : test_job
I0406 21:51:17.238000 2208 torch/distributed/launcher/api.py:224]   rdzv_backend             : c10d
I0406 21:51:17.238000 2208 torch/distributed/launcher/api.py:224]   rdzv_endpoint            : 10.100.10.1:29500
I0406 21:51:17.238000 2208 torch/distributed/launcher/api.py:224]   rdzv_configs             : {'is_host': 'True', 'timeout': 900}
I0406 21:51:17.238000 2208 torch/distributed/launcher/api.py:224]   max_restarts             : 0
I0406 21:51:17.238000 2208 torch/distributed/launcher/api.py:224]   monitor_interval         : 0.1
I0406 21:51:17.238000 2208 torch/distributed/launcher/api.py:224]   log_dir                  : /tmp/torchelastic_qk97alc2
I0406 21:51:17.238000 2208 torch/distributed/launcher/api.py:224]   metrics_cfg              : {}
I0406 21:51:17.238000 2208 torch/distributed/launcher/api.py:224]   event_log_handler        : null
I0406 21:51:17.238000 2208 torch/distributed/launcher/api.py:224]   numa_options             : None
I0406 21:51:17.238000 2208 torch/distributed/launcher/api.py:224]   signals_to_handle        : SIGTERM,SIGINT,SIGHUP,SIGQUIT
I0406 21:51:17.238000 2208 torch/distributed/launcher/api.py:224]   duplicate_stdout_filters : []
I0406 21:51:17.238000 2208 torch/distributed/launcher/api.py:224]   duplicate_stderr_filters : []
I0406 21:51:17.238000 2208 torch/distributed/launcher/api.py:224]
I0406 21:51:17.265000 2208 torch/distributed/elastic/agent/server/api.py:898] [default] starting workers for entrypoint: python3
I0406 21:51:17.267000 2208 torch/distributed/elastic/agent/server/api.py:693] [default] Rendezvous'ing worker group
> /usr/local/lib/python3.12/dist-packages/torch/distributed/elastic/agent/server/api.py(524)_rendezvous()
-> self._store = store
(Pdb) p spec.master_addr
None
(Pdb) p rdzv_info.bootstrap_store_info.master_addr
'2f3d44d8b0ea'
(Pdb)
```

It’s coming from `rdzv_info.bootstrap_store_info.master_addr`. Who set the hostname in it? We need to step in the line `spec.rdzv_handler.next_rendezvous()`.

```log
root@2f3d44d8b0ea:/app# NCCL_DEBUG=INFO LOGLEVEL=DEBUG STEPS=30 torchrun   --nnodes=1 --nproc_per_node=1 --node_rank=0   --rdzv_id=test_job --rdzv_backend=c10d --rdzv_endpoint=10.100.10.1:29500   --rdzv_conf=is_host=True   train.py
> /usr/local/bin/torchrun(6)<module>()
-> sys.argv[0] = sys.argv[0].removesuffix('.exe')
(Pdb) b /usr/local/lib/python3.12/dist-packages/torch/distributed/elastic/agent/server/api.py:514
Breakpoint 1 at /usr/local/lib/python3.12/dist-packages/torch/distributed/elastic/agent/server/api.py:514
(Pdb) c
I0406 23:42:36.311000 2320 torch/distributed/run.py:735] Using nproc_per_node=1.
I0406 23:42:36.408000 2320 torch/distributed/launcher/api.py:131] Using default numa options = None
I0406 23:42:36.411000 2320 torch/distributed/launcher/api.py:224] Starting elastic_operator with launch configs:
I0406 23:42:36.411000 2320 torch/distributed/launcher/api.py:224]   entrypoint               : train.py
I0406 23:42:36.411000 2320 torch/distributed/launcher/api.py:224]   min_nodes                : 1
I0406 23:42:36.411000 2320 torch/distributed/launcher/api.py:224]   max_nodes                : 1
I0406 23:42:36.411000 2320 torch/distributed/launcher/api.py:224]   nproc_per_node           : 1
I0406 23:42:36.411000 2320 torch/distributed/launcher/api.py:224]   run_id                   : test_job
I0406 23:42:36.411000 2320 torch/distributed/launcher/api.py:224]   rdzv_backend             : c10d
I0406 23:42:36.411000 2320 torch/distributed/launcher/api.py:224]   rdzv_endpoint            : 10.100.10.1:29500
I0406 23:42:36.411000 2320 torch/distributed/launcher/api.py:224]   rdzv_configs             : {'is_host': 'True', 'timeout': 900}
I0406 23:42:36.411000 2320 torch/distributed/launcher/api.py:224]   max_restarts             : 0
I0406 23:42:36.411000 2320 torch/distributed/launcher/api.py:224]   monitor_interval         : 0.1
I0406 23:42:36.411000 2320 torch/distributed/launcher/api.py:224]   log_dir                  : /tmp/torchelastic_fv9249b9
I0406 23:42:36.411000 2320 torch/distributed/launcher/api.py:224]   metrics_cfg              : {}
I0406 23:42:36.411000 2320 torch/distributed/launcher/api.py:224]   event_log_handler        : null
I0406 23:42:36.411000 2320 torch/distributed/launcher/api.py:224]   numa_options             : None
I0406 23:42:36.411000 2320 torch/distributed/launcher/api.py:224]   signals_to_handle        : SIGTERM,SIGINT,SIGHUP,SIGQUIT
I0406 23:42:36.411000 2320 torch/distributed/launcher/api.py:224]   duplicate_stdout_filters : []
I0406 23:42:36.411000 2320 torch/distributed/launcher/api.py:224]   duplicate_stderr_filters : []
I0406 23:42:36.411000 2320 torch/distributed/launcher/api.py:224]
I0406 23:42:36.443000 2320 torch/distributed/elastic/agent/server/api.py:898] [default] starting workers for entrypoint: python3
I0406 23:42:36.446000 2320 torch/distributed/elastic/agent/server/api.py:693] [default] Rendezvous'ing worker group
> /usr/local/lib/python3.12/dist-packages/torch/distributed/elastic/agent/server/api.py(514)_rendezvous()
-> rdzv_info = spec.rdzv_handler.next_rendezvous()
(Pdb) s
--Call--
> /usr/local/lib/python3.12/dist-packages/torch/distributed/elastic/rendezvous/dynamic_rendezvous.py(1148)next_rendezvous()
-> def next_rendezvous(self) -> RendezvousInfo:
(Pdb) b /usr/local/lib/python3.12/dist-packages/torch/distributed/elastic/rendezvous/dynamic_rendezvous.py:1206
Breakpoint 2 at /usr/local/lib/python3.12/dist-packages/torch/distributed/elastic/rendezvous/dynamic_rendezvous.py:1206
(Pdb) c
> /usr/local/lib/python3.12/dist-packages/torch/distributed/elastic/rendezvous/dynamic_rendezvous.py(1206)next_rendezvous()
-> if self._bootstrap_store_info is None:
(Pdb) p self._bootstrap_store_info
None
(Pdb) self._this_node.addr
'2f3d44d8b0ea'
```

Alright, we have the hostname in `self._this_node.addr`, which is passed to `RendezvousStoreInfo.build` to build `_bootstrap_store_info`. Who set `self._this_node`? It’s in [`DynamicRendezvousHandler.__init__`](https://github.com/pytorch/pytorch/blob/v2.11.0/torch/distributed/elastic/rendezvous/dynamic_rendezvous.py#L1086).

```log
root@2f3d44d8b0ea:/app# NCCL_DEBUG=INFO LOGLEVEL=DEBUG STEPS=30 torchrun   --nnodes=1 --nproc_per_node=1 --node_rank=0   --rdzv_id=test_job --rdzv_backend=c10d --rdzv_endpoint=10.100.10.1:29500   --rdzv_conf=is_host=True   train.py
> /usr/local/bin/torchrun(6)<module>()
-> sys.argv[0] = sys.argv[0].removesuffix('.exe')
(Pdb) b /usr/local/lib/python3.12/dist-packages/torch/distributed/elastic/rendezvous/dynamic_rendezvous.py:1072
Breakpoint 1 at /usr/local/lib/python3.12/dist-packages/torch/distributed/elastic/rendezvous/dynamic_rendezvous.py:1072
(Pdb) c
I0406 23:46:37.870000 2351 torch/distributed/run.py:735] Using nproc_per_node=1.
I0406 23:46:37.976000 2351 torch/distributed/launcher/api.py:131] Using default numa options = None
I0406 23:46:37.979000 2351 torch/distributed/launcher/api.py:224] Starting elastic_operator with launch configs:
I0406 23:46:37.979000 2351 torch/distributed/launcher/api.py:224]   entrypoint               : train.py
I0406 23:46:37.979000 2351 torch/distributed/launcher/api.py:224]   min_nodes                : 1
I0406 23:46:37.979000 2351 torch/distributed/launcher/api.py:224]   max_nodes                : 1
I0406 23:46:37.979000 2351 torch/distributed/launcher/api.py:224]   nproc_per_node           : 1
I0406 23:46:37.979000 2351 torch/distributed/launcher/api.py:224]   run_id                   : test_job
I0406 23:46:37.979000 2351 torch/distributed/launcher/api.py:224]   rdzv_backend             : c10d
I0406 23:46:37.979000 2351 torch/distributed/launcher/api.py:224]   rdzv_endpoint            : 10.100.10.1:29500
I0406 23:46:37.979000 2351 torch/distributed/launcher/api.py:224]   rdzv_configs             : {'is_host': 'True', 'timeout': 900}
I0406 23:46:37.979000 2351 torch/distributed/launcher/api.py:224]   max_restarts             : 0
I0406 23:46:37.979000 2351 torch/distributed/launcher/api.py:224]   monitor_interval         : 0.1
I0406 23:46:37.979000 2351 torch/distributed/launcher/api.py:224]   log_dir                  : /tmp/torchelastic_8d10hlwv
I0406 23:46:37.979000 2351 torch/distributed/launcher/api.py:224]   metrics_cfg              : {}
I0406 23:46:37.979000 2351 torch/distributed/launcher/api.py:224]   event_log_handler        : null
I0406 23:46:37.979000 2351 torch/distributed/launcher/api.py:224]   numa_options             : None
I0406 23:46:37.979000 2351 torch/distributed/launcher/api.py:224]   signals_to_handle        : SIGTERM,SIGINT,SIGHUP,SIGQUIT
I0406 23:46:37.979000 2351 torch/distributed/launcher/api.py:224]   duplicate_stdout_filters : []
I0406 23:46:37.979000 2351 torch/distributed/launcher/api.py:224]   duplicate_stderr_filters : []
I0406 23:46:37.979000 2351 torch/distributed/launcher/api.py:224]
> /usr/local/lib/python3.12/dist-packages/torch/distributed/elastic/rendezvous/dynamic_rendezvous.py(1072)__init__()
-> if not settings.run_id:
(Pdb) p node
2f3d44d8b0ea_2351_0
(Pdb) p node.addr
'2f3d44d8b0ea'
(Pdb) bt
  /usr/local/bin/torchrun(7)<module>()
-> sys.exit(main())
  /usr/local/lib/python3.12/dist-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py(362)wrapper()
-> return f(*args, **kwargs)
  /usr/local/lib/python3.12/dist-packages/torch/distributed/run.py(990)main()
-> run(args)
  /usr/local/lib/python3.12/dist-packages/torch/distributed/run.py(981)run()
-> elastic_launch(
  /usr/local/lib/python3.12/dist-packages/torch/distributed/launcher/api.py(170)__call__()
-> return launch_agent(self._config, self._entrypoint, list(args))
  /usr/local/lib/python3.12/dist-packages/torch/distributed/launcher/api.py(284)launch_agent()
-> rdzv_handler=rdzv_registry.get_rendezvous_handler(rdzv_parameters),
  /usr/local/lib/python3.12/dist-packages/torch/distributed/elastic/rendezvous/registry.py(96)get_rendezvous_handler()
-> return handler_registry.create_handler(params)
  /usr/local/lib/python3.12/dist-packages/torch/distributed/elastic/rendezvous/api.py(377)create_handler()
-> handler = creator(params)
  /usr/local/lib/python3.12/dist-packages/torch/distributed/elastic/rendezvous/registry.py(48)_create_c10d_handler()
-> return create_handler(store, backend, params)
  /usr/local/lib/python3.12/dist-packages/torch/distributed/elastic/rendezvous/dynamic_rendezvous.py(1436)create_handler()
-> return DynamicRendezvousHandler.from_backend(
  /usr/local/lib/python3.12/dist-packages/torch/distributed/elastic/rendezvous/dynamic_rendezvous.py(1062)from_backend()
-> return cls(node, settings, backend.name, store, state_holder)
> /usr/local/lib/python3.12/dist-packages/torch/distributed/elastic/rendezvous/dynamic_rendezvous.py(1072)__init__()
-> if not settings.run_id:
(Pdb)
```

And `node` is coming from `cls._node_desc_generator.generate(local_addr)` in `from_backend`. Let’s see this `generate` function.

```log
root@2f3d44d8b0ea:/app# NCCL_DEBUG=INFO LOGLEVEL=DEBUG STEPS=30 torchrun   --nnodes=1 --nproc_per_node=1 --node_rank=0   --rdzv_id=test_job --rdzv_backend=c10d --rdzv_endpoint=10.100.10.1:29500   --rdzv_conf=is_host=True   train.py
> /usr/local/bin/torchrun(6)<module>()
-> sys.argv[0] = sys.argv[0].removesuffix('.exe')
(Pdb) b /usr/local/lib/python3.12/dist-packages/torch/distributed/elastic/rendezvous/dynamic_rendezvous.py:1049
Breakpoint 1 at /usr/local/lib/python3.12/dist-packages/torch/distributed/elastic/rendezvous/dynamic_rendezvous.py:1049
(Pdb) c
I0406 23:48:24.537000 2381 torch/distributed/run.py:735] Using nproc_per_node=1.
I0406 23:48:24.633000 2381 torch/distributed/launcher/api.py:131] Using default numa options = None
I0406 23:48:24.635000 2381 torch/distributed/launcher/api.py:224] Starting elastic_operator with launch configs:
I0406 23:48:24.635000 2381 torch/distributed/launcher/api.py:224]   entrypoint               : train.py
I0406 23:48:24.635000 2381 torch/distributed/launcher/api.py:224]   min_nodes                : 1
I0406 23:48:24.635000 2381 torch/distributed/launcher/api.py:224]   max_nodes                : 1
I0406 23:48:24.635000 2381 torch/distributed/launcher/api.py:224]   nproc_per_node           : 1
I0406 23:48:24.635000 2381 torch/distributed/launcher/api.py:224]   run_id                   : test_job
I0406 23:48:24.635000 2381 torch/distributed/launcher/api.py:224]   rdzv_backend             : c10d
I0406 23:48:24.635000 2381 torch/distributed/launcher/api.py:224]   rdzv_endpoint            : 10.100.10.1:29500
I0406 23:48:24.635000 2381 torch/distributed/launcher/api.py:224]   rdzv_configs             : {'is_host': 'True', 'timeout': 900}
I0406 23:48:24.635000 2381 torch/distributed/launcher/api.py:224]   max_restarts             : 0
I0406 23:48:24.635000 2381 torch/distributed/launcher/api.py:224]   monitor_interval         : 0.1
I0406 23:48:24.635000 2381 torch/distributed/launcher/api.py:224]   log_dir                  : /tmp/torchelastic_9e8y2vz6
I0406 23:48:24.635000 2381 torch/distributed/launcher/api.py:224]   metrics_cfg              : {}
I0406 23:48:24.635000 2381 torch/distributed/launcher/api.py:224]   event_log_handler        : null
I0406 23:48:24.635000 2381 torch/distributed/launcher/api.py:224]   numa_options             : None
I0406 23:48:24.635000 2381 torch/distributed/launcher/api.py:224]   signals_to_handle        : SIGTERM,SIGINT,SIGHUP,SIGQUIT
I0406 23:48:24.635000 2381 torch/distributed/launcher/api.py:224]   duplicate_stdout_filters : []
I0406 23:48:24.635000 2381 torch/distributed/launcher/api.py:224]   duplicate_stderr_filters : []
I0406 23:48:24.635000 2381 torch/distributed/launcher/api.py:224]
> /usr/local/lib/python3.12/dist-packages/torch/distributed/elastic/rendezvous/dynamic_rendezvous.py(1049)from_backend()
-> node = cls._node_desc_generator.generate(local_addr)
(Pdb) p local_addr
None
(Pdb) s
--Call--
> /usr/local/lib/python3.12/dist-packages/torch/distributed/elastic/rendezvous/dynamic_rendezvous.py(261)generate()
-> def generate(self, local_addr: str | None = None) -> _NodeDesc:
(Pdb)
```

`local_addr` is `None`, and see the function [`generate`](https://github.com/pytorch/pytorch/blob/v2.11.0/torch/distributed/elastic/rendezvous/dynamic_rendezvous.py#L261). It does `local_addr or socket.getfqdn()`! This is where PyTorch prefers the hostname though the varialbe name is `local_addr`.

Now the question is how to overwrite it. In other words, how to set the IP address in this `local_addr`? We backtrack a little bit more. In `create_handler`, we pass `params.local_addr` to `DynamicRendezvousHandler.from_backend`. And `params` is `RendezvousParameters`. We already know this class, right? It’s where we store `is_host` earlier.

So we specify `--rdzv_conf=is_host=True,local_addr=10.100.10.1`?

```log
root@2f3d44d8b0ea:/app# NCCL_DEBUG=INFO LOGLEVEL=DEBUG STEPS=30 torchrun   --nnodes=1 --nproc_per_node=1 --node_rank=0   --rdzv_id=test_job --rdzv_backend=c10d --rdzv_endpoint=10.100.10.1:29500   --rdzv_conf=is_host=True,local_addr=10.100.10.1   train.py
I0407 03:21:26.187000 2508 torch/distributed/run.py:735] Using nproc_per_node=1.
I0407 03:21:26.296000 2508 torch/distributed/launcher/api.py:131] Using default numa options = None
I0407 03:21:26.298000 2508 torch/distributed/launcher/api.py:224] Starting elastic_operator with launch configs:
I0407 03:21:26.298000 2508 torch/distributed/launcher/api.py:224]   entrypoint               : train.py
I0407 03:21:26.298000 2508 torch/distributed/launcher/api.py:224]   min_nodes                : 1
I0407 03:21:26.298000 2508 torch/distributed/launcher/api.py:224]   max_nodes                : 1
I0407 03:21:26.298000 2508 torch/distributed/launcher/api.py:224]   nproc_per_node           : 1
I0407 03:21:26.298000 2508 torch/distributed/launcher/api.py:224]   run_id                   : test_job
I0407 03:21:26.298000 2508 torch/distributed/launcher/api.py:224]   rdzv_backend             : c10d
I0407 03:21:26.298000 2508 torch/distributed/launcher/api.py:224]   rdzv_endpoint            : 10.100.10.1:29500
I0407 03:21:26.298000 2508 torch/distributed/launcher/api.py:224]   rdzv_configs             : {'is_host': 'True', 'local_addr': '10.100.10.1', 'timeout': 900}
I0407 03:21:26.298000 2508 torch/distributed/launcher/api.py:224]   max_restarts             : 0
I0407 03:21:26.298000 2508 torch/distributed/launcher/api.py:224]   monitor_interval         : 0.1
I0407 03:21:26.298000 2508 torch/distributed/launcher/api.py:224]   log_dir                  : /tmp/torchelastic_zhdsgvr7
I0407 03:21:26.298000 2508 torch/distributed/launcher/api.py:224]   metrics_cfg              : {}
I0407 03:21:26.298000 2508 torch/distributed/launcher/api.py:224]   event_log_handler        : null
I0407 03:21:26.298000 2508 torch/distributed/launcher/api.py:224]   numa_options             : None
I0407 03:21:26.298000 2508 torch/distributed/launcher/api.py:224]   signals_to_handle        : SIGTERM,SIGINT,SIGHUP,SIGQUIT
I0407 03:21:26.298000 2508 torch/distributed/launcher/api.py:224]   duplicate_stdout_filters : []
I0407 03:21:26.298000 2508 torch/distributed/launcher/api.py:224]   duplicate_stderr_filters : []
I0407 03:21:26.298000 2508 torch/distributed/launcher/api.py:224]
Traceback (most recent call last):
  File "/usr/local/bin/torchrun", line 6, in <module>
    sys.exit(main())
             ^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 362, in wrapper
    return f(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/distributed/run.py", line 990, in main
    run(args)
  File "/usr/local/lib/python3.12/dist-packages/torch/distributed/run.py", line 981, in run
    elastic_launch(
  File "/usr/local/lib/python3.12/dist-packages/torch/distributed/launcher/api.py", line 170, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/distributed/launcher/api.py", line 264, in launch_agent
    rdzv_parameters = RendezvousParameters(
                      ^^^^^^^^^^^^^^^^^^^^^
TypeError: torch.distributed.elastic.rendezvous.api.RendezvousParameters() got multiple values for keyword argument 'local_addr'
root@2f3d44d8b0ea:/app#
```

It was close, but failed with `TypeError: torch.distributed.elastic.rendezvous.api.RendezvousParameters() got multiple values for keyword argument 'local_addr'` because PyTorch instantiates `RendezvousParameters` as below. It explicitly specifies `local_addr` already, so we cannot include it in `config.rdzv_configs`.

```python
    rdzv_parameters = RendezvousParameters(
        backend=config.rdzv_backend,
        endpoint=config.rdzv_endpoint,
        run_id=config.run_id,
        min_nodes=config.min_nodes,
        max_nodes=config.max_nodes,
        local_addr=config.local_addr,
        **config.rdzv_configs,
    )
```

Where does `config.local_addr` come from? It’s from `args.local_addr` in `config_from_args`. `torchrun` directly takes `--local_addr` parameter.

Let’s try one more time! This is from the master node.

```log
root@2f3d44d8b0ea:/app# NCCL_DEBUG=INFO LOGLEVEL=DEBUG NCCL_SOCKET_IFNAME=wg0 STEPS=30 torchrun   --nnodes=2 --nproc_per_node=1 --node_rank=0   --rdzv_id=test_job --rdzv_backend=c10d --rdzv_endpoint=10.100.10.1:29500   --rdzv_conf=is_host=1   --local_addr=10.100.10.1   /app/train.py
I0407 03:33:29.820000 2911 torch/distributed/run.py:735] Using nproc_per_node=1.
I0407 03:33:29.926000 2911 torch/distributed/launcher/api.py:131] Using default numa options = None
I0407 03:33:29.928000 2911 torch/distributed/launcher/api.py:224] Starting elastic_operator with launch configs:
I0407 03:33:29.928000 2911 torch/distributed/launcher/api.py:224]   entrypoint               : /app/train.py
I0407 03:33:29.928000 2911 torch/distributed/launcher/api.py:224]   min_nodes                : 2
I0407 03:33:29.928000 2911 torch/distributed/launcher/api.py:224]   max_nodes                : 2
I0407 03:33:29.928000 2911 torch/distributed/launcher/api.py:224]   nproc_per_node           : 1
I0407 03:33:29.928000 2911 torch/distributed/launcher/api.py:224]   run_id                   : test_job
I0407 03:33:29.928000 2911 torch/distributed/launcher/api.py:224]   rdzv_backend             : c10d
I0407 03:33:29.928000 2911 torch/distributed/launcher/api.py:224]   rdzv_endpoint            : 10.100.10.1:29500
I0407 03:33:29.928000 2911 torch/distributed/launcher/api.py:224]   rdzv_configs             : {'is_host': '1', 'timeout': 900}
I0407 03:33:29.928000 2911 torch/distributed/launcher/api.py:224]   max_restarts             : 0
I0407 03:33:29.928000 2911 torch/distributed/launcher/api.py:224]   monitor_interval         : 0.1
I0407 03:33:29.928000 2911 torch/distributed/launcher/api.py:224]   log_dir                  : /tmp/torchelastic_p9blggcp
I0407 03:33:29.928000 2911 torch/distributed/launcher/api.py:224]   metrics_cfg              : {}
I0407 03:33:29.928000 2911 torch/distributed/launcher/api.py:224]   event_log_handler        : null
I0407 03:33:29.928000 2911 torch/distributed/launcher/api.py:224]   numa_options             : None
I0407 03:33:29.928000 2911 torch/distributed/launcher/api.py:224]   signals_to_handle        : SIGTERM,SIGINT,SIGHUP,SIGQUIT
I0407 03:33:29.928000 2911 torch/distributed/launcher/api.py:224]   duplicate_stdout_filters : []
I0407 03:33:29.928000 2911 torch/distributed/launcher/api.py:224]   duplicate_stderr_filters : []
I0407 03:33:29.928000 2911 torch/distributed/launcher/api.py:224]
I0407 03:33:29.956000 2911 torch/distributed/elastic/agent/server/api.py:898] [default] starting workers for entrypoint: python3
I0407 03:33:29.956000 2911 torch/distributed/elastic/agent/server/api.py:693] [default] Rendezvous'ing worker group
I0407 03:33:39.096000 2911 torch/distributed/elastic/agent/server/api.py:539] [default] Rendezvous complete for workers. Result:
I0407 03:33:39.096000 2911 torch/distributed/elastic/agent/server/api.py:539]   restart_count=0
I0407 03:33:39.096000 2911 torch/distributed/elastic/agent/server/api.py:539]   master_addr=10.100.10.1
I0407 03:33:39.096000 2911 torch/distributed/elastic/agent/server/api.py:539]   master_port=35121
I0407 03:33:39.096000 2911 torch/distributed/elastic/agent/server/api.py:539]   group_rank=0
I0407 03:33:39.096000 2911 torch/distributed/elastic/agent/server/api.py:539]   group_world_size=2
I0407 03:33:39.096000 2911 torch/distributed/elastic/agent/server/api.py:539]   local_ranks=[0]
I0407 03:33:39.096000 2911 torch/distributed/elastic/agent/server/api.py:539]   role_ranks=[0]
I0407 03:33:39.096000 2911 torch/distributed/elastic/agent/server/api.py:539]   global_ranks=[0]
I0407 03:33:39.096000 2911 torch/distributed/elastic/agent/server/api.py:539]   role_world_sizes=[2]
I0407 03:33:39.096000 2911 torch/distributed/elastic/agent/server/api.py:539]   global_world_sizes=[2]
I0407 03:33:39.096000 2911 torch/distributed/elastic/agent/server/api.py:539]   event_log_handler=null
I0407 03:33:39.096000 2911 torch/distributed/elastic/agent/server/api.py:539]
I0407 03:33:39.096000 2911 torch/distributed/elastic/agent/server/api.py:701] [default] Starting worker group
I0407 03:33:39.096000 2911 torch/distributed/elastic/agent/server/local_elastic_agent.py:299] use_agent_store: True
I0407 03:33:39.097000 2911 torch/distributed/elastic/agent/server/local_elastic_agent.py:195] Environment variable 'TORCHELASTIC_ENABLE_FILE_TIMER' not found. Do not start FileTimerServer.
I0407 03:33:39.097000 2911 torch/distributed/elastic/agent/server/local_elastic_agent.py:239] Environment variable 'TORCHELASTIC_HEALTH_CHECK_PORT' not found. Do not start health check.
2f3d44d8b0ea:2943:2943 [0] NCCL INFO ENV/Plugin: Could not find: libnccl-env.so
2f3d44d8b0ea:2943:2943 [0] NCCL INFO NCCL_SOCKET_IFNAME set by environment to wg0
2f3d44d8b0ea:2943:2943 [0] NCCL INFO Bootstrap: Using wg0:10.100.10.1<0>
2f3d44d8b0ea:2943:2943 [0] NCCL INFO cudaDriverVersion 12080
2f3d44d8b0ea:2943:2943 [0] NCCL INFO NCCL version 2.28.9+cuda12.9
2f3d44d8b0ea:2943:2943 [0] NCCL INFO Comm config Blocking set to 1
2f3d44d8b0ea:2943:2943 [0] NCCL INFO NET/Plugin: Could not find: libnccl-net.so
2f3d44d8b0ea:2943:2943 [0] NCCL INFO Failed to open libibverbs.so[.1]
2f3d44d8b0ea:2943:2943 [0] NCCL INFO transport/net_ib.cc:852 -> 3
2f3d44d8b0ea:2943:2943 [0] NCCL INFO Failed to initialize NET plugin IB
2f3d44d8b0ea:2943:2943 [0] NCCL INFO NCCL_SOCKET_IFNAME set by environment to wg0
2f3d44d8b0ea:2943:2943 [0] NCCL INFO NET/Socket : Using [0]wg0:10.100.10.1<0>
2f3d44d8b0ea:2943:2943 [0] NCCL INFO Initialized NET plugin Socket
2f3d44d8b0ea:2943:2943 [0] NCCL INFO Assigned NET plugin Socket to comm
2f3d44d8b0ea:2943:2943 [0] NCCL INFO Using network Socket
2f3d44d8b0ea:2943:2943 [0] NCCL INFO ncclCommInitRankConfig comm 0x271b62c0 rank 0 nranks 2 cudaDev 0 nvmlDev 0 busId 70 commId 0x77d5d096c29cef94 - Init START
2f3d44d8b0ea:2943:2943 [0] NCCL INFO RAS client listening socket at ::1<28028>
2f3d44d8b0ea:2943:2943 [0] NCCL INFO Bootstrap timings total 0.187971 (create 0.000071, send 0.000237, recv 0.183873, ring 0.000632, delay 0.000000)
2f3d44d8b0ea:2943:2943 [0] NCCL INFO NCCL_IGNORE_DISABLED_P2P set by environment to 1.
2f3d44d8b0ea:2943:2943 [0] NCCL INFO ncclTopoGetCpuAffinity: Affinity for GPU 0 is empty, ignoring. (GPU affinity =  ; CPU affinity = 0-27).
2f3d44d8b0ea:2943:2943 [0] NCCL INFO comm 0x271b62c0 rank 0 nRanks 2 nNodes 2 localRanks 1 localRank 0 MNNVL 0
2f3d44d8b0ea:2943:2943 [0] NCCL INFO Channel 00/02 : 0 1
2f3d44d8b0ea:2943:2943 [0] NCCL INFO Channel 01/02 : 0 1
2f3d44d8b0ea:2943:2943 [0] NCCL INFO Trees [0] 1/-1/-1->0->-1 [1] -1/-1/-1->0->1
2f3d44d8b0ea:2943:2943 [0] NCCL INFO P2P Chunksize set to 131072
2f3d44d8b0ea:2943:2943 [0] NCCL INFO PROFILER/Plugin: Could not find: libnccl-profiler.so
2f3d44d8b0ea:2943:2943 [0] NCCL INFO Check P2P Type isAllDirectP2p 1 directMode 0 isAllCudaP2p 1
2f3d44d8b0ea:2943:2977 [0] NCCL INFO [Proxy Service] Device 0 CPU core 5
2f3d44d8b0ea:2943:2978 [0] NCCL INFO [Proxy Service UDS] Device 0 CPU core 24
2f3d44d8b0ea:2943:2943 [0] NCCL INFO TUNER/Plugin: Could not find: libnccl-tuner.so
2f3d44d8b0ea:2943:2943 [0] NCCL INFO threadThresholds 8/8/64 | 16/8/64 | 512 | 512
2f3d44d8b0ea:2943:2943 [0] NCCL INFO 2 coll channels, 2 collnet channels, 0 nvls channels, 2 p2p channels, 2 p2p channels per peer
2f3d44d8b0ea:2943:2943 [0] NCCL INFO CC Off, workFifoBytes 1048576
2f3d44d8b0ea:2943:2943 [0] NCCL INFO ncclCommInitRankConfig comm 0x271b62c0 rank 0 nranks 2 cudaDev 0 nvmlDev 0 busId 70 commId 0x77d5d096c29cef94 - Init COMPLETE
2f3d44d8b0ea:2943:2943 [0] NCCL INFO Init timings - ncclCommInitRankConfig: rank 0 nranks 2 total 0.34 (kernels 0.13, alloc 0.00, bootstrap 0.19, allgathers 0.00, topo 0.01, graphs 0.00, connections 0.00, rest 0.00)
2f3d44d8b0ea:2943:2979 [0] NCCL INFO [Proxy Progress] Device 0 CPU core 10
2f3d44d8b0ea:2943:2943 [0] NCCL INFO Channel 00/0 : 1[0] -> 0[0] [receive] via NET/Socket/0
2f3d44d8b0ea:2943:2943 [0] NCCL INFO Channel 01/0 : 1[0] -> 0[0] [receive] via NET/Socket/0
2f3d44d8b0ea:2943:2943 [0] NCCL INFO Channel 00/0 : 0[0] -> 1[0] [send] via NET/Socket/0
2f3d44d8b0ea:2943:2943 [0] NCCL INFO Channel 01/0 : 0[0] -> 1[0] [send] via NET/Socket/0
2f3d44d8b0ea:2943:2943 [0] NCCL INFO Connected all rings, use ring PXN 0 GDR 0
[Rank 0] Starting training...
Step 0 | Loss: 1.4985
Step 10 | Loss: 1.1718
Step 20 | Loss: 1.5557
[Rank 0] Training complete.
2f3d44d8b0ea:2943:2943 [0] NCCL INFO comm 0x271b62c0 rank 0 nranks 2 cudaDev 0 busId 70 - Destroy COMPLETE
2f3d44d8b0ea:2943:2943 [0] NCCL INFO ENV/Plugin: Closing env plugin ncclEnvDefault
I0407 03:33:43.109000 2911 torch/distributed/elastic/agent/server/api.py:917] [default] worker group successfully finished. Waiting 300 seconds for other agents to finish.
I0407 03:33:43.109000 2911 torch/distributed/elastic/agent/server/api.py:970] Local worker group finished (WorkerState.SUCCEEDED). Waiting 300 seconds for other agents to finish
I0407 03:33:43.119000 2911 torch/distributed/elastic/agent/server/api.py:984] Done waiting for other agents. Elapsed: 0.00843501091003418 seconds
```

And this is from the second node.  It all worked!

```log
root@dd3bc9f2214d:/app# NCCL_DEBUG=INFO LOGLEVEL=DEBUG NCCL_SOCKET_IFNAME=wg0 STEPS=30 torchrun   --nnodes=2 --nproc_per_node=1 --node_rank=1   --rdzv_id=test_job --rdzv_backend=c10d --rdzv_endpoint=10.100.10.1:29500   --local_addr=10.100.10.2   /app/train.py
I0407 03:33:37.836000 490 torch/distributed/run.py:735] Using nproc_per_node=1.
I0407 03:33:37.916000 490 torch/distributed/launcher/api.py:131] Using default numa options = None
I0407 03:33:37.917000 490 torch/distributed/launcher/api.py:224] Starting elastic_operator with launch configs:
I0407 03:33:37.917000 490 torch/distributed/launcher/api.py:224]   entrypoint               : /app/train.py
I0407 03:33:37.917000 490 torch/distributed/launcher/api.py:224]   min_nodes                : 2
I0407 03:33:37.917000 490 torch/distributed/launcher/api.py:224]   max_nodes                : 2
I0407 03:33:37.917000 490 torch/distributed/launcher/api.py:224]   nproc_per_node           : 1
I0407 03:33:37.917000 490 torch/distributed/launcher/api.py:224]   run_id                   : test_job
I0407 03:33:37.917000 490 torch/distributed/launcher/api.py:224]   rdzv_backend             : c10d
I0407 03:33:37.917000 490 torch/distributed/launcher/api.py:224]   rdzv_endpoint            : 10.100.10.1:29500
I0407 03:33:37.917000 490 torch/distributed/launcher/api.py:224]   rdzv_configs             : {'timeout': 900}
I0407 03:33:37.917000 490 torch/distributed/launcher/api.py:224]   max_restarts             : 0
I0407 03:33:37.917000 490 torch/distributed/launcher/api.py:224]   monitor_interval         : 0.1
I0407 03:33:37.917000 490 torch/distributed/launcher/api.py:224]   log_dir                  : /tmp/torchelastic_667h1_sh
I0407 03:33:37.917000 490 torch/distributed/launcher/api.py:224]   metrics_cfg              : {}
I0407 03:33:37.917000 490 torch/distributed/launcher/api.py:224]   event_log_handler        : null
I0407 03:33:37.917000 490 torch/distributed/launcher/api.py:224]   numa_options             : None
I0407 03:33:37.917000 490 torch/distributed/launcher/api.py:224]   signals_to_handle        : SIGTERM,SIGINT,SIGHUP,SIGQUIT
I0407 03:33:37.917000 490 torch/distributed/launcher/api.py:224]   duplicate_stdout_filters : []
I0407 03:33:37.917000 490 torch/distributed/launcher/api.py:224]   duplicate_stderr_filters : []
I0407 03:33:37.917000 490 torch/distributed/launcher/api.py:224]
I0407 03:33:37.947000 490 torch/distributed/elastic/agent/server/api.py:898] [default] starting workers for entrypoint: python3
I0407 03:33:37.948000 490 torch/distributed/elastic/agent/server/api.py:693] [default] Rendezvous'ing worker group
I0407 03:33:39.097000 490 torch/distributed/elastic/agent/server/api.py:539] [default] Rendezvous complete for workers. Result:
I0407 03:33:39.097000 490 torch/distributed/elastic/agent/server/api.py:539]   restart_count=0
I0407 03:33:39.097000 490 torch/distributed/elastic/agent/server/api.py:539]   master_addr=10.100.10.1
I0407 03:33:39.097000 490 torch/distributed/elastic/agent/server/api.py:539]   master_port=35121
I0407 03:33:39.097000 490 torch/distributed/elastic/agent/server/api.py:539]   group_rank=1
I0407 03:33:39.097000 490 torch/distributed/elastic/agent/server/api.py:539]   group_world_size=2
I0407 03:33:39.097000 490 torch/distributed/elastic/agent/server/api.py:539]   local_ranks=[0]
I0407 03:33:39.097000 490 torch/distributed/elastic/agent/server/api.py:539]   role_ranks=[1]
I0407 03:33:39.097000 490 torch/distributed/elastic/agent/server/api.py:539]   global_ranks=[1]
I0407 03:33:39.097000 490 torch/distributed/elastic/agent/server/api.py:539]   role_world_sizes=[2]
I0407 03:33:39.097000 490 torch/distributed/elastic/agent/server/api.py:539]   global_world_sizes=[2]
I0407 03:33:39.097000 490 torch/distributed/elastic/agent/server/api.py:539]   event_log_handler=null
I0407 03:33:39.097000 490 torch/distributed/elastic/agent/server/api.py:539]
I0407 03:33:39.097000 490 torch/distributed/elastic/agent/server/api.py:701] [default] Starting worker group
I0407 03:33:39.097000 490 torch/distributed/elastic/agent/server/local_elastic_agent.py:299] use_agent_store: True
I0407 03:33:39.098000 490 torch/distributed/elastic/agent/server/local_elastic_agent.py:195] Environment variable 'TORCHELASTIC_ENABLE_FILE_TIMER' not found. Do not start FileTimerServer.
I0407 03:33:39.098000 490 torch/distributed/elastic/agent/server/local_elastic_agent.py:239] Environment variable 'TORCHELASTIC_HEALTH_CHECK_PORT' not found. Do not start health check.
dd3bc9f2214d:520:520 [0] NCCL INFO ENV/Plugin: Could not find: libnccl-env.so
dd3bc9f2214d:520:520 [0] NCCL INFO cudaDriverVersion 12080
dd3bc9f2214d:520:520 [0] NCCL INFO NCCL_SOCKET_IFNAME set by environment to wg0
dd3bc9f2214d:520:520 [0] NCCL INFO Bootstrap: Using wg0:10.100.10.2<0>
dd3bc9f2214d:520:520 [0] NCCL INFO NCCL version 2.28.9+cuda12.9
dd3bc9f2214d:520:520 [0] NCCL INFO Comm config Blocking set to 1
dd3bc9f2214d:520:520 [0] NCCL INFO NET/Plugin: Could not find: libnccl-net.so
dd3bc9f2214d:520:520 [0] NCCL INFO Failed to open libibverbs.so[.1]
dd3bc9f2214d:520:520 [0] NCCL INFO transport/net_ib.cc:852 -> 3
dd3bc9f2214d:520:520 [0] NCCL INFO Failed to initialize NET plugin IB
dd3bc9f2214d:520:520 [0] NCCL INFO NCCL_SOCKET_IFNAME set by environment to wg0
dd3bc9f2214d:520:520 [0] NCCL INFO NET/Socket : Using [0]wg0:10.100.10.2<0>
dd3bc9f2214d:520:520 [0] NCCL INFO Initialized NET plugin Socket
dd3bc9f2214d:520:520 [0] NCCL INFO Assigned NET plugin Socket to comm
dd3bc9f2214d:520:520 [0] NCCL INFO Using network Socket
dd3bc9f2214d:520:520 [0] NCCL INFO ncclCommInitRankConfig comm 0x26af6510 rank 1 nranks 2 cudaDev 0 nvmlDev 0 busId 70 commId 0x77d5d096c29cef94 - Init START
dd3bc9f2214d:520:520 [0] NCCL INFO RAS client listening socket at ::1<28028>
dd3bc9f2214d:520:520 [0] NCCL INFO Bootstrap timings total 0.007266 (create 0.000047, send 0.001406, recv 0.003567, ring 0.000768, delay 0.000001)
dd3bc9f2214d:520:520 [0] NCCL INFO NCCL_IGNORE_DISABLED_P2P set by environment to 1.
dd3bc9f2214d:520:520 [0] NCCL INFO ncclTopoGetCpuAffinity: Affinity for GPU 0 is empty, ignoring. (GPU affinity =  ; CPU affinity = 0-27).
dd3bc9f2214d:520:520 [0] NCCL INFO comm 0x26af6510 rank 1 nRanks 2 nNodes 2 localRanks 1 localRank 0 MNNVL 0
dd3bc9f2214d:520:520 [0] NCCL INFO Trees [0] -1/-1/-1->1->0 [1] 0/-1/-1->1->-1
dd3bc9f2214d:520:520 [0] NCCL INFO P2P Chunksize set to 131072
dd3bc9f2214d:520:520 [0] NCCL INFO PROFILER/Plugin: Could not find: libnccl-profiler.so
dd3bc9f2214d:520:520 [0] NCCL INFO Check P2P Type isAllDirectP2p 1 directMode 0 isAllCudaP2p 1
dd3bc9f2214d:520:553 [0] NCCL INFO [Proxy Service] Device 0 CPU core 17
dd3bc9f2214d:520:554 [0] NCCL INFO [Proxy Service UDS] Device 0 CPU core 11
dd3bc9f2214d:520:520 [0] NCCL INFO TUNER/Plugin: Could not find: libnccl-tuner.so
dd3bc9f2214d:520:520 [0] NCCL INFO threadThresholds 8/8/64 | 16/8/64 | 512 | 512
dd3bc9f2214d:520:520 [0] NCCL INFO 2 coll channels, 2 collnet channels, 0 nvls channels, 2 p2p channels, 2 p2p channels per peer
dd3bc9f2214d:520:520 [0] NCCL INFO ncclCommInitRankConfig comm 0x26af6510 rank 1 nranks 2 cudaDev 0 nvmlDev 0 busId 70 commId 0x77d5d096c29cef94 - Init COMPLETE
dd3bc9f2214d:520:520 [0] NCCL INFO Init timings - ncclCommInitRankConfig: rank 1 nranks 2 total 0.17 (kernels 0.15, alloc 0.00, bootstrap 0.01, allgathers 0.00, topo 0.01, graphs 0.00, connections 0.00, rest 0.00)
dd3bc9f2214d:520:555 [0] NCCL INFO [Proxy Progress] Device 0 CPU core 27
dd3bc9f2214d:520:520 [0] NCCL INFO Channel 00/0 : 0[0] -> 1[0] [receive] via NET/Socket/0
dd3bc9f2214d:520:520 [0] NCCL INFO Channel 01/0 : 0[0] -> 1[0] [receive] via NET/Socket/0
dd3bc9f2214d:520:520 [0] NCCL INFO Channel 00/0 : 1[0] -> 0[0] [send] via NET/Socket/0
dd3bc9f2214d:520:520 [0] NCCL INFO Channel 01/0 : 1[0] -> 0[0] [send] via NET/Socket/0
dd3bc9f2214d:520:520 [0] NCCL INFO Connected all rings, use ring PXN 0 GDR 0
[Rank 1] Starting training...
[Rank 1] Training complete.
dd3bc9f2214d:520:520 [0] NCCL INFO comm 0x26af6510 rank 1 nranks 2 cudaDev 0 busId 70 - Destroy COMPLETE
dd3bc9f2214d:520:520 [0] NCCL INFO ENV/Plugin: Closing env plugin ncclEnvDefault
I0407 03:33:43.118000 490 torch/distributed/elastic/agent/server/api.py:917] [default] worker group successfully finished. Waiting 300 seconds for other agents to finish.
I0407 03:33:43.118000 490 torch/distributed/elastic/agent/server/api.py:970] Local worker group finished (WorkerState.SUCCEEDED). Waiting 300 seconds for other agents to finish
I0407 03:33:43.120000 490 torch/distributed/elastic/agent/server/api.py:984] Done waiting for other agents. Elapsed: 0.0019180774688720703 seconds
```

### Conclusion

We just confirmed PyTorch DDP worked with a WireGuard network established between Docker containers.  There are special configurations needed:

* Specify `--rdzv_conf=is_host=1` for the master node because PyTorch doesn't see secondary IP addresses to check if the rendezvous endpoint is itself or not
* Specify `--local_addr` for every node to communicate via IP addresses instead of hostname

Actually specifying these parameters along with other standard DDP parameters such as `rdzv_endpoint` or `nnodes` is error-prone.  I created [a sample Docker image](https://github.com/kinesis-network/docker-image-samples/tree/main/11-torchrun) to semi-automated these configurations.

The question remains: Is this a bug in PyTorch we should report?

My answer is yes.  `--rdzv_conf=is_host=1` is a great workaround, but PyTorch should check all IP addresses assigned.  At the same time, I cannot find a clean solution for this yet.  It's strangely difficult to enumerate all IP addressese on Linux.  One solution AI suggested is to use [`fcntl.ioctl`](https://docs.python.org/3/library/fcntl.html#fcntl.ioctl) , but I believe it will be rejected because it looks too "C-style" or less compatible.  I'll think through it further.

Anyway, Kinesis now supports PyTorch DDP.  If you have a model to train, go to [https://portal.kinesis.network](https://portal.kinesis.network/) and run it!
