Update README.md
Browse files
README.md
CHANGED
@@ -17,7 +17,7 @@ This model is trained on the Jumanji snake environment
|
|
17 |
|
18 |
### How to use
|
19 |
|
20 |
-
Go to the jumanji repo for the primary model and requirements.
|
21 |
|
22 |
```
|
23 |
pip install --quiet -U pip -r ../requirements/requirements-train.txt ../.
|
@@ -26,7 +26,6 @@ pip install --quiet -U pip -r ../requirements/requirements-train.txt ../.
|
|
26 |
Below is an example script for loading and running the Jumanji model
|
27 |
|
28 |
```python
|
29 |
-
|
30 |
import pickle
|
31 |
import joblib
|
32 |
|
@@ -39,17 +38,16 @@ from jumanji.training.setup_train import setup_agent, setup_env
|
|
39 |
from jumanji.training.utils import first_from_device
|
40 |
|
41 |
# initialise the config
|
42 |
-
with initialize(version_base=None, config_path="
|
43 |
-
cfg = compose(config_name="config.yaml", overrides=["env=
|
44 |
|
45 |
# get model state from HF
|
46 |
-
REPO_ID = "
|
47 |
FILENAME = "Snake-v1_training_state"
|
48 |
|
49 |
model_weights = hf_hub_download(repo_id=REPO_ID, filename=FILENAME)
|
50 |
|
51 |
-
|
52 |
-
with open("training_state","rb") as f:
|
53 |
training_state = pickle.load(f)
|
54 |
|
55 |
params = first_from_device(training_state.params_state.params)
|
|
|
17 |
|
18 |
### How to use
|
19 |
|
20 |
+
Go to the jumanji repo for the primary model and requirements. Clone the repo and navigate to the root directory.
|
21 |
|
22 |
```
|
23 |
pip install --quiet -U pip -r ../requirements/requirements-train.txt ../.
|
|
|
26 |
Below is an example script for loading and running the Jumanji model
|
27 |
|
28 |
```python
|
|
|
29 |
import pickle
|
30 |
import joblib
|
31 |
|
|
|
38 |
from jumanji.training.utils import first_from_device
|
39 |
|
40 |
# initialise the config
|
41 |
+
with initialize(version_base=None, config_path="jumanji/training/configs"):
|
42 |
+
cfg = compose(config_name="config.yaml", overrides=["env=snake", "agent=a2c"])
|
43 |
|
44 |
# get model state from HF
|
45 |
+
REPO_ID = "d-byrne/snake-v1_training_state"
|
46 |
FILENAME = "Snake-v1_training_state"
|
47 |
|
48 |
model_weights = hf_hub_download(repo_id=REPO_ID, filename=FILENAME)
|
49 |
|
50 |
+
with open(model_weights,"rb") as f:
|
|
|
51 |
training_state = pickle.load(f)
|
52 |
|
53 |
params = first_from_device(training_state.params_state.params)
|