leedoming commited on
Commit
d8e1a3b
1 Parent(s): 0478ee4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +104 -139
app.py CHANGED
@@ -7,15 +7,12 @@ from io import BytesIO
7
  import time
8
  import json
9
  import numpy as np
10
- import cv2
11
- from inference_sdk import InferenceHTTPClient
12
- import matplotlib.pyplot as plt
13
- import base64
14
 
15
  # Load model and tokenizer
16
  @st.cache_resource
17
  def load_model():
18
- model, preprocess_val, tokenizer = open_clip.create_model_and_transforms('hf-hub:Marqo/marqo-fashionSigLIP')
 
19
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
  model.to(device)
21
  return model, preprocess_val, tokenizer, device
@@ -25,161 +22,106 @@ model, preprocess_val, tokenizer, device = load_model()
25
  # Load and process data
26
  @st.cache_data
27
  def load_data():
28
- with open('musinsa-final.json', 'r', encoding='utf-8') as f:
29
  return json.load(f)
30
 
31
  data = load_data()
32
 
33
  # Helper functions
34
- @st.cache_data
35
- def download_and_process_image(image_url):
36
- try:
37
- response = requests.get(image_url)
38
- response.raise_for_status() # Raises an HTTPError for bad responses
39
- image = Image.open(BytesIO(response.content))
40
-
41
- # Convert image to RGB mode if it's in RGBA mode
42
- if image.mode == 'RGBA':
43
- image = image.convert('RGB')
44
-
45
- return image
46
- except requests.RequestException as e:
47
- st.error(f"Error downloading image: {e}")
48
- return None
49
- except Exception as e:
50
- st.error(f"Error processing image: {e}")
 
51
  return None
52
 
53
- def get_image_embedding(image):
54
  image_tensor = preprocess_val(image).unsqueeze(0).to(device)
 
55
  with torch.no_grad():
56
  image_features = model.encode_image(image_tensor)
57
  image_features /= image_features.norm(dim=-1, keepdim=True)
58
- return image_features.cpu().numpy()
59
 
60
- def setup_roboflow_client(api_key):
61
- return InferenceHTTPClient(
62
- api_url="https://outline.roboflow.com",
63
- api_key=api_key
64
- )
65
-
66
- def segment_image(image_path, client):
67
- try:
68
- # 이미지 파일 읽기
69
- with open(image_path, "rb") as image_file:
70
- image_data = image_file.read()
71
-
72
- # 이미지를 base64로 인코딩
73
- encoded_image = base64.b64encode(image_data).decode('utf-8')
74
-
75
- # 원본 이미지 로드
76
- image = cv2.imread(image_path)
77
- image = cv2.resize(image, (800, 600))
78
- mask = np.zeros(image.shape, dtype=np.uint8)
79
-
80
- # Roboflow API 호출
81
- results = client.infer(encoded_image, model_id="closet/1")
82
-
83
- # 결과가 이미 딕셔너리인 경우 JSON 파싱 단계 제거
84
- if isinstance(results, dict):
85
- predictions = results.get('predictions', [])
86
- else:
87
- # 문자열인 경우에만 JSON 파싱
88
- predictions = json.loads(results).get('predictions', [])
89
-
90
- if predictions:
91
- for prediction in predictions:
92
- points = prediction['points']
93
- pts = np.array([[p['x'], p['y']] for p in points], np.int32)
94
- scale_x = image.shape[1] / results['image']['width']
95
- scale_y = image.shape[0] / results['image']['height']
96
- pts = pts * [scale_x, scale_y]
97
- pts = pts.astype(np.int32)
98
- pts = pts.reshape((-1, 1, 2))
99
- cv2.fillPoly(mask, [pts], color=(255, 255, 255)) # White mask
100
-
101
- segmented_image = cv2.bitwise_and(image, mask)
102
- else:
103
- st.warning("No predictions found in the image. Returning original image.")
104
- segmented_image = image
105
-
106
- return Image.fromarray(cv2.cvtColor(segmented_image, cv2.COLOR_BGR2RGB))
107
- except Exception as e:
108
- st.error(f"Error in segmentation: {str(e)}")
109
- # 원본 이미지를 다시 읽어 반환
110
- return Image.open(image_path)
111
 
112
  @st.cache_data
113
- def process_database_cached(data):
114
  database_embeddings = []
115
  database_info = []
 
116
  for item in data:
117
  image_url = item['이미지 링크'][0]
118
- product_id = item.get('\ufeff상품 ID') or item.get('상품 ID')
119
-
120
- image = download_and_process_image(image_url)
121
- if image is None:
122
- continue
123
-
124
- # Save the image temporarily
125
- temp_path = f"temp_{product_id}.jpg"
126
- image.save(temp_path, 'JPEG')
127
-
128
- database_info.append({
129
- 'id': product_id,
130
- 'category': item['카테고리'],
131
- 'brand': item['브랜드명'],
132
- 'name': item['제품명'],
133
- 'price': item['정가'],
134
- 'discount': item['할인율'],
135
- 'image_url': image_url,
136
- 'temp_path': temp_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
  })
138
-
139
- return database_info
140
 
141
- def process_database(client, data):
142
- database_info = process_database_cached(data)
143
- database_embeddings = []
144
-
145
- for item in database_info:
146
- segmented_image = segment_image(item['temp_path'], client)
147
- embedding = get_image_embedding(segmented_image)
148
- database_embeddings.append(embedding)
149
-
150
- return np.vstack(database_embeddings), database_info
151
 
152
  # Streamlit app
153
- st.title("Fashion Search App with Segmentation")
154
-
155
- # API Key input
156
- api_key = st.text_input("Enter your Roboflow API Key", type="password")
157
-
158
- if api_key:
159
- CLIENT = setup_roboflow_client(api_key)
160
-
161
- # Initialize database_embeddings and database_info
162
- database_embeddings, database_info = process_database(CLIENT, data)
163
-
164
- uploaded_file = st.file_uploader("Choose an image...", type="jpg")
165
- if uploaded_file is not None:
166
- image = Image.open(uploaded_file)
167
- st.image(image, caption='Uploaded Image', use_column_width=True)
168
-
169
- if st.button('Find Similar Items'):
170
- with st.spinner('Processing...'):
171
- # Save uploaded image temporarily
172
- temp_path = "temp_upload.jpg"
173
- image.save(temp_path)
174
-
175
- # Segment the uploaded image
176
- segmented_image = segment_image(temp_path, CLIENT)
177
- st.image(segmented_image, caption='Segmented Image', use_column_width=True)
178
-
179
- # Get embedding for segmented image
180
- query_embedding = get_image_embedding(segmented_image)
181
  similar_images = find_similar_images(query_embedding)
182
-
183
  st.subheader("Similar Items:")
184
  for img in similar_images:
185
  col1, col2 = st.columns(2)
@@ -192,5 +134,28 @@ if api_key:
192
  st.write(f"Price: {img['info']['price']}")
193
  st.write(f"Discount: {img['info']['discount']}%")
194
  st.write(f"Similarity: {img['similarity']:.2f}")
195
- else:
196
- st.warning("Please enter your Roboflow API Key to use the app.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  import time
8
  import json
9
  import numpy as np
 
 
 
 
10
 
11
  # Load model and tokenizer
12
  @st.cache_resource
13
  def load_model():
14
+ model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms('hf-hub:Marqo/marqo-fashionSigLIP')
15
+ tokenizer = open_clip.get_tokenizer('hf-hub:Marqo/marqo-fashionSigLIP')
16
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
  model.to(device)
18
  return model, preprocess_val, tokenizer, device
 
22
  # Load and process data
23
  @st.cache_data
24
  def load_data():
25
+ with open('./musinsa-final.json', 'r', encoding='utf-8') as f:
26
  return json.load(f)
27
 
28
  data = load_data()
29
 
30
  # Helper functions
31
+ def load_image_from_url(url, max_retries=3):
32
+ for attempt in range(max_retries):
33
+ try:
34
+ response = requests.get(url, timeout=10)
35
+ response.raise_for_status()
36
+ img = Image.open(BytesIO(response.content)).convert('RGB')
37
+ return img
38
+ except (requests.RequestException, Image.UnidentifiedImageError) as e:
39
+ #st.warning(f"Attempt {attempt + 1} failed: {str(e)}")
40
+ if attempt < max_retries - 1:
41
+ time.sleep(1)
42
+ else:
43
+ #st.error(f"Failed to load image from {url} after {max_retries} attempts")
44
+ return None
45
+
46
+ def get_image_embedding_from_url(image_url):
47
+ image = load_image_from_url(image_url)
48
+ if image is None:
49
  return None
50
 
 
51
  image_tensor = preprocess_val(image).unsqueeze(0).to(device)
52
+
53
  with torch.no_grad():
54
  image_features = model.encode_image(image_tensor)
55
  image_features /= image_features.norm(dim=-1, keepdim=True)
 
56
 
57
+ return image_features.cpu().numpy()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
  @st.cache_data
60
+ def process_database():
61
  database_embeddings = []
62
  database_info = []
63
+
64
  for item in data:
65
  image_url = item['이미지 링크'][0]
66
+ embedding = get_image_embedding_from_url(image_url)
67
+
68
+ if embedding is not None:
69
+ database_embeddings.append(embedding)
70
+ database_info.append({
71
+ 'id': item['\ufeff상품 ID'],
72
+ 'category': item['카테고리'],
73
+ 'brand': item['브랜드명'],
74
+ 'name': item['제품명'],
75
+ 'price': item['정가'],
76
+ 'discount': item['할인율'],
77
+ 'image_url': image_url
78
+ })
79
+ else:
80
+ st.warning(f"Skipping item {item['상품 ID']} due to image loading failure")
81
+
82
+ if database_embeddings:
83
+ return np.vstack(database_embeddings), database_info
84
+ else:
85
+ st.error("No valid embeddings were generated.")
86
+ return None, None
87
+
88
+ database_embeddings, database_info = process_database()
89
+
90
+ def get_text_embedding(text):
91
+ text_tokens = tokenizer([text]).to(device)
92
+
93
+ with torch.no_grad():
94
+ text_features = model.encode_text(text_tokens)
95
+ text_features /= text_features.norm(dim=-1, keepdim=True)
96
+
97
+ return text_features.cpu().numpy()
98
+
99
+ def find_similar_images(query_embedding, top_k=5):
100
+ similarities = np.dot(database_embeddings, query_embedding.T).squeeze()
101
+ top_indices = np.argsort(similarities)[::-1][:top_k]
102
+
103
+ results = []
104
+ for idx in top_indices:
105
+ results.append({
106
+ 'info': database_info[idx],
107
+ 'similarity': similarities[idx]
108
  })
 
 
109
 
110
+ return results
 
 
 
 
 
 
 
 
 
111
 
112
  # Streamlit app
113
+ st.title("Fashion Search App")
114
+
115
+ search_type = st.radio("Search by:", ("Image URL", "Text"))
116
+
117
+ if search_type == "Image URL":
118
+ query_image_url = st.text_input("Enter image URL:")
119
+ if st.button("Search by Image"):
120
+ if query_image_url:
121
+ query_embedding = get_image_embedding_from_url(query_image_url)
122
+ if query_embedding is not None:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
  similar_images = find_similar_images(query_embedding)
124
+ st.image(query_image_url, caption="Query Image", use_column_width=True)
125
  st.subheader("Similar Items:")
126
  for img in similar_images:
127
  col1, col2 = st.columns(2)
 
134
  st.write(f"Price: {img['info']['price']}")
135
  st.write(f"Discount: {img['info']['discount']}%")
136
  st.write(f"Similarity: {img['similarity']:.2f}")
137
+ else:
138
+ st.error("Failed to process the image. Please try another URL.")
139
+ else:
140
+ st.warning("Please enter an image URL.")
141
+
142
+ else: # Text search
143
+ query_text = st.text_input("Enter search text:")
144
+ if st.button("Search by Text"):
145
+ if query_text:
146
+ text_embedding = get_text_embedding(query_text)
147
+ similar_images = find_similar_images(text_embedding)
148
+ st.subheader("Similar Items:")
149
+ for img in similar_images:
150
+ col1, col2 = st.columns(2)
151
+ with col1:
152
+ st.image(img['info']['image_url'], use_column_width=True)
153
+ with col2:
154
+ st.write(f"Name: {img['info']['name']}")
155
+ st.write(f"Brand: {img['info']['brand']}")
156
+ st.write(f"Category: {img['info']['category']}")
157
+ st.write(f"Price: {img['info']['price']}")
158
+ st.write(f"Discount: {img['info']['discount']}%")
159
+ st.write(f"Similarity: {img['similarity']:.2f}")
160
+ else:
161
+ st.warning("Please enter a search text.")