UnlearnDiffAtk / app.py
xinchen9's picture
[Update]Change backend of Gaudi
41d84e3 verified
raw
history blame contribute delete
No virus
5.8 kB
import gradio as gr
import os
import requests
import json
import base64
from io import BytesIO
from huggingface_hub import login
from PIL import Image
# myip = os.environ["0.0.0.0"]
# myport = os.environ["80"]
myip = "146.152.224.103"
myport=8080
is_spaces = True if "SPACE_ID" in os.environ else False
is_shared_ui = False
from css_html_js import custom_css
from about import (
CITATION_BUTTON_LABEL,
CITATION_BUTTON_TEXT,
EVALUATION_QUEUE_TEXT,
INTRODUCTION_TEXT,
LLM_BENCHMARKS_TEXT,
TITLE,
)
def process_image_from_binary(img_stream):
if img_stream is None:
print("no image binary")
return
image_data = base64.b64decode(img_stream)
image_bytes = BytesIO(image_data)
img = Image.open(image_bytes)
return img
def execute_prepare(diffusion_model_id, concept, steps, attack_id):
print(f"my IP is {myip}, my port is {myport}")
print(f"my input is diffusion_model_id: {diffusion_model_id}, concept: {concept}, steps: {steps}")
response = requests.post('http://{}:{}/prepare'.format(myip, myport),
json={"diffusion_model_id": diffusion_model_id, "concept": concept, "steps": steps, "attack_id": attack_id},
timeout=(10, 1200))
print(f"result: {response}")
# result = result.text[1:-1]
prompt = ""
img = None
if response.status_code == 200:
response_json = response.json()
print(response_json)
prompt = response_json['input_prompt']
img = process_image_from_binary(response_json['no_attack_img'])
else:
print(f"Request failed with status code {response.status_code}")
return prompt, img
def execute_udiff(diffusion_model_id, concept, steps, attack_id):
print(f"my IP is {myip}, my port is {myport}")
print(f"my input is diffusion_model_id: {diffusion_model_id}, concept: {concept}, steps: {steps}")
response = requests.post('http://{}:{}/udiff'.format(myip, myport),
json={"diffusion_model_id": diffusion_model_id, "concept": concept, "steps": steps, "attack_id": attack_id},
timeout=(10, 1200))
print(f"result: {response}")
# result = result.text[1:-1]
prompt = ""
img = None
if response.status_code == 200:
response_json = response.json()
print(response_json)
prompt = response_json['output_prompt']
img = process_image_from_binary(response_json['attack_img'])
else:
print(f"Request failed with status code {response.status_code}")
return prompt, img
css = '''
.instruction{position: absolute; top: 0;right: 0;margin-top: 0px !important}
.arrow{position: absolute;top: 0;right: -110px;margin-top: -8px !important}
#component-4, #component-3, #component-10{min-height: 0}
.duplicate-button img{margin: 0}
#img_1, #img_2, #img_3, #img_4{height:15rem}
#mdStyle{font-size: 0.7rem}
#titleCenter {text-align:center}
'''
with gr.Blocks(css=custom_css) as demo:
gr.HTML(TITLE)
gr.Markdown(INTRODUCTION_TEXT, elem_classes="markdown-text")
# gr.Markdown("# Demo of UnlearnDiffAtk.")
# gr.Markdown("### UnlearnDiffAtk is an effective and efficient adversarial prompt generation approach for unlearned diffusion models(DMs).")
# # gr.Markdown("####For more details, please visit the [project](https://www.optml-group.com/posts/mu_attack),
# # check the [code](https://github.com/OPTML-Group/Diffusion-MU-Attack), and read the [paper](https://arxiv.org/abs/2310.11868).")
# gr.Markdown("### Please notice that the process may take a long time, but the results will be saved. You can try it later if it waits for too long.")
with gr.Row() as udiff:
with gr.Row():
drop = gr.Dropdown(["Object-Church", "Object-Parachute", "Object-Garbage_Truck","Style-VanGogh",
"Nudity"],
label="Unlearning undesirable concepts")
with gr.Column():
# gr.Markdown("Please upload your model id.")
drop_model = gr.Dropdown(["ESD", "FMN"],
label="Unlearned DMs")
# diffusion_model_T = gr.Textbox(label='diffusion_model_id')
# concept = gr.Textbox(label='concept')
# attacker = gr.Textbox(label='attacker')
# start_button = gr.Button("Attack!")
with gr.Column():
atk_idx = gr.Textbox(label="attack index")
with gr.Column():
shown_columns_step = gr.Slider(
0, 100, value=40,
step=1, label="Attack Steps", info="Choose between 0 and 100",
interactive=True,)
with gr.Row() as attack:
with gr.Column(min_width=512):
start_button = gr.Button("Attack prepare!",size='lg')
text_input = gr.Textbox(label="Input Prompt")
orig_img = gr.Image(label="Image Generated by Input Prompt",width=512,show_share_button=False,show_download_button=False)
with gr.Column():
attack_button = gr.Button("UnlearnDiffAtk!",size='lg')
text_ouput = gr.Textbox(label="Prompt Genetated by UnlearnDiffAtk")
result_img = gr.Image(label="Image Gnerated by Prompt of UnlearnDiffAtk",width=512,show_share_button=False,show_download_button=False)
start_button.click(fn=execute_prepare, inputs=[drop_model, drop, shown_columns_step, atk_idx], outputs=[text_input, orig_img], api_name="prepare")
attack_button.click(fn=execute_udiff, inputs=[drop_model, drop, shown_columns_step, atk_idx], outputs=[text_ouput, result_img], api_name="udiff")
demo.queue().launch(server_name='0.0.0.0')