Skip to content

Add Jax distributed training guide with RNN/MLP examples#77

Open
atoniolo76 wants to merge 8 commits into
mainfrom
alessio/jax-training-example
Open

Add Jax distributed training guide with RNN/MLP examples#77
atoniolo76 wants to merge 8 commits into
mainfrom
alessio/jax-training-example

Conversation

@atoniolo76
Copy link
Copy Markdown
Contributor

Creates a Modal training script with 4 entry points: mlp_train, mlp_sample, rnn_train,rnn_sample. Adds a README explaining the advantages of Jax over PyTorch and how to setup a multi-node cluster using mesh/sharding primitives. Requires third-party library Equinox for neural network convenience.

MLP example:
Fit a basic MLP with hidden_size=64 to the x^2 function. Compute the mean-squared error as loss-function and back-propagate with Adam optimizer.

RNN example:
Next-character prediction on Chapter 32 from Moby Dick. Computes cross-entropy loss. Vocabulary is a one-hot vector of size 64.

Checklist

  • [*] Example is documented with comments throughout, in a Literate Programming style.
  • [*] Example does not require third-party dependencies to be installed locally
  • [*] Example follows the style guide
  • [*] Example pins its dependencies
    • [*] Example pins container images to a stable tag, not a dynamic tag like latest
    • [*] Example specifies a python_version for the base image, if it is used
    • [*] Example pins all dependencies to at least minor version, ~=x.y.z or ==x.y
    • [*] Example dependencies with version < 1 are pinned to patch version, ==0.y.z

(Modal's internal guide page for this repo is Multi-node examples guidance.)

image image (2) image

Outside contributors

You're great! Thanks for your contribution.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant