์ด๋ฒ์๋ GPU์๋ฒ์ ์ฌ๋ฆด Flask ๋ฅผ ์์ฑํ๋ค.
Flask์์๋ ๋ชจ๋ธ์ ๋ก๋ํ๊ณ , ๋ฐฑ์๋ API์ ํ์ํ AI์์ ์ ์ํํ๋ค.
GPU๋น์ฉ ๋ฌธ์ ๋ก, GPU๊ฐ ๊ผญ ํ์ํ ๊ฒฝ์ฐ๋ง ์ด ์๋ฒ์์ ๋์ํ๋ค.
์๊ฐ
Flask์์๋ ๋ค์์ ์ํํ๋ค.
1. generate comment : ์ผ๊ธฐ ์์ฑ ๋ด์ฉ์ ๋ฐํ์ผ๋ก ์์ ๋ฌธ๊ตฌ๋ฅผ ์์ฑํ๋ค.
2. generate image : ๋ฐฑ์๋๋ก๋ถํฐ ์ด๋ฏธ์ง ์์ฑ ํ๋กฌํํธ๋ฅผ ๋ฐ๊ณ , ํ๋๋ ๋ํจ์ ๋ชจ๋ธ์ ๋ก๋ํ์ฌ ์ด๋ฏธ์ง๋ฅผ ์์ฑํ๋ค.
3. emotion classification : ์ผ๊ธฐ ์์ฑ ๋ด์ฉ์ ๋ฐํ์ผ๋ก ๊ฐ์ ๋ถ์ํ๋ค.
4. recommend music : ๊ฐ์ ๋ถ์ ๊ฒฐ๊ณผ์ ํฌ๋กค๋ง์ ํตํด ์์งํ ์์ ๋ฐ์ดํฐ์ ๊ฐ์ ๋ถ์ ๊ฒฐ๊ณผ๋ฅผ ์ฌ์ฉํด ์ ์ฌ๋๋ฅผ ํตํ ์์ ์ถ์ฒ, ๊ฒฐ๊ณผ ์ด 5๊ฐ์ง ๋ฐํ
Image
์ด๋ฏธ์ง๋ฅผ ์์ฑํ๊ธฐ ์ํ ํ๋กฌํํธ๋ฅผ ๋ฐ์์ ์ด๋ฏธ์ง ์์ฑ, S3๋ฒํท์ ์ ์ฅํ๊ณ , URL์ ๋ณด๋ฅผ ์๋ต์ผ๋ก ๋ฐํํ๋ค.
์ด๋ฏธ์ง๋ฅผ ์์ฑํ๋ ํจ์๋ฅผ ๋จผ์ ๋ค๋ฅธ ํ์ผ์ ์์ฑํ์.
#generate_image.py ์ ๋ชจ๋ธ์ ๋ก๋ํ๊ณ , ์ด๋ฏธ์ง๋ฅผ ์์ฑํ๋ ํจ์๋ฅผ ๋ง๋ค์๋ค.
from diffusers import StableDiffusionPipeline
import torch
# ๋ชจ๋ธ ๋ก๋ ๋ฐ ๋๋ฐ์ด์ค ์ค์
model_path = '๋ชจ๋ธ ์์น' # FineTuning Model Path
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
device = torch.device('cuda')
pipe.unet.load_attn_procs(model_path)
pipe.to(device)
# Negative prompt ์ค์
neg_prompt = '''FastNegativeV2,(bad-artist:1.0), (loli:1.2),
(worst quality, low quality:1.4), (bad_prompt_version2:0.8),
bad-hands-5,lowres, bad anatomy, bad hands, ((text)), (watermark),
error, missing fingers, extra digit, fewer digits, cropped,
worst quality, low quality, normal quality, ((username)), blurry,
(extra limbs), bad-artist-anime, badhandv4, EasyNegative,
ng_deepnegative_v1_75t, verybadimagenegative_v1.3, BadDream,
(three hands:1.1),(three legs:1.1),(more than two hands:1.4),
(more than two legs,:1.2),badhandv4,EasyNegative,ng_deepnegative_v1_75t,verybadimagenegative_v1.3,(worst quality, low quality:1.4),text,words,logo,watermark,
'''
def get_image(prompt):
image = pipe(prompt, negative_prompt=neg_prompt,num_inference_steps=30, guidance_scale=7.5).images[0]
return image
๋์ฒด๋ก 15์ด์ ๋ ์์๋๊ณ , ๋ฆ์ด์ง๋ฉด 30์ด ์ด๋ด๋ก ์์ฑ๋๋ค. ํ์ง๋ง ์ฌ์ค ๊ฝค ๊ธด ์๊ฐ์ด๋ฏ๋ก ์ด๋ฅผ ์ ๊ณ ๋ คํด์ ์๋ต์ ์ฃผ๊ณ ๋ฐ๋ ๋ฐฑ์๋ ์๋ฒ์์ ์๋ต ๋๊ธฐ ์๊ฐ์ ๋๋ ค์ฃผ์.
#app.py
# S3์ ์ฐ๊ฒฐ
s3 = boto3.client('s3',
aws_access_key_id=S3_ACESS_KEY,
aws_secret_access_key=S3_SECRET_ACCESS_KEY)
# ์ด๋ฏธ์ง ์์ฑ ์์ฒญ์ ๋ํ ์์
@app.route('/get_image', methods=['POST'])
async def process_image_request():
try:
# ๋ฐ์ ์์ฒญ์ ๋ฐ์ดํฐ๋ฅผ ํ์ธ
request_data = request.json
prompt = request_data.get('prompt')
#๋ฐ์ ํ๋กฌํํธ๋ก ์ด๋ฏธ์ง๋ฅผ ์์ฑํ๋ค.
image = get_image(prompt)
# ์ด๋ฏธ์ง๋ฅผ ์ ์ฅํ๋ค.
image_key = str(uuid.uuid4())
buffered = io.BytesIO()
image.save(buffered, format="JPEG")
buffered.seek(0)
s3.upload_fileobj(buffered, Bucket=S3_BUCKET_NAME, Key=f'images/{image_key}.jpg', ExtraArgs={'ContentType':'image/jpeg'})
image_url = f'https://{S3_BUCKET_NAME}.s3.{AWS_S3_REGION_NAME}.amazonaws.com/images/{image_key}.jpg'
buffered.close()
# ์ ์ฅ ํ ์ด๋ฏธ์ง์ URL์ ์๋ตํ๋ค.
return jsonify({'image_url': image_url}), 200
except Exception as e:
print("Exception occurred in process_request:", e)
return jsonify({"error": str(e)}), 500
Comment
ํ๋ก์ธ์ค๋ ์์ ๊ฐ๋ค. ๋ชจ๋ธ์ ๋ก๋ํ์ฌ ์์ ์ ์ํํ๋ ํจ์๋ฅผ ํ๋ ๋ง๋ค๊ณ , API์์ฒญ์ ๋ํ ์์ ์ app.py์ ์์ฑํ๋ค.
#๋ชจ๋ธ ๋ก๋
#gpt model
print('gpt_load')
gpt_device = torch.device("cuda:0")
gpt_model = GPT2LMHeadModel.from_pretrained('๋ชจ๋ธ').to(gpt_device)
gpt_tokenizer = PreTrainedTokenizerFast.from_pretrained('๋ชจ๋ธ')
U_TKN = '<usr>'
S_TKN = '<sys>'
BOS = '</s>'
EOS = '</s>'
MASK = '<unused0>'
SENT = '<unused1>'
PAD = '<pad>'
def get_comment(input_text): #koGPT2 ๋ชจ๋ธ์ ํ์ฉํ์ฌ ์
๋ ฅ๋ ์ง๋ฌธ์ ๋ํ ๋๋ต์ ์์ฑํ๋ ํจ์
q = input_text
a = ""
sent = ""
while True:
input_ids = torch.LongTensor(gpt_tokenizer.encode(U_TKN + q + SENT + sent + S_TKN + a)).unsqueeze(dim=0).to(gpt_device)
pred = gpt_model(input_ids)
pred = pred.logits
gen = gpt_tokenizer.convert_ids_to_tokens(torch.argmax(pred, dim=-1).squeeze().tolist())[-1]
if gen == EOS:
break
a += gen.replace("โ", " ")
return a
# app.py
@app.route('/get_comment', methods=['POST'])
async def process_comment_request():
try:
request_data = request.json
content = request_data.get('content')
comment = get_comment(content)
return jsonify({'comment': comment}), 200
except Exception as e:
print("Exception occurred in process_request:", e)
return jsonify({"error": str(e)}), 500
Emotion
def get_emotion_label(content):
emotion_pred = inference(content)
max_value = max(emotion_pred)
max_index = emotion_pred .index(max_value)
return emotion_pred, emotion_arr[max_index]
def get_music(content):
emotion_pred, max_index=get_emotion_label(content)
df_user_sentiment = pd.DataFrame([emotion_pred],columns=emotion_arr)
user_emotion_str = df_user_sentiment.apply(lambda x: ' '.join(map(str, x)), axis=1)
music_emotion_str = final_emotion[emotion_arr].apply(lambda x: ' '.join(map(str, x)), axis=1)
tfidf = TfidfVectorizer()
user_tfidf_matrix = tfidf.fit_transform(user_emotion_str)
music_tfidf_matrix = tfidf.transform(music_emotion_str)
cosine_sim = cosine_similarity(user_tfidf_matrix, music_tfidf_matrix)
most_similar_song_index = cosine_sim.argmax()
most_similar_song_info = final_emotion.iloc[most_similar_song_index]
num_additional_recommendations = 4
similar_songs_indices = cosine_sim.argsort()[0][-num_additional_recommendations-1:-1][::-1]
similar_songs_info = final_emotion.iloc[similar_songs_indices]
return most_similar_song_info, similar_songs_info
@app.route('/get_sentiment', methods=['POST'])
async def process_sentiment_request():
try:
request_data = request.json
content = request_data.get('content')
_, emotion_label = get_emotion_label(content)
return jsonify({'emotion_label': emotion_label}), 200
except Exception as e:
print("Exception occurred in process_request:", e)
return jsonify({"error": str(e)}), 500
@app.route('/get_music', methods=['POST'])
async def process_music_request():
try:
request_data = request.json
content = request_data.get('content')
most_similar_song_info, similar_songs_info = get_music(content)
response_data = {
'most_similar_song': {
'title': most_similar_song_info[0],
'artist': most_similar_song_info[1],
'genre': most_similar_song_info[2]
},
'similar_songs': [{
'title': song_info[0],
'artist': song_info[1],
'genre': song_info[2]
} for song_info in similar_songs_info.values]
}
return jsonify(response_data), 200
except Exception as e:
print("Exception occurred in process_request:", e)
return jsonify({"error": str(e)}), 500
'๐ฅProject > [Project] Threepark' ์นดํ ๊ณ ๋ฆฌ์ ๋ค๋ฅธ ๊ธ
[Threepark] 3. ๋ฐฑ์๋ ๊ตฌํ - (8) DRF ๊ฐ๋ฐ | API๋ฌธ์ Swagger (0) | 2024.05.25 |
---|---|
[Threepark] ๋ฐฑ์๋ ์๋ฒ์ ํต์ฌ ๊ธฐ๋ฅ ๊ตฌํ (0) | 2024.05.21 |
[Capstone Design] 3. ๋ฐฑ์๋ ๊ตฌํ - (7) DRF ๊ฐ๋ฐ | VIEW (0) | 2024.05.21 |
[Capstone Design] 3. ๋ฐฑ์๋ ๊ตฌํ - (6) DRF ๊ฐ๋ฐ | PERMISSION (0) | 2024.05.21 |
[Capstone Design] 3. ๋ฐฑ์๋ ๊ตฌํ - (5) DRF ๊ฐ๋ฐ | SERIALIZER (0) | 2024.05.21 |