Skip to content

docs should mention limitation of sbert #8

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
thiswillbeyourgithub opened this issue Sep 4, 2023 · 4 comments
Closed

docs should mention limitation of sbert #8

thiswillbeyourgithub opened this issue Sep 4, 2023 · 4 comments
Labels
documentation Improvements or additions to documentation

Comments

@thiswillbeyourgithub
Copy link

thiswillbeyourgithub commented Sep 4, 2023

Hi,

I encounter time and time again people disappointed by the effectiveness of the sentence-transformers models. Usually the reason being that the models have very short "max sequence length" (the default model is 256) and everything after that is silently clipped.

GIven that this happens silently, I think most people are not aware of that. And the multilingual models have even shorted length!

I brought that up several times here on langchain and there too.

So I think it would be good to mention this in the README.md.

And if anyone is down for writing a simple wrapper that does a rolling average/maxpooling/whateverpooling of the input instead of clipping it that would be awesome! That would be a workaround that can't possibly be worse than just clipping the input right?

Cheers and llm is great!

(related to simonw/llm#220)

simonw added a commit that referenced this issue Sep 8, 2023
@simonw
Copy link
Owner

simonw commented Sep 8, 2023

Thanks, that's a good tip - I've added it to the "usage" section.

@simonw simonw closed this as completed Sep 8, 2023
simonw added a commit that referenced this issue Sep 8, 2023
@simonw simonw added the documentation Improvements or additions to documentation label Sep 8, 2023
@thiswillbeyourgithub
Copy link
Author

thiswillbeyourgithub commented Sep 24, 2023

I'm sharing my own 'rolling' sbert script to avoid clipping the sentences. It's seemingly functionnal but not very elegent, a class would be better of course but I just hope it helps someone :

from sentence_transformers import SentenceTransformer


model = SentenceTransformer("all-mpnet-base-v2")

# sbert silently crops any token above the max_seq_length,
# so we do a windowing embedding then sum. The normalization happens
# afterwards.
def encode_sentences(sentences):
      max_len = model.get_max_seq_length()

    if not isinstance(max_len, int):
        # the clip model has a different way to use the encoder
        # sources : https://github.com/UKPLab/sentence-transformers/issues/1269
        assert "clip" in str(model).lower(), f"sbert model with no 'max_seq_length' attribute and not clip: '{model}'"
        max_len = 77
        encode = model._first_module().processor.tokenizer.encode
    else:
        if hasattr(model.tokenizer, "encode"):
            # most models
            encode = model.tokenizer.encode
        else:
            # word embeddings models like glove
            encode = model.tokenizer.tokenize

    assert isinstance(max_len, int), "n must be int"
    n23 = (max_len * 2) // 3
    add_sent = []  # additional sentences
    add_sent_idx = []  # indices to keep track of sub sentences

    for i, s in enumerate(sentences):
        # skip if the sentence is short
        length = len(encode(s))
        if length <= max_len:
            continue

        # otherwise, split the sentence at regular interval
        # then do the embedding of each
        # and finally maxpool those sub embeddings together
        # the renormalization happens later in the code
        sub_sentences = []
        words = s.split(" ")
        avg_tkn = length / len(words)
        j = int(max_len / avg_tkn * 0.8)  # start at 90% of the supposed max_len
        while len(encode(" ".join(words))) > max_len:

            # if reached max length, use that minus one word
            until_j = len(encode(" ".join(words[:j])))
            if until_j >= max_len:
                jjj = 1
                while len(encode(" ".join(words[:j-jjj]))) >= max_len:
                    jjj += 1
                sub_sentences.append(" ".join(words[:j-jjj]))

                # remove first word until 1/3 of the max_token was removed
                # this way we have a rolling window
                jj = int((max_len // 3) / avg_tkn * 0.8)
                while len(encode(" ".join(words[jj:j-jjj]))) > n23:
                    jj += 1
                words = words[jj:]

                j = int(max_len / avg_tkn * 0.8)
            else:
                diff = abs(max_len - until_j)
                if diff > 10:
                    j += max(1, int(10 / avg_tkn))
                else:
                    j += 1

        sub_sentences.append(" ".join(words))

        sentences[i] = " "  # discard this sentence as we will keep only
        # the sub sentences maxpooled

        # remove empty text just in case
        if "" in sub_sentences:
            while "" in sub_sentences:
                sub_sentences.remove("")
        assert sum([len(encode(ss)) > max_len for ss in sub_sentences]) == 0, f"error when splitting long sentences: {sub_sentences}"
        add_sent.extend(sub_sentences)
        add_sent_idx.extend([i] * len(sub_sentences))

    if add_sent:
        sent_check = [
                len(encode(s)) > max_len
                for s in sentences
                ]
        addsent_check = [
                len(encode(s)) > max_len
                for s in add_sent
                ]
        assert sum(sent_check + addsent_check) == 0, (
            f"The rolling average failed apparently:\n{sent_check}\n{addsent_check}")
    vectors = vectorizer(
            sentences=sentences + add_sent,
            show_progress_bar=True,
            output_value="sentence_embedding",
            convert_to_numpy=True,
            normalize_embeddings=False,
            )

    if add_sent:
        # at the position of the original sentence (not split)
        # add the vectors of the corresponding sub_sentence
        # then return only the 'maxpooled' section
        assert len(add_sent) == len(add_sent_idx), (
            "Invalid add_sent length")
        offset = len(sentences)
        for sid in list(set(add_sent_idx)):
            id_range = [i for i, j in enumerate(add_sent_idx) if j == sid]
            add_sent_vec = vectors[
                    offset + min(id_range): offset + max(id_range), :]
            vectors[sid] = np.amax(add_sent_vec, axis=0)
        return vectors[:offset]
    else:
        return vectors


edit: fixed the code :/

@Jakobhenningjensen
Copy link

@thiswillbeyourgithub

I like the idea of that "rolling window".

Is it the way to do it, or is it just to have an alternative to clipping i.e how well does it work?

@thiswillbeyourgithub
Copy link
Author

Keep in mind that my implementation is pretty naive and can certainly be vastly optimized but the idea is there. An implementation for langchain can be found here. A nonlangchain implementation in a code I'm using regularly can be found here

I don't know if I have found by chance the way but probably not. In my example I did a maxpooling but I could have done a meanpooling instead. That also brings about the question of L1 vs L2 if doing a normalization. Also one can think about having an exponential decay of the importance of each new token of text etc.

In my experience: Maxpooling or meanpooling seem to work fine.
In my opinion: Any kind of rolling window seems in theory vastly superior to silently cropping and the difference between each windowing method is probably negligeable compared to cropping.

More tests would be needed with proper metrics to find out which is best and if the enhancement is not just placebo that degrade results.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation
Projects
None yet
Development

No branches or pull requests

3 participants