We fine-tuned Llama 405B on AMD GPUs

publish.obsidian.md

209 points by felarof 6 hours ago

Hey HN, we recently fine-tuned the llama3.1 405B model on 8xAMD MI300x GPUs using JAX instead of PyTorch. JAX's advanced sharding APIs allowed us to achieve great performance. Check out our blog post to learn about the cool sharding tricks we used. We've also open-sourced the code: https://github.com/felafax/felafax

We're a small startup building AI infra for fine-tuning and serving LLMs on non-NVIDIA hardware (TPUs, AMD, Trainium).

Problem: Many companies are trying to get PyTorch working on AMD GPUs, but we believe this is a treacherous path. PyTorch is deeply intertwined with the NVIDIA ecosystem in a lot of ways (e.g., `torch.cuda` or scaled_dot_product_attention is an NVIDIA CUDA kernel exposed as a PyTorch function). So, to get PyTorch code running on non-NVIDIA hardware, there's a lot of "de-NVIDIAfying" that needs to be done.

Solution: We believe JAX is a better fit for non-NVIDIA hardware. In JAX, ML model code compiles to hardware-independent HLO graphs, which are then optimized by the XLA compiler before hardware-specific optimization. This clean separation allowed us to run the same LLaMA3 JAX code both on Google TPUs and AMD GPUs with no changes.

Our strategy as a company is to invest upfront in porting models to JAX, then leverage its framework and XLA kernels to extract maximum performance from non-NVIDIA backends. This is why we first ported Llama 3.1 from PyTorch to JAX, and now the same JAX model works great on TPUs and runs perfectly on AMD GPUs.

We'd love to hear your thoughts on our vision and repo!

felarof 5 hours ago

Hey HN, we recently fine-tuned the llama3.1 405B model on 8xAMD MI300x GPUs using JAX instead of PyTorch. JAX's advanced sharding APIs allowed us to achieve great performance. Check out our blog post to learn about the cool sharding tricks we used. We've also open-sourced the code: https://github.com/felafax/felafax

We're a small startup building AI infra for fine-tuning and serving LLMs on non-NVIDIA hardware (TPUs, AMD, Trainium).

Problem: Many companies are trying to get PyTorch working on AMD GPUs, but we believe this is a treacherous path. PyTorch is deeply intertwined with the NVIDIA ecosystem in a lot of ways (e.g., `torch.cuda` or scaled_dot_product_attention is an NVIDIA CUDA kernel exposed as a PyTorch function). So, to get PyTorch code running on non-NVIDIA hardware, there's a lot of "de-NVIDIAfying" that needs to be done.

Solution: We believe JAX is a better fit for non-NVIDIA hardware. In JAX, ML model code compiles to hardware-independent HLO graphs, which are then optimized by the XLA compiler before hardware-specific optimization. This clean separation allowed us to run the same LLaMA3 JAX code both on Google TPUs and AMD GPUs with no changes.

Our strategy as a company is to invest upfront in porting models to JAX, then leverage its framework and XLA kernels to extract maximum performance from non-NVIDIA backends. This is why we first ported Llama 3.1 from PyTorch to JAX, and now the same JAX model works great on TPUs and runs perfectly on AMD GPUs.

We'd love to hear your thoughts on our vision and repo!

  • ipsum2 3 hours ago

    I, and several others, had no problem running on PyTorch on AMD GPUs, with no code changes from CUDA. Check out MosaicML's blog posts: https://www.databricks.com/blog/training-llms-scale-amd-mi25...

    • felarof 3 hours ago

      Ahh, interesting, will take a look!

      Curious what are the steps to run PyTorch on AMD (does it work out-of-box with PyTorch+rocm docker image)? Does torch.compile work smoothly?

      • anthonix1 2 hours ago

        Yeah I would suggest taking a look at PyTorch on AMD before saying stuff like "scaled_dot_product_attention is an NVIDIA CUDA kernel exposed as a PyTorch function", because that is demonstrably false.

        Also, FWIW, I would suggest getting a small Llama 3.1 model training fast before trying to do a big 405B model -- faster to iterate and almost everything you'll learn on the small models will scale to the 405B.

        • felarof 2 hours ago

          Thanks for the feedback! I appreciate you pointing that out. My understanding was based on the PyTorch documentation for scaled_dot_product_attention (https://pytorch.org/docs/stable/generated/torch.nn.functiona...). - "The function may call optimized kernels for improved performance when using the CUDA backend. For all other backends, the PyTorch implementation will be used."

          And was trying to make a broader point about the lack of transparency (in performance, lower-level impl) in PyTorch when running on NVIDIA vs. non-NVIDIA hardware.

          • anthonix1 43 minutes ago

            I just asked an instance of Hermes 3 Llama 3.1 405B running on AMD GPUs: "does pytorch scaled dot product attention run on AMD GPUs?":

            "Yes, PyTorch's scaled dot product attention can run on AMD GPUs. PyTorch supports AMD GPUs through the ROCm (Radeon Open Compute) platform. To use PyTorch with an AMD GPU, you need to install the ROCm version of PyTorch, which is specifically built for AMD GPUs [...]"

            And it proceeded to give the steps to follow to install and run, with example Python code to demonstrate it. One slight nitpick is that it referred to an older URL with the --index-url to install torch with pip, but otherwise it was correct.

    • mistymountains 2 hours ago

      Again, the problem is custom kernels in CUDA. It’s not straightforward for many applications (LLMs are probably the most straightforward).

  • anthonix1 3 hours ago

    Does JAX have its own implementations of matmul, flash attention etc? Or does it use the ROCm implementations like PyTorch does? (e.g,. hipblaslt, Composable Kernel FA etc)

    Not too familiar with JAX, but the abysmal PyTorch training perf on MI300x is in large part attributable to the slow perf of the ROCm libraries it is using under the hood.

  • germanjoey 4 hours ago

    How are you verifying accuracy for your JAX port of Llama 3.1?

    IMHO, the main reason to use pytorch is actually that the original model used pytorch. What can seem to be identical logic between different model versions may actually cause model drift when infinitesimal floating point errors accumulate due to the huge scale of the data. My experience is that debugging an accuracy mismatches like this in a big model is a torturous ordeal beyond the 10th circle of hell.

    • felarof an hour ago

      Good question. We used a new AI+math-based testing tool (benchify.com) to run comparison tests, but we are working on building more robust infrastructure for this. Translating models from PyTorch to JAX is core to our strategy.

      That said, this path is not uncommon (translating from one framework to another). HuggingFace translates Google's Gemma family models from JAX to PyTorch, and a ton of people use it.

    • credit_guy 3 hours ago

      When you say "model versions", do you mean different quantizations of the model? Then it's not floating point errors that accumulate. Different quantizations of the model are different models. People will call such a model something like Meta-Llama-3.1-8B-Instruct--q4_0, claiming that it's just a "version" of the Meta-Llama-3.1-8B-Instruct. But it's just a lie. It's not the same model, and you should not expect the same results. There is no reason to debug the differences, what exactly would you expect to find, and what action would you envision to take once you find what you are looking for? However, is the quantized version still a useful LLM? Absolutely. Most people don't have an A100 to run the original model, so a quantized version is better than nothing.

    • srcreigh 3 hours ago

      Very fascinating, can you explain more about a time when this happened?

      Like what area was affected by fp errors, why were they introduced (was it like refactoring of pytorch code?), how was this determined to be the cause?

  • llm_trw 4 hours ago

    Does this work on the consumer grade cards like the 7090 XTX?

    And by work I don't mean: spend two weeks trying to get the drivers set up and never update the server again.

  • cameron_b 4 hours ago

    I'm glad to see a full implementation on AMD hardware.

    I'm not familiar with JAX, but the idea of providing an abstraction layer to more easily get to work on what hardware is available seems really valuable. Bringing back some competitiveness to the ecosystem will be a big win for workload mobility.

    I suspect that price/performance across implementations will be highly dependent on contract details, but do you intend to publish some comparisons in the future?

  • anthonix1 3 hours ago

    Any direct comparisons to 8xH100? 2 toks/sec seems very slow!

    I haven't done any LoRA training on MI300x myself, but I have done LLama 3.1 full training on 8xMI300x and got pretty close to 8xH100 performance with my own kernels (ROCm is just too slow).

    • felarof an hour ago

      Oops, my calculation was wrong. Let me add an edit to the blog, thanks for pointing it out!

      My train step was taking 30s.

      And I was using a batch size of 16 and seq length of 64, making the training speed as (16*64/30) tokens per sec == 35 tokens per second (for fine-tuning in JAX eager mode).

      (I haven't done comparison with 8XH100)

  • jgalt212 5 hours ago

    Is there some cost rule of thumb to compare Nvidia, AMD, and Google TPU?

    • felarof 3 hours ago

      Good question. No good metric give performance depends on software stack (JAX vs PyTorch) + optimizations.

      But my take performance per dollar of TPU > AMD > NVIDIA.

  • ngcc_hk 5 hours ago

    Given it is a migration, is there actual comparison of the same model on PyTorch vs your version. The comparison table there seems to be on technical side.

    Also any technical issues encountered?

    • felarof an hour ago

      We have a few technical issues that we still need to address:

      1) This entire fine-tuning run was done in JAX eager mode. I kept running out of memory (OOM) when trying to `jax.jit` the entire training step. Even gradual `jax.jit` didn't work.

      2) The current version doesn't have gradient accumulation, and with a batch size of just 16, that’s not ideal. I'm working on implementing gradient accumulation next.

      3) We still haven't found a good way to load large sequence-length data (like 32k sequence length). Currently, before sharding the training batch across GPUs, it ends up loading the entire batch onto a single GPU’s VRAM and causes OOM issues.

manojlds 4 hours ago

Thought this was a post from Obsidian at first. Why haven't they done the GitHub.com vs GitHub.io thing yet.

  • codetrotter 3 hours ago

    Looking at the URL has me thinking that this confusion would be resolved if HN adds a small piece of logic to treat the domain publish.obsidian.md specially, just like how HN already does for pages served under forbes.com/sites which is not written by the Forbes staff themselves.

    So instead of showing the domain as obsidian.md, HN would show the domain for this link as publish.obsidian.md

    Maybe something for dang to consider if he sees this comment?

  • gbraad 4 hours ago

    Same thought here. Why would Obsidian bother with AI? Oh wait, this is publish? So this is what $8 per month gets you? I am amazed, as I would have at least expected a subhost: [username].publish.obsidian.md

    • felarof 3 hours ago

      Yeah, used Obsidian Publish.

      But struggling to get custom domain to work with it (have emailed support).

abalaji 5 hours ago

@dang: could we get url to include the username since this isn't about Obsidian itself, but rather a user generated blog?

  • m00x 4 hours ago

    It's strange that HN didn't include the full domain "publish.obsidian.cmd".

  • meiraleal 4 hours ago

    That's something obsidian should fix if they care about not looking like they are being impersonated on HN.

    • viraptor an hour ago

      Obsidian can't do anything about it. It's HN chopping up the url

3abiton 5 hours ago

Firstly great work! I dabbled with AMD GPUs and ROCm support a year ago, and it was obvious AMD still a long way from catch ling up with Nvidia. While opting for JAX is in an interesting approach, what were the challenges for you deviating from pytorch (being the standard library for ML)?

  • felarof 3 hours ago

    A few weeks ago, I did a Show HN explaining our journey: https://news.ycombinator.com/item?id=41512142.

    We initially started with the goal of fine-tuning LLaMA 3 on TPUs, but PyTorch XLA was clunky, so we decided to rewrite the model in JAX. That said, as mentioned earlier in the thread, we also believe JAX is a better platform for non-NVIDIA GPUs and want to build on JAX+openXLA for building infra for non-NVIDIA GPUs.

  • 6y56h56 5 hours ago

    I cannot get AMD ROCm running on my debian 12 system which is what I think is causing Ollama to use CPU instead of GPU. So I guess there is still a long way to go.

    • jchw 2 hours ago

      At the risk of pissing people off, I think you may be better served by a distribution that provides a more up-to-date kernel. Debian 12 will give you Linux 6.1 LTS, which is probably OK if you're using an older Radeon card, but I've heard support for the 7900 XT/X series is a bit dicey and beyond that (e.g. Radeon 890M) non-existent.

      If there were improvements on the AMDGPU DRM driver side, you would not see them in Debian any time soon, as the 6.1 LTS kernel will be stuck with roughly whatever shipped January of last year. This is just a shortcoming in the Linux kernel, due to its lack of any kind of stable ABI for drivers.

      Of course it is possible this would help nothing or even hurt. My experience running stable (or even newer) kernels has been quite good, though. I run stable or newer across a few devices and run into hiccups not more than once every few years, which is definitely worth it to be able to get new driver improvements years in advance.

      (FWIW Debian is not even supported by ROCm[1]... although distros with even older kernels are. But, even if ROCm works, I can't imagine you will get ideal hardware support when running older kernels. I am not sure if ROCm has some workaround for enterprise Linux distributions specifically, but it feels like they must, given how many of their customers in the datacenter are likely to want to use them.)

      [1]: https://rocm.docs.amd.com/en/latest/compatibility/compatibil...

    • llm_trw 2 hours ago

      Like everything in machine learning it only really runs on Ubuntu 22.04. Anything else is unsupported and you need to spend weeks tinkering to get it to work, then never upgrade.

    • ants_everywhere 4 hours ago

      I've had more luck with the ROCm docker container. I run it via k8s. It was pretty painless to set up and has been mostly painless since. Prior to that it was nearly impossible to get Jax running reliably on ROCm.

      Even with the container, you have to be careful installing Python libraries because they can still break things.

      • lenova 4 hours ago

        I just recently went down the AMD GPU + ROCm rabbit hole as well. ROCm 6.2 was just released in August of this year and introduces a lot better support, though as the above poster mentioned, isn't merged into most recent OSes.

        This Github repo is good for tracking the latest Ubuntu + ROCm install process: https://github.com/nktice/AMD-AI

        • latchkey an hour ago

          That's a nice repo of random installation notes. Very helpful, thanks!

    • superkuh 4 hours ago

      You'd probably have a lot better luck using Vulkan acceleration (not ROCm) of llama.cpp as backend to ollama. It is incomparibly easier to set up and maintain compared to ROCm. You can actually do it on your computer's normal OS instead of inside a bunch of container/vms where the system libs are entirely customized to running just that one application.

      AMD's support of consumer cards is very, very short. By the time it's stable enough for a new card to run the card is no longer supported. In 2021 I bought an AMD GPU that came out 3 years before and 1 year after I bought it (4 years since release) they dropped ROCm support.

latchkey 5 hours ago

Nice work! I was just playing with the inference side of things with 405B myself this weekend [0].

I'm not convinced that 'torch.cuda' is really that bad since the AMD version of PyTorch just translates that for you. More like a naming problem, than anything. Fact is that it is just as easy to grab the rocm:pytorch container, as it is the rocm:jax container.

I don't see very many numbers posted. What MFU did you get?

[0] https://x.com/HotAisle/status/1837580046732874026

  • felarof 3 hours ago

    Nice!

    I need to calculate MFU. GPU, VRAM details can be found in the repo: https://dub.sh/amd-405b-res.

    I plan to reattempt the training run next weekend and JIT the entire training step to calculate MFU then

yeahwhatever10 5 hours ago

Where is the performance data?

  • felarof 3 hours ago

    (author here, sorry for the delay in replying, was stuck in back-to-back meetings)

    I updated our github repo to include GPU, VRAM utilization data (https://github.com/felafax/felafax?tab=readme-ov-file#amd-40...)

    Note: we couldn't run the JIT-compiled version of the 405B model due to our code/VRAM constraints (we need to investigate this further). The entire training run was executed in JAX eager mode, so there is significant potential for performance improvements.

    GPU utilization across the board was still ~30-40% even with eager mode, which is quite good! With JIT, I think the GPU util can easily shoot up to ~50-60%.