d-byrne commited on
Commit
a39796c
1 Parent(s): 3976171

added readme instructions

Browse files
Files changed (1) hide show
  1. README.md +85 -0
README.md CHANGED
@@ -1,3 +1,88 @@
1
  ---
2
  license: apache-2.0
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: apache-2.0
3
  ---
4
+
5
+ # Snake-V1
6
+ This model is trained on the Jumanji snake environment
7
+
8
+
9
+ **Developed by:** InstaDeep
10
+
11
+ ### Model Sources
12
+
13
+ <!-- Provide the basic links for the model. -->
14
+
15
+ - **Repository:** [Jumanji](https://github.com/instadeepai/jumanji)
16
+ - **Paper:** TBD
17
+
18
+ ### How to use
19
+
20
+ Go to the jumanji repo for the primary model and requirements.
21
+
22
+ ```
23
+ pip install --quiet -U pip -r ../requirements/requirements-train.txt ../.
24
+ ```
25
+
26
+ Below is an example script for loading and running the Jumanji model
27
+
28
+ ```python
29
+
30
+ import pickle
31
+ import joblib
32
+
33
+ import jax
34
+ from hydra import compose, initialize
35
+ from huggingface_hub import hf_hub_download
36
+
37
+
38
+ from jumanji.training.setup_train import setup_agent, setup_env
39
+ from jumanji.training.utils import first_from_device
40
+
41
+ # initialise the config
42
+ with initialize(version_base=None, config_path="../jumanji/training/configs"):
43
+ cfg = compose(config_name="config.yaml", overrides=["env=connector", "agent=a2c"])
44
+
45
+ # get model state from HF
46
+ REPO_ID = "YOUR_REPO_ID"
47
+ FILENAME = "Snake-v1_training_state"
48
+
49
+ model_weights = hf_hub_download(repo_id=REPO_ID, filename=FILENAME)
50
+
51
+ # load the model
52
+ with open("training_state","rb") as f:
53
+ training_state = pickle.load(f)
54
+
55
+ params = first_from_device(training_state.params_state.params)
56
+ env = setup_env(cfg).unwrapped
57
+ agent = setup_agent(cfg, env)
58
+ policy = jax.jit(agent.make_policy(params.actor, stochastic = False))
59
+
60
+ # rollout a few episodes
61
+ NUM_EPISODES = 10
62
+
63
+ states = []
64
+ key = jax.random.PRNGKey(cfg.seed)
65
+ for episode in range(NUM_EPISODES):
66
+ key, reset_key = jax.random.split(key)
67
+ state, timestep = jax.jit(env.reset)(reset_key)
68
+ while not timestep.last():
69
+ key, action_key = jax.random.split(key)
70
+ observation = jax.tree_util.tree_map(lambda x: x[None], timestep.observation)
71
+ action, _ = policy(observation, action_key)
72
+ state, timestep = jax.jit(env.step)(state, action.squeeze(axis=0))
73
+ states.append(state)
74
+ # Freeze the terminal frame to pause the GIF.
75
+ for _ in range(10):
76
+ states.append(state)
77
+
78
+ # animate a GIF
79
+ env.animate(states, interval=150).save("./snake.gif")
80
+
81
+ # save PNG
82
+ import matplotlib.pyplot as plt
83
+ %matplotlib inline
84
+ env.render(states[117])
85
+ plt.savefig("connector.png", dpi=300)
86
+
87
+ ```
88
+