Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce cuda_p2p based fused_all_gather_matmul and fused_matmul_reduce_scatter #126634

Closed
wants to merge 9 commits into from

Conversation

Copy link

pytorch-bot bot commented May 19, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/126634

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

✅ You can merge normally! (3 Unrelated Failures)

As of commit 29e6b1f with merge base ff65b18 (image):

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

UNSTABLE - The following job failed but was likely due to flakiness present on trunk and has been marked as unstable:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

…_matmul_reduce_scatter"


## Context
See context [here](#122163).

## This PR
Introduces `cuda_p2p` based `fused_all_gather_matmul` and `fused_matmul_reduce_scatter` dispatcher ops which performs micro-pipelining TP for `all-gather -> matmul` and `matmul -> reduce-scatter` respectively.

Fusion vs. decomposition - in principle, the micro-pipelining is achieved via decomposition. However, in practice, today Inductor can't deal with the decomposed patterns well. So instead performing decomposition in Inductor, we fuse the patterns to be decomposed and dispatch them to corresponding operators that handle decomposition + micropipelining.

cc mrshenli pritamdamania87 zhaojuanmao satgera gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k

[ghstack-poisoned]
Copy link
Contributor

@Chillee Chillee left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do feel like we should be able to provide a higher level API for this 🤔 It would be nice if it could be the same API for both allgather and reduce_scatter.

torch/distributed/_cuda_p2p/__init__.py Show resolved Hide resolved
torch/distributed/_cuda_p2p/__init__.py Show resolved Hide resolved
Copy link
Contributor

@wanchaol wanchaol left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks awesome!



@contextmanager
def test_with_non_cuda_p2p_group():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: these test utils should move to the torch.testing package instead?

ag_shape = list(A_shard.shape)
ag_shape[gather_dim] *= group_size
ag_out = A_shard.new_empty(ag_shape)
return ag_out, [ag_out @ B for B in Bs]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for meta formulas, wondering if this matmul would actually incur computation or just call the matmul meta kernel(i guess it's the later one?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah we are calling the meta kernels to deduce device, shape, and strides.

…_matmul_reduce_scatter"


## Context
See context [here](#122163).

## This PR
Introduces `cuda_p2p` based `fused_all_gather_matmul` and `fused_matmul_reduce_scatter` dispatcher ops which performs micro-pipelining TP for `all-gather -> matmul` and `matmul -> reduce-scatter` respectively.

Fusion vs. decomposition - in principle, the micro-pipelining is achieved via decomposition. However, in practice, today Inductor can't deal with the decomposed patterns well. So instead performing decomposition in Inductor, we fuse the patterns to be decomposed and dispatch them to corresponding operators that handle decomposition + micropipelining.

cc mrshenli pritamdamania87 zhaojuanmao satgera gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k

[ghstack-poisoned]
…_matmul_reduce_scatter"


## Context
See context [here](#122163).

## This PR
Introduces `cuda_p2p` based `fused_all_gather_matmul` and `fused_matmul_reduce_scatter` dispatcher ops which performs micro-pipelining TP for `all-gather -> matmul` and `matmul -> reduce-scatter` respectively.

Fusion vs. decomposition - in principle, the micro-pipelining is achieved via decomposition. However, in practice, today Inductor can't deal with the decomposed patterns well. So instead performing decomposition in Inductor, we fuse the patterns to be decomposed and dispatch them to corresponding operators that handle decomposition + micropipelining.

cc mrshenli pritamdamania87 zhaojuanmao satgera gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k

[ghstack-poisoned]
…_matmul_reduce_scatter"


## Context
See context [here](#122163).

## This PR
Introduces `cuda_p2p` based `fused_all_gather_matmul` and `fused_matmul_reduce_scatter` dispatcher ops which performs micro-pipelining TP for `all-gather -> matmul` and `matmul -> reduce-scatter` respectively.

Fusion vs. decomposition - in principle, the micro-pipelining is achieved via decomposition. However, in practice, today Inductor can't deal with the decomposed patterns well. So instead performing decomposition in Inductor, we fuse the patterns to be decomposed and dispatch them to corresponding operators that handle decomposition + micropipelining.

cc mrshenli pritamdamania87 zhaojuanmao satgera gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k

[ghstack-poisoned]
yifuwang added a commit that referenced this pull request May 21, 2024
…uce_scatter

ghstack-source-id: cfada01c278b4ed552914d073147c77aa29e6a04
Pull Request resolved: #126634
…_matmul_reduce_scatter"


## Context
See context [here](#122163).

## This PR
Introduces `cuda_p2p` based `fused_all_gather_matmul` and `fused_matmul_reduce_scatter` dispatcher ops which performs micro-pipelining TP for `all-gather -> matmul` and `matmul -> reduce-scatter` respectively.

Fusion vs. decomposition - in principle, the micro-pipelining is achieved via decomposition. However, in practice, today Inductor can't deal with the decomposed patterns well. So instead performing decomposition in Inductor, we fuse the patterns to be decomposed and dispatch them to corresponding operators that handle decomposition + micropipelining.

cc mrshenli pritamdamania87 zhaojuanmao satgera gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k

[ghstack-poisoned]
…_matmul_reduce_scatter"


## Context
See context [here](#122163).

## This PR
Introduces `cuda_p2p` based `fused_all_gather_matmul` and `fused_matmul_reduce_scatter` dispatcher ops which performs micro-pipelining TP for `all-gather -> matmul` and `matmul -> reduce-scatter` respectively.

Fusion vs. decomposition - in principle, the micro-pipelining is achieved via decomposition. However, in practice, today Inductor can't deal with the decomposed patterns well. So instead performing decomposition in Inductor, we fuse the patterns to be decomposed and dispatch them to corresponding operators that handle decomposition + micropipelining.

cc mrshenli pritamdamania87 zhaojuanmao satgera gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k

[ghstack-poisoned]
@yifuwang yifuwang added the topic: not user facing topic category label May 29, 2024
@yifuwang
Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label May 30, 2024
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged oncall: distributed Add this issue/PR to distributed oncall triage queue topic: not user facing topic category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants