[Project] Threepark

[Threepark] 3. 모델 서빙 - (1) Flask 작성

mingyung 2024. 5. 25. 19:08

이번에는 GPU서버에 올릴 Flask 를 작성한다.

Flask에서는 모델을 로드하고, 백엔드 API에 필요한 AI작업을 수행한다.

GPU비용 문제로, GPU가 꼭 필요한 경우만 이 서버에서 동작한다.

소개

Flask에서는 다음을 수행한다.

1. generate comment : 일기 작성 내용을 바탕으로 응원 문구를 생성한다.

2. generate image : 백엔드로부터 이미지 생성 프롬프트를 받고, 튜닝된 디퓨전 모델을 로드하여 이미지를 생성한다.

3. emotion classification : 일기 작성 내용을 바탕으로 감정 분석한다.

4. recommend music : 감정 분석 결과와 크롤링을 통해 수집한 음악 데이터의 감정분석 결과를 사용해 유사도를 통한 음악 추천, 결과 총 5가지 반환

 

 

Image

이미지를 생성하기 위한 프롬프트를 받아서 이미지 생성, S3버킷에 저장하고, URL정보를 응답으로 반환한다.

이미지를 생성하는 함수를 먼저 다른 파일에 작성하자.

#generate_image.py 에 모델을 로드하고, 이미지를 생성하는 함수를 만들었다.
from diffusers import StableDiffusionPipeline
import torch

# 모델 로드 및 디바이스 설정
model_path = '모델 위치'  # FineTuning Model Path
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
device = torch.device('cuda')
pipe.unet.load_attn_procs(model_path)
pipe.to(device)

# Negative prompt 설정
neg_prompt = '''FastNegativeV2,(bad-artist:1.0), (loli:1.2),
    (worst quality, low quality:1.4), (bad_prompt_version2:0.8),
    bad-hands-5,lowres, bad anatomy, bad hands, ((text)), (watermark),
    error, missing fingers, extra digit, fewer digits, cropped,
    worst quality, low quality, normal quality, ((username)), blurry,
    (extra limbs), bad-artist-anime, badhandv4, EasyNegative,
    ng_deepnegative_v1_75t, verybadimagenegative_v1.3, BadDream,
    (three hands:1.1),(three legs:1.1),(more than two hands:1.4),
    (more than two legs,:1.2),badhandv4,EasyNegative,ng_deepnegative_v1_75t,verybadimagenegative_v1.3,(worst quality, low quality:1.4),text,words,logo,watermark,
    '''

def get_image(prompt):
    image = pipe(prompt, negative_prompt=neg_prompt,num_inference_steps=30, guidance_scale=7.5).images[0]
    return image

대체로 15초정도 소요되고, 늦어지면 30초 이내로 생성된다. 하지만 사실 꽤 긴 시간이므로 이를 잘 고려해서 응답을 주고받는 백엔드 서버에서 응답 대기 시간을 늘려주자.

#app.py

# S3에 연결
s3 = boto3.client('s3',
                  aws_access_key_id=S3_ACESS_KEY,
                  aws_secret_access_key=S3_SECRET_ACCESS_KEY)

# 이미지 생성 요청에 대한 작업
@app.route('/get_image', methods=['POST'])
async def process_image_request():
    try:
        # 받은 요청의 데이터를 확인
        request_data = request.json
        prompt = request_data.get('prompt')
        
        #받은 프롬프트로 이미지를 생성한다.
        image = get_image(prompt)
        
        # 이미지를 저장한다.
        image_key = str(uuid.uuid4())
        buffered = io.BytesIO()
        image.save(buffered, format="JPEG")
        buffered.seek(0)
        s3.upload_fileobj(buffered, Bucket=S3_BUCKET_NAME, Key=f'images/{image_key}.jpg', ExtraArgs={'ContentType':'image/jpeg'})
        image_url = f'https://{S3_BUCKET_NAME}.s3.{AWS_S3_REGION_NAME}.amazonaws.com/images/{image_key}.jpg'
        buffered.close()
        
        # 저장 후 이미지의 URL을 응답한다.
        return jsonify({'image_url': image_url}), 200
    except Exception as e:
        print("Exception occurred in process_request:", e)
        return jsonify({"error": str(e)}), 500

 

Comment

프로세스는 위와 같다. 모델을 로드하여 작업을 수행하는 함수를 하나 만들고, API요청에 대한 작업을 app.py에 작성한다.

#모델 로드
#gpt model
print('gpt_load')
gpt_device = torch.device("cuda:0")
gpt_model = GPT2LMHeadModel.from_pretrained('모델').to(gpt_device)
gpt_tokenizer = PreTrainedTokenizerFast.from_pretrained('모델')
U_TKN = '<usr>'
S_TKN = '<sys>'
BOS = '</s>'
EOS = '</s>'
MASK = '<unused0>'
SENT = '<unused1>'
PAD = '<pad>'

def get_comment(input_text): #koGPT2 모델을 활용하여 입력된 질문에 대한 대답을 생성하는 함수
    q = input_text
    a = ""
    sent = ""
    while True:
        input_ids = torch.LongTensor(gpt_tokenizer.encode(U_TKN + q + SENT + sent + S_TKN + a)).unsqueeze(dim=0).to(gpt_device)
        pred = gpt_model(input_ids)
        pred = pred.logits
        gen = gpt_tokenizer.convert_ids_to_tokens(torch.argmax(pred, dim=-1).squeeze().tolist())[-1]
        if gen == EOS:
            break
        a += gen.replace("▁", " ")
    return a

 

# app.py
@app.route('/get_comment', methods=['POST'])
async def process_comment_request():
    try:
        request_data = request.json
        content = request_data.get('content')
        comment = get_comment(content)
        
        return jsonify({'comment': comment}), 200
    except Exception as e:
        print("Exception occurred in process_request:", e)
        return jsonify({"error": str(e)}), 500

 

Emotion

def get_emotion_label(content):
    emotion_pred = inference(content)
    max_value = max(emotion_pred)
    max_index = emotion_pred .index(max_value)
    return emotion_pred, emotion_arr[max_index]


def get_music(content):
    emotion_pred, max_index=get_emotion_label(content)
    df_user_sentiment = pd.DataFrame([emotion_pred],columns=emotion_arr)
    user_emotion_str = df_user_sentiment.apply(lambda x: ' '.join(map(str, x)), axis=1)
    music_emotion_str = final_emotion[emotion_arr].apply(lambda x: ' '.join(map(str, x)), axis=1)

    tfidf = TfidfVectorizer()
    user_tfidf_matrix = tfidf.fit_transform(user_emotion_str)
    music_tfidf_matrix = tfidf.transform(music_emotion_str)

    cosine_sim = cosine_similarity(user_tfidf_matrix, music_tfidf_matrix)
    
    most_similar_song_index = cosine_sim.argmax()
    most_similar_song_info = final_emotion.iloc[most_similar_song_index]

    num_additional_recommendations = 4
    similar_songs_indices = cosine_sim.argsort()[0][-num_additional_recommendations-1:-1][::-1]
    similar_songs_info = final_emotion.iloc[similar_songs_indices]

    return most_similar_song_info, similar_songs_info
@app.route('/get_sentiment', methods=['POST'])
async def process_sentiment_request():
    try:
        request_data = request.json
        content = request_data.get('content')
        _, emotion_label = get_emotion_label(content)
        
        return jsonify({'emotion_label': emotion_label}), 200
    except Exception as e:
        print("Exception occurred in process_request:", e)
        return jsonify({"error": str(e)}), 500

@app.route('/get_music', methods=['POST'])
async def process_music_request():
    try:
        request_data = request.json
        content = request_data.get('content')
        most_similar_song_info, similar_songs_info = get_music(content)
        
        response_data = {
            'most_similar_song': {
                'title': most_similar_song_info[0],
                'artist': most_similar_song_info[1],
                'genre': most_similar_song_info[2]
            },
            'similar_songs': [{
                'title': song_info[0],
                'artist': song_info[1],
                'genre': song_info[2]
            } for song_info in similar_songs_info.values]
        }
        return jsonify(response_data), 200
    except Exception as e:
        print("Exception occurred in process_request:", e)
        return jsonify({"error": str(e)}), 500