{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "e9c5f0df", "metadata": { "lines_to_next_cell": 2 }, "outputs": [], "source": [ "!pip install -q git+https://github.com/srush/MiniChain\n", "!git clone https://github.com/srush/MiniChain; cp -fr MiniChain/examples/* . " ] }, { "cell_type": "code", "execution_count": null, "id": "56932926", "metadata": { "lines_to_next_cell": 2 }, "outputs": [], "source": [ "desc = \"\"\"\n", "### Self-Ask\n", "\n", " Notebook implementation of the self-ask + Google tool use prompt. [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/srush/MiniChain/blob/master/examples/selfask.ipynb)\n", "\n", " (Adapted from [Self-Ask repo](https://github.com/ofirpress/self-ask))\n", "\"\"\"" ] }, { "cell_type": "code", "execution_count": null, "id": "43d24799", "metadata": {}, "outputs": [], "source": [ "from dataclasses import dataclass, replace\n", "from typing import Optional\n", "from minichain import prompt, show, OpenAI, Google" ] }, { "cell_type": "code", "execution_count": null, "id": "506b8e41", "metadata": {}, "outputs": [], "source": [ "@dataclass\n", "class State:\n", " question: str\n", " history: str = \"\"\n", " next_query: Optional[str] = None\n", " final_answer: Optional[str] = None" ] }, { "cell_type": "code", "execution_count": null, "id": "fb5fdf8e", "metadata": { "lines_to_next_cell": 1 }, "outputs": [], "source": [ "@prompt(OpenAI(),\n", " template_file = \"selfask.pmpt.tpl\",\n", " stop_template = \"\\nIntermediate answer:\")\n", "def self_ask(model, state):\n", " out = model(state)\n", " res = out.split(\":\", 1)[1]\n", " if out.startswith(\"Follow up:\"):\n", " return replace(state, next_query=res)\n", " elif out.startswith(\"So the final answer is:\"):\n", " return replace(state, final_answer=res)" ] }, { "cell_type": "code", "execution_count": null, "id": "f2d66476", "metadata": { "lines_to_next_cell": 1 }, "outputs": [], "source": [ "@prompt(Google())\n", "def google(model, state):\n", " if state.next_query is None:\n", " return state\n", "\n", " result = model(state.next_query)\n", " return State(state.question,\n", " state.history + \"\\nIntermediate answer: \" + result + \"\\n\")" ] }, { "cell_type": "code", "execution_count": null, "id": "e41d7a75", "metadata": {}, "outputs": [], "source": [ "def selfask(question):\n", " state = State(question)\n", " for i in range(3):\n", " state = self_ask(state)\n", " state = google(state)\n", " return state" ] }, { "cell_type": "code", "execution_count": null, "id": "e6e4b06f", "metadata": {}, "outputs": [], "source": [ "gradio = show(selfask,\n", " examples=[\"What is the zip code of the city where George Washington was born?\"],\n", " subprompts=[self_ask, google] * 3,\n", " description=desc,\n", " out_type=\"json\"\n", " )\n", "if __name__ == \"__main__\":\n", " gradio.launch()" ] }, { "cell_type": "code", "execution_count": null, "id": "6afd60de", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "jupytext": { "cell_metadata_filter": "-all", "main_language": "python", "notebook_metadata_filter": "-all" } }, "nbformat": 4, "nbformat_minor": 5 }