이번에는 GPU서버에 올릴 Flask 를 작성한다.
Flask에서는 모델을 로드하고, 백엔드 API에 필요한 AI작업을 수행한다.
GPU비용 문제로, GPU가 꼭 필요한 경우만 이 서버에서 동작한다.
소개
Flask에서는 다음을 수행한다.
1. generate comment : 일기 작성 내용을 바탕으로 응원 문구를 생성한다.
2. generate image : 백엔드로부터 이미지 생성 프롬프트를 받고, 튜닝된 디퓨전 모델을 로드하여 이미지를 생성한다.
3. emotion classification : 일기 작성 내용을 바탕으로 감정 분석한다.
4. recommend music : 감정 분석 결과와 크롤링을 통해 수집한 음악 데이터의 감정분석 결과를 사용해 유사도를 통한 음악 추천, 결과 총 5가지 반환
Image
이미지를 생성하기 위한 프롬프트를 받아서 이미지 생성, S3버킷에 저장하고, URL정보를 응답으로 반환한다.
이미지를 생성하는 함수를 먼저 다른 파일에 작성하자.
#generate_image.py 에 모델을 로드하고, 이미지를 생성하는 함수를 만들었다.
from diffusers import StableDiffusionPipeline
import torch
# 모델 로드 및 디바이스 설정
model_path = '모델 위치' # FineTuning Model Path
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
device = torch.device('cuda')
pipe.unet.load_attn_procs(model_path)
pipe.to(device)
# Negative prompt 설정
neg_prompt = '''FastNegativeV2,(bad-artist:1.0), (loli:1.2),
(worst quality, low quality:1.4), (bad_prompt_version2:0.8),
bad-hands-5,lowres, bad anatomy, bad hands, ((text)), (watermark),
error, missing fingers, extra digit, fewer digits, cropped,
worst quality, low quality, normal quality, ((username)), blurry,
(extra limbs), bad-artist-anime, badhandv4, EasyNegative,
ng_deepnegative_v1_75t, verybadimagenegative_v1.3, BadDream,
(three hands:1.1),(three legs:1.1),(more than two hands:1.4),
(more than two legs,:1.2),badhandv4,EasyNegative,ng_deepnegative_v1_75t,verybadimagenegative_v1.3,(worst quality, low quality:1.4),text,words,logo,watermark,
'''
def get_image(prompt):
image = pipe(prompt, negative_prompt=neg_prompt,num_inference_steps=30, guidance_scale=7.5).images[0]
return image
대체로 15초정도 소요되고, 늦어지면 30초 이내로 생성된다. 하지만 사실 꽤 긴 시간이므로 이를 잘 고려해서 응답을 주고받는 백엔드 서버에서 응답 대기 시간을 늘려주자.
#app.py
# S3에 연결
s3 = boto3.client('s3',
aws_access_key_id=S3_ACESS_KEY,
aws_secret_access_key=S3_SECRET_ACCESS_KEY)
# 이미지 생성 요청에 대한 작업
@app.route('/get_image', methods=['POST'])
async def process_image_request():
try:
# 받은 요청의 데이터를 확인
request_data = request.json
prompt = request_data.get('prompt')
#받은 프롬프트로 이미지를 생성한다.
image = get_image(prompt)
# 이미지를 저장한다.
image_key = str(uuid.uuid4())
buffered = io.BytesIO()
image.save(buffered, format="JPEG")
buffered.seek(0)
s3.upload_fileobj(buffered, Bucket=S3_BUCKET_NAME, Key=f'images/{image_key}.jpg', ExtraArgs={'ContentType':'image/jpeg'})
image_url = f'https://{S3_BUCKET_NAME}.s3.{AWS_S3_REGION_NAME}.amazonaws.com/images/{image_key}.jpg'
buffered.close()
# 저장 후 이미지의 URL을 응답한다.
return jsonify({'image_url': image_url}), 200
except Exception as e:
print("Exception occurred in process_request:", e)
return jsonify({"error": str(e)}), 500
Comment
프로세스는 위와 같다. 모델을 로드하여 작업을 수행하는 함수를 하나 만들고, API요청에 대한 작업을 app.py에 작성한다.
#모델 로드
#gpt model
print('gpt_load')
gpt_device = torch.device("cuda:0")
gpt_model = GPT2LMHeadModel.from_pretrained('모델').to(gpt_device)
gpt_tokenizer = PreTrainedTokenizerFast.from_pretrained('모델')
U_TKN = '<usr>'
S_TKN = '<sys>'
BOS = '</s>'
EOS = '</s>'
MASK = '<unused0>'
SENT = '<unused1>'
PAD = '<pad>'
def get_comment(input_text): #koGPT2 모델을 활용하여 입력된 질문에 대한 대답을 생성하는 함수
q = input_text
a = ""
sent = ""
while True:
input_ids = torch.LongTensor(gpt_tokenizer.encode(U_TKN + q + SENT + sent + S_TKN + a)).unsqueeze(dim=0).to(gpt_device)
pred = gpt_model(input_ids)
pred = pred.logits
gen = gpt_tokenizer.convert_ids_to_tokens(torch.argmax(pred, dim=-1).squeeze().tolist())[-1]
if gen == EOS:
break
a += gen.replace("▁", " ")
return a
# app.py
@app.route('/get_comment', methods=['POST'])
async def process_comment_request():
try:
request_data = request.json
content = request_data.get('content')
comment = get_comment(content)
return jsonify({'comment': comment}), 200
except Exception as e:
print("Exception occurred in process_request:", e)
return jsonify({"error": str(e)}), 500
Emotion
def get_emotion_label(content):
emotion_pred = inference(content)
max_value = max(emotion_pred)
max_index = emotion_pred .index(max_value)
return emotion_pred, emotion_arr[max_index]
def get_music(content):
emotion_pred, max_index=get_emotion_label(content)
df_user_sentiment = pd.DataFrame([emotion_pred],columns=emotion_arr)
user_emotion_str = df_user_sentiment.apply(lambda x: ' '.join(map(str, x)), axis=1)
music_emotion_str = final_emotion[emotion_arr].apply(lambda x: ' '.join(map(str, x)), axis=1)
tfidf = TfidfVectorizer()
user_tfidf_matrix = tfidf.fit_transform(user_emotion_str)
music_tfidf_matrix = tfidf.transform(music_emotion_str)
cosine_sim = cosine_similarity(user_tfidf_matrix, music_tfidf_matrix)
most_similar_song_index = cosine_sim.argmax()
most_similar_song_info = final_emotion.iloc[most_similar_song_index]
num_additional_recommendations = 4
similar_songs_indices = cosine_sim.argsort()[0][-num_additional_recommendations-1:-1][::-1]
similar_songs_info = final_emotion.iloc[similar_songs_indices]
return most_similar_song_info, similar_songs_info
@app.route('/get_sentiment', methods=['POST'])
async def process_sentiment_request():
try:
request_data = request.json
content = request_data.get('content')
_, emotion_label = get_emotion_label(content)
return jsonify({'emotion_label': emotion_label}), 200
except Exception as e:
print("Exception occurred in process_request:", e)
return jsonify({"error": str(e)}), 500
@app.route('/get_music', methods=['POST'])
async def process_music_request():
try:
request_data = request.json
content = request_data.get('content')
most_similar_song_info, similar_songs_info = get_music(content)
response_data = {
'most_similar_song': {
'title': most_similar_song_info[0],
'artist': most_similar_song_info[1],
'genre': most_similar_song_info[2]
},
'similar_songs': [{
'title': song_info[0],
'artist': song_info[1],
'genre': song_info[2]
} for song_info in similar_songs_info.values]
}
return jsonify(response_data), 200
except Exception as e:
print("Exception occurred in process_request:", e)
return jsonify({"error": str(e)}), 500
'[Project] Threepark' 카테고리의 다른 글
[Threepark] 3. 백엔드 구현 - (8) DRF 개발 | API문서 Swagger (0) | 2024.05.25 |
---|---|
[Threepark] 백엔드 서버와 핵심 기능 구현 (0) | 2024.05.21 |
[Capstone Design] 3. 백엔드 구현 - (7) DRF 개발 | VIEW (0) | 2024.05.21 |
[Capstone Design] 3. 백엔드 구현 - (6) DRF 개발 | PERMISSION (0) | 2024.05.21 |
[Capstone Design] 3. 백엔드 구현 - (5) DRF 개발 | SERIALIZER (0) | 2024.05.21 |