Join our community |
---|
This repository is designed to support developers in building and training their own geotagging models. The geotagging model architecture provided here allows for customization and training. Additionally, we publish datasets that are well-suited for training in different geolocation detection scenarios.
The current models reach 30km Median Error on Haversine Distance for top 10% most relevant texts. Challenges in the repository issues are open to improve the model's performance.
Click to unfold geotagging model architecture diagram.
%%{init:{'theme':'neutral'}}%%
flowchart TD
subgraph "ByT5 classifier"
a("Input text") --> b("Input_ids")
subgraph "byt5(T5EncoderModel)"
b("Input_ids") --> c("byt5.encoder.inp_input_ids")
subgraph "byt5.encoder(T5Stack)"
c("byt5.encoder.inp_input_ids") --> d("byt5.encoder.embed_tokens")
subgraph "byt5.encoder.embed_tokens (Embedding)"
d("byt5.encoder.embed_tokens") --> f("embedding")
e("byt5.encoder.embed_tokens.inp_weights") --> f("embedding") --> g("byt5.encoder.embed_tokens.out_0")
end
g("byt5.encoder.embed_tokens.out_0") --> h("byt5.encoder.dropout(Dropout)") --> i("byt5.encoder.block.0(T5Block)") --> j("byt5.encoder.block.1(T5Block)") & k("byt5.encoder.block.2-9(T5Block)") & l("byt5.encoder.block.10(T5Block)")
j("byt5.encoder.block.1(T5Block)") --> k("byt5.encoder.block.2(T5Block)<br><br> ...<br><br>byt5.encoder.block.10(T5Block) ") --> l("byt5.encoder.block.11(T5Block)") --> m("byt5.encoder.final_layer_norm(T5LayerNorm)")
m("byt5.encoder.final_layer_norm(T5LayerNorm)")-->n("byt5.encoder.dropout(Dropout)")--> o("byt5.encoder.out_0")
end
o("byt5.encoder.out_0") --> p("byt5.out_0")
end
p("byt5.out_0")-->q("(Linear)")
end
q("(Linear)") -->r("logits")
Train your text-to-location model |
---|
Ensure that the following dependencies are installed in your environment to build and train your geotagging model:
transformers==4.29.1
tqdm==4.63.2
pandas==1.4.4
pytorch==1.7.1
To train your geotagging model using the ByT5-encoder based approach, execute the following script:
python train_model.py --train_input_file <training_file> --test_input_file <test_file> --do_train true --do_test true --load_clustering .
Refer to the train_model.py
file for a comprehensive list of available parameters.
{
"text":"These kittens need homes and are located in the Omaha area! They have their shots and are spayed/neutered. They need to be gone by JAN 1st! Please Retweet to help spread the word!",
"geotagging":{
"lat":41.257160,
"lon":-95.995102,
"confidence":0.9950085878372192
}
}
{
"type": "FeatureCollection",
"features": [
{
"type": "Feature",
"id": 1,
"properties": {
"ID": 0
},
"geometry": {
"type": "Polygon",
"coordinates": [
[
[-96.296363, 41.112793],
[-96.296363, 41.345177],
[-95.786877, 41.345177],
[-95.786877, 41.112793],
[-96.296363, 41.112793]
]
]
}
},
{
"type": "Feature",
"id": 2,
"properties": {
"ID": 0
},
"geometry": {
"type": "Point",
"coordinates": [-95.995102, 41.257160]
}
}
]
}
Our team has curated two comprehensive datasets for two distinct training approaches. These datasets are intended for use in training and validating the models. Share your training results in the repository issues.
Regions dataset |
---|
The goal of the Regions approach is to look into the dataset of top most populated regions around the world.
- is an annotated corpus of 500k texts, as well as the respective geocoordinates
- covers 123 regions
- includes 5000 tweets per location
Seasons dataset |
---|
The goal of the Seasons approach is to identify the correlation between the time/date of post, the content, and the location. Time zone differences, as well as seasonality of the events, should be analyzed and used to predict the location. For example: snow is more likely to appear in the Northern Hemisphere, especially if in December. Rock concerts are more likely to happen in the evening and in bigger cities, so the time of the post about a concert should be used to identify the time zone of the author and narrow down the list of potential locations.
- is a .json of >600.000 texts
- collected over the span of 12 months
- covers 15 different time zones
- focuses on 6 countries (Cuba, Iran, Russia, North Korea, Syria, Venezuela)
Your custom data. The geotagging model supports training and testing on custom datasets. Prepare your data in CSV format with the following columns: text
, lat
, and lon
.
The geotagging model incorporates confidence estimation to assess the reliability of predicted coordinates. The Relevance field in the output indicates prediction confidence, ranging from 0.0
to 1.0.
Higher values indicate increased confidence.
For detailed information on confidence estimation and how to utilize the model for geotagging predictions, please refer to the inference.py
file. This file provides an example script demonstrating the model architecture and integration of confidence estimation.
Feel free to explore the code, adapt it to your specific requirements, and integrate it into your projects. If you have any questions or require assistance, please don't hesitate to reach out. We highly appreciate your feedback and are dedicated to continuously enhancing the geotagging models.