d-byrne commited on
Commit
9680934
1 Parent(s): a39796c

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +5 -7
README.md CHANGED
@@ -17,7 +17,7 @@ This model is trained on the Jumanji snake environment
17
 
18
  ### How to use
19
 
20
- Go to the jumanji repo for the primary model and requirements.
21
 
22
  ```
23
  pip install --quiet -U pip -r ../requirements/requirements-train.txt ../.
@@ -26,7 +26,6 @@ pip install --quiet -U pip -r ../requirements/requirements-train.txt ../.
26
  Below is an example script for loading and running the Jumanji model
27
 
28
  ```python
29
-
30
  import pickle
31
  import joblib
32
 
@@ -39,17 +38,16 @@ from jumanji.training.setup_train import setup_agent, setup_env
39
  from jumanji.training.utils import first_from_device
40
 
41
  # initialise the config
42
- with initialize(version_base=None, config_path="../jumanji/training/configs"):
43
- cfg = compose(config_name="config.yaml", overrides=["env=connector", "agent=a2c"])
44
 
45
  # get model state from HF
46
- REPO_ID = "YOUR_REPO_ID"
47
  FILENAME = "Snake-v1_training_state"
48
 
49
  model_weights = hf_hub_download(repo_id=REPO_ID, filename=FILENAME)
50
 
51
- # load the model
52
- with open("training_state","rb") as f:
53
  training_state = pickle.load(f)
54
 
55
  params = first_from_device(training_state.params_state.params)
 
17
 
18
  ### How to use
19
 
20
+ Go to the jumanji repo for the primary model and requirements. Clone the repo and navigate to the root directory.
21
 
22
  ```
23
  pip install --quiet -U pip -r ../requirements/requirements-train.txt ../.
 
26
  Below is an example script for loading and running the Jumanji model
27
 
28
  ```python
 
29
  import pickle
30
  import joblib
31
 
 
38
  from jumanji.training.utils import first_from_device
39
 
40
  # initialise the config
41
+ with initialize(version_base=None, config_path="jumanji/training/configs"):
42
+ cfg = compose(config_name="config.yaml", overrides=["env=snake", "agent=a2c"])
43
 
44
  # get model state from HF
45
+ REPO_ID = "d-byrne/snake-v1_training_state"
46
  FILENAME = "Snake-v1_training_state"
47
 
48
  model_weights = hf_hub_download(repo_id=REPO_ID, filename=FILENAME)
49
 
50
+ with open(model_weights,"rb") as f:
 
51
  training_state = pickle.load(f)
52
 
53
  params = first_from_device(training_state.params_state.params)