SVN / public / code / lecture-11 / embeddings_test.py

Revision 2035
Date2026-01-20T16:05:16+01:00
Committerhb1003
Download
import argparse
import torch

if __name__ == "__main__":
    argparser = argparse.ArgumentParser()
    argparser.add_argument("embeddings_file", help="Path to embeddings file")
    args = argparser.parse_args()

    print("Loading embeddings from {}".format(args.embeddings_file))
    embeddings = torch.load(args.embeddings_file)
    internet = embeddings["internet"]
    web = embeddings["web"]
    surfing = embeddings["surfing"]
    beach = embeddings["beach"]
    print()
    print(
        'Cosine similary between "internet" and "web":',
        torch.cosine_similarity(internet, web, dim=0),
    )
    print(
        'Cosine similary between "internet" and "beach":',
        torch.cosine_similarity(internet, beach, dim=0),
    )
    print(
        'Cosine similary between "beach" and "surfing":',
        torch.cosine_similarity(beach, surfing, dim=0),
    )
    print()