d-byrne's picture
Create README.md
ae613e2
|
raw
history blame
2.29 kB
---
license: apache-2.0
tags:
- jax
- rl
- jumanji
---
# CVRP-V1
This model is trained on the Jumanji CVRP environment
**Developed by:** InstaDeep
### Model Sources
<!-- Provide the basic links for the model. -->
- **Repository:** [Jumanji](https://github.com/instadeepai/jumanji)
- **Paper:** TBD
### How to use
[Notebook](#)
Go to the jumanji repo for the primary model and requirements. Clone the repo and navigate to the root directory.
```
pip install -e .
```
Below is an example script for loading and running the Jumanji model
```python
import pickle
import joblib
import jax
from hydra import compose, initialize
from huggingface_hub import hf_hub_download
from jumanji.training.setup_train import setup_agent, setup_env
from jumanji.training.utils import first_from_device
# initialise the config
with initialize(version_base=None, config_path="jumanji/training/configs"):
cfg = compose(config_name="config.yaml", overrides=["env=cvrp", "agent=a2c"])
# get model state from HF
REPO_ID = "InstaDeepAI/jumanji-cvrp-v1-a2c-benchmark"
FILENAME = "CVRP-v1_training_state"
model_weights = hf_hub_download(repo_id=REPO_ID, filename=FILENAME)
with open(model_weights,"rb") as f:
training_state = pickle.load(f)
params = first_from_device(training_state.params_state.params)
env = setup_env(cfg).unwrapped
agent = setup_agent(cfg, env)
policy = jax.jit(agent.make_policy(params.actor, stochastic = False))
# rollout a few episodes
NUM_EPISODES = 10
states = []
key = jax.random.PRNGKey(cfg.seed)
for episode in range(NUM_EPISODES):
key, reset_key = jax.random.split(key)
state, timestep = jax.jit(env.reset)(reset_key)
while not timestep.last():
key, action_key = jax.random.split(key)
observation = jax.tree_util.tree_map(lambda x: x[None], timestep.observation)
action, _ = policy(observation, action_key)
state, timestep = jax.jit(env.step)(state, action.squeeze(axis=0))
states.append(state)
# Freeze the terminal frame to pause the GIF.
for _ in range(10):
states.append(state)
# animate a GIF
env.animate(states, interval=150).save("./binpack.gif")
# save PNG
import matplotlib.pyplot as plt
%matplotlib inline
env.render(states[117])
plt.savefig("connector.png", dpi=300)
```