+3

[Practical Series] Late Chunking - Improvements in RAG chunking

Mở đầu.

Xin chào mọi người đến với bài tiếp theo của practical Series. Nhân tiện bài trước đó của mình có điểm qua một số phương pháp Chunking (ở đây), tập trung hơn một chút về agentic chunking. Thì đến bài này chúng ta sẽ điểm qua 1 kỹ thuật khác là Late Chunking giúp cải thiện trong phase Chunking và Retrieval của RAG.

Important : Trước khi đọc bài viết, mình xin được nhắc lại, đây không phải là một Paper Explain. Mục đích của bài viết chỉ là tìm hiểu kiến thức mới và diễn giải nó theo cách hiểu của mình, mình sẽ chỉ điểm qua những thông tin mình cho là quan trọng nên sẽ có nhiều đoạn bị tối nghĩa, không đủ clear. Khuyến khích mọi người đọc từ paper và các nguồn ở reference để hiểu rõ thêm.

Late Chunking.

Đặt vấn đề.

Let's goo. Oke thì paper ở đây. (Trước khi đi sâu hơn thì hi vọng mọi người đá qua AbstractMethod của nó trước để có cái nhìn tổng quan).

Đầu tiên thì ở bài trước đó mình có mô tả những cách chunking thường được sử dụng. Ví dụ như hình sau : image.png Chúng ta sẽ thường chia nhỏ văn bản thành các đoạn text nhỏ rồi embed các đoạn text (hay chunk) thành embeddings . Tuy nhiên, các chunk này được embed một cách độc lập điều này dẫn đến việc các chunk embeddings bị mất đi contextual information. Đá qua một ví dụ nhỏ ở bài trước để hiểu việc mất thông tin này 🫠

Python là một ngôn ngữ lập trình phổ biến. Nó được sử dụng rộng rãi trong khoa học dữ liệu, trí tuệ nhân tạo, và phát triển web. Guido van Rossum là người tạo ra nó vào năm 1991.

Nếu chiến thuật chunk ở đây chúng ta sử dụng là chunk theo từng câu sẽ tách thành 3 câu :

  • Python là một ngôn ngữ lập trình phổ biến.
  • Nó được sử dụng rộng rãi trong khoa học dữ liệu, trí tuệ nhân tạo, và phát triển web.
  • Guido van Rossum là người tạo ra nó vào năm 1991.

ở câu 2, 3 ở đây là Python , tuy nhiên do embed một cách độc lập nên không thể biết nó đại diện cho điều gì ➜ hiện tượng mất đi contextual information.

Để giải quyết việc này thì author có đề cập tới một phương pháp mới là Late Chunking. 🤨🤨

Cách hoạt động.

Oke, nào lan man thế đủ rồi, bắt đầu vào chủ đề chính nhé 😊.

Author có đề xuất một phương pháp tiếp cận mới mang tên Late chunking . Y như cái tên nó là việc chia chunk muộn, thay vì chia nhỏ text rồi embed độc lập thì tác giải tận dụng long context embedding model để embed toàn bộ văn bản rồi mới đưa ra chunk. Điều này kỳ vọng sẽ giảm bớt việc mất mát thông tin của văn bản.

Nhìn hình ảnh để hình dung dễ hơn nhé. image.png

Oke, nó ý tưởng cũng chỉ có thế thôi, để hiểu rõ hơn một chút nữa, chúng ta xem qua lại một chút về Sentence Transformer và Thuật toán của Late Chunking nhé.

Sentence Transformer

Nhắc lại một chút về Sentence Transformer nhé 😗 image.png Nhìn qua hình ảnh, chúng ta cũng hình dung được phần nào cách hoạt động của nó rồi nhỉ. 😉

  • Đầu tiên thì sử dụng BERT để embedding từng token (Token embeddings này đã có các thông tin từ các Token embeddings khác - mọi người có thể xem lại phase encode của Transformer để nhớ lại).
  • Sau đó sử dụng mean pool tổng hợp các token embeddings thành một sentence embedidng duy nhất.

Oke, đơn giản thì nó chỉ có thế thôi. Để hiểu rõ hơn nữa về việc nó embed hay mean pool kiểu gì thì ở ref mình có kèm link nha 😄.

Thuật toán.

Late Chunking.

Oke thuật toán của nó thì sẽ như thế này :

image.png

Nhìn mấy cái mô tả này hơi nhức đầu nhỉ🥲. Đơn giản nó một chút xuống nào 🤔

  • (c1, . . . , cn) ← Chunker(T, S) : Đơn giản chỉ là áp dụng chiến thuật Chunk để đưa ra các chunk trước.
  • (ϑ1, . . . , ϑm) ← Model(τ1, . . . , τm) : Sau đó thì sử dụng Long Context Embedding Model để embed các token lại.
  • Còn 2 đoạn for còn lại chỉ đơn giản là : Gom các token embedding theo từng chunk ở bước 1 và dùng mean pool đối với List chunk đã được gom đó.

Hừm, điều này có thể sử dụng được khi mà Context Length của embedding model đủ lớn (bao trọn được toàn bộ văn bản).

Long Late Chunking.

Tuy nhiên sẽ có những trường hợp mà văn bản quá dài, context length của model không thể chứa đủ, thì làm thế nào. Ngoài ra còn có memory khi mà số lượng token tăng lên rất nhiều nữa.🥲 Điều này làm cho việc embed toàn bộ văn bản là bất khả thi. 🥲

Ở đây Author có đề cập tới một thuật toán nữa giải quyết vấn đề này Long Late Chunking image.png

Đơn giản hoá nó một chút 😂. Oke, chúng ta sẽ chỉ cần quan tâm đến dòng 14 to 16 mà thuật toán đề cập. Đơn giản chỉ là lợi dụng Overlap Token w để giảm thiểu việc mất mát thông tin giữa các macro chunk (là chunk to - chứa các chunk nhỏ và max length là l max ) với nhau.

Hừm, ở đây thì author đang lợi dụng Overlap token w để làm cầu nối giữa các macro chunk. Vậy thì có một số nảy sinh với overlap token w này.

Liệu Overlap Token W có chứa đủ thông tin ngữ cảnh quan trọng để làm cầu nối không ?

Đúng thì ở đây không có gì đảm bảo 100% rằng w luôn luôn chứa ngữ cảnh quan trọng tuy nhiên thì ở đây việc sử dụng overlap token w đang dựa trên một giải định :

  • Trong nhiều văn bản, ngữ cảnh thường liền mạch (ví dụ, một đoạn văn, một câu dài, hoặc mối liên kết logic giữa các câu). Token w được chọn nằm ở cuối một khối và đầu khối tiếp theo, tức là ở vị trí có khả năng mang ngữ cảnh cần thiết.
  • ω không chỉ đơn thuần là một phần của khối trước, mà còn cung cấp thông tin cho khối tiếp theo. Do đó, ω có vai trò như một cầu nối để giảm thiểu mất mát ngữ cảnh.

Vậy chọn overlap token w như nào thì tối ưu ?

  • Rõ ràng, Nếu văn bản bị phân chia ở vị trí ngẫu nhiên hoặc không hợp lý (ví dụ, giữa một câu dài), thì 𝜔 có thể không chứa đủ ngữ cảnh quan trọng.
  • Cần một chiến lược chia chunk hợp lý (ví dụ, dựa trên dấu câu hoặc cấu trúc ngữ pháp) để tăng khả năng 𝜔 nắm giữ ngữ cảnh hữu ích.

Thực hành.

Okeee, đã quá nhiều lý thuyết ở trên rồi, chúng ta triển khai một tẹo code để đỡ chán nhé 🫠🫠.

Setup một chút thì ở đây mình sẽ dùng mô hình của author trên huggingface nhé

Chúng ta sẽ thử chunk theo từng câu và tính cosine thử nhé.

from sentence_transformers import SentenceTransformer
# import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
# cos_sim = lambda x, y: np.dot(x, y) / (np.linalg.norm(x) * np.linalg.norm(y))
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
embedding_model = SentenceTransformer('jinaai/jina-embeddings-v2-base-en', trust_remote_code=True).to(device)
text = "Berlin is the capital and largest city of Germany, both by area and by population. Its more than 3,85 million inhabitants make it the European Union's most populous city, as measured by population within city limits. The city is also one of the states of Germany, and is the third smallest state in the country in terms of area."

sentences = [str(chunk + '.').strip() for chunk in text.split('.') if chunk]

embedding_traditional = embedding_model.encode(sentences)
query = "Berlin"
query_embedding = embedding_model.encode(query)

for sentence, embedding in zip(sentences, embedding_traditional):
  print(f"Sentence: {sentence}")
  print(f"Cosine similarity: {cosine_similarity([embedding], [query_embedding])}")

Kết quả là :

Sentence: Berlin is the capital and largest city of Germany, both by area and by population.
Cosine similarity: [[0.8486219]]
Sentence: Its more than 3,85 million inhabitants make it the European Union's most populous city, as measured by population within city limits.
Cosine similarity: [[0.7001531]]
Sentence: The city is also one of the states of Germany, and is the third smallest state in the country in terms of area.
Cosine similarity: [[0.7534554]]

Có thể nhìn thấy rằng câu 1 và câu 3 có các thông tin bổ sung cho Berlin như là nước Đức, thành phố, diện tích, ... thì cosine similarity nó lớn hơn. Còn câu 2, nếu đọc cả đoạn văn thì rõ ràng cũng đang nói đến Berlin, nhưng điểm lại không cao, do khi embed độc lập thì nó không có chút thông tin nào liên quan đến Berlin hay Đức , .... cả thì đương nhiên score của nó sẽ thấp hơn.

Oke thế thì implement thử Late Chunking và xem có gì thay đổi nhé 😃.

from sentence_transformers import SentenceTransformer
import torch
from sklearn.metrics.pairwise import cosine_similarity

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
embedding_model = SentenceTransformer('jinaai/jina-embeddings-v2-base-en', trust_remote_code=True).to(device)
transformer_layer = embedding_model._first_module()
pooling_layer = embedding_model._last_module()

text = "Berlin is the capital and largest city of Germany, both by area and by population. Its more than 3,85 million inhabitants make it the European Union's most populous city, as measured by population within city limits. The city is also one of the states of Germany, and is the third smallest state in the country in terms of area."
list_chunks = [chunk + '.' for chunk in text.split('.') if chunk]

# Step 1: Tokenize the entire text
tokens = embedding_model.tokenizer(text, return_tensors='pt', padding=False, truncation=False).to(device)


# Step 2: Get token embeddings
with torch.no_grad():
    outputs = transformer_layer({'input_ids': tokens['input_ids'], 'attention_mask': tokens['attention_mask']})
    token_embeddings = outputs['token_embeddings']

# Step 3: Use pooling layer for chunks
sentence_embeddings = []
current_token_idx = 1  # skip CLS token

for chunk in list_chunks:
    chunk_tokens = embedding_model.tokenizer(chunk, return_tensors='pt', padding=True, truncation=True).to(device)
    chunk_length = chunk_tokens['input_ids'].shape[1] - 2  # Remove CLS and SEP tokens
    
    chunk_embeddings = token_embeddings[:, current_token_idx:current_token_idx+chunk_length]
    chunk_attention_mask = chunk_tokens['attention_mask'][:, 1: -1] # Remove CLS and SEP tokens
    
    sentence_embedding = torch.mean(chunk_embeddings, dim=1)  # Mean pooling
    sentence_embedding = sentence_embedding.squeeze(0)  # Remove batch dimension
    
    # # Use pooling layer
    # features = {}
    # features['token_embeddings'] = chunk_embeddings # Add batch dimension
    # features['attention_mask'] = chunk_attention_mask
    # features['sentence_embedding'] = torch.mean(chunk_embeddings, dim=1)  # Mean pooling
    # sentence_embedding = pooling_layer(features)['sentence_embedding']
    # sentence_embedding = sentence_embedding.squeeze(0)  # Remove batch dimension

    sentence_embeddings.append(sentence_embedding)
    current_token_idx += chunk_length

sentence_embeddings = torch.stack(sentence_embeddings)

# Step 4: Process query using pooling layer
query = "Berlin"
query_tokens = embedding_model.tokenizer(query, return_tensors='pt', padding=True, truncation=True).to(device)

with torch.no_grad():
    query_outputs = transformer_layer({'input_ids': query_tokens['input_ids'], 
                                     'attention_mask': query_tokens['attention_mask']})
    query_embedding = query_outputs['token_embeddings']

# use pooling layer
query_embedding = torch.mean(query_embedding, dim=1)  # Mean pooling
query_embedding = query_embedding.squeeze(0)  # Remove batch dimension


for sentence, embedding in zip(list_chunks, sentence_embeddings):
    print(f"Sentence: {sentence}")
    print(f"Cosine similarity: {cosine_similarity(embedding.cpu().numpy().reshape(1, -1), query_embedding.cpu().numpy().reshape(1, -1))}")

Kết quả :

Sentence: Berlin is the capital and largest city of Germany, both by area and by population.
Cosine similarity: [[0.85716665]]
Sentence:  Its more than 3,85 million inhabitants make it the European Union's most populous city, as measured by population within city limits.
Cosine similarity: [[0.8255315]]
Sentence:  The city is also one of the states of Germany, and is the third smallest state in the country in terms of area.
Cosine similarity: [[0.8503671]]

Hừm, nó có tăng lên một chút ở các câu ở dưới. Chứng tỏ việc embed toàn bộ văn bản cũng đã cải thiện được một phần chunk embeddings.

Kết luận.

Oke, Kết luận một chút nào. Dường như Late Chunking đang cố gắng cải thiện embedding bằng cách embed cả document vào rồi mới chia token vô chunk. Tuy nhiên nó có một vài nhược điểm:

  • Phụ thuộc hoàn toàn vào long context embedding model.
  • Không dùng được với mọi model embedding.
  • Mình có test thử thêm với model embedding cho tiếng việt như dangvantuan/vietnamese-embedding thì consine khi embedding truyền thống câu 1 sẽ là 0.78, câu 2 là 0.078, câu 3 là 0.42. Áp dụng vào late chunking thì cosine câu 1, 2 , 3 sẽ loanh quanh ở 0.59 -> 0.65. Mình cảm giác không đẩy được cao lên nữa. Có thể do chất lượng model cũng như cách triển khai basic code của mình chưa đủ tốt, chưa bao gồm đủ các case.

Đồng ý rằng đây là ý tưởng hay, tuy nhiên cũng phải tùy vào từng bài toán và tính ổn định thì mới chọn những chiến thuật chunking khác nhau. Nếu mà được và budget đủ dùng thì mình vẫn thích dùng propositional-retrieval hơn 🤣

10 phút đọc cũng đã đủ dài cho một bài view trước khi mọi người đọc ngẫm sâu hơn vào paper rồi, hẹn mọi người ở các bài viết sau. Nếu hay thì hãy upvote và bookmark cho mình với nào 😊

Reference.


All rights reserved

Viblo
Hãy đăng ký một tài khoản Viblo để nhận được nhiều bài viết thú vị hơn.
Đăng kí