|
2 | 2 | import requests
|
3 | 3 | import boto3
|
4 | 4 | import os
|
| 5 | +import io |
| 6 | + |
| 7 | +from streamlit_cropper import st_cropper |
| 8 | +from PIL import Image, ImageDraw |
5 | 9 |
|
6 | 10 |
|
7 | 11 | ALB_URL = os.environ.get('ALB_URL')
|
|
12 | 16 | s3 = boto3.client('s3')
|
13 | 17 |
|
14 | 18 | st.set_page_config(
|
15 |
| - page_title='Gen AI - Image Variation', |
| 19 | + page_title='Gen AI - Image Replace', |
16 | 20 | page_icon = 'images/aws_favi.png',
|
17 | 21 | # layout = 'wide'
|
18 | 22 | )
|
19 | 23 | st.title('이미지 교체')
|
20 | 24 | st.write('주변 배경과 일치하도록 변경하여 이미지를 수정합니다.')
|
21 | 25 |
|
| 26 | +if "mask_enable" not in st.session_state: |
| 27 | + st.session_state.mask_enable = False |
| 28 | + |
22 | 29 | uploaded_file = st.file_uploader("파일을 선택하세요", type=['png', 'jpg'])
|
23 | 30 | if uploaded_file is not None:
|
24 |
| - bytes_data = uploaded_file.getvalue() |
25 |
| - st.image(bytes_data) |
| 31 | + # print(st.session_state.mask_enable) |
| 32 | + st.checkbox("이미지 마스크 지정", key="mask_enable") |
| 33 | + if st.session_state.mask_enable: |
| 34 | + img = Image.open(uploaded_file) |
| 35 | + width, height = img.size |
26 | 36 |
|
27 |
| -m_input = st.text_area( |
28 |
| - '이미지에서 남기고 싶은 오브젝트를 서술합니다 예) car, phone, bag', |
29 |
| - '', |
30 |
| - # height=100 |
31 |
| -) |
32 |
| -q_input = st.text_area( |
33 |
| - '남기고 싶은 오프젝트 이외에 배경에 대해서 정의합니다', |
34 |
| - '', |
35 |
| - # height=100 |
36 |
| -) |
| 37 | + cropped_box = st_cropper( |
| 38 | + img, |
| 39 | + realtime_update=True, |
| 40 | + box_color='#0000FF', |
| 41 | + aspect_ratio=None, |
| 42 | + return_type='box' |
| 43 | + ) |
| 44 | + |
| 45 | + left = cropped_box['left'] |
| 46 | + top = cropped_box['top'] |
| 47 | + right = left + cropped_box['width'] |
| 48 | + bottom = top + cropped_box['height'] |
| 49 | + shape = (left, top, right, bottom) |
| 50 | + |
| 51 | + masked_image = Image.new('RGB', (width, height), color=(255, 255, 255)) |
| 52 | + temp_image = ImageDraw.Draw(masked_image) |
| 53 | + temp_image.rectangle(shape, fill=(0, 0, 0)) |
| 54 | + |
| 55 | + # st.image(masked_image) |
| 56 | + |
| 57 | + st.write('이미지에서 남기고 싶은 오브젝트') |
| 58 | + cropped_image = img.crop(shape) |
| 59 | + _ = cropped_image.thumbnail((150,150)) |
| 60 | + st.image(cropped_image) |
| 61 | + else: |
| 62 | + bytes_data = uploaded_file.getvalue() |
| 63 | + st.image(bytes_data) |
| 64 | + |
| 65 | + m_input = st.text_area( |
| 66 | + '이미지에서 남기고 싶은 오브젝트를 지정합니다 예) car, phone, bag', |
| 67 | + '', |
| 68 | + # height=100 |
| 69 | + ) |
| 70 | + |
| 71 | + q_input = st.text_area( |
| 72 | + '지정한 오프젝트가 보일 배경에 대해서 정의합니다', |
| 73 | + '', |
| 74 | + # height=100 |
| 75 | + ) |
37 | 76 |
|
38 | 77 | with st.form('submit_form', clear_on_submit=True):
|
39 | 78 | submitted = st.form_submit_button('Submit')
|
40 | 79 | if submitted:
|
41 | 80 | with st.spinner('Loading...'):
|
| 81 | + uploaded_file.seek(0) |
| 82 | + s3.upload_fileobj( |
| 83 | + uploaded_file, |
| 84 | + BUCKET_NAME, |
| 85 | + f'images/{uploaded_file.name}' |
| 86 | + ) |
| 87 | + |
| 88 | + if st.session_state.mask_enable: |
| 89 | + in_mem_file = io.BytesIO() |
| 90 | + masked_image.save(in_mem_file, format='PNG') |
| 91 | + in_mem_file.seek(0) |
| 92 | + |
| 93 | + file_name, ext = os.path.splitext(uploaded_file.name) |
| 94 | + masked_image_name = f'images/masked_{file_name}.png' |
42 | 95 | s3.upload_fileobj(
|
43 |
| - uploaded_file, |
| 96 | + in_mem_file, |
44 | 97 | BUCKET_NAME,
|
45 |
| - f'images/{uploaded_file.name}' |
| 98 | + masked_image_name |
46 | 99 | )
|
47 | 100 |
|
| 101 | + data = { |
| 102 | + 'name': f'images/{uploaded_file.name}', |
| 103 | + 'prompt': q_input, |
| 104 | + 'mask_image': masked_image_name |
| 105 | + } |
| 106 | + else: |
48 | 107 | data = {
|
49 | 108 | 'name': f'images/{uploaded_file.name}',
|
50 | 109 | 'prompt': q_input,
|
51 | 110 | 'mask_prompt': m_input
|
52 | 111 | }
|
53 |
| - response = requests.post(API_URL, json=data) |
54 |
| - result = response.text |
| 112 | + |
| 113 | + response = requests.post(API_URL, json=data) |
| 114 | + result = response.text |
55 | 115 |
|
56 |
| - print(result) |
57 |
| - image_object = s3.get_object(Bucket=BUCKET_NAME, Key=f'images/{result}') |
58 |
| - st.image(image_object['Body'].read()) |
| 116 | + print(result) |
| 117 | + image_object = s3.get_object(Bucket=BUCKET_NAME, Key=f'images/{result}') |
| 118 | + st.image(image_object['Body'].read()) |
59 | 119 |
|
0 commit comments