Skip to content

Conversation

alfuyao1986
Copy link

No description provided.

Copy link

meta-cla bot commented Aug 25, 2025

Hi @alfuyao1986!

Thank you for your pull request and welcome to our community.

Action Required

In order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you.

Process

In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.

If you have received this in error or have any questions, please contact us at cla@meta.com. Thanks!

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please justify the value of this change, following https://github.com/pytorch/torchtitan/blob/main/CONTRIBUTING.md#proof-of-value

In particular, why fake data is better than the default c4 / c4_test?


def __iter__(self) -> Iterator[Tuple[Dict[str, torch.Tensor], torch.Tensor]]:
while True:
inputs = torch.randint(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is fake data, not "synthetic" data.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, how about call it random data? The goal is to remove dataset dependency for quick performance benchmarking.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Default c4 will have two problems:

  1. Although small enough, but still have dependency for user to download before testing, and in rapid debugging and reruns, it is possible to hit HF request limit. Other case is in an unstable network, also affecting smooth development. I had to make local changes like this so I can developing without worry about dataset. Guess many users also has similar experience.
  2. With larger models and bigger batch size runs, it will easily loopback data, but same reason as 1, it may be limited or time consuming in many cases for user to download very large dataset.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Random dataset is usually very useful when debugging the CPU overhead brought by data loading, though I'm not sure if we already have such a use case. Multimodal may be benefit from random dataset.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@alfuyao1986

Although small enough, but still have dependency for user to download before testing, and in rapid debugging and reruns, it is possible to hit HF request limit. Other case is in an unstable network, also affecting smooth development. I had to make local changes like this so I can developing without worry about dataset. Guess many users also has similar experience.

We have c4_test stored in the repo
https://github.com/pytorch/torchtitan/tree/main/tests/assets/c4_test

With larger models and bigger batch size runs, it will easily loopback data, but same reason as 1, it may be limited or time consuming in many cases for user to download very large dataset.

What would be the advantage of using random / fake data versus looping back on c4_test?

@fegin

Multimodal may be benefit from random dataset.

As we don't have multimodal training, I think the main thing I'd like to understand what's the benefit of adding random data on top of existing c4_test.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Random dataset generally can skip the overhead of data loading, like actually reading from a disk. This is not related to whether the dataset is large or small. But as mentioned above, this may be more useful when we start to see dataloader overhead is a big thing. As for development efficiency, I didn't encounter such an issue, so I should not be the one to answer.

This is solely my opinion.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, given c4_test is already pre-stored in the repo, for most of the cases, it should be fine. I am actually completely fine with using pre-stored c4_test dataset. Only two more consideration just bring up for discussion.

  1. Random dataset can usually stress the whole stack better, numerically and computationally, vs. a small repeated dataset, but it is debatable that whether this additional stress practically realistic and necessary.
  2. Other frameworks (MaxText, Megatron-LM) do provide "synthetic/mock" data options for fast benchmarking, for ease of comparison point of view, it may be better to have a matching option.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, thanks for the context!

From my perspective, the value of this dataset is somewhat limited, given we already have c4_test which doesn't involve randomness so has become a standard way for numerical testing even when parallelism / world size changes.

That said, if people have strong opinion to add this dataset, I'm OK, too. If that's the case, I would suggest making a new builder function & file, instead of piggyback on existing build_hf_dataloader. I understand that would make it harder to switch to this new dataset from config, but that's not a good reason to reuse.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, we definitely shouldn't use build_hf_dataloader for random dataset. There is actually another benefit of random dataset (when it has a deterministic option) -- debugging checkpoint issue. Given that the dataloader is controlled by other package, having a random dataset with a deterministic option will make debugging checkpoint inconsistency easier, at least we can rule out the dataset/dataloader problem.

Copy link

meta-cla bot commented Aug 25, 2025

Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks!

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Aug 25, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants